Some comments an naming fixes

This commit is contained in:
Knut Forkalsrud 2014-01-04 16:11:40 -08:00
parent 969add5846
commit 70c928b703

View file

@ -61,9 +61,9 @@ public class VpTree<T> {
public List<WithDistance> search(final T target, int numResults) { public List<WithDistance> search(final T target, int numResults) {
Nearest acc = new Nearest(target, numResults); Searcher searcher = new Searcher(target, numResults);
root.search(acc); root.search(searcher);
PriorityQueue<WithDistance> heap = acc.matches; PriorityQueue<WithDistance> heap = searcher.matches;
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size()); ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
while (!heap.isEmpty()) { while (!heap.isEmpty()) {
results.add(heap.remove()); results.add(heap.remove());
@ -82,8 +82,8 @@ public class VpTree<T> {
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) { Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
WithDistance vp = points.get(left); WithDistance vp = points.get(left);
Node leftChild = null; Node nearChild = null;
Node rightChild = null; Node farChild = null;
double threshold = 0d; double threshold = 0d;
if (right - left > 1) { if (right - left > 1) {
@ -92,35 +92,42 @@ public class VpTree<T> {
int pos = left + random.nextInt(len); int pos = left + random.nextInt(len);
Collections.swap(points, left, pos); Collections.swap(points, left, pos);
vp = points.get(left); vp = points.get(left);
// sort the rest by distance from said point
for (int i = left + 1; i < right; i++) { for (int i = left + 1; i < right; i++) {
WithDistance wd = points.get(i); WithDistance wd = points.get(i);
wd.distance = metric.distance(vp.value, wd.value); wd.distance = metric.distance(vp.value, wd.value);
} }
Collections.sort(points.subList(left + 1, right)); Collections.sort(points.subList(left + 1, right));
// Use the median distance as a split
int midpoint = (left + 1 + right) / 2; int midpoint = (left + 1 + right) / 2;
threshold = points.get(midpoint).distance; threshold = points.get(midpoint).distance;
leftChild = buildFromPoints(points, left + 1, midpoint); // The leftmost points are closer, the rightmost are farher away
rightChild = buildFromPoints(points, midpoint, right); nearChild = buildFromPoints(points, left + 1, midpoint);
farChild = buildFromPoints(points, midpoint, right);
} }
return new Node(vp.value, threshold, leftChild, rightChild); return new Node(vp.value, threshold, nearChild, farChild);
} }
protected class Nearest { protected class Searcher {
private T query; private final T query;
private int numMatches; private final int numMatches;
private PriorityQueue<WithDistance> matches; private final PriorityQueue<WithDistance> matches;
public Nearest(T query, int numMatches) { public Searcher(T query, int numMatches) {
this.query = query; this.query = query;
this.numMatches = numMatches; this.numMatches = numMatches;
this.matches = new PriorityQueue<WithDistance>(numMatches); this.matches = new PriorityQueue<WithDistance>(numMatches);
} }
public double distanceFrom(T value) { public double queryDistance(T value) {
return metric.distance(query, value); return metric.distance(query, value);
} }
/**
* Farthest match distance from query
* @return
*/
public double tau() { public double tau() {
return matches.peek().distance; return matches.peek().distance;
} }
@ -130,6 +137,7 @@ public class VpTree<T> {
matches.add(new WithDistance(value, distance)); matches.add(new WithDistance(value, distance));
} }
if (matches.size() > numMatches) { if (matches.size() > numMatches) {
// drop the one with the largest distance
matches.poll(); matches.poll();
} }
} }
@ -139,35 +147,35 @@ public class VpTree<T> {
private T value; private T value;
private double threshold; private double threshold;
private Node left; private Node near;
private Node right; private Node far;
public Node(T value, double threshold, Node left, Node right) { public Node(T value, double threshold, Node left, Node right) {
this.value = value; this.value = value;
this.threshold = threshold; this.threshold = threshold;
this.left = left; this.near = left;
this.right = right; this.far = right;
} }
void search(Nearest qip) { void search(Searcher searcher) {
double distance = qip.distanceFrom(value); double distance = searcher.queryDistance(value);
qip.add(value, distance); searcher.add(value, distance);
if (distance < threshold) { if (distance < threshold) {
// try left child first // try near child first
if (left != null && distance - qip.tau() <= threshold) { if (near != null && distance - searcher.tau() <= threshold) {
left.search(qip); near.search(searcher);
} }
if (right != null && distance + qip.tau() >= threshold) { if (far != null && distance + searcher.tau() >= threshold) {
right.search(qip); far.search(searcher);
} }
} else { } else {
// try right child first // try far child first
if (right != null && distance + qip.tau() >= threshold) { if (far != null && distance + searcher.tau() >= threshold) {
right.search(qip); far.search(searcher);
} }
if (left != null && distance - qip.tau() <= threshold) { if (near != null && distance - searcher.tau() <= threshold) {
left.search(qip); near.search(searcher);
} }
} }
} }