Sumit Kumar
Sumit Kumar

Reputation: 422

Handling 5K websocket connections with rust axum server

I am using axum server for handling websocket connections. I am reading messages from kafka topic and if the websocket connection is established to receive messages from a particular session then the message will be relayed on to that connection. I am facing issues when I am load testing the application for 5000 connections (4CPU and 8GB RAM). I am getting i/o timeout error on client side. I am benchmarking this service against Java spring boot but java service is working fine with very low latency and the configuration is same.

I am sharing the important code snippets, though I am using the standard code from example

lazy_static! {
    pub static ref ACTIVE_SESSIONS: DashMap<String, SplitSink<WebSocket, axum::extract::ws::Message>> =
        DashMap::new();
}

#[tokio::main(flavor = "multi_thread", worker_threads = 16)]
let consumer_handles = (0..num_partitions)
        .map(|_| {
            tokio::spawn(run_async_processor(
                broker_url.to_owned(),
                random_group_id.to_owned(),
                topic.to_owned(),
                username.to_owned(),
                password.to_owned(),
                app_config.kafka.consumer.enable_auto_commit.to_owned(),
                offset.to_owned(),
            ))
        })
        .collect::<Vec<_>>();

match TcpListener::bind(&ADDR).await {
        Ok(listener) => {
            let local_addr = listener.local_addr().expect("Failed to get local address");
            tracing::info!("server is running on port {}", local_addr.port());
            axum::serve(listener, app.into_make_service())
                .tcp_nodelay(true)
                .await
                .expect("Server failed to start");
        }
        Err(e) => {
            tracing::error!("Failed to initiate the listener: {}", e);
            std::process::exit(1);
        }
    }
    
// handler function
pub async fn handler(
    ws: WebSocketUpgrade,
    Path(call_id): Path<String>,
    auth_header: Option<TypedHeader<Authorization<Bearer>>>,
    token: Option<Query<TokenQuery>>,
    State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
    let auth_token = match (auth_header, token) {
        (Some(auth_header), _) => auth_header.token().to_string(),
        (None, Some(token)) => token.token.clone(),
        _ => return (StatusCode::UNAUTHORIZED, "auth token is missing").into_response(),
    };

    // Check for valid JWT format
    if !is_valid_jwt(&auth_token) {
        return (StatusCode::UNAUTHORIZED, "invalid JWT format").into_response();
    }

    let account_info =
        match auth_handler::get_account_info(&auth_token, &state.config).await {
            Ok(account_info) => account_info,
            Err(e) => {
                tracing::error!("failed to get the account info {}", e);
                return (StatusCode::UNAUTHORIZED, e.to_string()).into_response();
            }
        };
    let _auth_context = match auth_handler::get_auth_context(
        &account_info,
        &auth_token,
        None,
        &state.auth_properties,
    )
    .await
    {
        Ok(auth_context) => auth_context,
        Err(e) => {
            tracing::error!("failed to authorize the user");
            return (StatusCode::UNAUTHORIZED, e.to_string()).into_response();
        }
    };

    ws.on_upgrade(move |socket| handle_socket(socket, call_id))
}

