phntm
phntm

Reputation: 541

How can I index each occurrence of a max value along a given axis of a numpy array?

Suppose I have the following numpy array.

Q = np.array([[0,1,1],[1,0,1],[0,2,0]) 

Question: How do I identify the position of each max value along axis 1? So the desired output would be something like:

array([[1,2],[0,2],[1]]) # The dtype of the output is not required to be a np array.

With np.argmax I can identify the first occurrence of the maximum along the axis, but not the subsequent values.

In: np.argmax(Q, axis =1) 
Out: array([1, 0, 1])    

I've also seen answers that rely on using np.argwhere that use a term like this.

np.argwhere(Q == np.amax(Q)) 

This will also not work here because I can't limit argwhere to work along a single axis. I also can't just flatten out the np array to a single axis because the max's in each row will differ. I need to identify each instance of the max of each row.

Is there a pythonic way to achieve this without looping through each row of the entire array, or is there a function analogous to np.argwhere that accepts an axis argument?

Any insight would be appreciated thanks!

Upvotes: 1

Views: 430

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150735

Try with np.where:

np.where(Q == Q.max(axis=1)[:,None])

Output:

(array([0, 0, 1, 1, 2]), array([1, 2, 0, 2, 1]))

Not quite the output you want, but contains equivalent information.

You can also use np.argwhere which gives you the zip data:

np.argwhere(Q==Q.max(axis=1)[:,None])

Output:

array([[0, 1],
       [0, 2],
       [1, 0],
       [1, 2],
       [2, 1]])

Upvotes: 1

Related Questions