cabbage_queen
cabbage_queen

Reputation: 5

Construction of KD-Tree

I'm trying to construct a KD-Tree, but I'm getting an error where the root node doesn't have any children which it should have.

I think it's a problem with the recursion, but I can't figure out why.

A minimal reproducible sample.

import numpy as np

class Node():
    def __init__(self, point, min_ax, max_ax, axis, left=None, right=None):
        self.point = point
        self.min_ax = min_ax
        self.max_ax = max_ax
        self.axis = axis

        self.left = None
        self.right = None

def construct_kd_tree(points, axis=0):
    if len(points) == 0:
        return None

    # sort triangles with the first axis of the center
    vals = list(sorted(points, key=lambda x: x[axis]))

    median = len(points) // 2

    left = construct_kd_tree(vals[:median])
    right = construct_kd_tree(vals[median+1:])

    # print(left,right)

    return Node(vals[median],
                vals[0][axis],
                vals[-1][axis],
                axis,
                left=left,
                right=right)


points = np.random.rand(10,3)
node = construct_kd_tree(points)

print(node.left)  # None
print(node.right) # None

Upvotes: 0

Views: 110

Answers (1)

Musabbir Arrafi
Musabbir Arrafi

Reputation: 1885

See if this solves your problem:

class Node():
    def __init__(self, point, min_ax, max_ax, axis, left, right):
        self.point = point
        self.min_ax = min_ax
        self.max_ax = max_ax
        self.axis = axis
        self.left = left
        self.right = right
        
def construct_kd_tree(points, axis=0):
    if len(points) == 0:
        return None

    vals = sorted(points, key=lambda x: x[axis])
    median = len(points) // 2

    return Node(vals[median], vals[0][axis], vals[-1][axis], axis,
                left=construct_kd_tree(vals[:median], (axis + 1) % len(points[0])),
                right=construct_kd_tree(vals[median+1:], (axis + 1) % len(points[0])))

# Example usage
points = [(1, 2), (5, 3), (8, 1), (3, 6)]
kd_tree = construct_kd_tree(points)

Upvotes: 0

Related Questions