Ashwin Geet D'Sa
Ashwin Geet D'Sa

Reputation: 7379

How to obtain the values from each row at given position in pytorch?

How to obtain the values from a 2-d torch array based on 1-d array containing the positions at each row:

for example:

a = torch.randn((5,5))
>>> a
tensor([[ 0.0740, -0.3129,  0.7814, -0.0519,  1.3503],
        [ 1.1985,  0.2098, -0.0326,  0.3922,  0.5037],
        [-1.4334,  1.4047, -0.6607, -1.8024, -0.0088],
        [ 1.2116,  0.5928,  1.4041,  1.0494, -0.1146],
        [ 0.4173,  1.0482,  0.5244, -2.1767,  0.5264]])

b = torch.randint(0,5, (5,))
>>> b
tensor([1, 0, 1, 3, 2])

I wanted to get the elements of tensor a at the position given by tensor b

For example:

desired output:
tensor([-0.3129,
        1.1985,
        1.4047,
        1.0494,
        0.5244])

Here, each element in a given position by tensor b is chosen row-wise.

I have tried:

for index in range(b.size(-1)):
    val = torch.cat((val,a[index,b[index]].view(1,-1)), dim=0) if val is not None else a[index,b[index]].view(1,-1)


>>> val
tensor([[-0.3129],
        [ 1.1985],
        [ 1.4047],
        [ 1.0494],
        [ 0.5244]])

However, is there a tensor indexing way to do it? I tried couple of solutions using tensor indexing, but none of them worked.

Upvotes: 3

Views: 2075

Answers (1)

Dishin H Goyani
Dishin H Goyani

Reputation: 7723

You can use torch.gather

>>> a.gather(1, b.unsqueeze(1))
tensor([[-0.3129],
        [ 1.1985],
        [ 1.4047],
        [ 1.0494],
        [ 0.5244]])

Or

>>> a[range(len(a)), b].unsqueeze(1)
tensor([[-0.3129],
        [ 1.1985],
        [ 1.4047],
        [ 1.0494],
        [ 0.5244]])

Upvotes: 2

Related Questions