Reputation: 23
I am trying to save the the weights of a pytorch model into a .txt or .json. When writing it to a .txt,
#import torch
model = torch.load("model_path")
string = str(model)
with open('some_file.txt', 'w') as fp:
fp.write(string)
I get a file where not all the weights are saved, i.e there are ellipsis throughout the textfile. I cannot write it to a JSON since the model has tensors, which are not JSON serializable [unless there is a way that I do not know?] How can I save the weights in the .pth file to some format such that no information is lost, and can be easily seen?
Thanks
Upvotes: 2
Views: 7018
Reputation: 121
A bit late but hopefully this helps. This is how you store it:
import torch
from torch.utils.data import Dataset
from json import JSONEncoder
import json
class EncodeTensor(JSONEncoder,Dataset):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().detach().numpy().tolist()
return super(EncodeTensor, self).default(obj)
with open('torch_weights.json', 'w') as json_file:
json.dump(model.state_dict(), json_file,cls=EncodeTensor)
Take into account that the stored values are of type list
so you have to use torch.Tensor(list)
when you are going to use the weights.
Upvotes: 2
Reputation: 2011
When you are doing str(model.state_dict())
, it recursively uses str
method of elements it contains. So the problem is how individual element string representations are build. You should increase the limit of lines printed in individual string representation:
torch.set_printoptions(profile="full")
See the difference with this:
import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
Tensors are currently not JSON serializable.
Upvotes: 1