this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{StatusCode, request::Parts, header::AUTHORIZATION},
4 response::{IntoResponse, Response},
5 Json,
6};
7use serde_json::json;
8use crate::state::AppState;
9use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
10pub struct BearerAuth(pub AuthenticatedUser);
11#[derive(Debug)]
12pub enum AuthError {
13 MissingToken,
14 InvalidFormat,
15 AuthenticationFailed,
16 AccountDeactivated,
17 AccountTakedown,
18}
19impl IntoResponse for AuthError {
20 fn into_response(self) -> Response {
21 let (status, error, message) = match self {
22 AuthError::MissingToken => (
23 StatusCode::UNAUTHORIZED,
24 "AuthenticationRequired",
25 "Authorization header is required",
26 ),
27 AuthError::InvalidFormat => (
28 StatusCode::UNAUTHORIZED,
29 "InvalidToken",
30 "Invalid authorization header format",
31 ),
32 AuthError::AuthenticationFailed => (
33 StatusCode::UNAUTHORIZED,
34 "AuthenticationFailed",
35 "Invalid or expired token",
36 ),
37 AuthError::AccountDeactivated => (
38 StatusCode::UNAUTHORIZED,
39 "AccountDeactivated",
40 "Account is deactivated",
41 ),
42 AuthError::AccountTakedown => (
43 StatusCode::UNAUTHORIZED,
44 "AccountTakedown",
45 "Account has been taken down",
46 ),
47 };
48 (status, Json(json!({ "error": error, "message": message }))).into_response()
49 }
50}
51fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
52 let auth_header = auth_header.trim();
53 if auth_header.len() < 8 {
54 return Err(AuthError::InvalidFormat);
55 }
56 let prefix = &auth_header[..7];
57 if !prefix.eq_ignore_ascii_case("bearer ") {
58 return Err(AuthError::InvalidFormat);
59 }
60 let token = auth_header[7..].trim();
61 if token.is_empty() {
62 return Err(AuthError::InvalidFormat);
63 }
64 Ok(token)
65}
66pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
67 let header = auth_header?;
68 let header = header.trim();
69 if header.len() < 7 {
70 return None;
71 }
72 if !header[..7].eq_ignore_ascii_case("bearer ") {
73 return None;
74 }
75 let token = header[7..].trim();
76 if token.is_empty() {
77 return None;
78 }
79 Some(token.to_string())
80}
81impl FromRequestParts<AppState> for BearerAuth {
82 type Rejection = AuthError;
83 async fn from_request_parts(
84 parts: &mut Parts,
85 state: &AppState,
86 ) -> Result<Self, Self::Rejection> {
87 let auth_header = parts
88 .headers
89 .get(AUTHORIZATION)
90 .ok_or(AuthError::MissingToken)?
91 .to_str()
92 .map_err(|_| AuthError::InvalidFormat)?;
93 let token = extract_bearer_token(auth_header)?;
94 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
95 Ok(user) => Ok(BearerAuth(user)),
96 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
97 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
98 Err(_) => Err(AuthError::AuthenticationFailed),
99 }
100 }
101}
102pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
103impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
104 type Rejection = AuthError;
105 async fn from_request_parts(
106 parts: &mut Parts,
107 state: &AppState,
108 ) -> Result<Self, Self::Rejection> {
109 let auth_header = parts
110 .headers
111 .get(AUTHORIZATION)
112 .ok_or(AuthError::MissingToken)?
113 .to_str()
114 .map_err(|_| AuthError::InvalidFormat)?;
115 let token = extract_bearer_token(auth_header)?;
116 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
117 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
118 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
119 Err(_) => Err(AuthError::AuthenticationFailed),
120 }
121 }
122}
123#[cfg(test)]
124mod tests {
125 use super::*;
126 #[test]
127 fn test_extract_bearer_token() {
128 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
129 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
130 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
131 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
132 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
133 assert!(extract_bearer_token("Basic abc123").is_err());
134 assert!(extract_bearer_token("Bearer").is_err());
135 assert!(extract_bearer_token("Bearer ").is_err());
136 assert!(extract_bearer_token("abc123").is_err());
137 assert!(extract_bearer_token("").is_err());
138 }
139}