Reputation: 43491
I have some data y_hat
that looks like:
[[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
...
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]]
I want to get the argmax
of each row so that I end up with a vector like:
[[3]
[8]
[8]
...
[5]
[1]
[7]]
If I just do np.argmax(y_hat)
, it returns a 1
.
Upvotes: 1
Views: 1882
Reputation: 323226
Here is one way after argmax
with numpy
broadcast
a.argmax(axis = 1)[:,None]
Or
a[:,None].argmax(-1)
Upvotes: 1
Reputation: 362557
np.argmax
accepts an axis
keyword argument. Use that.
It's axis=0
for columns, axis=1
for rows.
Upvotes: 2