Alexander Borochkin
Alexander Borochkin

Reputation: 4611

EarlyStopping callback in PyTorch Lightning problem

I try to train Neural Network model in PyTorch Lightning and training fails on validation step where it executes EarlyStopping callback.

The relevant part of the model is below. See, in particular, validation_step which must log the metrics necessary for Early stopping.

class DialogActsLightningModel(pl.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.model = ContextAwareDAC(
            model_name=self.config['model_name'],
            hidden_size=self.config['hidden_size'],
            num_classes=self.config['num_classes'],
            device=self.config['device']
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

    def forward(self, batch):
        logits = self.model(batch)
        return logits

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
        logits = self(batch)
        loss = F.cross_entropy(logits, targets)
        acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
        f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
                "val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
        avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
        wandb.log({"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
                   "val_recall": avg_recall})
        return {"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
                "val_recall": avg_recall}

When I run training in the following way:

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import os

from trainer import DialogActsLightningModel
import wandb

wandb.init()

logger = WandbLogger(
    name="model_name",
    entity='myname',
    save_dir=config["save_dir"],
    project=config["project"],
    log_model=True,
)
early_stopping = EarlyStopping(
    monitor="val_accuracy",
    min_delta=0.001,
    patience=5,
)

model = DialogActsLightningModel(config=config)

trainer = pl.Trainer(
    logger=logger,
    gpus=[0],
    checkpoint_callback=True,
    callbacks=[early_stopping],
    default_root_dir=MODELS_DIRECTORY,
    max_epochs=config["epochs"],
    precision=config["precision"],
    limit_train_batches=10, # run for only 10 batches, debug mode
    limit_test_batches=10,
    limit_val_batches=10
)

trainer.fit(model)

I've got an error, but the model should have calculated and logged the metric "val_accuracy" during validation step.

Epoch 0: 100%
20/20 [00:28<00:00, 1.41s/it, loss=1.95]

/opt/conda/lib/python3.9/site-packages/pytorch_lightning/callbacks/early_stopping.py in _validate_condition_metric(self, logs)
    149         if monitor_val is None:
    150             if self.strict:
--> 151                 raise RuntimeError(error_msg)
    152             if self.verbose > 0:
    153                 rank_zero_warn(error_msg, RuntimeWarning)

RuntimeError: Early stopping conditioned on metric `val_accuracy` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``

What I am doing wrong? How to fix it?

Upvotes: 1

Views: 8348

Answers (1)

nadhir hasan
nadhir hasan

Reputation: 193

if you use pytorch-lightning latest version you should want to log the val_accuracy or val_loss while you calling early stopping or similar functions. for more please check out the code below.i think this will definitely helpful for you...

def validation_step(self, batch, batch_idx):
    input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
    logits = self(batch)
    loss = F.cross_entropy(logits, targets)
    acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
    f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
    precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
    recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])

    ##########################################################################
    ##########################################################################
    self.log("val_accuracy", torch.tensor([acc])     # try this line
    ##########################################################################
    ##########################################################################

    return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
            "val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}

If This Post is Useful Please Up vote

Upvotes: 2

Related Questions