Reputation: 13
I would like to customize inception_v3 to make it work for 4-channel input. I tried to modify first layer of inception v3 as below.
x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=True)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)
but it produces the following error. I think the input shape and network are correctly modified, so I can't understand why it makes error. does anyone have any advice?
torch.Size([5, 4, 299, 299])
Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
RuntimeErrorTraceback (most recent call last)
<ipython-input-118-41c045338348> in <module>
29 print(model_ft.Conv2d_1a_3x3.conv)
30
---> 31 out=model_ft(x)
32 print(out)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
202 def forward(self, x: Tensor) -> InceptionOutputs:
203 x = self._transform_input(x)
--> 204 x, aux = self._forward(x)
205 aux_defined = self.training and self.aux_logits
206 if torch.jit.is_scripting():
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in _forward(self, x)
141 def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
142 # N x 3 x 299 x 299
--> 143 x = self.Conv2d_1a_3x3(x)
144 # N x 32 x 149 x 149
145 x = self.Conv2d_2a_3x3(x)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
474
475 def forward(self, x: Tensor) -> Tensor:
--> 476 x = self.conv(x)
477 x = self.bn(x)
478 return F.relu(x, inplace=True)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
441
442 def forward(self, input: Tensor) -> Tensor:
--> 443 return self._conv_forward(input, self.weight, self.bias)
444
445 class Conv3d(_ConvNd):
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
438 _pair(0), self.dilation, self.groups)
439 return F.conv2d(input, weight, bias, self.stride,
--> 440 self.padding, self.dilation, self.groups)
441
442 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Given groups=1, weight of size [32, 4, 3, 3], expected input[5, 3, 299, 299] to have 4 channels, but got 3 channels instead
Upvotes: 0
Views: 907
Reputation: 13
I found that when pretrained=True
, normalization filter for imagenet dataset is applied before the network which can be seen here. and the filter is designed for 3 channel input image. It was why the error occurs.
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
...
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
...
return Inception3(**kwargs)
class Inception3(nn.Module):
def __init__(self,num_classes: int = 1000,aux_logits: bool = True,transform_input: bool = False,inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,init_weights: Optional[bool] = None) -> None:
super(Inception3, self).__init__()
...
self.transform_input = transform_input
...
def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x)
...
I could finally use pretrained model by the way below.
x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=True)
model_ft.transform_input=False
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
out=model_ft(x)
Upvotes: 1
Reputation: 6618
The error is due the param pretrained=True
.
Since you are using pretrained weights and you cannot edit the shape of pretrained weights to make its adjust for 4 channel
. Hence the error pops up
Plz use it in this way ( which will only load architecture)
x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=False)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)
and it will work
Upvotes: 0