cerebrou
cerebrou

Reputation: 5540

Accessing functions in the class modules of nn.Sequential

When running nn.Sequential, I include a list of class modules (which would be layers of a neural network). When running nn.Sequential, it calls forward functions of the modules. However each of the class modules also has a function which I would like to access when the nn.Sequential runs. How can I access and run this function when running nn.Sequential?

Upvotes: 1

Views: 1034

Answers (1)

Gil Pinsky
Gil Pinsky

Reputation: 2493

You can use a hook for that. Let's consider the following example demonstrated on VGG16:

This is the network architecture:

enter image description here

Say we want to monitor the input and output for layer (2) in the features Sequential (that Conv2d layer you see above). For this matter we register a forward hook, named my_hook which will be called on any forward pass:

import torch
from torchvision.models import vgg16

def my_hook(self, input, output):
    print('my_hook\'s output')
    print('input: ', input)
    print('output: ', output)

# Sample net:
net = vgg16()

#Register forward hook:
net.features[2].register_forward_hook(my_hook)

# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed

Upvotes: 1

Related Questions