this repo has no description
1use aws_config::BehaviorVersion; 2use aws_sdk_s3::Client as S3Client; 3use aws_sdk_s3::config::Credentials; 4use chrono::Utc; 5use reqwest::{Client, StatusCode, header}; 6use serde_json::{Value, json}; 7use sqlx::postgres::PgPoolOptions; 8#[allow(unused_imports)] 9use std::collections::HashMap; 10use std::sync::OnceLock; 11#[allow(unused_imports)] 12use std::time::Duration; 13use tokio::net::TcpListener; 14use tranquil_pds::state::AppState; 15use wiremock::matchers::{method, path}; 16use wiremock::{Mock, MockServer, ResponseTemplate}; 17 18static SERVER_URL: OnceLock<String> = OnceLock::new(); 19static APP_PORT: OnceLock<u16> = OnceLock::new(); 20static MOCK_APPVIEW: OnceLock<MockServer> = OnceLock::new(); 21static TEST_DB_POOL: OnceLock<sqlx::PgPool> = OnceLock::new(); 22 23#[cfg(not(feature = "external-infra"))] 24use testcontainers::core::ContainerPort; 25#[cfg(not(feature = "external-infra"))] 26use testcontainers::{ContainerAsync, GenericImage, ImageExt, runners::AsyncRunner}; 27#[cfg(not(feature = "external-infra"))] 28use testcontainers_modules::postgres::Postgres; 29#[cfg(not(feature = "external-infra"))] 30static DB_CONTAINER: OnceLock<ContainerAsync<Postgres>> = OnceLock::new(); 31#[cfg(not(feature = "external-infra"))] 32static S3_CONTAINER: OnceLock<ContainerAsync<GenericImage>> = OnceLock::new(); 33 34#[allow(dead_code)] 35pub const AUTH_TOKEN: &str = "test-token"; 36#[allow(dead_code)] 37pub const BAD_AUTH_TOKEN: &str = "bad-token"; 38#[allow(dead_code)] 39pub const AUTH_DID: &str = "did:plc:fake"; 40#[allow(dead_code)] 41pub const TARGET_DID: &str = "did:plc:target"; 42 43fn has_external_infra() -> bool { 44 std::env::var("TRANQUIL_PDS_TEST_INFRA_READY").is_ok() 45 || (std::env::var("DATABASE_URL").is_ok() && std::env::var("S3_ENDPOINT").is_ok()) 46} 47#[cfg(test)] 48#[ctor::dtor] 49fn cleanup() { 50 if has_external_infra() { 51 return; 52 } 53 if std::env::var("XDG_RUNTIME_DIR").is_ok() { 54 let _ = std::process::Command::new("podman") 55 .args(["rm", "-f", "--filter", "label=tranquil_pds_test=true"]) 56 .output(); 57 } 58 let _ = std::process::Command::new("docker") 59 .args([ 60 "container", 61 "prune", 62 "-f", 63 "--filter", 64 "label=tranquil_pds_test=true", 65 ]) 66 .output(); 67} 68 69#[allow(dead_code)] 70pub fn client() -> Client { 71 Client::new() 72} 73 74#[allow(dead_code)] 75pub fn app_port() -> u16 { 76 *APP_PORT.get().expect("APP_PORT not initialized") 77} 78 79pub async fn base_url() -> &'static str { 80 SERVER_URL.get_or_init(|| { 81 let (tx, rx) = std::sync::mpsc::channel(); 82 std::thread::spawn(move || { 83 unsafe { 84 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1"); 85 } 86 if std::env::var("DOCKER_HOST").is_err() 87 && let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") 88 { 89 let podman_sock = std::path::Path::new(&runtime_dir).join("podman/podman.sock"); 90 if podman_sock.exists() { 91 unsafe { 92 std::env::set_var( 93 "DOCKER_HOST", 94 format!("unix://{}", podman_sock.display()), 95 ); 96 } 97 } 98 } 99 let rt = tokio::runtime::Runtime::new().unwrap(); 100 rt.block_on(async move { 101 if has_external_infra() { 102 let url = setup_with_external_infra().await; 103 tx.send(url).unwrap(); 104 } else { 105 let url = setup_with_testcontainers().await; 106 tx.send(url).unwrap(); 107 } 108 std::future::pending::<()>().await; 109 }); 110 }); 111 rx.recv().expect("Failed to start test server") 112 }) 113} 114 115async fn setup_with_external_infra() -> String { 116 let database_url = 117 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra"); 118 let s3_endpoint = 119 std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra"); 120 unsafe { 121 std::env::set_var( 122 "S3_BUCKET", 123 std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()), 124 ); 125 std::env::set_var( 126 "AWS_ACCESS_KEY_ID", 127 std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()), 128 ); 129 std::env::set_var( 130 "AWS_SECRET_ACCESS_KEY", 131 std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()), 132 ); 133 std::env::set_var( 134 "AWS_REGION", 135 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()), 136 ); 137 std::env::set_var("S3_ENDPOINT", &s3_endpoint); 138 std::env::set_var("MAX_IMPORT_SIZE", "100000000"); 139 std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); 140 } 141 let mock_server = MockServer::start().await; 142 setup_mock_appview(&mock_server).await; 143 let mock_uri = mock_server.uri(); 144 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 145 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 146 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 147 MOCK_APPVIEW.set(mock_server).ok(); 148 spawn_app(database_url).await 149} 150 151#[cfg(not(feature = "external-infra"))] 152async fn setup_with_testcontainers() -> String { 153 let s3_container = GenericImage::new("minio/minio", "latest") 154 .with_exposed_port(ContainerPort::Tcp(9000)) 155 .with_env_var("MINIO_ROOT_USER", "minioadmin") 156 .with_env_var("MINIO_ROOT_PASSWORD", "minioadmin") 157 .with_cmd(vec!["server".to_string(), "/data".to_string()]) 158 .with_label("tranquil_pds_test", "true") 159 .start() 160 .await 161 .expect("Failed to start MinIO"); 162 let s3_port = s3_container 163 .get_host_port_ipv4(9000) 164 .await 165 .expect("Failed to get S3 port"); 166 let s3_endpoint = format!("http://127.0.0.1:{}", s3_port); 167 unsafe { 168 std::env::set_var("S3_BUCKET", "test-bucket"); 169 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin"); 170 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin"); 171 std::env::set_var("AWS_REGION", "us-east-1"); 172 std::env::set_var("S3_ENDPOINT", &s3_endpoint); 173 std::env::set_var("MAX_IMPORT_SIZE", "100000000"); 174 std::env::set_var("SKIP_IMPORT_VERIFICATION", "true"); 175 } 176 let sdk_config = aws_config::defaults(BehaviorVersion::latest()) 177 .region("us-east-1") 178 .endpoint_url(&s3_endpoint) 179 .credentials_provider(Credentials::new( 180 "minioadmin", 181 "minioadmin", 182 None, 183 None, 184 "test", 185 )) 186 .load() 187 .await; 188 let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config) 189 .force_path_style(true) 190 .build(); 191 let s3_client = S3Client::from_conf(s3_config); 192 let _ = s3_client.create_bucket().bucket("test-bucket").send().await; 193 let mock_server = MockServer::start().await; 194 setup_mock_appview(&mock_server).await; 195 let mock_uri = mock_server.uri(); 196 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 197 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 198 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 199 MOCK_APPVIEW.set(mock_server).ok(); 200 S3_CONTAINER.set(s3_container).ok(); 201 let container = Postgres::default() 202 .with_tag("18-alpine") 203 .with_label("tranquil_pds_test", "true") 204 .start() 205 .await 206 .expect("Failed to start Postgres"); 207 let connection_string = format!( 208 "postgres://postgres:postgres@127.0.0.1:{}", 209 container 210 .get_host_port_ipv4(5432) 211 .await 212 .expect("Failed to get port") 213 ); 214 DB_CONTAINER.set(container).ok(); 215 spawn_app(connection_string).await 216} 217 218#[cfg(feature = "external-infra")] 219async fn setup_with_testcontainers() -> String { 220 panic!( 221 "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT." 222 ); 223} 224 225async fn setup_mock_did_document(mock_server: &MockServer, did: &str, service_endpoint: &str) { 226 Mock::given(method("GET")) 227 .and(path("/.well-known/did.json")) 228 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ 229 "id": did, 230 "service": [{ 231 "id": "#atproto_appview", 232 "type": "AtprotoAppView", 233 "serviceEndpoint": service_endpoint 234 }] 235 }))) 236 .mount(mock_server) 237 .await; 238} 239 240async fn setup_mock_appview(_mock_server: &MockServer) {} 241 242async fn spawn_app(database_url: String) -> String { 243 use tranquil_pds::rate_limit::RateLimiters; 244 let pool = PgPoolOptions::new() 245 .max_connections(3) 246 .acquire_timeout(std::time::Duration::from_secs(30)) 247 .connect(&database_url) 248 .await 249 .expect("Failed to connect to Postgres. Make sure the database is running."); 250 sqlx::migrate!("./migrations") 251 .run(&pool) 252 .await 253 .expect("Failed to run migrations"); 254 let test_pool = PgPoolOptions::new() 255 .max_connections(5) 256 .acquire_timeout(std::time::Duration::from_secs(30)) 257 .connect(&database_url) 258 .await 259 .expect("Failed to create test pool"); 260 TEST_DB_POOL.set(test_pool).ok(); 261 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 262 let addr = listener.local_addr().unwrap(); 263 APP_PORT.set(addr.port()).ok(); 264 unsafe { 265 std::env::set_var("PDS_HOSTNAME", addr.to_string()); 266 } 267 let rate_limiters = RateLimiters::new() 268 .with_login_limit(10000) 269 .with_account_creation_limit(10000) 270 .with_password_reset_limit(10000) 271 .with_email_update_limit(10000) 272 .with_oauth_authorize_limit(10000) 273 .with_oauth_token_limit(10000); 274 let state = AppState::from_db(pool) 275 .await 276 .with_rate_limiters(rate_limiters); 277 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await; 278 let app = tranquil_pds::app(state); 279 tokio::spawn(async move { 280 axum::serve(listener, app).await.unwrap(); 281 }); 282 format!("http://{}", addr) 283} 284 285#[allow(dead_code)] 286pub async fn get_db_connection_string() -> String { 287 base_url().await; 288 if has_external_infra() { 289 std::env::var("DATABASE_URL").expect("DATABASE_URL not set") 290 } else { 291 #[cfg(not(feature = "external-infra"))] 292 { 293 let container = DB_CONTAINER.get().expect("DB container not initialized"); 294 let port = container 295 .get_host_port_ipv4(5432) 296 .await 297 .expect("Failed to get port"); 298 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port) 299 } 300 #[cfg(feature = "external-infra")] 301 { 302 panic!("DATABASE_URL must be set with external-infra feature"); 303 } 304 } 305} 306 307#[allow(dead_code)] 308pub async fn get_test_db_pool() -> &'static sqlx::PgPool { 309 base_url().await; 310 TEST_DB_POOL.get().expect("TEST_DB_POOL not initialized") 311} 312 313#[allow(dead_code)] 314pub async fn verify_new_account(client: &Client, did: &str) -> String { 315 let pool = get_test_db_pool().await; 316 let body_text: String = sqlx::query_scalar!( 317 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1", 318 did 319 ) 320 .fetch_one(pool) 321 .await 322 .expect("Failed to get verification code"); 323 324 let lines: Vec<&str> = body_text.lines().collect(); 325 let verification_code = lines 326 .iter() 327 .enumerate() 328 .find(|(_, line)| line.contains("verification code is:") || line.contains("code is:")) 329 .and_then(|(i, _)| lines.get(i + 1).map(|s| s.trim().to_string())) 330 .or_else(|| { 331 body_text 332 .split_whitespace() 333 .find(|word| word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3) 334 .map(|s| s.to_string()) 335 }) 336 .unwrap_or_else(|| body_text.clone()); 337 338 let confirm_payload = json!({ 339 "did": did, 340 "verificationCode": verification_code 341 }); 342 let confirm_res = client 343 .post(format!( 344 "{}/xrpc/com.atproto.server.confirmSignup", 345 base_url().await 346 )) 347 .json(&confirm_payload) 348 .send() 349 .await 350 .expect("confirmSignup request failed"); 351 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed"); 352 let confirm_body: Value = confirm_res 353 .json() 354 .await 355 .expect("Invalid JSON from confirmSignup"); 356 confirm_body["accessJwt"] 357 .as_str() 358 .expect("No accessJwt in confirmSignup response") 359 .to_string() 360} 361 362#[allow(dead_code)] 363pub async fn upload_test_blob(client: &Client, data: &'static str, mime: &'static str) -> Value { 364 let res = client 365 .post(format!( 366 "{}/xrpc/com.atproto.repo.uploadBlob", 367 base_url().await 368 )) 369 .header(header::CONTENT_TYPE, mime) 370 .bearer_auth(AUTH_TOKEN) 371 .body(data) 372 .send() 373 .await 374 .expect("Failed to send uploadBlob request"); 375 assert_eq!(res.status(), StatusCode::OK, "Failed to upload blob"); 376 let body: Value = res.json().await.expect("Blob upload response was not JSON"); 377 body["blob"].clone() 378} 379 380#[allow(dead_code)] 381pub async fn create_test_post( 382 client: &Client, 383 text: &str, 384 reply_to: Option<Value>, 385) -> (String, String, String) { 386 let collection = "app.bsky.feed.post"; 387 let mut record = json!({ 388 "$type": collection, 389 "text": text, 390 "createdAt": Utc::now().to_rfc3339() 391 }); 392 if let Some(reply_obj) = reply_to { 393 record["reply"] = reply_obj; 394 } 395 let payload = json!({ 396 "repo": AUTH_DID, 397 "collection": collection, 398 "record": record 399 }); 400 let res = client 401 .post(format!( 402 "{}/xrpc/com.atproto.repo.createRecord", 403 base_url().await 404 )) 405 .bearer_auth(AUTH_TOKEN) 406 .json(&payload) 407 .send() 408 .await 409 .expect("Failed to send createRecord"); 410 assert_eq!(res.status(), StatusCode::OK, "Failed to create post record"); 411 let body: Value = res 412 .json() 413 .await 414 .expect("createRecord response was not JSON"); 415 let uri = body["uri"] 416 .as_str() 417 .expect("Response had no URI") 418 .to_string(); 419 let cid = body["cid"] 420 .as_str() 421 .expect("Response had no CID") 422 .to_string(); 423 let rkey = uri 424 .split('/') 425 .next_back() 426 .expect("URI was malformed") 427 .to_string(); 428 (uri, cid, rkey) 429} 430 431#[allow(dead_code)] 432pub async fn create_account_and_login(client: &Client) -> (String, String) { 433 create_account_and_login_internal(client, false).await 434} 435 436#[allow(dead_code)] 437pub async fn create_admin_account_and_login(client: &Client) -> (String, String) { 438 create_account_and_login_internal(client, true).await 439} 440 441async fn create_account_and_login_internal(client: &Client, make_admin: bool) -> (String, String) { 442 let mut last_error = String::new(); 443 for attempt in 0..3 { 444 if attempt > 0 { 445 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await; 446 } 447 let handle = format!("u{}", &uuid::Uuid::new_v4().simple().to_string()[..12]); 448 let payload = json!({ 449 "handle": handle, 450 "email": format!("{}@example.com", handle), 451 "password": "Testpass123!" 452 }); 453 let res = match client 454 .post(format!( 455 "{}/xrpc/com.atproto.server.createAccount", 456 base_url().await 457 )) 458 .json(&payload) 459 .send() 460 .await 461 { 462 Ok(r) => r, 463 Err(e) => { 464 last_error = format!("Request failed: {}", e); 465 continue; 466 } 467 }; 468 if res.status() == StatusCode::OK { 469 let body: Value = res.json().await.expect("Invalid JSON"); 470 let did = body["did"].as_str().expect("No did").to_string(); 471 let pool = get_test_db_pool().await; 472 if make_admin { 473 sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 474 .execute(pool) 475 .await 476 .expect("Failed to mark user as admin"); 477 } 478 let verification_required = body["verificationRequired"].as_bool().unwrap_or(true); 479 if let Some(access_jwt) = body["accessJwt"].as_str() 480 && !verification_required 481 { 482 return (access_jwt.to_string(), did); 483 } 484 let body_text: String = sqlx::query_scalar!( 485 "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1", 486 &did 487 ) 488 .fetch_one(pool) 489 .await 490 .expect("Failed to get verification from comms_queue"); 491 let lines: Vec<&str> = body_text.lines().collect(); 492 let verification_code = lines 493 .iter() 494 .enumerate() 495 .find(|(_, line): &(usize, &&str)| { 496 line.contains("verification code is:") || line.contains("code is:") 497 }) 498 .and_then(|(i, _)| lines.get(i + 1).map(|s: &&str| s.trim().to_string())) 499 .or_else(|| { 500 body_text 501 .split_whitespace() 502 .find(|word: &&str| { 503 word.contains('-') && word.chars().filter(|c| *c == '-').count() >= 3 504 }) 505 .map(|s: &str| s.to_string()) 506 }) 507 .unwrap_or_else(|| body_text.clone()); 508 509 let confirm_payload = json!({ 510 "did": did, 511 "verificationCode": verification_code 512 }); 513 let confirm_res = client 514 .post(format!( 515 "{}/xrpc/com.atproto.server.confirmSignup", 516 base_url().await 517 )) 518 .json(&confirm_payload) 519 .send() 520 .await 521 .expect("confirmSignup request failed"); 522 if confirm_res.status() == StatusCode::OK { 523 let confirm_body: Value = confirm_res 524 .json() 525 .await 526 .expect("Invalid JSON from confirmSignup"); 527 let access_jwt = confirm_body["accessJwt"] 528 .as_str() 529 .expect("No accessJwt in confirmSignup response") 530 .to_string(); 531 return (access_jwt, did); 532 } 533 last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await); 534 continue; 535 } 536 last_error = format!("Status {}: {:?}", res.status(), res.text().await); 537 } 538 panic!("Failed to create account after 3 attempts: {}", last_error); 539}