user541686
user541686

Reputation: 210525

How to determine if a function returns a reference in C++03?

In pre-C++11, how can I determine if a given function returns a reference or not, when called with particular arguments?

For example, if the code looks like this:

template<class F>
bool returns_reference(F f) { return is_reference(f(5)); }

then how should I implement is_reference?

Note that f may also be a functor, and its operator() may have multiple overloads -- I only care about the overload that actually gets called via my arguments here.

Upvotes: 7

Views: 143

Answers (2)

user541686
user541686

Reputation: 210525

I found the answer to my own question.

returns_reference returns a type with size > 1 if the return type of the function is a reference.

It works in most scenarios, but the number of combinations of const and volatile increases exponentially with the number of parameters.

Whenever it doesn't work -- whether because we have more than one argument or whether overload resolution is rather obscure (for example, when the const version of operator() works, but the non-const version doesn't), the user should use reftype on all arguments before passing them to returns_reference.

I think the reftype version might still have a few edge-cases in C++11 (regarding r-value references that are in fact l-lvalues), but it's "good enough" for me for now. At that point, the question of what's actually a "reference" is itself ambiguous anyway. (Although, if we're using C++11, we can just use decltype and forget about all this. The only situation we would still use this under C++11 is if the compiler doesn't support decltype but does support r-value references.)

template<bool B> struct static_bool { unsigned char _[B + 1]; static_bool(...) { } operator bool() const { return B; } };

static_bool<false> returns_reference(...) { return false; }

template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T &(*)(void)>(0) ()))> returns_reference(F, T &) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T const &(*)(void)>(0) ()))> returns_reference(F, T const &) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T volatile &(*)(void)>(0) ()))> returns_reference(F, T volatile &) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T const volatile &(*)(void)>(0) ()))> returns_reference(F, T const volatile &) { return NULL; }

template<class T> struct type_wrapper { };
template<class T> type_wrapper<T &> reftype(T &) { return type_wrapper<T &>(); }
template<class T> type_wrapper<T const &> reftype(T const &) { return type_wrapper<T const &>(); }
template<class T> type_wrapper<T volatile &> reftype(T volatile &) { return type_wrapper<T volatile &>(); }
template<class T> type_wrapper<T const volatile &> reftype(T const volatile &) { return type_wrapper<T const volatile &>(); }

#if __cplusplus >= 201103L
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T &&(*)(void)>(0) ()))> returns_reference(F, T &&) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T const &&(*)(void)>(0) ()))> returns_reference(F, T const &&) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T volatile &&(*)(void)>(0) ()))> returns_reference(F, T volatile &&) { return NULL; }
template<class T, class F> static_bool<!!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T const volatile &&(*)(void)>(0) ()))> returns_reference(F, T const volatile &&) { return NULL; }

template<class T> type_wrapper<T &&> reftype(T &&) { return type_wrapper<T &&>(); }
template<class T> type_wrapper<T const &&> reftype(T const &&) { return type_wrapper<T const &&>(); }
template<class T> type_wrapper<T volatile &&> reftype(T volatile &&) { return type_wrapper<T volatile &&>(); }
template<class T> type_wrapper<T const volatile &&> reftype(T const volatile &&) { return type_wrapper<T const volatile &&>(); }
#endif

template<class T, class F> static_bool<
    !!sizeof(&(static_cast<F(*)()>(0) ())(static_cast<T (*)(void)>(0) ()))
> returns_reference(type_wrapper<F>, type_wrapper<T> = type_wrapper<T>()) { return NULL; }

Test code:

struct Test
{
    Test() { }
    Test(Test &) { }
    Test(Test const &) { }
    Test(Test volatile &) { }
    Test(Test const volatile &) { }
    Test *operator()(Test *p) const;
    Test const *operator()(Test const *p) const;
    Test volatile *operator()(Test volatile *p) const;
    Test const volatile *operator()(Test const volatile *p) const;
    Test &operator()(Test &p) const;
    Test const &operator()(Test const &p) const;
    Test volatile &operator()(Test volatile &p) const;
    Test const volatile &operator()(Test const volatile &p) const;
#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1700
    Test &&operator()(Test &&p) const { return std::move(p); }
    Test const &&operator()(Test const &&p) const;
    Test volatile &&operator()(Test volatile &&p) const;
    Test const volatile &&operator()(Test const volatile &&p) const;
#endif
};

