Reputation: 1383
I want to add weight normalization to PyTorch pre-trained VGG-16. One possible solution which I can think of is as follows,
from torch.nn.utils import weight_norm as wn
import torchvision.models as models
class ResnetEncoder(nn.Module):
def __init__(self):
super(ResnetEncoder, self).__init__()
...
self.encoder = models.vgg16(pretrained=True).features
...
def forward(self, input_image):
self.features = []
x = (input_image - self.mean) / self.std
self.features.append(self.encoder(x))
...
return self.features
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.encoder = ResnetEncoder() # this is basically VGG16
self.decoder = DepthDecoder(self.encoder.num_ch_enc)
for k,m in self.encoder.encoder._modules.items():
if isinstance(m,nn.Conv2d):
m = wn(m)
def forward(self,x):
return self.decoder(self.encoder(x))
vgg_backbone_model = Net()
vgg_backbone_model.train()
...
But I do not know if this is the correct way to add weight normalization to pre-trained VGG16.
Upvotes: 1
Views: 299
Reputation: 40678
You should be using nn.Module.modules
instead of accessing the _modules
attribute.
Doing m = wn(m)
won't update the parameters of the layer but instead make a copy and overwrite the local variable m
. Instead, you should override the layer itself from the nn.Module
, one way to do such thing is to use setattr
:
for k, v in model.named_modules():
if isinstance(v, nn.Conv2d):
setattr(model, k, weight_norm(v))
Upvotes: 2