this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::sync::Arc; 5use std::time::Duration; 6 7use crate::cache::Cache; 8use crate::oauth::scopes::ScopePermissions; 9 10pub mod extractor; 11pub mod scope_check; 12pub mod service; 13pub mod token; 14pub mod verify; 15 16pub use extractor::{ 17 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 18 extract_auth_token_from_header, extract_bearer_token_from_header, 19}; 20pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 21pub use token::{ 22 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 23 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 24 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata, 25 create_service_token, 26}; 27pub use verify::{ 28 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 29}; 30 31const KEY_CACHE_TTL_SECS: u64 = 300; 32const SESSION_CACHE_TTL_SECS: u64 = 60; 33 34#[derive(Debug, Clone, Copy, PartialEq, Eq)] 35pub enum TokenValidationError { 36 AccountDeactivated, 37 AccountTakedown, 38 KeyDecryptionFailed, 39 AuthenticationFailed, 40} 41 42impl fmt::Display for TokenValidationError { 43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 44 match self { 45 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 46 Self::AccountTakedown => write!(f, "AccountTakedown"), 47 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 48 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 49 } 50 } 51} 52 53pub struct AuthenticatedUser { 54 pub did: String, 55 pub key_bytes: Option<Vec<u8>>, 56 pub is_oauth: bool, 57 pub is_admin: bool, 58 pub scope: Option<String>, 59} 60 61impl AuthenticatedUser { 62 pub fn permissions(&self) -> ScopePermissions { 63 if !self.is_oauth { 64 return ScopePermissions::from_scope_string(Some("atproto")); 65 } 66 ScopePermissions::from_scope_string(self.scope.as_deref()) 67 } 68} 69 70pub async fn validate_bearer_token( 71 db: &PgPool, 72 token: &str, 73) -> Result<AuthenticatedUser, TokenValidationError> { 74 validate_bearer_token_with_options_internal(db, None, token, false, false).await 75} 76 77pub async fn validate_bearer_token_allow_deactivated( 78 db: &PgPool, 79 token: &str, 80) -> Result<AuthenticatedUser, TokenValidationError> { 81 validate_bearer_token_with_options_internal(db, None, token, true, false).await 82} 83 84pub async fn validate_bearer_token_cached( 85 db: &PgPool, 86 cache: &Arc<dyn Cache>, 87 token: &str, 88) -> Result<AuthenticatedUser, TokenValidationError> { 89 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 90} 91 92pub async fn validate_bearer_token_cached_allow_deactivated( 93 db: &PgPool, 94 cache: &Arc<dyn Cache>, 95 token: &str, 96) -> Result<AuthenticatedUser, TokenValidationError> { 97 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 98} 99 100pub async fn validate_bearer_token_for_service_auth( 101 db: &PgPool, 102 token: &str, 103) -> Result<AuthenticatedUser, TokenValidationError> { 104 validate_bearer_token_with_options_internal(db, None, token, true, true).await 105} 106 107async fn validate_bearer_token_with_options_internal( 108 db: &PgPool, 109 cache: Option<&Arc<dyn Cache>>, 110 token: &str, 111 allow_deactivated: bool, 112 allow_takendown: bool, 113) -> Result<AuthenticatedUser, TokenValidationError> { 114 let did_from_token = get_did_from_token(token).ok(); 115 116 if let Some(ref did) = did_from_token { 117 let key_cache_key = format!("auth:key:{}", did); 118 let mut cached_key: Option<Vec<u8>> = None; 119 120 if let Some(c) = cache { 121 cached_key = c.get_bytes(&key_cache_key).await; 122 if cached_key.is_some() { 123 crate::metrics::record_auth_cache_hit("key"); 124 } else { 125 crate::metrics::record_auth_cache_miss("key"); 126 } 127 } 128 129 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 130 { 131 let user_status = sqlx::query!( 132 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 133 did 134 ) 135 .fetch_optional(db) 136 .await 137 .ok() 138 .flatten(); 139 140 match user_status { 141 Some(status) => ( 142 Some(key), 143 status.deactivated_at, 144 status.takedown_ref, 145 status.is_admin, 146 ), 147 None => (None, None, None, false), 148 } 149 } else if let Some(user) = sqlx::query!( 150 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 151 FROM users u 152 JOIN user_keys k ON u.id = k.user_id 153 WHERE u.did = $1", 154 did 155 ) 156 .fetch_optional(db) 157 .await 158 .ok() 159 .flatten() 160 { 161 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 162 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 163 164 if let Some(c) = cache { 165 let _ = c 166 .set_bytes( 167 &key_cache_key, 168 &key, 169 Duration::from_secs(KEY_CACHE_TTL_SECS), 170 ) 171 .await; 172 } 173 174 ( 175 Some(key), 176 user.deactivated_at, 177 user.takedown_ref, 178 user.is_admin, 179 ) 180 } else { 181 (None, None, None, false) 182 }; 183 184 if let Some(decrypted_key) = decrypted_key { 185 if !allow_deactivated && deactivated_at.is_some() { 186 return Err(TokenValidationError::AccountDeactivated); 187 } 188 189 if !allow_takendown && takedown_ref.is_some() { 190 return Err(TokenValidationError::AccountTakedown); 191 } 192 193 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 194 let jti = &token_data.claims.jti; 195 let session_cache_key = format!("auth:session:{}:{}", did, jti); 196 let mut session_valid = false; 197 198 if let Some(c) = cache { 199 if let Some(cached_value) = c.get(&session_cache_key).await { 200 session_valid = cached_value == "1"; 201 crate::metrics::record_auth_cache_hit("session"); 202 } else { 203 crate::metrics::record_auth_cache_miss("session"); 204 } 205 } 206 207 if !session_valid { 208 let session_exists = sqlx::query_scalar!( 209 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 210 did, 211 jti 212 ) 213 .fetch_optional(db) 214 .await 215 .ok() 216 .flatten(); 217 218 session_valid = session_exists.is_some(); 219 220 if session_valid && let Some(c) = cache { 221 let _ = c 222 .set( 223 &session_cache_key, 224 "1", 225 Duration::from_secs(SESSION_CACHE_TTL_SECS), 226 ) 227 .await; 228 } 229 } 230 231 if session_valid { 232 return Ok(AuthenticatedUser { 233 did: did.clone(), 234 key_bytes: Some(decrypted_key), 235 is_oauth: false, 236 is_admin, 237 scope: None, 238 }); 239 } 240 } 241 } 242 } 243 244 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 245 && let Some(oauth_token) = sqlx::query!( 246 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 247 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 248 FROM oauth_token t 249 JOIN users u ON t.did = u.did 250 LEFT JOIN user_keys k ON u.id = k.user_id 251 WHERE t.token_id = $1"#, 252 oauth_info.token_id 253 ) 254 .fetch_optional(db) 255 .await 256 .ok() 257 .flatten() 258 { 259 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 260 return Err(TokenValidationError::AccountDeactivated); 261 } 262 263 if oauth_token.takedown_ref.is_some() { 264 return Err(TokenValidationError::AccountTakedown); 265 } 266 267 let now = chrono::Utc::now(); 268 if oauth_token.expires_at > now { 269 let key_bytes = if let (Some(kb), Some(ev)) = 270 (&oauth_token.key_bytes, oauth_token.encryption_version) 271 { 272 crate::config::decrypt_key(kb, Some(ev)).ok() 273 } else { 274 None 275 }; 276 return Ok(AuthenticatedUser { 277 did: oauth_token.did, 278 key_bytes, 279 is_oauth: true, 280 is_admin: oauth_token.is_admin, 281 scope: oauth_info.scope, 282 }); 283 } 284 } 285 286 Err(TokenValidationError::AuthenticationFailed) 287} 288 289pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 290 let key_cache_key = format!("auth:key:{}", did); 291 let _ = cache.delete(&key_cache_key).await; 292} 293 294pub async fn validate_token_with_dpop( 295 db: &PgPool, 296 token: &str, 297 is_dpop_token: bool, 298 dpop_proof: Option<&str>, 299 http_method: &str, 300 http_uri: &str, 301 allow_deactivated: bool, 302) -> Result<AuthenticatedUser, TokenValidationError> { 303 if !is_dpop_token { 304 if allow_deactivated { 305 return validate_bearer_token_allow_deactivated(db, token).await; 306 } else { 307 return validate_bearer_token(db, token).await; 308 } 309 } 310 match crate::oauth::verify::verify_oauth_access_token( 311 db, 312 token, 313 dpop_proof, 314 http_method, 315 http_uri, 316 ) 317 .await 318 { 319 Ok(result) => { 320 let user_info = sqlx::query!( 321 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 322 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 323 FROM users u 324 LEFT JOIN user_keys k ON u.id = k.user_id 325 WHERE u.did = $1"#, 326 result.did 327 ) 328 .fetch_optional(db) 329 .await 330 .ok() 331 .flatten(); 332 let Some(user_info) = user_info else { 333 return Err(TokenValidationError::AuthenticationFailed); 334 }; 335 if !allow_deactivated && user_info.deactivated_at.is_some() { 336 return Err(TokenValidationError::AccountDeactivated); 337 } 338 if user_info.takedown_ref.is_some() { 339 return Err(TokenValidationError::AccountTakedown); 340 } 341 let key_bytes = if let (Some(kb), Some(ev)) = 342 (&user_info.key_bytes, user_info.encryption_version) 343 { 344 crate::config::decrypt_key(kb, Some(ev)).ok() 345 } else { 346 None 347 }; 348 Ok(AuthenticatedUser { 349 did: result.did, 350 key_bytes, 351 is_oauth: true, 352 is_admin: user_info.is_admin, 353 scope: result.scope, 354 }) 355 } 356 Err(_) => Err(TokenValidationError::AuthenticationFailed), 357 } 358} 359 360#[derive(Debug, Serialize, Deserialize)] 361pub struct Claims { 362 pub iss: String, 363 pub sub: String, 364 pub aud: String, 365 pub exp: usize, 366 pub iat: usize, 367 #[serde(skip_serializing_if = "Option::is_none")] 368 pub scope: Option<String>, 369 #[serde(skip_serializing_if = "Option::is_none")] 370 pub lxm: Option<String>, 371 pub jti: String, 372} 373 374#[derive(Debug, Serialize, Deserialize)] 375pub struct Header { 376 pub alg: String, 377 pub typ: String, 378} 379 380#[derive(Debug, Serialize, Deserialize)] 381pub struct UnsafeClaims { 382 pub iss: String, 383 pub sub: Option<String>, 384} 385 386pub struct TokenData<T> { 387 pub claims: T, 388}