user1767774
user1767774

Reputation: 1825

Pytorch: choosing columns from a 3d tensor, according to indices tensor

I have a 3D tensor M of dimensions [BxLxD] and a 1D tensor idx of dimensions [B,1] that contains column indices in the range (0, L-1). I want to create a 2D tensor N of dimensions [BxD] such that N[i,j] = M[i, idx[i], j]. How can this be done efficiently?

Example:

B,L,D = 2,4,2

M = torch.rand(B,L,D)

>

tensor([[[0.0612, 0.7385],
         [0.7675, 0.3444],
         [0.9129, 0.7601],
         [0.0567, 0.5602]],

        [[0.5450, 0.3749],
         [0.4212, 0.9243],
         [0.1965, 0.9654],
         [0.7230, 0.6295]]])


idx = torch.randint(0, L, size = (B,))

>

tensor([3, 0])

N = get_N(M, idx)

Expected output:

>

tensor([[0.0567, 0.5602], 
       [0.5450, 0.3749]])

Thanks.

Upvotes: 0

Views: 769

Answers (1)

one
one

Reputation: 2585

import torch

B,L,D = 2,4,2

def get_N(M, idx):
    return M[torch.arange(B), idx, :].squeeze()

M = torch.tensor([[[0.0612, 0.7385],
                   [0.7675, 0.3444],
                   [0.9129, 0.7601],
                   [0.0567, 0.5602]],

                   [[0.5450, 0.3749],
                   [0.4212, 0.9243],
                   [0.1965, 0.9654],
                   [0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)

result:

tensor([[0.0567, 0.5602],
        [0.5450, 0.3749]])

slice along two dimensions.

Upvotes: 3

Related Questions