Reputation: 28044
In NumPy, I would do
a = np.zeros((4, 5, 6))
a = a[:, :, np.newaxis, :]
assert a.shape == (4, 5, 1, 6)
How to do the same in PyTorch?
Upvotes: 52
Views: 112319
Reputation: 28044
a = torch.zeros(4, 5, 6)
a = a[:, :, None, :]
assert a.shape == (4, 5, 1, 6)
Upvotes: 69
Reputation: 40768
You can add a new axis with torch.unsqueeze()
(first argument being the index of the new axis):
>>> a = torch.zeros(4, 5, 6)
>>> a = a.unsqueeze(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
Or using the in-place version: torch.unsqueeze_()
:
>>> a = torch.zeros(4, 5, 6)
>>> a.unsqueeze_(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
Upvotes: 43
Reputation: 576
x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)
y will be -> tensor([[ 1, 2, 3, 4]])
EDIT: see more details here: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
Upvotes: 5