Alex Pi
Alex Pi

Reputation: 826

Select on second dimension on a 3D pytorch tensor with an array of indexes

I am kind of new with numpy and torch and I am struggling to understand what to me seems the most basic operations.

For instance, given this tensor:

A = tensor([[[6, 3, 8, 3],
         [1, 0, 9, 9]],

        [[4, 9, 4, 1],
         [8, 1, 3, 5]],

        [[9, 7, 5, 6],
         [3, 7, 8, 1]]])

And this other tensor:

B = tensor([1, 0, 1])

I would like to use B as indexes for A so that I get a 3 by 4 tensor that looks like this:

[[1, 0, 9, 9],
 [4, 9, 4, 1],         
 [3, 7, 8, 1]]

Thanks!

Upvotes: 1

Views: 1302

Answers (2)

Ivan
Ivan

Reputation: 40618

Alternatively, you can use torch.gather:

>>> indexer = B.view(-1, 1, 1).expand(-1, -1, 4)
tensor([[[1, 1, 1, 1]],

        [[0, 0, 0, 0]],

        [[1, 1, 1, 1]]])

>>> A.gather(1, indexer).view(len(B), -1)
tensor([[1, 0, 9, 9],
        [4, 9, 4, 1],
        [3, 7, 8, 1]])

Upvotes: 0

Alex Pi
Alex Pi

Reputation: 826

Ok, my mistake was to assume this:

A[:, B]

is equal to this:

A[[0, 1, 2], B]

Or more generally the solution I wanted is:

A[range(B.shape[0]), B]

Upvotes: 2

Related Questions