Reputation: 121
Given the below tensor:
tensor = torch.Tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
and below is the tensor containing the indices:
indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1])
How can I sort tensor
using the values inside of indices
?
Trying with sorted
gives the error:
TypeError: 'Tensor' object is not callable`.
While numpy.sort
gives:
ValueError: Cannot specify order when the array has no fields.`
Upvotes: 2
Views: 1339
Reputation: 9786
You can use the indices like this:
tensor = torch.Tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1])
sorted_tensor = tensor[indices]
print(sorted_tensor)
# output
tensor([[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.]])
Upvotes: 4