Reputation: 23
I want to overwrite nn.conv2d so prepared models such as resnet, alexnet etc. can use it without changing the all nn.conv2ds in the model manually.
from torchvision import models
from torch import nn
class replace_conv2d(nn.Module):
# other codes
nn.conv2d = replace_conv2d # what I want to do
model = models.resnet18()
so resnet18 will use the replace_conv2d class instead of nn.conv2d
Upvotes: 1
Views: 190
Reputation: 40648
I am not sure you could overwrite the modules when they are loaded in. What you can do though is wrap the nn.Module
with a function that will go through the module tree and replace nn.Conv2d
with another layer implementation (for example here nn.Identity
). The only trick is the fact child layers can be identified by compound keys. For example models.layer1[0].conv2
has keys "layer1"
, "0"
, and finally "conv2"
.
Gather the nn.Conv2d
and split their compound keys:
convs = []
for k, v in model.named_modules():
if isinstance(v, nn.Conv2d):
convs.append(k.split('.'))
Build a recursive function to get a sub module from a compound key:
inspect = lambda m, k: inspect(getattr(m, k[0]), k[1:]) if len(k)>1 else m
Finally, you can iterate over the submodules and replace the layer:
for k in convs:
setattr(inspect(model, k), k[-1], nn.Identity())
You will see all nn.Conv2d
layers (whatever their depth) will be replaced:
>>> model.layer1[0].conv2
Identity()
If you want to access the parameters of the conv layer you are about to replace, you can check its attributes:
keys = 'in_channels', 'out_channels', 'kernel_size', \
'stride', 'padding', 'dilation', 'groups', \
'bias', 'padding_mode'
for k in convs:
parent = inspect(model, k)
conv = getattr(parent, k[-1])
setattr(parent, k[-1], nn.Conv2d(**{k: getattr(conv,k) for k in keys}))
Upvotes: 2