Reputation: 399
I’m trying to train BERT model from scratch using my own dataset using HuggingFace library. I would like to train the model in a way that it has the exact architecture of the original BERT model.
In the original paper, it stated that: “BERT is trained on two tasks: predicting randomly masked tokens (MLM) and predicting whether two sentences follow each other (NSP). SCIBERT follows the same architecture as BERT but is instead pretrained on scientific text.”
I’m trying to understand how to train the model on two tasks as above. At the moment, I initialised the model as below:
from transformers import BertForMaskedLM
model = BertForMaskedLM(config=config)
However, it would just be for MLM and not NSP. How can I initialize and train the model with NSP as well or maybe my original approach was fine as it is?
My assumptions would be either
Initialize with BertForPreTraining
(for both MLM and NSP), OR
After finish training with BertForMaskedLM
,
initalize the same model and train again with
BertForNextSentencePrediction
(but this approach’s computation and
resources would cost twice…)
I’m not sure which one is the correct way. Any insights or advice would be greatly appreciated.
Upvotes: 14
Views: 13206
Reputation: 1491
I would suggest doing the following:
First pre-train BERT on the MLM objective. HuggingFace provides a script especially for training BERT on the MLM objective on your own data. You can find it here. As you can see in the run_mlm.py
script, they use AutoModelForMaskedLM
, and you can specify any architecture you want.
Second, if want to train on the next sentence prediction task, you can define a BertForPretraining
model (which has both the MLM and NSP heads on top), then load in the weights from the model you trained in step 1, and then further pre-train it on a next sentence prediction task.
UPDATE: apparently the next sentence prediction task did help improve performance of BERT on some GLUE tasks. See this talk by the author of BERT.
Upvotes: 11
Reputation: 660
You can easily train BERT from scratch both on MLM & NSP tasks using combination of BertForPretraining
TextDatasetForNextSentencePrediction
DataCollatorForLanguageModeling
and Trainer
.
I wouldn't suggest you to first train your model MLM then NSP which might lead to catastrophic forgetting. It's basically forgetting what you've learnt from previous training.
from transformers import BertTokenizer
bert_cased_tokenizer = BertTokenizer.from_pretrained("/path/to/pre-trained/tokenizer/for/new/domain", do_lower_case=False)
BertForPretraining
from transformers import BertConfig, BertForPreTraining
config = BertConfig()
model = BertForPreTraining(config)
TextDatasetForNextSentencePrediction
will tokenize and creates labels for sentences. Your dataset should in the following format: (or you could just modify the existing code)(1) One sentence per line. These should ideally be actual sentences (2) Blank lines between documents
Sentence-1 From Document-1
Sentence-2 From Document-1
Sentence-3 From Document-1
...
Sentence-1 From Document-2
Sentence-2 From Document-2
Sentence-3 From Document-2
from transformers import TextDatasetForNextSentencePrediction
dataset = TextDatasetForNextSentencePrediction(
tokenizer=bert_cased_tokenizer,
file_path="/path/to/your/dataset",
block_size = 256
)
DataCollatorForLanguageModeling
for masking and passing the labels that are generated from TextDatasetForNextSentencePrediction
. DataCollatorForNextSentencePrediction
has been removed, since it was doing the same thing with DataCollatorForLanguageModeling
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=bert_cased_tokenizer,
mlm=True,
mlm_probability= 0.15
)
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir= "/path/to/output/dir/for/training/arguments"
overwrite_output_dir=True,
num_train_epochs=2,
per_gpu_train_batch_size= 16,
save_steps=10_000,
save_total_limit=2,
prediction_loss_only=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
trainer.save_model("path/to/your/model")
Upvotes: 30