linkyndy
linkyndy

Reputation: 17900

Numpy condition for getting nested arrays if their max is above a threshold

I have the following array:

arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

I would like to get the indices of arr that contain an array whose max value is greater or equal than .9. So, for this case, the result would be [1] because the array with index 1 [.9, .1] is the only one whose max value is >= 9.

I tried:

>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5,  0.5])

But, as you see, it yields the wrong answer.

Upvotes: 2

Views: 2962

Answers (4)

Alan
Alan

Reputation: 9610

Use max along an axis to get the row max values, and then where to get the indexes of the biggest:

np.where(arr.max(axis=1)>=0.9)

Upvotes: 1

Alex Riley
Alex Riley

Reputation: 176750

I think you want np.where here. This function returns the indices of any values which meet a particular condition:

>>> np.where(arr >= 0.9)[0] # here we look at the whole 2D array
array([1])

(np.where(arr >= 0.9) returns a tuple of arrays of indices, one for each axis of the array. Your expected output implies that you only want the row indices (axis 0).)

If you want to take the maximum of each row first, you can use arr.max(axis=1):

>>> np.where(arr.max(axis=1) >= 0.9)[0] # here we look at the 1D array of row maximums
array([1])

Upvotes: 3

Mike Bessonov
Mike Bessonov

Reputation: 686

In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])

Upvotes: 2

TheBlackCat
TheBlackCat

Reputation: 10298

The reason you are getting the wrong answer is because np.max(arr) gives you the max of the flattened array. You want np.max(arr, axis=1) or, better yet, arr.max(axis=1).

(arr.max(axis=1)>=.9).nonzero()

Upvotes: 1

Related Questions