168 lines
4.9 KiB
Java
168 lines
4.9 KiB
Java
|
|
package org.forkalsrud.util;
|
||
|
|
|
||
|
|
import java.util.ArrayList;
|
||
|
|
import java.util.Collection;
|
||
|
|
import java.util.Collections;
|
||
|
|
import java.util.List;
|
||
|
|
import java.util.PriorityQueue;
|
||
|
|
import java.util.Random;
|
||
|
|
|
||
|
|
public class VpTree<T> {
|
||
|
|
|
||
|
|
public static interface Metric<T> {
|
||
|
|
public double distance(T a, T b);
|
||
|
|
}
|
||
|
|
|
||
|
|
private Metric metric;
|
||
|
|
private Random random;
|
||
|
|
private Node root;
|
||
|
|
|
||
|
|
public VpTree(Collection<T> items, Metric<T> metric) {
|
||
|
|
|
||
|
|
this.metric = metric;
|
||
|
|
this.random = new Random();
|
||
|
|
ArrayList<WithDistance> points = new ArrayList<WithDistance>(items.size());
|
||
|
|
for (T t : items) {
|
||
|
|
points.add(new WithDistance(t, 0d));
|
||
|
|
}
|
||
|
|
this.root = buildFromPoints(points, 0, points.size());
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
private class WithDistance implements Comparable<WithDistance> {
|
||
|
|
|
||
|
|
private T value;
|
||
|
|
private double distance;
|
||
|
|
|
||
|
|
public WithDistance(T value, double distance) {
|
||
|
|
this.value = value;
|
||
|
|
this.distance = distance;
|
||
|
|
}
|
||
|
|
|
||
|
|
public T getValue() {
|
||
|
|
return this.value;
|
||
|
|
}
|
||
|
|
|
||
|
|
public double getDistance() {
|
||
|
|
return this.distance;
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public int compareTo(WithDistance other) {
|
||
|
|
double delta = other.distance - this.distance;
|
||
|
|
return (int) Math.signum(delta);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
public List<WithDistance> search(final T target, int numResults) {
|
||
|
|
|
||
|
|
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
|
||
|
|
root.search(new Nearest(target, numResults));
|
||
|
|
|
||
|
|
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
||
|
|
while (!heap.isEmpty()) {
|
||
|
|
results.add(heap.remove());
|
||
|
|
}
|
||
|
|
Collections.reverse(results);
|
||
|
|
return results;
|
||
|
|
}
|
||
|
|
|
||
|
|
static void swap(ArrayList points, int a, int b) {
|
||
|
|
Object temp = points.get(a);
|
||
|
|
points.set(a, points.get(b));
|
||
|
|
points.set(b, temp);
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
||
|
|
|
||
|
|
if (right <= left) {
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
// choose an arbitrary point and move it to the start
|
||
|
|
int len = right - left;
|
||
|
|
int pos = left + random.nextInt(len);
|
||
|
|
Collections.swap(points, left, pos);
|
||
|
|
final WithDistance vp = points.get(left);
|
||
|
|
for (int i = left + 1; i < right; i++) {
|
||
|
|
WithDistance wd = points.get(i);
|
||
|
|
wd.distance = metric.distance(vp.value, wd.value);
|
||
|
|
}
|
||
|
|
Collections.sort(points.subList(left + 1, right));
|
||
|
|
int midpoint = (left + 1 + right) / 2;
|
||
|
|
double threshold = points.get(midpoint).distance;
|
||
|
|
Node leftChild = buildFromPoints(points, left + 1, midpoint);
|
||
|
|
Node rightChild = buildFromPoints(points, midpoint, right);
|
||
|
|
return new Node(vp.value, threshold, leftChild, rightChild);
|
||
|
|
}
|
||
|
|
|
||
|
|
protected class Nearest {
|
||
|
|
|
||
|
|
private T query;
|
||
|
|
private int numMatches;
|
||
|
|
private PriorityQueue<WithDistance> matches;
|
||
|
|
|
||
|
|
public Nearest(T query, int numMatches) {
|
||
|
|
this.query = query;
|
||
|
|
this.numMatches = numMatches;
|
||
|
|
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
|
||
|
|
}
|
||
|
|
|
||
|
|
public double distanceFrom(T value) {
|
||
|
|
return metric.distance(query, value);
|
||
|
|
}
|
||
|
|
|
||
|
|
public double tau() {
|
||
|
|
return matches.peek().distance;
|
||
|
|
}
|
||
|
|
|
||
|
|
public void add(T value, double distance) {
|
||
|
|
if (matches.size() < numMatches || distance < tau()) {
|
||
|
|
matches.add(new WithDistance(value, distance));
|
||
|
|
}
|
||
|
|
if (matches.size() > numMatches) {
|
||
|
|
matches.poll();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
protected class Node {
|
||
|
|
|
||
|
|
private T value;
|
||
|
|
private double threshold;
|
||
|
|
private Node left;
|
||
|
|
private Node right;
|
||
|
|
|
||
|
|
public Node(T value, double threshold, Node left, Node right) {
|
||
|
|
this.value = value;
|
||
|
|
this.threshold = threshold;
|
||
|
|
this.left = left;
|
||
|
|
this.right = right;
|
||
|
|
}
|
||
|
|
|
||
|
|
void search(Nearest qip) {
|
||
|
|
|
||
|
|
double distance = qip.distanceFrom(value);
|
||
|
|
qip.add(value, distance);
|
||
|
|
if (distance < threshold) {
|
||
|
|
// try left child first
|
||
|
|
if (left != null && distance - qip.tau() <= threshold) {
|
||
|
|
left.search(qip);
|
||
|
|
}
|
||
|
|
if (right != null && distance + qip.tau() >= threshold) {
|
||
|
|
right.search(qip);
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
// try right child first
|
||
|
|
if (right != null && distance + qip.tau() >= threshold) {
|
||
|
|
right.search(qip);
|
||
|
|
}
|
||
|
|
if (left != null && distance - qip.tau() <= threshold) {
|
||
|
|
left.search(qip);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
}
|