Reputation: 7379
How to obtain the values from a 2-d torch array based on 1-d array containing the positions at each row:
for example:
a = torch.randn((5,5))
>>> a
tensor([[ 0.0740, -0.3129, 0.7814, -0.0519, 1.3503],
[ 1.1985, 0.2098, -0.0326, 0.3922, 0.5037],
[-1.4334, 1.4047, -0.6607, -1.8024, -0.0088],
[ 1.2116, 0.5928, 1.4041, 1.0494, -0.1146],
[ 0.4173, 1.0482, 0.5244, -2.1767, 0.5264]])
b = torch.randint(0,5, (5,))
>>> b
tensor([1, 0, 1, 3, 2])
I wanted to get the elements of tensor a
at the position given by tensor b
For example:
desired output:
tensor([-0.3129,
1.1985,
1.4047,
1.0494,
0.5244])
Here, each element in a given position by tensor b
is chosen row-wise.
I have tried:
for index in range(b.size(-1)):
val = torch.cat((val,a[index,b[index]].view(1,-1)), dim=0) if val is not None else a[index,b[index]].view(1,-1)
>>> val
tensor([[-0.3129],
[ 1.1985],
[ 1.4047],
[ 1.0494],
[ 0.5244]])
However, is there a tensor indexing way to do it? I tried couple of solutions using tensor indexing, but none of them worked.
Upvotes: 3
Views: 2075
Reputation: 7723
You can use torch.gather
>>> a.gather(1, b.unsqueeze(1))
tensor([[-0.3129],
[ 1.1985],
[ 1.4047],
[ 1.0494],
[ 0.5244]])
Or
>>> a[range(len(a)), b].unsqueeze(1)
tensor([[-0.3129],
[ 1.1985],
[ 1.4047],
[ 1.0494],
[ 0.5244]])
Upvotes: 2