this repo has no description
1use axum::{ 2 body::Body, 3 extract::ConnectInfo, 4 http::{HeaderMap, Request, StatusCode}, 5 middleware::Next, 6 response::{IntoResponse, Response}, 7 Json, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{ 15 net::SocketAddr, 16 num::NonZeroU32, 17 sync::Arc, 18}; 19 20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 23#[derive(Clone)] 24pub struct RateLimiters { 25 pub login: Arc<KeyedRateLimiter>, 26 pub oauth_token: Arc<KeyedRateLimiter>, 27 pub oauth_authorize: Arc<KeyedRateLimiter>, 28 pub password_reset: Arc<KeyedRateLimiter>, 29 pub account_creation: Arc<KeyedRateLimiter>, 30 pub refresh_session: Arc<KeyedRateLimiter>, 31 pub reset_password: Arc<KeyedRateLimiter>, 32 pub oauth_par: Arc<KeyedRateLimiter>, 33 pub oauth_introspect: Arc<KeyedRateLimiter>, 34 pub app_password: Arc<KeyedRateLimiter>, 35 pub email_update: Arc<KeyedRateLimiter>, 36} 37 38impl Default for RateLimiters { 39 fn default() -> Self { 40 Self::new() 41 } 42} 43 44impl RateLimiters { 45 pub fn new() -> Self { 46 Self { 47 login: Arc::new(RateLimiter::keyed( 48 Quota::per_minute(NonZeroU32::new(10).unwrap()) 49 )), 50 oauth_token: Arc::new(RateLimiter::keyed( 51 Quota::per_minute(NonZeroU32::new(30).unwrap()) 52 )), 53 oauth_authorize: Arc::new(RateLimiter::keyed( 54 Quota::per_minute(NonZeroU32::new(10).unwrap()) 55 )), 56 password_reset: Arc::new(RateLimiter::keyed( 57 Quota::per_hour(NonZeroU32::new(5).unwrap()) 58 )), 59 account_creation: Arc::new(RateLimiter::keyed( 60 Quota::per_hour(NonZeroU32::new(10).unwrap()) 61 )), 62 refresh_session: Arc::new(RateLimiter::keyed( 63 Quota::per_minute(NonZeroU32::new(60).unwrap()) 64 )), 65 reset_password: Arc::new(RateLimiter::keyed( 66 Quota::per_minute(NonZeroU32::new(10).unwrap()) 67 )), 68 oauth_par: Arc::new(RateLimiter::keyed( 69 Quota::per_minute(NonZeroU32::new(30).unwrap()) 70 )), 71 oauth_introspect: Arc::new(RateLimiter::keyed( 72 Quota::per_minute(NonZeroU32::new(30).unwrap()) 73 )), 74 app_password: Arc::new(RateLimiter::keyed( 75 Quota::per_minute(NonZeroU32::new(10).unwrap()) 76 )), 77 email_update: Arc::new(RateLimiter::keyed( 78 Quota::per_hour(NonZeroU32::new(5).unwrap()) 79 )), 80 } 81 } 82 83 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 self.login = Arc::new(RateLimiter::keyed( 85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 )); 87 self 88 } 89 90 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 91 self.oauth_token = Arc::new(RateLimiter::keyed( 92 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 93 )); 94 self 95 } 96 97 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 98 self.oauth_authorize = Arc::new(RateLimiter::keyed( 99 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 100 )); 101 self 102 } 103 104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 105 self.password_reset = Arc::new(RateLimiter::keyed( 106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 107 )); 108 self 109 } 110 111 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 112 self.account_creation = Arc::new(RateLimiter::keyed( 113 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 114 )); 115 self 116 } 117 118 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 119 self.email_update = Arc::new(RateLimiter::keyed( 120 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 121 )); 122 self 123 } 124} 125 126pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 127 if let Some(forwarded) = headers.get("x-forwarded-for") { 128 if let Ok(value) = forwarded.to_str() { 129 if let Some(first_ip) = value.split(',').next() { 130 return first_ip.trim().to_string(); 131 } 132 } 133 } 134 135 if let Some(real_ip) = headers.get("x-real-ip") { 136 if let Ok(value) = real_ip.to_str() { 137 return value.trim().to_string(); 138 } 139 } 140 141 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 142} 143 144fn rate_limit_response() -> Response { 145 ( 146 StatusCode::TOO_MANY_REQUESTS, 147 Json(serde_json::json!({ 148 "error": "RateLimitExceeded", 149 "message": "Too many requests. Please try again later." 150 })), 151 ) 152 .into_response() 153} 154 155pub async fn login_rate_limit( 156 ConnectInfo(addr): ConnectInfo<SocketAddr>, 157 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 158 request: Request<Body>, 159 next: Next, 160) -> Response { 161 let client_ip = extract_client_ip(request.headers(), Some(addr)); 162 163 if limiters.login.check_key(&client_ip).is_err() { 164 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 165 return rate_limit_response(); 166 } 167 168 next.run(request).await 169} 170 171pub async fn oauth_token_rate_limit( 172 ConnectInfo(addr): ConnectInfo<SocketAddr>, 173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 174 request: Request<Body>, 175 next: Next, 176) -> Response { 177 let client_ip = extract_client_ip(request.headers(), Some(addr)); 178 179 if limiters.oauth_token.check_key(&client_ip).is_err() { 180 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 181 return rate_limit_response(); 182 } 183 184 next.run(request).await 185} 186 187pub async fn password_reset_rate_limit( 188 ConnectInfo(addr): ConnectInfo<SocketAddr>, 189 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 190 request: Request<Body>, 191 next: Next, 192) -> Response { 193 let client_ip = extract_client_ip(request.headers(), Some(addr)); 194 195 if limiters.password_reset.check_key(&client_ip).is_err() { 196 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 197 return rate_limit_response(); 198 } 199 200 next.run(request).await 201} 202 203pub async fn account_creation_rate_limit( 204 ConnectInfo(addr): ConnectInfo<SocketAddr>, 205 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 206 request: Request<Body>, 207 next: Next, 208) -> Response { 209 let client_ip = extract_client_ip(request.headers(), Some(addr)); 210 211 if limiters.account_creation.check_key(&client_ip).is_err() { 212 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 213 return rate_limit_response(); 214 } 215 216 next.run(request).await 217} 218 219#[cfg(test)] 220mod tests { 221 use super::*; 222 223 #[test] 224 fn test_rate_limiters_creation() { 225 let limiters = RateLimiters::new(); 226 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 227 } 228 229 #[test] 230 fn test_rate_limiter_exhaustion() { 231 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 232 let key = "test_ip".to_string(); 233 234 assert!(limiter.check_key(&key).is_ok()); 235 assert!(limiter.check_key(&key).is_ok()); 236 assert!(limiter.check_key(&key).is_err()); 237 } 238 239 #[test] 240 fn test_different_keys_have_separate_limits() { 241 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 242 243 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 244 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 245 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 246 } 247 248 #[test] 249 fn test_builder_pattern() { 250 let limiters = RateLimiters::new() 251 .with_login_limit(20) 252 .with_oauth_token_limit(60) 253 .with_password_reset_limit(3) 254 .with_account_creation_limit(5); 255 256 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 257 } 258}