this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{header::AUTHORIZATION, request::Parts},
4 response::{IntoResponse, Response},
5};
6
7use super::{
8 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
9 validate_bearer_token_cached_allow_deactivated, validate_token_with_dpop,
10};
11use crate::api::error::ApiError;
12use crate::state::AppState;
13use crate::util::build_full_url;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 TokenExpired,
23 AccountDeactivated,
24 AccountTakedown,
25 AdminRequired,
26}
27
28impl IntoResponse for AuthError {
29 fn into_response(self) -> Response {
30 ApiError::from(self).into_response()
31 }
32}
33
34#[cfg(test)]
35fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
36 let auth_header = auth_header.trim();
37
38 if auth_header.len() < 8 {
39 return Err(AuthError::InvalidFormat);
40 }
41
42 let prefix = &auth_header[..7];
43 if !prefix.eq_ignore_ascii_case("bearer ") {
44 return Err(AuthError::InvalidFormat);
45 }
46
47 let token = auth_header[7..].trim();
48 if token.is_empty() {
49 return Err(AuthError::InvalidFormat);
50 }
51
52 Ok(token)
53}
54
55pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
56 let header = auth_header?;
57 let header = header.trim();
58
59 if header.len() < 7 {
60 return None;
61 }
62
63 if !header[..7].eq_ignore_ascii_case("bearer ") {
64 return None;
65 }
66
67 let token = header[7..].trim();
68 if token.is_empty() {
69 return None;
70 }
71
72 Some(token.to_string())
73}
74
75pub struct ExtractedToken {
76 pub token: String,
77 pub is_dpop: bool,
78}
79
80pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
81 let header = auth_header?;
82 let header = header.trim();
83
84 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
85 let token = header[7..].trim();
86 if token.is_empty() {
87 return None;
88 }
89 return Some(ExtractedToken {
90 token: token.to_string(),
91 is_dpop: false,
92 });
93 }
94
95 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
96 let token = header[5..].trim();
97 if token.is_empty() {
98 return None;
99 }
100 return Some(ExtractedToken {
101 token: token.to_string(),
102 is_dpop: true,
103 });
104 }
105
106 None
107}
108
109impl FromRequestParts<AppState> for BearerAuth {
110 type Rejection = AuthError;
111
112 async fn from_request_parts(
113 parts: &mut Parts,
114 state: &AppState,
115 ) -> Result<Self, Self::Rejection> {
116 let auth_header = parts
117 .headers
118 .get(AUTHORIZATION)
119 .ok_or(AuthError::MissingToken)?
120 .to_str()
121 .map_err(|_| AuthError::InvalidFormat)?;
122
123 let extracted =
124 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
125
126 if extracted.is_dpop {
127 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
128 let method = parts.method.as_str();
129 let uri = build_full_url(&parts.uri.to_string());
130
131 match validate_token_with_dpop(
132 &state.db,
133 &extracted.token,
134 true,
135 dpop_proof,
136 method,
137 &uri,
138 false,
139 )
140 .await
141 {
142 Ok(user) => Ok(BearerAuth(user)),
143 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
144 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
145 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
146 Err(_) => Err(AuthError::AuthenticationFailed),
147 }
148 } else {
149 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token).await {
150 Ok(user) => Ok(BearerAuth(user)),
151 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
152 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
153 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
154 Err(_) => Err(AuthError::AuthenticationFailed),
155 }
156 }
157 }
158}
159
160pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
161
162impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
163 type Rejection = AuthError;
164
165 async fn from_request_parts(
166 parts: &mut Parts,
167 state: &AppState,
168 ) -> Result<Self, Self::Rejection> {
169 let auth_header = parts
170 .headers
171 .get(AUTHORIZATION)
172 .ok_or(AuthError::MissingToken)?
173 .to_str()
174 .map_err(|_| AuthError::InvalidFormat)?;
175
176 let extracted =
177 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
178
179 if extracted.is_dpop {
180 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
181 let method = parts.method.as_str();
182 let uri = build_full_url(&parts.uri.to_string());
183
184 match validate_token_with_dpop(
185 &state.db,
186 &extracted.token,
187 true,
188 dpop_proof,
189 method,
190 &uri,
191 true,
192 )
193 .await
194 {
195 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
196 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
197 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
198 Err(_) => Err(AuthError::AuthenticationFailed),
199 }
200 } else {
201 match validate_bearer_token_cached_allow_deactivated(
202 &state.db,
203 state.cache.as_ref(),
204 &extracted.token,
205 )
206 .await
207 {
208 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
209 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
210 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired),
211 Err(_) => Err(AuthError::AuthenticationFailed),
212 }
213 }
214 }
215}
216
217pub struct BearerAuthAdmin(pub AuthenticatedUser);
218
219impl FromRequestParts<AppState> for BearerAuthAdmin {
220 type Rejection = AuthError;
221
222 async fn from_request_parts(
223 parts: &mut Parts,
224 state: &AppState,
225 ) -> Result<Self, Self::Rejection> {
226 let auth_header = parts
227 .headers
228 .get(AUTHORIZATION)
229 .ok_or(AuthError::MissingToken)?
230 .to_str()
231 .map_err(|_| AuthError::InvalidFormat)?;
232
233 let extracted =
234 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?;
235
236 let user = if extracted.is_dpop {
237 let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok());
238 let method = parts.method.as_str();
239 let uri = build_full_url(&parts.uri.to_string());
240
241 match validate_token_with_dpop(
242 &state.db,
243 &extracted.token,
244 true,
245 dpop_proof,
246 method,
247 &uri,
248 false,
249 )
250 .await
251 {
252 Ok(user) => user,
253 Err(TokenValidationError::AccountDeactivated) => {
254 return Err(AuthError::AccountDeactivated);
255 }
256 Err(TokenValidationError::AccountTakedown) => {
257 return Err(AuthError::AccountTakedown);
258 }
259 Err(TokenValidationError::TokenExpired) => {
260 return Err(AuthError::TokenExpired);
261 }
262 Err(_) => return Err(AuthError::AuthenticationFailed),
263 }
264 } else {
265 match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token).await {
266 Ok(user) => user,
267 Err(TokenValidationError::AccountDeactivated) => {
268 return Err(AuthError::AccountDeactivated);
269 }
270 Err(TokenValidationError::AccountTakedown) => {
271 return Err(AuthError::AccountTakedown);
272 }
273 Err(TokenValidationError::TokenExpired) => {
274 return Err(AuthError::TokenExpired);
275 }
276 Err(_) => return Err(AuthError::AuthenticationFailed),
277 }
278 };
279
280 if !user.is_admin {
281 return Err(AuthError::AdminRequired);
282 }
283 Ok(BearerAuthAdmin(user))
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_extract_bearer_token() {
293 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
294 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
295 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
296 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
297 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
298
299 assert!(extract_bearer_token("Basic abc123").is_err());
300 assert!(extract_bearer_token("Bearer").is_err());
301 assert!(extract_bearer_token("Bearer ").is_err());
302 assert!(extract_bearer_token("abc123").is_err());
303 assert!(extract_bearer_token("").is_err());
304 }
305}