user118967
user118967

Reputation: 5722

How can I have submodules of a PyTorch Module that are not attributes of the module

I would like to have a PyTorch sub-class of Module that keeps sub-modules in a list (because there may be a variable number of sub-modules depending on the constructor's arguments). I set this list in the following way:

self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]

According to this and this question, a submodule is only registered by __setattr__, when a Module object is assigned to an attribute of self. Because hidden_layers is not assigned an object of type Module, the submodules in the list are not registered as submodules, and as a result self.parameters() does not iterate over the submodules' parameters.

I suppose I could explicitly call __subattr__ for each element of the list but that would be quite ugly. Is there a more correct way to register a submodule that is not a direct attribute of Module?

Upvotes: 3

Views: 3548

Answers (2)

Nopileos
Nopileos

Reputation: 2117

As answered nn.ModuleList is what you want.

What you can also use is nn.Sequential. You can create a list of layers and then combine them via nn.Sequential, which will just act as a wrapper and combines all layers to essential one layer/module. This has the advantage that you only need one call to forward it through all the layers, which is nice if you have a dynamic count of modules, so you don't have to write the loops on your own.

One example would be in the pytorch ResNet code: https://github.com/pytorch/vision/blob/497744b9d510ff2df756f479ee5a19fce0d579b6/torchvision/models/resnet.py#L177

Upvotes: 1

hkchengrex
hkchengrex

Reputation: 4826

Use nn.ModuleList.

self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])

Upvotes: 4

Related Questions