SorushN
SorushN

Reputation: 33

PyTorch: batching from multiple datasets

I have multiple datasets that I want to use in the training. I want each batch to be from one dataset but have batches from (possibly) all of the datasets in each epoch.

Merging the datasets into one simple Dataset object and using the default Dataloader leads to having samples from different datasets in one batch.

My own guess is to have a separate Dataset object for each dataset and override the Dataloader or the sampler, but I don't know how to do it.

Upvotes: 3

Views: 2234

Answers (1)

Shai
Shai

Reputation: 114786

I think the best way to solve your problem is to have a single merged dataset with a single data loader, but have a custom BatchSampler that yields indices based on the different datasets inside the merged dataset.

Upvotes: 4

Related Questions