Blade
Blade

Reputation: 1110

Using dataloader to sample with replacement in pytorch

I have a dataset defined in the format:

class MyDataset(Dataset):
    def __init__(self, N):
        self.N = N
        self.x = torch.rand(self.N, 10)
        self.y = torch.randint(0, 3, (self.N,))

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

During the training, I would like to sample batches of m training samples, with replacement; e.g. the first iteration includes data indices [1, 5, 6], second iteration includes data points [12, 3, 5], and so on and so forth. So the total number of iterations is an input, rather than N/m

Is there a way to use dataloader to handle this? If not, is there any other method than something in the form of

for i in range(iter):
    x = np.random.choice(range(N), m, replace=True)

to implement this?

Upvotes: 5

Views: 3409

Answers (1)

Ivan
Ivan

Reputation: 40748

You can use a RandomSampler, this is a utility that slides in between the dataset and dataloader:

>>> ds = MyDataset(N)
>>> sampler = RandomSampler(ds, replacement=True, num_samples=M)

Above, sampler will sample a total of M (replacement is necessary of course if num_samples > len(ds)). In your example M = iter*m.

You can then initialize a DataLoader with sampler:

>>> dl = DataLoader(ds, sampler=sampler, batch_size=2)

Here is a possible result with N = 2, M = 2*len(ds) = 4, and batch_size = 2:

>>> for x, y in dl:
...     print(x, y)

tensor([[0.5541, 0.3596, 0.5180, 0.1511, 0.3523, 0.4001, 0.6977, 0.1218, 0.2458, 0.8735],
        [0.0407, 0.2081, 0.5510, 0.2063, 0.1499, 0.1266, 0.1928, 0.0589, 0.2789, 0.3531]]) 
tensor([1, 0])

tensor([[0.5541, 0.3596, 0.5180, 0.1511, 0.3523, 0.4001, 0.6977, 0.1218, 0.2458, 0.8735],
        [0.0431, 0.0452, 0.3286, 0.5139, 0.4620, 0.4468, 0.3490, 0.4226, 0.3930, 0.2227]]) 
tensor([1, 0])

tensor([[0.5541, 0.3596, 0.5180, 0.1511, 0.3523, 0.4001, 0.6977, 0.1218, 0.2458, 0.8735],
        [0.5541, 0.3596, 0.5180, 0.1511, 0.3523, 0.4001, 0.6977, 0.1218, 0.2458, 0.8735]]) 
tensor([1, 1])

Upvotes: 6

Related Questions