tyler02
tyler02

Reputation: 23

Saving the weights of a Pytorch .pth model into a .txt or .json

I am trying to save the the weights of a pytorch model into a .txt or .json. When writing it to a .txt,

#import torch
model = torch.load("model_path")
string = str(model)
with open('some_file.txt', 'w') as fp:
     fp.write(string)

I get a file where not all the weights are saved, i.e there are ellipsis throughout the textfile. I cannot write it to a JSON since the model has tensors, which are not JSON serializable [unless there is a way that I do not know?] How can I save the weights in the .pth file to some format such that no information is lost, and can be easily seen?

Thanks

Upvotes: 2

Views: 7018

Answers (2)

A bit late but hopefully this helps. This is how you store it:

import torch
from torch.utils.data import Dataset

from json import JSONEncoder
import json

class EncodeTensor(JSONEncoder,Dataset):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.cpu().detach().numpy().tolist()
        return super(EncodeTensor, self).default(obj)

with open('torch_weights.json', 'w') as json_file:
    json.dump(model.state_dict(), json_file,cls=EncodeTensor)

Take into account that the stored values are of type list so you have to use torch.Tensor(list) when you are going to use the weights.

Upvotes: 2

Proko
Proko

Reputation: 2011

When you are doing str(model.state_dict()), it recursively uses str method of elements it contains. So the problem is how individual element string representations are build. You should increase the limit of lines printed in individual string representation:

torch.set_printoptions(profile="full")

See the difference with this:

import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])

Tensors are currently not JSON serializable.

Upvotes: 1

Related Questions