Reputation: 2489
I have 2 nd arrays
where each row is a 3D
point and one array is much bigger than the other.
i.e
array([[1., 2., 3.],
[2.01, 5., 1.],
[3., 3., 4.],
[1., 4., 1.],
[3., 6., 7.01]])
array([[3.02, 3.01, 4.0],
[1.01, 1.99, 3.01],
[2.98, 6.01, 7.01]])
And I know each point in the second array correspond to a point in the first array.
I would like to get the list of indices of correspondence,
I.e for this example it would be
array([2,0,4])
as the first point in the second array is similar to the third point in the first array, the second point in the second array is similar to the first point in the first array, etc.
Upvotes: 2
Views: 117
Reputation: 11602
You can do this efficiently with a KDTree
.
import numpy as np
from scipy.spatial import KDTree
x = np.array([[1., 2., 3.],
[2.01, 5., 1.],
[3., 3., 4.],
[1., 4., 1.],
[3., 6., 7.01]])
y = np.array([[1.01, 1.99, 3.01],
[3.02, 3.01, 4.0],
[2.98, 6.01, 7.01]])
result = KDTree(x).query(y)[1]
# In [16]: result
# Out[16]: array([0, 2, 4])
Thanks to Divakar for pointing out that scipy
also provides a C implementation of KDTree
, called cKDTree
. It is 10x faster for the following benchmark:
x = np.random.rand(100_000, 3)
y = np.random.rand(100, 3)
def benchmark(TreeClass):
return TreeClass(x).query(y)[1]
In [23]: %timeit q.benchmark(KDTree)
322 ms ± 7.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [24]: %timeit q.benchmark(cKDTree)
36.5 ms ± 763 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Upvotes: 4
Reputation: 221704
We can extend one of those to 3D
and then with a given tolerance parameter ( which in the given sample case seems to be something <= 0.2) compare for closeness with np.isclose()
or np.abs()<tolerance
and finally get ALL
matches along last axis and get the indices -
In [88]: a
Out[88]:
array([[1. , 2. , 3. ],
[2.01, 5. , 1. ],
[3. , 3. , 4. ],
[1. , 4. , 1. ],
[3. , 6. , 7.01]])
In [89]: b
Out[89]:
array([[3.02, 3.01, 4. ],
[1.01, 1.99, 3.01],
[2.98, 6.01, 7.01]])
In [90]: r,c = np.nonzero(np.isclose(a[:,None],b, atol=0.02).all(2))
In [91]: r[c]
Out[91]: array([2, 0, 4])
Upvotes: 2