this repo has no description
1use axum::{ 2 extract::FromRequestParts, 3 http::{StatusCode, request::Parts, header::AUTHORIZATION}, 4 response::{IntoResponse, Response}, 5 Json, 6}; 7use serde_json::json; 8use crate::state::AppState; 9use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 10pub struct BearerAuth(pub AuthenticatedUser); 11#[derive(Debug)] 12pub enum AuthError { 13 MissingToken, 14 InvalidFormat, 15 AuthenticationFailed, 16 AccountDeactivated, 17 AccountTakedown, 18} 19impl IntoResponse for AuthError { 20 fn into_response(self) -> Response { 21 let (status, error, message) = match self { 22 AuthError::MissingToken => ( 23 StatusCode::UNAUTHORIZED, 24 "AuthenticationRequired", 25 "Authorization header is required", 26 ), 27 AuthError::InvalidFormat => ( 28 StatusCode::UNAUTHORIZED, 29 "InvalidToken", 30 "Invalid authorization header format", 31 ), 32 AuthError::AuthenticationFailed => ( 33 StatusCode::UNAUTHORIZED, 34 "AuthenticationFailed", 35 "Invalid or expired token", 36 ), 37 AuthError::AccountDeactivated => ( 38 StatusCode::UNAUTHORIZED, 39 "AccountDeactivated", 40 "Account is deactivated", 41 ), 42 AuthError::AccountTakedown => ( 43 StatusCode::UNAUTHORIZED, 44 "AccountTakedown", 45 "Account has been taken down", 46 ), 47 }; 48 (status, Json(json!({ "error": error, "message": message }))).into_response() 49 } 50} 51fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 52 let auth_header = auth_header.trim(); 53 if auth_header.len() < 8 { 54 return Err(AuthError::InvalidFormat); 55 } 56 let prefix = &auth_header[..7]; 57 if !prefix.eq_ignore_ascii_case("bearer ") { 58 return Err(AuthError::InvalidFormat); 59 } 60 let token = auth_header[7..].trim(); 61 if token.is_empty() { 62 return Err(AuthError::InvalidFormat); 63 } 64 Ok(token) 65} 66pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 67 let header = auth_header?; 68 let header = header.trim(); 69 if header.len() < 7 { 70 return None; 71 } 72 if !header[..7].eq_ignore_ascii_case("bearer ") { 73 return None; 74 } 75 let token = header[7..].trim(); 76 if token.is_empty() { 77 return None; 78 } 79 Some(token.to_string()) 80} 81impl FromRequestParts<AppState> for BearerAuth { 82 type Rejection = AuthError; 83 async fn from_request_parts( 84 parts: &mut Parts, 85 state: &AppState, 86 ) -> Result<Self, Self::Rejection> { 87 let auth_header = parts 88 .headers 89 .get(AUTHORIZATION) 90 .ok_or(AuthError::MissingToken)? 91 .to_str() 92 .map_err(|_| AuthError::InvalidFormat)?; 93 let token = extract_bearer_token(auth_header)?; 94 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 95 Ok(user) => Ok(BearerAuth(user)), 96 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 97 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 98 Err(_) => Err(AuthError::AuthenticationFailed), 99 } 100 } 101} 102pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 103impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 104 type Rejection = AuthError; 105 async fn from_request_parts( 106 parts: &mut Parts, 107 state: &AppState, 108 ) -> Result<Self, Self::Rejection> { 109 let auth_header = parts 110 .headers 111 .get(AUTHORIZATION) 112 .ok_or(AuthError::MissingToken)? 113 .to_str() 114 .map_err(|_| AuthError::InvalidFormat)?; 115 let token = extract_bearer_token(auth_header)?; 116 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 117 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 118 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 119 Err(_) => Err(AuthError::AuthenticationFailed), 120 } 121 } 122} 123#[cfg(test)] 124mod tests { 125 use super::*; 126 #[test] 127 fn test_extract_bearer_token() { 128 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 129 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 130 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 131 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 132 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 133 assert!(extract_bearer_token("Basic abc123").is_err()); 134 assert!(extract_bearer_token("Bearer").is_err()); 135 assert!(extract_bearer_token("Bearer ").is_err()); 136 assert!(extract_bearer_token("abc123").is_err()); 137 assert!(extract_bearer_token("").is_err()); 138 } 139}