Lilla
Lilla

Reputation: 209

Splitting tensor into sub-tensors in overlapping fashion

I'm in pytorch and I have a tensor x of size batch_size x d x S. It has to be intended as a batch of sequences of length S, where every sequence element is d dimensional. Every sequence is actually the overlap of multiple sub-sequences, in the following sense:

To make an example, with batch_size=1, d=1, consider x = [1,2,3,4,5,6,7,8,9], where present_size = 2, past_size = 3. The resulting subsequences would be:

  1. [1,2,3,4,5]
  2. [3,4,5,6,7]
  3. [5,6,7,8,9]

The end goal is to produce the splitting of every sequence into the, say, N sub-sequences, to get a tensor of shape batch_size*N x d x past_size+present_size.

My second try is the following:

def seq(x, present_size, total_size=present_size+past_size, N):
   z = x.unfold(-1, total_size, present_size)
   v = torch.flatten(z, start_dim=2)
   s = torch.cat(torch.chunk(v, N, -1), 0)
   return s

Is there a more efficient way? Is it possible to backpropagate through such a function?

Edit In the above example, N = 3.

Moreover, we have the following relation: N*present_size + past_size = S


Input-output

Here is an example with N=4, present_size = 1, past_size = 2.

x = torch.rand(4,8,6) # d=8, batch_size = 4, 6 = N*present_size + past_size
>>> tensor([[[0.5667, 0.5300, 0.2460, 0.4327, 0.4727, 0.5649],
     [0.0360, 0.6687, 0.0167, 0.5359, 0.9804, 0.8778],
     [0.3703, 0.4884, 0.1505, 0.5463, 0.8114, 0.3270],
     [0.2932, 0.4928, 0.3933, 0.2433, 0.7053, 0.5222],
     [0.6667, 0.2014, 0.7107, 0.7535, 0.2816, 0.6515],
     [0.5285, 0.4150, 0.2557, 0.2144, 0.8317, 0.5448],
     [0.7971, 0.6609, 0.1811, 0.7788, 0.6649, 0.1848],
     [0.6902, 0.3999, 0.8719, 0.7624, 0.5216, 0.3494]],

    [[0.0196, 0.7850, 0.2796, 0.4173, 0.8076, 0.5709],
     [0.4566, 0.4814, 0.0568, 0.8568, 0.9119, 0.4030],
     [0.4031, 0.8887, 0.3782, 0.8015, 0.9835, 0.6043],
     [0.3557, 0.5960, 0.2102, 0.8165, 0.1938, 0.4948],
     [0.8163, 0.7907, 0.3711, 0.6835, 0.8021, 0.1897],
     [0.7790, 0.2621, 0.3769, 0.3830, 0.7140, 0.2309],
     [0.5831, 0.0246, 0.6548, 0.8694, 0.1988, 0.5470],
     [0.1192, 0.2928, 0.4240, 0.2624, 0.7959, 0.4091]],

    [[0.7959, 0.7144, 0.4523, 0.5090, 0.6053, 0.4071],
     [0.4742, 0.0224, 0.9939, 0.9757, 0.0732, 0.6213],
     [0.5211, 0.1149, 0.8218, 0.7061, 0.1807, 0.2822],
     [0.1456, 0.7331, 0.9107, 0.9533, 0.2438, 0.4031],
     [0.0958, 0.2623, 0.0828, 0.2861, 0.0474, 0.8349],
     [0.1740, 0.3658, 0.2416, 0.6735, 0.4013, 0.8896],
     [0.6934, 0.8709, 0.4017, 0.6121, 0.5824, 0.5803],
     [0.4811, 0.1036, 0.4356, 0.6441, 0.5859, 0.4683]],

    [[0.2479, 0.9247, 0.3216, 0.6844, 0.1701, 0.4609],
     [0.3320, 0.4908, 0.0458, 0.9887, 0.4725, 0.7511],
     [0.0594, 0.1978, 0.8830, 0.9126, 0.4821, 0.7731],
     [0.3729, 0.4921, 0.9266, 0.7827, 0.8101, 0.6258],
     [0.4998, 0.7596, 0.1160, 0.3928, 0.4773, 0.7892],
     [0.0215, 0.1325, 0.5940, 0.2094, 0.3109, 0.9281],
     [0.7960, 0.1707, 0.1793, 0.7335, 0.2065, 0.6204],
     [0.6350, 0.9696, 0.5099, 0.7375, 0.7601, 0.1405]]])


