A.Yazdiha
A.Yazdiha

Reputation: 1378

What does it mean to sort an array/matrix by the argmax as the key?

I am struggling to understand the mechanism behind a function around the topic of sorting in numpy.

    import numpy as np
    arr = [[8, 5, 9], 
            [3, 9.5, 5], [5.5, 4, 3.5], [6, 2, 1],
            [6,1,2],[3,2,1],[8,5,3]]
    res = sorted(arr, key=np.argmax)

This gives me the following result:

    print(res)
    [[5.5, 4, 3.5], [6, 2, 1], [6, 1, 2],
      [3, 2, 1], [8, 5, 3], [3, 9.5, 5], [8, 5, 9]]  

I am an R user and not very familiar with Python. I might have some clue about the role of the 'key' argument, but for this example specifically I ask for your help. In a simple case if the key argument is defined as a function which returns the first element, then the sorted, sorts the array based on its first element, but I cannot see how this works with the argmax. Thanks,

Upvotes: 4

Views: 544

Answers (2)

Jean-François Fabre
Jean-François Fabre

Reputation: 140196

The argmax function returns the indice of the biggest element. It is used as a key in the sort function.

If you print this:

print([np.argmax(x) for x in arr])

you get:

[2, 1, 0, 0, 0, 0, 0]

which explains the sorting. Last elements appear first in your result, first element appears last because it has the highest criteria, and second element appears just before.

Of course this is a "weak" sorting since the criteria often returns the same value and thus the result depends on the order of the initial list (edit: this is called a stable sorting, see interesting Bakuriu comment)

Upvotes: 2

David Stansby
David Stansby

Reputation: 1911

np.argmax gives you the argument of the maximum value. In your example, it is acting on each individual list of 3 items, for example

>>> np.argmax([8,5,3])
0
>>> np.argmax([1,2,3])
2

Upvotes: 0

Related Questions