this repo has no description
1use axum::{ 2 Json, 3 extract::FromRequestParts, 4 http::{StatusCode, header::AUTHORIZATION, request::Parts}, 5 response::{IntoResponse, Response}, 6}; 7use serde_json::json; 8 9use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 12}; 13use crate::state::AppState; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 TokenExpired, 23 AccountDeactivated, 24 AccountTakedown, 25 AdminRequired, 26} 27 28impl IntoResponse for AuthError { 29 fn into_response(self) -> Response { 30 let (status, error, message) = match self { 31 AuthError::MissingToken => ( 32 StatusCode::UNAUTHORIZED, 33 "AuthenticationRequired", 34 "Authorization header is required", 35 ), 36 AuthError::InvalidFormat => ( 37 StatusCode::UNAUTHORIZED, 38 "InvalidToken", 39 "Invalid authorization header format", 40 ), 41 AuthError::AuthenticationFailed => ( 42 StatusCode::UNAUTHORIZED, 43 "InvalidToken", 44 "Token could not be verified", 45 ), 46 AuthError::TokenExpired => ( 47 StatusCode::UNAUTHORIZED, 48 "ExpiredToken", 49 "Token has expired", 50 ), 51 AuthError::AccountDeactivated => ( 52 StatusCode::UNAUTHORIZED, 53 "AccountDeactivated", 54 "Account is deactivated", 55 ), 56 AuthError::AccountTakedown => ( 57 StatusCode::UNAUTHORIZED, 58 "AccountTakedown", 59 "Account has been taken down", 60 ), 61 AuthError::AdminRequired => ( 62 StatusCode::FORBIDDEN, 63 "AdminRequired", 64 "This action requires admin privileges", 65 ), 66 }; 67 68 (status, Json(json!({ "error": error, "message": message }))).into_response() 69 } 70} 71 72#[cfg(test)] 73fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 74 let auth_header = auth_header.trim(); 75 76 if auth_header.len() < 8 { 77 return Err(AuthError::InvalidFormat); 78 } 79 80 let prefix = &auth_header[..7]; 81 if !prefix.eq_ignore_ascii_case("bearer ") { 82 return Err(AuthError::InvalidFormat); 83 } 84 85 let token = auth_header[7..].trim(); 86 if token.is_empty() { 87 return Err(AuthError::InvalidFormat); 88 } 89 90 Ok(token) 91} 92 93pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 94 let header = auth_header?; 95 let header = header.trim(); 96 97 if header.len() < 7 { 98 return None; 99 } 100 101 if !header[..7].eq_ignore_ascii_case("bearer ") { 102 return None; 103 } 104 105 let token = header[7..].trim(); 106 if token.is_empty() { 107 return None; 108 } 109 110 Some(token.to_string()) 111} 112 113pub struct ExtractedToken { 114 pub token: String, 115 pub is_dpop: bool, 116} 117 118pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 119 let header = auth_header?; 120 let header = header.trim(); 121 122 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 123 let token = header[7..].trim(); 124 if token.is_empty() { 125 return None; 126 } 127 return Some(ExtractedToken { 128 token: token.to_string(), 129 is_dpop: false, 130 }); 131 } 132 133 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 134 let token = header[5..].trim(); 135 if token.is_empty() { 136 return None; 137 } 138 return Some(ExtractedToken { 139 token: token.to_string(), 140 is_dpop: true, 141 }); 142 } 143 144 None 145} 146 147impl FromRequestParts<AppState> for BearerAuth { 148 type Rejection = AuthError; 149 150 async fn from_request_parts( 151 parts: &mut Parts, 152 state: &AppState, 153 ) -> Result<Self, Self::Rejection> { 154 let auth_header = parts 155 .headers 156 .get(AUTHORIZATION) 157 .ok_or(AuthError::MissingToken)? 158 .to_str() 159 .map_err(|_| AuthError::InvalidFormat)?; 160 161 let extracted = 162 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 163 164 if extracted.is_dpop { 165 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 166 let method = parts.method.as_str(); 167 let uri = parts.uri.to_string(); 168 169 match validate_token_with_dpop( 170 &state.db, 171 &extracted.token, 172 true, 173 dpop_proof, 174 method, 175 &uri, 176 false, 177 ) 178 .await 179 { 180 Ok(user) => Ok(BearerAuth(user)), 181 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 182 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 183 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 184 Err(_) => Err(AuthError::AuthenticationFailed), 185 } 186 } else { 187 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 188 Ok(user) => Ok(BearerAuth(user)), 189 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 190 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 191 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 192 Err(_) => Err(AuthError::AuthenticationFailed), 193 } 194 } 195 } 196} 197 198pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 199 200impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 201 type Rejection = AuthError; 202 203 async fn from_request_parts( 204 parts: &mut Parts, 205 state: &AppState, 206 ) -> Result<Self, Self::Rejection> { 207 let auth_header = parts 208 .headers 209 .get(AUTHORIZATION) 210 .ok_or(AuthError::MissingToken)? 211 .to_str() 212 .map_err(|_| AuthError::InvalidFormat)?; 213 214 let extracted = 215 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 216 217 if extracted.is_dpop { 218 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 219 let method = parts.method.as_str(); 220 let uri = parts.uri.to_string(); 221 222 match validate_token_with_dpop( 223 &state.db, 224 &extracted.token, 225 true, 226 dpop_proof, 227 method, 228 &uri, 229 true, 230 ) 231 .await 232 { 233 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 234 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 235 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 236 Err(_) => Err(AuthError::AuthenticationFailed), 237 } 238 } else { 239 match validate_bearer_token_cached_allow_deactivated( 240 &state.db, 241 &state.cache, 242 &extracted.token, 243 ) 244 .await 245 { 246 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 247 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 248 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 249 Err(_) => Err(AuthError::AuthenticationFailed), 250 } 251 } 252 } 253} 254 255pub struct BearerAuthAdmin(pub AuthenticatedUser); 256 257impl FromRequestParts<AppState> for BearerAuthAdmin { 258 type Rejection = AuthError; 259 260 async fn from_request_parts( 261 parts: &mut Parts, 262 state: &AppState, 263 ) -> Result<Self, Self::Rejection> { 264 let auth_header = parts 265 .headers 266 .get(AUTHORIZATION) 267 .ok_or(AuthError::MissingToken)? 268 .to_str() 269 .map_err(|_| AuthError::InvalidFormat)?; 270 271 let extracted = 272 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 273 274 let user = if extracted.is_dpop { 275 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 276 let method = parts.method.as_str(); 277 let uri = parts.uri.to_string(); 278 279 match validate_token_with_dpop( 280 &state.db, 281 &extracted.token, 282 true, 283 dpop_proof, 284 method, 285 &uri, 286 false, 287 ) 288 .await 289 { 290 Ok(user) => user, 291 Err(TokenValidationError::AccountDeactivated) => { 292 return Err(AuthError::AccountDeactivated); 293 } 294 Err(TokenValidationError::AccountTakedown) => { 295 return Err(AuthError::AccountTakedown); 296 } 297 Err(TokenValidationError::TokenExpired) => { 298 return Err(AuthError::TokenExpired); 299 } 300 Err(_) => return Err(AuthError::AuthenticationFailed), 301 } 302 } else { 303 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 304 Ok(user) => user, 305 Err(TokenValidationError::AccountDeactivated) => { 306 return Err(AuthError::AccountDeactivated); 307 } 308 Err(TokenValidationError::AccountTakedown) => { 309 return Err(AuthError::AccountTakedown); 310 } 311 Err(TokenValidationError::TokenExpired) => { 312 return Err(AuthError::TokenExpired); 313 } 314 Err(_) => return Err(AuthError::AuthenticationFailed), 315 } 316 }; 317 318 if !user.is_admin { 319 return Err(AuthError::AdminRequired); 320 } 321 Ok(BearerAuthAdmin(user)) 322 } 323} 324 325#[cfg(test)] 326mod tests { 327 use super::*; 328 329 #[test] 330 fn test_extract_bearer_token() { 331 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 332 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 333 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 334 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 335 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 336 337 assert!(extract_bearer_token("Basic abc123").is_err()); 338 assert!(extract_bearer_token("Bearer").is_err()); 339 assert!(extract_bearer_token("Bearer ").is_err()); 340 assert!(extract_bearer_token("abc123").is_err()); 341 assert!(extract_bearer_token("").is_err()); 342 } 343}