Wasi Ahmad
Wasi Ahmad

Reputation: 37691

Indexing a 3d tensor using a 2d tensor

I have a 3d tensor, source of shape (bsz x slen1 x nhd) and a 2d tensor, index of shape (bsz x slen2). More specifically, I have:

source = 32 x 20 x 768
index  = 32 x 16

Each value in the index tensor is in between [0, 19] which is the index of the desired vector according to the 2nd dim of the source tensor.

After indexing, I am expecting an output tensor of shape, 32 x 16 x 768.

Currently I am doing this:

bsz, _, nhid = source.size()
_, slen = index.size()

source = source.reshape(-1, nhid)
source = source[index.reshape(-1), :]
source = source.reshape(bsz, slen, nhid)

So, I am converting the 3d source tensor to a 2d tensor and 2d indexing tensor to a 1d tensor and then perform the indexing. Is this correct?

Is there any better way to do it?

Update

I checked that my code is not giving the expected result. To explain what I want, I am providing the following code snippet.

source = torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820],
     [ 0.3490, -0.0198,  0.7928]],

    [[-0.0973,  2.3106, -1.8358],
     [-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

index = torch.LongTensor([[0, 1, 2, 3], 
                          [1, 2, 3, 4]])

And I want the output tensor as:

torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820]],

    [[-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

Upvotes: 5

Views: 4149

Answers (3)

NND
NND

Reputation: 1

I would like to expand this topic because I have just come into this issue. To subset the third dimension, use the similar codes:

dim_1 = torch.arange(source.shape[0]).unsqueeze(-1).unsqueeze(-1)
dim_2 = torch.arange(source.shape[1]).unsqueeze(-1)
dim_3 = torch.arange(2)
subset = source[dim_1, dim_2, dim_3]

Upvotes: 0

colesbury
colesbury

Reputation: 136

Update:

source[torch.arange(source.shape[0]).unsqueeze(-1), index]

Note that torch.arange(source.shape[0]).unsqueeze(-1) gives:

tensor([[0],
        [1]])  # 2 x 1

and index is:

tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]])  # 2 x 4

The arange indexes the batch dimension while index simultaneously indexes the slen1 dimension. The unsqueeze call adds the extra x 1 dimension to the arange result so that the two can be broadcast together.

Upvotes: 6

Wasi Ahmad
Wasi Ahmad

Reputation: 37691

I have solved the problem. So, I was actually in need of defining an offset. The following code works for me.

index = torch.LongTensor([[0, 1, 2, 3], [1, 2, 3, 4]])
offset = torch.arange(0, source.size(0) * source.size(1), source.size(1))
index = index + offset.unsqueeze(1)

source = source.reshape(-1, source.shape[-1])[index]

Upvotes: 1

Related Questions