this repo has no description
1use axum::{ 2 extract::FromRequestParts, 3 http::{StatusCode, request::Parts, header::AUTHORIZATION}, 4 response::{IntoResponse, Response}, 5 Json, 6}; 7use serde_json::json; 8 9use crate::state::AppState; 10use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token, validate_bearer_token_allow_deactivated}; 11 12pub struct BearerAuth(pub AuthenticatedUser); 13 14#[derive(Debug)] 15pub enum AuthError { 16 MissingToken, 17 InvalidFormat, 18 AuthenticationFailed, 19 AccountDeactivated, 20 AccountTakedown, 21} 22 23impl IntoResponse for AuthError { 24 fn into_response(self) -> Response { 25 let (status, error, message) = match self { 26 AuthError::MissingToken => ( 27 StatusCode::UNAUTHORIZED, 28 "AuthenticationRequired", 29 "Authorization header is required", 30 ), 31 AuthError::InvalidFormat => ( 32 StatusCode::UNAUTHORIZED, 33 "InvalidToken", 34 "Invalid authorization header format", 35 ), 36 AuthError::AuthenticationFailed => ( 37 StatusCode::UNAUTHORIZED, 38 "AuthenticationFailed", 39 "Invalid or expired token", 40 ), 41 AuthError::AccountDeactivated => ( 42 StatusCode::UNAUTHORIZED, 43 "AccountDeactivated", 44 "Account is deactivated", 45 ), 46 AuthError::AccountTakedown => ( 47 StatusCode::UNAUTHORIZED, 48 "AccountTakedown", 49 "Account has been taken down", 50 ), 51 }; 52 53 (status, Json(json!({ "error": error, "message": message }))).into_response() 54 } 55} 56 57fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 58 let auth_header = auth_header.trim(); 59 60 if auth_header.len() < 8 { 61 return Err(AuthError::InvalidFormat); 62 } 63 64 let prefix = &auth_header[..7]; 65 if !prefix.eq_ignore_ascii_case("bearer ") { 66 return Err(AuthError::InvalidFormat); 67 } 68 69 let token = auth_header[7..].trim(); 70 if token.is_empty() { 71 return Err(AuthError::InvalidFormat); 72 } 73 74 Ok(token) 75} 76 77pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 78 let header = auth_header?; 79 let header = header.trim(); 80 81 if header.len() < 7 { 82 return None; 83 } 84 85 if !header[..7].eq_ignore_ascii_case("bearer ") { 86 return None; 87 } 88 89 let token = header[7..].trim(); 90 if token.is_empty() { 91 return None; 92 } 93 94 Some(token.to_string()) 95} 96 97impl FromRequestParts<AppState> for BearerAuth { 98 type Rejection = AuthError; 99 100 async fn from_request_parts( 101 parts: &mut Parts, 102 state: &AppState, 103 ) -> Result<Self, Self::Rejection> { 104 let auth_header = parts 105 .headers 106 .get(AUTHORIZATION) 107 .ok_or(AuthError::MissingToken)? 108 .to_str() 109 .map_err(|_| AuthError::InvalidFormat)?; 110 111 let token = extract_bearer_token(auth_header)?; 112 113 match validate_bearer_token(&state.db, token).await { 114 Ok(user) => Ok(BearerAuth(user)), 115 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 116 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 117 Err(_) => Err(AuthError::AuthenticationFailed), 118 } 119 } 120} 121 122pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 123 124impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 125 type Rejection = AuthError; 126 127 async fn from_request_parts( 128 parts: &mut Parts, 129 state: &AppState, 130 ) -> Result<Self, Self::Rejection> { 131 let auth_header = parts 132 .headers 133 .get(AUTHORIZATION) 134 .ok_or(AuthError::MissingToken)? 135 .to_str() 136 .map_err(|_| AuthError::InvalidFormat)?; 137 138 let token = extract_bearer_token(auth_header)?; 139 140 match validate_bearer_token_allow_deactivated(&state.db, token).await { 141 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 142 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 143 Err(_) => Err(AuthError::AuthenticationFailed), 144 } 145 } 146} 147 148#[cfg(test)] 149mod tests { 150 use super::*; 151 152 #[test] 153 fn test_extract_bearer_token() { 154 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 155 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 156 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 157 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 158 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 159 160 assert!(extract_bearer_token("Basic abc123").is_err()); 161 assert!(extract_bearer_token("Bearer").is_err()); 162 assert!(extract_bearer_token("Bearer ").is_err()); 163 assert!(extract_bearer_token("abc123").is_err()); 164 assert!(extract_bearer_token("").is_err()); 165 } 166}