Reputation: 5217
I was trying to reset the dataloader manually but was unable. I tried everything here https://discuss.pytorch.org/t/how-could-i-reset-dataloader-or-count-data-batch-with-iter-instead-of-epoch/22902/4 but no luck. Anyone know how to reset the data loader AND also have the suffle/randomness of the batches not be broken?
Upvotes: 16
Views: 16335
Reputation: 54821
To reset a DataLoader then just enumerate the loader again. Each call to enumerate(loader)
starts from the beginning.
To not break transformers that use random values, then reset the random seed each time the DataLoader is initialized.
def seed_init_fn(x):
seed = args.seed + x
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
return
loader = torch.utils.data.DataLoader(...., worker_init_fn = seed_init_fn)
while True:
for i,data in enumerate(loader):
# will always yield same data
See worker_init_fn
in the documents:
https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader
Here is a better example:
https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359
Upvotes: 14