Reputation: 439
I have a PyTorch tensor b
with the shape: torch.Size([10, 10, 51])
. I want to select one element between the 10 possible elements in the dimension d=1 (middle one) using a numpy array: a = np.array([0,1,2,3,4,5,6,7,8,9])
. this is just a random example.
I wanted to do:
b[:,a,:]
but that isn't working
Upvotes: 0
Views: 1546
Reputation: 439
I have found the solution on the PyTorch forum: (https://discuss.pytorch.org/t/how-to-select-specific-vector-in-3d-tensor-beautifully/37724)
x = torch.tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19]]])
idx = torch.tensor([1, 2])
x[torch.arange(x.size(0)), idx]
Upvotes: 0
Reputation: 40618
An indexing of b
on the second axis using a
should do:
>>> b = torch.rand(10, 10, 51)
>>> a = np.array([0,1,2,3,4,5,6,7,8,9])
>>> b[:, a].shape
torch.Size([10, 10, 51])
Upvotes: 0