Adding test case.
Fixing a couple of bugs.
This commit is contained in:
parent
0ea01af309
commit
969add5846
2 changed files with 64 additions and 21 deletions
|
|
@ -29,7 +29,7 @@ public class VpTree<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private class WithDistance implements Comparable<WithDistance> {
|
public class WithDistance implements Comparable<WithDistance> {
|
||||||
|
|
||||||
private T value;
|
private T value;
|
||||||
private double distance;
|
private double distance;
|
||||||
|
|
@ -52,13 +52,18 @@ public class VpTree<T> {
|
||||||
double delta = other.distance - this.distance;
|
double delta = other.distance - this.distance;
|
||||||
return (int) Math.signum(delta);
|
return (int) Math.signum(delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return String.valueOf(value) + " (" + distance + ")";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<WithDistance> search(final T target, int numResults) {
|
public List<WithDistance> search(final T target, int numResults) {
|
||||||
|
|
||||||
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
|
Nearest acc = new Nearest(target, numResults);
|
||||||
root.search(new Nearest(target, numResults));
|
root.search(acc);
|
||||||
|
PriorityQueue<WithDistance> heap = acc.matches;
|
||||||
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
||||||
while (!heap.isEmpty()) {
|
while (!heap.isEmpty()) {
|
||||||
results.add(heap.remove());
|
results.add(heap.remove());
|
||||||
|
|
@ -76,23 +81,27 @@ public class VpTree<T> {
|
||||||
|
|
||||||
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
||||||
|
|
||||||
if (right <= left) {
|
WithDistance vp = points.get(left);
|
||||||
return null;
|
Node leftChild = null;
|
||||||
|
Node rightChild = null;
|
||||||
|
double threshold = 0d;
|
||||||
|
|
||||||
|
if (right - left > 1) {
|
||||||
|
// choose an arbitrary point and move it to the start
|
||||||
|
int len = right - left;
|
||||||
|
int pos = left + random.nextInt(len);
|
||||||
|
Collections.swap(points, left, pos);
|
||||||
|
vp = points.get(left);
|
||||||
|
for (int i = left + 1; i < right; i++) {
|
||||||
|
WithDistance wd = points.get(i);
|
||||||
|
wd.distance = metric.distance(vp.value, wd.value);
|
||||||
|
}
|
||||||
|
Collections.sort(points.subList(left + 1, right));
|
||||||
|
int midpoint = (left + 1 + right) / 2;
|
||||||
|
threshold = points.get(midpoint).distance;
|
||||||
|
leftChild = buildFromPoints(points, left + 1, midpoint);
|
||||||
|
rightChild = buildFromPoints(points, midpoint, right);
|
||||||
}
|
}
|
||||||
// choose an arbitrary point and move it to the start
|
|
||||||
int len = right - left;
|
|
||||||
int pos = left + random.nextInt(len);
|
|
||||||
Collections.swap(points, left, pos);
|
|
||||||
final WithDistance vp = points.get(left);
|
|
||||||
for (int i = left + 1; i < right; i++) {
|
|
||||||
WithDistance wd = points.get(i);
|
|
||||||
wd.distance = metric.distance(vp.value, wd.value);
|
|
||||||
}
|
|
||||||
Collections.sort(points.subList(left + 1, right));
|
|
||||||
int midpoint = (left + 1 + right) / 2;
|
|
||||||
double threshold = points.get(midpoint).distance;
|
|
||||||
Node leftChild = buildFromPoints(points, left + 1, midpoint);
|
|
||||||
Node rightChild = buildFromPoints(points, midpoint, right);
|
|
||||||
return new Node(vp.value, threshold, leftChild, rightChild);
|
return new Node(vp.value, threshold, leftChild, rightChild);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -105,7 +114,7 @@ public class VpTree<T> {
|
||||||
public Nearest(T query, int numMatches) {
|
public Nearest(T query, int numMatches) {
|
||||||
this.query = query;
|
this.query = query;
|
||||||
this.numMatches = numMatches;
|
this.numMatches = numMatches;
|
||||||
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
|
this.matches = new PriorityQueue<WithDistance>(numMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double distanceFrom(T value) {
|
public double distanceFrom(T value) {
|
||||||
|
|
|
||||||
34
src/main/java/org/forkalsrud/util/VpTreeTest.java
Normal file
34
src/main/java/org/forkalsrud/util/VpTreeTest.java
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
package org.forkalsrud.util;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.forkalsrud.util.VpTree.Metric;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class VpTreeTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSearch() {
|
||||||
|
|
||||||
|
double epsilon = 0.0003d;
|
||||||
|
|
||||||
|
VpTree<Integer> vp = new VpTree<Integer>(Arrays.asList(4, 1, 7, 2, 9, 3, 5), new Metric<Integer>() {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double distance(Integer a, Integer b) {
|
||||||
|
return Math.abs(b - a);
|
||||||
|
}
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
List<VpTree<Integer>.WithDistance> result = vp.search(6, 3);
|
||||||
|
assertEquals(3, result.size());
|
||||||
|
assertEquals(1d, result.get(0).getDistance(), epsilon);
|
||||||
|
assertEquals(1d, result.get(1).getDistance(), epsilon);
|
||||||
|
assertEquals(2d, result.get(2).getDistance(), epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue