Covi
Covi

Reputation: 1361

Implementing an “infinite loop” Dataset & DataLoader in PyTorch

I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

As you can see, the main challenge here is the __len()__ method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.

If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.

I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.

Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)

Upvotes: 10

Views: 11949

Answers (4)

Hariharan J
Hariharan J

Reputation: 1

This loader iterates over a list infinite times, also if the shuffle variable is set to True, in the next iteration the list elements are shuffled.

from torch.utils.data import DataLoader, Dataset, Sampler
import random

class listDataset(Dataset):
    def __init__(self):
        self.varList = [1,2,3,4]
    def __len__(self):
        return len(self.varList)
    def __getitem__(self, idx) :
        return self.varList[idx]

class customSampler(Sampler) :
    def __init__(self, dataset, shuffle):
        assert len(dataset) > 0
        self.dataset = dataset
        self.shuffle = shuffle

    def __iter__(self):
        order = list(range((len(self.dataset))))
        idx = 0
        while True:
            yield order[idx]
            idx += 1
            if idx == len(order):
                if self.shuffle:
                    random.shuffle(order)
                idx = 0

dset = listDataset()
sampler = customSampler(dset, shuffle=True)
loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
for x in range(10):
    i = next(loader)
    print(i)

Upvotes: 0

kuzand
kuzand

Reputation: 9806

This seems to be working without periodically duplicating the data:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))


data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

Result:

Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

So I think the problem is in your function sample_func_to_be_parallelized().


Edit: If instead of torch.randint(0, 10, (3,)) I use np.random.randint(10, size=3) in __getitem__ (as an example of the sample_func_to_be_parallelized()), then the data is indeed duplicated at each batch. See this issue.

So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized(), then the workaround is to use

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

and to reset the seed by np.random.seed() before each call of data = next(iter(data_loader)).

Upvotes: 5

trsvchn
trsvchn

Reputation: 8981

Try to use cycle from itertools. Here is an example for simple dataset:

Code:

from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader


# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])


class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]


bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)

Output:

batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...

batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
...

Upvotes: 1

Jatentaki
Jatentaki

Reputation: 13113

DataLoader samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset). My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader and doing it yourself. You could use the batch_sampler param, and pass in a custom variant, implemented based on RandomSampler. This will allow you to keep the parallel loading part of DataLoader.

That being said, the protocol of iteration based on __len__ and __getitem__ just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__ to just return 1, your Dataset.__getitem__ to always return a new sample, regardless of the index, and then sampling n times with replacement from this dataset. Technically, it will ask n times for the 0-th sample, but since you override __getitem__ to return different samples, this will effectively do what you're looking for.

Upvotes: 1

Related Questions