this repo has no description
1use axum::{ 2 Json, 3 extract::FromRequestParts, 4 http::{StatusCode, header::AUTHORIZATION, request::Parts}, 5 response::{IntoResponse, Response}, 6}; 7use serde_json::json; 8 9use super::{ 10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 12}; 13use crate::state::AppState; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 AccountDeactivated, 23 AccountTakedown, 24 AdminRequired, 25} 26 27impl IntoResponse for AuthError { 28 fn into_response(self) -> Response { 29 let (status, error, message) = match self { 30 AuthError::MissingToken => ( 31 StatusCode::UNAUTHORIZED, 32 "AuthenticationRequired", 33 "Authorization header is required", 34 ), 35 AuthError::InvalidFormat => ( 36 StatusCode::UNAUTHORIZED, 37 "InvalidToken", 38 "Invalid authorization header format", 39 ), 40 AuthError::AuthenticationFailed => ( 41 StatusCode::UNAUTHORIZED, 42 "AuthenticationFailed", 43 "Invalid or expired token", 44 ), 45 AuthError::AccountDeactivated => ( 46 StatusCode::UNAUTHORIZED, 47 "AccountDeactivated", 48 "Account is deactivated", 49 ), 50 AuthError::AccountTakedown => ( 51 StatusCode::UNAUTHORIZED, 52 "AccountTakedown", 53 "Account has been taken down", 54 ), 55 AuthError::AdminRequired => ( 56 StatusCode::FORBIDDEN, 57 "AdminRequired", 58 "This action requires admin privileges", 59 ), 60 }; 61 62 (status, Json(json!({ "error": error, "message": message }))).into_response() 63 } 64} 65 66#[cfg(test)] 67fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 68 let auth_header = auth_header.trim(); 69 70 if auth_header.len() < 8 { 71 return Err(AuthError::InvalidFormat); 72 } 73 74 let prefix = &auth_header[..7]; 75 if !prefix.eq_ignore_ascii_case("bearer ") { 76 return Err(AuthError::InvalidFormat); 77 } 78 79 let token = auth_header[7..].trim(); 80 if token.is_empty() { 81 return Err(AuthError::InvalidFormat); 82 } 83 84 Ok(token) 85} 86 87pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 88 let header = auth_header?; 89 let header = header.trim(); 90 91 if header.len() < 7 { 92 return None; 93 } 94 95 if !header[..7].eq_ignore_ascii_case("bearer ") { 96 return None; 97 } 98 99 let token = header[7..].trim(); 100 if token.is_empty() { 101 return None; 102 } 103 104 Some(token.to_string()) 105} 106 107pub struct ExtractedToken { 108 pub token: String, 109 pub is_dpop: bool, 110} 111 112pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 113 let header = auth_header?; 114 let header = header.trim(); 115 116 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 117 let token = header[7..].trim(); 118 if token.is_empty() { 119 return None; 120 } 121 return Some(ExtractedToken { 122 token: token.to_string(), 123 is_dpop: false, 124 }); 125 } 126 127 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 128 let token = header[5..].trim(); 129 if token.is_empty() { 130 return None; 131 } 132 return Some(ExtractedToken { 133 token: token.to_string(), 134 is_dpop: true, 135 }); 136 } 137 138 None 139} 140 141impl FromRequestParts<AppState> for BearerAuth { 142 type Rejection = AuthError; 143 144 async fn from_request_parts( 145 parts: &mut Parts, 146 state: &AppState, 147 ) -> Result<Self, Self::Rejection> { 148 let auth_header = parts 149 .headers 150 .get(AUTHORIZATION) 151 .ok_or(AuthError::MissingToken)? 152 .to_str() 153 .map_err(|_| AuthError::InvalidFormat)?; 154 155 let extracted = 156 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 157 158 if extracted.is_dpop { 159 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 160 let method = parts.method.as_str(); 161 let uri = parts.uri.to_string(); 162 163 match validate_token_with_dpop( 164 &state.db, 165 &extracted.token, 166 true, 167 dpop_proof, 168 method, 169 &uri, 170 false, 171 ) 172 .await 173 { 174 Ok(user) => Ok(BearerAuth(user)), 175 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 176 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 177 Err(_) => Err(AuthError::AuthenticationFailed), 178 } 179 } else { 180 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 181 Ok(user) => Ok(BearerAuth(user)), 182 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 183 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 184 Err(_) => Err(AuthError::AuthenticationFailed), 185 } 186 } 187 } 188} 189 190pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 191 192impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 193 type Rejection = AuthError; 194 195 async fn from_request_parts( 196 parts: &mut Parts, 197 state: &AppState, 198 ) -> Result<Self, Self::Rejection> { 199 let auth_header = parts 200 .headers 201 .get(AUTHORIZATION) 202 .ok_or(AuthError::MissingToken)? 203 .to_str() 204 .map_err(|_| AuthError::InvalidFormat)?; 205 206 let extracted = 207 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 208 209 if extracted.is_dpop { 210 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 211 let method = parts.method.as_str(); 212 let uri = parts.uri.to_string(); 213 214 match validate_token_with_dpop( 215 &state.db, 216 &extracted.token, 217 true, 218 dpop_proof, 219 method, 220 &uri, 221 true, 222 ) 223 .await 224 { 225 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 226 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 227 Err(_) => Err(AuthError::AuthenticationFailed), 228 } 229 } else { 230 match validate_bearer_token_cached_allow_deactivated( 231 &state.db, 232 &state.cache, 233 &extracted.token, 234 ) 235 .await 236 { 237 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 238 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 239 Err(_) => Err(AuthError::AuthenticationFailed), 240 } 241 } 242 } 243} 244 245pub struct BearerAuthAdmin(pub AuthenticatedUser); 246 247impl FromRequestParts<AppState> for BearerAuthAdmin { 248 type Rejection = AuthError; 249 250 async fn from_request_parts( 251 parts: &mut Parts, 252 state: &AppState, 253 ) -> Result<Self, Self::Rejection> { 254 let auth_header = parts 255 .headers 256 .get(AUTHORIZATION) 257 .ok_or(AuthError::MissingToken)? 258 .to_str() 259 .map_err(|_| AuthError::InvalidFormat)?; 260 261 let extracted = 262 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 263 264 let user = if extracted.is_dpop { 265 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 266 let method = parts.method.as_str(); 267 let uri = parts.uri.to_string(); 268 269 match validate_token_with_dpop( 270 &state.db, 271 &extracted.token, 272 true, 273 dpop_proof, 274 method, 275 &uri, 276 false, 277 ) 278 .await 279 { 280 Ok(user) => user, 281 Err(TokenValidationError::AccountDeactivated) => { 282 return Err(AuthError::AccountDeactivated); 283 } 284 Err(TokenValidationError::AccountTakedown) => { 285 return Err(AuthError::AccountTakedown); 286 } 287 Err(_) => return Err(AuthError::AuthenticationFailed), 288 } 289 } else { 290 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await { 291 Ok(user) => user, 292 Err(TokenValidationError::AccountDeactivated) => { 293 return Err(AuthError::AccountDeactivated); 294 } 295 Err(TokenValidationError::AccountTakedown) => { 296 return Err(AuthError::AccountTakedown); 297 } 298 Err(_) => return Err(AuthError::AuthenticationFailed), 299 } 300 }; 301 302 if !user.is_admin { 303 return Err(AuthError::AdminRequired); 304 } 305 Ok(BearerAuthAdmin(user)) 306 } 307} 308 309#[cfg(test)] 310mod tests { 311 use super::*; 312 313 #[test] 314 fn test_extract_bearer_token() { 315 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 316 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 317 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 318 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 319 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 320 321 assert!(extract_bearer_token("Basic abc123").is_err()); 322 assert!(extract_bearer_token("Bearer").is_err()); 323 assert!(extract_bearer_token("Bearer ").is_err()); 324 assert!(extract_bearer_token("abc123").is_err()); 325 assert!(extract_bearer_token("").is_err()); 326 } 327}