this repo has no description
1use axum::{
2 body::Body,
3 extract::ConnectInfo,
4 http::{HeaderMap, Request, StatusCode},
5 middleware::Next,
6 response::{IntoResponse, Response},
7 Json,
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{
15 net::SocketAddr,
16 num::NonZeroU32,
17 sync::Arc,
18};
19
20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
22
23#[derive(Clone)]
24pub struct RateLimiters {
25 pub login: Arc<KeyedRateLimiter>,
26 pub oauth_token: Arc<KeyedRateLimiter>,
27 pub oauth_authorize: Arc<KeyedRateLimiter>,
28 pub password_reset: Arc<KeyedRateLimiter>,
29 pub account_creation: Arc<KeyedRateLimiter>,
30 pub refresh_session: Arc<KeyedRateLimiter>,
31 pub reset_password: Arc<KeyedRateLimiter>,
32 pub oauth_par: Arc<KeyedRateLimiter>,
33 pub oauth_introspect: Arc<KeyedRateLimiter>,
34 pub app_password: Arc<KeyedRateLimiter>,
35 pub email_update: Arc<KeyedRateLimiter>,
36}
37
38impl Default for RateLimiters {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl RateLimiters {
45 pub fn new() -> Self {
46 Self {
47 login: Arc::new(RateLimiter::keyed(
48 Quota::per_minute(NonZeroU32::new(10).unwrap())
49 )),
50 oauth_token: Arc::new(RateLimiter::keyed(
51 Quota::per_minute(NonZeroU32::new(30).unwrap())
52 )),
53 oauth_authorize: Arc::new(RateLimiter::keyed(
54 Quota::per_minute(NonZeroU32::new(10).unwrap())
55 )),
56 password_reset: Arc::new(RateLimiter::keyed(
57 Quota::per_hour(NonZeroU32::new(5).unwrap())
58 )),
59 account_creation: Arc::new(RateLimiter::keyed(
60 Quota::per_hour(NonZeroU32::new(10).unwrap())
61 )),
62 refresh_session: Arc::new(RateLimiter::keyed(
63 Quota::per_minute(NonZeroU32::new(60).unwrap())
64 )),
65 reset_password: Arc::new(RateLimiter::keyed(
66 Quota::per_minute(NonZeroU32::new(10).unwrap())
67 )),
68 oauth_par: Arc::new(RateLimiter::keyed(
69 Quota::per_minute(NonZeroU32::new(30).unwrap())
70 )),
71 oauth_introspect: Arc::new(RateLimiter::keyed(
72 Quota::per_minute(NonZeroU32::new(30).unwrap())
73 )),
74 app_password: Arc::new(RateLimiter::keyed(
75 Quota::per_minute(NonZeroU32::new(10).unwrap())
76 )),
77 email_update: Arc::new(RateLimiter::keyed(
78 Quota::per_hour(NonZeroU32::new(5).unwrap())
79 )),
80 }
81 }
82
83 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84 self.login = Arc::new(RateLimiter::keyed(
85 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86 ));
87 self
88 }
89
90 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
91 self.oauth_token = Arc::new(RateLimiter::keyed(
92 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
93 ));
94 self
95 }
96
97 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
98 self.oauth_authorize = Arc::new(RateLimiter::keyed(
99 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
100 ));
101 self
102 }
103
104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
105 self.password_reset = Arc::new(RateLimiter::keyed(
106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
107 ));
108 self
109 }
110
111 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
112 self.account_creation = Arc::new(RateLimiter::keyed(
113 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
114 ));
115 self
116 }
117
118 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
119 self.email_update = Arc::new(RateLimiter::keyed(
120 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
121 ));
122 self
123 }
124}
125
126pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
127 if let Some(forwarded) = headers.get("x-forwarded-for") {
128 if let Ok(value) = forwarded.to_str() {
129 if let Some(first_ip) = value.split(',').next() {
130 return first_ip.trim().to_string();
131 }
132 }
133 }
134
135 if let Some(real_ip) = headers.get("x-real-ip") {
136 if let Ok(value) = real_ip.to_str() {
137 return value.trim().to_string();
138 }
139 }
140
141 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
142}
143
144fn rate_limit_response() -> Response {
145 (
146 StatusCode::TOO_MANY_REQUESTS,
147 Json(serde_json::json!({
148 "error": "RateLimitExceeded",
149 "message": "Too many requests. Please try again later."
150 })),
151 )
152 .into_response()
153}
154
155pub async fn login_rate_limit(
156 ConnectInfo(addr): ConnectInfo<SocketAddr>,
157 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
158 request: Request<Body>,
159 next: Next,
160) -> Response {
161 let client_ip = extract_client_ip(request.headers(), Some(addr));
162
163 if limiters.login.check_key(&client_ip).is_err() {
164 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
165 return rate_limit_response();
166 }
167
168 next.run(request).await
169}
170
171pub async fn oauth_token_rate_limit(
172 ConnectInfo(addr): ConnectInfo<SocketAddr>,
173 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
174 request: Request<Body>,
175 next: Next,
176) -> Response {
177 let client_ip = extract_client_ip(request.headers(), Some(addr));
178
179 if limiters.oauth_token.check_key(&client_ip).is_err() {
180 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
181 return rate_limit_response();
182 }
183
184 next.run(request).await
185}
186
187pub async fn password_reset_rate_limit(
188 ConnectInfo(addr): ConnectInfo<SocketAddr>,
189 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
190 request: Request<Body>,
191 next: Next,
192) -> Response {
193 let client_ip = extract_client_ip(request.headers(), Some(addr));
194
195 if limiters.password_reset.check_key(&client_ip).is_err() {
196 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
197 return rate_limit_response();
198 }
199
200 next.run(request).await
201}
202
203pub async fn account_creation_rate_limit(
204 ConnectInfo(addr): ConnectInfo<SocketAddr>,
205 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
206 request: Request<Body>,
207 next: Next,
208) -> Response {
209 let client_ip = extract_client_ip(request.headers(), Some(addr));
210
211 if limiters.account_creation.check_key(&client_ip).is_err() {
212 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
213 return rate_limit_response();
214 }
215
216 next.run(request).await
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_rate_limiters_creation() {
225 let limiters = RateLimiters::new();
226 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
227 }
228
229 #[test]
230 fn test_rate_limiter_exhaustion() {
231 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
232 let key = "test_ip".to_string();
233
234 assert!(limiter.check_key(&key).is_ok());
235 assert!(limiter.check_key(&key).is_ok());
236 assert!(limiter.check_key(&key).is_err());
237 }
238
239 #[test]
240 fn test_different_keys_have_separate_limits() {
241 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
242
243 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
244 assert!(limiter.check_key(&"ip1".to_string()).is_err());
245 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
246 }
247
248 #[test]
249 fn test_builder_pattern() {
250 let limiters = RateLimiters::new()
251 .with_login_limit(20)
252 .with_oauth_token_limit(60)
253 .with_password_reset_limit(3)
254 .with_account_creation_limit(5);
255
256 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
257 }
258}