Reputation: 41
I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop. Here I've written a simple class illustrating my idea:
import asyncio
from typing import List
class AsyncSimpleIterator:
def __init__(self, data: List[str], batch_size=None):
self.data = data
self.batch_size = batch_size
self.doc2index = self.get_doc_ids()
def get_doc_ids(self):
return list(range(len(self.data)))
async def get_batch_data(self, doc_ids):
print("get_batch_data() running")
page = [self.data[j] for j in doc_ids]
return page
async def get_docs(self, batch_size):
print("get_docs() running")
_batch_size = self.batch_size or batch_size
batches = [self.doc2index[i:i + _batch_size] for i in
range(0, len(self.doc2index), _batch_size)]
for _, doc_ids in enumerate(batches):
docs = await self.get_batch_data(doc_ids)
yield docs, doc_ids
async def main(self):
print("main() running")
async for res in self.get_docs(batch_size=2):
print(res) # how to yield instead of print?
def gen_batches(self):
# how to get results of self.main() here?
loop = asyncio.get_event_loop()
loop.run_until_complete(self.main())
loop.close()
DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
iterator.gen_batches()
So, my question is, how to yield a result from main()
to gather it inside gen_batches()
?
When I print the result inside main()
, I get the following output:
main() running
get_docs() running
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [0, 1])
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [2, 3])
Upvotes: 4
Views: 2375
Reputation: 41
A working solution based on @user4815162342 answer to the original question:
import asyncio
from typing import List
class AsyncSimpleIterator:
def __init__(self, data: List[str], batch_size=None):
self.data = data
self.batch_size = batch_size
self.doc2index = self.get_doc_ids()
def get_doc_ids(self):
return list(range(len(self.data)))
async def get_batch_data(self, doc_ids):
print("get_batch_data() running")
page = [self.data[j] for j in doc_ids]
return page
async def get_docs(self, batch_size):
print("get_docs() running")
_batch_size = self.batch_size or batch_size
batches = [self.doc2index[i:i + _batch_size] for i in
range(0, len(self.doc2index), _batch_size)]
for _, doc_ids in enumerate(batches):
docs = await self.get_batch_data(doc_ids)
yield docs, doc_ids
def gen_batches(self):
loop = asyncio.get_event_loop()
async def collect():
return [j async for j in self.get_docs(batch_size=2)]
items = loop.run_until_complete(collect())
loop.close()
return items
DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
result = iterator.gen_batches()
print(result)
Upvotes: 0
Reputation: 154911
I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop
Yielding from an async for
works like a regular yield, except that it also has to be collected by an async for
or equivalent. For example, the yield
in get_docs
makes it an async generator. If you replace print(res)
with yield res
in main()
, it will make main()
an async generator as well.
the generator in
main()
should exhaust ingen_batches()
, so I can gather all results ingen_batches()
To collect the values produced by an async generator (such as main()
with print(res)
replaced with yield res
), you can use a helper coroutine:
def gen_batches(self):
loop = asyncio.get_event_loop()
async def collect():
return [item async for item in self.main()]
items = loop.run_until_complete(collect())
loop.close()
return items
The collect()
helper makes use of a PEP 530 asynchronous comprehension, which can be thought of as syntactic sugar for the more explicit:
async def collect():
l = []
async for item in self.main():
l.append(item)
return l
Upvotes: 1