this repo has no description
1use std::convert::Infallible;
2
3use crate::api::error::ApiError;
4use crate::api::proxy_client::proxy_client;
5use crate::state::AppState;
6use axum::{
7 body::Bytes,
8 extract::{RawQuery, Request, State},
9 handler::Handler,
10 http::{HeaderMap, Method, StatusCode},
11 response::{IntoResponse, Response},
12};
13use futures_util::future::Either;
14use tower::{Service, util::BoxCloneSyncService};
15use tracing::{error, info, warn};
16
17const PROTECTED_METHODS: &[&str] = &[
18 "app.bsky.actor.getPreferences",
19 "app.bsky.actor.putPreferences",
20 "com.atproto.admin.deleteAccount",
21 "com.atproto.admin.disableAccountInvites",
22 "com.atproto.admin.disableInviteCodes",
23 "com.atproto.admin.enableAccountInvites",
24 "com.atproto.admin.getAccountInfo",
25 "com.atproto.admin.getAccountInfos",
26 "com.atproto.admin.getInviteCodes",
27 "com.atproto.admin.getSubjectStatus",
28 "com.atproto.admin.searchAccounts",
29 "com.atproto.admin.sendEmail",
30 "com.atproto.admin.updateAccountEmail",
31 "com.atproto.admin.updateAccountHandle",
32 "com.atproto.admin.updateAccountPassword",
33 "com.atproto.admin.updateSubjectStatus",
34 "com.atproto.identity.getRecommendedDidCredentials",
35 "com.atproto.identity.requestPlcOperationSignature",
36 "com.atproto.identity.signPlcOperation",
37 "com.atproto.identity.submitPlcOperation",
38 "com.atproto.identity.updateHandle",
39 "com.atproto.repo.applyWrites",
40 "com.atproto.repo.createRecord",
41 "com.atproto.repo.deleteRecord",
42 "com.atproto.repo.importRepo",
43 "com.atproto.repo.putRecord",
44 "com.atproto.repo.uploadBlob",
45 "com.atproto.server.activateAccount",
46 "com.atproto.server.checkAccountStatus",
47 "com.atproto.server.confirmEmail",
48 "com.atproto.server.confirmSignup",
49 "com.atproto.server.createAccount",
50 "com.atproto.server.createAppPassword",
51 "com.atproto.server.createInviteCode",
52 "com.atproto.server.createInviteCodes",
53 "com.atproto.server.createSession",
54 "com.atproto.server.createTotpSecret",
55 "com.atproto.server.deactivateAccount",
56 "com.atproto.server.deleteAccount",
57 "com.atproto.server.deletePasskey",
58 "com.atproto.server.deleteSession",
59 "com.atproto.server.describeServer",
60 "com.atproto.server.disableTotp",
61 "com.atproto.server.enableTotp",
62 "com.atproto.server.finishPasskeyRegistration",
63 "com.atproto.server.getAccountInviteCodes",
64 "com.atproto.server.getServiceAuth",
65 "com.atproto.server.getSession",
66 "com.atproto.server.getTotpStatus",
67 "com.atproto.server.listAppPasswords",
68 "com.atproto.server.listPasskeys",
69 "com.atproto.server.refreshSession",
70 "com.atproto.server.regenerateBackupCodes",
71 "com.atproto.server.requestAccountDelete",
72 "com.atproto.server.requestEmailConfirmation",
73 "com.atproto.server.requestEmailUpdate",
74 "com.atproto.server.requestPasswordReset",
75 "com.atproto.server.resendMigrationVerification",
76 "com.atproto.server.resendVerification",
77 "com.atproto.server.reserveSigningKey",
78 "com.atproto.server.resetPassword",
79 "com.atproto.server.revokeAppPassword",
80 "com.atproto.server.startPasskeyRegistration",
81 "com.atproto.server.updateEmail",
82 "com.atproto.server.updatePasskey",
83 "com.atproto.server.verifyMigrationEmail",
84 "com.atproto.sync.getBlob",
85 "com.atproto.sync.getBlocks",
86 "com.atproto.sync.getCheckout",
87 "com.atproto.sync.getHead",
88 "com.atproto.sync.getLatestCommit",
89 "com.atproto.sync.getRecord",
90 "com.atproto.sync.getRepo",
91 "com.atproto.sync.getRepoStatus",
92 "com.atproto.sync.listBlobs",
93 "com.atproto.sync.listRepos",
94 "com.atproto.sync.notifyOfUpdate",
95 "com.atproto.sync.requestCrawl",
96 "com.atproto.sync.subscribeRepos",
97 "com.atproto.temp.checkSignupQueue",
98 "com.atproto.temp.dereferenceScope",
99];
100
101fn is_protected_method(method: &str) -> bool {
102 PROTECTED_METHODS.contains(&method)
103}
104
105pub struct XrpcProxyLayer {
106 state: AppState,
107}
108
109impl XrpcProxyLayer {
110 pub fn new(state: AppState) -> Self {
111 XrpcProxyLayer { state }
112 }
113}
114
115impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
116 type Service = XrpcProxyingService<S>;
117
118 fn layer(&self, inner: S) -> Self::Service {
119 XrpcProxyingService {
120 inner,
121 // TODO(nel): make our own service here instead of boxing a HandlerService
122 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
123 }
124 }
125}
126
127#[derive(Clone)]
128pub struct XrpcProxyingService<S> {
129 inner: S,
130 handler: BoxCloneSyncService<Request, Response, Infallible>,
131}
132
133impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
134 for XrpcProxyingService<S>
135{
136 type Response = Response;
137
138 type Error = Infallible;
139
140 type Future = Either<
141 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
142 S::Future,
143 >;
144
145 fn poll_ready(
146 &mut self,
147 cx: &mut std::task::Context<'_>,
148 ) -> std::task::Poll<Result<(), Self::Error>> {
149 self.inner.poll_ready(cx)
150 }
151
152 fn call(&mut self, req: Request) -> Self::Future {
153 if req
154 .headers()
155 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
156 {
157 let path = req.uri().path();
158 let method = path.trim_start_matches("/");
159
160 if is_protected_method(method) {
161 return Either::Right(self.inner.call(req));
162 }
163
164 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it
165 if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok()
166 && (path.ends_with("app.bsky.ageassurance.getState")
167 || path.ends_with("app.bsky.unspecced.getAgeAssuranceState"))
168 {
169 return Either::Right(self.inner.call(req));
170 }
171
172 Either::Left(self.handler.call(req))
173 } else {
174 Either::Right(self.inner.call(req))
175 }
176 }
177}
178
179async fn proxy_handler(
180 State(state): State<AppState>,
181 uri: http::Uri,
182 method_verb: Method,
183 headers: HeaderMap,
184 RawQuery(query): RawQuery,
185 body: Bytes,
186) -> Response {
187 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
188 let method = uri.path().trim_start_matches("/");
189 if is_protected_method(method) {
190 warn!(method = %method, "Attempted to proxy protected method");
191 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method))
192 .into_response();
193 }
194
195 let Some(proxy_header) = headers
196 .get("atproto-proxy")
197 .and_then(|h| h.to_str().ok())
198 .map(String::from)
199 else {
200 return ApiError::InvalidRequest("Missing required atproto-proxy header".into())
201 .into_response();
202 };
203
204 let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
205 let Some(resolved) = state.did_resolver.resolve_did(did).await else {
206 error!(did = %did, "Could not resolve service DID");
207 return ApiError::UpstreamFailure.into_response();
208 };
209
210 let target_url = match &query {
211 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
212 None => format!("{}/xrpc/{}", resolved.url, method),
213 };
214 info!("Proxying {} request to {}", method_verb, target_url);
215
216 let client = proxy_client();
217 let mut request_builder = client.request(method_verb.clone(), &target_url);
218
219 let mut auth_header_val = headers.get("Authorization").cloned();
220 if let Some(extracted) = crate::auth::extract_auth_token_from_header(
221 headers.get("Authorization").and_then(|h| h.to_str().ok()),
222 ) {
223 let token = extracted.token;
224 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
225 let http_uri = uri.to_string();
226
227 match crate::auth::validate_token_with_dpop(
228 &state.db,
229 &token,
230 extracted.is_dpop,
231 dpop_proof,
232 method_verb.as_str(),
233 &http_uri,
234 false,
235 false,
236 )
237 .await
238 {
239 Ok(auth_user) => {
240 if let Err(e) = crate::auth::scope_check::check_rpc_scope(
241 auth_user.is_oauth,
242 auth_user.scope.as_deref(),
243 &resolved.did,
244 method,
245 ) {
246 return e;
247 }
248
249 if let Some(key_bytes) = auth_user.key_bytes {
250 match crate::auth::create_service_token(
251 &auth_user.did,
252 &resolved.did,
253 method,
254 &key_bytes,
255 ) {
256 Ok(new_token) => {
257 if let Ok(val) =
258 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
259 {
260 auth_header_val = Some(val);
261 }
262 }
263 Err(e) => {
264 warn!("Failed to create service token: {:?}", e);
265 }
266 }
267 }
268 }
269 Err(e) => {
270 warn!("Token validation failed: {:?}", e);
271 if matches!(e, crate::auth::TokenValidationError::TokenExpired)
272 && extracted.is_dpop
273 {
274 let www_auth =
275 "DPoP error=\"invalid_token\", error_description=\"Token has expired\"";
276 let mut response =
277 ApiError::ExpiredToken(Some("Token has expired".into())).into_response();
278 *response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
279 response
280 .headers_mut()
281 .insert("WWW-Authenticate", www_auth.parse().unwrap());
282 let nonce = crate::oauth::verify::generate_dpop_nonce();
283 response
284 .headers_mut()
285 .insert("DPoP-Nonce", nonce.parse().unwrap());
286 return response;
287 }
288 }
289 }
290 }
291
292 if let Some(val) = auth_header_val {
293 request_builder = request_builder.header("Authorization", val);
294 }
295 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
296 if let Some(val) = headers.get(*header_name) {
297 request_builder = request_builder.header(*header_name, val);
298 }
299 }
300 if !body.is_empty() {
301 request_builder = request_builder.body(body);
302 }
303
304 match request_builder.send().await {
305 Ok(resp) => {
306 let status = resp.status();
307 let headers = resp.headers().clone();
308 let body = match resp.bytes().await {
309 Ok(b) => b,
310 Err(e) => {
311 error!("Error reading proxy response body: {:?}", e);
312 return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
313 .into_response();
314 }
315 };
316 let mut response_builder = Response::builder().status(status);
317 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
318 if let Some(val) = headers.get(*header_name) {
319 response_builder = response_builder.header(*header_name, val);
320 }
321 }
322 match response_builder.body(axum::body::Body::from(body)) {
323 Ok(r) => r,
324 Err(e) => {
325 error!("Error building proxy response: {:?}", e);
326 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
327 }
328 }
329 }
330 Err(e) => {
331 error!("Error sending proxy request: {:?}", e);
332 if e.is_timeout() {
333 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response()
334 } else {
335 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response()
336 }
337 }
338 }
339}