Reputation: 23
Pytorch provides API to concatenate tensors, like cat, stack. But does it provide any API to concatenate pytorch tensors alternatively?
For example,
suppose input1.shape = C*H*W
, a1.shape = H\*W
, and output.shape = (3C)*H*W
This can be achieved using a loop, but I am wondering if any Pytorch API can do this
Upvotes: 2
Views: 376
Reputation: 9806
I will try to do it with small example:
input1 = torch.full((3, 3), 1)
input2 = torch.full((3, 3), 2)
input3 = torch.full((3, 3), 3)
out = torch.concat((input1,input2, input3)).T.flatten()
torch.stack(torch.split(out, 3), dim=1).reshape(3,-1)
#output
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3]])
Upvotes: 1