this repo has no description
1use axum::{
2 body::Body,
3 extract::ConnectInfo,
4 http::{HeaderMap, Request, StatusCode},
5 middleware::Next,
6 response::{IntoResponse, Response},
7 Json,
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{
15 net::SocketAddr,
16 num::NonZeroU32,
17 sync::Arc,
18};
19
20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
22
23// NOTE: For production deployments with high traffic, prefer using the distributed rate
24// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory
25// rate limiters here don't automatically clean up expired entries, which can cause
26// memory growth over time with many unique client IPs. The distributed rate limiter
27// uses Redis TTL for automatic cleanup and works correctly across multiple instances.
28
29#[derive(Clone)]
30pub struct RateLimiters {
31 pub login: Arc<KeyedRateLimiter>,
32 pub oauth_token: Arc<KeyedRateLimiter>,
33 pub oauth_authorize: Arc<KeyedRateLimiter>,
34 pub password_reset: Arc<KeyedRateLimiter>,
35 pub account_creation: Arc<KeyedRateLimiter>,
36 pub refresh_session: Arc<KeyedRateLimiter>,
37 pub reset_password: Arc<KeyedRateLimiter>,
38 pub oauth_par: Arc<KeyedRateLimiter>,
39 pub oauth_introspect: Arc<KeyedRateLimiter>,
40 pub app_password: Arc<KeyedRateLimiter>,
41 pub email_update: Arc<KeyedRateLimiter>,
42}
43
44impl Default for RateLimiters {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl RateLimiters {
51 pub fn new() -> Self {
52 Self {
53 login: Arc::new(RateLimiter::keyed(
54 Quota::per_minute(NonZeroU32::new(10).unwrap())
55 )),
56 oauth_token: Arc::new(RateLimiter::keyed(
57 Quota::per_minute(NonZeroU32::new(30).unwrap())
58 )),
59 oauth_authorize: Arc::new(RateLimiter::keyed(
60 Quota::per_minute(NonZeroU32::new(10).unwrap())
61 )),
62 password_reset: Arc::new(RateLimiter::keyed(
63 Quota::per_hour(NonZeroU32::new(5).unwrap())
64 )),
65 account_creation: Arc::new(RateLimiter::keyed(
66 Quota::per_hour(NonZeroU32::new(10).unwrap())
67 )),
68 refresh_session: Arc::new(RateLimiter::keyed(
69 Quota::per_minute(NonZeroU32::new(60).unwrap())
70 )),
71 reset_password: Arc::new(RateLimiter::keyed(
72 Quota::per_minute(NonZeroU32::new(10).unwrap())
73 )),
74 oauth_par: Arc::new(RateLimiter::keyed(
75 Quota::per_minute(NonZeroU32::new(30).unwrap())
76 )),
77 oauth_introspect: Arc::new(RateLimiter::keyed(
78 Quota::per_minute(NonZeroU32::new(30).unwrap())
79 )),
80 app_password: Arc::new(RateLimiter::keyed(
81 Quota::per_minute(NonZeroU32::new(10).unwrap())
82 )),
83 email_update: Arc::new(RateLimiter::keyed(
84 Quota::per_hour(NonZeroU32::new(5).unwrap())
85 )),
86 }
87 }
88
89 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
90 self.login = Arc::new(RateLimiter::keyed(
91 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
92 ));
93 self
94 }
95
96 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
97 self.oauth_token = Arc::new(RateLimiter::keyed(
98 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
99 ));
100 self
101 }
102
103 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
104 self.oauth_authorize = Arc::new(RateLimiter::keyed(
105 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
106 ));
107 self
108 }
109
110 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
111 self.password_reset = Arc::new(RateLimiter::keyed(
112 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
113 ));
114 self
115 }
116
117 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
118 self.account_creation = Arc::new(RateLimiter::keyed(
119 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
120 ));
121 self
122 }
123
124 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
125 self.email_update = Arc::new(RateLimiter::keyed(
126 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
127 ));
128 self
129 }
130}
131
132pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
133 if let Some(forwarded) = headers.get("x-forwarded-for") {
134 if let Ok(value) = forwarded.to_str() {
135 if let Some(first_ip) = value.split(',').next() {
136 return first_ip.trim().to_string();
137 }
138 }
139 }
140
141 if let Some(real_ip) = headers.get("x-real-ip") {
142 if let Ok(value) = real_ip.to_str() {
143 return value.trim().to_string();
144 }
145 }
146
147 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
148}
149
150fn rate_limit_response() -> Response {
151 (
152 StatusCode::TOO_MANY_REQUESTS,
153 Json(serde_json::json!({
154 "error": "RateLimitExceeded",
155 "message": "Too many requests. Please try again later."
156 })),
157 )
158 .into_response()
159}
160
161pub async fn login_rate_limit(
162 ConnectInfo(addr): ConnectInfo<SocketAddr>,
163 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
164 request: Request<Body>,
165 next: Next,
166) -> Response {
167 let client_ip = extract_client_ip(request.headers(), Some(addr));
168
169 if limiters.login.check_key(&client_ip).is_err() {
170 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
171 return rate_limit_response();
172 }
173
174 next.run(request).await
175}
176
177pub async fn oauth_token_rate_limit(
178 ConnectInfo(addr): ConnectInfo<SocketAddr>,
179 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
180 request: Request<Body>,
181 next: Next,
182) -> Response {
183 let client_ip = extract_client_ip(request.headers(), Some(addr));
184
185 if limiters.oauth_token.check_key(&client_ip).is_err() {
186 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
187 return rate_limit_response();
188 }
189
190 next.run(request).await
191}
192
193pub async fn password_reset_rate_limit(
194 ConnectInfo(addr): ConnectInfo<SocketAddr>,
195 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
196 request: Request<Body>,
197 next: Next,
198) -> Response {
199 let client_ip = extract_client_ip(request.headers(), Some(addr));
200
201 if limiters.password_reset.check_key(&client_ip).is_err() {
202 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
203 return rate_limit_response();
204 }
205
206 next.run(request).await
207}
208
209pub async fn account_creation_rate_limit(
210 ConnectInfo(addr): ConnectInfo<SocketAddr>,
211 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
212 request: Request<Body>,
213 next: Next,
214) -> Response {
215 let client_ip = extract_client_ip(request.headers(), Some(addr));
216
217 if limiters.account_creation.check_key(&client_ip).is_err() {
218 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
219 return rate_limit_response();
220 }
221
222 next.run(request).await
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_rate_limiters_creation() {
231 let limiters = RateLimiters::new();
232 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
233 }
234
235 #[test]
236 fn test_rate_limiter_exhaustion() {
237 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
238 let key = "test_ip".to_string();
239
240 assert!(limiter.check_key(&key).is_ok());
241 assert!(limiter.check_key(&key).is_ok());
242 assert!(limiter.check_key(&key).is_err());
243 }
244
245 #[test]
246 fn test_different_keys_have_separate_limits() {
247 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
248
249 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
250 assert!(limiter.check_key(&"ip1".to_string()).is_err());
251 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
252 }
253
254 #[test]
255 fn test_builder_pattern() {
256 let limiters = RateLimiters::new()
257 .with_login_limit(20)
258 .with_oauth_token_limit(60)
259 .with_password_reset_limit(3)
260 .with_account_creation_limit(5);
261
262 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
263 }
264}