Simon Rochester
Simon Rochester

Reputation: 33

run conv2d against a list

I tested conv2d with following code:

import torch
import torch.nn as nn

x=torch.randint(500,(256,))
conv=nn.Conv2d(1,6,5,padding=1)
y=x.view(1,1,16,16)
z=conv(y)
print (z.shape)

and I got error:

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in conv2d_forward(self, input, weight)
    340                             _pair(0), self.dilation, self.groups)
    341         return F.conv2d(input, weight, self.bias, self.stride,
--> 342                         self.padding, self.dilation, self.groups)
    343 
    344     def forward(self, input):

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'weight' in call to _thnn_conv2d_forward

How to fix it?

Upvotes: 1

Views: 120

Answers (1)

Mercury
Mercury

Reputation: 4181

In pytorch, the nn.Conv2d module needs the data to be in float. You can just make a simple edit:

x = torch.randint(500,(256,), dtype=torch.float32)

Alternatively you can also do:

x = torch.randint(500,(256,))
x = x.float()

Upvotes: 1

Related Questions