Fluffy
Fluffy

Reputation: 922

Returning std::function held in a map of std::variant

I have a map of std::variant holding several std::function specializations, such as:

// note the different return types
using function_t = std::variant<std::function<int(void)>, std::function<void(int)>>;
std::map<int, function_t> callbacks;
callbacks[0] = [](){ return 9; };

How do I write a caller(...) helper function that would give me a reference of the mapped std::function held in my variant at an index, allowing a call similar to:

int value = caller(callbacks, 0)();

A simple visitor doesn't work because of the different return types held in function_t, i.e:

// cannot compile
auto caller(std::map<int, function_t> callbacks, int idx) {
    return std::visit([](const auto& arg) { return arg; }, callbacks[idx]);    
}

Upvotes: 1

Views: 570

Answers (1)

Yakk - Adam Nevraumont
Yakk - Adam Nevraumont

Reputation: 275310

The first part is being able to call a function only if the arguments match:

struct void_t {};

template<class R, class...Args, class...Ts,
  // in C++20 do requires
  std::enable_if_t<sizeof...(Args)==sizeof...(Ts), bool> = true,
  class R0=std::conditional_t< std::is_same_v<R,void>, void_t, R >
>
std::optional<R0> call_me_maybe( std::function<R(Args...)> const& f, Ts&&...ts ) {

  if constexpr ( (std::is_convertible_v<Ts&&, Args> && ... ))
  {
    if constexpr (std::is_same_v<R, void>) {
      f(std::forward<Ts>(ts)...);
      return void_t{};
    } else {
      return f(std::forward<Ts>(ts)...);
    }
  }
  else
  {
    return std::nullopt;
  }
}
template<class R, class...Args, class...Ts,
  // in C++20 do requires
  std::enable_if_t<sizeof...(Args)!=sizeof...(Ts), bool> = true,
  class R0=std::conditional_t< std::is_same_v<R,void>, void_t, R >
>
constexpr std::optional<R0> call_me_maybe( std::function<R(Args...)> const& f, Ts&&...ts ) {
  return std::nullopt;
}

The second part involves some work with variants:

template<std::size_t I>
using index_t = std::integral_constant<std::size_t, I>;
template<std::size_t I>
constexpr index_t<I> index = {};

template<std::size_t...Is>
using variant_index_t = std::variant< index_t<Is>... >;
template<std::size_t...Is, class R=variant_index_t<Is...>>
constexpr R make_variant_index( std::size_t I, std::index_sequence<Is...> ) {
  constexpr R retvals[] = {
    R( index<Is> )...
  };
  return retvals[I];
}
template<std::size_t N>
constexpr auto make_variant_index( std::size_t I ) {
  return make_variant_index( I, std::make_index_sequence<N>{} );
}
template<class...Ts>
constexpr auto get_variant_index( std::variant<Ts...> const& v ) {
  return make_variant_index<sizeof...(Ts)>( v.index() );
}

That lets you work with variant indexes in a more compile-time friendly way.

template<class...Ts>
std::optional<std::variant<Ts...>> var_opt_flip( std::variant<std::optional<Ts>...> const& var ) {
  return std::visit( [&](auto I)->std::optional<std::variant<Ts...>> {
    if (std::get<I>(var))
      return std::variant<Ts...>(std::in_place_index_t<I>{}, *std::get<I>(var));
    else
      return std::nullopt;
  }, get_variant_index(var) );
}

this lets us take a variant<optional<Ts>...> and produce an optional<variant<Ts...>>, even if there are duplicate types.

We now need to be able to build the right return value.

Now we can write this, a function that takes a variant of functions and arguments, and maybe calls the active one:

template<class...Sigs, class...Ts>
auto call_maybe( std::variant<std::function<Sigs>...> const& vf, Ts&&...ts )
{
  using R0 = std::variant< decltype(call_me_maybe(std::function<Sigs>{}, std::forward<Ts>(ts)...))... >;
  R0 retval = std::visit(
    [&](auto I)->R0 {
      return R0( std::in_place_index_t<I>{}, call_me_maybe(std::get<I>(vf), std::forward<Ts>(ts)... ) );
    },
    get_variant_index(vf)
  );
  return var_opt_flip( std::move(retval) );
}

Then we rewrite caller to use it:

using function_t = std::variant< std::function< void() >, std::function< int(int) > >;

template<class...Ts>
auto caller(std::map<int, function_t> const& callbacks, int idx, Ts&&...ts) {
  auto it = callbacks.find(idx);
  using R = decltype(call_maybe( it->second, std::forward<Ts>(ts)... ));
  // wrong index:
  if (it == callbacks.end())
    return R(std::nullopt);
  // ok, give it a try:
  return call_maybe( it->second, std::forward<Ts>(ts)... );
}

