this repo has no description
1use aws_config::BehaviorVersion;
2use aws_sdk_s3::Client as S3Client;
3use aws_sdk_s3::config::Credentials;
4use chrono::Utc;
5use reqwest::{Client, StatusCode, header};
6use serde_json::{Value, json};
7use sqlx::postgres::PgPoolOptions;
8#[allow(unused_imports)]
9use std::collections::HashMap;
10use std::sync::OnceLock;
11#[allow(unused_imports)]
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tranquil_pds::state::AppState;
15use wiremock::matchers::{method, path};
16use wiremock::{Mock, MockServer, ResponseTemplate};
17
18static SERVER_URL: OnceLock<String> = OnceLock::new();
19static APP_PORT: OnceLock<u16> = OnceLock::new();
20static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new();
21static TEST_DB_POOL: OnceLock<sqlx::PgPool> = OnceLock::new();
22
23#[cfg(not(feature = "external-infra"))]
24use testcontainers::core::ContainerPort;
25#[cfg(not(feature = "external-infra"))]
26use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner};
27#[cfg(not(feature = "external-infra"))]
28use testcontainers_modules::postgres::Postgres;
29#[cfg(not(feature = "external-infra"))]
30static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new();
31#[cfg(not(feature = "external-infra"))]
32static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new();
33
34#[allow(dead_code)]
35pub const AUTH_TOKEN: &str = "test-token";
36#[allow(dead_code)]
37pub const BAD_AUTH_TOKEN: &str = "bad-token";
38#[allow(dead_code)]
39pub const AUTH_DID: &str = "did:plc:fake";
40#[allow(dead_code)]
41pub const TARGET_DID: &str = "did:plc:target";
42
43fn has_external_infra() -> bool {
44 std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok()
45 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok())
46}
47#[cfg(test)]
48#[ctor::dtor]
49fn cleanup() {
50 if has_external_infra() {
51 return;
52 }
53 if std::env::var("XDG_RUNTIME_DIR").is_ok() {
54 let _ = std::process::Command::new("podman")
55 .args(&["rm", "-f", "--filter", "label=tranquil_pds_test=true"])
56 .output();
57 }
58 let _ = std::process::Command::new("docker")
59 .args(&[
60 "container",
61 "prune",
62 "-f",
63 "--filter",
64 "label=tranquil_pds_test=true",
65 ])
66 .output();
67}
68
69#[allow(dead_code)]
70pub fn client() -> Client {
71 Client::new()
72}
73
74#[allow(dead_code)]
75pub fn app_port() -> u16 {
76 *APP_PORT.get().expect("APP_PORT not initialized")
77}
78
79pub async fn base_url() -> &'static str {
80 SERVER_URL.get_or_init(|| {
81 let (tx, rx) = std::sync::mpsc::channel();
82 std::thread::spawn(move || {
83 unsafe {
84 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
85 }
86 if std::env::var("DOCKER_HOST").is_err() {
87 if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
88 let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock");
89 if podman_sock.exists() {
90 unsafe {
91 std::env::set_var(
92 "DOCKER_HOST",
93 format!("unix://{}", podman_sock.display()),
94 );
95 }
96 }
97 }
98 }
99 let rt = tokio::runtime::Runtime::new().unwrap();
100 rt.block_on(async move {
101 if has_external_infra() {
102 let url = setup_with_external_infra().await;
103 tx.send(url).unwrap();
104 } else {
105 let url = setup_with_testcontainers().await;
106 tx.send(url).unwrap();
107 }
108 std::future::pending::<()>().await;
109 });
110 });
111 rx.recv().expect("Failed to start test server")
112 })
113}
114
115async fn setup_with_external_infra() -> String {
116 let database_url =
117 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
118 let s3_endpoint =
119 std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
120 unsafe {
121 std::env::set_var(
122 "S3_BUCKET",
123 std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
124 );
125 std::env::set_var(
126 "AWS_ACCESS_KEY_ID",
127 std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
128 );
129 std::env::set_var(
130 "AWS_SECRET_ACCESS_KEY",
131 std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
132 );
133 std::env::set_var(
134 "AWS_REGION",
135 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
136 );
137 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
138 }
139 let mock_server = MockServer::start().await;
140 setup_mock_appview(&mock_server).await;
141 let mock_uri = mock_server.uri();
142 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
143 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
144 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
145 MOCK_APPVIEW.set(mock_server).ok();
146 spawn_app(database_url).await
147}
148
149#[cfg(not(feature = "external-infra"))]
150async fn setup_with_testcontainers() -> String {
151 let s3_container = GenericImage::new("minio/minio", "latest")
152 .with_exposed_port(ContainerPort::Tcp(9000))
153 .with_env_var("MINIO_ROOT_USER", "minioadmin")
154 .with_env_var("MINIO_ROOT_PASSWORD", "minioadmin")
155 .with_cmd(vec!["server".to_string(), "/data".to_string()])
156 .with_label("tranquil_pds_test", "true")
157 .start()
158 .await
159 .expect("Failed to start MinIO");
160 let s3_port = s3_container
161 .get_host_port_ipv4(9000)
162 .await
163 .expect("Failed to get S3 port");
164 let s3_endpoint = format!("http://127.0.0.1:{}", s3_port);
165 unsafe {
166 std::env::set_var("S3_BUCKET", "test-bucket");
167 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
168 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
169 std::env::set_var("AWS_REGION", "us-east-1");
170 std::env::set_var("S3_ENDPOINT", &s3_endpoint);
171 }
172 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
173 .region("us-east-1")
174 .endpoint_url(&s3_endpoint)
175 .credentials_provider(Credentials::new(
176 "minioadmin",
177 "minioadmin",
178 None,
179 None,
180 "test",
181 ))
182 .load()
183 .await;
184 let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config)
185 .force_path_style(true)
186 .build();
187 let s3_client = S3Client::from_conf(s3_config);
188 let _ = s3_client.create_bucket().bucket("test-bucket").send().await;
189 let mock_server = MockServer::start().await;
190 setup_mock_appview(&mock_server).await;
191 let mock_uri = mock_server.uri();
192 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
193 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
194 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
195 MOCK_APPVIEW.set(mock_server).ok();
196 S3_CONTAINER.set(s3_container).ok();
197 let container = Postgres::default()
198 .with_tag("18-alpine")
199 .with_label("tranquil_pds_test", "true")
200 .start()
201 .await
202 .expect("Failed to start Postgres");
203 let connection_string = format!(
204 "postgres://postgres:postgres@127.0.0.1:{}",
205 container
206 .get_host_port_ipv4(5432)
207 .await
208 .expect("Failed to get port")
209 );
210 DB_CONTAINER.set(container).ok();
211 spawn_app(connection_string).await
212}
213
214#[cfg(feature = "external-infra")]
215async fn setup_with_testcontainers() -> String {
216 panic!(
217 "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
218 );
219}
220
221async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) {
222 Mock::given(method("GET"))
223 .and(path("/.well-known/did.json"))
224 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
225 "id": did,
226 "service": [{
227 "id": "#atproto_appview",
228 "type": "AtprotoAppView",
229 "serviceEndpoint": service_endpoint
230 }]
231 })))
232 .mount(mock_server)
233 .await;
234}
235
236async fn setup_mock_appview(_mock_server: &MockServer) {}
237
238async fn spawn_app(database_url: String) -> String {
239 use tranquil_pds::rate_limit::RateLimiters;
240 let pool = PgPoolOptions::new()
241 .max_connections(3)
242 .acquire_timeout(std::time::Duration::from_secs(30))
243 .connect(&database_url)
244 .await
245 .expect("Failed to connect to Postgres. Make sure the database is running.");
246 sqlx::migrate!("./migrations")
247 .run(&pool)
248 .await
249 .expect("Failed to run migrations");
250 let test_pool = PgPoolOptions::new()
251 .max_connections(5)
252 .acquire_timeout(std::time::Duration::from_secs(30))
253 .connect(&database_url)
254 .await
255 .expect("Failed to create test pool");
256 TEST_DB_POOL.set(test_pool).ok();
257 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
258 let addr = listener.local_addr().unwrap();
259 APP_PORT.set(addr.port()).ok();
260 unsafe {
261 std::env::set_var("PDS_HOSTNAME", addr.to_string());
262 }
263 let rate_limiters = RateLimiters::new()
264 .with_login_limit(10000)
265 .with_account_creation_limit(10000)
266 .with_password_reset_limit(10000)
267 .with_email_update_limit(10000)
268 .with_oauth_authorize_limit(10000)
269 .with_oauth_token_limit(10000);
270 let state = AppState::from_db(pool)
271 .await
272 .with_rate_limiters(rate_limiters);
273 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
274 let app = tranquil_pds::app(state);
275 tokio::spawn(async move {
276 axum::serve(listener, app).await.unwrap();
277 });
278 format!("http://{}", addr)
279}
280
281#[allow(dead_code)]
282pub async fn get_db_connection_string() -> String {
283 base_url().await;
284 if has_external_infra() {
285 std::env::var("DATABASE_URL").expect("DATABASE_URL not set")
286 } else {
287 #[cfg(not(feature = "external-infra"))]
288 {
289 let container = DB_CONTAINER.get().expect("DB container not initialized");
290 let port = container
291 .get_host_port_ipv4(5432)
292 .await
293 .expect("Failed to get port");
294 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
295 }
296 #[cfg(feature = "external-infra")]
297 {
298 panic!("DATABASE_URL must be set with external-infra feature");
299 }
300 }
301}
302
303#[allow(dead_code)]
304pub async fn get_test_db_pool() -> &'static sqlx::PgPool {
305 base_url().await;
306 TEST_DB_POOL.get().expect("TEST_DB_POOL not initialized")
307}
308
309#[allow(dead_code)]
310pub async fn verify_new_account(client: &Client, did: &str) -> String {
311 let pool = get_test_db_pool().await;
312 let body_text: String = sqlx::query_scalar!(
313 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
314 did
315 )
316 .fetch_one(pool)
317 .await
318 .expect("Failed to get verification code");
319
320 let lines: Vec<&str> = body_text.lines().collect();
321 let verification_code = lines
322 .iter()
323 .enumerate()
324 .find(|(_, line)| line.contains("verification code is:") || line.contains("code is:"))
325 .and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
326 .or_else(|| {
327 body_text
328 .split_whitespace()
329 .find(|word| word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3)
330 .map(|s| s.to_string())
331 })
332 .unwrap_or_else(|| body_text.clone());
333
334 let confirm_payload = json!({
335 "did": did,
336 "verificationCode": verification_code
337 });
338 let confirm_res = client
339 .post(format!(
340 "{}/xrpc/com.atproto.server.confirmSignup",
341 base_url().await
342 ))
343 .json(&confirm_payload)
344 .send()
345 .await
346 .expect("confirmSignup request failed");
347 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
348 let confirm_body: Value = confirm_res
349 .json()
350 .await
351 .expect("Invalid JSON from confirmSignup");
352 confirm_body["accessJwt"]
353 .as_str()
354 .expect("No accessJwt in confirmSignup response")
355 .to_string()
356}
357
358#[allow(dead_code)]
359pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value {
360 let res = client
361 .post(format!(
362 "{}/xrpc/com.atproto.repo.uploadBlob",
363 base_url().await
364 ))
365 .header(header::CONTENT_TYPE, mime)
366 .bearer_auth(AUTH_TOKEN)
367 .body(data)
368 .send()
369 .await
370 .expect("Failed to send uploadBlob request");
371 assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob");
372 let body: Value = res.json().await.expect("Blob upload response was not JSON");
373 body["blob"].clone()
374}
375
376#[allow(dead_code)]
377pub async fn create_test_post(
378 client: &Client,
379 text: &str,
380 reply_to: Option<Value>,
381) -> (String, String, String) {
382 let collection = "app.bsky.feed.post";
383 let mut record = json!({
384 "$type": collection,
385 "text": text,
386 "createdAt": Utc::now().to_rfc3339()
387 });
388 if let Some(reply_obj) = reply_to {
389 record["reply"] = reply_obj;
390 }
391 let payload = json!({
392 "repo": AUTH_DID,
393 "collection": collection,
394 "record": record
395 });
396 let res = client
397 .post(format!(
398 "{}/xrpc/com.atproto.repo.createRecord",
399 base_url().await
400 ))
401 .bearer_auth(AUTH_TOKEN)
402 .json(&payload)
403 .send()
404 .await
405 .expect("Failed to send createRecord");
406 assert_eq!(res.status(), StatusCode::OK, "Failed to create post record");
407 let body: Value = res
408 .json()
409 .await
410 .expect("createRecord response was not JSON");
411 let uri = body["uri"]
412 .as_str()
413 .expect("Response had no URI")
414 .to_string();
415 let cid = body["cid"]
416 .as_str()
417 .expect("Response had no CID")
418 .to_string();
419 let rkey = uri
420 .split('/')
421 .last()
422 .expect("URI was malformed")
423 .to_string();
424 (uri, cid, rkey)
425}
426
427#[allow(dead_code)]
428pub async fn create_account_and_login(client: &Client) -> (String, String) {
429 create_account_and_login_internal(client, false).await
430}
431
432#[allow(dead_code)]
433pub async fn create_admin_account_and_login(client: &Client) -> (String, String) {
434 create_account_and_login_internal(client, true).await
435}
436
437async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) {
438 let mut last_error = String::new();
439 for attempt in 0..3 {
440 if attempt > 0 {
441 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
442 }
443 let handle = format!("u{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
444 let payload = json!({
445 "handle": handle,
446 "email": format!("{}@example.com", handle),
447 "password": "Testpass123!"
448 });
449 let res = match client
450 .post(format!(
451 "{}/xrpc/com.atproto.server.createAccount",
452 base_url().await
453 ))
454 .json(&payload)
455 .send()
456 .await
457 {
458 Ok(r) => r,
459 Err(e) => {
460 last_error = format!("Request failed: {}", e);
461 continue;
462 }
463 };
464 if res.status() == StatusCode::OK {
465 let body: Value = res.json().await.expect("Invalid JSON");
466 let did = body["did"].as_str().expect("No did").to_string();
467 let pool = get_test_db_pool().await;
468 if make_admin {
469 sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
470 .execute(pool)
471 .await
472 .expect("Failed to mark user as admin");
473 }
474 let verification_required = body["verificationRequired"].as_bool().unwrap_or(true);
475 if let Some(access_jwt) = body["accessJwt"].as_str() {
476 if !verification_required {
477 return (access_jwt.to_string(), did);
478 }
479 }
480 let body_text: String = sqlx::query_scalar!(
481 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
482 &did
483 )
484 .fetch_one(pool)
485 .await
486 .expect("Failed to get verification from comms_queue");
487 let lines: Vec<&str> = body_text.lines().collect();
488 let verification_code = lines
489 .iter()
490 .enumerate()
491 .find(|(_, line)| {
492 line.contains("verification code is:") || line.contains("code is:")
493 })
494 .and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string()))
495 .or_else(|| {
496 body_text
497 .split_whitespace()
498 .find(|word| {
499 word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3
500 })
501 .map(|s| s.to_string())
502 })
503 .unwrap_or_else(|| body_text.clone());
504
505 let confirm_payload = json!({
506 "did": did,
507 "verificationCode": verification_code
508 });
509 let confirm_res = client
510 .post(format!(
511 "{}/xrpc/com.atproto.server.confirmSignup",
512 base_url().await
513 ))
514 .json(&confirm_payload)
515 .send()
516 .await
517 .expect("confirmSignup request failed");
518 if confirm_res.status() == StatusCode::OK {
519 let confirm_body: Value = confirm_res
520 .json()
521 .await
522 .expect("Invalid JSON from confirmSignup");
523 let access_jwt = confirm_body["accessJwt"]
524 .as_str()
525 .expect("No accessJwt in confirmSignup response")
526 .to_string();
527 return (access_jwt, did);
528 }
529 last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await);
530 continue;
531 }
532 last_error = format!("Status {}: {:?}", res.status(), res.text().await);
533 }
534 panic!("Failed to create account after 3 attempts: {}", last_error);
535}