MikeMx7f
MikeMx7f

Reputation: 937

Dynamic Dispatch to Template Function C++

I have a template function (in my case a cuda kernel), where there are a small number of boolean template parameters that can chosen between at runtime. I am happy to instantiate all permutations at compile time and dispatch dynamically, like so (for boolean b0,b1,b2):

if (b0) {
    if (b1) {
        if (b2) {
            myFunc<true,true,true,otherArgs>(args);
        } else {
            myFunc<true,true,false,otherArgs>(args);
        }
    } else {
        if(b2) {
            myFunc<true,false,true,otherArgs>(args);
        } else {
            myFunc<true,false,false,otherArgs>(args);
        }
    }
} else {
    if(b1) {
        if(b2) {
            myFunc<false,true,true,otherArgs>(args);
        } else {
            myFunc<false,true,false,otherArgs>(args);
        }
    } else {
        if(b2) {
            myFunc<false,false,true,otherArgs>(args);
        } else {
            myFunc<false,false,false,otherArgs>(args);
        }
    }
}

This is annoying to write, and gets exponentially worse if I end up with a b3 and b4.

Is there a simple way to rewrite this in a more concise way in C++11/14 without bringing in large external libraries (like boost)? Something like:

const auto dispatcher = construct_dispatcher<bool, 3>(myFunc);

...

dispatcher(b0,b1,b2,otherArgs,args);

Upvotes: 7

Views: 1783

Answers (2)

Yakk - Adam Nevraumont
Yakk - Adam Nevraumont

Reputation: 275500

No problem.

template<bool b>
using kbool = std::integral_constant<bool, b>;

template<std::size_t max>
struct dispatch_bools {
  template<std::size_t N, class F, class...Bools>
  void operator()( std::array<bool, N> const& input, F&& continuation, Bools... )
  {
    if (input[max-1])
      dispatch_bools<max-1>{}( input, continuation, kbool<true>{}, Bools{}... );
    else
      dispatch_bools<max-1>{}( input, continuation, kbool<false>{}, Bools{}... );
  }
};
template<>
struct dispatch_bools<0> {
  template<std::size_t N, class F, class...Bools>
  void operator()( std::array<bool, N> const& input, F&& continuation, Bools... )
  {
     continuation( Bools{}... );
  }
};

Live example.

So kbool is a variable with represents a compile time constant boolean. dispatch_bools is a helper struct that has an operator().

This operator() takes an array of runtime bools, and starting at max-1 proceeds to spawn max if/else branches, each recursing into call to dispatch_bools with one more compile-time bool calculated.

This generates 2^max code; exactly the code you don't want to write.

The continuation is passed all the way down to the bottom recursion (where max=0). At that point, all of the compile-time bools have been built up -- we call continuation::operator() passing in those compile-time bools as function parameters.

Hopefully continuation::operator() is a template function that can accept compile-time bools. If it is, there are 2^max instantiations of it, each with each of the 2^max possible true/false combinations.


To use this to solve your problem in you just do:

std::array<bool, 3> bargs={{b0, b1, b2}};
dispatch_bools<3>{}(bargs, [&](auto...Bargs){
  myFunc<decltype(Bargs)::value...,otherArgs>(args);
});

This is easy because has auto lambdas; it can have a template operator() on a lambda. Turning those compile-time bool arguments back into template non-type arguments is easy.

Note that many nominally compilers support auto-lambdas, because of how easy it was. However, if you lack it, you can still solve this in with a helper struct:

template<class OtherArgs>
struct callMyFunc {
  Args args;
  template<class...Bools>
  void operator()(Bools...){
    myFunc<Bools::value...,otherArgs>(args);
  }
};

now use is:

std::array<bool, 3> bargs={{b0, b1, b2}};
dispatch_bools<3>{}(bargs, callMyFunc<otherArgs>{args});

This is basically manually writing what the lambda would do.


In you can replace void with auto and return instead of just recursing and it will deduce a return type for you reasonably well.

If you want that feature in you can either write a lot of decltype code, or you can use this macro:

#define RETURNS(...) \
  noexcept(noexcept(__VA_ARGS__)) \
  -> decltype(__VA_ARGS__) \
  { return __VA_ARGS__; }

and write the body of dispatch_bools like:

template<class T, std::size_t N, class F, class...Bools>
auto operator()( std::array<T, N> const& input, F&& continuation, Bools... )
RETURNS(
 (input[max-1])?
    dispatch_bools<max-1>{}( input, continutation, kbool<true>{}, Bools{}... )
 :
    dispatch_bools<max-1>{}( input, continutation, kbool<false>{}, Bools{}... )
)

