Reputation: 326
I am trying to train my model using 2 dataloaders from 2 different datasets.
I found how to set up this by using cycle() and zip()
because my datasets are not the same length from here: How to iterate over two dataloaders simultaneously using pytorch?
File "/home/Desktop/example/train.py", line 229, in train_2
for i, (x1, x2) in enumerate(zip(cycle(train_loader_1), train_loader_2)):
File "/home/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 346, in __next__
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 154140672 bytes. Error code 12 (Cannot allocate memory)
I tried to solve that by setting num_workers=0
, decreasing the batch size, using pinned_memory=False
and shuffle=False
...
But none of it worked... I am having 256GB of RAM and 4 NVIDIA TESLA V100 GPUs.
I tried to run it just by not training in 2 dataloaders simultaneously but individually and it worked. However for my project I need this parallel training with 2 datasets...
Upvotes: 0
Views: 2598
Reputation: 326
Based on this discussion, instead of cycle()
and zip()
I avoid any errors by using:
try:
data, target = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloader)
data, target = next(dataloader_iterator)
kudos to @srossi93 from this pytorch post!
Upvotes: 4