Reputation: 349
I fine-tuned a pre-trained BERT model from Huggingface on a custom dataset for 10 epochs using pytorch-lightning. I did logging with Weights and Biases logger.
When I load from checkpoint like so:
model.load_from_checkpoint("/path/to/checkpoint/epoch=9-step=590.ckpt")
I get the warning saying that the weights got re-initialized:
Some weights of the model checkpoint at tbs17/MathBERT were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at tbs17/MathBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In my case, I fine-tuned and loaded BertForSequenceClassification and expect it to be identical to the one I fine-tuned. How do I make sure that the weights are not re-initialized like that?
Thanks.
ps, this answer describes a similar issue and the answer states that the warning can be ignored. Still, is there a way to verify that the checkpoint loading did not break fine-tuning?
Upvotes: 1
Views: 1080
Reputation: 380
You can make prediction on a sample and save the results. To test checkpoint loading
you can use the saved results to match model output.
Upvotes: -1