Shary
Shary

Reputation: 101

Getting (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
        
    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )
        
    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, image_channels, num_features= 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(image_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        self.down_blocks = nn.ModuleList = ([
            ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
        ])
        
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        
        self.up_blocks = nn.ModuleList = ([
            ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, padding=1, stride=2, output_padding=1),
            ConvBlock(num_features*2, num_features, down=False, kernel_size=3, padding=1, stride=2, output_padding=1),
        ])
        
        self.last = nn.Conv2d(num_features, image_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        
    def forward(self, x):
        x = self.initial(x)
        
        for layer in self.down_blocks:
            x = layer(x)
            
        x = self.residual_blocks(x)
        
        for layer in self.up_blocks:
            x = layer(x)
            
        return torch.tanh(self.last(x))

img_channels = 3
img_size = 256
x = torch.randn((2, img_channels, img_size, img_size))
x = x.to(DEVICE)
gen = Generator(img_channels, 9).to(DEVICE)
print(gen(x).shape)

I have implemented this model for Cycle GAN. The input used here is just for demonstration purposes to shorten the code however the actual input throws the same error. The code runs fine on the CPU but when I shift it to GPU I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-a56669856674> in <module>
      4 x = x.to(DEVICE)
      5 gen = Generator(img_channels, 9).to(DEVICE)
----> 6 print(gen(x).shape)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-25-f9a41f6c9d12> in forward(self, x)
     26 
     27         for layer in self.down_blocks:
---> 28             x = layer(x)
     29 
     30         x = self.residual_blocks(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-23-e139087b9df4> in forward(self, x)
     11 
     12     def forward(self, x):
---> 13         return self.conv(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    421 
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 
    425 class Conv3d(_ConvNd):

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    416             return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    417                             weight, self.bias, self.stride,
--> 418                             _pair(0), self.dilation, self.groups)
    419         return F.conv2d(input, weight, self.bias, self.stride,
    420                         self.padding, self.dilation, self.groups)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

My model is being changed to device(CUDA) and the input as well. I have no idea what the issue is now

Upvotes: 2

Views: 243

Answers (1)

Ivan
Ivan

Reputation: 40708

You have typos in your code:

self.down_blocks = nn.ModuleList = ([
...
self.up_blocks = nn.ModuleList = ([

should be:

self.down_blocks = nn.ModuleList([
...
self.up_blocks = nn.ModuleList([

You need to reload your kernel since at this point you've essentially overwritten nn.ModuleList to a list.

Upvotes: 1

Related Questions