Erasiel
Erasiel

Reputation: 1

PyTorch conditional paired sampling from the same dataset

I have a PyTorch dataset of (x, y) pairs where x is an input sample and y is some conditional information about x (eg. a simple supervised classification dataset where ys are labels).

I'm looking to create a sampling method that returns (x1, x2, y) triplets where x1 and x2 have the same conditional information y. Sticking to the example of supervised classification, x1 and x2 should have the same label y. This sampling should also be random, ie. I want to use this with a DataLoader in a way that the dataloader returns conditionally random pairs of samples with the same conditional information. In other words, this last point means that a custom dataset with randomly generated sample pairings will not work (at least in my mind) since, in the dataloader, the indices are shuffled, but this would not randomize the pairings.

My thinking is that I need some kind of Sampler that would have access to the dataset, generate these random conditional pairs on the fly and return their indices, however, a PyTorch dataset is indexed by one index and not two. Is this a workable idea or is there another (possibly more elegant) solution to this problem?


Edit: I was able to "solve" this issue by creating a wrapper for a dataset that randomly selects a sample with the same conditional information in its __getitem__ function. I'm still interested in whether this problem can be solved elegantly with a Sampler or something similar that is already in the PyTorch API.

My solution was the following:

class ConditionallyPairedDatasetWrapper(Dataset):
    def __init__(self, dataset: Dataset):
        super().__init__()
        self.dataset = dataset
        self.target_idxs = {} # Dict[int, List[int]]
        self._setup_target_idxs()

    def _setup_target_idxs(self):
        targets = torch.tensor([self.dataset[i][1] 
                                for i in range(len(self.dataset))])
        unique_targets = torch.unique(targets)
        for target in unique_targets:
            self.target_idxs[target.item()] = (targets == target).nonzero()

    def _get_conditional_random_sample(self, target):
        index = self.target_idxs[target][torch.randperm(len(self.target_idxs[target]))[0].item()]
        x, _ = self.dataset[index]
        return x

    def __getitem__(self, index):
        x1, y = self.dataset[index]
        x2 = self._get_conditional_random_sample(y)
        return x1, x2, y

    def __len__(self):
        return len(self.dataset)

Obviously the above code is disgusting but it's a serviceable demonstration.

Upvotes: 0

Views: 98

Answers (0)

Related Questions