sarnikowski
sarnikowski

Reputation: 77

Using _scatter() to replace values in matrix

Given the following two tensors:

x = torch.tensor([[[1, 2],
                   [2, 0],
                   [0, 0]],

                  [[2, 2],
                   [2, 0],
                   [3, 3]]]) # [batch_size x sequence_length x subseq_length]
y = torch.tensor([[2, 1, 0],
                  [2, 1, 2]]) # [batch_size x sequence_length]

I would like to sort the sequences in x based on their sub-sequence lengths (0 corresponds to padding in the sequence). y corresponds to the lengths of the sub-sequences in x. I have tried the following:

y_sorted, y_sort_idx = y.sort(dim=1, descending=True)
print(x.scatter_(dim=1, index=y_sort_idx.unsqueeze(2), src=x))

This results in:

tensor([[[1, 2],
         [2, 0],
         [0, 0]],

        [[2, 2],
         [2, 0],
         [2, 3]]])

However what I would like to achieve is:

tensor([[[1, 2],
         [2, 0],
         [0, 0]],

        [[2, 2],
         [3, 3],
         [2, 0]]])

Upvotes: 1

Views: 223

Answers (1)

Emil Laursen
Emil Laursen

Reputation: 26

This should do it

y_sorted, y_sort_idx = y.sort(dim=1, descending=True)
index = y_sort_idx.unsqueeze(2).expand_as(x)
x = x.gather(dim=1, index=index)

Upvotes: 1

Related Questions