Server tools to backfill, tail, mirror, and verify PLC logs

experimental: forward writes to upstream

+259 -67
+31 -3
src/bin/mirror.rs
··· 1 - use allegedly::{Db, ListenConf, bin::GlobalArgs, bin_init, pages_to_pg, poll_upstream, serve}; 2 use clap::Parser; 3 use reqwest::Url; 4 use std::{net::SocketAddr, path::PathBuf}; ··· 39 #[arg(long, requires("acme_domain"), env = "ALLEGEDLY_ACME_DIRECTORY_URL")] 40 #[clap(default_value = "https://acme-v02.api.letsencrypt.org/directory")] 41 acme_directory_url: Url, 42 - /// listen for ipv6 43 #[arg(long, action, requires("acme_domain"), env = "ALLEGEDLY_ACME_IPV6")] 44 acme_ipv6: bool, 45 } 46 47 pub async fn run( ··· 55 acme_cache_path, 56 acme_directory_url, 57 acme_ipv6, 58 }: Args, 59 ) -> anyhow::Result<()> { 60 let db = Db::new(wrap_pg.as_str(), wrap_pg_cert).await?; ··· 79 } 80 (bind, true, None) => ListenConf::Bind(bind), 81 (_, _, _) => unreachable!(), 82 }; 83 84 let mut tasks = JoinSet::new(); ··· 90 91 tasks.spawn(poll_upstream(Some(latest), poll_url, send_page)); 92 tasks.spawn(pages_to_pg(db.clone(), recv_page)); 93 - tasks.spawn(serve(upstream, wrap, listen_conf, db.clone())); 94 95 while let Some(next) = tasks.join_next().await { 96 match next {
··· 1 + use allegedly::{ 2 + Db, ExperimentalConf, ListenConf, bin::GlobalArgs, bin_init, pages_to_pg, poll_upstream, serve, 3 + }; 4 use clap::Parser; 5 use reqwest::Url; 6 use std::{net::SocketAddr, path::PathBuf}; ··· 41 #[arg(long, requires("acme_domain"), env = "ALLEGEDLY_ACME_DIRECTORY_URL")] 42 #[clap(default_value = "https://acme-v02.api.letsencrypt.org/directory")] 43 acme_directory_url: Url, 44 + /// try to listen for ipv6 45 #[arg(long, action, requires("acme_domain"), env = "ALLEGEDLY_ACME_IPV6")] 46 acme_ipv6: bool, 47 + /// only accept experimental requests at this hostname 48 + /// 49 + /// a cert will be provisioned for it from letsencrypt. if you're not using 50 + /// acme (eg., behind a tls-terminating reverse proxy), open a feature request. 51 + #[arg( 52 + long, 53 + requires("acme_domain"), 54 + env = "ALLEGEDLY_EXPERIMENTAL_ACME_DOMAIN" 55 + )] 56 + experimental_acme_domain: Option<String>, 57 + /// accept writes! by forwarding them upstream 58 + #[arg(long, action, env = "ALLEGEDLY_EXPERIMENTAL_WRITE_UPSTREAM")] 59 + experimental_write_upstream: bool, 60 } 61 62 pub async fn run( ··· 70 acme_cache_path, 71 acme_directory_url, 72 acme_ipv6, 73 + experimental_acme_domain, 74 + experimental_write_upstream, 75 }: Args, 76 ) -> anyhow::Result<()> { 77 let db = Db::new(wrap_pg.as_str(), wrap_pg_cert).await?; ··· 96 } 97 (bind, true, None) => ListenConf::Bind(bind), 98 (_, _, _) => unreachable!(), 99 + }; 100 + 101 + let experimental_conf = ExperimentalConf { 102 + acme_domain: experimental_acme_domain, 103 + write_upstream: experimental_write_upstream, 104 }; 105 106 let mut tasks = JoinSet::new(); ··· 112 113 tasks.spawn(poll_upstream(Some(latest), poll_url, send_page)); 114 tasks.spawn(pages_to_pg(db.clone(), recv_page)); 115 + tasks.spawn(serve( 116 + upstream, 117 + wrap, 118 + listen_conf, 119 + experimental_conf, 120 + db.clone(), 121 + )); 122 123 while let Some(next) = tasks.join_next().await { 124 match next {
+2 -2
src/lib.rs
··· 15 pub use backfill::backfill; 16 pub use cached_value::{CachedValue, Fetcher}; 17 pub use client::{CLIENT, UA}; 18 - pub use mirror::{ListenConf, serve}; 19 pub use plc_pg::{Db, backfill_to_pg, pages_to_pg}; 20 pub use poll::{PageBoundaryState, get_page, poll_upstream}; 21 - pub use ratelimit::GovernorMiddleware; 22 pub use weekly::{BundleSource, FolderSource, HttpSource, Week, pages_to_weeks, week_to_pages}; 23 24 pub type Dt = chrono::DateTime<chrono::Utc>;
··· 15 pub use backfill::backfill; 16 pub use cached_value::{CachedValue, Fetcher}; 17 pub use client::{CLIENT, UA}; 18 + pub use mirror::{ExperimentalConf, ListenConf, serve}; 19 pub use plc_pg::{Db, backfill_to_pg, pages_to_pg}; 20 pub use poll::{PageBoundaryState, get_page, poll_upstream}; 21 + pub use ratelimit::{CreatePlcOpLimiter, GovernorMiddleware, IpLimiters}; 22 pub use weekly::{BundleSource, FolderSource, HttpSource, Week, pages_to_weeks, week_to_pages}; 23 24 pub type Dt = chrono::DateTime<chrono::Utc>;
+138 -29
src/mirror.rs
··· 1 - use crate::{CachedValue, Db, Dt, Fetcher, GovernorMiddleware, UA, logo}; 2 use futures::TryStreamExt; 3 use governor::Quota; 4 use poem::{ 5 - Endpoint, EndpointExt, Error, IntoResponse, Request, Response, Result, Route, Server, get, 6 - handler, 7 http::StatusCode, 8 listener::{Listener, TcpListener, acme::AutoCert}, 9 middleware::{AddData, CatchPanic, Compression, Cors, Tracing}, 10 - web::{Data, Json}, 11 }; 12 use reqwest::{Client, Url}; 13 use std::{net::SocketAddr, path::PathBuf, time::Duration}; ··· 19 upstream: Url, 20 latest_at: CachedValue<Dt, GetLatestAt>, 21 upstream_status: CachedValue<PlcStatus, CheckUpstream>, 22 } 23 24 #[handler] ··· 69 include_bytes!("../favicon.ico").with_content_type("image/x-icon") 70 } 71 72 - fn failed_to_reach_wrapped() -> String { 73 format!( 74 r#"{} 75 76 - Failed to reach the wrapped reference PLC server. Sorry. 77 "#, 78 logo("mirror 502 :( ") 79 ) 80 } 81 82 type PlcStatus = (bool, serde_json::Value); 83 84 async fn plc_status(url: &Url, client: &Client) -> PlcStatus { ··· 168 ) 169 } 170 171 #[handler] 172 - async fn proxy(req: &Request, Data(state): Data<&State>) -> Result<impl IntoResponse> { 173 let mut target = state.plc.clone(); 174 target.set_path(req.uri().path()); 175 - let upstream_res = state 176 .client 177 .get(target) 178 .timeout(Duration::from_secs(3)) // should be low latency to wrapped server ··· 181 .await 182 .map_err(|e| { 183 log::error!("upstream req fail: {e}"); 184 - Error::from_string(failed_to_reach_wrapped(), StatusCode::BAD_GATEWAY) 185 })?; 186 187 - let http_res: poem::http::Response<reqwest::Body> = upstream_res.into(); 188 - let (parts, reqw_body) = http_res.into_parts(); 189 190 - let parts = poem::ResponseParts { 191 - status: parts.status, 192 - version: parts.version, 193 - headers: parts.headers, 194 - extensions: parts.extensions, 195 - }; 196 197 - let body = http_body_util::BodyDataStream::new(reqw_body) 198 - .map_err(|e| std::io::Error::other(Box::new(e))); 199 200 - Ok(Response::from_parts( 201 - parts, 202 - poem::Body::from_bytes_stream(body), 203 - )) 204 } 205 206 #[handler] ··· 212 213 Sorry, this server does not accept POST requests. 214 215 - You may wish to try upstream: {upstream} 216 "#, 217 logo("mirror (nope)") 218 ), ··· 230 Bind(SocketAddr), 231 } 232 233 pub async fn serve( 234 upstream: Url, 235 plc: Url, 236 listen: ListenConf, 237 db: Db, 238 ) -> anyhow::Result<&'static str> { 239 log::info!("starting server..."); ··· 257 upstream: upstream.clone(), 258 latest_at, 259 upstream_status, 260 }; 261 262 - let app = Route::new() 263 .at("/", get(hello)) 264 .at("/favicon.ico", get(favicon)) 265 - .at("/_health", get(health)) 266 - .at("/:any", get(proxy).post(nope)) 267 .with(AddData::new(state)) 268 .with(Cors::new().allow_credentials(false)) 269 .with(Compression::new()) 270 - .with(GovernorMiddleware::new(Quota::per_minute( 271 3000.try_into().expect("ratelimit middleware to build"), 272 - ))) 273 .with(CatchPanic::new()) 274 .with(Tracing); 275 ··· 288 .directory_url(directory_url) 289 .cache_path(cache_path); 290 for domain in domains { 291 auto_cert = auto_cert.domain(domain); 292 } 293 let auto_cert = auto_cert.build().expect("acme config to build");
··· 1 + use crate::{ 2 + CachedValue, CreatePlcOpLimiter, Db, Dt, Fetcher, GovernorMiddleware, IpLimiters, UA, logo, 3 + }; 4 use futures::TryStreamExt; 5 use governor::Quota; 6 use poem::{ 7 + Body, Endpoint, EndpointExt, Error, IntoResponse, Request, Response, Result, Route, Server, 8 + get, handler, 9 http::StatusCode, 10 listener::{Listener, TcpListener, acme::AutoCert}, 11 middleware::{AddData, CatchPanic, Compression, Cors, Tracing}, 12 + web::{Data, Json, Path}, 13 }; 14 use reqwest::{Client, Url}; 15 use std::{net::SocketAddr, path::PathBuf, time::Duration}; ··· 21 upstream: Url, 22 latest_at: CachedValue<Dt, GetLatestAt>, 23 upstream_status: CachedValue<PlcStatus, CheckUpstream>, 24 + experimental: ExperimentalConf, 25 } 26 27 #[handler] ··· 72 include_bytes!("../favicon.ico").with_content_type("image/x-icon") 73 } 74 75 + fn failed_to_reach_named(name: &str) -> String { 76 format!( 77 r#"{} 78 79 + Failed to reach the {name} server. Sorry. 80 "#, 81 logo("mirror 502 :( ") 82 ) 83 } 84 85 + fn bad_create_op(reason: &str) -> Response { 86 + Response::builder() 87 + .status(StatusCode::BAD_REQUEST) 88 + .body(format!( 89 + r#"{} 90 + 91 + NooOOOooooo: {reason} 92 + "#, 93 + logo("mirror 400 >:( ") 94 + )) 95 + } 96 + 97 type PlcStatus = (bool, serde_json::Value); 98 99 async fn plc_status(url: &Url, client: &Client) -> PlcStatus { ··· 183 ) 184 } 185 186 + fn proxy_response(res: reqwest::Response) -> Response { 187 + let http_res: poem::http::Response<reqwest::Body> = res.into(); 188 + let (parts, reqw_body) = http_res.into_parts(); 189 + 190 + let parts = poem::ResponseParts { 191 + status: parts.status, 192 + version: parts.version, 193 + headers: parts.headers, 194 + extensions: parts.extensions, 195 + }; 196 + 197 + let body = http_body_util::BodyDataStream::new(reqw_body) 198 + .map_err(|e| std::io::Error::other(Box::new(e))); 199 + 200 + Response::from_parts(parts, poem::Body::from_bytes_stream(body)) 201 + } 202 + 203 #[handler] 204 + async fn proxy(req: &Request, Data(state): Data<&State>) -> Result<Response> { 205 let mut target = state.plc.clone(); 206 target.set_path(req.uri().path()); 207 + let wrapped_res = state 208 .client 209 .get(target) 210 .timeout(Duration::from_secs(3)) // should be low latency to wrapped server ··· 213 .await 214 .map_err(|e| { 215 log::error!("upstream req fail: {e}"); 216 + Error::from_string( 217 + failed_to_reach_named("wrapped reference PLC"), 218 + StatusCode::BAD_GATEWAY, 219 + ) 220 })?; 221 222 + Ok(proxy_response(wrapped_res)) 223 + } 224 225 + #[handler] 226 + async fn forward_create_op_upstream( 227 + Data(State { 228 + upstream, 229 + client, 230 + experimental, 231 + .. 232 + }): Data<&State>, 233 + Path(did): Path<String>, 234 + req: &Request, 235 + body: Body, 236 + ) -> Result<Response> { 237 + if let Some(expected_domain) = &experimental.acme_domain { 238 + let Some(found_host) = req.header("Host") else { 239 + return Ok(bad_create_op(&format!( 240 + "missing `Host` header, expected {expected_domain} for experimental requests." 241 + ))); 242 + }; 243 + if found_host != expected_domain { 244 + return Ok(bad_create_op(&format!( 245 + "experimental requests must be made to {expected_domain}, but this request's `Host` header was {found_host}" 246 + ))); 247 + } 248 + } 249 + 250 + // adjust proxied headers 251 + let mut headers: reqwest::header::HeaderMap = req.headers().clone(); 252 + log::trace!("original request headers: {headers:?}"); 253 + headers.insert("Host", upstream.host_str().unwrap().parse().unwrap()); 254 + let client_ua = headers 255 + .get("User-Agent") 256 + .map(|h| h.to_str().unwrap()) 257 + .unwrap_or("unknown"); 258 + headers.insert( 259 + "User-Agent", 260 + format!("{UA} (forwarding from {client_ua:?})") 261 + .parse() 262 + .unwrap(), 263 + ); 264 + log::trace!("adjusted request headers: {headers:?}"); 265 266 + let mut target = upstream.clone(); 267 + target.set_path(&did); 268 + let upstream_res = client 269 + .post(target) 270 + .timeout(Duration::from_secs(15)) // be a little generous 271 + .headers(headers) 272 + .body(reqwest::Body::wrap_stream(body.into_bytes_stream())) 273 + .send() 274 + .await 275 + .map_err(|e| { 276 + log::warn!("upstream write fail: {e}"); 277 + Error::from_string( 278 + failed_to_reach_named("upstream PLC"), 279 + StatusCode::BAD_GATEWAY, 280 + ) 281 + })?; 282 283 + Ok(proxy_response(upstream_res)) 284 } 285 286 #[handler] ··· 292 293 Sorry, this server does not accept POST requests. 294 295 + You may wish to try sending that to our upstream: {upstream}. 296 + 297 + If you operate this server, try running with `--experimental-write-upstream`. 298 "#, 299 logo("mirror (nope)") 300 ), ··· 312 Bind(SocketAddr), 313 } 314 315 + #[derive(Debug, Clone)] 316 + pub struct ExperimentalConf { 317 + pub acme_domain: Option<String>, 318 + pub write_upstream: bool, 319 + } 320 + 321 pub async fn serve( 322 upstream: Url, 323 plc: Url, 324 listen: ListenConf, 325 + experimental: ExperimentalConf, 326 db: Db, 327 ) -> anyhow::Result<&'static str> { 328 log::info!("starting server..."); ··· 346 upstream: upstream.clone(), 347 latest_at, 348 upstream_status, 349 + experimental: experimental.clone(), 350 }; 351 352 + let mut app = Route::new() 353 .at("/", get(hello)) 354 .at("/favicon.ico", get(favicon)) 355 + .at("/_health", get(health)); 356 + 357 + if experimental.write_upstream { 358 + log::info!("enabling experimental write forwarding to upstream"); 359 + 360 + let ip_limiter = IpLimiters::new(Quota::per_hour(10.try_into().unwrap())); 361 + let did_limiter = CreatePlcOpLimiter::new(Quota::per_hour(4.try_into().unwrap())); 362 + 363 + let upstream_proxier = forward_create_op_upstream 364 + .with(GovernorMiddleware::new(did_limiter)) 365 + .with(GovernorMiddleware::new(ip_limiter)); 366 + 367 + app = app.at("/:any", get(proxy).post(upstream_proxier)); 368 + } else { 369 + app = app.at("/:any", get(proxy).post(nope)); 370 + } 371 + 372 + let app = app 373 .with(AddData::new(state)) 374 .with(Cors::new().allow_credentials(false)) 375 .with(Compression::new()) 376 + .with(GovernorMiddleware::new(IpLimiters::new(Quota::per_minute( 377 3000.try_into().expect("ratelimit middleware to build"), 378 + )))) 379 .with(CatchPanic::new()) 380 .with(Tracing); 381 ··· 394 .directory_url(directory_url) 395 .cache_path(cache_path); 396 for domain in domains { 397 + auto_cert = auto_cert.domain(domain); 398 + } 399 + if let Some(domain) = experimental.acme_domain { 400 auto_cert = auto_cert.domain(domain); 401 } 402 let auto_cert = auto_cert.build().expect("acme config to build");
+88 -33
src/ratelimit.rs
··· 8 use poem::{Endpoint, Middleware, Request, Response, Result, http::StatusCode}; 9 use std::{ 10 convert::TryInto, 11 net::{IpAddr, Ipv6Addr}, 12 sync::{Arc, LazyLock}, 13 time::Duration, ··· 20 type IP6_56 = [u8; 7]; 21 type IP6_48 = [u8; 6]; 22 23 fn scale_quota(quota: Quota, factor: u32) -> Option<Quota> { 24 let period = quota.replenish_interval() / factor; 25 let burst = quota ··· 30 } 31 32 #[derive(Debug)] 33 - struct IpLimiters { 34 per_ip: RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>, 35 ip6_56: RateLimiter<IP6_56, DefaultKeyedStateStore<IP6_56>, DefaultClock>, 36 ip6_48: RateLimiter<IP6_48, DefaultKeyedStateStore<IP6_48>, DefaultClock>, ··· 44 ip6_48: RateLimiter::keyed(scale_quota(quota, 256).expect("to scale quota")), 45 } 46 } 47 - pub fn check_key(&self, ip: IpAddr) -> Result<(), Duration> { 48 let asdf = |n: NotUntil<_>| n.wait_time_from(CLOCK.now()); 49 match ip { 50 - addr @ IpAddr::V4(_) => self.per_ip.check_key(&addr).map_err(asdf), 51 IpAddr::V6(a) => { 52 // always check all limiters 53 let check_ip = self ··· 74 } 75 } 76 } 77 } 78 79 /// Once the rate limit has been reached, the middleware will respond with 80 /// status code 429 (too many requests) and a `Retry-After` header with the amount 81 /// of time that needs to pass before another request will be allowed. 82 - #[derive(Debug)] 83 - pub struct GovernorMiddleware { 84 #[allow(dead_code)] 85 stop_on_drop: oneshot::Sender<()>, 86 - limiters: Arc<IpLimiters>, 87 } 88 89 - impl GovernorMiddleware { 90 /// Limit request rates 91 /// 92 /// a little gross but this spawns a tokio task for housekeeping: 93 /// https://docs.rs/governor/latest/governor/struct.RateLimiter.html#keyed-rate-limiters---housekeeping 94 - pub fn new(quota: Quota) -> Self { 95 - let limiters = Arc::new(IpLimiters::new(quota)); 96 let (stop_on_drop, mut stopped) = oneshot::channel(); 97 tokio::task::spawn({ 98 let limiters = limiters.clone(); ··· 102 _ = &mut stopped => break, 103 _ = tokio::time::sleep(Duration::from_secs(60)) => {}, 104 }; 105 - log::debug!( 106 - "limiter sizes before housekeeping: {}/ip {}/v6_56 {}/v6_48", 107 - limiters.per_ip.len(), 108 - limiters.ip6_56.len(), 109 - limiters.ip6_48.len(), 110 - ); 111 - limiters.per_ip.retain_recent(); 112 - limiters.ip6_56.retain_recent(); 113 - limiters.ip6_48.retain_recent(); 114 } 115 } 116 }); ··· 121 } 122 } 123 124 - impl<E: Endpoint> Middleware<E> for GovernorMiddleware { 125 - type Output = GovernorMiddlewareImpl<E>; 126 fn transform(&self, ep: E) -> Self::Output { 127 GovernorMiddlewareImpl { 128 ep, ··· 131 } 132 } 133 134 - pub struct GovernorMiddlewareImpl<E> { 135 ep: E, 136 - limiters: Arc<IpLimiters>, 137 } 138 139 - impl<E: Endpoint> Endpoint for GovernorMiddlewareImpl<E> { 140 type Output = E::Output; 141 142 async fn call(&self, req: Request) -> Result<Self::Output> { 143 - let remote = req 144 - .remote_addr() 145 - .as_socket_addr() 146 - .expect("failed to get request's remote addr") // TODO 147 - .ip(); 148 149 - log::trace!("remote: {remote}"); 150 - 151 - match self.limiters.check_key(remote) { 152 Ok(_) => { 153 - log::debug!("allowing remote {remote}"); 154 self.ep.call(req).await 155 } 156 Err(d) => { 157 let wait_time = d.as_secs(); 158 159 - log::debug!("rate limit exceeded for {remote}, quota reset in {wait_time}s"); 160 161 let res = Response::builder() 162 .status(StatusCode::TOO_MANY_REQUESTS)
··· 8 use poem::{Endpoint, Middleware, Request, Response, Result, http::StatusCode}; 9 use std::{ 10 convert::TryInto, 11 + hash::Hash, 12 net::{IpAddr, Ipv6Addr}, 13 sync::{Arc, LazyLock}, 14 time::Duration, ··· 21 type IP6_56 = [u8; 7]; 22 type IP6_48 = [u8; 6]; 23 24 + pub trait Limiter<K: Hash + std::fmt::Debug>: Send + Sync + 'static { 25 + fn extract_key(&self, req: &Request) -> Result<K>; 26 + fn check_key(&self, ip: &K) -> Result<(), Duration>; 27 + fn housekeep(&self); 28 + } 29 + 30 fn scale_quota(quota: Quota, factor: u32) -> Option<Quota> { 31 let period = quota.replenish_interval() / factor; 32 let burst = quota ··· 37 } 38 39 #[derive(Debug)] 40 + pub struct CreatePlcOpLimiter { 41 + limiter: RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>, 42 + } 43 + 44 + impl CreatePlcOpLimiter { 45 + pub fn new(quota: Quota) -> Self { 46 + Self { 47 + limiter: RateLimiter::keyed(quota), 48 + } 49 + } 50 + } 51 + 52 + /// this must be used with an endpoint with a single path param for the did 53 + impl Limiter<String> for CreatePlcOpLimiter { 54 + fn extract_key(&self, req: &Request) -> Result<String> { 55 + let (did,) = req.path_params::<(String,)>()?; 56 + Ok(did) 57 + } 58 + fn check_key(&self, did: &String) -> Result<(), Duration> { 59 + self.limiter 60 + .check_key(did) 61 + .map_err(|e| e.wait_time_from(CLOCK.now())) 62 + } 63 + fn housekeep(&self) { 64 + log::debug!( 65 + "limiter size before housekeeping: {} dids", 66 + self.limiter.len() 67 + ); 68 + self.limiter.retain_recent(); 69 + } 70 + } 71 + 72 + #[derive(Debug)] 73 + pub struct IpLimiters { 74 per_ip: RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>, 75 ip6_56: RateLimiter<IP6_56, DefaultKeyedStateStore<IP6_56>, DefaultClock>, 76 ip6_48: RateLimiter<IP6_48, DefaultKeyedStateStore<IP6_48>, DefaultClock>, ··· 84 ip6_48: RateLimiter::keyed(scale_quota(quota, 256).expect("to scale quota")), 85 } 86 } 87 + } 88 + 89 + impl Limiter<IpAddr> for IpLimiters { 90 + fn extract_key(&self, req: &Request) -> Result<IpAddr> { 91 + Ok(req 92 + .remote_addr() 93 + .as_socket_addr() 94 + .expect("failed to get request's remote addr") // TODO 95 + .ip()) 96 + } 97 + fn check_key(&self, ip: &IpAddr) -> Result<(), Duration> { 98 let asdf = |n: NotUntil<_>| n.wait_time_from(CLOCK.now()); 99 match ip { 100 + addr @ IpAddr::V4(_) => self.per_ip.check_key(addr).map_err(asdf), 101 IpAddr::V6(a) => { 102 // always check all limiters 103 let check_ip = self ··· 124 } 125 } 126 } 127 + fn housekeep(&self) { 128 + log::debug!( 129 + "limiter sizes before housekeeping: {}/ip {}/v6_56 {}/v6_48", 130 + self.per_ip.len(), 131 + self.ip6_56.len(), 132 + self.ip6_48.len(), 133 + ); 134 + self.per_ip.retain_recent(); 135 + self.ip6_56.retain_recent(); 136 + self.ip6_48.retain_recent(); 137 + } 138 } 139 140 /// Once the rate limit has been reached, the middleware will respond with 141 /// status code 429 (too many requests) and a `Retry-After` header with the amount 142 /// of time that needs to pass before another request will be allowed. 143 + // #[derive(Debug)] 144 + pub struct GovernorMiddleware<K> { 145 #[allow(dead_code)] 146 stop_on_drop: oneshot::Sender<()>, 147 + limiters: Arc<dyn Limiter<K>>, 148 } 149 150 + impl<K: Hash + std::fmt::Debug> GovernorMiddleware<K> { 151 /// Limit request rates 152 /// 153 /// a little gross but this spawns a tokio task for housekeeping: 154 /// https://docs.rs/governor/latest/governor/struct.RateLimiter.html#keyed-rate-limiters---housekeeping 155 + pub fn new(limiters: impl Limiter<K>) -> Self { 156 + let limiters = Arc::new(limiters); 157 let (stop_on_drop, mut stopped) = oneshot::channel(); 158 tokio::task::spawn({ 159 let limiters = limiters.clone(); ··· 163 _ = &mut stopped => break, 164 _ = tokio::time::sleep(Duration::from_secs(60)) => {}, 165 }; 166 + limiters.housekeep(); 167 } 168 } 169 }); ··· 174 } 175 } 176 177 + impl<E, K> Middleware<E> for GovernorMiddleware<K> 178 + where 179 + E: Endpoint, 180 + K: Hash + std::fmt::Debug + Send + Sync + 'static, 181 + { 182 + type Output = GovernorMiddlewareImpl<E, K>; 183 fn transform(&self, ep: E) -> Self::Output { 184 GovernorMiddlewareImpl { 185 ep, ··· 188 } 189 } 190 191 + pub struct GovernorMiddlewareImpl<E, K> { 192 ep: E, 193 + limiters: Arc<dyn Limiter<K>>, 194 } 195 196 + impl<E, K> Endpoint for GovernorMiddlewareImpl<E, K> 197 + where 198 + E: Endpoint, 199 + K: Hash + std::fmt::Debug + Send + Sync + 'static, 200 + { 201 type Output = E::Output; 202 203 async fn call(&self, req: Request) -> Result<Self::Output> { 204 + let key = self.limiters.extract_key(&req)?; 205 206 + match self.limiters.check_key(&key) { 207 Ok(_) => { 208 + log::debug!("allowing key {key:?}"); 209 self.ep.call(req).await 210 } 211 Err(d) => { 212 let wait_time = d.as_secs(); 213 214 + log::debug!("rate limit exceeded for {key:?}, quota reset in {wait_time}s"); 215 216 let res = Response::builder() 217 .status(StatusCode::TOO_MANY_REQUESTS)