this repo has no description
1use axum::{
2 Json,
3 body::Body,
4 extract::ConnectInfo,
5 http::{HeaderMap, Request, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
15
16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
18
19#[derive(Clone)]
20pub struct RateLimiters {
21 pub login: Arc<KeyedRateLimiter>,
22 pub oauth_token: Arc<KeyedRateLimiter>,
23 pub oauth_authorize: Arc<KeyedRateLimiter>,
24 pub password_reset: Arc<KeyedRateLimiter>,
25 pub account_creation: Arc<KeyedRateLimiter>,
26 pub refresh_session: Arc<KeyedRateLimiter>,
27 pub reset_password: Arc<KeyedRateLimiter>,
28 pub oauth_par: Arc<KeyedRateLimiter>,
29 pub oauth_introspect: Arc<KeyedRateLimiter>,
30 pub app_password: Arc<KeyedRateLimiter>,
31 pub email_update: Arc<KeyedRateLimiter>,
32}
33
34impl Default for RateLimiters {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl RateLimiters {
41 pub fn new() -> Self {
42 Self {
43 login: Arc::new(RateLimiter::keyed(Quota::per_minute(
44 NonZeroU32::new(10).unwrap(),
45 ))),
46 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute(
47 NonZeroU32::new(30).unwrap(),
48 ))),
49 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute(
50 NonZeroU32::new(10).unwrap(),
51 ))),
52 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour(
53 NonZeroU32::new(5).unwrap(),
54 ))),
55 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour(
56 NonZeroU32::new(10).unwrap(),
57 ))),
58 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute(
59 NonZeroU32::new(60).unwrap(),
60 ))),
61 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
62 NonZeroU32::new(10).unwrap(),
63 ))),
64 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute(
65 NonZeroU32::new(30).unwrap(),
66 ))),
67 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute(
68 NonZeroU32::new(30).unwrap(),
69 ))),
70 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
71 NonZeroU32::new(10).unwrap(),
72 ))),
73 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour(
74 NonZeroU32::new(5).unwrap(),
75 ))),
76 }
77 }
78
79 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
80 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute(
81 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
82 )));
83 self
84 }
85
86 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
87 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute(
88 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()),
89 )));
90 self
91 }
92
93 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
94 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute(
95 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
96 )));
97 self
98 }
99
100 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
101 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour(
102 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
103 )));
104 self
105 }
106
107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
108 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour(
109 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()),
110 )));
111 self
112 }
113
114 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
115 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour(
116 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
117 )));
118 self
119 }
120}
121
122pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
123 if let Some(forwarded) = headers.get("x-forwarded-for")
124 && let Ok(value) = forwarded.to_str()
125 && let Some(first_ip) = value.split(',').next()
126 {
127 return first_ip.trim().to_string();
128 }
129
130 if let Some(real_ip) = headers.get("x-real-ip")
131 && let Ok(value) = real_ip.to_str()
132 {
133 return value.trim().to_string();
134 }
135
136 addr.map(|a| a.ip().to_string())
137 .unwrap_or_else(|| "unknown".to_string())
138}
139
140fn rate_limit_response() -> Response {
141 (
142 StatusCode::TOO_MANY_REQUESTS,
143 Json(serde_json::json!({
144 "error": "RateLimitExceeded",
145 "message": "Too many requests. Please try again later."
146 })),
147 )
148 .into_response()
149}
150
151pub async fn login_rate_limit(
152 ConnectInfo(addr): ConnectInfo<SocketAddr>,
153 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
154 request: Request<Body>,
155 next: Next,
156) -> Response {
157 let client_ip = extract_client_ip(request.headers(), Some(addr));
158
159 if limiters.login.check_key(&client_ip).is_err() {
160 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
161 return rate_limit_response();
162 }
163
164 next.run(request).await
165}
166
167pub async fn oauth_token_rate_limit(
168 ConnectInfo(addr): ConnectInfo<SocketAddr>,
169 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
170 request: Request<Body>,
171 next: Next,
172) -> Response {
173 let client_ip = extract_client_ip(request.headers(), Some(addr));
174
175 if limiters.oauth_token.check_key(&client_ip).is_err() {
176 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
177 return rate_limit_response();
178 }
179
180 next.run(request).await
181}
182
183pub async fn password_reset_rate_limit(
184 ConnectInfo(addr): ConnectInfo<SocketAddr>,
185 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
186 request: Request<Body>,
187 next: Next,
188) -> Response {
189 let client_ip = extract_client_ip(request.headers(), Some(addr));
190
191 if limiters.password_reset.check_key(&client_ip).is_err() {
192 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
193 return rate_limit_response();
194 }
195
196 next.run(request).await
197}
198
199pub async fn account_creation_rate_limit(
200 ConnectInfo(addr): ConnectInfo<SocketAddr>,
201 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
202 request: Request<Body>,
203 next: Next,
204) -> Response {
205 let client_ip = extract_client_ip(request.headers(), Some(addr));
206
207 if limiters.account_creation.check_key(&client_ip).is_err() {
208 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
209 return rate_limit_response();
210 }
211
212 next.run(request).await
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_rate_limiters_creation() {
221 let limiters = RateLimiters::new();
222 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
223 }
224
225 #[test]
226 fn test_rate_limiter_exhaustion() {
227 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
228 let key = "test_ip".to_string();
229
230 assert!(limiter.check_key(&key).is_ok());
231 assert!(limiter.check_key(&key).is_ok());
232 assert!(limiter.check_key(&key).is_err());
233 }
234
235 #[test]
236 fn test_different_keys_have_separate_limits() {
237 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
238
239 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
240 assert!(limiter.check_key(&"ip1".to_string()).is_err());
241 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
242 }
243
244 #[test]
245 fn test_builder_pattern() {
246 let limiters = RateLimiters::new()
247 .with_login_limit(20)
248 .with_oauth_token_limit(60)
249 .with_password_reset_limit(3)
250 .with_account_creation_limit(5);
251
252 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
253 }
254}