Ausrada404
Ausrada404

Reputation: 599

Pytorch tensor get the index of the element with specific values?

I have two tensores, tensor a and tensor b.

I want to get all indexes of values in tensor b.

For example.

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])

I want the index of 1, 2, 4 in tensor a. I can do this by the following code.

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
    mask = mask + (a == e)
    print(mask)

How can I do it without for?

Upvotes: 2

Views: 6321

Answers (2)

Hossein
Hossein

Reputation: 25924

Update:

As @zaydh kindly pointed out in the comments, since PyTorch 1.10, isin() and isinf()(and many other numpy equivalents) are available as well, thus you can simply do:

torch.isin(a, b)

which would give you :

Out[4]: tensor([ True,  True,  True, False,  True,  True,  True, False])

Old answer:

Is this what you want? :

np.in1d(a.numpy(), b.numpy())

will result in :

array([ True,  True,  True, False,  True,  True,  True, False])

Upvotes: 4

pregenRobot
pregenRobot

Reputation: 307

If you just do not want to use a for loop, you can just use list comprehension:

mask = [a[index] for index in b]

If do not even want to use the "for" word, you can always convert the tensors to numpy and use numpy indexing.

mask = torch.tensor(a.numpy()[b.numpy()])

UPDATE

Might have misunderstood your question. In that case, I would say the best way to achieve this is through list comprehension. (Slicing will probably not achieve this.

mask = [index for index,value in enumerate(a) if value in b.tolist()] 

This iterates over every element in a, gets their index and values, and if the value is inside b, then gets the index.

Upvotes: 0

Related Questions