Some comments an naming fixes
This commit is contained in:
parent
969add5846
commit
70c928b703
1 changed files with 39 additions and 31 deletions
|
|
@ -61,9 +61,9 @@ public class VpTree<T> {
|
||||||
|
|
||||||
public List<WithDistance> search(final T target, int numResults) {
|
public List<WithDistance> search(final T target, int numResults) {
|
||||||
|
|
||||||
Nearest acc = new Nearest(target, numResults);
|
Searcher searcher = new Searcher(target, numResults);
|
||||||
root.search(acc);
|
root.search(searcher);
|
||||||
PriorityQueue<WithDistance> heap = acc.matches;
|
PriorityQueue<WithDistance> heap = searcher.matches;
|
||||||
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
||||||
while (!heap.isEmpty()) {
|
while (!heap.isEmpty()) {
|
||||||
results.add(heap.remove());
|
results.add(heap.remove());
|
||||||
|
|
@ -82,8 +82,8 @@ public class VpTree<T> {
|
||||||
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
||||||
|
|
||||||
WithDistance vp = points.get(left);
|
WithDistance vp = points.get(left);
|
||||||
Node leftChild = null;
|
Node nearChild = null;
|
||||||
Node rightChild = null;
|
Node farChild = null;
|
||||||
double threshold = 0d;
|
double threshold = 0d;
|
||||||
|
|
||||||
if (right - left > 1) {
|
if (right - left > 1) {
|
||||||
|
|
@ -92,35 +92,42 @@ public class VpTree<T> {
|
||||||
int pos = left + random.nextInt(len);
|
int pos = left + random.nextInt(len);
|
||||||
Collections.swap(points, left, pos);
|
Collections.swap(points, left, pos);
|
||||||
vp = points.get(left);
|
vp = points.get(left);
|
||||||
|
// sort the rest by distance from said point
|
||||||
for (int i = left + 1; i < right; i++) {
|
for (int i = left + 1; i < right; i++) {
|
||||||
WithDistance wd = points.get(i);
|
WithDistance wd = points.get(i);
|
||||||
wd.distance = metric.distance(vp.value, wd.value);
|
wd.distance = metric.distance(vp.value, wd.value);
|
||||||
}
|
}
|
||||||
Collections.sort(points.subList(left + 1, right));
|
Collections.sort(points.subList(left + 1, right));
|
||||||
|
// Use the median distance as a split
|
||||||
int midpoint = (left + 1 + right) / 2;
|
int midpoint = (left + 1 + right) / 2;
|
||||||
threshold = points.get(midpoint).distance;
|
threshold = points.get(midpoint).distance;
|
||||||
leftChild = buildFromPoints(points, left + 1, midpoint);
|
// The leftmost points are closer, the rightmost are farher away
|
||||||
rightChild = buildFromPoints(points, midpoint, right);
|
nearChild = buildFromPoints(points, left + 1, midpoint);
|
||||||
|
farChild = buildFromPoints(points, midpoint, right);
|
||||||
}
|
}
|
||||||
return new Node(vp.value, threshold, leftChild, rightChild);
|
return new Node(vp.value, threshold, nearChild, farChild);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected class Nearest {
|
protected class Searcher {
|
||||||
|
|
||||||
private T query;
|
private final T query;
|
||||||
private int numMatches;
|
private final int numMatches;
|
||||||
private PriorityQueue<WithDistance> matches;
|
private final PriorityQueue<WithDistance> matches;
|
||||||
|
|
||||||
public Nearest(T query, int numMatches) {
|
public Searcher(T query, int numMatches) {
|
||||||
this.query = query;
|
this.query = query;
|
||||||
this.numMatches = numMatches;
|
this.numMatches = numMatches;
|
||||||
this.matches = new PriorityQueue<WithDistance>(numMatches);
|
this.matches = new PriorityQueue<WithDistance>(numMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double distanceFrom(T value) {
|
public double queryDistance(T value) {
|
||||||
return metric.distance(query, value);
|
return metric.distance(query, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Farthest match distance from query
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public double tau() {
|
public double tau() {
|
||||||
return matches.peek().distance;
|
return matches.peek().distance;
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +137,7 @@ public class VpTree<T> {
|
||||||
matches.add(new WithDistance(value, distance));
|
matches.add(new WithDistance(value, distance));
|
||||||
}
|
}
|
||||||
if (matches.size() > numMatches) {
|
if (matches.size() > numMatches) {
|
||||||
|
// drop the one with the largest distance
|
||||||
matches.poll();
|
matches.poll();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -139,35 +147,35 @@ public class VpTree<T> {
|
||||||
|
|
||||||
private T value;
|
private T value;
|
||||||
private double threshold;
|
private double threshold;
|
||||||
private Node left;
|
private Node near;
|
||||||
private Node right;
|
private Node far;
|
||||||
|
|
||||||
public Node(T value, double threshold, Node left, Node right) {
|
public Node(T value, double threshold, Node left, Node right) {
|
||||||
this.value = value;
|
this.value = value;
|
||||||
this.threshold = threshold;
|
this.threshold = threshold;
|
||||||
this.left = left;
|
this.near = left;
|
||||||
this.right = right;
|
this.far = right;
|
||||||
}
|
}
|
||||||
|
|
||||||
void search(Nearest qip) {
|
void search(Searcher searcher) {
|
||||||
|
|
||||||
double distance = qip.distanceFrom(value);
|
double distance = searcher.queryDistance(value);
|
||||||
qip.add(value, distance);
|
searcher.add(value, distance);
|
||||||
if (distance < threshold) {
|
if (distance < threshold) {
|
||||||
// try left child first
|
// try near child first
|
||||||
if (left != null && distance - qip.tau() <= threshold) {
|
if (near != null && distance - searcher.tau() <= threshold) {
|
||||||
left.search(qip);
|
near.search(searcher);
|
||||||
}
|
}
|
||||||
if (right != null && distance + qip.tau() >= threshold) {
|
if (far != null && distance + searcher.tau() >= threshold) {
|
||||||
right.search(qip);
|
far.search(searcher);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// try right child first
|
// try far child first
|
||||||
if (right != null && distance + qip.tau() >= threshold) {
|
if (far != null && distance + searcher.tau() >= threshold) {
|
||||||
right.search(qip);
|
far.search(searcher);
|
||||||
}
|
}
|
||||||
if (left != null && distance - qip.tau() <= threshold) {
|
if (near != null && distance - searcher.tau() <= threshold) {
|
||||||
left.search(qip);
|
near.search(searcher);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue