laridzhang
laridzhang

Reputation: 1801

Pytorch: is there a function similar to torch.argmax which can really keep the dimension of the original data?

For example, the code is

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)

input is

tensor([[ 1.5742,  0.8183, -2.3005, -1.1650, -0.2451],
       [ 1.0553,  0.6021, -0.4938, -1.5379, -1.2054],
       [-0.1728,  0.8372, -1.9181, -0.9110,  0.2422]])

and result is

tensor([[ 0,  2,  1,  2,  2]])

However, I want a result like this

tensor([[ 1,  0,  0,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  1,  0,  1,  1]])

Upvotes: 3

Views: 2576

Answers (2)

Amir
Amir

Reputation: 21

You can use nn.functional.one_hot and then permute for the axis you need, since one_hot expands the last dimension.

from torch.nn.functional import one_hot

input = torch.randn(3, 5)
output = one_hot(torch.argmax(input, dim=0)).permute(-1,0)

With input

tensor([[ 1.1320, -0.7152,  2.0861,  0.6044, -0.9840],
        [ 0.8313,  2.4974,  1.3477,  1.4260, -0.4859],
        [-0.6532,  2.5891, -1.3084,  2.0589,  1.8340]])

And output

tensor([[1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 1, 0, 1, 1]])

Upvotes: 0

laridzhang
laridzhang

Reputation: 1801

Finally, I solved it. But this solution may not be efficient. Code as follow,

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
result_0 = result == 0
result_1 = result == 1
result_2 = result == 2
result = torch.cat((result_0, result_1, result_2), 0)

Upvotes: 2

Related Questions