Thomas
Thomas

Reputation: 1123

How to overload a built-in CUDA function?

CUDA has some built-in math functions, such as norm(). I want to create my own version of the norm() function, and use my version throughout the code. However when I define my own norm() function like so:

__device__ float norm(float a, float b) {
    return sqrt(a*a+b*b);
}

I get the following compilation error:

kernel.cu(9): error: more than one instance of overloaded function "norm" has "C" linkage

Is there a way I can overload the norm() function, or do I have to just give my own function a unique name?

I'm using PyCuda to compile my CUDA code.

Upvotes: 1

Views: 706

Answers (2)

talonmies
talonmies

Reputation: 72342

The problem here is the use of C linkage in your code.

You may or may not be explicitly specifying extern "C"anywhere. Irrespective of whether you are, if you are using the PyCUDA SourceModule facility to compile your code, it is (un)helpfully, automagically bracketing the code you submit with extern "C".

If you look at the documentation for SourceModule, you will see the option no_extern_c. Set that to True and this problem will go away. But note that everything you compile will now be compiled with C++ linkage and symbol mangling. You will have to adapt your Python code accordingly (see here for some of the gory details).

And after that, read the other answer, which contains some very sage advice about the perils of overloading standard libraries and a best practice alternative.

Upvotes: 3

einpoklum
einpoklum

Reputation: 131986

I'll make two suggestions in addition to @talonmies' answer - in case you do manage to get overloading working:

  1. General non-CUDA-specific advice: Avoid overloading the builtins / API functions of a library, unless that is absolutely necessary (which it isn't in your case).
    Reasons for this:

    • Likely to confuse other readers of your code
    • Mixing up "wrapper" code with builtins - it's not a "clean" way to code.
    • If the builtins change, your code using the builtins+overloads is likely to also have to change, sometimes in ways you didn't expect.
  2. In your case, I would seriously consider having some namespace with your utility functions, e.g.

    namespace math {
        template <typename T>  
        __device__ T norm(T a, T b) { return math::sqrt<T>(a*a+b*b); }
    }
    

    (of course you would need a math::sqrt template, which would abstract from the single-precision sqrtf(), double-precision sqrt() etc.)

Upvotes: 1

Related Questions