Ilia
Ilia

Reputation: 349

How to load huggingface's BERT after fine-tuning with Pytorch Lightning?

I fine-tuned a pre-trained BERT model from Huggingface on a custom dataset for 10 epochs using pytorch-lightning. I did logging with Weights and Biases logger.

When I load from checkpoint like so:

model.load_from_checkpoint("/path/to/checkpoint/epoch=9-step=590.ckpt")

I get the warning saying that the weights got re-initialized:

Some weights of the model checkpoint at tbs17/MathBERT were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at tbs17/MathBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

In my case, I fine-tuned and loaded BertForSequenceClassification and expect it to be identical to the one I fine-tuned. How do I make sure that the weights are not re-initialized like that?

Thanks.

ps, this answer describes a similar issue and the answer states that the warning can be ignored. Still, is there a way to verify that the checkpoint loading did not break fine-tuning?

Upvotes: 1

Views: 1080

Answers (1)

Aniket Maurya
Aniket Maurya

Reputation: 380

You can make prediction on a sample and save the results. To test checkpoint loading you can use the saved results to match model output.

Upvotes: -1

Related Questions