Ammar Ul Hassan
Ammar Ul Hassan

Reputation: 906

How to concatenate a list of tensors on a specific axis?

I have a list (my_list) of tensors all with the same shape. I want to concatenate them on the channel axis. Helping code

for i in my_list:
    print(i.shape) #[1, 3, 128, 128] => [batch, channel, width, height]

I would like to get a new tensor i.e. new_tensor = [1, 3*len(my_list), width, height]

I don't want to use torch.stack() to add a new dimension. And i am unable to figure out how can I use torch.cat() to do this?

Upvotes: 2

Views: 5817

Answers (1)

Ivan
Ivan

Reputation: 40768

Given a example list containing 10 tensors shaped (1, 3, 128, 128):

>>> my_list = [torch.rand(1, 3, 128, 128) for _ in range(10)]

You are looking to concatenate your tensors on axis=1 because the 2nd dimension is where the tensor to concatenate together. You can do so using torch.cat:

>>> res = torch.cat(my_list, axis=1)
>>> res.shape
torch.Size([1, 30, 128, 128])

This is actually equivalent to stacking your tensor in my_list vertically, i.e. by using torch.vstack:

>>> res = torch.vstack(my_list)

Upvotes: 4

Related Questions