use std::time::Duration; use rand::RngExt; use reqwest::StatusCode; use serde::{Deserialize, Deserializer, Serializer}; /// outcome of [`RetryWithBackoff::retry`] when the operation does not succeed. pub enum RetryOutcome { /// ratelimited after exhausting all retries Ratelimited, /// non-ratelimit failure, carrying the last error Failed(E), } /// extension trait that adds `.retry()` to async `FnMut` closures. /// /// `on_ratelimit` receives the error and current attempt number. /// returning `Some(duration)` signals a transient failure and provides the backoff; /// returning `None` signals a terminal failure. pub trait RetryWithBackoff: FnMut() -> Fut where Fut: Future>, { #[allow(async_fn_in_trait)] async fn retry( &mut self, max_retries: u32, on_ratelimit: impl Fn(&E, u32) -> Option, ) -> Result> { let mut attempt = 0u32; loop { match self().await { Ok(val) => return Ok(val), Err(e) => match on_ratelimit(&e, attempt) { Some(_) if attempt >= max_retries => return Err(RetryOutcome::Ratelimited), Some(backoff) => { // jitter the backoff let backoff = rand::rng().random_range((backoff / 2)..backoff); tokio::time::sleep(backoff).await; attempt += 1; } None => return Err(RetryOutcome::Failed(e)), }, } } } } impl RetryWithBackoff for F where F: FnMut() -> Fut, Fut: Future>, { } /// extension trait that adds `.error_for_status()` to futures returning a reqwest `Response`. pub trait ErrorForStatus: Future> { fn error_for_status(self) -> impl Future> where Self: Sized, { futures::FutureExt::map(self, |r| r.and_then(|r| r.error_for_status())) } } impl>> ErrorForStatus for F {} /// extracts a retry delay in seconds from rate limit response headers. /// /// checks in priority order: /// - `retry-after: ` (relative) /// - `ratelimit-reset: ` (absolute) (ref pds sends this) pub fn parse_retry_after(resp: &reqwest::Response) -> Option { let headers = resp.headers(); let retry_after = headers .get(reqwest::header::RETRY_AFTER) .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()); let rate_limit_reset = headers .get("ratelimit-reset") .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) .map(|ts| { let now = chrono::Utc::now().timestamp(); (ts - now).max(1) as u64 }); retry_after.or(rate_limit_reset) } // cloudflare-specific status codes pub const CONNECTION_TIMEOUT: StatusCode = unsafe { match StatusCode::from_u16(522) { Ok(s) => s, _ => std::hint::unreachable_unchecked(), } }; pub const SITE_FROZEN: StatusCode = unsafe { match StatusCode::from_u16(530) { Ok(s) => s, _ => std::hint::unreachable_unchecked(), } }; pub fn ser_status_code(s: &Option, ser: S) -> Result { match s { Some(code) => ser.serialize_some(&code.as_u16()), None => ser.serialize_none(), } } pub fn deser_status_code<'de, D: Deserializer<'de>>( deser: D, ) -> Result, D::Error> { Option::::deserialize(deser)? .map(StatusCode::from_u16) .transpose() .map_err(serde::de::Error::custom) }