Reputation: 813
With python lists, we can do:
a = [1, 2, 3]
assert a.index(2) == 1
How can a pytorch tensor find the .index()
directly?
Upvotes: 71
Views: 212342
Reputation: 6044
a = torch.tensor([1, 2, 3])
torch.where(a == 2)[0]
>>>tensor([1])
Upvotes: 3
Reputation: 4083
In my opinion, calling tolist()
is simple and easy to understand.
t = torch.Tensor([1, 2, 3])
t.tolist().index(2) # -> 1
Upvotes: 1
Reputation: 5487
x = torch.Tensor([11, 22, 33, 22])
print((x==22).nonzero().squeeze())
tensor([1, 3])
Upvotes: 4
Reputation: 5247
The answers already given are great but they don't handle when I tried it when there is no match. For that see this:
def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
"""
Returns generalized index (i.e. location/coordinate) of the first occurence of value
in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
If there are multiple occurences then you need to choose which one you want with ith_index.
e.g. ith_index=0 gives first occurence.
Reference: https://stackoverflow.com/a/67175757/1601580
:return:
"""
# bool tensor of where value occurred
places_where_value_occurs = (tensor == value)
# get matches as a "coordinate list" where occurence happened
matches = (tensor == value).nonzero() # [number_of_matches, tensor_dimension]
if matches.size(0) == 0: # no matches
return -1
else:
# get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
index = matches[ith_match]
return index
credit to this great answer: https://stackoverflow.com/a/67175757/1601580
Upvotes: 2
Reputation: 81
Based on others' answers:
t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())
Upvotes: 2
Reputation: 2646
For multidimensional tensors you can do:
(tensor == target_value).nonzero(as_tuple=True)
The resulting tensor will be of shape number_of_matches x tensor_dimension
. For example, say tensor
is a 3 x 4
tensor (that means the dimension is 2), the result will be a 2D-tensor with the indexes for the matches in the rows.
tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
[0, 2],
[1, 2]])
Upvotes: 26
Reputation: 2751
I think there is no direct translation from list.index()
to a pytorch function. However, you can achieve similar results using tensor==number
and then the nonzero()
function. For example:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])
This piece of code returns
1
[torch.LongTensor of size 1x1]
Upvotes: 99
Reputation: 1
for finding index of an element in 1d tensor/array Example
mat=torch.tensor([1,8,5,3])
to find index of 5
five=5
numb_of_col=4
for o in range(numb_of_col):
if mat[o]==five:
print(torch.tensor([o]))
To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)
Example
mat=torch.tensor([[1,2],[4,3])
#to find index of 2
five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
if mat[o] == five:
print(torch.tensor([o]))
Upvotes: -2
Reputation: 498
For floating point tensors, I use this to get the index of the element in the tensor.
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
Upvotes: -3
Reputation: 1
import torch
x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
print(x_data.data[0])
>>tensor([1.])
Upvotes: -4
Reputation: 204
Can be done by converting to numpy as follows
import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1., 2., 3., 4.])
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
Upvotes: 0