Reputation: 989
I want to find rows of a numpy array that are members of a set. For example:
wanted=set([(1,2),(8,9)])
z=np.array([[1,2],[8,8],[2,3]])
The result should be [1,2].
I could use a list comprehension:
[b for b in z if tuple(b) in wanted]
but this is slow when z has many rows and columns. Is there a faster way to do this?
Thank You!
Upvotes: 0
Views: 59
Reputation: 221574
One vectorized approach would be to -
Convert set wanted
to a Numpy array with map()
and np.vstack
.
Extend dimensions of Numpy array version of wanted
with None/np.newaxis
to form a 3D array and compare against z
bringing in broadcasting
.
Check for ALL True rows and ANY True first axis match, giving us a mask that could be used to index into z
for final selection.
Implementation -
wanted_arr = np.vstack((map(np.array,wanted)))
out = z[((wanted_arr[:,None] == z).all(2)).any(0)]
Sample run -
In [64]: z
Out[64]:
array([[1, 2],
[8, 8],
[2, 3]])
In [65]: wanted
Out[65]: {(1, 2), (8, 9)}
In [66]: wanted_arr = np.vstack((map(np.array,wanted)))
In [67]: wanted_arr
Out[67]:
array([[1, 2],
[8, 9]])
In [68]: z[((wanted_arr[:,None] == z).all(2)).any(0)]
Out[68]: array([[1, 2]])
Upvotes: 2