this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::sync::Arc; 5use std::time::Duration; 6 7use crate::cache::Cache; 8use crate::oauth::scopes::ScopePermissions; 9 10pub mod extractor; 11pub mod scope_check; 12pub mod service; 13pub mod token; 14pub mod totp; 15pub mod verification_token; 16pub mod verify; 17pub mod webauthn; 18 19pub use extractor::{ 20 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 21 extract_auth_token_from_header, extract_bearer_token_from_header, 22}; 23pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 24pub use token::{ 25 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 26 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 27 create_access_token_with_delegation, create_access_token_with_metadata, 28 create_access_token_with_scope_metadata, create_refresh_token, 29 create_refresh_token_with_metadata, create_service_token, 30}; 31pub use verify::{ 32 TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token, 33 verify_access_token_typed, verify_refresh_token, verify_token, 34}; 35 36const KEY_CACHE_TTL_SECS: u64 = 300; 37const SESSION_CACHE_TTL_SECS: u64 = 60; 38 39#[derive(Debug, Clone, Copy, PartialEq, Eq)] 40pub enum TokenValidationError { 41 AccountDeactivated, 42 AccountTakedown, 43 KeyDecryptionFailed, 44 AuthenticationFailed, 45 TokenExpired, 46} 47 48impl fmt::Display for TokenValidationError { 49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 50 match self { 51 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 52 Self::AccountTakedown => write!(f, "AccountTakedown"), 53 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 54 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 55 Self::TokenExpired => write!(f, "ExpiredToken"), 56 } 57 } 58} 59 60pub struct AuthenticatedUser { 61 pub did: String, 62 pub key_bytes: Option<Vec<u8>>, 63 pub is_oauth: bool, 64 pub is_admin: bool, 65 pub is_takendown: bool, 66 pub scope: Option<String>, 67 pub controller_did: Option<String>, 68} 69 70impl AuthenticatedUser { 71 pub fn permissions(&self) -> ScopePermissions { 72 if let Some(ref scope) = self.scope 73 && scope != SCOPE_ACCESS 74 { 75 return ScopePermissions::from_scope_string(Some(scope)); 76 } 77 if !self.is_oauth { 78 return ScopePermissions::from_scope_string(Some("atproto")); 79 } 80 ScopePermissions::from_scope_string(self.scope.as_deref()) 81 } 82} 83 84pub async fn validate_bearer_token( 85 db: &PgPool, 86 token: &str, 87) -> Result<AuthenticatedUser, TokenValidationError> { 88 validate_bearer_token_with_options_internal(db, None, token, false, false).await 89} 90 91pub async fn validate_bearer_token_allow_deactivated( 92 db: &PgPool, 93 token: &str, 94) -> Result<AuthenticatedUser, TokenValidationError> { 95 validate_bearer_token_with_options_internal(db, None, token, true, false).await 96} 97 98pub async fn validate_bearer_token_cached( 99 db: &PgPool, 100 cache: &Arc<dyn Cache>, 101 token: &str, 102) -> Result<AuthenticatedUser, TokenValidationError> { 103 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 104} 105 106pub async fn validate_bearer_token_cached_allow_deactivated( 107 db: &PgPool, 108 cache: &Arc<dyn Cache>, 109 token: &str, 110) -> Result<AuthenticatedUser, TokenValidationError> { 111 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 112} 113 114pub async fn validate_bearer_token_for_service_auth( 115 db: &PgPool, 116 token: &str, 117) -> Result<AuthenticatedUser, TokenValidationError> { 118 validate_bearer_token_with_options_internal(db, None, token, true, true).await 119} 120 121pub async fn validate_bearer_token_allow_takendown( 122 db: &PgPool, 123 token: &str, 124) -> Result<AuthenticatedUser, TokenValidationError> { 125 validate_bearer_token_with_options_internal(db, None, token, false, true).await 126} 127 128async fn validate_bearer_token_with_options_internal( 129 db: &PgPool, 130 cache: Option<&Arc<dyn Cache>>, 131 token: &str, 132 allow_deactivated: bool, 133 allow_takendown: bool, 134) -> Result<AuthenticatedUser, TokenValidationError> { 135 let did_from_token = get_did_from_token(token).ok(); 136 137 if let Some(ref did) = did_from_token { 138 let key_cache_key = format!("auth:key:{}", did); 139 let mut cached_key: Option<Vec<u8>> = None; 140 141 if let Some(c) = cache { 142 cached_key = c.get_bytes(&key_cache_key).await; 143 if cached_key.is_some() { 144 crate::metrics::record_auth_cache_hit("key"); 145 } else { 146 crate::metrics::record_auth_cache_miss("key"); 147 } 148 } 149 150 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 151 { 152 let user_status = sqlx::query!( 153 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 154 did 155 ) 156 .fetch_optional(db) 157 .await 158 .ok() 159 .flatten(); 160 161 match user_status { 162 Some(status) => ( 163 Some(key), 164 status.deactivated_at, 165 status.takedown_ref, 166 status.is_admin, 167 ), 168 None => (None, None, None, false), 169 } 170 } else if let Some(user) = sqlx::query!( 171 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 172 FROM users u 173 JOIN user_keys k ON u.id = k.user_id 174 WHERE u.did = $1", 175 did 176 ) 177 .fetch_optional(db) 178 .await 179 .ok() 180 .flatten() 181 { 182 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 183 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 184 185 if let Some(c) = cache { 186 let _ = c 187 .set_bytes( 188 &key_cache_key, 189 &key, 190 Duration::from_secs(KEY_CACHE_TTL_SECS), 191 ) 192 .await; 193 } 194 195 ( 196 Some(key), 197 user.deactivated_at, 198 user.takedown_ref, 199 user.is_admin, 200 ) 201 } else { 202 (None, None, None, false) 203 }; 204 205 if let Some(decrypted_key) = decrypted_key { 206 if !allow_deactivated && deactivated_at.is_some() { 207 return Err(TokenValidationError::AccountDeactivated); 208 } 209 210 if !allow_takendown && takedown_ref.is_some() { 211 return Err(TokenValidationError::AccountTakedown); 212 } 213 214 match verify_access_token_typed(token, &decrypted_key) { 215 Ok(token_data) => { 216 let jti = &token_data.claims.jti; 217 let session_cache_key = format!("auth:session:{}:{}", did, jti); 218 let mut session_valid = false; 219 220 if let Some(c) = cache { 221 if let Some(cached_value) = c.get(&session_cache_key).await { 222 session_valid = cached_value == "1"; 223 crate::metrics::record_auth_cache_hit("session"); 224 } else { 225 crate::metrics::record_auth_cache_miss("session"); 226 } 227 } 228 229 if !session_valid { 230 let session_row = sqlx::query!( 231 "SELECT access_expires_at FROM session_tokens WHERE did = $1 AND access_jti = $2", 232 did, 233 jti 234 ) 235 .fetch_optional(db) 236 .await 237 .ok() 238 .flatten(); 239 240 if let Some(row) = session_row { 241 if row.access_expires_at > chrono::Utc::now() { 242 session_valid = true; 243 if let Some(c) = cache { 244 let _ = c 245 .set( 246 &session_cache_key, 247 "1", 248 Duration::from_secs(SESSION_CACHE_TTL_SECS), 249 ) 250 .await; 251 } 252 } else { 253 return Err(TokenValidationError::TokenExpired); 254 } 255 } 256 } 257 258 if session_valid { 259 let controller_did = token_data.claims.act.as_ref().map(|a| a.sub.clone()); 260 return Ok(AuthenticatedUser { 261 did: did.clone(), 262 key_bytes: Some(decrypted_key), 263 is_oauth: false, 264 is_admin, 265 is_takendown: takedown_ref.is_some(), 266 scope: token_data.claims.scope.clone(), 267 controller_did, 268 }); 269 } 270 } 271 Err(verify::TokenVerifyError::Expired) => { 272 return Err(TokenValidationError::TokenExpired); 273 } 274 Err(verify::TokenVerifyError::Invalid) => {} 275 } 276 } 277 } 278 279 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 280 && let Some(oauth_token) = sqlx::query!( 281 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 282 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 283 FROM oauth_token t 284 JOIN users u ON t.did = u.did 285 LEFT JOIN user_keys k ON u.id = k.user_id 286 WHERE t.token_id = $1"#, 287 oauth_info.token_id 288 ) 289 .fetch_optional(db) 290 .await 291 .ok() 292 .flatten() 293 { 294 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 295 return Err(TokenValidationError::AccountDeactivated); 296 } 297 298 let is_takendown = oauth_token.takedown_ref.is_some(); 299 if !allow_takendown && is_takendown { 300 return Err(TokenValidationError::AccountTakedown); 301 } 302 303 let now = chrono::Utc::now(); 304 if oauth_token.expires_at > now { 305 let key_bytes = if let (Some(kb), Some(ev)) = 306 (&oauth_token.key_bytes, oauth_token.encryption_version) 307 { 308 crate::config::decrypt_key(kb, Some(ev)).ok() 309 } else { 310 None 311 }; 312 return Ok(AuthenticatedUser { 313 did: oauth_token.did, 314 key_bytes, 315 is_oauth: true, 316 is_admin: oauth_token.is_admin, 317 is_takendown, 318 scope: oauth_info.scope, 319 controller_did: oauth_info.controller_did, 320 }); 321 } else { 322 return Err(TokenValidationError::TokenExpired); 323 } 324 } 325 326 Err(TokenValidationError::AuthenticationFailed) 327} 328 329pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 330 let key_cache_key = format!("auth:key:{}", did); 331 let _ = cache.delete(&key_cache_key).await; 332} 333 334pub async fn validate_token_with_dpop( 335 db: &PgPool, 336 token: &str, 337 is_dpop_token: bool, 338 dpop_proof: Option<&str>, 339 http_method: &str, 340 http_uri: &str, 341 allow_deactivated: bool, 342) -> Result<AuthenticatedUser, TokenValidationError> { 343 if !is_dpop_token { 344 if allow_deactivated { 345 return validate_bearer_token_allow_deactivated(db, token).await; 346 } else { 347 return validate_bearer_token(db, token).await; 348 } 349 } 350 match crate::oauth::verify::verify_oauth_access_token( 351 db, 352 token, 353 dpop_proof, 354 http_method, 355 http_uri, 356 ) 357 .await 358 { 359 Ok(result) => { 360 let user_info = sqlx::query!( 361 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 362 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 363 FROM users u 364 LEFT JOIN user_keys k ON u.id = k.user_id 365 WHERE u.did = $1"#, 366 result.did 367 ) 368 .fetch_optional(db) 369 .await 370 .ok() 371 .flatten(); 372 let Some(user_info) = user_info else { 373 return Err(TokenValidationError::AuthenticationFailed); 374 }; 375 if !allow_deactivated && user_info.deactivated_at.is_some() { 376 return Err(TokenValidationError::AccountDeactivated); 377 } 378 let is_takendown = user_info.takedown_ref.is_some(); 379 if is_takendown { 380 return Err(TokenValidationError::AccountTakedown); 381 } 382 let key_bytes = if let (Some(kb), Some(ev)) = 383 (&user_info.key_bytes, user_info.encryption_version) 384 { 385 crate::config::decrypt_key(kb, Some(ev)).ok() 386 } else { 387 None 388 }; 389 Ok(AuthenticatedUser { 390 did: result.did, 391 key_bytes, 392 is_oauth: true, 393 is_admin: user_info.is_admin, 394 is_takendown, 395 scope: result.scope, 396 controller_did: None, 397 }) 398 } 399 Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired), 400 Err(_) => Err(TokenValidationError::AuthenticationFailed), 401 } 402} 403 404#[derive(Debug, Clone, Serialize, Deserialize)] 405pub struct ActClaim { 406 pub sub: String, 407} 408 409#[derive(Debug, Serialize, Deserialize)] 410pub struct Claims { 411 pub iss: String, 412 pub sub: String, 413 pub aud: String, 414 pub exp: usize, 415 pub iat: usize, 416 #[serde(skip_serializing_if = "Option::is_none")] 417 pub scope: Option<String>, 418 #[serde(skip_serializing_if = "Option::is_none")] 419 pub lxm: Option<String>, 420 pub jti: String, 421 #[serde(skip_serializing_if = "Option::is_none")] 422 pub act: Option<ActClaim>, 423} 424 425#[derive(Debug, Serialize, Deserialize)] 426pub struct Header { 427 pub alg: String, 428 pub typ: String, 429} 430 431#[derive(Debug, Serialize, Deserialize)] 432pub struct UnsafeClaims { 433 pub iss: String, 434 pub sub: Option<String>, 435} 436 437pub struct TokenData<T> { 438 pub claims: T, 439}