Reputation: 27946
Some official pytorch lightning docs have code that refer to stage
as Optional[str]
with for example the following code
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
When does stage take the value of None? I could find no docs describing this.
Upvotes: 1
Views: 1336
Reputation: 3426
Quoting from:
This method expects a
stage
argument. It is used to separate setup logic fortrainer.{fit,validate,test,predict}
. Ifsetup
is called withstage=None
, we assume all stages have been set-up.
Here is a simple code example; if you use this LightningDataModule
snippet:
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 256 if torch.cuda.is_available() else 64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage):
print(stage)
when running:
MNIST_dm = MNISTDataModule(PATH_DATASETS)
trainer.fit(mnist_model, MNIST_dm)
trainer.validate(mnist_model, MNIST_dm)
trainer.test(mnist_model, MNIST_dm)
MNIST_dm.setup(stage="None")
you will see this printed:
TrainerFn.FITTING
TrainerFn.VALIDATING
TrainerFn.TESTING
None
It is None
when you set it explicitly to None
, otherwise it takes the stage name that it is being called
Upvotes: 0
Reputation: 886
The Trainer will never send stage=None
to the setup hook, or any of the other hooks that take this argument. The type is annotated optional and the default value is None for historical reasons. The values will always be one of "fit", "validate", "test", "predict".
There is an RFC to change this to a required argument to avoid confusion. The link provides some more context why it has been like this for the past.
Upvotes: 1