this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::time::Duration; 5 6use crate::AccountStatus; 7use crate::cache::Cache; 8use crate::oauth::scopes::ScopePermissions; 9use crate::types::Did; 10 11pub mod extractor; 12pub mod scope_check; 13pub mod service; 14pub mod token; 15pub mod totp; 16pub mod verification_token; 17pub mod verify; 18pub mod webauthn; 19 20pub use extractor::{ 21 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 22 extract_auth_token_from_header, extract_bearer_token_from_header, 23}; 24pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 25pub use token::{ 26 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 27 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 28 create_access_token_with_delegation, create_access_token_with_metadata, 29 create_access_token_with_scope_metadata, create_refresh_token, 30 create_refresh_token_with_metadata, create_service_token, 31}; 32pub use verify::{ 33 TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token, 34 verify_access_token_typed, verify_refresh_token, verify_token, 35}; 36 37const KEY_CACHE_TTL_SECS: u64 = 300; 38const SESSION_CACHE_TTL_SECS: u64 = 60; 39const USER_STATUS_CACHE_TTL_SECS: u64 = 60; 40 41#[derive(Serialize, Deserialize)] 42struct CachedUserStatus { 43 deactivated: bool, 44 takendown: bool, 45 is_admin: bool, 46} 47 48#[derive(Debug, Clone, Copy, PartialEq, Eq)] 49pub enum TokenValidationError { 50 AccountDeactivated, 51 AccountTakedown, 52 KeyDecryptionFailed, 53 AuthenticationFailed, 54 TokenExpired, 55} 56 57impl fmt::Display for TokenValidationError { 58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 59 match self { 60 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 61 Self::AccountTakedown => write!(f, "AccountTakedown"), 62 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 63 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 64 Self::TokenExpired => write!(f, "ExpiredToken"), 65 } 66 } 67} 68 69pub struct AuthenticatedUser { 70 pub did: Did, 71 pub key_bytes: Option<Vec<u8>>, 72 pub is_oauth: bool, 73 pub is_admin: bool, 74 pub status: AccountStatus, 75 pub scope: Option<String>, 76 pub controller_did: Option<Did>, 77} 78 79impl AuthenticatedUser { 80 pub fn permissions(&self) -> ScopePermissions { 81 if let Some(ref scope) = self.scope 82 && scope != SCOPE_ACCESS 83 { 84 return ScopePermissions::from_scope_string(Some(scope)); 85 } 86 if !self.is_oauth { 87 return ScopePermissions::from_scope_string(Some("atproto")); 88 } 89 ScopePermissions::from_scope_string(self.scope.as_deref()) 90 } 91 92 pub fn is_takendown(&self) -> bool { 93 self.status.is_takendown() 94 } 95} 96 97pub async fn validate_bearer_token( 98 db: &PgPool, 99 token: &str, 100) -> Result<AuthenticatedUser, TokenValidationError> { 101 validate_bearer_token_with_options_internal(db, None, token, false, false).await 102} 103 104pub async fn validate_bearer_token_allow_deactivated( 105 db: &PgPool, 106 token: &str, 107) -> Result<AuthenticatedUser, TokenValidationError> { 108 validate_bearer_token_with_options_internal(db, None, token, true, false).await 109} 110 111pub async fn validate_bearer_token_cached( 112 db: &PgPool, 113 cache: &dyn Cache, 114 token: &str, 115) -> Result<AuthenticatedUser, TokenValidationError> { 116 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 117} 118 119pub async fn validate_bearer_token_cached_allow_deactivated( 120 db: &PgPool, 121 cache: &dyn Cache, 122 token: &str, 123) -> Result<AuthenticatedUser, TokenValidationError> { 124 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 125} 126 127pub async fn validate_bearer_token_for_service_auth( 128 db: &PgPool, 129 token: &str, 130) -> Result<AuthenticatedUser, TokenValidationError> { 131 validate_bearer_token_with_options_internal(db, None, token, true, true).await 132} 133 134pub async fn validate_bearer_token_allow_takendown( 135 db: &PgPool, 136 token: &str, 137) -> Result<AuthenticatedUser, TokenValidationError> { 138 validate_bearer_token_with_options_internal(db, None, token, false, true).await 139} 140 141async fn validate_bearer_token_with_options_internal( 142 db: &PgPool, 143 cache: Option<&dyn Cache>, 144 token: &str, 145 allow_deactivated: bool, 146 allow_takendown: bool, 147) -> Result<AuthenticatedUser, TokenValidationError> { 148 let did_from_token = get_did_from_token(token).ok(); 149 150 if let Some(ref did) = did_from_token { 151 let key_cache_key = format!("auth:key:{}", did); 152 let mut cached_key: Option<Vec<u8>> = None; 153 154 if let Some(c) = cache { 155 cached_key = c.get_bytes(&key_cache_key).await; 156 if cached_key.is_some() { 157 crate::metrics::record_auth_cache_hit("key"); 158 } else { 159 crate::metrics::record_auth_cache_miss("key"); 160 } 161 } 162 163 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 164 { 165 let status_cache_key = format!("auth:status:{}", did); 166 let cached_status: Option<CachedUserStatus> = if let Some(c) = cache { 167 c.get(&status_cache_key) 168 .await 169 .and_then(|s| serde_json::from_str(&s).ok()) 170 } else { 171 None 172 }; 173 174 if let Some(status) = cached_status { 175 ( 176 Some(key), 177 if status.deactivated { 178 Some(chrono::Utc::now()) 179 } else { 180 None 181 }, 182 if status.takendown { 183 Some("takendown".to_string()) 184 } else { 185 None 186 }, 187 status.is_admin, 188 ) 189 } else { 190 let user_status = sqlx::query!( 191 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 192 did 193 ) 194 .fetch_optional(db) 195 .await 196 .ok() 197 .flatten(); 198 199 match user_status { 200 Some(status) => { 201 if let Some(c) = cache { 202 let cached = CachedUserStatus { 203 deactivated: status.deactivated_at.is_some(), 204 takendown: status.takedown_ref.is_some(), 205 is_admin: status.is_admin, 206 }; 207 if let Ok(json) = serde_json::to_string(&cached) { 208 let _ = c 209 .set( 210 &status_cache_key, 211 &json, 212 Duration::from_secs(USER_STATUS_CACHE_TTL_SECS), 213 ) 214 .await; 215 } 216 } 217 ( 218 Some(key), 219 status.deactivated_at, 220 status.takedown_ref, 221 status.is_admin, 222 ) 223 } 224 None => (None, None, None, false), 225 } 226 } 227 } else if let Some(user) = sqlx::query!( 228 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 229 FROM users u 230 JOIN user_keys k ON u.id = k.user_id 231 WHERE u.did = $1", 232 did 233 ) 234 .fetch_optional(db) 235 .await 236 .ok() 237 .flatten() 238 { 239 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 240 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 241 242 if let Some(c) = cache { 243 let _ = c 244 .set_bytes( 245 &key_cache_key, 246 &key, 247 Duration::from_secs(KEY_CACHE_TTL_SECS), 248 ) 249 .await; 250 251 let status_cache_key = format!("auth:status:{}", did); 252 let cached = CachedUserStatus { 253 deactivated: user.deactivated_at.is_some(), 254 takendown: user.takedown_ref.is_some(), 255 is_admin: user.is_admin, 256 }; 257 if let Ok(json) = serde_json::to_string(&cached) { 258 let _ = c 259 .set( 260 &status_cache_key, 261 &json, 262 Duration::from_secs(USER_STATUS_CACHE_TTL_SECS), 263 ) 264 .await; 265 } 266 } 267 268 ( 269 Some(key), 270 user.deactivated_at, 271 user.takedown_ref, 272 user.is_admin, 273 ) 274 } else { 275 (None, None, None, false) 276 }; 277 278 if let Some(decrypted_key) = decrypted_key { 279 if !allow_deactivated && deactivated_at.is_some() { 280 return Err(TokenValidationError::AccountDeactivated); 281 } 282 283 if !allow_takendown && takedown_ref.is_some() { 284 return Err(TokenValidationError::AccountTakedown); 285 } 286 287 match verify_access_token_typed(token, &decrypted_key) { 288 Ok(token_data) => { 289 let jti = &token_data.claims.jti; 290 let session_cache_key = format!("auth:session:{}:{}", did, jti); 291 let mut session_valid = false; 292 293 if let Some(c) = cache { 294 if let Some(cached_value) = c.get(&session_cache_key).await { 295 session_valid = cached_value == "1"; 296 crate::metrics::record_auth_cache_hit("session"); 297 } else { 298 crate::metrics::record_auth_cache_miss("session"); 299 } 300 } 301 302 if !session_valid { 303 let session_row = sqlx::query!( 304 "SELECT access_expires_at FROM session_tokens WHERE did = $1 AND access_jti = $2", 305 did, 306 jti 307 ) 308 .fetch_optional(db) 309 .await 310 .ok() 311 .flatten(); 312 313 if let Some(row) = session_row { 314 if row.access_expires_at > chrono::Utc::now() { 315 session_valid = true; 316 if let Some(c) = cache { 317 let _ = c 318 .set( 319 &session_cache_key, 320 "1", 321 Duration::from_secs(SESSION_CACHE_TTL_SECS), 322 ) 323 .await; 324 } 325 } else { 326 return Err(TokenValidationError::TokenExpired); 327 } 328 } 329 } 330 331 if session_valid { 332 let controller_did = token_data 333 .claims 334 .act 335 .as_ref() 336 .map(|a| Did::new_unchecked(a.sub.clone())); 337 let status = 338 AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at); 339 return Ok(AuthenticatedUser { 340 did: Did::new_unchecked(did.clone()), 341 key_bytes: Some(decrypted_key), 342 is_oauth: false, 343 is_admin, 344 status, 345 scope: token_data.claims.scope.clone(), 346 controller_did, 347 }); 348 } 349 } 350 Err(verify::TokenVerifyError::Expired) => { 351 return Err(TokenValidationError::TokenExpired); 352 } 353 Err(verify::TokenVerifyError::Invalid) => {} 354 } 355 } 356 } 357 358 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 359 && let Some(oauth_token) = sqlx::query!( 360 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 361 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 362 FROM oauth_token t 363 JOIN users u ON t.did = u.did 364 LEFT JOIN user_keys k ON u.id = k.user_id 365 WHERE t.token_id = $1"#, 366 oauth_info.token_id 367 ) 368 .fetch_optional(db) 369 .await 370 .ok() 371 .flatten() 372 { 373 let status = AccountStatus::from_db_fields( 374 oauth_token.takedown_ref.as_deref(), 375 oauth_token.deactivated_at, 376 ); 377 378 if !allow_deactivated && status.is_deactivated() { 379 return Err(TokenValidationError::AccountDeactivated); 380 } 381 382 if !allow_takendown && status.is_takendown() { 383 return Err(TokenValidationError::AccountTakedown); 384 } 385 386 let now = chrono::Utc::now(); 387 if oauth_token.expires_at > now { 388 let key_bytes = if let (Some(kb), Some(ev)) = 389 (&oauth_token.key_bytes, oauth_token.encryption_version) 390 { 391 crate::config::decrypt_key(kb, Some(ev)).ok() 392 } else { 393 None 394 }; 395 return Ok(AuthenticatedUser { 396 did: Did::new_unchecked(oauth_token.did), 397 key_bytes, 398 is_oauth: true, 399 is_admin: oauth_token.is_admin, 400 status, 401 scope: oauth_info.scope, 402 controller_did: oauth_info.controller_did.map(Did::new_unchecked), 403 }); 404 } else { 405 return Err(TokenValidationError::TokenExpired); 406 } 407 } 408 409 Err(TokenValidationError::AuthenticationFailed) 410} 411 412pub async fn invalidate_auth_cache(cache: &dyn Cache, did: &str) { 413 let key_cache_key = format!("auth:key:{}", did); 414 let status_cache_key = format!("auth:status:{}", did); 415 let _ = cache.delete(&key_cache_key).await; 416 let _ = cache.delete(&status_cache_key).await; 417} 418 419#[allow(clippy::too_many_arguments)] 420pub async fn validate_token_with_dpop( 421 db: &PgPool, 422 token: &str, 423 is_dpop_token: bool, 424 dpop_proof: Option<&str>, 425 http_method: &str, 426 http_uri: &str, 427 allow_deactivated: bool, 428 allow_takendown: bool, 429) -> Result<AuthenticatedUser, TokenValidationError> { 430 if !is_dpop_token { 431 if allow_takendown { 432 return validate_bearer_token_allow_takendown(db, token).await; 433 } else if allow_deactivated { 434 return validate_bearer_token_allow_deactivated(db, token).await; 435 } else { 436 return validate_bearer_token(db, token).await; 437 } 438 } 439 match crate::oauth::verify::verify_oauth_access_token( 440 db, 441 token, 442 dpop_proof, 443 http_method, 444 http_uri, 445 ) 446 .await 447 { 448 Ok(result) => { 449 let user_info = sqlx::query!( 450 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 451 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 452 FROM users u 453 LEFT JOIN user_keys k ON u.id = k.user_id 454 WHERE u.did = $1"#, 455 result.did 456 ) 457 .fetch_optional(db) 458 .await 459 .ok() 460 .flatten(); 461 let Some(user_info) = user_info else { 462 return Err(TokenValidationError::AuthenticationFailed); 463 }; 464 let status = AccountStatus::from_db_fields( 465 user_info.takedown_ref.as_deref(), 466 user_info.deactivated_at, 467 ); 468 if !allow_deactivated && status.is_deactivated() { 469 return Err(TokenValidationError::AccountDeactivated); 470 } 471 if !allow_takendown && status.is_takendown() { 472 return Err(TokenValidationError::AccountTakedown); 473 } 474 let key_bytes = if let (Some(kb), Some(ev)) = 475 (&user_info.key_bytes, user_info.encryption_version) 476 { 477 crate::config::decrypt_key(kb, Some(ev)).ok() 478 } else { 479 None 480 }; 481 Ok(AuthenticatedUser { 482 did: Did::new_unchecked(result.did), 483 key_bytes, 484 is_oauth: true, 485 is_admin: user_info.is_admin, 486 status, 487 scope: result.scope, 488 controller_did: None, 489 }) 490 } 491 Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired), 492 Err(_) => Err(TokenValidationError::AuthenticationFailed), 493 } 494} 495 496#[derive(Debug, Clone, Serialize, Deserialize)] 497pub struct ActClaim { 498 pub sub: String, 499} 500 501#[derive(Debug, Serialize, Deserialize)] 502pub struct Claims { 503 pub iss: String, 504 pub sub: String, 505 pub aud: String, 506 pub exp: usize, 507 pub iat: usize, 508 #[serde(skip_serializing_if = "Option::is_none")] 509 pub scope: Option<String>, 510 #[serde(skip_serializing_if = "Option::is_none")] 511 pub lxm: Option<String>, 512 pub jti: String, 513 #[serde(skip_serializing_if = "Option::is_none")] 514 pub act: Option<ActClaim>, 515} 516 517#[derive(Debug, Serialize, Deserialize)] 518pub struct Header { 519 pub alg: String, 520 pub typ: String, 521} 522 523#[derive(Debug, Serialize, Deserialize)] 524pub struct UnsafeClaims { 525 pub iss: String, 526 pub sub: Option<String>, 527} 528 529pub struct TokenData<T> { 530 pub claims: T, 531}