tom c
tom c

Reputation: 140

Understanding the order when reshaping a tensor

For a tensor:

x = torch.tensor([
    [
        [[0.4495, 0.2356],
          [0.4069, 0.2361],
          [0.4224, 0.2362]],
                   
         [[0.4357, 0.6762],
          [0.4370, 0.6779],
          [0.4406, 0.6663]]
    ],    
    [
        [[0.5796, 0.4047],
          [0.5655, 0.4080],
          [0.5431, 0.4035]],
         
         [[0.5338, 0.6255],
          [0.5335, 0.6266],
          [0.5204, 0.6396]]
    ]
])

Firstly would like to split it into 2 (x.shape[0]) tensors then concat them. Here, i dont really have to actually split it as long as i get the correct output, but it makes a lot more sense to me visually to split it then concat them back together.

For example:

# the shape of the splits are always the same
split1 = torch.tensor([
    [[0.4495, 0.2356],
    [0.4069, 0.2361],
    [0.4224, 0.2362]],

    [[0.4357, 0.6762],
    [0.4370, 0.6779],
    [0.4406, 0.6663]]
])
split2 = torch.tensor([
    [[0.5796, 0.4047],
    [0.5655, 0.4080],
    [0.5431, 0.4035]],

    [[0.5338, 0.6255],
    [0.5335, 0.6266],
    [0.5204, 0.6396]]
])

split1 = torch.cat((split1[0], split1[1]), dim=1)
split2 = torch.cat((split2[0], split2[1]), dim=1)
what_i_want = torch.cat((split1, split2), dim=0).reshape(x.shape[0], split1.shape[0], split1.shape[1])

enter image description here

For the above result, i thought directly reshaping x.reshape([2, 3, 4]) would work, it resulted in the correct dimension but incorrect result.

In general i am:

  1. not sure how to split the tensor into x.shape[0] tensors.
  2. confused about how reshape works. Most of the time i am able to get the dimension right, but the order of the numbers are always incorrect.

Thank you

Upvotes: 2

Views: 9671

Answers (3)

mQuan9909
mQuan9909

Reputation: 1

As I understand, you have a tensor x of shape (B, C, H, W) and you want to convert it into (B, H, C * W). To achieve this, you need to do the two following steps

  1. Rearrange x 's dimension to (B, H, C, W) to have a new tensor named y
  2. Reshape y into (B, H, C * W) to have the final result

The reason for reshaping x into (B, H, C, W) and especially not (B, H, W, C) is that

  • You want one row in the result made of the concatenation of rows of submatrices of x (i.e. the 3x2 matrices you indicate by 1, 2, 3, 4) having the same row index.
  • Pytorch 's reshape function operates in a row-major fashion.

Therefore, the rows need to be put on top of each other for reshape to return the desired order.

With the above reasoning, the code for getting what_i_want is

what_i_want = x.permute(0, 2, 1, 3).reshape(2, 3, 4)

Upvotes: 0

hpaulj
hpaulj

Reputation: 231325

