phil
phil

Reputation: 263

Creating 1D vectors over 3D tensors in pytorch

I have the following tensor with dimensions (2, 3, 2, 2) where the dimensions represent (batch_size, channels, height, width):

tensor([[[[ 1.,  2.],
          [ 3.,  4.]],

         [[ 5.,  6.],
          [ 7.,  8.]],

         [[ 9., 10.],
          [11., 12.]]],


        [[[13., 14.],
          [15., 16.]],

         [[17., 18.],
          [19., 20.]],

         [[21., 22.],
          [23., 24.]]]])

I would like to convert this into the following tensor with dimensions (8, 3):

tensor([[ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11],
        [ 4,  8, 12],
        [13, 17, 21],
        [14, 18, 22],
        [15, 19, 23],
        [16, 20, 24]])

Essentially I would like to create 1D vector over the elements of the matrices. I have tried many operations such as flatten and reshape, but I cannot figure out how to achieve this reshaping.

Upvotes: 0

Views: 872

Answers (2)

Ivan
Ivan

Reputation: 40648

You could achieve this with an axes permutation and a flattening the resulting tensor:

  1. swap axis=1 (of size 3) with the last one: axis=-1, using torch.permute (torch.swapaxes is an alias),
  2. flatten everything but the last axis i.e. from axis=0 to axis=-2 using torch.flatten.

This looks like:

>>> x.transpose(1, -1).flatten(0, -2)
tensor([[ 1.,  5.,  9.],
        [ 3.,  7., 11.],
        [ 2.,  6., 10.],
        [ 4.,  8., 12.],
        [13., 17., 21.],
        [15., 19., 23.],
        [14., 18., 22.],
        [16., 20., 24.]])

Upvotes: 0

Berriel
Berriel

Reputation: 13601

You can do it this way:

import torch

x = torch.Tensor(
   [
    [
    [[1,2],[3,4]],
    [[5,6],[7,8]],
    [[9,10],[11,12]]],
    [
    [[13,14],[15,16]],
    [[17,18],[19,20]],
    [[21,22],[23,24]]]
   ]  
)

result = x.swapaxes(0, 1).reshape(3, -1).T

print(result)
# > tensor([[ 1.,  5.,  9.],
# >         [ 2.,  6., 10.],
# >         [ 3.,  7., 11.],
# >         [ 4.,  8., 12.],
# >         [13., 17., 21.],
# >         [14., 18., 22.],
# >         [15., 19., 23.],
# >         [16., 20., 24.]])

Upvotes: 1

Related Questions