this repo has no description
1use axum::{
2 Json,
3 body::Body,
4 extract::ConnectInfo,
5 http::{HeaderMap, Request, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
15
16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
18
19#[derive(Clone)]
20pub struct RateLimiters {
21 pub login: Arc<KeyedRateLimiter>,
22 pub oauth_token: Arc<KeyedRateLimiter>,
23 pub oauth_authorize: Arc<KeyedRateLimiter>,
24 pub password_reset: Arc<KeyedRateLimiter>,
25 pub account_creation: Arc<KeyedRateLimiter>,
26 pub refresh_session: Arc<KeyedRateLimiter>,
27 pub reset_password: Arc<KeyedRateLimiter>,
28 pub oauth_par: Arc<KeyedRateLimiter>,
29 pub oauth_introspect: Arc<KeyedRateLimiter>,
30 pub app_password: Arc<KeyedRateLimiter>,
31 pub email_update: Arc<KeyedRateLimiter>,
32 pub totp_verify: Arc<KeyedRateLimiter>,
33 pub handle_update: Arc<KeyedRateLimiter>,
34 pub handle_update_daily: Arc<KeyedRateLimiter>,
35 pub verification_check: Arc<KeyedRateLimiter>,
36}
37
38impl Default for RateLimiters {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl RateLimiters {
45 pub fn new() -> Self {
46 Self {
47 login: Arc::new(RateLimiter::keyed(Quota::per_minute(
48 NonZeroU32::new(10).unwrap(),
49 ))),
50 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute(
51 NonZeroU32::new(30).unwrap(),
52 ))),
53 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute(
54 NonZeroU32::new(10).unwrap(),
55 ))),
56 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour(
57 NonZeroU32::new(5).unwrap(),
58 ))),
59 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour(
60 NonZeroU32::new(10).unwrap(),
61 ))),
62 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute(
63 NonZeroU32::new(60).unwrap(),
64 ))),
65 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
66 NonZeroU32::new(10).unwrap(),
67 ))),
68 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute(
69 NonZeroU32::new(30).unwrap(),
70 ))),
71 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute(
72 NonZeroU32::new(30).unwrap(),
73 ))),
74 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
75 NonZeroU32::new(10).unwrap(),
76 ))),
77 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour(
78 NonZeroU32::new(5).unwrap(),
79 ))),
80 totp_verify: Arc::new(RateLimiter::keyed(
81 Quota::with_period(std::time::Duration::from_secs(60))
82 .unwrap()
83 .allow_burst(NonZeroU32::new(5).unwrap()),
84 )),
85 handle_update: Arc::new(RateLimiter::keyed(
86 Quota::with_period(std::time::Duration::from_secs(30))
87 .unwrap()
88 .allow_burst(NonZeroU32::new(10).unwrap()),
89 )),
90 handle_update_daily: Arc::new(RateLimiter::keyed(
91 Quota::with_period(std::time::Duration::from_secs(1728))
92 .unwrap()
93 .allow_burst(NonZeroU32::new(50).unwrap()),
94 )),
95 verification_check: Arc::new(RateLimiter::keyed(Quota::per_minute(
96 NonZeroU32::new(60).unwrap(),
97 ))),
98 }
99 }
100
101 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
102 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute(
103 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
104 )));
105 self
106 }
107
108 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
109 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute(
110 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()),
111 )));
112 self
113 }
114
115 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
116 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute(
117 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
118 )));
119 self
120 }
121
122 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
123 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour(
124 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
125 )));
126 self
127 }
128
129 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
130 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour(
131 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()),
132 )));
133 self
134 }
135
136 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
137 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour(
138 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
139 )));
140 self
141 }
142}
143
144pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
145 if let Some(forwarded) = headers.get("x-forwarded-for")
146 && let Ok(value) = forwarded.to_str()
147 && let Some(first_ip) = value.split(',').next()
148 {
149 return first_ip.trim().to_string();
150 }
151
152 if let Some(real_ip) = headers.get("x-real-ip")
153 && let Ok(value) = real_ip.to_str()
154 {
155 return value.trim().to_string();
156 }
157
158 addr.map(|a| a.ip().to_string())
159 .unwrap_or_else(|| "unknown".to_string())
160}
161
162fn rate_limit_response() -> Response {
163 (
164 StatusCode::TOO_MANY_REQUESTS,
165 Json(serde_json::json!({
166 "error": "RateLimitExceeded",
167 "message": "Too many requests. Please try again later."
168 })),
169 )
170 .into_response()
171}
172
173pub async fn login_rate_limit(
174 ConnectInfo(addr): ConnectInfo<SocketAddr>,
175 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
176 request: Request<Body>,
177 next: Next,
178) -> Response {
179 let client_ip = extract_client_ip(request.headers(), Some(addr));
180
181 if limiters.login.check_key(&client_ip).is_err() {
182 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
183 return rate_limit_response();
184 }
185
186 next.run(request).await
187}
188
189pub async fn oauth_token_rate_limit(
190 ConnectInfo(addr): ConnectInfo<SocketAddr>,
191 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
192 request: Request<Body>,
193 next: Next,
194) -> Response {
195 let client_ip = extract_client_ip(request.headers(), Some(addr));
196
197 if limiters.oauth_token.check_key(&client_ip).is_err() {
198 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
199 return rate_limit_response();
200 }
201
202 next.run(request).await
203}
204
205pub async fn password_reset_rate_limit(
206 ConnectInfo(addr): ConnectInfo<SocketAddr>,
207 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
208 request: Request<Body>,
209 next: Next,
210) -> Response {
211 let client_ip = extract_client_ip(request.headers(), Some(addr));
212
213 if limiters.password_reset.check_key(&client_ip).is_err() {
214 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
215 return rate_limit_response();
216 }
217
218 next.run(request).await
219}
220
221pub async fn account_creation_rate_limit(
222 ConnectInfo(addr): ConnectInfo<SocketAddr>,
223 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
224 request: Request<Body>,
225 next: Next,
226) -> Response {
227 let client_ip = extract_client_ip(request.headers(), Some(addr));
228
229 if limiters.account_creation.check_key(&client_ip).is_err() {
230 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
231 return rate_limit_response();
232 }
233
234 next.run(request).await
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_rate_limiters_creation() {
243 let limiters = RateLimiters::new();
244 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
245 }
246
247 #[test]
248 fn test_rate_limiter_exhaustion() {
249 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
250 let key = "test_ip".to_string();
251
252 assert!(limiter.check_key(&key).is_ok());
253 assert!(limiter.check_key(&key).is_ok());
254 assert!(limiter.check_key(&key).is_err());
255 }
256
257 #[test]
258 fn test_different_keys_have_separate_limits() {
259 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
260
261 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
262 assert!(limiter.check_key(&"ip1".to_string()).is_err());
263 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
264 }
265
266 #[test]
267 fn test_builder_pattern() {
268 let limiters = RateLimiters::new()
269 .with_login_limit(20)
270 .with_oauth_token_limit(60)
271 .with_password_reset_limit(3)
272 .with_account_creation_limit(5);
273
274 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
275 }
276}