Reputation: 5540
When running nn.Sequential, I include a list of class modules (which would be layers of a neural network). When running nn.Sequential, it calls forward functions of the modules. However each of the class modules also has a function which I would like to access when the nn.Sequential runs. How can I access and run this function when running nn.Sequential?
Upvotes: 1
Views: 1034
Reputation: 2493
You can use a hook for that. Let's consider the following example demonstrated on VGG16:
This is the network architecture:
Say we want to monitor the input and output for layer (2) in the features Sequential
(that Conv2d layer you see above).
For this matter we register a forward hook, named my_hook
which will be called on any forward pass:
import torch
from torchvision.models import vgg16
def my_hook(self, input, output):
print('my_hook\'s output')
print('input: ', input)
print('output: ', output)
# Sample net:
net = vgg16()
#Register forward hook:
net.features[2].register_forward_hook(my_hook)
# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed
Upvotes: 1