this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, header::AUTHORIZATION, request::Parts},
5 response::{IntoResponse, Response},
6};
7use serde_json::json;
8
9use super::{
10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
12};
13use crate::state::AppState;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 TokenExpired,
23 AccountDeactivated,
24 AccountTakedown,
25 AdminRequired,
26}
27
28impl IntoResponse for AuthError {
29 fn into_response(self) -> Response {
30 let (status, error, message) = match self {
31 AuthError::MissingToken => (
32 StatusCode::UNAUTHORIZED,
33 "AuthenticationRequired",
34 "Authorization header is required",
35 ),
36 AuthError::InvalidFormat => (
37 StatusCode::UNAUTHORIZED,
38 "InvalidToken",
39 "Invalid authorization header format",
40 ),
41 AuthError::AuthenticationFailed => (
42 StatusCode::UNAUTHORIZED,
43 "InvalidToken",
44 "Token could not be verified",
45 ),
46 AuthError::TokenExpired => (
47 StatusCode::UNAUTHORIZED,
48 "ExpiredToken",
49 "Token has expired",
50 ),
51 AuthError::AccountDeactivated => (
52 StatusCode::UNAUTHORIZED,
53 "AccountDeactivated",
54 "Account is deactivated",
55 ),
56 AuthError::AccountTakedown => (
57 StatusCode::UNAUTHORIZED,
58 "AccountTakedown",
59 "Account has been taken down",
60 ),
61 AuthError::AdminRequired => (
62 StatusCode::FORBIDDEN,
63 "AdminRequired",
64 "This action requires admin privileges",
65 ),
66 };
67
68 (status, Json(json!({ "error": error, "message": message }))).into_response()
69 }
70}
71
72#[cfg(test)]
73fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
74 let auth_header = auth_header.trim();
75
76 if auth_header.len() < 8 {
77 return Err(AuthError::InvalidFormat);
78 }
79
80 let prefix = &auth_header[..7];
81 if !prefix.eq_ignore_ascii_case("bearer ") {
82 return Err(AuthError::InvalidFormat);
83 }
84
85 let token = auth_header[7..].trim();
86 if token.is_empty() {
87 return Err(AuthError::InvalidFormat);
88 }
89
90 Ok(token)
91}
92
93pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
94 let header = auth_header?;
95 let header = header.trim();
96
97 if header.len() < 7 {
98 return None;
99 }
100
101 if !header[..7].eq_ignore_ascii_case("bearer ") {
102 return None;
103 }
104
105 let token = header[7..].trim();
106 if token.is_empty() {
107 return None;
108 }
109
110 Some(token.to_string())
111}
112
113pub struct ExtractedToken {
114 pub token: String,
115 pub is_dpop: bool,
116}
117
118pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
119 let header = auth_header?;
120 let header = header.trim();
121
122 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
123 let token = header[7..].trim();
124 if token.is_empty() {
125 return None;
126 }
127 return Some(ExtractedToken {
128 token: token.to_string(),
129 is_dpop: false,
130 });
131 }
132
133 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
134 let token = header[5..].trim();
135 if token.is_empty() {
136 return None;
137 }
138 return Some(ExtractedToken {
139 token: token.to_string(),
140 is_dpop: true,
141 });
142 }
143
144 None
145}
146
147impl FromRequestParts<AppState> for BearerAuth {
148 type Rejection = AuthError;
149
150 async fn from_request_parts(
151 parts: &mut Parts,
152 state: &AppState,
153 ) -> Result<Self, Self::Rejection> {
154 let auth_header = parts
155 .headers
156 .get(AUTHORIZATION)
157 .ok_or(AuthError::MissingToken)?
158 .to_str()
159 .map_err(|_| AuthError::InvalidFormat)?;
160
161 let extracted =
162 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
163
164 if extracted.is_dpop {
165 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
166 let method = parts.method.as_str();
167 let uri = parts.uri.to_string();
168
169 match validate_token_with_dpop(
170 &state.db,
171 &extracted.token,
172 true,
173 dpop_proof,
174 method,
175 &uri,
176 false,
177 )
178 .await
179 {
180 Ok(user) => Ok(BearerAuth(user)),
181 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
182 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
183 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
184 Err(_) => Err(AuthError::AuthenticationFailed),
185 }
186 } else {
187 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
188 Ok(user) => Ok(BearerAuth(user)),
189 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
190 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
191 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
192 Err(_) => Err(AuthError::AuthenticationFailed),
193 }
194 }
195 }
196}
197
198pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
199
200impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
201 type Rejection = AuthError;
202
203 async fn from_request_parts(
204 parts: &mut Parts,
205 state: &AppState,
206 ) -> Result<Self, Self::Rejection> {
207 let auth_header = parts
208 .headers
209 .get(AUTHORIZATION)
210 .ok_or(AuthError::MissingToken)?
211 .to_str()
212 .map_err(|_| AuthError::InvalidFormat)?;
213
214 let extracted =
215 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
216
217 if extracted.is_dpop {
218 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
219 let method = parts.method.as_str();
220 let uri = parts.uri.to_string();
221
222 match validate_token_with_dpop(
223 &state.db,
224 &extracted.token,
225 true,
226 dpop_proof,
227 method,
228 &uri,
229 true,
230 )
231 .await
232 {
233 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
234 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
235 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
236 Err(_) => Err(AuthError::AuthenticationFailed),
237 }
238 } else {
239 match validate_bearer_token_cached_allow_deactivated(
240 &state.db,
241 &state.cache,
242 &extracted.token,
243 )
244 .await
245 {
246 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
247 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
248 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
249 Err(_) => Err(AuthError::AuthenticationFailed),
250 }
251 }
252 }
253}
254
255pub struct BearerAuthAdmin(pub AuthenticatedUser);
256
257impl FromRequestParts<AppState> for BearerAuthAdmin {
258 type Rejection = AuthError;
259
260 async fn from_request_parts(
261 parts: &mut Parts,
262 state: &AppState,
263 ) -> Result<Self, Self::Rejection> {
264 let auth_header = parts
265 .headers
266 .get(AUTHORIZATION)
267 .ok_or(AuthError::MissingToken)?
268 .to_str()
269 .map_err(|_| AuthError::InvalidFormat)?;
270
271 let extracted =
272 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
273
274 let user = if extracted.is_dpop {
275 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
276 let method = parts.method.as_str();
277 let uri = parts.uri.to_string();
278
279 match validate_token_with_dpop(
280 &state.db,
281 &extracted.token,
282 true,
283 dpop_proof,
284 method,
285 &uri,
286 false,
287 )
288 .await
289 {
290 Ok(user) => user,
291 Err(TokenValidationError::AccountDeactivated) => {
292 return Err(AuthError::AccountDeactivated);
293 }
294 Err(TokenValidationError::AccountTakedown) => {
295 return Err(AuthError::AccountTakedown);
296 }
297 Err(TokenValidationError::TokenExpired) => {
298 return Err(AuthError::TokenExpired);
299 }
300 Err(_) => return Err(AuthError::AuthenticationFailed),
301 }
302 } else {
303 match validate_bearer_token_cached(&state.db, &state.cache, &extracted.token).await {
304 Ok(user) => user,
305 Err(TokenValidationError::AccountDeactivated) => {
306 return Err(AuthError::AccountDeactivated);
307 }
308 Err(TokenValidationError::AccountTakedown) => {
309 return Err(AuthError::AccountTakedown);
310 }
311 Err(TokenValidationError::TokenExpired) => {
312 return Err(AuthError::TokenExpired);
313 }
314 Err(_) => return Err(AuthError::AuthenticationFailed),
315 }
316 };
317
318 if !user.is_admin {
319 return Err(AuthError::AdminRequired);
320 }
321 Ok(BearerAuthAdmin(user))
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_extract_bearer_token() {
331 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
332 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
333 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
334 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
335 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
336
337 assert!(extract_bearer_token("Basic abc123").is_err());
338 assert!(extract_bearer_token("Bearer").is_err());
339 assert!(extract_bearer_token("Bearer ").is_err());
340 assert!(extract_bearer_token("abc123").is_err());
341 assert!(extract_bearer_token("").is_err());
342 }
343}