Carter Li
Carter Li

Reputation: 159

C++ coroutines: call a coroutine function without co_await

struct task is modified from https://github.com/Quuxplusone/coro/blob/master/include/coro/gor_task.h. I just change suspend_always in Line 18 to suspend_never.

// test.cpp
#include <exception>
#include <experimental/coroutine>
#include <variant>

template<class T>
struct task {
    struct promise_type {
        std::variant<std::monostate, T, std::exception_ptr> result_;
        std::experimental::coroutine_handle<void> waiter_;

        task get_return_object() { return task(this); }
        auto initial_suspend() { return std::experimental::suspend_never{}; } // Originally suspend_always
        auto final_suspend() {
            struct Awaiter {
                promise_type *me_;
                bool await_ready() { return false; }
                void await_suspend(std::experimental::coroutine_handle<void> caller) {
                    me_->waiter_.resume();
                }
                void await_resume() {}
            };
            return Awaiter{this};
        }
        template<class U>
        void return_value(U&& u) {
            result_.template emplace<1>(static_cast<U&&>(u));
        }
        void unhandled_exception() {
            result_.template emplace<2>(std::current_exception());
        }
    };

    bool await_ready() { return false; }
    void await_suspend(std::experimental::coroutine_handle<void> caller) {
        coro_.promise().waiter_ = caller;
        coro_.resume();
    }
    T await_resume() {
        if (coro_.promise().result_.index() == 2) {
            std::rethrow_exception(std::get<2>(coro_.promise().result_));
        }
        return std::get<1>(coro_.promise().result_);
    }

    ~task() {
        coro_.destroy();
    }
private:
    using handle_t = std::experimental::coroutine_handle<promise_type>;
    task(promise_type *p) : coro_(handle_t::from_promise(*p)) {}
    handle_t coro_;
};

#include <stdio.h>

task<int> f2() {
    puts("enter f2");
    co_return 1;
}

task<int> f1() {
    puts("enter f1");
    int a = co_await f2();
    printf("f2 return: %d\n", a);
    co_return a;
}

int main() {
    f1();
}
$ clang++ -fcoroutines-ts -std=c++17 -stdlib=libc++ -lc++ -lc++abi test.cpp -o test
$ ./test
enter f1
enter f2
fish: './test' terminated by signal SIGSEGV (Address boundary error)

Because there's no co_await, I expect f2 return: 1 should be printed and the program should exit normally, but it crashes with segfault. Why and how can I fix this issue?

Upvotes: 0

Views: 2437

Answers (1)

Nicol Bolas
Nicol Bolas

Reputation: 473946

When a coroutine function performs co_await, it halts that function's execution and schedules the resumption of that execution to someone else. This "someone else" ultimately depends on the expression being co_awaited on, the promise type of the coroutine function, and the future type returned by the coroutine.

So let's look at control flow here.

f2 gets called. It executes, then terminates via a co_return. That means that the coroutine handle for f2 is complete. This also means that the promise type's final_suspend will be called. Well, the coroutine machinery expects that final_suspend will return an awaitable type, which your task::promise provides.

Except... task::promise::waiter_ is uninitialized. This is because only await_suspend assigns a value to waiter_. And nobody has co_awaited on the return value of f2 yet. So until that happens, waiter_ does not have a value.

So the moment f2 gets called, you attempt to continue a coroutine handle that is at best nullptr, and thereby get a crash.

It would make more sense if final_suspend's Awaiter type first checked to see if waiter_ is nullptr (which means nobody is waiting on the coroutine yet) and if so, returns true from await_ready.

Upvotes: 1

Related Questions