helium4
helium4

Reputation: 11

PyTorch data loading from multiple different-sized datasets

I have multiple datasets, each with a different number of images (and different image dimensions) in it. In the training loop I want to load a batch of images randomly from among all the datasets but so that each batch only contains images from a single dataset. For example, I have datasets A, B, C, D and each has images 01.jpg, 02.jpg, … n.jpg (where n depends on the dataset), and let’s say the batch size is 3. In the first loaded batch, for example, I may get images [B/02.jpg, B/06.jpg, B/12.jpg], in the next batch [D/01.jpg, D/05.jpg, D/12.jpg], etc.

So far I have considered the following:

  1. Use a different DataLoader for each dataset, e.g. dataloaderA, dataloaderB, etc., and then in each training loop randomly select one of the dataloaders and get a batch from it. However, this will require a for loop and for large number of datasets it would be very slow since it can’t be split among workers to do it in parallel.
  2. Use a single DataLoader with all of the images from all datasets together but with a custom collate_fn which will create a batch using only images from the same dataset. (I’m not sure how exactly to go about this.)
  3. I have looked at the ConcatDataset class but from its source code it looks like if I use it and try getting a new batch the images in it will be mixed up from among different datasets which I don’t want.

What would be the best way to do this? Thanks!

Upvotes: 1

Views: 5907

Answers (1)

Fábio Perez
Fábio Perez

Reputation: 26108

You can use ConcatDataset, and provide a batch_sampler to DataLoader.

concat_dataset = ConcatDataset((dataset1, dataset2))

ConcatDataset.comulative_sizes will give you the boundaries between each dataset you have:

ds_indices = concat_dataset.cumulative_sizes

Now, you can use ds_indices to create a batch sampler. See the source for BatchSampler for reference. Your batch sampler just has to return a list with N random indices that will respect the ds_indices boundaries. This will guarantee that your batches will have elements from the same dataset.

Upvotes: 0

Related Questions