Reputation: 1272
I have the following torch tensor:
tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
and the following numpy array: (I can convert it to something else if necessary)
[1 0 1]
I want to get the following tensor:
tensor([0.3, -0.5, 0.2])
i.e. I want the numpy array to index each sub-element of my tensor. Preferably without using a loop.
Thanks in advance
Upvotes: 7
Views: 10472
Reputation: 7693
You may want to use torch.gather
- "Gathers values along an axis specified by dim."
t = torch.tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
idxs = np.array([1,0,1])
idxs = torch.from_numpy(idxs).long().unsqueeze(1)
# or torch.from_numpy(idxs).long().view(-1,1)
t.gather(1, idxs)
tensor([[ 0.3000],
[-0.5000],
[ 0.2000]])
Here, your index is numpy array so you have to convert it to LongTensor.
Upvotes: 4
Reputation: 11198
Just simply, use a range(len(index)) for the first dimension.
import torch
a = torch.tensor([[-0.2, 0.3],
[-0.5, 0.1],
[-0.4, 0.2]])
c = [1, 0, 1]
b = a[range(3),c]
print(b)
Upvotes: 3