Kevin Fang
Kevin Fang

Reputation: 2012

numpy argmax when values are equal

I got a numpy matrix and I want to get the index of the maximum value in each row. E.g.

[[1,2,3],[1,3,2],[3,2,1]]

will return

[0,1,2]

However, when there're more than 1 maximum value in each row, numpy.argmax will only return the smallest index. E.g.

[[0,0,0],[0,0,0],[0,0,0]]

will return

[0,0,0]

Can I change the default (smallest index) to some other values? E.g. when there're equal maximum values, return 1 or None, so that the above result will be

[1,1,1]
or
[None, None, None]

If I can do this in TensorFlow that'll be better.

Thanks!

Upvotes: 3

Views: 13518

Answers (2)

kuppern87
kuppern87

Reputation: 1135

You can use np.partition two find the two largest values and check if they are equal, and then use that as a mask in np.where to set the default value:

In [228]: a = np.array([[1, 2, 3, 2], [3, 1, 3, 2], [3, 5, 2, 1]])

In [229]: twomax = np.partition(a, -2)[:, -2:].T

In [230]: default = -1

In [231]: argmax = np.where(twomax[0] != twomax[1], np.argmax(a, -1), default)

In [232]: argmax
Out[232]: array([ 2, -1,  1])

Upvotes: 3

user6655984
user6655984

Reputation:

A convenient value of "default" is -1, as argmax will not return that on its own. None does not fit in an integer array. A masked array is also an option, but I didn't go that far. Here is a NumPy implementation

def my_argmax(a):
    rows = np.where(a == a.max(axis=1)[:, None])[0]
    rows_multiple_max = rows[:-1][rows[:-1] == rows[1:]]
    my_argmax = a.argmax(axis=1)
    my_argmax[rows_multiple_max] = -1
    return my_argmax

Example of use:

import numpy as np
a = np.array([[0, 0, 0], [4, 5, 3], [3, 4, 4], [6, 2, 1]])
my_argmax(a)   #  array([-1,  1, -1,  0])

Explanation: where selects the indexes of all maximal elements in each row. If a row has multiple maxima, the row number will appear more than once in rows array. Since this array is already sorted, such repetition is detected by comparing consecutive elements. This identifies the rows with multiple maxima, after which they are masked in the output of NumPy's argmax method.

Upvotes: 1

Related Questions