use async_trait::async_trait; use deadpool_redis::{Pool as RedisPool, redis::AsyncCommands}; use std::time::{SystemTime, UNIX_EPOCH}; use crate::throttle::{Throttle, ThrottleError}; /// Redis-backed rolling window throttle /// /// Uses a rolling window strategy where time is divided into fixed-size windows /// and the throttle checks multiple consecutive windows to determine if an action /// should be allowed. /// /// For example, with a 5-minute window size, 3 windows, and a limit of 15 actions: /// - Time is divided into 5-minute buckets /// - The throttle checks the current window and the previous 2 windows (15 minutes total) /// - If the total count across all 3 windows is >= 15, the action is throttled /// /// Redis keys have the format: `{namespace}:{key}:{window_timestamp}` /// Each key stores a counter and expires after (window_size * num_windows) seconds pub struct RedisRollingWindowThrottle { /// Redis connection pool redis_pool: RedisPool, /// Namespace prefix for Redis keys (e.g., "throttle:email") namespace: String, /// Size of each time window in seconds (e.g., 300 for 5 minutes) window_size_seconds: u64, /// Number of windows to check (e.g., 3 for checking current + 2 previous) num_windows: usize, /// Maximum number of actions allowed across all windows max_actions: u64, } impl RedisRollingWindowThrottle { /// Create a new Redis rolling window throttle /// /// # Arguments /// * `redis_pool` - Redis connection pool /// * `namespace` - Namespace for Redis keys (e.g., "throttle:email") /// * `window_size_seconds` - Size of each window in seconds (e.g., 300 for 5 minutes) /// * `num_windows` - Number of windows to check (e.g., 3) /// * `max_actions` - Maximum actions allowed across all windows (e.g., 15) /// /// # Example /// ```ignore /// // Throttle to 15 actions per 15 minutes using 5-minute windows /// let throttle = RedisRollingWindowThrottle::new( /// redis_pool, /// "throttle:email".to_string(), /// 300, // 5 minutes /// 3, // check 3 windows (15 minutes total) /// 15, // max 15 actions /// ); /// ``` pub fn new( redis_pool: RedisPool, namespace: String, window_size_seconds: u64, num_windows: usize, max_actions: u64, ) -> Self { Self { redis_pool, namespace, window_size_seconds, num_windows, max_actions, } } /// Get the current window timestamp fn current_window(&self) -> Result { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|e| ThrottleError::TimeError(e.to_string()))? .as_secs(); Ok(now / self.window_size_seconds) } /// Generate Redis key for a window fn window_key(&self, key: &str, window: u64) -> String { format!("{}:{}:{}", self.namespace, key, window) } /// Get the sum of actions across all windows async fn get_window_sum(&self, key: &str) -> Result { let current_window = self.current_window()?; let mut conn = self .redis_pool .get() .await .map_err(|e| ThrottleError::RedisError(e.to_string()))?; let mut total: u64 = 0; // Check current window and previous (num_windows - 1) windows for i in 0..self.num_windows { let window = current_window.saturating_sub(i as u64); let redis_key = self.window_key(key, window); let count: Option = conn .get(&redis_key) .await .map_err(|e| ThrottleError::RedisError(e.to_string()))?; if let Some(count) = count { total += count; } } Ok(total) } /// Record an action in the current window async fn record(&self, key: &str) -> Result<(), ThrottleError> { let current_window = self.current_window()?; let redis_key = self.window_key(key, current_window); let ttl_seconds = self.window_size_seconds * self.num_windows as u64; let mut conn = self .redis_pool .get() .await .map_err(|e| ThrottleError::RedisError(e.to_string()))?; // Increment the counter for this window let _: u64 = conn .incr(&redis_key, 1) .await .map_err(|e| ThrottleError::RedisError(e.to_string()))?; // Set expiration (idempotent - resets TTL each time) let _: bool = conn .expire(&redis_key, ttl_seconds as i64) .await .map_err(|e| ThrottleError::RedisError(e.to_string()))?; Ok(()) } } #[async_trait] impl Throttle for RedisRollingWindowThrottle { async fn check(&self, key: &str) -> Result { let total = self.get_window_sum(key).await?; Ok(total < self.max_actions) } async fn check_and_record(&self, key: &str) -> Result { let total = self.get_window_sum(key).await?; if total >= self.max_actions { return Ok(false); } self.record(key).await?; Ok(true) } } #[cfg(test)] mod tests { use super::*; fn create_test_pool() -> RedisPool { // Create a test pool - in tests this won't be used for actual Redis operations use crate::storage::cache::create_cache_pool; create_cache_pool("redis://localhost:6379").unwrap() } #[test] fn test_window_calculation() { let throttle = RedisRollingWindowThrottle::new( create_test_pool(), "test".to_string(), 300, // 5 minutes 3, 15, ); // Test window key format let key = throttle.window_key("user@example.com", 12345); assert_eq!(key, "test:user@example.com:12345"); } #[test] fn test_window_key_format() { let throttle = RedisRollingWindowThrottle::new( create_test_pool(), "throttle:email".to_string(), 300, 3, 15, ); let key = throttle.window_key("test@example.com", 1000); assert_eq!(key, "throttle:email:test@example.com:1000"); } }