Clock Slave
Clock Slave

Reputation: 7957

Training over an already trained transformer model

I took a pretrained BERT model and fine tuned it for text classification using a dataset(~3mn records, 46 categories).

Now I want to add some data(~5k records, 10 categories) to the model while keeping the original 46 categories. I just want the model to have all the latest data.

I want to avoid retraining with the full(3mn+5k) data because of time and costs and also because can be recurring activity (3-4 times a week)

Is there a way to do this?

Below is my code setup. I am using HF's trainer

# imports
import torch
from transformers import TrainerCallback
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import EarlyStoppingCallback

# constants
device = torch.device("cuda")
MODEL_NAME = 'bert-large-uncased'
TRAINING_EPOCHS = 20
TRAINING_BATCH_SIZE = 400
EVAL_BATCH_SIZE = 100

# dataset from pandas df
tr_dataset = Dataset(x_tr, tr_df.label_encoded.values.tolist())
te_dataset = Dataset(x_te, te_df.label_encoded.values.tolist())

# download model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=n_out, ).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# define training args
args = TrainingArguments(
    output_dir=SAVE_BERT_PATH,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    save_strategy="no",
    per_device_train_batch_size=TRAINING_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    num_train_epochs=TRAINING_EPOCHS,
    seed=42,
    fp16=True,
    dataloader_num_workers = 10,
    load_best_model_at_end=False,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    logging_strategy='epoch',
    logging_first_step=True
    
)

# define trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tr_dataset,
    eval_dataset=te_dataset,
    compute_metrics=compute_metrics)

# train and eval
trainer.train()
trainer.evaluate()

Upvotes: 2

Views: 947

Answers (1)

Hyeoung Ho Bae
Hyeoung Ho Bae

Reputation: 21

In case you haven't found the answer, you can just load your trained model from the saved location (with trainer.save_model()) or checkpoint. In either case, your model is in the latest trained state. Then you can train it further with new data.

Upvotes: 1

Related Questions