this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8
9pub mod extractor;
10pub mod token;
11pub mod verify;
12
13pub use extractor::{
14 AuthError, BearerAuth, BearerAuthAllowDeactivated, ExtractedToken,
15 extract_auth_token_from_header, extract_bearer_token_from_header,
16};
17pub use token::{
18 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
19 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
20 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
21 create_service_token,
22};
23pub use verify::{
24 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
25};
26
27const KEY_CACHE_TTL_SECS: u64 = 300;
28const SESSION_CACHE_TTL_SECS: u64 = 60;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum TokenValidationError {
32 AccountDeactivated,
33 AccountTakedown,
34 KeyDecryptionFailed,
35 AuthenticationFailed,
36}
37
38impl fmt::Display for TokenValidationError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
42 Self::AccountTakedown => write!(f, "AccountTakedown"),
43 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
44 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
45 }
46 }
47}
48
49pub struct AuthenticatedUser {
50 pub did: String,
51 pub key_bytes: Option<Vec<u8>>,
52 pub is_oauth: bool,
53}
54
55pub async fn validate_bearer_token(
56 db: &PgPool,
57 token: &str,
58) -> Result<AuthenticatedUser, TokenValidationError> {
59 validate_bearer_token_with_options_internal(db, None, token, false).await
60}
61
62pub async fn validate_bearer_token_allow_deactivated(
63 db: &PgPool,
64 token: &str,
65) -> Result<AuthenticatedUser, TokenValidationError> {
66 validate_bearer_token_with_options_internal(db, None, token, true).await
67}
68
69pub async fn validate_bearer_token_cached(
70 db: &PgPool,
71 cache: &Arc<dyn Cache>,
72 token: &str,
73) -> Result<AuthenticatedUser, TokenValidationError> {
74 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
75}
76
77pub async fn validate_bearer_token_cached_allow_deactivated(
78 db: &PgPool,
79 cache: &Arc<dyn Cache>,
80 token: &str,
81) -> Result<AuthenticatedUser, TokenValidationError> {
82 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
83}
84
85async fn validate_bearer_token_with_options_internal(
86 db: &PgPool,
87 cache: Option<&Arc<dyn Cache>>,
88 token: &str,
89 allow_deactivated: bool,
90) -> Result<AuthenticatedUser, TokenValidationError> {
91 let did_from_token = get_did_from_token(token).ok();
92
93 if let Some(ref did) = did_from_token {
94 let key_cache_key = format!("auth:key:{}", did);
95 let mut cached_key: Option<Vec<u8>> = None;
96
97 if let Some(c) = cache {
98 cached_key = c.get_bytes(&key_cache_key).await;
99 if cached_key.is_some() {
100 crate::metrics::record_auth_cache_hit("key");
101 } else {
102 crate::metrics::record_auth_cache_miss("key");
103 }
104 }
105
106 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
107 let user_status = sqlx::query!(
108 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
109 did
110 )
111 .fetch_optional(db)
112 .await
113 .ok()
114 .flatten();
115
116 match user_status {
117 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
118 None => (None, None, None),
119 }
120 } else if let Some(user) = sqlx::query!(
121 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
122 FROM users u
123 JOIN user_keys k ON u.id = k.user_id
124 WHERE u.did = $1",
125 did
126 )
127 .fetch_optional(db)
128 .await
129 .ok()
130 .flatten()
131 {
132 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
133 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
134
135 if let Some(c) = cache {
136 let _ = c
137 .set_bytes(
138 &key_cache_key,
139 &key,
140 Duration::from_secs(KEY_CACHE_TTL_SECS),
141 )
142 .await;
143 }
144
145 (Some(key), user.deactivated_at, user.takedown_ref)
146 } else {
147 (None, None, None)
148 };
149
150 if let Some(decrypted_key) = decrypted_key {
151 if !allow_deactivated && deactivated_at.is_some() {
152 return Err(TokenValidationError::AccountDeactivated);
153 }
154
155 if takedown_ref.is_some() {
156 return Err(TokenValidationError::AccountTakedown);
157 }
158
159 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
160 let jti = &token_data.claims.jti;
161 let session_cache_key = format!("auth:session:{}:{}", did, jti);
162 let mut session_valid = false;
163
164 if let Some(c) = cache {
165 if let Some(cached_value) = c.get(&session_cache_key).await {
166 session_valid = cached_value == "1";
167 crate::metrics::record_auth_cache_hit("session");
168 } else {
169 crate::metrics::record_auth_cache_miss("session");
170 }
171 }
172
173 if !session_valid {
174 let session_exists = sqlx::query_scalar!(
175 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
176 did,
177 jti
178 )
179 .fetch_optional(db)
180 .await
181 .ok()
182 .flatten();
183
184 session_valid = session_exists.is_some();
185
186 if session_valid
187 && let Some(c) = cache {
188 let _ = c
189 .set(
190 &session_cache_key,
191 "1",
192 Duration::from_secs(SESSION_CACHE_TTL_SECS),
193 )
194 .await;
195 }
196 }
197
198 if session_valid {
199 return Ok(AuthenticatedUser {
200 did: did.clone(),
201 key_bytes: Some(decrypted_key),
202 is_oauth: false,
203 });
204 }
205 }
206 }
207 }
208
209 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
210 && let Some(oauth_token) = sqlx::query!(
211 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref,
212 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
213 FROM oauth_token t
214 JOIN users u ON t.did = u.did
215 LEFT JOIN user_keys k ON u.id = k.user_id
216 WHERE t.token_id = $1"#,
217 oauth_info.token_id
218 )
219 .fetch_optional(db)
220 .await
221 .ok()
222 .flatten()
223 {
224 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
225 return Err(TokenValidationError::AccountDeactivated);
226 }
227
228 if oauth_token.takedown_ref.is_some() {
229 return Err(TokenValidationError::AccountTakedown);
230 }
231
232 let now = chrono::Utc::now();
233 if oauth_token.expires_at > now {
234 let key_bytes = if let (Some(kb), Some(ev)) =
235 (&oauth_token.key_bytes, oauth_token.encryption_version)
236 {
237 crate::config::decrypt_key(kb, Some(ev)).ok()
238 } else {
239 None
240 };
241 return Ok(AuthenticatedUser {
242 did: oauth_token.did,
243 key_bytes,
244 is_oauth: true,
245 });
246 }
247 }
248
249 Err(TokenValidationError::AuthenticationFailed)
250}
251
252pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
253 let key_cache_key = format!("auth:key:{}", did);
254 let _ = cache.delete(&key_cache_key).await;
255}
256
257pub async fn validate_token_with_dpop(
258 db: &PgPool,
259 token: &str,
260 is_dpop_token: bool,
261 dpop_proof: Option<&str>,
262 http_method: &str,
263 http_uri: &str,
264 allow_deactivated: bool,
265) -> Result<AuthenticatedUser, TokenValidationError> {
266 if !is_dpop_token {
267 if allow_deactivated {
268 return validate_bearer_token_allow_deactivated(db, token).await;
269 } else {
270 return validate_bearer_token(db, token).await;
271 }
272 }
273 match crate::oauth::verify::verify_oauth_access_token(
274 db,
275 token,
276 dpop_proof,
277 http_method,
278 http_uri,
279 )
280 .await
281 {
282 Ok(result) => {
283 if !allow_deactivated {
284 let deactivated = sqlx::query_scalar!(
285 "SELECT deactivated_at FROM users WHERE did = $1",
286 result.did
287 )
288 .fetch_optional(db)
289 .await
290 .ok()
291 .flatten()
292 .flatten();
293 if deactivated.is_some() {
294 return Err(TokenValidationError::AccountDeactivated);
295 }
296 }
297 let takedown =
298 sqlx::query_scalar!("SELECT takedown_ref FROM users WHERE did = $1", result.did)
299 .fetch_optional(db)
300 .await
301 .ok()
302 .flatten()
303 .flatten();
304 if takedown.is_some() {
305 return Err(TokenValidationError::AccountTakedown);
306 }
307 let key_bytes = sqlx::query!(
308 "SELECT k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
309 result.did
310 )
311 .fetch_optional(db)
312 .await
313 .ok()
314 .flatten()
315 .and_then(|row| crate::config::decrypt_key(&row.key_bytes, row.encryption_version).ok());
316 Ok(AuthenticatedUser {
317 did: result.did,
318 key_bytes,
319 is_oauth: true,
320 })
321 }
322 Err(_) => Err(TokenValidationError::AuthenticationFailed),
323 }
324}
325
326#[derive(Debug, Serialize, Deserialize)]
327pub struct Claims {
328 pub iss: String,
329 pub sub: String,
330 pub aud: String,
331 pub exp: usize,
332 pub iat: usize,
333 #[serde(skip_serializing_if = "Option::is_none")]
334 pub scope: Option<String>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub lxm: Option<String>,
337 pub jti: String,
338}
339
340#[derive(Debug, Serialize, Deserialize)]
341pub struct Header {
342 pub alg: String,
343 pub typ: String,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347pub struct UnsafeClaims {
348 pub iss: String,
349 pub sub: Option<String>,
350}
351
352pub struct TokenData<T> {
353 pub claims: T,
354}