DsCpp
DsCpp

Reputation: 2489

PyTorch gather 3D source with 2D index

Given a source tensor of shape [B,N,F] and an index tensor of shape [B,k], where index[i][j] is an index to a specific feature inside source[i][j] is there a way to extract an output tensor such that:

output[i][j] = source[i][j][index[i][j]]

torch.gather specifies that index.shape == source.shape, while here the shape of the source is one dimension bigger.

source = [
[[0.1,0.2],[0.2,0.3]],
[[0.4,0.5],[0.6,0.7]],
[[0.7,0.6],[0.8,0.9]]
]
index = [
[1,0],
[0,0],
[1,1]
]

desired_output = [
[0.2,0.2],
[0.4,0.6],
[0.6,0.9]
]

Upvotes: 1

Views: 1628

Answers (1)

DsCpp
DsCpp

Reputation: 2489

For future references - The solution is

source.gather(2,index.unsqueeze(2)).squeeze(2)

Upvotes: 4

Related Questions