Reputation: 263
I have the following tensor with dimensions (2, 3, 2, 2)
where the dimensions represent (batch_size, channels, height, width)
:
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]]],
[[[13., 14.],
[15., 16.]],
[[17., 18.],
[19., 20.]],
[[21., 22.],
[23., 24.]]]])
I would like to convert this into the following tensor with dimensions (8, 3)
:
tensor([[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11],
[ 4, 8, 12],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23],
[16, 20, 24]])
Essentially I would like to create 1D vector over the elements of the matrices. I have tried many operations such as flatten and reshape, but I cannot figure out how to achieve this reshaping.
Upvotes: 0
Views: 872
Reputation: 40648
You could achieve this with an axes permutation and a flattening the resulting tensor:
axis=1
(of size 3
) with the last one: axis=-1
, using torch.permute
(torch.swapaxes
is an alias),axis=0
to axis=-2
using torch.flatten
.This looks like:
>>> x.transpose(1, -1).flatten(0, -2)
tensor([[ 1., 5., 9.],
[ 3., 7., 11.],
[ 2., 6., 10.],
[ 4., 8., 12.],
[13., 17., 21.],
[15., 19., 23.],
[14., 18., 22.],
[16., 20., 24.]])
Upvotes: 0
Reputation: 13601
You can do it this way:
import torch
x = torch.Tensor(
[
[
[[1,2],[3,4]],
[[5,6],[7,8]],
[[9,10],[11,12]]],
[
[[13,14],[15,16]],
[[17,18],[19,20]],
[[21,22],[23,24]]]
]
)
result = x.swapaxes(0, 1).reshape(3, -1).T
print(result)
# > tensor([[ 1., 5., 9.],
# > [ 2., 6., 10.],
# > [ 3., 7., 11.],
# > [ 4., 8., 12.],
# > [13., 17., 21.],
# > [14., 18., 22.],
# > [15., 19., 23.],
# > [16., 20., 24.]])
Upvotes: 1