Reputation: 3
Is there a way of doing the following without looping?
S, N, H = 9, 7, 4
a = torch.randn(S, N, H)
# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(1, S + 1, (N,))
res = torch.zeros(N, H)
for i in range(N):
res[i] = a[lens[i] - 1, i, :]
Upvotes: 0
Views: 400
Reputation: 871
Yes, I believe this works.
import torch
S, N, H = 9, 7, 4
a = torch.randn(S, N, H)
# tensor with integer values between 1, S of shape (N,)
lens = torch.randint(0, S, (N,))
i = torch.tensor(range(0,7))
res = torch.zeros(N, H)
res = a[lens, i, :]
print(res)
And why did you make lens 1 from S+1 and then do lens[i]-1
? I just changed it so lens is 0 from S for convenience. However if you need lens to be 1 from S+1, you can change
res = a[lens, i, :]
to
res = a[lens-1, i, :]
Upvotes: 2