Reputation: 1767
How can I save a PyTorch model without a need for the model class to be defined somewhere?
Disclaimer:
In Best way to save a trained model in PyTorch?, there are no solutions (or a working solution) for saving the model without access to the model class code.
Upvotes: 26
Views: 21282
Reputation: 1515
If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.
I think the simplest thing is to use trace = torch.jit.trace(model, typical_input)
and then torch.jit.save(trace, path)
. You can then load the traced model with torch.jit.load(path)
.
Here's a really simple example. We make two files:
train.py
:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
x = torch.relu(self.linear(x))
return x
model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
infer.py
:
import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))
Running these sequentially gives results:
python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).
TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.
Upvotes: 33
Reputation: 887
Supplying an official answer by one of the core PyTorch devs (smth):
There are limitations to loading a pytorch model without code.
First limitation: We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).
For example:
import foo
class MyModel(...):
def forward(input):
foo.bar(input)
Here the package foo
is not saved in the model checkpoint.
Second limitation: There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill 25 is one such package.
Given these limitations, there is no robust way to have torch.load work without having the original source files.
Upvotes: 1
Reputation: 41
I recomend you to convert you pytorch model to onnx and save it. Probably its best way to store model without an access to the class.
Upvotes: 2
Reputation: 46449
There is no a solutins (or working solution) for saving model without an access to the class.
You can save whatever you like.
You can save the model, torch.save(model, filepath)
. It saves the model object itself.
You can save just the model state dict.
torch.save(model.state_dict(), filepath)
Further, you can save anything you like, since torch.save
is just a pickle based save.
state = {
'hello_text': 'just the optimizer sd will be saved',
'optimizer': optimizer.state_dict(),
}
torch.save(state, filepath)
You may check what I wrote on torch.save
some time ago.
Upvotes: 0