relot
relot

Reputation: 701

Pytorch: explain torch.argmax

Hello I have the following code:

import torch
x = torch.zeros(1,8,4,576) # create a 4 dimensional tensor
x[0,4,2,333] = 1.0 # put on 1 on a random spot

# I want to find the index of the highest value (0,4,2,333)
print(x.argmax()) # this should return the index

This returns

tensor(10701)

How does this 10701 make sense?

How do I get the actual indices 0,4,2,333 ?

Upvotes: 0

Views: 882

Answers (1)

dannyadam
dannyadam

Reputation: 4170

The data in the 4-dimensional array is stored linearly in memory, and argmax() returns the corresponding index of this flat representation.

Numpy has a function for unraveling the index (converting from the flat array index to the corresponding multi-dimensional indices).

import numpy as np
np.unravel_index(10701, (1,8,4,576))

Upvotes: 2

Related Questions