r = seq(x, 1, 2+1, 4)
>>> tensor([[[0.5667, 0.5300, 0.2460],
     [0.0360, 0.6687, 0.0167],
     [0.3703, 0.4884, 0.1505],
     [0.2932, 0.4928, 0.3933],
     [0.6667, 0.2014, 0.7107],
     [0.5285, 0.4150, 0.2557],
     [0.7971, 0.6609, 0.1811],
     [0.6902, 0.3999, 0.8719]],

    [[0.0196, 0.7850, 0.2796],
     [0.4566, 0.4814, 0.0568],
     [0.4031, 0.8887, 0.3782],
     [0.3557, 0.5960, 0.2102],
     [0.8163, 0.7907, 0.3711],
     [0.7790, 0.2621, 0.3769],
     [0.5831, 0.0246, 0.6548],
     [0.1192, 0.2928, 0.4240]],

    [[0.7959, 0.7144, 0.4523],
     [0.4742, 0.0224, 0.9939],
     [0.5211, 0.1149, 0.8218],
     [0.1456, 0.7331, 0.9107],
     [0.0958, 0.2623, 0.0828],
     [0.1740, 0.3658, 0.2416],
     [0.6934, 0.8709, 0.4017],
     [0.4811, 0.1036, 0.4356]],

    [[0.2479, 0.9247, 0.3216],
     [0.3320, 0.4908, 0.0458],
     [0.0594, 0.1978, 0.8830],
     [0.3729, 0.4921, 0.9266],
     [0.4998, 0.7596, 0.1160],
     [0.0215, 0.1325, 0.5940],
     [0.7960, 0.1707, 0.1793],
     [0.6350, 0.9696, 0.5099]],

    [[0.5300, 0.2460, 0.4327],
     [0.6687, 0.0167, 0.5359],
     [0.4884, 0.1505, 0.5463],
     [0.4928, 0.3933, 0.2433],
     [0.2014, 0.7107, 0.7535],
     [0.4150, 0.2557, 0.2144],
     [0.6609, 0.1811, 0.7788],
     [0.3999, 0.8719, 0.7624]],

    [[0.7850, 0.2796, 0.4173],
     [0.4814, 0.0568, 0.8568],
     [0.8887, 0.3782, 0.8015],
     [0.5960, 0.2102, 0.8165],
     [0.7907, 0.3711, 0.6835],
     [0.2621, 0.3769, 0.3830],
     [0.0246, 0.6548, 0.8694],
     [0.2928, 0.4240, 0.2624]],

    [[0.7144, 0.4523, 0.5090],
     [0.0224, 0.9939, 0.9757],
     [0.1149, 0.8218, 0.7061],
     [0.7331, 0.9107, 0.9533],
     [0.2623, 0.0828, 0.2861],
     [0.3658, 0.2416, 0.6735],
     [0.8709, 0.4017, 0.6121],
     [0.1036, 0.4356, 0.6441]],

    [[0.9247, 0.3216, 0.6844],
     [0.4908, 0.0458, 0.9887],
     [0.1978, 0.8830, 0.9126],
     [0.4921, 0.9266, 0.7827],
     [0.7596, 0.1160, 0.3928],
     [0.1325, 0.5940, 0.2094],
     [0.1707, 0.1793, 0.7335],
     [0.9696, 0.5099, 0.7375]],

    [[0.2460, 0.4327, 0.4727],
     [0.0167, 0.5359, 0.9804],
     [0.1505, 0.5463, 0.8114],
     [0.3933, 0.2433, 0.7053],
     [0.7107, 0.7535, 0.2816],
     [0.2557, 0.2144, 0.8317],
     [0.1811, 0.7788, 0.6649],
     [0.8719, 0.7624, 0.5216]],

    [[0.2796, 0.4173, 0.8076],
     [0.0568, 0.8568, 0.9119],
     [0.3782, 0.8015, 0.9835],
     [0.2102, 0.8165, 0.1938],
     [0.3711, 0.6835, 0.8021],
     [0.3769, 0.3830, 0.7140],
     [0.6548, 0.8694, 0.1988],
     [0.4240, 0.2624, 0.7959]],

    [[0.4523, 0.5090, 0.6053],
     [0.9939, 0.9757, 0.0732],
     [0.8218, 0.7061, 0.1807],
     [0.9107, 0.9533, 0.2438],
     [0.0828, 0.2861, 0.0474],
     [0.2416, 0.6735, 0.4013],
     [0.4017, 0.6121, 0.5824],
     [0.4356, 0.6441, 0.5859]],

    [[0.3216, 0.6844, 0.1701],
     [0.0458, 0.9887, 0.4725],
     [0.8830, 0.9126, 0.4821],
     [0.9266, 0.7827, 0.8101],
     [0.1160, 0.3928, 0.4773],
     [0.5940, 0.2094, 0.3109],
     [0.1793, 0.7335, 0.2065],
     [0.5099, 0.7375, 0.7601]],

    [[0.4327, 0.4727, 0.5649],
     [0.5359, 0.9804, 0.8778],
     [0.5463, 0.8114, 0.3270],
     [0.2433, 0.7053, 0.5222],
     [0.7535, 0.2816, 0.6515],
     [0.2144, 0.8317, 0.5448],
     [0.7788, 0.6649, 0.1848],
     [0.7624, 0.5216, 0.3494]],

    [[0.4173, 0.8076, 0.5709],
     [0.8568, 0.9119, 0.4030],
     [0.8015, 0.9835, 0.6043],
     [0.8165, 0.1938, 0.4948],
     [0.6835, 0.8021, 0.1897],
     [0.3830, 0.7140, 0.2309],
     [0.8694, 0.1988, 0.5470],
     [0.2624, 0.7959, 0.4091]],

    [[0.5090, 0.6053, 0.4071],
     [0.9757, 0.0732, 0.6213],
     [0.7061, 0.1807, 0.2822],
     [0.9533, 0.2438, 0.4031],
     [0.2861, 0.0474, 0.8349],
     [0.6735, 0.4013, 0.8896],
     [0.6121, 0.5824, 0.5803],
     [0.6441, 0.5859, 0.4683]],

    [[0.6844, 0.1701, 0.4609],
     [0.9887, 0.4725, 0.7511],
     [0.9126, 0.4821, 0.7731],
     [0.7827, 0.8101, 0.6258],
     [0.3928, 0.4773, 0.7892],
     [0.2094, 0.3109, 0.9281],
     [0.7335, 0.2065, 0.6204],
     [0.7375, 0.7601, 0.1405]]])

