Reputation: 25
I need to find k
nearest neighbors for each object from a set. Each object has its coordinates as properties.
To solve the task, I am trying to use spatial.KDTree
from scipy
. It works fine if I use a list or tuple to represent a point, but it doesn't work for objects.
I implemented __getitem__
and __len__
methods in my class, but KDTree
implementation asks my objects for non-existing coordinate axis (say for 3-rd coordinate of 2-dimensional point).
Here is a simple script to reproduce the problem:
from scipy import spatial
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
else:
raise Exception('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)
It is not necessary for me to use this specific implementation or library or even algorithm, but the requirement is to deal with objects.
P.S. I may add id
field to each object and move all coordinates into separate array where index is object id
. But I still want to avoid such approach if possible.
Upvotes: 2
Views: 3367
Reputation: 23002
The docs for scipy.spatial.KDTree
state that the data
parameter should be array_like
which generally means "convertible to a numpy array." And indeed, the first line of initialization tries to convert the data to a numpy array, as you can see in the source code:
class KDTree(object):
""" ... """
def __init__(self, data, leafsize=10):
self.data = np.asarray(data)
So what you want to achieve is an object so that a list of them converts nicely to a numpy array. This is hard to exactly define as numpy tries many ways to make your object into an array. However, an iterable containing many sequences of the same length definitely qualifies.
Your Unit
object is basically a sequence, since it implements __len__
and __getitem__
and indexes with sequential integers starting at 0. Python knows when your sequence ends from it hitting an IndexError
. But your __getitem__
raises an Exception
on a bad index instead. So the normal mechanism for providing sequential iteration from those two methods breaks. Instead, raise an IndexError
, and you'll convert nicely:
class Unit:
def __init__(self, x, y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
raise IndexError('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
Now we can check a list of these converts into a numpy array with no problems:
In [5]: np.array([Unit(1, 1), Unit(2, 2), Unit(3, 3), Unit(4, 4), Unit(5, 5)])
Out[5]:
array([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]])
So, we should have no problem initializing the KDTree
now. This is why if you stored the coords in an internal list and just deferred __getitem__
to that list, or simply treated your coords as some simple sequence like a list or tuple, you'd be fine.
An easier method with simple classes like this would be using namedtuples
or similar, but for more complex objects, turning them into sequences is a good way to go.
Upvotes: 3
Reputation: 1054
The class probably needs to access slices of the object. But with your definition, it is impossible to use a slice (try Unit(6, 6)[:]
, it will throw the same Error).
One way to deal with this is to hold the x and y variable in a list:
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
self.data = [x, y]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return 2
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
result = tree.query(Unit(6,6), 3)
print(result)
(array([1.41421356, 2.82842712, 4.24264069]), array([4, 3, 2]))
Upvotes: 1