this repo has no description
1use axum::{ 2 body::Body, 3 extract::ConnectInfo, 4 http::{HeaderMap, Request, StatusCode}, 5 middleware::Next, 6 response::{IntoResponse, Response}, 7 Json, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{ 15 net::SocketAddr, 16 num::NonZeroU32, 17 sync::Arc, 18}; 19 20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 23#[derive(Clone)] 24pub struct RateLimiters { 25 pub login: Arc<KeyedRateLimiter>, 26 pub oauth_token: Arc<KeyedRateLimiter>, 27 pub password_reset: Arc<KeyedRateLimiter>, 28 pub account_creation: Arc<KeyedRateLimiter>, 29} 30 31impl Default for RateLimiters { 32 fn default() -> Self { 33 Self::new() 34 } 35} 36 37impl RateLimiters { 38 pub fn new() -> Self { 39 Self { 40 login: Arc::new(RateLimiter::keyed( 41 Quota::per_minute(NonZeroU32::new(10).unwrap()) 42 )), 43 oauth_token: Arc::new(RateLimiter::keyed( 44 Quota::per_minute(NonZeroU32::new(30).unwrap()) 45 )), 46 password_reset: Arc::new(RateLimiter::keyed( 47 Quota::per_hour(NonZeroU32::new(5).unwrap()) 48 )), 49 account_creation: Arc::new(RateLimiter::keyed( 50 Quota::per_hour(NonZeroU32::new(10).unwrap()) 51 )), 52 } 53 } 54 55 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 56 self.login = Arc::new(RateLimiter::keyed( 57 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 58 )); 59 self 60 } 61 62 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 63 self.oauth_token = Arc::new(RateLimiter::keyed( 64 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 65 )); 66 self 67 } 68 69 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 70 self.password_reset = Arc::new(RateLimiter::keyed( 71 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 72 )); 73 self 74 } 75 76 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 77 self.account_creation = Arc::new(RateLimiter::keyed( 78 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 79 )); 80 self 81 } 82} 83 84fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 85 if let Some(forwarded) = headers.get("x-forwarded-for") { 86 if let Ok(value) = forwarded.to_str() { 87 if let Some(first_ip) = value.split(',').next() { 88 return first_ip.trim().to_string(); 89 } 90 } 91 } 92 93 if let Some(real_ip) = headers.get("x-real-ip") { 94 if let Ok(value) = real_ip.to_str() { 95 return value.trim().to_string(); 96 } 97 } 98 99 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 100} 101 102fn rate_limit_response() -> Response { 103 ( 104 StatusCode::TOO_MANY_REQUESTS, 105 Json(serde_json::json!({ 106 "error": "RateLimitExceeded", 107 "message": "Too many requests. Please try again later." 108 })), 109 ) 110 .into_response() 111} 112 113pub async fn login_rate_limit( 114 ConnectInfo(addr): ConnectInfo<SocketAddr>, 115 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 116 request: Request<Body>, 117 next: Next, 118) -> Response { 119 let client_ip = extract_client_ip(request.headers(), Some(addr)); 120 121 if limiters.login.check_key(&client_ip).is_err() { 122 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 123 return rate_limit_response(); 124 } 125 126 next.run(request).await 127} 128 129pub async fn oauth_token_rate_limit( 130 ConnectInfo(addr): ConnectInfo<SocketAddr>, 131 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 132 request: Request<Body>, 133 next: Next, 134) -> Response { 135 let client_ip = extract_client_ip(request.headers(), Some(addr)); 136 137 if limiters.oauth_token.check_key(&client_ip).is_err() { 138 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 139 return rate_limit_response(); 140 } 141 142 next.run(request).await 143} 144 145pub async fn password_reset_rate_limit( 146 ConnectInfo(addr): ConnectInfo<SocketAddr>, 147 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 148 request: Request<Body>, 149 next: Next, 150) -> Response { 151 let client_ip = extract_client_ip(request.headers(), Some(addr)); 152 153 if limiters.password_reset.check_key(&client_ip).is_err() { 154 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 155 return rate_limit_response(); 156 } 157 158 next.run(request).await 159} 160 161pub async fn account_creation_rate_limit( 162 ConnectInfo(addr): ConnectInfo<SocketAddr>, 163 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 164 request: Request<Body>, 165 next: Next, 166) -> Response { 167 let client_ip = extract_client_ip(request.headers(), Some(addr)); 168 169 if limiters.account_creation.check_key(&client_ip).is_err() { 170 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 171 return rate_limit_response(); 172 } 173 174 next.run(request).await 175} 176 177#[cfg(test)] 178mod tests { 179 use super::*; 180 181 #[test] 182 fn test_rate_limiters_creation() { 183 let limiters = RateLimiters::new(); 184 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 185 } 186 187 #[test] 188 fn test_rate_limiter_exhaustion() { 189 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 190 let key = "test_ip".to_string(); 191 192 assert!(limiter.check_key(&key).is_ok()); 193 assert!(limiter.check_key(&key).is_ok()); 194 assert!(limiter.check_key(&key).is_err()); 195 } 196 197 #[test] 198 fn test_different_keys_have_separate_limits() { 199 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 200 201 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 202 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 203 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 204 } 205 206 #[test] 207 fn test_builder_pattern() { 208 let limiters = RateLimiters::new() 209 .with_login_limit(20) 210 .with_oauth_token_limit(60) 211 .with_password_reset_limit(3) 212 .with_account_creation_limit(5); 213 214 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 215 } 216}