caaax
caaax

Reputation: 460

How to use gather() in python to return values at specific indices of a tensor

I have a tensor which looks like this:

tensor([[-0.0150,  0.1234],
    [-0.0184,  0.1062],
    [-0.0139,  0.1113],
    [-0.0088,  0.0726]])

And another that looks like this:

tensor([[1.],
    [1.],
    [0.],
    [0.]])

I want to return the values from the first tensor, for each row, that corresponds to the indice from the second tensor.

So our output would be:

tensor([0.1234], [0.1062], [-0.0139], [-0.0088]])

So far I have this code:

return torch.gather(tensor1, tensor2)

However I am getting the error:

TypeError: gather() received an invalid combination of arguments - got (Tensor, Tensor), but expected one of:
 * (Tensor input, int dim, Tensor index, *, bool sparse_grad, Tensor out)
 * (Tensor input, name dim, Tensor index, *, bool sparse_grad, Tensor out)

What am I doing wrong?

Upvotes: 1

Views: 1947

Answers (2)

Hakima Sabri
Hakima Sabri

Reputation: 1

t2=torch.tensor([[-0.0150,  0.1234],
[-0.0184,  0.1062],
[-0.0139,  0.1113],
[-0.0088,  0.0726]])
t3=torch.tensor([[1.],
[1.],
[0.],
[0.]]).type(torch.int64)
res=t2.gather(1,t3)
print(res)

Upvotes: 0

Tamir
Tamir

Reputation: 1331

You are missing the dim argument. You can see an example here: https://pytorch.org/docs/stable/generated/torch.gather.html

For your case I think that return torch.gather(tensor1, 1, tensor2) should work

Upvotes: 1

Related Questions