user2309803
user2309803

Reputation: 655

NumPy: How to retrieve the indices of the maximum values in a multidimensional array

With the following array:

In [103]: da                                                                                         
Out[103]: 
array([[[ 6, 22,  3],
        [ 4,  9, 20],
        [21, 16,  0]],

       [[ 2, 25, 11],
        [ 5, 17, 18],
        [23, 13,  7]],

       [[10, 14, 26],
        [ 8,  1, 19],
        [15, 12, 24]]])

In [104]: da.shape                                                                                   
Out[104]: (3, 3, 3)

The indices of the element with the maximum value can be determined with the following:

In [114]: np.unravel_index(np.argmax(da), da.shape)                                                  
Out[114]: (2, 0, 2)

and checked:

In [115]: da[2, 0, 2]                                                                                
Out[115]: 26

But would one determine, without looping/iterating, the 9 indices containing the maximum values for each group of integers da[:, i1, i2] where where i1 and i2 are 0, 1 or 2?

For example, the group da[:, 0, 0] is 6, 2 and 10. The maximum value is 10 and its indices are da[2, 0, 0].

Upvotes: 2

Views: 56

Answers (1)

Mad Physicist
Mad Physicist

Reputation: 114578

The axis argument allows you to specify a single axis of operation:

i0 = np.argmax(da, axis=0)

This means that i0 is a (3, 3) array containing the index of the maximum for each corresponding i1, i2. The maximum for any i1, i2 in your example is

da[i0[i1, i2], i1, i2]

Upvotes: 3

Related Questions