this repo has no description
1use std::convert::Infallible; 2 3use crate::api::error::ApiError; 4use crate::api::proxy_client::proxy_client; 5use crate::state::AppState; 6use axum::{ 7 body::Bytes, 8 extract::{RawQuery, Request, State}, 9 handler::Handler, 10 http::{HeaderMap, Method, StatusCode}, 11 response::{IntoResponse, Response}, 12}; 13use futures_util::future::Either; 14use tower::{Service, util::BoxCloneSyncService}; 15use tracing::{error, info, warn}; 16 17const PROTECTED_METHODS: &[&str] = &[ 18 "com.atproto.admin.sendEmail", 19 "com.atproto.identity.requestPlcOperationSignature", 20 "com.atproto.identity.signPlcOperation", 21 "com.atproto.identity.updateHandle", 22 "com.atproto.server.activateAccount", 23 "com.atproto.server.confirmEmail", 24 "com.atproto.server.createAppPassword", 25 "com.atproto.server.deactivateAccount", 26 "com.atproto.server.getAccountInviteCodes", 27 "com.atproto.server.getSession", 28 "com.atproto.server.listAppPasswords", 29 "com.atproto.server.requestAccountDelete", 30 "com.atproto.server.requestEmailConfirmation", 31 "com.atproto.server.requestEmailUpdate", 32 "com.atproto.server.revokeAppPassword", 33 "com.atproto.server.updateEmail", 34]; 35 36fn is_protected_method(method: &str) -> bool { 37 PROTECTED_METHODS.contains(&method) 38} 39 40pub struct XrpcProxyLayer { 41 state: AppState, 42} 43 44impl XrpcProxyLayer { 45 pub fn new(state: AppState) -> Self { 46 XrpcProxyLayer { state } 47 } 48} 49 50impl<S> tower_layer::Layer<S> for XrpcProxyLayer { 51 type Service = XrpcProxyingService<S>; 52 53 fn layer(&self, inner: S) -> Self::Service { 54 XrpcProxyingService { 55 inner, 56 // TODO(nel): make our own service here instead of boxing a HandlerService 57 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), 58 } 59 } 60} 61 62#[derive(Clone)] 63pub struct XrpcProxyingService<S> { 64 inner: S, 65 handler: BoxCloneSyncService<Request, Response, Infallible>, 66} 67 68impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request> 69 for XrpcProxyingService<S> 70{ 71 type Response = Response; 72 73 type Error = Infallible; 74 75 type Future = Either< 76 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future, 77 S::Future, 78 >; 79 80 fn poll_ready( 81 &mut self, 82 cx: &mut std::task::Context<'_>, 83 ) -> std::task::Poll<Result<(), Self::Error>> { 84 self.inner.poll_ready(cx) 85 } 86 87 fn call(&mut self, req: Request) -> Self::Future { 88 if req 89 .headers() 90 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) 91 { 92 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it. 93 if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err() 94 && (req.uri().path().ends_with("app.bsky.ageassurance.getState") 95 || req 96 .uri() 97 .path() 98 .ends_with("app.bsky.unspecced.getAgeAssuranceState")) 99 { 100 return Either::Right(self.inner.call(req)); 101 } 102 103 Either::Left(self.handler.call(req)) 104 } else { 105 Either::Right(self.inner.call(req)) 106 } 107 } 108} 109 110async fn proxy_handler( 111 State(state): State<AppState>, 112 uri: http::Uri, 113 method_verb: Method, 114 headers: HeaderMap, 115 RawQuery(query): RawQuery, 116 body: Bytes, 117) -> Response { 118 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the / 119 let method = uri.path().trim_start_matches("/"); 120 if is_protected_method(&method) { 121 warn!(method = %method, "Attempted to proxy protected method"); 122 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method)) 123 .into_response(); 124 } 125 126 let Some(proxy_header) = headers 127 .get("atproto-proxy") 128 .and_then(|h| h.to_str().ok()) 129 .map(String::from) 130 else { 131 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 132 .into_response(); 133 }; 134 135 let did = proxy_header.split('#').next().unwrap_or(&proxy_header); 136 let Some(resolved) = state.did_resolver.resolve_did(did).await else { 137 error!(did = %did, "Could not resolve service DID"); 138 return ApiError::UpstreamFailure.into_response(); 139 }; 140 141 let target_url = match &query { 142 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 143 None => format!("{}/xrpc/{}", resolved.url, method), 144 }; 145 info!("Proxying {} request to {}", method_verb, target_url); 146 147 let client = proxy_client(); 148 let mut request_builder = client.request(method_verb, &target_url); 149 150 let mut auth_header_val = headers.get("Authorization").cloned(); 151 if let Some(token) = crate::auth::extract_bearer_token_from_header( 152 headers.get("Authorization").and_then(|h| h.to_str().ok()), 153 ) { 154 match crate::auth::validate_bearer_token(&state.db, &token).await { 155 Ok(auth_user) => { 156 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 157 auth_user.is_oauth, 158 auth_user.scope.as_deref(), 159 &resolved.did, 160 &method, 161 ) { 162 return e; 163 } 164 165 if let Some(key_bytes) = auth_user.key_bytes { 166 match crate::auth::create_service_token( 167 &auth_user.did, 168 &resolved.did, 169 &method, 170 &key_bytes, 171 ) { 172 Ok(new_token) => { 173 if let Ok(val) = 174 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 175 { 176 auth_header_val = Some(val); 177 } 178 } 179 Err(e) => { 180 warn!("Failed to create service token: {:?}", e); 181 } 182 } 183 } 184 } 185 Err(e) => { 186 warn!("Token validation failed: {:?}", e); 187 if matches!(e, crate::auth::TokenValidationError::TokenExpired) { 188 let auth_header_str = headers 189 .get("Authorization") 190 .and_then(|h| h.to_str().ok()) 191 .unwrap_or(""); 192 let is_dpop = auth_header_str 193 .trim() 194 .get(..5) 195 .is_some_and(|s| s.eq_ignore_ascii_case("dpop ")); 196 let scheme = if is_dpop { "DPoP" } else { "Bearer" }; 197 let www_auth = format!( 198 "{} error=\"invalid_token\", error_description=\"Token has expired\"", 199 scheme 200 ); 201 let mut response = 202 ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 203 response 204 .headers_mut() 205 .insert("WWW-Authenticate", www_auth.parse().unwrap()); 206 if is_dpop { 207 let nonce = crate::oauth::verify::generate_dpop_nonce(); 208 response 209 .headers_mut() 210 .insert("DPoP-Nonce", nonce.parse().unwrap()); 211 } 212 return response; 213 } 214 } 215 } 216 } 217 218 if let Some(val) = auth_header_val { 219 request_builder = request_builder.header("Authorization", val); 220 } 221 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD { 222 if let Some(val) = headers.get(*header_name) { 223 request_builder = request_builder.header(*header_name, val); 224 } 225 } 226 if !body.is_empty() { 227 request_builder = request_builder.body(body); 228 } 229 230 match request_builder.send().await { 231 Ok(resp) => { 232 let status = resp.status(); 233 let headers = resp.headers().clone(); 234 let body = match resp.bytes().await { 235 Ok(b) => b, 236 Err(e) => { 237 error!("Error reading proxy response body: {:?}", e); 238 return (StatusCode::BAD_GATEWAY, "Error reading upstream response") 239 .into_response(); 240 } 241 }; 242 let mut response_builder = Response::builder().status(status); 243 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD { 244 if let Some(val) = headers.get(*header_name) { 245 response_builder = response_builder.header(*header_name, val); 246 } 247 } 248 match response_builder.body(axum::body::Body::from(body)) { 249 Ok(r) => r, 250 Err(e) => { 251 error!("Error building proxy response: {:?}", e); 252 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() 253 } 254 } 255 } 256 Err(e) => { 257 error!("Error sending proxy request: {:?}", e); 258 if e.is_timeout() { 259 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response() 260 } else { 261 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response() 262 } 263 } 264 } 265}