Alexander Soare
Alexander Soare

Reputation: 3247

How to use with torch.cuda.device() conditionally

I have some code encapsulated in a:

with torch.cuda.device(self.device):
    # do a bunch of stuff

And in my __init__ I have:

self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

But I'm a little confused about how to deal with a situation where the device is cpu. Because torch.cuda.device is already explicitly for cuda. Should I just write a decorator for the function? Seems a bit overkill

Upvotes: 3

Views: 3831

Answers (1)

jodag
jodag

Reputation: 22204

According to the documentation for torch.cuda.device

device (torch.device or int) – device index to select. It’s a no-op if this argument is a negative integer or None.

Based on that we could use something like

with torch.cuda.device(self.device if self.device.type == 'cuda' else None):
    # do a bunch of stuff

which would simply be a no-op when self.device isn't a CUDA device.

Upvotes: 4

Related Questions