Reputation: 5722
I would like to have a PyTorch sub-class of Module
that keeps sub-modules in a list (because there may be a variable number of sub-modules depending on the constructor's arguments). I set this list in the following way:
self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]
According to this and this question, a submodule is only registered by __setattr__
, when a Module
object is assigned to an attribute of self
. Because hidden_layers
is not assigned an object of type Module
, the submodules in the list are not registered as submodules, and as a result self.parameters()
does not iterate over the submodules' parameters.
I suppose I could explicitly call __subattr__
for each element of the list but that would be quite ugly. Is there a more correct way to register a submodule that is not a direct attribute of Module
?
Upvotes: 3
Views: 3548
Reputation: 2117
As answered nn.ModuleList
is what you want.
What you can also use is nn.Sequential
. You can create a list of layers and then combine them via nn.Sequential
, which will just act as a wrapper and combines all layers to essential one layer/module. This has the advantage that you only need one call to forward it through all the layers, which is nice if you have a dynamic count of modules, so you don't have to write the loops on your own.
One example would be in the pytorch ResNet code: https://github.com/pytorch/vision/blob/497744b9d510ff2df756f479ee5a19fce0d579b6/torchvision/models/resnet.py#L177
Upvotes: 1
Reputation: 4826
Use nn.ModuleList
.
self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])
Upvotes: 4