this repo has no description
1use crate::appview::DidResolver;
2use crate::cache::{Cache, DistributedRateLimiter, create_cache};
3use crate::circuit_breaker::CircuitBreakers;
4use crate::config::AuthConfig;
5use crate::rate_limit::RateLimiters;
6use crate::repo::PostgresBlockStore;
7use crate::storage::{BlobStorage, S3BlobStorage};
8use crate::sync::firehose::SequencedEvent;
9use sqlx::PgPool;
10use std::error::Error;
11use std::sync::Arc;
12use tokio::sync::broadcast;
13
14#[derive(Clone)]
15pub struct AppState {
16 pub db: PgPool,
17 pub block_store: PostgresBlockStore,
18 pub blob_store: Arc<dyn BlobStorage>,
19 pub firehose_tx: broadcast::Sender<SequencedEvent>,
20 pub rate_limiters: Arc<RateLimiters>,
21 pub circuit_breakers: Arc<CircuitBreakers>,
22 pub cache: Arc<dyn Cache>,
23 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
24 pub did_resolver: Arc<DidResolver>,
25}
26
27pub enum RateLimitKind {
28 Login,
29 AccountCreation,
30 PasswordReset,
31 ResetPassword,
32 RefreshSession,
33 OAuthToken,
34 OAuthAuthorize,
35 OAuthPar,
36 OAuthIntrospect,
37 AppPassword,
38 EmailUpdate,
39 TotpVerify,
40}
41
42impl RateLimitKind {
43 fn key_prefix(&self) -> &'static str {
44 match self {
45 Self::Login => "login",
46 Self::AccountCreation => "account_creation",
47 Self::PasswordReset => "password_reset",
48 Self::ResetPassword => "reset_password",
49 Self::RefreshSession => "refresh_session",
50 Self::OAuthToken => "oauth_token",
51 Self::OAuthAuthorize => "oauth_authorize",
52 Self::OAuthPar => "oauth_par",
53 Self::OAuthIntrospect => "oauth_introspect",
54 Self::AppPassword => "app_password",
55 Self::EmailUpdate => "email_update",
56 Self::TotpVerify => "totp_verify",
57 }
58 }
59
60 fn limit_and_window_ms(&self) -> (u32, u64) {
61 match self {
62 Self::Login => (10, 60_000),
63 Self::AccountCreation => (10, 3_600_000),
64 Self::PasswordReset => (5, 3_600_000),
65 Self::ResetPassword => (10, 60_000),
66 Self::RefreshSession => (60, 60_000),
67 Self::OAuthToken => (30, 60_000),
68 Self::OAuthAuthorize => (10, 60_000),
69 Self::OAuthPar => (30, 60_000),
70 Self::OAuthIntrospect => (30, 60_000),
71 Self::AppPassword => (10, 60_000),
72 Self::EmailUpdate => (5, 3_600_000),
73 Self::TotpVerify => (5, 300_000),
74 }
75 }
76}
77
78impl AppState {
79 pub async fn new() -> Result<Self, Box<dyn Error>> {
80 let database_url = std::env::var("DATABASE_URL")
81 .map_err(|_| "DATABASE_URL environment variable must be set")?;
82
83 let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS")
84 .ok()
85 .and_then(|v| v.parse().ok())
86 .unwrap_or(100);
87
88 let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS")
89 .ok()
90 .and_then(|v| v.parse().ok())
91 .unwrap_or(10);
92
93 let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS")
94 .ok()
95 .and_then(|v| v.parse().ok())
96 .unwrap_or(10);
97
98 tracing::info!(
99 "Configuring database pool: max={}, min={}, acquire_timeout={}s",
100 max_connections,
101 min_connections,
102 acquire_timeout_secs
103 );
104
105 let db = sqlx::postgres::PgPoolOptions::new()
106 .max_connections(max_connections)
107 .min_connections(min_connections)
108 .acquire_timeout(std::time::Duration::from_secs(acquire_timeout_secs))
109 .idle_timeout(std::time::Duration::from_secs(300))
110 .max_lifetime(std::time::Duration::from_secs(1800))
111 .connect(&database_url)
112 .await
113 .map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
114
115 sqlx::migrate!("./migrations")
116 .run(&db)
117 .await
118 .map_err(|e| format!("Failed to run migrations: {}", e))?;
119
120 Ok(Self::from_db(db).await)
121 }
122
123 pub async fn from_db(db: PgPool) -> Self {
124 AuthConfig::init();
125
126 let block_store = PostgresBlockStore::new(db.clone());
127 let blob_store = S3BlobStorage::new().await;
128
129 let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE")
130 .ok()
131 .and_then(|v| v.parse().ok())
132 .unwrap_or(10000);
133
134 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size);
135 let rate_limiters = Arc::new(RateLimiters::new());
136 let circuit_breakers = Arc::new(CircuitBreakers::new());
137 let (cache, distributed_rate_limiter) = create_cache().await;
138 let did_resolver = Arc::new(DidResolver::new());
139
140 Self {
141 db,
142 block_store,
143 blob_store: Arc::new(blob_store),
144 firehose_tx,
145 rate_limiters,
146 circuit_breakers,
147 cache,
148 distributed_rate_limiter,
149 did_resolver,
150 }
151 }
152
153 pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
154 self.rate_limiters = Arc::new(rate_limiters);
155 self
156 }
157
158 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
159 self.circuit_breakers = Arc::new(circuit_breakers);
160 self
161 }
162
163 pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool {
164 if std::env::var("DISABLE_RATE_LIMITING").is_ok() {
165 return true;
166 }
167
168 let key = format!("{}:{}", kind.key_prefix(), client_ip);
169 let limiter_name = kind.key_prefix();
170 let (limit, window_ms) = kind.limit_and_window_ms();
171
172 if !self
173 .distributed_rate_limiter
174 .check_rate_limit(&key, limit, window_ms)
175 .await
176 {
177 crate::metrics::record_rate_limit_rejection(limiter_name);
178 return false;
179 }
180
181 let limiter = match kind {
182 RateLimitKind::Login => &self.rate_limiters.login,
183 RateLimitKind::AccountCreation => &self.rate_limiters.account_creation,
184 RateLimitKind::PasswordReset => &self.rate_limiters.password_reset,
185 RateLimitKind::ResetPassword => &self.rate_limiters.reset_password,
186 RateLimitKind::RefreshSession => &self.rate_limiters.refresh_session,
187 RateLimitKind::OAuthToken => &self.rate_limiters.oauth_token,
188 RateLimitKind::OAuthAuthorize => &self.rate_limiters.oauth_authorize,
189 RateLimitKind::OAuthPar => &self.rate_limiters.oauth_par,
190 RateLimitKind::OAuthIntrospect => &self.rate_limiters.oauth_introspect,
191 RateLimitKind::AppPassword => &self.rate_limiters.app_password,
192 RateLimitKind::EmailUpdate => &self.rate_limiters.email_update,
193 RateLimitKind::TotpVerify => &self.rate_limiters.totp_verify,
194 };
195
196 let ok = limiter.check_key(&client_ip.to_string()).is_ok();
197 if !ok {
198 crate::metrics::record_rate_limit_rejection(limiter_name);
199 }
200 ok
201 }
202}