this repo has no description
1use async_trait::async_trait;
2use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
3use std::sync::Arc;
4use std::time::Duration;
5
6#[derive(Debug, thiserror::Error)]
7pub enum CacheError {
8 #[error("Cache connection error: {0}")]
9 Connection(String),
10 #[error("Serialization error: {0}")]
11 Serialization(String),
12}
13
14#[async_trait]
15pub trait Cache: Send + Sync {
16 async fn get(&self, key: &str) -> Option<String>;
17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
18 async fn delete(&self, key: &str) -> Result<(), CacheError>;
19 async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> {
20 self.get(key).await.and_then(|s| BASE64.decode(&s).ok())
21 }
22 async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> {
23 let encoded = BASE64.encode(value);
24 self.set(key, &encoded, ttl).await
25 }
26}
27
28#[derive(Clone)]
29pub struct ValkeyCache {
30 conn: redis::aio::ConnectionManager,
31}
32
33impl ValkeyCache {
34 pub async fn new(url: &str) -> Result<Self, CacheError> {
35 let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?;
36 let manager = client
37 .get_connection_manager()
38 .await
39 .map_err(|e| CacheError::Connection(e.to_string()))?;
40 Ok(Self { conn: manager })
41 }
42
43 pub fn connection(&self) -> redis::aio::ConnectionManager {
44 self.conn.clone()
45 }
46}
47
48#[async_trait]
49impl Cache for ValkeyCache {
50 async fn get(&self, key: &str) -> Option<String> {
51 let mut conn = self.conn.clone();
52 redis::cmd("GET")
53 .arg(key)
54 .query_async::<Option<String>>(&mut conn)
55 .await
56 .ok()
57 .flatten()
58 }
59
60 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
61 let mut conn = self.conn.clone();
62 redis::cmd("SET")
63 .arg(key)
64 .arg(value)
65 .arg("EX")
66 .arg(ttl.as_secs() as i64)
67 .query_async::<()>(&mut conn)
68 .await
69 .map_err(|e| CacheError::Connection(e.to_string()))
70 }
71
72 async fn delete(&self, key: &str) -> Result<(), CacheError> {
73 let mut conn = self.conn.clone();
74 redis::cmd("DEL")
75 .arg(key)
76 .query_async::<()>(&mut conn)
77 .await
78 .map_err(|e| CacheError::Connection(e.to_string()))
79 }
80}
81
82pub struct NoOpCache;
83
84#[async_trait]
85impl Cache for NoOpCache {
86 async fn get(&self, _key: &str) -> Option<String> {
87 None
88 }
89
90 async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
91 Ok(())
92 }
93
94 async fn delete(&self, _key: &str) -> Result<(), CacheError> {
95 Ok(())
96 }
97}
98
99#[async_trait]
100pub trait DistributedRateLimiter: Send + Sync {
101 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
102}
103
104#[derive(Clone)]
105pub struct RedisRateLimiter {
106 conn: redis::aio::ConnectionManager,
107}
108
109impl RedisRateLimiter {
110 pub fn new(conn: redis::aio::ConnectionManager) -> Self {
111 Self { conn }
112 }
113}
114
115#[async_trait]
116impl DistributedRateLimiter for RedisRateLimiter {
117 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
118 let mut conn = self.conn.clone();
119 let full_key = format!("rl:{}", key);
120 let window_secs = window_ms.div_ceil(1000).max(1) as i64;
121 let count: Result<i64, _> = redis::cmd("INCR")
122 .arg(&full_key)
123 .query_async(&mut conn)
124 .await;
125 let count = match count {
126 Ok(c) => c,
127 Err(e) => {
128 tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
129 return true;
130 }
131 };
132 if count == 1 {
133 let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
134 .arg(&full_key)
135 .arg(window_secs)
136 .query_async(&mut conn)
137 .await;
138 }
139 count <= limit as i64
140 }
141}
142
143pub struct NoOpRateLimiter;
144
145#[async_trait]
146impl DistributedRateLimiter for NoOpRateLimiter {
147 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
148 true
149 }
150}
151
152pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
153 match std::env::var("VALKEY_URL") {
154 Ok(url) => match ValkeyCache::new(&url).await {
155 Ok(cache) => {
156 tracing::info!("Connected to Valkey cache at {}", url);
157 let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
158 (Arc::new(cache), rate_limiter)
159 }
160 Err(e) => {
161 tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
162 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
163 }
164 },
165 Err(_) => {
166 tracing::info!("VALKEY_URL not set. Running without cache.");
167 (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
168 }
169 }
170}