Reputation: 534
I'm working with multiple GPUs handling large amaounts of data. I want to create an out-of-memory (OOM) catch system that skips the current batch on all GPUs if any are out of memory.
However, for reasons I don't understand only the OOM GPU reaches the dist.all_reduce synchronisation point. The rest don't log anything beyond the first print and the execution freezes and ends without further message.
I feel like I'm missing something either simple or some distributed computing trivia I don't know. If anyone could point our my error I'd be grateful.
def train_epoch(model, loader, optimizer, device, loss_fn):
for batch_idx, data in enumerate(loader):
if hasattr(data, 'stores') and isinstance(data.stores, list):
for store in data.stores:
if 'name' in store:
print(f"[rank {idr_torch.rank}] Batch {train_count} contains samples with names: {store['name']}")
# Initialize OOM flag
oom_flag = torch.tensor(0, device=device)
try:
# Move data to device
data = data.to(device)
optimizer.zero_grad()
# Forward pass
pred = model(data)
# Compute loss
loss = loss_fn(pred, data=data, device=device)
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
except RuntimeError as e:
if 'CUDA out of memory' in str(e):
print(f"[rank {idr_torch.rank}] CUDA OOM at batch {batch_idx}. Skipping batch...")
torch.cuda.empty_cache()
# Log problematic batch
if hasattr(data, 'stores'):
for store in data.stores:
if 'name' in store:
print(f"[rank {idr_torch.rank}] Problematic batch samples: {store['name']}")
# Set OOM flag
oom_flag = torch.tensor(1, device=device)
# Clear gradients and cache to prevent residue state
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
else:
raise e # Raise non-OOM exceptions
# Synchronize OOM flag across ranks (ensures all GPUs check if any had an OOM)
print(f"[rank {idr_torch.rank}] Waiting on OOM-flag synch in batch {batch_idx}...")
torch.distributed.all_reduce(oom_flag, op=torch.distributed.ReduceOp.MAX)
print(f"[rank {idr_torch.rank}] Synch complete at batch {batch_idx}.")
# If any rank had OOM, skip the batch
if oom_flag.item() > 0:
print(f"[rank {idr_torch.rank}] Skipping synchronized batch {batch_idx} due to OOM...")
skip_count += 1
continue # Skip optimizer step
# Ensure memory cleanup after epoch
torch.cuda.empty_cache()
return True
Output:
<trimmed above prints>
2024-12-11 17:10:24,414 - INFO - [rank 14] Batch 26 contains samples with names: ['test1', 'test2']
2024-12-11 17:10:24,414 - INFO - [rank 8] Batch 26 contains samples with names: ['test2', 'test3']
2024-12-11 17:10:24,414 - INFO - [rank 15] Batch 26 contains samples with names: ['test4', 'test5']
2024-12-11 17:10:30,923 - INFO - [rank 3] CUDA OOM at batch 26. Skipping batch...
2024-12-11 17:10:30,925 - INFO - [rank 3] Problematic batch samples: ['test6', 'test7']
2024-12-11 17:10:30,932 - INFO - [rank 3] Waiting on OOM-flag synch in batch 26...
2024-12-11 17:10:30,934 - INFO - [rank 3] Synch complete at batch 26.
<execution dies silently>
I've also tried adding an extra dist.barrier()
before the all_reduce
, this causes the OOM GPU to hang on "Waiting on OOM-flag synch"
Upvotes: 0
Views: 24