skm
skm

Reputation: 5679

How to change the default device in cuPy?

I am using cuPy in a python program to perform computations on GPU. My program consists of several functions/classes spread across multiple files.

I am using a GPU cluster (NVIDIA V100), consisting of four GPUs.

How can I select a particular GPU as the default GPU for the whole program? I found some information at cuPy documentation about use keyword but since there is no example, I am not sure how to utilize it.

MWE:

maths_ops.py:

import cupy as cp

class MathsOps:
    def vector_addition(self, vector1, vector2):
        return cp.add(vector1, vector2)

dummy_class.py:

from maths_ops import MathsOps

class DummyClass:
    def __init__(self, maths_ops_instance):
        self.maths_ops_instance = maths_ops_instance

    def perform_operation(self, vector1, vector2):
        result = self.maths_ops_instance.vector_addition(vector1, vector2)
        return result

main.py:

from maths_ops import MathsOps
from dummy_class import DummyClass

# Create an instance of MathsOps
maths_ops_instance = MathsOps()

# Create an instance of DummyClass with the MathsOps instance
dummy_instance = DummyClass(maths_ops_instance)

# Example vectors
vector1 = cp.array([1, 2, 3])
vector2 = cp.array([4, 5, 6])

# Perform vector addition using DummyClass
result = dummy_instance.perform_operation(vector1, vector2)

# Print the result
print("Vector Addition Result:", result)

Question: How do I specify a GPU device (for example, device-1) I'm the main function as the default GPU device that should be used by the program?

Upvotes: 2

Views: 397

Answers (1)

mdd
mdd

Reputation: 763

You can use https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.runtime.setDevice.html to set the device globally.

import cupy as cp
import numpy as np

squared_diff = cp.ElementwiseKernel(
   'float32 x, float32 y',
   'float32 z',
   'z = (x - y) * (x - y)',
   'squared_diff')


cp.cuda.runtime.setDevice(1)

x = cp.arange(1000000000, dtype=np.float32)
y = cp.arange(1000000000, dtype=np.float32)

print(squared_diff(x, y))

Upvotes: 0

Related Questions