sachinruk
sachinruk

Reputation: 9869

Flattening Efficientnet model

I am trying to remove the top layer in the efficientnet-pytorch implementation. However, if I simply replace the final _fc layer with my own fully connected layer, as suggested by the author in this github comment, I am worried that there is still a swish activation even after this layer, as opposed to having nothing as I expected. When I print the model, the final lines is as follows:

(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_avg_pooling): AdaptiveAvgPool2d(output_size=1)
    (_dropout): Dropout(p=0.2, inplace=False)
    (_fc): Sequential(
      (0): Linear(in_features=1280, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=512, out_features=128, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.25, inplace=False)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
    (_swish): MemoryEfficientSwish()
  )
)

where _fc is my replaced module.

What I hoped to do was:

base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))

where in my mind base_model.children() flattens the model from the nested structure. However, now I cannot seem to be able to use the model as if I use a dummy input, x=torch.randn(1,3,255,255) I get the error: TypeError: forward() takes 1 positional argument but 2 were given.

It should be noted that model[:2](x) works, but not model[:3](x). model[2] seems to be the mobile blocks.

Here is a colab notebook with the above code.

Upvotes: 1

Views: 1333

Answers (1)

Berriel
Berriel

Reputation: 13641

This is a common misunderstanding of what print(net) actually does.

The fact that there is a _swish Module after the _fc simply means that for former was registered after the latter. You can check that in the code:

class EfficientNet(nn.Module):
    def __init__(self, blocks_args=None, global_params=None):

        # [...]

        # Final linear layer
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
        self._dropout = nn.Dropout(self._global_params.dropout_rate)
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
        self._swish = MemoryEfficientSwish()

The order in which they are defined is the order that they will be printed. When it comes to what exactly is performed, you have to check the forward:

def forward(self, inputs):
    # Convolution layers
    x = self.extract_features(inputs)

    # Pooling and final linear layer
    x = self._avg_pooling(x)
    x = x.flatten(start_dim=1)
    x = self._dropout(x)
    x = self._fc(x)

    return x

and, as you can see, there is nothing after self._fc(x), which means no Swish will be applied.

Upvotes: 3

Related Questions