···1313pub mod storage;
1414pub mod task_identity_refresh;
1515pub mod task_oauth_requests_cleanup;
1616-pub mod task_refresh_tokens;
1716pub mod task_search_indexer;
1817pub mod task_search_indexer_errors;
1918pub mod task_webhooks;
-244
src/task_refresh_tokens.rs
···11-use anyhow::Result;
22-use atproto_identity::key::identify_key;
33-use atproto_oauth::workflow::{OAuthClient, oauth_refresh};
44-use chrono::{Duration, Utc};
55-use deadpool_redis::redis::{AsyncCommands, pipe};
66-use std::borrow::Cow;
77-use tokio::time::{Instant, sleep};
88-use tokio_util::sync::CancellationToken;
99-1010-use crate::{
1111- config::SigningKeys,
1212- refresh_tokens_errors::RefreshError,
1313- storage::{
1414- CachePool, StoragePool,
1515- cache::{OAUTH_REFRESH_HEARTBEATS, OAUTH_REFRESH_QUEUE, build_worker_queue},
1616- oauth::{oauth_session_delete, oauth_session_update, web_session_lookup},
1717- },
1818-};
1919-2020-pub struct RefreshTokensTaskConfig {
2121- pub sleep_interval: Duration,
2222- pub worker_id: String,
2323- pub external_url_base: String,
2424- pub signing_keys: SigningKeys,
2525-}
2626-2727-pub struct RefreshTokensTask {
2828- pub config: RefreshTokensTaskConfig,
2929- pub http_client: reqwest::Client,
3030- pub storage_pool: StoragePool,
3131- pub cache_pool: CachePool,
3232- pub document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>,
3333- pub cancellation_token: CancellationToken,
3434-}
3535-3636-impl RefreshTokensTask {
3737- #[must_use]
3838- pub fn new(
3939- config: RefreshTokensTaskConfig,
4040- http_client: reqwest::Client,
4141- storage_pool: StoragePool,
4242- cache_pool: CachePool,
4343- document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>,
4444- cancellation_token: CancellationToken,
4545- ) -> Self {
4646- Self {
4747- config,
4848- http_client,
4949- storage_pool,
5050- cache_pool,
5151- document_storage,
5252- cancellation_token,
5353- }
5454- }
5555-5656- /// Runs the refresh tokens task as a long-running process
5757- ///
5858- /// # Errors
5959- /// Returns an error if the sleep interval cannot be converted, or if there's a problem
6060- /// processing the work items
6161- pub async fn run(&self) -> Result<()> {
6262- tracing::debug!("RefreshTokensTask started");
6363-6464- let interval = self.config.sleep_interval.to_std()?;
6565-6666- let sleeper = sleep(interval);
6767- tokio::pin!(sleeper);
6868-6969- loop {
7070- tokio::select! {
7171- () = self.cancellation_token.cancelled() => {
7272- break;
7373- },
7474- () = &mut sleeper => {
7575- if let Err(err) = self.process_work().await {
7676- tracing::error!("RefreshTokensTask failed: {}", err);
7777- }
7878- sleeper.as_mut().reset(Instant::now() + interval);
7979- }
8080- }
8181- }
8282-8383- tracing::info!("RefreshTokensTask stopped");
8484-8585- Ok(())
8686- }
8787-8888- async fn process_work(&self) -> Result<i32> {
8989- let worker_queue = build_worker_queue(&self.config.worker_id);
9090-9191- let mut conn = self.cache_pool.get().await?;
9292-9393- let now = chrono::Utc::now();
9494- let epoch_millis = now.timestamp_millis();
9595-9696- let _: () = conn
9797- .hset(
9898- OAUTH_REFRESH_HEARTBEATS,
9999- &self.config.worker_id,
100100- now.to_string(),
101101- )
102102- .await?;
103103-104104- let global_queue_count: i32 = conn
105105- .zcount(OAUTH_REFRESH_QUEUE, 0, epoch_millis + 1)
106106- .await?;
107107- let worker_queue_count: i32 = conn.zcount(&worker_queue, 0, epoch_millis + 1).await?;
108108-109109- tracing::trace!(
110110- global_queue_count = global_queue_count,
111111- worker_queue_count = worker_queue_count,
112112- "queue counts"
113113- );
114114-115115- let mut process_work = worker_queue_count > 0;
116116-117117- if global_queue_count > 0 && worker_queue_count == 0 {
118118- let (moved, new_count): (i64, i64) = pipe()
119119- .atomic()
120120- // Take some work from the global queue and put it in the worker queue
121121- // ZRANGESTORE dst src min max [BYSCORE | BYLEX] [REV] [LIMIT offset count]
122122- .cmd("ZRANGESTORE")
123123- .arg(&worker_queue)
124124- .arg(OAUTH_REFRESH_QUEUE)
125125- .arg(0)
126126- .arg(epoch_millis)
127127- .arg("BYSCORE")
128128- .arg("LIMIT")
129129- .arg(0)
130130- .arg(5)
131131- // Update the global queue to remove the items that were moved
132132- .cmd("ZDIFFSTORE")
133133- .arg(OAUTH_REFRESH_QUEUE)
134134- .arg(2)
135135- .arg(OAUTH_REFRESH_QUEUE)
136136- .arg(&worker_queue)
137137- .query_async(&mut conn)
138138- .await?;
139139- process_work = true;
140140-141141- tracing::debug!(
142142- moved = moved,
143143- new_count = new_count,
144144- "moved work from global queue to worker queue"
145145- );
146146- }
147147-148148- if !process_work {
149149- return Ok(0);
150150- }
151151-152152- let count = 0;
153153- let results: Vec<(String, i64)> = conn
154154- .zrangebyscore_limit_withscores(&worker_queue, 0, epoch_millis, 0, 5)
155155- .await?;
156156-157157- for (session_group, deadline) in results {
158158- let _: () = conn.zrem(&worker_queue, &session_group).await?;
159159-160160- if let Err(err) = self
161161- .refresh_oauth_session(&mut conn, &session_group, deadline)
162162- .await
163163- {
164164- tracing::error!(session_group, deadline, err = ?err, "failed to refresh oauth session: {}", err);
165165-166166- if let Err(err) = oauth_session_delete(&self.storage_pool, &session_group).await {
167167- tracing::error!(session_group, err = ?err, "failed to delete oauth session: {}", err);
168168- }
169169- }
170170- }
171171-172172- Ok(count)
173173- }
174174-175175- async fn refresh_oauth_session(
176176- &self,
177177- conn: &mut deadpool_redis::Connection,
178178- session_group: &str,
179179- _deadline: i64,
180180- ) -> Result<()> {
181181- let (handle, oauth_session) =
182182- web_session_lookup(&self.storage_pool, session_group, None).await?;
183183-184184- let secret_signing_key_string = self
185185- .config
186186- .signing_keys
187187- .as_ref()
188188- .get(&oauth_session.secret_jwk_id)
189189- .cloned()
190190- .ok_or_else(|| anyhow::Error::from(RefreshError::SecretSigningKeyNotFound))?;
191191-192192- let private_signing_key_data = identify_key(&secret_signing_key_string)?;
193193-194194- let private_dpop_key_data = identify_key(&oauth_session.dpop_jwk)?;
195195-196196- let document = match self
197197- .document_storage
198198- .get_document_by_did(&handle.did)
199199- .await?
200200- {
201201- Some(doc) => doc,
202202- None => return Err(RefreshError::IdentityDocumentNotFound.into()),
203203- };
204204-205205- let oauth_client = OAuthClient {
206206- redirect_uri: format!("https://{}/oauth/callback", self.config.external_url_base),
207207- client_id: format!(
208208- "https://{}/oauth/client-metadata.json",
209209- self.config.external_url_base
210210- ),
211211- private_signing_key_data,
212212- };
213213-214214- let token_response = oauth_refresh(
215215- &self.http_client,
216216- &oauth_client,
217217- &private_dpop_key_data,
218218- &oauth_session.refresh_token,
219219- &document,
220220- )
221221- .await?;
222222-223223- let now = Utc::now();
224224-225225- oauth_session_update(
226226- &self.storage_pool,
227227- Cow::Borrowed(session_group),
228228- Cow::Borrowed(&token_response.access_token),
229229- Cow::Borrowed(&token_response.refresh_token),
230230- now + chrono::Duration::seconds(i64::from(token_response.expires_in)),
231231- )
232232- .await?;
233233-234234- let modified_expires_at = ((f64::from(token_response.expires_in)) * 0.8).round() as i64;
235235- let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis();
236236-237237- let _: () = conn
238238- .zadd(OAUTH_REFRESH_QUEUE, session_group, refresh_at)
239239- .await
240240- .map_err(RefreshError::PlaceInRefreshQueueFailed)?;
241241-242242- Ok(())
243243- }
244244-}