Reputation: 9263
I have a source of random (non-deterministic, non-repeatable) data, that I'd like to wrap in Dataset and Dataloader for PyTorch training. How can I do this?
__len__
is not defined, as the source is infinite (with possible repition).
__getitem__
is not defined, as the source is non-deterministic.
Upvotes: 1
Views: 453
Reputation: 2203
When defining a custom dataset class, you'd ordinarily subclass torch.utils.data.Dataset
and define __len__()
and __getitem__()
.
However, for cases where you want sequential but not random access, you can use an iterable-style dataset. To do this, you instead subclass torch.utils.data.IterableDataset
and define __iter__()
. Whatever is returned by __iter__()
should be a proper iterator; it should maintain state (if necessary) and define __next__()
to obtain the next item in the sequence. __next__()
should raise StopIteration
when there's nothing left to read. In your case with an infinite dataset, it never needs to do this.
Here's an example:
import torch
class MyInfiniteIterator:
def __next__(self):
return torch.randn(10)
class MyInfiniteDataset(torch.utils.data.IterableDataset):
def __iter__(self):
return MyInfiniteIterator()
dataset = MyInfiniteDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)
for batch in dataloader:
# ... Do some stuff here ...
# ...
# if some_condition:
# break
Upvotes: 1