Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

DRAFT: Better code quality via type safety #5

merged opened by lewis.moe targeting main from fix/code-quality-in-general

Ensuring at compile-time that we're definitely handling possible early failures in functions

Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mdbo7zq5ae22
+4990 -3973
Diff #2
+27
.sqlx/query-03fc2ba947ee547e000b044fafb486e71b9b65a7dd923b5354c5a4dde98332eb.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "UPDATE users SET preferred_comms_channel = $1, updated_at = NOW() WHERE did = $2", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + { 9 + "Custom": { 10 + "name": "comms_channel", 11 + "kind": { 12 + "Enum": [ 13 + "email", 14 + "discord", 15 + "telegram", 16 + "signal" 17 + ] 18 + } 19 + } 20 + }, 21 + "Text" 22 + ] 23 + }, 24 + "nullable": [] 25 + }, 26 + "hash": "03fc2ba947ee547e000b044fafb486e71b9b65a7dd923b5354c5a4dde98332eb" 27 + }
+3 -3
.sqlx/query-805a344e73f2c19caaffe71de227ddd505599839033e83ae4be5b243d343d651.json .sqlx/query-0d32a592a97ad47c65aa37cf0d45417f2966fcbd688be7434626ae5f6971fa1f.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq = $1", 3 + "query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq = $1", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 20 20 }, 21 21 { 22 22 "ordinal": 3, 23 - "name": "event_type", 23 + "name": "event_type: RepoEventType", 24 24 "type_info": "Text" 25 25 }, 26 26 { ··· 96 96 true 97 97 ] 98 98 }, 99 - "hash": "805a344e73f2c19caaffe71de227ddd505599839033e83ae4be5b243d343d651" 99 + "hash": "0d32a592a97ad47c65aa37cf0d45417f2966fcbd688be7434626ae5f6971fa1f" 100 100 }
+28
.sqlx/query-200ecf153f1433ae8f6fbe81ab888a04ddd035ec9e88ef5f207e2487a02a1224.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT available_uses, COALESCE(disabled, false) as \"disabled!\" FROM invite_codes WHERE code = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "available_uses", 9 + "type_info": "Int4" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "disabled!", 14 + "type_info": "Bool" 15 + } 16 + ], 17 + "parameters": { 18 + "Left": [ 19 + "Text" 20 + ] 21 + }, 22 + "nullable": [ 23 + false, 24 + null 25 + ] 26 + }, 27 + "hash": "200ecf153f1433ae8f6fbe81ab888a04ddd035ec9e88ef5f207e2487a02a1224" 28 + }
+17 -5
.sqlx/query-426fedba6791c420fe7af6decc296c681d05a5c24a38b8cd7083c8dfa9178ded.json .sqlx/query-247470d26a90617e7dc9b5b3a2146ee3f54448e3c24943f7005e3a8e28820d43.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT\n email,\n preferred_comms_channel::text as \"preferred_channel!\",\n discord_id,\n discord_verified,\n telegram_username,\n telegram_verified,\n signal_number,\n signal_verified\n FROM users WHERE did = $1", 3 + "query": "SELECT\n email,\n preferred_comms_channel as \"preferred_channel!: CommsChannel\",\n discord_id,\n discord_verified,\n telegram_username,\n telegram_verified,\n signal_number,\n signal_verified\n FROM users WHERE did = $1", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 10 10 }, 11 11 { 12 12 "ordinal": 1, 13 - "name": "preferred_channel!", 14 - "type_info": "Text" 13 + "name": "preferred_channel!: CommsChannel", 14 + "type_info": { 15 + "Custom": { 16 + "name": "comms_channel", 17 + "kind": { 18 + "Enum": [ 19 + "email", 20 + "discord", 21 + "telegram", 22 + "signal" 23 + ] 24 + } 25 + } 26 + } 15 27 }, 16 28 { 17 29 "ordinal": 2, ··· 51 63 }, 52 64 "nullable": [ 53 65 true, 54 - null, 66 + false, 55 67 true, 56 68 false, 57 69 true, ··· 60 72 false 61 73 ] 62 74 }, 63 - "hash": "426fedba6791c420fe7af6decc296c681d05a5c24a38b8cd7083c8dfa9178ded" 75 + "hash": "247470d26a90617e7dc9b5b3a2146ee3f54448e3c24943f7005e3a8e28820d43" 64 76 }
+5 -5
.sqlx/query-9fea6394495b70ef5af2c2f5298e651d1ae78aa9ac6b03f952b6b0416023f671.json .sqlx/query-25309f4a08845a49557d694ad9b5b9a137be4dcce28e9293551c8c3fd40fdd86.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "\n SELECT\n created_at,\n channel as \"channel: String\",\n comms_type as \"comms_type: String\",\n status as \"status: String\",\n subject,\n body\n FROM comms_queue\n WHERE user_id = $1\n ORDER BY created_at DESC\n LIMIT $2\n ", 3 + "query": "\n SELECT\n created_at,\n channel as \"channel: CommsChannel\",\n comms_type as \"comms_type: CommsType\",\n status as \"status: CommsStatus\",\n subject,\n body\n FROM comms_queue\n WHERE user_id = $1\n ORDER BY created_at DESC\n LIMIT $2\n ", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 10 10 }, 11 11 { 12 12 "ordinal": 1, 13 - "name": "channel: String", 13 + "name": "channel: CommsChannel", 14 14 "type_info": { 15 15 "Custom": { 16 16 "name": "comms_channel", ··· 27 27 }, 28 28 { 29 29 "ordinal": 2, 30 - "name": "comms_type: String", 30 + "name": "comms_type: CommsType", 31 31 "type_info": { 32 32 "Custom": { 33 33 "name": "comms_type", ··· 52 52 }, 53 53 { 54 54 "ordinal": 3, 55 - "name": "status: String", 55 + "name": "status: CommsStatus", 56 56 "type_info": { 57 57 "Custom": { 58 58 "name": "comms_status", ··· 93 93 false 94 94 ] 95 95 }, 96 - "hash": "9fea6394495b70ef5af2c2f5298e651d1ae78aa9ac6b03f952b6b0416023f671" 96 + "hash": "25309f4a08845a49557d694ad9b5b9a137be4dcce28e9293551c8c3fd40fdd86" 97 97 }
-22
.sqlx/query-36441073d3fb87230f88ddce4e597c248fbf7360e510d703b9eec42efe9e049e.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT (available_uses > 0 AND NOT COALESCE(disabled, false)) as \"valid!\" FROM invite_codes WHERE code = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "valid!", 9 - "type_info": "Bool" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text" 15 - ] 16 - }, 17 - "nullable": [ 18 - null 19 - ] 20 - }, 21 - "hash": "36441073d3fb87230f88ddce4e597c248fbf7360e510d703b9eec42efe9e049e" 22 - }
+15 -5
.sqlx/query-445c2ebb72f3833119f32284b9e721cf34c8ae581e6ae58a392fc93e77a7a015.json .sqlx/query-7061e8763ef7d91ff152ed0124f99e1820172fd06916d225ca6c5137a507b8fa.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "\n SELECT id, did, email, password_hash, password_required, two_factor_enabled,\n preferred_comms_channel as \"preferred_comms_channel!: CommsChannel\",\n deactivated_at, takedown_ref,\n email_verified, discord_verified, telegram_verified, signal_verified,\n account_type::text as \"account_type!\"\n FROM users\n WHERE handle = $1 OR email = $1\n ", 3 + "query": "\n SELECT id, did, email, password_hash, password_required, two_factor_enabled,\n preferred_comms_channel as \"preferred_comms_channel!: CommsChannel\",\n deactivated_at, takedown_ref,\n email_verified, discord_verified, telegram_verified, signal_verified,\n account_type as \"account_type!: AccountType\"\n FROM users\n WHERE handle = $1 OR email = $1\n ", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 82 82 }, 83 83 { 84 84 "ordinal": 13, 85 - "name": "account_type!", 86 - "type_info": "Text" 85 + "name": "account_type!: AccountType", 86 + "type_info": { 87 + "Custom": { 88 + "name": "account_type", 89 + "kind": { 90 + "Enum": [ 91 + "personal", 92 + "delegated" 93 + ] 94 + } 95 + } 96 + } 87 97 } 88 98 ], 89 99 "parameters": { ··· 105 115 false, 106 116 false, 107 117 false, 108 - null 118 + false 109 119 ] 110 120 }, 111 - "hash": "445c2ebb72f3833119f32284b9e721cf34c8ae581e6ae58a392fc93e77a7a015" 121 + "hash": "7061e8763ef7d91ff152ed0124f99e1820172fd06916d225ca6c5137a507b8fa" 112 122 }
+3 -3
.sqlx/query-605dc962cf86004de763aee65757a5a77da150b36aa8470c52fd5835e9b895fc.json .sqlx/query-b26bf97a27783eb7fb524a92dda3e68ef8470a9751fcaefe5fd2d7909dead54b.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1 AND seq < $2\n ORDER BY seq ASC", 3 + "query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 20 20 }, 21 21 { 22 22 "ordinal": 3, 23 - "name": "event_type", 23 + "name": "event_type: RepoEventType", 24 24 "type_info": "Text" 25 25 }, 26 26 { ··· 97 97 true 98 98 ] 99 99 }, 100 - "hash": "605dc962cf86004de763aee65757a5a77da150b36aa8470c52fd5835e9b895fc" 100 + "hash": "b26bf97a27783eb7fb524a92dda3e68ef8470a9751fcaefe5fd2d7909dead54b" 101 101 }
+3 -3
.sqlx/query-e2befe7fa07a1072a8b3f0ed6c1a54a39ffc8769aa65391ea282c78d2cd29f23.json .sqlx/query-b8101757a50075d20147014e450cb7deb7e58f84310690c7bde61e1834dc5903.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC", 3 + "query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 20 20 }, 21 21 { 22 22 "ordinal": 3, 23 - "name": "event_type", 23 + "name": "event_type: RepoEventType", 24 24 "type_info": "Text" 25 25 }, 26 26 { ··· 96 96 true 97 97 ] 98 98 }, 99 - "hash": "e2befe7fa07a1072a8b3f0ed6c1a54a39ffc8769aa65391ea282c78d2cd29f23" 99 + "hash": "b8101757a50075d20147014e450cb7deb7e58f84310690c7bde61e1834dc5903" 100 100 }
+3 -3
.sqlx/query-8f6a1e09351dc716eaadc9e30c5cfea45212901a139e98f0fccfacfbb3371dec.json .sqlx/query-d8524ad3f5dc03eb09ed60396a78df5003f804c43ad253d6476523eacdebf811.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2", 3 + "query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1 AND seq < $2\n ORDER BY seq ASC", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 20 20 }, 21 21 { 22 22 "ordinal": 3, 23 - "name": "event_type", 23 + "name": "event_type: RepoEventType", 24 24 "type_info": "Text" 25 25 }, 26 26 { ··· 97 97 true 98 98 ] 99 99 }, 100 - "hash": "8f6a1e09351dc716eaadc9e30c5cfea45212901a139e98f0fccfacfbb3371dec" 100 + "hash": "d8524ad3f5dc03eb09ed60396a78df5003f804c43ad253d6476523eacdebf811" 101 101 }
+52
.sqlx/query-d8fd97c8be3211b2509669dd859245b14e15f81a42d7e0c4c428b65f466af5ee.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT email, handle, preferred_comms_channel as \"preferred_channel!: CommsChannel\", preferred_locale\n FROM users WHERE id = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "email", 9 + "type_info": "Text" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "handle", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "preferred_channel!: CommsChannel", 19 + "type_info": { 20 + "Custom": { 21 + "name": "comms_channel", 22 + "kind": { 23 + "Enum": [ 24 + "email", 25 + "discord", 26 + "telegram", 27 + "signal" 28 + ] 29 + } 30 + } 31 + } 32 + }, 33 + { 34 + "ordinal": 3, 35 + "name": "preferred_locale", 36 + "type_info": "Varchar" 37 + } 38 + ], 39 + "parameters": { 40 + "Left": [ 41 + "Uuid" 42 + ] 43 + }, 44 + "nullable": [ 45 + true, 46 + false, 47 + false, 48 + true 49 + ] 50 + }, 51 + "hash": "d8fd97c8be3211b2509669dd859245b14e15f81a42d7e0c4c428b65f466af5ee" 52 + }
-40
.sqlx/query-e3aeec9a759b2b68cb11fa48b5d34ffc19430a6b16adb0c49307da0cacdf1ca3.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT email, handle, preferred_comms_channel::text as \"preferred_channel!\", preferred_locale\n FROM users WHERE id = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "email", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "handle", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "preferred_channel!", 19 - "type_info": "Text" 20 - }, 21 - { 22 - "ordinal": 3, 23 - "name": "preferred_locale", 24 - "type_info": "Varchar" 25 - } 26 - ], 27 - "parameters": { 28 - "Left": [ 29 - "Uuid" 30 - ] 31 - }, 32 - "nullable": [ 33 - true, 34 - false, 35 - null, 36 - true 37 - ] 38 - }, 39 - "hash": "e3aeec9a759b2b68cb11fa48b5d34ffc19430a6b16adb0c49307da0cacdf1ca3" 40 - }
+3 -3
.sqlx/query-caffa68d10445a42878b66e6b0224dafb8527c8a4cc9806d6f733edff72bc9db.json .sqlx/query-e7aa1080be9eb3a8ddf1f050c93dc8afd10478f41e22307014784b4ee3740b4a.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2", 3 + "query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 20 20 }, 21 21 { 22 22 "ordinal": 3, 23 - "name": "event_type", 23 + "name": "event_type: RepoEventType", 24 24 "type_info": "Text" 25 25 }, 26 26 { ··· 97 97 true 98 98 ] 99 99 }, 100 - "hash": "caffa68d10445a42878b66e6b0224dafb8527c8a4cc9806d6f733edff72bc9db" 100 + "hash": "e7aa1080be9eb3a8ddf1f050c93dc8afd10478f41e22307014784b4ee3740b4a" 101 101 }
+3 -2
Cargo.lock
··· 5733 5733 5734 5734 [[package]] 5735 5735 name = "tokio-util" 5736 - version = "0.7.17" 5736 + version = "0.7.18" 5737 5737 source = "registry+https://github.com/rust-lang/crates.io-index" 5738 - checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" 5738 + checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" 5739 5739 dependencies = [ 5740 5740 "bytes", 5741 5741 "futures-core", ··· 6115 6115 "thiserror 2.0.17", 6116 6116 "tokio", 6117 6117 "tokio-tungstenite", 6118 + "tokio-util", 6118 6119 "tower", 6119 6120 "tower-http", 6120 6121 "tower-layer",
+1
Cargo.toml
··· 87 87 subtle = "2.5" 88 88 thiserror = "2.0" 89 89 tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process"] } 90 + tokio-util = "0.7.18" 90 91 tokio-tungstenite = { version = "0.28", features = ["native-tls"] } 91 92 totp-rs = { version = "5", features = ["qr"] } 92 93 tower = "0.5"
+51
crates/tranquil-db-traits/src/channel_verification.rs
··· 1 + use crate::CommsChannel; 2 + use serde::{Deserialize, Serialize}; 3 + 4 + #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] 5 + pub struct ChannelVerificationStatus { 6 + pub email: bool, 7 + pub discord: bool, 8 + pub telegram: bool, 9 + pub signal: bool, 10 + } 11 + 12 + impl ChannelVerificationStatus { 13 + pub fn new(email: bool, discord: bool, telegram: bool, signal: bool) -> Self { 14 + Self { 15 + email, 16 + discord, 17 + telegram, 18 + signal, 19 + } 20 + } 21 + 22 + pub fn has_any_verified(&self) -> bool { 23 + self.email || self.discord || self.telegram || self.signal 24 + } 25 + 26 + pub fn verified_channels(&self) -> Vec<CommsChannel> { 27 + let mut channels = Vec::with_capacity(4); 28 + if self.email { 29 + channels.push(CommsChannel::Email); 30 + } 31 + if self.discord { 32 + channels.push(CommsChannel::Discord); 33 + } 34 + if self.telegram { 35 + channels.push(CommsChannel::Telegram); 36 + } 37 + if self.signal { 38 + channels.push(CommsChannel::Signal); 39 + } 40 + channels 41 + } 42 + 43 + pub fn is_verified(&self, channel: CommsChannel) -> bool { 44 + match channel { 45 + CommsChannel::Email => self.email, 46 + CommsChannel::Discord => self.discord, 47 + CommsChannel::Telegram => self.telegram, 48 + CommsChannel::Signal => self.signal, 49 + } 50 + } 51 + }
+6 -5
crates/tranquil-db-traits/src/delegation.rs
··· 5 5 use uuid::Uuid; 6 6 7 7 use crate::DbError; 8 + use crate::scope::DbScope; 8 9 9 10 #[derive(Debug, Clone, Serialize, Deserialize)] 10 11 pub struct DelegationGrant { 11 12 pub id: Uuid, 12 13 pub delegated_did: Did, 13 14 pub controller_did: Did, 14 - pub granted_scopes: String, 15 + pub granted_scopes: DbScope, 15 16 pub granted_at: DateTime<Utc>, 16 17 pub granted_by: Did, 17 18 pub revoked_at: Option<DateTime<Utc>>, ··· 22 23 pub struct DelegatedAccountInfo { 23 24 pub did: Did, 24 25 pub handle: Handle, 25 - pub granted_scopes: String, 26 + pub granted_scopes: DbScope, 26 27 pub granted_at: DateTime<Utc>, 27 28 } 28 29 ··· 30 31 pub struct ControllerInfo { 31 32 pub did: Did, 32 33 pub handle: Handle, 33 - pub granted_scopes: String, 34 + pub granted_scopes: DbScope, 34 35 pub granted_at: DateTime<Utc>, 35 36 pub is_active: bool, 36 37 } ··· 67 68 &self, 68 69 delegated_did: &Did, 69 70 controller_did: &Did, 70 - granted_scopes: &str, 71 + granted_scopes: &DbScope, 71 72 granted_by: &Did, 72 73 ) -> Result<Uuid, DbError>; 73 74 ··· 82 83 &self, 83 84 delegated_did: &Did, 84 85 controller_did: &Did, 85 - new_scopes: &str, 86 + new_scopes: &DbScope, 86 87 ) -> Result<bool, DbError>; 87 88 88 89 async fn get_delegation(
+63 -7
crates/tranquil-db-traits/src/infra.rs
··· 5 5 use uuid::Uuid; 6 6 7 7 use crate::DbError; 8 + use crate::invite_code::{InviteCodeError, ValidatedInviteCode}; 8 9 9 10 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] 10 11 pub enum InviteCodeSortOrder { ··· 13 14 Usage, 14 15 } 15 16 17 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] 18 + pub enum InviteCodeState { 19 + #[default] 20 + Active, 21 + Disabled, 22 + } 23 + 24 + impl InviteCodeState { 25 + pub fn is_active(self) -> bool { 26 + matches!(self, Self::Active) 27 + } 28 + 29 + pub fn is_disabled(self) -> bool { 30 + matches!(self, Self::Disabled) 31 + } 32 + } 33 + 34 + impl From<bool> for InviteCodeState { 35 + fn from(disabled: bool) -> Self { 36 + if disabled { 37 + Self::Disabled 38 + } else { 39 + Self::Active 40 + } 41 + } 42 + } 43 + 44 + impl From<Option<bool>> for InviteCodeState { 45 + fn from(disabled: Option<bool>) -> Self { 46 + Self::from(disabled.unwrap_or(false)) 47 + } 48 + } 49 + 50 + impl From<InviteCodeState> for bool { 51 + fn from(state: InviteCodeState) -> Self { 52 + matches!(state, InviteCodeState::Disabled) 53 + } 54 + } 55 + 16 56 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 17 57 #[sqlx(type_name = "comms_channel", rename_all = "snake_case")] 18 58 pub enum CommsChannel { ··· 72 112 pub struct InviteCodeInfo { 73 113 pub code: String, 74 114 pub available_uses: i32, 75 - pub disabled: bool, 115 + pub state: InviteCodeState, 76 116 pub for_account: Option<Did>, 77 117 pub created_at: DateTime<Utc>, 78 118 pub created_by: Option<Did>, ··· 95 135 pub created_at: DateTime<Utc>, 96 136 } 97 137 138 + impl InviteCodeRow { 139 + pub fn state(&self) -> InviteCodeState { 140 + InviteCodeState::from(self.disabled) 141 + } 142 + } 143 + 98 144 #[derive(Debug, Clone)] 99 145 pub struct ReservedSigningKey { 100 146 pub id: Uuid, ··· 148 194 149 195 async fn get_invite_code_available_uses(&self, code: &str) -> Result<Option<i32>, DbError>; 150 196 151 - async fn is_invite_code_valid(&self, code: &str) -> Result<bool, DbError>; 197 + async fn validate_invite_code<'a>( 198 + &self, 199 + code: &'a str, 200 + ) -> Result<ValidatedInviteCode<'a>, InviteCodeError>; 152 201 153 - async fn decrement_invite_code_uses(&self, code: &str) -> Result<(), DbError>; 202 + async fn decrement_invite_code_uses( 203 + &self, 204 + code: &ValidatedInviteCode<'_>, 205 + ) -> Result<(), DbError>; 154 206 155 - async fn record_invite_code_use(&self, code: &str, used_by_user: Uuid) -> Result<(), DbError>; 207 + async fn record_invite_code_use( 208 + &self, 209 + code: &ValidatedInviteCode<'_>, 210 + used_by_user: Uuid, 211 + ) -> Result<(), DbError>; 156 212 157 213 async fn get_invite_codes_for_account( 158 214 &self, ··· 317 373 #[derive(Debug, Clone)] 318 374 pub struct NotificationHistoryRow { 319 375 pub created_at: DateTime<Utc>, 320 - pub channel: String, 321 - pub comms_type: String, 322 - pub status: String, 376 + pub channel: CommsChannel, 377 + pub comms_type: CommsType, 378 + pub status: CommsStatus, 323 379 pub subject: Option<String>, 324 380 pub body: String, 325 381 }
+56
crates/tranquil-db-traits/src/invite_code.rs
··· 1 + use std::marker::PhantomData; 2 + 3 + use crate::DbError; 4 + 5 + #[derive(Debug)] 6 + pub struct ValidatedInviteCode<'a> { 7 + code: &'a str, 8 + _marker: PhantomData<&'a ()>, 9 + } 10 + 11 + impl<'a> ValidatedInviteCode<'a> { 12 + pub fn new_validated(code: &'a str) -> Self { 13 + Self { 14 + code, 15 + _marker: PhantomData, 16 + } 17 + } 18 + 19 + pub fn code(&self) -> &str { 20 + self.code 21 + } 22 + } 23 + 24 + #[derive(Debug)] 25 + pub enum InviteCodeError { 26 + NotFound, 27 + ExhaustedUses, 28 + Disabled, 29 + DatabaseError(DbError), 30 + } 31 + 32 + impl std::fmt::Display for InviteCodeError { 33 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 34 + match self { 35 + Self::NotFound => write!(f, "Invite code not found"), 36 + Self::ExhaustedUses => write!(f, "Invite code has no remaining uses"), 37 + Self::Disabled => write!(f, "Invite code is disabled"), 38 + Self::DatabaseError(e) => write!(f, "Database error: {}", e), 39 + } 40 + } 41 + } 42 + 43 + impl std::error::Error for InviteCodeError { 44 + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 45 + match self { 46 + Self::DatabaseError(e) => Some(e), 47 + _ => None, 48 + } 49 + } 50 + } 51 + 52 + impl From<DbError> for InviteCodeError { 53 + fn from(e: DbError) -> Self { 54 + Self::DatabaseError(e) 55 + } 56 + }
+30 -19
crates/tranquil-db-traits/src/lib.rs
··· 1 1 mod backlink; 2 2 mod backup; 3 3 mod blob; 4 + mod channel_verification; 4 5 mod delegation; 5 6 mod error; 6 7 mod infra; 8 + mod invite_code; 7 9 mod oauth; 8 10 mod repo; 11 + mod scope; 12 + mod sequence; 9 13 mod session; 10 14 mod sso; 11 15 mod user; ··· 16 20 OldBackupInfo, UserBackupInfo, 17 21 }; 18 22 pub use blob::{BlobForExport, BlobMetadata, BlobRepository, BlobWithTakedown, MissingBlobInfo}; 23 + pub use channel_verification::ChannelVerificationStatus; 19 24 pub use delegation::{ 20 25 AuditLogEntry, ControllerInfo, DelegatedAccountInfo, DelegationActionType, DelegationGrant, 21 26 DelegationRepository, ··· 23 28 pub use error::DbError; 24 29 pub use infra::{ 25 30 AdminAccountInfo, CommsChannel, CommsStatus, CommsType, DeletionRequest, InfraRepository, 26 - InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeUse, NotificationHistoryRow, 27 - QueuedComms, ReservedSigningKey, 31 + InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeState, InviteCodeUse, 32 + NotificationHistoryRow, QueuedComms, ReservedSigningKey, 28 33 }; 34 + pub use invite_code::{InviteCodeError, ValidatedInviteCode}; 29 35 pub use oauth::{ 30 36 DeviceAccountRow, DeviceTrustInfo, OAuthRepository, OAuthSessionListItem, RefreshTokenLookup, 31 - ScopePreference, TrustedDeviceRow, TwoFactorChallenge, 37 + ScopePreference, TokenFamilyId, TrustedDeviceRow, TwoFactorChallenge, 32 38 }; 33 39 pub use repo::{ 34 - ApplyCommitError, ApplyCommitInput, ApplyCommitResult, BrokenGenesisCommit, CommitEventData, 35 - EventBlocksCids, FullRecordInfo, ImportBlock, ImportRecord, ImportRepoError, RecordDelete, 36 - RecordInfo, RecordUpsert, RecordWithTakedown, RepoAccountInfo, RepoEventNotifier, 37 - RepoEventReceiver, RepoInfo, RepoListItem, RepoRepository, RepoSeqEvent, RepoWithoutRev, 38 - SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks, 40 + AccountStatus, ApplyCommitError, ApplyCommitInput, ApplyCommitResult, BrokenGenesisCommit, 41 + CommitEventData, EventBlocksCids, FullRecordInfo, ImportBlock, ImportRecord, ImportRepoError, 42 + RecordDelete, RecordInfo, RecordUpsert, RecordWithTakedown, RepoAccountInfo, RepoEventNotifier, 43 + RepoEventReceiver, RepoEventType, RepoInfo, RepoListItem, RepoRepository, RepoSeqEvent, 44 + RepoWithoutRev, SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks, 39 45 }; 46 + pub use scope::{DbScope, InvalidScopeError}; 47 + pub use sequence::{SequenceNumber, deserialize_optional_sequence}; 40 48 pub use session::{ 41 - AppPasswordCreate, AppPasswordRecord, RefreshSessionResult, SessionForRefresh, SessionListItem, 42 - SessionMfaStatus, SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate, 49 + AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, LoginType, RefreshSessionResult, 50 + SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus, SessionRefreshData, 51 + SessionRepository, SessionToken, SessionTokenCreate, 43 52 }; 44 53 pub use sso::{ 45 - ExternalIdentity, SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository, 54 + ExternalEmail, ExternalIdentity, ExternalUserId, ExternalUsername, SsoAction, SsoAuthState, 55 + SsoPendingRegistration, SsoProviderType, SsoRepository, 46 56 }; 47 57 pub use user::{ 48 - AccountSearchResult, CompletePasskeySetupInput, CreateAccountError, 58 + AccountSearchResult, AccountType, CompletePasskeySetupInput, CreateAccountError, 49 59 CreateDelegatedAccountInput, CreatePasskeyAccountInput, CreatePasswordAccountInput, 50 60 CreatePasswordAccountResult, CreateSsoAccountInput, DidWebOverrides, 51 61 MigrationReactivationError, MigrationReactivationInput, NotificationPrefs, OAuthTokenWithUser, 52 62 PasswordResetResult, ReactivatedAccountInfo, RecoverPasskeyAccountInput, 53 63 RecoverPasskeyAccountResult, ScheduledDeletionAccount, StoredBackupCode, StoredPasskey, 54 - TotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs, UserConfirmSignup, UserDidWebInfo, 55 - UserEmailInfo, UserForDeletion, UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery, 56 - UserForPasskeySetup, UserForRecovery, UserForVerification, UserIdAndHandle, 57 - UserIdAndPasswordHash, UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId, 58 - UserLegacyLoginPref, UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo, 59 - UserRepository, UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo, 60 - UserStatus, UserVerificationInfo, UserWithKey, 64 + TotpRecord, TotpRecordState, UnverifiedTotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs, 65 + UserConfirmSignup, UserDidWebInfo, UserEmailInfo, UserForDeletion, UserForDidDoc, 66 + UserForDidDocBuild, UserForPasskeyRecovery, UserForPasskeySetup, UserForRecovery, 67 + UserForVerification, UserIdAndHandle, UserIdAndPasswordHash, UserIdHandleEmail, 68 + UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref, UserLoginCheck, 69 + UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository, UserResendVerification, 70 + UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus, UserVerificationInfo, UserWithKey, 71 + VerifiedTotpRecord, 61 72 };
+47 -12
crates/tranquil-db-traits/src/oauth.rs
··· 10 10 11 11 use crate::DbError; 12 12 13 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 14 + pub struct TokenFamilyId(i32); 15 + 16 + impl TokenFamilyId { 17 + pub fn new(id: i32) -> Self { 18 + Self(id) 19 + } 20 + 21 + pub fn as_i32(self) -> i32 { 22 + self.0 23 + } 24 + } 25 + 26 + impl From<i32> for TokenFamilyId { 27 + fn from(id: i32) -> Self { 28 + Self(id) 29 + } 30 + } 31 + 32 + impl From<TokenFamilyId> for i32 { 33 + fn from(id: TokenFamilyId) -> Self { 34 + id.0 35 + } 36 + } 37 + 38 + impl std::fmt::Display for TokenFamilyId { 39 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 40 + write!(f, "{}", self.0) 41 + } 42 + } 43 + 13 44 #[derive(Debug, Clone, Serialize, Deserialize)] 14 45 pub struct ScopePreference { 15 46 pub scope: String, ··· 53 84 54 85 #[derive(Debug, Clone)] 55 86 pub struct OAuthSessionListItem { 56 - pub id: i32, 87 + pub id: TokenFamilyId, 57 88 pub token_id: TokenId, 58 89 pub created_at: DateTime<Utc>, 59 90 pub expires_at: DateTime<Utc>, ··· 62 93 63 94 pub enum RefreshTokenLookup { 64 95 Valid { 65 - db_id: i32, 96 + db_id: TokenFamilyId, 66 97 token_data: TokenData, 67 98 }, 68 99 InGracePeriod { 69 - db_id: i32, 100 + db_id: TokenFamilyId, 70 101 token_data: TokenData, 71 102 rotated_at: DateTime<Utc>, 72 103 }, 73 104 Used { 74 - original_token_id: i32, 105 + original_token_id: TokenFamilyId, 75 106 }, 76 107 Expired { 77 - db_id: i32, 108 + db_id: TokenFamilyId, 78 109 }, 79 110 NotFound, 80 111 } ··· 93 124 94 125 #[async_trait] 95 126 pub trait OAuthRepository: Send + Sync { 96 - async fn create_token(&self, data: &TokenData) -> Result<i32, DbError>; 127 + async fn create_token(&self, data: &TokenData) -> Result<TokenFamilyId, DbError>; 97 128 async fn get_token_by_id(&self, token_id: &TokenId) -> Result<Option<TokenData>, DbError>; 98 129 async fn get_token_by_refresh_token( 99 130 &self, 100 131 refresh_token: &RefreshToken, 101 - ) -> Result<Option<(i32, TokenData)>, DbError>; 132 + ) -> Result<Option<(TokenFamilyId, TokenData)>, DbError>; 102 133 async fn get_token_by_previous_refresh_token( 103 134 &self, 104 135 refresh_token: &RefreshToken, 105 - ) -> Result<Option<(i32, TokenData)>, DbError>; 136 + ) -> Result<Option<(TokenFamilyId, TokenData)>, DbError>; 106 137 async fn rotate_token( 107 138 &self, 108 - old_db_id: i32, 139 + old_db_id: TokenFamilyId, 109 140 new_refresh_token: &RefreshToken, 110 141 new_expires_at: DateTime<Utc>, 111 142 ) -> Result<(), DbError>; 112 143 async fn check_refresh_token_used( 113 144 &self, 114 145 refresh_token: &RefreshToken, 115 - ) -> Result<Option<i32>, DbError>; 146 + ) -> Result<Option<TokenFamilyId>, DbError>; 116 147 async fn delete_token(&self, token_id: &TokenId) -> Result<(), DbError>; 117 - async fn delete_token_family(&self, db_id: i32) -> Result<(), DbError>; 148 + async fn delete_token_family(&self, db_id: TokenFamilyId) -> Result<(), DbError>; 118 149 async fn list_tokens_for_user(&self, did: &Did) -> Result<Vec<TokenData>, DbError>; 119 150 async fn count_tokens_for_user(&self, did: &Did) -> Result<i64, DbError>; 120 151 async fn delete_oldest_tokens_for_user( ··· 274 305 ) -> Result<(), DbError>; 275 306 276 307 async fn list_sessions_by_did(&self, did: &Did) -> Result<Vec<OAuthSessionListItem>, DbError>; 277 - async fn delete_session_by_id(&self, session_id: i32, did: &Did) -> Result<u64, DbError>; 308 + async fn delete_session_by_id( 309 + &self, 310 + session_id: TokenFamilyId, 311 + did: &Did, 312 + ) -> Result<u64, DbError>; 278 313 async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError>; 279 314 async fn delete_sessions_by_did_except( 280 315 &self,
+146 -24
crates/tranquil-db-traits/src/repo.rs
··· 5 5 use uuid::Uuid; 6 6 7 7 use crate::DbError; 8 + use crate::sequence::SequenceNumber; 9 + 10 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 11 + #[sqlx(type_name = "text", rename_all = "snake_case")] 12 + #[serde(rename_all = "snake_case")] 13 + pub enum RepoEventType { 14 + Commit, 15 + Identity, 16 + Account, 17 + Sync, 18 + } 19 + 20 + impl RepoEventType { 21 + pub fn as_str(&self) -> &'static str { 22 + match self { 23 + Self::Commit => "commit", 24 + Self::Identity => "identity", 25 + Self::Account => "account", 26 + Self::Sync => "sync", 27 + } 28 + } 29 + } 30 + 31 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 32 + #[sqlx(type_name = "text", rename_all = "lowercase")] 33 + #[serde(rename_all = "lowercase")] 34 + pub enum AccountStatus { 35 + Active, 36 + Takendown, 37 + Suspended, 38 + Deactivated, 39 + Deleted, 40 + } 41 + 42 + impl AccountStatus { 43 + pub fn as_str(&self) -> &'static str { 44 + match self { 45 + Self::Active => "active", 46 + Self::Takendown => "takendown", 47 + Self::Suspended => "suspended", 48 + Self::Deactivated => "deactivated", 49 + Self::Deleted => "deleted", 50 + } 51 + } 52 + 53 + pub fn for_firehose(&self) -> Option<&'static str> { 54 + match self { 55 + Self::Active => None, 56 + other => Some(other.as_str()), 57 + } 58 + } 59 + 60 + pub fn parse(s: &str) -> Option<Self> { 61 + match s.to_lowercase().as_str() { 62 + "active" => Some(Self::Active), 63 + "takendown" => Some(Self::Takendown), 64 + "suspended" => Some(Self::Suspended), 65 + "deactivated" => Some(Self::Deactivated), 66 + "deleted" => Some(Self::Deleted), 67 + _ => None, 68 + } 69 + } 70 + 71 + pub fn is_active(&self) -> bool { 72 + matches!(self, Self::Active) 73 + } 74 + 75 + pub fn is_takendown(&self) -> bool { 76 + matches!(self, Self::Takendown) 77 + } 78 + 79 + pub fn is_deactivated(&self) -> bool { 80 + matches!(self, Self::Deactivated) 81 + } 82 + 83 + pub fn is_suspended(&self) -> bool { 84 + matches!(self, Self::Suspended) 85 + } 86 + 87 + pub fn is_deleted(&self) -> bool { 88 + matches!(self, Self::Deleted) 89 + } 90 + 91 + pub fn allows_read(&self) -> bool { 92 + matches!(self, Self::Active | Self::Deactivated) 93 + } 94 + 95 + pub fn allows_write(&self) -> bool { 96 + matches!(self, Self::Active) 97 + } 98 + 99 + pub fn from_db_fields( 100 + takedown_ref: Option<&str>, 101 + deactivated_at: Option<DateTime<Utc>>, 102 + ) -> Self { 103 + if takedown_ref.is_some() { 104 + Self::Takendown 105 + } else if deactivated_at.is_some() { 106 + Self::Deactivated 107 + } else { 108 + Self::Active 109 + } 110 + } 111 + } 112 + 113 + impl std::fmt::Display for AccountStatus { 114 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 115 + f.write_str(self.as_str()) 116 + } 117 + } 8 118 9 119 #[derive(Debug, Clone, Serialize, Deserialize)] 10 120 pub struct RepoAccountInfo { ··· 49 159 50 160 #[derive(Debug, Clone)] 51 161 pub struct BrokenGenesisCommit { 52 - pub seq: i64, 162 + pub seq: SequenceNumber, 53 163 pub did: Did, 54 164 pub commit_cid: Option<CidLink>, 55 165 } ··· 69 179 70 180 #[derive(Debug, Clone, Serialize, Deserialize)] 71 181 pub struct RepoSeqEvent { 72 - pub seq: i64, 182 + pub seq: SequenceNumber, 73 183 } 74 184 75 185 #[derive(Debug, Clone, Serialize, Deserialize)] 76 186 pub struct SequencedEvent { 77 - pub seq: i64, 187 + pub seq: SequenceNumber, 78 188 pub did: Did, 79 189 pub created_at: DateTime<Utc>, 80 - pub event_type: String, 190 + pub event_type: RepoEventType, 81 191 pub commit_cid: Option<CidLink>, 82 192 pub prev_cid: Option<CidLink>, 83 193 pub prev_data_cid: Option<CidLink>, ··· 86 196 pub blocks_cids: Option<Vec<String>>, 87 197 pub handle: Option<Handle>, 88 198 pub active: Option<bool>, 89 - pub status: Option<String>, 199 + pub status: Option<AccountStatus>, 90 200 pub rev: Option<String>, 91 201 } 92 202 93 203 #[derive(Debug, Clone)] 94 204 pub struct CommitEventData { 95 205 pub did: Did, 96 - pub event_type: String, 206 + pub event_type: RepoEventType, 97 207 pub commit_cid: Option<CidLink>, 98 208 pub prev_cid: Option<CidLink>, 99 209 pub ops: Option<serde_json::Value>, ··· 283 393 284 394 async fn count_user_blocks(&self, user_id: Uuid) -> Result<i64, DbError>; 285 395 286 - async fn insert_commit_event(&self, data: &CommitEventData) -> Result<i64, DbError>; 396 + async fn insert_commit_event(&self, data: &CommitEventData) -> Result<SequenceNumber, DbError>; 287 397 288 398 async fn insert_identity_event( 289 399 &self, 290 400 did: &Did, 291 401 handle: Option<&Handle>, 292 - ) -> Result<i64, DbError>; 402 + ) -> Result<SequenceNumber, DbError>; 293 403 294 404 async fn insert_account_event( 295 405 &self, 296 406 did: &Did, 297 - active: bool, 298 - status: Option<&str>, 299 - ) -> Result<i64, DbError>; 407 + status: AccountStatus, 408 + ) -> Result<SequenceNumber, DbError>; 300 409 301 410 async fn insert_sync_event( 302 411 &self, 303 412 did: &Did, 304 413 commit_cid: &CidLink, 305 414 rev: Option<&str>, 306 - ) -> Result<i64, DbError>; 415 + ) -> Result<SequenceNumber, DbError>; 307 416 308 417 async fn insert_genesis_commit_event( 309 418 &self, ··· 311 420 commit_cid: &CidLink, 312 421 mst_root_cid: &CidLink, 313 422 rev: &str, 314 - ) -> Result<i64, DbError>; 423 + ) -> Result<SequenceNumber, DbError>; 315 424 316 - async fn update_seq_blocks_cids(&self, seq: i64, blocks_cids: &[String]) 317 - -> Result<(), DbError>; 425 + async fn update_seq_blocks_cids( 426 + &self, 427 + seq: SequenceNumber, 428 + blocks_cids: &[String], 429 + ) -> Result<(), DbError>; 318 430 319 - async fn delete_sequences_except(&self, did: &Did, keep_seq: i64) -> Result<(), DbError>; 431 + async fn delete_sequences_except( 432 + &self, 433 + did: &Did, 434 + keep_seq: SequenceNumber, 435 + ) -> Result<(), DbError>; 320 436 321 - async fn get_max_seq(&self) -> Result<i64, DbError>; 437 + async fn get_max_seq(&self) -> Result<SequenceNumber, DbError>; 322 438 323 - async fn get_min_seq_since(&self, since: DateTime<Utc>) -> Result<Option<i64>, DbError>; 439 + async fn get_min_seq_since( 440 + &self, 441 + since: DateTime<Utc>, 442 + ) -> Result<Option<SequenceNumber>, DbError>; 324 443 325 444 async fn get_account_with_repo(&self, did: &Did) -> Result<Option<RepoAccountInfo>, DbError>; 326 445 327 446 async fn get_events_since_seq( 328 447 &self, 329 - since_seq: i64, 448 + since_seq: SequenceNumber, 330 449 limit: Option<i64>, 331 450 ) -> Result<Vec<SequencedEvent>, DbError>; 332 451 333 452 async fn get_events_in_seq_range( 334 453 &self, 335 - start_seq: i64, 336 - end_seq: i64, 454 + start_seq: SequenceNumber, 455 + end_seq: SequenceNumber, 337 456 ) -> Result<Vec<SequencedEvent>, DbError>; 338 457 339 - async fn get_event_by_seq(&self, seq: i64) -> Result<Option<SequencedEvent>, DbError>; 458 + async fn get_event_by_seq( 459 + &self, 460 + seq: SequenceNumber, 461 + ) -> Result<Option<SequencedEvent>, DbError>; 340 462 341 463 async fn get_events_since_cursor( 342 464 &self, 343 - cursor: i64, 465 + cursor: SequenceNumber, 344 466 limit: i64, 345 467 ) -> Result<Vec<SequencedEvent>, DbError>; 346 468 ··· 359 481 async fn get_repo_root_cid_by_user_id(&self, user_id: Uuid) 360 482 -> Result<Option<CidLink>, DbError>; 361 483 362 - async fn notify_update(&self, seq: i64) -> Result<(), DbError>; 484 + async fn notify_update(&self, seq: SequenceNumber) -> Result<(), DbError>; 363 485 364 486 async fn import_repo_data( 365 487 &self,
+180
crates/tranquil-db-traits/src/scope.rs
··· 1 + use serde::{Deserialize, Deserializer, Serialize}; 2 + use std::fmt; 3 + 4 + #[derive(Debug, Clone, PartialEq, Eq)] 5 + pub struct DbScope(String); 6 + 7 + impl DbScope { 8 + pub fn new(scope: impl Into<String>) -> Result<Self, InvalidScopeError> { 9 + let scope = scope.into(); 10 + validate_scope_string(&scope)?; 11 + Ok(Self(scope)) 12 + } 13 + 14 + pub fn empty() -> Self { 15 + Self(String::new()) 16 + } 17 + 18 + pub fn from_db(scope: String) -> Self { 19 + match validate_scope_string(&scope) { 20 + Ok(()) => Self(scope), 21 + Err(e) => panic!("corrupted scope data from database: {}", e), 22 + } 23 + } 24 + 25 + pub fn as_str(&self) -> &str { 26 + &self.0 27 + } 28 + 29 + pub fn into_string(self) -> String { 30 + self.0 31 + } 32 + 33 + pub fn is_empty(&self) -> bool { 34 + self.0.is_empty() 35 + } 36 + } 37 + 38 + impl Default for DbScope { 39 + fn default() -> Self { 40 + Self::empty() 41 + } 42 + } 43 + 44 + impl fmt::Display for DbScope { 45 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 46 + write!(f, "{}", self.0) 47 + } 48 + } 49 + 50 + impl AsRef<str> for DbScope { 51 + fn as_ref(&self) -> &str { 52 + &self.0 53 + } 54 + } 55 + 56 + impl Serialize for DbScope { 57 + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 58 + where 59 + S: serde::Serializer, 60 + { 61 + self.0.serialize(serializer) 62 + } 63 + } 64 + 65 + impl<'de> Deserialize<'de> for DbScope { 66 + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 67 + where 68 + D: Deserializer<'de>, 69 + { 70 + let s = String::deserialize(deserializer)?; 71 + Self::new(s).map_err(serde::de::Error::custom) 72 + } 73 + } 74 + 75 + #[derive(Debug, Clone)] 76 + pub struct InvalidScopeError { 77 + message: String, 78 + } 79 + 80 + impl InvalidScopeError { 81 + pub fn new(message: impl Into<String>) -> Self { 82 + Self { 83 + message: message.into(), 84 + } 85 + } 86 + 87 + pub fn message(&self) -> &str { 88 + &self.message 89 + } 90 + } 91 + 92 + impl fmt::Display for InvalidScopeError { 93 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 94 + write!(f, "{}", self.message) 95 + } 96 + } 97 + 98 + impl std::error::Error for InvalidScopeError {} 99 + 100 + fn validate_scope_string(scopes: &str) -> Result<(), InvalidScopeError> { 101 + if scopes.is_empty() { 102 + return Ok(()); 103 + } 104 + 105 + scopes.split_whitespace().try_for_each(|scope| { 106 + let base = scope.split_once('?').map_or(scope, |(b, _)| b); 107 + if is_valid_scope_prefix(base) { 108 + Ok(()) 109 + } else { 110 + Err(InvalidScopeError::new(format!("Invalid scope: {}", scope))) 111 + } 112 + }) 113 + } 114 + 115 + fn is_valid_scope_prefix(base: &str) -> bool { 116 + const VALID_PREFIXES: [&str; 8] = [ 117 + "atproto", 118 + "repo:", 119 + "blob:", 120 + "rpc:", 121 + "account:", 122 + "identity:", 123 + "transition:", 124 + "include:", 125 + ]; 126 + 127 + VALID_PREFIXES 128 + .iter() 129 + .any(|prefix| base == prefix.trim_end_matches(':') || base.starts_with(prefix)) 130 + } 131 + 132 + #[cfg(test)] 133 + mod tests { 134 + use super::*; 135 + 136 + #[test] 137 + fn test_valid_scopes() { 138 + assert!(DbScope::new("atproto").is_ok()); 139 + assert!(DbScope::new("repo:*").is_ok()); 140 + assert!(DbScope::new("blob:*/*").is_ok()); 141 + assert!(DbScope::new("repo:* blob:*/*").is_ok()); 142 + assert!(DbScope::new("").is_ok()); 143 + assert!(DbScope::new("account:email?action=read").is_ok()); 144 + assert!(DbScope::new("identity:handle").is_ok()); 145 + assert!(DbScope::new("transition:generic").is_ok()); 146 + assert!(DbScope::new("include:app.bsky.authFullApp").is_ok()); 147 + } 148 + 149 + #[test] 150 + fn test_invalid_scopes() { 151 + assert!(DbScope::new("invalid:scope").is_err()); 152 + assert!(DbScope::new("garbage").is_err()); 153 + assert!(DbScope::new("repo:* invalid:scope").is_err()); 154 + } 155 + 156 + #[test] 157 + fn test_empty_scope() { 158 + let scope = DbScope::empty(); 159 + assert!(scope.is_empty()); 160 + assert_eq!(scope.as_str(), ""); 161 + } 162 + 163 + #[test] 164 + fn test_display() { 165 + let scope = DbScope::new("repo:*").unwrap(); 166 + assert_eq!(format!("{}", scope), "repo:*"); 167 + } 168 + 169 + #[test] 170 + #[should_panic(expected = "corrupted scope data from database")] 171 + fn test_from_db_panics_on_corrupted_data() { 172 + DbScope::from_db("totally_invalid_garbage_scope".to_string()); 173 + } 174 + 175 + #[test] 176 + fn test_from_db_accepts_valid_data() { 177 + let scope = DbScope::from_db("repo:* blob:*/*".to_string()); 178 + assert_eq!(scope.as_str(), "repo:* blob:*/*"); 179 + } 180 + }
+73
crates/tranquil-db-traits/src/sequence.rs
··· 1 + use serde::{Deserialize, Deserializer, Serialize, Serializer}; 2 + use std::fmt; 3 + 4 + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] 5 + pub struct SequenceNumber(i64); 6 + 7 + impl SequenceNumber { 8 + pub const ZERO: Self = Self(0); 9 + pub const UNSET: Self = Self(-1); 10 + 11 + pub fn new(n: i64) -> Option<Self> { 12 + if n >= 0 { Some(Self(n)) } else { None } 13 + } 14 + 15 + pub fn from_raw(n: i64) -> Self { 16 + Self(n) 17 + } 18 + 19 + pub fn as_i64(&self) -> i64 { 20 + self.0 21 + } 22 + 23 + pub fn is_valid(&self) -> bool { 24 + self.0 >= 0 25 + } 26 + } 27 + 28 + impl fmt::Display for SequenceNumber { 29 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 + write!(f, "{}", self.0) 31 + } 32 + } 33 + 34 + impl From<i64> for SequenceNumber { 35 + fn from(n: i64) -> Self { 36 + Self(n) 37 + } 38 + } 39 + 40 + impl From<SequenceNumber> for i64 { 41 + fn from(seq: SequenceNumber) -> Self { 42 + seq.0 43 + } 44 + } 45 + 46 + impl Serialize for SequenceNumber { 47 + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 48 + where 49 + S: Serializer, 50 + { 51 + self.0.serialize(serializer) 52 + } 53 + } 54 + 55 + impl<'de> Deserialize<'de> for SequenceNumber { 56 + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 57 + where 58 + D: Deserializer<'de>, 59 + { 60 + let n = i64::deserialize(deserializer)?; 61 + Ok(Self(n)) 62 + } 63 + } 64 + 65 + pub fn deserialize_optional_sequence<'de, D>( 66 + deserializer: D, 67 + ) -> Result<Option<SequenceNumber>, D::Error> 68 + where 69 + D: Deserializer<'de>, 70 + { 71 + let opt: Option<i64> = Option::deserialize(deserializer)?; 72 + Ok(opt.map(SequenceNumber::from_raw)) 73 + }
+107 -15
crates/tranquil-db-traits/src/session.rs
··· 5 5 6 6 use crate::DbError; 7 7 8 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] 9 + pub enum LoginType { 10 + #[default] 11 + Modern, 12 + Legacy, 13 + } 14 + 15 + impl LoginType { 16 + pub fn is_legacy(self) -> bool { 17 + matches!(self, Self::Legacy) 18 + } 19 + 20 + pub fn is_modern(self) -> bool { 21 + matches!(self, Self::Modern) 22 + } 23 + } 24 + 25 + impl From<bool> for LoginType { 26 + fn from(legacy: bool) -> Self { 27 + if legacy { Self::Legacy } else { Self::Modern } 28 + } 29 + } 30 + 31 + impl From<LoginType> for bool { 32 + fn from(lt: LoginType) -> Self { 33 + matches!(lt, LoginType::Legacy) 34 + } 35 + } 36 + 37 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] 38 + pub enum AppPasswordPrivilege { 39 + #[default] 40 + Standard, 41 + Privileged, 42 + } 43 + 44 + impl AppPasswordPrivilege { 45 + pub fn is_privileged(self) -> bool { 46 + matches!(self, Self::Privileged) 47 + } 48 + } 49 + 50 + impl From<bool> for AppPasswordPrivilege { 51 + fn from(privileged: bool) -> Self { 52 + if privileged { 53 + Self::Privileged 54 + } else { 55 + Self::Standard 56 + } 57 + } 58 + } 59 + 60 + impl From<AppPasswordPrivilege> for bool { 61 + fn from(p: AppPasswordPrivilege) -> Self { 62 + matches!(p, AppPasswordPrivilege::Privileged) 63 + } 64 + } 65 + 66 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 67 + pub struct SessionId(i32); 68 + 69 + impl SessionId { 70 + pub fn new(id: i32) -> Self { 71 + Self(id) 72 + } 73 + 74 + pub fn as_i32(self) -> i32 { 75 + self.0 76 + } 77 + } 78 + 79 + impl From<i32> for SessionId { 80 + fn from(id: i32) -> Self { 81 + Self(id) 82 + } 83 + } 84 + 85 + impl From<SessionId> for i32 { 86 + fn from(id: SessionId) -> Self { 87 + id.0 88 + } 89 + } 90 + 91 + impl std::fmt::Display for SessionId { 92 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 93 + write!(f, "{}", self.0) 94 + } 95 + } 96 + 8 97 #[derive(Debug, Clone)] 9 98 pub struct SessionToken { 10 - pub id: i32, 99 + pub id: SessionId, 11 100 pub did: Did, 12 101 pub access_jti: String, 13 102 pub refresh_jti: String, 14 103 pub access_expires_at: DateTime<Utc>, 15 104 pub refresh_expires_at: DateTime<Utc>, 16 - pub legacy_login: bool, 105 + pub login_type: LoginType, 17 106 pub mfa_verified: bool, 18 107 pub scope: Option<String>, 19 108 pub controller_did: Option<Did>, ··· 29 118 pub refresh_jti: String, 30 119 pub access_expires_at: DateTime<Utc>, 31 120 pub refresh_expires_at: DateTime<Utc>, 32 - pub legacy_login: bool, 121 + pub login_type: LoginType, 33 122 pub mfa_verified: bool, 34 123 pub scope: Option<String>, 35 124 pub controller_did: Option<Did>, ··· 38 127 39 128 #[derive(Debug, Clone)] 40 129 pub struct SessionForRefresh { 41 - pub id: i32, 130 + pub id: SessionId, 42 131 pub did: Did, 43 132 pub scope: Option<String>, 44 133 pub controller_did: Option<Did>, ··· 48 137 49 138 #[derive(Debug, Clone)] 50 139 pub struct SessionListItem { 51 - pub id: i32, 140 + pub id: SessionId, 52 141 pub access_jti: String, 53 142 pub created_at: DateTime<Utc>, 54 143 pub refresh_expires_at: DateTime<Utc>, ··· 61 150 pub name: String, 62 151 pub password_hash: String, 63 152 pub created_at: DateTime<Utc>, 64 - pub privileged: bool, 153 + pub privilege: AppPasswordPrivilege, 65 154 pub scopes: Option<String>, 66 155 pub created_by_controller_did: Option<Did>, 67 156 } ··· 71 160 pub user_id: Uuid, 72 161 pub name: String, 73 162 pub password_hash: String, 74 - pub privileged: bool, 163 + pub privilege: AppPasswordPrivilege, 75 164 pub scopes: Option<String>, 76 165 pub created_by_controller_did: Option<Did>, 77 166 } 78 167 79 168 #[derive(Debug, Clone)] 80 169 pub struct SessionMfaStatus { 81 - pub legacy_login: bool, 170 + pub login_type: LoginType, 82 171 pub mfa_verified: bool, 83 172 pub last_reauth_at: Option<DateTime<Utc>>, 84 173 } ··· 93 182 #[derive(Debug, Clone)] 94 183 pub struct SessionRefreshData { 95 184 pub old_refresh_jti: String, 96 - pub session_id: i32, 185 + pub session_id: SessionId, 97 186 pub new_access_jti: String, 98 187 pub new_refresh_jti: String, 99 188 pub new_access_expires_at: DateTime<Utc>, ··· 102 191 103 192 #[async_trait] 104 193 pub trait SessionRepository: Send + Sync { 105 - async fn create_session(&self, data: &SessionTokenCreate) -> Result<i32, DbError>; 194 + async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError>; 106 195 107 196 async fn get_session_by_access_jti( 108 197 &self, ··· 116 205 117 206 async fn update_session_tokens( 118 207 &self, 119 - session_id: i32, 208 + session_id: SessionId, 120 209 new_access_jti: &str, 121 210 new_refresh_jti: &str, 122 211 new_access_expires_at: DateTime<Utc>, ··· 125 214 126 215 async fn delete_session_by_access_jti(&self, access_jti: &str) -> Result<u64, DbError>; 127 216 128 - async fn delete_session_by_id(&self, session_id: i32) -> Result<u64, DbError>; 217 + async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError>; 129 218 130 219 async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError>; 131 220 ··· 139 228 140 229 async fn get_session_access_jti_by_id( 141 230 &self, 142 - session_id: i32, 231 + session_id: SessionId, 143 232 did: &Did, 144 233 ) -> Result<Option<String>, DbError>; 145 234 ··· 155 244 app_password_name: &str, 156 245 ) -> Result<Vec<String>, DbError>; 157 246 158 - async fn check_refresh_token_used(&self, refresh_jti: &str) -> Result<Option<i32>, DbError>; 247 + async fn check_refresh_token_used( 248 + &self, 249 + refresh_jti: &str, 250 + ) -> Result<Option<SessionId>, DbError>; 159 251 160 252 async fn mark_refresh_token_used( 161 253 &self, 162 254 refresh_jti: &str, 163 - session_id: i32, 255 + session_id: SessionId, 164 256 ) -> Result<bool, DbError>; 165 257 166 258 async fn list_app_passwords(&self, user_id: Uuid) -> Result<Vec<AppPasswordRecord>, DbError>;
+147 -8
crates/tranquil-db-traits/src/sso.rs
··· 6 6 7 7 use crate::DbError; 8 8 9 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 10 + pub struct ExternalUserId(String); 11 + 12 + impl ExternalUserId { 13 + pub fn new(id: impl Into<String>) -> Self { 14 + Self(id.into()) 15 + } 16 + 17 + pub fn as_str(&self) -> &str { 18 + &self.0 19 + } 20 + 21 + pub fn into_inner(self) -> String { 22 + self.0 23 + } 24 + } 25 + 26 + impl std::fmt::Display for ExternalUserId { 27 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 28 + write!(f, "{}", self.0) 29 + } 30 + } 31 + 32 + impl From<String> for ExternalUserId { 33 + fn from(s: String) -> Self { 34 + Self(s) 35 + } 36 + } 37 + 38 + impl From<ExternalUserId> for String { 39 + fn from(id: ExternalUserId) -> Self { 40 + id.0 41 + } 42 + } 43 + 44 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 45 + pub struct ExternalUsername(String); 46 + 47 + impl ExternalUsername { 48 + pub fn new(username: impl Into<String>) -> Self { 49 + Self(username.into()) 50 + } 51 + 52 + pub fn as_str(&self) -> &str { 53 + &self.0 54 + } 55 + 56 + pub fn into_inner(self) -> String { 57 + self.0 58 + } 59 + } 60 + 61 + impl std::fmt::Display for ExternalUsername { 62 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 63 + write!(f, "{}", self.0) 64 + } 65 + } 66 + 67 + impl From<String> for ExternalUsername { 68 + fn from(s: String) -> Self { 69 + Self(s) 70 + } 71 + } 72 + 73 + impl From<ExternalUsername> for String { 74 + fn from(username: ExternalUsername) -> Self { 75 + username.0 76 + } 77 + } 78 + 79 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] 80 + pub struct ExternalEmail(String); 81 + 82 + impl ExternalEmail { 83 + pub fn new(email: impl Into<String>) -> Self { 84 + Self(email.into()) 85 + } 86 + 87 + pub fn as_str(&self) -> &str { 88 + &self.0 89 + } 90 + 91 + pub fn into_inner(self) -> String { 92 + self.0 93 + } 94 + } 95 + 96 + impl std::fmt::Display for ExternalEmail { 97 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 98 + write!(f, "{}", self.0) 99 + } 100 + } 101 + 102 + impl From<String> for ExternalEmail { 103 + fn from(s: String) -> Self { 104 + Self(s) 105 + } 106 + } 107 + 108 + impl From<ExternalEmail> for String { 109 + fn from(email: ExternalEmail) -> Self { 110 + email.0 111 + } 112 + } 113 + 9 114 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 10 115 #[sqlx(type_name = "sso_provider_type", rename_all = "lowercase")] 11 116 pub enum SsoProviderType { ··· 17 122 Apple, 18 123 } 19 124 125 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 126 + #[sqlx(type_name = "text", rename_all = "lowercase")] 127 + #[serde(rename_all = "lowercase")] 128 + pub enum SsoAction { 129 + Login, 130 + Link, 131 + Register, 132 + } 133 + 134 + impl SsoAction { 135 + pub fn as_str(&self) -> &'static str { 136 + match self { 137 + Self::Login => "login", 138 + Self::Link => "link", 139 + Self::Register => "register", 140 + } 141 + } 142 + 143 + pub fn parse(s: &str) -> Option<Self> { 144 + match s.to_lowercase().as_str() { 145 + "login" => Some(Self::Login), 146 + "link" => Some(Self::Link), 147 + "register" => Some(Self::Register), 148 + _ => None, 149 + } 150 + } 151 + } 152 + 153 + impl std::fmt::Display for SsoAction { 154 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 155 + f.write_str(self.as_str()) 156 + } 157 + } 158 + 20 159 impl SsoProviderType { 21 160 pub fn as_str(&self) -> &'static str { 22 161 match self { ··· 69 208 pub id: Uuid, 70 209 pub did: Did, 71 210 pub provider: SsoProviderType, 72 - pub provider_user_id: String, 73 - pub provider_username: Option<String>, 74 - pub provider_email: Option<String>, 211 + pub provider_user_id: ExternalUserId, 212 + pub provider_username: Option<ExternalUsername>, 213 + pub provider_email: Option<ExternalEmail>, 75 214 pub created_at: DateTime<Utc>, 76 215 pub updated_at: DateTime<Utc>, 77 216 pub last_login_at: Option<DateTime<Utc>>, ··· 82 221 pub state: String, 83 222 pub request_uri: String, 84 223 pub provider: SsoProviderType, 85 - pub action: String, 224 + pub action: SsoAction, 86 225 pub nonce: Option<String>, 87 226 pub code_verifier: Option<String>, 88 227 pub did: Option<Did>, ··· 95 234 pub token: String, 96 235 pub request_uri: String, 97 236 pub provider: SsoProviderType, 98 - pub provider_user_id: String, 99 - pub provider_username: Option<String>, 100 - pub provider_email: Option<String>, 237 + pub provider_user_id: ExternalUserId, 238 + pub provider_username: Option<ExternalUsername>, 239 + pub provider_email: Option<ExternalEmail>, 101 240 pub provider_email_verified: bool, 102 241 pub created_at: DateTime<Utc>, 103 242 pub expires_at: DateTime<Utc>, ··· 140 279 state: &str, 141 280 request_uri: &str, 142 281 provider: SsoProviderType, 143 - action: &str, 282 + action: SsoAction, 144 283 nonce: Option<&str>, 145 284 code_verifier: Option<&str>, 146 285 did: Option<&Did>,
+100 -34
crates/tranquil-db-traits/src/user.rs
··· 1 1 use async_trait::async_trait; 2 2 use chrono::{DateTime, Utc}; 3 + use serde::{Deserialize, Serialize}; 3 4 use tranquil_types::{Did, Handle}; 4 5 use uuid::Uuid; 5 6 6 - use crate::{CommsChannel, DbError, SsoProviderType}; 7 + use crate::{ChannelVerificationStatus, CommsChannel, DbError, SsoProviderType}; 8 + 9 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 10 + #[sqlx(type_name = "account_type", rename_all = "snake_case")] 11 + pub enum AccountType { 12 + Personal, 13 + Delegated, 14 + } 15 + 16 + impl AccountType { 17 + pub fn is_delegated(&self) -> bool { 18 + matches!(self, Self::Delegated) 19 + } 20 + } 7 21 8 22 #[derive(Debug, Clone)] 9 23 pub struct UserRow { ··· 62 76 pub preferred_comms_channel: CommsChannel, 63 77 pub deactivated_at: Option<DateTime<Utc>>, 64 78 pub takedown_ref: Option<String>, 65 - pub email_verified: bool, 66 - pub discord_verified: bool, 67 - pub telegram_verified: bool, 68 - pub signal_verified: bool, 69 - pub account_type: String, 79 + pub channel_verification: ChannelVerificationStatus, 80 + pub account_type: AccountType, 70 81 } 71 82 72 83 #[derive(Debug, Clone)] ··· 74 85 pub id: Uuid, 75 86 pub two_factor_enabled: bool, 76 87 pub preferred_comms_channel: CommsChannel, 77 - pub email_verified: bool, 78 - pub discord_verified: bool, 79 - pub telegram_verified: bool, 80 - pub signal_verified: bool, 88 + pub channel_verification: ChannelVerificationStatus, 81 89 } 82 90 83 91 #[async_trait] ··· 202 210 did: &Did, 203 211 ) -> Result<Option<UserIdHandleEmail>, DbError>; 204 212 205 - async fn update_preferred_comms_channel(&self, did: &Did, channel: &str) 206 - -> Result<(), DbError>; 213 + async fn update_preferred_comms_channel( 214 + &self, 215 + did: &Did, 216 + channel: CommsChannel, 217 + ) -> Result<(), DbError>; 207 218 208 219 async fn clear_discord(&self, user_id: Uuid) -> Result<(), DbError>; 209 220 ··· 292 303 293 304 async fn get_totp_record(&self, did: &Did) -> Result<Option<TotpRecord>, DbError>; 294 305 306 + async fn get_totp_record_state(&self, did: &Did) -> Result<Option<TotpRecordState>, DbError>; 307 + 295 308 async fn upsert_totp_secret( 296 309 &self, 297 310 did: &Did, ··· 560 573 pub struct UserCommsPrefs { 561 574 pub email: Option<String>, 562 575 pub handle: Handle, 563 - pub preferred_channel: String, 576 + pub preferred_channel: CommsChannel, 564 577 pub preferred_locale: Option<String>, 565 578 } 566 579 ··· 611 624 pub password_hash: Option<String>, 612 625 pub deactivated_at: Option<DateTime<Utc>>, 613 626 pub takedown_ref: Option<String>, 614 - pub email_verified: bool, 615 - pub discord_verified: bool, 616 - pub telegram_verified: bool, 617 - pub signal_verified: bool, 627 + pub channel_verification: ChannelVerificationStatus, 618 628 } 619 629 620 630 #[derive(Debug, Clone)] 621 631 pub struct NotificationPrefs { 622 632 pub email: String, 623 - pub preferred_channel: String, 633 + pub preferred_channel: CommsChannel, 624 634 pub discord_id: Option<String>, 625 635 pub discord_verified: bool, 626 636 pub telegram_username: Option<String>, ··· 641 651 pub id: Uuid, 642 652 pub handle: Handle, 643 653 pub email: Option<String>, 644 - pub email_verified: bool, 645 - pub discord_verified: bool, 646 - pub telegram_verified: bool, 647 - pub signal_verified: bool, 654 + pub channel_verification: ChannelVerificationStatus, 648 655 } 649 656 650 657 #[derive(Debug, Clone)] ··· 675 682 pub verified: bool, 676 683 } 677 684 685 + #[derive(Debug, Clone)] 686 + pub struct VerifiedTotpRecord { 687 + pub secret_encrypted: Vec<u8>, 688 + pub encryption_version: i32, 689 + } 690 + 691 + #[derive(Debug, Clone)] 692 + pub struct UnverifiedTotpRecord { 693 + pub secret_encrypted: Vec<u8>, 694 + pub encryption_version: i32, 695 + } 696 + 697 + #[derive(Debug, Clone)] 698 + pub enum TotpRecordState { 699 + Verified(VerifiedTotpRecord), 700 + Unverified(UnverifiedTotpRecord), 701 + } 702 + 703 + impl TotpRecordState { 704 + pub fn is_verified(&self) -> bool { 705 + matches!(self, Self::Verified(_)) 706 + } 707 + 708 + pub fn as_verified(&self) -> Option<&VerifiedTotpRecord> { 709 + match self { 710 + Self::Verified(r) => Some(r), 711 + Self::Unverified(_) => None, 712 + } 713 + } 714 + 715 + pub fn as_unverified(&self) -> Option<&UnverifiedTotpRecord> { 716 + match self { 717 + Self::Unverified(r) => Some(r), 718 + Self::Verified(_) => None, 719 + } 720 + } 721 + 722 + pub fn into_verified(self) -> Option<VerifiedTotpRecord> { 723 + match self { 724 + Self::Verified(r) => Some(r), 725 + Self::Unverified(_) => None, 726 + } 727 + } 728 + 729 + pub fn into_unverified(self) -> Option<UnverifiedTotpRecord> { 730 + match self { 731 + Self::Unverified(r) => Some(r), 732 + Self::Verified(_) => None, 733 + } 734 + } 735 + } 736 + 737 + impl From<TotpRecord> for TotpRecordState { 738 + fn from(record: TotpRecord) -> Self { 739 + if record.verified { 740 + Self::Verified(VerifiedTotpRecord { 741 + secret_encrypted: record.secret_encrypted, 742 + encryption_version: record.encryption_version, 743 + }) 744 + } else { 745 + Self::Unverified(UnverifiedTotpRecord { 746 + secret_encrypted: record.secret_encrypted, 747 + encryption_version: record.encryption_version, 748 + }) 749 + } 750 + } 751 + } 752 + 678 753 #[derive(Debug, Clone)] 679 754 pub struct StoredBackupCode { 680 755 pub id: Uuid, ··· 685 760 pub struct UserSessionInfo { 686 761 pub handle: Handle, 687 762 pub email: Option<String>, 688 - pub email_verified: bool, 689 763 pub is_admin: bool, 690 764 pub deactivated_at: Option<DateTime<Utc>>, 691 765 pub takedown_ref: Option<String>, 692 766 pub preferred_locale: Option<String>, 693 767 pub preferred_comms_channel: CommsChannel, 694 - pub discord_verified: bool, 695 - pub telegram_verified: bool, 696 - pub signal_verified: bool, 768 + pub channel_verification: ChannelVerificationStatus, 697 769 pub migrated_to_pds: Option<String>, 698 770 pub migrated_at: Option<DateTime<Utc>>, 699 771 } ··· 713 785 pub email: Option<String>, 714 786 pub deactivated_at: Option<DateTime<Utc>>, 715 787 pub takedown_ref: Option<String>, 716 - pub email_verified: bool, 717 - pub discord_verified: bool, 718 - pub telegram_verified: bool, 719 - pub signal_verified: bool, 788 + pub channel_verification: ChannelVerificationStatus, 720 789 pub allow_legacy_login: bool, 721 790 pub migrated_to_pds: Option<String>, 722 791 pub preferred_comms_channel: CommsChannel, ··· 748 817 pub discord_id: Option<String>, 749 818 pub telegram_username: Option<String>, 750 819 pub signal_number: Option<String>, 751 - pub email_verified: bool, 752 - pub discord_verified: bool, 753 - pub telegram_verified: bool, 754 - pub signal_verified: bool, 820 + pub channel_verification: ChannelVerificationStatus, 755 821 } 756 822 757 823 #[derive(Debug, Clone)]
+9 -9
crates/tranquil-db/src/postgres/delegation.rs
··· 1 1 use async_trait::async_trait; 2 2 use sqlx::PgPool; 3 3 use tranquil_db_traits::{ 4 - AuditLogEntry, ControllerInfo, DbError, DelegatedAccountInfo, DelegationActionType, 4 + AuditLogEntry, ControllerInfo, DbError, DbScope, DelegatedAccountInfo, DelegationActionType, 5 5 DelegationGrant, DelegationRepository, 6 6 }; 7 7 use tranquil_types::Did; ··· 80 80 &self, 81 81 delegated_did: &Did, 82 82 controller_did: &Did, 83 - granted_scopes: &str, 83 + granted_scopes: &DbScope, 84 84 granted_by: &Did, 85 85 ) -> Result<Uuid, DbError> { 86 86 let id = sqlx::query_scalar!( ··· 91 91 "#, 92 92 delegated_did.as_str(), 93 93 controller_did.as_str(), 94 - granted_scopes, 94 + granted_scopes.as_str(), 95 95 granted_by.as_str() 96 96 ) 97 97 .fetch_one(&self.pool) ··· 128 128 &self, 129 129 delegated_did: &Did, 130 130 controller_did: &Did, 131 - new_scopes: &str, 131 + new_scopes: &DbScope, 132 132 ) -> Result<bool, DbError> { 133 133 let result = sqlx::query!( 134 134 r#" ··· 136 136 SET granted_scopes = $1 137 137 WHERE delegated_did = $2 AND controller_did = $3 AND revoked_at IS NULL 138 138 "#, 139 - new_scopes, 139 + new_scopes.as_str(), 140 140 delegated_did.as_str(), 141 141 controller_did.as_str() 142 142 ) ··· 170 170 id: r.id, 171 171 delegated_did: r.delegated_did.into(), 172 172 controller_did: r.controller_did.into(), 173 - granted_scopes: r.granted_scopes, 173 + granted_scopes: DbScope::from_db(r.granted_scopes), 174 174 granted_at: r.granted_at, 175 175 granted_by: r.granted_by.into(), 176 176 revoked_at: r.revoked_at, ··· 206 206 .map(|r| ControllerInfo { 207 207 did: r.did.into(), 208 208 handle: r.handle.into(), 209 - granted_scopes: r.granted_scopes, 209 + granted_scopes: DbScope::from_db(r.granted_scopes), 210 210 granted_at: r.granted_at, 211 211 is_active: r.is_active, 212 212 }) ··· 243 243 .map(|r| DelegatedAccountInfo { 244 244 did: r.did.into(), 245 245 handle: r.handle.into(), 246 - granted_scopes: r.granted_scopes, 246 + granted_scopes: DbScope::from_db(r.granted_scopes), 247 247 granted_at: r.granted_at, 248 248 }) 249 249 .collect()) ··· 280 280 .map(|r| ControllerInfo { 281 281 did: r.did.into(), 282 282 handle: r.handle.into(), 283 - granted_scopes: r.granted_scopes, 283 + granted_scopes: DbScope::from_db(r.granted_scopes), 284 284 granted_at: r.granted_at, 285 285 is_active: r.is_active, 286 286 })
+34 -18
crates/tranquil-db/src/postgres/infra.rs
··· 3 3 use sqlx::PgPool; 4 4 use tranquil_db_traits::{ 5 5 AdminAccountInfo, CommsChannel, CommsStatus, CommsType, DbError, DeletionRequest, 6 - InfraRepository, InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeUse, 7 - NotificationHistoryRow, QueuedComms, ReservedSigningKey, 6 + InfraRepository, InviteCodeError, InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, 7 + InviteCodeState, InviteCodeUse, NotificationHistoryRow, QueuedComms, ReservedSigningKey, 8 + ValidatedInviteCode, 8 9 }; 9 10 use tranquil_types::{CidLink, Did, Handle}; 10 11 use uuid::Uuid; ··· 182 183 Ok(result) 183 184 } 184 185 185 - async fn is_invite_code_valid(&self, code: &str) -> Result<bool, DbError> { 186 - let result = sqlx::query_scalar!( 187 - r#"SELECT (available_uses > 0 AND NOT COALESCE(disabled, false)) as "valid!" FROM invite_codes WHERE code = $1"#, 186 + async fn validate_invite_code<'a>( 187 + &self, 188 + code: &'a str, 189 + ) -> Result<ValidatedInviteCode<'a>, InviteCodeError> { 190 + let result = sqlx::query!( 191 + r#"SELECT available_uses, COALESCE(disabled, false) as "disabled!" FROM invite_codes WHERE code = $1"#, 188 192 code 189 193 ) 190 194 .fetch_optional(&self.pool) 191 195 .await 192 - .map_err(map_sqlx_error)?; 196 + .map_err(|e| InviteCodeError::DatabaseError(map_sqlx_error(e)))?; 193 197 194 - Ok(result.unwrap_or(false)) 198 + match result { 199 + None => Err(InviteCodeError::NotFound), 200 + Some(row) if row.disabled => Err(InviteCodeError::Disabled), 201 + Some(row) if row.available_uses <= 0 => Err(InviteCodeError::ExhaustedUses), 202 + Some(_) => Ok(ValidatedInviteCode::new_validated(code)), 203 + } 195 204 } 196 205 197 - async fn decrement_invite_code_uses(&self, code: &str) -> Result<(), DbError> { 206 + async fn decrement_invite_code_uses( 207 + &self, 208 + code: &ValidatedInviteCode<'_>, 209 + ) -> Result<(), DbError> { 198 210 sqlx::query!( 199 211 "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 200 - code 212 + code.code() 201 213 ) 202 214 .execute(&self.pool) 203 215 .await ··· 206 218 Ok(()) 207 219 } 208 220 209 - async fn record_invite_code_use(&self, code: &str, used_by_user: Uuid) -> Result<(), DbError> { 221 + async fn record_invite_code_use( 222 + &self, 223 + code: &ValidatedInviteCode<'_>, 224 + used_by_user: Uuid, 225 + ) -> Result<(), DbError> { 210 226 sqlx::query!( 211 227 "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 212 - code, 228 + code.code(), 213 229 used_by_user 214 230 ) 215 231 .execute(&self.pool) ··· 245 261 .map(|r| InviteCodeInfo { 246 262 code: r.code, 247 263 available_uses: r.available_uses, 248 - disabled: r.disabled.unwrap_or(false), 264 + state: InviteCodeState::from(r.disabled), 249 265 for_account: Some(Did::from(r.for_account)), 250 266 created_at: r.created_at, 251 267 created_by: None, ··· 422 438 .map(|r| InviteCodeInfo { 423 439 code: r.code, 424 440 available_uses: r.available_uses, 425 - disabled: r.disabled.unwrap_or(false), 441 + state: InviteCodeState::from(r.disabled), 426 442 for_account: Some(Did::from(r.for_account)), 427 443 created_at: r.created_at, 428 444 created_by: Some(Did::from(r.created_by)), ··· 445 461 Ok(result.map(|r| InviteCodeInfo { 446 462 code: r.code, 447 463 available_uses: r.available_uses, 448 - disabled: r.disabled.unwrap_or(false), 464 + state: InviteCodeState::from(r.disabled), 449 465 for_account: Some(Did::from(r.for_account)), 450 466 created_at: r.created_at, 451 467 created_by: Some(Did::from(r.created_by)), ··· 476 492 InviteCodeInfo { 477 493 code: r.code, 478 494 available_uses: r.available_uses, 479 - disabled: r.disabled.unwrap_or(false), 495 + state: InviteCodeState::from(r.disabled), 480 496 for_account: Some(Did::from(r.for_account)), 481 497 created_at: r.created_at, 482 498 created_by: Some(Did::from(r.created_by)), ··· 841 857 r#" 842 858 SELECT 843 859 created_at, 844 - channel as "channel: String", 845 - comms_type as "comms_type: String", 846 - status as "status: String", 860 + channel as "channel: CommsChannel", 861 + comms_type as "comms_type: CommsType", 862 + status as "status: CommsStatus", 847 863 subject, 848 864 body 849 865 FROM comms_queue
+112 -63
crates/tranquil-db/src/postgres/oauth.rs
··· 4 4 use sqlx::PgPool; 5 5 use tranquil_db_traits::{ 6 6 DbError, DeviceAccountRow, DeviceTrustInfo, OAuthRepository, OAuthSessionListItem, 7 - ScopePreference, TrustedDeviceRow, TwoFactorChallenge, 7 + ScopePreference, TokenFamilyId, TrustedDeviceRow, TwoFactorChallenge, 8 8 }; 9 9 use tranquil_oauth::{ 10 - AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, DeviceData, RequestData, 11 - TokenData, 10 + AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, Code as OAuthCode, 11 + DeviceData, DeviceId as OAuthDeviceId, RefreshToken as OAuthRefreshToken, RequestData, 12 + SessionId as OAuthSessionId, TokenData, TokenId as OAuthTokenId, 12 13 }; 13 14 use tranquil_types::{ 14 15 AuthorizationCode, ClientId, DPoPProofId, DeviceId, Did, Handle, RefreshToken, RequestId, ··· 48 49 49 50 #[async_trait] 50 51 impl OAuthRepository for PostgresOAuthRepository { 51 - async fn create_token(&self, data: &TokenData) -> Result<i32, DbError> { 52 + async fn create_token(&self, data: &TokenData) -> Result<TokenFamilyId, DbError> { 52 53 let client_auth_json = to_json(&data.client_auth)?; 53 54 let parameters_json = to_json(&data.parameters)?; 54 55 let row = sqlx::query!( ··· 59 60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 60 61 RETURNING id 61 62 "#, 62 - data.did, 63 - data.token_id, 63 + data.did.as_str(), 64 + &data.token_id.0, 64 65 data.created_at, 65 66 data.updated_at, 66 67 data.expires_at, 67 68 data.client_id, 68 69 client_auth_json, 69 - data.device_id, 70 + data.device_id.as_ref().map(|d| d.0.as_str()), 70 71 parameters_json, 71 72 data.details, 72 - data.code, 73 - data.current_refresh_token, 73 + data.code.as_ref().map(|c| c.0.as_str()), 74 + data.current_refresh_token.as_ref().map(|r| r.0.as_str()), 74 75 data.scope, 75 - data.controller_did, 76 + data.controller_did.as_ref().map(|d| d.as_str()), 76 77 ) 77 78 .fetch_one(&self.pool) 78 79 .await 79 80 .map_err(map_sqlx_error)?; 80 - Ok(row.id) 81 + Ok(TokenFamilyId::new(row.id)) 81 82 } 82 83 83 84 async fn get_token_by_id(&self, token_id: &TokenId) -> Result<Option<TokenData>, DbError> { ··· 95 96 .map_err(map_sqlx_error)?; 96 97 match row { 97 98 Some(r) => Ok(Some(TokenData { 98 - did: r.did, 99 - token_id: r.token_id, 99 + did: r 100 + .did 101 + .parse() 102 + .map_err(|_| DbError::Other("Invalid DID in token".into()))?, 103 + token_id: OAuthTokenId(r.token_id), 100 104 created_at: r.created_at, 101 105 updated_at: r.updated_at, 102 106 expires_at: r.expires_at, 103 107 client_id: r.client_id, 104 108 client_auth: from_json(r.client_auth)?, 105 - device_id: r.device_id, 109 + device_id: r.device_id.map(OAuthDeviceId), 106 110 parameters: from_json(r.parameters)?, 107 111 details: r.details, 108 - code: r.code, 109 - current_refresh_token: r.current_refresh_token, 112 + code: r.code.map(OAuthCode), 113 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 110 114 scope: r.scope, 111 - controller_did: r.controller_did, 115 + controller_did: r 116 + .controller_did 117 + .map(|s| s.parse()) 118 + .transpose() 119 + .map_err(|_| DbError::Other("Invalid controller DID".into()))?, 112 120 })), 113 121 None => Ok(None), 114 122 } ··· 117 125 async fn get_token_by_refresh_token( 118 126 &self, 119 127 refresh_token: &RefreshToken, 120 - ) -> Result<Option<(i32, TokenData)>, DbError> { 128 + ) -> Result<Option<(TokenFamilyId, TokenData)>, DbError> { 121 129 let row = sqlx::query!( 122 130 r#" 123 131 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, ··· 132 140 .map_err(map_sqlx_error)?; 133 141 match row { 134 142 Some(r) => Ok(Some(( 135 - r.id, 143 + TokenFamilyId::new(r.id), 136 144 TokenData { 137 - did: r.did, 138 - token_id: r.token_id, 145 + did: r 146 + .did 147 + .parse() 148 + .map_err(|_| DbError::Other("Invalid DID in token".into()))?, 149 + token_id: OAuthTokenId(r.token_id), 139 150 created_at: r.created_at, 140 151 updated_at: r.updated_at, 141 152 expires_at: r.expires_at, 142 153 client_id: r.client_id, 143 154 client_auth: from_json(r.client_auth)?, 144 - device_id: r.device_id, 155 + device_id: r.device_id.map(OAuthDeviceId), 145 156 parameters: from_json(r.parameters)?, 146 157 details: r.details, 147 - code: r.code, 148 - current_refresh_token: r.current_refresh_token, 158 + code: r.code.map(OAuthCode), 159 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 149 160 scope: r.scope, 150 - controller_did: r.controller_did, 161 + controller_did: r 162 + .controller_did 163 + .map(|s| s.parse()) 164 + .transpose() 165 + .map_err(|_| DbError::Other("Invalid controller DID".into()))?, 151 166 }, 152 167 ))), 153 168 None => Ok(None), ··· 157 172 async fn get_token_by_previous_refresh_token( 158 173 &self, 159 174 refresh_token: &RefreshToken, 160 - ) -> Result<Option<(i32, TokenData)>, DbError> { 175 + ) -> Result<Option<(TokenFamilyId, TokenData)>, DbError> { 161 176 let grace_cutoff = Utc::now() - Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 162 177 let row = sqlx::query!( 163 178 r#" ··· 174 189 .map_err(map_sqlx_error)?; 175 190 match row { 176 191 Some(r) => Ok(Some(( 177 - r.id, 192 + TokenFamilyId::new(r.id), 178 193 TokenData { 179 - did: r.did, 180 - token_id: r.token_id, 194 + did: r 195 + .did 196 + .parse() 197 + .map_err(|_| DbError::Other("Invalid DID in token".into()))?, 198 + token_id: OAuthTokenId(r.token_id), 181 199 created_at: r.created_at, 182 200 updated_at: r.updated_at, 183 201 expires_at: r.expires_at, 184 202 client_id: r.client_id, 185 203 client_auth: from_json(r.client_auth)?, 186 - device_id: r.device_id, 204 + device_id: r.device_id.map(OAuthDeviceId), 187 205 parameters: from_json(r.parameters)?, 188 206 details: r.details, 189 - code: r.code, 190 - current_refresh_token: r.current_refresh_token, 207 + code: r.code.map(OAuthCode), 208 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 191 209 scope: r.scope, 192 - controller_did: r.controller_did, 210 + controller_did: r 211 + .controller_did 212 + .map(|s| s.parse()) 213 + .transpose() 214 + .map_err(|_| DbError::Other("Invalid controller DID".into()))?, 193 215 }, 194 216 ))), 195 217 None => Ok(None), ··· 198 220 199 221 async fn rotate_token( 200 222 &self, 201 - old_db_id: i32, 223 + old_db_id: TokenFamilyId, 202 224 new_refresh_token: &RefreshToken, 203 225 new_expires_at: DateTime<Utc>, 204 226 ) -> Result<(), DbError> { ··· 207 229 r#" 208 230 SELECT current_refresh_token FROM oauth_token WHERE id = $1 209 231 "#, 210 - old_db_id 232 + old_db_id.as_i32() 211 233 ) 212 234 .fetch_one(&mut *tx) 213 235 .await ··· 219 241 VALUES ($1, $2) 220 242 "#, 221 243 old_rt, 222 - old_db_id 244 + old_db_id.as_i32() 223 245 ) 224 246 .execute(&mut *tx) 225 247 .await ··· 232 254 previous_refresh_token = $4, rotated_at = NOW() 233 255 WHERE id = $1 234 256 "#, 235 - old_db_id, 257 + old_db_id.as_i32(), 236 258 new_refresh_token.as_str(), 237 259 new_expires_at, 238 260 old_refresh ··· 247 269 async fn check_refresh_token_used( 248 270 &self, 249 271 refresh_token: &RefreshToken, 250 - ) -> Result<Option<i32>, DbError> { 272 + ) -> Result<Option<TokenFamilyId>, DbError> { 251 273 let row = sqlx::query_scalar!( 252 274 r#" 253 275 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 ··· 257 279 .fetch_optional(&self.pool) 258 280 .await 259 281 .map_err(map_sqlx_error)?; 260 - Ok(row) 282 + Ok(row.map(TokenFamilyId::new)) 261 283 } 262 284 263 285 async fn delete_token(&self, token_id: &TokenId) -> Result<(), DbError> { ··· 273 295 Ok(()) 274 296 } 275 297 276 - async fn delete_token_family(&self, db_id: i32) -> Result<(), DbError> { 298 + async fn delete_token_family(&self, db_id: TokenFamilyId) -> Result<(), DbError> { 277 299 sqlx::query!( 278 300 r#" 279 301 DELETE FROM oauth_token WHERE id = $1 280 302 "#, 281 - db_id 303 + db_id.as_i32() 282 304 ) 283 305 .execute(&self.pool) 284 306 .await ··· 302 324 rows.into_iter() 303 325 .map(|r| { 304 326 Ok(TokenData { 305 - did: r.did, 306 - token_id: r.token_id, 327 + did: r 328 + .did 329 + .parse() 330 + .map_err(|_| DbError::Other("Invalid DID in token".into()))?, 331 + token_id: OAuthTokenId(r.token_id), 307 332 created_at: r.created_at, 308 333 updated_at: r.updated_at, 309 334 expires_at: r.expires_at, 310 335 client_id: r.client_id, 311 336 client_auth: from_json(r.client_auth)?, 312 - device_id: r.device_id, 337 + device_id: r.device_id.map(OAuthDeviceId), 313 338 parameters: from_json(r.parameters)?, 314 339 details: r.details, 315 - code: r.code, 316 - current_refresh_token: r.current_refresh_token, 340 + code: r.code.map(OAuthCode), 341 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 317 342 scope: r.scope, 318 - controller_did: r.controller_did, 343 + controller_did: r 344 + .controller_did 345 + .map(|s| s.parse()) 346 + .transpose() 347 + .map_err(|_| DbError::Other("Invalid controller DID".into()))?, 319 348 }) 320 349 }) 321 350 .collect() ··· 407 436 VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 408 437 "#, 409 438 request_id.as_str(), 410 - data.did, 411 - data.device_id, 439 + data.did.as_ref().map(|d| d.as_str()), 440 + data.device_id.as_ref().map(|d| d.0.as_str()), 412 441 data.client_id, 413 442 client_auth_json, 414 443 parameters_json, 415 444 data.expires_at, 416 - data.code, 445 + data.code.as_ref().map(|c| c.0.as_str()), 417 446 ) 418 447 .execute(&self.pool) 419 448 .await ··· 448 477 client_auth, 449 478 parameters, 450 479 expires_at: r.expires_at, 451 - did: r.did, 452 - device_id: r.device_id, 453 - code: r.code, 454 - controller_did: r.controller_did, 480 + did: r 481 + .did 482 + .map(|s| s.parse()) 483 + .transpose() 484 + .map_err(|_| DbError::Other("Invalid DID in DB".into()))?, 485 + device_id: r.device_id.map(OAuthDeviceId), 486 + code: r.code.map(OAuthCode), 487 + controller_did: r 488 + .controller_did 489 + .map(|s| s.parse()) 490 + .transpose() 491 + .map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?, 455 492 })) 456 493 } 457 494 None => Ok(None), ··· 534 571 client_auth, 535 572 parameters, 536 573 expires_at: r.expires_at, 537 - did: r.did, 538 - device_id: r.device_id, 539 - code: r.code, 540 - controller_did: r.controller_did, 574 + did: r 575 + .did 576 + .map(|s| s.parse()) 577 + .transpose() 578 + .map_err(|_| DbError::Other("Invalid DID in DB".into()))?, 579 + device_id: r.device_id.map(OAuthDeviceId), 580 + code: r.code.map(OAuthCode), 581 + controller_did: r 582 + .controller_did 583 + .map(|s| s.parse()) 584 + .transpose() 585 + .map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?, 541 586 })) 542 587 } 543 588 None => Ok(None), ··· 655 700 VALUES ($1, $2, $3, $4, $5) 656 701 "#, 657 702 device_id.as_str(), 658 - data.session_id, 703 + &data.session_id.0, 659 704 data.user_agent, 660 705 data.ip_address, 661 706 data.last_seen_at, ··· 679 724 .await 680 725 .map_err(map_sqlx_error)?; 681 726 Ok(row.map(|r| DeviceData { 682 - session_id: r.session_id, 727 + session_id: OAuthSessionId(r.session_id), 683 728 user_agent: r.user_agent, 684 729 ip_address: r.ip_address, 685 730 last_seen_at: r.last_seen_at, ··· 1207 1252 Ok(rows 1208 1253 .into_iter() 1209 1254 .map(|r| OAuthSessionListItem { 1210 - id: r.id, 1255 + id: TokenFamilyId::new(r.id), 1211 1256 token_id: TokenId::from(r.token_id), 1212 1257 created_at: r.created_at, 1213 1258 expires_at: r.expires_at, ··· 1216 1261 .collect()) 1217 1262 } 1218 1263 1219 - async fn delete_session_by_id(&self, session_id: i32, did: &Did) -> Result<u64, DbError> { 1264 + async fn delete_session_by_id( 1265 + &self, 1266 + session_id: TokenFamilyId, 1267 + did: &Did, 1268 + ) -> Result<u64, DbError> { 1220 1269 let result = sqlx::query!( 1221 1270 "DELETE FROM oauth_token WHERE id = $1 AND did = $2", 1222 - session_id, 1271 + session_id.as_i32(), 1223 1272 did.as_str() 1224 1273 ) 1225 1274 .execute(&self.pool)
+146 -107
crates/tranquil-db/src/postgres/repo.rs
··· 2 2 use chrono::{DateTime, Utc}; 3 3 use sqlx::PgPool; 4 4 use tranquil_db_traits::{ 5 - BrokenGenesisCommit, CommitEventData, DbError, EventBlocksCids, FullRecordInfo, ImportBlock, 6 - ImportRecord, ImportRepoError, RecordInfo, RecordWithTakedown, RepoAccountInfo, RepoInfo, 7 - RepoListItem, RepoRepository, RepoWithoutRev, SequencedEvent, UserNeedingRecordBlobsBackfill, 8 - UserWithoutBlocks, 5 + AccountStatus, BrokenGenesisCommit, CommitEventData, DbError, EventBlocksCids, FullRecordInfo, 6 + ImportBlock, ImportRecord, ImportRepoError, RecordInfo, RecordWithTakedown, RepoAccountInfo, 7 + RepoEventType, RepoInfo, RepoListItem, RepoRepository, RepoWithoutRev, SequenceNumber, 8 + SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks, 9 9 }; 10 10 use tranquil_types::{AtUri, CidLink, Did, Handle, Nsid, Rkey}; 11 11 use uuid::Uuid; ··· 21 21 seq: i64, 22 22 did: String, 23 23 created_at: DateTime<Utc>, 24 - event_type: String, 24 + event_type: RepoEventType, 25 25 commit_cid: Option<String>, 26 26 prev_cid: Option<String>, 27 27 prev_data_cid: Option<String>, ··· 627 627 Ok(rows.into_iter().map(|(cid,)| cid).collect()) 628 628 } 629 629 630 - async fn insert_commit_event(&self, data: &CommitEventData) -> Result<i64, DbError> { 630 + async fn insert_commit_event(&self, data: &CommitEventData) -> Result<SequenceNumber, DbError> { 631 631 let seq = sqlx::query_scalar!( 632 632 r#" 633 633 INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid, rev) ··· 635 635 RETURNING seq 636 636 "#, 637 637 data.did.as_str(), 638 - data.event_type, 638 + data.event_type.as_str(), 639 639 data.commit_cid.as_ref().map(|c| c.as_str()), 640 640 data.prev_cid.as_ref().map(|c| c.as_str()), 641 641 data.ops, ··· 648 648 .await 649 649 .map_err(map_sqlx_error)?; 650 650 651 - Ok(seq) 651 + Ok(seq.into()) 652 652 } 653 653 654 654 async fn insert_identity_event( 655 655 &self, 656 656 did: &Did, 657 657 handle: Option<&Handle>, 658 - ) -> Result<i64, DbError> { 658 + ) -> Result<SequenceNumber, DbError> { 659 659 let handle_str = handle.map(|h| h.as_str()); 660 660 let seq = sqlx::query_scalar!( 661 661 r#" ··· 675 675 .await 676 676 .map_err(map_sqlx_error)?; 677 677 678 - Ok(seq) 678 + Ok(seq.into()) 679 679 } 680 680 681 681 async fn insert_account_event( 682 682 &self, 683 683 did: &Did, 684 - active: bool, 685 - status: Option<&str>, 686 - ) -> Result<i64, DbError> { 684 + status: AccountStatus, 685 + ) -> Result<SequenceNumber, DbError> { 686 + let active = status.is_active(); 687 + let status_str = status.for_firehose(); 687 688 let seq = sqlx::query_scalar!( 688 689 r#" 689 690 INSERT INTO repo_seq (did, event_type, active, status) ··· 692 693 "#, 693 694 did.as_str(), 694 695 active, 695 - status 696 + status_str 696 697 ) 697 698 .fetch_one(&self.pool) 698 699 .await ··· 703 704 .await 704 705 .map_err(map_sqlx_error)?; 705 706 706 - Ok(seq) 707 + Ok(seq.into()) 707 708 } 708 709 709 710 async fn insert_sync_event( ··· 711 712 did: &Did, 712 713 commit_cid: &CidLink, 713 714 rev: Option<&str>, 714 - ) -> Result<i64, DbError> { 715 + ) -> Result<SequenceNumber, DbError> { 715 716 let seq = sqlx::query_scalar!( 716 717 r#" 717 718 INSERT INTO repo_seq (did, event_type, commit_cid, rev) ··· 731 732 .await 732 733 .map_err(map_sqlx_error)?; 733 734 734 - Ok(seq) 735 + Ok(seq.into()) 735 736 } 736 737 737 738 async fn insert_genesis_commit_event( ··· 740 741 commit_cid: &CidLink, 741 742 mst_root_cid: &CidLink, 742 743 rev: &str, 743 - ) -> Result<i64, DbError> { 744 + ) -> Result<SequenceNumber, DbError> { 744 745 let ops = serde_json::json!([]); 745 746 let blobs: Vec<String> = vec![]; 746 747 let blocks_cids: Vec<String> = vec![mst_root_cid.to_string(), commit_cid.to_string()]; ··· 769 770 .await 770 771 .map_err(map_sqlx_error)?; 771 772 772 - Ok(seq) 773 + Ok(seq.into()) 773 774 } 774 775 775 776 async fn update_seq_blocks_cids( 776 777 &self, 777 - seq: i64, 778 + seq: SequenceNumber, 778 779 blocks_cids: &[String], 779 780 ) -> Result<(), DbError> { 780 781 sqlx::query!( 781 782 "UPDATE repo_seq SET blocks_cids = $1 WHERE seq = $2", 782 783 blocks_cids, 783 - seq 784 + seq.as_i64() 784 785 ) 785 786 .execute(&self.pool) 786 787 .await ··· 789 790 Ok(()) 790 791 } 791 792 792 - async fn delete_sequences_except(&self, did: &Did, keep_seq: i64) -> Result<(), DbError> { 793 + async fn delete_sequences_except( 794 + &self, 795 + did: &Did, 796 + keep_seq: SequenceNumber, 797 + ) -> Result<(), DbError> { 793 798 sqlx::query!( 794 799 "DELETE FROM repo_seq WHERE did = $1 AND seq != $2", 795 800 did.as_str(), 796 - keep_seq 801 + keep_seq.as_i64() 797 802 ) 798 803 .execute(&self.pool) 799 804 .await ··· 802 807 Ok(()) 803 808 } 804 809 805 - async fn get_max_seq(&self) -> Result<i64, DbError> { 810 + async fn get_max_seq(&self) -> Result<SequenceNumber, DbError> { 806 811 let seq = sqlx::query_scalar!(r#"SELECT COALESCE(MAX(seq), 0) as "max!" FROM repo_seq"#) 807 812 .fetch_one(&self.pool) 808 813 .await 809 814 .map_err(map_sqlx_error)?; 810 815 811 - Ok(seq) 816 + Ok(seq.into()) 812 817 } 813 818 814 - async fn get_min_seq_since(&self, since: DateTime<Utc>) -> Result<Option<i64>, DbError> { 819 + async fn get_min_seq_since( 820 + &self, 821 + since: DateTime<Utc>, 822 + ) -> Result<Option<SequenceNumber>, DbError> { 815 823 let seq = sqlx::query_scalar!( 816 824 "SELECT MIN(seq) FROM repo_seq WHERE created_at >= $1", 817 825 since ··· 820 828 .await 821 829 .map_err(map_sqlx_error)?; 822 830 823 - Ok(seq) 831 + Ok(seq.map(SequenceNumber::from)) 824 832 } 825 833 826 834 async fn get_account_with_repo(&self, did: &Did) -> Result<Option<RepoAccountInfo>, DbError> { ··· 846 854 847 855 async fn get_events_since_seq( 848 856 &self, 849 - since_seq: i64, 857 + since_seq: SequenceNumber, 850 858 limit: Option<i64>, 851 859 ) -> Result<Vec<SequencedEvent>, DbError> { 852 - let map_row = |r: SequencedEventRow| SequencedEvent { 853 - seq: r.seq, 854 - did: Did::from(r.did), 855 - created_at: r.created_at, 856 - event_type: r.event_type, 857 - commit_cid: r.commit_cid.map(CidLink::from), 858 - prev_cid: r.prev_cid.map(CidLink::from), 859 - prev_data_cid: r.prev_data_cid.map(CidLink::from), 860 - ops: r.ops, 861 - blobs: r.blobs, 862 - blocks_cids: r.blocks_cids, 863 - handle: r.handle.map(Handle::from), 864 - active: r.active, 865 - status: r.status, 866 - rev: r.rev, 860 + let map_row = |r: SequencedEventRow| { 861 + let status = r 862 + .status 863 + .as_deref() 864 + .and_then(AccountStatus::parse) 865 + .or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active)); 866 + SequencedEvent { 867 + seq: r.seq.into(), 868 + did: Did::from(r.did), 869 + created_at: r.created_at, 870 + event_type: r.event_type, 871 + commit_cid: r.commit_cid.map(CidLink::from), 872 + prev_cid: r.prev_cid.map(CidLink::from), 873 + prev_data_cid: r.prev_data_cid.map(CidLink::from), 874 + ops: r.ops, 875 + blobs: r.blobs, 876 + blocks_cids: r.blocks_cids, 877 + handle: r.handle.map(Handle::from), 878 + active: r.active, 879 + status, 880 + rev: r.rev, 881 + } 867 882 }; 868 883 match limit { 869 884 Some(lim) => { 870 885 let rows = sqlx::query_as!( 871 886 SequencedEventRow, 872 - r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, 887 + r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid, 873 888 ops, blobs, blocks_cids, handle, active, status, rev 874 889 FROM repo_seq 875 890 WHERE seq > $1 876 891 ORDER BY seq ASC 877 892 LIMIT $2"#, 878 - since_seq, 893 + since_seq.as_i64(), 879 894 lim 880 895 ) 881 896 .fetch_all(&self.pool) ··· 886 901 None => { 887 902 let rows = sqlx::query_as!( 888 903 SequencedEventRow, 889 - r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, 904 + r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid, 890 905 ops, blobs, blocks_cids, handle, active, status, rev 891 906 FROM repo_seq 892 907 WHERE seq > $1 893 908 ORDER BY seq ASC"#, 894 - since_seq 909 + since_seq.as_i64() 895 910 ) 896 911 .fetch_all(&self.pool) 897 912 .await ··· 903 918 904 919 async fn get_events_in_seq_range( 905 920 &self, 906 - start_seq: i64, 907 - end_seq: i64, 921 + start_seq: SequenceNumber, 922 + end_seq: SequenceNumber, 908 923 ) -> Result<Vec<SequencedEvent>, DbError> { 909 924 let rows = sqlx::query!( 910 - r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, 925 + r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid, 911 926 ops, blobs, blocks_cids, handle, active, status, rev 912 927 FROM repo_seq 913 928 WHERE seq > $1 AND seq < $2 914 929 ORDER BY seq ASC"#, 915 - start_seq, 916 - end_seq 930 + start_seq.as_i64(), 931 + end_seq.as_i64() 917 932 ) 918 933 .fetch_all(&self.pool) 919 934 .await 920 935 .map_err(map_sqlx_error)?; 921 936 Ok(rows 922 937 .into_iter() 923 - .map(|r| SequencedEvent { 924 - seq: r.seq, 925 - did: Did::from(r.did), 926 - created_at: r.created_at, 927 - event_type: r.event_type, 928 - commit_cid: r.commit_cid.map(CidLink::from), 929 - prev_cid: r.prev_cid.map(CidLink::from), 930 - prev_data_cid: r.prev_data_cid.map(CidLink::from), 931 - ops: r.ops, 932 - blobs: r.blobs, 933 - blocks_cids: r.blocks_cids, 934 - handle: r.handle.map(Handle::from), 935 - active: r.active, 936 - status: r.status, 937 - rev: r.rev, 938 + .map(|r| { 939 + let status = r 940 + .status 941 + .as_deref() 942 + .and_then(AccountStatus::parse) 943 + .or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active)); 944 + SequencedEvent { 945 + seq: r.seq.into(), 946 + did: Did::from(r.did), 947 + created_at: r.created_at, 948 + event_type: r.event_type, 949 + commit_cid: r.commit_cid.map(CidLink::from), 950 + prev_cid: r.prev_cid.map(CidLink::from), 951 + prev_data_cid: r.prev_data_cid.map(CidLink::from), 952 + ops: r.ops, 953 + blobs: r.blobs, 954 + blocks_cids: r.blocks_cids, 955 + handle: r.handle.map(Handle::from), 956 + active: r.active, 957 + status, 958 + rev: r.rev, 959 + } 938 960 }) 939 961 .collect()) 940 962 } 941 963 942 - async fn get_event_by_seq(&self, seq: i64) -> Result<Option<SequencedEvent>, DbError> { 964 + async fn get_event_by_seq( 965 + &self, 966 + seq: SequenceNumber, 967 + ) -> Result<Option<SequencedEvent>, DbError> { 943 968 let row = sqlx::query!( 944 - r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, 969 + r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid, 945 970 ops, blobs, blocks_cids, handle, active, status, rev 946 971 FROM repo_seq 947 972 WHERE seq = $1"#, 948 - seq 973 + seq.as_i64() 949 974 ) 950 975 .fetch_optional(&self.pool) 951 976 .await 952 977 .map_err(map_sqlx_error)?; 953 - Ok(row.map(|r| SequencedEvent { 954 - seq: r.seq, 955 - did: Did::from(r.did), 956 - created_at: r.created_at, 957 - event_type: r.event_type, 958 - commit_cid: r.commit_cid.map(CidLink::from), 959 - prev_cid: r.prev_cid.map(CidLink::from), 960 - prev_data_cid: r.prev_data_cid.map(CidLink::from), 961 - ops: r.ops, 962 - blobs: r.blobs, 963 - blocks_cids: r.blocks_cids, 964 - handle: r.handle.map(Handle::from), 965 - active: r.active, 966 - status: r.status, 967 - rev: r.rev, 978 + Ok(row.map(|r| { 979 + let status = r 980 + .status 981 + .as_deref() 982 + .and_then(AccountStatus::parse) 983 + .or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active)); 984 + SequencedEvent { 985 + seq: r.seq.into(), 986 + did: Did::from(r.did), 987 + created_at: r.created_at, 988 + event_type: r.event_type, 989 + commit_cid: r.commit_cid.map(CidLink::from), 990 + prev_cid: r.prev_cid.map(CidLink::from), 991 + prev_data_cid: r.prev_data_cid.map(CidLink::from), 992 + ops: r.ops, 993 + blobs: r.blobs, 994 + blocks_cids: r.blocks_cids, 995 + handle: r.handle.map(Handle::from), 996 + active: r.active, 997 + status, 998 + rev: r.rev, 999 + } 968 1000 })) 969 1001 } 970 1002 971 1003 async fn get_events_since_cursor( 972 1004 &self, 973 - cursor: i64, 1005 + cursor: SequenceNumber, 974 1006 limit: i64, 975 1007 ) -> Result<Vec<SequencedEvent>, DbError> { 976 1008 let rows = sqlx::query!( 977 - r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, 1009 + r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid, 978 1010 ops, blobs, blocks_cids, handle, active, status, rev 979 1011 FROM repo_seq 980 1012 WHERE seq > $1 981 1013 ORDER BY seq ASC 982 1014 LIMIT $2"#, 983 - cursor, 1015 + cursor.as_i64(), 984 1016 limit 985 1017 ) 986 1018 .fetch_all(&self.pool) ··· 988 1020 .map_err(map_sqlx_error)?; 989 1021 Ok(rows 990 1022 .into_iter() 991 - .map(|r| SequencedEvent { 992 - seq: r.seq, 993 - did: Did::from(r.did), 994 - created_at: r.created_at, 995 - event_type: r.event_type, 996 - commit_cid: r.commit_cid.map(CidLink::from), 997 - prev_cid: r.prev_cid.map(CidLink::from), 998 - prev_data_cid: r.prev_data_cid.map(CidLink::from), 999 - ops: r.ops, 1000 - blobs: r.blobs, 1001 - blocks_cids: r.blocks_cids, 1002 - handle: r.handle.map(Handle::from), 1003 - active: r.active, 1004 - status: r.status, 1005 - rev: r.rev, 1023 + .map(|r| { 1024 + let status = r 1025 + .status 1026 + .as_deref() 1027 + .and_then(AccountStatus::parse) 1028 + .or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active)); 1029 + SequencedEvent { 1030 + seq: r.seq.into(), 1031 + did: Did::from(r.did), 1032 + created_at: r.created_at, 1033 + event_type: r.event_type, 1034 + commit_cid: r.commit_cid.map(CidLink::from), 1035 + prev_cid: r.prev_cid.map(CidLink::from), 1036 + prev_data_cid: r.prev_data_cid.map(CidLink::from), 1037 + ops: r.ops, 1038 + blobs: r.blobs, 1039 + blocks_cids: r.blocks_cids, 1040 + handle: r.handle.map(Handle::from), 1041 + active: r.active, 1042 + status, 1043 + rev: r.rev, 1044 + } 1006 1045 }) 1007 1046 .collect()) 1008 1047 } ··· 1079 1118 Ok(cid.map(CidLink::from)) 1080 1119 } 1081 1120 1082 - async fn notify_update(&self, seq: i64) -> Result<(), DbError> { 1083 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq)) 1121 + async fn notify_update(&self, seq: SequenceNumber) -> Result<(), DbError> { 1122 + sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq.as_i64())) 1084 1123 .execute(&self.pool) 1085 1124 .await 1086 1125 .map_err(map_sqlx_error)?; ··· 1329 1368 "#, 1330 1369 ) 1331 1370 .bind(&event.did) 1332 - .bind(&event.event_type) 1371 + .bind(event.event_type.as_str()) 1333 1372 .bind(&event.commit_cid) 1334 1373 .bind(&event.prev_cid) 1335 1374 .bind(&event.ops) ··· 1375 1414 Ok(rows 1376 1415 .into_iter() 1377 1416 .map(|r| BrokenGenesisCommit { 1378 - seq: r.seq, 1417 + seq: r.seq.into(), 1379 1418 did: Did::from(r.did), 1380 1419 commit_cid: r.commit_cid.map(CidLink::from), 1381 1420 })
+42 -33
crates/tranquil-db/src/postgres/session.rs
··· 2 2 use chrono::{DateTime, Utc}; 3 3 use sqlx::PgPool; 4 4 use tranquil_db_traits::{ 5 - AppPasswordCreate, AppPasswordRecord, DbError, RefreshSessionResult, SessionForRefresh, 6 - SessionListItem, SessionMfaStatus, SessionRefreshData, SessionRepository, SessionToken, 7 - SessionTokenCreate, 5 + AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, DbError, LoginType, 6 + RefreshSessionResult, SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus, 7 + SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate, 8 8 }; 9 9 use tranquil_types::Did; 10 10 use uuid::Uuid; ··· 23 23 24 24 #[async_trait] 25 25 impl SessionRepository for PostgresSessionRepository { 26 - async fn create_session(&self, data: &SessionTokenCreate) -> Result<i32, DbError> { 26 + async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError> { 27 27 let row = sqlx::query!( 28 28 r#" 29 29 INSERT INTO session_tokens ··· 37 37 data.refresh_jti, 38 38 data.access_expires_at, 39 39 data.refresh_expires_at, 40 - data.legacy_login, 40 + bool::from(data.login_type), 41 41 data.mfa_verified, 42 42 data.scope, 43 43 data.controller_did.as_ref().map(|d| d.as_str()), ··· 47 47 .await 48 48 .map_err(map_sqlx_error)?; 49 49 50 - Ok(row.id) 50 + Ok(SessionId::new(row.id)) 51 51 } 52 52 53 53 async fn get_session_by_access_jti( ··· 69 69 .map_err(map_sqlx_error)?; 70 70 71 71 Ok(row.map(|r| SessionToken { 72 - id: r.id, 72 + id: SessionId::new(r.id), 73 73 did: Did::from(r.did), 74 74 access_jti: r.access_jti, 75 75 refresh_jti: r.refresh_jti, 76 76 access_expires_at: r.access_expires_at, 77 77 refresh_expires_at: r.refresh_expires_at, 78 - legacy_login: r.legacy_login, 78 + login_type: LoginType::from(r.legacy_login), 79 79 mfa_verified: r.mfa_verified, 80 80 scope: r.scope, 81 81 controller_did: r.controller_did.map(Did::from), ··· 104 104 .map_err(map_sqlx_error)?; 105 105 106 106 Ok(row.map(|r| SessionForRefresh { 107 - id: r.id, 107 + id: SessionId::new(r.id), 108 108 did: Did::from(r.did), 109 109 scope: r.scope, 110 110 controller_did: r.controller_did.map(Did::from), ··· 115 115 116 116 async fn update_session_tokens( 117 117 &self, 118 - session_id: i32, 118 + session_id: SessionId, 119 119 new_access_jti: &str, 120 120 new_refresh_jti: &str, 121 121 new_access_expires_at: DateTime<Utc>, ··· 132 132 new_refresh_jti, 133 133 new_access_expires_at, 134 134 new_refresh_expires_at, 135 - session_id 135 + session_id.as_i32() 136 136 ) 137 137 .execute(&self.pool) 138 138 .await ··· 153 153 Ok(result.rows_affected()) 154 154 } 155 155 156 - async fn delete_session_by_id(&self, session_id: i32) -> Result<u64, DbError> { 157 - let result = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id) 158 - .execute(&self.pool) 159 - .await 160 - .map_err(map_sqlx_error)?; 156 + async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError> { 157 + let result = sqlx::query!( 158 + "DELETE FROM session_tokens WHERE id = $1", 159 + session_id.as_i32() 160 + ) 161 + .execute(&self.pool) 162 + .await 163 + .map_err(map_sqlx_error)?; 161 164 162 165 Ok(result.rows_affected()) 163 166 } ··· 205 208 Ok(rows 206 209 .into_iter() 207 210 .map(|r| SessionListItem { 208 - id: r.id, 211 + id: SessionId::new(r.id), 209 212 access_jti: r.access_jti, 210 213 created_at: r.created_at, 211 214 refresh_expires_at: r.refresh_expires_at, ··· 215 218 216 219 async fn get_session_access_jti_by_id( 217 220 &self, 218 - session_id: i32, 221 + session_id: SessionId, 219 222 did: &Did, 220 223 ) -> Result<Option<String>, DbError> { 221 224 let row = sqlx::query_scalar!( 222 225 "SELECT access_jti FROM session_tokens WHERE id = $1 AND did = $2", 223 - session_id, 226 + session_id.as_i32(), 224 227 did.as_str() 225 228 ) 226 229 .fetch_optional(&self.pool) ··· 264 267 Ok(rows) 265 268 } 266 269 267 - async fn check_refresh_token_used(&self, refresh_jti: &str) -> Result<Option<i32>, DbError> { 270 + async fn check_refresh_token_used( 271 + &self, 272 + refresh_jti: &str, 273 + ) -> Result<Option<SessionId>, DbError> { 268 274 let row = sqlx::query_scalar!( 269 275 "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1", 270 276 refresh_jti ··· 273 279 .await 274 280 .map_err(map_sqlx_error)?; 275 281 276 - Ok(row) 282 + Ok(row.map(SessionId::new)) 277 283 } 278 284 279 285 async fn mark_refresh_token_used( 280 286 &self, 281 287 refresh_jti: &str, 282 - session_id: i32, 288 + session_id: SessionId, 283 289 ) -> Result<bool, DbError> { 284 290 let result = sqlx::query!( 285 291 r#" ··· 288 294 ON CONFLICT (refresh_jti) DO NOTHING 289 295 "#, 290 296 refresh_jti, 291 - session_id 297 + session_id.as_i32() 292 298 ) 293 299 .execute(&self.pool) 294 300 .await ··· 319 325 name: r.name, 320 326 password_hash: r.password_hash, 321 327 created_at: r.created_at, 322 - privileged: r.privileged, 328 + privilege: AppPasswordPrivilege::from(r.privileged), 323 329 scopes: r.scopes, 324 330 created_by_controller_did: r.created_by_controller_did.map(Did::from), 325 331 }) ··· 352 358 name: r.name, 353 359 password_hash: r.password_hash, 354 360 created_at: r.created_at, 355 - privileged: r.privileged, 361 + privilege: AppPasswordPrivilege::from(r.privileged), 356 362 scopes: r.scopes, 357 363 created_by_controller_did: r.created_by_controller_did.map(Did::from), 358 364 }) ··· 383 389 name: r.name, 384 390 password_hash: r.password_hash, 385 391 created_at: r.created_at, 386 - privileged: r.privileged, 392 + privilege: AppPasswordPrivilege::from(r.privileged), 387 393 scopes: r.scopes, 388 394 created_by_controller_did: r.created_by_controller_did.map(Did::from), 389 395 })) ··· 399 405 data.user_id, 400 406 data.name, 401 407 data.password_hash, 402 - data.privileged, 408 + bool::from(data.privilege), 403 409 data.scopes, 404 410 data.created_by_controller_did.as_ref().map(|d| d.as_str()) 405 411 ) ··· 480 486 .map_err(map_sqlx_error)?; 481 487 482 488 Ok(row.map(|r| SessionMfaStatus { 483 - legacy_login: r.legacy_login, 489 + login_type: LoginType::from(r.legacy_login), 484 490 mfa_verified: r.mfa_verified, 485 491 last_reauth_at: r.last_reauth_at, 486 492 })) ··· 535 541 let result = sqlx::query!( 536 542 "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 537 543 data.old_refresh_jti, 538 - data.session_id 544 + data.session_id.as_i32() 539 545 ) 540 546 .execute(&mut *tx) 541 547 .await 542 548 .map_err(map_sqlx_error)?; 543 549 544 550 if result.rows_affected() == 0 { 545 - let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", data.session_id) 546 - .execute(&mut *tx) 547 - .await; 551 + let _ = sqlx::query!( 552 + "DELETE FROM session_tokens WHERE id = $1", 553 + data.session_id.as_i32() 554 + ) 555 + .execute(&mut *tx) 556 + .await; 548 557 tx.commit().await.map_err(map_sqlx_error)?; 549 558 return Ok(RefreshSessionResult::ConcurrentRefresh); 550 559 } ··· 555 564 data.new_refresh_jti, 556 565 data.new_access_expires_at, 557 566 data.new_refresh_expires_at, 558 - data.session_id 567 + data.session_id.as_i32() 559 568 ) 560 569 .execute(&mut *tx) 561 570 .await
+33 -28
crates/tranquil-db/src/postgres/sso.rs
··· 2 2 use chrono::Utc; 3 3 use sqlx::PgPool; 4 4 use tranquil_db_traits::{ 5 - DbError, ExternalIdentity, SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository, 5 + DbError, ExternalEmail, ExternalIdentity, ExternalUserId, ExternalUsername, SsoAction, 6 + SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository, 6 7 }; 7 8 use tranquil_types::Did; 8 9 use uuid::Uuid; ··· 69 70 70 71 Ok(row.map(|r| ExternalIdentity { 71 72 id: r.id, 72 - did: Did::new_unchecked(&r.did), 73 + did: unsafe { Did::new_unchecked(&r.did) }, 73 74 provider: r.provider, 74 - provider_user_id: r.provider_user_id, 75 - provider_username: r.provider_username, 76 - provider_email: r.provider_email, 75 + provider_user_id: ExternalUserId::from(r.provider_user_id), 76 + provider_username: r.provider_username.map(ExternalUsername::from), 77 + provider_email: r.provider_email.map(ExternalEmail::from), 77 78 created_at: r.created_at, 78 79 updated_at: r.updated_at, 79 80 last_login_at: r.last_login_at, ··· 102 103 .into_iter() 103 104 .map(|r| ExternalIdentity { 104 105 id: r.id, 105 - did: Did::new_unchecked(&r.did), 106 + did: unsafe { Did::new_unchecked(&r.did) }, 106 107 provider: r.provider, 107 - provider_user_id: r.provider_user_id, 108 - provider_username: r.provider_username, 109 - provider_email: r.provider_email, 108 + provider_user_id: ExternalUserId::from(r.provider_user_id), 109 + provider_username: r.provider_username.map(ExternalUsername::from), 110 + provider_email: r.provider_email.map(ExternalEmail::from), 110 111 created_at: r.created_at, 111 112 updated_at: r.updated_at, 112 113 last_login_at: r.last_login_at, ··· 161 162 state: &str, 162 163 request_uri: &str, 163 164 provider: SsoProviderType, 164 - action: &str, 165 + action: SsoAction, 165 166 nonce: Option<&str>, 166 167 code_verifier: Option<&str>, 167 168 did: Option<&Did>, ··· 174 175 state, 175 176 request_uri, 176 177 provider as SsoProviderType, 177 - action, 178 + action.as_str(), 178 179 nonce, 179 180 code_verifier, 180 181 did.map(|d| d.as_str()), ··· 200 201 .await 201 202 .map_err(map_sqlx_error)?; 202 203 203 - Ok(row.map(|r| SsoAuthState { 204 - state: r.state, 205 - request_uri: r.request_uri, 206 - provider: r.provider, 207 - action: r.action, 208 - nonce: r.nonce, 209 - code_verifier: r.code_verifier, 210 - did: r.did.map(|d| Did::new_unchecked(&d)), 211 - created_at: r.created_at, 212 - expires_at: r.expires_at, 213 - })) 204 + row.map(|r| { 205 + let action = SsoAction::parse(&r.action).ok_or(DbError::NotFound)?; 206 + Ok(SsoAuthState { 207 + state: r.state, 208 + request_uri: r.request_uri, 209 + provider: r.provider, 210 + action, 211 + nonce: r.nonce, 212 + code_verifier: r.code_verifier, 213 + did: r.did.map(|d| unsafe { Did::new_unchecked(&d) }), 214 + created_at: r.created_at, 215 + expires_at: r.expires_at, 216 + }) 217 + }) 218 + .transpose() 214 219 } 215 220 216 221 async fn cleanup_expired_sso_auth_states(&self) -> Result<u64, DbError> { ··· 280 285 token: r.token, 281 286 request_uri: r.request_uri, 282 287 provider: r.provider, 283 - provider_user_id: r.provider_user_id, 284 - provider_username: r.provider_username, 285 - provider_email: r.provider_email, 288 + provider_user_id: ExternalUserId::from(r.provider_user_id), 289 + provider_username: r.provider_username.map(ExternalUsername::from), 290 + provider_email: r.provider_email.map(ExternalEmail::from), 286 291 provider_email_verified: r.provider_email_verified, 287 292 created_at: r.created_at, 288 293 expires_at: r.expires_at, ··· 311 316 token: r.token, 312 317 request_uri: r.request_uri, 313 318 provider: r.provider, 314 - provider_user_id: r.provider_user_id, 315 - provider_username: r.provider_username, 316 - provider_email: r.provider_email, 319 + provider_user_id: ExternalUserId::from(r.provider_user_id), 320 + provider_username: r.provider_username.map(ExternalUsername::from), 321 + provider_email: r.provider_email.map(ExternalEmail::from), 317 322 provider_email_verified: r.provider_email_verified, 318 323 created_at: r.created_at, 319 324 expires_at: r.expires_at,
+66 -45
crates/tranquil-db/src/postgres/user.rs
··· 5 5 use uuid::Uuid; 6 6 7 7 use tranquil_db_traits::{ 8 - AccountSearchResult, CommsChannel, DbError, DidWebOverrides, NotificationPrefs, 9 - OAuthTokenWithUser, PasswordResetResult, SsoProviderType, StoredBackupCode, StoredPasskey, 10 - TotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs, UserConfirmSignup, UserDidWebInfo, 11 - UserEmailInfo, UserForDeletion, UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery, 12 - UserForPasskeySetup, UserForRecovery, UserForVerification, UserIdAndHandle, 13 - UserIdAndPasswordHash, UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId, 14 - UserLegacyLoginPref, UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo, 15 - UserRepository, UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo, 16 - UserStatus, UserVerificationInfo, UserWithKey, 8 + AccountSearchResult, AccountType, ChannelVerificationStatus, CommsChannel, DbError, 9 + DidWebOverrides, NotificationPrefs, OAuthTokenWithUser, PasswordResetResult, SsoProviderType, 10 + StoredBackupCode, StoredPasskey, TotpRecord, TotpRecordState, User2faStatus, UserAuthInfo, 11 + UserCommsPrefs, UserConfirmSignup, UserDidWebInfo, UserEmailInfo, UserForDeletion, 12 + UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery, UserForPasskeySetup, 13 + UserForRecovery, UserForVerification, UserIdAndHandle, UserIdAndPasswordHash, 14 + UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref, 15 + UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository, 16 + UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus, 17 + UserVerificationInfo, UserWithKey, 17 18 }; 18 19 19 20 pub struct PostgresUserRepository { ··· 280 281 password_hash: r.password_hash, 281 282 deactivated_at: r.deactivated_at, 282 283 takedown_ref: r.takedown_ref, 283 - email_verified: r.email_verified, 284 - discord_verified: r.discord_verified, 285 - telegram_verified: r.telegram_verified, 286 - signal_verified: r.signal_verified, 284 + channel_verification: ChannelVerificationStatus::new( 285 + r.email_verified, 286 + r.discord_verified, 287 + r.telegram_verified, 288 + r.signal_verified, 289 + ), 287 290 })) 288 291 } 289 292 ··· 308 311 309 312 async fn get_comms_prefs(&self, user_id: Uuid) -> Result<Option<UserCommsPrefs>, DbError> { 310 313 let row = sqlx::query!( 311 - r#"SELECT email, handle, preferred_comms_channel::text as "preferred_channel!", preferred_locale 314 + r#"SELECT email, handle, preferred_comms_channel as "preferred_channel!: CommsChannel", preferred_locale 312 315 FROM users WHERE id = $1"#, 313 316 user_id 314 317 ) ··· 601 604 let row = sqlx::query!( 602 605 r#"SELECT 603 606 email, 604 - preferred_comms_channel::text as "preferred_channel!", 607 + preferred_comms_channel as "preferred_channel!: CommsChannel", 605 608 discord_id, 606 609 discord_verified, 607 610 telegram_username, ··· 647 650 async fn update_preferred_comms_channel( 648 651 &self, 649 652 did: &Did, 650 - channel: &str, 653 + channel: CommsChannel, 651 654 ) -> Result<(), DbError> { 652 - sqlx::query( 653 - "UPDATE users SET preferred_comms_channel = $1::comms_channel, updated_at = NOW() WHERE did = $2", 655 + sqlx::query!( 656 + "UPDATE users SET preferred_comms_channel = $1, updated_at = NOW() WHERE did = $2", 657 + channel as CommsChannel, 658 + did.as_str() 654 659 ) 655 - .bind(channel) 656 - .bind(did.as_str()) 657 660 .execute(&self.pool) 658 661 .await 659 662 .map_err(map_sqlx_error)?; ··· 709 712 id: r.id, 710 713 handle: Handle::from(r.handle), 711 714 email: r.email, 712 - email_verified: r.email_verified, 713 - discord_verified: r.discord_verified, 714 - telegram_verified: r.telegram_verified, 715 - signal_verified: r.signal_verified, 715 + channel_verification: ChannelVerificationStatus::new( 716 + r.email_verified, 717 + r.discord_verified, 718 + r.telegram_verified, 719 + r.signal_verified, 720 + ), 716 721 })) 717 722 } 718 723 ··· 1065 1070 })) 1066 1071 } 1067 1072 1073 + async fn get_totp_record_state(&self, did: &Did) -> Result<Option<TotpRecordState>, DbError> { 1074 + self.get_totp_record(did) 1075 + .await 1076 + .map(|opt| opt.map(TotpRecordState::from)) 1077 + } 1078 + 1068 1079 async fn upsert_totp_secret( 1069 1080 &self, 1070 1081 did: &Did, ··· 1300 1311 preferred_comms_channel as "preferred_comms_channel!: CommsChannel", 1301 1312 deactivated_at, takedown_ref, 1302 1313 email_verified, discord_verified, telegram_verified, signal_verified, 1303 - account_type::text as "account_type!" 1314 + account_type as "account_type!: AccountType" 1304 1315 FROM users 1305 1316 WHERE handle = $1 OR email = $1 1306 1317 "#, ··· 1320 1331 preferred_comms_channel: row.preferred_comms_channel, 1321 1332 deactivated_at: row.deactivated_at, 1322 1333 takedown_ref: row.takedown_ref, 1323 - email_verified: row.email_verified, 1324 - discord_verified: row.discord_verified, 1325 - telegram_verified: row.telegram_verified, 1326 - signal_verified: row.signal_verified, 1334 + channel_verification: ChannelVerificationStatus::new( 1335 + row.email_verified, 1336 + row.discord_verified, 1337 + row.telegram_verified, 1338 + row.signal_verified, 1339 + ), 1327 1340 account_type: row.account_type, 1328 1341 }) 1329 1342 }) ··· 1348 1361 id: row.id, 1349 1362 two_factor_enabled: row.two_factor_enabled, 1350 1363 preferred_comms_channel: row.preferred_comms_channel, 1351 - email_verified: row.email_verified, 1352 - discord_verified: row.discord_verified, 1353 - telegram_verified: row.telegram_verified, 1354 - signal_verified: row.signal_verified, 1364 + channel_verification: ChannelVerificationStatus::new( 1365 + row.email_verified, 1366 + row.discord_verified, 1367 + row.telegram_verified, 1368 + row.signal_verified, 1369 + ), 1355 1370 }) 1356 1371 }) 1357 1372 } ··· 1376 1391 opt.map(|row| UserSessionInfo { 1377 1392 handle: Handle::from(row.handle), 1378 1393 email: row.email, 1379 - email_verified: row.email_verified, 1380 1394 is_admin: row.is_admin, 1381 1395 deactivated_at: row.deactivated_at, 1382 1396 takedown_ref: row.takedown_ref, 1383 1397 preferred_locale: row.preferred_locale, 1384 1398 preferred_comms_channel: row.preferred_comms_channel, 1385 - discord_verified: row.discord_verified, 1386 - telegram_verified: row.telegram_verified, 1387 - signal_verified: row.signal_verified, 1399 + channel_verification: ChannelVerificationStatus::new( 1400 + row.email_verified, 1401 + row.discord_verified, 1402 + row.telegram_verified, 1403 + row.signal_verified, 1404 + ), 1388 1405 migrated_to_pds: row.migrated_to_pds, 1389 1406 migrated_at: row.migrated_at, 1390 1407 }) ··· 1469 1486 email: row.email, 1470 1487 deactivated_at: row.deactivated_at, 1471 1488 takedown_ref: row.takedown_ref, 1472 - email_verified: row.email_verified, 1473 - discord_verified: row.discord_verified, 1474 - telegram_verified: row.telegram_verified, 1475 - signal_verified: row.signal_verified, 1489 + channel_verification: ChannelVerificationStatus::new( 1490 + row.email_verified, 1491 + row.discord_verified, 1492 + row.telegram_verified, 1493 + row.signal_verified, 1494 + ), 1476 1495 allow_legacy_login: row.allow_legacy_login, 1477 1496 migrated_to_pds: row.migrated_to_pds, 1478 1497 preferred_comms_channel: row.preferred_comms_channel, ··· 1543 1562 discord_id: row.discord_id, 1544 1563 telegram_username: row.telegram_username, 1545 1564 signal_number: row.signal_number, 1546 - email_verified: row.email_verified, 1547 - discord_verified: row.discord_verified, 1548 - telegram_verified: row.telegram_verified, 1549 - signal_verified: row.signal_verified, 1565 + channel_verification: ChannelVerificationStatus::new( 1566 + row.email_verified, 1567 + row.discord_verified, 1568 + row.telegram_verified, 1569 + row.signal_verified, 1570 + ), 1550 1571 }) 1551 1572 }) 1552 1573 }
+6 -4
crates/tranquil-oauth/src/lib.rs
··· 10 10 }; 11 11 pub use error::OAuthError; 12 12 pub use types::{ 13 - AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 - AuthorizedClientData, ClientAuth, Code, DPoPClaims, DeviceData, DeviceId, JwkPublicKey, Jwks, 15 - OAuthClientMetadata, ParResponse, ProtectedResourceMetadata, RefreshToken, RefreshTokenState, 16 - RequestData, RequestId, SessionId, TokenData, TokenId, TokenRequest, TokenResponse, 13 + AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 + AuthorizedClientData, ClientAuth, Code, CodeChallengeMethod, DPoPClaims, DeviceData, DeviceId, 15 + FlowAuthenticated, FlowAuthorized, FlowExpired, FlowNotAuthenticated, FlowNotAuthorized, 16 + FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, ParResponse, Prompt, 17 + ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, 18 + ResponseMode, ResponseType, SessionId, TokenData, TokenId, TokenRequest, TokenResponse, 17 19 };
+247 -144
crates/tranquil-oauth/src/types.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use serde::{Deserialize, Serialize}; 3 3 use serde_json::Value as JsonValue; 4 + use tranquil_types::Did; 4 5 5 - #[derive(Debug, Clone, Serialize, Deserialize)] 6 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 7 + #[serde(transparent)] 8 + #[sqlx(transparent)] 6 9 pub struct RequestId(pub String); 7 10 8 - #[derive(Debug, Clone, Serialize, Deserialize)] 11 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 12 + #[serde(transparent)] 13 + #[sqlx(transparent)] 9 14 pub struct TokenId(pub String); 10 15 11 - #[derive(Debug, Clone, Serialize, Deserialize)] 16 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 17 + #[serde(transparent)] 18 + #[sqlx(transparent)] 12 19 pub struct DeviceId(pub String); 13 20 14 - #[derive(Debug, Clone, Serialize, Deserialize)] 21 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 22 + #[serde(transparent)] 23 + #[sqlx(transparent)] 15 24 pub struct SessionId(pub String); 16 25 17 - #[derive(Debug, Clone, Serialize, Deserialize)] 26 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 27 + #[serde(transparent)] 28 + #[sqlx(transparent)] 18 29 pub struct Code(pub String); 19 30 20 - #[derive(Debug, Clone, Serialize, Deserialize)] 31 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 32 + #[serde(transparent)] 33 + #[sqlx(transparent)] 21 34 pub struct RefreshToken(pub String); 22 35 23 36 impl RequestId { ··· 82 95 PrivateKeyJwt { client_assertion: String }, 83 96 } 84 97 98 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 99 + #[serde(rename_all = "snake_case")] 100 + pub enum ResponseType { 101 + #[default] 102 + Code, 103 + } 104 + 105 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 106 + pub enum CodeChallengeMethod { 107 + #[default] 108 + #[serde(rename = "S256")] 109 + S256, 110 + #[serde(rename = "plain")] 111 + Plain, 112 + } 113 + 114 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 115 + #[serde(rename_all = "snake_case")] 116 + pub enum ResponseMode { 117 + #[default] 118 + Query, 119 + Fragment, 120 + FormPost, 121 + } 122 + 123 + impl ResponseMode { 124 + pub fn as_str(&self) -> &'static str { 125 + match self { 126 + Self::Query => "query", 127 + Self::Fragment => "fragment", 128 + Self::FormPost => "form_post", 129 + } 130 + } 131 + } 132 + 133 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 134 + #[serde(rename_all = "snake_case")] 135 + pub enum Prompt { 136 + None, 137 + Login, 138 + Consent, 139 + SelectAccount, 140 + Create, 141 + } 142 + 143 + impl Prompt { 144 + pub fn as_str(&self) -> &'static str { 145 + match self { 146 + Self::None => "none", 147 + Self::Login => "login", 148 + Self::Consent => "consent", 149 + Self::SelectAccount => "select_account", 150 + Self::Create => "create", 151 + } 152 + } 153 + } 154 + 85 155 #[derive(Debug, Clone, Serialize, Deserialize)] 86 156 pub struct AuthorizationRequestParameters { 87 - pub response_type: String, 157 + pub response_type: ResponseType, 88 158 pub client_id: String, 89 159 pub redirect_uri: String, 90 160 pub scope: Option<String>, 91 161 pub state: Option<String>, 92 162 pub code_challenge: String, 93 - pub code_challenge_method: String, 94 - pub response_mode: Option<String>, 163 + pub code_challenge_method: CodeChallengeMethod, 164 + pub response_mode: Option<ResponseMode>, 95 165 pub login_hint: Option<String>, 96 166 pub dpop_jkt: Option<String>, 97 - pub prompt: Option<String>, 167 + pub prompt: Option<Prompt>, 98 168 #[serde(flatten)] 99 169 pub extra: Option<JsonValue>, 100 170 } ··· 105 175 pub client_auth: Option<ClientAuth>, 106 176 pub parameters: AuthorizationRequestParameters, 107 177 pub expires_at: DateTime<Utc>, 108 - pub did: Option<String>, 109 - pub device_id: Option<String>, 110 - pub code: Option<String>, 111 - pub controller_did: Option<String>, 178 + pub did: Option<Did>, 179 + pub device_id: Option<DeviceId>, 180 + pub code: Option<Code>, 181 + pub controller_did: Option<Did>, 112 182 } 113 183 114 184 #[derive(Debug, Clone)] 115 185 pub struct DeviceData { 116 - pub session_id: String, 186 + pub session_id: SessionId, 117 187 pub user_agent: Option<String>, 118 188 pub ip_address: String, 119 189 pub last_seen_at: DateTime<Utc>, ··· 121 191 122 192 #[derive(Debug, Clone)] 123 193 pub struct TokenData { 124 - pub did: String, 125 - pub token_id: String, 194 + pub did: Did, 195 + pub token_id: TokenId, 126 196 pub created_at: DateTime<Utc>, 127 197 pub updated_at: DateTime<Utc>, 128 198 pub expires_at: DateTime<Utc>, 129 199 pub client_id: String, 130 200 pub client_auth: ClientAuth, 131 - pub device_id: Option<String>, 201 + pub device_id: Option<DeviceId>, 132 202 pub parameters: AuthorizationRequestParameters, 133 203 pub details: Option<JsonValue>, 134 - pub code: Option<String>, 135 - pub current_refresh_token: Option<String>, 204 + pub code: Option<Code>, 205 + pub current_refresh_token: Option<RefreshToken>, 136 206 pub scope: Option<String>, 137 - pub controller_did: Option<String>, 207 + pub controller_did: Option<Did>, 138 208 } 139 209 140 210 #[derive(Debug, Clone, Serialize, Deserialize)] ··· 247 317 pub keys: Vec<JwkPublicKey>, 248 318 } 249 319 250 - #[derive(Debug, Clone, PartialEq, Eq)] 251 - pub enum AuthFlowState { 252 - Pending, 253 - Authenticated { 254 - did: String, 255 - device_id: Option<String>, 256 - }, 257 - Authorized { 258 - did: String, 259 - device_id: Option<String>, 260 - code: String, 261 - }, 262 - Expired, 320 + #[derive(Debug, Clone)] 321 + pub struct FlowPending { 322 + pub parameters: AuthorizationRequestParameters, 323 + pub client_id: String, 324 + pub client_auth: Option<ClientAuth>, 325 + pub expires_at: DateTime<Utc>, 326 + pub controller_did: Option<Did>, 263 327 } 264 328 265 - impl AuthFlowState { 266 - pub fn from_request_data(data: &RequestData) -> Self { 267 - if data.expires_at < chrono::Utc::now() { 268 - return AuthFlowState::Expired; 269 - } 270 - match (&data.did, &data.code) { 271 - (Some(did), Some(code)) => AuthFlowState::Authorized { 272 - did: did.clone(), 273 - device_id: data.device_id.clone(), 274 - code: code.clone(), 275 - }, 276 - (Some(did), None) => AuthFlowState::Authenticated { 277 - did: did.clone(), 278 - device_id: data.device_id.clone(), 279 - }, 280 - (None, _) => AuthFlowState::Pending, 281 - } 282 - } 329 + #[derive(Debug, Clone)] 330 + pub struct FlowAuthenticated { 331 + pub parameters: AuthorizationRequestParameters, 332 + pub client_id: String, 333 + pub client_auth: Option<ClientAuth>, 334 + pub expires_at: DateTime<Utc>, 335 + pub did: Did, 336 + pub device_id: Option<DeviceId>, 337 + pub controller_did: Option<Did>, 338 + } 283 339 284 - pub fn is_pending(&self) -> bool { 285 - matches!(self, AuthFlowState::Pending) 286 - } 340 + #[derive(Debug, Clone)] 341 + pub struct FlowAuthorized { 342 + pub parameters: AuthorizationRequestParameters, 343 + pub client_id: String, 344 + pub client_auth: Option<ClientAuth>, 345 + pub expires_at: DateTime<Utc>, 346 + pub did: Did, 347 + pub device_id: Option<DeviceId>, 348 + pub code: Code, 349 + pub controller_did: Option<Did>, 350 + } 287 351 288 - pub fn is_authenticated(&self) -> bool { 289 - matches!(self, AuthFlowState::Authenticated { .. }) 290 - } 352 + #[derive(Debug)] 353 + pub struct FlowExpired; 291 354 292 - pub fn is_authorized(&self) -> bool { 293 - matches!(self, AuthFlowState::Authorized { .. }) 355 + #[derive(Debug)] 356 + pub struct FlowNotAuthenticated; 357 + 358 + #[derive(Debug)] 359 + pub struct FlowNotAuthorized; 360 + 361 + #[derive(Debug, Clone)] 362 + pub enum AuthFlow { 363 + Pending(FlowPending), 364 + Authenticated(FlowAuthenticated), 365 + Authorized(FlowAuthorized), 366 + } 367 + 368 + #[derive(Debug, Clone)] 369 + pub enum AuthFlowWithUser { 370 + Authenticated(FlowAuthenticated), 371 + Authorized(FlowAuthorized), 372 + } 373 + 374 + impl AuthFlow { 375 + pub fn from_request_data(data: RequestData) -> Result<Self, FlowExpired> { 376 + if data.expires_at < chrono::Utc::now() { 377 + return Err(FlowExpired); 378 + } 379 + match (data.did, data.code) { 380 + (None, _) => Ok(AuthFlow::Pending(FlowPending { 381 + parameters: data.parameters, 382 + client_id: data.client_id, 383 + client_auth: data.client_auth, 384 + expires_at: data.expires_at, 385 + controller_did: data.controller_did, 386 + })), 387 + (Some(did), None) => Ok(AuthFlow::Authenticated(FlowAuthenticated { 388 + parameters: data.parameters, 389 + client_id: data.client_id, 390 + client_auth: data.client_auth, 391 + expires_at: data.expires_at, 392 + did, 393 + device_id: data.device_id, 394 + controller_did: data.controller_did, 395 + })), 396 + (Some(did), Some(code)) => Ok(AuthFlow::Authorized(FlowAuthorized { 397 + parameters: data.parameters, 398 + client_id: data.client_id, 399 + client_auth: data.client_auth, 400 + expires_at: data.expires_at, 401 + did, 402 + device_id: data.device_id, 403 + code, 404 + controller_did: data.controller_did, 405 + })), 406 + } 294 407 } 295 408 296 - pub fn is_expired(&self) -> bool { 297 - matches!(self, AuthFlowState::Expired) 409 + pub fn require_user(self) -> Result<AuthFlowWithUser, FlowNotAuthenticated> { 410 + match self { 411 + AuthFlow::Pending(_) => Err(FlowNotAuthenticated), 412 + AuthFlow::Authenticated(a) => Ok(AuthFlowWithUser::Authenticated(a)), 413 + AuthFlow::Authorized(a) => Ok(AuthFlowWithUser::Authorized(a)), 414 + } 298 415 } 299 416 300 - pub fn can_authenticate(&self) -> bool { 301 - matches!(self, AuthFlowState::Pending) 417 + pub fn require_authorized(self) -> Result<FlowAuthorized, FlowNotAuthorized> { 418 + match self { 419 + AuthFlow::Authorized(a) => Ok(a), 420 + _ => Err(FlowNotAuthorized), 421 + } 302 422 } 423 + } 303 424 304 - pub fn can_authorize(&self) -> bool { 305 - matches!(self, AuthFlowState::Authenticated { .. }) 425 + impl AuthFlowWithUser { 426 + pub fn did(&self) -> &Did { 427 + match self { 428 + AuthFlowWithUser::Authenticated(a) => &a.did, 429 + AuthFlowWithUser::Authorized(a) => &a.did, 430 + } 306 431 } 307 432 308 - pub fn can_exchange(&self) -> bool { 309 - matches!(self, AuthFlowState::Authorized { .. }) 433 + pub fn device_id(&self) -> Option<&DeviceId> { 434 + match self { 435 + AuthFlowWithUser::Authenticated(a) => a.device_id.as_ref(), 436 + AuthFlowWithUser::Authorized(a) => a.device_id.as_ref(), 437 + } 310 438 } 311 439 312 - pub fn did(&self) -> Option<&str> { 440 + pub fn parameters(&self) -> &AuthorizationRequestParameters { 313 441 match self { 314 - AuthFlowState::Authenticated { did, .. } | AuthFlowState::Authorized { did, .. } => { 315 - Some(did) 316 - } 317 - _ => None, 442 + AuthFlowWithUser::Authenticated(a) => &a.parameters, 443 + AuthFlowWithUser::Authorized(a) => &a.parameters, 318 444 } 319 445 } 320 446 321 - pub fn code(&self) -> Option<&str> { 447 + pub fn client_id(&self) -> &str { 322 448 match self { 323 - AuthFlowState::Authorized { code, .. } => Some(code), 324 - _ => None, 449 + AuthFlowWithUser::Authenticated(a) => &a.client_id, 450 + AuthFlowWithUser::Authorized(a) => &a.client_id, 325 451 } 326 452 } 327 - } 328 453 329 - impl std::fmt::Display for AuthFlowState { 330 - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 454 + pub fn controller_did(&self) -> Option<&Did> { 331 455 match self { 332 - AuthFlowState::Pending => write!(f, "pending"), 333 - AuthFlowState::Authenticated { did, .. } => write!(f, "authenticated ({})", did), 334 - AuthFlowState::Authorized { did, code, .. } => { 335 - write!( 336 - f, 337 - "authorized ({}, code={}...)", 338 - did, 339 - &code[..8.min(code.len())] 340 - ) 341 - } 342 - AuthFlowState::Expired => write!(f, "expired"), 456 + AuthFlowWithUser::Authenticated(a) => a.controller_did.as_ref(), 457 + AuthFlowWithUser::Authorized(a) => a.controller_did.as_ref(), 343 458 } 344 459 } 345 460 } ··· 406 521 use chrono::{Duration, Utc}; 407 522 408 523 fn make_request_data( 409 - did: Option<String>, 410 - code: Option<String>, 524 + did: Option<Did>, 525 + code: Option<Code>, 411 526 expires_in: Duration, 412 527 ) -> RequestData { 413 528 RequestData { 414 529 client_id: "test-client".into(), 415 530 client_auth: None, 416 531 parameters: AuthorizationRequestParameters { 417 - response_type: "code".into(), 532 + response_type: ResponseType::Code, 418 533 client_id: "test-client".into(), 419 534 redirect_uri: "https://example.com/callback".into(), 420 535 scope: Some("atproto".into()), 421 536 state: None, 422 537 code_challenge: "test".into(), 423 - code_challenge_method: "S256".into(), 538 + code_challenge_method: CodeChallengeMethod::S256, 424 539 response_mode: None, 425 540 login_hint: None, 426 541 dpop_jkt: None, ··· 435 550 } 436 551 } 437 552 553 + fn test_did(s: &str) -> Did { 554 + s.parse().expect("valid test DID") 555 + } 556 + 557 + fn test_code(s: &str) -> Code { 558 + Code(s.to_string()) 559 + } 560 + 438 561 #[test] 439 - fn test_auth_flow_state_pending() { 562 + fn test_auth_flow_pending() { 440 563 let data = make_request_data(None, None, Duration::minutes(5)); 441 - let state = AuthFlowState::from_request_data(&data); 442 - assert!(state.is_pending()); 443 - assert!(!state.is_authenticated()); 444 - assert!(!state.is_authorized()); 445 - assert!(!state.is_expired()); 446 - assert!(state.can_authenticate()); 447 - assert!(!state.can_authorize()); 448 - assert!(!state.can_exchange()); 449 - assert!(state.did().is_none()); 450 - assert!(state.code().is_none()); 564 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 565 + assert!(matches!(flow, AuthFlow::Pending(_))); 566 + assert!(flow.clone().require_user().is_err()); 567 + assert!(flow.require_authorized().is_err()); 451 568 } 452 569 453 570 #[test] 454 - fn test_auth_flow_state_authenticated() { 455 - let data = make_request_data(Some("did:plc:test".into()), None, Duration::minutes(5)); 456 - let state = AuthFlowState::from_request_data(&data); 457 - assert!(!state.is_pending()); 458 - assert!(state.is_authenticated()); 459 - assert!(!state.is_authorized()); 460 - assert!(!state.is_expired()); 461 - assert!(!state.can_authenticate()); 462 - assert!(state.can_authorize()); 463 - assert!(!state.can_exchange()); 464 - assert_eq!(state.did(), Some("did:plc:test")); 465 - assert!(state.code().is_none()); 571 + fn test_auth_flow_authenticated() { 572 + let did = test_did("did:plc:test"); 573 + let data = make_request_data(Some(did.clone()), None, Duration::minutes(5)); 574 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 575 + assert!(matches!(flow, AuthFlow::Authenticated(_))); 576 + let with_user = flow.clone().require_user().expect("should have user"); 577 + assert_eq!(with_user.did(), &did); 578 + assert!(flow.require_authorized().is_err()); 466 579 } 467 580 468 581 #[test] 469 - fn test_auth_flow_state_authorized() { 470 - let data = make_request_data( 471 - Some("did:plc:test".into()), 472 - Some("auth-code-123".into()), 473 - Duration::minutes(5), 474 - ); 475 - let state = AuthFlowState::from_request_data(&data); 476 - assert!(!state.is_pending()); 477 - assert!(!state.is_authenticated()); 478 - assert!(state.is_authorized()); 479 - assert!(!state.is_expired()); 480 - assert!(!state.can_authenticate()); 481 - assert!(!state.can_authorize()); 482 - assert!(state.can_exchange()); 483 - assert_eq!(state.did(), Some("did:plc:test")); 484 - assert_eq!(state.code(), Some("auth-code-123")); 582 + fn test_auth_flow_authorized() { 583 + let did = test_did("did:plc:test"); 584 + let code = test_code("auth-code-123"); 585 + let data = make_request_data(Some(did.clone()), Some(code.clone()), Duration::minutes(5)); 586 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 587 + assert!(matches!(flow, AuthFlow::Authorized(_))); 588 + let with_user = flow.clone().require_user().expect("should have user"); 589 + assert_eq!(with_user.did(), &did); 590 + let authorized = flow.require_authorized().expect("should be authorized"); 591 + assert_eq!(authorized.did, did); 592 + assert_eq!(authorized.code, code); 485 593 } 486 594 487 595 #[test] 488 - fn test_auth_flow_state_expired() { 489 - let data = make_request_data( 490 - Some("did:plc:test".into()), 491 - Some("code".into()), 492 - Duration::minutes(-1), 493 - ); 494 - let state = AuthFlowState::from_request_data(&data); 495 - assert!(state.is_expired()); 496 - assert!(!state.can_authenticate()); 497 - assert!(!state.can_authorize()); 498 - assert!(!state.can_exchange()); 596 + fn test_auth_flow_expired() { 597 + let did = test_did("did:plc:test"); 598 + let code = test_code("code"); 599 + let data = make_request_data(Some(did), Some(code), Duration::minutes(-1)); 600 + let result = AuthFlow::from_request_data(data); 601 + assert!(result.is_err()); 499 602 } 500 603 501 604 #[test]
+1
crates/tranquil-pds/Cargo.toml
··· 67 67 thiserror = { workspace = true } 68 68 tokio = { workspace = true } 69 69 tokio-tungstenite = { workspace = true } 70 + tokio-util = { workspace = true } 70 71 tower = { workspace = true } 71 72 tower-http = { workspace = true } 72 73 tower-layer = { workspace = true }
+10 -12
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use crate::types::Did; ··· 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use serde::Deserialize; 12 - use tracing::{error, warn}; 12 + use tracing::warn; 13 13 14 14 #[derive(Deserialize)] 15 15 pub struct DeleteAccountInput { ··· 26 26 .user_repo 27 27 .get_id_and_handle_by_did(did) 28 28 .await 29 - .map_err(|e| { 30 - error!("DB error in delete_account: {:?}", e); 31 - ApiError::InternalError(None) 32 - })? 29 + .log_db_err("in delete_account")? 33 30 .ok_or(ApiError::AccountNotFound) 34 31 .map(|row| (row.id, row.handle))?; 35 32 ··· 37 34 .user_repo 38 35 .admin_delete_account_complete(user_id, did) 39 36 .await 40 - .map_err(|e| { 41 - error!("Failed to delete account {}: {:?}", did, e); 42 - ApiError::InternalError(Some("Failed to delete account".into())) 43 - })?; 37 + .log_db_err("deleting account")?; 44 38 45 - if let Err(e) = 46 - crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await 39 + if let Err(e) = crate::api::repo::record::sequence_account_event( 40 + &state, 41 + did, 42 + tranquil_db_traits::AccountStatus::Deleted, 43 + ) 44 + .await 47 45 { 48 46 warn!( 49 47 "Failed to sequence account deletion event for {}: {}",
+5 -7
crates/tranquil-pds/src/api/admin/account/email.rs
··· 1 - use crate::api::error::{ApiError, AtpJson}; 1 + use crate::api::error::{ApiError, AtpJson, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::Did; 5 + use crate::util::pds_hostname; 5 6 use axum::{ 6 7 Json, 7 8 extract::State, ··· 9 10 response::{IntoResponse, Response}, 10 11 }; 11 12 use serde::{Deserialize, Serialize}; 12 - use tracing::{error, warn}; 13 + use tracing::warn; 13 14 14 15 #[derive(Deserialize)] 15 16 #[serde(rename_all = "camelCase")] ··· 39 40 .user_repo 40 41 .get_by_did(&input.recipient_did) 41 42 .await 42 - .map_err(|e| { 43 - error!("DB error in send_email: {:?}", e); 44 - ApiError::InternalError(None) 45 - })? 43 + .log_db_err("in send_email")? 46 44 .ok_or(ApiError::AccountNotFound)?; 47 45 48 46 let email = user.email.ok_or(ApiError::NoEmail)?; 49 47 let (user_id, handle) = (user.id, user.handle); 50 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 48 + let hostname = pds_hostname(); 51 49 let subject = input 52 50 .subject 53 51 .clone()
+6 -13
crates/tranquil-pds/src/api/admin/account/info.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; ··· 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 12 use std::collections::HashMap; 13 - use tracing::error; 14 13 15 14 #[derive(Deserialize)] 16 15 pub struct GetAccountInfoParams { ··· 74 73 .infra_repo 75 74 .get_admin_account_info_by_did(&params.did) 76 75 .await 77 - .map_err(|e| { 78 - error!("DB error in get_account_info: {:?}", e); 79 - ApiError::InternalError(None) 80 - })? 76 + .log_db_err("in get_account_info")? 81 77 .ok_or(ApiError::AccountNotFound)?; 82 78 83 79 let invited_by = get_invited_by(&state, account.id).await; ··· 153 149 .map(|ic| InviteCodeInfo { 154 150 code: ic.code.clone(), 155 151 available: ic.available_uses, 156 - disabled: ic.disabled, 152 + disabled: ic.state.is_disabled(), 157 153 for_account: ic.for_account, 158 154 created_by: ic.created_by, 159 155 created_at: ic.created_at.to_rfc3339(), ··· 181 177 Some(InviteCodeInfo { 182 178 code: info.code, 183 179 available: info.available_uses, 184 - disabled: info.disabled, 180 + disabled: info.state.is_disabled(), 185 181 for_account: info.for_account, 186 182 created_by: info.created_by, 187 183 created_at: info.created_at.to_rfc3339(), ··· 214 210 .infra_repo 215 211 .get_admin_account_infos_by_dids(&dids_typed) 216 212 .await 217 - .map_err(|e| { 218 - error!("Failed to fetch account infos: {:?}", e); 219 - ApiError::InternalError(None) 220 - })?; 213 + .log_db_err("fetching account infos")?; 221 214 222 215 let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect(); 223 216 ··· 272 265 let info = InviteCodeInfo { 273 266 code: ic.code.clone(), 274 267 available: ic.available_uses, 275 - disabled: ic.disabled, 268 + disabled: ic.state.is_disabled(), 276 269 for_account: ic.for_account, 277 270 created_by: ic.created_by, 278 271 created_at: ic.created_at.to_rfc3339(),
+2 -6
crates/tranquil-pds/src/api/admin/account/search.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; ··· 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 - use tracing::error; 13 12 14 13 #[derive(Deserialize)] 15 14 pub struct SearchAccountsParams { ··· 66 65 limit + 1, 67 66 ) 68 67 .await 69 - .map_err(|e| { 70 - error!("DB error in search_accounts: {:?}", e); 71 - ApiError::InternalError(None) 72 - })?; 68 + .log_db_err("in search_accounts")?; 73 69 74 70 let has_more = rows.len() > limit as usize; 75 71 let accounts: Vec<AccountView> = rows
+3 -3
crates/tranquil-pds/src/api/admin/account/update.rs
··· 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use crate::types::{Did, Handle, PlainPassword}; 6 + use crate::util::pds_hostname_without_port; 6 7 use axum::{ 7 8 Json, 8 9 extract::State, ··· 69 70 { 70 71 return Err(ApiError::InvalidHandle(None)); 71 72 } 72 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 73 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 73 + let hostname_for_handles = pds_hostname_without_port(); 74 74 let handle = if !input_handle.contains('.') { 75 75 format!("{}.{}", input_handle, hostname_for_handles) 76 76 } else { ··· 84 84 .ok() 85 85 .flatten() 86 86 .ok_or(ApiError::AccountNotFound)?; 87 - let handle_for_check = Handle::new_unchecked(&handle); 87 + let handle_for_check = unsafe { Handle::new_unchecked(&handle) }; 88 88 if let Ok(true) = state 89 89 .user_repo 90 90 .check_handle_exists(&handle_for_check, user_id)
+14 -50
crates/tranquil-pds/src/api/admin/config.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use axum::{Json, extract::State}; ··· 56 56 .infra_repo 57 57 .get_server_configs(keys) 58 58 .await 59 - .map_err(|e| { 60 - error!("DB error fetching server config: {:?}", e); 61 - ApiError::InternalError(None) 62 - })?; 59 + .log_db_err("fetching server config")?; 63 60 64 61 let config_map: std::collections::HashMap<String, String> = rows.into_iter().collect(); 65 62 ··· 92 89 .infra_repo 93 90 .upsert_server_config("server_name", trimmed) 94 91 .await 95 - .map_err(|e| { 96 - error!("DB error upserting server_name: {:?}", e); 97 - ApiError::InternalError(None) 98 - })?; 92 + .log_db_err("upserting server_name")?; 99 93 } 100 94 101 95 if let Some(ref color) = req.primary_color { ··· 104 98 .infra_repo 105 99 .delete_server_config("primary_color") 106 100 .await 107 - .map_err(|e| { 108 - error!("DB error deleting primary_color: {:?}", e); 109 - ApiError::InternalError(None) 110 - })?; 101 + .log_db_err("deleting primary_color")?; 111 102 } else if is_valid_hex_color(color) { 112 103 state 113 104 .infra_repo 114 105 .upsert_server_config("primary_color", color) 115 106 .await 116 - .map_err(|e| { 117 - error!("DB error upserting primary_color: {:?}", e); 118 - ApiError::InternalError(None) 119 - })?; 107 + .log_db_err("upserting primary_color")?; 120 108 } else { 121 109 return Err(ApiError::InvalidRequest( 122 110 "Invalid primary color format (expected #RRGGBB)".into(), ··· 130 118 .infra_repo 131 119 .delete_server_config("primary_color_dark") 132 120 .await 133 - .map_err(|e| { 134 - error!("DB error deleting primary_color_dark: {:?}", e); 135 - ApiError::InternalError(None) 136 - })?; 121 + .log_db_err("deleting primary_color_dark")?; 137 122 } else if is_valid_hex_color(color) { 138 123 state 139 124 .infra_repo 140 125 .upsert_server_config("primary_color_dark", color) 141 126 .await 142 - .map_err(|e| { 143 - error!("DB error upserting primary_color_dark: {:?}", e); 144 - ApiError::InternalError(None) 145 - })?; 127 + .log_db_err("upserting primary_color_dark")?; 146 128 } else { 147 129 return Err(ApiError::InvalidRequest( 148 130 "Invalid primary dark color format (expected #RRGGBB)".into(), ··· 156 138 .infra_repo 157 139 .delete_server_config("secondary_color") 158 140 .await 159 - .map_err(|e| { 160 - error!("DB error deleting secondary_color: {:?}", e); 161 - ApiError::InternalError(None) 162 - })?; 141 + .log_db_err("deleting secondary_color")?; 163 142 } else if is_valid_hex_color(color) { 164 143 state 165 144 .infra_repo 166 145 .upsert_server_config("secondary_color", color) 167 146 .await 168 - .map_err(|e| { 169 - error!("DB error upserting secondary_color: {:?}", e); 170 - ApiError::InternalError(None) 171 - })?; 147 + .log_db_err("upserting secondary_color")?; 172 148 } else { 173 149 return Err(ApiError::InvalidRequest( 174 150 "Invalid secondary color format (expected #RRGGBB)".into(), ··· 182 158 .infra_repo 183 159 .delete_server_config("secondary_color_dark") 184 160 .await 185 - .map_err(|e| { 186 - error!("DB error deleting secondary_color_dark: {:?}", e); 187 - ApiError::InternalError(None) 188 - })?; 161 + .log_db_err("deleting secondary_color_dark")?; 189 162 } else if is_valid_hex_color(color) { 190 163 state 191 164 .infra_repo 192 165 .upsert_server_config("secondary_color_dark", color) 193 166 .await 194 - .map_err(|e| { 195 - error!("DB error upserting secondary_color_dark: {:?}", e); 196 - ApiError::InternalError(None) 197 - })?; 167 + .log_db_err("upserting secondary_color_dark")?; 198 168 } else { 199 169 return Err(ApiError::InvalidRequest( 200 170 "Invalid secondary dark color format (expected #RRGGBB)".into(), ··· 217 187 }; 218 188 219 189 if let Some(old_cid_str) = should_delete_old { 220 - let old_cid = CidLink::new_unchecked(old_cid_str); 190 + let old_cid = unsafe { CidLink::new_unchecked(old_cid_str) }; 221 191 if let Ok(Some(storage_key)) = 222 192 state.infra_repo.get_blob_storage_key_by_cid(&old_cid).await 223 193 { ··· 235 205 .infra_repo 236 206 .delete_server_config("logo_cid") 237 207 .await 238 - .map_err(|e| { 239 - error!("DB error deleting logo_cid: {:?}", e); 240 - ApiError::InternalError(None) 241 - })?; 208 + .log_db_err("deleting logo_cid")?; 242 209 } else { 243 210 state 244 211 .infra_repo 245 212 .upsert_server_config("logo_cid", logo_cid) 246 213 .await 247 - .map_err(|e| { 248 - error!("DB error upserting logo_cid: {:?}", e); 249 - ApiError::InternalError(None) 250 - })?; 214 + .log_db_err("upserting logo_cid")?; 251 215 } 252 216 } 253 217
+3 -6
crates/tranquil-pds/src/api/admin/invite.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use axum::{ ··· 91 91 .infra_repo 92 92 .list_invite_codes(params.cursor.as_deref(), limit, sort_order) 93 93 .await 94 - .map_err(|e| { 95 - error!("DB error fetching invite codes: {:?}", e); 96 - ApiError::InternalError(None) 97 - })?; 94 + .log_db_err("fetching invite codes")?; 98 95 99 96 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect(); 100 97 let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect(); ··· 138 135 InviteCodeInfo { 139 136 code: r.code.clone(), 140 137 available: r.available_uses, 141 - disabled: r.disabled.unwrap_or(false), 138 + disabled: r.state().is_disabled(), 142 139 for_account: creator_did.clone(), 143 140 created_by: creator_did, 144 141 created_at: r.created_at.to_rfc3339(),
+9 -19
crates/tranquil-pds/src/api/admin/status.rs
··· 175 175 Some("com.atproto.admin.defs#repoRef") => { 176 176 let did_str = input.subject.get("did").and_then(|d| d.as_str()); 177 177 if let Some(did_str) = did_str { 178 - let did = Did::new_unchecked(did_str); 178 + let did = unsafe { Did::new_unchecked(did_str) }; 179 179 if let Some(takedown) = &input.takedown { 180 180 let takedown_ref = if takedown.applied { 181 181 takedown.r#ref.as_deref() ··· 207 207 } 208 208 if let Some(takedown) = &input.takedown { 209 209 let status = if takedown.applied { 210 - Some("takendown") 210 + tranquil_db_traits::AccountStatus::Takendown 211 211 } else { 212 - None 212 + tranquil_db_traits::AccountStatus::Active 213 213 }; 214 - if let Err(e) = crate::api::repo::record::sequence_account_event( 215 - &state, 216 - &did, 217 - !takedown.applied, 218 - status, 219 - ) 220 - .await 214 + if let Err(e) = 215 + crate::api::repo::record::sequence_account_event(&state, &did, status).await 221 216 { 222 217 warn!("Failed to sequence account event for takedown: {}", e); 223 218 } 224 219 } 225 220 if let Some(deactivated) = &input.deactivated { 226 221 let status = if deactivated.applied { 227 - Some("deactivated") 222 + tranquil_db_traits::AccountStatus::Deactivated 228 223 } else { 229 - None 224 + tranquil_db_traits::AccountStatus::Active 230 225 }; 231 - if let Err(e) = crate::api::repo::record::sequence_account_event( 232 - &state, 233 - &did, 234 - !deactivated.applied, 235 - status, 236 - ) 237 - .await 226 + if let Err(e) = 227 + crate::api::repo::record::sequence_account_event(&state, &did, status).await 238 228 { 239 229 warn!("Failed to sequence account event for deactivation: {}", e); 240 230 }
+2 -2
crates/tranquil-pds/src/api/age_assurance.rs
··· 33 33 } 34 34 35 35 async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option<String> { 36 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 36 + let auth_header = crate::util::get_header_str(headers, "Authorization"); 37 37 tracing::debug!(?auth_header, "age assurance: extracting token"); 38 38 39 39 let extracted = extract_auth_token_from_header(auth_header)?; 40 40 tracing::debug!("age assurance: got token, validating"); 41 41 42 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 42 + let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 43 43 let http_uri = "/"; 44 44 45 45 let auth_user = match validate_token_with_dpop(
+66 -116
crates/tranquil-pds/src/api/delegation.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::create_signed_commit; 3 3 use crate::auth::{Active, Auth}; 4 - use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes}; 5 - use crate::state::{AppState, RateLimitKind}; 4 + use crate::delegation::{ 5 + DelegationActionType, SCOPE_PRESETS, ValidatedDelegationScope, verify_can_add_controllers, 6 + verify_can_be_controller, verify_can_control_accounts, 7 + }; 8 + use crate::rate_limit::{AccountCreationLimit, RateLimited}; 9 + use crate::state::AppState; 6 10 use crate::types::{Did, Handle, Nsid, Rkey}; 7 - use crate::util::extract_client_ip; 11 + use crate::util::{pds_hostname, pds_hostname_without_port}; 8 12 use axum::{ 9 13 Json, 10 14 extract::{Query, State}, 11 - http::{HeaderMap, StatusCode}, 15 + http::StatusCode, 12 16 response::{IntoResponse, Response}, 13 17 }; 14 18 use jacquard_common::types::{integer::LimitedU32, string::Tid}; ··· 57 61 .map(|c| ControllerInfo { 58 62 did: c.did, 59 63 handle: c.handle, 60 - granted_scopes: c.granted_scopes, 64 + granted_scopes: c.granted_scopes.into_string(), 61 65 granted_at: c.granted_at, 62 66 is_active: c.is_active, 63 67 }) ··· 69 73 #[derive(Debug, Deserialize)] 70 74 pub struct AddControllerInput { 71 75 pub controller_did: Did, 72 - pub granted_scopes: String, 76 + pub granted_scopes: ValidatedDelegationScope, 73 77 } 74 78 75 79 pub async fn add_controller( ··· 77 81 auth: Auth<Active>, 78 82 Json(input): Json<AddControllerInput>, 79 83 ) -> Result<Response, ApiError> { 80 - if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 81 - return Ok(ApiError::InvalidScopes(e).into_response()); 82 - } 83 - 84 84 let controller_exists = state 85 85 .user_repo 86 86 .get_by_did(&input.controller_did) ··· 93 93 return Ok(ApiError::ControllerNotFound.into_response()); 94 94 } 95 95 96 - match state.delegation_repo.controls_any_accounts(&auth.did).await { 97 - Ok(true) => { 98 - return Ok(ApiError::InvalidDelegation( 99 - "Cannot add controllers to an account that controls other accounts".into(), 100 - ) 101 - .into_response()); 102 - } 103 - Err(e) => { 104 - tracing::error!("Failed to check delegation status: {:?}", e); 105 - return Ok( 106 - ApiError::InternalError(Some("Failed to verify delegation status".into())) 107 - .into_response(), 108 - ); 109 - } 110 - Ok(false) => {} 111 - } 96 + let can_add = match verify_can_add_controllers(&state, &auth).await { 97 + Ok(proof) => proof, 98 + Err(response) => return Ok(response), 99 + }; 112 100 113 - match state 114 - .delegation_repo 115 - .has_any_controllers(&input.controller_did) 116 - .await 117 - { 118 - Ok(true) => { 119 - return Ok(ApiError::InvalidDelegation( 120 - "Cannot add a controlled account as a controller".into(), 121 - ) 122 - .into_response()); 123 - } 124 - Err(e) => { 125 - tracing::error!("Failed to check controller status: {:?}", e); 126 - return Ok( 127 - ApiError::InternalError(Some("Failed to verify controller status".into())) 128 - .into_response(), 129 - ); 130 - } 131 - Ok(false) => {} 132 - } 101 + let can_be_controller = match verify_can_be_controller(&state, &input.controller_did).await { 102 + Ok(proof) => proof, 103 + Err(response) => return Ok(response), 104 + }; 133 105 134 106 match state 135 107 .delegation_repo 136 108 .create_delegation( 137 - &auth.did, 138 - &input.controller_did, 109 + can_add.did(), 110 + can_be_controller.did(), 139 111 &input.granted_scopes, 140 - &auth.did, 112 + can_add.did(), 141 113 ) 142 114 .await 143 115 { ··· 145 117 let _ = state 146 118 .delegation_repo 147 119 .log_delegation_action( 148 - &auth.did, 149 - &auth.did, 150 - Some(&input.controller_did), 120 + can_add.did(), 121 + can_add.did(), 122 + Some(can_be_controller.did()), 151 123 DelegationActionType::GrantCreated, 152 124 Some(serde_json::json!({ 153 - "granted_scopes": input.granted_scopes 125 + "granted_scopes": input.granted_scopes.as_str() 154 126 })), 155 127 None, 156 128 None, ··· 235 207 #[derive(Debug, Deserialize)] 236 208 pub struct UpdateControllerScopesInput { 237 209 pub controller_did: Did, 238 - pub granted_scopes: String, 210 + pub granted_scopes: ValidatedDelegationScope, 239 211 } 240 212 241 213 pub async fn update_controller_scopes( ··· 243 215 auth: Auth<Active>, 244 216 Json(input): Json<UpdateControllerScopesInput>, 245 217 ) -> Result<Response, ApiError> { 246 - if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 247 - return Ok(ApiError::InvalidScopes(e).into_response()); 248 - } 249 - 250 218 match state 251 219 .delegation_repo 252 220 .update_delegation_scopes(&auth.did, &input.controller_did, &input.granted_scopes) ··· 261 229 Some(&input.controller_did), 262 230 DelegationActionType::ScopesModified, 263 231 Some(serde_json::json!({ 264 - "new_scopes": input.granted_scopes 232 + "new_scopes": input.granted_scopes.as_str() 265 233 })), 266 234 None, 267 235 None, ··· 326 294 .map(|a| DelegatedAccountInfo { 327 295 did: a.did, 328 296 handle: a.handle, 329 - granted_scopes: a.granted_scopes, 297 + granted_scopes: a.granted_scopes.into_string(), 330 298 granted_at: a.granted_at, 331 299 }) 332 300 .collect(), ··· 443 411 pub struct CreateDelegatedAccountInput { 444 412 pub handle: String, 445 413 pub email: Option<String>, 446 - pub controller_scopes: String, 414 + pub controller_scopes: ValidatedDelegationScope, 447 415 pub invite_code: Option<String>, 448 416 } 449 417 ··· 456 424 457 425 pub async fn create_delegated_account( 458 426 State(state): State<AppState>, 459 - headers: HeaderMap, 427 + _rate_limit: RateLimited<AccountCreationLimit>, 460 428 auth: Auth<Active>, 461 429 Json(input): Json<CreateDelegatedAccountInput>, 462 430 ) -> Result<Response, ApiError> { 463 - let client_ip = extract_client_ip(&headers); 464 - if !state 465 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 466 - .await 467 - { 468 - warn!(ip = %client_ip, "Delegated account creation rate limit exceeded"); 469 - return Ok(ApiError::RateLimitExceeded(Some( 470 - "Too many account creation attempts. Please try again later.".into(), 471 - )) 472 - .into_response()); 473 - } 474 - 475 - if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) { 476 - return Ok(ApiError::InvalidScopes(e).into_response()); 477 - } 478 - 479 - match state.delegation_repo.has_any_controllers(&auth.did).await { 480 - Ok(true) => { 481 - return Ok(ApiError::InvalidDelegation( 482 - "Cannot create delegated accounts from a controlled account".into(), 483 - ) 484 - .into_response()); 485 - } 486 - Err(e) => { 487 - tracing::error!("Failed to check controller status: {:?}", e); 488 - return Ok( 489 - ApiError::InternalError(Some("Failed to verify controller status".into())) 490 - .into_response(), 491 - ); 492 - } 493 - Ok(false) => {} 494 - } 431 + let can_control = match verify_can_control_accounts(&state, &auth).await { 432 + Ok(proof) => proof, 433 + Err(response) => return Ok(response), 434 + }; 495 435 496 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 497 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 436 + let hostname = pds_hostname(); 437 + let hostname_for_handles = pds_hostname_without_port(); 498 438 let pds_suffix = format!(".{}", hostname_for_handles); 499 439 500 440 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) { ··· 527 467 return Ok(ApiError::InvalidEmail.into_response()); 528 468 } 529 469 530 - if let Some(ref code) = input.invite_code { 531 - let valid = state 532 - .infra_repo 533 - .is_invite_code_valid(code) 534 - .await 535 - .unwrap_or(false); 536 - 537 - if !valid { 538 - return Ok(ApiError::InvalidInviteCode.into_response()); 470 + let validated_invite_code = if let Some(ref code) = input.invite_code { 471 + match state.infra_repo.validate_invite_code(code).await { 472 + Ok(validated) => Some(validated), 473 + Err(_) => return Ok(ApiError::InvalidInviteCode.into_response()), 539 474 } 540 475 } else { 541 476 let invite_required = std::env::var("INVITE_CODE_REQUIRED") ··· 544 479 if invite_required { 545 480 return Ok(ApiError::InviteCodeRequired.into_response()); 546 481 } 547 - } 482 + None 483 + }; 548 484 549 485 use k256::ecdsa::SigningKey; 550 486 use rand::rngs::OsRng; ··· 593 529 .into_response()); 594 530 } 595 531 596 - let did = Did::new_unchecked(&genesis_result.did); 597 - let handle = Handle::new_unchecked(&handle); 598 - info!(did = %did, handle = %handle, controller = %&auth.did, "Created DID for delegated account"); 532 + let did = unsafe { Did::new_unchecked(&genesis_result.did) }; 533 + let handle = unsafe { Handle::new_unchecked(&handle) }; 534 + info!(did = %did, handle = %handle, controller = %can_control.did(), "Created DID for delegated account"); 599 535 600 536 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 601 537 Ok(bytes) => bytes, ··· 635 571 handle: handle.clone(), 636 572 email: email.clone(), 637 573 did: did.clone(), 638 - controller_did: auth.did.clone(), 639 - controller_scopes: input.controller_scopes.clone(), 574 + controller_did: can_control.did().clone(), 575 + controller_scopes: input.controller_scopes.as_str().to_string(), 640 576 encrypted_key_bytes, 641 577 encryption_version: crate::config::ENCRYPTION_VERSION, 642 578 commit_cid: commit_cid.to_string(), ··· 645 581 invite_code: input.invite_code.clone(), 646 582 }; 647 583 648 - let _user_id = match state 584 + let user_id = match state 649 585 .user_repo 650 586 .create_delegated_account(&create_input) 651 587 .await ··· 663 599 } 664 600 }; 665 601 602 + if let Some(validated) = validated_invite_code 603 + && let Err(e) = state 604 + .infra_repo 605 + .record_invite_code_use(&validated, user_id) 606 + .await 607 + { 608 + warn!("Failed to record invite code use for {}: {:?}", did, e); 609 + } 610 + 666 611 if let Err(e) = 667 612 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle)).await 668 613 { 669 614 warn!("Failed to sequence identity event for {}: {}", did, e); 670 615 } 671 - if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 616 + if let Err(e) = crate::api::repo::record::sequence_account_event( 617 + &state, 618 + &did, 619 + tranquil_db_traits::AccountStatus::Active, 620 + ) 621 + .await 672 622 { 673 623 warn!("Failed to sequence account event for {}: {}", did, e); 674 624 } ··· 677 627 "$type": "app.bsky.actor.profile", 678 628 "displayName": handle 679 629 }); 680 - let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile"); 681 - let profile_rkey = Rkey::new_unchecked("self"); 630 + let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 631 + let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 682 632 if let Err(e) = crate::api::repo::record::create_record_internal( 683 633 &state, 684 634 &did, ··· 700 650 DelegationActionType::GrantCreated, 701 651 Some(json!({ 702 652 "account_created": true, 703 - "granted_scopes": input.controller_scopes 653 + "granted_scopes": input.controller_scopes.as_str() 704 654 })), 705 655 None, 706 656 None,
+19
crates/tranquil-pds/src/api/error.rs
··· 694 694 } 695 695 } 696 696 697 + impl From<crate::rate_limit::UserRateLimitError> for ApiError { 698 + fn from(e: crate::rate_limit::UserRateLimitError) -> Self { 699 + Self::RateLimitExceeded(e.message) 700 + } 701 + } 702 + 697 703 #[allow(clippy::result_large_err)] 698 704 pub fn parse_did(s: &str) -> Result<tranquil_types::Did, Response> { 699 705 s.parse() ··· 756 762 _ => "Invalid request body".to_string(), 757 763 } 758 764 } 765 + 766 + pub trait DbResultExt<T> { 767 + fn log_db_err(self, ctx: &str) -> Result<T, ApiError>; 768 + } 769 + 770 + impl<T, E: std::fmt::Debug> DbResultExt<T> for Result<T, E> { 771 + fn log_db_err(self, ctx: &str) -> Result<T, ApiError> { 772 + self.map_err(|e| { 773 + tracing::error!("DB error {}: {:?}", ctx, e); 774 + ApiError::DatabaseError 775 + }) 776 + } 777 + }
+40 -65
crates/tranquil-pds/src/api/identity/account.rs
··· 3 3 use crate::api::repo::record::utils::create_signed_commit; 4 4 use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token}; 5 5 use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 6 - use crate::state::{AppState, RateLimitKind}; 6 + use crate::rate_limit::{AccountCreationLimit, RateLimited}; 7 + use crate::state::AppState; 7 8 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 9 + use crate::util::{pds_hostname, pds_hostname_without_port}; 8 10 use crate::validation::validate_password; 9 11 use axum::{ 10 12 Json, ··· 22 24 use std::sync::Arc; 23 25 use tracing::{debug, error, info, warn}; 24 26 25 - fn extract_client_ip(headers: &HeaderMap) -> String { 26 - if let Some(forwarded) = headers.get("x-forwarded-for") 27 - && let Ok(value) = forwarded.to_str() 28 - && let Some(first_ip) = value.split(',').next() 29 - { 30 - return first_ip.trim().to_string(); 31 - } 32 - if let Some(real_ip) = headers.get("x-real-ip") 33 - && let Ok(value) = real_ip.to_str() 34 - { 35 - return value.trim().to_string(); 36 - } 37 - "unknown".to_string() 38 - } 39 - 40 27 #[derive(Deserialize)] 41 28 #[serde(rename_all = "camelCase")] 42 29 pub struct CreateAccountInput { ··· 68 55 69 56 pub async fn create_account( 70 57 State(state): State<AppState>, 58 + _rate_limit: RateLimited<AccountCreationLimit>, 71 59 headers: HeaderMap, 72 60 Json(input): Json<CreateAccountInput>, 73 61 ) -> Response { ··· 84 72 } else { 85 73 info!("create_account called"); 86 74 } 87 - let client_ip = extract_client_ip(&headers); 88 - if !state 89 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 90 - .await 91 - { 92 - warn!(ip = %client_ip, "Account creation rate limit exceeded"); 93 - return ApiError::RateLimitExceeded(Some( 94 - "Too many account creation attempts. Please try again later.".into(), 95 - )) 96 - .into_response(); 97 - } 98 75 99 76 let migration_auth = if let Some(extracted) = 100 - extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 77 + extract_auth_token_from_header(crate::util::get_header_str(&headers, "Authorization")) 101 78 { 102 79 let token = extracted.token; 103 80 if is_service_token(&token) { ··· 143 120 if (is_migration || is_did_web_byod) 144 121 && let (Some(provided_did), Some(auth_did)) = (input.did.as_ref(), migration_auth.as_ref()) 145 122 { 146 - if provided_did != auth_did { 123 + if provided_did != auth_did.as_str() { 147 124 info!( 148 125 "[MIGRATION] createAccount: Service token mismatch - token_did={} provided_did={}", 149 126 auth_did, provided_did ··· 164 141 } 165 142 } 166 143 167 - let hostname_for_validation = 168 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 144 + let hostname_for_validation = pds_hostname_without_port(); 169 145 let pds_suffix = format!(".{}", hostname_for_validation); 170 146 171 147 let validated_short_handle = if !input.handle.contains('.') ··· 242 218 _ => return ApiError::InvalidVerificationChannel.into_response(), 243 219 }) 244 220 }; 245 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 246 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 221 + let hostname = pds_hostname(); 222 + let hostname_for_handles = pds_hostname_without_port(); 247 223 let pds_endpoint = format!("https://{}", hostname); 248 224 let suffix = format!(".{}", hostname_for_handles); 249 225 let handle = if input.handle.ends_with(&suffix) { ··· 308 284 } 309 285 if !is_did_web_byod 310 286 && let Err(e) = 311 - verify_did_web(d, &hostname, &input.handle, input.signing_key.as_deref()).await 287 + verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()).await 312 288 { 313 289 return ApiError::InvalidDid(e).into_response(); 314 290 } ··· 322 298 d.clone() 323 299 } else if d.starts_with("did:web:") { 324 300 if !is_did_web_byod 325 - && let Err(e) = verify_did_web( 326 - d, 327 - &hostname, 328 - &input.handle, 329 - input.signing_key.as_deref(), 330 - ) 331 - .await 301 + && let Err(e) = 302 + verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()) 303 + .await 332 304 { 333 305 return ApiError::InvalidDid(e).into_response(); 334 306 } ··· 408 380 }; 409 381 if is_migration { 410 382 let reactivate_input = tranquil_db_traits::MigrationReactivationInput { 411 - did: Did::new_unchecked(&did), 412 - new_handle: Handle::new_unchecked(&handle), 383 + did: unsafe { Did::new_unchecked(&did) }, 384 + new_handle: unsafe { Handle::new_unchecked(&handle) }, 413 385 new_email: email.clone(), 414 386 }; 415 387 match state ··· 463 435 } 464 436 }; 465 437 let session_data = tranquil_db_traits::SessionTokenCreate { 466 - did: Did::new_unchecked(&did), 438 + did: unsafe { Did::new_unchecked(&did) }, 467 439 access_jti: access_meta.jti.clone(), 468 440 refresh_jti: refresh_meta.jti.clone(), 469 441 access_expires_at: access_meta.expires_at, 470 442 refresh_expires_at: refresh_meta.expires_at, 471 - legacy_login: false, 443 + login_type: tranquil_db_traits::LoginType::Modern, 472 444 mfa_verified: false, 473 445 scope: None, 474 446 controller_did: None, ··· 478 450 error!("Error creating session: {:?}", e); 479 451 return ApiError::InternalError(None).into_response(); 480 452 } 481 - let hostname = 482 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 453 + let hostname = pds_hostname(); 483 454 let verification_required = if let Some(ref user_email) = email { 484 455 let token = 485 456 crate::auth::verification_token::generate_migration_token(&did, user_email); ··· 491 462 reactivated.user_id, 492 463 user_email, 493 464 &formatted_token, 494 - &hostname, 465 + hostname, 495 466 ) 496 467 .await 497 468 { ··· 505 476 axum::http::StatusCode::OK, 506 477 Json(CreateAccountOutput { 507 478 handle: handle.clone().into(), 508 - did: Did::new_unchecked(&did), 479 + did: unsafe { Did::new_unchecked(&did) }, 509 480 did_doc: state.did_resolver.resolve_did_document(&did).await, 510 481 access_jwt: access_meta.token, 511 482 refresh_jwt: refresh_meta.token, ··· 529 500 } 530 501 } 531 502 532 - let handle_typed = Handle::new_unchecked(&handle); 503 + let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 533 504 let handle_available = match state 534 505 .user_repo 535 506 .check_handle_available_for_new_account(&handle_typed) ··· 613 584 } 614 585 }; 615 586 let rev = Tid::now(LimitedU32::MIN); 616 - let did_for_commit = Did::new_unchecked(&did); 587 + let did_for_commit = unsafe { Did::new_unchecked(&did) }; 617 588 let (commit_bytes, _sig) = 618 589 match create_signed_commit(&did_for_commit, mst_root, rev.as_ref(), None, &signing_key) { 619 590 Ok(result) => result, ··· 649 620 }; 650 621 651 622 let create_input = tranquil_db_traits::CreatePasswordAccountInput { 652 - handle: Handle::new_unchecked(&handle), 623 + handle: unsafe { Handle::new_unchecked(&handle) }, 653 624 email: email.clone(), 654 - did: Did::new_unchecked(&did), 625 + did: unsafe { Did::new_unchecked(&did) }, 655 626 password_hash, 656 627 preferred_comms_channel, 657 628 discord_id: input ··· 701 672 }; 702 673 let user_id = create_result.user_id; 703 674 if !is_migration && !is_did_web_byod { 704 - let did_typed = Did::new_unchecked(&did); 705 - let handle_typed = Handle::new_unchecked(&handle); 675 + let did_typed = unsafe { Did::new_unchecked(&did) }; 676 + let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 706 677 if let Err(e) = crate::api::repo::record::sequence_identity_event( 707 678 &state, 708 679 &did_typed, ··· 712 683 { 713 684 warn!("Failed to sequence identity event for {}: {}", did, e); 714 685 } 715 - if let Err(e) = 716 - crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await 686 + if let Err(e) = crate::api::repo::record::sequence_account_event( 687 + &state, 688 + &did_typed, 689 + tranquil_db_traits::AccountStatus::Active, 690 + ) 691 + .await 717 692 { 718 693 warn!("Failed to sequence account event for {}: {}", did, e); 719 694 } ··· 742 717 "$type": "app.bsky.actor.profile", 743 718 "displayName": input.handle 744 719 }); 745 - let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile"); 746 - let profile_rkey = Rkey::new_unchecked("self"); 720 + let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 721 + let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 747 722 if let Err(e) = crate::api::repo::record::create_record_internal( 748 723 &state, 749 724 &did_typed, ··· 756 731 warn!("Failed to create default profile for {}: {}", did, e); 757 732 } 758 733 } 759 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 734 + let hostname = pds_hostname(); 760 735 if !is_migration { 761 736 if let Some(ref recipient) = verification_recipient { 762 737 let verification_token = crate::auth::verification_token::generate_signup_token( ··· 772 747 verification_channel, 773 748 recipient, 774 749 &formatted_token, 775 - &hostname, 750 + hostname, 776 751 ) 777 752 .await 778 753 { ··· 791 766 user_id, 792 767 user_email, 793 768 &formatted_token, 794 - &hostname, 769 + hostname, 795 770 ) 796 771 .await 797 772 { ··· 816 791 } 817 792 }; 818 793 let session_data = tranquil_db_traits::SessionTokenCreate { 819 - did: Did::new_unchecked(&did), 794 + did: unsafe { Did::new_unchecked(&did) }, 820 795 access_jti: access_meta.jti.clone(), 821 796 refresh_jti: refresh_meta.jti.clone(), 822 797 access_expires_at: access_meta.expires_at, 823 798 refresh_expires_at: refresh_meta.expires_at, 824 - legacy_login: false, 799 + login_type: tranquil_db_traits::LoginType::Modern, 825 800 mfa_verified: false, 826 801 scope: None, 827 802 controller_did: None, ··· 845 820 StatusCode::OK, 846 821 Json(CreateAccountOutput { 847 822 handle: handle.clone().into(), 848 - did: Did::new_unchecked(&did), 823 + did: unsafe { Did::new_unchecked(&did) }, 849 824 did_doc, 850 825 access_jwt: access_meta.token, 851 826 refresh_jwt: refresh_meta.token,
+27 -31
crates/tranquil-pds/src/api/identity/did.rs
··· 1 1 use crate::api::{ApiError, DidResponse, EmptyResponse}; 2 2 use crate::auth::{Auth, NotTakendown}; 3 3 use crate::plc::signing_key_to_did_key; 4 + use crate::rate_limit::{ 5 + HandleUpdateDailyLimit, HandleUpdateLimit, check_user_rate_limit_with_message, 6 + }; 4 7 use crate::state::AppState; 5 8 use crate::types::Handle; 9 + use crate::util::{get_header_str, pds_hostname, pds_hostname_without_port}; 6 10 use axum::{ 7 11 Json, 8 12 extract::{Path, Query, State}, ··· 101 105 } 102 106 103 107 pub async fn well_known_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 104 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 105 - let host_header = headers 106 - .get("host") 107 - .and_then(|h| h.to_str().ok()) 108 - .unwrap_or(&hostname); 108 + let hostname = pds_hostname(); 109 + let hostname_without_port = pds_hostname_without_port(); 110 + let host_header = get_header_str(&headers, "host").unwrap_or(hostname); 109 111 let host_without_port = host_header.split(':').next().unwrap_or(host_header); 110 - let hostname_without_port = hostname.split(':').next().unwrap_or(&hostname); 111 112 if host_without_port != hostname_without_port 112 113 && host_without_port.ends_with(&format!(".{}", hostname_without_port)) 113 114 { 114 115 let handle = host_without_port 115 116 .strip_suffix(&format!(".{}", hostname_without_port)) 116 117 .unwrap_or(host_without_port); 117 - return serve_subdomain_did_doc(&state, handle, &hostname).await; 118 + return serve_subdomain_did_doc(&state, handle, hostname).await; 118 119 } 119 120 let did = if hostname.contains(':') { 120 121 format!("did:web:{}", hostname.replace(':', "%3A")) ··· 257 258 } 258 259 259 260 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 260 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 261 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 261 + let hostname = pds_hostname(); 262 + let hostname_for_handles = pds_hostname_without_port(); 262 263 let current_handle = format!("{}.{}", handle, hostname_for_handles); 263 264 let current_handle_typed: Handle = match current_handle.parse() { 264 265 Ok(h) => h, ··· 531 532 ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into())) 532 533 })?; 533 534 534 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 535 + let hostname = pds_hostname(); 535 536 let pds_endpoint = format!("https://{}", hostname); 536 537 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes) 537 538 .map_err(|_| ApiError::InternalError(None))?; ··· 585 586 return Ok(e); 586 587 } 587 588 let did = auth.did.clone(); 588 - if !state 589 - .check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did) 590 - .await 591 - { 592 - return Err(ApiError::RateLimitExceeded(Some( 593 - "Too many handle updates. Try again later.".into(), 594 - ))); 595 - } 596 - if !state 597 - .check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did) 598 - .await 599 - { 600 - return Err(ApiError::RateLimitExceeded(Some( 601 - "Daily handle update limit exceeded.".into(), 602 - ))); 603 - } 589 + let _rate_limit = check_user_rate_limit_with_message::<HandleUpdateLimit>( 590 + &state, 591 + &did, 592 + "Too many handle updates. Try again later.", 593 + ) 594 + .await?; 595 + let _daily_rate_limit = check_user_rate_limit_with_message::<HandleUpdateDailyLimit>( 596 + &state, 597 + &did, 598 + "Daily handle update limit exceeded.", 599 + ) 600 + .await?; 604 601 let user_row = state 605 602 .user_repo 606 603 .get_id_and_handle_by_did(&did) ··· 639 636 "Inappropriate language in handle".into(), 640 637 ))); 641 638 } 642 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 643 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 639 + let hostname_for_handles = pds_hostname_without_port(); 644 640 let suffix = format!(".{}", hostname_for_handles); 645 641 let is_service_domain = 646 642 crate::handle::is_service_domain_handle(&new_handle, hostname_for_handles); ··· 656 652 format!("{}.{}", new_handle, hostname_for_handles) 657 653 }; 658 654 if full_handle == current_handle { 659 - let handle_typed = Handle::new_unchecked(&full_handle); 655 + let handle_typed = unsafe { Handle::new_unchecked(&full_handle) }; 660 656 if let Err(e) = 661 657 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 662 658 .await ··· 679 675 full_handle 680 676 } else { 681 677 if new_handle == current_handle { 682 - let handle_typed = Handle::new_unchecked(&new_handle); 678 + let handle_typed = unsafe { Handle::new_unchecked(&new_handle) }; 683 679 if let Err(e) = 684 680 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 685 681 .await ··· 772 768 } 773 769 774 770 pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 775 - let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 771 + let host = match crate::util::get_header_str(&headers, "host") { 776 772 Some(h) => h, 777 773 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 778 774 };
+7 -12
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Auth, Permissive}; 4 4 use crate::state::AppState; 5 + use crate::util::pds_hostname; 5 6 use axum::{ 6 7 extract::State, 7 8 response::{IntoResponse, Response}, 8 9 }; 9 10 use chrono::{Duration, Utc}; 10 - use tracing::{error, info, warn}; 11 + use tracing::{info, warn}; 11 12 12 13 fn generate_plc_token() -> String { 13 14 crate::util::generate_token_code() ··· 28 29 .user_repo 29 30 .get_id_by_did(&auth.did) 30 31 .await 31 - .map_err(|e| { 32 - error!("DB error: {:?}", e); 33 - ApiError::InternalError(None) 34 - })? 32 + .log_db_err("fetching user id")? 35 33 .ok_or(ApiError::AccountNotFound)?; 36 34 37 35 let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await; ··· 41 39 .infra_repo 42 40 .insert_plc_token(user_id, &plc_token, expires_at) 43 41 .await 44 - .map_err(|e| { 45 - error!("Failed to create PLC token: {:?}", e); 46 - ApiError::InternalError(None) 47 - })?; 42 + .log_db_err("creating PLC token")?; 48 43 49 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 44 + let hostname = pds_hostname(); 50 45 if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation( 51 46 state.user_repo.as_ref(), 52 47 state.infra_repo.as_ref(), 53 48 user_id, 54 49 &plc_token, 55 - &hostname, 50 + hostname, 56 51 ) 57 52 .await 58 53 {
+4 -12
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 1 1 use crate::api::ApiError; 2 + use crate::api::error::DbResultExt; 2 3 use crate::auth::{Auth, Permissive}; 3 4 use crate::circuit_breaker::with_circuit_breaker; 4 5 use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation}; ··· 64 65 .user_repo 65 66 .get_id_by_did(did) 66 67 .await 67 - .map_err(|e| { 68 - error!("DB error: {:?}", e); 69 - ApiError::InternalError(None) 70 - })? 68 + .log_db_err("fetching user id")? 71 69 .ok_or(ApiError::AccountNotFound)?; 72 70 73 71 let token_expiry = state 74 72 .infra_repo 75 73 .get_plc_token_expiry(user_id, token) 76 74 .await 77 - .map_err(|e| { 78 - error!("DB error: {:?}", e); 79 - ApiError::InternalError(None) 80 - })? 75 + .log_db_err("fetching PLC token expiry")? 81 76 .ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?; 82 77 83 78 if Utc::now() > token_expiry { ··· 88 83 .user_repo 89 84 .get_user_key_by_id(user_id) 90 85 .await 91 - .map_err(|e| { 92 - error!("DB error: {:?}", e); 93 - ApiError::InternalError(None) 94 - })? 86 + .log_db_err("fetching user key")? 95 87 .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 96 88 97 89 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+5 -9
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 1 + use crate::api::error::DbResultExt; 1 2 use crate::api::{ApiError, EmptyResponse}; 2 3 use crate::auth::{Auth, Permissive}; 3 4 use crate::circuit_breaker::with_circuit_breaker; 4 5 use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation}; 5 6 use crate::state::AppState; 7 + use crate::util::pds_hostname; 6 8 use axum::{ 7 9 Json, 8 10 extract::State, ··· 40 42 .map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?; 41 43 42 44 let op = &input.operation; 43 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 45 + let hostname = pds_hostname(); 44 46 let public_url = format!("https://{}", hostname); 45 47 let user = state 46 48 .user_repo 47 49 .get_id_and_handle_by_did(did) 48 50 .await 49 - .map_err(|e| { 50 - error!("DB error: {:?}", e); 51 - ApiError::InternalError(None) 52 - })? 51 + .log_db_err("fetching user")? 53 52 .ok_or(ApiError::AccountNotFound)?; 54 53 55 54 let key_row = state 56 55 .user_repo 57 56 .get_user_key_by_id(user.id) 58 57 .await 59 - .map_err(|e| { 60 - error!("DB error: {:?}", e); 61 - ApiError::InternalError(None) 62 - })? 58 + .log_db_err("fetching user key")? 63 59 .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 64 60 65 61 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+33 -27
crates/tranquil-pds/src/api/notification_prefs.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::auth::{Active, Auth}; 3 3 use crate::state::AppState; 4 + use crate::util::pds_hostname; 4 5 use axum::{ 5 6 Json, 6 7 extract::State, ··· 9 10 use serde::{Deserialize, Serialize}; 10 11 use serde_json::json; 11 12 use tracing::info; 13 + use tranquil_db_traits::{CommsChannel, CommsStatus, CommsType}; 12 14 13 15 #[derive(Serialize)] 14 16 #[serde(rename_all = "camelCase")] 15 17 pub struct NotificationPrefsResponse { 16 - pub preferred_channel: String, 18 + pub preferred_channel: CommsChannel, 17 19 pub email: String, 18 20 pub discord_id: Option<String>, 19 21 pub discord_verified: bool, ··· 50 52 #[serde(rename_all = "camelCase")] 51 53 pub struct NotificationHistoryEntry { 52 54 pub created_at: String, 53 - pub channel: String, 54 - pub comms_type: String, 55 - pub status: String, 55 + pub channel: CommsChannel, 56 + pub comms_type: CommsType, 57 + pub status: CommsStatus, 56 58 pub subject: Option<String>, 57 59 pub body: String, 58 60 } ··· 81 83 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 82 84 83 85 let sensitive_types = [ 84 - "email_verification", 85 - "password_reset", 86 - "email_update", 87 - "two_factor_code", 88 - "passkey_recovery", 89 - "migration_verification", 90 - "plc_operation", 91 - "channel_verification", 92 - "signup_verification", 86 + CommsType::EmailVerification, 87 + CommsType::PasswordReset, 88 + CommsType::EmailUpdate, 89 + CommsType::TwoFactorCode, 90 + CommsType::PasskeyRecovery, 91 + CommsType::MigrationVerification, 92 + CommsType::PlcOperation, 93 + CommsType::ChannelVerification, 93 94 ]; 94 95 95 96 let notifications = rows 96 97 .iter() 97 98 .map(|row| { 98 - let body = if sensitive_types.contains(&row.comms_type.as_str()) { 99 + let body = if sensitive_types.contains(&row.comms_type) { 99 100 "[Code redacted for security]".to_string() 100 101 } else { 101 102 row.body.clone() 102 103 }; 103 104 NotificationHistoryEntry { 104 105 created_at: row.created_at.to_rfc3339(), 105 - channel: row.channel.clone(), 106 - comms_type: row.comms_type.clone(), 107 - status: row.status.clone(), 106 + channel: row.channel, 107 + comms_type: row.comms_type, 108 + status: row.status, 108 109 subject: row.subject.clone(), 109 110 body, 110 111 } ··· 145 146 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 146 147 147 148 if channel == "email" { 148 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 149 + let hostname = pds_hostname(); 149 150 let handle_str = handle.unwrap_or("user"); 150 151 crate::comms::comms_repo::enqueue_email_update( 151 152 state.infra_repo.as_ref(), ··· 153 154 identifier, 154 155 handle_str, 155 156 &formatted_token, 156 - &hostname, 157 + hostname, 157 158 ) 158 159 .await 159 160 .map_err(|e| format!("Failed to enqueue email notification: {}", e))?; ··· 200 201 201 202 let mut verification_required: Vec<String> = Vec::new(); 202 203 203 - if let Some(ref channel) = input.preferred_channel { 204 - let valid_channels = ["email", "discord", "telegram", "signal"]; 205 - if !valid_channels.contains(&channel.as_str()) { 206 - return Err(ApiError::InvalidRequest( 207 - "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 208 - )); 209 - } 204 + if let Some(ref channel_str) = input.preferred_channel { 205 + let channel = match channel_str.as_str() { 206 + "email" => CommsChannel::Email, 207 + "discord" => CommsChannel::Discord, 208 + "telegram" => CommsChannel::Telegram, 209 + "signal" => CommsChannel::Signal, 210 + _ => { 211 + return Err(ApiError::InvalidRequest( 212 + "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 213 + )); 214 + } 215 + }; 210 216 state 211 217 .user_repo 212 218 .update_preferred_comms_channel(&auth.did, channel) 213 219 .await 214 220 .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 215 - info!(did = %auth.did, channel = %channel, "Updated preferred notification channel"); 221 + info!(did = %auth.did, channel = ?channel, "Updated preferred notification channel"); 216 222 } 217 223 218 224 if let Some(ref new_email) = input.email {
+4 -7
crates/tranquil-pds/src/api/proxy.rs
··· 3 3 use crate::api::error::ApiError; 4 4 use crate::api::proxy_client::proxy_client; 5 5 use crate::state::AppState; 6 + use crate::util::get_header_str; 6 7 use axum::{ 7 8 body::Bytes, 8 9 extract::{RawQuery, Request, State}, ··· 191 192 .into_response(); 192 193 } 193 194 194 - let Some(proxy_header) = headers 195 - .get("atproto-proxy") 196 - .and_then(|h| h.to_str().ok()) 197 - .map(String::from) 198 - else { 195 + let Some(proxy_header) = get_header_str(&headers, "atproto-proxy").map(String::from) else { 199 196 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 200 197 .into_response(); 201 198 }; ··· 217 214 218 215 let mut auth_header_val = headers.get("Authorization").cloned(); 219 216 if let Some(extracted) = crate::auth::extract_auth_token_from_header( 220 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 217 + crate::util::get_header_str(&headers, "Authorization"), 221 218 ) { 222 219 let token = extracted.token; 223 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 220 + let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 224 221 let http_uri = crate::util::build_full_url(&uri.to_string()); 225 222 226 223 match crate::auth::validate_token_with_dpop(
+18 -28
crates/tranquil-pds/src/api/repo/blob.rs
··· 1 - use crate::api::error::ApiError; 2 - use crate::auth::{Auth, AuthAny, NotTakendown, Permissive}; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 + use crate::auth::{Auth, AuthAny, NotTakendown, Permissive, VerifyScope}; 3 3 use crate::delegation::DelegationActionType; 4 4 use crate::state::AppState; 5 5 use crate::types::{CidLink, Did}; 6 - use crate::util::get_max_blob_size; 6 + use crate::util::{get_header_str, get_max_blob_size}; 7 7 use axum::body::Body; 8 8 use axum::{ 9 9 Json, ··· 56 56 if user.status.is_takendown() { 57 57 return Err(ApiError::AccountTakedown); 58 58 } 59 - let mime_type_for_check = headers 60 - .get("content-type") 61 - .and_then(|h| h.to_str().ok()) 62 - .unwrap_or("application/octet-stream"); 63 - if let Err(e) = crate::auth::scope_check::check_blob_scope( 64 - user.is_oauth(), 65 - user.scope.as_deref(), 66 - mime_type_for_check, 67 - ) { 68 - return Ok(e); 69 - } 70 - (user.did.clone(), user.controller_did.clone()) 59 + let mime_type_for_check = 60 + get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 61 + let scope_proof = match user.verify_blob_upload(mime_type_for_check) { 62 + Ok(proof) => proof, 63 + Err(e) => return Ok(e.into_response()), 64 + }; 65 + ( 66 + scope_proof.principal_did().into_did(), 67 + scope_proof.controller_did().map(|c| c.into_did()), 68 + ) 71 69 } 72 70 }; 73 71 ··· 80 78 return Err(ApiError::Forbidden); 81 79 } 82 80 83 - let client_mime_hint = headers 84 - .get("content-type") 85 - .and_then(|h| h.to_str().ok()) 86 - .unwrap_or("application/octet-stream"); 81 + let client_mime_hint = 82 + get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 87 83 88 84 let user_id = state 89 85 .user_repo ··· 140 136 }; 141 137 let cid = Cid::new_v1(0x55, multihash); 142 138 let cid_str = cid.to_string(); 143 - let cid_link: CidLink = CidLink::new_unchecked(&cid_str); 139 + let cid_link: CidLink = unsafe { CidLink::new_unchecked(&cid_str) }; 144 140 let storage_key = cid_str.clone(); 145 141 146 142 info!( ··· 232 228 .user_repo 233 229 .get_by_did(did) 234 230 .await 235 - .map_err(|e| { 236 - error!("DB error fetching user: {:?}", e); 237 - ApiError::InternalError(None) 238 - })? 231 + .log_db_err("fetching user")? 239 232 .ok_or(ApiError::InternalError(None))?; 240 233 241 234 let limit = params.limit.unwrap_or(500).clamp(1, 1000); ··· 244 237 .blob_repo 245 238 .list_missing_blobs(user.id, cursor, limit + 1) 246 239 .await 247 - .map_err(|e| { 248 - error!("DB error fetching missing blobs: {:?}", e); 249 - ApiError::InternalError(None) 250 - })?; 240 + .log_db_err("fetching missing blobs")?; 251 241 252 242 let has_more = missing.len() > limit as usize; 253 243 let blobs: Vec<RecordBlob> = missing
+8 -12
crates/tranquil-pds/src/api/repo/import.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::api::repo::record::create_signed_commit; 4 4 use crate::auth::{Auth, NotTakendown}; 5 5 use crate::state::AppState; ··· 49 49 .user_repo 50 50 .get_by_did(did) 51 51 .await 52 - .map_err(|e| { 53 - error!("DB error fetching user: {:?}", e); 54 - ApiError::InternalError(None) 55 - })? 52 + .log_db_err("fetching user")? 56 53 .ok_or(ApiError::AccountNotFound)?; 57 54 if user.takedown_ref.is_some() { 58 55 return Err(ApiError::AccountTakedown); ··· 207 204 let record_uri = 208 205 AtUri::from_parts(did.as_str(), &record.collection, &record.rkey); 209 206 record.blob_refs.iter().map(move |blob_ref| { 210 - ( 211 - record_uri.clone(), 212 - CidLink::new_unchecked(blob_ref.cid.clone()), 213 - ) 207 + (record_uri.clone(), unsafe { 208 + CidLink::new_unchecked(blob_ref.cid.clone()) 209 + }) 214 210 }) 215 211 }) 216 212 .collect(); ··· 275 271 error!("Failed to store new commit block: {:?}", e); 276 272 ApiError::InternalError(None) 277 273 })?; 278 - let new_root_cid_link = CidLink::new_unchecked(new_root_cid.to_string()); 274 + let new_root_cid_link = unsafe { CidLink::new_unchecked(new_root_cid.to_string()) }; 279 275 state 280 276 .repo_repo 281 277 .update_repo_root(user_id, &new_root_cid_link, &new_rev_str) ··· 368 364 ) -> Result<(), tranquil_db::DbError> { 369 365 let data = tranquil_db::CommitEventData { 370 366 did: did.clone(), 371 - event_type: "commit".to_string(), 372 - commit_cid: Some(CidLink::new_unchecked(commit_cid)), 367 + event_type: tranquil_db::RepoEventType::Commit, 368 + commit_cid: Some(unsafe { CidLink::new_unchecked(commit_cid) }), 373 369 prev_cid: None, 374 370 ops: Some(serde_json::json!([])), 375 371 blobs: Some(vec![]),
+2 -2
crates/tranquil-pds/src/api/repo/meta.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 3 use crate::types::AtIdentifier; 4 + use crate::util::pds_hostname_without_port; 4 5 use axum::{ 5 6 Json, 6 7 extract::{Query, State}, ··· 18 19 State(state): State<AppState>, 19 20 Query(input): Query<DescribeRepoInput>, 20 21 ) -> Response { 21 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 22 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 22 + let hostname_for_handles = pds_hostname_without_port(); 23 23 let user_row = if input.repo.is_did() { 24 24 let did: crate::types::Did = match input.repo.as_str().parse() { 25 25 Ok(d) => d,
+70 -131
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record_with_status; 2 + use super::validation_mode::{ValidationMode, deserialize_validation_mode}; 2 3 use crate::api::error::ApiError; 3 4 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 4 - use crate::auth::{Active, Auth}; 5 + use crate::auth::{ 6 + Active, Auth, WriteOpKind, require_not_migrated, require_verified_or_delegated, 7 + verify_batch_write_scopes, 8 + }; 9 + use crate::cid_types::CommitCid; 5 10 use crate::delegation::DelegationActionType; 6 11 use crate::repo::tracking::TrackingBlockStore; 7 12 use crate::state::AppState; ··· 34 39 write: &WriteOp, 35 40 acc: WriteAccumulator, 36 41 did: &Did, 37 - validate: Option<bool>, 42 + validate: ValidationMode, 38 43 tracking_store: &TrackingBlockStore, 39 44 ) -> Result<WriteAccumulator, Response> { 40 45 let WriteAccumulator { ··· 51 56 rkey, 52 57 value, 53 58 } => { 54 - let validation_status = match validate { 55 - Some(false) => None, 56 - _ => { 57 - let require_lexicon = validate == Some(true); 58 - match validate_record_with_status( 59 - value, 60 - collection, 61 - rkey.as_ref(), 62 - require_lexicon, 63 - ) { 64 - Ok(status) => Some(status), 65 - Err(err_response) => return Err(*err_response), 66 - } 59 + let validation_status = if validate.should_skip() { 60 + None 61 + } else { 62 + match validate_record_with_status( 63 + value, 64 + collection, 65 + rkey.as_ref(), 66 + validate.requires_lexicon(), 67 + ) { 68 + Ok(status) => Some(status), 69 + Err(err_response) => return Err(*err_response), 67 70 } 68 71 }; 69 72 all_blob_cids.extend(extract_blob_cids(value)); ··· 104 107 rkey, 105 108 value, 106 109 } => { 107 - let validation_status = match validate { 108 - Some(false) => None, 109 - _ => { 110 - let require_lexicon = validate == Some(true); 111 - match validate_record_with_status( 112 - value, 113 - collection, 114 - Some(rkey), 115 - require_lexicon, 116 - ) { 117 - Ok(status) => Some(status), 118 - Err(err_response) => return Err(*err_response), 119 - } 110 + let validation_status = if validate.should_skip() { 111 + None 112 + } else { 113 + match validate_record_with_status( 114 + value, 115 + collection, 116 + Some(rkey), 117 + validate.requires_lexicon(), 118 + ) { 119 + Ok(status) => Some(status), 120 + Err(err_response) => return Err(*err_response), 120 121 } 121 122 }; 122 123 all_blob_cids.extend(extract_blob_cids(value)); ··· 181 182 writes: &[WriteOp], 182 183 initial_mst: Mst<TrackingBlockStore>, 183 184 did: &Did, 184 - validate: Option<bool>, 185 + validate: ValidationMode, 185 186 tracking_store: &TrackingBlockStore, 186 187 ) -> Result<WriteAccumulator, Response> { 187 188 use futures::stream::{self, TryStreamExt}; ··· 222 223 #[serde(rename_all = "camelCase")] 223 224 pub struct ApplyWritesInput { 224 225 pub repo: AtIdentifier, 225 - pub validate: Option<bool>, 226 + #[serde(default, deserialize_with = "deserialize_validation_mode")] 227 + pub validate: ValidationMode, 226 228 pub writes: Vec<WriteOp>, 227 229 pub swap_commit: Option<String>, 228 230 } ··· 270 272 input.repo, 271 273 input.writes.len() 272 274 ); 273 - let did = auth.did.clone(); 274 - let is_oauth = auth.is_oauth(); 275 - let scope = auth.scope.clone(); 276 - let controller_did = auth.controller_did.clone(); 277 - if input.repo.as_str() != did { 278 - return Err(ApiError::InvalidRepo( 279 - "Repo does not match authenticated user".into(), 280 - )); 281 - } 282 - if state 283 - .user_repo 284 - .is_account_migrated(&did) 285 - .await 286 - .unwrap_or(false) 287 - { 288 - return Err(ApiError::AccountMigrated); 289 - } 290 - let is_verified = state 291 - .user_repo 292 - .has_verified_comms_channel(&did) 293 - .await 294 - .unwrap_or(false); 295 - let is_delegated = state 296 - .delegation_repo 297 - .is_delegated_account(&did) 298 - .await 299 - .unwrap_or(false); 300 - if !is_verified && !is_delegated { 301 - return Err(ApiError::AccountNotVerified); 302 - } 275 + 303 276 if input.writes.is_empty() { 304 277 return Err(ApiError::InvalidRequest("writes array is empty".into())); 305 278 } ··· 310 283 ))); 311 284 } 312 285 313 - let has_custom_scope = scope 314 - .as_ref() 315 - .map(|s| s != "com.atproto.access") 316 - .unwrap_or(false); 317 - if is_oauth || has_custom_scope { 318 - use std::collections::HashSet; 319 - let create_collections: HashSet<&Nsid> = input 320 - .writes 321 - .iter() 322 - .filter_map(|w| { 323 - if let WriteOp::Create { collection, .. } = w { 324 - Some(collection) 325 - } else { 326 - None 327 - } 328 - }) 329 - .collect(); 330 - let update_collections: HashSet<&Nsid> = input 331 - .writes 332 - .iter() 333 - .filter_map(|w| { 334 - if let WriteOp::Update { collection, .. } = w { 335 - Some(collection) 336 - } else { 337 - None 338 - } 339 - }) 340 - .collect(); 341 - let delete_collections: HashSet<&Nsid> = input 342 - .writes 343 - .iter() 344 - .filter_map(|w| { 345 - if let WriteOp::Delete { collection, .. } = w { 346 - Some(collection) 347 - } else { 348 - None 349 - } 350 - }) 351 - .collect(); 286 + let batch_proof = match verify_batch_write_scopes( 287 + &auth, 288 + &auth, 289 + &input.writes, 290 + |w| match w { 291 + WriteOp::Create { collection, .. } => collection.as_str(), 292 + WriteOp::Update { collection, .. } => collection.as_str(), 293 + WriteOp::Delete { collection, .. } => collection.as_str(), 294 + }, 295 + |w| match w { 296 + WriteOp::Create { .. } => WriteOpKind::Create, 297 + WriteOp::Update { .. } => WriteOpKind::Update, 298 + WriteOp::Delete { .. } => WriteOpKind::Delete, 299 + }, 300 + ) { 301 + Ok(proof) => proof, 302 + Err(e) => return Ok(e.into_response()), 303 + }; 352 304 353 - let scope_checks = create_collections 354 - .iter() 355 - .map(|c| (crate::oauth::RepoAction::Create, c)) 356 - .chain( 357 - update_collections 358 - .iter() 359 - .map(|c| (crate::oauth::RepoAction::Update, c)), 360 - ) 361 - .chain( 362 - delete_collections 363 - .iter() 364 - .map(|c| (crate::oauth::RepoAction::Delete, c)), 365 - ); 305 + let principal_did = batch_proof.principal_did(); 306 + let controller_did = batch_proof.controller_did().map(|c| c.into_did()); 366 307 367 - if let Some(err) = scope_checks 368 - .filter_map(|(action, collection)| { 369 - crate::auth::scope_check::check_repo_scope( 370 - is_oauth, 371 - scope.as_deref(), 372 - action, 373 - collection, 374 - ) 375 - .err() 376 - }) 377 - .next() 378 - { 379 - return Ok(err); 380 - } 308 + if input.repo.as_str() != principal_did.as_str() { 309 + return Err(ApiError::InvalidRepo( 310 + "Repo does not match authenticated user".into(), 311 + )); 312 + } 313 + 314 + let did = principal_did.into_did(); 315 + if let Err(e) = require_not_migrated(&state, &did).await { 316 + return Ok(e); 317 + } 318 + if let Err(e) = require_verified_or_delegated(&state, batch_proof.user()).await { 319 + return Ok(e); 381 320 } 382 321 383 322 let user_id: uuid::Uuid = state ··· 394 333 .ok() 395 334 .flatten() 396 335 .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; 397 - let current_root_cid = Cid::from_str(&root_cid_str) 336 + let current_root_cid = CommitCid::from_str(&root_cid_str) 398 337 .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?; 399 338 if let Some(swap_commit) = &input.swap_commit 400 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 339 + && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 401 340 { 402 341 return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 403 342 } 404 343 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 405 344 let commit_bytes = tracking_store 406 - .get(&current_root_cid) 345 + .get(current_root_cid.as_cid()) 407 346 .await 408 347 .ok() 409 348 .flatten() ··· 471 410 } => Some(*cid), 472 411 _ => None, 473 412 }); 474 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 413 + let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 475 414 .chain( 476 415 old_mst_blocks 477 416 .keys() ··· 487 426 CommitParams { 488 427 did: &did, 489 428 user_id, 490 - current_root_cid: Some(current_root_cid), 429 + current_root_cid: Some(current_root_cid.into_cid()), 491 430 prev_data_cid: Some(commit.data), 492 431 new_mst_root, 493 432 ops,
+12 -15
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 - use crate::auth::{Active, Auth}; 4 + use crate::auth::{Active, Auth, VerifyScope}; 5 + use crate::cid_types::CommitCid; 5 6 use crate::delegation::DelegationActionType; 6 7 use crate::repo::tracking::TrackingBlockStore; 7 8 use crate::state::AppState; ··· 43 44 auth: Auth<Active>, 44 45 Json(input): Json<DeleteRecordInput>, 45 46 ) -> Result<Response, crate::api::error::ApiError> { 46 - let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 47 + let scope_proof = match auth.verify_repo_delete(&input.collection) { 48 + Ok(proof) => proof, 49 + Err(e) => return Ok(e.into_response()), 50 + }; 51 + 52 + let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await { 47 53 Ok(res) => res, 48 54 Err(err_res) => return Ok(err_res), 49 55 }; 50 56 51 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 52 - repo_auth.is_oauth, 53 - repo_auth.scope.as_deref(), 54 - crate::oauth::RepoAction::Delete, 55 - &input.collection, 56 - ) { 57 - return Ok(e); 58 - } 59 - 60 57 let did = repo_auth.did; 61 58 let user_id = repo_auth.user_id; 62 59 let current_root_cid = repo_auth.current_root_cid; 63 60 let controller_did = repo_auth.controller_did; 64 61 65 62 if let Some(swap_commit) = &input.swap_commit 66 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 63 + && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 67 64 { 68 65 return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 69 66 } 70 67 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 71 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 68 + let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { 72 69 Ok(Some(b)) => b, 73 70 _ => { 74 71 return Ok( ··· 159 156 .into_iter() 160 157 .collect(); 161 158 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 162 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 159 + let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 163 160 .chain( 164 161 old_mst_blocks 165 162 .keys() ··· 173 170 CommitParams { 174 171 did: &did, 175 172 user_id, 176 - current_root_cid: Some(current_root_cid), 173 + current_root_cid: Some(current_root_cid.into_cid()), 177 174 prev_data_cid: Some(commit.data), 178 175 new_mst_root, 179 176 ops: vec![op],
+5
crates/tranquil-pds/src/api/repo/record/mod.rs
··· 1 1 pub mod batch; 2 2 pub mod delete; 3 + pub mod pagination; 3 4 pub mod read; 4 5 pub mod utils; 5 6 pub mod validation; 7 + pub mod validation_mode; 6 8 pub mod write; 7 9 10 + pub use pagination::PaginationDirection; 11 + pub use validation_mode::ValidationMode; 12 + 8 13 pub use batch::apply_writes; 9 14 pub use delete::{DeleteRecordInput, delete_record, delete_record_internal}; 10 15 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
+31
crates/tranquil-pds/src/api/repo/record/pagination.rs
··· 1 + use serde::{Deserialize, Deserializer}; 2 + 3 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] 4 + pub enum PaginationDirection { 5 + #[default] 6 + Forward, 7 + Backward, 8 + } 9 + 10 + impl PaginationDirection { 11 + pub fn from_optional_bool(value: Option<bool>) -> Self { 12 + match value { 13 + Some(true) => Self::Backward, 14 + Some(false) | None => Self::Forward, 15 + } 16 + } 17 + 18 + pub fn is_reverse(&self) -> bool { 19 + matches!(self, Self::Backward) 20 + } 21 + } 22 + 23 + pub fn deserialize_pagination_direction<'de, D>( 24 + deserializer: D, 25 + ) -> Result<PaginationDirection, D::Error> 26 + where 27 + D: Deserializer<'de>, 28 + { 29 + let opt: Option<bool> = Option::deserialize(deserializer)?; 30 + Ok(PaginationDirection::from_optional_bool(opt)) 31 + }
+7 -7
crates/tranquil-pds/src/api/repo/record/read.rs
··· 1 + use super::pagination::{PaginationDirection, deserialize_pagination_direction}; 1 2 use crate::api::error::ApiError; 2 3 use crate::state::AppState; 3 4 use crate::types::{AtIdentifier, Nsid, Rkey}; 5 + use crate::util::pds_hostname_without_port; 4 6 use axum::{ 5 7 Json, 6 8 extract::{Query, State}, ··· 58 60 _headers: HeaderMap, 59 61 Query(input): Query<GetRecordInput>, 60 62 ) -> Response { 61 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 62 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 63 + let hostname_for_handles = pds_hostname_without_port(); 63 64 let user_id_opt = if input.repo.is_did() { 64 65 let did: crate::types::Did = match input.repo.as_str().parse() { 65 66 Ok(d) => d, ··· 144 145 pub rkey_start: Option<Rkey>, 145 146 #[serde(rename = "rkeyEnd")] 146 147 pub rkey_end: Option<Rkey>, 147 - pub reverse: Option<bool>, 148 + #[serde(default, deserialize_with = "deserialize_pagination_direction")] 149 + pub reverse: PaginationDirection, 148 150 } 149 151 #[derive(Serialize)] 150 152 pub struct ListRecordsOutput { ··· 157 159 State(state): State<AppState>, 158 160 Query(input): Query<ListRecordsInput>, 159 161 ) -> Response { 160 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 161 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 162 + let hostname_for_handles = pds_hostname_without_port(); 162 163 let user_id_opt = if input.repo.is_did() { 163 164 let did: crate::types::Did = match input.repo.as_str().parse() { 164 165 Ok(d) => d, ··· 194 195 } 195 196 }; 196 197 let limit = input.limit.unwrap_or(50).clamp(1, 100); 197 - let reverse = input.reverse.unwrap_or(false); 198 198 let limit_i64 = limit as i64; 199 199 let cursor_rkey = input 200 200 .cursor ··· 207 207 &input.collection, 208 208 cursor_rkey.as_ref(), 209 209 limit_i64, 210 - reverse, 210 + input.reverse.is_reverse(), 211 211 input.rkey_start.as_ref(), 212 212 input.rkey_end.as_ref(), 213 213 )
+21 -19
crates/tranquil-pds/src/api/repo/record/utils.rs
··· 8 8 use k256::ecdsa::SigningKey; 9 9 use serde_json::{Value, json}; 10 10 use std::str::FromStr; 11 + use tranquil_db_traits::SequenceNumber; 11 12 use uuid::Uuid; 12 13 13 14 pub fn extract_blob_cids(record: &Value) -> Vec<String> { ··· 139 140 ) -> Result<CommitResult, String> { 140 141 use tranquil_db_traits::{ 141 142 ApplyCommitError, ApplyCommitInput, CommitEventData, RecordDelete, RecordUpsert, 143 + RepoEventType, 142 144 }; 143 145 144 146 let CommitParams { ··· 199 201 upserts.push(RecordUpsert { 200 202 collection: collection.clone(), 201 203 rkey: rkey.clone(), 202 - cid: crate::types::CidLink::new_unchecked(cid.to_string()), 204 + cid: unsafe { crate::types::CidLink::new_unchecked(cid.to_string()) }, 203 205 }); 204 206 } 205 207 RecordOp::Delete { ··· 263 265 264 266 let commit_event = CommitEventData { 265 267 did: did.clone(), 266 - event_type: "commit".to_string(), 267 - commit_cid: Some(crate::types::CidLink::new_unchecked( 268 - new_root_cid.to_string(), 269 - )), 270 - prev_cid: current_root_cid.map(|c| crate::types::CidLink::new_unchecked(c.to_string())), 268 + event_type: RepoEventType::Commit, 269 + commit_cid: Some(unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) }), 270 + prev_cid: current_root_cid 271 + .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 271 272 ops: Some(json!(ops_json)), 272 273 blobs: Some(blobs.to_vec()), 273 274 blocks_cids: Some(blocks_cids.to_vec()), 274 - prev_data_cid: prev_data_cid.map(|c| crate::types::CidLink::new_unchecked(c.to_string())), 275 + prev_data_cid: prev_data_cid 276 + .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 275 277 rev: Some(rev_str.clone()), 276 278 }; 277 279 ··· 279 281 user_id, 280 282 did: did.clone(), 281 283 expected_root_cid: current_root_cid 282 - .map(|c| crate::types::CidLink::new_unchecked(c.to_string())), 283 - new_root_cid: crate::types::CidLink::new_unchecked(new_root_cid.to_string()), 284 + .map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }), 285 + new_root_cid: unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) }, 284 286 new_rev: rev_str.clone(), 285 287 new_block_cids: all_block_cids, 286 288 obsolete_block_cids: obsolete_bytes, ··· 417 419 state: &AppState, 418 420 did: &Did, 419 421 handle: Option<&Handle>, 420 - ) -> Result<i64, String> { 422 + ) -> Result<SequenceNumber, String> { 421 423 state 422 424 .repo_repo 423 425 .insert_identity_event(did, handle) ··· 427 429 pub async fn sequence_account_event( 428 430 state: &AppState, 429 431 did: &Did, 430 - active: bool, 431 - status: Option<&str>, 432 - ) -> Result<i64, String> { 432 + status: tranquil_db_traits::AccountStatus, 433 + ) -> Result<SequenceNumber, String> { 433 434 state 434 435 .repo_repo 435 - .insert_account_event(did, active, status) 436 + .insert_account_event(did, status) 436 437 .await 437 438 .map_err(|e| format!("DB Error (account event): {}", e)) 438 439 } ··· 441 442 did: &Did, 442 443 commit_cid: &str, 443 444 rev: Option<&str>, 444 - ) -> Result<i64, String> { 445 - let cid_link = crate::types::CidLink::new_unchecked(commit_cid); 445 + ) -> Result<SequenceNumber, String> { 446 + let cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid) }; 446 447 state 447 448 .repo_repo 448 449 .insert_sync_event(did, &cid_link, rev) ··· 456 457 commit_cid: &Cid, 457 458 mst_root_cid: &Cid, 458 459 rev: &str, 459 - ) -> Result<i64, String> { 460 - let commit_cid_link = crate::types::CidLink::new_unchecked(commit_cid.to_string()); 461 - let mst_root_cid_link = crate::types::CidLink::new_unchecked(mst_root_cid.to_string()); 460 + ) -> Result<SequenceNumber, String> { 461 + let commit_cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid.to_string()) }; 462 + let mst_root_cid_link = 463 + unsafe { crate::types::CidLink::new_unchecked(mst_root_cid.to_string()) }; 462 464 state 463 465 .repo_repo 464 466 .insert_genesis_commit_event(did, &commit_cid_link, &mst_root_cid_link, rev)
+35
crates/tranquil-pds/src/api/repo/record/validation_mode.rs
··· 1 + use serde::{Deserialize, Deserializer}; 2 + 3 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] 4 + pub enum ValidationMode { 5 + Skip, 6 + #[default] 7 + Infer, 8 + Strict, 9 + } 10 + 11 + impl ValidationMode { 12 + pub fn from_optional_bool(value: Option<bool>) -> Self { 13 + match value { 14 + Some(false) => Self::Skip, 15 + Some(true) => Self::Strict, 16 + None => Self::Infer, 17 + } 18 + } 19 + 20 + pub fn should_skip(&self) -> bool { 21 + matches!(self, Self::Skip) 22 + } 23 + 24 + pub fn requires_lexicon(&self) -> bool { 25 + matches!(self, Self::Strict) 26 + } 27 + } 28 + 29 + pub fn deserialize_validation_mode<'de, D>(deserializer: D) -> Result<ValidationMode, D::Error> 30 + where 31 + D: Deserializer<'de>, 32 + { 33 + let opt: Option<bool> = Option::deserialize(deserializer)?; 34 + Ok(ValidationMode::from_optional_bool(opt)) 35 + }
+51 -77
crates/tranquil-pds/src/api/repo/record/write.rs
··· 1 1 use super::validation::validate_record_with_status; 2 + use super::validation_mode::{ValidationMode, deserialize_validation_mode}; 2 3 use crate::api::error::ApiError; 3 4 use crate::api::repo::record::utils::{ 4 5 CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 5 6 }; 6 - use crate::auth::{Active, Auth}; 7 + use crate::auth::{ 8 + Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, 9 + require_verified_or_delegated, 10 + }; 11 + use crate::cid_types::CommitCid; 7 12 use crate::delegation::DelegationActionType; 8 13 use crate::repo::tracking::TrackingBlockStore; 9 14 use crate::state::AppState; ··· 26 31 pub struct RepoWriteAuth { 27 32 pub did: Did, 28 33 pub user_id: Uuid, 29 - pub current_root_cid: Cid, 34 + pub current_root_cid: CommitCid, 30 35 pub is_oauth: bool, 31 36 pub scope: Option<String>, 32 37 pub controller_did: Option<Did>, 33 38 } 34 39 35 - pub async fn prepare_repo_write( 40 + pub async fn prepare_repo_write<A: RepoScopeAction>( 36 41 state: &AppState, 37 - auth_user: &crate::auth::AuthenticatedUser, 42 + scope_proof: &ScopeVerified<'_, A>, 38 43 repo: &AtIdentifier, 39 44 ) -> Result<RepoWriteAuth, Response> { 40 - if repo.as_str() != auth_user.did.as_str() { 45 + let user = scope_proof.user(); 46 + let principal_did = scope_proof.principal_did(); 47 + if repo.as_str() != principal_did.as_str() { 41 48 return Err( 42 49 ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), 43 50 ); 44 51 } 45 - if state 46 - .user_repo 47 - .is_account_migrated(&auth_user.did) 48 - .await 49 - .unwrap_or(false) 50 - { 51 - return Err(ApiError::AccountMigrated.into_response()); 52 - } 53 - let is_verified = state 54 - .user_repo 55 - .has_verified_comms_channel(&auth_user.did) 56 - .await 57 - .unwrap_or(false); 58 - let is_delegated = state 59 - .delegation_repo 60 - .is_delegated_account(&auth_user.did) 61 - .await 62 - .unwrap_or(false); 63 - if !is_verified && !is_delegated { 64 - return Err(ApiError::AccountNotVerified.into_response()); 65 - } 52 + 53 + require_not_migrated(state, principal_did.as_did()).await?; 54 + let _account_verified = require_verified_or_delegated(state, user).await?; 55 + 66 56 let user_id = state 67 57 .user_repo 68 - .get_id_by_did(&auth_user.did) 58 + .get_id_by_did(principal_did.as_did()) 69 59 .await 70 60 .map_err(|e| { 71 61 error!("DB error fetching user: {}", e); ··· 83 73 .ok_or_else(|| { 84 74 ApiError::InternalError(Some("Repo root not found".into())).into_response() 85 75 })?; 86 - let current_root_cid = Cid::from_str(&root_cid_str).map_err(|_| { 76 + let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| { 87 77 ApiError::InternalError(Some("Invalid repo root CID".into())).into_response() 88 78 })?; 89 79 Ok(RepoWriteAuth { 90 - did: auth_user.did.clone(), 80 + did: principal_did.into_did(), 91 81 user_id, 92 82 current_root_cid, 93 - is_oauth: auth_user.is_oauth(), 94 - scope: auth_user.scope.clone(), 95 - controller_did: auth_user.controller_did.clone(), 83 + is_oauth: user.is_oauth(), 84 + scope: user.scope.clone(), 85 + controller_did: scope_proof.controller_did().map(|c| c.into_did()), 96 86 }) 97 87 } 98 88 #[derive(Deserialize)] ··· 101 91 pub repo: AtIdentifier, 102 92 pub collection: Nsid, 103 93 pub rkey: Option<Rkey>, 104 - pub validate: Option<bool>, 94 + #[serde(default, deserialize_with = "deserialize_validation_mode")] 95 + pub validate: ValidationMode, 105 96 pub record: serde_json::Value, 106 97 #[serde(rename = "swapCommit")] 107 98 pub swap_commit: Option<String>, ··· 127 118 auth: Auth<Active>, 128 119 Json(input): Json<CreateRecordInput>, 129 120 ) -> Result<Response, crate::api::error::ApiError> { 130 - let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 121 + let scope_proof = match auth.verify_repo_create(&input.collection) { 122 + Ok(proof) => proof, 123 + Err(e) => return Ok(e.into_response()), 124 + }; 125 + 126 + let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await { 131 127 Ok(res) => res, 132 128 Err(err_res) => return Ok(err_res), 133 129 }; 134 130 135 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 136 - repo_auth.is_oauth, 137 - repo_auth.scope.as_deref(), 138 - crate::oauth::RepoAction::Create, 139 - &input.collection, 140 - ) { 141 - return Ok(e); 142 - } 143 - 144 131 let did = repo_auth.did; 145 132 let user_id = repo_auth.user_id; 146 133 let current_root_cid = repo_auth.current_root_cid; 147 134 let controller_did = repo_auth.controller_did; 148 135 149 136 if let Some(swap_commit) = &input.swap_commit 150 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 137 + && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 151 138 { 152 139 return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 153 140 } 154 141 155 - let validation_status = if input.validate == Some(false) { 142 + let validation_status = if input.validate.should_skip() { 156 143 None 157 144 } else { 158 - let require_lexicon = input.validate == Some(true); 159 145 match validate_record_with_status( 160 146 &input.record, 161 147 &input.collection, 162 148 input.rkey.as_ref(), 163 - require_lexicon, 149 + input.validate.requires_lexicon(), 164 150 ) { 165 151 Ok(status) => Some(status), 166 152 Err(err_response) => return Ok(*err_response), ··· 169 155 let rkey = input.rkey.unwrap_or_else(Rkey::generate); 170 156 171 157 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 172 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 158 + let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { 173 159 Ok(Some(b)) => b, 174 160 _ => { 175 161 return Ok( ··· 192 178 let mut conflict_uris_to_cleanup: Vec<AtUri> = Vec::new(); 193 179 let mut all_old_mst_blocks = std::collections::BTreeMap::new(); 194 180 195 - if input.validate != Some(false) { 181 + if !input.validate.should_skip() { 196 182 let record_uri = AtUri::from_parts(&did, &input.collection, &rkey); 197 183 let backlinks = extract_backlinks(&record_uri, &input.record); 198 184 ··· 323 309 .collect(); 324 310 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 325 311 let blob_cids = extract_blob_cids(&input.record); 326 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 312 + let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 327 313 .chain( 328 314 all_old_mst_blocks 329 315 .keys() ··· 337 323 CommitParams { 338 324 did: &did, 339 325 user_id, 340 - current_root_cid: Some(current_root_cid), 326 + current_root_cid: Some(current_root_cid.into_cid()), 341 327 prev_data_cid: Some(initial_mst_root), 342 328 new_mst_root, 343 329 ops, ··· 412 398 pub repo: AtIdentifier, 413 399 pub collection: Nsid, 414 400 pub rkey: Rkey, 415 - pub validate: Option<bool>, 401 + #[serde(default, deserialize_with = "deserialize_validation_mode")] 402 + pub validate: ValidationMode, 416 403 pub record: serde_json::Value, 417 404 #[serde(rename = "swapCommit")] 418 405 pub swap_commit: Option<String>, ··· 434 421 auth: Auth<Active>, 435 422 Json(input): Json<PutRecordInput>, 436 423 ) -> Result<Response, crate::api::error::ApiError> { 437 - let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 424 + let upsert_proof = match auth.verify_repo_upsert(&input.collection) { 425 + Ok(proof) => proof, 426 + Err(e) => return Ok(e.into_response()), 427 + }; 428 + 429 + let repo_auth = match prepare_repo_write(&state, &upsert_proof, &input.repo).await { 438 430 Ok(res) => res, 439 431 Err(err_res) => return Ok(err_res), 440 432 }; 441 433 442 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 443 - repo_auth.is_oauth, 444 - repo_auth.scope.as_deref(), 445 - crate::oauth::RepoAction::Create, 446 - &input.collection, 447 - ) { 448 - return Ok(e); 449 - } 450 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 451 - repo_auth.is_oauth, 452 - repo_auth.scope.as_deref(), 453 - crate::oauth::RepoAction::Update, 454 - &input.collection, 455 - ) { 456 - return Ok(e); 457 - } 458 - 459 434 let did = repo_auth.did; 460 435 let user_id = repo_auth.user_id; 461 436 let current_root_cid = repo_auth.current_root_cid; 462 437 let controller_did = repo_auth.controller_did; 463 438 464 439 if let Some(swap_commit) = &input.swap_commit 465 - && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 440 + && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 466 441 { 467 442 return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 468 443 } 469 444 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 470 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 445 + let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { 471 446 Ok(Some(b)) => b, 472 447 _ => { 473 448 return Ok( ··· 485 460 }; 486 461 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 487 462 let key = format!("{}/{}", input.collection, input.rkey); 488 - let validation_status = if input.validate == Some(false) { 463 + let validation_status = if input.validate.should_skip() { 489 464 None 490 465 } else { 491 - let require_lexicon = input.validate == Some(true); 492 466 match validate_record_with_status( 493 467 &input.record, 494 468 &input.collection, 495 469 Some(&input.rkey), 496 - require_lexicon, 470 + input.validate.requires_lexicon(), 497 471 ) { 498 472 Ok(status) => Some(status), 499 473 Err(err_response) => return Ok(*err_response), ··· 610 584 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 611 585 let is_update = existing_cid.is_some(); 612 586 let blob_cids = extract_blob_cids(&input.record); 613 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 587 + let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 614 588 .chain( 615 589 old_mst_blocks 616 590 .keys() ··· 624 598 CommitParams { 625 599 did: &did, 626 600 user_id, 627 - current_root_cid: Some(current_root_cid), 601 + current_root_cid: Some(current_root_cid.into_cid()), 628 602 prev_data_cid: Some(commit.data), 629 603 new_mst_root, 630 604 ops: vec![op],
+31 -32
crates/tranquil-pds/src/api/server/account_status.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::{Auth, NotTakendown, Permissive}; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 + use crate::auth::{Auth, NotTakendown, Permissive, require_legacy_session_mfa}; 4 4 use crate::cache::Cache; 5 5 use crate::plc::PlcClient; 6 6 use crate::state::AppState; 7 7 use crate::types::PlainPassword; 8 + use crate::util::pds_hostname; 8 9 use axum::{ 9 10 Json, 10 11 extract::State, ··· 130 131 did: &crate::types::Did, 131 132 with_retry: bool, 132 133 ) -> Result<(), ApiError> { 133 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 134 + let hostname = pds_hostname(); 134 135 let expected_endpoint = format!("https://{}", hostname); 135 136 136 137 if did.as_str().starts_with("did:plc:") { ··· 219 220 .and_then(|v| v.get("atproto")) 220 221 .and_then(|k| k.as_str()); 221 222 222 - let user_key = user_repo.get_user_key_by_did(did).await.map_err(|e| { 223 - error!("Failed to fetch user key: {:?}", e); 224 - ApiError::InternalError(None) 225 - })?; 223 + let user_key = user_repo 224 + .get_user_key_by_did(did) 225 + .await 226 + .log_db_err("fetching user key")?; 226 227 227 228 if let Some(key_info) = user_key { 228 229 let key_bytes = ··· 379 380 "[MIGRATION] activateAccount: Sequencing account event (active=true) for did={}", 380 381 did 381 382 ); 382 - if let Err(e) = 383 - crate::api::repo::record::sequence_account_event(&state, &did, true, None).await 383 + if let Err(e) = crate::api::repo::record::sequence_account_event( 384 + &state, 385 + &did, 386 + tranquil_db_traits::AccountStatus::Active, 387 + ) 388 + .await 384 389 { 385 390 warn!( 386 391 "[MIGRATION] activateAccount: Failed to sequence account activation event: {}", ··· 502 507 if let Err(e) = crate::api::repo::record::sequence_account_event( 503 508 &state, 504 509 &did, 505 - false, 506 - Some("deactivated"), 510 + tranquil_db_traits::AccountStatus::Deactivated, 507 511 ) 508 512 .await 509 513 { ··· 523 527 State(state): State<AppState>, 524 528 auth: Auth<NotTakendown>, 525 529 ) -> Result<Response, ApiError> { 526 - let did = &auth.did; 527 - 528 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await { 529 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 530 - &*state.user_repo, 531 - &*state.session_repo, 532 - did, 533 - ) 534 - .await); 535 - } 530 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 531 + Ok(proof) => proof, 532 + Err(response) => return Ok(response), 533 + }; 536 534 537 535 let user_id = state 538 536 .user_repo 539 - .get_id_by_did(did) 537 + .get_id_by_did(session_mfa.did()) 540 538 .await 541 539 .ok() 542 540 .flatten() ··· 545 543 let expires_at = Utc::now() + Duration::minutes(15); 546 544 state 547 545 .infra_repo 548 - .create_deletion_request(&confirmation_token, did, expires_at) 546 + .create_deletion_request(&confirmation_token, session_mfa.did(), expires_at) 549 547 .await 550 - .map_err(|e| { 551 - error!("DB error creating deletion token: {:?}", e); 552 - ApiError::InternalError(None) 553 - })?; 554 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 548 + .log_db_err("creating deletion token")?; 549 + let hostname = pds_hostname(); 555 550 if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion( 556 551 state.user_repo.as_ref(), 557 552 state.infra_repo.as_ref(), 558 553 user_id, 559 554 &confirmation_token, 560 - &hostname, 555 + hostname, 561 556 ) 562 557 .await 563 558 { 564 559 warn!("Failed to enqueue account deletion notification: {:?}", e); 565 560 } 566 - info!("Account deletion requested for user {}", did); 561 + info!("Account deletion requested for user {}", session_mfa.did()); 567 562 Ok(EmptyResponse::ok().into_response()) 568 563 } 569 564 ··· 642 637 error!("DB error deleting account: {:?}", e); 643 638 return ApiError::InternalError(None).into_response(); 644 639 } 645 - let account_seq = 646 - crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await; 640 + let account_seq = crate::api::repo::record::sequence_account_event( 641 + &state, 642 + did, 643 + tranquil_db_traits::AccountStatus::Deleted, 644 + ) 645 + .await; 647 646 match account_seq { 648 647 Ok(seq) => { 649 648 if let Err(e) = state.repo_repo.delete_sequences_except(did, seq).await {
+19 -51
crates/tranquil-pds/src/api/server/app_password.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Auth, NotTakendown, Permissive, generate_app_password}; 4 4 use crate::delegation::{DelegationActionType, intersect_scopes}; 5 - use crate::state::{AppState, RateLimitKind}; 5 + use crate::rate_limit::{AppPasswordLimit, RateLimited}; 6 + use crate::state::AppState; 6 7 use axum::{ 7 8 Json, 8 9 extract::State, 9 - http::HeaderMap, 10 10 response::{IntoResponse, Response}, 11 11 }; 12 12 use serde::{Deserialize, Serialize}; 13 13 use serde_json::json; 14 - use tracing::{error, warn}; 14 + use tracing::error; 15 15 use tranquil_db_traits::AppPasswordCreate; 16 16 17 17 #[derive(Serialize)] ··· 39 39 .user_repo 40 40 .get_by_did(&auth.did) 41 41 .await 42 - .map_err(|e| { 43 - error!("DB error getting user: {:?}", e); 44 - ApiError::InternalError(None) 45 - })? 42 + .log_db_err("getting user")? 46 43 .ok_or(ApiError::AccountNotFound)?; 47 44 48 45 let rows = state 49 46 .session_repo 50 47 .list_app_passwords(user.id) 51 48 .await 52 - .map_err(|e| { 53 - error!("DB error listing app passwords: {:?}", e); 54 - ApiError::InternalError(None) 55 - })?; 49 + .log_db_err("listing app passwords")?; 56 50 let passwords: Vec<AppPassword> = rows 57 51 .iter() 58 52 .map(|row| AppPassword { 59 53 name: row.name.clone(), 60 54 created_at: row.created_at.to_rfc3339(), 61 - privileged: row.privileged, 55 + privileged: row.privilege.is_privileged(), 62 56 scopes: row.scopes.clone(), 63 57 created_by_controller: row 64 58 .created_by_controller_did ··· 89 83 90 84 pub async fn create_app_password( 91 85 State(state): State<AppState>, 92 - headers: HeaderMap, 86 + _rate_limit: RateLimited<AppPasswordLimit>, 93 87 auth: Auth<NotTakendown>, 94 88 Json(input): Json<CreateAppPasswordInput>, 95 89 ) -> Result<Response, ApiError> { 96 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 97 - if !state 98 - .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 99 - .await 100 - { 101 - warn!(ip = %client_ip, "App password creation rate limit exceeded"); 102 - return Err(ApiError::RateLimitExceeded(None)); 103 - } 104 - 105 90 let user = state 106 91 .user_repo 107 92 .get_by_did(&auth.did) 108 93 .await 109 - .map_err(|e| { 110 - error!("DB error getting user: {:?}", e); 111 - ApiError::InternalError(None) 112 - })? 94 + .log_db_err("getting user")? 113 95 .ok_or(ApiError::AccountNotFound)?; 114 96 115 97 let name = input.name.trim(); ··· 121 103 .session_repo 122 104 .get_app_password_by_name(user.id, name) 123 105 .await 124 - .map_err(|e| { 125 - error!("DB error checking app password: {:?}", e); 126 - ApiError::InternalError(None) 127 - })? 106 + .log_db_err("checking app password")? 128 107 .is_some() 129 108 { 130 109 return Err(ApiError::DuplicateAppPassword); ··· 140 119 let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default(); 141 120 142 121 let requested = input.scopes.as_deref().unwrap_or("atproto"); 143 - let intersected = intersect_scopes(requested, &granted_scopes); 122 + let intersected = intersect_scopes(requested, granted_scopes.as_str()); 144 123 145 124 if intersected.is_empty() && !granted_scopes.is_empty() { 146 125 return Err(ApiError::InsufficientScope(None)); ··· 171 150 ApiError::InternalError(None) 172 151 })?; 173 152 174 - let privileged = input.privileged.unwrap_or(false); 153 + let privilege = 154 + tranquil_db_traits::AppPasswordPrivilege::from(input.privileged.unwrap_or(false)); 175 155 let created_at = chrono::Utc::now(); 176 156 177 157 let create_data = AppPasswordCreate { 178 158 user_id: user.id, 179 159 name: name.to_string(), 180 160 password_hash, 181 - privileged, 161 + privilege, 182 162 scopes: final_scopes.clone(), 183 163 created_by_controller_did: controller_did.clone(), 184 164 }; ··· 187 167 .session_repo 188 168 .create_app_password(&create_data) 189 169 .await 190 - .map_err(|e| { 191 - error!("DB error creating app password: {:?}", e); 192 - ApiError::InternalError(None) 193 - })?; 170 + .log_db_err("creating app password")?; 194 171 195 172 if let Some(ref controller) = controller_did { 196 173 let _ = state ··· 214 191 name: name.to_string(), 215 192 password, 216 193 created_at: created_at.to_rfc3339(), 217 - privileged, 194 + privileged: privilege.is_privileged(), 218 195 scopes: final_scopes, 219 196 }) 220 197 .into_response()) ··· 234 211 .user_repo 235 212 .get_by_did(&auth.did) 236 213 .await 237 - .map_err(|e| { 238 - error!("DB error getting user: {:?}", e); 239 - ApiError::InternalError(None) 240 - })? 214 + .log_db_err("getting user")? 241 215 .ok_or(ApiError::AccountNotFound)?; 242 216 243 217 let name = input.name.trim(); ··· 255 229 .session_repo 256 230 .delete_sessions_by_app_password(&auth.did, name) 257 231 .await 258 - .map_err(|e| { 259 - error!("DB error revoking sessions for app password: {:?}", e); 260 - ApiError::InternalError(None) 261 - })?; 232 + .log_db_err("revoking sessions for app password")?; 262 233 263 234 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 264 235 let cache_key = format!("auth:session:{}:{}", &auth.did, jti); ··· 273 244 .session_repo 274 245 .delete_app_password(user.id, name) 275 246 .await 276 - .map_err(|e| { 277 - error!("DB error revoking app password: {:?}", e); 278 - ApiError::InternalError(None) 279 - })?; 247 + .log_db_err("revoking app password")?; 280 248 281 249 Ok(EmptyResponse::ok().into_response()) 282 250 }
+21 -92
crates/tranquil-pds/src/api/server/email.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse}; 3 3 use crate::auth::{Auth, NotTakendown}; 4 - use crate::state::{AppState, RateLimitKind}; 4 + use crate::rate_limit::{EmailUpdateLimit, RateLimited, VerificationCheckLimit}; 5 + use crate::state::AppState; 6 + use crate::util::pds_hostname; 5 7 use axum::{ 6 8 Json, 7 9 extract::State, ··· 44 46 45 47 pub async fn request_email_update( 46 48 State(state): State<AppState>, 47 - headers: axum::http::HeaderMap, 49 + _rate_limit: RateLimited<EmailUpdateLimit>, 48 50 auth: Auth<NotTakendown>, 49 51 input: Option<Json<RequestEmailUpdateInput>>, 50 52 ) -> Result<Response, ApiError> { 51 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 52 - if !state 53 - .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 54 - .await 55 - { 56 - warn!(ip = %client_ip, "Email update rate limit exceeded"); 57 - return Err(ApiError::RateLimitExceeded(None)); 58 - } 59 - 60 53 if let Err(e) = crate::auth::scope_check::check_account_scope( 61 54 auth.is_oauth(), 62 55 auth.scope.as_deref(), ··· 70 63 .user_repo 71 64 .get_email_info_by_did(&auth.did) 72 65 .await 73 - .map_err(|e| { 74 - error!("DB error: {:?}", e); 75 - ApiError::InternalError(None) 76 - })? 66 + .log_db_err("getting email info")? 77 67 .ok_or(ApiError::AccountNotFound)?; 78 68 79 69 let Some(current_email) = user.email else { ··· 111 101 } 112 102 } 113 103 114 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 104 + let hostname = pds_hostname(); 115 105 if let Err(e) = crate::comms::comms_repo::enqueue_email_update_token( 116 106 state.user_repo.as_ref(), 117 107 state.infra_repo.as_ref(), 118 108 user.id, 119 109 &code, 120 110 &formatted_code, 121 - &hostname, 111 + hostname, 122 112 ) 123 113 .await 124 114 { ··· 139 129 140 130 pub async fn confirm_email( 141 131 State(state): State<AppState>, 142 - headers: axum::http::HeaderMap, 132 + _rate_limit: RateLimited<EmailUpdateLimit>, 143 133 auth: Auth<NotTakendown>, 144 134 Json(input): Json<ConfirmEmailInput>, 145 135 ) -> Result<Response, ApiError> { 146 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 147 - if !state 148 - .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 149 - .await 150 - { 151 - warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 152 - return Err(ApiError::RateLimitExceeded(None)); 153 - } 154 - 155 136 if let Err(e) = crate::auth::scope_check::check_account_scope( 156 137 auth.is_oauth(), 157 138 auth.scope.as_deref(), ··· 166 147 .user_repo 167 148 .get_email_info_by_did(did) 168 149 .await 169 - .map_err(|e| { 170 - error!("DB error: {:?}", e); 171 - ApiError::InternalError(None) 172 - })? 150 + .log_db_err("getting email info")? 173 151 .ok_or(ApiError::AccountNotFound)?; 174 152 175 153 let Some(ref email) = user.email else { ··· 213 191 .user_repo 214 192 .set_email_verified(user.id, true) 215 193 .await 216 - .map_err(|e| { 217 - error!("DB error confirming email: {:?}", e); 218 - ApiError::InternalError(None) 219 - })?; 194 + .log_db_err("confirming email")?; 220 195 221 196 info!("Email confirmed for user {}", user.id); 222 197 Ok(EmptyResponse::ok().into_response()) ··· 250 225 .user_repo 251 226 .get_email_info_by_did(did) 252 227 .await 253 - .map_err(|e| { 254 - error!("DB error: {:?}", e); 255 - ApiError::InternalError(None) 256 - })? 228 + .log_db_err("getting email info")? 257 229 .ok_or(ApiError::AccountNotFound)?; 258 230 259 231 let user_id = user.id; ··· 325 297 .user_repo 326 298 .update_email(user_id, &new_email) 327 299 .await 328 - .map_err(|e| { 329 - error!("DB error updating email: {:?}", e); 330 - ApiError::InternalError(None) 331 - })?; 300 + .log_db_err("updating email")?; 332 301 333 302 let verification_token = 334 303 crate::auth::verification_token::generate_signup_token(did, "email", &new_email); 335 304 let formatted_token = 336 305 crate::auth::verification_token::format_token_for_display(&verification_token); 337 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 306 + let hostname = pds_hostname(); 338 307 if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 339 308 state.infra_repo.as_ref(), 340 309 user_id, 341 310 "email", 342 311 &new_email, 343 312 &formatted_token, 344 - &hostname, 313 + hostname, 345 314 ) 346 315 .await 347 316 { ··· 371 340 372 341 pub async fn check_email_verified( 373 342 State(state): State<AppState>, 374 - headers: axum::http::HeaderMap, 343 + _rate_limit: RateLimited<VerificationCheckLimit>, 375 344 Json(input): Json<CheckEmailVerifiedInput>, 376 345 ) -> Response { 377 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 378 - if !state 379 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 380 - .await 381 - { 382 - return ApiError::RateLimitExceeded(None).into_response(); 383 - } 384 - 385 346 match state 386 347 .user_repo 387 348 .check_email_verified_by_identifier(&input.identifier) ··· 403 364 404 365 pub async fn authorize_email_update( 405 366 State(state): State<AppState>, 406 - headers: axum::http::HeaderMap, 367 + _rate_limit: RateLimited<VerificationCheckLimit>, 407 368 axum::extract::Query(query): axum::extract::Query<AuthorizeEmailUpdateQuery>, 408 369 ) -> Response { 409 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 410 - if !state 411 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 412 - .await 413 - { 414 - return ApiError::RateLimitExceeded(None).into_response(); 415 - } 416 - 417 370 let verified = crate::auth::verification_token::verify_token_signature(&query.token); 418 371 419 372 let token_data = match verified { ··· 488 441 489 442 info!(did = %did, "Email update authorized via link click"); 490 443 491 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 444 + let hostname = pds_hostname(); 492 445 let redirect_url = format!( 493 446 "https://{}/app/verify?type=email-authorize-success", 494 447 hostname ··· 499 452 500 453 pub async fn check_email_update_status( 501 454 State(state): State<AppState>, 502 - headers: axum::http::HeaderMap, 455 + _rate_limit: RateLimited<VerificationCheckLimit>, 503 456 auth: Auth<NotTakendown>, 504 457 ) -> Result<Response, ApiError> { 505 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 506 - if !state 507 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 508 - .await 509 - { 510 - return Err(ApiError::RateLimitExceeded(None)); 511 - } 512 - 513 458 if let Err(e) = crate::auth::scope_check::check_account_scope( 514 459 auth.is_oauth(), 515 460 auth.scope.as_deref(), ··· 549 494 550 495 pub async fn check_email_in_use( 551 496 State(state): State<AppState>, 552 - headers: axum::http::HeaderMap, 497 + _rate_limit: RateLimited<VerificationCheckLimit>, 553 498 Json(input): Json<CheckEmailInUseInput>, 554 499 ) -> Response { 555 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 556 - if !state 557 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 558 - .await 559 - { 560 - return ApiError::RateLimitExceeded(None).into_response(); 561 - } 562 - 563 500 let email = input.email.trim().to_lowercase(); 564 501 if email.is_empty() { 565 502 return ApiError::InvalidRequest("email is required".into()).into_response(); ··· 587 524 588 525 pub async fn check_comms_channel_in_use( 589 526 State(state): State<AppState>, 590 - headers: axum::http::HeaderMap, 527 + _rate_limit: RateLimited<VerificationCheckLimit>, 591 528 Json(input): Json<CheckCommsChannelInUseInput>, 592 529 ) -> Response { 593 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 594 - if !state 595 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 596 - .await 597 - { 598 - return ApiError::RateLimitExceeded(None).into_response(); 599 - } 600 - 601 530 let channel = match input.channel.to_lowercase().as_str() { 602 531 "email" => CommsChannel::Email, 603 532 "discord" => CommsChannel::Discord,
+6 -10
crates/tranquil-pds/src/api/server/invite.rs
··· 1 1 use crate::api::ApiError; 2 + use crate::api::error::DbResultExt; 2 3 use crate::auth::{Admin, Auth, NotTakendown}; 3 4 use crate::state::AppState; 4 5 use crate::types::Did; 6 + use crate::util::pds_hostname; 5 7 use axum::{ 6 8 Json, 7 9 extract::State, ··· 24 26 } 25 27 26 28 fn gen_invite_code() -> String { 27 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 29 + let hostname = pds_hostname(); 28 30 let hostname_prefix = hostname.replace('.', "-"); 29 31 format!("{}-{}", hostname_prefix, gen_random_token()) 30 32 } ··· 121 123 .user_repo 122 124 .get_any_admin_user_id() 123 125 .await 124 - .map_err(|e| { 125 - error!("DB error looking up admin user: {:?}", e); 126 - ApiError::InternalError(None) 127 - })? 126 + .log_db_err("looking up admin user")? 128 127 .ok_or_else(|| { 129 128 error!("No admin user found to create invite codes"); 130 129 ApiError::InternalError(None) ··· 202 201 .infra_repo 203 202 .get_invite_codes_for_account(&auth.did) 204 203 .await 205 - .map_err(|e| { 206 - error!("DB error fetching invite codes: {:?}", e); 207 - ApiError::InternalError(None) 208 - })?; 204 + .log_db_err("fetching invite codes")?; 209 205 210 206 let filtered_codes: Vec<_> = codes_info 211 207 .into_iter() 212 - .filter(|info| !info.disabled) 208 + .filter(|info| info.state.is_active()) 213 209 .collect(); 214 210 215 211 let codes = futures::future::join_all(filtered_codes.into_iter().map(|info| {
+1 -1
crates/tranquil-pds/src/api/server/logo.rs
··· 21 21 Some(c) if !c.is_empty() => c, 22 22 _ => return StatusCode::NOT_FOUND.into_response(), 23 23 }; 24 - let cid = crate::types::CidLink::new_unchecked(&cid_str); 24 + let cid = unsafe { crate::types::CidLink::new_unchecked(&cid_str) }; 25 25 26 26 let metadata = match state.blob_repo.get_blob_metadata(&cid).await { 27 27 Ok(Some(m)) => m,
+3 -2
crates/tranquil-pds/src/api/server/meta.rs
··· 1 1 use crate::state::AppState; 2 + use crate::util::pds_hostname; 2 3 use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; 3 4 use serde_json::json; 4 5 ··· 30 31 } 31 32 32 33 pub async fn describe_server() -> impl IntoResponse { 33 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 34 + let pds_hostname = pds_hostname(); 34 35 let domains_str = 35 - std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.clone()); 36 + std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.to_string()); 36 37 let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect(); 37 38 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 38 39 .map(|v| v == "true" || v == "1")
+6 -13
crates/tranquil-pds/src/api/server/migration.rs
··· 1 1 use crate::api::ApiError; 2 + use crate::api::error::DbResultExt; 2 3 use crate::auth::{Active, Auth}; 3 4 use crate::state::AppState; 5 + use crate::util::pds_hostname; 4 6 use axum::{ 5 7 Json, 6 8 extract::State, ··· 49 51 .user_repo 50 52 .get_user_for_did_doc(&auth.did) 51 53 .await 52 - .map_err(|e| { 53 - tracing::error!("DB error getting user: {:?}", e); 54 - ApiError::InternalError(None) 55 - })? 54 + .log_db_err("getting user")? 56 55 .ok_or(ApiError::AccountNotFound)?; 57 56 58 57 if let Some(ref methods) = input.verification_methods { ··· 107 106 .user_repo 108 107 .upsert_did_web_overrides(user.id, verification_methods_json, also_known_as) 109 108 .await 110 - .map_err(|e| { 111 - tracing::error!("DB error upserting did_web_overrides: {:?}", e); 112 - ApiError::InternalError(None) 113 - })?; 109 + .log_db_err("upserting did_web_overrides")?; 114 110 115 111 if let Some(ref endpoint) = input.service_endpoint { 116 112 let endpoint_clean = endpoint.trim().trim_end_matches('/'); ··· 118 114 .user_repo 119 115 .update_migrated_to_pds(&auth.did, endpoint_clean) 120 116 .await 121 - .map_err(|e| { 122 - tracing::error!("DB error updating service endpoint: {:?}", e); 123 - ApiError::InternalError(None) 124 - })?; 117 + .log_db_err("updating service endpoint")?; 125 118 } 126 119 127 120 let did_doc = build_did_document(&state, &auth.did).await; ··· 154 147 } 155 148 156 149 async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value { 157 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 150 + let hostname = pds_hostname(); 158 151 159 152 let user = match state.user_repo.get_user_for_did_doc_build(did).await { 160 153 Ok(Some(row)) => row,
+38 -87
crates/tranquil-pds/src/api/server/passkey_account.rs
··· 1 1 use crate::api::SuccessResponse; 2 2 use crate::api::error::ApiError; 3 + use crate::auth::NormalizedLoginIdentifier; 3 4 use axum::{ 4 5 Json, 5 6 extract::State, ··· 19 20 20 21 use crate::api::repo::record::utils::create_signed_commit; 21 22 use crate::auth::{ServiceTokenVerifier, generate_app_password, is_service_token}; 22 - use crate::state::{AppState, RateLimitKind}; 23 + use crate::rate_limit::{AccountCreationLimit, PasswordResetLimit, RateLimited}; 24 + use crate::state::AppState; 23 25 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 26 + use crate::util::{pds_hostname, pds_hostname_without_port}; 24 27 use crate::validation::validate_password; 25 28 26 - fn extract_client_ip(headers: &HeaderMap) -> String { 27 - if let Some(forwarded) = headers.get("x-forwarded-for") 28 - && let Ok(value) = forwarded.to_str() 29 - && let Some(first_ip) = value.split(',').next() 30 - { 31 - return first_ip.trim().to_string(); 32 - } 33 - if let Some(real_ip) = headers.get("x-real-ip") 34 - && let Ok(value) = real_ip.to_str() 35 - { 36 - return value.trim().to_string(); 37 - } 38 - "unknown".to_string() 39 - } 40 - 41 29 fn generate_setup_token() -> String { 42 30 let mut rng = rand::thread_rng(); 43 31 (0..32) ··· 80 68 81 69 pub async fn create_passkey_account( 82 70 State(state): State<AppState>, 71 + _rate_limit: RateLimited<AccountCreationLimit>, 83 72 headers: HeaderMap, 84 73 Json(input): Json<CreatePasskeyAccountInput>, 85 74 ) -> Response { 86 - let client_ip = extract_client_ip(&headers); 87 - if !state 88 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 89 - .await 90 - { 91 - warn!(ip = %client_ip, "Account creation rate limit exceeded"); 92 - return ApiError::RateLimitExceeded(Some( 93 - "Too many account creation attempts. Please try again later.".into(), 94 - )) 95 - .into_response(); 96 - } 97 - 98 75 let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( 99 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 76 + crate::util::get_header_str(&headers, "Authorization"), 100 77 ) { 101 78 let token = extracted.token; 102 79 if is_service_token(&token) { ··· 135 112 .map(|d| d.starts_with("did:web:")) 136 113 .unwrap_or(false); 137 114 138 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 139 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 115 + let hostname = pds_hostname(); 116 + let hostname_for_handles = pds_hostname_without_port(); 140 117 let pds_suffix = format!(".{}", hostname_for_handles); 141 118 142 119 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) { ··· 169 146 return ApiError::InvalidEmail.into_response(); 170 147 } 171 148 172 - if let Some(ref code) = input.invite_code { 173 - let valid = state 174 - .infra_repo 175 - .is_invite_code_valid(code) 176 - .await 177 - .unwrap_or(false); 178 - 179 - if !valid { 180 - return ApiError::InvalidInviteCode.into_response(); 149 + let _validated_invite_code = if let Some(ref code) = input.invite_code { 150 + match state.infra_repo.validate_invite_code(code).await { 151 + Ok(validated) => Some(validated), 152 + Err(_) => return ApiError::InvalidInviteCode.into_response(), 181 153 } 182 154 } else { 183 155 let invite_required = std::env::var("INVITE_CODE_REQUIRED") ··· 186 158 if invite_required { 187 159 return ApiError::InviteCodeRequired.into_response(); 188 160 } 189 - } 161 + None 162 + }; 190 163 191 164 let verification_channel = input.verification_channel.as_deref().unwrap_or("email"); 192 165 let verification_recipient = match verification_channel { ··· 268 241 } 269 242 if is_byod_did_web { 270 243 if let Some(ref auth_did) = byod_auth 271 - && d != auth_did 244 + && d != auth_did.as_str() 272 245 { 273 246 return ApiError::AuthorizationError(format!( 274 247 "Service token issuer {} does not match DID {}", ··· 280 253 } else { 281 254 if let Err(e) = crate::api::identity::did::verify_did_web( 282 255 d, 283 - &hostname, 256 + hostname, 284 257 &input.handle, 285 258 input.signing_key.as_deref(), 286 259 ) ··· 296 269 if let Some(ref auth_did) = byod_auth { 297 270 if let Some(ref provided_did) = input.did { 298 271 if provided_did.starts_with("did:plc:") { 299 - if provided_did != auth_did { 272 + if provided_did != auth_did.as_str() { 300 273 return ApiError::AuthorizationError(format!( 301 274 "Service token issuer {} does not match DID {}", 302 275 auth_did, provided_did ··· 389 362 } 390 363 }; 391 364 let rev = Tid::now(LimitedU32::MIN); 392 - let did_typed = Did::new_unchecked(&did); 365 + let did_typed = unsafe { Did::new_unchecked(&did) }; 393 366 let (commit_bytes, _sig) = 394 367 match create_signed_commit(&did_typed, mst_root, rev.as_ref(), None, &secret_key) { 395 368 Ok(result) => result, ··· 422 395 _ => tranquil_db_traits::CommsChannel::Email, 423 396 }; 424 397 425 - let handle_typed = Handle::new_unchecked(&handle); 398 + let handle_typed = unsafe { Handle::new_unchecked(&handle) }; 426 399 let create_input = tranquil_db_traits::CreatePasskeyAccountInput { 427 400 handle: handle_typed.clone(), 428 401 email: email.clone().unwrap_or_default(), ··· 484 457 { 485 458 warn!("Failed to sequence identity event for {}: {}", did, e); 486 459 } 487 - if let Err(e) = 488 - crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await 460 + if let Err(e) = crate::api::repo::record::sequence_account_event( 461 + &state, 462 + &did_typed, 463 + tranquil_db_traits::AccountStatus::Active, 464 + ) 465 + .await 489 466 { 490 467 warn!("Failed to sequence account event for {}: {}", did, e); 491 468 } ··· 493 470 "$type": "app.bsky.actor.profile", 494 471 "displayName": handle 495 472 }); 496 - let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile"); 497 - let profile_rkey = Rkey::new_unchecked("self"); 473 + let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") }; 474 + let profile_rkey = unsafe { Rkey::new_unchecked("self") }; 498 475 if let Err(e) = crate::api::repo::record::create_record_internal( 499 476 &state, 500 477 &did_typed, ··· 521 498 verification_channel, 522 499 &verification_recipient, 523 500 &formatted_token, 524 - &hostname, 501 + hostname, 525 502 ) 526 503 .await 527 504 { ··· 541 518 refresh_jti, 542 519 access_expires_at: token_meta.expires_at, 543 520 refresh_expires_at: refresh_expires, 544 - legacy_login: false, 521 + login_type: tranquil_db::LoginType::Modern, 545 522 mfa_verified: false, 546 523 scope: None, 547 524 controller_did: None, ··· 626 603 return ApiError::InvalidToken(None).into_response(); 627 604 } 628 605 629 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 630 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 631 - Ok(w) => w, 632 - Err(e) => { 633 - error!("Failed to create WebAuthn config: {:?}", e); 634 - return ApiError::InternalError(None).into_response(); 635 - } 636 - }; 606 + let webauthn = &state.webauthn_config; 637 607 638 608 let reg_state = match state 639 609 .user_repo ··· 768 738 return ApiError::InvalidToken(None).into_response(); 769 739 } 770 740 771 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 772 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 773 - Ok(w) => w, 774 - Err(e) => { 775 - error!("Failed to create WebAuthn config: {:?}", e); 776 - return ApiError::InternalError(None).into_response(); 777 - } 778 - }; 741 + let webauthn = &state.webauthn_config; 779 742 780 743 let existing_passkeys = state 781 744 .user_repo ··· 840 803 841 804 pub async fn request_passkey_recovery( 842 805 State(state): State<AppState>, 843 - headers: HeaderMap, 806 + _rate_limit: RateLimited<PasswordResetLimit>, 844 807 Json(input): Json<RequestPasskeyRecoveryInput>, 845 808 ) -> Response { 846 - let client_ip = extract_client_ip(&headers); 847 - if !state 848 - .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 849 - .await 850 - { 851 - return ApiError::RateLimitExceeded(None).into_response(); 852 - } 853 - 854 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 855 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 809 + let hostname_for_handles = pds_hostname_without_port(); 856 810 let identifier = input.email.trim().to_lowercase(); 857 811 let identifier = identifier.strip_prefix('@').unwrap_or(&identifier); 858 - let normalized_handle = if identifier.contains('@') || identifier.contains('.') { 859 - identifier.to_string() 860 - } else { 861 - format!("{}.{}", identifier, hostname_for_handles) 862 - }; 812 + let normalized_handle = 813 + NormalizedLoginIdentifier::normalize(&input.email, hostname_for_handles); 863 814 864 815 let user = match state 865 816 .user_repo 866 - .get_user_for_passkey_recovery(identifier, &normalized_handle) 817 + .get_user_for_passkey_recovery(identifier, normalized_handle.as_str()) 867 818 .await 868 819 { 869 820 Ok(Some(u)) if !u.password_required => u, ··· 890 841 return ApiError::InternalError(None).into_response(); 891 842 } 892 843 893 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 844 + let hostname = pds_hostname(); 894 845 let recovery_url = format!( 895 846 "https://{}/app/recover-passkey?did={}&token={}", 896 847 hostname, ··· 903 854 state.infra_repo.as_ref(), 904 855 user.id, 905 856 &recovery_url, 906 - &hostname, 857 + hostname, 907 858 ) 908 859 .await; 909 860
+20 -56
crates/tranquil-pds/src/api/server/passkeys.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::webauthn::WebAuthnConfig; 4 - use crate::auth::{Active, Auth}; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 + use crate::auth::{Active, Auth, require_legacy_session_mfa, require_reauth_window}; 5 4 use crate::state::AppState; 6 5 use axum::{ 7 6 Json, ··· 12 11 use tracing::{error, info, warn}; 13 12 use webauthn_rs::prelude::*; 14 13 15 - fn get_webauthn() -> Result<WebAuthnConfig, ApiError> { 16 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 17 - WebAuthnConfig::new(&hostname).map_err(|e| { 18 - error!("Failed to create WebAuthn config: {}", e); 19 - ApiError::InternalError(Some("WebAuthn configuration failed".into())) 20 - }) 21 - } 22 - 23 14 #[derive(Deserialize)] 24 15 #[serde(rename_all = "camelCase")] 25 16 pub struct StartRegistrationInput { ··· 37 28 auth: Auth<Active>, 38 29 Json(input): Json<StartRegistrationInput>, 39 30 ) -> Result<Response, ApiError> { 40 - let webauthn = get_webauthn()?; 31 + let webauthn = &state.webauthn_config; 41 32 42 33 let handle = state 43 34 .user_repo 44 35 .get_handle_by_did(&auth.did) 45 36 .await 46 - .map_err(|e| { 47 - error!("DB error fetching user: {:?}", e); 48 - ApiError::InternalError(None) 49 - })? 37 + .log_db_err("fetching user")? 50 38 .ok_or(ApiError::AccountNotFound)?; 51 39 52 40 let existing_passkeys = state 53 41 .user_repo 54 42 .get_passkeys_for_user(&auth.did) 55 43 .await 56 - .map_err(|e| { 57 - error!("DB error fetching existing passkeys: {:?}", e); 58 - ApiError::InternalError(None) 59 - })?; 44 + .log_db_err("fetching existing passkeys")?; 60 45 61 46 let exclude_credentials: Vec<CredentialID> = existing_passkeys 62 47 .iter() ··· 81 66 .user_repo 82 67 .save_webauthn_challenge(&auth.did, "registration", &state_json) 83 68 .await 84 - .map_err(|e| { 85 - error!("Failed to save registration state: {:?}", e); 86 - ApiError::InternalError(None) 87 - })?; 69 + .log_db_err("saving registration state")?; 88 70 89 71 let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({})); 90 72 ··· 112 94 auth: Auth<Active>, 113 95 Json(input): Json<FinishRegistrationInput>, 114 96 ) -> Result<Response, ApiError> { 115 - let webauthn = get_webauthn()?; 97 + let webauthn = &state.webauthn_config; 116 98 117 99 let reg_state_json = state 118 100 .user_repo 119 101 .load_webauthn_challenge(&auth.did, "registration") 120 102 .await 121 - .map_err(|e| { 122 - error!("DB error loading registration state: {:?}", e); 123 - ApiError::InternalError(None) 124 - })? 103 + .log_db_err("loading registration state")? 125 104 .ok_or(ApiError::NoRegistrationInProgress)?; 126 105 127 106 let reg_state: SecurityKeyRegistration = ··· 157 136 input.friendly_name.as_deref(), 158 137 ) 159 138 .await 160 - .map_err(|e| { 161 - error!("Failed to save passkey: {:?}", e); 162 - ApiError::InternalError(None) 163 - })?; 139 + .log_db_err("saving passkey")?; 164 140 165 141 if let Err(e) = state 166 142 .user_repo ··· 208 184 .user_repo 209 185 .get_passkeys_for_user(&auth.did) 210 186 .await 211 - .map_err(|e| { 212 - error!("DB error fetching passkeys: {:?}", e); 213 - ApiError::InternalError(None) 214 - })?; 187 + .log_db_err("fetching passkeys")?; 215 188 216 189 let passkey_infos: Vec<PasskeyInfo> = passkeys 217 190 .into_iter() ··· 241 214 auth: Auth<Active>, 242 215 Json(input): Json<DeletePasskeyInput>, 243 216 ) -> Result<Response, ApiError> { 244 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 245 - { 246 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 247 - &*state.user_repo, 248 - &*state.session_repo, 249 - &auth.did, 250 - ) 251 - .await); 252 - } 217 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 218 + Ok(proof) => proof, 219 + Err(response) => return Ok(response), 220 + }; 253 221 254 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 255 - return Ok(crate::api::server::reauth::reauth_required_response( 256 - &*state.user_repo, 257 - &*state.session_repo, 258 - &auth.did, 259 - ) 260 - .await); 261 - } 222 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 223 + Ok(proof) => proof, 224 + Err(response) => return Ok(response), 225 + }; 262 226 263 227 let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 264 228 265 - match state.user_repo.delete_passkey(id, &auth.did).await { 229 + match state.user_repo.delete_passkey(id, reauth_mfa.did()).await { 266 230 Ok(true) => { 267 - info!(did = %auth.did, passkey_id = %id, "Passkey deleted"); 231 + info!(did = %session_mfa.did(), passkey_id = %id, "Passkey deleted"); 268 232 Ok(EmptyResponse::ok().into_response()) 269 233 } 270 234 Ok(false) => Err(ApiError::PasskeyNotFound),
+64 -170
crates/tranquil-pds/src/api/server/password.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse}; 3 - use crate::auth::{Active, Auth}; 4 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::auth::{ 4 + Active, Auth, NormalizedLoginIdentifier, require_legacy_session_mfa, require_reauth_window, 5 + require_reauth_window_if_available, 6 + }; 7 + use crate::rate_limit::{PasswordResetLimit, RateLimited, ResetPasswordLimit}; 8 + use crate::state::AppState; 5 9 use crate::types::PlainPassword; 10 + use crate::util::{pds_hostname, pds_hostname_without_port}; 6 11 use crate::validation::validate_password; 7 12 use axum::{ 8 13 Json, 9 14 extract::State, 10 - http::HeaderMap, 11 15 response::{IntoResponse, Response}, 12 16 }; 13 - use bcrypt::{DEFAULT_COST, hash, verify}; 17 + use bcrypt::{DEFAULT_COST, hash}; 14 18 use chrono::{Duration, Utc}; 15 19 use serde::Deserialize; 16 20 use tracing::{error, info, warn}; ··· 18 22 fn generate_reset_code() -> String { 19 23 crate::util::generate_token_code() 20 24 } 21 - fn extract_client_ip(headers: &HeaderMap) -> String { 22 - if let Some(forwarded) = headers.get("x-forwarded-for") 23 - && let Ok(value) = forwarded.to_str() 24 - && let Some(first_ip) = value.split(',').next() 25 - { 26 - return first_ip.trim().to_string(); 27 - } 28 - if let Some(real_ip) = headers.get("x-real-ip") 29 - && let Ok(value) = real_ip.to_str() 30 - { 31 - return value.trim().to_string(); 32 - } 33 - "unknown".to_string() 34 - } 35 25 36 26 #[derive(Deserialize)] 37 27 pub struct RequestPasswordResetInput { ··· 41 31 42 32 pub async fn request_password_reset( 43 33 State(state): State<AppState>, 44 - headers: HeaderMap, 34 + _rate_limit: RateLimited<PasswordResetLimit>, 45 35 Json(input): Json<RequestPasswordResetInput>, 46 36 ) -> Response { 47 - let client_ip = extract_client_ip(&headers); 48 - if !state 49 - .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 50 - .await 51 - { 52 - warn!(ip = %client_ip, "Password reset rate limit exceeded"); 53 - return ApiError::RateLimitExceeded(None).into_response(); 54 - } 55 37 let identifier = input.email.trim(); 56 38 if identifier.is_empty() { 57 39 return ApiError::InvalidRequest("email or handle is required".into()).into_response(); 58 40 } 59 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 60 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 41 + let hostname_for_handles = pds_hostname_without_port(); 61 42 let normalized = identifier.to_lowercase(); 62 43 let normalized = normalized.strip_prefix('@').unwrap_or(&normalized); 63 44 let is_email_lookup = normalized.contains('@'); 64 - let normalized_handle = if normalized.contains('@') || normalized.contains('.') { 65 - normalized.to_string() 66 - } else { 67 - format!("{}.{}", normalized, hostname_for_handles) 68 - }; 45 + let normalized_handle = NormalizedLoginIdentifier::normalize(identifier, hostname_for_handles); 69 46 70 47 let multiple_accounts_warning = if is_email_lookup { 71 48 match state.user_repo.count_accounts_by_email(normalized).await { ··· 78 55 79 56 let user_id = match state 80 57 .user_repo 81 - .get_id_by_email_or_handle(normalized, &normalized_handle) 58 + .get_id_by_email_or_handle(normalized, normalized_handle.as_str()) 82 59 .await 83 60 { 84 61 Ok(Some(id)) => id, ··· 101 78 error!("DB error setting reset code: {:?}", e); 102 79 return ApiError::InternalError(None).into_response(); 103 80 } 104 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 81 + let hostname = pds_hostname(); 105 82 if let Err(e) = crate::comms::comms_repo::enqueue_password_reset( 106 83 state.user_repo.as_ref(), 107 84 state.infra_repo.as_ref(), 108 85 user_id, 109 86 &code, 110 - &hostname, 87 + hostname, 111 88 ) 112 89 .await 113 90 { ··· 135 112 136 113 pub async fn reset_password( 137 114 State(state): State<AppState>, 138 - headers: HeaderMap, 115 + _rate_limit: RateLimited<ResetPasswordLimit>, 139 116 Json(input): Json<ResetPasswordInput>, 140 117 ) -> Response { 141 - let client_ip = extract_client_ip(&headers); 142 - if !state 143 - .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 144 - .await 145 - { 146 - warn!(ip = %client_ip, "Reset password rate limit exceeded"); 147 - return ApiError::RateLimitExceeded(None).into_response(); 148 - } 149 118 let token = input.token.trim(); 150 119 let password = &input.password; 151 120 if token.is_empty() { ··· 230 199 auth: Auth<Active>, 231 200 Json(input): Json<ChangePasswordInput>, 232 201 ) -> Result<Response, ApiError> { 233 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 234 - { 235 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 236 - &*state.user_repo, 237 - &*state.session_repo, 238 - &auth.did, 239 - ) 240 - .await); 241 - } 202 + use crate::auth::verify_password_mfa; 242 203 243 - let current_password = &input.current_password; 244 - let new_password = &input.new_password; 245 - if current_password.is_empty() { 204 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 205 + Ok(proof) => proof, 206 + Err(response) => return Ok(response), 207 + }; 208 + 209 + if input.current_password.is_empty() { 246 210 return Err(ApiError::InvalidRequest( 247 211 "currentPassword is required".into(), 248 212 )); 249 213 } 250 - if new_password.is_empty() { 214 + if input.new_password.is_empty() { 251 215 return Err(ApiError::InvalidRequest("newPassword is required".into())); 252 216 } 253 - if let Err(e) = validate_password(new_password) { 217 + if let Err(e) = validate_password(&input.new_password) { 254 218 return Err(ApiError::InvalidRequest(e.to_string())); 255 219 } 220 + 221 + let password_mfa = verify_password_mfa(&state, &auth, &input.current_password).await?; 222 + 256 223 let user = state 257 224 .user_repo 258 - .get_id_and_password_hash_by_did(&auth.did) 225 + .get_id_and_password_hash_by_did(password_mfa.did()) 259 226 .await 260 - .map_err(|e| { 261 - error!("DB error in change_password: {:?}", e); 262 - ApiError::InternalError(None) 263 - })? 227 + .log_db_err("in change_password")? 264 228 .ok_or(ApiError::AccountNotFound)?; 265 229 266 - let (user_id, password_hash) = (user.id, user.password_hash); 267 - let valid = verify(current_password, &password_hash).map_err(|e| { 268 - error!("Password verification error: {:?}", e); 269 - ApiError::InternalError(None) 270 - })?; 271 - if !valid { 272 - return Err(ApiError::InvalidPassword( 273 - "Current password is incorrect".into(), 274 - )); 275 - } 276 - let new_password_clone = new_password.to_string(); 230 + let new_password_clone = input.new_password.to_string(); 277 231 let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 278 232 .await 279 233 .map_err(|e| { ··· 287 241 288 242 state 289 243 .user_repo 290 - .update_password_hash(user_id, &new_hash) 244 + .update_password_hash(user.id, &new_hash) 291 245 .await 292 - .map_err(|e| { 293 - error!("DB error updating password: {:?}", e); 294 - ApiError::InternalError(None) 295 - })?; 246 + .log_db_err("updating password")?; 296 247 297 - info!(did = %&auth.did, "Password changed successfully"); 248 + info!(did = %session_mfa.did(), "Password changed successfully"); 298 249 Ok(EmptyResponse::ok().into_response()) 299 250 } 300 251 ··· 302 253 State(state): State<AppState>, 303 254 auth: Auth<Active>, 304 255 ) -> Result<Response, ApiError> { 305 - match state.user_repo.has_password_by_did(&auth.did).await { 306 - Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()), 307 - Ok(None) => Err(ApiError::AccountNotFound), 308 - Err(e) => { 309 - error!("DB error: {:?}", e); 310 - Err(ApiError::InternalError(None)) 311 - } 312 - } 256 + let has = state 257 + .user_repo 258 + .has_password_by_did(&auth.did) 259 + .await 260 + .log_db_err("checking password status")? 261 + .ok_or(ApiError::AccountNotFound)?; 262 + Ok(HasPasswordResponse::response(has).into_response()) 313 263 } 314 264 315 265 pub async fn remove_password( 316 266 State(state): State<AppState>, 317 267 auth: Auth<Active>, 318 268 ) -> Result<Response, ApiError> { 319 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 320 - { 321 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 322 - &*state.user_repo, 323 - &*state.session_repo, 324 - &auth.did, 325 - ) 326 - .await); 327 - } 269 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 270 + Ok(proof) => proof, 271 + Err(response) => return Ok(response), 272 + }; 328 273 329 - if crate::api::server::reauth::check_reauth_required_cached( 330 - &*state.session_repo, 331 - &state.cache, 332 - &auth.did, 333 - ) 334 - .await 335 - { 336 - return Ok(crate::api::server::reauth::reauth_required_response( 337 - &*state.user_repo, 338 - &*state.session_repo, 339 - &auth.did, 340 - ) 341 - .await); 342 - } 274 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 275 + Ok(proof) => proof, 276 + Err(response) => return Ok(response), 277 + }; 343 278 344 279 let has_passkeys = state 345 280 .user_repo 346 - .has_passkeys(&auth.did) 281 + .has_passkeys(reauth_mfa.did()) 347 282 .await 348 283 .unwrap_or(false); 349 284 if !has_passkeys { ··· 354 289 355 290 let user = state 356 291 .user_repo 357 - .get_password_info_by_did(&auth.did) 292 + .get_password_info_by_did(reauth_mfa.did()) 358 293 .await 359 - .map_err(|e| { 360 - error!("DB error: {:?}", e); 361 - ApiError::InternalError(None) 362 - })? 294 + .log_db_err("getting password info")? 363 295 .ok_or(ApiError::AccountNotFound)?; 364 296 365 297 if user.password_hash.is_none() { ··· 372 304 .user_repo 373 305 .remove_user_password(user.id) 374 306 .await 375 - .map_err(|e| { 376 - error!("DB error removing password: {:?}", e); 377 - ApiError::InternalError(None) 378 - })?; 307 + .log_db_err("removing password")?; 379 308 380 - info!(did = %&auth.did, "Password removed - account is now passkey-only"); 309 + info!(did = %session_mfa.did(), "Password removed - account is now passkey-only"); 381 310 Ok(SuccessResponse::ok().into_response()) 382 311 } 383 312 ··· 392 321 auth: Auth<Active>, 393 322 Json(input): Json<SetPasswordInput>, 394 323 ) -> Result<Response, ApiError> { 395 - let has_password = state 396 - .user_repo 397 - .has_password_by_did(&auth.did) 398 - .await 399 - .ok() 400 - .flatten() 401 - .unwrap_or(false); 402 - let has_passkeys = state 403 - .user_repo 404 - .has_passkeys(&auth.did) 405 - .await 406 - .unwrap_or(false); 407 - let has_totp = state 408 - .user_repo 409 - .has_totp_enabled(&auth.did) 410 - .await 411 - .unwrap_or(false); 412 - 413 - let has_any_reauth_method = has_password || has_passkeys || has_totp; 414 - 415 - if has_any_reauth_method 416 - && crate::api::server::reauth::check_reauth_required_cached( 417 - &*state.session_repo, 418 - &state.cache, 419 - &auth.did, 420 - ) 421 - .await 422 - { 423 - return Ok(crate::api::server::reauth::reauth_required_response( 424 - &*state.user_repo, 425 - &*state.session_repo, 426 - &auth.did, 427 - ) 428 - .await); 429 - } 324 + let reauth_mfa = match require_reauth_window_if_available(&state, &auth).await { 325 + Ok(proof) => proof, 326 + Err(response) => return Ok(response), 327 + }; 430 328 431 329 let new_password = &input.new_password; 432 330 if new_password.is_empty() { ··· 436 334 return Err(ApiError::InvalidRequest(e.to_string())); 437 335 } 438 336 337 + let did = reauth_mfa.as_ref().map(|m| m.did()).unwrap_or(&auth.did); 338 + 439 339 let user = state 440 340 .user_repo 441 - .get_password_info_by_did(&auth.did) 341 + .get_password_info_by_did(did) 442 342 .await 443 - .map_err(|e| { 444 - error!("DB error: {:?}", e); 445 - ApiError::InternalError(None) 446 - })? 343 + .log_db_err("getting password info")? 447 344 .ok_or(ApiError::AccountNotFound)?; 448 345 449 346 if user.password_hash.is_some() { ··· 468 365 .user_repo 469 366 .set_new_user_password(user.id, &new_hash) 470 367 .await 471 - .map_err(|e| { 472 - error!("DB error setting password: {:?}", e); 473 - ApiError::InternalError(None) 474 - })?; 368 + .log_db_err("setting password")?; 475 369 476 - info!(did = %&auth.did, "Password set for passkey-only account"); 370 + info!(did = %did, "Password set for passkey-only account"); 477 371 Ok(SuccessResponse::ok().into_response()) 478 372 }
+22 -59
crates/tranquil-pds/src/api/server/reauth.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use axum::{ 3 3 Json, 4 4 extract::State, ··· 11 11 use tranquil_db_traits::{SessionRepository, UserRepository}; 12 12 13 13 use crate::auth::{Active, Auth}; 14 - use crate::state::{AppState, RateLimitKind}; 14 + use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message}; 15 + use crate::state::AppState; 15 16 use crate::types::PlainPassword; 16 17 17 - const REAUTH_WINDOW_SECONDS: i64 = 300; 18 + pub const REAUTH_WINDOW_SECONDS: i64 = 300; 18 19 19 20 #[derive(Serialize)] 20 21 #[serde(rename_all = "camelCase")] ··· 32 33 .session_repo 33 34 .get_last_reauth_at(&auth.did) 34 35 .await 35 - .map_err(|e| { 36 - error!("DB error: {:?}", e); 37 - ApiError::InternalError(None) 38 - })?; 36 + .log_db_err("getting last reauth")?; 39 37 40 38 let reauth_required = is_reauth_required(last_reauth_at); 41 39 let available_methods = ··· 70 68 .user_repo 71 69 .get_password_hash_by_did(&auth.did) 72 70 .await 73 - .map_err(|e| { 74 - error!("DB error: {:?}", e); 75 - ApiError::InternalError(None) 76 - })? 71 + .log_db_err("fetching password hash")? 77 72 .ok_or(ApiError::AccountNotFound)?; 78 73 79 74 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); ··· 97 92 98 93 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 99 94 .await 100 - .map_err(|e| { 101 - error!("DB error updating reauth: {:?}", e); 102 - ApiError::InternalError(None) 103 - })?; 95 + .log_db_err("updating reauth")?; 104 96 105 97 info!(did = %&auth.did, "Re-auth successful via password"); 106 98 Ok(Json(ReauthResponse { reauthed_at }).into_response()) ··· 117 109 auth: Auth<Active>, 118 110 Json(input): Json<TotpReauthInput>, 119 111 ) -> Result<Response, ApiError> { 120 - if !state 121 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 122 - .await 123 - { 124 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 125 - return Err(ApiError::RateLimitExceeded(Some( 126 - "Too many verification attempts. Please try again in a few minutes.".into(), 127 - ))); 128 - } 112 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 113 + &state, 114 + &auth.did, 115 + "Too many verification attempts. Please try again in a few minutes.", 116 + ) 117 + .await?; 129 118 130 119 let valid = 131 120 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.did, &input.code) ··· 140 129 141 130 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 142 131 .await 143 - .map_err(|e| { 144 - error!("DB error updating reauth: {:?}", e); 145 - ApiError::InternalError(None) 146 - })?; 132 + .log_db_err("updating reauth")?; 147 133 148 134 info!(did = %&auth.did, "Re-auth successful via TOTP"); 149 135 Ok(Json(ReauthResponse { reauthed_at }).into_response()) ··· 159 145 State(state): State<AppState>, 160 146 auth: Auth<Active>, 161 147 ) -> Result<Response, ApiError> { 162 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 163 - 164 148 let stored_passkeys = state 165 149 .user_repo 166 150 .get_passkeys_for_user(&auth.did) 167 151 .await 168 - .map_err(|e| { 169 - error!("Failed to get passkeys: {:?}", e); 170 - ApiError::InternalError(None) 171 - })?; 152 + .log_db_err("getting passkeys")?; 172 153 173 154 if stored_passkeys.is_empty() { 174 155 return Err(ApiError::NoPasskeys); ··· 185 166 ))); 186 167 } 187 168 188 - let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 189 - error!("Failed to create WebAuthn config: {:?}", e); 190 - ApiError::InternalError(None) 191 - })?; 169 + let webauthn = &state.webauthn_config; 192 170 193 171 let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| { 194 172 error!("Failed to start passkey authentication: {:?}", e); ··· 204 182 .user_repo 205 183 .save_webauthn_challenge(&auth.did, "authentication", &state_json) 206 184 .await 207 - .map_err(|e| { 208 - error!("Failed to save authentication state: {:?}", e); 209 - ApiError::InternalError(None) 210 - })?; 185 + .log_db_err("saving authentication state")?; 211 186 212 187 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 213 188 Ok(Json(PasskeyReauthStartResponse { options }).into_response()) ··· 224 199 auth: Auth<Active>, 225 200 Json(input): Json<PasskeyReauthFinishInput>, 226 201 ) -> Result<Response, ApiError> { 227 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 228 - 229 202 let auth_state_json = state 230 203 .user_repo 231 204 .load_webauthn_challenge(&auth.did, "authentication") 232 205 .await 233 - .map_err(|e| { 234 - error!("Failed to load authentication state: {:?}", e); 235 - ApiError::InternalError(None) 236 - })? 206 + .log_db_err("loading authentication state")? 237 207 .ok_or(ApiError::NoChallengeInProgress)?; 238 208 239 209 let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = ··· 248 218 ApiError::InvalidCredential 249 219 })?; 250 220 251 - let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 252 - error!("Failed to create WebAuthn config: {:?}", e); 253 - ApiError::InternalError(None) 254 - })?; 255 - 256 - let auth_result = webauthn 221 + let auth_result = state 222 + .webauthn_config 257 223 .finish_authentication(&credential, &auth_state) 258 224 .map_err(|e| { 259 225 warn!(did = %&auth.did, "Passkey re-auth failed: {:?}", e); ··· 287 253 288 254 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 289 255 .await 290 - .map_err(|e| { 291 - error!("DB error updating reauth: {:?}", e); 292 - ApiError::InternalError(None) 293 - })?; 256 + .log_db_err("updating reauth")?; 294 257 295 258 info!(did = %&auth.did, "Re-auth successful via passkey"); 296 259 Ok(Json(ReauthResponse { reauthed_at }).into_response()) ··· 418 381 ) -> bool { 419 382 match session_repo.get_session_mfa_status(did).await { 420 383 Ok(Some(status)) => { 421 - if !status.legacy_login { 384 + if status.login_type.is_modern() { 422 385 return true; 423 386 } 424 387 if status.mfa_verified {
+3 -3
crates/tranquil-pds/src/api/server/service_auth.rs
··· 51 51 headers: axum::http::HeaderMap, 52 52 Query(params): Query<GetServiceAuthParams>, 53 53 ) -> Response { 54 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 55 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 54 + let auth_header = crate::util::get_header_str(&headers, "Authorization"); 55 + let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 56 56 info!( 57 57 has_auth_header = auth_header.is_some(), 58 58 has_dpop_proof = dpop_proof.is_some(), ··· 94 94 .await 95 95 { 96 96 Ok(result) => crate::auth::AuthenticatedUser { 97 - did: Did::new_unchecked(result.did), 97 + did: unsafe { Did::new_unchecked(result.did) }, 98 98 is_admin: false, 99 99 status: AccountStatus::Active, 100 100 scope: result.scope,
+88 -161
crates/tranquil-pds/src/api/server/session.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, SuccessResponse}; 3 - use crate::auth::{Active, Auth, Permissive}; 4 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::auth::{ 4 + Active, Auth, NormalizedLoginIdentifier, Permissive, require_legacy_session_mfa, 5 + require_reauth_window, 6 + }; 7 + use crate::rate_limit::{LoginLimit, RateLimited, RefreshSessionLimit}; 8 + use crate::state::AppState; 5 9 use crate::types::{AccountState, Did, Handle, PlainPassword}; 10 + use crate::util::{pds_hostname, pds_hostname_without_port}; 6 11 use axum::{ 7 12 Json, 8 13 extract::State, ··· 13 18 use serde::{Deserialize, Serialize}; 14 19 use serde_json::json; 15 20 use tracing::{error, info, warn}; 21 + use tranquil_db_traits::{SessionId, TokenFamilyId}; 16 22 use tranquil_types::TokenId; 17 23 18 - fn extract_client_ip(headers: &HeaderMap) -> String { 19 - if let Some(forwarded) = headers.get("x-forwarded-for") 20 - && let Ok(value) = forwarded.to_str() 21 - && let Some(first_ip) = value.split(',').next() 22 - { 23 - return first_ip.trim().to_string(); 24 - } 25 - if let Some(real_ip) = headers.get("x-real-ip") 26 - && let Ok(value) = real_ip.to_str() 27 - { 28 - return value.trim().to_string(); 29 - } 30 - "unknown".to_string() 31 - } 32 - 33 - fn normalize_handle(identifier: &str, pds_hostname: &str) -> String { 34 - let identifier = identifier.trim(); 35 - if identifier.contains('@') || identifier.starts_with("did:") { 36 - identifier.to_string() 37 - } else if !identifier.contains('.') { 38 - format!("{}.{}", identifier.to_lowercase(), pds_hostname) 39 - } else { 40 - identifier.to_lowercase() 41 - } 42 - } 43 - 44 24 fn full_handle(stored_handle: &str, _pds_hostname: &str) -> String { 45 25 stored_handle.to_string() 46 26 } ··· 75 55 76 56 pub async fn create_session( 77 57 State(state): State<AppState>, 78 - headers: HeaderMap, 58 + rate_limit: RateLimited<LoginLimit>, 79 59 Json(input): Json<CreateSessionInput>, 80 60 ) -> Response { 61 + let client_ip = rate_limit.client_ip(); 81 62 info!( 82 63 "create_session called with identifier: {}", 83 64 input.identifier 84 65 ); 85 - let client_ip = extract_client_ip(&headers); 86 - if !state 87 - .check_rate_limit(RateLimitKind::Login, &client_ip) 88 - .await 89 - { 90 - warn!(ip = %client_ip, "Login rate limit exceeded"); 91 - return ApiError::RateLimitExceeded(None).into_response(); 92 - } 93 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 94 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 95 - let normalized_identifier = normalize_handle(&input.identifier, hostname_for_handles); 66 + let pds_host = pds_hostname(); 67 + let hostname_for_handles = pds_hostname_without_port(); 68 + let normalized_identifier = 69 + NormalizedLoginIdentifier::normalize(&input.identifier, hostname_for_handles); 96 70 info!( 97 71 "Normalized identifier: {} -> {}", 98 72 input.identifier, normalized_identifier 99 73 ); 100 74 let row = match state 101 75 .user_repo 102 - .get_login_full_by_identifier(&normalized_identifier) 76 + .get_login_full_by_identifier(normalized_identifier.as_str()) 103 77 .await 104 78 { 105 79 Ok(Some(row)) => row, ··· 165 139 warn!("Login attempt for takendown account: {}", row.did); 166 140 return ApiError::AccountTakedown.into_response(); 167 141 } 168 - let is_verified = 169 - row.email_verified || row.discord_verified || row.telegram_verified || row.signal_verified; 142 + let is_verified = row.channel_verification.has_any_verified(); 170 143 let is_delegated = state 171 144 .delegation_repo 172 145 .is_delegated_account(&row.did) ··· 226 199 refresh_jti: refresh_meta.jti.clone(), 227 200 access_expires_at: access_meta.expires_at, 228 201 refresh_expires_at: refresh_meta.expires_at, 229 - legacy_login: is_legacy_login, 202 + login_type: tranquil_db_traits::LoginType::from(is_legacy_login), 230 203 mfa_verified: false, 231 204 scope: app_password_scopes.clone(), 232 205 controller_did: app_password_controller.clone(), ··· 246 219 ip = %client_ip, 247 220 "Legacy login on TOTP-enabled account - sending notification" 248 221 ); 249 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 222 + let hostname = pds_hostname(); 250 223 if let Err(e) = crate::comms::comms_repo::enqueue_legacy_login( 251 224 state.user_repo.as_ref(), 252 225 state.infra_repo.as_ref(), 253 226 row.id, 254 - &hostname, 255 - &client_ip, 227 + hostname, 228 + client_ip, 256 229 row.preferred_comms_channel, 257 230 ) 258 231 .await ··· 260 233 error!("Failed to queue legacy login notification: {:?}", e); 261 234 } 262 235 } 263 - let handle = full_handle(&row.handle, &pds_hostname); 236 + let handle = full_handle(&row.handle, pds_host); 264 237 let is_active = account_state.is_active(); 265 238 let status = account_state.status_for_session().map(String::from); 266 239 Json(CreateSessionOutput { ··· 270 243 did: row.did, 271 244 did_doc, 272 245 email: row.email, 273 - email_confirmed: Some(row.email_verified), 246 + email_confirmed: Some(row.channel_verification.email), 274 247 active: Some(is_active), 275 248 status, 276 249 }) ··· 292 265 ); 293 266 match db_result { 294 267 Ok(Some(row)) => { 295 - let (preferred_channel, preferred_channel_verified) = match row.preferred_comms_channel 296 - { 297 - tranquil_db_traits::CommsChannel::Email => ("email", row.email_verified), 298 - tranquil_db_traits::CommsChannel::Discord => ("discord", row.discord_verified), 299 - tranquil_db_traits::CommsChannel::Telegram => ("telegram", row.telegram_verified), 300 - tranquil_db_traits::CommsChannel::Signal => ("signal", row.signal_verified), 268 + let preferred_channel = match row.preferred_comms_channel { 269 + tranquil_db_traits::CommsChannel::Email => "email", 270 + tranquil_db_traits::CommsChannel::Discord => "discord", 271 + tranquil_db_traits::CommsChannel::Telegram => "telegram", 272 + tranquil_db_traits::CommsChannel::Signal => "signal", 301 273 }; 302 - let pds_hostname = 303 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 304 - let handle = full_handle(&row.handle, &pds_hostname); 274 + let preferred_channel_verified = row 275 + .channel_verification 276 + .is_verified(row.preferred_comms_channel); 277 + let pds_hostname = pds_hostname(); 278 + let handle = full_handle(&row.handle, pds_hostname); 305 279 let account_state = AccountState::from_db_fields( 306 280 row.deactivated_at, 307 281 row.takedown_ref.clone(), ··· 313 287 } else { 314 288 None 315 289 }; 316 - let email_confirmed_value = can_read_email && row.email_verified; 290 + let email_confirmed_value = can_read_email && row.channel_verification.email; 317 291 let mut response = json!({ 318 292 "handle": handle, 319 293 "did": &auth.did, ··· 352 326 headers: axum::http::HeaderMap, 353 327 _auth: Auth<Active>, 354 328 ) -> Result<Response, ApiError> { 355 - let extracted = crate::auth::extract_auth_token_from_header( 356 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 357 - ) 329 + let extracted = crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 330 + &headers, 331 + "Authorization", 332 + )) 358 333 .ok_or(ApiError::AuthenticationRequired)?; 359 334 let jti = crate::auth::get_jti_from_token(&extracted.token) 360 335 .map_err(|_| ApiError::AuthenticationFailed(None))?; ··· 374 349 375 350 pub async fn refresh_session( 376 351 State(state): State<AppState>, 352 + _rate_limit: RateLimited<RefreshSessionLimit>, 377 353 headers: axum::http::HeaderMap, 378 354 ) -> Response { 379 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 380 - if !state 381 - .check_rate_limit(RateLimitKind::RefreshSession, &client_ip) 382 - .await 383 - { 384 - tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 385 - return ApiError::RateLimitExceeded(None).into_response(); 386 - } 387 - let extracted = match crate::auth::extract_auth_token_from_header( 388 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 389 - ) { 355 + let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 356 + &headers, 357 + "Authorization", 358 + )) { 390 359 Some(t) => t, 391 360 None => return ApiError::AuthenticationRequired.into_response(), 392 361 }; ··· 503 472 ); 504 473 match db_result { 505 474 Ok(Some(u)) => { 506 - let (preferred_channel, preferred_channel_verified) = match u.preferred_comms_channel { 507 - tranquil_db_traits::CommsChannel::Email => ("email", u.email_verified), 508 - tranquil_db_traits::CommsChannel::Discord => ("discord", u.discord_verified), 509 - tranquil_db_traits::CommsChannel::Telegram => ("telegram", u.telegram_verified), 510 - tranquil_db_traits::CommsChannel::Signal => ("signal", u.signal_verified), 475 + let preferred_channel = match u.preferred_comms_channel { 476 + tranquil_db_traits::CommsChannel::Email => "email", 477 + tranquil_db_traits::CommsChannel::Discord => "discord", 478 + tranquil_db_traits::CommsChannel::Telegram => "telegram", 479 + tranquil_db_traits::CommsChannel::Signal => "signal", 511 480 }; 512 - let pds_hostname = 513 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 514 - let handle = full_handle(&u.handle, &pds_hostname); 481 + let preferred_channel_verified = u 482 + .channel_verification 483 + .is_verified(u.preferred_comms_channel); 484 + let pds_hostname = pds_hostname(); 485 + let handle = full_handle(&u.handle, pds_hostname); 515 486 let account_state = 516 487 AccountState::from_db_fields(u.deactivated_at, u.takedown_ref.clone(), None, None); 517 488 let mut response = json!({ ··· 520 491 "handle": handle, 521 492 "did": session_row.did, 522 493 "email": u.email, 523 - "emailConfirmed": u.email_verified, 494 + "emailConfirmed": u.channel_verification.email, 524 495 "preferredChannel": preferred_channel, 525 496 "preferredChannelVerified": preferred_channel_verified, 526 497 "preferredLocale": u.preferred_locale, ··· 664 635 refresh_jti: refresh_meta.jti.clone(), 665 636 access_expires_at: access_meta.expires_at, 666 637 refresh_expires_at: refresh_meta.expires_at, 667 - legacy_login: false, 638 + login_type: tranquil_db_traits::LoginType::Modern, 668 639 mfa_verified: false, 669 640 scope: None, 670 641 controller_did: None, ··· 675 646 return ApiError::InternalError(None).into_response(); 676 647 } 677 648 678 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 649 + let hostname = pds_hostname(); 679 650 if let Err(e) = crate::comms::comms_repo::enqueue_welcome( 680 651 state.user_repo.as_ref(), 681 652 state.infra_repo.as_ref(), 682 653 row.id, 683 - &hostname, 654 + hostname, 684 655 ) 685 656 .await 686 657 { ··· 731 702 return ApiError::InternalError(None).into_response(); 732 703 } 733 704 }; 734 - let is_verified = 735 - row.email_verified || row.discord_verified || row.telegram_verified || row.signal_verified; 705 + let is_verified = row.channel_verification.has_any_verified(); 736 706 if is_verified { 737 707 return ApiError::InvalidRequest("Account is already verified".into()).into_response(); 738 708 } ··· 756 726 let formatted_token = 757 727 crate::auth::verification_token::format_token_for_display(&verification_token); 758 728 759 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 729 + let hostname = pds_hostname(); 760 730 if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 761 731 state.infra_repo.as_ref(), 762 732 row.id, 763 733 channel_str, 764 734 &recipient, 765 735 &formatted_token, 766 - &hostname, 736 + hostname, 767 737 ) 768 738 .await 769 739 { ··· 804 774 .session_repo 805 775 .list_sessions_by_did(&auth.did) 806 776 .await 807 - .map_err(|e| { 808 - error!("DB error fetching JWT sessions: {:?}", e); 809 - ApiError::InternalError(None) 810 - })?; 777 + .log_db_err("fetching JWT sessions")?; 811 778 812 779 let oauth_rows = state 813 780 .oauth_repo 814 781 .list_sessions_by_did(&auth.did) 815 782 .await 816 - .map_err(|e| { 817 - error!("DB error fetching OAuth sessions: {:?}", e); 818 - ApiError::InternalError(None) 819 - })?; 783 + .log_db_err("fetching OAuth sessions")?; 820 784 821 785 let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 822 786 id: format!("jwt:{}", row.id), ··· 869 833 Json(input): Json<RevokeSessionInput>, 870 834 ) -> Result<Response, ApiError> { 871 835 if let Some(jwt_id) = input.session_id.strip_prefix("jwt:") { 872 - let session_id: i32 = jwt_id 873 - .parse() 836 + let session_id = jwt_id 837 + .parse::<i32>() 838 + .map(SessionId::new) 874 839 .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 875 840 let access_jti = state 876 841 .session_repo 877 842 .get_session_access_jti_by_id(session_id, &auth.did) 878 843 .await 879 - .map_err(|e| { 880 - error!("DB error in revoke_session: {:?}", e); 881 - ApiError::InternalError(None) 882 - })? 844 + .log_db_err("in revoke_session")? 883 845 .ok_or(ApiError::SessionNotFound)?; 884 846 state 885 847 .session_repo 886 848 .delete_session_by_id(session_id) 887 849 .await 888 - .map_err(|e| { 889 - error!("DB error deleting session: {:?}", e); 890 - ApiError::InternalError(None) 891 - })?; 850 + .log_db_err("deleting session")?; 892 851 let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti); 893 852 if let Err(e) = state.cache.delete(&cache_key).await { 894 853 warn!("Failed to invalidate session cache: {:?}", e); 895 854 } 896 855 info!(did = %&auth.did, session_id = %session_id, "JWT session revoked"); 897 856 } else if let Some(oauth_id) = input.session_id.strip_prefix("oauth:") { 898 - let session_id: i32 = oauth_id 899 - .parse() 857 + let session_id = oauth_id 858 + .parse::<i32>() 859 + .map(TokenFamilyId::new) 900 860 .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 901 861 let deleted = state 902 862 .oauth_repo 903 863 .delete_session_by_id(session_id, &auth.did) 904 864 .await 905 - .map_err(|e| { 906 - error!("DB error deleting OAuth session: {:?}", e); 907 - ApiError::InternalError(None) 908 - })?; 865 + .log_db_err("deleting OAuth session")?; 909 866 if deleted == 0 { 910 867 return Err(ApiError::SessionNotFound); 911 868 } ··· 932 889 .session_repo 933 890 .delete_sessions_by_did(&auth.did) 934 891 .await 935 - .map_err(|e| { 936 - error!("DB error revoking JWT sessions: {:?}", e); 937 - ApiError::InternalError(None) 938 - })?; 892 + .log_db_err("revoking JWT sessions")?; 939 893 let jti_typed = TokenId::from(jti.clone()); 940 894 state 941 895 .oauth_repo 942 896 .delete_sessions_by_did_except(&auth.did, &jti_typed) 943 897 .await 944 - .map_err(|e| { 945 - error!("DB error revoking OAuth sessions: {:?}", e); 946 - ApiError::InternalError(None) 947 - })?; 898 + .log_db_err("revoking OAuth sessions")?; 948 899 } else { 949 900 state 950 901 .session_repo 951 902 .delete_sessions_by_did_except_jti(&auth.did, &jti) 952 903 .await 953 - .map_err(|e| { 954 - error!("DB error revoking JWT sessions: {:?}", e); 955 - ApiError::InternalError(None) 956 - })?; 904 + .log_db_err("revoking JWT sessions")?; 957 905 state 958 906 .oauth_repo 959 907 .delete_sessions_by_did(&auth.did) 960 908 .await 961 - .map_err(|e| { 962 - error!("DB error revoking OAuth sessions: {:?}", e); 963 - ApiError::InternalError(None) 964 - })?; 909 + .log_db_err("revoking OAuth sessions")?; 965 910 } 966 911 967 912 info!(did = %&auth.did, "All other sessions revoked"); ··· 983 928 .user_repo 984 929 .get_legacy_login_pref(&auth.did) 985 930 .await 986 - .map_err(|e| { 987 - error!("DB error: {:?}", e); 988 - ApiError::InternalError(None) 989 - })? 931 + .log_db_err("getting legacy login pref")? 990 932 .ok_or(ApiError::AccountNotFound)?; 991 933 Ok(Json(LegacyLoginPreferenceOutput { 992 934 allow_legacy_login: pref.allow_legacy_login, ··· 1006 948 auth: Auth<Active>, 1007 949 Json(input): Json<UpdateLegacyLoginInput>, 1008 950 ) -> Result<Response, ApiError> { 1009 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 1010 - { 1011 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 1012 - &*state.user_repo, 1013 - &*state.session_repo, 1014 - &auth.did, 1015 - ) 1016 - .await); 1017 - } 951 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 952 + Ok(proof) => proof, 953 + Err(response) => return Ok(response), 954 + }; 1018 955 1019 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 1020 - return Ok(crate::api::server::reauth::reauth_required_response( 1021 - &*state.user_repo, 1022 - &*state.session_repo, 1023 - &auth.did, 1024 - ) 1025 - .await); 1026 - } 956 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 957 + Ok(proof) => proof, 958 + Err(response) => return Ok(response), 959 + }; 1027 960 1028 961 let updated = state 1029 962 .user_repo 1030 - .update_legacy_login(&auth.did, input.allow_legacy_login) 963 + .update_legacy_login(reauth_mfa.did(), input.allow_legacy_login) 1031 964 .await 1032 - .map_err(|e| { 1033 - error!("DB error: {:?}", e); 1034 - ApiError::InternalError(None) 1035 - })?; 965 + .log_db_err("updating legacy login")?; 1036 966 if !updated { 1037 967 return Err(ApiError::AccountNotFound); 1038 968 } 1039 969 info!( 1040 - did = %&auth.did, 970 + did = %session_mfa.did(), 1041 971 allow_legacy_login = input.allow_legacy_login, 1042 972 "Legacy login preference updated" 1043 973 ); ··· 1071 1001 .user_repo 1072 1002 .update_locale(&auth.did, &input.preferred_locale) 1073 1003 .await 1074 - .map_err(|e| { 1075 - error!("DB error updating locale: {:?}", e); 1076 - ApiError::InternalError(None) 1077 - })?; 1004 + .log_db_err("updating locale")?; 1078 1005 if !updated { 1079 1006 return Err(ApiError::AccountNotFound); 1080 1007 }
+68 -166
crates/tranquil-pds/src/api/server/totp.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::{Active, Auth}; 2 + use crate::api::error::{ApiError, DbResultExt}; 4 3 use crate::auth::{ 5 - decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64, 6 - generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format, 7 - verify_backup_code, verify_totp_code, 4 + Active, Auth, decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, 5 + generate_qr_png_base64, generate_totp_secret, generate_totp_uri, hash_backup_code, 6 + is_backup_code_format, require_legacy_session_mfa, verify_backup_code, verify_password_mfa, 7 + verify_totp_code, verify_totp_mfa, 8 8 }; 9 - use crate::state::{AppState, RateLimitKind}; 9 + use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message}; 10 + use crate::state::AppState; 10 11 use crate::types::PlainPassword; 12 + use crate::util::pds_hostname; 11 13 use axum::{ 12 14 Json, 13 15 extract::State, ··· 30 32 State(state): State<AppState>, 31 33 auth: Auth<Active>, 32 34 ) -> Result<Response, ApiError> { 33 - match state.user_repo.get_totp_record(&auth.did).await { 34 - Ok(Some(record)) if record.verified => return Err(ApiError::TotpAlreadyEnabled), 35 - Ok(_) => {} 35 + use tranquil_db_traits::TotpRecordState; 36 + 37 + match state.user_repo.get_totp_record_state(&auth.did).await { 38 + Ok(Some(TotpRecordState::Verified(_))) => return Err(ApiError::TotpAlreadyEnabled), 39 + Ok(Some(TotpRecordState::Unverified(_))) | Ok(None) => {} 36 40 Err(e) => { 37 41 error!("DB error checking TOTP: {:?}", e); 38 42 return Err(ApiError::InternalError(None)); ··· 45 49 .user_repo 46 50 .get_handle_by_did(&auth.did) 47 51 .await 48 - .map_err(|e| { 49 - error!("DB error fetching handle: {:?}", e); 50 - ApiError::InternalError(None) 51 - })? 52 + .log_db_err("fetching handle")? 52 53 .ok_or(ApiError::AccountNotFound)?; 53 54 54 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 55 - let uri = generate_totp_uri(&secret, &handle, &hostname); 55 + let hostname = pds_hostname(); 56 + let uri = generate_totp_uri(&secret, &handle, hostname); 56 57 57 - let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| { 58 + let qr_code = generate_qr_png_base64(&secret, &handle, hostname).map_err(|e| { 58 59 error!("Failed to generate QR code: {:?}", e); 59 60 ApiError::InternalError(Some("Failed to generate QR code".into())) 60 61 })?; ··· 68 69 .user_repo 69 70 .upsert_totp_secret(&auth.did, &encrypted_secret, ENCRYPTION_VERSION) 70 71 .await 71 - .map_err(|e| { 72 - error!("Failed to store TOTP secret: {:?}", e); 73 - ApiError::InternalError(None) 74 - })?; 72 + .log_db_err("storing TOTP secret")?; 75 73 76 74 let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); 77 75 ··· 101 99 auth: Auth<Active>, 102 100 Json(input): Json<EnableTotpInput>, 103 101 ) -> Result<Response, ApiError> { 104 - if !state 105 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 106 - .await 107 - { 108 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 109 - return Err(ApiError::RateLimitExceeded(None)); 110 - } 102 + use tranquil_db_traits::TotpRecordState; 111 103 112 - let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 113 - Ok(Some(row)) => row, 104 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 105 + &state, 106 + &auth.did, 107 + "Too many verification attempts. Please try again in a few minutes.", 108 + ) 109 + .await?; 110 + 111 + let unverified_record = match state.user_repo.get_totp_record_state(&auth.did).await { 112 + Ok(Some(TotpRecordState::Unverified(record))) => record, 113 + Ok(Some(TotpRecordState::Verified(_))) => return Err(ApiError::TotpAlreadyEnabled), 114 114 Ok(None) => return Err(ApiError::TotpNotEnabled), 115 115 Err(e) => { 116 116 error!("DB error fetching TOTP: {:?}", e); ··· 118 118 } 119 119 }; 120 120 121 - if totp_record.verified { 122 - return Err(ApiError::TotpAlreadyEnabled); 123 - } 124 - 125 121 let secret = decrypt_totp_secret( 126 - &totp_record.secret_encrypted, 127 - totp_record.encryption_version, 122 + &unverified_record.secret_encrypted, 123 + unverified_record.encryption_version, 128 124 ) 129 125 .map_err(|e| { 130 126 error!("Failed to decrypt TOTP secret: {:?}", e); ··· 152 148 .user_repo 153 149 .enable_totp_with_backup_codes(&auth.did, &backup_hashes) 154 150 .await 155 - .map_err(|e| { 156 - error!("Failed to enable TOTP: {:?}", e); 157 - ApiError::InternalError(None) 158 - })?; 151 + .log_db_err("enabling TOTP")?; 159 152 160 153 info!(did = %&auth.did, "TOTP enabled with {} backup codes", backup_codes.len()); 161 154 ··· 173 166 auth: Auth<Active>, 174 167 Json(input): Json<DisableTotpInput>, 175 168 ) -> Result<Response, ApiError> { 176 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 177 - { 178 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 179 - &*state.user_repo, 180 - &*state.session_repo, 181 - &auth.did, 182 - ) 183 - .await); 184 - } 185 - 186 - if !state 187 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 188 - .await 189 - { 190 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 191 - return Err(ApiError::RateLimitExceeded(None)); 192 - } 193 - 194 - let password_hash = state 195 - .user_repo 196 - .get_password_hash_by_did(&auth.did) 197 - .await 198 - .map_err(|e| { 199 - error!("DB error fetching user: {:?}", e); 200 - ApiError::InternalError(None) 201 - })? 202 - .ok_or(ApiError::AccountNotFound)?; 203 - 204 - let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 205 - if !password_valid { 206 - return Err(ApiError::InvalidPassword("Password is incorrect".into())); 207 - } 208 - 209 - let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 210 - Ok(Some(row)) if row.verified => row, 211 - Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 212 - Err(e) => { 213 - error!("DB error fetching TOTP: {:?}", e); 214 - return Err(ApiError::InternalError(None)); 215 - } 169 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 170 + Ok(proof) => proof, 171 + Err(response) => return Ok(response), 216 172 }; 217 173 218 - let code = input.code.trim(); 219 - let code_valid = if is_backup_code_format(code) { 220 - verify_backup_code_for_user(&state, &auth.did, code).await 221 - } else { 222 - let secret = decrypt_totp_secret( 223 - &totp_record.secret_encrypted, 224 - totp_record.encryption_version, 225 - ) 226 - .map_err(|e| { 227 - error!("Failed to decrypt TOTP secret: {:?}", e); 228 - ApiError::InternalError(None) 229 - })?; 230 - verify_totp_code(&secret, code) 231 - }; 174 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 175 + &state, 176 + session_mfa.did(), 177 + "Too many verification attempts. Please try again in a few minutes.", 178 + ) 179 + .await?; 232 180 233 - if !code_valid { 234 - return Err(ApiError::InvalidCode(Some( 235 - "Invalid verification code".into(), 236 - ))); 237 - } 181 + let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?; 182 + let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?; 238 183 239 184 state 240 185 .user_repo 241 - .delete_totp_and_backup_codes(&auth.did) 186 + .delete_totp_and_backup_codes(totp_mfa.did()) 242 187 .await 243 - .map_err(|e| { 244 - error!("Failed to delete TOTP: {:?}", e); 245 - ApiError::InternalError(None) 246 - })?; 188 + .log_db_err("deleting TOTP")?; 247 189 248 - info!(did = %&auth.did, "TOTP disabled"); 190 + info!(did = %session_mfa.did(), "TOTP disabled (verified via {} and {})", password_mfa.method(), totp_mfa.method()); 249 191 250 192 Ok(EmptyResponse::ok().into_response()) 251 193 } ··· 262 204 State(state): State<AppState>, 263 205 auth: Auth<Active>, 264 206 ) -> Result<Response, ApiError> { 265 - let enabled = match state.user_repo.get_totp_record(&auth.did).await { 266 - Ok(Some(row)) => row.verified, 267 - Ok(None) => false, 207 + use tranquil_db_traits::TotpRecordState; 208 + 209 + let enabled = match state.user_repo.get_totp_record_state(&auth.did).await { 210 + Ok(Some(TotpRecordState::Verified(_))) => true, 211 + Ok(Some(TotpRecordState::Unverified(_))) | Ok(None) => false, 268 212 Err(e) => { 269 213 error!("DB error fetching TOTP status: {:?}", e); 270 214 return Err(ApiError::InternalError(None)); ··· 275 219 .user_repo 276 220 .count_unused_backup_codes(&auth.did) 277 221 .await 278 - .map_err(|e| { 279 - error!("DB error counting backup codes: {:?}", e); 280 - ApiError::InternalError(None) 281 - })?; 222 + .log_db_err("counting backup codes")?; 282 223 283 224 Ok(Json(GetTotpStatusResponse { 284 225 enabled, ··· 305 246 auth: Auth<Active>, 306 247 Json(input): Json<RegenerateBackupCodesInput>, 307 248 ) -> Result<Response, ApiError> { 308 - if !state 309 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 310 - .await 311 - { 312 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 313 - return Err(ApiError::RateLimitExceeded(None)); 314 - } 315 - 316 - let password_hash = state 317 - .user_repo 318 - .get_password_hash_by_did(&auth.did) 319 - .await 320 - .map_err(|e| { 321 - error!("DB error fetching user: {:?}", e); 322 - ApiError::InternalError(None) 323 - })? 324 - .ok_or(ApiError::AccountNotFound)?; 325 - 326 - let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 327 - if !password_valid { 328 - return Err(ApiError::InvalidPassword("Password is incorrect".into())); 329 - } 330 - 331 - let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 332 - Ok(Some(row)) if row.verified => row, 333 - Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 334 - Err(e) => { 335 - error!("DB error fetching TOTP: {:?}", e); 336 - return Err(ApiError::InternalError(None)); 337 - } 338 - }; 339 - 340 - let secret = decrypt_totp_secret( 341 - &totp_record.secret_encrypted, 342 - totp_record.encryption_version, 249 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 250 + &state, 251 + &auth.did, 252 + "Too many verification attempts. Please try again in a few minutes.", 343 253 ) 344 - .map_err(|e| { 345 - error!("Failed to decrypt TOTP secret: {:?}", e); 346 - ApiError::InternalError(None) 347 - })?; 254 + .await?; 348 255 349 - let code = input.code.trim(); 350 - if !verify_totp_code(&secret, code) { 351 - return Err(ApiError::InvalidCode(Some( 352 - "Invalid verification code".into(), 353 - ))); 354 - } 256 + let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?; 257 + let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?; 355 258 356 259 let backup_codes = generate_backup_codes(); 357 260 let backup_hashes: Vec<_> = backup_codes ··· 365 268 366 269 state 367 270 .user_repo 368 - .replace_backup_codes(&auth.did, &backup_hashes) 271 + .replace_backup_codes(totp_mfa.did(), &backup_hashes) 369 272 .await 370 - .map_err(|e| { 371 - error!("Failed to regenerate backup codes: {:?}", e); 372 - ApiError::InternalError(None) 373 - })?; 273 + .log_db_err("replacing backup codes")?; 374 274 375 - info!(did = %&auth.did, "Backup codes regenerated"); 275 + info!(did = %password_mfa.did(), "Backup codes regenerated (verified via {} and {})", password_mfa.method(), totp_mfa.method()); 376 276 377 277 Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response()) 378 278 } ··· 410 310 did: &crate::types::Did, 411 311 code: &str, 412 312 ) -> bool { 313 + use tranquil_db_traits::TotpRecordState; 314 + 413 315 let code = code.trim(); 414 316 415 317 if is_backup_code_format(code) { 416 318 return verify_backup_code_for_user(state, did, code).await; 417 319 } 418 320 419 - let totp_record = match state.user_repo.get_totp_record(did).await { 420 - Ok(Some(row)) if row.verified => row, 321 + let verified_record = match state.user_repo.get_totp_record_state(did).await { 322 + Ok(Some(TotpRecordState::Verified(record))) => record, 421 323 _ => return false, 422 324 }; 423 325 424 326 let secret = match decrypt_totp_secret( 425 - &totp_record.secret_encrypted, 426 - totp_record.encryption_version, 327 + &verified_record.secret_encrypted, 328 + verified_record.encryption_version, 427 329 ) { 428 330 Ok(s) => s, 429 331 Err(_) => return false,
+4 -13
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 1 1 use crate::api::SuccessResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use axum::{ 4 4 Json, 5 5 extract::State, ··· 79 79 .oauth_repo 80 80 .list_trusted_devices(&auth.did) 81 81 .await 82 - .map_err(|e| { 83 - error!("DB error: {:?}", e); 84 - ApiError::InternalError(None) 85 - })?; 82 + .log_db_err("listing trusted devices")?; 86 83 87 84 let devices = rows 88 85 .into_iter() ··· 134 131 .oauth_repo 135 132 .revoke_device_trust(&device_id) 136 133 .await 137 - .map_err(|e| { 138 - error!("DB error: {:?}", e); 139 - ApiError::InternalError(None) 140 - })?; 134 + .log_db_err("revoking device trust")?; 141 135 142 136 info!(did = %&auth.did, device_id = %input.device_id, "Trusted device revoked"); 143 137 Ok(SuccessResponse::ok().into_response()) ··· 175 169 .oauth_repo 176 170 .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 177 171 .await 178 - .map_err(|e| { 179 - error!("DB error: {:?}", e); 180 - ApiError::InternalError(None) 181 - })?; 172 + .log_db_err("updating device friendly name")?; 182 173 183 174 info!(did = %auth.did, device_id = %input.device_id, "Trusted device updated"); 184 175 Ok(SuccessResponse::ok().into_response())
+3 -2
crates/tranquil-pds/src/api/server/verify_email.rs
··· 5 5 use tracing::{info, warn}; 6 6 7 7 use crate::state::AppState; 8 + use crate::util::pds_hostname; 8 9 9 10 #[derive(Deserialize)] 10 11 #[serde(rename_all = "camelCase")] ··· 70 71 return Ok(Json(ResendMigrationVerificationOutput { sent: true })); 71 72 } 72 73 73 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 74 + let hostname = pds_hostname(); 74 75 let token = crate::auth::verification_token::generate_migration_token(&user.did, &email); 75 76 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 76 77 ··· 80 81 user.id, 81 82 &email, 82 83 &formatted_token, 83 - &hostname, 84 + hostname, 84 85 ) 85 86 .await 86 87 {
+16 -52
crates/tranquil-pds/src/api/server/verify_token.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::types::Did; 3 3 use axum::{Json, extract::State}; 4 4 use serde::{Deserialize, Serialize}; 5 - use tracing::{error, info, warn}; 5 + use tracing::{info, warn}; 6 6 7 7 use crate::auth::verification_token::{ 8 8 VerificationPurpose, normalize_token_input, verify_token_signature, ··· 81 81 .user_repo 82 82 .get_verification_info(&did_typed) 83 83 .await 84 - .map_err(|e| { 85 - warn!(error = ?e, "Database error during migration verification"); 86 - ApiError::InternalError(None) 87 - })? 84 + .log_db_err("during migration verification")? 88 85 .ok_or(ApiError::AccountNotFound)?; 89 86 90 87 if user.email.as_ref().map(|e| e.to_lowercase()) != Some(identifier.to_string()) { 91 88 return Err(ApiError::IdentifierMismatch); 92 89 } 93 90 94 - if !user.email_verified { 91 + if !user.channel_verification.email { 95 92 state 96 93 .user_repo 97 94 .set_email_verified_flag(user.id) 98 95 .await 99 - .map_err(|e| { 100 - warn!(error = ?e, "Failed to update email_verified status"); 101 - ApiError::InternalError(None) 102 - })?; 96 + .log_db_err("updating email_verified status")?; 103 97 } 104 98 105 99 info!(did = %did, "Migration email verified successfully"); ··· 125 119 .user_repo 126 120 .get_id_by_did(&did_typed) 127 121 .await 128 - .map_err(|_| ApiError::InternalError(None))? 122 + .log_db_err("fetching user id")? 129 123 .ok_or(ApiError::AccountNotFound)?; 130 124 131 125 match channel { ··· 134 128 .user_repo 135 129 .verify_email_channel(user_id, identifier) 136 130 .await 137 - .map_err(|e| { 138 - error!("Failed to update email channel: {:?}", e); 139 - ApiError::InternalError(None) 140 - })?; 131 + .log_db_err("updating email channel")?; 141 132 if !success { 142 133 return Err(ApiError::EmailTaken); 143 134 } ··· 147 138 .user_repo 148 139 .verify_discord_channel(user_id, identifier) 149 140 .await 150 - .map_err(|e| { 151 - error!("Failed to update discord channel: {:?}", e); 152 - ApiError::InternalError(None) 153 - })?; 141 + .log_db_err("updating discord channel")?; 154 142 } 155 143 "telegram" => { 156 144 state 157 145 .user_repo 158 146 .verify_telegram_channel(user_id, identifier) 159 147 .await 160 - .map_err(|e| { 161 - error!("Failed to update telegram channel: {:?}", e); 162 - ApiError::InternalError(None) 163 - })?; 148 + .log_db_err("updating telegram channel")?; 164 149 } 165 150 "signal" => { 166 151 state 167 152 .user_repo 168 153 .verify_signal_channel(user_id, identifier) 169 154 .await 170 - .map_err(|e| { 171 - error!("Failed to update signal channel: {:?}", e); 172 - ApiError::InternalError(None) 173 - })?; 155 + .log_db_err("updating signal channel")?; 174 156 } 175 157 _ => { 176 158 return Err(ApiError::InvalidChannel); ··· 200 182 .user_repo 201 183 .get_verification_info(&did_typed) 202 184 .await 203 - .map_err(|e| { 204 - warn!(error = ?e, "Database error during signup verification"); 205 - ApiError::InternalError(None) 206 - })? 185 + .log_db_err("during signup verification")? 207 186 .ok_or(ApiError::AccountNotFound)?; 208 187 209 - let is_verified = user.email_verified 210 - || user.discord_verified 211 - || user.telegram_verified 212 - || user.signal_verified; 188 + let is_verified = user.channel_verification.has_any_verified(); 213 189 if is_verified { 214 190 info!(did = %did, "Account already verified"); 215 191 return Ok(Json(VerifyTokenOutput { ··· 226 202 .user_repo 227 203 .set_email_verified_flag(user.id) 228 204 .await 229 - .map_err(|e| { 230 - warn!(error = ?e, "Failed to update email verified status"); 231 - ApiError::InternalError(None) 232 - })?; 205 + .log_db_err("updating email verified status")?; 233 206 } 234 207 "discord" => { 235 208 state 236 209 .user_repo 237 210 .set_discord_verified_flag(user.id) 238 211 .await 239 - .map_err(|e| { 240 - warn!(error = ?e, "Failed to update discord verified status"); 241 - ApiError::InternalError(None) 242 - })?; 212 + .log_db_err("updating discord verified status")?; 243 213 } 244 214 "telegram" => { 245 215 state 246 216 .user_repo 247 217 .set_telegram_verified_flag(user.id) 248 218 .await 249 - .map_err(|e| { 250 - warn!(error = ?e, "Failed to update telegram verified status"); 251 - ApiError::InternalError(None) 252 - })?; 219 + .log_db_err("updating telegram verified status")?; 253 220 } 254 221 "signal" => { 255 222 state 256 223 .user_repo 257 224 .set_signal_verified_flag(user.id) 258 225 .await 259 - .map_err(|e| { 260 - warn!(error = ?e, "Failed to update signal verified status"); 261 - ApiError::InternalError(None) 262 - })?; 226 + .log_db_err("updating signal verified status")?; 263 227 } 264 228 _ => { 265 229 return Err(ApiError::InvalidChannel);
+61
crates/tranquil-pds/src/auth/account_verified.rs
··· 1 + use axum::response::{IntoResponse, Response}; 2 + 3 + use super::AuthenticatedUser; 4 + use crate::api::error::ApiError; 5 + use crate::state::AppState; 6 + use crate::types::Did; 7 + 8 + pub struct AccountVerified<'a> { 9 + user: &'a AuthenticatedUser, 10 + } 11 + 12 + impl<'a> AccountVerified<'a> { 13 + pub fn did(&self) -> &Did { 14 + &self.user.did 15 + } 16 + 17 + pub fn user(&self) -> &AuthenticatedUser { 18 + self.user 19 + } 20 + } 21 + 22 + pub async fn require_verified_or_delegated<'a>( 23 + state: &AppState, 24 + user: &'a AuthenticatedUser, 25 + ) -> Result<AccountVerified<'a>, Response> { 26 + let is_verified = state 27 + .user_repo 28 + .has_verified_comms_channel(&user.did) 29 + .await 30 + .unwrap_or(false); 31 + 32 + if is_verified { 33 + return Ok(AccountVerified { user }); 34 + } 35 + 36 + let is_delegated = state 37 + .delegation_repo 38 + .is_delegated_account(&user.did) 39 + .await 40 + .unwrap_or(false); 41 + 42 + if is_delegated { 43 + return Ok(AccountVerified { user }); 44 + } 45 + 46 + Err(ApiError::AccountNotVerified.into_response()) 47 + } 48 + 49 + pub async fn require_not_migrated(state: &AppState, did: &Did) -> Result<(), Response> { 50 + match state.user_repo.is_account_migrated(did).await { 51 + Ok(true) => Err(ApiError::AccountMigrated.into_response()), 52 + Ok(false) => Ok(()), 53 + Err(e) => { 54 + tracing::error!("Failed to check migration status: {:?}", e); 55 + Err( 56 + ApiError::InternalError(Some("Failed to verify migration status".into())) 57 + .into_response(), 58 + ) 59 + } 60 + } 61 + }
-547
crates/tranquil-pds/src/auth/auth_extractor.rs
··· 1 - mod common; 2 - mod helpers; 3 - 4 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 - use chrono::Utc; 6 - use common::{base_url, client, create_account_and_login, pds_endpoint}; 7 - use helpers::verify_new_account; 8 - use reqwest::StatusCode; 9 - use serde_json::{Value, json}; 10 - use sha2::{Digest, Sha256}; 11 - use wiremock::matchers::{method, path}; 12 - use wiremock::{Mock, MockServer, ResponseTemplate}; 13 - 14 - fn generate_pkce() -> (String, String) { 15 - let verifier_bytes: [u8; 32] = rand::random(); 16 - let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 - let mut hasher = Sha256::new(); 18 - hasher.update(code_verifier.as_bytes()); 19 - let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); 20 - (code_verifier, code_challenge) 21 - } 22 - 23 - async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer { 24 - let mock_server = MockServer::start().await; 25 - let metadata = json!({ 26 - "client_id": mock_server.uri(), 27 - "client_name": "Auth Extractor Test Client", 28 - "redirect_uris": [redirect_uri], 29 - "grant_types": ["authorization_code", "refresh_token"], 30 - "response_types": ["code"], 31 - "token_endpoint_auth_method": "none", 32 - "dpop_bound_access_tokens": dpop_bound 33 - }); 34 - Mock::given(method("GET")) 35 - .and(path("/")) 36 - .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 37 - .mount(&mock_server) 38 - .await; 39 - mock_server 40 - } 41 - 42 - async fn get_oauth_session( 43 - http_client: &reqwest::Client, 44 - url: &str, 45 - dpop_bound: bool, 46 - ) -> (String, String, String, String) { 47 - let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 48 - let handle = format!("ae{}", suffix); 49 - let password = "AuthExtract123!"; 50 - let create_res = http_client 51 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 52 - .json(&json!({ 53 - "handle": handle, 54 - "email": format!("{}@example.com", handle), 55 - "password": password 56 - })) 57 - .send() 58 - .await 59 - .unwrap(); 60 - assert_eq!(create_res.status(), StatusCode::OK); 61 - let account: Value = create_res.json().await.unwrap(); 62 - let did = account["did"].as_str().unwrap().to_string(); 63 - verify_new_account(http_client, &did).await; 64 - 65 - let redirect_uri = "https://example.com/auth-callback"; 66 - let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await; 67 - let client_id = mock_client.uri(); 68 - let (code_verifier, code_challenge) = generate_pkce(); 69 - 70 - let par_body: Value = http_client 71 - .post(format!("{}/oauth/par", url)) 72 - .form(&[ 73 - ("response_type", "code"), 74 - ("client_id", &client_id), 75 - ("redirect_uri", redirect_uri), 76 - ("code_challenge", &code_challenge), 77 - ("code_challenge_method", "S256"), 78 - ]) 79 - .send() 80 - .await 81 - .unwrap() 82 - .json() 83 - .await 84 - .unwrap(); 85 - let request_uri = par_body["request_uri"].as_str().unwrap(); 86 - 87 - let auth_res = http_client 88 - .post(format!("{}/oauth/authorize", url)) 89 - .header("Content-Type", "application/json") 90 - .header("Accept", "application/json") 91 - .json(&json!({ 92 - "request_uri": request_uri, 93 - "username": &handle, 94 - "password": password, 95 - "remember_device": false 96 - })) 97 - .send() 98 - .await 99 - .unwrap(); 100 - let auth_body: Value = auth_res.json().await.unwrap(); 101 - let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 102 - 103 - if location.contains("/oauth/consent") { 104 - let consent_res = http_client 105 - .post(format!("{}/oauth/authorize/consent", url)) 106 - .header("Content-Type", "application/json") 107 - .json(&json!({ 108 - "request_uri": request_uri, 109 - "approved_scopes": ["atproto"], 110 - "remember": false 111 - })) 112 - .send() 113 - .await 114 - .unwrap(); 115 - let consent_body: Value = consent_res.json().await.unwrap(); 116 - location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 117 - } 118 - 119 - let code = location 120 - .split("code=") 121 - .nth(1) 122 - .unwrap() 123 - .split('&') 124 - .next() 125 - .unwrap(); 126 - 127 - let token_body: Value = http_client 128 - .post(format!("{}/oauth/token", url)) 129 - .form(&[ 130 - ("grant_type", "authorization_code"), 131 - ("code", code), 132 - ("redirect_uri", redirect_uri), 133 - ("code_verifier", &code_verifier), 134 - ("client_id", &client_id), 135 - ]) 136 - .send() 137 - .await 138 - .unwrap() 139 - .json() 140 - .await 141 - .unwrap(); 142 - 143 - ( 144 - token_body["access_token"].as_str().unwrap().to_string(), 145 - token_body["refresh_token"].as_str().unwrap().to_string(), 146 - client_id, 147 - did, 148 - ) 149 - } 150 - 151 - #[tokio::test] 152 - async fn test_oauth_token_works_with_bearer_auth() { 153 - let url = base_url().await; 154 - let http_client = client(); 155 - let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await; 156 - 157 - let res = http_client 158 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 159 - .bearer_auth(&access_token) 160 - .send() 161 - .await 162 - .unwrap(); 163 - 164 - assert_eq!(res.status(), StatusCode::OK, "OAuth token should work with RequiredAuth extractor"); 165 - let body: Value = res.json().await.unwrap(); 166 - assert_eq!(body["did"].as_str().unwrap(), did); 167 - } 168 - 169 - #[tokio::test] 170 - async fn test_session_token_still_works() { 171 - let url = base_url().await; 172 - let http_client = client(); 173 - let (jwt, did) = create_account_and_login(&http_client).await; 174 - 175 - let res = http_client 176 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 177 - .bearer_auth(&jwt) 178 - .send() 179 - .await 180 - .unwrap(); 181 - 182 - assert_eq!(res.status(), StatusCode::OK, "Session token should still work"); 183 - let body: Value = res.json().await.unwrap(); 184 - assert_eq!(body["did"].as_str().unwrap(), did); 185 - } 186 - 187 - 188 - #[tokio::test] 189 - async fn test_oauth_admin_extractor_allows_oauth_tokens() { 190 - let url = base_url().await; 191 - let http_client = client(); 192 - 193 - let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 194 - let handle = format!("adm{}", suffix); 195 - let password = "AdminOAuth123!"; 196 - let create_res = http_client 197 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 198 - .json(&json!({ 199 - "handle": handle, 200 - "email": format!("{}@example.com", handle), 201 - "password": password 202 - })) 203 - .send() 204 - .await 205 - .unwrap(); 206 - assert_eq!(create_res.status(), StatusCode::OK); 207 - let account: Value = create_res.json().await.unwrap(); 208 - let did = account["did"].as_str().unwrap().to_string(); 209 - verify_new_account(&http_client, &did).await; 210 - 211 - let pool = common::get_test_db_pool().await; 212 - sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 213 - .execute(pool) 214 - .await 215 - .expect("Failed to mark user as admin"); 216 - 217 - let redirect_uri = "https://example.com/admin-callback"; 218 - let mock_client = setup_mock_client_metadata(redirect_uri, false).await; 219 - let client_id = mock_client.uri(); 220 - let (code_verifier, code_challenge) = generate_pkce(); 221 - 222 - let par_body: Value = http_client 223 - .post(format!("{}/oauth/par", url)) 224 - .form(&[ 225 - ("response_type", "code"), 226 - ("client_id", &client_id), 227 - ("redirect_uri", redirect_uri), 228 - ("code_challenge", &code_challenge), 229 - ("code_challenge_method", "S256"), 230 - ]) 231 - .send() 232 - .await 233 - .unwrap() 234 - .json() 235 - .await 236 - .unwrap(); 237 - let request_uri = par_body["request_uri"].as_str().unwrap(); 238 - 239 - let auth_res = http_client 240 - .post(format!("{}/oauth/authorize", url)) 241 - .header("Content-Type", "application/json") 242 - .header("Accept", "application/json") 243 - .json(&json!({ 244 - "request_uri": request_uri, 245 - "username": &handle, 246 - "password": password, 247 - "remember_device": false 248 - })) 249 - .send() 250 - .await 251 - .unwrap(); 252 - let auth_body: Value = auth_res.json().await.unwrap(); 253 - let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 254 - if location.contains("/oauth/consent") { 255 - let consent_res = http_client 256 - .post(format!("{}/oauth/authorize/consent", url)) 257 - .header("Content-Type", "application/json") 258 - .json(&json!({ 259 - "request_uri": request_uri, 260 - "approved_scopes": ["atproto"], 261 - "remember": false 262 - })) 263 - .send() 264 - .await 265 - .unwrap(); 266 - let consent_body: Value = consent_res.json().await.unwrap(); 267 - location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 268 - } 269 - 270 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 271 - let token_body: Value = http_client 272 - .post(format!("{}/oauth/token", url)) 273 - .form(&[ 274 - ("grant_type", "authorization_code"), 275 - ("code", code), 276 - ("redirect_uri", redirect_uri), 277 - ("code_verifier", &code_verifier), 278 - ("client_id", &client_id), 279 - ]) 280 - .send() 281 - .await 282 - .unwrap() 283 - .json() 284 - .await 285 - .unwrap(); 286 - let access_token = token_body["access_token"].as_str().unwrap(); 287 - 288 - let res = http_client 289 - .get(format!("{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", url, did)) 290 - .bearer_auth(access_token) 291 - .send() 292 - .await 293 - .unwrap(); 294 - 295 - assert_eq!( 296 - res.status(), 297 - StatusCode::OK, 298 - "OAuth token for admin user should work with admin endpoint" 299 - ); 300 - } 301 - 302 - #[tokio::test] 303 - async fn test_expired_oauth_token_returns_proper_error() { 304 - let url = base_url().await; 305 - let http_client = client(); 306 - 307 - let now = Utc::now().timestamp(); 308 - let header = json!({"alg": "HS256", "typ": "at+jwt"}); 309 - let payload = json!({ 310 - "iss": url, 311 - "sub": "did:plc:test123", 312 - "aud": url, 313 - "iat": now - 7200, 314 - "exp": now - 3600, 315 - "jti": "expired-token", 316 - "sid": "expired-session", 317 - "scope": "atproto", 318 - "client_id": "https://example.com" 319 - }); 320 - let fake_token = format!( 321 - "{}.{}.{}", 322 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 323 - URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 324 - URL_SAFE_NO_PAD.encode([1u8; 32]) 325 - ); 326 - 327 - let res = http_client 328 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 329 - .bearer_auth(&fake_token) 330 - .send() 331 - .await 332 - .unwrap(); 333 - 334 - assert_eq!( 335 - res.status(), 336 - StatusCode::UNAUTHORIZED, 337 - "Expired token should be rejected" 338 - ); 339 - } 340 - 341 - #[tokio::test] 342 - async fn test_dpop_nonce_error_has_proper_headers() { 343 - let url = base_url().await; 344 - let pds_url = pds_endpoint(); 345 - let http_client = client(); 346 - 347 - let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 348 - let handle = format!("dpop{}", suffix); 349 - let create_res = http_client 350 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 351 - .json(&json!({ 352 - "handle": handle, 353 - "email": format!("{}@test.com", handle), 354 - "password": "DpopTest123!" 355 - })) 356 - .send() 357 - .await 358 - .unwrap(); 359 - assert_eq!(create_res.status(), StatusCode::OK); 360 - let account: Value = create_res.json().await.unwrap(); 361 - let did = account["did"].as_str().unwrap(); 362 - verify_new_account(&http_client, did).await; 363 - 364 - let redirect_uri = "https://example.com/dpop-callback"; 365 - let mock_server = MockServer::start().await; 366 - let client_id = mock_server.uri(); 367 - let metadata = json!({ 368 - "client_id": &client_id, 369 - "client_name": "DPoP Test Client", 370 - "redirect_uris": [redirect_uri], 371 - "grant_types": ["authorization_code", "refresh_token"], 372 - "response_types": ["code"], 373 - "token_endpoint_auth_method": "none", 374 - "dpop_bound_access_tokens": true 375 - }); 376 - Mock::given(method("GET")) 377 - .and(path("/")) 378 - .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 379 - .mount(&mock_server) 380 - .await; 381 - 382 - let (code_verifier, code_challenge) = generate_pkce(); 383 - let par_body: Value = http_client 384 - .post(format!("{}/oauth/par", url)) 385 - .form(&[ 386 - ("response_type", "code"), 387 - ("client_id", &client_id), 388 - ("redirect_uri", redirect_uri), 389 - ("code_challenge", &code_challenge), 390 - ("code_challenge_method", "S256"), 391 - ]) 392 - .send() 393 - .await 394 - .unwrap() 395 - .json() 396 - .await 397 - .unwrap(); 398 - 399 - let request_uri = par_body["request_uri"].as_str().unwrap(); 400 - let auth_res = http_client 401 - .post(format!("{}/oauth/authorize", url)) 402 - .header("Content-Type", "application/json") 403 - .header("Accept", "application/json") 404 - .json(&json!({ 405 - "request_uri": request_uri, 406 - "username": &handle, 407 - "password": "DpopTest123!", 408 - "remember_device": false 409 - })) 410 - .send() 411 - .await 412 - .unwrap(); 413 - let auth_body: Value = auth_res.json().await.unwrap(); 414 - let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 415 - if location.contains("/oauth/consent") { 416 - let consent_res = http_client 417 - .post(format!("{}/oauth/authorize/consent", url)) 418 - .header("Content-Type", "application/json") 419 - .json(&json!({ 420 - "request_uri": request_uri, 421 - "approved_scopes": ["atproto"], 422 - "remember": false 423 - })) 424 - .send() 425 - .await 426 - .unwrap(); 427 - let consent_body: Value = consent_res.json().await.unwrap(); 428 - location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 429 - } 430 - 431 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 432 - 433 - let token_endpoint = format!("{}/oauth/token", pds_url); 434 - let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None); 435 - 436 - let token_res = http_client 437 - .post(format!("{}/oauth/token", url)) 438 - .header("DPoP", &dpop_proof) 439 - .form(&[ 440 - ("grant_type", "authorization_code"), 441 - ("code", code), 442 - ("redirect_uri", redirect_uri), 443 - ("code_verifier", &code_verifier), 444 - ("client_id", &client_id), 445 - ]) 446 - .send() 447 - .await 448 - .unwrap(); 449 - 450 - let token_status = token_res.status(); 451 - let token_nonce = token_res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap().to_string()); 452 - let token_body: Value = token_res.json().await.unwrap(); 453 - 454 - let access_token = if token_status == StatusCode::OK { 455 - token_body["access_token"].as_str().unwrap().to_string() 456 - } else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") { 457 - let nonce = token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error"); 458 - let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce)); 459 - 460 - let retry_res = http_client 461 - .post(format!("{}/oauth/token", url)) 462 - .header("DPoP", &dpop_proof_with_nonce) 463 - .form(&[ 464 - ("grant_type", "authorization_code"), 465 - ("code", code), 466 - ("redirect_uri", redirect_uri), 467 - ("code_verifier", &code_verifier), 468 - ("client_id", &client_id), 469 - ]) 470 - .send() 471 - .await 472 - .unwrap(); 473 - let retry_body: Value = retry_res.json().await.unwrap(); 474 - retry_body["access_token"].as_str().expect("Should get access_token after nonce retry").to_string() 475 - } else { 476 - panic!("Token exchange failed unexpectedly: {:?}", token_body); 477 - }; 478 - 479 - let res = http_client 480 - .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 481 - .header("Authorization", format!("DPoP {}", access_token)) 482 - .send() 483 - .await 484 - .unwrap(); 485 - 486 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DPoP token without proof should fail"); 487 - 488 - let www_auth = res.headers().get("www-authenticate").map(|h| h.to_str().unwrap()); 489 - assert!(www_auth.is_some(), "Should have WWW-Authenticate header"); 490 - assert!( 491 - www_auth.unwrap().contains("use_dpop_nonce"), 492 - "WWW-Authenticate should indicate dpop nonce required" 493 - ); 494 - 495 - let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap()); 496 - assert!(nonce.is_some(), "Should return DPoP-Nonce header"); 497 - 498 - let body: Value = res.json().await.unwrap(); 499 - assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce"); 500 - } 501 - 502 - fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) { 503 - use p256::ecdsa::{SigningKey, signature::Signer}; 504 - use p256::elliptic_curve::rand_core::OsRng; 505 - 506 - let signing_key = SigningKey::random(&mut OsRng); 507 - let verifying_key = signing_key.verifying_key(); 508 - let point = verifying_key.to_encoded_point(false); 509 - let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 510 - let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 511 - 512 - let jwk = json!({ 513 - "kty": "EC", 514 - "crv": "P-256", 515 - "x": x, 516 - "y": y 517 - }); 518 - 519 - let header = { 520 - let h = json!({ 521 - "typ": "dpop+jwt", 522 - "alg": "ES256", 523 - "jwk": jwk.clone() 524 - }); 525 - h 526 - }; 527 - 528 - let mut payload = json!({ 529 - "jti": uuid::Uuid::new_v4().to_string(), 530 - "htm": method, 531 - "htu": uri, 532 - "iat": Utc::now().timestamp() 533 - }); 534 - if let Some(n) = nonce { 535 - payload["nonce"] = json!(n); 536 - } 537 - 538 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 539 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 540 - let signing_input = format!("{}.{}", header_b64, payload_b64); 541 - 542 - let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 543 - let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 544 - 545 - let proof = format!("{}.{}", signing_input, sig_b64); 546 - (jwk, proof) 547 - }
+22 -18
crates/tranquil-pds/src/auth/extractor.rs
··· 9 9 10 10 use super::{ 11 11 AccountStatus, AuthSource, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier, 12 - is_service_token, validate_bearer_token_for_service_auth, 12 + is_service_token, scope_verified::VerifyScope, validate_bearer_token_for_service_auth, 13 13 }; 14 14 use crate::api::error::ApiError; 15 15 use crate::oauth::scopes::{RepoAction, ScopePermissions}; ··· 293 293 return Ok(ExtractedAuth::Service(claims)); 294 294 } 295 295 296 - let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok()); 296 + let dpop_proof = crate::util::get_header_str(&parts.headers, "DPoP"); 297 297 let method = parts.method.as_str(); 298 298 let uri = build_full_url(&parts.uri.to_string()); 299 299 ··· 358 358 } 359 359 } 360 360 361 + impl<P: AuthPolicy> AsRef<AuthenticatedUser> for Auth<P> { 362 + fn as_ref(&self) -> &AuthenticatedUser { 363 + &self.0 364 + } 365 + } 366 + 367 + impl<P: AuthPolicy> VerifyScope for Auth<P> { 368 + fn needs_scope_check(&self) -> bool { 369 + self.0.is_oauth() 370 + } 371 + 372 + fn permissions(&self) -> ScopePermissions { 373 + self.0.permissions() 374 + } 375 + } 376 + 361 377 impl<P: AuthPolicy> FromRequestParts<AppState> for Auth<P> { 362 378 type Rejection = AuthError; 363 379 ··· 418 434 ) -> Result<Self, Self::Rejection> { 419 435 match extract_auth_internal(parts, state).await? { 420 436 ExtractedAuth::Service(claims) => { 421 - let did: Did = claims 422 - .iss 423 - .parse() 424 - .map_err(|_| AuthError::AuthenticationFailed)?; 437 + let did = claims.iss.clone(); 425 438 Ok(ServiceAuth { did, claims }) 426 439 } 427 440 ExtractedAuth::User(_) => Err(AuthError::AuthenticationFailed), ··· 438 451 ) -> Result<Option<Self>, Self::Rejection> { 439 452 match extract_auth_internal(parts, state).await { 440 453 Ok(ExtractedAuth::Service(claims)) => { 441 - let did: Did = claims 442 - .iss 443 - .parse() 444 - .map_err(|_| AuthError::AuthenticationFailed)?; 454 + let did = claims.iss.clone(); 445 455 Ok(Some(ServiceAuth { did, claims })) 446 456 } 447 457 Ok(ExtractedAuth::User(_)) => Err(AuthError::AuthenticationFailed), ··· 503 513 Ok(AuthAny::User(Auth(user, PhantomData))) 504 514 } 505 515 ExtractedAuth::Service(claims) => { 506 - let did: Did = claims 507 - .iss 508 - .parse() 509 - .map_err(|_| AuthError::AuthenticationFailed)?; 516 + let did = claims.iss.clone(); 510 517 Ok(AuthAny::Service(ServiceAuth { did, claims })) 511 518 } 512 519 } ··· 526 533 Ok(Some(AuthAny::User(Auth(user, PhantomData)))) 527 534 } 528 535 Ok(ExtractedAuth::Service(claims)) => { 529 - let did: Did = claims 530 - .iss 531 - .parse() 532 - .map_err(|_| AuthError::AuthenticationFailed)?; 536 + let did = claims.iss.clone(); 533 537 Ok(Some(AuthAny::Service(ServiceAuth { did, claims }))) 534 538 } 535 539 Err(AuthError::MissingToken) => Ok(None),
+135
crates/tranquil-pds/src/auth/login_identifier.rs
··· 1 + use std::fmt; 2 + 3 + #[derive(Debug, Clone, PartialEq, Eq)] 4 + pub struct NormalizedLoginIdentifier(String); 5 + 6 + impl NormalizedLoginIdentifier { 7 + pub fn normalize(identifier: &str, pds_hostname: &str) -> Self { 8 + let trimmed = identifier.trim(); 9 + let stripped = trimmed.strip_prefix('@').unwrap_or(trimmed); 10 + 11 + let normalized = match () { 12 + _ if stripped.starts_with("did:") => stripped.to_string(), 13 + _ if stripped.contains('@') => stripped.to_string(), 14 + _ if !stripped.contains('.') => { 15 + format!("{}.{}", stripped.to_lowercase(), pds_hostname) 16 + } 17 + _ => stripped.to_lowercase(), 18 + }; 19 + 20 + Self(normalized) 21 + } 22 + 23 + pub fn as_str(&self) -> &str { 24 + &self.0 25 + } 26 + } 27 + 28 + impl AsRef<str> for NormalizedLoginIdentifier { 29 + fn as_ref(&self) -> &str { 30 + &self.0 31 + } 32 + } 33 + 34 + impl fmt::Display for NormalizedLoginIdentifier { 35 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 36 + write!(f, "{}", self.0) 37 + } 38 + } 39 + 40 + #[derive(Debug, Clone, PartialEq, Eq)] 41 + pub struct BareLoginIdentifier(String); 42 + 43 + impl BareLoginIdentifier { 44 + pub fn from_identifier(identifier: &str, pds_hostname: &str) -> Self { 45 + let trimmed = identifier.trim(); 46 + let stripped = trimmed.strip_prefix('@').unwrap_or(trimmed); 47 + let suffix = format!(".{}", pds_hostname); 48 + let bare = stripped.strip_suffix(&suffix).unwrap_or(stripped); 49 + Self(bare.to_string()) 50 + } 51 + 52 + pub fn as_str(&self) -> &str { 53 + &self.0 54 + } 55 + } 56 + 57 + impl AsRef<str> for BareLoginIdentifier { 58 + fn as_ref(&self) -> &str { 59 + &self.0 60 + } 61 + } 62 + 63 + impl fmt::Display for BareLoginIdentifier { 64 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 65 + write!(f, "{}", self.0) 66 + } 67 + } 68 + 69 + #[cfg(test)] 70 + mod tests { 71 + use super::*; 72 + 73 + #[test] 74 + fn normalized_identifier_handles_did() { 75 + let id = NormalizedLoginIdentifier::normalize("did:plc:abc123", "example.com"); 76 + assert_eq!(id.as_str(), "did:plc:abc123"); 77 + } 78 + 79 + #[test] 80 + fn normalized_identifier_handles_email() { 81 + let id = NormalizedLoginIdentifier::normalize("user@example.org", "pds.example.com"); 82 + assert_eq!(id.as_str(), "user@example.org"); 83 + } 84 + 85 + #[test] 86 + fn normalized_identifier_handles_bare_handle() { 87 + let id = NormalizedLoginIdentifier::normalize("alice", "pds.example.com"); 88 + assert_eq!(id.as_str(), "alice.pds.example.com"); 89 + } 90 + 91 + #[test] 92 + fn normalized_identifier_handles_bare_handle_with_at_prefix() { 93 + let id = NormalizedLoginIdentifier::normalize("@alice", "pds.example.com"); 94 + assert_eq!(id.as_str(), "alice.pds.example.com"); 95 + } 96 + 97 + #[test] 98 + fn normalized_identifier_handles_full_handle() { 99 + let id = NormalizedLoginIdentifier::normalize("alice.bsky.social", "pds.example.com"); 100 + assert_eq!(id.as_str(), "alice.bsky.social"); 101 + } 102 + 103 + #[test] 104 + fn normalized_identifier_handles_uppercase() { 105 + let id = NormalizedLoginIdentifier::normalize("ALICE", "pds.example.com"); 106 + assert_eq!(id.as_str(), "alice.pds.example.com"); 107 + 108 + let id2 = NormalizedLoginIdentifier::normalize("ALICE.BSKY.SOCIAL", "pds.example.com"); 109 + assert_eq!(id2.as_str(), "alice.bsky.social"); 110 + } 111 + 112 + #[test] 113 + fn normalized_identifier_trims_whitespace() { 114 + let id = NormalizedLoginIdentifier::normalize(" alice ", "pds.example.com"); 115 + assert_eq!(id.as_str(), "alice.pds.example.com"); 116 + } 117 + 118 + #[test] 119 + fn bare_identifier_strips_hostname_suffix() { 120 + let id = BareLoginIdentifier::from_identifier("alice.pds.example.com", "pds.example.com"); 121 + assert_eq!(id.as_str(), "alice"); 122 + } 123 + 124 + #[test] 125 + fn bare_identifier_preserves_non_matching() { 126 + let id = BareLoginIdentifier::from_identifier("alice.bsky.social", "pds.example.com"); 127 + assert_eq!(id.as_str(), "alice.bsky.social"); 128 + } 129 + 130 + #[test] 131 + fn bare_identifier_strips_at_prefix() { 132 + let id = BareLoginIdentifier::from_identifier("@alice.pds.example.com", "pds.example.com"); 133 + assert_eq!(id.as_str(), "alice"); 134 + } 135 + }
+234
crates/tranquil-pds/src/auth/mfa_verified.rs
··· 1 + use axum::response::Response; 2 + 3 + use super::AuthenticatedUser; 4 + use crate::state::AppState; 5 + use crate::types::Did; 6 + 7 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 8 + pub enum MfaMethod { 9 + Totp, 10 + Passkey, 11 + Password, 12 + RecoveryCode, 13 + SessionReauth, 14 + } 15 + 16 + impl MfaMethod { 17 + pub fn as_str(&self) -> &'static str { 18 + match self { 19 + Self::Totp => "totp", 20 + Self::Passkey => "passkey", 21 + Self::Password => "password", 22 + Self::RecoveryCode => "recovery_code", 23 + Self::SessionReauth => "session_reauth", 24 + } 25 + } 26 + } 27 + 28 + impl std::fmt::Display for MfaMethod { 29 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 30 + write!(f, "{}", self.as_str()) 31 + } 32 + } 33 + 34 + pub struct MfaVerified<'a> { 35 + user: &'a AuthenticatedUser, 36 + method: MfaMethod, 37 + } 38 + 39 + impl<'a> MfaVerified<'a> { 40 + fn new(user: &'a AuthenticatedUser, method: MfaMethod) -> Self { 41 + Self { user, method } 42 + } 43 + 44 + pub(crate) fn from_totp(user: &'a AuthenticatedUser) -> Self { 45 + Self::new(user, MfaMethod::Totp) 46 + } 47 + 48 + pub(crate) fn from_password(user: &'a AuthenticatedUser) -> Self { 49 + Self::new(user, MfaMethod::Password) 50 + } 51 + 52 + pub(crate) fn from_recovery_code(user: &'a AuthenticatedUser) -> Self { 53 + Self::new(user, MfaMethod::RecoveryCode) 54 + } 55 + 56 + pub(crate) fn from_session_reauth(user: &'a AuthenticatedUser) -> Self { 57 + Self::new(user, MfaMethod::SessionReauth) 58 + } 59 + 60 + pub fn user(&self) -> &AuthenticatedUser { 61 + self.user 62 + } 63 + 64 + pub fn did(&self) -> &Did { 65 + &self.user.did 66 + } 67 + 68 + pub fn method(&self) -> MfaMethod { 69 + self.method 70 + } 71 + } 72 + 73 + pub async fn require_legacy_session_mfa<'a>( 74 + state: &AppState, 75 + user: &'a AuthenticatedUser, 76 + ) -> Result<MfaVerified<'a>, Response> { 77 + use crate::api::server::reauth::{check_legacy_session_mfa, legacy_mfa_required_response}; 78 + 79 + if check_legacy_session_mfa(&*state.session_repo, &user.did).await { 80 + Ok(MfaVerified::from_session_reauth(user)) 81 + } else { 82 + Err(legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 83 + } 84 + } 85 + 86 + pub async fn require_reauth_window<'a>( 87 + state: &AppState, 88 + user: &'a AuthenticatedUser, 89 + ) -> Result<MfaVerified<'a>, Response> { 90 + use crate::api::server::reauth::{REAUTH_WINDOW_SECONDS, reauth_required_response}; 91 + use chrono::Utc; 92 + 93 + let status = state 94 + .session_repo 95 + .get_session_mfa_status(&user.did) 96 + .await 97 + .ok() 98 + .flatten(); 99 + 100 + match status { 101 + Some(s) => { 102 + if let Some(last_reauth) = s.last_reauth_at { 103 + let elapsed = Utc::now().signed_duration_since(last_reauth); 104 + if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 105 + return Ok(MfaVerified::from_session_reauth(user)); 106 + } 107 + } 108 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 109 + } 110 + None => { 111 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 112 + } 113 + } 114 + } 115 + 116 + pub async fn require_reauth_window_if_available<'a>( 117 + state: &AppState, 118 + user: &'a AuthenticatedUser, 119 + ) -> Result<Option<MfaVerified<'a>>, Response> { 120 + use crate::api::server::reauth::{check_reauth_required_cached, reauth_required_response}; 121 + 122 + let has_password = state 123 + .user_repo 124 + .has_password_by_did(&user.did) 125 + .await 126 + .ok() 127 + .flatten() 128 + .unwrap_or(false); 129 + let has_passkeys = state 130 + .user_repo 131 + .has_passkeys(&user.did) 132 + .await 133 + .unwrap_or(false); 134 + let has_totp = state 135 + .user_repo 136 + .has_totp_enabled(&user.did) 137 + .await 138 + .unwrap_or(false); 139 + 140 + let has_any_reauth_method = has_password || has_passkeys || has_totp; 141 + 142 + if !has_any_reauth_method { 143 + return Ok(None); 144 + } 145 + 146 + if check_reauth_required_cached(&*state.session_repo, &state.cache, &user.did).await { 147 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 148 + } else { 149 + Ok(Some(MfaVerified::from_session_reauth(user))) 150 + } 151 + } 152 + 153 + pub async fn verify_password_mfa<'a>( 154 + state: &AppState, 155 + user: &'a AuthenticatedUser, 156 + password: &str, 157 + ) -> Result<MfaVerified<'a>, crate::api::error::ApiError> { 158 + let hash = state 159 + .user_repo 160 + .get_password_hash_by_did(&user.did) 161 + .await 162 + .ok() 163 + .flatten(); 164 + 165 + match hash { 166 + Some(h) => { 167 + if bcrypt::verify(password, &h).unwrap_or(false) { 168 + Ok(MfaVerified::from_password(user)) 169 + } else { 170 + Err(crate::api::error::ApiError::InvalidPassword( 171 + "Password is incorrect".into(), 172 + )) 173 + } 174 + } 175 + None => Err(crate::api::error::ApiError::AccountNotFound), 176 + } 177 + } 178 + 179 + pub async fn verify_totp_mfa<'a>( 180 + state: &AppState, 181 + user: &'a AuthenticatedUser, 182 + code: &str, 183 + ) -> Result<MfaVerified<'a>, crate::api::error::ApiError> { 184 + use crate::auth::{decrypt_totp_secret, is_backup_code_format, verify_totp_code}; 185 + use tranquil_db_traits::TotpRecordState; 186 + 187 + let code = code.trim(); 188 + 189 + if is_backup_code_format(code) { 190 + let backup_codes = state 191 + .user_repo 192 + .get_unused_backup_codes(&user.did) 193 + .await 194 + .ok() 195 + .unwrap_or_default(); 196 + let code_upper = code.to_uppercase(); 197 + 198 + let matched = backup_codes 199 + .iter() 200 + .find(|row| crate::auth::verify_backup_code(&code_upper, &row.code_hash)); 201 + 202 + return match matched { 203 + Some(row) => { 204 + let _ = state.user_repo.mark_backup_code_used(row.id).await; 205 + Ok(MfaVerified::from_recovery_code(user)) 206 + } 207 + None => Err(crate::api::error::ApiError::InvalidCode(Some( 208 + "Invalid backup code".into(), 209 + ))), 210 + }; 211 + } 212 + 213 + let verified_record = match state.user_repo.get_totp_record_state(&user.did).await { 214 + Ok(Some(TotpRecordState::Verified(record))) => record, 215 + _ => { 216 + return Err(crate::api::error::ApiError::TotpNotEnabled); 217 + } 218 + }; 219 + 220 + let secret = decrypt_totp_secret( 221 + &verified_record.secret_encrypted, 222 + verified_record.encryption_version, 223 + ) 224 + .map_err(|_| crate::api::error::ApiError::InternalError(None))?; 225 + 226 + if verify_totp_code(&secret, code) { 227 + let _ = state.user_repo.update_totp_last_used(&user.did).await; 228 + Ok(MfaVerified::from_totp(user)) 229 + } else { 230 + Err(crate::api::error::ApiError::InvalidCode(Some( 231 + "Invalid verification code".into(), 232 + ))) 233 + } 234 + }
+23 -4
crates/tranquil-pds/src/auth/mod.rs
··· 10 10 use tranquil_db::UserRepository; 11 11 use tranquil_db_traits::OAuthRepository; 12 12 13 + pub mod account_verified; 13 14 pub mod extractor; 15 + pub mod login_identifier; 16 + pub mod mfa_verified; 14 17 pub mod scope_check; 18 + pub mod scope_verified; 15 19 pub mod service; 16 20 pub mod verification_token; 17 21 pub mod webauthn; 18 22 23 + pub use login_identifier::{BareLoginIdentifier, NormalizedLoginIdentifier}; 24 + 25 + pub use account_verified::{AccountVerified, require_not_migrated, require_verified_or_delegated}; 19 26 pub use extractor::{ 20 27 Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown, 21 28 Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header, 22 29 }; 30 + pub use mfa_verified::{ 31 + MfaMethod, MfaVerified, require_legacy_session_mfa, require_reauth_window, 32 + require_reauth_window_if_available, verify_password_mfa, verify_totp_mfa, 33 + }; 34 + pub use scope_verified::{ 35 + AccountManage, AccountRead, BatchWriteScopes, BlobScopeAction, BlobUpload, ControllerDid, 36 + IdentityAccess, PrincipalDid, RepoCreate, RepoDelete, RepoScopeAction, RepoUpdate, RepoUpsert, 37 + RpcCall, ScopeAction, ScopeVerificationError, ScopeVerified, VerifyScope, WriteOpKind, 38 + verify_batch_write_scopes, 39 + }; 23 40 pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 24 41 25 42 pub use tranquil_auth::{ ··· 409 426 .claims 410 427 .act 411 428 .as_ref() 412 - .map(|a| Did::new_unchecked(a.sub.clone())); 429 + .map(|a| unsafe { Did::new_unchecked(a.sub.clone()) }); 413 430 let status = 414 431 AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at); 415 432 return Ok(AuthenticatedUser { ··· 461 478 None 462 479 }; 463 480 return Ok(AuthenticatedUser { 464 - did: Did::new_unchecked(oauth_token.did), 481 + did: unsafe { Did::new_unchecked(oauth_token.did) }, 465 482 key_bytes, 466 483 is_admin: oauth_token.is_admin, 467 484 status, 468 485 scope: oauth_info.scope, 469 - controller_did: oauth_info.controller_did.map(Did::new_unchecked), 486 + controller_did: oauth_info 487 + .controller_did 488 + .map(|d| unsafe { Did::new_unchecked(d) }), 470 489 auth_source: AuthSource::OAuth, 471 490 }); 472 491 } else { ··· 545 564 None 546 565 }; 547 566 Ok(AuthenticatedUser { 548 - did: Did::new_unchecked(result.did), 567 + did: unsafe { Did::new_unchecked(result.did) }, 549 568 key_bytes, 550 569 is_admin: user_info.is_admin, 551 570 status,
+502
crates/tranquil-pds/src/auth/scope_verified.rs
··· 1 + use std::marker::PhantomData; 2 + use std::ops::Deref; 3 + 4 + use axum::response::{IntoResponse, Response}; 5 + 6 + use crate::api::error::ApiError; 7 + use crate::oauth::scopes::{ 8 + AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions, 9 + }; 10 + use crate::types::Did; 11 + 12 + use super::AuthenticatedUser; 13 + 14 + #[derive(Debug, Clone)] 15 + pub struct PrincipalDid(Did); 16 + 17 + impl PrincipalDid { 18 + pub fn as_did(&self) -> &Did { 19 + &self.0 20 + } 21 + 22 + pub fn into_did(self) -> Did { 23 + self.0 24 + } 25 + 26 + pub fn as_str(&self) -> &str { 27 + self.0.as_str() 28 + } 29 + } 30 + 31 + impl Deref for PrincipalDid { 32 + type Target = Did; 33 + 34 + fn deref(&self) -> &Self::Target { 35 + &self.0 36 + } 37 + } 38 + 39 + impl AsRef<Did> for PrincipalDid { 40 + fn as_ref(&self) -> &Did { 41 + &self.0 42 + } 43 + } 44 + 45 + impl std::fmt::Display for PrincipalDid { 46 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 47 + self.0.fmt(f) 48 + } 49 + } 50 + 51 + #[derive(Debug, Clone)] 52 + pub struct ControllerDid(Did); 53 + 54 + impl ControllerDid { 55 + pub fn as_did(&self) -> &Did { 56 + &self.0 57 + } 58 + 59 + pub fn into_did(self) -> Did { 60 + self.0 61 + } 62 + 63 + pub fn as_str(&self) -> &str { 64 + self.0.as_str() 65 + } 66 + } 67 + 68 + impl Deref for ControllerDid { 69 + type Target = Did; 70 + 71 + fn deref(&self) -> &Self::Target { 72 + &self.0 73 + } 74 + } 75 + 76 + impl AsRef<Did> for ControllerDid { 77 + fn as_ref(&self) -> &Did { 78 + &self.0 79 + } 80 + } 81 + 82 + impl std::fmt::Display for ControllerDid { 83 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 84 + self.0.fmt(f) 85 + } 86 + } 87 + 88 + #[derive(Debug)] 89 + pub struct ScopeVerificationError { 90 + message: String, 91 + } 92 + 93 + impl ScopeVerificationError { 94 + pub fn new(message: impl Into<String>) -> Self { 95 + Self { 96 + message: message.into(), 97 + } 98 + } 99 + 100 + pub fn message(&self) -> &str { 101 + &self.message 102 + } 103 + } 104 + 105 + impl std::fmt::Display for ScopeVerificationError { 106 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 107 + write!(f, "{}", self.message) 108 + } 109 + } 110 + 111 + impl std::error::Error for ScopeVerificationError {} 112 + 113 + impl IntoResponse for ScopeVerificationError { 114 + fn into_response(self) -> Response { 115 + ApiError::InsufficientScope(Some(self.message)).into_response() 116 + } 117 + } 118 + 119 + mod private { 120 + pub trait Sealed {} 121 + pub trait RepoScopeSealed {} 122 + pub trait BlobScopeSealed {} 123 + } 124 + 125 + pub trait ScopeAction: private::Sealed {} 126 + 127 + pub trait RepoScopeAction: ScopeAction + private::RepoScopeSealed {} 128 + 129 + pub trait BlobScopeAction: ScopeAction + private::BlobScopeSealed {} 130 + 131 + pub struct RepoCreate; 132 + pub struct RepoUpdate; 133 + pub struct RepoDelete; 134 + pub struct RepoUpsert; 135 + pub struct BlobUpload; 136 + pub struct RpcCall; 137 + pub struct AccountRead; 138 + pub struct AccountManage; 139 + pub struct IdentityAccess; 140 + 141 + impl private::Sealed for RepoCreate {} 142 + impl private::Sealed for RepoUpdate {} 143 + impl private::Sealed for RepoDelete {} 144 + impl private::Sealed for RepoUpsert {} 145 + impl private::Sealed for BlobUpload {} 146 + impl private::Sealed for RpcCall {} 147 + impl private::Sealed for AccountRead {} 148 + impl private::Sealed for AccountManage {} 149 + impl private::Sealed for IdentityAccess {} 150 + 151 + impl private::RepoScopeSealed for RepoCreate {} 152 + impl private::RepoScopeSealed for RepoUpdate {} 153 + impl private::RepoScopeSealed for RepoDelete {} 154 + impl private::RepoScopeSealed for RepoUpsert {} 155 + 156 + impl private::BlobScopeSealed for BlobUpload {} 157 + 158 + impl ScopeAction for RepoCreate {} 159 + impl ScopeAction for RepoUpdate {} 160 + impl ScopeAction for RepoDelete {} 161 + impl ScopeAction for RepoUpsert {} 162 + impl ScopeAction for BlobUpload {} 163 + impl ScopeAction for RpcCall {} 164 + impl ScopeAction for AccountRead {} 165 + impl ScopeAction for AccountManage {} 166 + impl ScopeAction for IdentityAccess {} 167 + 168 + impl RepoScopeAction for RepoCreate {} 169 + impl RepoScopeAction for RepoUpdate {} 170 + impl RepoScopeAction for RepoDelete {} 171 + impl RepoScopeAction for RepoUpsert {} 172 + 173 + impl BlobScopeAction for BlobUpload {} 174 + 175 + pub struct ScopeVerified<'a, A: ScopeAction> { 176 + user: &'a AuthenticatedUser, 177 + _action: PhantomData<A>, 178 + } 179 + 180 + impl<'a, A: ScopeAction> ScopeVerified<'a, A> { 181 + pub fn user(&self) -> &AuthenticatedUser { 182 + self.user 183 + } 184 + 185 + pub fn principal_did(&self) -> PrincipalDid { 186 + PrincipalDid(self.user.did.clone()) 187 + } 188 + 189 + pub fn controller_did(&self) -> Option<ControllerDid> { 190 + self.user.controller_did.clone().map(ControllerDid) 191 + } 192 + 193 + pub fn is_admin(&self) -> bool { 194 + self.user.is_admin 195 + } 196 + } 197 + 198 + pub struct BatchWriteScopes<'a> { 199 + user: &'a AuthenticatedUser, 200 + has_creates: bool, 201 + has_updates: bool, 202 + has_deletes: bool, 203 + } 204 + 205 + impl<'a> BatchWriteScopes<'a> { 206 + pub fn principal_did(&self) -> PrincipalDid { 207 + PrincipalDid(self.user.did.clone()) 208 + } 209 + 210 + pub fn controller_did(&self) -> Option<ControllerDid> { 211 + self.user.controller_did.clone().map(ControllerDid) 212 + } 213 + 214 + pub fn user(&self) -> &AuthenticatedUser { 215 + self.user 216 + } 217 + 218 + pub fn has_creates(&self) -> bool { 219 + self.has_creates 220 + } 221 + 222 + pub fn has_updates(&self) -> bool { 223 + self.has_updates 224 + } 225 + 226 + pub fn has_deletes(&self) -> bool { 227 + self.has_deletes 228 + } 229 + } 230 + 231 + pub fn verify_batch_write_scopes<'a, T, C, F>( 232 + auth: &'a impl VerifyScope, 233 + user: &'a AuthenticatedUser, 234 + writes: &[T], 235 + get_collection: F, 236 + classify: C, 237 + ) -> Result<BatchWriteScopes<'a>, ScopeVerificationError> 238 + where 239 + F: Fn(&T) -> &str, 240 + C: Fn(&T) -> WriteOpKind, 241 + { 242 + use std::collections::HashSet; 243 + 244 + let create_collections: HashSet<&str> = writes 245 + .iter() 246 + .filter(|w| matches!(classify(w), WriteOpKind::Create)) 247 + .map(&get_collection) 248 + .collect(); 249 + 250 + let update_collections: HashSet<&str> = writes 251 + .iter() 252 + .filter(|w| matches!(classify(w), WriteOpKind::Update)) 253 + .map(&get_collection) 254 + .collect(); 255 + 256 + let delete_collections: HashSet<&str> = writes 257 + .iter() 258 + .filter(|w| matches!(classify(w), WriteOpKind::Delete)) 259 + .map(&get_collection) 260 + .collect(); 261 + 262 + if auth.needs_scope_check() { 263 + create_collections.iter().try_for_each(|c| { 264 + auth.permissions() 265 + .assert_repo(RepoAction::Create, c) 266 + .map_err(|e| ScopeVerificationError::new(e.to_string())) 267 + })?; 268 + 269 + update_collections.iter().try_for_each(|c| { 270 + auth.permissions() 271 + .assert_repo(RepoAction::Update, c) 272 + .map_err(|e| ScopeVerificationError::new(e.to_string())) 273 + })?; 274 + 275 + delete_collections.iter().try_for_each(|c| { 276 + auth.permissions() 277 + .assert_repo(RepoAction::Delete, c) 278 + .map_err(|e| ScopeVerificationError::new(e.to_string())) 279 + })?; 280 + } 281 + 282 + Ok(BatchWriteScopes { 283 + user, 284 + has_creates: !create_collections.is_empty(), 285 + has_updates: !update_collections.is_empty(), 286 + has_deletes: !delete_collections.is_empty(), 287 + }) 288 + } 289 + 290 + #[derive(Clone, Copy)] 291 + pub enum WriteOpKind { 292 + Create, 293 + Update, 294 + Delete, 295 + } 296 + 297 + pub trait VerifyScope { 298 + fn needs_scope_check(&self) -> bool; 299 + fn permissions(&self) -> ScopePermissions; 300 + 301 + fn verify_repo_create<'a>( 302 + &'a self, 303 + collection: &str, 304 + ) -> Result<ScopeVerified<'a, RepoCreate>, ScopeVerificationError> 305 + where 306 + Self: AsRef<AuthenticatedUser>, 307 + { 308 + if !self.needs_scope_check() { 309 + return Ok(ScopeVerified { 310 + user: self.as_ref(), 311 + _action: PhantomData, 312 + }); 313 + } 314 + self.permissions() 315 + .assert_repo(RepoAction::Create, collection) 316 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 317 + Ok(ScopeVerified { 318 + user: self.as_ref(), 319 + _action: PhantomData, 320 + }) 321 + } 322 + 323 + fn verify_repo_update<'a>( 324 + &'a self, 325 + collection: &str, 326 + ) -> Result<ScopeVerified<'a, RepoUpdate>, ScopeVerificationError> 327 + where 328 + Self: AsRef<AuthenticatedUser>, 329 + { 330 + if !self.needs_scope_check() { 331 + return Ok(ScopeVerified { 332 + user: self.as_ref(), 333 + _action: PhantomData, 334 + }); 335 + } 336 + self.permissions() 337 + .assert_repo(RepoAction::Update, collection) 338 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 339 + Ok(ScopeVerified { 340 + user: self.as_ref(), 341 + _action: PhantomData, 342 + }) 343 + } 344 + 345 + fn verify_repo_delete<'a>( 346 + &'a self, 347 + collection: &str, 348 + ) -> Result<ScopeVerified<'a, RepoDelete>, ScopeVerificationError> 349 + where 350 + Self: AsRef<AuthenticatedUser>, 351 + { 352 + if !self.needs_scope_check() { 353 + return Ok(ScopeVerified { 354 + user: self.as_ref(), 355 + _action: PhantomData, 356 + }); 357 + } 358 + self.permissions() 359 + .assert_repo(RepoAction::Delete, collection) 360 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 361 + Ok(ScopeVerified { 362 + user: self.as_ref(), 363 + _action: PhantomData, 364 + }) 365 + } 366 + 367 + fn verify_repo_upsert<'a>( 368 + &'a self, 369 + collection: &str, 370 + ) -> Result<ScopeVerified<'a, RepoUpsert>, ScopeVerificationError> 371 + where 372 + Self: AsRef<AuthenticatedUser>, 373 + { 374 + if !self.needs_scope_check() { 375 + return Ok(ScopeVerified { 376 + user: self.as_ref(), 377 + _action: PhantomData, 378 + }); 379 + } 380 + self.permissions() 381 + .assert_repo(RepoAction::Create, collection) 382 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 383 + self.permissions() 384 + .assert_repo(RepoAction::Update, collection) 385 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 386 + Ok(ScopeVerified { 387 + user: self.as_ref(), 388 + _action: PhantomData, 389 + }) 390 + } 391 + 392 + fn verify_blob_upload<'a>( 393 + &'a self, 394 + mime_type: &str, 395 + ) -> Result<ScopeVerified<'a, BlobUpload>, ScopeVerificationError> 396 + where 397 + Self: AsRef<AuthenticatedUser>, 398 + { 399 + if !self.needs_scope_check() { 400 + return Ok(ScopeVerified { 401 + user: self.as_ref(), 402 + _action: PhantomData, 403 + }); 404 + } 405 + self.permissions() 406 + .assert_blob(mime_type) 407 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 408 + Ok(ScopeVerified { 409 + user: self.as_ref(), 410 + _action: PhantomData, 411 + }) 412 + } 413 + 414 + fn verify_rpc<'a>( 415 + &'a self, 416 + aud: &str, 417 + lxm: &str, 418 + ) -> Result<ScopeVerified<'a, RpcCall>, ScopeVerificationError> 419 + where 420 + Self: AsRef<AuthenticatedUser>, 421 + { 422 + if !self.needs_scope_check() { 423 + return Ok(ScopeVerified { 424 + user: self.as_ref(), 425 + _action: PhantomData, 426 + }); 427 + } 428 + self.permissions() 429 + .assert_rpc(aud, lxm) 430 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 431 + Ok(ScopeVerified { 432 + user: self.as_ref(), 433 + _action: PhantomData, 434 + }) 435 + } 436 + 437 + fn verify_account_read<'a>( 438 + &'a self, 439 + attr: AccountAttr, 440 + ) -> Result<ScopeVerified<'a, AccountRead>, ScopeVerificationError> 441 + where 442 + Self: AsRef<AuthenticatedUser>, 443 + { 444 + if !self.needs_scope_check() { 445 + return Ok(ScopeVerified { 446 + user: self.as_ref(), 447 + _action: PhantomData, 448 + }); 449 + } 450 + self.permissions() 451 + .assert_account(attr, AccountAction::Read) 452 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 453 + Ok(ScopeVerified { 454 + user: self.as_ref(), 455 + _action: PhantomData, 456 + }) 457 + } 458 + 459 + fn verify_account_manage<'a>( 460 + &'a self, 461 + attr: AccountAttr, 462 + ) -> Result<ScopeVerified<'a, AccountManage>, ScopeVerificationError> 463 + where 464 + Self: AsRef<AuthenticatedUser>, 465 + { 466 + if !self.needs_scope_check() { 467 + return Ok(ScopeVerified { 468 + user: self.as_ref(), 469 + _action: PhantomData, 470 + }); 471 + } 472 + self.permissions() 473 + .assert_account(attr, AccountAction::Manage) 474 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 475 + Ok(ScopeVerified { 476 + user: self.as_ref(), 477 + _action: PhantomData, 478 + }) 479 + } 480 + 481 + fn verify_identity<'a>( 482 + &'a self, 483 + attr: IdentityAttr, 484 + ) -> Result<ScopeVerified<'a, IdentityAccess>, ScopeVerificationError> 485 + where 486 + Self: AsRef<AuthenticatedUser>, 487 + { 488 + if !self.needs_scope_check() { 489 + return Ok(ScopeVerified { 490 + user: self.as_ref(), 491 + _action: PhantomData, 492 + }); 493 + } 494 + self.permissions() 495 + .assert_identity(attr) 496 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 497 + Ok(ScopeVerified { 498 + user: self.as_ref(), 499 + _action: PhantomData, 500 + }) 501 + } 502 + }
+10 -9
crates/tranquil-pds/src/auth/service.rs
··· 1 + use crate::types::Did; 2 + use crate::util::pds_hostname; 1 3 use anyhow::{Result, anyhow}; 2 4 use base64::Engine as _; 3 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; ··· 42 44 43 45 #[derive(Debug, Clone, Serialize, Deserialize)] 44 46 pub struct ServiceTokenClaims { 45 - pub iss: String, 47 + pub iss: Did, 46 48 #[serde(default)] 47 - pub sub: Option<String>, 48 - pub aud: String, 49 + pub sub: Option<Did>, 50 + pub aud: Did, 49 51 pub exp: usize, 50 52 #[serde(default)] 51 53 pub iat: Option<usize>, ··· 56 58 } 57 59 58 60 impl ServiceTokenClaims { 59 - pub fn subject(&self) -> &str { 60 - self.sub.as_deref().unwrap_or(&self.iss) 61 + pub fn subject(&self) -> &Did { 62 + self.sub.as_ref().unwrap_or(&self.iss) 61 63 } 62 64 } 63 65 ··· 78 80 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL") 79 81 .unwrap_or_else(|_| "https://plc.directory".to_string()); 80 82 81 - let pds_hostname = 82 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 + let pds_hostname = pds_hostname(); 83 84 let pds_did = format!("did:web:{}", pds_hostname); 84 85 85 86 let client = Client::builder() ··· 130 131 return Err(anyhow!("Token expired")); 131 132 } 132 133 133 - if claims.aud != self.pds_did { 134 + if claims.aud.as_str() != self.pds_did { 134 135 return Err(anyhow!( 135 136 "Invalid audience: expected {}, got {}", 136 137 self.pds_did, ··· 154 155 } 155 156 } 156 157 157 - let did = &claims.iss; 158 + let did = claims.iss.as_str(); 158 159 let public_key = self.resolve_signing_key(did).await?; 159 160 160 161 let signature_bytes = URL_SAFE_NO_PAD
+1 -1
crates/tranquil-pds/src/auth/webauthn.rs
··· 7 7 8 8 impl WebAuthnConfig { 9 9 pub fn new(hostname: &str) -> Result<Self, String> { 10 - let rp_id = hostname.to_string(); 10 + let rp_id = hostname.split(':').next().unwrap_or(hostname).to_string(); 11 11 let rp_origin = Url::parse(&format!("https://{}", hostname)) 12 12 .map_err(|e| format!("Invalid origin URL: {}", e))?; 13 13
+101
crates/tranquil-pds/src/cid_types.rs
··· 1 + use cid::Cid; 2 + use std::fmt; 3 + use std::str::FromStr; 4 + 5 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 6 + pub struct CommitCid(Cid); 7 + 8 + impl CommitCid { 9 + pub fn new(cid: Cid) -> Self { 10 + Self(cid) 11 + } 12 + 13 + pub fn as_cid(&self) -> &Cid { 14 + &self.0 15 + } 16 + 17 + pub fn into_cid(self) -> Cid { 18 + self.0 19 + } 20 + } 21 + 22 + impl From<Cid> for CommitCid { 23 + fn from(cid: Cid) -> Self { 24 + Self(cid) 25 + } 26 + } 27 + 28 + impl From<CommitCid> for Cid { 29 + fn from(commit_cid: CommitCid) -> Self { 30 + commit_cid.0 31 + } 32 + } 33 + 34 + impl FromStr for CommitCid { 35 + type Err = cid::Error; 36 + 37 + fn from_str(s: &str) -> Result<Self, Self::Err> { 38 + Cid::from_str(s).map(Self) 39 + } 40 + } 41 + 42 + impl fmt::Display for CommitCid { 43 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 44 + write!(f, "{}", self.0) 45 + } 46 + } 47 + 48 + impl AsRef<Cid> for CommitCid { 49 + fn as_ref(&self) -> &Cid { 50 + &self.0 51 + } 52 + } 53 + 54 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 55 + pub struct RecordCid(Cid); 56 + 57 + impl RecordCid { 58 + pub fn new(cid: Cid) -> Self { 59 + Self(cid) 60 + } 61 + 62 + pub fn as_cid(&self) -> &Cid { 63 + &self.0 64 + } 65 + 66 + pub fn into_cid(self) -> Cid { 67 + self.0 68 + } 69 + } 70 + 71 + impl From<Cid> for RecordCid { 72 + fn from(cid: Cid) -> Self { 73 + Self(cid) 74 + } 75 + } 76 + 77 + impl From<RecordCid> for Cid { 78 + fn from(record_cid: RecordCid) -> Self { 79 + record_cid.0 80 + } 81 + } 82 + 83 + impl FromStr for RecordCid { 84 + type Err = cid::Error; 85 + 86 + fn from_str(s: &str) -> Result<Self, Self::Err> { 87 + Cid::from_str(s).map(Self) 88 + } 89 + } 90 + 91 + impl fmt::Display for RecordCid { 92 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 93 + write!(f, "{}", self.0) 94 + } 95 + } 96 + 97 + impl AsRef<Cid> for RecordCid { 98 + fn as_ref(&self) -> &Cid { 99 + &self.0 100 + } 101 + }
+11 -13
crates/tranquil-pds/src/comms/service.rs
··· 3 3 use std::time::Duration; 4 4 5 5 use chrono::Utc; 6 - use tokio::sync::watch; 7 6 use tokio::time::interval; 7 + use tokio_util::sync::CancellationToken; 8 8 use tracing::{debug, error, info, warn}; 9 9 use tranquil_comms::{ 10 10 CommsChannel, CommsSender, CommsStatus, CommsType, NewComms, SendError, format_message, ··· 96 96 !self.senders.is_empty() 97 97 } 98 98 99 - pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 99 + pub async fn run(self, shutdown: CancellationToken) { 100 100 if self.senders.is_empty() { 101 101 warn!( 102 102 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured." ··· 116 116 error!(error = %e, "Failed to process comms batch"); 117 117 } 118 118 } 119 - _ = shutdown.changed() => { 120 - if *shutdown.borrow() { 121 - info!("Comms service shutting down"); 122 - break; 123 - } 119 + _ = shutdown.cancelled() => { 120 + info!("Comms service shutting down"); 121 + break; 124 122 } 125 123 } 126 124 } ··· 278 276 &[("hostname", hostname), ("handle", &prefs.handle)], 279 277 ); 280 278 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 281 - let channel = channel_from_str(&prefs.preferred_channel); 279 + let channel = prefs.preferred_channel; 282 280 infra_repo 283 281 .enqueue_comms( 284 282 Some(user_id), ··· 309 307 &[("handle", &prefs.handle), ("code", code)], 310 308 ); 311 309 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 312 - let channel = channel_from_str(&prefs.preferred_channel); 310 + let channel = prefs.preferred_channel; 313 311 infra_repo 314 312 .enqueue_comms( 315 313 Some(user_id), ··· 422 420 &[("handle", &prefs.handle), ("code", code)], 423 421 ); 424 422 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 425 - let channel = channel_from_str(&prefs.preferred_channel); 423 + let channel = prefs.preferred_channel; 426 424 infra_repo 427 425 .enqueue_comms( 428 426 Some(user_id), ··· 453 451 &[("handle", &prefs.handle), ("token", token)], 454 452 ); 455 453 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 456 - let channel = channel_from_str(&prefs.preferred_channel); 454 + let channel = prefs.preferred_channel; 457 455 infra_repo 458 456 .enqueue_comms( 459 457 Some(user_id), ··· 484 482 &[("handle", &prefs.handle), ("url", recovery_url)], 485 483 ); 486 484 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 487 - let channel = channel_from_str(&prefs.preferred_channel); 485 + let channel = prefs.preferred_channel; 488 486 infra_repo 489 487 .enqueue_comms( 490 488 Some(user_id), ··· 614 612 &[("handle", &prefs.handle), ("code", code)], 615 613 ); 616 614 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 617 - let channel = channel_from_str(&prefs.preferred_channel); 615 + let channel = prefs.preferred_channel; 618 616 infra_repo 619 617 .enqueue_comms( 620 618 Some(user_id),
+14 -10
crates/tranquil-pds/src/crawlers.rs
··· 1 1 use crate::circuit_breaker::CircuitBreaker; 2 2 use crate::sync::firehose::SequencedEvent; 3 + use crate::util::pds_hostname; 3 4 use reqwest::Client; 4 5 use std::sync::Arc; 5 6 use std::sync::atomic::{AtomicU64, Ordering}; 6 7 use std::time::Duration; 7 - use tokio::sync::{broadcast, watch}; 8 + use tokio::sync::broadcast; 9 + use tokio_util::sync::CancellationToken; 8 10 use tracing::{debug, error, info, warn}; 11 + use tranquil_db_traits::RepoEventType; 9 12 10 13 const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; 11 14 ··· 40 43 } 41 44 42 45 pub fn from_env() -> Option<Self> { 43 - let hostname = std::env::var("PDS_HOSTNAME").ok()?; 46 + let hostname = pds_hostname(); 47 + if hostname == "localhost" { 48 + return None; 49 + } 44 50 45 51 let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 46 52 .unwrap_or_default() ··· 53 59 return None; 54 60 } 55 61 56 - Some(Self::new(hostname, crawler_urls)) 62 + Some(Self::new(hostname.to_string(), crawler_urls)) 57 63 } 58 64 59 65 fn should_notify(&self) -> bool { ··· 143 149 pub async fn start_crawlers_service( 144 150 crawlers: Arc<Crawlers>, 145 151 mut firehose_rx: broadcast::Receiver<SequencedEvent>, 146 - mut shutdown: watch::Receiver<bool>, 152 + shutdown: CancellationToken, 147 153 ) { 148 154 info!( 149 155 hostname = %crawlers.hostname, ··· 157 163 result = firehose_rx.recv() => { 158 164 match result { 159 165 Ok(event) => { 160 - if event.event_type == "commit" { 166 + if event.event_type == RepoEventType::Commit { 161 167 crawlers.notify_of_update().await; 162 168 } 163 169 } ··· 171 177 } 172 178 } 173 179 } 174 - _ = shutdown.changed() => { 175 - if *shutdown.borrow() { 176 - info!("Crawlers service shutting down"); 177 - break; 178 - } 180 + _ = shutdown.cancelled() => { 181 + info!("Crawlers service shutting down"); 182 + break; 179 183 } 180 184 } 181 185 }
+9 -1
crates/tranquil-pds/src/delegation/mod.rs
··· 1 + pub mod roles; 1 2 pub mod scopes; 2 3 3 - pub use scopes::{SCOPE_PRESETS, ScopePreset, intersect_scopes}; 4 + pub use roles::{ 5 + CanAddControllers, CanBeController, CanControlAccounts, verify_can_add_controllers, 6 + verify_can_be_controller, verify_can_control_accounts, 7 + }; 8 + pub use scopes::{ 9 + InvalidDelegationScopeError, SCOPE_PRESETS, ScopePreset, ValidatedDelegationScope, 10 + intersect_scopes, validate_delegation_scopes, 11 + }; 4 12 pub use tranquil_db_traits::DelegationActionType;
+108
crates/tranquil-pds/src/delegation/roles.rs
··· 1 + use axum::response::{IntoResponse, Response}; 2 + 3 + use crate::api::error::ApiError; 4 + use crate::auth::AuthenticatedUser; 5 + use crate::state::AppState; 6 + use crate::types::Did; 7 + 8 + pub struct CanAddControllers<'a> { 9 + user: &'a AuthenticatedUser, 10 + } 11 + 12 + pub struct CanControlAccounts<'a> { 13 + user: &'a AuthenticatedUser, 14 + } 15 + 16 + pub struct CanBeController<'a> { 17 + controller_did: &'a Did, 18 + } 19 + 20 + impl<'a> CanAddControllers<'a> { 21 + pub fn did(&self) -> &Did { 22 + &self.user.did 23 + } 24 + 25 + pub fn user(&self) -> &AuthenticatedUser { 26 + self.user 27 + } 28 + } 29 + 30 + impl<'a> CanControlAccounts<'a> { 31 + pub fn did(&self) -> &Did { 32 + &self.user.did 33 + } 34 + 35 + pub fn user(&self) -> &AuthenticatedUser { 36 + self.user 37 + } 38 + } 39 + 40 + impl<'a> CanBeController<'a> { 41 + pub fn did(&self) -> &Did { 42 + self.controller_did 43 + } 44 + } 45 + 46 + pub async fn verify_can_add_controllers<'a>( 47 + state: &AppState, 48 + user: &'a AuthenticatedUser, 49 + ) -> Result<CanAddControllers<'a>, Response> { 50 + match state.delegation_repo.controls_any_accounts(&user.did).await { 51 + Ok(true) => Err(ApiError::InvalidDelegation( 52 + "Cannot add controllers to an account that controls other accounts".into(), 53 + ) 54 + .into_response()), 55 + Ok(false) => Ok(CanAddControllers { user }), 56 + Err(e) => { 57 + tracing::error!("Failed to check delegation status: {:?}", e); 58 + Err( 59 + ApiError::InternalError(Some("Failed to verify delegation status".into())) 60 + .into_response(), 61 + ) 62 + } 63 + } 64 + } 65 + 66 + pub async fn verify_can_control_accounts<'a>( 67 + state: &AppState, 68 + user: &'a AuthenticatedUser, 69 + ) -> Result<CanControlAccounts<'a>, Response> { 70 + match state.delegation_repo.has_any_controllers(&user.did).await { 71 + Ok(true) => Err(ApiError::InvalidDelegation( 72 + "Cannot create delegated accounts from a controlled account".into(), 73 + ) 74 + .into_response()), 75 + Ok(false) => Ok(CanControlAccounts { user }), 76 + Err(e) => { 77 + tracing::error!("Failed to check controller status: {:?}", e); 78 + Err( 79 + ApiError::InternalError(Some("Failed to verify controller status".into())) 80 + .into_response(), 81 + ) 82 + } 83 + } 84 + } 85 + 86 + pub async fn verify_can_be_controller<'a>( 87 + state: &AppState, 88 + controller_did: &'a Did, 89 + ) -> Result<CanBeController<'a>, Response> { 90 + match state 91 + .delegation_repo 92 + .has_any_controllers(controller_did) 93 + .await 94 + { 95 + Ok(true) => Err(ApiError::InvalidDelegation( 96 + "Cannot add a controlled account as a controller".into(), 97 + ) 98 + .into_response()), 99 + Ok(false) => Ok(CanBeController { controller_did }), 100 + Err(e) => { 101 + tracing::error!("Failed to check controller status: {:?}", e); 102 + Err( 103 + ApiError::InternalError(Some("Failed to verify controller status".into())) 104 + .into_response(), 105 + ) 106 + } 107 + } 108 + }
+7 -29
crates/tranquil-pds/src/delegation/scopes.rs
··· 1 1 use std::collections::HashSet; 2 2 3 + pub use tranquil_db_traits::{ 4 + DbScope as ValidatedDelegationScope, InvalidScopeError as InvalidDelegationScopeError, 5 + }; 6 + 3 7 pub struct ScopePreset { 4 8 pub name: &'static str, 5 9 pub label: &'static str, ··· 107 111 } 108 112 } 109 113 110 - pub fn validate_delegation_scopes(scopes: &str) -> Result<(), String> { 111 - if scopes.is_empty() { 112 - return Ok(()); 113 - } 114 - 115 - scopes.split_whitespace().try_for_each(|scope| { 116 - let (base, _) = split_scope(scope); 117 - if is_valid_scope_prefix(base) { 118 - Ok(()) 119 - } else { 120 - Err(format!("Invalid scope: {}", scope)) 121 - } 122 - }) 123 - } 124 - 125 - fn is_valid_scope_prefix(base: &str) -> bool { 126 - const VALID_PREFIXES: [&str; 7] = [ 127 - "atproto", 128 - "repo:", 129 - "blob:", 130 - "rpc:", 131 - "account:", 132 - "identity:", 133 - "transition:", 134 - ]; 135 - 136 - VALID_PREFIXES 137 - .iter() 138 - .any(|prefix| base == prefix.trim_end_matches(':') || base.starts_with(prefix)) 114 + pub fn validate_delegation_scopes(scopes: &str) -> Result<(), InvalidDelegationScopeError> { 115 + ValidatedDelegationScope::new(scopes)?; 116 + Ok(()) 139 117 } 140 118 141 119 #[cfg(test)]
+2 -1
crates/tranquil-pds/src/lib.rs
··· 2 2 pub mod appview; 3 3 pub mod auth; 4 4 pub mod cache; 5 + pub mod cid_types; 5 6 pub mod circuit_breaker; 6 7 pub mod comms; 7 8 pub mod config; ··· 35 36 use http::StatusCode; 36 37 use serde_json::json; 37 38 use state::AppState; 38 - pub use sync::util::AccountStatus; 39 39 use tower::ServiceBuilder; 40 40 use tower_http::cors::{Any, CorsLayer}; 41 + pub use tranquil_db_traits::AccountStatus; 41 42 pub use types::{AccountState, AtIdentifier, AtUri, Did, Handle, Nsid, Rkey}; 42 43 43 44 pub fn app(state: AppState) -> Router {
+50 -37
crates/tranquil-pds/src/main.rs
··· 1 1 use std::net::SocketAddr; 2 2 use std::process::ExitCode; 3 3 use std::sync::Arc; 4 - use tokio::sync::watch; 4 + use tokio_util::sync::CancellationToken; 5 5 use tracing::{error, info, warn}; 6 6 use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender}; 7 7 ··· 34 34 } 35 35 36 36 async fn run() -> Result<(), Box<dyn std::error::Error>> { 37 - let state = AppState::new().await?; 38 - tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await; 37 + let shutdown = CancellationToken::new(); 38 + 39 + let shutdown_for_panic = shutdown.clone(); 40 + let default_panic_hook = std::panic::take_hook(); 41 + std::panic::set_hook(Box::new(move |info| { 42 + error!("PANIC: {}", info); 43 + shutdown_for_panic.cancel(); 44 + default_panic_hook(info); 45 + })); 39 46 40 - let (shutdown_tx, shutdown_rx) = watch::channel(false); 47 + spawn_signal_handler(shutdown.clone()); 48 + 49 + let state = AppState::new(shutdown.clone()).await?; 50 + tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await; 41 51 42 52 let backfill_repo_repo = state.repo_repo.clone(); 43 53 let backfill_block_store = state.block_store.clone(); ··· 77 87 comms_service = comms_service.register_sender(signal_sender); 78 88 } 79 89 80 - let comms_handle = tokio::spawn(comms_service.run(shutdown_rx.clone())); 90 + let comms_handle = tokio::spawn(comms_service.run(shutdown.clone())); 81 91 82 92 let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 83 93 let crawlers = Arc::new( ··· 88 98 Some(tokio::spawn(start_crawlers_service( 89 99 crawlers, 90 100 firehose_rx, 91 - shutdown_rx.clone(), 101 + shutdown.clone(), 92 102 ))) 93 103 } else { 94 104 warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); ··· 102 112 state.backup_repo.clone(), 103 113 state.block_store.clone(), 104 114 backup_storage, 105 - shutdown_rx.clone(), 115 + shutdown.clone(), 106 116 ))) 107 117 } else { 108 118 warn!("Backup service disabled (BACKUP_S3_BUCKET not set or BACKUP_ENABLED=false)"); ··· 114 124 state.blob_repo.clone(), 115 125 state.blob_store.clone(), 116 126 state.sso_repo.clone(), 117 - shutdown_rx, 127 + shutdown.clone(), 118 128 )); 119 129 120 130 let app = tranquil_pds::app(state); ··· 136 146 .map_err(|e| format!("Failed to bind to {}: {}", addr, e))?; 137 147 138 148 let server_result = axum::serve(listener, app) 139 - .with_graceful_shutdown(shutdown_signal(shutdown_tx)) 149 + .with_graceful_shutdown(shutdown.clone().cancelled_owned()) 140 150 .await; 141 151 142 152 comms_handle.await.ok(); ··· 158 168 Ok(()) 159 169 } 160 170 161 - async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) { 162 - let ctrl_c = async { 163 - match tokio::signal::ctrl_c().await { 164 - Ok(()) => {} 165 - Err(e) => { 166 - error!("Failed to install Ctrl+C handler: {}", e); 167 - } 168 - } 169 - }; 170 - 171 - #[cfg(unix)] 172 - let terminate = async { 173 - match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { 174 - Ok(mut signal) => { 175 - signal.recv().await; 171 + fn spawn_signal_handler(shutdown: CancellationToken) { 172 + tokio::spawn(async move { 173 + let ctrl_c = async { 174 + match tokio::signal::ctrl_c().await { 175 + Ok(()) => {} 176 + Err(e) => { 177 + error!("Failed to install Ctrl+C handler: {}", e); 178 + std::future::pending::<()>().await; 179 + } 176 180 } 177 - Err(e) => { 178 - error!("Failed to install SIGTERM handler: {}", e); 179 - std::future::pending::<()>().await; 181 + }; 182 + 183 + #[cfg(unix)] 184 + let terminate = async { 185 + match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { 186 + Ok(mut signal) => { 187 + signal.recv().await; 188 + } 189 + Err(e) => { 190 + error!("Failed to install SIGTERM handler: {}", e); 191 + std::future::pending::<()>().await; 192 + } 180 193 } 181 - } 182 - }; 194 + }; 183 195 184 - #[cfg(not(unix))] 185 - let terminate = std::future::pending::<()>(); 196 + #[cfg(not(unix))] 197 + let terminate = std::future::pending::<()>(); 186 198 187 - tokio::select! { 188 - _ = ctrl_c => {}, 189 - _ = terminate => {}, 190 - } 199 + tokio::select! { 200 + _ = ctrl_c => {}, 201 + _ = terminate => {}, 202 + } 191 203 192 - info!("Shutdown signal received, stopping services..."); 193 - shutdown_tx.send(true).ok(); 204 + info!("Shutdown signal received, stopping services..."); 205 + shutdown.cancel(); 206 + }); 194 207 }
+113 -331
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 + use crate::auth::{BareLoginIdentifier, NormalizedLoginIdentifier}; 1 2 use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 3 use crate::oauth::{ 3 - AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, 4 + AuthFlow, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, Prompt, SessionId, 4 5 db::should_show_consent, scopes::expand_include_scopes, 5 6 }; 6 - use crate::state::{AppState, RateLimitKind}; 7 + use crate::rate_limit::{ 8 + OAuthAuthorizeLimit, OAuthRateLimited, OAuthRegisterCompleteLimit, TotpVerifyLimit, 9 + check_user_rate_limit, 10 + }; 11 + use crate::state::AppState; 7 12 use crate::types::{Did, Handle, PlainPassword}; 13 + use crate::util::{extract_client_ip, pds_hostname, pds_hostname_without_port}; 8 14 use axum::{ 9 15 Json, 10 16 extract::{Query, State}, ··· 79 85 || s.starts_with("include:") 80 86 } 81 87 82 - fn validate_auth_flow_state( 83 - flow_state: &AuthFlowState, 84 - require_authenticated: bool, 85 - ) -> Option<Response> { 86 - if flow_state.is_expired() { 87 - return Some(json_error( 88 - StatusCode::BAD_REQUEST, 89 - "invalid_request", 90 - "Authorization request has expired", 91 - )); 92 - } 93 - if require_authenticated && flow_state.is_pending() { 94 - return Some(json_error( 95 - StatusCode::FORBIDDEN, 96 - "access_denied", 97 - "Not authenticated", 98 - )); 99 - } 100 - None 101 - } 102 - 103 88 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 104 89 headers 105 90 .get("cookie") ··· 113 98 }) 114 99 } 115 100 116 - fn extract_client_ip(headers: &HeaderMap) -> String { 117 - if let Some(forwarded) = headers.get("x-forwarded-for") 118 - && let Ok(value) = forwarded.to_str() 119 - && let Some(first_ip) = value.split(',').next() 120 - { 121 - return first_ip.trim().to_string(); 122 - } 123 - if let Some(real_ip) = headers.get("x-real-ip") 124 - && let Ok(value) = real_ip.to_str() 125 - { 126 - return value.trim().to_string(); 127 - } 128 - "0.0.0.0".to_string() 129 - } 130 - 131 101 fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 132 102 headers 133 103 .get("user-agent") ··· 282 252 283 253 if let Some(ref login_hint) = request_data.parameters.login_hint { 284 254 tracing::info!(login_hint = %login_hint, "Checking login_hint for delegation"); 285 - let pds_hostname = 286 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 287 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 288 - let normalized = if login_hint.contains('@') || login_hint.starts_with("did:") { 289 - login_hint.clone() 290 - } else if !login_hint.contains('.') { 291 - format!("{}.{}", login_hint.to_lowercase(), hostname_for_handles) 292 - } else { 293 - login_hint.to_lowercase() 294 - }; 255 + let hostname_for_handles = pds_hostname_without_port(); 256 + let normalized = NormalizedLoginIdentifier::normalize(login_hint, hostname_for_handles); 295 257 tracing::info!(normalized = %normalized, "Normalized login_hint"); 296 258 297 259 match state 298 260 .user_repo 299 - .get_login_check_by_handle_or_email(&normalized) 261 + .get_login_check_by_handle_or_email(normalized.as_str()) 300 262 .await 301 263 { 302 264 Ok(Some(user)) => { ··· 340 302 tracing::info!("No login_hint in request"); 341 303 } 342 304 343 - if request_data.parameters.prompt.as_deref() == Some("create") { 305 + if request_data.parameters.prompt == Some(Prompt::Create) { 344 306 return redirect_see_other(&format!( 345 307 "/app/oauth/register?request_uri={}", 346 308 url_encode(&request_uri) ··· 485 447 486 448 pub async fn authorize_post( 487 449 State(state): State<AppState>, 450 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 488 451 headers: HeaderMap, 489 452 Json(form): Json<AuthorizeSubmit>, 490 453 ) -> Response { 491 454 let json_response = wants_json(&headers); 492 - let client_ip = extract_client_ip(&headers); 493 - if !state 494 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 495 - .await 496 - { 497 - tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 498 - if json_response { 499 - return ( 500 - axum::http::StatusCode::TOO_MANY_REQUESTS, 501 - Json(serde_json::json!({ 502 - "error": "RateLimitExceeded", 503 - "error_description": "Too many login attempts. Please try again later." 504 - })), 505 - ) 506 - .into_response(); 507 - } 508 - return redirect_to_frontend_error( 509 - "RateLimitExceeded", 510 - "Too many login attempts. Please try again later.", 511 - ); 512 - } 513 455 let form_request_id = RequestId::from(form.request_uri.clone()); 514 456 let request_data = match state 515 457 .oauth_repo ··· 584 526 url_encode(error_msg) 585 527 )) 586 528 }; 587 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 588 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 589 - let normalized_username = form.username.trim(); 590 - let normalized_username = normalized_username 591 - .strip_prefix('@') 592 - .unwrap_or(normalized_username); 593 - let normalized_username = if normalized_username.contains('@') { 594 - normalized_username.to_string() 595 - } else if !normalized_username.contains('.') { 596 - format!("{}.{}", normalized_username, hostname_for_handles) 597 - } else { 598 - normalized_username.to_string() 599 - }; 529 + let hostname_for_handles = pds_hostname_without_port(); 530 + let normalized_username = 531 + NormalizedLoginIdentifier::normalize(&form.username, hostname_for_handles); 600 532 tracing::debug!( 601 533 original_username = %form.username, 602 534 normalized_username = %normalized_username, 603 - pds_hostname = %pds_hostname, 535 + pds_hostname = %pds_hostname(), 604 536 "Normalized username for lookup" 605 537 ); 606 538 let user = match state 607 539 .user_repo 608 - .get_login_info_by_handle_or_email(&normalized_username) 540 + .get_login_info_by_handle_or_email(normalized_username.as_str()) 609 541 .await 610 542 { 611 543 Ok(Some(u)) => u, ··· 624 556 if user.takedown_ref.is_some() { 625 557 return show_login_error("This account has been taken down.", json_response); 626 558 } 627 - let is_verified = user.email_verified 628 - || user.discord_verified 629 - || user.telegram_verified 630 - || user.signal_verified; 559 + let is_verified = user.channel_verification.has_any_verified(); 631 560 if !is_verified { 632 561 return show_login_error( 633 562 "Please verify your account before logging in.", ··· 635 564 ); 636 565 } 637 566 638 - if user.account_type == "delegated" { 567 + if user.account_type.is_delegated() { 639 568 if state 640 569 .oauth_repo 641 570 .set_authorization_did(&form_request_id, &user.did, None) ··· 748 677 .await 749 678 { 750 679 Ok(challenge) => { 751 - let hostname = 752 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 680 + let hostname = pds_hostname(); 753 681 if let Err(e) = enqueue_2fa_code( 754 682 state.user_repo.as_ref(), 755 683 state.infra_repo.as_ref(), 756 684 user.id, 757 685 &challenge.code, 758 - &hostname, 686 + hostname, 759 687 ) 760 688 .await 761 689 { ··· 792 720 } else { 793 721 let new_id = DeviceId::generate(); 794 722 let device_data = DeviceData { 795 - session_id: SessionId::generate().0, 723 + session_id: SessionId::generate(), 796 724 user_agent: extract_user_agent(&headers), 797 - ip_address: extract_client_ip(&headers), 725 + ip_address: extract_client_ip(&headers, None), 798 726 last_seen_at: Utc::now(), 799 727 }; 800 728 let new_device_id_typed = DeviceIdType::from(new_id.0.clone()); ··· 888 816 &request_data.parameters.redirect_uri, 889 817 &code.0, 890 818 request_data.parameters.state.as_deref(), 891 - request_data.parameters.response_mode.as_deref(), 819 + request_data.parameters.response_mode.map(|m| m.as_str()), 892 820 ); 893 821 if let Some(cookie) = new_cookie { 894 822 ( ··· 905 833 &request_data.parameters.redirect_uri, 906 834 &code.0, 907 835 request_data.parameters.state.as_deref(), 908 - request_data.parameters.response_mode.as_deref(), 836 + request_data.parameters.response_mode.map(|m| m.as_str()), 909 837 ); 910 838 if let Some(cookie) = new_cookie { 911 839 ( ··· 1026 954 ); 1027 955 } 1028 956 }; 1029 - let is_verified = user.email_verified 1030 - || user.discord_verified 1031 - || user.telegram_verified 1032 - || user.signal_verified; 957 + let is_verified = user.channel_verification.has_any_verified(); 1033 958 if !is_verified { 1034 959 return json_error( 1035 960 StatusCode::FORBIDDEN, ··· 1068 993 .await 1069 994 { 1070 995 Ok(challenge) => { 1071 - let hostname = 1072 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 996 + let hostname = pds_hostname(); 1073 997 if let Err(e) = enqueue_2fa_code( 1074 998 state.user_repo.as_ref(), 1075 999 state.infra_repo.as_ref(), 1076 1000 user.id, 1077 1001 &challenge.code, 1078 - &hostname, 1002 + hostname, 1079 1003 ) 1080 1004 .await 1081 1005 { ··· 1169 1093 &request_data.parameters.redirect_uri, 1170 1094 &code.0, 1171 1095 request_data.parameters.state.as_deref(), 1172 - request_data.parameters.response_mode.as_deref(), 1096 + request_data.parameters.response_mode.map(|m| m.as_str()), 1173 1097 ); 1174 1098 Json(serde_json::json!({ 1175 1099 "redirect_uri": redirect_url ··· 1193 1117 '?' 1194 1118 }; 1195 1119 redirect_url.push(separator); 1196 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1120 + let pds_host = pds_hostname(); 1197 1121 redirect_url.push_str(&format!( 1198 1122 "iss={}", 1199 - url_encode(&format!("https://{}", pds_hostname)) 1123 + url_encode(&format!("https://{}", pds_host)) 1200 1124 )); 1201 1125 if let Some(req_state) = state { 1202 1126 redirect_url.push_str(&format!("&state={}", url_encode(req_state))); ··· 1211 1135 state: Option<&str>, 1212 1136 response_mode: Option<&str>, 1213 1137 ) -> String { 1214 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1138 + let pds_host = pds_hostname(); 1215 1139 let mut url = format!( 1216 1140 "https://{}/oauth/authorize/redirect?redirect_uri={}&code={}", 1217 - pds_hostname, 1141 + pds_host, 1218 1142 url_encode(redirect_uri), 1219 1143 url_encode(code) 1220 1144 ); ··· 1459 1383 ); 1460 1384 } 1461 1385 }; 1462 - let flow_state = AuthFlowState::from_request_data(&request_data); 1463 - 1464 - if let Some(err_response) = validate_auth_flow_state(&flow_state, true) { 1465 - if flow_state.is_expired() { 1386 + let flow_with_user = match AuthFlow::from_request_data(request_data.clone()) { 1387 + Ok(flow) => match flow.require_user() { 1388 + Ok(u) => u, 1389 + Err(_) => { 1390 + return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1391 + } 1392 + }, 1393 + Err(_) => { 1466 1394 let _ = state 1467 1395 .oauth_repo 1468 1396 .delete_authorization_request(&consent_request_id) 1469 1397 .await; 1470 - } 1471 - return err_response; 1472 - } 1473 - 1474 - let did_str = flow_state.did().unwrap().to_string(); 1475 - let did: Did = match did_str.parse() { 1476 - Ok(d) => d, 1477 - Err(_) => { 1478 1398 return json_error( 1479 1399 StatusCode::BAD_REQUEST, 1480 1400 "invalid_request", 1481 - "Invalid DID format in request.", 1401 + "Authorization request has expired", 1482 1402 ); 1483 1403 } 1484 1404 }; 1405 + 1406 + let did = flow_with_user.did().clone(); 1485 1407 let client_cache = ClientMetadataCache::new(3600); 1486 1408 let client_metadata = client_cache 1487 1409 .get(&request_data.parameters.client_id) ··· 1510 1432 }; 1511 1433 1512 1434 let effective_scope_str = if let Some(ref grant) = delegation_grant { 1513 - crate::delegation::intersect_scopes(requested_scope_str, &grant.granted_scopes) 1435 + crate::delegation::intersect_scopes(requested_scope_str, grant.granted_scopes.as_str()) 1514 1436 } else { 1515 1437 requested_scope_str.to_string() 1516 1438 }; ··· 1609 1531 let level = if let Some(ref grant) = delegation_grant { 1610 1532 let preset = crate::delegation::SCOPE_PRESETS 1611 1533 .iter() 1612 - .find(|p| p.scopes == grant.granted_scopes); 1534 + .find(|p| p.scopes == grant.granted_scopes.as_str()); 1613 1535 preset 1614 1536 .map(|p| p.label.to_string()) 1615 1537 .unwrap_or_else(|| "Custom".to_string()) ··· 1635 1557 logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1636 1558 scopes, 1637 1559 show_consent, 1638 - did: did_str, 1560 + did: did.to_string(), 1639 1561 handle: account_handle, 1640 1562 is_delegation, 1641 1563 controller_did: controller_did_resp, ··· 1676 1598 ); 1677 1599 } 1678 1600 }; 1679 - let flow_state = AuthFlowState::from_request_data(&request_data); 1680 - 1681 - if flow_state.is_expired() { 1682 - let _ = state 1683 - .oauth_repo 1684 - .delete_authorization_request(&consent_post_request_id) 1685 - .await; 1686 - return json_error( 1687 - StatusCode::BAD_REQUEST, 1688 - "invalid_request", 1689 - "Authorization request has expired", 1690 - ); 1691 - } 1692 - if flow_state.is_pending() { 1693 - return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1694 - } 1695 - 1696 - let did_str = flow_state.did().unwrap().to_string(); 1697 - let did: Did = match did_str.parse() { 1698 - Ok(d) => d, 1601 + let flow_with_user = match AuthFlow::from_request_data(request_data.clone()) { 1602 + Ok(flow) => match flow.require_user() { 1603 + Ok(u) => u, 1604 + Err(_) => { 1605 + return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1606 + } 1607 + }, 1699 1608 Err(_) => { 1609 + let _ = state 1610 + .oauth_repo 1611 + .delete_authorization_request(&consent_post_request_id) 1612 + .await; 1700 1613 return json_error( 1701 1614 StatusCode::BAD_REQUEST, 1702 1615 "invalid_request", 1703 - "Invalid DID format", 1616 + "Authorization request has expired", 1704 1617 ); 1705 1618 } 1706 1619 }; 1620 + 1621 + let did = flow_with_user.did().clone(); 1707 1622 let original_scope_str = request_data 1708 1623 .parameters 1709 1624 .scope ··· 1726 1641 }; 1727 1642 1728 1643 let effective_scope_str = if let Some(ref grant) = delegation_grant { 1729 - crate::delegation::intersect_scopes(original_scope_str, &grant.granted_scopes) 1644 + crate::delegation::intersect_scopes(original_scope_str, grant.granted_scopes.as_str()) 1730 1645 } else { 1731 1646 original_scope_str.to_string() 1732 1647 }; ··· 1799 1714 let consent_post_device_id = request_data 1800 1715 .device_id 1801 1716 .as_ref() 1802 - .map(|d| DeviceIdType::from(d.clone())); 1717 + .map(|d| DeviceIdType::from(d.0.clone())); 1803 1718 let consent_post_code = AuthorizationCode::from(code.0.clone()); 1804 1719 if state 1805 1720 .oauth_repo ··· 1823 1738 redirect_uri, 1824 1739 &code.0, 1825 1740 request_data.parameters.state.as_deref(), 1826 - request_data.parameters.response_mode.as_deref(), 1741 + request_data.parameters.response_mode.map(|m| m.as_str()), 1827 1742 ); 1828 1743 tracing::info!( 1829 1744 intermediate_url = %intermediate_url, ··· 1835 1750 1836 1751 pub async fn authorize_2fa_post( 1837 1752 State(state): State<AppState>, 1753 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 1838 1754 headers: HeaderMap, 1839 1755 Json(form): Json<Authorize2faSubmit>, 1840 1756 ) -> Response { ··· 1848 1764 ) 1849 1765 .into_response() 1850 1766 }; 1851 - let client_ip = extract_client_ip(&headers); 1852 - if !state 1853 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 1854 - .await 1855 - { 1856 - tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 1857 - return json_error( 1858 - StatusCode::TOO_MANY_REQUESTS, 1859 - "RateLimitExceeded", 1860 - "Too many attempts. Please try again later.", 1861 - ); 1862 - } 1863 1767 let twofa_post_request_id = RequestId::from(form.request_uri.clone()); 1864 1768 let request_data = match state 1865 1769 .oauth_repo ··· 1956 1860 &request_data.parameters.redirect_uri, 1957 1861 &code.0, 1958 1862 request_data.parameters.state.as_deref(), 1959 - request_data.parameters.response_mode.as_deref(), 1863 + request_data.parameters.response_mode.map(|m| m.as_str()), 1960 1864 ); 1961 1865 return Json(serde_json::json!({ 1962 1866 "redirect_uri": redirect_url ··· 1990 1894 "No 2FA challenge found. Please start over.", 1991 1895 ); 1992 1896 } 1993 - if !state 1994 - .check_rate_limit(RateLimitKind::TotpVerify, &did) 1995 - .await 1996 - { 1997 - tracing::warn!(did = %did, "TOTP verification rate limit exceeded"); 1998 - return json_error( 1999 - StatusCode::TOO_MANY_REQUESTS, 2000 - "RateLimitExceeded", 2001 - "Too many verification attempts. Please try again in a few minutes.", 2002 - ); 2003 - } 1897 + let _rate_proof = match check_user_rate_limit::<TotpVerifyLimit>(&state, &did).await { 1898 + Ok(proof) => proof, 1899 + Err(_) => { 1900 + return json_error( 1901 + StatusCode::TOO_MANY_REQUESTS, 1902 + "RateLimitExceeded", 1903 + "Too many verification attempts. Please try again in a few minutes.", 1904 + ); 1905 + } 1906 + }; 2004 1907 let totp_valid = 2005 1908 crate::api::server::verify_totp_or_backup_for_user(&state, &did, &form.code).await; 2006 1909 if !totp_valid { ··· 2065 1968 &request_data.parameters.redirect_uri, 2066 1969 &code.0, 2067 1970 request_data.parameters.state.as_deref(), 2068 - request_data.parameters.response_mode.as_deref(), 1971 + request_data.parameters.response_mode.map(|m| m.as_str()), 2069 1972 ); 2070 1973 Json(serde_json::json!({ 2071 1974 "redirect_uri": redirect_url ··· 2089 1992 State(state): State<AppState>, 2090 1993 Query(query): Query<CheckPasskeysQuery>, 2091 1994 ) -> Response { 2092 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2093 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2094 - let normalized_identifier = query.identifier.trim(); 2095 - let normalized_identifier = normalized_identifier 2096 - .strip_prefix('@') 2097 - .unwrap_or(normalized_identifier); 2098 - let normalized_identifier = if let Some(bare_handle) = 2099 - normalized_identifier.strip_suffix(&format!(".{}", hostname_for_handles)) 2100 - { 2101 - bare_handle.to_string() 2102 - } else { 2103 - normalized_identifier.to_string() 2104 - }; 1995 + let hostname_for_handles = pds_hostname_without_port(); 1996 + let bare_identifier = 1997 + BareLoginIdentifier::from_identifier(&query.identifier, hostname_for_handles); 2105 1998 2106 1999 let user = state 2107 2000 .user_repo 2108 - .get_login_check_by_handle_or_email(&normalized_identifier) 2001 + .get_login_check_by_handle_or_email(bare_identifier.as_str()) 2109 2002 .await; 2110 2003 2111 2004 let has_passkeys = match user { ··· 2131 2024 State(state): State<AppState>, 2132 2025 Query(query): Query<CheckPasskeysQuery>, 2133 2026 ) -> Response { 2134 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2135 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2136 - let identifier = query.identifier.trim(); 2137 - let identifier = identifier.strip_prefix('@').unwrap_or(identifier); 2138 - let normalized_identifier = if identifier.contains('@') || identifier.starts_with("did:") { 2139 - identifier.to_string() 2140 - } else if !identifier.contains('.') { 2141 - format!("{}.{}", identifier.to_lowercase(), hostname_for_handles) 2142 - } else { 2143 - identifier.to_lowercase() 2144 - }; 2027 + let hostname_for_handles = pds_hostname_without_port(); 2028 + let normalized_identifier = 2029 + NormalizedLoginIdentifier::normalize(&query.identifier, hostname_for_handles); 2145 2030 2146 2031 let user = state 2147 2032 .user_repo 2148 - .get_login_check_by_handle_or_email(&normalized_identifier) 2033 + .get_login_check_by_handle_or_email(normalized_identifier.as_str()) 2149 2034 .await; 2150 2035 2151 2036 let (has_passkeys, has_totp, has_password, is_delegated, did): ( ··· 2200 2085 2201 2086 pub async fn passkey_start( 2202 2087 State(state): State<AppState>, 2203 - headers: HeaderMap, 2088 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 2204 2089 Json(form): Json<PasskeyStartInput>, 2205 2090 ) -> Response { 2206 - let client_ip = extract_client_ip(&headers); 2207 - 2208 - if !state 2209 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 2210 - .await 2211 - { 2212 - tracing::warn!(ip = %client_ip, "OAuth passkey rate limit exceeded"); 2213 - return ( 2214 - StatusCode::TOO_MANY_REQUESTS, 2215 - Json(serde_json::json!({ 2216 - "error": "RateLimitExceeded", 2217 - "error_description": "Too many login attempts. Please try again later." 2218 - })), 2219 - ) 2220 - .into_response(); 2221 - } 2222 - 2223 2091 let passkey_start_request_id = RequestId::from(form.request_uri.clone()); 2224 2092 let request_data = match state 2225 2093 .oauth_repo ··· 2264 2132 .into_response(); 2265 2133 } 2266 2134 2267 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2268 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2269 - let normalized_username = form.identifier.trim(); 2270 - let normalized_username = normalized_username 2271 - .strip_prefix('@') 2272 - .unwrap_or(normalized_username); 2273 - let normalized_username = if normalized_username.contains('@') { 2274 - normalized_username.to_string() 2275 - } else if !normalized_username.contains('.') { 2276 - format!("{}.{}", normalized_username, hostname_for_handles) 2277 - } else { 2278 - normalized_username.to_string() 2279 - }; 2135 + let hostname_for_handles = pds_hostname_without_port(); 2136 + let normalized_username = 2137 + NormalizedLoginIdentifier::normalize(&form.identifier, hostname_for_handles); 2280 2138 2281 2139 let user = match state 2282 2140 .user_repo 2283 - .get_login_info_by_handle_or_email(&normalized_username) 2141 + .get_login_info_by_handle_or_email(normalized_username.as_str()) 2284 2142 .await 2285 2143 { 2286 2144 Ok(Some(u)) => u, ··· 2328 2186 .into_response(); 2329 2187 } 2330 2188 2331 - let is_verified = user.email_verified 2332 - || user.discord_verified 2333 - || user.telegram_verified 2334 - || user.signal_verified; 2189 + let is_verified = user.channel_verification.has_any_verified(); 2335 2190 2336 2191 if !is_verified { 2337 2192 return ( ··· 2386 2241 .into_response(); 2387 2242 } 2388 2243 2389 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2390 - Ok(w) => w, 2391 - Err(e) => { 2392 - tracing::error!(error = %e, "Failed to create WebAuthn config"); 2393 - return ( 2394 - StatusCode::INTERNAL_SERVER_ERROR, 2395 - Json(serde_json::json!({ 2396 - "error": "server_error", 2397 - "error_description": "WebAuthn configuration failed." 2398 - })), 2399 - ) 2400 - .into_response(); 2401 - } 2402 - }; 2403 - 2404 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 2244 + let (rcr, auth_state) = match state.webauthn_config.start_authentication(passkeys) { 2405 2245 Ok(result) => result, 2406 2246 Err(e) => { 2407 2247 tracing::error!(error = %e, "Failed to start passkey authentication"); ··· 2680 2520 } 2681 2521 }; 2682 2522 2683 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2684 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2685 - Ok(w) => w, 2686 - Err(e) => { 2687 - tracing::error!(error = %e, "Failed to create WebAuthn config"); 2688 - return ( 2689 - StatusCode::INTERNAL_SERVER_ERROR, 2690 - Json(serde_json::json!({ 2691 - "error": "server_error", 2692 - "error_description": "WebAuthn configuration failed." 2693 - })), 2694 - ) 2695 - .into_response(); 2696 - } 2697 - }; 2698 - 2699 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 2523 + let auth_result = match state 2524 + .webauthn_config 2525 + .finish_authentication(&credential, &auth_state) 2526 + { 2700 2527 Ok(r) => r, 2701 2528 Err(e) => { 2702 2529 tracing::warn!(error = %e, did = %did, "Failed to verify passkey authentication"); ··· 2769 2596 .await 2770 2597 { 2771 2598 Ok(challenge) => { 2772 - let hostname = 2773 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2599 + let hostname = pds_hostname(); 2774 2600 if let Err(e) = enqueue_2fa_code( 2775 2601 state.user_repo.as_ref(), 2776 2602 state.infra_repo.as_ref(), 2777 2603 user.id, 2778 2604 &challenge.code, 2779 - &hostname, 2605 + hostname, 2780 2606 ) 2781 2607 .await 2782 2608 { ··· 2859 2685 &request_data.parameters.redirect_uri, 2860 2686 &code.0, 2861 2687 request_data.parameters.state.as_deref(), 2862 - request_data.parameters.response_mode.as_deref(), 2688 + request_data.parameters.response_mode.map(|m| m.as_str()), 2863 2689 ); 2864 2690 2865 2691 Json(serde_json::json!({ ··· 2884 2710 State(state): State<AppState>, 2885 2711 Query(query): Query<AuthorizePasskeyQuery>, 2886 2712 ) -> Response { 2887 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2888 - 2889 2713 let auth_passkey_start_request_id = RequestId::from(query.request_uri.clone()); 2890 2714 let request_data = match state 2891 2715 .oauth_repo ··· 2994 2818 .into_response(); 2995 2819 } 2996 2820 2997 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2998 - Ok(w) => w, 2999 - Err(e) => { 3000 - tracing::error!("Failed to create WebAuthn config: {:?}", e); 3001 - return ( 3002 - StatusCode::INTERNAL_SERVER_ERROR, 3003 - Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 3004 - ) 3005 - .into_response(); 3006 - } 3007 - }; 3008 - 3009 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 2821 + let (rcr, auth_state) = match state.webauthn_config.start_authentication(passkeys) { 3010 2822 Ok(result) => result, 3011 2823 Err(e) => { 3012 2824 tracing::error!("Failed to start passkey authentication: {:?}", e); ··· 3063 2875 headers: HeaderMap, 3064 2876 Json(form): Json<AuthorizePasskeySubmit>, 3065 2877 ) -> Response { 3066 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2878 + let pds_hostname = pds_hostname(); 3067 2879 let passkey_finish_request_id = RequestId::from(form.request_uri.clone()); 3068 2880 3069 2881 let request_data = match state ··· 3193 3005 } 3194 3006 }; 3195 3007 3196 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 3197 - Ok(w) => w, 3198 - Err(e) => { 3199 - tracing::error!("Failed to create WebAuthn config: {:?}", e); 3200 - return ( 3201 - StatusCode::INTERNAL_SERVER_ERROR, 3202 - Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 3203 - ) 3204 - .into_response(); 3205 - } 3206 - }; 3207 - 3208 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 3008 + let auth_result = match state 3009 + .webauthn_config 3010 + .finish_authentication(&credential, &auth_state) 3011 + { 3209 3012 Ok(r) => r, 3210 3013 Err(e) => { 3211 3014 tracing::warn!("Passkey authentication failed: {:?}", e); ··· 3292 3095 state.infra_repo.as_ref(), 3293 3096 user.id, 3294 3097 &challenge.code, 3295 - &pds_hostname, 3098 + pds_hostname, 3296 3099 ) 3297 3100 .await 3298 3101 { ··· 3347 3150 3348 3151 pub async fn register_complete( 3349 3152 State(state): State<AppState>, 3350 - headers: HeaderMap, 3153 + _rate_limit: OAuthRateLimited<OAuthRegisterCompleteLimit>, 3351 3154 Json(form): Json<RegisterCompleteInput>, 3352 3155 ) -> Response { 3353 - let client_ip = extract_client_ip(&headers); 3354 - 3355 - if !state 3356 - .check_rate_limit(RateLimitKind::OAuthRegisterComplete, &client_ip) 3357 - .await 3358 - { 3359 - return ( 3360 - StatusCode::TOO_MANY_REQUESTS, 3361 - Json(serde_json::json!({ 3362 - "error": "RateLimitExceeded", 3363 - "error_description": "Too many attempts. Please try again later." 3364 - })), 3365 - ) 3366 - .into_response(); 3367 - } 3368 - 3369 3156 let did = Did::from(form.did.clone()); 3370 3157 3371 3158 let request_id = RequestId::from(form.request_uri.clone()); ··· 3417 3204 .into_response(); 3418 3205 } 3419 3206 3420 - if request_data.parameters.prompt.as_deref() != Some("create") { 3207 + if request_data.parameters.prompt != Some(Prompt::Create) { 3421 3208 tracing::warn!( 3422 3209 request_uri = %form.request_uri, 3423 3210 prompt = ?request_data.parameters.prompt, ··· 3506 3293 } 3507 3294 3508 3295 let is_verified = match state.user_repo.get_session_info_by_did(&did).await { 3509 - Ok(Some(info)) => { 3510 - info.email_verified 3511 - || info.discord_verified 3512 - || info.telegram_verified 3513 - || info.signal_verified 3514 - } 3296 + Ok(Some(info)) => info.channel_verification.has_any_verified(), 3515 3297 Ok(None) => { 3516 3298 return ( 3517 3299 StatusCode::FORBIDDEN, ··· 3636 3418 &request_data.parameters.redirect_uri, 3637 3419 &code.0, 3638 3420 request_data.parameters.state.as_deref(), 3639 - request_data.parameters.response_mode.as_deref(), 3421 + request_data.parameters.response_mode.map(|m| m.as_str()), 3640 3422 ); 3641 3423 Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 3642 3424 } ··· 3662 3444 None => { 3663 3445 let new_id = DeviceId::generate(); 3664 3446 let device_data = DeviceData { 3665 - session_id: SessionId::generate().0, 3447 + session_id: SessionId::generate(), 3666 3448 user_agent: extract_user_agent(&headers), 3667 - ip_address: extract_client_ip(&headers), 3449 + ip_address: extract_client_ip(&headers, None), 3668 3450 last_seen_at: Utc::now(), 3669 3451 }; 3670 3452 let device_typed = DeviceIdType::from(new_id.0.clone());
+32 -64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
··· 1 1 use crate::auth::{Active, Auth}; 2 2 use crate::delegation::DelegationActionType; 3 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::rate_limit::{LoginLimit, OAuthRateLimited, TotpVerifyLimit}; 4 + use crate::state::AppState; 4 5 use crate::types::PlainPassword; 5 6 use crate::util::extract_client_ip; 6 7 use axum::{ 7 8 Json, 8 9 extract::State, 9 - http::{HeaderMap, StatusCode}, 10 + http::HeaderMap, 10 11 response::{IntoResponse, Response}, 11 12 }; 12 13 use serde::{Deserialize, Serialize}; ··· 35 36 36 37 pub async fn delegation_auth( 37 38 State(state): State<AppState>, 39 + rate_limit: OAuthRateLimited<LoginLimit>, 38 40 headers: HeaderMap, 39 41 Json(form): Json<DelegationAuthSubmit>, 40 42 ) -> Response { 41 - let client_ip = extract_client_ip(&headers); 42 - if !state 43 - .check_rate_limit(RateLimitKind::Login, &client_ip) 44 - .await 45 - { 46 - return ( 47 - StatusCode::TOO_MANY_REQUESTS, 48 - Json(DelegationAuthResponse { 49 - success: false, 50 - needs_totp: None, 51 - redirect_uri: None, 52 - error: Some("Too many login attempts. Please try again later.".to_string()), 53 - }), 54 - ) 55 - .into_response(); 56 - } 57 - 43 + let client_ip = rate_limit.client_ip(); 58 44 let request_id = RequestId::from(form.request_uri.clone()); 59 45 let request = match state 60 46 .oauth_repo ··· 82 68 } 83 69 }; 84 70 85 - let delegated_did_str = match form.delegated_did.as_ref().or(request.did.as_ref()) { 86 - Some(did) => did.clone(), 87 - None => { 88 - return Json(DelegationAuthResponse { 89 - success: false, 90 - needs_totp: None, 91 - redirect_uri: None, 92 - error: Some("No delegated account selected".to_string()), 93 - }) 94 - .into_response(); 95 - } 96 - }; 97 - 98 - let delegated_did: Did = match delegated_did_str.parse() { 99 - Ok(d) => d, 100 - Err(_) => { 101 - return Json(DelegationAuthResponse { 102 - success: false, 103 - needs_totp: None, 104 - redirect_uri: None, 105 - error: Some("Invalid delegated DID".to_string()), 106 - }) 107 - .into_response(); 71 + let delegated_did: Did = if let Some(did_str) = form.delegated_did.as_ref() { 72 + match did_str.parse() { 73 + Ok(d) => d, 74 + Err(_) => { 75 + return Json(DelegationAuthResponse { 76 + success: false, 77 + needs_totp: None, 78 + redirect_uri: None, 79 + error: Some("Invalid delegated DID".to_string()), 80 + }) 81 + .into_response(); 82 + } 108 83 } 84 + } else if let Some(did) = request.did.as_ref() { 85 + did.clone() 86 + } else { 87 + return Json(DelegationAuthResponse { 88 + success: false, 89 + needs_totp: None, 90 + redirect_uri: None, 91 + error: Some("No delegated account selected".to_string()), 92 + }) 93 + .into_response(); 109 94 }; 110 95 111 96 let controller_did: Did = match form.controller_did.parse() { ··· 249 234 .into_response(); 250 235 } 251 236 252 - let ip = extract_client_ip(&headers); 253 237 let user_agent = headers 254 238 .get("user-agent") 255 239 .and_then(|v| v.to_str().ok()) ··· 266 250 "client_id": request.client_id, 267 251 "granted_scopes": grant.granted_scopes 268 252 })), 269 - Some(&ip), 253 + Some(client_ip), 270 254 user_agent.as_deref(), 271 255 ) 272 256 .await; ··· 291 275 292 276 pub async fn delegation_totp_verify( 293 277 State(state): State<AppState>, 278 + rate_limit: OAuthRateLimited<TotpVerifyLimit>, 294 279 headers: HeaderMap, 295 280 Json(form): Json<DelegationTotpSubmit>, 296 281 ) -> Response { 297 - let client_ip = extract_client_ip(&headers); 298 - if !state 299 - .check_rate_limit(RateLimitKind::TotpVerify, &client_ip) 300 - .await 301 - { 302 - return ( 303 - StatusCode::TOO_MANY_REQUESTS, 304 - Json(DelegationAuthResponse { 305 - success: false, 306 - needs_totp: None, 307 - redirect_uri: None, 308 - error: Some("Too many verification attempts. Please try again later.".to_string()), 309 - }), 310 - ) 311 - .into_response(); 312 - } 313 - 282 + let client_ip = rate_limit.client_ip(); 314 283 let totp_request_id = RequestId::from(form.request_uri.clone()); 315 284 let request = match state 316 285 .oauth_repo ··· 420 389 .into_response(); 421 390 } 422 391 423 - let ip = extract_client_ip(&headers); 424 392 let user_agent = headers 425 393 .get("user-agent") 426 394 .and_then(|v| v.to_str().ok()) ··· 437 405 "client_id": request.client_id, 438 406 "granted_scopes": grant.granted_scopes 439 407 })), 440 - Some(&ip), 408 + Some(client_ip), 441 409 user_agent.as_deref(), 442 410 ) 443 411 .await; ··· 564 532 .into_response(); 565 533 } 566 534 567 - let ip = extract_client_ip(&headers); 535 + let ip = extract_client_ip(&headers, None); 568 536 let user_agent = headers 569 537 .get("user-agent") 570 538 .and_then(|v| v.to_str().ok())
+3 -2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
··· 1 1 use crate::oauth::jwks::{JwkSet, create_jwk_set}; 2 2 use crate::state::AppState; 3 + use crate::util::pds_hostname; 3 4 use axum::{Json, extract::State}; 4 5 use serde::{Deserialize, Serialize}; 5 6 ··· 57 58 pub async fn oauth_protected_resource( 58 59 State(_state): State<AppState>, 59 60 ) -> Json<ProtectedResourceMetadata> { 60 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 61 + let pds_hostname = pds_hostname(); 61 62 let public_url = format!("https://{}", pds_hostname); 62 63 Json(ProtectedResourceMetadata { 63 64 resource: public_url.clone(), ··· 71 72 pub async fn oauth_authorization_server( 72 73 State(_state): State<AppState>, 73 74 ) -> Json<AuthorizationServerMetadata> { 74 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 75 + let pds_hostname = pds_hostname(); 75 76 let issuer = format!("https://{}", pds_hostname); 76 77 Json(AuthorizationServerMetadata { 77 78 issuer: issuer.clone(),
+58 -50
crates/tranquil-pds/src/oauth/endpoints/par.rs
··· 1 1 use crate::oauth::{ 2 - AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData, 3 - RequestId, 2 + AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, CodeChallengeMethod, 3 + OAuthError, Prompt, RequestData, RequestId, ResponseMode, ResponseType, 4 4 scopes::{ParsedScope, parse_scope}, 5 5 }; 6 - use crate::state::{AppState, RateLimitKind}; 6 + use crate::rate_limit::{OAuthParLimit, OAuthRateLimited}; 7 + use crate::state::AppState; 7 8 use axum::body::Bytes; 8 9 use axum::{Json, extract::State, http::HeaderMap}; 9 10 use chrono::{Duration, Utc}; ··· 49 50 50 51 pub async fn pushed_authorization_request( 51 52 State(state): State<AppState>, 53 + _rate_limit: OAuthRateLimited<OAuthParLimit>, 52 54 headers: HeaderMap, 53 55 body: Bytes, 54 56 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { ··· 70 72 .to_string(), 71 73 )); 72 74 }; 73 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 74 - if !state 75 - .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 76 - .await 77 - { 78 - tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 79 - return Err(OAuthError::RateLimited); 80 - } 81 - if request.response_type != "code" { 82 - return Err(OAuthError::InvalidRequest( 83 - "response_type must be 'code'".to_string(), 84 - )); 85 - } 75 + let response_type = parse_response_type(&request.response_type)?; 86 76 let code_challenge = request 87 77 .code_challenge 88 78 .as_ref() 89 79 .filter(|s| !s.is_empty()) 90 80 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 91 - let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 92 - if code_challenge_method != "S256" { 93 - return Err(OAuthError::InvalidRequest( 94 - "code_challenge_method must be 'S256'".to_string(), 95 - )); 96 - } 81 + let code_challenge_method = 82 + parse_code_challenge_method(request.code_challenge_method.as_deref())?; 97 83 let client_cache = ClientMetadataCache::new(3600); 98 84 let client_metadata = client_cache.get(&request.client_id).await?; 99 85 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; ··· 101 87 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 102 88 let request_id = RequestId::generate(); 103 89 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 104 - let response_mode = match request.response_mode.as_deref() { 105 - Some("fragment") => Some("fragment".to_string()), 106 - Some("query") | None => None, 107 - Some(mode) => { 108 - return Err(OAuthError::InvalidRequest(format!( 109 - "Unsupported response_mode: {}", 110 - mode 111 - ))); 112 - } 113 - }; 114 - let prompt = validate_prompt(&request.prompt)?; 90 + let response_mode = parse_response_mode(request.response_mode.as_deref())?; 91 + let prompt = parse_prompt(request.prompt.as_deref())?; 115 92 let parameters = AuthorizationRequestParameters { 116 - response_type: request.response_type, 93 + response_type, 117 94 client_id: request.client_id.clone(), 118 95 redirect_uri: request.redirect_uri, 119 96 scope: validated_scope, 120 97 state: request.state, 121 98 code_challenge: code_challenge.clone(), 122 - code_challenge_method: code_challenge_method.to_string(), 99 + code_challenge_method, 123 100 response_mode, 124 101 login_hint: request.login_hint, 125 102 dpop_jkt: request.dpop_jkt, ··· 266 243 false 267 244 } 268 245 269 - fn validate_prompt(prompt: &Option<String>) -> Result<Option<String>, OAuthError> { 270 - const VALID_PROMPTS: &[&str] = &["none", "login", "consent", "select_account", "create"]; 246 + fn parse_response_type(value: &str) -> Result<ResponseType, OAuthError> { 247 + match value { 248 + "code" => Ok(ResponseType::Code), 249 + other => Err(OAuthError::InvalidRequest(format!( 250 + "response_type must be 'code', got '{}'", 251 + other 252 + ))), 253 + } 254 + } 271 255 272 - match prompt { 273 - None => Ok(None), 274 - Some(p) if p.is_empty() => Ok(None), 275 - Some(p) => { 276 - if VALID_PROMPTS.contains(&p.as_str()) { 277 - Ok(Some(p.clone())) 278 - } else { 279 - Err(OAuthError::InvalidRequest(format!( 280 - "Unsupported prompt value: {}", 281 - p 282 - ))) 283 - } 284 - } 256 + fn parse_code_challenge_method(value: Option<&str>) -> Result<CodeChallengeMethod, OAuthError> { 257 + match value { 258 + Some("S256") | None => Ok(CodeChallengeMethod::S256), 259 + Some("plain") => Err(OAuthError::InvalidRequest( 260 + "code_challenge_method 'plain' is not allowed, use 'S256'".to_string(), 261 + )), 262 + Some(other) => Err(OAuthError::InvalidRequest(format!( 263 + "Unsupported code_challenge_method: {}", 264 + other 265 + ))), 266 + } 267 + } 268 + 269 + fn parse_response_mode(value: Option<&str>) -> Result<Option<ResponseMode>, OAuthError> { 270 + match value { 271 + None | Some("query") => Ok(None), 272 + Some("fragment") => Ok(Some(ResponseMode::Fragment)), 273 + Some("form_post") => Ok(Some(ResponseMode::FormPost)), 274 + Some(other) => Err(OAuthError::InvalidRequest(format!( 275 + "Unsupported response_mode: {}", 276 + other 277 + ))), 278 + } 279 + } 280 + 281 + fn parse_prompt(value: Option<&str>) -> Result<Option<Prompt>, OAuthError> { 282 + match value { 283 + None | Some("") => Ok(None), 284 + Some("none") => Ok(Some(Prompt::None)), 285 + Some("login") => Ok(Some(Prompt::Login)), 286 + Some("consent") => Ok(Some(Prompt::Consent)), 287 + Some("select_account") => Ok(Some(Prompt::SelectAccount)), 288 + Some("create") => Ok(Some(Prompt::Create)), 289 + Some(other) => Err(OAuthError::InvalidRequest(format!( 290 + "Unsupported prompt value: {}", 291 + other 292 + ))), 285 293 } 286 294 }
+42 -49
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
··· 3 3 use crate::config::AuthConfig; 4 4 use crate::delegation::intersect_scopes; 5 5 use crate::oauth::{ 6 - AuthFlowState, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, 7 - TokenData, TokenId, 6 + AuthFlow, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, TokenData, 7 + TokenId, 8 8 db::{enforce_token_limit_for_user, lookup_refresh_token}, 9 9 scopes::expand_include_scopes, 10 10 verify_client_auth, 11 11 }; 12 12 use crate::state::AppState; 13 + use crate::util::pds_hostname; 13 14 use axum::Json; 14 15 use axum::http::HeaderMap; 15 16 use chrono::{Duration, Utc}; ··· 51 52 .map_err(crate::oauth::db_err_to_oauth)? 52 53 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 53 54 54 - let flow_state = AuthFlowState::from_request_data(&auth_request); 55 - if flow_state.is_expired() { 56 - return Err(OAuthError::InvalidGrant( 57 - "Authorization code has expired".to_string(), 58 - )); 59 - } 60 - if !flow_state.can_exchange() { 61 - return Err(OAuthError::InvalidGrant( 62 - "Authorization not completed".to_string(), 63 - )); 64 - } 55 + let flow = AuthFlow::from_request_data(auth_request) 56 + .map_err(|_| OAuthError::InvalidGrant("Authorization code has expired".to_string()))?; 57 + 58 + let authorized = flow 59 + .require_authorized() 60 + .map_err(|_| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 65 61 66 62 if let Some(request_client_id) = &request.client_auth.client_id 67 - && request_client_id != &auth_request.client_id 63 + && request_client_id != &authorized.client_id 68 64 { 69 65 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 70 66 } 71 - let did = flow_state.did().unwrap().to_string(); 67 + let did = authorized.did.to_string(); 72 68 let client_metadata_cache = ClientMetadataCache::new(3600); 73 - let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?; 69 + let client_metadata = client_metadata_cache.get(&authorized.client_id).await?; 74 70 let client_auth = if let (Some(assertion), Some(assertion_type)) = ( 75 71 &request.client_auth.client_assertion, 76 72 &request.client_auth.client_assertion_type, ··· 91 87 ClientAuth::None 92 88 }; 93 89 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 94 - verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 90 + verify_pkce(&authorized.parameters.code_challenge, &code_verifier)?; 95 91 if let Some(req_redirect_uri) = &redirect_uri 96 - && req_redirect_uri != &auth_request.parameters.redirect_uri 92 + && req_redirect_uri != &authorized.parameters.redirect_uri 97 93 { 98 94 return Err(OAuthError::InvalidGrant( 99 95 "redirect_uri mismatch".to_string(), ··· 102 98 let dpop_jkt = if let Some(proof) = &dpop_proof { 103 99 let config = AuthConfig::get(); 104 100 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 105 - let pds_hostname = 106 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 + let pds_hostname = pds_hostname(); 107 102 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 108 103 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 109 104 if !state ··· 116 111 "DPoP proof has already been used".to_string(), 117 112 )); 118 113 } 119 - if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 114 + if let Some(expected_jkt) = &authorized.parameters.dpop_jkt 120 115 && result.jkt.as_str() != expected_jkt 121 116 { 122 117 return Err(OAuthError::InvalidDpopProof( ··· 124 119 )); 125 120 } 126 121 Some(result.jkt.as_str().to_string()) 127 - } else if auth_request.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() { 122 + } else if authorized.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() { 128 123 return Err(OAuthError::UseDpopNonce( 129 124 DPoPVerifier::new(AuthConfig::get().dpop_secret().as_bytes()).generate_nonce(), 130 125 )); ··· 135 130 let refresh_token = RefreshToken::generate(); 136 131 let now = Utc::now(); 137 132 138 - let (raw_scope, controller_did) = if let Some(ref controller) = auth_request.controller_did { 133 + let (raw_scope, controller_did) = if let Some(ref controller) = authorized.controller_did { 139 134 let did_parsed: Did = did 140 135 .parse() 141 136 .map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?; ··· 149 144 .ok() 150 145 .flatten(); 151 146 let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default(); 152 - let requested = auth_request 153 - .parameters 154 - .scope 155 - .as_deref() 156 - .unwrap_or("atproto"); 157 - let intersected = intersect_scopes(requested, &granted_scopes); 147 + let requested = authorized.parameters.scope.as_deref().unwrap_or("atproto"); 148 + let intersected = intersect_scopes(requested, granted_scopes.as_str()); 158 149 (Some(intersected), Some(controller.clone())) 159 150 } else { 160 - (auth_request.parameters.scope.clone(), None) 151 + (authorized.parameters.scope.clone(), None) 161 152 }; 162 153 163 154 let final_scope = if let Some(ref scope) = raw_scope { ··· 177 168 final_scope.as_deref(), 178 169 controller_did.as_deref(), 179 170 )?; 180 - let stored_client_auth = auth_request.client_auth.unwrap_or(ClientAuth::None); 171 + let stored_client_auth = authorized.client_auth.unwrap_or(ClientAuth::None); 181 172 let refresh_expiry_days = if matches!(stored_client_auth, ClientAuth::None) { 182 173 REFRESH_TOKEN_EXPIRY_DAYS_PUBLIC 183 174 } else { 184 175 REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL 185 176 }; 186 - let mut stored_parameters = auth_request.parameters.clone(); 177 + let mut stored_parameters = authorized.parameters.clone(); 187 178 stored_parameters.dpop_jkt = dpop_jkt.clone(); 179 + let did_typed: Did = did 180 + .parse() 181 + .map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?; 188 182 let token_data = TokenData { 189 - did: did.clone(), 190 - token_id: token_id.0.clone(), 183 + did: did_typed, 184 + token_id: token_id.clone(), 191 185 created_at: now, 192 186 updated_at: now, 193 187 expires_at: now + Duration::days(refresh_expiry_days), 194 - client_id: auth_request.client_id.clone(), 188 + client_id: authorized.client_id.clone(), 195 189 client_auth: stored_client_auth, 196 - device_id: auth_request.device_id, 190 + device_id: authorized.device_id.clone(), 197 191 parameters: stored_parameters, 198 192 details: None, 199 193 code: None, 200 - current_refresh_token: Some(refresh_token.0.clone()), 194 + current_refresh_token: Some(refresh_token.clone()), 201 195 scope: final_scope.clone(), 202 196 controller_did: controller_did.clone(), 203 197 }; ··· 209 203 tracing::info!( 210 204 did = %did, 211 205 token_id = %token_id.0, 212 - client_id = %auth_request.client_id, 206 + client_id = %authorized.client_id, 213 207 "Authorization code grant completed, token created" 214 208 ); 215 209 tokio::spawn({ ··· 280 274 ); 281 275 let dpop_jkt = token_data.parameters.dpop_jkt.as_deref(); 282 276 let access_token = create_access_token_with_delegation( 283 - &token_data.token_id, 284 - &token_data.did, 277 + &token_data.token_id.0, 278 + token_data.did.as_str(), 285 279 dpop_jkt, 286 280 token_data.scope.as_deref(), 287 - token_data.controller_did.as_deref(), 281 + token_data.controller_did.as_ref().map(|d| d.as_str()), 288 282 )?; 289 283 let mut response_headers = HeaderMap::new(); 290 284 let config = AuthConfig::get(); ··· 296 290 access_token, 297 291 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 298 292 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 299 - refresh_token: token_data.current_refresh_token, 293 + refresh_token: token_data.current_refresh_token.map(|r| r.0), 300 294 scope: token_data.scope, 301 - sub: Some(token_data.did), 295 + sub: Some(token_data.did.to_string()), 302 296 }), 303 297 )); 304 298 } ··· 337 331 let dpop_jkt = if let Some(proof) = &dpop_proof { 338 332 let config = AuthConfig::get(); 339 333 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 340 - let pds_hostname = 341 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 334 + let pds_hostname = pds_hostname(); 342 335 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 343 336 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 344 337 if !state ··· 385 378 "Refresh token rotated successfully" 386 379 ); 387 380 let access_token = create_access_token_with_delegation( 388 - &token_data.token_id, 389 - &token_data.did, 381 + &token_data.token_id.0, 382 + token_data.did.as_str(), 390 383 dpop_jkt.as_deref(), 391 384 token_data.scope.as_deref(), 392 - token_data.controller_did.as_deref(), 385 + token_data.controller_did.as_ref().map(|d| d.as_str()), 393 386 )?; 394 387 let mut response_headers = HeaderMap::new(); 395 388 let config = AuthConfig::get(); ··· 403 396 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 404 397 refresh_token: Some(new_refresh_token.0), 405 398 scope: token_data.scope, 406 - sub: Some(token_data.did), 399 + sub: Some(token_data.did.to_string()), 407 400 }), 408 401 )) 409 402 }
+2 -1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
··· 1 1 use crate::config::AuthConfig; 2 2 use crate::oauth::OAuthError; 3 + use crate::util::pds_hostname; 3 4 use base64::Engine; 4 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 5 6 use chrono::Utc; ··· 51 52 ) -> Result<String, OAuthError> { 52 53 use serde_json::json; 53 54 let jti = uuid::Uuid::new_v4().to_string(); 54 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 55 + let pds_hostname = pds_hostname(); 55 56 let issuer = format!("https://{}", pds_hostname); 56 57 let now = Utc::now().timestamp(); 57 58 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
+8 -22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
··· 1 1 use super::helpers::extract_token_claims; 2 2 use crate::oauth::OAuthError; 3 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::rate_limit::{OAuthIntrospectLimit, OAuthRateLimited}; 4 + use crate::state::AppState; 5 + use crate::util::pds_hostname; 4 6 use axum::extract::State; 5 - use axum::http::{HeaderMap, StatusCode}; 7 + use axum::http::StatusCode; 6 8 use axum::{Form, Json}; 7 9 use chrono::Utc; 8 10 use serde::{Deserialize, Serialize}; ··· 17 19 18 20 pub async fn revoke_token( 19 21 State(state): State<AppState>, 20 - headers: HeaderMap, 22 + _rate_limit: OAuthRateLimited<OAuthIntrospectLimit>, 21 23 Form(request): Form<RevokeRequest>, 22 24 ) -> Result<StatusCode, OAuthError> { 23 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 24 - if !state 25 - .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 26 - .await 27 - { 28 - tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 29 - return Err(OAuthError::RateLimited); 30 - } 31 25 if let Some(token) = &request.token { 32 26 let refresh_token = RefreshToken::from(token.clone()); 33 27 if let Some((db_id, _)) = state ··· 89 83 90 84 pub async fn introspect_token( 91 85 State(state): State<AppState>, 92 - headers: HeaderMap, 86 + _rate_limit: OAuthRateLimited<OAuthIntrospectLimit>, 93 87 Form(request): Form<IntrospectRequest>, 94 88 ) -> Result<Json<IntrospectResponse>, OAuthError> { 95 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 96 - if !state 97 - .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 98 - .await 99 - { 100 - tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 101 - return Err(OAuthError::RateLimited); 102 - } 103 89 let inactive_response = IntrospectResponse { 104 90 active: false, 105 91 scope: None, ··· 126 112 if token_data.expires_at < Utc::now() { 127 113 return Ok(Json(inactive_response)); 128 114 } 129 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 115 + let pds_hostname = pds_hostname(); 130 116 let issuer = format!("https://{}", pds_hostname); 131 117 Ok(Json(IntrospectResponse { 132 118 active: true, ··· 141 127 exp: Some(token_info.exp), 142 128 iat: Some(token_info.iat), 143 129 nbf: Some(token_info.iat), 144 - sub: Some(token_data.did), 130 + sub: Some(token_data.did.to_string()), 145 131 aud: Some(issuer.clone()), 146 132 iss: Some(issuer), 147 133 jti: Some(token_info.jti),
+3 -26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
··· 4 4 mod types; 5 5 6 6 use crate::oauth::OAuthError; 7 - use crate::state::{AppState, RateLimitKind}; 7 + use crate::rate_limit::{OAuthRateLimited, OAuthTokenLimit}; 8 + use crate::state::AppState; 8 9 use axum::body::Bytes; 9 10 use axum::{Json, extract::State, http::HeaderMap}; 10 11 ··· 17 18 ClientAuthParams, GrantType, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest, 18 19 }; 19 20 20 - fn extract_client_ip(headers: &HeaderMap) -> String { 21 - if let Some(forwarded) = headers.get("x-forwarded-for") 22 - && let Ok(value) = forwarded.to_str() 23 - && let Some(first_ip) = value.split(',').next() 24 - { 25 - return first_ip.trim().to_string(); 26 - } 27 - if let Some(real_ip) = headers.get("x-real-ip") 28 - && let Ok(value) = real_ip.to_str() 29 - { 30 - return value.trim().to_string(); 31 - } 32 - "unknown".to_string() 33 - } 34 - 35 21 pub async fn token_endpoint( 36 22 State(state): State<AppState>, 23 + _rate_limit: OAuthRateLimited<OAuthTokenLimit>, 37 24 headers: HeaderMap, 38 25 body: Bytes, 39 26 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { ··· 53 40 .to_string(), 54 41 )); 55 42 }; 56 - let client_ip = extract_client_ip(&headers); 57 - if !state 58 - .check_rate_limit(RateLimitKind::OAuthToken, &client_ip) 59 - .await 60 - { 61 - tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 62 - return Err(OAuthError::InvalidRequest( 63 - "Too many requests. Please try again later.".to_string(), 64 - )); 65 - } 66 43 let dpop_proof = headers 67 44 .get("DPoP") 68 45 .and_then(|v| v.to_str().ok())
+9 -7
crates/tranquil-pds/src/oauth/mod.rs
··· 10 10 } 11 11 12 12 pub use tranquil_oauth::{ 13 - AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 - AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, DPoPClaims, 15 - DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, DPoPVerifyResult, DeviceData, 16 - DeviceId, JwkPublicKey, Jwks, OAuthClientMetadata, OAuthError, ParResponse, 17 - ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, SessionId, 18 - TokenData, TokenId, TokenRequest, TokenResponse, compute_access_token_hash, 19 - compute_jwk_thumbprint, verify_client_auth, 13 + AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 + AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, 15 + CodeChallengeMethod, DPoPClaims, DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, 16 + DPoPVerifyResult, DeviceData, DeviceId, FlowAuthenticated, FlowAuthorized, FlowExpired, 17 + FlowNotAuthenticated, FlowNotAuthorized, FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, 18 + OAuthError, ParResponse, Prompt, ProtectedResourceMetadata, RefreshToken, RefreshTokenState, 19 + RequestData, RequestId, ResponseMode, ResponseType, SessionId, TokenData, TokenId, 20 + TokenRequest, TokenResponse, compute_access_token_hash, compute_jwk_thumbprint, 21 + verify_client_auth, 20 22 }; 21 23 22 24 pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions};
+20 -14
crates/tranquil-pds/src/oauth/verify.rs
··· 20 20 use crate::state::AppState; 21 21 22 22 pub struct OAuthTokenInfo { 23 - pub did: String, 24 - pub token_id: String, 25 - pub client_id: String, 23 + pub did: Did, 24 + pub token_id: TokenId, 25 + pub client_id: ClientId, 26 26 pub scope: Option<String>, 27 27 pub dpop_jkt: Option<String>, 28 - pub controller_did: Option<String>, 28 + pub controller_did: Option<Did>, 29 29 } 30 30 31 31 pub struct VerifyResult { ··· 48 48 has_dpop_proof = dpop_proof.is_some(), 49 49 "Verifying OAuth access token" 50 50 ); 51 - let token_id = TokenId::from(token_info.token_id.clone()); 51 + let token_id = token_info.token_id.clone(); 52 52 let token_data = oauth_repo 53 53 .get_token_by_id(&token_id) 54 54 .await ··· 154 154 if exp < now { 155 155 return Err(OAuthError::ExpiredToken("Token has expired".to_string())); 156 156 } 157 - let token_id = payload 157 + let token_id_str = payload 158 158 .get("sid") 159 159 .and_then(|j| j.as_str()) 160 - .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))? 161 - .to_string(); 162 - let did = payload 160 + .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?; 161 + let token_id = TokenId::new(token_id_str); 162 + let did_str = payload 163 163 .get("sub") 164 164 .and_then(|s| s.as_str()) 165 - .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))? 166 - .to_string(); 165 + .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?; 166 + let did: Did = did_str 167 + .parse() 168 + .map_err(|_| OAuthError::InvalidToken("Invalid sub claim (not a valid DID)".to_string()))?; 167 169 let scope = payload 168 170 .get("scope") 169 171 .and_then(|s| s.as_str()) ··· 173 175 .and_then(|c| c.get("jkt")) 174 176 .and_then(|j| j.as_str()) 175 177 .map(|s| s.to_string()); 176 - let client_id = payload 178 + let client_id_str = payload 177 179 .get("client_id") 178 180 .and_then(|c| c.as_str()) 179 - .map(|s| s.to_string()) 180 181 .unwrap_or_default(); 182 + let client_id = ClientId::new(client_id_str); 181 183 let controller_did = payload 182 184 .get("act") 183 185 .and_then(|a| a.get("sub")) 184 186 .and_then(|s| s.as_str()) 185 - .map(|s| s.to_string()); 187 + .map(|s| s.parse::<Did>()) 188 + .transpose() 189 + .map_err(|_| { 190 + OAuthError::InvalidToken("Invalid act.sub claim (not a valid DID)".to_string()) 191 + })?; 186 192 Ok(OAuthTokenInfo { 187 193 did, 188 194 token_id,
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
··· 1 + use std::marker::PhantomData; 2 + 3 + use axum::{ 4 + extract::FromRequestParts, 5 + http::request::Parts, 6 + response::{IntoResponse, Response}, 7 + }; 8 + 9 + use crate::api::error::ApiError; 10 + use crate::oauth::OAuthError; 11 + use crate::state::{AppState, RateLimitKind}; 12 + use crate::util::extract_client_ip; 13 + 14 + pub trait RateLimitPolicy: Send + Sync + 'static { 15 + const KIND: RateLimitKind; 16 + } 17 + 18 + pub struct LoginLimit; 19 + impl RateLimitPolicy for LoginLimit { 20 + const KIND: RateLimitKind = RateLimitKind::Login; 21 + } 22 + 23 + pub struct AccountCreationLimit; 24 + impl RateLimitPolicy for AccountCreationLimit { 25 + const KIND: RateLimitKind = RateLimitKind::AccountCreation; 26 + } 27 + 28 + pub struct PasswordResetLimit; 29 + impl RateLimitPolicy for PasswordResetLimit { 30 + const KIND: RateLimitKind = RateLimitKind::PasswordReset; 31 + } 32 + 33 + pub struct ResetPasswordLimit; 34 + impl RateLimitPolicy for ResetPasswordLimit { 35 + const KIND: RateLimitKind = RateLimitKind::ResetPassword; 36 + } 37 + 38 + pub struct RefreshSessionLimit; 39 + impl RateLimitPolicy for RefreshSessionLimit { 40 + const KIND: RateLimitKind = RateLimitKind::RefreshSession; 41 + } 42 + 43 + pub struct OAuthTokenLimit; 44 + impl RateLimitPolicy for OAuthTokenLimit { 45 + const KIND: RateLimitKind = RateLimitKind::OAuthToken; 46 + } 47 + 48 + pub struct OAuthAuthorizeLimit; 49 + impl RateLimitPolicy for OAuthAuthorizeLimit { 50 + const KIND: RateLimitKind = RateLimitKind::OAuthAuthorize; 51 + } 52 + 53 + pub struct OAuthParLimit; 54 + impl RateLimitPolicy for OAuthParLimit { 55 + const KIND: RateLimitKind = RateLimitKind::OAuthPar; 56 + } 57 + 58 + pub struct OAuthIntrospectLimit; 59 + impl RateLimitPolicy for OAuthIntrospectLimit { 60 + const KIND: RateLimitKind = RateLimitKind::OAuthIntrospect; 61 + } 62 + 63 + pub struct AppPasswordLimit; 64 + impl RateLimitPolicy for AppPasswordLimit { 65 + const KIND: RateLimitKind = RateLimitKind::AppPassword; 66 + } 67 + 68 + pub struct EmailUpdateLimit; 69 + impl RateLimitPolicy for EmailUpdateLimit { 70 + const KIND: RateLimitKind = RateLimitKind::EmailUpdate; 71 + } 72 + 73 + pub struct TotpVerifyLimit; 74 + impl RateLimitPolicy for TotpVerifyLimit { 75 + const KIND: RateLimitKind = RateLimitKind::TotpVerify; 76 + } 77 + 78 + pub struct HandleUpdateLimit; 79 + impl RateLimitPolicy for HandleUpdateLimit { 80 + const KIND: RateLimitKind = RateLimitKind::HandleUpdate; 81 + } 82 + 83 + pub struct HandleUpdateDailyLimit; 84 + impl RateLimitPolicy for HandleUpdateDailyLimit { 85 + const KIND: RateLimitKind = RateLimitKind::HandleUpdateDaily; 86 + } 87 + 88 + pub struct VerificationCheckLimit; 89 + impl RateLimitPolicy for VerificationCheckLimit { 90 + const KIND: RateLimitKind = RateLimitKind::VerificationCheck; 91 + } 92 + 93 + pub struct SsoInitiateLimit; 94 + impl RateLimitPolicy for SsoInitiateLimit { 95 + const KIND: RateLimitKind = RateLimitKind::SsoInitiate; 96 + } 97 + 98 + pub struct SsoCallbackLimit; 99 + impl RateLimitPolicy for SsoCallbackLimit { 100 + const KIND: RateLimitKind = RateLimitKind::SsoCallback; 101 + } 102 + 103 + pub struct SsoUnlinkLimit; 104 + impl RateLimitPolicy for SsoUnlinkLimit { 105 + const KIND: RateLimitKind = RateLimitKind::SsoUnlink; 106 + } 107 + 108 + pub struct OAuthRegisterCompleteLimit; 109 + impl RateLimitPolicy for OAuthRegisterCompleteLimit { 110 + const KIND: RateLimitKind = RateLimitKind::OAuthRegisterComplete; 111 + } 112 + 113 + pub trait RateLimitRejection: IntoResponse + Send + 'static { 114 + fn new() -> Self; 115 + } 116 + 117 + pub struct ApiRateLimitRejection; 118 + 119 + impl RateLimitRejection for ApiRateLimitRejection { 120 + fn new() -> Self { 121 + Self 122 + } 123 + } 124 + 125 + impl IntoResponse for ApiRateLimitRejection { 126 + fn into_response(self) -> Response { 127 + ApiError::RateLimitExceeded(None).into_response() 128 + } 129 + } 130 + 131 + pub struct OAuthRateLimitRejection; 132 + 133 + impl RateLimitRejection for OAuthRateLimitRejection { 134 + fn new() -> Self { 135 + Self 136 + } 137 + } 138 + 139 + impl IntoResponse for OAuthRateLimitRejection { 140 + fn into_response(self) -> Response { 141 + OAuthError::RateLimited.into_response() 142 + } 143 + } 144 + 145 + impl From<OAuthRateLimitRejection> for OAuthError { 146 + fn from(_: OAuthRateLimitRejection) -> Self { 147 + OAuthError::RateLimited 148 + } 149 + } 150 + 151 + pub struct RateLimitedInner<P: RateLimitPolicy, R: RateLimitRejection> { 152 + client_ip: String, 153 + _marker: PhantomData<(P, R)>, 154 + } 155 + 156 + impl<P: RateLimitPolicy, R: RateLimitRejection> RateLimitedInner<P, R> { 157 + pub fn client_ip(&self) -> &str { 158 + &self.client_ip 159 + } 160 + } 161 + 162 + impl<P: RateLimitPolicy, R: RateLimitRejection> FromRequestParts<AppState> 163 + for RateLimitedInner<P, R> 164 + { 165 + type Rejection = R; 166 + 167 + async fn from_request_parts( 168 + parts: &mut Parts, 169 + state: &AppState, 170 + ) -> Result<Self, Self::Rejection> { 171 + let client_ip = extract_client_ip(&parts.headers, None); 172 + 173 + if !state.check_rate_limit(P::KIND, &client_ip).await { 174 + tracing::warn!( 175 + ip = %client_ip, 176 + kind = ?P::KIND, 177 + "Rate limit exceeded" 178 + ); 179 + return Err(R::new()); 180 + } 181 + 182 + Ok(RateLimitedInner { 183 + client_ip, 184 + _marker: PhantomData, 185 + }) 186 + } 187 + } 188 + 189 + pub type RateLimited<P> = RateLimitedInner<P, ApiRateLimitRejection>; 190 + pub type OAuthRateLimited<P> = RateLimitedInner<P, OAuthRateLimitRejection>; 191 + 192 + #[derive(Debug)] 193 + pub struct UserRateLimitError { 194 + pub kind: RateLimitKind, 195 + pub message: Option<String>, 196 + } 197 + 198 + impl UserRateLimitError { 199 + pub fn new(kind: RateLimitKind) -> Self { 200 + Self { 201 + kind, 202 + message: None, 203 + } 204 + } 205 + 206 + pub fn with_message(kind: RateLimitKind, message: impl Into<String>) -> Self { 207 + Self { 208 + kind, 209 + message: Some(message.into()), 210 + } 211 + } 212 + } 213 + 214 + impl std::fmt::Display for UserRateLimitError { 215 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 216 + match &self.message { 217 + Some(msg) => write!(f, "{}", msg), 218 + None => write!(f, "Rate limit exceeded for {:?}", self.kind), 219 + } 220 + } 221 + } 222 + 223 + impl std::error::Error for UserRateLimitError {} 224 + 225 + impl IntoResponse for UserRateLimitError { 226 + fn into_response(self) -> Response { 227 + ApiError::RateLimitExceeded(self.message).into_response() 228 + } 229 + } 230 + 231 + pub struct UserRateLimitProof<P: RateLimitPolicy> { 232 + _marker: PhantomData<P>, 233 + } 234 + 235 + impl<P: RateLimitPolicy> UserRateLimitProof<P> { 236 + fn new() -> Self { 237 + Self { 238 + _marker: PhantomData, 239 + } 240 + } 241 + } 242 + 243 + pub async fn check_user_rate_limit<P: RateLimitPolicy>( 244 + state: &AppState, 245 + user_key: &str, 246 + ) -> Result<UserRateLimitProof<P>, UserRateLimitError> { 247 + if !state.check_rate_limit(P::KIND, user_key).await { 248 + tracing::warn!( 249 + key = %user_key, 250 + kind = ?P::KIND, 251 + "User rate limit exceeded" 252 + ); 253 + return Err(UserRateLimitError::new(P::KIND)); 254 + } 255 + Ok(UserRateLimitProof::new()) 256 + } 257 + 258 + pub async fn check_user_rate_limit_with_message<P: RateLimitPolicy>( 259 + state: &AppState, 260 + user_key: &str, 261 + error_message: impl Into<String>, 262 + ) -> Result<UserRateLimitProof<P>, UserRateLimitError> { 263 + if !state.check_rate_limit(P::KIND, user_key).await { 264 + tracing::warn!( 265 + key = %user_key, 266 + kind = ?P::KIND, 267 + "User rate limit exceeded" 268 + ); 269 + return Err(UserRateLimitError::with_message(P::KIND, error_message)); 270 + } 271 + Ok(UserRateLimitProof::new()) 272 + }
+5 -102
crates/tranquil-pds/src/rate_limit.rs crates/tranquil-pds/src/rate_limit/mod.rs
··· 1 - use axum::{ 2 - Json, 3 - body::Body, 4 - extract::ConnectInfo, 5 - http::{HeaderMap, Request, StatusCode}, 6 - middleware::Next, 7 - response::{IntoResponse, Response}, 8 - }; 1 + mod extractor; 2 + 3 + pub use extractor::*; 4 + 9 5 use governor::{ 10 6 Quota, RateLimiter, 11 7 clock::DefaultClock, 12 8 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13 9 }; 14 - use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 10 + use std::{num::NonZeroU32, sync::Arc}; 15 11 16 12 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17 13 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; ··· 166 162 } 167 163 } 168 164 169 - pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 170 - if let Some(forwarded) = headers.get("x-forwarded-for") 171 - && let Ok(value) = forwarded.to_str() 172 - && let Some(first_ip) = value.split(',').next() 173 - { 174 - return first_ip.trim().to_string(); 175 - } 176 - 177 - if let Some(real_ip) = headers.get("x-real-ip") 178 - && let Ok(value) = real_ip.to_str() 179 - { 180 - return value.trim().to_string(); 181 - } 182 - 183 - addr.map(|a| a.ip().to_string()) 184 - .unwrap_or_else(|| "unknown".to_string()) 185 - } 186 - 187 - fn rate_limit_response() -> Response { 188 - ( 189 - StatusCode::TOO_MANY_REQUESTS, 190 - Json(serde_json::json!({ 191 - "error": "RateLimitExceeded", 192 - "message": "Too many requests. Please try again later." 193 - })), 194 - ) 195 - .into_response() 196 - } 197 - 198 - pub async fn login_rate_limit( 199 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 200 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 201 - request: Request<Body>, 202 - next: Next, 203 - ) -> Response { 204 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 205 - 206 - if limiters.login.check_key(&client_ip).is_err() { 207 - tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 208 - return rate_limit_response(); 209 - } 210 - 211 - next.run(request).await 212 - } 213 - 214 - pub async fn oauth_token_rate_limit( 215 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 216 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 217 - request: Request<Body>, 218 - next: Next, 219 - ) -> Response { 220 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 221 - 222 - if limiters.oauth_token.check_key(&client_ip).is_err() { 223 - tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 224 - return rate_limit_response(); 225 - } 226 - 227 - next.run(request).await 228 - } 229 - 230 - pub async fn password_reset_rate_limit( 231 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 232 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 233 - request: Request<Body>, 234 - next: Next, 235 - ) -> Response { 236 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 237 - 238 - if limiters.password_reset.check_key(&client_ip).is_err() { 239 - tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 240 - return rate_limit_response(); 241 - } 242 - 243 - next.run(request).await 244 - } 245 - 246 - pub async fn account_creation_rate_limit( 247 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 248 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 249 - request: Request<Body>, 250 - next: Next, 251 - ) -> Response { 252 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 253 - 254 - if limiters.account_creation.check_key(&client_ip).is_err() { 255 - tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 256 - return rate_limit_response(); 257 - } 258 - 259 - next.run(request).await 260 - } 261 - 262 165 #[cfg(test)] 263 166 mod tests { 264 167 use super::*;
+15 -19
crates/tranquil-pds/src/scheduled.rs
··· 5 5 use std::str::FromStr; 6 6 use std::sync::Arc; 7 7 use std::time::Duration; 8 - use tokio::sync::watch; 9 8 use tokio::time::interval; 9 + use tokio_util::sync::CancellationToken; 10 10 use tracing::{debug, error, info, warn}; 11 11 use tranquil_db_traits::{ 12 - BackupRepository, BlobRepository, BrokenGenesisCommit, RepoRepository, SsoRepository, 13 - UserRepository, 12 + BackupRepository, BlobRepository, BrokenGenesisCommit, RepoRepository, SequenceNumber, 13 + SsoRepository, UserRepository, 14 14 }; 15 15 use tranquil_types::{AtUri, CidLink, Did}; 16 16 ··· 22 22 repo_repo: &dyn RepoRepository, 23 23 block_store: &PostgresBlockStore, 24 24 row: BrokenGenesisCommit, 25 - ) -> Result<(Did, i64), (i64, &'static str)> { 25 + ) -> Result<(Did, SequenceNumber), (SequenceNumber, &'static str)> { 26 26 let commit_cid_str = row.commit_cid.ok_or((row.seq, "missing commit_cid"))?; 27 27 let commit_cid = Cid::from_str(&commit_cid_str).map_err(|_| (row.seq, "invalid CID"))?; 28 28 let block = block_store ··· 73 73 74 74 let (success, failed) = results.iter().fold((0, 0), |(s, f), r| match r { 75 75 Ok((did, seq)) => { 76 - info!(seq = seq, did = %did, "Fixed genesis commit blocks_cids"); 76 + info!(seq = seq.as_i64(), did = %did, "Fixed genesis commit blocks_cids"); 77 77 (s + 1, f) 78 78 } 79 79 Err((seq, reason)) => { 80 80 warn!( 81 - seq = seq, 81 + seq = seq.as_i64(), 82 82 reason = reason, 83 83 "Failed to process genesis commit" 84 84 ); ··· 314 314 record.collection.as_str(), 315 315 record.rkey.as_str(), 316 316 ); 317 - (record_uri, CidLink::new_unchecked(blob_ref.cid)) 317 + (record_uri, unsafe { CidLink::new_unchecked(blob_ref.cid) }) 318 318 }) 319 319 .collect::<Vec<_>>(), 320 320 ) ··· 392 392 blob_repo: Arc<dyn BlobRepository>, 393 393 blob_store: Arc<dyn BlobStorage>, 394 394 sso_repo: Arc<dyn SsoRepository>, 395 - mut shutdown_rx: watch::Receiver<bool>, 395 + shutdown: CancellationToken, 396 396 ) { 397 397 let check_interval = Duration::from_secs( 398 398 std::env::var("SCHEDULED_DELETE_CHECK_INTERVAL_SECS") ··· 411 411 412 412 loop { 413 413 tokio::select! { 414 - _ = shutdown_rx.changed() => { 415 - if *shutdown_rx.borrow() { 416 - info!("Scheduled tasks service shutting down"); 417 - break; 418 - } 414 + _ = shutdown.cancelled() => { 415 + info!("Scheduled tasks service shutting down"); 416 + break; 419 417 } 420 418 _ = ticker.tick() => { 421 419 if let Err(e) = process_scheduled_deletions( ··· 538 536 backup_repo: Arc<dyn BackupRepository>, 539 537 block_store: PostgresBlockStore, 540 538 backup_storage: Arc<dyn BackupStorage>, 541 - mut shutdown_rx: watch::Receiver<bool>, 539 + shutdown: CancellationToken, 542 540 ) { 543 541 let backup_interval = Duration::from_secs(backup_interval_secs()); 544 542 ··· 553 551 554 552 loop { 555 553 tokio::select! { 556 - _ = shutdown_rx.changed() => { 557 - if *shutdown_rx.borrow() { 558 - info!("Backup service shutting down"); 559 - break; 560 - } 554 + _ = shutdown.cancelled() => { 555 + info!("Backup service shutting down"); 556 + break; 561 557 } 562 558 _ = ticker.tick() => { 563 559 if let Err(e) = process_scheduled_backups(
+2 -1
crates/tranquil-pds/src/sso/config.rs
··· 1 + use crate::util::pds_hostname; 1 2 use std::sync::OnceLock; 2 3 use tranquil_db_traits::SsoProviderType; 3 4 ··· 50 51 }; 51 52 52 53 if config.is_any_enabled() { 53 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_default(); 54 + let hostname = pds_hostname(); 54 55 if hostname.is_empty() || hostname == "localhost" { 55 56 panic!( 56 57 "PDS_HOSTNAME must be set to a valid hostname when SSO is enabled. \
+98 -114
crates/tranquil-pds/src/sso/endpoints.rs
··· 6 6 }; 7 7 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 8 8 use serde::{Deserialize, Serialize}; 9 - use tranquil_db_traits::SsoProviderType; 9 + use tranquil_db_traits::{SsoAction, SsoProviderType}; 10 10 use tranquil_types::RequestId; 11 11 12 12 use super::config::SsoConfig; 13 13 use crate::api::error::ApiError; 14 14 use crate::auth::extractor::extract_bearer_token_from_header; 15 15 use crate::auth::{generate_app_password, validate_bearer_token_cached}; 16 - use crate::rate_limit::extract_client_ip; 17 - use crate::state::{AppState, RateLimitKind}; 16 + use crate::rate_limit::{ 17 + AccountCreationLimit, RateLimited, SsoCallbackLimit, SsoInitiateLimit, SsoUnlinkLimit, 18 + check_user_rate_limit_with_message, 19 + }; 20 + use crate::state::AppState; 21 + use crate::util::{pds_hostname, pds_hostname_without_port}; 18 22 19 23 fn generate_state() -> String { 20 24 use rand::RngCore; ··· 71 75 72 76 pub async fn sso_initiate( 73 77 State(state): State<AppState>, 78 + _rate_limit: RateLimited<SsoInitiateLimit>, 74 79 headers: HeaderMap, 75 80 Json(input): Json<SsoInitiateRequest>, 76 81 ) -> Result<Json<SsoInitiateResponse>, ApiError> { 77 - let client_ip = extract_client_ip(&headers, None); 78 - if !state 79 - .check_rate_limit(RateLimitKind::SsoInitiate, &client_ip) 80 - .await 81 - { 82 - tracing::warn!(ip = %client_ip, "SSO initiate rate limit exceeded"); 83 - return Err(ApiError::RateLimitExceeded(None)); 84 - } 85 - 86 82 if input.provider.len() > 20 { 87 83 return Err(ApiError::SsoProviderNotFound); 88 84 } ··· 105 101 .get_provider(provider_type) 106 102 .ok_or(ApiError::SsoProviderNotEnabled)?; 107 103 108 - let action = input.action.as_deref().unwrap_or("login"); 109 - if !["login", "link", "register"].contains(&action) { 110 - return Err(ApiError::SsoInvalidAction); 111 - } 104 + let action = input 105 + .action 106 + .as_deref() 107 + .map(SsoAction::parse) 108 + .unwrap_or(Some(SsoAction::Login)) 109 + .ok_or(ApiError::SsoInvalidAction)?; 112 110 113 - let is_standalone = action == "register" && input.request_uri.is_none(); 111 + let is_standalone = action == SsoAction::Register && input.request_uri.is_none(); 114 112 let request_uri = input 115 113 .request_uri 116 114 .clone() 117 115 .unwrap_or_else(|| "standalone".to_string()); 118 116 119 117 let auth_did = match action { 120 - "link" => { 118 + SsoAction::Link => { 121 119 let auth_header = headers 122 120 .get(axum::http::header::AUTHORIZATION) 123 121 .and_then(|v| v.to_str().ok()); ··· 132 130 .map_err(|_| ApiError::SsoNotAuthenticated)?; 133 131 Some(auth_user.did) 134 132 } 135 - "register" if is_standalone => None, 133 + SsoAction::Register if is_standalone => None, 136 134 _ => { 137 135 let request_id = RequestId::new(request_uri.clone()); 138 136 let _request_data = state ··· 217 215 218 216 pub async fn sso_callback( 219 217 State(state): State<AppState>, 220 - headers: HeaderMap, 218 + _rate_limit: RateLimited<SsoCallbackLimit>, 221 219 Query(query): Query<SsoCallbackQuery>, 222 220 ) -> Response { 221 + sso_callback_internal(&state, query).await 222 + } 223 + 224 + async fn sso_callback_internal(state: &AppState, query: SsoCallbackQuery) -> Response { 223 225 tracing::debug!( 224 226 has_code = query.code.is_some(), 225 227 has_state = query.state.is_some(), ··· 227 229 "SSO callback received" 228 230 ); 229 231 230 - let client_ip = extract_client_ip(&headers, None); 231 - if !state 232 - .check_rate_limit(RateLimitKind::SsoCallback, &client_ip) 233 - .await 234 - { 235 - tracing::warn!(ip = %client_ip, "SSO callback rate limit exceeded"); 236 - return redirect_to_error("Too many requests. Please try again later."); 237 - } 238 - 239 232 if let Some(ref error) = query.error { 240 233 tracing::warn!( 241 234 error = %error, ··· 326 319 } 327 320 }; 328 321 329 - match auth_state.action.as_str() { 330 - "login" => { 322 + match auth_state.action { 323 + SsoAction::Login => { 331 324 handle_sso_login( 332 - &state, 325 + state, 333 326 &auth_state.request_uri, 334 327 auth_state.provider, 335 328 &user_info, 336 329 ) 337 330 .await 338 331 } 339 - "link" => { 332 + SsoAction::Link => { 340 333 let did = match auth_state.did { 341 334 Some(d) => d, 342 335 None => return redirect_to_error("Not authenticated"), 343 336 }; 344 - handle_sso_link(&state, did, auth_state.provider, &user_info).await 337 + handle_sso_link(state, did, auth_state.provider, &user_info).await 345 338 } 346 - "register" => { 339 + SsoAction::Register => { 347 340 handle_sso_register( 348 - &state, 341 + state, 349 342 &auth_state.request_uri, 350 343 auth_state.provider, 351 344 &user_info, 352 345 ) 353 346 .await 354 347 } 355 - _ => redirect_to_error("Unknown SSO action"), 356 348 } 357 349 } 358 350 359 351 pub async fn sso_callback_post( 360 352 State(state): State<AppState>, 361 - headers: HeaderMap, 353 + _rate_limit: RateLimited<SsoCallbackLimit>, 362 354 Form(form): Form<SsoCallbackForm>, 363 355 ) -> Response { 364 356 tracing::debug!( ··· 376 368 error_description: form.error_description, 377 369 }; 378 370 379 - sso_callback(State(state), headers, Query(query)).await 371 + sso_callback_internal(&state, query).await 380 372 } 381 373 382 374 fn generate_registration_token() -> String { ··· 429 421 }; 430 422 431 423 let is_verified = match state.user_repo.get_session_info_by_did(&identity.did).await { 432 - Ok(Some(info)) => { 433 - info.email_verified 434 - || info.discord_verified 435 - || info.telegram_verified 436 - || info.signal_verified 437 - } 424 + Ok(Some(info)) => info.channel_verification.has_any_verified(), 438 425 Ok(None) => { 439 426 tracing::error!("User not found for SSO login: {}", identity.did); 440 427 return redirect_to_error("Account not found"); ··· 486 473 "SSO login successful" 487 474 ); 488 475 489 - let has_totp = match state.user_repo.get_totp_record(&identity.did).await { 490 - Ok(Some(record)) => record.verified, 491 - _ => false, 492 - }; 476 + let has_totp = matches!( 477 + state.user_repo.get_totp_record_state(&identity.did).await, 478 + Ok(Some(tranquil_db_traits::TotpRecordState::Verified(_))) 479 + ); 493 480 494 481 if has_totp { 495 482 return Redirect::to(&format!( ··· 657 644 id: id.id.to_string(), 658 645 provider: id.provider.as_str().to_string(), 659 646 provider_name: id.provider.display_name().to_string(), 660 - provider_username: id.provider_username, 661 - provider_email: id.provider_email, 647 + provider_username: id.provider_username.map(|u| u.into_inner()), 648 + provider_email: id.provider_email.map(|e| e.into_inner()), 662 649 created_at: id.created_at.to_rfc3339(), 663 650 last_login_at: id.last_login_at.map(|t| t.to_rfc3339()), 664 651 }) ··· 682 669 auth: crate::auth::Auth<crate::auth::Active>, 683 670 Json(input): Json<UnlinkAccountRequest>, 684 671 ) -> Result<Json<UnlinkAccountResponse>, ApiError> { 685 - if !state 686 - .check_rate_limit(RateLimitKind::SsoUnlink, auth.did.as_str()) 687 - .await 688 - { 689 - tracing::warn!(did = %auth.did, "SSO unlink rate limit exceeded"); 690 - return Err(ApiError::RateLimitExceeded(None)); 691 - } 672 + let _rate_limit = check_user_rate_limit_with_message::<SsoUnlinkLimit>( 673 + &state, 674 + auth.did.as_str(), 675 + "Too many unlink attempts. Please try again later.", 676 + ) 677 + .await?; 692 678 693 679 let id = uuid::Uuid::parse_str(&input.id).map_err(|_| ApiError::InvalidId)?; 694 680 ··· 746 732 747 733 pub async fn get_pending_registration( 748 734 State(state): State<AppState>, 749 - headers: HeaderMap, 735 + _rate_limit: RateLimited<SsoCallbackLimit>, 750 736 Query(query): Query<PendingRegistrationQuery>, 751 737 ) -> Result<Json<PendingRegistrationResponse>, ApiError> { 752 - let client_ip = extract_client_ip(&headers, None); 753 - if !state 754 - .check_rate_limit(RateLimitKind::SsoCallback, &client_ip) 755 - .await 756 - { 757 - tracing::warn!(ip = %client_ip, "SSO pending registration rate limit exceeded"); 758 - return Err(ApiError::RateLimitExceeded(None)); 759 - } 760 - 761 738 if query.token.len() > 100 { 762 739 return Err(ApiError::InvalidRequest("Invalid token".into())); 763 740 } ··· 771 748 Ok(Json(PendingRegistrationResponse { 772 749 request_uri: pending.request_uri, 773 750 provider: pending.provider.as_str().to_string(), 774 - provider_user_id: pending.provider_user_id, 775 - provider_username: pending.provider_username, 776 - provider_email: pending.provider_email, 751 + provider_user_id: pending.provider_user_id.into_inner(), 752 + provider_username: pending.provider_username.map(|u| u.into_inner()), 753 + provider_email: pending.provider_email.map(|e| e.into_inner()), 777 754 provider_email_verified: pending.provider_email_verified, 778 755 })) 779 756 } ··· 810 787 } 811 788 }; 812 789 813 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 814 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 790 + let hostname_for_handles = pds_hostname_without_port(); 815 791 let full_handle = format!("{}.{}", validated, hostname_for_handles); 816 - let handle_typed = crate::types::Handle::new_unchecked(&full_handle); 792 + let handle_typed = unsafe { crate::types::Handle::new_unchecked(&full_handle) }; 817 793 818 794 let db_available = state 819 795 .user_repo ··· 866 842 867 843 pub async fn complete_registration( 868 844 State(state): State<AppState>, 869 - headers: HeaderMap, 845 + rate_limit: RateLimited<AccountCreationLimit>, 870 846 Json(input): Json<CompleteRegistrationInput>, 871 847 ) -> Result<Json<CompleteRegistrationResponse>, ApiError> { 848 + let client_ip = rate_limit.client_ip(); 872 849 use jacquard_common::types::{integer::LimitedU32, string::Tid}; 873 850 use jacquard_repo::{mst::Mst, storage::BlockStore}; 874 851 use k256::ecdsa::SigningKey; ··· 876 853 use serde_json::json; 877 854 use std::sync::Arc; 878 855 879 - let client_ip = extract_client_ip(&headers, None); 880 - if !state 881 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 882 - .await 883 - { 884 - tracing::warn!(ip = %client_ip, "SSO registration rate limit exceeded"); 885 - return Err(ApiError::RateLimitExceeded(None)); 886 - } 887 - 888 856 if input.token.len() > 100 { 889 857 return Err(ApiError::InvalidRequest("Invalid token".into())); 890 858 } ··· 899 867 .await? 900 868 .ok_or(ApiError::SsoSessionExpired)?; 901 869 902 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 903 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 870 + let hostname = pds_hostname(); 871 + let hostname_for_handles = pds_hostname_without_port(); 904 872 905 873 let handle = match crate::api::validation::validate_short_handle(&input.handle) { 906 874 Ok(h) => format!("{}.{}", h, hostname_for_handles), ··· 913 881 let email = input 914 882 .email 915 883 .clone() 916 - .or_else(|| pending_preview.provider_email.clone()) 884 + .or_else(|| { 885 + pending_preview 886 + .provider_email 887 + .clone() 888 + .map(|e| e.into_inner()) 889 + }) 917 890 .map(|e| e.trim().to_string()) 918 891 .filter(|e| !e.is_empty()); 919 892 match email { ··· 939 912 let email = input 940 913 .email 941 914 .clone() 942 - .or_else(|| pending_preview.provider_email.clone()) 915 + .or_else(|| { 916 + pending_preview 917 + .provider_email 918 + .clone() 919 + .map(|e| e.into_inner()) 920 + }) 943 921 .map(|e| e.trim().to_string()) 944 922 .filter(|e| !e.is_empty()); 945 923 ··· 956 934 None => None, 957 935 }; 958 936 959 - if let Some(ref code) = input.invite_code { 960 - let valid = state 961 - .infra_repo 962 - .is_invite_code_valid(code) 963 - .await 964 - .unwrap_or(false); 965 - if !valid { 966 - return Err(ApiError::InvalidInviteCode); 937 + let _validated_invite_code = if let Some(ref code) = input.invite_code { 938 + match state.infra_repo.validate_invite_code(code).await { 939 + Ok(validated) => Some(validated), 940 + Err(_) => return Err(ApiError::InvalidInviteCode), 967 941 } 968 942 } else { 969 943 let invite_required = std::env::var("INVITE_CODE_REQUIRED") ··· 972 946 if invite_required { 973 947 return Err(ApiError::InviteCodeRequired); 974 948 } 975 - } 949 + None 950 + }; 976 951 977 - let handle_typed = crate::types::Handle::new_unchecked(&handle); 952 + let handle_typed = unsafe { crate::types::Handle::new_unchecked(&handle) }; 978 953 let reserved = state 979 954 .user_repo 980 - .reserve_handle(&handle_typed, &client_ip) 955 + .reserve_handle(&handle_typed, client_ip) 981 956 .await 982 957 .unwrap_or(false); 983 958 ··· 1076 1051 }; 1077 1052 1078 1053 let rev = Tid::now(LimitedU32::MIN); 1079 - let did_typed = crate::types::Did::new_unchecked(&did); 1054 + let did_typed = unsafe { crate::types::Did::new_unchecked(&did) }; 1080 1055 let (commit_bytes, _sig) = match crate::api::repo::record::utils::create_signed_commit( 1081 1056 &did_typed, 1082 1057 mst_root, ··· 1144 1119 invite_code: input.invite_code.clone(), 1145 1120 birthdate_pref, 1146 1121 sso_provider: pending_preview.provider, 1147 - sso_provider_user_id: pending_preview.provider_user_id.clone(), 1148 - sso_provider_username: pending_preview.provider_username.clone(), 1149 - sso_provider_email: pending_preview.provider_email.clone(), 1122 + sso_provider_user_id: pending_preview.provider_user_id.clone().into_inner(), 1123 + sso_provider_username: pending_preview 1124 + .provider_username 1125 + .clone() 1126 + .map(|u| u.into_inner()), 1127 + sso_provider_email: pending_preview 1128 + .provider_email 1129 + .clone() 1130 + .map(|e| e.into_inner()), 1150 1131 sso_provider_email_verified: pending_preview.provider_email_verified, 1151 1132 pending_registration_token: input.token.clone(), 1152 1133 }; ··· 1179 1160 { 1180 1161 tracing::warn!("Failed to sequence identity event for {}: {}", did, e); 1181 1162 } 1182 - if let Err(e) = 1183 - crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await 1163 + if let Err(e) = crate::api::repo::record::sequence_account_event( 1164 + &state, 1165 + &did_typed, 1166 + tranquil_db_traits::AccountStatus::Active, 1167 + ) 1168 + .await 1184 1169 { 1185 1170 tracing::warn!("Failed to sequence account event for {}: {}", did, e); 1186 1171 } ··· 1189 1174 "$type": "app.bsky.actor.profile", 1190 1175 "displayName": handle_typed.as_str() 1191 1176 }); 1192 - let profile_collection = crate::types::Nsid::new_unchecked("app.bsky.actor.profile"); 1193 - let profile_rkey = crate::types::Rkey::new_unchecked("self"); 1177 + let profile_collection = unsafe { crate::types::Nsid::new_unchecked("app.bsky.actor.profile") }; 1178 + let profile_rkey = unsafe { crate::types::Rkey::new_unchecked("self") }; 1194 1179 if let Err(e) = crate::api::repo::record::create_record_internal( 1195 1180 &state, 1196 1181 &did_typed, ··· 1217 1202 user_id: create_result.user_id, 1218 1203 name: app_password_name.clone(), 1219 1204 password_hash: app_password_hash, 1220 - privileged: false, 1205 + privilege: tranquil_db_traits::AppPasswordPrivilege::Standard, 1221 1206 scopes: None, 1222 1207 created_by_controller_did: None, 1223 1208 }; ··· 1260 1245 1261 1246 let channel_auto_verified = verification_channel == "email" 1262 1247 && pending_preview.provider_email_verified 1263 - && pending_preview.provider_email.as_ref() == email.as_ref(); 1248 + && pending_preview.provider_email.as_ref().map(|e| e.as_str()) == email.as_deref(); 1264 1249 1265 1250 if channel_auto_verified { 1266 1251 let _ = state ··· 1304 1289 refresh_jti: refresh_meta.jti.clone(), 1305 1290 access_expires_at: access_meta.expires_at, 1306 1291 refresh_expires_at: refresh_meta.expires_at, 1307 - legacy_login: false, 1292 + login_type: tranquil_db_traits::LoginType::Modern, 1308 1293 mfa_verified: false, 1309 1294 scope: None, 1310 1295 controller_did: None, ··· 1315 1300 return Err(ApiError::InternalError(None)); 1316 1301 } 1317 1302 1318 - let hostname = 1319 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1303 + let hostname = pds_hostname(); 1320 1304 if let Err(e) = crate::comms::comms_repo::enqueue_welcome( 1321 1305 state.user_repo.as_ref(), 1322 1306 state.infra_repo.as_ref(), 1323 1307 user_id.unwrap_or(uuid::Uuid::nil()), 1324 - &hostname, 1308 + hostname, 1325 1309 ) 1326 1310 .await 1327 1311 { ··· 1367 1351 verification_channel, 1368 1352 &verification_recipient, 1369 1353 &formatted_token, 1370 - &hostname, 1354 + hostname, 1371 1355 ) 1372 1356 .await 1373 1357 {
+15 -3
crates/tranquil-pds/src/state.rs
··· 1 1 use crate::appview::DidResolver; 2 + use crate::auth::webauthn::WebAuthnConfig; 2 3 use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 3 4 use crate::circuit_breaker::CircuitBreakers; 4 5 use crate::config::AuthConfig; ··· 7 8 use crate::sso::{SsoConfig, SsoManager}; 8 9 use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage}; 9 10 use crate::sync::firehose::SequencedEvent; 11 + use crate::util::pds_hostname; 10 12 use sqlx::PgPool; 11 13 use std::error::Error; 12 14 use std::sync::Arc; 13 15 use tokio::sync::broadcast; 16 + use tokio_util::sync::CancellationToken; 14 17 use tranquil_db::{ 15 18 BacklinkRepository, BackupRepository, BlobRepository, DelegationRepository, InfraRepository, 16 19 OAuthRepository, PostgresRepositories, RepoEventNotifier, RepoRepository, SessionRepository, ··· 41 44 pub did_resolver: Arc<DidResolver>, 42 45 pub sso_repo: Arc<dyn SsoRepository>, 43 46 pub sso_manager: SsoManager, 47 + pub webauthn_config: Arc<WebAuthnConfig>, 48 + pub shutdown: CancellationToken, 44 49 } 45 50 51 + #[derive(Debug, Clone, Copy)] 46 52 pub enum RateLimitKind { 47 53 Login, 48 54 AccountCreation, ··· 116 122 } 117 123 118 124 impl AppState { 119 - pub async fn new() -> Result<Self, Box<dyn Error>> { 125 + pub async fn new(shutdown: CancellationToken) -> Result<Self, Box<dyn Error>> { 120 126 let database_url = std::env::var("DATABASE_URL") 121 127 .map_err(|_| "DATABASE_URL environment variable must be set")?; 122 128 ··· 157 163 .await 158 164 .map_err(|e| format!("Failed to run migrations: {}", e))?; 159 165 160 - Ok(Self::from_db(db).await) 166 + Ok(Self::from_db(db, shutdown).await) 161 167 } 162 168 163 - pub async fn from_db(db: PgPool) -> Self { 169 + pub async fn from_db(db: PgPool, shutdown: CancellationToken) -> Self { 164 170 AuthConfig::init(); 165 171 166 172 let repos = Arc::new(PostgresRepositories::new(db.clone())); ··· 180 186 let did_resolver = Arc::new(DidResolver::new()); 181 187 let sso_config = SsoConfig::init(); 182 188 let sso_manager = SsoManager::from_config(sso_config); 189 + let webauthn_config = Arc::new( 190 + WebAuthnConfig::new(pds_hostname()) 191 + .expect("Failed to create WebAuthn config at startup"), 192 + ); 183 193 184 194 Self { 185 195 user_repo: repos.user.clone(), ··· 204 214 distributed_rate_limiter, 205 215 did_resolver, 206 216 sso_manager, 217 + webauthn_config, 218 + shutdown, 207 219 } 208 220 } 209 221
+4 -3
crates/tranquil-pds/src/sync/commit.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 - use crate::sync::util::{AccountStatus, assert_repo_availability, get_account_with_status}; 3 + use crate::sync::util::{assert_repo_availability, get_account_with_status}; 4 4 use axum::{ 5 5 Json, 6 6 extract::{Query, State}, ··· 13 13 use serde::{Deserialize, Serialize}; 14 14 use std::str::FromStr; 15 15 use tracing::error; 16 + use tranquil_db_traits::AccountStatus; 16 17 use tranquil_types::Did; 17 18 18 19 async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> { ··· 130 131 head: cid_str, 131 132 rev, 132 133 active: status.is_active(), 133 - status: status.as_str().map(String::from), 134 + status: status.for_firehose().map(String::from), 134 135 }); 135 136 } 136 137 let next_cursor = if has_more { ··· 212 213 Json(GetRepoStatusOutput { 213 214 did: account.did, 214 215 active: account.status.is_active(), 215 - status: account.status.as_str().map(String::from), 216 + status: account.status.for_firehose().map(String::from), 216 217 rev, 217 218 }), 218 219 )
+5 -4
crates/tranquil-pds/src/sync/deprecated.rs
··· 19 19 const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 20 20 21 21 async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool { 22 - let extracted = match crate::auth::extract_auth_token_from_header( 23 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 24 - ) { 22 + let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str( 23 + headers, 24 + "Authorization", 25 + )) { 25 26 Some(t) => t, 26 27 None => return false, 27 28 }; 28 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 29 + let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 29 30 let http_uri = "/"; 30 31 match crate::auth::validate_token_with_dpop( 31 32 state.user_repo.as_ref(),
+10 -7
crates/tranquil-pds/src/sync/frame.rs
··· 2 2 use cid::Cid; 3 3 use serde::{Deserialize, Serialize}; 4 4 use std::str::FromStr; 5 + use tranquil_scopes::RepoAction; 5 6 6 7 #[derive(Debug, Serialize, Deserialize)] 7 8 pub struct FrameHeader { ··· 38 39 39 40 #[derive(Debug, Serialize, Deserialize)] 40 41 pub struct RepoOp { 41 - pub action: String, 42 + pub action: RepoAction, 42 43 pub path: String, 43 44 pub cid: Option<Cid>, 44 45 #[serde(skip_serializing_if = "Option::is_none")] ··· 159 160 serde_json::from_value(self.ops_json).unwrap_or_else(|_| vec![]); 160 161 let ops: Vec<RepoOp> = json_ops 161 162 .into_iter() 162 - .map(|op| RepoOp { 163 - action: op.action, 164 - path: op.path, 165 - cid: op.cid.and_then(|s| Cid::from_str(&s).ok()), 166 - prev: op.prev.and_then(|s| Cid::from_str(&s).ok()), 163 + .filter_map(|op| { 164 + Some(RepoOp { 165 + action: RepoAction::parse_str(&op.action)?, 166 + path: op.path, 167 + cid: op.cid.and_then(|s| Cid::from_str(&s).ok()), 168 + prev: op.prev.and_then(|s| Cid::from_str(&s).ok()), 169 + }) 167 170 }) 168 171 .collect(); 169 172 let rev = self.rev.unwrap_or_else(placeholder_rev); ··· 202 205 CommitFrameError::InvalidCommitCid("Missing commit_cid in event".to_string()) 203 206 })?; 204 207 let builder = CommitFrameBuilder::new( 205 - event.seq, 208 + event.seq.as_i64(), 206 209 event.did.to_string(), 207 210 commit_cid.as_str(), 208 211 event.prev_cid.as_ref().map(|c| c.as_str()),
+21 -10
crates/tranquil-pds/src/sync/listener.rs
··· 2 2 use crate::sync::firehose::SequencedEvent; 3 3 use std::sync::atomic::{AtomicI64, Ordering}; 4 4 use tracing::{debug, error, info, warn}; 5 + use tranquil_db_traits::SequenceNumber; 5 6 6 7 static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 7 8 8 9 pub async fn start_sequencer_listener(state: AppState) { 9 - let initial_seq = state.repo_repo.get_max_seq().await.unwrap_or(0); 10 - LAST_BROADCAST_SEQ.store(initial_seq, Ordering::SeqCst); 11 - info!(initial_seq = initial_seq, "Initialized sequencer listener"); 10 + let initial_seq = state 11 + .repo_repo 12 + .get_max_seq() 13 + .await 14 + .unwrap_or(SequenceNumber::ZERO); 15 + LAST_BROADCAST_SEQ.store(initial_seq.as_i64(), Ordering::SeqCst); 16 + info!( 17 + initial_seq = initial_seq.as_i64(), 18 + "Initialized sequencer listener" 19 + ); 12 20 tokio::spawn(async move { 13 21 info!("Starting sequencer listener background task"); 14 22 loop { ··· 27 35 .await 28 36 .map_err(|e| anyhow::anyhow!("Failed to subscribe to events: {:?}", e))?; 29 37 info!("Connected to database and listening for repo updates"); 30 - let catchup_start = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 38 + let catchup_start = SequenceNumber::from_raw(LAST_BROADCAST_SEQ.load(Ordering::SeqCst)); 31 39 let events = state 32 40 .repo_repo 33 41 .get_events_since_seq(catchup_start, None) ··· 36 44 if !events.is_empty() { 37 45 info!( 38 46 count = events.len(), 39 - from_seq = catchup_start, 47 + from_seq = catchup_start.as_i64(), 40 48 "Broadcasting catch-up events" 41 49 ); 42 50 events.into_iter().for_each(|event| { 43 51 let seq = event.seq; 44 52 let firehose_event = to_firehose_event(event); 45 53 let _ = state.firehose_tx.send(firehose_event); 46 - LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 54 + LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst); 47 55 }); 48 56 } 49 57 loop { ··· 63 71 if seq_id > last_seq + 1 { 64 72 let gap_events = state 65 73 .repo_repo 66 - .get_events_in_seq_range(last_seq, seq_id) 74 + .get_events_in_seq_range( 75 + SequenceNumber::from_raw(last_seq), 76 + SequenceNumber::from_raw(seq_id), 77 + ) 67 78 .await 68 79 .unwrap_or_default(); 69 80 if !gap_events.is_empty() { ··· 72 83 let seq = event.seq; 73 84 let firehose_event = to_firehose_event(event); 74 85 let _ = state.firehose_tx.send(firehose_event); 75 - LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 86 + LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst); 76 87 }); 77 88 } 78 89 } 79 90 let event = state 80 91 .repo_repo 81 - .get_event_by_seq(seq_id) 92 + .get_event_by_seq(SequenceNumber::from_raw(seq_id)) 82 93 .await 83 94 .ok() 84 95 .flatten(); ··· 97 108 warn!(seq = seq_id, error = %e, "Failed to broadcast event (no receivers?)"); 98 109 } 99 110 } 100 - LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 111 + LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst); 101 112 } else { 102 113 warn!( 103 114 seq = seq_id,
+2 -2
crates/tranquil-pds/src/sync/mod.rs
··· 18 18 pub use deprecated::{get_checkout, get_head}; 19 19 pub use repo::{get_blocks, get_record, get_repo}; 20 20 pub use subscribe_repos::subscribe_repos; 21 + pub use tranquil_db_traits::AccountStatus; 21 22 pub use util::{ 22 - AccountStatus, RepoAccount, RepoAvailabilityError, assert_repo_availability, 23 - get_account_with_status, 23 + RepoAccount, RepoAvailabilityError, assert_repo_availability, get_account_with_status, 24 24 }; 25 25 pub use verify::{CarVerifier, VerifiedCar, VerifyError};
+12 -6
crates/tranquil-pds/src/sync/subscribe_repos.rs
··· 13 13 use std::sync::atomic::{AtomicUsize, Ordering}; 14 14 use tokio::sync::broadcast::error::RecvError; 15 15 use tracing::{error, info, warn}; 16 + use tranquil_db_traits::SequenceNumber; 16 17 17 18 const BACKFILL_BATCH_SIZE: i64 = 1000; 18 19 ··· 69 70 params: SubscribeReposParams, 70 71 ) -> Result<(), ()> { 71 72 let mut rx = state.firehose_tx.subscribe(); 72 - let mut last_seen: i64 = -1; 73 + let mut last_seen = SequenceNumber::UNSET; 73 74 74 75 if let Some(cursor) = params.cursor { 75 - let current_seq = state.repo_repo.get_max_seq().await.unwrap_or(0); 76 + let cursor_seq = SequenceNumber::from_raw(cursor); 77 + let current_seq = state 78 + .repo_repo 79 + .get_max_seq() 80 + .await 81 + .unwrap_or(SequenceNumber::ZERO); 76 82 77 - if cursor > current_seq { 83 + if cursor_seq > current_seq { 78 84 if let Ok(error_bytes) = 79 85 format_error_frame("FutureCursor", Some("Cursor in the future.")) 80 86 { ··· 88 94 89 95 let first_event = state 90 96 .repo_repo 91 - .get_events_since_cursor(cursor, 1) 97 + .get_events_since_cursor(cursor_seq, 1) 92 98 .await 93 99 .ok() 94 100 .and_then(|events| events.into_iter().next()); 95 101 96 - let mut current_cursor = cursor; 102 + let mut current_cursor = cursor_seq; 97 103 98 104 if let Some(ref event) = first_event 99 105 && event.created_at < backfill_time ··· 113 119 .flatten(); 114 120 115 121 if let Some(earliest_seq) = earliest { 116 - current_cursor = earliest_seq - 1; 122 + current_cursor = SequenceNumber::from_raw(earliest_seq.as_i64() - 1); 117 123 } 118 124 } 119 125
+18 -102
crates/tranquil-pds/src/sync/util.rs
··· 11 11 use iroh_car::{CarHeader, CarWriter}; 12 12 use jacquard_repo::commit::Commit; 13 13 use jacquard_repo::storage::BlockStore; 14 - use serde::Serialize; 15 14 use std::collections::{BTreeMap, HashMap}; 16 15 use std::io::Cursor; 17 16 use std::str::FromStr; 18 17 use tokio::io::AsyncWriteExt; 19 - use tranquil_db_traits::RepoRepository; 18 + use tranquil_db_traits::{AccountStatus, RepoEventType, RepoRepository}; 20 19 use tranquil_types::Did; 21 20 22 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] 23 - #[serde(rename_all = "lowercase")] 24 - pub enum AccountStatus { 25 - Active, 26 - Takendown, 27 - Suspended, 28 - Deactivated, 29 - Deleted, 30 - } 31 - 32 - impl AccountStatus { 33 - pub fn as_str(&self) -> Option<&'static str> { 34 - match self { 35 - Self::Active => None, 36 - Self::Takendown => Some("takendown"), 37 - Self::Suspended => Some("suspended"), 38 - Self::Deactivated => Some("deactivated"), 39 - Self::Deleted => Some("deleted"), 40 - } 41 - } 42 - 43 - pub fn is_active(&self) -> bool { 44 - matches!(self, Self::Active) 45 - } 46 - 47 - pub fn is_takendown(&self) -> bool { 48 - matches!(self, Self::Takendown) 49 - } 50 - 51 - pub fn is_suspended(&self) -> bool { 52 - matches!(self, Self::Suspended) 53 - } 54 - 55 - pub fn is_deactivated(&self) -> bool { 56 - matches!(self, Self::Deactivated) 57 - } 58 - 59 - pub fn is_deleted(&self) -> bool { 60 - matches!(self, Self::Deleted) 61 - } 62 - 63 - pub fn allows_read(&self) -> bool { 64 - matches!(self, Self::Active | Self::Deactivated) 65 - } 66 - 67 - pub fn allows_write(&self) -> bool { 68 - matches!(self, Self::Active) 69 - } 70 - 71 - pub fn from_db_fields( 72 - takedown_ref: Option<&str>, 73 - deactivated_at: Option<chrono::DateTime<chrono::Utc>>, 74 - ) -> Self { 75 - if takedown_ref.is_some() { 76 - Self::Takendown 77 - } else if deactivated_at.is_some() { 78 - Self::Deactivated 79 - } else { 80 - Self::Active 81 - } 82 - } 83 - } 84 - 85 - impl From<crate::types::AccountState> for AccountStatus { 86 - fn from(state: crate::types::AccountState) -> Self { 87 - match state { 88 - crate::types::AccountState::Active => AccountStatus::Active, 89 - crate::types::AccountState::Deactivated { .. } => AccountStatus::Deactivated, 90 - crate::types::AccountState::TakenDown { .. } => AccountStatus::Takendown, 91 - crate::types::AccountState::Migrated { .. } => AccountStatus::Deactivated, 92 - } 93 - } 94 - } 95 - 96 - impl From<&crate::types::AccountState> for AccountStatus { 97 - fn from(state: &crate::types::AccountState) -> Self { 98 - match state { 99 - crate::types::AccountState::Active => AccountStatus::Active, 100 - crate::types::AccountState::Deactivated { .. } => AccountStatus::Deactivated, 101 - crate::types::AccountState::TakenDown { .. } => AccountStatus::Takendown, 102 - crate::types::AccountState::Migrated { .. } => AccountStatus::Deactivated, 103 - } 104 - } 105 - } 106 - 107 21 pub struct RepoAccount { 108 22 pub did: String, 109 23 pub user_id: uuid::Uuid, ··· 233 147 let frame = IdentityFrame { 234 148 did: event.did.to_string(), 235 149 handle: event.handle.as_ref().map(|h| h.to_string()), 236 - seq: event.seq, 150 + seq: event.seq.as_i64(), 237 151 time: format_atproto_time(event.created_at), 238 152 }; 239 153 let header = FrameHeader { ··· 250 164 let frame = AccountFrame { 251 165 did: event.did.to_string(), 252 166 active: event.active.unwrap_or(true), 253 - status: event.status.clone(), 254 - seq: event.seq, 167 + status: event 168 + .status 169 + .and_then(|s| s.for_firehose().map(String::from)), 170 + seq: event.seq.as_i64(), 255 171 time: format_atproto_time(event.created_at), 256 172 }; 257 173 let header = FrameHeader { ··· 298 214 did: event.did.to_string(), 299 215 rev, 300 216 blocks: car_bytes, 301 - seq: event.seq, 217 + seq: event.seq.as_i64(), 302 218 time: format_atproto_time(event.created_at), 303 219 }; 304 220 let header = FrameHeader { ··· 315 231 state: &AppState, 316 232 event: SequencedEvent, 317 233 ) -> Result<Vec<u8>, anyhow::Error> { 318 - match event.event_type.as_str() { 319 - "identity" => return format_identity_event(&event), 320 - "account" => return format_account_event(&event), 321 - "sync" => return format_sync_event(state, &event).await, 322 - _ => {} 234 + match event.event_type { 235 + RepoEventType::Identity => return format_identity_event(&event), 236 + RepoEventType::Account => return format_account_event(&event), 237 + RepoEventType::Sync => return format_sync_event(state, &event).await, 238 + RepoEventType::Commit => {} 323 239 } 324 240 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 325 241 let prev_cid_link = event.prev_cid.clone(); ··· 440 356 did: event.did.to_string(), 441 357 rev, 442 358 blocks: car_bytes, 443 - seq: event.seq, 359 + seq: event.seq.as_i64(), 444 360 time: format_atproto_time(event.created_at), 445 361 }; 446 362 let header = FrameHeader { ··· 457 373 event: SequencedEvent, 458 374 prefetched: &HashMap<Cid, Bytes>, 459 375 ) -> Result<Vec<u8>, anyhow::Error> { 460 - match event.event_type.as_str() { 461 - "identity" => return format_identity_event(&event), 462 - "account" => return format_account_event(&event), 463 - "sync" => return format_sync_event_with_prefetched(&event, prefetched), 464 - _ => {} 376 + match event.event_type { 377 + RepoEventType::Identity => return format_identity_event(&event), 378 + RepoEventType::Account => return format_account_event(&event), 379 + RepoEventType::Sync => return format_sync_event_with_prefetched(&event, prefetched), 380 + RepoEventType::Commit => {} 465 381 } 466 382 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 467 383 let prev_cid_link = event.prev_cid.clone();
+20 -4
crates/tranquil-pds/src/util.rs
··· 4 4 use rand::Rng; 5 5 use serde_json::Value as JsonValue; 6 6 use std::collections::BTreeMap; 7 + use std::net::SocketAddr; 7 8 use std::str::FromStr; 8 9 use std::sync::OnceLock; 9 10 ··· 11 12 const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 12 13 13 14 static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 15 + static PDS_HOSTNAME: OnceLock<String> = OnceLock::new(); 16 + static PDS_HOSTNAME_WITHOUT_PORT: OnceLock<String> = OnceLock::new(); 14 17 15 18 pub fn get_max_blob_size() -> usize { 16 19 *MAX_BLOB_SIZE.get_or_init(|| { ··· 69 72 .unwrap_or_default() 70 73 } 71 74 72 - pub fn extract_client_ip(headers: &HeaderMap) -> String { 75 + pub fn get_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { 76 + headers.get(name).and_then(|h| h.to_str().ok()) 77 + } 78 + 79 + pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 73 80 if let Some(forwarded) = headers.get("x-forwarded-for") 74 81 && let Ok(value) = forwarded.to_str() 75 82 && let Some(first_ip) = value.split(',').next() ··· 81 88 { 82 89 return value.trim().to_string(); 83 90 } 84 - "unknown".to_string() 91 + addr.map(|a| a.ip().to_string()) 92 + .unwrap_or_else(|| "unknown".to_string()) 85 93 } 86 94 87 - pub fn pds_hostname() -> String { 88 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 95 + pub fn pds_hostname() -> &'static str { 96 + PDS_HOSTNAME 97 + .get_or_init(|| std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())) 98 + } 99 + 100 + pub fn pds_hostname_without_port() -> &'static str { 101 + PDS_HOSTNAME_WITHOUT_PORT.get_or_init(|| { 102 + let hostname = pds_hostname(); 103 + hostname.split(':').next().unwrap_or(hostname).to_string() 104 + }) 89 105 } 90 106 91 107 pub fn pds_public_url() -> String {
+1 -1
crates/tranquil-pds/tests/commit_signing.rs
··· 99 99 use tranquil_pds::api::repo::record::utils::create_signed_commit; 100 100 101 101 let signing_key = SigningKey::random(&mut rand::thread_rng()); 102 - let did = Did::new_unchecked("did:plc:testuser123456789abcdef"); 102 + let did = unsafe { Did::new_unchecked("did:plc:testuser123456789abcdef") }; 103 103 let data_cid = 104 104 Cid::from_str("bafyreib2rxk3ryblouj3fxza5jvx6psmwewwessc4m6g6e7pqhhkwqomfi").unwrap(); 105 105 let rev = Tid::now(LimitedU32::MIN).to_string();
+2 -1
crates/tranquil-pds/tests/common/mod.rs
··· 14 14 #[allow(unused_imports)] 15 15 use std::time::Duration; 16 16 use tokio::net::TcpListener; 17 + use tokio_util::sync::CancellationToken; 17 18 use tranquil_pds::state::AppState; 18 19 use wiremock::matchers::{method, path}; 19 20 use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; ··· 546 547 .with_email_update_limit(10000) 547 548 .with_oauth_authorize_limit(10000) 548 549 .with_oauth_token_limit(10000); 549 - let state = AppState::from_db(pool) 550 + let state = AppState::from_db(pool, CancellationToken::new()) 550 551 .await 551 552 .with_rate_limiters(rate_limiters); 552 553 tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
+6 -10
crates/tranquil-pds/tests/firehose_validation.rs
··· 9 9 use serde_json::{Value, json}; 10 10 use std::io::Cursor; 11 11 use tokio_tungstenite::{connect_async, tungstenite}; 12 + use tranquil_scopes::RepoAction; 12 13 13 14 #[derive(Debug, Deserialize, Serialize)] 14 15 struct FrameHeader { ··· 39 40 40 41 #[derive(Debug, Deserialize)] 41 42 struct RepoOp { 42 - action: String, 43 + action: RepoAction, 43 44 path: String, 44 45 cid: Option<Cid>, 45 46 prev: Option<Cid>, ··· 292 293 println!("\nOps validation:"); 293 294 for (i, op) in frame.ops.iter().enumerate() { 294 295 println!(" Op {}:", i); 295 - println!(" action: {}", op.action); 296 + println!(" action: {:?}", op.action); 296 297 println!(" path: {}", op.path); 297 298 println!(" cid: {:?}", op.cid); 298 299 println!( ··· 300 301 op.prev 301 302 ); 302 303 303 - assert!( 304 - ["create", "update", "delete"].contains(&op.action.as_str()), 305 - "Invalid action: {}", 306 - op.action 307 - ); 308 304 assert!( 309 305 op.path.contains('/'), 310 306 "Path should contain collection/rkey: {}", 311 307 op.path 312 308 ); 313 309 314 - if op.action == "create" { 310 + if op.action == RepoAction::Create { 315 311 assert!(op.cid.is_some(), "Create op should have cid"); 316 312 } 317 313 } ··· 445 441 446 442 for op in &frame.ops { 447 443 println!( 448 - "Op: action={}, path={}, cid={:?}, prev={:?}", 444 + "Op: action={:?}, path={}, cid={:?}, prev={:?}", 449 445 op.action, op.path, op.cid, op.prev 450 446 ); 451 447 452 - if op.action == "update" && op.path.contains("app.bsky.actor.profile") { 448 + if op.action == RepoAction::Update && op.path.contains("app.bsky.actor.profile") { 453 449 assert!( 454 450 op.prev.is_some(), 455 451 "Update operation should have 'prev' field with old CID! Got: {:?}",
+71
crates/tranquil-pds/tests/shutdown_unit.rs
··· 1 + use std::sync::Arc; 2 + use std::sync::atomic::{AtomicBool, Ordering}; 3 + use tokio_util::sync::CancellationToken; 4 + 5 + #[test] 6 + fn test_panic_hook_cancels_shutdown_token() { 7 + let shutdown = CancellationToken::new(); 8 + let shutdown_clone = shutdown.clone(); 9 + 10 + let panic_occurred = Arc::new(AtomicBool::new(false)); 11 + let panic_occurred_clone = panic_occurred.clone(); 12 + 13 + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { 14 + let default_hook = std::panic::take_hook(); 15 + std::panic::set_hook(Box::new(move |info| { 16 + panic_occurred_clone.store(true, Ordering::SeqCst); 17 + shutdown_clone.cancel(); 18 + default_hook(info); 19 + })); 20 + 21 + panic!("simulated corrupted data panic"); 22 + })); 23 + 24 + assert!(result.is_err()); 25 + assert!(panic_occurred.load(Ordering::SeqCst)); 26 + assert!(shutdown.is_cancelled()); 27 + 28 + let _ = std::panic::take_hook(); 29 + } 30 + 31 + #[test] 32 + fn test_cancellation_token_propagates_to_clones() { 33 + let shutdown = CancellationToken::new(); 34 + let clone1 = shutdown.clone(); 35 + let clone2 = shutdown.clone(); 36 + 37 + assert!(!shutdown.is_cancelled()); 38 + assert!(!clone1.is_cancelled()); 39 + assert!(!clone2.is_cancelled()); 40 + 41 + shutdown.cancel(); 42 + 43 + assert!(shutdown.is_cancelled()); 44 + assert!(clone1.is_cancelled()); 45 + assert!(clone2.is_cancelled()); 46 + } 47 + 48 + #[tokio::test] 49 + async fn test_cancelled_future_completes_on_cancel() { 50 + let shutdown = CancellationToken::new(); 51 + let shutdown_clone = shutdown.clone(); 52 + 53 + let handle = tokio::spawn(async move { 54 + shutdown_clone.cancelled().await; 55 + true 56 + }); 57 + 58 + tokio::time::sleep(std::time::Duration::from_millis(10)).await; 59 + assert!(!handle.is_finished()); 60 + 61 + shutdown.cancel(); 62 + 63 + let result = tokio::time::timeout( 64 + std::time::Duration::from_millis(100), 65 + handle, 66 + ) 67 + .await; 68 + 69 + assert!(result.is_ok()); 70 + assert!(result.unwrap().unwrap()); 71 + }
+25 -17
crates/tranquil-pds/tests/sso.rs
··· 232 232 let _url = base_url().await; 233 233 let pool = get_test_db_pool().await; 234 234 235 - let did = Did::new_unchecked(format!( 236 - "did:plc:test{}", 237 - &uuid::Uuid::new_v4().simple().to_string()[..12] 238 - )); 235 + let did = unsafe { 236 + Did::new_unchecked(format!( 237 + "did:plc:test{}", 238 + &uuid::Uuid::new_v4().simple().to_string()[..12] 239 + )) 240 + }; 239 241 let provider = SsoProviderType::Github; 240 242 let provider_user_id = format!("github_user_{}", uuid::Uuid::new_v4().simple()); 241 243 ··· 350 352 let _url = base_url().await; 351 353 let pool = get_test_db_pool().await; 352 354 353 - let did1 = Did::new_unchecked(format!( 354 - "did:plc:uc1{}", 355 - &uuid::Uuid::new_v4().simple().to_string()[..10] 356 - )); 357 - let did2 = Did::new_unchecked(format!( 358 - "did:plc:uc2{}", 359 - &uuid::Uuid::new_v4().simple().to_string()[..10] 360 - )); 355 + let did1 = unsafe { 356 + Did::new_unchecked(format!( 357 + "did:plc:uc1{}", 358 + &uuid::Uuid::new_v4().simple().to_string()[..10] 359 + )) 360 + }; 361 + let did2 = unsafe { 362 + Did::new_unchecked(format!( 363 + "did:plc:uc2{}", 364 + &uuid::Uuid::new_v4().simple().to_string()[..10] 365 + )) 366 + }; 361 367 let provider_user_id = format!("unique_test_{}", uuid::Uuid::new_v4().simple()); 362 368 363 369 sqlx::query!( ··· 577 583 let _url = base_url().await; 578 584 let pool = get_test_db_pool().await; 579 585 580 - let did = Did::new_unchecked(format!( 581 - "did:plc:del{}", 582 - &uuid::Uuid::new_v4().simple().to_string()[..10] 583 - )); 584 - let wrong_did = Did::new_unchecked("did:plc:wrongdid12345"); 586 + let did = unsafe { 587 + Did::new_unchecked(format!( 588 + "did:plc:del{}", 589 + &uuid::Uuid::new_v4().simple().to_string()[..10] 590 + )) 591 + }; 592 + let wrong_did = unsafe { Did::new_unchecked("did:plc:wrongdid12345") }; 585 593 586 594 sqlx::query!( 587 595 "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
+3 -1
crates/tranquil-scopes/src/parser.rs
··· 1 + use serde::{Deserialize, Serialize}; 1 2 use std::collections::{HashMap, HashSet}; 2 3 3 4 #[derive(Debug, Clone, PartialEq, Eq)] ··· 27 28 pub actions: HashSet<RepoAction>, 28 29 } 29 30 30 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 31 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 32 + #[serde(rename_all = "lowercase")] 31 33 pub enum RepoAction { 32 34 Create, 33 35 Update,
+6 -3
crates/tranquil-types/src/lib.rs
··· 166 166 Ok(Self(s)) 167 167 } 168 168 169 - pub fn new_unchecked(s: impl Into<String>) -> Self { 169 + #[allow(unsafe_code, clippy::missing_safety_doc)] 170 + pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 170 171 Self(s.into()) 171 172 } 172 173 } ··· 228 229 Ok(Self(s)) 229 230 } 230 231 231 - pub fn new_unchecked(s: impl Into<String>) -> Self { 232 + #[allow(unsafe_code, clippy::missing_safety_doc)] 233 + pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 232 234 Self(s.into()) 233 235 } 234 236 } ··· 489 491 Ok(Self(s)) 490 492 } 491 493 492 - pub fn new_unchecked(s: impl Into<String>) -> Self { 494 + #[allow(unsafe_code, clippy::missing_safety_doc)] 495 + pub unsafe fn new_unchecked(s: impl Into<String>) -> Self { 493 496 Self(s.into()) 494 497 } 495 498

History

3 rounds 0 comments
sign up or login to add to the discussion
1 commit
expand
fix: better type-safety
expand 0 comments
pull request successfully merged
1 commit
expand
fix: better type-safety
expand 0 comments
lewis.moe submitted #0
1 commit
expand
fix: better type-safety
expand 0 comments