user395788
user395788

Reputation: 245

Indexing the max elements in a multidimensional tensor in PyTorch

I'm trying to index the maximum elements along the last dimension in a multidimensional tensor. For example, say I have a tensor

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

Here idx stores the maximum indices, which may look something like

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

I want to be able to access these indices and assign to another tensor based on them. Meaning I want to be able to do

B = torch.new_zeros(A.size())
B[idx] = A[idx]

where B is 0 everywhere except where A is maximum along the last dimension. That is B should store

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

This is proving to be much more difficult than I expected, as the idx does not index the array A properly. Thus far I have been unable to find a vectorized solution to use idx to index A.

Is there a good vectorized way to do this?

Upvotes: 8

Views: 7012

Answers (3)

Jingjian Wei
Jingjian Wei

Reputation: 1

could use torch.scatter here

>>> import torch
>>> a = torch.randn(4,2,3)
>>> a
tensor([[[ 0.1583,  0.1102, -0.8188],
         [ 0.6328, -1.9169, -0.5596]],

        [[ 0.5335,  0.4069,  0.8403],
         [-1.2537,  0.9868, -0.4947]],

        [[-1.2830,  0.4386, -0.0107],
         [ 1.3384,  0.5651,  0.2877]],

        [[-0.0334, -1.0619, -0.1144],
         [ 0.1954, -0.7371,  1.7001]]])
>>> ind = torch.max(a,1,keepdims=True)[1]
>>> ind
tensor([[[1, 0, 1]],

        [[0, 1, 0]],

        [[1, 1, 1]],

        [[1, 1, 1]]])
>>> torch.zeros_like(a).scatter(1,ind,a)
tensor([[[ 0.0000,  0.1102,  0.0000],
         [ 0.1583,  0.0000, -0.8188]],

        [[ 0.5335,  0.0000,  0.8403],
         [ 0.0000,  0.4069,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [-1.2830,  0.4386, -0.0107]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.0334, -1.0619, -0.1144]]])

Upvotes: 0

a_guest
a_guest

Reputation: 36249

You can use torch.meshgrid to create an index tuple:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

Note that you can also mimic meshgrid via (for the specific case of 3D):

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

Bit more explanation:
We will have the indices something like this:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

From this, we want to go to three indices (since our tensor is 3D, we need three numbers to retrieve each element). Basically we want to build a grid in the first two dimensions, as shown below. (And that's why we use meshgrid).

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

In the above 5 lines, the first two numbers in the indices are basically the grid that we build using meshgrid and the third number is coming from idx.

i.e. the first two numbers form a grid.

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)

Upvotes: 5

Jatentaki
Jatentaki

Reputation: 13103

An ugly hackaround is to create a binary mask out of idx and use it to index the arrays. The basic code looks like this:

import torch
torch.manual_seed(0)

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)

The trick is that torch.arange(A.size(2)) enumerates the possible values in idx and mask is nonzero in places where they equal the idx. Remarks:

  1. If you really discard the first output of torch.max, you can use torch.argmax instead.
  2. I assume that this is a minimal example of some wider problem, but be aware that you are currently reinventing torch.nn.functional.max_pool3d with kernel of size (1, 1, 3).
  3. Also, be aware that in-place modification of tensors with masked assignment can cause issues with autograd, so you may want to use torch.where as shown here.

I would expect that somebody comes up with a cleaner solution (avoiding the intermedia allocation of the mask array), likely making use of torch.index_select, but I can't get it to work right now.

Upvotes: 3

Related Questions