finiteautomata
finiteautomata

Reputation: 3813

Optimizer and scheduler for BERT fine-tuning

I'm trying to fine-tune a model with BERT (using transformers library), and I'm a bit unsure about the optimizer and scheduler.

First, I understand that I should use transformers.AdamW instead of Pytorch's version of it. Also, we should use a warmup scheduler as suggested in the paper, so the scheduler is created using get_linear_scheduler_with_warmup function from transformers package.

The main questions I have are:

  1. get_linear_scheduler_with_warmup should be called with the warm up. Is it ok to use 2 for warmup out of 10 epochs?
  2. When should I call scheduler.step()? If I do after train, the learning rate is zero for the first epoch. Should I call it for each batch?

Am I doing something wrong with this?

from transformers import AdamW
from transformers.optimization import get_linear_scheduler_with_warmup

N_EPOCHS = 10

model = BertGRUModel(finetune_bert=True,...)
num_training_steps = N_EPOCHS+1
num_warmup_steps = 2
warmup_proportion = float(num_warmup_steps) / float(num_training_steps)  # 0.1

optimizer = AdamW(model.parameters())
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([class_weights[1]]))


scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, 
    num_training_steps=num_training_steps
)

for epoch in range(N_EPOCHS):
    scheduler.step() #If I do after train, LR = 0 for the first epoch
    print(optimizer.param_groups[0]["lr"])

    train(...) # here we call optimizer.step()
    evaluate(...)

My model and train routine(quite similar to this notebook)

class BERTGRUSentiment(nn.Module):
    def __init__(self,
                 bert,
                 hidden_dim,
                 output_dim,
                 n_layers=1, 
                 bidirectional=False,
                 finetune_bert=False,
                 dropout=0.2):

        super().__init__()

        self.bert = bert

        embedding_dim = bert.config.to_dict()['hidden_size']

        self.finetune_bert = finetune_bert

        self.rnn = nn.GRU(embedding_dim,
                          hidden_dim,
                          num_layers = n_layers,
                          bidirectional = bidirectional,
                          batch_first = True,
                          dropout = 0 if n_layers < 2 else dropout)

        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)        
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):    
        #text = [batch size, sent len]

        if not self.finetune_bert:
            with torch.no_grad():
                embedded = self.bert(text)[0]
        else:
            embedded = self.bert(text)[0]
        #embedded = [batch size, sent len, emb dim]
        _, hidden = self.rnn(embedded)

        #hidden = [n layers * n directions, batch size, emb dim]

        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
        else:
            hidden = self.dropout(hidden[-1,:,:])

        #hidden = [batch size, hid dim]

        output = self.out(hidden)

        #output = [batch size, out dim]

        return output


import torch
from sklearn.metrics import accuracy_score, f1_score


def train(model, iterator, optimizer, criterion, max_grad_norm=None):
    """
    Trains the model for one full epoch
    """
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for i, batch in enumerate(iterator):
        optimizer.zero_grad()
        text, lens = batch.text

        predictions = model(text)

        target = batch.target

        loss = criterion(predictions.squeeze(1), target)

        prob_predictions = torch.sigmoid(predictions)

        preds = torch.round(prob_predictions).detach().cpu()
        acc = accuracy_score(preds, target.cpu())

        loss.backward()
        # Gradient clipping
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)


Upvotes: 13

Views: 38242

Answers (2)

pashok3ddd
pashok3ddd

Reputation: 335

Here you can see a visualization of learning rate changes using get_linear_scheduler_with_warmup.

Referring to this comment: Warm up steps is a parameter which is used to lower the learning rate in order to reduce the impact of deviating the model from learning on sudden new data set exposure.

By default, number of warm up steps is 0.

Then you make bigger steps, because you are probably not near the minima. But as you are approaching the minima, you make smaller steps to converge to it.

Also, note that number of training steps is number of batches * number of epochs, but not just number of epochs. So, basically num_training_steps = N_EPOCHS+1 is not correct, unless your batch_size is equal to the training set size.

You call scheduler.step() every batch, right after optimizer.step(), to update the learning rate.

Upvotes: 18

dennlinger
dennlinger

Reputation: 11430

I think it is hardly possible to give a 100% perfect answer, but you can certainly get inspiration from the way other scripts are doing it. The best place to start is the examples/ directory of the huggingface repository itself, where you can for example find this excerpt:

if (step + 1) % args.gradient_accumulation_steps == 0:
    if args.fp16:
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
    else:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

    optimizer.step()
    scheduler.step()  # Update learning rate schedule
    model.zero_grad()
    global_step += 1

If we look at the surrounding parts, this is basically updating the LR schedule every time you do a backwards pass. In the same example, you can also look at the default value for warmup_steps, which is 0. From my understanding, the warmup is not necessarily required when fine-tuning, but I am less certain about this aspect and would check with other scripts as well.

Upvotes: 2

Related Questions