this repo has no description
1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::Duration;
4
5#[derive(Debug, thiserror::Error)]
6pub enum CacheError {
7 #[error("Cache connection error: {0}")]
8 Connection(String),
9 #[error("Serialization error: {0}")]
10 Serialization(String),
11}
12
13#[async_trait]
14pub trait Cache: Send + Sync {
15 async fn get(&self, key: &str) -> Option<String>;
16 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
17 async fn delete(&self, key: &str) -> Result<(), CacheError>;
18}
19
20#[derive(Clone)]
21pub struct ValkeyCache {
22 conn: redis::aio::ConnectionManager,
23}
24
25impl ValkeyCache {
26 pub async fn new(url: &str) -> Result<Self, CacheError> {
27 let client = redis::Client::open(url)
28 .map_err(|e| CacheError::Connection(e.to_string()))?;
29 let manager = client
30 .get_connection_manager()
31 .await
32 .map_err(|e| CacheError::Connection(e.to_string()))?;
33 Ok(Self { conn: manager })
34 }
35
36 pub fn connection(&self) -> redis::aio::ConnectionManager {
37 self.conn.clone()
38 }
39}
40
41#[async_trait]
42impl Cache for ValkeyCache {
43 async fn get(&self, key: &str) -> Option<String> {
44 let mut conn = self.conn.clone();
45 redis::cmd("GET")
46 .arg(key)
47 .query_async::<Option<String>>(&mut conn)
48 .await
49 .ok()
50 .flatten()
51 }
52
53 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
54 let mut conn = self.conn.clone();
55 redis::cmd("SET")
56 .arg(key)
57 .arg(value)
58 .arg("EX")
59 .arg(ttl.as_secs() as i64)
60 .query_async::<()>(&mut conn)
61 .await
62 .map_err(|e| CacheError::Connection(e.to_string()))
63 }
64
65 async fn delete(&self, key: &str) -> Result<(), CacheError> {
66 let mut conn = self.conn.clone();
67 redis::cmd("DEL")
68 .arg(key)
69 .query_async::<()>(&mut conn)
70 .await
71 .map_err(|e| CacheError::Connection(e.to_string()))
72 }
73}
74
75pub struct NoOpCache;
76
77#[async_trait]
78impl Cache for NoOpCache {
79 async fn get(&self, _key: &str) -> Option<String> {
80 None
81 }
82
83 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
84 Ok(())
85 }
86
87 async fn delete(&self, _key: &str) -> Result<(), CacheError> {
88 Ok(())
89 }
90}
91
92#[async_trait]
93pub trait DistributedRateLimiter: Send + Sync {
94 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
95}
96
97#[derive(Clone)]
98pub struct RedisRateLimiter {
99 conn: redis::aio::ConnectionManager,
100}
101
102impl RedisRateLimiter {
103 pub fn new(conn: redis::aio::ConnectionManager) -> Self {
104 Self { conn }
105 }
106}
107
108#[async_trait]
109impl DistributedRateLimiter for RedisRateLimiter {
110 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
111 let mut conn = self.conn.clone();
112 let full_key = format!("rl:{}", key);
113 let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
114
115 let count: Result<i64, _> = redis::cmd("INCR")
116 .arg(&full_key)
117 .query_async(&mut conn)
118 .await;
119
120 let count = match count {
121 Ok(c) => c,
122 Err(e) => {
123 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
124 return true;
125 }
126 };
127
128 if count == 1 {
129 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
130 .arg(&full_key)
131 .arg(window_secs)
132 .query_async(&mut conn)
133 .await;
134 }
135
136 count <= limit as i64
137 }
138}
139
140pub struct NoOpRateLimiter;
141
142#[async_trait]
143impl DistributedRateLimiter for NoOpRateLimiter {
144 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
145 true
146 }
147}
148
149pub enum CacheBackend {
150 Valkey(ValkeyCache),
151 NoOp,
152}
153
154impl CacheBackend {
155 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
156 match self {
157 CacheBackend::Valkey(cache) => {
158 Arc::new(RedisRateLimiter::new(cache.connection()))
159 }
160 CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
161 }
162 }
163}
164
165#[async_trait]
166impl Cache for CacheBackend {
167 async fn get(&self, key: &str) -> Option<String> {
168 match self {
169 CacheBackend::Valkey(c) => c.get(key).await,
170 CacheBackend::NoOp => None,
171 }
172 }
173
174 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
175 match self {
176 CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
177 CacheBackend::NoOp => Ok(()),
178 }
179 }
180
181 async fn delete(&self, key: &str) -> Result<(), CacheError> {
182 match self {
183 CacheBackend::Valkey(c) => c.delete(key).await,
184 CacheBackend::NoOp => Ok(()),
185 }
186 }
187}
188
189pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
190 match std::env::var("VALKEY_URL") {
191 Ok(url) => match ValkeyCache::new(&url).await {
192 Ok(cache) => {
193 tracing::info!("Connected to Valkey cache at {}", url);
194 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
195 (Arc::new(cache), rate_limiter)
196 }
197 Err(e) => {
198 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
199 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
200 }
201 },
202 Err(_) => {
203 tracing::info!("VALKEY_URL not set. Running without cache.");
204 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
205 }
206 }
207}