this repo has no description
1use async_trait::async_trait; 2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3use std::sync::Arc; 4use std::time::Duration; 5#[derive(Debug, thiserror::Error)] 6pub enum CacheError { 7 #[error("Cache connection error: {0}")] 8 Connection(String), 9 #[error("Serialization error: {0}")] 10 Serialization(String), 11} 12#[async_trait] 13pub trait Cache: Send + Sync { 14 async fn get(&self, key: &str) -> Option<String>; 15 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 16 async fn delete(&self, key: &str) -> Result<(), CacheError>; 17 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 18 self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 19 } 20 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 21 let encoded = BASE64.encode(value); 22 self.set(key, &encoded, ttl).await 23 } 24} 25#[derive(Clone)] 26pub struct ValkeyCache { 27 conn: redis::aio::ConnectionManager, 28} 29impl ValkeyCache { 30 pub async fn new(url: &str) -> Result<Self, CacheError> { 31 let client = redis::Client::open(url) 32 .map_err(|e| CacheError::Connection(e.to_string()))?; 33 let manager = client 34 .get_connection_manager() 35 .await 36 .map_err(|e| CacheError::Connection(e.to_string()))?; 37 Ok(Self { conn: manager }) 38 } 39 pub fn connection(&self) -> redis::aio::ConnectionManager { 40 self.conn.clone() 41 } 42} 43#[async_trait] 44impl Cache for ValkeyCache { 45 async fn get(&self, key: &str) -> Option<String> { 46 let mut conn = self.conn.clone(); 47 redis::cmd("GET") 48 .arg(key) 49 .query_async::<Option<String>>(&mut conn) 50 .await 51 .ok() 52 .flatten() 53 } 54 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 55 let mut conn = self.conn.clone(); 56 redis::cmd("SET") 57 .arg(key) 58 .arg(value) 59 .arg("EX") 60 .arg(ttl.as_secs() as i64) 61 .query_async::<()>(&mut conn) 62 .await 63 .map_err(|e| CacheError::Connection(e.to_string())) 64 } 65 async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 let mut conn = self.conn.clone(); 67 redis::cmd("DEL") 68 .arg(key) 69 .query_async::<()>(&mut conn) 70 .await 71 .map_err(|e| CacheError::Connection(e.to_string())) 72 } 73} 74pub struct NoOpCache; 75#[async_trait] 76impl Cache for NoOpCache { 77 async fn get(&self, _key: &str) -> Option<String> { 78 None 79 } 80 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 81 Ok(()) 82 } 83 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 84 Ok(()) 85 } 86} 87#[async_trait] 88pub trait DistributedRateLimiter: Send + Sync { 89 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 90} 91#[derive(Clone)] 92pub struct RedisRateLimiter { 93 conn: redis::aio::ConnectionManager, 94} 95impl RedisRateLimiter { 96 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 97 Self { conn } 98 } 99} 100#[async_trait] 101impl DistributedRateLimiter for RedisRateLimiter { 102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 103 let mut conn = self.conn.clone(); 104 let full_key = format!("rl:{}", key); 105 let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 106 let count: Result<i64, _> = redis::cmd("INCR") 107 .arg(&full_key) 108 .query_async(&mut conn) 109 .await; 110 let count = match count { 111 Ok(c) => c, 112 Err(e) => { 113 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 114 return true; 115 } 116 }; 117 if count == 1 { 118 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 119 .arg(&full_key) 120 .arg(window_secs) 121 .query_async(&mut conn) 122 .await; 123 } 124 count <= limit as i64 125 } 126} 127pub struct NoOpRateLimiter; 128#[async_trait] 129impl DistributedRateLimiter for NoOpRateLimiter { 130 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 131 true 132 } 133} 134pub enum CacheBackend { 135 Valkey(ValkeyCache), 136 NoOp, 137} 138impl CacheBackend { 139 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 140 match self { 141 CacheBackend::Valkey(cache) => { 142 Arc::new(RedisRateLimiter::new(cache.connection())) 143 } 144 CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 145 } 146 } 147} 148#[async_trait] 149impl Cache for CacheBackend { 150 async fn get(&self, key: &str) -> Option<String> { 151 match self { 152 CacheBackend::Valkey(c) => c.get(key).await, 153 CacheBackend::NoOp => None, 154 } 155 } 156 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 157 match self { 158 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 159 CacheBackend::NoOp => Ok(()), 160 } 161 } 162 async fn delete(&self, key: &str) -> Result<(), CacheError> { 163 match self { 164 CacheBackend::Valkey(c) => c.delete(key).await, 165 CacheBackend::NoOp => Ok(()), 166 } 167 } 168} 169pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 170 match std::env::var("VALKEY_URL") { 171 Ok(url) => match ValkeyCache::new(&url).await { 172 Ok(cache) => { 173 tracing::info!("Connected to Valkey cache at {}", url); 174 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 175 (Arc::new(cache), rate_limiter) 176 } 177 Err(e) => { 178 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 179 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 180 } 181 }, 182 Err(_) => { 183 tracing::info!("VALKEY_URL not set. Running without cache."); 184 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 185 } 186 } 187}