Reputation: 1
I am fine-tuning the M2M model, with 1.2B model as the last checkpoint. But while training the model I am getting this error that it cannot load the paramters and the model architechure should match
Traceback (most recent call last): File "/home/krish/.local/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 75, in _wrap fn(i, *args) File "/home/krish/.local/lib/python3.10/site-packages/fairseq/distributed/utils.py", line 328, in distributed_main main(cfg, **kwargs) File "/home/krish/content/train.py", line 165, in main extra_state, epoch_itr = checkpoint_utils.load_checkpoint( File "/home/krish/.local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py", line 248, in load_checkpoint extra_state = trainer.load_checkpoint( File "/home/krish/.local/lib/python3.10/site-packages/fairseq/trainer.py", line 580, in load_checkpoint raise Exception( Exception: Cannot load model parameters from checkpoint /home/krish/content/1.2B_last_checkpoint.pt; please ensure that the architectures match.
The code that I executed:
try:
train_command = 'CUDA_VISIBLE_DEVICES="0" python /home/krish/content/train.py /home/krish/content/Hindi_Marathi/wmt22_spm/wmt22_bin \
--arch transformer_wmt_en_de_big \
--task translation_multi_simple_epoch \
--finetune-from-model /home/krish/content/1.2B_last_checkpoint.pt \
--save-dir /home/krish/content/Hindi_Marathi/checkpoint \
--langs \'hi,mr\' \
--lang-pairs \'hi-mr\' \
--max-tokens 1200 \
--encoder-normalize-before --decoder-normalize-before \
--sampling-method temperature --sampling-temperature 1.5 \
--encoder-langtok src --decoder-langtok \
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
--optimizer adam --adam-eps 1e-06 --adam-betas \'(0.9, 0.98)\' \
--lr-scheduler inverse_sqrt --lr 3e-05 \
--warmup-updates 2500 --max-update 40000 \
--dropout 0.3 --attention-dropout 0.1 \
--weight-decay 0.0 \
--update-freq 2 --save-interval 5 \
--save-interval-updates 5000 --keep-interval-updates 3 \
--no-epoch-checkpoints \
--seed 222 \
--log-format simple \
--log-interval 2 \
--encoder-layers 12 --decoder-layers 12 \
--encoder-layerdrop 0.05 --decoder-layerdrop 0.05 \
--share-decoder-input-output-embed \
--share-all-embeddings \
--ddp-backend no_c10d'
This was the training script. Here, I have checked with the model that the arch is "transformer_wmt_en_de_big" only. How should I proceed to resolve this?
Upvotes: 0
Views: 99
Reputation: 1
import torch
checkpoint = torch.load('/home/prajna/krish/content/1.2B_last_checkpoint.pt')
print(checkpoint['args'].arch)
This will give you the architecture used for training the checkpoint.
Upvotes: 0