Reputation: 799
Supposed I have a dataset:
datasets = [0,1,2,3,4]
In scenario I, the code is:
torch.manual_seed(1)
ran_sampler = RandomSampler(data_source=datasets)
for data in ran_sampler:
print(data)
The result is 1,3,4,0,2
.
In scenario II, the code is:
torch.manual_seed(1)
seed=1234
G = torch.Generator()
G.manual_seed(seed)
ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets,
sampler=ran_sampler,
generator=G)
for data in ran_sampler:
print(data)
The result is 1,3,4,0,2
. In fact, give any value to the variable seed
, the result is still 1,3,4,0,2
.
In scenario III, the code is:
torch.manual_seed(1)
ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets,
sampler=ran_sampler)
for data in dataloader:
print(data)
The result is 4,1,3,0,2
.
I check the source code of RandomSampler
and find:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
It shows that RandomSampler
would create a generator itself if no generator is given. Therefore in theory, my scenario I, II and III would output the same results but scenario III outputs a different results. Why would this happen? I am lost in the source code of Dataloader
and I am confused about the relationship between Dataloader, sampler and generator.
I have already asked a question about The shuffling order of DataLoader in pytorch. I understand that Dataloader
would pass the generator to sampler in certain environments but in my scenario III, the RandomSampler
has a generator already.
Upvotes: 8
Views: 3244
Reputation: 685
Scenario 3 is actually trickier to investigate than I thought. Let's look at all scenarios one by one. In this answer, "generator" means "random number generator" that is an instance of torch.Generator
, and not Python's generator.
This scenario is straightforward. When one iterates a RandomSampler
created without generator
supplied, the sampler creates its own generator as you pointed out. This creation can be seen inside the RandomSampler.__iter__
definition:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
The result is 1,3,4,0,2
.
A generator passed to DataLoader
is used only to (a) create a RandomSampler
if sampler
is not given and (b) generate a base seed for workers when multiprocessing is used. Both uses are described in the docstring. In your code, sampler
is set to ran_sampler
and multiprocessing isn't used (the default). Thus, the passed generator G
has no effect. What was intended is perhaps for G
to determine the random sampling of datasets
. In this case, G
should be passed to RandomSampler
as follows
seed = 1
G = torch.Generator()
G.manual_seed(seed)
ran_sampler = RandomSampler(data_source=datasets, generator=G) # G is passed here
dataloader = DataLoader(dataset=datasets,
sampler=ran_sampler) # G is not passed here
for data in ran_sampler:
print(data)
However, this code prints 0,4,2,3,1
which is different from Scenario 1. This is because the actual seed used in Scenario 1 by the random sampler is not 1 (recall that RandomSampler
creates its own generator in Scenario 1). To make the output identical, we need to use the same seed:
torch.manual_seed(1)
seed = int(torch.empty((), dtype=torch.int64).random_().item()) # use the same seed as Scenario 1
G = torch.Generator()
G.manual_seed(seed)
ran_sampler = RandomSampler(data_source=datasets, generator=G)
dataloader = DataLoader(dataset=datasets,
sampler=ran_sampler)
for data in ran_sampler:
print(data)
This code now outputs 1,3,4,0,2
.
In Scenario 3, one would expect the result would be the same as Scenario 1 because they use the same random sampler, and both samplers should generate the same seed for their own generators. However, the seeds are actually different because when DataLoader.__iter__
is called, the below code (defined inside _BaseDataLoaderIter
class in torch.utils.data.dataloader) will also run:
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
Here, self._index_sampler
is an instance of BatchSampler
that iterates over ran_sampler
if self._sampler_iter
is iterated over. In other words, in order for ran_sampler
to create its own generator, self._sampler_iter
must be iterated over. After looking at the code of BatchSampler.__iter__
, you may be wondering why. The reason is because self._index_sampler.__iter__
is a Python generator which will be executed only if the returned generator iterator, namely self._sampler_iter
, is iterated over, which doesn't happen in the code above.
Note that Python generator is not a random number generator like torch.Generator
but rather a function that contains yield
. It's unfortunate that both use the same term which can cause confusion.
Also note that a seed is generated in the code above (self._base_seed
), which happens before self._sampler_iter
is iterated over. When the loop in Scenario 3 runs, i.e. self._sampler_iter
is iterated over, ran_sampler
creates its own generator. Recall from Scenario 1 that this creation executes
seed = int(torch.empty((), dtype=torch.int64).random_().item())
where the call Tensor.random_()
is the second call of that method, making seed
different from that of Scenario 1; Scenario 1 seed is now in self._base_seed
which is obtained using the first call of Tensor.random_()
. In other words, self._base_seed
is equal to Scenario 1 seed while seed
is not. Iterating ran_sampler
in Scenario 1 twice provides an indication:
torch.manual_seed(1)
ran_sampler = RandomSampler(data_source=datasets)
list(ran_sampler) # first iteration
for data in ran_sampler:
print(data)
The above code outputs 4,1,3,0,2
.
Upvotes: 7