Duane
Duane

Reputation: 5140

Using expand_dims in pytorch

I'm trying to tile a length 18 1 hot vector into a 40x40 grid.

Looking at pytorch docs, expand dims seems to be what i need.

But I cannot get it to work. Any idea what I'm doing wrong?

one_hot = torch.zeros(18).unsqueeze(0)
one_hot[0,1] = 1.0
one_hot
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
one_hot.expand(-1,-1,40,40)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
RuntimeError: The expanded size of the tensor (40) must match the existing size (18) at non-singleton dimension 3

I'm expecting a tensor of shape (1, 18, 40,40)

Upvotes: 1

Views: 19963

Answers (1)

Shai
Shai

Reputation: 114816

expand works along singleton dimensions of the input tensor. In your example, you are trying to expand a 1-by-18 tensor along its (non-existent) third and fourth dimensions - this is why you are getting an error. The only singleton dimension (=dimension with size==1) you have is the first dimension.

fix

one_hot = torch.zeros(1,18,1,1, dtype=torch.float)  # create the tensor with all singleton dimensions in place
one_hot[0,1,0,0] = 1.
one_hot.expand(-1,-1,40,40)

Upvotes: 4

Related Questions