this repo has no description

extract proxying into a middleware instead of a fallback handler

authored by nelind and committed by tangled.org 9853d456 4fbd820c

Changed files
+104 -120
src
+4
Cargo.lock
··· 6315 6315 "dotenvy", 6316 6316 "ed25519-dalek", 6317 6317 "futures", 6318 + "futures-util", 6318 6319 "governor", 6319 6320 "hex", 6320 6321 "hickory-resolver", 6321 6322 "hkdf", 6322 6323 "hmac", 6324 + "http 1.4.0", 6323 6325 "image", 6324 6326 "ipld-core", 6325 6327 "iroh-car", ··· 6352 6354 "tokio", 6353 6355 "tokio-tungstenite", 6354 6356 "totp-rs", 6357 + "tower", 6355 6358 "tower-http", 6359 + "tower-layer", 6356 6360 "tracing", 6357 6361 "tracing-subscriber", 6358 6362 "urlencoding",
+4
Cargo.toml
··· 64 64 webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-user-presence-only-security-keys"] } 65 65 webauthn-rs-proto = "0.5.4" 66 66 zip = { version = "7.0.0", default-features = false, features = ["deflate"] } 67 + tower = "0.5.2" 68 + tower-layer = "0.3.3" 69 + futures-util = "0.3.31" 70 + http = "1.4.0" 67 71 [features] 68 72 external-infra = [] 69 73 [dev-dependencies]
+4 -55
src/api/age_assurance.rs
··· 2 2 use crate::state::AppState; 3 3 use axum::{ 4 4 Json, 5 - body::Bytes, 6 - extract::{Path, RawQuery, State}, 7 - http::{HeaderMap, Method, StatusCode}, 5 + extract::State, 6 + http::{HeaderMap, StatusCode}, 8 7 response::{IntoResponse, Response}, 9 8 }; 10 9 use serde_json::json; 11 10 12 - pub async fn get_state( 13 - State(state): State<AppState>, 14 - headers: HeaderMap, 15 - RawQuery(query): RawQuery, 16 - ) -> Response { 17 - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() { 18 - return proxy_to_appview(state, headers, "app.bsky.ageassurance.getState", query).await; 19 - } 20 - 11 + pub async fn get_state(State(state): State<AppState>, headers: HeaderMap) -> Response { 21 12 let created_at = get_account_created_at(&state, &headers).await; 22 13 let now = chrono::Utc::now().to_rfc3339(); 23 14 ··· 37 28 .into_response() 38 29 } 39 30 40 - pub async fn get_age_assurance_state( 41 - State(state): State<AppState>, 42 - headers: HeaderMap, 43 - RawQuery(query): RawQuery, 44 - ) -> Response { 45 - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() { 46 - return proxy_to_appview( 47 - state, 48 - headers, 49 - "app.bsky.unspecced.getAgeAssuranceState", 50 - query, 51 - ) 52 - .await; 53 - } 54 - 31 + pub async fn get_age_assurance_state() -> Response { 55 32 (StatusCode::OK, Json(json!({"status": "assured"}))).into_response() 56 33 } 57 34 ··· 89 66 90 67 row.map(|r| r.created_at.to_rfc3339()) 91 68 } 92 - 93 - async fn proxy_to_appview( 94 - state: AppState, 95 - headers: HeaderMap, 96 - method: &str, 97 - query: Option<String>, 98 - ) -> Response { 99 - if headers.get("atproto-proxy").is_none() { 100 - return ( 101 - StatusCode::BAD_REQUEST, 102 - Json(json!({ 103 - "error": "InvalidRequest", 104 - "message": "Missing required atproto-proxy header" 105 - })), 106 - ) 107 - .into_response(); 108 - } 109 - 110 - crate::api::proxy::proxy_handler( 111 - State(state), 112 - Path(method.to_string()), 113 - Method::GET, 114 - headers, 115 - RawQuery(query), 116 - Bytes::new(), 117 - ) 118 - .await 119 - }
+80 -3
src/api/proxy.rs
··· 1 + use std::convert::Infallible; 2 + 1 3 use crate::api::proxy_client::proxy_client; 2 4 use crate::state::AppState; 3 5 use axum::{ 4 6 Json, 5 7 body::Bytes, 6 - extract::{Path, RawQuery, State}, 8 + extract::{RawQuery, Request, State}, 9 + handler::Handler, 7 10 http::{HeaderMap, Method, StatusCode}, 8 11 response::{IntoResponse, Response}, 9 12 }; 13 + use futures_util::future::Either; 10 14 use serde_json::json; 15 + use tower::{Service, util::BoxCloneSyncService}; 11 16 use tracing::{error, info, warn}; 12 17 13 18 const PROTECTED_METHODS: &[&str] = &[ ··· 33 38 PROTECTED_METHODS.contains(&method) 34 39 } 35 40 36 - pub async fn proxy_handler( 41 + pub struct XrpcProxyLayer { 42 + state: AppState, 43 + } 44 + 45 + impl XrpcProxyLayer { 46 + pub fn new(state: AppState) -> Self { 47 + XrpcProxyLayer { state } 48 + } 49 + } 50 + 51 + impl<S> tower_layer::Layer<S> for XrpcProxyLayer { 52 + type Service = XrpcProxyingService<S>; 53 + 54 + fn layer(&self, inner: S) -> Self::Service { 55 + XrpcProxyingService { 56 + inner, 57 + // TODO(nel): make our own service here instead of boxing a HandlerService 58 + handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), 59 + } 60 + } 61 + } 62 + 63 + #[derive(Clone)] 64 + pub struct XrpcProxyingService<S> { 65 + inner: S, 66 + handler: BoxCloneSyncService<Request, Response, Infallible>, 67 + } 68 + 69 + impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request> 70 + for XrpcProxyingService<S> 71 + { 72 + type Response = Response; 73 + 74 + type Error = Infallible; 75 + 76 + type Future = Either< 77 + <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future, 78 + S::Future, 79 + >; 80 + 81 + fn poll_ready( 82 + &mut self, 83 + cx: &mut std::task::Context<'_>, 84 + ) -> std::task::Poll<Result<(), Self::Error>> { 85 + self.inner.poll_ready(cx) 86 + } 87 + 88 + fn call(&mut self, req: Request) -> Self::Future { 89 + if req 90 + .headers() 91 + .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) 92 + { 93 + // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it. 94 + if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() 95 + && (req.uri().path().ends_with("app.bsky.ageassurance.getState") 96 + || req 97 + .uri() 98 + .path() 99 + .ends_with("app.bsky.unspecced.getAgeAssuranceState")) 100 + { 101 + return Either::Right(self.inner.call(req)); 102 + } 103 + 104 + Either::Left(self.handler.call(req)) 105 + } else { 106 + Either::Right(self.inner.call(req)) 107 + } 108 + } 109 + } 110 + 111 + async fn proxy_handler( 37 112 State(state): State<AppState>, 38 - Path(method): Path<String>, 113 + uri: http::Uri, 39 114 method_verb: Method, 40 115 headers: HeaderMap, 41 116 RawQuery(query): RawQuery, 42 117 body: Bytes, 43 118 ) -> Response { 119 + // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the / 120 + let method = uri.path().trim_start_matches("/"); 44 121 if is_protected_method(&method) { 45 122 warn!(method = %method, "Attempted to proxy protected method"); 46 123 return (
+1 -57
src/api/repo/record/read.rs
··· 1 - use crate::api::proxy_client::proxy_client; 2 1 use crate::state::AppState; 3 2 use axum::{ 4 3 Json, ··· 14 13 use serde_json::{Map, Value, json}; 15 14 use std::collections::HashMap; 16 15 use std::str::FromStr; 17 - use tracing::{error, info}; 16 + use tracing::error; 18 17 19 18 fn ipld_to_json(ipld: Ipld) -> Value { 20 19 match ipld { ··· 78 77 let user_id: uuid::Uuid = match user_id_opt { 79 78 Ok(Some(id)) => id, 80 79 Ok(None) => { 81 - if let Some(proxy_header) = headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 82 - let did = proxy_header.split('#').next().unwrap_or(proxy_header); 83 - if let Some(resolved) = state.did_resolver.resolve_did(did).await { 84 - let mut url = format!( 85 - "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}", 86 - resolved.url.trim_end_matches('/'), 87 - urlencoding::encode(&input.repo), 88 - urlencoding::encode(&input.collection), 89 - urlencoding::encode(&input.rkey) 90 - ); 91 - if let Some(cid) = &input.cid { 92 - url.push_str(&format!("&cid={}", urlencoding::encode(cid))); 93 - } 94 - info!("Proxying getRecord to {}: {}", did, url); 95 - match proxy_client().get(&url).send().await { 96 - Ok(resp) => { 97 - let status = resp.status(); 98 - let body = match resp.bytes().await { 99 - Ok(b) => b, 100 - Err(e) => { 101 - error!("Error reading proxy response: {:?}", e); 102 - return ( 103 - StatusCode::BAD_GATEWAY, 104 - Json(json!({"error": "UpstreamFailure", "message": "Error reading upstream response"})), 105 - ) 106 - .into_response(); 107 - } 108 - }; 109 - return Response::builder() 110 - .status(status) 111 - .header("content-type", "application/json") 112 - .body(axum::body::Body::from(body)) 113 - .unwrap_or_else(|_| { 114 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error") 115 - .into_response() 116 - }); 117 - } 118 - Err(e) => { 119 - error!("Error proxying request: {:?}", e); 120 - return ( 121 - StatusCode::BAD_GATEWAY, 122 - Json(json!({"error": "UpstreamFailure", "message": "Failed to reach upstream service"})), 123 - ) 124 - .into_response(); 125 - } 126 - } 127 - } else { 128 - error!("Could not resolve DID from atproto-proxy header: {}", did); 129 - return ( 130 - StatusCode::BAD_GATEWAY, 131 - Json(json!({"error": "UpstreamFailure", "message": "Could not resolve proxy DID"})), 132 - ) 133 - .into_response(); 134 - } 135 - } 136 80 return ( 137 81 StatusCode::NOT_FOUND, 138 82 Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
+11 -5
src/lib.rs
··· 22 22 pub mod util; 23 23 pub mod validation; 24 24 25 + use api::proxy::XrpcProxyLayer; 25 26 use axum::{ 26 - Router, 27 + Json, Router, 27 28 extract::DefaultBodyLimit, 28 29 http::Method, 29 30 middleware, 30 - routing::{any, get, post}, 31 + routing::{get, post}, 31 32 }; 33 + use http::StatusCode; 34 + use serde_json::json; 32 35 use state::AppState; 36 + use tower::{Layer, ServiceBuilder}; 33 37 use tower_http::cors::{Any, CorsLayer}; 34 38 use tower_http::services::{ServeDir, ServeFile}; 35 39 ··· 494 498 .route( 495 499 "/app.bsky.unspecced.getAgeAssuranceState", 496 500 get(api::age_assurance::get_age_assurance_state), 497 - ) 498 - .route("/{*method}", any(api::proxy::proxy_handler)); 501 + ); 502 + let xrpc_service = ServiceBuilder::new() 503 + .layer(XrpcProxyLayer::new(state.clone())) 504 + .service(xrpc_router.with_state(state.clone())); 499 505 500 506 let oauth_router = Router::new() 501 507 .route("/jwks", get(oauth::endpoints::oauth_jwks)) ··· 559 565 ); 560 566 561 567 let router = Router::new() 562 - .nest("/xrpc", xrpc_router) 568 + .nest_service("/xrpc", xrpc_service) 563 569 .nest("/oauth", oauth_router) 564 570 .route("/metrics", get(metrics::metrics_handler)) 565 571 .route("/health", get(api::server::health))