this repo has no description
1use async_trait::async_trait;
2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3use std::sync::Arc;
4use std::time::Duration;
5
6#[derive(Debug, thiserror::Error)]
7pub enum CacheError {
8 #[error("Cache connection error: {0}")]
9 Connection(String),
10 #[error("Serialization error: {0}")]
11 Serialization(String),
12}
13
14#[async_trait]
15pub trait Cache: Send + Sync {
16 async fn get(&self, key: &str) -> Option<String>;
17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
18 async fn delete(&self, key: &str) -> Result<(), CacheError>;
19 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
20 self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
21 }
22 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> {
23 let encoded = BASE64.encode(value);
24 self.set(key, &encoded, ttl).await
25 }
26}
27
28#[derive(Clone)]
29pub struct ValkeyCache {
30 conn: redis::aio::ConnectionManager,
31}
32
33impl ValkeyCache {
34 pub async fn new(url: &str) -> Result<Self, CacheError> {
35 let client = redis::Client::open(url)
36 .map_err(|e| CacheError::Connection(e.to_string()))?;
37 let manager = client
38 .get_connection_manager()
39 .await
40 .map_err(|e| CacheError::Connection(e.to_string()))?;
41 Ok(Self { conn: manager })
42 }
43
44 pub fn connection(&self) -> redis::aio::ConnectionManager {
45 self.conn.clone()
46 }
47}
48
49#[async_trait]
50impl Cache for ValkeyCache {
51 async fn get(&self, key: &str) -> Option<String> {
52 let mut conn = self.conn.clone();
53 redis::cmd("GET")
54 .arg(key)
55 .query_async::<Option<String>>(&mut conn)
56 .await
57 .ok()
58 .flatten()
59 }
60
61 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
62 let mut conn = self.conn.clone();
63 redis::cmd("SET")
64 .arg(key)
65 .arg(value)
66 .arg("EX")
67 .arg(ttl.as_secs() as i64)
68 .query_async::<()>(&mut conn)
69 .await
70 .map_err(|e| CacheError::Connection(e.to_string()))
71 }
72
73 async fn delete(&self, key: &str) -> Result<(), CacheError> {
74 let mut conn = self.conn.clone();
75 redis::cmd("DEL")
76 .arg(key)
77 .query_async::<()>(&mut conn)
78 .await
79 .map_err(|e| CacheError::Connection(e.to_string()))
80 }
81}
82
83pub struct NoOpCache;
84
85#[async_trait]
86impl Cache for NoOpCache {
87 async fn get(&self, _key: &str) -> Option<String> {
88 None
89 }
90
91 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
92 Ok(())
93 }
94
95 async fn delete(&self, _key: &str) -> Result<(), CacheError> {
96 Ok(())
97 }
98}
99
100#[async_trait]
101pub trait DistributedRateLimiter: Send + Sync {
102 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
103}
104
105#[derive(Clone)]
106pub struct RedisRateLimiter {
107 conn: redis::aio::ConnectionManager,
108}
109
110impl RedisRateLimiter {
111 pub fn new(conn: redis::aio::ConnectionManager) -> Self {
112 Self { conn }
113 }
114}
115
116#[async_trait]
117impl DistributedRateLimiter for RedisRateLimiter {
118 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
119 let mut conn = self.conn.clone();
120 let full_key = format!("rl:{}", key);
121 let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
122 let count: Result<i64, _> = redis::cmd("INCR")
123 .arg(&full_key)
124 .query_async(&mut conn)
125 .await;
126 let count = match count {
127 Ok(c) => c,
128 Err(e) => {
129 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
130 return true;
131 }
132 };
133 if count == 1 {
134 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
135 .arg(&full_key)
136 .arg(window_secs)
137 .query_async(&mut conn)
138 .await;
139 }
140 count <= limit as i64
141 }
142}
143
144pub struct NoOpRateLimiter;
145
146#[async_trait]
147impl DistributedRateLimiter for NoOpRateLimiter {
148 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
149 true
150 }
151}
152
153pub enum CacheBackend {
154 Valkey(ValkeyCache),
155 NoOp,
156}
157
158impl CacheBackend {
159 pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
160 match self {
161 CacheBackend::Valkey(cache) => {
162 Arc::new(RedisRateLimiter::new(cache.connection()))
163 }
164 CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
165 }
166 }
167}
168
169#[async_trait]
170impl Cache for CacheBackend {
171 async fn get(&self, key: &str) -> Option<String> {
172 match self {
173 CacheBackend::Valkey(c) => c.get(key).await,
174 CacheBackend::NoOp => None,
175 }
176 }
177
178 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
179 match self {
180 CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
181 CacheBackend::NoOp => Ok(()),
182 }
183 }
184
185 async fn delete(&self, key: &str) -> Result<(), CacheError> {
186 match self {
187 CacheBackend::Valkey(c) => c.delete(key).await,
188 CacheBackend::NoOp => Ok(()),
189 }
190 }
191}
192
193pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
194 match std::env::var("VALKEY_URL") {
195 Ok(url) => match ValkeyCache::new(&url).await {
196 Ok(cache) => {
197 tracing::info!("Connected to Valkey cache at {}", url);
198 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
199 (Arc::new(cache), rate_limiter)
200 }
201 Err(e) => {
202 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
203 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
204 }
205 },
206 Err(_) => {
207 tracing::info!("VALKEY_URL not set. Running without cache.");
208 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
209 }
210 }
211}