Gulzar
Gulzar

Reputation: 27946

When is `stage is None` in pytorch lightning?

Some official pytorch lightning docs have code that refer to stage as Optional[str] with for example the following code

import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict" or stage is None:
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

When does stage take the value of None? I could find no docs describing this.

Upvotes: 1

Views: 1336

Answers (2)

Mike B
Mike B

Reputation: 3426

Quoting from:

This method expects a stage argument. It is used to separate setup logic for trainer.{fit,validate,test,predict}. If setup is called with stage=None, we assume all stages have been set-up.

Here is a simple code example; if you use this LightningDataModule snippet:

class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 256 if torch.cuda.is_available() else 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage):
        print(stage)

when running:

MNIST_dm = MNISTDataModule(PATH_DATASETS)

trainer.fit(mnist_model, MNIST_dm)
trainer.validate(mnist_model, MNIST_dm)
trainer.test(mnist_model, MNIST_dm)
MNIST_dm.setup(stage="None")

you will see this printed:

TrainerFn.FITTING
TrainerFn.VALIDATING
TrainerFn.TESTING
None

It is None when you set it explicitly to None, otherwise it takes the stage name that it is being called

Upvotes: 0

awaelchli
awaelchli

Reputation: 886

The Trainer will never send stage=None to the setup hook, or any of the other hooks that take this argument. The type is annotated optional and the default value is None for historical reasons. The values will always be one of "fit", "validate", "test", "predict".

There is an RFC to change this to a required argument to avoid confusion. The link provides some more context why it has been like this for the past.

Upvotes: 1

Related Questions