KRISH MANTRI
KRISH MANTRI

Reputation: 1

Exception: Cannot load model parameters from checkpoint /home/krish/content/1.2B_last_checkpoint.pt; please ensure that the architectures match

I am fine-tuning the M2M model, with 1.2B model as the last checkpoint. But while training the model I am getting this error that it cannot load the paramters and the model architechure should match

Traceback (most recent call last): File "/home/krish/.local/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 75, in _wrap fn(i, *args) File "/home/krish/.local/lib/python3.10/site-packages/fairseq/distributed/utils.py", line 328, in distributed_main main(cfg, **kwargs) File "/home/krish/content/train.py", line 165, in main extra_state, epoch_itr = checkpoint_utils.load_checkpoint( File "/home/krish/.local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py", line 248, in load_checkpoint extra_state = trainer.load_checkpoint( File "/home/krish/.local/lib/python3.10/site-packages/fairseq/trainer.py", line 580, in load_checkpoint raise Exception( Exception: Cannot load model parameters from checkpoint /home/krish/content/1.2B_last_checkpoint.pt; please ensure that the architectures match.

The code that I executed:

try:

train_command = 'CUDA_VISIBLE_DEVICES="0" python /home/krish/content/train.py /home/krish/content/Hindi_Marathi/wmt22_spm/wmt22_bin \
        --arch transformer_wmt_en_de_big \
        --task translation_multi_simple_epoch \
        --finetune-from-model /home/krish/content/1.2B_last_checkpoint.pt \
        --save-dir /home/krish/content/Hindi_Marathi/checkpoint \
        --langs \'hi,mr\' \
        --lang-pairs \'hi-mr\' \
        --max-tokens 1200 \
        --encoder-normalize-before --decoder-normalize-before \
        --sampling-method temperature --sampling-temperature 1.5 \
        --encoder-langtok src --decoder-langtok \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
        --optimizer adam --adam-eps 1e-06 --adam-betas \'(0.9, 0.98)\' \
        --lr-scheduler inverse_sqrt --lr 3e-05 \
        --warmup-updates 2500 --max-update 40000 \
        --dropout 0.3 --attention-dropout 0.1 \
        --weight-decay 0.0 \
        --update-freq 2 --save-interval 5 \
        --save-interval-updates 5000 --keep-interval-updates 3 \
        --no-epoch-checkpoints \
        --seed 222 \
        --log-format simple \
        --log-interval 2 \
        --encoder-layers 12 --decoder-layers 12 \
        --encoder-layerdrop 0.05 --decoder-layerdrop 0.05 \
        --share-decoder-input-output-embed \
        --share-all-embeddings \
        --ddp-backend no_c10d'

This was the training script. Here, I have checked with the model that the arch is "transformer_wmt_en_de_big" only. How should I proceed to resolve this?

Upvotes: 0

Views: 99

Answers (1)

Faouzia V
Faouzia V

Reputation: 1

import torch

checkpoint = torch.load('/home/prajna/krish/content/1.2B_last_checkpoint.pt')
print(checkpoint['args'].arch)

This will give you the architecture used for training the checkpoint.

Upvotes: 0

Related Questions