this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::time::Duration;
5
6use crate::AccountStatus;
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9use crate::types::Did;
10
11pub mod extractor;
12pub mod scope_check;
13pub mod service;
14pub mod token;
15pub mod totp;
16pub mod verification_token;
17pub mod verify;
18pub mod webauthn;
19
20pub use extractor::{
21 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
22 extract_auth_token_from_header, extract_bearer_token_from_header,
23};
24pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
25pub use token::{
26 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
27 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
28 create_access_token_with_delegation, create_access_token_with_metadata,
29 create_access_token_with_scope_metadata, create_refresh_token,
30 create_refresh_token_with_metadata, create_service_token,
31};
32pub use verify::{
33 TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token,
34 verify_access_token_typed, verify_refresh_token, verify_token,
35};
36
37const KEY_CACHE_TTL_SECS: u64 = 300;
38const SESSION_CACHE_TTL_SECS: u64 = 60;
39const USER_STATUS_CACHE_TTL_SECS: u64 = 60;
40
41#[derive(Serialize, Deserialize)]
42struct CachedUserStatus {
43 deactivated: bool,
44 takendown: bool,
45 is_admin: bool,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum TokenValidationError {
50 AccountDeactivated,
51 AccountTakedown,
52 KeyDecryptionFailed,
53 AuthenticationFailed,
54 TokenExpired,
55}
56
57impl fmt::Display for TokenValidationError {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
61 Self::AccountTakedown => write!(f, "AccountTakedown"),
62 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
63 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
64 Self::TokenExpired => write!(f, "ExpiredToken"),
65 }
66 }
67}
68
69pub struct AuthenticatedUser {
70 pub did: Did,
71 pub key_bytes: Option<Vec<u8>>,
72 pub is_oauth: bool,
73 pub is_admin: bool,
74 pub status: AccountStatus,
75 pub scope: Option<String>,
76 pub controller_did: Option<Did>,
77}
78
79impl AuthenticatedUser {
80 pub fn permissions(&self) -> ScopePermissions {
81 if let Some(ref scope) = self.scope
82 && scope != SCOPE_ACCESS
83 {
84 return ScopePermissions::from_scope_string(Some(scope));
85 }
86 if !self.is_oauth {
87 return ScopePermissions::from_scope_string(Some("atproto"));
88 }
89 ScopePermissions::from_scope_string(self.scope.as_deref())
90 }
91
92 pub fn is_takendown(&self) -> bool {
93 self.status.is_takendown()
94 }
95}
96
97pub async fn validate_bearer_token(
98 db: &PgPool,
99 token: &str,
100) -> Result<AuthenticatedUser, TokenValidationError> {
101 validate_bearer_token_with_options_internal(db, None, token, false, false).await
102}
103
104pub async fn validate_bearer_token_allow_deactivated(
105 db: &PgPool,
106 token: &str,
107) -> Result<AuthenticatedUser, TokenValidationError> {
108 validate_bearer_token_with_options_internal(db, None, token, true, false).await
109}
110
111pub async fn validate_bearer_token_cached(
112 db: &PgPool,
113 cache: &dyn Cache,
114 token: &str,
115) -> Result<AuthenticatedUser, TokenValidationError> {
116 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
117}
118
119pub async fn validate_bearer_token_cached_allow_deactivated(
120 db: &PgPool,
121 cache: &dyn Cache,
122 token: &str,
123) -> Result<AuthenticatedUser, TokenValidationError> {
124 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
125}
126
127pub async fn validate_bearer_token_for_service_auth(
128 db: &PgPool,
129 token: &str,
130) -> Result<AuthenticatedUser, TokenValidationError> {
131 validate_bearer_token_with_options_internal(db, None, token, true, true).await
132}
133
134pub async fn validate_bearer_token_allow_takendown(
135 db: &PgPool,
136 token: &str,
137) -> Result<AuthenticatedUser, TokenValidationError> {
138 validate_bearer_token_with_options_internal(db, None, token, false, true).await
139}
140
141async fn validate_bearer_token_with_options_internal(
142 db: &PgPool,
143 cache: Option<&dyn Cache>,
144 token: &str,
145 allow_deactivated: bool,
146 allow_takendown: bool,
147) -> Result<AuthenticatedUser, TokenValidationError> {
148 let did_from_token = get_did_from_token(token).ok();
149
150 if let Some(ref did) = did_from_token {
151 let key_cache_key = format!("auth:key:{}", did);
152 let mut cached_key: Option<Vec<u8>> = None;
153
154 if let Some(c) = cache {
155 cached_key = c.get_bytes(&key_cache_key).await;
156 if cached_key.is_some() {
157 crate::metrics::record_auth_cache_hit("key");
158 } else {
159 crate::metrics::record_auth_cache_miss("key");
160 }
161 }
162
163 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
164 {
165 let status_cache_key = format!("auth:status:{}", did);
166 let cached_status: Option<CachedUserStatus> = if let Some(c) = cache {
167 c.get(&status_cache_key)
168 .await
169 .and_then(|s| serde_json::from_str(&s).ok())
170 } else {
171 None
172 };
173
174 if let Some(status) = cached_status {
175 (
176 Some(key),
177 if status.deactivated {
178 Some(chrono::Utc::now())
179 } else {
180 None
181 },
182 if status.takendown {
183 Some("takendown".to_string())
184 } else {
185 None
186 },
187 status.is_admin,
188 )
189 } else {
190 let user_status = sqlx::query!(
191 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
192 did
193 )
194 .fetch_optional(db)
195 .await
196 .ok()
197 .flatten();
198
199 match user_status {
200 Some(status) => {
201 if let Some(c) = cache {
202 let cached = CachedUserStatus {
203 deactivated: status.deactivated_at.is_some(),
204 takendown: status.takedown_ref.is_some(),
205 is_admin: status.is_admin,
206 };
207 if let Ok(json) = serde_json::to_string(&cached) {
208 let _ = c
209 .set(
210 &status_cache_key,
211 &json,
212 Duration::from_secs(USER_STATUS_CACHE_TTL_SECS),
213 )
214 .await;
215 }
216 }
217 (
218 Some(key),
219 status.deactivated_at,
220 status.takedown_ref,
221 status.is_admin,
222 )
223 }
224 None => (None, None, None, false),
225 }
226 }
227 } else if let Some(user) = sqlx::query!(
228 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
229 FROM users u
230 JOIN user_keys k ON u.id = k.user_id
231 WHERE u.did = $1",
232 did
233 )
234 .fetch_optional(db)
235 .await
236 .ok()
237 .flatten()
238 {
239 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
240 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
241
242 if let Some(c) = cache {
243 let _ = c
244 .set_bytes(
245 &key_cache_key,
246 &key,
247 Duration::from_secs(KEY_CACHE_TTL_SECS),
248 )
249 .await;
250
251 let status_cache_key = format!("auth:status:{}", did);
252 let cached = CachedUserStatus {
253 deactivated: user.deactivated_at.is_some(),
254 takendown: user.takedown_ref.is_some(),
255 is_admin: user.is_admin,
256 };
257 if let Ok(json) = serde_json::to_string(&cached) {
258 let _ = c
259 .set(
260 &status_cache_key,
261 &json,
262 Duration::from_secs(USER_STATUS_CACHE_TTL_SECS),
263 )
264 .await;
265 }
266 }
267
268 (
269 Some(key),
270 user.deactivated_at,
271 user.takedown_ref,
272 user.is_admin,
273 )
274 } else {
275 (None, None, None, false)
276 };
277
278 if let Some(decrypted_key) = decrypted_key {
279 if !allow_deactivated && deactivated_at.is_some() {
280 return Err(TokenValidationError::AccountDeactivated);
281 }
282
283 if !allow_takendown && takedown_ref.is_some() {
284 return Err(TokenValidationError::AccountTakedown);
285 }
286
287 match verify_access_token_typed(token, &decrypted_key) {
288 Ok(token_data) => {
289 let jti = &token_data.claims.jti;
290 let session_cache_key = format!("auth:session:{}:{}", did, jti);
291 let mut session_valid = false;
292
293 if let Some(c) = cache {
294 if let Some(cached_value) = c.get(&session_cache_key).await {
295 session_valid = cached_value == "1";
296 crate::metrics::record_auth_cache_hit("session");
297 } else {
298 crate::metrics::record_auth_cache_miss("session");
299 }
300 }
301
302 if !session_valid {
303 let session_row = sqlx::query!(
304 "SELECT access_expires_at FROM session_tokens WHERE did = $1 AND access_jti = $2",
305 did,
306 jti
307 )
308 .fetch_optional(db)
309 .await
310 .ok()
311 .flatten();
312
313 if let Some(row) = session_row {
314 if row.access_expires_at > chrono::Utc::now() {
315 session_valid = true;
316 if let Some(c) = cache {
317 let _ = c
318 .set(
319 &session_cache_key,
320 "1",
321 Duration::from_secs(SESSION_CACHE_TTL_SECS),
322 )
323 .await;
324 }
325 } else {
326 return Err(TokenValidationError::TokenExpired);
327 }
328 }
329 }
330
331 if session_valid {
332 let controller_did = token_data
333 .claims
334 .act
335 .as_ref()
336 .map(|a| Did::new_unchecked(a.sub.clone()));
337 let status =
338 AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at);
339 return Ok(AuthenticatedUser {
340 did: Did::new_unchecked(did.clone()),
341 key_bytes: Some(decrypted_key),
342 is_oauth: false,
343 is_admin,
344 status,
345 scope: token_data.claims.scope.clone(),
346 controller_did,
347 });
348 }
349 }
350 Err(verify::TokenVerifyError::Expired) => {
351 return Err(TokenValidationError::TokenExpired);
352 }
353 Err(verify::TokenVerifyError::Invalid) => {}
354 }
355 }
356 }
357
358 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
359 && let Some(oauth_token) = sqlx::query!(
360 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
361 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
362 FROM oauth_token t
363 JOIN users u ON t.did = u.did
364 LEFT JOIN user_keys k ON u.id = k.user_id
365 WHERE t.token_id = $1"#,
366 oauth_info.token_id
367 )
368 .fetch_optional(db)
369 .await
370 .ok()
371 .flatten()
372 {
373 let status = AccountStatus::from_db_fields(
374 oauth_token.takedown_ref.as_deref(),
375 oauth_token.deactivated_at,
376 );
377
378 if !allow_deactivated && status.is_deactivated() {
379 return Err(TokenValidationError::AccountDeactivated);
380 }
381
382 if !allow_takendown && status.is_takendown() {
383 return Err(TokenValidationError::AccountTakedown);
384 }
385
386 let now = chrono::Utc::now();
387 if oauth_token.expires_at > now {
388 let key_bytes = if let (Some(kb), Some(ev)) =
389 (&oauth_token.key_bytes, oauth_token.encryption_version)
390 {
391 crate::config::decrypt_key(kb, Some(ev)).ok()
392 } else {
393 None
394 };
395 return Ok(AuthenticatedUser {
396 did: Did::new_unchecked(oauth_token.did),
397 key_bytes,
398 is_oauth: true,
399 is_admin: oauth_token.is_admin,
400 status,
401 scope: oauth_info.scope,
402 controller_did: oauth_info.controller_did.map(Did::new_unchecked),
403 });
404 } else {
405 return Err(TokenValidationError::TokenExpired);
406 }
407 }
408
409 Err(TokenValidationError::AuthenticationFailed)
410}
411
412pub async fn invalidate_auth_cache(cache: &dyn Cache, did: &str) {
413 let key_cache_key = format!("auth:key:{}", did);
414 let status_cache_key = format!("auth:status:{}", did);
415 let _ = cache.delete(&key_cache_key).await;
416 let _ = cache.delete(&status_cache_key).await;
417}
418
419#[allow(clippy::too_many_arguments)]
420pub async fn validate_token_with_dpop(
421 db: &PgPool,
422 token: &str,
423 is_dpop_token: bool,
424 dpop_proof: Option<&str>,
425 http_method: &str,
426 http_uri: &str,
427 allow_deactivated: bool,
428 allow_takendown: bool,
429) -> Result<AuthenticatedUser, TokenValidationError> {
430 if !is_dpop_token {
431 if allow_takendown {
432 return validate_bearer_token_allow_takendown(db, token).await;
433 } else if allow_deactivated {
434 return validate_bearer_token_allow_deactivated(db, token).await;
435 } else {
436 return validate_bearer_token(db, token).await;
437 }
438 }
439 match crate::oauth::verify::verify_oauth_access_token(
440 db,
441 token,
442 dpop_proof,
443 http_method,
444 http_uri,
445 )
446 .await
447 {
448 Ok(result) => {
449 let user_info = sqlx::query!(
450 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
451 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
452 FROM users u
453 LEFT JOIN user_keys k ON u.id = k.user_id
454 WHERE u.did = $1"#,
455 result.did
456 )
457 .fetch_optional(db)
458 .await
459 .ok()
460 .flatten();
461 let Some(user_info) = user_info else {
462 return Err(TokenValidationError::AuthenticationFailed);
463 };
464 let status = AccountStatus::from_db_fields(
465 user_info.takedown_ref.as_deref(),
466 user_info.deactivated_at,
467 );
468 if !allow_deactivated && status.is_deactivated() {
469 return Err(TokenValidationError::AccountDeactivated);
470 }
471 if !allow_takendown && status.is_takendown() {
472 return Err(TokenValidationError::AccountTakedown);
473 }
474 let key_bytes = if let (Some(kb), Some(ev)) =
475 (&user_info.key_bytes, user_info.encryption_version)
476 {
477 crate::config::decrypt_key(kb, Some(ev)).ok()
478 } else {
479 None
480 };
481 Ok(AuthenticatedUser {
482 did: Did::new_unchecked(result.did),
483 key_bytes,
484 is_oauth: true,
485 is_admin: user_info.is_admin,
486 status,
487 scope: result.scope,
488 controller_did: None,
489 })
490 }
491 Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired),
492 Err(_) => Err(TokenValidationError::AuthenticationFailed),
493 }
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct ActClaim {
498 pub sub: String,
499}
500
501#[derive(Debug, Serialize, Deserialize)]
502pub struct Claims {
503 pub iss: String,
504 pub sub: String,
505 pub aud: String,
506 pub exp: usize,
507 pub iat: usize,
508 #[serde(skip_serializing_if = "Option::is_none")]
509 pub scope: Option<String>,
510 #[serde(skip_serializing_if = "Option::is_none")]
511 pub lxm: Option<String>,
512 pub jti: String,
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub act: Option<ActClaim>,
515}
516
517#[derive(Debug, Serialize, Deserialize)]
518pub struct Header {
519 pub alg: String,
520 pub typ: String,
521}
522
523#[derive(Debug, Serialize, Deserialize)]
524pub struct UnsafeClaims {
525 pub iss: String,
526 pub sub: Option<String>,
527}
528
529pub struct TokenData<T> {
530 pub claims: T,
531}