Charlie Parker
Charlie Parker

Reputation: 5217

How does one reset the dataloader in pytorch?

I was trying to reset the dataloader manually but was unable. I tried everything here https://discuss.pytorch.org/t/how-could-i-reset-dataloader-or-count-data-batch-with-iter-instead-of-epoch/22902/4 but no luck. Anyone know how to reset the data loader AND also have the suffle/randomness of the batches not be broken?

Upvotes: 16

Views: 16335

Answers (1)

Reactgular
Reactgular

Reputation: 54821

To reset a DataLoader then just enumerate the loader again. Each call to enumerate(loader) starts from the beginning.

To not break transformers that use random values, then reset the random seed each time the DataLoader is initialized.

def seed_init_fn(x):
   seed = args.seed + x
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

loader = torch.utils.data.DataLoader(...., worker_init_fn = seed_init_fn)

while True:
   for i,data in enumerate(loader):
      # will always yield same data

See worker_init_fn in the documents:

https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader

Here is a better example:

https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359

Upvotes: 14

Related Questions