Darren Cook
Darren Cook

Reputation: 28913

Best way to cut a pytorch tensor into overlapping chunks?

If for instance I have:

eg6 = torch.tensor([
    [ 1.,  7., 13., 19.],
    [ 2.,  8., 14., 20.],
    [ 3.,  9., 15., 21.],
    [ 4., 10., 16., 22.],
    [ 5., 11., 17., 23.],
    [ 6., 12., 18., 24.]])
batch1 = eg6
batch2 = -eg6
x = torch.cat((batch1,batch2)).view(2,6,4)

And then I want to slice it up into overlapping chunks, like a sliding window function, and have the chunks be batch-processable. For example, and just looking at the first dimension, I want 1,2,3, 3,4,5, 5,6 (or 5,6,0).

It seems unfold() kind of does what I want. It transposes the last two dimensions for some reason, but that is easy enough to repair. Changing it from shape [2,3,3,4] to [6,3,4] requires a memory copy, but I believe that is unavoidable?

SZ=3
x2 = x.unfold(1,SZ,2).transpose(2,3).reshape(-1,SZ,4)

This works perfectly when x is of shape [2,7,4]. But with only 6 rows, it throws away the final row.

Is there a version of unfold() that can be told to use all data, ideally taking a pad character?

Or do I need to pad x before calling unfold()? What is the best way to do that? I'm wondering if "pad" is the wrong word, as I'm only finding functions that want to put padding characters at both ends, with convolutions in mind.


Aside: Looking at the source of unfold, it seems the strange transpose is there deliberately and explicitly?! For that reason, and the undesired chop behaviour, it made me think the correct answer to my question might be write a new low-level function. But that is too much effort for me, at least for today... (I think a second function for the backwards pass also needs to be written.)

Upvotes: 0

Views: 1375

Answers (1)

Ivan
Ivan

Reputation: 40648

The operation performed here is similar to what a 1D convolution would behave like. With kernel_size=SZ and stride=2. As you noticed if you don't provide sufficient padding (you're correct on the wording) the last element won't be used.

A general approach (for any SZ and any input shape x.size(1)) is to figure out if padding is necessary, and if so what amount is needed.

  • The size of the output is given by out = floor((x.size(1) - SZ)/2 + 1).

  • The number of unused elements is x.size(1) - out*(SZ-1) - 1.

  • If the number of unused elements is non zero, you need to add a padding of (out+1)*(SZ-1) + 1 - x.size(1)


This example won't need padding:

>>> x = torch.stack((torch.tensor([
            [ 1.,  7., 13., 19.],
            [ 2.,  8., 14., 20.],
            [ 3.,  9., 15., 21.],
            [ 4., 10., 16., 22.],
            [ 5., 11., 17., 23.]]),)*2)

>>> x.shape
torch.Size([2, 5, 4])

>>> out = floor((x.size(1) - SZ)/2 + 1)
2

>>> unused = x.size(1) - out*(SZ-1) - 1
0

While this one will:

>>> x = torch.stack((torch.tensor([
          [ 1.,  7., 13., 19.],
          [ 2.,  8., 14., 20.],
          [ 3.,  9., 15., 21.],
          [ 4., 10., 16., 22.],
          [ 5., 11., 17., 23.],
          [ 6., 12., 18., 24.]]),)*2)

>>> x.shape
torch.Size([2, 6, 4])

>>> out = floor((x.size(1) - SZ)/2 + 1)
2

>>> unused = x.size(1) - out*(SZ-1) - 1
1

>>> p = (out+1)*(SZ-1) + 1 - x.size(1)
1

Now, to actually add padding you could just use torch.cat. Although I am the built-in, nn.functional.pad, would work...

>>> torch.cat((x, torch.zeros(x.size(0), p, x.size(2))), dim=1)
tensor([[[ 1.,  7., 13., 19.],
         [ 2.,  8., 14., 20.],
         [ 3.,  9., 15., 21.],
         [ 4., 10., 16., 22.],
         [ 5., 11., 17., 23.],
         [ 6., 12., 18., 24.],
         [ 0.,  0.,  0.,  0.]],

        [[ 1.,  7., 13., 19.],
         [ 2.,  8., 14., 20.],
         [ 3.,  9., 15., 21.],
         [ 4., 10., 16., 22.],
         [ 5., 11., 17., 23.],
         [ 6., 12., 18., 24.],
         [ 0.,  0.,  0.,  0.]]])

Upvotes: 1

Related Questions