BugsBuggy
BugsBuggy

Reputation: 159

How to get the indexes of equal elements in two different size PyTorch tensors?

Let's say I have two PyTorch tensors:

t_1d = torch.Tensor([6, 5, 1, 7, 8, 4, 7, 1, 0, 4, 11, 7, 4, 7, 4, 1])
t = torch.Tensor([4, 7])

I want to get the indices of exact match intersection between the sets for the tensor t_1d with tensor t.

Desired output of t_1d and t: [5, 12] (first index of exact intersection)

Preferably on GPU for large Tensors, so no loops or Numpy casts.

Upvotes: 1

Views: 1584

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

In general, we can check where each element in t is equal to elements in t_1d.

After that, shift back the last element by as many places as it misses from the first element (in general case, here shift by -1) and check whether arrays are equal:

intersection = (t_1d == t[0]) & torch.roll(t_1d == t[1], shifts=-1)
torch.where(intersection)[0] # torch.tensor([5, 12])

Upvotes: 2

Related Questions