this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, header::AUTHORIZATION, request::Parts},
5 response::{IntoResponse, Response},
6};
7use serde_json::json;
8
9use super::{
10 AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11 validate_bearer_token_cached_allow_deactivated,
12};
13use crate::state::AppState;
14
15pub struct BearerAuth(pub AuthenticatedUser);
16
17#[derive(Debug)]
18pub enum AuthError {
19 MissingToken,
20 InvalidFormat,
21 AuthenticationFailed,
22 AccountDeactivated,
23 AccountTakedown,
24}
25
26impl IntoResponse for AuthError {
27 fn into_response(self) -> Response {
28 let (status, error, message) = match self {
29 AuthError::MissingToken => (
30 StatusCode::UNAUTHORIZED,
31 "AuthenticationRequired",
32 "Authorization header is required",
33 ),
34 AuthError::InvalidFormat => (
35 StatusCode::UNAUTHORIZED,
36 "InvalidToken",
37 "Invalid authorization header format",
38 ),
39 AuthError::AuthenticationFailed => (
40 StatusCode::UNAUTHORIZED,
41 "AuthenticationFailed",
42 "Invalid or expired token",
43 ),
44 AuthError::AccountDeactivated => (
45 StatusCode::UNAUTHORIZED,
46 "AccountDeactivated",
47 "Account is deactivated",
48 ),
49 AuthError::AccountTakedown => (
50 StatusCode::UNAUTHORIZED,
51 "AccountTakedown",
52 "Account has been taken down",
53 ),
54 };
55
56 (status, Json(json!({ "error": error, "message": message }))).into_response()
57 }
58}
59
60fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
61 let auth_header = auth_header.trim();
62
63 if auth_header.len() < 8 {
64 return Err(AuthError::InvalidFormat);
65 }
66
67 let prefix = &auth_header[..7];
68 if !prefix.eq_ignore_ascii_case("bearer ") {
69 return Err(AuthError::InvalidFormat);
70 }
71
72 let token = auth_header[7..].trim();
73 if token.is_empty() {
74 return Err(AuthError::InvalidFormat);
75 }
76
77 Ok(token)
78}
79
80pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
81 let header = auth_header?;
82 let header = header.trim();
83
84 if header.len() < 7 {
85 return None;
86 }
87
88 if !header[..7].eq_ignore_ascii_case("bearer ") {
89 return None;
90 }
91
92 let token = header[7..].trim();
93 if token.is_empty() {
94 return None;
95 }
96
97 Some(token.to_string())
98}
99
100pub struct ExtractedToken {
101 pub token: String,
102 pub is_dpop: bool,
103}
104
105pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> {
106 let header = auth_header?;
107 let header = header.trim();
108
109 if header.len() >= 7 && header[..7].eq_ignore_ascii_case("bearer ") {
110 let token = header[7..].trim();
111 if token.is_empty() {
112 return None;
113 }
114 return Some(ExtractedToken {
115 token: token.to_string(),
116 is_dpop: false,
117 });
118 }
119
120 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
121 let token = header[5..].trim();
122 if token.is_empty() {
123 return None;
124 }
125 return Some(ExtractedToken {
126 token: token.to_string(),
127 is_dpop: true,
128 });
129 }
130
131 None
132}
133
134impl FromRequestParts<AppState> for BearerAuth {
135 type Rejection = AuthError;
136
137 async fn from_request_parts(
138 parts: &mut Parts,
139 state: &AppState,
140 ) -> Result<Self, Self::Rejection> {
141 let auth_header = parts
142 .headers
143 .get(AUTHORIZATION)
144 .ok_or(AuthError::MissingToken)?
145 .to_str()
146 .map_err(|_| AuthError::InvalidFormat)?;
147
148 let token = extract_bearer_token(auth_header)?;
149
150 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
151 Ok(user) => Ok(BearerAuth(user)),
152 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
153 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
154 Err(_) => Err(AuthError::AuthenticationFailed),
155 }
156 }
157}
158
159pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
160
161impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
162 type Rejection = AuthError;
163
164 async fn from_request_parts(
165 parts: &mut Parts,
166 state: &AppState,
167 ) -> Result<Self, Self::Rejection> {
168 let auth_header = parts
169 .headers
170 .get(AUTHORIZATION)
171 .ok_or(AuthError::MissingToken)?
172 .to_str()
173 .map_err(|_| AuthError::InvalidFormat)?;
174
175 let token = extract_bearer_token(auth_header)?;
176
177 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
178 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
179 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
180 Err(_) => Err(AuthError::AuthenticationFailed),
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_extract_bearer_token() {
191 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
192 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
193 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
194 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
195 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
196
197 assert!(extract_bearer_token("Basic abc123").is_err());
198 assert!(extract_bearer_token("Bearer").is_err());
199 assert!(extract_bearer_token("Bearer ").is_err());
200 assert!(extract_bearer_token("abc123").is_err());
201 assert!(extract_bearer_token("").is_err());
202 }
203}