Reputation: 1273
I think both torch.cat
and torch.stack
cannot fully satisfy my requirement.
Initially, I define an empty tensor. Then I want to append a 1d-tensor to it for multiple times.
x = torch.tensor([]).type(torch.DoubleTensor)
y = torch.tensor([ 0.3981, 0.6952, -1.2320]).type(torch.DoubleTensor)
x = torch.stack([x,y])
This will throw an error:
RuntimeError: stack expects each tensor to be equal size, but got [0] at entry 0 and [3] at entry 1
So I have to initialise x
as torch.tensor([0,0,0])
(but can this be avoided?)
x = torch.tensor([0,0,0]).type(torch.DoubleTensor)
y = torch.tensor([ 0.3981, 0.6952, -1.2320]).type(torch.DoubleTensor)
x = torch.stack([x,y]) # this is okay
x = torch.stack([x,y]) # <--- this got error again
But when I run x = torch.stack([x,y])
the second time, I got this error:
RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3] at entry 1
What I want to achieve is being able to append a 1d-tensor multiple times (the added 1d-tensor is different at each time, here I use the same one for simplicity)**:
tensor([[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
...
[ 0.3981, 0.6952, -1.2320]], dtype=torch.float64)
How to achieve this?
Upvotes: 0
Views: 6742
Reputation: 20287
From the documentation of torch.cat "All tensors must either have the same shape (except in the concatenating dimension) or be empty". So, the easiest solution is to add one more dimension (size 1) to the tensor you want to add. Then, you will have tensors of size (n, whatever) and (1, whatever) which will be concatenated along the 0th dimension, meeting the requirements for torch.cat.
Code:
x = torch.empty(size=(0,3))
y = torch.tensor([ 0.3981, 0.6952, -1.2320])
for n in range(5):
y1 = y.unsqueeze(dim=0) # same as y but with shape (1,3)
x = torch.cat([x,y1], dim=0)
print(x)
Output:
tensor([[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320],
[ 0.3981, 0.6952, -1.2320]])
Upvotes: 1