Kuo
Kuo

Reputation: 412

Pytorch: can we use nn.Module layers directly in forward() function?

Generally, In the constructor, we declare all the layers we want to use. In the forward function, we define how the model is going to be run, from input to output.

My question is what if calling those predefined/built-in nn.Modules directly in forward() function? Is this Keras function API style legal for Pytorch? If not, why?

Update: TestModel constructed in this way did run successfully, without an alarm. But the training loss will descend slowly compared with the conventional way.

import torch.nn as nn
from cnn import CNN

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_embeddings = 2020
        self.embedding_dim = 51

    def forward(self, input):
        x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
        # CNN is a customized class and nn.Module subclassed
        # we will ignore the arguments for its instantiation 
        x = CNN(...)(x)
        x = nn.ReLu()(x)
        x = nn.Dropout(p=0.2)(x)
        return output = x

Upvotes: 9

Views: 6167

Answers (2)

Szymon Maszke
Szymon Maszke

Reputation: 24691

What you are trying to do kinda can be done, but shouldn't as it's totally unnecessary in most cases. And it's not more readable IMO and definitely against PyTorch's way.

In your forward layers are reinitialized every time and they are not registered in your network.

To do it correctly you can use Module's add_module() function with guard against reassignment (method dynamic below):

import torch

class Example(torch.nn.Module):
    def __init__(self):
        self.num_embeddings = 2020        
        self.embedding_dim = 51 

    def dynamic(self, name: str, module_class, *args, **kwargs):
        if not hasattr(self, name):
            self.add_module(name, module_class(*args, **kwargs))
        return getattr(self, name)

    def forward(self, x):
        embedded = self.dynamic("embedding",
                     torch.nn.Embedding,
                     self.num_embeddings,
                     self.embedding_dim)(x)
        return embedded

You could structure it differently but that's the idea standing behind it.

Real use case could be when the layer's creation is somehow dependent on the data passed to forward, but this may indicate some flaws in program design.

Upvotes: 3

Shai
Shai

Reputation: 114786

You need to think of the scope of the trainable parameters.

If you define, say, a conv layer in the forward function of your model, then the scope of this "layer" and its trainable parameters is local to the function and will be discarded after every call to the forward method. You cannot update and train weights that are constantly being discarded after every forward pass.
However, when the conv layer is a member of your model its scope extends beyond the forward method and the trainable parameters persists as long as the model object exists. This way you can update and train the model and its weights.

Upvotes: 10

Related Questions