gutzcha
gutzcha

Reputation: 180

How to load data from multiply datasets in pytorch

I have two datasets of images - indoors and outdoors, they don't have the same number of examples.

Each dataset has images that contain a certain number of classes (minimum 1 maximum 4), these classes can appear in both datasets, and each class has 4 categories - red, blue, green, white. Example: Indoor - cats, dogs, horses Outdoor - dogs, humans

I am trying to train a model, where I tell it, "here is an image that contains a cat, tell me it's color" regardless of where it was taken (Indoors, outdoors, In a car, on the moon)

To do that, I need to present my model examples so that every batch has only one category (cat, dog, horse or human), but I want to sample from all datasets (two in this case) that contains these objects and mix them. How can I do this?

It has to take into account that the number of examples in each dataset is different, and that some categories appear in one dataset where others can appear in more than one. and each batch must contain only one category.

I would appreciate any help, I have been trying to solve this for a few days now.

Upvotes: 1

Views: 1408

Answers (1)

Matthew R.
Matthew R.

Reputation: 625

Assuming the question is:

  1. Combine 2+ data sets with potentially overlapping categories of objects (distinguishable by label)
  2. Each object has 4 "subcategories" for each color (distinguishable by label)
  3. Each batch should only contain a single object category

The first step will be to ensure consistency of the object labels from both data sets, if not already consistent. For example, if the dog class is label 0 in the first data set but label 2 in the second data set, then we need to make sure the two dog categories are correctly merged. We can do this "translation" with a simple data set wrapper:

class TranslatedDataset(Dataset):
  """
  Args:
    dataset: The original dataset.
    translate_label: A lambda (function) that maps the original
      dataset label to the label it should have in the combined data set
  """
  def __init__(self, dataset, translate_label):
    super().__init__()
    self._dataset = dataset
    self._translate_label = translate_label

  def __len__(self):
    return len(self._dataset)

  def __getitem__(self, idx):
    inputs, target = self._dataset[idx]
    return inputs, self._translate_label(target)

The next step is combining the translated data sets together, which can be done easily with a ConcatDataset:

first_original_dataset = ...
second_original_dataset = ...

first_translated = TranslateDataset(
  first_original_dataset, 
  lambda y: 0 if y is 2 else 2 if y is 0 else y, # or similar
)
second_translated = TranslateDataset(
  second_original_dataset, 
  lambda y: y, # or similar
)

combined = ConcatDataset([first_translated, second_translated])

Finally, we need to restrict batch sampling to the same class, which is possible with a custom Sampler when creating the data loader.

class SingleClassSampler(torch.utils.data.Sampler):
  def __init__(self, dataset, batch_size):
    super().__init__()
    # We need to create sequential groups
    # with batch_size elements from the same class
    indices_for_target = {} # dict to store a list of indices for each target
    
    for i, (_, target) in enumerate(dataset):
      # converting to string since Tensors hash by reference, not value
      str_targ = str(target)
      if str_targ not in indices_for_target:
        indices_for_target[str_targ] = []
      indices_for_target[str_targ] += [i]

    # make sure we have a whole number of batches for each class
    trimmed = { 
      k: v[:-(len(v) % batch_size)] 
      for k, v in indices_for_target.items()
    }

    # concatenate the lists of indices for each class
    self._indices = sum(list(trimmed.values()))
  
  def __len__(self):
    return len(self._indices)

  def __iter__(self):
    yield from self._indices

Then to use the sampler:

loader = DataLoader(
  combined, 
  sampler=SingleClassSampler(combined, 64), 
  batch_size=64, 
  shuffle=True
)

I haven't run this code, so it might not be exactly right, but hopefully it will put you on the right track.


torch.utils.data Docs

Upvotes: 1

Related Questions