Reputation: 63
If I have a tensor which has multiple maximum values how can I get all the indices of maximum value. I have tried torch.argmax(tensor) but it only gives me the first index.
>>> a_list = [3,23,53,32,53]
>>> a_tensor = torch.Tensor(a_list)
>>> a_tensor
tensor([ 3., 23., 53., 32., 53.])
>>> torch.max(a_tensor)
tensor(53.)
>>> torch.argmax(a_tensor)
tensor(2)
I have following function to do it but was wondering if there are more efficient approaches:
def max_tensor_indices(tensor_t,max_value):
tensor_list=tensor_t[0]
indices_list=[]
for i in range(len(tensor_list)):
if tensor_list[i]==max_value:
indices_list.append(i)
return indices_list
Upvotes: 1
Views: 1589
Reputation: 27201
Find the maximum value, then find all elements with that value.
(x == torch.max(x)).nonzero()
Note: nonzero
may also be called with as_tuple=True
, which may be helpful.
Upvotes: 2