Reputation: 826
I am kind of new with numpy and torch and I am struggling to understand what to me seems the most basic operations.
For instance, given this tensor:
A = tensor([[[6, 3, 8, 3],
[1, 0, 9, 9]],
[[4, 9, 4, 1],
[8, 1, 3, 5]],
[[9, 7, 5, 6],
[3, 7, 8, 1]]])
And this other tensor:
B = tensor([1, 0, 1])
I would like to use B as indexes for A so that I get a 3 by 4 tensor that looks like this:
[[1, 0, 9, 9],
[4, 9, 4, 1],
[3, 7, 8, 1]]
Thanks!
Upvotes: 1
Views: 1302
Reputation: 40618
Alternatively, you can use torch.gather
:
>>> indexer = B.view(-1, 1, 1).expand(-1, -1, 4)
tensor([[[1, 1, 1, 1]],
[[0, 0, 0, 0]],
[[1, 1, 1, 1]]])
>>> A.gather(1, indexer).view(len(B), -1)
tensor([[1, 0, 9, 9],
[4, 9, 4, 1],
[3, 7, 8, 1]])
Upvotes: 0
Reputation: 826
Ok, my mistake was to assume this:
A[:, B]
is equal to this:
A[[0, 1, 2], B]
Or more generally the solution I wanted is:
A[range(B.shape[0]), B]
Upvotes: 2