Reputation: 2046
I'd like to partition a neural network into two sub-networks using Pytorch. To make things concrete, consider this image:
In 1, I've a 3x4x1 neural network. What I want is, for example during epoch 1, I'd only like to update the weights in the sub-network 1, i.e., the weights that appear in the sub-network 2 must be frozen. Then again, in epoch 2, I'd like to train the weights that appear in sub-network 2 while the rest should be frozen.
How can I do that?
Upvotes: 2
Views: 2527
Reputation: 2748
You can do this easily if your subnet is a subset of layers. That is, you do not need to freeze any partial layers. It is all or nothing.
For your example that would mean dividing the hidden layer into two different 2-node layers. Each would belong to exactly one of the subnetworks, which gets us back to all or nothing.
With that done, you can toggle individual layers using requires_grad. Setting this to False
on the parameters will disable training and freeze the weights. To do this for an entire model, sub-model, or Module, you loop through the model.parameters()
.
For your example, with 3 inputs, 1 output, and a now split 2x2 hidden layer, it might look something like this:
import torch.nn as nn
import torch.nn.functional as F
def set_grad(model, grad):
for param in model.parameters():
param.requires_grad = grad
class HalfFrozenModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.hid1 = torch.nn.Linear(3, 2)
self.hid2 = torch.nn.Linear(3, 2)
self.out = torch.nn.Linear(4, 1)
def set_freeze(self, hid1=False, hid2=False):
set_grad(self.hid1, not hid1)
set_grad(self.hid2, not hid2)
def forward(self, inp):
hid1 = self.hid1(inp)
hid2 = self.hid2(inp)
hidden = torch.cat([hid1, hid2], 1)
return self.out(F.relu(hidden))
Then you can train one half or the other like so:
model = HalfFrozenModel()
model.set_freeze(hid1=True)
# Do some training.
model.set_freeze(hid2=True)
# Do some more training.
# ...
If you happen to use fastai, then there is a concept of layer groups that is also used for this. The fastai documentation goes into some detail about how that works.
Upvotes: 5