this repo has no description
1use axum::{ 2 Json, 3 body::Body, 4 extract::ConnectInfo, 5 http::{HeaderMap, Request, StatusCode}, 6 middleware::Next, 7 response::{IntoResponse, Response}, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 15 16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 18 19#[derive(Clone)] 20pub struct RateLimiters { 21 pub login: Arc<KeyedRateLimiter>, 22 pub oauth_token: Arc<KeyedRateLimiter>, 23 pub oauth_authorize: Arc<KeyedRateLimiter>, 24 pub password_reset: Arc<KeyedRateLimiter>, 25 pub account_creation: Arc<KeyedRateLimiter>, 26 pub refresh_session: Arc<KeyedRateLimiter>, 27 pub reset_password: Arc<KeyedRateLimiter>, 28 pub oauth_par: Arc<KeyedRateLimiter>, 29 pub oauth_introspect: Arc<KeyedRateLimiter>, 30 pub app_password: Arc<KeyedRateLimiter>, 31 pub email_update: Arc<KeyedRateLimiter>, 32 pub totp_verify: Arc<KeyedRateLimiter>, 33 pub handle_update: Arc<KeyedRateLimiter>, 34 pub handle_update_daily: Arc<KeyedRateLimiter>, 35} 36 37impl Default for RateLimiters { 38 fn default() -> Self { 39 Self::new() 40 } 41} 42 43impl RateLimiters { 44 pub fn new() -> Self { 45 Self { 46 login: Arc::new(RateLimiter::keyed(Quota::per_minute( 47 NonZeroU32::new(10).unwrap(), 48 ))), 49 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 50 NonZeroU32::new(30).unwrap(), 51 ))), 52 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 53 NonZeroU32::new(10).unwrap(), 54 ))), 55 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 56 NonZeroU32::new(5).unwrap(), 57 ))), 58 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 59 NonZeroU32::new(10).unwrap(), 60 ))), 61 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 62 NonZeroU32::new(60).unwrap(), 63 ))), 64 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 65 NonZeroU32::new(10).unwrap(), 66 ))), 67 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 68 NonZeroU32::new(30).unwrap(), 69 ))), 70 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 71 NonZeroU32::new(30).unwrap(), 72 ))), 73 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 74 NonZeroU32::new(10).unwrap(), 75 ))), 76 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 77 NonZeroU32::new(5).unwrap(), 78 ))), 79 totp_verify: Arc::new(RateLimiter::keyed( 80 Quota::with_period(std::time::Duration::from_secs(60)) 81 .unwrap() 82 .allow_burst(NonZeroU32::new(5).unwrap()), 83 )), 84 handle_update: Arc::new(RateLimiter::keyed( 85 Quota::with_period(std::time::Duration::from_secs(30)) 86 .unwrap() 87 .allow_burst(NonZeroU32::new(10).unwrap()), 88 )), 89 handle_update_daily: Arc::new(RateLimiter::keyed( 90 Quota::with_period(std::time::Duration::from_secs(1728)) 91 .unwrap() 92 .allow_burst(NonZeroU32::new(50).unwrap()), 93 )), 94 } 95 } 96 97 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 98 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 99 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 100 ))); 101 self 102 } 103 104 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 105 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 106 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 107 ))); 108 self 109 } 110 111 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 112 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 113 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 114 ))); 115 self 116 } 117 118 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 119 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 120 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 121 ))); 122 self 123 } 124 125 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 126 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 127 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 128 ))); 129 self 130 } 131 132 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 133 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 134 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 135 ))); 136 self 137 } 138} 139 140pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 141 if let Some(forwarded) = headers.get("x-forwarded-for") 142 && let Ok(value) = forwarded.to_str() 143 && let Some(first_ip) = value.split(',').next() 144 { 145 return first_ip.trim().to_string(); 146 } 147 148 if let Some(real_ip) = headers.get("x-real-ip") 149 && let Ok(value) = real_ip.to_str() 150 { 151 return value.trim().to_string(); 152 } 153 154 addr.map(|a| a.ip().to_string()) 155 .unwrap_or_else(|| "unknown".to_string()) 156} 157 158fn rate_limit_response() -> Response { 159 ( 160 StatusCode::TOO_MANY_REQUESTS, 161 Json(serde_json::json!({ 162 "error": "RateLimitExceeded", 163 "message": "Too many requests. Please try again later." 164 })), 165 ) 166 .into_response() 167} 168 169pub async fn login_rate_limit( 170 ConnectInfo(addr): ConnectInfo<SocketAddr>, 171 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 172 request: Request<Body>, 173 next: Next, 174) -> Response { 175 let client_ip = extract_client_ip(request.headers(), Some(addr)); 176 177 if limiters.login.check_key(&client_ip).is_err() { 178 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 179 return rate_limit_response(); 180 } 181 182 next.run(request).await 183} 184 185pub async fn oauth_token_rate_limit( 186 ConnectInfo(addr): ConnectInfo<SocketAddr>, 187 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 188 request: Request<Body>, 189 next: Next, 190) -> Response { 191 let client_ip = extract_client_ip(request.headers(), Some(addr)); 192 193 if limiters.oauth_token.check_key(&client_ip).is_err() { 194 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 195 return rate_limit_response(); 196 } 197 198 next.run(request).await 199} 200 201pub async fn password_reset_rate_limit( 202 ConnectInfo(addr): ConnectInfo<SocketAddr>, 203 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 204 request: Request<Body>, 205 next: Next, 206) -> Response { 207 let client_ip = extract_client_ip(request.headers(), Some(addr)); 208 209 if limiters.password_reset.check_key(&client_ip).is_err() { 210 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 211 return rate_limit_response(); 212 } 213 214 next.run(request).await 215} 216 217pub async fn account_creation_rate_limit( 218 ConnectInfo(addr): ConnectInfo<SocketAddr>, 219 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 220 request: Request<Body>, 221 next: Next, 222) -> Response { 223 let client_ip = extract_client_ip(request.headers(), Some(addr)); 224 225 if limiters.account_creation.check_key(&client_ip).is_err() { 226 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 227 return rate_limit_response(); 228 } 229 230 next.run(request).await 231} 232 233#[cfg(test)] 234mod tests { 235 use super::*; 236 237 #[test] 238 fn test_rate_limiters_creation() { 239 let limiters = RateLimiters::new(); 240 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 241 } 242 243 #[test] 244 fn test_rate_limiter_exhaustion() { 245 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 246 let key = "test_ip".to_string(); 247 248 assert!(limiter.check_key(&key).is_ok()); 249 assert!(limiter.check_key(&key).is_ok()); 250 assert!(limiter.check_key(&key).is_err()); 251 } 252 253 #[test] 254 fn test_different_keys_have_separate_limits() { 255 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 256 257 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 258 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 259 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 260 } 261 262 #[test] 263 fn test_builder_pattern() { 264 let limiters = RateLimiters::new() 265 .with_login_limit(20) 266 .with_oauth_token_limit(60) 267 .with_password_reset_limit(3) 268 .with_account_creation_limit(5); 269 270 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 271 } 272}