Luca Di Liello
Luca Di Liello

Reputation: 1643

Get local world size in torch distributed training

Suppose I have 2 machines with 4 GPUs each. Suppose that each instance of the training algorithm requires 2 GPUs. I would like to run 4 processes, 2 for each machine, each process using 2 GPUs.

How can I make each process retrieve the number of local processes running on the same machine? I can detect the world size with

torch.distributed.get_world_size()

and the global rank with

torch.distributed.get_rank()

But, given that I would like not to hard code parameters, is there a way to recover that on each node are running 2 processes? This will be usefull to me to assign GPUs to each process equally.

Example: Suppose I know that a machine has 4 GPUs and that there are 2 processes on it, I will assign GPUs [0, 1] to process with local rank 0 and GPUs [2, 3] to process with local rank 1. I know total number of processes but I cannot understand if they are on the same machine, so I cannot decide how many GPUs they are allowed to use.

I need a function that would be called torch.distributed.get_local_world_size()

Upvotes: 11

Views: 20477

Answers (3)

Finncent Price
Finncent Price

Reputation: 827

If you are using torchrun, you can get the local world size using environmental variables set by torchrun.

The direct answer to your question, if that is the case, is

local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

Upvotes: 2

Shomy
Shomy

Reputation: 123

The launcher would set envs at the beginning, and local world size could be obtained from the os environment variables (default the numbers of gpus in a node):

# -*- coding: utf-8 -*-                                                                                                                                                                                                                                                                                                                                 import os                                                                                                                                                                   import torch.distributed as dist                                                                                                                                            import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()

dist.init_process_group('nccl')
local_rank = args.local_rank
local_world_size = os.environ["LOCAL_WORLD_SIZE"]                                                                                                                           
print(f'{local_rank = }; { local_world_size = }')

run: python3 -m torch.distributed.launch --nproc_per_node=4 test.py The output:

local_rank = 0;  local_world_size = '4'                                                                                                                                     
local_rank = 3;  local_world_size = '4'                                                                                                                                     
local_rank = 1;  local_world_size = '4'                                                                                                                                     
local_rank = 2;  local_world_size = '4'                                                                                                                                     ```

Upvotes: 7

Xander
Xander

Reputation: 144

torch.cuda.device_count() is essentially the local world size and could be useful in determining how many GPUs you have available on each device. If you can't do that for some reason, using plain MPI might help

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank() # device rank - [0,1]

torch.cuda.device(i)
ngpus = torch.cuda.device_count()
print(ngpus, " gpus on machine", i) # here's local world size for each process

but I think it would work just to call torch.cuda.device_count() in any case without adding this dependency. I am pretty new here so if you can, please let me know how this answer can be improved.

Upvotes: 9

Related Questions