Deon Hargrove
Deon Hargrove

Reputation: 281

Freeze certain layers of an existing model in PyTorch

I am using the mobileNetV2 and I only want to freeze part of the model. I know I can use the following code to freeze the entire model

MobileNet = models.mobilenet_v2(pretrained = True)

for param in MobileNet.parameters():
    param.requires_grad = False

but I want everything from (15) onward to remain unfrozen. How can I selectively freeze everything before the desired layer is frozen?

    (15): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): ConvBNReLU(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (16): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): ConvBNReLU(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (17): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): ConvBNReLU(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (18): ConvBNReLU(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )   )   (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1280, out_features=1000, bias=True)   ) )

Upvotes: 26

Views: 55511

Answers (6)

Flicic Suo
Flicic Suo

Reputation: 797

Start by freezing everything:

for param in MobileNet.parameters():
    param.requires_grad = False

Then, unfreeze parameters in (15):

for param in MobileNet.features[15].parameters():
    param.requires_grad = True

Or, unfreeze parameters in (15)-(18):

for i in range(15, 19):
    for param in MobileNet.features[i].parameters():
        param.requires_grad = True

Upvotes: 43

Ayanabha
Ayanabha

Reputation: 51

You can iterate over the required layers using this snippet :

for name, child in MobileNet.named_children:
    if name < 15 :
        for param in child.parameters():
            param.requires_grad = False

Upvotes: 1

Alex Punnen
Alex Punnen

Reputation: 6244

Since there are different types of models sometimes setting required_grad=True on the blocks alone does not work*.

Option 1

# freeze everything
for param in model.parameters():
     param.requires_grad = False

# and Un-Freeze lower 4 layers of encoder 
for i in range(0,num_encoder_layers-8,1):
    for param in model.encoder.block[i].parameters():
        param.requires_grad = True
#verify
for name, param in model.named_parameters():
    print(name,param.requires_grad)

Option 2

# Freeze upper 3 layers of encoder (lower is unfreezed)
 for i in range(num_encoder_layers-1,num_encoder_layers-4,-1):
     for param in model.encoder.block[i].parameters():
         param.requires_grad = False

# Freeze all layers of decoder
for i in range(num_decoder_layers):
    for param in model.decoder.block[i].parameters():
        param.requires_grad = False

for name, param in model.named_parameters():
    print(name,param.requires_grad)

Depending on what you have frozen you get something like this

Ouput

