Reputation: 1825
I have a 3D tensor M
of dimensions [BxLxD]
and a 1D tensor idx
of dimensions [B,1]
that contains column indices in the range (0, L-1)
. I want to create a 2D tensor N
of dimensions [BxD]
such that N[i,j] = M[i, idx[i], j]
. How can this be done efficiently?
Example:
B,L,D = 2,4,2
M = torch.rand(B,L,D)
>
tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.randint(0, L, size = (B,))
>
tensor([3, 0])
N = get_N(M, idx)
Expected output:
>
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
Thanks.
Upvotes: 0
Views: 769
Reputation: 2585
import torch
B,L,D = 2,4,2
def get_N(M, idx):
return M[torch.arange(B), idx, :].squeeze()
M = torch.tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)
result:
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
slice along two dimensions.
Upvotes: 3