erip
erip

Reputation: 16995

How can I use templates to deduce the parameter types of a std::function?

I'm working on a problem to rotate an NxN matrix of type T by 90 degrees. In the spirit of DRY, I'd like the function signature of my rotate function to look like this:

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, std::function<void(T&, T&, T&, T&)> swap_direction);

This will allow me to swap clockwise and counterclockwise with the same function simply by passing a different std::function<void(T&, T&, T&, T&)>.

I currently have the following code:

#include <iostream>
#include <array>
#include <functional>

template <typename T, std::size_t N>
using Matrix = std::array<std::array<T, N>, N>;

template <typename T>
void four_way_swap_clockwise(T& top_left, T& top_right, T& bottom_left, T& bottom_right) {
    T temp = top_left;
    top_left = top_right;
    top_right = bottom_right;
    bottom_right = bottom_left;
    bottom_left = temp;
}

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, std::function<void(T&, T&, T&, T&)> swap_direction) {
    for(std::size_t i = 0; i < N/2; ++i) {
        for(std::size_t j = 0; j < (N+1)/2; ++j) {
            swap_direction(
                m[i][j],
                m[N-j-1][i],
                m[j][N-i-1],
                m[N-i-1][N-j-1]
            );
        }
    }
}

int main() {
    constexpr std::size_t N = 5;
    Matrix<int, N> m {{
        {{1,2,3,4,5}},
        {{6,7,8,9,10}},
        {{11,12,13,14,15}},
        {{16,17,18,19,20}},
        {{21,22,23,24,25}}
    }};

    std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);

    rotate_90(m, swap_clockwise);    
}

This currently doesn't compile, failing with the following error:

error: no matching function for call to 'std::function<void(int&, int&, int&, int&)>::function(<unresolved overloaded function type>)'
 std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);

However, even if it did compile, it also defeats the purpose of template programming to specify the type of the types of the parameters of the swap function (i.e., in the definition of std::function<void(int&, int&, int&, int&)> swap_clockwise(four_way_swap_clockwise);).

How can I pass the std::function with the template type deduced?

Upvotes: 2

Views: 131

Answers (4)

lrm29
lrm29

Reputation: 498

You could make your template function a functor:

struct four_way_swap_clockwise {
    template <typename T>
    void
    operator()(T& top_left, T& top_right, T& bottom_left, T& bottom_right) {
        T temp = top_left;
        top_left = top_right;
        top_right = bottom_right;
        bottom_right = bottom_left;
        bottom_left = temp;
    }
};

Then call:

four_way_swap_clockwise swap_clockwise;
rotate_90(m, swap_clockwise);

Upvotes: 2

Yakk - Adam Nevraumont
Yakk - Adam Nevraumont

Reputation: 275966

template<class T> struct tag_t{using type=T;};
template<class T> using block_deduction=typename tag_t<T>::type;

This construct blocks C++ from trying to deduce template arguments from a function argument.

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, block_deduction<std::function<void(T&, T&, T&, T&)>> swap_direction) {

now the type of the 2nd argument is always deduced from the type of the first!

The next problem is that std::function doesn't disambiguate overloaded function names. An overloaded function name isn't a C++ value, it is a set of names which (in the right context) a value is found. std::function construction is not one of those contexts.

We can extend std::function with an additional constructor like this:

template<class Sig, class F=std::function<Sig>>
struct my_func:F {
  using F::F;
  using F::operator=;
  my_func( Sig* ptr ):F(ptr) {}
  my_func& operator=( Sig* ptr ) {
    F::operator=(ptr);
    return *this;
  }
  my_func()=default;
  my_func(my_func&&)=default;
  my_func(my_func const&)=default;
  my_func& operator=(my_func&&)=default;
  my_func& operator=(my_func const&)=default;
}; 

live example.

An alternative approach is to wrap your overload set into a lambda:

auto overloads = [](auto&&...args){ return four_way_swap_clockwise(decltype(args)(args)...); };

then pass the overloads to your function. This lambda represents all over the overloads of four_way_swap_clockwise at once.

We can also manually disambiguate by doing four_way_swap_clockwise<int>.

Both of these still requires the block_deduction technique above.

An alternative to consider would be:

template <typename T, std::size_t N, class F>
void rotate_90(Matrix<T, N>& m, F&& swap_direction)

where we leave swap_direction completely free and let any failures occur within the algorithm. This also gives a slight performance boost. You still have to disambiguate the four_way_swap_clockwise with either <int> or the lambda-wrapper technique.

Another approach would be to make for_way_swap_clockwise a lambda itself:

auto four_way_swap_clockwise = [](auto& top_left, auto& top_right, auto& bottom_left, auto& bottom_right) {
  auto temp = top_left;
  top_left = top_right;
  top_right = bottom_right;
  bottom_right = bottom_left;
  bottom_left = temp;
};

and now it is an object with a template operator() overload. This with block_deduction solves your problem.

In short, there are lots of ways around your problem.

Upvotes: 4

WhiZTiM
WhiZTiM

Reputation: 21576

To call the function,

template <typename T, std::size_t N>
void rotate_90(Matrix<T, N>& m, std::function<void(T&, T&, T&, T&)> swap_direction);

given:

template <typename T>
void four_way_swap_clockwise(T& top_left, T& top_right, T& bottom_left, T& bottom_right);

you can simply try this:

rotate_90<int>(m, four_way_swap_clockwise<int>);

As to why you cannot call it like:

rotate_90(m, four_way_swap_clockwise);

It is partly because the name four_way_swap_clockwise is a template-function and not a function, and using such name requires its instantiation. which I instantiated as four_way_swap_clockwise<int>

Better still, as per my first comment on your Question, it will be better to write the rotate_90 like:

template <typename T, std::size_t N, typename Func>
void rotate_90(Matrix<T, N>& m, Func swap_direction);

Upvotes: 2

rocambille
rocambille

Reputation: 15996

You may prefer to make rotate_90 more generic like this:

template <typename T, std::size_t N, typename F>
void rotate_90(Matrix<T, N>& m, F swap_direction) {
    for(std::size_t i = 0; i < N/2; ++i) {
        for(std::size_t j = 0; j < (N+1)/2; ++j) {
            swap_direction(
                m[i][j],
                m[N-j-1][i],
                m[j][N-i-1],
                m[N-i-1][N-j-1]
            );
        }
    }
}

Upvotes: 4

Related Questions