Sebastian T. Vincent
Sebastian T. Vincent

Reputation: 55

PyTorch tensors topk for every tensor across a dimension

I have the following tensor

inp = tensor([[[ 0.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 0.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 0.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2777e+00],
     [ 1.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 1.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2777e+00]],

    [[ 0.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 0.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 0.0000e+00,  2.9330e+03, -7.3009e+00],
     [ 1.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 1.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 1.0000e+00,  2.9330e+03, -7.3009e+00]],

    [[ 0.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 0.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 0.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 0.0000e+00,  1.2910e+03, -7.3615e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 1.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 1.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 1.0000e+00,  1.2910e+03, -7.3615e+00]]])

of shape

torch.Size([3, 8, 3])

and I would like to find the topk(k=4) elements across dim1, where the value to sort by is dim2 (the negative values). The resulting tensor shape should then be:

torch.Size([3, 4, 3])

I know how to do topk for a single tensor, but how do I do this for several batches at once?

Upvotes: 0

Views: 803

Answers (2)

swag2198
swag2198

Reputation: 2696

One way to do this is by combining fancy indexing and broadcasting as follows:

I am taking a random tensor x of shape (3, 4, 3) and k to be 2 as the example.

>>> import torch
>>> x = torch.rand(3, 4, 3)
>>> x
tensor([[[0.0256, 0.7366, 0.2528],
         [0.5596, 0.9450, 0.5795],
         [0.8265, 0.5469, 0.8304],
         [0.4223, 0.5206, 0.2898]],

        [[0.2159, 0.0369, 0.6869],
         [0.4556, 0.5804, 0.3169],
         [0.8194, 0.5240, 0.0055],
         [0.8357, 0.4162, 0.3740]],

        [[0.3849, 0.0223, 0.9951],
         [0.2872, 0.5952, 0.6570],
         [0.1433, 0.8450, 0.6557],
         [0.0270, 0.9176, 0.3904]]])

Now sort the tensor along the required dimension (here last) and get the indices:

>>> _, idx = torch.sort(x[:, :, -1])
>>> k = 2
>>> idx = idx[:, :k]
# idx is = 
tensor([[0, 3],
        [2, 1],
        [3, 2]])

Now generate three pair of indices (i, j, k) to slice the original tensor as follows:

>>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
>>> j = idx.reshape(x.shape[0], -1, 1)
>>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])

Note that once you index anything by (i, j, k), they are going to expand and take the shape (x.shape[0], k, x.shape[2]) which is the desired output shape here. Now just index x by i, j and k:

>>> x[i, j, k]
tensor([[[0.0256, 0.7366, 0.2528],
         [0.4223, 0.5206, 0.2898]],

        [[0.8194, 0.5240, 0.0055],
         [0.4556, 0.5804, 0.3169]],

        [[0.0270, 0.9176, 0.3904],
         [0.1433, 0.8450, 0.6557]]])

Essentially, the general recipe that I follow is to create the corresponding access pattern of the tensor via the index arrays and then slicing the tensor directly by using those arrays as indices.

I actually did this for an ascending order sort, so here I am getting top-k least elements. An easy workaround to get the reverse would be to use torch.sort(x[:, :, -1], descending = True).

Upvotes: 0

Sebastian T. Vincent
Sebastian T. Vincent

Reputation: 55

I did it like this:

val, ind = inp[:, :, 2].squeeze().topk(k=4, dim=1, sorted=True)
new_ind = ind.unsqueeze(-1).repeat(1,1,3)
result = inp.gather(1, new_ind)

I don't know if this is the best way to do this but it worked.

Upvotes: 1

Related Questions