this repo has no description
1use axum::{ 2 Json, 3 body::Body, 4 extract::ConnectInfo, 5 http::{HeaderMap, Request, StatusCode}, 6 middleware::Next, 7 response::{IntoResponse, Response}, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 15 16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 18 19#[derive(Clone)] 20pub struct RateLimiters { 21 pub login: Arc<KeyedRateLimiter>, 22 pub oauth_token: Arc<KeyedRateLimiter>, 23 pub oauth_authorize: Arc<KeyedRateLimiter>, 24 pub password_reset: Arc<KeyedRateLimiter>, 25 pub account_creation: Arc<KeyedRateLimiter>, 26 pub refresh_session: Arc<KeyedRateLimiter>, 27 pub reset_password: Arc<KeyedRateLimiter>, 28 pub oauth_par: Arc<KeyedRateLimiter>, 29 pub oauth_introspect: Arc<KeyedRateLimiter>, 30 pub app_password: Arc<KeyedRateLimiter>, 31 pub email_update: Arc<KeyedRateLimiter>, 32} 33 34impl Default for RateLimiters { 35 fn default() -> Self { 36 Self::new() 37 } 38} 39 40impl RateLimiters { 41 pub fn new() -> Self { 42 Self { 43 login: Arc::new(RateLimiter::keyed(Quota::per_minute( 44 NonZeroU32::new(10).unwrap(), 45 ))), 46 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 47 NonZeroU32::new(30).unwrap(), 48 ))), 49 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 50 NonZeroU32::new(10).unwrap(), 51 ))), 52 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 53 NonZeroU32::new(5).unwrap(), 54 ))), 55 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 56 NonZeroU32::new(10).unwrap(), 57 ))), 58 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 59 NonZeroU32::new(60).unwrap(), 60 ))), 61 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 62 NonZeroU32::new(10).unwrap(), 63 ))), 64 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 65 NonZeroU32::new(30).unwrap(), 66 ))), 67 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 68 NonZeroU32::new(30).unwrap(), 69 ))), 70 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 71 NonZeroU32::new(10).unwrap(), 72 ))), 73 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 74 NonZeroU32::new(5).unwrap(), 75 ))), 76 } 77 } 78 79 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 80 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 81 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 82 ))); 83 self 84 } 85 86 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 87 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 88 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 89 ))); 90 self 91 } 92 93 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 94 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 95 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 96 ))); 97 self 98 } 99 100 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 101 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 102 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 103 ))); 104 self 105 } 106 107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 108 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 109 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 110 ))); 111 self 112 } 113 114 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 115 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 116 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 117 ))); 118 self 119 } 120} 121 122pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 123 if let Some(forwarded) = headers.get("x-forwarded-for") 124 && let Ok(value) = forwarded.to_str() 125 && let Some(first_ip) = value.split(',').next() 126 { 127 return first_ip.trim().to_string(); 128 } 129 130 if let Some(real_ip) = headers.get("x-real-ip") 131 && let Ok(value) = real_ip.to_str() 132 { 133 return value.trim().to_string(); 134 } 135 136 addr.map(|a| a.ip().to_string()) 137 .unwrap_or_else(|| "unknown".to_string()) 138} 139 140fn rate_limit_response() -> Response { 141 ( 142 StatusCode::TOO_MANY_REQUESTS, 143 Json(serde_json::json!({ 144 "error": "RateLimitExceeded", 145 "message": "Too many requests. Please try again later." 146 })), 147 ) 148 .into_response() 149} 150 151pub async fn login_rate_limit( 152 ConnectInfo(addr): ConnectInfo<SocketAddr>, 153 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 154 request: Request<Body>, 155 next: Next, 156) -> Response { 157 let client_ip = extract_client_ip(request.headers(), Some(addr)); 158 159 if limiters.login.check_key(&client_ip).is_err() { 160 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 161 return rate_limit_response(); 162 } 163 164 next.run(request).await 165} 166 167pub async fn oauth_token_rate_limit( 168 ConnectInfo(addr): ConnectInfo<SocketAddr>, 169 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 170 request: Request<Body>, 171 next: Next, 172) -> Response { 173 let client_ip = extract_client_ip(request.headers(), Some(addr)); 174 175 if limiters.oauth_token.check_key(&client_ip).is_err() { 176 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 177 return rate_limit_response(); 178 } 179 180 next.run(request).await 181} 182 183pub async fn password_reset_rate_limit( 184 ConnectInfo(addr): ConnectInfo<SocketAddr>, 185 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 186 request: Request<Body>, 187 next: Next, 188) -> Response { 189 let client_ip = extract_client_ip(request.headers(), Some(addr)); 190 191 if limiters.password_reset.check_key(&client_ip).is_err() { 192 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 193 return rate_limit_response(); 194 } 195 196 next.run(request).await 197} 198 199pub async fn account_creation_rate_limit( 200 ConnectInfo(addr): ConnectInfo<SocketAddr>, 201 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 202 request: Request<Body>, 203 next: Next, 204) -> Response { 205 let client_ip = extract_client_ip(request.headers(), Some(addr)); 206 207 if limiters.account_creation.check_key(&client_ip).is_err() { 208 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 209 return rate_limit_response(); 210 } 211 212 next.run(request).await 213} 214 215#[cfg(test)] 216mod tests { 217 use super::*; 218 219 #[test] 220 fn test_rate_limiters_creation() { 221 let limiters = RateLimiters::new(); 222 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 223 } 224 225 #[test] 226 fn test_rate_limiter_exhaustion() { 227 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 228 let key = "test_ip".to_string(); 229 230 assert!(limiter.check_key(&key).is_ok()); 231 assert!(limiter.check_key(&key).is_ok()); 232 assert!(limiter.check_key(&key).is_err()); 233 } 234 235 #[test] 236 fn test_different_keys_have_separate_limits() { 237 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 238 239 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 240 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 241 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 242 } 243 244 #[test] 245 fn test_builder_pattern() { 246 let limiters = RateLimiters::new() 247 .with_login_limit(20) 248 .with_oauth_token_limit(60) 249 .with_password_reset_limit(3) 250 .with_account_creation_limit(5); 251 252 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 253 } 254}