and similar for the <0> specialization, and get style return deduction in .

RETURNS makes deducing return types of one-liner functions trivial.

Upvotes: 5

Silvio Mayolo
Silvio Mayolo

Reputation: 70277

Is there a simple way? No. Can it be done using an unholy mess of garbled templates? Sure, why not.

Implementation

First, this is going to be a bit easier if we have a class rather than a function, simply because parameterized classes can be passed as template parameters. So I'm going to write a trivial wrapper around your myFunc.

template <bool... Acc>
struct MyFuncWrapper {
  template <typename T>
  void operator()(T&& extra) const {
    return myFunc<Acc...>(std::forward<T&&>(extra));
  }
};

This is just a class for which MyFuncWrapper<...>()(extra) is equivalent to myFunc<...>(extra).

Now let's make our dispatcher.

template <template <bool...> class Func, typename Args, bool... Acc>
struct Dispatcher {

  auto dispatch(Args&& args) const {
    return Func<Acc...>()(std::forward<Args&&>(args));
  }

  template <typename... Bools>
  auto dispatch(Args&& args, bool head, Bools... tail) const {
    return head ?
      Dispatcher<Func, Args, Acc..., true >().dispatch(std::forward<Args&&>(args), tail...) :
      Dispatcher<Func, Args, Acc..., false>().dispatch(std::forward<Args&&>(args), tail...);
  }

};

Whew, there's quite a bit to explain there. The Dispatcher class has two template arguments and then a variadic list. The first two arguments are simple: the function we want to call (as a class) and the "extra" argument type. The variadic argument will start out empty, and we'll use it as an accumulator during the recursion (similar to an accumulator when you're doing tail call optimization) to accumulate the template Boolean list.

dispatch is just a recursive template function. The base case is when we don't have any arguments left, so we just call the function with the arguments we've accumulated so far. The recursive case involves a conditional, where we accumulate a true if the Boolean is true and a false if it's false.

We can call this with

Dispatcher<MyFuncWrapper, TypeOfExtraArgument>()
    .dispatch(extraArgument, true, true, false);

However, this is a bit verbose, so we can write a macro to make it a bit more approachable.1

#define DISPATCH(F, A, ...) Dispatcher<F, decltype(A)>().dispatch(A, __VA_ARGS__);

Now our call is

DISPATCH(MyFuncWrapper, extraArgument, true, true, false);

Complete Runnable Example

Includes a sample myFunc implementation.

#include <utility>
#include <iostream>

#define DISPATCH(F, A, ...) Dispatcher<F, decltype(A)>().dispatch(A, __VA_ARGS__);

template <bool a, bool b, bool c, typename T>
void myFunc(T&& extra) {
  std::cout << a << " " << b << " " << c << " " << extra << std::endl;
}

template <bool... Acc>
struct MyFuncWrapper {
  template <typename T>
  void operator()(T&& extra) const {
    return myFunc<Acc...>(std::forward<T&&>(extra));
  }
};

template <template <bool...> class Func, typename Args, bool... Acc>
struct Dispatcher {

  auto dispatch(Args&& args) const {
    return Func<Acc...>()(std::forward<Args&&>(args));
  }

  template <typename... Bools>
  auto dispatch(Args&& args, bool head, Bools... tail) const {
    return head ?
      Dispatcher<Func, Args, Acc..., true >().dispatch(std::forward<Args&&>(args), tail...) :
      Dispatcher<Func, Args, Acc..., false>().dispatch(std::forward<Args&&>(args), tail...);
  }

};

int main() {
  DISPATCH(MyFuncWrapper, 17, true, true, false);
  DISPATCH(MyFuncWrapper, 22, true, false, true);
  DISPATCH(MyFuncWrapper, -9, false, false, false);
}

Closing Notes

The implementation provided above will let myFunc return values as well, although your example only included a return type of void, so I'm not sure you'll need this. As written, the implementation requires C++14 for auto return types. If you want to do this under C++11, you can either change all the return types to void (can't return anything from myFunc anymore) or you can try to hack together the return types with decltype. If you want to do this in C++98, ... ... ... ... good luck


1 This macro is susceptible to the comma problem and thus won't work if you pass it zero Booleans. But if you're not going to pass any Booleans, you probably shouldn't be going through this process anyway.

Upvotes: 3

Related Questions