this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{StatusCode, request::Parts, header::AUTHORIZATION},
4 response::{IntoResponse, Response},
5 Json,
6};
7use serde_json::json;
8
9use crate::state::AppState;
10use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
11
12pub struct BearerAuth(pub AuthenticatedUser);
13
14#[derive(Debug)]
15pub enum AuthError {
16 MissingToken,
17 InvalidFormat,
18 AuthenticationFailed,
19 AccountDeactivated,
20 AccountTakedown,
21}
22
23impl IntoResponse for AuthError {
24 fn into_response(self) -> Response {
25 let (status, error, message) = match self {
26 AuthError::MissingToken => (
27 StatusCode::UNAUTHORIZED,
28 "AuthenticationRequired",
29 "Authorization header is required",
30 ),
31 AuthError::InvalidFormat => (
32 StatusCode::UNAUTHORIZED,
33 "InvalidToken",
34 "Invalid authorization header format",
35 ),
36 AuthError::AuthenticationFailed => (
37 StatusCode::UNAUTHORIZED,
38 "AuthenticationFailed",
39 "Invalid or expired token",
40 ),
41 AuthError::AccountDeactivated => (
42 StatusCode::UNAUTHORIZED,
43 "AccountDeactivated",
44 "Account is deactivated",
45 ),
46 AuthError::AccountTakedown => (
47 StatusCode::UNAUTHORIZED,
48 "AccountTakedown",
49 "Account has been taken down",
50 ),
51 };
52
53 (status, Json(json!({ "error": error, "message": message }))).into_response()
54 }
55}
56
57fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
58 let auth_header = auth_header.trim();
59
60 if auth_header.len() < 8 {
61 return Err(AuthError::InvalidFormat);
62 }
63
64 let prefix = &auth_header[..7];
65 if !prefix.eq_ignore_ascii_case("bearer ") {
66 return Err(AuthError::InvalidFormat);
67 }
68
69 let token = auth_header[7..].trim();
70 if token.is_empty() {
71 return Err(AuthError::InvalidFormat);
72 }
73
74 Ok(token)
75}
76
77pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> {
78 let header = auth_header?;
79 let header = header.trim();
80
81 if header.len() < 7 {
82 return None;
83 }
84
85 if !header[..7].eq_ignore_ascii_case("bearer ") {
86 return None;
87 }
88
89 let token = header[7..].trim();
90 if token.is_empty() {
91 return None;
92 }
93
94 Some(token.to_string())
95}
96
97impl FromRequestParts<AppState> for BearerAuth {
98 type Rejection = AuthError;
99
100 async fn from_request_parts(
101 parts: &mut Parts,
102 state: &AppState,
103 ) -> Result<Self, Self::Rejection> {
104 let auth_header = parts
105 .headers
106 .get(AUTHORIZATION)
107 .ok_or(AuthError::MissingToken)?
108 .to_str()
109 .map_err(|_| AuthError::InvalidFormat)?;
110
111 let token = extract_bearer_token(auth_header)?;
112
113 match validate_bearer_token_cached(&state.db, &state.cache, token).await {
114 Ok(user) => Ok(BearerAuth(user)),
115 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
116 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
117 Err(_) => Err(AuthError::AuthenticationFailed),
118 }
119 }
120}
121
122pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
123
124impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
125 type Rejection = AuthError;
126
127 async fn from_request_parts(
128 parts: &mut Parts,
129 state: &AppState,
130 ) -> Result<Self, Self::Rejection> {
131 let auth_header = parts
132 .headers
133 .get(AUTHORIZATION)
134 .ok_or(AuthError::MissingToken)?
135 .to_str()
136 .map_err(|_| AuthError::InvalidFormat)?;
137
138 let token = extract_bearer_token(auth_header)?;
139
140 match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await {
141 Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
142 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
143 Err(_) => Err(AuthError::AuthenticationFailed),
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_extract_bearer_token() {
154 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
155 assert_eq!(extract_bearer_token("bearer abc123").unwrap(), "abc123");
156 assert_eq!(extract_bearer_token("BEARER abc123").unwrap(), "abc123");
157 assert_eq!(extract_bearer_token("Bearer abc123").unwrap(), "abc123");
158 assert_eq!(extract_bearer_token(" Bearer abc123 ").unwrap(), "abc123");
159
160 assert!(extract_bearer_token("Basic abc123").is_err());
161 assert!(extract_bearer_token("Bearer").is_err());
162 assert!(extract_bearer_token("Bearer ").is_err());
163 assert!(extract_bearer_token("abc123").is_err());
164 assert!(extract_bearer_token("").is_err());
165 }
166}