this repo has no description
1use std::convert::Infallible; 2 3use crate::api::proxy_client::proxy_client; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 body::Bytes, 8 extract::{RawQuery, Request, State}, 9 handler::Handler, 10 http::{HeaderMap, Method, StatusCode}, 11 response::{IntoResponse, Response}, 12}; 13use futures_util::future::Either; 14use serde_json::json; 15use tower::{Service, util::BoxCloneSyncService}; 16use tracing::{error, info, warn}; 17 18const PROTECTED_METHODS: &[&str] = &[ 19 "com.atproto.admin.sendEmail", 20 "com.atproto.identity.requestPlcOperationSignature", 21 "com.atproto.identity.signPlcOperation", 22 "com.atproto.identity.updateHandle", 23 "com.atproto.server.activateAccount", 24 "com.atproto.server.confirmEmail", 25 "com.atproto.server.createAppPassword", 26 "com.atproto.server.deactivateAccount", 27 "com.atproto.server.getAccountInviteCodes", 28 "com.atproto.server.getSession", 29 "com.atproto.server.listAppPasswords", 30 "com.atproto.server.requestAccountDelete", 31 "com.atproto.server.requestEmailConfirmation", 32 "com.atproto.server.requestEmailUpdate", 33 "com.atproto.server.revokeAppPassword", 34 "com.atproto.server.updateEmail", 35]; 36 37fn is_protected_method(method: &str) -> bool { 38 PROTECTED_METHODS.contains(&method) 39} 40 41pub struct XrpcProxyLayer { 42 state: AppState, 43} 44 45impl XrpcProxyLayer { 46 pub fn new(state: AppState) -> Self { 47 XrpcProxyLayer { state } 48 } 49} 50 51impl<S> tower_layer::Layer<S> for XrpcProxyLayer { 52 type Service = XrpcProxyingService<S>; 53 54 fn layer(&self, inner: S) -> Self::Service { 55 XrpcProxyingService { 56 inner, 57 // TODO(nel): make our own service here instead of boxing a HandlerService 58 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), 59 } 60 } 61} 62 63#[derive(Clone)] 64pub struct XrpcProxyingService<S> { 65 inner: S, 66 handler: BoxCloneSyncService<Request, Response, Infallible>, 67} 68 69impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request> 70 for XrpcProxyingService<S> 71{ 72 type Response = Response; 73 74 type Error = Infallible; 75 76 type Future = Either< 77 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future, 78 S::Future, 79 >; 80 81 fn poll_ready( 82 &mut self, 83 cx: &mut std::task::Context<'_>, 84 ) -> std::task::Poll<Result<(), Self::Error>> { 85 self.inner.poll_ready(cx) 86 } 87 88 fn call(&mut self, req: Request) -> Self::Future { 89 if req 90 .headers() 91 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) 92 { 93 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it. 94 if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() 95 && (req.uri().path().ends_with("app.bsky.ageassurance.getState") 96 || req 97 .uri() 98 .path() 99 .ends_with("app.bsky.unspecced.getAgeAssuranceState")) 100 { 101 return Either::Right(self.inner.call(req)); 102 } 103 104 Either::Left(self.handler.call(req)) 105 } else { 106 Either::Right(self.inner.call(req)) 107 } 108 } 109} 110 111async fn proxy_handler( 112 State(state): State<AppState>, 113 uri: http::Uri, 114 method_verb: Method, 115 headers: HeaderMap, 116 RawQuery(query): RawQuery, 117 body: Bytes, 118) -> Response { 119 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the / 120 let method = uri.path().trim_start_matches("/"); 121 if is_protected_method(&method) { 122 warn!(method = %method, "Attempted to proxy protected method"); 123 return ( 124 StatusCode::BAD_REQUEST, 125 Json(json!({ 126 "error": "InvalidRequest", 127 "message": format!("Cannot proxy protected method: {}", method) 128 })), 129 ) 130 .into_response(); 131 } 132 133 let proxy_header = match headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) { 134 Some(h) => h.to_string(), 135 None => { 136 return ( 137 StatusCode::BAD_REQUEST, 138 Json(json!({ 139 "error": "InvalidRequest", 140 "message": "Missing required atproto-proxy header" 141 })), 142 ) 143 .into_response(); 144 } 145 }; 146 147 let did = proxy_header.split('#').next().unwrap_or(&proxy_header); 148 let resolved = match state.did_resolver.resolve_did(did).await { 149 Some(r) => r, 150 None => { 151 error!(did = %did, "Could not resolve service DID"); 152 return ( 153 StatusCode::BAD_GATEWAY, 154 Json(json!({ 155 "error": "UpstreamFailure", 156 "message": "Could not resolve service DID" 157 })), 158 ) 159 .into_response(); 160 } 161 }; 162 163 let target_url = match &query { 164 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 165 None => format!("{}/xrpc/{}", resolved.url, method), 166 }; 167 info!("Proxying {} request to {}", method_verb, target_url); 168 169 let client = proxy_client(); 170 let mut request_builder = client.request(method_verb, &target_url); 171 172 let mut auth_header_val = headers.get("Authorization").cloned(); 173 if let Some(token) = crate::auth::extract_bearer_token_from_header( 174 headers.get("Authorization").and_then(|h| h.to_str().ok()), 175 ) { 176 match crate::auth::validate_bearer_token(&state.db, &token).await { 177 Ok(auth_user) => { 178 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 179 auth_user.is_oauth, 180 auth_user.scope.as_deref(), 181 &resolved.did, 182 &method, 183 ) { 184 return e; 185 } 186 187 if let Some(key_bytes) = auth_user.key_bytes { 188 match crate::auth::create_service_token( 189 &auth_user.did, 190 &resolved.did, 191 &method, 192 &key_bytes, 193 ) { 194 Ok(new_token) => { 195 if let Ok(val) = 196 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 197 { 198 auth_header_val = Some(val); 199 } 200 } 201 Err(e) => { 202 warn!("Failed to create service token: {:?}", e); 203 } 204 } 205 } 206 } 207 Err(e) => { 208 warn!("Token validation failed: {:?}", e); 209 if matches!(e, crate::auth::TokenValidationError::TokenExpired) { 210 let auth_header_str = headers 211 .get("Authorization") 212 .and_then(|h| h.to_str().ok()) 213 .unwrap_or(""); 214 let is_dpop = auth_header_str 215 .trim() 216 .get(..5) 217 .is_some_and(|s| s.eq_ignore_ascii_case("dpop ")); 218 let scheme = if is_dpop { "DPoP" } else { "Bearer" }; 219 let www_auth = format!( 220 "{} error=\"invalid_token\", error_description=\"Token has expired\"", 221 scheme 222 ); 223 let mut response = ( 224 StatusCode::UNAUTHORIZED, 225 Json(json!({ 226 "error": "ExpiredToken", 227 "message": "Token has expired" 228 })), 229 ) 230 .into_response(); 231 response 232 .headers_mut() 233 .insert("WWW-Authenticate", www_auth.parse().unwrap()); 234 if is_dpop { 235 let nonce = crate::oauth::verify::generate_dpop_nonce(); 236 response 237 .headers_mut() 238 .insert("DPoP-Nonce", nonce.parse().unwrap()); 239 } 240 return response; 241 } 242 } 243 } 244 } 245 246 if let Some(val) = auth_header_val { 247 request_builder = request_builder.header("Authorization", val); 248 } 249 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD { 250 if let Some(val) = headers.get(*header_name) { 251 request_builder = request_builder.header(*header_name, val); 252 } 253 } 254 if !body.is_empty() { 255 request_builder = request_builder.body(body); 256 } 257 258 match request_builder.send().await { 259 Ok(resp) => { 260 let status = resp.status(); 261 let headers = resp.headers().clone(); 262 let body = match resp.bytes().await { 263 Ok(b) => b, 264 Err(e) => { 265 error!("Error reading proxy response body: {:?}", e); 266 return (StatusCode::BAD_GATEWAY, "Error reading upstream response") 267 .into_response(); 268 } 269 }; 270 let mut response_builder = Response::builder().status(status); 271 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD { 272 if let Some(val) = headers.get(*header_name) { 273 response_builder = response_builder.header(*header_name, val); 274 } 275 } 276 match response_builder.body(axum::body::Body::from(body)) { 277 Ok(r) => r, 278 Err(e) => { 279 error!("Error building proxy response: {:?}", e); 280 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() 281 } 282 } 283 } 284 Err(e) => { 285 error!("Error sending proxy request: {:?}", e); 286 if e.is_timeout() { 287 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response() 288 } else { 289 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response() 290 } 291 } 292 } 293}