Reputation: 77
Given the following two tensors:
x = torch.tensor([[[1, 2],
[2, 0],
[0, 0]],
[[2, 2],
[2, 0],
[3, 3]]]) # [batch_size x sequence_length x subseq_length]
y = torch.tensor([[2, 1, 0],
[2, 1, 2]]) # [batch_size x sequence_length]
I would like to sort the sequences in x based on their sub-sequence lengths (0 corresponds to padding in the sequence). y
corresponds to the lengths of the sub-sequences in x. I have tried the following:
y_sorted, y_sort_idx = y.sort(dim=1, descending=True)
print(x.scatter_(dim=1, index=y_sort_idx.unsqueeze(2), src=x))
This results in:
tensor([[[1, 2],
[2, 0],
[0, 0]],
[[2, 2],
[2, 0],
[2, 3]]])
However what I would like to achieve is:
tensor([[[1, 2],
[2, 0],
[0, 0]],
[[2, 2],
[3, 3],
[2, 0]]])
Upvotes: 1
Views: 223
Reputation: 26
This should do it
y_sorted, y_sort_idx = y.sort(dim=1, descending=True)
index = y_sort_idx.unsqueeze(2).expand_as(x)
x = x.gather(dim=1, index=index)
Upvotes: 1