Jimmy2027
Jimmy2027

Reputation: 333

How to write a contextmanager to throw and catch errors

I want to catch the runtime error CUDA out of memory on multiple occasions in my code. I do this to then rerun the whole training workflow with lower batch size. What is the best way to do that?

I am currently doing this:

try:
    result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
    if str(e).startswith('CUDA out of memory.') and batch_size > 10:
        raise CudaOutOfMemory(e)
    else:
        raise e

I then catch the error CudaOutOfMemory outside my main function.

However, this is a pretty long piece of code that I need to repeat many times. Is there any way to make a context manager for this?

such that instead I can run:

with catch_cuda_out_of_mem_error:
  result = model(input)

Edit: I want to create a context manager instead of a function because the functions I want to wrap the "try, except" around are not always the same. In my workflow, I have many functions that use a lot of GPU memory and I would like to catch this error in any of them.

Upvotes: 0

Views: 1182

Answers (2)

Jimmy2027
Jimmy2027

Reputation: 333

Inspired by this post: General decorator to wrap try except in python? I found an answer to my problem:

import torch
from contextlib import contextmanager


class CudaOutOfMemory(Exception):
    pass


@contextmanager
def catching_cuda_out_of_memory():
    """
    Context that throws CudaOutOfMemory error if GPU is out of memory.
    """
    try:
        yield
    except RuntimeError as e:
        if str(e).startswith('CUDA out of memory.'):
            raise CudaOutOfMemory(e)
        else:
            raise e


def oom():
    x = torch.randn(100, 10000, device=1)
    for _ in range(100):
        l = torch.nn.Linear(10000, 10000)
        l.to(1)
        x = l(x)


try:
    with catching_cuda_out_of_memory():
        oom()
except CudaOutOfMemory:
    print('GOTCHA!')

Upvotes: 0

mCoding
mCoding

Reputation: 4839

Using a context manager is about properly acquiring and releasing a resource. Here you don't really have any resource that you are acquiring and releasing, so I don't think a context manager is appropriate. How about just using a function?

def try_compute_model(input):
    try:
        return model(input)
    # if the GPU runs out of memory, start the experiment again with a smaller batch size
    except RuntimeError as e:
        if str(e).startswith('CUDA out of memory.') and batch_size > 10:
            raise CudaOutOfMemory(e)
        else:
            raise e

Then use it like

result = try_compute_model(input)

Upvotes: 1

Related Questions