my-master
my-master

Reputation: 41

How to yield from an async for loop using asyncio?

I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop. Here I've written a simple class illustrating my idea:

import asyncio
from typing import List

class AsyncSimpleIterator:
    def __init__(self, data: List[str], batch_size=None):
        self.data = data
        self.batch_size = batch_size
        self.doc2index = self.get_doc_ids()

    def get_doc_ids(self):
        return list(range(len(self.data)))

    async def get_batch_data(self, doc_ids):
        print("get_batch_data() running")
        page = [self.data[j] for j in doc_ids]
        return page

    async def get_docs(self, batch_size):
        print("get_docs() running")

        _batch_size = self.batch_size or batch_size
        batches = [self.doc2index[i:i + _batch_size] for i in
                   range(0, len(self.doc2index), _batch_size)]

        for _, doc_ids in enumerate(batches):
            docs = await self.get_batch_data(doc_ids)
            yield docs, doc_ids

    async def main(self):
        print("main() running")
        async for res in self.get_docs(batch_size=2):
            print(res)  # how to yield instead of print?

    def gen_batches(self):
        # how to get results of self.main() here?
        loop = asyncio.get_event_loop()
        loop.run_until_complete(self.main())
        loop.close()


 DATA = ["Hello, world!"] * 4
 iterator = AsyncSimpleIterator(DATA)
 iterator.gen_batches()

So, my question is, how to yield a result from main() to gather it inside gen_batches()?

When I print the result inside main(), I get the following output:

main() running
get_docs() running
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [0, 1])
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [2, 3])

Upvotes: 4

Views: 2375

Answers (2)

my-master
my-master

Reputation: 41

A working solution based on @user4815162342 answer to the original question:

import asyncio
from typing import List


class AsyncSimpleIterator:

def __init__(self, data: List[str], batch_size=None):
    self.data = data
    self.batch_size = batch_size
    self.doc2index = self.get_doc_ids()

def get_doc_ids(self):
    return list(range(len(self.data)))

async def get_batch_data(self, doc_ids):
    print("get_batch_data() running")
    page = [self.data[j] for j in doc_ids]
    return page

async def get_docs(self, batch_size):
    print("get_docs() running")

    _batch_size = self.batch_size or batch_size
    batches = [self.doc2index[i:i + _batch_size] for i in
               range(0, len(self.doc2index), _batch_size)]

    for _, doc_ids in enumerate(batches):
        docs = await self.get_batch_data(doc_ids)
        yield docs, doc_ids

def gen_batches(self):
    loop = asyncio.get_event_loop()

    async def collect():
        return [j async for j in self.get_docs(batch_size=2)]

    items = loop.run_until_complete(collect())
    loop.close()
    return items


DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
result = iterator.gen_batches()
print(result)

Upvotes: 0

user4815162342
user4815162342

Reputation: 154911

I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop

Yielding from an async for works like a regular yield, except that it also has to be collected by an async for or equivalent. For example, the yield in get_docs makes it an async generator. If you replace print(res) with yield res in main(), it will make main() an async generator as well.

the generator in main() should exhaust in gen_batches(), so I can gather all results in gen_batches()

To collect the values produced by an async generator (such as main() with print(res) replaced with yield res), you can use a helper coroutine:

def gen_batches(self):
    loop = asyncio.get_event_loop()
    async def collect():
        return [item async for item in self.main()]
    items = loop.run_until_complete(collect())
    loop.close()
    return items

The collect() helper makes use of a PEP 530 asynchronous comprehension, which can be thought of as syntactic sugar for the more explicit:

    async def collect():
        l = []
        async for item in self.main():
            l.append(item)
        return l

Upvotes: 1

Related Questions