Reputation: 5540
Is it possible to fix the seed for torch.utils.data.random_split()
when splitting a dataset so that it is possible to reproduce the test results?
Upvotes: 11
Views: 17930
Reputation: 141
generator = torch.Generator()
generator.manual_seed(0)
train, val, test = random_split(dataset=dataset,
lengths=[train_size, val_size, test_size],
generator=generator)
Upvotes: 2
Reputation: 464
As you can see from the documentation is possible to pass a generator to random_split
random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
Upvotes: 19
Reputation: 24691
You can use torch.manual_seed
function to seed the script globally:
import torch
torch.manual_seed(0)
See reproducibility documentation for more information.
If you want to specifically seed torch.utils.data.random_split
you could "reset" the seed to it's initial value afterwards. Simply use torch.initial_seed()
like this:
torch.manual_seed(torch.initial_seed())
AFAIK pytorch
does not provide arguments like seed
or random_state
(which could be seen in sklearn
for example).
Upvotes: 17