Heavily customized version of smokesignal - https://whtwnd.com/kayrozen.com/3lpwe4ymowg2t
at main 268 lines 8.8 kB view raw
1use anyhow::Result; 2use base64_url; 3use chrono::Duration; 4use deadpool_redis::redis::{pipe, AsyncCommands}; 5use p256::SecretKey; 6use tokio::time::{sleep, Instant}; 7use tokio_util::sync::CancellationToken; 8use rand::seq::SliceRandom; 9 10use crate::{ 11 config::{SigningKeys, OAuthActiveKeys}, 12 config_errors::ConfigError, 13 jose::jwk::WrappedJsonWebKey, 14 refresh_tokens_errors::RefreshError, 15 storage::{ 16 oauth::{ 17 oauth_session_delete_by_id, 18 oauth_session_update_tokens, web_session_lookup, 19 }, 20 types::{StoragePool, CachePool, OAUTH_REFRESH_QUEUE, OAUTH_REFRESH_HEARTBEATS, TokenSet}, 21 }, 22}; 23 24pub struct RefreshTokensTaskConfig { 25 pub sleep_interval: Duration, 26 pub worker_id: String, 27 pub external_url_base: String, 28 pub signing_keys: SigningKeys, 29 pub oauth_active_keys: OAuthActiveKeys, 30} 31 32impl RefreshTokensTaskConfig { 33 pub fn select_oauth_signing_key(&self) -> Result<(String, SecretKey)> { 34 let key_id = self 35 .oauth_active_keys 36 .as_ref() 37 .choose(&mut rand::thread_rng()) 38 .ok_or(ConfigError::SigningKeyNotFound)? 39 .clone(); 40 let signing_key = self 41 .signing_keys 42 .as_ref() 43 .get(&key_id) 44 .ok_or(ConfigError::SigningKeyNotFound)? 45 .clone(); 46 47 Ok((key_id, signing_key)) 48 } 49} 50 51pub struct RefreshTokensTask { 52 pub config: RefreshTokensTaskConfig, 53 pub http_client: reqwest::Client, 54 pub storage_pool: StoragePool, 55 pub cache_pool: CachePool, 56 pub cancellation_token: CancellationToken, 57} 58 59impl RefreshTokensTask { 60 #[must_use] 61 pub fn new( 62 config: RefreshTokensTaskConfig, 63 http_client: reqwest::Client, 64 storage_pool: StoragePool, 65 cache_pool: CachePool, 66 cancellation_token: CancellationToken, 67 ) -> Self { 68 Self { 69 config, 70 http_client, 71 storage_pool, 72 cache_pool, 73 cancellation_token, 74 } 75 } 76 77 /// Runs the refresh tokens task as a long-running process 78 /// 79 /// # Errors 80 /// Returns an error if the sleep interval cannot be converted, or if there's a problem 81 /// processing the work items 82 pub async fn run(&self) -> Result<()> { 83 tracing::debug!("RefreshTokensTask started"); 84 85 let interval = self.config.sleep_interval.to_std()?; 86 87 let sleeper = sleep(interval); 88 tokio::pin!(sleeper); 89 90 loop { 91 tokio::select! { 92 () = self.cancellation_token.cancelled() => { 93 break; 94 }, 95 () = &mut sleeper => { 96 if let Err(err) = self.process_work().await { 97 tracing::error!("RefreshTokensTask failed: {}", err); 98 } 99 sleeper.as_mut().reset(Instant::now() + interval); 100 } 101 } 102 } 103 104 tracing::info!("RefreshTokensTask stopped"); 105 106 Ok(()) 107 } 108 109 async fn process_work(&self) -> Result<i32> { 110 let worker_queue = build_worker_queue(&self.config.worker_id); 111 112 let mut conn = self.cache_pool.get().await?; 113 114 let now = chrono::Utc::now(); 115 let epoch_millis = now.timestamp_millis(); 116 117 let _: () = conn 118 .hset( 119 OAUTH_REFRESH_HEARTBEATS, 120 &self.config.worker_id, 121 now.to_string(), 122 ) 123 .await?; 124 125 let global_queue_count: i32 = conn 126 .zcount(OAUTH_REFRESH_QUEUE, 0, epoch_millis + 1) 127 .await?; 128 let worker_queue_count: i32 = conn.zcount(&worker_queue, 0, epoch_millis + 1).await?; 129 130 tracing::trace!( 131 global_queue_count = global_queue_count, 132 worker_queue_count = worker_queue_count, 133 "queue counts" 134 ); 135 136 let mut process_work = worker_queue_count > 0; 137 138 if global_queue_count > 0 && worker_queue_count == 0 { 139 let (moved, new_count): (i64, i64) = pipe() 140 .atomic() 141 // Take some work from the global queue and put it in the worker queue 142 // ZRANGESTORE dst src min max [BYSCORE | BYLEX] [REV] [LIMIT offset count] 143 .cmd("ZRANGESTORE") 144 .arg(&worker_queue) 145 .arg(OAUTH_REFRESH_QUEUE) 146 .arg(0) 147 .arg(epoch_millis) 148 .arg("BYSCORE") 149 .arg("LIMIT") 150 .arg(0) 151 .arg(5) 152 // Update the global queue to remove the items that were moved 153 .cmd("ZDIFFSTORE") 154 .arg(OAUTH_REFRESH_QUEUE) 155 .arg(2) 156 .arg(OAUTH_REFRESH_QUEUE) 157 .arg(&worker_queue) 158 .query_async(&mut conn) 159 .await?; 160 process_work = true; 161 162 tracing::debug!( 163 moved = moved, 164 new_count = new_count, 165 "moved work from global queue to worker queue" 166 ); 167 } 168 169 if !process_work { 170 return Ok(0); 171 } 172 173 let count = 0; 174 let results: Vec<(String, i64)> = conn 175 .zrangebyscore_limit_withscores(&worker_queue, 0, epoch_millis, 0, 5) 176 .await?; 177 178 for (session_group, deadline) in results { 179 tracing::info!(session_group, deadline, "processing work"); 180 let _: () = conn.zrem(&worker_queue, &session_group).await?; 181 182 if let Err(err) = self 183 .refresh_oauth_session(&mut conn, &session_group, deadline) 184 .await 185 { 186 tracing::error!(session_group, deadline, err = ?err, "failed to refresh oauth session: {}", err); 187 188 if let Err(err) = oauth_session_delete_by_id(&self.storage_pool, &session_group).await { 189 tracing::error!(session_group, err = ?err, "failed to delete oauth session: {}", err); 190 } 191 } 192 } 193 194 Ok(count) 195 } 196 197 async fn refresh_oauth_session( 198 &self, 199 conn: &mut deadpool_redis::Connection, 200 session_group: &str, 201 _deadline: i64, 202 ) -> Result<()> { 203 let (_handle, oauth_session) = 204 web_session_lookup(&self.storage_pool, session_group, None).await?; 205 206 // Use a signing key from the available OAuth signing keys 207 let (_, _secret_signing_key) = self.config.select_oauth_signing_key()?; 208 209 // Simplified DPoP key extraction - skip complex error handling for now 210 let _dpop_secret_key_opt: Option<SecretKey> = oauth_session 211 .dpop_jwk 212 .as_ref() 213 .and_then(|sqlx_jwk: &sqlx::types::Json<WrappedJsonWebKey>| { 214 let jwk = &sqlx_jwk.0.jwk; 215 jwk.d.as_ref().and_then(|d_b64u| { 216 base64_url::decode(d_b64u).ok() 217 .and_then(|d_bytes| SecretKey::from_slice(&d_bytes).ok()) 218 }) 219 }); 220 221 // For now, create a simplified token response that matches TokenSet 222 // This would need to be replaced with actual OAuth refresh call 223 let token_response = crate::oauth::model::TokenResponse { 224 access_token: oauth_session.access_token.clone(), 225 token_type: "Bearer".to_string(), 226 refresh_token: oauth_session.refresh_token.clone(), 227 scope: "".to_string(), 228 expires_in: 3600, 229 sub: oauth_session.did.clone(), 230 }; 231 232 let token_set = TokenSet { 233 access_token: token_response.access_token.clone(), 234 refresh_token: token_response.refresh_token.clone(), 235 expires_in: token_response.expires_in, 236 token_type: token_response.token_type.clone(), 237 scope: token_response.scope.clone(), 238 sub: token_response.sub.clone(), 239 }; 240 241 // Update the session with the new tokens 242 let now = chrono::Utc::now(); 243 let expires_at = now + chrono::Duration::seconds(token_set.expires_in as i64); 244 245 oauth_session_update_tokens( 246 &self.storage_pool, 247 session_group, 248 &token_set.access_token, 249 &token_set.refresh_token, 250 expires_at, 251 ) 252 .await?; 253 254 let modified_expires_at = ((token_set.expires_in as f64) * 0.8).round() as i64; 255 let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis(); 256 257 let _: () = conn 258 .zadd(OAUTH_REFRESH_QUEUE, session_group, refresh_at) 259 .await 260 .map_err(RefreshError::PlaceInRefreshQueueFailed)?; 261 262 Ok(()) 263 } 264} 265 266fn build_worker_queue(worker_id: &str) -> String { 267 format!("{}:{}", OAUTH_REFRESH_QUEUE, worker_id) 268}