M.Z.
M.Z.

Reputation: 144

How to save model architecture in PyTorch?

I know I can save a model by torch.save(model.state_dict(), FILE) or torch.save(model, FILE). But both of them don't save the architecture of model.

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ? I want to apply different tweaks to my model. Do I have any better way than copying the whole class definition every time and creating a new class if I can't save the architecture of a model?

Upvotes: 11

Views: 14632

Answers (5)

Prajot Kuvalekar
Prajot Kuvalekar

Reputation: 6658

Export/Load Model in TorchScript Format is what you are looking for

Another common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in C++.

NOTE: Using the TorchScript format, you will be able to load the exported model and run inference without defining the model class.

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

Export:

model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save

Load [ Works w/o defining model class ]:

model = torch.jit.load('model_scripted.pt')
model.eval()

       **Model arch in Netron looks like this**

Upvotes: 2

elgehelge
elgehelge

Reputation: 2150

PyTorch's way of serializing a model for inference is to use torch.jit to compile the model to TorchScript.

PyTorch's TorchScript supports more advanced control flows than TensorFlow, and thus the serialization can happen either through tracing (torch.jit.trace) or compiling the Python model code (torch.jit.script).

Great references:

Upvotes: 2

Xxxo
Xxxo

Reputation: 1931

Regarding the actual question:

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ?

The answer is: You cannot

Is there any way to load a trained model without declaring the class definition before ? I want the model architecture as well as parameters to be loaded.

no, you have to load the class definition before, this is a python pickling limitation.

https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11

Though, there are other options (probably you have already seen most of those) that are listed at this PyTorch post:

https://pytorch.org/tutorials/beginner/saving_loading_models.html

Upvotes: 2

Shai
Shai

Reputation: 114926

Saving all the parameters (state_dict) and all the Modules is not enough, since there are operations that manipulates the tensors, but are only reflected in the actual code of the specific implementation (e.g., reshapeing in ResNet).

Furthermore, the network might not have a fixed and pre-determined compute graph: You can think of a network that has branching or a loop (recurrence).

Therefore, you must save the actual code.

Alternatively, if there are no branches/loops in the net, you may save the computation graph, see, e.g., this post.

You should also consider exporting your model using onnx and have a representation that captures both the trained weights as well as the computation graph.

Upvotes: 6

Roshan Santhosh
Roshan Santhosh

Reputation: 687

You can refer to this article to understand how to save the classifier. To make a tweaks to a model, what you can do is create a new model which is a child of the existing model.


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

With this setup, newModel has all the layers as well as the forward function of oldModelClass. If you need to make tweaks, you can define new layers in the __init__ function and then write a new forward function to define it.

Upvotes: 6

Related Questions