noname
noname

Reputation: 13

Error when I modify number of input channels of inception-v3

I would like to customize inception_v3 to make it work for 4-channel input. I tried to modify first layer of inception v3 as below.

x=torch.randn((5,4,299,299))

model_ft=models.inception_v3(pretrained=True)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)

but it produces the following error. I think the input shape and network are correctly modified, so I can't understand why it makes error. does anyone have any advice?

torch.Size([5, 4, 299, 299])
Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)

RuntimeErrorTraceback (most recent call last)
<ipython-input-118-41c045338348> in <module>
     29 print(model_ft.Conv2d_1a_3x3.conv)
     30 
---> 31 out=model_ft(x)
     32 print(out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    202     def forward(self, x: Tensor) -> InceptionOutputs:
    203         x = self._transform_input(x)
--> 204         x, aux = self._forward(x)
    205         aux_defined = self.training and self.aux_logits
    206         if torch.jit.is_scripting():

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in _forward(self, x)
    141     def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
    142         # N x 3 x 299 x 299
--> 143         x = self.Conv2d_1a_3x3(x)
    144         # N x 32 x 149 x 149
    145         x = self.Conv2d_2a_3x3(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    474 
    475     def forward(self, x: Tensor) -> Tensor:
--> 476         x = self.conv(x)
    477         x = self.bn(x)
    478         return F.relu(x, inplace=True)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    441 
    442     def forward(self, input: Tensor) -> Tensor:
--> 443         return self._conv_forward(input, self.weight, self.bias)
    444 
    445 class Conv3d(_ConvNd):

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    438                             _pair(0), self.dilation, self.groups)
    439         return F.conv2d(input, weight, bias, self.stride,
--> 440                         self.padding, self.dilation, self.groups)
    441 
    442     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [32, 4, 3, 3], expected input[5, 3, 299, 299] to have 4 channels, but got 3 channels instead

Upvotes: 0

Views: 907

Answers (2)

noname
noname

Reputation: 13

I found that when pretrained=True, normalization filter for imagenet dataset is applied before the network which can be seen here. and the filter is designed for 3 channel input image. It was why the error occurs.

def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
   ...
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
   ...
   return Inception3(**kwargs)


class Inception3(nn.Module):
   def __init__(self,num_classes: int = 1000,aux_logits: bool = True,transform_input: bool = False,inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,init_weights: Optional[bool] = None) -> None:
      super(Inception3, self).__init__()
      ...
      self.transform_input = transform_input
      ...

   def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x
   def forward(self, x: Tensor) -> InceptionOutputs:
        x = self._transform_input(x)
      ...

I could finally use pretrained model by the way below.

x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=True)
model_ft.transform_input=False
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
out=model_ft(x)

Upvotes: 1

Prajot Kuvalekar
Prajot Kuvalekar

Reputation: 6618

The error is due the param pretrained=True.
Since you are using pretrained weights and you cannot edit the shape of pretrained weights to make its adjust for 4 channel. Hence the error pops up
Plz use it in this way ( which will only load architecture)

x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=False)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)

and it will work

Upvotes: 0

Related Questions