SRobertJames
SRobertJames

Reputation: 9263

PyTorch Dataset / Dataloader from random source

I have a source of random (non-deterministic, non-repeatable) data, that I'd like to wrap in Dataset and Dataloader for PyTorch training. How can I do this?

__len__ is not defined, as the source is infinite (with possible repition).
__getitem__ is not defined, as the source is non-deterministic.

Upvotes: 1

Views: 453

Answers (1)

Alexander Guyer
Alexander Guyer

Reputation: 2203

When defining a custom dataset class, you'd ordinarily subclass torch.utils.data.Dataset and define __len__() and __getitem__().

However, for cases where you want sequential but not random access, you can use an iterable-style dataset. To do this, you instead subclass torch.utils.data.IterableDataset and define __iter__(). Whatever is returned by __iter__() should be a proper iterator; it should maintain state (if necessary) and define __next__() to obtain the next item in the sequence. __next__() should raise StopIteration when there's nothing left to read. In your case with an infinite dataset, it never needs to do this.

Here's an example:

import torch

class MyInfiniteIterator:
    def __next__(self):
        return torch.randn(10)

class MyInfiniteDataset(torch.utils.data.IterableDataset):
    def __iter__(self):
        return MyInfiniteIterator()

dataset = MyInfiniteDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)

for batch in dataloader:
    # ... Do some stuff here ...
    # ...

    # if some_condition:
    #     break

Upvotes: 1

Related Questions