Reputation: 209
I'm in pytorch and I have a tensor x
of size batch_size x d x S
. It has to be intended as a batch of sequences of length S
, where every sequence element is d
dimensional. Every sequence is actually the overlap of multiple sub-sequences, in the following sense:
past_size + present_size
, i.e we have past_size
d
-dimensional elements followed by other present_size
elementspresent_size
sections are equispaced by present_size
elements, and they are placed in the right-most positionsTo make an example, with batch_size=1, d=1
, consider x = [1,2,3,4,5,6,7,8,9]
, where present_size = 2, past_size = 3
. The resulting subsequences would be:
[1,2,3,4,5]
[3,4,5,6,7]
[5,6,7,8,9]
The end goal is to produce the splitting of every sequence into the, say, N
sub-sequences, to get a tensor of shape batch_size*N x d x past_size+present_size
.
My second try is the following:
def seq(x, present_size, total_size=present_size+past_size, N):
z = x.unfold(-1, total_size, present_size)
v = torch.flatten(z, start_dim=2)
s = torch.cat(torch.chunk(v, N, -1), 0)
return s
Edit
In the above example, N = 3
.
Moreover, we have the following relation: N*present_size + past_size = S
Input-output
Here is an example with N=4, present_size = 1, past_size = 2
.
x = torch.rand(4,8,6) # d=8, batch_size = 4, 6 = N*present_size + past_size
>>> tensor([[[0.5667, 0.5300, 0.2460, 0.4327, 0.4727, 0.5649],
[0.0360, 0.6687, 0.0167, 0.5359, 0.9804, 0.8778],
[0.3703, 0.4884, 0.1505, 0.5463, 0.8114, 0.3270],
[0.2932, 0.4928, 0.3933, 0.2433, 0.7053, 0.5222],
[0.6667, 0.2014, 0.7107, 0.7535, 0.2816, 0.6515],
[0.5285, 0.4150, 0.2557, 0.2144, 0.8317, 0.5448],
[0.7971, 0.6609, 0.1811, 0.7788, 0.6649, 0.1848],
[0.6902, 0.3999, 0.8719, 0.7624, 0.5216, 0.3494]],
[[0.0196, 0.7850, 0.2796, 0.4173, 0.8076, 0.5709],
[0.4566, 0.4814, 0.0568, 0.8568, 0.9119, 0.4030],
[0.4031, 0.8887, 0.3782, 0.8015, 0.9835, 0.6043],
[0.3557, 0.5960, 0.2102, 0.8165, 0.1938, 0.4948],
[0.8163, 0.7907, 0.3711, 0.6835, 0.8021, 0.1897],
[0.7790, 0.2621, 0.3769, 0.3830, 0.7140, 0.2309],
[0.5831, 0.0246, 0.6548, 0.8694, 0.1988, 0.5470],
[0.1192, 0.2928, 0.4240, 0.2624, 0.7959, 0.4091]],
[[0.7959, 0.7144, 0.4523, 0.5090, 0.6053, 0.4071],
[0.4742, 0.0224, 0.9939, 0.9757, 0.0732, 0.6213],
[0.5211, 0.1149, 0.8218, 0.7061, 0.1807, 0.2822],
[0.1456, 0.7331, 0.9107, 0.9533, 0.2438, 0.4031],
[0.0958, 0.2623, 0.0828, 0.2861, 0.0474, 0.8349],
[0.1740, 0.3658, 0.2416, 0.6735, 0.4013, 0.8896],
[0.6934, 0.8709, 0.4017, 0.6121, 0.5824, 0.5803],
[0.4811, 0.1036, 0.4356, 0.6441, 0.5859, 0.4683]],
[[0.2479, 0.9247, 0.3216, 0.6844, 0.1701, 0.4609],
[0.3320, 0.4908, 0.0458, 0.9887, 0.4725, 0.7511],
[0.0594, 0.1978, 0.8830, 0.9126, 0.4821, 0.7731],
[0.3729, 0.4921, 0.9266, 0.7827, 0.8101, 0.6258],
[0.4998, 0.7596, 0.1160, 0.3928, 0.4773, 0.7892],
[0.0215, 0.1325, 0.5940, 0.2094, 0.3109, 0.9281],
[0.7960, 0.1707, 0.1793, 0.7335, 0.2065, 0.6204],
[0.6350, 0.9696, 0.5099, 0.7375, 0.7601, 0.1405]]])
r = seq(x, 1, 2+1, 4)
>>> tensor([[[0.5667, 0.5300, 0.2460],
[0.0360, 0.6687, 0.0167],
[0.3703, 0.4884, 0.1505],
[0.2932, 0.4928, 0.3933],
[0.6667, 0.2014, 0.7107],
[0.5285, 0.4150, 0.2557],
[0.7971, 0.6609, 0.1811],
[0.6902, 0.3999, 0.8719]],
[[0.0196, 0.7850, 0.2796],
[0.4566, 0.4814, 0.0568],
[0.4031, 0.8887, 0.3782],
[0.3557, 0.5960, 0.2102],
[0.8163, 0.7907, 0.3711],
[0.7790, 0.2621, 0.3769],
[0.5831, 0.0246, 0.6548],
[0.1192, 0.2928, 0.4240]],
[[0.7959, 0.7144, 0.4523],
[0.4742, 0.0224, 0.9939],
[0.5211, 0.1149, 0.8218],
[0.1456, 0.7331, 0.9107],
[0.0958, 0.2623, 0.0828],
[0.1740, 0.3658, 0.2416],
[0.6934, 0.8709, 0.4017],
[0.4811, 0.1036, 0.4356]],
[[0.2479, 0.9247, 0.3216],
[0.3320, 0.4908, 0.0458],
[0.0594, 0.1978, 0.8830],
[0.3729, 0.4921, 0.9266],
[0.4998, 0.7596, 0.1160],
[0.0215, 0.1325, 0.5940],
[0.7960, 0.1707, 0.1793],
[0.6350, 0.9696, 0.5099]],
[[0.5300, 0.2460, 0.4327],
[0.6687, 0.0167, 0.5359],
[0.4884, 0.1505, 0.5463],
[0.4928, 0.3933, 0.2433],
[0.2014, 0.7107, 0.7535],
[0.4150, 0.2557, 0.2144],
[0.6609, 0.1811, 0.7788],
[0.3999, 0.8719, 0.7624]],
[[0.7850, 0.2796, 0.4173],
[0.4814, 0.0568, 0.8568],
[0.8887, 0.3782, 0.8015],
[0.5960, 0.2102, 0.8165],
[0.7907, 0.3711, 0.6835],
[0.2621, 0.3769, 0.3830],
[0.0246, 0.6548, 0.8694],
[0.2928, 0.4240, 0.2624]],
[[0.7144, 0.4523, 0.5090],
[0.0224, 0.9939, 0.9757],
[0.1149, 0.8218, 0.7061],
[0.7331, 0.9107, 0.9533],
[0.2623, 0.0828, 0.2861],
[0.3658, 0.2416, 0.6735],
[0.8709, 0.4017, 0.6121],
[0.1036, 0.4356, 0.6441]],
[[0.9247, 0.3216, 0.6844],
[0.4908, 0.0458, 0.9887],
[0.1978, 0.8830, 0.9126],
[0.4921, 0.9266, 0.7827],
[0.7596, 0.1160, 0.3928],
[0.1325, 0.5940, 0.2094],
[0.1707, 0.1793, 0.7335],
[0.9696, 0.5099, 0.7375]],
[[0.2460, 0.4327, 0.4727],
[0.0167, 0.5359, 0.9804],
[0.1505, 0.5463, 0.8114],
[0.3933, 0.2433, 0.7053],
[0.7107, 0.7535, 0.2816],
[0.2557, 0.2144, 0.8317],
[0.1811, 0.7788, 0.6649],
[0.8719, 0.7624, 0.5216]],
[[0.2796, 0.4173, 0.8076],
[0.0568, 0.8568, 0.9119],
[0.3782, 0.8015, 0.9835],
[0.2102, 0.8165, 0.1938],
[0.3711, 0.6835, 0.8021],
[0.3769, 0.3830, 0.7140],
[0.6548, 0.8694, 0.1988],
[0.4240, 0.2624, 0.7959]],
[[0.4523, 0.5090, 0.6053],
[0.9939, 0.9757, 0.0732],
[0.8218, 0.7061, 0.1807],
[0.9107, 0.9533, 0.2438],
[0.0828, 0.2861, 0.0474],
[0.2416, 0.6735, 0.4013],
[0.4017, 0.6121, 0.5824],
[0.4356, 0.6441, 0.5859]],
[[0.3216, 0.6844, 0.1701],
[0.0458, 0.9887, 0.4725],
[0.8830, 0.9126, 0.4821],
[0.9266, 0.7827, 0.8101],
[0.1160, 0.3928, 0.4773],
[0.5940, 0.2094, 0.3109],
[0.1793, 0.7335, 0.2065],
[0.5099, 0.7375, 0.7601]],
[[0.4327, 0.4727, 0.5649],
[0.5359, 0.9804, 0.8778],
[0.5463, 0.8114, 0.3270],
[0.2433, 0.7053, 0.5222],
[0.7535, 0.2816, 0.6515],
[0.2144, 0.8317, 0.5448],
[0.7788, 0.6649, 0.1848],
[0.7624, 0.5216, 0.3494]],
[[0.4173, 0.8076, 0.5709],
[0.8568, 0.9119, 0.4030],
[0.8015, 0.9835, 0.6043],
[0.8165, 0.1938, 0.4948],
[0.6835, 0.8021, 0.1897],
[0.3830, 0.7140, 0.2309],
[0.8694, 0.1988, 0.5470],
[0.2624, 0.7959, 0.4091]],
[[0.5090, 0.6053, 0.4071],
[0.9757, 0.0732, 0.6213],
[0.7061, 0.1807, 0.2822],
[0.9533, 0.2438, 0.4031],
[0.2861, 0.0474, 0.8349],
[0.6735, 0.4013, 0.8896],
[0.6121, 0.5824, 0.5803],
[0.6441, 0.5859, 0.4683]],
[[0.6844, 0.1701, 0.4609],
[0.9887, 0.4725, 0.7511],
[0.9126, 0.4821, 0.7731],
[0.7827, 0.8101, 0.6258],
[0.3928, 0.4773, 0.7892],
[0.2094, 0.3109, 0.9281],
[0.7335, 0.2065, 0.6204],
[0.7375, 0.7601, 0.1405]]])
Upvotes: 0
Views: 763
Reputation: 40658
torch.gather
You can see this problem as reassigning each element to a new position. This has to be done using a tensor containing the indices of the permutation you which to see happening.
If you look at the indices of the last dimension for input x
(we will take your example with x.shape = (4, 8, 6)
), you have them ordered this way:
tensor([[[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
... 4 more
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]],
... 2 more
[[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
... 4 more
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]]])
Now the permutation of indices should be looking like (considering N=4
, present_size=1
, and past_size=2
). Keep in mind I'm only representing two dimensions among the four x
in total:
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
From there it will be easy to construct the new tensor using torch.gather
. The operation will effectively create a tensor out
defined in the following way:
out[i][j][k][l] = x[i][j][k][indices[i, j, k, l]]
In order to construct such tensor of indices, we will use arrangements. The following are the base indices:
>>> arr = torch.arange(total_size)[None].repeat(N, 1)
tensor([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
to which we add a displacement of present_size
accumulated over the rows:
>>> disp = torch.arange(0, total_size + 1, step=present_size)[None].T
tensor([[0],
[1],
[2],
[3]])
The resulting minimal tensor of indices is:
>>> indices = arr + disp
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
torch.gather
First, we need to expand the rows of x
to N
: the number of rows in the resulting tensor.
>>> x_r = x[None].expand(N, *(-1,)*x.ndim)
>>> x.shape, x_r.shape
(torch.Size([4, 8, 6]), torch.Size([4, 4, 8, 6]))
In order to use torch.gather
, we need the input and tensor of indices to have the same shape. To do so we can make views of our tensors using Tensor.expand
.
So here we will insert two additional dimensions on indices
and expand them to match the sizes of x
's first and second axis.
>>> i_r = indices[:, None, None, :].expand(-1, x.size(0), x.size(1), -1)
indices.shape, i_r.shape
(torch.Size([4, 3]), torch.Size([4, 4, 8, 3]))
Then apply the gather function on the last axis of indices:
>>> torch.gather(x_r, dim=-1, index=i_r)
tensor([[[[0.5667, 0.5300, 0.2460],
[0.0360, 0.6687, 0.0167],
[0.3703, 0.4884, 0.1505],
[0.2932, 0.4928, 0.3933],
[0.6667, 0.2014, 0.7107],
[0.5285, 0.4150, 0.2557],
[0.7971, 0.6609, 0.1811],
[0.6902, 0.3999, 0.8719]],
...
[[0.6844, 0.1701, 0.4609],
[0.9887, 0.4725, 0.7511],
[0.9126, 0.4821, 0.7731],
[0.7827, 0.8101, 0.6258],
[0.3928, 0.4773, 0.7892],
[0.2094, 0.3109, 0.9281],
[0.7335, 0.2065, 0.6204],
[0.7375, 0.7601, 0.1405]]]])
If you have any questions, please don't hesitate to ask!
Upvotes: 1