Reputation: 169
I have an issue in apply argmax to an array which has multiple brackets. In real life I am getting this as a result of a pytorch tensor. Here I can put an example:
a = np.array([[1.0, 1.1],[2.1,2.0]])
np.argmax(a,axis=1)
array([1, 0])
It is correct. But:
a = np.array([[[1.0, 1.1]],[[2.1,2.0]]])
np.argmax(a,axis=1)
array([[0, 0],
[0, 0]])
It does not give me what I expect. Consider that in reality I have this level of inner brackets:
a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
Upvotes: 0
Views: 192
Reputation: 14399
Use .squeeze()
and a negative index.
a = np.array([[[[1.0, 1.1]]], [[[2.1, 2.0]]]])
np.argmax(a, axis = -1).squeeze()
array([1, 0], dtype=int32)
Upvotes: 1
Reputation: 169
A possible solution is to increment axis value:
a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
np.argmax(a,axis=3)
array([[[1]],
[[0]]])
But I still have inner brackets.
Upvotes: 0