this repo has no description
1use axum::{
2 body::Body,
3 extract::ConnectInfo,
4 http::{HeaderMap, Request, StatusCode},
5 middleware::Next,
6 response::{IntoResponse, Response},
7 Json,
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{
15 net::SocketAddr,
16 num::NonZeroU32,
17 sync::Arc,
18};
19pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
20pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
21// NOTE: For production deployments with high traffic, prefer using the distributed rate
22// limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory
23// rate limiters here don't automatically clean up expired entries, which can cause
24// memory growth over time with many unique client IPs. The distributed rate limiter
25// uses Redis TTL for automatic cleanup and works correctly across multiple instances.
26#[derive(Clone)]
27pub struct RateLimiters {
28 pub login: Arc<KeyedRateLimiter>,
29 pub oauth_token: Arc<KeyedRateLimiter>,
30 pub oauth_authorize: Arc<KeyedRateLimiter>,
31 pub password_reset: Arc<KeyedRateLimiter>,
32 pub account_creation: Arc<KeyedRateLimiter>,
33 pub refresh_session: Arc<KeyedRateLimiter>,
34 pub reset_password: Arc<KeyedRateLimiter>,
35 pub oauth_par: Arc<KeyedRateLimiter>,
36 pub oauth_introspect: Arc<KeyedRateLimiter>,
37 pub app_password: Arc<KeyedRateLimiter>,
38 pub email_update: Arc<KeyedRateLimiter>,
39}
40impl Default for RateLimiters {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45impl RateLimiters {
46 pub fn new() -> Self {
47 Self {
48 login: Arc::new(RateLimiter::keyed(
49 Quota::per_minute(NonZeroU32::new(10).unwrap())
50 )),
51 oauth_token: Arc::new(RateLimiter::keyed(
52 Quota::per_minute(NonZeroU32::new(30).unwrap())
53 )),
54 oauth_authorize: Arc::new(RateLimiter::keyed(
55 Quota::per_minute(NonZeroU32::new(10).unwrap())
56 )),
57 password_reset: Arc::new(RateLimiter::keyed(
58 Quota::per_hour(NonZeroU32::new(5).unwrap())
59 )),
60 account_creation: Arc::new(RateLimiter::keyed(
61 Quota::per_hour(NonZeroU32::new(10).unwrap())
62 )),
63 refresh_session: Arc::new(RateLimiter::keyed(
64 Quota::per_minute(NonZeroU32::new(60).unwrap())
65 )),
66 reset_password: Arc::new(RateLimiter::keyed(
67 Quota::per_minute(NonZeroU32::new(10).unwrap())
68 )),
69 oauth_par: Arc::new(RateLimiter::keyed(
70 Quota::per_minute(NonZeroU32::new(30).unwrap())
71 )),
72 oauth_introspect: Arc::new(RateLimiter::keyed(
73 Quota::per_minute(NonZeroU32::new(30).unwrap())
74 )),
75 app_password: Arc::new(RateLimiter::keyed(
76 Quota::per_minute(NonZeroU32::new(10).unwrap())
77 )),
78 email_update: Arc::new(RateLimiter::keyed(
79 Quota::per_hour(NonZeroU32::new(5).unwrap())
80 )),
81 }
82 }
83 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84 self.login = Arc::new(RateLimiter::keyed(
85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86 ));
87 self
88 }
89 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
90 self.oauth_token = Arc::new(RateLimiter::keyed(
91 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
92 ));
93 self
94 }
95 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
96 self.oauth_authorize = Arc::new(RateLimiter::keyed(
97 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
98 ));
99 self
100 }
101 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
102 self.password_reset = Arc::new(RateLimiter::keyed(
103 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
104 ));
105 self
106 }
107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
108 self.account_creation = Arc::new(RateLimiter::keyed(
109 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
110 ));
111 self
112 }
113 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
114 self.email_update = Arc::new(RateLimiter::keyed(
115 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
116 ));
117 self
118 }
119}
120pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
121 if let Some(forwarded) = headers.get("x-forwarded-for") {
122 if let Ok(value) = forwarded.to_str() {
123 if let Some(first_ip) = value.split(',').next() {
124 return first_ip.trim().to_string();
125 }
126 }
127 }
128 if let Some(real_ip) = headers.get("x-real-ip") {
129 if let Ok(value) = real_ip.to_str() {
130 return value.trim().to_string();
131 }
132 }
133 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
134}
135fn rate_limit_response() -> Response {
136 (
137 StatusCode::TOO_MANY_REQUESTS,
138 Json(serde_json::json!({
139 "error": "RateLimitExceeded",
140 "message": "Too many requests. Please try again later."
141 })),
142 )
143 .into_response()
144}
145pub async fn login_rate_limit(
146 ConnectInfo(addr): ConnectInfo<SocketAddr>,
147 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
148 request: Request<Body>,
149 next: Next,
150) -> Response {
151 let client_ip = extract_client_ip(request.headers(), Some(addr));
152 if limiters.login.check_key(&client_ip).is_err() {
153 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
154 return rate_limit_response();
155 }
156 next.run(request).await
157}
158pub async fn oauth_token_rate_limit(
159 ConnectInfo(addr): ConnectInfo<SocketAddr>,
160 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
161 request: Request<Body>,
162 next: Next,
163) -> Response {
164 let client_ip = extract_client_ip(request.headers(), Some(addr));
165 if limiters.oauth_token.check_key(&client_ip).is_err() {
166 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
167 return rate_limit_response();
168 }
169 next.run(request).await
170}
171pub async fn password_reset_rate_limit(
172 ConnectInfo(addr): ConnectInfo<SocketAddr>,
173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
174 request: Request<Body>,
175 next: Next,
176) -> Response {
177 let client_ip = extract_client_ip(request.headers(), Some(addr));
178 if limiters.password_reset.check_key(&client_ip).is_err() {
179 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
180 return rate_limit_response();
181 }
182 next.run(request).await
183}
184pub async fn account_creation_rate_limit(
185 ConnectInfo(addr): ConnectInfo<SocketAddr>,
186 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
187 request: Request<Body>,
188 next: Next,
189) -> Response {
190 let client_ip = extract_client_ip(request.headers(), Some(addr));
191 if limiters.account_creation.check_key(&client_ip).is_err() {
192 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
193 return rate_limit_response();
194 }
195 next.run(request).await
196}
197#[cfg(test)]
198mod tests {
199 use super::*;
200 #[test]
201 fn test_rate_limiters_creation() {
202 let limiters = RateLimiters::new();
203 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
204 }
205 #[test]
206 fn test_rate_limiter_exhaustion() {
207 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
208 let key = "test_ip".to_string();
209 assert!(limiter.check_key(&key).is_ok());
210 assert!(limiter.check_key(&key).is_ok());
211 assert!(limiter.check_key(&key).is_err());
212 }
213 #[test]
214 fn test_different_keys_have_separate_limits() {
215 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
216 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
217 assert!(limiter.check_key(&"ip1".to_string()).is_err());
218 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
219 }
220 #[test]
221 fn test_builder_pattern() {
222 let limiters = RateLimiters::new()
223 .with_login_limit(20)
224 .with_oauth_token_limit(60)
225 .with_password_reset_limit(3)
226 .with_account_creation_limit(5);
227 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
228 }
229}