Reputation: 91
I want to know how to use torch.utils.data.DataLoader
in PyTorch, especially in a multi-worker case.
I found that one batch output from DataLoader
always comes from a single worker.
I expected that there is a queue in the DataLoader which stores data from all of the workers and DataLoader shuffles them in the queue to output the random batch data. I think this is the way in tf.data.Dataset
in Tensorflow.
Can we implement a similar function in PyTorch? I want to load a dataset from big serialized files (like Tfrecord
) by using multi workers. In this case, mixing the source file in one batch, which means mixing the source of the worker, is important.
Please refer to following code:
import random
import time
import torch
class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50
def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()
time.sleep(random.uniform(0, 1))
print("[{}]:{}".format(info.id, idx))
return idx, info.id
if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)
Output:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...
Here, [0, 1, 2, 3, 4]
and [0, 0, 0, 0, 0]
in [tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
mean that this batch includes index 0-th to 4-th data came from worker id 0
.
Note that shuffle=True
does not solve this problem which only change the indices of data.
In this case, I want to get a batch like: [tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])]
.
Upvotes: 6
Views: 4500
Reputation: 2839
I've implemented something simple to solve a similar problem, where I have large video files as training data and each worker is responsible for loading and preprocessing a single file and then yielding samples from it. Problem is that as OP describes, with Pytorch's default data loading mechanism, each batch contains samples only from a single video file.
First, let's review the problem. In this simplified code example each worker yields a single Tensor containing its zero-indexed worker id. With a batch size of 32 and 4 workers, we want each batch to contain 8 zeros, 8 ones, 8 twos and 8 threes.
from collections import defaultdict
import torch as T
import torch.utils.data as tdata
class Dataset(tdata.IterableDataset):
def __init__(self, batch_size: int):
self._bs = batch_size
def __iter__(self):
worker_info = tdata.get_worker_info()
if not worker_info:
raise NotImplementedError('Not implemented for num_workers=0')
for _ in range(self._bs):
yield T.tensor([worker_info.id])
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers)
for batch in loader:
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
Instead the code prints:
{0: 32}
{1: 32}
{2: 32}
{3: 32}
Meaning that the first batch contains samples only from worker 0, the second only from worker 1, etc. To remedy this, we will set the batch size in the DataLoader
to batch_size // num_workers
and use a simple wrapper over the DataLoader
to pool samples from each worker for our batch:
def pooled_batches(loader):
loader_it = iter(loader)
while True:
samples = []
for _ in range(loader.num_workers):
try:
samples.append(next(loader_it))
except StopIteration:
pass
if len(samples) == 0:
break
else:
yield T.cat(samples, dim=0)
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
batch_size=per_worker,
num_workers=num_workers)
for batch in pooled_batches(loader):
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
And the code now prints
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
as expected.
Upvotes: 2
Reputation: 6864
Note that a multi-worker DataLoader
with a batch_size specified is going to load multiple batches parallelly, so essentially one batch comes from a worker always. However, I have achieved something close to what you require by doing as follows:
Make the batch size as 1, so every worker only yields one sample at a time
Write a background process that iterates through the DataLoader, fetches 1 sample at a time and inserts it into a queue. This makes it possible to have the samples in a different order in the queue rather than having worker specific batches
Have a batching mechanism, like the collate_fn
which takes a number of samples equal to your batch size from the queue and feed it to the model
If you want to be more specific in batch creation, say like picking particular samples from specific workers, you can have multiples queues. Your collate procedure should be modified to account for multiple queues and choose from them. But I doubt if that kind of a specificity is needed.
Upvotes: 0