this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::sync::Arc; 5use std::time::Duration; 6 7use crate::cache::Cache; 8use crate::oauth::scopes::ScopePermissions; 9 10pub mod extractor; 11pub mod scope_check; 12pub mod service; 13pub mod token; 14pub mod totp; 15pub mod verify; 16pub mod webauthn; 17 18pub use extractor::{ 19 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 20 extract_auth_token_from_header, extract_bearer_token_from_header, 21}; 22pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 23pub use token::{ 24 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 25 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 26 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata, 27 create_service_token, 28}; 29pub use verify::{ 30 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 31}; 32 33const KEY_CACHE_TTL_SECS: u64 = 300; 34const SESSION_CACHE_TTL_SECS: u64 = 60; 35 36#[derive(Debug, Clone, Copy, PartialEq, Eq)] 37pub enum TokenValidationError { 38 AccountDeactivated, 39 AccountTakedown, 40 KeyDecryptionFailed, 41 AuthenticationFailed, 42} 43 44impl fmt::Display for TokenValidationError { 45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 46 match self { 47 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 48 Self::AccountTakedown => write!(f, "AccountTakedown"), 49 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 50 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 51 } 52 } 53} 54 55pub struct AuthenticatedUser { 56 pub did: String, 57 pub key_bytes: Option<Vec<u8>>, 58 pub is_oauth: bool, 59 pub is_admin: bool, 60 pub scope: Option<String>, 61} 62 63impl AuthenticatedUser { 64 pub fn permissions(&self) -> ScopePermissions { 65 if !self.is_oauth { 66 return ScopePermissions::from_scope_string(Some("atproto")); 67 } 68 ScopePermissions::from_scope_string(self.scope.as_deref()) 69 } 70} 71 72pub async fn validate_bearer_token( 73 db: &PgPool, 74 token: &str, 75) -> Result<AuthenticatedUser, TokenValidationError> { 76 validate_bearer_token_with_options_internal(db, None, token, false, false).await 77} 78 79pub async fn validate_bearer_token_allow_deactivated( 80 db: &PgPool, 81 token: &str, 82) -> Result<AuthenticatedUser, TokenValidationError> { 83 validate_bearer_token_with_options_internal(db, None, token, true, false).await 84} 85 86pub async fn validate_bearer_token_cached( 87 db: &PgPool, 88 cache: &Arc<dyn Cache>, 89 token: &str, 90) -> Result<AuthenticatedUser, TokenValidationError> { 91 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 92} 93 94pub async fn validate_bearer_token_cached_allow_deactivated( 95 db: &PgPool, 96 cache: &Arc<dyn Cache>, 97 token: &str, 98) -> Result<AuthenticatedUser, TokenValidationError> { 99 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 100} 101 102pub async fn validate_bearer_token_for_service_auth( 103 db: &PgPool, 104 token: &str, 105) -> Result<AuthenticatedUser, TokenValidationError> { 106 validate_bearer_token_with_options_internal(db, None, token, true, true).await 107} 108 109async fn validate_bearer_token_with_options_internal( 110 db: &PgPool, 111 cache: Option<&Arc<dyn Cache>>, 112 token: &str, 113 allow_deactivated: bool, 114 allow_takendown: bool, 115) -> Result<AuthenticatedUser, TokenValidationError> { 116 let did_from_token = get_did_from_token(token).ok(); 117 118 if let Some(ref did) = did_from_token { 119 let key_cache_key = format!("auth:key:{}", did); 120 let mut cached_key: Option<Vec<u8>> = None; 121 122 if let Some(c) = cache { 123 cached_key = c.get_bytes(&key_cache_key).await; 124 if cached_key.is_some() { 125 crate::metrics::record_auth_cache_hit("key"); 126 } else { 127 crate::metrics::record_auth_cache_miss("key"); 128 } 129 } 130 131 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 132 { 133 let user_status = sqlx::query!( 134 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 135 did 136 ) 137 .fetch_optional(db) 138 .await 139 .ok() 140 .flatten(); 141 142 match user_status { 143 Some(status) => ( 144 Some(key), 145 status.deactivated_at, 146 status.takedown_ref, 147 status.is_admin, 148 ), 149 None => (None, None, None, false), 150 } 151 } else if let Some(user) = sqlx::query!( 152 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 153 FROM users u 154 JOIN user_keys k ON u.id = k.user_id 155 WHERE u.did = $1", 156 did 157 ) 158 .fetch_optional(db) 159 .await 160 .ok() 161 .flatten() 162 { 163 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 164 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 165 166 if let Some(c) = cache { 167 let _ = c 168 .set_bytes( 169 &key_cache_key, 170 &key, 171 Duration::from_secs(KEY_CACHE_TTL_SECS), 172 ) 173 .await; 174 } 175 176 ( 177 Some(key), 178 user.deactivated_at, 179 user.takedown_ref, 180 user.is_admin, 181 ) 182 } else { 183 (None, None, None, false) 184 }; 185 186 if let Some(decrypted_key) = decrypted_key { 187 if !allow_deactivated && deactivated_at.is_some() { 188 return Err(TokenValidationError::AccountDeactivated); 189 } 190 191 if !allow_takendown && takedown_ref.is_some() { 192 return Err(TokenValidationError::AccountTakedown); 193 } 194 195 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 196 let jti = &token_data.claims.jti; 197 let session_cache_key = format!("auth:session:{}:{}", did, jti); 198 let mut session_valid = false; 199 200 if let Some(c) = cache { 201 if let Some(cached_value) = c.get(&session_cache_key).await { 202 session_valid = cached_value == "1"; 203 crate::metrics::record_auth_cache_hit("session"); 204 } else { 205 crate::metrics::record_auth_cache_miss("session"); 206 } 207 } 208 209 if !session_valid { 210 let session_exists = sqlx::query_scalar!( 211 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 212 did, 213 jti 214 ) 215 .fetch_optional(db) 216 .await 217 .ok() 218 .flatten(); 219 220 session_valid = session_exists.is_some(); 221 222 if session_valid && let Some(c) = cache { 223 let _ = c 224 .set( 225 &session_cache_key, 226 "1", 227 Duration::from_secs(SESSION_CACHE_TTL_SECS), 228 ) 229 .await; 230 } 231 } 232 233 if session_valid { 234 return Ok(AuthenticatedUser { 235 did: did.clone(), 236 key_bytes: Some(decrypted_key), 237 is_oauth: false, 238 is_admin, 239 scope: None, 240 }); 241 } 242 } 243 } 244 } 245 246 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 247 && let Some(oauth_token) = sqlx::query!( 248 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 249 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 250 FROM oauth_token t 251 JOIN users u ON t.did = u.did 252 LEFT JOIN user_keys k ON u.id = k.user_id 253 WHERE t.token_id = $1"#, 254 oauth_info.token_id 255 ) 256 .fetch_optional(db) 257 .await 258 .ok() 259 .flatten() 260 { 261 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 262 return Err(TokenValidationError::AccountDeactivated); 263 } 264 265 if oauth_token.takedown_ref.is_some() { 266 return Err(TokenValidationError::AccountTakedown); 267 } 268 269 let now = chrono::Utc::now(); 270 if oauth_token.expires_at > now { 271 let key_bytes = if let (Some(kb), Some(ev)) = 272 (&oauth_token.key_bytes, oauth_token.encryption_version) 273 { 274 crate::config::decrypt_key(kb, Some(ev)).ok() 275 } else { 276 None 277 }; 278 return Ok(AuthenticatedUser { 279 did: oauth_token.did, 280 key_bytes, 281 is_oauth: true, 282 is_admin: oauth_token.is_admin, 283 scope: oauth_info.scope, 284 }); 285 } 286 } 287 288 Err(TokenValidationError::AuthenticationFailed) 289} 290 291pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 292 let key_cache_key = format!("auth:key:{}", did); 293 let _ = cache.delete(&key_cache_key).await; 294} 295 296pub async fn validate_token_with_dpop( 297 db: &PgPool, 298 token: &str, 299 is_dpop_token: bool, 300 dpop_proof: Option<&str>, 301 http_method: &str, 302 http_uri: &str, 303 allow_deactivated: bool, 304) -> Result<AuthenticatedUser, TokenValidationError> { 305 if !is_dpop_token { 306 if allow_deactivated { 307 return validate_bearer_token_allow_deactivated(db, token).await; 308 } else { 309 return validate_bearer_token(db, token).await; 310 } 311 } 312 match crate::oauth::verify::verify_oauth_access_token( 313 db, 314 token, 315 dpop_proof, 316 http_method, 317 http_uri, 318 ) 319 .await 320 { 321 Ok(result) => { 322 let user_info = sqlx::query!( 323 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 324 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 325 FROM users u 326 LEFT JOIN user_keys k ON u.id = k.user_id 327 WHERE u.did = $1"#, 328 result.did 329 ) 330 .fetch_optional(db) 331 .await 332 .ok() 333 .flatten(); 334 let Some(user_info) = user_info else { 335 return Err(TokenValidationError::AuthenticationFailed); 336 }; 337 if !allow_deactivated && user_info.deactivated_at.is_some() { 338 return Err(TokenValidationError::AccountDeactivated); 339 } 340 if user_info.takedown_ref.is_some() { 341 return Err(TokenValidationError::AccountTakedown); 342 } 343 let key_bytes = if let (Some(kb), Some(ev)) = 344 (&user_info.key_bytes, user_info.encryption_version) 345 { 346 crate::config::decrypt_key(kb, Some(ev)).ok() 347 } else { 348 None 349 }; 350 Ok(AuthenticatedUser { 351 did: result.did, 352 key_bytes, 353 is_oauth: true, 354 is_admin: user_info.is_admin, 355 scope: result.scope, 356 }) 357 } 358 Err(_) => Err(TokenValidationError::AuthenticationFailed), 359 } 360} 361 362#[derive(Debug, Serialize, Deserialize)] 363pub struct Claims { 364 pub iss: String, 365 pub sub: String, 366 pub aud: String, 367 pub exp: usize, 368 pub iat: usize, 369 #[serde(skip_serializing_if = "Option::is_none")] 370 pub scope: Option<String>, 371 #[serde(skip_serializing_if = "Option::is_none")] 372 pub lxm: Option<String>, 373 pub jti: String, 374} 375 376#[derive(Debug, Serialize, Deserialize)] 377pub struct Header { 378 pub alg: String, 379 pub typ: String, 380} 381 382#[derive(Debug, Serialize, Deserialize)] 383pub struct UnsafeClaims { 384 pub iss: String, 385 pub sub: Option<String>, 386} 387 388pub struct TokenData<T> { 389 pub claims: T, 390}