Reputation: 5201
I am very well aware of loading the dictionary and then having a instance of be loaded with the old dictionary of parameters (e.g. this great question & answer). Unfortunately, when I have a torch.nn.Sequential
I of course do not have a class definition for it.
So I wanted to double check, what is the proper way to do it. I believe torch.save
is sufficient (so far my code has not collapsed), though these things can be more subtle than one might expect (e.g. I get a warning when I use pickle but torch.save
uses it internally so it's confusing). Also, numpy has it's own save functions (e.g. see this answer) which tend to be more efficient, so there might be a subtle trade off I might be overlooking.
My test code:
# creating data and running through a nn and saving it
import torch
import torch.nn as nn
from pathlib import Path
from collections import OrderedDict
import numpy as np
import pickle
path = Path('~/data/tmp/').expanduser()
path.mkdir(parents=True, exist_ok=True)
num_samples = 3
Din, Dout = 1, 1
lb, ub = -1, 1
x = torch.torch.distributions.Uniform(low=lb, high=ub).sample((num_samples, Din))
f = nn.Sequential(OrderedDict([
('f1', nn.Linear(Din,Dout)),
('out', nn.SELU())
]))
y = f(x)
# save data torch to numpy
x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()
np.savez(path / 'db', x=x_np, y=y_np)
print(x_np)
# save model
with open('db_saving_seq', 'wb') as file:
pickle.dump({'f': f}, file)
# load model
with open('db_saving_seq', 'rb') as file:
db = pickle.load(file)
f2 = db['f']
# test that it outputs the right thing
y2 = f2(x)
y_eq_y2 = y == y2
print(y_eq_y2)
db2 = {'f': f, 'x': x, 'y': y}
torch.save(db2, path / 'db_f_x_y')
print('Done')
db3 = torch.load(path / 'db_f_x_y')
f3 = db3['f']
x3 = db3['x']
y3 = db3['y']
yy3 = f3(x3)
y_eq_y3 = y == y3
print(y_eq_y3)
y_eq_yy3 = y == yy3
print(y_eq_yy3)
Related:
Upvotes: 3
Views: 1992
Reputation: 388
As can be seen in the code torch.nn.Sequential
is based on torch.nn.Module
:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
So you can use
f = torch.nn.Sequential(...)
torch.save(f.state_dict(), path)
just like with any other torch.nn.Module
.
Upvotes: 1