Your example, using numpy methods (I don't have tensorflow installed):

In [559]: x = np.array([
     ...:     [
     ...:         [[0.4495, 0.2356],
     ...:           [0.4069, 0.2361],
     ...:           [0.4224, 0.2362]],
     ...: 
     ...:          [[0.4357, 0.6762],
     ...:           [0.4370, 0.6779],
     ...:           [0.4406, 0.6663]]
     ...:     ],
     ...:     [
     ...:         [[0.5796, 0.4047],
     ...:           [0.5655, 0.4080],
     ...:           [0.5431, 0.4035]],
     ...: 
     ...:          [[0.5338, 0.6255],
     ...:           [0.5335, 0.6266],
     ...:           [0.5204, 0.6396]]
     ...:     ]
     ...: ])
In [560]: x.shape
Out[560]: (2, 2, 3, 2)

In [562]: s1=np.concatenate((x[0,0],x[0,1]), axis=1)
In [563]: s2=np.concatenate((x[1,0],x[1,1]), axis=1)
In [564]: s1.shape
Out[564]: (3, 4)

In [565]: new =np.concatenate((s1,s2), axis=0)
In [566]: new.shape
Out[566]: (6, 4)
In [567]: new.reshape(2,3,4)
Out[567]: 
array([[[0.4495, 0.2356, 0.4357, 0.6762],
        [0.4069, 0.2361, 0.437 , 0.6779],
        [0.4224, 0.2362, 0.4406, 0.6663]],

       [[0.5796, 0.4047, 0.5338, 0.6255],
        [0.5655, 0.408 , 0.5335, 0.6266],
        [0.5431, 0.4035, 0.5204, 0.6396]]])

numpy has a stack that joins arrays on a new axis, so we can skip the last concatenate and reshape with

np.stack((s1,s2))    # or
np.array((s1,s2))

The direct way to get there is to swap the middle 2 dimensions:

In [569]: x.transpose(0,2,1,3).shape
Out[569]: (2, 3, 2, 2)

In [571]: x.transpose(0,2,1,3).reshape(2,3,4)
Out[571]: 
array([[[0.4495, 0.2356, 0.4357, 0.6762],
        [0.4069, 0.2361, 0.437 , 0.6779],
        [0.4224, 0.2362, 0.4406, 0.6663]],

       [[0.5796, 0.4047, 0.5338, 0.6255],
        [0.5655, 0.408 , 0.5335, 0.6266],
        [0.5431, 0.4035, 0.5204, 0.6396]]])

reshape can be used to combine 'adjacent' dimensions, but doesn't reorder the underlying data. That is x.ravel() remains the same with reshape. While reshape of (2,2,3,2) to (2,3,4) is allowed, the apparent order of values probably is not what you want. That might be easier to see if you try to reshape

In [572]: np.arange(6).reshape(2,3)
Out[572]: 
array([[0, 1, 2],
       [3, 4, 5]])
In [573]: _.reshape(3,2)
Out[573]: 
array([[0, 1],
       [2, 3],
       [4, 5]])

compare that with a transpose:

In [574]: np.arange(6).reshape(2,3).transpose(1,0)
Out[574]: 
array([[0, 3],
       [1, 4],
       [2, 5]])

The transpose/swap that I did in [569] may be hard to understand. There are enough different ways of reordering dimensions, that it's hard to generalize.

Upvotes: 1

Shai
Shai

Reputation: 114786

The order of the elements in memory in python, pytorch, numpy, c++ etc. are in row-major ordering:

[ first, second
  third, forth  ]

While in matlab, fortran, etc. the order is column major:

[ first,  third
  second, fourth ]

For higher dimensional tensors, this means elements are ordered from the last dimension to the first.

You can easily visualize it using torch.arange followed by .view:

a = torch.arange(24)
a.view(2,3,4)

Results with

tensor([[[ 0,  1,  2,  3],
    [ 4,  5,  6,  7],
    [ 8,  9, 10, 11]],

   [[12, 13, 14, 15],
    [16, 17, 18, 19],
    [20, 21, 22, 23]]])

As you can see the elements are ordered first by row (last dimension), then by column, and finally by the first dimension.

When you reshape a tensor, you do not change the underlying order of the elements, only the shape of the tensor. However, if you permute a tensor - you change the underlying order of the elements.

Look at the difference between a.view(3,2,4) and a.permute(0,1,2) - the shape of the resulting two tensors is the same, but not the ordering of elements:

In []: a.view(3,2,4)
Out[]:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23]]])

In []: a.permute(1,0,2)
Out[]:
tensor([[[ 0,  1,  2,  3],
         [12, 13, 14, 15]],

        [[ 4,  5,  6,  7],
         [16, 17, 18, 19]],

        [[ 8,  9, 10, 11],
         [20, 21, 22, 23]]])

Upvotes: 8

Related Questions