From 969add584670a8183eba5b6e40dd7f91221ced7c Mon Sep 17 00:00:00 2001 From: Knut Forkalsrud Date: Thu, 7 Nov 2013 23:01:46 -0800 Subject: [PATCH] Adding test case. Fixing a couple of bugs. --- src/main/java/org/forkalsrud/util/VpTree.java | 51 +++++++++++-------- .../java/org/forkalsrud/util/VpTreeTest.java | 34 +++++++++++++ 2 files changed, 64 insertions(+), 21 deletions(-) create mode 100644 src/main/java/org/forkalsrud/util/VpTreeTest.java diff --git a/src/main/java/org/forkalsrud/util/VpTree.java b/src/main/java/org/forkalsrud/util/VpTree.java index b875e02..0e0baad 100644 --- a/src/main/java/org/forkalsrud/util/VpTree.java +++ b/src/main/java/org/forkalsrud/util/VpTree.java @@ -29,7 +29,7 @@ public class VpTree { } - private class WithDistance implements Comparable { + public class WithDistance implements Comparable { private T value; private double distance; @@ -52,13 +52,18 @@ public class VpTree { double delta = other.distance - this.distance; return (int) Math.signum(delta); } + + @Override + public String toString() { + return String.valueOf(value) + " (" + distance + ")"; + } } public List search(final T target, int numResults) { - PriorityQueue heap = new PriorityQueue(numResults); - root.search(new Nearest(target, numResults)); - + Nearest acc = new Nearest(target, numResults); + root.search(acc); + PriorityQueue heap = acc.matches; ArrayList results = new ArrayList(heap.size()); while (!heap.isEmpty()) { results.add(heap.remove()); @@ -76,23 +81,27 @@ public class VpTree { Node buildFromPoints(ArrayList points, int left, int right) { - if (right <= left) { - return null; + WithDistance vp = points.get(left); + Node leftChild = null; + Node rightChild = null; + double threshold = 0d; + + if (right - left > 1) { + // choose an arbitrary point and move it to the start + int len = right - left; + int pos = left + random.nextInt(len); + Collections.swap(points, left, pos); + vp = points.get(left); + for (int i = left + 1; i < right; i++) { + WithDistance wd = points.get(i); + wd.distance = metric.distance(vp.value, wd.value); + } + Collections.sort(points.subList(left + 1, right)); + int midpoint = (left + 1 + right) / 2; + threshold = points.get(midpoint).distance; + leftChild = buildFromPoints(points, left + 1, midpoint); + rightChild = buildFromPoints(points, midpoint, right); } - // choose an arbitrary point and move it to the start - int len = right - left; - int pos = left + random.nextInt(len); - Collections.swap(points, left, pos); - final WithDistance vp = points.get(left); - for (int i = left + 1; i < right; i++) { - WithDistance wd = points.get(i); - wd.distance = metric.distance(vp.value, wd.value); - } - Collections.sort(points.subList(left + 1, right)); - int midpoint = (left + 1 + right) / 2; - double threshold = points.get(midpoint).distance; - Node leftChild = buildFromPoints(points, left + 1, midpoint); - Node rightChild = buildFromPoints(points, midpoint, right); return new Node(vp.value, threshold, leftChild, rightChild); } @@ -105,7 +114,7 @@ public class VpTree { public Nearest(T query, int numMatches) { this.query = query; this.numMatches = numMatches; - this.matches = new PriorityQueue(numMatches, Collections.reverseOrder()); + this.matches = new PriorityQueue(numMatches); } public double distanceFrom(T value) { diff --git a/src/main/java/org/forkalsrud/util/VpTreeTest.java b/src/main/java/org/forkalsrud/util/VpTreeTest.java new file mode 100644 index 0000000..9c67ae0 --- /dev/null +++ b/src/main/java/org/forkalsrud/util/VpTreeTest.java @@ -0,0 +1,34 @@ +package org.forkalsrud.util; + +import java.util.Arrays; +import java.util.List; + +import org.forkalsrud.util.VpTree.Metric; +import static org.junit.Assert.*; +import org.junit.Test; + +public class VpTreeTest { + + @Test + public void testSearch() { + + double epsilon = 0.0003d; + + VpTree vp = new VpTree(Arrays.asList(4, 1, 7, 2, 9, 3, 5), new Metric() { + + @Override + public double distance(Integer a, Integer b) { + return Math.abs(b - a); + } + + }); + + + List.WithDistance> result = vp.search(6, 3); + assertEquals(3, result.size()); + assertEquals(1d, result.get(0).getDistance(), epsilon); + assertEquals(1d, result.get(1).getDistance(), epsilon); + assertEquals(2d, result.get(2).getDistance(), epsilon); + } + +}