Reputation: 7312
What is the most appropriate way to call the forward()
method of a parent Module
? For example, if I subclass the nn.Linear
module, I might do the following
class LinearWithOtherStuff(nn.Linear):
def forward(self, x):
y = super(Linear, self).forward(x)
z = do_other_stuff(y)
return z
However, the docs say not to call the forward()
method directly:
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
which makes me think super(Linear, self).forward(x)
could result in some unexpected errors. Is this true or am I misunderstanding inheritance?
Upvotes: 17
Views: 5711
Reputation: 24884
You can use super().forward(...)
freely even with hooks and even with hooks registered in super()
instance.
As stated by this answer __call__
is here so the registered hooks (e.g. register_forward_hook
) will be run.
If you inherit and want to reuse base class's forward
, e.g. this:
import torch
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
return super(Child, self).forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
You are perfectly fine if you call __call__
method, forward
won't run the hook (so you get 3
as above).
It is unlikely you would like to register_hook
on the instance of super
, but let's consider such example:
def increment_by_one(module, input, output):
return output + 1
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
# Increment by `1` from Parent
super().register_forward_hook(increment_by_one)
return super().forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1))) # and it is 5 indeed
print(module.forward(torch.tensor(1))) # here is 3
You are perfectly fine using super().forward(...)
and even hooks will work correctly (and that is the main idea of using __call__
instead of forward
).
BTW. Calling super().__call__(...)
would raise InifiniteRecursion
error.
Upvotes: 8
Reputation: 46449
Here is the minimal M0
module in PyTorch. Nothing in there (no other modules).
What they said about forward()
is you should not call it directly, instead it is called automatically when you instantiate the module and execute the module m0()
import torch
import torch.nn as nn
class M0(nn.Module):
def __init__(self):
super().__init__()
def forward(self)->None:
print("empty module:forward")
# we create a module instance m1
m0 = M0()
m0()
# ??m0.__call__ # has forward() inside
Out:
empty module:forward
In case you would like to have submodules you would aggregate them:
import torch
import torch.nn as nn
class M1(nn.Module):
'''
Single linear layer
'''
def __init__(self):
super().__init__()
self.l1 = nn.Linear(10,100)
def forward(self,x):
print("M1:forward")
x = self.l1(x)
return x
# we create a module instance m1
m1 = M1()
print(m1)
inp = torch.randn(1,10)
r = m1(inp) # result
print(r.shape)
Out:
M1(
(l1): Linear(in_features=10, out_features=100, bias=True)
)
M1:forward
torch.Size([1, 100])
Once you aggregate other modules, you call forward()
to execute them. forward()
will need the input and will return some output.
This model was originally presented in Lua programming language, and PyTorch just used that.
which makes me think super(Linear, self).forward(x) could result in some unexpected errors
This is exactly why forward()
is not called directly to suppress these unexpected errors. Instead, modules are callable like we did in the example:
self.l1(x)
Upvotes: 1