Knovolt
Knovolt

Reputation: 115

How to detect a tie in a numpy array when using argmax

If I have an array like below, how can I detect that there is a tie of at least 3 or more values when using np.argmax()?

examp = np.array([[4, 0, 1, 4, 4],
                  [5, 5, 1, 5, 5],
                  [1, 2, 2, 4, 1],
                  [4, 6, 1, 2, 4],
                  [1, 4, 3, 3, 3]])

np.argmax(examp, axis=1)

which gives an output:

array([0, 0, 3, 1, 1]

Taking the first row as an example, there is a "3-way tie". 3 values of 4. np.argmax returns the first index that has the max value. But, how can I detect that there is a "3-way tie" going on and have it decide the tie breaker with a custom function (on the condition that there is at least a "3-way tie" occurring?

So, first row: sees that there is a "3-way tie" of 4s. Custom function runs so that it can decide the tie-breaker.

Second row: "4-way tie" same thing happens.

Third row: only "2-way tie" which is less than condition of at least a "3-way tie". Can default to np.argmax.

Upvotes: 3

Views: 1774

Answers (2)

Mad Physicist
Mad Physicist

Reputation: 114478

One way for finding the n-th maximum is np.partition (or np.argpartition). In this case you can do something like this:

>>> n = 3  # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)

The values in the third-to-last and last columns are guaranteed to be in the correct sort order (and therefore the second-to-last as well, but only in this limited case). If those two values are equal to each other, then you have a 3-way tie:

>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True,  True, False, False, False])

You can also use np.diff to compute the mask:

>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

You can get a similar result by using np.take_along_axis instead of the first index r:

>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

In all these cases, the value of argmax is just i[:, -1], since that's the index of the maximum value in the array.

Since you are already using numpy, I highly recommend that you vectorize the custom tie-breaking function as well. I've provided the output as a mask here so that you can do exactly that as efficiently as possible.

Upvotes: 1

Kevin
Kevin

Reputation: 3368

You are correct that np.argmax will only find the first max value. Although you could count how many of these argmax exist and base your logic of that number

indices = examp.argmax(0)
counts = (examp == examp[indices, np.r_[:3]]).sum(0)
# the same as
counts = np.count_nonzero(examp == examp[indices, np.r_[:3]], axis=0)

Will return

indices = array([0, 3, 2])
counts = array([4, 1, 2])

Upvotes: 1

Related Questions