this repo has no description
1use async_trait::async_trait; 2use std::sync::Arc; 3use std::time::Duration; 4 5#[derive(Debug, thiserror::Error)] 6pub enum CacheError { 7 #[error("Cache connection error: {0}")] 8 Connection(String), 9 #[error("Serialization error: {0}")] 10 Serialization(String), 11} 12 13#[async_trait] 14pub trait Cache: Send + Sync { 15 async fn get(&self, key: &str) -> Option<String>; 16 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 17 async fn delete(&self, key: &str) -> Result<(), CacheError>; 18} 19 20#[derive(Clone)] 21pub struct ValkeyCache { 22 conn: redis::aio::ConnectionManager, 23} 24 25impl ValkeyCache { 26 pub async fn new(url: &str) -> Result<Self, CacheError> { 27 let client = redis::Client::open(url) 28 .map_err(|e| CacheError::Connection(e.to_string()))?; 29 let manager = client 30 .get_connection_manager() 31 .await 32 .map_err(|e| CacheError::Connection(e.to_string()))?; 33 Ok(Self { conn: manager }) 34 } 35 36 pub fn connection(&self) -> redis::aio::ConnectionManager { 37 self.conn.clone() 38 } 39} 40 41#[async_trait] 42impl Cache for ValkeyCache { 43 async fn get(&self, key: &str) -> Option<String> { 44 let mut conn = self.conn.clone(); 45 redis::cmd("GET") 46 .arg(key) 47 .query_async::<Option<String>>(&mut conn) 48 .await 49 .ok() 50 .flatten() 51 } 52 53 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 54 let mut conn = self.conn.clone(); 55 redis::cmd("SET") 56 .arg(key) 57 .arg(value) 58 .arg("EX") 59 .arg(ttl.as_secs() as i64) 60 .query_async::<()>(&mut conn) 61 .await 62 .map_err(|e| CacheError::Connection(e.to_string())) 63 } 64 65 async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 let mut conn = self.conn.clone(); 67 redis::cmd("DEL") 68 .arg(key) 69 .query_async::<()>(&mut conn) 70 .await 71 .map_err(|e| CacheError::Connection(e.to_string())) 72 } 73} 74 75pub struct NoOpCache; 76 77#[async_trait] 78impl Cache for NoOpCache { 79 async fn get(&self, _key: &str) -> Option<String> { 80 None 81 } 82 83 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 84 Ok(()) 85 } 86 87 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 88 Ok(()) 89 } 90} 91 92#[async_trait] 93pub trait DistributedRateLimiter: Send + Sync { 94 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 95} 96 97#[derive(Clone)] 98pub struct RedisRateLimiter { 99 conn: redis::aio::ConnectionManager, 100} 101 102impl RedisRateLimiter { 103 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 104 Self { conn } 105 } 106} 107 108#[async_trait] 109impl DistributedRateLimiter for RedisRateLimiter { 110 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 111 let mut conn = self.conn.clone(); 112 let full_key = format!("rl:{}", key); 113 let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 114 115 let count: Result<i64, _> = redis::cmd("INCR") 116 .arg(&full_key) 117 .query_async(&mut conn) 118 .await; 119 120 let count = match count { 121 Ok(c) => c, 122 Err(e) => { 123 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 124 return true; 125 } 126 }; 127 128 if count == 1 { 129 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 130 .arg(&full_key) 131 .arg(window_secs) 132 .query_async(&mut conn) 133 .await; 134 } 135 136 count <= limit as i64 137 } 138} 139 140pub struct NoOpRateLimiter; 141 142#[async_trait] 143impl DistributedRateLimiter for NoOpRateLimiter { 144 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 145 true 146 } 147} 148 149pub enum CacheBackend { 150 Valkey(ValkeyCache), 151 NoOp, 152} 153 154impl CacheBackend { 155 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 156 match self { 157 CacheBackend::Valkey(cache) => { 158 Arc::new(RedisRateLimiter::new(cache.connection())) 159 } 160 CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 161 } 162 } 163} 164 165#[async_trait] 166impl Cache for CacheBackend { 167 async fn get(&self, key: &str) -> Option<String> { 168 match self { 169 CacheBackend::Valkey(c) => c.get(key).await, 170 CacheBackend::NoOp => None, 171 } 172 } 173 174 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 175 match self { 176 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 177 CacheBackend::NoOp => Ok(()), 178 } 179 } 180 181 async fn delete(&self, key: &str) -> Result<(), CacheError> { 182 match self { 183 CacheBackend::Valkey(c) => c.delete(key).await, 184 CacheBackend::NoOp => Ok(()), 185 } 186 } 187} 188 189pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 190 match std::env::var("VALKEY_URL") { 191 Ok(url) => match ValkeyCache::new(&url).await { 192 Ok(cache) => { 193 tracing::info!("Connected to Valkey cache at {}", url); 194 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 195 (Arc::new(cache), rate_limiter) 196 } 197 Err(e) => { 198 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 199 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 200 } 201 }, 202 Err(_) => { 203 tracing::info!("VALKEY_URL not set. Running without cache."); 204 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 205 } 206 } 207}