this repo has no description
1use async_trait::async_trait; 2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3use std::sync::Arc; 4use std::time::Duration; 5 6#[derive(Debug, thiserror::Error)] 7pub enum CacheError { 8 #[error("Cache connection error: {0}")] 9 Connection(String), 10 #[error("Serialization error: {0}")] 11 Serialization(String), 12} 13 14#[async_trait] 15pub trait Cache: Send + Sync { 16 async fn get(&self, key: &str) -> Option<String>; 17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 18 async fn delete(&self, key: &str) -> Result<(), CacheError>; 19 20 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 21 self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 22 } 23 24 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 25 let encoded = BASE64.encode(value); 26 self.set(key, &encoded, ttl).await 27 } 28} 29 30#[derive(Clone)] 31pub struct ValkeyCache { 32 conn: redis::aio::ConnectionManager, 33} 34 35impl ValkeyCache { 36 pub async fn new(url: &str) -> Result<Self, CacheError> { 37 let client = redis::Client::open(url) 38 .map_err(|e| CacheError::Connection(e.to_string()))?; 39 let manager = client 40 .get_connection_manager() 41 .await 42 .map_err(|e| CacheError::Connection(e.to_string()))?; 43 Ok(Self { conn: manager }) 44 } 45 46 pub fn connection(&self) -> redis::aio::ConnectionManager { 47 self.conn.clone() 48 } 49} 50 51#[async_trait] 52impl Cache for ValkeyCache { 53 async fn get(&self, key: &str) -> Option<String> { 54 let mut conn = self.conn.clone(); 55 redis::cmd("GET") 56 .arg(key) 57 .query_async::<Option<String>>(&mut conn) 58 .await 59 .ok() 60 .flatten() 61 } 62 63 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 64 let mut conn = self.conn.clone(); 65 redis::cmd("SET") 66 .arg(key) 67 .arg(value) 68 .arg("EX") 69 .arg(ttl.as_secs() as i64) 70 .query_async::<()>(&mut conn) 71 .await 72 .map_err(|e| CacheError::Connection(e.to_string())) 73 } 74 75 async fn delete(&self, key: &str) -> Result<(), CacheError> { 76 let mut conn = self.conn.clone(); 77 redis::cmd("DEL") 78 .arg(key) 79 .query_async::<()>(&mut conn) 80 .await 81 .map_err(|e| CacheError::Connection(e.to_string())) 82 } 83} 84 85pub struct NoOpCache; 86 87#[async_trait] 88impl Cache for NoOpCache { 89 async fn get(&self, _key: &str) -> Option<String> { 90 None 91 } 92 93 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 94 Ok(()) 95 } 96 97 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 98 Ok(()) 99 } 100} 101 102#[async_trait] 103pub trait DistributedRateLimiter: Send + Sync { 104 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 105} 106 107#[derive(Clone)] 108pub struct RedisRateLimiter { 109 conn: redis::aio::ConnectionManager, 110} 111 112impl RedisRateLimiter { 113 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 114 Self { conn } 115 } 116} 117 118#[async_trait] 119impl DistributedRateLimiter for RedisRateLimiter { 120 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 121 let mut conn = self.conn.clone(); 122 let full_key = format!("rl:{}", key); 123 let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 124 125 let count: Result<i64, _> = redis::cmd("INCR") 126 .arg(&full_key) 127 .query_async(&mut conn) 128 .await; 129 130 let count = match count { 131 Ok(c) => c, 132 Err(e) => { 133 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 134 return true; 135 } 136 }; 137 138 if count == 1 { 139 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 140 .arg(&full_key) 141 .arg(window_secs) 142 .query_async(&mut conn) 143 .await; 144 } 145 146 count <= limit as i64 147 } 148} 149 150pub struct NoOpRateLimiter; 151 152#[async_trait] 153impl DistributedRateLimiter for NoOpRateLimiter { 154 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 155 true 156 } 157} 158 159pub enum CacheBackend { 160 Valkey(ValkeyCache), 161 NoOp, 162} 163 164impl CacheBackend { 165 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 166 match self { 167 CacheBackend::Valkey(cache) => { 168 Arc::new(RedisRateLimiter::new(cache.connection())) 169 } 170 CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 171 } 172 } 173} 174 175#[async_trait] 176impl Cache for CacheBackend { 177 async fn get(&self, key: &str) -> Option<String> { 178 match self { 179 CacheBackend::Valkey(c) => c.get(key).await, 180 CacheBackend::NoOp => None, 181 } 182 } 183 184 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 185 match self { 186 CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 187 CacheBackend::NoOp => Ok(()), 188 } 189 } 190 191 async fn delete(&self, key: &str) -> Result<(), CacheError> { 192 match self { 193 CacheBackend::Valkey(c) => c.delete(key).await, 194 CacheBackend::NoOp => Ok(()), 195 } 196 } 197} 198 199pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 200 match std::env::var("VALKEY_URL") { 201 Ok(url) => match ValkeyCache::new(&url).await { 202 Ok(cache) => { 203 tracing::info!("Connected to Valkey cache at {}", url); 204 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 205 (Arc::new(cache), rate_limiter) 206 } 207 Err(e) => { 208 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 209 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 210 } 211 }, 212 Err(_) => { 213 tracing::info!("VALKEY_URL not set. Running without cache."); 214 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 215 } 216 } 217}