jss367
jss367

Reputation: 5381

What should __len__ be for PyTorch when generating unlimited data?

Say I am trying to use PyTorch to learn the equation y = 2x and I want to generate an unlimited amount of data to train my model with. I am supposed to provide a __len__ function. Here's an example below. What should it be in this case? How do I specify the number of mini-batch iterations per epoch? Do I just set a number arbitrarily?

import numpy as np
from torch.utils.data import Dataset

class UnlimitedData(Dataset):
    def __init__(self):
        pass
    
    def __getitem__(self, index):
        x = np.random.randint(1,10)
        y = 2 * x
        return x, y
    
    def __len__(self):
        return 1000000 # This works but is not correct

Upvotes: 2

Views: 662

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24701

You should use torch.utils.data.IterableDataset instead of torch.utils.data.Dataset. In your case it would be:

import torch


class Dataset(torch.utils.data.IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            x = torch.randint(1, 10, (self.batch_size,))
            y = 2 * x
            yield x, y

You should use batches (probably large ones), as that would speed up computations (pytorch is well suited for GPU computations on many samples at once).

Upvotes: 3

Related Questions