Test *test1(Test *p);
Test const *test2(Test const *p);
Test volatile *test3(Test volatile *p);
Test const volatile *test4(Test const volatile *p);
Test &test5(Test &p);
Test const &test6(Test const &p);
Test volatile &test7(Test volatile &p);
Test const volatile &test8(Test const volatile &p);
#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1700
Test &&test9(Test &&p);
Test const &&test10(Test const &&p);
Test volatile &&test11(Test volatile &&p);
Test const volatile &&test12(Test const volatile &&p);
#endif
int main()
{
    Test t; (void)t;
    Test const tc; (void)tc;
    Test volatile tv; (void)tv;
    Test const volatile tcv; (void)tcv;
    std::cerr << (sizeof(returns_reference(        t      ,         &t   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         &tc  )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         &tv  )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         &tcv )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         t    )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         tc   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         tv   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        t      ,         tcv  )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test1  ,         &t   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test2  ,         &tc  )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test3  ,         &tv  )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test4  ,         &tcv )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test5  ,         t    )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test6  ,         tc   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test7  ,         tv   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test8  ,         tcv  )) != sizeof(unsigned char)) << std::endl;
#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1700
    std::cerr << (sizeof(returns_reference(        test9  ,         t    )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test10 ,         tc   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test11 ,         tv   )) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(        test12 ,         tcv  )) != sizeof(unsigned char)) << std::endl;
#endif
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(&t  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(&tc ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(&tv ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(&tcv))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(t   ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(tc  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(tv  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(t     ), reftype(tcv ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test1 ), reftype(&t  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test2 ), reftype(&tc ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test3 ), reftype(&tv ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test4 ), reftype(&tcv))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test5 ), reftype(t   ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test6 ), reftype(tc  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test7 ), reftype(tv  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test8 ), reftype(tcv ))) != sizeof(unsigned char)) << std::endl;
#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1700
    std::cerr << (sizeof(returns_reference(reftype(test9 ), reftype(t   ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test10), reftype(tc  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test11), reftype(tv  ))) != sizeof(unsigned char)) << std::endl;
    std::cerr << (sizeof(returns_reference(reftype(test12), reftype(tcv ))) != sizeof(unsigned char)) << std::endl;
#endif
    return 0;
}

Upvotes: 1

dyp
dyp

Reputation: 39121

Here's a SFINAE-based solution that checks if a function-call expression yields an lvalue:

#include <boost/type_traits.hpp>
#include <boost/utility.hpp>
#include <cstddef>

// Func: function (object/pointer/reference) type
// Arg0: type of the first argument to use (for overload resolution)
template<class Func, class Arg0>
struct yields_lvalue_1 // with one argument
{
    typedef char yes[1];
    typedef char no[2];

    // decay possible function types
    typedef typename boost::decay<Func>::type F_decayed;

    // a type whose constructor can take any lvalue expression
    struct Any
    {
        template<class T>
        Any(T&);
    };

    // SFINAE-test: if `Any(....)` is well-formed, this overload of `test` is
    // viable
    template<class T>
    static yes& test(boost::integral_constant<std::size_t,
                 sizeof(Any( boost::declval<T>()(boost::declval<Arg0>()) ))>*);
    // fall-back
    template<class T>
    static no&  test(...);

    // perform test
    static bool const result = sizeof(test<F_decayed>(0)) == sizeof(yes);
};

Some exemplary function objects:

struct foo
{
    bool& operator()(int);
    bool operator()(double);
};

struct bar
{
    template<class T>
    double operator()(T);
};

Usage example:

#include <iostream>
#include <iomanip>

void print(bool expect, bool result)
{
    std::cout << "expect: "<<std::setw(5)<<expect<<" -- result: "<<result<<"\n";
}

int main()
{
    std::cout << std::boolalpha;
    print(true , yields_lvalue_1<foo, int>   ::result);
    print(false, yields_lvalue_1<foo, double>::result);
    print(false, yields_lvalue_1<bar, int>   ::result);
    print(true , yields_lvalue_1<foo&(*)(long), int>::result);
    print(false, yields_lvalue_1<void(*)(int), short>::result);
    print(true , yields_lvalue_1<bool&(short), long>::result);
    print(false, yields_lvalue_1<void(float), int>::result);
    print(true , yields_lvalue_1<char&(&)(bool), long>::result);
    print(false, yields_lvalue_1<foo(&)(int), short>::result);
}

Upvotes: 2

Related Questions