An easy-to-host PDS on the ATProtocol, MacOS. Grandma-approved.

fix(relay): address PR review issues for refactor/dry-route-handlers

Critical:
- Run cargo fmt --all (formatting violations in auth.rs, create_account.rs)
- Add unit tests for require_admin_token() in auth.rs (6 tests covering all
branches including the non-UTF-8 Authorization header path)
- Add unit tests for generate_code() in code_gen.rs (4 tests: length, charset,
character set membership, non-constant output)

Important:
- Narrow pub mod auth to pub(crate) mod auth in routes/mod.rs
- Drop pub from CODE_LEN and CHARSET in code_gen.rs (no external consumers)
- Switch OR EXISTS queries from bool to i64 + CAST AS INTEGER to avoid
sqlx type-affinity ambiguity on untyped SQLite expressions
- Narrow auth.rs doc comment: presence/prefix checks are conventional
short-circuits; only the final comparison uses subtle::ct_eq
- Remove stale "handle_in_handles query coverage" comment from test
- Log constraint name in unique_violation_source default arm so unexpected
future constraints are visible in traces

Suggestions (high-value):
- Use bool::from(ct_eq(...)) instead of unwrap_u8() != 1 per subtle docs
- Upgrade non-UTF-8 Authorization header log from debug to warn

authored by malpercio.dev and committed by

Tangled 24cb2724 7081ac6c

