NeuralNew
NeuralNew

Reputation: 116

Shuffling along a given axis in PyTorch

I have the a dataset that gets loaded in with the following dimension [batch_size, seq_len, n_features] (e.g. torch.Size([16, 600, 130])).

I want to be able to shuffle this data along the sequence length axis=1 without altering the batch ordering or the feature vector ordering in PyTorch.

Further explanation: For exemplification let's say my batch size is 3, sequence length is 3 and number of features is 2.

example: tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]]) I want to be able to randomly shuffle the following way:

tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])

Are there any PyTorch functions that will do that automatically for me, or does anyone know what would be a good way to implement this?

Upvotes: 3

Views: 3531

Answers (2)

Alex Gaudio
Alex Gaudio

Reputation: 1944

A two-part answer to shuffle along an axis.

  • First, a direct solution giving different random permutation for each "row" of axis 1.
  • Second, a generalized shuffle "row" function to shuffle any axis.

Side note 1: Sorry my answer is several months late - I just had this question myself and I couldn't find an easy solution to the problem online, so here it is.

Side note 2: The nice answer from @GoodDeeds, as mentioned, gives the same random permutation across other axes. This gives a different permutation across other axes.

First, an intuitive example for axis=1:

Input:

>>> a
tensor([[[1, 1],
         [2, 2],
         [3, 3]],

        [[4, 4],
         [5, 5],
         [6, 6]],

        [[7, 7],
         [8, 8],
         [9, 9]]])

Select random "rows" of axis 1.

>>> z = torch.rand(a.shape[:2]).argsort(1)  # define random "row" indices
>>> z = z.unsqueeze(-1).repeat(1, 1, *(a.shape[2:]))  # reformat this for the gather operation.  Note that this works only for dim=1.
>>> output = a.gather(1, z)

Output:

>>> output
tensor([[[2, 2],
         [3, 3],
         [1, 1]],

        [[5, 5],
         [6, 6],
         [4, 4]],

        [[8, 8],
         [9, 9],
         [7, 7]]])

Second, a generalization to any axis:

It would be great if PyTorch had this function in its standard lib. I'll raise an issue and link to this post.

def shufflerow(tensor, axis):
    row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis)  # get permutation indices
    for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
    row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:]))  # reformat this for the gather operation
    return tensor.gather(axis, row_perm)

Example:

>>> x = torch.arange(2*3*4).reshape(2,3,4)
>>> x
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

shuffle axis 0:

>>> shufflerow(x, 0)
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])

shuffle axis 1

>>> shufflerow(x, 1)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[16, 17, 18, 19],
         [12, 13, 14, 15],
         [20, 21, 22, 23]]])

shuffle axis 2

>>> shufflerow(x, 2)
tensor([[[ 2,  0,  1,  3],
         [ 5,  6,  7,  4],
         [11, 10,  9,  8]],

        [[15, 14, 13, 12],
         [18, 17, 19, 16],
         [23, 20, 22, 21]]])

Upvotes: 2

GoodDeeds
GoodDeeds

Reputation: 8497

You can use torch.randperm.

For tensor t, you can use:

t[:,torch.randperm(t.shape[1]),:]

For your example:

>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
         [2, 2],
         [3, 3]],

        [[4, 4],
         [5, 5],
         [6, 6]],

        [[7, 7],
         [8, 8],
         [9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
         [3, 3],
         [1, 1]],

        [[5, 5],
         [6, 6],
         [4, 4]],

        [[8, 8],
         [9, 9],
         [7, 7]]])

Upvotes: 1

Related Questions