diff --git a/src/main/java/org/forkalsrud/util/VpTree.java b/src/main/java/org/forkalsrud/util/VpTree.java index 0e0baad..3eaf0ac 100644 --- a/src/main/java/org/forkalsrud/util/VpTree.java +++ b/src/main/java/org/forkalsrud/util/VpTree.java @@ -61,9 +61,9 @@ public class VpTree { public List search(final T target, int numResults) { - Nearest acc = new Nearest(target, numResults); - root.search(acc); - PriorityQueue heap = acc.matches; + Searcher searcher = new Searcher(target, numResults); + root.search(searcher); + PriorityQueue heap = searcher.matches; ArrayList results = new ArrayList(heap.size()); while (!heap.isEmpty()) { results.add(heap.remove()); @@ -82,8 +82,8 @@ public class VpTree { Node buildFromPoints(ArrayList points, int left, int right) { WithDistance vp = points.get(left); - Node leftChild = null; - Node rightChild = null; + Node nearChild = null; + Node farChild = null; double threshold = 0d; if (right - left > 1) { @@ -92,35 +92,42 @@ public class VpTree { int pos = left + random.nextInt(len); Collections.swap(points, left, pos); vp = points.get(left); + // sort the rest by distance from said point for (int i = left + 1; i < right; i++) { WithDistance wd = points.get(i); wd.distance = metric.distance(vp.value, wd.value); } Collections.sort(points.subList(left + 1, right)); + // Use the median distance as a split int midpoint = (left + 1 + right) / 2; threshold = points.get(midpoint).distance; - leftChild = buildFromPoints(points, left + 1, midpoint); - rightChild = buildFromPoints(points, midpoint, right); + // The leftmost points are closer, the rightmost are farher away + nearChild = buildFromPoints(points, left + 1, midpoint); + farChild = buildFromPoints(points, midpoint, right); } - return new Node(vp.value, threshold, leftChild, rightChild); + return new Node(vp.value, threshold, nearChild, farChild); } - protected class Nearest { + protected class Searcher { - private T query; - private int numMatches; - private PriorityQueue matches; + private final T query; + private final int numMatches; + private final PriorityQueue matches; - public Nearest(T query, int numMatches) { + public Searcher(T query, int numMatches) { this.query = query; this.numMatches = numMatches; this.matches = new PriorityQueue(numMatches); } - public double distanceFrom(T value) { + public double queryDistance(T value) { return metric.distance(query, value); } + /** + * Farthest match distance from query + * @return + */ public double tau() { return matches.peek().distance; } @@ -130,6 +137,7 @@ public class VpTree { matches.add(new WithDistance(value, distance)); } if (matches.size() > numMatches) { + // drop the one with the largest distance matches.poll(); } } @@ -139,35 +147,35 @@ public class VpTree { private T value; private double threshold; - private Node left; - private Node right; + private Node near; + private Node far; public Node(T value, double threshold, Node left, Node right) { this.value = value; this.threshold = threshold; - this.left = left; - this.right = right; + this.near = left; + this.far = right; } - void search(Nearest qip) { + void search(Searcher searcher) { - double distance = qip.distanceFrom(value); - qip.add(value, distance); + double distance = searcher.queryDistance(value); + searcher.add(value, distance); if (distance < threshold) { - // try left child first - if (left != null && distance - qip.tau() <= threshold) { - left.search(qip); + // try near child first + if (near != null && distance - searcher.tau() <= threshold) { + near.search(searcher); } - if (right != null && distance + qip.tau() >= threshold) { - right.search(qip); + if (far != null && distance + searcher.tau() >= threshold) { + far.search(searcher); } } else { - // try right child first - if (right != null && distance + qip.tau() >= threshold) { - right.search(qip); + // try far child first + if (far != null && distance + searcher.tau() >= threshold) { + far.search(searcher); } - if (left != null && distance - qip.tau() <= threshold) { - left.search(qip); + if (near != null && distance - searcher.tau() <= threshold) { + near.search(searcher); } } }