Hadar
Hadar

Reputation: 668

extracting tensor values given tensor index values torch

I have a tensor of values val with shape (b,n) and a tensor of indexes ind with shape (b,m) (where n>m). My goal is to take the values in val that corresponds to the indexes in ind. Ive tried using val[ind], but it only expanded the dimensions of val, rather than taking only the relevant items

val = torch.tensor([[1,2,3],
                    [4,5,6],
                    [7,8,9],
                    [10,11,12],
                    [13,14,15]])   
ind = torch.tensor([[1,2],
                    [0,2],
                    [0,1],
                    [1,2],
                    [0,1]])
val[ind] # shaped (5,2,4), I need (5,2)

the wanted output is

torch.tensor([[2,3],
              [4,6],
              [7,8],
              [11,12],
              [13,14]])

Upvotes: 1

Views: 1998

Answers (1)

Ivan
Ivan

Reputation: 40658

You can perform such operation using torch.gather:

>>> val.gather(dim=1, index=ind)
tensor([[ 2,  3],
        [ 4,  6],
        [ 7,  8],
        [11, 12],
        [13, 14]])

Essentially indexing val's 2nd dimension using ind's values. The returned tensor out follows:

out[i][j] = val[i][ind[i]]

Upvotes: 1

Related Questions