Tomas Lawton
Tomas Lawton

Reputation: 113

How to do multiprocessing with non-blocking fastapi websockets

Background: I'm making a websocket interface which controls some CPU-bound processes in the ProcessPoolExecutor. These processes send regular client updates using the queue and only terminate when a "stop" message is received (long running).

Problem: After reading the docs, I haven't been able to get ProcessPoolExecutor to work so that a) the socket remains unblocked (required for parallel calls to update_cicada()). And b), the process can be terminated with a websocket message.

What am I missing here? The default threaded executor works, but in this case, MP parallelism is faster (minimal I/O). Useful information (no websockets): How to do multiprocessing in FastAPI. Threading rather than multiprocessing. Make an CPU-bound task asynchronous for FastAPI WebSockets

Example

@app.websocket_route("/ws")
async def websocket_endpoint(websocket: WebSocket):
    pool = ProcessPoolExecutor() #used instead of default/threadpool
    loop = asyncio.get_event_loop()
    queue = asyncio.Queue()
    s = CICADA(clip_class, queue)

    await websocket.accept()

    while True:
        data = await websocket.receive_json()

        #should be non-blocking and terminate on "stop" message
        loop.run_in_executor(pool, update_cicada(data, s, queue))

        #update_cicada adds to queue thus updating client
        result = await queue.get()
        websocket.send_json(result) #so process can update client whithout terminating

Upvotes: 4

Views: 6776

Answers (2)

fbzyx
fbzyx

Reputation: 349

I also had some problems implementing FastApi with websockets for multiple clients in parallel. As described in the documentation (At the end in the section "Tip"), you should have a look at starlette websocket example for more complex implementations.

They have a full working example.

Upvotes: 0

JarroVGIT
JarroVGIT

Reputation: 5359

I read this question 2 days ago and couldn't shake it. There are multiple concepts in play here, and some (if not all) are quite complex. I created a prototype for myself to fully understand what was happening, and this is a working example. I added many comments to explain what or why something is happening.

Couple of pointers though right of the bat:

  • An asyncio.Queue is not thread or process safe. That means, you can (and probably will) get a corrupt state when sharing such an object across processes. Such a Queue is good for sharing state across Tasks, as they al run on the same thread in the event loop.
  • multiprocessing.Queue is thread and process safe, but you will need a Manager() to handle the specifics. It essentially creates another subprocess to handle all communication with the Queue (from other processes).
  • Make sure that your code is not blocking other requests. In my example below, I use asyncio.sleep() to yield control back to the event loop, to allow other tasks in the event loop to continue processing. If I hadn't, it would block the current task in the infinite while loop.

I tested the below with 4 concurrent requests (I used wscat for that from the command line). Please note that I am in no way an expert in asyncio or multiprocessing, so I am not claiming that these are best practices.

import asyncio
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
from queue import Empty
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import time

app = FastAPI()

#do not re-create the pool with every request, only create it once
pool = ProcessPoolExecutor()


def long_running_task(q: mp.Queue) -> str:
    # This would be your update_cicada function
    for i in range(5):
        #this represents some blocking IO
        time.sleep(3)
        q.put(f"'result': 'Iteration {i}'")
    return "done!"


@app.websocket_route("/ws")
async def websocket_endpoint(websocket: WebSocket):
    loop = asyncio.get_event_loop()
    
    #To use Queue's across processes, you need to use the mp.Manager()
    m = mp.Manager()
    q = m.Queue()
    
    await websocket.accept()
    
    #run_in_executor will return a Future object. Normally, you would await
    #such an method but we want a bit more control over it. 
    result = loop.run_in_executor(pool, long_running_task, q)
    while True:
        
        #None of the coroutines called in this block (e.g. send_json()) 
        # will yield back control. asyncio.sleep() does, and so it will allow
        # the event loop to switch context and serve multiple requests 
        # concurrently.
        await asyncio.sleep(0)

        try:
            #see if our long running task has some intermediate result.
            # Will result None if there isn't any.
            q_result = q.get(block=False)
        except Empty:
            #if q.get() throws Empty exception, then nothing was 
            # available (yet!).
            q_result = None

        #If there is an intermediate result, let's send it to the client.
        if q_result:
            try:
                await websocket.send_json(q_result)
            except WebSocketDisconnect:
                #This happens if client has moved on, we should stop the long
                #  running task
                result.cancel()
                #break out of the while loop.
                break
        
        #We want to stop the connection when the long running task is done.
        if result.done():
            try:
                await websocket.send_json(result.result())
                await websocket.close()  
            except WebSocketDisconnect:
                #This happens if client has moved on, we should stop the long
                #  running task
                result.cancel()
            finally:
                #Make sure we break out of the infinte While loop.
                break
            
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000,  )

Upvotes: 4

Related Questions