this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{header::AUTHORIZATION, request::Parts},
4 response::{IntoResponse, Response},
5};
6
7use super::{
8 AuthenticatedUser, TokenValidationError, validate_bearer_token_allow_takendown,
9 validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated,
10 validate_token_with_dpop,
11};
12use crate::api::error::ApiError;
13use crate::state::AppState;
14use crate::util::build_full_url;
15
16pub struct BearerAuth(pub AuthenticatedUser);
17
18#[derive(Debug)]
19pub enum AuthError {
20 MissingToken,
21 InvalidFormat,
22 AuthenticationFailed,
23 TokenExpired,
24 AccountDeactivated,
25 AccountTakedown,
26 AdminRequired,
27}
28
29impl IntoResponse for AuthError {
30 fn into_response(self) -> Response {
31 ApiError::from(self).into_response()
32 }
33}
34
35#[cfg(test)]
36fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
37 let auth_header = auth_header.trim();
38
39 if auth_header.len() < 8 {
40 return Err(AuthError::InvalidFormat);
41 }
42
43 let prefix = &auth_header[..7];
44 if !prefix.eq_ignore_ascii_case("bearer ") {
45 return Err(AuthError::InvalidFormat);
46 }
47
48 let token = auth_header[7..].trim();
49 if token.is_empty() {
50 return Err(AuthError::InvalidFormat);
51 }
52
53 Ok(token)
54}
55
56pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
57 let header = auth_header?;
58 let header = header.trim();
59
60 if header.len() < 7 {
61 return None;
62 }
63
64 if !header[..7].eq_ignore_ascii_case("bearer ") {
65 return None;
66 }
67
68 let token = header[7..].trim();
69 if token.is_empty() {
70 return None;
71 }
72
73 Some(token.to_string())
74}
75
76pub struct ExtractedToken {
77 pub token: String,
78 pub is_dpop: bool,
79}
80
81pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
82 let header = auth_header?;
83 let header = header.trim();
84
85 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
86 let token = header[7..].trim();
87 if token.is_empty() {
88 return None;
89 }
90 return Some(ExtractedToken {
91 token: token.to_string(),
92 is_dpop: false,
93 });
94 }
95
96 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
97 let token = header[5..].trim();
98 if token.is_empty() {
99 return None;
100 }
101 return Some(ExtractedToken {
102 token: token.to_string(),
103 is_dpop: true,
104 });
105 }
106
107 None
108}
109
110impl FromRequestParts<AppState> for BearerAuth {
111 type Rejection = AuthError;
112
113 async fn from_request_parts(
114 parts: &mut Parts,
115 state: &AppState,
116 ) -> Result<Self, Self::Rejection> {
117 let auth_header = parts
118 .headers
119 .get(AUTHORIZATION)
120 .ok_or(AuthError::MissingToken)?
121 .to_str()
122 .map_err(|_| AuthError::InvalidFormat)?;
123
124 let extracted =
125 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
126
127 if extracted.is_dpop {
128 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
129 let method = parts.method.as_str();
130 let uri = build_full_url(&parts.uri.to_string());
131
132 match validate_token_with_dpop(
133 &state.db,
134 &extracted.token,
135 true,
136 dpop_proof,
137 method,
138 &uri,
139 false,
140 false,
141 )
142 .await
143 {
144 Ok(user) => Ok(BearerAuth(user)),
145 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
146 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
147 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
148 Err(_) => Err(AuthError::AuthenticationFailed),
149 }
150 } else {
151 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token)
152 .await
153 {
154 Ok(user) => Ok(BearerAuth(user)),
155 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
156 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
157 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
158 Err(_) => Err(AuthError::AuthenticationFailed),
159 }
160 }
161 }
162}
163
164pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
165
166impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
167 type Rejection = AuthError;
168
169 async fn from_request_parts(
170 parts: &mut Parts,
171 state: &AppState,
172 ) -> Result<Self, Self::Rejection> {
173 let auth_header = parts
174 .headers
175 .get(AUTHORIZATION)
176 .ok_or(AuthError::MissingToken)?
177 .to_str()
178 .map_err(|_| AuthError::InvalidFormat)?;
179
180 let extracted =
181 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
182
183 if extracted.is_dpop {
184 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
185 let method = parts.method.as_str();
186 let uri = build_full_url(&parts.uri.to_string());
187
188 match validate_token_with_dpop(
189 &state.db,
190 &extracted.token,
191 true,
192 dpop_proof,
193 method,
194 &uri,
195 true,
196 false,
197 )
198 .await
199 {
200 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
201 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
202 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
203 Err(_) => Err(AuthError::AuthenticationFailed),
204 }
205 } else {
206 match validate_bearer_token_cached_allow_deactivated(
207 &state.db,
208 state.cache.as_ref(),
209 &extracted.token,
210 )
211 .await
212 {
213 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
214 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
215 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
216 Err(_) => Err(AuthError::AuthenticationFailed),
217 }
218 }
219 }
220}
221
222pub struct BearerAuthAllowTakendown(pub AuthenticatedUser);
223
224impl FromRequestParts<AppState> for BearerAuthAllowTakendown {
225 type Rejection = AuthError;
226
227 async fn from_request_parts(
228 parts: &mut Parts,
229 state: &AppState,
230 ) -> Result<Self, Self::Rejection> {
231 let auth_header = parts
232 .headers
233 .get(AUTHORIZATION)
234 .ok_or(AuthError::MissingToken)?
235 .to_str()
236 .map_err(|_| AuthError::InvalidFormat)?;
237
238 let extracted =
239 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
240
241 if extracted.is_dpop {
242 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
243 let method = parts.method.as_str();
244 let uri = build_full_url(&parts.uri.to_string());
245
246 match validate_token_with_dpop(
247 &state.db,
248 &extracted.token,
249 true,
250 dpop_proof,
251 method,
252 &uri,
253 false,
254 true,
255 )
256 .await
257 {
258 Ok(user) => Ok(BearerAuthAllowTakendown(user)),
259 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
260 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
261 Err(_) => Err(AuthError::AuthenticationFailed),
262 }
263 } else {
264 match validate_bearer_token_allow_takendown(&state.db, &extracted.token).await {
265 Ok(user) => Ok(BearerAuthAllowTakendown(user)),
266 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
267 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
268 Err(_) => Err(AuthError::AuthenticationFailed),
269 }
270 }
271 }
272}
273
274pub struct BearerAuthAdmin(pub AuthenticatedUser);
275
276impl FromRequestParts<AppState> for BearerAuthAdmin {
277 type Rejection = AuthError;
278
279 async fn from_request_parts(
280 parts: &mut Parts,
281 state: &AppState,
282 ) -> Result<Self, Self::Rejection> {
283 let auth_header = parts
284 .headers
285 .get(AUTHORIZATION)
286 .ok_or(AuthError::MissingToken)?
287 .to_str()
288 .map_err(|_| AuthError::InvalidFormat)?;
289
290 let extracted =
291 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
292
293 let user = if extracted.is_dpop {
294 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
295 let method = parts.method.as_str();
296 let uri = build_full_url(&parts.uri.to_string());
297
298 match validate_token_with_dpop(
299 &state.db,
300 &extracted.token,
301 true,
302 dpop_proof,
303 method,
304 &uri,
305 false,
306 false,
307 )
308 .await
309 {
310 Ok(user) => user,
311 Err(TokenValidationError::AccountDeactivated) => {
312 return Err(AuthError::AccountDeactivated);
313 }
314 Err(TokenValidationError::AccountTakedown) => {
315 return Err(AuthError::AccountTakedown);
316 }
317 Err(TokenValidationError::TokenExpired) => {
318 return Err(AuthError::TokenExpired);
319 }
320 Err(_) => return Err(AuthError::AuthenticationFailed),
321 }
322 } else {
323 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token)
324 .await
325 {
326 Ok(user) => user,
327 Err(TokenValidationError::AccountDeactivated) => {
328 return Err(AuthError::AccountDeactivated);
329 }
330 Err(TokenValidationError::AccountTakedown) => {
331 return Err(AuthError::AccountTakedown);
332 }
333 Err(TokenValidationError::TokenExpired) => {
334 return Err(AuthError::TokenExpired);
335 }
336 Err(_) => return Err(AuthError::AuthenticationFailed),
337 }
338 };
339
340 if !user.is_admin {
341 return Err(AuthError::AdminRequired);
342 }
343 Ok(BearerAuthAdmin(user))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_extract_bearer_token() {
353 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
354 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
355 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
356 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
357 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
358
359 assert!(extract_bearer_token("Basic abc123").is_err());
360 assert!(extract_bearer_token("Bearer").is_err());
361 assert!(extract_bearer_token("Bearer ").is_err());
362 assert!(extract_bearer_token("abc123").is_err());
363 assert!(extract_bearer_token("").is_err());
364 }
365}