lima0
lima0

Reputation: 121

How to sort a one hot tensor according to a tensor of indices

Given the below tensor:

tensor = torch.Tensor([[1., 0., 0., 0., 0.],
                       [0., 1., 0., 0., 0.],
                       [0., 0., 1., 0., 0.],
                       [0., 0., 0., 0., 1.],
                       [1., 0., 0., 0., 0.],
                       [1., 0., 0., 0., 0.],
                       [0., 0., 0., 1., 0.],
                       [0., 0., 0., 0., 1.]])

and below is the tensor containing the indices:

indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1])  

How can I sort tensor using the values inside of indices?

Trying with sorted gives the error:

TypeError: 'Tensor' object is not callable`.

While numpy.sort gives:

ValueError: Cannot specify order when the array has no fields.`

Upvotes: 2

Views: 1339

Answers (1)

Hamzah Al-Qadasi
Hamzah Al-Qadasi

Reputation: 9786

You can use the indices like this:

tensor = torch.Tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])
indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1]) 
sorted_tensor = tensor[indices]
print(sorted_tensor)
# output
tensor([[0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.]])

Upvotes: 4

Related Questions