marcman
marcman

Reputation: 3383

Indexing one PyTorch tensor by another using index_select

I have a 3 x 3 PyTorch LongTensor that looks something like this:

A = 
    [0, 0, 0]
    [1, 2, 2]
    [1, 2, 3]

I want to us it to index a 4 x 2 FloatTensor like this one:

B = 
    [0.4, 0.5]
    [1.2, 1.4]
    [0.8, 1.9]
    [2.4, 2.9]

My intended output is the 2 x 3 x 3 FloatTensor below:

C[0,:,:] = 
    [0.4, 0.4, 0.4]
    [1.2, 0.8, 0.8]
    [1.2, 0.8, 2.4]

C[1,:,:] =
    [0.5, 0.5, 0.5]
    [1.4, 1.9, 1.9]
    [1.4, 1.9, 2.9]

In other words, matrix A is indexing and broadcasting matrix B. A is the matrix of indices of B, so this operation is essentially an indexing operation.

How can this be done using the torch.index_select() function? If the solution involves adding or permuting dimensions, that's fine.

Upvotes: 2

Views: 3925

Answers (1)

marcman
marcman

Reputation: 3383

Using index_select() requires that the indexing values are in a vector rather than a tensor. But as long as that is formatted correctly, the function handles the broadcasting for you. The last thing that must be done is reshaping the output, I believe due to the broadcasting.

The one-liner that will do this operation successfully is

torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)

A.view(-1) vectorizes the indices matrix.

__.view(3,-1,2) reshapes back to the shape of the indexing matrix, but accounting for the new extra dimension of size 2 (as I am indexing an N x 2 matrix).

Lastly, __.permute(2,0,1) reshapes the matrix so that the output views each dimension of B in a separate channel (rather than each column).

Upvotes: 1

Related Questions