this repo has no description
1use crate::appview::DidResolver; 2use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 3use crate::circuit_breaker::CircuitBreakers; 4use crate::config::AuthConfig; 5use crate::rate_limit::RateLimiters; 6use crate::repo::PostgresBlockStore; 7use crate::storage::{BlobStorage, S3BlobStorage}; 8use crate::sync::firehose::SequencedEvent; 9use sqlx::PgPool; 10use std::error::Error; 11use std::sync::Arc; 12use tokio::sync::broadcast; 13 14#[derive(Clone)] 15pub struct AppState { 16 pub db: PgPool, 17 pub block_store: PostgresBlockStore, 18 pub blob_store: Arc<dyn BlobStorage>, 19 pub firehose_tx: broadcast::Sender<SequencedEvent>, 20 pub rate_limiters: Arc<RateLimiters>, 21 pub circuit_breakers: Arc<CircuitBreakers>, 22 pub cache: Arc<dyn Cache>, 23 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 24 pub did_resolver: Arc<DidResolver>, 25} 26 27pub enum RateLimitKind { 28 Login, 29 AccountCreation, 30 PasswordReset, 31 ResetPassword, 32 RefreshSession, 33 OAuthToken, 34 OAuthAuthorize, 35 OAuthPar, 36 OAuthIntrospect, 37 AppPassword, 38 EmailUpdate, 39 TotpVerify, 40} 41 42impl RateLimitKind { 43 fn key_prefix(&self) -> &'static str { 44 match self { 45 Self::Login => "login", 46 Self::AccountCreation => "account_creation", 47 Self::PasswordReset => "password_reset", 48 Self::ResetPassword => "reset_password", 49 Self::RefreshSession => "refresh_session", 50 Self::OAuthToken => "oauth_token", 51 Self::OAuthAuthorize => "oauth_authorize", 52 Self::OAuthPar => "oauth_par", 53 Self::OAuthIntrospect => "oauth_introspect", 54 Self::AppPassword => "app_password", 55 Self::EmailUpdate => "email_update", 56 Self::TotpVerify => "totp_verify", 57 } 58 } 59 60 fn limit_and_window_ms(&self) -> (u32, u64) { 61 match self { 62 Self::Login => (10, 60_000), 63 Self::AccountCreation => (10, 3_600_000), 64 Self::PasswordReset => (5, 3_600_000), 65 Self::ResetPassword => (10, 60_000), 66 Self::RefreshSession => (60, 60_000), 67 Self::OAuthToken => (30, 60_000), 68 Self::OAuthAuthorize => (10, 60_000), 69 Self::OAuthPar => (30, 60_000), 70 Self::OAuthIntrospect => (30, 60_000), 71 Self::AppPassword => (10, 60_000), 72 Self::EmailUpdate => (5, 3_600_000), 73 Self::TotpVerify => (5, 300_000), 74 } 75 } 76} 77 78impl AppState { 79 pub async fn new() -> Result<Self, Box<dyn Error>> { 80 let database_url = std::env::var("DATABASE_URL") 81 .map_err(|_| "DATABASE_URL environment variable must be set")?; 82 83 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS") 84 .ok() 85 .and_then(|v| v.parse().ok()) 86 .unwrap_or(100); 87 88 let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS") 89 .ok() 90 .and_then(|v| v.parse().ok()) 91 .unwrap_or(10); 92 93 let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS") 94 .ok() 95 .and_then(|v| v.parse().ok()) 96 .unwrap_or(10); 97 98 tracing::info!( 99 "Configuring database pool: max={}, min={}, acquire_timeout={}s", 100 max_connections, 101 min_connections, 102 acquire_timeout_secs 103 ); 104 105 let db = sqlx::postgres::PgPoolOptions::new() 106 .max_connections(max_connections) 107 .min_connections(min_connections) 108 .acquire_timeout(std::time::Duration::from_secs(acquire_timeout_secs)) 109 .idle_timeout(std::time::Duration::from_secs(300)) 110 .max_lifetime(std::time::Duration::from_secs(1800)) 111 .connect(&database_url) 112 .await 113 .map_err(|e| format!("Failed to connect to Postgres: {}", e))?; 114 115 sqlx::migrate!("./migrations") 116 .run(&db) 117 .await 118 .map_err(|e| format!("Failed to run migrations: {}", e))?; 119 120 Ok(Self::from_db(db).await) 121 } 122 123 pub async fn from_db(db: PgPool) -> Self { 124 AuthConfig::init(); 125 126 let block_store = PostgresBlockStore::new(db.clone()); 127 let blob_store = S3BlobStorage::new().await; 128 129 let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE") 130 .ok() 131 .and_then(|v| v.parse().ok()) 132 .unwrap_or(10000); 133 134 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 135 let rate_limiters = Arc::new(RateLimiters::new()); 136 let circuit_breakers = Arc::new(CircuitBreakers::new()); 137 let (cache, distributed_rate_limiter) = create_cache().await; 138 let did_resolver = Arc::new(DidResolver::new()); 139 140 Self { 141 db, 142 block_store, 143 blob_store: Arc::new(blob_store), 144 firehose_tx, 145 rate_limiters, 146 circuit_breakers, 147 cache, 148 distributed_rate_limiter, 149 did_resolver, 150 } 151 } 152 153 pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { 154 self.rate_limiters = Arc::new(rate_limiters); 155 self 156 } 157 158 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 159 self.circuit_breakers = Arc::new(circuit_breakers); 160 self 161 } 162 163 pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool { 164 if std::env::var("DISABLE_RATE_LIMITING").is_ok() { 165 return true; 166 } 167 168 let key = format!("{}:{}", kind.key_prefix(), client_ip); 169 let limiter_name = kind.key_prefix(); 170 let (limit, window_ms) = kind.limit_and_window_ms(); 171 172 if !self 173 .distributed_rate_limiter 174 .check_rate_limit(&key, limit, window_ms) 175 .await 176 { 177 crate::metrics::record_rate_limit_rejection(limiter_name); 178 return false; 179 } 180 181 let limiter = match kind { 182 RateLimitKind::Login => &self.rate_limiters.login, 183 RateLimitKind::AccountCreation => &self.rate_limiters.account_creation, 184 RateLimitKind::PasswordReset => &self.rate_limiters.password_reset, 185 RateLimitKind::ResetPassword => &self.rate_limiters.reset_password, 186 RateLimitKind::RefreshSession => &self.rate_limiters.refresh_session, 187 RateLimitKind::OAuthToken => &self.rate_limiters.oauth_token, 188 RateLimitKind::OAuthAuthorize => &self.rate_limiters.oauth_authorize, 189 RateLimitKind::OAuthPar => &self.rate_limiters.oauth_par, 190 RateLimitKind::OAuthIntrospect => &self.rate_limiters.oauth_introspect, 191 RateLimitKind::AppPassword => &self.rate_limiters.app_password, 192 RateLimitKind::EmailUpdate => &self.rate_limiters.email_update, 193 RateLimitKind::TotpVerify => &self.rate_limiters.totp_verify, 194 }; 195 196 let ok = limiter.check_key(&client_ip.to_string()).is_ok(); 197 if !ok { 198 crate::metrics::record_rate_limit_rejection(limiter_name); 199 } 200 ok 201 } 202}