Reputation: 21
I am working with Langchain and FastApi. Basically I am creating a StreamingResponse, which streams back JSON diffs. The diffs are created by comparing a JSON document to its previous version. The various attributes of the JSON are generated by information it receives from OpenAI through langchain chains (with LCEL (https://python.langchain.com/docs/expression_language/) (also streamed).
To come up with all the attributes, I run a function called 'run' that triggers other (async) functions. These async functions add a json diff to an async.queue. The anext method unloads the queue.
The code is quite complex, i tried to condense the relevant parts in here:
This is the code that creates the async generator. The run method of the RunAgent actually builds the document and fills the queue.
from jsonpatch import JsonPatch
import asyncio
from pydantic import BaseModel
from src.llm.chain import meta_chain, sub_object_meta_chain, sub_object_details_chain
from src.schema import MyObject
class IterableAgent(BaseModel):
queue: asyncio.Queue | None = None
done: bool = False
timeout: int = 10
async def __anext__(self):
if self.done and self.queue.empty():
raise StopAsyncIteration
try:
return await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
except asyncio.TimeoutError:
raise StopAsyncIteration
def __aiter__(self):
self.done = False
self.queue = asyncio.Queue()
asyncio.create_task(self.run())
return self
@abstractmethod
def run():
raise NotImplementedError
class RunAgent(IterableAgent):
_last_object: str = "{}"
def _get_patch(self) -> str:
new_object = self.object.model_dump_json()
patch = JsonPatch.from_diff(self._last_object, new_object).patch
self._last_object = new_object
return patch
def _continue_response(self) -> None:
patch = self._get_patch()
if patch:
self.queue.put_nowait(patch)
async def _add_meta(self):
inputs = dict() # Any inputs I need here for the chain...
async for result in meta_chain.astream(inputs):
self.object.meta = result
self._continue_response()
async def _add_sub_objects_meta(self):
inputs = dict() # Any inputs I need here for the chain...
async for result in sub_object_meta_chain.astream(inputs):
self.object.sub_objects = result
self._continue_response()
async def _add_sub_objects_details(self, so):
inputs = dict(so=so) # Any inputs I need here for the chain...
async for result in sub_object_details_chain.astream(inputs):
so.details = result
self._continue_response()
async def run(self):
self.object = MyObject()
await self._add_meta()
await self._add_sub_objects_meta()
cors = [self._add_sub_objects_details(so) for so in self.object.sub_objects]
# I tried to create tasks instead of just having coroutines, no luck there...
tasks = [asyncio.create_task(x) for x in cors]
await asyncio.gather(*tasks)
self.done = True
Then there is also a small piece of code in my FastApi router that uses the generator to feed into the StreamingResponse.
@router.post("/new")
async def new_document(req: MyRequest):
\# The agent is implemented to be an AsyncIterator
agent = RunAgent() # I would typically unpack MyRequest and feed the relevant data into the agent initialization
return StreamingResponse(agent, media_type="application/json")
Now this works quite well, however sometimes, it just stops half way. The queue.get just times out. It seems as if tasks are just not started. It does seem that any tasks/coroutines that have started do complete. So for example What i would see for instance is the metadata of the base object gets generated, but the metadata of the subobjects not. Or the details of one of the sub objects is missing. I think there is some race condition, I just know where....
Upvotes: 2
Views: 329
Reputation: 110516
Python asyncio has a peculiar design choice that leads to a problem which is likely what you hit there:
Tasks created with .create_task
are not hard-referenced by the event loop. (it keeps only a weak-reference to them). While this sometimes do not show up in straightforward runs, or when there are few tasks- as the weak-referenced tasks are actually started in the first execution of the asyncio loop core, when there are lots of tasks (~2000 in some tests I made), or possibly less, when task creation takes place in more than one place like in this code, tasks may just vanish without trace.
That is likely your problem, and the solution is simply keeping a reference to tasks created by .create_task
calls - a simple class-level set
can keep those (or an instance level, if you have an __init__
method there):
...
class IterableAgent(BaseModel):
queue: asyncio.Queue | None = None
done: bool = False
timeout: int = 10
running_tasks: set[asyncio.Task] = set()
...
def __aiter__(self):
self.done = False
self.queue = asyncio.Queue()
task = asyncio.create_task(self.run())
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
return self
...
(of course, use the same pattern if there are other calls to .create_task
).
As for were this behavior is documented, note the "important" note at: https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
Upvotes: 2