this repo has no description
1use aws_config::BehaviorVersion;
2use aws_sdk_s3::Client as S3Client;
3use aws_sdk_s3::config::Credentials;
4use chrono::Utc;
5use reqwest::{Client, StatusCode, header};
6use serde_json::{Value, json};
7use sqlx::postgres::PgPoolOptions;
8#[allow(unused_imports)]
9use std::collections::HashMap;
10use std::sync::OnceLock;
11#[allow(unused_imports)]
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tranquil_pds::state::AppState;
15use wiremock::matchers::{method, path};
16use wiremock::{Mock, MockServer, ResponseTemplate};
17
18static SERVER_URL: OnceLock<String> = OnceLock::new();
19static APP_PORT: OnceLock<u16> = OnceLock::new();
20static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21static TEST_DB_POOL: OnceLock<sqlx::PgPool> = OnceLock::new();
22
23#[cfg(not(feature = "external-infra"))]
24use testcontainers::core::ContainerPort;
25#[cfg(not(feature = "external-infra"))]
26use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
27#[cfg(not(feature = "external-infra"))]
28use testcontainers_modules::postgres::Postgres;
29#[cfg(not(feature = "external-infra"))]
30static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
31#[cfg(not(feature = "external-infra"))]
32static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
33
34#[allow(dead_code)]
35pub const AUTH_TOKEN: &str = "test-token";
36#[allow(dead_code)]
37pub const BAD_AUTH_TOKEN: &str = "bad-token";
38#[allow(dead_code)]
39pub const AUTH_DID: &str = "did:plc:fake";
40#[allow(dead_code)]
41pub const TARGET_DID: &str = "did:plc:target";
42
43fn has_external_infra() -> bool {
44 std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok()
45 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
46}
47#[cfg(test)]
48#[ctor::dtor]
49fn cleanup() {
50 if has_external_infra() {
51 return;
52 }
53 if std::env::var("XDG_RUNTIME_DIR").is_ok() {
54 let _ = std::process::Command::new("podman")
55 .args(["rm", "-f", "--filter", "label=tranquil_pds_test=true"])
56 .output();
57 }
58 let _ = std::process::Command::new("docker")
59 .args([
60 "container",
61 "prune",
62 "-f",
63 "--filter",
64 "label=tranquil_pds_test=true",
65 ])
66 .output();
67}
68
69#[allow(dead_code)]
70pub fn client() -> Client {
71 Client::new()
72}
73
74#[allow(dead_code)]
75pub fn app_port() -> u16 {
76 *APP_PORT.get().expect("APP_PORT not initialized")
77}
78
79pub async fn base_url() -> &'static str {
80 SERVER_URL.get_or_init(|| {
81 let (tx, rx) = std::sync::mpsc::channel();
82 std::thread::spawn(move || {
83 unsafe {
84 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
85 }
86 if std::env::var("DOCKER_HOST").is_err()
87 && let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR")
88 {
89 let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
90 if podman_sock.exists() {
91 unsafe {
92 std::env::set_var(
93 "DOCKER_HOST",
94 format!("unix://{}", podman_sock.display()),
95 );
96 }
97 }
98 }
99 let rt = tokio::runtime::Runtime::new().unwrap();
100 rt.block_on(async move {
101 if has_external_infra() {
102 let url = setup_with_external_infra().await;
103 tx.send(url).unwrap();
104 } else {
105 let url = setup_with_testcontainers().await;
106 tx.send(url).unwrap();
107 }
108 std::future::pending::<()>().await;
109 });
110 });
111 rx.recv().expect("Failed to start test server")
112 })
113}
114
115async fn setup_with_external_infra() -> String {
116 let database_url =
117 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
118 let s3_endpoint =
119 std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
120 unsafe {
121 std::env::set_var(
122 "S3_BUCKET",
123 std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
124 );
125 std::env::set_var(
126 "AWS_ACCESS_KEY_ID",
127 std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
128 );
129 std::env::set_var(
130 "AWS_SECRET_ACCESS_KEY",
131 std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
132 );
133 std::env::set_var(
134 "AWS_REGION",
135 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
136 );
137 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
138 std::env::set_var("MAX_IMPORT_SIZE", "100000000");
139 }
140 let mock_server = MockServer::start().await;
141 setup_mock_appview(&mock_server).await;
142 let mock_uri = mock_server.uri();
143 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
144 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
145 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
146 MOCK_APPVIEW.set(mock_server).ok();
147 spawn_app(database_url).await
148}
149
150#[cfg(not(feature = "external-infra"))]
151async fn setup_with_testcontainers() -> String {
152 let s3_container = GenericImage::new("minio/minio", "latest")
153 .with_exposed_port(ContainerPort::Tcp(9000))
154 .with_env_var("MINIO_ROOT_USER", "minioadmin")
155 .with_env_var("MINIO_ROOT_PASSWORD", "minioadmin")
156 .with_cmd(vec!["server".to_string(), "/data".to_string()])
157 .with_label("tranquil_pds_test", "true")
158 .start()
159 .await
160 .expect("Failed to start MinIO");
161 let s3_port = s3_container
162 .get_host_port_ipv4(9000)
163 .await
164 .expect("Failed to get S3 port");
165 let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
166 unsafe {
167 std::env::set_var("S3_BUCKET", "test-bucket");
168 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
169 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
170 std::env::set_var("AWS_REGION", "us-east-1");
171 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
172 std::env::set_var("MAX_IMPORT_SIZE", "100000000");
173 }
174 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
175 .region("us-east-1")
176 .endpoint_url(&s3_endpoint)
177 .credentials_provider(Credentials::new(
178 "minioadmin",
179 "minioadmin",
180 None,
181 None,
182 "test",
183 ))
184 .load()
185 .await;
186 let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config)
187 .force_path_style(true)
188 .build();
189 let s3_client = S3Client::from_conf(s3_config);
190 let _ = s3_client.create_bucket().bucket("test-bucket").send().await;
191 let mock_server = MockServer::start().await;
192 setup_mock_appview(&mock_server).await;
193 let mock_uri = mock_server.uri();
194 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
195 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
196 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
197 MOCK_APPVIEW.set(mock_server).ok();
198 S3_CONTAINER.set(s3_container).ok();
199 let container = Postgres::default()
200 .with_tag("18-alpine")
201 .with_label("tranquil_pds_test", "true")
202 .start()
203 .await
204 .expect("Failed to start Postgres");
205 let connection_string = format!(
206 "postgres://postgres:postgres@127.0.0.1:{}",
207 container
208 .get_host_port_ipv4(5432)
209 .await
210 .expect("Failed to get port")
211 );
212 DB_CONTAINER.set(container).ok();
213 spawn_app(connection_string).await
214}
215
216#[cfg(feature = "external-infra")]
217async fn setup_with_testcontainers() -> String {
218 panic!(
219 "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
220 );
221}
222
223async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) {
224 Mock::given(method("GET"))
225 .and(path("/.well-known/did.json"))
226 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
227 "id": did,
228 "service": [{
229 "id": "#atproto_appview",
230 "type": "AtprotoAppView",
231 "serviceEndpoint": service_endpoint
232 }]
233 })))
234 .mount(mock_server)
235 .await;
236}
237
238async fn setup_mock_appview(_mock_server: &MockServer) {}
239
240async fn spawn_app(database_url: String) -> String {
241 use tranquil_pds::rate_limit::RateLimiters;
242 let pool = PgPoolOptions::new()
243 .max_connections(3)
244 .acquire_timeout(std::time::Duration::from_secs(30))
245 .connect(&database_url)
246 .await
247 .expect("Failed to connect to Postgres. Make sure the database is running.");
248 sqlx::migrate!("./migrations")
249 .run(&pool)
250 .await
251 .expect("Failed to run migrations");
252 let test_pool = PgPoolOptions::new()
253 .max_connections(5)
254 .acquire_timeout(std::time::Duration::from_secs(30))
255 .connect(&database_url)
256 .await
257 .expect("Failed to create test pool");
258 TEST_DB_POOL.set(test_pool).ok();
259 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
260 let addr = listener.local_addr().unwrap();
261 APP_PORT.set(addr.port()).ok();
262 unsafe {
263 std::env::set_var("PDS_HOSTNAME", addr.to_string());
264 }
265 let rate_limiters = RateLimiters::new()
266 .with_login_limit(10000)
267 .with_account_creation_limit(10000)
268 .with_password_reset_limit(10000)
269 .with_email_update_limit(10000)
270 .with_oauth_authorize_limit(10000)
271 .with_oauth_token_limit(10000);
272 let state = AppState::from_db(pool)
273 .await
274 .with_rate_limiters(rate_limiters);
275 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
276 let app = tranquil_pds::app(state);
277 tokio::spawn(async move {
278 axum::serve(listener, app).await.unwrap();
279 });
280 format!("http://{}", addr)
281}
282
283#[allow(dead_code)]
284pub async fn get_db_connection_string() -> String {
285 base_url().await;
286 if has_external_infra() {
287 std::env::var("DATABASE_URL").expect("DATABASE_URL not set")
288 } else {
289 #[cfg(not(feature = "external-infra"))]
290 {
291 let container = DB_CONTAINER.get().expect("DB container not initialized");
292 let port = container
293 .get_host_port_ipv4(5432)
294 .await
295 .expect("Failed to get port");
296 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
297 }
298 #[cfg(feature = "external-infra")]
299 {
300 panic!("DATABASE_URL must be set with external-infra feature");
301 }
302 }
303}
304
305#[allow(dead_code)]
306pub async fn get_test_db_pool() -> &'static sqlx::PgPool {
307 base_url().await;
308 TEST_DB_POOL.get().expect("TEST_DB_POOL not initialized")
309}
310
311#[allow(dead_code)]
312pub async fn verify_new_account(client: &Client, did: &str) -> String {
313 let pool = get_test_db_pool().await;
314 let body_text: String = sqlx::query_scalar!(
315 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
316 did
317 )
318 .fetch_one(pool)
319 .await
320 .expect("Failed to get verification code");
321
322 let lines: Vec<&str> = body_text.lines().collect();
323 let verification_code = lines
324 .iter()
325 .enumerate()
326 .find(|(_, line)| line.contains("verification code is:") || line.contains("code is:"))
327 .and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
328 .or_else(|| {
329 body_text
330 .split_whitespace()
331 .find(|word| word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3)
332 .map(|s| s.to_string())
333 })
334 .unwrap_or_else(|| body_text.clone());
335
336 let confirm_payload = json!({
337 "did": did,
338 "verificationCode": verification_code
339 });
340 let confirm_res = client
341 .post(format!(
342 "{}/xrpc/com.atproto.server.confirmSignup",
343 base_url().await
344 ))
345 .json(&confirm_payload)
346 .send()
347 .await
348 .expect("confirmSignup request failed");
349 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
350 let confirm_body: Value = confirm_res
351 .json()
352 .await
353 .expect("Invalid JSON from confirmSignup");
354 confirm_body["accessJwt"]
355 .as_str()
356 .expect("No accessJwt in confirmSignup response")
357 .to_string()
358}
359
360#[allow(dead_code)]
361pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
362 let res = client
363 .post(format!(
364 "{}/xrpc/com.atproto.repo.uploadBlob",
365 base_url().await
366 ))
367 .header(header::CONTENT_TYPE, mime)
368 .bearer_auth(AUTH_TOKEN)
369 .body(data)
370 .send()
371 .await
372 .expect("Failed to send uploadBlob request");
373 assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob");
374 let body: Value = res.json().await.expect("Blob upload response was not JSON");
375 body["blob"].clone()
376}
377
378#[allow(dead_code)]
379pub async fn create_test_post(
380 client: &Client,
381 text: &str,
382 reply_to: Option<Value>,
383) -> (String, String, String) {
384 let collection = "app.bsky.feed.post";
385 let mut record = json!({
386 "$type": collection,
387 "text": text,
388 "createdAt": Utc::now().to_rfc3339()
389 });
390 if let Some(reply_obj) = reply_to {
391 record["reply"] = reply_obj;
392 }
393 let payload = json!({
394 "repo": AUTH_DID,
395 "collection": collection,
396 "record": record
397 });
398 let res = client
399 .post(format!(
400 "{}/xrpc/com.atproto.repo.createRecord",
401 base_url().await
402 ))
403 .bearer_auth(AUTH_TOKEN)
404 .json(&payload)
405 .send()
406 .await
407 .expect("Failed to send createRecord");
408 assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
409 let body: Value = res
410 .json()
411 .await
412 .expect("createRecord response was not JSON");
413 let uri = body["uri"]
414 .as_str()
415 .expect("Response had no URI")
416 .to_string();
417 let cid = body["cid"]
418 .as_str()
419 .expect("Response had no CID")
420 .to_string();
421 let rkey = uri
422 .split('/')
423 .next_back()
424 .expect("URI was malformed")
425 .to_string();
426 (uri, cid, rkey)
427}
428
429#[allow(dead_code)]
430pub async fn create_account_and_login(client: &Client) -> (String, String) {
431 create_account_and_login_internal(client, false).await
432}
433
434#[allow(dead_code)]
435pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
436 create_account_and_login_internal(client, true).await
437}
438
439async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
440 let mut last_error = String::new();
441 for attempt in 0..3 {
442 if attempt > 0 {
443 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
444 }
445 let handle = format!("u{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
446 let payload = json!({
447 "handle": handle,
448 "email": format!("{}@example.com", handle),
449 "password": "Testpass123!"
450 });
451 let res = match client
452 .post(format!(
453 "{}/xrpc/com.atproto.server.createAccount",
454 base_url().await
455 ))
456 .json(&payload)
457 .send()
458 .await
459 {
460 Ok(r) => r,
461 Err(e) => {
462 last_error = format!("Request failed: {}", e);
463 continue;
464 }
465 };
466 if res.status() == StatusCode::OK {
467 let body: Value = res.json().await.expect("Invalid JSON");
468 let did = body["did"].as_str().expect("No did").to_string();
469 let pool = get_test_db_pool().await;
470 if make_admin {
471 sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
472 .execute(pool)
473 .await
474 .expect("Failed to mark user as admin");
475 }
476 let verification_required = body["verificationRequired"].as_bool().unwrap_or(true);
477 if let Some(access_jwt) = body["accessJwt"].as_str()
478 && !verification_required
479 {
480 return (access_jwt.to_string(), did);
481 }
482 let body_text: String = sqlx::query_scalar!(
483 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
484 &did
485 )
486 .fetch_one(pool)
487 .await
488 .expect("Failed to get verification from comms_queue");
489 let lines: Vec<&str> = body_text.lines().collect();
490 let verification_code = lines
491 .iter()
492 .enumerate()
493 .find(|(_, line): &(usize, &&str)| {
494 line.contains("verification code is:") || line.contains("code is:")
495 })
496 .and_then(|(i, _)| lines.get(i + 1).map(|s: &&str| s.trim().to_string()))
497 .or_else(|| {
498 body_text
499 .split_whitespace()
500 .find(|word: &&str| {
501 word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
502 })
503 .map(|s: &str| s.to_string())
504 })
505 .unwrap_or_else(|| body_text.clone());
506
507 let confirm_payload = json!({
508 "did": did,
509 "verificationCode": verification_code
510 });
511 let confirm_res = client
512 .post(format!(
513 "{}/xrpc/com.atproto.server.confirmSignup",
514 base_url().await
515 ))
516 .json(&confirm_payload)
517 .send()
518 .await
519 .expect("confirmSignup request failed");
520 if confirm_res.status() == StatusCode::OK {
521 let confirm_body: Value = confirm_res
522 .json()
523 .await
524 .expect("Invalid JSON from confirmSignup");
525 let access_jwt = confirm_body["accessJwt"]
526 .as_str()
527 .expect("No accessJwt in confirmSignup response")
528 .to_string();
529 return (access_jwt, did);
530 }
531 last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await);
532 continue;
533 }
534 last_error = format!("Status {}: {:?}", res.status(), res.text().await);
535 }
536 panic!("Failed to create account after 3 attempts: {}", last_error);
537}