Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

fix: bulk type safety improvements, added a couple of tests

+5080 -2754
+8
.config/nextest.toml
··· 33 33 filter = "test(/two_node_stress_concurrent_load/)" 34 34 test-group = "heavy-load-tests" 35 35 36 + [[profile.default.overrides]] 37 + filter = "binary(repo_lifecycle)" 38 + test-group = "heavy-load-tests" 39 + 36 40 [[profile.ci.overrides]] 37 41 filter = "test(/import_with_verification/) | test(/plc_migration/)" 38 42 test-group = "serial-env-tests" ··· 48 52 [[profile.ci.overrides]] 49 53 filter = "test(/two_node_stress_concurrent_load/)" 50 54 test-group = "heavy-load-tests" 55 + 56 + [[profile.ci.overrides]] 57 + filter = "binary(repo_lifecycle)" 58 + test-group = "heavy-load-tests"
+15 -14
Cargo.lock
··· 5893 5893 5894 5894 [[package]] 5895 5895 name = "tranquil-auth" 5896 - version = "0.2.0" 5896 + version = "0.2.1" 5897 5897 dependencies = [ 5898 5898 "anyhow", 5899 5899 "base32", ··· 5915 5915 5916 5916 [[package]] 5917 5917 name = "tranquil-cache" 5918 - version = "0.2.0" 5918 + version = "0.2.1" 5919 5919 dependencies = [ 5920 5920 "async-trait", 5921 5921 "base64 0.22.1", ··· 5928 5928 5929 5929 [[package]] 5930 5930 name = "tranquil-comms" 5931 - version = "0.2.0" 5931 + version = "0.2.1" 5932 5932 dependencies = [ 5933 5933 "async-trait", 5934 5934 "base64 0.22.1", ··· 5942 5942 5943 5943 [[package]] 5944 5944 name = "tranquil-crypto" 5945 - version = "0.2.0" 5945 + version = "0.2.1" 5946 5946 dependencies = [ 5947 5947 "aes-gcm", 5948 5948 "base64 0.22.1", ··· 5958 5958 5959 5959 [[package]] 5960 5960 name = "tranquil-db" 5961 - version = "0.2.0" 5961 + version = "0.2.1" 5962 5962 dependencies = [ 5963 5963 "async-trait", 5964 5964 "chrono", ··· 5975 5975 5976 5976 [[package]] 5977 5977 name = "tranquil-db-traits" 5978 - version = "0.2.0" 5978 + version = "0.2.1" 5979 5979 dependencies = [ 5980 5980 "async-trait", 5981 5981 "base64 0.22.1", ··· 5991 5991 5992 5992 [[package]] 5993 5993 name = "tranquil-infra" 5994 - version = "0.2.0" 5994 + version = "0.2.1" 5995 5995 dependencies = [ 5996 5996 "async-trait", 5997 5997 "bytes", ··· 6001 6001 6002 6002 [[package]] 6003 6003 name = "tranquil-oauth" 6004 - version = "0.2.0" 6004 + version = "0.2.1" 6005 6005 dependencies = [ 6006 6006 "anyhow", 6007 6007 "axum", ··· 6024 6024 6025 6025 [[package]] 6026 6026 name = "tranquil-pds" 6027 - version = "0.2.0" 6027 + version = "0.2.1" 6028 6028 dependencies = [ 6029 6029 "aes-gcm", 6030 6030 "anyhow", ··· 6109 6109 6110 6110 [[package]] 6111 6111 name = "tranquil-repo" 6112 - version = "0.2.0" 6112 + version = "0.2.1" 6113 6113 dependencies = [ 6114 6114 "bytes", 6115 6115 "cid", ··· 6121 6121 6122 6122 [[package]] 6123 6123 name = "tranquil-ripple" 6124 - version = "0.2.0" 6124 + version = "0.2.1" 6125 6125 dependencies = [ 6126 6126 "async-trait", 6127 6127 "backon", ··· 6145 6145 6146 6146 [[package]] 6147 6147 name = "tranquil-scopes" 6148 - version = "0.2.0" 6148 + version = "0.2.1" 6149 6149 dependencies = [ 6150 6150 "axum", 6151 6151 "futures", ··· 6153 6153 "reqwest", 6154 6154 "serde", 6155 6155 "serde_json", 6156 + "thiserror 2.0.17", 6156 6157 "tokio", 6157 6158 "tracing", 6158 6159 "urlencoding", ··· 6160 6161 6161 6162 [[package]] 6162 6163 name = "tranquil-storage" 6163 - version = "0.2.0" 6164 + version = "0.2.1" 6164 6165 dependencies = [ 6165 6166 "async-trait", 6166 6167 "aws-config", ··· 6176 6177 6177 6178 [[package]] 6178 6179 name = "tranquil-types" 6179 - version = "0.2.0" 6180 + version = "0.2.1" 6180 6181 dependencies = [ 6181 6182 "chrono", 6182 6183 "cid",
+1 -1
Cargo.toml
··· 18 18 ] 19 19 20 20 [workspace.package] 21 - version = "0.2.0" 21 + version = "0.2.1" 22 22 edition = "2024" 23 23 license = "AGPL-3.0-or-later" 24 24
+2 -34
TODO.md
··· 5 5 ### Storage backend abstraction 6 6 Make storage layers swappable via traits. 7 7 8 - filesystem blob storage 9 - - [ ] FilesystemBlobStorage implementation 10 - - [ ] directory structure (content-addressed like blobs/{cid} already used in objsto) 11 - - [ ] atomic writes (write to temp, rename) 12 - - [ ] config option to choose backend (env var or config flag) 13 - - [ ] also traitify BackupStorage (currently hardcoded to objsto) 14 - 15 8 sqlite database backend 16 9 - [ ] abstract db layer behind trait (queries, transactions, migrations) 17 10 - [ ] sqlite implementation matching postgres behavior ··· 20 13 - [ ] testing: run full test suite against both backends 21 14 - [ ] config option to choose backend (postgres vs sqlite) 22 15 - [ ] document tradeoffs (sqlite for single-user/small, postgres for multi-user/scale) 16 + 17 + - [ ] skip sqlite and just straight-up do our own db?! 23 18 24 19 ### Plugin system 25 20 WASM component model plugins. Compile to wasm32-wasip2, sandboxed via wasmtime, capability-gated. Based on zed's extensions. ··· 131 126 - [ ] on_access_grant_request for custom authorization 132 127 - [ ] on_key_rotation to notify interested parties 133 128 134 - --- 135 - 136 - ## Completed 137 - 138 - Core ATProto: Health, describeServer, all session endpoints, full repo CRUD, applyWrites, blob upload, importRepo, firehose with cursor replay, CAR export, blob sync, crawler notifications, handle resolution, PLC operations, full admin API, moderation reports. 139 - 140 - did:web support: Self-hosted did:web (subdomain format `did:web:handle.pds.com`), external/BYOD did:web, DID document serving via `/.well-known/did.json`, clear registration warnings about did:web trade-offs vs did:plc. 141 - 142 - OAuth 2.1: Authorization server metadata, JWKS, PAR, authorize endpoint with login UI, token endpoint (auth code + refresh), revocation, introspection, DPoP, PKCE S256, client metadata validation, private_key_jwt verification. 143 - 144 - OAuth Scope Enforcement: Full granular scope system with consent UI, human-readable scope descriptions, per-client scope preferences, scope parsing (repo/blob/rpc/account/identity), endpoint-level scope checks, DPoP token support in auth extractors, token revocation on re-authorization, response_mode support (query/fragment). 145 - 146 - App endpoints: getPreferences, putPreferences, getProfile, getProfiles, getTimeline, getAuthorFeed, getActorLikes, getPostThread, getFeed, registerPush (all with local-first + proxy fallback). 147 - 148 - Infrastructure: Sequencer with cursor replay, postgres repo storage with atomic transactions, valkey DID cache, debounced crawler notifications with circuit breakers, multi-channel notifications (email/Discord/Telegram/Signal), image processing, distributed rate limiting, security hardening. 149 - 150 - Web UI: OAuth login, registration, email verification, password reset, multi-account selector, dashboard, sessions, app passwords, invites, notification preferences, repo browser, CAR export, admin panel, OAuth consent screen with scope selection. 151 - 152 - Auth: ES256K + HS256 dual support, JTI-only token storage, refresh token family tracking, encrypted signing keys (AES-256-GCM), DPoP replay protection, constant-time comparisons. 153 - 154 - Passkeys and 2FA: WebAuthn/FIDO2 passkey registration and authentication, TOTP with QR setup, backup codes (hashed, one-time use), passkey-only account creation, trusted devices (remember this browser), re-auth for sensitive actions, rate-limited 2FA attempts, settings UI for managing all auth methods. 155 - 156 - App password scopes: Granular permissions for app passwords using the same scope system as OAuth. Preset buttons for common use cases (full access, read-only, post-only), scope stored in session and preserved across token refresh, explicit RPC/repo/blob scope enforcement for restricted passwords. 157 - 158 - Account Delegation: Delegated accounts controlled by other accounts instead of passwords. OAuth delegation flow (authenticate as controller), scope-based permissions (owner/admin/editor/viewer presets), scope intersection (tokens limited to granted permissions), `act` claim for delegation tracking, creating delegated account flow, controller management UI, "act as" account switcher, comprehensive audit logging with actor/controller tracking, delegation-aware OAuth consent with permission limitation notices. 159 - 160 - Migration: OAuth-based inbound migration wizard with PLC token flow, offline restore from CAR file + rotation key for disaster recovery, scheduled automatic backups, standalone repo/blob export, did:web DID document editor for self-service identity management, handle preservation (keep existing external handle via DNS/HTTP verification or create new PDS-subdomain handle).
+10 -10
crates/tranquil-auth/src/lib.rs
··· 4 4 mod verify; 5 5 6 6 pub use token::{ 7 - SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 8 - TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, create_access_token_hs256, 9 - create_access_token_hs256_with_metadata, create_access_token_with_delegation, 10 - create_access_token_with_metadata, create_access_token_with_scope_metadata, 11 - create_refresh_token, create_refresh_token_hs256, create_refresh_token_hs256_with_metadata, 12 - create_refresh_token_with_metadata, create_service_token, create_service_token_hs256, 7 + create_access_token, create_access_token_hs256, create_access_token_hs256_with_metadata, 8 + create_access_token_with_delegation, create_access_token_with_metadata, 9 + create_access_token_with_scope_metadata, create_refresh_token, create_refresh_token_hs256, 10 + create_refresh_token_hs256_with_metadata, create_refresh_token_with_metadata, 11 + create_service_token, create_service_token_hs256, 13 12 }; 14 13 15 14 pub use totp::{ 16 - decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64, 17 - generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format, 18 - verify_backup_code, verify_totp_code, 15 + TotpError, decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, 16 + generate_qr_png_base64, generate_totp_secret, generate_totp_uri, hash_backup_code, 17 + is_backup_code_format, verify_backup_code, verify_totp_code, 19 18 }; 20 19 21 20 pub use types::{ 22 - ActClaim, Claims, Header, TokenData, TokenVerifyError, TokenWithMetadata, UnsafeClaims, 21 + ActClaim, Claims, Header, SigningAlgorithm, TokenData, TokenDecodeError, TokenScope, TokenType, 22 + TokenVerifyError, TokenWithMetadata, UnsafeClaims, 23 23 }; 24 24 25 25 pub use verify::{
+32 -38
crates/tranquil-auth/src/token.rs
··· 1 - use super::types::{ActClaim, Claims, Header, TokenWithMetadata}; 1 + use super::types::{ 2 + ActClaim, Claims, Header, SigningAlgorithm, TokenScope, TokenType, TokenWithMetadata, 3 + }; 2 4 use anyhow::Result; 3 5 use base64::Engine as _; 4 6 use base64::engine::general_purpose::URL_SAFE_NO_PAD; ··· 9 11 10 12 type HmacSha256 = Hmac<Sha256>; 11 13 12 - pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 13 - pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; 14 - pub const TOKEN_TYPE_SERVICE: &str = "jwt"; 15 - pub const SCOPE_ACCESS: &str = "com.atproto.access"; 16 - pub const SCOPE_REFRESH: &str = "com.atproto.refresh"; 17 - pub const SCOPE_APP_PASS: &str = "com.atproto.appPass"; 18 - pub const SCOPE_APP_PASS_PRIVILEGED: &str = "com.atproto.appPassPrivileged"; 19 - 20 14 pub fn create_access_token(did: &str, key_bytes: &[u8]) -> Result<String> { 21 15 Ok(create_access_token_with_metadata(did, key_bytes)?.token) 22 16 } ··· 35 29 scopes: Option<&str>, 36 30 hostname: Option<&str>, 37 31 ) -> Result<TokenWithMetadata> { 38 - let scope = scopes.unwrap_or(SCOPE_ACCESS); 32 + let scope = scopes.unwrap_or(TokenScope::Access.as_str()); 39 33 create_signed_token_with_metadata( 40 34 did, 41 35 scope, 42 - TOKEN_TYPE_ACCESS, 36 + TokenType::Access, 43 37 key_bytes, 44 38 Duration::minutes(15), 45 39 hostname, ··· 53 47 controller_did: Option<&str>, 54 48 hostname: Option<&str>, 55 49 ) -> Result<TokenWithMetadata> { 56 - let scope = scopes.unwrap_or(SCOPE_ACCESS); 50 + let scope = scopes.unwrap_or(TokenScope::Access.as_str()); 57 51 let act = controller_did.map(|c| ActClaim { sub: c.to_string() }); 58 52 create_signed_token_with_act( 59 53 did, 60 54 scope, 61 - TOKEN_TYPE_ACCESS, 55 + TokenType::Access, 62 56 key_bytes, 63 57 Duration::minutes(15), 64 58 act, ··· 72 66 ) -> Result<TokenWithMetadata> { 73 67 create_signed_token_with_metadata( 74 68 did, 75 - SCOPE_REFRESH, 76 - TOKEN_TYPE_REFRESH, 69 + TokenScope::Refresh.as_str(), 70 + TokenType::Refresh, 77 71 key_bytes, 78 72 Duration::days(14), 79 73 None, ··· 92 86 iss: did.to_owned(), 93 87 sub: did.to_owned(), 94 88 aud: aud.to_owned(), 95 - exp: expiration as usize, 96 - iat: Utc::now().timestamp() as usize, 89 + exp: expiration, 90 + iat: Utc::now().timestamp(), 97 91 scope: None, 98 92 lxm: Some(lxm.to_string()), 99 93 jti: uuid::Uuid::new_v4().to_string(), ··· 106 100 fn create_signed_token_with_metadata( 107 101 did: &str, 108 102 scope: &str, 109 - typ: &str, 103 + typ: TokenType, 110 104 key_bytes: &[u8], 111 105 duration: Duration, 112 106 hostname: Option<&str>, ··· 117 111 fn create_signed_token_with_act( 118 112 did: &str, 119 113 scope: &str, 120 - typ: &str, 114 + typ: TokenType, 121 115 key_bytes: &[u8], 122 116 duration: Duration, 123 117 act: Option<ActClaim>, ··· 140 134 iss: did.to_owned(), 141 135 sub: did.to_owned(), 142 136 aud: format!("did:web:{}", aud_hostname), 143 - exp: expiration as usize, 144 - iat: Utc::now().timestamp() as usize, 137 + exp: expiration, 138 + iat: Utc::now().timestamp(), 145 139 scope: Some(scope.to_string()), 146 140 lxm: None, 147 141 jti: jti.clone(), ··· 158 152 } 159 153 160 154 fn sign_claims(claims: Claims, key: &SigningKey) -> Result<String> { 161 - sign_claims_with_type(claims, key, TOKEN_TYPE_SERVICE) 155 + sign_claims_with_type(claims, key, TokenType::Service) 162 156 } 163 157 164 - fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: &str) -> Result<String> { 158 + fn sign_claims_with_type(claims: Claims, key: &SigningKey, typ: TokenType) -> Result<String> { 165 159 let header = Header { 166 - alg: "ES256K".to_string(), 167 - typ: typ.to_string(), 160 + alg: SigningAlgorithm::ES256K, 161 + typ, 168 162 }; 169 163 170 164 let header_json = serde_json::to_string(&header)?; ··· 194 188 ) -> Result<TokenWithMetadata> { 195 189 create_hs256_token_with_metadata( 196 190 did, 197 - SCOPE_ACCESS, 198 - TOKEN_TYPE_ACCESS, 191 + TokenScope::Access.as_str(), 192 + TokenType::Access, 199 193 secret, 200 194 Duration::minutes(15), 201 195 ) ··· 207 201 ) -> Result<TokenWithMetadata> { 208 202 create_hs256_token_with_metadata( 209 203 did, 210 - SCOPE_REFRESH, 211 - TOKEN_TYPE_REFRESH, 204 + TokenScope::Refresh.as_str(), 205 + TokenType::Refresh, 212 206 secret, 213 207 Duration::days(14), 214 208 ) ··· 229 223 iss: did.to_owned(), 230 224 sub: did.to_owned(), 231 225 aud: aud.to_owned(), 232 - exp: expiration as usize, 233 - iat: Utc::now().timestamp() as usize, 226 + exp: expiration, 227 + iat: Utc::now().timestamp(), 234 228 scope: None, 235 229 lxm: Some(lxm.to_string()), 236 230 jti: uuid::Uuid::new_v4().to_string(), 237 231 act: None, 238 232 }; 239 233 240 - sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) 234 + sign_claims_hs256(claims, TokenType::Service, secret) 241 235 } 242 236 243 237 fn create_hs256_token_with_metadata( 244 238 did: &str, 245 239 scope: &str, 246 - typ: &str, 240 + typ: TokenType, 247 241 secret: &[u8], 248 242 duration: Duration, 249 243 ) -> Result<TokenWithMetadata> { ··· 261 255 "did:web:{}", 262 256 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 263 257 ), 264 - exp: expiration as usize, 265 - iat: Utc::now().timestamp() as usize, 258 + exp: expiration, 259 + iat: Utc::now().timestamp(), 266 260 scope: Some(scope.to_string()), 267 261 lxm: None, 268 262 jti: jti.clone(), ··· 278 272 }) 279 273 } 280 274 281 - fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> { 275 + fn sign_claims_hs256(claims: Claims, typ: TokenType, secret: &[u8]) -> Result<String> { 282 276 let header = Header { 283 - alg: "HS256".to_string(), 284 - typ: typ.to_string(), 277 + alg: SigningAlgorithm::HS256, 278 + typ, 285 279 }; 286 280 287 281 let header_json = serde_json::to_string(&header)?;
+29 -9
crates/tranquil-auth/src/totp.rs
··· 1 1 use base32::Alphabet; 2 - use rand::RngCore; 2 + use rand::{Rng, RngCore}; 3 3 use subtle::ConstantTimeEq; 4 4 use totp_rs::{Algorithm, TOTP}; 5 5 6 6 const TOTP_DIGITS: usize = 6; 7 7 const TOTP_STEP: u64 = 30; 8 + const TOTP_STEP_SIGNED: i64 = TOTP_STEP as i64; 8 9 const TOTP_SECRET_LENGTH: usize = 20; 9 10 11 + #[derive(Debug)] 12 + pub enum TotpError { 13 + CreationFailed(String), 14 + QrGenerationFailed(String), 15 + HashFailed(String), 16 + } 17 + 18 + impl std::fmt::Display for TotpError { 19 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 20 + match self { 21 + Self::CreationFailed(e) => write!(f, "TOTP creation failed: {}", e), 22 + Self::QrGenerationFailed(e) => write!(f, "QR generation failed: {}", e), 23 + Self::HashFailed(e) => write!(f, "Hash failed: {}", e), 24 + } 25 + } 26 + } 27 + 28 + impl std::error::Error for TotpError {} 29 + 10 30 pub fn generate_totp_secret() -> Vec<u8> { 11 31 let mut secret = vec![0u8; TOTP_SECRET_LENGTH]; 12 32 rand::thread_rng().fill_bytes(&mut secret); ··· 31 51 secret: Vec<u8>, 32 52 issuer: Option<String>, 33 53 account_name: String, 34 - ) -> Result<TOTP, String> { 54 + ) -> Result<TOTP, TotpError> { 35 55 TOTP::new( 36 56 Algorithm::SHA1, 37 57 TOTP_DIGITS, ··· 41 61 issuer, 42 62 account_name, 43 63 ) 44 - .map_err(|e| format!("Failed to create TOTP: {}", e)) 64 + .map_err(|e| TotpError::CreationFailed(e.to_string())) 45 65 } 46 66 47 67 pub fn verify_totp_code(secret: &[u8], code: &str) -> bool { ··· 60 80 .unwrap_or(0); 61 81 62 82 [-1i64, 0, 1].iter().any(|&offset| { 63 - let time = (now as i64 + offset * TOTP_STEP as i64) as u64; 83 + let time = now.wrapping_add_signed(offset * TOTP_STEP_SIGNED); 64 84 let expected = totp.generate(time); 65 85 let is_valid: bool = code.as_bytes().ct_eq(expected.as_bytes()).into(); 66 86 is_valid ··· 84 104 secret: &[u8], 85 105 account_name: &str, 86 106 issuer: &str, 87 - ) -> Result<String, String> { 107 + ) -> Result<String, TotpError> { 88 108 use base64::{Engine, engine::general_purpose::STANDARD}; 89 109 90 110 let totp = create_totp( ··· 95 115 96 116 let qr_png = totp 97 117 .get_qr_png() 98 - .map_err(|e| format!("Failed to generate QR code: {}", e))?; 118 + .map_err(|e| TotpError::QrGenerationFailed(e.to_string()))?; 99 119 100 120 Ok(STANDARD.encode(qr_png)) 101 121 } ··· 112 132 (0..BACKUP_CODE_COUNT).for_each(|_| { 113 133 let code: String = (0..BACKUP_CODE_LENGTH) 114 134 .map(|_| { 115 - let idx = (rng.next_u32() as usize) % BACKUP_CODE_ALPHABET.len(); 135 + let idx = rng.gen_range(0..BACKUP_CODE_ALPHABET.len()); 116 136 BACKUP_CODE_ALPHABET[idx] as char 117 137 }) 118 138 .collect(); ··· 122 142 codes 123 143 } 124 144 125 - pub fn hash_backup_code(code: &str) -> Result<String, String> { 126 - bcrypt::hash(code, BACKUP_CODE_BCRYPT_COST).map_err(|e| format!("Failed to hash code: {}", e)) 145 + pub fn hash_backup_code(code: &str) -> Result<String, TotpError> { 146 + bcrypt::hash(code, BACKUP_CODE_BCRYPT_COST).map_err(|e| TotpError::HashFailed(e.to_string())) 127 147 } 128 148 129 149 pub fn verify_backup_code(code: &str, hash: &str) -> bool {
+202 -5
crates/tranquil-auth/src/types.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 - use serde::{Deserialize, Serialize}; 2 + use serde::{Deserialize, Serialize, de, ser}; 3 3 use std::fmt; 4 + use std::str::FromStr; 5 + 6 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 7 + pub enum TokenType { 8 + Access, 9 + Refresh, 10 + Service, 11 + } 12 + 13 + impl TokenType { 14 + pub fn as_str(&self) -> &'static str { 15 + match self { 16 + Self::Access => "at+jwt", 17 + Self::Refresh => "refresh+jwt", 18 + Self::Service => "jwt", 19 + } 20 + } 21 + } 22 + 23 + impl fmt::Display for TokenType { 24 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 25 + f.write_str(self.as_str()) 26 + } 27 + } 28 + 29 + impl FromStr for TokenType { 30 + type Err = TokenTypeParseError; 31 + 32 + fn from_str(s: &str) -> Result<Self, Self::Err> { 33 + match s { 34 + "at+jwt" => Ok(Self::Access), 35 + "refresh+jwt" => Ok(Self::Refresh), 36 + "jwt" => Ok(Self::Service), 37 + _ => Err(TokenTypeParseError(s.to_string())), 38 + } 39 + } 40 + } 41 + 42 + #[derive(Debug, Clone)] 43 + pub struct TokenTypeParseError(pub String); 44 + 45 + impl fmt::Display for TokenTypeParseError { 46 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 + write!(f, "unknown token type: {}", self.0) 48 + } 49 + } 50 + 51 + impl std::error::Error for TokenTypeParseError {} 52 + 53 + impl Serialize for TokenType { 54 + fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { 55 + serializer.serialize_str(self.as_str()) 56 + } 57 + } 58 + 59 + impl<'de> Deserialize<'de> for TokenType { 60 + fn deserialize<D: de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { 61 + let s = String::deserialize(deserializer)?; 62 + Self::from_str(&s).map_err(de::Error::custom) 63 + } 64 + } 65 + 66 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 67 + pub enum SigningAlgorithm { 68 + ES256K, 69 + HS256, 70 + } 71 + 72 + impl SigningAlgorithm { 73 + pub fn as_str(&self) -> &'static str { 74 + match self { 75 + Self::ES256K => "ES256K", 76 + Self::HS256 => "HS256", 77 + } 78 + } 79 + } 80 + 81 + impl fmt::Display for SigningAlgorithm { 82 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 83 + f.write_str(self.as_str()) 84 + } 85 + } 86 + 87 + impl FromStr for SigningAlgorithm { 88 + type Err = SigningAlgorithmParseError; 89 + 90 + fn from_str(s: &str) -> Result<Self, Self::Err> { 91 + match s { 92 + "ES256K" => Ok(Self::ES256K), 93 + "HS256" => Ok(Self::HS256), 94 + _ => Err(SigningAlgorithmParseError(s.to_string())), 95 + } 96 + } 97 + } 98 + 99 + #[derive(Debug, Clone)] 100 + pub struct SigningAlgorithmParseError(pub String); 101 + 102 + impl fmt::Display for SigningAlgorithmParseError { 103 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 104 + write!(f, "unknown signing algorithm: {}", self.0) 105 + } 106 + } 107 + 108 + impl std::error::Error for SigningAlgorithmParseError {} 109 + 110 + impl Serialize for SigningAlgorithm { 111 + fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { 112 + serializer.serialize_str(self.as_str()) 113 + } 114 + } 115 + 116 + impl<'de> Deserialize<'de> for SigningAlgorithm { 117 + fn deserialize<D: de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { 118 + let s = String::deserialize(deserializer)?; 119 + Self::from_str(&s).map_err(de::Error::custom) 120 + } 121 + } 122 + 123 + #[derive(Debug, Clone, PartialEq, Eq)] 124 + pub enum TokenScope { 125 + Access, 126 + Refresh, 127 + AppPass, 128 + AppPassPrivileged, 129 + Custom(String), 130 + } 131 + 132 + impl TokenScope { 133 + pub fn as_str(&self) -> &str { 134 + match self { 135 + Self::Access => "com.atproto.access", 136 + Self::Refresh => "com.atproto.refresh", 137 + Self::AppPass => "com.atproto.appPass", 138 + Self::AppPassPrivileged => "com.atproto.appPassPrivileged", 139 + Self::Custom(s) => s, 140 + } 141 + } 142 + 143 + pub fn is_access_like(&self) -> bool { 144 + matches!(self, Self::Access | Self::AppPass | Self::AppPassPrivileged) 145 + } 146 + } 147 + 148 + impl fmt::Display for TokenScope { 149 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 150 + f.write_str(self.as_str()) 151 + } 152 + } 153 + 154 + impl FromStr for TokenScope { 155 + type Err = std::convert::Infallible; 156 + 157 + fn from_str(s: &str) -> Result<Self, Self::Err> { 158 + Ok(match s { 159 + "com.atproto.access" => Self::Access, 160 + "com.atproto.refresh" => Self::Refresh, 161 + "com.atproto.appPass" => Self::AppPass, 162 + "com.atproto.appPassPrivileged" => Self::AppPassPrivileged, 163 + other => Self::Custom(other.to_string()), 164 + }) 165 + } 166 + } 167 + 168 + impl Serialize for TokenScope { 169 + fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { 170 + serializer.serialize_str(self.as_str()) 171 + } 172 + } 173 + 174 + impl<'de> Deserialize<'de> for TokenScope { 175 + fn deserialize<D: de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { 176 + let s = String::deserialize(deserializer)?; 177 + Ok(Self::from_str(&s).unwrap_or_else(|e| match e {})) 178 + } 179 + } 180 + 181 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 182 + pub enum TokenDecodeError { 183 + InvalidFormat, 184 + Base64DecodeFailed, 185 + JsonDecodeFailed, 186 + MissingClaim, 187 + } 188 + 189 + impl fmt::Display for TokenDecodeError { 190 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 191 + match self { 192 + Self::InvalidFormat => write!(f, "Invalid token format"), 193 + Self::Base64DecodeFailed => write!(f, "Base64 decode failed"), 194 + Self::JsonDecodeFailed => write!(f, "JSON decode failed"), 195 + Self::MissingClaim => write!(f, "Missing required claim"), 196 + } 197 + } 198 + } 199 + 200 + impl std::error::Error for TokenDecodeError {} 4 201 5 202 #[derive(Debug, Clone, Serialize, Deserialize)] 6 203 pub struct ActClaim { ··· 12 209 pub iss: String, 13 210 pub sub: String, 14 211 pub aud: String, 15 - pub exp: usize, 16 - pub iat: usize, 212 + pub exp: i64, 213 + pub iat: i64, 17 214 #[serde(skip_serializing_if = "Option::is_none")] 18 215 pub scope: Option<String>, 19 216 #[serde(skip_serializing_if = "Option::is_none")] ··· 25 222 26 223 #[derive(Debug, Serialize, Deserialize)] 27 224 pub struct Header { 28 - pub alg: String, 29 - pub typ: String, 225 + pub alg: SigningAlgorithm, 226 + pub typ: TokenType, 30 227 } 31 228 32 229 #[derive(Debug, Serialize, Deserialize)]
+61 -39
crates/tranquil-auth/src/verify.rs
··· 1 - use super::token::{ 2 - SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS, 3 - TOKEN_TYPE_REFRESH, 1 + use super::types::{ 2 + Claims, Header, SigningAlgorithm, TokenData, TokenDecodeError, TokenScope, TokenType, 3 + TokenVerifyError, UnsafeClaims, 4 4 }; 5 - use super::types::{Claims, Header, TokenData, TokenVerifyError, UnsafeClaims}; 6 5 use anyhow::{Context, Result, anyhow}; 7 6 use base64::Engine as _; 8 7 use base64::engine::general_purpose::URL_SAFE_NO_PAD; ··· 14 13 15 14 type HmacSha256 = Hmac<Sha256>; 16 15 17 - pub fn get_did_from_token(token: &str) -> Result<String, String> { 16 + pub fn get_did_from_token(token: &str) -> Result<String, TokenDecodeError> { 18 17 let parts: Vec<&str> = token.split('.').collect(); 19 18 if parts.len() != 3 { 20 - return Err("Invalid token format".to_string()); 19 + return Err(TokenDecodeError::InvalidFormat); 21 20 } 22 21 23 22 let payload_bytes = URL_SAFE_NO_PAD 24 23 .decode(parts[1]) 25 - .map_err(|e| format!("Base64 decode failed: {}", e))?; 24 + .map_err(|_| TokenDecodeError::Base64DecodeFailed)?; 26 25 27 26 let claims: UnsafeClaims = 28 - serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 27 + serde_json::from_slice(&payload_bytes).map_err(|_| TokenDecodeError::JsonDecodeFailed)?; 29 28 30 29 Ok(claims.sub.unwrap_or(claims.iss)) 31 30 } 32 31 33 - pub fn get_jti_from_token(token: &str) -> Result<String, String> { 32 + pub fn get_jti_from_token(token: &str) -> Result<String, TokenDecodeError> { 34 33 let parts: Vec<&str> = token.split('.').collect(); 35 34 if parts.len() != 3 { 36 - return Err("Invalid token format".to_string()); 35 + return Err(TokenDecodeError::InvalidFormat); 37 36 } 38 37 39 38 let payload_bytes = URL_SAFE_NO_PAD 40 39 .decode(parts[1]) 41 - .map_err(|e| format!("Base64 decode failed: {}", e))?; 40 + .map_err(|_| TokenDecodeError::Base64DecodeFailed)?; 42 41 43 42 let claims: serde_json::Value = 44 - serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 43 + serde_json::from_slice(&payload_bytes).map_err(|_| TokenDecodeError::JsonDecodeFailed)?; 45 44 46 45 claims 47 46 .get("jti") 48 47 .and_then(|j| j.as_str()) 49 48 .map(|s| s.to_string()) 50 - .ok_or_else(|| "No jti claim in token".to_string()) 49 + .ok_or(TokenDecodeError::MissingClaim) 51 50 } 52 51 53 - pub fn get_algorithm_from_token(token: &str) -> Result<String, String> { 52 + pub fn get_algorithm_from_token(token: &str) -> Result<SigningAlgorithm, TokenDecodeError> { 54 53 let parts: Vec<&str> = token.split('.').collect(); 55 54 if parts.len() != 3 { 56 - return Err("Invalid token format".to_string()); 55 + return Err(TokenDecodeError::InvalidFormat); 57 56 } 58 57 59 58 let header_bytes = URL_SAFE_NO_PAD 60 59 .decode(parts[0]) 61 - .map_err(|e| format!("Base64 decode failed: {}", e))?; 60 + .map_err(|_| TokenDecodeError::Base64DecodeFailed)?; 62 61 63 62 let header: Header = 64 - serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 63 + serde_json::from_slice(&header_bytes).map_err(|_| TokenDecodeError::JsonDecodeFailed)?; 65 64 66 65 Ok(header.alg) 67 66 } ··· 74 73 verify_token_internal( 75 74 token, 76 75 key_bytes, 77 - Some(TOKEN_TYPE_ACCESS), 78 - Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 76 + Some(TokenType::Access), 77 + Some(&[ 78 + TokenScope::Access, 79 + TokenScope::AppPass, 80 + TokenScope::AppPassPrivileged, 81 + ]), 79 82 ) 80 83 } 81 84 ··· 83 86 verify_token_internal( 84 87 token, 85 88 key_bytes, 86 - Some(TOKEN_TYPE_REFRESH), 87 - Some(&[SCOPE_REFRESH]), 89 + Some(TokenType::Refresh), 90 + Some(&[TokenScope::Refresh]), 88 91 ) 89 92 } 90 93 ··· 92 95 verify_token_hs256_internal( 93 96 token, 94 97 secret, 95 - Some(TOKEN_TYPE_ACCESS), 96 - Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 98 + Some(TokenType::Access), 99 + Some(&[ 100 + TokenScope::Access, 101 + TokenScope::AppPass, 102 + TokenScope::AppPassPrivileged, 103 + ]), 97 104 ) 98 105 } 99 106 ··· 101 108 verify_token_hs256_internal( 102 109 token, 103 110 secret, 104 - Some(TOKEN_TYPE_REFRESH), 105 - Some(&[SCOPE_REFRESH]), 111 + Some(TokenType::Refresh), 112 + Some(&[TokenScope::Refresh]), 106 113 ) 107 114 } 108 115 109 116 fn verify_token_internal( 110 117 token: &str, 111 118 key_bytes: &[u8], 112 - expected_typ: Option<&str>, 113 - allowed_scopes: Option<&[&str]>, 119 + expected_typ: Option<TokenType>, 120 + allowed_scopes: Option<&[TokenScope]>, 114 121 ) -> Result<TokenData<Claims>> { 115 122 let parts: Vec<&str> = token.split('.').collect(); 116 123 if parts.len() != 3 { ··· 160 167 let claims: Claims = 161 168 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 162 169 163 - let now = Utc::now().timestamp() as usize; 170 + let now = Utc::now().timestamp(); 164 171 if claims.exp < now { 165 172 return Err(anyhow!("Token expired")); 166 173 } 167 174 168 175 if let Some(scopes) = allowed_scopes { 169 - let token_scope = claims.scope.as_deref().unwrap_or(""); 176 + let token_scope: TokenScope = claims 177 + .scope 178 + .as_deref() 179 + .unwrap_or("") 180 + .parse() 181 + .unwrap_or_else(|e| match e {}); 170 182 if !scopes.contains(&token_scope) { 171 183 return Err(anyhow!("Invalid token scope: {}", token_scope)); 172 184 } ··· 178 190 fn verify_token_hs256_internal( 179 191 token: &str, 180 192 secret: &[u8], 181 - expected_typ: Option<&str>, 182 - allowed_scopes: Option<&[&str]>, 193 + expected_typ: Option<TokenType>, 194 + allowed_scopes: Option<&[TokenScope]>, 183 195 ) -> Result<TokenData<Claims>> { 184 196 let parts: Vec<&str> = token.split('.').collect(); 185 197 if parts.len() != 3 { ··· 197 209 let header: Header = 198 210 serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 199 211 200 - if header.alg != "HS256" { 212 + if header.alg != SigningAlgorithm::HS256 { 201 213 return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 202 214 } 203 215 ··· 235 247 let claims: Claims = 236 248 serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 237 249 238 - let now = Utc::now().timestamp() as usize; 250 + let now = Utc::now().timestamp(); 239 251 if claims.exp < now { 240 252 return Err(anyhow!("Token expired")); 241 253 } 242 254 243 255 if let Some(scopes) = allowed_scopes { 244 - let token_scope = claims.scope.as_deref().unwrap_or(""); 256 + let token_scope: TokenScope = claims 257 + .scope 258 + .as_deref() 259 + .unwrap_or("") 260 + .parse() 261 + .unwrap_or_else(|e| match e {}); 245 262 if !scopes.contains(&token_scope) { 246 263 return Err(anyhow!("Invalid token scope: {}", token_scope)); 247 264 } ··· 254 271 token: &str, 255 272 key_bytes: &[u8], 256 273 ) -> Result<TokenData<Claims>, TokenVerifyError> { 257 - verify_token_typed_internal(token, key_bytes, Some(TOKEN_TYPE_ACCESS), None) 274 + verify_token_typed_internal(token, key_bytes, Some(TokenType::Access), None) 258 275 } 259 276 260 277 fn verify_token_typed_internal( 261 278 token: &str, 262 279 key_bytes: &[u8], 263 - expected_typ: Option<&str>, 264 - allowed_scopes: Option<&[&str]>, 280 + expected_typ: Option<TokenType>, 281 + allowed_scopes: Option<&[TokenScope]>, 265 282 ) -> Result<TokenData<Claims>, TokenVerifyError> { 266 283 let parts: Vec<&str> = token.split('.').collect(); 267 284 if parts.len() != 3 { ··· 315 332 return Err(TokenVerifyError::Invalid); 316 333 }; 317 334 318 - let now = Utc::now().timestamp() as usize; 335 + let now = Utc::now().timestamp(); 319 336 if claims.exp < now { 320 337 return Err(TokenVerifyError::Expired); 321 338 } 322 339 323 340 if let Some(scopes) = allowed_scopes { 324 - let token_scope = claims.scope.as_deref().unwrap_or(""); 341 + let token_scope: TokenScope = claims 342 + .scope 343 + .as_deref() 344 + .unwrap_or("") 345 + .parse() 346 + .unwrap_or_else(|e| match e {}); 325 347 if !scopes.contains(&token_scope) { 326 348 return Err(TokenVerifyError::Invalid); 327 349 }
+3 -3
crates/tranquil-cache/src/lib.rs
··· 48 48 .arg(key) 49 49 .arg(value) 50 50 .arg("PX") 51 - .arg(ttl.as_millis().min(i64::MAX as u128) as i64) 51 + .arg(i64::try_from(ttl.as_millis()).unwrap_or(i64::MAX)) 52 52 .query_async::<()>(&mut conn) 53 53 .await 54 54 .map_err(|e| CacheError::Connection(e.to_string())) ··· 94 94 async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 95 95 let mut conn = self.conn.clone(); 96 96 let full_key = format!("rl:{}", key); 97 - let window_secs = window_ms.div_ceil(1000).max(1) as i64; 97 + let window_secs = i64::try_from(window_ms.div_ceil(1000).max(1)).unwrap_or(i64::MAX); 98 98 let result: Result<i64, _> = redis::Script::new( 99 99 r"local c = redis.call('INCR', KEYS[1]) 100 100 if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end ··· 106 106 .invoke_async(&mut conn) 107 107 .await; 108 108 match result { 109 - Ok(count) => count <= limit as i64, 109 + Ok(count) => count <= i64::from(limit), 110 110 Err(e) => { 111 111 tracing::warn!(error = %e, "redis rate limit script failed, allowing request"); 112 112 true
+48 -1
crates/tranquil-db-traits/src/backlink.rs
··· 1 + use std::fmt; 2 + use std::str::FromStr; 3 + 1 4 use async_trait::async_trait; 2 5 use tranquil_types::{AtUri, Nsid}; 3 6 use uuid::Uuid; 4 7 5 8 use crate::DbError; 6 9 10 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 11 + pub enum BacklinkPath { 12 + Subject, 13 + SubjectUri, 14 + } 15 + 16 + impl BacklinkPath { 17 + pub fn as_str(&self) -> &'static str { 18 + match self { 19 + Self::Subject => "subject", 20 + Self::SubjectUri => "subject.uri", 21 + } 22 + } 23 + } 24 + 25 + impl fmt::Display for BacklinkPath { 26 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 + f.write_str(self.as_str()) 28 + } 29 + } 30 + 31 + #[derive(Debug, Clone)] 32 + pub struct BacklinkPathParseError(String); 33 + 34 + impl fmt::Display for BacklinkPathParseError { 35 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 36 + write!(f, "unknown backlink path: {}", self.0) 37 + } 38 + } 39 + 40 + impl std::error::Error for BacklinkPathParseError {} 41 + 42 + impl FromStr for BacklinkPath { 43 + type Err = BacklinkPathParseError; 44 + 45 + fn from_str(s: &str) -> Result<Self, Self::Err> { 46 + match s { 47 + "subject" => Ok(Self::Subject), 48 + "subject.uri" => Ok(Self::SubjectUri), 49 + _ => Err(BacklinkPathParseError(s.to_owned())), 50 + } 51 + } 52 + } 53 + 7 54 #[derive(Debug, Clone)] 8 55 pub struct Backlink { 9 56 pub uri: AtUri, 10 - pub path: String, 57 + pub path: BacklinkPath, 11 58 pub link_to: String, 12 59 } 13 60
+10 -1
crates/tranquil-db-traits/src/channel_verification.rs
··· 10 10 } 11 11 12 12 impl ChannelVerificationStatus { 13 - pub fn new(email: bool, discord: bool, telegram: bool, signal: bool) -> Self { 13 + pub fn from_db_row(email: bool, discord: bool, telegram: bool, signal: bool) -> Self { 14 14 Self { 15 15 email, 16 16 discord, 17 17 telegram, 18 18 signal, 19 + } 20 + } 21 + 22 + pub fn from_verified_channels(channels: &[CommsChannel]) -> Self { 23 + Self { 24 + email: channels.contains(&CommsChannel::Email), 25 + discord: channels.contains(&CommsChannel::Discord), 26 + telegram: channels.contains(&CommsChannel::Telegram), 27 + signal: channels.contains(&CommsChannel::Signal), 19 28 } 20 29 } 21 30
+3
crates/tranquil-db-traits/src/error.rs
··· 26 26 #[error("Resource busy, try again")] 27 27 LockContention, 28 28 29 + #[error("Corrupt data in column: {0}")] 30 + CorruptData(&'static str), 31 + 29 32 #[error("Other database error: {0}")] 30 33 Other(String), 31 34 }
+53 -17
crates/tranquil-db-traits/src/infra.rs
··· 31 31 } 32 32 } 33 33 34 - impl From<bool> for InviteCodeState { 35 - fn from(disabled: bool) -> Self { 36 - if disabled { 37 - Self::Disabled 38 - } else { 39 - Self::Active 34 + impl InviteCodeState { 35 + pub fn from_disabled_flag(disabled: bool) -> Self { 36 + match disabled { 37 + true => Self::Disabled, 38 + false => Self::Active, 40 39 } 41 40 } 42 - } 43 41 44 - impl From<Option<bool>> for InviteCodeState { 45 - fn from(disabled: Option<bool>) -> Self { 46 - Self::from(disabled.unwrap_or(false)) 47 - } 48 - } 49 - 50 - impl From<InviteCodeState> for bool { 51 - fn from(state: InviteCodeState) -> Self { 52 - matches!(state, InviteCodeState::Disabled) 42 + pub fn from_optional_disabled_flag(disabled: Option<bool>) -> Self { 43 + Self::from_disabled_flag(disabled.unwrap_or(false)) 53 44 } 54 45 } 55 46 ··· 62 53 Telegram, 63 54 Signal, 64 55 } 56 + 57 + impl CommsChannel { 58 + pub fn as_str(self) -> &'static str { 59 + match self { 60 + Self::Email => "email", 61 + Self::Discord => "discord", 62 + Self::Telegram => "telegram", 63 + Self::Signal => "signal", 64 + } 65 + } 66 + 67 + pub fn display_name(self) -> &'static str { 68 + match self { 69 + Self::Email => "email", 70 + Self::Discord => "Discord", 71 + Self::Telegram => "Telegram", 72 + Self::Signal => "Signal", 73 + } 74 + } 75 + } 76 + 77 + impl std::str::FromStr for CommsChannel { 78 + type Err = InvalidCommsChannel; 79 + 80 + fn from_str(s: &str) -> Result<Self, Self::Err> { 81 + match s { 82 + "email" => Ok(Self::Email), 83 + "discord" => Ok(Self::Discord), 84 + "telegram" => Ok(Self::Telegram), 85 + "signal" => Ok(Self::Signal), 86 + _ => Err(InvalidCommsChannel), 87 + } 88 + } 89 + } 90 + 91 + #[derive(Debug, Clone)] 92 + pub struct InvalidCommsChannel; 93 + 94 + impl std::fmt::Display for InvalidCommsChannel { 95 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 96 + f.write_str("invalid comms channel") 97 + } 98 + } 99 + 100 + impl std::error::Error for InvalidCommsChannel {} 65 101 66 102 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] 67 103 #[sqlx(type_name = "comms_type", rename_all = "snake_case")] ··· 139 175 140 176 impl InviteCodeRow { 141 177 pub fn state(&self) -> InviteCodeState { 142 - InviteCodeState::from(self.disabled) 178 + InviteCodeState::from_optional_disabled_flag(self.disabled) 143 179 } 144 180 } 145 181
+2 -2
crates/tranquil-db-traits/src/lib.rs
··· 14 14 mod sso; 15 15 mod user; 16 16 17 - pub use backlink::{Backlink, BacklinkRepository}; 17 + pub use backlink::{Backlink, BacklinkPath, BacklinkRepository}; 18 18 pub use backup::{ 19 19 BackupForDeletion, BackupRepository, BackupRow, BackupStorageInfo, BlobExportInfo, 20 20 OldBackupInfo, UserBackupInfo, ··· 68 68 UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref, UserLoginCheck, 69 69 UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository, UserResendVerification, 70 70 UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus, UserVerificationInfo, UserWithKey, 71 - VerifiedTotpRecord, 71 + VerifiedTotpRecord, WebauthnChallengeType, 72 72 };
+7
crates/tranquil-db-traits/src/repo.rs
··· 57 57 } 58 58 } 59 59 60 + pub fn for_firehose_typed(&self) -> Option<Self> { 61 + match self { 62 + Self::Active => None, 63 + other => Some(*other), 64 + } 65 + } 66 + 60 67 pub fn parse(s: &str) -> Option<Self> { 61 68 match s.to_lowercase().as_str() { 62 69 "active" => Some(Self::Active),
+9 -23
crates/tranquil-db-traits/src/session.rs
··· 20 20 pub fn is_modern(self) -> bool { 21 21 matches!(self, Self::Modern) 22 22 } 23 - } 24 23 25 - impl From<bool> for LoginType { 26 - fn from(legacy: bool) -> Self { 27 - if legacy { Self::Legacy } else { Self::Modern } 28 - } 29 - } 30 - 31 - impl From<LoginType> for bool { 32 - fn from(lt: LoginType) -> Self { 33 - matches!(lt, LoginType::Legacy) 24 + pub fn from_legacy_flag(legacy: bool) -> Self { 25 + match legacy { 26 + true => Self::Legacy, 27 + false => Self::Modern, 28 + } 34 29 } 35 30 } 36 31 ··· 45 40 pub fn is_privileged(self) -> bool { 46 41 matches!(self, Self::Privileged) 47 42 } 48 - } 49 43 50 - impl From<bool> for AppPasswordPrivilege { 51 - fn from(privileged: bool) -> Self { 52 - if privileged { 53 - Self::Privileged 54 - } else { 55 - Self::Standard 44 + pub fn from_privileged_flag(privileged: bool) -> Self { 45 + match privileged { 46 + true => Self::Privileged, 47 + false => Self::Standard, 56 48 } 57 - } 58 - } 59 - 60 - impl From<AppPasswordPrivilege> for bool { 61 - fn from(p: AppPasswordPrivilege) -> Self { 62 - matches!(p, AppPasswordPrivilege::Privileged) 63 49 } 64 50 } 65 51
+41 -2
crates/tranquil-db-traits/src/sso.rs
··· 113 113 114 114 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 115 115 #[sqlx(type_name = "sso_provider_type", rename_all = "lowercase")] 116 + #[serde(rename_all = "lowercase")] 116 117 pub enum SsoProviderType { 117 118 Github, 118 119 Discord, ··· 141 142 } 142 143 143 144 pub fn parse(s: &str) -> Option<Self> { 144 - match s.to_lowercase().as_str() { 145 + match s { 145 146 "login" => Some(Self::Login), 146 147 "link" => Some(Self::Link), 147 148 "register" => Some(Self::Register), ··· 156 157 } 157 158 } 158 159 160 + impl std::str::FromStr for SsoProviderType { 161 + type Err = InvalidSsoProviderType; 162 + 163 + fn from_str(s: &str) -> Result<Self, Self::Err> { 164 + Self::parse(s).ok_or(InvalidSsoProviderType) 165 + } 166 + } 167 + 168 + #[derive(Debug, Clone)] 169 + pub struct InvalidSsoProviderType; 170 + 171 + impl std::fmt::Display for InvalidSsoProviderType { 172 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 173 + f.write_str("invalid SSO provider type") 174 + } 175 + } 176 + 177 + impl std::error::Error for InvalidSsoProviderType {} 178 + 179 + impl std::str::FromStr for SsoAction { 180 + type Err = InvalidSsoAction; 181 + 182 + fn from_str(s: &str) -> Result<Self, Self::Err> { 183 + Self::parse(s).ok_or(InvalidSsoAction) 184 + } 185 + } 186 + 187 + #[derive(Debug, Clone)] 188 + pub struct InvalidSsoAction; 189 + 190 + impl std::fmt::Display for InvalidSsoAction { 191 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 192 + f.write_str("invalid SSO action") 193 + } 194 + } 195 + 196 + impl std::error::Error for InvalidSsoAction {} 197 + 159 198 impl SsoProviderType { 160 199 pub fn as_str(&self) -> &'static str { 161 200 match self { ··· 169 208 } 170 209 171 210 pub fn parse(s: &str) -> Option<Self> { 172 - match s.to_lowercase().as_str() { 211 + match s { 173 212 "github" => Some(Self::Github), 174 213 "discord" => Some(Self::Discord), 175 214 "google" => Some(Self::Google),
+18 -3
crates/tranquil-db-traits/src/user.rs
··· 6 6 7 7 use crate::{ChannelVerificationStatus, CommsChannel, DbError, SsoProviderType}; 8 8 9 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 10 + pub enum WebauthnChallengeType { 11 + Registration, 12 + Authentication, 13 + } 14 + 15 + impl WebauthnChallengeType { 16 + pub fn as_str(self) -> &'static str { 17 + match self { 18 + Self::Registration => "registration", 19 + Self::Authentication => "authentication", 20 + } 21 + } 22 + } 23 + 9 24 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 10 25 #[sqlx(type_name = "account_type", rename_all = "snake_case")] 11 26 pub enum AccountType { ··· 325 340 async fn save_webauthn_challenge( 326 341 &self, 327 342 did: &Did, 328 - challenge_type: &str, 343 + challenge_type: WebauthnChallengeType, 329 344 state_json: &str, 330 345 ) -> Result<Uuid, DbError>; 331 346 332 347 async fn load_webauthn_challenge( 333 348 &self, 334 349 did: &Did, 335 - challenge_type: &str, 350 + challenge_type: WebauthnChallengeType, 336 351 ) -> Result<Option<String>, DbError>; 337 352 338 353 async fn delete_webauthn_challenge( 339 354 &self, 340 355 did: &Did, 341 - challenge_type: &str, 356 + challenge_type: WebauthnChallengeType, 342 357 ) -> Result<(), DbError>; 343 358 344 359 async fn get_totp_record(&self, did: &Did) -> Result<Option<TotpRecord>, DbError>;
+4 -4
crates/tranquil-db/src/postgres/infra.rs
··· 261 261 .map(|r| InviteCodeInfo { 262 262 code: r.code, 263 263 available_uses: r.available_uses, 264 - state: InviteCodeState::from(r.disabled), 264 + state: InviteCodeState::from_optional_disabled_flag(r.disabled), 265 265 for_account: Some(Did::from(r.for_account)), 266 266 created_at: r.created_at, 267 267 created_by: None, ··· 438 438 .map(|r| InviteCodeInfo { 439 439 code: r.code, 440 440 available_uses: r.available_uses, 441 - state: InviteCodeState::from(r.disabled), 441 + state: InviteCodeState::from_optional_disabled_flag(r.disabled), 442 442 for_account: Some(Did::from(r.for_account)), 443 443 created_at: r.created_at, 444 444 created_by: Some(Did::from(r.created_by)), ··· 461 461 Ok(result.map(|r| InviteCodeInfo { 462 462 code: r.code, 463 463 available_uses: r.available_uses, 464 - state: InviteCodeState::from(r.disabled), 464 + state: InviteCodeState::from_optional_disabled_flag(r.disabled), 465 465 for_account: Some(Did::from(r.for_account)), 466 466 created_at: r.created_at, 467 467 created_by: Some(Did::from(r.created_by)), ··· 492 492 InviteCodeInfo { 493 493 code: r.code, 494 494 available_uses: r.available_uses, 495 - state: InviteCodeState::from(r.disabled), 495 + state: InviteCodeState::from_optional_disabled_flag(r.disabled), 496 496 for_account: Some(Did::from(r.for_account)), 497 497 created_at: r.created_at, 498 498 created_by: Some(Did::from(r.created_by)),
+7 -7
crates/tranquil-db/src/postgres/session.rs
··· 37 37 data.refresh_jti, 38 38 data.access_expires_at, 39 39 data.refresh_expires_at, 40 - bool::from(data.login_type), 40 + data.login_type.is_legacy(), 41 41 data.mfa_verified, 42 42 data.scope, 43 43 data.controller_did.as_ref().map(|d| d.as_str()), ··· 75 75 refresh_jti: r.refresh_jti, 76 76 access_expires_at: r.access_expires_at, 77 77 refresh_expires_at: r.refresh_expires_at, 78 - login_type: LoginType::from(r.legacy_login), 78 + login_type: LoginType::from_legacy_flag(r.legacy_login), 79 79 mfa_verified: r.mfa_verified, 80 80 scope: r.scope, 81 81 controller_did: r.controller_did.map(Did::from), ··· 325 325 name: r.name, 326 326 password_hash: r.password_hash, 327 327 created_at: r.created_at, 328 - privilege: AppPasswordPrivilege::from(r.privileged), 328 + privilege: AppPasswordPrivilege::from_privileged_flag(r.privileged), 329 329 scopes: r.scopes, 330 330 created_by_controller_did: r.created_by_controller_did.map(Did::from), 331 331 }) ··· 358 358 name: r.name, 359 359 password_hash: r.password_hash, 360 360 created_at: r.created_at, 361 - privilege: AppPasswordPrivilege::from(r.privileged), 361 + privilege: AppPasswordPrivilege::from_privileged_flag(r.privileged), 362 362 scopes: r.scopes, 363 363 created_by_controller_did: r.created_by_controller_did.map(Did::from), 364 364 }) ··· 389 389 name: r.name, 390 390 password_hash: r.password_hash, 391 391 created_at: r.created_at, 392 - privilege: AppPasswordPrivilege::from(r.privileged), 392 + privilege: AppPasswordPrivilege::from_privileged_flag(r.privileged), 393 393 scopes: r.scopes, 394 394 created_by_controller_did: r.created_by_controller_did.map(Did::from), 395 395 })) ··· 405 405 data.user_id, 406 406 data.name, 407 407 data.password_hash, 408 - bool::from(data.privilege), 408 + data.privilege.is_privileged(), 409 409 data.scopes, 410 410 data.created_by_controller_did.as_ref().map(|d| d.as_str()) 411 411 ) ··· 486 486 .map_err(map_sqlx_error)?; 487 487 488 488 Ok(row.map(|r| SessionMfaStatus { 489 - login_type: LoginType::from(r.legacy_login), 489 + login_type: LoginType::from_legacy_flag(r.legacy_login), 490 490 mfa_verified: r.mfa_verified, 491 491 last_reauth_at: r.last_reauth_at, 492 492 }))
+37 -26
crates/tranquil-db/src/postgres/sso.rs
··· 68 68 .await 69 69 .map_err(map_sqlx_error)?; 70 70 71 - Ok(row.map(|r| ExternalIdentity { 72 - id: r.id, 73 - did: unsafe { Did::new_unchecked(&r.did) }, 74 - provider: r.provider, 75 - provider_user_id: ExternalUserId::from(r.provider_user_id), 76 - provider_username: r.provider_username.map(ExternalUsername::from), 77 - provider_email: r.provider_email.map(ExternalEmail::from), 78 - created_at: r.created_at, 79 - updated_at: r.updated_at, 80 - last_login_at: r.last_login_at, 81 - })) 71 + row.map(|r| { 72 + Ok(ExternalIdentity { 73 + id: r.id, 74 + did: r.did.parse().map_err(|_| DbError::CorruptData("DID"))?, 75 + provider: r.provider, 76 + provider_user_id: ExternalUserId::from(r.provider_user_id), 77 + provider_username: r.provider_username.map(ExternalUsername::from), 78 + provider_email: r.provider_email.map(ExternalEmail::from), 79 + created_at: r.created_at, 80 + updated_at: r.updated_at, 81 + last_login_at: r.last_login_at, 82 + }) 83 + }) 84 + .transpose() 82 85 } 83 86 84 87 async fn get_external_identities_by_did( ··· 99 102 .await 100 103 .map_err(map_sqlx_error)?; 101 104 102 - Ok(rows 103 - .into_iter() 104 - .map(|r| ExternalIdentity { 105 - id: r.id, 106 - did: unsafe { Did::new_unchecked(&r.did) }, 107 - provider: r.provider, 108 - provider_user_id: ExternalUserId::from(r.provider_user_id), 109 - provider_username: r.provider_username.map(ExternalUsername::from), 110 - provider_email: r.provider_email.map(ExternalEmail::from), 111 - created_at: r.created_at, 112 - updated_at: r.updated_at, 113 - last_login_at: r.last_login_at, 105 + rows.into_iter() 106 + .map(|r| { 107 + Ok(ExternalIdentity { 108 + id: r.id, 109 + did: r.did.parse().map_err(|_| DbError::CorruptData("DID"))?, 110 + provider: r.provider, 111 + provider_user_id: ExternalUserId::from(r.provider_user_id), 112 + provider_username: r.provider_username.map(ExternalUsername::from), 113 + provider_email: r.provider_email.map(ExternalEmail::from), 114 + created_at: r.created_at, 115 + updated_at: r.updated_at, 116 + last_login_at: r.last_login_at, 117 + }) 114 118 }) 115 - .collect()) 119 + .collect() 116 120 } 117 121 118 122 async fn update_external_identity_login( ··· 202 206 .map_err(map_sqlx_error)?; 203 207 204 208 row.map(|r| { 205 - let action = SsoAction::parse(&r.action).ok_or(DbError::NotFound)?; 209 + let action: SsoAction = r 210 + .action 211 + .parse() 212 + .map_err(|_| DbError::CorruptData("sso_action"))?; 206 213 Ok(SsoAuthState { 207 214 state: r.state, 208 215 request_uri: r.request_uri, ··· 210 217 action, 211 218 nonce: r.nonce, 212 219 code_verifier: r.code_verifier, 213 - did: r.did.map(|d| unsafe { Did::new_unchecked(&d) }), 220 + did: r 221 + .did 222 + .map(|d| d.parse::<Did>()) 223 + .transpose() 224 + .map_err(|_| DbError::CorruptData("DID"))?, 214 225 created_at: r.created_at, 215 226 expires_at: r.expires_at, 216 227 })
+14 -14
crates/tranquil-db/src/postgres/user.rs
··· 14 14 UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref, 15 15 UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository, 16 16 UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus, 17 - UserVerificationInfo, UserWithKey, 17 + UserVerificationInfo, UserWithKey, WebauthnChallengeType, 18 18 }; 19 19 20 20 pub struct PostgresUserRepository { ··· 281 281 password_hash: r.password_hash, 282 282 deactivated_at: r.deactivated_at, 283 283 takedown_ref: r.takedown_ref, 284 - channel_verification: ChannelVerificationStatus::new( 284 + channel_verification: ChannelVerificationStatus::from_db_row( 285 285 r.email_verified, 286 286 r.discord_verified, 287 287 r.telegram_verified, ··· 746 746 id: r.id, 747 747 handle: Handle::from(r.handle), 748 748 email: r.email, 749 - channel_verification: ChannelVerificationStatus::new( 749 + channel_verification: ChannelVerificationStatus::from_db_row( 750 750 r.email_verified, 751 751 r.discord_verified, 752 752 r.telegram_verified, ··· 1029 1029 async fn save_webauthn_challenge( 1030 1030 &self, 1031 1031 did: &Did, 1032 - challenge_type: &str, 1032 + challenge_type: WebauthnChallengeType, 1033 1033 state_json: &str, 1034 1034 ) -> Result<Uuid, DbError> { 1035 1035 let id = Uuid::new_v4(); ··· 1041 1041 id, 1042 1042 did.as_str(), 1043 1043 challenge, 1044 - challenge_type, 1044 + challenge_type.as_str(), 1045 1045 state_json, 1046 1046 expires_at, 1047 1047 ) ··· 1055 1055 async fn load_webauthn_challenge( 1056 1056 &self, 1057 1057 did: &Did, 1058 - challenge_type: &str, 1058 + challenge_type: WebauthnChallengeType, 1059 1059 ) -> Result<Option<String>, DbError> { 1060 1060 let row = sqlx::query_scalar!( 1061 1061 r#"SELECT state_json FROM webauthn_challenges 1062 1062 WHERE did = $1 AND challenge_type = $2 AND expires_at > NOW() 1063 1063 ORDER BY created_at DESC LIMIT 1"#, 1064 1064 did.as_str(), 1065 - challenge_type 1065 + challenge_type.as_str() 1066 1066 ) 1067 1067 .fetch_optional(&self.pool) 1068 1068 .await ··· 1074 1074 async fn delete_webauthn_challenge( 1075 1075 &self, 1076 1076 did: &Did, 1077 - challenge_type: &str, 1077 + challenge_type: WebauthnChallengeType, 1078 1078 ) -> Result<(), DbError> { 1079 1079 sqlx::query!( 1080 1080 "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = $2", 1081 1081 did.as_str(), 1082 - challenge_type 1082 + challenge_type.as_str() 1083 1083 ) 1084 1084 .execute(&self.pool) 1085 1085 .await ··· 1365 1365 preferred_comms_channel: row.preferred_comms_channel, 1366 1366 deactivated_at: row.deactivated_at, 1367 1367 takedown_ref: row.takedown_ref, 1368 - channel_verification: ChannelVerificationStatus::new( 1368 + channel_verification: ChannelVerificationStatus::from_db_row( 1369 1369 row.email_verified, 1370 1370 row.discord_verified, 1371 1371 row.telegram_verified, ··· 1395 1395 id: row.id, 1396 1396 two_factor_enabled: row.two_factor_enabled, 1397 1397 preferred_comms_channel: row.preferred_comms_channel, 1398 - channel_verification: ChannelVerificationStatus::new( 1398 + channel_verification: ChannelVerificationStatus::from_db_row( 1399 1399 row.email_verified, 1400 1400 row.discord_verified, 1401 1401 row.telegram_verified, ··· 1432 1432 takedown_ref: row.takedown_ref, 1433 1433 preferred_locale: row.preferred_locale, 1434 1434 preferred_comms_channel: row.preferred_comms_channel, 1435 - channel_verification: ChannelVerificationStatus::new( 1435 + channel_verification: ChannelVerificationStatus::from_db_row( 1436 1436 row.email_verified, 1437 1437 row.discord_verified, 1438 1438 row.telegram_verified, ··· 1525 1525 email: row.email, 1526 1526 deactivated_at: row.deactivated_at, 1527 1527 takedown_ref: row.takedown_ref, 1528 - channel_verification: ChannelVerificationStatus::new( 1528 + channel_verification: ChannelVerificationStatus::from_db_row( 1529 1529 row.email_verified, 1530 1530 row.discord_verified, 1531 1531 row.telegram_verified, ··· 1602 1602 discord_username: row.discord_username, 1603 1603 telegram_username: row.telegram_username, 1604 1604 signal_username: row.signal_username, 1605 - channel_verification: ChannelVerificationStatus::new( 1605 + channel_verification: ChannelVerificationStatus::from_db_row( 1606 1606 row.email_verified, 1607 1607 row.discord_verified, 1608 1608 row.telegram_verified,
+1 -1
crates/tranquil-pds/src/api/actor/preferences.rs
··· 21 21 let bday = NaiveDate::parse_from_str(birth_date, "%Y-%m-%d").ok()?; 22 22 let today = Utc::now().date_naive(); 23 23 let mut age = today.year() - bday.year(); 24 - let m = today.month() as i32 - bday.month() as i32; 24 + let m = i32::try_from(today.month()).unwrap_or(0) - i32::try_from(bday.month()).unwrap_or(0); 25 25 if m < 0 || (m == 0 && today.day() < bday.day()) { 26 26 age -= 1; 27 27 }
+4 -1
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 48 48 did, e 49 49 ); 50 50 } 51 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 51 + let _ = state 52 + .cache 53 + .delete(&crate::cache_keys::handle_key(&handle)) 54 + .await; 52 55 Ok(EmptyResponse::ok().into_response()) 53 56 }
+2 -2
crates/tranquil-pds/src/api/admin/account/email.rs
··· 1 - use crate::api::error::{ApiError, AtpJson, DbResultExt}; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::Did; ··· 30 30 pub async fn send_email( 31 31 State(state): State<AppState>, 32 32 _auth: Auth<Admin>, 33 - AtpJson(input): AtpJson<SendEmailInput>, 33 + Json(input): Json<SendEmailInput>, 34 34 ) -> Result<Response, ApiError> { 35 35 let content = input.content.trim(); 36 36 if content.is_empty() {
+3 -2
crates/tranquil-pds/src/api/admin/account/search.rs
··· 67 67 .await 68 68 .log_db_err("in search_accounts")?; 69 69 70 - let has_more = rows.len() > limit as usize; 70 + let limit_usize = usize::try_from(limit).unwrap_or(0); 71 + let has_more = rows.len() > limit_usize; 71 72 let accounts: Vec<AccountView> = rows 72 73 .into_iter() 73 - .take(limit as usize) 74 + .take(limit_usize) 74 75 .map(|row| AccountView { 75 76 did: row.did.clone(), 76 77 handle: row.handle,
+9 -3
crates/tranquil-pds/src/api/admin/account/update.rs
··· 84 84 .ok() 85 85 .flatten() 86 86 .ok_or(ApiError::AccountNotFound)?; 87 - let handle_for_check = unsafe { Handle::new_unchecked(&handle) }; 87 + let handle_for_check: Handle = handle.parse().map_err(|_| ApiError::InvalidHandle(None))?; 88 88 if let Ok(true) = state 89 89 .user_repo 90 90 .check_handle_exists(&handle_for_check, user_id) ··· 100 100 Ok(0) => Err(ApiError::AccountNotFound), 101 101 Ok(_) => { 102 102 if let Some(old) = old_handle { 103 - let _ = state.cache.delete(&format!("handle:{}", old)).await; 103 + let _ = state 104 + .cache 105 + .delete(&crate::cache_keys::handle_key(&old)) 106 + .await; 104 107 } 105 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 108 + let _ = state 109 + .cache 110 + .delete(&crate::cache_keys::handle_key(&handle)) 111 + .await; 106 112 if let Err(e) = crate::api::repo::record::sequence_identity_event( 107 113 &state, 108 114 did,
+18 -9
crates/tranquil-pds/src/api/admin/config.rs
··· 3 3 use crate::state::AppState; 4 4 use axum::{Json, extract::State}; 5 5 use serde::{Deserialize, Serialize}; 6 - use tracing::error; 6 + use tracing::{error, warn}; 7 7 use tranquil_types::CidLink; 8 8 9 9 #[derive(Serialize)] ··· 187 187 }; 188 188 189 189 if let Some(old_cid_str) = should_delete_old { 190 - let old_cid = unsafe { CidLink::new_unchecked(old_cid_str) }; 191 - if let Ok(Some(storage_key)) = 192 - state.infra_repo.get_blob_storage_key_by_cid(&old_cid).await 193 - { 194 - if let Err(e) = state.blob_store.delete(&storage_key).await { 195 - error!("Failed to delete old logo blob from storage: {:?}", e); 190 + match CidLink::new(old_cid_str) { 191 + Ok(old_cid) => { 192 + if let Ok(Some(storage_key)) = 193 + state.infra_repo.get_blob_storage_key_by_cid(&old_cid).await 194 + { 195 + if let Err(e) = state.blob_store.delete(&storage_key).await { 196 + error!("Failed to delete old logo blob from storage: {:?}", e); 197 + } 198 + if let Err(e) = state.infra_repo.delete_blob_by_cid(&old_cid).await { 199 + error!("Failed to delete old logo blob record: {:?}", e); 200 + } 201 + } 196 202 } 197 - if let Err(e) = state.infra_repo.delete_blob_by_cid(&old_cid).await { 198 - error!("Failed to delete old logo blob record: {:?}", e); 203 + Err(e) => { 204 + warn!( 205 + "Old logo CID in database is invalid, skipping cleanup: {:?}", 206 + e 207 + ); 199 208 } 200 209 } 201 210 }
+1 -1
crates/tranquil-pds/src/api/admin/invite.rs
··· 144 144 }) 145 145 .collect(); 146 146 147 - let next_cursor = if codes_rows.len() == limit as usize { 147 + let next_cursor = if codes_rows.len() == usize::try_from(limit).unwrap_or(0) { 148 148 codes_rows.last().map(|r| r.code.clone()) 149 149 } else { 150 150 None
+8 -2
crates/tranquil-pds/src/api/admin/status.rs
··· 175 175 Some("com.atproto.admin.defs#repoRef") => { 176 176 let did_str = input.subject.get("did").and_then(|d| d.as_str()); 177 177 if let Some(did_str) = did_str { 178 - let did = unsafe { Did::new_unchecked(did_str) }; 178 + let did: Did = match did_str.parse() { 179 + Ok(d) => d, 180 + Err(_) => return Err(ApiError::InvalidDid("Invalid DID format".into())), 181 + }; 179 182 if let Some(takedown) = &input.takedown { 180 183 let takedown_ref = if takedown.applied { 181 184 takedown.r#ref.as_deref() ··· 230 233 } 231 234 } 232 235 if let Ok(Some(handle)) = state.user_repo.get_handle_by_did(&did).await { 233 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 236 + let _ = state 237 + .cache 238 + .delete(&crate::cache_keys::handle_key(&handle)) 239 + .await; 234 240 } 235 241 return Ok(( 236 242 StatusCode::OK,
+7 -8
crates/tranquil-pds/src/api/age_assurance.rs
··· 1 - use crate::auth::{extract_auth_token_from_header, validate_token_with_dpop}; 1 + use crate::auth::{AccountRequirement, extract_auth_token_from_header, validate_token_with_dpop}; 2 2 use crate::state::AppState; 3 3 use axum::{ 4 4 Json, 5 5 extract::State, 6 - http::{HeaderMap, StatusCode}, 6 + http::{HeaderMap, Method, StatusCode}, 7 7 response::{IntoResponse, Response}, 8 8 }; 9 9 use serde_json::json; ··· 33 33 } 34 34 35 35 async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option<String> { 36 - let auth_header = crate::util::get_header_str(headers, "Authorization"); 36 + let auth_header = crate::util::get_header_str(headers, http::header::AUTHORIZATION); 37 37 tracing::debug!(?auth_header, "age assurance: extracting token"); 38 38 39 39 let extracted = extract_auth_token_from_header(auth_header)?; 40 40 tracing::debug!("age assurance: got token, validating"); 41 41 42 - let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 42 + let dpop_proof = crate::util::get_header_str(headers, crate::util::HEADER_DPOP); 43 43 let http_uri = "/"; 44 44 45 45 let auth_user = match validate_token_with_dpop( 46 46 state.user_repo.as_ref(), 47 47 state.oauth_repo.as_ref(), 48 48 &extracted.token, 49 - extracted.is_dpop, 49 + extracted.scheme, 50 50 dpop_proof, 51 - "GET", 51 + Method::GET.as_str(), 52 52 http_uri, 53 - false, 54 - false, 53 + AccountRequirement::Active, 55 54 ) 56 55 .await 57 56 {
+6 -5
crates/tranquil-pds/src/api/backup.rs
··· 4 4 use crate::scheduled::generate_full_backup; 5 5 use crate::state::AppState; 6 6 use crate::storage::{BackupStorage, backup_retention_count}; 7 + use anyhow::Context; 7 8 use axum::{ 8 9 Json, 9 10 extract::{Query, State}, ··· 213 214 }; 214 215 215 216 let block_count = crate::scheduled::count_car_blocks(&car_bytes); 216 - let size_bytes = car_bytes.len() as i64; 217 + let size_bytes = i64::try_from(car_bytes.len()).unwrap_or(i64::MAX); 217 218 218 219 let storage_key = match backup_storage 219 220 .put_backup(&user.did, &repo_rev, &car_bytes) ··· 292 293 backup_storage: &dyn BackupStorage, 293 294 user_id: uuid::Uuid, 294 295 retention_count: u32, 295 - ) -> Result<(), String> { 296 + ) -> anyhow::Result<()> { 296 297 let old_backups: Vec<OldBackupInfo> = backup_repo 297 - .get_old_backups(user_id, retention_count as i64) 298 + .get_old_backups(user_id, i64::from(retention_count)) 298 299 .await 299 - .map_err(|e| format!("DB error fetching old backups: {}", e))?; 300 + .context("DB error fetching old backups")?; 300 301 301 302 for backup in old_backups { 302 303 if let Err(e) = backup_storage.delete_backup(&backup.storage_key).await { ··· 311 312 backup_repo 312 313 .delete_backup(backup.id) 313 314 .await 314 - .map_err(|e| format!("Failed to delete old backup record: {}", e))?; 315 + .context("Failed to delete old backup record")?; 315 316 } 316 317 317 318 Ok(())
+12 -11
crates/tranquil-pds/src/api/delegation.rs
··· 7 7 }; 8 8 use crate::rate_limit::{AccountCreationLimit, RateLimited}; 9 9 use crate::state::AppState; 10 - use crate::types::{Did, Handle, Nsid, Rkey}; 10 + use crate::types::{Did, Handle}; 11 11 use crate::util::{pds_hostname, pds_hostname_without_port}; 12 12 use axum::{ 13 13 Json, ··· 164 164 .session_repo 165 165 .delete_app_passwords_by_controller(&auth.did, &input.controller_did) 166 166 .await 167 - .unwrap_or(0) as usize; 167 + .unwrap_or(0) 168 + .try_into() 169 + .unwrap_or(0usize); 168 170 169 171 let revoked_oauth_tokens = state 170 172 .oauth_repo ··· 473 475 Err(_) => return Ok(ApiError::InvalidInviteCode.into_response()), 474 476 } 475 477 } else { 476 - let invite_required = std::env::var("INVITE_CODE_REQUIRED") 477 - .map(|v| v == "true" || v == "1") 478 - .unwrap_or(false); 478 + let invite_required = crate::util::parse_env_bool("INVITE_CODE_REQUIRED"); 479 479 if invite_required { 480 480 return Ok(ApiError::InviteCodeRequired.into_response()); 481 481 } ··· 529 529 .into_response()); 530 530 } 531 531 532 - let did = unsafe { Did::new_unchecked(&genesis_result.did) }; 533 - let handle = unsafe { Handle::new_unchecked(&handle) }; 532 + let did: Did = genesis_result 533 + .did 534 + .parse() 535 + .map_err(|_| ApiError::InternalError(Some("PLC genesis returned invalid DID".into())))?; 536 + let handle: Handle = handle.parse().map_err(|_| ApiError::InvalidHandle(None))?; 534 537 info!(did = %did, handle = %handle, controller = %can_control.did(), "Created DID for delegated account"); 535 538 536 539 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { ··· 627 630 "$type": "app.bsky.actor.profile", 628 631 "displayName": handle 629 632 }); 630 - let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 631 - let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 632 633 if let Err(e) = crate::api::repo::record::create_record_internal( 633 634 &state, 634 635 &did, 635 - &profile_collection, 636 - &profile_rkey, 636 + &crate::types::PROFILE_COLLECTION, 637 + &crate::types::PROFILE_RKEY, 637 638 &profile_record, 638 639 ) 639 640 .await
+1 -1
crates/tranquil-pds/src/api/discord_webhook.rs
··· 183 183 state.user_repo.as_ref(), 184 184 state.infra_repo.as_ref(), 185 185 user_id, 186 - "discord", 186 + tranquil_db_traits::CommsChannel::Discord, 187 187 &discord_user_id, 188 188 pds_hostname(), 189 189 )
+13 -68
crates/tranquil-pds/src/api/error.rs
··· 1 1 use axum::{ 2 2 Json, 3 - extract::{FromRequest, Request, rejection::JsonRejection}, 4 - http::StatusCode, 3 + http::{HeaderValue, StatusCode}, 5 4 response::{IntoResponse, Response}, 6 5 }; 7 - use serde::{Serialize, de::DeserializeOwned}; 6 + use serde::Serialize; 8 7 use std::borrow::Cow; 9 8 10 9 #[derive(Debug, Serialize)] ··· 103 102 UpstreamTimeout, 104 103 UpstreamUnavailable(String), 105 104 UpstreamError { 106 - status: u16, 105 + status: StatusCode, 107 106 error: Option<String>, 108 107 message: Option<String>, 109 108 }, ··· 127 126 } 128 127 Self::ServiceUnavailable(_) | Self::BackupsDisabled => StatusCode::SERVICE_UNAVAILABLE, 129 128 Self::UpstreamTimeout => StatusCode::GATEWAY_TIMEOUT, 130 - Self::UpstreamError { status, .. } => { 131 - StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY) 132 - } 129 + Self::UpstreamError { status, .. } => *status, 133 130 Self::AuthenticationRequired 134 131 | Self::AuthenticationFailed(_) 135 132 | Self::AccountDeactivated ··· 451 448 _ => None, 452 449 } 453 450 } 454 - pub fn from_upstream_response(status: u16, body: &[u8]) -> Self { 451 + pub fn from_upstream_response(status: StatusCode, body: &[u8]) -> Self { 455 452 if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(body) { 456 453 let error = parsed 457 454 .get("error") ··· 485 482 match &self { 486 483 Self::ExpiredToken(_) => { 487 484 response.headers_mut().insert( 488 - "WWW-Authenticate", 489 - "Bearer error=\"invalid_token\", error_description=\"Token has expired\"" 490 - .parse() 491 - .unwrap(), 485 + http::header::WWW_AUTHENTICATE, 486 + HeaderValue::from_static( 487 + "Bearer error=\"invalid_token\", error_description=\"Token has expired\"", 488 + ), 492 489 ); 493 490 } 494 491 Self::OAuthExpiredToken(_) => { 495 492 response.headers_mut().insert( 496 - "WWW-Authenticate", 497 - "DPoP error=\"invalid_token\", error_description=\"Token has expired\"" 498 - .parse() 499 - .unwrap(), 493 + http::header::WWW_AUTHENTICATE, 494 + HeaderValue::from_static( 495 + "DPoP error=\"invalid_token\", error_description=\"Token has expired\"", 496 + ), 500 497 ); 501 498 } 502 499 _ => {} ··· 721 718 #[allow(clippy::result_large_err)] 722 719 pub fn parse_did_option(s: Option<&str>) -> Result<Option<tranquil_types::Did>, Response> { 723 720 s.map(parse_did).transpose() 724 - } 725 - 726 - pub struct AtpJson<T>(pub T); 727 - 728 - impl<T, S> FromRequest<S> for AtpJson<T> 729 - where 730 - T: DeserializeOwned, 731 - S: Send + Sync, 732 - { 733 - type Rejection = (StatusCode, Json<serde_json::Value>); 734 - 735 - async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { 736 - match Json::<T>::from_request(req, state).await { 737 - Ok(Json(value)) => Ok(AtpJson(value)), 738 - Err(rejection) => { 739 - let message = extract_json_error_message(&rejection); 740 - Err(( 741 - StatusCode::BAD_REQUEST, 742 - Json(serde_json::json!({ 743 - "error": "InvalidRequest", 744 - "message": message 745 - })), 746 - )) 747 - } 748 - } 749 - } 750 - } 751 - 752 - fn extract_json_error_message(rejection: &JsonRejection) -> String { 753 - match rejection { 754 - JsonRejection::JsonDataError(e) => { 755 - let inner = e.body_text(); 756 - if inner.contains("missing field") { 757 - let field = inner 758 - .split("missing field `") 759 - .nth(1) 760 - .and_then(|s| s.split('`').next()) 761 - .unwrap_or("unknown"); 762 - format!("Missing required field: {}", field) 763 - } else if inner.contains("invalid type") { 764 - format!("Invalid field type: {}", inner) 765 - } else { 766 - inner 767 - } 768 - } 769 - JsonRejection::JsonSyntaxError(_) => "Invalid JSON syntax".to_string(), 770 - JsonRejection::MissingJsonContentType(_) => { 771 - "Content-Type must be application/json".to_string() 772 - } 773 - JsonRejection::BytesRejection(_) => "Failed to read request body".to_string(), 774 - _ => "Invalid request body".to_string(), 775 - } 776 721 } 777 722 778 723 pub trait DbResultExt<T> {
+56 -55
crates/tranquil-pds/src/api/identity/account.rs
··· 5 5 use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 6 6 use crate::rate_limit::{AccountCreationLimit, RateLimited}; 7 7 use crate::state::AppState; 8 - use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 8 + use crate::types::{Did, Handle, PlainPassword}; 9 9 use crate::util::{pds_hostname, pds_hostname_without_port}; 10 10 use crate::validation::validate_password; 11 11 use axum::{ ··· 34 34 pub did: Option<String>, 35 35 pub did_type: Option<String>, 36 36 pub signing_key: Option<String>, 37 - pub verification_channel: Option<String>, 37 + pub verification_channel: Option<tranquil_db_traits::CommsChannel>, 38 38 pub discord_username: Option<String>, 39 39 pub telegram_username: Option<String>, 40 40 pub signal_username: Option<String>, ··· 50 50 pub access_jwt: String, 51 51 pub refresh_jwt: String, 52 52 pub verification_required: bool, 53 - pub verification_channel: String, 53 + pub verification_channel: tranquil_db_traits::CommsChannel, 54 54 } 55 55 56 56 pub async fn create_account( ··· 73 73 info!("create_account called"); 74 74 } 75 75 76 - let migration_auth = if let Some(extracted) = 77 - extract_auth_token_from_header(crate::util::get_header_str(&headers, "Authorization")) 78 - { 76 + let migration_auth = if let Some(extracted) = extract_auth_token_from_header( 77 + crate::util::get_header_str(&headers, http::header::AUTHORIZATION), 78 + ) { 79 79 let token = extracted.token; 80 80 if is_service_token(&token) { 81 81 let verifier = ServiceTokenVerifier::new(); ··· 190 190 { 191 191 return ApiError::InvalidEmail.into_response(); 192 192 } 193 - let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 194 - let valid_channels = ["email", "discord", "telegram", "signal"]; 195 - if !valid_channels.contains(&verification_channel) && !is_migration { 196 - return ApiError::InvalidVerificationChannel.into_response(); 197 - } 193 + let verification_channel = input 194 + .verification_channel 195 + .unwrap_or(tranquil_db_traits::CommsChannel::Email); 198 196 let verification_recipient = if is_migration { 199 197 None 200 198 } else { 201 199 Some(match verification_channel { 202 - "email" => match &input.email { 200 + tranquil_db_traits::CommsChannel::Email => match &input.email { 203 201 Some(email) if !email.trim().is_empty() => email.trim().to_string(), 204 202 _ => return ApiError::MissingEmail.into_response(), 205 203 }, 206 - "discord" => match &input.discord_username { 204 + tranquil_db_traits::CommsChannel::Discord => match &input.discord_username { 207 205 Some(username) if !username.trim().is_empty() => { 208 206 let clean = username.trim().to_lowercase(); 209 207 if !crate::api::validation::is_valid_discord_username(&clean) { ··· 215 213 } 216 214 _ => return ApiError::MissingDiscordId.into_response(), 217 215 }, 218 - "telegram" => match &input.telegram_username { 216 + tranquil_db_traits::CommsChannel::Telegram => match &input.telegram_username { 219 217 Some(username) if !username.trim().is_empty() => { 220 218 let clean = username.trim().trim_start_matches('@'); 221 219 if !crate::api::validation::is_valid_telegram_username(clean) { ··· 227 225 } 228 226 _ => return ApiError::MissingTelegramUsername.into_response(), 229 227 }, 230 - "signal" => match &input.signal_username { 228 + tranquil_db_traits::CommsChannel::Signal => match &input.signal_username { 231 229 Some(username) if !username.trim().is_empty() => { 232 230 username.trim().trim_start_matches('@').to_lowercase() 233 231 } 234 232 _ => return ApiError::MissingSignalNumber.into_response(), 235 233 }, 236 - _ => return ApiError::InvalidVerificationChannel.into_response(), 237 234 }) 238 235 }; 239 236 let hostname = pds_hostname(); ··· 304 301 && let Err(e) = 305 302 verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()).await 306 303 { 307 - return ApiError::InvalidDid(e).into_response(); 304 + return ApiError::InvalidDid(e.to_string()).into_response(); 308 305 } 309 306 info!(did = %d, "Creating external did:web account"); 310 307 d.clone() ··· 320 317 verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()) 321 318 .await 322 319 { 323 - return ApiError::InvalidDid(e).into_response(); 320 + return ApiError::InvalidDid(e.to_string()).into_response(); 324 321 } 325 322 d.clone() 326 323 } else if !d.trim().is_empty() { ··· 397 394 } 398 395 }; 399 396 if is_migration { 397 + let did_typed: Did = match did.parse() { 398 + Ok(d) => d, 399 + Err(_) => return ApiError::InternalError(Some("Invalid DID".into())).into_response(), 400 + }; 401 + let handle_typed: Handle = match handle.parse() { 402 + Ok(h) => h, 403 + Err(_) => return ApiError::InvalidHandle(None).into_response(), 404 + }; 400 405 let reactivate_input = tranquil_db_traits::MigrationReactivationInput { 401 - did: unsafe { Did::new_unchecked(&did) }, 402 - new_handle: unsafe { Handle::new_unchecked(&handle) }, 406 + did: did_typed.clone(), 407 + new_handle: handle_typed.clone(), 403 408 new_email: email.clone(), 404 409 }; 405 410 match state ··· 453 458 } 454 459 }; 455 460 let session_data = tranquil_db_traits::SessionTokenCreate { 456 - did: unsafe { Did::new_unchecked(&did) }, 461 + did: did_typed.clone(), 457 462 access_jti: access_meta.jti.clone(), 458 463 refresh_jti: refresh_meta.jti.clone(), 459 464 access_expires_at: access_meta.expires_at, ··· 470 475 } 471 476 let hostname = pds_hostname(); 472 477 let verification_required = if let Some(ref user_email) = email { 473 - let token = 474 - crate::auth::verification_token::generate_migration_token(&did, user_email); 478 + let token = crate::auth::verification_token::generate_migration_token( 479 + &did_typed, user_email, 480 + ); 475 481 let formatted_token = 476 482 crate::auth::verification_token::format_token_for_display(&token); 477 483 if let Err(e) = crate::comms::comms_repo::enqueue_migration_verification( ··· 494 500 axum::http::StatusCode::OK, 495 501 Json(CreateAccountOutput { 496 502 handle: handle.clone().into(), 497 - did: unsafe { Did::new_unchecked(&did) }, 503 + did: did_typed.clone(), 498 504 did_doc: state.did_resolver.resolve_did_document(&did).await, 499 505 access_jwt: access_meta.token, 500 506 refresh_jwt: refresh_meta.token, 501 507 verification_required, 502 - verification_channel: "email".to_string(), 508 + verification_channel: tranquil_db_traits::CommsChannel::Email, 503 509 }), 504 510 ) 505 511 .into_response(); ··· 518 524 } 519 525 } 520 526 521 - let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 527 + let handle_typed: Handle = match handle.parse() { 528 + Ok(h) => h, 529 + Err(_) => return ApiError::InvalidHandle(None).into_response(), 530 + }; 522 531 let handle_available = match state 523 532 .user_repo 524 533 .check_handle_available_for_new_account(&handle_typed) ··· 534 543 return ApiError::HandleTaken.into_response(); 535 544 } 536 545 537 - let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 538 - .map(|v| v == "true" || v == "1") 539 - .unwrap_or(false); 546 + let invite_code_required = crate::util::parse_env_bool("INVITE_CODE_REQUIRED"); 540 547 if invite_code_required 541 548 && input 542 549 .invite_code ··· 602 609 } 603 610 }; 604 611 let rev = Tid::now(LimitedU32::MIN); 605 - let did_for_commit = unsafe { Did::new_unchecked(&did) }; 612 + let did_for_commit: Did = match did.parse() { 613 + Ok(d) => d, 614 + Err(_) => return ApiError::InternalError(Some("Invalid DID".into())).into_response(), 615 + }; 606 616 let (commit_bytes, _sig) = 607 617 match create_signed_commit(&did_for_commit, mst_root, rev.as_ref(), None, &signing_key) { 608 618 Ok(result) => result, ··· 629 639 }) 630 640 }); 631 641 632 - let preferred_comms_channel = match verification_channel { 633 - "email" => tranquil_db_traits::CommsChannel::Email, 634 - "discord" => tranquil_db_traits::CommsChannel::Discord, 635 - "telegram" => tranquil_db_traits::CommsChannel::Telegram, 636 - "signal" => tranquil_db_traits::CommsChannel::Signal, 637 - _ => tranquil_db_traits::CommsChannel::Email, 638 - }; 642 + let preferred_comms_channel = verification_channel; 639 643 640 644 let create_input = tranquil_db_traits::CreatePasswordAccountInput { 641 - handle: unsafe { Handle::new_unchecked(&handle) }, 645 + handle: handle_typed.clone(), 642 646 email: email.clone(), 643 - did: unsafe { Did::new_unchecked(&did) }, 647 + did: did_for_commit.clone(), 644 648 password_hash, 645 649 preferred_comms_channel, 646 650 discord_username: input ··· 689 693 }; 690 694 let user_id = create_result.user_id; 691 695 if !is_migration && !is_did_web_byod { 692 - let did_typed = unsafe { Did::new_unchecked(&did) }; 693 - let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 694 696 if let Err(e) = crate::api::repo::record::sequence_identity_event( 695 697 &state, 696 - &did_typed, 698 + &did_for_commit, 697 699 Some(&handle_typed), 698 700 ) 699 701 .await ··· 702 704 } 703 705 if let Err(e) = crate::api::repo::record::sequence_account_event( 704 706 &state, 705 - &did_typed, 707 + &did_for_commit, 706 708 tranquil_db_traits::AccountStatus::Active, 707 709 ) 708 710 .await ··· 711 713 } 712 714 if let Err(e) = crate::api::repo::record::sequence_genesis_commit( 713 715 &state, 714 - &did_typed, 716 + &did_for_commit, 715 717 &commit_cid, 716 718 &mst_root, 717 719 &rev_str, ··· 722 724 } 723 725 if let Err(e) = crate::api::repo::record::sequence_sync_event( 724 726 &state, 725 - &did_typed, 727 + &did_for_commit, 726 728 &commit_cid_str, 727 729 Some(rev.as_ref()), 728 730 ) ··· 734 736 "$type": "app.bsky.actor.profile", 735 737 "displayName": input.handle 736 738 }); 737 - let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 738 - let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 739 739 if let Err(e) = crate::api::repo::record::create_record_internal( 740 740 &state, 741 - &did_typed, 742 - &profile_collection, 743 - &profile_rkey, 741 + &did_for_commit, 742 + &crate::types::PROFILE_COLLECTION, 743 + &crate::types::PROFILE_RKEY, 744 744 &profile_record, 745 745 ) 746 746 .await ··· 752 752 if !is_migration { 753 753 if let Some(ref recipient) = verification_recipient { 754 754 let verification_token = crate::auth::verification_token::generate_signup_token( 755 - &did, 755 + &did_for_commit, 756 756 verification_channel, 757 757 recipient, 758 758 ); ··· 776 776 } 777 777 } 778 778 } else if let Some(ref user_email) = email { 779 - let token = crate::auth::verification_token::generate_migration_token(&did, user_email); 779 + let token = 780 + crate::auth::verification_token::generate_migration_token(&did_for_commit, user_email); 780 781 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 781 782 if let Err(e) = crate::comms::comms_repo::enqueue_migration_verification( 782 783 state.user_repo.as_ref(), ··· 809 810 } 810 811 }; 811 812 let session_data = tranquil_db_traits::SessionTokenCreate { 812 - did: unsafe { Did::new_unchecked(&did) }, 813 + did: did_for_commit.clone(), 813 814 access_jti: access_meta.jti.clone(), 814 815 refresh_jti: refresh_meta.jti.clone(), 815 816 access_expires_at: access_meta.expires_at, ··· 838 839 StatusCode::OK, 839 840 Json(CreateAccountOutput { 840 841 handle: handle.clone().into(), 841 - did: unsafe { Did::new_unchecked(&did) }, 842 + did: did_for_commit, 842 843 did_doc, 843 844 access_jwt: access_meta.token, 844 845 refresh_jwt: refresh_meta.token, 845 846 verification_required: !is_migration, 846 - verification_channel: verification_channel.to_string(), 847 + verification_channel, 847 848 }), 848 849 ) 849 850 .into_response()
+100 -52
crates/tranquil-pds/src/api/identity/did.rs
··· 42 42 if handle_str.is_empty() { 43 43 return ApiError::InvalidRequest("handle is required".into()).into_response(); 44 44 } 45 - let cache_key = format!("handle:{}", handle_str); 45 + let cache_key = crate::cache_keys::handle_key(handle_str); 46 46 if let Some(did) = state.cache.get(&cache_key).await { 47 47 return DidResponse::response(did).into_response(); 48 48 } ··· 78 78 } 79 79 } 80 80 81 - pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 82 - let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 81 + #[derive(Debug)] 82 + pub enum KeyError { 83 + InvalidKeyLength, 84 + MissingCoordinate, 85 + } 86 + 87 + impl std::fmt::Display for KeyError { 88 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 89 + match self { 90 + Self::InvalidKeyLength => write!(f, "invalid key length"), 91 + Self::MissingCoordinate => write!(f, "missing elliptic curve coordinate"), 92 + } 93 + } 94 + } 95 + 96 + impl std::error::Error for KeyError {} 97 + 98 + pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, KeyError> { 99 + let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| KeyError::InvalidKeyLength)?; 83 100 let public_key = secret_key.public_key(); 84 101 let encoded = public_key.to_encoded_point(false); 85 - let x = encoded.x().ok_or("Missing x coordinate")?; 86 - let y = encoded.y().ok_or("Missing y coordinate")?; 102 + let x = encoded.x().ok_or(KeyError::MissingCoordinate)?; 103 + let y = encoded.y().ok_or(KeyError::MissingCoordinate)?; 87 104 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 88 105 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 89 106 Ok(json!({ ··· 94 111 })) 95 112 } 96 113 97 - pub fn get_public_key_multibase(key_bytes: &[u8]) -> Result<String, &'static str> { 98 - let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 114 + pub fn get_public_key_multibase(key_bytes: &[u8]) -> Result<String, KeyError> { 115 + let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| KeyError::InvalidKeyLength)?; 99 116 let public_key = secret_key.public_key(); 100 117 let compressed = public_key.to_encoded_point(true); 101 118 let compressed_bytes = compressed.as_bytes(); ··· 107 124 pub async fn well_known_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 108 125 let hostname = pds_hostname(); 109 126 let hostname_without_port = pds_hostname_without_port(); 110 - let host_header = get_header_str(&headers, "host").unwrap_or(hostname); 127 + let host_header = get_header_str(&headers, http::header::HOST).unwrap_or(hostname); 111 128 let host_without_port = host_header.split(':').next().unwrap_or(host_header); 112 129 if host_without_port != hostname_without_port 113 130 && host_without_port.ends_with(&format!(".{}", hostname_without_port)) ··· 127 144 "id": did, 128 145 "service": [{ 129 146 "id": "#atproto_pds", 130 - "type": "AtprotoPersonalDataServer", 147 + "type": crate::plc::ServiceType::Pds.as_str(), 131 148 "serviceEndpoint": format!("https://{}", hostname) 132 149 }] 133 150 })) ··· 197 214 })).collect::<Vec<_>>(), 198 215 "service": [{ 199 216 "id": "#atproto_pds", 200 - "type": "AtprotoPersonalDataServer", 217 + "type": crate::plc::ServiceType::Pds.as_str(), 201 218 "serviceEndpoint": service_endpoint 202 219 }] 203 220 })) ··· 250 267 }], 251 268 "service": [{ 252 269 "id": "#atproto_pds", 253 - "type": "AtprotoPersonalDataServer", 270 + "type": crate::plc::ServiceType::Pds.as_str(), 254 271 "serviceEndpoint": service_endpoint 255 272 }] 256 273 })) ··· 332 349 })).collect::<Vec<_>>(), 333 350 "service": [{ 334 351 "id": "#atproto_pds", 335 - "type": "AtprotoPersonalDataServer", 352 + "type": crate::plc::ServiceType::Pds.as_str(), 336 353 "serviceEndpoint": service_endpoint 337 354 }] 338 355 })) ··· 385 402 }], 386 403 "service": [{ 387 404 "id": "#atproto_pds", 388 - "type": "AtprotoPersonalDataServer", 405 + "type": crate::plc::ServiceType::Pds.as_str(), 389 406 "serviceEndpoint": service_endpoint 390 407 }] 391 408 })) 392 409 .into_response() 393 410 } 394 411 412 + #[derive(Debug, thiserror::Error)] 413 + pub enum DidWebVerifyError { 414 + #[error("Invalid did:web format")] 415 + InvalidFormat, 416 + #[error("Invalid DID path for this PDS. Expected {0}")] 417 + InvalidPath(String), 418 + #[error( 419 + "External did:web requires a pre-reserved signing key. Call com.atproto.server.reserveSigningKey first, configure your DID document with the returned key, then provide the signingKey in createAccount." 420 + )] 421 + MissingSigningKey, 422 + #[error("Failed to fetch DID doc: {0}")] 423 + FetchFailed(String), 424 + #[error("Invalid DID document: {0}")] 425 + InvalidDocument(String), 426 + #[error("DID document does not list this PDS ({0}) as AtprotoPersonalDataServer")] 427 + PdsNotListed(String), 428 + #[error( 429 + "DID document verification key does not match reserved signing key. Expected publicKeyMultibase: {0}" 430 + )] 431 + KeyMismatch(String), 432 + #[error("Invalid signing key format")] 433 + InvalidSigningKey, 434 + } 435 + 395 436 pub async fn verify_did_web( 396 437 did: &str, 397 438 hostname: &str, 398 439 handle: &str, 399 440 expected_signing_key: Option<&str>, 400 - ) -> Result<(), String> { 441 + ) -> Result<(), DidWebVerifyError> { 401 442 let hostname_for_handles = hostname.split(':').next().unwrap_or(hostname); 402 443 let subdomain_host = format!("{}.{}", handle, hostname_for_handles); 403 444 let encoded_subdomain = subdomain_host.replace(':', "%3A"); ··· 413 454 if did.starts_with(&expected_prefix) { 414 455 let suffix = &did[expected_prefix.len()..]; 415 456 let expected_suffix = format!(":u:{}", handle); 416 - if suffix == expected_suffix { 417 - return Ok(()); 457 + return if suffix == expected_suffix { 458 + Ok(()) 418 459 } else { 419 - return Err(format!( 420 - "Invalid DID path for this PDS. Expected {}", 421 - expected_suffix 422 - )); 423 - } 460 + Err(DidWebVerifyError::InvalidPath(expected_suffix)) 461 + }; 424 462 } 425 - let expected_signing_key = expected_signing_key.ok_or_else(|| { 426 - "External did:web requires a pre-reserved signing key. Call com.atproto.server.reserveSigningKey first, configure your DID document with the returned key, then provide the signingKey in createAccount.".to_string() 427 - })?; 463 + let expected_signing_key = expected_signing_key.ok_or(DidWebVerifyError::MissingSigningKey)?; 428 464 let parts: Vec<&str> = did.split(':').collect(); 429 465 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 430 - return Err("Invalid did:web format".into()); 466 + return Err(DidWebVerifyError::InvalidFormat); 431 467 } 432 468 let domain_segment = parts[2]; 433 469 let domain = domain_segment.replace("%3A", ":"); ··· 447 483 .get(&url) 448 484 .send() 449 485 .await 450 - .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 486 + .map_err(|e| DidWebVerifyError::FetchFailed(e.to_string()))?; 451 487 if !resp.status().is_success() { 452 - return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 488 + return Err(DidWebVerifyError::FetchFailed(format!( 489 + "HTTP {}", 490 + resp.status() 491 + ))); 453 492 } 454 493 let doc: serde_json::Value = resp 455 494 .json() 456 495 .await 457 - .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 496 + .map_err(|e| DidWebVerifyError::InvalidDocument(e.to_string()))?; 458 497 let services = doc["service"] 459 498 .as_array() 460 - .ok_or("No services found in DID doc")?; 499 + .ok_or(DidWebVerifyError::InvalidDocument( 500 + "No services found".to_string(), 501 + ))?; 461 502 let pds_endpoint = format!("https://{}", hostname); 462 - let has_valid_service = services 463 - .iter() 464 - .any(|s| s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint); 503 + let has_valid_service = services.iter().any(|s| { 504 + s["type"] == crate::plc::ServiceType::Pds.as_str() && s["serviceEndpoint"] == pds_endpoint 505 + }); 465 506 if !has_valid_service { 466 - return Err(format!( 467 - "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 468 - pds_endpoint 469 - )); 507 + return Err(DidWebVerifyError::PdsNotListed(pds_endpoint)); 470 508 } 471 - let verification_methods = doc["verificationMethod"] 472 - .as_array() 473 - .ok_or("No verificationMethod found in DID doc")?; 509 + let verification_methods = 510 + doc["verificationMethod"] 511 + .as_array() 512 + .ok_or(DidWebVerifyError::InvalidDocument( 513 + "No verificationMethod found".to_string(), 514 + ))?; 474 515 let expected_multibase = expected_signing_key 475 516 .strip_prefix("did:key:") 476 - .ok_or("Invalid signing key format")?; 517 + .ok_or(DidWebVerifyError::InvalidSigningKey)?; 477 518 let has_matching_key = verification_methods.iter().any(|vm| { 478 519 vm["publicKeyMultibase"] 479 520 .as_str() 480 - .map(|pk| pk == expected_multibase) 481 - .unwrap_or(false) 521 + .is_some_and(|pk| pk == expected_multibase) 482 522 }); 483 523 if !has_matching_key { 484 - return Err(format!( 485 - "DID document verification key does not match reserved signing key. Expected publicKeyMultibase: {}", 486 - expected_multibase 524 + return Err(DidWebVerifyError::KeyMismatch( 525 + expected_multibase.to_string(), 487 526 )); 488 527 } 489 528 Ok(()) ··· 559 598 verification_methods: VerificationMethods { atproto: did_key }, 560 599 services: Services { 561 600 atproto_pds: AtprotoPds { 562 - service_type: "AtprotoPersonalDataServer".to_string(), 601 + service_type: crate::plc::ServiceType::Pds.as_str().to_string(), 563 602 endpoint: pds_endpoint, 564 603 }, 565 604 }, ··· 579 618 Json(input): Json<UpdateHandleInput>, 580 619 ) -> Result<Response, ApiError> { 581 620 if let Err(e) = crate::auth::scope_check::check_identity_scope( 582 - auth.is_oauth(), 621 + &auth.auth_source, 583 622 auth.scope.as_deref(), 584 623 crate::oauth::scopes::IdentityAttr::Handle, 585 624 ) { ··· 652 691 format!("{}.{}", new_handle, hostname_for_handles) 653 692 }; 654 693 if full_handle == current_handle { 655 - let handle_typed = unsafe { Handle::new_unchecked(&full_handle) }; 694 + let handle_typed: Handle = match full_handle.parse() { 695 + Ok(h) => h, 696 + Err(_) => return Err(ApiError::InvalidHandle(None)), 697 + }; 656 698 if let Err(e) = 657 699 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 658 700 .await ··· 675 717 full_handle 676 718 } else { 677 719 if new_handle == current_handle { 678 - let handle_typed = unsafe { Handle::new_unchecked(&new_handle) }; 720 + let handle_typed: Handle = match new_handle.parse() { 721 + Ok(h) => h, 722 + Err(_) => return Err(ApiError::InvalidHandle(None)), 723 + }; 679 724 if let Err(e) = 680 725 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 681 726 .await ··· 728 773 if !current_handle.is_empty() { 729 774 let _ = state 730 775 .cache 731 - .delete(&format!("handle:{}", current_handle)) 776 + .delete(&crate::cache_keys::handle_key(&current_handle)) 732 777 .await; 733 778 } 734 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 779 + let _ = state 780 + .cache 781 + .delete(&crate::cache_keys::handle_key(&handle)) 782 + .await; 735 783 if let Err(e) = 736 784 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)).await 737 785 { ··· 768 816 } 769 817 770 818 pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 771 - let host = match crate::util::get_header_str(&headers, "host") { 819 + let host = match crate::util::get_header_str(&headers, http::header::HOST) { 772 820 Some(h) => h, 773 821 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 774 822 };
+2 -10
crates/tranquil-pds/src/api/identity/handle.rs
··· 1 - use crate::api::error::ApiError; 2 1 use crate::rate_limit::{HandleVerificationLimit, RateLimited}; 3 2 use crate::types::{Did, Handle}; 4 3 use axum::{ ··· 9 8 10 9 #[derive(Deserialize)] 11 10 pub struct VerifyHandleOwnershipInput { 12 - pub handle: String, 11 + pub handle: Handle, 13 12 pub did: Did, 14 13 } 15 14 ··· 27 26 _rate_limit: RateLimited<HandleVerificationLimit>, 28 27 Json(input): Json<VerifyHandleOwnershipInput>, 29 28 ) -> Response { 30 - let handle: Handle = match input.handle.parse() { 31 - Ok(h) => h, 32 - Err(_) => { 33 - return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(); 34 - } 35 - }; 36 - 37 - let handle_str = handle.as_str(); 29 + let handle_str = input.handle.as_str(); 38 30 let did_str = input.did.as_str(); 39 31 40 32 let dns_mismatch = match crate::handle::resolve_handle_dns(handle_str).await {
+1 -1
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 19 19 auth: Auth<Permissive>, 20 20 ) -> Result<Response, ApiError> { 21 21 if let Err(e) = crate::auth::scope_check::check_identity_scope( 22 - auth.is_oauth(), 22 + &auth.auth_source, 23 23 auth.scope.as_deref(), 24 24 crate::oauth::scopes::IdentityAttr::Wildcard, 25 25 ) {
+3 -3
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 2 2 use crate::api::error::DbResultExt; 3 3 use crate::auth::{Auth, Permissive}; 4 4 use crate::circuit_breaker::with_circuit_breaker; 5 - use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation}; 5 + use crate::plc::{PlcClient, PlcError, PlcService, ServiceType, create_update_op, sign_operation}; 6 6 use crate::state::AppState; 7 7 use axum::{ 8 8 Json, ··· 30 30 #[derive(Debug, Deserialize, Clone)] 31 31 pub struct ServiceInput { 32 32 #[serde(rename = "type")] 33 - pub service_type: String, 33 + pub service_type: ServiceType, 34 34 pub endpoint: String, 35 35 } 36 36 ··· 45 45 Json(input): Json<SignPlcOperationInput>, 46 46 ) -> Result<Response, ApiError> { 47 47 if let Err(e) = crate::auth::scope_check::check_identity_scope( 48 - auth.is_oauth(), 48 + &auth.auth_source, 49 49 auth.scope.as_deref(), 50 50 crate::oauth::scopes::IdentityAttr::Wildcard, 51 51 ) {
+14 -5
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 26 26 Json(input): Json<SubmitPlcOperationInput>, 27 27 ) -> Result<Response, ApiError> { 28 28 if let Err(e) = crate::auth::scope_check::check_identity_scope( 29 - auth.is_oauth(), 29 + &auth.auth_source, 30 30 auth.scope.as_deref(), 31 31 crate::oauth::scopes::IdentityAttr::Wildcard, 32 32 ) { ··· 87 87 { 88 88 let service_type = pds.get("type").and_then(|v| v.as_str()); 89 89 let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 90 - if service_type != Some("AtprotoPersonalDataServer") { 90 + if service_type != Some(crate::plc::ServiceType::Pds.as_str()) { 91 91 return Err(ApiError::InvalidRequest( 92 92 "Incorrect type on atproto_pds service".into(), 93 93 )); ··· 143 143 warn!("Failed to sequence identity event: {:?}", e); 144 144 } 145 145 } 146 - let _ = state.cache.delete(&format!("handle:{}", user.handle)).await; 147 - let _ = state.cache.delete(&format!("plc:doc:{}", did)).await; 148 - let _ = state.cache.delete(&format!("plc:data:{}", did)).await; 146 + let _ = state 147 + .cache 148 + .delete(&crate::cache_keys::handle_key(&user.handle)) 149 + .await; 150 + let _ = state 151 + .cache 152 + .delete(&crate::cache_keys::plc_doc_key(did)) 153 + .await; 154 + let _ = state 155 + .cache 156 + .delete(&crate::cache_keys::plc_data_key(did)) 157 + .await; 149 158 if state.did_resolver.refresh_did(did).await.is_none() { 150 159 warn!(did = %did, "Failed to refresh DID cache after PLC update"); 151 160 }
+1 -1
crates/tranquil-pds/src/api/mod.rs
··· 19 19 pub mod verification; 20 20 21 21 pub use error::ApiError; 22 - pub use proxy_client::{AtUriParts, proxy_client, validate_at_uri, validate_did, validate_limit}; 22 + pub use proxy_client::{AtUriParts, proxy_client, validate_at_uri, validate_limit}; 23 23 pub use responses::{ 24 24 DidResponse, EmptyResponse, EnabledResponse, HasPasswordResponse, OptionsResponse, 25 25 StatusResponse, SuccessResponse, TokenRequiredResponse, VerifiedResponse,
+48 -26
crates/tranquil-pds/src/api/moderation/mod.rs
··· 12 12 use serde_json::{Value, json}; 13 13 use tracing::{error, info}; 14 14 15 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 16 + pub enum ReportReasonType { 17 + #[serde(rename = "com.atproto.moderation.defs#reasonSpam")] 18 + Spam, 19 + #[serde(rename = "com.atproto.moderation.defs#reasonViolation")] 20 + Violation, 21 + #[serde(rename = "com.atproto.moderation.defs#reasonMisleading")] 22 + Misleading, 23 + #[serde(rename = "com.atproto.moderation.defs#reasonSexual")] 24 + Sexual, 25 + #[serde(rename = "com.atproto.moderation.defs#reasonRude")] 26 + Rude, 27 + #[serde(rename = "com.atproto.moderation.defs#reasonOther")] 28 + Other, 29 + #[serde(rename = "com.atproto.moderation.defs#reasonAppeal")] 30 + Appeal, 31 + } 32 + 33 + impl ReportReasonType { 34 + pub fn as_str(self) -> &'static str { 35 + match self { 36 + Self::Spam => "com.atproto.moderation.defs#reasonSpam", 37 + Self::Violation => "com.atproto.moderation.defs#reasonViolation", 38 + Self::Misleading => "com.atproto.moderation.defs#reasonMisleading", 39 + Self::Sexual => "com.atproto.moderation.defs#reasonSexual", 40 + Self::Rude => "com.atproto.moderation.defs#reasonRude", 41 + Self::Other => "com.atproto.moderation.defs#reasonOther", 42 + Self::Appeal => "com.atproto.moderation.defs#reasonAppeal", 43 + } 44 + } 45 + } 46 + 15 47 #[derive(Deserialize)] 16 48 #[serde(rename_all = "camelCase")] 17 49 pub struct CreateReportInput { 18 - pub reason_type: String, 50 + pub reason_type: ReportReasonType, 19 51 pub reason: Option<String>, 20 52 pub subject: Value, 21 53 } ··· 24 56 #[serde(rename_all = "camelCase")] 25 57 pub struct CreateReportOutput { 26 58 pub id: i64, 27 - pub reason_type: String, 59 + pub reason_type: ReportReasonType, 28 60 pub reason: Option<String>, 29 61 pub subject: Value, 30 62 pub reported_by: String, 31 63 pub created_at: String, 32 64 } 33 65 34 - fn get_report_service_config() -> Option<(String, String)> { 66 + struct ReportServiceConfig { 67 + url: String, 68 + did: String, 69 + } 70 + 71 + fn get_report_service_config() -> Option<ReportServiceConfig> { 35 72 let url = std::env::var("REPORT_SERVICE_URL").ok()?; 36 73 let did = std::env::var("REPORT_SERVICE_DID").ok()?; 37 74 if url.is_empty() || did.is_empty() { 38 75 return None; 39 76 } 40 - Some((url, did)) 77 + Some(ReportServiceConfig { url, did }) 41 78 } 42 79 43 80 pub async fn create_report( ··· 47 84 ) -> Response { 48 85 let did = &auth.did; 49 86 50 - if let Some((service_url, service_did)) = get_report_service_config() { 51 - return proxy_to_report_service(&state, &auth, &service_url, &service_did, &input).await; 87 + if let Some(config) = get_report_service_config() { 88 + return proxy_to_report_service(&state, &auth, &config.url, &config.did, &input).await; 52 89 } 53 90 54 91 create_report_locally(&state, did, auth.status.is_takendown(), input).await ··· 177 214 is_takendown: bool, 178 215 input: CreateReportInput, 179 216 ) -> Response { 180 - const REASON_APPEAL: &str = "com.atproto.moderation.defs#reasonAppeal"; 181 - 182 - if is_takendown && input.reason_type != REASON_APPEAL { 217 + if is_takendown && input.reason_type != ReportReasonType::Appeal { 183 218 return ApiError::InvalidRequest("Report not accepted from takendown account".into()) 184 219 .into_response(); 185 220 } 186 221 187 - let valid_reason_types = [ 188 - "com.atproto.moderation.defs#reasonSpam", 189 - "com.atproto.moderation.defs#reasonViolation", 190 - "com.atproto.moderation.defs#reasonMisleading", 191 - "com.atproto.moderation.defs#reasonSexual", 192 - "com.atproto.moderation.defs#reasonRude", 193 - "com.atproto.moderation.defs#reasonOther", 194 - REASON_APPEAL, 195 - ]; 196 - 197 - if !valid_reason_types.contains(&input.reason_type.as_str()) { 198 - return ApiError::InvalidRequest("Invalid reasonType".into()).into_response(); 199 - } 200 - 201 222 let created_at = chrono::Utc::now(); 202 - let report_id = (uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) as i64; 223 + let report_id = i64::try_from(uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) 224 + .expect("masked to 63 bits, always fits i64"); 203 225 let subject_json = json!(input.subject); 204 226 205 227 if let Err(e) = state 206 228 .infra_repo 207 229 .insert_report( 208 230 report_id, 209 - &input.reason_type, 231 + input.reason_type.as_str(), 210 232 input.reason.as_deref(), 211 233 subject_json, 212 234 did, ··· 221 243 info!( 222 244 report_id = %report_id, 223 245 reported_by = %did, 224 - reason_type = %input.reason_type, 246 + reason_type = input.reason_type.as_str(), 225 247 "Report created locally (no report service configured)" 226 248 ); 227 249
+100 -103
crates/tranquil-pds/src/api/notification_prefs.rs
··· 11 11 use serde_json::json; 12 12 use tracing::info; 13 13 use tranquil_db_traits::{CommsChannel, CommsStatus, CommsType}; 14 + use tranquil_types::Did; 14 15 15 16 #[derive(Serialize)] 16 17 #[serde(rename_all = "camelCase")] ··· 130 131 pub struct UpdateNotificationPrefsResponse { 131 132 pub success: bool, 132 133 #[serde(skip_serializing_if = "Vec::is_empty")] 133 - pub verification_required: Vec<String>, 134 + pub verification_required: Vec<CommsChannel>, 134 135 } 135 136 136 137 pub async fn request_channel_verification( 137 138 state: &AppState, 138 139 user_id: uuid::Uuid, 139 - did: &str, 140 - channel: &str, 140 + did: &Did, 141 + channel: CommsChannel, 141 142 identifier: &str, 142 143 handle: Option<&str>, 143 - ) -> Result<String, String> { 144 + ) -> Result<String, ApiError> { 144 145 let token = 145 146 crate::auth::verification_token::generate_channel_update_token(did, channel, identifier); 146 147 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 147 148 148 - if channel == "email" { 149 - let hostname = pds_hostname(); 150 - let handle_str = handle.unwrap_or("user"); 151 - crate::comms::comms_repo::enqueue_email_update( 152 - state.infra_repo.as_ref(), 153 - user_id, 154 - identifier, 155 - handle_str, 156 - &formatted_token, 157 - hostname, 158 - ) 159 - .await 160 - .map_err(|e| format!("Failed to enqueue email notification: {}", e))?; 161 - } else { 162 - let comms_channel = match channel { 163 - "discord" => tranquil_db_traits::CommsChannel::Discord, 164 - "telegram" => tranquil_db_traits::CommsChannel::Telegram, 165 - "signal" => tranquil_db_traits::CommsChannel::Signal, 166 - _ => return Err("Invalid channel".to_string()), 167 - }; 168 - let hostname = pds_hostname(); 169 - let encoded_token = urlencoding::encode(&formatted_token); 170 - let encoded_identifier = urlencoding::encode(identifier); 171 - let verify_link = format!( 172 - "https://{}/app/verify?token={}&identifier={}", 173 - hostname, encoded_token, encoded_identifier 174 - ); 175 - let prefs = state 176 - .user_repo 177 - .get_comms_prefs(user_id) 149 + match channel { 150 + CommsChannel::Email => { 151 + let hostname = pds_hostname(); 152 + let handle_str = handle.unwrap_or("user"); 153 + crate::comms::comms_repo::enqueue_email_update( 154 + state.infra_repo.as_ref(), 155 + user_id, 156 + identifier, 157 + handle_str, 158 + &formatted_token, 159 + hostname, 160 + ) 178 161 .await 179 - .ok() 180 - .flatten(); 181 - let locale = prefs 182 - .as_ref() 183 - .and_then(|p| p.preferred_locale.as_deref()) 184 - .unwrap_or("en"); 185 - let strings = crate::comms::get_strings(locale); 186 - let body = crate::comms::format_message( 187 - strings.channel_verification_body, 188 - &[("code", &formatted_token), ("verify_link", &verify_link)], 189 - ); 190 - let subject = crate::comms::format_message( 191 - strings.channel_verification_subject, 192 - &[("hostname", hostname)], 193 - ); 194 - let recipient = match comms_channel { 195 - tranquil_db_traits::CommsChannel::Telegram => state 162 + .map_err(|e| { 163 + ApiError::InternalError(Some(format!( 164 + "Failed to enqueue email notification: {}", 165 + e 166 + ))) 167 + })?; 168 + } 169 + _ => { 170 + let hostname = pds_hostname(); 171 + let encoded_token = urlencoding::encode(&formatted_token); 172 + let encoded_identifier = urlencoding::encode(identifier); 173 + let verify_link = format!( 174 + "https://{}/app/verify?token={}&identifier={}", 175 + hostname, encoded_token, encoded_identifier 176 + ); 177 + let prefs = state 196 178 .user_repo 197 - .get_telegram_chat_id(user_id) 179 + .get_comms_prefs(user_id) 198 180 .await 199 181 .ok() 200 - .flatten() 201 - .map(|id| id.to_string()) 202 - .unwrap_or_else(|| identifier.to_string()), 203 - _ => identifier.to_string(), 204 - }; 205 - state 206 - .infra_repo 207 - .enqueue_comms( 208 - Some(user_id), 209 - comms_channel, 210 - tranquil_db_traits::CommsType::ChannelVerification, 211 - &recipient, 212 - Some(&subject), 213 - &body, 214 - Some(json!({"code": formatted_token})), 215 - ) 216 - .await 217 - .map_err(|e| format!("Failed to enqueue notification: {}", e))?; 182 + .flatten(); 183 + let locale = prefs 184 + .as_ref() 185 + .and_then(|p| p.preferred_locale.as_deref()) 186 + .unwrap_or("en"); 187 + let strings = crate::comms::get_strings(locale); 188 + let body = crate::comms::format_message( 189 + strings.channel_verification_body, 190 + &[("code", &formatted_token), ("verify_link", &verify_link)], 191 + ); 192 + let subject = crate::comms::format_message( 193 + strings.channel_verification_subject, 194 + &[("hostname", hostname)], 195 + ); 196 + let recipient = match channel { 197 + CommsChannel::Telegram => state 198 + .user_repo 199 + .get_telegram_chat_id(user_id) 200 + .await 201 + .ok() 202 + .flatten() 203 + .map(|id| id.to_string()) 204 + .unwrap_or_else(|| identifier.to_string()), 205 + _ => identifier.to_string(), 206 + }; 207 + state 208 + .infra_repo 209 + .enqueue_comms( 210 + Some(user_id), 211 + channel, 212 + tranquil_db_traits::CommsType::ChannelVerification, 213 + &recipient, 214 + Some(&subject), 215 + &body, 216 + Some(json!({"code": formatted_token})), 217 + ) 218 + .await 219 + .map_err(|e| { 220 + ApiError::InternalError(Some(format!("Failed to enqueue notification: {}", e))) 221 + })?; 222 + } 218 223 } 219 224 220 225 Ok(token) ··· 246 251 let effective_channel = input 247 252 .preferred_channel 248 253 .as_deref() 249 - .map(|ch| match ch { 250 - "email" => Ok(CommsChannel::Email), 251 - "discord" => Ok(CommsChannel::Discord), 252 - "telegram" => Ok(CommsChannel::Telegram), 253 - "signal" => Ok(CommsChannel::Signal), 254 - _ => Err(ApiError::InvalidRequest( 255 - "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 256 - )), 254 + .map(|ch| { 255 + ch.parse::<CommsChannel>().map_err(|_| { 256 + ApiError::InvalidRequest( 257 + "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 258 + ) 259 + }) 257 260 }) 258 261 .transpose()? 259 262 .unwrap_or(current_prefs.preferred_channel); 260 263 261 - let mut verification_required: Vec<String> = Vec::new(); 264 + let mut verification_required: Vec<CommsChannel> = Vec::new(); 262 265 263 - if let Some(ref channel_str) = input.preferred_channel { 264 - let channel = match channel_str.as_str() { 265 - "email" => CommsChannel::Email, 266 - "discord" => CommsChannel::Discord, 267 - "telegram" => CommsChannel::Telegram, 268 - "signal" => CommsChannel::Signal, 269 - _ => { 270 - return Err(ApiError::InvalidRequest( 271 - "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 272 - )); 273 - } 274 - }; 266 + if input.preferred_channel.is_some() { 275 267 state 276 268 .user_repo 277 - .update_preferred_comms_channel(&auth.did, channel) 269 + .update_preferred_comms_channel(&auth.did, effective_channel) 278 270 .await 279 271 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 280 - info!(did = %auth.did, channel = ?channel, "Updated preferred notification channel"); 272 + info!(did = %auth.did, channel = ?effective_channel, "Updated preferred notification channel"); 281 273 } 282 274 283 275 if let Some(ref new_email) = input.email { ··· 295 287 &state, 296 288 user_id, 297 289 &auth.did, 298 - "email", 290 + CommsChannel::Email, 299 291 &email_clean, 300 292 Some(&handle), 301 293 ) 302 - .await 303 - .map_err(|e| ApiError::InternalError(Some(e)))?; 304 - verification_required.push("email".to_string()); 294 + .await?; 295 + verification_required.push(CommsChannel::Email); 305 296 info!(did = %auth.did, "Requested email verification"); 306 297 } 307 298 } ··· 331 322 .set_unverified_discord(user_id, &discord_clean) 332 323 .await 333 324 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 334 - verification_required.push("discord".to_string()); 325 + verification_required.push(CommsChannel::Discord); 335 326 info!(did = %auth.did, discord_username = %discord_clean, "Stored unverified Discord username"); 336 327 } 337 328 } ··· 361 352 .set_unverified_telegram(user_id, telegram_clean) 362 353 .await 363 354 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 364 - verification_required.push("telegram".to_string()); 355 + verification_required.push(CommsChannel::Telegram); 365 356 info!(did = %auth.did, telegram_username = %telegram_clean, "Stored unverified Telegram username"); 366 357 } 367 358 } ··· 391 382 .set_unverified_signal(user_id, &signal_clean) 392 383 .await 393 384 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 394 - request_channel_verification(&state, user_id, &auth.did, "signal", &signal_clean, None) 395 - .await 396 - .map_err(|e| ApiError::InternalError(Some(e)))?; 397 - verification_required.push("signal".to_string()); 385 + request_channel_verification( 386 + &state, 387 + user_id, 388 + &auth.did, 389 + CommsChannel::Signal, 390 + &signal_clean, 391 + None, 392 + ) 393 + .await?; 394 + verification_required.push(CommsChannel::Signal); 398 395 info!(did = %auth.did, signal_username = %signal_clean, "Stored unverified Signal username"); 399 396 } 400 397 }
+106 -97
crates/tranquil-pds/src/api/proxy.rs
··· 1 + use std::collections::HashSet; 1 2 use std::convert::Infallible; 3 + use std::sync::LazyLock; 2 4 3 5 use crate::api::error::ApiError; 4 6 use crate::api::proxy_client::proxy_client; ··· 15 17 use tower::{Service, util::BoxCloneSyncService}; 16 18 use tracing::{error, info, warn}; 17 19 18 - const PROTECTED_METHODS: &[&str] = &[ 19 - "app.bsky.actor.getPreferences", 20 - "app.bsky.actor.putPreferences", 21 - "com.atproto.admin.deleteAccount", 22 - "com.atproto.admin.disableAccountInvites", 23 - "com.atproto.admin.disableInviteCodes", 24 - "com.atproto.admin.enableAccountInvites", 25 - "com.atproto.admin.getAccountInfo", 26 - "com.atproto.admin.getAccountInfos", 27 - "com.atproto.admin.getInviteCodes", 28 - "com.atproto.admin.getSubjectStatus", 29 - "com.atproto.admin.searchAccounts", 30 - "com.atproto.admin.sendEmail", 31 - "com.atproto.admin.updateAccountEmail", 32 - "com.atproto.admin.updateAccountHandle", 33 - "com.atproto.admin.updateAccountPassword", 34 - "com.atproto.admin.updateSubjectStatus", 35 - "com.atproto.identity.getRecommendedDidCredentials", 36 - "com.atproto.identity.requestPlcOperationSignature", 37 - "com.atproto.identity.signPlcOperation", 38 - "com.atproto.identity.submitPlcOperation", 39 - "com.atproto.identity.updateHandle", 40 - "com.atproto.repo.applyWrites", 41 - "com.atproto.repo.createRecord", 42 - "com.atproto.repo.deleteRecord", 43 - "com.atproto.repo.importRepo", 44 - "com.atproto.repo.putRecord", 45 - "com.atproto.repo.uploadBlob", 46 - "com.atproto.server.activateAccount", 47 - "com.atproto.server.checkAccountStatus", 48 - "com.atproto.server.confirmEmail", 49 - "com.atproto.server.confirmSignup", 50 - "com.atproto.server.createAccount", 51 - "com.atproto.server.createAppPassword", 52 - "com.atproto.server.createInviteCode", 53 - "com.atproto.server.createInviteCodes", 54 - "com.atproto.server.createSession", 55 - "com.atproto.server.createTotpSecret", 56 - "com.atproto.server.deactivateAccount", 57 - "com.atproto.server.deleteAccount", 58 - "com.atproto.server.deletePasskey", 59 - "com.atproto.server.deleteSession", 60 - "com.atproto.server.describeServer", 61 - "com.atproto.server.disableTotp", 62 - "com.atproto.server.enableTotp", 63 - "com.atproto.server.finishPasskeyRegistration", 64 - "com.atproto.server.getAccountInviteCodes", 65 - "com.atproto.server.getServiceAuth", 66 - "com.atproto.server.getSession", 67 - "com.atproto.server.getTotpStatus", 68 - "com.atproto.server.listAppPasswords", 69 - "com.atproto.server.listPasskeys", 70 - "com.atproto.server.refreshSession", 71 - "com.atproto.server.regenerateBackupCodes", 72 - "com.atproto.server.requestAccountDelete", 73 - "com.atproto.server.requestEmailConfirmation", 74 - "com.atproto.server.requestEmailUpdate", 75 - "com.atproto.server.requestPasswordReset", 76 - "com.atproto.server.resendMigrationVerification", 77 - "com.atproto.server.resendVerification", 78 - "com.atproto.server.reserveSigningKey", 79 - "com.atproto.server.resetPassword", 80 - "com.atproto.server.revokeAppPassword", 81 - "com.atproto.server.startPasskeyRegistration", 82 - "com.atproto.server.updateEmail", 83 - "com.atproto.server.updatePasskey", 84 - "com.atproto.server.verifyMigrationEmail", 85 - "com.atproto.sync.getBlob", 86 - "com.atproto.sync.getBlocks", 87 - "com.atproto.sync.getCheckout", 88 - "com.atproto.sync.getHead", 89 - "com.atproto.sync.getLatestCommit", 90 - "com.atproto.sync.getRecord", 91 - "com.atproto.sync.getRepo", 92 - "com.atproto.sync.getRepoStatus", 93 - "com.atproto.sync.listBlobs", 94 - "com.atproto.sync.listRepos", 95 - "com.atproto.sync.notifyOfUpdate", 96 - "com.atproto.sync.requestCrawl", 97 - "com.atproto.sync.subscribeRepos", 98 - "com.atproto.temp.checkSignupQueue", 99 - "com.atproto.temp.dereferenceScope", 100 - ]; 20 + static PROTECTED_METHODS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| { 21 + [ 22 + "app.bsky.actor.getPreferences", 23 + "app.bsky.actor.putPreferences", 24 + "com.atproto.admin.deleteAccount", 25 + "com.atproto.admin.disableAccountInvites", 26 + "com.atproto.admin.disableInviteCodes", 27 + "com.atproto.admin.enableAccountInvites", 28 + "com.atproto.admin.getAccountInfo", 29 + "com.atproto.admin.getAccountInfos", 30 + "com.atproto.admin.getInviteCodes", 31 + "com.atproto.admin.getSubjectStatus", 32 + "com.atproto.admin.searchAccounts", 33 + "com.atproto.admin.sendEmail", 34 + "com.atproto.admin.updateAccountEmail", 35 + "com.atproto.admin.updateAccountHandle", 36 + "com.atproto.admin.updateAccountPassword", 37 + "com.atproto.admin.updateSubjectStatus", 38 + "com.atproto.identity.getRecommendedDidCredentials", 39 + "com.atproto.identity.requestPlcOperationSignature", 40 + "com.atproto.identity.signPlcOperation", 41 + "com.atproto.identity.submitPlcOperation", 42 + "com.atproto.identity.updateHandle", 43 + "com.atproto.repo.applyWrites", 44 + "com.atproto.repo.createRecord", 45 + "com.atproto.repo.deleteRecord", 46 + "com.atproto.repo.importRepo", 47 + "com.atproto.repo.putRecord", 48 + "com.atproto.repo.uploadBlob", 49 + "com.atproto.server.activateAccount", 50 + "com.atproto.server.checkAccountStatus", 51 + "com.atproto.server.confirmEmail", 52 + "com.atproto.server.confirmSignup", 53 + "com.atproto.server.createAccount", 54 + "com.atproto.server.createAppPassword", 55 + "com.atproto.server.createInviteCode", 56 + "com.atproto.server.createInviteCodes", 57 + "com.atproto.server.createSession", 58 + "com.atproto.server.createTotpSecret", 59 + "com.atproto.server.deactivateAccount", 60 + "com.atproto.server.deleteAccount", 61 + "com.atproto.server.deletePasskey", 62 + "com.atproto.server.deleteSession", 63 + "com.atproto.server.describeServer", 64 + "com.atproto.server.disableTotp", 65 + "com.atproto.server.enableTotp", 66 + "com.atproto.server.finishPasskeyRegistration", 67 + "com.atproto.server.getAccountInviteCodes", 68 + "com.atproto.server.getServiceAuth", 69 + "com.atproto.server.getSession", 70 + "com.atproto.server.getTotpStatus", 71 + "com.atproto.server.listAppPasswords", 72 + "com.atproto.server.listPasskeys", 73 + "com.atproto.server.refreshSession", 74 + "com.atproto.server.regenerateBackupCodes", 75 + "com.atproto.server.requestAccountDelete", 76 + "com.atproto.server.requestEmailConfirmation", 77 + "com.atproto.server.requestEmailUpdate", 78 + "com.atproto.server.requestPasswordReset", 79 + "com.atproto.server.resendMigrationVerification", 80 + "com.atproto.server.resendVerification", 81 + "com.atproto.server.reserveSigningKey", 82 + "com.atproto.server.resetPassword", 83 + "com.atproto.server.revokeAppPassword", 84 + "com.atproto.server.startPasskeyRegistration", 85 + "com.atproto.server.updateEmail", 86 + "com.atproto.server.updatePasskey", 87 + "com.atproto.server.verifyMigrationEmail", 88 + "com.atproto.sync.getBlob", 89 + "com.atproto.sync.getBlocks", 90 + "com.atproto.sync.getCheckout", 91 + "com.atproto.sync.getHead", 92 + "com.atproto.sync.getLatestCommit", 93 + "com.atproto.sync.getRecord", 94 + "com.atproto.sync.getRepo", 95 + "com.atproto.sync.getRepoStatus", 96 + "com.atproto.sync.listBlobs", 97 + "com.atproto.sync.listRepos", 98 + "com.atproto.sync.notifyOfUpdate", 99 + "com.atproto.sync.requestCrawl", 100 + "com.atproto.sync.subscribeRepos", 101 + "com.atproto.temp.checkSignupQueue", 102 + "com.atproto.temp.dereferenceScope", 103 + ] 104 + .into_iter() 105 + .collect() 106 + }); 101 107 102 108 fn is_protected_method(method: &str) -> bool { 103 - PROTECTED_METHODS.contains(&method) 109 + PROTECTED_METHODS.contains(method) 104 110 } 105 111 106 112 pub struct XrpcProxyLayer { ··· 192 198 .into_response(); 193 199 } 194 200 195 - let Some(proxy_header) = get_header_str(&headers, "atproto-proxy").map(String::from) else { 201 + let Some(proxy_header) = 202 + get_header_str(&headers, crate::util::HEADER_ATPROTO_PROXY).map(String::from) 203 + else { 196 204 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 197 205 .into_response(); 198 206 }; ··· 212 220 let client = proxy_client(); 213 221 let mut request_builder = client.request(method_verb.clone(), &target_url); 214 222 215 - let mut auth_header_val = headers.get("Authorization").cloned(); 223 + let mut auth_header_val = headers.get(http::header::AUTHORIZATION).cloned(); 216 224 if let Some(extracted) = crate::auth::extract_auth_token_from_header( 217 - crate::util::get_header_str(&headers, "Authorization"), 225 + crate::util::get_header_str(&headers, http::header::AUTHORIZATION), 218 226 ) { 219 227 let token = extracted.token; 220 - let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 228 + let dpop_proof = crate::util::get_header_str(&headers, crate::util::HEADER_DPOP); 221 229 let http_uri = crate::util::build_full_url(&format!("/xrpc{}", uri)); 222 230 223 231 match crate::auth::validate_token_with_dpop( 224 232 state.user_repo.as_ref(), 225 233 state.oauth_repo.as_ref(), 226 234 &token, 227 - extracted.is_dpop, 235 + extracted.scheme, 228 236 dpop_proof, 229 237 method_verb.as_str(), 230 238 &http_uri, 231 - false, 232 - false, 239 + crate::auth::AccountRequirement::Active, 233 240 ) 234 241 .await 235 242 { 236 243 Ok(auth_user) => { 237 244 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 238 - auth_user.is_oauth(), 245 + &auth_user.auth_source, 239 246 auth_user.scope.as_deref(), 240 247 &resolved.did, 241 248 method, ··· 298 305 info!(error = ?e, "Proxy token validation failed, returning error to client"); 299 306 let mut response = ApiError::from(e).into_response(); 300 307 if let Ok(nonce_val) = crate::oauth::verify::generate_dpop_nonce().parse() { 301 - response.headers_mut().insert("DPoP-Nonce", nonce_val); 308 + response 309 + .headers_mut() 310 + .insert(crate::util::HEADER_DPOP_NONCE, nonce_val); 302 311 } 303 312 return response; 304 313 } ··· 306 315 } 307 316 308 317 if let Some(val) = auth_header_val { 309 - request_builder = request_builder.header("Authorization", val); 318 + request_builder = request_builder.header(http::header::AUTHORIZATION, val); 310 319 } 311 320 request_builder = crate::api::proxy_client::HEADERS_TO_FORWARD 312 321 .iter() 313 - .filter_map(|name| headers.get(*name).map(|val| (*name, val))) 322 + .filter_map(|name| headers.get(name).map(|val| (name, val))) 314 323 .fold(request_builder, |builder, (name, val)| { 315 - builder.header(name, val) 324 + builder.header(name.as_str(), val) 316 325 }); 317 326 if !body.is_empty() { 318 327 request_builder = request_builder.body(body); ··· 333 342 let mut response_builder = Response::builder().status(status); 334 343 response_builder = crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD 335 344 .iter() 336 - .filter_map(|name| headers.get(*name).map(|val| (*name, val))) 345 + .filter_map(|name| headers.get(name).map(|val| (name, val))) 337 346 .fold(response_builder, |builder, (name, val)| { 338 347 builder.header(name, val) 339 348 });
+44 -56
crates/tranquil-pds/src/api/proxy_client.rs
··· 1 + use axum::http::HeaderName; 1 2 use reqwest::{Client, ClientBuilder, Url}; 2 3 use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; 3 - use std::sync::OnceLock; 4 + use std::sync::{LazyLock, OnceLock}; 4 5 use std::time::Duration; 5 6 use tracing::warn; 7 + use tranquil_types::{Did, Nsid, Rkey}; 6 8 7 9 pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 8 10 pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); ··· 146 148 147 149 impl std::error::Error for SsrfError {} 148 150 149 - pub const HEADERS_TO_FORWARD: &[&str] = &[ 150 - "accept-language", 151 - "atproto-accept-labelers", 152 - "x-bsky-topics", 153 - "content-type", 154 - ]; 155 - pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[ 156 - "atproto-repo-rev", 157 - "atproto-content-labelers", 158 - "retry-after", 159 - "content-type", 160 - "cache-control", 161 - "etag", 162 - ]; 151 + pub static HEADERS_TO_FORWARD: LazyLock<[HeaderName; 4]> = LazyLock::new(|| { 152 + [ 153 + HeaderName::from_static("accept-language"), 154 + crate::util::HEADER_ATPROTO_ACCEPT_LABELERS, 155 + crate::util::HEADER_X_BSKY_TOPICS, 156 + http::header::CONTENT_TYPE, 157 + ] 158 + }); 159 + pub static RESPONSE_HEADERS_TO_FORWARD: LazyLock<[HeaderName; 6]> = LazyLock::new(|| { 160 + [ 161 + crate::util::HEADER_ATPROTO_REPO_REV, 162 + crate::util::HEADER_ATPROTO_CONTENT_LABELERS, 163 + HeaderName::from_static("retry-after"), 164 + http::header::CONTENT_TYPE, 165 + http::header::CACHE_CONTROL, 166 + http::header::ETAG, 167 + ] 168 + }); 163 169 164 170 pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 165 171 if !uri.starts_with("at://") { ··· 170 176 if parts.is_empty() { 171 177 return Err("URI missing DID"); 172 178 } 173 - let did = parts[0]; 174 - if !did.starts_with("did:") { 175 - return Err("Invalid DID in URI"); 176 - } 177 - if parts.len() > 1 { 178 - let collection = parts[1]; 179 - if collection.is_empty() || !collection.contains('.') { 180 - return Err("Invalid collection NSID"); 181 - } 182 - } 179 + let did: Did = parts[0].parse().map_err(|_| "Invalid DID in URI")?; 180 + let collection = parts 181 + .get(1) 182 + .map(|s| s.parse::<Nsid>()) 183 + .transpose() 184 + .map_err(|_| "Invalid collection NSID")?; 185 + let rkey = parts 186 + .get(2) 187 + .map(|s| s.parse::<Rkey>()) 188 + .transpose() 189 + .map_err(|_| "Invalid rkey")?; 183 190 Ok(AtUriParts { 184 - did: did.to_string(), 185 - collection: parts.get(1).map(|s| s.to_string()), 186 - rkey: parts.get(2).map(|s| s.to_string()), 191 + did, 192 + collection, 193 + rkey, 187 194 }) 188 195 } 189 196 190 197 #[derive(Debug, Clone)] 191 198 pub struct AtUriParts { 192 - pub did: String, 193 - pub collection: Option<String>, 194 - pub rkey: Option<String>, 199 + pub did: Did, 200 + pub collection: Option<Nsid>, 201 + pub rkey: Option<Rkey>, 195 202 } 196 203 197 204 pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { ··· 203 210 } 204 211 } 205 212 206 - pub fn validate_did(did: &str) -> Result<(), &'static str> { 207 - if !did.starts_with("did:") { 208 - return Err("Invalid DID format"); 209 - } 210 - let parts: Vec<&str> = did.split(':').collect(); 211 - if parts.len() < 3 { 212 - return Err("DID must have at least method and identifier"); 213 - } 214 - let method = parts[1]; 215 - if method != "plc" && method != "web" { 216 - return Err("Unsupported DID method"); 217 - } 218 - Ok(()) 219 - } 220 - 221 213 #[cfg(test)] 222 214 mod tests { 223 215 use super::*; ··· 243 235 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123"); 244 236 assert!(result.is_ok()); 245 237 let parts = result.unwrap(); 246 - assert_eq!(parts.did, "did:plc:test"); 247 - assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string())); 248 - assert_eq!(parts.rkey, Some("abc123".to_string())); 238 + assert_eq!(parts.did, "did:plc:test".parse::<Did>().unwrap()); 239 + assert_eq!( 240 + parts.collection, 241 + Some("app.bsky.feed.post".parse::<Nsid>().unwrap()) 242 + ); 243 + assert_eq!(parts.rkey, Some("abc123".parse::<Rkey>().unwrap())); 249 244 } 250 245 #[test] 251 246 fn test_validate_at_uri_invalid() { ··· 258 253 assert_eq!(validate_limit(Some(0), 50, 100), 50); 259 254 assert_eq!(validate_limit(Some(200), 50, 100), 100); 260 255 assert_eq!(validate_limit(Some(75), 50, 100), 75); 261 - } 262 - #[test] 263 - fn test_validate_did() { 264 - assert!(validate_did("did:plc:abc123").is_ok()); 265 - assert!(validate_did("did:web:example.com").is_ok()); 266 - assert!(validate_did("notadid").is_err()); 267 - assert!(validate_did("did:unknown:test").is_err()); 268 256 } 269 257 }
+14 -7
crates/tranquil-pds/src/api/repo/blob.rs
··· 56 56 if user.status.is_takendown() { 57 57 return Err(ApiError::AccountTakedown); 58 58 } 59 - let mime_type_for_check = 60 - get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 59 + let mime_type_for_check = get_header_str(&headers, http::header::CONTENT_TYPE) 60 + .unwrap_or("application/octet-stream"); 61 61 let scope_proof = match user.verify_blob_upload(mime_type_for_check) { 62 62 Ok(proof) => proof, 63 63 Err(e) => return Ok(e.into_response()), ··· 79 79 } 80 80 81 81 let client_mime_hint = 82 - get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 82 + get_header_str(&headers, http::header::CONTENT_TYPE).unwrap_or("application/octet-stream"); 83 83 84 84 let user_id = state 85 85 .user_repo ··· 89 89 .ok_or(ApiError::InternalError(None))?; 90 90 91 91 let temp_key = format!("temp/{}", uuid::Uuid::new_v4()); 92 - let max_size = get_max_blob_size() as u64; 92 + let max_size = u64::try_from(get_max_blob_size()).unwrap_or(u64::MAX); 93 93 94 94 let body_stream = body.into_data_stream(); 95 95 let mapped_stream = ··· 148 148 149 149 match state 150 150 .blob_repo 151 - .insert_blob(&cid_link, &mime_type, size as i64, user_id, &storage_key) 151 + .insert_blob( 152 + &cid_link, 153 + &mime_type, 154 + i64::try_from(size).unwrap_or(i64::MAX), 155 + user_id, 156 + &storage_key, 157 + ) 152 158 .await 153 159 { 154 160 Ok(_) => {} ··· 248 254 .await 249 255 .log_db_err("fetching missing blobs")?; 250 256 251 - let has_more = missing.len() > limit as usize; 257 + let limit_usize = usize::try_from(limit).unwrap_or(0); 258 + let has_more = missing.len() > limit_usize; 252 259 let blobs: Vec<RecordBlob> = missing 253 260 .into_iter() 254 - .take(limit as usize) 261 + .take(limit_usize) 255 262 .map(|m| RecordBlob { 256 263 cid: m.blob_cid.to_string(), 257 264 record_uri: m.record_uri.to_string(),
+20 -13
crates/tranquil-pds/src/api/repo/import.rs
··· 92 92 "Root block not found in CAR file".into(), 93 93 )); 94 94 }; 95 - let commit_did = match jacquard_repo::commit::Commit::from_cbor(root_block) { 96 - Ok(commit) => commit.did().to_string(), 95 + let commit_did: Did = match jacquard_repo::commit::Commit::from_cbor(root_block) { 96 + Ok(commit) => commit 97 + .did() 98 + .as_str() 99 + .parse() 100 + .map_err(|_| ApiError::InvalidRequest("Commit contains invalid DID".into()))?, 97 101 Err(e) => { 98 102 return Err(ApiError::InvalidRequest(format!("Invalid commit: {}", e))); 99 103 } ··· 104 108 commit_did, did 105 109 ))); 106 110 } 107 - let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION") 108 - .map(|v| v == "true" || v == "1") 109 - .unwrap_or(false); 111 + let skip_verification = crate::util::parse_env_bool("SKIP_IMPORT_VERIFICATION"); 110 112 let is_migration = user.deactivated_at.is_some(); 111 113 if skip_verification { 112 114 warn!("Skipping all CAR verification for import (SKIP_IMPORT_VERIFICATION=true)"); ··· 221 223 .flat_map(|record| { 222 224 let record_uri = 223 225 AtUri::from_parts(did.as_str(), &record.collection, &record.rkey); 224 - record.blob_refs.iter().map(move |blob_ref| { 225 - (record_uri.clone(), unsafe { 226 - CidLink::new_unchecked(blob_ref.cid.clone()) 227 - }) 226 + record.blob_refs.iter().filter_map(move |blob_ref| { 227 + match CidLink::new(&blob_ref.cid) { 228 + Ok(cid_link) => Some((record_uri.clone(), cid_link)), 229 + Err(_) => { 230 + tracing::warn!(cid = %blob_ref.cid, "skipping unparseable blob CID reference during import"); 231 + None 232 + } 233 + } 228 234 }) 229 235 }) 230 236 .collect(); ··· 289 295 error!("Failed to store new commit block: {:?}", e); 290 296 ApiError::InternalError(None) 291 297 })?; 292 - let new_root_cid_link = unsafe { CidLink::new_unchecked(new_root_cid.to_string()) }; 298 + let new_root_cid_link = CidLink::from(&new_root_cid); 293 299 state 294 300 .repo_repo 295 301 .update_repo_root(user_id, &new_root_cid_link, &new_rev_str) ··· 313 319 "Created new commit for imported repo: cid={}, rev={}", 314 320 new_root_str, new_rev_str 315 321 ); 316 - if !is_migration && let Err(e) = sequence_import_event(&state, did, &new_root_str).await 322 + if !is_migration 323 + && let Err(e) = sequence_import_event(&state, did, &new_root_cid_link).await 317 324 { 318 325 warn!("Failed to sequence import event: {:?}", e); 319 326 } ··· 378 385 async fn sequence_import_event( 379 386 state: &AppState, 380 387 did: &Did, 381 - commit_cid: &str, 388 + commit_cid: &CidLink, 382 389 ) -> Result<(), tranquil_db::DbError> { 383 390 let data = tranquil_db::CommitEventData { 384 391 did: did.clone(), 385 392 event_type: tranquil_db::RepoEventType::Commit, 386 - commit_cid: Some(unsafe { CidLink::new_unchecked(commit_cid) }), 393 + commit_cid: Some(commit_cid.clone()), 387 394 prev_cid: None, 388 395 ops: Some(serde_json::json!([])), 389 396 blobs: Some(vec![]),
+7 -14
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 11 11 use crate::repo::tracking::TrackingBlockStore; 12 12 use crate::state::AppState; 13 13 use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; 14 + use crate::validation::ValidationStatus; 14 15 use axum::{ 15 16 Json, 16 17 extract::State, ··· 23 24 use serde_json::json; 24 25 use std::str::FromStr; 25 26 use std::sync::Arc; 26 - use tracing::{error, info}; 27 + use tracing::info; 27 28 28 29 const MAX_BATCH_WRITES: usize = 200; 29 30 ··· 87 88 results.push(WriteResult::CreateResult { 88 89 uri, 89 90 cid: record_cid.to_string(), 90 - validation_status: validation_status.map(|s| s.to_string()), 91 + validation_status, 91 92 }); 92 93 ops.push(RecordOp::Create { 93 94 collection: collection.clone(), ··· 138 139 results.push(WriteResult::UpdateResult { 139 140 uri, 140 141 cid: record_cid.to_string(), 141 - validation_status: validation_status.map(|s| s.to_string()), 142 + validation_status, 142 143 }); 143 144 ops.push(RecordOp::Update { 144 145 collection: collection.clone(), ··· 237 238 uri: AtUri, 238 239 cid: String, 239 240 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 240 - validation_status: Option<String>, 241 + validation_status: Option<ValidationStatus>, 241 242 }, 242 243 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 243 244 UpdateResult { 244 245 uri: AtUri, 245 246 cid: String, 246 247 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 247 - validation_status: Option<String>, 248 + validation_status: Option<ValidationStatus>, 248 249 }, 249 250 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 250 251 DeleteResult {}, ··· 441 442 .await 442 443 { 443 444 Ok(res) => res, 444 - Err(e) if e.contains("ConcurrentModification") => { 445 - return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 446 - } 447 - Err(e) => { 448 - error!("Commit failed: {}", e); 449 - return Err(ApiError::InternalError(Some( 450 - "Failed to commit changes".into(), 451 - ))); 452 - } 445 + Err(e) => return Err(ApiError::from(e)), 453 446 }; 454 447 455 448 if let Some(ref controller) = controller_did {
+17 -18
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{ 3 - CommitParams, RecordOp, commit_and_log, get_current_root_cid, 3 + CommitError, CommitParams, RecordOp, commit_and_log, get_current_root_cid, 4 4 }; 5 5 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 6 6 use crate::auth::{Active, Auth, VerifyScope}; ··· 186 186 .await 187 187 { 188 188 Ok(res) => res, 189 - Err(e) if e.contains("ConcurrentModification") => { 190 - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 191 - } 192 - Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 189 + Err(e) => return Ok(ApiError::from(e).into_response()), 193 190 }; 194 191 195 192 if let Some(ref controller) = controller_did { ··· 241 238 user_id: Uuid, 242 239 collection: &Nsid, 243 240 rkey: &Rkey, 244 - ) -> Result<(), String> { 241 + ) -> Result<(), CommitError> { 245 242 let _write_lock = state.repo_write_locks.lock(user_id).await; 246 243 247 244 let root_cid_str = state 248 245 .repo_repo 249 246 .get_repo_root_cid_by_user_id(user_id) 250 247 .await 251 - .map_err(|e| format!("DB error: {}", e))? 252 - .ok_or_else(|| "Repo root not found".to_string())?; 248 + .map_err(|e| CommitError::DatabaseError(e.to_string()))? 249 + .ok_or(CommitError::RepoNotFound)?; 253 250 254 251 let current_root_cid = 255 - Cid::from_str(root_cid_str.as_str()).map_err(|_| "Invalid repo root CID".to_string())?; 252 + Cid::from_str(root_cid_str.as_str()).map_err(|e| CommitError::InvalidCid(e.to_string()))?; 256 253 257 254 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 258 255 let commit_bytes = tracking_store 259 256 .get(&current_root_cid) 260 257 .await 261 - .map_err(|e| format!("Failed to fetch commit: {:?}", e))? 262 - .ok_or_else(|| "Commit block not found".to_string())?; 258 + .map_err(|e| CommitError::BlockStoreFailed(format!("{:?}", e)))? 259 + .ok_or(CommitError::BlockStoreFailed( 260 + "Commit block not found".into(), 261 + ))?; 263 262 264 - let commit = 265 - Commit::from_cbor(&commit_bytes).map_err(|e| format!("Failed to parse commit: {:?}", e))?; 263 + let commit = Commit::from_cbor(&commit_bytes) 264 + .map_err(|e| CommitError::CommitParseFailed(format!("{:?}", e)))?; 266 265 267 266 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 268 267 let key = format!("{}/{}", collection, rkey); ··· 270 269 let prev_record_cid = mst 271 270 .get(&key) 272 271 .await 273 - .map_err(|e| format!("MST get error: {:?}", e))?; 272 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 274 273 275 274 let Some(prev_cid) = prev_record_cid else { 276 275 return Ok(()); ··· 279 278 let new_mst = mst 280 279 .delete(&key) 281 280 .await 282 - .map_err(|e| format!("Failed to delete from MST: {:?}", e))?; 281 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 283 282 284 283 let new_mst_root = new_mst 285 284 .persist() 286 285 .await 287 - .map_err(|e| format!("Failed to persist MST: {:?}", e))?; 286 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 288 287 289 288 let op = RecordOp::Delete { 290 289 collection: collection.clone(), ··· 298 297 new_mst 299 298 .blocks_for_path(&key, &mut new_mst_blocks) 300 299 .await 301 - .map_err(|e| format!("Failed to get new MST blocks: {:?}", e))?; 300 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 302 301 303 302 mst.blocks_for_path(&key, &mut old_mst_blocks) 304 303 .await 305 - .map_err(|e| format!("Failed to get old MST blocks: {:?}", e))?; 304 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 306 305 307 306 let mut relevant_blocks = new_mst_blocks.clone(); 308 307 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
+1 -1
crates/tranquil-pds/src/api/repo/record/read.rs
··· 195 195 } 196 196 }; 197 197 let limit = input.limit.unwrap_or(50).clamp(1, 100); 198 - let limit_i64 = limit as i64; 198 + let limit_i64 = i64::from(limit); 199 199 let cursor_rkey = input 200 200 .cursor 201 201 .as_ref()
+119 -56
crates/tranquil-pds/src/api/repo/record/utils.rs
··· 14 14 use tranquil_db_traits::SequenceNumber; 15 15 use uuid::Uuid; 16 16 17 + #[derive(Debug)] 18 + pub enum CommitError { 19 + InvalidDid(String), 20 + InvalidTid(String), 21 + SigningFailed(String), 22 + SerializationFailed(String), 23 + KeyNotFound, 24 + KeyDecryptionFailed(String), 25 + InvalidKey(String), 26 + BlockStoreFailed(String), 27 + RepoNotFound, 28 + ConcurrentModification, 29 + DatabaseError(String), 30 + UserNotFound, 31 + CommitParseFailed(String), 32 + MstOperationFailed(String), 33 + RecordSerializationFailed(String), 34 + InvalidCid(String), 35 + } 36 + 37 + impl std::fmt::Display for CommitError { 38 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 39 + match self { 40 + Self::InvalidDid(e) => write!(f, "Invalid DID: {}", e), 41 + Self::InvalidTid(e) => write!(f, "Invalid TID: {}", e), 42 + Self::SigningFailed(e) => write!(f, "Failed to sign commit: {}", e), 43 + Self::SerializationFailed(e) => write!(f, "Failed to serialize signed commit: {}", e), 44 + Self::KeyNotFound => write!(f, "Signing key not found"), 45 + Self::KeyDecryptionFailed(e) => write!(f, "Failed to decrypt signing key: {}", e), 46 + Self::InvalidKey(e) => write!(f, "Invalid signing key: {}", e), 47 + Self::BlockStoreFailed(e) => write!(f, "Block store operation failed: {}", e), 48 + Self::RepoNotFound => write!(f, "Repo not found"), 49 + Self::ConcurrentModification => { 50 + write!(f, "Repo has been modified since last read") 51 + } 52 + Self::DatabaseError(e) => write!(f, "Database error: {}", e), 53 + Self::UserNotFound => write!(f, "User not found"), 54 + Self::CommitParseFailed(e) => write!(f, "Failed to parse commit: {}", e), 55 + Self::MstOperationFailed(e) => write!(f, "MST operation failed: {}", e), 56 + Self::RecordSerializationFailed(e) => { 57 + write!(f, "Failed to serialize record: {}", e) 58 + } 59 + Self::InvalidCid(e) => write!(f, "Invalid CID: {}", e), 60 + } 61 + } 62 + } 63 + 64 + impl std::error::Error for CommitError {} 65 + 66 + impl From<CommitError> for ApiError { 67 + fn from(err: CommitError) -> Self { 68 + match err { 69 + CommitError::ConcurrentModification => { 70 + ApiError::InvalidSwap(Some("Repo has been modified".into())) 71 + } 72 + CommitError::RepoNotFound => ApiError::RepoNotFound(None), 73 + CommitError::UserNotFound => ApiError::RepoNotFound(Some("User not found".into())), 74 + other => { 75 + error!("Commit failed: {}", other); 76 + ApiError::InternalError(Some("Failed to commit changes".into())) 77 + } 78 + } 79 + } 80 + } 81 + 17 82 pub async fn get_current_root_cid(state: &AppState, user_id: Uuid) -> Result<CommitCid, ApiError> { 18 83 let root_cid_str = state 19 84 .repo_repo ··· 55 120 } 56 121 57 122 use crate::types::AtUri; 58 - use tranquil_db_traits::Backlink; 123 + use tranquil_db_traits::{Backlink, BacklinkPath}; 59 124 60 125 pub fn extract_backlinks(uri: &AtUri, record: &Value) -> Vec<Backlink> { 61 126 let record_type = record ··· 71 136 .map(|subject| { 72 137 vec![Backlink { 73 138 uri: uri.clone(), 74 - path: "subject".to_string(), 139 + path: BacklinkPath::Subject, 75 140 link_to: subject.to_string(), 76 141 }] 77 142 }) ··· 84 149 .map(|subject_uri| { 85 150 vec![Backlink { 86 151 uri: uri.clone(), 87 - path: "subject.uri".to_string(), 152 + path: BacklinkPath::SubjectUri, 88 153 link_to: subject_uri.to_string(), 89 154 }] 90 155 }) ··· 99 164 rev: &str, 100 165 prev: Option<Cid>, 101 166 signing_key: &SigningKey, 102 - ) -> Result<(Vec<u8>, Bytes), String> { 167 + ) -> Result<(Vec<u8>, Bytes), CommitError> { 103 168 let did = jacquard_common::types::string::Did::new(did.as_str()) 104 - .map_err(|e| format!("Invalid DID: {:?}", e))?; 169 + .map_err(|e| CommitError::InvalidDid(format!("{:?}", e)))?; 105 170 let rev = jacquard_common::types::string::Tid::from_str(rev) 106 - .map_err(|e| format!("Invalid TID: {:?}", e))?; 171 + .map_err(|e| CommitError::InvalidTid(format!("{:?}", e)))?; 107 172 let unsigned = Commit::new_unsigned(did, data, rev, prev); 108 173 let signed = unsigned 109 174 .sign(signing_key) 110 - .map_err(|e| format!("Failed to sign commit: {:?}", e))?; 175 + .map_err(|e| CommitError::SigningFailed(format!("{:?}", e)))?; 111 176 let sig_bytes = signed.sig().clone(); 112 177 let signed_bytes = signed 113 178 .to_cbor() 114 - .map_err(|e| format!("Failed to serialize signed commit: {:?}", e))?; 179 + .map_err(|e| CommitError::SerializationFailed(format!("{:?}", e)))?; 115 180 Ok((signed_bytes, sig_bytes)) 116 181 } 117 182 ··· 154 219 pub async fn commit_and_log( 155 220 state: &AppState, 156 221 params: CommitParams<'_>, 157 - ) -> Result<CommitResult, String> { 222 + ) -> Result<CommitResult, CommitError> { 158 223 use tranquil_db_traits::{ 159 224 ApplyCommitError, ApplyCommitInput, CommitEventData, RecordDelete, RecordUpsert, 160 225 RepoEventType, ··· 175 240 .user_repo 176 241 .get_user_key_by_id(user_id) 177 242 .await 178 - .map_err(|e| format!("Failed to fetch signing key: {}", e))? 179 - .ok_or_else(|| "Signing key not found".to_string())?; 243 + .map_err(|e| CommitError::DatabaseError(format!("Failed to fetch signing key: {}", e)))? 244 + .ok_or(CommitError::KeyNotFound)?; 180 245 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 181 - .map_err(|e| format!("Failed to decrypt signing key: {}", e))?; 246 + .map_err(|e| CommitError::KeyDecryptionFailed(e.to_string()))?; 182 247 let signing_key = 183 - SigningKey::from_slice(&key_bytes).map_err(|e| format!("Invalid signing key: {}", e))?; 248 + SigningKey::from_slice(&key_bytes).map_err(|e| CommitError::InvalidKey(e.to_string()))?; 184 249 let rev = Tid::now(LimitedU32::MIN); 185 250 let rev_str = rev.to_string(); 186 251 let (new_commit_bytes, _sig) = ··· 189 254 .block_store 190 255 .put(&new_commit_bytes) 191 256 .await 192 - .map_err(|e| format!("Failed to save commit block: {:?}", e))?; 257 + .map_err(|e| CommitError::BlockStoreFailed(format!("{:?}", e)))?; 193 258 194 259 let mut all_block_cids: Vec<Vec<u8>> = blocks_cids 195 260 .iter() ··· 218 283 upserts.push(RecordUpsert { 219 284 collection: collection.clone(), 220 285 rkey: rkey.clone(), 221 - cid: unsafe { crate::types::CidLink::new_unchecked(cid.to_string()) }, 286 + cid: crate::types::CidLink::from(cid), 222 287 }); 223 288 } 224 289 RecordOp::Delete { ··· 283 348 let commit_event = CommitEventData { 284 349 did: did.clone(), 285 350 event_type: RepoEventType::Commit, 286 - commit_cid: Some(unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) }), 287 - prev_cid: current_root_cid 288 - .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 351 + commit_cid: Some(crate::types::CidLink::from(new_root_cid)), 352 + prev_cid: current_root_cid.map(crate::types::CidLink::from), 289 353 ops: Some(json!(ops_json)), 290 354 blobs: Some(blobs.to_vec()), 291 355 blocks_cids: Some(blocks_cids.to_vec()), 292 - prev_data_cid: prev_data_cid 293 - .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 356 + prev_data_cid: prev_data_cid.map(crate::types::CidLink::from), 294 357 rev: Some(rev_str.clone()), 295 358 }; 296 359 297 360 let input = ApplyCommitInput { 298 361 user_id, 299 362 did: did.clone(), 300 - expected_root_cid: current_root_cid 301 - .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 302 - new_root_cid: unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) }, 363 + expected_root_cid: current_root_cid.map(crate::types::CidLink::from), 364 + new_root_cid: crate::types::CidLink::from(new_root_cid), 303 365 new_rev: rev_str.clone(), 304 366 new_block_cids: all_block_cids, 305 367 obsolete_block_cids: obsolete_bytes, ··· 313 375 .apply_commit(input) 314 376 .await 315 377 .map_err(|e| match e { 316 - ApplyCommitError::RepoNotFound => "Repo not found".to_string(), 317 - ApplyCommitError::ConcurrentModification => { 318 - "ConcurrentModification: Repo has been modified since last read".to_string() 319 - } 320 - ApplyCommitError::Database(msg) => format!("DB Error: {}", msg), 378 + ApplyCommitError::RepoNotFound => CommitError::RepoNotFound, 379 + ApplyCommitError::ConcurrentModification => CommitError::ConcurrentModification, 380 + ApplyCommitError::Database(msg) => CommitError::DatabaseError(msg), 321 381 })?; 322 382 323 383 if result.is_account_active { ··· 335 395 collection: &Nsid, 336 396 rkey: &Rkey, 337 397 record: &serde_json::Value, 338 - ) -> Result<(String, Cid), String> { 398 + ) -> Result<(String, Cid), CommitError> { 339 399 use crate::repo::tracking::TrackingBlockStore; 340 400 use jacquard_repo::mst::Mst; 341 401 use std::sync::Arc; ··· 343 403 .user_repo 344 404 .get_id_by_did(did) 345 405 .await 346 - .map_err(|e| format!("DB error: {}", e))? 347 - .ok_or_else(|| "User not found".to_string())?; 406 + .map_err(|e| CommitError::DatabaseError(e.to_string()))? 407 + .ok_or(CommitError::UserNotFound)?; 348 408 349 409 let _write_lock = state.repo_write_locks.lock(user_id).await; 350 410 ··· 352 412 .repo_repo 353 413 .get_repo_root_cid_by_user_id(user_id) 354 414 .await 355 - .map_err(|e| format!("DB error: {}", e))? 356 - .ok_or_else(|| "Repo not found".to_string())?; 357 - let current_root_cid = 358 - Cid::from_str(root_cid_link.as_str()).map_err(|_| "Invalid repo root CID".to_string())?; 415 + .map_err(|e| CommitError::DatabaseError(e.to_string()))? 416 + .ok_or(CommitError::RepoNotFound)?; 417 + let current_root_cid = Cid::from_str(root_cid_link.as_str()) 418 + .map_err(|e| CommitError::InvalidCid(e.to_string()))?; 359 419 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 360 420 let commit_bytes = tracking_store 361 421 .get(&current_root_cid) 362 422 .await 363 - .map_err(|e| format!("Failed to fetch commit: {:?}", e))? 364 - .ok_or_else(|| "Commit block not found".to_string())?; 423 + .map_err(|e| CommitError::BlockStoreFailed(format!("{:?}", e)))? 424 + .ok_or(CommitError::BlockStoreFailed( 425 + "Commit block not found".into(), 426 + ))?; 365 427 let commit = jacquard_repo::commit::Commit::from_cbor(&commit_bytes) 366 - .map_err(|e| format!("Failed to parse commit: {:?}", e))?; 428 + .map_err(|e| CommitError::CommitParseFailed(format!("{:?}", e)))?; 367 429 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 368 430 let record_ipld = crate::util::json_to_ipld(record); 369 431 let mut record_bytes = Vec::new(); 370 432 serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld) 371 - .map_err(|e| format!("Failed to serialize record: {:?}", e))?; 433 + .map_err(|e| CommitError::RecordSerializationFailed(format!("{:?}", e)))?; 372 434 let record_cid = tracking_store 373 435 .put(&record_bytes) 374 436 .await 375 - .map_err(|e| format!("Failed to save record block: {:?}", e))?; 437 + .map_err(|e| CommitError::BlockStoreFailed(format!("{:?}", e)))?; 376 438 let key = format!("{}/{}", collection, rkey); 377 439 let new_mst = mst 378 440 .add(&key, record_cid) 379 441 .await 380 - .map_err(|e| format!("Failed to add to MST: {:?}", e))?; 442 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 381 443 let new_mst_root = new_mst 382 444 .persist() 383 445 .await 384 - .map_err(|e| format!("Failed to persist MST: {:?}", e))?; 446 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 385 447 let op = RecordOp::Create { 386 448 collection: collection.clone(), 387 449 rkey: rkey.clone(), ··· 392 454 new_mst 393 455 .blocks_for_path(&key, &mut new_mst_blocks) 394 456 .await 395 - .map_err(|e| format!("Failed to get new MST blocks for path: {:?}", e))?; 457 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 396 458 mst.blocks_for_path(&key, &mut old_mst_blocks) 397 459 .await 398 - .map_err(|e| format!("Failed to get old MST blocks for path: {:?}", e))?; 460 + .map_err(|e| CommitError::MstOperationFailed(format!("{:?}", e)))?; 399 461 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 400 462 .chain( 401 463 old_mst_blocks ··· 439 501 state: &AppState, 440 502 did: &Did, 441 503 handle: Option<&Handle>, 442 - ) -> Result<SequenceNumber, String> { 504 + ) -> Result<SequenceNumber, CommitError> { 443 505 state 444 506 .repo_repo 445 507 .insert_identity_event(did, handle) 446 508 .await 447 - .map_err(|e| format!("DB Error (identity event): {}", e)) 509 + .map_err(|e| CommitError::DatabaseError(format!("identity event: {}", e))) 448 510 } 449 511 pub async fn sequence_account_event( 450 512 state: &AppState, 451 513 did: &Did, 452 514 status: tranquil_db_traits::AccountStatus, 453 - ) -> Result<SequenceNumber, String> { 515 + ) -> Result<SequenceNumber, CommitError> { 454 516 state 455 517 .repo_repo 456 518 .insert_account_event(did, status) 457 519 .await 458 - .map_err(|e| format!("DB Error (account event): {}", e)) 520 + .map_err(|e| CommitError::DatabaseError(format!("account event: {}", e))) 459 521 } 460 522 pub async fn sequence_sync_event( 461 523 state: &AppState, 462 524 did: &Did, 463 525 commit_cid: &str, 464 526 rev: Option<&str>, 465 - ) -> Result<SequenceNumber, String> { 466 - let cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid) }; 527 + ) -> Result<SequenceNumber, CommitError> { 528 + let cid_link: crate::types::CidLink = commit_cid 529 + .parse() 530 + .map_err(|_| CommitError::InvalidCid(commit_cid.to_string()))?; 467 531 state 468 532 .repo_repo 469 533 .insert_sync_event(did, &cid_link, rev) 470 534 .await 471 - .map_err(|e| format!("DB Error (sync event): {}", e)) 535 + .map_err(|e| CommitError::DatabaseError(format!("sync event: {}", e))) 472 536 } 473 537 474 538 pub async fn sequence_genesis_commit( ··· 477 541 commit_cid: &Cid, 478 542 mst_root_cid: &Cid, 479 543 rev: &str, 480 - ) -> Result<SequenceNumber, String> { 481 - let commit_cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid.to_string()) }; 482 - let mst_root_cid_link = 483 - unsafe { crate::types::CidLink::new_unchecked(mst_root_cid.to_string()) }; 544 + ) -> Result<SequenceNumber, CommitError> { 545 + let commit_cid_link = crate::types::CidLink::from(commit_cid); 546 + let mst_root_cid_link = crate::types::CidLink::from(mst_root_cid); 484 547 state 485 548 .repo_repo 486 549 .insert_genesis_commit_event(did, &commit_cid_link, &mst_root_cid_link, rev) 487 550 .await 488 - .map_err(|e| format!("DB Error (genesis commit event): {}", e)) 551 + .map_err(|e| CommitError::DatabaseError(format!("genesis commit event: {}", e))) 489 552 }
+11 -16
crates/tranquil-pds/src/api/repo/record/write.rs
··· 6 6 get_current_root_cid, 7 7 }; 8 8 use crate::auth::{ 9 - Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, 9 + Active, Auth, AuthSource, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, 10 10 require_verified_or_delegated, 11 11 }; 12 12 use crate::cid_types::CommitCid; ··· 14 14 use crate::repo::tracking::TrackingBlockStore; 15 15 use crate::state::AppState; 16 16 use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; 17 + use crate::validation::ValidationStatus; 17 18 use axum::{ 18 19 Json, 19 20 extract::State, ··· 32 33 pub struct RepoWriteAuth { 33 34 pub did: Did, 34 35 pub user_id: Uuid, 35 - pub is_oauth: bool, 36 + pub auth_source: AuthSource, 36 37 pub scope: Option<String>, 37 38 pub controller_did: Option<Did>, 38 39 } ··· 66 67 Ok(RepoWriteAuth { 67 68 did: principal_did.into_did(), 68 69 user_id, 69 - is_oauth: user.is_oauth(), 70 + auth_source: user.auth_source.clone(), 70 71 scope: user.scope.clone(), 71 72 controller_did: scope_proof.controller_did().map(|c| c.into_did()), 72 73 }) ··· 97 98 pub cid: String, 98 99 pub commit: CommitInfo, 99 100 #[serde(skip_serializing_if = "Option::is_none")] 100 - pub validation_status: Option<String>, 101 + pub validation_status: Option<ValidationStatus>, 101 102 } 102 103 pub async fn create_record( 103 104 State(state): State<AppState>, ··· 323 324 .await 324 325 { 325 326 Ok(res) => res, 326 - Err(e) if e.contains("ConcurrentModification") => { 327 - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 328 - } 329 - Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 327 + Err(e) => return Ok(ApiError::from(e).into_response()), 330 328 }; 331 329 332 330 for conflict_uri in conflict_uris_to_cleanup { ··· 375 373 cid: commit_result.commit_cid.to_string(), 376 374 rev: commit_result.rev, 377 375 }, 378 - validation_status: validation_status.map(|s| s.to_string()), 376 + validation_status, 379 377 }), 380 378 ) 381 379 .into_response()) ··· 402 400 #[serde(skip_serializing_if = "Option::is_none")] 403 401 pub commit: Option<CommitInfo>, 404 402 #[serde(skip_serializing_if = "Option::is_none")] 405 - pub validation_status: Option<String>, 403 + pub validation_status: Option<ValidationStatus>, 406 404 } 407 405 pub async fn put_record( 408 406 State(state): State<AppState>, ··· 494 492 uri: AtUri::from_parts(&did, &input.collection, &input.rkey), 495 493 cid: record_cid.to_string(), 496 494 commit: None, 497 - validation_status: validation_status.map(|s| s.to_string()), 495 + validation_status, 498 496 }), 499 497 ) 500 498 .into_response()); ··· 600 598 .await 601 599 { 602 600 Ok(res) => res, 603 - Err(e) if e.contains("ConcurrentModification") => { 604 - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 605 - } 606 - Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 601 + Err(e) => return Ok(ApiError::from(e).into_response()), 607 602 }; 608 603 609 604 if let Some(ref controller) = controller_did { ··· 634 629 cid: commit_result.commit_cid.to_string(), 635 630 rev: commit_result.rev, 636 631 }), 637 - validation_status: validation_status.map(|s| s.to_string()), 632 + validation_status, 638 633 }), 639 634 ) 640 635 .into_response())
+17 -8
crates/tranquil-pds/src/api/server/account_status.rs
··· 285 285 arr.iter().find(|svc| { 286 286 svc.get("id").and_then(|id| id.as_str()) == Some("#atproto_pds") 287 287 || svc.get("type").and_then(|t| t.as_str()) 288 - == Some("AtprotoPersonalDataServer") 288 + == Some(crate::plc::ServiceType::Pds.as_str()) 289 289 }) 290 290 }) 291 291 .and_then(|svc| svc.get("serviceEndpoint")) ··· 316 316 ); 317 317 318 318 if let Err(e) = crate::auth::scope_check::check_account_scope( 319 - auth.is_oauth(), 319 + &auth.auth_source, 320 320 auth.scope.as_deref(), 321 321 crate::oauth::scopes::AccountAttr::Repo, 322 322 crate::oauth::scopes::AccountAction::Manage, ··· 366 366 did 367 367 ); 368 368 if let Some(ref h) = handle { 369 - let _ = state.cache.delete(&format!("handle:{}", h)).await; 369 + let _ = state.cache.delete(&crate::cache_keys::handle_key(h)).await; 370 370 } 371 - let _ = state.cache.delete(&format!("plc:doc:{}", did)).await; 372 - let _ = state.cache.delete(&format!("plc:data:{}", did)).await; 371 + let _ = state 372 + .cache 373 + .delete(&crate::cache_keys::plc_doc_key(&did)) 374 + .await; 375 + let _ = state 376 + .cache 377 + .delete(&crate::cache_keys::plc_data_key(&did)) 378 + .await; 373 379 if state.did_resolver.refresh_did(did.as_str()).await.is_none() { 374 380 warn!( 375 381 "[MIGRATION] activateAccount: Failed to refresh DID cache for {}", ··· 479 485 Json(input): Json<DeactivateAccountInput>, 480 486 ) -> Result<Response, ApiError> { 481 487 if let Err(e) = crate::auth::scope_check::check_account_scope( 482 - auth.is_oauth(), 488 + &auth.auth_source, 483 489 auth.scope.as_deref(), 484 490 crate::oauth::scopes::AccountAttr::Repo, 485 491 crate::oauth::scopes::AccountAction::Manage, ··· 502 508 match result { 503 509 Ok(true) => { 504 510 if let Some(ref h) = handle { 505 - let _ = state.cache.delete(&format!("handle:{}", h)).await; 511 + let _ = state.cache.delete(&crate::cache_keys::handle_key(h)).await; 506 512 } 507 513 if let Err(e) = crate::api::repo::record::sequence_account_event( 508 514 &state, ··· 659 665 ); 660 666 } 661 667 } 662 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 668 + let _ = state 669 + .cache 670 + .delete(&crate::cache_keys::handle_key(&handle)) 671 + .await; 663 672 info!("Account {} deleted successfully", did); 664 673 EmptyResponse::ok().into_response() 665 674 }
+4 -3
crates/tranquil-pds/src/api/server/app_password.rs
··· 150 150 ApiError::InternalError(None) 151 151 })?; 152 152 153 - let privilege = 154 - tranquil_db_traits::AppPasswordPrivilege::from(input.privileged.unwrap_or(false)); 153 + let privilege = tranquil_db_traits::AppPasswordPrivilege::from_privileged_flag( 154 + input.privileged.unwrap_or(false), 155 + ); 155 156 let created_at = chrono::Utc::now(); 156 157 157 158 let create_data = AppPasswordCreate { ··· 232 233 .log_db_err("revoking sessions for app password")?; 233 234 234 235 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 235 - let cache_key = format!("auth:session:{}:{}", &auth.did, jti); 236 + let cache_key = crate::cache_keys::session_key(&auth.did, jti); 236 237 let cache = state.cache.clone(); 237 238 async move { 238 239 let _ = cache.delete(&cache_key).await;
+22 -44
crates/tranquil-pds/src/api/server/email.rs
··· 21 21 const EMAIL_UPDATE_TTL: Duration = Duration::from_secs(30 * 60); 22 22 23 23 fn email_update_cache_key(did: &str) -> String { 24 - format!("email_update:{}", did) 24 + crate::cache_keys::email_update_key(did) 25 25 } 26 26 27 27 fn hash_token(token: &str) -> String { ··· 51 51 input: Option<Json<RequestEmailUpdateInput>>, 52 52 ) -> Result<Response, ApiError> { 53 53 if let Err(e) = crate::auth::scope_check::check_account_scope( 54 - auth.is_oauth(), 54 + &auth.auth_source, 55 55 auth.scope.as_deref(), 56 56 crate::oauth::scopes::AccountAttr::Email, 57 57 crate::oauth::scopes::AccountAction::Manage, ··· 111 111 state.infra_repo.as_ref(), 112 112 user.id, 113 113 &token, 114 - "email_update", 115 114 hostname, 116 115 ) 117 116 .await ··· 138 137 Json(input): Json<ConfirmEmailInput>, 139 138 ) -> Result<Response, ApiError> { 140 139 if let Err(e) = crate::auth::scope_check::check_account_scope( 141 - auth.is_oauth(), 140 + &auth.auth_source, 142 141 auth.scope.as_deref(), 143 142 crate::oauth::scopes::AccountAttr::Email, 144 143 crate::oauth::scopes::AccountAction::Manage, ··· 173 172 174 173 let verified = crate::auth::verification_token::verify_signup_token( 175 174 &confirmation_code, 176 - "email", 175 + CommsChannel::Email, 177 176 &provided_email, 178 177 ); 179 178 180 179 match verified { 181 180 Ok(token_data) => { 182 - if token_data.did != did.as_str() { 181 + if token_data.did != *did { 183 182 return Err(ApiError::InvalidToken(None)); 184 183 } 185 184 } ··· 216 215 Json(input): Json<UpdateEmailInput>, 217 216 ) -> Result<Response, ApiError> { 218 217 if let Err(e) = crate::auth::scope_check::check_account_scope( 219 - auth.is_oauth(), 218 + &auth.auth_source, 220 219 auth.scope.as_deref(), 221 220 crate::oauth::scopes::AccountAttr::Email, 222 221 crate::oauth::scopes::AccountAction::Manage, ··· 324 323 325 324 let verified = crate::auth::verification_token::verify_channel_update_token( 326 325 &confirmation_token, 327 - "email_update", 326 + CommsChannel::Email, 328 327 &current_email_lower, 329 328 ); 330 329 331 330 match verified { 332 331 Ok(token_data) => { 333 - if token_data.did != did.as_str() { 332 + if token_data.did != *did { 334 333 return Err(ApiError::InvalidToken(None)); 335 334 } 336 335 } ··· 361 360 .await 362 361 .log_db_err("updating email")?; 363 362 364 - let verification_token = 365 - crate::auth::verification_token::generate_signup_token(did, "email", &new_email); 363 + let verification_token = crate::auth::verification_token::generate_signup_token( 364 + did, 365 + CommsChannel::Email, 366 + &new_email, 367 + ); 366 368 let formatted_token = 367 369 crate::auth::verification_token::format_token_for_display(&verification_token); 368 370 let hostname = pds_hostname(); ··· 370 372 state.user_repo.as_ref(), 371 373 state.infra_repo.as_ref(), 372 374 user_id, 373 - "email", 375 + tranquil_db_traits::CommsChannel::Email, 374 376 &new_email, 375 377 &formatted_token, 376 378 hostname, ··· 422 424 423 425 #[derive(Deserialize)] 424 426 pub struct CheckChannelVerifiedInput { 425 - pub did: String, 426 - pub channel: String, 427 + pub did: crate::types::Did, 428 + pub channel: CommsChannel, 427 429 } 428 430 429 431 pub async fn check_channel_verified( ··· 431 433 _rate_limit: RateLimited<VerificationCheckLimit>, 432 434 Json(input): Json<CheckChannelVerifiedInput>, 433 435 ) -> Response { 434 - let channel = match input.channel.to_lowercase().as_str() { 435 - "email" => CommsChannel::Email, 436 - "discord" => CommsChannel::Discord, 437 - "telegram" => CommsChannel::Telegram, 438 - "signal" => CommsChannel::Signal, 439 - _ => { 440 - return ApiError::InvalidRequest("invalid channel".into()).into_response(); 441 - } 442 - }; 443 - 444 - let did = match crate::Did::new(input.did) { 445 - Ok(d) => d, 446 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 447 - }; 448 436 match state 449 437 .user_repo 450 - .check_channel_verified_by_did(&did, channel) 438 + .check_channel_verified_by_did(&input.did, input.channel) 451 439 .await 452 440 { 453 441 Ok(Some(verified)) => VerifiedResponse::response(verified).into_response(), ··· 490 478 ); 491 479 return ApiError::InvalidToken(None).into_response(); 492 480 } 493 - if token_data.channel != "email_update" { 481 + if token_data.channel != CommsChannel::Email { 494 482 warn!( 495 - "authorize_email_update: wrong channel: {}", 483 + "authorize_email_update: wrong channel: {:?}", 496 484 token_data.channel 497 485 ); 498 486 return ApiError::InvalidToken(None).into_response(); ··· 558 546 auth: Auth<NotTakendown>, 559 547 ) -> Result<Response, ApiError> { 560 548 if let Err(e) = crate::auth::scope_check::check_account_scope( 561 - auth.is_oauth(), 549 + &auth.auth_source, 562 550 auth.scope.as_deref(), 563 551 crate::oauth::scopes::AccountAttr::Email, 564 552 crate::oauth::scopes::AccountAction::Read, ··· 620 608 621 609 #[derive(Deserialize)] 622 610 pub struct CheckCommsChannelInUseInput { 623 - pub channel: String, 611 + pub channel: CommsChannel, 624 612 pub identifier: String, 625 613 } 626 614 ··· 629 617 _rate_limit: RateLimited<VerificationCheckLimit>, 630 618 Json(input): Json<CheckCommsChannelInUseInput>, 631 619 ) -> Response { 632 - let channel = match input.channel.to_lowercase().as_str() { 633 - "email" => CommsChannel::Email, 634 - "discord" => CommsChannel::Discord, 635 - "telegram" => CommsChannel::Telegram, 636 - "signal" => CommsChannel::Signal, 637 - _ => { 638 - return ApiError::InvalidRequest("invalid channel".into()).into_response(); 639 - } 640 - }; 641 - 642 620 let identifier = input.identifier.trim(); 643 621 if identifier.is_empty() { 644 622 return ApiError::InvalidRequest("identifier is required".into()).into_response(); ··· 646 624 647 625 let count = match state 648 626 .user_repo 649 - .count_accounts_by_comms_identifier(channel, identifier) 627 + .count_accounts_by_comms_identifier(input.channel, identifier) 650 628 .await 651 629 { 652 630 Ok(c) => c,
+1 -1
crates/tranquil-pds/src/api/server/invite.rs
··· 226 226 }) 227 227 .unwrap_or_default(); 228 228 229 - let use_count = uses.len() as i32; 229 + let use_count = i32::try_from(uses.len()).unwrap_or(i32::MAX); 230 230 if !include_used && use_count >= info.available_uses { 231 231 return None; 232 232 }
+5 -2
crates/tranquil-pds/src/api/server/logo.rs
··· 21 21 Some(c) if !c.is_empty() => c, 22 22 _ => return StatusCode::NOT_FOUND.into_response(), 23 23 }; 24 - let cid = unsafe { crate::types::CidLink::new_unchecked(&cid_str) }; 24 + let cid = match crate::types::CidLink::new(&cid_str) { 25 + Ok(c) => c, 26 + Err(_) => return StatusCode::NOT_FOUND.into_response(), 27 + }; 25 28 26 29 let metadata = match state.blob_repo.get_blob_metadata(&cid).await { 27 30 Ok(Some(m)) => m, ··· 38 41 .header(header::CONTENT_TYPE, &metadata.mime_type) 39 42 .header(header::CACHE_CONTROL, "public, max-age=3600") 40 43 .body(Body::from(data)) 41 - .unwrap(), 44 + .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()), 42 45 Err(e) => { 43 46 error!("Failed to fetch logo from storage: {:?}", e); 44 47 StatusCode::NOT_FOUND.into_response()
+7 -8
crates/tranquil-pds/src/api/server/meta.rs
··· 3 3 use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; 4 4 use serde_json::json; 5 5 6 - fn get_available_comms_channels() -> Vec<&'static str> { 7 - let mut channels = vec!["email"]; 6 + fn get_available_comms_channels() -> Vec<tranquil_db_traits::CommsChannel> { 7 + use tranquil_db_traits::CommsChannel; 8 + let mut channels = vec![CommsChannel::Email]; 8 9 if std::env::var("DISCORD_BOT_TOKEN").is_ok() { 9 - channels.push("discord"); 10 + channels.push(CommsChannel::Discord); 10 11 } 11 12 if std::env::var("TELEGRAM_BOT_TOKEN").is_ok() { 12 - channels.push("telegram"); 13 + channels.push(CommsChannel::Telegram); 13 14 } 14 15 if std::env::var("SIGNAL_CLI_PATH").is_ok() && std::env::var("SIGNAL_SENDER_NUMBER").is_ok() { 15 - channels.push("signal"); 16 + channels.push(CommsChannel::Signal); 16 17 } 17 18 channels 18 19 } ··· 35 36 let domains_str = 36 37 std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.to_string()); 37 38 let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect(); 38 - let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 39 - .map(|v| v == "true" || v == "1") 40 - .unwrap_or(false); 39 + let invite_code_required = crate::util::parse_env_bool("INVITE_CODE_REQUIRED"); 41 40 let privacy_policy = std::env::var("PRIVACY_POLICY_URL").ok(); 42 41 let terms_of_service = std::env::var("TERMS_OF_SERVICE_URL").ok(); 43 42 let contact_email = std::env::var("CONTACT_EMAIL").ok();
+2 -2
crates/tranquil-pds/src/api/server/migration.rs
··· 196 196 })).collect::<Vec<_>>(), 197 197 "service": [{ 198 198 "id": "#atproto_pds", 199 - "type": "AtprotoPersonalDataServer", 199 + "type": crate::plc::ServiceType::Pds.as_str(), 200 200 "serviceEndpoint": service_endpoint 201 201 }] 202 202 }); ··· 244 244 }], 245 245 "service": [{ 246 246 "id": "#atproto_pds", 247 - "type": "AtprotoPersonalDataServer", 247 + "type": crate::plc::ServiceType::Pds.as_str(), 248 248 "serviceEndpoint": service_endpoint 249 249 }] 250 250 })
+27 -31
crates/tranquil-pds/src/api/server/passkey_account.rs
··· 16 16 use serde_json::json; 17 17 use std::sync::Arc; 18 18 use tracing::{debug, error, info, warn}; 19 + use tranquil_db_traits::WebauthnChallengeType; 19 20 use uuid::Uuid; 20 21 21 22 use crate::api::repo::record::utils::create_signed_commit; 22 23 use crate::auth::{ServiceTokenVerifier, generate_app_password, is_service_token}; 23 24 use crate::rate_limit::{AccountCreationLimit, PasswordResetLimit, RateLimited}; 24 25 use crate::state::AppState; 25 - use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 26 + use crate::types::{Did, Handle, PlainPassword}; 26 27 use crate::util::{pds_hostname, pds_hostname_without_port}; 27 28 use crate::validation::validate_password; 28 29 ··· 49 50 pub did: Option<String>, 50 51 pub did_type: Option<String>, 51 52 pub signing_key: Option<String>, 52 - pub verification_channel: Option<String>, 53 + pub verification_channel: Option<tranquil_db_traits::CommsChannel>, 53 54 pub discord_username: Option<String>, 54 55 pub telegram_username: Option<String>, 55 56 pub signal_username: Option<String>, ··· 73 74 Json(input): Json<CreatePasskeyAccountInput>, 74 75 ) -> Response { 75 76 let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( 76 - crate::util::get_header_str(&headers, "Authorization"), 77 + crate::util::get_header_str(&headers, http::header::AUTHORIZATION), 77 78 ) { 78 79 let token = extracted.token; 79 80 if is_service_token(&token) { ··· 152 153 Err(_) => return ApiError::InvalidInviteCode.into_response(), 153 154 } 154 155 } else { 155 - let invite_required = std::env::var("INVITE_CODE_REQUIRED") 156 - .map(|v| v == "true" || v == "1") 157 - .unwrap_or(false); 156 + let invite_required = crate::util::parse_env_bool("INVITE_CODE_REQUIRED"); 158 157 if invite_required { 159 158 return ApiError::InviteCodeRequired.into_response(); 160 159 } 161 160 None 162 161 }; 163 162 164 - let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 163 + let verification_channel = input 164 + .verification_channel 165 + .unwrap_or(tranquil_db_traits::CommsChannel::Email); 165 166 let verification_recipient = match verification_channel { 166 - "email" => match &email { 167 + tranquil_db_traits::CommsChannel::Email => match &email { 167 168 Some(e) if !e.is_empty() => e.clone(), 168 169 _ => return ApiError::MissingEmail.into_response(), 169 170 }, 170 - "discord" => match &input.discord_username { 171 + tranquil_db_traits::CommsChannel::Discord => match &input.discord_username { 171 172 Some(username) if !username.trim().is_empty() => { 172 173 let clean = username.trim().to_lowercase(); 173 174 if !crate::api::validation::is_valid_discord_username(&clean) { ··· 179 180 } 180 181 _ => return ApiError::MissingDiscordId.into_response(), 181 182 }, 182 - "telegram" => match &input.telegram_username { 183 + tranquil_db_traits::CommsChannel::Telegram => match &input.telegram_username { 183 184 Some(username) if !username.trim().is_empty() => { 184 185 let clean = username.trim().trim_start_matches('@'); 185 186 if !crate::api::validation::is_valid_telegram_username(clean) { ··· 191 192 } 192 193 _ => return ApiError::MissingTelegramUsername.into_response(), 193 194 }, 194 - "signal" => match &input.signal_username { 195 + tranquil_db_traits::CommsChannel::Signal => match &input.signal_username { 195 196 Some(username) if !username.trim().is_empty() => { 196 197 username.trim().trim_start_matches('@').to_lowercase() 197 198 } 198 199 _ => return ApiError::MissingSignalNumber.into_response(), 199 200 }, 200 - _ => return ApiError::InvalidVerificationChannel.into_response(), 201 201 }; 202 202 203 203 use k256::ecdsa::SigningKey; ··· 277 277 ) 278 278 .await 279 279 { 280 - return ApiError::InvalidDid(e).into_response(); 280 + return ApiError::InvalidDid(e.to_string()).into_response(); 281 281 } 282 282 info!(did = %d, "Creating external did:web passkey account (reserved key)"); 283 283 } ··· 380 380 } 381 381 }; 382 382 let rev = Tid::now(LimitedU32::MIN); 383 - let did_typed = unsafe { Did::new_unchecked(&did) }; 383 + let did_typed: Did = match did.parse() { 384 + Ok(d) => d, 385 + Err(_) => return ApiError::InternalError(Some("Invalid DID".into())).into_response(), 386 + }; 384 387 let (commit_bytes, _sig) = 385 388 match create_signed_commit(&did_typed, mst_root, rev.as_ref(), None, &secret_key) { 386 389 Ok(result) => result, ··· 405 408 }) 406 409 }); 407 410 408 - let preferred_comms_channel = match verification_channel { 409 - "email" => tranquil_db_traits::CommsChannel::Email, 410 - "discord" => tranquil_db_traits::CommsChannel::Discord, 411 - "telegram" => tranquil_db_traits::CommsChannel::Telegram, 412 - "signal" => tranquil_db_traits::CommsChannel::Signal, 413 - _ => tranquil_db_traits::CommsChannel::Email, 411 + let handle_typed: Handle = match handle.parse() { 412 + Ok(h) => h, 413 + Err(_) => return ApiError::InvalidHandle(None).into_response(), 414 414 }; 415 - 416 - let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 417 415 let create_input = tranquil_db_traits::CreatePasskeyAccountInput { 418 416 handle: handle_typed.clone(), 419 417 email: email.clone().unwrap_or_default(), 420 418 did: did_typed.clone(), 421 - preferred_comms_channel, 419 + preferred_comms_channel: verification_channel, 422 420 discord_username: input 423 421 .discord_username 424 422 .as_deref() ··· 487 485 "$type": "app.bsky.actor.profile", 488 486 "displayName": handle 489 487 }); 490 - let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 491 - let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 492 488 if let Err(e) = crate::api::repo::record::create_record_internal( 493 489 &state, 494 490 &did_typed, 495 - &profile_collection, 496 - &profile_rkey, 491 + &crate::types::PROFILE_COLLECTION, 492 + &crate::types::PROFILE_RKEY, 497 493 &profile_record, 498 494 ) 499 495 .await ··· 503 499 } 504 500 505 501 let verification_token = crate::auth::verification_token::generate_signup_token( 506 - &did, 502 + &did_typed, 507 503 verification_channel, 508 504 &verification_recipient, 509 505 ); ··· 625 621 626 622 let reg_state = match state 627 623 .user_repo 628 - .load_webauthn_challenge(&input.did, "registration") 624 + .load_webauthn_challenge(&input.did, WebauthnChallengeType::Registration) 629 625 .await 630 626 { 631 627 Ok(Some(json)) => match serde_json::from_str(&json) { ··· 706 702 707 703 let _ = state 708 704 .user_repo 709 - .delete_webauthn_challenge(&input.did, "registration") 705 + .delete_webauthn_challenge(&input.did, WebauthnChallengeType::Registration) 710 706 .await; 711 707 712 708 info!(did = %input.did, "Passkey-only account setup completed"); ··· 793 789 }; 794 790 if let Err(e) = state 795 791 .user_repo 796 - .save_webauthn_challenge(&input.did, "registration", &state_json) 792 + .save_webauthn_challenge(&input.did, WebauthnChallengeType::Registration, &state_json) 797 793 .await 798 794 { 799 795 error!("Failed to save registration state: {:?}", e);
+4 -3
crates/tranquil-pds/src/api/server/passkeys.rs
··· 9 9 }; 10 10 use serde::{Deserialize, Serialize}; 11 11 use tracing::{error, info, warn}; 12 + use tranquil_db_traits::WebauthnChallengeType; 12 13 use webauthn_rs::prelude::*; 13 14 14 15 #[derive(Deserialize)] ··· 64 65 65 66 state 66 67 .user_repo 67 - .save_webauthn_challenge(&auth.did, "registration", &state_json) 68 + .save_webauthn_challenge(&auth.did, WebauthnChallengeType::Registration, &state_json) 68 69 .await 69 70 .log_db_err("saving registration state")?; 70 71 ··· 98 99 99 100 let reg_state_json = state 100 101 .user_repo 101 - .load_webauthn_challenge(&auth.did, "registration") 102 + .load_webauthn_challenge(&auth.did, WebauthnChallengeType::Registration) 102 103 .await 103 104 .log_db_err("loading registration state")? 104 105 .ok_or(ApiError::NoRegistrationInProgress)?; ··· 140 141 141 142 if let Err(e) = state 142 143 .user_repo 143 - .delete_webauthn_challenge(&auth.did, "registration") 144 + .delete_webauthn_challenge(&auth.did, WebauthnChallengeType::Registration) 144 145 .await 145 146 { 146 147 warn!("Failed to delete registration state: {:?}", e);
+1 -1
crates/tranquil-pds/src/api/server/password.rs
··· 171 171 } 172 172 }; 173 173 futures::future::join_all(result.session_jtis.iter().map(|jti| { 174 - let cache_key = format!("auth:session:{}:{}", result.did, jti); 174 + let cache_key = crate::cache_keys::session_key(&result.did, jti); 175 175 let cache = state.cache.clone(); 176 176 async move { 177 177 if let Err(e) = cache.delete(&cache_key).await {
+31 -16
crates/tranquil-pds/src/api/server/reauth.rs
··· 8 8 use chrono::{DateTime, Utc}; 9 9 use serde::{Deserialize, Serialize}; 10 10 use tracing::{error, info, warn}; 11 - use tranquil_db_traits::{SessionRepository, UserRepository}; 11 + use tranquil_db_traits::{SessionRepository, UserRepository, WebauthnChallengeType}; 12 12 13 13 use crate::auth::{Active, Auth}; 14 14 use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message}; ··· 17 17 18 18 pub const REAUTH_WINDOW_SECONDS: i64 = 300; 19 19 20 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] 21 + #[serde(rename_all = "lowercase")] 22 + pub enum ReauthMethod { 23 + Password, 24 + Totp, 25 + Passkey, 26 + } 27 + 20 28 #[derive(Serialize)] 21 29 #[serde(rename_all = "camelCase")] 22 30 pub struct ReauthStatusResponse { 23 31 pub last_reauth_at: Option<DateTime<Utc>>, 24 32 pub reauth_required: bool, 25 - pub available_methods: Vec<String>, 33 + pub available_methods: Vec<ReauthMethod>, 26 34 } 27 35 28 36 pub async fn get_reauth_status( ··· 180 188 181 189 state 182 190 .user_repo 183 - .save_webauthn_challenge(&auth.did, "authentication", &state_json) 191 + .save_webauthn_challenge( 192 + &auth.did, 193 + WebauthnChallengeType::Authentication, 194 + &state_json, 195 + ) 184 196 .await 185 197 .log_db_err("saving authentication state")?; 186 198 ··· 201 213 ) -> Result<Response, ApiError> { 202 214 let auth_state_json = state 203 215 .user_repo 204 - .load_webauthn_challenge(&auth.did, "authentication") 216 + .load_webauthn_challenge(&auth.did, WebauthnChallengeType::Authentication) 205 217 .await 206 218 .log_db_err("loading authentication state")? 207 219 .ok_or(ApiError::NoChallengeInProgress)?; ··· 229 241 let cred_id_bytes = auth_result.cred_id().as_ref(); 230 242 match state 231 243 .user_repo 232 - .update_passkey_counter(cred_id_bytes, auth_result.counter() as i32) 244 + .update_passkey_counter( 245 + cred_id_bytes, 246 + i32::try_from(auth_result.counter()).unwrap_or(i32::MAX), 247 + ) 233 248 .await 234 249 { 235 250 Ok(false) => { 236 251 warn!(did = %&auth.did, "Passkey counter anomaly detected - possible cloned key"); 237 252 let _ = state 238 253 .user_repo 239 - .delete_webauthn_challenge(&auth.did, "authentication") 254 + .delete_webauthn_challenge(&auth.did, WebauthnChallengeType::Authentication) 240 255 .await; 241 256 return Err(ApiError::PasskeyCounterAnomaly); 242 257 } ··· 248 263 249 264 let _ = state 250 265 .user_repo 251 - .delete_webauthn_challenge(&auth.did, "authentication") 266 + .delete_webauthn_challenge(&auth.did, WebauthnChallengeType::Authentication) 252 267 .await; 253 268 254 269 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) ··· 265 280 did: &crate::types::Did, 266 281 ) -> Result<DateTime<Utc>, tranquil_db_traits::DbError> { 267 282 let now = session_repo.update_last_reauth(did).await?; 268 - let cache_key = format!("reauth:{}", did); 283 + let cache_key = crate::cache_keys::reauth_key(did); 269 284 let _ = cache 270 285 .set( 271 286 &cache_key, 272 287 &now.timestamp().to_string(), 273 - std::time::Duration::from_secs(REAUTH_WINDOW_SECONDS as u64), 288 + std::time::Duration::from_secs(u64::try_from(REAUTH_WINDOW_SECONDS).unwrap_or(300)), 274 289 ) 275 290 .await; 276 291 Ok(now) ··· 290 305 user_repo: &dyn UserRepository, 291 306 _session_repo: &dyn SessionRepository, 292 307 did: &crate::types::Did, 293 - ) -> Vec<String> { 308 + ) -> Vec<ReauthMethod> { 294 309 let mut methods = Vec::new(); 295 310 296 311 let has_password = user_repo ··· 301 316 .is_some(); 302 317 303 318 if has_password { 304 - methods.push("password".to_string()); 319 + methods.push(ReauthMethod::Password); 305 320 } 306 321 307 322 let has_totp = user_repo.has_totp_enabled(did).await.unwrap_or(false); 308 323 if has_totp { 309 - methods.push("totp".to_string()); 324 + methods.push(ReauthMethod::Totp); 310 325 } 311 326 312 327 let has_passkeys = user_repo.has_passkeys(did).await.unwrap_or(false); 313 328 if has_passkeys { 314 - methods.push("passkey".to_string()); 329 + methods.push(ReauthMethod::Passkey); 315 330 } 316 331 317 332 methods ··· 332 347 cache: &std::sync::Arc<dyn crate::cache::Cache>, 333 348 did: &crate::types::Did, 334 349 ) -> bool { 335 - let cache_key = format!("reauth:{}", did); 350 + let cache_key = crate::cache_keys::reauth_key(did); 336 351 if let Some(timestamp_str) = cache.get(&cache_key).await 337 352 && let Ok(timestamp) = timestamp_str.parse::<i64>() 338 353 { ··· 355 370 pub struct ReauthRequiredError { 356 371 pub error: String, 357 372 pub message: String, 358 - pub reauth_methods: Vec<String>, 373 + pub reauth_methods: Vec<ReauthMethod>, 359 374 } 360 375 361 376 pub async fn reauth_required_response( ··· 428 443 pub struct MfaVerificationRequiredError { 429 444 pub error: String, 430 445 pub message: String, 431 - pub reauth_methods: Vec<String>, 446 + pub reauth_methods: Vec<ReauthMethod>, 432 447 }
+69 -51
crates/tranquil-pds/src/api/server/service_auth.rs
··· 2 2 use crate::api::error::ApiError; 3 3 use crate::state::AppState; 4 4 use crate::types::Did; 5 + use axum::http::Method; 5 6 use axum::{ 6 7 Json, 7 8 extract::{Query, State}, ··· 10 11 }; 11 12 use serde::{Deserialize, Serialize}; 12 13 use serde_json::json; 14 + use std::collections::HashSet; 15 + use std::sync::LazyLock; 13 16 use tracing::{error, info, warn}; 17 + use tranquil_types::Nsid; 18 + 19 + static CREATE_ACCOUNT_NSID: LazyLock<Nsid> = 20 + LazyLock::new(|| "com.atproto.server.createAccount".parse().unwrap()); 14 21 15 22 const HOUR_SECS: i64 = 3600; 16 23 const MINUTE_SECS: i64 = 60; 17 24 18 - const PROTECTED_METHODS: &[&str] = &[ 19 - "com.atproto.admin.sendEmail", 20 - "com.atproto.identity.requestPlcOperationSignature", 21 - "com.atproto.identity.signPlcOperation", 22 - "com.atproto.identity.updateHandle", 23 - "com.atproto.server.activateAccount", 24 - "com.atproto.server.confirmEmail", 25 - "com.atproto.server.createAppPassword", 26 - "com.atproto.server.deactivateAccount", 27 - "com.atproto.server.getAccountInviteCodes", 28 - "com.atproto.server.getSession", 29 - "com.atproto.server.listAppPasswords", 30 - "com.atproto.server.requestAccountDelete", 31 - "com.atproto.server.requestEmailConfirmation", 32 - "com.atproto.server.requestEmailUpdate", 33 - "com.atproto.server.revokeAppPassword", 34 - "com.atproto.server.updateEmail", 35 - ]; 25 + static PROTECTED_METHODS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| { 26 + [ 27 + "com.atproto.admin.sendEmail", 28 + "com.atproto.identity.requestPlcOperationSignature", 29 + "com.atproto.identity.signPlcOperation", 30 + "com.atproto.identity.updateHandle", 31 + "com.atproto.server.activateAccount", 32 + "com.atproto.server.confirmEmail", 33 + "com.atproto.server.createAppPassword", 34 + "com.atproto.server.deactivateAccount", 35 + "com.atproto.server.getAccountInviteCodes", 36 + "com.atproto.server.getSession", 37 + "com.atproto.server.listAppPasswords", 38 + "com.atproto.server.requestAccountDelete", 39 + "com.atproto.server.requestEmailConfirmation", 40 + "com.atproto.server.requestEmailUpdate", 41 + "com.atproto.server.revokeAppPassword", 42 + "com.atproto.server.updateEmail", 43 + ] 44 + .into_iter() 45 + .collect() 46 + }); 36 47 37 48 #[derive(Deserialize)] 38 49 pub struct GetServiceAuthParams { 39 - pub aud: String, 40 - pub lxm: Option<String>, 50 + pub aud: Did, 51 + pub lxm: Option<Nsid>, 41 52 pub exp: Option<i64>, 42 53 } 43 54 ··· 51 62 headers: axum::http::HeaderMap, 52 63 Query(params): Query<GetServiceAuthParams>, 53 64 ) -> Response { 54 - let auth_header = crate::util::get_header_str(&headers, "Authorization"); 55 - let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 65 + let auth_header = crate::util::get_header_str(&headers, axum::http::header::AUTHORIZATION); 66 + let dpop_proof = crate::util::get_header_str(&headers, crate::util::HEADER_DPOP); 56 67 info!( 57 68 has_auth_header = auth_header.is_some(), 58 69 has_dpop_proof = dpop_proof.is_some(), ··· 68 79 } 69 80 }; 70 81 71 - let (token, is_dpop) = if auth_header.len() >= 7 72 - && auth_header[..7].eq_ignore_ascii_case("bearer ") 73 - { 74 - (auth_header[7..].trim().to_string(), false) 75 - } else if auth_header.len() >= 5 && auth_header[..5].eq_ignore_ascii_case("dpop ") { 76 - (auth_header[5..].trim().to_string(), true) 77 - } else { 78 - warn!(auth_scheme = ?auth_header.split_whitespace().next(), "getServiceAuth: invalid auth scheme"); 79 - return ApiError::AuthenticationRequired.into_response(); 82 + let extracted = match crate::auth::extract_auth_token_from_header(Some(auth_header)) { 83 + Some(e) => e, 84 + None => { 85 + warn!(auth_scheme = ?auth_header.split_whitespace().next(), "getServiceAuth: invalid auth scheme"); 86 + return ApiError::AuthenticationRequired.into_response(); 87 + } 80 88 }; 89 + let token = extracted.token; 81 90 82 - let auth_user = if is_dpop { 91 + let auth_user = if extracted.scheme.is_dpop() { 83 92 match crate::oauth::verify::verify_oauth_access_token( 84 93 state.oauth_repo.as_ref(), 85 94 &token, 86 95 dpop_proof, 87 - "GET", 96 + Method::GET.as_str(), 88 97 &crate::util::build_full_url(&format!( 89 98 "/xrpc/com.atproto.server.getServiceAuth?aud={}&lxm={}", 90 99 params.aud, 91 - params.lxm.as_deref().unwrap_or("") 100 + params.lxm.as_ref().map_or("", |n| n.as_str()) 92 101 )), 93 102 ) 94 103 .await 95 104 { 96 - Ok(result) => crate::auth::AuthenticatedUser { 97 - did: unsafe { Did::new_unchecked(result.did) }, 98 - is_admin: false, 99 - status: AccountStatus::Active, 100 - scope: result.scope, 101 - key_bytes: None, 102 - controller_did: None, 103 - auth_source: crate::auth::AuthSource::OAuth, 104 - }, 105 + Ok(result) => { 106 + let did: Did = match result.did.parse() { 107 + Ok(d) => d, 108 + Err(_) => { 109 + return ApiError::InternalError(Some("Invalid DID in token".into())) 110 + .into_response(); 111 + } 112 + }; 113 + crate::auth::AuthenticatedUser { 114 + did, 115 + is_admin: false, 116 + status: AccountStatus::Active, 117 + scope: result.scope, 118 + key_bytes: None, 119 + controller_did: None, 120 + auth_source: crate::auth::AuthSource::OAuth, 121 + } 122 + } 105 123 Err(crate::oauth::OAuthError::UseDpopNonce(nonce)) => { 106 124 return ( 107 125 StatusCode::UNAUTHORIZED, ··· 179 197 } 180 198 }; 181 199 182 - let lxm = params.lxm.as_deref(); 183 - let lxm_for_token = lxm.unwrap_or("*"); 200 + let lxm = params.lxm.as_ref(); 201 + let lxm_for_token = lxm.map_or("*", |n| n.as_str()); 184 202 185 203 if let Some(method) = lxm { 186 204 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 187 - auth_user.is_oauth(), 205 + &auth_user.auth_source, 188 206 auth_user.scope.as_deref(), 189 - &params.aud, 190 - method, 207 + params.aud.as_str(), 208 + method.as_str(), 191 209 ) { 192 210 return e; 193 211 } ··· 209 227 .flatten() 210 228 .is_some_and(|s| s.takedown_ref.is_some()); 211 229 212 - if is_takendown && lxm != Some("com.atproto.server.createAccount") { 230 + if is_takendown && lxm != Some(&*CREATE_ACCOUNT_NSID) { 213 231 return ApiError::InvalidToken(Some("Bad token scope".into())).into_response(); 214 232 } 215 233 216 234 if let Some(method) = lxm 217 - && PROTECTED_METHODS.contains(&method) 235 + && PROTECTED_METHODS.contains(&method.as_str()) 218 236 { 219 237 return ApiError::InvalidRequest(format!( 220 238 "cannot request a service auth token for the following protected method: {}", ··· 248 266 249 267 let service_token = match crate::auth::create_service_token( 250 268 &auth_user.did, 251 - &params.aud, 269 + params.aud.as_str(), 252 270 lxm_for_token, 253 271 &key_bytes, 254 272 ) {
+36 -54
crates/tranquil-pds/src/api/server/session.rs
··· 266 266 refresh_jti: refresh_meta.jti.clone(), 267 267 access_expires_at: access_meta.expires_at, 268 268 refresh_expires_at: refresh_meta.expires_at, 269 - login_type: tranquil_db_traits::LoginType::from(is_legacy_login), 269 + login_type: tranquil_db_traits::LoginType::from_legacy_flag(is_legacy_login), 270 270 mfa_verified: false, 271 271 scope: app_password_scopes.clone(), 272 272 controller_did: app_password_controller.clone(), ··· 338 338 ); 339 339 match db_result { 340 340 Ok(Some(row)) => { 341 - let preferred_channel = match row.preferred_comms_channel { 342 - tranquil_db_traits::CommsChannel::Email => "email", 343 - tranquil_db_traits::CommsChannel::Discord => "discord", 344 - tranquil_db_traits::CommsChannel::Telegram => "telegram", 345 - tranquil_db_traits::CommsChannel::Signal => "signal", 346 - }; 347 341 let preferred_channel_verified = row 348 342 .channel_verification 349 343 .is_verified(row.preferred_comms_channel); ··· 365 359 "handle": handle, 366 360 "did": &auth.did, 367 361 "active": account_state.is_active(), 368 - "preferredChannel": preferred_channel, 362 + "preferredChannel": row.preferred_comms_channel.as_str(), 369 363 "preferredChannelVerified": preferred_channel_verified, 370 364 "preferredLocale": row.preferred_locale, 371 365 "isAdmin": row.is_admin ··· 404 398 ) -> Result<Response, ApiError> { 405 399 let extracted = crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 406 400 &headers, 407 - "Authorization", 401 + http::header::AUTHORIZATION, 408 402 )) 409 403 .ok_or(ApiError::AuthenticationRequired)?; 410 404 let jti = crate::auth::get_jti_from_token(&extracted.token) ··· 413 407 match state.session_repo.delete_session_by_access_jti(&jti).await { 414 408 Ok(rows) if rows > 0 => { 415 409 if let Some(did) = did { 416 - let session_cache_key = format!("auth:session:{}:{}", did, jti); 410 + let session_cache_key = crate::cache_keys::session_key(&did, &jti); 417 411 let _ = state.cache.delete(&session_cache_key).await; 418 412 } 419 413 Ok(EmptyResponse::ok().into_response()) ··· 430 424 ) -> Response { 431 425 let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 432 426 &headers, 433 - "Authorization", 427 + http::header::AUTHORIZATION, 434 428 )) { 435 429 Some(t) => t, 436 430 None => return ApiError::AuthenticationRequired.into_response(), ··· 548 542 ); 549 543 match db_result { 550 544 Ok(Some(u)) => { 551 - let preferred_channel = match u.preferred_comms_channel { 552 - tranquil_db_traits::CommsChannel::Email => "email", 553 - tranquil_db_traits::CommsChannel::Discord => "discord", 554 - tranquil_db_traits::CommsChannel::Telegram => "telegram", 555 - tranquil_db_traits::CommsChannel::Signal => "signal", 556 - }; 557 545 let preferred_channel_verified = u 558 546 .channel_verification 559 547 .is_verified(u.preferred_comms_channel); ··· 568 556 "did": session_row.did, 569 557 "email": u.email, 570 558 "emailConfirmed": u.channel_verification.email, 571 - "preferredChannel": preferred_channel, 559 + "preferredChannel": u.preferred_comms_channel.as_str(), 572 560 "preferredChannelVerified": preferred_channel_verified, 573 561 "preferredLocale": u.preferred_locale, 574 562 "isAdmin": u.is_admin, ··· 609 597 pub did: Did, 610 598 pub email: Option<String>, 611 599 pub email_verified: bool, 612 - pub preferred_channel: String, 600 + pub preferred_channel: tranquil_db_traits::CommsChannel, 613 601 pub preferred_channel_verified: bool, 614 602 } 615 603 ··· 631 619 } 632 620 }; 633 621 634 - let (channel_str, identifier) = match row.channel { 635 - tranquil_db_traits::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 622 + let identifier = match row.channel { 623 + tranquil_db_traits::CommsChannel::Email => row.email.clone().unwrap_or_default(), 636 624 tranquil_db_traits::CommsChannel::Discord => { 637 - ("discord", row.discord_username.clone().unwrap_or_default()) 625 + row.discord_username.clone().unwrap_or_default() 638 626 } 639 - tranquil_db_traits::CommsChannel::Telegram => ( 640 - "telegram", 641 - row.telegram_username.clone().unwrap_or_default(), 642 - ), 643 - tranquil_db_traits::CommsChannel::Signal => { 644 - ("signal", row.signal_username.clone().unwrap_or_default()) 627 + tranquil_db_traits::CommsChannel::Telegram => { 628 + row.telegram_username.clone().unwrap_or_default() 645 629 } 630 + tranquil_db_traits::CommsChannel::Signal => row.signal_username.clone().unwrap_or_default(), 646 631 }; 647 632 648 633 let normalized_token = 649 634 crate::auth::verification_token::normalize_token_input(&input.verification_code); 650 635 match crate::auth::verification_token::verify_signup_token( 651 636 &normalized_token, 652 - channel_str, 637 + row.channel, 653 638 &identifier, 654 639 ) { 655 640 Ok(token_data) => { 656 - if token_data.did != input.did.as_str() { 641 + if token_data.did != input.did { 657 642 warn!( 658 643 "Token DID mismatch for confirm_signup: expected {}, got {}", 659 644 input.did, token_data.did ··· 733 718 { 734 719 warn!("Failed to enqueue welcome notification: {:?}", e); 735 720 } 736 - let email_verified = matches!(row.channel, tranquil_db_traits::CommsChannel::Email); 737 - let preferred_channel = match row.channel { 738 - tranquil_db_traits::CommsChannel::Email => "email", 739 - tranquil_db_traits::CommsChannel::Discord => "discord", 740 - tranquil_db_traits::CommsChannel::Telegram => "telegram", 741 - tranquil_db_traits::CommsChannel::Signal => "signal", 742 - }; 743 721 Json(ConfirmSignupOutput { 744 722 access_jwt: access_meta.token, 745 723 refresh_jwt: refresh_meta.token, 746 724 handle: row.handle, 747 725 did: row.did, 748 726 email: row.email, 749 - email_verified, 750 - preferred_channel: preferred_channel.to_string(), 727 + email_verified: matches!(row.channel, tranquil_db_traits::CommsChannel::Email), 728 + preferred_channel: row.channel, 751 729 preferred_channel_verified: true, 752 730 }) 753 731 .into_response() ··· 783 761 return ApiError::InvalidRequest("Account is already verified".into()).into_response(); 784 762 } 785 763 786 - let (channel_str, recipient) = match row.channel { 787 - tranquil_db_traits::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 764 + let recipient = match row.channel { 765 + tranquil_db_traits::CommsChannel::Email => row.email.clone().unwrap_or_default(), 788 766 tranquil_db_traits::CommsChannel::Discord => { 789 - ("discord", row.discord_username.clone().unwrap_or_default()) 767 + row.discord_username.clone().unwrap_or_default() 790 768 } 791 - tranquil_db_traits::CommsChannel::Telegram => ( 792 - "telegram", 793 - row.telegram_username.clone().unwrap_or_default(), 794 - ), 795 - tranquil_db_traits::CommsChannel::Signal => { 796 - ("signal", row.signal_username.clone().unwrap_or_default()) 769 + tranquil_db_traits::CommsChannel::Telegram => { 770 + row.telegram_username.clone().unwrap_or_default() 797 771 } 772 + tranquil_db_traits::CommsChannel::Signal => row.signal_username.clone().unwrap_or_default(), 798 773 }; 799 774 800 775 let verification_token = 801 - crate::auth::verification_token::generate_signup_token(&input.did, channel_str, &recipient); 776 + crate::auth::verification_token::generate_signup_token(&input.did, row.channel, &recipient); 802 777 let formatted_token = 803 778 crate::auth::verification_token::format_token_for_display(&verification_token); 804 779 ··· 807 782 state.user_repo.as_ref(), 808 783 state.infra_repo.as_ref(), 809 784 row.id, 810 - channel_str, 785 + row.channel, 811 786 &recipient, 812 787 &formatted_token, 813 788 hostname, ··· 820 795 } 821 796 822 797 #[derive(Serialize)] 798 + #[serde(rename_all = "lowercase")] 799 + pub enum SessionType { 800 + Legacy, 801 + OAuth, 802 + } 803 + 804 + #[derive(Serialize)] 823 805 #[serde(rename_all = "camelCase")] 824 806 pub struct SessionInfo { 825 807 pub id: String, 826 - pub session_type: String, 808 + pub session_type: SessionType, 827 809 pub client_name: Option<String>, 828 810 pub created_at: String, 829 811 pub expires_at: String, ··· 861 843 862 844 let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 863 845 id: format!("jwt:{}", row.id), 864 - session_type: "legacy".to_string(), 846 + session_type: SessionType::Legacy, 865 847 client_name: None, 866 848 created_at: row.created_at.to_rfc3339(), 867 849 expires_at: row.refresh_expires_at.to_rfc3339(), ··· 874 856 let is_current_oauth = is_oauth && current_jti.as_deref() == Some(row.token_id.as_str()); 875 857 SessionInfo { 876 858 id: format!("oauth:{}", row.id), 877 - session_type: "oauth".to_string(), 859 + session_type: SessionType::OAuth, 878 860 client_name: Some(client_name), 879 861 created_at: row.created_at.to_rfc3339(), 880 862 expires_at: row.expires_at.to_rfc3339(), ··· 925 907 .delete_session_by_id(session_id) 926 908 .await 927 909 .log_db_err("deleting session")?; 928 - let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti); 910 + let cache_key = crate::cache_keys::session_key(&auth.did, &access_jti); 929 911 if let Err(e) = state.cache.delete(&cache_key).await { 930 912 warn!("Failed to invalidate session cache: {:?}", e); 931 913 }
+7 -9
crates/tranquil-pds/src/api/server/signing_key.rs
··· 25 25 26 26 #[derive(Deserialize)] 27 27 pub struct ReserveSigningKeyInput { 28 - pub did: Option<String>, 28 + pub did: Option<crate::types::Did>, 29 29 } 30 30 31 31 #[derive(Serialize)] ··· 38 38 State(state): State<AppState>, 39 39 Json(input): Json<ReserveSigningKeyInput>, 40 40 ) -> Response { 41 - let did: Option<crate::types::Did> = match input.did { 42 - Some(ref d) => match d.parse() { 43 - Ok(parsed) => Some(parsed), 44 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 45 - }, 46 - None => None, 47 - }; 48 41 let signing_key = SigningKey::random(&mut rand::thread_rng()); 49 42 let private_key_bytes = signing_key.to_bytes(); 50 43 let public_key_did_key = public_key_to_did_key(&signing_key); ··· 52 45 let private_bytes: &[u8] = &private_key_bytes; 53 46 match state 54 47 .infra_repo 55 - .reserve_signing_key(did.as_ref(), &public_key_did_key, private_bytes, expires_at) 48 + .reserve_signing_key( 49 + input.did.as_ref(), 50 + &public_key_did_key, 51 + private_bytes, 52 + expires_at, 53 + ) 56 54 .await 57 55 { 58 56 Ok(key_id) => {
+13 -23
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 103 103 #[derive(Deserialize)] 104 104 #[serde(rename_all = "camelCase")] 105 105 pub struct RevokeTrustedDeviceInput { 106 - pub device_id: String, 106 + pub device_id: DeviceId, 107 107 } 108 108 109 109 pub async fn revoke_trusted_device( ··· 111 111 auth: Auth<Active>, 112 112 Json(input): Json<RevokeTrustedDeviceInput>, 113 113 ) -> Result<Response, ApiError> { 114 - let device_id = DeviceId::from(input.device_id.clone()); 115 114 match state 116 115 .oauth_repo 117 - .device_belongs_to_user(&device_id, &auth.did) 116 + .device_belongs_to_user(&input.device_id, &auth.did) 118 117 .await 119 118 { 120 119 Ok(true) => {} ··· 129 128 130 129 state 131 130 .oauth_repo 132 - .revoke_device_trust(&device_id) 131 + .revoke_device_trust(&input.device_id) 133 132 .await 134 133 .log_db_err("revoking device trust")?; 135 134 ··· 140 139 #[derive(Deserialize)] 141 140 #[serde(rename_all = "camelCase")] 142 141 pub struct UpdateTrustedDeviceInput { 143 - pub device_id: String, 142 + pub device_id: DeviceId, 144 143 pub friendly_name: Option<String>, 145 144 } 146 145 ··· 149 148 auth: Auth<Active>, 150 149 Json(input): Json<UpdateTrustedDeviceInput>, 151 150 ) -> Result<Response, ApiError> { 152 - let device_id = DeviceId::from(input.device_id.clone()); 153 151 match state 154 152 .oauth_repo 155 - .device_belongs_to_user(&device_id, &auth.did) 153 + .device_belongs_to_user(&input.device_id, &auth.did) 156 154 .await 157 155 { 158 156 Ok(true) => {} ··· 167 165 168 166 state 169 167 .oauth_repo 170 - .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 168 + .update_device_friendly_name(&input.device_id, input.friendly_name.as_deref()) 171 169 .await 172 170 .log_db_err("updating device friendly name")?; 173 171 ··· 177 175 178 176 pub async fn get_device_trust_state( 179 177 oauth_repo: &dyn OAuthRepository, 180 - device_id: &str, 178 + device_id: &DeviceId, 181 179 did: &tranquil_types::Did, 182 180 ) -> DeviceTrustState { 183 - let device_id_typed = DeviceId::from(device_id.to_string()); 184 - match oauth_repo 185 - .get_device_trust_info(&device_id_typed, did) 186 - .await 187 - { 181 + match oauth_repo.get_device_trust_info(device_id, did).await { 188 182 Ok(Some(info)) => DeviceTrustState::from_timestamps(info.trusted_at, info.trusted_until), 189 183 _ => DeviceTrustState::Untrusted, 190 184 } ··· 192 186 193 187 pub async fn is_device_trusted( 194 188 oauth_repo: &dyn OAuthRepository, 195 - device_id: &str, 189 + device_id: &DeviceId, 196 190 did: &tranquil_types::Did, 197 191 ) -> bool { 198 192 get_device_trust_state(oauth_repo, device_id, did) ··· 202 196 203 197 pub async fn trust_device( 204 198 oauth_repo: &dyn OAuthRepository, 205 - device_id: &str, 199 + device_id: &DeviceId, 206 200 ) -> Result<(), tranquil_db_traits::DbError> { 207 201 let now = Utc::now(); 208 202 let trusted_until = now + Duration::days(TRUST_DURATION_DAYS); 209 - let device_id_typed = DeviceId::from(device_id.to_string()); 210 - oauth_repo 211 - .trust_device(&device_id_typed, now, trusted_until) 212 - .await 203 + oauth_repo.trust_device(device_id, now, trusted_until).await 213 204 } 214 205 215 206 pub async fn extend_device_trust( 216 207 oauth_repo: &dyn OAuthRepository, 217 - device_id: &str, 208 + device_id: &DeviceId, 218 209 ) -> Result<(), tranquil_db_traits::DbError> { 219 210 let trusted_until = Utc::now() + Duration::days(TRUST_DURATION_DAYS); 220 - let device_id_typed = DeviceId::from(device_id.to_string()); 221 211 oauth_repo 222 - .extend_device_trust(&device_id_typed, trusted_until) 212 + .extend_device_trust(device_id, trusted_until) 223 213 .await 224 214 }
+40 -54
crates/tranquil-pds/src/api/server/verify_token.rs
··· 10 10 VerificationPurpose, normalize_token_input, verify_token_signature, 11 11 }; 12 12 use crate::state::AppState; 13 + use tranquil_db_traits::CommsChannel; 13 14 14 15 #[derive(Deserialize, Clone)] 15 16 #[serde(rename_all = "camelCase")] ··· 23 24 pub struct VerifyTokenOutput { 24 25 pub success: bool, 25 26 pub did: Did, 26 - pub purpose: String, 27 - pub channel: String, 27 + pub purpose: VerificationPurpose, 28 + pub channel: CommsChannel, 28 29 } 29 30 30 31 pub async fn verify_token( ··· 53 54 54 55 match token_data.purpose { 55 56 VerificationPurpose::Migration => { 56 - handle_migration_verification(state, &token_data.did, &token_data.channel, &identifier) 57 + handle_migration_verification(state, &token_data.did, token_data.channel, &identifier) 57 58 .await 58 59 } 59 60 VerificationPurpose::ChannelUpdate => { 60 - handle_channel_update(state, &token_data.did, &token_data.channel, &identifier).await 61 + handle_channel_update(state, &token_data.did, token_data.channel, &identifier).await 61 62 } 62 63 VerificationPurpose::Signup => { 63 - handle_signup_verification(state, &token_data.did, &token_data.channel, &identifier) 64 + handle_signup_verification(state, &token_data.did, token_data.channel, &identifier) 64 65 .await 65 66 } 66 67 } ··· 68 69 69 70 async fn handle_migration_verification( 70 71 state: &AppState, 71 - did: &str, 72 - channel: &str, 72 + did: &Did, 73 + channel: CommsChannel, 73 74 identifier: &str, 74 75 ) -> Result<Json<VerifyTokenOutput>, ApiError> { 75 - if channel != "email" { 76 + if channel != CommsChannel::Email { 76 77 return Err(ApiError::InvalidChannel); 77 78 } 78 79 79 - let did_typed: Did = did 80 - .parse() 81 - .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 82 80 let user = state 83 81 .user_repo 84 - .get_verification_info(&did_typed) 82 + .get_verification_info(did) 85 83 .await 86 84 .log_db_err("during migration verification")? 87 85 .ok_or(ApiError::AccountNotFound)?; ··· 102 100 103 101 Ok(Json(VerifyTokenOutput { 104 102 success: true, 105 - did: did.to_string().into(), 106 - purpose: "migration".to_string(), 107 - channel: channel.to_string(), 103 + did: did.clone(), 104 + purpose: VerificationPurpose::Migration, 105 + channel, 108 106 })) 109 107 } 110 108 111 109 async fn handle_channel_update( 112 110 state: &AppState, 113 - did: &str, 114 - channel: &str, 111 + did: &Did, 112 + channel: CommsChannel, 115 113 identifier: &str, 116 114 ) -> Result<Json<VerifyTokenOutput>, ApiError> { 117 - let did_typed: Did = did 118 - .parse() 119 - .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 120 115 let user_id = state 121 116 .user_repo 122 - .get_id_by_did(&did_typed) 117 + .get_id_by_did(did) 123 118 .await 124 119 .log_db_err("fetching user id")? 125 120 .ok_or(ApiError::AccountNotFound)?; 126 121 127 122 match channel { 128 - "email" => { 123 + CommsChannel::Email => { 129 124 let success = state 130 125 .user_repo 131 126 .verify_email_channel(user_id, identifier) ··· 135 130 return Err(ApiError::EmailTaken); 136 131 } 137 132 } 138 - "discord" => { 133 + CommsChannel::Discord => { 139 134 state 140 135 .user_repo 141 136 .verify_discord_channel(user_id, identifier) 142 137 .await 143 138 .log_db_err("updating discord channel")?; 144 139 } 145 - "telegram" => { 140 + CommsChannel::Telegram => { 146 141 state 147 142 .user_repo 148 143 .verify_telegram_channel(user_id, identifier) 149 144 .await 150 145 .log_db_err("updating telegram channel")?; 151 146 } 152 - "signal" => { 147 + CommsChannel::Signal => { 153 148 state 154 149 .user_repo 155 150 .verify_signal_channel(user_id, identifier) 156 151 .await 157 152 .log_db_err("updating signal channel")?; 158 153 } 159 - _ => { 160 - return Err(ApiError::InvalidChannel); 161 - } 162 154 }; 163 155 164 - info!(did = %did, channel = %channel, "Channel verified successfully"); 156 + info!(did = %did, channel = ?channel, "Channel verified successfully"); 165 157 166 158 let recipient = resolve_verified_recipient(state, user_id, channel, identifier).await; 167 159 if let Err(e) = comms_repo::enqueue_channel_verified( ··· 179 171 180 172 Ok(Json(VerifyTokenOutput { 181 173 success: true, 182 - did: did.to_string().into(), 183 - purpose: "channel_update".to_string(), 184 - channel: channel.to_string(), 174 + did: did.clone(), 175 + purpose: VerificationPurpose::ChannelUpdate, 176 + channel, 185 177 })) 186 178 } 187 179 188 180 async fn resolve_verified_recipient( 189 181 state: &AppState, 190 182 user_id: uuid::Uuid, 191 - channel: &str, 183 + channel: tranquil_db_traits::CommsChannel, 192 184 identifier: &str, 193 185 ) -> String { 194 186 match channel { 195 - "telegram" => state 187 + tranquil_db_traits::CommsChannel::Telegram => state 196 188 .user_repo 197 189 .get_telegram_chat_id(user_id) 198 190 .await ··· 206 198 207 199 async fn handle_signup_verification( 208 200 state: &AppState, 209 - did: &str, 210 - channel: &str, 201 + did: &Did, 202 + channel: CommsChannel, 211 203 identifier: &str, 212 204 ) -> Result<Json<VerifyTokenOutput>, ApiError> { 213 - let did_typed: Did = did 214 - .parse() 215 - .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 216 205 let user = state 217 206 .user_repo 218 - .get_verification_info(&did_typed) 207 + .get_verification_info(did) 219 208 .await 220 209 .log_db_err("during signup verification")? 221 210 .ok_or(ApiError::AccountNotFound)?; ··· 225 214 info!(did = %did, "Account already verified"); 226 215 return Ok(Json(VerifyTokenOutput { 227 216 success: true, 228 - did: did.to_string().into(), 229 - purpose: "signup".to_string(), 230 - channel: channel.to_string(), 217 + did: did.clone(), 218 + purpose: VerificationPurpose::Signup, 219 + channel, 231 220 })); 232 221 } 233 222 234 223 match channel { 235 - "email" => { 224 + CommsChannel::Email => { 236 225 state 237 226 .user_repo 238 227 .set_email_verified_flag(user.id) 239 228 .await 240 229 .log_db_err("updating email verified status")?; 241 230 } 242 - "discord" => { 231 + CommsChannel::Discord => { 243 232 state 244 233 .user_repo 245 234 .set_discord_verified_flag(user.id) 246 235 .await 247 236 .log_db_err("updating discord verified status")?; 248 237 } 249 - "telegram" => { 238 + CommsChannel::Telegram => { 250 239 state 251 240 .user_repo 252 241 .set_telegram_verified_flag(user.id) 253 242 .await 254 243 .log_db_err("updating telegram verified status")?; 255 244 } 256 - "signal" => { 245 + CommsChannel::Signal => { 257 246 state 258 247 .user_repo 259 248 .set_signal_verified_flag(user.id) 260 249 .await 261 250 .log_db_err("updating signal verified status")?; 262 251 } 263 - _ => { 264 - return Err(ApiError::InvalidChannel); 265 - } 266 252 }; 267 253 268 - info!(did = %did, channel = %channel, "Signup verified successfully"); 254 + info!(did = %did, channel = ?channel, "Signup verified successfully"); 269 255 270 256 let recipient = resolve_verified_recipient(state, user.id, channel, identifier).await; 271 257 if let Err(e) = comms_repo::enqueue_channel_verified( ··· 283 269 284 270 Ok(Json(VerifyTokenOutput { 285 271 success: true, 286 - did: did.to_string().into(), 287 - purpose: "signup".to_string(), 288 - channel: channel.to_string(), 272 + did: did.clone(), 273 + purpose: VerificationPurpose::Signup, 274 + channel, 289 275 })) 290 276 }
+1 -1
crates/tranquil-pds/src/api/telegram_webhook.rs
··· 86 86 state.user_repo.as_ref(), 87 87 state.infra_repo.as_ref(), 88 88 user_id, 89 - "telegram", 89 + tranquil_db_traits::CommsChannel::Telegram, 90 90 &from.id.to_string(), 91 91 pds_hostname(), 92 92 )
+1 -1
crates/tranquil-pds/src/api/temp.rs
··· 57 57 58 58 for part in scope_parts { 59 59 if let Some(cid_str) = part.strip_prefix("ref:") { 60 - let cache_key = format!("scope_ref:{}", cid_str); 60 + let cache_key = crate::cache_keys::scope_ref_key(cid_str); 61 61 if let Some(cached) = state.cache.get(&cache_key).await { 62 62 for s in cached.split_whitespace() { 63 63 if !resolved_scopes.contains(&s.to_string()) {
+18 -7
crates/tranquil-pds/src/api/validation.rs
··· 23 23 } 24 24 25 25 pub fn new_allow_reserved(handle: impl AsRef<str>) -> Result<Self, HandleValidationError> { 26 - let validated = validate_service_handle(handle.as_ref(), true)?; 26 + let validated = validate_service_handle(handle.as_ref(), ReservedHandlePolicy::Allow)?; 27 27 Ok(Self(validated)) 28 28 } 29 29 ··· 252 252 253 253 impl std::error::Error for HandleValidationError {} 254 254 255 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 256 + pub enum ReservedHandlePolicy { 257 + Allow, 258 + Reject, 259 + } 260 + 255 261 pub fn validate_short_handle(handle: &str) -> Result<String, HandleValidationError> { 256 - validate_service_handle(handle, false) 262 + validate_service_handle(handle, ReservedHandlePolicy::Reject) 257 263 } 258 264 259 265 pub fn validate_service_handle( 260 266 handle: &str, 261 - allow_reserved: bool, 267 + reserved_policy: ReservedHandlePolicy, 262 268 ) -> Result<String, HandleValidationError> { 263 269 let handle = handle.trim(); 264 270 ··· 301 307 return Err(HandleValidationError::BannedWord); 302 308 } 303 309 304 - if !allow_reserved && crate::handle::reserved::is_reserved_subdomain(handle) { 310 + if reserved_policy == ReservedHandlePolicy::Reject 311 + && crate::handle::reserved::is_reserved_subdomain(handle) 312 + { 305 313 return Err(HandleValidationError::Reserved); 306 314 } 307 315 ··· 501 509 #[test] 502 510 fn test_allow_reserved() { 503 511 assert_eq!( 504 - validate_service_handle("admin", true), 512 + validate_service_handle("admin", ReservedHandlePolicy::Allow), 505 513 Ok("admin".to_string()) 506 514 ); 507 - assert_eq!(validate_service_handle("api", true), Ok("api".to_string())); 515 + assert_eq!( 516 + validate_service_handle("api", ReservedHandlePolicy::Allow), 517 + Ok("api".to_string()) 518 + ); 508 519 assert_eq!( 509 - validate_service_handle("admin", false), 520 + validate_service_handle("admin", ReservedHandlePolicy::Reject), 510 521 Err(HandleValidationError::Reserved) 511 522 ); 512 523 }
+1 -1
crates/tranquil-pds/src/api/verification.rs
··· 10 10 #[derive(Deserialize)] 11 11 #[serde(rename_all = "camelCase")] 12 12 pub struct ConfirmChannelVerificationInput { 13 - pub channel: String, 13 + pub channel: tranquil_db_traits::CommsChannel, 14 14 pub identifier: String, 15 15 pub code: String, 16 16 }
+51 -21
crates/tranquil-pds/src/appview/mod.rs
··· 6 6 use tokio::sync::RwLock; 7 7 use tracing::{debug, error, info, warn}; 8 8 9 + #[derive(Debug, thiserror::Error)] 10 + pub enum DidResolutionError { 11 + #[error("Invalid did:web format")] 12 + InvalidDidWeb, 13 + #[error("HTTP request failed: {0}")] 14 + HttpFailed(String), 15 + #[error("Invalid DID document: {0}")] 16 + InvalidDocument(String), 17 + #[error("DID not found")] 18 + NotFound, 19 + } 20 + 9 21 #[derive(Debug, Clone, Serialize, Deserialize)] 10 22 pub struct DidDocument { 11 23 pub id: String, ··· 78 90 } 79 91 } 80 92 81 - fn build_did_web_url(did: &str) -> Result<String, String> { 93 + fn build_did_web_url(did: &str) -> Result<String, DidResolutionError> { 82 94 let host = did 83 95 .strip_prefix("did:web:") 84 - .ok_or("Invalid did:web format")?; 96 + .ok_or(DidResolutionError::InvalidDidWeb)?; 85 97 86 98 let (host, path) = if host.contains(':') { 87 99 let decoded = host.replace("%3A", ":"); ··· 184 196 self.extract_service_endpoint(&doc) 185 197 } 186 198 187 - async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> { 199 + async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, DidResolutionError> { 188 200 let url = Self::build_did_web_url(did)?; 189 201 190 202 debug!("Resolving did:web {} via {}", did, url); ··· 194 206 .get(&url) 195 207 .send() 196 208 .await 197 - .map_err(|e| format!("HTTP request failed: {}", e))?; 209 + .map_err(|e| DidResolutionError::HttpFailed(e.to_string()))?; 198 210 199 211 if !resp.status().is_success() { 200 - return Err(format!("HTTP {}", resp.status())); 212 + return Err(DidResolutionError::HttpFailed(format!( 213 + "HTTP {}", 214 + resp.status() 215 + ))); 201 216 } 202 217 203 218 resp.json::<DidDocument>() 204 219 .await 205 - .map_err(|e| format!("Failed to parse DID document: {}", e)) 220 + .map_err(|e| DidResolutionError::InvalidDocument(e.to_string())) 206 221 } 207 222 208 - async fn resolve_did_plc(&self, did: &str) -> Result<DidDocument, String> { 223 + async fn resolve_did_plc(&self, did: &str) -> Result<DidDocument, DidResolutionError> { 209 224 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 210 225 211 226 debug!("Resolving did:plc {} via {}", did, url); ··· 215 230 .get(&url) 216 231 .send() 217 232 .await 218 - .map_err(|e| format!("HTTP request failed: {}", e))?; 233 + .map_err(|e| DidResolutionError::HttpFailed(e.to_string()))?; 219 234 220 235 if resp.status() == reqwest::StatusCode::NOT_FOUND { 221 - return Err("DID not found".to_string()); 236 + return Err(DidResolutionError::NotFound); 222 237 } 223 238 224 239 if !resp.status().is_success() { 225 - return Err(format!("HTTP {}", resp.status())); 240 + return Err(DidResolutionError::HttpFailed(format!( 241 + "HTTP {}", 242 + resp.status() 243 + ))); 226 244 } 227 245 228 246 resp.json::<DidDocument>() 229 247 .await 230 - .map_err(|e| format!("Failed to parse DID document: {}", e)) 248 + .map_err(|e| DidResolutionError::InvalidDocument(e.to_string())) 231 249 } 232 250 233 251 fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> { 234 252 if let Some(service) = doc.service.iter().find(|s| { 235 - s.service_type == "AtprotoAppView" 253 + s.service_type == crate::plc::ServiceType::AppView.as_str() 236 254 || s.id.contains("atproto_appview") 237 255 || s.id.ends_with("#bsky_appview") 238 256 }) { ··· 329 347 } 330 348 } 331 349 332 - async fn fetch_did_document_web(&self, did: &str) -> Result<serde_json::Value, String> { 350 + async fn fetch_did_document_web( 351 + &self, 352 + did: &str, 353 + ) -> Result<serde_json::Value, DidResolutionError> { 333 354 let url = Self::build_did_web_url(did)?; 334 355 335 356 let resp = self ··· 337 358 .get(&url) 338 359 .send() 339 360 .await 340 - .map_err(|e| format!("HTTP request failed: {}", e))?; 361 + .map_err(|e| DidResolutionError::HttpFailed(e.to_string()))?; 341 362 342 363 if !resp.status().is_success() { 343 - return Err(format!("HTTP {}", resp.status())); 364 + return Err(DidResolutionError::HttpFailed(format!( 365 + "HTTP {}", 366 + resp.status() 367 + ))); 344 368 } 345 369 346 370 resp.json::<serde_json::Value>() 347 371 .await 348 - .map_err(|e| format!("Failed to parse DID document: {}", e)) 372 + .map_err(|e| DidResolutionError::InvalidDocument(e.to_string())) 349 373 } 350 374 351 - async fn fetch_did_document_plc(&self, did: &str) -> Result<serde_json::Value, String> { 375 + async fn fetch_did_document_plc( 376 + &self, 377 + did: &str, 378 + ) -> Result<serde_json::Value, DidResolutionError> { 352 379 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 353 380 354 381 let resp = self ··· 356 383 .get(&url) 357 384 .send() 358 385 .await 359 - .map_err(|e| format!("HTTP request failed: {}", e))?; 386 + .map_err(|e| DidResolutionError::HttpFailed(e.to_string()))?; 360 387 361 388 if resp.status() == reqwest::StatusCode::NOT_FOUND { 362 - return Err("DID not found".to_string()); 389 + return Err(DidResolutionError::NotFound); 363 390 } 364 391 365 392 if !resp.status().is_success() { 366 - return Err(format!("HTTP {}", resp.status())); 393 + return Err(DidResolutionError::HttpFailed(format!( 394 + "HTTP {}", 395 + resp.status() 396 + ))); 367 397 } 368 398 369 399 resp.json::<serde_json::Value>() 370 400 .await 371 - .map_err(|e| format!("Failed to parse DID document: {}", e)) 401 + .map_err(|e| DidResolutionError::InvalidDocument(e.to_string())) 372 402 } 373 403 374 404 pub async fn invalidate_cache(&self, did: &str) {
+1 -1
crates/tranquil-pds/src/auth/email_token.rs
··· 55 55 } 56 56 57 57 fn current_timestamp() -> u64 { 58 - chrono::Utc::now().timestamp().max(0) as u64 58 + u64::try_from(chrono::Utc::now().timestamp()).unwrap_or(0) 59 59 } 60 60 61 61 pub async fn create_email_token(
+19 -7
crates/tranquil-pds/src/auth/extractor.rs
··· 64 64 } 65 65 } 66 66 67 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 68 + pub enum AuthScheme { 69 + Bearer, 70 + DPoP, 71 + } 72 + 73 + impl AuthScheme { 74 + pub fn is_dpop(self) -> bool { 75 + matches!(self, Self::DPoP) 76 + } 77 + } 78 + 67 79 pub struct ExtractedToken { 68 80 pub token: String, 69 - pub is_dpop: bool, 81 + pub scheme: AuthScheme, 70 82 } 71 83 72 84 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { ··· 100 112 } 101 113 return Some(ExtractedToken { 102 114 token: token.to_string(), 103 - is_dpop: false, 115 + scheme: AuthScheme::Bearer, 104 116 }); 105 117 } 106 118 ··· 111 123 } 112 124 return Some(ExtractedToken { 113 125 token: token.to_string(), 114 - is_dpop: true, 126 + scheme: AuthScheme::DPoP, 115 127 }); 116 128 } 117 129 ··· 255 267 } 256 268 } 257 269 258 - async fn verify_service_token(token: &str) -> Result<ServiceTokenClaims, AuthError> { 270 + async fn verify_service_token_claims(token: &str) -> Result<ServiceTokenClaims, AuthError> { 259 271 let verifier = ServiceTokenVerifier::new(); 260 272 let claims = verifier 261 273 .verify_service_token(token, None) ··· 289 301 extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 290 302 291 303 if is_service_token(&extracted.token) { 292 - let claims = verify_service_token(&extracted.token).await?; 304 + let claims = verify_service_token_claims(&extracted.token).await?; 293 305 return Ok(ExtractedAuth::Service(claims)); 294 306 } 295 307 296 - let dpop_proof = crate::util::get_header_str(&parts.headers, "DPoP"); 308 + let dpop_proof = crate::util::get_header_str(&parts.headers, crate::util::HEADER_DPOP); 297 309 let method = parts.method.as_str(); 298 310 let original_uri = parts 299 311 .extensions ··· 418 430 impl ServiceAuth { 419 431 pub fn require_lxm(&self, expected_lxm: &str) -> Result<(), ApiError> { 420 432 match &self.claims.lxm { 421 - Some(lxm) if lxm == "*" || lxm == expected_lxm => Ok(()), 433 + Some(lxm) if crate::auth::lxm_permits(lxm, expected_lxm) => Ok(()), 422 434 Some(lxm) => Err(ApiError::AuthorizationError(format!( 423 435 "Token lxm '{}' does not permit '{}'", 424 436 lxm, expected_lxm
+1 -1
crates/tranquil-pds/src/auth/legacy_2fa.rs
··· 135 135 } 136 136 137 137 fn current_timestamp() -> u64 { 138 - Utc::now().timestamp().max(0) as u64 138 + u64::try_from(Utc::now().timestamp()).unwrap_or(0) 139 139 } 140 140 141 141 fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
+100 -38
crates/tranquil-pds/src/auth/mod.rs
··· 26 26 27 27 pub use account_verified::{AccountVerified, require_not_migrated, require_verified_or_delegated}; 28 28 pub use extractor::{ 29 - Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown, 30 - Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header, 29 + Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, AuthScheme, ExtractedToken, 30 + NotTakendown, Permissive, ServiceAuth, extract_auth_token_from_header, 31 + extract_bearer_token_from_header, 31 32 }; 32 33 pub use mfa_verified::{ 33 34 MfaMethod, MfaVerified, require_legacy_session_mfa, require_reauth_window, ··· 39 40 RpcCall, ScopeAction, ScopeVerificationError, ScopeVerified, VerifyScope, WriteOpKind, 40 41 verify_batch_write_scopes, 41 42 }; 42 - pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 43 + pub use service::{ServiceTokenClaims, ServiceTokenError, ServiceTokenVerifier, is_service_token}; 43 44 44 45 pub use tranquil_auth::{ 45 - ActClaim, Claims, Header, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 46 - SCOPE_REFRESH, TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenData, 47 - TokenVerifyError, TokenWithMetadata, UnsafeClaims, create_access_token, 46 + ActClaim, Claims, Header, SigningAlgorithm, TokenData, TokenDecodeError, TokenScope, TokenType, 47 + TokenVerifyError, TokenWithMetadata, TotpError, UnsafeClaims, create_access_token, 48 48 create_access_token_hs256, create_access_token_hs256_with_metadata, 49 49 create_access_token_with_delegation, create_access_token_with_metadata, 50 50 create_access_token_with_scope_metadata, create_refresh_token, create_refresh_token_hs256, ··· 56 56 verify_refresh_token, verify_refresh_token_hs256, verify_token, verify_totp_code, 57 57 }; 58 58 59 - pub fn encrypt_totp_secret(secret: &[u8]) -> Result<Vec<u8>, String> { 59 + pub fn lxm_permits(lxm: &str, expected: &str) -> bool { 60 + lxm == "*" || lxm == expected 61 + } 62 + 63 + pub fn encrypt_totp_secret(secret: &[u8]) -> Result<Vec<u8>, crate::config::CryptoError> { 60 64 crate::config::encrypt_key(secret) 61 65 } 62 66 63 - pub fn decrypt_totp_secret(encrypted: &[u8], version: i32) -> Result<Vec<u8>, String> { 67 + pub fn decrypt_totp_secret( 68 + encrypted: &[u8], 69 + version: i32, 70 + ) -> Result<Vec<u8>, crate::config::CryptoError> { 64 71 crate::config::decrypt_key(encrypted, Some(version)) 65 72 } 66 73 ··· 113 120 } 114 121 } 115 122 123 + #[derive(Debug, Clone)] 116 124 pub enum AuthSource { 117 125 Session, 118 126 OAuth, ··· 162 170 pub fn require_lxm(&self, expected_lxm: &str) -> Result<(), ApiError> { 163 171 match self.auth_source.service_claims() { 164 172 Some(claims) => match &claims.lxm { 165 - Some(lxm) if lxm == "*" || lxm == expected_lxm => Ok(()), 173 + Some(lxm) if lxm_permits(lxm, expected_lxm) => Ok(()), 166 174 Some(lxm) => Err(ApiError::AuthorizationError(format!( 167 175 "Token lxm '{}' does not permit '{}'", 168 176 lxm, expected_lxm ··· 192 200 impl AuthenticatedUser { 193 201 pub fn permissions(&self) -> ScopePermissions { 194 202 if let Some(ref scope) = self.scope 195 - && scope != SCOPE_ACCESS 203 + && scope != TokenScope::Access.as_str() 196 204 { 197 205 return ScopePermissions::from_scope_string(Some(scope)); 198 206 } ··· 265 273 Ok(d) => d, 266 274 Err(_) => return Err(TokenValidationError::InvalidToken), 267 275 }; 268 - let key_cache_key = format!("auth:key:{}", did_str); 276 + let key_cache_key = crate::cache_keys::signing_key_key(did_str); 269 277 let mut cached_key: Option<Vec<u8>> = None; 270 278 271 279 if let Some(c) = cache { ··· 279 287 280 288 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 281 289 { 282 - let status_cache_key = format!("auth:status:{}", did_str); 290 + let status_cache_key = crate::cache_keys::user_status_key(did_str); 283 291 let cached_status: Option<CachedUserStatus> = if let Some(c) = cache { 284 292 c.get(&status_cache_key) 285 293 .await ··· 347 355 ) 348 356 .await; 349 357 350 - let status_cache_key = format!("auth:status:{}", did); 358 + let status_cache_key = crate::cache_keys::user_status_key(&did.to_string()); 351 359 let cached = CachedUserStatus { 352 360 deactivated: user.deactivated_at.is_some(), 353 361 takendown: user.takedown_ref.is_some(), ··· 386 394 match verify_access_token_typed(token, &decrypted_key) { 387 395 Ok(token_data) => { 388 396 let jti = &token_data.claims.jti; 389 - let session_cache_key = format!("auth:session:{}:{}", did, jti); 397 + let session_cache_key = crate::cache_keys::session_key(&did, &jti); 390 398 let mut session_valid = false; 391 399 392 400 if let Some(c) = cache { ··· 424 432 } 425 433 426 434 if session_valid { 427 - let controller_did = token_data 428 - .claims 429 - .act 430 - .as_ref() 431 - .map(|a| unsafe { Did::new_unchecked(a.sub.clone()) }); 435 + let controller_did: Option<Did> = match &token_data.claims.act { 436 + Some(act) => Some( 437 + act.sub 438 + .parse() 439 + .map_err(|_| TokenValidationError::InvalidToken)?, 440 + ), 441 + None => None, 442 + }; 432 443 let status = 433 444 AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at); 434 445 return Ok(AuthenticatedUser { ··· 479 490 } else { 480 491 None 481 492 }; 493 + let did: Did = oauth_token 494 + .did 495 + .parse() 496 + .map_err(|_| TokenValidationError::InvalidToken)?; 497 + let controller_did: Option<Did> = oauth_info 498 + .controller_did 499 + .map(|d| d.parse()) 500 + .transpose() 501 + .map_err(|_| TokenValidationError::InvalidToken)?; 482 502 return Ok(AuthenticatedUser { 483 - did: unsafe { Did::new_unchecked(oauth_token.did) }, 503 + did, 484 504 key_bytes, 485 505 is_admin: oauth_token.is_admin, 486 506 status, 487 507 scope: oauth_info.scope, 488 - controller_did: oauth_info 489 - .controller_did 490 - .map(|d| unsafe { Did::new_unchecked(d) }), 508 + controller_did, 491 509 auth_source: AuthSource::OAuth, 492 510 }); 493 511 } else { ··· 499 517 } 500 518 501 519 pub async fn invalidate_auth_cache(cache: &dyn Cache, did: &str) { 502 - let key_cache_key = format!("auth:key:{}", did); 503 - let status_cache_key = format!("auth:status:{}", did); 520 + let key_cache_key = crate::cache_keys::signing_key_key(did); 521 + let status_cache_key = crate::cache_keys::user_status_key(did); 504 522 let _ = cache.delete(&key_cache_key).await; 505 523 let _ = cache.delete(&status_cache_key).await; 506 524 } 507 525 508 - #[allow(clippy::too_many_arguments)] 526 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 527 + pub enum AccountRequirement { 528 + Active, 529 + NotTakendown, 530 + AnyStatus, 531 + } 532 + 509 533 pub async fn validate_token_with_dpop( 510 534 user_repo: &dyn UserRepository, 511 535 oauth_repo: &dyn OAuthRepository, 512 536 token: &str, 513 - is_dpop_token: bool, 537 + scheme: AuthScheme, 514 538 dpop_proof: Option<&str>, 515 539 http_method: &str, 516 540 http_uri: &str, 517 - allow_deactivated: bool, 518 - allow_takendown: bool, 541 + requirement: AccountRequirement, 519 542 ) -> Result<AuthenticatedUser, TokenValidationError> { 520 - if !is_dpop_token { 521 - if allow_takendown { 522 - return validate_bearer_token_allow_takendown(user_repo, token).await; 523 - } else if allow_deactivated { 524 - return validate_bearer_token_allow_deactivated(user_repo, token).await; 525 - } else { 526 - return validate_bearer_token(user_repo, token).await; 527 - } 543 + if !scheme.is_dpop() { 544 + return match requirement { 545 + AccountRequirement::AnyStatus => { 546 + validate_bearer_token_allow_takendown(user_repo, token).await 547 + } 548 + AccountRequirement::NotTakendown => { 549 + validate_bearer_token_allow_deactivated(user_repo, token).await 550 + } 551 + AccountRequirement::Active => validate_bearer_token(user_repo, token).await, 552 + }; 528 553 } 554 + let (allow_deactivated, allow_takendown) = match requirement { 555 + AccountRequirement::Active => (false, false), 556 + AccountRequirement::NotTakendown => (true, false), 557 + AccountRequirement::AnyStatus => (true, true), 558 + }; 529 559 match crate::oauth::verify::verify_oauth_access_token( 530 560 oauth_repo, 531 561 token, ··· 566 596 None 567 597 }; 568 598 Ok(AuthenticatedUser { 569 - did: unsafe { Did::new_unchecked(result.did) }, 599 + did: result_did, 570 600 key_bytes, 571 601 is_admin: user_info.is_admin, 572 602 status, ··· 581 611 Err(_) => Err(TokenValidationError::AuthenticationFailed), 582 612 } 583 613 } 614 + 615 + #[cfg(test)] 616 + mod tests { 617 + use super::*; 618 + 619 + #[test] 620 + fn test_lxm_permits_exact_match() { 621 + assert!(lxm_permits( 622 + "com.atproto.repo.uploadBlob", 623 + "com.atproto.repo.uploadBlob" 624 + )); 625 + } 626 + 627 + #[test] 628 + fn test_lxm_permits_wildcard() { 629 + assert!(lxm_permits("*", "com.atproto.repo.uploadBlob")); 630 + assert!(lxm_permits("*", "anything.at.all")); 631 + } 632 + 633 + #[test] 634 + fn test_lxm_permits_mismatch() { 635 + assert!(!lxm_permits( 636 + "com.atproto.repo.uploadBlob", 637 + "com.atproto.repo.createRecord" 638 + )); 639 + } 640 + 641 + #[test] 642 + fn test_lxm_permits_partial_not_wildcard() { 643 + assert!(!lxm_permits("com.atproto.*", "com.atproto.repo.uploadBlob")); 644 + } 645 + }
+22 -15
crates/tranquil-pds/src/auth/scope_check.rs
··· 7 7 AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions, 8 8 }; 9 9 10 - use super::SCOPE_ACCESS; 10 + use super::{AuthSource, TokenScope}; 11 11 12 - fn has_custom_scope(scope: Option<&str>) -> bool { 13 - match scope { 14 - None => false, 15 - Some(s) => s != SCOPE_ACCESS, 12 + fn requires_scope_check(auth_source: &AuthSource, scope: Option<&str>) -> bool { 13 + match auth_source { 14 + AuthSource::OAuth => true, 15 + _ => match scope { 16 + None => false, 17 + Some(s) => s != TokenScope::Access.as_str(), 18 + }, 16 19 } 17 20 } 18 21 19 22 pub fn check_repo_scope( 20 - is_oauth: bool, 23 + auth_source: &AuthSource, 21 24 scope: Option<&str>, 22 25 action: RepoAction, 23 26 collection: &str, 24 27 ) -> Result<(), Response> { 25 - if !is_oauth && !has_custom_scope(scope) { 28 + if !requires_scope_check(auth_source, scope) { 26 29 return Ok(()); 27 30 } 28 31 ··· 32 35 .map_err(|e| ApiError::InsufficientScope(Some(e.to_string())).into_response()) 33 36 } 34 37 35 - pub fn check_blob_scope(is_oauth: bool, scope: Option<&str>, mime: &str) -> Result<(), Response> { 36 - if !is_oauth && !has_custom_scope(scope) { 38 + pub fn check_blob_scope( 39 + auth_source: &AuthSource, 40 + scope: Option<&str>, 41 + mime: &str, 42 + ) -> Result<(), Response> { 43 + if !requires_scope_check(auth_source, scope) { 37 44 return Ok(()); 38 45 } 39 46 ··· 44 51 } 45 52 46 53 pub fn check_rpc_scope( 47 - is_oauth: bool, 54 + auth_source: &AuthSource, 48 55 scope: Option<&str>, 49 56 aud: &str, 50 57 lxm: &str, 51 58 ) -> Result<(), Response> { 52 - if !is_oauth && !has_custom_scope(scope) { 59 + if !requires_scope_check(auth_source, scope) { 53 60 return Ok(()); 54 61 } 55 62 ··· 60 67 } 61 68 62 69 pub fn check_account_scope( 63 - is_oauth: bool, 70 + auth_source: &AuthSource, 64 71 scope: Option<&str>, 65 72 attr: AccountAttr, 66 73 action: AccountAction, 67 74 ) -> Result<(), Response> { 68 - if !is_oauth && !has_custom_scope(scope) { 75 + if !requires_scope_check(auth_source, scope) { 69 76 return Ok(()); 70 77 } 71 78 ··· 76 83 } 77 84 78 85 pub fn check_identity_scope( 79 - is_oauth: bool, 86 + auth_source: &AuthSource, 80 87 scope: Option<&str>, 81 88 attr: IdentityAttr, 82 89 ) -> Result<(), Response> { 83 - if !is_oauth && !has_custom_scope(scope) { 90 + if !requires_scope_check(auth_source, scope) { 84 91 return Ok(()); 85 92 } 86 93
+180 -92
crates/tranquil-pds/src/auth/service.rs
··· 1 - use crate::types::Did; 2 1 use crate::util::pds_hostname; 3 - use anyhow::{Result, anyhow}; 4 2 use base64::Engine as _; 5 3 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 6 4 use chrono::Utc; ··· 9 7 use serde::{Deserialize, Serialize}; 10 8 use std::time::Duration; 11 9 use tracing::debug; 10 + use tranquil_types::Did; 11 + 12 + #[derive(Debug, thiserror::Error)] 13 + pub enum ServiceTokenError { 14 + #[error("Invalid token format")] 15 + InvalidFormat, 16 + #[error("Base64 decode failed")] 17 + Base64Decode(#[source] base64::DecodeError), 18 + #[error("JSON decode failed")] 19 + JsonDecode(#[source] serde_json::Error), 20 + #[error("Unsupported algorithm: {0}")] 21 + UnsupportedAlgorithm(super::SigningAlgorithm), 22 + #[error("Token expired")] 23 + Expired, 24 + #[error("Invalid audience: expected {expected}, got {actual}")] 25 + InvalidAudience { expected: Did, actual: Did }, 26 + #[error("Token lxm '{token_lxm}' does not permit '{required}'")] 27 + LxmMismatch { token_lxm: String, required: String }, 28 + #[error("Token missing lxm claim")] 29 + MissingLxm, 30 + #[error("Invalid signature format")] 31 + InvalidSignature(#[source] k256::ecdsa::Error), 32 + #[error("Signature verification failed")] 33 + SignatureVerificationFailed(#[source] k256::ecdsa::Error), 34 + #[error("No atproto verification method found")] 35 + NoVerificationMethod, 36 + #[error("Verification method missing publicKeyMultibase")] 37 + MissingPublicKey, 38 + #[error("Unsupported DID method")] 39 + UnsupportedDidMethod, 40 + #[error("DID not found: {0}")] 41 + DidNotFound(String), 42 + #[error("HTTP request failed")] 43 + HttpFailed(#[source] reqwest::Error), 44 + #[error("Failed to parse DID document")] 45 + InvalidDidDocument(#[source] reqwest::Error), 46 + #[error("HTTP {0}")] 47 + HttpStatus(reqwest::StatusCode), 48 + #[error("Invalid multibase encoding")] 49 + InvalidMultibase(#[source] multibase::Error), 50 + #[error("Invalid multicodec data")] 51 + InvalidMulticodec, 52 + #[error("Unsupported key type: expected secp256k1")] 53 + UnsupportedKeyType, 54 + #[error("Invalid public key")] 55 + InvalidPublicKey(#[source] k256::ecdsa::Error), 56 + } 57 + 58 + struct JwtParts<'a> { 59 + header: &'a str, 60 + claims: &'a str, 61 + signature: &'a str, 62 + } 63 + 64 + impl<'a> JwtParts<'a> { 65 + fn parse(token: &'a str) -> Result<Self, ServiceTokenError> { 66 + let mut parts = token.splitn(4, '.'); 67 + match (parts.next(), parts.next(), parts.next(), parts.next()) { 68 + (Some(header), Some(claims), Some(signature), None) => Ok(Self { 69 + header, 70 + claims, 71 + signature, 72 + }), 73 + _ => Err(ServiceTokenError::InvalidFormat), 74 + } 75 + } 76 + 77 + fn signing_input(&self) -> String { 78 + format!("{}.{}", self.header, self.claims) 79 + } 80 + } 12 81 13 82 #[derive(Debug, Clone, Serialize, Deserialize)] 14 83 #[serde(rename_all = "camelCase")] ··· 48 117 #[serde(default)] 49 118 pub sub: Option<Did>, 50 119 pub aud: Did, 51 - pub exp: usize, 120 + pub exp: i64, 52 121 #[serde(default)] 53 - pub iat: Option<usize>, 122 + pub iat: Option<i64>, 54 123 #[serde(skip_serializing_if = "Option::is_none")] 55 124 pub lxm: Option<String>, 56 125 #[serde(default)] ··· 65 134 66 135 #[derive(Debug, Clone, Serialize, Deserialize)] 67 136 struct TokenHeader { 68 - pub alg: String, 69 - pub typ: String, 137 + pub alg: super::SigningAlgorithm, 138 + pub typ: super::TokenType, 70 139 } 71 140 72 141 pub struct ServiceTokenVerifier { 73 142 client: Client, 74 143 plc_directory_url: String, 75 - pds_did: String, 144 + pds_did: Did, 76 145 } 77 146 78 147 impl ServiceTokenVerifier { ··· 81 150 .unwrap_or_else(|_| "https://plc.directory".to_string()); 82 151 83 152 let pds_hostname = pds_hostname(); 84 - let pds_did = format!("did:web:{}", pds_hostname); 153 + let pds_did: Did = format!("did:web:{}", pds_hostname) 154 + .parse() 155 + .expect("PDS hostname produces a valid DID"); 85 156 86 157 let client = Client::builder() 87 158 .timeout(Duration::from_secs(10)) ··· 102 173 &self, 103 174 token: &str, 104 175 required_lxm: Option<&str>, 105 - ) -> Result<ServiceTokenClaims> { 106 - let parts: Vec<&str> = token.split('.').collect(); 107 - if parts.len() != 3 { 108 - return Err(anyhow!("Invalid token format")); 109 - } 176 + ) -> Result<ServiceTokenClaims, ServiceTokenError> { 177 + let jwt = JwtParts::parse(token)?; 110 178 111 179 let header_bytes = URL_SAFE_NO_PAD 112 - .decode(parts[0]) 113 - .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?; 180 + .decode(jwt.header) 181 + .map_err(ServiceTokenError::Base64Decode)?; 114 182 115 - let header: TokenHeader = serde_json::from_slice(&header_bytes) 116 - .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?; 183 + let header: TokenHeader = 184 + serde_json::from_slice(&header_bytes).map_err(ServiceTokenError::JsonDecode)?; 117 185 118 - if header.alg != "ES256K" { 119 - return Err(anyhow!("Unsupported algorithm: {}", header.alg)); 186 + if header.alg != super::SigningAlgorithm::ES256K { 187 + return Err(ServiceTokenError::UnsupportedAlgorithm(header.alg)); 120 188 } 121 189 122 190 let claims_bytes = URL_SAFE_NO_PAD 123 - .decode(parts[1]) 124 - .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?; 191 + .decode(jwt.claims) 192 + .map_err(ServiceTokenError::Base64Decode)?; 125 193 126 - let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes) 127 - .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?; 194 + let claims: ServiceTokenClaims = 195 + serde_json::from_slice(&claims_bytes).map_err(ServiceTokenError::JsonDecode)?; 128 196 129 - let now = Utc::now().timestamp() as usize; 197 + let now = Utc::now().timestamp(); 130 198 if claims.exp < now { 131 - return Err(anyhow!("Token expired")); 199 + return Err(ServiceTokenError::Expired); 132 200 } 133 201 134 - if claims.aud.as_str() != self.pds_did { 135 - return Err(anyhow!( 136 - "Invalid audience: expected {}, got {}", 137 - self.pds_did, 138 - claims.aud 139 - )); 202 + if claims.aud != self.pds_did { 203 + return Err(ServiceTokenError::InvalidAudience { 204 + expected: self.pds_did.clone(), 205 + actual: claims.aud.clone(), 206 + }); 140 207 } 141 208 142 209 if let Some(required) = required_lxm { 143 210 match &claims.lxm { 144 - Some(lxm) if lxm == "*" || lxm == required => {} 211 + Some(lxm) if crate::auth::lxm_permits(lxm, required) => {} 145 212 Some(lxm) => { 146 - return Err(anyhow!( 147 - "Token lxm '{}' does not permit '{}'", 148 - lxm, 149 - required 150 - )); 213 + return Err(ServiceTokenError::LxmMismatch { 214 + token_lxm: lxm.clone(), 215 + required: required.to_string(), 216 + }); 151 217 } 152 218 None => { 153 - return Err(anyhow!("Token missing lxm claim")); 219 + return Err(ServiceTokenError::MissingLxm); 154 220 } 155 221 } 156 222 } ··· 159 225 let public_key = self.resolve_signing_key(did).await?; 160 226 161 227 let signature_bytes = URL_SAFE_NO_PAD 162 - .decode(parts[2]) 163 - .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?; 228 + .decode(jwt.signature) 229 + .map_err(ServiceTokenError::Base64Decode)?; 164 230 165 - let signature = Signature::from_slice(&signature_bytes) 166 - .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 231 + let signature = 232 + Signature::from_slice(&signature_bytes).map_err(ServiceTokenError::InvalidSignature)?; 167 233 168 - let message = format!("{}.{}", parts[0], parts[1]); 234 + let message = jwt.signing_input(); 169 235 170 236 public_key 171 237 .verify(message.as_bytes(), &signature) 172 - .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 238 + .map_err(ServiceTokenError::SignatureVerificationFailed)?; 173 239 174 240 debug!("Service token verified for DID: {}", did); 175 241 176 242 Ok(claims) 177 243 } 178 244 179 - async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> { 245 + async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey, ServiceTokenError> { 180 246 let did_doc = self.resolve_did_document(did).await?; 181 247 182 248 let atproto_key = did_doc 183 249 .verification_method 184 250 .iter() 185 251 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did)) 186 - .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?; 252 + .ok_or(ServiceTokenError::NoVerificationMethod)?; 187 253 188 254 let multibase = atproto_key 189 255 .public_key_multibase 190 256 .as_ref() 191 - .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?; 257 + .ok_or(ServiceTokenError::MissingPublicKey)?; 192 258 193 259 parse_did_key_multibase(multibase) 194 260 } 195 261 196 - async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> { 262 + async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument, ServiceTokenError> { 197 263 if did.starts_with("did:plc:") { 198 264 self.resolve_did_plc(did).await 199 265 } else if did.starts_with("did:web:") { 200 266 self.resolve_did_web(did).await 201 267 } else { 202 - Err(anyhow!("Unsupported DID method: {}", did)) 268 + Err(ServiceTokenError::UnsupportedDidMethod) 203 269 } 204 270 } 205 271 206 - async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> { 272 + async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument, ServiceTokenError> { 207 273 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 208 274 debug!("Resolving did:plc {} via {}", did, url); 209 275 ··· 212 278 .get(&url) 213 279 .send() 214 280 .await 215 - .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 281 + .map_err(ServiceTokenError::HttpFailed)?; 216 282 217 283 if resp.status() == reqwest::StatusCode::NOT_FOUND { 218 - return Err(anyhow!("DID not found: {}", did)); 284 + return Err(ServiceTokenError::DidNotFound(did.to_string())); 219 285 } 220 286 221 287 if !resp.status().is_success() { 222 - return Err(anyhow!("HTTP {}", resp.status())); 288 + return Err(ServiceTokenError::HttpStatus(resp.status())); 223 289 } 224 290 225 291 resp.json::<FullDidDocument>() 226 292 .await 227 - .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 293 + .map_err(ServiceTokenError::InvalidDidDocument) 228 294 } 229 295 230 - async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> { 296 + async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument, ServiceTokenError> { 231 297 let host = did 232 298 .strip_prefix("did:web:") 233 - .ok_or_else(|| anyhow!("Invalid did:web format"))?; 299 + .ok_or(ServiceTokenError::InvalidFormat)?; 234 300 235 - let parts: Vec<&str> = host.split(':').collect(); 236 - if parts.is_empty() { 237 - return Err(anyhow!("Invalid did:web format - no host")); 238 - } 239 - 240 - let host_part = parts[0].replace("%3A", ":"); 301 + let mut host_parts = host.splitn(2, ':'); 302 + let host_part = host_parts 303 + .next() 304 + .ok_or(ServiceTokenError::InvalidFormat)? 305 + .replace("%3A", ":"); 306 + let path_part = host_parts.next(); 241 307 242 308 let scheme = if host_part.starts_with("localhost") 243 309 || host_part.starts_with("127.0.0.1") ··· 248 314 "https" 249 315 }; 250 316 251 - let url = if parts.len() == 1 { 252 - format!("{}://{}/.well-known/did.json", scheme, host_part) 253 - } else { 254 - let path = parts[1..].join("/"); 255 - format!("{}://{}/{}/did.json", scheme, host_part, path) 317 + let url = match path_part { 318 + None => format!("{}://{}/.well-known/did.json", scheme, host_part), 319 + Some(path) => { 320 + let resolved_path = path.replace(':', "/"); 321 + format!("{}://{}/{}/did.json", scheme, host_part, resolved_path) 322 + } 256 323 }; 257 324 258 325 debug!("Resolving did:web {} via {}", did, url); ··· 262 329 .get(&url) 263 330 .send() 264 331 .await 265 - .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 332 + .map_err(ServiceTokenError::HttpFailed)?; 266 333 267 334 if !resp.status().is_success() { 268 - return Err(anyhow!("HTTP {}", resp.status())); 335 + return Err(ServiceTokenError::HttpStatus(resp.status())); 269 336 } 270 337 271 338 resp.json::<FullDidDocument>() 272 339 .await 273 - .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 340 + .map_err(ServiceTokenError::InvalidDidDocument) 274 341 } 275 342 } 276 343 ··· 280 347 } 281 348 } 282 349 283 - fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 350 + fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey, ServiceTokenError> { 284 351 if !multibase.starts_with('z') { 285 - return Err(anyhow!( 286 - "Expected base58btc multibase encoding (starts with 'z')" 352 + let base_char = multibase.chars().next().unwrap_or('?'); 353 + return Err(ServiceTokenError::InvalidMultibase( 354 + multibase::Error::UnknownBase(base_char), 287 355 )); 288 356 } 289 357 290 - let (_, decoded) = 291 - multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 358 + let (_, decoded) = multibase::decode(multibase).map_err(ServiceTokenError::InvalidMultibase)?; 292 359 293 360 if decoded.len() < 2 { 294 - return Err(anyhow!("Invalid multicodec data")); 361 + return Err(ServiceTokenError::InvalidMulticodec); 295 362 } 296 363 297 - let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 298 - (0xe701u16, &decoded[2..]) 364 + let key_bytes = if decoded.starts_with(&crate::plc::SECP256K1_MULTICODEC_PREFIX) { 365 + &decoded[crate::plc::SECP256K1_MULTICODEC_PREFIX.len()..] 299 366 } else { 300 - return Err(anyhow!( 301 - "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}", 302 - decoded[0], 303 - decoded[1] 304 - )); 367 + return Err(ServiceTokenError::UnsupportedKeyType); 305 368 }; 306 369 307 - if codec != 0xe701 { 308 - return Err(anyhow!("Only secp256k1 keys are supported")); 309 - } 310 - 311 - VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e)) 370 + VerifyingKey::from_sec1_bytes(key_bytes).map_err(ServiceTokenError::InvalidPublicKey) 312 371 } 313 372 314 373 pub fn is_service_token(token: &str) -> bool { 315 - let parts: Vec<&str> = token.split('.').collect(); 316 - if parts.len() != 3 { 374 + let Ok(jwt) = JwtParts::parse(token) else { 317 375 return false; 318 - } 376 + }; 319 377 320 - let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else { 378 + let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(jwt.claims) else { 321 379 return false; 322 380 }; 323 381 ··· 376 434 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB"; 377 435 let result = parse_did_key_multibase(test_key); 378 436 assert!(result.is_ok(), "Failed to parse valid multibase key"); 437 + } 438 + 439 + #[test] 440 + fn test_jwt_parts_parse_valid() { 441 + let jwt = JwtParts::parse("a.b.c").unwrap(); 442 + assert_eq!(jwt.header, "a"); 443 + assert_eq!(jwt.claims, "b"); 444 + assert_eq!(jwt.signature, "c"); 445 + } 446 + 447 + #[test] 448 + fn test_jwt_parts_parse_too_few() { 449 + assert!(matches!( 450 + JwtParts::parse("a.b"), 451 + Err(ServiceTokenError::InvalidFormat) 452 + )); 453 + } 454 + 455 + #[test] 456 + fn test_jwt_parts_parse_too_many() { 457 + assert!(matches!( 458 + JwtParts::parse("a.b.c.d"), 459 + Err(ServiceTokenError::InvalidFormat) 460 + )); 461 + } 462 + 463 + #[test] 464 + fn test_jwt_parts_signing_input() { 465 + let jwt = JwtParts::parse("header.claims.sig").unwrap(); 466 + assert_eq!(jwt.signing_input(), "header.claims"); 379 467 } 380 468 }
+76 -56
crates/tranquil-pds/src/auth/verification_token.rs
··· 1 1 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 2 2 use hmac::Mac; 3 3 use sha2::{Digest, Sha256}; 4 + use tranquil_db_traits::CommsChannel; 5 + use tranquil_types::Did; 4 6 5 7 type HmacSha256 = hmac::Hmac<Sha256>; 6 8 ··· 9 11 const DEFAULT_MIGRATION_EXPIRY_HOURS: u64 = 48; 10 12 const DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES: u64 = 10; 11 13 12 - #[derive(Debug, Clone, Copy, PartialEq, Eq)] 14 + #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)] 15 + #[serde(rename_all = "snake_case")] 13 16 pub enum VerificationPurpose { 14 17 Signup, 15 18 Migration, ··· 25 28 } 26 29 } 27 30 28 - fn from_str(s: &str) -> Option<Self> { 29 - match s { 30 - "signup" => Some(Self::Signup), 31 - "migration" => Some(Self::Migration), 32 - "channel_update" => Some(Self::ChannelUpdate), 33 - _ => None, 34 - } 35 - } 36 - 37 31 fn default_expiry_seconds(&self) -> u64 { 38 32 match self { 39 33 Self::Signup => DEFAULT_SIGNUP_EXPIRY_MINUTES * 60, ··· 43 37 } 44 38 } 45 39 40 + impl std::str::FromStr for VerificationPurpose { 41 + type Err = (); 42 + 43 + fn from_str(s: &str) -> Result<Self, Self::Err> { 44 + match s { 45 + "signup" => Ok(Self::Signup), 46 + "migration" => Ok(Self::Migration), 47 + "channel_update" => Ok(Self::ChannelUpdate), 48 + _ => Err(()), 49 + } 50 + } 51 + } 52 + 46 53 #[derive(Debug)] 47 54 pub struct VerificationToken { 48 - pub did: String, 55 + pub did: Did, 49 56 pub purpose: VerificationPurpose, 50 - pub channel: String, 57 + pub channel: CommsChannel, 51 58 pub identifier_hash: String, 52 59 pub expires_at: u64, 53 60 } ··· 75 82 URL_SAFE_NO_PAD.encode(&result[..16]) 76 83 } 77 84 78 - pub fn generate_signup_token(did: &str, channel: &str, identifier: &str) -> String { 85 + pub fn generate_signup_token(did: &Did, channel: CommsChannel, identifier: &str) -> String { 79 86 generate_token(did, VerificationPurpose::Signup, channel, identifier) 80 87 } 81 88 82 - pub fn generate_migration_token(did: &str, email: &str) -> String { 83 - generate_token(did, VerificationPurpose::Migration, "email", email) 89 + pub fn generate_migration_token(did: &Did, email: &str) -> String { 90 + generate_token( 91 + did, 92 + VerificationPurpose::Migration, 93 + CommsChannel::Email, 94 + email, 95 + ) 84 96 } 85 97 86 - pub fn generate_channel_update_token(did: &str, channel: &str, identifier: &str) -> String { 98 + pub fn generate_channel_update_token(did: &Did, channel: CommsChannel, identifier: &str) -> String { 87 99 generate_token(did, VerificationPurpose::ChannelUpdate, channel, identifier) 88 100 } 89 101 90 102 pub fn generate_token( 91 - did: &str, 103 + did: &Did, 92 104 purpose: VerificationPurpose, 93 - channel: &str, 105 + channel: CommsChannel, 94 106 identifier: &str, 95 107 ) -> String { 96 108 generate_token_with_expiry( ··· 103 115 } 104 116 105 117 pub fn generate_token_with_expiry( 106 - did: &str, 118 + did: &Did, 107 119 purpose: VerificationPurpose, 108 - channel: &str, 120 + channel: CommsChannel, 109 121 identifier: &str, 110 122 expiry_seconds: u64, 111 123 ) -> String { 112 124 let key = derive_verification_key(); 113 125 let identifier_hash = hash_identifier(identifier); 126 + let channel_str = channel.as_str(); 114 127 let expires_at = std::time::SystemTime::now() 115 128 .duration_since(std::time::UNIX_EPOCH) 116 129 .unwrap_or_default() ··· 121 134 "{}|{}|{}|{}|{}", 122 135 did, 123 136 purpose.as_str(), 124 - channel, 137 + channel_str, 125 138 identifier_hash, 126 139 expires_at 127 140 ); ··· 135 148 TOKEN_VERSION, 136 149 did, 137 150 purpose.as_str(), 138 - channel, 151 + channel_str, 139 152 identifier_hash, 140 153 expires_at, 141 154 signature ··· 170 183 171 184 pub fn verify_signup_token( 172 185 token: &str, 173 - expected_channel: &str, 186 + expected_channel: CommsChannel, 174 187 expected_identifier: &str, 175 188 ) -> Result<VerificationToken, VerifyError> { 176 189 let parsed = verify_token_signature(token)?; ··· 195 208 if parsed.purpose != VerificationPurpose::Migration { 196 209 return Err(VerifyError::PurposeMismatch); 197 210 } 198 - if parsed.channel != "email" { 211 + if parsed.channel != CommsChannel::Email { 199 212 return Err(VerifyError::ChannelMismatch); 200 213 } 201 214 let expected_hash = hash_identifier(expected_email); ··· 207 220 208 221 pub fn verify_channel_update_token( 209 222 token: &str, 210 - expected_channel: &str, 223 + expected_channel: CommsChannel, 211 224 expected_identifier: &str, 212 225 ) -> Result<VerificationToken, VerifyError> { 213 226 let parsed = verify_token_signature(token)?; ··· 226 239 227 240 pub fn verify_token_for_did( 228 241 token: &str, 229 - expected_did: &str, 242 + expected_did: &Did, 230 243 ) -> Result<VerificationToken, VerifyError> { 231 244 let parsed = verify_token_signature(token)?; 232 - if parsed.did != expected_did { 245 + if parsed.did != *expected_did { 233 246 return Err(VerifyError::IdentifierMismatch); 234 247 } 235 248 Ok(parsed) ··· 253 266 254 267 let did = parts[1]; 255 268 let purpose_str = parts[2]; 256 - let channel = parts[3]; 269 + let channel_str = parts[3]; 257 270 let identifier_hash = parts[4]; 258 271 let expires_at: u64 = parts[5].parse().map_err(|_| VerifyError::InvalidFormat)?; 259 272 let provided_signature = parts[6]; 260 273 261 - let purpose = VerificationPurpose::from_str(purpose_str).ok_or(VerifyError::InvalidFormat)?; 274 + let purpose: VerificationPurpose = purpose_str 275 + .parse() 276 + .map_err(|_| VerifyError::InvalidFormat)?; 277 + let channel: CommsChannel = channel_str 278 + .parse() 279 + .map_err(|_| VerifyError::InvalidFormat)?; 262 280 263 281 let now = std::time::SystemTime::now() 264 282 .duration_since(std::time::UNIX_EPOCH) ··· 271 289 let key = derive_verification_key(); 272 290 let payload = format!( 273 291 "{}|{}|{}|{}|{}", 274 - did, purpose_str, channel, identifier_hash, expires_at 292 + did, purpose_str, channel_str, identifier_hash, expires_at 275 293 ); 276 294 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid"); 277 295 mac.update(payload.as_bytes()); ··· 286 304 return Err(VerifyError::InvalidSignature); 287 305 } 288 306 307 + let parsed_did: Did = did.parse().map_err(|_| VerifyError::InvalidFormat)?; 308 + 289 309 Ok(VerificationToken { 290 - did: did.to_string(), 310 + did: parsed_did, 291 311 purpose, 292 - channel: channel.to_string(), 312 + channel, 293 313 identifier_hash: identifier_hash.to_string(), 294 314 expires_at, 295 315 }) ··· 309 329 310 330 #[test] 311 331 fn test_signup_token() { 312 - let did = "did:plc:test123"; 313 - let channel = "email"; 332 + let did: Did = "did:plc:test123".parse().unwrap(); 333 + let channel = CommsChannel::Email; 314 334 let identifier = "test@example.com"; 315 - let token = generate_signup_token(did, channel, identifier); 335 + let token = generate_signup_token(&did, channel, identifier); 316 336 let result = verify_signup_token(&token, channel, identifier); 317 337 assert!(result.is_ok(), "Expected Ok, got {:?}", result); 318 338 let parsed = result.unwrap(); ··· 323 343 324 344 #[test] 325 345 fn test_migration_token() { 326 - let did = "did:plc:test123"; 346 + let did: Did = "did:plc:test123".parse().unwrap(); 327 347 let email = "test@example.com"; 328 - let token = generate_migration_token(did, email); 348 + let token = generate_migration_token(&did, email); 329 349 let result = verify_migration_token(&token, email); 330 350 assert!(result.is_ok(), "Expected Ok, got {:?}", result); 331 351 let parsed = result.unwrap(); ··· 335 355 336 356 #[test] 337 357 fn test_token_case_insensitive() { 338 - let did = "did:plc:test123"; 339 - let token = generate_signup_token(did, "email", "Test@Example.COM"); 340 - let result = verify_signup_token(&token, "email", "test@example.com"); 358 + let did: Did = "did:plc:test123".parse().unwrap(); 359 + let token = generate_signup_token(&did, CommsChannel::Email, "Test@Example.COM"); 360 + let result = verify_signup_token(&token, CommsChannel::Email, "test@example.com"); 341 361 assert!(result.is_ok()); 342 362 } 343 363 344 364 #[test] 345 365 fn test_token_wrong_identifier() { 346 - let did = "did:plc:test123"; 347 - let token = generate_signup_token(did, "email", "test@example.com"); 348 - let result = verify_signup_token(&token, "email", "other@example.com"); 366 + let did: Did = "did:plc:test123".parse().unwrap(); 367 + let token = generate_signup_token(&did, CommsChannel::Email, "test@example.com"); 368 + let result = verify_signup_token(&token, CommsChannel::Email, "other@example.com"); 349 369 assert!(matches!(result, Err(VerifyError::IdentifierMismatch))); 350 370 } 351 371 352 372 #[test] 353 373 fn test_token_wrong_channel() { 354 - let did = "did:plc:test123"; 355 - let token = generate_signup_token(did, "email", "test@example.com"); 356 - let result = verify_signup_token(&token, "discord", "test@example.com"); 374 + let did: Did = "did:plc:test123".parse().unwrap(); 375 + let token = generate_signup_token(&did, CommsChannel::Email, "test@example.com"); 376 + let result = verify_signup_token(&token, CommsChannel::Discord, "test@example.com"); 357 377 assert!(matches!(result, Err(VerifyError::ChannelMismatch))); 358 378 } 359 379 360 380 #[test] 361 381 fn test_expired_token() { 362 - let did = "did:plc:test123"; 382 + let did: Did = "did:plc:test123".parse().unwrap(); 363 383 let token = generate_token_with_expiry( 364 - did, 384 + &did, 365 385 VerificationPurpose::Signup, 366 - "email", 386 + CommsChannel::Email, 367 387 "test@example.com", 368 388 0, 369 389 ); 370 390 std::thread::sleep(std::time::Duration::from_millis(1100)); 371 - let result = verify_signup_token(&token, "email", "test@example.com"); 391 + let result = verify_signup_token(&token, CommsChannel::Email, "test@example.com"); 372 392 assert!(matches!(result, Err(VerifyError::Expired))); 373 393 } 374 394 375 395 #[test] 376 396 fn test_invalid_token() { 377 - let result = verify_signup_token("invalid-token", "email", "test@example.com"); 397 + let result = verify_signup_token("invalid-token", CommsChannel::Email, "test@example.com"); 378 398 assert!(matches!(result, Err(VerifyError::InvalidFormat))); 379 399 } 380 400 381 401 #[test] 382 402 fn test_purpose_mismatch() { 383 - let did = "did:plc:test123"; 403 + let did: Did = "did:plc:test123".parse().unwrap(); 384 404 let email = "test@example.com"; 385 - let signup_token = generate_signup_token(did, "email", email); 405 + let signup_token = generate_signup_token(&did, CommsChannel::Email, email); 386 406 let result = verify_migration_token(&signup_token, email); 387 407 assert!(matches!(result, Err(VerifyError::PurposeMismatch))); 388 408 } 389 409 390 410 #[test] 391 411 fn test_discord_channel() { 392 - let did = "did:plc:test123"; 412 + let did: Did = "did:plc:test123".parse().unwrap(); 393 413 let discord_id = "123456789012345678"; 394 - let token = generate_signup_token(did, "discord", discord_id); 395 - let result = verify_signup_token(&token, "discord", discord_id); 414 + let token = generate_signup_token(&did, CommsChannel::Discord, discord_id); 415 + let result = verify_signup_token(&token, CommsChannel::Discord, discord_id); 396 416 assert!(result.is_ok()); 397 417 } 398 418
+24 -12
crates/tranquil-pds/src/auth/webauthn.rs
··· 1 1 use uuid::Uuid; 2 2 use webauthn_rs::prelude::*; 3 3 4 + #[derive(Debug, thiserror::Error)] 5 + pub enum WebauthnError { 6 + #[error("Invalid origin URL: {0}")] 7 + InvalidOrigin(String), 8 + #[error("Failed to create WebAuthn builder: {0}")] 9 + BuilderFailed(String), 10 + #[error("Registration failed: {0}")] 11 + RegistrationFailed(String), 12 + #[error("Authentication failed: {0}")] 13 + AuthenticationFailed(String), 14 + } 15 + 4 16 pub struct WebAuthnConfig { 5 17 webauthn: Webauthn, 6 18 } 7 19 8 20 impl WebAuthnConfig { 9 - pub fn new(hostname: &str) -> Result<Self, String> { 21 + pub fn new(hostname: &str) -> Result<Self, WebauthnError> { 10 22 let rp_id = hostname.split(':').next().unwrap_or(hostname).to_string(); 11 23 let rp_origin = Url::parse(&format!("https://{}", hostname)) 12 - .map_err(|e| format!("Invalid origin URL: {}", e))?; 24 + .map_err(|e| WebauthnError::InvalidOrigin(e.to_string()))?; 13 25 14 26 let builder = WebauthnBuilder::new(&rp_id, &rp_origin) 15 - .map_err(|e| format!("Failed to create WebAuthn builder: {}", e))? 27 + .map_err(|e| WebauthnError::BuilderFailed(e.to_string()))? 16 28 .rp_name("Tranquil PDS") 17 29 .danger_set_user_presence_only_security_keys(true); 18 30 19 31 let webauthn = builder 20 32 .build() 21 - .map_err(|e| format!("Failed to build WebAuthn: {}", e))?; 33 + .map_err(|e| WebauthnError::BuilderFailed(e.to_string()))?; 22 34 23 35 Ok(Self { webauthn }) 24 36 } ··· 29 41 username: &str, 30 42 display_name: &str, 31 43 exclude_credentials: Vec<CredentialID>, 32 - ) -> Result<(CreationChallengeResponse, SecurityKeyRegistration), String> { 44 + ) -> Result<(CreationChallengeResponse, SecurityKeyRegistration), WebauthnError> { 33 45 let user_unique_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, user_id.as_bytes()); 34 46 35 47 self.webauthn ··· 45 57 None, 46 58 None, 47 59 ) 48 - .map_err(|e| format!("Failed to start registration: {}", e)) 60 + .map_err(|e| WebauthnError::RegistrationFailed(e.to_string())) 49 61 } 50 62 51 63 pub fn finish_registration( 52 64 &self, 53 65 reg: &RegisterPublicKeyCredential, 54 66 state: &SecurityKeyRegistration, 55 - ) -> Result<SecurityKey, String> { 67 + ) -> Result<SecurityKey, WebauthnError> { 56 68 self.webauthn 57 69 .finish_securitykey_registration(reg, state) 58 - .map_err(|e| format!("Failed to finish registration: {}", e)) 70 + .map_err(|e| WebauthnError::RegistrationFailed(e.to_string())) 59 71 } 60 72 61 73 pub fn start_authentication( 62 74 &self, 63 75 credentials: Vec<SecurityKey>, 64 - ) -> Result<(RequestChallengeResponse, SecurityKeyAuthentication), String> { 76 + ) -> Result<(RequestChallengeResponse, SecurityKeyAuthentication), WebauthnError> { 65 77 self.webauthn 66 78 .start_securitykey_authentication(&credentials) 67 - .map_err(|e| format!("Failed to start authentication: {}", e)) 79 + .map_err(|e| WebauthnError::AuthenticationFailed(e.to_string())) 68 80 } 69 81 70 82 pub fn finish_authentication( 71 83 &self, 72 84 auth: &PublicKeyCredential, 73 85 state: &SecurityKeyAuthentication, 74 - ) -> Result<AuthenticationResult, String> { 86 + ) -> Result<AuthenticationResult, WebauthnError> { 75 87 self.webauthn 76 88 .finish_securitykey_authentication(auth, state) 77 - .map_err(|e| format!("Failed to finish authentication: {}", e)) 89 + .map_err(|e| WebauthnError::AuthenticationFailed(e.to_string())) 78 90 } 79 91 }
+35
crates/tranquil-pds/src/cache_keys.rs
··· 1 + pub fn session_key(did: &str, jti: &str) -> String { 2 + format!("auth:session:{}:{}", did, jti) 3 + } 4 + 5 + pub fn signing_key_key(did: &str) -> String { 6 + format!("auth:key:{}", did) 7 + } 8 + 9 + pub fn user_status_key(did: &str) -> String { 10 + format!("auth:status:{}", did) 11 + } 12 + 13 + pub fn handle_key(handle: &str) -> String { 14 + format!("handle:{}", handle) 15 + } 16 + 17 + pub fn reauth_key(did: &str) -> String { 18 + format!("reauth:{}", did) 19 + } 20 + 21 + pub fn plc_doc_key(did: &str) -> String { 22 + format!("plc:doc:{}", did) 23 + } 24 + 25 + pub fn plc_data_key(did: &str) -> String { 26 + format!("plc:data:{}", did) 27 + } 28 + 29 + pub fn email_update_key(did: &str) -> String { 30 + format!("email_update:{}", did) 31 + } 32 + 33 + pub fn scope_ref_key(cid: &str) -> String { 34 + format!("scope_ref:{}", cid) 35 + }
+35 -19
crates/tranquil-pds/src/circuit_breaker.rs
··· 1 + use std::num::{NonZeroU32, NonZeroU64}; 1 2 use std::sync::Arc; 2 3 use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 3 4 use std::time::Duration; ··· 24 25 impl CircuitBreaker { 25 26 pub fn new( 26 27 name: &str, 27 - failure_threshold: u32, 28 - success_threshold: u32, 29 - timeout_secs: u64, 28 + failure_threshold: NonZeroU32, 29 + success_threshold: NonZeroU32, 30 + timeout_secs: NonZeroU64, 30 31 ) -> Self { 31 32 Self { 32 33 name: name.to_string(), 33 - failure_threshold, 34 - success_threshold, 35 - timeout: Duration::from_secs(timeout_secs), 34 + failure_threshold: failure_threshold.get(), 35 + success_threshold: success_threshold.get(), 36 + timeout: Duration::from_secs(timeout_secs.get()), 36 37 state: Arc::new(RwLock::new(CircuitState::Closed)), 37 38 failure_count: AtomicU32::new(0), 38 39 success_count: AtomicU32::new(0), ··· 49 50 let last_failure = self.last_failure_time.load(Ordering::SeqCst); 50 51 let now = std::time::SystemTime::now() 51 52 .duration_since(std::time::UNIX_EPOCH) 52 - .unwrap() 53 + .unwrap_or_default() 53 54 .as_secs(); 54 55 55 - if now - last_failure >= self.timeout.as_secs() { 56 + if now.saturating_sub(last_failure) >= self.timeout.as_secs() { 56 57 drop(state); 57 58 let mut state = self.state.write().await; 58 59 if *state == CircuitState::Open { ··· 100 101 *state = CircuitState::Open; 101 102 let now = std::time::SystemTime::now() 102 103 .duration_since(std::time::UNIX_EPOCH) 103 - .unwrap() 104 + .unwrap_or_default() 104 105 .as_secs(); 105 106 self.last_failure_time.store(now, Ordering::SeqCst); 106 107 tracing::warn!( ··· 150 151 impl CircuitBreakers { 151 152 pub fn new() -> Self { 152 153 Self { 153 - plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 154 - relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 154 + plc_directory: Arc::new(CircuitBreaker::new( 155 + "plc_directory", 156 + const { NonZeroU32::new(5).unwrap() }, 157 + const { NonZeroU32::new(3).unwrap() }, 158 + const { NonZeroU64::new(60).unwrap() }, 159 + )), 160 + relay_notification: Arc::new(CircuitBreaker::new( 161 + "relay_notification", 162 + const { NonZeroU32::new(10).unwrap() }, 163 + const { NonZeroU32::new(5).unwrap() }, 164 + const { NonZeroU64::new(30).unwrap() }, 165 + )), 155 166 } 156 167 } 157 168 } ··· 223 234 mod tests { 224 235 use super::*; 225 236 237 + const TEST_FAILURE: NonZeroU32 = const { NonZeroU32::new(3).unwrap() }; 238 + const TEST_SUCCESS: NonZeroU32 = const { NonZeroU32::new(2).unwrap() }; 239 + const TEST_TIMEOUT: NonZeroU64 = const { NonZeroU64::new(10).unwrap() }; 240 + const TEST_ZERO_TIMEOUT: NonZeroU64 = const { NonZeroU64::new(1).unwrap() }; 241 + 226 242 #[tokio::test] 227 243 async fn test_circuit_breaker_starts_closed() { 228 - let cb = CircuitBreaker::new("test", 3, 2, 10); 244 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_TIMEOUT); 229 245 assert_eq!(cb.state().await, CircuitState::Closed); 230 246 assert!(cb.can_execute().await); 231 247 } 232 248 233 249 #[tokio::test] 234 250 async fn test_circuit_breaker_opens_after_failures() { 235 - let cb = CircuitBreaker::new("test", 3, 2, 10); 251 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_TIMEOUT); 236 252 237 253 cb.record_failure().await; 238 254 assert_eq!(cb.state().await, CircuitState::Closed); ··· 247 263 248 264 #[tokio::test] 249 265 async fn test_circuit_breaker_success_resets_failures() { 250 - let cb = CircuitBreaker::new("test", 3, 2, 10); 266 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_TIMEOUT); 251 267 252 268 cb.record_failure().await; 253 269 cb.record_failure().await; ··· 263 279 264 280 #[tokio::test] 265 281 async fn test_circuit_breaker_half_open_closes_after_successes() { 266 - let cb = CircuitBreaker::new("test", 3, 2, 0); 282 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_ZERO_TIMEOUT); 267 283 268 284 futures::future::join_all((0..3).map(|_| cb.record_failure())).await; 269 285 assert_eq!(cb.state().await, CircuitState::Open); 270 286 271 - tokio::time::sleep(Duration::from_millis(100)).await; 287 + tokio::time::sleep(Duration::from_millis(1100)).await; 272 288 assert!(cb.can_execute().await); 273 289 assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 290 ··· 281 297 282 298 #[tokio::test] 283 299 async fn test_circuit_breaker_half_open_reopens_on_failure() { 284 - let cb = CircuitBreaker::new("test", 3, 2, 0); 300 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_ZERO_TIMEOUT); 285 301 286 302 futures::future::join_all((0..3).map(|_| cb.record_failure())).await; 287 303 288 - tokio::time::sleep(Duration::from_millis(100)).await; 304 + tokio::time::sleep(Duration::from_millis(1100)).await; 289 305 cb.can_execute().await; 290 306 291 307 cb.record_failure().await; ··· 294 310 295 311 #[tokio::test] 296 312 async fn test_with_circuit_breaker_helper() { 297 - let cb = CircuitBreaker::new("test", 3, 2, 10); 313 + let cb = CircuitBreaker::new("test", TEST_FAILURE, TEST_SUCCESS, TEST_TIMEOUT); 298 314 299 315 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 300 316 with_circuit_breaker(&cb, || async { Ok(42) }).await;
+1 -1
crates/tranquil-pds/src/comms/mod.rs
··· 7 7 mime_encode_header, sanitize_header_value, validate_locale, 8 8 }; 9 9 10 - pub use service::{CommsService, channel_display_name, repo as comms_repo}; 10 + pub use service::{CommsService, repo as comms_repo};
+22 -123
crates/tranquil-pds/src/comms/service.rs
··· 7 7 use tokio_util::sync::CancellationToken; 8 8 use tracing::{debug, error, info, warn}; 9 9 use tranquil_comms::{ 10 - CommsChannel, CommsSender, CommsStatus, CommsType, NewComms, SendError, format_message, 11 - get_strings, 10 + CommsChannel, CommsSender, CommsType, NewComms, SendError, format_message, get_strings, 12 11 }; 13 12 use tranquil_db_traits::{InfraRepository, QueuedComms, UserCommsPrefs, UserRepository}; 14 13 use uuid::Uuid; ··· 54 53 } 55 54 56 55 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, tranquil_db_traits::DbError> { 57 - let channel = match item.channel { 58 - CommsChannel::Email => tranquil_db_traits::CommsChannel::Email, 59 - CommsChannel::Discord => tranquil_db_traits::CommsChannel::Discord, 60 - CommsChannel::Telegram => tranquil_db_traits::CommsChannel::Telegram, 61 - CommsChannel::Signal => tranquil_db_traits::CommsChannel::Signal, 62 - }; 63 - let comms_type = match item.comms_type { 64 - CommsType::Welcome => tranquil_db_traits::CommsType::Welcome, 65 - CommsType::EmailVerification => tranquil_db_traits::CommsType::EmailVerification, 66 - CommsType::PasswordReset => tranquil_db_traits::CommsType::PasswordReset, 67 - CommsType::EmailUpdate => tranquil_db_traits::CommsType::EmailUpdate, 68 - CommsType::AccountDeletion => tranquil_db_traits::CommsType::AccountDeletion, 69 - CommsType::AdminEmail => tranquil_db_traits::CommsType::AdminEmail, 70 - CommsType::PlcOperation => tranquil_db_traits::CommsType::PlcOperation, 71 - CommsType::TwoFactorCode => tranquil_db_traits::CommsType::TwoFactorCode, 72 - CommsType::PasskeyRecovery => tranquil_db_traits::CommsType::PasskeyRecovery, 73 - CommsType::LegacyLoginAlert => tranquil_db_traits::CommsType::LegacyLoginAlert, 74 - CommsType::MigrationVerification => { 75 - tranquil_db_traits::CommsType::MigrationVerification 76 - } 77 - CommsType::ChannelVerification => tranquil_db_traits::CommsType::ChannelVerification, 78 - CommsType::ChannelVerified => tranquil_db_traits::CommsType::ChannelVerified, 79 - }; 80 56 let id = self 81 57 .infra_repo 82 58 .enqueue_comms( 83 59 Some(item.user_id), 84 - channel, 85 - comms_type, 60 + item.channel, 61 + item.comms_type, 86 62 &item.recipient, 87 63 item.subject.as_deref(), 88 64 &item.body, ··· 144 120 145 121 async fn process_item(&self, item: QueuedComms) { 146 122 let comms_id = item.id; 147 - let channel = match item.channel { 148 - tranquil_db_traits::CommsChannel::Email => CommsChannel::Email, 149 - tranquil_db_traits::CommsChannel::Discord => CommsChannel::Discord, 150 - tranquil_db_traits::CommsChannel::Telegram => CommsChannel::Telegram, 151 - tranquil_db_traits::CommsChannel::Signal => CommsChannel::Signal, 152 - }; 153 - let comms_item = tranquil_comms::QueuedComms { 154 - id: item.id, 155 - user_id: item.user_id, 156 - channel, 157 - comms_type: match item.comms_type { 158 - tranquil_db_traits::CommsType::Welcome => CommsType::Welcome, 159 - tranquil_db_traits::CommsType::EmailVerification => CommsType::EmailVerification, 160 - tranquil_db_traits::CommsType::PasswordReset => CommsType::PasswordReset, 161 - tranquil_db_traits::CommsType::EmailUpdate => CommsType::EmailUpdate, 162 - tranquil_db_traits::CommsType::AccountDeletion => CommsType::AccountDeletion, 163 - tranquil_db_traits::CommsType::AdminEmail => CommsType::AdminEmail, 164 - tranquil_db_traits::CommsType::PlcOperation => CommsType::PlcOperation, 165 - tranquil_db_traits::CommsType::TwoFactorCode => CommsType::TwoFactorCode, 166 - tranquil_db_traits::CommsType::PasskeyRecovery => CommsType::PasskeyRecovery, 167 - tranquil_db_traits::CommsType::LegacyLoginAlert => CommsType::LegacyLoginAlert, 168 - tranquil_db_traits::CommsType::MigrationVerification => { 169 - CommsType::MigrationVerification 170 - } 171 - tranquil_db_traits::CommsType::ChannelVerification => { 172 - CommsType::ChannelVerification 173 - } 174 - tranquil_db_traits::CommsType::ChannelVerified => CommsType::ChannelVerified, 175 - }, 176 - status: match item.status { 177 - tranquil_db_traits::CommsStatus::Pending => CommsStatus::Pending, 178 - tranquil_db_traits::CommsStatus::Processing => CommsStatus::Processing, 179 - tranquil_db_traits::CommsStatus::Sent => CommsStatus::Sent, 180 - tranquil_db_traits::CommsStatus::Failed => CommsStatus::Failed, 181 - }, 182 - recipient: item.recipient, 183 - subject: item.subject, 184 - body: item.body, 185 - metadata: item.metadata, 186 - attempts: item.attempts, 187 - max_attempts: item.max_attempts, 188 - last_error: item.last_error, 189 - created_at: item.created_at, 190 - updated_at: item.updated_at, 191 - scheduled_for: item.scheduled_for, 192 - processed_at: item.processed_at, 193 - }; 194 - let result = match self.senders.get(&channel) { 195 - Some(sender) => sender.send(&comms_item).await, 123 + let result = match self.senders.get(&item.channel) { 124 + Some(sender) => sender.send(&item).await, 196 125 None => { 197 126 warn!( 198 127 comms_id = %comms_id, 199 - channel = ?channel, 128 + channel = ?item.channel, 200 129 "No sender registered for channel" 201 130 ); 202 - Err(SendError::NotConfigured(channel)) 131 + Err(SendError::NotConfigured(item.channel)) 203 132 } 204 133 }; 205 134 match result { ··· 240 169 } 241 170 } 242 171 243 - pub fn channel_display_name(channel: CommsChannel) -> &'static str { 244 - match channel { 245 - CommsChannel::Email => "email", 246 - CommsChannel::Discord => "Discord", 247 - CommsChannel::Telegram => "Telegram", 248 - CommsChannel::Signal => "Signal", 249 - } 250 - } 251 - 252 172 struct ResolvedRecipient { 253 173 channel: tranquil_db_traits::CommsChannel, 254 174 recipient: String, ··· 289 209 recipient: n.clone(), 290 210 }) 291 211 .unwrap_or_else(email_fallback), 292 - } 293 - } 294 - 295 - fn channel_from_str(s: &str) -> tranquil_db_traits::CommsChannel { 296 - match s { 297 - "discord" => tranquil_db_traits::CommsChannel::Discord, 298 - "telegram" => tranquil_db_traits::CommsChannel::Telegram, 299 - "signal" => tranquil_db_traits::CommsChannel::Signal, 300 - _ => tranquil_db_traits::CommsChannel::Email, 301 212 } 302 213 } 303 214 ··· 453 364 infra_repo: &dyn InfraRepository, 454 365 user_id: Uuid, 455 366 token: &str, 456 - purpose: &str, 457 367 hostname: &str, 458 368 ) -> Result<Uuid, DbError> { 459 369 let prefs = user_repo ··· 463 373 let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 464 374 let current_email = prefs.email.clone().unwrap_or_default(); 465 375 466 - let (subject_template, body_template, comms_type) = match purpose { 467 - "email_update" => ( 468 - strings.email_update_subject, 469 - strings.short_token_body, 470 - CommsType::EmailUpdate, 471 - ), 472 - _ => ( 473 - strings.email_update_subject, 474 - strings.short_token_body, 475 - CommsType::EmailUpdate, 476 - ), 477 - }; 376 + let subject_template = strings.email_update_subject; 377 + let body_template = strings.short_token_body; 378 + let comms_type = CommsType::EmailUpdate; 478 379 479 380 let verify_page = format!("https://{}/app/settings", hostname); 480 381 let body = format_message( ··· 642 543 user_repo: &dyn UserRepository, 643 544 infra_repo: &dyn InfraRepository, 644 545 user_id: Uuid, 645 - channel: &str, 546 + channel: tranquil_db_traits::CommsChannel, 646 547 recipient: &str, 647 548 code: &str, 648 549 hostname: &str, 649 550 ) -> Result<Uuid, DbError> { 650 - let comms_channel = channel_from_str(channel); 651 - let prefs = user_repo.get_comms_prefs(user_id).await.ok().flatten(); 551 + let comms_channel = channel; 552 + let prefs = match user_repo.get_comms_prefs(user_id).await { 553 + Ok(p) => p, 554 + Err(e) => { 555 + tracing::warn!(user_id = %user_id, error = %e, "failed to fetch comms preferences, using defaults"); 556 + None 557 + } 558 + }; 652 559 let locale = prefs 653 560 .as_ref() 654 561 .and_then(|p| p.preferred_locale.as_deref()) ··· 762 669 user_repo: &dyn UserRepository, 763 670 infra_repo: &dyn InfraRepository, 764 671 user_id: Uuid, 765 - channel_name: &str, 672 + channel: tranquil_db_traits::CommsChannel, 766 673 recipient: &str, 767 674 hostname: &str, 768 675 ) -> Result<Uuid, DbError> { ··· 771 678 .await? 772 679 .ok_or(DbError::NotFound)?; 773 680 let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 774 - let display_name = match channel_name { 775 - "email" => "Email", 776 - "discord" => "Discord", 777 - "telegram" => "Telegram", 778 - "signal" => "Signal", 779 - other => other, 780 - }; 781 681 let body = format_message( 782 682 strings.channel_verified_body, 783 683 &[ 784 684 ("handle", &prefs.handle), 785 - ("channel", display_name), 685 + ("channel", channel.display_name()), 786 686 ("hostname", hostname), 787 687 ], 788 688 ); 789 689 let subject = format_message(strings.channel_verified_subject, &[("hostname", hostname)]); 790 - let comms_channel = channel_from_str(channel_name); 791 690 infra_repo 792 691 .enqueue_comms( 793 692 Some(user_id), 794 - comms_channel, 693 + channel, 795 694 CommsType::ChannelVerified, 796 695 recipient, 797 696 Some(&subject),
+55 -13
crates/tranquil-pds/src/config.rs
··· 6 6 use sha2::{Digest, Sha256}; 7 7 use std::sync::OnceLock; 8 8 9 + #[derive(Debug)] 10 + pub enum CryptoError { 11 + CipherCreationFailed(String), 12 + EncryptionFailed(String), 13 + DecryptionFailed(String), 14 + DataTooShort, 15 + UnknownEncryptionVersion(i32), 16 + } 17 + 18 + impl std::fmt::Display for CryptoError { 19 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 20 + match self { 21 + Self::CipherCreationFailed(e) => write!(f, "Failed to create cipher: {}", e), 22 + Self::EncryptionFailed(e) => write!(f, "Encryption failed: {}", e), 23 + Self::DecryptionFailed(e) => write!(f, "Decryption failed: {}", e), 24 + Self::DataTooShort => write!(f, "Encrypted data too short"), 25 + Self::UnknownEncryptionVersion(v) => write!(f, "Unknown encryption version: {}", v), 26 + } 27 + } 28 + } 29 + 30 + impl std::error::Error for CryptoError {} 31 + 9 32 static CONFIG: OnceLock<AuthConfig> = OnceLock::new(); 10 33 11 34 pub const ENCRYPTION_VERSION: i32 = 1; ··· 208 231 } 209 232 } 210 233 211 - pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, String> { 234 + pub fn encrypt_user_key(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> { 212 235 use rand::RngCore; 213 236 214 237 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 215 - .map_err(|e| format!("Failed to create cipher: {}", e))?; 238 + .map_err(|e| CryptoError::CipherCreationFailed(e.to_string()))?; 216 239 217 240 let mut nonce_bytes = [0u8; 12]; 218 241 rand::thread_rng().fill_bytes(&mut nonce_bytes); ··· 222 245 223 246 let ciphertext = cipher 224 247 .encrypt(nonce, plaintext) 225 - .map_err(|e| format!("Encryption failed: {}", e))?; 248 + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; 226 249 227 250 let mut result = Vec::with_capacity(12 + ciphertext.len()); 228 251 result.extend_from_slice(&nonce_bytes); ··· 231 254 Ok(result) 232 255 } 233 256 234 - pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, String> { 257 + pub fn decrypt_user_key(&self, encrypted: &[u8]) -> Result<Vec<u8>, CryptoError> { 235 258 if encrypted.len() < 12 { 236 - return Err("Encrypted data too short".to_string()); 259 + return Err(CryptoError::DataTooShort); 237 260 } 238 261 239 262 let cipher = Aes256Gcm::new_from_slice(&self.key_encryption_key) 240 - .map_err(|e| format!("Failed to create cipher: {}", e))?; 263 + .map_err(|e| CryptoError::CipherCreationFailed(e.to_string()))?; 241 264 242 265 #[allow(deprecated)] 243 266 let nonce = Nonce::from_slice(&encrypted[..12]); ··· 245 268 246 269 cipher 247 270 .decrypt(nonce, ciphertext) 248 - .map_err(|e| format!("Decryption failed: {}", e)) 271 + .map_err(|e| CryptoError::DecryptionFailed(e.to_string())) 249 272 } 250 273 } 251 274 252 - pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, String> { 275 + pub fn encrypt_key(plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> { 253 276 AuthConfig::get().encrypt_user_key(plaintext) 254 277 } 255 278 256 - pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, String> { 257 - match version.unwrap_or(0) { 258 - 0 => Ok(encrypted.to_vec()), 259 - 1 => AuthConfig::get().decrypt_user_key(encrypted), 260 - v => Err(format!("Unknown encryption version: {}", v)), 279 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 280 + pub enum EncryptionVersion { 281 + Unencrypted, 282 + AesGcm, 283 + } 284 + 285 + impl EncryptionVersion { 286 + pub fn from_db(version: Option<i32>) -> Result<Self, CryptoError> { 287 + match version.unwrap_or(0) { 288 + 0 => Ok(Self::Unencrypted), 289 + 1 => Ok(Self::AesGcm), 290 + v => Err(CryptoError::UnknownEncryptionVersion(v)), 291 + } 292 + } 293 + 294 + pub fn from_db_required(version: i32) -> Result<Self, CryptoError> { 295 + Self::from_db(Some(version)) 296 + } 297 + } 298 + 299 + pub fn decrypt_key(encrypted: &[u8], version: Option<i32>) -> Result<Vec<u8>, CryptoError> { 300 + match EncryptionVersion::from_db(version)? { 301 + EncryptionVersion::Unencrypted => Ok(encrypted.to_vec()), 302 + EncryptionVersion::AesGcm => AuthConfig::get().decrypt_user_key(encrypted), 261 303 } 262 304 }
+12
crates/tranquil-pds/src/delegation/scopes.rs
··· 163 163 fn test_validate_scopes_invalid() { 164 164 assert!(validate_delegation_scopes("invalid:scope").is_err()); 165 165 } 166 + 167 + #[test] 168 + fn test_scope_presets_parse() { 169 + SCOPE_PRESETS.iter().for_each(|p| { 170 + validate_delegation_scopes(p.scopes).unwrap_or_else(|e| { 171 + panic!( 172 + "preset '{}' has invalid scopes '{}': {}", 173 + p.name, p.scopes, e 174 + ) 175 + }); 176 + }); 177 + } 166 178 }
+8 -4
crates/tranquil-pds/src/image/mod.rs
··· 197 197 max_size: u32, 198 198 ) -> Result<ProcessedImage, ImageError> { 199 199 let (orig_width, orig_height) = (img.width(), img.height()); 200 + let safe_f64_to_u32 = 201 + |v: f64| -> u32 { u32::try_from(v.round() as u64).unwrap_or(u32::MAX) }; 200 202 let (new_width, new_height) = if orig_width > orig_height { 201 - let ratio = max_size as f64 / orig_width as f64; 202 - (max_size, (orig_height as f64 * ratio) as u32) 203 + let ratio = f64::from(max_size) / f64::from(orig_width); 204 + let scaled = safe_f64_to_u32((f64::from(orig_height) * ratio).max(1.0)); 205 + (max_size, scaled.min(max_size)) 203 206 } else { 204 - let ratio = max_size as f64 / orig_height as f64; 205 - ((orig_width as f64 * ratio) as u32, max_size) 207 + let ratio = f64::from(max_size) / f64::from(orig_height); 208 + let scaled = safe_f64_to_u32((f64::from(orig_width) * ratio).max(1.0)); 209 + (scaled.min(max_size), max_size) 206 210 }; 207 211 let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); 208 212 self.encode_image(&thumb)
+77 -12
crates/tranquil-pds/src/lib.rs
··· 2 2 pub mod appview; 3 3 pub mod auth; 4 4 pub mod cache; 5 + pub mod cache_keys; 5 6 pub mod cid_types; 6 7 pub mod circuit_breaker; 7 8 pub mod comms; ··· 658 659 .layer(DefaultBodyLimit::max(64 * 1024)), 659 660 ) 660 661 .layer(DefaultBodyLimit::max(util::get_max_blob_size())) 662 + .layer(axum::middleware::map_response(rewrite_422_to_400)) 661 663 .layer(middleware::from_fn(metrics::metrics_middleware)) 662 664 .layer( 663 665 CorsLayer::new() 664 666 .allow_origin(Any) 665 667 .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) 666 668 .allow_headers([ 667 - "Authorization".parse().unwrap(), 668 - "Content-Type".parse().unwrap(), 669 - "Content-Encoding".parse().unwrap(), 670 - "Accept-Encoding".parse().unwrap(), 671 - "DPoP".parse().unwrap(), 672 - "atproto-proxy".parse().unwrap(), 673 - "atproto-accept-labelers".parse().unwrap(), 674 - "x-bsky-topics".parse().unwrap(), 669 + http::header::AUTHORIZATION, 670 + http::header::CONTENT_TYPE, 671 + http::header::CONTENT_ENCODING, 672 + http::header::ACCEPT_ENCODING, 673 + util::HEADER_DPOP, 674 + util::HEADER_ATPROTO_PROXY, 675 + util::HEADER_ATPROTO_ACCEPT_LABELERS, 676 + util::HEADER_X_BSKY_TOPICS, 675 677 ]) 676 678 .expose_headers([ 677 - "WWW-Authenticate".parse().unwrap(), 678 - "DPoP-Nonce".parse().unwrap(), 679 - "atproto-repo-rev".parse().unwrap(), 680 - "atproto-content-labelers".parse().unwrap(), 679 + http::header::WWW_AUTHENTICATE, 680 + util::HEADER_DPOP_NONCE, 681 + util::HEADER_ATPROTO_REPO_REV, 682 + util::HEADER_ATPROTO_CONTENT_LABELERS, 681 683 ]), 682 684 ) 683 685 .with_state(state) 684 686 } 687 + 688 + async fn rewrite_422_to_400(response: axum::response::Response) -> axum::response::Response { 689 + if response.status() != StatusCode::UNPROCESSABLE_ENTITY { 690 + return response; 691 + } 692 + let (mut parts, body) = response.into_parts(); 693 + let bytes = match axum::body::to_bytes(body, 64 * 1024).await { 694 + Ok(b) => b, 695 + Err(_) => { 696 + parts.status = StatusCode::BAD_REQUEST; 697 + parts.headers.remove(http::header::CONTENT_LENGTH); 698 + let fallback = json!({"error": "InvalidRequest", "message": "Invalid request body"}); 699 + return axum::response::Response::from_parts( 700 + parts, 701 + axum::body::Body::from(serde_json::to_vec(&fallback).unwrap_or_default()), 702 + ); 703 + } 704 + }; 705 + let raw = serde_json::from_slice::<serde_json::Value>(&bytes) 706 + .ok() 707 + .and_then(|v| v.get("message").and_then(|m| m.as_str()).map(String::from)) 708 + .unwrap_or_else(|| { 709 + String::from_utf8(bytes.to_vec()).unwrap_or_else(|_| "Invalid request body".into()) 710 + }); 711 + let message = humanize_json_error(&raw); 712 + 713 + parts.status = StatusCode::BAD_REQUEST; 714 + parts.headers.remove(http::header::CONTENT_LENGTH); 715 + let error_name = classify_deserialization_error(&raw); 716 + let new_body = json!({ 717 + "error": error_name, 718 + "message": message 719 + }); 720 + axum::response::Response::from_parts( 721 + parts, 722 + axum::body::Body::from(serde_json::to_vec(&new_body).unwrap_or_default()), 723 + ) 724 + } 725 + 726 + fn humanize_json_error(raw: &str) -> String { 727 + if raw.contains("missing field") { 728 + raw.split("missing field `") 729 + .nth(1) 730 + .and_then(|s| s.split('`').next()) 731 + .map(|field| format!("Missing required field: {}", field)) 732 + .unwrap_or_else(|| raw.to_string()) 733 + } else if raw.contains("invalid type") { 734 + format!("Invalid field type: {}", raw) 735 + } else if raw.contains("Invalid JSON") || raw.contains("syntax") { 736 + "Invalid JSON syntax".to_string() 737 + } else if raw.contains("Content-Type") || raw.contains("content type") { 738 + "Content-Type must be application/json".to_string() 739 + } else { 740 + raw.to_string() 741 + } 742 + } 743 + 744 + fn classify_deserialization_error(raw: &str) -> &'static str { 745 + match raw { 746 + s if s.contains("invalid handle") => "InvalidHandle", 747 + _ => "InvalidRequest", 748 + } 749 + }
+49 -43
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 1 use crate::auth::{BareLoginIdentifier, NormalizedLoginIdentifier}; 2 - use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 + use crate::comms::comms_repo::enqueue_2fa_code; 3 3 use crate::oauth::{ 4 4 AuthFlow, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, Prompt, SessionId, 5 5 db::should_show_consent, scopes::expand_include_scopes, ··· 23 23 use chrono::Utc; 24 24 use serde::{Deserialize, Serialize}; 25 25 use subtle::ConstantTimeEq; 26 - use tranquil_db_traits::ScopePreference; 26 + use tranquil_db_traits::{ScopePreference, WebauthnChallengeType}; 27 27 use tranquil_types::{AuthorizationCode, ClientId, DeviceId as DeviceIdType, RequestId}; 28 28 use urlencoding::encode as url_encode; 29 29 ··· 85 85 || s.starts_with("include:") 86 86 } 87 87 88 - fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 88 + fn extract_device_cookie(headers: &HeaderMap) -> Option<tranquil_types::DeviceId> { 89 89 headers 90 90 .get("cookie") 91 91 .and_then(|v| v.to_str().ok()) ··· 94 94 cookie 95 95 .strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) 96 96 .and_then(|value| crate::config::AuthConfig::get().verify_device_cookie(value)) 97 + .map(tranquil_types::DeviceId::new) 97 98 }) 98 99 }) 99 100 } ··· 105 106 .map(|s| s.to_string()) 106 107 } 107 108 108 - fn make_device_cookie(device_id: &str) -> String { 109 - let signed_value = crate::config::AuthConfig::get().sign_device_cookie(device_id); 109 + fn make_device_cookie(device_id: &tranquil_types::DeviceId) -> String { 110 + let signed_value = crate::config::AuthConfig::get().sign_device_cookie(device_id.as_str()); 110 111 format!( 111 112 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", 112 113 DEVICE_COOKIE_NAME, signed_value ··· 313 314 && let Some(device_id) = extract_device_cookie(&headers) 314 315 && let Ok(accounts) = state 315 316 .oauth_repo 316 - .get_device_accounts(&DeviceIdType::from(device_id.clone())) 317 + .get_device_accounts(&device_id.clone()) 317 318 .await 318 319 && !accounts.is_empty() 319 320 { ··· 419 420 .into_response(); 420 421 } 421 422 }; 422 - let device_id_typed = DeviceIdType::from(device_id.clone()); 423 - let accounts = match state.oauth_repo.get_device_accounts(&device_id_typed).await { 423 + let accounts = match state.oauth_repo.get_device_accounts(&device_id).await { 424 424 Ok(accts) => accts, 425 425 Err(_) => { 426 426 return Json(AccountsResponse { ··· 693 693 "Failed to enqueue 2FA notification" 694 694 ); 695 695 } 696 - let channel_name = channel_display_name(user.preferred_comms_channel); 696 + let channel_name = user.preferred_comms_channel.display_name(); 697 697 if json_response { 698 698 return Json(serde_json::json!({ 699 699 "needs_2fa": true, ··· 712 712 } 713 713 } 714 714 } 715 - let mut device_id: Option<String> = extract_device_cookie(&headers); 715 + let mut device_id: Option<DeviceIdType> = extract_device_cookie(&headers); 716 716 let mut new_cookie: Option<String> = None; 717 717 if form.remember_device { 718 718 let final_device_id = if let Some(existing_id) = &device_id { 719 719 existing_id.clone() 720 720 } else { 721 721 let new_id = DeviceId::generate(); 722 + let new_device_id_typed = DeviceIdType::new(new_id.0.clone()); 722 723 let device_data = DeviceData { 723 724 session_id: SessionId::generate(), 724 725 user_agent: extract_user_agent(&headers), 725 726 ip_address: extract_client_ip(&headers, None), 726 727 last_seen_at: Utc::now(), 727 728 }; 728 - let new_device_id_typed = DeviceIdType::from(new_id.0.clone()); 729 729 if state 730 730 .oauth_repo 731 731 .create_device(&new_device_id_typed, &device_data) 732 732 .await 733 733 .is_ok() 734 734 { 735 - new_cookie = Some(make_device_cookie(&new_id.0)); 736 - device_id = Some(new_id.0.clone()); 735 + new_cookie = Some(make_device_cookie(&new_device_id_typed)); 736 + device_id = Some(new_device_id_typed.clone()); 737 737 } 738 - new_id.0 738 + new_device_id_typed 739 739 }; 740 - let final_device_typed = DeviceIdType::from(final_device_id.clone()); 741 740 let _ = state 742 741 .oauth_repo 743 - .upsert_account_device(&user.did, &final_device_typed) 742 + .upsert_account_device(&user.did, &final_device_id) 744 743 .await; 745 744 } 746 - let set_auth_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 745 + let set_auth_device_id = device_id.clone(); 747 746 if state 748 747 .oauth_repo 749 748 .set_authorization_did(&form_request_id, &user.did, set_auth_device_id.as_ref()) ··· 796 795 return redirect_see_other(&consent_url); 797 796 } 798 797 let code = Code::generate(); 799 - let auth_post_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 798 + let auth_post_device_id = device_id.clone(); 800 799 let auth_post_code = AuthorizationCode::from(code.0.clone()); 801 800 if state 802 801 .oauth_repo ··· 915 914 ); 916 915 } 917 916 }; 918 - let verify_device_id = DeviceIdType::from(device_id.clone()); 917 + let verify_device_id = device_id.clone(); 919 918 let account_valid = match state 920 919 .oauth_repo 921 920 .verify_account_on_device(&verify_device_id, &did) ··· 963 962 ); 964 963 } 965 964 let has_totp = crate::api::server::has_totp_enabled(&state, &did).await; 966 - let select_early_device_typed = DeviceIdType::from(device_id.clone()); 965 + let select_early_device_typed = device_id.clone(); 967 966 if has_totp { 968 967 if state 969 968 .oauth_repo ··· 1009 1008 "Failed to enqueue 2FA notification" 1010 1009 ); 1011 1010 } 1012 - let channel_name = channel_display_name(user.preferred_comms_channel); 1011 + let channel_name = user.preferred_comms_channel.display_name(); 1013 1012 return Json(serde_json::json!({ 1014 1013 "needs_2fa": true, 1015 1014 "channel": channel_name ··· 1025 1024 } 1026 1025 } 1027 1026 } 1028 - let select_device_typed = DeviceIdType::from(device_id.clone()); 1027 + let select_device_typed = device_id.clone(); 1029 1028 let _ = state 1030 1029 .oauth_repo 1031 1030 .upsert_account_device(&did, &select_device_typed) ··· 1714 1713 let consent_post_device_id = request_data 1715 1714 .device_id 1716 1715 .as_ref() 1717 - .map(|d| DeviceIdType::from(d.0.clone())); 1716 + .map(|d| DeviceIdType::new(d.0.clone())); 1718 1717 let consent_post_code = AuthorizationCode::from(code.0.clone()); 1719 1718 if state 1720 1719 .oauth_repo ··· 1837 1836 let _ = state.oauth_repo.delete_2fa_challenge(challenge.id).await; 1838 1837 let code = Code::generate(); 1839 1838 let device_id = extract_device_cookie(&headers); 1840 - let twofa_totp_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 1839 + let twofa_totp_device_id = device_id.clone(); 1841 1840 let twofa_totp_code = AuthorizationCode::from(code.0.clone()); 1842 1841 if state 1843 1842 .oauth_repo ··· 1945 1944 return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 1946 1945 } 1947 1946 let code = Code::generate(); 1948 - let twofa_final_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 1947 + let twofa_final_device_id = device_id.clone(); 1949 1948 let twofa_final_code = AuthorizationCode::from(code.0.clone()); 1950 1949 if state 1951 1950 .oauth_repo ··· 2273 2272 2274 2273 if let Err(e) = state 2275 2274 .user_repo 2276 - .save_webauthn_challenge(&user.did, "authentication", &state_json) 2275 + .save_webauthn_challenge( 2276 + &user.did, 2277 + WebauthnChallengeType::Authentication, 2278 + &state_json, 2279 + ) 2277 2280 .await 2278 2281 { 2279 2282 tracing::error!(error = %e, "Failed to save authentication state"); ··· 2461 2464 2462 2465 let auth_state_json = match state 2463 2466 .user_repo 2464 - .load_webauthn_challenge(passkey_owner_did, "authentication") 2467 + .load_webauthn_challenge(passkey_owner_did, WebauthnChallengeType::Authentication) 2465 2468 .await 2466 2469 { 2467 2470 Ok(Some(s)) => s, ··· 2540 2543 2541 2544 if let Err(e) = state 2542 2545 .user_repo 2543 - .delete_webauthn_challenge(passkey_owner_did, "authentication") 2546 + .delete_webauthn_challenge(passkey_owner_did, WebauthnChallengeType::Authentication) 2544 2547 .await 2545 2548 { 2546 2549 tracing::warn!(error = %e, "Failed to delete authentication state"); ··· 2550 2553 let cred_id_bytes = auth_result.cred_id().as_slice(); 2551 2554 match state 2552 2555 .user_repo 2553 - .update_passkey_counter(cred_id_bytes, auth_result.counter() as i32) 2556 + .update_passkey_counter( 2557 + cred_id_bytes, 2558 + i32::try_from(auth_result.counter()).unwrap_or(i32::MAX), 2559 + ) 2554 2560 .await 2555 2561 { 2556 2562 Ok(false) => { ··· 2608 2614 { 2609 2615 tracing::warn!(did = %did, error = %e, "Failed to enqueue 2FA notification"); 2610 2616 } 2611 - let channel_name = channel_display_name(user.preferred_comms_channel); 2617 + let channel_name = user.preferred_comms_channel.display_name(); 2612 2618 return Json(serde_json::json!({ 2613 2619 "needs_2fa": true, 2614 2620 "channel": channel_name ··· 2658 2664 } 2659 2665 2660 2666 let code = Code::generate(); 2661 - let passkey_final_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 2667 + let passkey_final_device_id = device_id.clone(); 2662 2668 let passkey_final_code = AuthorizationCode::from(code.0.clone()); 2663 2669 if state 2664 2670 .oauth_repo ··· 2844 2850 2845 2851 if let Err(e) = state 2846 2852 .user_repo 2847 - .save_webauthn_challenge(&did, "authentication", &state_json) 2853 + .save_webauthn_challenge(&did, WebauthnChallengeType::Authentication, &state_json) 2848 2854 .await 2849 2855 { 2850 2856 tracing::error!("Failed to save authentication state: {:?}", e); ··· 2951 2957 2952 2958 let auth_state_json = match state 2953 2959 .user_repo 2954 - .load_webauthn_challenge(&did, "authentication") 2960 + .load_webauthn_challenge(&did, WebauthnChallengeType::Authentication) 2955 2961 .await 2956 2962 { 2957 2963 Ok(Some(s)) => s, ··· 3025 3031 3026 3032 let _ = state 3027 3033 .user_repo 3028 - .delete_webauthn_challenge(&did, "authentication") 3034 + .delete_webauthn_challenge(&did, WebauthnChallengeType::Authentication) 3029 3035 .await; 3030 3036 3031 3037 match state 3032 3038 .user_repo 3033 - .update_passkey_counter(credential.id.as_ref(), auth_result.counter() as i32) 3039 + .update_passkey_counter( 3040 + credential.id.as_ref(), 3041 + i32::try_from(auth_result.counter()).unwrap_or(i32::MAX), 3042 + ) 3034 3043 .await 3035 3044 { 3036 3045 Ok(false) => { ··· 3101 3110 { 3102 3111 tracing::warn!(did = %did, error = %e, "Failed to enqueue 2FA notification"); 3103 3112 } 3104 - let channel_name = channel_display_name(user.preferred_comms_channel); 3113 + let channel_name = user.preferred_comms_channel.display_name(); 3105 3114 let redirect_url = format!( 3106 3115 "/app/oauth/2fa?request_uri={}&channel={}", 3107 3116 url_encode(&form.request_uri), ··· 3440 3449 3441 3450 let (device_id, new_cookie) = match existing_device { 3442 3451 Some(id) => { 3443 - let device_typed = DeviceIdType::from(id.clone()); 3444 - let _ = state 3445 - .oauth_repo 3446 - .upsert_account_device(did, &device_typed) 3447 - .await; 3452 + let _ = state.oauth_repo.upsert_account_device(did, &id).await; 3448 3453 (id, None) 3449 3454 } 3450 3455 None => { 3451 3456 let new_id = DeviceId::generate(); 3457 + let device_typed = DeviceIdType::new(new_id.0.clone()); 3452 3458 let device_data = DeviceData { 3453 3459 session_id: SessionId::generate(), 3454 3460 user_agent: extract_user_agent(&headers), 3455 3461 ip_address: extract_client_ip(&headers, None), 3456 3462 last_seen_at: Utc::now(), 3457 3463 }; 3458 - let device_typed = DeviceIdType::from(new_id.0.clone()); 3459 3464 3460 3465 if let Err(e) = state 3461 3466 .oauth_repo ··· 3489 3494 .into_response(); 3490 3495 } 3491 3496 3492 - (new_id.0.clone(), Some(make_device_cookie(&new_id.0))) 3497 + let cookie = make_device_cookie(&device_typed); 3498 + (device_typed, Some(cookie)) 3493 3499 } 3494 3500 }; 3495 3501
+1 -1
crates/tranquil-pds/src/oauth/endpoints/par.rs
··· 131 131 axum::http::StatusCode::CREATED, 132 132 Json(ParResponse { 133 133 request_uri: request_id.0, 134 - expires_in: PAR_EXPIRY_SECONDS as u64, 134 + expires_in: u64::try_from(PAR_EXPIRY_SECONDS).unwrap_or(600), 135 135 }), 136 136 )) 137 137 }
+58 -34
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
··· 1 1 use super::helpers::{create_access_token_with_delegation, verify_pkce}; 2 - use super::types::{TokenGrant, TokenResponse, ValidatedTokenRequest}; 2 + use super::types::{ 3 + RequestClientAuth, TokenGrant, TokenResponse, TokenType, ValidatedTokenRequest, 4 + }; 3 5 use crate::config::AuthConfig; 4 6 use crate::delegation::intersect_scopes; 5 7 use crate::oauth::{ ··· 12 14 use crate::state::AppState; 13 15 use crate::util::pds_hostname; 14 16 use axum::Json; 15 - use axum::http::HeaderMap; 17 + use axum::http::{HeaderMap, Method}; 16 18 use chrono::{Duration, Utc}; 17 19 use tranquil_db_traits::RefreshTokenLookup; 18 20 use tranquil_types::{AuthorizationCode, Did, RefreshToken as RefreshTokenType}; 19 21 20 - const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 300; 22 + const ACCESS_TOKEN_EXPIRY_SECONDS: u64 = 300; 21 23 const REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL: i64 = 60; 22 24 const REFRESH_TOKEN_EXPIRY_DAYS_PUBLIC: i64 = 14; 23 25 ··· 29 31 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 30 32 tracing::info!( 31 33 has_dpop = dpop_proof.is_some(), 32 - client_id = ?request.client_auth.client_id, 34 + client_id = ?request.client_auth.client_id(), 33 35 "Authorization code grant requested" 34 36 ); 35 37 let (code, code_verifier, redirect_uri) = match request.grant { ··· 59 61 .require_authorized() 60 62 .map_err(|_| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 61 63 62 - if let Some(request_client_id) = &request.client_auth.client_id 63 - && request_client_id != &authorized.client_id 64 + if let Some(request_client_id) = request.client_auth.client_id() 65 + && request_client_id != authorized.client_id 64 66 { 65 67 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 66 68 } 67 69 let did = authorized.did.to_string(); 68 70 let client_metadata_cache = ClientMetadataCache::new(3600); 69 71 let client_metadata = client_metadata_cache.get(&authorized.client_id).await?; 70 - let client_auth = if let (Some(assertion), Some(assertion_type)) = ( 71 - &request.client_auth.client_assertion, 72 - &request.client_auth.client_assertion_type, 73 - ) { 74 - if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 75 - return Err(OAuthError::InvalidClient( 76 - "Unsupported client_assertion_type".to_string(), 77 - )); 78 - } 79 - ClientAuth::PrivateKeyJwt { 80 - client_assertion: assertion.clone(), 72 + let client_auth = match &request.client_auth { 73 + RequestClientAuth::PrivateKeyJwt { 74 + assertion, 75 + assertion_type, 76 + .. 77 + } => { 78 + if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 79 + return Err(OAuthError::InvalidClient( 80 + "Unsupported client_assertion_type".to_string(), 81 + )); 82 + } 83 + ClientAuth::PrivateKeyJwt { 84 + client_assertion: assertion.clone(), 85 + } 81 86 } 82 - } else if let Some(secret) = &request.client_auth.client_secret { 83 - ClientAuth::SecretPost { 84 - client_secret: secret.clone(), 85 - } 86 - } else { 87 - ClientAuth::None 87 + RequestClientAuth::SecretPost { client_secret, .. } => ClientAuth::SecretPost { 88 + client_secret: client_secret.clone(), 89 + }, 90 + RequestClientAuth::None { .. } => ClientAuth::None, 88 91 }; 89 92 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 90 93 verify_pkce(&authorized.parameters.code_challenge, &code_verifier)?; ··· 100 103 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 101 104 let pds_hostname = pds_hostname(); 102 105 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 103 - let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 106 + let result = verifier.verify_proof(proof, Method::POST.as_str(), &token_endpoint, None)?; 104 107 if !state 105 108 .oauth_repo 106 109 .check_and_record_dpop_jti(&result.jti) ··· 220 223 let mut response_headers = HeaderMap::new(); 221 224 let config = AuthConfig::get(); 222 225 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 223 - response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap()); 226 + let nonce = verifier.generate_nonce(); 227 + let nonce_header = nonce.parse().map_err(|_| { 228 + OAuthError::ServerError("Failed to encode DPoP nonce as header value".to_string()) 229 + })?; 230 + response_headers.insert("DPoP-Nonce", nonce_header); 224 231 Ok(( 225 232 response_headers, 226 233 Json(TokenResponse { 227 234 access_token, 228 - token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 229 - expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 235 + token_type: match dpop_jkt { 236 + Some(_) => TokenType::DPoP, 237 + None => TokenType::Bearer, 238 + }, 239 + expires_in: ACCESS_TOKEN_EXPIRY_SECONDS, 230 240 refresh_token: Some(refresh_token.0), 231 241 scope: final_scope, 232 242 sub: Some(did), ··· 283 293 let mut response_headers = HeaderMap::new(); 284 294 let config = AuthConfig::get(); 285 295 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 286 - response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap()); 296 + let nonce = verifier.generate_nonce(); 297 + let nonce_header = nonce.parse().map_err(|_| { 298 + OAuthError::ServerError("Failed to encode DPoP nonce as header value".to_string()) 299 + })?; 300 + response_headers.insert("DPoP-Nonce", nonce_header); 287 301 return Ok(( 288 302 response_headers, 289 303 Json(TokenResponse { 290 304 access_token, 291 - token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 292 - expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 305 + token_type: match dpop_jkt { 306 + Some(_) => TokenType::DPoP, 307 + None => TokenType::Bearer, 308 + }, 309 + expires_in: ACCESS_TOKEN_EXPIRY_SECONDS, 293 310 refresh_token: token_data.current_refresh_token.map(|r| r.0), 294 311 scope: token_data.scope, 295 312 sub: Some(token_data.did.to_string()), ··· 333 350 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 334 351 let pds_hostname = pds_hostname(); 335 352 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 336 - let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 353 + let result = verifier.verify_proof(proof, Method::POST.as_str(), &token_endpoint, None)?; 337 354 if !state 338 355 .oauth_repo 339 356 .check_and_record_dpop_jti(&result.jti) ··· 387 404 let mut response_headers = HeaderMap::new(); 388 405 let config = AuthConfig::get(); 389 406 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 390 - response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap()); 407 + let nonce = verifier.generate_nonce(); 408 + let nonce_header = nonce.parse().map_err(|_| { 409 + OAuthError::ServerError("Failed to encode DPoP nonce as header value".to_string()) 410 + })?; 411 + response_headers.insert("DPoP-Nonce", nonce_header); 391 412 Ok(( 392 413 response_headers, 393 414 Json(TokenResponse { 394 415 access_token, 395 - token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 396 - expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 416 + token_type: match dpop_jkt { 417 + Some(_) => TokenType::DPoP, 418 + None => TokenType::Bearer, 419 + }, 420 + expires_in: ACCESS_TOKEN_EXPIRY_SECONDS, 397 421 refresh_token: Some(new_refresh_token.0), 398 422 scope: token_data.scope, 399 423 sub: Some(token_data.did.to_string()),
+8 -2
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
··· 77 77 "alg": "HS256", 78 78 "typ": "at+jwt" 79 79 }); 80 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 81 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 80 + let header_b64 = 81 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).map_err(|_| { 82 + OAuthError::ServerError("token header serialization failed".to_string()) 83 + })?); 84 + let payload_b64 = 85 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).map_err(|_| { 86 + OAuthError::ServerError("token payload serialization failed".to_string()) 87 + })?); 82 88 let signing_input = format!("{}.{}", header_b64, payload_b64); 83 89 let config = AuthConfig::get(); 84 90 type HmacSha256 = hmac::Hmac<Sha256>;
+2 -2
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
··· 15 15 IntrospectRequest, IntrospectResponse, RevokeRequest, introspect_token, revoke_token, 16 16 }; 17 17 pub use types::{ 18 - ClientAuthParams, GrantType, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest, 18 + GrantType, RequestClientAuth, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest, 19 19 }; 20 20 21 21 pub async fn token_endpoint( ··· 41 41 )); 42 42 }; 43 43 let dpop_proof = headers 44 - .get("DPoP") 44 + .get(crate::util::HEADER_DPOP) 45 45 .and_then(|v| v.to_str().ok()) 46 46 .map(|s| s.to_string()); 47 47 let validated = request.validate()?;
+70 -43
crates/tranquil-pds/src/oauth/endpoints/token/types.rs
··· 5 5 pub enum GrantType { 6 6 AuthorizationCode, 7 7 RefreshToken, 8 - Unsupported(String), 9 8 } 10 9 11 10 impl GrantType { ··· 13 12 match self { 14 13 Self::AuthorizationCode => "authorization_code", 15 14 Self::RefreshToken => "refresh_token", 16 - Self::Unsupported(s) => s, 17 15 } 18 16 } 19 17 } 20 18 21 - impl std::str::FromStr for GrantType { 22 - type Err = std::convert::Infallible; 19 + #[derive(Debug, Clone)] 20 + pub struct UnsupportedGrantType(pub String); 23 21 24 - fn from_str(s: &str) -> Result<Self, Self::Err> { 25 - Ok(match s { 26 - "authorization_code" => Self::AuthorizationCode, 27 - "refresh_token" => Self::RefreshToken, 28 - other => Self::Unsupported(other.to_string()), 29 - }) 22 + impl std::fmt::Display for UnsupportedGrantType { 23 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 24 + write!(f, "unsupported grant type: {}", self.0) 30 25 } 31 26 } 32 27 33 - impl<'de> Deserialize<'de> for GrantType { 34 - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 35 - where 36 - D: serde::Deserializer<'de>, 37 - { 38 - let s = String::deserialize(deserializer)?; 39 - Ok(s.parse().unwrap()) 40 - } 41 - } 28 + impl std::error::Error for UnsupportedGrantType {} 29 + 30 + impl std::str::FromStr for GrantType { 31 + type Err = UnsupportedGrantType; 42 32 43 - impl Serialize for GrantType { 44 - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 45 - where 46 - S: serde::Serializer, 47 - { 48 - serializer.serialize_str(self.as_str()) 33 + fn from_str(s: &str) -> Result<Self, Self::Err> { 34 + match s { 35 + "authorization_code" => Ok(Self::AuthorizationCode), 36 + "refresh_token" => Ok(Self::RefreshToken), 37 + other => Err(UnsupportedGrantType(other.to_string())), 38 + } 49 39 } 50 40 } 51 41 52 42 #[derive(Debug, Deserialize)] 53 43 pub struct TokenRequest { 54 - pub grant_type: GrantType, 44 + pub grant_type: String, 55 45 #[serde(default)] 56 46 pub code: Option<String>, 57 47 #[serde(default)] ··· 82 72 }, 83 73 } 84 74 85 - #[derive(Debug, Clone, Default)] 86 - pub struct ClientAuthParams { 87 - pub client_id: Option<String>, 88 - pub client_secret: Option<String>, 89 - pub client_assertion: Option<String>, 90 - pub client_assertion_type: Option<String>, 75 + #[derive(Debug, Clone)] 76 + pub enum RequestClientAuth { 77 + None { 78 + client_id: Option<String>, 79 + }, 80 + SecretPost { 81 + client_id: Option<String>, 82 + client_secret: String, 83 + }, 84 + PrivateKeyJwt { 85 + client_id: Option<String>, 86 + assertion: String, 87 + assertion_type: String, 88 + }, 89 + } 90 + 91 + impl RequestClientAuth { 92 + pub fn client_id(&self) -> Option<&str> { 93 + match self { 94 + Self::None { client_id } 95 + | Self::SecretPost { client_id, .. } 96 + | Self::PrivateKeyJwt { client_id, .. } => client_id.as_deref(), 97 + } 98 + } 91 99 } 92 100 93 101 #[derive(Debug, Clone)] 94 102 pub struct ValidatedTokenRequest { 95 103 pub grant: TokenGrant, 96 - pub client_auth: ClientAuthParams, 104 + pub client_auth: RequestClientAuth, 97 105 } 98 106 99 107 impl TokenRequest { 100 108 pub fn validate(self) -> Result<ValidatedTokenRequest, OAuthError> { 101 - let grant = match self.grant_type { 109 + let grant_type: GrantType = self 110 + .grant_type 111 + .parse() 112 + .map_err(|e: UnsupportedGrantType| OAuthError::UnsupportedGrantType(e.0))?; 113 + let grant = match grant_type { 102 114 GrantType::AuthorizationCode => { 103 115 let code = self.code.ok_or_else(|| { 104 116 OAuthError::InvalidRequest( ··· 124 136 })?; 125 137 TokenGrant::RefreshToken { refresh_token } 126 138 } 127 - GrantType::Unsupported(grant_type) => { 128 - return Err(OAuthError::UnsupportedGrantType(grant_type)); 129 - } 130 139 }; 131 140 132 - let client_auth = ClientAuthParams { 133 - client_id: self.client_id, 134 - client_secret: self.client_secret, 135 - client_assertion: self.client_assertion, 136 - client_assertion_type: self.client_assertion_type, 141 + let client_auth = match (self.client_assertion, self.client_assertion_type) { 142 + (Some(assertion), Some(assertion_type)) => RequestClientAuth::PrivateKeyJwt { 143 + client_id: self.client_id, 144 + assertion, 145 + assertion_type, 146 + }, 147 + _ => match self.client_secret { 148 + Some(secret) => RequestClientAuth::SecretPost { 149 + client_id: self.client_id, 150 + client_secret: secret, 151 + }, 152 + None => RequestClientAuth::None { 153 + client_id: self.client_id, 154 + }, 155 + }, 137 156 }; 138 157 139 158 Ok(ValidatedTokenRequest { grant, client_auth }) 140 159 } 141 160 } 142 161 162 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] 163 + #[serde(rename_all = "PascalCase")] 164 + pub enum TokenType { 165 + Bearer, 166 + #[serde(rename = "DPoP")] 167 + DPoP, 168 + } 169 + 143 170 #[derive(Debug, Serialize)] 144 171 pub struct TokenResponse { 145 172 pub access_token: String, 146 - pub token_type: String, 173 + pub token_type: TokenType, 147 174 pub expires_in: u64, 148 175 #[serde(skip_serializing_if = "Option::is_none")] 149 176 pub refresh_token: Option<String>,
+22 -10
crates/tranquil-pds/src/oauth/verify.rs
··· 12 12 use tranquil_db_traits::{OAuthRepository, UserRepository}; 13 13 use tranquil_types::{ClientId, TokenId}; 14 14 15 + use crate::auth::AuthSource; 15 16 use crate::types::Did; 16 17 17 18 use super::scopes::ScopePermissions; ··· 217 218 pub did: Did, 218 219 pub client_id: Option<ClientId>, 219 220 pub scope: Option<String>, 220 - pub is_oauth: bool, 221 + pub auth_source: AuthSource, 221 222 pub permissions: ScopePermissions, 222 223 } 223 224 ··· 240 241 ) 241 242 .into_response(); 242 243 if let Some(nonce) = self.dpop_nonce { 243 - response 244 - .headers_mut() 245 - .insert("DPoP-Nonce", nonce.parse().unwrap()); 244 + match nonce.parse() { 245 + Ok(val) => { 246 + response.headers_mut().insert("DPoP-Nonce", val); 247 + } 248 + Err(e) => tracing::warn!(error = %e, "DPoP-Nonce header value failed to encode"), 249 + } 246 250 } 247 251 if let Some(www_auth) = self.www_authenticate { 248 - response 249 - .headers_mut() 250 - .insert("WWW-Authenticate", www_auth.parse().unwrap()); 252 + match www_auth.parse() { 253 + Ok(val) => { 254 + response.headers_mut().insert("WWW-Authenticate", val); 255 + } 256 + Err(e) => { 257 + tracing::warn!(error = %e, "WWW-Authenticate header value failed to encode") 258 + } 259 + } 251 260 } 252 261 response 253 262 } ··· 289 298 www_authenticate: None, 290 299 }); 291 300 }; 292 - let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok()); 301 + let dpop_proof = parts 302 + .headers 303 + .get(crate::util::HEADER_DPOP) 304 + .and_then(|v| v.to_str().ok()); 293 305 if let Ok(result) = try_legacy_auth(state.user_repo.as_ref(), token).await { 294 306 return Ok(OAuthUser { 295 307 did: result.did, 296 308 client_id: None, 297 309 scope: None, 298 - is_oauth: false, 310 + auth_source: AuthSource::Session, 299 311 permissions: ScopePermissions::default(), 300 312 }); 301 313 } ··· 316 328 did: result.did, 317 329 client_id: Some(result.client_id), 318 330 scope: result.scope, 319 - is_oauth: true, 331 + auth_source: AuthSource::OAuth, 320 332 permissions, 321 333 }) 322 334 }
+71 -49
crates/tranquil-pds/src/plc/mod.rs
··· 31 31 CircuitBreakerOpen, 32 32 } 33 33 34 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 35 + pub enum PlcOpType { 36 + #[serde(rename = "plc_operation")] 37 + Operation, 38 + #[serde(rename = "plc_tombstone")] 39 + Tombstone, 40 + } 41 + 42 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 43 + pub enum ServiceType { 44 + #[serde(rename = "AtprotoPersonalDataServer")] 45 + Pds, 46 + #[serde(rename = "AtprotoAppView")] 47 + AppView, 48 + } 49 + 50 + impl ServiceType { 51 + pub fn as_str(self) -> &'static str { 52 + match self { 53 + Self::Pds => "AtprotoPersonalDataServer", 54 + Self::AppView => "AtprotoAppView", 55 + } 56 + } 57 + 58 + pub fn is_pds(self) -> bool { 59 + matches!(self, Self::Pds) 60 + } 61 + } 62 + 63 + pub const SECP256K1_MULTICODEC_PREFIX: [u8; 2] = [0xe7, 0x01]; 64 + 34 65 #[derive(Debug, Clone, Serialize, Deserialize)] 35 66 pub struct PlcOperation { 36 67 #[serde(rename = "type")] 37 - pub op_type: String, 68 + pub op_type: PlcOpType, 38 69 #[serde(rename = "rotationKeys")] 39 70 pub rotation_keys: Vec<String>, 40 71 #[serde(rename = "verificationMethods")] ··· 50 81 #[derive(Debug, Clone, Serialize, Deserialize)] 51 82 pub struct PlcService { 52 83 #[serde(rename = "type")] 53 - pub service_type: String, 84 + pub service_type: ServiceType, 54 85 pub endpoint: String, 55 86 } 56 87 57 88 #[derive(Debug, Clone, Serialize, Deserialize)] 58 89 pub struct PlcTombstone { 59 90 #[serde(rename = "type")] 60 - pub op_type: String, 91 + pub op_type: PlcOpType, 61 92 pub prev: String, 62 93 #[serde(skip_serializing_if = "Option::is_none")] 63 94 pub sig: Option<String>, ··· 74 105 pub fn is_tombstone(&self) -> bool { 75 106 match self { 76 107 PlcOpOrTombstone::Tombstone(_) => true, 77 - PlcOpOrTombstone::Operation(op) => op.op_type == "plc_tombstone", 108 + PlcOpOrTombstone::Operation(op) => op.op_type == PlcOpType::Tombstone, 78 109 } 79 110 } 80 111 } ··· 124 155 } 125 156 126 157 pub async fn get_document(&self, did: &str) -> Result<Value, PlcError> { 127 - let cache_key = format!("plc:doc:{}", did); 158 + let cache_key = crate::cache_keys::plc_doc_key(did); 128 159 if let Some(ref cache) = self.cache 129 160 && let Some(cached) = cache.get(&cache_key).await 130 161 && let Ok(value) = serde_json::from_str(&cached) ··· 163 194 } 164 195 165 196 pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> { 166 - let cache_key = format!("plc:data:{}", did); 197 + let cache_key = crate::cache_keys::plc_data_key(did); 167 198 if let Some(ref cache) = self.cache 168 199 && let Some(cached) = cache.get(&cache_key).await 169 200 && let Ok(value) = serde_json::from_str(&cached) ··· 313 344 } 314 345 }; 315 346 let new_op = PlcOperation { 316 - op_type: "plc_operation".to_string(), 347 + op_type: PlcOpType::Operation, 317 348 rotation_keys: rotation_keys.unwrap_or(base_rotation_keys), 318 349 verification_methods: verification_methods.unwrap_or(base_verification_methods), 319 350 also_known_as: also_known_as.unwrap_or(base_also_known_as), ··· 328 359 let verifying_key = signing_key.verifying_key(); 329 360 let point = verifying_key.to_encoded_point(true); 330 361 let compressed_bytes = point.as_bytes(); 331 - let mut prefixed = vec![0xe7, 0x01]; 362 + let mut prefixed = Vec::from(SECP256K1_MULTICODEC_PREFIX); 332 363 prefixed.extend_from_slice(compressed_bytes); 333 364 let encoded = multibase::encode(multibase::Base::Base58Btc, &prefixed); 334 365 format!("did:key:{}", encoded) ··· 352 383 services.insert( 353 384 "atproto_pds".to_string(), 354 385 PlcService { 355 - service_type: "AtprotoPersonalDataServer".to_string(), 386 + service_type: ServiceType::Pds, 356 387 endpoint: pds_endpoint.to_string(), 357 388 }, 358 389 ); 359 390 let genesis_op = PlcOperation { 360 - op_type: "plc_operation".to_string(), 391 + op_type: PlcOpType::Operation, 361 392 rotation_keys: vec![rotation_key.to_string()], 362 393 verification_methods, 363 394 also_known_as: vec![format!("at://{}", handle)], ··· 386 417 Ok(format!("did:plc:{}", truncated)) 387 418 } 388 419 389 - pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> { 420 + pub fn validate_plc_operation(op: &Value) -> Result<PlcOpType, PlcError> { 390 421 let obj = op 391 422 .as_object() 392 423 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 393 - let op_type = obj 424 + let op_type_str = obj 394 425 .get("type") 395 - .and_then(|v| v.as_str()) 396 426 .ok_or_else(|| PlcError::InvalidResponse("Missing type field".to_string()))?; 397 - if op_type != "plc_operation" && op_type != "plc_tombstone" { 398 - return Err(PlcError::InvalidResponse(format!( 427 + let op_type: PlcOpType = serde_json::from_value(op_type_str.clone()).map_err(|_| { 428 + PlcError::InvalidResponse(format!( 399 429 "Invalid type: {}", 400 - op_type 401 - ))); 402 - } 403 - if op_type == "plc_operation" { 404 - if obj.get("rotationKeys").is_none() { 405 - return Err(PlcError::InvalidResponse( 406 - "Missing rotationKeys".to_string(), 407 - )); 408 - } 409 - if obj.get("verificationMethods").is_none() { 410 - return Err(PlcError::InvalidResponse( 411 - "Missing verificationMethods".to_string(), 412 - )); 413 - } 414 - if obj.get("alsoKnownAs").is_none() { 415 - return Err(PlcError::InvalidResponse("Missing alsoKnownAs".to_string())); 416 - } 417 - if obj.get("services").is_none() { 418 - return Err(PlcError::InvalidResponse("Missing services".to_string())); 430 + op_type_str.as_str().unwrap_or("<non-string>") 431 + )) 432 + })?; 433 + match op_type { 434 + PlcOpType::Operation => { 435 + let required_fields = [ 436 + "rotationKeys", 437 + "verificationMethods", 438 + "alsoKnownAs", 439 + "services", 440 + ]; 441 + required_fields.iter().try_for_each(|field| { 442 + obj.get(*field) 443 + .map(|_| ()) 444 + .ok_or_else(|| PlcError::InvalidResponse(format!("Missing {}", field))) 445 + })?; 419 446 } 447 + PlcOpType::Tombstone => {} 420 448 } 421 449 if obj.get("sig").is_none() { 422 450 return Err(PlcError::InvalidResponse("Missing sig".to_string())); 423 451 } 424 - Ok(()) 452 + Ok(op_type) 425 453 } 426 454 427 455 pub struct PlcValidationContext { ··· 435 463 op: &Value, 436 464 ctx: &PlcValidationContext, 437 465 ) -> Result<(), PlcError> { 438 - validate_plc_operation(op)?; 466 + let op_type = validate_plc_operation(op)?; 467 + if op_type != PlcOpType::Operation { 468 + return Ok(()); 469 + } 439 470 let obj = op 440 471 .as_object() 441 472 .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 442 - let op_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or(""); 443 - if op_type != "plc_operation" { 444 - return Ok(()); 445 - } 446 473 let rotation_keys = obj 447 474 .get("rotationKeys") 448 475 .and_then(|v| v.as_array()) ··· 489 516 .get("type") 490 517 .and_then(|v| v.as_str()) 491 518 .unwrap_or(""); 492 - if service_type != "AtprotoPersonalDataServer" { 519 + if service_type != ServiceType::Pds.as_str() { 493 520 return Err(PlcError::InvalidResponse( 494 521 "Incorrect type on atproto_pds service".to_string(), 495 522 )); ··· 551 578 "Invalid did:key data".to_string(), 552 579 )); 553 580 } 554 - let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 555 - (0xe701u16, &decoded[2..]) 581 + let key_bytes = if decoded.starts_with(&SECP256K1_MULTICODEC_PREFIX) { 582 + &decoded[SECP256K1_MULTICODEC_PREFIX.len()..] 556 583 } else { 557 584 return Err(PlcError::InvalidResponse( 558 - "Unsupported key type in did:key".to_string(), 585 + "Unsupported key type in did:key (expected secp256k1)".to_string(), 559 586 )); 560 587 }; 561 - if codec != 0xe701 { 562 - return Err(PlcError::InvalidResponse( 563 - "Only secp256k1 keys are supported".to_string(), 564 - )); 565 - } 566 588 let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes) 567 589 .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; 568 590 Ok(verifying_key.verify(message, signature).is_ok())
+29 -29
crates/tranquil-pds/src/rate_limit/mod.rs
··· 46 46 pub fn new() -> Self { 47 47 Self { 48 48 login: Arc::new(RateLimiter::keyed(Quota::per_minute( 49 - NonZeroU32::new(10).unwrap(), 49 + const { NonZeroU32::new(10).unwrap() }, 50 50 ))), 51 51 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 52 - NonZeroU32::new(300).unwrap(), 52 + const { NonZeroU32::new(300).unwrap() }, 53 53 ))), 54 54 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 55 - NonZeroU32::new(10).unwrap(), 55 + const { NonZeroU32::new(10).unwrap() }, 56 56 ))), 57 57 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 58 - NonZeroU32::new(5).unwrap(), 58 + const { NonZeroU32::new(5).unwrap() }, 59 59 ))), 60 60 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 61 - NonZeroU32::new(10).unwrap(), 61 + const { NonZeroU32::new(10).unwrap() }, 62 62 ))), 63 63 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 64 - NonZeroU32::new(60).unwrap(), 64 + const { NonZeroU32::new(60).unwrap() }, 65 65 ))), 66 66 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 67 - NonZeroU32::new(10).unwrap(), 67 + const { NonZeroU32::new(10).unwrap() }, 68 68 ))), 69 69 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 70 - NonZeroU32::new(30).unwrap(), 70 + const { NonZeroU32::new(30).unwrap() }, 71 71 ))), 72 72 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 73 - NonZeroU32::new(30).unwrap(), 73 + const { NonZeroU32::new(30).unwrap() }, 74 74 ))), 75 75 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 76 - NonZeroU32::new(10).unwrap(), 76 + const { NonZeroU32::new(10).unwrap() }, 77 77 ))), 78 78 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 79 - NonZeroU32::new(5).unwrap(), 79 + const { NonZeroU32::new(5).unwrap() }, 80 80 ))), 81 81 totp_verify: Arc::new(RateLimiter::keyed( 82 82 Quota::with_period(std::time::Duration::from_secs(60)) 83 83 .unwrap() 84 - .allow_burst(NonZeroU32::new(5).unwrap()), 84 + .allow_burst(const { NonZeroU32::new(5).unwrap() }), 85 85 )), 86 86 handle_update: Arc::new(RateLimiter::keyed( 87 87 Quota::with_period(std::time::Duration::from_secs(30)) 88 88 .unwrap() 89 - .allow_burst(NonZeroU32::new(10).unwrap()), 89 + .allow_burst(const { NonZeroU32::new(10).unwrap() }), 90 90 )), 91 91 handle_update_daily: Arc::new(RateLimiter::keyed( 92 92 Quota::with_period(std::time::Duration::from_secs(1728)) 93 93 .unwrap() 94 - .allow_burst(NonZeroU32::new(50).unwrap()), 94 + .allow_burst(const { NonZeroU32::new(50).unwrap() }), 95 95 )), 96 96 verification_check: Arc::new(RateLimiter::keyed(Quota::per_minute( 97 - NonZeroU32::new(60).unwrap(), 97 + const { NonZeroU32::new(60).unwrap() }, 98 98 ))), 99 99 sso_initiate: Arc::new(RateLimiter::keyed(Quota::per_minute( 100 - NonZeroU32::new(10).unwrap(), 100 + const { NonZeroU32::new(10).unwrap() }, 101 101 ))), 102 102 sso_callback: Arc::new(RateLimiter::keyed(Quota::per_minute( 103 - NonZeroU32::new(30).unwrap(), 103 + const { NonZeroU32::new(30).unwrap() }, 104 104 ))), 105 105 sso_unlink: Arc::new(RateLimiter::keyed(Quota::per_minute( 106 - NonZeroU32::new(10).unwrap(), 106 + const { NonZeroU32::new(10).unwrap() }, 107 107 ))), 108 108 oauth_register_complete: Arc::new(RateLimiter::keyed( 109 109 Quota::with_period(std::time::Duration::from_secs(60)) 110 110 .unwrap() 111 - .allow_burst(NonZeroU32::new(5).unwrap()), 111 + .allow_burst(const { NonZeroU32::new(5).unwrap() }), 112 112 )), 113 113 handle_verification: Arc::new(RateLimiter::keyed(Quota::per_minute( 114 - NonZeroU32::new(10).unwrap(), 114 + const { NonZeroU32::new(10).unwrap() }, 115 115 ))), 116 116 } 117 117 } 118 118 119 119 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 120 120 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 121 - NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 121 + NonZeroU32::new(per_minute).unwrap_or(const { NonZeroU32::new(10).unwrap() }), 122 122 ))); 123 123 self 124 124 } 125 125 126 126 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 127 127 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 128 - NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 128 + NonZeroU32::new(per_minute).unwrap_or(const { NonZeroU32::new(30).unwrap() }), 129 129 ))); 130 130 self 131 131 } 132 132 133 133 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 134 134 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 135 - NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 135 + NonZeroU32::new(per_minute).unwrap_or(const { NonZeroU32::new(10).unwrap() }), 136 136 ))); 137 137 self 138 138 } 139 139 140 140 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 141 141 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 142 - NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 142 + NonZeroU32::new(per_hour).unwrap_or(const { NonZeroU32::new(5).unwrap() }), 143 143 ))); 144 144 self 145 145 } 146 146 147 147 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 148 148 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 149 - NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 149 + NonZeroU32::new(per_hour).unwrap_or(const { NonZeroU32::new(10).unwrap() }), 150 150 ))); 151 151 self 152 152 } 153 153 154 154 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 155 155 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 156 - NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 156 + NonZeroU32::new(per_hour).unwrap_or(const { NonZeroU32::new(5).unwrap() }), 157 157 ))); 158 158 self 159 159 } 160 160 161 161 pub fn with_sso_initiate_limit(mut self, per_minute: u32) -> Self { 162 162 self.sso_initiate = Arc::new(RateLimiter::keyed(Quota::per_minute( 163 - NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 163 + NonZeroU32::new(per_minute).unwrap_or(const { NonZeroU32::new(10).unwrap() }), 164 164 ))); 165 165 self 166 166 } ··· 178 178 179 179 #[test] 180 180 fn test_rate_limiter_exhaustion() { 181 - let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 181 + let limiter = RateLimiter::keyed(Quota::per_minute(const { NonZeroU32::new(2).unwrap() })); 182 182 let key = "test_ip".to_string(); 183 183 184 184 assert!(limiter.check_key(&key).is_ok()); ··· 188 188 189 189 #[test] 190 190 fn test_different_keys_have_separate_limits() { 191 - let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 191 + let limiter = RateLimiter::keyed(Quota::per_minute(const { NonZeroU32::new(1).unwrap() })); 192 192 193 193 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 194 194 assert!(limiter.check_key(&"ip1".to_string()).is_err());
+121 -55
crates/tranquil-pds/src/scheduled.rs
··· 1 + use anyhow::Context; 1 2 use cid::Cid; 2 3 use ipld_core::ipld::Ipld; 3 4 use jacquard_repo::commit::Commit; ··· 18 19 use crate::storage::{BackupStorage, BlobStorage, backup_interval_secs, backup_retention_count}; 19 20 use crate::sync::car::encode_car_header; 20 21 22 + #[derive(Debug)] 23 + enum GenesisBackfillError { 24 + MissingCommitCid, 25 + InvalidCid, 26 + BlockFetchFailed, 27 + BlockNotFound, 28 + CommitParseFailed, 29 + UpdateFailed, 30 + } 31 + 32 + impl std::fmt::Display for GenesisBackfillError { 33 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 34 + match self { 35 + Self::MissingCommitCid => f.write_str("missing commit_cid"), 36 + Self::InvalidCid => f.write_str("invalid CID"), 37 + Self::BlockFetchFailed => f.write_str("failed to fetch block"), 38 + Self::BlockNotFound => f.write_str("block not found"), 39 + Self::CommitParseFailed => f.write_str("failed to parse commit"), 40 + Self::UpdateFailed => f.write_str("failed to update"), 41 + } 42 + } 43 + } 44 + 21 45 async fn process_genesis_commit( 22 46 repo_repo: &dyn RepoRepository, 23 47 block_store: &PostgresBlockStore, 24 48 row: BrokenGenesisCommit, 25 - ) -> Result<(Did, SequenceNumber), (SequenceNumber, &'static str)> { 26 - let commit_cid_str = row.commit_cid.ok_or((row.seq, "missing commit_cid"))?; 27 - let commit_cid = Cid::from_str(&commit_cid_str).map_err(|_| (row.seq, "invalid CID"))?; 49 + ) -> Result<(Did, SequenceNumber), (SequenceNumber, GenesisBackfillError)> { 50 + let commit_cid_str = row 51 + .commit_cid 52 + .ok_or((row.seq, GenesisBackfillError::MissingCommitCid))?; 53 + let commit_cid = 54 + Cid::from_str(&commit_cid_str).map_err(|_| (row.seq, GenesisBackfillError::InvalidCid))?; 28 55 let block = block_store 29 56 .get(&commit_cid) 30 57 .await 31 - .map_err(|_| (row.seq, "failed to fetch block"))? 32 - .ok_or((row.seq, "block not found"))?; 33 - let commit = Commit::from_cbor(&block).map_err(|_| (row.seq, "failed to parse commit"))?; 58 + .map_err(|_| (row.seq, GenesisBackfillError::BlockFetchFailed))? 59 + .ok_or((row.seq, GenesisBackfillError::BlockNotFound))?; 60 + let commit = Commit::from_cbor(&block) 61 + .map_err(|_| (row.seq, GenesisBackfillError::CommitParseFailed))?; 34 62 let blocks_cids = vec![commit.data.to_string(), commit_cid.to_string()]; 35 63 repo_repo 36 64 .update_seq_blocks_cids(row.seq, &blocks_cids) 37 65 .await 38 - .map_err(|_| (row.seq, "failed to update"))?; 66 + .map_err(|_| (row.seq, GenesisBackfillError::UpdateFailed))?; 39 67 Ok((row.did, row.seq)) 40 68 } 41 69 ··· 79 107 Err((seq, reason)) => { 80 108 warn!( 81 109 seq = seq.as_i64(), 82 - reason = reason, 110 + reason = %reason, 83 111 "Failed to process genesis commit" 84 112 ); 85 113 (s, f + 1) ··· 99 127 repo_root_cid: String, 100 128 ) -> Result<uuid::Uuid, uuid::Uuid> { 101 129 let cid = Cid::from_str(&repo_root_cid).map_err(|_| user_id)?; 102 - let block = block_store.get(&cid).await.ok().flatten().ok_or(user_id)?; 130 + let block = match block_store.get(&cid).await { 131 + Ok(Some(b)) => b, 132 + Ok(None) => { 133 + tracing::warn!(user_id = %user_id, cid = %cid, "block not found for repo rev backfill"); 134 + return Err(user_id); 135 + } 136 + Err(e) => { 137 + tracing::warn!(user_id = %user_id, cid = %cid, error = %e, "block store error during repo rev backfill"); 138 + return Err(user_id); 139 + } 140 + }; 103 141 let commit = Commit::from_cbor(&block).map_err(|_| user_id)?; 104 142 let rev = commit.rev().to_string(); 105 143 repo_repo ··· 235 273 pub async fn collect_current_repo_blocks( 236 274 block_store: &PostgresBlockStore, 237 275 head_cid: &Cid, 238 - ) -> Result<Vec<Vec<u8>>, String> { 276 + ) -> anyhow::Result<Vec<Vec<u8>>> { 239 277 let mut block_cids: Vec<Vec<u8>> = Vec::new(); 240 278 let mut to_visit = vec![*head_cid]; 241 279 let mut visited = std::collections::HashSet::new(); ··· 250 288 let block = match block_store.get(&cid).await { 251 289 Ok(Some(b)) => b, 252 290 Ok(None) => continue, 253 - Err(e) => return Err(format!("Failed to get block {}: {:?}", cid, e)), 291 + Err(e) => anyhow::bail!("Failed to get block {}: {:?}", cid, e), 254 292 }; 255 293 256 294 if let Ok(commit) = Commit::from_cbor(&block) { ··· 308 346 Some( 309 347 blob_refs 310 348 .into_iter() 311 - .map(|blob_ref| { 349 + .filter_map(|blob_ref| { 312 350 let record_uri = AtUri::from_parts( 313 351 did.as_str(), 314 352 record.collection.as_str(), 315 353 record.rkey.as_str(), 316 354 ); 317 - (record_uri, unsafe { CidLink::new_unchecked(blob_ref.cid) }) 355 + match CidLink::new(&blob_ref.cid) { 356 + Ok(cid_link) => Some((record_uri, cid_link)), 357 + Err(_) => { 358 + tracing::warn!(cid = %blob_ref.cid, "skipping unparseable blob CID in record blob backfill"); 359 + None 360 + } 361 + } 318 362 }) 319 363 .collect::<Vec<_>>(), 320 364 ) ··· 462 506 user_repo: &dyn UserRepository, 463 507 blob_repo: &dyn BlobRepository, 464 508 blob_store: &dyn BlobStorage, 465 - ) -> Result<(), String> { 509 + ) -> anyhow::Result<()> { 466 510 let accounts_to_delete = user_repo 467 511 .get_accounts_scheduled_for_deletion(100) 468 512 .await 469 - .map_err(|e| format!("DB error fetching accounts to delete: {:?}", e))?; 513 + .context("DB error fetching accounts to delete")?; 470 514 471 515 if accounts_to_delete.is_empty() { 472 516 debug!("No accounts scheduled for deletion"); ··· 501 545 blob_store: &dyn BlobStorage, 502 546 user_id: uuid::Uuid, 503 547 did: &Did, 504 - ) -> Result<(), String> { 548 + ) -> anyhow::Result<()> { 505 549 let blob_storage_keys = blob_repo 506 550 .get_blob_storage_keys_by_user(user_id) 507 551 .await 508 - .map_err(|e| format!("DB error fetching blob keys: {:?}", e))?; 552 + .context("DB error fetching blob keys")?; 509 553 510 554 futures::future::join_all(blob_storage_keys.iter().map(|storage_key| async move { 511 555 (storage_key, blob_store.delete(storage_key).await) ··· 520 564 let _account_seq = user_repo 521 565 .delete_account_with_firehose(user_id, did) 522 566 .await 523 - .map_err(|e| format!("Failed to delete account: {:?}", e))?; 567 + .context("Failed to delete account")?; 524 568 525 569 info!( 526 570 did = %did, ··· 570 614 } 571 615 572 616 struct BackupResult { 573 - did: String, 617 + did: Did, 574 618 repo_rev: String, 575 619 size_bytes: i64, 576 620 block_count: i32, ··· 579 623 580 624 enum BackupOutcome { 581 625 Success(BackupResult), 582 - Skipped(String, &'static str), 583 - Failed(String, String), 626 + Skipped(Did, &'static str), 627 + Failed(Did, String), 584 628 } 585 629 586 630 #[allow(clippy::too_many_arguments)] ··· 590 634 block_store: &PostgresBlockStore, 591 635 backup_storage: &dyn BackupStorage, 592 636 user_id: uuid::Uuid, 593 - did: String, 637 + did: Did, 594 638 repo_root_cid: String, 595 639 repo_rev: Option<String>, 596 640 ) -> BackupOutcome { ··· 610 654 }; 611 655 612 656 let block_count = count_car_blocks(&car_bytes); 613 - let size_bytes = car_bytes.len() as i64; 657 + let size_bytes = i64::try_from(car_bytes.len()).unwrap_or(i64::MAX); 614 658 615 - let storage_key = match backup_storage.put_backup(&did, &repo_rev, &car_bytes).await { 659 + let storage_key = match backup_storage 660 + .put_backup(did.as_str(), &repo_rev, &car_bytes) 661 + .await 662 + { 616 663 Ok(key) => key, 617 664 Err(e) => return BackupOutcome::Failed(did, format!("S3 upload: {}", e)), 618 665 }; ··· 653 700 backup_repo: &dyn BackupRepository, 654 701 block_store: &PostgresBlockStore, 655 702 backup_storage: &dyn BackupStorage, 656 - ) -> Result<(), String> { 657 - let interval_secs = backup_interval_secs() as i64; 703 + ) -> anyhow::Result<()> { 704 + let interval_secs = i64::try_from(backup_interval_secs()).unwrap_or(i64::MAX); 658 705 let retention = backup_retention_count(); 659 706 660 707 let users_needing_backup = backup_repo 661 708 .get_users_needing_backup(interval_secs, 50) 662 709 .await 663 - .map_err(|e| format!("DB error fetching users for backup: {:?}", e))?; 710 + .context("DB error fetching users for backup")?; 664 711 665 712 if users_needing_backup.is_empty() { 666 713 debug!("No accounts need backup"); ··· 679 726 block_store, 680 727 backup_storage, 681 728 user.id, 682 - user.did.to_string(), 729 + user.did, 683 730 user.repo_root_cid.to_string(), 684 731 user.repo_rev, 685 732 ) ··· 719 766 pub async fn generate_repo_car( 720 767 block_store: &PostgresBlockStore, 721 768 head_cid: &Cid, 722 - ) -> Result<Vec<u8>, String> { 769 + ) -> anyhow::Result<Vec<u8>> { 723 770 use jacquard_repo::storage::BlockStore; 724 771 725 772 let block_cids_bytes = collect_current_repo_blocks(block_store, head_cid).await?; 726 773 let block_cids: Vec<Cid> = block_cids_bytes 727 774 .iter() 728 - .filter_map(|b| Cid::try_from(b.as_slice()).ok()) 775 + .filter_map(|b| match Cid::try_from(b.as_slice()) { 776 + Ok(cid) => Some(cid), 777 + Err(e) => { 778 + tracing::warn!(error = %e, "skipping unparseable CID in backup generation"); 779 + None 780 + } 781 + }) 729 782 .collect(); 730 783 731 - let car_bytes = 732 - encode_car_header(head_cid).map_err(|e| format!("Failed to encode CAR header: {}", e))?; 784 + let car_bytes = encode_car_header(head_cid).context("Failed to encode CAR header")?; 733 785 734 786 let blocks = block_store 735 787 .get_many(&block_cids) 736 788 .await 737 - .map_err(|e| format!("Failed to fetch blocks: {:?}", e))?; 789 + .context("Failed to fetch blocks")?; 738 790 739 791 let car_bytes = block_cids 740 792 .iter() ··· 753 805 let cid_bytes = cid.to_bytes(); 754 806 let total_len = cid_bytes.len() + block.len(); 755 807 let mut writer = Vec::new(); 756 - crate::sync::car::write_varint(&mut writer, total_len as u64) 808 + crate::sync::car::write_varint(&mut writer, u64::try_from(total_len).expect("len fits u64")) 757 809 .expect("Writing to Vec<u8> should never fail"); 758 810 writer 759 811 .write_all(&cid_bytes) ··· 769 821 block_store: &PostgresBlockStore, 770 822 user_id: uuid::Uuid, 771 823 _head_cid: &Cid, 772 - ) -> Result<Vec<u8>, String> { 824 + ) -> anyhow::Result<Vec<u8>> { 773 825 use std::str::FromStr; 774 826 775 827 let repo_root_cid_str: String = repo_repo 776 828 .get_repo_root_cid_by_user_id(user_id) 777 829 .await 778 - .map_err(|e| format!("Failed to fetch repo: {:?}", e))? 779 - .ok_or_else(|| "Repository not found".to_string())? 830 + .context("Failed to fetch repo")? 831 + .ok_or_else(|| anyhow::anyhow!("Repository not found"))? 780 832 .to_string(); 781 833 782 - let actual_head_cid = 783 - Cid::from_str(&repo_root_cid_str).map_err(|e| format!("Invalid repo_root_cid: {}", e))?; 834 + let actual_head_cid = Cid::from_str(&repo_root_cid_str).context("Invalid repo_root_cid")?; 784 835 785 836 generate_repo_car(block_store, &actual_head_cid).await 786 837 } ··· 790 841 block_store: &PostgresBlockStore, 791 842 user_id: uuid::Uuid, 792 843 head_cid: &Cid, 793 - ) -> Result<Vec<u8>, String> { 844 + ) -> anyhow::Result<Vec<u8>> { 794 845 generate_repo_car_from_user_blocks(repo_repo, block_store, user_id, head_cid).await 795 846 } 796 847 797 848 pub fn count_car_blocks(car_bytes: &[u8]) -> i32 { 798 - let mut count = 0; 799 - let mut pos = 0; 849 + let mut count: i32 = 0; 850 + let mut pos: usize = 0; 800 851 801 852 if let Some((header_len, header_varint_len)) = read_varint(&car_bytes[pos..]) { 802 - pos += header_varint_len + header_len as usize; 853 + let Some(header_size) = usize::try_from(header_len).ok() else { 854 + return 0; 855 + }; 856 + let Some(next_pos) = header_varint_len 857 + .checked_add(header_size) 858 + .and_then(|skip| pos.checked_add(skip)) 859 + else { 860 + return 0; 861 + }; 862 + pos = next_pos; 803 863 } else { 804 864 return 0; 805 865 } 806 866 807 867 while pos < car_bytes.len() { 808 868 if let Some((block_len, varint_len)) = read_varint(&car_bytes[pos..]) { 809 - pos += varint_len + block_len as usize; 810 - count += 1; 869 + let Some(block_size) = usize::try_from(block_len).ok() else { 870 + break; 871 + }; 872 + let Some(next_pos) = varint_len 873 + .checked_add(block_size) 874 + .and_then(|skip| pos.checked_add(skip)) 875 + else { 876 + break; 877 + }; 878 + pos = next_pos; 879 + count = count.saturating_add(1); 811 880 } else { 812 881 break; 813 882 } ··· 839 908 backup_storage: &dyn BackupStorage, 840 909 user_id: uuid::Uuid, 841 910 retention_count: u32, 842 - ) -> Result<(), String> { 911 + ) -> anyhow::Result<()> { 843 912 let old_backups = backup_repo 844 - .get_old_backups(user_id, retention_count as i64) 913 + .get_old_backups(user_id, i64::from(retention_count)) 845 914 .await 846 - .map_err(|e| format!("DB error fetching old backups: {:?}", e))?; 915 + .context("DB error fetching old backups")?; 847 916 848 917 let results = futures::future::join_all(old_backups.into_iter().map(|backup| async move { 849 918 match backup_storage.delete_backup(&backup.storage_key).await { 850 - Ok(()) => match backup_repo.delete_backup(backup.id).await { 851 - Ok(()) => Ok(()), 852 - Err(e) => Err(format!( 853 - "DB delete failed for {}: {:?}", 854 - backup.storage_key, e 855 - )), 856 - }, 919 + Ok(()) => backup_repo 920 + .delete_backup(backup.id) 921 + .await 922 + .with_context(|| format!("DB delete failed for {}", backup.storage_key)), 857 923 Err(e) => { 858 924 warn!( 859 925 storage_key = %backup.storage_key,
+18 -32
crates/tranquil-pds/src/sso/config.rs
··· 80 80 } 81 81 82 82 fn load_provider(name: &str, needs_issuer: bool) -> Option<ProviderConfig> { 83 - let enabled = std::env::var(format!("SSO_{}_ENABLED", name)) 84 - .map(|v| v == "true" || v == "1") 85 - .unwrap_or(false); 83 + let enabled = crate::util::parse_env_bool(&format!("SSO_{}_ENABLED", name)); 86 84 87 85 if !enabled { 88 86 return None; ··· 121 119 } 122 120 123 121 fn load_apple_provider() -> Option<AppleProviderConfig> { 124 - let enabled = std::env::var("SSO_APPLE_ENABLED") 125 - .map(|v| v == "true" || v == "1") 126 - .unwrap_or(false); 122 + let enabled = crate::util::parse_env_bool("SSO_APPLE_ENABLED"); 127 123 128 124 if !enabled { 129 125 return None; ··· 178 174 self.apple.as_ref() 179 175 } 180 176 177 + fn provider_configs(&self) -> [(SsoProviderType, bool); 6] { 178 + [ 179 + (SsoProviderType::Github, self.github.is_some()), 180 + (SsoProviderType::Discord, self.discord.is_some()), 181 + (SsoProviderType::Google, self.google.is_some()), 182 + (SsoProviderType::Gitlab, self.gitlab.is_some()), 183 + (SsoProviderType::Oidc, self.oidc.is_some()), 184 + (SsoProviderType::Apple, self.apple.is_some()), 185 + ] 186 + } 187 + 181 188 pub fn enabled_providers(&self) -> Vec<SsoProviderType> { 182 - let mut providers = Vec::new(); 183 - if self.github.is_some() { 184 - providers.push(SsoProviderType::Github); 185 - } 186 - if self.discord.is_some() { 187 - providers.push(SsoProviderType::Discord); 188 - } 189 - if self.google.is_some() { 190 - providers.push(SsoProviderType::Google); 191 - } 192 - if self.gitlab.is_some() { 193 - providers.push(SsoProviderType::Gitlab); 194 - } 195 - if self.oidc.is_some() { 196 - providers.push(SsoProviderType::Oidc); 197 - } 198 - if self.apple.is_some() { 199 - providers.push(SsoProviderType::Apple); 200 - } 201 - providers 189 + self.provider_configs() 190 + .into_iter() 191 + .filter_map(|(p, enabled)| enabled.then_some(p)) 192 + .collect() 202 193 } 203 194 204 195 pub fn is_any_enabled(&self) -> bool { 205 - self.github.is_some() 206 - || self.discord.is_some() 207 - || self.google.is_some() 208 - || self.gitlab.is_some() 209 - || self.oidc.is_some() 210 - || self.apple.is_some() 196 + self.provider_configs().into_iter().any(|(_, e)| e) 211 197 } 212 198 }
+33 -52
crates/tranquil-pds/src/sso/endpoints.rs
··· 36 36 37 37 #[derive(Debug, Serialize)] 38 38 pub struct SsoProviderInfo { 39 - pub provider: String, 39 + pub provider: SsoProviderType, 40 40 pub name: String, 41 41 pub icon: String, 42 42 } ··· 52 52 .enabled_providers() 53 53 .iter() 54 54 .map(|(t, name, icon)| SsoProviderInfo { 55 - provider: t.as_str().to_string(), 55 + provider: *t, 56 56 name: name.to_string(), 57 57 icon: icon.to_string(), 58 58 }) ··· 63 63 64 64 #[derive(Debug, Deserialize)] 65 65 pub struct SsoInitiateRequest { 66 - pub provider: String, 66 + pub provider: SsoProviderType, 67 67 pub request_uri: Option<String>, 68 - pub action: Option<String>, 68 + pub action: Option<SsoAction>, 69 69 } 70 70 71 71 #[derive(Debug, Serialize)] ··· 79 79 headers: HeaderMap, 80 80 Json(input): Json<SsoInitiateRequest>, 81 81 ) -> Result<Json<SsoInitiateResponse>, ApiError> { 82 - if input.provider.len() > 20 { 83 - return Err(ApiError::SsoProviderNotFound); 84 - } 85 82 if let Some(ref uri) = input.request_uri 86 83 && uri.len() > 500 87 84 { 88 85 return Err(ApiError::InvalidRequest("Request URI too long".into())); 89 86 } 90 - if let Some(ref action) = input.action 91 - && action.len() > 20 92 - { 93 - return Err(ApiError::SsoInvalidAction); 94 - } 95 87 96 - let provider_type = 97 - SsoProviderType::parse(&input.provider).ok_or(ApiError::SsoProviderNotFound)?; 88 + let provider_type = input.provider; 98 89 99 90 let provider = state 100 91 .sso_manager 101 92 .get_provider(provider_type) 102 93 .ok_or(ApiError::SsoProviderNotEnabled)?; 103 94 104 - let action = input 105 - .action 106 - .as_deref() 107 - .map(SsoAction::parse) 108 - .unwrap_or(Some(SsoAction::Login)) 109 - .ok_or(ApiError::SsoInvalidAction)?; 95 + let action = input.action.unwrap_or(SsoAction::Login); 110 96 111 97 let is_standalone = action == SsoAction::Register && input.request_uri.is_none(); 112 98 let request_uri = input ··· 616 602 #[derive(Debug, Serialize)] 617 603 pub struct LinkedAccountInfo { 618 604 pub id: String, 619 - pub provider: String, 605 + pub provider: SsoProviderType, 620 606 pub provider_name: String, 621 607 pub provider_username: Option<String>, 622 608 pub provider_email: Option<String>, ··· 642 628 .into_iter() 643 629 .map(|id| LinkedAccountInfo { 644 630 id: id.id.to_string(), 645 - provider: id.provider.as_str().to_string(), 631 + provider: id.provider, 646 632 provider_name: id.provider.display_name().to_string(), 647 633 provider_username: id.provider_username.map(|u| u.into_inner()), 648 634 provider_email: id.provider_email.map(|e| e.into_inner()), ··· 723 709 #[derive(Debug, Serialize)] 724 710 pub struct PendingRegistrationResponse { 725 711 pub request_uri: String, 726 - pub provider: String, 712 + pub provider: SsoProviderType, 727 713 pub provider_user_id: String, 728 714 pub provider_username: Option<String>, 729 715 pub provider_email: Option<String>, ··· 747 733 748 734 Ok(Json(PendingRegistrationResponse { 749 735 request_uri: pending.request_uri, 750 - provider: pending.provider.as_str().to_string(), 736 + provider: pending.provider, 751 737 provider_user_id: pending.provider_user_id.into_inner(), 752 738 provider_username: pending.provider_username.map(|u| u.into_inner()), 753 739 provider_email: pending.provider_email.map(|e| e.into_inner()), ··· 789 775 790 776 let hostname_for_handles = pds_hostname_without_port(); 791 777 let full_handle = format!("{}.{}", validated, hostname_for_handles); 792 - let handle_typed = unsafe { crate::types::Handle::new_unchecked(&full_handle) }; 778 + let handle_typed: crate::types::Handle = match full_handle.parse() { 779 + Ok(h) => h, 780 + Err(_) => return Err(ApiError::InvalidHandle(None)), 781 + }; 793 782 794 783 let db_available = state 795 784 .user_repo ··· 816 805 pub handle: String, 817 806 pub email: Option<String>, 818 807 pub invite_code: Option<String>, 819 - pub verification_channel: Option<String>, 808 + pub verification_channel: Option<tranquil_db_traits::CommsChannel>, 820 809 pub discord_username: Option<String>, 821 810 pub telegram_username: Option<String>, 822 811 pub signal_username: Option<String>, ··· 875 864 Err(_) => return Err(ApiError::InvalidHandle(None)), 876 865 }; 877 866 878 - let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 867 + let verification_channel = input 868 + .verification_channel 869 + .unwrap_or(tranquil_db_traits::CommsChannel::Email); 879 870 let verification_recipient = match verification_channel { 880 - "email" => { 871 + tranquil_db_traits::CommsChannel::Email => { 881 872 let email = input 882 873 .email 883 874 .clone() ··· 894 885 _ => return Err(ApiError::MissingEmail), 895 886 } 896 887 } 897 - "discord" => match &input.discord_username { 888 + tranquil_db_traits::CommsChannel::Discord => match &input.discord_username { 898 889 Some(username) if !username.trim().is_empty() => { 899 890 let clean = username.trim().to_lowercase(); 900 891 if !crate::api::validation::is_valid_discord_username(&clean) { ··· 906 897 } 907 898 _ => return Err(ApiError::MissingDiscordId), 908 899 }, 909 - "telegram" => match &input.telegram_username { 900 + tranquil_db_traits::CommsChannel::Telegram => match &input.telegram_username { 910 901 Some(username) if !username.trim().is_empty() => { 911 902 let clean = username.trim().trim_start_matches('@'); 912 903 if !crate::api::validation::is_valid_telegram_username(clean) { ··· 918 909 } 919 910 _ => return Err(ApiError::MissingTelegramUsername), 920 911 }, 921 - "signal" => match &input.signal_username { 912 + tranquil_db_traits::CommsChannel::Signal => match &input.signal_username { 922 913 Some(username) if !username.trim().is_empty() => { 923 914 username.trim().trim_start_matches('@').to_lowercase() 924 915 } 925 916 _ => return Err(ApiError::MissingSignalNumber), 926 917 }, 927 - _ => return Err(ApiError::InvalidVerificationChannel), 928 918 }; 929 919 930 920 let email = input ··· 958 948 Err(_) => return Err(ApiError::InvalidInviteCode), 959 949 } 960 950 } else { 961 - let invite_required = std::env::var("INVITE_CODE_REQUIRED") 962 - .map(|v| v == "true" || v == "1") 963 - .unwrap_or(false); 951 + let invite_required = crate::util::parse_env_bool("INVITE_CODE_REQUIRED"); 964 952 if invite_required { 965 953 return Err(ApiError::InviteCodeRequired); 966 954 } 967 955 None 968 956 }; 969 957 970 - let handle_typed = unsafe { crate::types::Handle::new_unchecked(&handle) }; 958 + let handle_typed: crate::types::Handle = 959 + handle.parse().map_err(|_| ApiError::InvalidHandle(None))?; 971 960 let reserved = state 972 961 .user_repo 973 962 .reserve_handle(&handle_typed, client_ip) ··· 1069 1058 }; 1070 1059 1071 1060 let rev = Tid::now(LimitedU32::MIN); 1072 - let did_typed = unsafe { crate::types::Did::new_unchecked(&did) }; 1061 + let did_typed: crate::types::Did = did 1062 + .parse() 1063 + .map_err(|_| ApiError::InternalError(Some("Invalid DID".into())))?; 1073 1064 let (commit_bytes, _sig) = match crate::api::repo::record::utils::create_signed_commit( 1074 1065 &did_typed, 1075 1066 mst_root, ··· 1101 1092 }) 1102 1093 }); 1103 1094 1104 - let preferred_comms_channel = match verification_channel { 1105 - "email" => tranquil_db_traits::CommsChannel::Email, 1106 - "discord" => tranquil_db_traits::CommsChannel::Discord, 1107 - "telegram" => tranquil_db_traits::CommsChannel::Telegram, 1108 - "signal" => tranquil_db_traits::CommsChannel::Signal, 1109 - _ => tranquil_db_traits::CommsChannel::Email, 1110 - }; 1111 - 1112 1095 let create_input = tranquil_db_traits::CreateSsoAccountInput { 1113 1096 handle: handle_typed.clone(), 1114 1097 email: email.clone(), 1115 1098 did: did_typed.clone(), 1116 - preferred_comms_channel, 1099 + preferred_comms_channel: verification_channel, 1117 1100 discord_username: input 1118 1101 .discord_username 1119 1102 .clone() ··· 1192 1175 "$type": "app.bsky.actor.profile", 1193 1176 "displayName": handle_typed.as_str() 1194 1177 }); 1195 - let profile_collection = unsafe { crate::types::Nsid::new_unchecked("app.bsky.actor.profile") }; 1196 - let profile_rkey = unsafe { crate::types::Rkey::new_unchecked("self") }; 1197 1178 if let Err(e) = crate::api::repo::record::create_record_internal( 1198 1179 &state, 1199 1180 &did_typed, 1200 - &profile_collection, 1201 - &profile_rkey, 1181 + &crate::types::PROFILE_COLLECTION, 1182 + &crate::types::PROFILE_RKEY, 1202 1183 &profile_record, 1203 1184 ) 1204 1185 .await ··· 1261 1242 .await 1262 1243 .unwrap_or(None); 1263 1244 1264 - let channel_auto_verified = verification_channel == "email" 1245 + let channel_auto_verified = verification_channel == tranquil_db_traits::CommsChannel::Email 1265 1246 && pending_preview.provider_email_verified 1266 1247 && pending_preview.provider_email.as_ref().map(|e| e.as_str()) == email.as_deref(); 1267 1248 ··· 1357 1338 1358 1339 if let Some(uid) = user_id { 1359 1340 let verification_token = crate::auth::verification_token::generate_signup_token( 1360 - &did, 1341 + &did_typed, 1361 1342 verification_channel, 1362 1343 &verification_recipient, 1363 1344 );
+32 -16
crates/tranquil-pds/src/sso/providers.rs
··· 14 14 15 15 const SSO_HTTP_TIMEOUT: Duration = Duration::from_secs(15); 16 16 17 + struct PkceChallenge { 18 + code_verifier: String, 19 + code_challenge: String, 20 + } 21 + 22 + struct ClientSecretWithExpiry { 23 + secret: String, 24 + expires_at: u64, 25 + } 26 + 17 27 fn create_http_client() -> Client { 18 28 Client::builder() 19 29 .timeout(SSO_HTTP_TIMEOUT) ··· 473 483 .await 474 484 } 475 485 476 - fn generate_pkce() -> (String, String) { 486 + fn generate_pkce() -> PkceChallenge { 477 487 use rand::RngCore; 478 488 let mut verifier_bytes = [0u8; 32]; 479 489 rand::thread_rng().fill_bytes(&mut verifier_bytes); 480 - let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 490 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 481 491 482 492 use sha2::{Digest, Sha256}; 483 - let challenge_bytes = Sha256::digest(verifier.as_bytes()); 484 - let challenge = URL_SAFE_NO_PAD.encode(challenge_bytes); 493 + let challenge_bytes = Sha256::digest(code_verifier.as_bytes()); 494 + let code_challenge = URL_SAFE_NO_PAD.encode(challenge_bytes); 485 495 486 - (verifier, challenge) 496 + PkceChallenge { 497 + code_verifier, 498 + code_challenge, 499 + } 487 500 } 488 501 489 502 fn validate_id_token( ··· 585 598 redirect_uri: &str, 586 599 nonce: Option<&str>, 587 600 ) -> Result<AuthUrlResult, SsoError> { 588 - let (verifier, challenge) = Self::generate_pkce(); 601 + let pkce = Self::generate_pkce(); 589 602 590 603 let auth_endpoint = match self.provider_type { 591 604 SsoProviderType::Google => "https://accounts.google.com/o/oauth2/v2/auth".to_string(), ··· 604 617 urlencoding::encode(&self.client_id), 605 618 urlencoding::encode(redirect_uri), 606 619 urlencoding::encode(state), 607 - urlencoding::encode(&challenge), 620 + urlencoding::encode(&pkce.code_challenge), 608 621 ); 609 622 610 623 if let Some(n) = nonce { ··· 613 626 614 627 Ok(AuthUrlResult { 615 628 url, 616 - code_verifier: Some(verifier), 629 + code_verifier: Some(pkce.code_verifier), 617 630 }) 618 631 } 619 632 ··· 785 798 }) 786 799 } 787 800 788 - fn generate_client_secret(&self) -> Result<(String, u64), SsoError> { 801 + fn generate_client_secret(&self) -> Result<ClientSecretWithExpiry, SsoError> { 789 802 let now = SystemTime::now() 790 803 .duration_since(UNIX_EPOCH) 791 - .unwrap() 804 + .unwrap_or_default() 792 805 .as_secs(); 793 806 let exp = now + (150 * 24 * 60 * 60); 794 807 ··· 821 834 SsoError::Provider(format!("Failed to generate Apple client secret: {}", e)) 822 835 })?; 823 836 824 - Ok((token, exp)) 837 + Ok(ClientSecretWithExpiry { 838 + secret: token, 839 + expires_at: exp, 840 + }) 825 841 } 826 842 827 843 async fn get_client_secret(&self) -> Result<String, SsoError> { 828 844 let now = SystemTime::now() 829 845 .duration_since(UNIX_EPOCH) 830 - .unwrap() 846 + .unwrap_or_default() 831 847 .as_secs(); 832 848 833 849 { ··· 839 855 } 840 856 } 841 857 842 - let (secret, expires_at) = self.generate_client_secret()?; 858 + let generated = self.generate_client_secret()?; 843 859 844 860 { 845 861 let mut cache = self.client_secret_cache.write().await; 846 862 *cache = Some(CachedClientSecret { 847 - secret: secret.clone(), 848 - expires_at, 863 + secret: generated.secret.clone(), 864 + expires_at: generated.expires_at, 849 865 }); 850 866 } 851 867 852 - Ok(secret) 868 + Ok(generated.secret) 853 869 } 854 870 855 871 async fn get_jwks(&self) -> Result<&JwkSet, SsoError> {
+90 -24
crates/tranquil-pds/src/state.rs
··· 62 62 } 63 63 64 64 #[derive(Debug, Clone, Copy)] 65 + pub struct RateLimitParams { 66 + pub limit: u32, 67 + pub window_ms: u64, 68 + } 69 + 70 + #[derive(Debug, Clone, Copy)] 65 71 pub enum RateLimitKind { 66 72 Login, 67 73 AccountCreation, ··· 86 92 } 87 93 88 94 impl RateLimitKind { 89 - fn key_prefix(&self) -> &'static str { 95 + const fn key_prefix(&self) -> &'static str { 90 96 match self { 91 97 Self::Login => "login", 92 98 Self::AccountCreation => "account_creation", ··· 111 117 } 112 118 } 113 119 114 - fn limit_and_window_ms(&self) -> (u32, u64) { 120 + const fn params(&self) -> RateLimitParams { 115 121 match self { 116 - Self::Login => (10, 60_000), 117 - Self::AccountCreation => (10, 3_600_000), 118 - Self::PasswordReset => (5, 3_600_000), 119 - Self::ResetPassword => (10, 60_000), 120 - Self::RefreshSession => (60, 60_000), 121 - Self::OAuthToken => (300, 60_000), 122 - Self::OAuthAuthorize => (10, 60_000), 123 - Self::OAuthPar => (30, 60_000), 124 - Self::OAuthIntrospect => (30, 60_000), 125 - Self::AppPassword => (10, 60_000), 126 - Self::EmailUpdate => (5, 3_600_000), 127 - Self::TotpVerify => (5, 300_000), 128 - Self::HandleUpdate => (10, 300_000), 129 - Self::HandleUpdateDaily => (50, 86_400_000), 130 - Self::VerificationCheck => (60, 60_000), 131 - Self::SsoInitiate => (10, 60_000), 132 - Self::SsoCallback => (30, 60_000), 133 - Self::SsoUnlink => (10, 60_000), 134 - Self::OAuthRegisterComplete => (5, 300_000), 135 - Self::HandleVerification => (10, 60_000), 122 + Self::Login => RateLimitParams { 123 + limit: 10, 124 + window_ms: 60_000, 125 + }, 126 + Self::AccountCreation => RateLimitParams { 127 + limit: 10, 128 + window_ms: 3_600_000, 129 + }, 130 + Self::PasswordReset => RateLimitParams { 131 + limit: 5, 132 + window_ms: 3_600_000, 133 + }, 134 + Self::ResetPassword => RateLimitParams { 135 + limit: 10, 136 + window_ms: 60_000, 137 + }, 138 + Self::RefreshSession => RateLimitParams { 139 + limit: 60, 140 + window_ms: 60_000, 141 + }, 142 + Self::OAuthToken => RateLimitParams { 143 + limit: 300, 144 + window_ms: 60_000, 145 + }, 146 + Self::OAuthAuthorize => RateLimitParams { 147 + limit: 10, 148 + window_ms: 60_000, 149 + }, 150 + Self::OAuthPar => RateLimitParams { 151 + limit: 30, 152 + window_ms: 60_000, 153 + }, 154 + Self::OAuthIntrospect => RateLimitParams { 155 + limit: 30, 156 + window_ms: 60_000, 157 + }, 158 + Self::AppPassword => RateLimitParams { 159 + limit: 10, 160 + window_ms: 60_000, 161 + }, 162 + Self::EmailUpdate => RateLimitParams { 163 + limit: 5, 164 + window_ms: 3_600_000, 165 + }, 166 + Self::TotpVerify => RateLimitParams { 167 + limit: 5, 168 + window_ms: 300_000, 169 + }, 170 + Self::HandleUpdate => RateLimitParams { 171 + limit: 10, 172 + window_ms: 300_000, 173 + }, 174 + Self::HandleUpdateDaily => RateLimitParams { 175 + limit: 50, 176 + window_ms: 86_400_000, 177 + }, 178 + Self::VerificationCheck => RateLimitParams { 179 + limit: 60, 180 + window_ms: 60_000, 181 + }, 182 + Self::SsoInitiate => RateLimitParams { 183 + limit: 10, 184 + window_ms: 60_000, 185 + }, 186 + Self::SsoCallback => RateLimitParams { 187 + limit: 30, 188 + window_ms: 60_000, 189 + }, 190 + Self::SsoUnlink => RateLimitParams { 191 + limit: 10, 192 + window_ms: 60_000, 193 + }, 194 + Self::OAuthRegisterComplete => RateLimitParams { 195 + limit: 5, 196 + window_ms: 300_000, 197 + }, 198 + Self::HandleVerification => RateLimitParams { 199 + limit: 10, 200 + window_ms: 60_000, 201 + }, 136 202 } 137 203 } 138 204 } ··· 294 360 } 295 361 296 362 let key = format!("{}:{}", kind.key_prefix(), client_ip); 297 - let (limit, window_ms) = kind.limit_and_window_ms(); 363 + let params = kind.params(); 298 364 299 365 if !self 300 366 .distributed_rate_limiter 301 - .check_rate_limit(&key, limit, window_ms) 367 + .check_rate_limit(&key, params.limit, params.window_ms) 302 368 .await 303 369 { 304 370 crate::metrics::record_rate_limit_rejection(limiter_name);
+26 -40
crates/tranquil-pds/src/sync/blob.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 - use crate::sync::util::assert_repo_availability; 3 + use crate::sync::util::{RepoAccessLevel, assert_repo_availability}; 4 4 use axum::{ 5 5 Json, 6 6 body::Body, ··· 15 15 16 16 #[derive(Deserialize)] 17 17 pub struct GetBlobParams { 18 - pub did: String, 19 - pub cid: String, 18 + pub did: Did, 19 + pub cid: CidLink, 20 20 } 21 21 22 22 pub async fn get_blob( 23 23 State(state): State<AppState>, 24 24 Query(params): Query<GetBlobParams>, 25 25 ) -> Response { 26 - let did_str = params.did.trim(); 27 - let cid_str = params.cid.trim(); 28 - if did_str.is_empty() { 29 - return ApiError::InvalidRequest("did is required".into()).into_response(); 30 - } 31 - if cid_str.is_empty() { 32 - return ApiError::InvalidRequest("cid is required".into()).into_response(); 33 - } 34 - let did: Did = match did_str.parse() { 35 - Ok(d) => d, 36 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 37 - }; 38 - let cid: CidLink = match cid_str.parse() { 39 - Ok(c) => c, 40 - Err(_) => return ApiError::InvalidRequest("invalid cid".into()).into_response(), 41 - }; 26 + let did = params.did; 27 + let cid = params.cid; 42 28 43 - let _account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 44 - Ok(a) => a, 45 - Err(e) => return e.into_response(), 46 - }; 29 + let _account = 30 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 31 + .await 32 + { 33 + Ok(a) => a, 34 + Err(e) => return e.into_response(), 35 + }; 47 36 48 37 let blob_result = state.blob_repo.get_blob_metadata(&cid).await; 49 38 match blob_result { ··· 55 44 .header("x-content-type-options", "nosniff") 56 45 .header("content-security-policy", "default-src 'none'; sandbox") 57 46 .body(Body::from(data)) 58 - .unwrap(), 47 + .unwrap_or_else(|_| ApiError::InternalError(None).into_response()), 59 48 Err(e) => { 60 49 error!("Failed to fetch blob from storage: {:?}", e); 61 50 ApiError::BlobNotFound(Some("Blob not found in storage".into())).into_response() ··· 71 60 72 61 #[derive(Deserialize)] 73 62 pub struct ListBlobsParams { 74 - pub did: String, 63 + pub did: Did, 75 64 pub since: Option<String>, 76 65 pub limit: Option<i64>, 77 66 pub cursor: Option<String>, ··· 88 77 State(state): State<AppState>, 89 78 Query(params): Query<ListBlobsParams>, 90 79 ) -> Response { 91 - let did_str = params.did.trim(); 92 - if did_str.is_empty() { 93 - return ApiError::InvalidRequest("did is required".into()).into_response(); 94 - } 95 - let did: Did = match did_str.parse() { 96 - Ok(d) => d, 97 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 98 - }; 80 + let did = params.did; 99 81 100 - let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 101 - Ok(a) => a, 102 - Err(e) => return e.into_response(), 103 - }; 82 + let account = 83 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 84 + .await 85 + { 86 + Ok(a) => a, 87 + Err(e) => return e.into_response(), 88 + }; 104 89 105 90 let limit = params.limit.unwrap_or(500).clamp(1, 1000); 106 91 let cursor_cid = params.cursor.as_deref().unwrap_or(""); ··· 117 102 cid_strs 118 103 .into_iter() 119 104 .filter(|c| c.as_str() > cursor_cid) 120 - .take((limit + 1) as usize) 105 + .take(usize::try_from(limit + 1).unwrap_or(0)) 121 106 .collect() 122 107 }) 123 108 } else { ··· 129 114 }; 130 115 match cids_result { 131 116 Ok(cids) => { 132 - let has_more = cids.len() as i64 > limit; 133 - let cids: Vec<String> = cids.into_iter().take(limit as usize).collect(); 117 + let limit_usize = usize::try_from(limit).unwrap_or(0); 118 + let has_more = cids.len() > limit_usize; 119 + let cids: Vec<String> = cids.into_iter().take(limit_usize).collect(); 134 120 let next_cursor = if has_more { cids.last().cloned() } else { None }; 135 121 ( 136 122 StatusCode::OK,
+45 -16
crates/tranquil-pds/src/sync/car.rs
··· 2 2 use iroh_car::CarHeader; 3 3 use std::io::Write; 4 4 5 + #[derive(Debug)] 6 + pub enum CarEncodeError { 7 + CborEncodeFailed(String), 8 + } 9 + 10 + impl std::fmt::Display for CarEncodeError { 11 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 12 + match self { 13 + Self::CborEncodeFailed(e) => write!(f, "Failed to encode CAR header: {}", e), 14 + } 15 + } 16 + } 17 + 18 + impl std::error::Error for CarEncodeError {} 19 + 5 20 pub fn write_varint<W: Write>(mut writer: W, mut value: u64) -> std::io::Result<()> { 6 21 loop { 7 22 let mut byte = (value & 0x7F) as u8; ··· 18 33 } 19 34 20 35 pub fn ld_write<W: Write>(mut writer: W, data: &[u8]) -> std::io::Result<()> { 21 - write_varint(&mut writer, data.len() as u64)?; 36 + write_varint( 37 + &mut writer, 38 + u64::try_from(data.len()).expect("len fits u64"), 39 + )?; 22 40 writer.write_all(data)?; 23 41 Ok(()) 24 42 } 25 43 26 - pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 27 - let header = CarHeader::new_v1(vec![*root_cid]); 44 + pub fn encode_car_header_with_root(root_cid: Option<&Cid>) -> Result<Vec<u8>, CarEncodeError> { 45 + let roots = root_cid.map_or_else(Vec::new, |cid| vec![*cid]); 46 + let header = CarHeader::new_v1(roots); 28 47 let header_cbor = header 29 48 .encode() 30 - .map_err(|e| format!("Failed to encode CAR header: {:?}", e))?; 49 + .map_err(|e| CarEncodeError::CborEncodeFailed(format!("{:?}", e)))?; 31 50 let mut result = Vec::new(); 32 - write_varint(&mut result, header_cbor.len() as u64) 33 - .expect("Writing to Vec<u8> should never fail"); 51 + write_varint( 52 + &mut result, 53 + u64::try_from(header_cbor.len()).expect("len fits u64"), 54 + ) 55 + .expect("Writing to Vec<u8> should never fail"); 34 56 result.extend_from_slice(&header_cbor); 35 57 Ok(result) 36 58 } 37 59 38 - pub fn encode_car_header_null_root() -> Result<Vec<u8>, String> { 39 - let header = CarHeader::new_v1(vec![]); 40 - let header_cbor = header 41 - .encode() 42 - .map_err(|e| format!("Failed to encode CAR header: {:?}", e))?; 43 - let mut result = Vec::new(); 44 - write_varint(&mut result, header_cbor.len() as u64) 45 - .expect("Writing to Vec<u8> should never fail"); 46 - result.extend_from_slice(&header_cbor); 47 - Ok(result) 60 + pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, CarEncodeError> { 61 + encode_car_header_with_root(Some(root_cid)) 62 + } 63 + 64 + pub fn encode_car_header_null_root() -> Result<Vec<u8>, CarEncodeError> { 65 + encode_car_header_with_root(None) 66 + } 67 + 68 + pub fn encode_car_block(cid: &Cid, block: &[u8]) -> Vec<u8> { 69 + let cid_bytes = cid.to_bytes(); 70 + let total_len = cid_bytes.len() + block.len(); 71 + let mut buf = Vec::with_capacity(10 + total_len); 72 + write_varint(&mut buf, u64::try_from(total_len).unwrap_or(u64::MAX)) 73 + .unwrap_or_else(|_| unreachable!()); 74 + buf.extend_from_slice(&cid_bytes); 75 + buf.extend_from_slice(block); 76 + buf 48 77 }
+26 -39
crates/tranquil-pds/src/sync/commit.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 - use crate::sync::util::{assert_repo_availability, get_account_with_status}; 3 + use crate::sync::util::{RepoAccessLevel, assert_repo_availability, get_account_with_status}; 4 4 use axum::{ 5 5 Json, 6 6 extract::{Query, State}, ··· 25 25 26 26 #[derive(Deserialize)] 27 27 pub struct GetLatestCommitParams { 28 - pub did: String, 28 + pub did: Did, 29 29 } 30 30 31 31 #[derive(Serialize)] ··· 38 38 State(state): State<AppState>, 39 39 Query(params): Query<GetLatestCommitParams>, 40 40 ) -> Response { 41 - let did_str = params.did.trim(); 42 - if did_str.is_empty() { 43 - return ApiError::InvalidRequest("did is required".into()).into_response(); 44 - } 45 - let did: Did = match did_str.parse() { 46 - Ok(d) => d, 47 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 48 - }; 41 + let did = params.did; 49 42 50 - let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 51 - Ok(a) => a, 52 - Err(e) => return e.into_response(), 53 - }; 43 + let account = 44 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 45 + .await 46 + { 47 + Ok(a) => a, 48 + Err(e) => return e.into_response(), 49 + }; 54 50 55 51 let Some(repo_root_cid) = account.repo_root_cid else { 56 52 return ApiError::RepoNotFound(Some("Repo not initialized".into())).into_response(); ··· 59 55 let Some(rev) = get_rev_from_commit(&state, &repo_root_cid).await else { 60 56 error!( 61 57 "Failed to parse commit for DID {}: CID {}", 62 - did_str, repo_root_cid 58 + did, repo_root_cid 63 59 ); 64 60 return ApiError::InternalError(Some("Failed to read repo commit".into())).into_response(); 65 61 }; ··· 83 79 #[derive(Serialize)] 84 80 #[serde(rename_all = "camelCase")] 85 81 pub struct RepoInfo { 86 - pub did: String, 82 + pub did: Did, 87 83 pub head: String, 88 84 pub rev: String, 89 85 pub active: bool, 90 86 #[serde(skip_serializing_if = "Option::is_none")] 91 - pub status: Option<String>, 87 + pub status: Option<AccountStatus>, 92 88 } 93 89 94 90 #[derive(Serialize)] ··· 111 107 .await; 112 108 match result { 113 109 Ok(rows) => { 114 - let has_more = rows.len() as i64 > limit; 110 + let limit_usize = usize::try_from(limit).unwrap_or(0); 111 + let has_more = rows.len() > limit_usize; 115 112 let mut repos: Vec<RepoInfo> = Vec::new(); 116 - for row in rows.iter().take(limit as usize) { 113 + for row in rows.iter().take(limit_usize) { 117 114 let cid_str = row.repo_root_cid.to_string(); 118 115 let rev = get_rev_from_commit(&state, &cid_str) 119 116 .await ··· 127 124 AccountStatus::Active 128 125 }; 129 126 repos.push(RepoInfo { 130 - did: row.did.to_string(), 127 + did: row.did.clone(), 131 128 head: cid_str, 132 129 rev, 133 130 active: status.is_active(), 134 - status: status.for_firehose().map(String::from), 131 + status: status.for_firehose_typed(), 135 132 }); 136 133 } 137 134 let next_cursor = if has_more { 138 - repos.last().map(|r| r.did.clone()) 135 + repos.last().map(|r| r.did.to_string()) 139 136 } else { 140 137 None 141 138 }; ··· 157 154 158 155 #[derive(Deserialize)] 159 156 pub struct GetRepoStatusParams { 160 - pub did: String, 157 + pub did: Did, 161 158 } 162 159 163 160 #[derive(Serialize)] 164 161 pub struct GetRepoStatusOutput { 165 - pub did: String, 162 + pub did: Did, 166 163 pub active: bool, 167 164 #[serde(skip_serializing_if = "Option::is_none")] 168 - pub status: Option<String>, 165 + pub status: Option<AccountStatus>, 169 166 #[serde(skip_serializing_if = "Option::is_none")] 170 167 pub rev: Option<String>, 171 168 } ··· 174 171 State(state): State<AppState>, 175 172 Query(params): Query<GetRepoStatusParams>, 176 173 ) -> Response { 177 - let did_str = params.did.trim(); 178 - if did_str.is_empty() { 179 - return ApiError::InvalidRequest("did is required".into()).into_response(); 180 - } 181 - let did: Did = match did_str.parse() { 182 - Ok(d) => d, 183 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 184 - }; 174 + let did = params.did; 185 175 186 176 let account = match get_account_with_status(state.repo_repo.as_ref(), &did).await { 187 177 Ok(Some(a)) => a, 188 178 Ok(None) => { 189 - return ApiError::RepoNotFound(Some(format!( 190 - "Could not find repo for DID: {}", 191 - did_str 192 - ))) 193 - .into_response(); 179 + return ApiError::RepoNotFound(Some(format!("Could not find repo for DID: {}", did))) 180 + .into_response(); 194 181 } 195 182 Err(e) => { 196 183 error!("DB error in get_repo_status: {:?}", e); ··· 213 200 Json(GetRepoStatusOutput { 214 201 did: account.did, 215 202 active: account.status.is_active(), 216 - status: account.status.for_firehose().map(String::from), 203 + status: account.status.for_firehose_typed(), 217 204 rev, 218 205 }), 219 206 )
+42 -37
crates/tranquil-pds/src/sync/deprecated.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 - use crate::sync::car::encode_car_header; 4 - use crate::sync::util::assert_repo_availability; 3 + use crate::sync::car::{encode_car_block, encode_car_header}; 4 + use crate::sync::util::{RepoAccessLevel, assert_repo_availability}; 5 5 use axum::{ 6 6 Json, 7 7 extract::{Query, State}, 8 - http::{HeaderMap, StatusCode}, 8 + http::{HeaderMap, Method, StatusCode}, 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use cid::Cid; 12 12 use ipld_core::ipld::Ipld; 13 13 use jacquard_repo::storage::BlockStore; 14 14 use serde::{Deserialize, Serialize}; 15 - use std::io::Write; 16 15 use std::str::FromStr; 17 16 use tranquil_types::Did; 18 17 19 18 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 20 19 21 - async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool { 20 + async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &Did) -> bool { 22 21 let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 23 22 headers, 24 - "Authorization", 23 + axum::http::header::AUTHORIZATION, 25 24 )) { 26 25 Some(t) => t, 27 26 None => return false, 28 27 }; 29 - let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 28 + let dpop_proof = crate::util::get_header_str(headers, crate::util::HEADER_DPOP); 30 29 let http_uri = "/"; 31 30 match crate::auth::validate_token_with_dpop( 32 31 state.user_repo.as_ref(), 33 32 state.oauth_repo.as_ref(), 34 33 &extracted.token, 35 - extracted.is_dpop, 34 + extracted.scheme, 36 35 dpop_proof, 37 - "GET", 36 + Method::GET.as_str(), 38 37 http_uri, 39 - false, 40 - true, 38 + crate::auth::AccountRequirement::AnyStatus, 41 39 ) 42 40 .await 43 41 { 44 - Ok(auth_user) => auth_user.is_admin || auth_user.did == did, 42 + Ok(auth_user) => auth_user.is_admin || auth_user.did == *did, 45 43 Err(_) => false, 46 44 } 47 45 } ··· 69 67 Ok(d) => d, 70 68 Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 71 69 }; 72 - let is_admin_or_self = check_admin_or_self(&state, &headers, did_str).await; 73 - let account = 74 - match assert_repo_availability(state.repo_repo.as_ref(), &did, is_admin_or_self).await { 75 - Ok(a) => a, 76 - Err(e) => return e.into_response(), 77 - }; 70 + let is_admin_or_self = check_admin_or_self(&state, &headers, &did).await; 71 + let account = match assert_repo_availability( 72 + state.repo_repo.as_ref(), 73 + &did, 74 + if is_admin_or_self { 75 + RepoAccessLevel::Privileged 76 + } else { 77 + RepoAccessLevel::Public 78 + }, 79 + ) 80 + .await 81 + { 82 + Ok(a) => a, 83 + Err(e) => return e.into_response(), 84 + }; 78 85 match account.repo_root_cid { 79 86 Some(root) => (StatusCode::OK, Json(GetHeadOutput { root })).into_response(), 80 - None => ApiError::RepoNotFound(Some(format!("Could not find root for DID: {}", did_str))) 87 + None => ApiError::RepoNotFound(Some(format!("Could not find root for DID: {}", did))) 81 88 .into_response(), 82 89 } 83 90 } ··· 100 107 Ok(d) => d, 101 108 Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 102 109 }; 103 - let is_admin_or_self = check_admin_or_self(&state, &headers, did_str).await; 104 - let account = 105 - match assert_repo_availability(state.repo_repo.as_ref(), &did, is_admin_or_self).await { 106 - Ok(a) => a, 107 - Err(e) => return e.into_response(), 108 - }; 110 + let is_admin_or_self = check_admin_or_self(&state, &headers, &did).await; 111 + let account = match assert_repo_availability( 112 + state.repo_repo.as_ref(), 113 + &did, 114 + if is_admin_or_self { 115 + RepoAccessLevel::Privileged 116 + } else { 117 + RepoAccessLevel::Public 118 + }, 119 + ) 120 + .await 121 + { 122 + Ok(a) => a, 123 + Err(e) => return e.into_response(), 124 + }; 109 125 let Some(head_str) = account.repo_root_cid else { 110 126 return ApiError::RepoNotFound(Some("Repo not initialized".into())).into_response(); 111 127 }; ··· 128 144 } 129 145 remaining -= 1; 130 146 if let Ok(Some(block)) = state.block_store.get(&cid).await { 131 - let cid_bytes = cid.to_bytes(); 132 - let total_len = cid_bytes.len() + block.len(); 133 - let mut writer = Vec::new(); 134 - crate::sync::car::write_varint(&mut writer, total_len as u64) 135 - .expect("Writing to Vec<u8> should never fail"); 136 - writer 137 - .write_all(&cid_bytes) 138 - .expect("Writing to Vec<u8> should never fail"); 139 - writer 140 - .write_all(&block) 141 - .expect("Writing to Vec<u8> should never fail"); 142 - car_bytes.extend_from_slice(&writer); 147 + car_bytes.extend_from_slice(&encode_car_block(&cid, &block)); 143 148 if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) { 144 149 extract_links_ipld(&value, &mut stack); 145 150 }
+38 -11
crates/tranquil-pds/src/sync/frame.rs
··· 3 3 use serde::{Deserialize, Serialize}; 4 4 use std::str::FromStr; 5 5 use tranquil_scopes::RepoAction; 6 + use tranquil_types::Did; 7 + 8 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 9 + pub enum FrameType { 10 + #[serde(rename = "#commit")] 11 + Commit, 12 + #[serde(rename = "#identity")] 13 + Identity, 14 + #[serde(rename = "#account")] 15 + Account, 16 + #[serde(rename = "#sync")] 17 + Sync, 18 + #[serde(rename = "#info")] 19 + Info, 20 + } 6 21 7 22 #[derive(Debug, Serialize, Deserialize)] 8 23 pub struct FrameHeader { 9 24 pub op: i64, 10 - pub t: String, 25 + pub t: FrameType, 11 26 } 12 27 13 28 #[derive(Debug, Serialize, Deserialize)] ··· 16 31 pub rebase: bool, 17 32 #[serde(rename = "tooBig")] 18 33 pub too_big: bool, 19 - pub repo: String, 34 + pub repo: Did, 20 35 pub commit: Cid, 21 36 pub rev: String, 22 37 pub since: Option<String>, ··· 48 63 49 64 #[derive(Debug, Serialize, Deserialize)] 50 65 pub struct IdentityFrame { 51 - pub did: String, 66 + pub did: Did, 52 67 #[serde(skip_serializing_if = "Option::is_none")] 53 68 pub handle: Option<String>, 54 69 pub seq: i64, ··· 57 72 58 73 #[derive(Debug, Serialize, Deserialize)] 59 74 pub struct AccountFrame { 60 - pub did: String, 75 + pub did: Did, 61 76 pub active: bool, 62 77 #[serde(skip_serializing_if = "Option::is_none")] 63 - pub status: Option<String>, 78 + pub status: Option<tranquil_db_traits::AccountStatus>, 64 79 pub seq: i64, 65 80 pub time: String, 66 81 } 67 82 68 83 #[derive(Debug, Serialize, Deserialize)] 69 84 pub struct SyncFrame { 70 - pub did: String, 85 + pub did: Did, 71 86 pub rev: String, 72 87 #[serde(with = "serde_bytes")] 73 88 pub blocks: Vec<u8>, ··· 75 90 pub time: String, 76 91 } 77 92 93 + #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 94 + pub enum InfoFrameName { 95 + #[serde(rename = "OutdatedCursor")] 96 + OutdatedCursor, 97 + } 98 + 99 + #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 100 + pub enum ErrorFrameName { 101 + #[serde(rename = "FutureCursor")] 102 + FutureCursor, 103 + } 104 + 78 105 #[derive(Debug, Serialize, Deserialize)] 79 106 pub struct InfoFrame { 80 - pub name: String, 107 + pub name: InfoFrameName, 81 108 #[serde(skip_serializing_if = "Option::is_none")] 82 109 pub message: Option<String>, 83 110 } ··· 89 116 90 117 #[derive(Debug, Serialize, Deserialize)] 91 118 pub struct ErrorFrameBody { 92 - pub error: String, 119 + pub error: ErrorFrameName, 93 120 #[serde(skip_serializing_if = "Option::is_none")] 94 121 pub message: Option<String>, 95 122 } ··· 113 140 114 141 pub struct CommitFrameBuilder { 115 142 seq: i64, 116 - did: String, 143 + did: Did, 117 144 commit_cid: Cid, 118 145 prev_cid: Option<Cid>, 119 146 ops_json: serde_json::Value, ··· 126 153 #[allow(clippy::too_many_arguments)] 127 154 pub fn new( 128 155 seq: i64, 129 - did: String, 156 + did: Did, 130 157 commit_cid_str: &str, 131 158 prev_cid_str: Option<&str>, 132 159 ops_json: serde_json::Value, ··· 206 233 })?; 207 234 let builder = CommitFrameBuilder::new( 208 235 event.seq.as_i64(), 209 - event.did.to_string(), 236 + event.did.clone(), 210 237 commit_cid.as_str(), 211 238 event.prev_cid.as_ref().map(|c| c.as_str()), 212 239 event.ops.unwrap_or_default(),
+3 -6
crates/tranquil-pds/src/sync/import.rs
··· 210 210 if let Ipld::Map(entry_obj) = entry { 211 211 let prefix_len = entry_obj 212 212 .get("p") 213 - .and_then(|p| { 214 - if let Ipld::Integer(n) = p { 215 - Some(*n as usize) 216 - } else { 217 - None 218 - } 213 + .and_then(|p| match p { 214 + Ipld::Integer(n) => usize::try_from(*n).ok(), 215 + _ => None, 219 216 }) 220 217 .unwrap_or(0); 221 218
+2 -1
crates/tranquil-pds/src/sync/mod.rs
··· 20 20 pub use subscribe_repos::subscribe_repos; 21 21 pub use tranquil_db_traits::AccountStatus; 22 22 pub use util::{ 23 - RepoAccount, RepoAvailabilityError, assert_repo_availability, get_account_with_status, 23 + RepoAccessLevel, RepoAccount, RepoAvailabilityError, assert_repo_availability, 24 + get_account_with_status, 24 25 }; 25 26 pub use verify::{CarVerifier, VerifiedCar, VerifyError};
+57 -86
crates/tranquil-pds/src/sync/repo.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::scheduled::generate_repo_car_from_user_blocks; 3 3 use crate::state::AppState; 4 - use crate::sync::car::encode_car_header; 5 - use crate::sync::util::assert_repo_availability; 4 + use crate::sync::car::{encode_car_block, encode_car_header}; 5 + use crate::sync::util::{RepoAccessLevel, assert_repo_availability}; 6 6 use axum::{ 7 7 extract::{Query, RawQuery, State}, 8 8 http::StatusCode, ··· 11 11 use cid::Cid; 12 12 use jacquard_repo::storage::BlockStore; 13 13 use serde::Deserialize; 14 - use std::io::Write; 15 14 use std::str::FromStr; 16 15 use tracing::error; 17 16 use tranquil_types::Did; 18 17 19 - fn parse_get_blocks_query(query_string: &str) -> Result<(String, Vec<String>), String> { 20 - let did = crate::util::parse_repeated_query_param(Some(query_string), "did") 18 + struct GetBlocksParams { 19 + did: Did, 20 + cids: Vec<String>, 21 + } 22 + 23 + fn parse_get_blocks_query(query_string: &str) -> Result<GetBlocksParams, ApiError> { 24 + let did_str = crate::util::parse_repeated_query_param(Some(query_string), "did") 21 25 .into_iter() 22 26 .next() 23 - .ok_or("Missing required parameter: did")?; 27 + .ok_or_else(|| ApiError::InvalidRequest("Missing required parameter: did".into()))?; 28 + let did: Did = did_str 29 + .parse() 30 + .map_err(|_| ApiError::InvalidRequest("invalid did".into()))?; 24 31 let cids = crate::util::parse_repeated_query_param(Some(query_string), "cids"); 25 - Ok((did, cids)) 32 + Ok(GetBlocksParams { did, cids }) 26 33 } 27 34 28 35 pub async fn get_blocks(State(state): State<AppState>, RawQuery(query): RawQuery) -> Response { ··· 30 37 return ApiError::InvalidRequest("Missing query parameters".into()).into_response(); 31 38 }; 32 39 33 - let (did_str, cid_strings) = match parse_get_blocks_query(&query_string) { 40 + let GetBlocksParams { 41 + did, 42 + cids: cid_strings, 43 + } = match parse_get_blocks_query(&query_string) { 34 44 Ok(parsed) => parsed, 35 - Err(msg) => return ApiError::InvalidRequest(msg).into_response(), 36 - }; 37 - let did: Did = match did_str.parse() { 38 - Ok(d) => d, 39 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 40 - }; 41 - 42 - let _account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 43 - Ok(a) => a, 44 45 Err(e) => return e.into_response(), 45 46 }; 46 47 48 + let _account = 49 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 50 + .await 51 + { 52 + Ok(a) => a, 53 + Err(e) => return e.into_response(), 54 + }; 55 + 47 56 let cids: Vec<Cid> = match cid_strings 48 57 .iter() 49 58 .map(|s| Cid::from_str(s).map_err(|_| s.clone())) ··· 89 98 } 90 99 }; 91 100 let mut car_bytes = header; 92 - for (i, block_opt) in blocks.into_iter().enumerate() { 93 - if let Some(block) = block_opt { 94 - let cid = cids[i]; 95 - let cid_bytes = cid.to_bytes(); 96 - let total_len = cid_bytes.len() + block.len(); 97 - let mut writer = Vec::new(); 98 - crate::sync::car::write_varint(&mut writer, total_len as u64) 99 - .expect("Writing to Vec<u8> should never fail"); 100 - writer 101 - .write_all(&cid_bytes) 102 - .expect("Writing to Vec<u8> should never fail"); 103 - writer 104 - .write_all(&block) 105 - .expect("Writing to Vec<u8> should never fail"); 106 - car_bytes.extend_from_slice(&writer); 107 - } 108 - } 101 + blocks 102 + .into_iter() 103 + .enumerate() 104 + .filter_map(|(i, block_opt)| block_opt.map(|block| (cids[i], block))) 105 + .for_each(|(cid, block)| car_bytes.extend_from_slice(&encode_car_block(&cid, &block))); 109 106 ( 110 107 StatusCode::OK, 111 108 [(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")], ··· 116 113 117 114 #[derive(Deserialize)] 118 115 pub struct GetRepoQuery { 119 - pub did: String, 116 + pub did: Did, 120 117 pub since: Option<String>, 121 118 } 122 119 ··· 124 121 State(state): State<AppState>, 125 122 Query(query): Query<GetRepoQuery>, 126 123 ) -> Response { 127 - let did: Did = match query.did.parse() { 128 - Ok(d) => d, 129 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 130 - }; 131 - let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 132 - Ok(a) => a, 133 - Err(e) => return e.into_response(), 134 - }; 124 + let did = query.did; 125 + let account = 126 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 127 + .await 128 + { 129 + Ok(a) => a, 130 + Err(e) => return e.into_response(), 131 + }; 135 132 136 133 let Some(head_str) = account.repo_root_cid else { 137 134 return ApiError::RepoNotFound(Some("Repo not initialized".into())).into_response(); ··· 223 220 } 224 221 }; 225 222 226 - for (i, block_opt) in blocks.into_iter().enumerate() { 227 - if let Some(block) = block_opt { 228 - let cid = block_cids[i]; 229 - let cid_bytes = cid.to_bytes(); 230 - let total_len = cid_bytes.len() + block.len(); 231 - let mut writer = Vec::new(); 232 - crate::sync::car::write_varint(&mut writer, total_len as u64) 233 - .expect("Writing to Vec<u8> should never fail"); 234 - writer 235 - .write_all(&cid_bytes) 236 - .expect("Writing to Vec<u8> should never fail"); 237 - writer 238 - .write_all(&block) 239 - .expect("Writing to Vec<u8> should never fail"); 240 - car_bytes.extend_from_slice(&writer); 241 - } 242 - } 223 + blocks 224 + .into_iter() 225 + .enumerate() 226 + .filter_map(|(i, block_opt)| block_opt.map(|block| (block_cids[i], block))) 227 + .for_each(|(cid, block)| car_bytes.extend_from_slice(&encode_car_block(&cid, &block))); 243 228 244 229 ( 245 230 StatusCode::OK, ··· 251 236 252 237 #[derive(Deserialize)] 253 238 pub struct GetRecordQuery { 254 - pub did: String, 239 + pub did: Did, 255 240 pub collection: String, 256 241 pub rkey: String, 257 242 } ··· 265 250 use std::collections::BTreeMap; 266 251 use std::sync::Arc; 267 252 268 - let did: Did = match query.did.parse() { 269 - Ok(d) => d, 270 - Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 271 - }; 272 - let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 273 - Ok(a) => a, 274 - Err(e) => return e.into_response(), 275 - }; 253 + let did = query.did; 254 + let account = 255 + match assert_repo_availability(state.repo_repo.as_ref(), &did, RepoAccessLevel::Public) 256 + .await 257 + { 258 + Ok(a) => a, 259 + Err(e) => return e.into_response(), 260 + }; 276 261 277 262 let commit_cid_str = match account.repo_root_cid { 278 263 Some(cid) => cid, ··· 321 306 } 322 307 }; 323 308 let mut car_bytes = header; 324 - let write_block = |car: &mut Vec<u8>, cid: &Cid, data: &[u8]| { 325 - let cid_bytes = cid.to_bytes(); 326 - let total_len = cid_bytes.len() + data.len(); 327 - let mut writer = Vec::new(); 328 - crate::sync::car::write_varint(&mut writer, total_len as u64) 329 - .expect("Writing to Vec<u8> should never fail"); 330 - writer 331 - .write_all(&cid_bytes) 332 - .expect("Writing to Vec<u8> should never fail"); 333 - writer 334 - .write_all(data) 335 - .expect("Writing to Vec<u8> should never fail"); 336 - car.extend_from_slice(&writer); 337 - }; 338 - write_block(&mut car_bytes, &commit_cid, &commit_bytes); 309 + car_bytes.extend_from_slice(&encode_car_block(&commit_cid, &commit_bytes)); 339 310 proof_blocks 340 311 .iter() 341 - .for_each(|(cid, data)| write_block(&mut car_bytes, cid, data)); 342 - write_block(&mut car_bytes, &record_cid, &record_block); 312 + .for_each(|(cid, data)| car_bytes.extend_from_slice(&encode_car_block(cid, data))); 313 + car_bytes.extend_from_slice(&encode_car_block(&record_cid, &record_block)); 343 314 ( 344 315 StatusCode::OK, 345 316 [(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")],
+4 -3
crates/tranquil-pds/src/sync/subscribe_repos.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 + use crate::sync::frame::{ErrorFrameName, InfoFrameName}; 3 4 use crate::sync::util::{ 4 5 format_error_frame, format_event_for_sending, format_event_with_prefetched_blocks, 5 6 format_info_frame, prefetch_blocks_for_events, ··· 82 83 83 84 if cursor_seq > current_seq { 84 85 if let Ok(error_bytes) = 85 - format_error_frame("FutureCursor", Some("Cursor in the future.")) 86 + format_error_frame(ErrorFrameName::FutureCursor, Some("Cursor in the future.")) 86 87 { 87 88 let _ = socket.send(Message::Binary(error_bytes.into())).await; 88 89 } ··· 105 106 && event.created_at < backfill_time 106 107 { 107 108 if let Ok(info_bytes) = format_info_frame( 108 - "OutdatedCursor", 109 + InfoFrameName::OutdatedCursor, 109 110 Some("Requested cursor exceeded limit. Possibly missing events"), 110 111 ) { 111 112 let _ = socket.send(Message::Binary(info_bytes.into())).await; ··· 161 162 } 162 163 crate::metrics::record_firehose_event(); 163 164 } 164 - if (events_count as i64) < BACKFILL_BATCH_SIZE { 165 + if i64::try_from(events_count).unwrap_or(i64::MAX) < BACKFILL_BATCH_SIZE { 165 166 break; 166 167 } 167 168 }
+133 -63
crates/tranquil-pds/src/sync/util.rs
··· 2 2 use crate::state::AppState; 3 3 use crate::sync::firehose::SequencedEvent; 4 4 use crate::sync::frame::{ 5 - AccountFrame, CommitFrame, ErrorFrameBody, ErrorFrameHeader, FrameHeader, IdentityFrame, 6 - InfoFrame, SyncFrame, 5 + AccountFrame, CommitFrame, ErrorFrameBody, ErrorFrameHeader, ErrorFrameName, FrameHeader, 6 + FrameType, IdentityFrame, InfoFrame, InfoFrameName, SyncFrame, 7 7 }; 8 8 use axum::response::{IntoResponse, Response}; 9 9 use bytes::Bytes; ··· 18 18 use tranquil_db_traits::{AccountStatus, RepoEventType, RepoRepository}; 19 19 use tranquil_types::Did; 20 20 21 + #[derive(Debug)] 22 + pub enum SyncFrameError { 23 + CarWrite(iroh_car::Error), 24 + CarFinalize(iroh_car::Error), 25 + IoFlush(std::io::Error), 26 + CborSerialize(String), 27 + MissingCommitCid, 28 + CommitBlockNotFound, 29 + RevExtraction, 30 + InvalidEvent(String), 31 + BlockStore(tranquil_db_traits::DbError), 32 + CidParse(cid::Error), 33 + } 34 + 35 + impl std::fmt::Display for SyncFrameError { 36 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 37 + match self { 38 + Self::CarWrite(e) => write!(f, "CAR block write failed: {}", e), 39 + Self::CarFinalize(e) => write!(f, "CAR finalize failed: {}", e), 40 + Self::IoFlush(e) => write!(f, "CAR buffer flush failed: {}", e), 41 + Self::CborSerialize(e) => write!(f, "CBOR serialization failed: {}", e), 42 + Self::MissingCommitCid => write!(f, "missing commit_cid"), 43 + Self::CommitBlockNotFound => write!(f, "commit block not found"), 44 + Self::RevExtraction => write!(f, "could not extract rev from commit"), 45 + Self::InvalidEvent(msg) => write!(f, "invalid event: {}", msg), 46 + Self::BlockStore(e) => write!(f, "block store error: {}", e), 47 + Self::CidParse(e) => write!(f, "CID parse failed: {}", e), 48 + } 49 + } 50 + } 51 + 52 + impl std::error::Error for SyncFrameError {} 53 + 54 + impl From<serde_ipld_dagcbor::EncodeError<std::collections::TryReserveError>> for SyncFrameError { 55 + fn from(e: serde_ipld_dagcbor::EncodeError<std::collections::TryReserveError>) -> Self { 56 + Self::CborSerialize(e.to_string()) 57 + } 58 + } 59 + 60 + impl From<serde_ipld_dagcbor::EncodeError<std::io::Error>> for SyncFrameError { 61 + fn from(e: serde_ipld_dagcbor::EncodeError<std::io::Error>) -> Self { 62 + Self::CborSerialize(e.to_string()) 63 + } 64 + } 65 + 66 + impl From<cid::Error> for SyncFrameError { 67 + fn from(e: cid::Error) -> Self { 68 + Self::CidParse(e) 69 + } 70 + } 71 + 72 + impl From<tranquil_db_traits::DbError> for SyncFrameError { 73 + fn from(e: tranquil_db_traits::DbError) -> Self { 74 + Self::BlockStore(e) 75 + } 76 + } 77 + 78 + impl From<jacquard_repo::error::RepoError> for SyncFrameError { 79 + fn from(e: jacquard_repo::error::RepoError) -> Self { 80 + Self::BlockStore(tranquil_db_traits::DbError::from_query_error(e.to_string())) 81 + } 82 + } 83 + 21 84 pub struct RepoAccount { 22 - pub did: String, 85 + pub did: Did, 23 86 pub user_id: uuid::Uuid, 24 87 pub status: AccountStatus, 25 88 pub repo_root_cid: Option<String>, 26 89 } 27 90 91 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 92 + pub enum RepoAccessLevel { 93 + Public, 94 + Privileged, 95 + } 96 + 28 97 pub enum RepoAvailabilityError { 29 - NotFound(String), 30 - Takendown(String), 31 - Deactivated(String), 98 + NotFound(Did), 99 + Takendown(Did), 100 + Deactivated(Did), 32 101 Internal(String), 33 102 } 34 103 ··· 64 133 }; 65 134 66 135 RepoAccount { 67 - did: r.did.to_string(), 136 + did: r.did, 68 137 user_id: r.user_id, 69 138 status, 70 139 repo_root_cid: r.repo_root_cid.map(|c| c.to_string()), ··· 75 144 pub async fn assert_repo_availability( 76 145 repo_repo: &dyn RepoRepository, 77 146 did: &Did, 78 - is_admin_or_self: bool, 147 + access_level: RepoAccessLevel, 79 148 ) -> Result<RepoAccount, RepoAvailabilityError> { 80 149 let account = get_account_with_status(repo_repo, did) 81 150 .await 82 151 .map_err(|e| RepoAvailabilityError::Internal(e.to_string()))?; 83 152 84 - let did_str = did.to_string(); 85 153 let account = match account { 86 154 Some(a) => a, 87 - None => return Err(RepoAvailabilityError::NotFound(did_str)), 155 + None => return Err(RepoAvailabilityError::NotFound(did.clone())), 88 156 }; 89 157 90 - if is_admin_or_self { 158 + if access_level == RepoAccessLevel::Privileged { 91 159 return Ok(account); 92 160 } 93 161 94 162 match account.status { 95 - AccountStatus::Takendown => return Err(RepoAvailabilityError::Takendown(did_str)), 163 + AccountStatus::Takendown => return Err(RepoAvailabilityError::Takendown(did.clone())), 96 164 AccountStatus::Deactivated => { 97 - return Err(RepoAvailabilityError::Deactivated(did_str)); 165 + return Err(RepoAvailabilityError::Deactivated(did.clone())); 98 166 } 99 167 _ => {} 100 168 } ··· 112 180 commit_cid: Cid, 113 181 commit_bytes: Option<Bytes>, 114 182 other_blocks: BTreeMap<Cid, Bytes>, 115 - ) -> Result<Vec<u8>, anyhow::Error> { 183 + ) -> Result<Vec<u8>, SyncFrameError> { 116 184 let mut buffer = Cursor::new(Vec::new()); 117 185 let header = CarHeader::new_v1(vec![commit_cid]); 118 186 let mut writer = CarWriter::new(header, &mut buffer); ··· 120 188 writer 121 189 .write(*cid, data.as_ref()) 122 190 .await 123 - .map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?; 191 + .map_err(SyncFrameError::CarWrite)?; 124 192 } 125 193 if let Some(data) = commit_bytes { 126 194 writer 127 195 .write(commit_cid, data.as_ref()) 128 196 .await 129 - .map_err(|e| anyhow::anyhow!("writing commit block: {}", e))?; 197 + .map_err(SyncFrameError::CarWrite)?; 130 198 } 131 - writer 132 - .finish() 133 - .await 134 - .map_err(|e| anyhow::anyhow!("finalizing CAR: {}", e))?; 135 - buffer 136 - .flush() 137 - .await 138 - .map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?; 199 + writer.finish().await.map_err(SyncFrameError::CarFinalize)?; 200 + buffer.flush().await.map_err(SyncFrameError::IoFlush)?; 139 201 Ok(buffer.into_inner()) 140 202 } 141 203 ··· 143 205 dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string() 144 206 } 145 207 146 - fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 208 + fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, SyncFrameError> { 147 209 let frame = IdentityFrame { 148 - did: event.did.to_string(), 210 + did: event.did.clone(), 149 211 handle: event.handle.as_ref().map(|h| h.to_string()), 150 212 seq: event.seq.as_i64(), 151 213 time: format_atproto_time(event.created_at), 152 214 }; 153 215 let header = FrameHeader { 154 216 op: 1, 155 - t: "#identity".to_string(), 217 + t: FrameType::Identity, 156 218 }; 157 219 let mut bytes = Vec::with_capacity(256); 158 220 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 160 222 Ok(bytes) 161 223 } 162 224 163 - fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 225 + fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, SyncFrameError> { 164 226 let frame = AccountFrame { 165 - did: event.did.to_string(), 227 + did: event.did.clone(), 166 228 active: event.active.unwrap_or(true), 167 - status: event 168 - .status 169 - .and_then(|s| s.for_firehose().map(String::from)), 229 + status: event.status.filter(|s| !s.is_active()), 170 230 seq: event.seq.as_i64(), 171 231 time: format_atproto_time(event.created_at), 172 232 }; 173 233 let header = FrameHeader { 174 234 op: 1, 175 - t: "#account".to_string(), 235 + t: FrameType::Account, 176 236 }; 177 237 let mut bytes = Vec::with_capacity(256); 178 238 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 192 252 async fn format_sync_event( 193 253 state: &AppState, 194 254 event: &SequencedEvent, 195 - ) -> Result<Vec<u8>, anyhow::Error> { 255 + ) -> Result<Vec<u8>, SyncFrameError> { 196 256 let commit_cid_str = event 197 257 .commit_cid 198 258 .as_ref() 199 - .ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?; 259 + .ok_or(SyncFrameError::MissingCommitCid)?; 200 260 let commit_cid = Cid::from_str(commit_cid_str)?; 201 261 let commit_bytes = state 202 262 .block_store 203 263 .get(&commit_cid) 204 264 .await? 205 - .ok_or_else(|| anyhow::anyhow!("Commit block not found"))?; 265 + .ok_or(SyncFrameError::CommitBlockNotFound)?; 206 266 let rev = if let Some(ref stored_rev) = event.rev { 207 267 stored_rev.clone() 208 268 } else { 209 - extract_rev_from_commit_bytes(&commit_bytes) 210 - .ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))? 269 + extract_rev_from_commit_bytes(&commit_bytes).ok_or(SyncFrameError::RevExtraction)? 211 270 }; 212 271 let car_bytes = write_car_blocks(commit_cid, Some(commit_bytes), BTreeMap::new()).await?; 213 272 let frame = SyncFrame { 214 - did: event.did.to_string(), 273 + did: event.did.clone(), 215 274 rev, 216 275 blocks: car_bytes, 217 276 seq: event.seq.as_i64(), ··· 219 278 }; 220 279 let header = FrameHeader { 221 280 op: 1, 222 - t: "#sync".to_string(), 281 + t: FrameType::Sync, 223 282 }; 224 283 let mut bytes = Vec::with_capacity(512); 225 284 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 230 289 pub async fn format_event_for_sending( 231 290 state: &AppState, 232 291 event: SequencedEvent, 233 - ) -> Result<Vec<u8>, anyhow::Error> { 292 + ) -> Result<Vec<u8>, SyncFrameError> { 234 293 match event.event_type { 235 294 RepoEventType::Identity => return format_identity_event(&event), 236 295 RepoEventType::Account => return format_account_event(&event), ··· 240 299 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 241 300 let prev_cid_link = event.prev_cid.clone(); 242 301 let prev_data_cid_link = event.prev_data_cid.clone(); 243 - let mut frame: CommitFrame = event 244 - .try_into() 245 - .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 302 + let mut frame: CommitFrame = 303 + event 304 + .try_into() 305 + .map_err(|e: crate::sync::frame::CommitFrameError| { 306 + SyncFrameError::InvalidEvent(e.to_string()) 307 + })?; 246 308 if let Some(ref pdc) = prev_data_cid_link 247 309 && let Ok(cid) = Cid::from_str(pdc.as_str()) 248 310 { ··· 287 349 frame.blocks = car_bytes; 288 350 let header = FrameHeader { 289 351 op: 1, 290 - t: "#commit".to_string(), 352 + t: FrameType::Commit, 291 353 }; 292 354 let mut bytes = Vec::with_capacity(frame.blocks.len() + 512); 293 355 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 298 360 pub async fn prefetch_blocks_for_events( 299 361 state: &AppState, 300 362 events: &[SequencedEvent], 301 - ) -> Result<HashMap<Cid, Bytes>, anyhow::Error> { 363 + ) -> Result<HashMap<Cid, Bytes>, SyncFrameError> { 302 364 let mut all_cids: Vec<Cid> = events 303 365 .iter() 304 366 .flat_map(|event| { ··· 332 394 fn format_sync_event_with_prefetched( 333 395 event: &SequencedEvent, 334 396 prefetched: &HashMap<Cid, Bytes>, 335 - ) -> Result<Vec<u8>, anyhow::Error> { 397 + ) -> Result<Vec<u8>, SyncFrameError> { 336 398 let commit_cid_str = event 337 399 .commit_cid 338 400 .as_ref() 339 - .ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?; 401 + .ok_or(SyncFrameError::MissingCommitCid)?; 340 402 let commit_cid = Cid::from_str(commit_cid_str)?; 341 403 let commit_bytes = prefetched 342 404 .get(&commit_cid) 343 - .ok_or_else(|| anyhow::anyhow!("Commit block not found in prefetched"))?; 405 + .ok_or(SyncFrameError::CommitBlockNotFound)?; 344 406 let rev = if let Some(ref stored_rev) = event.rev { 345 407 stored_rev.clone() 346 408 } else { 347 - extract_rev_from_commit_bytes(commit_bytes) 348 - .ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))? 409 + extract_rev_from_commit_bytes(commit_bytes).ok_or(SyncFrameError::RevExtraction)? 349 410 }; 350 411 let car_bytes = futures::executor::block_on(write_car_blocks( 351 412 commit_cid, ··· 353 414 BTreeMap::new(), 354 415 ))?; 355 416 let frame = SyncFrame { 356 - did: event.did.to_string(), 417 + did: event.did.clone(), 357 418 rev, 358 419 blocks: car_bytes, 359 420 seq: event.seq.as_i64(), ··· 361 422 }; 362 423 let header = FrameHeader { 363 424 op: 1, 364 - t: "#sync".to_string(), 425 + t: FrameType::Sync, 365 426 }; 366 427 let mut bytes = Vec::new(); 367 428 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 372 433 pub async fn format_event_with_prefetched_blocks( 373 434 event: SequencedEvent, 374 435 prefetched: &HashMap<Cid, Bytes>, 375 - ) -> Result<Vec<u8>, anyhow::Error> { 436 + ) -> Result<Vec<u8>, SyncFrameError> { 376 437 match event.event_type { 377 438 RepoEventType::Identity => return format_identity_event(&event), 378 439 RepoEventType::Account => return format_account_event(&event), ··· 382 443 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 383 444 let prev_cid_link = event.prev_cid.clone(); 384 445 let prev_data_cid_link = event.prev_data_cid.clone(); 385 - let mut frame: CommitFrame = event 386 - .try_into() 387 - .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 446 + let mut frame: CommitFrame = 447 + event 448 + .try_into() 449 + .map_err(|e: crate::sync::frame::CommitFrameError| { 450 + SyncFrameError::InvalidEvent(e.to_string()) 451 + })?; 388 452 if let Some(ref pdc) = prev_data_cid_link 389 453 && let Ok(cid) = Cid::from_str(pdc.as_str()) 390 454 { ··· 427 491 frame.blocks = car_bytes; 428 492 let header = FrameHeader { 429 493 op: 1, 430 - t: "#commit".to_string(), 494 + t: FrameType::Commit, 431 495 }; 432 496 let mut bytes = Vec::with_capacity(frame.blocks.len() + 512); 433 497 serde_ipld_dagcbor::to_writer(&mut bytes, &header)?; ··· 435 499 Ok(bytes) 436 500 } 437 501 438 - pub fn format_info_frame(name: &str, message: Option<&str>) -> Result<Vec<u8>, anyhow::Error> { 502 + pub fn format_info_frame( 503 + name: InfoFrameName, 504 + message: Option<&str>, 505 + ) -> Result<Vec<u8>, SyncFrameError> { 439 506 let header = FrameHeader { 440 507 op: 1, 441 - t: "#info".to_string(), 508 + t: FrameType::Info, 442 509 }; 443 510 let frame = InfoFrame { 444 - name: name.to_string(), 511 + name, 445 512 message: message.map(String::from), 446 513 }; 447 514 let mut bytes = Vec::with_capacity(128); ··· 450 517 Ok(bytes) 451 518 } 452 519 453 - pub fn format_error_frame(error: &str, message: Option<&str>) -> Result<Vec<u8>, anyhow::Error> { 520 + pub fn format_error_frame( 521 + error: ErrorFrameName, 522 + message: Option<&str>, 523 + ) -> Result<Vec<u8>, SyncFrameError> { 454 524 let header = ErrorFrameHeader { op: -1 }; 455 525 let frame = ErrorFrameBody { 456 - error: error.to_string(), 526 + error, 457 527 message: message.map(String::from), 458 528 }; 459 529 let mut bytes = Vec::with_capacity(128);
+21 -19
crates/tranquil-pds/src/sync/verify.rs
··· 8 8 use std::collections::HashMap; 9 9 use thiserror::Error; 10 10 use tracing::{debug, warn}; 11 + use tranquil_types::Did; 11 12 12 13 #[derive(Error, Debug)] 13 14 pub enum VerifyError { ··· 57 58 58 59 pub async fn verify_car( 59 60 &self, 60 - expected_did: &str, 61 + expected_did: &Did, 61 62 root_cid: &Cid, 62 63 blocks: &HashMap<Cid, Bytes>, 63 64 ) -> Result<VerifiedCar, VerifyError> { ··· 67 68 let commit = 68 69 Commit::from_cbor(root_block).map_err(|e| VerifyError::InvalidCommit(e.to_string()))?; 69 70 let commit_did = commit.did().as_str(); 70 - if commit_did != expected_did { 71 + if commit_did != expected_did.as_str() { 71 72 return Err(VerifyError::DidMismatch { 72 73 commit_did: commit_did.to_string(), 73 74 expected_did: expected_did.to_string(), 74 75 }); 75 76 } 76 - let pubkey = self.resolve_did_signing_key(commit_did).await?; 77 + let pubkey = self.resolve_did_signing_key(expected_did).await?; 77 78 commit 78 79 .verify(&pubkey) 79 80 .map_err(|_| VerifyError::InvalidSignature)?; 80 - debug!("Commit signature verified for DID {}", commit_did); 81 + debug!("Commit signature verified for DID {}", expected_did); 81 82 let data_cid = commit.data(); 82 83 self.verify_mst_structure(data_cid, blocks)?; 83 - debug!("MST structure verified for DID {}", commit_did); 84 + debug!("MST structure verified for DID {}", expected_did); 84 85 Ok(VerifiedCar { 85 - did: commit_did.to_string(), 86 + did: expected_did.clone(), 86 87 rev: commit.rev().to_string(), 87 88 data_cid: *data_cid, 88 89 prev: commit.prev().cloned(), ··· 91 92 92 93 pub fn verify_car_structure_only( 93 94 &self, 94 - expected_did: &str, 95 + expected_did: &Did, 95 96 root_cid: &Cid, 96 97 blocks: &HashMap<Cid, Bytes>, 97 98 ) -> Result<VerifiedCar, VerifyError> { ··· 101 102 let commit = 102 103 Commit::from_cbor(root_block).map_err(|e| VerifyError::InvalidCommit(e.to_string()))?; 103 104 let commit_did = commit.did().as_str(); 104 - if commit_did != expected_did { 105 + if commit_did != expected_did.as_str() { 105 106 return Err(VerifyError::DidMismatch { 106 107 commit_did: commit_did.to_string(), 107 108 expected_did: expected_did.to_string(), ··· 111 112 self.verify_mst_structure(data_cid, blocks)?; 112 113 debug!( 113 114 "MST structure verified for DID {} (signature verification skipped for migration)", 114 - commit_did 115 + expected_did 115 116 ); 116 117 Ok(VerifiedCar { 117 - did: commit_did.to_string(), 118 + did: expected_did.clone(), 118 119 rev: commit.rev().to_string(), 119 120 data_cid: *data_cid, 120 121 prev: commit.prev().cloned(), 121 122 }) 122 123 } 123 124 124 - async fn resolve_did_signing_key(&self, did: &str) -> Result<PublicKey<'static>, VerifyError> { 125 + async fn resolve_did_signing_key(&self, did: &Did) -> Result<PublicKey<'static>, VerifyError> { 125 126 let did_doc = self.resolve_did_document(did).await?; 126 127 did_doc 127 128 .atproto_public_key() ··· 129 130 .ok_or(VerifyError::NoSigningKey) 130 131 } 131 132 132 - async fn resolve_did_document(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> { 133 - if did.starts_with("did:plc:") { 134 - self.resolve_plc_did(did).await 135 - } else if did.starts_with("did:web:") { 136 - self.resolve_web_did(did).await 133 + async fn resolve_did_document(&self, did: &Did) -> Result<DidDocument<'static>, VerifyError> { 134 + let did_str = did.as_str(); 135 + if did_str.starts_with("did:plc:") { 136 + self.resolve_plc_did(did_str).await 137 + } else if did_str.starts_with("did:web:") { 138 + self.resolve_web_did(did_str).await 137 139 } else { 138 140 Err(VerifyError::DidResolutionFailed(format!( 139 141 "Unsupported DID method: {}", 140 - did 142 + did_str 141 143 ))) 142 144 } 143 145 } ··· 239 241 let prefix_len = entry_obj 240 242 .get("p") 241 243 .and_then(|p| match p { 242 - Ipld::Integer(i) => Some(*i as usize), 244 + Ipld::Integer(i) => usize::try_from(*i).ok(), 243 245 _ => None, 244 246 }) 245 247 .unwrap_or(0); ··· 294 296 295 297 #[derive(Debug, Clone)] 296 298 pub struct VerifiedCar { 297 - pub did: String, 299 + pub did: Did, 298 300 pub rev: String, 299 301 pub data_cid: Cid, 300 302 pub prev: Option<Cid>,
+2 -1
crates/tranquil-pds/src/sync/verify_tests.rs
··· 225 225 #[tokio::test] 226 226 async fn test_unsupported_did_method() { 227 227 let verifier = CarVerifier::new(); 228 - let result = verifier.resolve_did_document("did:unknown:test").await; 228 + let did: tranquil_types::Did = "did:unknown:test".parse().expect("valid DID format"); 229 + let result = verifier.resolve_did_document(&did).await; 229 230 assert!(result.is_err()); 230 231 let err = result.unwrap_err(); 231 232 assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
+6
crates/tranquil-pds/src/types.rs
··· 1 1 pub use tranquil_types::*; 2 + 3 + use std::sync::LazyLock; 4 + 5 + pub static PROFILE_COLLECTION: LazyLock<Nsid> = 6 + LazyLock::new(|| "app.bsky.actor.profile".parse().unwrap()); 7 + pub static PROFILE_RKEY: LazyLock<Rkey> = LazyLock::new(|| "self".parse().unwrap());
+46 -3
crates/tranquil-pds/src/util.rs
··· 1 - use axum::http::HeaderMap; 1 + use axum::http::{HeaderMap, HeaderName}; 2 2 use cid::Cid; 3 3 use ipld_core::ipld::Ipld; 4 4 use rand::Rng; ··· 76 76 .unwrap_or_default() 77 77 } 78 78 79 - pub fn get_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { 79 + pub const HEADER_DPOP: HeaderName = HeaderName::from_static("dpop"); 80 + pub const HEADER_DPOP_NONCE: HeaderName = HeaderName::from_static("dpop-nonce"); 81 + pub const HEADER_ATPROTO_PROXY: HeaderName = HeaderName::from_static("atproto-proxy"); 82 + pub const HEADER_ATPROTO_ACCEPT_LABELERS: HeaderName = 83 + HeaderName::from_static("atproto-accept-labelers"); 84 + pub const HEADER_ATPROTO_REPO_REV: HeaderName = HeaderName::from_static("atproto-repo-rev"); 85 + pub const HEADER_ATPROTO_CONTENT_LABELERS: HeaderName = 86 + HeaderName::from_static("atproto-content-labelers"); 87 + pub const HEADER_X_BSKY_TOPICS: HeaderName = HeaderName::from_static("x-bsky-topics"); 88 + 89 + pub fn get_header_str( 90 + headers: &HeaderMap, 91 + name: impl axum::http::header::AsHeaderName, 92 + ) -> Option<&str> { 80 93 headers.get(name).and_then(|h| h.to_str().ok()) 81 94 } 82 95 ··· 140 153 TELEGRAM_BOT_USERNAME.get().map(|s| s.as_str()) 141 154 } 142 155 156 + pub fn parse_env_bool(key: &str) -> bool { 157 + std::env::var(key) 158 + .map(|v| v == "true" || v == "1") 159 + .unwrap_or(false) 160 + } 161 + 143 162 pub fn pds_public_url() -> String { 144 163 format!("https://{}", pds_hostname()) 145 164 } ··· 163 182 JsonValue::Bool(b) => Ipld::Bool(*b), 164 183 JsonValue::Number(n) => { 165 184 if let Some(i) = n.as_i64() { 166 - Ipld::Integer(i as i128) 185 + Ipld::Integer(i128::from(i)) 167 186 } else if let Some(f) = n.as_f64() { 168 187 Ipld::Float(f) 169 188 } else { ··· 352 371 return; 353 372 } 354 373 panic!("Failed to find CID link in parsed CBOR"); 374 + } 375 + 376 + #[test] 377 + fn test_parse_env_bool_true_values() { 378 + unsafe { std::env::set_var("TEST_PARSE_ENV_BOOL_1", "true") }; 379 + assert!(parse_env_bool("TEST_PARSE_ENV_BOOL_1")); 380 + unsafe { std::env::set_var("TEST_PARSE_ENV_BOOL_1", "1") }; 381 + assert!(parse_env_bool("TEST_PARSE_ENV_BOOL_1")); 382 + } 383 + 384 + #[test] 385 + fn test_parse_env_bool_false_values() { 386 + unsafe { std::env::set_var("TEST_PARSE_ENV_BOOL_2", "false") }; 387 + assert!(!parse_env_bool("TEST_PARSE_ENV_BOOL_2")); 388 + unsafe { std::env::set_var("TEST_PARSE_ENV_BOOL_2", "0") }; 389 + assert!(!parse_env_bool("TEST_PARSE_ENV_BOOL_2")); 390 + unsafe { std::env::set_var("TEST_PARSE_ENV_BOOL_2", "yes") }; 391 + assert!(!parse_env_bool("TEST_PARSE_ENV_BOOL_2")); 392 + } 393 + 394 + #[test] 395 + fn test_parse_env_bool_unset() { 396 + unsafe { std::env::remove_var("TEST_PARSE_ENV_BOOL_UNSET_KEY") }; 397 + assert!(!parse_env_bool("TEST_PARSE_ENV_BOOL_UNSET_KEY")); 355 398 } 356 399 357 400 #[test]
+2 -1
crates/tranquil-pds/src/validation/mod.rs
··· 21 21 BannedContent { path: String }, 22 22 } 23 23 24 - #[derive(Debug, Clone, Copy, PartialEq, Eq)] 24 + #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)] 25 + #[serde(rename_all = "lowercase")] 25 26 pub enum ValidationStatus { 26 27 Valid, 27 28 Unknown,
+2 -10
crates/tranquil-pds/tests/account_notifications.rs
··· 112 112 .send() 113 113 .await 114 114 .unwrap(); 115 - assert!( 116 - resp.status() == 400 || resp.status() == 422, 117 - "Expected 400 or 422, got {}", 118 - resp.status() 119 - ); 115 + assert_eq!(resp.status(), 400); 120 116 } 121 117 122 118 #[tokio::test] ··· 137 133 .send() 138 134 .await 139 135 .unwrap(); 140 - assert!( 141 - resp.status() == 400 || resp.status() == 422, 142 - "Expected 400 or 422, got {}", 143 - resp.status() 144 - ); 136 + assert_eq!(resp.status(), 400); 145 137 } 146 138 147 139 #[tokio::test]
+3 -1
crates/tranquil-pds/tests/commit_signing.rs
··· 99 99 use tranquil_pds::api::repo::record::utils::create_signed_commit; 100 100 101 101 let signing_key = SigningKey::random(&mut rand::thread_rng()); 102 - let did = unsafe { Did::new_unchecked("did:plc:testuser123456789abcdef") }; 102 + let did: Did = "did:plc:testuser123456789abcdef" 103 + .parse() 104 + .expect("valid test DID"); 103 105 let data_cid = 104 106 Cid::from_str("bafyreib2rxk3ryblouj3fxza5jvx6psmwewwessc4m6g6e7pqhhkwqomfi").unwrap(); 105 107 let rev = Tid::now(LimitedU32::MIN).to_string();
+4 -1
crates/tranquil-pds/tests/common/mod.rs
··· 98 98 99 99 #[allow(dead_code)] 100 100 pub fn client() -> Client { 101 - Client::new() 101 + Client::builder() 102 + .timeout(Duration::from_secs(120)) 103 + .build() 104 + .expect("Failed to build HTTP client") 102 105 } 103 106 104 107 #[allow(dead_code)]
+3 -3
crates/tranquil-pds/tests/delete_account.rs
··· 365 365 .send() 366 366 .await 367 367 .expect("Failed to send request"); 368 - assert_eq!(res1.status(), StatusCode::UNPROCESSABLE_ENTITY); 368 + assert_eq!(res1.status(), StatusCode::BAD_REQUEST); 369 369 let res2 = client 370 370 .post(format!( 371 371 "{}/xrpc/com.atproto.server.deleteAccount", ··· 378 378 .send() 379 379 .await 380 380 .expect("Failed to send request"); 381 - assert_eq!(res2.status(), StatusCode::UNPROCESSABLE_ENTITY); 381 + assert_eq!(res2.status(), StatusCode::BAD_REQUEST); 382 382 let res3 = client 383 383 .post(format!( 384 384 "{}/xrpc/com.atproto.server.deleteAccount", ··· 391 391 .send() 392 392 .await 393 393 .expect("Failed to send request"); 394 - assert_eq!(res3.status(), StatusCode::UNPROCESSABLE_ENTITY); 394 + assert_eq!(res3.status(), StatusCode::BAD_REQUEST); 395 395 } 396 396 397 397 #[tokio::test]
+439
crates/tranquil-pds/tests/firehose/mod.rs
··· 1 + use cid::Cid; 2 + use futures::stream::StreamExt; 3 + use serde::Deserialize; 4 + use std::sync::{Arc, Mutex}; 5 + use tokio::task::JoinHandle; 6 + use tokio_tungstenite::{connect_async, tungstenite}; 7 + use tokio_util::sync::CancellationToken; 8 + use tranquil_scopes::RepoAction; 9 + 10 + #[derive(Debug)] 11 + pub enum FirehoseFrame { 12 + Commit(Box<ParsedCommitFrame>), 13 + Identity(IdentityData), 14 + Account(AccountData), 15 + Info(InfoData), 16 + Error(ErrorData), 17 + Unknown(Vec<u8>), 18 + } 19 + 20 + #[allow(dead_code)] 21 + #[derive(Debug, Clone)] 22 + pub struct ParsedCommitFrame { 23 + pub seq: i64, 24 + pub repo: String, 25 + pub commit: Cid, 26 + pub rev: String, 27 + pub since: Option<String>, 28 + pub blocks: Vec<u8>, 29 + pub ops: Vec<ParsedRepoOp>, 30 + pub blobs: Vec<Cid>, 31 + pub time: String, 32 + pub prev_data: Option<Cid>, 33 + } 34 + 35 + #[derive(Debug, Clone)] 36 + pub struct ParsedRepoOp { 37 + pub action: RepoAction, 38 + pub path: String, 39 + pub cid: Option<Cid>, 40 + pub prev: Option<Cid>, 41 + } 42 + 43 + #[allow(dead_code)] 44 + #[derive(Debug, Clone)] 45 + pub struct IdentityData { 46 + pub did: String, 47 + pub seq: i64, 48 + } 49 + 50 + #[allow(dead_code)] 51 + #[derive(Debug, Clone)] 52 + pub struct AccountData { 53 + pub did: String, 54 + pub seq: i64, 55 + pub active: bool, 56 + } 57 + 58 + #[allow(dead_code)] 59 + #[derive(Debug, Clone)] 60 + pub struct InfoData { 61 + pub name: String, 62 + pub message: Option<String>, 63 + } 64 + 65 + #[allow(dead_code)] 66 + #[derive(Debug, Clone)] 67 + pub struct ErrorData { 68 + pub error: String, 69 + pub message: Option<String>, 70 + } 71 + 72 + #[derive(Debug, Deserialize)] 73 + struct RawFrameHeader { 74 + op: i64, 75 + #[serde(default)] 76 + t: Option<String>, 77 + } 78 + 79 + #[derive(Debug, Deserialize)] 80 + struct RawCommitBody { 81 + seq: i64, 82 + repo: String, 83 + commit: Cid, 84 + rev: String, 85 + since: Option<String>, 86 + #[serde(with = "serde_bytes")] 87 + blocks: Vec<u8>, 88 + ops: Vec<RawOp>, 89 + #[serde(default)] 90 + blobs: Vec<Cid>, 91 + time: String, 92 + #[serde(rename = "prevData")] 93 + prev_data: Option<Cid>, 94 + } 95 + 96 + #[derive(Debug, Deserialize)] 97 + struct RawOp { 98 + action: RepoAction, 99 + path: String, 100 + cid: Option<Cid>, 101 + prev: Option<Cid>, 102 + } 103 + 104 + #[derive(Debug, Deserialize)] 105 + struct RawIdentityBody { 106 + did: String, 107 + seq: i64, 108 + } 109 + 110 + #[derive(Debug, Deserialize)] 111 + struct RawAccountBody { 112 + did: String, 113 + seq: i64, 114 + active: bool, 115 + } 116 + 117 + #[derive(Debug, Deserialize)] 118 + struct RawInfoBody { 119 + name: String, 120 + message: Option<String>, 121 + } 122 + 123 + #[derive(Debug, Deserialize)] 124 + struct RawErrorBody { 125 + error: String, 126 + message: Option<String>, 127 + } 128 + 129 + pub struct FirehoseConsumer { 130 + frames: Arc<Mutex<Vec<FirehoseFrame>>>, 131 + cancel: CancellationToken, 132 + handle: JoinHandle<()>, 133 + } 134 + 135 + impl FirehoseConsumer { 136 + pub async fn connect(port: u16) -> Self { 137 + Self::connect_inner(port, None).await 138 + } 139 + 140 + pub async fn connect_with_cursor(port: u16, cursor: i64) -> Self { 141 + Self::connect_inner(port, Some(cursor)).await 142 + } 143 + 144 + async fn connect_inner(port: u16, cursor: Option<i64>) -> Self { 145 + let url = match cursor { 146 + Some(c) => format!( 147 + "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos?cursor={}", 148 + port, c 149 + ), 150 + None => format!( 151 + "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos", 152 + port 153 + ), 154 + }; 155 + let (ws_stream, _) = connect_async(&url) 156 + .await 157 + .expect("Failed to connect to firehose"); 158 + let frames: Arc<Mutex<Vec<FirehoseFrame>>> = Arc::new(Mutex::new(Vec::new())); 159 + let cancel = CancellationToken::new(); 160 + 161 + let frames_clone = frames.clone(); 162 + let cancel_clone = cancel.clone(); 163 + 164 + let handle = tokio::spawn(async move { 165 + let (_, mut read) = ws_stream.split(); 166 + loop { 167 + tokio::select! { 168 + _ = cancel_clone.cancelled() => break, 169 + msg = read.next() => { 170 + match msg { 171 + Some(Ok(tungstenite::Message::Binary(bin))) => { 172 + let frame = parse_frame_bytes(&bin); 173 + frames_clone.lock().unwrap().push(frame); 174 + } 175 + Some(Ok(tungstenite::Message::Close(_))) | None => break, 176 + _ => {} 177 + } 178 + } 179 + } 180 + } 181 + }); 182 + 183 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 184 + 185 + Self { 186 + frames, 187 + cancel, 188 + handle, 189 + } 190 + } 191 + 192 + pub async fn wait_for_commits( 193 + &self, 194 + did: &str, 195 + count: usize, 196 + timeout: std::time::Duration, 197 + ) -> Vec<ParsedCommitFrame> { 198 + let deadline = tokio::time::Instant::now() + timeout; 199 + loop { 200 + let matching: Vec<ParsedCommitFrame> = self 201 + .frames 202 + .lock() 203 + .unwrap() 204 + .iter() 205 + .filter_map(|f| match f { 206 + FirehoseFrame::Commit(c) if c.repo == did => Some(ParsedCommitFrame::clone(c)), 207 + _ => None, 208 + }) 209 + .collect(); 210 + if matching.len() >= count { 211 + return matching; 212 + } 213 + if tokio::time::Instant::now() >= deadline { 214 + panic!( 215 + "Timed out waiting for {} commits for DID {}, got {}", 216 + count, 217 + did, 218 + matching.len() 219 + ); 220 + } 221 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 222 + } 223 + } 224 + 225 + #[allow(dead_code)] 226 + pub fn all_frames(&self) -> Vec<FirehoseFrame> { 227 + self.frames.lock().unwrap().drain(..).collect() 228 + } 229 + 230 + pub fn all_commits(&self) -> Vec<ParsedCommitFrame> { 231 + self.frames 232 + .lock() 233 + .unwrap() 234 + .iter() 235 + .filter_map(|f| match f { 236 + FirehoseFrame::Commit(c) => Some(ParsedCommitFrame::clone(c)), 237 + _ => None, 238 + }) 239 + .collect() 240 + } 241 + } 242 + 243 + impl Drop for FirehoseConsumer { 244 + fn drop(&mut self) { 245 + self.cancel.cancel(); 246 + self.handle.abort(); 247 + } 248 + } 249 + 250 + impl Clone for FirehoseFrame { 251 + fn clone(&self) -> Self { 252 + match self { 253 + Self::Commit(c) => Self::Commit(c.clone()), 254 + Self::Identity(i) => Self::Identity(i.clone()), 255 + Self::Account(a) => Self::Account(a.clone()), 256 + Self::Info(i) => Self::Info(i.clone()), 257 + Self::Error(e) => Self::Error(e.clone()), 258 + Self::Unknown(b) => Self::Unknown(b.clone()), 259 + } 260 + } 261 + } 262 + 263 + fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> { 264 + let mut pos = 0; 265 + 266 + fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> { 267 + match additional { 268 + 0..=23 => Ok(additional as u64), 269 + 24 => { 270 + if *pos >= bytes.len() { 271 + return Err("Unexpected end".into()); 272 + } 273 + let val = bytes[*pos] as u64; 274 + *pos += 1; 275 + Ok(val) 276 + } 277 + 25 => { 278 + if *pos + 2 > bytes.len() { 279 + return Err("Unexpected end".into()); 280 + } 281 + let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64; 282 + *pos += 2; 283 + Ok(val) 284 + } 285 + 26 => { 286 + if *pos + 4 > bytes.len() { 287 + return Err("Unexpected end".into()); 288 + } 289 + let val = u32::from_be_bytes([ 290 + bytes[*pos], 291 + bytes[*pos + 1], 292 + bytes[*pos + 2], 293 + bytes[*pos + 3], 294 + ]) as u64; 295 + *pos += 4; 296 + Ok(val) 297 + } 298 + 27 => { 299 + if *pos + 8 > bytes.len() { 300 + return Err("Unexpected end".into()); 301 + } 302 + let val = u64::from_be_bytes([ 303 + bytes[*pos], 304 + bytes[*pos + 1], 305 + bytes[*pos + 2], 306 + bytes[*pos + 3], 307 + bytes[*pos + 4], 308 + bytes[*pos + 5], 309 + bytes[*pos + 6], 310 + bytes[*pos + 7], 311 + ]); 312 + *pos += 8; 313 + Ok(val) 314 + } 315 + _ => Err(format!("Invalid additional info: {}", additional)), 316 + } 317 + } 318 + 319 + fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> { 320 + if *pos >= bytes.len() { 321 + return Err("Unexpected end".into()); 322 + } 323 + let initial = bytes[*pos]; 324 + *pos += 1; 325 + let major = initial >> 5; 326 + let additional = initial & 0x1f; 327 + 328 + match major { 329 + 0 | 1 => { 330 + read_uint(bytes, pos, additional)?; 331 + Ok(()) 332 + } 333 + 2 | 3 => { 334 + let len = read_uint(bytes, pos, additional)? as usize; 335 + *pos += len; 336 + Ok(()) 337 + } 338 + 4 => { 339 + let len = read_uint(bytes, pos, additional)?; 340 + (0..len).try_for_each(|_| skip_value(bytes, pos)) 341 + } 342 + 5 => { 343 + let len = read_uint(bytes, pos, additional)?; 344 + (0..len).try_for_each(|_| { 345 + skip_value(bytes, pos)?; 346 + skip_value(bytes, pos) 347 + }) 348 + } 349 + 6 => { 350 + read_uint(bytes, pos, additional)?; 351 + skip_value(bytes, pos) 352 + } 353 + 7 => Ok(()), 354 + _ => Err(format!("Unknown major type: {}", major)), 355 + } 356 + } 357 + 358 + skip_value(bytes, &mut pos)?; 359 + Ok(pos) 360 + } 361 + 362 + pub fn parse_frame_bytes(raw: &[u8]) -> FirehoseFrame { 363 + let header_end = match find_cbor_map_end(raw) { 364 + Ok(e) => e, 365 + Err(_) => return FirehoseFrame::Unknown(raw.to_vec()), 366 + }; 367 + 368 + let header: RawFrameHeader = match serde_ipld_dagcbor::from_slice(&raw[..header_end]) { 369 + Ok(h) => h, 370 + Err(_) => return FirehoseFrame::Unknown(raw.to_vec()), 371 + }; 372 + 373 + let body = &raw[header_end..]; 374 + 375 + if header.op == -1 { 376 + return serde_ipld_dagcbor::from_slice::<RawErrorBody>(body) 377 + .map(|b| { 378 + FirehoseFrame::Error(ErrorData { 379 + error: b.error, 380 + message: b.message, 381 + }) 382 + }) 383 + .unwrap_or_else(|_| FirehoseFrame::Unknown(raw.to_vec())); 384 + } 385 + 386 + match header.t.as_deref() { 387 + Some("#commit") => serde_ipld_dagcbor::from_slice::<RawCommitBody>(body) 388 + .map(|b| { 389 + FirehoseFrame::Commit(Box::new(ParsedCommitFrame { 390 + seq: b.seq, 391 + repo: b.repo, 392 + commit: b.commit, 393 + rev: b.rev, 394 + since: b.since, 395 + blocks: b.blocks, 396 + ops: b 397 + .ops 398 + .into_iter() 399 + .map(|op| ParsedRepoOp { 400 + action: op.action, 401 + path: op.path, 402 + cid: op.cid, 403 + prev: op.prev, 404 + }) 405 + .collect(), 406 + blobs: b.blobs, 407 + time: b.time, 408 + prev_data: b.prev_data, 409 + })) 410 + }) 411 + .unwrap_or_else(|_| FirehoseFrame::Unknown(raw.to_vec())), 412 + Some("#identity") => serde_ipld_dagcbor::from_slice::<RawIdentityBody>(body) 413 + .map(|b| { 414 + FirehoseFrame::Identity(IdentityData { 415 + did: b.did, 416 + seq: b.seq, 417 + }) 418 + }) 419 + .unwrap_or_else(|_| FirehoseFrame::Unknown(raw.to_vec())), 420 + Some("#account") => serde_ipld_dagcbor::from_slice::<RawAccountBody>(body) 421 + .map(|b| { 422 + FirehoseFrame::Account(AccountData { 423 + did: b.did, 424 + seq: b.seq, 425 + active: b.active, 426 + }) 427 + }) 428 + .unwrap_or_else(|_| FirehoseFrame::Unknown(raw.to_vec())), 429 + Some("#info") => serde_ipld_dagcbor::from_slice::<RawInfoBody>(body) 430 + .map(|b| { 431 + FirehoseFrame::Info(InfoData { 432 + name: b.name, 433 + message: b.message, 434 + }) 435 + }) 436 + .unwrap_or_else(|_| FirehoseFrame::Unknown(raw.to_vec())), 437 + _ => FirehoseFrame::Unknown(raw.to_vec()), 438 + } 439 + }
+2 -2
crates/tranquil-pds/tests/identity.rs
··· 547 547 .send() 548 548 .await 549 549 .expect("Failed to send request"); 550 - assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); 550 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 551 551 } 552 552 553 553 #[tokio::test] ··· 582 582 .send() 583 583 .await 584 584 .expect("Failed to send request"); 585 - assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); 585 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 586 586 } 587 587 588 588 #[tokio::test]
+28 -25
crates/tranquil-pds/tests/jwt_security.rs
··· 10 10 use serde_json::{Value, json}; 11 11 use sha2::{Digest, Sha256}; 12 12 use tranquil_pds::auth::{ 13 - self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, 14 - TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token, 15 - create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token, 16 - verify_access_token, verify_refresh_token, verify_token, 13 + self, TokenScope, TokenType, create_access_token, create_refresh_token, create_service_token, 14 + get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, 15 + verify_token, 17 16 }; 18 17 19 18 fn generate_user_key() -> Vec<u8> { ··· 100 99 let key_bytes = generate_user_key(); 101 100 let did = "did:plc:test"; 102 101 103 - let none_header = json!({ "alg": "none", "typ": TOKEN_TYPE_ACCESS }); 102 + let none_header = json!({ "alg": "none", "typ": TokenType::Access.as_str() }); 104 103 let claims = json!({ 105 104 "iss": did, "sub": did, "aud": "did:web:test.pds", 106 105 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 107 - "jti": "attack-token", "scope": SCOPE_ACCESS 106 + "jti": "attack-token", "scope": TokenScope::Access.as_str() 108 107 }); 109 108 let none_token = create_unsigned_jwt(&none_header, &claims); 110 109 assert!( ··· 112 111 "Algorithm 'none' must be rejected" 113 112 ); 114 113 115 - let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS }); 114 + let hs256_header = json!({ "alg": "HS256", "typ": TokenType::Access.as_str() }); 116 115 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap()); 117 116 let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 118 117 use hmac::{Hmac, Mac}; ··· 128 127 ); 129 128 130 129 for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] { 131 - let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS }); 130 + let header = json!({ "alg": alg, "typ": TokenType::Access.as_str() }); 132 131 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 133 132 let fake_sig = URL_SAFE_NO_PAD.encode(vec![1u8; sig_len]); 134 133 let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig); ··· 179 178 fn test_scope_validation() { 180 179 let key_bytes = generate_user_key(); 181 180 let did = "did:plc:test"; 182 - let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 181 + let header = json!({ "alg": "ES256K", "typ": TokenType::Access.as_str() }); 183 182 184 183 let invalid_scope = json!({ 185 184 "iss": did, "sub": did, "aud": "did:web:test.pds", ··· 225 224 .is_err() 226 225 ); 227 226 228 - for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] { 227 + for scope in [ 228 + TokenScope::Access.as_str(), 229 + TokenScope::AppPass.as_str(), 230 + TokenScope::AppPassPrivileged.as_str(), 231 + ] { 229 232 let claims = json!({ 230 233 "iss": did, "sub": did, "aud": "did:web:test.pds", 231 234 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, ··· 240 243 let refresh_scope = json!({ 241 244 "iss": did, "sub": did, "aud": "did:web:test.pds", 242 245 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 243 - "jti": "test", "scope": SCOPE_REFRESH 246 + "jti": "test", "scope": TokenScope::Refresh.as_str() 244 247 }); 245 248 assert!( 246 249 verify_access_token( ··· 255 258 fn test_expiration_and_timing() { 256 259 let key_bytes = generate_user_key(); 257 260 let did = "did:plc:test"; 258 - let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 261 + let header = json!({ "alg": "ES256K", "typ": TokenType::Access.as_str() }); 259 262 let now = Utc::now().timestamp(); 260 263 261 264 let expired = json!({ 262 265 "iss": did, "sub": did, "aud": "did:web:test.pds", 263 - "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS 266 + "iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": TokenScope::Access.as_str() 264 267 }); 265 268 let result = verify_access_token( 266 269 &create_custom_jwt(&header, &expired, &key_bytes), ··· 270 273 271 274 let future_iat = json!({ 272 275 "iss": did, "sub": did, "aud": "did:web:test.pds", 273 - "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS 276 + "iat": now + 60, "exp": now + 7200, "jti": "test", "scope": TokenScope::Access.as_str() 274 277 }); 275 278 assert!( 276 279 verify_access_token( ··· 282 285 283 286 let just_expired = json!({ 284 287 "iss": did, "sub": did, "aud": "did:web:test.pds", 285 - "iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS 288 + "iat": now - 10, "exp": now - 1, "jti": "test", "scope": TokenScope::Access.as_str() 286 289 }); 287 290 assert!( 288 291 verify_access_token( ··· 294 297 295 298 let far_future = json!({ 296 299 "iss": did, "sub": did, "aud": "did:web:test.pds", 297 - "iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS 300 + "iat": now, "exp": i64::MAX, "jti": "test", "scope": TokenScope::Access.as_str() 298 301 }); 299 302 let _ = verify_access_token( 300 303 &create_custom_jwt(&header, &far_future, &key_bytes), ··· 303 306 304 307 let negative_iat = json!({ 305 308 "iss": did, "sub": did, "aud": "did:web:test.pds", 306 - "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS 309 + "iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": TokenScope::Access.as_str() 307 310 }); 308 311 let _ = verify_access_token( 309 312 &create_custom_jwt(&header, &negative_iat, &key_bytes), ··· 359 362 fn test_claim_validation() { 360 363 let key_bytes = generate_user_key(); 361 364 let did = "did:plc:test"; 362 - let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS }); 365 + let header = json!({ "alg": "ES256K", "typ": TokenType::Access.as_str() }); 363 366 364 367 let missing_exp = json!({ 365 368 "iss": did, "sub": did, "aud": "did:web:test", 366 - "iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS 369 + "iat": Utc::now().timestamp(), "scope": TokenScope::Access.as_str() 367 370 }); 368 371 assert!( 369 372 verify_access_token( ··· 375 378 376 379 let missing_iat = json!({ 377 380 "iss": did, "sub": did, "aud": "did:web:test", 378 - "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 381 + "exp": Utc::now().timestamp() + 3600, "scope": TokenScope::Access.as_str() 379 382 }); 380 383 assert!( 381 384 verify_access_token( ··· 387 390 388 391 let missing_sub = json!({ 389 392 "iss": did, "aud": "did:web:test", 390 - "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS 393 + "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": TokenScope::Access.as_str() 391 394 }); 392 395 assert!( 393 396 verify_access_token( ··· 399 402 400 403 let wrong_types = json!({ 401 404 "iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"}, 402 - "iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS 405 + "iat": "not a number", "exp": "also not a number", "jti": null, "scope": TokenScope::Access.as_str() 403 406 }); 404 407 assert!( 405 408 verify_access_token( ··· 412 415 let unicode_injection = json!({ 413 416 "iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta", 414 417 "aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 415 - "jti": "test", "scope": SCOPE_ACCESS 418 + "jti": "test", "scope": TokenScope::Access.as_str() 416 419 }); 417 420 if let Ok(data) = verify_access_token( 418 421 &create_custom_jwt(&header, &unicode_injection, &key_bytes), ··· 453 456 let did = "did:plc:test"; 454 457 455 458 let header = json!({ 456 - "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS, 459 + "alg": "ES256K", "typ": TokenType::Access.as_str(), 457 460 "kid": "../../../../../../etc/passwd", "jku": "https://attacker.com/keys" 458 461 }); 459 462 let claims = json!({ 460 463 "iss": did, "sub": did, "aud": "did:web:test.pds", 461 464 "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, 462 - "jti": "test", "scope": SCOPE_ACCESS 465 + "jti": "test", "scope": TokenScope::Access.as_str() 463 466 }); 464 467 assert!( 465 468 verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok()
+5 -5
crates/tranquil-pds/tests/plc_validation.rs
··· 2 2 use serde_json::json; 3 3 use std::collections::HashMap; 4 4 use tranquil_pds::plc::{ 5 - PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation, 6 - signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission, 7 - verify_operation_signature, 5 + PlcError, PlcOpType, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, 6 + sign_operation, signing_key_to_did_key, validate_plc_operation, 7 + validate_plc_operation_for_submission, verify_operation_signature, 8 8 }; 9 9 10 10 fn create_valid_operation() -> serde_json::Value { ··· 264 264 services.insert( 265 265 "atproto_pds".to_string(), 266 266 PlcService { 267 - service_type: "AtprotoPersonalDataServer".to_string(), 267 + service_type: tranquil_pds::plc::ServiceType::Pds, 268 268 endpoint: "https://pds.example.com".to_string(), 269 269 }, 270 270 ); 271 271 let mut verification_methods = HashMap::new(); 272 272 verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 273 273 let op = PlcOperation { 274 - op_type: "plc_operation".to_string(), 274 + op_type: PlcOpType::Operation, 275 275 rotation_keys: vec!["did:key:zTest123".to_string()], 276 276 verification_methods, 277 277 also_known_as: vec!["at://test.handle".to_string()],
+581
crates/tranquil-pds/tests/repo_lifecycle.rs
··· 1 + mod common; 2 + mod firehose; 3 + 4 + use cid::Cid; 5 + use common::*; 6 + use firehose::{FirehoseConsumer, ParsedCommitFrame}; 7 + use iroh_car::CarReader; 8 + use jacquard_repo::commit::Commit; 9 + use reqwest::StatusCode; 10 + use serde_json::{Value, json}; 11 + use std::io::Cursor; 12 + use std::str::FromStr; 13 + use tranquil_scopes::RepoAction; 14 + 15 + mod helpers; 16 + 17 + async fn create_post_record(client: &reqwest::Client, token: &str, did: &str, text: &str) -> Value { 18 + let payload = json!({ 19 + "repo": did, 20 + "collection": "app.bsky.feed.post", 21 + "record": { 22 + "$type": "app.bsky.feed.post", 23 + "text": text, 24 + "createdAt": chrono::Utc::now().to_rfc3339(), 25 + } 26 + }); 27 + let res = client 28 + .post(format!( 29 + "{}/xrpc/com.atproto.repo.createRecord", 30 + base_url().await 31 + )) 32 + .bearer_auth(token) 33 + .json(&payload) 34 + .send() 35 + .await 36 + .expect("Failed to create post"); 37 + assert_eq!(res.status(), StatusCode::OK); 38 + res.json().await.expect("Invalid JSON from createRecord") 39 + } 40 + 41 + async fn get_latest_commit(client: &reqwest::Client, token: &str, did: &str) -> Value { 42 + let res = client 43 + .get(format!( 44 + "{}/xrpc/com.atproto.sync.getLatestCommit?did={}", 45 + base_url().await, 46 + did 47 + )) 48 + .bearer_auth(token) 49 + .send() 50 + .await 51 + .expect("Failed to get latest commit"); 52 + assert_eq!(res.status(), StatusCode::OK); 53 + res.json().await.expect("Invalid JSON from getLatestCommit") 54 + } 55 + 56 + #[tokio::test] 57 + async fn test_create_record_cid_matches_firehose() { 58 + let client = client(); 59 + let (token, did) = create_account_and_login(&client).await; 60 + 61 + let pool = get_test_db_pool().await; 62 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 63 + .fetch_one(pool) 64 + .await 65 + .unwrap(); 66 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 67 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 68 + 69 + let api_response = create_post_record(&client, &token, &did, "CID match test").await; 70 + let api_commit_cid = api_response["commit"]["cid"].as_str().unwrap(); 71 + let api_commit_rev = api_response["commit"]["rev"].as_str().unwrap(); 72 + let api_record_cid = api_response["cid"].as_str().unwrap(); 73 + 74 + let frames = consumer 75 + .wait_for_commits(&did, 1, std::time::Duration::from_secs(10)) 76 + .await; 77 + let frame = &frames[0]; 78 + 79 + assert_eq!( 80 + api_commit_cid, 81 + frame.commit.to_string(), 82 + "API commit CID must match firehose commit CID" 83 + ); 84 + assert_eq!( 85 + api_commit_rev, frame.rev, 86 + "API commit rev must match firehose rev" 87 + ); 88 + assert_eq!(frame.ops.len(), 1, "Expected exactly 1 op"); 89 + assert_eq!( 90 + api_record_cid, 91 + frame.ops[0].cid.unwrap().to_string(), 92 + "API record CID must match firehose op CID" 93 + ); 94 + assert_eq!(frame.ops[0].action, RepoAction::Create); 95 + assert!(frame.ops[0].prev.is_none(), "Create op must have no prev"); 96 + 97 + let latest = get_latest_commit(&client, &token, &did).await; 98 + assert_eq!( 99 + latest["cid"].as_str().unwrap(), 100 + api_commit_cid, 101 + "getLatestCommit CID must match" 102 + ); 103 + assert_eq!( 104 + latest["rev"].as_str().unwrap(), 105 + api_commit_rev, 106 + "getLatestCommit rev must match" 107 + ); 108 + } 109 + 110 + #[tokio::test] 111 + async fn test_update_record_prev_matches_old_cid() { 112 + let client = client(); 113 + let (token, did) = create_account_and_login(&client).await; 114 + 115 + let v1_payload = json!({ 116 + "repo": did, 117 + "collection": "app.bsky.actor.profile", 118 + "rkey": "self", 119 + "record": { 120 + "$type": "app.bsky.actor.profile", 121 + "displayName": "Profile v1", 122 + } 123 + }); 124 + let v1_res = client 125 + .post(format!( 126 + "{}/xrpc/com.atproto.repo.putRecord", 127 + base_url().await 128 + )) 129 + .bearer_auth(&token) 130 + .json(&v1_payload) 131 + .send() 132 + .await 133 + .expect("Failed to create profile v1"); 134 + assert_eq!(v1_res.status(), StatusCode::OK); 135 + let v1_body: Value = v1_res.json().await.unwrap(); 136 + let v1_cid_str = v1_body["cid"].as_str().unwrap(); 137 + let v1_cid = Cid::from_str(v1_cid_str).unwrap(); 138 + 139 + let pool = get_test_db_pool().await; 140 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 141 + .fetch_one(pool) 142 + .await 143 + .unwrap(); 144 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 145 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 146 + 147 + let v2_payload = json!({ 148 + "repo": did, 149 + "collection": "app.bsky.actor.profile", 150 + "rkey": "self", 151 + "record": { 152 + "$type": "app.bsky.actor.profile", 153 + "displayName": "Profile v2", 154 + } 155 + }); 156 + let v2_res = client 157 + .post(format!( 158 + "{}/xrpc/com.atproto.repo.putRecord", 159 + base_url().await 160 + )) 161 + .bearer_auth(&token) 162 + .json(&v2_payload) 163 + .send() 164 + .await 165 + .expect("Failed to update profile v2"); 166 + assert_eq!(v2_res.status(), StatusCode::OK); 167 + let v2_body: Value = v2_res.json().await.unwrap(); 168 + let v2_cid_str = v2_body["cid"].as_str().unwrap(); 169 + let v2_cid = Cid::from_str(v2_cid_str).unwrap(); 170 + 171 + let frames = consumer 172 + .wait_for_commits(&did, 1, std::time::Duration::from_secs(10)) 173 + .await; 174 + let frame = &frames[0]; 175 + 176 + let profile_op = frame 177 + .ops 178 + .iter() 179 + .find(|op| op.path.contains("app.bsky.actor.profile")) 180 + .expect("No profile op found"); 181 + 182 + assert_eq!(profile_op.action, RepoAction::Update); 183 + assert_eq!( 184 + profile_op.prev, 185 + Some(v1_cid), 186 + "Update op.prev must be the old CID" 187 + ); 188 + assert_eq!( 189 + profile_op.cid, 190 + Some(v2_cid), 191 + "Update op.cid must be the new CID" 192 + ); 193 + assert!( 194 + frame.prev_data.is_some(), 195 + "Update commit must have prevData" 196 + ); 197 + } 198 + 199 + #[tokio::test] 200 + async fn test_delete_record_prev_set_cid_none() { 201 + let client = client(); 202 + let (token, did) = create_account_and_login(&client).await; 203 + 204 + let create_body = create_post_record(&client, &token, &did, "To be deleted").await; 205 + let record_cid = Cid::from_str(create_body["cid"].as_str().unwrap()).unwrap(); 206 + let uri = create_body["uri"].as_str().unwrap(); 207 + let parts: Vec<&str> = uri.split('/').collect(); 208 + let collection = parts[parts.len() - 2]; 209 + let rkey = parts[parts.len() - 1]; 210 + 211 + let pool = get_test_db_pool().await; 212 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 213 + .fetch_one(pool) 214 + .await 215 + .unwrap(); 216 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 217 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 218 + 219 + let delete_payload = json!({ 220 + "repo": did, 221 + "collection": collection, 222 + "rkey": rkey, 223 + }); 224 + let del_res = client 225 + .post(format!( 226 + "{}/xrpc/com.atproto.repo.deleteRecord", 227 + base_url().await 228 + )) 229 + .bearer_auth(&token) 230 + .json(&delete_payload) 231 + .send() 232 + .await 233 + .expect("Failed to delete record"); 234 + assert_eq!(del_res.status(), StatusCode::OK); 235 + 236 + let frames = consumer 237 + .wait_for_commits(&did, 1, std::time::Duration::from_secs(10)) 238 + .await; 239 + let frame = &frames[0]; 240 + 241 + assert_eq!(frame.ops.len(), 1, "Expected exactly 1 delete op"); 242 + let op = &frame.ops[0]; 243 + assert_eq!(op.action, RepoAction::Delete); 244 + assert!(op.cid.is_none(), "Delete op.cid must be None"); 245 + assert_eq!( 246 + op.prev, 247 + Some(record_cid), 248 + "Delete op.prev must be the original CID" 249 + ); 250 + } 251 + 252 + #[tokio::test] 253 + async fn test_five_record_commit_chain_integrity() { 254 + let client = client(); 255 + let (token, did) = create_account_and_login(&client).await; 256 + 257 + let pool = get_test_db_pool().await; 258 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 259 + .fetch_one(pool) 260 + .await 261 + .unwrap(); 262 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 263 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 264 + 265 + let texts = [ 266 + "Chain post 0", 267 + "Chain post 1", 268 + "Chain post 2", 269 + "Chain post 3", 270 + "Chain post 4", 271 + ]; 272 + for text in &texts { 273 + create_post_record(&client, &token, &did, text).await; 274 + } 275 + 276 + let mut frames = consumer 277 + .wait_for_commits(&did, 5, std::time::Duration::from_secs(15)) 278 + .await; 279 + frames.sort_by_key(|f| f.seq); 280 + 281 + let revs: Vec<&str> = frames.iter().map(|f| f.rev.as_str()).collect(); 282 + let unique_revs: std::collections::HashSet<&&str> = revs.iter().collect(); 283 + assert_eq!( 284 + unique_revs.len(), 285 + 5, 286 + "All rev values must be distinct, got: {:?}", 287 + revs 288 + ); 289 + 290 + let seqs: Vec<i64> = frames.iter().map(|f| f.seq).collect(); 291 + seqs.windows(2).for_each(|pair| { 292 + assert!( 293 + pair[1] > pair[0], 294 + "Seq values must be strictly monotonically increasing: {} <= {}", 295 + pair[1], 296 + pair[0] 297 + ); 298 + }); 299 + 300 + frames.iter().enumerate().skip(1).for_each(|(i, frame)| { 301 + assert_eq!( 302 + frame.since.as_deref(), 303 + Some(frames[i - 1].rev.as_str()), 304 + "Frame {} since must equal frame {} rev", 305 + i, 306 + i - 1 307 + ); 308 + }); 309 + 310 + let latest = get_latest_commit(&client, &token, &did).await; 311 + let final_frame = frames.last().unwrap(); 312 + assert_eq!( 313 + latest["cid"].as_str().unwrap(), 314 + final_frame.commit.to_string(), 315 + "getLatestCommit CID must match final frame" 316 + ); 317 + assert_eq!( 318 + latest["rev"].as_str().unwrap(), 319 + final_frame.rev, 320 + "getLatestCommit rev must match final frame" 321 + ); 322 + } 323 + 324 + #[tokio::test] 325 + async fn test_apply_writes_single_commit_multiple_ops() { 326 + let client = client(); 327 + let (token, did) = create_account_and_login(&client).await; 328 + 329 + let pool = get_test_db_pool().await; 330 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 331 + .fetch_one(pool) 332 + .await 333 + .unwrap(); 334 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 335 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 336 + 337 + let now = chrono::Utc::now().to_rfc3339(); 338 + let writes: Vec<Value> = (0..3) 339 + .map(|i| { 340 + json!({ 341 + "$type": "com.atproto.repo.applyWrites#create", 342 + "collection": "app.bsky.feed.post", 343 + "value": { 344 + "$type": "app.bsky.feed.post", 345 + "text": format!("Batch post {}", i), 346 + "createdAt": now, 347 + } 348 + }) 349 + }) 350 + .collect(); 351 + 352 + let payload = json!({ 353 + "repo": did, 354 + "writes": writes, 355 + }); 356 + let res = client 357 + .post(format!( 358 + "{}/xrpc/com.atproto.repo.applyWrites", 359 + base_url().await 360 + )) 361 + .bearer_auth(&token) 362 + .json(&payload) 363 + .send() 364 + .await 365 + .expect("Failed to applyWrites"); 366 + assert_eq!(res.status(), StatusCode::OK); 367 + let api_body: Value = res.json().await.unwrap(); 368 + let api_results = api_body["results"].as_array().expect("No results array"); 369 + assert_eq!(api_results.len(), 3, "Expected 3 results from applyWrites"); 370 + 371 + let frames = consumer 372 + .wait_for_commits(&did, 1, std::time::Duration::from_secs(10)) 373 + .await; 374 + assert_eq!( 375 + frames.len(), 376 + 1, 377 + "applyWrites should produce exactly 1 commit" 378 + ); 379 + let frame = &frames[0]; 380 + assert_eq!(frame.ops.len(), 3, "Commit should contain 3 ops"); 381 + 382 + frame.ops.iter().for_each(|op| { 383 + assert_eq!(op.action, RepoAction::Create, "All ops should be Create"); 384 + }); 385 + 386 + api_results.iter().enumerate().for_each(|(i, result)| { 387 + let api_cid = result["cid"].as_str().expect("No cid in result"); 388 + let frame_cid = frame.ops[i].cid.expect("No cid in op").to_string(); 389 + assert_eq!( 390 + api_cid, frame_cid, 391 + "API result[{}] CID must match firehose op[{}] CID", 392 + i, i 393 + ); 394 + }); 395 + } 396 + 397 + #[tokio::test] 398 + async fn test_firehose_commit_signature_verification() { 399 + let client = client(); 400 + let (token, did) = create_account_and_login(&client).await; 401 + 402 + let key_bytes = helpers::get_user_signing_key(&did) 403 + .await 404 + .expect("Failed to get signing key"); 405 + let signing_key = 406 + k256::ecdsa::SigningKey::from_slice(&key_bytes).expect("Invalid signing key bytes"); 407 + let pubkey_bytes = signing_key.verifying_key().to_encoded_point(true); 408 + let pubkey = jacquard_common::types::crypto::PublicKey { 409 + codec: jacquard_common::types::crypto::KeyCodec::Secp256k1, 410 + bytes: std::borrow::Cow::Owned(pubkey_bytes.as_bytes().to_vec()), 411 + }; 412 + 413 + let pool = get_test_db_pool().await; 414 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 415 + .fetch_one(pool) 416 + .await 417 + .unwrap(); 418 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 419 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 420 + 421 + let _api_response = 422 + create_post_record(&client, &token, &did, "Signature verification test").await; 423 + 424 + let frames = consumer 425 + .wait_for_commits(&did, 1, std::time::Duration::from_secs(10)) 426 + .await; 427 + let frame = &frames[0]; 428 + 429 + let mut car_reader = CarReader::new(Cursor::new(&frame.blocks)) 430 + .await 431 + .expect("Failed to parse CAR"); 432 + let mut blocks = std::collections::HashMap::new(); 433 + while let Ok(Some((cid, data))) = car_reader.next_block().await { 434 + blocks.insert(cid, data); 435 + } 436 + 437 + let commit_block = blocks 438 + .get(&frame.commit) 439 + .expect("Commit block not found in CAR"); 440 + 441 + let commit = Commit::from_cbor(commit_block).expect("Failed to parse commit from CBOR"); 442 + 443 + commit 444 + .verify(&pubkey) 445 + .expect("Commit signature verification failed"); 446 + 447 + assert_eq!( 448 + commit.rev().to_string(), 449 + frame.rev, 450 + "Commit rev must match frame rev" 451 + ); 452 + assert_eq!( 453 + commit.did().as_str(), 454 + did, 455 + "Commit DID must match account DID" 456 + ); 457 + } 458 + 459 + #[tokio::test] 460 + async fn test_cursor_backfill_completeness() { 461 + let client = client(); 462 + let (token, did) = create_account_and_login(&client).await; 463 + 464 + let pool = get_test_db_pool().await; 465 + let baseline_seq: i64 = 466 + sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 467 + .fetch_one(pool) 468 + .await 469 + .unwrap(); 470 + 471 + let mut expected_cids: Vec<String> = Vec::with_capacity(5); 472 + let texts = [ 473 + "Backfill 0", 474 + "Backfill 1", 475 + "Backfill 2", 476 + "Backfill 3", 477 + "Backfill 4", 478 + ]; 479 + for text in &texts { 480 + let body = create_post_record(&client, &token, &did, text).await; 481 + expected_cids.push(body["commit"]["cid"].as_str().unwrap().to_string()); 482 + } 483 + 484 + tokio::time::sleep(std::time::Duration::from_millis(200)).await; 485 + 486 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), baseline_seq).await; 487 + 488 + let frames = consumer 489 + .wait_for_commits(&did, 5, std::time::Duration::from_secs(15)) 490 + .await; 491 + 492 + let mut sorted_frames: Vec<ParsedCommitFrame> = frames; 493 + sorted_frames.sort_by_key(|f| f.seq); 494 + 495 + let received_cids: Vec<String> = sorted_frames.iter().map(|f| f.commit.to_string()).collect(); 496 + 497 + expected_cids.iter().for_each(|expected| { 498 + assert!( 499 + received_cids.contains(expected), 500 + "Missing commit {} in backfill", 501 + expected 502 + ); 503 + }); 504 + 505 + let seqs: Vec<i64> = sorted_frames.iter().map(|f| f.seq).collect(); 506 + let unique_seqs: std::collections::HashSet<&i64> = seqs.iter().collect(); 507 + assert_eq!( 508 + unique_seqs.len(), 509 + seqs.len(), 510 + "No duplicate seq values allowed in backfill" 511 + ); 512 + } 513 + 514 + #[tokio::test] 515 + async fn test_multi_account_seq_interleaving() { 516 + let client = client(); 517 + let (alice_token, alice_did) = create_account_and_login(&client).await; 518 + let (bob_token, bob_did) = create_account_and_login(&client).await; 519 + 520 + let pool = get_test_db_pool().await; 521 + let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq") 522 + .fetch_one(pool) 523 + .await 524 + .unwrap(); 525 + let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await; 526 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 527 + 528 + let _a1 = create_post_record(&client, &alice_token, &alice_did, "Alice post 1").await; 529 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 530 + let _b1 = create_post_record(&client, &bob_token, &bob_did, "Bob post 1").await; 531 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 532 + let _a2 = create_post_record(&client, &alice_token, &alice_did, "Alice post 2").await; 533 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 534 + let _b2 = create_post_record(&client, &bob_token, &bob_did, "Bob post 2").await; 535 + 536 + let alice_frames = consumer 537 + .wait_for_commits(&alice_did, 2, std::time::Duration::from_secs(10)) 538 + .await; 539 + let bob_frames = consumer 540 + .wait_for_commits(&bob_did, 2, std::time::Duration::from_secs(10)) 541 + .await; 542 + 543 + let mut all_commits = consumer.all_commits(); 544 + all_commits.sort_by_key(|f| f.seq); 545 + 546 + let global_seqs: Vec<i64> = all_commits.iter().map(|f| f.seq).collect(); 547 + global_seqs.windows(2).for_each(|pair| { 548 + assert!( 549 + pair[1] > pair[0], 550 + "Global seq must be strictly monotonically increasing: {} <= {}", 551 + pair[1], 552 + pair[0] 553 + ); 554 + }); 555 + 556 + let mut alice_sorted: Vec<ParsedCommitFrame> = alice_frames; 557 + alice_sorted.sort_by_key(|f| f.seq); 558 + assert_eq!(alice_sorted.len(), 2); 559 + assert!( 560 + alice_sorted[1].since.is_some(), 561 + "Alice's second commit must have since" 562 + ); 563 + assert_eq!( 564 + alice_sorted[1].since.as_deref(), 565 + Some(alice_sorted[0].rev.as_str()), 566 + "Alice's since chain must be self-consistent" 567 + ); 568 + 569 + let mut bob_sorted: Vec<ParsedCommitFrame> = bob_frames; 570 + bob_sorted.sort_by_key(|f| f.seq); 571 + assert_eq!(bob_sorted.len(), 2); 572 + assert!( 573 + bob_sorted[1].since.is_some(), 574 + "Bob's second commit must have since" 575 + ); 576 + assert_eq!( 577 + bob_sorted[1].since.as_deref(), 578 + Some(bob_sorted[0].rev.as_str()), 579 + "Bob's since chain must be self-consistent" 580 + ); 581 + }
+13
crates/tranquil-pds/tests/ripple_cluster.rs
··· 677 677 let nodes = common::cluster().await; 678 678 let client = common::client(); 679 679 680 + let now_ms = u64::try_from( 681 + std::time::SystemTime::now() 682 + .duration_since(std::time::UNIX_EPOCH) 683 + .unwrap() 684 + .as_millis(), 685 + ) 686 + .unwrap_or(u64::MAX); 687 + let login_window_ms: u64 = 60_000; 688 + let remaining = login_window_ms - (now_ms % login_window_ms); 689 + if remaining < 35_000 { 690 + tokio::time::sleep(Duration::from_millis(remaining + 100)).await; 691 + } 692 + 680 693 let uuid_bytes = uuid::Uuid::new_v4(); 681 694 let b = uuid_bytes.as_bytes(); 682 695 let unique_ip = format!("10.{}.{}.{}", b[0], b[1], b[2]);
+1 -4
crates/tranquil-pds/tests/server.rs
··· 65 65 .send() 66 66 .await 67 67 .unwrap(); 68 - assert!( 69 - missing_id.status() == StatusCode::BAD_REQUEST 70 - || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY 71 - ); 68 + assert_eq!(missing_id.status(), StatusCode::BAD_REQUEST); 72 69 let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base)) 73 70 .json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "Testpass123!" })) 74 71 .send().await.unwrap();
+27 -31
crates/tranquil-pds/tests/sso.rs
··· 40 40 41 41 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 42 42 let body: Value = res.json().await.unwrap(); 43 - assert_eq!(body["error"], "SsoProviderNotFound"); 43 + assert_eq!(body["error"], "InvalidRequest"); 44 44 } 45 45 46 46 #[tokio::test] ··· 61 61 62 62 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 63 63 let body: Value = res.json().await.unwrap(); 64 - assert!( 65 - body["error"] == "SsoInvalidAction" || body["error"] == "SsoProviderNotEnabled", 66 - "Expected SsoInvalidAction or SsoProviderNotEnabled, got: {}", 67 - body["error"] 68 - ); 64 + assert_eq!(body["error"], "InvalidRequest"); 69 65 } 70 66 71 67 #[tokio::test] ··· 232 228 let _url = base_url().await; 233 229 let pool = get_test_db_pool().await; 234 230 235 - let did = unsafe { 236 - Did::new_unchecked(format!( 237 - "did:plc:test{}", 238 - &uuid::Uuid::new_v4().simple().to_string()[..12] 239 - )) 240 - }; 231 + let did: Did = format!( 232 + "did:plc:test{}", 233 + &uuid::Uuid::new_v4().simple().to_string()[..12] 234 + ) 235 + .parse() 236 + .expect("valid test DID"); 241 237 let provider = SsoProviderType::Github; 242 238 let provider_user_id = format!("github_user_{}", uuid::Uuid::new_v4().simple()); 243 239 ··· 352 348 let _url = base_url().await; 353 349 let pool = get_test_db_pool().await; 354 350 355 - let did1 = unsafe { 356 - Did::new_unchecked(format!( 357 - "did:plc:uc1{}", 358 - &uuid::Uuid::new_v4().simple().to_string()[..10] 359 - )) 360 - }; 361 - let did2 = unsafe { 362 - Did::new_unchecked(format!( 363 - "did:plc:uc2{}", 364 - &uuid::Uuid::new_v4().simple().to_string()[..10] 365 - )) 366 - }; 351 + let did1: Did = format!( 352 + "did:plc:uc1{}", 353 + &uuid::Uuid::new_v4().simple().to_string()[..10] 354 + ) 355 + .parse() 356 + .expect("valid test DID"); 357 + let did2: Did = format!( 358 + "did:plc:uc2{}", 359 + &uuid::Uuid::new_v4().simple().to_string()[..10] 360 + ) 361 + .parse() 362 + .expect("valid test DID"); 367 363 let provider_user_id = format!("unique_test_{}", uuid::Uuid::new_v4().simple()); 368 364 369 365 sqlx::query!( ··· 583 579 let _url = base_url().await; 584 580 let pool = get_test_db_pool().await; 585 581 586 - let did = unsafe { 587 - Did::new_unchecked(format!( 588 - "did:plc:del{}", 589 - &uuid::Uuid::new_v4().simple().to_string()[..10] 590 - )) 591 - }; 592 - let wrong_did = unsafe { Did::new_unchecked("did:plc:wrongdid12345") }; 582 + let did: Did = format!( 583 + "did:plc:del{}", 584 + &uuid::Uuid::new_v4().simple().to_string()[..10] 585 + ) 586 + .parse() 587 + .expect("valid test DID"); 588 + let wrong_did: Did = "did:plc:wrongdid12345".parse().expect("valid test DID"); 593 589 594 590 sqlx::query!( 595 591 "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
+1 -2
crates/tranquil-ripple/src/transport.rs
··· 373 373 if let Err(e) = sock_ref.set_tcp_nodelay(true) { 374 374 tracing::warn!(error = %e, "failed to set TCP_NODELAY"); 375 375 } 376 - let keepalive = socket2::TcpKeepalive::new() 377 - .with_time(Duration::from_secs(30)); 376 + let keepalive = socket2::TcpKeepalive::new().with_time(Duration::from_secs(30)); 378 377 #[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))] 379 378 let keepalive = keepalive.with_interval(Duration::from_secs(10)); 380 379 let params = keepalive;
+1
crates/tranquil-scopes/Cargo.toml
··· 11 11 reqwest = { workspace = true } 12 12 serde = { workspace = true } 13 13 serde_json = { workspace = true } 14 + thiserror = { workspace = true } 14 15 tokio = { workspace = true } 15 16 tracing = { workspace = true } 16 17 urlencoding = "2"
+65 -29
crates/tranquil-scopes/src/permission_set.rs
··· 6 6 use tokio::sync::RwLock; 7 7 use tracing::{debug, warn}; 8 8 9 + #[derive(Debug, thiserror::Error)] 10 + pub enum ScopeExpansionError { 11 + #[error("Invalid NSID format: {0}")] 12 + InvalidNsid(String), 13 + #[error("Missing definition: {0}")] 14 + MissingDefinition(String), 15 + #[error("Unexpected lexicon type: {0}")] 16 + UnexpectedType(String), 17 + #[error("DNS resolution failed: {0}")] 18 + DnsResolution(String), 19 + #[error("HTTP request failed: {0}")] 20 + HttpFailed(String), 21 + #[error("DID resolution failed: {0}")] 22 + DidResolution(String), 23 + #[error("No valid permissions found in permission-set")] 24 + EmptyPermissions, 25 + } 26 + 9 27 static LEXICON_CACHE: LazyLock<RwLock<HashMap<String, CachedLexicon>>> = 10 28 LazyLock::new(|| RwLock::new(HashMap::new())); 11 29 ··· 86 104 .unwrap_or((rest, None)) 87 105 } 88 106 89 - async fn expand_permission_set(nsid: &str, aud: Option<&str>) -> Result<String, String> { 107 + async fn expand_permission_set( 108 + nsid: &str, 109 + aud: Option<&str>, 110 + ) -> Result<String, ScopeExpansionError> { 90 111 let cache_key = match aud { 91 112 Some(a) => format!("{}?aud={}", nsid, a), 92 113 None => nsid.to_string(), ··· 107 128 let main_def = lexicon 108 129 .defs 109 130 .get("main") 110 - .ok_or("Missing 'main' definition in lexicon")?; 131 + .ok_or(ScopeExpansionError::MissingDefinition("main".to_string()))?; 111 132 112 133 if main_def.def_type != "permission-set" { 113 - return Err(format!( 114 - "Expected permission-set type, got: {}", 115 - main_def.def_type 134 + return Err(ScopeExpansionError::UnexpectedType( 135 + main_def.def_type.clone(), 116 136 )); 117 137 } 118 138 119 - let permissions = main_def 120 - .permissions 121 - .as_ref() 122 - .ok_or("Missing permissions in permission-set")?; 139 + let permissions = 140 + main_def 141 + .permissions 142 + .as_ref() 143 + .ok_or(ScopeExpansionError::MissingDefinition( 144 + "permissions".to_string(), 145 + ))?; 123 146 124 147 let namespace_authority = extract_namespace_authority(nsid); 125 148 let expanded = build_expanded_scopes(permissions, aud, &namespace_authority); 126 149 127 150 if expanded.is_empty() { 128 - return Err("No valid permissions found in permission-set".to_string()); 151 + return Err(ScopeExpansionError::EmptyPermissions); 129 152 } 130 153 131 154 { ··· 143 166 Ok(expanded) 144 167 } 145 168 146 - async fn fetch_lexicon_via_atproto(nsid: &str) -> Result<LexiconDoc, String> { 169 + async fn fetch_lexicon_via_atproto(nsid: &str) -> Result<LexiconDoc, ScopeExpansionError> { 147 170 let parts: Vec<&str> = nsid.split('.').collect(); 148 171 if parts.len() < 3 { 149 - return Err(format!("Invalid NSID format: {}", nsid)); 172 + return Err(ScopeExpansionError::InvalidNsid(nsid.to_string())); 150 173 } 151 174 152 175 let authority = parts[..2] ··· 166 189 let client = Client::builder() 167 190 .timeout(std::time::Duration::from_secs(10)) 168 191 .build() 169 - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 192 + .map_err(|e| ScopeExpansionError::HttpFailed(e.to_string()))?; 170 193 171 194 let url = format!( 172 195 "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=com.atproto.lexicon.schema&rkey={}", ··· 181 204 .header("Accept", "application/json") 182 205 .send() 183 206 .await 184 - .map_err(|e| format!("Failed to fetch lexicon: {}", e))?; 207 + .map_err(|e| ScopeExpansionError::HttpFailed(e.to_string()))?; 185 208 186 209 if !response.status().is_success() { 187 - return Err(format!( 188 - "Failed to fetch lexicon: HTTP {}", 210 + return Err(ScopeExpansionError::HttpFailed(format!( 211 + "HTTP {}", 189 212 response.status() 190 - )); 213 + ))); 191 214 } 192 215 193 216 let record: GetRecordResponse = response 194 217 .json() 195 218 .await 196 - .map_err(|e| format!("Failed to parse lexicon response: {}", e))?; 219 + .map_err(|e| ScopeExpansionError::HttpFailed(e.to_string()))?; 197 220 198 221 Ok(record.value) 199 222 } 200 223 201 - async fn resolve_lexicon_did_authority(authority: &str) -> Result<String, String> { 224 + async fn resolve_lexicon_did_authority(authority: &str) -> Result<String, ScopeExpansionError> { 202 225 let resolver = TokioAsyncResolver::tokio_from_system_conf() 203 - .map_err(|e| format!("Failed to create DNS resolver: {}", e))?; 226 + .map_err(|e| ScopeExpansionError::DnsResolution(e.to_string()))?; 204 227 205 228 let dns_name = format!("_lexicon.{}", authority); 206 229 debug!(dns_name = %dns_name, "Looking up DNS TXT record"); ··· 208 231 let txt_records = resolver 209 232 .txt_lookup(&dns_name) 210 233 .await 211 - .map_err(|e| format!("DNS lookup failed for {}: {}", dns_name, e))?; 234 + .map_err(|e| ScopeExpansionError::DnsResolution(format!("{}: {}", dns_name, e)))?; 212 235 213 236 txt_records 214 237 .iter() ··· 217 240 let txt = String::from_utf8_lossy(data); 218 241 txt.strip_prefix("did=").map(|did| did.to_string()) 219 242 }) 220 - .ok_or_else(|| format!("No valid did= TXT record found at {}", dns_name)) 243 + .ok_or_else(|| { 244 + ScopeExpansionError::DnsResolution(format!( 245 + "No valid did= TXT record found at {}", 246 + dns_name 247 + )) 248 + }) 221 249 } 222 250 223 - async fn resolve_did_to_pds(did: &str) -> Result<String, String> { 251 + async fn resolve_did_to_pds(did: &str) -> Result<String, ScopeExpansionError> { 224 252 let client = Client::builder() 225 253 .timeout(std::time::Duration::from_secs(10)) 226 254 .build() 227 - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 255 + .map_err(|e| ScopeExpansionError::HttpFailed(e.to_string()))?; 228 256 229 257 let url = if did.starts_with("did:plc:") { 230 258 format!("https://plc.directory/{}", did) ··· 232 260 let domain = did.strip_prefix("did:web:").unwrap(); 233 261 format!("https://{}/.well-known/did.json", domain) 234 262 } else { 235 - return Err(format!("Unsupported DID method: {}", did)); 263 + return Err(ScopeExpansionError::DidResolution(format!( 264 + "Unsupported DID method: {}", 265 + did 266 + ))); 236 267 }; 237 268 238 269 let response = client ··· 240 271 .header("Accept", "application/json") 241 272 .send() 242 273 .await 243 - .map_err(|e| format!("Failed to resolve DID: {}", e))?; 274 + .map_err(|e| ScopeExpansionError::DidResolution(e.to_string()))?; 244 275 245 276 if !response.status().is_success() { 246 - return Err(format!("Failed to resolve DID: HTTP {}", response.status())); 277 + return Err(ScopeExpansionError::DidResolution(format!( 278 + "HTTP {}", 279 + response.status() 280 + ))); 247 281 } 248 282 249 283 let doc: PlcDocument = response 250 284 .json() 251 285 .await 252 - .map_err(|e| format!("Failed to parse DID document: {}", e))?; 286 + .map_err(|e| ScopeExpansionError::DidResolution(e.to_string()))?; 253 287 254 288 doc.service 255 289 .iter() 256 290 .find(|s| s.id == "#atproto_pds") 257 291 .map(|s| s.service_endpoint.clone()) 258 - .ok_or_else(|| "No #atproto_pds service found in DID document".to_string()) 292 + .ok_or(ScopeExpansionError::DidResolution( 293 + "No #atproto_pds service found in DID document".to_string(), 294 + )) 259 295 } 260 296 261 297 fn extract_namespace_authority(nsid: &str) -> String {
+34 -24
crates/tranquil-types/src/lib.rs
··· 101 101 pub fn new(s: impl Into<String>) -> Self { 102 102 Self(s.into()) 103 103 } 104 - 105 - pub fn new_unchecked(s: impl Into<String>) -> Self { 106 - Self(s.into()) 107 - } 108 104 } 109 105 110 106 impl_string_common!($name); ··· 123 119 124 120 impl $name { 125 121 pub fn new(s: impl Into<String>) -> Self { 126 - Self(s.into()) 127 - } 128 - 129 - pub fn new_unchecked(s: impl Into<String>) -> Self { 130 122 Self(s.into()) 131 123 } 132 124 } ··· 140 132 $(#[$meta:meta])* 141 133 $vis:vis struct $name:ident; 142 134 error = $error:ident; 135 + label = $label:expr; 143 136 validator = $validator:expr; 144 137 ) => { 145 138 $(#[$meta])* ··· 165 158 validator(&s).map_err(|_| $error::Invalid(s.clone()))?; 166 159 Ok(Self(s)) 167 160 } 168 - 169 - #[allow(unsafe_code, clippy::missing_safety_doc)] 170 - pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 171 - Self(s.into()) 172 - } 173 161 } 174 162 175 163 impl FromStr for $name { ··· 182 170 183 171 impl_string_common!($name); 184 172 185 - #[derive(Debug, Clone, thiserror::Error)] 173 + #[derive(Debug, Clone)] 186 174 pub enum $error { 187 - #[error("invalid: {0}")] 188 175 Invalid(String), 189 176 } 177 + 178 + impl std::fmt::Display for $error { 179 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 180 + match self { 181 + Self::Invalid(s) => write!(f, concat!("invalid ", $label, ": {}"), s), 182 + } 183 + } 184 + } 185 + 186 + impl std::error::Error for $error {} 190 187 }; 191 188 } 192 189 193 190 validated_string_newtype! { 194 191 pub struct Did; 195 192 error = DidError; 193 + label = "DID"; 196 194 validator = |s| jacquard_common::types::string::Did::new(s).map(|_| ()).map_err(|_| ()); 197 195 } 198 196 ··· 217 215 D: serde::Deserializer<'de>, 218 216 { 219 217 let s = String::deserialize(deserializer)?; 220 - Ok(Handle(s)) 218 + Handle::new(&s).map_err(|e| serde::de::Error::custom(e.to_string())) 221 219 } 222 220 } 223 221 ··· 227 225 jacquard_common::types::string::Handle::new(&s) 228 226 .map_err(|_| HandleError::Invalid(s.clone()))?; 229 227 Ok(Self(s)) 230 - } 231 - 232 - #[allow(unsafe_code, clippy::missing_safety_doc)] 233 - pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 234 - Self(s.into()) 235 228 } 236 229 } 237 230 ··· 368 361 validated_string_newtype! { 369 362 pub struct Rkey; 370 363 error = RkeyError; 364 + label = "rkey"; 371 365 validator = |s| jacquard_common::types::string::Rkey::new(s).map(|_| ()).map_err(|_| ()); 372 366 } 373 367 ··· 389 383 validated_string_newtype! { 390 384 pub struct Nsid; 391 385 error = NsidError; 386 + label = "NSID"; 392 387 validator = |s| jacquard_common::types::string::Nsid::new(s).map(|_| ()).map_err(|_| ()); 393 388 } 394 389 ··· 405 400 validated_string_newtype! { 406 401 pub struct AtUri; 407 402 error = AtUriError; 403 + label = "AT URI"; 408 404 validator = |s| jacquard_common::types::string::AtUri::new(s).map(|_| ()).map_err(|_| ()); 409 405 } 410 406 ··· 435 431 validated_string_newtype! { 436 432 pub struct Tid; 437 433 error = TidError; 434 + label = "TID"; 438 435 validator = |s| jacquard_common::types::string::Tid::from_str(s).map(|_| ()).map_err(|_| ()); 439 436 } 440 437 ··· 448 445 validated_string_newtype! { 449 446 pub struct Datetime; 450 447 error = DatetimeError; 448 + label = "datetime"; 451 449 validator = |s| jacquard_common::types::string::Datetime::from_str(s).map(|_| ()).map_err(|_| ()); 452 450 } 453 451 ··· 466 464 validated_string_newtype! { 467 465 pub struct Language; 468 466 error = LanguageError; 467 + label = "language"; 469 468 validator = |s| jacquard_common::types::string::Language::from_str(s).map(|_| ()).map_err(|_| ()); 470 469 } 471 470 ··· 491 490 Ok(Self(s)) 492 491 } 493 492 494 - #[allow(unsafe_code, clippy::missing_safety_doc)] 495 - pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 496 - Self(s.into()) 493 + pub fn from_cid(cid: &cid::Cid) -> Self { 494 + Self(cid.to_string()) 497 495 } 498 496 499 497 pub fn to_cid(&self) -> Option<cid::Cid> { 500 498 cid::Cid::from_str(&self.0).ok() 499 + } 500 + } 501 + 502 + impl From<cid::Cid> for CidLink { 503 + fn from(cid: cid::Cid) -> Self { 504 + Self(cid.to_string()) 505 + } 506 + } 507 + 508 + impl From<&cid::Cid> for CidLink { 509 + fn from(cid: &cid::Cid) -> Self { 510 + Self(cid.to_string()) 501 511 } 502 512 } 503 513
-162
frontend/src/components/dashboard/AdminContent.svelte
··· 35 35 invitesDisabled?: boolean 36 36 } 37 37 38 - interface Invite { 39 - code: string 40 - available: number 41 - disabled: boolean 42 - forAccount: string 43 - createdBy: string 44 - createdAt: string 45 - uses: Array<{ usedBy: string; usedAt: string }> 46 - } 47 - 48 38 let stats = $state<ServerStats | null>(null) 49 39 let users = $state<User[]>([]) 50 40 let loading = $state(true) 51 41 let usersLoading = $state(false) 52 42 let searchQuery = $state('') 53 43 let usersCursor = $state<string | undefined>(undefined) 54 - 55 - let invites = $state<Invite[]>([]) 56 - let invitesLoading = $state(false) 57 - let invitesCursor = $state<string | undefined>(undefined) 58 - let showInvites = $state(false) 59 44 60 45 let selectedUser = $state<User | null>(null) 61 46 let userActionLoading = $state(false) ··· 219 204 logoChanged 220 205 } 221 206 222 - async function loadInvites(reset = false) { 223 - invitesLoading = true 224 - if (reset) { 225 - invites = [] 226 - invitesCursor = undefined 227 - } 228 - try { 229 - const result = await api.getInviteCodes(session.accessJwt, { 230 - cursor: reset ? undefined : invitesCursor, 231 - limit: 25, 232 - }) 233 - invites = reset ? result.codes : [...invites, ...result.codes] 234 - invitesCursor = result.cursor 235 - showInvites = true 236 - } catch { 237 - toast.error($_('admin.failedToLoadInvites')) 238 - } finally { 239 - invitesLoading = false 240 - } 241 - } 242 - 243 - async function disableInvite(code: string) { 244 - if (!confirm($_('admin.disableInviteConfirm', { values: { code } }))) return 245 - try { 246 - await api.disableInviteCodes(session.accessJwt, [code]) 247 - invites = invites.map(i => i.code === code ? { ...i, disabled: true } : i) 248 - toast.success($_('admin.inviteDisabled')) 249 - } catch (e) { 250 - toast.error(e instanceof ApiError ? e.message : $_('admin.failedToDisableInvite')) 251 - } 252 - } 253 - 254 207 async function showUserDetail(user: User) { 255 208 selectedUser = user 256 209 userDetailLoading = true ··· 467 420 {/if} 468 421 </section> 469 422 470 - <section class="invites-section"> 471 - <h3>{$_('admin.inviteCodes')}</h3> 472 - <div class="section-actions"> 473 - <button onclick={() => loadInvites(true)} disabled={invitesLoading}> 474 - {invitesLoading ? $_('common.loading') : showInvites ? $_('admin.refresh') : $_('admin.loadInviteCodes')} 475 - </button> 476 - </div> 477 - {#if showInvites} 478 - {#if invites.length === 0} 479 - <p class="empty">{$_('admin.noInvites')}</p> 480 - {:else} 481 - <ul class="invite-list"> 482 - {#each invites as invite} 483 - <li class="invite-item" class:disabled-row={invite.disabled}> 484 - <div class="invite-info"> 485 - <code class="invite-code">{invite.code}</code> 486 - <span class="invite-meta"> 487 - {$_('admin.available')}: {invite.available} - {$_('admin.uses')}: {invite.uses.length} - {$_('admin.created')}: {formatDate(invite.createdAt)} 488 - </span> 489 - </div> 490 - <div class="invite-status"> 491 - {#if invite.disabled} 492 - <span class="badge deactivated">{$_('admin.disabled')}</span> 493 - {:else if invite.available === 0} 494 - <span class="badge unverified">{$_('admin.exhausted')}</span> 495 - {:else} 496 - <span class="badge verified">{$_('admin.active')}</span> 497 - {/if} 498 - </div> 499 - <div class="invite-actions"> 500 - {#if !invite.disabled} 501 - <button class="action-btn danger" onclick={() => disableInvite(invite.code)}> 502 - {$_('admin.disable')} 503 - </button> 504 - {/if} 505 - </div> 506 - </li> 507 - {/each} 508 - </ul> 509 - {#if invitesCursor} 510 - <button type="button" class="load-more" onclick={() => loadInvites(false)} disabled={invitesLoading}> 511 - {invitesLoading ? $_('common.loading') : $_('admin.loadMore')} 512 - </button> 513 - {/if} 514 - {/if} 515 - {/if} 516 - </section> 517 423 </div> 518 424 519 425 {#if selectedUser} ··· 857 763 .badge.unverified { 858 764 background: var(--bg-tertiary); 859 765 color: var(--text-secondary); 860 - } 861 - 862 - .section-actions { 863 - margin-bottom: var(--space-4); 864 - } 865 - 866 - .invite-list { 867 - list-style: none; 868 - padding: 0; 869 - margin: 0; 870 - display: flex; 871 - flex-direction: column; 872 - gap: var(--space-2); 873 - } 874 - 875 - .invite-item { 876 - display: flex; 877 - align-items: center; 878 - padding: var(--space-3); 879 - background: var(--bg-card); 880 - border: 1px solid var(--border-color); 881 - border-radius: var(--radius-md); 882 - gap: var(--space-3); 883 - } 884 - 885 - .invite-item.disabled-row { 886 - opacity: 0.6; 887 - } 888 - 889 - .invite-info { 890 - flex: 1; 891 - min-width: 0; 892 - } 893 - 894 - .invite-code { 895 - display: block; 896 - font-family: var(--font-mono); 897 - font-size: var(--text-sm); 898 - } 899 - 900 - .invite-meta { 901 - font-size: var(--text-xs); 902 - color: var(--text-secondary); 903 - } 904 - 905 - .invite-status { 906 - flex-shrink: 0; 907 - } 908 - 909 - .invite-actions { 910 - flex-shrink: 0; 911 - } 912 - 913 - .action-btn { 914 - padding: var(--space-2) var(--space-3); 915 - font-size: var(--text-sm); 916 - border-radius: var(--radius-md); 917 - cursor: pointer; 918 - } 919 - 920 - .action-btn.danger { 921 - background: transparent; 922 - border: 1px solid var(--error-border); 923 - color: var(--error-text); 924 - } 925 - 926 - .action-btn.danger:hover { 927 - background: var(--error-bg); 928 766 } 929 767 930 768 .modal-overlay {
+38 -2
frontend/src/components/dashboard/InviteCodesContent.svelte
··· 5 5 import { toast } from '../../lib/toast.svelte' 6 6 import { formatDate } from '../../lib/date' 7 7 import type { Session } from '../../lib/types/api' 8 + import Skeleton from '../Skeleton.svelte' 8 9 9 10 interface Props { 10 11 session: Session ··· 15 16 let codes = $state<InviteCode[]>([]) 16 17 let loading = $state(true) 17 18 let creating = $state(false) 19 + let disablingCode = $state<string | null>(null) 18 20 let createdCode = $state<string | null>(null) 19 21 let createdCodeCopied = $state(false) 20 22 let copiedCode = $state<string | null>(null) ··· 60 62 } 61 63 } 62 64 65 + async function disableCode(code: string) { 66 + if (!confirm($_('inviteCodes.disableConfirm', { values: { code } }))) return 67 + disablingCode = code 68 + try { 69 + await api.disableInviteCodes(session.accessJwt, [code]) 70 + codes = codes.map(c => c.code === code ? { ...c, disabled: true } : c) 71 + toast.success($_('inviteCodes.disableSuccess')) 72 + } catch (e) { 73 + toast.error(e instanceof ApiError ? e.message : $_('inviteCodes.disableFailed')) 74 + } finally { 75 + disablingCode = null 76 + } 77 + } 78 + 63 79 function copyCode(code: string) { 64 80 navigator.clipboard.writeText(code) 65 81 copiedCode = code ··· 96 112 <section class="list-section"> 97 113 <h2>{$_('inviteCodes.yourCodes')}</h2> 98 114 {#if loading} 99 - <div class="loading">{$_('common.loading')}</div> 115 + <ul class="code-list"> 116 + {#each Array(3) as _} 117 + <li class="code-item skeleton-item"> 118 + <div class="code-main"> 119 + <Skeleton variant="line" size="medium" /> 120 + </div> 121 + <div class="code-meta"> 122 + <Skeleton variant="line" size="short" /> 123 + <Skeleton variant="line" size="tiny" /> 124 + </div> 125 + </li> 126 + {/each} 127 + </ul> 100 128 {:else if codes.length === 0} 101 129 <p class="empty">{$_('inviteCodes.noCodes')}</p> 102 130 {:else} ··· 174 202 margin: 0 0 var(--space-4) 0; 175 203 } 176 204 177 - .loading, 178 205 .empty { 179 206 color: var(--text-secondary); 180 207 padding: var(--space-6); ··· 195 222 background: var(--bg-secondary); 196 223 border: 1px solid var(--border-color); 197 224 border-radius: var(--radius-lg); 225 + } 226 + 227 + .skeleton-item { 228 + pointer-events: none; 198 229 } 199 230 200 231 .code-item.disabled { ··· 221 252 } 222 253 223 254 .copy-btn { 255 + flex-shrink: 0; 256 + } 257 + 258 + .danger-text { 259 + color: var(--error-text); 224 260 flex-shrink: 0; 225 261 } 226 262
+146 -30
frontend/src/lib/migration/plc-ops.ts
··· 7 7 import { 8 8 P256PrivateKey, 9 9 parsePrivateMultikey, 10 + parsePublicMultikey, 10 11 Secp256k1PrivateKey, 11 12 Secp256k1PrivateKeyExportable, 12 13 } from "@atcute/crypto"; 13 14 import * as CBOR from "@atcute/cbor"; 14 - import { fromBase16, toBase64Url } from "@atcute/multibase"; 15 + import { 16 + fromBase16, 17 + fromBase58Btc, 18 + fromBase64Url, 19 + toBase64Url, 20 + } from "@atcute/multibase"; 15 21 16 22 export type PrivateKey = P256PrivateKey | Secp256k1PrivateKey; 17 23 ··· 36 42 sig?: string; 37 43 } 38 44 45 + type KeyCurve = "secp256k1" | "p256"; 46 + 47 + const HEX_PRIVATE_KEY_REGEX = /^[0-9a-f]{64}$/i; 48 + const BASE58BTC_CHARSET_REGEX = /^[a-km-zA-HJ-NP-Z1-9]+$/; 49 + 50 + const importRawBytes = ( 51 + bytes: Uint8Array, 52 + curve: KeyCurve, 53 + ): Promise<PrivateKey> => 54 + curve === "p256" 55 + ? P256PrivateKey.importRaw(bytes) 56 + : Secp256k1PrivateKey.importRaw(bytes); 57 + 58 + const importFromMultikeyMatch = ( 59 + match: ReturnType<typeof parsePrivateMultikey>, 60 + ): Promise<PrivateKey> => 61 + match.type === "p256" 62 + ? P256PrivateKey.importRaw(match.privateKeyBytes) 63 + : Secp256k1PrivateKey.importRaw(match.privateKeyBytes); 64 + 65 + const importJwk = async ( 66 + json: string, 67 + _curve: KeyCurve, 68 + ): Promise<PrivateKey> => { 69 + const parsed: unknown = JSON.parse(json); 70 + if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) { 71 + throw new Error("Invalid JWK: expected a JSON object"); 72 + } 73 + const jwk = parsed as Record<string, unknown>; 74 + 75 + if (jwk.kty !== "EC") { 76 + throw new Error( 77 + `Unsupported JWK key type: ${ 78 + String(jwk.kty) 79 + }. Only EC keys are supported`, 80 + ); 81 + } 82 + 83 + if (typeof jwk.d !== "string") { 84 + throw new Error( 85 + "This JWK is a public key (missing 'd' parameter). The private key JWK is required", 86 + ); 87 + } 88 + 89 + const detectedCurve: KeyCurve = (() => { 90 + switch (jwk.crv) { 91 + case "secp256k1": 92 + return "secp256k1"; 93 + case "P-256": 94 + return "p256"; 95 + default: 96 + throw new Error( 97 + `Unsupported JWK curve: ${ 98 + String(jwk.crv) 99 + }. Expected secp256k1 or P-256`, 100 + ); 101 + } 102 + })(); 103 + 104 + const privateKeyBytes = fromBase64Url(jwk.d); 105 + return importRawBytes(privateKeyBytes, detectedCurve); 106 + }; 107 + 108 + const importMultikeyOrBase58 = ( 109 + input: string, 110 + curve: KeyCurve, 111 + ): Promise<PrivateKey> => { 112 + try { 113 + const match = parsePrivateMultikey(input); 114 + return importFromMultikeyMatch(match); 115 + } catch { 116 + try { 117 + parsePublicMultikey(input); 118 + throw new Error( 119 + "This is a public multikey. The private key multikey is required", 120 + ); 121 + } catch (publicErr) { 122 + if ( 123 + publicErr instanceof Error && 124 + publicErr.message.includes("public multikey") 125 + ) { 126 + throw publicErr; 127 + } 128 + } 129 + 130 + try { 131 + return importBase58Raw(input, curve); 132 + } catch { 133 + return importBase58Raw(input.slice(1), curve); 134 + } 135 + } 136 + }; 137 + 138 + const importBase58Raw = ( 139 + input: string, 140 + curve: KeyCurve, 141 + ): Promise<PrivateKey> => { 142 + const bytes = fromBase58Btc(input); 143 + if (bytes.length !== 32) { 144 + throw new Error( 145 + `Invalid base58 key: decoded to ${bytes.length} bytes, expected 32`, 146 + ); 147 + } 148 + return importRawBytes(bytes, curve); 149 + }; 150 + 151 + const detectAndImportPrivateKey = ( 152 + input: string, 153 + curve: KeyCurve, 154 + ): Promise<PrivateKey> => { 155 + if (input.startsWith("{")) { 156 + return importJwk(input, curve); 157 + } 158 + 159 + if (HEX_PRIVATE_KEY_REGEX.test(input)) { 160 + return importRawBytes(fromBase16(input.toLowerCase()), curve); 161 + } 162 + 163 + if (input.startsWith("z")) { 164 + return importMultikeyOrBase58(input, curve); 165 + } 166 + 167 + if (BASE58BTC_CHARSET_REGEX.test(input)) { 168 + return importBase58Raw(input, curve); 169 + } 170 + 171 + throw new Error( 172 + "Unrecognized key format. Expected hex, base58, multikey, or JWK", 173 + ); 174 + }; 175 + 39 176 const jsonToB64Url = (obj: unknown): string => { 40 177 const enc = new TextEncoder(); 41 178 const json = JSON.stringify(obj); ··· 88 225 89 226 async getKeyPair( 90 227 privateKeyString: string, 91 - type: "secp256k1" | "p256" = "secp256k1", 228 + type: KeyCurve = "secp256k1", 92 229 ): Promise<KeypairInfo> { 93 - const HEX_REGEX = /^[0-9a-f]+$/i; 94 - const MULTIKEY_REGEX = /^z[a-km-zA-HJ-NP-Z1-9]+$/; 95 - let keypair: PrivateKey | undefined; 230 + const trimmed = privateKeyString.trim(); 96 231 97 - const trimmed = privateKeyString.trim(); 232 + if (trimmed.length === 0) { 233 + throw new Error("Private key is required"); 234 + } 98 235 99 - if (HEX_REGEX.test(trimmed) && trimmed.length === 64) { 100 - const privateKeyBytes = fromBase16(trimmed); 101 - if (type === "p256") { 102 - keypair = await P256PrivateKey.importRaw(privateKeyBytes); 103 - } else { 104 - keypair = await Secp256k1PrivateKey.importRaw(privateKeyBytes); 105 - } 106 - } else if (MULTIKEY_REGEX.test(trimmed)) { 107 - const match = parsePrivateMultikey(trimmed); 108 - const privateKeyBytes = match.privateKeyBytes; 109 - if (match.type === "p256") { 110 - keypair = await P256PrivateKey.importRaw(privateKeyBytes); 111 - } else if (match.type === "secp256k1") { 112 - keypair = await Secp256k1PrivateKey.importRaw(privateKeyBytes); 113 - } else { 114 - throw new Error( 115 - `Unsupported key type: ${(match as { type: string }).type}`, 116 - ); 117 - } 118 - } else { 236 + if (trimmed.startsWith("did:key:")) { 119 237 throw new Error( 120 - "Invalid key format. Expected 64-char hex or multikey format.", 238 + "This is a did:key public key identifier. The private key is required", 121 239 ); 122 240 } 123 241 124 - if (!keypair) { 125 - throw new Error("Failed to parse private key"); 126 - } 242 + const keypair = await detectAndImportPrivateKey(trimmed, type); 127 243 128 244 return { 129 245 type: "private_key",
+4 -11
frontend/src/locales/en.json
··· 323 323 "disabled": "Disabled", 324 324 "created": "Invite Code Created", 325 325 "copy": "Copy", 326 + "disable": "Disable", 327 + "disableConfirm": "Disable invite code {code}?", 328 + "disableSuccess": "Invite code disabled", 329 + "disableFailed": "Failed to disable invite code", 326 330 "createdOn": "Created {date}" 327 331 }, 328 332 "security": { ··· 504 508 "status": "Status", 505 509 "created": "Created", 506 510 "loadMore": "Load More", 507 - "inviteCodes": "Invite Codes", 508 - "loadInviteCodes": "Load Invite Codes", 509 - "refresh": "Refresh", 510 - "noInvites": "No invite codes found", 511 - "available": "Available", 512 - "uses": "Uses", 513 - "disable": "Disable", 514 - "disableInviteConfirm": "Disable invite code {code}?", 515 - "active": "Active", 516 - "exhausted": "Exhausted", 517 511 "disabled": "Disabled", 518 512 "userDetails": "User Details", 519 513 "did": "DID", ··· 526 520 "verified": "Verified", 527 521 "unverified": "Unverified", 528 522 "deactivated": "Deactivated", 529 - "inviteDisabled": "Invite code disabled", 530 523 "invitesEnabled": "User invites enabled", 531 524 "invitesDisabled": "User invites disabled", 532 525 "userDeleted": "User account deleted",
+4 -11
frontend/src/locales/fi.json
··· 321 321 "disabled": "Poistettu käytöstä", 322 322 "created": "Kutsukoodi luotu", 323 323 "copy": "Kopioi", 324 + "disable": "Poista käytöstä", 325 + "disableConfirm": "Poista kutsukoodi {code} käytöstä?", 326 + "disableSuccess": "Kutsukoodi poistettu käytöstä", 327 + "disableFailed": "Kutsukoodin poistaminen käytöstä epäonnistui", 324 328 "createdOn": "Luotu {date}", 325 329 "loadFailed": "Kutsukoodien lataus epäonnistui", 326 330 "createFailed": "Kutsukoodin luonti epäonnistui" ··· 498 502 "status": "Tila", 499 503 "created": "Luotu", 500 504 "loadMore": "Lataa lisää", 501 - "inviteCodes": "Kutsukoodit", 502 - "loadInviteCodes": "Lataa kutsukoodit", 503 - "refresh": "Päivitä", 504 - "noInvites": "Kutsukoodeja ei löytynyt", 505 - "available": "Saatavilla", 506 - "uses": "Käyttökerrat", 507 - "disable": "Poista käytöstä", 508 - "disableInviteConfirm": "Poista kutsukoodi {code} käytöstä?", 509 - "active": "Aktiivinen", 510 - "exhausted": "Käytetty loppuun", 511 505 "disabled": "Poistettu käytöstä", 512 506 "userDetails": "Käyttäjän tiedot", 513 507 "did": "DID", ··· 526 520 "failedToLoadUsers": "Käyttäjien lataus epäonnistui", 527 521 "searchToSeeUsers": "Hae nähdäksesi käyttäjät", 528 522 "search": "Hae", 529 - "inviteDisabled": "Kutsukoodi poistettu käytöstä", 530 523 "invitesEnabled": "Käyttäjäkutsut käytössä", 531 524 "invitesDisabled": "Käyttäjäkutsut pois käytöstä", 532 525 "userDeleted": "Käyttäjätili poistettu",
+4 -11
frontend/src/locales/ja.json
··· 321 321 "disabled": "無効", 322 322 "created": "招待コードを作成しました", 323 323 "copy": "コピー", 324 + "disable": "無効化", 325 + "disableConfirm": "招待コード {code} を無効にしますか?", 326 + "disableSuccess": "招待コードを無効にしました", 327 + "disableFailed": "招待コードの無効化に失敗しました", 324 328 "createdOn": "{date} に作成", 325 329 "loadFailed": "招待コードの読み込みに失敗しました", 326 330 "createFailed": "招待コードの作成に失敗しました" ··· 498 502 "status": "ステータス", 499 503 "created": "作成日時", 500 504 "loadMore": "さらに読み込む", 501 - "inviteCodes": "招待コード", 502 - "loadInviteCodes": "招待コードを読み込む", 503 - "refresh": "更新", 504 - "noInvites": "招待コードが見つかりません", 505 - "available": "利用可能", 506 - "uses": "使用回数", 507 - "disable": "無効化", 508 - "disableInviteConfirm": "招待コード {code} を無効にしますか?", 509 - "active": "アクティブ", 510 - "exhausted": "使用済み", 511 505 "disabled": "無効", 512 506 "userDetails": "ユーザー詳細", 513 507 "did": "DID", ··· 526 520 "failedToLoadUsers": "ユーザーの読み込みに失敗しました", 527 521 "searchToSeeUsers": "検索してユーザーを表示", 528 522 "search": "検索", 529 - "inviteDisabled": "招待コードを無効にしました", 530 523 "invitesEnabled": "ユーザー招待を有効にしました", 531 524 "invitesDisabled": "ユーザー招待を無効にしました", 532 525 "userDeleted": "ユーザーアカウントを削除しました",
+4 -11
frontend/src/locales/ko.json
··· 321 321 "disabled": "비활성화됨", 322 322 "created": "초대 코드가 생성되었습니다", 323 323 "copy": "복사", 324 + "disable": "비활성화", 325 + "disableConfirm": "초대 코드 {code}을(를) 비활성화하시겠습니까?", 326 + "disableSuccess": "초대 코드가 비활성화되었습니다", 327 + "disableFailed": "초대 코드 비활성화 실패", 324 328 "createdOn": "{date}에 생성됨", 325 329 "loadFailed": "초대 코드 로딩 실패", 326 330 "createFailed": "초대 코드 생성 실패" ··· 498 502 "status": "상태", 499 503 "created": "생성일", 500 504 "loadMore": "더 불러오기", 501 - "inviteCodes": "초대 코드", 502 - "loadInviteCodes": "초대 코드 불러오기", 503 - "refresh": "새로고침", 504 - "noInvites": "초대 코드가 없습니다", 505 - "available": "사용 가능", 506 - "uses": "사용 횟수", 507 - "disable": "비활성화", 508 - "disableInviteConfirm": "초대 코드 {code}을(를) 비활성화하시겠습니까?", 509 - "active": "활성", 510 - "exhausted": "소진됨", 511 505 "disabled": "비활성화됨", 512 506 "userDetails": "사용자 세부 정보", 513 507 "did": "DID", ··· 526 520 "failedToLoadUsers": "사용자 로딩 실패", 527 521 "searchToSeeUsers": "검색하여 사용자 보기", 528 522 "search": "검색", 529 - "inviteDisabled": "초대 코드가 비활성화되었습니다", 530 523 "invitesEnabled": "사용자 초대가 활성화되었습니다", 531 524 "invitesDisabled": "사용자 초대가 비활성화되었습니다", 532 525 "userDeleted": "사용자 계정이 삭제되었습니다",
+4 -11
frontend/src/locales/sv.json
··· 320 320 "disabled": "Inaktiverad", 321 321 "created": "Inbjudningskod skapad", 322 322 "copy": "Kopiera", 323 + "disable": "Inaktivera", 324 + "disableConfirm": "Inaktivera inbjudningskod {code}?", 325 + "disableSuccess": "Inbjudningskod inaktiverad", 326 + "disableFailed": "Kunde inte inaktivera inbjudningskod", 323 327 "createdOn": "Skapad {date}", 324 328 "spent": "Förbrukad", 325 329 "loadFailed": "Kunde inte ladda inbjudningskoder", ··· 498 502 "status": "Status", 499 503 "created": "Skapad", 500 504 "loadMore": "Ladda fler", 501 - "inviteCodes": "Inbjudningskoder", 502 - "loadInviteCodes": "Ladda inbjudningskoder", 503 - "refresh": "Uppdatera", 504 - "noInvites": "Inga inbjudningskoder hittades", 505 - "available": "Tillgänglig", 506 - "uses": "Användningar", 507 - "disable": "Inaktivera", 508 - "disableInviteConfirm": "Inaktivera inbjudningskod {code}?", 509 - "active": "Aktiv", 510 - "exhausted": "Förbrukad", 511 505 "disabled": "Inaktiverad", 512 506 "userDetails": "Användardetaljer", 513 507 "did": "DID", ··· 526 520 "failedToLoadUsers": "Kunde inte ladda användare", 527 521 "searchToSeeUsers": "Sök för att visa användare", 528 522 "search": "Sök", 529 - "inviteDisabled": "Inbjudningskod inaktiverad", 530 523 "invitesEnabled": "Användarinbjudningar aktiverade", 531 524 "invitesDisabled": "Användarinbjudningar inaktiverade", 532 525 "userDeleted": "Användarkonto raderat",
+4 -11
frontend/src/locales/zh.json
··· 320 320 "disabled": "已禁用", 321 321 "created": "邀请码已创建", 322 322 "copy": "复制", 323 + "disable": "禁用", 324 + "disableConfirm": "禁用邀请码 {code}?", 325 + "disableSuccess": "邀请码已禁用", 326 + "disableFailed": "禁用邀请码失败", 323 327 "createdOn": "创建于 {date}", 324 328 "spent": "已使用", 325 329 "loadFailed": "加载邀请码失败", ··· 500 504 "status": "状态", 501 505 "created": "创建时间", 502 506 "loadMore": "加载更多", 503 - "inviteCodes": "邀请码", 504 - "loadInviteCodes": "加载邀请码", 505 - "refresh": "刷新", 506 - "noInvites": "暂无邀请码", 507 - "available": "可用", 508 - "uses": "使用次数", 509 - "disable": "禁用", 510 - "disableInviteConfirm": "禁用邀请码 {code}?", 511 - "active": "活跃", 512 - "exhausted": "已用完", 513 507 "disabled": "已禁用", 514 508 "userDetails": "用户详情", 515 509 "did": "DID", ··· 526 520 "failedToLoadUsers": "加载用户失败", 527 521 "searchToSeeUsers": "搜索以查看用户", 528 522 "search": "搜索", 529 - "inviteDisabled": "邀请码已禁用", 530 523 "invitesEnabled": "用户邀请已启用", 531 524 "invitesDisabled": "用户邀请已禁用", 532 525 "userDeleted": "用户账户已删除",
+215 -3
frontend/src/tests/migration/plc-ops.test.ts
··· 1 1 import { beforeEach, describe, expect, it, vi } from "vitest"; 2 2 import { PlcOps, plcOps } from "../../lib/migration/plc-ops.ts"; 3 + import { 4 + P256PrivateKeyExportable, 5 + Secp256k1PrivateKeyExportable, 6 + } from "@atcute/crypto"; 7 + import { fromBase58Btc, toBase58Btc } from "@atcute/multibase"; 3 8 4 9 describe("migration/plc-ops", () => { 5 10 beforeEach(() => { ··· 89 94 90 95 it("throws for invalid key format", async () => { 91 96 await expect(plcOps.getKeyPair("not-a-valid-key")).rejects.toThrow( 92 - "Invalid key format", 97 + "Unrecognized key format", 93 98 ); 94 99 }); 95 100 96 101 it("throws for hex key with wrong length", async () => { 97 - await expect(plcOps.getKeyPair("abc123")).rejects.toThrow( 98 - "Invalid key format", 102 + await expect(plcOps.getKeyPair("abc123")).rejects.toThrow(); 103 + }); 104 + }); 105 + 106 + describe("getKeyPair - multikey round-trip", () => { 107 + it("round-trips from createNewSecp256k1Keypair", async () => { 108 + const { privateKey, publicKey } = await plcOps 109 + .createNewSecp256k1Keypair(); 110 + 111 + const result = await plcOps.getKeyPair(privateKey); 112 + 113 + expect(result.didPublicKey).toBe(publicKey); 114 + }); 115 + 116 + it("produces correct multikey structure (z prefix, codec bytes)", async () => { 117 + const { privateKey } = await plcOps.createNewSecp256k1Keypair(); 118 + 119 + expect(privateKey.startsWith("z")).toBe(true); 120 + const decoded = fromBase58Btc(privateKey.slice(1)); 121 + expect(decoded[0]).toBe(0x81); 122 + expect(decoded[1]).toBe(0x26); 123 + expect(decoded.length).toBe(34); 124 + }); 125 + 126 + it("multikey import matches hex import of same raw bytes", async () => { 127 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 128 + const multikey = await keypair.exportPrivateKey("multikey"); 129 + const rawHex = await keypair.exportPrivateKey("rawHex"); 130 + 131 + const fromMultikey = await plcOps.getKeyPair(multikey); 132 + const fromHex = await plcOps.getKeyPair(rawHex); 133 + 134 + expect(fromMultikey.didPublicKey).toBe(fromHex.didPublicKey); 135 + }); 136 + }); 137 + 138 + describe("getKeyPair - hex format", () => { 139 + it("accepts uppercase hex", async () => { 140 + const result = await plcOps.getKeyPair("A".repeat(64)); 141 + 142 + expect(result.type).toBe("private_key"); 143 + expect(result.didPublicKey.startsWith("did:key:")).toBe(true); 144 + }); 145 + }); 146 + 147 + describe("getKeyPair - JWK format", () => { 148 + it("imports secp256k1 JWK with d parameter", async () => { 149 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 150 + const jwk = await keypair.exportPrivateKey("jwk"); 151 + const expectedDid = await keypair.exportPublicKey("did"); 152 + 153 + const result = await plcOps.getKeyPair(JSON.stringify(jwk)); 154 + 155 + expect(result.didPublicKey).toBe(expectedDid); 156 + }); 157 + 158 + it("imports P-256 JWK with d parameter", async () => { 159 + const keypair = await P256PrivateKeyExportable.createKeypair(); 160 + const jwk = await keypair.exportPrivateKey("jwk"); 161 + const expectedDid = await keypair.exportPublicKey("did"); 162 + 163 + const result = await plcOps.getKeyPair(JSON.stringify(jwk)); 164 + 165 + expect(result.didPublicKey).toBe(expectedDid); 166 + }); 167 + 168 + it("rejects JWK without d (public key)", async () => { 169 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 170 + const jwk = await keypair.exportPublicKey("jwk"); 171 + 172 + await expect( 173 + plcOps.getKeyPair(JSON.stringify(jwk)), 174 + ).rejects.toThrow("public key"); 175 + }); 176 + 177 + it("rejects unsupported kty", async () => { 178 + const jwk = { kty: "RSA", n: "abc", e: "AQAB" }; 179 + 180 + await expect(plcOps.getKeyPair(JSON.stringify(jwk))).rejects.toThrow( 181 + "Unsupported JWK key type", 99 182 ); 183 + }); 184 + 185 + it("rejects unsupported crv", async () => { 186 + const jwk = { kty: "EC", crv: "P-384", d: "AAAA", x: "BBBB", y: "CCCC" }; 187 + 188 + await expect(plcOps.getKeyPair(JSON.stringify(jwk))).rejects.toThrow( 189 + "Unsupported JWK curve", 190 + ); 191 + }); 192 + 193 + it("rejects malformed JSON", async () => { 194 + await expect(plcOps.getKeyPair("{not valid json")).rejects.toThrow(); 195 + }); 196 + 197 + it("produces same public key as hex import of same raw bytes", async () => { 198 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 199 + const jwk = await keypair.exportPrivateKey("jwk"); 200 + const rawHex = await keypair.exportPrivateKey("rawHex"); 201 + 202 + const fromJwk = await plcOps.getKeyPair(JSON.stringify(jwk)); 203 + const fromHex = await plcOps.getKeyPair(rawHex); 204 + 205 + expect(fromJwk.didPublicKey).toBe(fromHex.didPublicKey); 206 + }); 207 + }); 208 + 209 + describe("getKeyPair - plain base58 format", () => { 210 + it("imports base58-encoded 32-byte raw key", async () => { 211 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 212 + const rawBytes = await keypair.exportPrivateKey("raw"); 213 + const base58 = toBase58Btc(rawBytes); 214 + const expectedDid = await keypair.exportPublicKey("did"); 215 + 216 + const result = await plcOps.getKeyPair(base58); 217 + 218 + expect(result.didPublicKey).toBe(expectedDid); 219 + }); 220 + 221 + it("produces same public key as hex import of same raw bytes", async () => { 222 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 223 + const rawBytes = await keypair.exportPrivateKey("raw"); 224 + const rawHex = await keypair.exportPrivateKey("rawHex"); 225 + const base58 = toBase58Btc(rawBytes); 226 + 227 + const fromBase58 = await plcOps.getKeyPair(base58); 228 + const fromHex = await plcOps.getKeyPair(rawHex); 229 + 230 + expect(fromBase58.didPublicKey).toBe(fromHex.didPublicKey); 231 + }); 232 + 233 + it("rejects wrong decoded length", async () => { 234 + const shortBytes = new Uint8Array(16); 235 + crypto.getRandomValues(shortBytes); 236 + const base58Short = toBase58Btc(shortBytes); 237 + 238 + await expect(plcOps.getKeyPair(base58Short)).rejects.toThrow( 239 + "expected 32", 240 + ); 241 + }); 242 + }); 243 + 244 + describe("getKeyPair - cross-format consistency", () => { 245 + it("hex, multikey, and JWK all produce identical did:key", async () => { 246 + const keypair = await Secp256k1PrivateKeyExportable.createKeypair(); 247 + const rawHex = await keypair.exportPrivateKey("rawHex"); 248 + const multikey = await keypair.exportPrivateKey("multikey"); 249 + const jwk = await keypair.exportPrivateKey("jwk"); 250 + 251 + const [fromHex, fromMultikey, fromJwk] = await Promise.all([ 252 + plcOps.getKeyPair(rawHex), 253 + plcOps.getKeyPair(multikey), 254 + plcOps.getKeyPair(JSON.stringify(jwk)), 255 + ]); 256 + 257 + expect(fromHex.didPublicKey).toBe(fromMultikey.didPublicKey); 258 + expect(fromHex.didPublicKey).toBe(fromJwk.didPublicKey); 259 + }); 260 + 261 + it("hex, multikey, JWK, and base58 all match for P-256", async () => { 262 + const keypair = await P256PrivateKeyExportable.createKeypair(); 263 + const rawHex = await keypair.exportPrivateKey("rawHex"); 264 + const multikey = await keypair.exportPrivateKey("multikey"); 265 + const jwk = await keypair.exportPrivateKey("jwk"); 266 + const rawBytes = await keypair.exportPrivateKey("raw"); 267 + const base58 = toBase58Btc(rawBytes); 268 + 269 + const [fromHex, fromMultikey, fromJwk, fromBase58] = await Promise.all([ 270 + plcOps.getKeyPair(rawHex, "p256"), 271 + plcOps.getKeyPair(multikey), 272 + plcOps.getKeyPair(JSON.stringify(jwk)), 273 + plcOps.getKeyPair(base58, "p256"), 274 + ]); 275 + 276 + expect(fromHex.didPublicKey).toBe(fromMultikey.didPublicKey); 277 + expect(fromHex.didPublicKey).toBe(fromJwk.didPublicKey); 278 + expect(fromHex.didPublicKey).toBe(fromBase58.didPublicKey); 279 + }); 280 + }); 281 + 282 + describe("getKeyPair - error cases", () => { 283 + it("rejects empty string", async () => { 284 + await expect(plcOps.getKeyPair("")).rejects.toThrow( 285 + "Private key is required", 286 + ); 287 + }); 288 + 289 + it("rejects whitespace-only", async () => { 290 + await expect(plcOps.getKeyPair(" ")).rejects.toThrow( 291 + "Private key is required", 292 + ); 293 + }); 294 + 295 + it("rejects did:key: prefix with helpful error", async () => { 296 + await expect( 297 + plcOps.getKeyPair( 298 + "did:key:zQ3shunBKoL5VRgSEX7RQGQEG3TTo6MPVWvT7tcVjjwZCWMEE", 299 + ), 300 + ).rejects.toThrow("public key"); 301 + }); 302 + 303 + it("rejects unrecognized garbage", async () => { 304 + await expect(plcOps.getKeyPair("!!!invalid!!!")).rejects.toThrow( 305 + "Unrecognized key format", 306 + ); 307 + }); 308 + 309 + it("rejects hex with non-hex chars in 64-char string", async () => { 310 + const almostHex = "g".repeat(64); 311 + await expect(plcOps.getKeyPair(almostHex)).rejects.toThrow(); 100 312 }); 101 313 }); 102 314