dkv
dkv

Reputation: 7312

Calling super's forward() method

What is the most appropriate way to call the forward() method of a parent Module? For example, if I subclass the nn.Linear module, I might do the following

class LinearWithOtherStuff(nn.Linear):
    def forward(self, x):
        y = super(Linear, self).forward(x)
        z = do_other_stuff(y)
        return z

However, the docs say not to call the forward() method directly:

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

which makes me think super(Linear, self).forward(x) could result in some unexpected errors. Is this true or am I misunderstanding inheritance?

Upvotes: 17

Views: 5711

Answers (2)

Szymon Maszke
Szymon Maszke

Reputation: 24884

TLDR;

You can use super().forward(...) freely even with hooks and even with hooks registered in super() instance.

Explanation

As stated by this answer __call__ is here so the registered hooks (e.g. register_forward_hook) will be run.

If you inherit and want to reuse base class's forward, e.g. this:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

You are perfectly fine if you call __call__ method, forward won't run the hook (so you get 3 as above).

It is unlikely you would like to register_hook on the instance of super , but let's consider such example:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

You are perfectly fine using super().forward(...) and even hooks will work correctly (and that is the main idea of using __call__ instead of forward).

BTW. Calling super().__call__(...) would raise InifiniteRecursion error.

Upvotes: 8

prosti
prosti

Reputation: 46449

Here is the minimal M0 module in PyTorch. Nothing in there (no other modules). What they said about forward() is you should not call it directly, instead it is called automatically when you instantiate the module and execute the module m0()

import torch
import torch.nn as nn

class M0(nn.Module):    
    def __init__(self):
        super().__init__()
    
    def forward(self)->None: 
        print("empty module:forward")
        
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside

Out:

empty module:forward

In case you would like to have submodules you would aggregate them:

import torch
import torch.nn as nn

class M1(nn.Module):    
    ''' 
    Single linear layer
    '''
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(10,100)        
    
    def forward(self,x): 
        print("M1:forward")
        x = self.l1(x) 
        return x
        
# we create a module instance m1
m1 = M1()
print(m1)
inp = torch.randn(1,10)
r = m1(inp) # result
print(r.shape)

Out:

M1(
  (l1): Linear(in_features=10, out_features=100, bias=True)
)
M1:forward
torch.Size([1, 100])

Once you aggregate other modules, you call forward() to execute them. forward() will need the input and will return some output.

This model was originally presented in Lua programming language, and PyTorch just used that.


which makes me think super(Linear, self).forward(x) could result in some unexpected errors

This is exactly why forward() is not called directly to suppress these unexpected errors. Instead, modules are callable like we did in the example:

self.l1(x)

Upvotes: 1

Related Questions