Reputation: 5309
I wonder if there is a way to distributed the dataloader/Dataset to many CPUs, even when using a single GPU.
Specifically, I would like to have a Dataset class, and the __getitem__
function will be distributed across many different CPUs (using mpi maybe? but any other way is also good).
Thanks
EDIT My title was erroneously edited, I am not trying to distribute the model itslef, I only want to distribute the data loading/parsing of the model
EDIT - 2 Some interesting discussion in this direction is available here
Upvotes: 1
Views: 2233
Reputation: 804
Fetching data from remote server in pytorch dataloader is kinda a duplicate of your question so I can suggest the same answer.
I've written RPCDataloader to distribute dataloader workers on remote servers. It's not using mpi (yet) because the bandwidth on simple TCP sockets (over IB) was sufficient in my case, and I can get the node configuration from SLURM.
It takes 3 steps to use:
python -m rpcdataloader.launch --host=0.0.0.0 --port=xxxx
dataset = rpcdataloader.RPCDataset(
workers=['node01:6543', 'node02:5432'],
dataset=torchvision.datasets.ImageFolder,
root=args.data_path + "/train",
transform=train_transform)
dataloader = rpcdataloader.RPCDataloader(
dataset
batch_size=2,
shuffle=True,
pin_memory=True)
Upvotes: 1
Reputation: 887
You can do this of course, but mind you - it is not always very effective for general Machine Learning needs, due to the hefty communication costs.
Use DistributedDataParallel
Implements distributed data parallelism that is based on torch.distributed package at the module level.
This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.
In practice, I'd recommend you utilize the pytorch_lightning
package, to reduce some of the boilerplate code you need to write for this to work.
Reference: DistributedDataParallel,pytorch_lightning
Upvotes: 3