neocyber
neocyber

Reputation: 3

How to perform advanced indexing in PyTorch?

Is there a way of doing the following without looping?

S, N, H = 9, 7, 4

a = torch.randn(S, N, H)

# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(1, S + 1, (N,)) 

res = torch.zeros(N, H)

for i in range(N):
    res[i] = a[lens[i] - 1, i, :]

Upvotes: 0

Views: 400

Answers (1)

ddoGas
ddoGas

Reputation: 871

Yes, I believe this works.

import torch

S, N, H = 9, 7, 4

a = torch.randn(S, N, H)

# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(0, S, (N,)) 
i = torch.tensor(range(0,7))
res = torch.zeros(N, H)

res = a[lens, i, :]
print(res)

And why did you make lens 1 from S+1 and then do lens[i]-1 ? I just changed it so lens is 0 from S for convenience. However if you need lens to be 1 from S+1, you can change
res = a[lens, i, :]
to
res = a[lens-1, i, :]

Upvotes: 2

Related Questions