user3433489
user3433489

Reputation: 989

Find rows of numpy array that are members of a set

I want to find rows of a numpy array that are members of a set. For example:

wanted=set([(1,2),(8,9)])

z=np.array([[1,2],[8,8],[2,3]])

The result should be [1,2].

I could use a list comprehension:

[b for b in z if tuple(b) in wanted]

but this is slow when z has many rows and columns. Is there a faster way to do this?

Thank You!

Upvotes: 0

Views: 59

Answers (1)

Divakar
Divakar

Reputation: 221574

One vectorized approach would be to -

  • Convert set wanted to a Numpy array with map() and np.vstack.

  • Extend dimensions of Numpy array version of wanted with None/np.newaxis to form a 3D array and compare against z bringing in broadcasting.

  • Check for ALL True rows and ANY True first axis match, giving us a mask that could be used to index into z for final selection.

Implementation -

wanted_arr = np.vstack((map(np.array,wanted)))
out = z[((wanted_arr[:,None] == z).all(2)).any(0)]

Sample run -

In [64]: z
Out[64]: 
array([[1, 2],
       [8, 8],
       [2, 3]])

In [65]: wanted
Out[65]: {(1, 2), (8, 9)}

In [66]: wanted_arr = np.vstack((map(np.array,wanted)))

In [67]: wanted_arr
Out[67]: 
array([[1, 2],
       [8, 9]])

In [68]: z[((wanted_arr[:,None] == z).all(2)).any(0)]
Out[68]: array([[1, 2]])

Upvotes: 2

Related Questions