this repo has no description
1use axum::{ 2 extract::FromRequestParts, 3 http::{header::AUTHORIZATION, request::Parts}, 4 response::{IntoResponse, Response}, 5}; 6 7use super::{ 8 AuthenticatedUser, TokenValidationError, validate_bearer_token_allow_takendown, 9 validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated, 10 validate_token_with_dpop, 11}; 12use crate::api::error::ApiError; 13use crate::state::AppState; 14use crate::util::build_full_url; 15 16pub struct BearerAuth(pub AuthenticatedUser); 17 18#[derive(Debug)] 19pub enum AuthError { 20 MissingToken, 21 InvalidFormat, 22 AuthenticationFailed, 23 TokenExpired, 24 AccountDeactivated, 25 AccountTakedown, 26 AdminRequired, 27} 28 29impl IntoResponse for AuthError { 30 fn into_response(self) -> Response { 31 ApiError::from(self).into_response() 32 } 33} 34 35#[cfg(test)] 36fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 37 let auth_header = auth_header.trim(); 38 39 if auth_header.len() < 8 { 40 return Err(AuthError::InvalidFormat); 41 } 42 43 let prefix = &auth_header[..7]; 44 if !prefix.eq_ignore_ascii_case("bearer ") { 45 return Err(AuthError::InvalidFormat); 46 } 47 48 let token = auth_header[7..].trim(); 49 if token.is_empty() { 50 return Err(AuthError::InvalidFormat); 51 } 52 53 Ok(token) 54} 55 56pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 57 let header = auth_header?; 58 let header = header.trim(); 59 60 if header.len() < 7 { 61 return None; 62 } 63 64 if !header[..7].eq_ignore_ascii_case("bearer ") { 65 return None; 66 } 67 68 let token = header[7..].trim(); 69 if token.is_empty() { 70 return None; 71 } 72 73 Some(token.to_string()) 74} 75 76pub struct ExtractedToken { 77 pub token: String, 78 pub is_dpop: bool, 79} 80 81pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 82 let header = auth_header?; 83 let header = header.trim(); 84 85 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 86 let token = header[7..].trim(); 87 if token.is_empty() { 88 return None; 89 } 90 return Some(ExtractedToken { 91 token: token.to_string(), 92 is_dpop: false, 93 }); 94 } 95 96 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 97 let token = header[5..].trim(); 98 if token.is_empty() { 99 return None; 100 } 101 return Some(ExtractedToken { 102 token: token.to_string(), 103 is_dpop: true, 104 }); 105 } 106 107 None 108} 109 110impl FromRequestParts<AppState> for BearerAuth { 111 type Rejection = AuthError; 112 113 async fn from_request_parts( 114 parts: &mut Parts, 115 state: &AppState, 116 ) -> Result<Self, Self::Rejection> { 117 let auth_header = parts 118 .headers 119 .get(AUTHORIZATION) 120 .ok_or(AuthError::MissingToken)? 121 .to_str() 122 .map_err(|_| AuthError::InvalidFormat)?; 123 124 let extracted = 125 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 126 127 if extracted.is_dpop { 128 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 129 let method = parts.method.as_str(); 130 let uri = build_full_url(&parts.uri.to_string()); 131 132 match validate_token_with_dpop( 133 &state.db, 134 &extracted.token, 135 true, 136 dpop_proof, 137 method, 138 &uri, 139 false, 140 false, 141 ) 142 .await 143 { 144 Ok(user) => Ok(BearerAuth(user)), 145 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 146 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 147 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 148 Err(_) => Err(AuthError::AuthenticationFailed), 149 } 150 } else { 151 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 152 .await 153 { 154 Ok(user) => Ok(BearerAuth(user)), 155 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 156 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 157 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 158 Err(_) => Err(AuthError::AuthenticationFailed), 159 } 160 } 161 } 162} 163 164pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 165 166impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 167 type Rejection = AuthError; 168 169 async fn from_request_parts( 170 parts: &mut Parts, 171 state: &AppState, 172 ) -> Result<Self, Self::Rejection> { 173 let auth_header = parts 174 .headers 175 .get(AUTHORIZATION) 176 .ok_or(AuthError::MissingToken)? 177 .to_str() 178 .map_err(|_| AuthError::InvalidFormat)?; 179 180 let extracted = 181 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 182 183 if extracted.is_dpop { 184 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 185 let method = parts.method.as_str(); 186 let uri = build_full_url(&parts.uri.to_string()); 187 188 match validate_token_with_dpop( 189 &state.db, 190 &extracted.token, 191 true, 192 dpop_proof, 193 method, 194 &uri, 195 true, 196 false, 197 ) 198 .await 199 { 200 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 201 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 202 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 203 Err(_) => Err(AuthError::AuthenticationFailed), 204 } 205 } else { 206 match validate_bearer_token_cached_allow_deactivated( 207 &state.db, 208 state.cache.as_ref(), 209 &extracted.token, 210 ) 211 .await 212 { 213 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 214 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 215 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 216 Err(_) => Err(AuthError::AuthenticationFailed), 217 } 218 } 219 } 220} 221 222pub struct BearerAuthAllowTakendown(pub AuthenticatedUser); 223 224impl FromRequestParts<AppState> for BearerAuthAllowTakendown { 225 type Rejection = AuthError; 226 227 async fn from_request_parts( 228 parts: &mut Parts, 229 state: &AppState, 230 ) -> Result<Self, Self::Rejection> { 231 let auth_header = parts 232 .headers 233 .get(AUTHORIZATION) 234 .ok_or(AuthError::MissingToken)? 235 .to_str() 236 .map_err(|_| AuthError::InvalidFormat)?; 237 238 let extracted = 239 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 240 241 if extracted.is_dpop { 242 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 243 let method = parts.method.as_str(); 244 let uri = build_full_url(&parts.uri.to_string()); 245 246 match validate_token_with_dpop( 247 &state.db, 248 &extracted.token, 249 true, 250 dpop_proof, 251 method, 252 &uri, 253 false, 254 true, 255 ) 256 .await 257 { 258 Ok(user) => Ok(BearerAuthAllowTakendown(user)), 259 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 260 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 261 Err(_) => Err(AuthError::AuthenticationFailed), 262 } 263 } else { 264 match validate_bearer_token_allow_takendown(&state.db, &extracted.token).await { 265 Ok(user) => Ok(BearerAuthAllowTakendown(user)), 266 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 267 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 268 Err(_) => Err(AuthError::AuthenticationFailed), 269 } 270 } 271 } 272} 273 274pub struct BearerAuthAdmin(pub AuthenticatedUser); 275 276impl FromRequestParts<AppState> for BearerAuthAdmin { 277 type Rejection = AuthError; 278 279 async fn from_request_parts( 280 parts: &mut Parts, 281 state: &AppState, 282 ) -> Result<Self, Self::Rejection> { 283 let auth_header = parts 284 .headers 285 .get(AUTHORIZATION) 286 .ok_or(AuthError::MissingToken)? 287 .to_str() 288 .map_err(|_| AuthError::InvalidFormat)?; 289 290 let extracted = 291 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 292 293 let user = if extracted.is_dpop { 294 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 295 let method = parts.method.as_str(); 296 let uri = build_full_url(&parts.uri.to_string()); 297 298 match validate_token_with_dpop( 299 &state.db, 300 &extracted.token, 301 true, 302 dpop_proof, 303 method, 304 &uri, 305 false, 306 false, 307 ) 308 .await 309 { 310 Ok(user) => user, 311 Err(TokenValidationError::AccountDeactivated) => { 312 return Err(AuthError::AccountDeactivated); 313 } 314 Err(TokenValidationError::AccountTakedown) => { 315 return Err(AuthError::AccountTakedown); 316 } 317 Err(TokenValidationError::TokenExpired) => { 318 return Err(AuthError::TokenExpired); 319 } 320 Err(_) => return Err(AuthError::AuthenticationFailed), 321 } 322 } else { 323 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 324 .await 325 { 326 Ok(user) => user, 327 Err(TokenValidationError::AccountDeactivated) => { 328 return Err(AuthError::AccountDeactivated); 329 } 330 Err(TokenValidationError::AccountTakedown) => { 331 return Err(AuthError::AccountTakedown); 332 } 333 Err(TokenValidationError::TokenExpired) => { 334 return Err(AuthError::TokenExpired); 335 } 336 Err(_) => return Err(AuthError::AuthenticationFailed), 337 } 338 }; 339 340 if !user.is_admin { 341 return Err(AuthError::AdminRequired); 342 } 343 Ok(BearerAuthAdmin(user)) 344 } 345} 346 347#[cfg(test)] 348mod tests { 349 use super::*; 350 351 #[test] 352 fn test_extract_bearer_token() { 353 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 354 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 355 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 356 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 357 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 358 359 assert!(extract_bearer_token("Basic abc123").is_err()); 360 assert!(extract_bearer_token("Bearer").is_err()); 361 assert!(extract_bearer_token("Bearer ").is_err()); 362 assert!(extract_bearer_token("abc123").is_err()); 363 assert!(extract_bearer_token("").is_err()); 364 } 365}