package org.forkalsrud.util; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.PriorityQueue; import java.util.Random; public class VpTree { public static interface Metric { public double distance(T a, T b); } private Metric metric; private Random random; private Node root; public VpTree(Collection items, Metric metric) { this.metric = metric; this.random = new Random(); ArrayList points = new ArrayList(items.size()); for (T t : items) { points.add(new WithDistance(t, 0d)); } this.root = buildFromPoints(points, 0, points.size()); } private class WithDistance implements Comparable { private T value; private double distance; public WithDistance(T value, double distance) { this.value = value; this.distance = distance; } public T getValue() { return this.value; } public double getDistance() { return this.distance; } @Override public int compareTo(WithDistance other) { double delta = other.distance - this.distance; return (int) Math.signum(delta); } } public List search(final T target, int numResults) { PriorityQueue heap = new PriorityQueue(numResults); root.search(new Nearest(target, numResults)); ArrayList results = new ArrayList(heap.size()); while (!heap.isEmpty()) { results.add(heap.remove()); } Collections.reverse(results); return results; } static void swap(ArrayList points, int a, int b) { Object temp = points.get(a); points.set(a, points.get(b)); points.set(b, temp); } Node buildFromPoints(ArrayList points, int left, int right) { if (right <= left) { return null; } // choose an arbitrary point and move it to the start int len = right - left; int pos = left + random.nextInt(len); Collections.swap(points, left, pos); final WithDistance vp = points.get(left); for (int i = left + 1; i < right; i++) { WithDistance wd = points.get(i); wd.distance = metric.distance(vp.value, wd.value); } Collections.sort(points.subList(left + 1, right)); int midpoint = (left + 1 + right) / 2; double threshold = points.get(midpoint).distance; Node leftChild = buildFromPoints(points, left + 1, midpoint); Node rightChild = buildFromPoints(points, midpoint, right); return new Node(vp.value, threshold, leftChild, rightChild); } protected class Nearest { private T query; private int numMatches; private PriorityQueue matches; public Nearest(T query, int numMatches) { this.query = query; this.numMatches = numMatches; this.matches = new PriorityQueue(numMatches, Collections.reverseOrder()); } public double distanceFrom(T value) { return metric.distance(query, value); } public double tau() { return matches.peek().distance; } public void add(T value, double distance) { if (matches.size() < numMatches || distance < tau()) { matches.add(new WithDistance(value, distance)); } if (matches.size() > numMatches) { matches.poll(); } } } protected class Node { private T value; private double threshold; private Node left; private Node right; public Node(T value, double threshold, Node left, Node right) { this.value = value; this.threshold = threshold; this.left = left; this.right = right; } void search(Nearest qip) { double distance = qip.distanceFrom(value); qip.add(value, distance); if (distance < threshold) { // try left child first if (left != null && distance - qip.tau() <= threshold) { left.search(qip); } if (right != null && distance + qip.tau() >= threshold) { right.search(qip); } } else { // try right child first if (right != null && distance + qip.tau() >= threshold) { right.search(qip); } if (left != null && distance - qip.tau() <= threshold) { left.search(qip); } } } } }