Zyzyx
Zyzyx

Reputation: 534

How to properly handle multi-GPU execution failing on one GPU due to OOM

I'm working with multiple GPUs handling large amaounts of data. I want to create an out-of-memory (OOM) catch system that skips the current batch on all GPUs if any are out of memory.

However, for reasons I don't understand only the OOM GPU reaches the dist.all_reduce synchronisation point. The rest don't log anything beyond the first print and the execution freezes and ends without further message.

I feel like I'm missing something either simple or some distributed computing trivia I don't know. If anyone could point our my error I'd be grateful.

def train_epoch(model, loader, optimizer, device, loss_fn):
for batch_idx, data in enumerate(loader):
    if hasattr(data, 'stores') and isinstance(data.stores, list):
        for store in data.stores:
            if 'name' in store:
                print(f"[rank {idr_torch.rank}] Batch {train_count} contains samples with names: {store['name']}")

    # Initialize OOM flag
    oom_flag = torch.tensor(0, device=device)

    try:
        # Move data to device
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        pred = model(data)

        # Compute loss
        loss = loss_fn(pred, data=data, device=device)

        # Backward pass
        loss.backward()

        # Optimizer step
        optimizer.step()

    except RuntimeError as e:
        if 'CUDA out of memory' in str(e):
            print(f"[rank {idr_torch.rank}] CUDA OOM at batch {batch_idx}. Skipping batch...")
            torch.cuda.empty_cache()

            # Log problematic batch
            if hasattr(data, 'stores'):
                for store in data.stores:
                    if 'name' in store:
                        print(f"[rank {idr_torch.rank}] Problematic batch samples: {store['name']}")

            # Set OOM flag
            oom_flag = torch.tensor(1, device=device)

            # Clear gradients and cache to prevent residue state
            optimizer.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()

        else:
            raise e  # Raise non-OOM exceptions

    # Synchronize OOM flag across ranks (ensures all GPUs check if any had an OOM)
    print(f"[rank {idr_torch.rank}] Waiting on OOM-flag synch in batch {batch_idx}...")
    torch.distributed.all_reduce(oom_flag, op=torch.distributed.ReduceOp.MAX)
    print(f"[rank {idr_torch.rank}] Synch complete at batch {batch_idx}.")

    # If any rank had OOM, skip the batch
    if oom_flag.item() > 0:
        print(f"[rank {idr_torch.rank}] Skipping synchronized batch {batch_idx} due to OOM...")
        skip_count += 1
        continue  # Skip optimizer step

# Ensure memory cleanup after epoch
torch.cuda.empty_cache()
return True

Output:

            <trimmed above prints>
2024-12-11 17:10:24,414 - INFO - [rank 14] Batch 26 contains samples with names: ['test1', 'test2']
2024-12-11 17:10:24,414 - INFO - [rank 8] Batch 26 contains samples with names: ['test2', 'test3']
2024-12-11 17:10:24,414 - INFO - [rank 15] Batch 26 contains samples with names: ['test4', 'test5']
2024-12-11 17:10:30,923 - INFO - [rank 3] CUDA OOM at batch 26. Skipping batch...
2024-12-11 17:10:30,925 - INFO - [rank 3] Problematic batch samples: ['test6', 'test7']
2024-12-11 17:10:30,932 - INFO - [rank 3] Waiting on OOM-flag synch in batch 26...
2024-12-11 17:10:30,934 - INFO - [rank 3] Synch complete at batch 26.
            <execution dies silently>     

I've also tried adding an extra dist.barrier() before the all_reduce, this causes the OOM GPU to hang on "Waiting on OOM-flag synch"

Upvotes: 0

Views: 24

Answers (0)

Related Questions