this repo has no description
1use axum::{
2 Json,
3 body::Body,
4 extract::ConnectInfo,
5 http::{HeaderMap, Request, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
15
16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
18
19#[derive(Clone)]
20pub struct RateLimiters {
21 pub login: Arc<KeyedRateLimiter>,
22 pub oauth_token: Arc<KeyedRateLimiter>,
23 pub oauth_authorize: Arc<KeyedRateLimiter>,
24 pub password_reset: Arc<KeyedRateLimiter>,
25 pub account_creation: Arc<KeyedRateLimiter>,
26 pub refresh_session: Arc<KeyedRateLimiter>,
27 pub reset_password: Arc<KeyedRateLimiter>,
28 pub oauth_par: Arc<KeyedRateLimiter>,
29 pub oauth_introspect: Arc<KeyedRateLimiter>,
30 pub app_password: Arc<KeyedRateLimiter>,
31 pub email_update: Arc<KeyedRateLimiter>,
32 pub totp_verify: Arc<KeyedRateLimiter>,
33}
34
35impl Default for RateLimiters {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl RateLimiters {
42 pub fn new() -> Self {
43 Self {
44 login: Arc::new(RateLimiter::keyed(Quota::per_minute(
45 NonZeroU32::new(10).unwrap(),
46 ))),
47 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute(
48 NonZeroU32::new(30).unwrap(),
49 ))),
50 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute(
51 NonZeroU32::new(10).unwrap(),
52 ))),
53 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour(
54 NonZeroU32::new(5).unwrap(),
55 ))),
56 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour(
57 NonZeroU32::new(10).unwrap(),
58 ))),
59 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute(
60 NonZeroU32::new(60).unwrap(),
61 ))),
62 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
63 NonZeroU32::new(10).unwrap(),
64 ))),
65 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute(
66 NonZeroU32::new(30).unwrap(),
67 ))),
68 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute(
69 NonZeroU32::new(30).unwrap(),
70 ))),
71 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
72 NonZeroU32::new(10).unwrap(),
73 ))),
74 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour(
75 NonZeroU32::new(5).unwrap(),
76 ))),
77 totp_verify: Arc::new(RateLimiter::keyed(Quota::with_period(std::time::Duration::from_secs(60))
78 .unwrap()
79 .allow_burst(NonZeroU32::new(5).unwrap()),
80 )),
81 }
82 }
83
84 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
85 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute(
86 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
87 )));
88 self
89 }
90
91 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
92 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute(
93 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()),
94 )));
95 self
96 }
97
98 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
99 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute(
100 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
101 )));
102 self
103 }
104
105 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
106 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour(
107 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
108 )));
109 self
110 }
111
112 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
113 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour(
114 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()),
115 )));
116 self
117 }
118
119 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
120 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour(
121 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
122 )));
123 self
124 }
125}
126
127pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
128 if let Some(forwarded) = headers.get("x-forwarded-for")
129 && let Ok(value) = forwarded.to_str()
130 && let Some(first_ip) = value.split(',').next()
131 {
132 return first_ip.trim().to_string();
133 }
134
135 if let Some(real_ip) = headers.get("x-real-ip")
136 && let Ok(value) = real_ip.to_str()
137 {
138 return value.trim().to_string();
139 }
140
141 addr.map(|a| a.ip().to_string())
142 .unwrap_or_else(|| "unknown".to_string())
143}
144
145fn rate_limit_response() -> Response {
146 (
147 StatusCode::TOO_MANY_REQUESTS,
148 Json(serde_json::json!({
149 "error": "RateLimitExceeded",
150 "message": "Too many requests. Please try again later."
151 })),
152 )
153 .into_response()
154}
155
156pub async fn login_rate_limit(
157 ConnectInfo(addr): ConnectInfo<SocketAddr>,
158 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
159 request: Request<Body>,
160 next: Next,
161) -> Response {
162 let client_ip = extract_client_ip(request.headers(), Some(addr));
163
164 if limiters.login.check_key(&client_ip).is_err() {
165 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
166 return rate_limit_response();
167 }
168
169 next.run(request).await
170}
171
172pub async fn oauth_token_rate_limit(
173 ConnectInfo(addr): ConnectInfo<SocketAddr>,
174 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
175 request: Request<Body>,
176 next: Next,
177) -> Response {
178 let client_ip = extract_client_ip(request.headers(), Some(addr));
179
180 if limiters.oauth_token.check_key(&client_ip).is_err() {
181 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
182 return rate_limit_response();
183 }
184
185 next.run(request).await
186}
187
188pub async fn password_reset_rate_limit(
189 ConnectInfo(addr): ConnectInfo<SocketAddr>,
190 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
191 request: Request<Body>,
192 next: Next,
193) -> Response {
194 let client_ip = extract_client_ip(request.headers(), Some(addr));
195
196 if limiters.password_reset.check_key(&client_ip).is_err() {
197 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
198 return rate_limit_response();
199 }
200
201 next.run(request).await
202}
203
204pub async fn account_creation_rate_limit(
205 ConnectInfo(addr): ConnectInfo<SocketAddr>,
206 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
207 request: Request<Body>,
208 next: Next,
209) -> Response {
210 let client_ip = extract_client_ip(request.headers(), Some(addr));
211
212 if limiters.account_creation.check_key(&client_ip).is_err() {
213 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
214 return rate_limit_response();
215 }
216
217 next.run(request).await
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_rate_limiters_creation() {
226 let limiters = RateLimiters::new();
227 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
228 }
229
230 #[test]
231 fn test_rate_limiter_exhaustion() {
232 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
233 let key = "test_ip".to_string();
234
235 assert!(limiter.check_key(&key).is_ok());
236 assert!(limiter.check_key(&key).is_ok());
237 assert!(limiter.check_key(&key).is_err());
238 }
239
240 #[test]
241 fn test_different_keys_have_separate_limits() {
242 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
243
244 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
245 assert!(limiter.check_key(&"ip1".to_string()).is_err());
246 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
247 }
248
249 #[test]
250 fn test_builder_pattern() {
251 let limiters = RateLimiters::new()
252 .with_login_limit(20)
253 .with_oauth_token_limit(60)
254 .with_password_reset_limit(3)
255 .with_account_creation_limit(5);
256
257 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
258 }
259}