Shew
Shew

Reputation: 1596

Numpy where with array comparison

I have an array called Y which contains class labels. I want to find all the indexes of Y that match multiple values specified by a list labs.

In this case:

Y = np.array([1,2,3,1,2,3,1,2,3,1,2,3])
labs = [2,3]

How can I do something like np.where(Y == labs) that returns

array([1,2,4,5,7,8,10,11])

I know one possibility is to iterate through the list labs and do element wise comparison. But I am looking for a more pythonic/numpy based solution which avoids looping.

Upvotes: 1

Views: 332

Answers (1)

willeM_ Van Onsem
willeM_ Van Onsem

Reputation: 476537

You can use np.where(..) [numpy-doc] on an np.isin(..) [numpy-doc] here:

>>> np.where(np.isin(Y, L))[0]
array([ 1,  2,  4,  5,  7,  8, 10, 11])

The .isin(Y, L) will give us an array of True and False where the item of Y matches an element in L:

>>> np.isin(Y, labs)
array([False,  True,  True, False,  True,  True, False,  True,  True,
       False,  True,  True])

and with np.where(..) we map the Trues to the corresponding indices.

As @hpaulj says, for small Ls, we can write this as:

np.any([Y == li for li in labs],axis=0)

here, for each element in labs, we will check if Y is that elements, and we use np.any(..) to make a "chain of logical ORs" in between to fold it to a boolean.

Upvotes: 1

Related Questions