spiridon_the_sun_rotator
spiridon_the_sun_rotator

Reputation: 1034

Does the PyTorch Lightning average metrics over the whole epoch?

I am looking at the example provided on PyTorch-Lightning official documentation https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html.

Here the loss and metric is calculated on the concrete batch. But when logging one is not interested in the accuracy for a particular batch, which can be rather small and not representative, but the averaged over all epoch. Do I understand correctly, that there is some code performing the averaging over all batches, passed through the epoch?

 import pytorch_lightning as pl
 from pytorch_lightning.metrics import functional as FM

 class ClassificationTask(pl.LightningModule):

 def __init__(self, model):
     super().__init__()
     self.model = model

 def training_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self.model(x)
     loss = F.cross_entropy(y_hat, y)
     return pl.TrainResult(loss)

 def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = FM.accuracy(y_hat, y)
    result = pl.EvalResult(checkpoint_on=loss)
    result.log_dict({'val_acc': acc, 'val_loss': loss})
    return result

 def test_step(self, batch, batch_idx):
    result = self.validation_step(batch, batch_idx)
    result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
    return result

 def configure_optimizers(self):
     return torch.optim.Adam(self.model.parameters(), lr=0.02)

Upvotes: 6

Views: 5144

Answers (1)

hendryx
hendryx

Reputation: 193

If you want to average metrics over the epoch, you'll need to tell the LightningModule you've subclassed to do so. There are a few different ways to do this such as:

  1. Call result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) as shown in the docs with on_epoch=True so that the training loss is averaged across the epoch. I.e.:
 def training_step(self, batch, batch_idx):
     x, y = batch
     y_hat = self.model(x)
     loss = F.cross_entropy(y_hat, y)
     result = pl.TrainResult(loss)
     result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
     return result
  1. Alternatively, you can call the log method on the LightningModule itself: self.log("train_loss", loss, on_epoch=True, sync_dist=True) (Optionally passing sync_dist=True to reduce across accelerators).

You'll want to do something similar in validation_step to get aggregated val-set metrics or implement the aggregation yourself in the validation_epoch_end method.

Upvotes: 3

Related Questions