Marcel Kopera
Marcel Kopera

Reputation: 11

Error in signature when trying to pass shared state to my authorization middleware in Axum

I'm trying to pass DB pool which is in State(app_state): State<AppState> to authorize fn inside my auth.rs from api.rs where I'm using this middleware as JWT to protect my urls like this.

api.rs

pub fn api_routes() -> Router<AppState> {
    let route = Router::new()
        .route("/api", get(main_page_get).post(main_page_post))
        .route("/api/{query}", get(id_page_get))
        .route("/demo.json", get(get_demo_json).put(put_demo_json)
        .layer(middleware::from_fn(auth::authorize))
    );

    return route;
}

auth.rs

pub async fn authorize(
    mut req: Request<Body>,
    State(app_state): State<AppState>,
    next: Next
) -> Result<Response<Body>, AuthError> {

I'm getting error in this line middleware::from_fn(auth::authorize) saying:

the trait bound `axum::middleware::FromFn<fn(http::Request<Body>, axum::extract::State<AppState>, Next) -> impl Future<Output = Result<Response<Body>, AuthError>> {authorize}, (), Route, _>: tower_service::Service<http::Request<Body>>` is not satisfied
the trait `tower_service::Service<http::Request<Body>>` is not implemented for `FromFn<fn(Request<Body>, ..., ...) -> ... {authorize}, ..., ..., ...>`
the following other types implement trait `tower_service::Service<Request>`:
  axum::middleware::FromFn<F, S, I, (T1, T2)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
and 8 othersrustcClick for full compiler diagnostic
api.rs(17, 10): required by a bound introduced by this call
method_routing.rs(967, 21): required by a bound in `MethodRouter::<S, E>::layer`

And I'm not sure why, because the authorize signature looks fine.

EDIT: My shared state is implemented like this

use std::sync::Arc;
use sqlx::postgres::PgPool;

#[derive(Clone)]
pub struct AppState {
    pub html_path: Arc<String>,
    pub db_pool: Arc<PgPool>,
}

impl AppState {
    pub fn new(html_path: String, db_pool: PgPool) -> Self {
        AppState {
            html_path: Arc::new(html_path),
            db_pool: Arc::new(db_pool),
        }
    }
}

Upvotes: 0

Views: 37

Answers (1)

Revanth Shalon
Revanth Shalon

Reputation: 141

This is relatively a simple error. This is how the code should be

pub fn api_routes(app_state: AppState) -> Router {
    Router::new()
        .route("/api", get(main_page_get).post(main_page_post))
        .route("/api/{query}", get(id_page_get))
        .route("/demo.json", get(get_demo_json).put(put_demo_json)
        .layer(middleware::from_fn_with_state(app_state.clone() ,auth::authorize))
    )
}

Also, PgPool by default is cloneable, and String is cloneable as well. So placing an Arc on top of these 2 seems pretty useless. The criteria for app state to hold arc fields would be if the field needs modifications on runtime, then you should introduce Arc with a RwLock or Mutex. When you app state is not expected to change, using clone would suffice.

#[derive(Debug, Clone)]
pub struct AppState {
    html_path: String,
    db_pool: PgPool,
}

impl AppState {
    pub async fn new(config: &AppConfig) -> Result<Self> {
        let html_path = //Get it from your config
        let pool = //Initialize your pool here
        Ok(Self {
            html_path,
            pool
            })
    }
    
    pub fn html_path(&self) -> String {
        self.html_path.clone()
    }

    pub fn db_pool(&self) -> PgPool {
        self.pool.clone()
    }

}

Also for auth.rs. State should be the first parameter.then request and all other stuff.

pub async fn authorize(
    State(app_state): State<AppState>,
    mut req: Request,
    next: Next
) -> Result<Response<Body>, AuthError> {

edit:

This is how I pass my app state to router function

use axum::Router;

mod branch;
mod health;

use crate::state::SharedAppState;

/// Initializes the routes for the StaffHub application.
///
/// This function sets up the routing for the application, including health check routes
/// and API versioning.
///
/// # Arguments
///
/// * `app_state` - A shared application state used across routes.
///
/// # Returns
///
/// A `Router` instance with the configured routes.
pub fn init_routes(app_state: SharedAppState) -> Router {
    // Initialize health check routes under the "/health" path.
    let health_routes_v1 =
        Router::new().nest("/health", health::init_health_routes(app_state.clone()));
    let branch_routes_v1 =
        Router::new().nest("/branch", branch::init_branch_routes(app_state.clone()));

    // Merge the health check routes.
    let merged_routes_v1 = Router::new()
        .merge(branch_routes_v1)
        .merge(health_routes_v1);

    // Nest the merged routes under the "/v1" version path.
    let v1 = Router::new().nest("/v1", merged_routes_v1);

    // Nest the versioned routes under the "/api" path.
    let api_routes_v1 = Router::new().nest("/api", v1);

    // Nest the API routes under the "/staff-hub" base path.
    Router::new().nest("/staff-hub", api_routes_v1)
}

state.rs

mod jwks;
mod keycloak;

use crate::config::StaffHubConfig;
use crate::state::jwks::JwksClient;
use crate::state::keycloak::KeycloakClient;
use anyhow::Context;
use sqlx::{PgPool, postgres::PgPoolOptions};
use std::{sync::Arc, time::Duration};

/// Application state structure.
///
/// This structure holds the state of the application, including the database connection pool,
/// the Keycloak client, and the JWKS client.
///
/// Fields:
/// - `pool`: The PostgreSQL connection pool.
/// - `keycloak`: The Keycloak client.
/// - `jwks`: The JWKS client.
pub struct AppState {
    pool: PgPool,
    keycloak: KeycloakClient,
    jwks: JwksClient,
}

/// Shared application state type.
///
/// This type represents a shared reference to the application state.
pub type SharedAppState = Arc<AppState>;

impl AppState {
    /// Initializes the application state.
    ///
    /// This function creates a new `AppState` instance by configuring and initializing
    /// the PostgreSQL connection pool using the provided configuration.
    ///
    /// # Parameters
    ///
    /// - `config`: A reference to the `StaffHubConfig` containing the configuration details.
    ///
    /// # Returns
    ///
    /// An `anyhow::Result` containing the initialized `AppState` instance.
    pub async fn init_state(config: &StaffHubConfig) -> anyhow::Result<Self> {
        let pool = PgPoolOptions::new()
            .min_connections(config.database().min_connections())
            .max_connections(config.database().max_connections())
            .acquire_timeout(Duration::from_secs(config.database().connect_timeout()))
            .max_lifetime(Some(Duration::from_secs(config.database().max_lifetime())))
            .idle_timeout(Duration::from_secs(config.database().idle_timeout()))
            .connect_lazy(config.database().connection_string().as_str())?;
        let keycloak = KeycloakClient::new(config.keycloak());
        let jwks = JwksClient::new(config.keycloak().jwks_uri())
            .await
            .context("Jwks Client Setup")?;

        Ok(Self {
            pool,
            keycloak,
            jwks,
        })
    }

    /// Returns a shared reference to the application state.
    ///
    /// This function wraps the `AppState` instance in an `Arc` to create a shared reference.
    ///
    /// # Returns
    ///
    /// A `SharedAppState` representing the shared application state.
    pub fn get_shared_state(self) -> SharedAppState {
        Arc::new(self)
    }

    /// Returns the PostgreSQL connection pool.
    ///
    /// # Returns
    ///
    /// A `PgPool` representing the PostgreSQL connection pool.
    pub fn pool(&self) -> PgPool {
        self.pool.clone()
    }

    /// Returns a reference to the Keycloak client.
    ///
    /// # Returns
    ///
    /// A reference to the `KeycloakClient`.
    pub fn keycloak(&self) -> &KeycloakClient {
        &self.keycloak
    }

    /// Returns a reference to the JWKS client.
    ///
    /// # Returns
    ///
    /// A reference to the `JwksClient`.
    pub fn jwks(&self) -> &JwksClient {
        &self.jwks
    }
}

main.rs

use crate::config::StaffHubConfig;
use crate::state::AppState;
use anyhow::Context;
use tokio::net::TcpListener;

mod config;
mod dtos;
mod entities;
mod errors;
mod handlers;
mod middlewares;
mod routes;
mod state;
mod utils;

/// Initializes and starts the StaffHub service.
///
/// This asynchronous function performs the following steps:
/// 1. Loads the application configuration.
/// 2. Initializes the application state.
/// 3. Sets up the application routes.
/// 4. Binds the server to the specified address and port.
/// 5. Starts serving the application.
///
/// # Returns
///
/// An `anyhow::Result` which is `Ok` if the service starts successfully, or an error if any step fails.
pub async fn init_service() -> anyhow::Result<()> {
    // Load the application configuration.
    let app_config = StaffHubConfig::load_config().context("Configuration Load")?;

    // Initialize the application state.
    let app_state = AppState::init_state(&app_config)
        .await
        .context("App State Initialization")?;

    let shared_state = app_state.get_shared_state();

    // Setting up background tasks
    tokio::spawn(utils::background_jwks_refresh(shared_state.clone()));

    // Set up the application routes.
    let app_routes = routes::init_routes(shared_state.clone());

    // Bind the server to the specified address and port.
    let listener = TcpListener::bind(app_config.server().addr())
        .await
        .context("Server Port Bind")?;

    // Start serving the application.
    axum::serve(listener, app_routes.into_make_service())
        .await
        .context("Application Serve")?;

    Ok(())
}

Upvotes: 1

Related Questions