usr-N
usr-N

Reputation: 25

Find k nearest neighbors using kd-tree in python when coordinates are held in objects

I need to find k nearest neighbors for each object from a set. Each object has its coordinates as properties. To solve the task, I am trying to use spatial.KDTree from scipy. It works fine if I use a list or tuple to represent a point, but it doesn't work for objects. I implemented __getitem__ and __len__ methods in my class, but KDTree implementation asks my objects for non-existing coordinate axis (say for 3-rd coordinate of 2-dimensional point).

Here is a simple script to reproduce the problem:

from scipy import spatial

class Unit:

    def __init__(self, x,y):
        self.x = x
        self.y = y


    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        else:          
            raise Exception('Unit coordinates are 2 dimensional')


    def __len__(self):        
        return 2



#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]

tree = spatial.KDTree(points)

#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)

print(result)

It is not necessary for me to use this specific implementation or library or even algorithm, but the requirement is to deal with objects.

P.S. I may add id field to each object and move all coordinates into separate array where index is object id. But I still want to avoid such approach if possible.

Upvotes: 2

Views: 3367

Answers (2)

alkasm
alkasm

Reputation: 23002

The docs for scipy.spatial.KDTree state that the data parameter should be array_like which generally means "convertible to a numpy array." And indeed, the first line of initialization tries to convert the data to a numpy array, as you can see in the source code:

class KDTree(object):
    """ ... """
    def __init__(self, data, leafsize=10):
        self.data = np.asarray(data)

So what you want to achieve is an object so that a list of them converts nicely to a numpy array. This is hard to exactly define as numpy tries many ways to make your object into an array. However, an iterable containing many sequences of the same length definitely qualifies.

Your Unit object is basically a sequence, since it implements __len__ and __getitem__ and indexes with sequential integers starting at 0. Python knows when your sequence ends from it hitting an IndexError. But your __getitem__ raises an Exception on a bad index instead. So the normal mechanism for providing sequential iteration from those two methods breaks. Instead, raise an IndexError, and you'll convert nicely:

class Unit:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):        
        if index == 0:
            return self.x
        elif index == 1:
            return self.y
        raise IndexError('Unit coordinates are 2 dimensional')

    def __len__(self):        
        return 2

Now we can check a list of these converts into a numpy array with no problems:

In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
       [2, 2],
       [3, 3],
       [4, 4],
       [5, 5]])

So, we should have no problem initializing the KDTree now. This is why if you stored the coords in an internal list and just deferred __getitem__ to that list, or simply treated your coords as some simple sequence like a list or tuple, you'd be fine.

An easier method with simple classes like this would be using namedtuples or similar, but for more complex objects, turning them into sequences is a good way to go.

Upvotes: 3

Horace
Horace

Reputation: 1054

The class probably needs to access slices of the object. But with your definition, it is impossible to use a slice (try Unit(6, 6)[:], it will throw the same Error).

One way to deal with this is to hold the x and y variable in a list:

class Unit:
    def __init__(self, x,y):
        self.x = x
        self.y = y
        self.data = [x, y]

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):        
        return 2

points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)

print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))

Upvotes: 1

Related Questions