shared.weight False
encoder.block.0.layer.0.SelfAttention.q.weight True
encoder.block.0.layer.0.SelfAttention.k.weight True
encoder.block.0.layer.0.SelfAttention.v.weight True
encoder.block.0.layer.0.SelfAttention.o.weight True
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight True
encoder.block.0.layer.0.layer_norm.weight True
encoder.block.0.layer.1.DenseReluDense.wi_0.weight True
encoder.block.0.layer.1.DenseReluDense.wi_1.weight True
encoder.block.0.layer.1.DenseReluDense.wo.weight True
encoder.block.0.layer.1.layer_norm.weight True
encoder.block.1.layer.0.SelfAttention.q.weight True
encoder.block.1.layer.0.SelfAttention.k.weight True
encoder.block.1.layer.0.SelfAttention.v.weight True
encoder.block.1.layer.0.SelfAttention.o.weight True
encoder.block.1.layer.0.layer_norm.weight True
encoder.block.1.layer.1.DenseReluDense.wi_0.weight True
encoder.block.1.layer.1.DenseReluDense.wi_1.weight True
encoder.block.1.layer.1.DenseReluDense.wo.weight True
encoder.block.1.layer.1.layer_norm.weight True
encoder.block.2.layer.0.SelfAttention.q.weight True
encoder.block.2.layer.0.SelfAttention.k.weight True
encoder.block.2.layer.0.SelfAttention.v.weight True
encoder.block.2.layer.0.SelfAttention.o.weight True
encoder.block.2.layer.0.layer_norm.weight True
encoder.block.2.layer.1.DenseReluDense.wi_0.weight True
encoder.block.2.layer.1.DenseReluDense.wi_1.weight True
encoder.block.2.layer.1.DenseReluDense.wo.weight True
encoder.block.2.layer.1.layer_norm.weight True
encoder.block.3.layer.0.SelfAttention.q.weight True
encoder.block.3.layer.0.SelfAttention.k.weight True
encoder.block.3.layer.0.SelfAttention.v.weight True
encoder.block.3.layer.0.SelfAttention.o.weight True
encoder.block.3.layer.0.layer_norm.weight True
encoder.block.3.layer.1.DenseReluDense.wi_0.weight True
encoder.block.3.layer.1.DenseReluDense.wi_1.weight True
encoder.block.3.layer.1.DenseReluDense.wo.weight True
encoder.block.3.layer.1.layer_norm.weight True
encoder.block.4.layer.0.SelfAttention.q.weight False
encoder.block.4.layer.0.SelfAttention.k.weight False
encoder.block.4.layer.0.SelfAttention.v.weight False
encoder.block.4.layer.0.SelfAttention.o.weight False
encoder.block.4.layer.0.layer_norm.weight False
encoder.block.4.layer.1.DenseReluDense.wi_0.weight False
encoder.block.4.layer.1.DenseReluDense.wi_1.weight False
encoder.block.4.layer.1.DenseReluDense.wo.weight False
encoder.block.4.layer.1.layer_norm.weight False
encoder.block.5.layer.0.SelfAttention.q.weight False
encoder.block.5.layer.0.SelfAttention.k.weight False
encoder.block.5.layer.0.SelfAttention.v.weight False
encoder.block.5.layer.0.SelfAttention.o.weight False
encoder.block.5.layer.0.layer_norm.weight False
encoder.block.5.layer.1.DenseReluDense.wi_0.weight False
encoder.block.5.layer.1.DenseReluDense.wi_1.weight False
encoder.block.5.layer.1.DenseReluDense.wo.weight False
encoder.block.5.layer.1.layer_norm.weight False
encoder.block.6.layer.0.SelfAttention.q.weight False
encoder.block.6.layer.0.SelfAttention.k.weight False
encoder.block.6.layer.0.SelfAttention.v.weight False
encoder.block.6.layer.0.SelfAttention.o.weight False
encoder.block.6.layer.0.layer_norm.weight False
encoder.block.6.layer.1.DenseReluDense.wi_0.weight False
encoder.block.6.layer.1.DenseReluDense.wi_1.weight False
encoder.block.6.layer.1.DenseReluDense.wo.weight False
encoder.block.6.layer.1.layer_norm.weight False
encoder.block.7.layer.0.SelfAttention.q.weight False
encoder.block.7.layer.0.SelfAttention.k.weight False
encoder.block.7.layer.0.SelfAttention.v.weight False
encoder.block.7.layer.0.SelfAttention.o.weight False
encoder.block.7.layer.0.layer_norm.weight False
encoder.block.7.layer.1.DenseReluDense.wi_0.weight False
encoder.block.7.layer.1.DenseReluDense.wi_1.weight False
encoder.block.7.layer.1.DenseReluDense.wo.weight False
encoder.block.7.layer.1.layer_norm.weight False
encoder.block.8.layer.0.SelfAttention.q.weight False
encoder.block.8.layer.0.SelfAttention.k.weight False
encoder.block.8.layer.0.SelfAttention.v.weight False
encoder.block.8.layer.0.SelfAttention.o.weight False
encoder.block.8.layer.0.layer_norm.weight False
encoder.block.8.layer.1.DenseReluDense.wi_0.weight False
encoder.block.8.layer.1.DenseReluDense.wi_1.weight False
encoder.block.8.layer.1.DenseReluDense.wo.weight False
encoder.block.8.layer.1.layer_norm.weight False
encoder.block.9.layer.0.SelfAttention.q.weight False
encoder.block.9.layer.0.SelfAttention.k.weight False
encoder.block.9.layer.0.SelfAttention.v.weight False
encoder.block.9.layer.0.SelfAttention.o.weight False
encoder.block.9.layer.0.layer_norm.weight False
encoder.block.9.layer.1.DenseReluDense.wi_0.weight False
encoder.block.9.layer.1.DenseReluDense.wi_1.weight False
encoder.block.9.layer.1.DenseReluDense.wo.weight False
encoder.block.9.layer.1.layer_norm.weight False
encoder.block.10.layer.0.SelfAttention.q.weight False
encoder.block.10.layer.0.SelfAttention.k.weight False
encoder.block.10.layer.0.SelfAttention.v.weight False
encoder.block.10.layer.0.SelfAttention.o.weight False
encoder.block.10.layer.0.layer_norm.weight False
encoder.block.10.layer.1.DenseReluDense.wi_0.weight False
encoder.block.10.layer.1.DenseReluDense.wi_1.weight False
encoder.block.10.layer.1.DenseReluDense.wo.weight False
encoder.block.10.layer.1.layer_norm.weight False
encoder.block.11.layer.0.SelfAttention.q.weight False
encoder.block.11.layer.0.SelfAttention.k.weight False
encoder.block.11.layer.0.SelfAttention.v.weight False
encoder.block.11.layer.0.SelfAttention.o.weight False
encoder.block.11.layer.0.layer_norm.weight False
encoder.block.11.layer.1.DenseReluDense.wi_0.weight False
encoder.block.11.layer.1.DenseReluDense.wi_1.weight False
encoder.block.11.layer.1.DenseReluDense.wo.weight False
encoder.block.11.layer.1.layer_norm.weight False
encoder.final_layer_norm.weight False
decoder.block.0.layer.0.SelfAttention.q.weight False
decoder.block.0.layer.0.SelfAttention.k.weight False
decoder.block.0.layer.0.SelfAttention.v.weight False
decoder.block.0.layer.0.SelfAttention.o.weight False
decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight False
decoder.block.0.layer.0.layer_norm.weight False
decoder.block.0.layer.1.EncDecAttention.q.weight False
decoder.block.0.layer.1.EncDecAttention.k.weight False
decoder.block.0.layer.1.EncDecAttention.v.weight False
decoder.block.0.layer.1.EncDecAttention.o.weight False
decoder.block.0.layer.1.layer_norm.weight False
decoder.block.0.layer.2.DenseReluDense.wi_0.weight False
decoder.block.0.layer.2.DenseReluDense.wi_1.weight False
decoder.block.0.layer.2.DenseReluDense.wo.weight False
decoder.block.0.layer.2.layer_norm.weight False
decoder.block.1.layer.0.SelfAttention.q.weight False
decoder.block.1.layer.0.SelfAttention.k.weight False
decoder.block.1.layer.0.SelfAttention.v.weight False
decoder.block.1.layer.0.SelfAttention.o.weight False
decoder.block.1.layer.0.layer_norm.weight False
decoder.block.1.layer.1.EncDecAttention.q.weight False
decoder.block.1.layer.1.EncDecAttention.k.weight False
decoder.block.1.layer.1.EncDecAttention.v.weight False
decoder.block.1.layer.1.EncDecAttention.o.weight False
decoder.block.1.layer.1.layer_norm.weight False
decoder.block.1.layer.2.DenseReluDense.wi_0.weight False
decoder.block.1.layer.2.DenseReluDense.wi_1.weight False
decoder.block.1.layer.2.DenseReluDense.wo.weight False
decoder.block.1.layer.2.layer_norm.weight False
decoder.block.2.layer.0.SelfAttention.q.weight False
decoder.block.2.layer.0.SelfAttention.k.weight False
decoder.block.2.layer.0.SelfAttention.v.weight False
decoder.block.2.layer.0.SelfAttention.o.weight False
decoder.block.2.layer.0.layer_norm.weight False
decoder.block.2.layer.1.EncDecAttention.q.weight False
decoder.block.2.layer.1.EncDecAttention.k.weight False
decoder.block.2.layer.1.EncDecAttention.v.weight False
decoder.block.2.layer.1.EncDecAttention.o.weight False
decoder.block.2.layer.1.layer_norm.weight False
decoder.block.2.layer.2.DenseReluDense.wi_0.weight False
decoder.block.2.layer.2.DenseReluDense.wi_1.weight False
decoder.block.2.layer.2.DenseReluDense.wo.weight False
decoder.block.2.layer.2.layer_norm.weight False
decoder.block.3.layer.0.SelfAttention.q.weight False
decoder.block.3.layer.0.SelfAttention.k.weight False
decoder.block.3.layer.0.SelfAttention.v.weight False
decoder.block.3.layer.0.SelfAttention.o.weight False
decoder.block.3.layer.0.layer_norm.weight False
decoder.block.3.layer.1.EncDecAttention.q.weight False
decoder.block.3.layer.1.EncDecAttention.k.weight False
decoder.block.3.layer.1.EncDecAttention.v.weight False
decoder.block.3.layer.1.EncDecAttention.o.weight False
decoder.block.3.layer.1.layer_norm.weight False
decoder.block.3.layer.2.DenseReluDense.wi_0.weight False
decoder.block.3.layer.2.DenseReluDense.wi_1.weight False
decoder.block.3.layer.2.DenseReluDense.wo.weight False
decoder.block.3.layer.2.layer_norm.weight False
decoder.block.4.layer.0.SelfAttention.q.weight False
decoder.block.4.layer.0.SelfAttention.k.weight False
decoder.block.4.layer.0.SelfAttention.v.weight False
decoder.block.4.layer.0.SelfAttention.o.weight False
decoder.block.4.layer.0.layer_norm.weight False
decoder.block.4.layer.1.EncDecAttention.q.weight False
decoder.block.4.layer.1.EncDecAttention.k.weight False
decoder.block.4.layer.1.EncDecAttention.v.weight False
decoder.block.4.layer.1.EncDecAttention.o.weight False
decoder.block.4.layer.1.layer_norm.weight False
decoder.block.4.layer.2.DenseReluDense.wi_0.weight False
decoder.block.4.layer.2.DenseReluDense.wi_1.weight False
decoder.block.4.layer.2.DenseReluDense.wo.weight False
decoder.block.4.layer.2.layer_norm.weight False
decoder.block.5.layer.0.SelfAttention.q.weight False
decoder.block.5.layer.0.SelfAttention.k.weight False
decoder.block.5.layer.0.SelfAttention.v.weight False
decoder.block.5.layer.0.SelfAttention.o.weight False
decoder.block.5.layer.0.layer_norm.weight False
decoder.block.5.layer.1.EncDecAttention.q.weight False
decoder.block.5.layer.1.EncDecAttention.k.weight False
decoder.block.5.layer.1.EncDecAttention.v.weight False
decoder.block.5.layer.1.EncDecAttention.o.weight False
decoder.block.5.layer.1.layer_norm.weight False
decoder.block.5.layer.2.DenseReluDense.wi_0.weight False
decoder.block.5.layer.2.DenseReluDense.wi_1.weight False
decoder.block.5.layer.2.DenseReluDense.wo.weight False
decoder.block.5.layer.2.layer_norm.weight False
decoder.block.6.layer.0.SelfAttention.q.weight False
decoder.block.6.layer.0.SelfAttention.k.weight False
decoder.block.6.layer.0.SelfAttention.v.weight False
decoder.block.6.layer.0.SelfAttention.o.weight False
decoder.block.6.layer.0.layer_norm.weight False
decoder.block.6.layer.1.EncDecAttention.q.weight False
decoder.block.6.layer.1.EncDecAttention.k.weight False
decoder.block.6.layer.1.EncDecAttention.v.weight False
decoder.block.6.layer.1.EncDecAttention.o.weight False
decoder.block.6.layer.1.layer_norm.weight False
decoder.block.6.layer.2.DenseReluDense.wi_0.weight False
decoder.block.6.layer.2.DenseReluDense.wi_1.weight False
decoder.block.6.layer.2.DenseReluDense.wo.weight False
decoder.block.6.layer.2.layer_norm.weight False
decoder.block.7.layer.0.SelfAttention.q.weight False
decoder.block.7.layer.0.SelfAttention.k.weight False
decoder.block.7.layer.0.SelfAttention.v.weight False
decoder.block.7.layer.0.SelfAttention.o.weight False
decoder.block.7.layer.0.layer_norm.weight False
decoder.block.7.layer.1.EncDecAttention.q.weight False
decoder.block.7.layer.1.EncDecAttention.k.weight False
decoder.block.7.layer.1.EncDecAttention.v.weight False
decoder.block.7.layer.1.EncDecAttention.o.weight False
decoder.block.7.layer.1.layer_norm.weight False
decoder.block.7.layer.2.DenseReluDense.wi_0.weight False
decoder.block.7.layer.2.DenseReluDense.wi_1.weight False
decoder.block.7.layer.2.DenseReluDense.wo.weight False
decoder.block.7.layer.2.layer_norm.weight False
decoder.block.8.layer.0.SelfAttention.q.weight False
decoder.block.8.layer.0.SelfAttention.k.weight False
decoder.block.8.layer.0.SelfAttention.v.weight False
decoder.block.8.layer.0.SelfAttention.o.weight False
decoder.block.8.layer.0.layer_norm.weight False
decoder.block.8.layer.1.EncDecAttention.q.weight False
decoder.block.8.layer.1.EncDecAttention.k.weight False
decoder.block.8.layer.1.EncDecAttention.v.weight False
decoder.block.8.layer.1.EncDecAttention.o.weight False
decoder.block.8.layer.1.layer_norm.weight False
decoder.block.8.layer.2.DenseReluDense.wi_0.weight False
decoder.block.8.layer.2.DenseReluDense.wi_1.weight False
decoder.block.8.layer.2.DenseReluDense.wo.weight False
decoder.block.8.layer.2.layer_norm.weight False
decoder.block.9.layer.0.SelfAttention.q.weight False
decoder.block.9.layer.0.SelfAttention.k.weight False
decoder.block.9.layer.0.SelfAttention.v.weight False
decoder.block.9.layer.0.SelfAttention.o.weight False
decoder.block.9.layer.0.layer_norm.weight False
decoder.block.9.layer.1.EncDecAttention.q.weight False
decoder.block.9.layer.1.EncDecAttention.k.weight False
decoder.block.9.layer.1.EncDecAttention.v.weight False
decoder.block.9.layer.1.EncDecAttention.o.weight False
decoder.block.9.layer.1.layer_norm.weight False
decoder.block.9.layer.2.DenseReluDense.wi_0.weight False
decoder.block.9.layer.2.DenseReluDense.wi_1.weight False
decoder.block.9.layer.2.DenseReluDense.wo.weight False
decoder.block.9.layer.2.layer_norm.weight False
decoder.block.10.layer.0.SelfAttention.q.weight False
decoder.block.10.layer.0.SelfAttention.k.weight False
decoder.block.10.layer.0.SelfAttention.v.weight False
decoder.block.10.layer.0.SelfAttention.o.weight False
decoder.block.10.layer.0.layer_norm.weight False
decoder.block.10.layer.1.EncDecAttention.q.weight False
decoder.block.10.layer.1.EncDecAttention.k.weight False
decoder.block.10.layer.1.EncDecAttention.v.weight False
decoder.block.10.layer.1.EncDecAttention.o.weight False
decoder.block.10.layer.1.layer_norm.weight False
decoder.block.10.layer.2.DenseReluDense.wi_0.weight False
decoder.block.10.layer.2.DenseReluDense.wi_1.weight False
decoder.block.10.layer.2.DenseReluDense.wo.weight False
decoder.block.10.layer.2.layer_norm.weight False
decoder.block.11.layer.0.SelfAttention.q.weight False
decoder.block.11.layer.0.SelfAttention.k.weight False
decoder.block.11.layer.0.SelfAttention.v.weight False
decoder.block.11.layer.0.SelfAttention.o.weight False
decoder.block.11.layer.0.layer_norm.weight False
decoder.block.11.layer.1.EncDecAttention.q.weight False
decoder.block.11.layer.1.EncDecAttention.k.weight False
decoder.block.11.layer.1.EncDecAttention.v.weight False
decoder.block.11.layer.1.EncDecAttention.o.weight False
decoder.block.11.layer.1.layer_norm.weight False
decoder.block.11.layer.2.DenseReluDense.wi_0.weight False
decoder.block.11.layer.2.DenseReluDense.wi_1.weight False
decoder.block.11.layer.2.DenseReluDense.wo.weight False
decoder.block.11.layer.2.layer_norm.weight False
decoder.final_layer_norm.weight False
lm_head.weight False
  • Example this DO NOT WORK
