this repo has no description
1use axum::{ 2 Json, 3 extract::FromRequestParts, 4 http::{StatusCode, header::AUTHORIZATION, request::Parts}, 5 response::{IntoResponse, Response}, 6}; 7use serde_json::json; 8 9use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 12}; 13use crate::state::AppState; 14use crate::util::build_full_url; 15 16pub struct BearerAuth(pub AuthenticatedUser); 17 18#[derive(Debug)] 19pub enum AuthError { 20 MissingToken, 21 InvalidFormat, 22 AuthenticationFailed, 23 TokenExpired, 24 AccountDeactivated, 25 AccountTakedown, 26 AdminRequired, 27} 28 29impl IntoResponse for AuthError { 30 fn into_response(self) -> Response { 31 let (status, error, message) = match self { 32 AuthError::MissingToken => ( 33 StatusCode::UNAUTHORIZED, 34 "AuthenticationRequired", 35 "Authorization header is required", 36 ), 37 AuthError::InvalidFormat => ( 38 StatusCode::UNAUTHORIZED, 39 "InvalidToken", 40 "Invalid authorization header format", 41 ), 42 AuthError::AuthenticationFailed => ( 43 StatusCode::UNAUTHORIZED, 44 "InvalidToken", 45 "Token could not be verified", 46 ), 47 AuthError::TokenExpired => ( 48 StatusCode::UNAUTHORIZED, 49 "ExpiredToken", 50 "Token has expired", 51 ), 52 AuthError::AccountDeactivated => ( 53 StatusCode::UNAUTHORIZED, 54 "AccountDeactivated", 55 "Account is deactivated", 56 ), 57 AuthError::AccountTakedown => ( 58 StatusCode::UNAUTHORIZED, 59 "AccountTakedown", 60 "Account has been taken down", 61 ), 62 AuthError::AdminRequired => ( 63 StatusCode::FORBIDDEN, 64 "AdminRequired", 65 "This action requires admin privileges", 66 ), 67 }; 68 69 (status, Json(json!({ "error": error, "message": message }))).into_response() 70 } 71} 72 73#[cfg(test)] 74fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 75 let auth_header = auth_header.trim(); 76 77 if auth_header.len() < 8 { 78 return Err(AuthError::InvalidFormat); 79 } 80 81 let prefix = &auth_header[..7]; 82 if !prefix.eq_ignore_ascii_case("bearer ") { 83 return Err(AuthError::InvalidFormat); 84 } 85 86 let token = auth_header[7..].trim(); 87 if token.is_empty() { 88 return Err(AuthError::InvalidFormat); 89 } 90 91 Ok(token) 92} 93 94pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 95 let header = auth_header?; 96 let header = header.trim(); 97 98 if header.len() < 7 { 99 return None; 100 } 101 102 if !header[..7].eq_ignore_ascii_case("bearer ") { 103 return None; 104 } 105 106 let token = header[7..].trim(); 107 if token.is_empty() { 108 return None; 109 } 110 111 Some(token.to_string()) 112} 113 114pub struct ExtractedToken { 115 pub token: String, 116 pub is_dpop: bool, 117} 118 119pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 120 let header = auth_header?; 121 let header = header.trim(); 122 123 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 124 let token = header[7..].trim(); 125 if token.is_empty() { 126 return None; 127 } 128 return Some(ExtractedToken { 129 token: token.to_string(), 130 is_dpop: false, 131 }); 132 } 133 134 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 135 let token = header[5..].trim(); 136 if token.is_empty() { 137 return None; 138 } 139 return Some(ExtractedToken { 140 token: token.to_string(), 141 is_dpop: true, 142 }); 143 } 144 145 None 146} 147 148impl FromRequestParts<AppState> for BearerAuth { 149 type Rejection = AuthError; 150 151 async fn from_request_parts( 152 parts: &mut Parts, 153 state: &AppState, 154 ) -> Result<Self, Self::Rejection> { 155 let auth_header = parts 156 .headers 157 .get(AUTHORIZATION) 158 .ok_or(AuthError::MissingToken)? 159 .to_str() 160 .map_err(|_| AuthError::InvalidFormat)?; 161 162 let extracted = 163 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 164 165 if extracted.is_dpop { 166 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 167 let method = parts.method.as_str(); 168 let uri = build_full_url(&parts.uri.to_string()); 169 170 match validate_token_with_dpop( 171 &state.db, 172 &extracted.token, 173 true, 174 dpop_proof, 175 method, 176 &uri, 177 false, 178 ) 179 .await 180 { 181 Ok(user) => Ok(BearerAuth(user)), 182 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 183 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 184 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 185 Err(_) => Err(AuthError::AuthenticationFailed), 186 } 187 } else { 188 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 189 Ok(user) => Ok(BearerAuth(user)), 190 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 191 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 192 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 193 Err(_) => Err(AuthError::AuthenticationFailed), 194 } 195 } 196 } 197} 198 199pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 200 201impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 202 type Rejection = AuthError; 203 204 async fn from_request_parts( 205 parts: &mut Parts, 206 state: &AppState, 207 ) -> Result<Self, Self::Rejection> { 208 let auth_header = parts 209 .headers 210 .get(AUTHORIZATION) 211 .ok_or(AuthError::MissingToken)? 212 .to_str() 213 .map_err(|_| AuthError::InvalidFormat)?; 214 215 let extracted = 216 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 217 218 if extracted.is_dpop { 219 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 220 let method = parts.method.as_str(); 221 let uri = build_full_url(&parts.uri.to_string()); 222 223 match validate_token_with_dpop( 224 &state.db, 225 &extracted.token, 226 true, 227 dpop_proof, 228 method, 229 &uri, 230 true, 231 ) 232 .await 233 { 234 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 235 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 236 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 237 Err(_) => Err(AuthError::AuthenticationFailed), 238 } 239 } else { 240 match validate_bearer_token_cached_allow_deactivated( 241 &state.db, 242 &state.cache, 243 &extracted.token, 244 ) 245 .await 246 { 247 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 248 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 249 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 250 Err(_) => Err(AuthError::AuthenticationFailed), 251 } 252 } 253 } 254} 255 256pub struct BearerAuthAdmin(pub AuthenticatedUser); 257 258impl FromRequestParts<AppState> for BearerAuthAdmin { 259 type Rejection = AuthError; 260 261 async fn from_request_parts( 262 parts: &mut Parts, 263 state: &AppState, 264 ) -> Result<Self, Self::Rejection> { 265 let auth_header = parts 266 .headers 267 .get(AUTHORIZATION) 268 .ok_or(AuthError::MissingToken)? 269 .to_str() 270 .map_err(|_| AuthError::InvalidFormat)?; 271 272 let extracted = 273 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 274 275 let user = if extracted.is_dpop { 276 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 277 let method = parts.method.as_str(); 278 let uri = build_full_url(&parts.uri.to_string()); 279 280 match validate_token_with_dpop( 281 &state.db, 282 &extracted.token, 283 true, 284 dpop_proof, 285 method, 286 &uri, 287 false, 288 ) 289 .await 290 { 291 Ok(user) => user, 292 Err(TokenValidationError::AccountDeactivated) => { 293 return Err(AuthError::AccountDeactivated); 294 } 295 Err(TokenValidationError::AccountTakedown) => { 296 return Err(AuthError::AccountTakedown); 297 } 298 Err(TokenValidationError::TokenExpired) => { 299 return Err(AuthError::TokenExpired); 300 } 301 Err(_) => return Err(AuthError::AuthenticationFailed), 302 } 303 } else { 304 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 305 Ok(user) => user, 306 Err(TokenValidationError::AccountDeactivated) => { 307 return Err(AuthError::AccountDeactivated); 308 } 309 Err(TokenValidationError::AccountTakedown) => { 310 return Err(AuthError::AccountTakedown); 311 } 312 Err(TokenValidationError::TokenExpired) => { 313 return Err(AuthError::TokenExpired); 314 } 315 Err(_) => return Err(AuthError::AuthenticationFailed), 316 } 317 }; 318 319 if !user.is_admin { 320 return Err(AuthError::AdminRequired); 321 } 322 Ok(BearerAuthAdmin(user)) 323 } 324} 325 326#[cfg(test)] 327mod tests { 328 use super::*; 329 330 #[test] 331 fn test_extract_bearer_token() { 332 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 333 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 334 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 335 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 336 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 337 338 assert!(extract_bearer_token("Basic abc123").is_err()); 339 assert!(extract_bearer_token("Bearer").is_err()); 340 assert!(extract_bearer_token("Bearer ").is_err()); 341 assert!(extract_bearer_token("abc123").is_err()); 342 assert!(extract_bearer_token("").is_err()); 343 } 344}