tangled
alpha
login
or
join now
quilling.dev
/
bluepds
6
fork
atom
Alternative ATProto PDS implementation
6
fork
atom
overview
issues
pulls
pipelines
prototype pipethrough
quilling.dev
9 months ago
9f2716e6
b010f179
+637
-35
3 changed files
expand all
collapse all
unified
split
src
apis
com
atproto
repo
get_record.rs
lib.rs
pipethrough.rs
+30
-35
src/apis/com/atproto/repo/get_record.rs
···
1
1
//! Get a single record from a repository. Does not require auth.
2
2
+
3
3
+
use crate::pipethrough::{ProxyRequest, pipethrough};
4
4
+
2
5
use super::*;
3
6
7
7
+
use rsky_pds::pipethrough::OverrideOpts;
8
8
+
4
9
async fn inner_get_record(
5
10
repo: String,
6
11
collection: String,
7
12
rkey: String,
8
13
cid: Option<String>,
9
9
-
// req: ProxyRequest<'_>,
14
14
+
req: ProxyRequest,
10
15
actor_pools: HashMap<String, ActorStorage>,
11
16
account_manager: Arc<RwLock<AccountManager>>,
12
17
) -> Result<GetRecordOutput> {
···
31
36
_ => bail!("Could not locate record: `{uri}`"),
32
37
}
33
38
} else {
34
34
-
// match req.cfg.bsky_app_view {
35
35
-
// None => bail!("Could not locate record"),
36
36
-
// Some(_) => match pipethrough(
37
37
-
// &req,
38
38
-
// None,
39
39
-
// OverrideOpts {
40
40
-
// aud: None,
41
41
-
// lxm: None,
42
42
-
// },
43
43
-
// )
44
44
-
// .await
45
45
-
// {
46
46
-
// Err(error) => {
47
47
-
// tracing::error!("@LOG: ERROR: {error}");
48
48
-
bail!("Could not locate record")
49
49
-
// }
50
50
-
// Ok(res) => {
51
51
-
// let output: GetRecordOutput = serde_json::from_slice(res.buffer.as_slice())?;
52
52
-
// Ok(output)
53
53
-
// }
54
54
-
// },
55
55
-
// }
39
39
+
match req.cfg.bsky_app_view {
40
40
+
None => bail!("Could not locate record"),
41
41
+
Some(_) => match pipethrough(
42
42
+
&req,
43
43
+
None,
44
44
+
OverrideOpts {
45
45
+
aud: None,
46
46
+
lxm: None,
47
47
+
},
48
48
+
)
49
49
+
.await
50
50
+
{
51
51
+
Err(error) => {
52
52
+
tracing::error!("@LOG: ERROR: {error}");
53
53
+
bail!("Could not locate record")
54
54
+
}
55
55
+
Ok(res) => {
56
56
+
let output: GetRecordOutput = serde_json::from_slice(res.buffer.as_slice())?;
57
57
+
Ok(output)
58
58
+
}
59
59
+
},
60
60
+
}
56
61
}
57
62
}
58
63
···
73
78
Query(input): Query<ParametersData>,
74
79
State(db_actors): State<HashMap<String, ActorStorage, RandomState>>,
75
80
State(account_manager): State<Arc<RwLock<AccountManager>>>,
81
81
+
req: ProxyRequest,
76
82
) -> Result<Json<GetRecordOutput>, ApiError> {
77
83
let repo = input.repo;
78
84
let collection = input.collection;
79
85
let rkey = input.rkey;
80
86
let cid = input.cid;
81
81
-
// let req: ProxyRequest = todo!(); // TODO: Implement service proxy
82
82
-
match inner_get_record(
83
83
-
repo,
84
84
-
collection,
85
85
-
rkey,
86
86
-
cid,
87
87
-
// req,
88
88
-
db_actors,
89
89
-
account_manager,
90
90
-
)
91
91
-
.await
92
92
-
{
87
87
+
match inner_get_record(repo, collection, rkey, cid, req, db_actors, account_manager).await {
93
88
Ok(res) => Ok(Json(res)),
94
89
Err(error) => {
95
90
tracing::error!("@LOG: ERROR: {error}");
+1
src/lib.rs
···
11
11
mod metrics;
12
12
mod models;
13
13
mod oauth;
14
14
+
mod pipethrough;
14
15
mod schema;
15
16
mod serve;
16
17
mod service_proxy;
+606
src/pipethrough.rs
···
1
1
+
//! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs
2
2
+
//! blacksky-algorithms/rsky is licensed under the Apache License 2.0
3
3
+
//!
4
4
+
//! Modified for Axum instead of Rocket
5
5
+
6
6
+
use anyhow::{Result, bail};
7
7
+
use axum::extract::{FromRequestParts, State};
8
8
+
use rsky_identity::IdResolver;
9
9
+
use rsky_pds::apis::ApiError;
10
10
+
use rsky_pds::auth_verifier::{AccessOutput, AccessStandard};
11
11
+
use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg};
12
12
+
use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud};
13
13
+
use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError};
14
14
+
use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context};
15
15
+
// use lazy_static::lazy_static;
16
16
+
use reqwest::header::{CONTENT_TYPE, HeaderValue};
17
17
+
use reqwest::{Client, Method, RequestBuilder, Response};
18
18
+
// use rocket::data::ToByteUnit;
19
19
+
// use rocket::http::{Method, Status};
20
20
+
// use rocket::request::{FromRequest, Outcome, Request};
21
21
+
// use rocket::{Data, State};
22
22
+
use axum::{
23
23
+
body::Bytes,
24
24
+
http::{self, HeaderMap},
25
25
+
};
26
26
+
use rsky_common::{GetServiceEndpointOpts, get_service_endpoint};
27
27
+
use rsky_repo::types::Ids;
28
28
+
use serde::de::DeserializeOwned;
29
29
+
use serde_json::Value as JsonValue;
30
30
+
use std::collections::{BTreeMap, HashSet};
31
31
+
use std::str::FromStr;
32
32
+
use std::sync::Arc;
33
33
+
use std::time::Duration;
34
34
+
use ubyte::ToByteUnit as _;
35
35
+
use url::Url;
36
36
+
37
37
+
use crate::serve::AppState;
38
38
+
39
39
+
// pub struct OverrideOpts {
40
40
+
// pub aud: Option<String>,
41
41
+
// pub lxm: Option<String>,
42
42
+
// }
43
43
+
44
44
+
// pub struct UrlAndAud {
45
45
+
// pub url: Url,
46
46
+
// pub aud: String,
47
47
+
// pub lxm: String,
48
48
+
// }
49
49
+
50
50
+
// pub struct ProxyHeader {
51
51
+
// pub did: String,
52
52
+
// pub service_url: String,
53
53
+
// }
54
54
+
55
55
+
pub struct ProxyRequest {
56
56
+
pub headers: BTreeMap<String, String>,
57
57
+
pub query: Option<String>,
58
58
+
pub path: String,
59
59
+
pub method: Method,
60
60
+
pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>,
61
61
+
pub cfg: ServerConfig,
62
62
+
}
63
63
+
impl FromRequestParts<AppState> for ProxyRequest {
64
64
+
// type Rejection = ApiError;
65
65
+
type Rejection = axum::response::Response;
66
66
+
67
67
+
async fn from_request_parts(
68
68
+
parts: &mut axum::http::request::Parts,
69
69
+
state: &AppState,
70
70
+
) -> Result<Self, Self::Rejection> {
71
71
+
let headers = parts
72
72
+
.headers
73
73
+
.iter()
74
74
+
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
75
75
+
.collect::<BTreeMap<String, String>>();
76
76
+
let query = parts.uri.query().map(|s| s.to_string());
77
77
+
let path = parts.uri.path().to_string();
78
78
+
let method = parts.method.clone();
79
79
+
let id_resolver = state.id_resolver.clone();
80
80
+
// let cfg = state.cfg.clone();
81
81
+
let cfg = env_to_cfg(); // TODO: use state.cfg.clone();
82
82
+
83
83
+
Ok(Self {
84
84
+
headers,
85
85
+
query,
86
86
+
path,
87
87
+
method,
88
88
+
id_resolver,
89
89
+
cfg,
90
90
+
})
91
91
+
}
92
92
+
}
93
93
+
94
94
+
// #[rocket::async_trait]
95
95
+
// impl<'r> FromRequest<'r> for HandlerPipeThrough {
96
96
+
// type Error = anyhow::Error;
97
97
+
98
98
+
// #[tracing::instrument(skip_all)]
99
99
+
// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
100
100
+
// match AccessStandard::from_request(req).await {
101
101
+
// Outcome::Success(output) => {
102
102
+
// let AccessOutput { credentials, .. } = output.access;
103
103
+
// let requester: Option<String> = match credentials {
104
104
+
// None => None,
105
105
+
// Some(credentials) => credentials.did,
106
106
+
// };
107
107
+
// let headers = req.headers().clone().into_iter().fold(
108
108
+
// BTreeMap::new(),
109
109
+
// |mut acc: BTreeMap<String, String>, cur| {
110
110
+
// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
111
111
+
// acc
112
112
+
// },
113
113
+
// );
114
114
+
// let proxy_req = ProxyRequest {
115
115
+
// headers,
116
116
+
// query: match req.uri().query() {
117
117
+
// None => None,
118
118
+
// Some(query) => Some(query.to_string()),
119
119
+
// },
120
120
+
// path: req.uri().path().to_string(),
121
121
+
// method: req.method(),
122
122
+
// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
123
123
+
// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
124
124
+
// };
125
125
+
// match pipethrough(
126
126
+
// &proxy_req,
127
127
+
// requester,
128
128
+
// OverrideOpts {
129
129
+
// aud: None,
130
130
+
// lxm: None,
131
131
+
// },
132
132
+
// )
133
133
+
// .await
134
134
+
// {
135
135
+
// Ok(res) => Outcome::Success(res),
136
136
+
// Err(error) => match error.downcast_ref() {
137
137
+
// Some(InvalidRequestError::XRPCError(xrpc)) => {
138
138
+
// if let XRPCError::FailedResponse {
139
139
+
// status,
140
140
+
// error,
141
141
+
// message,
142
142
+
// headers,
143
143
+
// } = xrpc
144
144
+
// {
145
145
+
// tracing::error!(
146
146
+
// "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}"
147
147
+
// );
148
148
+
// }
149
149
+
// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
150
150
+
// Outcome::Error((Status::BadRequest, error))
151
151
+
// }
152
152
+
// _ => {
153
153
+
// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
154
154
+
// Outcome::Error((Status::BadRequest, error))
155
155
+
// }
156
156
+
// },
157
157
+
// }
158
158
+
// }
159
159
+
// Outcome::Error(err) => {
160
160
+
// req.local_cache(|| Some(ApiError::RuntimeError));
161
161
+
// Outcome::Error((
162
162
+
// Status::BadRequest,
163
163
+
// anyhow::Error::new(InvalidRequestError::AuthError(err.1)),
164
164
+
// ))
165
165
+
// }
166
166
+
// _ => panic!("Unexpected outcome during Pipethrough"),
167
167
+
// }
168
168
+
// }
169
169
+
// }
170
170
+
171
171
+
// #[rocket::async_trait]
172
172
+
// impl<'r> FromRequest<'r> for ProxyRequest<'r> {
173
173
+
// type Error = anyhow::Error;
174
174
+
175
175
+
// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
176
176
+
// let headers = req.headers().clone().into_iter().fold(
177
177
+
// BTreeMap::new(),
178
178
+
// |mut acc: BTreeMap<String, String>, cur| {
179
179
+
// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
180
180
+
// acc
181
181
+
// },
182
182
+
// );
183
183
+
// Outcome::Success(Self {
184
184
+
// headers,
185
185
+
// query: match req.uri().query() {
186
186
+
// None => None,
187
187
+
// Some(query) => Some(query.to_string()),
188
188
+
// },
189
189
+
// path: req.uri().path().to_string(),
190
190
+
// method: req.method(),
191
191
+
// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
192
192
+
// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
193
193
+
// })
194
194
+
// }
195
195
+
// }
196
196
+
197
197
+
pub async fn pipethrough(
198
198
+
req: &ProxyRequest,
199
199
+
requester: Option<String>,
200
200
+
override_opts: OverrideOpts,
201
201
+
) -> Result<HandlerPipeThrough> {
202
202
+
let UrlAndAud {
203
203
+
url,
204
204
+
aud,
205
205
+
lxm: nsid,
206
206
+
} = format_url_and_aud(req, override_opts.aud).await?;
207
207
+
let lxm = override_opts.lxm.unwrap_or(nsid);
208
208
+
let headers = format_headers(req, aud, lxm, requester).await?;
209
209
+
let req_init = format_req_init(req, url, headers, None)?;
210
210
+
let res = make_request(req_init).await?;
211
211
+
parse_proxy_res(res).await
212
212
+
}
213
213
+
214
214
+
pub async fn pipethrough_procedure<T: serde::Serialize>(
215
215
+
req: &ProxyRequest,
216
216
+
requester: Option<String>,
217
217
+
body: Option<T>,
218
218
+
) -> Result<HandlerPipeThrough> {
219
219
+
let UrlAndAud {
220
220
+
url,
221
221
+
aud,
222
222
+
lxm: nsid,
223
223
+
} = format_url_and_aud(req, None).await?;
224
224
+
let headers = format_headers(req, aud, nsid, requester).await?;
225
225
+
let encoded_body: Option<Vec<u8>> = match body {
226
226
+
None => None,
227
227
+
Some(body) => Some(serde_json::to_string(&body)?.into_bytes()),
228
228
+
};
229
229
+
let req_init = format_req_init(req, url, headers, encoded_body)?;
230
230
+
let res = make_request(req_init).await?;
231
231
+
parse_proxy_res(res).await
232
232
+
}
233
233
+
234
234
+
#[tracing::instrument(skip_all)]
235
235
+
pub async fn pipethrough_procedure_post(
236
236
+
req: &ProxyRequest,
237
237
+
requester: Option<String>,
238
238
+
body: Option<Bytes>,
239
239
+
) -> Result<HandlerPipeThrough, ApiError> {
240
240
+
let UrlAndAud {
241
241
+
url,
242
242
+
aud,
243
243
+
lxm: nsid,
244
244
+
} = format_url_and_aud(req, None).await?;
245
245
+
let headers = format_headers(req, aud, nsid, requester).await?;
246
246
+
let encoded_body: Option<JsonValue>;
247
247
+
match body {
248
248
+
None => encoded_body = None,
249
249
+
Some(body) => {
250
250
+
// let res = match body.open(50.megabytes()).into_string().await {
251
251
+
// Ok(res1) => {
252
252
+
// tracing::info!(res1.value);
253
253
+
// res1.value
254
254
+
// }
255
255
+
// Err(error) => {
256
256
+
// tracing::error!("{error}");
257
257
+
// return Err(ApiError::RuntimeError);
258
258
+
// }
259
259
+
// };
260
260
+
let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8");
261
261
+
262
262
+
match serde_json::from_str(res.as_str()) {
263
263
+
Ok(res) => {
264
264
+
encoded_body = Some(res);
265
265
+
}
266
266
+
Err(error) => {
267
267
+
tracing::error!("{error}");
268
268
+
return Err(ApiError::RuntimeError);
269
269
+
}
270
270
+
}
271
271
+
}
272
272
+
};
273
273
+
let req_init = format_req_init_with_value(req, url, headers, encoded_body)?;
274
274
+
let res = make_request(req_init).await?;
275
275
+
Ok(parse_proxy_res(res).await?)
276
276
+
}
277
277
+
278
278
+
// Request setup/formatting
279
279
+
// -------------------
280
280
+
281
281
+
const REQ_HEADERS_TO_FORWARD: [&str; 4] = [
282
282
+
"accept-language",
283
283
+
"content-type",
284
284
+
"atproto-accept-labelers",
285
285
+
"x-bsky-topics",
286
286
+
];
287
287
+
288
288
+
#[tracing::instrument(skip_all)]
289
289
+
pub async fn format_url_and_aud(
290
290
+
req: &ProxyRequest,
291
291
+
aud_override: Option<String>,
292
292
+
) -> Result<UrlAndAud> {
293
293
+
let proxy_to = parse_proxy_header(req).await?;
294
294
+
let nsid = parse_req_nsid(req);
295
295
+
let default_proxy = default_service(req, &nsid).await;
296
296
+
let service_url = match proxy_to {
297
297
+
Some(ref proxy_to) => {
298
298
+
tracing::info!(
299
299
+
"@LOG: format_url_and_aud() proxy_to: {:?}",
300
300
+
proxy_to.service_url
301
301
+
);
302
302
+
Some(proxy_to.service_url.clone())
303
303
+
}
304
304
+
None => match default_proxy {
305
305
+
Some(ref default_proxy) => Some(default_proxy.url.clone()),
306
306
+
None => None,
307
307
+
},
308
308
+
};
309
309
+
let aud = match aud_override {
310
310
+
Some(_) => aud_override,
311
311
+
None => match proxy_to {
312
312
+
Some(proxy_to) => Some(proxy_to.did),
313
313
+
None => match default_proxy {
314
314
+
Some(default_proxy) => Some(default_proxy.did),
315
315
+
None => None,
316
316
+
},
317
317
+
},
318
318
+
};
319
319
+
match (service_url, aud) {
320
320
+
(Some(service_url), Some(aud)) => {
321
321
+
let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?;
322
322
+
if let Some(ref params) = req.query {
323
323
+
url.set_query(Some(params.as_str()));
324
324
+
}
325
325
+
if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) {
326
326
+
bail!(InvalidRequestError::InvalidServiceUrl(url.to_string()));
327
327
+
}
328
328
+
Ok(UrlAndAud {
329
329
+
url,
330
330
+
aud,
331
331
+
lxm: nsid,
332
332
+
})
333
333
+
}
334
334
+
_ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())),
335
335
+
}
336
336
+
}
337
337
+
338
338
+
pub async fn format_headers(
339
339
+
req: &ProxyRequest,
340
340
+
aud: String,
341
341
+
lxm: String,
342
342
+
requester: Option<String>,
343
343
+
) -> Result<HeaderMap> {
344
344
+
let mut headers: HeaderMap = match requester {
345
345
+
Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?,
346
346
+
None => HeaderMap::new(),
347
347
+
};
348
348
+
// forward select headers to upstream services
349
349
+
for header in REQ_HEADERS_TO_FORWARD {
350
350
+
let val = req.headers.get(header);
351
351
+
if let Some(val) = val {
352
352
+
headers.insert(header, HeaderValue::from_str(val)?);
353
353
+
}
354
354
+
}
355
355
+
Ok(headers)
356
356
+
}
357
357
+
358
358
+
pub fn format_req_init(
359
359
+
req: &ProxyRequest,
360
360
+
url: Url,
361
361
+
headers: HeaderMap,
362
362
+
body: Option<Vec<u8>>,
363
363
+
) -> Result<RequestBuilder> {
364
364
+
match req.method {
365
365
+
Method::GET => {
366
366
+
let client = Client::builder()
367
367
+
.user_agent(APP_USER_AGENT)
368
368
+
.http2_keep_alive_while_idle(true)
369
369
+
.http2_keep_alive_timeout(Duration::from_secs(5))
370
370
+
.default_headers(headers)
371
371
+
.build()?;
372
372
+
Ok(client.get(url))
373
373
+
}
374
374
+
Method::HEAD => {
375
375
+
let client = Client::builder()
376
376
+
.user_agent(APP_USER_AGENT)
377
377
+
.http2_keep_alive_while_idle(true)
378
378
+
.http2_keep_alive_timeout(Duration::from_secs(5))
379
379
+
.default_headers(headers)
380
380
+
.build()?;
381
381
+
Ok(client.head(url))
382
382
+
}
383
383
+
Method::POST => {
384
384
+
let client = Client::builder()
385
385
+
.user_agent(APP_USER_AGENT)
386
386
+
.http2_keep_alive_while_idle(true)
387
387
+
.http2_keep_alive_timeout(Duration::from_secs(5))
388
388
+
.default_headers(headers)
389
389
+
.build()?;
390
390
+
Ok(client.post(url).body(body.unwrap()))
391
391
+
}
392
392
+
_ => bail!(InvalidRequestError::MethodNotFound),
393
393
+
}
394
394
+
}
395
395
+
396
396
+
pub fn format_req_init_with_value(
397
397
+
req: &ProxyRequest,
398
398
+
url: Url,
399
399
+
headers: HeaderMap,
400
400
+
body: Option<JsonValue>,
401
401
+
) -> Result<RequestBuilder> {
402
402
+
match req.method {
403
403
+
Method::GET => {
404
404
+
let client = Client::builder()
405
405
+
.user_agent(APP_USER_AGENT)
406
406
+
.http2_keep_alive_while_idle(true)
407
407
+
.http2_keep_alive_timeout(Duration::from_secs(5))
408
408
+
.default_headers(headers)
409
409
+
.build()?;
410
410
+
Ok(client.get(url))
411
411
+
}
412
412
+
Method::HEAD => {
413
413
+
let client = Client::builder()
414
414
+
.user_agent(APP_USER_AGENT)
415
415
+
.http2_keep_alive_while_idle(true)
416
416
+
.http2_keep_alive_timeout(Duration::from_secs(5))
417
417
+
.default_headers(headers)
418
418
+
.build()?;
419
419
+
Ok(client.head(url))
420
420
+
}
421
421
+
Method::POST => {
422
422
+
let client = Client::builder()
423
423
+
.user_agent(APP_USER_AGENT)
424
424
+
.http2_keep_alive_while_idle(true)
425
425
+
.http2_keep_alive_timeout(Duration::from_secs(5))
426
426
+
.default_headers(headers)
427
427
+
.build()?;
428
428
+
Ok(client.post(url).json(&body.unwrap()))
429
429
+
}
430
430
+
_ => bail!(InvalidRequestError::MethodNotFound),
431
431
+
}
432
432
+
}
433
433
+
434
434
+
pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> {
435
435
+
let headers = &req.headers;
436
436
+
let proxy_to: Option<&String> = headers.get("atproto-proxy");
437
437
+
match proxy_to {
438
438
+
None => Ok(None),
439
439
+
Some(proxy_to) => {
440
440
+
let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>();
441
441
+
match (parts.get(0), parts.get(1), parts.get(2)) {
442
442
+
(Some(did), Some(service_id), None) => {
443
443
+
let did = did.to_string();
444
444
+
let mut lock = req.id_resolver.write().await;
445
445
+
match lock.did.resolve(did.clone(), None).await? {
446
446
+
None => bail!(InvalidRequestError::CannotResolveProxyDid),
447
447
+
Some(did_doc) => {
448
448
+
match get_service_endpoint(
449
449
+
did_doc,
450
450
+
GetServiceEndpointOpts {
451
451
+
id: format!("#{service_id}"),
452
452
+
r#type: None,
453
453
+
},
454
454
+
) {
455
455
+
None => bail!(InvalidRequestError::CannotResolveServiceUrl),
456
456
+
Some(service_url) => Ok(Some(ProxyHeader { did, service_url })),
457
457
+
}
458
458
+
}
459
459
+
}
460
460
+
}
461
461
+
(_, None, _) => bail!(InvalidRequestError::NoServiceId),
462
462
+
_ => bail!("error parsing atproto-proxy header"),
463
463
+
}
464
464
+
}
465
465
+
}
466
466
+
}
467
467
+
468
468
+
pub fn parse_req_nsid(req: &ProxyRequest) -> String {
469
469
+
let nsid = req.path.as_str().replace("/xrpc/", "");
470
470
+
match nsid.ends_with("/") {
471
471
+
false => nsid,
472
472
+
true => nsid
473
473
+
.trim_end_matches(|c| c == nsid.chars().last().unwrap())
474
474
+
.to_string(),
475
475
+
}
476
476
+
}
477
477
+
478
478
+
// Sending request
479
479
+
// -------------------
480
480
+
#[tracing::instrument(skip_all)]
481
481
+
pub async fn make_request(req_init: RequestBuilder) -> Result<Response> {
482
482
+
let res = req_init.send().await;
483
483
+
match res {
484
484
+
Err(e) => {
485
485
+
tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string());
486
486
+
bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure))
487
487
+
}
488
488
+
Ok(res) => match res.error_for_status_ref() {
489
489
+
Ok(_) => Ok(res),
490
490
+
Err(_) => {
491
491
+
let status = res.status().to_string();
492
492
+
let headers = res.headers().clone();
493
493
+
let error_body = res.json::<JsonValue>().await?;
494
494
+
bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse {
495
495
+
status,
496
496
+
headers,
497
497
+
error: match error_body["error"].as_str() {
498
498
+
None => None,
499
499
+
Some(error_body_error) => Some(error_body_error.to_string()),
500
500
+
},
501
501
+
message: match error_body["message"].as_str() {
502
502
+
None => None,
503
503
+
Some(error_body_message) => Some(error_body_message.to_string()),
504
504
+
}
505
505
+
}))
506
506
+
}
507
507
+
},
508
508
+
}
509
509
+
}
510
510
+
511
511
+
// Response parsing/forwarding
512
512
+
// -------------------
513
513
+
514
514
+
const RES_HEADERS_TO_FORWARD: [&str; 4] = [
515
515
+
"content-type",
516
516
+
"content-language",
517
517
+
"atproto-repo-rev",
518
518
+
"atproto-content-labelers",
519
519
+
];
520
520
+
521
521
+
pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> {
522
522
+
let encoding = match res.headers().get(CONTENT_TYPE) {
523
523
+
Some(content_type) => content_type.to_str()?,
524
524
+
None => "application/json",
525
525
+
};
526
526
+
// Release borrow
527
527
+
let encoding = encoding.to_string();
528
528
+
let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold(
529
529
+
BTreeMap::new(),
530
530
+
|mut acc: BTreeMap<String, String>, cur| {
531
531
+
let _ = match res.headers().get(cur) {
532
532
+
Some(res_header_val) => acc.insert(
533
533
+
cur.to_string(),
534
534
+
res_header_val.clone().to_str().unwrap().to_string(),
535
535
+
),
536
536
+
None => None,
537
537
+
};
538
538
+
acc
539
539
+
},
540
540
+
);
541
541
+
let buffer = read_array_buffer_res(res).await?;
542
542
+
Ok(HandlerPipeThrough {
543
543
+
encoding,
544
544
+
buffer,
545
545
+
headers: Some(res_headers),
546
546
+
})
547
547
+
}
548
548
+
549
549
+
// Utils
550
550
+
// -------------------
551
551
+
552
552
+
pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> {
553
553
+
let cfg = req.cfg.clone();
554
554
+
match Ids::from_str(nsid) {
555
555
+
Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service,
556
556
+
Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service,
557
557
+
Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service,
558
558
+
Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service,
559
559
+
Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service,
560
560
+
Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service,
561
561
+
Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service,
562
562
+
Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service,
563
563
+
Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service,
564
564
+
Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service,
565
565
+
Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service,
566
566
+
Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service,
567
567
+
Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service,
568
568
+
Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service,
569
569
+
Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service,
570
570
+
Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service,
571
571
+
_ => cfg.bsky_app_view,
572
572
+
}
573
573
+
}
574
574
+
575
575
+
pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> {
576
576
+
let buffer = res.buffer;
577
577
+
let record = serde_json::from_slice::<T>(buffer.as_slice())?;
578
578
+
Ok(record)
579
579
+
}
580
580
+
581
581
+
#[tracing::instrument(skip_all)]
582
582
+
pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> {
583
583
+
match res.bytes().await {
584
584
+
Ok(bytes) => Ok(bytes.to_vec()),
585
585
+
Err(err) => {
586
586
+
tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string());
587
587
+
bail!("UpstreamFailure")
588
588
+
}
589
589
+
}
590
590
+
}
591
591
+
592
592
+
pub fn is_safe_url(url: Url) -> bool {
593
593
+
if url.scheme() != "https" {
594
594
+
return false;
595
595
+
}
596
596
+
match url.host_str() {
597
597
+
None => false,
598
598
+
Some(hostname) if hostname == "localhost" => false,
599
599
+
Some(hostname) => {
600
600
+
if std::net::IpAddr::from_str(hostname).is_ok() {
601
601
+
return false;
602
602
+
}
603
603
+
true
604
604
+
}
605
605
+
}
606
606
+
}