Reputation: 760
In this example, we see the following implementation of nn.Module
:
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
However, in the docs we have that 'forward(*input)
' "Should be overridden by all subclasses."
Why is that not then the case in this example?
Upvotes: 0
Views: 1210
Reputation: 40628
This Net
module is meant to be used via two separate interfaces encoder
and decode
, at least it seems so... Since it doesn't have a forward
implementation, then yes it is improperly inheriting from nn.Module
. However, the code is still "valid", and will run properly but may have some side effects if you are using forward hooks.
The standard way of performing inference on a nn.Module
is to call the object, i.e. calling the __call__
function. This __call__
function is implemented by the parent class nn.Module
and will in turn do two things:
forward
function of the class.The __call__
function acts as a wrapper of forward
.
So for this reason the forward
function is expected to be overridden by the user-defined nn.Module
. The only caveat of violating this design pattern is that it will effectively ignore any hooks applied on the nn.Module
.
Upvotes: 2