Ensuring at compile-time that we're definitely handling possible early failures in functions
+27
.sqlx/query-03fc2ba947ee547e000b044fafb486e71b9b65a7dd923b5354c5a4dde98332eb.json
+27
.sqlx/query-03fc2ba947ee547e000b044fafb486e71b9b65a7dd923b5354c5a4dde98332eb.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "UPDATE users SET preferred_comms_channel = $1, updated_at = NOW() WHERE did = $2",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
{
9
+
"Custom": {
10
+
"name": "comms_channel",
11
+
"kind": {
12
+
"Enum": [
13
+
"email",
14
+
"discord",
15
+
"telegram",
16
+
"signal"
17
+
]
18
+
}
19
+
}
20
+
},
21
+
"Text"
22
+
]
23
+
},
24
+
"nullable": []
25
+
},
26
+
"hash": "03fc2ba947ee547e000b044fafb486e71b9b65a7dd923b5354c5a4dde98332eb"
27
+
}
+3
-3
.sqlx/query-805a344e73f2c19caaffe71de227ddd505599839033e83ae4be5b243d343d651.json
.sqlx/query-0d32a592a97ad47c65aa37cf0d45417f2966fcbd688be7434626ae5f6971fa1f.json
+3
-3
.sqlx/query-805a344e73f2c19caaffe71de227ddd505599839033e83ae4be5b243d343d651.json
.sqlx/query-0d32a592a97ad47c65aa37cf0d45417f2966fcbd688be7434626ae5f6971fa1f.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq = $1",
3
+
"query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq = $1",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
20
20
},
21
21
{
22
22
"ordinal": 3,
23
-
"name": "event_type",
23
+
"name": "event_type: RepoEventType",
24
24
"type_info": "Text"
25
25
},
26
26
{
···
96
96
true
97
97
]
98
98
},
99
-
"hash": "805a344e73f2c19caaffe71de227ddd505599839033e83ae4be5b243d343d651"
99
+
"hash": "0d32a592a97ad47c65aa37cf0d45417f2966fcbd688be7434626ae5f6971fa1f"
100
100
}
+28
.sqlx/query-200ecf153f1433ae8f6fbe81ab888a04ddd035ec9e88ef5f207e2487a02a1224.json
+28
.sqlx/query-200ecf153f1433ae8f6fbe81ab888a04ddd035ec9e88ef5f207e2487a02a1224.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT available_uses, COALESCE(disabled, false) as \"disabled!\" FROM invite_codes WHERE code = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "available_uses",
9
+
"type_info": "Int4"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "disabled!",
14
+
"type_info": "Bool"
15
+
}
16
+
],
17
+
"parameters": {
18
+
"Left": [
19
+
"Text"
20
+
]
21
+
},
22
+
"nullable": [
23
+
false,
24
+
null
25
+
]
26
+
},
27
+
"hash": "200ecf153f1433ae8f6fbe81ab888a04ddd035ec9e88ef5f207e2487a02a1224"
28
+
}
+17
-5
.sqlx/query-426fedba6791c420fe7af6decc296c681d05a5c24a38b8cd7083c8dfa9178ded.json
.sqlx/query-247470d26a90617e7dc9b5b3a2146ee3f54448e3c24943f7005e3a8e28820d43.json
+17
-5
.sqlx/query-426fedba6791c420fe7af6decc296c681d05a5c24a38b8cd7083c8dfa9178ded.json
.sqlx/query-247470d26a90617e7dc9b5b3a2146ee3f54448e3c24943f7005e3a8e28820d43.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT\n email,\n preferred_comms_channel::text as \"preferred_channel!\",\n discord_id,\n discord_verified,\n telegram_username,\n telegram_verified,\n signal_number,\n signal_verified\n FROM users WHERE did = $1",
3
+
"query": "SELECT\n email,\n preferred_comms_channel as \"preferred_channel!: CommsChannel\",\n discord_id,\n discord_verified,\n telegram_username,\n telegram_verified,\n signal_number,\n signal_verified\n FROM users WHERE did = $1",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
10
10
},
11
11
{
12
12
"ordinal": 1,
13
-
"name": "preferred_channel!",
14
-
"type_info": "Text"
13
+
"name": "preferred_channel!: CommsChannel",
14
+
"type_info": {
15
+
"Custom": {
16
+
"name": "comms_channel",
17
+
"kind": {
18
+
"Enum": [
19
+
"email",
20
+
"discord",
21
+
"telegram",
22
+
"signal"
23
+
]
24
+
}
25
+
}
26
+
}
15
27
},
16
28
{
17
29
"ordinal": 2,
···
51
63
},
52
64
"nullable": [
53
65
true,
54
-
null,
66
+
false,
55
67
true,
56
68
false,
57
69
true,
···
60
72
false
61
73
]
62
74
},
63
-
"hash": "426fedba6791c420fe7af6decc296c681d05a5c24a38b8cd7083c8dfa9178ded"
75
+
"hash": "247470d26a90617e7dc9b5b3a2146ee3f54448e3c24943f7005e3a8e28820d43"
64
76
}
+5
-5
.sqlx/query-9fea6394495b70ef5af2c2f5298e651d1ae78aa9ac6b03f952b6b0416023f671.json
.sqlx/query-25309f4a08845a49557d694ad9b5b9a137be4dcce28e9293551c8c3fd40fdd86.json
+5
-5
.sqlx/query-9fea6394495b70ef5af2c2f5298e651d1ae78aa9ac6b03f952b6b0416023f671.json
.sqlx/query-25309f4a08845a49557d694ad9b5b9a137be4dcce28e9293551c8c3fd40fdd86.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT\n created_at,\n channel as \"channel: String\",\n comms_type as \"comms_type: String\",\n status as \"status: String\",\n subject,\n body\n FROM comms_queue\n WHERE user_id = $1\n ORDER BY created_at DESC\n LIMIT $2\n ",
3
+
"query": "\n SELECT\n created_at,\n channel as \"channel: CommsChannel\",\n comms_type as \"comms_type: CommsType\",\n status as \"status: CommsStatus\",\n subject,\n body\n FROM comms_queue\n WHERE user_id = $1\n ORDER BY created_at DESC\n LIMIT $2\n ",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
10
10
},
11
11
{
12
12
"ordinal": 1,
13
-
"name": "channel: String",
13
+
"name": "channel: CommsChannel",
14
14
"type_info": {
15
15
"Custom": {
16
16
"name": "comms_channel",
···
27
27
},
28
28
{
29
29
"ordinal": 2,
30
-
"name": "comms_type: String",
30
+
"name": "comms_type: CommsType",
31
31
"type_info": {
32
32
"Custom": {
33
33
"name": "comms_type",
···
52
52
},
53
53
{
54
54
"ordinal": 3,
55
-
"name": "status: String",
55
+
"name": "status: CommsStatus",
56
56
"type_info": {
57
57
"Custom": {
58
58
"name": "comms_status",
···
93
93
false
94
94
]
95
95
},
96
-
"hash": "9fea6394495b70ef5af2c2f5298e651d1ae78aa9ac6b03f952b6b0416023f671"
96
+
"hash": "25309f4a08845a49557d694ad9b5b9a137be4dcce28e9293551c8c3fd40fdd86"
97
97
}
-22
.sqlx/query-36441073d3fb87230f88ddce4e597c248fbf7360e510d703b9eec42efe9e049e.json
-22
.sqlx/query-36441073d3fb87230f88ddce4e597c248fbf7360e510d703b9eec42efe9e049e.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT (available_uses > 0 AND NOT COALESCE(disabled, false)) as \"valid!\" FROM invite_codes WHERE code = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "valid!",
9
-
"type_info": "Bool"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Text"
15
-
]
16
-
},
17
-
"nullable": [
18
-
null
19
-
]
20
-
},
21
-
"hash": "36441073d3fb87230f88ddce4e597c248fbf7360e510d703b9eec42efe9e049e"
22
-
}
+15
-5
.sqlx/query-445c2ebb72f3833119f32284b9e721cf34c8ae581e6ae58a392fc93e77a7a015.json
.sqlx/query-7061e8763ef7d91ff152ed0124f99e1820172fd06916d225ca6c5137a507b8fa.json
+15
-5
.sqlx/query-445c2ebb72f3833119f32284b9e721cf34c8ae581e6ae58a392fc93e77a7a015.json
.sqlx/query-7061e8763ef7d91ff152ed0124f99e1820172fd06916d225ca6c5137a507b8fa.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT id, did, email, password_hash, password_required, two_factor_enabled,\n preferred_comms_channel as \"preferred_comms_channel!: CommsChannel\",\n deactivated_at, takedown_ref,\n email_verified, discord_verified, telegram_verified, signal_verified,\n account_type::text as \"account_type!\"\n FROM users\n WHERE handle = $1 OR email = $1\n ",
3
+
"query": "\n SELECT id, did, email, password_hash, password_required, two_factor_enabled,\n preferred_comms_channel as \"preferred_comms_channel!: CommsChannel\",\n deactivated_at, takedown_ref,\n email_verified, discord_verified, telegram_verified, signal_verified,\n account_type as \"account_type!: AccountType\"\n FROM users\n WHERE handle = $1 OR email = $1\n ",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
82
82
},
83
83
{
84
84
"ordinal": 13,
85
-
"name": "account_type!",
86
-
"type_info": "Text"
85
+
"name": "account_type!: AccountType",
86
+
"type_info": {
87
+
"Custom": {
88
+
"name": "account_type",
89
+
"kind": {
90
+
"Enum": [
91
+
"personal",
92
+
"delegated"
93
+
]
94
+
}
95
+
}
96
+
}
87
97
}
88
98
],
89
99
"parameters": {
···
105
115
false,
106
116
false,
107
117
false,
108
-
null
118
+
false
109
119
]
110
120
},
111
-
"hash": "445c2ebb72f3833119f32284b9e721cf34c8ae581e6ae58a392fc93e77a7a015"
121
+
"hash": "7061e8763ef7d91ff152ed0124f99e1820172fd06916d225ca6c5137a507b8fa"
112
122
}
+3
-3
.sqlx/query-605dc962cf86004de763aee65757a5a77da150b36aa8470c52fd5835e9b895fc.json
.sqlx/query-b26bf97a27783eb7fb524a92dda3e68ef8470a9751fcaefe5fd2d7909dead54b.json
+3
-3
.sqlx/query-605dc962cf86004de763aee65757a5a77da150b36aa8470c52fd5835e9b895fc.json
.sqlx/query-b26bf97a27783eb7fb524a92dda3e68ef8470a9751fcaefe5fd2d7909dead54b.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1 AND seq < $2\n ORDER BY seq ASC",
3
+
"query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
20
20
},
21
21
{
22
22
"ordinal": 3,
23
-
"name": "event_type",
23
+
"name": "event_type: RepoEventType",
24
24
"type_info": "Text"
25
25
},
26
26
{
···
97
97
true
98
98
]
99
99
},
100
-
"hash": "605dc962cf86004de763aee65757a5a77da150b36aa8470c52fd5835e9b895fc"
100
+
"hash": "b26bf97a27783eb7fb524a92dda3e68ef8470a9751fcaefe5fd2d7909dead54b"
101
101
}
+3
-3
.sqlx/query-e2befe7fa07a1072a8b3f0ed6c1a54a39ffc8769aa65391ea282c78d2cd29f23.json
.sqlx/query-b8101757a50075d20147014e450cb7deb7e58f84310690c7bde61e1834dc5903.json
+3
-3
.sqlx/query-e2befe7fa07a1072a8b3f0ed6c1a54a39ffc8769aa65391ea282c78d2cd29f23.json
.sqlx/query-b8101757a50075d20147014e450cb7deb7e58f84310690c7bde61e1834dc5903.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC",
3
+
"query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
20
20
},
21
21
{
22
22
"ordinal": 3,
23
-
"name": "event_type",
23
+
"name": "event_type: RepoEventType",
24
24
"type_info": "Text"
25
25
},
26
26
{
···
96
96
true
97
97
]
98
98
},
99
-
"hash": "e2befe7fa07a1072a8b3f0ed6c1a54a39ffc8769aa65391ea282c78d2cd29f23"
99
+
"hash": "b8101757a50075d20147014e450cb7deb7e58f84310690c7bde61e1834dc5903"
100
100
}
+3
-3
.sqlx/query-8f6a1e09351dc716eaadc9e30c5cfea45212901a139e98f0fccfacfbb3371dec.json
.sqlx/query-d8524ad3f5dc03eb09ed60396a78df5003f804c43ad253d6476523eacdebf811.json
+3
-3
.sqlx/query-8f6a1e09351dc716eaadc9e30c5cfea45212901a139e98f0fccfacfbb3371dec.json
.sqlx/query-d8524ad3f5dc03eb09ed60396a78df5003f804c43ad253d6476523eacdebf811.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2",
3
+
"query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1 AND seq < $2\n ORDER BY seq ASC",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
20
20
},
21
21
{
22
22
"ordinal": 3,
23
-
"name": "event_type",
23
+
"name": "event_type: RepoEventType",
24
24
"type_info": "Text"
25
25
},
26
26
{
···
97
97
true
98
98
]
99
99
},
100
-
"hash": "8f6a1e09351dc716eaadc9e30c5cfea45212901a139e98f0fccfacfbb3371dec"
100
+
"hash": "d8524ad3f5dc03eb09ed60396a78df5003f804c43ad253d6476523eacdebf811"
101
101
}
+52
.sqlx/query-d8fd97c8be3211b2509669dd859245b14e15f81a42d7e0c4c428b65f466af5ee.json
+52
.sqlx/query-d8fd97c8be3211b2509669dd859245b14e15f81a42d7e0c4c428b65f466af5ee.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT email, handle, preferred_comms_channel as \"preferred_channel!: CommsChannel\", preferred_locale\n FROM users WHERE id = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "email",
9
+
"type_info": "Text"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "handle",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "preferred_channel!: CommsChannel",
19
+
"type_info": {
20
+
"Custom": {
21
+
"name": "comms_channel",
22
+
"kind": {
23
+
"Enum": [
24
+
"email",
25
+
"discord",
26
+
"telegram",
27
+
"signal"
28
+
]
29
+
}
30
+
}
31
+
}
32
+
},
33
+
{
34
+
"ordinal": 3,
35
+
"name": "preferred_locale",
36
+
"type_info": "Varchar"
37
+
}
38
+
],
39
+
"parameters": {
40
+
"Left": [
41
+
"Uuid"
42
+
]
43
+
},
44
+
"nullable": [
45
+
true,
46
+
false,
47
+
false,
48
+
true
49
+
]
50
+
},
51
+
"hash": "d8fd97c8be3211b2509669dd859245b14e15f81a42d7e0c4c428b65f466af5ee"
52
+
}
-40
.sqlx/query-e3aeec9a759b2b68cb11fa48b5d34ffc19430a6b16adb0c49307da0cacdf1ca3.json
-40
.sqlx/query-e3aeec9a759b2b68cb11fa48b5d34ffc19430a6b16adb0c49307da0cacdf1ca3.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT email, handle, preferred_comms_channel::text as \"preferred_channel!\", preferred_locale\n FROM users WHERE id = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "email",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "handle",
14
-
"type_info": "Text"
15
-
},
16
-
{
17
-
"ordinal": 2,
18
-
"name": "preferred_channel!",
19
-
"type_info": "Text"
20
-
},
21
-
{
22
-
"ordinal": 3,
23
-
"name": "preferred_locale",
24
-
"type_info": "Varchar"
25
-
}
26
-
],
27
-
"parameters": {
28
-
"Left": [
29
-
"Uuid"
30
-
]
31
-
},
32
-
"nullable": [
33
-
true,
34
-
false,
35
-
null,
36
-
true
37
-
]
38
-
},
39
-
"hash": "e3aeec9a759b2b68cb11fa48b5d34ffc19430a6b16adb0c49307da0cacdf1ca3"
40
-
}
+3
-3
.sqlx/query-caffa68d10445a42878b66e6b0224dafb8527c8a4cc9806d6f733edff72bc9db.json
.sqlx/query-e7aa1080be9eb3a8ddf1f050c93dc8afd10478f41e22307014784b4ee3740b4a.json
+3
-3
.sqlx/query-caffa68d10445a42878b66e6b0224dafb8527c8a4cc9806d6f733edff72bc9db.json
.sqlx/query-e7aa1080be9eb3a8ddf1f050c93dc8afd10478f41e22307014784b4ee3740b4a.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2",
3
+
"query": "SELECT seq, did, created_at, event_type as \"event_type: RepoEventType\", commit_cid, prev_cid, prev_data_cid,\n ops, blobs, blocks_cids, handle, active, status, rev\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
20
20
},
21
21
{
22
22
"ordinal": 3,
23
-
"name": "event_type",
23
+
"name": "event_type: RepoEventType",
24
24
"type_info": "Text"
25
25
},
26
26
{
···
97
97
true
98
98
]
99
99
},
100
-
"hash": "caffa68d10445a42878b66e6b0224dafb8527c8a4cc9806d6f733edff72bc9db"
100
+
"hash": "e7aa1080be9eb3a8ddf1f050c93dc8afd10478f41e22307014784b4ee3740b4a"
101
101
}
+3
-2
Cargo.lock
+3
-2
Cargo.lock
···
5733
5733
5734
5734
[[package]]
5735
5735
name = "tokio-util"
5736
-
version = "0.7.17"
5736
+
version = "0.7.18"
5737
5737
source = "registry+https://github.com/rust-lang/crates.io-index"
5738
-
checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594"
5738
+
checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
5739
5739
dependencies = [
5740
5740
"bytes",
5741
5741
"futures-core",
···
6115
6115
"thiserror 2.0.17",
6116
6116
"tokio",
6117
6117
"tokio-tungstenite",
6118
+
"tokio-util",
6118
6119
"tower",
6119
6120
"tower-http",
6120
6121
"tower-layer",
+1
Cargo.toml
+1
Cargo.toml
···
87
87
subtle = "2.5"
88
88
thiserror = "2.0"
89
89
tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process"] }
90
+
tokio-util = "0.7.18"
90
91
tokio-tungstenite = { version = "0.28", features = ["native-tls"] }
91
92
totp-rs = { version = "5", features = ["qr"] }
92
93
tower = "0.5"
+51
crates/tranquil-db-traits/src/channel_verification.rs
+51
crates/tranquil-db-traits/src/channel_verification.rs
···
1
+
use crate::CommsChannel;
2
+
use serde::{Deserialize, Serialize};
3
+
4
+
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
5
+
pub struct ChannelVerificationStatus {
6
+
pub email: bool,
7
+
pub discord: bool,
8
+
pub telegram: bool,
9
+
pub signal: bool,
10
+
}
11
+
12
+
impl ChannelVerificationStatus {
13
+
pub fn new(email: bool, discord: bool, telegram: bool, signal: bool) -> Self {
14
+
Self {
15
+
email,
16
+
discord,
17
+
telegram,
18
+
signal,
19
+
}
20
+
}
21
+
22
+
pub fn has_any_verified(&self) -> bool {
23
+
self.email || self.discord || self.telegram || self.signal
24
+
}
25
+
26
+
pub fn verified_channels(&self) -> Vec<CommsChannel> {
27
+
let mut channels = Vec::with_capacity(4);
28
+
if self.email {
29
+
channels.push(CommsChannel::Email);
30
+
}
31
+
if self.discord {
32
+
channels.push(CommsChannel::Discord);
33
+
}
34
+
if self.telegram {
35
+
channels.push(CommsChannel::Telegram);
36
+
}
37
+
if self.signal {
38
+
channels.push(CommsChannel::Signal);
39
+
}
40
+
channels
41
+
}
42
+
43
+
pub fn is_verified(&self, channel: CommsChannel) -> bool {
44
+
match channel {
45
+
CommsChannel::Email => self.email,
46
+
CommsChannel::Discord => self.discord,
47
+
CommsChannel::Telegram => self.telegram,
48
+
CommsChannel::Signal => self.signal,
49
+
}
50
+
}
51
+
}
+6
-5
crates/tranquil-db-traits/src/delegation.rs
+6
-5
crates/tranquil-db-traits/src/delegation.rs
···
5
5
use uuid::Uuid;
6
6
7
7
use crate::DbError;
8
+
use crate::scope::DbScope;
8
9
9
10
#[derive(Debug, Clone, Serialize, Deserialize)]
10
11
pub struct DelegationGrant {
11
12
pub id: Uuid,
12
13
pub delegated_did: Did,
13
14
pub controller_did: Did,
14
-
pub granted_scopes: String,
15
+
pub granted_scopes: DbScope,
15
16
pub granted_at: DateTime<Utc>,
16
17
pub granted_by: Did,
17
18
pub revoked_at: Option<DateTime<Utc>>,
···
22
23
pub struct DelegatedAccountInfo {
23
24
pub did: Did,
24
25
pub handle: Handle,
25
-
pub granted_scopes: String,
26
+
pub granted_scopes: DbScope,
26
27
pub granted_at: DateTime<Utc>,
27
28
}
28
29
···
30
31
pub struct ControllerInfo {
31
32
pub did: Did,
32
33
pub handle: Handle,
33
-
pub granted_scopes: String,
34
+
pub granted_scopes: DbScope,
34
35
pub granted_at: DateTime<Utc>,
35
36
pub is_active: bool,
36
37
}
···
67
68
&self,
68
69
delegated_did: &Did,
69
70
controller_did: &Did,
70
-
granted_scopes: &str,
71
+
granted_scopes: &DbScope,
71
72
granted_by: &Did,
72
73
) -> Result<Uuid, DbError>;
73
74
···
82
83
&self,
83
84
delegated_did: &Did,
84
85
controller_did: &Did,
85
-
new_scopes: &str,
86
+
new_scopes: &DbScope,
86
87
) -> Result<bool, DbError>;
87
88
88
89
async fn get_delegation(
+63
-7
crates/tranquil-db-traits/src/infra.rs
+63
-7
crates/tranquil-db-traits/src/infra.rs
···
5
5
use uuid::Uuid;
6
6
7
7
use crate::DbError;
8
+
use crate::invite_code::{InviteCodeError, ValidatedInviteCode};
8
9
9
10
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10
11
pub enum InviteCodeSortOrder {
···
13
14
Usage,
14
15
}
15
16
17
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
18
+
pub enum InviteCodeState {
19
+
#[default]
20
+
Active,
21
+
Disabled,
22
+
}
23
+
24
+
impl InviteCodeState {
25
+
pub fn is_active(self) -> bool {
26
+
matches!(self, Self::Active)
27
+
}
28
+
29
+
pub fn is_disabled(self) -> bool {
30
+
matches!(self, Self::Disabled)
31
+
}
32
+
}
33
+
34
+
impl From<bool> for InviteCodeState {
35
+
fn from(disabled: bool) -> Self {
36
+
if disabled {
37
+
Self::Disabled
38
+
} else {
39
+
Self::Active
40
+
}
41
+
}
42
+
}
43
+
44
+
impl From<Option<bool>> for InviteCodeState {
45
+
fn from(disabled: Option<bool>) -> Self {
46
+
Self::from(disabled.unwrap_or(false))
47
+
}
48
+
}
49
+
50
+
impl From<InviteCodeState> for bool {
51
+
fn from(state: InviteCodeState) -> Self {
52
+
matches!(state, InviteCodeState::Disabled)
53
+
}
54
+
}
55
+
16
56
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
17
57
#[sqlx(type_name = "comms_channel", rename_all = "snake_case")]
18
58
pub enum CommsChannel {
···
72
112
pub struct InviteCodeInfo {
73
113
pub code: String,
74
114
pub available_uses: i32,
75
-
pub disabled: bool,
115
+
pub state: InviteCodeState,
76
116
pub for_account: Option<Did>,
77
117
pub created_at: DateTime<Utc>,
78
118
pub created_by: Option<Did>,
···
95
135
pub created_at: DateTime<Utc>,
96
136
}
97
137
138
+
impl InviteCodeRow {
139
+
pub fn state(&self) -> InviteCodeState {
140
+
InviteCodeState::from(self.disabled)
141
+
}
142
+
}
143
+
98
144
#[derive(Debug, Clone)]
99
145
pub struct ReservedSigningKey {
100
146
pub id: Uuid,
···
148
194
149
195
async fn get_invite_code_available_uses(&self, code: &str) -> Result<Option<i32>, DbError>;
150
196
151
-
async fn is_invite_code_valid(&self, code: &str) -> Result<bool, DbError>;
197
+
async fn validate_invite_code<'a>(
198
+
&self,
199
+
code: &'a str,
200
+
) -> Result<ValidatedInviteCode<'a>, InviteCodeError>;
152
201
153
-
async fn decrement_invite_code_uses(&self, code: &str) -> Result<(), DbError>;
202
+
async fn decrement_invite_code_uses(
203
+
&self,
204
+
code: &ValidatedInviteCode<'_>,
205
+
) -> Result<(), DbError>;
154
206
155
-
async fn record_invite_code_use(&self, code: &str, used_by_user: Uuid) -> Result<(), DbError>;
207
+
async fn record_invite_code_use(
208
+
&self,
209
+
code: &ValidatedInviteCode<'_>,
210
+
used_by_user: Uuid,
211
+
) -> Result<(), DbError>;
156
212
157
213
async fn get_invite_codes_for_account(
158
214
&self,
···
317
373
#[derive(Debug, Clone)]
318
374
pub struct NotificationHistoryRow {
319
375
pub created_at: DateTime<Utc>,
320
-
pub channel: String,
321
-
pub comms_type: String,
322
-
pub status: String,
376
+
pub channel: CommsChannel,
377
+
pub comms_type: CommsType,
378
+
pub status: CommsStatus,
323
379
pub subject: Option<String>,
324
380
pub body: String,
325
381
}
+56
crates/tranquil-db-traits/src/invite_code.rs
+56
crates/tranquil-db-traits/src/invite_code.rs
···
1
+
use std::marker::PhantomData;
2
+
3
+
use crate::DbError;
4
+
5
+
#[derive(Debug)]
6
+
pub struct ValidatedInviteCode<'a> {
7
+
code: &'a str,
8
+
_marker: PhantomData<&'a ()>,
9
+
}
10
+
11
+
impl<'a> ValidatedInviteCode<'a> {
12
+
pub fn new_validated(code: &'a str) -> Self {
13
+
Self {
14
+
code,
15
+
_marker: PhantomData,
16
+
}
17
+
}
18
+
19
+
pub fn code(&self) -> &str {
20
+
self.code
21
+
}
22
+
}
23
+
24
+
#[derive(Debug)]
25
+
pub enum InviteCodeError {
26
+
NotFound,
27
+
ExhaustedUses,
28
+
Disabled,
29
+
DatabaseError(DbError),
30
+
}
31
+
32
+
impl std::fmt::Display for InviteCodeError {
33
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34
+
match self {
35
+
Self::NotFound => write!(f, "Invite code not found"),
36
+
Self::ExhaustedUses => write!(f, "Invite code has no remaining uses"),
37
+
Self::Disabled => write!(f, "Invite code is disabled"),
38
+
Self::DatabaseError(e) => write!(f, "Database error: {}", e),
39
+
}
40
+
}
41
+
}
42
+
43
+
impl std::error::Error for InviteCodeError {
44
+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45
+
match self {
46
+
Self::DatabaseError(e) => Some(e),
47
+
_ => None,
48
+
}
49
+
}
50
+
}
51
+
52
+
impl From<DbError> for InviteCodeError {
53
+
fn from(e: DbError) -> Self {
54
+
Self::DatabaseError(e)
55
+
}
56
+
}
+30
-19
crates/tranquil-db-traits/src/lib.rs
+30
-19
crates/tranquil-db-traits/src/lib.rs
···
1
1
mod backlink;
2
2
mod backup;
3
3
mod blob;
4
+
mod channel_verification;
4
5
mod delegation;
5
6
mod error;
6
7
mod infra;
8
+
mod invite_code;
7
9
mod oauth;
8
10
mod repo;
11
+
mod scope;
12
+
mod sequence;
9
13
mod session;
10
14
mod sso;
11
15
mod user;
···
16
20
OldBackupInfo, UserBackupInfo,
17
21
};
18
22
pub use blob::{BlobForExport, BlobMetadata, BlobRepository, BlobWithTakedown, MissingBlobInfo};
23
+
pub use channel_verification::ChannelVerificationStatus;
19
24
pub use delegation::{
20
25
AuditLogEntry, ControllerInfo, DelegatedAccountInfo, DelegationActionType, DelegationGrant,
21
26
DelegationRepository,
···
23
28
pub use error::DbError;
24
29
pub use infra::{
25
30
AdminAccountInfo, CommsChannel, CommsStatus, CommsType, DeletionRequest, InfraRepository,
26
-
InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeUse, NotificationHistoryRow,
27
-
QueuedComms, ReservedSigningKey,
31
+
InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeState, InviteCodeUse,
32
+
NotificationHistoryRow, QueuedComms, ReservedSigningKey,
28
33
};
34
+
pub use invite_code::{InviteCodeError, ValidatedInviteCode};
29
35
pub use oauth::{
30
36
DeviceAccountRow, DeviceTrustInfo, OAuthRepository, OAuthSessionListItem, RefreshTokenLookup,
31
-
ScopePreference, TrustedDeviceRow, TwoFactorChallenge,
37
+
ScopePreference, TokenFamilyId, TrustedDeviceRow, TwoFactorChallenge,
32
38
};
33
39
pub use repo::{
34
-
ApplyCommitError, ApplyCommitInput, ApplyCommitResult, BrokenGenesisCommit, CommitEventData,
35
-
EventBlocksCids, FullRecordInfo, ImportBlock, ImportRecord, ImportRepoError, RecordDelete,
36
-
RecordInfo, RecordUpsert, RecordWithTakedown, RepoAccountInfo, RepoEventNotifier,
37
-
RepoEventReceiver, RepoInfo, RepoListItem, RepoRepository, RepoSeqEvent, RepoWithoutRev,
38
-
SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks,
40
+
AccountStatus, ApplyCommitError, ApplyCommitInput, ApplyCommitResult, BrokenGenesisCommit,
41
+
CommitEventData, EventBlocksCids, FullRecordInfo, ImportBlock, ImportRecord, ImportRepoError,
42
+
RecordDelete, RecordInfo, RecordUpsert, RecordWithTakedown, RepoAccountInfo, RepoEventNotifier,
43
+
RepoEventReceiver, RepoEventType, RepoInfo, RepoListItem, RepoRepository, RepoSeqEvent,
44
+
RepoWithoutRev, SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks,
39
45
};
46
+
pub use scope::{DbScope, InvalidScopeError};
47
+
pub use sequence::{SequenceNumber, deserialize_optional_sequence};
40
48
pub use session::{
41
-
AppPasswordCreate, AppPasswordRecord, RefreshSessionResult, SessionForRefresh, SessionListItem,
42
-
SessionMfaStatus, SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate,
49
+
AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, LoginType, RefreshSessionResult,
50
+
SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus, SessionRefreshData,
51
+
SessionRepository, SessionToken, SessionTokenCreate,
43
52
};
44
53
pub use sso::{
45
-
ExternalIdentity, SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository,
54
+
ExternalEmail, ExternalIdentity, ExternalUserId, ExternalUsername, SsoAction, SsoAuthState,
55
+
SsoPendingRegistration, SsoProviderType, SsoRepository,
46
56
};
47
57
pub use user::{
48
-
AccountSearchResult, CompletePasskeySetupInput, CreateAccountError,
58
+
AccountSearchResult, AccountType, CompletePasskeySetupInput, CreateAccountError,
49
59
CreateDelegatedAccountInput, CreatePasskeyAccountInput, CreatePasswordAccountInput,
50
60
CreatePasswordAccountResult, CreateSsoAccountInput, DidWebOverrides,
51
61
MigrationReactivationError, MigrationReactivationInput, NotificationPrefs, OAuthTokenWithUser,
52
62
PasswordResetResult, ReactivatedAccountInfo, RecoverPasskeyAccountInput,
53
63
RecoverPasskeyAccountResult, ScheduledDeletionAccount, StoredBackupCode, StoredPasskey,
54
-
TotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs, UserConfirmSignup, UserDidWebInfo,
55
-
UserEmailInfo, UserForDeletion, UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery,
56
-
UserForPasskeySetup, UserForRecovery, UserForVerification, UserIdAndHandle,
57
-
UserIdAndPasswordHash, UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId,
58
-
UserLegacyLoginPref, UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo,
59
-
UserRepository, UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo,
60
-
UserStatus, UserVerificationInfo, UserWithKey,
64
+
TotpRecord, TotpRecordState, UnverifiedTotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs,
65
+
UserConfirmSignup, UserDidWebInfo, UserEmailInfo, UserForDeletion, UserForDidDoc,
66
+
UserForDidDocBuild, UserForPasskeyRecovery, UserForPasskeySetup, UserForRecovery,
67
+
UserForVerification, UserIdAndHandle, UserIdAndPasswordHash, UserIdHandleEmail,
68
+
UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref, UserLoginCheck,
69
+
UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository, UserResendVerification,
70
+
UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus, UserVerificationInfo, UserWithKey,
71
+
VerifiedTotpRecord,
61
72
};
+47
-12
crates/tranquil-db-traits/src/oauth.rs
+47
-12
crates/tranquil-db-traits/src/oauth.rs
···
10
10
11
11
use crate::DbError;
12
12
13
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14
+
pub struct TokenFamilyId(i32);
15
+
16
+
impl TokenFamilyId {
17
+
pub fn new(id: i32) -> Self {
18
+
Self(id)
19
+
}
20
+
21
+
pub fn as_i32(self) -> i32 {
22
+
self.0
23
+
}
24
+
}
25
+
26
+
impl From<i32> for TokenFamilyId {
27
+
fn from(id: i32) -> Self {
28
+
Self(id)
29
+
}
30
+
}
31
+
32
+
impl From<TokenFamilyId> for i32 {
33
+
fn from(id: TokenFamilyId) -> Self {
34
+
id.0
35
+
}
36
+
}
37
+
38
+
impl std::fmt::Display for TokenFamilyId {
39
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40
+
write!(f, "{}", self.0)
41
+
}
42
+
}
43
+
13
44
#[derive(Debug, Clone, Serialize, Deserialize)]
14
45
pub struct ScopePreference {
15
46
pub scope: String,
···
53
84
54
85
#[derive(Debug, Clone)]
55
86
pub struct OAuthSessionListItem {
56
-
pub id: i32,
87
+
pub id: TokenFamilyId,
57
88
pub token_id: TokenId,
58
89
pub created_at: DateTime<Utc>,
59
90
pub expires_at: DateTime<Utc>,
···
62
93
63
94
pub enum RefreshTokenLookup {
64
95
Valid {
65
-
db_id: i32,
96
+
db_id: TokenFamilyId,
66
97
token_data: TokenData,
67
98
},
68
99
InGracePeriod {
69
-
db_id: i32,
100
+
db_id: TokenFamilyId,
70
101
token_data: TokenData,
71
102
rotated_at: DateTime<Utc>,
72
103
},
73
104
Used {
74
-
original_token_id: i32,
105
+
original_token_id: TokenFamilyId,
75
106
},
76
107
Expired {
77
-
db_id: i32,
108
+
db_id: TokenFamilyId,
78
109
},
79
110
NotFound,
80
111
}
···
93
124
94
125
#[async_trait]
95
126
pub trait OAuthRepository: Send + Sync {
96
-
async fn create_token(&self, data: &TokenData) -> Result<i32, DbError>;
127
+
async fn create_token(&self, data: &TokenData) -> Result<TokenFamilyId, DbError>;
97
128
async fn get_token_by_id(&self, token_id: &TokenId) -> Result<Option<TokenData>, DbError>;
98
129
async fn get_token_by_refresh_token(
99
130
&self,
100
131
refresh_token: &RefreshToken,
101
-
) -> Result<Option<(i32, TokenData)>, DbError>;
132
+
) -> Result<Option<(TokenFamilyId, TokenData)>, DbError>;
102
133
async fn get_token_by_previous_refresh_token(
103
134
&self,
104
135
refresh_token: &RefreshToken,
105
-
) -> Result<Option<(i32, TokenData)>, DbError>;
136
+
) -> Result<Option<(TokenFamilyId, TokenData)>, DbError>;
106
137
async fn rotate_token(
107
138
&self,
108
-
old_db_id: i32,
139
+
old_db_id: TokenFamilyId,
109
140
new_refresh_token: &RefreshToken,
110
141
new_expires_at: DateTime<Utc>,
111
142
) -> Result<(), DbError>;
112
143
async fn check_refresh_token_used(
113
144
&self,
114
145
refresh_token: &RefreshToken,
115
-
) -> Result<Option<i32>, DbError>;
146
+
) -> Result<Option<TokenFamilyId>, DbError>;
116
147
async fn delete_token(&self, token_id: &TokenId) -> Result<(), DbError>;
117
-
async fn delete_token_family(&self, db_id: i32) -> Result<(), DbError>;
148
+
async fn delete_token_family(&self, db_id: TokenFamilyId) -> Result<(), DbError>;
118
149
async fn list_tokens_for_user(&self, did: &Did) -> Result<Vec<TokenData>, DbError>;
119
150
async fn count_tokens_for_user(&self, did: &Did) -> Result<i64, DbError>;
120
151
async fn delete_oldest_tokens_for_user(
···
274
305
) -> Result<(), DbError>;
275
306
276
307
async fn list_sessions_by_did(&self, did: &Did) -> Result<Vec<OAuthSessionListItem>, DbError>;
277
-
async fn delete_session_by_id(&self, session_id: i32, did: &Did) -> Result<u64, DbError>;
308
+
async fn delete_session_by_id(
309
+
&self,
310
+
session_id: TokenFamilyId,
311
+
did: &Did,
312
+
) -> Result<u64, DbError>;
278
313
async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError>;
279
314
async fn delete_sessions_by_did_except(
280
315
&self,
+146
-24
crates/tranquil-db-traits/src/repo.rs
+146
-24
crates/tranquil-db-traits/src/repo.rs
···
5
5
use uuid::Uuid;
6
6
7
7
use crate::DbError;
8
+
use crate::sequence::SequenceNumber;
9
+
10
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
11
+
#[sqlx(type_name = "text", rename_all = "snake_case")]
12
+
#[serde(rename_all = "snake_case")]
13
+
pub enum RepoEventType {
14
+
Commit,
15
+
Identity,
16
+
Account,
17
+
Sync,
18
+
}
19
+
20
+
impl RepoEventType {
21
+
pub fn as_str(&self) -> &'static str {
22
+
match self {
23
+
Self::Commit => "commit",
24
+
Self::Identity => "identity",
25
+
Self::Account => "account",
26
+
Self::Sync => "sync",
27
+
}
28
+
}
29
+
}
30
+
31
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
32
+
#[sqlx(type_name = "text", rename_all = "lowercase")]
33
+
#[serde(rename_all = "lowercase")]
34
+
pub enum AccountStatus {
35
+
Active,
36
+
Takendown,
37
+
Suspended,
38
+
Deactivated,
39
+
Deleted,
40
+
}
41
+
42
+
impl AccountStatus {
43
+
pub fn as_str(&self) -> &'static str {
44
+
match self {
45
+
Self::Active => "active",
46
+
Self::Takendown => "takendown",
47
+
Self::Suspended => "suspended",
48
+
Self::Deactivated => "deactivated",
49
+
Self::Deleted => "deleted",
50
+
}
51
+
}
52
+
53
+
pub fn for_firehose(&self) -> Option<&'static str> {
54
+
match self {
55
+
Self::Active => None,
56
+
other => Some(other.as_str()),
57
+
}
58
+
}
59
+
60
+
pub fn parse(s: &str) -> Option<Self> {
61
+
match s.to_lowercase().as_str() {
62
+
"active" => Some(Self::Active),
63
+
"takendown" => Some(Self::Takendown),
64
+
"suspended" => Some(Self::Suspended),
65
+
"deactivated" => Some(Self::Deactivated),
66
+
"deleted" => Some(Self::Deleted),
67
+
_ => None,
68
+
}
69
+
}
70
+
71
+
pub fn is_active(&self) -> bool {
72
+
matches!(self, Self::Active)
73
+
}
74
+
75
+
pub fn is_takendown(&self) -> bool {
76
+
matches!(self, Self::Takendown)
77
+
}
78
+
79
+
pub fn is_deactivated(&self) -> bool {
80
+
matches!(self, Self::Deactivated)
81
+
}
82
+
83
+
pub fn is_suspended(&self) -> bool {
84
+
matches!(self, Self::Suspended)
85
+
}
86
+
87
+
pub fn is_deleted(&self) -> bool {
88
+
matches!(self, Self::Deleted)
89
+
}
90
+
91
+
pub fn allows_read(&self) -> bool {
92
+
matches!(self, Self::Active | Self::Deactivated)
93
+
}
94
+
95
+
pub fn allows_write(&self) -> bool {
96
+
matches!(self, Self::Active)
97
+
}
98
+
99
+
pub fn from_db_fields(
100
+
takedown_ref: Option<&str>,
101
+
deactivated_at: Option<DateTime<Utc>>,
102
+
) -> Self {
103
+
if takedown_ref.is_some() {
104
+
Self::Takendown
105
+
} else if deactivated_at.is_some() {
106
+
Self::Deactivated
107
+
} else {
108
+
Self::Active
109
+
}
110
+
}
111
+
}
112
+
113
+
impl std::fmt::Display for AccountStatus {
114
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115
+
f.write_str(self.as_str())
116
+
}
117
+
}
8
118
9
119
#[derive(Debug, Clone, Serialize, Deserialize)]
10
120
pub struct RepoAccountInfo {
···
49
159
50
160
#[derive(Debug, Clone)]
51
161
pub struct BrokenGenesisCommit {
52
-
pub seq: i64,
162
+
pub seq: SequenceNumber,
53
163
pub did: Did,
54
164
pub commit_cid: Option<CidLink>,
55
165
}
···
69
179
70
180
#[derive(Debug, Clone, Serialize, Deserialize)]
71
181
pub struct RepoSeqEvent {
72
-
pub seq: i64,
182
+
pub seq: SequenceNumber,
73
183
}
74
184
75
185
#[derive(Debug, Clone, Serialize, Deserialize)]
76
186
pub struct SequencedEvent {
77
-
pub seq: i64,
187
+
pub seq: SequenceNumber,
78
188
pub did: Did,
79
189
pub created_at: DateTime<Utc>,
80
-
pub event_type: String,
190
+
pub event_type: RepoEventType,
81
191
pub commit_cid: Option<CidLink>,
82
192
pub prev_cid: Option<CidLink>,
83
193
pub prev_data_cid: Option<CidLink>,
···
86
196
pub blocks_cids: Option<Vec<String>>,
87
197
pub handle: Option<Handle>,
88
198
pub active: Option<bool>,
89
-
pub status: Option<String>,
199
+
pub status: Option<AccountStatus>,
90
200
pub rev: Option<String>,
91
201
}
92
202
93
203
#[derive(Debug, Clone)]
94
204
pub struct CommitEventData {
95
205
pub did: Did,
96
-
pub event_type: String,
206
+
pub event_type: RepoEventType,
97
207
pub commit_cid: Option<CidLink>,
98
208
pub prev_cid: Option<CidLink>,
99
209
pub ops: Option<serde_json::Value>,
···
283
393
284
394
async fn count_user_blocks(&self, user_id: Uuid) -> Result<i64, DbError>;
285
395
286
-
async fn insert_commit_event(&self, data: &CommitEventData) -> Result<i64, DbError>;
396
+
async fn insert_commit_event(&self, data: &CommitEventData) -> Result<SequenceNumber, DbError>;
287
397
288
398
async fn insert_identity_event(
289
399
&self,
290
400
did: &Did,
291
401
handle: Option<&Handle>,
292
-
) -> Result<i64, DbError>;
402
+
) -> Result<SequenceNumber, DbError>;
293
403
294
404
async fn insert_account_event(
295
405
&self,
296
406
did: &Did,
297
-
active: bool,
298
-
status: Option<&str>,
299
-
) -> Result<i64, DbError>;
407
+
status: AccountStatus,
408
+
) -> Result<SequenceNumber, DbError>;
300
409
301
410
async fn insert_sync_event(
302
411
&self,
303
412
did: &Did,
304
413
commit_cid: &CidLink,
305
414
rev: Option<&str>,
306
-
) -> Result<i64, DbError>;
415
+
) -> Result<SequenceNumber, DbError>;
307
416
308
417
async fn insert_genesis_commit_event(
309
418
&self,
···
311
420
commit_cid: &CidLink,
312
421
mst_root_cid: &CidLink,
313
422
rev: &str,
314
-
) -> Result<i64, DbError>;
423
+
) -> Result<SequenceNumber, DbError>;
315
424
316
-
async fn update_seq_blocks_cids(&self, seq: i64, blocks_cids: &[String])
317
-
-> Result<(), DbError>;
425
+
async fn update_seq_blocks_cids(
426
+
&self,
427
+
seq: SequenceNumber,
428
+
blocks_cids: &[String],
429
+
) -> Result<(), DbError>;
318
430
319
-
async fn delete_sequences_except(&self, did: &Did, keep_seq: i64) -> Result<(), DbError>;
431
+
async fn delete_sequences_except(
432
+
&self,
433
+
did: &Did,
434
+
keep_seq: SequenceNumber,
435
+
) -> Result<(), DbError>;
320
436
321
-
async fn get_max_seq(&self) -> Result<i64, DbError>;
437
+
async fn get_max_seq(&self) -> Result<SequenceNumber, DbError>;
322
438
323
-
async fn get_min_seq_since(&self, since: DateTime<Utc>) -> Result<Option<i64>, DbError>;
439
+
async fn get_min_seq_since(
440
+
&self,
441
+
since: DateTime<Utc>,
442
+
) -> Result<Option<SequenceNumber>, DbError>;
324
443
325
444
async fn get_account_with_repo(&self, did: &Did) -> Result<Option<RepoAccountInfo>, DbError>;
326
445
327
446
async fn get_events_since_seq(
328
447
&self,
329
-
since_seq: i64,
448
+
since_seq: SequenceNumber,
330
449
limit: Option<i64>,
331
450
) -> Result<Vec<SequencedEvent>, DbError>;
332
451
333
452
async fn get_events_in_seq_range(
334
453
&self,
335
-
start_seq: i64,
336
-
end_seq: i64,
454
+
start_seq: SequenceNumber,
455
+
end_seq: SequenceNumber,
337
456
) -> Result<Vec<SequencedEvent>, DbError>;
338
457
339
-
async fn get_event_by_seq(&self, seq: i64) -> Result<Option<SequencedEvent>, DbError>;
458
+
async fn get_event_by_seq(
459
+
&self,
460
+
seq: SequenceNumber,
461
+
) -> Result<Option<SequencedEvent>, DbError>;
340
462
341
463
async fn get_events_since_cursor(
342
464
&self,
343
-
cursor: i64,
465
+
cursor: SequenceNumber,
344
466
limit: i64,
345
467
) -> Result<Vec<SequencedEvent>, DbError>;
346
468
···
359
481
async fn get_repo_root_cid_by_user_id(&self, user_id: Uuid)
360
482
-> Result<Option<CidLink>, DbError>;
361
483
362
-
async fn notify_update(&self, seq: i64) -> Result<(), DbError>;
484
+
async fn notify_update(&self, seq: SequenceNumber) -> Result<(), DbError>;
363
485
364
486
async fn import_repo_data(
365
487
&self,
+180
crates/tranquil-db-traits/src/scope.rs
+180
crates/tranquil-db-traits/src/scope.rs
···
1
+
use serde::{Deserialize, Deserializer, Serialize};
2
+
use std::fmt;
3
+
4
+
#[derive(Debug, Clone, PartialEq, Eq)]
5
+
pub struct DbScope(String);
6
+
7
+
impl DbScope {
8
+
pub fn new(scope: impl Into<String>) -> Result<Self, InvalidScopeError> {
9
+
let scope = scope.into();
10
+
validate_scope_string(&scope)?;
11
+
Ok(Self(scope))
12
+
}
13
+
14
+
pub fn empty() -> Self {
15
+
Self(String::new())
16
+
}
17
+
18
+
pub fn from_db(scope: String) -> Self {
19
+
match validate_scope_string(&scope) {
20
+
Ok(()) => Self(scope),
21
+
Err(e) => panic!("corrupted scope data from database: {}", e),
22
+
}
23
+
}
24
+
25
+
pub fn as_str(&self) -> &str {
26
+
&self.0
27
+
}
28
+
29
+
pub fn into_string(self) -> String {
30
+
self.0
31
+
}
32
+
33
+
pub fn is_empty(&self) -> bool {
34
+
self.0.is_empty()
35
+
}
36
+
}
37
+
38
+
impl Default for DbScope {
39
+
fn default() -> Self {
40
+
Self::empty()
41
+
}
42
+
}
43
+
44
+
impl fmt::Display for DbScope {
45
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46
+
write!(f, "{}", self.0)
47
+
}
48
+
}
49
+
50
+
impl AsRef<str> for DbScope {
51
+
fn as_ref(&self) -> &str {
52
+
&self.0
53
+
}
54
+
}
55
+
56
+
impl Serialize for DbScope {
57
+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
58
+
where
59
+
S: serde::Serializer,
60
+
{
61
+
self.0.serialize(serializer)
62
+
}
63
+
}
64
+
65
+
impl<'de> Deserialize<'de> for DbScope {
66
+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67
+
where
68
+
D: Deserializer<'de>,
69
+
{
70
+
let s = String::deserialize(deserializer)?;
71
+
Self::new(s).map_err(serde::de::Error::custom)
72
+
}
73
+
}
74
+
75
+
#[derive(Debug, Clone)]
76
+
pub struct InvalidScopeError {
77
+
message: String,
78
+
}
79
+
80
+
impl InvalidScopeError {
81
+
pub fn new(message: impl Into<String>) -> Self {
82
+
Self {
83
+
message: message.into(),
84
+
}
85
+
}
86
+
87
+
pub fn message(&self) -> &str {
88
+
&self.message
89
+
}
90
+
}
91
+
92
+
impl fmt::Display for InvalidScopeError {
93
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94
+
write!(f, "{}", self.message)
95
+
}
96
+
}
97
+
98
+
impl std::error::Error for InvalidScopeError {}
99
+
100
+
fn validate_scope_string(scopes: &str) -> Result<(), InvalidScopeError> {
101
+
if scopes.is_empty() {
102
+
return Ok(());
103
+
}
104
+
105
+
scopes.split_whitespace().try_for_each(|scope| {
106
+
let base = scope.split_once('?').map_or(scope, |(b, _)| b);
107
+
if is_valid_scope_prefix(base) {
108
+
Ok(())
109
+
} else {
110
+
Err(InvalidScopeError::new(format!("Invalid scope: {}", scope)))
111
+
}
112
+
})
113
+
}
114
+
115
+
fn is_valid_scope_prefix(base: &str) -> bool {
116
+
const VALID_PREFIXES: [&str; 8] = [
117
+
"atproto",
118
+
"repo:",
119
+
"blob:",
120
+
"rpc:",
121
+
"account:",
122
+
"identity:",
123
+
"transition:",
124
+
"include:",
125
+
];
126
+
127
+
VALID_PREFIXES
128
+
.iter()
129
+
.any(|prefix| base == prefix.trim_end_matches(':') || base.starts_with(prefix))
130
+
}
131
+
132
+
#[cfg(test)]
133
+
mod tests {
134
+
use super::*;
135
+
136
+
#[test]
137
+
fn test_valid_scopes() {
138
+
assert!(DbScope::new("atproto").is_ok());
139
+
assert!(DbScope::new("repo:*").is_ok());
140
+
assert!(DbScope::new("blob:*/*").is_ok());
141
+
assert!(DbScope::new("repo:* blob:*/*").is_ok());
142
+
assert!(DbScope::new("").is_ok());
143
+
assert!(DbScope::new("account:email?action=read").is_ok());
144
+
assert!(DbScope::new("identity:handle").is_ok());
145
+
assert!(DbScope::new("transition:generic").is_ok());
146
+
assert!(DbScope::new("include:app.bsky.authFullApp").is_ok());
147
+
}
148
+
149
+
#[test]
150
+
fn test_invalid_scopes() {
151
+
assert!(DbScope::new("invalid:scope").is_err());
152
+
assert!(DbScope::new("garbage").is_err());
153
+
assert!(DbScope::new("repo:* invalid:scope").is_err());
154
+
}
155
+
156
+
#[test]
157
+
fn test_empty_scope() {
158
+
let scope = DbScope::empty();
159
+
assert!(scope.is_empty());
160
+
assert_eq!(scope.as_str(), "");
161
+
}
162
+
163
+
#[test]
164
+
fn test_display() {
165
+
let scope = DbScope::new("repo:*").unwrap();
166
+
assert_eq!(format!("{}", scope), "repo:*");
167
+
}
168
+
169
+
#[test]
170
+
#[should_panic(expected = "corrupted scope data from database")]
171
+
fn test_from_db_panics_on_corrupted_data() {
172
+
DbScope::from_db("totally_invalid_garbage_scope".to_string());
173
+
}
174
+
175
+
#[test]
176
+
fn test_from_db_accepts_valid_data() {
177
+
let scope = DbScope::from_db("repo:* blob:*/*".to_string());
178
+
assert_eq!(scope.as_str(), "repo:* blob:*/*");
179
+
}
180
+
}
+73
crates/tranquil-db-traits/src/sequence.rs
+73
crates/tranquil-db-traits/src/sequence.rs
···
1
+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
2
+
use std::fmt;
3
+
4
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
5
+
pub struct SequenceNumber(i64);
6
+
7
+
impl SequenceNumber {
8
+
pub const ZERO: Self = Self(0);
9
+
pub const UNSET: Self = Self(-1);
10
+
11
+
pub fn new(n: i64) -> Option<Self> {
12
+
if n >= 0 { Some(Self(n)) } else { None }
13
+
}
14
+
15
+
pub fn from_raw(n: i64) -> Self {
16
+
Self(n)
17
+
}
18
+
19
+
pub fn as_i64(&self) -> i64 {
20
+
self.0
21
+
}
22
+
23
+
pub fn is_valid(&self) -> bool {
24
+
self.0 >= 0
25
+
}
26
+
}
27
+
28
+
impl fmt::Display for SequenceNumber {
29
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30
+
write!(f, "{}", self.0)
31
+
}
32
+
}
33
+
34
+
impl From<i64> for SequenceNumber {
35
+
fn from(n: i64) -> Self {
36
+
Self(n)
37
+
}
38
+
}
39
+
40
+
impl From<SequenceNumber> for i64 {
41
+
fn from(seq: SequenceNumber) -> Self {
42
+
seq.0
43
+
}
44
+
}
45
+
46
+
impl Serialize for SequenceNumber {
47
+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48
+
where
49
+
S: Serializer,
50
+
{
51
+
self.0.serialize(serializer)
52
+
}
53
+
}
54
+
55
+
impl<'de> Deserialize<'de> for SequenceNumber {
56
+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57
+
where
58
+
D: Deserializer<'de>,
59
+
{
60
+
let n = i64::deserialize(deserializer)?;
61
+
Ok(Self(n))
62
+
}
63
+
}
64
+
65
+
pub fn deserialize_optional_sequence<'de, D>(
66
+
deserializer: D,
67
+
) -> Result<Option<SequenceNumber>, D::Error>
68
+
where
69
+
D: Deserializer<'de>,
70
+
{
71
+
let opt: Option<i64> = Option::deserialize(deserializer)?;
72
+
Ok(opt.map(SequenceNumber::from_raw))
73
+
}
+107
-15
crates/tranquil-db-traits/src/session.rs
+107
-15
crates/tranquil-db-traits/src/session.rs
···
5
5
6
6
use crate::DbError;
7
7
8
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
9
+
pub enum LoginType {
10
+
#[default]
11
+
Modern,
12
+
Legacy,
13
+
}
14
+
15
+
impl LoginType {
16
+
pub fn is_legacy(self) -> bool {
17
+
matches!(self, Self::Legacy)
18
+
}
19
+
20
+
pub fn is_modern(self) -> bool {
21
+
matches!(self, Self::Modern)
22
+
}
23
+
}
24
+
25
+
impl From<bool> for LoginType {
26
+
fn from(legacy: bool) -> Self {
27
+
if legacy { Self::Legacy } else { Self::Modern }
28
+
}
29
+
}
30
+
31
+
impl From<LoginType> for bool {
32
+
fn from(lt: LoginType) -> Self {
33
+
matches!(lt, LoginType::Legacy)
34
+
}
35
+
}
36
+
37
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
38
+
pub enum AppPasswordPrivilege {
39
+
#[default]
40
+
Standard,
41
+
Privileged,
42
+
}
43
+
44
+
impl AppPasswordPrivilege {
45
+
pub fn is_privileged(self) -> bool {
46
+
matches!(self, Self::Privileged)
47
+
}
48
+
}
49
+
50
+
impl From<bool> for AppPasswordPrivilege {
51
+
fn from(privileged: bool) -> Self {
52
+
if privileged {
53
+
Self::Privileged
54
+
} else {
55
+
Self::Standard
56
+
}
57
+
}
58
+
}
59
+
60
+
impl From<AppPasswordPrivilege> for bool {
61
+
fn from(p: AppPasswordPrivilege) -> Self {
62
+
matches!(p, AppPasswordPrivilege::Privileged)
63
+
}
64
+
}
65
+
66
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67
+
pub struct SessionId(i32);
68
+
69
+
impl SessionId {
70
+
pub fn new(id: i32) -> Self {
71
+
Self(id)
72
+
}
73
+
74
+
pub fn as_i32(self) -> i32 {
75
+
self.0
76
+
}
77
+
}
78
+
79
+
impl From<i32> for SessionId {
80
+
fn from(id: i32) -> Self {
81
+
Self(id)
82
+
}
83
+
}
84
+
85
+
impl From<SessionId> for i32 {
86
+
fn from(id: SessionId) -> Self {
87
+
id.0
88
+
}
89
+
}
90
+
91
+
impl std::fmt::Display for SessionId {
92
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93
+
write!(f, "{}", self.0)
94
+
}
95
+
}
96
+
8
97
#[derive(Debug, Clone)]
9
98
pub struct SessionToken {
10
-
pub id: i32,
99
+
pub id: SessionId,
11
100
pub did: Did,
12
101
pub access_jti: String,
13
102
pub refresh_jti: String,
14
103
pub access_expires_at: DateTime<Utc>,
15
104
pub refresh_expires_at: DateTime<Utc>,
16
-
pub legacy_login: bool,
105
+
pub login_type: LoginType,
17
106
pub mfa_verified: bool,
18
107
pub scope: Option<String>,
19
108
pub controller_did: Option<Did>,
···
29
118
pub refresh_jti: String,
30
119
pub access_expires_at: DateTime<Utc>,
31
120
pub refresh_expires_at: DateTime<Utc>,
32
-
pub legacy_login: bool,
121
+
pub login_type: LoginType,
33
122
pub mfa_verified: bool,
34
123
pub scope: Option<String>,
35
124
pub controller_did: Option<Did>,
···
38
127
39
128
#[derive(Debug, Clone)]
40
129
pub struct SessionForRefresh {
41
-
pub id: i32,
130
+
pub id: SessionId,
42
131
pub did: Did,
43
132
pub scope: Option<String>,
44
133
pub controller_did: Option<Did>,
···
48
137
49
138
#[derive(Debug, Clone)]
50
139
pub struct SessionListItem {
51
-
pub id: i32,
140
+
pub id: SessionId,
52
141
pub access_jti: String,
53
142
pub created_at: DateTime<Utc>,
54
143
pub refresh_expires_at: DateTime<Utc>,
···
61
150
pub name: String,
62
151
pub password_hash: String,
63
152
pub created_at: DateTime<Utc>,
64
-
pub privileged: bool,
153
+
pub privilege: AppPasswordPrivilege,
65
154
pub scopes: Option<String>,
66
155
pub created_by_controller_did: Option<Did>,
67
156
}
···
71
160
pub user_id: Uuid,
72
161
pub name: String,
73
162
pub password_hash: String,
74
-
pub privileged: bool,
163
+
pub privilege: AppPasswordPrivilege,
75
164
pub scopes: Option<String>,
76
165
pub created_by_controller_did: Option<Did>,
77
166
}
78
167
79
168
#[derive(Debug, Clone)]
80
169
pub struct SessionMfaStatus {
81
-
pub legacy_login: bool,
170
+
pub login_type: LoginType,
82
171
pub mfa_verified: bool,
83
172
pub last_reauth_at: Option<DateTime<Utc>>,
84
173
}
···
93
182
#[derive(Debug, Clone)]
94
183
pub struct SessionRefreshData {
95
184
pub old_refresh_jti: String,
96
-
pub session_id: i32,
185
+
pub session_id: SessionId,
97
186
pub new_access_jti: String,
98
187
pub new_refresh_jti: String,
99
188
pub new_access_expires_at: DateTime<Utc>,
···
102
191
103
192
#[async_trait]
104
193
pub trait SessionRepository: Send + Sync {
105
-
async fn create_session(&self, data: &SessionTokenCreate) -> Result<i32, DbError>;
194
+
async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError>;
106
195
107
196
async fn get_session_by_access_jti(
108
197
&self,
···
116
205
117
206
async fn update_session_tokens(
118
207
&self,
119
-
session_id: i32,
208
+
session_id: SessionId,
120
209
new_access_jti: &str,
121
210
new_refresh_jti: &str,
122
211
new_access_expires_at: DateTime<Utc>,
···
125
214
126
215
async fn delete_session_by_access_jti(&self, access_jti: &str) -> Result<u64, DbError>;
127
216
128
-
async fn delete_session_by_id(&self, session_id: i32) -> Result<u64, DbError>;
217
+
async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError>;
129
218
130
219
async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError>;
131
220
···
139
228
140
229
async fn get_session_access_jti_by_id(
141
230
&self,
142
-
session_id: i32,
231
+
session_id: SessionId,
143
232
did: &Did,
144
233
) -> Result<Option<String>, DbError>;
145
234
···
155
244
app_password_name: &str,
156
245
) -> Result<Vec<String>, DbError>;
157
246
158
-
async fn check_refresh_token_used(&self, refresh_jti: &str) -> Result<Option<i32>, DbError>;
247
+
async fn check_refresh_token_used(
248
+
&self,
249
+
refresh_jti: &str,
250
+
) -> Result<Option<SessionId>, DbError>;
159
251
160
252
async fn mark_refresh_token_used(
161
253
&self,
162
254
refresh_jti: &str,
163
-
session_id: i32,
255
+
session_id: SessionId,
164
256
) -> Result<bool, DbError>;
165
257
166
258
async fn list_app_passwords(&self, user_id: Uuid) -> Result<Vec<AppPasswordRecord>, DbError>;
+147
-8
crates/tranquil-db-traits/src/sso.rs
+147
-8
crates/tranquil-db-traits/src/sso.rs
···
6
6
7
7
use crate::DbError;
8
8
9
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10
+
pub struct ExternalUserId(String);
11
+
12
+
impl ExternalUserId {
13
+
pub fn new(id: impl Into<String>) -> Self {
14
+
Self(id.into())
15
+
}
16
+
17
+
pub fn as_str(&self) -> &str {
18
+
&self.0
19
+
}
20
+
21
+
pub fn into_inner(self) -> String {
22
+
self.0
23
+
}
24
+
}
25
+
26
+
impl std::fmt::Display for ExternalUserId {
27
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28
+
write!(f, "{}", self.0)
29
+
}
30
+
}
31
+
32
+
impl From<String> for ExternalUserId {
33
+
fn from(s: String) -> Self {
34
+
Self(s)
35
+
}
36
+
}
37
+
38
+
impl From<ExternalUserId> for String {
39
+
fn from(id: ExternalUserId) -> Self {
40
+
id.0
41
+
}
42
+
}
43
+
44
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
45
+
pub struct ExternalUsername(String);
46
+
47
+
impl ExternalUsername {
48
+
pub fn new(username: impl Into<String>) -> Self {
49
+
Self(username.into())
50
+
}
51
+
52
+
pub fn as_str(&self) -> &str {
53
+
&self.0
54
+
}
55
+
56
+
pub fn into_inner(self) -> String {
57
+
self.0
58
+
}
59
+
}
60
+
61
+
impl std::fmt::Display for ExternalUsername {
62
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63
+
write!(f, "{}", self.0)
64
+
}
65
+
}
66
+
67
+
impl From<String> for ExternalUsername {
68
+
fn from(s: String) -> Self {
69
+
Self(s)
70
+
}
71
+
}
72
+
73
+
impl From<ExternalUsername> for String {
74
+
fn from(username: ExternalUsername) -> Self {
75
+
username.0
76
+
}
77
+
}
78
+
79
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80
+
pub struct ExternalEmail(String);
81
+
82
+
impl ExternalEmail {
83
+
pub fn new(email: impl Into<String>) -> Self {
84
+
Self(email.into())
85
+
}
86
+
87
+
pub fn as_str(&self) -> &str {
88
+
&self.0
89
+
}
90
+
91
+
pub fn into_inner(self) -> String {
92
+
self.0
93
+
}
94
+
}
95
+
96
+
impl std::fmt::Display for ExternalEmail {
97
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98
+
write!(f, "{}", self.0)
99
+
}
100
+
}
101
+
102
+
impl From<String> for ExternalEmail {
103
+
fn from(s: String) -> Self {
104
+
Self(s)
105
+
}
106
+
}
107
+
108
+
impl From<ExternalEmail> for String {
109
+
fn from(email: ExternalEmail) -> Self {
110
+
email.0
111
+
}
112
+
}
113
+
9
114
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
10
115
#[sqlx(type_name = "sso_provider_type", rename_all = "lowercase")]
11
116
pub enum SsoProviderType {
···
17
122
Apple,
18
123
}
19
124
125
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
126
+
#[sqlx(type_name = "text", rename_all = "lowercase")]
127
+
#[serde(rename_all = "lowercase")]
128
+
pub enum SsoAction {
129
+
Login,
130
+
Link,
131
+
Register,
132
+
}
133
+
134
+
impl SsoAction {
135
+
pub fn as_str(&self) -> &'static str {
136
+
match self {
137
+
Self::Login => "login",
138
+
Self::Link => "link",
139
+
Self::Register => "register",
140
+
}
141
+
}
142
+
143
+
pub fn parse(s: &str) -> Option<Self> {
144
+
match s.to_lowercase().as_str() {
145
+
"login" => Some(Self::Login),
146
+
"link" => Some(Self::Link),
147
+
"register" => Some(Self::Register),
148
+
_ => None,
149
+
}
150
+
}
151
+
}
152
+
153
+
impl std::fmt::Display for SsoAction {
154
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155
+
f.write_str(self.as_str())
156
+
}
157
+
}
158
+
20
159
impl SsoProviderType {
21
160
pub fn as_str(&self) -> &'static str {
22
161
match self {
···
69
208
pub id: Uuid,
70
209
pub did: Did,
71
210
pub provider: SsoProviderType,
72
-
pub provider_user_id: String,
73
-
pub provider_username: Option<String>,
74
-
pub provider_email: Option<String>,
211
+
pub provider_user_id: ExternalUserId,
212
+
pub provider_username: Option<ExternalUsername>,
213
+
pub provider_email: Option<ExternalEmail>,
75
214
pub created_at: DateTime<Utc>,
76
215
pub updated_at: DateTime<Utc>,
77
216
pub last_login_at: Option<DateTime<Utc>>,
···
82
221
pub state: String,
83
222
pub request_uri: String,
84
223
pub provider: SsoProviderType,
85
-
pub action: String,
224
+
pub action: SsoAction,
86
225
pub nonce: Option<String>,
87
226
pub code_verifier: Option<String>,
88
227
pub did: Option<Did>,
···
95
234
pub token: String,
96
235
pub request_uri: String,
97
236
pub provider: SsoProviderType,
98
-
pub provider_user_id: String,
99
-
pub provider_username: Option<String>,
100
-
pub provider_email: Option<String>,
237
+
pub provider_user_id: ExternalUserId,
238
+
pub provider_username: Option<ExternalUsername>,
239
+
pub provider_email: Option<ExternalEmail>,
101
240
pub provider_email_verified: bool,
102
241
pub created_at: DateTime<Utc>,
103
242
pub expires_at: DateTime<Utc>,
···
140
279
state: &str,
141
280
request_uri: &str,
142
281
provider: SsoProviderType,
143
-
action: &str,
282
+
action: SsoAction,
144
283
nonce: Option<&str>,
145
284
code_verifier: Option<&str>,
146
285
did: Option<&Did>,
+100
-34
crates/tranquil-db-traits/src/user.rs
+100
-34
crates/tranquil-db-traits/src/user.rs
···
1
1
use async_trait::async_trait;
2
2
use chrono::{DateTime, Utc};
3
+
use serde::{Deserialize, Serialize};
3
4
use tranquil_types::{Did, Handle};
4
5
use uuid::Uuid;
5
6
6
-
use crate::{CommsChannel, DbError, SsoProviderType};
7
+
use crate::{ChannelVerificationStatus, CommsChannel, DbError, SsoProviderType};
8
+
9
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
10
+
#[sqlx(type_name = "account_type", rename_all = "snake_case")]
11
+
pub enum AccountType {
12
+
Personal,
13
+
Delegated,
14
+
}
15
+
16
+
impl AccountType {
17
+
pub fn is_delegated(&self) -> bool {
18
+
matches!(self, Self::Delegated)
19
+
}
20
+
}
7
21
8
22
#[derive(Debug, Clone)]
9
23
pub struct UserRow {
···
62
76
pub preferred_comms_channel: CommsChannel,
63
77
pub deactivated_at: Option<DateTime<Utc>>,
64
78
pub takedown_ref: Option<String>,
65
-
pub email_verified: bool,
66
-
pub discord_verified: bool,
67
-
pub telegram_verified: bool,
68
-
pub signal_verified: bool,
69
-
pub account_type: String,
79
+
pub channel_verification: ChannelVerificationStatus,
80
+
pub account_type: AccountType,
70
81
}
71
82
72
83
#[derive(Debug, Clone)]
···
74
85
pub id: Uuid,
75
86
pub two_factor_enabled: bool,
76
87
pub preferred_comms_channel: CommsChannel,
77
-
pub email_verified: bool,
78
-
pub discord_verified: bool,
79
-
pub telegram_verified: bool,
80
-
pub signal_verified: bool,
88
+
pub channel_verification: ChannelVerificationStatus,
81
89
}
82
90
83
91
#[async_trait]
···
202
210
did: &Did,
203
211
) -> Result<Option<UserIdHandleEmail>, DbError>;
204
212
205
-
async fn update_preferred_comms_channel(&self, did: &Did, channel: &str)
206
-
-> Result<(), DbError>;
213
+
async fn update_preferred_comms_channel(
214
+
&self,
215
+
did: &Did,
216
+
channel: CommsChannel,
217
+
) -> Result<(), DbError>;
207
218
208
219
async fn clear_discord(&self, user_id: Uuid) -> Result<(), DbError>;
209
220
···
292
303
293
304
async fn get_totp_record(&self, did: &Did) -> Result<Option<TotpRecord>, DbError>;
294
305
306
+
async fn get_totp_record_state(&self, did: &Did) -> Result<Option<TotpRecordState>, DbError>;
307
+
295
308
async fn upsert_totp_secret(
296
309
&self,
297
310
did: &Did,
···
560
573
pub struct UserCommsPrefs {
561
574
pub email: Option<String>,
562
575
pub handle: Handle,
563
-
pub preferred_channel: String,
576
+
pub preferred_channel: CommsChannel,
564
577
pub preferred_locale: Option<String>,
565
578
}
566
579
···
611
624
pub password_hash: Option<String>,
612
625
pub deactivated_at: Option<DateTime<Utc>>,
613
626
pub takedown_ref: Option<String>,
614
-
pub email_verified: bool,
615
-
pub discord_verified: bool,
616
-
pub telegram_verified: bool,
617
-
pub signal_verified: bool,
627
+
pub channel_verification: ChannelVerificationStatus,
618
628
}
619
629
620
630
#[derive(Debug, Clone)]
621
631
pub struct NotificationPrefs {
622
632
pub email: String,
623
-
pub preferred_channel: String,
633
+
pub preferred_channel: CommsChannel,
624
634
pub discord_id: Option<String>,
625
635
pub discord_verified: bool,
626
636
pub telegram_username: Option<String>,
···
641
651
pub id: Uuid,
642
652
pub handle: Handle,
643
653
pub email: Option<String>,
644
-
pub email_verified: bool,
645
-
pub discord_verified: bool,
646
-
pub telegram_verified: bool,
647
-
pub signal_verified: bool,
654
+
pub channel_verification: ChannelVerificationStatus,
648
655
}
649
656
650
657
#[derive(Debug, Clone)]
···
675
682
pub verified: bool,
676
683
}
677
684
685
+
#[derive(Debug, Clone)]
686
+
pub struct VerifiedTotpRecord {
687
+
pub secret_encrypted: Vec<u8>,
688
+
pub encryption_version: i32,
689
+
}
690
+
691
+
#[derive(Debug, Clone)]
692
+
pub struct UnverifiedTotpRecord {
693
+
pub secret_encrypted: Vec<u8>,
694
+
pub encryption_version: i32,
695
+
}
696
+
697
+
#[derive(Debug, Clone)]
698
+
pub enum TotpRecordState {
699
+
Verified(VerifiedTotpRecord),
700
+
Unverified(UnverifiedTotpRecord),
701
+
}
702
+
703
+
impl TotpRecordState {
704
+
pub fn is_verified(&self) -> bool {
705
+
matches!(self, Self::Verified(_))
706
+
}
707
+
708
+
pub fn as_verified(&self) -> Option<&VerifiedTotpRecord> {
709
+
match self {
710
+
Self::Verified(r) => Some(r),
711
+
Self::Unverified(_) => None,
712
+
}
713
+
}
714
+
715
+
pub fn as_unverified(&self) -> Option<&UnverifiedTotpRecord> {
716
+
match self {
717
+
Self::Unverified(r) => Some(r),
718
+
Self::Verified(_) => None,
719
+
}
720
+
}
721
+
722
+
pub fn into_verified(self) -> Option<VerifiedTotpRecord> {
723
+
match self {
724
+
Self::Verified(r) => Some(r),
725
+
Self::Unverified(_) => None,
726
+
}
727
+
}
728
+
729
+
pub fn into_unverified(self) -> Option<UnverifiedTotpRecord> {
730
+
match self {
731
+
Self::Unverified(r) => Some(r),
732
+
Self::Verified(_) => None,
733
+
}
734
+
}
735
+
}
736
+
737
+
impl From<TotpRecord> for TotpRecordState {
738
+
fn from(record: TotpRecord) -> Self {
739
+
if record.verified {
740
+
Self::Verified(VerifiedTotpRecord {
741
+
secret_encrypted: record.secret_encrypted,
742
+
encryption_version: record.encryption_version,
743
+
})
744
+
} else {
745
+
Self::Unverified(UnverifiedTotpRecord {
746
+
secret_encrypted: record.secret_encrypted,
747
+
encryption_version: record.encryption_version,
748
+
})
749
+
}
750
+
}
751
+
}
752
+
678
753
#[derive(Debug, Clone)]
679
754
pub struct StoredBackupCode {
680
755
pub id: Uuid,
···
685
760
pub struct UserSessionInfo {
686
761
pub handle: Handle,
687
762
pub email: Option<String>,
688
-
pub email_verified: bool,
689
763
pub is_admin: bool,
690
764
pub deactivated_at: Option<DateTime<Utc>>,
691
765
pub takedown_ref: Option<String>,
692
766
pub preferred_locale: Option<String>,
693
767
pub preferred_comms_channel: CommsChannel,
694
-
pub discord_verified: bool,
695
-
pub telegram_verified: bool,
696
-
pub signal_verified: bool,
768
+
pub channel_verification: ChannelVerificationStatus,
697
769
pub migrated_to_pds: Option<String>,
698
770
pub migrated_at: Option<DateTime<Utc>>,
699
771
}
···
713
785
pub email: Option<String>,
714
786
pub deactivated_at: Option<DateTime<Utc>>,
715
787
pub takedown_ref: Option<String>,
716
-
pub email_verified: bool,
717
-
pub discord_verified: bool,
718
-
pub telegram_verified: bool,
719
-
pub signal_verified: bool,
788
+
pub channel_verification: ChannelVerificationStatus,
720
789
pub allow_legacy_login: bool,
721
790
pub migrated_to_pds: Option<String>,
722
791
pub preferred_comms_channel: CommsChannel,
···
748
817
pub discord_id: Option<String>,
749
818
pub telegram_username: Option<String>,
750
819
pub signal_number: Option<String>,
751
-
pub email_verified: bool,
752
-
pub discord_verified: bool,
753
-
pub telegram_verified: bool,
754
-
pub signal_verified: bool,
820
+
pub channel_verification: ChannelVerificationStatus,
755
821
}
756
822
757
823
#[derive(Debug, Clone)]
+9
-9
crates/tranquil-db/src/postgres/delegation.rs
+9
-9
crates/tranquil-db/src/postgres/delegation.rs
···
1
1
use async_trait::async_trait;
2
2
use sqlx::PgPool;
3
3
use tranquil_db_traits::{
4
-
AuditLogEntry, ControllerInfo, DbError, DelegatedAccountInfo, DelegationActionType,
4
+
AuditLogEntry, ControllerInfo, DbError, DbScope, DelegatedAccountInfo, DelegationActionType,
5
5
DelegationGrant, DelegationRepository,
6
6
};
7
7
use tranquil_types::Did;
···
80
80
&self,
81
81
delegated_did: &Did,
82
82
controller_did: &Did,
83
-
granted_scopes: &str,
83
+
granted_scopes: &DbScope,
84
84
granted_by: &Did,
85
85
) -> Result<Uuid, DbError> {
86
86
let id = sqlx::query_scalar!(
···
91
91
"#,
92
92
delegated_did.as_str(),
93
93
controller_did.as_str(),
94
-
granted_scopes,
94
+
granted_scopes.as_str(),
95
95
granted_by.as_str()
96
96
)
97
97
.fetch_one(&self.pool)
···
128
128
&self,
129
129
delegated_did: &Did,
130
130
controller_did: &Did,
131
-
new_scopes: &str,
131
+
new_scopes: &DbScope,
132
132
) -> Result<bool, DbError> {
133
133
let result = sqlx::query!(
134
134
r#"
···
136
136
SET granted_scopes = $1
137
137
WHERE delegated_did = $2 AND controller_did = $3 AND revoked_at IS NULL
138
138
"#,
139
-
new_scopes,
139
+
new_scopes.as_str(),
140
140
delegated_did.as_str(),
141
141
controller_did.as_str()
142
142
)
···
170
170
id: r.id,
171
171
delegated_did: r.delegated_did.into(),
172
172
controller_did: r.controller_did.into(),
173
-
granted_scopes: r.granted_scopes,
173
+
granted_scopes: DbScope::from_db(r.granted_scopes),
174
174
granted_at: r.granted_at,
175
175
granted_by: r.granted_by.into(),
176
176
revoked_at: r.revoked_at,
···
206
206
.map(|r| ControllerInfo {
207
207
did: r.did.into(),
208
208
handle: r.handle.into(),
209
-
granted_scopes: r.granted_scopes,
209
+
granted_scopes: DbScope::from_db(r.granted_scopes),
210
210
granted_at: r.granted_at,
211
211
is_active: r.is_active,
212
212
})
···
243
243
.map(|r| DelegatedAccountInfo {
244
244
did: r.did.into(),
245
245
handle: r.handle.into(),
246
-
granted_scopes: r.granted_scopes,
246
+
granted_scopes: DbScope::from_db(r.granted_scopes),
247
247
granted_at: r.granted_at,
248
248
})
249
249
.collect())
···
280
280
.map(|r| ControllerInfo {
281
281
did: r.did.into(),
282
282
handle: r.handle.into(),
283
-
granted_scopes: r.granted_scopes,
283
+
granted_scopes: DbScope::from_db(r.granted_scopes),
284
284
granted_at: r.granted_at,
285
285
is_active: r.is_active,
286
286
})
+34
-18
crates/tranquil-db/src/postgres/infra.rs
+34
-18
crates/tranquil-db/src/postgres/infra.rs
···
3
3
use sqlx::PgPool;
4
4
use tranquil_db_traits::{
5
5
AdminAccountInfo, CommsChannel, CommsStatus, CommsType, DbError, DeletionRequest,
6
-
InfraRepository, InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder, InviteCodeUse,
7
-
NotificationHistoryRow, QueuedComms, ReservedSigningKey,
6
+
InfraRepository, InviteCodeError, InviteCodeInfo, InviteCodeRow, InviteCodeSortOrder,
7
+
InviteCodeState, InviteCodeUse, NotificationHistoryRow, QueuedComms, ReservedSigningKey,
8
+
ValidatedInviteCode,
8
9
};
9
10
use tranquil_types::{CidLink, Did, Handle};
10
11
use uuid::Uuid;
···
182
183
Ok(result)
183
184
}
184
185
185
-
async fn is_invite_code_valid(&self, code: &str) -> Result<bool, DbError> {
186
-
let result = sqlx::query_scalar!(
187
-
r#"SELECT (available_uses > 0 AND NOT COALESCE(disabled, false)) as "valid!" FROM invite_codes WHERE code = $1"#,
186
+
async fn validate_invite_code<'a>(
187
+
&self,
188
+
code: &'a str,
189
+
) -> Result<ValidatedInviteCode<'a>, InviteCodeError> {
190
+
let result = sqlx::query!(
191
+
r#"SELECT available_uses, COALESCE(disabled, false) as "disabled!" FROM invite_codes WHERE code = $1"#,
188
192
code
189
193
)
190
194
.fetch_optional(&self.pool)
191
195
.await
192
-
.map_err(map_sqlx_error)?;
196
+
.map_err(|e| InviteCodeError::DatabaseError(map_sqlx_error(e)))?;
193
197
194
-
Ok(result.unwrap_or(false))
198
+
match result {
199
+
None => Err(InviteCodeError::NotFound),
200
+
Some(row) if row.disabled => Err(InviteCodeError::Disabled),
201
+
Some(row) if row.available_uses <= 0 => Err(InviteCodeError::ExhaustedUses),
202
+
Some(_) => Ok(ValidatedInviteCode::new_validated(code)),
203
+
}
195
204
}
196
205
197
-
async fn decrement_invite_code_uses(&self, code: &str) -> Result<(), DbError> {
206
+
async fn decrement_invite_code_uses(
207
+
&self,
208
+
code: &ValidatedInviteCode<'_>,
209
+
) -> Result<(), DbError> {
198
210
sqlx::query!(
199
211
"UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1",
200
-
code
212
+
code.code()
201
213
)
202
214
.execute(&self.pool)
203
215
.await
···
206
218
Ok(())
207
219
}
208
220
209
-
async fn record_invite_code_use(&self, code: &str, used_by_user: Uuid) -> Result<(), DbError> {
221
+
async fn record_invite_code_use(
222
+
&self,
223
+
code: &ValidatedInviteCode<'_>,
224
+
used_by_user: Uuid,
225
+
) -> Result<(), DbError> {
210
226
sqlx::query!(
211
227
"INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)",
212
-
code,
228
+
code.code(),
213
229
used_by_user
214
230
)
215
231
.execute(&self.pool)
···
245
261
.map(|r| InviteCodeInfo {
246
262
code: r.code,
247
263
available_uses: r.available_uses,
248
-
disabled: r.disabled.unwrap_or(false),
264
+
state: InviteCodeState::from(r.disabled),
249
265
for_account: Some(Did::from(r.for_account)),
250
266
created_at: r.created_at,
251
267
created_by: None,
···
422
438
.map(|r| InviteCodeInfo {
423
439
code: r.code,
424
440
available_uses: r.available_uses,
425
-
disabled: r.disabled.unwrap_or(false),
441
+
state: InviteCodeState::from(r.disabled),
426
442
for_account: Some(Did::from(r.for_account)),
427
443
created_at: r.created_at,
428
444
created_by: Some(Did::from(r.created_by)),
···
445
461
Ok(result.map(|r| InviteCodeInfo {
446
462
code: r.code,
447
463
available_uses: r.available_uses,
448
-
disabled: r.disabled.unwrap_or(false),
464
+
state: InviteCodeState::from(r.disabled),
449
465
for_account: Some(Did::from(r.for_account)),
450
466
created_at: r.created_at,
451
467
created_by: Some(Did::from(r.created_by)),
···
476
492
InviteCodeInfo {
477
493
code: r.code,
478
494
available_uses: r.available_uses,
479
-
disabled: r.disabled.unwrap_or(false),
495
+
state: InviteCodeState::from(r.disabled),
480
496
for_account: Some(Did::from(r.for_account)),
481
497
created_at: r.created_at,
482
498
created_by: Some(Did::from(r.created_by)),
···
841
857
r#"
842
858
SELECT
843
859
created_at,
844
-
channel as "channel: String",
845
-
comms_type as "comms_type: String",
846
-
status as "status: String",
860
+
channel as "channel: CommsChannel",
861
+
comms_type as "comms_type: CommsType",
862
+
status as "status: CommsStatus",
847
863
subject,
848
864
body
849
865
FROM comms_queue
+112
-63
crates/tranquil-db/src/postgres/oauth.rs
+112
-63
crates/tranquil-db/src/postgres/oauth.rs
···
4
4
use sqlx::PgPool;
5
5
use tranquil_db_traits::{
6
6
DbError, DeviceAccountRow, DeviceTrustInfo, OAuthRepository, OAuthSessionListItem,
7
-
ScopePreference, TrustedDeviceRow, TwoFactorChallenge,
7
+
ScopePreference, TokenFamilyId, TrustedDeviceRow, TwoFactorChallenge,
8
8
};
9
9
use tranquil_oauth::{
10
-
AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, DeviceData, RequestData,
11
-
TokenData,
10
+
AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, Code as OAuthCode,
11
+
DeviceData, DeviceId as OAuthDeviceId, RefreshToken as OAuthRefreshToken, RequestData,
12
+
SessionId as OAuthSessionId, TokenData, TokenId as OAuthTokenId,
12
13
};
13
14
use tranquil_types::{
14
15
AuthorizationCode, ClientId, DPoPProofId, DeviceId, Did, Handle, RefreshToken, RequestId,
···
48
49
49
50
#[async_trait]
50
51
impl OAuthRepository for PostgresOAuthRepository {
51
-
async fn create_token(&self, data: &TokenData) -> Result<i32, DbError> {
52
+
async fn create_token(&self, data: &TokenData) -> Result<TokenFamilyId, DbError> {
52
53
let client_auth_json = to_json(&data.client_auth)?;
53
54
let parameters_json = to_json(&data.parameters)?;
54
55
let row = sqlx::query!(
···
59
60
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
60
61
RETURNING id
61
62
"#,
62
-
data.did,
63
-
data.token_id,
63
+
data.did.as_str(),
64
+
&data.token_id.0,
64
65
data.created_at,
65
66
data.updated_at,
66
67
data.expires_at,
67
68
data.client_id,
68
69
client_auth_json,
69
-
data.device_id,
70
+
data.device_id.as_ref().map(|d| d.0.as_str()),
70
71
parameters_json,
71
72
data.details,
72
-
data.code,
73
-
data.current_refresh_token,
73
+
data.code.as_ref().map(|c| c.0.as_str()),
74
+
data.current_refresh_token.as_ref().map(|r| r.0.as_str()),
74
75
data.scope,
75
-
data.controller_did,
76
+
data.controller_did.as_ref().map(|d| d.as_str()),
76
77
)
77
78
.fetch_one(&self.pool)
78
79
.await
79
80
.map_err(map_sqlx_error)?;
80
-
Ok(row.id)
81
+
Ok(TokenFamilyId::new(row.id))
81
82
}
82
83
83
84
async fn get_token_by_id(&self, token_id: &TokenId) -> Result<Option<TokenData>, DbError> {
···
95
96
.map_err(map_sqlx_error)?;
96
97
match row {
97
98
Some(r) => Ok(Some(TokenData {
98
-
did: r.did,
99
-
token_id: r.token_id,
99
+
did: r
100
+
.did
101
+
.parse()
102
+
.map_err(|_| DbError::Other("Invalid DID in token".into()))?,
103
+
token_id: OAuthTokenId(r.token_id),
100
104
created_at: r.created_at,
101
105
updated_at: r.updated_at,
102
106
expires_at: r.expires_at,
103
107
client_id: r.client_id,
104
108
client_auth: from_json(r.client_auth)?,
105
-
device_id: r.device_id,
109
+
device_id: r.device_id.map(OAuthDeviceId),
106
110
parameters: from_json(r.parameters)?,
107
111
details: r.details,
108
-
code: r.code,
109
-
current_refresh_token: r.current_refresh_token,
112
+
code: r.code.map(OAuthCode),
113
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
110
114
scope: r.scope,
111
-
controller_did: r.controller_did,
115
+
controller_did: r
116
+
.controller_did
117
+
.map(|s| s.parse())
118
+
.transpose()
119
+
.map_err(|_| DbError::Other("Invalid controller DID".into()))?,
112
120
})),
113
121
None => Ok(None),
114
122
}
···
117
125
async fn get_token_by_refresh_token(
118
126
&self,
119
127
refresh_token: &RefreshToken,
120
-
) -> Result<Option<(i32, TokenData)>, DbError> {
128
+
) -> Result<Option<(TokenFamilyId, TokenData)>, DbError> {
121
129
let row = sqlx::query!(
122
130
r#"
123
131
SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
···
132
140
.map_err(map_sqlx_error)?;
133
141
match row {
134
142
Some(r) => Ok(Some((
135
-
r.id,
143
+
TokenFamilyId::new(r.id),
136
144
TokenData {
137
-
did: r.did,
138
-
token_id: r.token_id,
145
+
did: r
146
+
.did
147
+
.parse()
148
+
.map_err(|_| DbError::Other("Invalid DID in token".into()))?,
149
+
token_id: OAuthTokenId(r.token_id),
139
150
created_at: r.created_at,
140
151
updated_at: r.updated_at,
141
152
expires_at: r.expires_at,
142
153
client_id: r.client_id,
143
154
client_auth: from_json(r.client_auth)?,
144
-
device_id: r.device_id,
155
+
device_id: r.device_id.map(OAuthDeviceId),
145
156
parameters: from_json(r.parameters)?,
146
157
details: r.details,
147
-
code: r.code,
148
-
current_refresh_token: r.current_refresh_token,
158
+
code: r.code.map(OAuthCode),
159
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
149
160
scope: r.scope,
150
-
controller_did: r.controller_did,
161
+
controller_did: r
162
+
.controller_did
163
+
.map(|s| s.parse())
164
+
.transpose()
165
+
.map_err(|_| DbError::Other("Invalid controller DID".into()))?,
151
166
},
152
167
))),
153
168
None => Ok(None),
···
157
172
async fn get_token_by_previous_refresh_token(
158
173
&self,
159
174
refresh_token: &RefreshToken,
160
-
) -> Result<Option<(i32, TokenData)>, DbError> {
175
+
) -> Result<Option<(TokenFamilyId, TokenData)>, DbError> {
161
176
let grace_cutoff = Utc::now() - Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
162
177
let row = sqlx::query!(
163
178
r#"
···
174
189
.map_err(map_sqlx_error)?;
175
190
match row {
176
191
Some(r) => Ok(Some((
177
-
r.id,
192
+
TokenFamilyId::new(r.id),
178
193
TokenData {
179
-
did: r.did,
180
-
token_id: r.token_id,
194
+
did: r
195
+
.did
196
+
.parse()
197
+
.map_err(|_| DbError::Other("Invalid DID in token".into()))?,
198
+
token_id: OAuthTokenId(r.token_id),
181
199
created_at: r.created_at,
182
200
updated_at: r.updated_at,
183
201
expires_at: r.expires_at,
184
202
client_id: r.client_id,
185
203
client_auth: from_json(r.client_auth)?,
186
-
device_id: r.device_id,
204
+
device_id: r.device_id.map(OAuthDeviceId),
187
205
parameters: from_json(r.parameters)?,
188
206
details: r.details,
189
-
code: r.code,
190
-
current_refresh_token: r.current_refresh_token,
207
+
code: r.code.map(OAuthCode),
208
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
191
209
scope: r.scope,
192
-
controller_did: r.controller_did,
210
+
controller_did: r
211
+
.controller_did
212
+
.map(|s| s.parse())
213
+
.transpose()
214
+
.map_err(|_| DbError::Other("Invalid controller DID".into()))?,
193
215
},
194
216
))),
195
217
None => Ok(None),
···
198
220
199
221
async fn rotate_token(
200
222
&self,
201
-
old_db_id: i32,
223
+
old_db_id: TokenFamilyId,
202
224
new_refresh_token: &RefreshToken,
203
225
new_expires_at: DateTime<Utc>,
204
226
) -> Result<(), DbError> {
···
207
229
r#"
208
230
SELECT current_refresh_token FROM oauth_token WHERE id = $1
209
231
"#,
210
-
old_db_id
232
+
old_db_id.as_i32()
211
233
)
212
234
.fetch_one(&mut *tx)
213
235
.await
···
219
241
VALUES ($1, $2)
220
242
"#,
221
243
old_rt,
222
-
old_db_id
244
+
old_db_id.as_i32()
223
245
)
224
246
.execute(&mut *tx)
225
247
.await
···
232
254
previous_refresh_token = $4, rotated_at = NOW()
233
255
WHERE id = $1
234
256
"#,
235
-
old_db_id,
257
+
old_db_id.as_i32(),
236
258
new_refresh_token.as_str(),
237
259
new_expires_at,
238
260
old_refresh
···
247
269
async fn check_refresh_token_used(
248
270
&self,
249
271
refresh_token: &RefreshToken,
250
-
) -> Result<Option<i32>, DbError> {
272
+
) -> Result<Option<TokenFamilyId>, DbError> {
251
273
let row = sqlx::query_scalar!(
252
274
r#"
253
275
SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
···
257
279
.fetch_optional(&self.pool)
258
280
.await
259
281
.map_err(map_sqlx_error)?;
260
-
Ok(row)
282
+
Ok(row.map(TokenFamilyId::new))
261
283
}
262
284
263
285
async fn delete_token(&self, token_id: &TokenId) -> Result<(), DbError> {
···
273
295
Ok(())
274
296
}
275
297
276
-
async fn delete_token_family(&self, db_id: i32) -> Result<(), DbError> {
298
+
async fn delete_token_family(&self, db_id: TokenFamilyId) -> Result<(), DbError> {
277
299
sqlx::query!(
278
300
r#"
279
301
DELETE FROM oauth_token WHERE id = $1
280
302
"#,
281
-
db_id
303
+
db_id.as_i32()
282
304
)
283
305
.execute(&self.pool)
284
306
.await
···
302
324
rows.into_iter()
303
325
.map(|r| {
304
326
Ok(TokenData {
305
-
did: r.did,
306
-
token_id: r.token_id,
327
+
did: r
328
+
.did
329
+
.parse()
330
+
.map_err(|_| DbError::Other("Invalid DID in token".into()))?,
331
+
token_id: OAuthTokenId(r.token_id),
307
332
created_at: r.created_at,
308
333
updated_at: r.updated_at,
309
334
expires_at: r.expires_at,
310
335
client_id: r.client_id,
311
336
client_auth: from_json(r.client_auth)?,
312
-
device_id: r.device_id,
337
+
device_id: r.device_id.map(OAuthDeviceId),
313
338
parameters: from_json(r.parameters)?,
314
339
details: r.details,
315
-
code: r.code,
316
-
current_refresh_token: r.current_refresh_token,
340
+
code: r.code.map(OAuthCode),
341
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
317
342
scope: r.scope,
318
-
controller_did: r.controller_did,
343
+
controller_did: r
344
+
.controller_did
345
+
.map(|s| s.parse())
346
+
.transpose()
347
+
.map_err(|_| DbError::Other("Invalid controller DID".into()))?,
319
348
})
320
349
})
321
350
.collect()
···
407
436
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
408
437
"#,
409
438
request_id.as_str(),
410
-
data.did,
411
-
data.device_id,
439
+
data.did.as_ref().map(|d| d.as_str()),
440
+
data.device_id.as_ref().map(|d| d.0.as_str()),
412
441
data.client_id,
413
442
client_auth_json,
414
443
parameters_json,
415
444
data.expires_at,
416
-
data.code,
445
+
data.code.as_ref().map(|c| c.0.as_str()),
417
446
)
418
447
.execute(&self.pool)
419
448
.await
···
448
477
client_auth,
449
478
parameters,
450
479
expires_at: r.expires_at,
451
-
did: r.did,
452
-
device_id: r.device_id,
453
-
code: r.code,
454
-
controller_did: r.controller_did,
480
+
did: r
481
+
.did
482
+
.map(|s| s.parse())
483
+
.transpose()
484
+
.map_err(|_| DbError::Other("Invalid DID in DB".into()))?,
485
+
device_id: r.device_id.map(OAuthDeviceId),
486
+
code: r.code.map(OAuthCode),
487
+
controller_did: r
488
+
.controller_did
489
+
.map(|s| s.parse())
490
+
.transpose()
491
+
.map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?,
455
492
}))
456
493
}
457
494
None => Ok(None),
···
534
571
client_auth,
535
572
parameters,
536
573
expires_at: r.expires_at,
537
-
did: r.did,
538
-
device_id: r.device_id,
539
-
code: r.code,
540
-
controller_did: r.controller_did,
574
+
did: r
575
+
.did
576
+
.map(|s| s.parse())
577
+
.transpose()
578
+
.map_err(|_| DbError::Other("Invalid DID in DB".into()))?,
579
+
device_id: r.device_id.map(OAuthDeviceId),
580
+
code: r.code.map(OAuthCode),
581
+
controller_did: r
582
+
.controller_did
583
+
.map(|s| s.parse())
584
+
.transpose()
585
+
.map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?,
541
586
}))
542
587
}
543
588
None => Ok(None),
···
655
700
VALUES ($1, $2, $3, $4, $5)
656
701
"#,
657
702
device_id.as_str(),
658
-
data.session_id,
703
+
&data.session_id.0,
659
704
data.user_agent,
660
705
data.ip_address,
661
706
data.last_seen_at,
···
679
724
.await
680
725
.map_err(map_sqlx_error)?;
681
726
Ok(row.map(|r| DeviceData {
682
-
session_id: r.session_id,
727
+
session_id: OAuthSessionId(r.session_id),
683
728
user_agent: r.user_agent,
684
729
ip_address: r.ip_address,
685
730
last_seen_at: r.last_seen_at,
···
1207
1252
Ok(rows
1208
1253
.into_iter()
1209
1254
.map(|r| OAuthSessionListItem {
1210
-
id: r.id,
1255
+
id: TokenFamilyId::new(r.id),
1211
1256
token_id: TokenId::from(r.token_id),
1212
1257
created_at: r.created_at,
1213
1258
expires_at: r.expires_at,
···
1216
1261
.collect())
1217
1262
}
1218
1263
1219
-
async fn delete_session_by_id(&self, session_id: i32, did: &Did) -> Result<u64, DbError> {
1264
+
async fn delete_session_by_id(
1265
+
&self,
1266
+
session_id: TokenFamilyId,
1267
+
did: &Did,
1268
+
) -> Result<u64, DbError> {
1220
1269
let result = sqlx::query!(
1221
1270
"DELETE FROM oauth_token WHERE id = $1 AND did = $2",
1222
-
session_id,
1271
+
session_id.as_i32(),
1223
1272
did.as_str()
1224
1273
)
1225
1274
.execute(&self.pool)
+146
-107
crates/tranquil-db/src/postgres/repo.rs
+146
-107
crates/tranquil-db/src/postgres/repo.rs
···
2
2
use chrono::{DateTime, Utc};
3
3
use sqlx::PgPool;
4
4
use tranquil_db_traits::{
5
-
BrokenGenesisCommit, CommitEventData, DbError, EventBlocksCids, FullRecordInfo, ImportBlock,
6
-
ImportRecord, ImportRepoError, RecordInfo, RecordWithTakedown, RepoAccountInfo, RepoInfo,
7
-
RepoListItem, RepoRepository, RepoWithoutRev, SequencedEvent, UserNeedingRecordBlobsBackfill,
8
-
UserWithoutBlocks,
5
+
AccountStatus, BrokenGenesisCommit, CommitEventData, DbError, EventBlocksCids, FullRecordInfo,
6
+
ImportBlock, ImportRecord, ImportRepoError, RecordInfo, RecordWithTakedown, RepoAccountInfo,
7
+
RepoEventType, RepoInfo, RepoListItem, RepoRepository, RepoWithoutRev, SequenceNumber,
8
+
SequencedEvent, UserNeedingRecordBlobsBackfill, UserWithoutBlocks,
9
9
};
10
10
use tranquil_types::{AtUri, CidLink, Did, Handle, Nsid, Rkey};
11
11
use uuid::Uuid;
···
21
21
seq: i64,
22
22
did: String,
23
23
created_at: DateTime<Utc>,
24
-
event_type: String,
24
+
event_type: RepoEventType,
25
25
commit_cid: Option<String>,
26
26
prev_cid: Option<String>,
27
27
prev_data_cid: Option<String>,
···
627
627
Ok(rows.into_iter().map(|(cid,)| cid).collect())
628
628
}
629
629
630
-
async fn insert_commit_event(&self, data: &CommitEventData) -> Result<i64, DbError> {
630
+
async fn insert_commit_event(&self, data: &CommitEventData) -> Result<SequenceNumber, DbError> {
631
631
let seq = sqlx::query_scalar!(
632
632
r#"
633
633
INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid, rev)
···
635
635
RETURNING seq
636
636
"#,
637
637
data.did.as_str(),
638
-
data.event_type,
638
+
data.event_type.as_str(),
639
639
data.commit_cid.as_ref().map(|c| c.as_str()),
640
640
data.prev_cid.as_ref().map(|c| c.as_str()),
641
641
data.ops,
···
648
648
.await
649
649
.map_err(map_sqlx_error)?;
650
650
651
-
Ok(seq)
651
+
Ok(seq.into())
652
652
}
653
653
654
654
async fn insert_identity_event(
655
655
&self,
656
656
did: &Did,
657
657
handle: Option<&Handle>,
658
-
) -> Result<i64, DbError> {
658
+
) -> Result<SequenceNumber, DbError> {
659
659
let handle_str = handle.map(|h| h.as_str());
660
660
let seq = sqlx::query_scalar!(
661
661
r#"
···
675
675
.await
676
676
.map_err(map_sqlx_error)?;
677
677
678
-
Ok(seq)
678
+
Ok(seq.into())
679
679
}
680
680
681
681
async fn insert_account_event(
682
682
&self,
683
683
did: &Did,
684
-
active: bool,
685
-
status: Option<&str>,
686
-
) -> Result<i64, DbError> {
684
+
status: AccountStatus,
685
+
) -> Result<SequenceNumber, DbError> {
686
+
let active = status.is_active();
687
+
let status_str = status.for_firehose();
687
688
let seq = sqlx::query_scalar!(
688
689
r#"
689
690
INSERT INTO repo_seq (did, event_type, active, status)
···
692
693
"#,
693
694
did.as_str(),
694
695
active,
695
-
status
696
+
status_str
696
697
)
697
698
.fetch_one(&self.pool)
698
699
.await
···
703
704
.await
704
705
.map_err(map_sqlx_error)?;
705
706
706
-
Ok(seq)
707
+
Ok(seq.into())
707
708
}
708
709
709
710
async fn insert_sync_event(
···
711
712
did: &Did,
712
713
commit_cid: &CidLink,
713
714
rev: Option<&str>,
714
-
) -> Result<i64, DbError> {
715
+
) -> Result<SequenceNumber, DbError> {
715
716
let seq = sqlx::query_scalar!(
716
717
r#"
717
718
INSERT INTO repo_seq (did, event_type, commit_cid, rev)
···
731
732
.await
732
733
.map_err(map_sqlx_error)?;
733
734
734
-
Ok(seq)
735
+
Ok(seq.into())
735
736
}
736
737
737
738
async fn insert_genesis_commit_event(
···
740
741
commit_cid: &CidLink,
741
742
mst_root_cid: &CidLink,
742
743
rev: &str,
743
-
) -> Result<i64, DbError> {
744
+
) -> Result<SequenceNumber, DbError> {
744
745
let ops = serde_json::json!([]);
745
746
let blobs: Vec<String> = vec![];
746
747
let blocks_cids: Vec<String> = vec![mst_root_cid.to_string(), commit_cid.to_string()];
···
769
770
.await
770
771
.map_err(map_sqlx_error)?;
771
772
772
-
Ok(seq)
773
+
Ok(seq.into())
773
774
}
774
775
775
776
async fn update_seq_blocks_cids(
776
777
&self,
777
-
seq: i64,
778
+
seq: SequenceNumber,
778
779
blocks_cids: &[String],
779
780
) -> Result<(), DbError> {
780
781
sqlx::query!(
781
782
"UPDATE repo_seq SET blocks_cids = $1 WHERE seq = $2",
782
783
blocks_cids,
783
-
seq
784
+
seq.as_i64()
784
785
)
785
786
.execute(&self.pool)
786
787
.await
···
789
790
Ok(())
790
791
}
791
792
792
-
async fn delete_sequences_except(&self, did: &Did, keep_seq: i64) -> Result<(), DbError> {
793
+
async fn delete_sequences_except(
794
+
&self,
795
+
did: &Did,
796
+
keep_seq: SequenceNumber,
797
+
) -> Result<(), DbError> {
793
798
sqlx::query!(
794
799
"DELETE FROM repo_seq WHERE did = $1 AND seq != $2",
795
800
did.as_str(),
796
-
keep_seq
801
+
keep_seq.as_i64()
797
802
)
798
803
.execute(&self.pool)
799
804
.await
···
802
807
Ok(())
803
808
}
804
809
805
-
async fn get_max_seq(&self) -> Result<i64, DbError> {
810
+
async fn get_max_seq(&self) -> Result<SequenceNumber, DbError> {
806
811
let seq = sqlx::query_scalar!(r#"SELECT COALESCE(MAX(seq), 0) as "max!" FROM repo_seq"#)
807
812
.fetch_one(&self.pool)
808
813
.await
809
814
.map_err(map_sqlx_error)?;
810
815
811
-
Ok(seq)
816
+
Ok(seq.into())
812
817
}
813
818
814
-
async fn get_min_seq_since(&self, since: DateTime<Utc>) -> Result<Option<i64>, DbError> {
819
+
async fn get_min_seq_since(
820
+
&self,
821
+
since: DateTime<Utc>,
822
+
) -> Result<Option<SequenceNumber>, DbError> {
815
823
let seq = sqlx::query_scalar!(
816
824
"SELECT MIN(seq) FROM repo_seq WHERE created_at >= $1",
817
825
since
···
820
828
.await
821
829
.map_err(map_sqlx_error)?;
822
830
823
-
Ok(seq)
831
+
Ok(seq.map(SequenceNumber::from))
824
832
}
825
833
826
834
async fn get_account_with_repo(&self, did: &Did) -> Result<Option<RepoAccountInfo>, DbError> {
···
846
854
847
855
async fn get_events_since_seq(
848
856
&self,
849
-
since_seq: i64,
857
+
since_seq: SequenceNumber,
850
858
limit: Option<i64>,
851
859
) -> Result<Vec<SequencedEvent>, DbError> {
852
-
let map_row = |r: SequencedEventRow| SequencedEvent {
853
-
seq: r.seq,
854
-
did: Did::from(r.did),
855
-
created_at: r.created_at,
856
-
event_type: r.event_type,
857
-
commit_cid: r.commit_cid.map(CidLink::from),
858
-
prev_cid: r.prev_cid.map(CidLink::from),
859
-
prev_data_cid: r.prev_data_cid.map(CidLink::from),
860
-
ops: r.ops,
861
-
blobs: r.blobs,
862
-
blocks_cids: r.blocks_cids,
863
-
handle: r.handle.map(Handle::from),
864
-
active: r.active,
865
-
status: r.status,
866
-
rev: r.rev,
860
+
let map_row = |r: SequencedEventRow| {
861
+
let status = r
862
+
.status
863
+
.as_deref()
864
+
.and_then(AccountStatus::parse)
865
+
.or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active));
866
+
SequencedEvent {
867
+
seq: r.seq.into(),
868
+
did: Did::from(r.did),
869
+
created_at: r.created_at,
870
+
event_type: r.event_type,
871
+
commit_cid: r.commit_cid.map(CidLink::from),
872
+
prev_cid: r.prev_cid.map(CidLink::from),
873
+
prev_data_cid: r.prev_data_cid.map(CidLink::from),
874
+
ops: r.ops,
875
+
blobs: r.blobs,
876
+
blocks_cids: r.blocks_cids,
877
+
handle: r.handle.map(Handle::from),
878
+
active: r.active,
879
+
status,
880
+
rev: r.rev,
881
+
}
867
882
};
868
883
match limit {
869
884
Some(lim) => {
870
885
let rows = sqlx::query_as!(
871
886
SequencedEventRow,
872
-
r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,
887
+
r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid,
873
888
ops, blobs, blocks_cids, handle, active, status, rev
874
889
FROM repo_seq
875
890
WHERE seq > $1
876
891
ORDER BY seq ASC
877
892
LIMIT $2"#,
878
-
since_seq,
893
+
since_seq.as_i64(),
879
894
lim
880
895
)
881
896
.fetch_all(&self.pool)
···
886
901
None => {
887
902
let rows = sqlx::query_as!(
888
903
SequencedEventRow,
889
-
r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,
904
+
r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid,
890
905
ops, blobs, blocks_cids, handle, active, status, rev
891
906
FROM repo_seq
892
907
WHERE seq > $1
893
908
ORDER BY seq ASC"#,
894
-
since_seq
909
+
since_seq.as_i64()
895
910
)
896
911
.fetch_all(&self.pool)
897
912
.await
···
903
918
904
919
async fn get_events_in_seq_range(
905
920
&self,
906
-
start_seq: i64,
907
-
end_seq: i64,
921
+
start_seq: SequenceNumber,
922
+
end_seq: SequenceNumber,
908
923
) -> Result<Vec<SequencedEvent>, DbError> {
909
924
let rows = sqlx::query!(
910
-
r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,
925
+
r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid,
911
926
ops, blobs, blocks_cids, handle, active, status, rev
912
927
FROM repo_seq
913
928
WHERE seq > $1 AND seq < $2
914
929
ORDER BY seq ASC"#,
915
-
start_seq,
916
-
end_seq
930
+
start_seq.as_i64(),
931
+
end_seq.as_i64()
917
932
)
918
933
.fetch_all(&self.pool)
919
934
.await
920
935
.map_err(map_sqlx_error)?;
921
936
Ok(rows
922
937
.into_iter()
923
-
.map(|r| SequencedEvent {
924
-
seq: r.seq,
925
-
did: Did::from(r.did),
926
-
created_at: r.created_at,
927
-
event_type: r.event_type,
928
-
commit_cid: r.commit_cid.map(CidLink::from),
929
-
prev_cid: r.prev_cid.map(CidLink::from),
930
-
prev_data_cid: r.prev_data_cid.map(CidLink::from),
931
-
ops: r.ops,
932
-
blobs: r.blobs,
933
-
blocks_cids: r.blocks_cids,
934
-
handle: r.handle.map(Handle::from),
935
-
active: r.active,
936
-
status: r.status,
937
-
rev: r.rev,
938
+
.map(|r| {
939
+
let status = r
940
+
.status
941
+
.as_deref()
942
+
.and_then(AccountStatus::parse)
943
+
.or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active));
944
+
SequencedEvent {
945
+
seq: r.seq.into(),
946
+
did: Did::from(r.did),
947
+
created_at: r.created_at,
948
+
event_type: r.event_type,
949
+
commit_cid: r.commit_cid.map(CidLink::from),
950
+
prev_cid: r.prev_cid.map(CidLink::from),
951
+
prev_data_cid: r.prev_data_cid.map(CidLink::from),
952
+
ops: r.ops,
953
+
blobs: r.blobs,
954
+
blocks_cids: r.blocks_cids,
955
+
handle: r.handle.map(Handle::from),
956
+
active: r.active,
957
+
status,
958
+
rev: r.rev,
959
+
}
938
960
})
939
961
.collect())
940
962
}
941
963
942
-
async fn get_event_by_seq(&self, seq: i64) -> Result<Option<SequencedEvent>, DbError> {
964
+
async fn get_event_by_seq(
965
+
&self,
966
+
seq: SequenceNumber,
967
+
) -> Result<Option<SequencedEvent>, DbError> {
943
968
let row = sqlx::query!(
944
-
r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,
969
+
r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid,
945
970
ops, blobs, blocks_cids, handle, active, status, rev
946
971
FROM repo_seq
947
972
WHERE seq = $1"#,
948
-
seq
973
+
seq.as_i64()
949
974
)
950
975
.fetch_optional(&self.pool)
951
976
.await
952
977
.map_err(map_sqlx_error)?;
953
-
Ok(row.map(|r| SequencedEvent {
954
-
seq: r.seq,
955
-
did: Did::from(r.did),
956
-
created_at: r.created_at,
957
-
event_type: r.event_type,
958
-
commit_cid: r.commit_cid.map(CidLink::from),
959
-
prev_cid: r.prev_cid.map(CidLink::from),
960
-
prev_data_cid: r.prev_data_cid.map(CidLink::from),
961
-
ops: r.ops,
962
-
blobs: r.blobs,
963
-
blocks_cids: r.blocks_cids,
964
-
handle: r.handle.map(Handle::from),
965
-
active: r.active,
966
-
status: r.status,
967
-
rev: r.rev,
978
+
Ok(row.map(|r| {
979
+
let status = r
980
+
.status
981
+
.as_deref()
982
+
.and_then(AccountStatus::parse)
983
+
.or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active));
984
+
SequencedEvent {
985
+
seq: r.seq.into(),
986
+
did: Did::from(r.did),
987
+
created_at: r.created_at,
988
+
event_type: r.event_type,
989
+
commit_cid: r.commit_cid.map(CidLink::from),
990
+
prev_cid: r.prev_cid.map(CidLink::from),
991
+
prev_data_cid: r.prev_data_cid.map(CidLink::from),
992
+
ops: r.ops,
993
+
blobs: r.blobs,
994
+
blocks_cids: r.blocks_cids,
995
+
handle: r.handle.map(Handle::from),
996
+
active: r.active,
997
+
status,
998
+
rev: r.rev,
999
+
}
968
1000
}))
969
1001
}
970
1002
971
1003
async fn get_events_since_cursor(
972
1004
&self,
973
-
cursor: i64,
1005
+
cursor: SequenceNumber,
974
1006
limit: i64,
975
1007
) -> Result<Vec<SequencedEvent>, DbError> {
976
1008
let rows = sqlx::query!(
977
-
r#"SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid,
1009
+
r#"SELECT seq, did, created_at, event_type as "event_type: RepoEventType", commit_cid, prev_cid, prev_data_cid,
978
1010
ops, blobs, blocks_cids, handle, active, status, rev
979
1011
FROM repo_seq
980
1012
WHERE seq > $1
981
1013
ORDER BY seq ASC
982
1014
LIMIT $2"#,
983
-
cursor,
1015
+
cursor.as_i64(),
984
1016
limit
985
1017
)
986
1018
.fetch_all(&self.pool)
···
988
1020
.map_err(map_sqlx_error)?;
989
1021
Ok(rows
990
1022
.into_iter()
991
-
.map(|r| SequencedEvent {
992
-
seq: r.seq,
993
-
did: Did::from(r.did),
994
-
created_at: r.created_at,
995
-
event_type: r.event_type,
996
-
commit_cid: r.commit_cid.map(CidLink::from),
997
-
prev_cid: r.prev_cid.map(CidLink::from),
998
-
prev_data_cid: r.prev_data_cid.map(CidLink::from),
999
-
ops: r.ops,
1000
-
blobs: r.blobs,
1001
-
blocks_cids: r.blocks_cids,
1002
-
handle: r.handle.map(Handle::from),
1003
-
active: r.active,
1004
-
status: r.status,
1005
-
rev: r.rev,
1023
+
.map(|r| {
1024
+
let status = r
1025
+
.status
1026
+
.as_deref()
1027
+
.and_then(AccountStatus::parse)
1028
+
.or_else(|| r.active.filter(|a| *a).map(|_| AccountStatus::Active));
1029
+
SequencedEvent {
1030
+
seq: r.seq.into(),
1031
+
did: Did::from(r.did),
1032
+
created_at: r.created_at,
1033
+
event_type: r.event_type,
1034
+
commit_cid: r.commit_cid.map(CidLink::from),
1035
+
prev_cid: r.prev_cid.map(CidLink::from),
1036
+
prev_data_cid: r.prev_data_cid.map(CidLink::from),
1037
+
ops: r.ops,
1038
+
blobs: r.blobs,
1039
+
blocks_cids: r.blocks_cids,
1040
+
handle: r.handle.map(Handle::from),
1041
+
active: r.active,
1042
+
status,
1043
+
rev: r.rev,
1044
+
}
1006
1045
})
1007
1046
.collect())
1008
1047
}
···
1079
1118
Ok(cid.map(CidLink::from))
1080
1119
}
1081
1120
1082
-
async fn notify_update(&self, seq: i64) -> Result<(), DbError> {
1083
-
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq))
1121
+
async fn notify_update(&self, seq: SequenceNumber) -> Result<(), DbError> {
1122
+
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq.as_i64()))
1084
1123
.execute(&self.pool)
1085
1124
.await
1086
1125
.map_err(map_sqlx_error)?;
···
1329
1368
"#,
1330
1369
)
1331
1370
.bind(&event.did)
1332
-
.bind(&event.event_type)
1371
+
.bind(event.event_type.as_str())
1333
1372
.bind(&event.commit_cid)
1334
1373
.bind(&event.prev_cid)
1335
1374
.bind(&event.ops)
···
1375
1414
Ok(rows
1376
1415
.into_iter()
1377
1416
.map(|r| BrokenGenesisCommit {
1378
-
seq: r.seq,
1417
+
seq: r.seq.into(),
1379
1418
did: Did::from(r.did),
1380
1419
commit_cid: r.commit_cid.map(CidLink::from),
1381
1420
})
+42
-33
crates/tranquil-db/src/postgres/session.rs
+42
-33
crates/tranquil-db/src/postgres/session.rs
···
2
2
use chrono::{DateTime, Utc};
3
3
use sqlx::PgPool;
4
4
use tranquil_db_traits::{
5
-
AppPasswordCreate, AppPasswordRecord, DbError, RefreshSessionResult, SessionForRefresh,
6
-
SessionListItem, SessionMfaStatus, SessionRefreshData, SessionRepository, SessionToken,
7
-
SessionTokenCreate,
5
+
AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, DbError, LoginType,
6
+
RefreshSessionResult, SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus,
7
+
SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate,
8
8
};
9
9
use tranquil_types::Did;
10
10
use uuid::Uuid;
···
23
23
24
24
#[async_trait]
25
25
impl SessionRepository for PostgresSessionRepository {
26
-
async fn create_session(&self, data: &SessionTokenCreate) -> Result<i32, DbError> {
26
+
async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError> {
27
27
let row = sqlx::query!(
28
28
r#"
29
29
INSERT INTO session_tokens
···
37
37
data.refresh_jti,
38
38
data.access_expires_at,
39
39
data.refresh_expires_at,
40
-
data.legacy_login,
40
+
bool::from(data.login_type),
41
41
data.mfa_verified,
42
42
data.scope,
43
43
data.controller_did.as_ref().map(|d| d.as_str()),
···
47
47
.await
48
48
.map_err(map_sqlx_error)?;
49
49
50
-
Ok(row.id)
50
+
Ok(SessionId::new(row.id))
51
51
}
52
52
53
53
async fn get_session_by_access_jti(
···
69
69
.map_err(map_sqlx_error)?;
70
70
71
71
Ok(row.map(|r| SessionToken {
72
-
id: r.id,
72
+
id: SessionId::new(r.id),
73
73
did: Did::from(r.did),
74
74
access_jti: r.access_jti,
75
75
refresh_jti: r.refresh_jti,
76
76
access_expires_at: r.access_expires_at,
77
77
refresh_expires_at: r.refresh_expires_at,
78
-
legacy_login: r.legacy_login,
78
+
login_type: LoginType::from(r.legacy_login),
79
79
mfa_verified: r.mfa_verified,
80
80
scope: r.scope,
81
81
controller_did: r.controller_did.map(Did::from),
···
104
104
.map_err(map_sqlx_error)?;
105
105
106
106
Ok(row.map(|r| SessionForRefresh {
107
-
id: r.id,
107
+
id: SessionId::new(r.id),
108
108
did: Did::from(r.did),
109
109
scope: r.scope,
110
110
controller_did: r.controller_did.map(Did::from),
···
115
115
116
116
async fn update_session_tokens(
117
117
&self,
118
-
session_id: i32,
118
+
session_id: SessionId,
119
119
new_access_jti: &str,
120
120
new_refresh_jti: &str,
121
121
new_access_expires_at: DateTime<Utc>,
···
132
132
new_refresh_jti,
133
133
new_access_expires_at,
134
134
new_refresh_expires_at,
135
-
session_id
135
+
session_id.as_i32()
136
136
)
137
137
.execute(&self.pool)
138
138
.await
···
153
153
Ok(result.rows_affected())
154
154
}
155
155
156
-
async fn delete_session_by_id(&self, session_id: i32) -> Result<u64, DbError> {
157
-
let result = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
158
-
.execute(&self.pool)
159
-
.await
160
-
.map_err(map_sqlx_error)?;
156
+
async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError> {
157
+
let result = sqlx::query!(
158
+
"DELETE FROM session_tokens WHERE id = $1",
159
+
session_id.as_i32()
160
+
)
161
+
.execute(&self.pool)
162
+
.await
163
+
.map_err(map_sqlx_error)?;
161
164
162
165
Ok(result.rows_affected())
163
166
}
···
205
208
Ok(rows
206
209
.into_iter()
207
210
.map(|r| SessionListItem {
208
-
id: r.id,
211
+
id: SessionId::new(r.id),
209
212
access_jti: r.access_jti,
210
213
created_at: r.created_at,
211
214
refresh_expires_at: r.refresh_expires_at,
···
215
218
216
219
async fn get_session_access_jti_by_id(
217
220
&self,
218
-
session_id: i32,
221
+
session_id: SessionId,
219
222
did: &Did,
220
223
) -> Result<Option<String>, DbError> {
221
224
let row = sqlx::query_scalar!(
222
225
"SELECT access_jti FROM session_tokens WHERE id = $1 AND did = $2",
223
-
session_id,
226
+
session_id.as_i32(),
224
227
did.as_str()
225
228
)
226
229
.fetch_optional(&self.pool)
···
264
267
Ok(rows)
265
268
}
266
269
267
-
async fn check_refresh_token_used(&self, refresh_jti: &str) -> Result<Option<i32>, DbError> {
270
+
async fn check_refresh_token_used(
271
+
&self,
272
+
refresh_jti: &str,
273
+
) -> Result<Option<SessionId>, DbError> {
268
274
let row = sqlx::query_scalar!(
269
275
"SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1",
270
276
refresh_jti
···
273
279
.await
274
280
.map_err(map_sqlx_error)?;
275
281
276
-
Ok(row)
282
+
Ok(row.map(SessionId::new))
277
283
}
278
284
279
285
async fn mark_refresh_token_used(
280
286
&self,
281
287
refresh_jti: &str,
282
-
session_id: i32,
288
+
session_id: SessionId,
283
289
) -> Result<bool, DbError> {
284
290
let result = sqlx::query!(
285
291
r#"
···
288
294
ON CONFLICT (refresh_jti) DO NOTHING
289
295
"#,
290
296
refresh_jti,
291
-
session_id
297
+
session_id.as_i32()
292
298
)
293
299
.execute(&self.pool)
294
300
.await
···
319
325
name: r.name,
320
326
password_hash: r.password_hash,
321
327
created_at: r.created_at,
322
-
privileged: r.privileged,
328
+
privilege: AppPasswordPrivilege::from(r.privileged),
323
329
scopes: r.scopes,
324
330
created_by_controller_did: r.created_by_controller_did.map(Did::from),
325
331
})
···
352
358
name: r.name,
353
359
password_hash: r.password_hash,
354
360
created_at: r.created_at,
355
-
privileged: r.privileged,
361
+
privilege: AppPasswordPrivilege::from(r.privileged),
356
362
scopes: r.scopes,
357
363
created_by_controller_did: r.created_by_controller_did.map(Did::from),
358
364
})
···
383
389
name: r.name,
384
390
password_hash: r.password_hash,
385
391
created_at: r.created_at,
386
-
privileged: r.privileged,
392
+
privilege: AppPasswordPrivilege::from(r.privileged),
387
393
scopes: r.scopes,
388
394
created_by_controller_did: r.created_by_controller_did.map(Did::from),
389
395
}))
···
399
405
data.user_id,
400
406
data.name,
401
407
data.password_hash,
402
-
data.privileged,
408
+
bool::from(data.privilege),
403
409
data.scopes,
404
410
data.created_by_controller_did.as_ref().map(|d| d.as_str())
405
411
)
···
480
486
.map_err(map_sqlx_error)?;
481
487
482
488
Ok(row.map(|r| SessionMfaStatus {
483
-
legacy_login: r.legacy_login,
489
+
login_type: LoginType::from(r.legacy_login),
484
490
mfa_verified: r.mfa_verified,
485
491
last_reauth_at: r.last_reauth_at,
486
492
}))
···
535
541
let result = sqlx::query!(
536
542
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
537
543
data.old_refresh_jti,
538
-
data.session_id
544
+
data.session_id.as_i32()
539
545
)
540
546
.execute(&mut *tx)
541
547
.await
542
548
.map_err(map_sqlx_error)?;
543
549
544
550
if result.rows_affected() == 0 {
545
-
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", data.session_id)
546
-
.execute(&mut *tx)
547
-
.await;
551
+
let _ = sqlx::query!(
552
+
"DELETE FROM session_tokens WHERE id = $1",
553
+
data.session_id.as_i32()
554
+
)
555
+
.execute(&mut *tx)
556
+
.await;
548
557
tx.commit().await.map_err(map_sqlx_error)?;
549
558
return Ok(RefreshSessionResult::ConcurrentRefresh);
550
559
}
···
555
564
data.new_refresh_jti,
556
565
data.new_access_expires_at,
557
566
data.new_refresh_expires_at,
558
-
data.session_id
567
+
data.session_id.as_i32()
559
568
)
560
569
.execute(&mut *tx)
561
570
.await
+33
-28
crates/tranquil-db/src/postgres/sso.rs
+33
-28
crates/tranquil-db/src/postgres/sso.rs
···
2
2
use chrono::Utc;
3
3
use sqlx::PgPool;
4
4
use tranquil_db_traits::{
5
-
DbError, ExternalIdentity, SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository,
5
+
DbError, ExternalEmail, ExternalIdentity, ExternalUserId, ExternalUsername, SsoAction,
6
+
SsoAuthState, SsoPendingRegistration, SsoProviderType, SsoRepository,
6
7
};
7
8
use tranquil_types::Did;
8
9
use uuid::Uuid;
···
69
70
70
71
Ok(row.map(|r| ExternalIdentity {
71
72
id: r.id,
72
-
did: Did::new_unchecked(&r.did),
73
+
did: unsafe { Did::new_unchecked(&r.did) },
73
74
provider: r.provider,
74
-
provider_user_id: r.provider_user_id,
75
-
provider_username: r.provider_username,
76
-
provider_email: r.provider_email,
75
+
provider_user_id: ExternalUserId::from(r.provider_user_id),
76
+
provider_username: r.provider_username.map(ExternalUsername::from),
77
+
provider_email: r.provider_email.map(ExternalEmail::from),
77
78
created_at: r.created_at,
78
79
updated_at: r.updated_at,
79
80
last_login_at: r.last_login_at,
···
102
103
.into_iter()
103
104
.map(|r| ExternalIdentity {
104
105
id: r.id,
105
-
did: Did::new_unchecked(&r.did),
106
+
did: unsafe { Did::new_unchecked(&r.did) },
106
107
provider: r.provider,
107
-
provider_user_id: r.provider_user_id,
108
-
provider_username: r.provider_username,
109
-
provider_email: r.provider_email,
108
+
provider_user_id: ExternalUserId::from(r.provider_user_id),
109
+
provider_username: r.provider_username.map(ExternalUsername::from),
110
+
provider_email: r.provider_email.map(ExternalEmail::from),
110
111
created_at: r.created_at,
111
112
updated_at: r.updated_at,
112
113
last_login_at: r.last_login_at,
···
161
162
state: &str,
162
163
request_uri: &str,
163
164
provider: SsoProviderType,
164
-
action: &str,
165
+
action: SsoAction,
165
166
nonce: Option<&str>,
166
167
code_verifier: Option<&str>,
167
168
did: Option<&Did>,
···
174
175
state,
175
176
request_uri,
176
177
provider as SsoProviderType,
177
-
action,
178
+
action.as_str(),
178
179
nonce,
179
180
code_verifier,
180
181
did.map(|d| d.as_str()),
···
200
201
.await
201
202
.map_err(map_sqlx_error)?;
202
203
203
-
Ok(row.map(|r| SsoAuthState {
204
-
state: r.state,
205
-
request_uri: r.request_uri,
206
-
provider: r.provider,
207
-
action: r.action,
208
-
nonce: r.nonce,
209
-
code_verifier: r.code_verifier,
210
-
did: r.did.map(|d| Did::new_unchecked(&d)),
211
-
created_at: r.created_at,
212
-
expires_at: r.expires_at,
213
-
}))
204
+
row.map(|r| {
205
+
let action = SsoAction::parse(&r.action).ok_or(DbError::NotFound)?;
206
+
Ok(SsoAuthState {
207
+
state: r.state,
208
+
request_uri: r.request_uri,
209
+
provider: r.provider,
210
+
action,
211
+
nonce: r.nonce,
212
+
code_verifier: r.code_verifier,
213
+
did: r.did.map(|d| unsafe { Did::new_unchecked(&d) }),
214
+
created_at: r.created_at,
215
+
expires_at: r.expires_at,
216
+
})
217
+
})
218
+
.transpose()
214
219
}
215
220
216
221
async fn cleanup_expired_sso_auth_states(&self) -> Result<u64, DbError> {
···
280
285
token: r.token,
281
286
request_uri: r.request_uri,
282
287
provider: r.provider,
283
-
provider_user_id: r.provider_user_id,
284
-
provider_username: r.provider_username,
285
-
provider_email: r.provider_email,
288
+
provider_user_id: ExternalUserId::from(r.provider_user_id),
289
+
provider_username: r.provider_username.map(ExternalUsername::from),
290
+
provider_email: r.provider_email.map(ExternalEmail::from),
286
291
provider_email_verified: r.provider_email_verified,
287
292
created_at: r.created_at,
288
293
expires_at: r.expires_at,
···
311
316
token: r.token,
312
317
request_uri: r.request_uri,
313
318
provider: r.provider,
314
-
provider_user_id: r.provider_user_id,
315
-
provider_username: r.provider_username,
316
-
provider_email: r.provider_email,
319
+
provider_user_id: ExternalUserId::from(r.provider_user_id),
320
+
provider_username: r.provider_username.map(ExternalUsername::from),
321
+
provider_email: r.provider_email.map(ExternalEmail::from),
317
322
provider_email_verified: r.provider_email_verified,
318
323
created_at: r.created_at,
319
324
expires_at: r.expires_at,
+66
-45
crates/tranquil-db/src/postgres/user.rs
+66
-45
crates/tranquil-db/src/postgres/user.rs
···
5
5
use uuid::Uuid;
6
6
7
7
use tranquil_db_traits::{
8
-
AccountSearchResult, CommsChannel, DbError, DidWebOverrides, NotificationPrefs,
9
-
OAuthTokenWithUser, PasswordResetResult, SsoProviderType, StoredBackupCode, StoredPasskey,
10
-
TotpRecord, User2faStatus, UserAuthInfo, UserCommsPrefs, UserConfirmSignup, UserDidWebInfo,
11
-
UserEmailInfo, UserForDeletion, UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery,
12
-
UserForPasskeySetup, UserForRecovery, UserForVerification, UserIdAndHandle,
13
-
UserIdAndPasswordHash, UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId,
14
-
UserLegacyLoginPref, UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo,
15
-
UserRepository, UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo,
16
-
UserStatus, UserVerificationInfo, UserWithKey,
8
+
AccountSearchResult, AccountType, ChannelVerificationStatus, CommsChannel, DbError,
9
+
DidWebOverrides, NotificationPrefs, OAuthTokenWithUser, PasswordResetResult, SsoProviderType,
10
+
StoredBackupCode, StoredPasskey, TotpRecord, TotpRecordState, User2faStatus, UserAuthInfo,
11
+
UserCommsPrefs, UserConfirmSignup, UserDidWebInfo, UserEmailInfo, UserForDeletion,
12
+
UserForDidDoc, UserForDidDocBuild, UserForPasskeyRecovery, UserForPasskeySetup,
13
+
UserForRecovery, UserForVerification, UserIdAndHandle, UserIdAndPasswordHash,
14
+
UserIdHandleEmail, UserInfoForAuth, UserKeyInfo, UserKeyWithId, UserLegacyLoginPref,
15
+
UserLoginCheck, UserLoginFull, UserLoginInfo, UserPasswordInfo, UserRepository,
16
+
UserResendVerification, UserResetCodeInfo, UserRow, UserSessionInfo, UserStatus,
17
+
UserVerificationInfo, UserWithKey,
17
18
};
18
19
19
20
pub struct PostgresUserRepository {
···
280
281
password_hash: r.password_hash,
281
282
deactivated_at: r.deactivated_at,
282
283
takedown_ref: r.takedown_ref,
283
-
email_verified: r.email_verified,
284
-
discord_verified: r.discord_verified,
285
-
telegram_verified: r.telegram_verified,
286
-
signal_verified: r.signal_verified,
284
+
channel_verification: ChannelVerificationStatus::new(
285
+
r.email_verified,
286
+
r.discord_verified,
287
+
r.telegram_verified,
288
+
r.signal_verified,
289
+
),
287
290
}))
288
291
}
289
292
···
308
311
309
312
async fn get_comms_prefs(&self, user_id: Uuid) -> Result<Option<UserCommsPrefs>, DbError> {
310
313
let row = sqlx::query!(
311
-
r#"SELECT email, handle, preferred_comms_channel::text as "preferred_channel!", preferred_locale
314
+
r#"SELECT email, handle, preferred_comms_channel as "preferred_channel!: CommsChannel", preferred_locale
312
315
FROM users WHERE id = $1"#,
313
316
user_id
314
317
)
···
601
604
let row = sqlx::query!(
602
605
r#"SELECT
603
606
email,
604
-
preferred_comms_channel::text as "preferred_channel!",
607
+
preferred_comms_channel as "preferred_channel!: CommsChannel",
605
608
discord_id,
606
609
discord_verified,
607
610
telegram_username,
···
647
650
async fn update_preferred_comms_channel(
648
651
&self,
649
652
did: &Did,
650
-
channel: &str,
653
+
channel: CommsChannel,
651
654
) -> Result<(), DbError> {
652
-
sqlx::query(
653
-
"UPDATE users SET preferred_comms_channel = $1::comms_channel, updated_at = NOW() WHERE did = $2",
655
+
sqlx::query!(
656
+
"UPDATE users SET preferred_comms_channel = $1, updated_at = NOW() WHERE did = $2",
657
+
channel as CommsChannel,
658
+
did.as_str()
654
659
)
655
-
.bind(channel)
656
-
.bind(did.as_str())
657
660
.execute(&self.pool)
658
661
.await
659
662
.map_err(map_sqlx_error)?;
···
709
712
id: r.id,
710
713
handle: Handle::from(r.handle),
711
714
email: r.email,
712
-
email_verified: r.email_verified,
713
-
discord_verified: r.discord_verified,
714
-
telegram_verified: r.telegram_verified,
715
-
signal_verified: r.signal_verified,
715
+
channel_verification: ChannelVerificationStatus::new(
716
+
r.email_verified,
717
+
r.discord_verified,
718
+
r.telegram_verified,
719
+
r.signal_verified,
720
+
),
716
721
}))
717
722
}
718
723
···
1065
1070
}))
1066
1071
}
1067
1072
1073
+
async fn get_totp_record_state(&self, did: &Did) -> Result<Option<TotpRecordState>, DbError> {
1074
+
self.get_totp_record(did)
1075
+
.await
1076
+
.map(|opt| opt.map(TotpRecordState::from))
1077
+
}
1078
+
1068
1079
async fn upsert_totp_secret(
1069
1080
&self,
1070
1081
did: &Did,
···
1300
1311
preferred_comms_channel as "preferred_comms_channel!: CommsChannel",
1301
1312
deactivated_at, takedown_ref,
1302
1313
email_verified, discord_verified, telegram_verified, signal_verified,
1303
-
account_type::text as "account_type!"
1314
+
account_type as "account_type!: AccountType"
1304
1315
FROM users
1305
1316
WHERE handle = $1 OR email = $1
1306
1317
"#,
···
1320
1331
preferred_comms_channel: row.preferred_comms_channel,
1321
1332
deactivated_at: row.deactivated_at,
1322
1333
takedown_ref: row.takedown_ref,
1323
-
email_verified: row.email_verified,
1324
-
discord_verified: row.discord_verified,
1325
-
telegram_verified: row.telegram_verified,
1326
-
signal_verified: row.signal_verified,
1334
+
channel_verification: ChannelVerificationStatus::new(
1335
+
row.email_verified,
1336
+
row.discord_verified,
1337
+
row.telegram_verified,
1338
+
row.signal_verified,
1339
+
),
1327
1340
account_type: row.account_type,
1328
1341
})
1329
1342
})
···
1348
1361
id: row.id,
1349
1362
two_factor_enabled: row.two_factor_enabled,
1350
1363
preferred_comms_channel: row.preferred_comms_channel,
1351
-
email_verified: row.email_verified,
1352
-
discord_verified: row.discord_verified,
1353
-
telegram_verified: row.telegram_verified,
1354
-
signal_verified: row.signal_verified,
1364
+
channel_verification: ChannelVerificationStatus::new(
1365
+
row.email_verified,
1366
+
row.discord_verified,
1367
+
row.telegram_verified,
1368
+
row.signal_verified,
1369
+
),
1355
1370
})
1356
1371
})
1357
1372
}
···
1376
1391
opt.map(|row| UserSessionInfo {
1377
1392
handle: Handle::from(row.handle),
1378
1393
email: row.email,
1379
-
email_verified: row.email_verified,
1380
1394
is_admin: row.is_admin,
1381
1395
deactivated_at: row.deactivated_at,
1382
1396
takedown_ref: row.takedown_ref,
1383
1397
preferred_locale: row.preferred_locale,
1384
1398
preferred_comms_channel: row.preferred_comms_channel,
1385
-
discord_verified: row.discord_verified,
1386
-
telegram_verified: row.telegram_verified,
1387
-
signal_verified: row.signal_verified,
1399
+
channel_verification: ChannelVerificationStatus::new(
1400
+
row.email_verified,
1401
+
row.discord_verified,
1402
+
row.telegram_verified,
1403
+
row.signal_verified,
1404
+
),
1388
1405
migrated_to_pds: row.migrated_to_pds,
1389
1406
migrated_at: row.migrated_at,
1390
1407
})
···
1469
1486
email: row.email,
1470
1487
deactivated_at: row.deactivated_at,
1471
1488
takedown_ref: row.takedown_ref,
1472
-
email_verified: row.email_verified,
1473
-
discord_verified: row.discord_verified,
1474
-
telegram_verified: row.telegram_verified,
1475
-
signal_verified: row.signal_verified,
1489
+
channel_verification: ChannelVerificationStatus::new(
1490
+
row.email_verified,
1491
+
row.discord_verified,
1492
+
row.telegram_verified,
1493
+
row.signal_verified,
1494
+
),
1476
1495
allow_legacy_login: row.allow_legacy_login,
1477
1496
migrated_to_pds: row.migrated_to_pds,
1478
1497
preferred_comms_channel: row.preferred_comms_channel,
···
1543
1562
discord_id: row.discord_id,
1544
1563
telegram_username: row.telegram_username,
1545
1564
signal_number: row.signal_number,
1546
-
email_verified: row.email_verified,
1547
-
discord_verified: row.discord_verified,
1548
-
telegram_verified: row.telegram_verified,
1549
-
signal_verified: row.signal_verified,
1565
+
channel_verification: ChannelVerificationStatus::new(
1566
+
row.email_verified,
1567
+
row.discord_verified,
1568
+
row.telegram_verified,
1569
+
row.signal_verified,
1570
+
),
1550
1571
})
1551
1572
})
1552
1573
}
+6
-4
crates/tranquil-oauth/src/lib.rs
+6
-4
crates/tranquil-oauth/src/lib.rs
···
10
10
};
11
11
pub use error::OAuthError;
12
12
pub use types::{
13
-
AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
-
AuthorizedClientData, ClientAuth, Code, DPoPClaims, DeviceData, DeviceId, JwkPublicKey, Jwks,
15
-
OAuthClientMetadata, ParResponse, ProtectedResourceMetadata, RefreshToken, RefreshTokenState,
16
-
RequestData, RequestId, SessionId, TokenData, TokenId, TokenRequest, TokenResponse,
13
+
AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
+
AuthorizedClientData, ClientAuth, Code, CodeChallengeMethod, DPoPClaims, DeviceData, DeviceId,
15
+
FlowAuthenticated, FlowAuthorized, FlowExpired, FlowNotAuthenticated, FlowNotAuthorized,
16
+
FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, ParResponse, Prompt,
17
+
ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId,
18
+
ResponseMode, ResponseType, SessionId, TokenData, TokenId, TokenRequest, TokenResponse,
17
19
};
+247
-144
crates/tranquil-oauth/src/types.rs
+247
-144
crates/tranquil-oauth/src/types.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use serde::{Deserialize, Serialize};
3
3
use serde_json::Value as JsonValue;
4
+
use tranquil_types::Did;
4
5
5
-
#[derive(Debug, Clone, Serialize, Deserialize)]
6
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
7
+
#[serde(transparent)]
8
+
#[sqlx(transparent)]
6
9
pub struct RequestId(pub String);
7
10
8
-
#[derive(Debug, Clone, Serialize, Deserialize)]
11
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
12
+
#[serde(transparent)]
13
+
#[sqlx(transparent)]
9
14
pub struct TokenId(pub String);
10
15
11
-
#[derive(Debug, Clone, Serialize, Deserialize)]
16
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
17
+
#[serde(transparent)]
18
+
#[sqlx(transparent)]
12
19
pub struct DeviceId(pub String);
13
20
14
-
#[derive(Debug, Clone, Serialize, Deserialize)]
21
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
22
+
#[serde(transparent)]
23
+
#[sqlx(transparent)]
15
24
pub struct SessionId(pub String);
16
25
17
-
#[derive(Debug, Clone, Serialize, Deserialize)]
26
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
27
+
#[serde(transparent)]
28
+
#[sqlx(transparent)]
18
29
pub struct Code(pub String);
19
30
20
-
#[derive(Debug, Clone, Serialize, Deserialize)]
31
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
32
+
#[serde(transparent)]
33
+
#[sqlx(transparent)]
21
34
pub struct RefreshToken(pub String);
22
35
23
36
impl RequestId {
···
82
95
PrivateKeyJwt { client_assertion: String },
83
96
}
84
97
98
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
99
+
#[serde(rename_all = "snake_case")]
100
+
pub enum ResponseType {
101
+
#[default]
102
+
Code,
103
+
}
104
+
105
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
106
+
pub enum CodeChallengeMethod {
107
+
#[default]
108
+
#[serde(rename = "S256")]
109
+
S256,
110
+
#[serde(rename = "plain")]
111
+
Plain,
112
+
}
113
+
114
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
115
+
#[serde(rename_all = "snake_case")]
116
+
pub enum ResponseMode {
117
+
#[default]
118
+
Query,
119
+
Fragment,
120
+
FormPost,
121
+
}
122
+
123
+
impl ResponseMode {
124
+
pub fn as_str(&self) -> &'static str {
125
+
match self {
126
+
Self::Query => "query",
127
+
Self::Fragment => "fragment",
128
+
Self::FormPost => "form_post",
129
+
}
130
+
}
131
+
}
132
+
133
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
134
+
#[serde(rename_all = "snake_case")]
135
+
pub enum Prompt {
136
+
None,
137
+
Login,
138
+
Consent,
139
+
SelectAccount,
140
+
Create,
141
+
}
142
+
143
+
impl Prompt {
144
+
pub fn as_str(&self) -> &'static str {
145
+
match self {
146
+
Self::None => "none",
147
+
Self::Login => "login",
148
+
Self::Consent => "consent",
149
+
Self::SelectAccount => "select_account",
150
+
Self::Create => "create",
151
+
}
152
+
}
153
+
}
154
+
85
155
#[derive(Debug, Clone, Serialize, Deserialize)]
86
156
pub struct AuthorizationRequestParameters {
87
-
pub response_type: String,
157
+
pub response_type: ResponseType,
88
158
pub client_id: String,
89
159
pub redirect_uri: String,
90
160
pub scope: Option<String>,
91
161
pub state: Option<String>,
92
162
pub code_challenge: String,
93
-
pub code_challenge_method: String,
94
-
pub response_mode: Option<String>,
163
+
pub code_challenge_method: CodeChallengeMethod,
164
+
pub response_mode: Option<ResponseMode>,
95
165
pub login_hint: Option<String>,
96
166
pub dpop_jkt: Option<String>,
97
-
pub prompt: Option<String>,
167
+
pub prompt: Option<Prompt>,
98
168
#[serde(flatten)]
99
169
pub extra: Option<JsonValue>,
100
170
}
···
105
175
pub client_auth: Option<ClientAuth>,
106
176
pub parameters: AuthorizationRequestParameters,
107
177
pub expires_at: DateTime<Utc>,
108
-
pub did: Option<String>,
109
-
pub device_id: Option<String>,
110
-
pub code: Option<String>,
111
-
pub controller_did: Option<String>,
178
+
pub did: Option<Did>,
179
+
pub device_id: Option<DeviceId>,
180
+
pub code: Option<Code>,
181
+
pub controller_did: Option<Did>,
112
182
}
113
183
114
184
#[derive(Debug, Clone)]
115
185
pub struct DeviceData {
116
-
pub session_id: String,
186
+
pub session_id: SessionId,
117
187
pub user_agent: Option<String>,
118
188
pub ip_address: String,
119
189
pub last_seen_at: DateTime<Utc>,
···
121
191
122
192
#[derive(Debug, Clone)]
123
193
pub struct TokenData {
124
-
pub did: String,
125
-
pub token_id: String,
194
+
pub did: Did,
195
+
pub token_id: TokenId,
126
196
pub created_at: DateTime<Utc>,
127
197
pub updated_at: DateTime<Utc>,
128
198
pub expires_at: DateTime<Utc>,
129
199
pub client_id: String,
130
200
pub client_auth: ClientAuth,
131
-
pub device_id: Option<String>,
201
+
pub device_id: Option<DeviceId>,
132
202
pub parameters: AuthorizationRequestParameters,
133
203
pub details: Option<JsonValue>,
134
-
pub code: Option<String>,
135
-
pub current_refresh_token: Option<String>,
204
+
pub code: Option<Code>,
205
+
pub current_refresh_token: Option<RefreshToken>,
136
206
pub scope: Option<String>,
137
-
pub controller_did: Option<String>,
207
+
pub controller_did: Option<Did>,
138
208
}
139
209
140
210
#[derive(Debug, Clone, Serialize, Deserialize)]
···
247
317
pub keys: Vec<JwkPublicKey>,
248
318
}
249
319
250
-
#[derive(Debug, Clone, PartialEq, Eq)]
251
-
pub enum AuthFlowState {
252
-
Pending,
253
-
Authenticated {
254
-
did: String,
255
-
device_id: Option<String>,
256
-
},
257
-
Authorized {
258
-
did: String,
259
-
device_id: Option<String>,
260
-
code: String,
261
-
},
262
-
Expired,
320
+
#[derive(Debug, Clone)]
321
+
pub struct FlowPending {
322
+
pub parameters: AuthorizationRequestParameters,
323
+
pub client_id: String,
324
+
pub client_auth: Option<ClientAuth>,
325
+
pub expires_at: DateTime<Utc>,
326
+
pub controller_did: Option<Did>,
263
327
}
264
328
265
-
impl AuthFlowState {
266
-
pub fn from_request_data(data: &RequestData) -> Self {
267
-
if data.expires_at < chrono::Utc::now() {
268
-
return AuthFlowState::Expired;
269
-
}
270
-
match (&data.did, &data.code) {
271
-
(Some(did), Some(code)) => AuthFlowState::Authorized {
272
-
did: did.clone(),
273
-
device_id: data.device_id.clone(),
274
-
code: code.clone(),
275
-
},
276
-
(Some(did), None) => AuthFlowState::Authenticated {
277
-
did: did.clone(),
278
-
device_id: data.device_id.clone(),
279
-
},
280
-
(None, _) => AuthFlowState::Pending,
281
-
}
282
-
}
329
+
#[derive(Debug, Clone)]
330
+
pub struct FlowAuthenticated {
331
+
pub parameters: AuthorizationRequestParameters,
332
+
pub client_id: String,
333
+
pub client_auth: Option<ClientAuth>,
334
+
pub expires_at: DateTime<Utc>,
335
+
pub did: Did,
336
+
pub device_id: Option<DeviceId>,
337
+
pub controller_did: Option<Did>,
338
+
}
283
339
284
-
pub fn is_pending(&self) -> bool {
285
-
matches!(self, AuthFlowState::Pending)
286
-
}
340
+
#[derive(Debug, Clone)]
341
+
pub struct FlowAuthorized {
342
+
pub parameters: AuthorizationRequestParameters,
343
+
pub client_id: String,
344
+
pub client_auth: Option<ClientAuth>,
345
+
pub expires_at: DateTime<Utc>,
346
+
pub did: Did,
347
+
pub device_id: Option<DeviceId>,
348
+
pub code: Code,
349
+
pub controller_did: Option<Did>,
350
+
}
287
351
288
-
pub fn is_authenticated(&self) -> bool {
289
-
matches!(self, AuthFlowState::Authenticated { .. })
290
-
}
352
+
#[derive(Debug)]
353
+
pub struct FlowExpired;
291
354
292
-
pub fn is_authorized(&self) -> bool {
293
-
matches!(self, AuthFlowState::Authorized { .. })
355
+
#[derive(Debug)]
356
+
pub struct FlowNotAuthenticated;
357
+
358
+
#[derive(Debug)]
359
+
pub struct FlowNotAuthorized;
360
+
361
+
#[derive(Debug, Clone)]
362
+
pub enum AuthFlow {
363
+
Pending(FlowPending),
364
+
Authenticated(FlowAuthenticated),
365
+
Authorized(FlowAuthorized),
366
+
}
367
+
368
+
#[derive(Debug, Clone)]
369
+
pub enum AuthFlowWithUser {
370
+
Authenticated(FlowAuthenticated),
371
+
Authorized(FlowAuthorized),
372
+
}
373
+
374
+
impl AuthFlow {
375
+
pub fn from_request_data(data: RequestData) -> Result<Self, FlowExpired> {
376
+
if data.expires_at < chrono::Utc::now() {
377
+
return Err(FlowExpired);
378
+
}
379
+
match (data.did, data.code) {
380
+
(None, _) => Ok(AuthFlow::Pending(FlowPending {
381
+
parameters: data.parameters,
382
+
client_id: data.client_id,
383
+
client_auth: data.client_auth,
384
+
expires_at: data.expires_at,
385
+
controller_did: data.controller_did,
386
+
})),
387
+
(Some(did), None) => Ok(AuthFlow::Authenticated(FlowAuthenticated {
388
+
parameters: data.parameters,
389
+
client_id: data.client_id,
390
+
client_auth: data.client_auth,
391
+
expires_at: data.expires_at,
392
+
did,
393
+
device_id: data.device_id,
394
+
controller_did: data.controller_did,
395
+
})),
396
+
(Some(did), Some(code)) => Ok(AuthFlow::Authorized(FlowAuthorized {
397
+
parameters: data.parameters,
398
+
client_id: data.client_id,
399
+
client_auth: data.client_auth,
400
+
expires_at: data.expires_at,
401
+
did,
402
+
device_id: data.device_id,
403
+
code,
404
+
controller_did: data.controller_did,
405
+
})),
406
+
}
294
407
}
295
408
296
-
pub fn is_expired(&self) -> bool {
297
-
matches!(self, AuthFlowState::Expired)
409
+
pub fn require_user(self) -> Result<AuthFlowWithUser, FlowNotAuthenticated> {
410
+
match self {
411
+
AuthFlow::Pending(_) => Err(FlowNotAuthenticated),
412
+
AuthFlow::Authenticated(a) => Ok(AuthFlowWithUser::Authenticated(a)),
413
+
AuthFlow::Authorized(a) => Ok(AuthFlowWithUser::Authorized(a)),
414
+
}
298
415
}
299
416
300
-
pub fn can_authenticate(&self) -> bool {
301
-
matches!(self, AuthFlowState::Pending)
417
+
pub fn require_authorized(self) -> Result<FlowAuthorized, FlowNotAuthorized> {
418
+
match self {
419
+
AuthFlow::Authorized(a) => Ok(a),
420
+
_ => Err(FlowNotAuthorized),
421
+
}
302
422
}
423
+
}
303
424
304
-
pub fn can_authorize(&self) -> bool {
305
-
matches!(self, AuthFlowState::Authenticated { .. })
425
+
impl AuthFlowWithUser {
426
+
pub fn did(&self) -> &Did {
427
+
match self {
428
+
AuthFlowWithUser::Authenticated(a) => &a.did,
429
+
AuthFlowWithUser::Authorized(a) => &a.did,
430
+
}
306
431
}
307
432
308
-
pub fn can_exchange(&self) -> bool {
309
-
matches!(self, AuthFlowState::Authorized { .. })
433
+
pub fn device_id(&self) -> Option<&DeviceId> {
434
+
match self {
435
+
AuthFlowWithUser::Authenticated(a) => a.device_id.as_ref(),
436
+
AuthFlowWithUser::Authorized(a) => a.device_id.as_ref(),
437
+
}
310
438
}
311
439
312
-
pub fn did(&self) -> Option<&str> {
440
+
pub fn parameters(&self) -> &AuthorizationRequestParameters {
313
441
match self {
314
-
AuthFlowState::Authenticated { did, .. } | AuthFlowState::Authorized { did, .. } => {
315
-
Some(did)
316
-
}
317
-
_ => None,
442
+
AuthFlowWithUser::Authenticated(a) => &a.parameters,
443
+
AuthFlowWithUser::Authorized(a) => &a.parameters,
318
444
}
319
445
}
320
446
321
-
pub fn code(&self) -> Option<&str> {
447
+
pub fn client_id(&self) -> &str {
322
448
match self {
323
-
AuthFlowState::Authorized { code, .. } => Some(code),
324
-
_ => None,
449
+
AuthFlowWithUser::Authenticated(a) => &a.client_id,
450
+
AuthFlowWithUser::Authorized(a) => &a.client_id,
325
451
}
326
452
}
327
-
}
328
453
329
-
impl std::fmt::Display for AuthFlowState {
330
-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454
+
pub fn controller_did(&self) -> Option<&Did> {
331
455
match self {
332
-
AuthFlowState::Pending => write!(f, "pending"),
333
-
AuthFlowState::Authenticated { did, .. } => write!(f, "authenticated ({})", did),
334
-
AuthFlowState::Authorized { did, code, .. } => {
335
-
write!(
336
-
f,
337
-
"authorized ({}, code={}...)",
338
-
did,
339
-
&code[..8.min(code.len())]
340
-
)
341
-
}
342
-
AuthFlowState::Expired => write!(f, "expired"),
456
+
AuthFlowWithUser::Authenticated(a) => a.controller_did.as_ref(),
457
+
AuthFlowWithUser::Authorized(a) => a.controller_did.as_ref(),
343
458
}
344
459
}
345
460
}
···
406
521
use chrono::{Duration, Utc};
407
522
408
523
fn make_request_data(
409
-
did: Option<String>,
410
-
code: Option<String>,
524
+
did: Option<Did>,
525
+
code: Option<Code>,
411
526
expires_in: Duration,
412
527
) -> RequestData {
413
528
RequestData {
414
529
client_id: "test-client".into(),
415
530
client_auth: None,
416
531
parameters: AuthorizationRequestParameters {
417
-
response_type: "code".into(),
532
+
response_type: ResponseType::Code,
418
533
client_id: "test-client".into(),
419
534
redirect_uri: "https://example.com/callback".into(),
420
535
scope: Some("atproto".into()),
421
536
state: None,
422
537
code_challenge: "test".into(),
423
-
code_challenge_method: "S256".into(),
538
+
code_challenge_method: CodeChallengeMethod::S256,
424
539
response_mode: None,
425
540
login_hint: None,
426
541
dpop_jkt: None,
···
435
550
}
436
551
}
437
552
553
+
fn test_did(s: &str) -> Did {
554
+
s.parse().expect("valid test DID")
555
+
}
556
+
557
+
fn test_code(s: &str) -> Code {
558
+
Code(s.to_string())
559
+
}
560
+
438
561
#[test]
439
-
fn test_auth_flow_state_pending() {
562
+
fn test_auth_flow_pending() {
440
563
let data = make_request_data(None, None, Duration::minutes(5));
441
-
let state = AuthFlowState::from_request_data(&data);
442
-
assert!(state.is_pending());
443
-
assert!(!state.is_authenticated());
444
-
assert!(!state.is_authorized());
445
-
assert!(!state.is_expired());
446
-
assert!(state.can_authenticate());
447
-
assert!(!state.can_authorize());
448
-
assert!(!state.can_exchange());
449
-
assert!(state.did().is_none());
450
-
assert!(state.code().is_none());
564
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
565
+
assert!(matches!(flow, AuthFlow::Pending(_)));
566
+
assert!(flow.clone().require_user().is_err());
567
+
assert!(flow.require_authorized().is_err());
451
568
}
452
569
453
570
#[test]
454
-
fn test_auth_flow_state_authenticated() {
455
-
let data = make_request_data(Some("did:plc:test".into()), None, Duration::minutes(5));
456
-
let state = AuthFlowState::from_request_data(&data);
457
-
assert!(!state.is_pending());
458
-
assert!(state.is_authenticated());
459
-
assert!(!state.is_authorized());
460
-
assert!(!state.is_expired());
461
-
assert!(!state.can_authenticate());
462
-
assert!(state.can_authorize());
463
-
assert!(!state.can_exchange());
464
-
assert_eq!(state.did(), Some("did:plc:test"));
465
-
assert!(state.code().is_none());
571
+
fn test_auth_flow_authenticated() {
572
+
let did = test_did("did:plc:test");
573
+
let data = make_request_data(Some(did.clone()), None, Duration::minutes(5));
574
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
575
+
assert!(matches!(flow, AuthFlow::Authenticated(_)));
576
+
let with_user = flow.clone().require_user().expect("should have user");
577
+
assert_eq!(with_user.did(), &did);
578
+
assert!(flow.require_authorized().is_err());
466
579
}
467
580
468
581
#[test]
469
-
fn test_auth_flow_state_authorized() {
470
-
let data = make_request_data(
471
-
Some("did:plc:test".into()),
472
-
Some("auth-code-123".into()),
473
-
Duration::minutes(5),
474
-
);
475
-
let state = AuthFlowState::from_request_data(&data);
476
-
assert!(!state.is_pending());
477
-
assert!(!state.is_authenticated());
478
-
assert!(state.is_authorized());
479
-
assert!(!state.is_expired());
480
-
assert!(!state.can_authenticate());
481
-
assert!(!state.can_authorize());
482
-
assert!(state.can_exchange());
483
-
assert_eq!(state.did(), Some("did:plc:test"));
484
-
assert_eq!(state.code(), Some("auth-code-123"));
582
+
fn test_auth_flow_authorized() {
583
+
let did = test_did("did:plc:test");
584
+
let code = test_code("auth-code-123");
585
+
let data = make_request_data(Some(did.clone()), Some(code.clone()), Duration::minutes(5));
586
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
587
+
assert!(matches!(flow, AuthFlow::Authorized(_)));
588
+
let with_user = flow.clone().require_user().expect("should have user");
589
+
assert_eq!(with_user.did(), &did);
590
+
let authorized = flow.require_authorized().expect("should be authorized");
591
+
assert_eq!(authorized.did, did);
592
+
assert_eq!(authorized.code, code);
485
593
}
486
594
487
595
#[test]
488
-
fn test_auth_flow_state_expired() {
489
-
let data = make_request_data(
490
-
Some("did:plc:test".into()),
491
-
Some("code".into()),
492
-
Duration::minutes(-1),
493
-
);
494
-
let state = AuthFlowState::from_request_data(&data);
495
-
assert!(state.is_expired());
496
-
assert!(!state.can_authenticate());
497
-
assert!(!state.can_authorize());
498
-
assert!(!state.can_exchange());
596
+
fn test_auth_flow_expired() {
597
+
let did = test_did("did:plc:test");
598
+
let code = test_code("code");
599
+
let data = make_request_data(Some(did), Some(code), Duration::minutes(-1));
600
+
let result = AuthFlow::from_request_data(data);
601
+
assert!(result.is_err());
499
602
}
500
603
501
604
#[test]
+1
crates/tranquil-pds/Cargo.toml
+1
crates/tranquil-pds/Cargo.toml
+10
-12
crates/tranquil-pds/src/api/admin/account/delete.rs
+10
-12
crates/tranquil-pds/src/api/admin/account/delete.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
5
use crate::types::Did;
···
9
9
response::{IntoResponse, Response},
10
10
};
11
11
use serde::Deserialize;
12
-
use tracing::{error, warn};
12
+
use tracing::warn;
13
13
14
14
#[derive(Deserialize)]
15
15
pub struct DeleteAccountInput {
···
26
26
.user_repo
27
27
.get_id_and_handle_by_did(did)
28
28
.await
29
-
.map_err(|e| {
30
-
error!("DB error in delete_account: {:?}", e);
31
-
ApiError::InternalError(None)
32
-
})?
29
+
.log_db_err("in delete_account")?
33
30
.ok_or(ApiError::AccountNotFound)
34
31
.map(|row| (row.id, row.handle))?;
35
32
···
37
34
.user_repo
38
35
.admin_delete_account_complete(user_id, did)
39
36
.await
40
-
.map_err(|e| {
41
-
error!("Failed to delete account {}: {:?}", did, e);
42
-
ApiError::InternalError(Some("Failed to delete account".into()))
43
-
})?;
37
+
.log_db_err("deleting account")?;
44
38
45
-
if let Err(e) =
46
-
crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await
39
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
40
+
&state,
41
+
did,
42
+
tranquil_db_traits::AccountStatus::Deleted,
43
+
)
44
+
.await
47
45
{
48
46
warn!(
49
47
"Failed to sequence account deletion event for {}: {}",
+5
-7
crates/tranquil-pds/src/api/admin/account/email.rs
+5
-7
crates/tranquil-pds/src/api/admin/account/email.rs
···
1
-
use crate::api::error::{ApiError, AtpJson};
1
+
use crate::api::error::{ApiError, AtpJson, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use crate::types::Did;
5
+
use crate::util::pds_hostname;
5
6
use axum::{
6
7
Json,
7
8
extract::State,
···
9
10
response::{IntoResponse, Response},
10
11
};
11
12
use serde::{Deserialize, Serialize};
12
-
use tracing::{error, warn};
13
+
use tracing::warn;
13
14
14
15
#[derive(Deserialize)]
15
16
#[serde(rename_all = "camelCase")]
···
39
40
.user_repo
40
41
.get_by_did(&input.recipient_did)
41
42
.await
42
-
.map_err(|e| {
43
-
error!("DB error in send_email: {:?}", e);
44
-
ApiError::InternalError(None)
45
-
})?
43
+
.log_db_err("in send_email")?
46
44
.ok_or(ApiError::AccountNotFound)?;
47
45
48
46
let email = user.email.ok_or(ApiError::NoEmail)?;
49
47
let (user_id, handle) = (user.id, user.handle);
50
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
48
+
let hostname = pds_hostname();
51
49
let subject = input
52
50
.subject
53
51
.clone()
+6
-13
crates/tranquil-pds/src/api/admin/account/info.rs
+6
-13
crates/tranquil-pds/src/api/admin/account/info.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use crate::types::{Did, Handle};
···
10
10
};
11
11
use serde::{Deserialize, Serialize};
12
12
use std::collections::HashMap;
13
-
use tracing::error;
14
13
15
14
#[derive(Deserialize)]
16
15
pub struct GetAccountInfoParams {
···
74
73
.infra_repo
75
74
.get_admin_account_info_by_did(¶ms.did)
76
75
.await
77
-
.map_err(|e| {
78
-
error!("DB error in get_account_info: {:?}", e);
79
-
ApiError::InternalError(None)
80
-
})?
76
+
.log_db_err("in get_account_info")?
81
77
.ok_or(ApiError::AccountNotFound)?;
82
78
83
79
let invited_by = get_invited_by(&state, account.id).await;
···
153
149
.map(|ic| InviteCodeInfo {
154
150
code: ic.code.clone(),
155
151
available: ic.available_uses,
156
-
disabled: ic.disabled,
152
+
disabled: ic.state.is_disabled(),
157
153
for_account: ic.for_account,
158
154
created_by: ic.created_by,
159
155
created_at: ic.created_at.to_rfc3339(),
···
181
177
Some(InviteCodeInfo {
182
178
code: info.code,
183
179
available: info.available_uses,
184
-
disabled: info.disabled,
180
+
disabled: info.state.is_disabled(),
185
181
for_account: info.for_account,
186
182
created_by: info.created_by,
187
183
created_at: info.created_at.to_rfc3339(),
···
214
210
.infra_repo
215
211
.get_admin_account_infos_by_dids(&dids_typed)
216
212
.await
217
-
.map_err(|e| {
218
-
error!("Failed to fetch account infos: {:?}", e);
219
-
ApiError::InternalError(None)
220
-
})?;
213
+
.log_db_err("fetching account infos")?;
221
214
222
215
let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect();
223
216
···
272
265
let info = InviteCodeInfo {
273
266
code: ic.code.clone(),
274
267
available: ic.available_uses,
275
-
disabled: ic.disabled,
268
+
disabled: ic.state.is_disabled(),
276
269
for_account: ic.for_account,
277
270
created_by: ic.created_by,
278
271
created_at: ic.created_at.to_rfc3339(),
+2
-6
crates/tranquil-pds/src/api/admin/account/search.rs
+2
-6
crates/tranquil-pds/src/api/admin/account/search.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use crate::types::{Did, Handle};
···
9
9
response::{IntoResponse, Response},
10
10
};
11
11
use serde::{Deserialize, Serialize};
12
-
use tracing::error;
13
12
14
13
#[derive(Deserialize)]
15
14
pub struct SearchAccountsParams {
···
66
65
limit + 1,
67
66
)
68
67
.await
69
-
.map_err(|e| {
70
-
error!("DB error in search_accounts: {:?}", e);
71
-
ApiError::InternalError(None)
72
-
})?;
68
+
.log_db_err("in search_accounts")?;
73
69
74
70
let has_more = rows.len() > limit as usize;
75
71
let accounts: Vec<AccountView> = rows
+3
-3
crates/tranquil-pds/src/api/admin/account/update.rs
+3
-3
crates/tranquil-pds/src/api/admin/account/update.rs
···
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
5
use crate::types::{Did, Handle, PlainPassword};
6
+
use crate::util::pds_hostname_without_port;
6
7
use axum::{
7
8
Json,
8
9
extract::State,
···
69
70
{
70
71
return Err(ApiError::InvalidHandle(None));
71
72
}
72
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
73
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
73
+
let hostname_for_handles = pds_hostname_without_port();
74
74
let handle = if !input_handle.contains('.') {
75
75
format!("{}.{}", input_handle, hostname_for_handles)
76
76
} else {
···
84
84
.ok()
85
85
.flatten()
86
86
.ok_or(ApiError::AccountNotFound)?;
87
-
let handle_for_check = Handle::new_unchecked(&handle);
87
+
let handle_for_check = unsafe { Handle::new_unchecked(&handle) };
88
88
if let Ok(true) = state
89
89
.user_repo
90
90
.check_handle_exists(&handle_for_check, user_id)
+14
-50
crates/tranquil-pds/src/api/admin/config.rs
+14
-50
crates/tranquil-pds/src/api/admin/config.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use axum::{Json, extract::State};
···
56
56
.infra_repo
57
57
.get_server_configs(keys)
58
58
.await
59
-
.map_err(|e| {
60
-
error!("DB error fetching server config: {:?}", e);
61
-
ApiError::InternalError(None)
62
-
})?;
59
+
.log_db_err("fetching server config")?;
63
60
64
61
let config_map: std::collections::HashMap<String, String> = rows.into_iter().collect();
65
62
···
92
89
.infra_repo
93
90
.upsert_server_config("server_name", trimmed)
94
91
.await
95
-
.map_err(|e| {
96
-
error!("DB error upserting server_name: {:?}", e);
97
-
ApiError::InternalError(None)
98
-
})?;
92
+
.log_db_err("upserting server_name")?;
99
93
}
100
94
101
95
if let Some(ref color) = req.primary_color {
···
104
98
.infra_repo
105
99
.delete_server_config("primary_color")
106
100
.await
107
-
.map_err(|e| {
108
-
error!("DB error deleting primary_color: {:?}", e);
109
-
ApiError::InternalError(None)
110
-
})?;
101
+
.log_db_err("deleting primary_color")?;
111
102
} else if is_valid_hex_color(color) {
112
103
state
113
104
.infra_repo
114
105
.upsert_server_config("primary_color", color)
115
106
.await
116
-
.map_err(|e| {
117
-
error!("DB error upserting primary_color: {:?}", e);
118
-
ApiError::InternalError(None)
119
-
})?;
107
+
.log_db_err("upserting primary_color")?;
120
108
} else {
121
109
return Err(ApiError::InvalidRequest(
122
110
"Invalid primary color format (expected #RRGGBB)".into(),
···
130
118
.infra_repo
131
119
.delete_server_config("primary_color_dark")
132
120
.await
133
-
.map_err(|e| {
134
-
error!("DB error deleting primary_color_dark: {:?}", e);
135
-
ApiError::InternalError(None)
136
-
})?;
121
+
.log_db_err("deleting primary_color_dark")?;
137
122
} else if is_valid_hex_color(color) {
138
123
state
139
124
.infra_repo
140
125
.upsert_server_config("primary_color_dark", color)
141
126
.await
142
-
.map_err(|e| {
143
-
error!("DB error upserting primary_color_dark: {:?}", e);
144
-
ApiError::InternalError(None)
145
-
})?;
127
+
.log_db_err("upserting primary_color_dark")?;
146
128
} else {
147
129
return Err(ApiError::InvalidRequest(
148
130
"Invalid primary dark color format (expected #RRGGBB)".into(),
···
156
138
.infra_repo
157
139
.delete_server_config("secondary_color")
158
140
.await
159
-
.map_err(|e| {
160
-
error!("DB error deleting secondary_color: {:?}", e);
161
-
ApiError::InternalError(None)
162
-
})?;
141
+
.log_db_err("deleting secondary_color")?;
163
142
} else if is_valid_hex_color(color) {
164
143
state
165
144
.infra_repo
166
145
.upsert_server_config("secondary_color", color)
167
146
.await
168
-
.map_err(|e| {
169
-
error!("DB error upserting secondary_color: {:?}", e);
170
-
ApiError::InternalError(None)
171
-
})?;
147
+
.log_db_err("upserting secondary_color")?;
172
148
} else {
173
149
return Err(ApiError::InvalidRequest(
174
150
"Invalid secondary color format (expected #RRGGBB)".into(),
···
182
158
.infra_repo
183
159
.delete_server_config("secondary_color_dark")
184
160
.await
185
-
.map_err(|e| {
186
-
error!("DB error deleting secondary_color_dark: {:?}", e);
187
-
ApiError::InternalError(None)
188
-
})?;
161
+
.log_db_err("deleting secondary_color_dark")?;
189
162
} else if is_valid_hex_color(color) {
190
163
state
191
164
.infra_repo
192
165
.upsert_server_config("secondary_color_dark", color)
193
166
.await
194
-
.map_err(|e| {
195
-
error!("DB error upserting secondary_color_dark: {:?}", e);
196
-
ApiError::InternalError(None)
197
-
})?;
167
+
.log_db_err("upserting secondary_color_dark")?;
198
168
} else {
199
169
return Err(ApiError::InvalidRequest(
200
170
"Invalid secondary dark color format (expected #RRGGBB)".into(),
···
217
187
};
218
188
219
189
if let Some(old_cid_str) = should_delete_old {
220
-
let old_cid = CidLink::new_unchecked(old_cid_str);
190
+
let old_cid = unsafe { CidLink::new_unchecked(old_cid_str) };
221
191
if let Ok(Some(storage_key)) =
222
192
state.infra_repo.get_blob_storage_key_by_cid(&old_cid).await
223
193
{
···
235
205
.infra_repo
236
206
.delete_server_config("logo_cid")
237
207
.await
238
-
.map_err(|e| {
239
-
error!("DB error deleting logo_cid: {:?}", e);
240
-
ApiError::InternalError(None)
241
-
})?;
208
+
.log_db_err("deleting logo_cid")?;
242
209
} else {
243
210
state
244
211
.infra_repo
245
212
.upsert_server_config("logo_cid", logo_cid)
246
213
.await
247
-
.map_err(|e| {
248
-
error!("DB error upserting logo_cid: {:?}", e);
249
-
ApiError::InternalError(None)
250
-
})?;
214
+
.log_db_err("upserting logo_cid")?;
251
215
}
252
216
}
253
217
+3
-6
crates/tranquil-pds/src/api/admin/invite.rs
+3
-6
crates/tranquil-pds/src/api/admin/invite.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
5
use axum::{
···
91
91
.infra_repo
92
92
.list_invite_codes(params.cursor.as_deref(), limit, sort_order)
93
93
.await
94
-
.map_err(|e| {
95
-
error!("DB error fetching invite codes: {:?}", e);
96
-
ApiError::InternalError(None)
97
-
})?;
94
+
.log_db_err("fetching invite codes")?;
98
95
99
96
let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect();
100
97
let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect();
···
138
135
InviteCodeInfo {
139
136
code: r.code.clone(),
140
137
available: r.available_uses,
141
-
disabled: r.disabled.unwrap_or(false),
138
+
disabled: r.state().is_disabled(),
142
139
for_account: creator_did.clone(),
143
140
created_by: creator_did,
144
141
created_at: r.created_at.to_rfc3339(),
+9
-19
crates/tranquil-pds/src/api/admin/status.rs
+9
-19
crates/tranquil-pds/src/api/admin/status.rs
···
175
175
Some("com.atproto.admin.defs#repoRef") => {
176
176
let did_str = input.subject.get("did").and_then(|d| d.as_str());
177
177
if let Some(did_str) = did_str {
178
-
let did = Did::new_unchecked(did_str);
178
+
let did = unsafe { Did::new_unchecked(did_str) };
179
179
if let Some(takedown) = &input.takedown {
180
180
let takedown_ref = if takedown.applied {
181
181
takedown.r#ref.as_deref()
···
207
207
}
208
208
if let Some(takedown) = &input.takedown {
209
209
let status = if takedown.applied {
210
-
Some("takendown")
210
+
tranquil_db_traits::AccountStatus::Takendown
211
211
} else {
212
-
None
212
+
tranquil_db_traits::AccountStatus::Active
213
213
};
214
-
if let Err(e) = crate::api::repo::record::sequence_account_event(
215
-
&state,
216
-
&did,
217
-
!takedown.applied,
218
-
status,
219
-
)
220
-
.await
214
+
if let Err(e) =
215
+
crate::api::repo::record::sequence_account_event(&state, &did, status).await
221
216
{
222
217
warn!("Failed to sequence account event for takedown: {}", e);
223
218
}
224
219
}
225
220
if let Some(deactivated) = &input.deactivated {
226
221
let status = if deactivated.applied {
227
-
Some("deactivated")
222
+
tranquil_db_traits::AccountStatus::Deactivated
228
223
} else {
229
-
None
224
+
tranquil_db_traits::AccountStatus::Active
230
225
};
231
-
if let Err(e) = crate::api::repo::record::sequence_account_event(
232
-
&state,
233
-
&did,
234
-
!deactivated.applied,
235
-
status,
236
-
)
237
-
.await
226
+
if let Err(e) =
227
+
crate::api::repo::record::sequence_account_event(&state, &did, status).await
238
228
{
239
229
warn!("Failed to sequence account event for deactivation: {}", e);
240
230
}
+2
-2
crates/tranquil-pds/src/api/age_assurance.rs
+2
-2
crates/tranquil-pds/src/api/age_assurance.rs
···
33
33
}
34
34
35
35
async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option<String> {
36
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
36
+
let auth_header = crate::util::get_header_str(headers, "Authorization");
37
37
tracing::debug!(?auth_header, "age assurance: extracting token");
38
38
39
39
let extracted = extract_auth_token_from_header(auth_header)?;
40
40
tracing::debug!("age assurance: got token, validating");
41
41
42
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
42
+
let dpop_proof = crate::util::get_header_str(headers, "DPoP");
43
43
let http_uri = "/";
44
44
45
45
let auth_user = match validate_token_with_dpop(
+66
-116
crates/tranquil-pds/src/api/delegation.rs
+66
-116
crates/tranquil-pds/src/api/delegation.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::api::repo::record::utils::create_signed_commit;
3
3
use crate::auth::{Active, Auth};
4
-
use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes};
5
-
use crate::state::{AppState, RateLimitKind};
4
+
use crate::delegation::{
5
+
DelegationActionType, SCOPE_PRESETS, ValidatedDelegationScope, verify_can_add_controllers,
6
+
verify_can_be_controller, verify_can_control_accounts,
7
+
};
8
+
use crate::rate_limit::{AccountCreationLimit, RateLimited};
9
+
use crate::state::AppState;
6
10
use crate::types::{Did, Handle, Nsid, Rkey};
7
-
use crate::util::extract_client_ip;
11
+
use crate::util::{pds_hostname, pds_hostname_without_port};
8
12
use axum::{
9
13
Json,
10
14
extract::{Query, State},
11
-
http::{HeaderMap, StatusCode},
15
+
http::StatusCode,
12
16
response::{IntoResponse, Response},
13
17
};
14
18
use jacquard_common::types::{integer::LimitedU32, string::Tid};
···
57
61
.map(|c| ControllerInfo {
58
62
did: c.did,
59
63
handle: c.handle,
60
-
granted_scopes: c.granted_scopes,
64
+
granted_scopes: c.granted_scopes.into_string(),
61
65
granted_at: c.granted_at,
62
66
is_active: c.is_active,
63
67
})
···
69
73
#[derive(Debug, Deserialize)]
70
74
pub struct AddControllerInput {
71
75
pub controller_did: Did,
72
-
pub granted_scopes: String,
76
+
pub granted_scopes: ValidatedDelegationScope,
73
77
}
74
78
75
79
pub async fn add_controller(
···
77
81
auth: Auth<Active>,
78
82
Json(input): Json<AddControllerInput>,
79
83
) -> Result<Response, ApiError> {
80
-
if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) {
81
-
return Ok(ApiError::InvalidScopes(e).into_response());
82
-
}
83
-
84
84
let controller_exists = state
85
85
.user_repo
86
86
.get_by_did(&input.controller_did)
···
93
93
return Ok(ApiError::ControllerNotFound.into_response());
94
94
}
95
95
96
-
match state.delegation_repo.controls_any_accounts(&auth.did).await {
97
-
Ok(true) => {
98
-
return Ok(ApiError::InvalidDelegation(
99
-
"Cannot add controllers to an account that controls other accounts".into(),
100
-
)
101
-
.into_response());
102
-
}
103
-
Err(e) => {
104
-
tracing::error!("Failed to check delegation status: {:?}", e);
105
-
return Ok(
106
-
ApiError::InternalError(Some("Failed to verify delegation status".into()))
107
-
.into_response(),
108
-
);
109
-
}
110
-
Ok(false) => {}
111
-
}
96
+
let can_add = match verify_can_add_controllers(&state, &auth).await {
97
+
Ok(proof) => proof,
98
+
Err(response) => return Ok(response),
99
+
};
112
100
113
-
match state
114
-
.delegation_repo
115
-
.has_any_controllers(&input.controller_did)
116
-
.await
117
-
{
118
-
Ok(true) => {
119
-
return Ok(ApiError::InvalidDelegation(
120
-
"Cannot add a controlled account as a controller".into(),
121
-
)
122
-
.into_response());
123
-
}
124
-
Err(e) => {
125
-
tracing::error!("Failed to check controller status: {:?}", e);
126
-
return Ok(
127
-
ApiError::InternalError(Some("Failed to verify controller status".into()))
128
-
.into_response(),
129
-
);
130
-
}
131
-
Ok(false) => {}
132
-
}
101
+
let can_be_controller = match verify_can_be_controller(&state, &input.controller_did).await {
102
+
Ok(proof) => proof,
103
+
Err(response) => return Ok(response),
104
+
};
133
105
134
106
match state
135
107
.delegation_repo
136
108
.create_delegation(
137
-
&auth.did,
138
-
&input.controller_did,
109
+
can_add.did(),
110
+
can_be_controller.did(),
139
111
&input.granted_scopes,
140
-
&auth.did,
112
+
can_add.did(),
141
113
)
142
114
.await
143
115
{
···
145
117
let _ = state
146
118
.delegation_repo
147
119
.log_delegation_action(
148
-
&auth.did,
149
-
&auth.did,
150
-
Some(&input.controller_did),
120
+
can_add.did(),
121
+
can_add.did(),
122
+
Some(can_be_controller.did()),
151
123
DelegationActionType::GrantCreated,
152
124
Some(serde_json::json!({
153
-
"granted_scopes": input.granted_scopes
125
+
"granted_scopes": input.granted_scopes.as_str()
154
126
})),
155
127
None,
156
128
None,
···
235
207
#[derive(Debug, Deserialize)]
236
208
pub struct UpdateControllerScopesInput {
237
209
pub controller_did: Did,
238
-
pub granted_scopes: String,
210
+
pub granted_scopes: ValidatedDelegationScope,
239
211
}
240
212
241
213
pub async fn update_controller_scopes(
···
243
215
auth: Auth<Active>,
244
216
Json(input): Json<UpdateControllerScopesInput>,
245
217
) -> Result<Response, ApiError> {
246
-
if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) {
247
-
return Ok(ApiError::InvalidScopes(e).into_response());
248
-
}
249
-
250
218
match state
251
219
.delegation_repo
252
220
.update_delegation_scopes(&auth.did, &input.controller_did, &input.granted_scopes)
···
261
229
Some(&input.controller_did),
262
230
DelegationActionType::ScopesModified,
263
231
Some(serde_json::json!({
264
-
"new_scopes": input.granted_scopes
232
+
"new_scopes": input.granted_scopes.as_str()
265
233
})),
266
234
None,
267
235
None,
···
326
294
.map(|a| DelegatedAccountInfo {
327
295
did: a.did,
328
296
handle: a.handle,
329
-
granted_scopes: a.granted_scopes,
297
+
granted_scopes: a.granted_scopes.into_string(),
330
298
granted_at: a.granted_at,
331
299
})
332
300
.collect(),
···
443
411
pub struct CreateDelegatedAccountInput {
444
412
pub handle: String,
445
413
pub email: Option<String>,
446
-
pub controller_scopes: String,
414
+
pub controller_scopes: ValidatedDelegationScope,
447
415
pub invite_code: Option<String>,
448
416
}
449
417
···
456
424
457
425
pub async fn create_delegated_account(
458
426
State(state): State<AppState>,
459
-
headers: HeaderMap,
427
+
_rate_limit: RateLimited<AccountCreationLimit>,
460
428
auth: Auth<Active>,
461
429
Json(input): Json<CreateDelegatedAccountInput>,
462
430
) -> Result<Response, ApiError> {
463
-
let client_ip = extract_client_ip(&headers);
464
-
if !state
465
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
466
-
.await
467
-
{
468
-
warn!(ip = %client_ip, "Delegated account creation rate limit exceeded");
469
-
return Ok(ApiError::RateLimitExceeded(Some(
470
-
"Too many account creation attempts. Please try again later.".into(),
471
-
))
472
-
.into_response());
473
-
}
474
-
475
-
if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) {
476
-
return Ok(ApiError::InvalidScopes(e).into_response());
477
-
}
478
-
479
-
match state.delegation_repo.has_any_controllers(&auth.did).await {
480
-
Ok(true) => {
481
-
return Ok(ApiError::InvalidDelegation(
482
-
"Cannot create delegated accounts from a controlled account".into(),
483
-
)
484
-
.into_response());
485
-
}
486
-
Err(e) => {
487
-
tracing::error!("Failed to check controller status: {:?}", e);
488
-
return Ok(
489
-
ApiError::InternalError(Some("Failed to verify controller status".into()))
490
-
.into_response(),
491
-
);
492
-
}
493
-
Ok(false) => {}
494
-
}
431
+
let can_control = match verify_can_control_accounts(&state, &auth).await {
432
+
Ok(proof) => proof,
433
+
Err(response) => return Ok(response),
434
+
};
495
435
496
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
497
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
436
+
let hostname = pds_hostname();
437
+
let hostname_for_handles = pds_hostname_without_port();
498
438
let pds_suffix = format!(".{}", hostname_for_handles);
499
439
500
440
let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) {
···
527
467
return Ok(ApiError::InvalidEmail.into_response());
528
468
}
529
469
530
-
if let Some(ref code) = input.invite_code {
531
-
let valid = state
532
-
.infra_repo
533
-
.is_invite_code_valid(code)
534
-
.await
535
-
.unwrap_or(false);
536
-
537
-
if !valid {
538
-
return Ok(ApiError::InvalidInviteCode.into_response());
470
+
let validated_invite_code = if let Some(ref code) = input.invite_code {
471
+
match state.infra_repo.validate_invite_code(code).await {
472
+
Ok(validated) => Some(validated),
473
+
Err(_) => return Ok(ApiError::InvalidInviteCode.into_response()),
539
474
}
540
475
} else {
541
476
let invite_required = std::env::var("INVITE_CODE_REQUIRED")
···
544
479
if invite_required {
545
480
return Ok(ApiError::InviteCodeRequired.into_response());
546
481
}
547
-
}
482
+
None
483
+
};
548
484
549
485
use k256::ecdsa::SigningKey;
550
486
use rand::rngs::OsRng;
···
593
529
.into_response());
594
530
}
595
531
596
-
let did = Did::new_unchecked(&genesis_result.did);
597
-
let handle = Handle::new_unchecked(&handle);
598
-
info!(did = %did, handle = %handle, controller = %&auth.did, "Created DID for delegated account");
532
+
let did = unsafe { Did::new_unchecked(&genesis_result.did) };
533
+
let handle = unsafe { Handle::new_unchecked(&handle) };
534
+
info!(did = %did, handle = %handle, controller = %can_control.did(), "Created DID for delegated account");
599
535
600
536
let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) {
601
537
Ok(bytes) => bytes,
···
635
571
handle: handle.clone(),
636
572
email: email.clone(),
637
573
did: did.clone(),
638
-
controller_did: auth.did.clone(),
639
-
controller_scopes: input.controller_scopes.clone(),
574
+
controller_did: can_control.did().clone(),
575
+
controller_scopes: input.controller_scopes.as_str().to_string(),
640
576
encrypted_key_bytes,
641
577
encryption_version: crate::config::ENCRYPTION_VERSION,
642
578
commit_cid: commit_cid.to_string(),
···
645
581
invite_code: input.invite_code.clone(),
646
582
};
647
583
648
-
let _user_id = match state
584
+
let user_id = match state
649
585
.user_repo
650
586
.create_delegated_account(&create_input)
651
587
.await
···
663
599
}
664
600
};
665
601
602
+
if let Some(validated) = validated_invite_code
603
+
&& let Err(e) = state
604
+
.infra_repo
605
+
.record_invite_code_use(&validated, user_id)
606
+
.await
607
+
{
608
+
warn!("Failed to record invite code use for {}: {:?}", did, e);
609
+
}
610
+
666
611
if let Err(e) =
667
612
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle)).await
668
613
{
669
614
warn!("Failed to sequence identity event for {}: {}", did, e);
670
615
}
671
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
616
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
617
+
&state,
618
+
&did,
619
+
tranquil_db_traits::AccountStatus::Active,
620
+
)
621
+
.await
672
622
{
673
623
warn!("Failed to sequence account event for {}: {}", did, e);
674
624
}
···
677
627
"$type": "app.bsky.actor.profile",
678
628
"displayName": handle
679
629
});
680
-
let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile");
681
-
let profile_rkey = Rkey::new_unchecked("self");
630
+
let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") };
631
+
let profile_rkey = unsafe { Rkey::new_unchecked("self") };
682
632
if let Err(e) = crate::api::repo::record::create_record_internal(
683
633
&state,
684
634
&did,
···
700
650
DelegationActionType::GrantCreated,
701
651
Some(json!({
702
652
"account_created": true,
703
-
"granted_scopes": input.controller_scopes
653
+
"granted_scopes": input.controller_scopes.as_str()
704
654
})),
705
655
None,
706
656
None,
+19
crates/tranquil-pds/src/api/error.rs
+19
crates/tranquil-pds/src/api/error.rs
···
694
694
}
695
695
}
696
696
697
+
impl From<crate::rate_limit::UserRateLimitError> for ApiError {
698
+
fn from(e: crate::rate_limit::UserRateLimitError) -> Self {
699
+
Self::RateLimitExceeded(e.message)
700
+
}
701
+
}
702
+
697
703
#[allow(clippy::result_large_err)]
698
704
pub fn parse_did(s: &str) -> Result<tranquil_types::Did, Response> {
699
705
s.parse()
···
756
762
_ => "Invalid request body".to_string(),
757
763
}
758
764
}
765
+
766
+
pub trait DbResultExt<T> {
767
+
fn log_db_err(self, ctx: &str) -> Result<T, ApiError>;
768
+
}
769
+
770
+
impl<T, E: std::fmt::Debug> DbResultExt<T> for Result<T, E> {
771
+
fn log_db_err(self, ctx: &str) -> Result<T, ApiError> {
772
+
self.map_err(|e| {
773
+
tracing::error!("DB error {}: {:?}", ctx, e);
774
+
ApiError::DatabaseError
775
+
})
776
+
}
777
+
}
+40
-65
crates/tranquil-pds/src/api/identity/account.rs
+40
-65
crates/tranquil-pds/src/api/identity/account.rs
···
3
3
use crate::api::repo::record::utils::create_signed_commit;
4
4
use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token};
5
5
use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key};
6
-
use crate::state::{AppState, RateLimitKind};
6
+
use crate::rate_limit::{AccountCreationLimit, RateLimited};
7
+
use crate::state::AppState;
7
8
use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey};
9
+
use crate::util::{pds_hostname, pds_hostname_without_port};
8
10
use crate::validation::validate_password;
9
11
use axum::{
10
12
Json,
···
22
24
use std::sync::Arc;
23
25
use tracing::{debug, error, info, warn};
24
26
25
-
fn extract_client_ip(headers: &HeaderMap) -> String {
26
-
if let Some(forwarded) = headers.get("x-forwarded-for")
27
-
&& let Ok(value) = forwarded.to_str()
28
-
&& let Some(first_ip) = value.split(',').next()
29
-
{
30
-
return first_ip.trim().to_string();
31
-
}
32
-
if let Some(real_ip) = headers.get("x-real-ip")
33
-
&& let Ok(value) = real_ip.to_str()
34
-
{
35
-
return value.trim().to_string();
36
-
}
37
-
"unknown".to_string()
38
-
}
39
-
40
27
#[derive(Deserialize)]
41
28
#[serde(rename_all = "camelCase")]
42
29
pub struct CreateAccountInput {
···
68
55
69
56
pub async fn create_account(
70
57
State(state): State<AppState>,
58
+
_rate_limit: RateLimited<AccountCreationLimit>,
71
59
headers: HeaderMap,
72
60
Json(input): Json<CreateAccountInput>,
73
61
) -> Response {
···
84
72
} else {
85
73
info!("create_account called");
86
74
}
87
-
let client_ip = extract_client_ip(&headers);
88
-
if !state
89
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
90
-
.await
91
-
{
92
-
warn!(ip = %client_ip, "Account creation rate limit exceeded");
93
-
return ApiError::RateLimitExceeded(Some(
94
-
"Too many account creation attempts. Please try again later.".into(),
95
-
))
96
-
.into_response();
97
-
}
98
75
99
76
let migration_auth = if let Some(extracted) =
100
-
extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
77
+
extract_auth_token_from_header(crate::util::get_header_str(&headers, "Authorization"))
101
78
{
102
79
let token = extracted.token;
103
80
if is_service_token(&token) {
···
143
120
if (is_migration || is_did_web_byod)
144
121
&& let (Some(provided_did), Some(auth_did)) = (input.did.as_ref(), migration_auth.as_ref())
145
122
{
146
-
if provided_did != auth_did {
123
+
if provided_did != auth_did.as_str() {
147
124
info!(
148
125
"[MIGRATION] createAccount: Service token mismatch - token_did={} provided_did={}",
149
126
auth_did, provided_did
···
164
141
}
165
142
}
166
143
167
-
let hostname_for_validation =
168
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
144
+
let hostname_for_validation = pds_hostname_without_port();
169
145
let pds_suffix = format!(".{}", hostname_for_validation);
170
146
171
147
let validated_short_handle = if !input.handle.contains('.')
···
242
218
_ => return ApiError::InvalidVerificationChannel.into_response(),
243
219
})
244
220
};
245
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
246
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
221
+
let hostname = pds_hostname();
222
+
let hostname_for_handles = pds_hostname_without_port();
247
223
let pds_endpoint = format!("https://{}", hostname);
248
224
let suffix = format!(".{}", hostname_for_handles);
249
225
let handle = if input.handle.ends_with(&suffix) {
···
308
284
}
309
285
if !is_did_web_byod
310
286
&& let Err(e) =
311
-
verify_did_web(d, &hostname, &input.handle, input.signing_key.as_deref()).await
287
+
verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()).await
312
288
{
313
289
return ApiError::InvalidDid(e).into_response();
314
290
}
···
322
298
d.clone()
323
299
} else if d.starts_with("did:web:") {
324
300
if !is_did_web_byod
325
-
&& let Err(e) = verify_did_web(
326
-
d,
327
-
&hostname,
328
-
&input.handle,
329
-
input.signing_key.as_deref(),
330
-
)
331
-
.await
301
+
&& let Err(e) =
302
+
verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref())
303
+
.await
332
304
{
333
305
return ApiError::InvalidDid(e).into_response();
334
306
}
···
408
380
};
409
381
if is_migration {
410
382
let reactivate_input = tranquil_db_traits::MigrationReactivationInput {
411
-
did: Did::new_unchecked(&did),
412
-
new_handle: Handle::new_unchecked(&handle),
383
+
did: unsafe { Did::new_unchecked(&did) },
384
+
new_handle: unsafe { Handle::new_unchecked(&handle) },
413
385
new_email: email.clone(),
414
386
};
415
387
match state
···
463
435
}
464
436
};
465
437
let session_data = tranquil_db_traits::SessionTokenCreate {
466
-
did: Did::new_unchecked(&did),
438
+
did: unsafe { Did::new_unchecked(&did) },
467
439
access_jti: access_meta.jti.clone(),
468
440
refresh_jti: refresh_meta.jti.clone(),
469
441
access_expires_at: access_meta.expires_at,
470
442
refresh_expires_at: refresh_meta.expires_at,
471
-
legacy_login: false,
443
+
login_type: tranquil_db_traits::LoginType::Modern,
472
444
mfa_verified: false,
473
445
scope: None,
474
446
controller_did: None,
···
478
450
error!("Error creating session: {:?}", e);
479
451
return ApiError::InternalError(None).into_response();
480
452
}
481
-
let hostname =
482
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
453
+
let hostname = pds_hostname();
483
454
let verification_required = if let Some(ref user_email) = email {
484
455
let token =
485
456
crate::auth::verification_token::generate_migration_token(&did, user_email);
···
491
462
reactivated.user_id,
492
463
user_email,
493
464
&formatted_token,
494
-
&hostname,
465
+
hostname,
495
466
)
496
467
.await
497
468
{
···
505
476
axum::http::StatusCode::OK,
506
477
Json(CreateAccountOutput {
507
478
handle: handle.clone().into(),
508
-
did: Did::new_unchecked(&did),
479
+
did: unsafe { Did::new_unchecked(&did) },
509
480
did_doc: state.did_resolver.resolve_did_document(&did).await,
510
481
access_jwt: access_meta.token,
511
482
refresh_jwt: refresh_meta.token,
···
529
500
}
530
501
}
531
502
532
-
let handle_typed = Handle::new_unchecked(&handle);
503
+
let handle_typed = unsafe { Handle::new_unchecked(&handle) };
533
504
let handle_available = match state
534
505
.user_repo
535
506
.check_handle_available_for_new_account(&handle_typed)
···
613
584
}
614
585
};
615
586
let rev = Tid::now(LimitedU32::MIN);
616
-
let did_for_commit = Did::new_unchecked(&did);
587
+
let did_for_commit = unsafe { Did::new_unchecked(&did) };
617
588
let (commit_bytes, _sig) =
618
589
match create_signed_commit(&did_for_commit, mst_root, rev.as_ref(), None, &signing_key) {
619
590
Ok(result) => result,
···
649
620
};
650
621
651
622
let create_input = tranquil_db_traits::CreatePasswordAccountInput {
652
-
handle: Handle::new_unchecked(&handle),
623
+
handle: unsafe { Handle::new_unchecked(&handle) },
653
624
email: email.clone(),
654
-
did: Did::new_unchecked(&did),
625
+
did: unsafe { Did::new_unchecked(&did) },
655
626
password_hash,
656
627
preferred_comms_channel,
657
628
discord_id: input
···
701
672
};
702
673
let user_id = create_result.user_id;
703
674
if !is_migration && !is_did_web_byod {
704
-
let did_typed = Did::new_unchecked(&did);
705
-
let handle_typed = Handle::new_unchecked(&handle);
675
+
let did_typed = unsafe { Did::new_unchecked(&did) };
676
+
let handle_typed = unsafe { Handle::new_unchecked(&handle) };
706
677
if let Err(e) = crate::api::repo::record::sequence_identity_event(
707
678
&state,
708
679
&did_typed,
···
712
683
{
713
684
warn!("Failed to sequence identity event for {}: {}", did, e);
714
685
}
715
-
if let Err(e) =
716
-
crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await
686
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
687
+
&state,
688
+
&did_typed,
689
+
tranquil_db_traits::AccountStatus::Active,
690
+
)
691
+
.await
717
692
{
718
693
warn!("Failed to sequence account event for {}: {}", did, e);
719
694
}
···
742
717
"$type": "app.bsky.actor.profile",
743
718
"displayName": input.handle
744
719
});
745
-
let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile");
746
-
let profile_rkey = Rkey::new_unchecked("self");
720
+
let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") };
721
+
let profile_rkey = unsafe { Rkey::new_unchecked("self") };
747
722
if let Err(e) = crate::api::repo::record::create_record_internal(
748
723
&state,
749
724
&did_typed,
···
756
731
warn!("Failed to create default profile for {}: {}", did, e);
757
732
}
758
733
}
759
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
734
+
let hostname = pds_hostname();
760
735
if !is_migration {
761
736
if let Some(ref recipient) = verification_recipient {
762
737
let verification_token = crate::auth::verification_token::generate_signup_token(
···
772
747
verification_channel,
773
748
recipient,
774
749
&formatted_token,
775
-
&hostname,
750
+
hostname,
776
751
)
777
752
.await
778
753
{
···
791
766
user_id,
792
767
user_email,
793
768
&formatted_token,
794
-
&hostname,
769
+
hostname,
795
770
)
796
771
.await
797
772
{
···
816
791
}
817
792
};
818
793
let session_data = tranquil_db_traits::SessionTokenCreate {
819
-
did: Did::new_unchecked(&did),
794
+
did: unsafe { Did::new_unchecked(&did) },
820
795
access_jti: access_meta.jti.clone(),
821
796
refresh_jti: refresh_meta.jti.clone(),
822
797
access_expires_at: access_meta.expires_at,
823
798
refresh_expires_at: refresh_meta.expires_at,
824
-
legacy_login: false,
799
+
login_type: tranquil_db_traits::LoginType::Modern,
825
800
mfa_verified: false,
826
801
scope: None,
827
802
controller_did: None,
···
845
820
StatusCode::OK,
846
821
Json(CreateAccountOutput {
847
822
handle: handle.clone().into(),
848
-
did: Did::new_unchecked(&did),
823
+
did: unsafe { Did::new_unchecked(&did) },
849
824
did_doc,
850
825
access_jwt: access_meta.token,
851
826
refresh_jwt: refresh_meta.token,
+27
-31
crates/tranquil-pds/src/api/identity/did.rs
+27
-31
crates/tranquil-pds/src/api/identity/did.rs
···
1
1
use crate::api::{ApiError, DidResponse, EmptyResponse};
2
2
use crate::auth::{Auth, NotTakendown};
3
3
use crate::plc::signing_key_to_did_key;
4
+
use crate::rate_limit::{
5
+
HandleUpdateDailyLimit, HandleUpdateLimit, check_user_rate_limit_with_message,
6
+
};
4
7
use crate::state::AppState;
5
8
use crate::types::Handle;
9
+
use crate::util::{get_header_str, pds_hostname, pds_hostname_without_port};
6
10
use axum::{
7
11
Json,
8
12
extract::{Path, Query, State},
···
101
105
}
102
106
103
107
pub async fn well_known_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
104
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
105
-
let host_header = headers
106
-
.get("host")
107
-
.and_then(|h| h.to_str().ok())
108
-
.unwrap_or(&hostname);
108
+
let hostname = pds_hostname();
109
+
let hostname_without_port = pds_hostname_without_port();
110
+
let host_header = get_header_str(&headers, "host").unwrap_or(hostname);
109
111
let host_without_port = host_header.split(':').next().unwrap_or(host_header);
110
-
let hostname_without_port = hostname.split(':').next().unwrap_or(&hostname);
111
112
if host_without_port != hostname_without_port
112
113
&& host_without_port.ends_with(&format!(".{}", hostname_without_port))
113
114
{
114
115
let handle = host_without_port
115
116
.strip_suffix(&format!(".{}", hostname_without_port))
116
117
.unwrap_or(host_without_port);
117
-
return serve_subdomain_did_doc(&state, handle, &hostname).await;
118
+
return serve_subdomain_did_doc(&state, handle, hostname).await;
118
119
}
119
120
let did = if hostname.contains(':') {
120
121
format!("did:web:{}", hostname.replace(':', "%3A"))
···
257
258
}
258
259
259
260
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
260
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
261
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
261
+
let hostname = pds_hostname();
262
+
let hostname_for_handles = pds_hostname_without_port();
262
263
let current_handle = format!("{}.{}", handle, hostname_for_handles);
263
264
let current_handle_typed: Handle = match current_handle.parse() {
264
265
Ok(h) => h,
···
531
532
ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into()))
532
533
})?;
533
534
534
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
535
+
let hostname = pds_hostname();
535
536
let pds_endpoint = format!("https://{}", hostname);
536
537
let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)
537
538
.map_err(|_| ApiError::InternalError(None))?;
···
585
586
return Ok(e);
586
587
}
587
588
let did = auth.did.clone();
588
-
if !state
589
-
.check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did)
590
-
.await
591
-
{
592
-
return Err(ApiError::RateLimitExceeded(Some(
593
-
"Too many handle updates. Try again later.".into(),
594
-
)));
595
-
}
596
-
if !state
597
-
.check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did)
598
-
.await
599
-
{
600
-
return Err(ApiError::RateLimitExceeded(Some(
601
-
"Daily handle update limit exceeded.".into(),
602
-
)));
603
-
}
589
+
let _rate_limit = check_user_rate_limit_with_message::<HandleUpdateLimit>(
590
+
&state,
591
+
&did,
592
+
"Too many handle updates. Try again later.",
593
+
)
594
+
.await?;
595
+
let _daily_rate_limit = check_user_rate_limit_with_message::<HandleUpdateDailyLimit>(
596
+
&state,
597
+
&did,
598
+
"Daily handle update limit exceeded.",
599
+
)
600
+
.await?;
604
601
let user_row = state
605
602
.user_repo
606
603
.get_id_and_handle_by_did(&did)
···
639
636
"Inappropriate language in handle".into(),
640
637
)));
641
638
}
642
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
643
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
639
+
let hostname_for_handles = pds_hostname_without_port();
644
640
let suffix = format!(".{}", hostname_for_handles);
645
641
let is_service_domain =
646
642
crate::handle::is_service_domain_handle(&new_handle, hostname_for_handles);
···
656
652
format!("{}.{}", new_handle, hostname_for_handles)
657
653
};
658
654
if full_handle == current_handle {
659
-
let handle_typed = Handle::new_unchecked(&full_handle);
655
+
let handle_typed = unsafe { Handle::new_unchecked(&full_handle) };
660
656
if let Err(e) =
661
657
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed))
662
658
.await
···
679
675
full_handle
680
676
} else {
681
677
if new_handle == current_handle {
682
-
let handle_typed = Handle::new_unchecked(&new_handle);
678
+
let handle_typed = unsafe { Handle::new_unchecked(&new_handle) };
683
679
if let Err(e) =
684
680
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed))
685
681
.await
···
772
768
}
773
769
774
770
pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
775
-
let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
771
+
let host = match crate::util::get_header_str(&headers, "host") {
776
772
Some(h) => h,
777
773
None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
778
774
};
+7
-12
crates/tranquil-pds/src/api/identity/plc/request.rs
+7
-12
crates/tranquil-pds/src/api/identity/plc/request.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Auth, Permissive};
4
4
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
5
6
use axum::{
6
7
extract::State,
7
8
response::{IntoResponse, Response},
8
9
};
9
10
use chrono::{Duration, Utc};
10
-
use tracing::{error, info, warn};
11
+
use tracing::{info, warn};
11
12
12
13
fn generate_plc_token() -> String {
13
14
crate::util::generate_token_code()
···
28
29
.user_repo
29
30
.get_id_by_did(&auth.did)
30
31
.await
31
-
.map_err(|e| {
32
-
error!("DB error: {:?}", e);
33
-
ApiError::InternalError(None)
34
-
})?
32
+
.log_db_err("fetching user id")?
35
33
.ok_or(ApiError::AccountNotFound)?;
36
34
37
35
let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await;
···
41
39
.infra_repo
42
40
.insert_plc_token(user_id, &plc_token, expires_at)
43
41
.await
44
-
.map_err(|e| {
45
-
error!("Failed to create PLC token: {:?}", e);
46
-
ApiError::InternalError(None)
47
-
})?;
42
+
.log_db_err("creating PLC token")?;
48
43
49
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
44
+
let hostname = pds_hostname();
50
45
if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation(
51
46
state.user_repo.as_ref(),
52
47
state.infra_repo.as_ref(),
53
48
user_id,
54
49
&plc_token,
55
-
&hostname,
50
+
hostname,
56
51
)
57
52
.await
58
53
{
+4
-12
crates/tranquil-pds/src/api/identity/plc/sign.rs
+4
-12
crates/tranquil-pds/src/api/identity/plc/sign.rs
···
1
1
use crate::api::ApiError;
2
+
use crate::api::error::DbResultExt;
2
3
use crate::auth::{Auth, Permissive};
3
4
use crate::circuit_breaker::with_circuit_breaker;
4
5
use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation};
···
64
65
.user_repo
65
66
.get_id_by_did(did)
66
67
.await
67
-
.map_err(|e| {
68
-
error!("DB error: {:?}", e);
69
-
ApiError::InternalError(None)
70
-
})?
68
+
.log_db_err("fetching user id")?
71
69
.ok_or(ApiError::AccountNotFound)?;
72
70
73
71
let token_expiry = state
74
72
.infra_repo
75
73
.get_plc_token_expiry(user_id, token)
76
74
.await
77
-
.map_err(|e| {
78
-
error!("DB error: {:?}", e);
79
-
ApiError::InternalError(None)
80
-
})?
75
+
.log_db_err("fetching PLC token expiry")?
81
76
.ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?;
82
77
83
78
if Utc::now() > token_expiry {
···
88
83
.user_repo
89
84
.get_user_key_by_id(user_id)
90
85
.await
91
-
.map_err(|e| {
92
-
error!("DB error: {:?}", e);
93
-
ApiError::InternalError(None)
94
-
})?
86
+
.log_db_err("fetching user key")?
95
87
.ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?;
96
88
97
89
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+5
-9
crates/tranquil-pds/src/api/identity/plc/submit.rs
+5
-9
crates/tranquil-pds/src/api/identity/plc/submit.rs
···
1
+
use crate::api::error::DbResultExt;
1
2
use crate::api::{ApiError, EmptyResponse};
2
3
use crate::auth::{Auth, Permissive};
3
4
use crate::circuit_breaker::with_circuit_breaker;
4
5
use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation};
5
6
use crate::state::AppState;
7
+
use crate::util::pds_hostname;
6
8
use axum::{
7
9
Json,
8
10
extract::State,
···
40
42
.map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?;
41
43
42
44
let op = &input.operation;
43
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
45
+
let hostname = pds_hostname();
44
46
let public_url = format!("https://{}", hostname);
45
47
let user = state
46
48
.user_repo
47
49
.get_id_and_handle_by_did(did)
48
50
.await
49
-
.map_err(|e| {
50
-
error!("DB error: {:?}", e);
51
-
ApiError::InternalError(None)
52
-
})?
51
+
.log_db_err("fetching user")?
53
52
.ok_or(ApiError::AccountNotFound)?;
54
53
55
54
let key_row = state
56
55
.user_repo
57
56
.get_user_key_by_id(user.id)
58
57
.await
59
-
.map_err(|e| {
60
-
error!("DB error: {:?}", e);
61
-
ApiError::InternalError(None)
62
-
})?
58
+
.log_db_err("fetching user key")?
63
59
.ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?;
64
60
65
61
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+33
-27
crates/tranquil-pds/src/api/notification_prefs.rs
+33
-27
crates/tranquil-pds/src/api/notification_prefs.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::auth::{Active, Auth};
3
3
use crate::state::AppState;
4
+
use crate::util::pds_hostname;
4
5
use axum::{
5
6
Json,
6
7
extract::State,
···
9
10
use serde::{Deserialize, Serialize};
10
11
use serde_json::json;
11
12
use tracing::info;
13
+
use tranquil_db_traits::{CommsChannel, CommsStatus, CommsType};
12
14
13
15
#[derive(Serialize)]
14
16
#[serde(rename_all = "camelCase")]
15
17
pub struct NotificationPrefsResponse {
16
-
pub preferred_channel: String,
18
+
pub preferred_channel: CommsChannel,
17
19
pub email: String,
18
20
pub discord_id: Option<String>,
19
21
pub discord_verified: bool,
···
50
52
#[serde(rename_all = "camelCase")]
51
53
pub struct NotificationHistoryEntry {
52
54
pub created_at: String,
53
-
pub channel: String,
54
-
pub comms_type: String,
55
-
pub status: String,
55
+
pub channel: CommsChannel,
56
+
pub comms_type: CommsType,
57
+
pub status: CommsStatus,
56
58
pub subject: Option<String>,
57
59
pub body: String,
58
60
}
···
81
83
.map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?;
82
84
83
85
let sensitive_types = [
84
-
"email_verification",
85
-
"password_reset",
86
-
"email_update",
87
-
"two_factor_code",
88
-
"passkey_recovery",
89
-
"migration_verification",
90
-
"plc_operation",
91
-
"channel_verification",
92
-
"signup_verification",
86
+
CommsType::EmailVerification,
87
+
CommsType::PasswordReset,
88
+
CommsType::EmailUpdate,
89
+
CommsType::TwoFactorCode,
90
+
CommsType::PasskeyRecovery,
91
+
CommsType::MigrationVerification,
92
+
CommsType::PlcOperation,
93
+
CommsType::ChannelVerification,
93
94
];
94
95
95
96
let notifications = rows
96
97
.iter()
97
98
.map(|row| {
98
-
let body = if sensitive_types.contains(&row.comms_type.as_str()) {
99
+
let body = if sensitive_types.contains(&row.comms_type) {
99
100
"[Code redacted for security]".to_string()
100
101
} else {
101
102
row.body.clone()
102
103
};
103
104
NotificationHistoryEntry {
104
105
created_at: row.created_at.to_rfc3339(),
105
-
channel: row.channel.clone(),
106
-
comms_type: row.comms_type.clone(),
107
-
status: row.status.clone(),
106
+
channel: row.channel,
107
+
comms_type: row.comms_type,
108
+
status: row.status,
108
109
subject: row.subject.clone(),
109
110
body,
110
111
}
···
145
146
let formatted_token = crate::auth::verification_token::format_token_for_display(&token);
146
147
147
148
if channel == "email" {
148
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
149
+
let hostname = pds_hostname();
149
150
let handle_str = handle.unwrap_or("user");
150
151
crate::comms::comms_repo::enqueue_email_update(
151
152
state.infra_repo.as_ref(),
···
153
154
identifier,
154
155
handle_str,
155
156
&formatted_token,
156
-
&hostname,
157
+
hostname,
157
158
)
158
159
.await
159
160
.map_err(|e| format!("Failed to enqueue email notification: {}", e))?;
···
200
201
201
202
let mut verification_required: Vec<String> = Vec::new();
202
203
203
-
if let Some(ref channel) = input.preferred_channel {
204
-
let valid_channels = ["email", "discord", "telegram", "signal"];
205
-
if !valid_channels.contains(&channel.as_str()) {
206
-
return Err(ApiError::InvalidRequest(
207
-
"Invalid channel. Must be one of: email, discord, telegram, signal".into(),
208
-
));
209
-
}
204
+
if let Some(ref channel_str) = input.preferred_channel {
205
+
let channel = match channel_str.as_str() {
206
+
"email" => CommsChannel::Email,
207
+
"discord" => CommsChannel::Discord,
208
+
"telegram" => CommsChannel::Telegram,
209
+
"signal" => CommsChannel::Signal,
210
+
_ => {
211
+
return Err(ApiError::InvalidRequest(
212
+
"Invalid channel. Must be one of: email, discord, telegram, signal".into(),
213
+
));
214
+
}
215
+
};
210
216
state
211
217
.user_repo
212
218
.update_preferred_comms_channel(&auth.did, channel)
213
219
.await
214
220
.map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?;
215
-
info!(did = %auth.did, channel = %channel, "Updated preferred notification channel");
221
+
info!(did = %auth.did, channel = ?channel, "Updated preferred notification channel");
216
222
}
217
223
218
224
if let Some(ref new_email) = input.email {
+4
-7
crates/tranquil-pds/src/api/proxy.rs
+4
-7
crates/tranquil-pds/src/api/proxy.rs
···
3
3
use crate::api::error::ApiError;
4
4
use crate::api::proxy_client::proxy_client;
5
5
use crate::state::AppState;
6
+
use crate::util::get_header_str;
6
7
use axum::{
7
8
body::Bytes,
8
9
extract::{RawQuery, Request, State},
···
191
192
.into_response();
192
193
}
193
194
194
-
let Some(proxy_header) = headers
195
-
.get("atproto-proxy")
196
-
.and_then(|h| h.to_str().ok())
197
-
.map(String::from)
198
-
else {
195
+
let Some(proxy_header) = get_header_str(&headers, "atproto-proxy").map(String::from) else {
199
196
return ApiError::InvalidRequest("Missing required atproto-proxy header".into())
200
197
.into_response();
201
198
};
···
217
214
218
215
let mut auth_header_val = headers.get("Authorization").cloned();
219
216
if let Some(extracted) = crate::auth::extract_auth_token_from_header(
220
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
217
+
crate::util::get_header_str(&headers, "Authorization"),
221
218
) {
222
219
let token = extracted.token;
223
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
220
+
let dpop_proof = crate::util::get_header_str(&headers, "DPoP");
224
221
let http_uri = crate::util::build_full_url(&uri.to_string());
225
222
226
223
match crate::auth::validate_token_with_dpop(
+18
-28
crates/tranquil-pds/src/api/repo/blob.rs
+18
-28
crates/tranquil-pds/src/api/repo/blob.rs
···
1
-
use crate::api::error::ApiError;
2
-
use crate::auth::{Auth, AuthAny, NotTakendown, Permissive};
1
+
use crate::api::error::{ApiError, DbResultExt};
2
+
use crate::auth::{Auth, AuthAny, NotTakendown, Permissive, VerifyScope};
3
3
use crate::delegation::DelegationActionType;
4
4
use crate::state::AppState;
5
5
use crate::types::{CidLink, Did};
6
-
use crate::util::get_max_blob_size;
6
+
use crate::util::{get_header_str, get_max_blob_size};
7
7
use axum::body::Body;
8
8
use axum::{
9
9
Json,
···
56
56
if user.status.is_takendown() {
57
57
return Err(ApiError::AccountTakedown);
58
58
}
59
-
let mime_type_for_check = headers
60
-
.get("content-type")
61
-
.and_then(|h| h.to_str().ok())
62
-
.unwrap_or("application/octet-stream");
63
-
if let Err(e) = crate::auth::scope_check::check_blob_scope(
64
-
user.is_oauth(),
65
-
user.scope.as_deref(),
66
-
mime_type_for_check,
67
-
) {
68
-
return Ok(e);
69
-
}
70
-
(user.did.clone(), user.controller_did.clone())
59
+
let mime_type_for_check =
60
+
get_header_str(&headers, "content-type").unwrap_or("application/octet-stream");
61
+
let scope_proof = match user.verify_blob_upload(mime_type_for_check) {
62
+
Ok(proof) => proof,
63
+
Err(e) => return Ok(e.into_response()),
64
+
};
65
+
(
66
+
scope_proof.principal_did().into_did(),
67
+
scope_proof.controller_did().map(|c| c.into_did()),
68
+
)
71
69
}
72
70
};
73
71
···
80
78
return Err(ApiError::Forbidden);
81
79
}
82
80
83
-
let client_mime_hint = headers
84
-
.get("content-type")
85
-
.and_then(|h| h.to_str().ok())
86
-
.unwrap_or("application/octet-stream");
81
+
let client_mime_hint =
82
+
get_header_str(&headers, "content-type").unwrap_or("application/octet-stream");
87
83
88
84
let user_id = state
89
85
.user_repo
···
140
136
};
141
137
let cid = Cid::new_v1(0x55, multihash);
142
138
let cid_str = cid.to_string();
143
-
let cid_link: CidLink = CidLink::new_unchecked(&cid_str);
139
+
let cid_link: CidLink = unsafe { CidLink::new_unchecked(&cid_str) };
144
140
let storage_key = cid_str.clone();
145
141
146
142
info!(
···
232
228
.user_repo
233
229
.get_by_did(did)
234
230
.await
235
-
.map_err(|e| {
236
-
error!("DB error fetching user: {:?}", e);
237
-
ApiError::InternalError(None)
238
-
})?
231
+
.log_db_err("fetching user")?
239
232
.ok_or(ApiError::InternalError(None))?;
240
233
241
234
let limit = params.limit.unwrap_or(500).clamp(1, 1000);
···
244
237
.blob_repo
245
238
.list_missing_blobs(user.id, cursor, limit + 1)
246
239
.await
247
-
.map_err(|e| {
248
-
error!("DB error fetching missing blobs: {:?}", e);
249
-
ApiError::InternalError(None)
250
-
})?;
240
+
.log_db_err("fetching missing blobs")?;
251
241
252
242
let has_more = missing.len() > limit as usize;
253
243
let blobs: Vec<RecordBlob> = missing
+8
-12
crates/tranquil-pds/src/api/repo/import.rs
+8
-12
crates/tranquil-pds/src/api/repo/import.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::api::repo::record::create_signed_commit;
4
4
use crate::auth::{Auth, NotTakendown};
5
5
use crate::state::AppState;
···
49
49
.user_repo
50
50
.get_by_did(did)
51
51
.await
52
-
.map_err(|e| {
53
-
error!("DB error fetching user: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?
52
+
.log_db_err("fetching user")?
56
53
.ok_or(ApiError::AccountNotFound)?;
57
54
if user.takedown_ref.is_some() {
58
55
return Err(ApiError::AccountTakedown);
···
207
204
let record_uri =
208
205
AtUri::from_parts(did.as_str(), &record.collection, &record.rkey);
209
206
record.blob_refs.iter().map(move |blob_ref| {
210
-
(
211
-
record_uri.clone(),
212
-
CidLink::new_unchecked(blob_ref.cid.clone()),
213
-
)
207
+
(record_uri.clone(), unsafe {
208
+
CidLink::new_unchecked(blob_ref.cid.clone())
209
+
})
214
210
})
215
211
})
216
212
.collect();
···
275
271
error!("Failed to store new commit block: {:?}", e);
276
272
ApiError::InternalError(None)
277
273
})?;
278
-
let new_root_cid_link = CidLink::new_unchecked(new_root_cid.to_string());
274
+
let new_root_cid_link = unsafe { CidLink::new_unchecked(new_root_cid.to_string()) };
279
275
state
280
276
.repo_repo
281
277
.update_repo_root(user_id, &new_root_cid_link, &new_rev_str)
···
368
364
) -> Result<(), tranquil_db::DbError> {
369
365
let data = tranquil_db::CommitEventData {
370
366
did: did.clone(),
371
-
event_type: "commit".to_string(),
372
-
commit_cid: Some(CidLink::new_unchecked(commit_cid)),
367
+
event_type: tranquil_db::RepoEventType::Commit,
368
+
commit_cid: Some(unsafe { CidLink::new_unchecked(commit_cid) }),
373
369
prev_cid: None,
374
370
ops: Some(serde_json::json!([])),
375
371
blobs: Some(vec![]),
+2
-2
crates/tranquil-pds/src/api/repo/meta.rs
+2
-2
crates/tranquil-pds/src/api/repo/meta.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::state::AppState;
3
3
use crate::types::AtIdentifier;
4
+
use crate::util::pds_hostname_without_port;
4
5
use axum::{
5
6
Json,
6
7
extract::{Query, State},
···
18
19
State(state): State<AppState>,
19
20
Query(input): Query<DescribeRepoInput>,
20
21
) -> Response {
21
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
22
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
22
+
let hostname_for_handles = pds_hostname_without_port();
23
23
let user_row = if input.repo.is_did() {
24
24
let did: crate::types::Did = match input.repo.as_str().parse() {
25
25
Ok(d) => d,
+70
-131
crates/tranquil-pds/src/api/repo/record/batch.rs
+70
-131
crates/tranquil-pds/src/api/repo/record/batch.rs
···
1
1
use super::validation::validate_record_with_status;
2
+
use super::validation_mode::{ValidationMode, deserialize_validation_mode};
2
3
use crate::api::error::ApiError;
3
4
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids};
4
-
use crate::auth::{Active, Auth};
5
+
use crate::auth::{
6
+
Active, Auth, WriteOpKind, require_not_migrated, require_verified_or_delegated,
7
+
verify_batch_write_scopes,
8
+
};
9
+
use crate::cid_types::CommitCid;
5
10
use crate::delegation::DelegationActionType;
6
11
use crate::repo::tracking::TrackingBlockStore;
7
12
use crate::state::AppState;
···
34
39
write: &WriteOp,
35
40
acc: WriteAccumulator,
36
41
did: &Did,
37
-
validate: Option<bool>,
42
+
validate: ValidationMode,
38
43
tracking_store: &TrackingBlockStore,
39
44
) -> Result<WriteAccumulator, Response> {
40
45
let WriteAccumulator {
···
51
56
rkey,
52
57
value,
53
58
} => {
54
-
let validation_status = match validate {
55
-
Some(false) => None,
56
-
_ => {
57
-
let require_lexicon = validate == Some(true);
58
-
match validate_record_with_status(
59
-
value,
60
-
collection,
61
-
rkey.as_ref(),
62
-
require_lexicon,
63
-
) {
64
-
Ok(status) => Some(status),
65
-
Err(err_response) => return Err(*err_response),
66
-
}
59
+
let validation_status = if validate.should_skip() {
60
+
None
61
+
} else {
62
+
match validate_record_with_status(
63
+
value,
64
+
collection,
65
+
rkey.as_ref(),
66
+
validate.requires_lexicon(),
67
+
) {
68
+
Ok(status) => Some(status),
69
+
Err(err_response) => return Err(*err_response),
67
70
}
68
71
};
69
72
all_blob_cids.extend(extract_blob_cids(value));
···
104
107
rkey,
105
108
value,
106
109
} => {
107
-
let validation_status = match validate {
108
-
Some(false) => None,
109
-
_ => {
110
-
let require_lexicon = validate == Some(true);
111
-
match validate_record_with_status(
112
-
value,
113
-
collection,
114
-
Some(rkey),
115
-
require_lexicon,
116
-
) {
117
-
Ok(status) => Some(status),
118
-
Err(err_response) => return Err(*err_response),
119
-
}
110
+
let validation_status = if validate.should_skip() {
111
+
None
112
+
} else {
113
+
match validate_record_with_status(
114
+
value,
115
+
collection,
116
+
Some(rkey),
117
+
validate.requires_lexicon(),
118
+
) {
119
+
Ok(status) => Some(status),
120
+
Err(err_response) => return Err(*err_response),
120
121
}
121
122
};
122
123
all_blob_cids.extend(extract_blob_cids(value));
···
181
182
writes: &[WriteOp],
182
183
initial_mst: Mst<TrackingBlockStore>,
183
184
did: &Did,
184
-
validate: Option<bool>,
185
+
validate: ValidationMode,
185
186
tracking_store: &TrackingBlockStore,
186
187
) -> Result<WriteAccumulator, Response> {
187
188
use futures::stream::{self, TryStreamExt};
···
222
223
#[serde(rename_all = "camelCase")]
223
224
pub struct ApplyWritesInput {
224
225
pub repo: AtIdentifier,
225
-
pub validate: Option<bool>,
226
+
#[serde(default, deserialize_with = "deserialize_validation_mode")]
227
+
pub validate: ValidationMode,
226
228
pub writes: Vec<WriteOp>,
227
229
pub swap_commit: Option<String>,
228
230
}
···
270
272
input.repo,
271
273
input.writes.len()
272
274
);
273
-
let did = auth.did.clone();
274
-
let is_oauth = auth.is_oauth();
275
-
let scope = auth.scope.clone();
276
-
let controller_did = auth.controller_did.clone();
277
-
if input.repo.as_str() != did {
278
-
return Err(ApiError::InvalidRepo(
279
-
"Repo does not match authenticated user".into(),
280
-
));
281
-
}
282
-
if state
283
-
.user_repo
284
-
.is_account_migrated(&did)
285
-
.await
286
-
.unwrap_or(false)
287
-
{
288
-
return Err(ApiError::AccountMigrated);
289
-
}
290
-
let is_verified = state
291
-
.user_repo
292
-
.has_verified_comms_channel(&did)
293
-
.await
294
-
.unwrap_or(false);
295
-
let is_delegated = state
296
-
.delegation_repo
297
-
.is_delegated_account(&did)
298
-
.await
299
-
.unwrap_or(false);
300
-
if !is_verified && !is_delegated {
301
-
return Err(ApiError::AccountNotVerified);
302
-
}
275
+
303
276
if input.writes.is_empty() {
304
277
return Err(ApiError::InvalidRequest("writes array is empty".into()));
305
278
}
···
310
283
)));
311
284
}
312
285
313
-
let has_custom_scope = scope
314
-
.as_ref()
315
-
.map(|s| s != "com.atproto.access")
316
-
.unwrap_or(false);
317
-
if is_oauth || has_custom_scope {
318
-
use std::collections::HashSet;
319
-
let create_collections: HashSet<&Nsid> = input
320
-
.writes
321
-
.iter()
322
-
.filter_map(|w| {
323
-
if let WriteOp::Create { collection, .. } = w {
324
-
Some(collection)
325
-
} else {
326
-
None
327
-
}
328
-
})
329
-
.collect();
330
-
let update_collections: HashSet<&Nsid> = input
331
-
.writes
332
-
.iter()
333
-
.filter_map(|w| {
334
-
if let WriteOp::Update { collection, .. } = w {
335
-
Some(collection)
336
-
} else {
337
-
None
338
-
}
339
-
})
340
-
.collect();
341
-
let delete_collections: HashSet<&Nsid> = input
342
-
.writes
343
-
.iter()
344
-
.filter_map(|w| {
345
-
if let WriteOp::Delete { collection, .. } = w {
346
-
Some(collection)
347
-
} else {
348
-
None
349
-
}
350
-
})
351
-
.collect();
286
+
let batch_proof = match verify_batch_write_scopes(
287
+
&auth,
288
+
&auth,
289
+
&input.writes,
290
+
|w| match w {
291
+
WriteOp::Create { collection, .. } => collection.as_str(),
292
+
WriteOp::Update { collection, .. } => collection.as_str(),
293
+
WriteOp::Delete { collection, .. } => collection.as_str(),
294
+
},
295
+
|w| match w {
296
+
WriteOp::Create { .. } => WriteOpKind::Create,
297
+
WriteOp::Update { .. } => WriteOpKind::Update,
298
+
WriteOp::Delete { .. } => WriteOpKind::Delete,
299
+
},
300
+
) {
301
+
Ok(proof) => proof,
302
+
Err(e) => return Ok(e.into_response()),
303
+
};
352
304
353
-
let scope_checks = create_collections
354
-
.iter()
355
-
.map(|c| (crate::oauth::RepoAction::Create, c))
356
-
.chain(
357
-
update_collections
358
-
.iter()
359
-
.map(|c| (crate::oauth::RepoAction::Update, c)),
360
-
)
361
-
.chain(
362
-
delete_collections
363
-
.iter()
364
-
.map(|c| (crate::oauth::RepoAction::Delete, c)),
365
-
);
305
+
let principal_did = batch_proof.principal_did();
306
+
let controller_did = batch_proof.controller_did().map(|c| c.into_did());
366
307
367
-
if let Some(err) = scope_checks
368
-
.filter_map(|(action, collection)| {
369
-
crate::auth::scope_check::check_repo_scope(
370
-
is_oauth,
371
-
scope.as_deref(),
372
-
action,
373
-
collection,
374
-
)
375
-
.err()
376
-
})
377
-
.next()
378
-
{
379
-
return Ok(err);
380
-
}
308
+
if input.repo.as_str() != principal_did.as_str() {
309
+
return Err(ApiError::InvalidRepo(
310
+
"Repo does not match authenticated user".into(),
311
+
));
312
+
}
313
+
314
+
let did = principal_did.into_did();
315
+
if let Err(e) = require_not_migrated(&state, &did).await {
316
+
return Ok(e);
317
+
}
318
+
if let Err(e) = require_verified_or_delegated(&state, batch_proof.user()).await {
319
+
return Ok(e);
381
320
}
382
321
383
322
let user_id: uuid::Uuid = state
···
394
333
.ok()
395
334
.flatten()
396
335
.ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?;
397
-
let current_root_cid = Cid::from_str(&root_cid_str)
336
+
let current_root_cid = CommitCid::from_str(&root_cid_str)
398
337
.map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?;
399
338
if let Some(swap_commit) = &input.swap_commit
400
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
339
+
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
401
340
{
402
341
return Err(ApiError::InvalidSwap(Some("Repo has been modified".into())));
403
342
}
404
343
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
405
344
let commit_bytes = tracking_store
406
-
.get(¤t_root_cid)
345
+
.get(current_root_cid.as_cid())
407
346
.await
408
347
.ok()
409
348
.flatten()
···
471
410
} => Some(*cid),
472
411
_ => None,
473
412
});
474
-
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
413
+
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
475
414
.chain(
476
415
old_mst_blocks
477
416
.keys()
···
487
426
CommitParams {
488
427
did: &did,
489
428
user_id,
490
-
current_root_cid: Some(current_root_cid),
429
+
current_root_cid: Some(current_root_cid.into_cid()),
491
430
prev_data_cid: Some(commit.data),
492
431
new_mst_root,
493
432
ops,
+12
-15
crates/tranquil-pds/src/api/repo/record/delete.rs
+12
-15
crates/tranquil-pds/src/api/repo/record/delete.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
3
3
use crate::api::repo::record::write::{CommitInfo, prepare_repo_write};
4
-
use crate::auth::{Active, Auth};
4
+
use crate::auth::{Active, Auth, VerifyScope};
5
+
use crate::cid_types::CommitCid;
5
6
use crate::delegation::DelegationActionType;
6
7
use crate::repo::tracking::TrackingBlockStore;
7
8
use crate::state::AppState;
···
43
44
auth: Auth<Active>,
44
45
Json(input): Json<DeleteRecordInput>,
45
46
) -> Result<Response, crate::api::error::ApiError> {
46
-
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
47
+
let scope_proof = match auth.verify_repo_delete(&input.collection) {
48
+
Ok(proof) => proof,
49
+
Err(e) => return Ok(e.into_response()),
50
+
};
51
+
52
+
let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await {
47
53
Ok(res) => res,
48
54
Err(err_res) => return Ok(err_res),
49
55
};
50
56
51
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
52
-
repo_auth.is_oauth,
53
-
repo_auth.scope.as_deref(),
54
-
crate::oauth::RepoAction::Delete,
55
-
&input.collection,
56
-
) {
57
-
return Ok(e);
58
-
}
59
-
60
57
let did = repo_auth.did;
61
58
let user_id = repo_auth.user_id;
62
59
let current_root_cid = repo_auth.current_root_cid;
63
60
let controller_did = repo_auth.controller_did;
64
61
65
62
if let Some(swap_commit) = &input.swap_commit
66
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
63
+
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
67
64
{
68
65
return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response());
69
66
}
70
67
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
71
-
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
68
+
let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await {
72
69
Ok(Some(b)) => b,
73
70
_ => {
74
71
return Ok(
···
159
156
.into_iter()
160
157
.collect();
161
158
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
162
-
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
159
+
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
163
160
.chain(
164
161
old_mst_blocks
165
162
.keys()
···
173
170
CommitParams {
174
171
did: &did,
175
172
user_id,
176
-
current_root_cid: Some(current_root_cid),
173
+
current_root_cid: Some(current_root_cid.into_cid()),
177
174
prev_data_cid: Some(commit.data),
178
175
new_mst_root,
179
176
ops: vec![op],
+5
crates/tranquil-pds/src/api/repo/record/mod.rs
+5
crates/tranquil-pds/src/api/repo/record/mod.rs
···
1
1
pub mod batch;
2
2
pub mod delete;
3
+
pub mod pagination;
3
4
pub mod read;
4
5
pub mod utils;
5
6
pub mod validation;
7
+
pub mod validation_mode;
6
8
pub mod write;
7
9
10
+
pub use pagination::PaginationDirection;
11
+
pub use validation_mode::ValidationMode;
12
+
8
13
pub use batch::apply_writes;
9
14
pub use delete::{DeleteRecordInput, delete_record, delete_record_internal};
10
15
pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records};
+31
crates/tranquil-pds/src/api/repo/record/pagination.rs
+31
crates/tranquil-pds/src/api/repo/record/pagination.rs
···
1
+
use serde::{Deserialize, Deserializer};
2
+
3
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
4
+
pub enum PaginationDirection {
5
+
#[default]
6
+
Forward,
7
+
Backward,
8
+
}
9
+
10
+
impl PaginationDirection {
11
+
pub fn from_optional_bool(value: Option<bool>) -> Self {
12
+
match value {
13
+
Some(true) => Self::Backward,
14
+
Some(false) | None => Self::Forward,
15
+
}
16
+
}
17
+
18
+
pub fn is_reverse(&self) -> bool {
19
+
matches!(self, Self::Backward)
20
+
}
21
+
}
22
+
23
+
pub fn deserialize_pagination_direction<'de, D>(
24
+
deserializer: D,
25
+
) -> Result<PaginationDirection, D::Error>
26
+
where
27
+
D: Deserializer<'de>,
28
+
{
29
+
let opt: Option<bool> = Option::deserialize(deserializer)?;
30
+
Ok(PaginationDirection::from_optional_bool(opt))
31
+
}
+7
-7
crates/tranquil-pds/src/api/repo/record/read.rs
+7
-7
crates/tranquil-pds/src/api/repo/record/read.rs
···
1
+
use super::pagination::{PaginationDirection, deserialize_pagination_direction};
1
2
use crate::api::error::ApiError;
2
3
use crate::state::AppState;
3
4
use crate::types::{AtIdentifier, Nsid, Rkey};
5
+
use crate::util::pds_hostname_without_port;
4
6
use axum::{
5
7
Json,
6
8
extract::{Query, State},
···
58
60
_headers: HeaderMap,
59
61
Query(input): Query<GetRecordInput>,
60
62
) -> Response {
61
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
62
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
63
+
let hostname_for_handles = pds_hostname_without_port();
63
64
let user_id_opt = if input.repo.is_did() {
64
65
let did: crate::types::Did = match input.repo.as_str().parse() {
65
66
Ok(d) => d,
···
144
145
pub rkey_start: Option<Rkey>,
145
146
#[serde(rename = "rkeyEnd")]
146
147
pub rkey_end: Option<Rkey>,
147
-
pub reverse: Option<bool>,
148
+
#[serde(default, deserialize_with = "deserialize_pagination_direction")]
149
+
pub reverse: PaginationDirection,
148
150
}
149
151
#[derive(Serialize)]
150
152
pub struct ListRecordsOutput {
···
157
159
State(state): State<AppState>,
158
160
Query(input): Query<ListRecordsInput>,
159
161
) -> Response {
160
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
161
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
162
+
let hostname_for_handles = pds_hostname_without_port();
162
163
let user_id_opt = if input.repo.is_did() {
163
164
let did: crate::types::Did = match input.repo.as_str().parse() {
164
165
Ok(d) => d,
···
194
195
}
195
196
};
196
197
let limit = input.limit.unwrap_or(50).clamp(1, 100);
197
-
let reverse = input.reverse.unwrap_or(false);
198
198
let limit_i64 = limit as i64;
199
199
let cursor_rkey = input
200
200
.cursor
···
207
207
&input.collection,
208
208
cursor_rkey.as_ref(),
209
209
limit_i64,
210
-
reverse,
210
+
input.reverse.is_reverse(),
211
211
input.rkey_start.as_ref(),
212
212
input.rkey_end.as_ref(),
213
213
)
+21
-19
crates/tranquil-pds/src/api/repo/record/utils.rs
+21
-19
crates/tranquil-pds/src/api/repo/record/utils.rs
···
8
8
use k256::ecdsa::SigningKey;
9
9
use serde_json::{Value, json};
10
10
use std::str::FromStr;
11
+
use tranquil_db_traits::SequenceNumber;
11
12
use uuid::Uuid;
12
13
13
14
pub fn extract_blob_cids(record: &Value) -> Vec<String> {
···
139
140
) -> Result<CommitResult, String> {
140
141
use tranquil_db_traits::{
141
142
ApplyCommitError, ApplyCommitInput, CommitEventData, RecordDelete, RecordUpsert,
143
+
RepoEventType,
142
144
};
143
145
144
146
let CommitParams {
···
199
201
upserts.push(RecordUpsert {
200
202
collection: collection.clone(),
201
203
rkey: rkey.clone(),
202
-
cid: crate::types::CidLink::new_unchecked(cid.to_string()),
204
+
cid: unsafe { crate::types::CidLink::new_unchecked(cid.to_string()) },
203
205
});
204
206
}
205
207
RecordOp::Delete {
···
263
265
264
266
let commit_event = CommitEventData {
265
267
did: did.clone(),
266
-
event_type: "commit".to_string(),
267
-
commit_cid: Some(crate::types::CidLink::new_unchecked(
268
-
new_root_cid.to_string(),
269
-
)),
270
-
prev_cid: current_root_cid.map(|c| crate::types::CidLink::new_unchecked(c.to_string())),
268
+
event_type: RepoEventType::Commit,
269
+
commit_cid: Some(unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) }),
270
+
prev_cid: current_root_cid
271
+
.map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }),
271
272
ops: Some(json!(ops_json)),
272
273
blobs: Some(blobs.to_vec()),
273
274
blocks_cids: Some(blocks_cids.to_vec()),
274
-
prev_data_cid: prev_data_cid.map(|c| crate::types::CidLink::new_unchecked(c.to_string())),
275
+
prev_data_cid: prev_data_cid
276
+
.map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }),
275
277
rev: Some(rev_str.clone()),
276
278
};
277
279
···
279
281
user_id,
280
282
did: did.clone(),
281
283
expected_root_cid: current_root_cid
282
-
.map(|c| crate::types::CidLink::new_unchecked(c.to_string())),
283
-
new_root_cid: crate::types::CidLink::new_unchecked(new_root_cid.to_string()),
284
+
.map(|c| unsafe { crate::types::CidLink::new_unchecked(c.to_string()) }),
285
+
new_root_cid: unsafe { crate::types::CidLink::new_unchecked(new_root_cid.to_string()) },
284
286
new_rev: rev_str.clone(),
285
287
new_block_cids: all_block_cids,
286
288
obsolete_block_cids: obsolete_bytes,
···
417
419
state: &AppState,
418
420
did: &Did,
419
421
handle: Option<&Handle>,
420
-
) -> Result<i64, String> {
422
+
) -> Result<SequenceNumber, String> {
421
423
state
422
424
.repo_repo
423
425
.insert_identity_event(did, handle)
···
427
429
pub async fn sequence_account_event(
428
430
state: &AppState,
429
431
did: &Did,
430
-
active: bool,
431
-
status: Option<&str>,
432
-
) -> Result<i64, String> {
432
+
status: tranquil_db_traits::AccountStatus,
433
+
) -> Result<SequenceNumber, String> {
433
434
state
434
435
.repo_repo
435
-
.insert_account_event(did, active, status)
436
+
.insert_account_event(did, status)
436
437
.await
437
438
.map_err(|e| format!("DB Error (account event): {}", e))
438
439
}
···
441
442
did: &Did,
442
443
commit_cid: &str,
443
444
rev: Option<&str>,
444
-
) -> Result<i64, String> {
445
-
let cid_link = crate::types::CidLink::new_unchecked(commit_cid);
445
+
) -> Result<SequenceNumber, String> {
446
+
let cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid) };
446
447
state
447
448
.repo_repo
448
449
.insert_sync_event(did, &cid_link, rev)
···
456
457
commit_cid: &Cid,
457
458
mst_root_cid: &Cid,
458
459
rev: &str,
459
-
) -> Result<i64, String> {
460
-
let commit_cid_link = crate::types::CidLink::new_unchecked(commit_cid.to_string());
461
-
let mst_root_cid_link = crate::types::CidLink::new_unchecked(mst_root_cid.to_string());
460
+
) -> Result<SequenceNumber, String> {
461
+
let commit_cid_link = unsafe { crate::types::CidLink::new_unchecked(commit_cid.to_string()) };
462
+
let mst_root_cid_link =
463
+
unsafe { crate::types::CidLink::new_unchecked(mst_root_cid.to_string()) };
462
464
state
463
465
.repo_repo
464
466
.insert_genesis_commit_event(did, &commit_cid_link, &mst_root_cid_link, rev)
+35
crates/tranquil-pds/src/api/repo/record/validation_mode.rs
+35
crates/tranquil-pds/src/api/repo/record/validation_mode.rs
···
1
+
use serde::{Deserialize, Deserializer};
2
+
3
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
4
+
pub enum ValidationMode {
5
+
Skip,
6
+
#[default]
7
+
Infer,
8
+
Strict,
9
+
}
10
+
11
+
impl ValidationMode {
12
+
pub fn from_optional_bool(value: Option<bool>) -> Self {
13
+
match value {
14
+
Some(false) => Self::Skip,
15
+
Some(true) => Self::Strict,
16
+
None => Self::Infer,
17
+
}
18
+
}
19
+
20
+
pub fn should_skip(&self) -> bool {
21
+
matches!(self, Self::Skip)
22
+
}
23
+
24
+
pub fn requires_lexicon(&self) -> bool {
25
+
matches!(self, Self::Strict)
26
+
}
27
+
}
28
+
29
+
pub fn deserialize_validation_mode<'de, D>(deserializer: D) -> Result<ValidationMode, D::Error>
30
+
where
31
+
D: Deserializer<'de>,
32
+
{
33
+
let opt: Option<bool> = Option::deserialize(deserializer)?;
34
+
Ok(ValidationMode::from_optional_bool(opt))
35
+
}
+51
-77
crates/tranquil-pds/src/api/repo/record/write.rs
+51
-77
crates/tranquil-pds/src/api/repo/record/write.rs
···
1
1
use super::validation::validate_record_with_status;
2
+
use super::validation_mode::{ValidationMode, deserialize_validation_mode};
2
3
use crate::api::error::ApiError;
3
4
use crate::api::repo::record::utils::{
4
5
CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids,
5
6
};
6
-
use crate::auth::{Active, Auth};
7
+
use crate::auth::{
8
+
Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated,
9
+
require_verified_or_delegated,
10
+
};
11
+
use crate::cid_types::CommitCid;
7
12
use crate::delegation::DelegationActionType;
8
13
use crate::repo::tracking::TrackingBlockStore;
9
14
use crate::state::AppState;
···
26
31
pub struct RepoWriteAuth {
27
32
pub did: Did,
28
33
pub user_id: Uuid,
29
-
pub current_root_cid: Cid,
34
+
pub current_root_cid: CommitCid,
30
35
pub is_oauth: bool,
31
36
pub scope: Option<String>,
32
37
pub controller_did: Option<Did>,
33
38
}
34
39
35
-
pub async fn prepare_repo_write(
40
+
pub async fn prepare_repo_write<A: RepoScopeAction>(
36
41
state: &AppState,
37
-
auth_user: &crate::auth::AuthenticatedUser,
42
+
scope_proof: &ScopeVerified<'_, A>,
38
43
repo: &AtIdentifier,
39
44
) -> Result<RepoWriteAuth, Response> {
40
-
if repo.as_str() != auth_user.did.as_str() {
45
+
let user = scope_proof.user();
46
+
let principal_did = scope_proof.principal_did();
47
+
if repo.as_str() != principal_did.as_str() {
41
48
return Err(
42
49
ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(),
43
50
);
44
51
}
45
-
if state
46
-
.user_repo
47
-
.is_account_migrated(&auth_user.did)
48
-
.await
49
-
.unwrap_or(false)
50
-
{
51
-
return Err(ApiError::AccountMigrated.into_response());
52
-
}
53
-
let is_verified = state
54
-
.user_repo
55
-
.has_verified_comms_channel(&auth_user.did)
56
-
.await
57
-
.unwrap_or(false);
58
-
let is_delegated = state
59
-
.delegation_repo
60
-
.is_delegated_account(&auth_user.did)
61
-
.await
62
-
.unwrap_or(false);
63
-
if !is_verified && !is_delegated {
64
-
return Err(ApiError::AccountNotVerified.into_response());
65
-
}
52
+
53
+
require_not_migrated(state, principal_did.as_did()).await?;
54
+
let _account_verified = require_verified_or_delegated(state, user).await?;
55
+
66
56
let user_id = state
67
57
.user_repo
68
-
.get_id_by_did(&auth_user.did)
58
+
.get_id_by_did(principal_did.as_did())
69
59
.await
70
60
.map_err(|e| {
71
61
error!("DB error fetching user: {}", e);
···
83
73
.ok_or_else(|| {
84
74
ApiError::InternalError(Some("Repo root not found".into())).into_response()
85
75
})?;
86
-
let current_root_cid = Cid::from_str(&root_cid_str).map_err(|_| {
76
+
let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| {
87
77
ApiError::InternalError(Some("Invalid repo root CID".into())).into_response()
88
78
})?;
89
79
Ok(RepoWriteAuth {
90
-
did: auth_user.did.clone(),
80
+
did: principal_did.into_did(),
91
81
user_id,
92
82
current_root_cid,
93
-
is_oauth: auth_user.is_oauth(),
94
-
scope: auth_user.scope.clone(),
95
-
controller_did: auth_user.controller_did.clone(),
83
+
is_oauth: user.is_oauth(),
84
+
scope: user.scope.clone(),
85
+
controller_did: scope_proof.controller_did().map(|c| c.into_did()),
96
86
})
97
87
}
98
88
#[derive(Deserialize)]
···
101
91
pub repo: AtIdentifier,
102
92
pub collection: Nsid,
103
93
pub rkey: Option<Rkey>,
104
-
pub validate: Option<bool>,
94
+
#[serde(default, deserialize_with = "deserialize_validation_mode")]
95
+
pub validate: ValidationMode,
105
96
pub record: serde_json::Value,
106
97
#[serde(rename = "swapCommit")]
107
98
pub swap_commit: Option<String>,
···
127
118
auth: Auth<Active>,
128
119
Json(input): Json<CreateRecordInput>,
129
120
) -> Result<Response, crate::api::error::ApiError> {
130
-
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
121
+
let scope_proof = match auth.verify_repo_create(&input.collection) {
122
+
Ok(proof) => proof,
123
+
Err(e) => return Ok(e.into_response()),
124
+
};
125
+
126
+
let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await {
131
127
Ok(res) => res,
132
128
Err(err_res) => return Ok(err_res),
133
129
};
134
130
135
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
136
-
repo_auth.is_oauth,
137
-
repo_auth.scope.as_deref(),
138
-
crate::oauth::RepoAction::Create,
139
-
&input.collection,
140
-
) {
141
-
return Ok(e);
142
-
}
143
-
144
131
let did = repo_auth.did;
145
132
let user_id = repo_auth.user_id;
146
133
let current_root_cid = repo_auth.current_root_cid;
147
134
let controller_did = repo_auth.controller_did;
148
135
149
136
if let Some(swap_commit) = &input.swap_commit
150
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
137
+
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
151
138
{
152
139
return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response());
153
140
}
154
141
155
-
let validation_status = if input.validate == Some(false) {
142
+
let validation_status = if input.validate.should_skip() {
156
143
None
157
144
} else {
158
-
let require_lexicon = input.validate == Some(true);
159
145
match validate_record_with_status(
160
146
&input.record,
161
147
&input.collection,
162
148
input.rkey.as_ref(),
163
-
require_lexicon,
149
+
input.validate.requires_lexicon(),
164
150
) {
165
151
Ok(status) => Some(status),
166
152
Err(err_response) => return Ok(*err_response),
···
169
155
let rkey = input.rkey.unwrap_or_else(Rkey::generate);
170
156
171
157
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
172
-
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
158
+
let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await {
173
159
Ok(Some(b)) => b,
174
160
_ => {
175
161
return Ok(
···
192
178
let mut conflict_uris_to_cleanup: Vec<AtUri> = Vec::new();
193
179
let mut all_old_mst_blocks = std::collections::BTreeMap::new();
194
180
195
-
if input.validate != Some(false) {
181
+
if !input.validate.should_skip() {
196
182
let record_uri = AtUri::from_parts(&did, &input.collection, &rkey);
197
183
let backlinks = extract_backlinks(&record_uri, &input.record);
198
184
···
323
309
.collect();
324
310
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
325
311
let blob_cids = extract_blob_cids(&input.record);
326
-
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
312
+
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
327
313
.chain(
328
314
all_old_mst_blocks
329
315
.keys()
···
337
323
CommitParams {
338
324
did: &did,
339
325
user_id,
340
-
current_root_cid: Some(current_root_cid),
326
+
current_root_cid: Some(current_root_cid.into_cid()),
341
327
prev_data_cid: Some(initial_mst_root),
342
328
new_mst_root,
343
329
ops,
···
412
398
pub repo: AtIdentifier,
413
399
pub collection: Nsid,
414
400
pub rkey: Rkey,
415
-
pub validate: Option<bool>,
401
+
#[serde(default, deserialize_with = "deserialize_validation_mode")]
402
+
pub validate: ValidationMode,
416
403
pub record: serde_json::Value,
417
404
#[serde(rename = "swapCommit")]
418
405
pub swap_commit: Option<String>,
···
434
421
auth: Auth<Active>,
435
422
Json(input): Json<PutRecordInput>,
436
423
) -> Result<Response, crate::api::error::ApiError> {
437
-
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
424
+
let upsert_proof = match auth.verify_repo_upsert(&input.collection) {
425
+
Ok(proof) => proof,
426
+
Err(e) => return Ok(e.into_response()),
427
+
};
428
+
429
+
let repo_auth = match prepare_repo_write(&state, &upsert_proof, &input.repo).await {
438
430
Ok(res) => res,
439
431
Err(err_res) => return Ok(err_res),
440
432
};
441
433
442
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
443
-
repo_auth.is_oauth,
444
-
repo_auth.scope.as_deref(),
445
-
crate::oauth::RepoAction::Create,
446
-
&input.collection,
447
-
) {
448
-
return Ok(e);
449
-
}
450
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
451
-
repo_auth.is_oauth,
452
-
repo_auth.scope.as_deref(),
453
-
crate::oauth::RepoAction::Update,
454
-
&input.collection,
455
-
) {
456
-
return Ok(e);
457
-
}
458
-
459
434
let did = repo_auth.did;
460
435
let user_id = repo_auth.user_id;
461
436
let current_root_cid = repo_auth.current_root_cid;
462
437
let controller_did = repo_auth.controller_did;
463
438
464
439
if let Some(swap_commit) = &input.swap_commit
465
-
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid)
440
+
&& CommitCid::from_str(swap_commit).ok().as_ref() != Some(¤t_root_cid)
466
441
{
467
442
return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response());
468
443
}
469
444
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
470
-
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
445
+
let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await {
471
446
Ok(Some(b)) => b,
472
447
_ => {
473
448
return Ok(
···
485
460
};
486
461
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
487
462
let key = format!("{}/{}", input.collection, input.rkey);
488
-
let validation_status = if input.validate == Some(false) {
463
+
let validation_status = if input.validate.should_skip() {
489
464
None
490
465
} else {
491
-
let require_lexicon = input.validate == Some(true);
492
466
match validate_record_with_status(
493
467
&input.record,
494
468
&input.collection,
495
469
Some(&input.rkey),
496
-
require_lexicon,
470
+
input.validate.requires_lexicon(),
497
471
) {
498
472
Ok(status) => Some(status),
499
473
Err(err_response) => return Ok(*err_response),
···
610
584
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
611
585
let is_update = existing_cid.is_some();
612
586
let blob_cids = extract_blob_cids(&input.record);
613
-
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
587
+
let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid())
614
588
.chain(
615
589
old_mst_blocks
616
590
.keys()
···
624
598
CommitParams {
625
599
did: &did,
626
600
user_id,
627
-
current_root_cid: Some(current_root_cid),
601
+
current_root_cid: Some(current_root_cid.into_cid()),
628
602
prev_data_cid: Some(commit.data),
629
603
new_mst_root,
630
604
ops: vec![op],
+31
-32
crates/tranquil-pds/src/api/server/account_status.rs
+31
-32
crates/tranquil-pds/src/api/server/account_status.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::{Auth, NotTakendown, Permissive};
2
+
use crate::api::error::{ApiError, DbResultExt};
3
+
use crate::auth::{Auth, NotTakendown, Permissive, require_legacy_session_mfa};
4
4
use crate::cache::Cache;
5
5
use crate::plc::PlcClient;
6
6
use crate::state::AppState;
7
7
use crate::types::PlainPassword;
8
+
use crate::util::pds_hostname;
8
9
use axum::{
9
10
Json,
10
11
extract::State,
···
130
131
did: &crate::types::Did,
131
132
with_retry: bool,
132
133
) -> Result<(), ApiError> {
133
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
134
+
let hostname = pds_hostname();
134
135
let expected_endpoint = format!("https://{}", hostname);
135
136
136
137
if did.as_str().starts_with("did:plc:") {
···
219
220
.and_then(|v| v.get("atproto"))
220
221
.and_then(|k| k.as_str());
221
222
222
-
let user_key = user_repo.get_user_key_by_did(did).await.map_err(|e| {
223
-
error!("Failed to fetch user key: {:?}", e);
224
-
ApiError::InternalError(None)
225
-
})?;
223
+
let user_key = user_repo
224
+
.get_user_key_by_did(did)
225
+
.await
226
+
.log_db_err("fetching user key")?;
226
227
227
228
if let Some(key_info) = user_key {
228
229
let key_bytes =
···
379
380
"[MIGRATION] activateAccount: Sequencing account event (active=true) for did={}",
380
381
did
381
382
);
382
-
if let Err(e) =
383
-
crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
383
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
384
+
&state,
385
+
&did,
386
+
tranquil_db_traits::AccountStatus::Active,
387
+
)
388
+
.await
384
389
{
385
390
warn!(
386
391
"[MIGRATION] activateAccount: Failed to sequence account activation event: {}",
···
502
507
if let Err(e) = crate::api::repo::record::sequence_account_event(
503
508
&state,
504
509
&did,
505
-
false,
506
-
Some("deactivated"),
510
+
tranquil_db_traits::AccountStatus::Deactivated,
507
511
)
508
512
.await
509
513
{
···
523
527
State(state): State<AppState>,
524
528
auth: Auth<NotTakendown>,
525
529
) -> Result<Response, ApiError> {
526
-
let did = &auth.did;
527
-
528
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await {
529
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
530
-
&*state.user_repo,
531
-
&*state.session_repo,
532
-
did,
533
-
)
534
-
.await);
535
-
}
530
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
531
+
Ok(proof) => proof,
532
+
Err(response) => return Ok(response),
533
+
};
536
534
537
535
let user_id = state
538
536
.user_repo
539
-
.get_id_by_did(did)
537
+
.get_id_by_did(session_mfa.did())
540
538
.await
541
539
.ok()
542
540
.flatten()
···
545
543
let expires_at = Utc::now() + Duration::minutes(15);
546
544
state
547
545
.infra_repo
548
-
.create_deletion_request(&confirmation_token, did, expires_at)
546
+
.create_deletion_request(&confirmation_token, session_mfa.did(), expires_at)
549
547
.await
550
-
.map_err(|e| {
551
-
error!("DB error creating deletion token: {:?}", e);
552
-
ApiError::InternalError(None)
553
-
})?;
554
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
548
+
.log_db_err("creating deletion token")?;
549
+
let hostname = pds_hostname();
555
550
if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion(
556
551
state.user_repo.as_ref(),
557
552
state.infra_repo.as_ref(),
558
553
user_id,
559
554
&confirmation_token,
560
-
&hostname,
555
+
hostname,
561
556
)
562
557
.await
563
558
{
564
559
warn!("Failed to enqueue account deletion notification: {:?}", e);
565
560
}
566
-
info!("Account deletion requested for user {}", did);
561
+
info!("Account deletion requested for user {}", session_mfa.did());
567
562
Ok(EmptyResponse::ok().into_response())
568
563
}
569
564
···
642
637
error!("DB error deleting account: {:?}", e);
643
638
return ApiError::InternalError(None).into_response();
644
639
}
645
-
let account_seq =
646
-
crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await;
640
+
let account_seq = crate::api::repo::record::sequence_account_event(
641
+
&state,
642
+
did,
643
+
tranquil_db_traits::AccountStatus::Deleted,
644
+
)
645
+
.await;
647
646
match account_seq {
648
647
Ok(seq) => {
649
648
if let Err(e) = state.repo_repo.delete_sequences_except(did, seq).await {
+19
-51
crates/tranquil-pds/src/api/server/app_password.rs
+19
-51
crates/tranquil-pds/src/api/server/app_password.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Auth, NotTakendown, Permissive, generate_app_password};
4
4
use crate::delegation::{DelegationActionType, intersect_scopes};
5
-
use crate::state::{AppState, RateLimitKind};
5
+
use crate::rate_limit::{AppPasswordLimit, RateLimited};
6
+
use crate::state::AppState;
6
7
use axum::{
7
8
Json,
8
9
extract::State,
9
-
http::HeaderMap,
10
10
response::{IntoResponse, Response},
11
11
};
12
12
use serde::{Deserialize, Serialize};
13
13
use serde_json::json;
14
-
use tracing::{error, warn};
14
+
use tracing::error;
15
15
use tranquil_db_traits::AppPasswordCreate;
16
16
17
17
#[derive(Serialize)]
···
39
39
.user_repo
40
40
.get_by_did(&auth.did)
41
41
.await
42
-
.map_err(|e| {
43
-
error!("DB error getting user: {:?}", e);
44
-
ApiError::InternalError(None)
45
-
})?
42
+
.log_db_err("getting user")?
46
43
.ok_or(ApiError::AccountNotFound)?;
47
44
48
45
let rows = state
49
46
.session_repo
50
47
.list_app_passwords(user.id)
51
48
.await
52
-
.map_err(|e| {
53
-
error!("DB error listing app passwords: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?;
49
+
.log_db_err("listing app passwords")?;
56
50
let passwords: Vec<AppPassword> = rows
57
51
.iter()
58
52
.map(|row| AppPassword {
59
53
name: row.name.clone(),
60
54
created_at: row.created_at.to_rfc3339(),
61
-
privileged: row.privileged,
55
+
privileged: row.privilege.is_privileged(),
62
56
scopes: row.scopes.clone(),
63
57
created_by_controller: row
64
58
.created_by_controller_did
···
89
83
90
84
pub async fn create_app_password(
91
85
State(state): State<AppState>,
92
-
headers: HeaderMap,
86
+
_rate_limit: RateLimited<AppPasswordLimit>,
93
87
auth: Auth<NotTakendown>,
94
88
Json(input): Json<CreateAppPasswordInput>,
95
89
) -> Result<Response, ApiError> {
96
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
97
-
if !state
98
-
.check_rate_limit(RateLimitKind::AppPassword, &client_ip)
99
-
.await
100
-
{
101
-
warn!(ip = %client_ip, "App password creation rate limit exceeded");
102
-
return Err(ApiError::RateLimitExceeded(None));
103
-
}
104
-
105
90
let user = state
106
91
.user_repo
107
92
.get_by_did(&auth.did)
108
93
.await
109
-
.map_err(|e| {
110
-
error!("DB error getting user: {:?}", e);
111
-
ApiError::InternalError(None)
112
-
})?
94
+
.log_db_err("getting user")?
113
95
.ok_or(ApiError::AccountNotFound)?;
114
96
115
97
let name = input.name.trim();
···
121
103
.session_repo
122
104
.get_app_password_by_name(user.id, name)
123
105
.await
124
-
.map_err(|e| {
125
-
error!("DB error checking app password: {:?}", e);
126
-
ApiError::InternalError(None)
127
-
})?
106
+
.log_db_err("checking app password")?
128
107
.is_some()
129
108
{
130
109
return Err(ApiError::DuplicateAppPassword);
···
140
119
let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default();
141
120
142
121
let requested = input.scopes.as_deref().unwrap_or("atproto");
143
-
let intersected = intersect_scopes(requested, &granted_scopes);
122
+
let intersected = intersect_scopes(requested, granted_scopes.as_str());
144
123
145
124
if intersected.is_empty() && !granted_scopes.is_empty() {
146
125
return Err(ApiError::InsufficientScope(None));
···
171
150
ApiError::InternalError(None)
172
151
})?;
173
152
174
-
let privileged = input.privileged.unwrap_or(false);
153
+
let privilege =
154
+
tranquil_db_traits::AppPasswordPrivilege::from(input.privileged.unwrap_or(false));
175
155
let created_at = chrono::Utc::now();
176
156
177
157
let create_data = AppPasswordCreate {
178
158
user_id: user.id,
179
159
name: name.to_string(),
180
160
password_hash,
181
-
privileged,
161
+
privilege,
182
162
scopes: final_scopes.clone(),
183
163
created_by_controller_did: controller_did.clone(),
184
164
};
···
187
167
.session_repo
188
168
.create_app_password(&create_data)
189
169
.await
190
-
.map_err(|e| {
191
-
error!("DB error creating app password: {:?}", e);
192
-
ApiError::InternalError(None)
193
-
})?;
170
+
.log_db_err("creating app password")?;
194
171
195
172
if let Some(ref controller) = controller_did {
196
173
let _ = state
···
214
191
name: name.to_string(),
215
192
password,
216
193
created_at: created_at.to_rfc3339(),
217
-
privileged,
194
+
privileged: privilege.is_privileged(),
218
195
scopes: final_scopes,
219
196
})
220
197
.into_response())
···
234
211
.user_repo
235
212
.get_by_did(&auth.did)
236
213
.await
237
-
.map_err(|e| {
238
-
error!("DB error getting user: {:?}", e);
239
-
ApiError::InternalError(None)
240
-
})?
214
+
.log_db_err("getting user")?
241
215
.ok_or(ApiError::AccountNotFound)?;
242
216
243
217
let name = input.name.trim();
···
255
229
.session_repo
256
230
.delete_sessions_by_app_password(&auth.did, name)
257
231
.await
258
-
.map_err(|e| {
259
-
error!("DB error revoking sessions for app password: {:?}", e);
260
-
ApiError::InternalError(None)
261
-
})?;
232
+
.log_db_err("revoking sessions for app password")?;
262
233
263
234
futures::future::join_all(sessions_to_invalidate.iter().map(|jti| {
264
235
let cache_key = format!("auth:session:{}:{}", &auth.did, jti);
···
273
244
.session_repo
274
245
.delete_app_password(user.id, name)
275
246
.await
276
-
.map_err(|e| {
277
-
error!("DB error revoking app password: {:?}", e);
278
-
ApiError::InternalError(None)
279
-
})?;
247
+
.log_db_err("revoking app password")?;
280
248
281
249
Ok(EmptyResponse::ok().into_response())
282
250
}
+21
-92
crates/tranquil-pds/src/api/server/email.rs
+21
-92
crates/tranquil-pds/src/api/server/email.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse};
3
3
use crate::auth::{Auth, NotTakendown};
4
-
use crate::state::{AppState, RateLimitKind};
4
+
use crate::rate_limit::{EmailUpdateLimit, RateLimited, VerificationCheckLimit};
5
+
use crate::state::AppState;
6
+
use crate::util::pds_hostname;
5
7
use axum::{
6
8
Json,
7
9
extract::State,
···
44
46
45
47
pub async fn request_email_update(
46
48
State(state): State<AppState>,
47
-
headers: axum::http::HeaderMap,
49
+
_rate_limit: RateLimited<EmailUpdateLimit>,
48
50
auth: Auth<NotTakendown>,
49
51
input: Option<Json<RequestEmailUpdateInput>>,
50
52
) -> Result<Response, ApiError> {
51
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
52
-
if !state
53
-
.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip)
54
-
.await
55
-
{
56
-
warn!(ip = %client_ip, "Email update rate limit exceeded");
57
-
return Err(ApiError::RateLimitExceeded(None));
58
-
}
59
-
60
53
if let Err(e) = crate::auth::scope_check::check_account_scope(
61
54
auth.is_oauth(),
62
55
auth.scope.as_deref(),
···
70
63
.user_repo
71
64
.get_email_info_by_did(&auth.did)
72
65
.await
73
-
.map_err(|e| {
74
-
error!("DB error: {:?}", e);
75
-
ApiError::InternalError(None)
76
-
})?
66
+
.log_db_err("getting email info")?
77
67
.ok_or(ApiError::AccountNotFound)?;
78
68
79
69
let Some(current_email) = user.email else {
···
111
101
}
112
102
}
113
103
114
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
104
+
let hostname = pds_hostname();
115
105
if let Err(e) = crate::comms::comms_repo::enqueue_email_update_token(
116
106
state.user_repo.as_ref(),
117
107
state.infra_repo.as_ref(),
118
108
user.id,
119
109
&code,
120
110
&formatted_code,
121
-
&hostname,
111
+
hostname,
122
112
)
123
113
.await
124
114
{
···
139
129
140
130
pub async fn confirm_email(
141
131
State(state): State<AppState>,
142
-
headers: axum::http::HeaderMap,
132
+
_rate_limit: RateLimited<EmailUpdateLimit>,
143
133
auth: Auth<NotTakendown>,
144
134
Json(input): Json<ConfirmEmailInput>,
145
135
) -> Result<Response, ApiError> {
146
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
147
-
if !state
148
-
.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip)
149
-
.await
150
-
{
151
-
warn!(ip = %client_ip, "Confirm email rate limit exceeded");
152
-
return Err(ApiError::RateLimitExceeded(None));
153
-
}
154
-
155
136
if let Err(e) = crate::auth::scope_check::check_account_scope(
156
137
auth.is_oauth(),
157
138
auth.scope.as_deref(),
···
166
147
.user_repo
167
148
.get_email_info_by_did(did)
168
149
.await
169
-
.map_err(|e| {
170
-
error!("DB error: {:?}", e);
171
-
ApiError::InternalError(None)
172
-
})?
150
+
.log_db_err("getting email info")?
173
151
.ok_or(ApiError::AccountNotFound)?;
174
152
175
153
let Some(ref email) = user.email else {
···
213
191
.user_repo
214
192
.set_email_verified(user.id, true)
215
193
.await
216
-
.map_err(|e| {
217
-
error!("DB error confirming email: {:?}", e);
218
-
ApiError::InternalError(None)
219
-
})?;
194
+
.log_db_err("confirming email")?;
220
195
221
196
info!("Email confirmed for user {}", user.id);
222
197
Ok(EmptyResponse::ok().into_response())
···
250
225
.user_repo
251
226
.get_email_info_by_did(did)
252
227
.await
253
-
.map_err(|e| {
254
-
error!("DB error: {:?}", e);
255
-
ApiError::InternalError(None)
256
-
})?
228
+
.log_db_err("getting email info")?
257
229
.ok_or(ApiError::AccountNotFound)?;
258
230
259
231
let user_id = user.id;
···
325
297
.user_repo
326
298
.update_email(user_id, &new_email)
327
299
.await
328
-
.map_err(|e| {
329
-
error!("DB error updating email: {:?}", e);
330
-
ApiError::InternalError(None)
331
-
})?;
300
+
.log_db_err("updating email")?;
332
301
333
302
let verification_token =
334
303
crate::auth::verification_token::generate_signup_token(did, "email", &new_email);
335
304
let formatted_token =
336
305
crate::auth::verification_token::format_token_for_display(&verification_token);
337
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
306
+
let hostname = pds_hostname();
338
307
if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification(
339
308
state.infra_repo.as_ref(),
340
309
user_id,
341
310
"email",
342
311
&new_email,
343
312
&formatted_token,
344
-
&hostname,
313
+
hostname,
345
314
)
346
315
.await
347
316
{
···
371
340
372
341
pub async fn check_email_verified(
373
342
State(state): State<AppState>,
374
-
headers: axum::http::HeaderMap,
343
+
_rate_limit: RateLimited<VerificationCheckLimit>,
375
344
Json(input): Json<CheckEmailVerifiedInput>,
376
345
) -> Response {
377
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
378
-
if !state
379
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
380
-
.await
381
-
{
382
-
return ApiError::RateLimitExceeded(None).into_response();
383
-
}
384
-
385
346
match state
386
347
.user_repo
387
348
.check_email_verified_by_identifier(&input.identifier)
···
403
364
404
365
pub async fn authorize_email_update(
405
366
State(state): State<AppState>,
406
-
headers: axum::http::HeaderMap,
367
+
_rate_limit: RateLimited<VerificationCheckLimit>,
407
368
axum::extract::Query(query): axum::extract::Query<AuthorizeEmailUpdateQuery>,
408
369
) -> Response {
409
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
410
-
if !state
411
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
412
-
.await
413
-
{
414
-
return ApiError::RateLimitExceeded(None).into_response();
415
-
}
416
-
417
370
let verified = crate::auth::verification_token::verify_token_signature(&query.token);
418
371
419
372
let token_data = match verified {
···
488
441
489
442
info!(did = %did, "Email update authorized via link click");
490
443
491
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
444
+
let hostname = pds_hostname();
492
445
let redirect_url = format!(
493
446
"https://{}/app/verify?type=email-authorize-success",
494
447
hostname
···
499
452
500
453
pub async fn check_email_update_status(
501
454
State(state): State<AppState>,
502
-
headers: axum::http::HeaderMap,
455
+
_rate_limit: RateLimited<VerificationCheckLimit>,
503
456
auth: Auth<NotTakendown>,
504
457
) -> Result<Response, ApiError> {
505
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
506
-
if !state
507
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
508
-
.await
509
-
{
510
-
return Err(ApiError::RateLimitExceeded(None));
511
-
}
512
-
513
458
if let Err(e) = crate::auth::scope_check::check_account_scope(
514
459
auth.is_oauth(),
515
460
auth.scope.as_deref(),
···
549
494
550
495
pub async fn check_email_in_use(
551
496
State(state): State<AppState>,
552
-
headers: axum::http::HeaderMap,
497
+
_rate_limit: RateLimited<VerificationCheckLimit>,
553
498
Json(input): Json<CheckEmailInUseInput>,
554
499
) -> Response {
555
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
556
-
if !state
557
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
558
-
.await
559
-
{
560
-
return ApiError::RateLimitExceeded(None).into_response();
561
-
}
562
-
563
500
let email = input.email.trim().to_lowercase();
564
501
if email.is_empty() {
565
502
return ApiError::InvalidRequest("email is required".into()).into_response();
···
587
524
588
525
pub async fn check_comms_channel_in_use(
589
526
State(state): State<AppState>,
590
-
headers: axum::http::HeaderMap,
527
+
_rate_limit: RateLimited<VerificationCheckLimit>,
591
528
Json(input): Json<CheckCommsChannelInUseInput>,
592
529
) -> Response {
593
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
594
-
if !state
595
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
596
-
.await
597
-
{
598
-
return ApiError::RateLimitExceeded(None).into_response();
599
-
}
600
-
601
530
let channel = match input.channel.to_lowercase().as_str() {
602
531
"email" => CommsChannel::Email,
603
532
"discord" => CommsChannel::Discord,
+6
-10
crates/tranquil-pds/src/api/server/invite.rs
+6
-10
crates/tranquil-pds/src/api/server/invite.rs
···
1
1
use crate::api::ApiError;
2
+
use crate::api::error::DbResultExt;
2
3
use crate::auth::{Admin, Auth, NotTakendown};
3
4
use crate::state::AppState;
4
5
use crate::types::Did;
6
+
use crate::util::pds_hostname;
5
7
use axum::{
6
8
Json,
7
9
extract::State,
···
24
26
}
25
27
26
28
fn gen_invite_code() -> String {
27
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
29
+
let hostname = pds_hostname();
28
30
let hostname_prefix = hostname.replace('.', "-");
29
31
format!("{}-{}", hostname_prefix, gen_random_token())
30
32
}
···
121
123
.user_repo
122
124
.get_any_admin_user_id()
123
125
.await
124
-
.map_err(|e| {
125
-
error!("DB error looking up admin user: {:?}", e);
126
-
ApiError::InternalError(None)
127
-
})?
126
+
.log_db_err("looking up admin user")?
128
127
.ok_or_else(|| {
129
128
error!("No admin user found to create invite codes");
130
129
ApiError::InternalError(None)
···
202
201
.infra_repo
203
202
.get_invite_codes_for_account(&auth.did)
204
203
.await
205
-
.map_err(|e| {
206
-
error!("DB error fetching invite codes: {:?}", e);
207
-
ApiError::InternalError(None)
208
-
})?;
204
+
.log_db_err("fetching invite codes")?;
209
205
210
206
let filtered_codes: Vec<_> = codes_info
211
207
.into_iter()
212
-
.filter(|info| !info.disabled)
208
+
.filter(|info| info.state.is_active())
213
209
.collect();
214
210
215
211
let codes = futures::future::join_all(filtered_codes.into_iter().map(|info| {
+1
-1
crates/tranquil-pds/src/api/server/logo.rs
+1
-1
crates/tranquil-pds/src/api/server/logo.rs
···
21
21
Some(c) if !c.is_empty() => c,
22
22
_ => return StatusCode::NOT_FOUND.into_response(),
23
23
};
24
-
let cid = crate::types::CidLink::new_unchecked(&cid_str);
24
+
let cid = unsafe { crate::types::CidLink::new_unchecked(&cid_str) };
25
25
26
26
let metadata = match state.blob_repo.get_blob_metadata(&cid).await {
27
27
Ok(Some(m)) => m,
+3
-2
crates/tranquil-pds/src/api/server/meta.rs
+3
-2
crates/tranquil-pds/src/api/server/meta.rs
···
1
1
use crate::state::AppState;
2
+
use crate::util::pds_hostname;
2
3
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
3
4
use serde_json::json;
4
5
···
30
31
}
31
32
32
33
pub async fn describe_server() -> impl IntoResponse {
33
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
34
+
let pds_hostname = pds_hostname();
34
35
let domains_str =
35
-
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.clone());
36
+
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.to_string());
36
37
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
37
38
let invite_code_required = std::env::var("INVITE_CODE_REQUIRED")
38
39
.map(|v| v == "true" || v == "1")
+6
-13
crates/tranquil-pds/src/api/server/migration.rs
+6
-13
crates/tranquil-pds/src/api/server/migration.rs
···
1
1
use crate::api::ApiError;
2
+
use crate::api::error::DbResultExt;
2
3
use crate::auth::{Active, Auth};
3
4
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
4
6
use axum::{
5
7
Json,
6
8
extract::State,
···
49
51
.user_repo
50
52
.get_user_for_did_doc(&auth.did)
51
53
.await
52
-
.map_err(|e| {
53
-
tracing::error!("DB error getting user: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?
54
+
.log_db_err("getting user")?
56
55
.ok_or(ApiError::AccountNotFound)?;
57
56
58
57
if let Some(ref methods) = input.verification_methods {
···
107
106
.user_repo
108
107
.upsert_did_web_overrides(user.id, verification_methods_json, also_known_as)
109
108
.await
110
-
.map_err(|e| {
111
-
tracing::error!("DB error upserting did_web_overrides: {:?}", e);
112
-
ApiError::InternalError(None)
113
-
})?;
109
+
.log_db_err("upserting did_web_overrides")?;
114
110
115
111
if let Some(ref endpoint) = input.service_endpoint {
116
112
let endpoint_clean = endpoint.trim().trim_end_matches('/');
···
118
114
.user_repo
119
115
.update_migrated_to_pds(&auth.did, endpoint_clean)
120
116
.await
121
-
.map_err(|e| {
122
-
tracing::error!("DB error updating service endpoint: {:?}", e);
123
-
ApiError::InternalError(None)
124
-
})?;
117
+
.log_db_err("updating service endpoint")?;
125
118
}
126
119
127
120
let did_doc = build_did_document(&state, &auth.did).await;
···
154
147
}
155
148
156
149
async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value {
157
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
150
+
let hostname = pds_hostname();
158
151
159
152
let user = match state.user_repo.get_user_for_did_doc_build(did).await {
160
153
Ok(Some(row)) => row,
+38
-87
crates/tranquil-pds/src/api/server/passkey_account.rs
+38
-87
crates/tranquil-pds/src/api/server/passkey_account.rs
···
1
1
use crate::api::SuccessResponse;
2
2
use crate::api::error::ApiError;
3
+
use crate::auth::NormalizedLoginIdentifier;
3
4
use axum::{
4
5
Json,
5
6
extract::State,
···
19
20
20
21
use crate::api::repo::record::utils::create_signed_commit;
21
22
use crate::auth::{ServiceTokenVerifier, generate_app_password, is_service_token};
22
-
use crate::state::{AppState, RateLimitKind};
23
+
use crate::rate_limit::{AccountCreationLimit, PasswordResetLimit, RateLimited};
24
+
use crate::state::AppState;
23
25
use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey};
26
+
use crate::util::{pds_hostname, pds_hostname_without_port};
24
27
use crate::validation::validate_password;
25
28
26
-
fn extract_client_ip(headers: &HeaderMap) -> String {
27
-
if let Some(forwarded) = headers.get("x-forwarded-for")
28
-
&& let Ok(value) = forwarded.to_str()
29
-
&& let Some(first_ip) = value.split(',').next()
30
-
{
31
-
return first_ip.trim().to_string();
32
-
}
33
-
if let Some(real_ip) = headers.get("x-real-ip")
34
-
&& let Ok(value) = real_ip.to_str()
35
-
{
36
-
return value.trim().to_string();
37
-
}
38
-
"unknown".to_string()
39
-
}
40
-
41
29
fn generate_setup_token() -> String {
42
30
let mut rng = rand::thread_rng();
43
31
(0..32)
···
80
68
81
69
pub async fn create_passkey_account(
82
70
State(state): State<AppState>,
71
+
_rate_limit: RateLimited<AccountCreationLimit>,
83
72
headers: HeaderMap,
84
73
Json(input): Json<CreatePasskeyAccountInput>,
85
74
) -> Response {
86
-
let client_ip = extract_client_ip(&headers);
87
-
if !state
88
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
89
-
.await
90
-
{
91
-
warn!(ip = %client_ip, "Account creation rate limit exceeded");
92
-
return ApiError::RateLimitExceeded(Some(
93
-
"Too many account creation attempts. Please try again later.".into(),
94
-
))
95
-
.into_response();
96
-
}
97
-
98
75
let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header(
99
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
76
+
crate::util::get_header_str(&headers, "Authorization"),
100
77
) {
101
78
let token = extracted.token;
102
79
if is_service_token(&token) {
···
135
112
.map(|d| d.starts_with("did:web:"))
136
113
.unwrap_or(false);
137
114
138
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
139
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
115
+
let hostname = pds_hostname();
116
+
let hostname_for_handles = pds_hostname_without_port();
140
117
let pds_suffix = format!(".{}", hostname_for_handles);
141
118
142
119
let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) {
···
169
146
return ApiError::InvalidEmail.into_response();
170
147
}
171
148
172
-
if let Some(ref code) = input.invite_code {
173
-
let valid = state
174
-
.infra_repo
175
-
.is_invite_code_valid(code)
176
-
.await
177
-
.unwrap_or(false);
178
-
179
-
if !valid {
180
-
return ApiError::InvalidInviteCode.into_response();
149
+
let _validated_invite_code = if let Some(ref code) = input.invite_code {
150
+
match state.infra_repo.validate_invite_code(code).await {
151
+
Ok(validated) => Some(validated),
152
+
Err(_) => return ApiError::InvalidInviteCode.into_response(),
181
153
}
182
154
} else {
183
155
let invite_required = std::env::var("INVITE_CODE_REQUIRED")
···
186
158
if invite_required {
187
159
return ApiError::InviteCodeRequired.into_response();
188
160
}
189
-
}
161
+
None
162
+
};
190
163
191
164
let verification_channel = input.verification_channel.as_deref().unwrap_or("email");
192
165
let verification_recipient = match verification_channel {
···
268
241
}
269
242
if is_byod_did_web {
270
243
if let Some(ref auth_did) = byod_auth
271
-
&& d != auth_did
244
+
&& d != auth_did.as_str()
272
245
{
273
246
return ApiError::AuthorizationError(format!(
274
247
"Service token issuer {} does not match DID {}",
···
280
253
} else {
281
254
if let Err(e) = crate::api::identity::did::verify_did_web(
282
255
d,
283
-
&hostname,
256
+
hostname,
284
257
&input.handle,
285
258
input.signing_key.as_deref(),
286
259
)
···
296
269
if let Some(ref auth_did) = byod_auth {
297
270
if let Some(ref provided_did) = input.did {
298
271
if provided_did.starts_with("did:plc:") {
299
-
if provided_did != auth_did {
272
+
if provided_did != auth_did.as_str() {
300
273
return ApiError::AuthorizationError(format!(
301
274
"Service token issuer {} does not match DID {}",
302
275
auth_did, provided_did
···
389
362
}
390
363
};
391
364
let rev = Tid::now(LimitedU32::MIN);
392
-
let did_typed = Did::new_unchecked(&did);
365
+
let did_typed = unsafe { Did::new_unchecked(&did) };
393
366
let (commit_bytes, _sig) =
394
367
match create_signed_commit(&did_typed, mst_root, rev.as_ref(), None, &secret_key) {
395
368
Ok(result) => result,
···
422
395
_ => tranquil_db_traits::CommsChannel::Email,
423
396
};
424
397
425
-
let handle_typed = Handle::new_unchecked(&handle);
398
+
let handle_typed = unsafe { Handle::new_unchecked(&handle) };
426
399
let create_input = tranquil_db_traits::CreatePasskeyAccountInput {
427
400
handle: handle_typed.clone(),
428
401
email: email.clone().unwrap_or_default(),
···
484
457
{
485
458
warn!("Failed to sequence identity event for {}: {}", did, e);
486
459
}
487
-
if let Err(e) =
488
-
crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await
460
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
461
+
&state,
462
+
&did_typed,
463
+
tranquil_db_traits::AccountStatus::Active,
464
+
)
465
+
.await
489
466
{
490
467
warn!("Failed to sequence account event for {}: {}", did, e);
491
468
}
···
493
470
"$type": "app.bsky.actor.profile",
494
471
"displayName": handle
495
472
});
496
-
let profile_collection = Nsid::new_unchecked("app.bsky.actor.profile");
497
-
let profile_rkey = Rkey::new_unchecked("self");
473
+
let profile_collection = unsafe { Nsid::new_unchecked("app.bsky.actor.profile") };
474
+
let profile_rkey = unsafe { Rkey::new_unchecked("self") };
498
475
if let Err(e) = crate::api::repo::record::create_record_internal(
499
476
&state,
500
477
&did_typed,
···
521
498
verification_channel,
522
499
&verification_recipient,
523
500
&formatted_token,
524
-
&hostname,
501
+
hostname,
525
502
)
526
503
.await
527
504
{
···
541
518
refresh_jti,
542
519
access_expires_at: token_meta.expires_at,
543
520
refresh_expires_at: refresh_expires,
544
-
legacy_login: false,
521
+
login_type: tranquil_db::LoginType::Modern,
545
522
mfa_verified: false,
546
523
scope: None,
547
524
controller_did: None,
···
626
603
return ApiError::InvalidToken(None).into_response();
627
604
}
628
605
629
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
630
-
let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
631
-
Ok(w) => w,
632
-
Err(e) => {
633
-
error!("Failed to create WebAuthn config: {:?}", e);
634
-
return ApiError::InternalError(None).into_response();
635
-
}
636
-
};
606
+
let webauthn = &state.webauthn_config;
637
607
638
608
let reg_state = match state
639
609
.user_repo
···
768
738
return ApiError::InvalidToken(None).into_response();
769
739
}
770
740
771
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
772
-
let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
773
-
Ok(w) => w,
774
-
Err(e) => {
775
-
error!("Failed to create WebAuthn config: {:?}", e);
776
-
return ApiError::InternalError(None).into_response();
777
-
}
778
-
};
741
+
let webauthn = &state.webauthn_config;
779
742
780
743
let existing_passkeys = state
781
744
.user_repo
···
840
803
841
804
pub async fn request_passkey_recovery(
842
805
State(state): State<AppState>,
843
-
headers: HeaderMap,
806
+
_rate_limit: RateLimited<PasswordResetLimit>,
844
807
Json(input): Json<RequestPasskeyRecoveryInput>,
845
808
) -> Response {
846
-
let client_ip = extract_client_ip(&headers);
847
-
if !state
848
-
.check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
849
-
.await
850
-
{
851
-
return ApiError::RateLimitExceeded(None).into_response();
852
-
}
853
-
854
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
855
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
809
+
let hostname_for_handles = pds_hostname_without_port();
856
810
let identifier = input.email.trim().to_lowercase();
857
811
let identifier = identifier.strip_prefix('@').unwrap_or(&identifier);
858
-
let normalized_handle = if identifier.contains('@') || identifier.contains('.') {
859
-
identifier.to_string()
860
-
} else {
861
-
format!("{}.{}", identifier, hostname_for_handles)
862
-
};
812
+
let normalized_handle =
813
+
NormalizedLoginIdentifier::normalize(&input.email, hostname_for_handles);
863
814
864
815
let user = match state
865
816
.user_repo
866
-
.get_user_for_passkey_recovery(identifier, &normalized_handle)
817
+
.get_user_for_passkey_recovery(identifier, normalized_handle.as_str())
867
818
.await
868
819
{
869
820
Ok(Some(u)) if !u.password_required => u,
···
890
841
return ApiError::InternalError(None).into_response();
891
842
}
892
843
893
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
844
+
let hostname = pds_hostname();
894
845
let recovery_url = format!(
895
846
"https://{}/app/recover-passkey?did={}&token={}",
896
847
hostname,
···
903
854
state.infra_repo.as_ref(),
904
855
user.id,
905
856
&recovery_url,
906
-
&hostname,
857
+
hostname,
907
858
)
908
859
.await;
909
860
+20
-56
crates/tranquil-pds/src/api/server/passkeys.rs
+20
-56
crates/tranquil-pds/src/api/server/passkeys.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::webauthn::WebAuthnConfig;
4
-
use crate::auth::{Active, Auth};
2
+
use crate::api::error::{ApiError, DbResultExt};
3
+
use crate::auth::{Active, Auth, require_legacy_session_mfa, require_reauth_window};
5
4
use crate::state::AppState;
6
5
use axum::{
7
6
Json,
···
12
11
use tracing::{error, info, warn};
13
12
use webauthn_rs::prelude::*;
14
13
15
-
fn get_webauthn() -> Result<WebAuthnConfig, ApiError> {
16
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
17
-
WebAuthnConfig::new(&hostname).map_err(|e| {
18
-
error!("Failed to create WebAuthn config: {}", e);
19
-
ApiError::InternalError(Some("WebAuthn configuration failed".into()))
20
-
})
21
-
}
22
-
23
14
#[derive(Deserialize)]
24
15
#[serde(rename_all = "camelCase")]
25
16
pub struct StartRegistrationInput {
···
37
28
auth: Auth<Active>,
38
29
Json(input): Json<StartRegistrationInput>,
39
30
) -> Result<Response, ApiError> {
40
-
let webauthn = get_webauthn()?;
31
+
let webauthn = &state.webauthn_config;
41
32
42
33
let handle = state
43
34
.user_repo
44
35
.get_handle_by_did(&auth.did)
45
36
.await
46
-
.map_err(|e| {
47
-
error!("DB error fetching user: {:?}", e);
48
-
ApiError::InternalError(None)
49
-
})?
37
+
.log_db_err("fetching user")?
50
38
.ok_or(ApiError::AccountNotFound)?;
51
39
52
40
let existing_passkeys = state
53
41
.user_repo
54
42
.get_passkeys_for_user(&auth.did)
55
43
.await
56
-
.map_err(|e| {
57
-
error!("DB error fetching existing passkeys: {:?}", e);
58
-
ApiError::InternalError(None)
59
-
})?;
44
+
.log_db_err("fetching existing passkeys")?;
60
45
61
46
let exclude_credentials: Vec<CredentialID> = existing_passkeys
62
47
.iter()
···
81
66
.user_repo
82
67
.save_webauthn_challenge(&auth.did, "registration", &state_json)
83
68
.await
84
-
.map_err(|e| {
85
-
error!("Failed to save registration state: {:?}", e);
86
-
ApiError::InternalError(None)
87
-
})?;
69
+
.log_db_err("saving registration state")?;
88
70
89
71
let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({}));
90
72
···
112
94
auth: Auth<Active>,
113
95
Json(input): Json<FinishRegistrationInput>,
114
96
) -> Result<Response, ApiError> {
115
-
let webauthn = get_webauthn()?;
97
+
let webauthn = &state.webauthn_config;
116
98
117
99
let reg_state_json = state
118
100
.user_repo
119
101
.load_webauthn_challenge(&auth.did, "registration")
120
102
.await
121
-
.map_err(|e| {
122
-
error!("DB error loading registration state: {:?}", e);
123
-
ApiError::InternalError(None)
124
-
})?
103
+
.log_db_err("loading registration state")?
125
104
.ok_or(ApiError::NoRegistrationInProgress)?;
126
105
127
106
let reg_state: SecurityKeyRegistration =
···
157
136
input.friendly_name.as_deref(),
158
137
)
159
138
.await
160
-
.map_err(|e| {
161
-
error!("Failed to save passkey: {:?}", e);
162
-
ApiError::InternalError(None)
163
-
})?;
139
+
.log_db_err("saving passkey")?;
164
140
165
141
if let Err(e) = state
166
142
.user_repo
···
208
184
.user_repo
209
185
.get_passkeys_for_user(&auth.did)
210
186
.await
211
-
.map_err(|e| {
212
-
error!("DB error fetching passkeys: {:?}", e);
213
-
ApiError::InternalError(None)
214
-
})?;
187
+
.log_db_err("fetching passkeys")?;
215
188
216
189
let passkey_infos: Vec<PasskeyInfo> = passkeys
217
190
.into_iter()
···
241
214
auth: Auth<Active>,
242
215
Json(input): Json<DeletePasskeyInput>,
243
216
) -> Result<Response, ApiError> {
244
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
245
-
{
246
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
247
-
&*state.user_repo,
248
-
&*state.session_repo,
249
-
&auth.did,
250
-
)
251
-
.await);
252
-
}
217
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
218
+
Ok(proof) => proof,
219
+
Err(response) => return Ok(response),
220
+
};
253
221
254
-
if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await {
255
-
return Ok(crate::api::server::reauth::reauth_required_response(
256
-
&*state.user_repo,
257
-
&*state.session_repo,
258
-
&auth.did,
259
-
)
260
-
.await);
261
-
}
222
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
223
+
Ok(proof) => proof,
224
+
Err(response) => return Ok(response),
225
+
};
262
226
263
227
let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?;
264
228
265
-
match state.user_repo.delete_passkey(id, &auth.did).await {
229
+
match state.user_repo.delete_passkey(id, reauth_mfa.did()).await {
266
230
Ok(true) => {
267
-
info!(did = %auth.did, passkey_id = %id, "Passkey deleted");
231
+
info!(did = %session_mfa.did(), passkey_id = %id, "Passkey deleted");
268
232
Ok(EmptyResponse::ok().into_response())
269
233
}
270
234
Ok(false) => Err(ApiError::PasskeyNotFound),
+64
-170
crates/tranquil-pds/src/api/server/password.rs
+64
-170
crates/tranquil-pds/src/api/server/password.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse};
3
-
use crate::auth::{Active, Auth};
4
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::auth::{
4
+
Active, Auth, NormalizedLoginIdentifier, require_legacy_session_mfa, require_reauth_window,
5
+
require_reauth_window_if_available,
6
+
};
7
+
use crate::rate_limit::{PasswordResetLimit, RateLimited, ResetPasswordLimit};
8
+
use crate::state::AppState;
5
9
use crate::types::PlainPassword;
10
+
use crate::util::{pds_hostname, pds_hostname_without_port};
6
11
use crate::validation::validate_password;
7
12
use axum::{
8
13
Json,
9
14
extract::State,
10
-
http::HeaderMap,
11
15
response::{IntoResponse, Response},
12
16
};
13
-
use bcrypt::{DEFAULT_COST, hash, verify};
17
+
use bcrypt::{DEFAULT_COST, hash};
14
18
use chrono::{Duration, Utc};
15
19
use serde::Deserialize;
16
20
use tracing::{error, info, warn};
···
18
22
fn generate_reset_code() -> String {
19
23
crate::util::generate_token_code()
20
24
}
21
-
fn extract_client_ip(headers: &HeaderMap) -> String {
22
-
if let Some(forwarded) = headers.get("x-forwarded-for")
23
-
&& let Ok(value) = forwarded.to_str()
24
-
&& let Some(first_ip) = value.split(',').next()
25
-
{
26
-
return first_ip.trim().to_string();
27
-
}
28
-
if let Some(real_ip) = headers.get("x-real-ip")
29
-
&& let Ok(value) = real_ip.to_str()
30
-
{
31
-
return value.trim().to_string();
32
-
}
33
-
"unknown".to_string()
34
-
}
35
25
36
26
#[derive(Deserialize)]
37
27
pub struct RequestPasswordResetInput {
···
41
31
42
32
pub async fn request_password_reset(
43
33
State(state): State<AppState>,
44
-
headers: HeaderMap,
34
+
_rate_limit: RateLimited<PasswordResetLimit>,
45
35
Json(input): Json<RequestPasswordResetInput>,
46
36
) -> Response {
47
-
let client_ip = extract_client_ip(&headers);
48
-
if !state
49
-
.check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
50
-
.await
51
-
{
52
-
warn!(ip = %client_ip, "Password reset rate limit exceeded");
53
-
return ApiError::RateLimitExceeded(None).into_response();
54
-
}
55
37
let identifier = input.email.trim();
56
38
if identifier.is_empty() {
57
39
return ApiError::InvalidRequest("email or handle is required".into()).into_response();
58
40
}
59
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
60
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
41
+
let hostname_for_handles = pds_hostname_without_port();
61
42
let normalized = identifier.to_lowercase();
62
43
let normalized = normalized.strip_prefix('@').unwrap_or(&normalized);
63
44
let is_email_lookup = normalized.contains('@');
64
-
let normalized_handle = if normalized.contains('@') || normalized.contains('.') {
65
-
normalized.to_string()
66
-
} else {
67
-
format!("{}.{}", normalized, hostname_for_handles)
68
-
};
45
+
let normalized_handle = NormalizedLoginIdentifier::normalize(identifier, hostname_for_handles);
69
46
70
47
let multiple_accounts_warning = if is_email_lookup {
71
48
match state.user_repo.count_accounts_by_email(normalized).await {
···
78
55
79
56
let user_id = match state
80
57
.user_repo
81
-
.get_id_by_email_or_handle(normalized, &normalized_handle)
58
+
.get_id_by_email_or_handle(normalized, normalized_handle.as_str())
82
59
.await
83
60
{
84
61
Ok(Some(id)) => id,
···
101
78
error!("DB error setting reset code: {:?}", e);
102
79
return ApiError::InternalError(None).into_response();
103
80
}
104
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
81
+
let hostname = pds_hostname();
105
82
if let Err(e) = crate::comms::comms_repo::enqueue_password_reset(
106
83
state.user_repo.as_ref(),
107
84
state.infra_repo.as_ref(),
108
85
user_id,
109
86
&code,
110
-
&hostname,
87
+
hostname,
111
88
)
112
89
.await
113
90
{
···
135
112
136
113
pub async fn reset_password(
137
114
State(state): State<AppState>,
138
-
headers: HeaderMap,
115
+
_rate_limit: RateLimited<ResetPasswordLimit>,
139
116
Json(input): Json<ResetPasswordInput>,
140
117
) -> Response {
141
-
let client_ip = extract_client_ip(&headers);
142
-
if !state
143
-
.check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
144
-
.await
145
-
{
146
-
warn!(ip = %client_ip, "Reset password rate limit exceeded");
147
-
return ApiError::RateLimitExceeded(None).into_response();
148
-
}
149
118
let token = input.token.trim();
150
119
let password = &input.password;
151
120
if token.is_empty() {
···
230
199
auth: Auth<Active>,
231
200
Json(input): Json<ChangePasswordInput>,
232
201
) -> Result<Response, ApiError> {
233
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
234
-
{
235
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
236
-
&*state.user_repo,
237
-
&*state.session_repo,
238
-
&auth.did,
239
-
)
240
-
.await);
241
-
}
202
+
use crate::auth::verify_password_mfa;
242
203
243
-
let current_password = &input.current_password;
244
-
let new_password = &input.new_password;
245
-
if current_password.is_empty() {
204
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
205
+
Ok(proof) => proof,
206
+
Err(response) => return Ok(response),
207
+
};
208
+
209
+
if input.current_password.is_empty() {
246
210
return Err(ApiError::InvalidRequest(
247
211
"currentPassword is required".into(),
248
212
));
249
213
}
250
-
if new_password.is_empty() {
214
+
if input.new_password.is_empty() {
251
215
return Err(ApiError::InvalidRequest("newPassword is required".into()));
252
216
}
253
-
if let Err(e) = validate_password(new_password) {
217
+
if let Err(e) = validate_password(&input.new_password) {
254
218
return Err(ApiError::InvalidRequest(e.to_string()));
255
219
}
220
+
221
+
let password_mfa = verify_password_mfa(&state, &auth, &input.current_password).await?;
222
+
256
223
let user = state
257
224
.user_repo
258
-
.get_id_and_password_hash_by_did(&auth.did)
225
+
.get_id_and_password_hash_by_did(password_mfa.did())
259
226
.await
260
-
.map_err(|e| {
261
-
error!("DB error in change_password: {:?}", e);
262
-
ApiError::InternalError(None)
263
-
})?
227
+
.log_db_err("in change_password")?
264
228
.ok_or(ApiError::AccountNotFound)?;
265
229
266
-
let (user_id, password_hash) = (user.id, user.password_hash);
267
-
let valid = verify(current_password, &password_hash).map_err(|e| {
268
-
error!("Password verification error: {:?}", e);
269
-
ApiError::InternalError(None)
270
-
})?;
271
-
if !valid {
272
-
return Err(ApiError::InvalidPassword(
273
-
"Current password is incorrect".into(),
274
-
));
275
-
}
276
-
let new_password_clone = new_password.to_string();
230
+
let new_password_clone = input.new_password.to_string();
277
231
let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST))
278
232
.await
279
233
.map_err(|e| {
···
287
241
288
242
state
289
243
.user_repo
290
-
.update_password_hash(user_id, &new_hash)
244
+
.update_password_hash(user.id, &new_hash)
291
245
.await
292
-
.map_err(|e| {
293
-
error!("DB error updating password: {:?}", e);
294
-
ApiError::InternalError(None)
295
-
})?;
246
+
.log_db_err("updating password")?;
296
247
297
-
info!(did = %&auth.did, "Password changed successfully");
248
+
info!(did = %session_mfa.did(), "Password changed successfully");
298
249
Ok(EmptyResponse::ok().into_response())
299
250
}
300
251
···
302
253
State(state): State<AppState>,
303
254
auth: Auth<Active>,
304
255
) -> Result<Response, ApiError> {
305
-
match state.user_repo.has_password_by_did(&auth.did).await {
306
-
Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()),
307
-
Ok(None) => Err(ApiError::AccountNotFound),
308
-
Err(e) => {
309
-
error!("DB error: {:?}", e);
310
-
Err(ApiError::InternalError(None))
311
-
}
312
-
}
256
+
let has = state
257
+
.user_repo
258
+
.has_password_by_did(&auth.did)
259
+
.await
260
+
.log_db_err("checking password status")?
261
+
.ok_or(ApiError::AccountNotFound)?;
262
+
Ok(HasPasswordResponse::response(has).into_response())
313
263
}
314
264
315
265
pub async fn remove_password(
316
266
State(state): State<AppState>,
317
267
auth: Auth<Active>,
318
268
) -> Result<Response, ApiError> {
319
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
320
-
{
321
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
322
-
&*state.user_repo,
323
-
&*state.session_repo,
324
-
&auth.did,
325
-
)
326
-
.await);
327
-
}
269
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
270
+
Ok(proof) => proof,
271
+
Err(response) => return Ok(response),
272
+
};
328
273
329
-
if crate::api::server::reauth::check_reauth_required_cached(
330
-
&*state.session_repo,
331
-
&state.cache,
332
-
&auth.did,
333
-
)
334
-
.await
335
-
{
336
-
return Ok(crate::api::server::reauth::reauth_required_response(
337
-
&*state.user_repo,
338
-
&*state.session_repo,
339
-
&auth.did,
340
-
)
341
-
.await);
342
-
}
274
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
275
+
Ok(proof) => proof,
276
+
Err(response) => return Ok(response),
277
+
};
343
278
344
279
let has_passkeys = state
345
280
.user_repo
346
-
.has_passkeys(&auth.did)
281
+
.has_passkeys(reauth_mfa.did())
347
282
.await
348
283
.unwrap_or(false);
349
284
if !has_passkeys {
···
354
289
355
290
let user = state
356
291
.user_repo
357
-
.get_password_info_by_did(&auth.did)
292
+
.get_password_info_by_did(reauth_mfa.did())
358
293
.await
359
-
.map_err(|e| {
360
-
error!("DB error: {:?}", e);
361
-
ApiError::InternalError(None)
362
-
})?
294
+
.log_db_err("getting password info")?
363
295
.ok_or(ApiError::AccountNotFound)?;
364
296
365
297
if user.password_hash.is_none() {
···
372
304
.user_repo
373
305
.remove_user_password(user.id)
374
306
.await
375
-
.map_err(|e| {
376
-
error!("DB error removing password: {:?}", e);
377
-
ApiError::InternalError(None)
378
-
})?;
307
+
.log_db_err("removing password")?;
379
308
380
-
info!(did = %&auth.did, "Password removed - account is now passkey-only");
309
+
info!(did = %session_mfa.did(), "Password removed - account is now passkey-only");
381
310
Ok(SuccessResponse::ok().into_response())
382
311
}
383
312
···
392
321
auth: Auth<Active>,
393
322
Json(input): Json<SetPasswordInput>,
394
323
) -> Result<Response, ApiError> {
395
-
let has_password = state
396
-
.user_repo
397
-
.has_password_by_did(&auth.did)
398
-
.await
399
-
.ok()
400
-
.flatten()
401
-
.unwrap_or(false);
402
-
let has_passkeys = state
403
-
.user_repo
404
-
.has_passkeys(&auth.did)
405
-
.await
406
-
.unwrap_or(false);
407
-
let has_totp = state
408
-
.user_repo
409
-
.has_totp_enabled(&auth.did)
410
-
.await
411
-
.unwrap_or(false);
412
-
413
-
let has_any_reauth_method = has_password || has_passkeys || has_totp;
414
-
415
-
if has_any_reauth_method
416
-
&& crate::api::server::reauth::check_reauth_required_cached(
417
-
&*state.session_repo,
418
-
&state.cache,
419
-
&auth.did,
420
-
)
421
-
.await
422
-
{
423
-
return Ok(crate::api::server::reauth::reauth_required_response(
424
-
&*state.user_repo,
425
-
&*state.session_repo,
426
-
&auth.did,
427
-
)
428
-
.await);
429
-
}
324
+
let reauth_mfa = match require_reauth_window_if_available(&state, &auth).await {
325
+
Ok(proof) => proof,
326
+
Err(response) => return Ok(response),
327
+
};
430
328
431
329
let new_password = &input.new_password;
432
330
if new_password.is_empty() {
···
436
334
return Err(ApiError::InvalidRequest(e.to_string()));
437
335
}
438
336
337
+
let did = reauth_mfa.as_ref().map(|m| m.did()).unwrap_or(&auth.did);
338
+
439
339
let user = state
440
340
.user_repo
441
-
.get_password_info_by_did(&auth.did)
341
+
.get_password_info_by_did(did)
442
342
.await
443
-
.map_err(|e| {
444
-
error!("DB error: {:?}", e);
445
-
ApiError::InternalError(None)
446
-
})?
343
+
.log_db_err("getting password info")?
447
344
.ok_or(ApiError::AccountNotFound)?;
448
345
449
346
if user.password_hash.is_some() {
···
468
365
.user_repo
469
366
.set_new_user_password(user.id, &new_hash)
470
367
.await
471
-
.map_err(|e| {
472
-
error!("DB error setting password: {:?}", e);
473
-
ApiError::InternalError(None)
474
-
})?;
368
+
.log_db_err("setting password")?;
475
369
476
-
info!(did = %&auth.did, "Password set for passkey-only account");
370
+
info!(did = %did, "Password set for passkey-only account");
477
371
Ok(SuccessResponse::ok().into_response())
478
372
}
+22
-59
crates/tranquil-pds/src/api/server/reauth.rs
+22
-59
crates/tranquil-pds/src/api/server/reauth.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use axum::{
3
3
Json,
4
4
extract::State,
···
11
11
use tranquil_db_traits::{SessionRepository, UserRepository};
12
12
13
13
use crate::auth::{Active, Auth};
14
-
use crate::state::{AppState, RateLimitKind};
14
+
use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message};
15
+
use crate::state::AppState;
15
16
use crate::types::PlainPassword;
16
17
17
-
const REAUTH_WINDOW_SECONDS: i64 = 300;
18
+
pub const REAUTH_WINDOW_SECONDS: i64 = 300;
18
19
19
20
#[derive(Serialize)]
20
21
#[serde(rename_all = "camelCase")]
···
32
33
.session_repo
33
34
.get_last_reauth_at(&auth.did)
34
35
.await
35
-
.map_err(|e| {
36
-
error!("DB error: {:?}", e);
37
-
ApiError::InternalError(None)
38
-
})?;
36
+
.log_db_err("getting last reauth")?;
39
37
40
38
let reauth_required = is_reauth_required(last_reauth_at);
41
39
let available_methods =
···
70
68
.user_repo
71
69
.get_password_hash_by_did(&auth.did)
72
70
.await
73
-
.map_err(|e| {
74
-
error!("DB error: {:?}", e);
75
-
ApiError::InternalError(None)
76
-
})?
71
+
.log_db_err("fetching password hash")?
77
72
.ok_or(ApiError::AccountNotFound)?;
78
73
79
74
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
···
97
92
98
93
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
99
94
.await
100
-
.map_err(|e| {
101
-
error!("DB error updating reauth: {:?}", e);
102
-
ApiError::InternalError(None)
103
-
})?;
95
+
.log_db_err("updating reauth")?;
104
96
105
97
info!(did = %&auth.did, "Re-auth successful via password");
106
98
Ok(Json(ReauthResponse { reauthed_at }).into_response())
···
117
109
auth: Auth<Active>,
118
110
Json(input): Json<TotpReauthInput>,
119
111
) -> Result<Response, ApiError> {
120
-
if !state
121
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
122
-
.await
123
-
{
124
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
125
-
return Err(ApiError::RateLimitExceeded(Some(
126
-
"Too many verification attempts. Please try again in a few minutes.".into(),
127
-
)));
128
-
}
112
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
113
+
&state,
114
+
&auth.did,
115
+
"Too many verification attempts. Please try again in a few minutes.",
116
+
)
117
+
.await?;
129
118
130
119
let valid =
131
120
crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.did, &input.code)
···
140
129
141
130
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
142
131
.await
143
-
.map_err(|e| {
144
-
error!("DB error updating reauth: {:?}", e);
145
-
ApiError::InternalError(None)
146
-
})?;
132
+
.log_db_err("updating reauth")?;
147
133
148
134
info!(did = %&auth.did, "Re-auth successful via TOTP");
149
135
Ok(Json(ReauthResponse { reauthed_at }).into_response())
···
159
145
State(state): State<AppState>,
160
146
auth: Auth<Active>,
161
147
) -> Result<Response, ApiError> {
162
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
163
-
164
148
let stored_passkeys = state
165
149
.user_repo
166
150
.get_passkeys_for_user(&auth.did)
167
151
.await
168
-
.map_err(|e| {
169
-
error!("Failed to get passkeys: {:?}", e);
170
-
ApiError::InternalError(None)
171
-
})?;
152
+
.log_db_err("getting passkeys")?;
172
153
173
154
if stored_passkeys.is_empty() {
174
155
return Err(ApiError::NoPasskeys);
···
185
166
)));
186
167
}
187
168
188
-
let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| {
189
-
error!("Failed to create WebAuthn config: {:?}", e);
190
-
ApiError::InternalError(None)
191
-
})?;
169
+
let webauthn = &state.webauthn_config;
192
170
193
171
let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| {
194
172
error!("Failed to start passkey authentication: {:?}", e);
···
204
182
.user_repo
205
183
.save_webauthn_challenge(&auth.did, "authentication", &state_json)
206
184
.await
207
-
.map_err(|e| {
208
-
error!("Failed to save authentication state: {:?}", e);
209
-
ApiError::InternalError(None)
210
-
})?;
185
+
.log_db_err("saving authentication state")?;
211
186
212
187
let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({}));
213
188
Ok(Json(PasskeyReauthStartResponse { options }).into_response())
···
224
199
auth: Auth<Active>,
225
200
Json(input): Json<PasskeyReauthFinishInput>,
226
201
) -> Result<Response, ApiError> {
227
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
228
-
229
202
let auth_state_json = state
230
203
.user_repo
231
204
.load_webauthn_challenge(&auth.did, "authentication")
232
205
.await
233
-
.map_err(|e| {
234
-
error!("Failed to load authentication state: {:?}", e);
235
-
ApiError::InternalError(None)
236
-
})?
206
+
.log_db_err("loading authentication state")?
237
207
.ok_or(ApiError::NoChallengeInProgress)?;
238
208
239
209
let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication =
···
248
218
ApiError::InvalidCredential
249
219
})?;
250
220
251
-
let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| {
252
-
error!("Failed to create WebAuthn config: {:?}", e);
253
-
ApiError::InternalError(None)
254
-
})?;
255
-
256
-
let auth_result = webauthn
221
+
let auth_result = state
222
+
.webauthn_config
257
223
.finish_authentication(&credential, &auth_state)
258
224
.map_err(|e| {
259
225
warn!(did = %&auth.did, "Passkey re-auth failed: {:?}", e);
···
287
253
288
254
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
289
255
.await
290
-
.map_err(|e| {
291
-
error!("DB error updating reauth: {:?}", e);
292
-
ApiError::InternalError(None)
293
-
})?;
256
+
.log_db_err("updating reauth")?;
294
257
295
258
info!(did = %&auth.did, "Re-auth successful via passkey");
296
259
Ok(Json(ReauthResponse { reauthed_at }).into_response())
···
418
381
) -> bool {
419
382
match session_repo.get_session_mfa_status(did).await {
420
383
Ok(Some(status)) => {
421
-
if !status.legacy_login {
384
+
if status.login_type.is_modern() {
422
385
return true;
423
386
}
424
387
if status.mfa_verified {
+3
-3
crates/tranquil-pds/src/api/server/service_auth.rs
+3
-3
crates/tranquil-pds/src/api/server/service_auth.rs
···
51
51
headers: axum::http::HeaderMap,
52
52
Query(params): Query<GetServiceAuthParams>,
53
53
) -> Response {
54
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
55
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
54
+
let auth_header = crate::util::get_header_str(&headers, "Authorization");
55
+
let dpop_proof = crate::util::get_header_str(&headers, "DPoP");
56
56
info!(
57
57
has_auth_header = auth_header.is_some(),
58
58
has_dpop_proof = dpop_proof.is_some(),
···
94
94
.await
95
95
{
96
96
Ok(result) => crate::auth::AuthenticatedUser {
97
-
did: Did::new_unchecked(result.did),
97
+
did: unsafe { Did::new_unchecked(result.did) },
98
98
is_admin: false,
99
99
status: AccountStatus::Active,
100
100
scope: result.scope,
+88
-161
crates/tranquil-pds/src/api/server/session.rs
+88
-161
crates/tranquil-pds/src/api/server/session.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, SuccessResponse};
3
-
use crate::auth::{Active, Auth, Permissive};
4
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::auth::{
4
+
Active, Auth, NormalizedLoginIdentifier, Permissive, require_legacy_session_mfa,
5
+
require_reauth_window,
6
+
};
7
+
use crate::rate_limit::{LoginLimit, RateLimited, RefreshSessionLimit};
8
+
use crate::state::AppState;
5
9
use crate::types::{AccountState, Did, Handle, PlainPassword};
10
+
use crate::util::{pds_hostname, pds_hostname_without_port};
6
11
use axum::{
7
12
Json,
8
13
extract::State,
···
13
18
use serde::{Deserialize, Serialize};
14
19
use serde_json::json;
15
20
use tracing::{error, info, warn};
21
+
use tranquil_db_traits::{SessionId, TokenFamilyId};
16
22
use tranquil_types::TokenId;
17
23
18
-
fn extract_client_ip(headers: &HeaderMap) -> String {
19
-
if let Some(forwarded) = headers.get("x-forwarded-for")
20
-
&& let Ok(value) = forwarded.to_str()
21
-
&& let Some(first_ip) = value.split(',').next()
22
-
{
23
-
return first_ip.trim().to_string();
24
-
}
25
-
if let Some(real_ip) = headers.get("x-real-ip")
26
-
&& let Ok(value) = real_ip.to_str()
27
-
{
28
-
return value.trim().to_string();
29
-
}
30
-
"unknown".to_string()
31
-
}
32
-
33
-
fn normalize_handle(identifier: &str, pds_hostname: &str) -> String {
34
-
let identifier = identifier.trim();
35
-
if identifier.contains('@') || identifier.starts_with("did:") {
36
-
identifier.to_string()
37
-
} else if !identifier.contains('.') {
38
-
format!("{}.{}", identifier.to_lowercase(), pds_hostname)
39
-
} else {
40
-
identifier.to_lowercase()
41
-
}
42
-
}
43
-
44
24
fn full_handle(stored_handle: &str, _pds_hostname: &str) -> String {
45
25
stored_handle.to_string()
46
26
}
···
75
55
76
56
pub async fn create_session(
77
57
State(state): State<AppState>,
78
-
headers: HeaderMap,
58
+
rate_limit: RateLimited<LoginLimit>,
79
59
Json(input): Json<CreateSessionInput>,
80
60
) -> Response {
61
+
let client_ip = rate_limit.client_ip();
81
62
info!(
82
63
"create_session called with identifier: {}",
83
64
input.identifier
84
65
);
85
-
let client_ip = extract_client_ip(&headers);
86
-
if !state
87
-
.check_rate_limit(RateLimitKind::Login, &client_ip)
88
-
.await
89
-
{
90
-
warn!(ip = %client_ip, "Login rate limit exceeded");
91
-
return ApiError::RateLimitExceeded(None).into_response();
92
-
}
93
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
94
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
95
-
let normalized_identifier = normalize_handle(&input.identifier, hostname_for_handles);
66
+
let pds_host = pds_hostname();
67
+
let hostname_for_handles = pds_hostname_without_port();
68
+
let normalized_identifier =
69
+
NormalizedLoginIdentifier::normalize(&input.identifier, hostname_for_handles);
96
70
info!(
97
71
"Normalized identifier: {} -> {}",
98
72
input.identifier, normalized_identifier
99
73
);
100
74
let row = match state
101
75
.user_repo
102
-
.get_login_full_by_identifier(&normalized_identifier)
76
+
.get_login_full_by_identifier(normalized_identifier.as_str())
103
77
.await
104
78
{
105
79
Ok(Some(row)) => row,
···
165
139
warn!("Login attempt for takendown account: {}", row.did);
166
140
return ApiError::AccountTakedown.into_response();
167
141
}
168
-
let is_verified =
169
-
row.email_verified || row.discord_verified || row.telegram_verified || row.signal_verified;
142
+
let is_verified = row.channel_verification.has_any_verified();
170
143
let is_delegated = state
171
144
.delegation_repo
172
145
.is_delegated_account(&row.did)
···
226
199
refresh_jti: refresh_meta.jti.clone(),
227
200
access_expires_at: access_meta.expires_at,
228
201
refresh_expires_at: refresh_meta.expires_at,
229
-
legacy_login: is_legacy_login,
202
+
login_type: tranquil_db_traits::LoginType::from(is_legacy_login),
230
203
mfa_verified: false,
231
204
scope: app_password_scopes.clone(),
232
205
controller_did: app_password_controller.clone(),
···
246
219
ip = %client_ip,
247
220
"Legacy login on TOTP-enabled account - sending notification"
248
221
);
249
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
222
+
let hostname = pds_hostname();
250
223
if let Err(e) = crate::comms::comms_repo::enqueue_legacy_login(
251
224
state.user_repo.as_ref(),
252
225
state.infra_repo.as_ref(),
253
226
row.id,
254
-
&hostname,
255
-
&client_ip,
227
+
hostname,
228
+
client_ip,
256
229
row.preferred_comms_channel,
257
230
)
258
231
.await
···
260
233
error!("Failed to queue legacy login notification: {:?}", e);
261
234
}
262
235
}
263
-
let handle = full_handle(&row.handle, &pds_hostname);
236
+
let handle = full_handle(&row.handle, pds_host);
264
237
let is_active = account_state.is_active();
265
238
let status = account_state.status_for_session().map(String::from);
266
239
Json(CreateSessionOutput {
···
270
243
did: row.did,
271
244
did_doc,
272
245
email: row.email,
273
-
email_confirmed: Some(row.email_verified),
246
+
email_confirmed: Some(row.channel_verification.email),
274
247
active: Some(is_active),
275
248
status,
276
249
})
···
292
265
);
293
266
match db_result {
294
267
Ok(Some(row)) => {
295
-
let (preferred_channel, preferred_channel_verified) = match row.preferred_comms_channel
296
-
{
297
-
tranquil_db_traits::CommsChannel::Email => ("email", row.email_verified),
298
-
tranquil_db_traits::CommsChannel::Discord => ("discord", row.discord_verified),
299
-
tranquil_db_traits::CommsChannel::Telegram => ("telegram", row.telegram_verified),
300
-
tranquil_db_traits::CommsChannel::Signal => ("signal", row.signal_verified),
268
+
let preferred_channel = match row.preferred_comms_channel {
269
+
tranquil_db_traits::CommsChannel::Email => "email",
270
+
tranquil_db_traits::CommsChannel::Discord => "discord",
271
+
tranquil_db_traits::CommsChannel::Telegram => "telegram",
272
+
tranquil_db_traits::CommsChannel::Signal => "signal",
301
273
};
302
-
let pds_hostname =
303
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
304
-
let handle = full_handle(&row.handle, &pds_hostname);
274
+
let preferred_channel_verified = row
275
+
.channel_verification
276
+
.is_verified(row.preferred_comms_channel);
277
+
let pds_hostname = pds_hostname();
278
+
let handle = full_handle(&row.handle, pds_hostname);
305
279
let account_state = AccountState::from_db_fields(
306
280
row.deactivated_at,
307
281
row.takedown_ref.clone(),
···
313
287
} else {
314
288
None
315
289
};
316
-
let email_confirmed_value = can_read_email && row.email_verified;
290
+
let email_confirmed_value = can_read_email && row.channel_verification.email;
317
291
let mut response = json!({
318
292
"handle": handle,
319
293
"did": &auth.did,
···
352
326
headers: axum::http::HeaderMap,
353
327
_auth: Auth<Active>,
354
328
) -> Result<Response, ApiError> {
355
-
let extracted = crate::auth::extract_auth_token_from_header(
356
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
357
-
)
329
+
let extracted = crate::auth::extract_auth_token_from_header(crate::util::get_header_str(
330
+
&headers,
331
+
"Authorization",
332
+
))
358
333
.ok_or(ApiError::AuthenticationRequired)?;
359
334
let jti = crate::auth::get_jti_from_token(&extracted.token)
360
335
.map_err(|_| ApiError::AuthenticationFailed(None))?;
···
374
349
375
350
pub async fn refresh_session(
376
351
State(state): State<AppState>,
352
+
_rate_limit: RateLimited<RefreshSessionLimit>,
377
353
headers: axum::http::HeaderMap,
378
354
) -> Response {
379
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
380
-
if !state
381
-
.check_rate_limit(RateLimitKind::RefreshSession, &client_ip)
382
-
.await
383
-
{
384
-
tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded");
385
-
return ApiError::RateLimitExceeded(None).into_response();
386
-
}
387
-
let extracted = match crate::auth::extract_auth_token_from_header(
388
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
389
-
) {
355
+
let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str(
356
+
&headers,
357
+
"Authorization",
358
+
)) {
390
359
Some(t) => t,
391
360
None => return ApiError::AuthenticationRequired.into_response(),
392
361
};
···
503
472
);
504
473
match db_result {
505
474
Ok(Some(u)) => {
506
-
let (preferred_channel, preferred_channel_verified) = match u.preferred_comms_channel {
507
-
tranquil_db_traits::CommsChannel::Email => ("email", u.email_verified),
508
-
tranquil_db_traits::CommsChannel::Discord => ("discord", u.discord_verified),
509
-
tranquil_db_traits::CommsChannel::Telegram => ("telegram", u.telegram_verified),
510
-
tranquil_db_traits::CommsChannel::Signal => ("signal", u.signal_verified),
475
+
let preferred_channel = match u.preferred_comms_channel {
476
+
tranquil_db_traits::CommsChannel::Email => "email",
477
+
tranquil_db_traits::CommsChannel::Discord => "discord",
478
+
tranquil_db_traits::CommsChannel::Telegram => "telegram",
479
+
tranquil_db_traits::CommsChannel::Signal => "signal",
511
480
};
512
-
let pds_hostname =
513
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
514
-
let handle = full_handle(&u.handle, &pds_hostname);
481
+
let preferred_channel_verified = u
482
+
.channel_verification
483
+
.is_verified(u.preferred_comms_channel);
484
+
let pds_hostname = pds_hostname();
485
+
let handle = full_handle(&u.handle, pds_hostname);
515
486
let account_state =
516
487
AccountState::from_db_fields(u.deactivated_at, u.takedown_ref.clone(), None, None);
517
488
let mut response = json!({
···
520
491
"handle": handle,
521
492
"did": session_row.did,
522
493
"email": u.email,
523
-
"emailConfirmed": u.email_verified,
494
+
"emailConfirmed": u.channel_verification.email,
524
495
"preferredChannel": preferred_channel,
525
496
"preferredChannelVerified": preferred_channel_verified,
526
497
"preferredLocale": u.preferred_locale,
···
664
635
refresh_jti: refresh_meta.jti.clone(),
665
636
access_expires_at: access_meta.expires_at,
666
637
refresh_expires_at: refresh_meta.expires_at,
667
-
legacy_login: false,
638
+
login_type: tranquil_db_traits::LoginType::Modern,
668
639
mfa_verified: false,
669
640
scope: None,
670
641
controller_did: None,
···
675
646
return ApiError::InternalError(None).into_response();
676
647
}
677
648
678
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
649
+
let hostname = pds_hostname();
679
650
if let Err(e) = crate::comms::comms_repo::enqueue_welcome(
680
651
state.user_repo.as_ref(),
681
652
state.infra_repo.as_ref(),
682
653
row.id,
683
-
&hostname,
654
+
hostname,
684
655
)
685
656
.await
686
657
{
···
731
702
return ApiError::InternalError(None).into_response();
732
703
}
733
704
};
734
-
let is_verified =
735
-
row.email_verified || row.discord_verified || row.telegram_verified || row.signal_verified;
705
+
let is_verified = row.channel_verification.has_any_verified();
736
706
if is_verified {
737
707
return ApiError::InvalidRequest("Account is already verified".into()).into_response();
738
708
}
···
756
726
let formatted_token =
757
727
crate::auth::verification_token::format_token_for_display(&verification_token);
758
728
759
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
729
+
let hostname = pds_hostname();
760
730
if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification(
761
731
state.infra_repo.as_ref(),
762
732
row.id,
763
733
channel_str,
764
734
&recipient,
765
735
&formatted_token,
766
-
&hostname,
736
+
hostname,
767
737
)
768
738
.await
769
739
{
···
804
774
.session_repo
805
775
.list_sessions_by_did(&auth.did)
806
776
.await
807
-
.map_err(|e| {
808
-
error!("DB error fetching JWT sessions: {:?}", e);
809
-
ApiError::InternalError(None)
810
-
})?;
777
+
.log_db_err("fetching JWT sessions")?;
811
778
812
779
let oauth_rows = state
813
780
.oauth_repo
814
781
.list_sessions_by_did(&auth.did)
815
782
.await
816
-
.map_err(|e| {
817
-
error!("DB error fetching OAuth sessions: {:?}", e);
818
-
ApiError::InternalError(None)
819
-
})?;
783
+
.log_db_err("fetching OAuth sessions")?;
820
784
821
785
let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo {
822
786
id: format!("jwt:{}", row.id),
···
869
833
Json(input): Json<RevokeSessionInput>,
870
834
) -> Result<Response, ApiError> {
871
835
if let Some(jwt_id) = input.session_id.strip_prefix("jwt:") {
872
-
let session_id: i32 = jwt_id
873
-
.parse()
836
+
let session_id = jwt_id
837
+
.parse::<i32>()
838
+
.map(SessionId::new)
874
839
.map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?;
875
840
let access_jti = state
876
841
.session_repo
877
842
.get_session_access_jti_by_id(session_id, &auth.did)
878
843
.await
879
-
.map_err(|e| {
880
-
error!("DB error in revoke_session: {:?}", e);
881
-
ApiError::InternalError(None)
882
-
})?
844
+
.log_db_err("in revoke_session")?
883
845
.ok_or(ApiError::SessionNotFound)?;
884
846
state
885
847
.session_repo
886
848
.delete_session_by_id(session_id)
887
849
.await
888
-
.map_err(|e| {
889
-
error!("DB error deleting session: {:?}", e);
890
-
ApiError::InternalError(None)
891
-
})?;
850
+
.log_db_err("deleting session")?;
892
851
let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti);
893
852
if let Err(e) = state.cache.delete(&cache_key).await {
894
853
warn!("Failed to invalidate session cache: {:?}", e);
895
854
}
896
855
info!(did = %&auth.did, session_id = %session_id, "JWT session revoked");
897
856
} else if let Some(oauth_id) = input.session_id.strip_prefix("oauth:") {
898
-
let session_id: i32 = oauth_id
899
-
.parse()
857
+
let session_id = oauth_id
858
+
.parse::<i32>()
859
+
.map(TokenFamilyId::new)
900
860
.map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?;
901
861
let deleted = state
902
862
.oauth_repo
903
863
.delete_session_by_id(session_id, &auth.did)
904
864
.await
905
-
.map_err(|e| {
906
-
error!("DB error deleting OAuth session: {:?}", e);
907
-
ApiError::InternalError(None)
908
-
})?;
865
+
.log_db_err("deleting OAuth session")?;
909
866
if deleted == 0 {
910
867
return Err(ApiError::SessionNotFound);
911
868
}
···
932
889
.session_repo
933
890
.delete_sessions_by_did(&auth.did)
934
891
.await
935
-
.map_err(|e| {
936
-
error!("DB error revoking JWT sessions: {:?}", e);
937
-
ApiError::InternalError(None)
938
-
})?;
892
+
.log_db_err("revoking JWT sessions")?;
939
893
let jti_typed = TokenId::from(jti.clone());
940
894
state
941
895
.oauth_repo
942
896
.delete_sessions_by_did_except(&auth.did, &jti_typed)
943
897
.await
944
-
.map_err(|e| {
945
-
error!("DB error revoking OAuth sessions: {:?}", e);
946
-
ApiError::InternalError(None)
947
-
})?;
898
+
.log_db_err("revoking OAuth sessions")?;
948
899
} else {
949
900
state
950
901
.session_repo
951
902
.delete_sessions_by_did_except_jti(&auth.did, &jti)
952
903
.await
953
-
.map_err(|e| {
954
-
error!("DB error revoking JWT sessions: {:?}", e);
955
-
ApiError::InternalError(None)
956
-
})?;
904
+
.log_db_err("revoking JWT sessions")?;
957
905
state
958
906
.oauth_repo
959
907
.delete_sessions_by_did(&auth.did)
960
908
.await
961
-
.map_err(|e| {
962
-
error!("DB error revoking OAuth sessions: {:?}", e);
963
-
ApiError::InternalError(None)
964
-
})?;
909
+
.log_db_err("revoking OAuth sessions")?;
965
910
}
966
911
967
912
info!(did = %&auth.did, "All other sessions revoked");
···
983
928
.user_repo
984
929
.get_legacy_login_pref(&auth.did)
985
930
.await
986
-
.map_err(|e| {
987
-
error!("DB error: {:?}", e);
988
-
ApiError::InternalError(None)
989
-
})?
931
+
.log_db_err("getting legacy login pref")?
990
932
.ok_or(ApiError::AccountNotFound)?;
991
933
Ok(Json(LegacyLoginPreferenceOutput {
992
934
allow_legacy_login: pref.allow_legacy_login,
···
1006
948
auth: Auth<Active>,
1007
949
Json(input): Json<UpdateLegacyLoginInput>,
1008
950
) -> Result<Response, ApiError> {
1009
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
1010
-
{
1011
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
1012
-
&*state.user_repo,
1013
-
&*state.session_repo,
1014
-
&auth.did,
1015
-
)
1016
-
.await);
1017
-
}
951
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
952
+
Ok(proof) => proof,
953
+
Err(response) => return Ok(response),
954
+
};
1018
955
1019
-
if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await {
1020
-
return Ok(crate::api::server::reauth::reauth_required_response(
1021
-
&*state.user_repo,
1022
-
&*state.session_repo,
1023
-
&auth.did,
1024
-
)
1025
-
.await);
1026
-
}
956
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
957
+
Ok(proof) => proof,
958
+
Err(response) => return Ok(response),
959
+
};
1027
960
1028
961
let updated = state
1029
962
.user_repo
1030
-
.update_legacy_login(&auth.did, input.allow_legacy_login)
963
+
.update_legacy_login(reauth_mfa.did(), input.allow_legacy_login)
1031
964
.await
1032
-
.map_err(|e| {
1033
-
error!("DB error: {:?}", e);
1034
-
ApiError::InternalError(None)
1035
-
})?;
965
+
.log_db_err("updating legacy login")?;
1036
966
if !updated {
1037
967
return Err(ApiError::AccountNotFound);
1038
968
}
1039
969
info!(
1040
-
did = %&auth.did,
970
+
did = %session_mfa.did(),
1041
971
allow_legacy_login = input.allow_legacy_login,
1042
972
"Legacy login preference updated"
1043
973
);
···
1071
1001
.user_repo
1072
1002
.update_locale(&auth.did, &input.preferred_locale)
1073
1003
.await
1074
-
.map_err(|e| {
1075
-
error!("DB error updating locale: {:?}", e);
1076
-
ApiError::InternalError(None)
1077
-
})?;
1004
+
.log_db_err("updating locale")?;
1078
1005
if !updated {
1079
1006
return Err(ApiError::AccountNotFound);
1080
1007
}
+68
-166
crates/tranquil-pds/src/api/server/totp.rs
+68
-166
crates/tranquil-pds/src/api/server/totp.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::{Active, Auth};
2
+
use crate::api::error::{ApiError, DbResultExt};
4
3
use crate::auth::{
5
-
decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64,
6
-
generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format,
7
-
verify_backup_code, verify_totp_code,
4
+
Active, Auth, decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes,
5
+
generate_qr_png_base64, generate_totp_secret, generate_totp_uri, hash_backup_code,
6
+
is_backup_code_format, require_legacy_session_mfa, verify_backup_code, verify_password_mfa,
7
+
verify_totp_code, verify_totp_mfa,
8
8
};
9
-
use crate::state::{AppState, RateLimitKind};
9
+
use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message};
10
+
use crate::state::AppState;
10
11
use crate::types::PlainPassword;
12
+
use crate::util::pds_hostname;
11
13
use axum::{
12
14
Json,
13
15
extract::State,
···
30
32
State(state): State<AppState>,
31
33
auth: Auth<Active>,
32
34
) -> Result<Response, ApiError> {
33
-
match state.user_repo.get_totp_record(&auth.did).await {
34
-
Ok(Some(record)) if record.verified => return Err(ApiError::TotpAlreadyEnabled),
35
-
Ok(_) => {}
35
+
use tranquil_db_traits::TotpRecordState;
36
+
37
+
match state.user_repo.get_totp_record_state(&auth.did).await {
38
+
Ok(Some(TotpRecordState::Verified(_))) => return Err(ApiError::TotpAlreadyEnabled),
39
+
Ok(Some(TotpRecordState::Unverified(_))) | Ok(None) => {}
36
40
Err(e) => {
37
41
error!("DB error checking TOTP: {:?}", e);
38
42
return Err(ApiError::InternalError(None));
···
45
49
.user_repo
46
50
.get_handle_by_did(&auth.did)
47
51
.await
48
-
.map_err(|e| {
49
-
error!("DB error fetching handle: {:?}", e);
50
-
ApiError::InternalError(None)
51
-
})?
52
+
.log_db_err("fetching handle")?
52
53
.ok_or(ApiError::AccountNotFound)?;
53
54
54
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
55
-
let uri = generate_totp_uri(&secret, &handle, &hostname);
55
+
let hostname = pds_hostname();
56
+
let uri = generate_totp_uri(&secret, &handle, hostname);
56
57
57
-
let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| {
58
+
let qr_code = generate_qr_png_base64(&secret, &handle, hostname).map_err(|e| {
58
59
error!("Failed to generate QR code: {:?}", e);
59
60
ApiError::InternalError(Some("Failed to generate QR code".into()))
60
61
})?;
···
68
69
.user_repo
69
70
.upsert_totp_secret(&auth.did, &encrypted_secret, ENCRYPTION_VERSION)
70
71
.await
71
-
.map_err(|e| {
72
-
error!("Failed to store TOTP secret: {:?}", e);
73
-
ApiError::InternalError(None)
74
-
})?;
72
+
.log_db_err("storing TOTP secret")?;
75
73
76
74
let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret);
77
75
···
101
99
auth: Auth<Active>,
102
100
Json(input): Json<EnableTotpInput>,
103
101
) -> Result<Response, ApiError> {
104
-
if !state
105
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
106
-
.await
107
-
{
108
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
109
-
return Err(ApiError::RateLimitExceeded(None));
110
-
}
102
+
use tranquil_db_traits::TotpRecordState;
111
103
112
-
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
113
-
Ok(Some(row)) => row,
104
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
105
+
&state,
106
+
&auth.did,
107
+
"Too many verification attempts. Please try again in a few minutes.",
108
+
)
109
+
.await?;
110
+
111
+
let unverified_record = match state.user_repo.get_totp_record_state(&auth.did).await {
112
+
Ok(Some(TotpRecordState::Unverified(record))) => record,
113
+
Ok(Some(TotpRecordState::Verified(_))) => return Err(ApiError::TotpAlreadyEnabled),
114
114
Ok(None) => return Err(ApiError::TotpNotEnabled),
115
115
Err(e) => {
116
116
error!("DB error fetching TOTP: {:?}", e);
···
118
118
}
119
119
};
120
120
121
-
if totp_record.verified {
122
-
return Err(ApiError::TotpAlreadyEnabled);
123
-
}
124
-
125
121
let secret = decrypt_totp_secret(
126
-
&totp_record.secret_encrypted,
127
-
totp_record.encryption_version,
122
+
&unverified_record.secret_encrypted,
123
+
unverified_record.encryption_version,
128
124
)
129
125
.map_err(|e| {
130
126
error!("Failed to decrypt TOTP secret: {:?}", e);
···
152
148
.user_repo
153
149
.enable_totp_with_backup_codes(&auth.did, &backup_hashes)
154
150
.await
155
-
.map_err(|e| {
156
-
error!("Failed to enable TOTP: {:?}", e);
157
-
ApiError::InternalError(None)
158
-
})?;
151
+
.log_db_err("enabling TOTP")?;
159
152
160
153
info!(did = %&auth.did, "TOTP enabled with {} backup codes", backup_codes.len());
161
154
···
173
166
auth: Auth<Active>,
174
167
Json(input): Json<DisableTotpInput>,
175
168
) -> Result<Response, ApiError> {
176
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
177
-
{
178
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
179
-
&*state.user_repo,
180
-
&*state.session_repo,
181
-
&auth.did,
182
-
)
183
-
.await);
184
-
}
185
-
186
-
if !state
187
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
188
-
.await
189
-
{
190
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
191
-
return Err(ApiError::RateLimitExceeded(None));
192
-
}
193
-
194
-
let password_hash = state
195
-
.user_repo
196
-
.get_password_hash_by_did(&auth.did)
197
-
.await
198
-
.map_err(|e| {
199
-
error!("DB error fetching user: {:?}", e);
200
-
ApiError::InternalError(None)
201
-
})?
202
-
.ok_or(ApiError::AccountNotFound)?;
203
-
204
-
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
205
-
if !password_valid {
206
-
return Err(ApiError::InvalidPassword("Password is incorrect".into()));
207
-
}
208
-
209
-
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
210
-
Ok(Some(row)) if row.verified => row,
211
-
Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled),
212
-
Err(e) => {
213
-
error!("DB error fetching TOTP: {:?}", e);
214
-
return Err(ApiError::InternalError(None));
215
-
}
169
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
170
+
Ok(proof) => proof,
171
+
Err(response) => return Ok(response),
216
172
};
217
173
218
-
let code = input.code.trim();
219
-
let code_valid = if is_backup_code_format(code) {
220
-
verify_backup_code_for_user(&state, &auth.did, code).await
221
-
} else {
222
-
let secret = decrypt_totp_secret(
223
-
&totp_record.secret_encrypted,
224
-
totp_record.encryption_version,
225
-
)
226
-
.map_err(|e| {
227
-
error!("Failed to decrypt TOTP secret: {:?}", e);
228
-
ApiError::InternalError(None)
229
-
})?;
230
-
verify_totp_code(&secret, code)
231
-
};
174
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
175
+
&state,
176
+
session_mfa.did(),
177
+
"Too many verification attempts. Please try again in a few minutes.",
178
+
)
179
+
.await?;
232
180
233
-
if !code_valid {
234
-
return Err(ApiError::InvalidCode(Some(
235
-
"Invalid verification code".into(),
236
-
)));
237
-
}
181
+
let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?;
182
+
let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?;
238
183
239
184
state
240
185
.user_repo
241
-
.delete_totp_and_backup_codes(&auth.did)
186
+
.delete_totp_and_backup_codes(totp_mfa.did())
242
187
.await
243
-
.map_err(|e| {
244
-
error!("Failed to delete TOTP: {:?}", e);
245
-
ApiError::InternalError(None)
246
-
})?;
188
+
.log_db_err("deleting TOTP")?;
247
189
248
-
info!(did = %&auth.did, "TOTP disabled");
190
+
info!(did = %session_mfa.did(), "TOTP disabled (verified via {} and {})", password_mfa.method(), totp_mfa.method());
249
191
250
192
Ok(EmptyResponse::ok().into_response())
251
193
}
···
262
204
State(state): State<AppState>,
263
205
auth: Auth<Active>,
264
206
) -> Result<Response, ApiError> {
265
-
let enabled = match state.user_repo.get_totp_record(&auth.did).await {
266
-
Ok(Some(row)) => row.verified,
267
-
Ok(None) => false,
207
+
use tranquil_db_traits::TotpRecordState;
208
+
209
+
let enabled = match state.user_repo.get_totp_record_state(&auth.did).await {
210
+
Ok(Some(TotpRecordState::Verified(_))) => true,
211
+
Ok(Some(TotpRecordState::Unverified(_))) | Ok(None) => false,
268
212
Err(e) => {
269
213
error!("DB error fetching TOTP status: {:?}", e);
270
214
return Err(ApiError::InternalError(None));
···
275
219
.user_repo
276
220
.count_unused_backup_codes(&auth.did)
277
221
.await
278
-
.map_err(|e| {
279
-
error!("DB error counting backup codes: {:?}", e);
280
-
ApiError::InternalError(None)
281
-
})?;
222
+
.log_db_err("counting backup codes")?;
282
223
283
224
Ok(Json(GetTotpStatusResponse {
284
225
enabled,
···
305
246
auth: Auth<Active>,
306
247
Json(input): Json<RegenerateBackupCodesInput>,
307
248
) -> Result<Response, ApiError> {
308
-
if !state
309
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
310
-
.await
311
-
{
312
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
313
-
return Err(ApiError::RateLimitExceeded(None));
314
-
}
315
-
316
-
let password_hash = state
317
-
.user_repo
318
-
.get_password_hash_by_did(&auth.did)
319
-
.await
320
-
.map_err(|e| {
321
-
error!("DB error fetching user: {:?}", e);
322
-
ApiError::InternalError(None)
323
-
})?
324
-
.ok_or(ApiError::AccountNotFound)?;
325
-
326
-
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
327
-
if !password_valid {
328
-
return Err(ApiError::InvalidPassword("Password is incorrect".into()));
329
-
}
330
-
331
-
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
332
-
Ok(Some(row)) if row.verified => row,
333
-
Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled),
334
-
Err(e) => {
335
-
error!("DB error fetching TOTP: {:?}", e);
336
-
return Err(ApiError::InternalError(None));
337
-
}
338
-
};
339
-
340
-
let secret = decrypt_totp_secret(
341
-
&totp_record.secret_encrypted,
342
-
totp_record.encryption_version,
249
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
250
+
&state,
251
+
&auth.did,
252
+
"Too many verification attempts. Please try again in a few minutes.",
343
253
)
344
-
.map_err(|e| {
345
-
error!("Failed to decrypt TOTP secret: {:?}", e);
346
-
ApiError::InternalError(None)
347
-
})?;
254
+
.await?;
348
255
349
-
let code = input.code.trim();
350
-
if !verify_totp_code(&secret, code) {
351
-
return Err(ApiError::InvalidCode(Some(
352
-
"Invalid verification code".into(),
353
-
)));
354
-
}
256
+
let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?;
257
+
let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?;
355
258
356
259
let backup_codes = generate_backup_codes();
357
260
let backup_hashes: Vec<_> = backup_codes
···
365
268
366
269
state
367
270
.user_repo
368
-
.replace_backup_codes(&auth.did, &backup_hashes)
271
+
.replace_backup_codes(totp_mfa.did(), &backup_hashes)
369
272
.await
370
-
.map_err(|e| {
371
-
error!("Failed to regenerate backup codes: {:?}", e);
372
-
ApiError::InternalError(None)
373
-
})?;
273
+
.log_db_err("replacing backup codes")?;
374
274
375
-
info!(did = %&auth.did, "Backup codes regenerated");
275
+
info!(did = %password_mfa.did(), "Backup codes regenerated (verified via {} and {})", password_mfa.method(), totp_mfa.method());
376
276
377
277
Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response())
378
278
}
···
410
310
did: &crate::types::Did,
411
311
code: &str,
412
312
) -> bool {
313
+
use tranquil_db_traits::TotpRecordState;
314
+
413
315
let code = code.trim();
414
316
415
317
if is_backup_code_format(code) {
416
318
return verify_backup_code_for_user(state, did, code).await;
417
319
}
418
320
419
-
let totp_record = match state.user_repo.get_totp_record(did).await {
420
-
Ok(Some(row)) if row.verified => row,
321
+
let verified_record = match state.user_repo.get_totp_record_state(did).await {
322
+
Ok(Some(TotpRecordState::Verified(record))) => record,
421
323
_ => return false,
422
324
};
423
325
424
326
let secret = match decrypt_totp_secret(
425
-
&totp_record.secret_encrypted,
426
-
totp_record.encryption_version,
327
+
&verified_record.secret_encrypted,
328
+
verified_record.encryption_version,
427
329
) {
428
330
Ok(s) => s,
429
331
Err(_) => return false,
+4
-13
crates/tranquil-pds/src/api/server/trusted_devices.rs
+4
-13
crates/tranquil-pds/src/api/server/trusted_devices.rs
···
1
1
use crate::api::SuccessResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use axum::{
4
4
Json,
5
5
extract::State,
···
79
79
.oauth_repo
80
80
.list_trusted_devices(&auth.did)
81
81
.await
82
-
.map_err(|e| {
83
-
error!("DB error: {:?}", e);
84
-
ApiError::InternalError(None)
85
-
})?;
82
+
.log_db_err("listing trusted devices")?;
86
83
87
84
let devices = rows
88
85
.into_iter()
···
134
131
.oauth_repo
135
132
.revoke_device_trust(&device_id)
136
133
.await
137
-
.map_err(|e| {
138
-
error!("DB error: {:?}", e);
139
-
ApiError::InternalError(None)
140
-
})?;
134
+
.log_db_err("revoking device trust")?;
141
135
142
136
info!(did = %&auth.did, device_id = %input.device_id, "Trusted device revoked");
143
137
Ok(SuccessResponse::ok().into_response())
···
175
169
.oauth_repo
176
170
.update_device_friendly_name(&device_id, input.friendly_name.as_deref())
177
171
.await
178
-
.map_err(|e| {
179
-
error!("DB error: {:?}", e);
180
-
ApiError::InternalError(None)
181
-
})?;
172
+
.log_db_err("updating device friendly name")?;
182
173
183
174
info!(did = %auth.did, device_id = %input.device_id, "Trusted device updated");
184
175
Ok(SuccessResponse::ok().into_response())
+3
-2
crates/tranquil-pds/src/api/server/verify_email.rs
+3
-2
crates/tranquil-pds/src/api/server/verify_email.rs
···
5
5
use tracing::{info, warn};
6
6
7
7
use crate::state::AppState;
8
+
use crate::util::pds_hostname;
8
9
9
10
#[derive(Deserialize)]
10
11
#[serde(rename_all = "camelCase")]
···
70
71
return Ok(Json(ResendMigrationVerificationOutput { sent: true }));
71
72
}
72
73
73
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
74
+
let hostname = pds_hostname();
74
75
let token = crate::auth::verification_token::generate_migration_token(&user.did, &email);
75
76
let formatted_token = crate::auth::verification_token::format_token_for_display(&token);
76
77
···
80
81
user.id,
81
82
&email,
82
83
&formatted_token,
83
-
&hostname,
84
+
hostname,
84
85
)
85
86
.await
86
87
{
+16
-52
crates/tranquil-pds/src/api/server/verify_token.rs
+16
-52
crates/tranquil-pds/src/api/server/verify_token.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::types::Did;
3
3
use axum::{Json, extract::State};
4
4
use serde::{Deserialize, Serialize};
5
-
use tracing::{error, info, warn};
5
+
use tracing::{info, warn};
6
6
7
7
use crate::auth::verification_token::{
8
8
VerificationPurpose, normalize_token_input, verify_token_signature,
···
81
81
.user_repo
82
82
.get_verification_info(&did_typed)
83
83
.await
84
-
.map_err(|e| {
85
-
warn!(error = ?e, "Database error during migration verification");
86
-
ApiError::InternalError(None)
87
-
})?
84
+
.log_db_err("during migration verification")?
88
85
.ok_or(ApiError::AccountNotFound)?;
89
86
90
87
if user.email.as_ref().map(|e| e.to_lowercase()) != Some(identifier.to_string()) {
91
88
return Err(ApiError::IdentifierMismatch);
92
89
}
93
90
94
-
if !user.email_verified {
91
+
if !user.channel_verification.email {
95
92
state
96
93
.user_repo
97
94
.set_email_verified_flag(user.id)
98
95
.await
99
-
.map_err(|e| {
100
-
warn!(error = ?e, "Failed to update email_verified status");
101
-
ApiError::InternalError(None)
102
-
})?;
96
+
.log_db_err("updating email_verified status")?;
103
97
}
104
98
105
99
info!(did = %did, "Migration email verified successfully");
···
125
119
.user_repo
126
120
.get_id_by_did(&did_typed)
127
121
.await
128
-
.map_err(|_| ApiError::InternalError(None))?
122
+
.log_db_err("fetching user id")?
129
123
.ok_or(ApiError::AccountNotFound)?;
130
124
131
125
match channel {
···
134
128
.user_repo
135
129
.verify_email_channel(user_id, identifier)
136
130
.await
137
-
.map_err(|e| {
138
-
error!("Failed to update email channel: {:?}", e);
139
-
ApiError::InternalError(None)
140
-
})?;
131
+
.log_db_err("updating email channel")?;
141
132
if !success {
142
133
return Err(ApiError::EmailTaken);
143
134
}
···
147
138
.user_repo
148
139
.verify_discord_channel(user_id, identifier)
149
140
.await
150
-
.map_err(|e| {
151
-
error!("Failed to update discord channel: {:?}", e);
152
-
ApiError::InternalError(None)
153
-
})?;
141
+
.log_db_err("updating discord channel")?;
154
142
}
155
143
"telegram" => {
156
144
state
157
145
.user_repo
158
146
.verify_telegram_channel(user_id, identifier)
159
147
.await
160
-
.map_err(|e| {
161
-
error!("Failed to update telegram channel: {:?}", e);
162
-
ApiError::InternalError(None)
163
-
})?;
148
+
.log_db_err("updating telegram channel")?;
164
149
}
165
150
"signal" => {
166
151
state
167
152
.user_repo
168
153
.verify_signal_channel(user_id, identifier)
169
154
.await
170
-
.map_err(|e| {
171
-
error!("Failed to update signal channel: {:?}", e);
172
-
ApiError::InternalError(None)
173
-
})?;
155
+
.log_db_err("updating signal channel")?;
174
156
}
175
157
_ => {
176
158
return Err(ApiError::InvalidChannel);
···
200
182
.user_repo
201
183
.get_verification_info(&did_typed)
202
184
.await
203
-
.map_err(|e| {
204
-
warn!(error = ?e, "Database error during signup verification");
205
-
ApiError::InternalError(None)
206
-
})?
185
+
.log_db_err("during signup verification")?
207
186
.ok_or(ApiError::AccountNotFound)?;
208
187
209
-
let is_verified = user.email_verified
210
-
|| user.discord_verified
211
-
|| user.telegram_verified
212
-
|| user.signal_verified;
188
+
let is_verified = user.channel_verification.has_any_verified();
213
189
if is_verified {
214
190
info!(did = %did, "Account already verified");
215
191
return Ok(Json(VerifyTokenOutput {
···
226
202
.user_repo
227
203
.set_email_verified_flag(user.id)
228
204
.await
229
-
.map_err(|e| {
230
-
warn!(error = ?e, "Failed to update email verified status");
231
-
ApiError::InternalError(None)
232
-
})?;
205
+
.log_db_err("updating email verified status")?;
233
206
}
234
207
"discord" => {
235
208
state
236
209
.user_repo
237
210
.set_discord_verified_flag(user.id)
238
211
.await
239
-
.map_err(|e| {
240
-
warn!(error = ?e, "Failed to update discord verified status");
241
-
ApiError::InternalError(None)
242
-
})?;
212
+
.log_db_err("updating discord verified status")?;
243
213
}
244
214
"telegram" => {
245
215
state
246
216
.user_repo
247
217
.set_telegram_verified_flag(user.id)
248
218
.await
249
-
.map_err(|e| {
250
-
warn!(error = ?e, "Failed to update telegram verified status");
251
-
ApiError::InternalError(None)
252
-
})?;
219
+
.log_db_err("updating telegram verified status")?;
253
220
}
254
221
"signal" => {
255
222
state
256
223
.user_repo
257
224
.set_signal_verified_flag(user.id)
258
225
.await
259
-
.map_err(|e| {
260
-
warn!(error = ?e, "Failed to update signal verified status");
261
-
ApiError::InternalError(None)
262
-
})?;
226
+
.log_db_err("updating signal verified status")?;
263
227
}
264
228
_ => {
265
229
return Err(ApiError::InvalidChannel);
+61
crates/tranquil-pds/src/auth/account_verified.rs
+61
crates/tranquil-pds/src/auth/account_verified.rs
···
1
+
use axum::response::{IntoResponse, Response};
2
+
3
+
use super::AuthenticatedUser;
4
+
use crate::api::error::ApiError;
5
+
use crate::state::AppState;
6
+
use crate::types::Did;
7
+
8
+
pub struct AccountVerified<'a> {
9
+
user: &'a AuthenticatedUser,
10
+
}
11
+
12
+
impl<'a> AccountVerified<'a> {
13
+
pub fn did(&self) -> &Did {
14
+
&self.user.did
15
+
}
16
+
17
+
pub fn user(&self) -> &AuthenticatedUser {
18
+
self.user
19
+
}
20
+
}
21
+
22
+
pub async fn require_verified_or_delegated<'a>(
23
+
state: &AppState,
24
+
user: &'a AuthenticatedUser,
25
+
) -> Result<AccountVerified<'a>, Response> {
26
+
let is_verified = state
27
+
.user_repo
28
+
.has_verified_comms_channel(&user.did)
29
+
.await
30
+
.unwrap_or(false);
31
+
32
+
if is_verified {
33
+
return Ok(AccountVerified { user });
34
+
}
35
+
36
+
let is_delegated = state
37
+
.delegation_repo
38
+
.is_delegated_account(&user.did)
39
+
.await
40
+
.unwrap_or(false);
41
+
42
+
if is_delegated {
43
+
return Ok(AccountVerified { user });
44
+
}
45
+
46
+
Err(ApiError::AccountNotVerified.into_response())
47
+
}
48
+
49
+
pub async fn require_not_migrated(state: &AppState, did: &Did) -> Result<(), Response> {
50
+
match state.user_repo.is_account_migrated(did).await {
51
+
Ok(true) => Err(ApiError::AccountMigrated.into_response()),
52
+
Ok(false) => Ok(()),
53
+
Err(e) => {
54
+
tracing::error!("Failed to check migration status: {:?}", e);
55
+
Err(
56
+
ApiError::InternalError(Some("Failed to verify migration status".into()))
57
+
.into_response(),
58
+
)
59
+
}
60
+
}
61
+
}
-547
crates/tranquil-pds/src/auth/auth_extractor.rs
-547
crates/tranquil-pds/src/auth/auth_extractor.rs
···
1
-
mod common;
2
-
mod helpers;
3
-
4
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
5
-
use chrono::Utc;
6
-
use common::{base_url, client, create_account_and_login, pds_endpoint};
7
-
use helpers::verify_new_account;
8
-
use reqwest::StatusCode;
9
-
use serde_json::{Value, json};
10
-
use sha2::{Digest, Sha256};
11
-
use wiremock::matchers::{method, path};
12
-
use wiremock::{Mock, MockServer, ResponseTemplate};
13
-
14
-
fn generate_pkce() -> (String, String) {
15
-
let verifier_bytes: [u8; 32] = rand::random();
16
-
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
17
-
let mut hasher = Sha256::new();
18
-
hasher.update(code_verifier.as_bytes());
19
-
let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
20
-
(code_verifier, code_challenge)
21
-
}
22
-
23
-
async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer {
24
-
let mock_server = MockServer::start().await;
25
-
let metadata = json!({
26
-
"client_id": mock_server.uri(),
27
-
"client_name": "Auth Extractor Test Client",
28
-
"redirect_uris": [redirect_uri],
29
-
"grant_types": ["authorization_code", "refresh_token"],
30
-
"response_types": ["code"],
31
-
"token_endpoint_auth_method": "none",
32
-
"dpop_bound_access_tokens": dpop_bound
33
-
});
34
-
Mock::given(method("GET"))
35
-
.and(path("/"))
36
-
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
37
-
.mount(&mock_server)
38
-
.await;
39
-
mock_server
40
-
}
41
-
42
-
async fn get_oauth_session(
43
-
http_client: &reqwest::Client,
44
-
url: &str,
45
-
dpop_bound: bool,
46
-
) -> (String, String, String, String) {
47
-
let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8];
48
-
let handle = format!("ae{}", suffix);
49
-
let password = "AuthExtract123!";
50
-
let create_res = http_client
51
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
52
-
.json(&json!({
53
-
"handle": handle,
54
-
"email": format!("{}@example.com", handle),
55
-
"password": password
56
-
}))
57
-
.send()
58
-
.await
59
-
.unwrap();
60
-
assert_eq!(create_res.status(), StatusCode::OK);
61
-
let account: Value = create_res.json().await.unwrap();
62
-
let did = account["did"].as_str().unwrap().to_string();
63
-
verify_new_account(http_client, &did).await;
64
-
65
-
let redirect_uri = "https://example.com/auth-callback";
66
-
let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await;
67
-
let client_id = mock_client.uri();
68
-
let (code_verifier, code_challenge) = generate_pkce();
69
-
70
-
let par_body: Value = http_client
71
-
.post(format!("{}/oauth/par", url))
72
-
.form(&[
73
-
("response_type", "code"),
74
-
("client_id", &client_id),
75
-
("redirect_uri", redirect_uri),
76
-
("code_challenge", &code_challenge),
77
-
("code_challenge_method", "S256"),
78
-
])
79
-
.send()
80
-
.await
81
-
.unwrap()
82
-
.json()
83
-
.await
84
-
.unwrap();
85
-
let request_uri = par_body["request_uri"].as_str().unwrap();
86
-
87
-
let auth_res = http_client
88
-
.post(format!("{}/oauth/authorize", url))
89
-
.header("Content-Type", "application/json")
90
-
.header("Accept", "application/json")
91
-
.json(&json!({
92
-
"request_uri": request_uri,
93
-
"username": &handle,
94
-
"password": password,
95
-
"remember_device": false
96
-
}))
97
-
.send()
98
-
.await
99
-
.unwrap();
100
-
let auth_body: Value = auth_res.json().await.unwrap();
101
-
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
102
-
103
-
if location.contains("/oauth/consent") {
104
-
let consent_res = http_client
105
-
.post(format!("{}/oauth/authorize/consent", url))
106
-
.header("Content-Type", "application/json")
107
-
.json(&json!({
108
-
"request_uri": request_uri,
109
-
"approved_scopes": ["atproto"],
110
-
"remember": false
111
-
}))
112
-
.send()
113
-
.await
114
-
.unwrap();
115
-
let consent_body: Value = consent_res.json().await.unwrap();
116
-
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
117
-
}
118
-
119
-
let code = location
120
-
.split("code=")
121
-
.nth(1)
122
-
.unwrap()
123
-
.split('&')
124
-
.next()
125
-
.unwrap();
126
-
127
-
let token_body: Value = http_client
128
-
.post(format!("{}/oauth/token", url))
129
-
.form(&[
130
-
("grant_type", "authorization_code"),
131
-
("code", code),
132
-
("redirect_uri", redirect_uri),
133
-
("code_verifier", &code_verifier),
134
-
("client_id", &client_id),
135
-
])
136
-
.send()
137
-
.await
138
-
.unwrap()
139
-
.json()
140
-
.await
141
-
.unwrap();
142
-
143
-
(
144
-
token_body["access_token"].as_str().unwrap().to_string(),
145
-
token_body["refresh_token"].as_str().unwrap().to_string(),
146
-
client_id,
147
-
did,
148
-
)
149
-
}
150
-
151
-
#[tokio::test]
152
-
async fn test_oauth_token_works_with_bearer_auth() {
153
-
let url = base_url().await;
154
-
let http_client = client();
155
-
let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await;
156
-
157
-
let res = http_client
158
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
159
-
.bearer_auth(&access_token)
160
-
.send()
161
-
.await
162
-
.unwrap();
163
-
164
-
assert_eq!(res.status(), StatusCode::OK, "OAuth token should work with RequiredAuth extractor");
165
-
let body: Value = res.json().await.unwrap();
166
-
assert_eq!(body["did"].as_str().unwrap(), did);
167
-
}
168
-
169
-
#[tokio::test]
170
-
async fn test_session_token_still_works() {
171
-
let url = base_url().await;
172
-
let http_client = client();
173
-
let (jwt, did) = create_account_and_login(&http_client).await;
174
-
175
-
let res = http_client
176
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
177
-
.bearer_auth(&jwt)
178
-
.send()
179
-
.await
180
-
.unwrap();
181
-
182
-
assert_eq!(res.status(), StatusCode::OK, "Session token should still work");
183
-
let body: Value = res.json().await.unwrap();
184
-
assert_eq!(body["did"].as_str().unwrap(), did);
185
-
}
186
-
187
-
188
-
#[tokio::test]
189
-
async fn test_oauth_admin_extractor_allows_oauth_tokens() {
190
-
let url = base_url().await;
191
-
let http_client = client();
192
-
193
-
let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8];
194
-
let handle = format!("adm{}", suffix);
195
-
let password = "AdminOAuth123!";
196
-
let create_res = http_client
197
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
198
-
.json(&json!({
199
-
"handle": handle,
200
-
"email": format!("{}@example.com", handle),
201
-
"password": password
202
-
}))
203
-
.send()
204
-
.await
205
-
.unwrap();
206
-
assert_eq!(create_res.status(), StatusCode::OK);
207
-
let account: Value = create_res.json().await.unwrap();
208
-
let did = account["did"].as_str().unwrap().to_string();
209
-
verify_new_account(&http_client, &did).await;
210
-
211
-
let pool = common::get_test_db_pool().await;
212
-
sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
213
-
.execute(pool)
214
-
.await
215
-
.expect("Failed to mark user as admin");
216
-
217
-
let redirect_uri = "https://example.com/admin-callback";
218
-
let mock_client = setup_mock_client_metadata(redirect_uri, false).await;
219
-
let client_id = mock_client.uri();
220
-
let (code_verifier, code_challenge) = generate_pkce();
221
-
222
-
let par_body: Value = http_client
223
-
.post(format!("{}/oauth/par", url))
224
-
.form(&[
225
-
("response_type", "code"),
226
-
("client_id", &client_id),
227
-
("redirect_uri", redirect_uri),
228
-
("code_challenge", &code_challenge),
229
-
("code_challenge_method", "S256"),
230
-
])
231
-
.send()
232
-
.await
233
-
.unwrap()
234
-
.json()
235
-
.await
236
-
.unwrap();
237
-
let request_uri = par_body["request_uri"].as_str().unwrap();
238
-
239
-
let auth_res = http_client
240
-
.post(format!("{}/oauth/authorize", url))
241
-
.header("Content-Type", "application/json")
242
-
.header("Accept", "application/json")
243
-
.json(&json!({
244
-
"request_uri": request_uri,
245
-
"username": &handle,
246
-
"password": password,
247
-
"remember_device": false
248
-
}))
249
-
.send()
250
-
.await
251
-
.unwrap();
252
-
let auth_body: Value = auth_res.json().await.unwrap();
253
-
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
254
-
if location.contains("/oauth/consent") {
255
-
let consent_res = http_client
256
-
.post(format!("{}/oauth/authorize/consent", url))
257
-
.header("Content-Type", "application/json")
258
-
.json(&json!({
259
-
"request_uri": request_uri,
260
-
"approved_scopes": ["atproto"],
261
-
"remember": false
262
-
}))
263
-
.send()
264
-
.await
265
-
.unwrap();
266
-
let consent_body: Value = consent_res.json().await.unwrap();
267
-
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
268
-
}
269
-
270
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
271
-
let token_body: Value = http_client
272
-
.post(format!("{}/oauth/token", url))
273
-
.form(&[
274
-
("grant_type", "authorization_code"),
275
-
("code", code),
276
-
("redirect_uri", redirect_uri),
277
-
("code_verifier", &code_verifier),
278
-
("client_id", &client_id),
279
-
])
280
-
.send()
281
-
.await
282
-
.unwrap()
283
-
.json()
284
-
.await
285
-
.unwrap();
286
-
let access_token = token_body["access_token"].as_str().unwrap();
287
-
288
-
let res = http_client
289
-
.get(format!("{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", url, did))
290
-
.bearer_auth(access_token)
291
-
.send()
292
-
.await
293
-
.unwrap();
294
-
295
-
assert_eq!(
296
-
res.status(),
297
-
StatusCode::OK,
298
-
"OAuth token for admin user should work with admin endpoint"
299
-
);
300
-
}
301
-
302
-
#[tokio::test]
303
-
async fn test_expired_oauth_token_returns_proper_error() {
304
-
let url = base_url().await;
305
-
let http_client = client();
306
-
307
-
let now = Utc::now().timestamp();
308
-
let header = json!({"alg": "HS256", "typ": "at+jwt"});
309
-
let payload = json!({
310
-
"iss": url,
311
-
"sub": "did:plc:test123",
312
-
"aud": url,
313
-
"iat": now - 7200,
314
-
"exp": now - 3600,
315
-
"jti": "expired-token",
316
-
"sid": "expired-session",
317
-
"scope": "atproto",
318
-
"client_id": "https://example.com"
319
-
});
320
-
let fake_token = format!(
321
-
"{}.{}.{}",
322
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()),
323
-
URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()),
324
-
URL_SAFE_NO_PAD.encode([1u8; 32])
325
-
);
326
-
327
-
let res = http_client
328
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
329
-
.bearer_auth(&fake_token)
330
-
.send()
331
-
.await
332
-
.unwrap();
333
-
334
-
assert_eq!(
335
-
res.status(),
336
-
StatusCode::UNAUTHORIZED,
337
-
"Expired token should be rejected"
338
-
);
339
-
}
340
-
341
-
#[tokio::test]
342
-
async fn test_dpop_nonce_error_has_proper_headers() {
343
-
let url = base_url().await;
344
-
let pds_url = pds_endpoint();
345
-
let http_client = client();
346
-
347
-
let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8];
348
-
let handle = format!("dpop{}", suffix);
349
-
let create_res = http_client
350
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
351
-
.json(&json!({
352
-
"handle": handle,
353
-
"email": format!("{}@test.com", handle),
354
-
"password": "DpopTest123!"
355
-
}))
356
-
.send()
357
-
.await
358
-
.unwrap();
359
-
assert_eq!(create_res.status(), StatusCode::OK);
360
-
let account: Value = create_res.json().await.unwrap();
361
-
let did = account["did"].as_str().unwrap();
362
-
verify_new_account(&http_client, did).await;
363
-
364
-
let redirect_uri = "https://example.com/dpop-callback";
365
-
let mock_server = MockServer::start().await;
366
-
let client_id = mock_server.uri();
367
-
let metadata = json!({
368
-
"client_id": &client_id,
369
-
"client_name": "DPoP Test Client",
370
-
"redirect_uris": [redirect_uri],
371
-
"grant_types": ["authorization_code", "refresh_token"],
372
-
"response_types": ["code"],
373
-
"token_endpoint_auth_method": "none",
374
-
"dpop_bound_access_tokens": true
375
-
});
376
-
Mock::given(method("GET"))
377
-
.and(path("/"))
378
-
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
379
-
.mount(&mock_server)
380
-
.await;
381
-
382
-
let (code_verifier, code_challenge) = generate_pkce();
383
-
let par_body: Value = http_client
384
-
.post(format!("{}/oauth/par", url))
385
-
.form(&[
386
-
("response_type", "code"),
387
-
("client_id", &client_id),
388
-
("redirect_uri", redirect_uri),
389
-
("code_challenge", &code_challenge),
390
-
("code_challenge_method", "S256"),
391
-
])
392
-
.send()
393
-
.await
394
-
.unwrap()
395
-
.json()
396
-
.await
397
-
.unwrap();
398
-
399
-
let request_uri = par_body["request_uri"].as_str().unwrap();
400
-
let auth_res = http_client
401
-
.post(format!("{}/oauth/authorize", url))
402
-
.header("Content-Type", "application/json")
403
-
.header("Accept", "application/json")
404
-
.json(&json!({
405
-
"request_uri": request_uri,
406
-
"username": &handle,
407
-
"password": "DpopTest123!",
408
-
"remember_device": false
409
-
}))
410
-
.send()
411
-
.await
412
-
.unwrap();
413
-
let auth_body: Value = auth_res.json().await.unwrap();
414
-
let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string();
415
-
if location.contains("/oauth/consent") {
416
-
let consent_res = http_client
417
-
.post(format!("{}/oauth/authorize/consent", url))
418
-
.header("Content-Type", "application/json")
419
-
.json(&json!({
420
-
"request_uri": request_uri,
421
-
"approved_scopes": ["atproto"],
422
-
"remember": false
423
-
}))
424
-
.send()
425
-
.await
426
-
.unwrap();
427
-
let consent_body: Value = consent_res.json().await.unwrap();
428
-
location = consent_body["redirect_uri"].as_str().unwrap().to_string();
429
-
}
430
-
431
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
432
-
433
-
let token_endpoint = format!("{}/oauth/token", pds_url);
434
-
let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None);
435
-
436
-
let token_res = http_client
437
-
.post(format!("{}/oauth/token", url))
438
-
.header("DPoP", &dpop_proof)
439
-
.form(&[
440
-
("grant_type", "authorization_code"),
441
-
("code", code),
442
-
("redirect_uri", redirect_uri),
443
-
("code_verifier", &code_verifier),
444
-
("client_id", &client_id),
445
-
])
446
-
.send()
447
-
.await
448
-
.unwrap();
449
-
450
-
let token_status = token_res.status();
451
-
let token_nonce = token_res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap().to_string());
452
-
let token_body: Value = token_res.json().await.unwrap();
453
-
454
-
let access_token = if token_status == StatusCode::OK {
455
-
token_body["access_token"].as_str().unwrap().to_string()
456
-
} else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") {
457
-
let nonce = token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error");
458
-
let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce));
459
-
460
-
let retry_res = http_client
461
-
.post(format!("{}/oauth/token", url))
462
-
.header("DPoP", &dpop_proof_with_nonce)
463
-
.form(&[
464
-
("grant_type", "authorization_code"),
465
-
("code", code),
466
-
("redirect_uri", redirect_uri),
467
-
("code_verifier", &code_verifier),
468
-
("client_id", &client_id),
469
-
])
470
-
.send()
471
-
.await
472
-
.unwrap();
473
-
let retry_body: Value = retry_res.json().await.unwrap();
474
-
retry_body["access_token"].as_str().expect("Should get access_token after nonce retry").to_string()
475
-
} else {
476
-
panic!("Token exchange failed unexpectedly: {:?}", token_body);
477
-
};
478
-
479
-
let res = http_client
480
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
481
-
.header("Authorization", format!("DPoP {}", access_token))
482
-
.send()
483
-
.await
484
-
.unwrap();
485
-
486
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DPoP token without proof should fail");
487
-
488
-
let www_auth = res.headers().get("www-authenticate").map(|h| h.to_str().unwrap());
489
-
assert!(www_auth.is_some(), "Should have WWW-Authenticate header");
490
-
assert!(
491
-
www_auth.unwrap().contains("use_dpop_nonce"),
492
-
"WWW-Authenticate should indicate dpop nonce required"
493
-
);
494
-
495
-
let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap());
496
-
assert!(nonce.is_some(), "Should return DPoP-Nonce header");
497
-
498
-
let body: Value = res.json().await.unwrap();
499
-
assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce");
500
-
}
501
-
502
-
fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) {
503
-
use p256::ecdsa::{SigningKey, signature::Signer};
504
-
use p256::elliptic_curve::rand_core::OsRng;
505
-
506
-
let signing_key = SigningKey::random(&mut OsRng);
507
-
let verifying_key = signing_key.verifying_key();
508
-
let point = verifying_key.to_encoded_point(false);
509
-
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
510
-
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
511
-
512
-
let jwk = json!({
513
-
"kty": "EC",
514
-
"crv": "P-256",
515
-
"x": x,
516
-
"y": y
517
-
});
518
-
519
-
let header = {
520
-
let h = json!({
521
-
"typ": "dpop+jwt",
522
-
"alg": "ES256",
523
-
"jwk": jwk.clone()
524
-
});
525
-
h
526
-
};
527
-
528
-
let mut payload = json!({
529
-
"jti": uuid::Uuid::new_v4().to_string(),
530
-
"htm": method,
531
-
"htu": uri,
532
-
"iat": Utc::now().timestamp()
533
-
});
534
-
if let Some(n) = nonce {
535
-
payload["nonce"] = json!(n);
536
-
}
537
-
538
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
539
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
540
-
let signing_input = format!("{}.{}", header_b64, payload_b64);
541
-
542
-
let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes());
543
-
let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
544
-
545
-
let proof = format!("{}.{}", signing_input, sig_b64);
546
-
(jwk, proof)
547
-
}
+22
-18
crates/tranquil-pds/src/auth/extractor.rs
+22
-18
crates/tranquil-pds/src/auth/extractor.rs
···
9
9
10
10
use super::{
11
11
AccountStatus, AuthSource, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier,
12
-
is_service_token, validate_bearer_token_for_service_auth,
12
+
is_service_token, scope_verified::VerifyScope, validate_bearer_token_for_service_auth,
13
13
};
14
14
use crate::api::error::ApiError;
15
15
use crate::oauth::scopes::{RepoAction, ScopePermissions};
···
293
293
return Ok(ExtractedAuth::Service(claims));
294
294
}
295
295
296
-
let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok());
296
+
let dpop_proof = crate::util::get_header_str(&parts.headers, "DPoP");
297
297
let method = parts.method.as_str();
298
298
let uri = build_full_url(&parts.uri.to_string());
299
299
···
358
358
}
359
359
}
360
360
361
+
impl<P: AuthPolicy> AsRef<AuthenticatedUser> for Auth<P> {
362
+
fn as_ref(&self) -> &AuthenticatedUser {
363
+
&self.0
364
+
}
365
+
}
366
+
367
+
impl<P: AuthPolicy> VerifyScope for Auth<P> {
368
+
fn needs_scope_check(&self) -> bool {
369
+
self.0.is_oauth()
370
+
}
371
+
372
+
fn permissions(&self) -> ScopePermissions {
373
+
self.0.permissions()
374
+
}
375
+
}
376
+
361
377
impl<P: AuthPolicy> FromRequestParts<AppState> for Auth<P> {
362
378
type Rejection = AuthError;
363
379
···
418
434
) -> Result<Self, Self::Rejection> {
419
435
match extract_auth_internal(parts, state).await? {
420
436
ExtractedAuth::Service(claims) => {
421
-
let did: Did = claims
422
-
.iss
423
-
.parse()
424
-
.map_err(|_| AuthError::AuthenticationFailed)?;
437
+
let did = claims.iss.clone();
425
438
Ok(ServiceAuth { did, claims })
426
439
}
427
440
ExtractedAuth::User(_) => Err(AuthError::AuthenticationFailed),
···
438
451
) -> Result<Option<Self>, Self::Rejection> {
439
452
match extract_auth_internal(parts, state).await {
440
453
Ok(ExtractedAuth::Service(claims)) => {
441
-
let did: Did = claims
442
-
.iss
443
-
.parse()
444
-
.map_err(|_| AuthError::AuthenticationFailed)?;
454
+
let did = claims.iss.clone();
445
455
Ok(Some(ServiceAuth { did, claims }))
446
456
}
447
457
Ok(ExtractedAuth::User(_)) => Err(AuthError::AuthenticationFailed),
···
503
513
Ok(AuthAny::User(Auth(user, PhantomData)))
504
514
}
505
515
ExtractedAuth::Service(claims) => {
506
-
let did: Did = claims
507
-
.iss
508
-
.parse()
509
-
.map_err(|_| AuthError::AuthenticationFailed)?;
516
+
let did = claims.iss.clone();
510
517
Ok(AuthAny::Service(ServiceAuth { did, claims }))
511
518
}
512
519
}
···
526
533
Ok(Some(AuthAny::User(Auth(user, PhantomData))))
527
534
}
528
535
Ok(ExtractedAuth::Service(claims)) => {
529
-
let did: Did = claims
530
-
.iss
531
-
.parse()
532
-
.map_err(|_| AuthError::AuthenticationFailed)?;
536
+
let did = claims.iss.clone();
533
537
Ok(Some(AuthAny::Service(ServiceAuth { did, claims })))
534
538
}
535
539
Err(AuthError::MissingToken) => Ok(None),
+135
crates/tranquil-pds/src/auth/login_identifier.rs
+135
crates/tranquil-pds/src/auth/login_identifier.rs
···
1
+
use std::fmt;
2
+
3
+
#[derive(Debug, Clone, PartialEq, Eq)]
4
+
pub struct NormalizedLoginIdentifier(String);
5
+
6
+
impl NormalizedLoginIdentifier {
7
+
pub fn normalize(identifier: &str, pds_hostname: &str) -> Self {
8
+
let trimmed = identifier.trim();
9
+
let stripped = trimmed.strip_prefix('@').unwrap_or(trimmed);
10
+
11
+
let normalized = match () {
12
+
_ if stripped.starts_with("did:") => stripped.to_string(),
13
+
_ if stripped.contains('@') => stripped.to_string(),
14
+
_ if !stripped.contains('.') => {
15
+
format!("{}.{}", stripped.to_lowercase(), pds_hostname)
16
+
}
17
+
_ => stripped.to_lowercase(),
18
+
};
19
+
20
+
Self(normalized)
21
+
}
22
+
23
+
pub fn as_str(&self) -> &str {
24
+
&self.0
25
+
}
26
+
}
27
+
28
+
impl AsRef<str> for NormalizedLoginIdentifier {
29
+
fn as_ref(&self) -> &str {
30
+
&self.0
31
+
}
32
+
}
33
+
34
+
impl fmt::Display for NormalizedLoginIdentifier {
35
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36
+
write!(f, "{}", self.0)
37
+
}
38
+
}
39
+
40
+
#[derive(Debug, Clone, PartialEq, Eq)]
41
+
pub struct BareLoginIdentifier(String);
42
+
43
+
impl BareLoginIdentifier {
44
+
pub fn from_identifier(identifier: &str, pds_hostname: &str) -> Self {
45
+
let trimmed = identifier.trim();
46
+
let stripped = trimmed.strip_prefix('@').unwrap_or(trimmed);
47
+
let suffix = format!(".{}", pds_hostname);
48
+
let bare = stripped.strip_suffix(&suffix).unwrap_or(stripped);
49
+
Self(bare.to_string())
50
+
}
51
+
52
+
pub fn as_str(&self) -> &str {
53
+
&self.0
54
+
}
55
+
}
56
+
57
+
impl AsRef<str> for BareLoginIdentifier {
58
+
fn as_ref(&self) -> &str {
59
+
&self.0
60
+
}
61
+
}
62
+
63
+
impl fmt::Display for BareLoginIdentifier {
64
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65
+
write!(f, "{}", self.0)
66
+
}
67
+
}
68
+
69
+
#[cfg(test)]
70
+
mod tests {
71
+
use super::*;
72
+
73
+
#[test]
74
+
fn normalized_identifier_handles_did() {
75
+
let id = NormalizedLoginIdentifier::normalize("did:plc:abc123", "example.com");
76
+
assert_eq!(id.as_str(), "did:plc:abc123");
77
+
}
78
+
79
+
#[test]
80
+
fn normalized_identifier_handles_email() {
81
+
let id = NormalizedLoginIdentifier::normalize("user@example.org", "pds.example.com");
82
+
assert_eq!(id.as_str(), "user@example.org");
83
+
}
84
+
85
+
#[test]
86
+
fn normalized_identifier_handles_bare_handle() {
87
+
let id = NormalizedLoginIdentifier::normalize("alice", "pds.example.com");
88
+
assert_eq!(id.as_str(), "alice.pds.example.com");
89
+
}
90
+
91
+
#[test]
92
+
fn normalized_identifier_handles_bare_handle_with_at_prefix() {
93
+
let id = NormalizedLoginIdentifier::normalize("@alice", "pds.example.com");
94
+
assert_eq!(id.as_str(), "alice.pds.example.com");
95
+
}
96
+
97
+
#[test]
98
+
fn normalized_identifier_handles_full_handle() {
99
+
let id = NormalizedLoginIdentifier::normalize("alice.bsky.social", "pds.example.com");
100
+
assert_eq!(id.as_str(), "alice.bsky.social");
101
+
}
102
+
103
+
#[test]
104
+
fn normalized_identifier_handles_uppercase() {
105
+
let id = NormalizedLoginIdentifier::normalize("ALICE", "pds.example.com");
106
+
assert_eq!(id.as_str(), "alice.pds.example.com");
107
+
108
+
let id2 = NormalizedLoginIdentifier::normalize("ALICE.BSKY.SOCIAL", "pds.example.com");
109
+
assert_eq!(id2.as_str(), "alice.bsky.social");
110
+
}
111
+
112
+
#[test]
113
+
fn normalized_identifier_trims_whitespace() {
114
+
let id = NormalizedLoginIdentifier::normalize(" alice ", "pds.example.com");
115
+
assert_eq!(id.as_str(), "alice.pds.example.com");
116
+
}
117
+
118
+
#[test]
119
+
fn bare_identifier_strips_hostname_suffix() {
120
+
let id = BareLoginIdentifier::from_identifier("alice.pds.example.com", "pds.example.com");
121
+
assert_eq!(id.as_str(), "alice");
122
+
}
123
+
124
+
#[test]
125
+
fn bare_identifier_preserves_non_matching() {
126
+
let id = BareLoginIdentifier::from_identifier("alice.bsky.social", "pds.example.com");
127
+
assert_eq!(id.as_str(), "alice.bsky.social");
128
+
}
129
+
130
+
#[test]
131
+
fn bare_identifier_strips_at_prefix() {
132
+
let id = BareLoginIdentifier::from_identifier("@alice.pds.example.com", "pds.example.com");
133
+
assert_eq!(id.as_str(), "alice");
134
+
}
135
+
}
+234
crates/tranquil-pds/src/auth/mfa_verified.rs
+234
crates/tranquil-pds/src/auth/mfa_verified.rs
···
1
+
use axum::response::Response;
2
+
3
+
use super::AuthenticatedUser;
4
+
use crate::state::AppState;
5
+
use crate::types::Did;
6
+
7
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8
+
pub enum MfaMethod {
9
+
Totp,
10
+
Passkey,
11
+
Password,
12
+
RecoveryCode,
13
+
SessionReauth,
14
+
}
15
+
16
+
impl MfaMethod {
17
+
pub fn as_str(&self) -> &'static str {
18
+
match self {
19
+
Self::Totp => "totp",
20
+
Self::Passkey => "passkey",
21
+
Self::Password => "password",
22
+
Self::RecoveryCode => "recovery_code",
23
+
Self::SessionReauth => "session_reauth",
24
+
}
25
+
}
26
+
}
27
+
28
+
impl std::fmt::Display for MfaMethod {
29
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30
+
write!(f, "{}", self.as_str())
31
+
}
32
+
}
33
+
34
+
pub struct MfaVerified<'a> {
35
+
user: &'a AuthenticatedUser,
36
+
method: MfaMethod,
37
+
}
38
+
39
+
impl<'a> MfaVerified<'a> {
40
+
fn new(user: &'a AuthenticatedUser, method: MfaMethod) -> Self {
41
+
Self { user, method }
42
+
}
43
+
44
+
pub(crate) fn from_totp(user: &'a AuthenticatedUser) -> Self {
45
+
Self::new(user, MfaMethod::Totp)
46
+
}
47
+
48
+
pub(crate) fn from_password(user: &'a AuthenticatedUser) -> Self {
49
+
Self::new(user, MfaMethod::Password)
50
+
}
51
+
52
+
pub(crate) fn from_recovery_code(user: &'a AuthenticatedUser) -> Self {
53
+
Self::new(user, MfaMethod::RecoveryCode)
54
+
}
55
+
56
+
pub(crate) fn from_session_reauth(user: &'a AuthenticatedUser) -> Self {
57
+
Self::new(user, MfaMethod::SessionReauth)
58
+
}
59
+
60
+
pub fn user(&self) -> &AuthenticatedUser {
61
+
self.user
62
+
}
63
+
64
+
pub fn did(&self) -> &Did {
65
+
&self.user.did
66
+
}
67
+
68
+
pub fn method(&self) -> MfaMethod {
69
+
self.method
70
+
}
71
+
}
72
+
73
+
pub async fn require_legacy_session_mfa<'a>(
74
+
state: &AppState,
75
+
user: &'a AuthenticatedUser,
76
+
) -> Result<MfaVerified<'a>, Response> {
77
+
use crate::api::server::reauth::{check_legacy_session_mfa, legacy_mfa_required_response};
78
+
79
+
if check_legacy_session_mfa(&*state.session_repo, &user.did).await {
80
+
Ok(MfaVerified::from_session_reauth(user))
81
+
} else {
82
+
Err(legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
83
+
}
84
+
}
85
+
86
+
pub async fn require_reauth_window<'a>(
87
+
state: &AppState,
88
+
user: &'a AuthenticatedUser,
89
+
) -> Result<MfaVerified<'a>, Response> {
90
+
use crate::api::server::reauth::{REAUTH_WINDOW_SECONDS, reauth_required_response};
91
+
use chrono::Utc;
92
+
93
+
let status = state
94
+
.session_repo
95
+
.get_session_mfa_status(&user.did)
96
+
.await
97
+
.ok()
98
+
.flatten();
99
+
100
+
match status {
101
+
Some(s) => {
102
+
if let Some(last_reauth) = s.last_reauth_at {
103
+
let elapsed = Utc::now().signed_duration_since(last_reauth);
104
+
if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
105
+
return Ok(MfaVerified::from_session_reauth(user));
106
+
}
107
+
}
108
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
109
+
}
110
+
None => {
111
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
112
+
}
113
+
}
114
+
}
115
+
116
+
pub async fn require_reauth_window_if_available<'a>(
117
+
state: &AppState,
118
+
user: &'a AuthenticatedUser,
119
+
) -> Result<Option<MfaVerified<'a>>, Response> {
120
+
use crate::api::server::reauth::{check_reauth_required_cached, reauth_required_response};
121
+
122
+
let has_password = state
123
+
.user_repo
124
+
.has_password_by_did(&user.did)
125
+
.await
126
+
.ok()
127
+
.flatten()
128
+
.unwrap_or(false);
129
+
let has_passkeys = state
130
+
.user_repo
131
+
.has_passkeys(&user.did)
132
+
.await
133
+
.unwrap_or(false);
134
+
let has_totp = state
135
+
.user_repo
136
+
.has_totp_enabled(&user.did)
137
+
.await
138
+
.unwrap_or(false);
139
+
140
+
let has_any_reauth_method = has_password || has_passkeys || has_totp;
141
+
142
+
if !has_any_reauth_method {
143
+
return Ok(None);
144
+
}
145
+
146
+
if check_reauth_required_cached(&*state.session_repo, &state.cache, &user.did).await {
147
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
148
+
} else {
149
+
Ok(Some(MfaVerified::from_session_reauth(user)))
150
+
}
151
+
}
152
+
153
+
pub async fn verify_password_mfa<'a>(
154
+
state: &AppState,
155
+
user: &'a AuthenticatedUser,
156
+
password: &str,
157
+
) -> Result<MfaVerified<'a>, crate::api::error::ApiError> {
158
+
let hash = state
159
+
.user_repo
160
+
.get_password_hash_by_did(&user.did)
161
+
.await
162
+
.ok()
163
+
.flatten();
164
+
165
+
match hash {
166
+
Some(h) => {
167
+
if bcrypt::verify(password, &h).unwrap_or(false) {
168
+
Ok(MfaVerified::from_password(user))
169
+
} else {
170
+
Err(crate::api::error::ApiError::InvalidPassword(
171
+
"Password is incorrect".into(),
172
+
))
173
+
}
174
+
}
175
+
None => Err(crate::api::error::ApiError::AccountNotFound),
176
+
}
177
+
}
178
+
179
+
pub async fn verify_totp_mfa<'a>(
180
+
state: &AppState,
181
+
user: &'a AuthenticatedUser,
182
+
code: &str,
183
+
) -> Result<MfaVerified<'a>, crate::api::error::ApiError> {
184
+
use crate::auth::{decrypt_totp_secret, is_backup_code_format, verify_totp_code};
185
+
use tranquil_db_traits::TotpRecordState;
186
+
187
+
let code = code.trim();
188
+
189
+
if is_backup_code_format(code) {
190
+
let backup_codes = state
191
+
.user_repo
192
+
.get_unused_backup_codes(&user.did)
193
+
.await
194
+
.ok()
195
+
.unwrap_or_default();
196
+
let code_upper = code.to_uppercase();
197
+
198
+
let matched = backup_codes
199
+
.iter()
200
+
.find(|row| crate::auth::verify_backup_code(&code_upper, &row.code_hash));
201
+
202
+
return match matched {
203
+
Some(row) => {
204
+
let _ = state.user_repo.mark_backup_code_used(row.id).await;
205
+
Ok(MfaVerified::from_recovery_code(user))
206
+
}
207
+
None => Err(crate::api::error::ApiError::InvalidCode(Some(
208
+
"Invalid backup code".into(),
209
+
))),
210
+
};
211
+
}
212
+
213
+
let verified_record = match state.user_repo.get_totp_record_state(&user.did).await {
214
+
Ok(Some(TotpRecordState::Verified(record))) => record,
215
+
_ => {
216
+
return Err(crate::api::error::ApiError::TotpNotEnabled);
217
+
}
218
+
};
219
+
220
+
let secret = decrypt_totp_secret(
221
+
&verified_record.secret_encrypted,
222
+
verified_record.encryption_version,
223
+
)
224
+
.map_err(|_| crate::api::error::ApiError::InternalError(None))?;
225
+
226
+
if verify_totp_code(&secret, code) {
227
+
let _ = state.user_repo.update_totp_last_used(&user.did).await;
228
+
Ok(MfaVerified::from_totp(user))
229
+
} else {
230
+
Err(crate::api::error::ApiError::InvalidCode(Some(
231
+
"Invalid verification code".into(),
232
+
)))
233
+
}
234
+
}
+23
-4
crates/tranquil-pds/src/auth/mod.rs
+23
-4
crates/tranquil-pds/src/auth/mod.rs
···
10
10
use tranquil_db::UserRepository;
11
11
use tranquil_db_traits::OAuthRepository;
12
12
13
+
pub mod account_verified;
13
14
pub mod extractor;
15
+
pub mod login_identifier;
16
+
pub mod mfa_verified;
14
17
pub mod scope_check;
18
+
pub mod scope_verified;
15
19
pub mod service;
16
20
pub mod verification_token;
17
21
pub mod webauthn;
18
22
23
+
pub use login_identifier::{BareLoginIdentifier, NormalizedLoginIdentifier};
24
+
25
+
pub use account_verified::{AccountVerified, require_not_migrated, require_verified_or_delegated};
19
26
pub use extractor::{
20
27
Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown,
21
28
Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header,
22
29
};
30
+
pub use mfa_verified::{
31
+
MfaMethod, MfaVerified, require_legacy_session_mfa, require_reauth_window,
32
+
require_reauth_window_if_available, verify_password_mfa, verify_totp_mfa,
33
+
};
34
+
pub use scope_verified::{
35
+
AccountManage, AccountRead, BatchWriteScopes, BlobScopeAction, BlobUpload, ControllerDid,
36
+
IdentityAccess, PrincipalDid, RepoCreate, RepoDelete, RepoScopeAction, RepoUpdate, RepoUpsert,
37
+
RpcCall, ScopeAction, ScopeVerificationError, ScopeVerified, VerifyScope, WriteOpKind,
38
+
verify_batch_write_scopes,
39
+
};
23
40
pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
24
41
25
42
pub use tranquil_auth::{
···
409
426
.claims
410
427
.act
411
428
.as_ref()
412
-
.map(|a| Did::new_unchecked(a.sub.clone()));
429
+
.map(|a| unsafe { Did::new_unchecked(a.sub.clone()) });
413
430
let status =
414
431
AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at);
415
432
return Ok(AuthenticatedUser {
···
461
478
None
462
479
};
463
480
return Ok(AuthenticatedUser {
464
-
did: Did::new_unchecked(oauth_token.did),
481
+
did: unsafe { Did::new_unchecked(oauth_token.did) },
465
482
key_bytes,
466
483
is_admin: oauth_token.is_admin,
467
484
status,
468
485
scope: oauth_info.scope,
469
-
controller_did: oauth_info.controller_did.map(Did::new_unchecked),
486
+
controller_did: oauth_info
487
+
.controller_did
488
+
.map(|d| unsafe { Did::new_unchecked(d) }),
470
489
auth_source: AuthSource::OAuth,
471
490
});
472
491
} else {
···
545
564
None
546
565
};
547
566
Ok(AuthenticatedUser {
548
-
did: Did::new_unchecked(result.did),
567
+
did: unsafe { Did::new_unchecked(result.did) },
549
568
key_bytes,
550
569
is_admin: user_info.is_admin,
551
570
status,
+502
crates/tranquil-pds/src/auth/scope_verified.rs
+502
crates/tranquil-pds/src/auth/scope_verified.rs
···
1
+
use std::marker::PhantomData;
2
+
use std::ops::Deref;
3
+
4
+
use axum::response::{IntoResponse, Response};
5
+
6
+
use crate::api::error::ApiError;
7
+
use crate::oauth::scopes::{
8
+
AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions,
9
+
};
10
+
use crate::types::Did;
11
+
12
+
use super::AuthenticatedUser;
13
+
14
+
#[derive(Debug, Clone)]
15
+
pub struct PrincipalDid(Did);
16
+
17
+
impl PrincipalDid {
18
+
pub fn as_did(&self) -> &Did {
19
+
&self.0
20
+
}
21
+
22
+
pub fn into_did(self) -> Did {
23
+
self.0
24
+
}
25
+
26
+
pub fn as_str(&self) -> &str {
27
+
self.0.as_str()
28
+
}
29
+
}
30
+
31
+
impl Deref for PrincipalDid {
32
+
type Target = Did;
33
+
34
+
fn deref(&self) -> &Self::Target {
35
+
&self.0
36
+
}
37
+
}
38
+
39
+
impl AsRef<Did> for PrincipalDid {
40
+
fn as_ref(&self) -> &Did {
41
+
&self.0
42
+
}
43
+
}
44
+
45
+
impl std::fmt::Display for PrincipalDid {
46
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47
+
self.0.fmt(f)
48
+
}
49
+
}
50
+
51
+
#[derive(Debug, Clone)]
52
+
pub struct ControllerDid(Did);
53
+
54
+
impl ControllerDid {
55
+
pub fn as_did(&self) -> &Did {
56
+
&self.0
57
+
}
58
+
59
+
pub fn into_did(self) -> Did {
60
+
self.0
61
+
}
62
+
63
+
pub fn as_str(&self) -> &str {
64
+
self.0.as_str()
65
+
}
66
+
}
67
+
68
+
impl Deref for ControllerDid {
69
+
type Target = Did;
70
+
71
+
fn deref(&self) -> &Self::Target {
72
+
&self.0
73
+
}
74
+
}
75
+
76
+
impl AsRef<Did> for ControllerDid {
77
+
fn as_ref(&self) -> &Did {
78
+
&self.0
79
+
}
80
+
}
81
+
82
+
impl std::fmt::Display for ControllerDid {
83
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84
+
self.0.fmt(f)
85
+
}
86
+
}
87
+
88
+
#[derive(Debug)]
89
+
pub struct ScopeVerificationError {
90
+
message: String,
91
+
}
92
+
93
+
impl ScopeVerificationError {
94
+
pub fn new(message: impl Into<String>) -> Self {
95
+
Self {
96
+
message: message.into(),
97
+
}
98
+
}
99
+
100
+
pub fn message(&self) -> &str {
101
+
&self.message
102
+
}
103
+
}
104
+
105
+
impl std::fmt::Display for ScopeVerificationError {
106
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107
+
write!(f, "{}", self.message)
108
+
}
109
+
}
110
+
111
+
impl std::error::Error for ScopeVerificationError {}
112
+
113
+
impl IntoResponse for ScopeVerificationError {
114
+
fn into_response(self) -> Response {
115
+
ApiError::InsufficientScope(Some(self.message)).into_response()
116
+
}
117
+
}
118
+
119
+
mod private {
120
+
pub trait Sealed {}
121
+
pub trait RepoScopeSealed {}
122
+
pub trait BlobScopeSealed {}
123
+
}
124
+
125
+
pub trait ScopeAction: private::Sealed {}
126
+
127
+
pub trait RepoScopeAction: ScopeAction + private::RepoScopeSealed {}
128
+
129
+
pub trait BlobScopeAction: ScopeAction + private::BlobScopeSealed {}
130
+
131
+
pub struct RepoCreate;
132
+
pub struct RepoUpdate;
133
+
pub struct RepoDelete;
134
+
pub struct RepoUpsert;
135
+
pub struct BlobUpload;
136
+
pub struct RpcCall;
137
+
pub struct AccountRead;
138
+
pub struct AccountManage;
139
+
pub struct IdentityAccess;
140
+
141
+
impl private::Sealed for RepoCreate {}
142
+
impl private::Sealed for RepoUpdate {}
143
+
impl private::Sealed for RepoDelete {}
144
+
impl private::Sealed for RepoUpsert {}
145
+
impl private::Sealed for BlobUpload {}
146
+
impl private::Sealed for RpcCall {}
147
+
impl private::Sealed for AccountRead {}
148
+
impl private::Sealed for AccountManage {}
149
+
impl private::Sealed for IdentityAccess {}
150
+
151
+
impl private::RepoScopeSealed for RepoCreate {}
152
+
impl private::RepoScopeSealed for RepoUpdate {}
153
+
impl private::RepoScopeSealed for RepoDelete {}
154
+
impl private::RepoScopeSealed for RepoUpsert {}
155
+
156
+
impl private::BlobScopeSealed for BlobUpload {}
157
+
158
+
impl ScopeAction for RepoCreate {}
159
+
impl ScopeAction for RepoUpdate {}
160
+
impl ScopeAction for RepoDelete {}
161
+
impl ScopeAction for RepoUpsert {}
162
+
impl ScopeAction for BlobUpload {}
163
+
impl ScopeAction for RpcCall {}
164
+
impl ScopeAction for AccountRead {}
165
+
impl ScopeAction for AccountManage {}
166
+
impl ScopeAction for IdentityAccess {}
167
+
168
+
impl RepoScopeAction for RepoCreate {}
169
+
impl RepoScopeAction for RepoUpdate {}
170
+
impl RepoScopeAction for RepoDelete {}
171
+
impl RepoScopeAction for RepoUpsert {}
172
+
173
+
impl BlobScopeAction for BlobUpload {}
174
+
175
+
pub struct ScopeVerified<'a, A: ScopeAction> {
176
+
user: &'a AuthenticatedUser,
177
+
_action: PhantomData<A>,
178
+
}
179
+
180
+
impl<'a, A: ScopeAction> ScopeVerified<'a, A> {
181
+
pub fn user(&self) -> &AuthenticatedUser {
182
+
self.user
183
+
}
184
+
185
+
pub fn principal_did(&self) -> PrincipalDid {
186
+
PrincipalDid(self.user.did.clone())
187
+
}
188
+
189
+
pub fn controller_did(&self) -> Option<ControllerDid> {
190
+
self.user.controller_did.clone().map(ControllerDid)
191
+
}
192
+
193
+
pub fn is_admin(&self) -> bool {
194
+
self.user.is_admin
195
+
}
196
+
}
197
+
198
+
pub struct BatchWriteScopes<'a> {
199
+
user: &'a AuthenticatedUser,
200
+
has_creates: bool,
201
+
has_updates: bool,
202
+
has_deletes: bool,
203
+
}
204
+
205
+
impl<'a> BatchWriteScopes<'a> {
206
+
pub fn principal_did(&self) -> PrincipalDid {
207
+
PrincipalDid(self.user.did.clone())
208
+
}
209
+
210
+
pub fn controller_did(&self) -> Option<ControllerDid> {
211
+
self.user.controller_did.clone().map(ControllerDid)
212
+
}
213
+
214
+
pub fn user(&self) -> &AuthenticatedUser {
215
+
self.user
216
+
}
217
+
218
+
pub fn has_creates(&self) -> bool {
219
+
self.has_creates
220
+
}
221
+
222
+
pub fn has_updates(&self) -> bool {
223
+
self.has_updates
224
+
}
225
+
226
+
pub fn has_deletes(&self) -> bool {
227
+
self.has_deletes
228
+
}
229
+
}
230
+
231
+
pub fn verify_batch_write_scopes<'a, T, C, F>(
232
+
auth: &'a impl VerifyScope,
233
+
user: &'a AuthenticatedUser,
234
+
writes: &[T],
235
+
get_collection: F,
236
+
classify: C,
237
+
) -> Result<BatchWriteScopes<'a>, ScopeVerificationError>
238
+
where
239
+
F: Fn(&T) -> &str,
240
+
C: Fn(&T) -> WriteOpKind,
241
+
{
242
+
use std::collections::HashSet;
243
+
244
+
let create_collections: HashSet<&str> = writes
245
+
.iter()
246
+
.filter(|w| matches!(classify(w), WriteOpKind::Create))
247
+
.map(&get_collection)
248
+
.collect();
249
+
250
+
let update_collections: HashSet<&str> = writes
251
+
.iter()
252
+
.filter(|w| matches!(classify(w), WriteOpKind::Update))
253
+
.map(&get_collection)
254
+
.collect();
255
+
256
+
let delete_collections: HashSet<&str> = writes
257
+
.iter()
258
+
.filter(|w| matches!(classify(w), WriteOpKind::Delete))
259
+
.map(&get_collection)
260
+
.collect();
261
+
262
+
if auth.needs_scope_check() {
263
+
create_collections.iter().try_for_each(|c| {
264
+
auth.permissions()
265
+
.assert_repo(RepoAction::Create, c)
266
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))
267
+
})?;
268
+
269
+
update_collections.iter().try_for_each(|c| {
270
+
auth.permissions()
271
+
.assert_repo(RepoAction::Update, c)
272
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))
273
+
})?;
274
+
275
+
delete_collections.iter().try_for_each(|c| {
276
+
auth.permissions()
277
+
.assert_repo(RepoAction::Delete, c)
278
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))
279
+
})?;
280
+
}
281
+
282
+
Ok(BatchWriteScopes {
283
+
user,
284
+
has_creates: !create_collections.is_empty(),
285
+
has_updates: !update_collections.is_empty(),
286
+
has_deletes: !delete_collections.is_empty(),
287
+
})
288
+
}
289
+
290
+
#[derive(Clone, Copy)]
291
+
pub enum WriteOpKind {
292
+
Create,
293
+
Update,
294
+
Delete,
295
+
}
296
+
297
+
pub trait VerifyScope {
298
+
fn needs_scope_check(&self) -> bool;
299
+
fn permissions(&self) -> ScopePermissions;
300
+
301
+
fn verify_repo_create<'a>(
302
+
&'a self,
303
+
collection: &str,
304
+
) -> Result<ScopeVerified<'a, RepoCreate>, ScopeVerificationError>
305
+
where
306
+
Self: AsRef<AuthenticatedUser>,
307
+
{
308
+
if !self.needs_scope_check() {
309
+
return Ok(ScopeVerified {
310
+
user: self.as_ref(),
311
+
_action: PhantomData,
312
+
});
313
+
}
314
+
self.permissions()
315
+
.assert_repo(RepoAction::Create, collection)
316
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
317
+
Ok(ScopeVerified {
318
+
user: self.as_ref(),
319
+
_action: PhantomData,
320
+
})
321
+
}
322
+
323
+
fn verify_repo_update<'a>(
324
+
&'a self,
325
+
collection: &str,
326
+
) -> Result<ScopeVerified<'a, RepoUpdate>, ScopeVerificationError>
327
+
where
328
+
Self: AsRef<AuthenticatedUser>,
329
+
{
330
+
if !self.needs_scope_check() {
331
+
return Ok(ScopeVerified {
332
+
user: self.as_ref(),
333
+
_action: PhantomData,
334
+
});
335
+
}
336
+
self.permissions()
337
+
.assert_repo(RepoAction::Update, collection)
338
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
339
+
Ok(ScopeVerified {
340
+
user: self.as_ref(),
341
+
_action: PhantomData,
342
+
})
343
+
}
344
+
345
+
fn verify_repo_delete<'a>(
346
+
&'a self,
347
+
collection: &str,
348
+
) -> Result<ScopeVerified<'a, RepoDelete>, ScopeVerificationError>
349
+
where
350
+
Self: AsRef<AuthenticatedUser>,
351
+
{
352
+
if !self.needs_scope_check() {
353
+
return Ok(ScopeVerified {
354
+
user: self.as_ref(),
355
+
_action: PhantomData,
356
+
});
357
+
}
358
+
self.permissions()
359
+
.assert_repo(RepoAction::Delete, collection)
360
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
361
+
Ok(ScopeVerified {
362
+
user: self.as_ref(),
363
+
_action: PhantomData,
364
+
})
365
+
}
366
+
367
+
fn verify_repo_upsert<'a>(
368
+
&'a self,
369
+
collection: &str,
370
+
) -> Result<ScopeVerified<'a, RepoUpsert>, ScopeVerificationError>
371
+
where
372
+
Self: AsRef<AuthenticatedUser>,
373
+
{
374
+
if !self.needs_scope_check() {
375
+
return Ok(ScopeVerified {
376
+
user: self.as_ref(),
377
+
_action: PhantomData,
378
+
});
379
+
}
380
+
self.permissions()
381
+
.assert_repo(RepoAction::Create, collection)
382
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
383
+
self.permissions()
384
+
.assert_repo(RepoAction::Update, collection)
385
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
386
+
Ok(ScopeVerified {
387
+
user: self.as_ref(),
388
+
_action: PhantomData,
389
+
})
390
+
}
391
+
392
+
fn verify_blob_upload<'a>(
393
+
&'a self,
394
+
mime_type: &str,
395
+
) -> Result<ScopeVerified<'a, BlobUpload>, ScopeVerificationError>
396
+
where
397
+
Self: AsRef<AuthenticatedUser>,
398
+
{
399
+
if !self.needs_scope_check() {
400
+
return Ok(ScopeVerified {
401
+
user: self.as_ref(),
402
+
_action: PhantomData,
403
+
});
404
+
}
405
+
self.permissions()
406
+
.assert_blob(mime_type)
407
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
408
+
Ok(ScopeVerified {
409
+
user: self.as_ref(),
410
+
_action: PhantomData,
411
+
})
412
+
}
413
+
414
+
fn verify_rpc<'a>(
415
+
&'a self,
416
+
aud: &str,
417
+
lxm: &str,
418
+
) -> Result<ScopeVerified<'a, RpcCall>, ScopeVerificationError>
419
+
where
420
+
Self: AsRef<AuthenticatedUser>,
421
+
{
422
+
if !self.needs_scope_check() {
423
+
return Ok(ScopeVerified {
424
+
user: self.as_ref(),
425
+
_action: PhantomData,
426
+
});
427
+
}
428
+
self.permissions()
429
+
.assert_rpc(aud, lxm)
430
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
431
+
Ok(ScopeVerified {
432
+
user: self.as_ref(),
433
+
_action: PhantomData,
434
+
})
435
+
}
436
+
437
+
fn verify_account_read<'a>(
438
+
&'a self,
439
+
attr: AccountAttr,
440
+
) -> Result<ScopeVerified<'a, AccountRead>, ScopeVerificationError>
441
+
where
442
+
Self: AsRef<AuthenticatedUser>,
443
+
{
444
+
if !self.needs_scope_check() {
445
+
return Ok(ScopeVerified {
446
+
user: self.as_ref(),
447
+
_action: PhantomData,
448
+
});
449
+
}
450
+
self.permissions()
451
+
.assert_account(attr, AccountAction::Read)
452
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
453
+
Ok(ScopeVerified {
454
+
user: self.as_ref(),
455
+
_action: PhantomData,
456
+
})
457
+
}
458
+
459
+
fn verify_account_manage<'a>(
460
+
&'a self,
461
+
attr: AccountAttr,
462
+
) -> Result<ScopeVerified<'a, AccountManage>, ScopeVerificationError>
463
+
where
464
+
Self: AsRef<AuthenticatedUser>,
465
+
{
466
+
if !self.needs_scope_check() {
467
+
return Ok(ScopeVerified {
468
+
user: self.as_ref(),
469
+
_action: PhantomData,
470
+
});
471
+
}
472
+
self.permissions()
473
+
.assert_account(attr, AccountAction::Manage)
474
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
475
+
Ok(ScopeVerified {
476
+
user: self.as_ref(),
477
+
_action: PhantomData,
478
+
})
479
+
}
480
+
481
+
fn verify_identity<'a>(
482
+
&'a self,
483
+
attr: IdentityAttr,
484
+
) -> Result<ScopeVerified<'a, IdentityAccess>, ScopeVerificationError>
485
+
where
486
+
Self: AsRef<AuthenticatedUser>,
487
+
{
488
+
if !self.needs_scope_check() {
489
+
return Ok(ScopeVerified {
490
+
user: self.as_ref(),
491
+
_action: PhantomData,
492
+
});
493
+
}
494
+
self.permissions()
495
+
.assert_identity(attr)
496
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
497
+
Ok(ScopeVerified {
498
+
user: self.as_ref(),
499
+
_action: PhantomData,
500
+
})
501
+
}
502
+
}
+10
-9
crates/tranquil-pds/src/auth/service.rs
+10
-9
crates/tranquil-pds/src/auth/service.rs
···
1
+
use crate::types::Did;
2
+
use crate::util::pds_hostname;
1
3
use anyhow::{Result, anyhow};
2
4
use base64::Engine as _;
3
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
···
42
44
43
45
#[derive(Debug, Clone, Serialize, Deserialize)]
44
46
pub struct ServiceTokenClaims {
45
-
pub iss: String,
47
+
pub iss: Did,
46
48
#[serde(default)]
47
-
pub sub: Option<String>,
48
-
pub aud: String,
49
+
pub sub: Option<Did>,
50
+
pub aud: Did,
49
51
pub exp: usize,
50
52
#[serde(default)]
51
53
pub iat: Option<usize>,
···
56
58
}
57
59
58
60
impl ServiceTokenClaims {
59
-
pub fn subject(&self) -> &str {
60
-
self.sub.as_deref().unwrap_or(&self.iss)
61
+
pub fn subject(&self) -> &Did {
62
+
self.sub.as_ref().unwrap_or(&self.iss)
61
63
}
62
64
}
63
65
···
78
80
let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
79
81
.unwrap_or_else(|_| "https://plc.directory".to_string());
80
82
81
-
let pds_hostname =
82
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83
+
let pds_hostname = pds_hostname();
83
84
let pds_did = format!("did:web:{}", pds_hostname);
84
85
85
86
let client = Client::builder()
···
130
131
return Err(anyhow!("Token expired"));
131
132
}
132
133
133
-
if claims.aud != self.pds_did {
134
+
if claims.aud.as_str() != self.pds_did {
134
135
return Err(anyhow!(
135
136
"Invalid audience: expected {}, got {}",
136
137
self.pds_did,
···
154
155
}
155
156
}
156
157
157
-
let did = &claims.iss;
158
+
let did = claims.iss.as_str();
158
159
let public_key = self.resolve_signing_key(did).await?;
159
160
160
161
let signature_bytes = URL_SAFE_NO_PAD
+1
-1
crates/tranquil-pds/src/auth/webauthn.rs
+1
-1
crates/tranquil-pds/src/auth/webauthn.rs
···
7
7
8
8
impl WebAuthnConfig {
9
9
pub fn new(hostname: &str) -> Result<Self, String> {
10
-
let rp_id = hostname.to_string();
10
+
let rp_id = hostname.split(':').next().unwrap_or(hostname).to_string();
11
11
let rp_origin = Url::parse(&format!("https://{}", hostname))
12
12
.map_err(|e| format!("Invalid origin URL: {}", e))?;
13
13
+101
crates/tranquil-pds/src/cid_types.rs
+101
crates/tranquil-pds/src/cid_types.rs
···
1
+
use cid::Cid;
2
+
use std::fmt;
3
+
use std::str::FromStr;
4
+
5
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6
+
pub struct CommitCid(Cid);
7
+
8
+
impl CommitCid {
9
+
pub fn new(cid: Cid) -> Self {
10
+
Self(cid)
11
+
}
12
+
13
+
pub fn as_cid(&self) -> &Cid {
14
+
&self.0
15
+
}
16
+
17
+
pub fn into_cid(self) -> Cid {
18
+
self.0
19
+
}
20
+
}
21
+
22
+
impl From<Cid> for CommitCid {
23
+
fn from(cid: Cid) -> Self {
24
+
Self(cid)
25
+
}
26
+
}
27
+
28
+
impl From<CommitCid> for Cid {
29
+
fn from(commit_cid: CommitCid) -> Self {
30
+
commit_cid.0
31
+
}
32
+
}
33
+
34
+
impl FromStr for CommitCid {
35
+
type Err = cid::Error;
36
+
37
+
fn from_str(s: &str) -> Result<Self, Self::Err> {
38
+
Cid::from_str(s).map(Self)
39
+
}
40
+
}
41
+
42
+
impl fmt::Display for CommitCid {
43
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44
+
write!(f, "{}", self.0)
45
+
}
46
+
}
47
+
48
+
impl AsRef<Cid> for CommitCid {
49
+
fn as_ref(&self) -> &Cid {
50
+
&self.0
51
+
}
52
+
}
53
+
54
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55
+
pub struct RecordCid(Cid);
56
+
57
+
impl RecordCid {
58
+
pub fn new(cid: Cid) -> Self {
59
+
Self(cid)
60
+
}
61
+
62
+
pub fn as_cid(&self) -> &Cid {
63
+
&self.0
64
+
}
65
+
66
+
pub fn into_cid(self) -> Cid {
67
+
self.0
68
+
}
69
+
}
70
+
71
+
impl From<Cid> for RecordCid {
72
+
fn from(cid: Cid) -> Self {
73
+
Self(cid)
74
+
}
75
+
}
76
+
77
+
impl From<RecordCid> for Cid {
78
+
fn from(record_cid: RecordCid) -> Self {
79
+
record_cid.0
80
+
}
81
+
}
82
+
83
+
impl FromStr for RecordCid {
84
+
type Err = cid::Error;
85
+
86
+
fn from_str(s: &str) -> Result<Self, Self::Err> {
87
+
Cid::from_str(s).map(Self)
88
+
}
89
+
}
90
+
91
+
impl fmt::Display for RecordCid {
92
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93
+
write!(f, "{}", self.0)
94
+
}
95
+
}
96
+
97
+
impl AsRef<Cid> for RecordCid {
98
+
fn as_ref(&self) -> &Cid {
99
+
&self.0
100
+
}
101
+
}
+11
-13
crates/tranquil-pds/src/comms/service.rs
+11
-13
crates/tranquil-pds/src/comms/service.rs
···
3
3
use std::time::Duration;
4
4
5
5
use chrono::Utc;
6
-
use tokio::sync::watch;
7
6
use tokio::time::interval;
7
+
use tokio_util::sync::CancellationToken;
8
8
use tracing::{debug, error, info, warn};
9
9
use tranquil_comms::{
10
10
CommsChannel, CommsSender, CommsStatus, CommsType, NewComms, SendError, format_message,
···
96
96
!self.senders.is_empty()
97
97
}
98
98
99
-
pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
99
+
pub async fn run(self, shutdown: CancellationToken) {
100
100
if self.senders.is_empty() {
101
101
warn!(
102
102
"Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured."
···
116
116
error!(error = %e, "Failed to process comms batch");
117
117
}
118
118
}
119
-
_ = shutdown.changed() => {
120
-
if *shutdown.borrow() {
121
-
info!("Comms service shutting down");
122
-
break;
123
-
}
119
+
_ = shutdown.cancelled() => {
120
+
info!("Comms service shutting down");
121
+
break;
124
122
}
125
123
}
126
124
}
···
278
276
&[("hostname", hostname), ("handle", &prefs.handle)],
279
277
);
280
278
let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]);
281
-
let channel = channel_from_str(&prefs.preferred_channel);
279
+
let channel = prefs.preferred_channel;
282
280
infra_repo
283
281
.enqueue_comms(
284
282
Some(user_id),
···
309
307
&[("handle", &prefs.handle), ("code", code)],
310
308
);
311
309
let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]);
312
-
let channel = channel_from_str(&prefs.preferred_channel);
310
+
let channel = prefs.preferred_channel;
313
311
infra_repo
314
312
.enqueue_comms(
315
313
Some(user_id),
···
422
420
&[("handle", &prefs.handle), ("code", code)],
423
421
);
424
422
let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]);
425
-
let channel = channel_from_str(&prefs.preferred_channel);
423
+
let channel = prefs.preferred_channel;
426
424
infra_repo
427
425
.enqueue_comms(
428
426
Some(user_id),
···
453
451
&[("handle", &prefs.handle), ("token", token)],
454
452
);
455
453
let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]);
456
-
let channel = channel_from_str(&prefs.preferred_channel);
454
+
let channel = prefs.preferred_channel;
457
455
infra_repo
458
456
.enqueue_comms(
459
457
Some(user_id),
···
484
482
&[("handle", &prefs.handle), ("url", recovery_url)],
485
483
);
486
484
let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]);
487
-
let channel = channel_from_str(&prefs.preferred_channel);
485
+
let channel = prefs.preferred_channel;
488
486
infra_repo
489
487
.enqueue_comms(
490
488
Some(user_id),
···
614
612
&[("handle", &prefs.handle), ("code", code)],
615
613
);
616
614
let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]);
617
-
let channel = channel_from_str(&prefs.preferred_channel);
615
+
let channel = prefs.preferred_channel;
618
616
infra_repo
619
617
.enqueue_comms(
620
618
Some(user_id),
+14
-10
crates/tranquil-pds/src/crawlers.rs
+14
-10
crates/tranquil-pds/src/crawlers.rs
···
1
1
use crate::circuit_breaker::CircuitBreaker;
2
2
use crate::sync::firehose::SequencedEvent;
3
+
use crate::util::pds_hostname;
3
4
use reqwest::Client;
4
5
use std::sync::Arc;
5
6
use std::sync::atomic::{AtomicU64, Ordering};
6
7
use std::time::Duration;
7
-
use tokio::sync::{broadcast, watch};
8
+
use tokio::sync::broadcast;
9
+
use tokio_util::sync::CancellationToken;
8
10
use tracing::{debug, error, info, warn};
11
+
use tranquil_db_traits::RepoEventType;
9
12
10
13
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
11
14
···
40
43
}
41
44
42
45
pub fn from_env() -> Option<Self> {
43
-
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
46
+
let hostname = pds_hostname();
47
+
if hostname == "localhost" {
48
+
return None;
49
+
}
44
50
45
51
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
46
52
.unwrap_or_default()
···
53
59
return None;
54
60
}
55
61
56
-
Some(Self::new(hostname, crawler_urls))
62
+
Some(Self::new(hostname.to_string(), crawler_urls))
57
63
}
58
64
59
65
fn should_notify(&self) -> bool {
···
143
149
pub async fn start_crawlers_service(
144
150
crawlers: Arc<Crawlers>,
145
151
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
146
-
mut shutdown: watch::Receiver<bool>,
152
+
shutdown: CancellationToken,
147
153
) {
148
154
info!(
149
155
hostname = %crawlers.hostname,
···
157
163
result = firehose_rx.recv() => {
158
164
match result {
159
165
Ok(event) => {
160
-
if event.event_type == "commit" {
166
+
if event.event_type == RepoEventType::Commit {
161
167
crawlers.notify_of_update().await;
162
168
}
163
169
}
···
171
177
}
172
178
}
173
179
}
174
-
_ = shutdown.changed() => {
175
-
if *shutdown.borrow() {
176
-
info!("Crawlers service shutting down");
177
-
break;
178
-
}
180
+
_ = shutdown.cancelled() => {
181
+
info!("Crawlers service shutting down");
182
+
break;
179
183
}
180
184
}
181
185
}
+9
-1
crates/tranquil-pds/src/delegation/mod.rs
+9
-1
crates/tranquil-pds/src/delegation/mod.rs
···
1
+
pub mod roles;
1
2
pub mod scopes;
2
3
3
-
pub use scopes::{SCOPE_PRESETS, ScopePreset, intersect_scopes};
4
+
pub use roles::{
5
+
CanAddControllers, CanBeController, CanControlAccounts, verify_can_add_controllers,
6
+
verify_can_be_controller, verify_can_control_accounts,
7
+
};
8
+
pub use scopes::{
9
+
InvalidDelegationScopeError, SCOPE_PRESETS, ScopePreset, ValidatedDelegationScope,
10
+
intersect_scopes, validate_delegation_scopes,
11
+
};
4
12
pub use tranquil_db_traits::DelegationActionType;
+108
crates/tranquil-pds/src/delegation/roles.rs
+108
crates/tranquil-pds/src/delegation/roles.rs
···
1
+
use axum::response::{IntoResponse, Response};
2
+
3
+
use crate::api::error::ApiError;
4
+
use crate::auth::AuthenticatedUser;
5
+
use crate::state::AppState;
6
+
use crate::types::Did;
7
+
8
+
pub struct CanAddControllers<'a> {
9
+
user: &'a AuthenticatedUser,
10
+
}
11
+
12
+
pub struct CanControlAccounts<'a> {
13
+
user: &'a AuthenticatedUser,
14
+
}
15
+
16
+
pub struct CanBeController<'a> {
17
+
controller_did: &'a Did,
18
+
}
19
+
20
+
impl<'a> CanAddControllers<'a> {
21
+
pub fn did(&self) -> &Did {
22
+
&self.user.did
23
+
}
24
+
25
+
pub fn user(&self) -> &AuthenticatedUser {
26
+
self.user
27
+
}
28
+
}
29
+
30
+
impl<'a> CanControlAccounts<'a> {
31
+
pub fn did(&self) -> &Did {
32
+
&self.user.did
33
+
}
34
+
35
+
pub fn user(&self) -> &AuthenticatedUser {
36
+
self.user
37
+
}
38
+
}
39
+
40
+
impl<'a> CanBeController<'a> {
41
+
pub fn did(&self) -> &Did {
42
+
self.controller_did
43
+
}
44
+
}
45
+
46
+
pub async fn verify_can_add_controllers<'a>(
47
+
state: &AppState,
48
+
user: &'a AuthenticatedUser,
49
+
) -> Result<CanAddControllers<'a>, Response> {
50
+
match state.delegation_repo.controls_any_accounts(&user.did).await {
51
+
Ok(true) => Err(ApiError::InvalidDelegation(
52
+
"Cannot add controllers to an account that controls other accounts".into(),
53
+
)
54
+
.into_response()),
55
+
Ok(false) => Ok(CanAddControllers { user }),
56
+
Err(e) => {
57
+
tracing::error!("Failed to check delegation status: {:?}", e);
58
+
Err(
59
+
ApiError::InternalError(Some("Failed to verify delegation status".into()))
60
+
.into_response(),
61
+
)
62
+
}
63
+
}
64
+
}
65
+
66
+
pub async fn verify_can_control_accounts<'a>(
67
+
state: &AppState,
68
+
user: &'a AuthenticatedUser,
69
+
) -> Result<CanControlAccounts<'a>, Response> {
70
+
match state.delegation_repo.has_any_controllers(&user.did).await {
71
+
Ok(true) => Err(ApiError::InvalidDelegation(
72
+
"Cannot create delegated accounts from a controlled account".into(),
73
+
)
74
+
.into_response()),
75
+
Ok(false) => Ok(CanControlAccounts { user }),
76
+
Err(e) => {
77
+
tracing::error!("Failed to check controller status: {:?}", e);
78
+
Err(
79
+
ApiError::InternalError(Some("Failed to verify controller status".into()))
80
+
.into_response(),
81
+
)
82
+
}
83
+
}
84
+
}
85
+
86
+
pub async fn verify_can_be_controller<'a>(
87
+
state: &AppState,
88
+
controller_did: &'a Did,
89
+
) -> Result<CanBeController<'a>, Response> {
90
+
match state
91
+
.delegation_repo
92
+
.has_any_controllers(controller_did)
93
+
.await
94
+
{
95
+
Ok(true) => Err(ApiError::InvalidDelegation(
96
+
"Cannot add a controlled account as a controller".into(),
97
+
)
98
+
.into_response()),
99
+
Ok(false) => Ok(CanBeController { controller_did }),
100
+
Err(e) => {
101
+
tracing::error!("Failed to check controller status: {:?}", e);
102
+
Err(
103
+
ApiError::InternalError(Some("Failed to verify controller status".into()))
104
+
.into_response(),
105
+
)
106
+
}
107
+
}
108
+
}
+7
-29
crates/tranquil-pds/src/delegation/scopes.rs
+7
-29
crates/tranquil-pds/src/delegation/scopes.rs
···
1
1
use std::collections::HashSet;
2
2
3
+
pub use tranquil_db_traits::{
4
+
DbScope as ValidatedDelegationScope, InvalidScopeError as InvalidDelegationScopeError,
5
+
};
6
+
3
7
pub struct ScopePreset {
4
8
pub name: &'static str,
5
9
pub label: &'static str,
···
107
111
}
108
112
}
109
113
110
-
pub fn validate_delegation_scopes(scopes: &str) -> Result<(), String> {
111
-
if scopes.is_empty() {
112
-
return Ok(());
113
-
}
114
-
115
-
scopes.split_whitespace().try_for_each(|scope| {
116
-
let (base, _) = split_scope(scope);
117
-
if is_valid_scope_prefix(base) {
118
-
Ok(())
119
-
} else {
120
-
Err(format!("Invalid scope: {}", scope))
121
-
}
122
-
})
123
-
}
124
-
125
-
fn is_valid_scope_prefix(base: &str) -> bool {
126
-
const VALID_PREFIXES: [&str; 7] = [
127
-
"atproto",
128
-
"repo:",
129
-
"blob:",
130
-
"rpc:",
131
-
"account:",
132
-
"identity:",
133
-
"transition:",
134
-
];
135
-
136
-
VALID_PREFIXES
137
-
.iter()
138
-
.any(|prefix| base == prefix.trim_end_matches(':') || base.starts_with(prefix))
114
+
pub fn validate_delegation_scopes(scopes: &str) -> Result<(), InvalidDelegationScopeError> {
115
+
ValidatedDelegationScope::new(scopes)?;
116
+
Ok(())
139
117
}
140
118
141
119
#[cfg(test)]
+2
-1
crates/tranquil-pds/src/lib.rs
+2
-1
crates/tranquil-pds/src/lib.rs
···
2
2
pub mod appview;
3
3
pub mod auth;
4
4
pub mod cache;
5
+
pub mod cid_types;
5
6
pub mod circuit_breaker;
6
7
pub mod comms;
7
8
pub mod config;
···
35
36
use http::StatusCode;
36
37
use serde_json::json;
37
38
use state::AppState;
38
-
pub use sync::util::AccountStatus;
39
39
use tower::ServiceBuilder;
40
40
use tower_http::cors::{Any, CorsLayer};
41
+
pub use tranquil_db_traits::AccountStatus;
41
42
pub use types::{AccountState, AtIdentifier, AtUri, Did, Handle, Nsid, Rkey};
42
43
43
44
pub fn app(state: AppState) -> Router {
+50
-37
crates/tranquil-pds/src/main.rs
+50
-37
crates/tranquil-pds/src/main.rs
···
1
1
use std::net::SocketAddr;
2
2
use std::process::ExitCode;
3
3
use std::sync::Arc;
4
-
use tokio::sync::watch;
4
+
use tokio_util::sync::CancellationToken;
5
5
use tracing::{error, info, warn};
6
6
use tranquil_pds::comms::{CommsService, DiscordSender, EmailSender, SignalSender, TelegramSender};
7
7
···
34
34
}
35
35
36
36
async fn run() -> Result<(), Box<dyn std::error::Error>> {
37
-
let state = AppState::new().await?;
38
-
tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
37
+
let shutdown = CancellationToken::new();
38
+
39
+
let shutdown_for_panic = shutdown.clone();
40
+
let default_panic_hook = std::panic::take_hook();
41
+
std::panic::set_hook(Box::new(move |info| {
42
+
error!("PANIC: {}", info);
43
+
shutdown_for_panic.cancel();
44
+
default_panic_hook(info);
45
+
}));
39
46
40
-
let (shutdown_tx, shutdown_rx) = watch::channel(false);
47
+
spawn_signal_handler(shutdown.clone());
48
+
49
+
let state = AppState::new(shutdown.clone()).await?;
50
+
tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
41
51
42
52
let backfill_repo_repo = state.repo_repo.clone();
43
53
let backfill_block_store = state.block_store.clone();
···
77
87
comms_service = comms_service.register_sender(signal_sender);
78
88
}
79
89
80
-
let comms_handle = tokio::spawn(comms_service.run(shutdown_rx.clone()));
90
+
let comms_handle = tokio::spawn(comms_service.run(shutdown.clone()));
81
91
82
92
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
83
93
let crawlers = Arc::new(
···
88
98
Some(tokio::spawn(start_crawlers_service(
89
99
crawlers,
90
100
firehose_rx,
91
-
shutdown_rx.clone(),
101
+
shutdown.clone(),
92
102
)))
93
103
} else {
94
104
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
···
102
112
state.backup_repo.clone(),
103
113
state.block_store.clone(),
104
114
backup_storage,
105
-
shutdown_rx.clone(),
115
+
shutdown.clone(),
106
116
)))
107
117
} else {
108
118
warn!("Backup service disabled (BACKUP_S3_BUCKET not set or BACKUP_ENABLED=false)");
···
114
124
state.blob_repo.clone(),
115
125
state.blob_store.clone(),
116
126
state.sso_repo.clone(),
117
-
shutdown_rx,
127
+
shutdown.clone(),
118
128
));
119
129
120
130
let app = tranquil_pds::app(state);
···
136
146
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
137
147
138
148
let server_result = axum::serve(listener, app)
139
-
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
149
+
.with_graceful_shutdown(shutdown.clone().cancelled_owned())
140
150
.await;
141
151
142
152
comms_handle.await.ok();
···
158
168
Ok(())
159
169
}
160
170
161
-
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
162
-
let ctrl_c = async {
163
-
match tokio::signal::ctrl_c().await {
164
-
Ok(()) => {}
165
-
Err(e) => {
166
-
error!("Failed to install Ctrl+C handler: {}", e);
167
-
}
168
-
}
169
-
};
170
-
171
-
#[cfg(unix)]
172
-
let terminate = async {
173
-
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
174
-
Ok(mut signal) => {
175
-
signal.recv().await;
171
+
fn spawn_signal_handler(shutdown: CancellationToken) {
172
+
tokio::spawn(async move {
173
+
let ctrl_c = async {
174
+
match tokio::signal::ctrl_c().await {
175
+
Ok(()) => {}
176
+
Err(e) => {
177
+
error!("Failed to install Ctrl+C handler: {}", e);
178
+
std::future::pending::<()>().await;
179
+
}
176
180
}
177
-
Err(e) => {
178
-
error!("Failed to install SIGTERM handler: {}", e);
179
-
std::future::pending::<()>().await;
181
+
};
182
+
183
+
#[cfg(unix)]
184
+
let terminate = async {
185
+
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
186
+
Ok(mut signal) => {
187
+
signal.recv().await;
188
+
}
189
+
Err(e) => {
190
+
error!("Failed to install SIGTERM handler: {}", e);
191
+
std::future::pending::<()>().await;
192
+
}
180
193
}
181
-
}
182
-
};
194
+
};
183
195
184
-
#[cfg(not(unix))]
185
-
let terminate = std::future::pending::<()>();
196
+
#[cfg(not(unix))]
197
+
let terminate = std::future::pending::<()>();
186
198
187
-
tokio::select! {
188
-
_ = ctrl_c => {},
189
-
_ = terminate => {},
190
-
}
199
+
tokio::select! {
200
+
_ = ctrl_c => {},
201
+
_ = terminate => {},
202
+
}
191
203
192
-
info!("Shutdown signal received, stopping services...");
193
-
shutdown_tx.send(true).ok();
204
+
info!("Shutdown signal received, stopping services...");
205
+
shutdown.cancel();
206
+
});
194
207
}
+32
-64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
+32
-64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
···
1
1
use crate::auth::{Active, Auth};
2
2
use crate::delegation::DelegationActionType;
3
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::rate_limit::{LoginLimit, OAuthRateLimited, TotpVerifyLimit};
4
+
use crate::state::AppState;
4
5
use crate::types::PlainPassword;
5
6
use crate::util::extract_client_ip;
6
7
use axum::{
7
8
Json,
8
9
extract::State,
9
-
http::{HeaderMap, StatusCode},
10
+
http::HeaderMap,
10
11
response::{IntoResponse, Response},
11
12
};
12
13
use serde::{Deserialize, Serialize};
···
35
36
36
37
pub async fn delegation_auth(
37
38
State(state): State<AppState>,
39
+
rate_limit: OAuthRateLimited<LoginLimit>,
38
40
headers: HeaderMap,
39
41
Json(form): Json<DelegationAuthSubmit>,
40
42
) -> Response {
41
-
let client_ip = extract_client_ip(&headers);
42
-
if !state
43
-
.check_rate_limit(RateLimitKind::Login, &client_ip)
44
-
.await
45
-
{
46
-
return (
47
-
StatusCode::TOO_MANY_REQUESTS,
48
-
Json(DelegationAuthResponse {
49
-
success: false,
50
-
needs_totp: None,
51
-
redirect_uri: None,
52
-
error: Some("Too many login attempts. Please try again later.".to_string()),
53
-
}),
54
-
)
55
-
.into_response();
56
-
}
57
-
43
+
let client_ip = rate_limit.client_ip();
58
44
let request_id = RequestId::from(form.request_uri.clone());
59
45
let request = match state
60
46
.oauth_repo
···
82
68
}
83
69
};
84
70
85
-
let delegated_did_str = match form.delegated_did.as_ref().or(request.did.as_ref()) {
86
-
Some(did) => did.clone(),
87
-
None => {
88
-
return Json(DelegationAuthResponse {
89
-
success: false,
90
-
needs_totp: None,
91
-
redirect_uri: None,
92
-
error: Some("No delegated account selected".to_string()),
93
-
})
94
-
.into_response();
95
-
}
96
-
};
97
-
98
-
let delegated_did: Did = match delegated_did_str.parse() {
99
-
Ok(d) => d,
100
-
Err(_) => {
101
-
return Json(DelegationAuthResponse {
102
-
success: false,
103
-
needs_totp: None,
104
-
redirect_uri: None,
105
-
error: Some("Invalid delegated DID".to_string()),
106
-
})
107
-
.into_response();
71
+
let delegated_did: Did = if let Some(did_str) = form.delegated_did.as_ref() {
72
+
match did_str.parse() {
73
+
Ok(d) => d,
74
+
Err(_) => {
75
+
return Json(DelegationAuthResponse {
76
+
success: false,
77
+
needs_totp: None,
78
+
redirect_uri: None,
79
+
error: Some("Invalid delegated DID".to_string()),
80
+
})
81
+
.into_response();
82
+
}
108
83
}
84
+
} else if let Some(did) = request.did.as_ref() {
85
+
did.clone()
86
+
} else {
87
+
return Json(DelegationAuthResponse {
88
+
success: false,
89
+
needs_totp: None,
90
+
redirect_uri: None,
91
+
error: Some("No delegated account selected".to_string()),
92
+
})
93
+
.into_response();
109
94
};
110
95
111
96
let controller_did: Did = match form.controller_did.parse() {
···
249
234
.into_response();
250
235
}
251
236
252
-
let ip = extract_client_ip(&headers);
253
237
let user_agent = headers
254
238
.get("user-agent")
255
239
.and_then(|v| v.to_str().ok())
···
266
250
"client_id": request.client_id,
267
251
"granted_scopes": grant.granted_scopes
268
252
})),
269
-
Some(&ip),
253
+
Some(client_ip),
270
254
user_agent.as_deref(),
271
255
)
272
256
.await;
···
291
275
292
276
pub async fn delegation_totp_verify(
293
277
State(state): State<AppState>,
278
+
rate_limit: OAuthRateLimited<TotpVerifyLimit>,
294
279
headers: HeaderMap,
295
280
Json(form): Json<DelegationTotpSubmit>,
296
281
) -> Response {
297
-
let client_ip = extract_client_ip(&headers);
298
-
if !state
299
-
.check_rate_limit(RateLimitKind::TotpVerify, &client_ip)
300
-
.await
301
-
{
302
-
return (
303
-
StatusCode::TOO_MANY_REQUESTS,
304
-
Json(DelegationAuthResponse {
305
-
success: false,
306
-
needs_totp: None,
307
-
redirect_uri: None,
308
-
error: Some("Too many verification attempts. Please try again later.".to_string()),
309
-
}),
310
-
)
311
-
.into_response();
312
-
}
313
-
282
+
let client_ip = rate_limit.client_ip();
314
283
let totp_request_id = RequestId::from(form.request_uri.clone());
315
284
let request = match state
316
285
.oauth_repo
···
420
389
.into_response();
421
390
}
422
391
423
-
let ip = extract_client_ip(&headers);
424
392
let user_agent = headers
425
393
.get("user-agent")
426
394
.and_then(|v| v.to_str().ok())
···
437
405
"client_id": request.client_id,
438
406
"granted_scopes": grant.granted_scopes
439
407
})),
440
-
Some(&ip),
408
+
Some(client_ip),
441
409
user_agent.as_deref(),
442
410
)
443
411
.await;
···
564
532
.into_response();
565
533
}
566
534
567
-
let ip = extract_client_ip(&headers);
535
+
let ip = extract_client_ip(&headers, None);
568
536
let user_agent = headers
569
537
.get("user-agent")
570
538
.and_then(|v| v.to_str().ok())
+3
-2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
+3
-2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
···
1
1
use crate::oauth::jwks::{JwkSet, create_jwk_set};
2
2
use crate::state::AppState;
3
+
use crate::util::pds_hostname;
3
4
use axum::{Json, extract::State};
4
5
use serde::{Deserialize, Serialize};
5
6
···
57
58
pub async fn oauth_protected_resource(
58
59
State(_state): State<AppState>,
59
60
) -> Json<ProtectedResourceMetadata> {
60
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
61
+
let pds_hostname = pds_hostname();
61
62
let public_url = format!("https://{}", pds_hostname);
62
63
Json(ProtectedResourceMetadata {
63
64
resource: public_url.clone(),
···
71
72
pub async fn oauth_authorization_server(
72
73
State(_state): State<AppState>,
73
74
) -> Json<AuthorizationServerMetadata> {
74
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
75
+
let pds_hostname = pds_hostname();
75
76
let issuer = format!("https://{}", pds_hostname);
76
77
Json(AuthorizationServerMetadata {
77
78
issuer: issuer.clone(),
+58
-50
crates/tranquil-pds/src/oauth/endpoints/par.rs
+58
-50
crates/tranquil-pds/src/oauth/endpoints/par.rs
···
1
1
use crate::oauth::{
2
-
AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData,
3
-
RequestId,
2
+
AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, CodeChallengeMethod,
3
+
OAuthError, Prompt, RequestData, RequestId, ResponseMode, ResponseType,
4
4
scopes::{ParsedScope, parse_scope},
5
5
};
6
-
use crate::state::{AppState, RateLimitKind};
6
+
use crate::rate_limit::{OAuthParLimit, OAuthRateLimited};
7
+
use crate::state::AppState;
7
8
use axum::body::Bytes;
8
9
use axum::{Json, extract::State, http::HeaderMap};
9
10
use chrono::{Duration, Utc};
···
49
50
50
51
pub async fn pushed_authorization_request(
51
52
State(state): State<AppState>,
53
+
_rate_limit: OAuthRateLimited<OAuthParLimit>,
52
54
headers: HeaderMap,
53
55
body: Bytes,
54
56
) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
···
70
72
.to_string(),
71
73
));
72
74
};
73
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
74
-
if !state
75
-
.check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
76
-
.await
77
-
{
78
-
tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
79
-
return Err(OAuthError::RateLimited);
80
-
}
81
-
if request.response_type != "code" {
82
-
return Err(OAuthError::InvalidRequest(
83
-
"response_type must be 'code'".to_string(),
84
-
));
85
-
}
75
+
let response_type = parse_response_type(&request.response_type)?;
86
76
let code_challenge = request
87
77
.code_challenge
88
78
.as_ref()
89
79
.filter(|s| !s.is_empty())
90
80
.ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
91
-
let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
92
-
if code_challenge_method != "S256" {
93
-
return Err(OAuthError::InvalidRequest(
94
-
"code_challenge_method must be 'S256'".to_string(),
95
-
));
96
-
}
81
+
let code_challenge_method =
82
+
parse_code_challenge_method(request.code_challenge_method.as_deref())?;
97
83
let client_cache = ClientMetadataCache::new(3600);
98
84
let client_metadata = client_cache.get(&request.client_id).await?;
99
85
client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
···
101
87
let validated_scope = validate_scope(&request.scope, &client_metadata)?;
102
88
let request_id = RequestId::generate();
103
89
let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
104
-
let response_mode = match request.response_mode.as_deref() {
105
-
Some("fragment") => Some("fragment".to_string()),
106
-
Some("query") | None => None,
107
-
Some(mode) => {
108
-
return Err(OAuthError::InvalidRequest(format!(
109
-
"Unsupported response_mode: {}",
110
-
mode
111
-
)));
112
-
}
113
-
};
114
-
let prompt = validate_prompt(&request.prompt)?;
90
+
let response_mode = parse_response_mode(request.response_mode.as_deref())?;
91
+
let prompt = parse_prompt(request.prompt.as_deref())?;
115
92
let parameters = AuthorizationRequestParameters {
116
-
response_type: request.response_type,
93
+
response_type,
117
94
client_id: request.client_id.clone(),
118
95
redirect_uri: request.redirect_uri,
119
96
scope: validated_scope,
120
97
state: request.state,
121
98
code_challenge: code_challenge.clone(),
122
-
code_challenge_method: code_challenge_method.to_string(),
99
+
code_challenge_method,
123
100
response_mode,
124
101
login_hint: request.login_hint,
125
102
dpop_jkt: request.dpop_jkt,
···
266
243
false
267
244
}
268
245
269
-
fn validate_prompt(prompt: &Option<String>) -> Result<Option<String>, OAuthError> {
270
-
const VALID_PROMPTS: &[&str] = &["none", "login", "consent", "select_account", "create"];
246
+
fn parse_response_type(value: &str) -> Result<ResponseType, OAuthError> {
247
+
match value {
248
+
"code" => Ok(ResponseType::Code),
249
+
other => Err(OAuthError::InvalidRequest(format!(
250
+
"response_type must be 'code', got '{}'",
251
+
other
252
+
))),
253
+
}
254
+
}
271
255
272
-
match prompt {
273
-
None => Ok(None),
274
-
Some(p) if p.is_empty() => Ok(None),
275
-
Some(p) => {
276
-
if VALID_PROMPTS.contains(&p.as_str()) {
277
-
Ok(Some(p.clone()))
278
-
} else {
279
-
Err(OAuthError::InvalidRequest(format!(
280
-
"Unsupported prompt value: {}",
281
-
p
282
-
)))
283
-
}
284
-
}
256
+
fn parse_code_challenge_method(value: Option<&str>) -> Result<CodeChallengeMethod, OAuthError> {
257
+
match value {
258
+
Some("S256") | None => Ok(CodeChallengeMethod::S256),
259
+
Some("plain") => Err(OAuthError::InvalidRequest(
260
+
"code_challenge_method 'plain' is not allowed, use 'S256'".to_string(),
261
+
)),
262
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
263
+
"Unsupported code_challenge_method: {}",
264
+
other
265
+
))),
266
+
}
267
+
}
268
+
269
+
fn parse_response_mode(value: Option<&str>) -> Result<Option<ResponseMode>, OAuthError> {
270
+
match value {
271
+
None | Some("query") => Ok(None),
272
+
Some("fragment") => Ok(Some(ResponseMode::Fragment)),
273
+
Some("form_post") => Ok(Some(ResponseMode::FormPost)),
274
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
275
+
"Unsupported response_mode: {}",
276
+
other
277
+
))),
278
+
}
279
+
}
280
+
281
+
fn parse_prompt(value: Option<&str>) -> Result<Option<Prompt>, OAuthError> {
282
+
match value {
283
+
None | Some("") => Ok(None),
284
+
Some("none") => Ok(Some(Prompt::None)),
285
+
Some("login") => Ok(Some(Prompt::Login)),
286
+
Some("consent") => Ok(Some(Prompt::Consent)),
287
+
Some("select_account") => Ok(Some(Prompt::SelectAccount)),
288
+
Some("create") => Ok(Some(Prompt::Create)),
289
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
290
+
"Unsupported prompt value: {}",
291
+
other
292
+
))),
285
293
}
286
294
}
+42
-49
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
+42
-49
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
···
3
3
use crate::config::AuthConfig;
4
4
use crate::delegation::intersect_scopes;
5
5
use crate::oauth::{
6
-
AuthFlowState, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken,
7
-
TokenData, TokenId,
6
+
AuthFlow, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, TokenData,
7
+
TokenId,
8
8
db::{enforce_token_limit_for_user, lookup_refresh_token},
9
9
scopes::expand_include_scopes,
10
10
verify_client_auth,
11
11
};
12
12
use crate::state::AppState;
13
+
use crate::util::pds_hostname;
13
14
use axum::Json;
14
15
use axum::http::HeaderMap;
15
16
use chrono::{Duration, Utc};
···
51
52
.map_err(crate::oauth::db_err_to_oauth)?
52
53
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
53
54
54
-
let flow_state = AuthFlowState::from_request_data(&auth_request);
55
-
if flow_state.is_expired() {
56
-
return Err(OAuthError::InvalidGrant(
57
-
"Authorization code has expired".to_string(),
58
-
));
59
-
}
60
-
if !flow_state.can_exchange() {
61
-
return Err(OAuthError::InvalidGrant(
62
-
"Authorization not completed".to_string(),
63
-
));
64
-
}
55
+
let flow = AuthFlow::from_request_data(auth_request)
56
+
.map_err(|_| OAuthError::InvalidGrant("Authorization code has expired".to_string()))?;
57
+
58
+
let authorized = flow
59
+
.require_authorized()
60
+
.map_err(|_| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
65
61
66
62
if let Some(request_client_id) = &request.client_auth.client_id
67
-
&& request_client_id != &auth_request.client_id
63
+
&& request_client_id != &authorized.client_id
68
64
{
69
65
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
70
66
}
71
-
let did = flow_state.did().unwrap().to_string();
67
+
let did = authorized.did.to_string();
72
68
let client_metadata_cache = ClientMetadataCache::new(3600);
73
-
let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?;
69
+
let client_metadata = client_metadata_cache.get(&authorized.client_id).await?;
74
70
let client_auth = if let (Some(assertion), Some(assertion_type)) = (
75
71
&request.client_auth.client_assertion,
76
72
&request.client_auth.client_assertion_type,
···
91
87
ClientAuth::None
92
88
};
93
89
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
94
-
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
90
+
verify_pkce(&authorized.parameters.code_challenge, &code_verifier)?;
95
91
if let Some(req_redirect_uri) = &redirect_uri
96
-
&& req_redirect_uri != &auth_request.parameters.redirect_uri
92
+
&& req_redirect_uri != &authorized.parameters.redirect_uri
97
93
{
98
94
return Err(OAuthError::InvalidGrant(
99
95
"redirect_uri mismatch".to_string(),
···
102
98
let dpop_jkt = if let Some(proof) = &dpop_proof {
103
99
let config = AuthConfig::get();
104
100
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
105
-
let pds_hostname =
106
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
101
+
let pds_hostname = pds_hostname();
107
102
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
108
103
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
109
104
if !state
···
116
111
"DPoP proof has already been used".to_string(),
117
112
));
118
113
}
119
-
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
114
+
if let Some(expected_jkt) = &authorized.parameters.dpop_jkt
120
115
&& result.jkt.as_str() != expected_jkt
121
116
{
122
117
return Err(OAuthError::InvalidDpopProof(
···
124
119
));
125
120
}
126
121
Some(result.jkt.as_str().to_string())
127
-
} else if auth_request.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() {
122
+
} else if authorized.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() {
128
123
return Err(OAuthError::UseDpopNonce(
129
124
DPoPVerifier::new(AuthConfig::get().dpop_secret().as_bytes()).generate_nonce(),
130
125
));
···
135
130
let refresh_token = RefreshToken::generate();
136
131
let now = Utc::now();
137
132
138
-
let (raw_scope, controller_did) = if let Some(ref controller) = auth_request.controller_did {
133
+
let (raw_scope, controller_did) = if let Some(ref controller) = authorized.controller_did {
139
134
let did_parsed: Did = did
140
135
.parse()
141
136
.map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?;
···
149
144
.ok()
150
145
.flatten();
151
146
let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default();
152
-
let requested = auth_request
153
-
.parameters
154
-
.scope
155
-
.as_deref()
156
-
.unwrap_or("atproto");
157
-
let intersected = intersect_scopes(requested, &granted_scopes);
147
+
let requested = authorized.parameters.scope.as_deref().unwrap_or("atproto");
148
+
let intersected = intersect_scopes(requested, granted_scopes.as_str());
158
149
(Some(intersected), Some(controller.clone()))
159
150
} else {
160
-
(auth_request.parameters.scope.clone(), None)
151
+
(authorized.parameters.scope.clone(), None)
161
152
};
162
153
163
154
let final_scope = if let Some(ref scope) = raw_scope {
···
177
168
final_scope.as_deref(),
178
169
controller_did.as_deref(),
179
170
)?;
180
-
let stored_client_auth = auth_request.client_auth.unwrap_or(ClientAuth::None);
171
+
let stored_client_auth = authorized.client_auth.unwrap_or(ClientAuth::None);
181
172
let refresh_expiry_days = if matches!(stored_client_auth, ClientAuth::None) {
182
173
REFRESH_TOKEN_EXPIRY_DAYS_PUBLIC
183
174
} else {
184
175
REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL
185
176
};
186
-
let mut stored_parameters = auth_request.parameters.clone();
177
+
let mut stored_parameters = authorized.parameters.clone();
187
178
stored_parameters.dpop_jkt = dpop_jkt.clone();
179
+
let did_typed: Did = did
180
+
.parse()
181
+
.map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?;
188
182
let token_data = TokenData {
189
-
did: did.clone(),
190
-
token_id: token_id.0.clone(),
183
+
did: did_typed,
184
+
token_id: token_id.clone(),
191
185
created_at: now,
192
186
updated_at: now,
193
187
expires_at: now + Duration::days(refresh_expiry_days),
194
-
client_id: auth_request.client_id.clone(),
188
+
client_id: authorized.client_id.clone(),
195
189
client_auth: stored_client_auth,
196
-
device_id: auth_request.device_id,
190
+
device_id: authorized.device_id.clone(),
197
191
parameters: stored_parameters,
198
192
details: None,
199
193
code: None,
200
-
current_refresh_token: Some(refresh_token.0.clone()),
194
+
current_refresh_token: Some(refresh_token.clone()),
201
195
scope: final_scope.clone(),
202
196
controller_did: controller_did.clone(),
203
197
};
···
209
203
tracing::info!(
210
204
did = %did,
211
205
token_id = %token_id.0,
212
-
client_id = %auth_request.client_id,
206
+
client_id = %authorized.client_id,
213
207
"Authorization code grant completed, token created"
214
208
);
215
209
tokio::spawn({
···
280
274
);
281
275
let dpop_jkt = token_data.parameters.dpop_jkt.as_deref();
282
276
let access_token = create_access_token_with_delegation(
283
-
&token_data.token_id,
284
-
&token_data.did,
277
+
&token_data.token_id.0,
278
+
token_data.did.as_str(),
285
279
dpop_jkt,
286
280
token_data.scope.as_deref(),
287
-
token_data.controller_did.as_deref(),
281
+
token_data.controller_did.as_ref().map(|d| d.as_str()),
288
282
)?;
289
283
let mut response_headers = HeaderMap::new();
290
284
let config = AuthConfig::get();
···
296
290
access_token,
297
291
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
298
292
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
299
-
refresh_token: token_data.current_refresh_token,
293
+
refresh_token: token_data.current_refresh_token.map(|r| r.0),
300
294
scope: token_data.scope,
301
-
sub: Some(token_data.did),
295
+
sub: Some(token_data.did.to_string()),
302
296
}),
303
297
));
304
298
}
···
337
331
let dpop_jkt = if let Some(proof) = &dpop_proof {
338
332
let config = AuthConfig::get();
339
333
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
340
-
let pds_hostname =
341
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
334
+
let pds_hostname = pds_hostname();
342
335
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
343
336
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
344
337
if !state
···
385
378
"Refresh token rotated successfully"
386
379
);
387
380
let access_token = create_access_token_with_delegation(
388
-
&token_data.token_id,
389
-
&token_data.did,
381
+
&token_data.token_id.0,
382
+
token_data.did.as_str(),
390
383
dpop_jkt.as_deref(),
391
384
token_data.scope.as_deref(),
392
-
token_data.controller_did.as_deref(),
385
+
token_data.controller_did.as_ref().map(|d| d.as_str()),
393
386
)?;
394
387
let mut response_headers = HeaderMap::new();
395
388
let config = AuthConfig::get();
···
403
396
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
404
397
refresh_token: Some(new_refresh_token.0),
405
398
scope: token_data.scope,
406
-
sub: Some(token_data.did),
399
+
sub: Some(token_data.did.to_string()),
407
400
}),
408
401
))
409
402
}
+2
-1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
+2
-1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
···
1
1
use crate::config::AuthConfig;
2
2
use crate::oauth::OAuthError;
3
+
use crate::util::pds_hostname;
3
4
use base64::Engine;
4
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5
6
use chrono::Utc;
···
51
52
) -> Result<String, OAuthError> {
52
53
use serde_json::json;
53
54
let jti = uuid::Uuid::new_v4().to_string();
54
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
55
+
let pds_hostname = pds_hostname();
55
56
let issuer = format!("https://{}", pds_hostname);
56
57
let now = Utc::now().timestamp();
57
58
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
+8
-22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
+8
-22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
···
1
1
use super::helpers::extract_token_claims;
2
2
use crate::oauth::OAuthError;
3
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::rate_limit::{OAuthIntrospectLimit, OAuthRateLimited};
4
+
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
4
6
use axum::extract::State;
5
-
use axum::http::{HeaderMap, StatusCode};
7
+
use axum::http::StatusCode;
6
8
use axum::{Form, Json};
7
9
use chrono::Utc;
8
10
use serde::{Deserialize, Serialize};
···
17
19
18
20
pub async fn revoke_token(
19
21
State(state): State<AppState>,
20
-
headers: HeaderMap,
22
+
_rate_limit: OAuthRateLimited<OAuthIntrospectLimit>,
21
23
Form(request): Form<RevokeRequest>,
22
24
) -> Result<StatusCode, OAuthError> {
23
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
24
-
if !state
25
-
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
26
-
.await
27
-
{
28
-
tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
29
-
return Err(OAuthError::RateLimited);
30
-
}
31
25
if let Some(token) = &request.token {
32
26
let refresh_token = RefreshToken::from(token.clone());
33
27
if let Some((db_id, _)) = state
···
89
83
90
84
pub async fn introspect_token(
91
85
State(state): State<AppState>,
92
-
headers: HeaderMap,
86
+
_rate_limit: OAuthRateLimited<OAuthIntrospectLimit>,
93
87
Form(request): Form<IntrospectRequest>,
94
88
) -> Result<Json<IntrospectResponse>, OAuthError> {
95
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
96
-
if !state
97
-
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
98
-
.await
99
-
{
100
-
tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
101
-
return Err(OAuthError::RateLimited);
102
-
}
103
89
let inactive_response = IntrospectResponse {
104
90
active: false,
105
91
scope: None,
···
126
112
if token_data.expires_at < Utc::now() {
127
113
return Ok(Json(inactive_response));
128
114
}
129
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
115
+
let pds_hostname = pds_hostname();
130
116
let issuer = format!("https://{}", pds_hostname);
131
117
Ok(Json(IntrospectResponse {
132
118
active: true,
···
141
127
exp: Some(token_info.exp),
142
128
iat: Some(token_info.iat),
143
129
nbf: Some(token_info.iat),
144
-
sub: Some(token_data.did),
130
+
sub: Some(token_data.did.to_string()),
145
131
aud: Some(issuer.clone()),
146
132
iss: Some(issuer),
147
133
jti: Some(token_info.jti),
+3
-26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
+3
-26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
···
4
4
mod types;
5
5
6
6
use crate::oauth::OAuthError;
7
-
use crate::state::{AppState, RateLimitKind};
7
+
use crate::rate_limit::{OAuthRateLimited, OAuthTokenLimit};
8
+
use crate::state::AppState;
8
9
use axum::body::Bytes;
9
10
use axum::{Json, extract::State, http::HeaderMap};
10
11
···
17
18
ClientAuthParams, GrantType, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest,
18
19
};
19
20
20
-
fn extract_client_ip(headers: &HeaderMap) -> String {
21
-
if let Some(forwarded) = headers.get("x-forwarded-for")
22
-
&& let Ok(value) = forwarded.to_str()
23
-
&& let Some(first_ip) = value.split(',').next()
24
-
{
25
-
return first_ip.trim().to_string();
26
-
}
27
-
if let Some(real_ip) = headers.get("x-real-ip")
28
-
&& let Ok(value) = real_ip.to_str()
29
-
{
30
-
return value.trim().to_string();
31
-
}
32
-
"unknown".to_string()
33
-
}
34
-
35
21
pub async fn token_endpoint(
36
22
State(state): State<AppState>,
23
+
_rate_limit: OAuthRateLimited<OAuthTokenLimit>,
37
24
headers: HeaderMap,
38
25
body: Bytes,
39
26
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
···
53
40
.to_string(),
54
41
));
55
42
};
56
-
let client_ip = extract_client_ip(&headers);
57
-
if !state
58
-
.check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
59
-
.await
60
-
{
61
-
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
62
-
return Err(OAuthError::InvalidRequest(
63
-
"Too many requests. Please try again later.".to_string(),
64
-
));
65
-
}
66
43
let dpop_proof = headers
67
44
.get("DPoP")
68
45
.and_then(|v| v.to_str().ok())
+9
-7
crates/tranquil-pds/src/oauth/mod.rs
+9
-7
crates/tranquil-pds/src/oauth/mod.rs
···
10
10
}
11
11
12
12
pub use tranquil_oauth::{
13
-
AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
-
AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, DPoPClaims,
15
-
DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, DPoPVerifyResult, DeviceData,
16
-
DeviceId, JwkPublicKey, Jwks, OAuthClientMetadata, OAuthError, ParResponse,
17
-
ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, SessionId,
18
-
TokenData, TokenId, TokenRequest, TokenResponse, compute_access_token_hash,
19
-
compute_jwk_thumbprint, verify_client_auth,
13
+
AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
+
AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code,
15
+
CodeChallengeMethod, DPoPClaims, DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier,
16
+
DPoPVerifyResult, DeviceData, DeviceId, FlowAuthenticated, FlowAuthorized, FlowExpired,
17
+
FlowNotAuthenticated, FlowNotAuthorized, FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata,
18
+
OAuthError, ParResponse, Prompt, ProtectedResourceMetadata, RefreshToken, RefreshTokenState,
19
+
RequestData, RequestId, ResponseMode, ResponseType, SessionId, TokenData, TokenId,
20
+
TokenRequest, TokenResponse, compute_access_token_hash, compute_jwk_thumbprint,
21
+
verify_client_auth,
20
22
};
21
23
22
24
pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions};
+20
-14
crates/tranquil-pds/src/oauth/verify.rs
+20
-14
crates/tranquil-pds/src/oauth/verify.rs
···
20
20
use crate::state::AppState;
21
21
22
22
pub struct OAuthTokenInfo {
23
-
pub did: String,
24
-
pub token_id: String,
25
-
pub client_id: String,
23
+
pub did: Did,
24
+
pub token_id: TokenId,
25
+
pub client_id: ClientId,
26
26
pub scope: Option<String>,
27
27
pub dpop_jkt: Option<String>,
28
-
pub controller_did: Option<String>,
28
+
pub controller_did: Option<Did>,
29
29
}
30
30
31
31
pub struct VerifyResult {
···
48
48
has_dpop_proof = dpop_proof.is_some(),
49
49
"Verifying OAuth access token"
50
50
);
51
-
let token_id = TokenId::from(token_info.token_id.clone());
51
+
let token_id = token_info.token_id.clone();
52
52
let token_data = oauth_repo
53
53
.get_token_by_id(&token_id)
54
54
.await
···
154
154
if exp < now {
155
155
return Err(OAuthError::ExpiredToken("Token has expired".to_string()));
156
156
}
157
-
let token_id = payload
157
+
let token_id_str = payload
158
158
.get("sid")
159
159
.and_then(|j| j.as_str())
160
-
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?
161
-
.to_string();
162
-
let did = payload
160
+
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?;
161
+
let token_id = TokenId::new(token_id_str);
162
+
let did_str = payload
163
163
.get("sub")
164
164
.and_then(|s| s.as_str())
165
-
.ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
166
-
.to_string();
165
+
.ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?;
166
+
let did: Did = did_str
167
+
.parse()
168
+
.map_err(|_| OAuthError::InvalidToken("Invalid sub claim (not a valid DID)".to_string()))?;
167
169
let scope = payload
168
170
.get("scope")
169
171
.and_then(|s| s.as_str())
···
173
175
.and_then(|c| c.get("jkt"))
174
176
.and_then(|j| j.as_str())
175
177
.map(|s| s.to_string());
176
-
let client_id = payload
178
+
let client_id_str = payload
177
179
.get("client_id")
178
180
.and_then(|c| c.as_str())
179
-
.map(|s| s.to_string())
180
181
.unwrap_or_default();
182
+
let client_id = ClientId::new(client_id_str);
181
183
let controller_did = payload
182
184
.get("act")
183
185
.and_then(|a| a.get("sub"))
184
186
.and_then(|s| s.as_str())
185
-
.map(|s| s.to_string());
187
+
.map(|s| s.parse::<Did>())
188
+
.transpose()
189
+
.map_err(|_| {
190
+
OAuthError::InvalidToken("Invalid act.sub claim (not a valid DID)".to_string())
191
+
})?;
186
192
Ok(OAuthTokenInfo {
187
193
did,
188
194
token_id,
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
···
1
+
use std::marker::PhantomData;
2
+
3
+
use axum::{
4
+
extract::FromRequestParts,
5
+
http::request::Parts,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
9
+
use crate::api::error::ApiError;
10
+
use crate::oauth::OAuthError;
11
+
use crate::state::{AppState, RateLimitKind};
12
+
use crate::util::extract_client_ip;
13
+
14
+
pub trait RateLimitPolicy: Send + Sync + 'static {
15
+
const KIND: RateLimitKind;
16
+
}
17
+
18
+
pub struct LoginLimit;
19
+
impl RateLimitPolicy for LoginLimit {
20
+
const KIND: RateLimitKind = RateLimitKind::Login;
21
+
}
22
+
23
+
pub struct AccountCreationLimit;
24
+
impl RateLimitPolicy for AccountCreationLimit {
25
+
const KIND: RateLimitKind = RateLimitKind::AccountCreation;
26
+
}
27
+
28
+
pub struct PasswordResetLimit;
29
+
impl RateLimitPolicy for PasswordResetLimit {
30
+
const KIND: RateLimitKind = RateLimitKind::PasswordReset;
31
+
}
32
+
33
+
pub struct ResetPasswordLimit;
34
+
impl RateLimitPolicy for ResetPasswordLimit {
35
+
const KIND: RateLimitKind = RateLimitKind::ResetPassword;
36
+
}
37
+
38
+
pub struct RefreshSessionLimit;
39
+
impl RateLimitPolicy for RefreshSessionLimit {
40
+
const KIND: RateLimitKind = RateLimitKind::RefreshSession;
41
+
}
42
+
43
+
pub struct OAuthTokenLimit;
44
+
impl RateLimitPolicy for OAuthTokenLimit {
45
+
const KIND: RateLimitKind = RateLimitKind::OAuthToken;
46
+
}
47
+
48
+
pub struct OAuthAuthorizeLimit;
49
+
impl RateLimitPolicy for OAuthAuthorizeLimit {
50
+
const KIND: RateLimitKind = RateLimitKind::OAuthAuthorize;
51
+
}
52
+
53
+
pub struct OAuthParLimit;
54
+
impl RateLimitPolicy for OAuthParLimit {
55
+
const KIND: RateLimitKind = RateLimitKind::OAuthPar;
56
+
}
57
+
58
+
pub struct OAuthIntrospectLimit;
59
+
impl RateLimitPolicy for OAuthIntrospectLimit {
60
+
const KIND: RateLimitKind = RateLimitKind::OAuthIntrospect;
61
+
}
62
+
63
+
pub struct AppPasswordLimit;
64
+
impl RateLimitPolicy for AppPasswordLimit {
65
+
const KIND: RateLimitKind = RateLimitKind::AppPassword;
66
+
}
67
+
68
+
pub struct EmailUpdateLimit;
69
+
impl RateLimitPolicy for EmailUpdateLimit {
70
+
const KIND: RateLimitKind = RateLimitKind::EmailUpdate;
71
+
}
72
+
73
+
pub struct TotpVerifyLimit;
74
+
impl RateLimitPolicy for TotpVerifyLimit {
75
+
const KIND: RateLimitKind = RateLimitKind::TotpVerify;
76
+
}
77
+
78
+
pub struct HandleUpdateLimit;
79
+
impl RateLimitPolicy for HandleUpdateLimit {
80
+
const KIND: RateLimitKind = RateLimitKind::HandleUpdate;
81
+
}
82
+
83
+
pub struct HandleUpdateDailyLimit;
84
+
impl RateLimitPolicy for HandleUpdateDailyLimit {
85
+
const KIND: RateLimitKind = RateLimitKind::HandleUpdateDaily;
86
+
}
87
+
88
+
pub struct VerificationCheckLimit;
89
+
impl RateLimitPolicy for VerificationCheckLimit {
90
+
const KIND: RateLimitKind = RateLimitKind::VerificationCheck;
91
+
}
92
+
93
+
pub struct SsoInitiateLimit;
94
+
impl RateLimitPolicy for SsoInitiateLimit {
95
+
const KIND: RateLimitKind = RateLimitKind::SsoInitiate;
96
+
}
97
+
98
+
pub struct SsoCallbackLimit;
99
+
impl RateLimitPolicy for SsoCallbackLimit {
100
+
const KIND: RateLimitKind = RateLimitKind::SsoCallback;
101
+
}
102
+
103
+
pub struct SsoUnlinkLimit;
104
+
impl RateLimitPolicy for SsoUnlinkLimit {
105
+
const KIND: RateLimitKind = RateLimitKind::SsoUnlink;
106
+
}
107
+
108
+
pub struct OAuthRegisterCompleteLimit;
109
+
impl RateLimitPolicy for OAuthRegisterCompleteLimit {
110
+
const KIND: RateLimitKind = RateLimitKind::OAuthRegisterComplete;
111
+
}
112
+
113
+
pub trait RateLimitRejection: IntoResponse + Send + 'static {
114
+
fn new() -> Self;
115
+
}
116
+
117
+
pub struct ApiRateLimitRejection;
118
+
119
+
impl RateLimitRejection for ApiRateLimitRejection {
120
+
fn new() -> Self {
121
+
Self
122
+
}
123
+
}
124
+
125
+
impl IntoResponse for ApiRateLimitRejection {
126
+
fn into_response(self) -> Response {
127
+
ApiError::RateLimitExceeded(None).into_response()
128
+
}
129
+
}
130
+
131
+
pub struct OAuthRateLimitRejection;
132
+
133
+
impl RateLimitRejection for OAuthRateLimitRejection {
134
+
fn new() -> Self {
135
+
Self
136
+
}
137
+
}
138
+
139
+
impl IntoResponse for OAuthRateLimitRejection {
140
+
fn into_response(self) -> Response {
141
+
OAuthError::RateLimited.into_response()
142
+
}
143
+
}
144
+
145
+
impl From<OAuthRateLimitRejection> for OAuthError {
146
+
fn from(_: OAuthRateLimitRejection) -> Self {
147
+
OAuthError::RateLimited
148
+
}
149
+
}
150
+
151
+
pub struct RateLimitedInner<P: RateLimitPolicy, R: RateLimitRejection> {
152
+
client_ip: String,
153
+
_marker: PhantomData<(P, R)>,
154
+
}
155
+
156
+
impl<P: RateLimitPolicy, R: RateLimitRejection> RateLimitedInner<P, R> {
157
+
pub fn client_ip(&self) -> &str {
158
+
&self.client_ip
159
+
}
160
+
}
161
+
162
+
impl<P: RateLimitPolicy, R: RateLimitRejection> FromRequestParts<AppState>
163
+
for RateLimitedInner<P, R>
164
+
{
165
+
type Rejection = R;
166
+
167
+
async fn from_request_parts(
168
+
parts: &mut Parts,
169
+
state: &AppState,
170
+
) -> Result<Self, Self::Rejection> {
171
+
let client_ip = extract_client_ip(&parts.headers, None);
172
+
173
+
if !state.check_rate_limit(P::KIND, &client_ip).await {
174
+
tracing::warn!(
175
+
ip = %client_ip,
176
+
kind = ?P::KIND,
177
+
"Rate limit exceeded"
178
+
);
179
+
return Err(R::new());
180
+
}
181
+
182
+
Ok(RateLimitedInner {
183
+
client_ip,
184
+
_marker: PhantomData,
185
+
})
186
+
}
187
+
}
188
+
189
+
pub type RateLimited<P> = RateLimitedInner<P, ApiRateLimitRejection>;
190
+
pub type OAuthRateLimited<P> = RateLimitedInner<P, OAuthRateLimitRejection>;
191
+
192
+
#[derive(Debug)]
193
+
pub struct UserRateLimitError {
194
+
pub kind: RateLimitKind,
195
+
pub message: Option<String>,
196
+
}
197
+
198
+
impl UserRateLimitError {
199
+
pub fn new(kind: RateLimitKind) -> Self {
200
+
Self {
201
+
kind,
202
+
message: None,
203
+
}
204
+
}
205
+
206
+
pub fn with_message(kind: RateLimitKind, message: impl Into<String>) -> Self {
207
+
Self {
208
+
kind,
209
+
message: Some(message.into()),
210
+
}
211
+
}
212
+
}
213
+
214
+
impl std::fmt::Display for UserRateLimitError {
215
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216
+
match &self.message {
217
+
Some(msg) => write!(f, "{}", msg),
218
+
None => write!(f, "Rate limit exceeded for {:?}", self.kind),
219
+
}
220
+
}
221
+
}
222
+
223
+
impl std::error::Error for UserRateLimitError {}
224
+
225
+
impl IntoResponse for UserRateLimitError {
226
+
fn into_response(self) -> Response {
227
+
ApiError::RateLimitExceeded(self.message).into_response()
228
+
}
229
+
}
230
+
231
+
pub struct UserRateLimitProof<P: RateLimitPolicy> {
232
+
_marker: PhantomData<P>,
233
+
}
234
+
235
+
impl<P: RateLimitPolicy> UserRateLimitProof<P> {
236
+
fn new() -> Self {
237
+
Self {
238
+
_marker: PhantomData,
239
+
}
240
+
}
241
+
}
242
+
243
+
pub async fn check_user_rate_limit<P: RateLimitPolicy>(
244
+
state: &AppState,
245
+
user_key: &str,
246
+
) -> Result<UserRateLimitProof<P>, UserRateLimitError> {
247
+
if !state.check_rate_limit(P::KIND, user_key).await {
248
+
tracing::warn!(
249
+
key = %user_key,
250
+
kind = ?P::KIND,
251
+
"User rate limit exceeded"
252
+
);
253
+
return Err(UserRateLimitError::new(P::KIND));
254
+
}
255
+
Ok(UserRateLimitProof::new())
256
+
}
257
+
258
+
pub async fn check_user_rate_limit_with_message<P: RateLimitPolicy>(
259
+
state: &AppState,
260
+
user_key: &str,
261
+
error_message: impl Into<String>,
262
+
) -> Result<UserRateLimitProof<P>, UserRateLimitError> {
263
+
if !state.check_rate_limit(P::KIND, user_key).await {
264
+
tracing::warn!(
265
+
key = %user_key,
266
+
kind = ?P::KIND,
267
+
"User rate limit exceeded"
268
+
);
269
+
return Err(UserRateLimitError::with_message(P::KIND, error_message));
270
+
}
271
+
Ok(UserRateLimitProof::new())
272
+
}
+5
-102
crates/tranquil-pds/src/rate_limit.rs
crates/tranquil-pds/src/rate_limit/mod.rs
+5
-102
crates/tranquil-pds/src/rate_limit.rs
crates/tranquil-pds/src/rate_limit/mod.rs
···
1
-
use axum::{
2
-
Json,
3
-
body::Body,
4
-
extract::ConnectInfo,
5
-
http::{HeaderMap, Request, StatusCode},
6
-
middleware::Next,
7
-
response::{IntoResponse, Response},
8
-
};
1
+
mod extractor;
2
+
3
+
pub use extractor::*;
4
+
9
5
use governor::{
10
6
Quota, RateLimiter,
11
7
clock::DefaultClock,
12
8
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13
9
};
14
-
use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
10
+
use std::{num::NonZeroU32, sync::Arc};
15
11
16
12
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17
13
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
···
166
162
}
167
163
}
168
164
169
-
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
170
-
if let Some(forwarded) = headers.get("x-forwarded-for")
171
-
&& let Ok(value) = forwarded.to_str()
172
-
&& let Some(first_ip) = value.split(',').next()
173
-
{
174
-
return first_ip.trim().to_string();
175
-
}
176
-
177
-
if let Some(real_ip) = headers.get("x-real-ip")
178
-
&& let Ok(value) = real_ip.to_str()
179
-
{
180
-
return value.trim().to_string();
181
-
}
182
-
183
-
addr.map(|a| a.ip().to_string())
184
-
.unwrap_or_else(|| "unknown".to_string())
185
-
}
186
-
187
-
fn rate_limit_response() -> Response {
188
-
(
189
-
StatusCode::TOO_MANY_REQUESTS,
190
-
Json(serde_json::json!({
191
-
"error": "RateLimitExceeded",
192
-
"message": "Too many requests. Please try again later."
193
-
})),
194
-
)
195
-
.into_response()
196
-
}
197
-
198
-
pub async fn login_rate_limit(
199
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
200
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
201
-
request: Request<Body>,
202
-
next: Next,
203
-
) -> Response {
204
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
205
-
206
-
if limiters.login.check_key(&client_ip).is_err() {
207
-
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
208
-
return rate_limit_response();
209
-
}
210
-
211
-
next.run(request).await
212
-
}
213
-
214
-
pub async fn oauth_token_rate_limit(
215
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
216
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
217
-
request: Request<Body>,
218
-
next: Next,
219
-
) -> Response {
220
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
221
-
222
-
if limiters.oauth_token.check_key(&client_ip).is_err() {
223
-
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
224
-
return rate_limit_response();
225
-
}
226
-
227
-
next.run(request).await
228
-
}
229
-
230
-
pub async fn password_reset_rate_limit(
231
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
232
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
233
-
request: Request<Body>,
234
-
next: Next,
235
-
) -> Response {
236
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
237
-
238
-
if limiters.password_reset.check_key(&client_ip).is_err() {
239
-
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
240
-
return rate_limit_response();
241
-
}
242
-
243
-
next.run(request).await
244
-
}
245
-
246
-
pub async fn account_creation_rate_limit(
247
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
248
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
249
-
request: Request<Body>,
250
-
next: Next,
251
-
) -> Response {
252
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
253
-
254
-
if limiters.account_creation.check_key(&client_ip).is_err() {
255
-
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
256
-
return rate_limit_response();
257
-
}
258
-
259
-
next.run(request).await
260
-
}
261
-
262
165
#[cfg(test)]
263
166
mod tests {
264
167
use super::*;
+15
-19
crates/tranquil-pds/src/scheduled.rs
+15
-19
crates/tranquil-pds/src/scheduled.rs
···
5
5
use std::str::FromStr;
6
6
use std::sync::Arc;
7
7
use std::time::Duration;
8
-
use tokio::sync::watch;
9
8
use tokio::time::interval;
9
+
use tokio_util::sync::CancellationToken;
10
10
use tracing::{debug, error, info, warn};
11
11
use tranquil_db_traits::{
12
-
BackupRepository, BlobRepository, BrokenGenesisCommit, RepoRepository, SsoRepository,
13
-
UserRepository,
12
+
BackupRepository, BlobRepository, BrokenGenesisCommit, RepoRepository, SequenceNumber,
13
+
SsoRepository, UserRepository,
14
14
};
15
15
use tranquil_types::{AtUri, CidLink, Did};
16
16
···
22
22
repo_repo: &dyn RepoRepository,
23
23
block_store: &PostgresBlockStore,
24
24
row: BrokenGenesisCommit,
25
-
) -> Result<(Did, i64), (i64, &'static str)> {
25
+
) -> Result<(Did, SequenceNumber), (SequenceNumber, &'static str)> {
26
26
let commit_cid_str = row.commit_cid.ok_or((row.seq, "missing commit_cid"))?;
27
27
let commit_cid = Cid::from_str(&commit_cid_str).map_err(|_| (row.seq, "invalid CID"))?;
28
28
let block = block_store
···
73
73
74
74
let (success, failed) = results.iter().fold((0, 0), |(s, f), r| match r {
75
75
Ok((did, seq)) => {
76
-
info!(seq = seq, did = %did, "Fixed genesis commit blocks_cids");
76
+
info!(seq = seq.as_i64(), did = %did, "Fixed genesis commit blocks_cids");
77
77
(s + 1, f)
78
78
}
79
79
Err((seq, reason)) => {
80
80
warn!(
81
-
seq = seq,
81
+
seq = seq.as_i64(),
82
82
reason = reason,
83
83
"Failed to process genesis commit"
84
84
);
···
314
314
record.collection.as_str(),
315
315
record.rkey.as_str(),
316
316
);
317
-
(record_uri, CidLink::new_unchecked(blob_ref.cid))
317
+
(record_uri, unsafe { CidLink::new_unchecked(blob_ref.cid) })
318
318
})
319
319
.collect::<Vec<_>>(),
320
320
)
···
392
392
blob_repo: Arc<dyn BlobRepository>,
393
393
blob_store: Arc<dyn BlobStorage>,
394
394
sso_repo: Arc<dyn SsoRepository>,
395
-
mut shutdown_rx: watch::Receiver<bool>,
395
+
shutdown: CancellationToken,
396
396
) {
397
397
let check_interval = Duration::from_secs(
398
398
std::env::var("SCHEDULED_DELETE_CHECK_INTERVAL_SECS")
···
411
411
412
412
loop {
413
413
tokio::select! {
414
-
_ = shutdown_rx.changed() => {
415
-
if *shutdown_rx.borrow() {
416
-
info!("Scheduled tasks service shutting down");
417
-
break;
418
-
}
414
+
_ = shutdown.cancelled() => {
415
+
info!("Scheduled tasks service shutting down");
416
+
break;
419
417
}
420
418
_ = ticker.tick() => {
421
419
if let Err(e) = process_scheduled_deletions(
···
538
536
backup_repo: Arc<dyn BackupRepository>,
539
537
block_store: PostgresBlockStore,
540
538
backup_storage: Arc<dyn BackupStorage>,
541
-
mut shutdown_rx: watch::Receiver<bool>,
539
+
shutdown: CancellationToken,
542
540
) {
543
541
let backup_interval = Duration::from_secs(backup_interval_secs());
544
542
···
553
551
554
552
loop {
555
553
tokio::select! {
556
-
_ = shutdown_rx.changed() => {
557
-
if *shutdown_rx.borrow() {
558
-
info!("Backup service shutting down");
559
-
break;
560
-
}
554
+
_ = shutdown.cancelled() => {
555
+
info!("Backup service shutting down");
556
+
break;
561
557
}
562
558
_ = ticker.tick() => {
563
559
if let Err(e) = process_scheduled_backups(
+2
-1
crates/tranquil-pds/src/sso/config.rs
+2
-1
crates/tranquil-pds/src/sso/config.rs
···
1
+
use crate::util::pds_hostname;
1
2
use std::sync::OnceLock;
2
3
use tranquil_db_traits::SsoProviderType;
3
4
···
50
51
};
51
52
52
53
if config.is_any_enabled() {
53
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_default();
54
+
let hostname = pds_hostname();
54
55
if hostname.is_empty() || hostname == "localhost" {
55
56
panic!(
56
57
"PDS_HOSTNAME must be set to a valid hostname when SSO is enabled. \
+98
-114
crates/tranquil-pds/src/sso/endpoints.rs
+98
-114
crates/tranquil-pds/src/sso/endpoints.rs
···
6
6
};
7
7
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8
8
use serde::{Deserialize, Serialize};
9
-
use tranquil_db_traits::SsoProviderType;
9
+
use tranquil_db_traits::{SsoAction, SsoProviderType};
10
10
use tranquil_types::RequestId;
11
11
12
12
use super::config::SsoConfig;
13
13
use crate::api::error::ApiError;
14
14
use crate::auth::extractor::extract_bearer_token_from_header;
15
15
use crate::auth::{generate_app_password, validate_bearer_token_cached};
16
-
use crate::rate_limit::extract_client_ip;
17
-
use crate::state::{AppState, RateLimitKind};
16
+
use crate::rate_limit::{
17
+
AccountCreationLimit, RateLimited, SsoCallbackLimit, SsoInitiateLimit, SsoUnlinkLimit,
18
+
check_user_rate_limit_with_message,
19
+
};
20
+
use crate::state::AppState;
21
+
use crate::util::{pds_hostname, pds_hostname_without_port};
18
22
19
23
fn generate_state() -> String {
20
24
use rand::RngCore;
···
71
75
72
76
pub async fn sso_initiate(
73
77
State(state): State<AppState>,
78
+
_rate_limit: RateLimited<SsoInitiateLimit>,
74
79
headers: HeaderMap,
75
80
Json(input): Json<SsoInitiateRequest>,
76
81
) -> Result<Json<SsoInitiateResponse>, ApiError> {
77
-
let client_ip = extract_client_ip(&headers, None);
78
-
if !state
79
-
.check_rate_limit(RateLimitKind::SsoInitiate, &client_ip)
80
-
.await
81
-
{
82
-
tracing::warn!(ip = %client_ip, "SSO initiate rate limit exceeded");
83
-
return Err(ApiError::RateLimitExceeded(None));
84
-
}
85
-
86
82
if input.provider.len() > 20 {
87
83
return Err(ApiError::SsoProviderNotFound);
88
84
}
···
105
101
.get_provider(provider_type)
106
102
.ok_or(ApiError::SsoProviderNotEnabled)?;
107
103
108
-
let action = input.action.as_deref().unwrap_or("login");
109
-
if !["login", "link", "register"].contains(&action) {
110
-
return Err(ApiError::SsoInvalidAction);
111
-
}
104
+
let action = input
105
+
.action
106
+
.as_deref()
107
+
.map(SsoAction::parse)
108
+
.unwrap_or(Some(SsoAction::Login))
109
+
.ok_or(ApiError::SsoInvalidAction)?;
112
110
113
-
let is_standalone = action == "register" && input.request_uri.is_none();
111
+
let is_standalone = action == SsoAction::Register && input.request_uri.is_none();
114
112
let request_uri = input
115
113
.request_uri
116
114
.clone()
117
115
.unwrap_or_else(|| "standalone".to_string());
118
116
119
117
let auth_did = match action {
120
-
"link" => {
118
+
SsoAction::Link => {
121
119
let auth_header = headers
122
120
.get(axum::http::header::AUTHORIZATION)
123
121
.and_then(|v| v.to_str().ok());
···
132
130
.map_err(|_| ApiError::SsoNotAuthenticated)?;
133
131
Some(auth_user.did)
134
132
}
135
-
"register" if is_standalone => None,
133
+
SsoAction::Register if is_standalone => None,
136
134
_ => {
137
135
let request_id = RequestId::new(request_uri.clone());
138
136
let _request_data = state
···
217
215
218
216
pub async fn sso_callback(
219
217
State(state): State<AppState>,
220
-
headers: HeaderMap,
218
+
_rate_limit: RateLimited<SsoCallbackLimit>,
221
219
Query(query): Query<SsoCallbackQuery>,
222
220
) -> Response {
221
+
sso_callback_internal(&state, query).await
222
+
}
223
+
224
+
async fn sso_callback_internal(state: &AppState, query: SsoCallbackQuery) -> Response {
223
225
tracing::debug!(
224
226
has_code = query.code.is_some(),
225
227
has_state = query.state.is_some(),
···
227
229
"SSO callback received"
228
230
);
229
231
230
-
let client_ip = extract_client_ip(&headers, None);
231
-
if !state
232
-
.check_rate_limit(RateLimitKind::SsoCallback, &client_ip)
233
-
.await
234
-
{
235
-
tracing::warn!(ip = %client_ip, "SSO callback rate limit exceeded");
236
-
return redirect_to_error("Too many requests. Please try again later.");
237
-
}
238
-
239
232
if let Some(ref error) = query.error {
240
233
tracing::warn!(
241
234
error = %error,
···
326
319
}
327
320
};
328
321
329
-
match auth_state.action.as_str() {
330
-
"login" => {
322
+
match auth_state.action {
323
+
SsoAction::Login => {
331
324
handle_sso_login(
332
-
&state,
325
+
state,
333
326
&auth_state.request_uri,
334
327
auth_state.provider,
335
328
&user_info,
336
329
)
337
330
.await
338
331
}
339
-
"link" => {
332
+
SsoAction::Link => {
340
333
let did = match auth_state.did {
341
334
Some(d) => d,
342
335
None => return redirect_to_error("Not authenticated"),
343
336
};
344
-
handle_sso_link(&state, did, auth_state.provider, &user_info).await
337
+
handle_sso_link(state, did, auth_state.provider, &user_info).await
345
338
}
346
-
"register" => {
339
+
SsoAction::Register => {
347
340
handle_sso_register(
348
-
&state,
341
+
state,
349
342
&auth_state.request_uri,
350
343
auth_state.provider,
351
344
&user_info,
352
345
)
353
346
.await
354
347
}
355
-
_ => redirect_to_error("Unknown SSO action"),
356
348
}
357
349
}
358
350
359
351
pub async fn sso_callback_post(
360
352
State(state): State<AppState>,
361
-
headers: HeaderMap,
353
+
_rate_limit: RateLimited<SsoCallbackLimit>,
362
354
Form(form): Form<SsoCallbackForm>,
363
355
) -> Response {
364
356
tracing::debug!(
···
376
368
error_description: form.error_description,
377
369
};
378
370
379
-
sso_callback(State(state), headers, Query(query)).await
371
+
sso_callback_internal(&state, query).await
380
372
}
381
373
382
374
fn generate_registration_token() -> String {
···
429
421
};
430
422
431
423
let is_verified = match state.user_repo.get_session_info_by_did(&identity.did).await {
432
-
Ok(Some(info)) => {
433
-
info.email_verified
434
-
|| info.discord_verified
435
-
|| info.telegram_verified
436
-
|| info.signal_verified
437
-
}
424
+
Ok(Some(info)) => info.channel_verification.has_any_verified(),
438
425
Ok(None) => {
439
426
tracing::error!("User not found for SSO login: {}", identity.did);
440
427
return redirect_to_error("Account not found");
···
486
473
"SSO login successful"
487
474
);
488
475
489
-
let has_totp = match state.user_repo.get_totp_record(&identity.did).await {
490
-
Ok(Some(record)) => record.verified,
491
-
_ => false,
492
-
};
476
+
let has_totp = matches!(
477
+
state.user_repo.get_totp_record_state(&identity.did).await,
478
+
Ok(Some(tranquil_db_traits::TotpRecordState::Verified(_)))
479
+
);
493
480
494
481
if has_totp {
495
482
return Redirect::to(&format!(
···
657
644
id: id.id.to_string(),
658
645
provider: id.provider.as_str().to_string(),
659
646
provider_name: id.provider.display_name().to_string(),
660
-
provider_username: id.provider_username,
661
-
provider_email: id.provider_email,
647
+
provider_username: id.provider_username.map(|u| u.into_inner()),
648
+
provider_email: id.provider_email.map(|e| e.into_inner()),
662
649
created_at: id.created_at.to_rfc3339(),
663
650
last_login_at: id.last_login_at.map(|t| t.to_rfc3339()),
664
651
})
···
682
669
auth: crate::auth::Auth<crate::auth::Active>,
683
670
Json(input): Json<UnlinkAccountRequest>,
684
671
) -> Result<Json<UnlinkAccountResponse>, ApiError> {
685
-
if !state
686
-
.check_rate_limit(RateLimitKind::SsoUnlink, auth.did.as_str())
687
-
.await
688
-
{
689
-
tracing::warn!(did = %auth.did, "SSO unlink rate limit exceeded");
690
-
return Err(ApiError::RateLimitExceeded(None));
691
-
}
672
+
let _rate_limit = check_user_rate_limit_with_message::<SsoUnlinkLimit>(
673
+
&state,
674
+
auth.did.as_str(),
675
+
"Too many unlink attempts. Please try again later.",
676
+
)
677
+
.await?;
692
678
693
679
let id = uuid::Uuid::parse_str(&input.id).map_err(|_| ApiError::InvalidId)?;
694
680
···
746
732
747
733
pub async fn get_pending_registration(
748
734
State(state): State<AppState>,
749
-
headers: HeaderMap,
735
+
_rate_limit: RateLimited<SsoCallbackLimit>,
750
736
Query(query): Query<PendingRegistrationQuery>,
751
737
) -> Result<Json<PendingRegistrationResponse>, ApiError> {
752
-
let client_ip = extract_client_ip(&headers, None);
753
-
if !state
754
-
.check_rate_limit(RateLimitKind::SsoCallback, &client_ip)
755
-
.await
756
-
{
757
-
tracing::warn!(ip = %client_ip, "SSO pending registration rate limit exceeded");
758
-
return Err(ApiError::RateLimitExceeded(None));
759
-
}
760
-
761
738
if query.token.len() > 100 {
762
739
return Err(ApiError::InvalidRequest("Invalid token".into()));
763
740
}
···
771
748
Ok(Json(PendingRegistrationResponse {
772
749
request_uri: pending.request_uri,
773
750
provider: pending.provider.as_str().to_string(),
774
-
provider_user_id: pending.provider_user_id,
775
-
provider_username: pending.provider_username,
776
-
provider_email: pending.provider_email,
751
+
provider_user_id: pending.provider_user_id.into_inner(),
752
+
provider_username: pending.provider_username.map(|u| u.into_inner()),
753
+
provider_email: pending.provider_email.map(|e| e.into_inner()),
777
754
provider_email_verified: pending.provider_email_verified,
778
755
}))
779
756
}
···
810
787
}
811
788
};
812
789
813
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
814
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
790
+
let hostname_for_handles = pds_hostname_without_port();
815
791
let full_handle = format!("{}.{}", validated, hostname_for_handles);
816
-
let handle_typed = crate::types::Handle::new_unchecked(&full_handle);
792
+
let handle_typed = unsafe { crate::types::Handle::new_unchecked(&full_handle) };
817
793
818
794
let db_available = state
819
795
.user_repo
···
866
842
867
843
pub async fn complete_registration(
868
844
State(state): State<AppState>,
869
-
headers: HeaderMap,
845
+
rate_limit: RateLimited<AccountCreationLimit>,
870
846
Json(input): Json<CompleteRegistrationInput>,
871
847
) -> Result<Json<CompleteRegistrationResponse>, ApiError> {
848
+
let client_ip = rate_limit.client_ip();
872
849
use jacquard_common::types::{integer::LimitedU32, string::Tid};
873
850
use jacquard_repo::{mst::Mst, storage::BlockStore};
874
851
use k256::ecdsa::SigningKey;
···
876
853
use serde_json::json;
877
854
use std::sync::Arc;
878
855
879
-
let client_ip = extract_client_ip(&headers, None);
880
-
if !state
881
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
882
-
.await
883
-
{
884
-
tracing::warn!(ip = %client_ip, "SSO registration rate limit exceeded");
885
-
return Err(ApiError::RateLimitExceeded(None));
886
-
}
887
-
888
856
if input.token.len() > 100 {
889
857
return Err(ApiError::InvalidRequest("Invalid token".into()));
890
858
}
···
899
867
.await?
900
868
.ok_or(ApiError::SsoSessionExpired)?;
901
869
902
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
903
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
870
+
let hostname = pds_hostname();
871
+
let hostname_for_handles = pds_hostname_without_port();
904
872
905
873
let handle = match crate::api::validation::validate_short_handle(&input.handle) {
906
874
Ok(h) => format!("{}.{}", h, hostname_for_handles),
···
913
881
let email = input
914
882
.email
915
883
.clone()
916
-
.or_else(|| pending_preview.provider_email.clone())
884
+
.or_else(|| {
885
+
pending_preview
886
+
.provider_email
887
+
.clone()
888
+
.map(|e| e.into_inner())
889
+
})
917
890
.map(|e| e.trim().to_string())
918
891
.filter(|e| !e.is_empty());
919
892
match email {
···
939
912
let email = input
940
913
.email
941
914
.clone()
942
-
.or_else(|| pending_preview.provider_email.clone())
915
+
.or_else(|| {
916
+
pending_preview
917
+
.provider_email
918
+
.clone()
919
+
.map(|e| e.into_inner())
920
+
})
943
921
.map(|e| e.trim().to_string())
944
922
.filter(|e| !e.is_empty());
945
923
···
956
934
None => None,
957
935
};
958
936
959
-
if let Some(ref code) = input.invite_code {
960
-
let valid = state
961
-
.infra_repo
962
-
.is_invite_code_valid(code)
963
-
.await
964
-
.unwrap_or(false);
965
-
if !valid {
966
-
return Err(ApiError::InvalidInviteCode);
937
+
let _validated_invite_code = if let Some(ref code) = input.invite_code {
938
+
match state.infra_repo.validate_invite_code(code).await {
939
+
Ok(validated) => Some(validated),
940
+
Err(_) => return Err(ApiError::InvalidInviteCode),
967
941
}
968
942
} else {
969
943
let invite_required = std::env::var("INVITE_CODE_REQUIRED")
···
972
946
if invite_required {
973
947
return Err(ApiError::InviteCodeRequired);
974
948
}
975
-
}
949
+
None
950
+
};
976
951
977
-
let handle_typed = crate::types::Handle::new_unchecked(&handle);
952
+
let handle_typed = unsafe { crate::types::Handle::new_unchecked(&handle) };
978
953
let reserved = state
979
954
.user_repo
980
-
.reserve_handle(&handle_typed, &client_ip)
955
+
.reserve_handle(&handle_typed, client_ip)
981
956
.await
982
957
.unwrap_or(false);
983
958
···
1076
1051
};
1077
1052
1078
1053
let rev = Tid::now(LimitedU32::MIN);
1079
-
let did_typed = crate::types::Did::new_unchecked(&did);
1054
+
let did_typed = unsafe { crate::types::Did::new_unchecked(&did) };
1080
1055
let (commit_bytes, _sig) = match crate::api::repo::record::utils::create_signed_commit(
1081
1056
&did_typed,
1082
1057
mst_root,
···
1144
1119
invite_code: input.invite_code.clone(),
1145
1120
birthdate_pref,
1146
1121
sso_provider: pending_preview.provider,
1147
-
sso_provider_user_id: pending_preview.provider_user_id.clone(),
1148
-
sso_provider_username: pending_preview.provider_username.clone(),
1149
-
sso_provider_email: pending_preview.provider_email.clone(),
1122
+
sso_provider_user_id: pending_preview.provider_user_id.clone().into_inner(),
1123
+
sso_provider_username: pending_preview
1124
+
.provider_username
1125
+
.clone()
1126
+
.map(|u| u.into_inner()),
1127
+
sso_provider_email: pending_preview
1128
+
.provider_email
1129
+
.clone()
1130
+
.map(|e| e.into_inner()),
1150
1131
sso_provider_email_verified: pending_preview.provider_email_verified,
1151
1132
pending_registration_token: input.token.clone(),
1152
1133
};
···
1179
1160
{
1180
1161
tracing::warn!("Failed to sequence identity event for {}: {}", did, e);
1181
1162
}
1182
-
if let Err(e) =
1183
-
crate::api::repo::record::sequence_account_event(&state, &did_typed, true, None).await
1163
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
1164
+
&state,
1165
+
&did_typed,
1166
+
tranquil_db_traits::AccountStatus::Active,
1167
+
)
1168
+
.await
1184
1169
{
1185
1170
tracing::warn!("Failed to sequence account event for {}: {}", did, e);
1186
1171
}
···
1189
1174
"$type": "app.bsky.actor.profile",
1190
1175
"displayName": handle_typed.as_str()
1191
1176
});
1192
-
let profile_collection = crate::types::Nsid::new_unchecked("app.bsky.actor.profile");
1193
-
let profile_rkey = crate::types::Rkey::new_unchecked("self");
1177
+
let profile_collection = unsafe { crate::types::Nsid::new_unchecked("app.bsky.actor.profile") };
1178
+
let profile_rkey = unsafe { crate::types::Rkey::new_unchecked("self") };
1194
1179
if let Err(e) = crate::api::repo::record::create_record_internal(
1195
1180
&state,
1196
1181
&did_typed,
···
1217
1202
user_id: create_result.user_id,
1218
1203
name: app_password_name.clone(),
1219
1204
password_hash: app_password_hash,
1220
-
privileged: false,
1205
+
privilege: tranquil_db_traits::AppPasswordPrivilege::Standard,
1221
1206
scopes: None,
1222
1207
created_by_controller_did: None,
1223
1208
};
···
1260
1245
1261
1246
let channel_auto_verified = verification_channel == "email"
1262
1247
&& pending_preview.provider_email_verified
1263
-
&& pending_preview.provider_email.as_ref() == email.as_ref();
1248
+
&& pending_preview.provider_email.as_ref().map(|e| e.as_str()) == email.as_deref();
1264
1249
1265
1250
if channel_auto_verified {
1266
1251
let _ = state
···
1304
1289
refresh_jti: refresh_meta.jti.clone(),
1305
1290
access_expires_at: access_meta.expires_at,
1306
1291
refresh_expires_at: refresh_meta.expires_at,
1307
-
legacy_login: false,
1292
+
login_type: tranquil_db_traits::LoginType::Modern,
1308
1293
mfa_verified: false,
1309
1294
scope: None,
1310
1295
controller_did: None,
···
1315
1300
return Err(ApiError::InternalError(None));
1316
1301
}
1317
1302
1318
-
let hostname =
1319
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
1303
+
let hostname = pds_hostname();
1320
1304
if let Err(e) = crate::comms::comms_repo::enqueue_welcome(
1321
1305
state.user_repo.as_ref(),
1322
1306
state.infra_repo.as_ref(),
1323
1307
user_id.unwrap_or(uuid::Uuid::nil()),
1324
-
&hostname,
1308
+
hostname,
1325
1309
)
1326
1310
.await
1327
1311
{
···
1367
1351
verification_channel,
1368
1352
&verification_recipient,
1369
1353
&formatted_token,
1370
-
&hostname,
1354
+
hostname,
1371
1355
)
1372
1356
.await
1373
1357
{
+15
-3
crates/tranquil-pds/src/state.rs
+15
-3
crates/tranquil-pds/src/state.rs
···
1
1
use crate::appview::DidResolver;
2
+
use crate::auth::webauthn::WebAuthnConfig;
2
3
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
3
4
use crate::circuit_breaker::CircuitBreakers;
4
5
use crate::config::AuthConfig;
···
7
8
use crate::sso::{SsoConfig, SsoManager};
8
9
use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage};
9
10
use crate::sync::firehose::SequencedEvent;
11
+
use crate::util::pds_hostname;
10
12
use sqlx::PgPool;
11
13
use std::error::Error;
12
14
use std::sync::Arc;
13
15
use tokio::sync::broadcast;
16
+
use tokio_util::sync::CancellationToken;
14
17
use tranquil_db::{
15
18
BacklinkRepository, BackupRepository, BlobRepository, DelegationRepository, InfraRepository,
16
19
OAuthRepository, PostgresRepositories, RepoEventNotifier, RepoRepository, SessionRepository,
···
41
44
pub did_resolver: Arc<DidResolver>,
42
45
pub sso_repo: Arc<dyn SsoRepository>,
43
46
pub sso_manager: SsoManager,
47
+
pub webauthn_config: Arc<WebAuthnConfig>,
48
+
pub shutdown: CancellationToken,
44
49
}
45
50
51
+
#[derive(Debug, Clone, Copy)]
46
52
pub enum RateLimitKind {
47
53
Login,
48
54
AccountCreation,
···
116
122
}
117
123
118
124
impl AppState {
119
-
pub async fn new() -> Result<Self, Box<dyn Error>> {
125
+
pub async fn new(shutdown: CancellationToken) -> Result<Self, Box<dyn Error>> {
120
126
let database_url = std::env::var("DATABASE_URL")
121
127
.map_err(|_| "DATABASE_URL environment variable must be set")?;
122
128
···
157
163
.await
158
164
.map_err(|e| format!("Failed to run migrations: {}", e))?;
159
165
160
-
Ok(Self::from_db(db).await)
166
+
Ok(Self::from_db(db, shutdown).await)
161
167
}
162
168
163
-
pub async fn from_db(db: PgPool) -> Self {
169
+
pub async fn from_db(db: PgPool, shutdown: CancellationToken) -> Self {
164
170
AuthConfig::init();
165
171
166
172
let repos = Arc::new(PostgresRepositories::new(db.clone()));
···
180
186
let did_resolver = Arc::new(DidResolver::new());
181
187
let sso_config = SsoConfig::init();
182
188
let sso_manager = SsoManager::from_config(sso_config);
189
+
let webauthn_config = Arc::new(
190
+
WebAuthnConfig::new(pds_hostname())
191
+
.expect("Failed to create WebAuthn config at startup"),
192
+
);
183
193
184
194
Self {
185
195
user_repo: repos.user.clone(),
···
204
214
distributed_rate_limiter,
205
215
did_resolver,
206
216
sso_manager,
217
+
webauthn_config,
218
+
shutdown,
207
219
}
208
220
}
209
221
+4
-3
crates/tranquil-pds/src/sync/commit.rs
+4
-3
crates/tranquil-pds/src/sync/commit.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::state::AppState;
3
-
use crate::sync::util::{AccountStatus, assert_repo_availability, get_account_with_status};
3
+
use crate::sync::util::{assert_repo_availability, get_account_with_status};
4
4
use axum::{
5
5
Json,
6
6
extract::{Query, State},
···
13
13
use serde::{Deserialize, Serialize};
14
14
use std::str::FromStr;
15
15
use tracing::error;
16
+
use tranquil_db_traits::AccountStatus;
16
17
use tranquil_types::Did;
17
18
18
19
async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> {
···
130
131
head: cid_str,
131
132
rev,
132
133
active: status.is_active(),
133
-
status: status.as_str().map(String::from),
134
+
status: status.for_firehose().map(String::from),
134
135
});
135
136
}
136
137
let next_cursor = if has_more {
···
212
213
Json(GetRepoStatusOutput {
213
214
did: account.did,
214
215
active: account.status.is_active(),
215
-
status: account.status.as_str().map(String::from),
216
+
status: account.status.for_firehose().map(String::from),
216
217
rev,
217
218
}),
218
219
)
+5
-4
crates/tranquil-pds/src/sync/deprecated.rs
+5
-4
crates/tranquil-pds/src/sync/deprecated.rs
···
19
19
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
20
20
21
21
async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool {
22
-
let extracted = match crate::auth::extract_auth_token_from_header(
23
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
24
-
) {
22
+
let extracted = match crate::auth::extract_auth_token_from_header(crate::util::get_header_str(
23
+
headers,
24
+
"Authorization",
25
+
)) {
25
26
Some(t) => t,
26
27
None => return false,
27
28
};
28
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
29
+
let dpop_proof = crate::util::get_header_str(headers, "DPoP");
29
30
let http_uri = "/";
30
31
match crate::auth::validate_token_with_dpop(
31
32
state.user_repo.as_ref(),
+10
-7
crates/tranquil-pds/src/sync/frame.rs
+10
-7
crates/tranquil-pds/src/sync/frame.rs
···
2
2
use cid::Cid;
3
3
use serde::{Deserialize, Serialize};
4
4
use std::str::FromStr;
5
+
use tranquil_scopes::RepoAction;
5
6
6
7
#[derive(Debug, Serialize, Deserialize)]
7
8
pub struct FrameHeader {
···
38
39
39
40
#[derive(Debug, Serialize, Deserialize)]
40
41
pub struct RepoOp {
41
-
pub action: String,
42
+
pub action: RepoAction,
42
43
pub path: String,
43
44
pub cid: Option<Cid>,
44
45
#[serde(skip_serializing_if = "Option::is_none")]
···
159
160
serde_json::from_value(self.ops_json).unwrap_or_else(|_| vec![]);
160
161
let ops: Vec<RepoOp> = json_ops
161
162
.into_iter()
162
-
.map(|op| RepoOp {
163
-
action: op.action,
164
-
path: op.path,
165
-
cid: op.cid.and_then(|s| Cid::from_str(&s).ok()),
166
-
prev: op.prev.and_then(|s| Cid::from_str(&s).ok()),
163
+
.filter_map(|op| {
164
+
Some(RepoOp {
165
+
action: RepoAction::parse_str(&op.action)?,
166
+
path: op.path,
167
+
cid: op.cid.and_then(|s| Cid::from_str(&s).ok()),
168
+
prev: op.prev.and_then(|s| Cid::from_str(&s).ok()),
169
+
})
167
170
})
168
171
.collect();
169
172
let rev = self.rev.unwrap_or_else(placeholder_rev);
···
202
205
CommitFrameError::InvalidCommitCid("Missing commit_cid in event".to_string())
203
206
})?;
204
207
let builder = CommitFrameBuilder::new(
205
-
event.seq,
208
+
event.seq.as_i64(),
206
209
event.did.to_string(),
207
210
commit_cid.as_str(),
208
211
event.prev_cid.as_ref().map(|c| c.as_str()),
+21
-10
crates/tranquil-pds/src/sync/listener.rs
+21
-10
crates/tranquil-pds/src/sync/listener.rs
···
2
2
use crate::sync::firehose::SequencedEvent;
3
3
use std::sync::atomic::{AtomicI64, Ordering};
4
4
use tracing::{debug, error, info, warn};
5
+
use tranquil_db_traits::SequenceNumber;
5
6
6
7
static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0);
7
8
8
9
pub async fn start_sequencer_listener(state: AppState) {
9
-
let initial_seq = state.repo_repo.get_max_seq().await.unwrap_or(0);
10
-
LAST_BROADCAST_SEQ.store(initial_seq, Ordering::SeqCst);
11
-
info!(initial_seq = initial_seq, "Initialized sequencer listener");
10
+
let initial_seq = state
11
+
.repo_repo
12
+
.get_max_seq()
13
+
.await
14
+
.unwrap_or(SequenceNumber::ZERO);
15
+
LAST_BROADCAST_SEQ.store(initial_seq.as_i64(), Ordering::SeqCst);
16
+
info!(
17
+
initial_seq = initial_seq.as_i64(),
18
+
"Initialized sequencer listener"
19
+
);
12
20
tokio::spawn(async move {
13
21
info!("Starting sequencer listener background task");
14
22
loop {
···
27
35
.await
28
36
.map_err(|e| anyhow::anyhow!("Failed to subscribe to events: {:?}", e))?;
29
37
info!("Connected to database and listening for repo updates");
30
-
let catchup_start = LAST_BROADCAST_SEQ.load(Ordering::SeqCst);
38
+
let catchup_start = SequenceNumber::from_raw(LAST_BROADCAST_SEQ.load(Ordering::SeqCst));
31
39
let events = state
32
40
.repo_repo
33
41
.get_events_since_seq(catchup_start, None)
···
36
44
if !events.is_empty() {
37
45
info!(
38
46
count = events.len(),
39
-
from_seq = catchup_start,
47
+
from_seq = catchup_start.as_i64(),
40
48
"Broadcasting catch-up events"
41
49
);
42
50
events.into_iter().for_each(|event| {
43
51
let seq = event.seq;
44
52
let firehose_event = to_firehose_event(event);
45
53
let _ = state.firehose_tx.send(firehose_event);
46
-
LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst);
54
+
LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst);
47
55
});
48
56
}
49
57
loop {
···
63
71
if seq_id > last_seq + 1 {
64
72
let gap_events = state
65
73
.repo_repo
66
-
.get_events_in_seq_range(last_seq, seq_id)
74
+
.get_events_in_seq_range(
75
+
SequenceNumber::from_raw(last_seq),
76
+
SequenceNumber::from_raw(seq_id),
77
+
)
67
78
.await
68
79
.unwrap_or_default();
69
80
if !gap_events.is_empty() {
···
72
83
let seq = event.seq;
73
84
let firehose_event = to_firehose_event(event);
74
85
let _ = state.firehose_tx.send(firehose_event);
75
-
LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst);
86
+
LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst);
76
87
});
77
88
}
78
89
}
79
90
let event = state
80
91
.repo_repo
81
-
.get_event_by_seq(seq_id)
92
+
.get_event_by_seq(SequenceNumber::from_raw(seq_id))
82
93
.await
83
94
.ok()
84
95
.flatten();
···
97
108
warn!(seq = seq_id, error = %e, "Failed to broadcast event (no receivers?)");
98
109
}
99
110
}
100
-
LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst);
111
+
LAST_BROADCAST_SEQ.store(seq.as_i64(), Ordering::SeqCst);
101
112
} else {
102
113
warn!(
103
114
seq = seq_id,
+2
-2
crates/tranquil-pds/src/sync/mod.rs
+2
-2
crates/tranquil-pds/src/sync/mod.rs
···
18
18
pub use deprecated::{get_checkout, get_head};
19
19
pub use repo::{get_blocks, get_record, get_repo};
20
20
pub use subscribe_repos::subscribe_repos;
21
+
pub use tranquil_db_traits::AccountStatus;
21
22
pub use util::{
22
-
AccountStatus, RepoAccount, RepoAvailabilityError, assert_repo_availability,
23
-
get_account_with_status,
23
+
RepoAccount, RepoAvailabilityError, assert_repo_availability, get_account_with_status,
24
24
};
25
25
pub use verify::{CarVerifier, VerifiedCar, VerifyError};
+12
-6
crates/tranquil-pds/src/sync/subscribe_repos.rs
+12
-6
crates/tranquil-pds/src/sync/subscribe_repos.rs
···
13
13
use std::sync::atomic::{AtomicUsize, Ordering};
14
14
use tokio::sync::broadcast::error::RecvError;
15
15
use tracing::{error, info, warn};
16
+
use tranquil_db_traits::SequenceNumber;
16
17
17
18
const BACKFILL_BATCH_SIZE: i64 = 1000;
18
19
···
69
70
params: SubscribeReposParams,
70
71
) -> Result<(), ()> {
71
72
let mut rx = state.firehose_tx.subscribe();
72
-
let mut last_seen: i64 = -1;
73
+
let mut last_seen = SequenceNumber::UNSET;
73
74
74
75
if let Some(cursor) = params.cursor {
75
-
let current_seq = state.repo_repo.get_max_seq().await.unwrap_or(0);
76
+
let cursor_seq = SequenceNumber::from_raw(cursor);
77
+
let current_seq = state
78
+
.repo_repo
79
+
.get_max_seq()
80
+
.await
81
+
.unwrap_or(SequenceNumber::ZERO);
76
82
77
-
if cursor > current_seq {
83
+
if cursor_seq > current_seq {
78
84
if let Ok(error_bytes) =
79
85
format_error_frame("FutureCursor", Some("Cursor in the future."))
80
86
{
···
88
94
89
95
let first_event = state
90
96
.repo_repo
91
-
.get_events_since_cursor(cursor, 1)
97
+
.get_events_since_cursor(cursor_seq, 1)
92
98
.await
93
99
.ok()
94
100
.and_then(|events| events.into_iter().next());
95
101
96
-
let mut current_cursor = cursor;
102
+
let mut current_cursor = cursor_seq;
97
103
98
104
if let Some(ref event) = first_event
99
105
&& event.created_at < backfill_time
···
113
119
.flatten();
114
120
115
121
if let Some(earliest_seq) = earliest {
116
-
current_cursor = earliest_seq - 1;
122
+
current_cursor = SequenceNumber::from_raw(earliest_seq.as_i64() - 1);
117
123
}
118
124
}
119
125
+18
-102
crates/tranquil-pds/src/sync/util.rs
+18
-102
crates/tranquil-pds/src/sync/util.rs
···
11
11
use iroh_car::{CarHeader, CarWriter};
12
12
use jacquard_repo::commit::Commit;
13
13
use jacquard_repo::storage::BlockStore;
14
-
use serde::Serialize;
15
14
use std::collections::{BTreeMap, HashMap};
16
15
use std::io::Cursor;
17
16
use std::str::FromStr;
18
17
use tokio::io::AsyncWriteExt;
19
-
use tranquil_db_traits::RepoRepository;
18
+
use tranquil_db_traits::{AccountStatus, RepoEventType, RepoRepository};
20
19
use tranquil_types::Did;
21
20
22
-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
23
-
#[serde(rename_all = "lowercase")]
24
-
pub enum AccountStatus {
25
-
Active,
26
-
Takendown,
27
-
Suspended,
28
-
Deactivated,
29
-
Deleted,
30
-
}
31
-
32
-
impl AccountStatus {
33
-
pub fn as_str(&self) -> Option<&'static str> {
34
-
match self {
35
-
Self::Active => None,
36
-
Self::Takendown => Some("takendown"),
37
-
Self::Suspended => Some("suspended"),
38
-
Self::Deactivated => Some("deactivated"),
39
-
Self::Deleted => Some("deleted"),
40
-
}
41
-
}
42
-
43
-
pub fn is_active(&self) -> bool {
44
-
matches!(self, Self::Active)
45
-
}
46
-
47
-
pub fn is_takendown(&self) -> bool {
48
-
matches!(self, Self::Takendown)
49
-
}
50
-
51
-
pub fn is_suspended(&self) -> bool {
52
-
matches!(self, Self::Suspended)
53
-
}
54
-
55
-
pub fn is_deactivated(&self) -> bool {
56
-
matches!(self, Self::Deactivated)
57
-
}
58
-
59
-
pub fn is_deleted(&self) -> bool {
60
-
matches!(self, Self::Deleted)
61
-
}
62
-
63
-
pub fn allows_read(&self) -> bool {
64
-
matches!(self, Self::Active | Self::Deactivated)
65
-
}
66
-
67
-
pub fn allows_write(&self) -> bool {
68
-
matches!(self, Self::Active)
69
-
}
70
-
71
-
pub fn from_db_fields(
72
-
takedown_ref: Option<&str>,
73
-
deactivated_at: Option<chrono::DateTime<chrono::Utc>>,
74
-
) -> Self {
75
-
if takedown_ref.is_some() {
76
-
Self::Takendown
77
-
} else if deactivated_at.is_some() {
78
-
Self::Deactivated
79
-
} else {
80
-
Self::Active
81
-
}
82
-
}
83
-
}
84
-
85
-
impl From<crate::types::AccountState> for AccountStatus {
86
-
fn from(state: crate::types::AccountState) -> Self {
87
-
match state {
88
-
crate::types::AccountState::Active => AccountStatus::Active,
89
-
crate::types::AccountState::Deactivated { .. } => AccountStatus::Deactivated,
90
-
crate::types::AccountState::TakenDown { .. } => AccountStatus::Takendown,
91
-
crate::types::AccountState::Migrated { .. } => AccountStatus::Deactivated,
92
-
}
93
-
}
94
-
}
95
-
96
-
impl From<&crate::types::AccountState> for AccountStatus {
97
-
fn from(state: &crate::types::AccountState) -> Self {
98
-
match state {
99
-
crate::types::AccountState::Active => AccountStatus::Active,
100
-
crate::types::AccountState::Deactivated { .. } => AccountStatus::Deactivated,
101
-
crate::types::AccountState::TakenDown { .. } => AccountStatus::Takendown,
102
-
crate::types::AccountState::Migrated { .. } => AccountStatus::Deactivated,
103
-
}
104
-
}
105
-
}
106
-
107
21
pub struct RepoAccount {
108
22
pub did: String,
109
23
pub user_id: uuid::Uuid,
···
233
147
let frame = IdentityFrame {
234
148
did: event.did.to_string(),
235
149
handle: event.handle.as_ref().map(|h| h.to_string()),
236
-
seq: event.seq,
150
+
seq: event.seq.as_i64(),
237
151
time: format_atproto_time(event.created_at),
238
152
};
239
153
let header = FrameHeader {
···
250
164
let frame = AccountFrame {
251
165
did: event.did.to_string(),
252
166
active: event.active.unwrap_or(true),
253
-
status: event.status.clone(),
254
-
seq: event.seq,
167
+
status: event
168
+
.status
169
+
.and_then(|s| s.for_firehose().map(String::from)),
170
+
seq: event.seq.as_i64(),
255
171
time: format_atproto_time(event.created_at),
256
172
};
257
173
let header = FrameHeader {
···
298
214
did: event.did.to_string(),
299
215
rev,
300
216
blocks: car_bytes,
301
-
seq: event.seq,
217
+
seq: event.seq.as_i64(),
302
218
time: format_atproto_time(event.created_at),
303
219
};
304
220
let header = FrameHeader {
···
315
231
state: &AppState,
316
232
event: SequencedEvent,
317
233
) -> Result<Vec<u8>, anyhow::Error> {
318
-
match event.event_type.as_str() {
319
-
"identity" => return format_identity_event(&event),
320
-
"account" => return format_account_event(&event),
321
-
"sync" => return format_sync_event(state, &event).await,
322
-
_ => {}
234
+
match event.event_type {
235
+
RepoEventType::Identity => return format_identity_event(&event),
236
+
RepoEventType::Account => return format_account_event(&event),
237
+
RepoEventType::Sync => return format_sync_event(state, &event).await,
238
+
RepoEventType::Commit => {}
323
239
}
324
240
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
325
241
let prev_cid_link = event.prev_cid.clone();
···
440
356
did: event.did.to_string(),
441
357
rev,
442
358
blocks: car_bytes,
443
-
seq: event.seq,
359
+
seq: event.seq.as_i64(),
444
360
time: format_atproto_time(event.created_at),
445
361
};
446
362
let header = FrameHeader {
···
457
373
event: SequencedEvent,
458
374
prefetched: &HashMap<Cid, Bytes>,
459
375
) -> Result<Vec<u8>, anyhow::Error> {
460
-
match event.event_type.as_str() {
461
-
"identity" => return format_identity_event(&event),
462
-
"account" => return format_account_event(&event),
463
-
"sync" => return format_sync_event_with_prefetched(&event, prefetched),
464
-
_ => {}
376
+
match event.event_type {
377
+
RepoEventType::Identity => return format_identity_event(&event),
378
+
RepoEventType::Account => return format_account_event(&event),
379
+
RepoEventType::Sync => return format_sync_event_with_prefetched(&event, prefetched),
380
+
RepoEventType::Commit => {}
465
381
}
466
382
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
467
383
let prev_cid_link = event.prev_cid.clone();
+20
-4
crates/tranquil-pds/src/util.rs
+20
-4
crates/tranquil-pds/src/util.rs
···
4
4
use rand::Rng;
5
5
use serde_json::Value as JsonValue;
6
6
use std::collections::BTreeMap;
7
+
use std::net::SocketAddr;
7
8
use std::str::FromStr;
8
9
use std::sync::OnceLock;
9
10
···
11
12
const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
12
13
13
14
static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
15
+
static PDS_HOSTNAME: OnceLock<String> = OnceLock::new();
16
+
static PDS_HOSTNAME_WITHOUT_PORT: OnceLock<String> = OnceLock::new();
14
17
15
18
pub fn get_max_blob_size() -> usize {
16
19
*MAX_BLOB_SIZE.get_or_init(|| {
···
69
72
.unwrap_or_default()
70
73
}
71
74
72
-
pub fn extract_client_ip(headers: &HeaderMap) -> String {
75
+
pub fn get_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
76
+
headers.get(name).and_then(|h| h.to_str().ok())
77
+
}
78
+
79
+
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
73
80
if let Some(forwarded) = headers.get("x-forwarded-for")
74
81
&& let Ok(value) = forwarded.to_str()
75
82
&& let Some(first_ip) = value.split(',').next()
···
81
88
{
82
89
return value.trim().to_string();
83
90
}
84
-
"unknown".to_string()
91
+
addr.map(|a| a.ip().to_string())
92
+
.unwrap_or_else(|| "unknown".to_string())
85
93
}
86
94
87
-
pub fn pds_hostname() -> String {
88
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
95
+
pub fn pds_hostname() -> &'static str {
96
+
PDS_HOSTNAME
97
+
.get_or_init(|| std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()))
98
+
}
99
+
100
+
pub fn pds_hostname_without_port() -> &'static str {
101
+
PDS_HOSTNAME_WITHOUT_PORT.get_or_init(|| {
102
+
let hostname = pds_hostname();
103
+
hostname.split(':').next().unwrap_or(hostname).to_string()
104
+
})
89
105
}
90
106
91
107
pub fn pds_public_url() -> String {
+1
-1
crates/tranquil-pds/tests/commit_signing.rs
+1
-1
crates/tranquil-pds/tests/commit_signing.rs
···
99
99
use tranquil_pds::api::repo::record::utils::create_signed_commit;
100
100
101
101
let signing_key = SigningKey::random(&mut rand::thread_rng());
102
-
let did = Did::new_unchecked("did:plc:testuser123456789abcdef");
102
+
let did = unsafe { Did::new_unchecked("did:plc:testuser123456789abcdef") };
103
103
let data_cid =
104
104
Cid::from_str("bafyreib2rxk3ryblouj3fxza5jvx6psmwewwessc4m6g6e7pqhhkwqomfi").unwrap();
105
105
let rev = Tid::now(LimitedU32::MIN).to_string();
+2
-1
crates/tranquil-pds/tests/common/mod.rs
+2
-1
crates/tranquil-pds/tests/common/mod.rs
···
14
14
#[allow(unused_imports)]
15
15
use std::time::Duration;
16
16
use tokio::net::TcpListener;
17
+
use tokio_util::sync::CancellationToken;
17
18
use tranquil_pds::state::AppState;
18
19
use wiremock::matchers::{method, path};
19
20
use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate};
···
546
547
.with_email_update_limit(10000)
547
548
.with_oauth_authorize_limit(10000)
548
549
.with_oauth_token_limit(10000);
549
-
let state = AppState::from_db(pool)
550
+
let state = AppState::from_db(pool, CancellationToken::new())
550
551
.await
551
552
.with_rate_limiters(rate_limiters);
552
553
tranquil_pds::sync::listener::start_sequencer_listener(state.clone()).await;
+6
-10
crates/tranquil-pds/tests/firehose_validation.rs
+6
-10
crates/tranquil-pds/tests/firehose_validation.rs
···
9
9
use serde_json::{Value, json};
10
10
use std::io::Cursor;
11
11
use tokio_tungstenite::{connect_async, tungstenite};
12
+
use tranquil_scopes::RepoAction;
12
13
13
14
#[derive(Debug, Deserialize, Serialize)]
14
15
struct FrameHeader {
···
39
40
40
41
#[derive(Debug, Deserialize)]
41
42
struct RepoOp {
42
-
action: String,
43
+
action: RepoAction,
43
44
path: String,
44
45
cid: Option<Cid>,
45
46
prev: Option<Cid>,
···
292
293
println!("\nOps validation:");
293
294
for (i, op) in frame.ops.iter().enumerate() {
294
295
println!(" Op {}:", i);
295
-
println!(" action: {}", op.action);
296
+
println!(" action: {:?}", op.action);
296
297
println!(" path: {}", op.path);
297
298
println!(" cid: {:?}", op.cid);
298
299
println!(
···
300
301
op.prev
301
302
);
302
303
303
-
assert!(
304
-
["create", "update", "delete"].contains(&op.action.as_str()),
305
-
"Invalid action: {}",
306
-
op.action
307
-
);
308
304
assert!(
309
305
op.path.contains('/'),
310
306
"Path should contain collection/rkey: {}",
311
307
op.path
312
308
);
313
309
314
-
if op.action == "create" {
310
+
if op.action == RepoAction::Create {
315
311
assert!(op.cid.is_some(), "Create op should have cid");
316
312
}
317
313
}
···
445
441
446
442
for op in &frame.ops {
447
443
println!(
448
-
"Op: action={}, path={}, cid={:?}, prev={:?}",
444
+
"Op: action={:?}, path={}, cid={:?}, prev={:?}",
449
445
op.action, op.path, op.cid, op.prev
450
446
);
451
447
452
-
if op.action == "update" && op.path.contains("app.bsky.actor.profile") {
448
+
if op.action == RepoAction::Update && op.path.contains("app.bsky.actor.profile") {
453
449
assert!(
454
450
op.prev.is_some(),
455
451
"Update operation should have 'prev' field with old CID! Got: {:?}",
+71
crates/tranquil-pds/tests/shutdown_unit.rs
+71
crates/tranquil-pds/tests/shutdown_unit.rs
···
1
+
use std::sync::Arc;
2
+
use std::sync::atomic::{AtomicBool, Ordering};
3
+
use tokio_util::sync::CancellationToken;
4
+
5
+
#[test]
6
+
fn test_panic_hook_cancels_shutdown_token() {
7
+
let shutdown = CancellationToken::new();
8
+
let shutdown_clone = shutdown.clone();
9
+
10
+
let panic_occurred = Arc::new(AtomicBool::new(false));
11
+
let panic_occurred_clone = panic_occurred.clone();
12
+
13
+
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
14
+
let default_hook = std::panic::take_hook();
15
+
std::panic::set_hook(Box::new(move |info| {
16
+
panic_occurred_clone.store(true, Ordering::SeqCst);
17
+
shutdown_clone.cancel();
18
+
default_hook(info);
19
+
}));
20
+
21
+
panic!("simulated corrupted data panic");
22
+
}));
23
+
24
+
assert!(result.is_err());
25
+
assert!(panic_occurred.load(Ordering::SeqCst));
26
+
assert!(shutdown.is_cancelled());
27
+
28
+
let _ = std::panic::take_hook();
29
+
}
30
+
31
+
#[test]
32
+
fn test_cancellation_token_propagates_to_clones() {
33
+
let shutdown = CancellationToken::new();
34
+
let clone1 = shutdown.clone();
35
+
let clone2 = shutdown.clone();
36
+
37
+
assert!(!shutdown.is_cancelled());
38
+
assert!(!clone1.is_cancelled());
39
+
assert!(!clone2.is_cancelled());
40
+
41
+
shutdown.cancel();
42
+
43
+
assert!(shutdown.is_cancelled());
44
+
assert!(clone1.is_cancelled());
45
+
assert!(clone2.is_cancelled());
46
+
}
47
+
48
+
#[tokio::test]
49
+
async fn test_cancelled_future_completes_on_cancel() {
50
+
let shutdown = CancellationToken::new();
51
+
let shutdown_clone = shutdown.clone();
52
+
53
+
let handle = tokio::spawn(async move {
54
+
shutdown_clone.cancelled().await;
55
+
true
56
+
});
57
+
58
+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
59
+
assert!(!handle.is_finished());
60
+
61
+
shutdown.cancel();
62
+
63
+
let result = tokio::time::timeout(
64
+
std::time::Duration::from_millis(100),
65
+
handle,
66
+
)
67
+
.await;
68
+
69
+
assert!(result.is_ok());
70
+
assert!(result.unwrap().unwrap());
71
+
}
+25
-17
crates/tranquil-pds/tests/sso.rs
+25
-17
crates/tranquil-pds/tests/sso.rs
···
232
232
let _url = base_url().await;
233
233
let pool = get_test_db_pool().await;
234
234
235
-
let did = Did::new_unchecked(format!(
236
-
"did:plc:test{}",
237
-
&uuid::Uuid::new_v4().simple().to_string()[..12]
238
-
));
235
+
let did = unsafe {
236
+
Did::new_unchecked(format!(
237
+
"did:plc:test{}",
238
+
&uuid::Uuid::new_v4().simple().to_string()[..12]
239
+
))
240
+
};
239
241
let provider = SsoProviderType::Github;
240
242
let provider_user_id = format!("github_user_{}", uuid::Uuid::new_v4().simple());
241
243
···
350
352
let _url = base_url().await;
351
353
let pool = get_test_db_pool().await;
352
354
353
-
let did1 = Did::new_unchecked(format!(
354
-
"did:plc:uc1{}",
355
-
&uuid::Uuid::new_v4().simple().to_string()[..10]
356
-
));
357
-
let did2 = Did::new_unchecked(format!(
358
-
"did:plc:uc2{}",
359
-
&uuid::Uuid::new_v4().simple().to_string()[..10]
360
-
));
355
+
let did1 = unsafe {
356
+
Did::new_unchecked(format!(
357
+
"did:plc:uc1{}",
358
+
&uuid::Uuid::new_v4().simple().to_string()[..10]
359
+
))
360
+
};
361
+
let did2 = unsafe {
362
+
Did::new_unchecked(format!(
363
+
"did:plc:uc2{}",
364
+
&uuid::Uuid::new_v4().simple().to_string()[..10]
365
+
))
366
+
};
361
367
let provider_user_id = format!("unique_test_{}", uuid::Uuid::new_v4().simple());
362
368
363
369
sqlx::query!(
···
577
583
let _url = base_url().await;
578
584
let pool = get_test_db_pool().await;
579
585
580
-
let did = Did::new_unchecked(format!(
581
-
"did:plc:del{}",
582
-
&uuid::Uuid::new_v4().simple().to_string()[..10]
583
-
));
584
-
let wrong_did = Did::new_unchecked("did:plc:wrongdid12345");
586
+
let did = unsafe {
587
+
Did::new_unchecked(format!(
588
+
"did:plc:del{}",
589
+
&uuid::Uuid::new_v4().simple().to_string()[..10]
590
+
))
591
+
};
592
+
let wrong_did = unsafe { Did::new_unchecked("did:plc:wrongdid12345") };
585
593
586
594
sqlx::query!(
587
595
"INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
+3
-1
crates/tranquil-scopes/src/parser.rs
+3
-1
crates/tranquil-scopes/src/parser.rs
···
1
+
use serde::{Deserialize, Serialize};
1
2
use std::collections::{HashMap, HashSet};
2
3
3
4
#[derive(Debug, Clone, PartialEq, Eq)]
···
27
28
pub actions: HashSet<RepoAction>,
28
29
}
29
30
30
-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32
+
#[serde(rename_all = "lowercase")]
31
33
pub enum RepoAction {
32
34
Create,
33
35
Update,
+6
-3
crates/tranquil-types/src/lib.rs
+6
-3
crates/tranquil-types/src/lib.rs
···
166
166
Ok(Self(s))
167
167
}
168
168
169
-
pub fn new_unchecked(s: impl Into<String>) -> Self {
169
+
#[allow(unsafe_code, clippy::missing_safety_doc)]
170
+
pub unsafe fn new_unchecked(s: impl Into<String>) -> Self {
170
171
Self(s.into())
171
172
}
172
173
}
···
228
229
Ok(Self(s))
229
230
}
230
231
231
-
pub fn new_unchecked(s: impl Into<String>) -> Self {
232
+
#[allow(unsafe_code, clippy::missing_safety_doc)]
233
+
pub unsafe fn new_unchecked(s: impl Into<String>) -> Self {
232
234
Self(s.into())
233
235
}
234
236
}
···
489
491
Ok(Self(s))
490
492
}
491
493
492
-
pub fn new_unchecked(s: impl Into<String>) -> Self {
494
+
#[allow(unsafe_code, clippy::missing_safety_doc)]
495
+
pub unsafe fn new_unchecked(s: impl Into<String>) -> Self {
493
496
Self(s.into())
494
497
}
495
498
History
3 rounds
0 comments
expand 0 comments
pull request successfully merged