this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{header::AUTHORIZATION, request::Parts},
4 response::{IntoResponse, Response},
5};
6
7use super::{
8 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
9 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
10};
11use crate::api::error::ApiError;
12use crate::state::AppState;
13use crate::util::build_full_url;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 TokenExpired,
23 AccountDeactivated,
24 AccountTakedown,
25 AdminRequired,
26}
27
28impl IntoResponse for AuthError {
29 fn into_response(self) -> Response {
30 ApiError::from(self).into_response()
31 }
32}
33
34#[cfg(test)]
35fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
36 let auth_header = auth_header.trim();
37
38 if auth_header.len() < 8 {
39 return Err(AuthError::InvalidFormat);
40 }
41
42 let prefix = &auth_header[..7];
43 if !prefix.eq_ignore_ascii_case("bearer ") {
44 return Err(AuthError::InvalidFormat);
45 }
46
47 let token = auth_header[7..].trim();
48 if token.is_empty() {
49 return Err(AuthError::InvalidFormat);
50 }
51
52 Ok(token)
53}
54
55pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
56 let header = auth_header?;
57 let header = header.trim();
58
59 if header.len() < 7 {
60 return None;
61 }
62
63 if !header[..7].eq_ignore_ascii_case("bearer ") {
64 return None;
65 }
66
67 let token = header[7..].trim();
68 if token.is_empty() {
69 return None;
70 }
71
72 Some(token.to_string())
73}
74
75pub struct ExtractedToken {
76 pub token: String,
77 pub is_dpop: bool,
78}
79
80pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
81 let header = auth_header?;
82 let header = header.trim();
83
84 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
85 let token = header[7..].trim();
86 if token.is_empty() {
87 return None;
88 }
89 return Some(ExtractedToken {
90 token: token.to_string(),
91 is_dpop: false,
92 });
93 }
94
95 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
96 let token = header[5..].trim();
97 if token.is_empty() {
98 return None;
99 }
100 return Some(ExtractedToken {
101 token: token.to_string(),
102 is_dpop: true,
103 });
104 }
105
106 None
107}
108
109impl FromRequestParts<AppState> for BearerAuth {
110 type Rejection = AuthError;
111
112 async fn from_request_parts(
113 parts: &mut Parts,
114 state: &AppState,
115 ) -> Result<Self, Self::Rejection> {
116 let auth_header = parts
117 .headers
118 .get(AUTHORIZATION)
119 .ok_or(AuthError::MissingToken)?
120 .to_str()
121 .map_err(|_| AuthError::InvalidFormat)?;
122
123 let extracted =
124 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
125
126 if extracted.is_dpop {
127 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
128 let method = parts.method.as_str();
129 let uri = build_full_url(&parts.uri.to_string());
130
131 match validate_token_with_dpop(
132 &state.db,
133 &extracted.token,
134 true,
135 dpop_proof,
136 method,
137 &uri,
138 false,
139 )
140 .await
141 {
142 Ok(user) => Ok(BearerAuth(user)),
143 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
144 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
145 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
146 Err(_) => Err(AuthError::AuthenticationFailed),
147 }
148 } else {
149 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token)
150 .await
151 {
152 Ok(user) => Ok(BearerAuth(user)),
153 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
154 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
155 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
156 Err(_) => Err(AuthError::AuthenticationFailed),
157 }
158 }
159 }
160}
161
162pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
163
164impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
165 type Rejection = AuthError;
166
167 async fn from_request_parts(
168 parts: &mut Parts,
169 state: &AppState,
170 ) -> Result<Self, Self::Rejection> {
171 let auth_header = parts
172 .headers
173 .get(AUTHORIZATION)
174 .ok_or(AuthError::MissingToken)?
175 .to_str()
176 .map_err(|_| AuthError::InvalidFormat)?;
177
178 let extracted =
179 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
180
181 if extracted.is_dpop {
182 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
183 let method = parts.method.as_str();
184 let uri = build_full_url(&parts.uri.to_string());
185
186 match validate_token_with_dpop(
187 &state.db,
188 &extracted.token,
189 true,
190 dpop_proof,
191 method,
192 &uri,
193 true,
194 )
195 .await
196 {
197 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
198 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
199 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
200 Err(_) => Err(AuthError::AuthenticationFailed),
201 }
202 } else {
203 match validate_bearer_token_cached_allow_deactivated(
204 &state.db,
205 state.cache.as_ref(),
206 &extracted.token,
207 )
208 .await
209 {
210 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
211 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
212 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
213 Err(_) => Err(AuthError::AuthenticationFailed),
214 }
215 }
216 }
217}
218
219pub struct BearerAuthAdmin(pub AuthenticatedUser);
220
221impl FromRequestParts<AppState> for BearerAuthAdmin {
222 type Rejection = AuthError;
223
224 async fn from_request_parts(
225 parts: &mut Parts,
226 state: &AppState,
227 ) -> Result<Self, Self::Rejection> {
228 let auth_header = parts
229 .headers
230 .get(AUTHORIZATION)
231 .ok_or(AuthError::MissingToken)?
232 .to_str()
233 .map_err(|_| AuthError::InvalidFormat)?;
234
235 let extracted =
236 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
237
238 let user = if extracted.is_dpop {
239 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
240 let method = parts.method.as_str();
241 let uri = build_full_url(&parts.uri.to_string());
242
243 match validate_token_with_dpop(
244 &state.db,
245 &extracted.token,
246 true,
247 dpop_proof,
248 method,
249 &uri,
250 false,
251 )
252 .await
253 {
254 Ok(user) => user,
255 Err(TokenValidationError::AccountDeactivated) => {
256 return Err(AuthError::AccountDeactivated);
257 }
258 Err(TokenValidationError::AccountTakedown) => {
259 return Err(AuthError::AccountTakedown);
260 }
261 Err(TokenValidationError::TokenExpired) => {
262 return Err(AuthError::TokenExpired);
263 }
264 Err(_) => return Err(AuthError::AuthenticationFailed),
265 }
266 } else {
267 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token)
268 .await
269 {
270 Ok(user) => user,
271 Err(TokenValidationError::AccountDeactivated) => {
272 return Err(AuthError::AccountDeactivated);
273 }
274 Err(TokenValidationError::AccountTakedown) => {
275 return Err(AuthError::AccountTakedown);
276 }
277 Err(TokenValidationError::TokenExpired) => {
278 return Err(AuthError::TokenExpired);
279 }
280 Err(_) => return Err(AuthError::AuthenticationFailed),
281 }
282 };
283
284 if !user.is_admin {
285 return Err(AuthError::AdminRequired);
286 }
287 Ok(BearerAuthAdmin(user))
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_extract_bearer_token() {
297 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
298 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
299 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
300 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
301 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
302
303 assert!(extract_bearer_token("Basic abc123").is_err());
304 assert!(extract_bearer_token("Bearer").is_err());
305 assert!(extract_bearer_token("Bearer ").is_err());
306 assert!(extract_bearer_token("abc123").is_err());
307 assert!(extract_bearer_token("").is_err());
308 }
309}