this repo has no description
1use async_trait::async_trait;
2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3use std::sync::Arc;
4use std::time::Duration;
5#[derive(Debug, thiserror::Error)]
6pub enum CacheError {
7 #[error("Cache connection error: {0}")]
8 Connection(String),
9 #[error("Serialization error: {0}")]
10 Serialization(String),
11}
12#[async_trait]
13pub trait Cache: Send + Sync {
14 async fn get(&self, key: &str) -> Option<String>;
15 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
16 async fn delete(&self, key: &str) -> Result<(), CacheError>;
17 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
18 self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
19 }
20 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> {
21 let encoded = BASE64.encode(value);
22 self.set(key, &encoded, ttl).await
23 }
24}
25#[derive(Clone)]
26pub struct ValkeyCache {
27 conn: redis::aio::ConnectionManager,
28}
29impl ValkeyCache {
30 pub async fn new(url: &str) -> Result<Self, CacheError> {
31 let client = redis::Client::open(url)
32 .map_err(|e| CacheError::Connection(e.to_string()))?;
33 let manager = client
34 .get_connection_manager()
35 .await
36 .map_err(|e| CacheError::Connection(e.to_string()))?;
37 Ok(Self { conn: manager })
38 }
39 pub fn connection(&self) -> redis::aio::ConnectionManager {
40 self.conn.clone()
41 }
42}
43#[async_trait]
44impl Cache for ValkeyCache {
45 async fn get(&self, key: &str) -> Option<String> {
46 let mut conn = self.conn.clone();
47 redis::cmd("GET")
48 .arg(key)
49 .query_async::<Option<String>>(&mut conn)
50 .await
51 .ok()
52 .flatten()
53 }
54 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
55 let mut conn = self.conn.clone();
56 redis::cmd("SET")
57 .arg(key)
58 .arg(value)
59 .arg("EX")
60 .arg(ttl.as_secs() as i64)
61 .query_async::<()>(&mut conn)
62 .await
63 .map_err(|e| CacheError::Connection(e.to_string()))
64 }
65 async fn delete(&self, key: &str) -> Result<(), CacheError> {
66 let mut conn = self.conn.clone();
67 redis::cmd("DEL")
68 .arg(key)
69 .query_async::<()>(&mut conn)
70 .await
71 .map_err(|e| CacheError::Connection(e.to_string()))
72 }
73}
74pub struct NoOpCache;
75#[async_trait]
76impl Cache for NoOpCache {
77 async fn get(&self, _key: &str) -> Option<String> {
78 None
79 }
80 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
81 Ok(())
82 }
83 async fn delete(&self, _key: &str) -> Result<(), CacheError> {
84 Ok(())
85 }
86}
87#[async_trait]
88pub trait DistributedRateLimiter: Send + Sync {
89 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
90}
91#[derive(Clone)]
92pub struct RedisRateLimiter {
93 conn: redis::aio::ConnectionManager,
94}
95impl RedisRateLimiter {
96 pub fn new(conn: redis::aio::ConnectionManager) -> Self {
97 Self { conn }
98 }
99}
100#[async_trait]
101impl DistributedRateLimiter for RedisRateLimiter {
102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
103 let mut conn = self.conn.clone();
104 let full_key = format!("rl:{}", key);
105 let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
106 let count: Result<i64, _> = redis::cmd("INCR")
107 .arg(&full_key)
108 .query_async(&mut conn)
109 .await;
110 let count = match count {
111 Ok(c) => c,
112 Err(e) => {
113 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
114 return true;
115 }
116 };
117 if count == 1 {
118 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
119 .arg(&full_key)
120 .arg(window_secs)
121 .query_async(&mut conn)
122 .await;
123 }
124 count <= limit as i64
125 }
126}
127pub struct NoOpRateLimiter;
128#[async_trait]
129impl DistributedRateLimiter for NoOpRateLimiter {
130 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
131 true
132 }
133}
134pub enum CacheBackend {
135 Valkey(ValkeyCache),
136 NoOp,
137}
138impl CacheBackend {
139 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
140 match self {
141 CacheBackend::Valkey(cache) => {
142 Arc::new(RedisRateLimiter::new(cache.connection()))
143 }
144 CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
145 }
146 }
147}
148#[async_trait]
149impl Cache for CacheBackend {
150 async fn get(&self, key: &str) -> Option<String> {
151 match self {
152 CacheBackend::Valkey(c) => c.get(key).await,
153 CacheBackend::NoOp => None,
154 }
155 }
156 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
157 match self {
158 CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
159 CacheBackend::NoOp => Ok(()),
160 }
161 }
162 async fn delete(&self, key: &str) -> Result<(), CacheError> {
163 match self {
164 CacheBackend::Valkey(c) => c.delete(key).await,
165 CacheBackend::NoOp => Ok(()),
166 }
167 }
168}
169pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
170 match std::env::var("VALKEY_URL") {
171 Ok(url) => match ValkeyCache::new(&url).await {
172 Ok(cache) => {
173 tracing::info!("Connected to Valkey cache at {}", url);
174 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
175 (Arc::new(cache), rate_limiter)
176 }
177 Err(e) => {
178 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
179 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
180 }
181 },
182 Err(_) => {
183 tracing::info!("VALKEY_URL not set. Running without cache.");
184 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
185 }
186 }
187}