Upvotes: 0

Views: 763

Answers (1)

Ivan
Ivan

Reputation: 40658

Possible method using torch.gather

You can see this problem as reassigning each element to a new position. This has to be done using a tensor containing the indices of the permutation you which to see happening.

If you look at the indices of the last dimension for input x (we will take your example with x.shape = (4, 8, 6)), you have them ordered this way:

tensor([[[0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5],
          ... 4 more
         [0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5]],
        
        ... 2 more

        [[0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5],
          ... 4 more
         [0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5]]])

Now the permutation of indices should be looking like (considering N=4, present_size=1, and past_size=2). Keep in mind I'm only representing two dimensions among the four x in total:

tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4],
        [3, 4, 5]])

From there it will be easy to construct the new tensor using torch.gather. The operation will effectively create a tensor out defined in the following way:

out[i][j][k][l] = x[i][j][k][indices[i, j, k, l]]

1. Constructing the tensor of indices

In order to construct such tensor of indices, we will use arrangements. The following are the base indices:

>>> arr = torch.arange(total_size)[None].repeat(N, 1)
tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]])

to which we add a displacement of present_size accumulated over the rows:

>>> disp = torch.arange(0, total_size + 1, step=present_size)[None].T
tensor([[0],
        [1],
        [2],
        [3]])

The resulting minimal tensor of indices is:

>>> indices = arr + disp
tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4],
        [3, 4, 5]])

2. Applying torch.gather

First, we need to expand the rows of x to N: the number of rows in the resulting tensor.

>>> x_r = x[None].expand(N, *(-1,)*x.ndim)
>>> x.shape, x_r.shape
(torch.Size([4, 8, 6]), torch.Size([4, 4, 8, 6]))

In order to use torch.gather, we need the input and tensor of indices to have the same shape. To do so we can make views of our tensors using Tensor.expand.

So here we will insert two additional dimensions on indices and expand them to match the sizes of x's first and second axis.

>>> i_r = indices[:, None, None, :].expand(-1, x.size(0), x.size(1), -1)
indices.shape, i_r.shape
(torch.Size([4, 3]), torch.Size([4, 4, 8, 3]))

Then apply the gather function on the last axis of indices:

>>> torch.gather(x_r, dim=-1, index=i_r)
tensor([[[[0.5667, 0.5300, 0.2460],
          [0.0360, 0.6687, 0.0167],
          [0.3703, 0.4884, 0.1505],
          [0.2932, 0.4928, 0.3933],
          [0.6667, 0.2014, 0.7107],
          [0.5285, 0.4150, 0.2557],
          [0.7971, 0.6609, 0.1811],
          [0.6902, 0.3999, 0.8719]],

         ...
           
        [[0.6844, 0.1701, 0.4609],
         [0.9887, 0.4725, 0.7511],
         [0.9126, 0.4821, 0.7731],
         [0.7827, 0.8101, 0.6258],
         [0.3928, 0.4773, 0.7892],
         [0.2094, 0.3109, 0.9281],
         [0.7335, 0.2065, 0.6204],
         [0.7375, 0.7601, 0.1405]]]])

If you have any questions, please don't hesitate to ask!

Upvotes: 1

Related Questions