this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, header::AUTHORIZATION, request::Parts},
5 response::{IntoResponse, Response},
6};
7use serde_json::json;
8
9use super::{
10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
12};
13use crate::state::AppState;
14use crate::util::build_full_url;
15
16pub struct BearerAuth(pub AuthenticatedUser);
17
18#[derive(Debug)]
19pub enum AuthError {
20 MissingToken,
21 InvalidFormat,
22 AuthenticationFailed,
23 TokenExpired,
24 AccountDeactivated,
25 AccountTakedown,
26 AdminRequired,
27}
28
29impl IntoResponse for AuthError {
30 fn into_response(self) -> Response {
31 let (status, error, message) = match self {
32 AuthError::MissingToken => (
33 StatusCode::UNAUTHORIZED,
34 "AuthenticationRequired",
35 "Authorization header is required",
36 ),
37 AuthError::InvalidFormat => (
38 StatusCode::UNAUTHORIZED,
39 "InvalidToken",
40 "Invalid authorization header format",
41 ),
42 AuthError::AuthenticationFailed => (
43 StatusCode::UNAUTHORIZED,
44 "InvalidToken",
45 "Token could not be verified",
46 ),
47 AuthError::TokenExpired => (
48 StatusCode::UNAUTHORIZED,
49 "ExpiredToken",
50 "Token has expired",
51 ),
52 AuthError::AccountDeactivated => (
53 StatusCode::UNAUTHORIZED,
54 "AccountDeactivated",
55 "Account is deactivated",
56 ),
57 AuthError::AccountTakedown => (
58 StatusCode::UNAUTHORIZED,
59 "AccountTakedown",
60 "Account has been taken down",
61 ),
62 AuthError::AdminRequired => (
63 StatusCode::FORBIDDEN,
64 "AdminRequired",
65 "This action requires admin privileges",
66 ),
67 };
68
69 (status, Json(json!({ "error": error, "message": message }))).into_response()
70 }
71}
72
73#[cfg(test)]
74fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
75 let auth_header = auth_header.trim();
76
77 if auth_header.len() < 8 {
78 return Err(AuthError::InvalidFormat);
79 }
80
81 let prefix = &auth_header[..7];
82 if !prefix.eq_ignore_ascii_case("bearer ") {
83 return Err(AuthError::InvalidFormat);
84 }
85
86 let token = auth_header[7..].trim();
87 if token.is_empty() {
88 return Err(AuthError::InvalidFormat);
89 }
90
91 Ok(token)
92}
93
94pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
95 let header = auth_header?;
96 let header = header.trim();
97
98 if header.len() < 7 {
99 return None;
100 }
101
102 if !header[..7].eq_ignore_ascii_case("bearer ") {
103 return None;
104 }
105
106 let token = header[7..].trim();
107 if token.is_empty() {
108 return None;
109 }
110
111 Some(token.to_string())
112}
113
114pub struct ExtractedToken {
115 pub token: String,
116 pub is_dpop: bool,
117}
118
119pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
120 let header = auth_header?;
121 let header = header.trim();
122
123 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
124 let token = header[7..].trim();
125 if token.is_empty() {
126 return None;
127 }
128 return Some(ExtractedToken {
129 token: token.to_string(),
130 is_dpop: false,
131 });
132 }
133
134 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
135 let token = header[5..].trim();
136 if token.is_empty() {
137 return None;
138 }
139 return Some(ExtractedToken {
140 token: token.to_string(),
141 is_dpop: true,
142 });
143 }
144
145 None
146}
147
148impl FromRequestParts<AppState> for BearerAuth {
149 type Rejection = AuthError;
150
151 async fn from_request_parts(
152 parts: &mut Parts,
153 state: &AppState,
154 ) -> Result<Self, Self::Rejection> {
155 let auth_header = parts
156 .headers
157 .get(AUTHORIZATION)
158 .ok_or(AuthError::MissingToken)?
159 .to_str()
160 .map_err(|_| AuthError::InvalidFormat)?;
161
162 let extracted =
163 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
164
165 if extracted.is_dpop {
166 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
167 let method = parts.method.as_str();
168 let uri = build_full_url(&parts.uri.to_string());
169
170 match validate_token_with_dpop(
171 &state.db,
172 &extracted.token,
173 true,
174 dpop_proof,
175 method,
176 &uri,
177 false,
178 )
179 .await
180 {
181 Ok(user) => Ok(BearerAuth(user)),
182 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
183 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
184 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
185 Err(_) => Err(AuthError::AuthenticationFailed),
186 }
187 } else {
188 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
189 Ok(user) => Ok(BearerAuth(user)),
190 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
191 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
192 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
193 Err(_) => Err(AuthError::AuthenticationFailed),
194 }
195 }
196 }
197}
198
199pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
200
201impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
202 type Rejection = AuthError;
203
204 async fn from_request_parts(
205 parts: &mut Parts,
206 state: &AppState,
207 ) -> Result<Self, Self::Rejection> {
208 let auth_header = parts
209 .headers
210 .get(AUTHORIZATION)
211 .ok_or(AuthError::MissingToken)?
212 .to_str()
213 .map_err(|_| AuthError::InvalidFormat)?;
214
215 let extracted =
216 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
217
218 if extracted.is_dpop {
219 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
220 let method = parts.method.as_str();
221 let uri = build_full_url(&parts.uri.to_string());
222
223 match validate_token_with_dpop(
224 &state.db,
225 &extracted.token,
226 true,
227 dpop_proof,
228 method,
229 &uri,
230 true,
231 )
232 .await
233 {
234 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
235 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
236 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
237 Err(_) => Err(AuthError::AuthenticationFailed),
238 }
239 } else {
240 match validate_bearer_token_cached_allow_deactivated(
241 &state.db,
242 &state.cache,
243 &extracted.token,
244 )
245 .await
246 {
247 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
248 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
249 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
250 Err(_) => Err(AuthError::AuthenticationFailed),
251 }
252 }
253 }
254}
255
256pub struct BearerAuthAdmin(pub AuthenticatedUser);
257
258impl FromRequestParts<AppState> for BearerAuthAdmin {
259 type Rejection = AuthError;
260
261 async fn from_request_parts(
262 parts: &mut Parts,
263 state: &AppState,
264 ) -> Result<Self, Self::Rejection> {
265 let auth_header = parts
266 .headers
267 .get(AUTHORIZATION)
268 .ok_or(AuthError::MissingToken)?
269 .to_str()
270 .map_err(|_| AuthError::InvalidFormat)?;
271
272 let extracted =
273 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
274
275 let user = if extracted.is_dpop {
276 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
277 let method = parts.method.as_str();
278 let uri = build_full_url(&parts.uri.to_string());
279
280 match validate_token_with_dpop(
281 &state.db,
282 &extracted.token,
283 true,
284 dpop_proof,
285 method,
286 &uri,
287 false,
288 )
289 .await
290 {
291 Ok(user) => user,
292 Err(TokenValidationError::AccountDeactivated) => {
293 return Err(AuthError::AccountDeactivated);
294 }
295 Err(TokenValidationError::AccountTakedown) => {
296 return Err(AuthError::AccountTakedown);
297 }
298 Err(TokenValidationError::TokenExpired) => {
299 return Err(AuthError::TokenExpired);
300 }
301 Err(_) => return Err(AuthError::AuthenticationFailed),
302 }
303 } else {
304 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
305 Ok(user) => user,
306 Err(TokenValidationError::AccountDeactivated) => {
307 return Err(AuthError::AccountDeactivated);
308 }
309 Err(TokenValidationError::AccountTakedown) => {
310 return Err(AuthError::AccountTakedown);
311 }
312 Err(TokenValidationError::TokenExpired) => {
313 return Err(AuthError::TokenExpired);
314 }
315 Err(_) => return Err(AuthError::AuthenticationFailed),
316 }
317 };
318
319 if !user.is_admin {
320 return Err(AuthError::AdminRequired);
321 }
322 Ok(BearerAuthAdmin(user))
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_extract_bearer_token() {
332 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
333 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
334 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
335 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
336 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
337
338 assert!(extract_bearer_token("Basic abc123").is_err());
339 assert!(extract_bearer_token("Bearer").is_err());
340 assert!(extract_bearer_token("Bearer ").is_err());
341 assert!(extract_bearer_token("abc123").is_err());
342 assert!(extract_bearer_token("").is_err());
343 }
344}