Reputation: 9869
I am trying to remove the top layer in the efficientnet-pytorch implementation. However, if I simply replace the final _fc
layer with my own fully connected layer, as suggested by the author in this github comment, I am worried that there is still a swish
activation even after this layer, as opposed to having nothing as I expected. When I print the model, the final lines is as follows:
(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_avg_pooling): AdaptiveAvgPool2d(output_size=1)
(_dropout): Dropout(p=0.2, inplace=False)
(_fc): Sequential(
(0): Linear(in_features=1280, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): ReLU()
(5): Dropout(p=0.25, inplace=False)
(6): Linear(in_features=128, out_features=1, bias=True)
(_swish): MemoryEfficientSwish()
where _fc
is my replaced module.
What I hoped to do was:
base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))
where in my mind base_model.children()
flattens the model from the nested structure. However, now I cannot seem to be able to use the model as if I use a dummy input, x=torch.randn(1,3,255,255)
I get the error: TypeError: forward() takes 1 positional argument but 2 were given
It should be noted that model[:2](x)
works, but not model[:3](x)
. model[2]
seems to be the mobile blocks.
Here is a colab notebook with the above code.
Upvotes: 1
Views: 1333
Reputation: 13641
This is a common misunderstanding of what print(net)
actually does.
The fact that there is a _swish
Module after the _fc
simply means that for former was registered after the latter. You can check that in the code:
class EfficientNet(nn.Module):
def __init__(self, blocks_args=None, global_params=None):
# [...]
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
The order in which they are defined is the order that they will be printed. When it comes to what exactly is performed, you have to check the forward
def forward(self, inputs):
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x
and, as you can see, there is nothing after self._fc(x)
, which means no Swish
will be applied.
Upvotes: 3