Reputation: 1361
I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:
class Infinite(Dataset):
def __len__(self):
return HPARAMS.batch_size
# return 1<<30 # This causes huge memory usage.
def __getitem__(self, idx):
"""Randomly generates one new example."""
return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
dataset=Infinite(),
batch_size=HPARAMS.batch_size,
num_workers=16,
worker_init_fn=lambda worker_id: np.random.seed(worker_id),
)
while True:
for idx, data in enumerate(infinite_loader):
# forward + backward on "data"
As you can see, the main challenge here is the __len()__
method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.
If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.
I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.
Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)
Upvotes: 10
Views: 11949
Reputation: 1
This loader iterates over a list infinite times, also if the shuffle variable is set to True, in the next iteration the list elements are shuffled.
from torch.utils.data import DataLoader, Dataset, Sampler
import random
class listDataset(Dataset):
def __init__(self):
self.varList = [1,2,3,4]
def __len__(self):
return len(self.varList)
def __getitem__(self, idx) :
return self.varList[idx]
class customSampler(Sampler) :
def __init__(self, dataset, shuffle):
assert len(dataset) > 0
self.dataset = dataset
self.shuffle = shuffle
def __iter__(self):
order = list(range((len(self.dataset))))
idx = 0
while True:
yield order[idx]
idx += 1
if idx == len(order):
if self.shuffle:
random.shuffle(order)
idx = 0
dset = listDataset()
sampler = customSampler(dset, shuffle=True)
loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
for x in range(10):
i = next(loader)
print(i)
Upvotes: 0
Reputation: 9806
This seems to be working without periodically duplicating the data:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
BATCH_SIZE = 2
class Infinite(Dataset):
def __len__(self):
return BATCH_SIZE
def __getitem__(self, idx):
return torch.randint(0, 10, (3,))
data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)
batch_count = 0
while True:
batch_count += 1
print(f'Batch {batch_count}:')
data = next(iter(data_loader))
print(data)
# forward + backward on "data"
if batch_count == 5:
break
Result:
Batch 1:
tensor([[4, 7, 7],
[0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
[2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
[8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
[2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
[2, 7, 5]])
So I think the problem is in your function sample_func_to_be_parallelized()
.
Edit: If instead of torch.randint(0, 10, (3,))
I use np.random.randint(10, size=3)
in __getitem__
(as an example of the sample_func_to_be_parallelized()
), then the data is indeed duplicated at each batch. See this issue.
So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized()
, then the workaround is to use
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
and to reset the seed by np.random.seed()
before each call of data = next(iter(data_loader))
.
Upvotes: 5
Reputation: 8981
Try to use cycle
from itertools
. Here is an example for simple dataset:
Code:
from itertools import cycle
import torch
from torch.utils.data import Dataset, DataLoader
# Create some dummy data.
data = torch.tensor([[0, 0],
[1, 1],
[2, 2],
[3, 3]])
class DataSet(Dataset):
"""Our dataset. Iterates over tensor data"""
def __init__(self, data):
self.data = data
self.n = self.data.shape[0]
def __len__(self):
return self.n
def __getitem__(self, idx):
return self.data[idx]
bs = 1 # batch size
workers = 1 # number of workers
dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)
# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
print(i, data)
Output:
batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...
batch size: 2 | number of workers: 2
0 tensor([[0, 0],
[1, 1]])
1 tensor([[2, 2],
[3, 3]])
0 tensor([[0, 0],
[1, 1]])
1 tensor([[2, 2],
...
Upvotes: 1
Reputation: 13113
DataLoader
samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset)
. My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader
and doing it yourself.
You could use the batch_sampler
param, and pass in a custom variant, implemented based on RandomSampler
. This will allow you to keep the parallel loading part of DataLoader
.
That being said, the protocol of iteration based on __len__
and __getitem__
just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__
to just return 1
, your Dataset.__getitem__
to always return a new sample, regardless of the index, and then sampling n
times with replacement from this dataset. Technically, it will ask n
times for the 0-th sample, but since you override __getitem__
to return different samples, this will effectively do what you're looking for.
Upvotes: 1