Reputation: 599
I have two tensores, tensor a and tensor b.
I want to get all indexes of values in tensor b.
For example.
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
I want the index of 1, 2, 4
in tensor a. I can do this by the following code.
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
mask = mask + (a == e)
print(mask)
How can I do it without for
?
Upvotes: 2
Views: 6321
Reputation: 25924
As @zaydh kindly pointed out in the comments, since PyTorch 1.10
, isin()
and isinf()
(and many other numpy equivalents) are available as well, thus you can simply do:
torch.isin(a, b)
which would give you :
Out[4]: tensor([ True, True, True, False, True, True, True, False])
Is this what you want? :
np.in1d(a.numpy(), b.numpy())
will result in :
array([ True, True, True, False, True, True, True, False])
Upvotes: 4
Reputation: 307
If you just do not want to use a for loop, you can just use list comprehension:
mask = [a[index] for index in b]
If do not even want to use the "for" word, you can always convert the tensors to numpy and use numpy indexing.
mask = torch.tensor(a.numpy()[b.numpy()])
UPDATE
Might have misunderstood your question. In that case, I would say the best way to achieve this is through list comprehension. (Slicing will probably not achieve this.
mask = [index for index,value in enumerate(a) if value in b.tolist()]
This iterates over every element in a, gets their index and values, and if the value is inside b, then gets the index.
Upvotes: 0