Reputation: 906
I have a list (my_list) of tensors all with the same shape. I want to concatenate them on the channel axis. Helping code
for i in my_list:
print(i.shape) #[1, 3, 128, 128] => [batch, channel, width, height]
I would like to get a new tensor i.e. new_tensor = [1, 3*len(my_list), width, height]
I don't want to use torch.stack()
to add a new dimension. And i am unable to figure out how can I use torch.cat()
to do this?
Upvotes: 2
Views: 5817
Reputation: 40768
Given a example list containing 10 tensors shaped (1, 3, 128, 128)
:
>>> my_list = [torch.rand(1, 3, 128, 128) for _ in range(10)]
You are looking to concatenate your tensors on axis=1
because the 2nd dimension is where the tensor to concatenate together. You can do so using torch.cat
:
>>> res = torch.cat(my_list, axis=1)
>>> res.shape
torch.Size([1, 30, 128, 128])
This is actually equivalent to stacking your tensor in my_list
vertically, i.e. by using torch.vstack
:
>>> res = torch.vstack(my_list)
Upvotes: 4