for i in range(-num_encoder_layers//2, 0):  
    model.encoder.block[i].requires_grad = True

Upvotes: 3

Hamzah Al-Qadasi
Hamzah Al-Qadasi

Reputation: 9826

An optimized answer to the first answer above is to freeze only the first 15 layers [0-14] because the last layers [15-18] are by default unfrozen (param.requires_grad = True).

Therefore, we only need to code this way:

MobileNet = torchvision.models.mobilenet_v2(pretrained = True)

for param in MobileNet.features[0:14].parameters():
      param.requires_grad = False

Upvotes: 0

joba2ca
joba2ca

Reputation: 366

If you want to define some layers by name and then unfreeze them, I propose a variant of @JVGD's answer:

class RetinaNet(torch.nn.Module):
    def __init__(self, ...):
        self.backbone = ResNet(...)
        self.fpn = FPN(...)
        self.box_head = torch.nn.Sequential(...)
        self.cls_head = torch.nn.Sequential(...)

# Getting the model
retinanet = RetinaNet(...)

# The param name is f'{module_name}.weight' or f'{module_name}.bias'.
# Some layers, e.g., batch norm, have additional params.
# In some circumstances, e.g., when using DataParallel(), 
# the param name is prefixed by 'module.'.
params_to_train = ['cls_head.weight', 'cls_head.bias']
for name, param in retinanet.named_parameters():
    # Set True only for params in the list 'params_to_train'
    param.requires_grad = True if name in params_to_train else False
...

The advantage is that you can define all layers to unfreeze in one Iterable.

Upvotes: 1

JVGD
JVGD

Reputation: 737

Just adding this here for completeness. You can also freeze parameters in place without iterating over them with requires_grad_ (API).

For example say you have a RetinaNet and want to just fine-tune on the heads

class RetinaNet(torch.nn.Module):
    def __init__(self, ...):
        self.backbone = ResNet(...)
        self.fpn = FPN(...)
        self.box_head = torch.nn.Sequential(...)
        self.cls_head = torch.nn.Sequential(...)

Then you could freeze the backbone and FPN like this:

# Getting the model
retinanet = RetinaNet(...)

# Freezing backbone and FPN
retinanet.backbone.requires_grad_(False)
retinanet.fpn.requires_grad_(False)

Upvotes: 27

Related Questions