ahhhaghost
ahhhaghost

Reputation: 21

C++20 coroutine use after free issue

I'm making an attempt to learn and implement C++20 coroutines, and I'm experiencing a bug.

Generator class:

template<class ReturnType = void>
    class enumerable
    {
    public:
        class promise_type;
        using handle_type = std::coroutine_handle<promise_type>;
        class promise_type
        {
        public:
            ReturnType current_value{};

            auto get_return_object()
            {
                return enumerable{ handle_type::from_promise(*this) };
            }

            auto initial_suspend()
            {
                return std::suspend_always{};
            }

            auto final_suspend() noexcept
            {
                return std::suspend_always();
            }

            void unhandled_exception()
            {
                // TODO:
            }

            void return_void()
            {

            }

            auto yield_value(ReturnType& value) noexcept
            {
                current_value = std::move(value);
                return std::suspend_always{};
            }

            auto yield_value(ReturnType&& value) noexcept
            {
                return yield_value(value);
            }
        };

        class iterator
        {
            using iterator_category = std::forward_iterator_tag;
            using difference_type = std::ptrdiff_t;
            using value_type = ReturnType;
            using pointer = ReturnType*;
            using reference = ReturnType&;
        private:
            handle_type handle;

        public:
            iterator(handle_type handle)
                : handle(handle)
            {

            }

            reference operator*() const
            {
                return this->handle.promise().current_value;
            }

            pointer operator->()
            {
                return &this->handle.promise().current_value;
            }

            iterator& operator++()
            {
                this->handle.resume();
                return *this;
            }

            friend bool operator==(const iterator& it, std::default_sentinel_t s) noexcept
            {
                return !it.handle || it.handle.done();
            }

            friend bool operator!=(const iterator& it, std::default_sentinel_t s) noexcept
            {
                return !(it == s);
            }

            friend bool operator==(std::default_sentinel_t s, const iterator& it) noexcept
            {
                return (it == s);
            }

            friend bool operator!=(std::default_sentinel_t s, const iterator& it) noexcept
            {
                return it != s;
            }
        };

        handle_type handle;

        enumerable() = delete;

        enumerable(handle_type h)
            : handle(h)
        {
            std::cout << "enumerable constructed: " << this << " : " << this->handle.address() << '\n';
        };

        iterator begin()
        {
            this->handle.resume();
            return iterator(this->handle);
        }

        std::default_sentinel_t end()
        {
            return {};
        }

        //Filters a sequence of values based on a predicate.
        template<class Predicate>
        enumerable<ReturnType> where(Predicate&& pred)
        {
            std::cout << "where: " << this << " : " << this->handle.address()  << '\n';
            for (auto& i : *this)
            {
                if(pred(i))
                    co_yield i;
            }
        }

        ~enumerable()
        {
            std::cout << "enumerable destructed: " << this << " : " << this->handle.address() << '\n';
        }
    };

Test code:

enumerable<int> numbers()
{
    co_yield 1;
    co_yield 2;
    co_yield 3;
    co_yield 4;
}

enumerable<int> filtered_numbers()
{
    return numbers().where([](int i) { return true; });
}
// Crashes
int main()
{
    for (auto& i : filtered_numbers())
    {
        std::cout << "value: " << i << '\n';
    }
    return 0;
}

Output:

enumerable constructed: 000000FF0550F560:000002959E3B5290
enumerable constructed: 000000FF0550F5D8:000002959E3B6470
destructed: 000000FF0550F560 : 000002959E3B5290
where: 000000FF0550F560 : 000000FF0550F640
//Works, despite "this" inside "where" still being destructed before use, can be observed with the couts.
int main()
{
    for(auto i : numbers().where([](int i) { return true; }))
    {
        std::cout << "value: " << i << '\n';
    }
    return 0;
}

Output:

enumerable constructed: 000000C9EDD2FD78:000001DADD1A61D0
enumerable constructed: 000000C9EDD2FD28:000001DADD1A73B0
destructed: 000000C9EDD2FD78 : 000001DADD1A61D0
where: 000000C9EDD2FD78 : 000001DADD1A61D0
value: 1
value: 2
value: 3
value: 4
destructed: 000000C9EDD2FD28 : 000001DADD1A73B0

Could somebody explain what is happening here? I'd like to come up with a workaround if possible, the crash does not happen if we return "std::suspend_never" in our promise_type's "initial_suspend", but suspending in initial_suspend isn't ideal behavior for a generator.

Upvotes: 1

Views: 312

Answers (2)

Yakk - Adam Nevraumont
Yakk - Adam Nevraumont

Reputation: 275740

    //Filters a sequence of values based on a predicate.
    template<class Predicate>
    enumerable<ReturnType> where(Predicate pred)&
    {
        std::cout << "where: " << this << " : " << this->handle.address()  << '\n';
        for (auto& i : *this)
        {
            if(pred(i))
                co_yield i;
        }
    }
    //Captures *this as well as above.
    template<class Predicate>
    enumerable<ReturnType> where(Predicate pred)&&
    {
        auto self=std::move(*this);
        std::cout << "where: " << this << " : " << this->handle.address()  << '\n';
        for (auto& i : self)
        {
            if(pred(i))
                co_yield i;
        }
    }

two changes.

  1. I take Predicate by value, to avoid dangling reference problem.

  2. I have a && overload that copies *this (well, moves from) and stores it within the coroutine.

This still doesn't work.

The first thing that happens is that our coroutine is suspended before any code is run. So the copy of auto self=std::move(*this) happens on the first time we try to get a value.

We can work around this in a few ways. One of them is to bounce to a free function and let it copy the enumerable<int>:

template<class Predicate>
friend enumerable<ReturnType> where( enumerable<ReturnType> self, Predicate pred ) {
  for (auto& i: self)
    if (pred(i))
      co_yield i;
}
//Filters a sequence of values based on a predicate.
template<class Predicate>
enumerable<ReturnType> where(Predicate pred)&
{
   return where( *this, std::move(pred) );
}
template<class Predicate>
enumerable<ReturnType> where(Predicate pred)&&
{
   return where( std::move(*this), std::move(pred) );
}

a second way is to modify enumerable<ReturnType> to support a setup phase.

struct init_done {};

auto initial_suspend() {
  return std::suspend_never{};
}
auto yield_value(init_done) noexcept {
  return std::suspend_always{};
}

and modify enumerable<int> returning functions to first co_yield init_done{}; after their setup is finished.

We'd do this on the first line of the numbers() coroutine, and after we copy *this into the local variable self in the where() coroutine.

This is probably simplest:

template<class F>
friend
enumerable<ReturnType> where2(enumerable<ReturnType> self, F f )
{
    for (auto i : self.where(std::move(f)))
        co_yield i;
}
template<class F>
enumerable<ReturnType> where(F f)&&
{
    return where2(std::move(*this), std::move(f));
}
template<class F>
enumerable<ReturnType> where(F f)&
{
    for (auto i : *this)
    {
        if (f(i))
            co_yield i;
    }
}

Upvotes: 1

Nicol Bolas
Nicol Bolas

Reputation: 473976

This:

return numbers().where([](int i) { return true; });

Creates a temporary (numbers()), then stores a reference to that temporary in a coroutine (the *this used in the loop), and then the temporary goes away.

That's bad. If you want to do chaining of coroutines, each step in that chain needs to be an object on someone's stack. where could be a non-member function that takes an enumerable by value. That would allow the where coroutine to preserve the existence of the enumerable.

Upvotes: 1

Related Questions