Reputation: 333
I want to catch the runtime error CUDA out of memory
on multiple occasions in my code. I do this to then rerun the whole training workflow with lower batch size. What is the best way to do that?
I am currently doing this:
try:
result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
if str(e).startswith('CUDA out of memory.') and batch_size > 10:
raise CudaOutOfMemory(e)
else:
raise e
I then catch the error CudaOutOfMemory
outside my main function.
However, this is a pretty long piece of code that I need to repeat many times. Is there any way to make a context manager for this?
such that instead I can run:
with catch_cuda_out_of_mem_error:
result = model(input)
Edit: I want to create a context manager instead of a function because the functions I want to wrap the "try, except" around are not always the same. In my workflow, I have many functions that use a lot of GPU memory and I would like to catch this error in any of them.
Upvotes: 0
Views: 1182
Reputation: 333
Inspired by this post: General decorator to wrap try except in python? I found an answer to my problem:
import torch
from contextlib import contextmanager
class CudaOutOfMemory(Exception):
pass
@contextmanager
def catching_cuda_out_of_memory():
"""
Context that throws CudaOutOfMemory error if GPU is out of memory.
"""
try:
yield
except RuntimeError as e:
if str(e).startswith('CUDA out of memory.'):
raise CudaOutOfMemory(e)
else:
raise e
def oom():
x = torch.randn(100, 10000, device=1)
for _ in range(100):
l = torch.nn.Linear(10000, 10000)
l.to(1)
x = l(x)
try:
with catching_cuda_out_of_memory():
oom()
except CudaOutOfMemory:
print('GOTCHA!')
Upvotes: 0
Reputation: 4839
Using a context manager is about properly acquiring and releasing a resource. Here you don't really have any resource that you are acquiring and releasing, so I don't think a context manager is appropriate. How about just using a function?
def try_compute_model(input):
try:
return model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
if str(e).startswith('CUDA out of memory.') and batch_size > 10:
raise CudaOutOfMemory(e)
else:
raise e
Then use it like
result = try_compute_model(input)
Upvotes: 1