this repo has no description
1use axum::{ 2 Json, 3 body::Body, 4 extract::ConnectInfo, 5 http::{HeaderMap, Request, StatusCode}, 6 middleware::Next, 7 response::{IntoResponse, Response}, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 15 16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 18 19#[derive(Clone)] 20pub struct RateLimiters { 21 pub login: Arc<KeyedRateLimiter>, 22 pub oauth_token: Arc<KeyedRateLimiter>, 23 pub oauth_authorize: Arc<KeyedRateLimiter>, 24 pub password_reset: Arc<KeyedRateLimiter>, 25 pub account_creation: Arc<KeyedRateLimiter>, 26 pub refresh_session: Arc<KeyedRateLimiter>, 27 pub reset_password: Arc<KeyedRateLimiter>, 28 pub oauth_par: Arc<KeyedRateLimiter>, 29 pub oauth_introspect: Arc<KeyedRateLimiter>, 30 pub app_password: Arc<KeyedRateLimiter>, 31 pub email_update: Arc<KeyedRateLimiter>, 32 pub totp_verify: Arc<KeyedRateLimiter>, 33} 34 35impl Default for RateLimiters { 36 fn default() -> Self { 37 Self::new() 38 } 39} 40 41impl RateLimiters { 42 pub fn new() -> Self { 43 Self { 44 login: Arc::new(RateLimiter::keyed(Quota::per_minute( 45 NonZeroU32::new(10).unwrap(), 46 ))), 47 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 48 NonZeroU32::new(30).unwrap(), 49 ))), 50 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 51 NonZeroU32::new(10).unwrap(), 52 ))), 53 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 54 NonZeroU32::new(5).unwrap(), 55 ))), 56 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 57 NonZeroU32::new(10).unwrap(), 58 ))), 59 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 60 NonZeroU32::new(60).unwrap(), 61 ))), 62 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 63 NonZeroU32::new(10).unwrap(), 64 ))), 65 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 66 NonZeroU32::new(30).unwrap(), 67 ))), 68 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 69 NonZeroU32::new(30).unwrap(), 70 ))), 71 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 72 NonZeroU32::new(10).unwrap(), 73 ))), 74 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 75 NonZeroU32::new(5).unwrap(), 76 ))), 77 totp_verify: Arc::new(RateLimiter::keyed( 78 Quota::with_period(std::time::Duration::from_secs(60)) 79 .unwrap() 80 .allow_burst(NonZeroU32::new(5).unwrap()), 81 )), 82 } 83 } 84 85 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 86 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 87 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 88 ))); 89 self 90 } 91 92 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 93 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 94 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 95 ))); 96 self 97 } 98 99 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 100 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 101 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 102 ))); 103 self 104 } 105 106 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 107 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 108 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 109 ))); 110 self 111 } 112 113 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 114 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 115 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 116 ))); 117 self 118 } 119 120 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 121 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 122 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 123 ))); 124 self 125 } 126} 127 128pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 129 if let Some(forwarded) = headers.get("x-forwarded-for") 130 && let Ok(value) = forwarded.to_str() 131 && let Some(first_ip) = value.split(',').next() 132 { 133 return first_ip.trim().to_string(); 134 } 135 136 if let Some(real_ip) = headers.get("x-real-ip") 137 && let Ok(value) = real_ip.to_str() 138 { 139 return value.trim().to_string(); 140 } 141 142 addr.map(|a| a.ip().to_string()) 143 .unwrap_or_else(|| "unknown".to_string()) 144} 145 146fn rate_limit_response() -> Response { 147 ( 148 StatusCode::TOO_MANY_REQUESTS, 149 Json(serde_json::json!({ 150 "error": "RateLimitExceeded", 151 "message": "Too many requests. Please try again later." 152 })), 153 ) 154 .into_response() 155} 156 157pub async fn login_rate_limit( 158 ConnectInfo(addr): ConnectInfo<SocketAddr>, 159 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 160 request: Request<Body>, 161 next: Next, 162) -> Response { 163 let client_ip = extract_client_ip(request.headers(), Some(addr)); 164 165 if limiters.login.check_key(&client_ip).is_err() { 166 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 167 return rate_limit_response(); 168 } 169 170 next.run(request).await 171} 172 173pub async fn oauth_token_rate_limit( 174 ConnectInfo(addr): ConnectInfo<SocketAddr>, 175 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 176 request: Request<Body>, 177 next: Next, 178) -> Response { 179 let client_ip = extract_client_ip(request.headers(), Some(addr)); 180 181 if limiters.oauth_token.check_key(&client_ip).is_err() { 182 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 183 return rate_limit_response(); 184 } 185 186 next.run(request).await 187} 188 189pub async fn password_reset_rate_limit( 190 ConnectInfo(addr): ConnectInfo<SocketAddr>, 191 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 192 request: Request<Body>, 193 next: Next, 194) -> Response { 195 let client_ip = extract_client_ip(request.headers(), Some(addr)); 196 197 if limiters.password_reset.check_key(&client_ip).is_err() { 198 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 199 return rate_limit_response(); 200 } 201 202 next.run(request).await 203} 204 205pub async fn account_creation_rate_limit( 206 ConnectInfo(addr): ConnectInfo<SocketAddr>, 207 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 208 request: Request<Body>, 209 next: Next, 210) -> Response { 211 let client_ip = extract_client_ip(request.headers(), Some(addr)); 212 213 if limiters.account_creation.check_key(&client_ip).is_err() { 214 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 215 return rate_limit_response(); 216 } 217 218 next.run(request).await 219} 220 221#[cfg(test)] 222mod tests { 223 use super::*; 224 225 #[test] 226 fn test_rate_limiters_creation() { 227 let limiters = RateLimiters::new(); 228 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 229 } 230 231 #[test] 232 fn test_rate_limiter_exhaustion() { 233 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 234 let key = "test_ip".to_string(); 235 236 assert!(limiter.check_key(&key).is_ok()); 237 assert!(limiter.check_key(&key).is_ok()); 238 assert!(limiter.check_key(&key).is_err()); 239 } 240 241 #[test] 242 fn test_different_keys_have_separate_limits() { 243 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 244 245 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 246 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 247 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 248 } 249 250 #[test] 251 fn test_builder_pattern() { 252 let limiters = RateLimiters::new() 253 .with_login_limit(20) 254 .with_oauth_token_limit(60) 255 .with_password_reset_limit(3) 256 .with_account_creation_limit(5); 257 258 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 259 } 260}