Reputation: 1577
I'm new to PyTorch and the whole model/AI programming.
I have a library that needs a checkpoint in the form of a state_dict from a model to run.
I've the .pt
model (more preciously the radtts pre-trained model) and I need to extract the dictionary for the checkpoint.
From what I understand from the PyTorch documentation I should be able to load the model and save the state_dict with torch.save(model.state_dict(), PATH)
My problem is, first of all, is it correct? How do I load the model on PyTorch to extract the state_dict?
Upvotes: 0
Views: 385
Reputation: 40738
If you load the checkpoint you linked to (hifigan_libritts100360_generator0p5.pt
) then you will see that the archive contains a single key/value pair: "generator"
and is assigned to the state dictionary of the checkpoint itself.
>>> pt = torch.load('hifigan_libritts100360_generator0p5.pt')
>>> pt.keys()
dict_keys(['generator'])
Upvotes: 1