Ark-kun
Ark-kun

Reputation: 6787

PyTorch - Save just the model structure without weights and then load and train it

I want to separate model structure authoring and training. The model author designs the model structure, saves the untrained model to a file and then sends it training service which loads the model structure and trains the model.

Keras has the ability to save the model config and then load it.

How can the same be accomplished with PyTorch?

Upvotes: 1

Views: 2736

Answers (1)

Ayush
Ayush

Reputation: 1620

You can write your own function to do that in PyTorch. Saving of weights is straight forward where you simply do a torch.save(model.state_dict(), 'weightsAndBiases.pth').
For saving the model structure, you can do this:
(Assume you have a model class named Network, and you instantiate yourModel = Network())

model_structure = {'input_size': 784,
              'output_size': 10,
              'hidden_layers': [each.out_features for each in yourModel.hidden_layers],
              'state_dict': yourModel.state_dict() #if you want to save the weights
}

torch.save(model_structure, 'model_structure.pth')

Similarly, we can write a function to load the structure.

def load_structure(filepath):
    structure = torch.load(filepath)
    model = Network(structure['input_size'],
                    structure['output_size'],
                    structure['hidden_layers'])
    # model.load_state_dict(structure['state_dict']) if you had saved weights as well
    
    return model

model = load_structure('model_structure.pth')
print(model)

Edit: Okay, the above was the case when you had access to source code for your class, or if the class was relatively simple so you could define a generic class like this:

class Network(nn.Module):
    def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):
        ''' Builds a feedforward network with arbitrary hidden layers.
        
            Arguments
            ---------
            input_size: integer, size of the input layer
            output_size: integer, size of the output layer
            hidden_layers: list of integers, the sizes of the hidden layers
        
        '''
        super().__init__()
        # Input to a hidden layer
        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
        
        # Add a variable number of more hidden layers
        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
        
        self.output = nn.Linear(hidden_layers[-1], output_size)
        
        self.dropout = nn.Dropout(p=drop_p)
        
    def forward(self, x):
        ''' Forward pass through the network, returns the output logits '''
        
        for each in self.hidden_layers:
            x = F.relu(each(x))
            x = self.dropout(x)
        x = self.output(x)
        
        return F.log_softmax(x, dim=1)

However, that will only work for simple cases so I suppose that's not what you intended.

One option is, you can define the architecture of model in a separate .py file and import it along with other necessities(if the model architecture is complex) or you can altogether define the model then and there.

Another option is converting your pytorch model to onxx and saving it.

The other option is that, in Tensorflow you can create a .pb file that defines both the architecture and the weights of the model and in Pytorch you would do something like that this way:

torch.save(model, filepath)

This will save the model object itself, as torch.save() is just a pickle-based save at the end of the day.

model = torch.load(filepath)

This however has limitations, your model class definition might not for example be picklable(possible in some complicated models). Because this is a such an iffy workaround, the answer that you'll usually get is - No, you have to declare the class definition before loading the trained model, ie you need to have access to the model class source code.

Side notes: An official answer by one of the core PyTorch devs on limitations of loading a pytorch model without code:

  1. We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).
import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

Here the package foo is not saved in the model checkpoint.

  1. There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill 25 is one such package.

Given these limitations, there is no robust way to have torch.load work without having the original source files.

Upvotes: 1

Related Questions