Reputation: 68
I implemented a basic nearest-neighbors search in my study work. The fact is that the basic numpy implementation is working well, but just adding the '@jit' decorator (compiling in Numba), the outputs are differents (it duplicates some neighbors in the end for some unknown reason...)
Here is the basic algorithm:
import numpy as np
from numba import jit
@jit(nopython=True)
def knn(p, points, k):
'''Find the k nearest neighbors (brute force) of the point p
in the list points (each row is a point)'''
n = p.size # Lenght of the points
M = points.shape[0] # Number of points
neighbors = np.zeros((k,n))
distances = 1e6*np.ones(k)
for i in xrange(M):
d = 0
pt = points[i, :] # Point to compare
for r in xrange(n): # For each coordinate
aux = p[r] - pt[r]
d += aux * aux
if d < distances[k-1]: # We find a new neighbor
pos = k-1
while pos>0 and d<distances[pos-1]: # Find the position
pos -= 1
pt = points[i, :]
# Insert neighbor and distance:
neighbors[pos+1:, :] = neighbors[pos:-1, :]
neighbors[pos, :] = pt
distances[pos+1:] = distances[pos:-1]
distances[pos] = d
return neighbors, distances
For testing:
p = np.random.rand(10)
points = np.random.rand(250, 10)
k = 5
neighbors = knn(p, points, k)
WITHOUT the @jit decorator, one gets the correct answer:
In [1]: distances
Out[1]: array([ 0.3933974 , 0.44754336, 0.54548715, 0.55619749, 0.5657846 ])
But the Numba compilation gives weird outputs:
Out[2]: distances
Out[2]: array([ 0.3933974 , 0.44754336, 0.54548715, 0.54548715, 0.54548715])
Somebody can help? I don't realize why it happens...
Thanks you.
Upvotes: 4
Views: 199
Reputation: 68682
I believe the issue is that Numba is handling writing one slice into another differently when those slices are overlapping than when running without. I'm not familiar with the internals of numpy, but perhaps there is special logic to handle dealing with volatile memory operations like this, that aren't there in Numba. Change the following lines and the results with the jit decorator become consistent with the plain python version:
neighbors[pos+1:, :] = neighbors[pos:-1, :].copy()
...
distances[pos+1:] = distances[pos:-1].copy()
Upvotes: 1