this repo has no description
1use axum::{ 2 Json, 3 extract::FromRequestParts, 4 http::{StatusCode, header::AUTHORIZATION, request::Parts}, 5 response::{IntoResponse, Response}, 6}; 7use serde_json::json; 8 9use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 validate_bearer_token_cached_allow_deactivated, 12}; 13use crate::state::AppState; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 AccountDeactivated, 23 AccountTakedown, 24 AdminRequired, 25} 26 27impl IntoResponse for AuthError { 28 fn into_response(self) -> Response { 29 let (status, error, message) = match self { 30 AuthError::MissingToken => ( 31 StatusCode::UNAUTHORIZED, 32 "AuthenticationRequired", 33 "Authorization header is required", 34 ), 35 AuthError::InvalidFormat => ( 36 StatusCode::UNAUTHORIZED, 37 "InvalidToken", 38 "Invalid authorization header format", 39 ), 40 AuthError::AuthenticationFailed => ( 41 StatusCode::UNAUTHORIZED, 42 "AuthenticationFailed", 43 "Invalid or expired token", 44 ), 45 AuthError::AccountDeactivated => ( 46 StatusCode::UNAUTHORIZED, 47 "AccountDeactivated", 48 "Account is deactivated", 49 ), 50 AuthError::AccountTakedown => ( 51 StatusCode::UNAUTHORIZED, 52 "AccountTakedown", 53 "Account has been taken down", 54 ), 55 AuthError::AdminRequired => ( 56 StatusCode::FORBIDDEN, 57 "AdminRequired", 58 "This action requires admin privileges", 59 ), 60 }; 61 62 (status, Json(json!({ "error": error, "message": message }))).into_response() 63 } 64} 65 66fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 67 let auth_header = auth_header.trim(); 68 69 if auth_header.len() < 8 { 70 return Err(AuthError::InvalidFormat); 71 } 72 73 let prefix = &auth_header[..7]; 74 if !prefix.eq_ignore_ascii_case("bearer ") { 75 return Err(AuthError::InvalidFormat); 76 } 77 78 let token = auth_header[7..].trim(); 79 if token.is_empty() { 80 return Err(AuthError::InvalidFormat); 81 } 82 83 Ok(token) 84} 85 86pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 87 let header = auth_header?; 88 let header = header.trim(); 89 90 if header.len() < 7 { 91 return None; 92 } 93 94 if !header[..7].eq_ignore_ascii_case("bearer ") { 95 return None; 96 } 97 98 let token = header[7..].trim(); 99 if token.is_empty() { 100 return None; 101 } 102 103 Some(token.to_string()) 104} 105 106pub struct ExtractedToken { 107 pub token: String, 108 pub is_dpop: bool, 109} 110 111pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 112 let header = auth_header?; 113 let header = header.trim(); 114 115 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 116 let token = header[7..].trim(); 117 if token.is_empty() { 118 return None; 119 } 120 return Some(ExtractedToken { 121 token: token.to_string(), 122 is_dpop: false, 123 }); 124 } 125 126 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 127 let token = header[5..].trim(); 128 if token.is_empty() { 129 return None; 130 } 131 return Some(ExtractedToken { 132 token: token.to_string(), 133 is_dpop: true, 134 }); 135 } 136 137 None 138} 139 140impl FromRequestParts<AppState> for BearerAuth { 141 type Rejection = AuthError; 142 143 async fn from_request_parts( 144 parts: &mut Parts, 145 state: &AppState, 146 ) -> Result<Self, Self::Rejection> { 147 let auth_header = parts 148 .headers 149 .get(AUTHORIZATION) 150 .ok_or(AuthError::MissingToken)? 151 .to_str() 152 .map_err(|_| AuthError::InvalidFormat)?; 153 154 let token = extract_bearer_token(auth_header)?; 155 156 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 157 Ok(user) => Ok(BearerAuth(user)), 158 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 159 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 160 Err(_) => Err(AuthError::AuthenticationFailed), 161 } 162 } 163} 164 165pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 166 167impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 168 type Rejection = AuthError; 169 170 async fn from_request_parts( 171 parts: &mut Parts, 172 state: &AppState, 173 ) -> Result<Self, Self::Rejection> { 174 let auth_header = parts 175 .headers 176 .get(AUTHORIZATION) 177 .ok_or(AuthError::MissingToken)? 178 .to_str() 179 .map_err(|_| AuthError::InvalidFormat)?; 180 181 let token = extract_bearer_token(auth_header)?; 182 183 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 184 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 185 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 186 Err(_) => Err(AuthError::AuthenticationFailed), 187 } 188 } 189} 190 191pub struct BearerAuthAdmin(pub AuthenticatedUser); 192 193impl FromRequestParts<AppState> for BearerAuthAdmin { 194 type Rejection = AuthError; 195 196 async fn from_request_parts( 197 parts: &mut Parts, 198 state: &AppState, 199 ) -> Result<Self, Self::Rejection> { 200 let auth_header = parts 201 .headers 202 .get(AUTHORIZATION) 203 .ok_or(AuthError::MissingToken)? 204 .to_str() 205 .map_err(|_| AuthError::InvalidFormat)?; 206 207 let token = extract_bearer_token(auth_header)?; 208 209 match validate_bearer_token_cached(&state.db, &state.cache, token).await { 210 Ok(user) => { 211 if !user.is_admin { 212 return Err(AuthError::AdminRequired); 213 } 214 Ok(BearerAuthAdmin(user)) 215 } 216 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 217 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 218 Err(_) => Err(AuthError::AuthenticationFailed), 219 } 220 } 221} 222 223#[cfg(test)] 224mod tests { 225 use super::*; 226 227 #[test] 228 fn test_extract_bearer_token() { 229 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 230 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 231 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 232 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 233 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 234 235 assert!(extract_bearer_token("Basic abc123").is_err()); 236 assert!(extract_bearer_token("Bearer").is_err()); 237 assert!(extract_bearer_token("Bearer ").is_err()); 238 assert!(extract_bearer_token("abc123").is_err()); 239 assert!(extract_bearer_token("").is_err()); 240 } 241}