this repo has no description
1use axum::{ 2 Json, 3 extract::FromRequestParts, 4 http::{StatusCode, header::AUTHORIZATION, request::Parts}, 5 response::{IntoResponse, Response}, 6}; 7use serde_json::json; 8 9use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 validate_bearer_token_cached_allow_deactivated, 12}; 13use crate::state::AppState; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 AccountDeactivated, 23 AccountTakedown, 24} 25 26impl IntoResponse for AuthError { 27 fn into_response(self) -> Response { 28 let (status, error, message) = match self { 29 AuthError::MissingToken => ( 30 StatusCode::UNAUTHORIZED, 31 "AuthenticationRequired", 32 "Authorization header is required", 33 ), 34 AuthError::InvalidFormat => ( 35 StatusCode::UNAUTHORIZED, 36 "InvalidToken", 37 "Invalid authorization header format", 38 ), 39 AuthError::AuthenticationFailed => ( 40 StatusCode::UNAUTHORIZED, 41 "AuthenticationFailed", 42 "Invalid or expired token", 43 ), 44 AuthError::AccountDeactivated => ( 45 StatusCode::UNAUTHORIZED, 46 "AccountDeactivated", 47 "Account is deactivated", 48 ), 49 AuthError::AccountTakedown => ( 50 StatusCode::UNAUTHORIZED, 51 "AccountTakedown", 52 "Account has been taken down", 53 ), 54 }; 55 56 (status, Json(json!({ "error": error, "message": message }))).into_response() 57 } 58} 59 60fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 61 let auth_header = auth_header.trim(); 62 63 if auth_header.len() < 8 { 64 return Err(AuthError::InvalidFormat); 65 } 66 67 let prefix = &auth_header[..7]; 68 if !prefix.eq_ignore_ascii_case("bearer ") { 69 return Err(AuthError::InvalidFormat); 70 } 71 72 let token = auth_header[7..].trim(); 73 if token.is_empty() { 74 return Err(AuthError::InvalidFormat); 75 } 76 77 Ok(token) 78} 79 80pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 81 let header = auth_header?; 82 let header = header.trim(); 83 84 if header.len() < 7 { 85 return None; 86 } 87 88 if !header[..7].eq_ignore_ascii_case("bearer ") { 89 return None; 90 } 91 92 let token = header[7..].trim(); 93 if token.is_empty() { 94 return None; 95 } 96 97 Some(token.to_string()) 98} 99 100pub struct ExtractedToken { 101 pub token: String, 102 pub is_dpop: bool, 103} 104 105pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 106 let header = auth_header?; 107 let header = header.trim(); 108 109 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 110 let token = header[7..].trim(); 111 if token.is_empty() { 112 return None; 113 } 114 return Some(ExtractedToken { 115 token: token.to_string(), 116 is_dpop: false, 117 }); 118 } 119 120 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 121 let token = header[5..].trim(); 122 if token.is_empty() { 123 return None; 124 } 125 return Some(ExtractedToken { 126 token: token.to_string(), 127 is_dpop: true, 128 }); 129 } 130 131 None 132} 133 134impl FromRequestParts<AppState> for BearerAuth { 135 type Rejection = AuthError; 136 137 async fn from_request_parts( 138 parts: &mut Parts, 139 state: &AppState, 140 ) -> Result<Self, Self::Rejection> { 141 let auth_header = parts 142 .headers 143 .get(AUTHORIZATION) 144 .ok_or(AuthError::MissingToken)? 145 .to_str() 146 .map_err(|_| AuthError::InvalidFormat)?; 147 148 let token = extract_bearer_token(auth_header)?; 149 150 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 151 Ok(user) => Ok(BearerAuth(user)), 152 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 153 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 154 Err(_) => Err(AuthError::AuthenticationFailed), 155 } 156 } 157} 158 159pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 160 161impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 162 type Rejection = AuthError; 163 164 async fn from_request_parts( 165 parts: &mut Parts, 166 state: &AppState, 167 ) -> Result<Self, Self::Rejection> { 168 let auth_header = parts 169 .headers 170 .get(AUTHORIZATION) 171 .ok_or(AuthError::MissingToken)? 172 .to_str() 173 .map_err(|_| AuthError::InvalidFormat)?; 174 175 let token = extract_bearer_token(auth_header)?; 176 177 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 178 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 179 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 180 Err(_) => Err(AuthError::AuthenticationFailed), 181 } 182 } 183} 184 185#[cfg(test)] 186mod tests { 187 use super::*; 188 189 #[test] 190 fn test_extract_bearer_token() { 191 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 192 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 193 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 194 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 195 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 196 197 assert!(extract_bearer_token("Basic abc123").is_err()); 198 assert!(extract_bearer_token("Bearer").is_err()); 199 assert!(extract_bearer_token("Bearer ").is_err()); 200 assert!(extract_bearer_token("abc123").is_err()); 201 assert!(extract_bearer_token("").is_err()); 202 } 203}