Nagabhushan S N
Nagabhushan S N

Reputation: 7277

How to concatenate 2 pytorch models and make the first one non-trainable in PyTorch

I've two networks, which I need to concatenate for my full model. However my first model is pre-trained and I need to make it non-trainable when training the full model. How can I achieve this in PyTorch.

I am able to concatenate two models using this answer

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

Basically here, I want to load pre-trained modelA and make it non-trainable when training the Ensemble model.

Upvotes: 5

Views: 13702

Answers (2)

Kaushik Roy
Kaushik Roy

Reputation: 1685

One easy way to do that is to detach the output tensor of the model that you don't want to update and it will not backprop gradient to the connected model. In your case, you can simply detach x2 tensor just before concatinating with x1 in the forward function of MyEnsemble model to keep the weight of modelB unchanged.

So, the new forward function should be like following:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x

Upvotes: 8

Theodor Peifer
Theodor Peifer

Reputation: 3506

You can freeze all parameters of the model you dont want to train, by setting requires_grad to false. Like this:

for param in model.parameters():
    param.requires_grad = False

This should work for you.

Another way is to handle this in your train-loop:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

So you pass the output of modelA to modelB but you optimize just modelB.

Upvotes: 2

Related Questions