There are going to be some compilers who don't like what I did with auto I; on those, decltype(I)::value replacing I might help (what can I say, not all compilers are C++ compliant).

The basic idea is that we create a variant, with matching indexes, of the possible return values of the functions. We then return an optional one of those, to deal with the fact that failure is definitely a possibility (at runtime).

call_me_maybe is (beyond a song reference) a way to be able to pretend we can call anything. That is where a nothing_t might be useful when R is void.

The variant_index_t is a trick I use to deal with variants as generic sum types with possibly repeat types in it.

First we define a compile time integer called an index. It is based off of the existing std::integral_constant.

Then we make a variant of these, such that alternative 3 is the compile-time index 3.

We can then use std::visit( [&](auto I){/*...*/}, get_variant_index(var) ) to work with the index of a variant as a compile time constant.

If var has 4 alternatives and holds alternative 2, then get_variant_index returns a std::variant<index<0>, index<1>, index<2>, index<3>> that has index<2> populated in it.

(At runtime this is plausibly going to be represented by a 64 bit integer 2. I find this funny.)

When we std::visit this variant_index, the lambda we pass gets passed the index_t<I>. So the lambda has a compile time constant passed to it. In a compiler that isn't dumb, you can constexpr extract the value from the index_t<I> by the operator std::size_t it has implicitly. For dumb compilers, you have to do std::decay_t<decltype(I)>::value, which will be the same compile-time integer.

Using that compile-time integer we can std::get<I>(var) the value inside the lambda (and guarantee the one at the right spot) and we can use it to construct another variant at the same alternative, even if that other variant has ambiguous alternatives. In your case, you'd see that if you had

std::function<int(int)>
std::function<int(int,int)>

the "variant of results" looks like std::variant<int,int> -- which is different from a std::variant<int>.

(As an additional step, you could remove duplicate types from this variant, but I'd advise doing that separately)

Each of the call_me_maybe calls returns an optional<R>. But a variant<optional<R>...> is dumb, so I flip it to a optional<variant<R>...>.

Which means you can quickly check if the function call worked, and if it does you can see what value you got out of it.


Test code:

    std::map<int, function_t> callbacks = {
        { 0, []{ std::cout << 0 << "\n"; } },
        { 1, [](int x){ std::cout << "1:" << x << "\n"; return x+1; } },
    };
    std::optional<std::variant<void_t, int>> results[] = {
        caller(callbacks, 0),
        caller(callbacks, 0, 1),
        caller(callbacks, 1),
        caller(callbacks, 1, 1),
    };
    for (auto&& op:results) {
        std::cout << (bool)op;
    }
    std::cout << "\n";
    auto printer = [](auto val) {
        if constexpr (std::is_same_v<decltype(val), void_t>) {
            std::cout << "void_t";
        } else {
            std::cout << val;
        }
    };
    int count = 0;
    for (auto&& op:results) {
        
        std::cout << count << ":";
        if (!op) {
            std::cout << "nullopt\n";
        } else {
            std::visit( printer, *op );
            std::cout << "\n";
        }
        ++count;
    }

I get this output:

0
1:1
1001
0:void_t
1:nullopt
2:nullopt
3:2

The first two lines are the void() and int(int) std::functions logging their call.

The third line shows which calls succeeded -- the 0 argument call to void() and the 1 argument call to int(int).

The last 4 lines are the results stored. The first one, the optional<variant> is engaged and holds a void_t. The 2nd and 3rd call failed so nullopt, and the last one contains the result of passing 1 to the function that returns 1+1.

Live example.

From the return value, you can see if the call worked (see if the outer optional is engaged), determine which callback was called if one was (the variant index), and get the value of the called variant (do a visit on it).


If the number of function types is large, there is an optimization you should consider.

The above has two nested std::visits of a variant index, both guaranteed to return the same value. This means O(n^2) code is being generated where only O(n) is required, where n is the number of alternatives in function_t.

You can clean that up by passing the variant index "down" to call_maybe and var_opt_flip as an extra argument. In theory a compiler could work out that the other n^2-n generated code elements are unreachable, but that both requires a lot of work on the part of the compiler and would be fragile even if it worked.

Doing so will reduce build times (and this kind of tomfoolery can cost build times; don't call this in a commonly included public header!), and could reduce runtime executable size.

Most programming languages and most uses of C++ don't permit O(n) code to generate more than O(n) binary; but templates are powerful enough, and std variant in particular, to generate O(n^2) and even O(n^3) binary code output. So some care should be taken.

Upvotes: 4

Related Questions