Reputation: 3383
I have a 3 x 3 PyTorch LongTensor that looks something like this:
A =
[0, 0, 0]
[1, 2, 2]
[1, 2, 3]
I want to us it to index a 4 x 2 FloatTensor like this one:
B =
[0.4, 0.5]
[1.2, 1.4]
[0.8, 1.9]
[2.4, 2.9]
My intended output is the 2 x 3 x 3 FloatTensor below:
C[0,:,:] =
[0.4, 0.4, 0.4]
[1.2, 0.8, 0.8]
[1.2, 0.8, 2.4]
C[1,:,:] =
[0.5, 0.5, 0.5]
[1.4, 1.9, 1.9]
[1.4, 1.9, 2.9]
In other words, matrix A
is indexing and broadcasting matrix B
. A
is the matrix of indices of B
, so this operation is essentially an indexing operation.
How can this be done using the torch.index_select()
function? If the solution involves adding or permuting dimensions, that's fine.
Upvotes: 2
Views: 3925
Reputation: 3383
Using index_select()
requires that the indexing values are in a vector rather than a tensor. But as long as that is formatted correctly, the function handles the broadcasting for you. The last thing that must be done is reshaping the output, I believe due to the broadcasting.
The one-liner that will do this operation successfully is
torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)
A.view(-1)
vectorizes the indices matrix.
__.view(3,-1,2)
reshapes back to the shape of the indexing matrix, but accounting for the new extra dimension of size 2 (as I am indexing an N x 2 matrix).
Lastly, __.permute(2,0,1)
reshapes the matrix so that the output views each dimension of B
in a separate channel (rather than each column).
Upvotes: 1