a.smiet
a.smiet

Reputation: 1825

python point indices in KDTree

Given a list of points, how can I get their indices in a KDTree?

from scipy import spatial
import numpy as np

#some data
x, y = np.mgrid[0:3, 0:3]
data = zip(x.ravel(), y.ravel())

points = [[0,1], [2,2]]

#KDTree
tree = spatial.cKDTree(data)

# incices of points in tree should be [1,8]

I could do something like:

[tree.query_ball_point(i,r=0) for i in points]

>>> [[1], [8]]

Does it make sense to do it that way?

Upvotes: 6

Views: 2195

Answers (1)

ali_m
ali_m

Reputation: 74172

Use cKDTree.query(x, k, ...) to find the k nearest neighbours to a given set of points x:

distances, indices = tree.query(points, k=1)
print(repr(indices))
# array([1, 8])

In a trivial case such as this, where your dataset and your set of query points are both small, and where each query point is identical to a single row within the dataset, it would be faster to use simple boolean operations with broadcasting rather than building and querying a k-D tree:

data, points = np.array(data), np.array(points)
indices = (data[..., None] == points.T).all(1).argmax(0)

data[..., None] == points.T broadcasts out to an (nrows, ndims, npoints) array, which could quickly become expensive in terms of memory for larger datasets. In such cases you might get better performance out of a normal for loop or list comprehension:

indices = [(data == p).all(1).argmax() for p in points]

Upvotes: 2

Related Questions