JVGD
JVGD

Reputation: 737

PyTorch indexing by argmax

Dear community I have a challenge with regard to tensor indexing in PyTorch. The problem is very simple. Given a tensor create an index tensor to index its maximum values per column.

x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0], 
              [0, 4, 9, 6, 7, 9, 1, 0]])

Given this tensor I would like to build a boolean mask for indexing its maximum values per colum. To be specific I do not need its maximum values, torch.max(x, dim=0), nor its indices, torch.argmax(x, dim=0), but a boolean mask for indexing other tensor based on this tensor max values. My ideal output would be:

# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
        [0, 4, 9, 6, 7, 9, 1, 0]])

# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

I know that values_max = x[idx] and values_max = x.max(dim=0) are equivalent but I am not looking for values_max but for idx.

I have built a solution around it but it just seem to complex and I am sure torch have an optimized way to do this. I have tried to use torch.index_select with the output of x.argmax(dim=0) but failed so I built a custom solution that seems to cumbersome to me so I am asking for help to do this in a vectorized / tensorial / torch way.

Upvotes: 0

Views: 4799

Answers (1)

Ivan
Ivan

Reputation: 40628

You can perform this operation by first extracting the index of the maximum value column-wise of your tensor with torch.argmax, setting keepdim to True

>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])

Then you can use torch.scatter to place 1s in a zero tensor at the designated indices:

>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

Upvotes: 2

Related Questions