this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8
9pub mod extractor;
10pub mod token;
11pub mod verify;
12
13pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
14pub use token::{
15 create_access_token, create_refresh_token, create_service_token,
16 create_access_token_with_metadata, create_refresh_token_with_metadata,
17 TokenWithMetadata,
18 TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE,
19 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
20};
21pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
22
23const KEY_CACHE_TTL_SECS: u64 = 300;
24const SESSION_CACHE_TTL_SECS: u64 = 60;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum TokenValidationError {
28 AccountDeactivated,
29 AccountTakedown,
30 KeyDecryptionFailed,
31 AuthenticationFailed,
32}
33
34impl fmt::Display for TokenValidationError {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
38 Self::AccountTakedown => write!(f, "AccountTakedown"),
39 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
40 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
41 }
42 }
43}
44
45pub struct AuthenticatedUser {
46 pub did: String,
47 pub key_bytes: Option<Vec<u8>>,
48 pub is_oauth: bool,
49}
50
51pub async fn validate_bearer_token(
52 db: &PgPool,
53 token: &str,
54) -> Result<AuthenticatedUser, TokenValidationError> {
55 validate_bearer_token_with_options_internal(db, None, token, false).await
56}
57
58pub async fn validate_bearer_token_allow_deactivated(
59 db: &PgPool,
60 token: &str,
61) -> Result<AuthenticatedUser, TokenValidationError> {
62 validate_bearer_token_with_options_internal(db, None, token, true).await
63}
64
65pub async fn validate_bearer_token_cached(
66 db: &PgPool,
67 cache: &Arc<dyn Cache>,
68 token: &str,
69) -> Result<AuthenticatedUser, TokenValidationError> {
70 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
71}
72
73pub async fn validate_bearer_token_cached_allow_deactivated(
74 db: &PgPool,
75 cache: &Arc<dyn Cache>,
76 token: &str,
77) -> Result<AuthenticatedUser, TokenValidationError> {
78 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
79}
80
81async fn validate_bearer_token_with_options_internal(
82 db: &PgPool,
83 cache: Option<&Arc<dyn Cache>>,
84 token: &str,
85 allow_deactivated: bool,
86) -> Result<AuthenticatedUser, TokenValidationError> {
87 let did_from_token = get_did_from_token(token).ok();
88
89 if let Some(ref did) = did_from_token {
90 let key_cache_key = format!("auth:key:{}", did);
91 let mut cached_key: Option<Vec<u8>> = None;
92
93 if let Some(c) = cache {
94 cached_key = c.get_bytes(&key_cache_key).await;
95 if cached_key.is_some() {
96 crate::metrics::record_auth_cache_hit("key");
97 } else {
98 crate::metrics::record_auth_cache_miss("key");
99 }
100 }
101
102 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
103 let user_status = sqlx::query!(
104 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
105 did
106 )
107 .fetch_optional(db)
108 .await
109 .ok()
110 .flatten();
111
112 match user_status {
113 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
114 None => (None, None, None),
115 }
116 } else {
117 if let Some(user) = sqlx::query!(
118 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
119 FROM users u
120 JOIN user_keys k ON u.id = k.user_id
121 WHERE u.did = $1",
122 did
123 )
124 .fetch_optional(db)
125 .await
126 .ok()
127 .flatten()
128 {
129 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
130 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
131
132 if let Some(c) = cache {
133 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
134 }
135
136 (Some(key), user.deactivated_at, user.takedown_ref)
137 } else {
138 (None, None, None)
139 }
140 };
141
142 if let Some(decrypted_key) = decrypted_key {
143 if !allow_deactivated && deactivated_at.is_some() {
144 return Err(TokenValidationError::AccountDeactivated);
145 }
146
147 if takedown_ref.is_some() {
148 return Err(TokenValidationError::AccountTakedown);
149 }
150
151 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
152 let jti = &token_data.claims.jti;
153 let session_cache_key = format!("auth:session:{}:{}", did, jti);
154 let mut session_valid = false;
155
156 if let Some(c) = cache {
157 if let Some(cached_value) = c.get(&session_cache_key).await {
158 session_valid = cached_value == "1";
159 crate::metrics::record_auth_cache_hit("session");
160 } else {
161 crate::metrics::record_auth_cache_miss("session");
162 }
163 }
164
165 if !session_valid {
166 let session_exists = sqlx::query_scalar!(
167 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
168 did,
169 jti
170 )
171 .fetch_optional(db)
172 .await
173 .ok()
174 .flatten();
175
176 session_valid = session_exists.is_some();
177
178 if session_valid {
179 if let Some(c) = cache {
180 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
181 }
182 }
183 }
184
185 if session_valid {
186 return Ok(AuthenticatedUser {
187 did: did.clone(),
188 key_bytes: Some(decrypted_key),
189 is_oauth: false,
190 });
191 }
192 }
193 }
194 }
195
196 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
197 if let Some(oauth_token) = sqlx::query!(
198 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
199 FROM oauth_token t
200 JOIN users u ON t.did = u.did
201 WHERE t.token_id = $1"#,
202 oauth_info.token_id
203 )
204 .fetch_optional(db)
205 .await
206 .ok()
207 .flatten()
208 {
209 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
210 return Err(TokenValidationError::AccountDeactivated);
211 }
212
213 if oauth_token.takedown_ref.is_some() {
214 return Err(TokenValidationError::AccountTakedown);
215 }
216
217 let now = chrono::Utc::now();
218 if oauth_token.expires_at > now {
219 return Ok(AuthenticatedUser {
220 did: oauth_token.did,
221 key_bytes: None,
222 is_oauth: true,
223 });
224 }
225 }
226 }
227
228 Err(TokenValidationError::AuthenticationFailed)
229}
230
231pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
232 let key_cache_key = format!("auth:key:{}", did);
233 let _ = cache.delete(&key_cache_key).await;
234}
235
236#[derive(Debug, Serialize, Deserialize)]
237pub struct Claims {
238 pub iss: String,
239 pub sub: String,
240 pub aud: String,
241 pub exp: usize,
242 pub iat: usize,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub scope: Option<String>,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub lxm: Option<String>,
247 pub jti: String,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251pub struct Header {
252 pub alg: String,
253 pub typ: String,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257pub struct UnsafeClaims {
258 pub iss: String,
259 pub sub: Option<String>,
260}
261
262pub struct TokenData<T> {
263 pub claims: T,
264}