Reputation: 460
I have a tensor which looks like this:
tensor([[-0.0150, 0.1234],
[-0.0184, 0.1062],
[-0.0139, 0.1113],
[-0.0088, 0.0726]])
And another that looks like this:
tensor([[1.],
[1.],
[0.],
[0.]])
I want to return the values from the first tensor, for each row, that corresponds to the indice from the second tensor.
So our output would be:
tensor([0.1234], [0.1062], [-0.0139], [-0.0088]])
So far I have this code:
return torch.gather(tensor1, tensor2)
However I am getting the error:
TypeError: gather() received an invalid combination of arguments - got (Tensor, Tensor), but expected one of:
* (Tensor input, int dim, Tensor index, *, bool sparse_grad, Tensor out)
* (Tensor input, name dim, Tensor index, *, bool sparse_grad, Tensor out)
What am I doing wrong?
Upvotes: 1
Views: 1947
Reputation: 1
t2=torch.tensor([[-0.0150, 0.1234],
[-0.0184, 0.1062],
[-0.0139, 0.1113],
[-0.0088, 0.0726]])
t3=torch.tensor([[1.],
[1.],
[0.],
[0.]]).type(torch.int64)
res=t2.gather(1,t3)
print(res)
Upvotes: 0
Reputation: 1331
You are missing the dim
argument.
You can see an example here: https://pytorch.org/docs/stable/generated/torch.gather.html
For your case I think that return torch.gather(tensor1, 1, tensor2)
should work
Upvotes: 1