this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{StatusCode, request::Parts, header::AUTHORIZATION},
4 response::{IntoResponse, Response},
5 Json,
6};
7use serde_json::json;
8
9use crate::state::AppState;
10use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
11
12pub struct BearerAuth(pub AuthenticatedUser);
13
14#[derive(Debug)]
15pub enum AuthError {
16 MissingToken,
17 InvalidFormat,
18 AuthenticationFailed,
19 AccountDeactivated,
20 AccountTakedown,
21}
22
23impl IntoResponse for AuthError {
24 fn into_response(self) -> Response {
25 let (status, error, message) = match self {
26 AuthError::MissingToken => (
27 StatusCode::UNAUTHORIZED,
28 "AuthenticationRequired",
29 "Authorization header is required",
30 ),
31 AuthError::InvalidFormat => (
32 StatusCode::UNAUTHORIZED,
33 "InvalidToken",
34 "Invalid authorization header format",
35 ),
36 AuthError::AuthenticationFailed => (
37 StatusCode::UNAUTHORIZED,
38 "AuthenticationFailed",
39 "Invalid or expired token",
40 ),
41 AuthError::AccountDeactivated => (
42 StatusCode::UNAUTHORIZED,
43 "AccountDeactivated",
44 "Account is deactivated",
45 ),
46 AuthError::AccountTakedown => (
47 StatusCode::UNAUTHORIZED,
48 "AccountTakedown",
49 "Account has been taken down",
50 ),
51 };
52
53 (status, Json(json!({ "error": error, "message": message }))).into_response()
54 }
55}
56
57fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
58 let auth_header = auth_header.trim();
59
60 if auth_header.len() < 8 {
61 return Err(AuthError::InvalidFormat);
62 }
63
64 let prefix = &auth_header[..7];
65 if !prefix.eq_ignore_ascii_case("bearer ") {
66 return Err(AuthError::InvalidFormat);
67 }
68
69 let token = auth_header[7..].trim();
70 if token.is_empty() {
71 return Err(AuthError::InvalidFormat);
72 }
73
74 Ok(token)
75}
76
77pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
78 let header = auth_header?;
79 let header = header.trim();
80
81 if header.len() < 7 {
82 return None;
83 }
84
85 if !header[..7].eq_ignore_ascii_case("bearer ") {
86 return None;
87 }
88
89 let token = header[7..].trim();
90 if token.is_empty() {
91 return None;
92 }
93
94 Some(token.to_string())
95}
96
97pub struct ExtractedToken {
98 pub token: String,
99 pub is_dpop: bool,
100}
101
102pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
103 let header = auth_header?;
104 let header = header.trim();
105
106 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
107 let token = header[7..].trim();
108 if token.is_empty() {
109 return None;
110 }
111 return Some(ExtractedToken { token: token.to_string(), is_dpop: false });
112 }
113
114 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
115 let token = header[5..].trim();
116 if token.is_empty() {
117 return None;
118 }
119 return Some(ExtractedToken { token: token.to_string(), is_dpop: true });
120 }
121
122 None
123}
124
125impl FromRequestParts<AppState> for BearerAuth {
126 type Rejection = AuthError;
127
128 async fn from_request_parts(
129 parts: &mut Parts,
130 state: &AppState,
131 ) -> Result<Self, Self::Rejection> {
132 let auth_header = parts
133 .headers
134 .get(AUTHORIZATION)
135 .ok_or(AuthError::MissingToken)?
136 .to_str()
137 .map_err(|_| AuthError::InvalidFormat)?;
138
139 let token = extract_bearer_token(auth_header)?;
140
141 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
142 Ok(user) => Ok(BearerAuth(user)),
143 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
144 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
145 Err(_) => Err(AuthError::AuthenticationFailed),
146 }
147 }
148}
149
150pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
151
152impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
153 type Rejection = AuthError;
154
155 async fn from_request_parts(
156 parts: &mut Parts,
157 state: &AppState,
158 ) -> Result<Self, Self::Rejection> {
159 let auth_header = parts
160 .headers
161 .get(AUTHORIZATION)
162 .ok_or(AuthError::MissingToken)?
163 .to_str()
164 .map_err(|_| AuthError::InvalidFormat)?;
165
166 let token = extract_bearer_token(auth_header)?;
167
168 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
169 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
170 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
171 Err(_) => Err(AuthError::AuthenticationFailed),
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_extract_bearer_token() {
182 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
183 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
184 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
185 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
186 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
187
188 assert!(extract_bearer_token("Basic abc123").is_err());
189 assert!(extract_bearer_token("Bearer").is_err());
190 assert!(extract_bearer_token("Bearer ").is_err());
191 assert!(extract_bearer_token("abc123").is_err());
192 assert!(extract_bearer_token("").is_err());
193 }
194}