this repo has no description
1use axum::{
2 body::Body,
3 extract::ConnectInfo,
4 http::{HeaderMap, Request, StatusCode},
5 middleware::Next,
6 response::{IntoResponse, Response},
7 Json,
8};
9use governor::{
10 Quota, RateLimiter,
11 clock::DefaultClock,
12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13};
14use std::{
15 net::SocketAddr,
16 num::NonZeroU32,
17 sync::Arc,
18};
19
20pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
22
23#[derive(Clone)]
24pub struct RateLimiters {
25 pub login: Arc<KeyedRateLimiter>,
26 pub oauth_token: Arc<KeyedRateLimiter>,
27 pub password_reset: Arc<KeyedRateLimiter>,
28 pub account_creation: Arc<KeyedRateLimiter>,
29}
30
31impl Default for RateLimiters {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl RateLimiters {
38 pub fn new() -> Self {
39 Self {
40 login: Arc::new(RateLimiter::keyed(
41 Quota::per_minute(NonZeroU32::new(10).unwrap())
42 )),
43 oauth_token: Arc::new(RateLimiter::keyed(
44 Quota::per_minute(NonZeroU32::new(30).unwrap())
45 )),
46 password_reset: Arc::new(RateLimiter::keyed(
47 Quota::per_hour(NonZeroU32::new(5).unwrap())
48 )),
49 account_creation: Arc::new(RateLimiter::keyed(
50 Quota::per_hour(NonZeroU32::new(10).unwrap())
51 )),
52 }
53 }
54
55 pub fn with_login_limit(mut self, per_minute: u32) -> Self {
56 self.login = Arc::new(RateLimiter::keyed(
57 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
58 ));
59 self
60 }
61
62 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
63 self.oauth_token = Arc::new(RateLimiter::keyed(
64 Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
65 ));
66 self
67 }
68
69 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
70 self.password_reset = Arc::new(RateLimiter::keyed(
71 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
72 ));
73 self
74 }
75
76 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
77 self.account_creation = Arc::new(RateLimiter::keyed(
78 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
79 ));
80 self
81 }
82}
83
84fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
85 if let Some(forwarded) = headers.get("x-forwarded-for") {
86 if let Ok(value) = forwarded.to_str() {
87 if let Some(first_ip) = value.split(',').next() {
88 return first_ip.trim().to_string();
89 }
90 }
91 }
92
93 if let Some(real_ip) = headers.get("x-real-ip") {
94 if let Ok(value) = real_ip.to_str() {
95 return value.trim().to_string();
96 }
97 }
98
99 addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
100}
101
102fn rate_limit_response() -> Response {
103 (
104 StatusCode::TOO_MANY_REQUESTS,
105 Json(serde_json::json!({
106 "error": "RateLimitExceeded",
107 "message": "Too many requests. Please try again later."
108 })),
109 )
110 .into_response()
111}
112
113pub async fn login_rate_limit(
114 ConnectInfo(addr): ConnectInfo<SocketAddr>,
115 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
116 request: Request<Body>,
117 next: Next,
118) -> Response {
119 let client_ip = extract_client_ip(request.headers(), Some(addr));
120
121 if limiters.login.check_key(&client_ip).is_err() {
122 tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
123 return rate_limit_response();
124 }
125
126 next.run(request).await
127}
128
129pub async fn oauth_token_rate_limit(
130 ConnectInfo(addr): ConnectInfo<SocketAddr>,
131 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
132 request: Request<Body>,
133 next: Next,
134) -> Response {
135 let client_ip = extract_client_ip(request.headers(), Some(addr));
136
137 if limiters.oauth_token.check_key(&client_ip).is_err() {
138 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
139 return rate_limit_response();
140 }
141
142 next.run(request).await
143}
144
145pub async fn password_reset_rate_limit(
146 ConnectInfo(addr): ConnectInfo<SocketAddr>,
147 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
148 request: Request<Body>,
149 next: Next,
150) -> Response {
151 let client_ip = extract_client_ip(request.headers(), Some(addr));
152
153 if limiters.password_reset.check_key(&client_ip).is_err() {
154 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
155 return rate_limit_response();
156 }
157
158 next.run(request).await
159}
160
161pub async fn account_creation_rate_limit(
162 ConnectInfo(addr): ConnectInfo<SocketAddr>,
163 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
164 request: Request<Body>,
165 next: Next,
166) -> Response {
167 let client_ip = extract_client_ip(request.headers(), Some(addr));
168
169 if limiters.account_creation.check_key(&client_ip).is_err() {
170 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
171 return rate_limit_response();
172 }
173
174 next.run(request).await
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_rate_limiters_creation() {
183 let limiters = RateLimiters::new();
184 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
185 }
186
187 #[test]
188 fn test_rate_limiter_exhaustion() {
189 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
190 let key = "test_ip".to_string();
191
192 assert!(limiter.check_key(&key).is_ok());
193 assert!(limiter.check_key(&key).is_ok());
194 assert!(limiter.check_key(&key).is_err());
195 }
196
197 #[test]
198 fn test_different_keys_have_separate_limits() {
199 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
200
201 assert!(limiter.check_key(&"ip1".to_string()).is_ok());
202 assert!(limiter.check_key(&"ip1".to_string()).is_err());
203 assert!(limiter.check_key(&"ip2".to_string()).is_ok());
204 }
205
206 #[test]
207 fn test_builder_pattern() {
208 let limiters = RateLimiters::new()
209 .with_login_limit(20)
210 .with_oauth_token_limit(60)
211 .with_password_reset_limit(3)
212 .with_account_creation_limit(5);
213
214 assert!(limiters.login.check_key(&"test".to_string()).is_ok());
215 }
216}