this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, header::AUTHORIZATION, request::Parts},
5 response::{IntoResponse, Response},
6};
7use serde_json::json;
8
9use super::{
10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11 validate_bearer_token_cached_allow_deactivated,
12};
13use crate::state::AppState;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 AccountDeactivated,
23 AccountTakedown,
24 AdminRequired,
25}
26
27impl IntoResponse for AuthError {
28 fn into_response(self) -> Response {
29 let (status, error, message) = match self {
30 AuthError::MissingToken => (
31 StatusCode::UNAUTHORIZED,
32 "AuthenticationRequired",
33 "Authorization header is required",
34 ),
35 AuthError::InvalidFormat => (
36 StatusCode::UNAUTHORIZED,
37 "InvalidToken",
38 "Invalid authorization header format",
39 ),
40 AuthError::AuthenticationFailed => (
41 StatusCode::UNAUTHORIZED,
42 "AuthenticationFailed",
43 "Invalid or expired token",
44 ),
45 AuthError::AccountDeactivated => (
46 StatusCode::UNAUTHORIZED,
47 "AccountDeactivated",
48 "Account is deactivated",
49 ),
50 AuthError::AccountTakedown => (
51 StatusCode::UNAUTHORIZED,
52 "AccountTakedown",
53 "Account has been taken down",
54 ),
55 AuthError::AdminRequired => (
56 StatusCode::FORBIDDEN,
57 "AdminRequired",
58 "This action requires admin privileges",
59 ),
60 };
61
62 (status, Json(json!({ "error": error, "message": message }))).into_response()
63 }
64}
65
66fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
67 let auth_header = auth_header.trim();
68
69 if auth_header.len() < 8 {
70 return Err(AuthError::InvalidFormat);
71 }
72
73 let prefix = &auth_header[..7];
74 if !prefix.eq_ignore_ascii_case("bearer ") {
75 return Err(AuthError::InvalidFormat);
76 }
77
78 let token = auth_header[7..].trim();
79 if token.is_empty() {
80 return Err(AuthError::InvalidFormat);
81 }
82
83 Ok(token)
84}
85
86pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
87 let header = auth_header?;
88 let header = header.trim();
89
90 if header.len() < 7 {
91 return None;
92 }
93
94 if !header[..7].eq_ignore_ascii_case("bearer ") {
95 return None;
96 }
97
98 let token = header[7..].trim();
99 if token.is_empty() {
100 return None;
101 }
102
103 Some(token.to_string())
104}
105
106pub struct ExtractedToken {
107 pub token: String,
108 pub is_dpop: bool,
109}
110
111pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
112 let header = auth_header?;
113 let header = header.trim();
114
115 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
116 let token = header[7..].trim();
117 if token.is_empty() {
118 return None;
119 }
120 return Some(ExtractedToken {
121 token: token.to_string(),
122 is_dpop: false,
123 });
124 }
125
126 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
127 let token = header[5..].trim();
128 if token.is_empty() {
129 return None;
130 }
131 return Some(ExtractedToken {
132 token: token.to_string(),
133 is_dpop: true,
134 });
135 }
136
137 None
138}
139
140impl FromRequestParts<AppState> for BearerAuth {
141 type Rejection = AuthError;
142
143 async fn from_request_parts(
144 parts: &mut Parts,
145 state: &AppState,
146 ) -> Result<Self, Self::Rejection> {
147 let auth_header = parts
148 .headers
149 .get(AUTHORIZATION)
150 .ok_or(AuthError::MissingToken)?
151 .to_str()
152 .map_err(|_| AuthError::InvalidFormat)?;
153
154 let token = extract_bearer_token(auth_header)?;
155
156 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
157 Ok(user) => Ok(BearerAuth(user)),
158 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
159 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
160 Err(_) => Err(AuthError::AuthenticationFailed),
161 }
162 }
163}
164
165pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
166
167impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
168 type Rejection = AuthError;
169
170 async fn from_request_parts(
171 parts: &mut Parts,
172 state: &AppState,
173 ) -> Result<Self, Self::Rejection> {
174 let auth_header = parts
175 .headers
176 .get(AUTHORIZATION)
177 .ok_or(AuthError::MissingToken)?
178 .to_str()
179 .map_err(|_| AuthError::InvalidFormat)?;
180
181 let token = extract_bearer_token(auth_header)?;
182
183 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
184 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
185 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
186 Err(_) => Err(AuthError::AuthenticationFailed),
187 }
188 }
189}
190
191pub struct BearerAuthAdmin(pub AuthenticatedUser);
192
193impl FromRequestParts<AppState> for BearerAuthAdmin {
194 type Rejection = AuthError;
195
196 async fn from_request_parts(
197 parts: &mut Parts,
198 state: &AppState,
199 ) -> Result<Self, Self::Rejection> {
200 let auth_header = parts
201 .headers
202 .get(AUTHORIZATION)
203 .ok_or(AuthError::MissingToken)?
204 .to_str()
205 .map_err(|_| AuthError::InvalidFormat)?;
206
207 let token = extract_bearer_token(auth_header)?;
208
209 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
210 Ok(user) => {
211 if !user.is_admin {
212 return Err(AuthError::AdminRequired);
213 }
214 Ok(BearerAuthAdmin(user))
215 }
216 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
217 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
218 Err(_) => Err(AuthError::AuthenticationFailed),
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_extract_bearer_token() {
229 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
230 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
231 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
232 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
233 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
234
235 assert!(extract_bearer_token("Basic abc123").is_err());
236 assert!(extract_bearer_token("Bearer").is_err());
237 assert!(extract_bearer_token("Bearer ").is_err());
238 assert!(extract_bearer_token("abc123").is_err());
239 assert!(extract_bearer_token("").is_err());
240 }
241}