this repo has no description
1use std::convert::Infallible;
2
3use crate::api::error::ApiError;
4use crate::api::proxy_client::proxy_client;
5use crate::state::AppState;
6use axum::{
7 body::Bytes,
8 extract::{RawQuery, Request, State},
9 handler::Handler,
10 http::{HeaderMap, Method, StatusCode},
11 response::{IntoResponse, Response},
12};
13use futures_util::future::Either;
14use tower::{Service, util::BoxCloneSyncService};
15use tracing::{error, info, warn};
16
17const PROTECTED_METHODS: &[&str] = &[
18 "app.bsky.actor.getPreferences",
19 "app.bsky.actor.putPreferences",
20 "com.atproto.admin.deleteAccount",
21 "com.atproto.admin.disableAccountInvites",
22 "com.atproto.admin.disableInviteCodes",
23 "com.atproto.admin.enableAccountInvites",
24 "com.atproto.admin.getAccountInfo",
25 "com.atproto.admin.getAccountInfos",
26 "com.atproto.admin.getInviteCodes",
27 "com.atproto.admin.getSubjectStatus",
28 "com.atproto.admin.searchAccounts",
29 "com.atproto.admin.sendEmail",
30 "com.atproto.admin.updateAccountEmail",
31 "com.atproto.admin.updateAccountHandle",
32 "com.atproto.admin.updateAccountPassword",
33 "com.atproto.admin.updateSubjectStatus",
34 "com.atproto.identity.getRecommendedDidCredentials",
35 "com.atproto.identity.requestPlcOperationSignature",
36 "com.atproto.identity.signPlcOperation",
37 "com.atproto.identity.submitPlcOperation",
38 "com.atproto.identity.updateHandle",
39 "com.atproto.repo.applyWrites",
40 "com.atproto.repo.createRecord",
41 "com.atproto.repo.deleteRecord",
42 "com.atproto.repo.importRepo",
43 "com.atproto.repo.putRecord",
44 "com.atproto.repo.uploadBlob",
45 "com.atproto.server.activateAccount",
46 "com.atproto.server.checkAccountStatus",
47 "com.atproto.server.confirmEmail",
48 "com.atproto.server.confirmSignup",
49 "com.atproto.server.createAccount",
50 "com.atproto.server.createAppPassword",
51 "com.atproto.server.createInviteCode",
52 "com.atproto.server.createInviteCodes",
53 "com.atproto.server.createSession",
54 "com.atproto.server.createTotpSecret",
55 "com.atproto.server.deactivateAccount",
56 "com.atproto.server.deleteAccount",
57 "com.atproto.server.deletePasskey",
58 "com.atproto.server.deleteSession",
59 "com.atproto.server.describeServer",
60 "com.atproto.server.disableTotp",
61 "com.atproto.server.enableTotp",
62 "com.atproto.server.finishPasskeyRegistration",
63 "com.atproto.server.getAccountInviteCodes",
64 "com.atproto.server.getServiceAuth",
65 "com.atproto.server.getSession",
66 "com.atproto.server.getTotpStatus",
67 "com.atproto.server.listAppPasswords",
68 "com.atproto.server.listPasskeys",
69 "com.atproto.server.refreshSession",
70 "com.atproto.server.regenerateBackupCodes",
71 "com.atproto.server.requestAccountDelete",
72 "com.atproto.server.requestEmailConfirmation",
73 "com.atproto.server.requestEmailUpdate",
74 "com.atproto.server.requestPasswordReset",
75 "com.atproto.server.resendMigrationVerification",
76 "com.atproto.server.resendVerification",
77 "com.atproto.server.reserveSigningKey",
78 "com.atproto.server.resetPassword",
79 "com.atproto.server.revokeAppPassword",
80 "com.atproto.server.startPasskeyRegistration",
81 "com.atproto.server.updateEmail",
82 "com.atproto.server.updatePasskey",
83 "com.atproto.server.verifyMigrationEmail",
84 "com.atproto.sync.getBlob",
85 "com.atproto.sync.getBlocks",
86 "com.atproto.sync.getCheckout",
87 "com.atproto.sync.getHead",
88 "com.atproto.sync.getLatestCommit",
89 "com.atproto.sync.getRecord",
90 "com.atproto.sync.getRepo",
91 "com.atproto.sync.getRepoStatus",
92 "com.atproto.sync.listBlobs",
93 "com.atproto.sync.listRepos",
94 "com.atproto.sync.notifyOfUpdate",
95 "com.atproto.sync.requestCrawl",
96 "com.atproto.sync.subscribeRepos",
97 "com.atproto.temp.checkSignupQueue",
98 "com.atproto.temp.dereferenceScope",
99];
100
101fn is_protected_method(method: &str) -> bool {
102 PROTECTED_METHODS.contains(&method)
103}
104
105pub struct XrpcProxyLayer {
106 state: AppState,
107}
108
109impl XrpcProxyLayer {
110 pub fn new(state: AppState) -> Self {
111 XrpcProxyLayer { state }
112 }
113}
114
115impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
116 type Service = XrpcProxyingService<S>;
117
118 fn layer(&self, inner: S) -> Self::Service {
119 XrpcProxyingService {
120 inner,
121 // TODO(nel): make our own service here instead of boxing a HandlerService
122 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
123 }
124 }
125}
126
127#[derive(Clone)]
128pub struct XrpcProxyingService<S> {
129 inner: S,
130 handler: BoxCloneSyncService<Request, Response, Infallible>,
131}
132
133impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
134 for XrpcProxyingService<S>
135{
136 type Response = Response;
137
138 type Error = Infallible;
139
140 type Future = Either<
141 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
142 S::Future,
143 >;
144
145 fn poll_ready(
146 &mut self,
147 cx: &mut std::task::Context<'_>,
148 ) -> std::task::Poll<Result<(), Self::Error>> {
149 self.inner.poll_ready(cx)
150 }
151
152 fn call(&mut self, req: Request) -> Self::Future {
153 if req
154 .headers()
155 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
156 {
157 let path = req.uri().path();
158 let method = path.trim_start_matches("/");
159
160 if is_protected_method(method) {
161 return Either::Right(self.inner.call(req));
162 }
163
164 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it
165 if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok()
166 && (path.ends_with("app.bsky.ageassurance.getState")
167 || path.ends_with("app.bsky.unspecced.getAgeAssuranceState"))
168 {
169 return Either::Right(self.inner.call(req));
170 }
171
172 Either::Left(self.handler.call(req))
173 } else {
174 Either::Right(self.inner.call(req))
175 }
176 }
177}
178
179async fn proxy_handler(
180 State(state): State<AppState>,
181 uri: http::Uri,
182 method_verb: Method,
183 headers: HeaderMap,
184 RawQuery(query): RawQuery,
185 body: Bytes,
186) -> Response {
187 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
188 let method = uri.path().trim_start_matches("/");
189 if is_protected_method(&method) {
190 warn!(method = %method, "Attempted to proxy protected method");
191 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method))
192 .into_response();
193 }
194
195 let Some(proxy_header) = headers
196 .get("atproto-proxy")
197 .and_then(|h| h.to_str().ok())
198 .map(String::from)
199 else {
200 return ApiError::InvalidRequest("Missing required atproto-proxy header".into())
201 .into_response();
202 };
203
204 let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
205 let Some(resolved) = state.did_resolver.resolve_did(did).await else {
206 error!(did = %did, "Could not resolve service DID");
207 return ApiError::UpstreamFailure.into_response();
208 };
209
210 let target_url = match &query {
211 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
212 None => format!("{}/xrpc/{}", resolved.url, method),
213 };
214 info!("Proxying {} request to {}", method_verb, target_url);
215
216 let client = proxy_client();
217 let mut request_builder = client.request(method_verb, &target_url);
218
219 let mut auth_header_val = headers.get("Authorization").cloned();
220 if let Some(token) = crate::auth::extract_bearer_token_from_header(
221 headers.get("Authorization").and_then(|h| h.to_str().ok()),
222 ) {
223 match crate::auth::validate_bearer_token(&state.db, &token).await {
224 Ok(auth_user) => {
225 if let Err(e) = crate::auth::scope_check::check_rpc_scope(
226 auth_user.is_oauth,
227 auth_user.scope.as_deref(),
228 &resolved.did,
229 &method,
230 ) {
231 return e;
232 }
233
234 if let Some(key_bytes) = auth_user.key_bytes {
235 match crate::auth::create_service_token(
236 &auth_user.did,
237 &resolved.did,
238 &method,
239 &key_bytes,
240 ) {
241 Ok(new_token) => {
242 if let Ok(val) =
243 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
244 {
245 auth_header_val = Some(val);
246 }
247 }
248 Err(e) => {
249 warn!("Failed to create service token: {:?}", e);
250 }
251 }
252 }
253 }
254 Err(e) => {
255 warn!("Token validation failed: {:?}", e);
256 if matches!(e, crate::auth::TokenValidationError::TokenExpired) {
257 let auth_header_str = headers
258 .get("Authorization")
259 .and_then(|h| h.to_str().ok())
260 .unwrap_or("");
261 let is_dpop = auth_header_str
262 .trim()
263 .get(..5)
264 .is_some_and(|s| s.eq_ignore_ascii_case("dpop "));
265 let scheme = if is_dpop { "DPoP" } else { "Bearer" };
266 let www_auth = format!(
267 "{} error=\"invalid_token\", error_description=\"Token has expired\"",
268 scheme
269 );
270 let mut response =
271 ApiError::ExpiredToken(Some("Token has expired".into())).into_response();
272 response
273 .headers_mut()
274 .insert("WWW-Authenticate", www_auth.parse().unwrap());
275 if is_dpop {
276 let nonce = crate::oauth::verify::generate_dpop_nonce();
277 response
278 .headers_mut()
279 .insert("DPoP-Nonce", nonce.parse().unwrap());
280 }
281 return response;
282 }
283 }
284 }
285 }
286
287 if let Some(val) = auth_header_val {
288 request_builder = request_builder.header("Authorization", val);
289 }
290 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
291 if let Some(val) = headers.get(*header_name) {
292 request_builder = request_builder.header(*header_name, val);
293 }
294 }
295 if !body.is_empty() {
296 request_builder = request_builder.body(body);
297 }
298
299 match request_builder.send().await {
300 Ok(resp) => {
301 let status = resp.status();
302 let headers = resp.headers().clone();
303 let body = match resp.bytes().await {
304 Ok(b) => b,
305 Err(e) => {
306 error!("Error reading proxy response body: {:?}", e);
307 return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
308 .into_response();
309 }
310 };
311 let mut response_builder = Response::builder().status(status);
312 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
313 if let Some(val) = headers.get(*header_name) {
314 response_builder = response_builder.header(*header_name, val);
315 }
316 }
317 match response_builder.body(axum::body::Body::from(body)) {
318 Ok(r) => r,
319 Err(e) => {
320 error!("Error building proxy response: {:?}", e);
321 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
322 }
323 }
324 }
325 Err(e) => {
326 error!("Error sending proxy request: {:?}", e);
327 if e.is_timeout() {
328 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response()
329 } else {
330 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response()
331 }
332 }
333 }
334}