Franz
Franz

Reputation: 360

Empty state_dict with vector or tuple of layers in nn.Module

I switched to using a Version with a parametrized number of layers of torch.nn.Module like Net_par below, only to find out all the saved state_dicts were empty after quite some optimizing.^^

This method is the recommended saving operation (https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models), still layers stored in a vector (or tuple, for that matter) are discarded when constructing the state_dict.

torch.save works properly in contrast, but adds to data and limits robustness. This feels a little like a bug, can anybody help with a workaround?

Minimal example for comparison between parametrized and fixed layer count:

import torch
import torch.nn as nn

class Net_par(nn.Module):
    def __init__(self,layer_dofs):
        super(Net_par, self).__init__()
        self.layers=[]
        for i in range(len(layer_dofs)-1):
            self.layers.append(nn.Linear(layer_dofs[i],layer_dofs[i+1]))
    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = torch.tanh(self.layers[i](x))
        return torch.tanh(self.layers[len(self.layers)-1](x))

class Net_sta(nn.Module):
    def __init__(self,dof1,dof2):
        super(Net_sta, self).__init__()
        self.layer=nn.Linear(dof1,dof2)
    def forward(self, x):
        return torch.tanh(self.layer1(x))

if __name__=="__main__":
    net_par=Net_par((3,4))
    net_sta=Net_sta(3,4)
    print(str(net_par.state_dict()))
    #OrderedDict()                 <------Why?!
    print(str(net_sta.state_dict()))
    #OrderedDict([('layer.weight', tensor([[...
    #   ...]])), ('layer.bias', tensor([...   ...]))])

Upvotes: 1

Views: 743

Answers (1)

Harshit Kumar
Harshit Kumar

Reputation: 12847

You need to use nn.ModuleList() instead of simple python list.

class Net_par(nn.Module):
    ...
    self.layers = nn.ModuleList([])

Upvotes: 1

Related Questions