raaj
raaj

Reputation: 3291

Pytorch - Distributed Data Parallel Confusion

I was just looking at the DDP Tutorial:

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

According to this:

It’s common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints. See SAVING AND LOADING MODELS for more details. When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead. This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values. If you use this optimization, make sure all processes do not start loading before the saving is finished. Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others’ devices. If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. For more advanced failure recovery and elasticity support, please refer to TorchElastic.

I dont understand what this means. Shouldn't only one process/first GPU be saving the model? Is saving and loading how weights are shared across the processes/GPUs?

Upvotes: 2

Views: 4734

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 32972

When you're using DistributedDataParallel you have the same model across multiple devices, which are being synchronised to have the exact same parameters.

When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead.

Since they are identical, it is unnecessary to save the models from all processes, as it would just write the same parameters multiple times. For example when you have 4 processes/GPUs you would write the same file 4 times instead of once. That can be avoided by only saving it from the main process.

That is an optimisation for the saving of the model. If you load the model right after you saved it, you need to be more careful.

If you use this optimization, make sure all processes do not start loading before the saving is finished.

If you save it in only one process, that process will take time to write the file. In the meantime all other processes continue and they might load the file before it was fully written to disk, which may lead to all sorts of unexpected behaviour or failure, whether that file does not exist yet, you are trying to read an incomplete file or you load an older version of the model (if you overwrite the same file).

Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others’ devices. If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices.

When saving the parameters (or any tensor for that matter) PyTorch includes the device where it was stored. Let's say you save it from the process that used GPU 0 (device = "cuda:0"), that information is saved and when you load it, the parameters are automatically put onto that device. But if you load it in the process that uses GPU 1 (device = "cuda:1"), you will incorrectly load them into "cuda:0". Now instead of using multiple GPUs, you have the same model multiple times in a single GPU. Most likely, you will run out of memory, but even if you don't, you won't be utilising the other GPUs anymore.

To avoid that problem, you should set the appropriate device for map_location of torch.load.

torch.load(PATH, map_location="cuda:1")

# Or load it on the CPU and later use .to(device) on the model
torch.load(PATH, map_location="cpu")

Upvotes: 3

Related Questions