/* vim: set expandtab tabstop=4 shiftwidth=4 : */ public class KdTree { private static final double XMIN = 0; private static final double YMIN = 0; private static final double XMAX = 1; private static final double YMAX = 1; private Node root; private int size; private class Node { private Point2D p; private Node left, right; public Node(Point2D p) { this.p = p; } } private class NearestChampion { private Point2D p; private double d; public NearestChampion(Point2D p, double d) { this.p = p; this.d = d; } } // construct an empty tree of points public KdTree() { size = 0; root = null; } // is the tree empty? public boolean isEmpty() { return (size == 0); } // number of points in the tree public int size() { return size; } // add the point p to the tree (if it is not already in the tree) public void insert(Point2D p) { if (p == null) return; if (root == null) { root = new Node(p); size = 1; } else if (put(root, p, true)) size += 1; } private boolean put(Node n, Point2D p, boolean vsplit) { int cmp; if (vsplit) cmp = Point2D.X_ORDER.compare(p, n.p); else cmp = Point2D.Y_ORDER.compare(p, n.p); if (cmp < 0) { // add to left subtree if (n.left != null) return put(n.left, p, !vsplit); n.left = new Node(p); return true; } if (cmp == 0) { if (vsplit) { if (Point2D.Y_ORDER.compare(p, n.p) == 0) return false; } else if (Point2D.X_ORDER.compare(p, n.p) == 0) return false; } // add to right subtree if (n.right != null) return put(n.right, p, !vsplit); n.right = new Node(p); return true; } // does the tree contain the point p? public boolean contains(Point2D p) { if ((p == null) || (root == null)) return false; return get(root, p, true); } private boolean get(Node n, Point2D p, boolean vsplit) { if (p.compareTo(n.p) == 0) return true; int cmp; if (vsplit) cmp = Point2D.X_ORDER.compare(p, n.p); else cmp = Point2D.Y_ORDER.compare(p, n.p); if (cmp < 0) { if (n.left == null) return false; return get(n.left, p, !vsplit); } if (n.right == null) return false; return get(n.right, p, !vsplit); } // draw all of the points to standard draw public void draw() { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.BLACK); StdDraw.rectangle(XMAX/2.0, YMAX/2.0, XMAX/2.0, YMAX/2.0); draw(root, true, new RectHV(XMIN, YMIN, XMAX, YMAX)); } private void draw(Node n, boolean vsplit, RectHV r) { if (n == null) return; if (vsplit) { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.RED); StdDraw.line(n.p.x(), r.ymin(), n.p.x(), r.ymax()); if (n.left != null) draw(n.left, !vsplit, new RectHV(r.xmin(), r.ymin(), n.p.x(), r.ymax())); if (n.right != null) draw(n.right, !vsplit, new RectHV(n.p.x(), r.ymin(), r.xmax(), r.ymax())); } else { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.BLUE); StdDraw.line(r.xmin(), n.p.y(), r.xmax(), n.p.y()); if (n.left != null) draw(n.left, !vsplit, new RectHV(r.xmin(), r.ymin(), r.xmax(), n.p.y())); if (n.right != null) draw(n.right, !vsplit, new RectHV(r.xmin(), n.p.y(), r.xmax(), r.ymax())); } StdDraw.setPenRadius(.01); StdDraw.setPenColor(StdDraw.BLACK); n.p.draw(); } // all points in the tree that are inside the rectangle public Iterable range(RectHV rect) { Stack stack = new Stack(); if ((rect == null) || (root == null)) return stack; range(root, rect, stack, true); return stack; } private void range(Node n, RectHV r, Stack s, boolean vsplit) { if (r.contains(n.p)) s.push(n.p); if (vsplit) { if (n.left != null && r.xmin() < n.p.x()) range(n.left, r, s, !vsplit); if (n.right != null && r.xmax() >= n.p.x()) range(n.right, r, s, !vsplit); } else { if (n.left != null && r.ymin() < n.p.y()) range(n.left, r, s, !vsplit); if (n.right != null && r.ymax() >= n.p.y()) range(n.right, r, s, !vsplit); } } // a nearest neighbor in the tree to p; null if tree is empty public Point2D nearest(Point2D p) { if ((p == null) || (root == null)) return null; NearestChampion ncp = new NearestChampion(null, Double.MAX_VALUE); nearest(root, p, ncp, true); return ncp.p; } private void nearest(Node n, Point2D p, NearestChampion ncp, boolean vsplit) { double d2 = p.distanceSquaredTo(n.p); if (d2 < ncp.d) { ncp.d = d2; ncp.p = n.p; } if (vsplit) { double d3 = n.p.x() - p.x(); if (d3 > 0) { if (n.left != null) nearest(n.left, p, ncp, !vsplit); if (n.right != null && ((d3 * d3) < ncp.d)) nearest(n.right, p, ncp, !vsplit); } else { if (n.right != null) nearest(n.right, p, ncp, !vsplit); if (n.left != null && ((d3 * d3) < ncp.d)) nearest(n.left, p, ncp, !vsplit); } } else { double d3 = n.p.y() - p.y(); if (d3 > 0) { if (n.left != null) nearest(n.left, p, ncp, !vsplit); if (n.right != null && ((d3 * d3) < ncp.d)) nearest(n.right, p, ncp, !vsplit); } else { if (n.right != null) nearest(n.right, p, ncp, !vsplit); if (n.left != null && ((d3 * d3) < ncp.d)) nearest(n.left, p, ncp, !vsplit); } } } }