xosp7tom
xosp7tom

Reputation: 2183

arguments forwarding for template function

Numerical libraries usually have type-specified functions with almost same function name and arguments, such as cblas_[sdcz]gemm, depending on type of inputs. To allow function overloading and thus call by cblas_tgemm, I wrote a wrapper of those functions, like:

inline
void cblas_tgemm(const  CBLAS_LAYOUT Layout, const  CBLAS_TRANSPOSE TransA,
                 const  CBLAS_TRANSPOSE TransB, const INT M, const INT N,
                 const INT K, const float alpha, const float *A,
                 const INT lda, const float *B, const INT ldb,
                 const float beta, float *C, const INT ldc)
{
    cblas_sgemm(Layout, TransA,
                 TransB, M, N,
                 K, alpha, A,
                 lda, B, ldb,
                 beta, C, ldc);
}
inline
void cblas_tgemm(const  CBLAS_LAYOUT Layout, const  CBLAS_TRANSPOSE TransA,
                 const  CBLAS_TRANSPOSE TransB, const INT M, const INT N,
                 const INT K, const double alpha, const double *A,
                 const INT lda, const double *B, const INT ldb,
                 const double beta, double *C, const INT ldc)
{
    cblas_dgemm(Layout, TransA,
                 TransB, M, N,
                 K, alpha, A,
                 lda, B, ldb,
                 beta, C, ldc);
}

Obviously, this is quite annoying as I need to iterate all function arguments. Is there a better way to forward all arguments? The following code, for example, does not work unfortunately, with error: redefinition of ‘template<class ... Params> void {anonymous}::cblas_tgemm(Params&& ...)’

template <typename ...Params>
void cblas_tgemm(Params&&... params)
{
    cblas_sgemm(std::forward<Params>(params)...);
}
template <typename ...Params>
void cblas_tgemm(Params&&... params)
{
    cblas_dgemm(std::forward<Params>(params)...);
}

Upvotes: 2

Views: 62

Answers (1)

Yakk - Adam Nevraumont
Yakk - Adam Nevraumont

Reputation: 275730

SFINAE might help:

#define RETURNS(...) \
  -> decltype(__VA_ARGS__ )\
  { return __VA_ARGS__; }

template <class ...Params>
auto cblas_tgemm(Params&&... params)
RETURNS(cblas_sgemm(std::forward<Params>(params)...))

now this overload only applies if the expression is valid.

There can be problems with duplicate signatures and multiple overloads being valid etc. But this is starting point.

Upvotes: 2

Related Questions