From 70a59ebbfa592771b5f7892a3d2702209786c3d3 Mon Sep 17 00:00:00 2001 From: Anton Maminov Date: Fri, 12 Apr 2024 17:18:47 +0300 Subject: [PATCH] add detailed comments explaining the purpose and functionality of each part of the code --- src/kd_tree.cr | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/kd_tree.cr b/src/kd_tree.cr index 3e47316..cb336be 100644 --- a/src/kd_tree.cr +++ b/src/kd_tree.cr @@ -2,30 +2,35 @@ require "priority-queue" require "./kd_tree/*" module Kd + # A generic KD-tree implementation where `T` is the type of the points. class Tree(T) + # Represents a node in the KD-tree. Each node stores a pivot point, + # the axis it splits, and references to its left and right children. class Node(T) - getter pivot, split, left, right + getter pivot : T, split : Int32, left : Node(T)?, right : Node(T)? def initialize(@pivot : T, @split : Int32, @left : self?, @right : self?) end end - getter root : Node(T)? - @k : Int32 + getter root : Node(T)? # The root node of the KD-tree + @k : Int32 # Dimensionality of the points + # Constructor for the KD-tree. Takes an array of points of type T and builds the tree. def initialize(points : Array(T)) - @k = points.first.size # assumes all points have the same dimension + @k = points.first.size # Assumes all points have the same dimension @root = build_tree(points, 0) end + # Recursive method to build the KD-tree from a given list of points. private def build_tree(points : Array(T), depth : Int32) : Node(T)? return if points.empty? - axis = depth % @k - points.sort_by!(&.[axis]) - median = points.size // 2 + axis = depth % @k # Determine the axis to split on based on the current depth + points.sort_by!(&.[axis]) # Sort points by the current axis + median = points.size // 2 # Find the median index - # Create node and construct subtrees + # Create a new Node with the median point as pivot, and recursively build the left and right subtrees. Node(T).new( points[median], axis, @@ -34,16 +39,18 @@ module Kd ) end + # Method to find the nearest 'n' points to a given target point. Returns an array of these points. def nearest(target : T, n : Int32 = 1) : Array(T) return [] of T if n < 1 - best_nodes = Priority::Queue(Node(T)).new + best_nodes = Priority::Queue(Node(T)).new # Initialize a priority queue to store the best nodes found - find_n_nearest(@root, target, 0, best_nodes, n) + find_n_nearest(@root, target, 0, best_nodes, n) # Recursively find the nearest nodes - best_nodes.map(&.value.pivot) + best_nodes.map(&.value.pivot) # Extract the pivot points from the nodes and return them end + # Recursive method to find the nearest nodes to a target point. private def find_n_nearest( node : Node(T)?, target : T, @@ -53,24 +60,29 @@ module Kd ) return unless node - axis = depth % @k + axis = depth % @k # Determine the axis to compare based on depth + # Determine which child node to search next, prioritizing the side closer to the target next_node = target[axis] < node.pivot[axis] ? node.left : node.right other_node = target[axis] < node.pivot[axis] ? node.right : node.left + # Recursively search the more likely side first find_n_nearest(next_node, target, depth + 1, best_nodes, n) + # Calculate the distance from the target to the current node's pivot and add to the queue best_nodes.push(distance(target, node.pivot), node) + # Ensure that only the 'n' closest nodes are kept in the queue best_nodes.pop if best_nodes.size > n + # Check if the other side might contain closer points and potentially search there too if other_node && (best_nodes.size < n || (target[axis] - node.pivot[axis]).abs ** 2 < distance(target, best_nodes.last.value.pivot)) find_n_nearest(other_node, target, depth + 1, best_nodes, n) end end + # Calculate squared Euclidean distance between two points of type T. private def distance(m : T, n : T) - # squared euclidean distance (to avoid expensive sqrt operation) @k.times.sum { |i| (m[i] - n[i]) ** 2 } end end