this repo has no description
1use serde::{Deserialize, Serialize}; 2use sqlx::PgPool; 3use std::fmt; 4use std::sync::Arc; 5use std::time::Duration; 6use crate::cache::Cache; 7 8pub mod extractor; 9pub mod token; 10pub mod verify; 11 12pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header}; 13pub use token::{ 14 create_access_token, create_refresh_token, create_service_token, 15 create_access_token_with_metadata, create_refresh_token_with_metadata, 16 TokenWithMetadata, 17 TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, 18 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 19}; 20pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 21 22const KEY_CACHE_TTL_SECS: u64 = 300; 23const SESSION_CACHE_TTL_SECS: u64 = 60; 24 25#[derive(Debug, Clone, Copy, PartialEq, Eq)] 26pub enum TokenValidationError { 27 AccountDeactivated, 28 AccountTakedown, 29 KeyDecryptionFailed, 30 AuthenticationFailed, 31} 32 33impl fmt::Display for TokenValidationError { 34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 35 match self { 36 Self::AccountDeactivated => write!(f, "AccountDeactivated"), 37 Self::AccountTakedown => write!(f, "AccountTakedown"), 38 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 39 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 40 } 41 } 42} 43 44pub struct AuthenticatedUser { 45 pub did: String, 46 pub key_bytes: Option<Vec<u8>>, 47 pub is_oauth: bool, 48} 49 50pub async fn validate_bearer_token( 51 db: &PgPool, 52 token: &str, 53) -> Result<AuthenticatedUser, TokenValidationError> { 54 validate_bearer_token_with_options_internal(db, None, token, false).await 55} 56 57pub async fn validate_bearer_token_allow_deactivated( 58 db: &PgPool, 59 token: &str, 60) -> Result<AuthenticatedUser, TokenValidationError> { 61 validate_bearer_token_with_options_internal(db, None, token, true).await 62} 63 64pub async fn validate_bearer_token_cached( 65 db: &PgPool, 66 cache: &Arc<dyn Cache>, 67 token: &str, 68) -> Result<AuthenticatedUser, TokenValidationError> { 69 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await 70} 71 72pub async fn validate_bearer_token_cached_allow_deactivated( 73 db: &PgPool, 74 cache: &Arc<dyn Cache>, 75 token: &str, 76) -> Result<AuthenticatedUser, TokenValidationError> { 77 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await 78} 79 80async fn validate_bearer_token_with_options_internal( 81 db: &PgPool, 82 cache: Option<&Arc<dyn Cache>>, 83 token: &str, 84 allow_deactivated: bool, 85) -> Result<AuthenticatedUser, TokenValidationError> { 86 let did_from_token = get_did_from_token(token).ok(); 87 88 if let Some(ref did) = did_from_token { 89 let key_cache_key = format!("auth:key:{}", did); 90 let mut cached_key: Option<Vec<u8>> = None; 91 92 if let Some(c) = cache { 93 cached_key = c.get_bytes(&key_cache_key).await; 94 if cached_key.is_some() { 95 crate::metrics::record_auth_cache_hit("key"); 96 } else { 97 crate::metrics::record_auth_cache_miss("key"); 98 } 99 } 100 101 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { 102 let user_status = sqlx::query!( 103 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", 104 did 105 ) 106 .fetch_optional(db) 107 .await 108 .ok() 109 .flatten(); 110 111 match user_status { 112 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 113 None => (None, None, None), 114 } 115 } else { 116 if let Some(user) = sqlx::query!( 117 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref 118 FROM users u 119 JOIN user_keys k ON u.id = k.user_id 120 WHERE u.did = $1", 121 did 122 ) 123 .fetch_optional(db) 124 .await 125 .ok() 126 .flatten() 127 { 128 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 129 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 130 131 if let Some(c) = cache { 132 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 133 } 134 135 (Some(key), user.deactivated_at, user.takedown_ref) 136 } else { 137 (None, None, None) 138 } 139 }; 140 141 if let Some(decrypted_key) = decrypted_key { 142 if !allow_deactivated && deactivated_at.is_some() { 143 return Err(TokenValidationError::AccountDeactivated); 144 } 145 if takedown_ref.is_some() { 146 return Err(TokenValidationError::AccountTakedown); 147 } 148 149 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 150 let jti = &token_data.claims.jti; 151 let session_cache_key = format!("auth:session:{}:{}", did, jti); 152 let mut session_valid = false; 153 154 if let Some(c) = cache { 155 if let Some(cached_value) = c.get(&session_cache_key).await { 156 session_valid = cached_value == "1"; 157 crate::metrics::record_auth_cache_hit("session"); 158 } else { 159 crate::metrics::record_auth_cache_miss("session"); 160 } 161 } 162 163 if !session_valid { 164 let session_exists = sqlx::query_scalar!( 165 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 166 did, 167 jti 168 ) 169 .fetch_optional(db) 170 .await 171 .ok() 172 .flatten(); 173 174 session_valid = session_exists.is_some(); 175 176 if session_valid { 177 if let Some(c) = cache { 178 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 179 } 180 } 181 } 182 183 if session_valid { 184 return Ok(AuthenticatedUser { 185 did: did.clone(), 186 key_bytes: Some(decrypted_key), 187 is_oauth: false, 188 }); 189 } 190 } 191 } 192 } 193 194 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) { 195 if let Some(oauth_token) = sqlx::query!( 196 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref 197 FROM oauth_token t 198 JOIN users u ON t.did = u.did 199 WHERE t.token_id = $1"#, 200 oauth_info.token_id 201 ) 202 .fetch_optional(db) 203 .await 204 .ok() 205 .flatten() 206 { 207 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 208 return Err(TokenValidationError::AccountDeactivated); 209 } 210 if oauth_token.takedown_ref.is_some() { 211 return Err(TokenValidationError::AccountTakedown); 212 } 213 214 let now = chrono::Utc::now(); 215 if oauth_token.expires_at > now { 216 return Ok(AuthenticatedUser { 217 did: oauth_token.did, 218 key_bytes: None, 219 is_oauth: true, 220 }); 221 } 222 } 223 } 224 225 Err(TokenValidationError::AuthenticationFailed) 226} 227 228pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 229 let key_cache_key = format!("auth:key:{}", did); 230 let _ = cache.delete(&key_cache_key).await; 231} 232 233#[derive(Debug, Serialize, Deserialize)] 234pub struct Claims { 235 pub iss: String, 236 pub sub: String, 237 pub aud: String, 238 pub exp: usize, 239 pub iat: usize, 240 #[serde(skip_serializing_if = "Option::is_none")] 241 pub scope: Option<String>, 242 #[serde(skip_serializing_if = "Option::is_none")] 243 pub lxm: Option<String>, 244 pub jti: String, 245} 246 247#[derive(Debug, Serialize, Deserialize)] 248pub struct Header { 249 pub alg: String, 250 pub typ: String, 251} 252 253#[derive(Debug, Serialize, Deserialize)] 254pub struct UnsafeClaims { 255 pub iss: String, 256 pub sub: Option<String>, 257} 258 259// fancy boy TokenData equivalent for compatibility/structure 260pub struct TokenData<T> { 261 pub claims: T, 262}