From a07627f48990aa58f0a0667b06f5779daee55254 Mon Sep 17 00:00:00 2001 From: Knut Forkalsrud Date: Tue, 1 Jan 2013 16:33:26 -0800 Subject: [PATCH] Experiment with figuring out similar images. --- src/main/java/org/forkalsrud/util/VpTree.java | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 src/main/java/org/forkalsrud/util/VpTree.java diff --git a/src/main/java/org/forkalsrud/util/VpTree.java b/src/main/java/org/forkalsrud/util/VpTree.java new file mode 100644 index 0000000..b875e02 --- /dev/null +++ b/src/main/java/org/forkalsrud/util/VpTree.java @@ -0,0 +1,167 @@ +package org.forkalsrud.util; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Random; + +public class VpTree { + + public static interface Metric { + public double distance(T a, T b); + } + + private Metric metric; + private Random random; + private Node root; + + public VpTree(Collection items, Metric metric) { + + this.metric = metric; + this.random = new Random(); + ArrayList points = new ArrayList(items.size()); + for (T t : items) { + points.add(new WithDistance(t, 0d)); + } + this.root = buildFromPoints(points, 0, points.size()); + } + + + private class WithDistance implements Comparable { + + private T value; + private double distance; + + public WithDistance(T value, double distance) { + this.value = value; + this.distance = distance; + } + + public T getValue() { + return this.value; + } + + public double getDistance() { + return this.distance; + } + + @Override + public int compareTo(WithDistance other) { + double delta = other.distance - this.distance; + return (int) Math.signum(delta); + } + } + + public List search(final T target, int numResults) { + + PriorityQueue heap = new PriorityQueue(numResults); + root.search(new Nearest(target, numResults)); + + ArrayList results = new ArrayList(heap.size()); + while (!heap.isEmpty()) { + results.add(heap.remove()); + } + Collections.reverse(results); + return results; + } + + static void swap(ArrayList points, int a, int b) { + Object temp = points.get(a); + points.set(a, points.get(b)); + points.set(b, temp); + } + + + Node buildFromPoints(ArrayList points, int left, int right) { + + if (right <= left) { + return null; + } + // choose an arbitrary point and move it to the start + int len = right - left; + int pos = left + random.nextInt(len); + Collections.swap(points, left, pos); + final WithDistance vp = points.get(left); + for (int i = left + 1; i < right; i++) { + WithDistance wd = points.get(i); + wd.distance = metric.distance(vp.value, wd.value); + } + Collections.sort(points.subList(left + 1, right)); + int midpoint = (left + 1 + right) / 2; + double threshold = points.get(midpoint).distance; + Node leftChild = buildFromPoints(points, left + 1, midpoint); + Node rightChild = buildFromPoints(points, midpoint, right); + return new Node(vp.value, threshold, leftChild, rightChild); + } + + protected class Nearest { + + private T query; + private int numMatches; + private PriorityQueue matches; + + public Nearest(T query, int numMatches) { + this.query = query; + this.numMatches = numMatches; + this.matches = new PriorityQueue(numMatches, Collections.reverseOrder()); + } + + public double distanceFrom(T value) { + return metric.distance(query, value); + } + + public double tau() { + return matches.peek().distance; + } + + public void add(T value, double distance) { + if (matches.size() < numMatches || distance < tau()) { + matches.add(new WithDistance(value, distance)); + } + if (matches.size() > numMatches) { + matches.poll(); + } + } + } + + protected class Node { + + private T value; + private double threshold; + private Node left; + private Node right; + + public Node(T value, double threshold, Node left, Node right) { + this.value = value; + this.threshold = threshold; + this.left = left; + this.right = right; + } + + void search(Nearest qip) { + + double distance = qip.distanceFrom(value); + qip.add(value, distance); + if (distance < threshold) { + // try left child first + if (left != null && distance - qip.tau() <= threshold) { + left.search(qip); + } + if (right != null && distance + qip.tau() >= threshold) { + right.search(qip); + } + } else { + // try right child first + if (right != null && distance + qip.tau() >= threshold) { + right.search(qip); + } + if (left != null && distance - qip.tau() <= threshold) { + left.search(qip); + } + } + } + } + +}