Forshank
Forshank

Reputation: 1149

Websockets Unexpectedly disconnecting after message broadcast - FastAPI

Im working with Wbsockets and FastAPI to make a simple chatbot powered with AI.

For each conversation there are two possible modes copilot (the user message will be answered by a person using the UI) autopilot(the message will be automatically handled and answered by the server using some Chain)

the Problem Im facing is the following and it has to do with the message sending logic I think:

when an app with client_id=2 sends a message (this being the user), regardless of the conversation's mode the server handles it correctly. the problem comes when an app with client_id=1 (this being me on the other side) sends a message back, manually. This message is being received correctly by the user (client_id=2) but the following time I send a message from the app with client_id=1 it never reaches the server, also if the app with client_id=2 sends me a message it will never reach me. This makes me think that somehow after sending the message the app with client_id=1 is being disconnected but it does not trigger any of the exception catches nor get removed from the connection lists, This websocket also never changes its status to DISCONNECTED

here is my web socket handling code:

@dataclass
class Connection:
    client_id: int
    conversation_id: str
    websocket: WebSocket


class ConnectionManager:
    def __init__(self):
        self.active_connections: List[Connection] = []

    async def connect(self, websocket: WebSocket, client_id: int, conversation_id: str):
        try:
            await websocket.accept()
            connection = Connection(
                client_id=client_id, conversation_id=conversation_id, websocket=websocket)
            self.active_connections.append(connection)
            print(f"Connected: {client_id} with conversation {conversation_id}")
        except WebSocketDisconnect:
            print(f"WebSocket for {client_id} disconnected during connect")

    def disconnect(self, websocket: WebSocket):
        self.active_connections = [
            conn for conn in self.active_connections if conn.websocket != websocket
        ]

    async def broadcast(self, message: str, AI_message: str, conversation_id: str, conversation_mode: Literal["copilot", "autopilot"], client_id: int):
        for connection in self.active_connections:
            try:
                print(f"Connection: {connection.client_id}, WebSocket State: {connection.websocket.client_state}")
                if connection.websocket.client_state == WebSocketState.CONNECTED:
                    if client_id == 2:
                        if conversation_mode == "autopilot" and connection.client_id == 1:
                            await connection.websocket.send_text(message)
                            await connection.websocket.send_text(AI_message)
                        if conversation_mode == "autopilot" and connection.client_id == 2 and connection.conversation_id == conversation_id:
                            await connection.websocket.send_text(AI_message)
                        if conversation_mode == "copilot" and connection.client_id == 1:
                            await connection.websocket.send_text(message)
                    if client_id == 1:
                        if connection.client_id == 2 and connection.conversation_id == conversation_id:
                            await connection.websocket.send_text(message)
            except Exception as e:
                print(f"Error in broadcasting: {e}")
                connection_manager.disconnect(connection.websocket)            
                await connection.websocket.close()

then I'm instanciating this handler and using it in my WS endpoint


connection_manager = ConnectionManager()


@app.websocket("/ws/{client_id}/{conversation_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int, conversation_id: str):

I think the problem has to do with the broadcast method, I will explain the desired logic for it now:


 async def broadcast(self, message: str, AI_message: str, conversation_id: str, conversation_mode: Literal["copilot", "autopilot"], client_id: int):
        for connection in self.active_connections:
            try:
                print(f"Connection: {connection.client_id}, WebSocket State: {connection.websocket.client_state}")
                if connection.websocket.client_state == WebSocketState.CONNECTED:
                    if client_id == 2:
                        if conversation_mode == "autopilot" and connection.client_id == 1:
                            await connection.websocket.send_text(message)
                            await connection.websocket.send_text(AI_message)
                        if conversation_mode == "autopilot" and connection.client_id == 2 and connection.conversation_id == conversation_id:
                            await connection.websocket.send_text(AI_message)
                        if conversation_mode == "copilot" and connection.client_id == 1:
                            await connection.websocket.send_text(message)
                    if client_id == 1:
                        if connection.client_id == 2 and connection.conversation_id == conversation_id:
                            await connection.websocket.send_text(message)
            except Exception as e:
                print(f"Error in broadcasting: {e}")
                connection_manager.disconnect(connection.websocket)            
                await connection.websocket.close()

Although I think it is not related to the problem, these are the payloads that are being sent/received:


broadcast_data = {
                    "suggestion": suggestion.get("output") if suggestion else "",
                    "eventType": "message",
                    "message": message_to_send,
                    "createdAt": str(new_message.created_at),
                    "sender": sender,
                    "conversationId": str(conversation["_id"]),
                    "chatAppClientId": 1,
                    "phoneNumber": json_data["phoneNumber"],
                }`
 AI_broadcast_data = {
                    "suggestion": "",
                    "eventType": "message",
                    "message": suggestion.get("output") if suggestion else "",
                    "createdAt": str(new_message.created_at),
                    "sender": "assistant",
                    "conversationId": str(conversation["_id"]),
                    "chatAppClientId": 1,
                    "phoneNumber": json_data["phoneNumber"],
                } 

Upvotes: 0

Views: 67

Answers (0)

Related Questions