this repo has no description
1use async_trait::async_trait; 2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3use std::sync::Arc; 4use std::time::Duration; 5 6#[derive(Debug, thiserror::Error)] 7pub enum CacheError { 8 #[error("Cache connection error: {0}")] 9 Connection(String), 10 #[error("Serialization error: {0}")] 11 Serialization(String), 12} 13 14#[async_trait] 15pub trait Cache: Send + Sync { 16 async fn get(&self, key: &str) -> Option<String>; 17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 18 async fn delete(&self, key: &str) -> Result<(), CacheError>; 19 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 20 self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 21 } 22 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 23 let encoded = BASE64.encode(value); 24 self.set(key, &encoded, ttl).await 25 } 26} 27 28#[derive(Clone)] 29pub struct ValkeyCache { 30 conn: redis::aio::ConnectionManager, 31} 32 33impl ValkeyCache { 34 pub async fn new(url: &str) -> Result<Self, CacheError> { 35 let client = redis::Client::open(url) 36 .map_err(|e| CacheError::Connection(e.to_string()))?; 37 let manager = client 38 .get_connection_manager() 39 .await 40 .map_err(|e| CacheError::Connection(e.to_string()))?; 41 Ok(Self { conn: manager }) 42 } 43 44 pub fn connection(&self) -> redis::aio::ConnectionManager { 45 self.conn.clone() 46 } 47} 48 49#[async_trait] 50impl Cache for ValkeyCache { 51 async fn get(&self, key: &str) -> Option<String> { 52 let mut conn = self.conn.clone(); 53 redis::cmd("GET") 54 .arg(key) 55 .query_async::<Option<String>>(&mut conn) 56 .await 57 .ok() 58 .flatten() 59 } 60 61 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 62 let mut conn = self.conn.clone(); 63 redis::cmd("SET") 64 .arg(key) 65 .arg(value) 66 .arg("EX") 67 .arg(ttl.as_secs() as i64) 68 .query_async::<()>(&mut conn) 69 .await 70 .map_err(|e| CacheError::Connection(e.to_string())) 71 } 72 73 async fn delete(&self, key: &str) -> Result<(), CacheError> { 74 let mut conn = self.conn.clone(); 75 redis::cmd("DEL") 76 .arg(key) 77 .query_async::<()>(&mut conn) 78 .await 79 .map_err(|e| CacheError::Connection(e.to_string())) 80 } 81} 82 83pub struct NoOpCache; 84 85#[async_trait] 86impl Cache for NoOpCache { 87 async fn get(&self, _key: &str) -> Option<String> { 88 None 89 } 90 91 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 92 Ok(()) 93 } 94 95 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 96 Ok(()) 97 } 98} 99 100#[async_trait] 101pub trait DistributedRateLimiter: Send + Sync { 102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 103} 104 105#[derive(Clone)] 106pub struct RedisRateLimiter { 107 conn: redis::aio::ConnectionManager, 108} 109 110impl RedisRateLimiter { 111 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 112 Self { conn } 113 } 114} 115 116#[async_trait] 117impl DistributedRateLimiter for RedisRateLimiter { 118 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 119 let mut conn = self.conn.clone(); 120 let full_key = format!("rl:{}", key); 121 let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 122 let count: Result<i64, _> = redis::cmd("INCR") 123 .arg(&full_key) 124 .query_async(&mut conn) 125 .await; 126 let count = match count { 127 Ok(c) => c, 128 Err(e) => { 129 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 130 return true; 131 } 132 }; 133 if count == 1 { 134 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 135 .arg(&full_key) 136 .arg(window_secs) 137 .query_async(&mut conn) 138 .await; 139 } 140 count <= limit as i64 141 } 142} 143 144pub struct NoOpRateLimiter; 145 146#[async_trait] 147impl DistributedRateLimiter for NoOpRateLimiter { 148 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 149 true 150 } 151} 152 153pub enum CacheBackend { 154 Valkey(ValkeyCache), 155 NoOp, 156} 157 158impl CacheBackend { 159 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 160 match self { 161 CacheBackend::Valkey(cache) => { 162 Arc::new(RedisRateLimiter::new(cache.connection())) 163 } 164 CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 165 } 166 } 167} 168 169#[async_trait] 170impl Cache for CacheBackend { 171 async fn get(&self, key: &str) -> Option<String> { 172 match self { 173 CacheBackend::Valkey(c) => c.get(key).await, 174 CacheBackend::NoOp => None, 175 } 176 } 177 178 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 179 match self { 180 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 181 CacheBackend::NoOp => Ok(()), 182 } 183 } 184 185 async fn delete(&self, key: &str) -> Result<(), CacheError> { 186 match self { 187 CacheBackend::Valkey(c) => c.delete(key).await, 188 CacheBackend::NoOp => Ok(()), 189 } 190 } 191} 192 193pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 194 match std::env::var("VALKEY_URL") { 195 Ok(url) => match ValkeyCache::new(&url).await { 196 Ok(cache) => { 197 tracing::info!("Connected to Valkey cache at {}", url); 198 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 199 (Arc::new(cache), rate_limiter) 200 } 201 Err(e) => { 202 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 203 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 204 } 205 }, 206 Err(_) => { 207 tracing::info!("VALKEY_URL not set. Running without cache."); 208 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 209 } 210 } 211}