Reputation: 1034
I am looking at the example provided on PyTorch-Lightning
official documentation https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html.
Here the loss and metric is calculated on the concrete batch. But when logging one is not interested in the accuracy for a particular batch, which can be rather small and not representative, but the averaged over all epoch. Do I understand correctly, that there is some code performing the averaging over all batches, passed through the epoch?
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss})
return result
def test_step(self, batch, batch_idx):
result = self.validation_step(batch, batch_idx)
result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
return result
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
Upvotes: 6
Views: 5144
Reputation: 193
If you want to average metrics over the epoch, you'll need to tell the LightningModule
you've subclassed to do so. There are a few different ways to do this such as:
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
as shown in the docs with on_epoch=True
so that the training loss is averaged across the epoch. I.e.: def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return result
log
method on the LightningModule
itself: self.log("train_loss", loss, on_epoch=True, sync_dist=True)
(Optionally passing sync_dist=True
to reduce across accelerators).You'll want to do something similar in validation_step
to get aggregated val-set metrics or implement the aggregation yourself in the validation_epoch_end
method.
Upvotes: 3