Reputation: 1149
I am given two numpy-arrays: One of dimensions i x m
and the other of dimensions j x m
. What I want to do is, loop through the FirstArray and compare each of its elements with each of the elements of the SecondArray. When I say 'compare', I mean: I want to compute the Euclidean distance between the elements of FirstArray and SecondArray. Then, I want to store the index of the element of SecondArray that is closest to the corresponding element of FirstArray, and I also want to store the index of the element of SecondArray that is second closest to the element of the FirstArray.
In code this would look somewhat similar to this:
smallest = None
idx = 0
for i in range(0, FirstArrayRows):
for j in range(0, SecondArrayRows):
EuclideanDistance = np.sqrt(np.sum(np.square(FirstArray[i,:] - SecondArray[j,:])))
if smallest is None or EuclideanDistance < smallest:
smallest = EuclideanDistance
idx_second = idx
idx = j
Closest[i] = idx
SecondClosest[i] = idx_second
And I think this works. However, there are two cases when this code fails to give the correct index for the second closest element of SecondArray:
So I wonder: Is there a better way of implementing this? I know there is. Maybe someone can help me see it?
Upvotes: 0
Views: 261
Reputation: 4146
You could use numpy's broadcasting to your advantage. Compute the Euclidean distance with all elements of the second array in a single operation. Then, you can find the two smallest distances using argpartition.
import numpy as np
i, j, m = 3, 4, 5
a = np.random.choice(10,(i,m))
b = np.random.choice(10,(j,m))
print('First array:\n',a)
print('Second array:\n',b)
closest, second_closest = np.zeros(i), np.zeros(i)
for i in range(a.shape[0]):
dist = np.sqrt(((a[i,:] - b)**2).sum(axis=1))
closest[i], second_closest[i] = np.argpartition(dist, 2)[:2]
print('Closest:', closest)
print('Second Closest:', second_closest)
Output:
First array:
[[3 9 0 2 2]
[1 2 9 9 7]
[4 0 6 6 4]]
Second array:
[[9 9 2 2 3]
[9 9 0 2 3]
[1 1 6 7 7]
[5 7 0 4 4]]
Closest: [3. 2. 2.]
Second Closest: [1. 3. 3.]
Upvotes: 2