this repo has no description
1use async_trait::async_trait; 2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 3use std::sync::Arc; 4use std::time::Duration; 5 6#[derive(Debug, thiserror::Error)] 7pub enum CacheError { 8 #[error("Cache connection error: {0}")] 9 Connection(String), 10 #[error("Serialization error: {0}")] 11 Serialization(String), 12} 13 14#[async_trait] 15pub trait Cache: Send + Sync { 16 async fn get(&self, key: &str) -> Option<String>; 17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 18 async fn delete(&self, key: &str) -> Result<(), CacheError>; 19 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 20 self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 21 } 22 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 23 let encoded = BASE64.encode(value); 24 self.set(key, &encoded, ttl).await 25 } 26} 27 28#[derive(Clone)] 29pub struct ValkeyCache { 30 conn: redis::aio::ConnectionManager, 31} 32 33impl ValkeyCache { 34 pub async fn new(url: &str) -> Result<Self, CacheError> { 35 let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; 36 let manager = client 37 .get_connection_manager() 38 .await 39 .map_err(|e| CacheError::Connection(e.to_string()))?; 40 Ok(Self { conn: manager }) 41 } 42 43 pub fn connection(&self) -> redis::aio::ConnectionManager { 44 self.conn.clone() 45 } 46} 47 48#[async_trait] 49impl Cache for ValkeyCache { 50 async fn get(&self, key: &str) -> Option<String> { 51 let mut conn = self.conn.clone(); 52 redis::cmd("GET") 53 .arg(key) 54 .query_async::<Option<String>>(&mut conn) 55 .await 56 .ok() 57 .flatten() 58 } 59 60 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 61 let mut conn = self.conn.clone(); 62 redis::cmd("SET") 63 .arg(key) 64 .arg(value) 65 .arg("EX") 66 .arg(ttl.as_secs() as i64) 67 .query_async::<()>(&mut conn) 68 .await 69 .map_err(|e| CacheError::Connection(e.to_string())) 70 } 71 72 async fn delete(&self, key: &str) -> Result<(), CacheError> { 73 let mut conn = self.conn.clone(); 74 redis::cmd("DEL") 75 .arg(key) 76 .query_async::<()>(&mut conn) 77 .await 78 .map_err(|e| CacheError::Connection(e.to_string())) 79 } 80} 81 82pub struct NoOpCache; 83 84#[async_trait] 85impl Cache for NoOpCache { 86 async fn get(&self, _key: &str) -> Option<String> { 87 None 88 } 89 90 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 91 Ok(()) 92 } 93 94 async fn delete(&self, _key: &str) -> Result<(), CacheError> { 95 Ok(()) 96 } 97} 98 99#[async_trait] 100pub trait DistributedRateLimiter: Send + Sync { 101 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 102} 103 104#[derive(Clone)] 105pub struct RedisRateLimiter { 106 conn: redis::aio::ConnectionManager, 107} 108 109impl RedisRateLimiter { 110 pub fn new(conn: redis::aio::ConnectionManager) -> Self { 111 Self { conn } 112 } 113} 114 115#[async_trait] 116impl DistributedRateLimiter for RedisRateLimiter { 117 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 118 let mut conn = self.conn.clone(); 119 let full_key = format!("rl:{}", key); 120 let window_secs = window_ms.div_ceil(1000).max(1) as i64; 121 let count: Result<i64, _> = redis::cmd("INCR") 122 .arg(&full_key) 123 .query_async(&mut conn) 124 .await; 125 let count = match count { 126 Ok(c) => c, 127 Err(e) => { 128 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 129 return true; 130 } 131 }; 132 if count == 1 { 133 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 134 .arg(&full_key) 135 .arg(window_secs) 136 .query_async(&mut conn) 137 .await; 138 } 139 count <= limit as i64 140 } 141} 142 143pub struct NoOpRateLimiter; 144 145#[async_trait] 146impl DistributedRateLimiter for NoOpRateLimiter { 147 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 148 true 149 } 150} 151 152pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 153 match std::env::var("VALKEY_URL") { 154 Ok(url) => match ValkeyCache::new(&url).await { 155 Ok(cache) => { 156 tracing::info!("Connected to Valkey cache at {}", url); 157 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 158 (Arc::new(cache), rate_limiter) 159 } 160 Err(e) => { 161 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 162 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 163 } 164 }, 165 Err(_) => { 166 tracing::info!("VALKEY_URL not set. Running without cache."); 167 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 168 } 169 } 170}