this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, header::AUTHORIZATION, request::Parts},
5 response::{IntoResponse, Response},
6};
7use serde_json::json;
8
9use super::{
10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
12};
13use crate::state::AppState;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 AccountDeactivated,
23 AccountTakedown,
24 AdminRequired,
25}
26
27impl IntoResponse for AuthError {
28 fn into_response(self) -> Response {
29 let (status, error, message) = match self {
30 AuthError::MissingToken => (
31 StatusCode::UNAUTHORIZED,
32 "AuthenticationRequired",
33 "Authorization header is required",
34 ),
35 AuthError::InvalidFormat => (
36 StatusCode::UNAUTHORIZED,
37 "InvalidToken",
38 "Invalid authorization header format",
39 ),
40 AuthError::AuthenticationFailed => (
41 StatusCode::UNAUTHORIZED,
42 "AuthenticationFailed",
43 "Invalid or expired token",
44 ),
45 AuthError::AccountDeactivated => (
46 StatusCode::UNAUTHORIZED,
47 "AccountDeactivated",
48 "Account is deactivated",
49 ),
50 AuthError::AccountTakedown => (
51 StatusCode::UNAUTHORIZED,
52 "AccountTakedown",
53 "Account has been taken down",
54 ),
55 AuthError::AdminRequired => (
56 StatusCode::FORBIDDEN,
57 "AdminRequired",
58 "This action requires admin privileges",
59 ),
60 };
61
62 (status, Json(json!({ "error": error, "message": message }))).into_response()
63 }
64}
65
66#[cfg(test)]
67fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
68 let auth_header = auth_header.trim();
69
70 if auth_header.len() < 8 {
71 return Err(AuthError::InvalidFormat);
72 }
73
74 let prefix = &auth_header[..7];
75 if !prefix.eq_ignore_ascii_case("bearer ") {
76 return Err(AuthError::InvalidFormat);
77 }
78
79 let token = auth_header[7..].trim();
80 if token.is_empty() {
81 return Err(AuthError::InvalidFormat);
82 }
83
84 Ok(token)
85}
86
87pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
88 let header = auth_header?;
89 let header = header.trim();
90
91 if header.len() < 7 {
92 return None;
93 }
94
95 if !header[..7].eq_ignore_ascii_case("bearer ") {
96 return None;
97 }
98
99 let token = header[7..].trim();
100 if token.is_empty() {
101 return None;
102 }
103
104 Some(token.to_string())
105}
106
107pub struct ExtractedToken {
108 pub token: String,
109 pub is_dpop: bool,
110}
111
112pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
113 let header = auth_header?;
114 let header = header.trim();
115
116 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
117 let token = header[7..].trim();
118 if token.is_empty() {
119 return None;
120 }
121 return Some(ExtractedToken {
122 token: token.to_string(),
123 is_dpop: false,
124 });
125 }
126
127 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
128 let token = header[5..].trim();
129 if token.is_empty() {
130 return None;
131 }
132 return Some(ExtractedToken {
133 token: token.to_string(),
134 is_dpop: true,
135 });
136 }
137
138 None
139}
140
141impl FromRequestParts<AppState> for BearerAuth {
142 type Rejection = AuthError;
143
144 async fn from_request_parts(
145 parts: &mut Parts,
146 state: &AppState,
147 ) -> Result<Self, Self::Rejection> {
148 let auth_header = parts
149 .headers
150 .get(AUTHORIZATION)
151 .ok_or(AuthError::MissingToken)?
152 .to_str()
153 .map_err(|_| AuthError::InvalidFormat)?;
154
155 let extracted =
156 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
157
158 if extracted.is_dpop {
159 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
160 let method = parts.method.as_str();
161 let uri = parts.uri.to_string();
162
163 match validate_token_with_dpop(
164 &state.db,
165 &extracted.token,
166 true,
167 dpop_proof,
168 method,
169 &uri,
170 false,
171 )
172 .await
173 {
174 Ok(user) => Ok(BearerAuth(user)),
175 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
176 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
177 Err(_) => Err(AuthError::AuthenticationFailed),
178 }
179 } else {
180 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
181 Ok(user) => Ok(BearerAuth(user)),
182 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
183 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
184 Err(_) => Err(AuthError::AuthenticationFailed),
185 }
186 }
187 }
188}
189
190pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
191
192impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
193 type Rejection = AuthError;
194
195 async fn from_request_parts(
196 parts: &mut Parts,
197 state: &AppState,
198 ) -> Result<Self, Self::Rejection> {
199 let auth_header = parts
200 .headers
201 .get(AUTHORIZATION)
202 .ok_or(AuthError::MissingToken)?
203 .to_str()
204 .map_err(|_| AuthError::InvalidFormat)?;
205
206 let extracted =
207 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
208
209 if extracted.is_dpop {
210 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
211 let method = parts.method.as_str();
212 let uri = parts.uri.to_string();
213
214 match validate_token_with_dpop(
215 &state.db,
216 &extracted.token,
217 true,
218 dpop_proof,
219 method,
220 &uri,
221 true,
222 )
223 .await
224 {
225 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
226 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
227 Err(_) => Err(AuthError::AuthenticationFailed),
228 }
229 } else {
230 match validate_bearer_token_cached_allow_deactivated(
231 &state.db,
232 &state.cache,
233 &extracted.token,
234 )
235 .await
236 {
237 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
238 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
239 Err(_) => Err(AuthError::AuthenticationFailed),
240 }
241 }
242 }
243}
244
245pub struct BearerAuthAdmin(pub AuthenticatedUser);
246
247impl FromRequestParts<AppState> for BearerAuthAdmin {
248 type Rejection = AuthError;
249
250 async fn from_request_parts(
251 parts: &mut Parts,
252 state: &AppState,
253 ) -> Result<Self, Self::Rejection> {
254 let auth_header = parts
255 .headers
256 .get(AUTHORIZATION)
257 .ok_or(AuthError::MissingToken)?
258 .to_str()
259 .map_err(|_| AuthError::InvalidFormat)?;
260
261 let extracted =
262 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
263
264 let user = if extracted.is_dpop {
265 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
266 let method = parts.method.as_str();
267 let uri = parts.uri.to_string();
268
269 match validate_token_with_dpop(
270 &state.db,
271 &extracted.token,
272 true,
273 dpop_proof,
274 method,
275 &uri,
276 false,
277 )
278 .await
279 {
280 Ok(user) => user,
281 Err(TokenValidationError::AccountDeactivated) => {
282 return Err(AuthError::AccountDeactivated);
283 }
284 Err(TokenValidationError::AccountTakedown) => {
285 return Err(AuthError::AccountTakedown);
286 }
287 Err(_) => return Err(AuthError::AuthenticationFailed),
288 }
289 } else {
290 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
291 Ok(user) => user,
292 Err(TokenValidationError::AccountDeactivated) => {
293 return Err(AuthError::AccountDeactivated);
294 }
295 Err(TokenValidationError::AccountTakedown) => {
296 return Err(AuthError::AccountTakedown);
297 }
298 Err(_) => return Err(AuthError::AuthenticationFailed),
299 }
300 };
301
302 if !user.is_admin {
303 return Err(AuthError::AdminRequired);
304 }
305 Ok(BearerAuthAdmin(user))
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_extract_bearer_token() {
315 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
316 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
317 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
318 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
319 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
320
321 assert!(extract_bearer_token("Basic abc123").is_err());
322 assert!(extract_bearer_token("Bearer").is_err());
323 assert!(extract_bearer_token("Bearer ").is_err());
324 assert!(extract_bearer_token("abc123").is_err());
325 assert!(extract_bearer_token("").is_err());
326 }
327}