Rubertos
Rubertos

Reputation: 129

How do I concatenate tensors along a given axis?


import torch

a = torch.Tensor(2,2,2)

b = myfunction(a)

print(a)
>>> [[[1,2],
      [5,6]],
     [[7,8],
      [9,10]]]

print(b)
>>> [[1,2,7,8],
     [5,6,9,10]]

How do I code myfunction to get b from a?

Is there some pytorch functions that transforms a in such way?

Upvotes: 0

Views: 359

Answers (1)

fuglede
fuglede

Reputation: 18201

You can achieve this by using transpose to swap the first two axes (cf. e.g. np.swapaxes), and reshape to get your desired shape:

In [12]: a
Out[12]:
tensor([[[  1.,   2.],
         [  5.,   6.]],

        [[  7.,   8.],
         [  9.,  10.]]])

In [13]: a.transpose(0, 1).reshape(2, 4)
Out[13]:
tensor([[  1.,   2.,   7.,   8.],
        [  5.,   6.,   9.,  10.]])

Upvotes: 1

Related Questions