this repo has no description
1use axum::{ 2 extract::FromRequestParts, 3 http::{header::AUTHORIZATION, request::Parts}, 4 response::{IntoResponse, Response}, 5}; 6 7use super::{ 8 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 9 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop, 10}; 11use crate::api::error::ApiError; 12use crate::state::AppState; 13use crate::util::build_full_url; 14 15pub struct BearerAuth(pub AuthenticatedUser); 16 17#[derive(Debug)] 18pub enum AuthError { 19 MissingToken, 20 InvalidFormat, 21 AuthenticationFailed, 22 TokenExpired, 23 AccountDeactivated, 24 AccountTakedown, 25 AdminRequired, 26} 27 28impl IntoResponse for AuthError { 29 fn into_response(self) -> Response { 30 ApiError::from(self).into_response() 31 } 32} 33 34#[cfg(test)] 35fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 36 let auth_header = auth_header.trim(); 37 38 if auth_header.len() < 8 { 39 return Err(AuthError::InvalidFormat); 40 } 41 42 let prefix = &auth_header[..7]; 43 if !prefix.eq_ignore_ascii_case("bearer ") { 44 return Err(AuthError::InvalidFormat); 45 } 46 47 let token = auth_header[7..].trim(); 48 if token.is_empty() { 49 return Err(AuthError::InvalidFormat); 50 } 51 52 Ok(token) 53} 54 55pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { 56 let header = auth_header?; 57 let header = header.trim(); 58 59 if header.len() < 7 { 60 return None; 61 } 62 63 if !header[..7].eq_ignore_ascii_case("bearer ") { 64 return None; 65 } 66 67 let token = header[7..].trim(); 68 if token.is_empty() { 69 return None; 70 } 71 72 Some(token.to_string()) 73} 74 75pub struct ExtractedToken { 76 pub token: String, 77 pub is_dpop: bool, 78} 79 80pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 81 let header = auth_header?; 82 let header = header.trim(); 83 84 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") { 85 let token = header[7..].trim(); 86 if token.is_empty() { 87 return None; 88 } 89 return Some(ExtractedToken { 90 token: token.to_string(), 91 is_dpop: false, 92 }); 93 } 94 95 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { 96 let token = header[5..].trim(); 97 if token.is_empty() { 98 return None; 99 } 100 return Some(ExtractedToken { 101 token: token.to_string(), 102 is_dpop: true, 103 }); 104 } 105 106 None 107} 108 109impl FromRequestParts<AppState> for BearerAuth { 110 type Rejection = AuthError; 111 112 async fn from_request_parts( 113 parts: &mut Parts, 114 state: &AppState, 115 ) -> Result<Self, Self::Rejection> { 116 let auth_header = parts 117 .headers 118 .get(AUTHORIZATION) 119 .ok_or(AuthError::MissingToken)? 120 .to_str() 121 .map_err(|_| AuthError::InvalidFormat)?; 122 123 let extracted = 124 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 125 126 if extracted.is_dpop { 127 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 128 let method = parts.method.as_str(); 129 let uri = build_full_url(&parts.uri.to_string()); 130 131 match validate_token_with_dpop( 132 &state.db, 133 &extracted.token, 134 true, 135 dpop_proof, 136 method, 137 &uri, 138 false, 139 ) 140 .await 141 { 142 Ok(user) => Ok(BearerAuth(user)), 143 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 144 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 145 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 146 Err(_) => Err(AuthError::AuthenticationFailed), 147 } 148 } else { 149 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token).await { 150 Ok(user) => Ok(BearerAuth(user)), 151 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 152 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 153 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 154 Err(_) => Err(AuthError::AuthenticationFailed), 155 } 156 } 157 } 158} 159 160pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 161 162impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 163 type Rejection = AuthError; 164 165 async fn from_request_parts( 166 parts: &mut Parts, 167 state: &AppState, 168 ) -> Result<Self, Self::Rejection> { 169 let auth_header = parts 170 .headers 171 .get(AUTHORIZATION) 172 .ok_or(AuthError::MissingToken)? 173 .to_str() 174 .map_err(|_| AuthError::InvalidFormat)?; 175 176 let extracted = 177 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 178 179 if extracted.is_dpop { 180 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 181 let method = parts.method.as_str(); 182 let uri = build_full_url(&parts.uri.to_string()); 183 184 match validate_token_with_dpop( 185 &state.db, 186 &extracted.token, 187 true, 188 dpop_proof, 189 method, 190 &uri, 191 true, 192 ) 193 .await 194 { 195 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 196 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 197 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 198 Err(_) => Err(AuthError::AuthenticationFailed), 199 } 200 } else { 201 match validate_bearer_token_cached_allow_deactivated( 202 &state.db, 203 state.cache.as_ref(), 204 &extracted.token, 205 ) 206 .await 207 { 208 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 209 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 210 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 211 Err(_) => Err(AuthError::AuthenticationFailed), 212 } 213 } 214 } 215} 216 217pub struct BearerAuthAdmin(pub AuthenticatedUser); 218 219impl FromRequestParts<AppState> for BearerAuthAdmin { 220 type Rejection = AuthError; 221 222 async fn from_request_parts( 223 parts: &mut Parts, 224 state: &AppState, 225 ) -> Result<Self, Self::Rejection> { 226 let auth_header = parts 227 .headers 228 .get(AUTHORIZATION) 229 .ok_or(AuthError::MissingToken)? 230 .to_str() 231 .map_err(|_| AuthError::InvalidFormat)?; 232 233 let extracted = 234 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 235 236 let user = if extracted.is_dpop { 237 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 238 let method = parts.method.as_str(); 239 let uri = build_full_url(&parts.uri.to_string()); 240 241 match validate_token_with_dpop( 242 &state.db, 243 &extracted.token, 244 true, 245 dpop_proof, 246 method, 247 &uri, 248 false, 249 ) 250 .await 251 { 252 Ok(user) => user, 253 Err(TokenValidationError::AccountDeactivated) => { 254 return Err(AuthError::AccountDeactivated); 255 } 256 Err(TokenValidationError::AccountTakedown) => { 257 return Err(AuthError::AccountTakedown); 258 } 259 Err(TokenValidationError::TokenExpired) => { 260 return Err(AuthError::TokenExpired); 261 } 262 Err(_) => return Err(AuthError::AuthenticationFailed), 263 } 264 } else { 265 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token).await { 266 Ok(user) => user, 267 Err(TokenValidationError::AccountDeactivated) => { 268 return Err(AuthError::AccountDeactivated); 269 } 270 Err(TokenValidationError::AccountTakedown) => { 271 return Err(AuthError::AccountTakedown); 272 } 273 Err(TokenValidationError::TokenExpired) => { 274 return Err(AuthError::TokenExpired); 275 } 276 Err(_) => return Err(AuthError::AuthenticationFailed), 277 } 278 }; 279 280 if !user.is_admin { 281 return Err(AuthError::AdminRequired); 282 } 283 Ok(BearerAuthAdmin(user)) 284 } 285} 286 287#[cfg(test)] 288mod tests { 289 use super::*; 290 291 #[test] 292 fn test_extract_bearer_token() { 293 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 294 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123"); 295 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123"); 296 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123"); 297 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123"); 298 299 assert!(extract_bearer_token("Basic abc123").is_err()); 300 assert!(extract_bearer_token("Bearer").is_err()); 301 assert!(extract_bearer_token("Bearer ").is_err()); 302 assert!(extract_bearer_token("abc123").is_err()); 303 assert!(extract_bearer_token("").is_err()); 304 } 305}