Reputation: 2489
Given a source tensor of shape [B,N,F]
and an index tensor of shape [B,k]
, where index[i][j]
is an index to a specific feature inside source[i][j]
is there a way to extract an output tensor such that:
output[i][j] = source[i][j][index[i][j]]
torch.gather
specifies that index.shape == source.shape
, while here the shape of the source is one dimension bigger.
source = [
[[0.1,0.2],[0.2,0.3]],
[[0.4,0.5],[0.6,0.7]],
[[0.7,0.6],[0.8,0.9]]
]
index = [
[1,0],
[0,0],
[1,1]
]
desired_output = [
[0.2,0.2],
[0.4,0.6],
[0.6,0.9]
]
Upvotes: 1
Views: 1628
Reputation: 2489
For future references - The solution is
source.gather(2,index.unsqueeze(2)).squeeze(2)
Upvotes: 4