this repo has no description
1use aws_config::BehaviorVersion;
2use aws_sdk_s3::Client as S3Client;
3use aws_sdk_s3::config::Credentials;
4use chrono::Utc;
5use reqwest::{Client, StatusCode, header};
6use serde_json::{Value, json};
7use sqlx::postgres::PgPoolOptions;
8#[allow(unused_imports)]
9use std::collections::HashMap;
10use std::sync::OnceLock;
11#[allow(unused_imports)]
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tranquil_pds::state::AppState;
15use wiremock::matchers::{method, path};
16use wiremock::{Mock, MockServer, ResponseTemplate};
17
18static SERVER_URL: OnceLock<String> = OnceLock::new();
19static APP_PORT: OnceLock<u16> = OnceLock::new();
20static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21static TEST_DB_POOL: OnceLock<sqlx::PgPool> = OnceLock::new();
22
23#[cfg(not(feature = "external-infra"))]
24use testcontainers::core::ContainerPort;
25#[cfg(not(feature = "external-infra"))]
26use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
27#[cfg(not(feature = "external-infra"))]
28use testcontainers_modules::postgres::Postgres;
29#[cfg(not(feature = "external-infra"))]
30static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
31#[cfg(not(feature = "external-infra"))]
32static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
33
34#[allow(dead_code)]
35pub const AUTH_TOKEN: &str = "test-token";
36#[allow(dead_code)]
37pub const BAD_AUTH_TOKEN: &str = "bad-token";
38#[allow(dead_code)]
39pub const AUTH_DID: &str = "did:plc:fake";
40#[allow(dead_code)]
41pub const TARGET_DID: &str = "did:plc:target";
42
43fn has_external_infra() -> bool {
44 std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok()
45 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
46}
47#[cfg(test)]
48#[ctor::dtor]
49fn cleanup() {
50 if has_external_infra() {
51 return;
52 }
53 if std::env::var("XDG_RUNTIME_DIR").is_ok() {
54 let _ = std::process::Command::new("podman")
55 .args(["rm", "-f", "--filter", "label=tranquil_pds_test=true"])
56 .output();
57 }
58 let _ = std::process::Command::new("docker")
59 .args([
60 "container",
61 "prune",
62 "-f",
63 "--filter",
64 "label=tranquil_pds_test=true",
65 ])
66 .output();
67}
68
69#[allow(dead_code)]
70pub fn client() -> Client {
71 Client::new()
72}
73
74#[allow(dead_code)]
75pub fn app_port() -> u16 {
76 *APP_PORT.get().expect("APP_PORT not initialized")
77}
78
79pub async fn base_url() -> &'static str {
80 SERVER_URL.get_or_init(|| {
81 let (tx, rx) = std::sync::mpsc::channel();
82 std::thread::spawn(move || {
83 unsafe {
84 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
85 }
86 if std::env::var("DOCKER_HOST").is_err()
87 && let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR")
88 {
89 let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
90 if podman_sock.exists() {
91 unsafe {
92 std::env::set_var(
93 "DOCKER_HOST",
94 format!("unix://{}", podman_sock.display()),
95 );
96 }
97 }
98 }
99 let rt = tokio::runtime::Runtime::new().unwrap();
100 rt.block_on(async move {
101 if has_external_infra() {
102 let url = setup_with_external_infra().await;
103 tx.send(url).unwrap();
104 } else {
105 let url = setup_with_testcontainers().await;
106 tx.send(url).unwrap();
107 }
108 std::future::pending::<()>().await;
109 });
110 });
111 rx.recv().expect("Failed to start test server")
112 })
113}
114
115async fn setup_with_external_infra() -> String {
116 let database_url =
117 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
118 let s3_endpoint =
119 std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
120 unsafe {
121 std::env::set_var(
122 "S3_BUCKET",
123 std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
124 );
125 std::env::set_var(
126 "AWS_ACCESS_KEY_ID",
127 std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
128 );
129 std::env::set_var(
130 "AWS_SECRET_ACCESS_KEY",
131 std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
132 );
133 std::env::set_var(
134 "AWS_REGION",
135 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
136 );
137 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
138 std::env::set_var("MAX_IMPORT_SIZE", "100000000");
139 std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
140 }
141 let mock_server = MockServer::start().await;
142 setup_mock_appview(&mock_server).await;
143 let mock_uri = mock_server.uri();
144 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
145 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
146 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
147 MOCK_APPVIEW.set(mock_server).ok();
148 spawn_app(database_url).await
149}
150
151#[cfg(not(feature = "external-infra"))]
152async fn setup_with_testcontainers() -> String {
153 let s3_container = GenericImage::new("minio/minio", "latest")
154 .with_exposed_port(ContainerPort::Tcp(9000))
155 .with_env_var("MINIO_ROOT_USER", "minioadmin")
156 .with_env_var("MINIO_ROOT_PASSWORD", "minioadmin")
157 .with_cmd(vec!["server".to_string(), "/data".to_string()])
158 .with_label("tranquil_pds_test", "true")
159 .start()
160 .await
161 .expect("Failed to start MinIO");
162 let s3_port = s3_container
163 .get_host_port_ipv4(9000)
164 .await
165 .expect("Failed to get S3 port");
166 let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
167 unsafe {
168 std::env::set_var("S3_BUCKET", "test-bucket");
169 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
170 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
171 std::env::set_var("AWS_REGION", "us-east-1");
172 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
173 std::env::set_var("MAX_IMPORT_SIZE", "100000000");
174 std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
175 }
176 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
177 .region("us-east-1")
178 .endpoint_url(&s3_endpoint)
179 .credentials_provider(Credentials::new(
180 "minioadmin",
181 "minioadmin",
182 None,
183 None,
184 "test",
185 ))
186 .load()
187 .await;
188 let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config)
189 .force_path_style(true)
190 .build();
191 let s3_client = S3Client::from_conf(s3_config);
192 let _ = s3_client.create_bucket().bucket("test-bucket").send().await;
193 let mock_server = MockServer::start().await;
194 setup_mock_appview(&mock_server).await;
195 let mock_uri = mock_server.uri();
196 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
197 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
198 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
199 MOCK_APPVIEW.set(mock_server).ok();
200 S3_CONTAINER.set(s3_container).ok();
201 let container = Postgres::default()
202 .with_tag("18-alpine")
203 .with_label("tranquil_pds_test", "true")
204 .start()
205 .await
206 .expect("Failed to start Postgres");
207 let connection_string = format!(
208 "postgres://postgres:postgres@127.0.0.1:{}",
209 container
210 .get_host_port_ipv4(5432)
211 .await
212 .expect("Failed to get port")
213 );
214 DB_CONTAINER.set(container).ok();
215 spawn_app(connection_string).await
216}
217
218#[cfg(feature = "external-infra")]
219async fn setup_with_testcontainers() -> String {
220 panic!(
221 "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
222 );
223}
224
225async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) {
226 Mock::given(method("GET"))
227 .and(path("/.well-known/did.json"))
228 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
229 "id": did,
230 "service": [{
231 "id": "#atproto_appview",
232 "type": "AtprotoAppView",
233 "serviceEndpoint": service_endpoint
234 }]
235 })))
236 .mount(mock_server)
237 .await;
238}
239
240async fn setup_mock_appview(_mock_server: &MockServer) {}
241
242async fn spawn_app(database_url: String) -> String {
243 use tranquil_pds::rate_limit::RateLimiters;
244 let pool = PgPoolOptions::new()
245 .max_connections(3)
246 .acquire_timeout(std::time::Duration::from_secs(30))
247 .connect(&database_url)
248 .await
249 .expect("Failed to connect to Postgres. Make sure the database is running.");
250 sqlx::migrate!("./migrations")
251 .run(&pool)
252 .await
253 .expect("Failed to run migrations");
254 let test_pool = PgPoolOptions::new()
255 .max_connections(5)
256 .acquire_timeout(std::time::Duration::from_secs(30))
257 .connect(&database_url)
258 .await
259 .expect("Failed to create test pool");
260 TEST_DB_POOL.set(test_pool).ok();
261 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
262 let addr = listener.local_addr().unwrap();
263 APP_PORT.set(addr.port()).ok();
264 unsafe {
265 std::env::set_var("PDS_HOSTNAME", addr.to_string());
266 }
267 let rate_limiters = RateLimiters::new()
268 .with_login_limit(10000)
269 .with_account_creation_limit(10000)
270 .with_password_reset_limit(10000)
271 .with_email_update_limit(10000)
272 .with_oauth_authorize_limit(10000)
273 .with_oauth_token_limit(10000);
274 let state = AppState::from_db(pool)
275 .await
276 .with_rate_limiters(rate_limiters);
277 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
278 let app = tranquil_pds::app(state);
279 tokio::spawn(async move {
280 axum::serve(listener, app).await.unwrap();
281 });
282 format!("http://{}", addr)
283}
284
285#[allow(dead_code)]
286pub async fn get_db_connection_string() -> String {
287 base_url().await;
288 if has_external_infra() {
289 std::env::var("DATABASE_URL").expect("DATABASE_URL not set")
290 } else {
291 #[cfg(not(feature = "external-infra"))]
292 {
293 let container = DB_CONTAINER.get().expect("DB container not initialized");
294 let port = container
295 .get_host_port_ipv4(5432)
296 .await
297 .expect("Failed to get port");
298 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
299 }
300 #[cfg(feature = "external-infra")]
301 {
302 panic!("DATABASE_URL must be set with external-infra feature");
303 }
304 }
305}
306
307#[allow(dead_code)]
308pub async fn get_test_db_pool() -> &'static sqlx::PgPool {
309 base_url().await;
310 TEST_DB_POOL.get().expect("TEST_DB_POOL not initialized")
311}
312
313#[allow(dead_code)]
314pub async fn verify_new_account(client: &Client, did: &str) -> String {
315 let pool = get_test_db_pool().await;
316 let body_text: String = sqlx::query_scalar!(
317 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
318 did
319 )
320 .fetch_one(pool)
321 .await
322 .expect("Failed to get verification code");
323
324 let lines: Vec<&str> = body_text.lines().collect();
325 let verification_code = lines
326 .iter()
327 .enumerate()
328 .find(|(_, line)| line.contains("verification code is:") || line.contains("code is:"))
329 .and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
330 .or_else(|| {
331 body_text
332 .split_whitespace()
333 .find(|word| word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3)
334 .map(|s| s.to_string())
335 })
336 .unwrap_or_else(|| body_text.clone());
337
338 let confirm_payload = json!({
339 "did": did,
340 "verificationCode": verification_code
341 });
342 let confirm_res = client
343 .post(format!(
344 "{}/xrpc/com.atproto.server.confirmSignup",
345 base_url().await
346 ))
347 .json(&confirm_payload)
348 .send()
349 .await
350 .expect("confirmSignup request failed");
351 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
352 let confirm_body: Value = confirm_res
353 .json()
354 .await
355 .expect("Invalid JSON from confirmSignup");
356 confirm_body["accessJwt"]
357 .as_str()
358 .expect("No accessJwt in confirmSignup response")
359 .to_string()
360}
361
362#[allow(dead_code)]
363pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
364 let res = client
365 .post(format!(
366 "{}/xrpc/com.atproto.repo.uploadBlob",
367 base_url().await
368 ))
369 .header(header::CONTENT_TYPE, mime)
370 .bearer_auth(AUTH_TOKEN)
371 .body(data)
372 .send()
373 .await
374 .expect("Failed to send uploadBlob request");
375 assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob");
376 let body: Value = res.json().await.expect("Blob upload response was not JSON");
377 body["blob"].clone()
378}
379
380#[allow(dead_code)]
381pub async fn create_test_post(
382 client: &Client,
383 text: &str,
384 reply_to: Option<Value>,
385) -> (String, String, String) {
386 let collection = "app.bsky.feed.post";
387 let mut record = json!({
388 "$type": collection,
389 "text": text,
390 "createdAt": Utc::now().to_rfc3339()
391 });
392 if let Some(reply_obj) = reply_to {
393 record["reply"] = reply_obj;
394 }
395 let payload = json!({
396 "repo": AUTH_DID,
397 "collection": collection,
398 "record": record
399 });
400 let res = client
401 .post(format!(
402 "{}/xrpc/com.atproto.repo.createRecord",
403 base_url().await
404 ))
405 .bearer_auth(AUTH_TOKEN)
406 .json(&payload)
407 .send()
408 .await
409 .expect("Failed to send createRecord");
410 assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
411 let body: Value = res
412 .json()
413 .await
414 .expect("createRecord response was not JSON");
415 let uri = body["uri"]
416 .as_str()
417 .expect("Response had no URI")
418 .to_string();
419 let cid = body["cid"]
420 .as_str()
421 .expect("Response had no CID")
422 .to_string();
423 let rkey = uri
424 .split('/')
425 .next_back()
426 .expect("URI was malformed")
427 .to_string();
428 (uri, cid, rkey)
429}
430
431#[allow(dead_code)]
432pub async fn create_account_and_login(client: &Client) -> (String, String) {
433 create_account_and_login_internal(client, false).await
434}
435
436#[allow(dead_code)]
437pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
438 create_account_and_login_internal(client, true).await
439}
440
441async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
442 let mut last_error = String::new();
443 for attempt in 0..3 {
444 if attempt > 0 {
445 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
446 }
447 let handle = format!("u{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
448 let payload = json!({
449 "handle": handle,
450 "email": format!("{}@example.com", handle),
451 "password": "Testpass123!"
452 });
453 let res = match client
454 .post(format!(
455 "{}/xrpc/com.atproto.server.createAccount",
456 base_url().await
457 ))
458 .json(&payload)
459 .send()
460 .await
461 {
462 Ok(r) => r,
463 Err(e) => {
464 last_error = format!("Request failed: {}", e);
465 continue;
466 }
467 };
468 if res.status() == StatusCode::OK {
469 let body: Value = res.json().await.expect("Invalid JSON");
470 let did = body["did"].as_str().expect("No did").to_string();
471 let pool = get_test_db_pool().await;
472 if make_admin {
473 sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
474 .execute(pool)
475 .await
476 .expect("Failed to mark user as admin");
477 }
478 let verification_required = body["verificationRequired"].as_bool().unwrap_or(true);
479 if let Some(access_jwt) = body["accessJwt"].as_str()
480 && !verification_required
481 {
482 return (access_jwt.to_string(), did);
483 }
484 let body_text: String = sqlx::query_scalar!(
485 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
486 &did
487 )
488 .fetch_one(pool)
489 .await
490 .expect("Failed to get verification from comms_queue");
491 let lines: Vec<&str> = body_text.lines().collect();
492 let verification_code = lines
493 .iter()
494 .enumerate()
495 .find(|(_, line): &(usize, &&str)| {
496 line.contains("verification code is:") || line.contains("code is:")
497 })
498 .and_then(|(i, _)| lines.get(i + 1).map(|s: &&str| s.trim().to_string()))
499 .or_else(|| {
500 body_text
501 .split_whitespace()
502 .find(|word: &&str| {
503 word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
504 })
505 .map(|s: &str| s.to_string())
506 })
507 .unwrap_or_else(|| body_text.clone());
508
509 let confirm_payload = json!({
510 "did": did,
511 "verificationCode": verification_code
512 });
513 let confirm_res = client
514 .post(format!(
515 "{}/xrpc/com.atproto.server.confirmSignup",
516 base_url().await
517 ))
518 .json(&confirm_payload)
519 .send()
520 .await
521 .expect("confirmSignup request failed");
522 if confirm_res.status() == StatusCode::OK {
523 let confirm_body: Value = confirm_res
524 .json()
525 .await
526 .expect("Invalid JSON from confirmSignup");
527 let access_jwt = confirm_body["accessJwt"]
528 .as_str()
529 .expect("No accessJwt in confirmSignup response")
530 .to_string();
531 return (access_jwt, did);
532 }
533 last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await);
534 continue;
535 }
536 last_error = format!("Status {}: {:?}", res.status(), res.text().await);
537 }
538 panic!("Failed to create account after 3 attempts: {}", last_error);
539}