Adding test case.
Fixing a couple of bugs.
This commit is contained in:
parent
0ea01af309
commit
969add5846
2 changed files with 64 additions and 21 deletions
|
|
@ -29,7 +29,7 @@ public class VpTree<T> {
|
|||
}
|
||||
|
||||
|
||||
private class WithDistance implements Comparable<WithDistance> {
|
||||
public class WithDistance implements Comparable<WithDistance> {
|
||||
|
||||
private T value;
|
||||
private double distance;
|
||||
|
|
@ -52,13 +52,18 @@ public class VpTree<T> {
|
|||
double delta = other.distance - this.distance;
|
||||
return (int) Math.signum(delta);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.valueOf(value) + " (" + distance + ")";
|
||||
}
|
||||
}
|
||||
|
||||
public List<WithDistance> search(final T target, int numResults) {
|
||||
|
||||
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
|
||||
root.search(new Nearest(target, numResults));
|
||||
|
||||
Nearest acc = new Nearest(target, numResults);
|
||||
root.search(acc);
|
||||
PriorityQueue<WithDistance> heap = acc.matches;
|
||||
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
||||
while (!heap.isEmpty()) {
|
||||
results.add(heap.remove());
|
||||
|
|
@ -76,23 +81,27 @@ public class VpTree<T> {
|
|||
|
||||
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
||||
|
||||
if (right <= left) {
|
||||
return null;
|
||||
WithDistance vp = points.get(left);
|
||||
Node leftChild = null;
|
||||
Node rightChild = null;
|
||||
double threshold = 0d;
|
||||
|
||||
if (right - left > 1) {
|
||||
// choose an arbitrary point and move it to the start
|
||||
int len = right - left;
|
||||
int pos = left + random.nextInt(len);
|
||||
Collections.swap(points, left, pos);
|
||||
vp = points.get(left);
|
||||
for (int i = left + 1; i < right; i++) {
|
||||
WithDistance wd = points.get(i);
|
||||
wd.distance = metric.distance(vp.value, wd.value);
|
||||
}
|
||||
Collections.sort(points.subList(left + 1, right));
|
||||
int midpoint = (left + 1 + right) / 2;
|
||||
threshold = points.get(midpoint).distance;
|
||||
leftChild = buildFromPoints(points, left + 1, midpoint);
|
||||
rightChild = buildFromPoints(points, midpoint, right);
|
||||
}
|
||||
// choose an arbitrary point and move it to the start
|
||||
int len = right - left;
|
||||
int pos = left + random.nextInt(len);
|
||||
Collections.swap(points, left, pos);
|
||||
final WithDistance vp = points.get(left);
|
||||
for (int i = left + 1; i < right; i++) {
|
||||
WithDistance wd = points.get(i);
|
||||
wd.distance = metric.distance(vp.value, wd.value);
|
||||
}
|
||||
Collections.sort(points.subList(left + 1, right));
|
||||
int midpoint = (left + 1 + right) / 2;
|
||||
double threshold = points.get(midpoint).distance;
|
||||
Node leftChild = buildFromPoints(points, left + 1, midpoint);
|
||||
Node rightChild = buildFromPoints(points, midpoint, right);
|
||||
return new Node(vp.value, threshold, leftChild, rightChild);
|
||||
}
|
||||
|
||||
|
|
@ -105,7 +114,7 @@ public class VpTree<T> {
|
|||
public Nearest(T query, int numMatches) {
|
||||
this.query = query;
|
||||
this.numMatches = numMatches;
|
||||
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
|
||||
this.matches = new PriorityQueue<WithDistance>(numMatches);
|
||||
}
|
||||
|
||||
public double distanceFrom(T value) {
|
||||
|
|
|
|||
34
src/main/java/org/forkalsrud/util/VpTreeTest.java
Normal file
34
src/main/java/org/forkalsrud/util/VpTreeTest.java
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package org.forkalsrud.util;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.forkalsrud.util.VpTree.Metric;
|
||||
import static org.junit.Assert.*;
|
||||
import org.junit.Test;
|
||||
|
||||
public class VpTreeTest {
|
||||
|
||||
@Test
|
||||
public void testSearch() {
|
||||
|
||||
double epsilon = 0.0003d;
|
||||
|
||||
VpTree<Integer> vp = new VpTree<Integer>(Arrays.asList(4, 1, 7, 2, 9, 3, 5), new Metric<Integer>() {
|
||||
|
||||
@Override
|
||||
public double distance(Integer a, Integer b) {
|
||||
return Math.abs(b - a);
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
|
||||
List<VpTree<Integer>.WithDistance> result = vp.search(6, 3);
|
||||
assertEquals(3, result.size());
|
||||
assertEquals(1d, result.get(0).getDistance(), epsilon);
|
||||
assertEquals(1d, result.get(1).getDistance(), epsilon);
|
||||
assertEquals(2d, result.get(2).getDistance(), epsilon);
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue