John
John

Reputation: 835

freezing layers in a neural network in pytorch

I have a cascaded neural network whereby the the output of first network become the input of second network. The first neural network is pretrained so I just initialise it with those pretrained weights. However, I want to freeze the first neural network so that when training its only updating weights of the second neural network. How can I do that? My network looks like:


###First network

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))

def get_first_model(load_weights = True):
    pretrained_model_reloaded_th = nn.Sequential( # Sequential,
        nn.Conv2d(4,300,(19, 1)),
        nn.BatchNorm2d(300),
        nn.ReLU(),
        nn.MaxPool2d((3, 1),(3, 1)),
        nn.Conv2d(300,200,(11, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        nn.Conv2d(200,200,(7, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        sd = torch.load('pretrained_model.pth')
        pretrained_model_reloaded_th.load_state_dict(sd)
    return  pretrained_model_reloaded_th

### second network

def next_model_architecture():
    next_model = nn.Sequential(
    nn.Linear(164, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid())
    
    return next_model

### joining two networks
def cascading_model(first_model,next_model):
    network = nn.Sequential(first_model, next_model)
    return network


first_model = get_first_model(load_weights = True)
next_model = next_model_architecture()
network = cascading_model(first_model,next_model)

If I do:


first_model = first_model.eval()

Will this freeze my first neural network and only update weights of second network during training?

Upvotes: 2

Views: 4875

Answers (2)

JVGD
JVGD

Reputation: 737

You can also freeze parameters in place without iterating over them with requires_grad_. Which in your case would be:

# Freezing network Sequential at index 0
network[0].requires_grad_(False)

Normally in more complex networks you would have different modules. In your case for example if you could have built the network like:

class Network(torch.nn.Module):
    def __init__(self, ...):
        self.first_model = get_first_model(load_weights = True)
        self.next_model = next_model_architecture()

    def forward(self, x):
        x = self.first_model(x)
        x = self.next_model(x)
        return x
# Class intance
network = Network(...)

Then you could freeze just one of the sub-models like this:

# Freezing network submodule: first_model
network.first_model.requires_grad_(False)

Upvotes: 1

ayandas
ayandas

Reputation: 2288

Freezing any parameter is done by setting it's .requires_grad to False. Do so by iterating over all parameters of the module (that you want to freeze)

for p in first_model.parameters():
    p.requires_grad = False

Upvotes: 3

Related Questions