Experiment with figuring out similar images.
This commit is contained in:
parent
5e5b8ee3d2
commit
a07627f489
1 changed files with 167 additions and 0 deletions
167
src/main/java/org/forkalsrud/util/VpTree.java
Normal file
167
src/main/java/org/forkalsrud/util/VpTree.java
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package org.forkalsrud.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Random;
|
||||
|
||||
public class VpTree<T> {
|
||||
|
||||
public static interface Metric<T> {
|
||||
public double distance(T a, T b);
|
||||
}
|
||||
|
||||
private Metric metric;
|
||||
private Random random;
|
||||
private Node root;
|
||||
|
||||
public VpTree(Collection<T> items, Metric<T> metric) {
|
||||
|
||||
this.metric = metric;
|
||||
this.random = new Random();
|
||||
ArrayList<WithDistance> points = new ArrayList<WithDistance>(items.size());
|
||||
for (T t : items) {
|
||||
points.add(new WithDistance(t, 0d));
|
||||
}
|
||||
this.root = buildFromPoints(points, 0, points.size());
|
||||
}
|
||||
|
||||
|
||||
private class WithDistance implements Comparable<WithDistance> {
|
||||
|
||||
private T value;
|
||||
private double distance;
|
||||
|
||||
public WithDistance(T value, double distance) {
|
||||
this.value = value;
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public T getValue() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
public double getDistance() {
|
||||
return this.distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(WithDistance other) {
|
||||
double delta = other.distance - this.distance;
|
||||
return (int) Math.signum(delta);
|
||||
}
|
||||
}
|
||||
|
||||
public List<WithDistance> search(final T target, int numResults) {
|
||||
|
||||
PriorityQueue<WithDistance> heap = new PriorityQueue<WithDistance>(numResults);
|
||||
root.search(new Nearest(target, numResults));
|
||||
|
||||
ArrayList<WithDistance> results = new ArrayList<WithDistance>(heap.size());
|
||||
while (!heap.isEmpty()) {
|
||||
results.add(heap.remove());
|
||||
}
|
||||
Collections.reverse(results);
|
||||
return results;
|
||||
}
|
||||
|
||||
static void swap(ArrayList points, int a, int b) {
|
||||
Object temp = points.get(a);
|
||||
points.set(a, points.get(b));
|
||||
points.set(b, temp);
|
||||
}
|
||||
|
||||
|
||||
Node buildFromPoints(ArrayList<WithDistance> points, int left, int right) {
|
||||
|
||||
if (right <= left) {
|
||||
return null;
|
||||
}
|
||||
// choose an arbitrary point and move it to the start
|
||||
int len = right - left;
|
||||
int pos = left + random.nextInt(len);
|
||||
Collections.swap(points, left, pos);
|
||||
final WithDistance vp = points.get(left);
|
||||
for (int i = left + 1; i < right; i++) {
|
||||
WithDistance wd = points.get(i);
|
||||
wd.distance = metric.distance(vp.value, wd.value);
|
||||
}
|
||||
Collections.sort(points.subList(left + 1, right));
|
||||
int midpoint = (left + 1 + right) / 2;
|
||||
double threshold = points.get(midpoint).distance;
|
||||
Node leftChild = buildFromPoints(points, left + 1, midpoint);
|
||||
Node rightChild = buildFromPoints(points, midpoint, right);
|
||||
return new Node(vp.value, threshold, leftChild, rightChild);
|
||||
}
|
||||
|
||||
protected class Nearest {
|
||||
|
||||
private T query;
|
||||
private int numMatches;
|
||||
private PriorityQueue<WithDistance> matches;
|
||||
|
||||
public Nearest(T query, int numMatches) {
|
||||
this.query = query;
|
||||
this.numMatches = numMatches;
|
||||
this.matches = new PriorityQueue<WithDistance>(numMatches, Collections.reverseOrder());
|
||||
}
|
||||
|
||||
public double distanceFrom(T value) {
|
||||
return metric.distance(query, value);
|
||||
}
|
||||
|
||||
public double tau() {
|
||||
return matches.peek().distance;
|
||||
}
|
||||
|
||||
public void add(T value, double distance) {
|
||||
if (matches.size() < numMatches || distance < tau()) {
|
||||
matches.add(new WithDistance(value, distance));
|
||||
}
|
||||
if (matches.size() > numMatches) {
|
||||
matches.poll();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected class Node {
|
||||
|
||||
private T value;
|
||||
private double threshold;
|
||||
private Node left;
|
||||
private Node right;
|
||||
|
||||
public Node(T value, double threshold, Node left, Node right) {
|
||||
this.value = value;
|
||||
this.threshold = threshold;
|
||||
this.left = left;
|
||||
this.right = right;
|
||||
}
|
||||
|
||||
void search(Nearest qip) {
|
||||
|
||||
double distance = qip.distanceFrom(value);
|
||||
qip.add(value, distance);
|
||||
if (distance < threshold) {
|
||||
// try left child first
|
||||
if (left != null && distance - qip.tau() <= threshold) {
|
||||
left.search(qip);
|
||||
}
|
||||
if (right != null && distance + qip.tau() >= threshold) {
|
||||
right.search(qip);
|
||||
}
|
||||
} else {
|
||||
// try right child first
|
||||
if (right != null && distance + qip.tau() >= threshold) {
|
||||
right.search(qip);
|
||||
}
|
||||
if (left != null && distance - qip.tau() <= threshold) {
|
||||
left.search(qip);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue