this repo has no description
1use axum::{ 2 body::Body, 3 extract::ConnectInfo, 4 http::{HeaderMap, Request, StatusCode}, 5 middleware::Next, 6 response::{IntoResponse, Response}, 7 Json, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{ 15 net::SocketAddr, 16 num::NonZeroU32, 17 sync::Arc, 18}; 19pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 20pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 21// NOTE: For production deployments with high traffic, prefer using the distributed rate 22// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory 23// rate limiters here don't automatically clean up expired entries, which can cause 24// memory growth over time with many unique client IPs. The distributed rate limiter 25// uses Redis TTL for automatic cleanup and works correctly across multiple instances. 26#[derive(Clone)] 27pub struct RateLimiters { 28 pub login: Arc<KeyedRateLimiter>, 29 pub oauth_token: Arc<KeyedRateLimiter>, 30 pub oauth_authorize: Arc<KeyedRateLimiter>, 31 pub password_reset: Arc<KeyedRateLimiter>, 32 pub account_creation: Arc<KeyedRateLimiter>, 33 pub refresh_session: Arc<KeyedRateLimiter>, 34 pub reset_password: Arc<KeyedRateLimiter>, 35 pub oauth_par: Arc<KeyedRateLimiter>, 36 pub oauth_introspect: Arc<KeyedRateLimiter>, 37 pub app_password: Arc<KeyedRateLimiter>, 38 pub email_update: Arc<KeyedRateLimiter>, 39} 40impl Default for RateLimiters { 41 fn default() -> Self { 42 Self::new() 43 } 44} 45impl RateLimiters { 46 pub fn new() -> Self { 47 Self { 48 login: Arc::new(RateLimiter::keyed( 49 Quota::per_minute(NonZeroU32::new(10).unwrap()) 50 )), 51 oauth_token: Arc::new(RateLimiter::keyed( 52 Quota::per_minute(NonZeroU32::new(30).unwrap()) 53 )), 54 oauth_authorize: Arc::new(RateLimiter::keyed( 55 Quota::per_minute(NonZeroU32::new(10).unwrap()) 56 )), 57 password_reset: Arc::new(RateLimiter::keyed( 58 Quota::per_hour(NonZeroU32::new(5).unwrap()) 59 )), 60 account_creation: Arc::new(RateLimiter::keyed( 61 Quota::per_hour(NonZeroU32::new(10).unwrap()) 62 )), 63 refresh_session: Arc::new(RateLimiter::keyed( 64 Quota::per_minute(NonZeroU32::new(60).unwrap()) 65 )), 66 reset_password: Arc::new(RateLimiter::keyed( 67 Quota::per_minute(NonZeroU32::new(10).unwrap()) 68 )), 69 oauth_par: Arc::new(RateLimiter::keyed( 70 Quota::per_minute(NonZeroU32::new(30).unwrap()) 71 )), 72 oauth_introspect: Arc::new(RateLimiter::keyed( 73 Quota::per_minute(NonZeroU32::new(30).unwrap()) 74 )), 75 app_password: Arc::new(RateLimiter::keyed( 76 Quota::per_minute(NonZeroU32::new(10).unwrap()) 77 )), 78 email_update: Arc::new(RateLimiter::keyed( 79 Quota::per_hour(NonZeroU32::new(5).unwrap()) 80 )), 81 } 82 } 83 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 self.login = Arc::new(RateLimiter::keyed( 85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 )); 87 self 88 } 89 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 90 self.oauth_token = Arc::new(RateLimiter::keyed( 91 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 92 )); 93 self 94 } 95 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 96 self.oauth_authorize = Arc::new(RateLimiter::keyed( 97 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 98 )); 99 self 100 } 101 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 102 self.password_reset = Arc::new(RateLimiter::keyed( 103 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 104 )); 105 self 106 } 107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 108 self.account_creation = Arc::new(RateLimiter::keyed( 109 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 110 )); 111 self 112 } 113 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 114 self.email_update = Arc::new(RateLimiter::keyed( 115 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 116 )); 117 self 118 } 119} 120pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 121 if let Some(forwarded) = headers.get("x-forwarded-for") { 122 if let Ok(value) = forwarded.to_str() { 123 if let Some(first_ip) = value.split(',').next() { 124 return first_ip.trim().to_string(); 125 } 126 } 127 } 128 if let Some(real_ip) = headers.get("x-real-ip") { 129 if let Ok(value) = real_ip.to_str() { 130 return value.trim().to_string(); 131 } 132 } 133 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 134} 135fn rate_limit_response() -> Response { 136 ( 137 StatusCode::TOO_MANY_REQUESTS, 138 Json(serde_json::json!({ 139 "error": "RateLimitExceeded", 140 "message": "Too many requests. Please try again later." 141 })), 142 ) 143 .into_response() 144} 145pub async fn login_rate_limit( 146 ConnectInfo(addr): ConnectInfo<SocketAddr>, 147 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 148 request: Request<Body>, 149 next: Next, 150) -> Response { 151 let client_ip = extract_client_ip(request.headers(), Some(addr)); 152 if limiters.login.check_key(&client_ip).is_err() { 153 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 154 return rate_limit_response(); 155 } 156 next.run(request).await 157} 158pub async fn oauth_token_rate_limit( 159 ConnectInfo(addr): ConnectInfo<SocketAddr>, 160 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 161 request: Request<Body>, 162 next: Next, 163) -> Response { 164 let client_ip = extract_client_ip(request.headers(), Some(addr)); 165 if limiters.oauth_token.check_key(&client_ip).is_err() { 166 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 167 return rate_limit_response(); 168 } 169 next.run(request).await 170} 171pub async fn password_reset_rate_limit( 172 ConnectInfo(addr): ConnectInfo<SocketAddr>, 173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 174 request: Request<Body>, 175 next: Next, 176) -> Response { 177 let client_ip = extract_client_ip(request.headers(), Some(addr)); 178 if limiters.password_reset.check_key(&client_ip).is_err() { 179 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 180 return rate_limit_response(); 181 } 182 next.run(request).await 183} 184pub async fn account_creation_rate_limit( 185 ConnectInfo(addr): ConnectInfo<SocketAddr>, 186 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 187 request: Request<Body>, 188 next: Next, 189) -> Response { 190 let client_ip = extract_client_ip(request.headers(), Some(addr)); 191 if limiters.account_creation.check_key(&client_ip).is_err() { 192 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 193 return rate_limit_response(); 194 } 195 next.run(request).await 196} 197#[cfg(test)] 198mod tests { 199 use super::*; 200 #[test] 201 fn test_rate_limiters_creation() { 202 let limiters = RateLimiters::new(); 203 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 204 } 205 #[test] 206 fn test_rate_limiter_exhaustion() { 207 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 208 let key = "test_ip".to_string(); 209 assert!(limiter.check_key(&key).is_ok()); 210 assert!(limiter.check_key(&key).is_ok()); 211 assert!(limiter.check_key(&key).is_err()); 212 } 213 #[test] 214 fn test_different_keys_have_separate_limits() { 215 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 216 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 217 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 218 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 219 } 220 #[test] 221 fn test_builder_pattern() { 222 let limiters = RateLimiters::new() 223 .with_login_limit(20) 224 .with_oauth_token_limit(60) 225 .with_password_reset_limit(3) 226 .with_account_creation_limit(5); 227 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 228 } 229}