Reputation: 668
I have a tensor of values val
with shape (b,n)
and a tensor of indexes ind
with shape (b,m)
(where n>m
). My goal is to take the values in val
that corresponds to the indexes in ind
. Ive tried using val[ind]
, but it only expanded the dimensions of val
, rather than taking only the relevant items
val = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9],
[10,11,12],
[13,14,15]])
ind = torch.tensor([[1,2],
[0,2],
[0,1],
[1,2],
[0,1]])
val[ind] # shaped (5,2,4), I need (5,2)
the wanted output is
torch.tensor([[2,3],
[4,6],
[7,8],
[11,12],
[13,14]])
Upvotes: 1
Views: 1998
Reputation: 40658
You can perform such operation using torch.gather
:
>>> val.gather(dim=1, index=ind)
tensor([[ 2, 3],
[ 4, 6],
[ 7, 8],
[11, 12],
[13, 14]])
Essentially indexing val
's 2nd dimension using ind
's values. The returned tensor out
follows:
out[i][j] = val[i][ind[i]]
Upvotes: 1