Adding test case.

Fixing a couple of bugs.
This commit is contained in:
Knut Forkalsrud 2013-11-07 23:01:46 -08:00
parent 0ea01af309
commit 969add5846
2 changed files with 64 additions and 21 deletions

View file

@ -29,7 +29,7 @@ public class VpTree<T> {
} }
private class WithDistance implements Comparable<WithDistance> { public class WithDistance implements Comparable<WithDistance> {
private T value; private T value;
private double distance; private double distance;
@ -52,13 +52,18 @@ public class VpTree<T> {
double delta = other.distance - this.distance; double delta = other.distance - this.distance;
return (int) Math.signum(delta); return (int) Math.signum(delta);
} }
@Override
public String toString() {
return String.valueOf(value) + " (" + distance + ")";
}
} }
public List<WithDistance> search(final T target, int numResults) { public List<WithDistance> search(final T target, int numResults) {
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults); Nearest acc = new Nearest(target, numResults);
root.search(new Nearest(target, numResults)); root.search(acc);
PriorityQueue<WithDistance> heap = acc.matches;
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size()); ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
while (!heap.isEmpty()) { while (!heap.isEmpty()) {
results.add(heap.remove()); results.add(heap.remove());
@ -76,23 +81,27 @@ public class VpTree<T> {
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) { Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
if (right <= left) { WithDistance vp = points.get(left);
return null; Node leftChild = null;
Node rightChild = null;
double threshold = 0d;
if (right - left > 1) {
// choose an arbitrary point and move it to the start
int len = right - left;
int pos = left + random.nextInt(len);
Collections.swap(points, left, pos);
vp = points.get(left);
for (int i = left + 1; i < right; i++) {
WithDistance wd = points.get(i);
wd.distance = metric.distance(vp.value, wd.value);
}
Collections.sort(points.subList(left + 1, right));
int midpoint = (left + 1 + right) / 2;
threshold = points.get(midpoint).distance;
leftChild = buildFromPoints(points, left + 1, midpoint);
rightChild = buildFromPoints(points, midpoint, right);
} }
// choose an arbitrary point and move it to the start
int len = right - left;
int pos = left + random.nextInt(len);
Collections.swap(points, left, pos);
final WithDistance vp = points.get(left);
for (int i = left + 1; i < right; i++) {
WithDistance wd = points.get(i);
wd.distance = metric.distance(vp.value, wd.value);
}
Collections.sort(points.subList(left + 1, right));
int midpoint = (left + 1 + right) / 2;
double threshold = points.get(midpoint).distance;
Node leftChild = buildFromPoints(points, left + 1, midpoint);
Node rightChild = buildFromPoints(points, midpoint, right);
return new Node(vp.value, threshold, leftChild, rightChild); return new Node(vp.value, threshold, leftChild, rightChild);
} }
@ -105,7 +114,7 @@ public class VpTree<T> {
public Nearest(T query, int numMatches) { public Nearest(T query, int numMatches) {
this.query = query; this.query = query;
this.numMatches = numMatches; this.numMatches = numMatches;
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder()); this.matches = new PriorityQueue<WithDistance>(numMatches);
} }
public double distanceFrom(T value) { public double distanceFrom(T value) {

View file

@ -0,0 +1,34 @@
package org.forkalsrud.util;
import java.util.Arrays;
import java.util.List;
import org.forkalsrud.util.VpTree.Metric;
import static org.junit.Assert.*;
import org.junit.Test;
public class VpTreeTest {
@Test
public void testSearch() {
double epsilon = 0.0003d;
VpTree<Integer> vp = new VpTree<Integer>(Arrays.asList(4, 1, 7, 2, 9, 3, 5), new Metric<Integer>() {
@Override
public double distance(Integer a, Integer b) {
return Math.abs(b - a);
}
});
List<VpTree<Integer>.WithDistance> result = vp.search(6, 3);
assertEquals(3, result.size());
assertEquals(1d, result.get(0).getDistance(), epsilon);
assertEquals(1d, result.get(1).getDistance(), epsilon);
assertEquals(2d, result.get(2).getDistance(), epsilon);
}
}