this repo has no description
1use axum::{ 2 body::Body, 3 extract::ConnectInfo, 4 http::{HeaderMap, Request, StatusCode}, 5 middleware::Next, 6 response::{IntoResponse, Response}, 7 Json, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{ 15 net::SocketAddr, 16 num::NonZeroU32, 17 sync::Arc, 18}; 19 20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 23// NOTE: For production deployments with high traffic, prefer using the distributed rate 24// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory 25// rate limiters here don't automatically clean up expired entries, which can cause 26// memory growth over time with many unique client IPs. The distributed rate limiter 27// uses Redis TTL for automatic cleanup and works correctly across multiple instances. 28 29#[derive(Clone)] 30pub struct RateLimiters { 31 pub login: Arc<KeyedRateLimiter>, 32 pub oauth_token: Arc<KeyedRateLimiter>, 33 pub oauth_authorize: Arc<KeyedRateLimiter>, 34 pub password_reset: Arc<KeyedRateLimiter>, 35 pub account_creation: Arc<KeyedRateLimiter>, 36 pub refresh_session: Arc<KeyedRateLimiter>, 37 pub reset_password: Arc<KeyedRateLimiter>, 38 pub oauth_par: Arc<KeyedRateLimiter>, 39 pub oauth_introspect: Arc<KeyedRateLimiter>, 40 pub app_password: Arc<KeyedRateLimiter>, 41 pub email_update: Arc<KeyedRateLimiter>, 42} 43 44impl Default for RateLimiters { 45 fn default() -> Self { 46 Self::new() 47 } 48} 49 50impl RateLimiters { 51 pub fn new() -> Self { 52 Self { 53 login: Arc::new(RateLimiter::keyed( 54 Quota::per_minute(NonZeroU32::new(10).unwrap()) 55 )), 56 oauth_token: Arc::new(RateLimiter::keyed( 57 Quota::per_minute(NonZeroU32::new(30).unwrap()) 58 )), 59 oauth_authorize: Arc::new(RateLimiter::keyed( 60 Quota::per_minute(NonZeroU32::new(10).unwrap()) 61 )), 62 password_reset: Arc::new(RateLimiter::keyed( 63 Quota::per_hour(NonZeroU32::new(5).unwrap()) 64 )), 65 account_creation: Arc::new(RateLimiter::keyed( 66 Quota::per_hour(NonZeroU32::new(10).unwrap()) 67 )), 68 refresh_session: Arc::new(RateLimiter::keyed( 69 Quota::per_minute(NonZeroU32::new(60).unwrap()) 70 )), 71 reset_password: Arc::new(RateLimiter::keyed( 72 Quota::per_minute(NonZeroU32::new(10).unwrap()) 73 )), 74 oauth_par: Arc::new(RateLimiter::keyed( 75 Quota::per_minute(NonZeroU32::new(30).unwrap()) 76 )), 77 oauth_introspect: Arc::new(RateLimiter::keyed( 78 Quota::per_minute(NonZeroU32::new(30).unwrap()) 79 )), 80 app_password: Arc::new(RateLimiter::keyed( 81 Quota::per_minute(NonZeroU32::new(10).unwrap()) 82 )), 83 email_update: Arc::new(RateLimiter::keyed( 84 Quota::per_hour(NonZeroU32::new(5).unwrap()) 85 )), 86 } 87 } 88 89 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 90 self.login = Arc::new(RateLimiter::keyed( 91 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 92 )); 93 self 94 } 95 96 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 97 self.oauth_token = Arc::new(RateLimiter::keyed( 98 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 99 )); 100 self 101 } 102 103 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 104 self.oauth_authorize = Arc::new(RateLimiter::keyed( 105 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 106 )); 107 self 108 } 109 110 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 111 self.password_reset = Arc::new(RateLimiter::keyed( 112 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 113 )); 114 self 115 } 116 117 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 118 self.account_creation = Arc::new(RateLimiter::keyed( 119 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 120 )); 121 self 122 } 123 124 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 125 self.email_update = Arc::new(RateLimiter::keyed( 126 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 127 )); 128 self 129 } 130} 131 132pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 133 if let Some(forwarded) = headers.get("x-forwarded-for") { 134 if let Ok(value) = forwarded.to_str() { 135 if let Some(first_ip) = value.split(',').next() { 136 return first_ip.trim().to_string(); 137 } 138 } 139 } 140 141 if let Some(real_ip) = headers.get("x-real-ip") { 142 if let Ok(value) = real_ip.to_str() { 143 return value.trim().to_string(); 144 } 145 } 146 147 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 148} 149 150fn rate_limit_response() -> Response { 151 ( 152 StatusCode::TOO_MANY_REQUESTS, 153 Json(serde_json::json!({ 154 "error": "RateLimitExceeded", 155 "message": "Too many requests. Please try again later." 156 })), 157 ) 158 .into_response() 159} 160 161pub async fn login_rate_limit( 162 ConnectInfo(addr): ConnectInfo<SocketAddr>, 163 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 164 request: Request<Body>, 165 next: Next, 166) -> Response { 167 let client_ip = extract_client_ip(request.headers(), Some(addr)); 168 169 if limiters.login.check_key(&client_ip).is_err() { 170 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 171 return rate_limit_response(); 172 } 173 174 next.run(request).await 175} 176 177pub async fn oauth_token_rate_limit( 178 ConnectInfo(addr): ConnectInfo<SocketAddr>, 179 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 180 request: Request<Body>, 181 next: Next, 182) -> Response { 183 let client_ip = extract_client_ip(request.headers(), Some(addr)); 184 185 if limiters.oauth_token.check_key(&client_ip).is_err() { 186 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 187 return rate_limit_response(); 188 } 189 190 next.run(request).await 191} 192 193pub async fn password_reset_rate_limit( 194 ConnectInfo(addr): ConnectInfo<SocketAddr>, 195 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 196 request: Request<Body>, 197 next: Next, 198) -> Response { 199 let client_ip = extract_client_ip(request.headers(), Some(addr)); 200 201 if limiters.password_reset.check_key(&client_ip).is_err() { 202 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 203 return rate_limit_response(); 204 } 205 206 next.run(request).await 207} 208 209pub async fn account_creation_rate_limit( 210 ConnectInfo(addr): ConnectInfo<SocketAddr>, 211 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 212 request: Request<Body>, 213 next: Next, 214) -> Response { 215 let client_ip = extract_client_ip(request.headers(), Some(addr)); 216 217 if limiters.account_creation.check_key(&client_ip).is_err() { 218 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 219 return rate_limit_response(); 220 } 221 222 next.run(request).await 223} 224 225#[cfg(test)] 226mod tests { 227 use super::*; 228 229 #[test] 230 fn test_rate_limiters_creation() { 231 let limiters = RateLimiters::new(); 232 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 233 } 234 235 #[test] 236 fn test_rate_limiter_exhaustion() { 237 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 238 let key = "test_ip".to_string(); 239 240 assert!(limiter.check_key(&key).is_ok()); 241 assert!(limiter.check_key(&key).is_ok()); 242 assert!(limiter.check_key(&key).is_err()); 243 } 244 245 #[test] 246 fn test_different_keys_have_separate_limits() { 247 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 248 249 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 250 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 251 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 252 } 253 254 #[test] 255 fn test_builder_pattern() { 256 let limiters = RateLimiters::new() 257 .with_login_limit(20) 258 .with_oauth_token_limit(60) 259 .with_password_reset_limit(3) 260 .with_account_creation_limit(5); 261 262 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 263 } 264}