this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{StatusCode, request::Parts, header::AUTHORIZATION},
4 response::{IntoResponse, Response},
5 Json,
6};
7use serde_json::json;
8
9use crate::state::AppState;
10use super::{AuthenticatedUser, validate_bearer_token};
11
12pub struct BearerAuth(pub AuthenticatedUser);
13
14#[derive(Debug)]
15pub enum AuthError {
16 MissingToken,
17 InvalidFormat,
18 AuthenticationFailed,
19 AccountDeactivated,
20 AccountTakedown,
21}
22
23impl IntoResponse for AuthError {
24 fn into_response(self) -> Response {
25 let (status, error, message) = match self {
26 AuthError::MissingToken => (
27 StatusCode::UNAUTHORIZED,
28 "AuthenticationRequired",
29 "Authorization header is required",
30 ),
31 AuthError::InvalidFormat => (
32 StatusCode::UNAUTHORIZED,
33 "InvalidToken",
34 "Invalid authorization header format",
35 ),
36 AuthError::AuthenticationFailed => (
37 StatusCode::UNAUTHORIZED,
38 "AuthenticationFailed",
39 "Invalid or expired token",
40 ),
41 AuthError::AccountDeactivated => (
42 StatusCode::UNAUTHORIZED,
43 "AccountDeactivated",
44 "Account is deactivated",
45 ),
46 AuthError::AccountTakedown => (
47 StatusCode::UNAUTHORIZED,
48 "AccountTakedown",
49 "Account has been taken down",
50 ),
51 };
52
53 (status, Json(json!({ "error": error, "message": message }))).into_response()
54 }
55}
56
57fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
58 let auth_header = auth_header.trim();
59
60 if auth_header.len() < 8 {
61 return Err(AuthError::InvalidFormat);
62 }
63
64 let prefix = &auth_header[..7];
65 if !prefix.eq_ignore_ascii_case("bearer ") {
66 return Err(AuthError::InvalidFormat);
67 }
68
69 let token = auth_header[7..].trim();
70 if token.is_empty() {
71 return Err(AuthError::InvalidFormat);
72 }
73
74 Ok(token)
75}
76
77pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
78 let header = auth_header?;
79 let header = header.trim();
80
81 if header.len() < 7 {
82 return None;
83 }
84
85 if !header[..7].eq_ignore_ascii_case("bearer ") {
86 return None;
87 }
88
89 let token = header[7..].trim();
90 if token.is_empty() {
91 return None;
92 }
93
94 Some(token.to_string())
95}
96
97impl FromRequestParts<AppState> for BearerAuth {
98 type Rejection = AuthError;
99
100 async fn from_request_parts(
101 parts: &mut Parts,
102 state: &AppState,
103 ) -> Result<Self, Self::Rejection> {
104 let auth_header = parts
105 .headers
106 .get(AUTHORIZATION)
107 .ok_or(AuthError::MissingToken)?
108 .to_str()
109 .map_err(|_| AuthError::InvalidFormat)?;
110
111 let token = extract_bearer_token(auth_header)?;
112
113 match validate_bearer_token(&state.db, token).await {
114 Ok(user) => Ok(BearerAuth(user)),
115 Err("AccountDeactivated") => Err(AuthError::AccountDeactivated),
116 Err("AccountTakedown") => Err(AuthError::AccountTakedown),
117 Err(_) => Err(AuthError::AuthenticationFailed),
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_extract_bearer_token() {
128 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
129 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
130 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
131 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
132 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
133
134 assert!(extract_bearer_token("Basic abc123").is_err());
135 assert!(extract_bearer_token("Bearer").is_err());
136 assert!(extract_bearer_token("Bearer ").is_err());
137 assert!(extract_bearer_token("abc123").is_err());
138 assert!(extract_bearer_token("").is_err());
139 }
140}