cerebrou
cerebrou

Reputation: 5540

Fixing the seed for torch random_split()

Is it possible to fix the seed for torch.utils.data.random_split() when splitting a dataset so that it is possible to reproduce the test results?

Upvotes: 11

Views: 17930

Answers (3)

Watson21
Watson21

Reputation: 141

generator = torch.Generator()
generator.manual_seed(0)

train, val, test = random_split(dataset=dataset,
                                lengths=[train_size, val_size, test_size],
                                generator=generator)

Upvotes: 2

Matteo Pennisi
Matteo Pennisi

Reputation: 464

As you can see from the documentation is possible to pass a generator to random_split

random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

Upvotes: 19

Szymon Maszke
Szymon Maszke

Reputation: 24691

You can use torch.manual_seed function to seed the script globally:

import torch
torch.manual_seed(0)

See reproducibility documentation for more information.

If you want to specifically seed torch.utils.data.random_split you could "reset" the seed to it's initial value afterwards. Simply use torch.initial_seed() like this:

torch.manual_seed(torch.initial_seed())

AFAIK pytorch does not provide arguments like seed or random_state (which could be seen in sklearn for example).

Upvotes: 17

Related Questions