Adding test case.

Fixing a couple of bugs.
This commit is contained in:
Knut Forkalsrud 2013-11-07 23:01:46 -08:00
parent 0ea01af309
commit 969add5846
2 changed files with 64 additions and 21 deletions

View file

@ -29,7 +29,7 @@ public class VpTree<T> {
}
private class WithDistance implements Comparable<WithDistance> {
public class WithDistance implements Comparable<WithDistance> {
private T value;
private double distance;
@ -52,13 +52,18 @@ public class VpTree<T> {
double delta = other.distance - this.distance;
return (int) Math.signum(delta);
}
@Override
public String toString() {
return String.valueOf(value) + " (" + distance + ")";
}
}
public List<WithDistance> search(final T target, int numResults) {
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
root.search(new Nearest(target, numResults));
Nearest acc = new Nearest(target, numResults);
root.search(acc);
PriorityQueue<WithDistance> heap = acc.matches;
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
while (!heap.isEmpty()) {
results.add(heap.remove());
@ -76,23 +81,27 @@ public class VpTree<T> {
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
if (right <= left) {
return null;
}
WithDistance vp = points.get(left);
Node leftChild = null;
Node rightChild = null;
double threshold = 0d;
if (right - left > 1) {
// choose an arbitrary point and move it to the start
int len = right - left;
int pos = left + random.nextInt(len);
Collections.swap(points, left, pos);
final WithDistance vp = points.get(left);
vp = points.get(left);
for (int i = left + 1; i < right; i++) {
WithDistance wd = points.get(i);
wd.distance = metric.distance(vp.value, wd.value);
}
Collections.sort(points.subList(left + 1, right));
int midpoint = (left + 1 + right) / 2;
double threshold = points.get(midpoint).distance;
Node leftChild = buildFromPoints(points, left + 1, midpoint);
Node rightChild = buildFromPoints(points, midpoint, right);
threshold = points.get(midpoint).distance;
leftChild = buildFromPoints(points, left + 1, midpoint);
rightChild = buildFromPoints(points, midpoint, right);
}
return new Node(vp.value, threshold, leftChild, rightChild);
}
@ -105,7 +114,7 @@ public class VpTree<T> {
public Nearest(T query, int numMatches) {
this.query = query;
this.numMatches = numMatches;
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
this.matches = new PriorityQueue<WithDistance>(numMatches);
}
public double distanceFrom(T value) {

View file

@ -0,0 +1,34 @@
package org.forkalsrud.util;
import java.util.Arrays;
import java.util.List;
import org.forkalsrud.util.VpTree.Metric;
import static org.junit.Assert.*;
import org.junit.Test;
public class VpTreeTest {
@Test
public void testSearch() {
double epsilon = 0.0003d;
VpTree<Integer> vp = new VpTree<Integer>(Arrays.asList(4, 1, 7, 2, 9, 3, 5), new Metric<Integer>() {
@Override
public double distance(Integer a, Integer b) {
return Math.abs(b - a);
}
});
List<VpTree<Integer>.WithDistance> result = vp.search(6, 3);
assertEquals(3, result.size());
assertEquals(1d, result.get(0).getDistance(), epsilon);
assertEquals(1d, result.get(1).getDistance(), epsilon);
assertEquals(2d, result.get(2).getDistance(), epsilon);
}
}