Reputation: 113
Background: I'm making a websocket interface which controls some CPU-bound processes in the ProcessPoolExecutor. These processes send regular client updates using the queue and only terminate when a "stop" message is received (long running).
Problem: After reading the docs, I haven't been able to get ProcessPoolExecutor to work so that a) the socket remains unblocked (required for parallel calls to update_cicada()). And b), the process can be terminated with a websocket message.
What am I missing here? The default threaded executor works, but in this case, MP parallelism is faster (minimal I/O). Useful information (no websockets): How to do multiprocessing in FastAPI. Threading rather than multiprocessing. Make an CPU-bound task asynchronous for FastAPI WebSockets
Example
@app.websocket_route("/ws")
async def websocket_endpoint(websocket: WebSocket):
pool = ProcessPoolExecutor() #used instead of default/threadpool
loop = asyncio.get_event_loop()
queue = asyncio.Queue()
s = CICADA(clip_class, queue)
await websocket.accept()
while True:
data = await websocket.receive_json()
#should be non-blocking and terminate on "stop" message
loop.run_in_executor(pool, update_cicada(data, s, queue))
#update_cicada adds to queue thus updating client
result = await queue.get()
websocket.send_json(result) #so process can update client whithout terminating
Upvotes: 4
Views: 6776
Reputation: 349
I also had some problems implementing FastApi with websockets for multiple clients in parallel. As described in the documentation (At the end in the section "Tip"), you should have a look at starlette websocket example for more complex implementations.
They have a full working example.
Upvotes: 0
Reputation: 5359
I read this question 2 days ago and couldn't shake it. There are multiple concepts in play here, and some (if not all) are quite complex. I created a prototype for myself to fully understand what was happening, and this is a working example. I added many comments to explain what or why something is happening.
Couple of pointers though right of the bat:
asyncio.Queue
is not thread or process safe. That means, you can (and probably will) get a corrupt state when sharing such an object across processes. Such a Queue is good for sharing state across Tasks
, as they al run on the same thread in the event loop.multiprocessing.Queue
is thread and process safe, but you will need a Manager()
to handle the specifics. It essentially creates another subprocess to handle all communication with the Queue (from other processes).asyncio.sleep()
to yield control back to the event loop, to allow other tasks in the event loop to continue processing. If I hadn't, it would block the current task in the infinite while loop.I tested the below with 4 concurrent requests (I used wscat
for that from the command line). Please note that I am in no way an expert in asyncio
or multiprocessing
, so I am not claiming that these are best practices.
import asyncio
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
from queue import Empty
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import time
app = FastAPI()
#do not re-create the pool with every request, only create it once
pool = ProcessPoolExecutor()
def long_running_task(q: mp.Queue) -> str:
# This would be your update_cicada function
for i in range(5):
#this represents some blocking IO
time.sleep(3)
q.put(f"'result': 'Iteration {i}'")
return "done!"
@app.websocket_route("/ws")
async def websocket_endpoint(websocket: WebSocket):
loop = asyncio.get_event_loop()
#To use Queue's across processes, you need to use the mp.Manager()
m = mp.Manager()
q = m.Queue()
await websocket.accept()
#run_in_executor will return a Future object. Normally, you would await
#such an method but we want a bit more control over it.
result = loop.run_in_executor(pool, long_running_task, q)
while True:
#None of the coroutines called in this block (e.g. send_json())
# will yield back control. asyncio.sleep() does, and so it will allow
# the event loop to switch context and serve multiple requests
# concurrently.
await asyncio.sleep(0)
try:
#see if our long running task has some intermediate result.
# Will result None if there isn't any.
q_result = q.get(block=False)
except Empty:
#if q.get() throws Empty exception, then nothing was
# available (yet!).
q_result = None
#If there is an intermediate result, let's send it to the client.
if q_result:
try:
await websocket.send_json(q_result)
except WebSocketDisconnect:
#This happens if client has moved on, we should stop the long
# running task
result.cancel()
#break out of the while loop.
break
#We want to stop the connection when the long running task is done.
if result.done():
try:
await websocket.send_json(result.result())
await websocket.close()
except WebSocketDisconnect:
#This happens if client has moved on, we should stop the long
# running task
result.cancel()
finally:
#Make sure we break out of the infinte While loop.
break
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, )
Upvotes: 4