Shamoon
Shamoon

Reputation: 43491

How can I use np argmax on an array of lists?

I have some data y_hat that looks like:

[[0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]]

I want to get the argmax of each row so that I end up with a vector like:

[[3]
 [8]
 [8]
 ...
 [5]
 [1]
 [7]]

If I just do np.argmax(y_hat), it returns a 1.

Upvotes: 1

Views: 1882

Answers (2)

BENY
BENY

Reputation: 323226

Here is one way after argmax with numpy broadcast

a.argmax(axis = 1)[:,None]

Or

a[:,None].argmax(-1)

Upvotes: 1

wim
wim

Reputation: 362557

np.argmax accepts an axis keyword argument. Use that.

It's axis=0 for columns, axis=1 for rows.

Upvotes: 2

Related Questions