Reputation: 116
I have the a dataset that gets loaded in with the following dimension [batch_size, seq_len, n_features]
(e.g. torch.Size([16, 600, 130])).
I want to be able to shuffle this data along the sequence length axis=1
without altering the batch ordering or the feature vector ordering in PyTorch.
Further explanation: For exemplification let's say my batch size is 3, sequence length is 3 and number of features is 2.
example:
tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
I want to be able to randomly shuffle the following way:
tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])
Are there any PyTorch functions that will do that automatically for me, or does anyone know what would be a good way to implement this?
Upvotes: 3
Views: 3531
Reputation: 1944
Side note 1: Sorry my answer is several months late - I just had this question myself and I couldn't find an easy solution to the problem online, so here it is.
Side note 2: The nice answer from @GoodDeeds, as mentioned, gives the same random permutation across other axes. This gives a different permutation across other axes.
Input:
>>> a
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
Select random "rows" of axis 1.
>>> z = torch.rand(a.shape[:2]).argsort(1) # define random "row" indices
>>> z = z.unsqueeze(-1).repeat(1, 1, *(a.shape[2:])) # reformat this for the gather operation. Note that this works only for dim=1.
>>> output = a.gather(1, z)
Output:
>>> output
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
It would be great if PyTorch had this function in its standard lib. I'll raise an issue and link to this post.
def shufflerow(tensor, axis):
row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis) # get permutation indices
for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:])) # reformat this for the gather operation
return tensor.gather(axis, row_perm)
Example:
>>> x = torch.arange(2*3*4).reshape(2,3,4)
>>> x
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
shuffle axis 0:
>>> shufflerow(x, 0)
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
shuffle axis 1
>>> shufflerow(x, 1)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
shuffle axis 2
>>> shufflerow(x, 2)
tensor([[[ 2, 0, 1, 3],
[ 5, 6, 7, 4],
[11, 10, 9, 8]],
[[15, 14, 13, 12],
[18, 17, 19, 16],
[23, 20, 22, 21]]])
Upvotes: 2
Reputation: 8497
You can use torch.randperm
.
For tensor t
, you can use:
t[:,torch.randperm(t.shape[1]),:]
For your example:
>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
Upvotes: 1