Arun
Arun

Reputation: 2478

Concatenate torch tensors

I have two tensors in PyTorch as:

a.shape, b.shape
# (torch.Size([512, 28, 2]), torch.Size([512, 28, 26]))

My goal is to join/merge/concatenate them together so that I get the shape: (512, 28, 28).

I tried:

torch.stack((a, b), dim = 2).shape
torch.cat((a, b)).shape

But none of them seem to work.

I am using PyTorch version: 1.11.0 and Python 3.9.

Help?

Upvotes: 2

Views: 6624

Answers (1)

trsvchn
trsvchn

Reputation: 8981

Set dim parameter to 2 to concatenate over last dimension:

a = torch.randn(512, 28, 2)
b = torch.randn(512, 28, 26)

print(a.size(), b.size())

# set dim=2 to concat over 2nd dimension
c = torch.cat((a, b), dim=2)

print(c.size())
torch.Size([512, 28, 2]) torch.Size([512, 28, 26])
torch.Size([512, 28, 28])

Upvotes: 3

Related Questions