Oren
Oren

Reputation: 5309

Running a PyTorch dataloader/Dataset on multiple distributed CPUs

I wonder if there is a way to distributed the dataloader/Dataset to many CPUs, even when using a single GPU. Specifically, I would like to have a Dataset class, and the __getitem__ function will be distributed across many different CPUs (using mpi maybe? but any other way is also good).

Thanks

EDIT My title was erroneously edited, I am not trying to distribute the model itslef, I only want to distribute the data loading/parsing of the model

EDIT - 2 Some interesting discussion in this direction is available here

Upvotes: 1

Views: 2233

Answers (2)

pixelou
pixelou

Reputation: 804

Fetching data from remote server in pytorch dataloader is kinda a duplicate of your question so I can suggest the same answer.

I've written RPCDataloader to distribute dataloader workers on remote servers. It's not using mpi (yet) because the bandwidth on simple TCP sockets (over IB) was sufficient in my case, and I can get the node configuration from SLURM.

It takes 3 steps to use:

  1. Start workers on the data node: python -m rpcdataloader.launch --host=0.0.0.0 --port=xxxx
  2. Create dataset in the trainer(s), this will instantiate actual datasets on the workers and a placeholder object in the trainer(s):
dataset = rpcdataloader.RPCDataset(
    workers=['node01:6543', 'node02:5432'],
    dataset=torchvision.datasets.ImageFolder,
    root=args.data_path + "/train",
    transform=train_transform)
  1. Create Dataloader:
dataloader = rpcdataloader.RPCDataloader(
    dataset
    batch_size=2,
    shuffle=True,
    pin_memory=True)

Upvotes: 1

Mano
Mano

Reputation: 887

You can do this of course, but mind you - it is not always very effective for general Machine Learning needs, due to the hefty communication costs. Use DistributedDataParallel

Implements distributed data parallelism that is based on torch.distributed package at the module level.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.

In practice, I'd recommend you utilize the pytorch_lightning package, to reduce some of the boilerplate code you need to write for this to work.

Reference: DistributedDataParallel,pytorch_lightning

Upvotes: 3

Related Questions