async fn handle_socket(socket: WebSocket, call_id: String) {
    let (sender, mut receiver) = socket.split();
    {
        // let mut active_sessions = ACTIVE_SESSIONS.lock().await;
        ACTIVE_SESSIONS.insert(call_id.clone(), sender);
    }

    while let Some(msg) = receiver.next().await {
        match msg {
            Ok(axum::extract::ws::Message::Text(text)) => {
                match serde_json::from_str::<Value>(&text) {
                    Ok(json) => {
                        match json.get("type").and_then(Value::as_str) {
                            Some("ping") => {
                                // Handle the ping message
                                tracing::info!("Received ping message");
                                let response = "{\"type\": \"pong\"}";
                                // let mut active_sessions = ACTIVE_SESSIONS.lock().await;
                                if let Some(mut ws) = ACTIVE_SESSIONS.get_mut(&call_id) {
                                    let msg =
                                        axum::extract::ws::Message::Text(response.to_string());
                                    if let Err(e) = ws.send(msg).await {
                                        tracing::error!(
                                            "error in sending websocket message: {}",
                                            e
                                        );
                                    }
                                    tracing::info!("Sent pong message");
                                }
                            }
                            _ => {
                                tracing::warn!("Unknown message type");
                            }
                        }
                    }
                    Err(e) => {
                        tracing::error!("failed to parse JSON message: {}", e);
                    }
                }
            }
            Ok(axum::extract::ws::Message::Close(frame)) => {
                if let Some(mut ws) = ACTIVE_SESSIONS.get_mut(&call_id) {
                    let msg = axum::extract::ws::Message::Close(frame);
                    if let Err(e) = ws.send(msg).await {
                        tracing::error!("error in sending close message: {}", e);
                    }
                }
                break;
            }
            Ok(_) => {
                // Ignore other message types (Ping, Pong, Binary)
                continue;
            }
            Err(e) => {
                tracing::error!("WebSocket error for call_id={}: {}", call_id, e);
                break;
            }
        }
    }

    // Clean up session when the loop ends
    {
        // let mut active_sessions = ACTIVE_SESSIONS.lock().await;
        ACTIVE_SESSIONS.remove(&call_id);
        tracing::info!("Removed session for call_id={}", call_id);
    }
}

// kafka consumer function
async fn run_async_processor(
    brokers: String,
    group_id: String,
    topic: String,
    username: String,
    password: String,
    enable_auto_commit: String,
    offset: String,
) {
    let consumer = create_consumer(
        brokers,
        group_id,
        username,
        password,
        enable_auto_commit,
        offset,
    );

    match consumer.subscribe(&[&topic]) {
        Ok(_) => {
            tracing::info!("subscribing to topic={}", topic);
        }
        Err(_e) => {
            panic!("failed to subscribe to the topic={}", topic);
        }
    }

    loop {
        match consumer.recv().await {
            Ok(m) => {
                match m.payload_view::<str>() {
                    None => continue,
                    Some(Ok(payload)) => {
                        let payload = payload.to_owned();

                        tokio::spawn(async move {
                            match serde_json::from_str::<CallTopicMessage>(&payload) {
                                Ok(call_topic_message) => {
                                    tracing::info!(
                                        "received message for call_id={}, type={:?}",
                                        call_topic_message.call_id,
                                        call_topic_message.message_type
                                    );
                                    // let mut active_sessions = ACTIVE_SESSIONS.lock().await;
                                    if let Some(mut ws) =
                                        ACTIVE_SESSIONS.get_mut(&call_topic_message.call_id)
                                    {
                                        tracing::info!(
                                            "active_session for call_id={} found",
                                            &call_topic_message.call_id
                                        );
                                        match serde_json::to_string(
                                            &call_topic_message.agent_assist_message,
                                        ) {
                                            Ok(msg) => {
                                                let msg = axum::extract::ws::Message::Text(msg);
                                                if let Err(e) = ws.send(msg).await {
                                                    tracing::error!(
                                                        "error in sending websocket message: {}",
                                                        e
                                                    );
                                                }
                                            }
                                            Err(e) => {
                                                tracing::error!(
                                                    "error in serializing kafka message {}",
                                                    e
                                                );
                                            }
                                        }
                                    }
                                }
                                Err(e) => {
                                    tracing::error!("error in deserializing kafka message {}", e);
                                }
                            };
                        });
                    }
                    Some(Err(_e)) => {
                        // do nothing
                    }
                };
            }
            Err(e) => tracing::error!("error in consuming from Kafka: {}", e),
        }
    }
}


the handler function of websocket is adding entry to ACTIVE_SESSIONS. On the other hand, kafka consumer consumes and check if there is a message for an already established connection and send it to that connection. I am new to rust and axum. Can I get some help how to make it better?

Upvotes: 0

Views: 183

Answers (0)

Related Questions