The smokesignal.events web application
1use async_trait::async_trait;
2use deadpool_redis::{Pool as RedisPool, redis::AsyncCommands};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use crate::throttle::{Throttle, ThrottleError};
6
7/// Redis-backed rolling window throttle
8///
9/// Uses a rolling window strategy where time is divided into fixed-size windows
10/// and the throttle checks multiple consecutive windows to determine if an action
11/// should be allowed.
12///
13/// For example, with a 5-minute window size, 3 windows, and a limit of 15 actions:
14/// - Time is divided into 5-minute buckets
15/// - The throttle checks the current window and the previous 2 windows (15 minutes total)
16/// - If the total count across all 3 windows is >= 15, the action is throttled
17///
18/// Redis keys have the format: `{namespace}:{key}:{window_timestamp}`
19/// Each key stores a counter and expires after (window_size * num_windows) seconds
20pub struct RedisRollingWindowThrottle {
21 /// Redis connection pool
22 redis_pool: RedisPool,
23
24 /// Namespace prefix for Redis keys (e.g., "throttle:email")
25 namespace: String,
26
27 /// Size of each time window in seconds (e.g., 300 for 5 minutes)
28 window_size_seconds: u64,
29
30 /// Number of windows to check (e.g., 3 for checking current + 2 previous)
31 num_windows: usize,
32
33 /// Maximum number of actions allowed across all windows
34 max_actions: u64,
35}
36
37impl RedisRollingWindowThrottle {
38 /// Create a new Redis rolling window throttle
39 ///
40 /// # Arguments
41 /// * `redis_pool` - Redis connection pool
42 /// * `namespace` - Namespace for Redis keys (e.g., "throttle:email")
43 /// * `window_size_seconds` - Size of each window in seconds (e.g., 300 for 5 minutes)
44 /// * `num_windows` - Number of windows to check (e.g., 3)
45 /// * `max_actions` - Maximum actions allowed across all windows (e.g., 15)
46 ///
47 /// # Example
48 /// ```ignore
49 /// // Throttle to 15 actions per 15 minutes using 5-minute windows
50 /// let throttle = RedisRollingWindowThrottle::new(
51 /// redis_pool,
52 /// "throttle:email".to_string(),
53 /// 300, // 5 minutes
54 /// 3, // check 3 windows (15 minutes total)
55 /// 15, // max 15 actions
56 /// );
57 /// ```
58 pub fn new(
59 redis_pool: RedisPool,
60 namespace: String,
61 window_size_seconds: u64,
62 num_windows: usize,
63 max_actions: u64,
64 ) -> Self {
65 Self {
66 redis_pool,
67 namespace,
68 window_size_seconds,
69 num_windows,
70 max_actions,
71 }
72 }
73
74 /// Get the current window timestamp
75 fn current_window(&self) -> Result<u64, ThrottleError> {
76 let now = SystemTime::now()
77 .duration_since(UNIX_EPOCH)
78 .map_err(|e| ThrottleError::TimeError(e.to_string()))?
79 .as_secs();
80
81 Ok(now / self.window_size_seconds)
82 }
83
84 /// Generate Redis key for a window
85 fn window_key(&self, key: &str, window: u64) -> String {
86 format!("{}:{}:{}", self.namespace, key, window)
87 }
88
89 /// Get the sum of actions across all windows
90 async fn get_window_sum(&self, key: &str) -> Result<u64, ThrottleError> {
91 let current_window = self.current_window()?;
92 let mut conn = self
93 .redis_pool
94 .get()
95 .await
96 .map_err(|e| ThrottleError::RedisError(e.to_string()))?;
97
98 let mut total: u64 = 0;
99
100 // Check current window and previous (num_windows - 1) windows
101 for i in 0..self.num_windows {
102 let window = current_window.saturating_sub(i as u64);
103 let redis_key = self.window_key(key, window);
104
105 let count: Option<u64> = conn
106 .get(&redis_key)
107 .await
108 .map_err(|e| ThrottleError::RedisError(e.to_string()))?;
109
110 if let Some(count) = count {
111 total += count;
112 }
113 }
114
115 Ok(total)
116 }
117
118 /// Record an action in the current window
119 async fn record(&self, key: &str) -> Result<(), ThrottleError> {
120 let current_window = self.current_window()?;
121 let redis_key = self.window_key(key, current_window);
122 let ttl_seconds = self.window_size_seconds * self.num_windows as u64;
123
124 let mut conn = self
125 .redis_pool
126 .get()
127 .await
128 .map_err(|e| ThrottleError::RedisError(e.to_string()))?;
129
130 // Increment the counter for this window
131 let _: u64 = conn
132 .incr(&redis_key, 1)
133 .await
134 .map_err(|e| ThrottleError::RedisError(e.to_string()))?;
135
136 // Set expiration (idempotent - resets TTL each time)
137 let _: bool = conn
138 .expire(&redis_key, ttl_seconds as i64)
139 .await
140 .map_err(|e| ThrottleError::RedisError(e.to_string()))?;
141
142 Ok(())
143 }
144}
145
146#[async_trait]
147impl Throttle for RedisRollingWindowThrottle {
148 async fn check(&self, key: &str) -> Result<bool, ThrottleError> {
149 let total = self.get_window_sum(key).await?;
150 Ok(total < self.max_actions)
151 }
152
153 async fn check_and_record(&self, key: &str) -> Result<bool, ThrottleError> {
154 let total = self.get_window_sum(key).await?;
155
156 if total >= self.max_actions {
157 return Ok(false);
158 }
159
160 self.record(key).await?;
161 Ok(true)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 fn create_test_pool() -> RedisPool {
170 // Create a test pool - in tests this won't be used for actual Redis operations
171 use crate::storage::cache::create_cache_pool;
172 create_cache_pool("redis://localhost:6379").unwrap()
173 }
174
175 #[test]
176 fn test_window_calculation() {
177 let throttle = RedisRollingWindowThrottle::new(
178 create_test_pool(),
179 "test".to_string(),
180 300, // 5 minutes
181 3,
182 15,
183 );
184
185 // Test window key format
186 let key = throttle.window_key("user@example.com", 12345);
187 assert_eq!(key, "test:user@example.com:12345");
188 }
189
190 #[test]
191 fn test_window_key_format() {
192 let throttle = RedisRollingWindowThrottle::new(
193 create_test_pool(),
194 "throttle:email".to_string(),
195 300,
196 3,
197 15,
198 );
199
200 let key = throttle.window_key("test@example.com", 1000);
201 assert_eq!(key, "throttle:email:test@example.com:1000");
202 }
203}