A KD-tree (k-dimensional tree) is a binary tree that partitions k-dimensional space by alternating splitting axes. It enables fast nearest-neighbor and range queries in low dimensions — the workhorse of scipy.spatial.cKDTree and many computer-graphics systems.
| Average | Worst | Space | ||||||
|---|---|---|---|---|---|---|---|---|
| Access | Search | Insertion | Deletion | Access | Search | Insertion | Deletion | Worst |
Θ(log n) |
Θ(log n) |
Θ(log n) |
Θ(log n) |
O(n) |
O(n) |
O(n) |
O(n) |
O(n) |
At depth d, split by axis d mod k. Build: pick the median along that axis as the root, recurse on each half. Nearest-neighbor query: descend to the leaf containing the target; on the way back up, check whether the bounding hyperplane could contain a closer point and recurse into the other side if so. O(log n) average for low k; degrades sharply when k > ~20 (curse of dimensionality).
class Node:
__slots__ = ("point", "axis", "left", "right")
def __init__(self, point, axis):
self.point, self.axis = point, axis
self.left = self.right = None
def build(points, depth=0):
if not points: return None
k = len(points[0])
axis = depth % k
points.sort(key=lambda p: p[axis])
mid = len(points) // 2
n = Node(points[mid], axis)
n.left = build(points[:mid], depth + 1)
n.right = build(points[mid + 1:], depth + 1)
return n
def _dist2(a, b):
return sum((x - y) ** 2 for x, y in zip(a, b))
def nearest(node, target, best=None):
if node is None: return best
d = _dist2(node.point, target)
if best is None or d < best[1]:
best = (node.point, d)
diff = target[node.axis] - node.point[node.axis]
near, far = (node.left, node.right) if diff < 0 else (node.right, node.left)
best = nearest(near, target, best)
if diff * diff < best[1]:
best = nearest(far, target, best)
return best
pts = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
tree = build(pts)
print(nearest(tree, (9, 2))) # ((8, 1), 2)
faiss).