use anyhow::Result; use base64_url; use chrono::Duration; use deadpool_redis::redis::{pipe, AsyncCommands}; use p256::SecretKey; use tokio::time::{sleep, Instant}; use tokio_util::sync::CancellationToken; use rand::seq::SliceRandom; use crate::{ config::{SigningKeys, OAuthActiveKeys}, config_errors::ConfigError, jose::jwk::WrappedJsonWebKey, refresh_tokens_errors::RefreshError, storage::{ oauth::{ oauth_session_delete_by_id, oauth_session_update_tokens, web_session_lookup, }, types::{StoragePool, CachePool, OAUTH_REFRESH_QUEUE, OAUTH_REFRESH_HEARTBEATS, TokenSet}, }, }; pub struct RefreshTokensTaskConfig { pub sleep_interval: Duration, pub worker_id: String, pub external_url_base: String, pub signing_keys: SigningKeys, pub oauth_active_keys: OAuthActiveKeys, } impl RefreshTokensTaskConfig { pub fn select_oauth_signing_key(&self) -> Result<(String, SecretKey)> { let key_id = self .oauth_active_keys .as_ref() .choose(&mut rand::thread_rng()) .ok_or(ConfigError::SigningKeyNotFound)? .clone(); let signing_key = self .signing_keys .as_ref() .get(&key_id) .ok_or(ConfigError::SigningKeyNotFound)? .clone(); Ok((key_id, signing_key)) } } pub struct RefreshTokensTask { pub config: RefreshTokensTaskConfig, pub http_client: reqwest::Client, pub storage_pool: StoragePool, pub cache_pool: CachePool, pub cancellation_token: CancellationToken, } impl RefreshTokensTask { #[must_use] pub fn new( config: RefreshTokensTaskConfig, http_client: reqwest::Client, storage_pool: StoragePool, cache_pool: CachePool, cancellation_token: CancellationToken, ) -> Self { Self { config, http_client, storage_pool, cache_pool, cancellation_token, } } /// Runs the refresh tokens task as a long-running process /// /// # Errors /// Returns an error if the sleep interval cannot be converted, or if there's a problem /// processing the work items pub async fn run(&self) -> Result<()> { tracing::debug!("RefreshTokensTask started"); let interval = self.config.sleep_interval.to_std()?; let sleeper = sleep(interval); tokio::pin!(sleeper); loop { tokio::select! { () = self.cancellation_token.cancelled() => { break; }, () = &mut sleeper => { if let Err(err) = self.process_work().await { tracing::error!("RefreshTokensTask failed: {}", err); } sleeper.as_mut().reset(Instant::now() + interval); } } } tracing::info!("RefreshTokensTask stopped"); Ok(()) } async fn process_work(&self) -> Result { let worker_queue = build_worker_queue(&self.config.worker_id); let mut conn = self.cache_pool.get().await?; let now = chrono::Utc::now(); let epoch_millis = now.timestamp_millis(); let _: () = conn .hset( OAUTH_REFRESH_HEARTBEATS, &self.config.worker_id, now.to_string(), ) .await?; let global_queue_count: i32 = conn .zcount(OAUTH_REFRESH_QUEUE, 0, epoch_millis + 1) .await?; let worker_queue_count: i32 = conn.zcount(&worker_queue, 0, epoch_millis + 1).await?; tracing::trace!( global_queue_count = global_queue_count, worker_queue_count = worker_queue_count, "queue counts" ); let mut process_work = worker_queue_count > 0; if global_queue_count > 0 && worker_queue_count == 0 { let (moved, new_count): (i64, i64) = pipe() .atomic() // Take some work from the global queue and put it in the worker queue // ZRANGESTORE dst src min max [BYSCORE | BYLEX] [REV] [LIMIT offset count] .cmd("ZRANGESTORE") .arg(&worker_queue) .arg(OAUTH_REFRESH_QUEUE) .arg(0) .arg(epoch_millis) .arg("BYSCORE") .arg("LIMIT") .arg(0) .arg(5) // Update the global queue to remove the items that were moved .cmd("ZDIFFSTORE") .arg(OAUTH_REFRESH_QUEUE) .arg(2) .arg(OAUTH_REFRESH_QUEUE) .arg(&worker_queue) .query_async(&mut conn) .await?; process_work = true; tracing::debug!( moved = moved, new_count = new_count, "moved work from global queue to worker queue" ); } if !process_work { return Ok(0); } let count = 0; let results: Vec<(String, i64)> = conn .zrangebyscore_limit_withscores(&worker_queue, 0, epoch_millis, 0, 5) .await?; for (session_group, deadline) in results { tracing::info!(session_group, deadline, "processing work"); let _: () = conn.zrem(&worker_queue, &session_group).await?; if let Err(err) = self .refresh_oauth_session(&mut conn, &session_group, deadline) .await { tracing::error!(session_group, deadline, err = ?err, "failed to refresh oauth session: {}", err); if let Err(err) = oauth_session_delete_by_id(&self.storage_pool, &session_group).await { tracing::error!(session_group, err = ?err, "failed to delete oauth session: {}", err); } } } Ok(count) } async fn refresh_oauth_session( &self, conn: &mut deadpool_redis::Connection, session_group: &str, _deadline: i64, ) -> Result<()> { let (_handle, oauth_session) = web_session_lookup(&self.storage_pool, session_group, None).await?; // Use a signing key from the available OAuth signing keys let (_, _secret_signing_key) = self.config.select_oauth_signing_key()?; // Simplified DPoP key extraction - skip complex error handling for now let _dpop_secret_key_opt: Option = oauth_session .dpop_jwk .as_ref() .and_then(|sqlx_jwk: &sqlx::types::Json| { let jwk = &sqlx_jwk.0.jwk; jwk.d.as_ref().and_then(|d_b64u| { base64_url::decode(d_b64u).ok() .and_then(|d_bytes| SecretKey::from_slice(&d_bytes).ok()) }) }); // For now, create a simplified token response that matches TokenSet // This would need to be replaced with actual OAuth refresh call let token_response = crate::oauth::model::TokenResponse { access_token: oauth_session.access_token.clone(), token_type: "Bearer".to_string(), refresh_token: oauth_session.refresh_token.clone(), scope: "".to_string(), expires_in: 3600, sub: oauth_session.did.clone(), }; let token_set = TokenSet { access_token: token_response.access_token.clone(), refresh_token: token_response.refresh_token.clone(), expires_in: token_response.expires_in, token_type: token_response.token_type.clone(), scope: token_response.scope.clone(), sub: token_response.sub.clone(), }; // Update the session with the new tokens let now = chrono::Utc::now(); let expires_at = now + chrono::Duration::seconds(token_set.expires_in as i64); oauth_session_update_tokens( &self.storage_pool, session_group, &token_set.access_token, &token_set.refresh_token, expires_at, ) .await?; let modified_expires_at = ((token_set.expires_in as f64) * 0.8).round() as i64; let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis(); let _: () = conn .zadd(OAUTH_REFRESH_QUEUE, session_group, refresh_at) .await .map_err(RefreshError::PlaceInRefreshQueueFailed)?; Ok(()) } } fn build_worker_queue(worker_id: &str) -> String { format!("{}:{}", OAUTH_REFRESH_QUEUE, worker_id) }