Wenhui
Wenhui

Reputation: 61

How to index a 3-d tensor with 2-d tensor in pytorch?

import torch
a = torch.rand(5,256,120)
min_values, indices = torch.min(a,dim=0)
aa = torch.zeros(256,120)
for i in range(256):
    for j in range(120):
        aa[i,j] = a[indices[i,j],i,j]

print((aa==min_values).sum()==256*120)

I want to know how to avoid to using the for-for loop to get the aa values? (I want to use the indices to select elements in another 3-d tensors so I can't use the values return by min directly)

Upvotes: 0

Views: 525

Answers (1)

Anton Ganichev
Anton Ganichev

Reputation: 2542

You can use torch.gather

aa = torch.gather(a, 0, indices.unsqueeze(0))

as explained here: Slicing a 4D tensor with a 3D tensor-index in PyTorch

Upvotes: 1

Related Questions