this repo has no description
1use aws_config::BehaviorVersion;
2use aws_sdk_s3::Client as S3Client;
3use aws_sdk_s3::config::Credentials;
4use chrono::Utc;
5use reqwest::{Client, StatusCode, header};
6use serde_json::{Value, json};
7use sqlx::postgres::PgPoolOptions;
8#[allow(unused_imports)]
9use std::collections::HashMap;
10use std::sync::OnceLock;
11#[allow(unused_imports)]
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tranquil_pds::state::AppState;
15use wiremock::matchers::{method, path};
16use wiremock::{Mock, MockServer, ResponseTemplate};
17
18static SERVER_URL: OnceLock<String> = OnceLock::new();
19static APP_PORT: OnceLock<u16> = OnceLock::new();
20static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21
22#[cfg(not(feature = "external-infra"))]
23use testcontainers::core::ContainerPort;
24#[cfg(not(feature = "external-infra"))]
25use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
26#[cfg(not(feature = "external-infra"))]
27use testcontainers_modules::postgres::Postgres;
28#[cfg(not(feature = "external-infra"))]
29static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
30#[cfg(not(feature = "external-infra"))]
31static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
32
33#[allow(dead_code)]
34pub const AUTH_TOKEN: &str = "test-token";
35#[allow(dead_code)]
36pub const BAD_AUTH_TOKEN: &str = "bad-token";
37#[allow(dead_code)]
38pub const AUTH_DID: &str = "did:plc:fake";
39#[allow(dead_code)]
40pub const TARGET_DID: &str = "did:plc:target";
41
42fn has_external_infra() -> bool {
43 std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok()
44 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
45}
46#[cfg(test)]
47#[ctor::dtor]
48fn cleanup() {
49 if has_external_infra() {
50 return;
51 }
52 if std::env::var("XDG_RUNTIME_DIR").is_ok() {
53 let _ = std::process::Command::new("podman")
54 .args(&["rm", "-f", "--filter", "label=tranquil_pds_test=true"])
55 .output();
56 }
57 let _ = std::process::Command::new("docker")
58 .args(&[
59 "container",
60 "prune",
61 "-f",
62 "--filter",
63 "label=tranquil_pds_test=true",
64 ])
65 .output();
66}
67
68#[allow(dead_code)]
69pub fn client() -> Client {
70 Client::new()
71}
72
73#[allow(dead_code)]
74pub fn app_port() -> u16 {
75 *APP_PORT.get().expect("APP_PORT not initialized")
76}
77
78pub async fn base_url() -> &'static str {
79 SERVER_URL.get_or_init(|| {
80 let (tx, rx) = std::sync::mpsc::channel();
81 std::thread::spawn(move || {
82 unsafe {
83 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
84 }
85 if std::env::var("DOCKER_HOST").is_err() {
86 if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
87 let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
88 if podman_sock.exists() {
89 unsafe {
90 std::env::set_var(
91 "DOCKER_HOST",
92 format!("unix://{}", podman_sock.display()),
93 );
94 }
95 }
96 }
97 }
98 let rt = tokio::runtime::Runtime::new().unwrap();
99 rt.block_on(async move {
100 if has_external_infra() {
101 let url = setup_with_external_infra().await;
102 tx.send(url).unwrap();
103 } else {
104 let url = setup_with_testcontainers().await;
105 tx.send(url).unwrap();
106 }
107 std::future::pending::<()>().await;
108 });
109 });
110 rx.recv().expect("Failed to start test server")
111 })
112}
113
114async fn setup_with_external_infra() -> String {
115 let database_url =
116 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
117 let s3_endpoint =
118 std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
119 unsafe {
120 std::env::set_var(
121 "S3_BUCKET",
122 std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
123 );
124 std::env::set_var(
125 "AWS_ACCESS_KEY_ID",
126 std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
127 );
128 std::env::set_var(
129 "AWS_SECRET_ACCESS_KEY",
130 std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
131 );
132 std::env::set_var(
133 "AWS_REGION",
134 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
135 );
136 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
137 }
138 let mock_server = MockServer::start().await;
139 setup_mock_appview(&mock_server).await;
140 let mock_uri = mock_server.uri();
141 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
142 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
143 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
144 MOCK_APPVIEW.set(mock_server).ok();
145 spawn_app(database_url).await
146}
147
148#[cfg(not(feature = "external-infra"))]
149async fn setup_with_testcontainers() -> String {
150 let s3_container = GenericImage::new("minio/minio", "latest")
151 .with_exposed_port(ContainerPort::Tcp(9000))
152 .with_env_var("MINIO_ROOT_USER", "minioadmin")
153 .with_env_var("MINIO_ROOT_PASSWORD", "minioadmin")
154 .with_cmd(vec!["server".to_string(), "/data".to_string()])
155 .with_label("tranquil_pds_test", "true")
156 .start()
157 .await
158 .expect("Failed to start MinIO");
159 let s3_port = s3_container
160 .get_host_port_ipv4(9000)
161 .await
162 .expect("Failed to get S3 port");
163 let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
164 unsafe {
165 std::env::set_var("S3_BUCKET", "test-bucket");
166 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
167 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
168 std::env::set_var("AWS_REGION", "us-east-1");
169 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
170 }
171 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
172 .region("us-east-1")
173 .endpoint_url(&s3_endpoint)
174 .credentials_provider(Credentials::new(
175 "minioadmin",
176 "minioadmin",
177 None,
178 None,
179 "test",
180 ))
181 .load()
182 .await;
183 let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config)
184 .force_path_style(true)
185 .build();
186 let s3_client = S3Client::from_conf(s3_config);
187 let _ = s3_client.create_bucket().bucket("test-bucket").send().await;
188 let mock_server = MockServer::start().await;
189 setup_mock_appview(&mock_server).await;
190 let mock_uri = mock_server.uri();
191 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
192 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
193 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
194 MOCK_APPVIEW.set(mock_server).ok();
195 S3_CONTAINER.set(s3_container).ok();
196 let container = Postgres::default()
197 .with_tag("18-alpine")
198 .with_label("tranquil_pds_test", "true")
199 .start()
200 .await
201 .expect("Failed to start Postgres");
202 let connection_string = format!(
203 "postgres://postgres:postgres@127.0.0.1:{}",
204 container
205 .get_host_port_ipv4(5432)
206 .await
207 .expect("Failed to get port")
208 );
209 DB_CONTAINER.set(container).ok();
210 spawn_app(connection_string).await
211}
212
213#[cfg(feature = "external-infra")]
214async fn setup_with_testcontainers() -> String {
215 panic!(
216 "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
217 );
218}
219
220async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) {
221 Mock::given(method("GET"))
222 .and(path("/.well-known/did.json"))
223 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
224 "id": did,
225 "service": [{
226 "id": "#atproto_appview",
227 "type": "AtprotoAppView",
228 "serviceEndpoint": service_endpoint
229 }]
230 })))
231 .mount(mock_server)
232 .await;
233}
234
235async fn setup_mock_appview(_mock_server: &MockServer) {}
236
237async fn spawn_app(database_url: String) -> String {
238 use tranquil_pds::rate_limit::RateLimiters;
239 let pool = PgPoolOptions::new()
240 .max_connections(50)
241 .connect(&database_url)
242 .await
243 .expect("Failed to connect to Postgres. Make sure the database is running.");
244 sqlx::migrate!("./migrations")
245 .run(&pool)
246 .await
247 .expect("Failed to run migrations");
248 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
249 let addr = listener.local_addr().unwrap();
250 APP_PORT.set(addr.port()).ok();
251 unsafe {
252 std::env::set_var("PDS_HOSTNAME", addr.to_string());
253 }
254 let rate_limiters = RateLimiters::new()
255 .with_login_limit(10000)
256 .with_account_creation_limit(10000)
257 .with_password_reset_limit(10000)
258 .with_email_update_limit(10000)
259 .with_oauth_authorize_limit(10000)
260 .with_oauth_token_limit(10000);
261 let state = AppState::new(pool).await.with_rate_limiters(rate_limiters);
262 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
263 let app = tranquil_pds::app(state);
264 tokio::spawn(async move {
265 axum::serve(listener, app).await.unwrap();
266 });
267 format!("http://{}", addr)
268}
269
270#[allow(dead_code)]
271pub async fn get_db_connection_string() -> String {
272 base_url().await;
273 if has_external_infra() {
274 std::env::var("DATABASE_URL").expect("DATABASE_URL not set")
275 } else {
276 #[cfg(not(feature = "external-infra"))]
277 {
278 let container = DB_CONTAINER.get().expect("DB container not initialized");
279 let port = container
280 .get_host_port_ipv4(5432)
281 .await
282 .expect("Failed to get port");
283 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
284 }
285 #[cfg(feature = "external-infra")]
286 {
287 panic!("DATABASE_URL must be set with external-infra feature");
288 }
289 }
290}
291
292#[allow(dead_code)]
293pub async fn verify_new_account(client: &Client, did: &str) -> String {
294 let conn_str = get_db_connection_string().await;
295 let pool = sqlx::postgres::PgPoolOptions::new()
296 .max_connections(2)
297 .connect(&conn_str)
298 .await
299 .expect("Failed to connect to test database");
300 let body_text: String = sqlx::query_scalar!(
301 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
302 did
303 )
304 .fetch_one(&pool)
305 .await
306 .expect("Failed to get verification code");
307
308 let verification_code = body_text
309 .lines()
310 .find(|line| line.contains("verification code:") || line.contains("code is:"))
311 .and_then(|line| {
312 if line.contains("verification code:") {
313 line.split("verification code:")
314 .nth(1)
315 .map(|s| s.trim().to_string())
316 } else {
317 line.split("code is:").nth(1).map(|s| s.trim().to_string())
318 }
319 })
320 .unwrap_or_else(|| {
321 body_text
322 .lines()
323 .find(|line| line.trim().starts_with("MX") && line.contains('-'))
324 .map(|s| s.trim().to_string())
325 .unwrap_or_default()
326 });
327
328 let confirm_payload = json!({
329 "did": did,
330 "verificationCode": verification_code
331 });
332 let confirm_res = client
333 .post(format!(
334 "{}/xrpc/com.atproto.server.confirmSignup",
335 base_url().await
336 ))
337 .json(&confirm_payload)
338 .send()
339 .await
340 .expect("confirmSignup request failed");
341 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
342 let confirm_body: Value = confirm_res
343 .json()
344 .await
345 .expect("Invalid JSON from confirmSignup");
346 confirm_body["accessJwt"]
347 .as_str()
348 .expect("No accessJwt in confirmSignup response")
349 .to_string()
350}
351
352#[allow(dead_code)]
353pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
354 let res = client
355 .post(format!(
356 "{}/xrpc/com.atproto.repo.uploadBlob",
357 base_url().await
358 ))
359 .header(header::CONTENT_TYPE, mime)
360 .bearer_auth(AUTH_TOKEN)
361 .body(data)
362 .send()
363 .await
364 .expect("Failed to send uploadBlob request");
365 assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob");
366 let body: Value = res.json().await.expect("Blob upload response was not JSON");
367 body["blob"].clone()
368}
369
370#[allow(dead_code)]
371pub async fn create_test_post(
372 client: &Client,
373 text: &str,
374 reply_to: Option<Value>,
375) -> (String, String, String) {
376 let collection = "app.bsky.feed.post";
377 let mut record = json!({
378 "$type": collection,
379 "text": text,
380 "createdAt": Utc::now().to_rfc3339()
381 });
382 if let Some(reply_obj) = reply_to {
383 record["reply"] = reply_obj;
384 }
385 let payload = json!({
386 "repo": AUTH_DID,
387 "collection": collection,
388 "record": record
389 });
390 let res = client
391 .post(format!(
392 "{}/xrpc/com.atproto.repo.createRecord",
393 base_url().await
394 ))
395 .bearer_auth(AUTH_TOKEN)
396 .json(&payload)
397 .send()
398 .await
399 .expect("Failed to send createRecord");
400 assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
401 let body: Value = res
402 .json()
403 .await
404 .expect("createRecord response was not JSON");
405 let uri = body["uri"]
406 .as_str()
407 .expect("Response had no URI")
408 .to_string();
409 let cid = body["cid"]
410 .as_str()
411 .expect("Response had no CID")
412 .to_string();
413 let rkey = uri
414 .split('/')
415 .last()
416 .expect("URI was malformed")
417 .to_string();
418 (uri, cid, rkey)
419}
420
421#[allow(dead_code)]
422pub async fn create_account_and_login(client: &Client) -> (String, String) {
423 create_account_and_login_internal(client, false).await
424}
425
426#[allow(dead_code)]
427pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
428 create_account_and_login_internal(client, true).await
429}
430
431async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
432 let mut last_error = String::new();
433 for attempt in 0..3 {
434 if attempt > 0 {
435 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
436 }
437 let handle = format!("user_{}", uuid::Uuid::new_v4());
438 let payload = json!({
439 "handle": handle,
440 "email": format!("{}@example.com", handle),
441 "password": "Testpass123!"
442 });
443 let res = match client
444 .post(format!(
445 "{}/xrpc/com.atproto.server.createAccount",
446 base_url().await
447 ))
448 .json(&payload)
449 .send()
450 .await
451 {
452 Ok(r) => r,
453 Err(e) => {
454 last_error = format!("Request failed: {}", e);
455 continue;
456 }
457 };
458 if res.status() == StatusCode::OK {
459 let body: Value = res.json().await.expect("Invalid JSON");
460 let did = body["did"].as_str().expect("No did").to_string();
461 let conn_str = get_db_connection_string().await;
462 let pool = sqlx::postgres::PgPoolOptions::new()
463 .max_connections(2)
464 .connect(&conn_str)
465 .await
466 .expect("Failed to connect to test database");
467 if make_admin {
468 sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
469 .execute(&pool)
470 .await
471 .expect("Failed to mark user as admin");
472 }
473 if let Some(access_jwt) = body["accessJwt"].as_str() {
474 return (access_jwt.to_string(), did);
475 }
476 let body_text: String = sqlx::query_scalar!(
477 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
478 &did
479 )
480 .fetch_one(&pool)
481 .await
482 .expect("Failed to get verification from comms_queue");
483 let verification_code = body_text
484 .lines()
485 .find(|line| line.contains("verification code:") || line.contains("code is:"))
486 .and_then(|line| {
487 if line.contains("verification code:") {
488 line.split("verification code:")
489 .nth(1)
490 .map(|s| s.trim().to_string())
491 } else if line.contains("code is:") {
492 line.split("code is:").nth(1).map(|s| s.trim().to_string())
493 } else {
494 None
495 }
496 })
497 .unwrap_or_else(|| {
498 body_text
499 .split_whitespace()
500 .find(|word| {
501 word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
502 })
503 .unwrap_or(&body_text)
504 .to_string()
505 });
506
507 let confirm_payload = json!({
508 "did": did,
509 "verificationCode": verification_code
510 });
511 let confirm_res = client
512 .post(format!(
513 "{}/xrpc/com.atproto.server.confirmSignup",
514 base_url().await
515 ))
516 .json(&confirm_payload)
517 .send()
518 .await
519 .expect("confirmSignup request failed");
520 if confirm_res.status() == StatusCode::OK {
521 let confirm_body: Value = confirm_res
522 .json()
523 .await
524 .expect("Invalid JSON from confirmSignup");
525 let access_jwt = confirm_body["accessJwt"]
526 .as_str()
527 .expect("No accessJwt in confirmSignup response")
528 .to_string();
529 return (access_jwt, did);
530 }
531 last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await);
532 continue;
533 }
534 last_error = format!("Status {}: {:?}", res.status(), res.text().await);
535 }
536 panic!("Failed to create account after 3 attempts: {}", last_error);
537}