MikeRand
MikeRand

Reputation: 4828

Numpy: change max in each row to 1, all other numbers to 0

I'm trying to implement a numpy function that replaces the max in each row of a 2D array with 1, and all other numbers with zero:

>>> a = np.array([[0, 1],
...               [2, 3],
...               [4, 5],
...               [6, 7],
...               [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]

What I've tried so far

def some_function(x):
    a = np.zeros(x.shape)
    a[:,np.argmax(x, axis=1)] = 1
    return a

>>> b = some_function(a)
>>> b
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]

Upvotes: 38

Views: 34802

Answers (4)

Mour_Ka
Mour_Ka

Reputation: 258

b = (a == np.max(a))

That worked for me

Upvotes: -1

Alex Riina
Alex Riina

Reputation: 801

a==np.max(a) will raise an error in the future, so here's a tweaked version that will continue to broadcast correctly.

I know this question is pretty ancient, but I think I have a decent solution that's a bit different from the other solutions.

# get max by row and convert from (n, ) -> (n, 1) which will broadcast
row_maxes = a.max(axis=1).reshape(-1, 1)
np.where(a == row_maxes, 1, 0)
np.where(a == row_maxes).astype(int)

if the update needs to be in place, you can do

a[:] = np.where(a == row_maxes, 1, 0)

Upvotes: 5

Cyclone
Cyclone

Reputation: 2133

I prefer using numpy.where like so:

a[np.where(a==np.max(a))] = 1

Upvotes: 7

DSM
DSM

Reputation: 353199

Method #1, tweaking yours:

>>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [1, 0]])

[Actually, range will work just fine; I wrote arange out of habit.]

Method #2, using max instead of argmax to handle the case where multiple elements reach the maximum value:

>>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
       [1, 1],
       [1, 0]])

Upvotes: 44

Related Questions