learning-man
learning-man

Reputation: 169

numpy argmax in array with multiple brackets

I have an issue in apply argmax to an array which has multiple brackets. In real life I am getting this as a result of a pytorch tensor. Here I can put an example:

a = np.array([[1.0, 1.1],[2.1,2.0]])
np.argmax(a,axis=1)

array([1, 0])

It is correct. But:

a = np.array([[[1.0, 1.1]],[[2.1,2.0]]])
np.argmax(a,axis=1)

array([[0, 0],
       [0, 0]])

It does not give me what I expect. Consider that in reality I have this level of inner brackets:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])

Upvotes: 0

Views: 192

Answers (2)

Daniel F
Daniel F

Reputation: 14399

Use .squeeze() and a negative index.

a = np.array([[[[1.0, 1.1]]], [[[2.1, 2.0]]]])
np.argmax(a, axis = -1).squeeze()

array([1, 0], dtype=int32)

Upvotes: 1

learning-man
learning-man

Reputation: 169

A possible solution is to increment axis value:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
np.argmax(a,axis=3)

array([[[1]],
       [[0]]])

But I still have inner brackets.

Upvotes: 0

Related Questions