Reputation: 7277
I've two networks, which I need to concatenate for my full model. However my first model is pre-trained and I need to make it non-trainable when training the full model. How can I achieve this in PyTorch.
I am able to concatenate two models using this answer
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.fc1 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyModelB(nn.Module):
def __init__(self):
super(MyModelB, self).__init__()
self.fc1 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
def forward(self, x):
x1 = self.modelA(x)
x2 = self.modelB(x1)
return x2
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)
Basically here, I want to load pre-trained modelA
and make it non-trainable when training the Ensemble model.
Upvotes: 5
Views: 13702
Reputation: 1685
One easy way to do that is to detach
the output tensor of the model that you don't want to update and it will not backprop gradient to the connected model. In your case, you can simply detach x2
tensor just before concatinating with x1
in the forward function of MyEnsemble
model to keep the weight of modelB
unchanged.
So, the new forward function should be like following:
def forward(self, x1, x2):
x1 = self.modelA(x1)
x2 = self.modelB(x2)
x = torch.cat((x1, x2.detach()), dim=1) # Detaching x2, so modelB wont be updated
x = self.classifier(F.relu(x))
return x
Upvotes: 8
Reputation: 3506
You can freeze all parameters of the model you dont want to train, by setting requires_grad
to false.
Like this:
for param in model.parameters():
param.requires_grad = False
This should work for you.
Another way is to handle this in your train-loop:
modelA = MyModelA()
modelB = MyModelB()
criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)
for epoch in range(epochs):
for samples, targets in dataloader:
optimizerB.zero_grad()
x = modelA.train()(samples)
predictions = modelB.train()(samples)
loss = criterionB(predictions, targets)
loss.backward()
optimizerB.step()
So you pass the output of modelA to modelB but you optimize just modelB.
Upvotes: 2