Heavily customized version of smokesignal - https://whtwnd.com/kayrozen.com/3lpwe4ymowg2t
1use anyhow::Result;
2use base64_url;
3use chrono::Duration;
4use deadpool_redis::redis::{pipe, AsyncCommands};
5use p256::SecretKey;
6use tokio::time::{sleep, Instant};
7use tokio_util::sync::CancellationToken;
8use rand::seq::SliceRandom;
9
10use crate::{
11 config::{SigningKeys, OAuthActiveKeys},
12 config_errors::ConfigError,
13 jose::jwk::WrappedJsonWebKey,
14 refresh_tokens_errors::RefreshError,
15 storage::{
16 oauth::{
17 oauth_session_delete_by_id,
18 oauth_session_update_tokens, web_session_lookup,
19 },
20 types::{StoragePool, CachePool, OAUTH_REFRESH_QUEUE, OAUTH_REFRESH_HEARTBEATS, TokenSet},
21 },
22};
23
24pub struct RefreshTokensTaskConfig {
25 pub sleep_interval: Duration,
26 pub worker_id: String,
27 pub external_url_base: String,
28 pub signing_keys: SigningKeys,
29 pub oauth_active_keys: OAuthActiveKeys,
30}
31
32impl RefreshTokensTaskConfig {
33 pub fn select_oauth_signing_key(&self) -> Result<(String, SecretKey)> {
34 let key_id = self
35 .oauth_active_keys
36 .as_ref()
37 .choose(&mut rand::thread_rng())
38 .ok_or(ConfigError::SigningKeyNotFound)?
39 .clone();
40 let signing_key = self
41 .signing_keys
42 .as_ref()
43 .get(&key_id)
44 .ok_or(ConfigError::SigningKeyNotFound)?
45 .clone();
46
47 Ok((key_id, signing_key))
48 }
49}
50
51pub struct RefreshTokensTask {
52 pub config: RefreshTokensTaskConfig,
53 pub http_client: reqwest::Client,
54 pub storage_pool: StoragePool,
55 pub cache_pool: CachePool,
56 pub cancellation_token: CancellationToken,
57}
58
59impl RefreshTokensTask {
60 #[must_use]
61 pub fn new(
62 config: RefreshTokensTaskConfig,
63 http_client: reqwest::Client,
64 storage_pool: StoragePool,
65 cache_pool: CachePool,
66 cancellation_token: CancellationToken,
67 ) -> Self {
68 Self {
69 config,
70 http_client,
71 storage_pool,
72 cache_pool,
73 cancellation_token,
74 }
75 }
76
77 /// Runs the refresh tokens task as a long-running process
78 ///
79 /// # Errors
80 /// Returns an error if the sleep interval cannot be converted, or if there's a problem
81 /// processing the work items
82 pub async fn run(&self) -> Result<()> {
83 tracing::debug!("RefreshTokensTask started");
84
85 let interval = self.config.sleep_interval.to_std()?;
86
87 let sleeper = sleep(interval);
88 tokio::pin!(sleeper);
89
90 loop {
91 tokio::select! {
92 () = self.cancellation_token.cancelled() => {
93 break;
94 },
95 () = &mut sleeper => {
96 if let Err(err) = self.process_work().await {
97 tracing::error!("RefreshTokensTask failed: {}", err);
98 }
99 sleeper.as_mut().reset(Instant::now() + interval);
100 }
101 }
102 }
103
104 tracing::info!("RefreshTokensTask stopped");
105
106 Ok(())
107 }
108
109 async fn process_work(&self) -> Result<i32> {
110 let worker_queue = build_worker_queue(&self.config.worker_id);
111
112 let mut conn = self.cache_pool.get().await?;
113
114 let now = chrono::Utc::now();
115 let epoch_millis = now.timestamp_millis();
116
117 let _: () = conn
118 .hset(
119 OAUTH_REFRESH_HEARTBEATS,
120 &self.config.worker_id,
121 now.to_string(),
122 )
123 .await?;
124
125 let global_queue_count: i32 = conn
126 .zcount(OAUTH_REFRESH_QUEUE, 0, epoch_millis + 1)
127 .await?;
128 let worker_queue_count: i32 = conn.zcount(&worker_queue, 0, epoch_millis + 1).await?;
129
130 tracing::trace!(
131 global_queue_count = global_queue_count,
132 worker_queue_count = worker_queue_count,
133 "queue counts"
134 );
135
136 let mut process_work = worker_queue_count > 0;
137
138 if global_queue_count > 0 && worker_queue_count == 0 {
139 let (moved, new_count): (i64, i64) = pipe()
140 .atomic()
141 // Take some work from the global queue and put it in the worker queue
142 // ZRANGESTORE dst src min max [BYSCORE | BYLEX] [REV] [LIMIT offset count]
143 .cmd("ZRANGESTORE")
144 .arg(&worker_queue)
145 .arg(OAUTH_REFRESH_QUEUE)
146 .arg(0)
147 .arg(epoch_millis)
148 .arg("BYSCORE")
149 .arg("LIMIT")
150 .arg(0)
151 .arg(5)
152 // Update the global queue to remove the items that were moved
153 .cmd("ZDIFFSTORE")
154 .arg(OAUTH_REFRESH_QUEUE)
155 .arg(2)
156 .arg(OAUTH_REFRESH_QUEUE)
157 .arg(&worker_queue)
158 .query_async(&mut conn)
159 .await?;
160 process_work = true;
161
162 tracing::debug!(
163 moved = moved,
164 new_count = new_count,
165 "moved work from global queue to worker queue"
166 );
167 }
168
169 if !process_work {
170 return Ok(0);
171 }
172
173 let count = 0;
174 let results: Vec<(String, i64)> = conn
175 .zrangebyscore_limit_withscores(&worker_queue, 0, epoch_millis, 0, 5)
176 .await?;
177
178 for (session_group, deadline) in results {
179 tracing::info!(session_group, deadline, "processing work");
180 let _: () = conn.zrem(&worker_queue, &session_group).await?;
181
182 if let Err(err) = self
183 .refresh_oauth_session(&mut conn, &session_group, deadline)
184 .await
185 {
186 tracing::error!(session_group, deadline, err = ?err, "failed to refresh oauth session: {}", err);
187
188 if let Err(err) = oauth_session_delete_by_id(&self.storage_pool, &session_group).await {
189 tracing::error!(session_group, err = ?err, "failed to delete oauth session: {}", err);
190 }
191 }
192 }
193
194 Ok(count)
195 }
196
197 async fn refresh_oauth_session(
198 &self,
199 conn: &mut deadpool_redis::Connection,
200 session_group: &str,
201 _deadline: i64,
202 ) -> Result<()> {
203 let (_handle, oauth_session) =
204 web_session_lookup(&self.storage_pool, session_group, None).await?;
205
206 // Use a signing key from the available OAuth signing keys
207 let (_, _secret_signing_key) = self.config.select_oauth_signing_key()?;
208
209 // Simplified DPoP key extraction - skip complex error handling for now
210 let _dpop_secret_key_opt: Option<SecretKey> = oauth_session
211 .dpop_jwk
212 .as_ref()
213 .and_then(|sqlx_jwk: &sqlx::types::Json<WrappedJsonWebKey>| {
214 let jwk = &sqlx_jwk.0.jwk;
215 jwk.d.as_ref().and_then(|d_b64u| {
216 base64_url::decode(d_b64u).ok()
217 .and_then(|d_bytes| SecretKey::from_slice(&d_bytes).ok())
218 })
219 });
220
221 // For now, create a simplified token response that matches TokenSet
222 // This would need to be replaced with actual OAuth refresh call
223 let token_response = crate::oauth::model::TokenResponse {
224 access_token: oauth_session.access_token.clone(),
225 token_type: "Bearer".to_string(),
226 refresh_token: oauth_session.refresh_token.clone(),
227 scope: "".to_string(),
228 expires_in: 3600,
229 sub: oauth_session.did.clone(),
230 };
231
232 let token_set = TokenSet {
233 access_token: token_response.access_token.clone(),
234 refresh_token: token_response.refresh_token.clone(),
235 expires_in: token_response.expires_in,
236 token_type: token_response.token_type.clone(),
237 scope: token_response.scope.clone(),
238 sub: token_response.sub.clone(),
239 };
240
241 // Update the session with the new tokens
242 let now = chrono::Utc::now();
243 let expires_at = now + chrono::Duration::seconds(token_set.expires_in as i64);
244
245 oauth_session_update_tokens(
246 &self.storage_pool,
247 session_group,
248 &token_set.access_token,
249 &token_set.refresh_token,
250 expires_at,
251 )
252 .await?;
253
254 let modified_expires_at = ((token_set.expires_in as f64) * 0.8).round() as i64;
255 let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis();
256
257 let _: () = conn
258 .zadd(OAUTH_REFRESH_QUEUE, session_group, refresh_at)
259 .await
260 .map_err(RefreshError::PlaceInRefreshQueueFailed)?;
261
262 Ok(())
263 }
264}
265
266fn build_worker_queue(worker_id: &str) -> String {
267 format!("{}:{}", OAUTH_REFRESH_QUEUE, worker_id)
268}