Deshwal
Deshwal

Reputation: 4162

How does the data splitting actually work in Multi GPU Inference for Accelerate when used in a batched inference setting?

I followed the code given in this github issue and this medium blog

I ran the batched experiment with process = 1 and process=4 it gave me the result but I'm confused right now because I thought the result would be in order. If they are not in orger, them I won't be able to map those with the ground Truth

For example let's say my data_length=5 and my batch=3. So if I got results [[1,2,3], [4,5]] for process=1 then I'm expecting when using process = 4, I should get the same results when I flatten the results.

they are coming out of order. What am I doing wrong?

NOTE: I used a zip(text,label) while passing data to processes to get the correct mapping BUT that is not the question

Below is the code:

def seed_everything(seed=13):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    set_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed = 13)


def test():
    accelerator = Accelerator()
    accelerator.wait_for_everyone() 
    seed_everything(seed = 13)
    
    model = load_model(model = "my_model_path"
                        lora = "./my_lora_checkpoint/checkpoint-8200",
                       device = {"": accelerator.process_index}, 
                       num_labels = NUM_LABELS,
                       merge_unload = False)
    
    
    with accelerator.split_between_processes(zipped_text_label) as prompts:
    
        res = {"pred_probs": [], "pred_labels": []}

        BATCH_SIZE = 10
    
        BATCHES = [prompts[i:i + BATCH_SIZE] for i in range(0, len(prompts), BATCH_SIZE)]
        print(len(BATCHES[0]))

        pred_probs = []
        pred_labels = []

        for batch in tqdm(BATCHES):
            text_batch = [i[0] for i in batch]
            score_batch = [i[1] for i in batch]
            
            with torch.no_grad():
                inputs = tokenizer(text_batch,truncation= True, max_length=MAX_LENGTH, padding="max_length", return_tensors = "pt").to(model.device)
                logits = model(**inputs).logits.cpu().to(torch.float32)
                probs = torch.softmax(logits, dim = 1).numpy()
                res["pred_probs"].append(probs.tolist())
                res["pred_labels"].append(probs.argmax(axis = 1).tolist())
        
        res = [res]
    
    result = gather_object(res)
    if accelerator.is_main_process:
        print(result)


notebook_launcher(test, num_processes=1)                              

Upvotes: 2

Views: 372

Answers (1)

Yaoming Xuan
Yaoming Xuan

Reputation: 176

In accelerate, the proper way to control batch splitting is using DataLoaderConfiguration class. Here is an example:

from accelerate.utils import DataLoaderConfiguration

dataloader_config = DataLoaderConfiguration(dispatch_batches=True, split_batches=False)
accelerator = accelerate.Accelerator(dataloader_config=dataloader_config)

In this example, batches given by your dataloader will be dispatched to different process. Each process has exactly one batch and batches from different processes are different. Because normally you won't be able to control which asynchronous process runs faster, you should expect the batches are dispatched randomly in practice.

By the way, split_batches param controls whether to split each batch given by your dataloader into smaller batches and dispatch each smaller batch to each process.

Upvotes: 0

Related Questions