Desmond Gold
Desmond Gold

Reputation: 2085

How to recursively yield generator by overloading yield_value?

I've created a generator that will have an overload operator* in order to be converted into std::ranges::subrange and I also want to overload yield_value from promise_type that accepts a subrange type that will be yielded recursively.

Source Code:

template <typename T>
class [[nodiscard]] generator {
  public:
    using value_type = T;
    struct promise_type;
    using handle_type = std::coroutine_handle<promise_type>;

  private:
    handle_type handle_ { nullptr };

    explicit generator(handle_type handle) : handle_(handle) {}

  public:
    struct promise_type {
      value_type value_;

      generator<value_type> get_return_object() {
        return generator{ handle_type::from_promise(*this) };
      }

      std::suspend_always initial_suspend() { return {}; }
            
      std::suspend_always final_suspend() { return {}; }
           
      void unhandled_exception() { std::terminate(); }

      std::suspend_always yield_value(const value_type& value) noexcept {
        value_ = value;
        return {};
      }
            
      template <typename U>
      std::suspend_never await_transform(U&&) = delete;

      void return_void() {}
    };

    generator() noexcept = default;
    generator(const generator&) = delete;
    generator(generator&& other) noexcept
    : handle_(std::move(other.handle_)) {
      other.handle_ = nullptr;
    }

    ~generator() { if (handle_) handle_.destroy(); }

    generator& operator=(const generator&) = delete;

    generator& operator=(generator&& other) noexcept {
      handle_ = std::move(other.handle_);
      other.handle_ = nullptr;
      return *this;
    }

    void swap(generator& other) noexcept {
      using std::swap;
      swap(handle_, other.handle_);
    }

    class iterator {
      private:
        handle_type handle_;
        friend generator;

        explicit iterator(handle_type handle) noexcept
        : handle_(handle) {}

      public:
        using value_type = std::remove_cvref_t<T>;
        using reference  = value_type&;
        using const_reference = const value_type&;
        using pointer = value_type*;
        using const_pointer = const value_type*;
        using size_type = std::size_t;
        using difference_type = std::ptrdiff_t;
        using iterator_category = std::input_iterator_tag;

        iterator() noexcept = default;

        friend bool operator==(const iterator& iter, std::default_sentinel_t) noexcept {
          return iter.handle_.done();
        }

        friend bool operator==(std::default_sentinel_t s, const iterator& iter) noexcept {
          return (iter == s);
        }

        iterator& operator++() {
          if (handle_.done()) handle_.promise().unhandled_exception();
          handle_.resume();
          return *this;          
        }

        iterator operator++(int) {
          auto temp = *this;
          ++*this;
          return temp;
        }

        reference operator*() noexcept {
          return handle_.promise().value_;
        }

        pointer operator->() noexcept {
          return std::addressof(operator*());
        }

    };

    iterator begin() noexcept {
      if (handle_) {
        handle_.resume();
        if (handle_.done())
          handle_.promise().unhandled_exception();
      }
      return iterator{handle_};
    }

    std::default_sentinel_t end() noexcept {
        return std::default_sentinel;
    }
};

Example:

auto generate_0(int n) -> generator<int> {
  while (n != 0)
    co_yield n--;
}

auto generate_1() -> generator<int> {
  for (const auto& elem : generate_0(10)) {
    co_yield elem;
  }
}

generate_1 will work obviously but I want have the same output like the generate_1 that each element is co_yield-ed directly inside the yield_value:

auto generate_1() -> generator<int> {
  co_yield* generate_0(10);
}

Such that: In class generator:

auto operator*() {
      return std::ranges::subrange(begin(), end());
}

In nested class generator<...>::promise_type:

template <typename U>
std::suspend_always yield_value(const std::ranges::subrange<U, std::default_sentinel_t>& r) noexcept {
  /** ... **/
  return {};
}

Upvotes: 4

Views: 730

Answers (1)

HTNW
HTNW

Reputation: 29193

