this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::sync::Arc; 5use std::time::Duration; 6 7use crate::cache::Cache; 8use crate::oauth::scopes::ScopePermissions; 9 10pub mod extractor; 11pub mod scope_check; 12pub mod service; 13pub mod token; 14pub mod totp; 15pub mod verification_token; 16pub mod verify; 17pub mod webauthn; 18 19pub use extractor::{ 20 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 21 extract_auth_token_from_header, extract_bearer_token_from_header, 22}; 23pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 24pub use token::{ 25 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 26 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 27 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata, 28 create_service_token, 29}; 30pub use verify::{ 31 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 32}; 33 34const KEY_CACHE_TTL_SECS: u64 = 300; 35const SESSION_CACHE_TTL_SECS: u64 = 60; 36 37#[derive(Debug, Clone, Copy, PartialEq, Eq)] 38pub enum TokenValidationError { 39 AccountDeactivated, 40 AccountTakedown, 41 KeyDecryptionFailed, 42 AuthenticationFailed, 43} 44 45impl fmt::Display for TokenValidationError { 46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 match self { 48 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 49 Self::AccountTakedown => write!(f, "AccountTakedown"), 50 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 51 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 52 } 53 } 54} 55 56pub struct AuthenticatedUser { 57 pub did: String, 58 pub key_bytes: Option<Vec<u8>>, 59 pub is_oauth: bool, 60 pub is_admin: bool, 61 pub scope: Option<String>, 62} 63 64impl AuthenticatedUser { 65 pub fn permissions(&self) -> ScopePermissions { 66 if !self.is_oauth { 67 return ScopePermissions::from_scope_string(Some("atproto")); 68 } 69 ScopePermissions::from_scope_string(self.scope.as_deref()) 70 } 71} 72 73pub async fn validate_bearer_token( 74 db: &PgPool, 75 token: &str, 76) -> Result<AuthenticatedUser, TokenValidationError> { 77 validate_bearer_token_with_options_internal(db, None, token, false, false).await 78} 79 80pub async fn validate_bearer_token_allow_deactivated( 81 db: &PgPool, 82 token: &str, 83) -> Result<AuthenticatedUser, TokenValidationError> { 84 validate_bearer_token_with_options_internal(db, None, token, true, false).await 85} 86 87pub async fn validate_bearer_token_cached( 88 db: &PgPool, 89 cache: &Arc<dyn Cache>, 90 token: &str, 91) -> Result<AuthenticatedUser, TokenValidationError> { 92 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 93} 94 95pub async fn validate_bearer_token_cached_allow_deactivated( 96 db: &PgPool, 97 cache: &Arc<dyn Cache>, 98 token: &str, 99) -> Result<AuthenticatedUser, TokenValidationError> { 100 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 101} 102 103pub async fn validate_bearer_token_for_service_auth( 104 db: &PgPool, 105 token: &str, 106) -> Result<AuthenticatedUser, TokenValidationError> { 107 validate_bearer_token_with_options_internal(db, None, token, true, true).await 108} 109 110async fn validate_bearer_token_with_options_internal( 111 db: &PgPool, 112 cache: Option<&Arc<dyn Cache>>, 113 token: &str, 114 allow_deactivated: bool, 115 allow_takendown: bool, 116) -> Result<AuthenticatedUser, TokenValidationError> { 117 let did_from_token = get_did_from_token(token).ok(); 118 119 if let Some(ref did) = did_from_token { 120 let key_cache_key = format!("auth:key:{}", did); 121 let mut cached_key: Option<Vec<u8>> = None; 122 123 if let Some(c) = cache { 124 cached_key = c.get_bytes(&key_cache_key).await; 125 if cached_key.is_some() { 126 crate::metrics::record_auth_cache_hit("key"); 127 } else { 128 crate::metrics::record_auth_cache_miss("key"); 129 } 130 } 131 132 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 133 { 134 let user_status = sqlx::query!( 135 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 136 did 137 ) 138 .fetch_optional(db) 139 .await 140 .ok() 141 .flatten(); 142 143 match user_status { 144 Some(status) => ( 145 Some(key), 146 status.deactivated_at, 147 status.takedown_ref, 148 status.is_admin, 149 ), 150 None => (None, None, None, false), 151 } 152 } else if let Some(user) = sqlx::query!( 153 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 154 FROM users u 155 JOIN user_keys k ON u.id = k.user_id 156 WHERE u.did = $1", 157 did 158 ) 159 .fetch_optional(db) 160 .await 161 .ok() 162 .flatten() 163 { 164 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 165 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 166 167 if let Some(c) = cache { 168 let _ = c 169 .set_bytes( 170 &key_cache_key, 171 &key, 172 Duration::from_secs(KEY_CACHE_TTL_SECS), 173 ) 174 .await; 175 } 176 177 ( 178 Some(key), 179 user.deactivated_at, 180 user.takedown_ref, 181 user.is_admin, 182 ) 183 } else { 184 (None, None, None, false) 185 }; 186 187 if let Some(decrypted_key) = decrypted_key { 188 if !allow_deactivated && deactivated_at.is_some() { 189 return Err(TokenValidationError::AccountDeactivated); 190 } 191 192 if !allow_takendown && takedown_ref.is_some() { 193 return Err(TokenValidationError::AccountTakedown); 194 } 195 196 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 197 let jti = &token_data.claims.jti; 198 let session_cache_key = format!("auth:session:{}:{}", did, jti); 199 let mut session_valid = false; 200 201 if let Some(c) = cache { 202 if let Some(cached_value) = c.get(&session_cache_key).await { 203 session_valid = cached_value == "1"; 204 crate::metrics::record_auth_cache_hit("session"); 205 } else { 206 crate::metrics::record_auth_cache_miss("session"); 207 } 208 } 209 210 if !session_valid { 211 let session_exists = sqlx::query_scalar!( 212 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 213 did, 214 jti 215 ) 216 .fetch_optional(db) 217 .await 218 .ok() 219 .flatten(); 220 221 session_valid = session_exists.is_some(); 222 223 if session_valid && let Some(c) = cache { 224 let _ = c 225 .set( 226 &session_cache_key, 227 "1", 228 Duration::from_secs(SESSION_CACHE_TTL_SECS), 229 ) 230 .await; 231 } 232 } 233 234 if session_valid { 235 return Ok(AuthenticatedUser { 236 did: did.clone(), 237 key_bytes: Some(decrypted_key), 238 is_oauth: false, 239 is_admin, 240 scope: None, 241 }); 242 } 243 } 244 } 245 } 246 247 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 248 && let Some(oauth_token) = sqlx::query!( 249 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 250 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 251 FROM oauth_token t 252 JOIN users u ON t.did = u.did 253 LEFT JOIN user_keys k ON u.id = k.user_id 254 WHERE t.token_id = $1"#, 255 oauth_info.token_id 256 ) 257 .fetch_optional(db) 258 .await 259 .ok() 260 .flatten() 261 { 262 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 263 return Err(TokenValidationError::AccountDeactivated); 264 } 265 266 if oauth_token.takedown_ref.is_some() { 267 return Err(TokenValidationError::AccountTakedown); 268 } 269 270 let now = chrono::Utc::now(); 271 if oauth_token.expires_at > now { 272 let key_bytes = if let (Some(kb), Some(ev)) = 273 (&oauth_token.key_bytes, oauth_token.encryption_version) 274 { 275 crate::config::decrypt_key(kb, Some(ev)).ok() 276 } else { 277 None 278 }; 279 return Ok(AuthenticatedUser { 280 did: oauth_token.did, 281 key_bytes, 282 is_oauth: true, 283 is_admin: oauth_token.is_admin, 284 scope: oauth_info.scope, 285 }); 286 } 287 } 288 289 Err(TokenValidationError::AuthenticationFailed) 290} 291 292pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 293 let key_cache_key = format!("auth:key:{}", did); 294 let _ = cache.delete(&key_cache_key).await; 295} 296 297pub async fn validate_token_with_dpop( 298 db: &PgPool, 299 token: &str, 300 is_dpop_token: bool, 301 dpop_proof: Option<&str>, 302 http_method: &str, 303 http_uri: &str, 304 allow_deactivated: bool, 305) -> Result<AuthenticatedUser, TokenValidationError> { 306 if !is_dpop_token { 307 if allow_deactivated { 308 return validate_bearer_token_allow_deactivated(db, token).await; 309 } else { 310 return validate_bearer_token(db, token).await; 311 } 312 } 313 match crate::oauth::verify::verify_oauth_access_token( 314 db, 315 token, 316 dpop_proof, 317 http_method, 318 http_uri, 319 ) 320 .await 321 { 322 Ok(result) => { 323 let user_info = sqlx::query!( 324 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 325 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 326 FROM users u 327 LEFT JOIN user_keys k ON u.id = k.user_id 328 WHERE u.did = $1"#, 329 result.did 330 ) 331 .fetch_optional(db) 332 .await 333 .ok() 334 .flatten(); 335 let Some(user_info) = user_info else { 336 return Err(TokenValidationError::AuthenticationFailed); 337 }; 338 if !allow_deactivated && user_info.deactivated_at.is_some() { 339 return Err(TokenValidationError::AccountDeactivated); 340 } 341 if user_info.takedown_ref.is_some() { 342 return Err(TokenValidationError::AccountTakedown); 343 } 344 let key_bytes = if let (Some(kb), Some(ev)) = 345 (&user_info.key_bytes, user_info.encryption_version) 346 { 347 crate::config::decrypt_key(kb, Some(ev)).ok() 348 } else { 349 None 350 }; 351 Ok(AuthenticatedUser { 352 did: result.did, 353 key_bytes, 354 is_oauth: true, 355 is_admin: user_info.is_admin, 356 scope: result.scope, 357 }) 358 } 359 Err(_) => Err(TokenValidationError::AuthenticationFailed), 360 } 361} 362 363#[derive(Debug, Serialize, Deserialize)] 364pub struct Claims { 365 pub iss: String, 366 pub sub: String, 367 pub aud: String, 368 pub exp: usize, 369 pub iat: usize, 370 #[serde(skip_serializing_if = "Option::is_none")] 371 pub scope: Option<String>, 372 #[serde(skip_serializing_if = "Option::is_none")] 373 pub lxm: Option<String>, 374 pub jti: String, 375} 376 377#[derive(Debug, Serialize, Deserialize)] 378pub struct Header { 379 pub alg: String, 380 pub typ: String, 381} 382 383#[derive(Debug, Serialize, Deserialize)] 384pub struct UnsafeClaims { 385 pub iss: String, 386 pub sub: Option<String>, 387} 388 389pub struct TokenData<T> { 390 pub claims: T, 391}