NicoCaldo
NicoCaldo

Reputation: 1577

Export state_dict checkpoint from .pt model PyTorch

I'm new to PyTorch and the whole model/AI programming.

I have a library that needs a checkpoint in the form of a state_dict from a model to run.

I've the .pt model (more preciously the radtts pre-trained model) and I need to extract the dictionary for the checkpoint.

From what I understand from the PyTorch documentation I should be able to load the model and save the state_dict with torch.save(model.state_dict(), PATH)

My problem is, first of all, is it correct? How do I load the model on PyTorch to extract the state_dict?

Upvotes: 0

Views: 385

Answers (1)

Ivan
Ivan

Reputation: 40738

If you load the checkpoint you linked to (hifigan_libritts100360_generator0p5.pt) then you will see that the archive contains a single key/value pair: "generator" and is assigned to the state dictionary of the checkpoint itself.

>>> pt = torch.load('hifigan_libritts100360_generator0p5.pt')
>>> pt.keys()
dict_keys(['generator'])

Upvotes: 1

Related Questions