Reputation: 17900
I have the following array:
arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
I would like to get the indices of arr
that contain an array whose max value is greater or equal than .9. So, for this case, the result would be [1]
because the array with index 1 [.9, .1]
is the only one whose max value is >= 9.
I tried:
>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5, 0.5])
But, as you see, it yields the wrong answer.
Upvotes: 2
Views: 2962
Reputation: 9610
Use max
along an axis to get the row max values, and then where
to get the indexes of the biggest:
np.where(arr.max(axis=1)>=0.9)
Upvotes: 1
Reputation: 176750
I think you want np.where
here. This function returns the indices of any values which meet a particular condition:
>>> np.where(arr >= 0.9)[0] # here we look at the whole 2D array
array([1])
(np.where(arr >= 0.9)
returns a tuple of arrays of indices, one for each axis of the array. Your expected output implies that you only want the row indices (axis 0).)
If you want to take the maximum of each row first, you can use arr.max(axis=1)
:
>>> np.where(arr.max(axis=1) >= 0.9)[0] # here we look at the 1D array of row maximums
array([1])
Upvotes: 3
Reputation: 686
In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])
Upvotes: 2
Reputation: 10298
The reason you are getting the wrong answer is because np.max(arr)
gives you the max of the flattened array. You want np.max(arr, axis=1)
or, better yet, arr.max(axis=1)
.
(arr.max(axis=1)>=.9).nonzero()
Upvotes: 1