+197 -40
+91 -11
crates/relay/src/routes/auth.rs
··· 7 7 8 8 /// Validate the admin Bearer token from request headers. 9 9 /// 10 - /// Returns `Ok(())` when the token is present, has the `"Bearer "` prefix, and matches 11 - /// `Config.admin_token` in constant time. Returns `ApiError::Unauthorized` in all other 12 - /// cases, including when the server has no token configured. 10 + /// Returns `Ok(())` when the token is present, has the `"Bearer "` prefix, and the 11 + /// final byte comparison passes. The presence check and `"Bearer "` prefix strip are 12 + /// conventional short-circuits that do not expose the token value; only the final byte 13 + /// comparison uses `subtle::ct_eq` to avoid timing side-channels on the token itself. 14 + /// Returns `ApiError::Unauthorized` in all other cases, including when the server has 15 + /// no token configured. 13 16 /// 14 17 /// Call this at the top of any handler that requires admin access. 15 18 pub fn require_admin_token(headers: &HeaderMap, state: &AppState) -> Result<(), ApiError> { ··· 24 27 .and_then(|v| { 25 28 v.to_str() 26 29 .inspect_err(|_| { 27 - tracing::debug!( 30 + tracing::warn!( 28 31 "Authorization header contains non-UTF-8 bytes; treating as absent" 29 32 ); 30 33 }) ··· 39 42 ) 40 43 })?; 41 44 42 - if provided_token 43 - .as_bytes() 44 - .ct_eq(expected_token.as_bytes()) 45 - .unwrap_u8() 46 - != 1 47 - { 48 - return Err(ApiError::new(ErrorCode::Unauthorized, "invalid admin token")); 45 + if !bool::from(provided_token.as_bytes().ct_eq(expected_token.as_bytes())) { 46 + return Err(ApiError::new( 47 + ErrorCode::Unauthorized, 48 + "invalid admin token", 49 + )); 49 50 } 50 51 51 52 Ok(()) 52 53 } 54 + 55 + #[cfg(test)] 56 + mod tests { 57 + use super::*; 58 + use axum::http::{HeaderMap, HeaderValue}; 59 + use std::sync::Arc; 60 + 61 + use crate::app::test_state; 62 + 63 + async fn state_with_token(token: &str) -> AppState { 64 + let base = test_state().await; 65 + let mut config = (*base.config).clone(); 66 + config.admin_token = Some(token.to_string()); 67 + AppState { 68 + config: Arc::new(config), 69 + db: base.db, 70 + } 71 + } 72 + 73 + fn headers_with_bearer(token: &str) -> HeaderMap { 74 + let mut h = HeaderMap::new(); 75 + h.insert( 76 + axum::http::header::AUTHORIZATION, 77 + format!("Bearer {token}").parse().unwrap(), 78 + ); 79 + h 80 + } 81 + 82 + #[tokio::test] 83 + async fn no_token_configured_returns_401() { 84 + let state = test_state().await; // admin_token = None 85 + let headers = headers_with_bearer("anything"); 86 + let err = require_admin_token(&headers, &state).unwrap_err(); 87 + assert_eq!(err.status_code(), 401); 88 + } 89 + 90 + #[tokio::test] 91 + async fn missing_authorization_header_returns_401() { 92 + let state = state_with_token("secret").await; 93 + let err = require_admin_token(&HeaderMap::new(), &state).unwrap_err(); 94 + assert_eq!(err.status_code(), 401); 95 + } 96 + 97 + #[tokio::test] 98 + async fn bare_token_without_bearer_prefix_returns_401() { 99 + let state = state_with_token("secret").await; 100 + let mut headers = HeaderMap::new(); 101 + headers.insert(axum::http::header::AUTHORIZATION, "secret".parse().unwrap()); 102 + let err = require_admin_token(&headers, &state).unwrap_err(); 103 + assert_eq!(err.status_code(), 401); 104 + } 105 + 106 + #[tokio::test] 107 + async fn wrong_token_returns_401() { 108 + let state = state_with_token("correct").await; 109 + let err = require_admin_token(&headers_with_bearer("wrong"), &state).unwrap_err(); 110 + assert_eq!(err.status_code(), 401); 111 + } 112 + 113 + #[tokio::test] 114 + async fn correct_token_returns_ok() { 115 + let state = state_with_token("secret").await; 116 + assert!(require_admin_token(&headers_with_bearer("secret"), &state).is_ok()); 117 + } 118 + 119 + #[tokio::test] 120 + async fn non_utf8_authorization_header_returns_401() { 121 + // Exercises the inspect_err / treat-as-absent path. 122 + // HeaderValue::from_bytes accepts arbitrary bytes; to_str() will fail on \xff. 123 + let state = state_with_token("secret").await; 124 + let mut headers = HeaderMap::new(); 125 + headers.insert( 126 + axum::http::header::AUTHORIZATION, 127 + HeaderValue::from_bytes(b"Bearer \xff\xfe").unwrap(), 128 + ); 129 + let err = require_admin_token(&headers, &state).unwrap_err(); 130 + assert_eq!(err.status_code(), 401); 131 + } 132 + }
+48 -2
crates/relay/src/routes/code_gen.rs
··· 1 1 use rand_core::{OsRng, RngCore}; 2 2 3 - pub const CODE_LEN: usize = 6; 4 - pub const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; 3 + const CODE_LEN: usize = 6; 4 + const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; 5 5 6 6 /// Generate a single 6-character uppercase alphanumeric code. 7 7 pub fn generate_code() -> String { ··· 11 11 .map(|&b| CHARSET[(b as usize) % CHARSET.len()] as char) 12 12 .collect() 13 13 } 14 + 15 + #[cfg(test)] 16 + mod tests { 17 + use super::*; 18 + 19 + #[test] 20 + fn code_is_6_chars() { 21 + assert_eq!(generate_code().len(), CODE_LEN); 22 + } 23 + 24 + #[test] 25 + fn code_is_uppercase_alphanumeric() { 26 + for _ in 0..50 { 27 + let code = generate_code(); 28 + assert!( 29 + code.chars() 30 + .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()), 31 + "code contained non-uppercase-alphanumeric char: {code}" 32 + ); 33 + } 34 + } 35 + 36 + #[test] 37 + fn codes_are_drawn_from_charset() { 38 + // Every character in a generated code must appear in CHARSET. 39 + for _ in 0..50 { 40 + let code = generate_code(); 41 + for ch in code.chars() { 42 + assert!( 43 + CHARSET.contains(&(ch as u8)), 44 + "char {ch:?} is not in CHARSET" 45 + ); 46 + } 47 + } 48 + } 49 + 50 + #[test] 51 + fn consecutive_codes_are_not_all_identical() { 52 + // With 36^6 ≈ 2.2 billion possible codes, the probability that 10 consecutive 53 + // calls all return the same code is negligibly small. This test catches a broken 54 + // RNG or constant-return implementation. 55 + let codes: Vec<String> = (0..10).map(|_| generate_code()).collect(); 56 + let unique: std::collections::HashSet<_> = codes.iter().collect(); 57 + assert!(unique.len() > 1, "all 10 generated codes were identical"); 58 + } 59 + }
+55 -24
crates/relay/src/routes/create_account.rs
··· 61 61 } 62 62 63 63 // --- Email uniqueness: check accounts and pending_accounts in one query --- 64 - // Fast-path optimization: reject before the INSERT to avoid touching the claim code retry 65 - // loop on a predictable failure. The unique indexes on pending_accounts.email and 66 - // accounts.email are the authoritative enforcement; this is an early return. 64 + // Fast-path: reject before the INSERT to avoid burning a claim_code slot on a 65 + // predictable conflict. The unique indexes are the authoritative enforcement; this 66 + // is an optimization that also provides an early error for fully-provisioned accounts. 67 67 // Note: pending_accounts has no cross-table FK to accounts, so both tables must be checked. 68 - let email_taken: bool = sqlx::query_scalar( 69 - "SELECT EXISTS(SELECT 1 FROM accounts WHERE email = ?) 70 - OR EXISTS(SELECT 1 FROM pending_accounts WHERE email = ?)", 68 + // CAST ensures sqlx maps the result as INTEGER regardless of SQLite's type affinity 69 + // rules on untyped OR expressions. 70 + let email_taken: i64 = sqlx::query_scalar( 71 + "SELECT CAST( 72 + (EXISTS(SELECT 1 FROM accounts WHERE email = ?) 73 + OR EXISTS(SELECT 1 FROM pending_accounts WHERE email = ?)) 74 + AS INTEGER)", 71 75 ) 72 76 .bind(&payload.email) 73 77 .bind(&payload.email) ··· 78 82 ApiError::new(ErrorCode::InternalError, "failed to create account") 79 83 })?; 80 84 81 - if email_taken { 85 + if email_taken != 0 { 82 86 return Err(ApiError::new( 83 87 ErrorCode::AccountExists, 84 88 "an account with this email already exists", ··· 86 90 } 87 91 88 92 // --- Handle uniqueness: check handles and pending_accounts in one query --- 89 - let handle_taken: bool = sqlx::query_scalar( 90 - "SELECT EXISTS(SELECT 1 FROM handles WHERE handle = ?) 91 - OR EXISTS(SELECT 1 FROM pending_accounts WHERE handle = ?)", 93 + // handles.handle is the PRIMARY KEY (uniqueness enforced by the PK, not a separate index). 94 + let handle_taken: i64 = sqlx::query_scalar( 95 + "SELECT CAST( 96 + (EXISTS(SELECT 1 FROM handles WHERE handle = ?) 97 + OR EXISTS(SELECT 1 FROM pending_accounts WHERE handle = ?)) 98 + AS INTEGER)", 92 99 ) 93 100 .bind(&payload.handle) 94 101 .bind(&payload.handle) ··· 99 106 ApiError::new(ErrorCode::InternalError, "failed to create account") 100 107 })?; 101 108 102 - if handle_taken { 109 + if handle_taken != 0 { 103 110 return Err(ApiError::new( 104 111 ErrorCode::HandleTaken, 105 112 "this handle is already claimed", ··· 164 171 } 165 172 166 173 tracing::error!("exhausted all claim code generation attempts"); 167 - Err(ApiError::new(ErrorCode::InternalError, "failed to create account")) 174 + Err(ApiError::new( 175 + ErrorCode::InternalError, 176 + "failed to create account", 177 + )) 168 178 } 169 179 170 180 /// Validate that a handle string passes basic format checks. ··· 264 274 if msg.contains("pending_accounts.handle") { 265 275 return Some(UniqueConflict::Handle); 266 276 } 277 + // Treat any other unique violation as a claim_codes.code collision. 278 + // Log the constraint name so unexpected constraints are visible in traces. 279 + tracing::debug!( 280 + constraint = msg, 281 + "unique violation on unknown constraint; treating as claim code collision" 282 + ); 267 283 return Some(UniqueConflict::ClaimCode); 268 284 } 269 285 } ··· 312 328 .await 313 329 .unwrap(); 314 330 let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 315 - assert!(json["accountId"].as_str().is_some(), "accountId must be present"); 331 + assert!( 332 + json["accountId"].as_str().is_some(), 333 + "accountId must be present" 334 + ); 316 335 assert_eq!(json["did"], serde_json::Value::Null, "did must be null"); 317 - assert!(json["claimCode"].as_str().is_some(), "claimCode must be present"); 336 + assert!( 337 + json["claimCode"].as_str().is_some(), 338 + "claimCode must be present" 339 + ); 318 340 assert_eq!(json["status"], "pending"); 319 341 } 320 342 ··· 336 358 let code = json["claimCode"].as_str().unwrap(); 337 359 assert_eq!(code.len(), 6, "claim code must be 6 chars"); 338 360 assert!( 339 - code.chars().all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()), 361 + code.chars() 362 + .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()), 340 363 "claim code must be uppercase alphanumeric, got: {code}" 341 364 ); 342 365 } ··· 386 409 .fetch_one(&db) 387 410 .await 388 411 .unwrap(); 389 - assert!(within_window, "claim code must expire approximately 24h from now"); 412 + assert!( 413 + within_window, 414 + "claim code must expire approximately 24h from now" 415 + ); 390 416 } 391 417 392 418 // ── Duplicate email ─────────────────────────────────────────────────────── ··· 415 441 .await 416 442 .unwrap(); 417 443 assert_eq!(second.status(), StatusCode::CONFLICT); 418 - let body = axum::body::to_bytes(second.into_body(), 4096).await.unwrap(); 444 + let body = axum::body::to_bytes(second.into_body(), 4096) 445 + .await 446 + .unwrap(); 419 447 let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 420 448 assert_eq!(json["error"]["code"], "ACCOUNT_EXISTS"); 421 449 } ··· 443 471 .unwrap(); 444 472 445 473 assert_eq!(response.status(), StatusCode::CONFLICT); 446 - let body = axum::body::to_bytes(response.into_body(), 4096).await.unwrap(); 474 + let body = axum::body::to_bytes(response.into_body(), 4096) 475 + .await 476 + .unwrap(); 447 477 let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 448 478 assert_eq!(json["error"]["code"], "ACCOUNT_EXISTS"); 449 479 } ··· 473 503 .await 474 504 .unwrap(); 475 505 assert_eq!(second.status(), StatusCode::CONFLICT); 476 - let body = axum::body::to_bytes(second.into_body(), 4096).await.unwrap(); 506 + let body = axum::body::to_bytes(second.into_body(), 4096) 507 + .await 508 + .unwrap(); 477 509 let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 478 510 assert_eq!(json["error"]["code"], "HANDLE_TAKEN"); 479 511 } 480 512 481 513 #[tokio::test] 482 514 async fn duplicate_handle_in_handles_returns_409() { 483 - // handle_in_handles query coverage 484 515 let state = test_state_with_admin_token().await; 485 516 486 517 // Seed a fully-provisioned account with an active handle. ··· 508 539 .unwrap(); 509 540 510 541 assert_eq!(response.status(), StatusCode::CONFLICT); 511 - let body = axum::body::to_bytes(response.into_body(), 4096).await.unwrap(); 542 + let body = axum::body::to_bytes(response.into_body(), 4096) 543 + .await 544 + .unwrap(); 512 545 let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 513 546 assert_eq!(json["error"]["code"], "HANDLE_TAKEN"); 514 547 } ··· 555 588 #[tokio::test] 556 589 async fn handle_exceeding_253_chars_returns_400() { 557 590 let long_handle = "a".repeat(254); 558 - let body = format!( 559 - r#"{{"email":"x@example.com","handle":"{long_handle}","tier":"free"}}"# 560 - ); 591 + let body = format!(r#"{{"email":"x@example.com","handle":"{long_handle}","tier":"free"}}"#); 561 592 let response = app(test_state_with_admin_token().await) 562 593 .oneshot(post_create_account(&body, Some("test-admin-token"))) 563 594 .await
+1 -1
crates/relay/src/routes/mod.rs
··· 1 - pub mod auth; 1 + pub(crate) mod auth; 2 2 pub mod claim_codes; 3 3 pub mod create_account; 4 4 pub mod create_signing_key;
+2 -2
crates/relay/src/routes/test_utils.rs
··· 5 5 /// Minimal test state with admin_token set to `"test-admin-token"`. 6 6 /// 7 7 /// Wraps `test_state()` and overrides the single config field that most 8 - /// admin-endpoint tests need. Defined once here to avoid copying the same 9 - /// 8-line block into every route test module. 8 + /// admin-endpoint tests need. Defined once here rather than duplicated in 9 + /// every route test module. 10 10 pub async fn test_state_with_admin_token() -> AppState { 11 11 let base = test_state().await; 12 12 let mut config = (*base.config).clone();