Reputation: 1971
Question Context:
I am working on a PyTorch CUDA extension that process complex-typed Tensors. I have the following code snippet that is used to launch the CUDA kernel
AT_DISPATCH_COMPLEX_TYPES(x.scalar_type(), "my_kernel_function_cuda",
([&] {
my_kernel_function<scalar_t><<<gridDim, blockDim>>>(
x_.data_ptr<scalar_t>(), h_.data_ptr<scalar_t>(), o_.data_ptr<scalar_t>()
);
})
);
AT_DISPATCH_COMPLEX_TYPES
pretty much register and dispatch functions for c10:complex<float>
and c10:complex<double>
.
A kernel function using template usually goes like this
template<typename scalar_t>
__global__ void my_kernel_function (
const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
scalar_t* __restrict__ o
) {
// Function Body
}
Question Problem:
I need to store these complex values in __shared__
memory. However, the compiler did not like having __shared__ scalar_t As[32]
when scaler_t
is not a basic type (int
, float
, double
, etc.).
Given the compiler ends up generating two functions scalar_t
-> c10::complex<float>
and scalar_t
-> c10::complex<double>
. I am looking for how to extract the c10::complex::value_type
, let's call it ctype
to have in my kernel code something like
...
__shared__ ctype As_real[32];
__shared__ ctype As_imag[32];
...
My failed attempt
I tried several variations for my code based on these answers for Difference of keywords 'typename' and 'class' in templates? and for What are some uses of template template parameters? but without any success. Which shows that I did not fully understood how template template parameters work.
Here is one of my attempts to have a single template function code for my kernel.
template < template < typename > class ComplexContainer, typename ComplexType>
__global__ void my_kernel_function (
const ComplexContainer<ComplexType> * __restrict__ x, const ComplexContainer<ComplexType>* __restrict__ h,
ComplexContainer<ComplexType>* __restrict__ o
) {
printf("Template of template with complex class\n");
}
It does compile, but it is definitely not the solution I need since the compiler can not find it
error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
argument types are: (c10::complex<float> *, c10::complex<float> *, c10::complex<float> *)
error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
argument types are: (c10::complex<double> *, c10::complex<double> *, c10::complex<double> *)
Current Alternative:
As a last resource, I could use template specialization. However, I would have repeated code just to change a data type, as show below.
template<typename scalar_t>
__global__ void my_kernel_function (
const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
scalar_t* __restrict__ o
) {
printf("General Case. It should not be invoked\n");
}
using cfloat = c10::complex<float>;
using cdouble = c10::complex<double>;
template<>
__global__ void my_kernel_function (
const cfloat* __restrict__ x, const cfloat* __restrict__ h,
cfloat* __restrict__ o
) {
__shared__ float As_real[32];
__shared__ float As_imag[32];
printf("Specialization 1\n");
}
template<>
__global__ void my_kernel_function (
const cdouble* __restrict__ x, const cdouble* __restrict__ h,
cdouble* __restrict__ o
) {
__shared__ double As_real[32];
__shared__ double As_imag[32];
printf("Specialization 2\n");
}
Upvotes: 1
Views: 121
Reputation: 1971
This example maybe what you are looking for. Applied to your case, you can use c10::complex<T>::value_type
, and it would be like:
template <class complex_t>
__global__ void my_kernel_function (
const complex_t * __restrict__ x, const complex_t* __restrict__ h,
complex_t* __restrict__ o
) {
__shared__ typename complex_t::value_type As_real[32];
__shared__ typename complex_t::value_type As_imag[32];
printf("Template of template with complex class: %s --- sizeof: %d\n", __PRETTY_FUNCTION__, sizeof(typename complex_t::value_type));
}
When float
is used, the output for each thread is
Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<float>] --- sizeof: 4
When double
is used, the output for each thread is
...
Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<double>] --- sizeof: 8
...
Upvotes: 2