Reputation: 52500
I'm trying to combine n
matrices in a 3-dimensional PyTorch tensor of shape (n, i, j)
into a single 2-dimensional matrix of shape (i, j*n)
. Here's a simple example where n=2, i=2, j=2
:
m = torch.tensor([[[2, 3],
[5, 7]],
[[11, 13],
[17, 19]]])
m.reshape(2, 4)
I was hoping this would produce:
tensor([[ 2, 3, 11, 13],
[ 5, 7, 17, 19]])
But instead it produced:
tensor([[ 2, 3, 5, 7],
[11, 13, 17, 19]])
How do I do this? I tried torch.cat
and torch.stack
, but they require tuples of tensors. I could try and create tuples, but that seems inefficient. Is there a better way?
Upvotes: 0
Views: 550
Reputation: 1660
To combine n
+ j
with reshape
you need them consequent in shape. One can fix it with swapaxes
:
m = torch.tensor([[[2, 3],
[5, 7]],
[[11, 13],
[17, 19]]])
m=m.swapaxes( 0,1 )
m.reshape(2, 4)
tensor([[ 2, 3, 11, 13],
[ 5, 7, 17, 19]])
Upvotes: 1