Reputation: 57
I'm trying to use a custom middleware in Axum to implement JWT authentication. However, I'm unable to compile it when attempting to return an error for failed validation. Below is the JWT middleware validation code I've written. Can you advise me on how to modify it to achieve the desired functionality?
//custom_middleware.rs
use axum::http::StatusCode;
use axum::{extract::Request, response::Response};
use futures_util::future::BoxFuture;
use jsonwebtoken::{
decode, errors::Error as JwtError, Algorithm, DecodingKey, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub id: usize,
pub exp: usize,
}
#[derive(Clone)]
pub struct MyLayer;
impl<S> Layer<S> for MyLayer {
type Service = MyMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
MyMiddleware { inner }
}
}
#[derive(Clone)]
pub struct MyMiddleware<S> {
inner: S,
}
impl<S> Service<Request> for MyMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
match has_permission(&req) {
Ok(_) => {
let future = self.inner.call(req);
Box::pin(async move {
let response: Response = future.await?;
Ok(response)
})
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
}
}
}
fn has_permission(req: &Request) -> Result<TokenData<Claims>, (StatusCode, &'static str)> {
let secret = "baby195lxl";
let authorization_header_option = req.headers().get("authorization");
if authorization_header_option.is_none() {
return Err((StatusCode::BAD_REQUEST, "authorization header is none"));
}
let authentication_token: String = authorization_header_option
.unwrap()
.to_str()
.unwrap_or("")
.to_string();
if authentication_token.is_empty() {
return Err((StatusCode::BAD_REQUEST, "authorization header is empty"));
}
let token_result: Result<TokenData<Claims>, JwtError> = decode::<Claims>(
&authentication_token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::new(Algorithm::HS256),
);
match token_result {
Ok(_token) => Ok(_token),
Err(_e) => Err((StatusCode::UNAUTHORIZED, "Token Error")),
}
}
//main.rs
use axum::{
body::Bytes,
extract::{Json, Request, State},
routing::{get, post},
Router,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::Deserialize;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::trace::TraceLayer;
use tracing::Span;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod custom_middleware;
use custom_middleware::Claims;
use custom_middleware::MyLayer;
mod state;
use state::AppState;
#[derive(Deserialize, Debug, PartialEq)]
struct User {
account: usize,
password: String,
}
async fn register(State(state): State<AppState>, Json(user): Json<User>) -> String {
let store_user = User {
account: 195,
password: "world".to_string(),
};
if user == store_user {
let expiration = SystemTime::now() + Duration::from_secs(30 * 60);
let exp_timestamp = expiration.duration_since(UNIX_EPOCH).unwrap().as_secs();
let claims = Claims {
id: user.account,
exp: exp_timestamp as usize,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(state.secret.as_bytes()),
)
.unwrap();
token
} else {
"hello, world!".to_string()
}
}
async fn login(State(state): State<AppState>, req: Request) -> Json<Claims> {
let token = req
.headers()
.get("Authorization")
.unwrap()
.to_str()
.unwrap();
let payload = decode::<Claims>(
token,
&DecodingKey::from_secret(state.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
)
.unwrap();
Json(payload.claims)
}
async fn protected(_req: Request) -> String {
"World!".to_string()
}
#[tokio::main]
async fn main() {
let state = AppState {
secret: "baby195lxl".to_string(),
};
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new("debug"))
.with(tracing_subscriber::fmt::layer())
.init();
let app = Router::new()
.route("/protected", get(protected))
.layer(MyLayer)
.route("/register", post(register))
.route("/login", post(login))
.with_state(state)
.layer(TraceLayer::new_for_http().on_body_chunk(
|chunk: &Bytes, latency: Duration, _span: &Span| {
tracing::debug!("streaming {} bytes in {:?}", chunk.len(), latency);
},
));
let listener = tokio::net::TcpListener::bind("127.0.0.1:5000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `Pin<Box<dyn Future<Output = ...> + Send>>`, found `Result<_, (StatusCode, &str)>`
I have not yet found a solution that doesn't produce errors. The configuration of Cargo.toml is as follows, and the Rust compiler version I am using is rustc 1.76.0 (07dca489a 2024-02-04). I would appreciate it if someone could help me clear up this confusion, and any reply are welcome. Thanks.
[dependencies]
axum = "^0.7"
tokio = { version = "^1.36", features = ["full"] }
tower-http = { version = "^0.5", features = ["trace"] }
tracing = "^0.1"
tracing-subscriber = { version = "^0.3", features = ["env-filter"] }
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9.2.0"
tower = "0.4.13"
futures-util = "0.3.30"
Upvotes: 0
Views: 1338
Reputation: 60517
The solution to your immediate problem is to wrap your Err(...)
in Box::pin(async { ... })
since you've said your Self::Future
is a BoxFuture
.
However, that is not the solution since you would then discover that you're using S::Error
instead of (StatusCode, &'static str)
as your error type, which means you'd have to abide by the nested service's error type. Which you probably don't want. You could make a wrapping error enum that can express either your own error type or one from a nested service, but that isn't the solution either because...
Axum expects its Service
s to not fail. You can see Router::layer
is constrained such that the layer's service's Error
type is Into<Infallible>
which means it can't return an error. This is because, while tower is designed such that error types are well-described, axum expects that errors are immediately converted to Response
s.
So you could refactor your Layer
and Service
implementations to follow that, but the better solution would be to use axum's middleware::from_fn
helper function which makes it all much nicer to use:
use axum::middleware::Next;
async fn auth_middleware(req: Request, next: Next) -> Result<Response, (StatusCode, &'static str)> {
match has_permission(&req) {
Ok(_) => {
let response = next.run(req).await;
Ok(response)
}
Err(_) => Err((StatusCode::BAD_REQUEST, "bad request")),
}
}
let app = Router::new()
...
.layer(axum::middleware::from_fn(auth_middleware))
...
Upvotes: 1