this repo has no description
1use std::convert::Infallible; 2 3use crate::api::error::ApiError; 4use crate::api::proxy_client::proxy_client; 5use crate::state::AppState; 6use axum::{ 7 body::Bytes, 8 extract::{RawQuery, Request, State}, 9 handler::Handler, 10 http::{HeaderMap, Method, StatusCode}, 11 response::{IntoResponse, Response}, 12}; 13use futures_util::future::Either; 14use tower::{Service, util::BoxCloneSyncService}; 15use tracing::{error, info, warn}; 16 17const PROTECTED_METHODS: &[&str] = &[ 18 "app.bsky.actor.getPreferences", 19 "app.bsky.actor.putPreferences", 20 "com.atproto.admin.deleteAccount", 21 "com.atproto.admin.disableAccountInvites", 22 "com.atproto.admin.disableInviteCodes", 23 "com.atproto.admin.enableAccountInvites", 24 "com.atproto.admin.getAccountInfo", 25 "com.atproto.admin.getAccountInfos", 26 "com.atproto.admin.getInviteCodes", 27 "com.atproto.admin.getSubjectStatus", 28 "com.atproto.admin.searchAccounts", 29 "com.atproto.admin.sendEmail", 30 "com.atproto.admin.updateAccountEmail", 31 "com.atproto.admin.updateAccountHandle", 32 "com.atproto.admin.updateAccountPassword", 33 "com.atproto.admin.updateSubjectStatus", 34 "com.atproto.identity.getRecommendedDidCredentials", 35 "com.atproto.identity.requestPlcOperationSignature", 36 "com.atproto.identity.signPlcOperation", 37 "com.atproto.identity.submitPlcOperation", 38 "com.atproto.identity.updateHandle", 39 "com.atproto.repo.applyWrites", 40 "com.atproto.repo.createRecord", 41 "com.atproto.repo.deleteRecord", 42 "com.atproto.repo.importRepo", 43 "com.atproto.repo.putRecord", 44 "com.atproto.repo.uploadBlob", 45 "com.atproto.server.activateAccount", 46 "com.atproto.server.checkAccountStatus", 47 "com.atproto.server.confirmEmail", 48 "com.atproto.server.confirmSignup", 49 "com.atproto.server.createAccount", 50 "com.atproto.server.createAppPassword", 51 "com.atproto.server.createInviteCode", 52 "com.atproto.server.createInviteCodes", 53 "com.atproto.server.createSession", 54 "com.atproto.server.createTotpSecret", 55 "com.atproto.server.deactivateAccount", 56 "com.atproto.server.deleteAccount", 57 "com.atproto.server.deletePasskey", 58 "com.atproto.server.deleteSession", 59 "com.atproto.server.describeServer", 60 "com.atproto.server.disableTotp", 61 "com.atproto.server.enableTotp", 62 "com.atproto.server.finishPasskeyRegistration", 63 "com.atproto.server.getAccountInviteCodes", 64 "com.atproto.server.getServiceAuth", 65 "com.atproto.server.getSession", 66 "com.atproto.server.getTotpStatus", 67 "com.atproto.server.listAppPasswords", 68 "com.atproto.server.listPasskeys", 69 "com.atproto.server.refreshSession", 70 "com.atproto.server.regenerateBackupCodes", 71 "com.atproto.server.requestAccountDelete", 72 "com.atproto.server.requestEmailConfirmation", 73 "com.atproto.server.requestEmailUpdate", 74 "com.atproto.server.requestPasswordReset", 75 "com.atproto.server.resendMigrationVerification", 76 "com.atproto.server.resendVerification", 77 "com.atproto.server.reserveSigningKey", 78 "com.atproto.server.resetPassword", 79 "com.atproto.server.revokeAppPassword", 80 "com.atproto.server.startPasskeyRegistration", 81 "com.atproto.server.updateEmail", 82 "com.atproto.server.updatePasskey", 83 "com.atproto.server.verifyMigrationEmail", 84 "com.atproto.sync.getBlob", 85 "com.atproto.sync.getBlocks", 86 "com.atproto.sync.getCheckout", 87 "com.atproto.sync.getHead", 88 "com.atproto.sync.getLatestCommit", 89 "com.atproto.sync.getRecord", 90 "com.atproto.sync.getRepo", 91 "com.atproto.sync.getRepoStatus", 92 "com.atproto.sync.listBlobs", 93 "com.atproto.sync.listRepos", 94 "com.atproto.sync.notifyOfUpdate", 95 "com.atproto.sync.requestCrawl", 96 "com.atproto.sync.subscribeRepos", 97 "com.atproto.temp.checkSignupQueue", 98 "com.atproto.temp.dereferenceScope", 99]; 100 101fn is_protected_method(method: &str) -> bool { 102 PROTECTED_METHODS.contains(&method) 103} 104 105pub struct XrpcProxyLayer { 106 state: AppState, 107} 108 109impl XrpcProxyLayer { 110 pub fn new(state: AppState) -> Self { 111 XrpcProxyLayer { state } 112 } 113} 114 115impl<S> tower_layer::Layer<S> for XrpcProxyLayer { 116 type Service = XrpcProxyingService<S>; 117 118 fn layer(&self, inner: S) -> Self::Service { 119 XrpcProxyingService { 120 inner, 121 // TODO(nel): make our own service here instead of boxing a HandlerService 122 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), 123 } 124 } 125} 126 127#[derive(Clone)] 128pub struct XrpcProxyingService<S> { 129 inner: S, 130 handler: BoxCloneSyncService<Request, Response, Infallible>, 131} 132 133impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request> 134 for XrpcProxyingService<S> 135{ 136 type Response = Response; 137 138 type Error = Infallible; 139 140 type Future = Either< 141 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future, 142 S::Future, 143 >; 144 145 fn poll_ready( 146 &mut self, 147 cx: &mut std::task::Context<'_>, 148 ) -> std::task::Poll<Result<(), Self::Error>> { 149 self.inner.poll_ready(cx) 150 } 151 152 fn call(&mut self, req: Request) -> Self::Future { 153 if req 154 .headers() 155 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) 156 { 157 let path = req.uri().path(); 158 let method = path.trim_start_matches("/"); 159 160 if is_protected_method(method) { 161 return Either::Right(self.inner.call(req)); 162 } 163 164 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it 165 if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok() 166 && (path.ends_with("app.bsky.ageassurance.getState") 167 || path.ends_with("app.bsky.unspecced.getAgeAssuranceState")) 168 { 169 return Either::Right(self.inner.call(req)); 170 } 171 172 Either::Left(self.handler.call(req)) 173 } else { 174 Either::Right(self.inner.call(req)) 175 } 176 } 177} 178 179async fn proxy_handler( 180 State(state): State<AppState>, 181 uri: http::Uri, 182 method_verb: Method, 183 headers: HeaderMap, 184 RawQuery(query): RawQuery, 185 body: Bytes, 186) -> Response { 187 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the / 188 let method = uri.path().trim_start_matches("/"); 189 if is_protected_method(method) { 190 warn!(method = %method, "Attempted to proxy protected method"); 191 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method)) 192 .into_response(); 193 } 194 195 let Some(proxy_header) = headers 196 .get("atproto-proxy") 197 .and_then(|h| h.to_str().ok()) 198 .map(String::from) 199 else { 200 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 201 .into_response(); 202 }; 203 204 let did = proxy_header.split('#').next().unwrap_or(&proxy_header); 205 let Some(resolved) = state.did_resolver.resolve_did(did).await else { 206 error!(did = %did, "Could not resolve service DID"); 207 return ApiError::UpstreamFailure.into_response(); 208 }; 209 210 let target_url = match &query { 211 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 212 None => format!("{}/xrpc/{}", resolved.url, method), 213 }; 214 info!("Proxying {} request to {}", method_verb, target_url); 215 216 let client = proxy_client(); 217 let mut request_builder = client.request(method_verb, &target_url); 218 219 let mut auth_header_val = headers.get("Authorization").cloned(); 220 if let Some(token) = crate::auth::extract_bearer_token_from_header( 221 headers.get("Authorization").and_then(|h| h.to_str().ok()), 222 ) { 223 match crate::auth::validate_bearer_token(&state.db, &token).await { 224 Ok(auth_user) => { 225 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 226 auth_user.is_oauth, 227 auth_user.scope.as_deref(), 228 &resolved.did, 229 method, 230 ) { 231 return e; 232 } 233 234 if let Some(key_bytes) = auth_user.key_bytes { 235 match crate::auth::create_service_token( 236 &auth_user.did, 237 &resolved.did, 238 method, 239 &key_bytes, 240 ) { 241 Ok(new_token) => { 242 if let Ok(val) = 243 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 244 { 245 auth_header_val = Some(val); 246 } 247 } 248 Err(e) => { 249 warn!("Failed to create service token: {:?}", e); 250 } 251 } 252 } 253 } 254 Err(e) => { 255 warn!("Token validation failed: {:?}", e); 256 if matches!(e, crate::auth::TokenValidationError::TokenExpired) { 257 let auth_header_str = headers 258 .get("Authorization") 259 .and_then(|h| h.to_str().ok()) 260 .unwrap_or(""); 261 let is_dpop = auth_header_str 262 .trim() 263 .get(..5) 264 .is_some_and(|s| s.eq_ignore_ascii_case("dpop ")); 265 let scheme = if is_dpop { "DPoP" } else { "Bearer" }; 266 let www_auth = format!( 267 "{} error=\"invalid_token\", error_description=\"Token has expired\"", 268 scheme 269 ); 270 let mut response = 271 ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 272 response 273 .headers_mut() 274 .insert("WWW-Authenticate", www_auth.parse().unwrap()); 275 if is_dpop { 276 let nonce = crate::oauth::verify::generate_dpop_nonce(); 277 response 278 .headers_mut() 279 .insert("DPoP-Nonce", nonce.parse().unwrap()); 280 } 281 return response; 282 } 283 } 284 } 285 } 286 287 if let Some(val) = auth_header_val { 288 request_builder = request_builder.header("Authorization", val); 289 } 290 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD { 291 if let Some(val) = headers.get(*header_name) { 292 request_builder = request_builder.header(*header_name, val); 293 } 294 } 295 if !body.is_empty() { 296 request_builder = request_builder.body(body); 297 } 298 299 match request_builder.send().await { 300 Ok(resp) => { 301 let status = resp.status(); 302 let headers = resp.headers().clone(); 303 let body = match resp.bytes().await { 304 Ok(b) => b, 305 Err(e) => { 306 error!("Error reading proxy response body: {:?}", e); 307 return (StatusCode::BAD_GATEWAY, "Error reading upstream response") 308 .into_response(); 309 } 310 }; 311 let mut response_builder = Response::builder().status(status); 312 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD { 313 if let Some(val) = headers.get(*header_name) { 314 response_builder = response_builder.header(*header_name, val); 315 } 316 } 317 match response_builder.body(axum::body::Body::from(body)) { 318 Ok(r) => r, 319 Err(e) => { 320 error!("Error building proxy response: {:?}", e); 321 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() 322 } 323 } 324 } 325 Err(e) => { 326 error!("Error sending proxy request: {:?}", e); 327 if e.is_timeout() { 328 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response() 329 } else { 330 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response() 331 } 332 } 333 } 334}