/* vim: set expandtab tabstop=4 shiftwidth=4 : */ public class KdTree { private static final double XMIN = 0; private static final double YMIN = 0; private static final double XMAX = 1; private static final double YMAX = 1; private Node root; private int size; private class Node { private Point2D p; private Node left, right; public Node(Point2D p) { this.p = p; } } // construct an empty tree of points public KdTree() { size = 0; root = null; } // is the tree empty? public boolean isEmpty() { return (size == 0); } // number of points in the tree public int size() { return size; } // add the point p to the tree (if it is not already in the tree) public void insert(Point2D p) { if (p == null) return; if (root == null) { root = new Node(p); size = 1; } else if (put(root, p, true)) size += 1; } private boolean put(Node n, Point2D p, boolean vsplit) { int cmp; if (vsplit) cmp = Point2D.X_ORDER.compare(p, n.p); else cmp = Point2D.Y_ORDER.compare(p, n.p); if (cmp < 0) { // add to left subtree if (n.left != null) return put(n.left, p, !vsplit); n.left = new Node(p); return true; } if (cmp == 0) { if (vsplit) { if (Point2D.Y_ORDER.compare(p, n.p) == 0) return false; } else if (Point2D.X_ORDER.compare(p, n.p) == 0) return false; } // add to right subtree if (n.right != null) return put(n.right, p, !vsplit); n.right = new Node(p); return true; } // does the tree contain the point p? public boolean contains(Point2D p) { if ((p == null) || (root == null)) return false; return get(root, p, true); } private boolean get(Node n, Point2D p, boolean vsplit) { if (p.compareTo(n.p) == 0) return true; int cmp; if (vsplit) cmp = Point2D.X_ORDER.compare(p, n.p); else cmp = Point2D.Y_ORDER.compare(p, n.p); if (cmp < 0) { if (n.left == null) return false; return get(n.left, p, !vsplit); } if (n.right == null) return false; return get(n.right, p, !vsplit); } // draw all of the points to standard draw public void draw() { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.BLACK); StdDraw.rectangle(XMAX/2.0, YMAX/2.0, XMAX/2.0, YMAX/2.0); draw(root, true, new RectHV(XMIN, YMIN, XMAX, YMAX)); } private void draw(Node n, boolean vsplit, RectHV r) { if (n == null) return; if (vsplit) { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.RED); StdDraw.line(n.p.x(), r.ymin(), n.p.x(), r.ymax()); if (n.left != null) draw(n.left, !vsplit, new RectHV(r.xmin(), r.ymin(), n.p.x(), r.ymax())); if (n.right != null) draw(n.right, !vsplit, new RectHV(n.p.x(), r.ymin(), r.xmax(), r.ymax())); } else { StdDraw.setPenRadius(); StdDraw.setPenColor(StdDraw.BLUE); StdDraw.line(r.xmin(), n.p.y(), r.xmax(), n.p.y()); if (n.left != null) draw(n.left, !vsplit, new RectHV(r.xmin(), r.ymin(), r.xmax(), n.p.y())); if (n.right != null) draw(n.right, !vsplit, new RectHV(r.xmin(), n.p.y(), r.xmax(), r.ymax())); } StdDraw.setPenRadius(.01); StdDraw.setPenColor(StdDraw.BLACK); n.p.draw(); } // all points in the tree that are inside the rectangle public Iterable range(RectHV rect) { Stack stack = new Stack(); if ((rect == null) || (root == null)) return stack; range(root, rect, stack, true); return stack; } private void range(Node n, RectHV r, Stack s, boolean vsplit) { if (r.contains(n.p)) s.push(n.p); if (vsplit) { if (n.left != null && r.xmin() < n.p.x()) range(n.left, r, s, !vsplit); if (n.right != null && r.xmax() >= n.p.x()) range(n.right, r, s, !vsplit); } else { if (n.left != null && r.ymin() < n.p.y()) range(n.left, r, s, !vsplit); if (n.right != null && r.ymax() >= n.p.y()) range(n.right, r, s, !vsplit); } } // a nearest neighbor in the tree to p; null if tree is empty public Point2D nearest(Point2D p) { if ((p == null) || (root == null)) return null; return nearest(root, p, null, Double.MAX_VALUE, true); } private Point2D nearest(Node n, Point2D p, Point2D cp, double d, boolean vsplit) { Point2D cp0 = cp; double d1 = d; double d2 = p.distanceSquaredTo(n.p); if (d2 < d1) { cp0 = n.p; d1 = d2; } if (vsplit) { d2 = n.p.x() - p.x(); if (d2 > 0) { if (n.left != null) { cp0 = nearest(n.left, p, cp0, d1, !vsplit); d1 = p.distanceSquaredTo(cp0); } if (n.right != null && ((d2 * d2) < d1)) cp0 = nearest(n.right, p, cp0, d1, !vsplit); } else { if (n.right != null) { cp0 = nearest(n.right, p, cp0, d1, !vsplit); d1 = p.distanceSquaredTo(cp0); } if (n.left != null && ((d2 * d2) < d1)) cp0 = nearest(n.left, p, cp0, d1, !vsplit); } } else { d2 = n.p.y() - p.y(); if (d2 > 0) { if (n.left != null) { cp0 = nearest(n.left, p, cp0, d1, !vsplit); d1 = p.distanceSquaredTo(cp0); } if (n.right != null && ((d2 * d2) < d1)) cp0 = nearest(n.right, p, cp0, d1, !vsplit); } else { if (n.right != null) { cp0 = nearest(n.right, p, cp0, d1, !vsplit); d1 = p.distanceSquaredTo(cp0); } if (n.left != null && ((d2 * d2) < d1)) cp0 = nearest(n.left, p, cp0, d1, !vsplit); } } return cp0; } }