Reputation: 9869
I am trying to remove the top layer in the efficientnet-pytorch implementation. However, if I simply replace the final _fc
layer with my own fully connected layer, as suggested by the author in this github comment, I am worried that there is still a swish
activation even after this layer, as opposed to having nothing as I expected. When I print the model, the final lines is as follows:
(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_avg_pooling): AdaptiveAvgPool2d(output_size=1)
(_dropout): Dropout(p=0.2, inplace=False)
(_fc): Sequential(
(0): Linear(in_features=1280, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): ReLU()
(5): Dropout(p=0.25, inplace=False)
(6): Linear(in_features=128, out_features=1, bias=True)
)
(_swish): MemoryEfficientSwish()
)
)
where _fc
is my replaced module.
What I hoped to do was:
base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))
where in my mind base_model.children()
flattens the model from the nested structure. However, now I cannot seem to be able to use the model as if I use a dummy input, x=torch.randn(1,3,255,255)
I get the error: TypeError: forward() takes 1 positional argument but 2 were given
.
It should be noted that model[:2](x)
works, but not model[:3](x)
. model[2]
seems to be the mobile blocks.
Here is a colab notebook with the above code.
Upvotes: 1
Views: 1333
Reputation: 13641
This is a common misunderstanding of what print(net)
actually does.
The fact that there is a _swish
Module after the _fc
simply means that for former was registered after the latter. You can check that in the code:
class EfficientNet(nn.Module):
def __init__(self, blocks_args=None, global_params=None):
# [...]
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
The order in which they are defined is the order that they will be printed. When it comes to what exactly is performed, you have to check the forward
:
def forward(self, inputs):
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x
and, as you can see, there is nothing after self._fc(x)
, which means no Swish
will be applied.
Upvotes: 3