this repo has no description
1use axum::{
2 Json,
3 body::Body,
4 extract::ConnectInfo,
5 http::{HeaderMap, Request, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
15
16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
18
19#[derive(Clone)]
20pub struct RateLimiters {
21 pub login: Arc<KeyedRateLimiter>,
22 pub oauth_token: Arc<KeyedRateLimiter>,
23 pub oauth_authorize: Arc<KeyedRateLimiter>,
24 pub password_reset: Arc<KeyedRateLimiter>,
25 pub account_creation: Arc<KeyedRateLimiter>,
26 pub refresh_session: Arc<KeyedRateLimiter>,
27 pub reset_password: Arc<KeyedRateLimiter>,
28 pub oauth_par: Arc<KeyedRateLimiter>,
29 pub oauth_introspect: Arc<KeyedRateLimiter>,
30 pub app_password: Arc<KeyedRateLimiter>,
31 pub email_update: Arc<KeyedRateLimiter>,
32 pub totp_verify: Arc<KeyedRateLimiter>,
33 pub handle_update: Arc<KeyedRateLimiter>,
34 pub handle_update_daily: Arc<KeyedRateLimiter>,
35}
36
37impl Default for RateLimiters {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl RateLimiters {
44 pub fn new() -> Self {
45 Self {
46 login: Arc::new(RateLimiter::keyed(Quota::per_minute(
47 NonZeroU32::new(10).unwrap(),
48 ))),
49 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute(
50 NonZeroU32::new(30).unwrap(),
51 ))),
52 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute(
53 NonZeroU32::new(10).unwrap(),
54 ))),
55 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour(
56 NonZeroU32::new(5).unwrap(),
57 ))),
58 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour(
59 NonZeroU32::new(10).unwrap(),
60 ))),
61 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute(
62 NonZeroU32::new(60).unwrap(),
63 ))),
64 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
65 NonZeroU32::new(10).unwrap(),
66 ))),
67 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute(
68 NonZeroU32::new(30).unwrap(),
69 ))),
70 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute(
71 NonZeroU32::new(30).unwrap(),
72 ))),
73 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
74 NonZeroU32::new(10).unwrap(),
75 ))),
76 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour(
77 NonZeroU32::new(5).unwrap(),
78 ))),
79 totp_verify: Arc::new(RateLimiter::keyed(
80 Quota::with_period(std::time::Duration::from_secs(60))
81 .unwrap()
82 .allow_burst(NonZeroU32::new(5).unwrap()),
83 )),
84 handle_update: Arc::new(RateLimiter::keyed(
85 Quota::with_period(std::time::Duration::from_secs(30))
86 .unwrap()
87 .allow_burst(NonZeroU32::new(10).unwrap()),
88 )),
89 handle_update_daily: Arc::new(RateLimiter::keyed(
90 Quota::with_period(std::time::Duration::from_secs(1728))
91 .unwrap()
92 .allow_burst(NonZeroU32::new(50).unwrap()),
93 )),
94 }
95 }
96
97 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
98 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute(
99 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
100 )));
101 self
102 }
103
104 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
105 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute(
106 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()),
107 )));
108 self
109 }
110
111 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
112 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute(
113 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
114 )));
115 self
116 }
117
118 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
119 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour(
120 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
121 )));
122 self
123 }
124
125 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
126 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour(
127 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()),
128 )));
129 self
130 }
131
132 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
133 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour(
134 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
135 )));
136 self
137 }
138}
139
140pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
141 if let Some(forwarded) = headers.get("x-forwarded-for")
142 && let Ok(value) = forwarded.to_str()
143 && let Some(first_ip) = value.split(',').next()
144 {
145 return first_ip.trim().to_string();
146 }
147
148 if let Some(real_ip) = headers.get("x-real-ip")
149 && let Ok(value) = real_ip.to_str()
150 {
151 return value.trim().to_string();
152 }
153
154 addr.map(|a| a.ip().to_string())
155 .unwrap_or_else(|| "unknown".to_string())
156}
157
158fn rate_limit_response() -> Response {
159 (
160 StatusCode::TOO_MANY_REQUESTS,
161 Json(serde_json::json!({
162 "error": "RateLimitExceeded",
163 "message": "Too many requests. Please try again later."
164 })),
165 )
166 .into_response()
167}
168
169pub async fn login_rate_limit(
170 ConnectInfo(addr): ConnectInfo<SocketAddr>,
171 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
172 request: Request<Body>,
173 next: Next,
174) -> Response {
175 let client_ip = extract_client_ip(request.headers(), Some(addr));
176
177 if limiters.login.check_key(&client_ip).is_err() {
178 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
179 return rate_limit_response();
180 }
181
182 next.run(request).await
183}
184
185pub async fn oauth_token_rate_limit(
186 ConnectInfo(addr): ConnectInfo<SocketAddr>,
187 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
188 request: Request<Body>,
189 next: Next,
190) -> Response {
191 let client_ip = extract_client_ip(request.headers(), Some(addr));
192
193 if limiters.oauth_token.check_key(&client_ip).is_err() {
194 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
195 return rate_limit_response();
196 }
197
198 next.run(request).await
199}
200
201pub async fn password_reset_rate_limit(
202 ConnectInfo(addr): ConnectInfo<SocketAddr>,
203 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
204 request: Request<Body>,
205 next: Next,
206) -> Response {
207 let client_ip = extract_client_ip(request.headers(), Some(addr));
208
209 if limiters.password_reset.check_key(&client_ip).is_err() {
210 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
211 return rate_limit_response();
212 }
213
214 next.run(request).await
215}
216
217pub async fn account_creation_rate_limit(
218 ConnectInfo(addr): ConnectInfo<SocketAddr>,
219 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
220 request: Request<Body>,
221 next: Next,
222) -> Response {
223 let client_ip = extract_client_ip(request.headers(), Some(addr));
224
225 if limiters.account_creation.check_key(&client_ip).is_err() {
226 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
227 return rate_limit_response();
228 }
229
230 next.run(request).await
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_rate_limiters_creation() {
239 let limiters = RateLimiters::new();
240 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
241 }
242
243 #[test]
244 fn test_rate_limiter_exhaustion() {
245 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
246 let key = "test_ip".to_string();
247
248 assert!(limiter.check_key(&key).is_ok());
249 assert!(limiter.check_key(&key).is_ok());
250 assert!(limiter.check_key(&key).is_err());
251 }
252
253 #[test]
254 fn test_different_keys_have_separate_limits() {
255 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
256
257 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
258 assert!(limiter.check_key(&"ip1".to_string()).is_err());
259 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
260 }
261
262 #[test]
263 fn test_builder_pattern() {
264 let limiters = RateLimiters::new()
265 .with_login_limit(20)
266 .with_oauth_token_limit(60)
267 .with_password_reset_limit(3)
268 .with_account_creation_limit(5);
269
270 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
271 }
272}