this repo has no description
1use std::convert::Infallible;
2
3use crate::api::proxy_client::proxy_client;
4use crate::state::AppState;
5use axum::{
6 Json,
7 body::Bytes,
8 extract::{RawQuery, Request, State},
9 handler::Handler,
10 http::{HeaderMap, Method, StatusCode},
11 response::{IntoResponse, Response},
12};
13use futures_util::future::Either;
14use serde_json::json;
15use tower::{Service, util::BoxCloneSyncService};
16use tracing::{error, info, warn};
17
18const PROTECTED_METHODS: &[&str] = &[
19 "com.atproto.admin.sendEmail",
20 "com.atproto.identity.requestPlcOperationSignature",
21 "com.atproto.identity.signPlcOperation",
22 "com.atproto.identity.updateHandle",
23 "com.atproto.server.activateAccount",
24 "com.atproto.server.confirmEmail",
25 "com.atproto.server.createAppPassword",
26 "com.atproto.server.deactivateAccount",
27 "com.atproto.server.getAccountInviteCodes",
28 "com.atproto.server.getSession",
29 "com.atproto.server.listAppPasswords",
30 "com.atproto.server.requestAccountDelete",
31 "com.atproto.server.requestEmailConfirmation",
32 "com.atproto.server.requestEmailUpdate",
33 "com.atproto.server.revokeAppPassword",
34 "com.atproto.server.updateEmail",
35];
36
37fn is_protected_method(method: &str) -> bool {
38 PROTECTED_METHODS.contains(&method)
39}
40
41pub struct XrpcProxyLayer {
42 state: AppState,
43}
44
45impl XrpcProxyLayer {
46 pub fn new(state: AppState) -> Self {
47 XrpcProxyLayer { state }
48 }
49}
50
51impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
52 type Service = XrpcProxyingService<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 XrpcProxyingService {
56 inner,
57 // TODO(nel): make our own service here instead of boxing a HandlerService
58 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
59 }
60 }
61}
62
63#[derive(Clone)]
64pub struct XrpcProxyingService<S> {
65 inner: S,
66 handler: BoxCloneSyncService<Request, Response, Infallible>,
67}
68
69impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
70 for XrpcProxyingService<S>
71{
72 type Response = Response;
73
74 type Error = Infallible;
75
76 type Future = Either<
77 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
78 S::Future,
79 >;
80
81 fn poll_ready(
82 &mut self,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Result<(), Self::Error>> {
85 self.inner.poll_ready(cx)
86 }
87
88 fn call(&mut self, req: Request) -> Self::Future {
89 if req
90 .headers()
91 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
92 {
93 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it.
94 if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err()
95 && (req.uri().path().ends_with("app.bsky.ageassurance.getState")
96 || req
97 .uri()
98 .path()
99 .ends_with("app.bsky.unspecced.getAgeAssuranceState"))
100 {
101 return Either::Right(self.inner.call(req));
102 }
103
104 Either::Left(self.handler.call(req))
105 } else {
106 Either::Right(self.inner.call(req))
107 }
108 }
109}
110
111async fn proxy_handler(
112 State(state): State<AppState>,
113 uri: http::Uri,
114 method_verb: Method,
115 headers: HeaderMap,
116 RawQuery(query): RawQuery,
117 body: Bytes,
118) -> Response {
119 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
120 let method = uri.path().trim_start_matches("/");
121 if is_protected_method(&method) {
122 warn!(method = %method, "Attempted to proxy protected method");
123 return (
124 StatusCode::BAD_REQUEST,
125 Json(json!({
126 "error": "InvalidRequest",
127 "message": format!("Cannot proxy protected method: {}", method)
128 })),
129 )
130 .into_response();
131 }
132
133 let proxy_header = match headers.get("atproto-proxy").and_then(|h| h.to_str().ok()) {
134 Some(h) => h.to_string(),
135 None => {
136 return (
137 StatusCode::BAD_REQUEST,
138 Json(json!({
139 "error": "InvalidRequest",
140 "message": "Missing required atproto-proxy header"
141 })),
142 )
143 .into_response();
144 }
145 };
146
147 let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
148 let resolved = match state.did_resolver.resolve_did(did).await {
149 Some(r) => r,
150 None => {
151 error!(did = %did, "Could not resolve service DID");
152 return (
153 StatusCode::BAD_GATEWAY,
154 Json(json!({
155 "error": "UpstreamFailure",
156 "message": "Could not resolve service DID"
157 })),
158 )
159 .into_response();
160 }
161 };
162
163 let target_url = match &query {
164 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
165 None => format!("{}/xrpc/{}", resolved.url, method),
166 };
167 info!("Proxying {} request to {}", method_verb, target_url);
168
169 let client = proxy_client();
170 let mut request_builder = client.request(method_verb, &target_url);
171
172 let mut auth_header_val = headers.get("Authorization").cloned();
173 if let Some(token) = crate::auth::extract_bearer_token_from_header(
174 headers.get("Authorization").and_then(|h| h.to_str().ok()),
175 ) {
176 match crate::auth::validate_bearer_token(&state.db, &token).await {
177 Ok(auth_user) => {
178 if let Err(e) = crate::auth::scope_check::check_rpc_scope(
179 auth_user.is_oauth,
180 auth_user.scope.as_deref(),
181 &resolved.did,
182 &method,
183 ) {
184 return e;
185 }
186
187 if let Some(key_bytes) = auth_user.key_bytes {
188 match crate::auth::create_service_token(
189 &auth_user.did,
190 &resolved.did,
191 &method,
192 &key_bytes,
193 ) {
194 Ok(new_token) => {
195 if let Ok(val) =
196 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
197 {
198 auth_header_val = Some(val);
199 }
200 }
201 Err(e) => {
202 warn!("Failed to create service token: {:?}", e);
203 }
204 }
205 }
206 }
207 Err(e) => {
208 warn!("Token validation failed: {:?}", e);
209 if matches!(e, crate::auth::TokenValidationError::TokenExpired) {
210 let auth_header_str = headers
211 .get("Authorization")
212 .and_then(|h| h.to_str().ok())
213 .unwrap_or("");
214 let is_dpop = auth_header_str
215 .trim()
216 .get(..5)
217 .is_some_and(|s| s.eq_ignore_ascii_case("dpop "));
218 let scheme = if is_dpop { "DPoP" } else { "Bearer" };
219 let www_auth = format!(
220 "{} error=\"invalid_token\", error_description=\"Token has expired\"",
221 scheme
222 );
223 let mut response = (
224 StatusCode::UNAUTHORIZED,
225 Json(json!({
226 "error": "ExpiredToken",
227 "message": "Token has expired"
228 })),
229 )
230 .into_response();
231 response
232 .headers_mut()
233 .insert("WWW-Authenticate", www_auth.parse().unwrap());
234 if is_dpop {
235 let nonce = crate::oauth::verify::generate_dpop_nonce();
236 response
237 .headers_mut()
238 .insert("DPoP-Nonce", nonce.parse().unwrap());
239 }
240 return response;
241 }
242 }
243 }
244 }
245
246 if let Some(val) = auth_header_val {
247 request_builder = request_builder.header("Authorization", val);
248 }
249 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
250 if let Some(val) = headers.get(*header_name) {
251 request_builder = request_builder.header(*header_name, val);
252 }
253 }
254 if !body.is_empty() {
255 request_builder = request_builder.body(body);
256 }
257
258 match request_builder.send().await {
259 Ok(resp) => {
260 let status = resp.status();
261 let headers = resp.headers().clone();
262 let body = match resp.bytes().await {
263 Ok(b) => b,
264 Err(e) => {
265 error!("Error reading proxy response body: {:?}", e);
266 return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
267 .into_response();
268 }
269 };
270 let mut response_builder = Response::builder().status(status);
271 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
272 if let Some(val) = headers.get(*header_name) {
273 response_builder = response_builder.header(*header_name, val);
274 }
275 }
276 match response_builder.body(axum::body::Body::from(body)) {
277 Ok(r) => r,
278 Err(e) => {
279 error!("Error building proxy response: {:?}", e);
280 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
281 }
282 }
283 }
284 Err(e) => {
285 error!("Error sending proxy request: {:?}", e);
286 if e.is_timeout() {
287 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response()
288 } else {
289 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response()
290 }
291 }
292 }
293}