Reputation: 5381
Say I am trying to use PyTorch to learn the equation y = 2x
and I want to generate an unlimited amount of data to train my model with. I am supposed to provide a __len__
function. Here's an example below. What should it be in this case? How do I specify the number of mini-batch iterations per epoch? Do I just set a number arbitrarily?
import numpy as np
from torch.utils.data import Dataset
class UnlimitedData(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
x = np.random.randint(1,10)
y = 2 * x
return x, y
def __len__(self):
return 1000000 # This works but is not correct
Upvotes: 2
Views: 662
Reputation: 24701
You should use torch.utils.data.IterableDataset
instead of torch.utils.data.Dataset
. In your case it would be:
import torch
class Dataset(torch.utils.data.IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
while True:
x = torch.randint(1, 10, (self.batch_size,))
y = 2 * x
yield x, y
You should use batches (probably large ones), as that would speed up computations (pytorch is well suited for GPU computations on many samples at once).
Upvotes: 3