Jaroslav Bezděk
Jaroslav Bezděk

Reputation: 7625

Python: sklearn.neighbors.KDTree not working as expected

I am writing a program that should select points that are located in neighborhood of another point. A neighborhood size is specified by radius. I am using sklearn.neighbors.KDTree algorithm for this. However, it is not working as I expected.

To show you what I am dealing with, I have got two data frames:

When I try to hardcode what I expect from KDTree, it seems that every single point from df_example_points should be extracted by KDTree as a point that lays inside a reference point neighbourhood.

>>> radius = 0.27
>>> x_ref, y_ref, z_ref = df_reference_point.iloc[0]
>>> x_min, x_max = x_ref - radius, x_ref + radius
>>> y_min, y_max = y_ref - radius, y_ref + radius
>>> z_min, z_max = z_ref - radius, z_ref + radius
>>> for i, (x, y, z) in df_example_points.iterrows():
...     if all([x_min <= x <= x_max, y_min <= y <= y_max, z_min <= z <= z_max]):
...         print(f'Point {i} SHOULD be extracted.')
...     else:
...         print(f'Point {i} SHOULD NOT be extracted.')
Point 0 SHOULD be extracted.
Point 1 SHOULD be extracted.
Point 2 SHOULD be extracted.
Point 3 SHOULD be extracted.

However, when I try to use KDTree, only one point is extracted.

>>> tree = KDTree(df_example_points.values)
>>> extracted_points_indices = tree.query_radius(df_reference_point.values.reshape(1, -1), radius)[0]
>>> print(f'Number of extracted points: {len(extracted_points_indices)}')
Number of extracted points: 1

I want to use KDTree, because the implementation is much more faster. However, I cannot use it, when the result is not reliable. Please, could you help me, what am I doing wrong? What am I missing?

Upvotes: 2

Views: 1130

Answers (1)

FBruzzesi
FBruzzesi

Reputation: 6485

As @Gabriel commented, you are using two different distance metrics. The KDTree default is minkowski, while you are using chebyshev (you can check sklearn possible metrics here: DistanceMetric).

Changing the default will give your expected result:

tree = KDTree(df_example_points.values, metric='chebyshev')
extracted_points_indices = tree.query_radius(df_reference_point.values.reshape(1, -1), radius)[0]

print(f'Number of extracted points: {len(extracted_points_indices)}')
Number of extracted points: 4

Upvotes: 1

Related Questions