jgauth
jgauth

Reputation: 315

Obtaining the shape (4, 1, 84, 84) with pytorch

Suppose I have four pytorch tensors (tensor1, tensor2, tensor3, tensor4). Each tensor is of shape (1, 1, 84, 84). The first dimension is the number of tensors, the second dimension is the number of colors (e.g. grayscale in our example) and the last two dimensions represent the height and the width of the image.

I want to stack them so that I get the shape (4, 1, 84, 84).

I tried torch.stack((tensor1, tensor2, tensor3, tensor4), dim=0), but I got a shape (4, 1, 1, 84, 84).

How can I stack those tensors so that the shape will be (4, 1, 84, 84)

Upvotes: 2

Views: 93

Answers (1)

ccl
ccl

Reputation: 2378

You can use the concatenate function:

a = torch.ones(1,1,84,84)
b = torch.ones(1,1,84,84)
c = torch.cat((a,b), 0) # size[2,1,84,84]

Upvotes: 5

Related Questions