Eric Fang
Eric Fang

Reputation: 1

Why does using 4 GPUs for distributed inference slow down compared to 2 GPUs?

I am working on a distributed inference script to process a dataset using multiple GPUs. My setup leverages the PartialState from the Huggingface accelerate library, and I split the data across GPUs using:

distributed_state = PartialState()

with distributed_state.split_between_processes(vPaths) as vPaths:
    # Processing logic here

The idea is to distribute batches of data across GPUs and perform inference in parallel. I first tested it with a small dataset, and observed that 2 GPUs are about twice as fast as 1 GPU, which was expected. However, when I scaled it to 4 GPUs, the total processing time did not improve—it actually became slower than with 2 GPUs.

I suspect the issue lies with the use of the following synchronization call:

distributed_state.wait_for_everyone()

It seems like this synchronization might be introducing extra overhead, negatively impacting the inference time when using more GPUs.

My Question

Would it be reasonable to remove the wait_for_everyone() call and split the data independently for each GPU, letting them process their own data asynchronously? After inference, I could then aggregate the individual outputs (e.g., multiple JSONL files) into a single result.

Alternatively, is there a better strategy to optimize multi-GPU inference to avoid bottlenecks introduced by synchronization?

I followed the method from this Example provided by huggingface. Below is a simplified version of my code:

...
    distributed_state = PartialState()

    for _, vPaths in tqdm(enumerate(data_loader), total=len(data_loader)):
        input_paths = []
        
        with distributed_state.split_between_processes(vPaths) as vPaths:
            for path in vPaths:
                path=  os.path.join(root_dir,path)
                result= generate(path, model, processor)

                
        distributed_state.wait_for_everyone()
     
        gathered_captions  = gather_object(result)
        gathered_paths  = gather_object(input_paths)
        
        if distributed_state.is_main_process:
           ...

Upvotes: 0

Views: 64

Answers (0)

Related Questions