First things first: bugs/odd bits on your end.

  • I don't think it's worth it trying to support old-style iterators. It doesn't make sense to default-construct generator<T>::iterator, and the new-style iterator concepts do not require it. You can tear out a lot of junk from iterator.
    • Also, == is magical. If x == y doesn't find a matching operator== but y == x does, then x == y is automatically rewritten to y == x. You don't need to provide both operator==s.
  • The promise_type does not need to hold T by value. An odd thing about yielding things from coroutines is that if you make yield_value take by-reference, you can get a reference to something that lives in the coroutine state. But the coroutine state is preserved until you resume it! So promise_type can instead hold T const*. Now you no longer require annoying things like copyability and default-constructibility from T.
  • It appears to be unnatural for a generator to initially suspend. Currently, if you do g.begin(); g.begin();, you will advance the generator even though you've incremented no iterator. If you make g.begin() not resume the coroutine and remove the initital suspension, everything just works. Alternatively, you could make generator track whether it has started the coroutine and only advance it to the first yield on begin(), but that's complicated.
  • While calling std::terminate() on every operation that's normally UB may be nice, it's also noisy and I'm just not going to include it in this answer. Also, please don't call it via unhandled_exception. That's just confusing: unhandled_exception has one very specific purpose and meaning and you are just not respecting that.
  • generator<T>::operator=(generator&&) leaks *this's coroutine state! Also, your swap is nonstandard because it is not a free 2-arg function. We can fix these by making operator= do what swap did and then getting rid of swap because std::swap works.

From a design/theory standpoint, I think it makes more sense to implement this syntax instead.

auto generate_1() -> generator<int> {
  co_await generate_0(10);
}

A generator can temporarily give up control to another and may resume running after it awaits for the inner generator to run out. Implementing something to yield from a range can be easily implemented atop this by making a generator wrapping the range. This also lines up with the syntax in other languages like Haskell.

