this repo has no description
1use axum::{ 2 extract::FromRequestParts, 3 http::{header::AUTHORIZATION, request::Parts}, 4 response::{IntoResponse, Response}, 5}; 6 7use super::{ 8 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 9 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 10}; 11use crate::api::error::ApiError; 12use crate::state::AppState; 13use crate::util::build_full_url; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 TokenExpired, 23 AccountDeactivated, 24 AccountTakedown, 25 AdminRequired, 26} 27 28impl IntoResponse for AuthError { 29 fn into_response(self) -> Response { 30 ApiError::from(self).into_response() 31 } 32} 33 34#[cfg(test)] 35fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 36 let auth_header = auth_header.trim(); 37 38 if auth_header.len() < 8 { 39 return Err(AuthError::InvalidFormat); 40 } 41 42 let prefix = &auth_header[..7]; 43 if !prefix.eq_ignore_ascii_case("bearer ") { 44 return Err(AuthError::InvalidFormat); 45 } 46 47 let token = auth_header[7..].trim(); 48 if token.is_empty() { 49 return Err(AuthError::InvalidFormat); 50 } 51 52 Ok(token) 53} 54 55pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 56 let header = auth_header?; 57 let header = header.trim(); 58 59 if header.len() < 7 { 60 return None; 61 } 62 63 if !header[..7].eq_ignore_ascii_case("bearer ") { 64 return None; 65 } 66 67 let token = header[7..].trim(); 68 if token.is_empty() { 69 return None; 70 } 71 72 Some(token.to_string()) 73} 74 75pub struct ExtractedToken { 76 pub token: String, 77 pub is_dpop: bool, 78} 79 80pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 81 let header = auth_header?; 82 let header = header.trim(); 83 84 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 85 let token = header[7..].trim(); 86 if token.is_empty() { 87 return None; 88 } 89 return Some(ExtractedToken { 90 token: token.to_string(), 91 is_dpop: false, 92 }); 93 } 94 95 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 96 let token = header[5..].trim(); 97 if token.is_empty() { 98 return None; 99 } 100 return Some(ExtractedToken { 101 token: token.to_string(), 102 is_dpop: true, 103 }); 104 } 105 106 None 107} 108 109impl FromRequestParts<AppState> for BearerAuth { 110 type Rejection = AuthError; 111 112 async fn from_request_parts( 113 parts: &mut Parts, 114 state: &AppState, 115 ) -> Result<Self, Self::Rejection> { 116 let auth_header = parts 117 .headers 118 .get(AUTHORIZATION) 119 .ok_or(AuthError::MissingToken)? 120 .to_str() 121 .map_err(|_| AuthError::InvalidFormat)?; 122 123 let extracted = 124 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 125 126 if extracted.is_dpop { 127 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 128 let method = parts.method.as_str(); 129 let uri = build_full_url(&parts.uri.to_string()); 130 131 match validate_token_with_dpop( 132 &state.db, 133 &extracted.token, 134 true, 135 dpop_proof, 136 method, 137 &uri, 138 false, 139 ) 140 .await 141 { 142 Ok(user) => Ok(BearerAuth(user)), 143 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 144 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 145 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 146 Err(_) => Err(AuthError::AuthenticationFailed), 147 } 148 } else { 149 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 150 .await 151 { 152 Ok(user) => Ok(BearerAuth(user)), 153 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 154 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 155 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 156 Err(_) => Err(AuthError::AuthenticationFailed), 157 } 158 } 159 } 160} 161 162pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 163 164impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 165 type Rejection = AuthError; 166 167 async fn from_request_parts( 168 parts: &mut Parts, 169 state: &AppState, 170 ) -> Result<Self, Self::Rejection> { 171 let auth_header = parts 172 .headers 173 .get(AUTHORIZATION) 174 .ok_or(AuthError::MissingToken)? 175 .to_str() 176 .map_err(|_| AuthError::InvalidFormat)?; 177 178 let extracted = 179 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 180 181 if extracted.is_dpop { 182 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 183 let method = parts.method.as_str(); 184 let uri = build_full_url(&parts.uri.to_string()); 185 186 match validate_token_with_dpop( 187 &state.db, 188 &extracted.token, 189 true, 190 dpop_proof, 191 method, 192 &uri, 193 true, 194 ) 195 .await 196 { 197 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 198 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 199 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 200 Err(_) => Err(AuthError::AuthenticationFailed), 201 } 202 } else { 203 match validate_bearer_token_cached_allow_deactivated( 204 &state.db, 205 state.cache.as_ref(), 206 &extracted.token, 207 ) 208 .await 209 { 210 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 211 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 212 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 213 Err(_) => Err(AuthError::AuthenticationFailed), 214 } 215 } 216 } 217} 218 219pub struct BearerAuthAdmin(pub AuthenticatedUser); 220 221impl FromRequestParts<AppState> for BearerAuthAdmin { 222 type Rejection = AuthError; 223 224 async fn from_request_parts( 225 parts: &mut Parts, 226 state: &AppState, 227 ) -> Result<Self, Self::Rejection> { 228 let auth_header = parts 229 .headers 230 .get(AUTHORIZATION) 231 .ok_or(AuthError::MissingToken)? 232 .to_str() 233 .map_err(|_| AuthError::InvalidFormat)?; 234 235 let extracted = 236 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 237 238 let user = if extracted.is_dpop { 239 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 240 let method = parts.method.as_str(); 241 let uri = build_full_url(&parts.uri.to_string()); 242 243 match validate_token_with_dpop( 244 &state.db, 245 &extracted.token, 246 true, 247 dpop_proof, 248 method, 249 &uri, 250 false, 251 ) 252 .await 253 { 254 Ok(user) => user, 255 Err(TokenValidationError::AccountDeactivated) => { 256 return Err(AuthError::AccountDeactivated); 257 } 258 Err(TokenValidationError::AccountTakedown) => { 259 return Err(AuthError::AccountTakedown); 260 } 261 Err(TokenValidationError::TokenExpired) => { 262 return Err(AuthError::TokenExpired); 263 } 264 Err(_) => return Err(AuthError::AuthenticationFailed), 265 } 266 } else { 267 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 268 .await 269 { 270 Ok(user) => user, 271 Err(TokenValidationError::AccountDeactivated) => { 272 return Err(AuthError::AccountDeactivated); 273 } 274 Err(TokenValidationError::AccountTakedown) => { 275 return Err(AuthError::AccountTakedown); 276 } 277 Err(TokenValidationError::TokenExpired) => { 278 return Err(AuthError::TokenExpired); 279 } 280 Err(_) => return Err(AuthError::AuthenticationFailed), 281 } 282 }; 283 284 if !user.is_admin { 285 return Err(AuthError::AdminRequired); 286 } 287 Ok(BearerAuthAdmin(user)) 288 } 289} 290 291#[cfg(test)] 292mod tests { 293 use super::*; 294 295 #[test] 296 fn test_extract_bearer_token() { 297 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 298 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 299 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 300 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 301 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 302 303 assert!(extract_bearer_token("Basic abc123").is_err()); 304 assert!(extract_bearer_token("Bearer").is_err()); 305 assert!(extract_bearer_token("Bearer ").is_err()); 306 assert!(extract_bearer_token("abc123").is_err()); 307 assert!(extract_bearer_token("").is_err()); 308 } 309}