Desperados
Desperados

Reputation: 434

PyTorch - FineTuning bert - Oscillating loss - Very bad accuracy

I have been trying to train a model on vulnerability detection through source code. And, after a little bit of searching, I thought a very good starting point could be using a pre-trained transformer model from HuggingFace with PyTorch and pl.lightning torch. I chose DistilBert because it was the fastest one.

I have an imbalanced dataset, approximately 70% non-vulnerable and 30% vulnerable functions.

However, my results have been very bad. The model does not seem to learn and generalize. Specifically, during training the train loss is heavily oscillating, accuracy is around 70 percent and recall is extremely low (implying that the model always predicts one label).

I was wondering if there is anything that might be obviously problematic with my code. This is the first time I am using a pre-trained model and pl lightning and I cannot really tell what might have gone wrong.

class Model(pl.LightningModule):
    def __init__(self, n_classes, n_training_steps, n_warmup_steps, lr, fine_tune=False):
        super().__init__()
        self.save_hyperparameters()
        self.bert = DistilBert.from_pretrained(BERT_MODEL_NAME, return_dict=True)
        for name, param in self.bert.named_parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(self.bert.config.hidden_size, self.hparams.n_classes)
        self.criterion = nn.BCELoss()

    def finetune(self):
        self.fine_tune = True
        for name, param in self.bert.named_parameters():
            if  'layer.5' in name:
                param.requires_grad = True

    def forward(self, input_ids, attention_mask, labels=None):
        x = self.bert(input_ids, attention_mask=attention_mask)
        x = x.last_hidden_state[:,0,:]
        x = self.classifier(x)
        x = torch.sigmoid(x)
        x = x.squeeze(dim=-1)
        loss = 0
        if labels is not None:
            loss = self.criterion(x, labels.float())
        return loss, x

    def training_step(self, batch, batch_idx):
        enc, labels = batch
        input_ids, attention_mask = enc
        loss, outputs = self.forward(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {'loss': loss, 'predictions': outputs, 'labels': labels}

    def validation_step(self, batch, batch_idx):
        enc, labels = batch
        input_ids, attention_mask = enc
        loss, outputs = self.forward(input_ids, attention_mask, labels)
        r = recall(outputs[:], labels[:])
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_recall", r, prog_bar=True, logger=True)
        return {'loss': loss, 'predictions': outputs, 'labels': labels}

    def test_step(self, batch, batch_idx):
        enc, labels = batch
        input_ids, attention_mask = enc
        loss, outputs = self.forward(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return {'loss': loss, 'predictions': outputs, 'labels': labels}

    def training_epoch_end(self, outputs):
        labels = []
        predictions = []
        for o in outputs:
            for o_labels in o['labels'].detach().cpu():
                labels.append(o_labels)
            for o_preds in o['predictions'].detach().cpu():
                predictions.append(o_preds)
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        class_recall = recall(predictions[:], labels[:])
        self.logger.experiment.add_scalar("recall/Train", class_recall, self.current_epoch)

    def validation_epoch_end(self, outputs):
        labels = []
        predictions = []
        for o in outputs:
            for o_labels in o['labels'].detach().cpu():
                labels.append(o_labels)
            for o_preds in o['predictions'].detach().cpu():
                predictions.append(o_preds)
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        class_recall = recall(predictions[:], labels[:])
        self.logger.experiment.add_scalar("recall/Validation", class_recall, self.current_epoch)

    def test_epoch_end(self, outputs):
        labels = []
        predictions = []
        for o in outputs:
            for o_labels in o['labels'].detach().cpu():
                labels.append(o_labels)
            for o_preds in o['predictions'].detach().cpu():
                predictions.append(o_preds)
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        class_recall = recall(predictions[:], labels[:])
        self.logger.experiment.add_scalar("recall/Test", class_recall, self.current_epoch)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr if self.hparams.fine_tune == False else self.hparams.lr // 100)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.n_warmup_steps,
            num_training_steps=self.hparams.n_training_steps
        )
        return dict(
            optimizer=optimizer,
            lr_scheduler=dict(
                scheduler=scheduler,
                interval='step'
            )
        )

if __name__ == "__main__":

    data_module = SourceCodeDataModule(batch_size=BATCH_SIZE)
    steps_per_epoch = len(train_loader) // BATCH_SIZE
    total_training_steps = steps_per_epoch * N_EPOCHS
    warmup_steps = total_training_steps // 5

    model = Model(
        n_classes=1,
        n_warmup_steps = warmup_steps,
        n_training_steps=total_training_steps,
        lr=2e-5
    )

    logger = TensorBoardLogger("lightning_logs", name="bert_predictor")
    early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
    trainer = pl.Trainer(
    logger=logger,
    checkpoint_callback=checkpoint_callback,
    callbacks=[early_stopping_callback],
    max_epochs=N_EPOCHS,
    gpus=1 if str(device).startswith('cuda') else 0,
    progress_bar_refresh_rate=30
    )

    # First just train the final layer.
    trainer.fit(model, datamodule=data_module)
    result = trainer.test(model, datamodule=data_module)
    print(f"Result when training classifier only: {result}")

    # Then train the whole model
    model = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    model.finetune()
    trainer.fit(model, datamodule=data_module)
    result = trainer.test(model, datamodule=data_module)
    print(f"Result when fine tuning: {result}")

Upvotes: 1

Views: 761

Answers (1)

razimbres
razimbres

Reputation: 5015

Here,

def finetune(self):
        self.fine_tune = True
        for name, param in self.bert.named_parameters():
            if  'layer.5' in name:
                param.requires_grad = True

try to unfreeze more layers at the end of the neural net, maybe the weights are saturated and not learning enough. Also, pay attention to the loss you are using, as well as the activation function at the output.

Upvotes: 1

Related Questions