Reputation: 412
Generally, In the constructor, we declare all the layers we want to use. In the forward function, we define how the model is going to be run, from input to output.
My question is what if calling those predefined/built-in nn.Modules directly in forward()
function? Is this Keras function API style legal for Pytorch? If not, why?
Update: TestModel constructed in this way did run successfully, without an alarm. But the training loss will descend slowly compared with the conventional way.
import torch.nn as nn
from cnn import CNN
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51
def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x
Upvotes: 9
Views: 6167
Reputation: 24691
What you are trying to do kinda can be done, but shouldn't as it's totally unnecessary in most cases. And it's not more readable IMO and definitely against PyTorch's way.
In your forward
layers are reinitialized every time and they are not registered in your network.
To do it correctly you can use Module
's add_module()
function with guard against reassignment (method dynamic
below):
import torch
class Example(torch.nn.Module):
def __init__(self):
self.num_embeddings = 2020
self.embedding_dim = 51
def dynamic(self, name: str, module_class, *args, **kwargs):
if not hasattr(self, name):
self.add_module(name, module_class(*args, **kwargs))
return getattr(self, name)
def forward(self, x):
embedded = self.dynamic("embedding",
torch.nn.Embedding,
self.num_embeddings,
self.embedding_dim)(x)
return embedded
You could structure it differently but that's the idea standing behind it.
Real use case could be when the layer's creation is somehow dependent on the data passed to forward
, but this may indicate some flaws in program design.
Upvotes: 3
Reputation: 114786
You need to think of the scope of the trainable parameters.
If you define, say, a conv layer in the forward
function of your model, then the scope of this "layer" and its trainable parameters is local to the function and will be discarded after every call to the forward
method. You cannot update and train weights that are constantly being discarded after every forward
pass.
However, when the conv layer is a member of your model
its scope extends beyond the forward
method and the trainable parameters persists as long as the model
object exists. This way you can update and train the model and its weights.
Upvotes: 10