this repo has no description
1use std::convert::Infallible;
2
3use crate::api::error::ApiError;
4use crate::api::proxy_client::proxy_client;
5use crate::state::AppState;
6use axum::{
7 body::Bytes,
8 extract::{RawQuery, Request, State},
9 handler::Handler,
10 http::{HeaderMap, Method, StatusCode},
11 response::{IntoResponse, Response},
12};
13use futures_util::future::Either;
14use tower::{Service, util::BoxCloneSyncService};
15use tracing::{error, info, warn};
16
17const PROTECTED_METHODS: &[&str] = &[
18 "com.atproto.admin.sendEmail",
19 "com.atproto.identity.requestPlcOperationSignature",
20 "com.atproto.identity.signPlcOperation",
21 "com.atproto.identity.updateHandle",
22 "com.atproto.server.activateAccount",
23 "com.atproto.server.confirmEmail",
24 "com.atproto.server.createAppPassword",
25 "com.atproto.server.deactivateAccount",
26 "com.atproto.server.getAccountInviteCodes",
27 "com.atproto.server.getSession",
28 "com.atproto.server.listAppPasswords",
29 "com.atproto.server.requestAccountDelete",
30 "com.atproto.server.requestEmailConfirmation",
31 "com.atproto.server.requestEmailUpdate",
32 "com.atproto.server.revokeAppPassword",
33 "com.atproto.server.updateEmail",
34];
35
36fn is_protected_method(method: &str) -> bool {
37 PROTECTED_METHODS.contains(&method)
38}
39
40pub struct XrpcProxyLayer {
41 state: AppState,
42}
43
44impl XrpcProxyLayer {
45 pub fn new(state: AppState) -> Self {
46 XrpcProxyLayer { state }
47 }
48}
49
50impl<S> tower_layer::Layer<S> for XrpcProxyLayer {
51 type Service = XrpcProxyingService<S>;
52
53 fn layer(&self, inner: S) -> Self::Service {
54 XrpcProxyingService {
55 inner,
56 // TODO(nel): make our own service here instead of boxing a HandlerService
57 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())),
58 }
59 }
60}
61
62#[derive(Clone)]
63pub struct XrpcProxyingService<S> {
64 inner: S,
65 handler: BoxCloneSyncService<Request, Response, Infallible>,
66}
67
68impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request>
69 for XrpcProxyingService<S>
70{
71 type Response = Response;
72
73 type Error = Infallible;
74
75 type Future = Either<
76 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future,
77 S::Future,
78 >;
79
80 fn poll_ready(
81 &mut self,
82 cx: &mut std::task::Context<'_>,
83 ) -> std::task::Poll<Result<(), Self::Error>> {
84 self.inner.poll_ready(cx)
85 }
86
87 fn call(&mut self, req: Request) -> Self::Future {
88 if req
89 .headers()
90 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy))
91 {
92 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it.
93 if !std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_err()
94 && (req.uri().path().ends_with("app.bsky.ageassurance.getState")
95 || req
96 .uri()
97 .path()
98 .ends_with("app.bsky.unspecced.getAgeAssuranceState"))
99 {
100 return Either::Right(self.inner.call(req));
101 }
102
103 Either::Left(self.handler.call(req))
104 } else {
105 Either::Right(self.inner.call(req))
106 }
107 }
108}
109
110async fn proxy_handler(
111 State(state): State<AppState>,
112 uri: http::Uri,
113 method_verb: Method,
114 headers: HeaderMap,
115 RawQuery(query): RawQuery,
116 body: Bytes,
117) -> Response {
118 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the /
119 let method = uri.path().trim_start_matches("/");
120 if is_protected_method(&method) {
121 warn!(method = %method, "Attempted to proxy protected method");
122 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method))
123 .into_response();
124 }
125
126 let Some(proxy_header) = headers
127 .get("atproto-proxy")
128 .and_then(|h| h.to_str().ok())
129 .map(String::from)
130 else {
131 return ApiError::InvalidRequest("Missing required atproto-proxy header".into())
132 .into_response();
133 };
134
135 let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
136 let Some(resolved) = state.did_resolver.resolve_did(did).await else {
137 error!(did = %did, "Could not resolve service DID");
138 return ApiError::UpstreamFailure.into_response();
139 };
140
141 let target_url = match &query {
142 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
143 None => format!("{}/xrpc/{}", resolved.url, method),
144 };
145 info!("Proxying {} request to {}", method_verb, target_url);
146
147 let client = proxy_client();
148 let mut request_builder = client.request(method_verb, &target_url);
149
150 let mut auth_header_val = headers.get("Authorization").cloned();
151 if let Some(token) = crate::auth::extract_bearer_token_from_header(
152 headers.get("Authorization").and_then(|h| h.to_str().ok()),
153 ) {
154 match crate::auth::validate_bearer_token(&state.db, &token).await {
155 Ok(auth_user) => {
156 if let Err(e) = crate::auth::scope_check::check_rpc_scope(
157 auth_user.is_oauth,
158 auth_user.scope.as_deref(),
159 &resolved.did,
160 &method,
161 ) {
162 return e;
163 }
164
165 if let Some(key_bytes) = auth_user.key_bytes {
166 match crate::auth::create_service_token(
167 &auth_user.did,
168 &resolved.did,
169 &method,
170 &key_bytes,
171 ) {
172 Ok(new_token) => {
173 if let Ok(val) =
174 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
175 {
176 auth_header_val = Some(val);
177 }
178 }
179 Err(e) => {
180 warn!("Failed to create service token: {:?}", e);
181 }
182 }
183 }
184 }
185 Err(e) => {
186 warn!("Token validation failed: {:?}", e);
187 if matches!(e, crate::auth::TokenValidationError::TokenExpired) {
188 let auth_header_str = headers
189 .get("Authorization")
190 .and_then(|h| h.to_str().ok())
191 .unwrap_or("");
192 let is_dpop = auth_header_str
193 .trim()
194 .get(..5)
195 .is_some_and(|s| s.eq_ignore_ascii_case("dpop "));
196 let scheme = if is_dpop { "DPoP" } else { "Bearer" };
197 let www_auth = format!(
198 "{} error=\"invalid_token\", error_description=\"Token has expired\"",
199 scheme
200 );
201 let mut response =
202 ApiError::ExpiredToken(Some("Token has expired".into())).into_response();
203 response
204 .headers_mut()
205 .insert("WWW-Authenticate", www_auth.parse().unwrap());
206 if is_dpop {
207 let nonce = crate::oauth::verify::generate_dpop_nonce();
208 response
209 .headers_mut()
210 .insert("DPoP-Nonce", nonce.parse().unwrap());
211 }
212 return response;
213 }
214 }
215 }
216 }
217
218 if let Some(val) = auth_header_val {
219 request_builder = request_builder.header("Authorization", val);
220 }
221 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD {
222 if let Some(val) = headers.get(*header_name) {
223 request_builder = request_builder.header(*header_name, val);
224 }
225 }
226 if !body.is_empty() {
227 request_builder = request_builder.body(body);
228 }
229
230 match request_builder.send().await {
231 Ok(resp) => {
232 let status = resp.status();
233 let headers = resp.headers().clone();
234 let body = match resp.bytes().await {
235 Ok(b) => b,
236 Err(e) => {
237 error!("Error reading proxy response body: {:?}", e);
238 return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
239 .into_response();
240 }
241 };
242 let mut response_builder = Response::builder().status(status);
243 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD {
244 if let Some(val) = headers.get(*header_name) {
245 response_builder = response_builder.header(*header_name, val);
246 }
247 }
248 match response_builder.body(axum::body::Body::from(body)) {
249 Ok(r) => r,
250 Err(e) => {
251 error!("Error building proxy response: {:?}", e);
252 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
253 }
254 }
255 }
256 Err(e) => {
257 error!("Error sending proxy request: {:?}", e);
258 if e.is_timeout() {
259 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response()
260 } else {
261 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response()
262 }
263 }
264 }
265}