H.Jamil
H.Jamil

Reputation: 63

How can I find multiple maximum indices of a torch tensor?

If I have a tensor which has multiple maximum values how can I get all the indices of maximum value. I have tried torch.argmax(tensor) but it only gives me the first index.

>>> a_list = [3,23,53,32,53]
>>> a_tensor = torch.Tensor(a_list)
>>> a_tensor
tensor([ 3., 23., 53., 32., 53.])
>>> torch.max(a_tensor)
tensor(53.)
>>> torch.argmax(a_tensor)
tensor(2)

I have following function to do it but was wondering if there are more efficient approaches:

def max_tensor_indices(tensor_t,max_value):
  tensor_list=tensor_t[0]
  indices_list=[]
  for i in range(len(tensor_list)):
     if tensor_list[i]==max_value:
         indices_list.append(i)
  return indices_list

Upvotes: 1

Views: 1589

Answers (1)

Mateen Ulhaq
Mateen Ulhaq

Reputation: 27201

Find the maximum value, then find all elements with that value.

(x == torch.max(x)).nonzero()

Note: nonzero may also be called with as_tuple=True, which may be helpful.

Upvotes: 2

Related Questions