spadel
spadel

Reputation: 1036

How to sort a 3d tensor by coordinates in last dimension (pytorch)

I have a tensor with shape [bn, k, 2]. The last dimension are coordinates and I want each batch to be sorted independently depending on the y coordinate ([:, :, 0]). My approach looks something like this:

import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]

print(a)
print(a_sorted)

So far so good, but I now it sorts both batches according to both index lists, so I get 4 batches in total:

a
tensor([[[ 0.5160,  0.3257],
         [-1.2410, -0.8361],
         [ 1.3826, -1.1308],
         [ 0.0338,  0.1665],
         [-0.9375, -0.3081]],

        [[ 0.4140, -1.0962],
         [ 0.9847, -0.7231],
         [-0.0110,  0.6437],
         [-0.4914,  0.2473],
         [-0.0938, -0.0722]]])

a_sorted
tensor([[[[-1.2410, -0.8361],
          [-0.9375, -0.3081],
          [ 0.0338,  0.1665],
          [ 0.5160,  0.3257],
          [ 1.3826, -1.1308]],

         [[ 0.0338,  0.1665],
          [-0.9375, -0.3081],
          [ 1.3826, -1.1308],
          [ 0.5160,  0.3257],
          [-1.2410, -0.8361]]],


        [[[ 0.9847, -0.7231],
          [-0.0938, -0.0722],
          [-0.4914,  0.2473],
          [ 0.4140, -1.0962],
          [-0.0110,  0.6437]],

         [[-0.4914,  0.2473],
          [-0.0938, -0.0722],
          [-0.0110,  0.6437],
          [ 0.4140, -1.0962],
          [ 0.9847, -0.7231]]]])

As you can see, I want only the 1st and the 4th batch to be returned. How do I do that?

Upvotes: 1

Views: 1006

Answers (1)

Multihunter
Multihunter

Reputation: 5918

What you want: concatenation of a[0, indices[0]] and a[1, indices[1]].

What you coded: concatenation of a[0, indices] and a[1, indices].

The issue you are facing is because the indices returned by sort are shaped like the first dimensions, but the values are only indices into the second dimension. When you go to use these, you want to match indices[0] on a[0], but pytorch doesn't do this implicitly (because fancy indexing is very powerful, and needs this syntax for it's power). So, all you have to do is give a parallel list of indices for the first dimension.

i.e. You want to use something like: a[[[0], [1]], indices].

To generalise this a bit more, you may use something like:

n = a.shape[0]
first_indices = torch.arange(n)[:, None]
a[first_indices, indices]

This is a little tricksy, so here's an example:

>>> a = torch.randn(2,4,2)
>>> a
tensor([[[-0.2050, -0.1651],
         [ 0.5688,  1.0082],
         [-1.5964, -0.9236],
         [ 0.3093, -0.2445]],

        [[ 1.0586,  1.0048],
         [ 0.0893,  2.4522],
         [ 2.1433, -1.2428],
         [ 0.1591,  2.4945]]])
>>> indices = a[:, :, 0].sort()[1]
>>> indices
tensor([[2, 0, 3, 1],
        [1, 3, 0, 2]])
>>> a[:, indices]
tensor([[[[-1.5964, -0.9236],
          [-0.2050, -0.1651],
          [ 0.3093, -0.2445],
          [ 0.5688,  1.0082]],

         [[ 0.5688,  1.0082],
          [ 0.3093, -0.2445],
          [-0.2050, -0.1651],
          [-1.5964, -0.9236]]],


        [[[ 2.1433, -1.2428],
          [ 1.0586,  1.0048],
          [ 0.1591,  2.4945],
          [ 0.0893,  2.4522]],

         [[ 0.0893,  2.4522],
          [ 0.1591,  2.4945],
          [ 1.0586,  1.0048],
          [ 2.1433, -1.2428]]]])
>>> a[0, indices]
tensor([[[-1.5964, -0.9236],
         [-0.2050, -0.1651],
         [ 0.3093, -0.2445],
         [ 0.5688,  1.0082]],

        [[ 0.5688,  1.0082],
         [ 0.3093, -0.2445],
         [-0.2050, -0.1651],
         [-1.5964, -0.9236]]])
>>> a[1, indices]
tensor([[[ 2.1433, -1.2428],
         [ 1.0586,  1.0048],
         [ 0.1591,  2.4945],
         [ 0.0893,  2.4522]],

        [[ 0.0893,  2.4522],
         [ 0.1591,  2.4945],
         [ 1.0586,  1.0048],
         [ 2.1433, -1.2428]]])
>>> a[0, indices[0]]
tensor([[-1.5964, -0.9236],
        [-0.2050, -0.1651],
        [ 0.3093, -0.2445],
        [ 0.5688,  1.0082]])
>>> a[1, indices[1]]
tensor([[ 0.0893,  2.4522],
        [ 0.1591,  2.4945],
        [ 1.0586,  1.0048],
        [ 2.1433, -1.2428]])
>>> a[[[0], [1]], indices]
tensor([[[-1.5964, -0.9236],
         [-0.2050, -0.1651],
         [ 0.3093, -0.2445],
         [ 0.5688,  1.0082]],

        [[ 0.0893,  2.4522],
         [ 0.1591,  2.4945],
         [ 1.0586,  1.0048],
         [ 2.1433, -1.2428]]])

Upvotes: 2

Related Questions