album/src/main/java/org/forkalsrud/util/VpTree.java

168 lines
4.9 KiB
Java
Raw Normal View History

package org.forkalsrud.util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
public class VpTree<T> {
public static interface Metric<T> {
public double distance(T a, T b);
}
private Metric metric;
private Random random;
private Node root;
public VpTree(Collection<T> items, Metric<T> metric) {
this.metric = metric;
this.random = new Random();
ArrayList<WithDistance> points = new ArrayList<WithDistance>(items.size());
for (T t : items) {
points.add(new WithDistance(t, 0d));
}
this.root = buildFromPoints(points, 0, points.size());
}
private class WithDistance implements Comparable<WithDistance> {
private T value;
private double distance;
public WithDistance(T value, double distance) {
this.value = value;
this.distance = distance;
}
public T getValue() {
return this.value;
}
public double getDistance() {
return this.distance;
}
@Override
public int compareTo(WithDistance other) {
double delta = other.distance - this.distance;
return (int) Math.signum(delta);
}
}
public List<WithDistance> search(final T target, int numResults) {
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
root.search(new Nearest(target, numResults));
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
while (!heap.isEmpty()) {
results.add(heap.remove());
}
Collections.reverse(results);
return results;
}
static void swap(ArrayList points, int a, int b) {
Object temp = points.get(a);
points.set(a, points.get(b));
points.set(b, temp);
}
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
if (right <= left) {
return null;
}
// choose an arbitrary point and move it to the start
int len = right - left;
int pos = left + random.nextInt(len);
Collections.swap(points, left, pos);
final WithDistance vp = points.get(left);
for (int i = left + 1; i < right; i++) {
WithDistance wd = points.get(i);
wd.distance = metric.distance(vp.value, wd.value);
}
Collections.sort(points.subList(left + 1, right));
int midpoint = (left + 1 + right) / 2;
double threshold = points.get(midpoint).distance;
Node leftChild = buildFromPoints(points, left + 1, midpoint);
Node rightChild = buildFromPoints(points, midpoint, right);
return new Node(vp.value, threshold, leftChild, rightChild);
}
protected class Nearest {
private T query;
private int numMatches;
private PriorityQueue<WithDistance> matches;
public Nearest(T query, int numMatches) {
this.query = query;
this.numMatches = numMatches;
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
}
public double distanceFrom(T value) {
return metric.distance(query, value);
}
public double tau() {
return matches.peek().distance;
}
public void add(T value, double distance) {
if (matches.size() < numMatches || distance < tau()) {
matches.add(new WithDistance(value, distance));
}
if (matches.size() > numMatches) {
matches.poll();
}
}
}
protected class Node {
private T value;
private double threshold;
private Node left;
private Node right;
public Node(T value, double threshold, Node left, Node right) {
this.value = value;
this.threshold = threshold;
this.left = left;
this.right = right;
}
void search(Nearest qip) {
double distance = qip.distanceFrom(value);
qip.add(value, distance);
if (distance < threshold) {
// try left child first
if (left != null && distance - qip.tau() <= threshold) {
left.search(qip);
}
if (right != null && distance + qip.tau() >= threshold) {
right.search(qip);
}
} else {
// try right child first
if (right != null && distance + qip.tau() >= threshold) {
right.search(qip);
}
if (left != null && distance - qip.tau() <= threshold) {
left.search(qip);
}
}
}
}
}