Christopher Mills
Christopher Mills

Reputation: 760

forward() not overridden in implementation of nn.Module in an example

In this example, we see the following implementation of nn.Module:

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

However, in the docs we have that 'forward(*input)' "Should be overridden by all subclasses."

Why is that not then the case in this example?

Upvotes: 0

Views: 1210

Answers (1)

Ivan
Ivan

Reputation: 40628

This Net module is meant to be used via two separate interfaces encoder and decode, at least it seems so... Since it doesn't have a forward implementation, then yes it is improperly inheriting from nn.Module. However, the code is still "valid", and will run properly but may have some side effects if you are using forward hooks.

The standard way of performing inference on a nn.Module is to call the object, i.e. calling the __call__ function. This __call__ function is implemented by the parent class nn.Module and will in turn do two things:

  • handle forward hooks before or after the inference call
  • call the forward function of the class.

The __call__ function acts as a wrapper of forward. So for this reason the forward function is expected to be overridden by the user-defined nn.Module. The only caveat of violating this design pattern is that it will effectively ignore any hooks applied on the nn.Module.

Upvotes: 2

Related Questions