at.
at.

Reputation: 52500

Reshape PyTorch tensor so that matrices are horizontal

I'm trying to combine n matrices in a 3-dimensional PyTorch tensor of shape (n, i, j) into a single 2-dimensional matrix of shape (i, j*n). Here's a simple example where n=2, i=2, j=2:

m = torch.tensor([[[2, 3],
                   [5, 7]],
                  [[11, 13],
                   [17, 19]]])
m.reshape(2, 4)

I was hoping this would produce:

tensor([[ 2,  3, 11, 13],
        [ 5,  7, 17, 19]])

But instead it produced:

tensor([[ 2,  3,  5,  7],
        [11, 13, 17, 19]])

How do I do this? I tried torch.cat and torch.stack, but they require tuples of tensors. I could try and create tuples, but that seems inefficient. Is there a better way?

Upvotes: 0

Views: 550

Answers (1)

Alexey Birukov
Alexey Birukov

Reputation: 1660

To combine n + j with reshape you need them consequent in shape. One can fix it with swapaxes:

m = torch.tensor([[[2, 3],
               [5, 7]],
              [[11, 13],
               [17, 19]]])
m=m.swapaxes( 0,1 ) 
m.reshape(2, 4)

tensor([[ 2,  3, 11, 13],
        [ 5,  7, 17, 19]])

Upvotes: 1

Related Questions