Reputation: 731
I would like to sort each row in a bxmxn
pytorch tensor (where b
represents the batch size) by the k-th column value in each row. So my input tensor is bxmxn
, and my output tensor is also bxmxn
with the rows of each mxn
tensor rearranged based on the k-th column value.
For example, if my original tensor is:
a = torch.as_tensor([[[1, 3, 7, 6], [9, 0, 6, 2], [3, 0, 5, 8]], [[1, 0, 1, 0], [2, 1, 0, 3], [0, 0, 6, 1]]])
My sorted tensor should be:
sorted_dim = 1 # sort by rows, preserving each row
sorted_column = 2 # sort rows on value of 3rd column of each row
sorted_a = torch.as_tensor([[[3, 0, 5, 8], [9, 0, 6, 2], [1, 3, 7, 6]], [[2, 1, 0, 3], [1, 0, 1, 0], [0, 0, 6, 1]]])
Thanks!
Upvotes: 1
Views: 490
Reputation: 147
Try this
a = torch.as_tensor([[[1, 3, 7, 6], [9, 0, 6, 2], [3, 0, 5, 8]], [[1, 0, 1, 0], [2, 1, 0, 3], [0, 0, 6, 1]]])
b=torch.argsort(a[:,:,2])
sorted_a=torch.stack([a[i,b[i],:] for i in range(a.shape[0])] )
sorted_a
output:
tensor([[[3, 0, 5, 8],
[9, 0, 6, 2],
[1, 3, 7, 6]],
[[2, 1, 0, 3],
[1, 0, 1, 0],
[0, 0, 6, 1]]])
Upvotes: 1