Carter Li
Carter Li

Reputation: 159

std::visit and std::variant usage

#include <variant>
#include <exception>
#include <type_traits>
#include <cassert>

template <typename T>
struct Promise {
    std::variant<
        std::monostate,
        std::conditional_t<std::is_void_v<T>, std::monostate, T>,
        std::exception_ptr
    > result_;

    T await_resume() const {
        assert(result_.index() > 0);
#if 1
        // old code that I want to optimise
        if (result_.index() == 2) {
            std::rethrow_exception(std::get<2>(result_));
        }
        if constexpr (!std::is_void_v<T>) {
            return std::get<1>(result_);
        }
#else
        // new code, won't compile
        return std::visit([](auto&& arg) {
            using TT = std::decay_t<decltype(arg)>;
            if constexpr (!std::is_same_v<TT, std::exception_ptr>) {
                std::rethrow_exception(arg);
            } else if constexpr (!std::is_void_v<T>) {
                return arg;
            }
        });
#endif
    }
};

template int Promise<int>::await_resume() const;
template std::exception_ptr Promise<std::exception_ptr>::await_resume() const;
template void Promise<void>::await_resume() const;

Promise::await_resume is a simple function which does following thing:

  1. If the variant holds a value of std::exception_ptr, rethrow the exception.
  2. If the variant holds a value of T (while T is set by users, may also be std::exception_ptr), return it out. If type of T is void, do nothing.

Originally I implement it using .index() check and std::get. It works, but std::get things generate extra checks internally and std::__1::__throw_bad_variant_access() things that not expected to happen: https://godbolt.org/z/YnjxDy

I want to optimise the code using std::visit according to cppreference, but can't get it compile.

Another problem is that when the type of T is std::exception_ptr, how can I know whether I should throw it?

Upvotes: 1

Views: 1909

Answers (2)

Carter Li
Carter Li

Reputation: 159

Refers to @Barry's answer, this is my final version:

T await_resume() const {
    if (auto* pep = std::get_if<2>(&result_)) {
        std::rethrow_exception(*pep);
    } else {
        if constexpr (!std::is_void_v<T>) {
            auto* pv = std::get_if<1>(&result_);
            assert(pv);
            return *pv;
        }
    }
}

Generates perfect asm, no extra checks, no bad_variant_access sh*t: https://godbolt.org/z/96gF_J

Upvotes: 0

Barry
Barry

Reputation: 302942

visit doesn't "optimise" code - it's just a good pattern for matching on the variant, and it's especially useful to ensure that you didn't forget any types.

But one of the requirements of visit is that each alternative has to return the same type. This is especially problematic in your use case, since only one of your alternatives should be returned... so it's just not a good fit. You also need to handle the monostate case in the visit, and you really have no way to do that (besides... throwing?) so you're just out of luck.

The version you had before is perfectly fine, I would just annotate it with types a bit to be more expressive:

struct Void { };

template <typename T>
struct Promise {
    using Value = std::conditional_t<std::is_void_v<T>, Void, T>;

    std::variant<
        std::monostate,
        Value,
        std::exception_ptr
    > result_;

    T await_resume() const {
        assert(not result_.valueless_by_exception());
        assert(not std::holds_alternative<std::monostate>(result_));

        if (auto* exc = std::get_if<std::exception_ptr>(&result)) {
            std::rethrow_exception(*exc);
        } else {
            if constexpr (not std::is_void_v<T>) {
                return std::get<T>(result_);
            }
        }
    }
}

I think this is just a bit nicer than using 0, 1, and 2 explicitly.


Another problem is that when the type of T is std::exception_ptr, how can I know whether I should throw it?

Simple: You don't throw it. Don't have wildly different semantics in generic code based on your type. Promise<T>::await_resume() returns a T if it holds a T. Promise<std::exception_ptr>::await_resume() returns an exception_ptr. That's fine.

I guess actually with my implementation above, using the explicit get_if<exception_ptr> would become ambiguous, which is unfortunate... so maybe the 0/1/2 is just the easy way to go.

Upvotes: 1

Related Questions