Now, coroutines have no stack. That means that as soon as we cross a function call boundary away from a coroutine like generate_1, it is not possible to suspend/resume that function via the coroutine state associated with the caller. So we have to implement our own stack, where we extend our coroutine state (promise_type) with the ability to record that it is currently pulling from another coroutine instead of having its own value. (Please note this would also apply to yielding from a range: whatever function is called to receive the range from generator_1 will not be able to control generator_1's coroutine.) We do this by making promise_type hold a

std::variant<T const*, std::subrange<iterator, std::default_sentinel_t>> value;

Note that promise_type does not own the generator represented by the subrange. Most of the time (as it is in generator_1) the same trick as yield_value applies: the generator which owns the sub-coroutine's state lives inside the caller coroutine's stack.

(This is also a point against directly implementing co_yield from a range: we need to fix the type of whatever is going into promise_type. From an API standpoint, it's understandable for co_await inside a generator<T> to accept generator<T>s. But if we implemented co_yield we'd only be able to directly handle one specific kind of range—a subrange wrapping a generator. That'd be weird. And to do otherwise we'd need to implement type-erasure; but the most obvious way to type-erase a range in this context is to make a generator. So we're back to a generator awaiting on another as being the more fundamental operation.)

The stack of running generators is now a linked-list threaded through their promise_types. Everything else just writes itself.

struct suspend_maybe { // just a general-purpose helper
    bool ready;
    explicit suspend_maybe(bool ready) : ready(ready) { }
    bool await_ready() const noexcept { return ready; }
    void await_suspend(std::coroutine_handle<>) const noexcept { }
    void await_resume() const noexcept { }
};

template<typename T>
class [[nodiscard]] generator {
public:
    struct iterator;
    struct promise_type;
    using handle_type = std::coroutine_handle<promise_type>;
    using range_type = std::ranges::subrange<iterator, std::default_sentinel_t>;

private:
    handle_type handle;

    explicit generator(handle_type handle) : handle(std::move(handle)) { }
public:
    class iterator {
    private:
        handle_type handle;
        friend generator;

        explicit iterator(handle_type handle) noexcept : handle(handle) { }
    public:
        // less clutter
        using iterator_concept = std::input_iterator_tag;
        using value_type = std::remove_cvref_t<T>;
        using difference_type = std::ptrdiff_t;

        // just need the one
        bool operator==(std::default_sentinel_t) const noexcept {
            return handle.done();
        }
        // need to muck around inside promise_type for this, so the definition is pulled out to break the cycle
        inline iterator &operator++();
        void operator++(int) { operator++(); }
        // again, need to see into promise_type
        inline T const *operator->() const noexcept;
        T const &operator*() const noexcept {
          return *operator->();
        }
    };
    iterator begin() noexcept {
        return iterator{handle};
    }
    std::default_sentinel_t end() const noexcept {
        return std::default_sentinel;
    }

    struct promise_type {
        // invariant: whenever the coroutine is non-finally suspended, this is nonempty
        // either the T const* is nonnull or the range_type is nonempty
        // note that neither of these own the data (T object or generator)
        // the coroutine's suspended state is often the actual owner
        std::variant<T const*, range_type> value = nullptr;

        generator get_return_object() {
            return generator(handle_type::from_promise(*this));
        }
        // initially suspending does not play nice with the conventional asymmetry between begin() and end()
        std::suspend_never initial_suspend() { return {}; }
        std::suspend_always final_suspend() noexcept { return {}; }
        void unhandled_exception() { std::terminate(); }
        std::suspend_always yield_value(T const &x) noexcept {
            value = std::addressof(x);
            return {};
        }
        suspend_maybe await_transform(generator &&source) noexcept {
            range_type range(source);
            value = range;
            return suspend_maybe(range.empty());
        }
        void return_void() { }
    };

    generator(generator const&) = delete;
    generator(generator &&other) noexcept : handle(std::move(other.handle)) {
        other.handle = nullptr;
    }
    ~generator() { if(handle) handle.destroy(); }
    generator& operator=(generator const&) = delete;
    generator& operator=(generator &&other) noexcept {
        // idiom: implementing assignment by swapping means the impending destruction/reuse of other implicitly handles cleanup of the resource being thrown away (which originated in *this)
        std::swap(handle, other.handle);
        return *this;
    }
};

// these are both recursive because I can't be bothered otherwise
// feel free to change that if it actually bites
template<typename T>
inline auto generator<T>::iterator::operator++() -> iterator& {
    struct visitor {
        handle_type handle;
        void operator()(T const*) { handle(); }
        void operator()(range_type &r) {
            if(r.advance(1).empty()) handle();
        }
    };
    std::visit(visitor(handle), handle.promise().value);
    return *this;
}
template<typename T>
inline auto generator<T>::iterator::operator->() const noexcept -> T const* {
    struct visitor {
        T const *operator()(T const *x) { return x; }
        T const *operator()(range_type &r) {
            return r.begin().operator->();
        }
    };
    return std::visit(visitor(), handle.promise().value);
}

Nothing appears to be on fire.

static_assert(std::ranges::input_range<generator<unsigned>>); // you really don't need all that junk in iterator!
generator<unsigned> generate_0(unsigned n) {
    while(n != 0) co_yield n--;
}
generator<unsigned> generate_1(unsigned n) {
    co_yield 0;
    co_await generate_0(n);
    co_yield 0;
}
int main() {
    auto g = generate_1(5);
    for(auto i : g) std::cout << i << "\n"; // 0 5 4 3 2 1 0 as expected
    // even better, asan is happy!
}

If you want to yield values from an arbitrary range, I would just implement this type-eraser.

auto generate_all(std::ranges::input_range auto &&r) -> generator<std::ranges::range_value_t<decltype(r)>> {
    for(auto &&x : std::forward<decltype(r)>(r)) co_yield std::forward<decltype(x)>(x);
}

So you get e.g.

generator<unsigned> generate_1(unsigned n) {
    co_await generate_all(std::array{41u, 42u, 43u});
    co_await generate_0(n);
    co_yield 0;
}

Upvotes: 2

Related Questions