Charlie Parker
Charlie Parker

Reputation: 5201

How does one save torch.nn.Sequential models in pytorch properly?

I am very well aware of loading the dictionary and then having a instance of be loaded with the old dictionary of parameters (e.g. this great question & answer). Unfortunately, when I have a torch.nn.Sequential I of course do not have a class definition for it.

So I wanted to double check, what is the proper way to do it. I believe torch.save is sufficient (so far my code has not collapsed), though these things can be more subtle than one might expect (e.g. I get a warning when I use pickle but torch.save uses it internally so it's confusing). Also, numpy has it's own save functions (e.g. see this answer) which tend to be more efficient, so there might be a subtle trade off I might be overlooking.


My test code:


# creating data and running through a nn and saving it

import torch
import torch.nn as nn

from pathlib import Path
from collections import OrderedDict

import numpy as np

import pickle

path = Path('~/data/tmp/').expanduser()
path.mkdir(parents=True, exist_ok=True)

num_samples = 3
Din, Dout = 1, 1
lb, ub = -1, 1

x = torch.torch.distributions.Uniform(low=lb, high=ub).sample((num_samples, Din))

f = nn.Sequential(OrderedDict([
    ('f1', nn.Linear(Din,Dout)),
    ('out', nn.SELU())
]))
y = f(x)

# save data torch to numpy
x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()
np.savez(path / 'db', x=x_np, y=y_np)

print(x_np)
# save model
with open('db_saving_seq', 'wb') as file:
    pickle.dump({'f': f}, file)

# load model
with open('db_saving_seq', 'rb') as file:
    db = pickle.load(file)
    f2 = db['f']

# test that it outputs the right thing
y2 = f2(x)

y_eq_y2 = y == y2
print(y_eq_y2)

db2 = {'f': f, 'x': x, 'y': y}
torch.save(db2, path / 'db_f_x_y')

print('Done')

db3 = torch.load(path / 'db_f_x_y')
f3 = db3['f']
x3 = db3['x']
y3 = db3['y']
yy3 = f3(x3)

y_eq_y3 = y == y3
print(y_eq_y3)

y_eq_yy3 = y == yy3
print(y_eq_yy3)

Related:

Upvotes: 3

Views: 1992

Answers (1)

Marius
Marius

Reputation: 388

As can be seen in the code torch.nn.Sequential is based on torch.nn.Module: https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential

So you can use

f = torch.nn.Sequential(...)
torch.save(f.state_dict(), path)

just like with any other torch.nn.Module.

Upvotes: 1

Related Questions