this repo has no description
1use async_trait::async_trait;
2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3use std::sync::Arc;
4use std::time::Duration;
5
6#[derive(Debug, thiserror::Error)]
7pub enum CacheError {
8 #[error("Cache connection error: {0}")]
9 Connection(String),
10 #[error("Serialization error: {0}")]
11 Serialization(String),
12}
13
14#[async_trait]
15pub trait Cache: Send + Sync {
16 async fn get(&self, key: &str) -> Option<String>;
17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
18 async fn delete(&self, key: &str) -> Result<(), CacheError>;
19
20 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
21 self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
22 }
23
24 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> {
25 let encoded = BASE64.encode(value);
26 self.set(key, &encoded, ttl).await
27 }
28}
29
30#[derive(Clone)]
31pub struct ValkeyCache {
32 conn: redis::aio::ConnectionManager,
33}
34
35impl ValkeyCache {
36 pub async fn new(url: &str) -> Result<Self, CacheError> {
37 let client = redis::Client::open(url)
38 .map_err(|e| CacheError::Connection(e.to_string()))?;
39 let manager = client
40 .get_connection_manager()
41 .await
42 .map_err(|e| CacheError::Connection(e.to_string()))?;
43 Ok(Self { conn: manager })
44 }
45
46 pub fn connection(&self) -> redis::aio::ConnectionManager {
47 self.conn.clone()
48 }
49}
50
51#[async_trait]
52impl Cache for ValkeyCache {
53 async fn get(&self, key: &str) -> Option<String> {
54 let mut conn = self.conn.clone();
55 redis::cmd("GET")
56 .arg(key)
57 .query_async::<Option<String>>(&mut conn)
58 .await
59 .ok()
60 .flatten()
61 }
62
63 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
64 let mut conn = self.conn.clone();
65 redis::cmd("SET")
66 .arg(key)
67 .arg(value)
68 .arg("EX")
69 .arg(ttl.as_secs() as i64)
70 .query_async::<()>(&mut conn)
71 .await
72 .map_err(|e| CacheError::Connection(e.to_string()))
73 }
74
75 async fn delete(&self, key: &str) -> Result<(), CacheError> {
76 let mut conn = self.conn.clone();
77 redis::cmd("DEL")
78 .arg(key)
79 .query_async::<()>(&mut conn)
80 .await
81 .map_err(|e| CacheError::Connection(e.to_string()))
82 }
83}
84
85pub struct NoOpCache;
86
87#[async_trait]
88impl Cache for NoOpCache {
89 async fn get(&self, _key: &str) -> Option<String> {
90 None
91 }
92
93 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
94 Ok(())
95 }
96
97 async fn delete(&self, _key: &str) -> Result<(), CacheError> {
98 Ok(())
99 }
100}
101
102#[async_trait]
103pub trait DistributedRateLimiter: Send + Sync {
104 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
105}
106
107#[derive(Clone)]
108pub struct RedisRateLimiter {
109 conn: redis::aio::ConnectionManager,
110}
111
112impl RedisRateLimiter {
113 pub fn new(conn: redis::aio::ConnectionManager) -> Self {
114 Self { conn }
115 }
116}
117
118#[async_trait]
119impl DistributedRateLimiter for RedisRateLimiter {
120 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
121 let mut conn = self.conn.clone();
122 let full_key = format!("rl:{}", key);
123 let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
124
125 let count: Result<i64, _> = redis::cmd("INCR")
126 .arg(&full_key)
127 .query_async(&mut conn)
128 .await;
129
130 let count = match count {
131 Ok(c) => c,
132 Err(e) => {
133 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
134 return true;
135 }
136 };
137
138 if count == 1 {
139 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
140 .arg(&full_key)
141 .arg(window_secs)
142 .query_async(&mut conn)
143 .await;
144 }
145
146 count <= limit as i64
147 }
148}
149
150pub struct NoOpRateLimiter;
151
152#[async_trait]
153impl DistributedRateLimiter for NoOpRateLimiter {
154 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
155 true
156 }
157}
158
159pub enum CacheBackend {
160 Valkey(ValkeyCache),
161 NoOp,
162}
163
164impl CacheBackend {
165 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
166 match self {
167 CacheBackend::Valkey(cache) => {
168 Arc::new(RedisRateLimiter::new(cache.connection()))
169 }
170 CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
171 }
172 }
173}
174
175#[async_trait]
176impl Cache for CacheBackend {
177 async fn get(&self, key: &str) -> Option<String> {
178 match self {
179 CacheBackend::Valkey(c) => c.get(key).await,
180 CacheBackend::NoOp => None,
181 }
182 }
183
184 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
185 match self {
186 CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
187 CacheBackend::NoOp => Ok(()),
188 }
189 }
190
191 async fn delete(&self, key: &str) -> Result<(), CacheError> {
192 match self {
193 CacheBackend::Valkey(c) => c.delete(key).await,
194 CacheBackend::NoOp => Ok(()),
195 }
196 }
197}
198
199pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
200 match std::env::var("VALKEY_URL") {
201 Ok(url) => match ValkeyCache::new(&url).await {
202 Ok(cache) => {
203 tracing::info!("Connected to Valkey cache at {}", url);
204 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
205 (Arc::new(cache), rate_limiter)
206 }
207 Err(e) => {
208 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
209 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
210 }
211 },
212 Err(_) => {
213 tracing::info!("VALKEY_URL not set. Running without cache.");
214 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
215 }
216 }
217}