Reputation: 115
If I have an array like below, how can I detect that there is a tie of at least 3 or more values when using np.argmax()
?
examp = np.array([[4, 0, 1, 4, 4],
[5, 5, 1, 5, 5],
[1, 2, 2, 4, 1],
[4, 6, 1, 2, 4],
[1, 4, 3, 3, 3]])
np.argmax(examp, axis=1)
which gives an output:
array([0, 0, 3, 1, 1]
Taking the first row as an example, there is a "3-way tie". 3 values of 4
. np.argmax
returns the first index that has the max value. But, how can I detect that there is a "3-way tie" going on and have it decide the tie breaker with a custom function (on the condition that there is at least a "3-way tie" occurring?
So, first row: sees that there is a "3-way tie" of 4s. Custom function runs so that it can decide the tie-breaker.
Second row: "4-way tie" same thing happens.
Third row: only "2-way tie" which is less than condition of at least a "3-way tie". Can default to np.argmax
.
Upvotes: 3
Views: 1774
Reputation: 114478
One way for finding the n-th maximum is np.partition
(or np.argpartition
). In this case you can do something like this:
>>> n = 3 # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)
The values in the third-to-last and last columns are guaranteed to be in the correct sort order (and therefore the second-to-last as well, but only in this limited case). If those two values are equal to each other, then you have a 3-way tie:
>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True, True, False, False, False])
You can also use np.diff
to compute the mask:
>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
[ True],
[False],
[False],
[False]])
You can get a similar result by using np.take_along_axis
instead of the first index r
:
>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
[ True],
[False],
[False],
[False]])
In all these cases, the value of argmax
is just i[:, -1]
, since that's the index of the maximum value in the array.
Since you are already using numpy, I highly recommend that you vectorize the custom tie-breaking function as well. I've provided the output as a mask here so that you can do exactly that as efficiently as possible.
Upvotes: 1
Reputation: 3368
You are correct that np.argmax
will only find the first max value. Although you could count how many of these argmax
exist and base your logic of that number
indices = examp.argmax(0)
counts = (examp == examp[indices, np.r_[:3]]).sum(0)
# the same as
counts = np.count_nonzero(examp == examp[indices, np.r_[:3]], axis=0)
Will return
indices = array([0, 3, 2])
counts = array([4, 1, 2])
Upvotes: 1