The smokesignal.events web application
at main 203 lines 6.4 kB view raw
1use async_trait::async_trait; 2use deadpool_redis::{Pool as RedisPool, redis::AsyncCommands}; 3use std::time::{SystemTime, UNIX_EPOCH}; 4 5use crate::throttle::{Throttle, ThrottleError}; 6 7/// Redis-backed rolling window throttle 8/// 9/// Uses a rolling window strategy where time is divided into fixed-size windows 10/// and the throttle checks multiple consecutive windows to determine if an action 11/// should be allowed. 12/// 13/// For example, with a 5-minute window size, 3 windows, and a limit of 15 actions: 14/// - Time is divided into 5-minute buckets 15/// - The throttle checks the current window and the previous 2 windows (15 minutes total) 16/// - If the total count across all 3 windows is >= 15, the action is throttled 17/// 18/// Redis keys have the format: `{namespace}:{key}:{window_timestamp}` 19/// Each key stores a counter and expires after (window_size * num_windows) seconds 20pub struct RedisRollingWindowThrottle { 21 /// Redis connection pool 22 redis_pool: RedisPool, 23 24 /// Namespace prefix for Redis keys (e.g., "throttle:email") 25 namespace: String, 26 27 /// Size of each time window in seconds (e.g., 300 for 5 minutes) 28 window_size_seconds: u64, 29 30 /// Number of windows to check (e.g., 3 for checking current + 2 previous) 31 num_windows: usize, 32 33 /// Maximum number of actions allowed across all windows 34 max_actions: u64, 35} 36 37impl RedisRollingWindowThrottle { 38 /// Create a new Redis rolling window throttle 39 /// 40 /// # Arguments 41 /// * `redis_pool` - Redis connection pool 42 /// * `namespace` - Namespace for Redis keys (e.g., "throttle:email") 43 /// * `window_size_seconds` - Size of each window in seconds (e.g., 300 for 5 minutes) 44 /// * `num_windows` - Number of windows to check (e.g., 3) 45 /// * `max_actions` - Maximum actions allowed across all windows (e.g., 15) 46 /// 47 /// # Example 48 /// ```ignore 49 /// // Throttle to 15 actions per 15 minutes using 5-minute windows 50 /// let throttle = RedisRollingWindowThrottle::new( 51 /// redis_pool, 52 /// "throttle:email".to_string(), 53 /// 300, // 5 minutes 54 /// 3, // check 3 windows (15 minutes total) 55 /// 15, // max 15 actions 56 /// ); 57 /// ``` 58 pub fn new( 59 redis_pool: RedisPool, 60 namespace: String, 61 window_size_seconds: u64, 62 num_windows: usize, 63 max_actions: u64, 64 ) -> Self { 65 Self { 66 redis_pool, 67 namespace, 68 window_size_seconds, 69 num_windows, 70 max_actions, 71 } 72 } 73 74 /// Get the current window timestamp 75 fn current_window(&self) -> Result<u64, ThrottleError> { 76 let now = SystemTime::now() 77 .duration_since(UNIX_EPOCH) 78 .map_err(|e| ThrottleError::TimeError(e.to_string()))? 79 .as_secs(); 80 81 Ok(now / self.window_size_seconds) 82 } 83 84 /// Generate Redis key for a window 85 fn window_key(&self, key: &str, window: u64) -> String { 86 format!("{}:{}:{}", self.namespace, key, window) 87 } 88 89 /// Get the sum of actions across all windows 90 async fn get_window_sum(&self, key: &str) -> Result<u64, ThrottleError> { 91 let current_window = self.current_window()?; 92 let mut conn = self 93 .redis_pool 94 .get() 95 .await 96 .map_err(|e| ThrottleError::RedisError(e.to_string()))?; 97 98 let mut total: u64 = 0; 99 100 // Check current window and previous (num_windows - 1) windows 101 for i in 0..self.num_windows { 102 let window = current_window.saturating_sub(i as u64); 103 let redis_key = self.window_key(key, window); 104 105 let count: Option<u64> = conn 106 .get(&redis_key) 107 .await 108 .map_err(|e| ThrottleError::RedisError(e.to_string()))?; 109 110 if let Some(count) = count { 111 total += count; 112 } 113 } 114 115 Ok(total) 116 } 117 118 /// Record an action in the current window 119 async fn record(&self, key: &str) -> Result<(), ThrottleError> { 120 let current_window = self.current_window()?; 121 let redis_key = self.window_key(key, current_window); 122 let ttl_seconds = self.window_size_seconds * self.num_windows as u64; 123 124 let mut conn = self 125 .redis_pool 126 .get() 127 .await 128 .map_err(|e| ThrottleError::RedisError(e.to_string()))?; 129 130 // Increment the counter for this window 131 let _: u64 = conn 132 .incr(&redis_key, 1) 133 .await 134 .map_err(|e| ThrottleError::RedisError(e.to_string()))?; 135 136 // Set expiration (idempotent - resets TTL each time) 137 let _: bool = conn 138 .expire(&redis_key, ttl_seconds as i64) 139 .await 140 .map_err(|e| ThrottleError::RedisError(e.to_string()))?; 141 142 Ok(()) 143 } 144} 145 146#[async_trait] 147impl Throttle for RedisRollingWindowThrottle { 148 async fn check(&self, key: &str) -> Result<bool, ThrottleError> { 149 let total = self.get_window_sum(key).await?; 150 Ok(total < self.max_actions) 151 } 152 153 async fn check_and_record(&self, key: &str) -> Result<bool, ThrottleError> { 154 let total = self.get_window_sum(key).await?; 155 156 if total >= self.max_actions { 157 return Ok(false); 158 } 159 160 self.record(key).await?; 161 Ok(true) 162 } 163} 164 165#[cfg(test)] 166mod tests { 167 use super::*; 168 169 fn create_test_pool() -> RedisPool { 170 // Create a test pool - in tests this won't be used for actual Redis operations 171 use crate::storage::cache::create_cache_pool; 172 create_cache_pool("redis://localhost:6379").unwrap() 173 } 174 175 #[test] 176 fn test_window_calculation() { 177 let throttle = RedisRollingWindowThrottle::new( 178 create_test_pool(), 179 "test".to_string(), 180 300, // 5 minutes 181 3, 182 15, 183 ); 184 185 // Test window key format 186 let key = throttle.window_key("user@example.com", 12345); 187 assert_eq!(key, "test:user@example.com:12345"); 188 } 189 190 #[test] 191 fn test_window_key_format() { 192 let throttle = RedisRollingWindowThrottle::new( 193 create_test_pool(), 194 "throttle:email".to_string(), 195 300, 196 3, 197 15, 198 ); 199 200 let key = throttle.window_key("test@example.com", 1000); 201 assert_eq!(key, "throttle:email:test@example.com:1000"); 202 } 203}