Reputation: 398
I've trained a Tacotron2 model, using Mozilla TTS, on a custom dataset. The trainer outputs a pth file and a config.json file. I have difficulty loading the trained model into PyTorch.
from torchaudio.models.tacotron2 import Tacotron2
tacotron2 =Tacotron2()
tacotron2.load_state_dict(torch.load('models/best_model.pth'))
RuntimeError: Error(s) in loading state_dict for Tacotron2: Missing key(s) in state_dict: "embedding.weight", "encoder.convolutions.0.0.weight", "encoder.convolutions.0.0.bias", "encoder.convolutions.0.1.weight", "encoder.convolutions.0.1.bias", "encoder.convolutions.0.1.running_mean", "encoder.convolutions.0.1.running_var", "encoder.convolutions.1.0.weight", "encoder.convolutions.1.0.bias", "encoder.convolutions.1.1.weight", "encoder.convolutions.1.1.bias", "encoder.convolutions.1.1.running_mean", "encoder.convolutions.1.1.running_var", "encoder.convolutions.2.0.weight", "encoder.convolutions.2.0.bias", "encoder.convolutions.2.1.weight", "encoder.convolutions.2.1.bias", "encoder.convolutions.2.1.running_mean", "encoder.convolutions.2.1.running_var", "encoder.lstm.weight_ih_l0", "encoder.lstm.weight_hh_l0", "encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0", "encoder.lstm.weight_ih_l0_reverse", "encoder.lstm.weight_hh_l0_reverse", "encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse", "decoder.prenet.layers.0.weight", "decoder.prenet.layers.1.weight", "decoder.attention_rnn.weight_ih", "decoder.attention_rnn.weight_hh", "decoder.attention_rnn.bias_ih", "decoder.attention_rnn.bias_hh", "decoder.attention_layer.query_layer.weight", "decoder.attention_layer.memory_layer.weight", "decoder.attention_layer.v.weight", "decoder.attention_layer.location_layer.location_conv.weight", "decoder.attention_layer.location_layer.location_dense.weight", "decoder.decoder_rnn.weight_ih", "decoder.decoder_rnn.weight_hh", "decoder.decoder_rnn.bias_ih", "decoder.decoder_rnn.bias_hh", "decoder.linear_projection.weight", "decoder.linear_projection.bias", "decoder.gate_layer.weight", "decoder.gate_layer.bias", "postnet.convolutions.0.0.weight", "postnet.convolutions.0.0.bias", "postnet.convolutions.0.1.weight", "postnet.convolutions.0.1.bias", "postnet.convolutions.0.1.running_mean", "postnet.convolutions.0.1.running_var", "postnet.convolutions.1.0.weight", "postnet.convolutions.1.0.bias", "postnet.convolutions.1.1.weight", "postnet.convolutions.1.1.bias", "postnet.convolutions.1.1.running_mean", "postnet.convolutions.1.1.running_var", "postnet.convolutions.2.0.weight", "postnet.convolutions.2.0.bias", "postnet.convolutions.2.1.weight", "postnet.convolutions.2.1.bias", "postnet.convolutions.2.1.running_mean", "postnet.convolutions.2.1.running_var", "postnet.convolutions.3.0.weight", "postnet.convolutions.3.0.bias", "postnet.convolutions.3.1.weight", "postnet.convolutions.3.1.bias", "postnet.convolutions.3.1.running_mean", "postnet.convolutions.3.1.running_var", "postnet.convolutions.4.0.weight", "postnet.convolutions.4.0.bias", "postnet.convolutions.4.1.weight", "postnet.convolutions.4.1.bias", "postnet.convolutions.4.1.running_mean", "postnet.convolutions.4.1.running_var". Unexpected key(s) in state_dict: "config", "model", "optimizer", "scaler", "step", "epoch", "date", "model_loss".
Upvotes: 1
Views: 648
Reputation: 3473
According to the error message, what the load_state_dict()
command was expecting was apparently a dictionary with keys being named network parameters like "decoder.attention_rnn.bias_hh" etc, i.e. the trained parameters and a way to identify them.
It seems however that the pth
checkpoint is a binarized python dictionary, containing all of the necessary ingredients to resume training (rather than just employ the model). I'm guessing that:
Try perhaps
checkpoint = torch.load('models/best_model.pth')
tacotron2.load_state_dict(checkpoint["model"])
and see what happens. If it doesn't work, check the keys of the nested dictionary checkpoint["model"]
and explore around.
If you passed any non-default arguments during training, you'll need to replicate them (hint: use the config) when initializing for loading too.
Upvotes: 0