this repo has no description

Performance enhancements, overengineering

lewis 2f185d97 519048c3

Changed files
+2998 -2877
.sqlx
migrations
observability
scripts
src
tests
+113 -9
.env.example
··· 1 + # ============================================================================= 2 + # Server 3 + # ============================================================================= 1 4 SERVER_HOST=127.0.0.1 2 5 SERVER_PORT=3000 3 6 7 + # The public-facing hostname of the PDS (used in DID documents, JWTs, etc.) 8 + PDS_HOSTNAME=localhost:3000 9 + 10 + # ============================================================================= 11 + # Database 12 + # ============================================================================= 4 13 DATABASE_URL=postgres://postgres:postgres@localhost:5432/pds 5 14 6 - S3_ENDPOINT=http://objsto:9000 15 + # Connection pool settings (defaults are good for most deployments) 16 + # DATABASE_MAX_CONNECTIONS=100 17 + # DATABASE_MIN_CONNECTIONS=10 18 + # DATABASE_ACQUIRE_TIMEOUT_SECS=30 19 + 20 + # ============================================================================= 21 + # Blob Storage (S3-compatible) 22 + # ============================================================================= 23 + S3_ENDPOINT=http://localhost:9000 7 24 AWS_REGION=us-east-1 8 25 S3_BUCKET=pds-blobs 9 26 AWS_ACCESS_KEY_ID=minioadmin 10 27 AWS_SECRET_ACCESS_KEY=minioadmin 11 28 12 - # The public-facing hostname of the PDS 13 - PDS_HOSTNAME=localhost:3000 14 - PLC_URL=plc.directory 29 + # ============================================================================= 30 + # Valkey (for caching and distributed rate limiting) 31 + # ============================================================================= 32 + # If not set, falls back to in-memory caching (single-node only) 33 + # VALKEY_URL=redis://localhost:6379 34 + 35 + # ============================================================================= 36 + # Security Secrets 37 + # ============================================================================= 38 + # These MUST be set in production (minimum 32 characters each) 39 + # In development, set BSPDS_ALLOW_INSECURE_SECRETS=1 to use defaults 40 + 41 + # Server-wide secret for OAuth token signing (HS256) 42 + # JWT_SECRET=your-secure-random-string-at-least-32-chars 15 43 16 - # A comma-separated list of relay URLs to notify via requestCrawl when we have updates. 17 - # e.g., CRAWLERS=https://bsky.network 18 - CRAWLERS= 44 + # Secret for DPoP proof validation 45 + # DPOP_SECRET=your-secure-random-string-at-least-32-chars 19 46 20 - # Notification Service Configuration 21 - # At least one notification channel should be configured for user notifications to work. 47 + # Key for encrypting user signing keys at rest (AES-256-GCM) 48 + # MASTER_KEY=your-secure-random-string-at-least-32-chars 49 + 50 + # Set this ONLY in development to allow default/weak secrets 51 + # BSPDS_ALLOW_INSECURE_SECRETS=1 52 + 53 + # ============================================================================= 54 + # PLC Directory 55 + # ============================================================================= 56 + # PLC_DIRECTORY_URL=https://plc.directory 57 + # PLC_TIMEOUT_SECS=10 58 + # PLC_CONNECT_TIMEOUT_SECS=5 59 + 60 + # Optional: rotation key for PLC operations (defaults to user's key) 61 + # PLC_ROTATION_KEY=did:key:... 62 + 63 + # ============================================================================= 64 + # Federation 65 + # ============================================================================= 66 + # Appview URL for proxying app.bsky.* requests 67 + # APPVIEW_URL=https://api.bsky.app 68 + 69 + # Comma-separated list of relay URLs to notify via requestCrawl 70 + # CRAWLERS=https://bsky.network 71 + 72 + # ============================================================================= 73 + # Firehose (subscribeRepos WebSocket) 74 + # ============================================================================= 75 + # Buffer size for firehose broadcast channel 76 + # FIREHOSE_BUFFER_SIZE=10000 77 + 78 + # Disconnect slow consumers after this many events of lag 79 + # FIREHOSE_MAX_LAG=5000 80 + 81 + # ============================================================================= 82 + # Notification Service 83 + # ============================================================================= 84 + # Queue processing settings 85 + # NOTIFICATION_BATCH_SIZE=100 86 + # NOTIFICATION_POLL_INTERVAL_MS=1000 22 87 23 88 # Email notifications (via sendmail/msmtp) 24 89 # MAIL_FROM_ADDRESS=noreply@example.com ··· 34 99 # Signal notifications (via signal-cli) 35 100 # SIGNAL_CLI_PATH=/usr/local/bin/signal-cli 36 101 # SIGNAL_SENDER_NUMBER=+1234567890 102 + 103 + # ============================================================================= 104 + # Repository Import 105 + # ============================================================================= 106 + # Set to "true" to accept repository imports 107 + # ACCEPTING_REPO_IMPORTS=false 108 + 109 + # Maximum import size in bytes (default: 50MB) 110 + # MAX_IMPORT_SIZE=52428800 111 + 112 + # Maximum blocks per import (default: 100000) 113 + # MAX_IMPORT_BLOCKS=100000 114 + 115 + # Skip verification during import (testing only) 116 + # SKIP_IMPORT_VERIFICATION=false 117 + 118 + # ============================================================================= 119 + # Account Registration 120 + # ============================================================================= 121 + # Require invite codes for registration 122 + # INVITE_CODE_REQUIRED=false 123 + 124 + # Comma-separated list of available user domains 125 + # AVAILABLE_USER_DOMAINS=example.com 126 + 127 + # ============================================================================= 128 + # Rate Limiting 129 + # ============================================================================= 130 + # Disable all rate limiting (testing only, NEVER in production) 131 + # DISABLE_RATE_LIMITING=1 132 + 133 + # ============================================================================= 134 + # Miscellaneous 135 + # ============================================================================= 136 + # Allow HTTP for proxy requests (development only) 137 + # ALLOW_HTTP_PROXY=1 138 + 139 + # Custom frontend directory (defaults to ./frontend/dist) 140 + # FRONTEND_DIR=/path/to/frontend/dist 37 141 38 142 CARGO_MOMMYS_LITTLE=mister 39 143 CARGO_MOMMYS_PRONOUNS=his
+28
.sqlx/query-04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "deactivated_at", 9 + "type_info": "Timestamptz" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "takedown_ref", 14 + "type_info": "Text" 15 + } 16 + ], 17 + "parameters": { 18 + "Left": [ 19 + "Text" 20 + ] 21 + }, 22 + "nullable": [ 23 + true, 24 + true 25 + ] 26 + }, 27 + "hash": "04c220298334c369872f0b0ad162b992c2353e28257b53f3f10cbff8abb26f5a" 28 + }
-18
.sqlx/query-0f10bde03edc0233a332e210a84a4186977c71efd3be80e2508a60ea5802cb1b.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Text", 9 - "Text", 10 - "Int8", 11 - "Uuid", 12 - "Text" 13 - ] 14 - }, 15 - "nullable": [] 16 - }, 17 - "hash": "0f10bde03edc0233a332e210a84a4186977c71efd3be80e2508a60ea5802cb1b" 18 - }
+22
.sqlx/query-1c831cb6f3b8d01b18feec900148278c2b491418b622da9e75fe1792089e4409.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT access_jti FROM session_tokens WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "access_jti", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "1c831cb6f3b8d01b18feec900148278c2b491418b622da9e75fe1792089e4409" 22 + }
+26
.sqlx/query-25ac36e9dec1c8e29cbe7cfc954683061c7c2733fa60f91f1c5ced4d00e7bf3d.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING RETURNING cid", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "cid", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Text", 15 + "Text", 16 + "Int8", 17 + "Uuid", 18 + "Text" 19 + ] 20 + }, 21 + "nullable": [ 22 + false 23 + ] 24 + }, 25 + "hash": "25ac36e9dec1c8e29cbe7cfc954683061c7c2733fa60f91f1c5ced4d00e7bf3d" 26 + }
-14
.sqlx/query-3b1176253dc7b94d3fc58c077310d8058f90edf1fa27200b52b464b9c37335dd.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Uuid" 9 - ] 10 - }, 11 - "nullable": [] 12 - }, 13 - "hash": "3b1176253dc7b94d3fc58c077310d8058f90edf1fa27200b52b464b9c37335dd" 14 - }
+2 -2
.sqlx/query-6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761.json .sqlx/query-f4f4b6a9e5d2345efa8e48380f66c819c1818030aa4bf26757d9fb40e654b693.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1", 3 + "query": "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref\n FROM users u\n JOIN user_keys k ON u.id = k.user_id\n WHERE u.did = $1", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 36 36 true 37 37 ] 38 38 }, 39 - "hash": "6b67b2b6759f01be11d5997a3ad68d381f59a02235a6940877f62193af8d9761" 39 + "hash": "f4f4b6a9e5d2345efa8e48380f66c819c1818030aa4bf26757d9fb40e654b693" 40 40 }
+70
.sqlx/query-8a7a8f0c4c0872c21c46d484219624215bdb14617b9f9a44974e394a28147f70.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "created_at", 19 + "type_info": "Timestamptz" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "event_type", 24 + "type_info": "Text" 25 + }, 26 + { 27 + "ordinal": 4, 28 + "name": "commit_cid", 29 + "type_info": "Text" 30 + }, 31 + { 32 + "ordinal": 5, 33 + "name": "prev_cid", 34 + "type_info": "Text" 35 + }, 36 + { 37 + "ordinal": 6, 38 + "name": "ops", 39 + "type_info": "Jsonb" 40 + }, 41 + { 42 + "ordinal": 7, 43 + "name": "blobs", 44 + "type_info": "TextArray" 45 + }, 46 + { 47 + "ordinal": 8, 48 + "name": "blocks_cids", 49 + "type_info": "TextArray" 50 + } 51 + ], 52 + "parameters": { 53 + "Left": [ 54 + "Int8" 55 + ] 56 + }, 57 + "nullable": [ 58 + false, 59 + false, 60 + false, 61 + false, 62 + true, 63 + true, 64 + true, 65 + true, 66 + true 67 + ] 68 + }, 69 + "hash": "8a7a8f0c4c0872c21c46d484219624215bdb14617b9f9a44974e394a28147f70" 70 + }
-18
.sqlx/query-8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Uuid", 9 - "Text", 10 - "Text", 11 - "Text", 12 - "Text" 13 - ] 14 - }, 15 - "nullable": [] 16 - }, 17 - "hash": "8a9e71f04ec779d5c10d79582cc398529e01be01a83898df3524bb35e3d2ed14" 18 - }
-16
.sqlx/query-8c9297289cb753c8eaa4231ae9eab6cd3367f9bf543d9f49bca4afa53434ce0d.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Uuid", 9 - "Text", 10 - "Text" 11 - ] 12 - }, 13 - "nullable": [] 14 - }, 15 - "hash": "8c9297289cb753c8eaa4231ae9eab6cd3367f9bf543d9f49bca4afa53434ce0d" 16 - }
+16
.sqlx/query-9806777e3db4db9e9a905a6ce26375f026aa8a6db2c5534cf5ccf9758a07ee39.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n DELETE FROM records\n WHERE repo_id = $1\n AND (collection, rkey) IN (SELECT * FROM UNNEST($2::text[], $3::text[]))\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Uuid", 9 + "TextArray", 10 + "TextArray" 11 + ] 12 + }, 13 + "nullable": [] 14 + }, 15 + "hash": "9806777e3db4db9e9a905a6ce26375f026aa8a6db2c5534cf5ccf9758a07ee39" 16 + }
+71
.sqlx/query-a63aed47193f06cd11d87157799c17a591e0a0be4487f718250eaf7afd4b4b07.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1 AND seq < $2\n ORDER BY seq ASC\n ", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "seq", 9 + "type_info": "Int8" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "created_at", 19 + "type_info": "Timestamptz" 20 + }, 21 + { 22 + "ordinal": 3, 23 + "name": "event_type", 24 + "type_info": "Text" 25 + }, 26 + { 27 + "ordinal": 4, 28 + "name": "commit_cid", 29 + "type_info": "Text" 30 + }, 31 + { 32 + "ordinal": 5, 33 + "name": "prev_cid", 34 + "type_info": "Text" 35 + }, 36 + { 37 + "ordinal": 6, 38 + "name": "ops", 39 + "type_info": "Jsonb" 40 + }, 41 + { 42 + "ordinal": 7, 43 + "name": "blobs", 44 + "type_info": "TextArray" 45 + }, 46 + { 47 + "ordinal": 8, 48 + "name": "blocks_cids", 49 + "type_info": "TextArray" 50 + } 51 + ], 52 + "parameters": { 53 + "Left": [ 54 + "Int8", 55 + "Int8" 56 + ] 57 + }, 58 + "nullable": [ 59 + false, 60 + false, 61 + false, 62 + false, 63 + true, 64 + true, 65 + true, 66 + true, 67 + true 68 + ] 69 + }, 70 + "hash": "a63aed47193f06cd11d87157799c17a591e0a0be4487f718250eaf7afd4b4b07" 71 + }
+20
.sqlx/query-b2a217b405ace1726097631c7fa532bf1a7330f11328a1e68d5eced41cad8a78.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "max", 9 + "type_info": "Int8" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [] 14 + }, 15 + "nullable": [ 16 + null 17 + ] 18 + }, 19 + "hash": "b2a217b405ace1726097631c7fa532bf1a7330f11328a1e68d5eced41cad8a78" 20 + }
+28
.sqlx/query-b9848ea8f168e1ab975dc2ad125b5b9e478e74254a8cf670e55b728bc402f046.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT cid, data FROM blocks WHERE cid = ANY($1)", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "cid", 9 + "type_info": "Bytea" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "data", 14 + "type_info": "Bytea" 15 + } 16 + ], 17 + "parameters": { 18 + "Left": [ 19 + "ByteaArray" 20 + ] 21 + }, 22 + "nullable": [ 23 + false, 24 + false 25 + ] 26 + }, 27 + "hash": "b9848ea8f168e1ab975dc2ad125b5b9e478e74254a8cf670e55b728bc402f046" 28 + }
+15
.sqlx/query-c9b624a9987dd263e908fcff4612e1cd446552c93d80254d9e15c2e51a95a596.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO blocks (cid, data)\n SELECT * FROM UNNEST($1::bytea[], $2::bytea[])\n ON CONFLICT (cid) DO NOTHING\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "ByteaArray", 9 + "ByteaArray" 10 + ] 11 + }, 12 + "nullable": [] 13 + }, 14 + "hash": "c9b624a9987dd263e908fcff4612e1cd446552c93d80254d9e15c2e51a95a596" 15 + }
+22
.sqlx/query-dcaedeec794a63ce8abb9b580461c193ad58fee110d57249f98355b40b757a37.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT password_hash FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC LIMIT 20", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "password_hash", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "dcaedeec794a63ce8abb9b580461c193ad58fee110d57249f98355b40b757a37" 22 + }
+18
.sqlx/query-e1066ab3a86852164e39848733c0f7e837657ea6595ea0094a6135673ea924a5.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "\n INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev)\n SELECT $1, collection, rkey, record_cid, $5\n FROM UNNEST($2::text[], $3::text[], $4::text[]) AS t(collection, rkey, record_cid)\n ON CONFLICT (repo_id, collection, rkey) DO UPDATE\n SET record_cid = EXCLUDED.record_cid, repo_rev = EXCLUDED.repo_rev, created_at = NOW()\n ", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Uuid", 9 + "TextArray", 10 + "TextArray", 11 + "TextArray", 12 + "Text" 13 + ] 14 + }, 15 + "nullable": [] 16 + }, 17 + "hash": "e1066ab3a86852164e39848733c0f7e837657ea6595ea0094a6135673ea924a5" 18 + }
+84
Cargo.lock
··· 63 63 ] 64 64 65 65 [[package]] 66 + name = "ahash" 67 + version = "0.8.12" 68 + source = "registry+https://github.com/rust-lang/crates.io-index" 69 + checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" 70 + dependencies = [ 71 + "cfg-if", 72 + "once_cell", 73 + "version_check", 74 + "zerocopy", 75 + ] 76 + 77 + [[package]] 66 78 name = "aho-corasick" 67 79 version = "1.1.4" 68 80 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 941 953 "jacquard-repo", 942 954 "jsonwebtoken", 943 955 "k256", 956 + "metrics", 957 + "metrics-exporter-prometheus", 944 958 "multibase", 945 959 "multihash", 946 960 "p256 0.13.2", ··· 1342 1356 version = "0.5.15" 1343 1357 source = "registry+https://github.com/rust-lang/crates.io-index" 1344 1358 checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" 1359 + dependencies = [ 1360 + "crossbeam-utils", 1361 + ] 1362 + 1363 + [[package]] 1364 + name = "crossbeam-epoch" 1365 + version = "0.9.18" 1366 + source = "registry+https://github.com/rust-lang/crates.io-index" 1367 + checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 1345 1368 dependencies = [ 1346 1369 "crossbeam-utils", 1347 1370 ] ··· 3560 3583 checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" 3561 3584 3562 3585 [[package]] 3586 + name = "metrics" 3587 + version = "0.24.3" 3588 + source = "registry+https://github.com/rust-lang/crates.io-index" 3589 + checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8" 3590 + dependencies = [ 3591 + "ahash", 3592 + "portable-atomic", 3593 + ] 3594 + 3595 + [[package]] 3596 + name = "metrics-exporter-prometheus" 3597 + version = "0.16.2" 3598 + source = "registry+https://github.com/rust-lang/crates.io-index" 3599 + checksum = "dd7399781913e5393588a8d8c6a2867bf85fb38eaf2502fdce465aad2dc6f034" 3600 + dependencies = [ 3601 + "base64 0.22.1", 3602 + "http-body-util", 3603 + "hyper 1.8.1", 3604 + "hyper-util", 3605 + "indexmap 2.12.1", 3606 + "ipnet", 3607 + "metrics", 3608 + "metrics-util", 3609 + "quanta", 3610 + "thiserror 1.0.69", 3611 + "tokio", 3612 + "tracing", 3613 + ] 3614 + 3615 + [[package]] 3616 + name = "metrics-util" 3617 + version = "0.19.1" 3618 + source = "registry+https://github.com/rust-lang/crates.io-index" 3619 + checksum = "b8496cc523d1f94c1385dd8f0f0c2c480b2b8aeccb5b7e4485ad6365523ae376" 3620 + dependencies = [ 3621 + "crossbeam-epoch", 3622 + "crossbeam-utils", 3623 + "hashbrown 0.15.5", 3624 + "metrics", 3625 + "quanta", 3626 + "rand 0.9.2", 3627 + "rand_xoshiro", 3628 + "sketches-ddsketch", 3629 + ] 3630 + 3631 + [[package]] 3563 3632 name = "miette" 3564 3633 version = "7.6.0" 3565 3634 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4483 4552 ] 4484 4553 4485 4554 [[package]] 4555 + name = "rand_xoshiro" 4556 + version = "0.7.0" 4557 + source = "registry+https://github.com/rust-lang/crates.io-index" 4558 + checksum = "f703f4665700daf5512dcca5f43afa6af89f09db47fb56be587f80636bda2d41" 4559 + dependencies = [ 4560 + "rand_core 0.9.3", 4561 + ] 4562 + 4563 + [[package]] 4486 4564 name = "range-traits" 4487 4565 version = "0.3.2" 4488 4566 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5235 5313 "tempfile", 5236 5314 "walkdir", 5237 5315 ] 5316 + 5317 + [[package]] 5318 + name = "sketches-ddsketch" 5319 + version = "0.3.0" 5320 + source = "registry+https://github.com/rust-lang/crates.io-index" 5321 + checksum = "c1e9a774a6c28142ac54bb25d25562e6bcf957493a184f15ad4eebccb23e410a" 5238 5322 5239 5323 [[package]] 5240 5324 name = "slab"
+2
Cargo.toml
··· 51 51 image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } 52 52 redis = { version = "0.27", features = ["tokio-comp", "connection-manager"] } 53 53 tower-http = { version = "0.6", features = ["fs"] } 54 + metrics = "0.24" 55 + metrics-exporter-prometheus = { version = "0.16", default-features = false, features = ["http-listener"] } 54 56 55 57 [features] 56 58 external-infra = []
+1 -1
Dockerfile
··· 5 5 RUN deno task build 6 6 7 7 # Stage 2: Build Rust backend 8 - FROM rust:1.91.1-alpine AS builder 8 + FROM rust:1.92-alpine AS builder 9 9 10 10 RUN apk add ca-certificates openssl openssl-dev pkgconfig 11 11
+14 -101
README.md
··· 1 - # BSPDS, a Personal Data Server 2 - 3 - A production-grade Personal Data Server (PDS) implementation for the AT Protocol. 1 + # BSPDS 4 2 5 - Uses PostgreSQL instead of SQLite, S3-compatible blob storage, and is designed to be a complete drop-in replacement for Bluesky's reference PDS implementation. 3 + A production-grade Personal Data Server (PDS) for the AT Protocol. Drop-in replacement for Bluesky's reference PDS, using postgres and s3-compatible blob storage. 6 4 7 5 ## Features 8 6 9 - - Full AT Protocol support, all `com.atproto.*` endpoints implemented 10 - - OAuth 2.1 Provider. PKCE, DPoP, Pushed Authorization Requests 11 - - PostgreSQL, prod-ready database backend 12 - - S3-compatible object storage for blobs; works with AWS S3, UpCloud object storage, self-hosted MinIO, etc. 13 - - WebSocket `subscribeRepos` endpoint for real-time sync 14 - - Crawler notifications via `requestCrawl` 15 - - Multi-channel notifications: email, discord, telegram, signal 16 - - Per-IP rate limiting on sensitive endpoints 7 + - Full AT Protocol support (`com.atproto.*` endpoints) 8 + - OAuth 2.1 provider (PKCE, DPoP, PAR) 9 + - WebSocket firehose (`subscribeRepos`) 10 + - Multi-channel notifications (email, discord, telegram, signal) 17 11 - Built-in web UI for account management 18 - 19 - ## Running Locally 20 - 21 - Requires Rust installed locally. 22 - 23 - Run PostgreSQL and S3-compatible object store (e.g., with podman/docker): 24 - 25 - ```bash 26 - podman compose up db objsto -d 27 - ``` 12 + - Per-IP rate limiting 28 13 29 - Run the PDS: 14 + ## Quick Start 30 15 31 16 ```bash 17 + cp .env.example .env 18 + podman compose up -d 32 19 just run 33 20 ``` 34 21 35 22 ## Configuration 36 23 37 - ### Required 38 - 39 - | Variable | Description | 40 - |----------|-------------| 41 - | `DATABASE_URL` | PostgreSQL connection string | 42 - | `S3_BUCKET` | Blob storage bucket name | 43 - | `S3_ENDPOINT` | S3 endpoint URL (for MinIO, etc.) | 44 - | `AWS_ACCESS_KEY_ID` | S3 credentials | 45 - | `AWS_SECRET_ACCESS_KEY` | S3 credentials | 46 - | `AWS_REGION` | S3 region | 47 - | `PDS_HOSTNAME` | Public hostname of this PDS | 48 - | `JWT_SECRET` | Secret for OAuth token signing (HS256) | 49 - | `KEY_ENCRYPTION_KEY` | Key for encrypting user signing keys (AES-256-GCM) | 50 - 51 - ### Optional 52 - 53 - | Variable | Description | 54 - |----------|-------------| 55 - | `APPVIEW_URL` | Appview URL to proxy unimplemented endpoints to | 56 - | `CRAWLERS` | Comma-separated list of relay URLs to notify via `requestCrawl` | 57 - 58 - ### Notifications 59 - 60 - At least one channel should be configured for user notifications (password reset, email verification, etc.): 61 - 62 - | Variable | Description | 63 - |----------|-------------| 64 - | `MAIL_FROM_ADDRESS` | Email sender address (enables email via sendmail) | 65 - | `MAIL_FROM_NAME` | Email sender name (default: "BSPDS") | 66 - | `SENDMAIL_PATH` | Path to sendmail binary (default: /usr/sbin/sendmail) | 67 - | `DISCORD_WEBHOOK_URL` | Discord webhook URL for notifications | 68 - | `TELEGRAM_BOT_TOKEN` | Telegram bot token for notifications | 69 - | `SIGNAL_CLI_PATH` | Path to signal-cli binary | 70 - | `SIGNAL_SENDER_NUMBER` | Signal sender phone number (+1234567890 format) | 24 + See `.env.example` for all configuration options. 71 25 72 26 ## Development 73 27 74 - ```bash 75 - just # Show available commands 76 - just test # Run tests (auto-starts postgres/minio, runs nextest) 77 - just lint # Clippy + fmt check 78 - just db-reset # Drop and recreate local database 79 - ``` 80 - 81 - ## Web UI 82 - 83 - BSPDS includes a built-in web frontend for users to manage their accounts. Users can: 84 - 85 - - Sign in and register new accounts 86 - - Manage app passwords 87 - - View and create invite codes 88 - - Update email and handle 89 - - Configure notification preferences 90 - - Browse their repository data 91 - 92 - The frontend is built with svelte and deno, and is served directly by the PDS. 28 + Run `just` to see available commands. 93 29 94 30 ```bash 95 - just frontend-dev # Run frontend dev server 96 - just frontend-build # Build for production 97 - just frontend-test # Run frontend tests 98 - ``` 99 - 100 - ## Project Structure 101 - 102 - ``` 103 - src/ 104 - main.rs Server entrypoint 105 - lib.rs Router setup 106 - state.rs AppState (db pool, stores, rate limiters, circuit breakers) 107 - api/ XRPC handlers organized by namespace 108 - auth/ JWT authentication (ES256K per-user keys) 109 - oauth/ OAuth 2.1 provider (HS256 server-wide) 110 - repo/ PostgreSQL block store 111 - storage/ S3 blob storage 112 - sync/ Firehose, CAR export, crawler notifications 113 - notifications/ Multi-channel notification service 114 - plc/ PLC directory client 115 - circuit_breaker/ Circuit breaker for external services 116 - rate_limit/ Per-IP rate limiting 117 - frontend/ Svelte web UI (deno) 118 - tests/ Integration tests 119 - migrations/ SQLx migrations 31 + just test # run tests 32 + just lint # clippy + fmt 120 33 ``` 121 34 122 35 ## License
+2 -2
TODO.md
··· 201 201 - [x] DID Cache 202 202 - [x] Implement caching layer for DID resolution (valkey). 203 203 - [x] Handle cache invalidation/expiry. 204 - - [x] Graceful fallback to no-cache when Valkey unavailable. 204 + - [x] Graceful fallback to no-cache when valkey unavailable. 205 205 - [x] Crawlers Service 206 206 - [x] Implement `Crawlers` service (debounce notifications to relays). 207 207 - [x] 20-minute notification debounce. ··· 237 237 - [x] Per-IP rate limiting on OAuth revoke/introspect (30/min). 238 238 - [x] Per-IP rate limiting on createAppPassword (10/min). 239 239 - [x] Per-IP rate limiting on email endpoints (5/hour). 240 - - [x] Distributed rate limiting via Valkey/Redis (with in-memory fallback). 240 + - [x] Distributed rate limiting via valkey (with in-memory fallback). 241 241 - [x] Circuit Breakers 242 242 - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). 243 243 - [x] Relay notification circuit breaker (10 failures → open, 30s timeout).
+14
docker-compose.yaml
··· 47 47 volumes: 48 48 - valkey_data:/data 49 49 50 + prometheus: 51 + image: prom/prometheus:latest 52 + ports: 53 + - "9090:9090" 54 + volumes: 55 + - ./observability/prometheus.yml:/etc/prometheus/prometheus.yml:ro 56 + - prometheus_data:/prometheus 57 + command: 58 + - '--config.file=/etc/prometheus/prometheus.yml' 59 + - '--storage.tsdb.path=/prometheus' 60 + depends_on: 61 + - app 62 + 50 63 volumes: 51 64 postgres_data: 52 65 minio_data: 53 66 valkey_data: 67 + prometheus_data:
+21
migrations/20251213_performance_indexes.sql
··· 1 + CREATE INDEX IF NOT EXISTS idx_records_repo_collection 2 + ON records(repo_id, collection); 3 + 4 + CREATE INDEX IF NOT EXISTS idx_records_repo_collection_created 5 + ON records(repo_id, collection, created_at DESC); 6 + 7 + CREATE INDEX IF NOT EXISTS idx_users_email 8 + ON users(email) 9 + WHERE email IS NOT NULL; 10 + 11 + CREATE INDEX IF NOT EXISTS idx_blobs_created_by_user 12 + ON blobs(created_by_user, created_at DESC); 13 + 14 + CREATE INDEX IF NOT EXISTS idx_repo_seq_did_seq 15 + ON repo_seq(did, seq DESC); 16 + 17 + CREATE INDEX IF NOT EXISTS idx_app_passwords_user_id 18 + ON app_passwords(user_id); 19 + 20 + CREATE INDEX IF NOT EXISTS idx_invite_codes_created_by 21 + ON invite_codes(created_by_user);
+13
observability/prometheus.yml
··· 1 + global: 2 + scrape_interval: 15s 3 + evaluation_interval: 15s 4 + 5 + scrape_configs: 6 + - job_name: 'prometheus' 7 + static_configs: 8 + - targets: ['localhost:9090'] 9 + 10 + - job_name: 'bspds' 11 + static_configs: 12 + - targets: ['app:3000'] 13 + metrics_path: /metrics
+1
scripts/test-infra.sh
··· 114 114 export BSPDS_TEST_INFRA_READY="1" 115 115 export BSPDS_ALLOW_INSECURE_SECRETS="1" 116 116 export SKIP_IMPORT_VERIFICATION="true" 117 + export DISABLE_RATE_LIMITING="1" 117 118 EOF 118 119 119 120 echo ""
+2 -2
src/api/actor/profile.rs
··· 6 6 Json, 7 7 }; 8 8 use jacquard_repo::storage::BlockStore; 9 - use reqwest::Client; 9 + use crate::api::proxy_client::proxy_client; 10 10 use serde::{Deserialize, Serialize}; 11 11 use serde_json::{json, Value}; 12 12 use std::collections::HashMap; ··· 89 89 let target_url = format!("{}/xrpc/{}", appview_url, method); 90 90 info!("Proxying GET request to {}", target_url); 91 91 92 - let client = Client::new(); 92 + let client = proxy_client(); 93 93 let mut request_builder = client.get(&target_url).query(params); 94 94 95 95 if let Some(auth) = auth_header {
+2 -2
src/api/identity/account.rs
··· 1 1 use super::did::verify_did_web; 2 - use crate::state::AppState; 2 + use crate::state::{AppState, RateLimitKind}; 3 3 use axum::{ 4 4 Json, 5 5 extract::State, ··· 64 64 info!("create_account called"); 65 65 66 66 let client_ip = extract_client_ip(&headers); 67 - if state.rate_limiters.account_creation.check_key(&client_ip).is_err() { 67 + if !state.check_rate_limit(RateLimitKind::AccountCreation, &client_ip).await { 68 68 warn!(ip = %client_ip, "Account creation rate limit exceeded"); 69 69 return ( 70 70 StatusCode::TOO_MANY_REQUESTS,
+2 -2
src/api/proxy.rs
··· 5 5 http::{HeaderMap, Method, StatusCode}, 6 6 response::{IntoResponse, Response}, 7 7 }; 8 - use reqwest::Client; 8 + use crate::api::proxy_client::proxy_client; 9 9 use std::collections::HashMap; 10 10 use tracing::{error, info}; 11 11 ··· 36 36 37 37 info!("Proxying {} request to {}", method_verb, target_url); 38 38 39 - let client = Client::new(); 39 + let client = proxy_client(); 40 40 41 41 let mut request_builder = client.request(method_verb, &target_url).query(&params); 42 42
+43 -20
src/api/read_after_write.rs
··· 8 8 response::{IntoResponse, Response}, 9 9 Json, 10 10 }; 11 + use bytes::Bytes; 11 12 use chrono::{DateTime, Utc}; 13 + use cid::Cid; 12 14 use jacquard_repo::storage::BlockStore; 13 15 use serde::{Deserialize, Serialize}; 14 16 use serde_json::Value; ··· 137 139 return Ok(result); 138 140 } 139 141 140 - for row in rows { 141 - result.count += 1; 142 + struct RowData { 143 + cid_str: String, 144 + collection: String, 145 + rkey: String, 146 + created_at: DateTime<Utc>, 147 + } 148 + 149 + let mut row_data: Vec<RowData> = Vec::with_capacity(rows.len()); 150 + let mut cids: Vec<Cid> = Vec::with_capacity(rows.len()); 151 + 152 + for row in &rows { 153 + if let Ok(cid) = row.record_cid.parse::<Cid>() { 154 + cids.push(cid); 155 + row_data.push(RowData { 156 + cid_str: row.record_cid.clone(), 157 + collection: row.collection.clone(), 158 + rkey: row.rkey.clone(), 159 + created_at: row.created_at, 160 + }); 161 + } 162 + } 142 163 143 - let cid: cid::Cid = match row.record_cid.parse() { 144 - Ok(c) => c, 145 - Err(_) => continue, 146 - }; 164 + let blocks: Vec<Option<Bytes>> = state 165 + .block_store 166 + .get_many(&cids) 167 + .await 168 + .map_err(|e| format!("Error fetching blocks: {}", e))?; 147 169 148 - let block_bytes = match state.block_store.get(&cid).await { 149 - Ok(Some(b)) => b, 150 - _ => continue, 170 + for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) { 171 + let block_bytes = match block_opt { 172 + Some(b) => b, 173 + None => continue, 151 174 }; 152 175 153 - let uri = format!("at://{}/{}/{}", did, row.collection, row.rkey); 154 - let indexed_at = row.created_at; 176 + result.count += 1; 177 + let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey); 155 178 156 - if row.collection == "app.bsky.actor.profile" && row.rkey == "self" { 179 + if data.collection == "app.bsky.actor.profile" && data.rkey == "self" { 157 180 if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) { 158 181 result.profile = Some(RecordDescript { 159 182 uri, 160 - cid: row.record_cid, 161 - indexed_at, 183 + cid: data.cid_str, 184 + indexed_at: data.created_at, 162 185 record, 163 186 }); 164 187 } 165 - } else if row.collection == "app.bsky.feed.post" { 188 + } else if data.collection == "app.bsky.feed.post" { 166 189 if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) { 167 190 result.posts.push(RecordDescript { 168 191 uri, 169 - cid: row.record_cid, 170 - indexed_at, 192 + cid: data.cid_str, 193 + indexed_at: data.created_at, 171 194 record, 172 195 }); 173 196 } 174 - } else if row.collection == "app.bsky.feed.like" { 197 + } else if data.collection == "app.bsky.feed.like" { 175 198 if let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) { 176 199 result.likes.push(RecordDescript { 177 200 uri, 178 - cid: row.record_cid, 179 - indexed_at, 201 + cid: data.cid_str, 202 + indexed_at: data.created_at, 180 203 record, 181 204 }); 182 205 }
+45 -13
src/api/repo/blob.rs
··· 83 83 84 84 let storage_key = format!("blobs/{}", cid_str); 85 85 86 - if let Err(e) = state.blob_store.put(&storage_key, &data).await { 87 - error!("Failed to upload blob to storage: {:?}", e); 88 - return ( 89 - StatusCode::INTERNAL_SERVER_ERROR, 90 - Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 91 - ) 92 - .into_response(); 93 - } 94 - 95 86 let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 96 87 .fetch_optional(&state.db) 97 88 .await; ··· 107 98 } 108 99 }; 109 100 101 + let mut tx = match state.db.begin().await { 102 + Ok(tx) => tx, 103 + Err(e) => { 104 + error!("Failed to begin transaction: {:?}", e); 105 + return ( 106 + StatusCode::INTERNAL_SERVER_ERROR, 107 + Json(json!({"error": "InternalError"})), 108 + ) 109 + .into_response(); 110 + } 111 + }; 112 + 110 113 let insert = sqlx::query!( 111 - "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING", 114 + "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING RETURNING cid", 112 115 cid_str, 113 116 mime_type, 114 117 size, 115 118 user_id, 116 119 storage_key 117 120 ) 118 - .execute(&state.db) 121 + .fetch_optional(&mut *tx) 119 122 .await; 120 123 121 - if let Err(e) = insert { 122 - error!("Failed to insert blob record: {:?}", e); 124 + let was_inserted = match insert { 125 + Ok(Some(_)) => true, 126 + Ok(None) => false, 127 + Err(e) => { 128 + error!("Failed to insert blob record: {:?}", e); 129 + return ( 130 + StatusCode::INTERNAL_SERVER_ERROR, 131 + Json(json!({"error": "InternalError"})), 132 + ) 133 + .into_response(); 134 + } 135 + }; 136 + 137 + if was_inserted { 138 + if let Err(e) = state.blob_store.put_bytes(&storage_key, bytes::Bytes::from(data)).await { 139 + error!("Failed to upload blob to storage: {:?}", e); 140 + return ( 141 + StatusCode::INTERNAL_SERVER_ERROR, 142 + Json(json!({"error": "InternalError", "message": "Failed to store blob"})), 143 + ) 144 + .into_response(); 145 + } 146 + } 147 + 148 + if let Err(e) = tx.commit().await { 149 + error!("Failed to commit blob transaction: {:?}", e); 150 + if was_inserted { 151 + if let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 152 + error!("Failed to cleanup orphaned blob {}: {:?}", storage_key, cleanup_err); 153 + } 154 + } 123 155 return ( 124 156 StatusCode::INTERNAL_SERVER_ERROR, 125 157 Json(json!({"error": "InternalError"})),
+27 -6
src/api/repo/record/read.rs
··· 9 9 use jacquard_repo::storage::BlockStore; 10 10 use serde::{Deserialize, Serialize}; 11 11 use serde_json::json; 12 + use std::collections::HashMap; 12 13 use std::str::FromStr; 13 14 use tracing::error; 14 15 ··· 232 233 } 233 234 }; 234 235 235 - let mut records = Vec::new(); 236 - let mut last_rkey = None; 236 + let last_rkey = rows.last().map(|(rkey, _)| rkey.clone()); 237 237 238 - for (rkey, cid_str) in rows { 239 - last_rkey = Some(rkey.clone()); 238 + let mut cid_to_rkey: HashMap<Cid, (String, String)> = HashMap::new(); 239 + let mut cids: Vec<Cid> = Vec::with_capacity(rows.len()); 240 240 241 - if let Ok(cid) = Cid::from_str(&cid_str) { 242 - if let Ok(Some(block)) = state.block_store.get(&cid).await { 241 + for (rkey, cid_str) in &rows { 242 + if let Ok(cid) = Cid::from_str(cid_str) { 243 + cid_to_rkey.insert(cid, (rkey.clone(), cid_str.clone())); 244 + cids.push(cid); 245 + } 246 + } 247 + 248 + let blocks = match state.block_store.get_many(&cids).await { 249 + Ok(b) => b, 250 + Err(e) => { 251 + error!("Error fetching blocks: {:?}", e); 252 + return ( 253 + StatusCode::INTERNAL_SERVER_ERROR, 254 + Json(json!({"error": "InternalError"})), 255 + ) 256 + .into_response(); 257 + } 258 + }; 259 + 260 + let mut records = Vec::new(); 261 + for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) { 262 + if let Some(block) = block_opt { 263 + if let Some((rkey, cid_str)) = cid_to_rkey.get(cid) { 243 264 if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) { 244 265 records.push(json!({ 245 266 "uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
+49 -21
src/api/repo/record/utils.rs
··· 92 92 .map_err(|e| format!("DB Error (repos): {}", e))?; 93 93 94 94 let rev_str = rev.to_string(); 95 + 96 + let mut upsert_collections: Vec<String> = Vec::new(); 97 + let mut upsert_rkeys: Vec<String> = Vec::new(); 98 + let mut upsert_cids: Vec<String> = Vec::new(); 99 + 100 + let mut delete_collections: Vec<String> = Vec::new(); 101 + let mut delete_rkeys: Vec<String> = Vec::new(); 102 + 95 103 for op in &ops { 96 104 match op { 97 105 RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid } => { 98 - sqlx::query!( 99 - "INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) VALUES ($1, $2, $3, $4, $5) 100 - ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4, repo_rev = $5, created_at = NOW()", 101 - user_id, 102 - collection, 103 - rkey, 104 - cid.to_string(), 105 - rev_str 106 - ) 107 - .execute(&mut *tx) 108 - .await 109 - .map_err(|e| format!("DB Error (records): {}", e))?; 106 + upsert_collections.push(collection.clone()); 107 + upsert_rkeys.push(rkey.clone()); 108 + upsert_cids.push(cid.to_string()); 110 109 } 111 110 RecordOp::Delete { collection, rkey } => { 112 - sqlx::query!( 113 - "DELETE FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3", 114 - user_id, 115 - collection, 116 - rkey 117 - ) 118 - .execute(&mut *tx) 119 - .await 120 - .map_err(|e| format!("DB Error (records): {}", e))?; 111 + delete_collections.push(collection.clone()); 112 + delete_rkeys.push(rkey.clone()); 121 113 } 122 114 } 115 + } 116 + 117 + if !upsert_collections.is_empty() { 118 + sqlx::query!( 119 + r#" 120 + INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) 121 + SELECT $1, collection, rkey, record_cid, $5 122 + FROM UNNEST($2::text[], $3::text[], $4::text[]) AS t(collection, rkey, record_cid) 123 + ON CONFLICT (repo_id, collection, rkey) DO UPDATE 124 + SET record_cid = EXCLUDED.record_cid, repo_rev = EXCLUDED.repo_rev, created_at = NOW() 125 + "#, 126 + user_id, 127 + &upsert_collections, 128 + &upsert_rkeys, 129 + &upsert_cids, 130 + rev_str 131 + ) 132 + .execute(&mut *tx) 133 + .await 134 + .map_err(|e| format!("DB Error (records batch upsert): {}", e))?; 135 + } 136 + 137 + if !delete_collections.is_empty() { 138 + sqlx::query!( 139 + r#" 140 + DELETE FROM records 141 + WHERE repo_id = $1 142 + AND (collection, rkey) IN (SELECT * FROM UNNEST($2::text[], $3::text[])) 143 + "#, 144 + user_id, 145 + &delete_collections, 146 + &delete_rkeys 147 + ) 148 + .execute(&mut *tx) 149 + .await 150 + .map_err(|e| format!("DB Error (records batch delete): {}", e))?; 123 151 } 124 152 125 153 let ops_json = ops.iter().map(|op| {
+10 -16
src/api/server/app_password.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::auth::BearerAuth; 3 - use crate::state::AppState; 3 + use crate::state::{AppState, RateLimitKind}; 4 4 use crate::util::get_user_id_by_did; 5 5 use axum::{ 6 6 Json, ··· 82 82 Json(input): Json<CreateAppPasswordInput>, 83 83 ) -> Response { 84 84 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 85 - if !state.distributed_rate_limiter.check_rate_limit( 86 - &format!("app_password:{}", client_ip), 87 - 10, 88 - 60_000, 89 - ).await { 90 - if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 91 - warn!(ip = %client_ip, "App password creation rate limit exceeded"); 92 - return ( 93 - axum::http::StatusCode::TOO_MANY_REQUESTS, 94 - Json(json!({ 95 - "error": "RateLimitExceeded", 96 - "message": "Too many requests. Please try again later." 97 - })), 98 - ).into_response(); 99 - } 85 + if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await { 86 + warn!(ip = %client_ip, "App password creation rate limit exceeded"); 87 + return ( 88 + axum::http::StatusCode::TOO_MANY_REQUESTS, 89 + Json(json!({ 90 + "error": "RateLimitExceeded", 91 + "message": "Too many requests. Please try again later." 92 + })), 93 + ).into_response(); 100 94 } 101 95 102 96 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
+19 -31
src/api/server/email.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::state::AppState; 2 + use crate::state::{AppState, RateLimitKind}; 3 3 use axum::{ 4 4 Json, 5 5 extract::State, ··· 27 27 Json(input): Json<RequestEmailUpdateInput>, 28 28 ) -> Response { 29 29 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 30 - if !state.distributed_rate_limiter.check_rate_limit( 31 - &format!("email_update:{}", client_ip), 32 - 5, 33 - 3_600_000, 34 - ).await { 35 - if state.rate_limiters.email_update.check_key(&client_ip).is_err() { 36 - warn!(ip = %client_ip, "Email update rate limit exceeded"); 37 - return ( 38 - StatusCode::TOO_MANY_REQUESTS, 39 - Json(json!({ 40 - "error": "RateLimitExceeded", 41 - "message": "Too many requests. Please try again later." 42 - })), 43 - ).into_response(); 44 - } 30 + if !state.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip).await { 31 + warn!(ip = %client_ip, "Email update rate limit exceeded"); 32 + return ( 33 + StatusCode::TOO_MANY_REQUESTS, 34 + Json(json!({ 35 + "error": "RateLimitExceeded", 36 + "message": "Too many requests. Please try again later." 37 + })), 38 + ).into_response(); 45 39 } 46 40 47 41 let token = match crate::auth::extract_bearer_token_from_header( ··· 154 148 Json(input): Json<ConfirmEmailInput>, 155 149 ) -> Response { 156 150 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 157 - if !state.distributed_rate_limiter.check_rate_limit( 158 - &format!("confirm_email:{}", client_ip), 159 - 10, 160 - 60_000, 161 - ).await { 162 - if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 163 - warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 164 - return ( 165 - StatusCode::TOO_MANY_REQUESTS, 166 - Json(json!({ 167 - "error": "RateLimitExceeded", 168 - "message": "Too many requests. Please try again later." 169 - })), 170 - ).into_response(); 171 - } 151 + if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await { 152 + warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 153 + return ( 154 + StatusCode::TOO_MANY_REQUESTS, 155 + Json(json!({ 156 + "error": "RateLimitExceeded", 157 + "message": "Too many requests. Please try again later." 158 + })), 159 + ).into_response(); 172 160 } 173 161 174 162 let token = match crate::auth::extract_bearer_token_from_header(
+51 -18
src/api/server/password.rs
··· 1 - use crate::state::AppState; 1 + use crate::state::{AppState, RateLimitKind}; 2 2 use axum::{ 3 3 Json, 4 4 extract::State, ··· 42 42 Json(input): Json<RequestPasswordResetInput>, 43 43 ) -> Response { 44 44 let client_ip = extract_client_ip(&headers); 45 - if state.rate_limiters.password_reset.check_key(&client_ip).is_err() { 45 + if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await { 46 46 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 47 47 return ( 48 48 StatusCode::TOO_MANY_REQUESTS, ··· 128 128 Json(input): Json<ResetPasswordInput>, 129 129 ) -> Response { 130 130 let client_ip = extract_client_ip(&headers); 131 - if !state.distributed_rate_limiter.check_rate_limit( 132 - &format!("reset_password:{}", client_ip), 133 - 10, 134 - 60_000, 135 - ).await { 136 - if state.rate_limiters.reset_password.check_key(&client_ip).is_err() { 137 - warn!(ip = %client_ip, "Reset password rate limit exceeded"); 138 - return ( 139 - StatusCode::TOO_MANY_REQUESTS, 140 - Json(json!({ 141 - "error": "RateLimitExceeded", 142 - "message": "Too many requests. Please try again later." 143 - })), 144 - ).into_response(); 145 - } 131 + if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await { 132 + warn!(ip = %client_ip, "Reset password rate limit exceeded"); 133 + return ( 134 + StatusCode::TOO_MANY_REQUESTS, 135 + Json(json!({ 136 + "error": "RateLimitExceeded", 137 + "message": "Too many requests. Please try again later." 138 + })), 139 + ).into_response(); 146 140 } 147 141 148 142 let token = input.token.trim(); ··· 259 253 .into_response(); 260 254 } 261 255 262 - if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id) 256 + let user_did = match sqlx::query_scalar!( 257 + "SELECT did FROM users WHERE id = $1", 258 + user_id 259 + ) 260 + .fetch_one(&mut *tx) 261 + .await 262 + { 263 + Ok(did) => did, 264 + Err(e) => { 265 + error!("Failed to get DID for user {}: {:?}", user_id, e); 266 + return ( 267 + StatusCode::INTERNAL_SERVER_ERROR, 268 + Json(json!({"error": "InternalError"})), 269 + ) 270 + .into_response(); 271 + } 272 + }; 273 + 274 + let session_jtis: Vec<String> = match sqlx::query_scalar!( 275 + "SELECT access_jti FROM session_tokens WHERE did = $1", 276 + user_did 277 + ) 278 + .fetch_all(&mut *tx) 279 + .await 280 + { 281 + Ok(jtis) => jtis, 282 + Err(e) => { 283 + error!("Failed to fetch session JTIs: {:?}", e); 284 + vec![] 285 + } 286 + }; 287 + 288 + if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 263 289 .execute(&mut *tx) 264 290 .await 265 291 { ··· 278 304 Json(json!({"error": "InternalError"})), 279 305 ) 280 306 .into_response(); 307 + } 308 + 309 + for jti in session_jtis { 310 + let cache_key = format!("auth:session:{}:{}", user_did, jti); 311 + if let Err(e) = state.cache.delete(&cache_key).await { 312 + warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e); 313 + } 281 314 } 282 315 283 316 info!("Password reset completed for user {}", user_id);
+33 -25
src/api/server/session.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::auth::BearerAuth; 3 - use crate::state::AppState; 3 + use crate::state::{AppState, RateLimitKind}; 4 4 use axum::{ 5 5 Json, 6 6 extract::State, ··· 52 52 info!("create_session called"); 53 53 54 54 let client_ip = extract_client_ip(&headers); 55 - if state.rate_limiters.login.check_key(&client_ip).is_err() { 55 + if !state.check_rate_limit(RateLimitKind::Login, &client_ip).await { 56 56 warn!(ip = %client_ip, "Login rate limit exceeded"); 57 57 return ( 58 58 StatusCode::TOO_MANY_REQUESTS, ··· 97 97 } 98 98 }; 99 99 100 - let password_valid = verify(&input.password, &row.password_hash).unwrap_or(false) 101 - || sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", row.id) 102 - .fetch_all(&state.db) 103 - .await 104 - .unwrap_or_default() 105 - .iter() 106 - .any(|app| verify(&input.password, &app.password_hash).unwrap_or(false)); 100 + let password_valid = if verify(&input.password, &row.password_hash).unwrap_or(false) { 101 + true 102 + } else { 103 + let app_passwords = sqlx::query!( 104 + "SELECT password_hash FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC LIMIT 20", 105 + row.id 106 + ) 107 + .fetch_all(&state.db) 108 + .await 109 + .unwrap_or_default(); 110 + 111 + app_passwords.iter().any(|app| verify(&input.password, &app.password_hash).unwrap_or(false)) 112 + }; 107 113 108 114 if !password_valid { 109 115 warn!("Password verification failed for login attempt"); ··· 204 210 Err(_) => return ApiError::AuthenticationFailed.into_response(), 205 211 }; 206 212 213 + let did = crate::auth::get_did_from_token(&token).ok(); 214 + 207 215 match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti) 208 216 .execute(&state.db) 209 217 .await 210 218 { 211 - Ok(res) if res.rows_affected() > 0 => Json(json!({})).into_response(), 219 + Ok(res) if res.rows_affected() > 0 => { 220 + if let Some(did) = did { 221 + let session_cache_key = format!("auth:session:{}:{}", did, jti); 222 + let _ = state.cache.delete(&session_cache_key).await; 223 + } 224 + Json(json!({})).into_response() 225 + } 212 226 Ok(_) => ApiError::AuthenticationFailed.into_response(), 213 227 Err(e) => { 214 228 error!("Database error in delete_session: {:?}", e); ··· 222 236 headers: axum::http::HeaderMap, 223 237 ) -> Response { 224 238 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 225 - if !state.distributed_rate_limiter.check_rate_limit( 226 - &format!("refresh_session:{}", client_ip), 227 - 60, 228 - 60_000, 229 - ).await { 230 - if state.rate_limiters.refresh_session.check_key(&client_ip).is_err() { 231 - tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 232 - return ( 233 - axum::http::StatusCode::TOO_MANY_REQUESTS, 234 - axum::Json(serde_json::json!({ 235 - "error": "RateLimitExceeded", 236 - "message": "Too many requests. Please try again later." 237 - })), 238 - ).into_response(); 239 - } 239 + if !state.check_rate_limit(RateLimitKind::RefreshSession, &client_ip).await { 240 + tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 241 + return ( 242 + axum::http::StatusCode::TOO_MANY_REQUESTS, 243 + axum::Json(serde_json::json!({ 244 + "error": "RateLimitExceeded", 245 + "message": "Too many requests. Please try again later." 246 + })), 247 + ).into_response(); 240 248 } 241 249 242 250 let refresh_token = match crate::auth::extract_bearer_token_from_header(
+3 -3
src/auth/extractor.rs
··· 7 7 use serde_json::json; 8 8 9 9 use crate::state::AppState; 10 - use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token, validate_bearer_token_allow_deactivated}; 10 + use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated}; 11 11 12 12 pub struct BearerAuth(pub AuthenticatedUser); 13 13 ··· 110 110 111 111 let token = extract_bearer_token(auth_header)?; 112 112 113 - match validate_bearer_token(&state.db, token).await { 113 + match validate_bearer_token_cached(&state.db, &state.cache, token).await { 114 114 Ok(user) => Ok(BearerAuth(user)), 115 115 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 116 116 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), ··· 137 137 138 138 let token = extract_bearer_token(auth_header)?; 139 139 140 - match validate_bearer_token_allow_deactivated(&state.db, token).await { 140 + match validate_bearer_token_cached_allow_deactivated(&state.db, &state.cache, token).await { 141 141 Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 142 142 Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 143 143 Err(_) => Err(AuthError::AuthenticationFailed),
+119 -30
src/auth/mod.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 2 use sqlx::PgPool; 3 3 use std::fmt; 4 + use std::sync::Arc; 5 + use std::time::Duration; 6 + use crate::cache::Cache; 4 7 5 8 pub mod extractor; 6 9 pub mod token; ··· 15 18 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 16 19 }; 17 20 pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 21 + 22 + const KEY_CACHE_TTL_SECS: u64 = 300; 23 + const SESSION_CACHE_TTL_SECS: u64 = 60; 18 24 19 25 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 20 26 pub enum TokenValidationError { ··· 45 51 db: &PgPool, 46 52 token: &str, 47 53 ) -> Result<AuthenticatedUser, TokenValidationError> { 48 - validate_bearer_token_with_options(db, token, false).await 54 + validate_bearer_token_with_options_internal(db, None, token, false).await 49 55 } 50 56 51 57 pub async fn validate_bearer_token_allow_deactivated( 52 58 db: &PgPool, 53 59 token: &str, 54 60 ) -> Result<AuthenticatedUser, TokenValidationError> { 55 - validate_bearer_token_with_options(db, token, true).await 61 + validate_bearer_token_with_options_internal(db, None, token, true).await 56 62 } 57 63 58 - async fn validate_bearer_token_with_options( 64 + pub async fn validate_bearer_token_cached( 59 65 db: &PgPool, 66 + cache: &Arc<dyn Cache>, 67 + token: &str, 68 + ) -> Result<AuthenticatedUser, TokenValidationError> { 69 + validate_bearer_token_with_options_internal(db, Some(cache), token, false).await 70 + } 71 + 72 + pub async fn validate_bearer_token_cached_allow_deactivated( 73 + db: &PgPool, 74 + cache: &Arc<dyn Cache>, 75 + token: &str, 76 + ) -> Result<AuthenticatedUser, TokenValidationError> { 77 + validate_bearer_token_with_options_internal(db, Some(cache), token, true).await 78 + } 79 + 80 + async fn validate_bearer_token_with_options_internal( 81 + db: &PgPool, 82 + cache: Option<&Arc<dyn Cache>>, 60 83 token: &str, 61 84 allow_deactivated: bool, 62 85 ) -> Result<AuthenticatedUser, TokenValidationError> { 63 86 let did_from_token = get_did_from_token(token).ok(); 64 87 65 88 if let Some(ref did) = did_from_token { 66 - if let Some(user) = sqlx::query!( 67 - "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref 68 - FROM users u 69 - JOIN user_keys k ON u.id = k.user_id 70 - WHERE u.did = $1", 71 - did 72 - ) 73 - .fetch_optional(db) 74 - .await 75 - .ok() 76 - .flatten() 77 - { 78 - if !allow_deactivated && user.deactivated_at.is_some() { 89 + let key_cache_key = format!("auth:key:{}", did); 90 + let mut cached_key: Option<Vec<u8>> = None; 91 + 92 + if let Some(c) = cache { 93 + cached_key = c.get_bytes(&key_cache_key).await; 94 + if cached_key.is_some() { 95 + crate::metrics::record_auth_cache_hit("key"); 96 + } else { 97 + crate::metrics::record_auth_cache_miss("key"); 98 + } 99 + } 100 + 101 + let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key { 102 + let user_status = sqlx::query!( 103 + "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1", 104 + did 105 + ) 106 + .fetch_optional(db) 107 + .await 108 + .ok() 109 + .flatten(); 110 + 111 + match user_status { 112 + Some(status) => (Some(key), status.deactivated_at, status.takedown_ref), 113 + None => (None, None, None), 114 + } 115 + } else { 116 + if let Some(user) = sqlx::query!( 117 + "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref 118 + FROM users u 119 + JOIN user_keys k ON u.id = k.user_id 120 + WHERE u.did = $1", 121 + did 122 + ) 123 + .fetch_optional(db) 124 + .await 125 + .ok() 126 + .flatten() 127 + { 128 + let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 129 + .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 130 + 131 + if let Some(c) = cache { 132 + let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await; 133 + } 134 + 135 + (Some(key), user.deactivated_at, user.takedown_ref) 136 + } else { 137 + (None, None, None) 138 + } 139 + }; 140 + 141 + if let Some(decrypted_key) = decrypted_key { 142 + if !allow_deactivated && deactivated_at.is_some() { 79 143 return Err(TokenValidationError::AccountDeactivated); 80 144 } 81 - if user.takedown_ref.is_some() { 145 + if takedown_ref.is_some() { 82 146 return Err(TokenValidationError::AccountTakedown); 83 147 } 84 148 85 - let decrypted_key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 86 - .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 87 - 88 149 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 89 - let session_exists = sqlx::query_scalar!( 90 - "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 91 - did, 92 - token_data.claims.jti 93 - ) 94 - .fetch_optional(db) 95 - .await 96 - .ok() 97 - .flatten(); 150 + let jti = &token_data.claims.jti; 151 + let session_cache_key = format!("auth:session:{}:{}", did, jti); 152 + let mut session_valid = false; 98 153 99 - if session_exists.is_some() { 154 + if let Some(c) = cache { 155 + if let Some(cached_value) = c.get(&session_cache_key).await { 156 + session_valid = cached_value == "1"; 157 + crate::metrics::record_auth_cache_hit("session"); 158 + } else { 159 + crate::metrics::record_auth_cache_miss("session"); 160 + } 161 + } 162 + 163 + if !session_valid { 164 + let session_exists = sqlx::query_scalar!( 165 + "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()", 166 + did, 167 + jti 168 + ) 169 + .fetch_optional(db) 170 + .await 171 + .ok() 172 + .flatten(); 173 + 174 + session_valid = session_exists.is_some(); 175 + 176 + if session_valid { 177 + if let Some(c) = cache { 178 + let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await; 179 + } 180 + } 181 + } 182 + 183 + if session_valid { 100 184 return Ok(AuthenticatedUser { 101 185 did: did.clone(), 102 186 key_bytes: Some(decrypted_key), ··· 139 223 } 140 224 141 225 Err(TokenValidationError::AuthenticationFailed) 226 + } 227 + 228 + pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) { 229 + let key_cache_key = format!("auth:key:{}", did); 230 + let _ = cache.delete(&key_cache_key).await; 142 231 } 143 232 144 233 #[derive(Debug, Serialize, Deserialize)]
+10
src/cache/mod.rs
··· 1 1 use async_trait::async_trait; 2 + use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 2 3 use std::sync::Arc; 3 4 use std::time::Duration; 4 5 ··· 15 16 async fn get(&self, key: &str) -> Option<String>; 16 17 async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 17 18 async fn delete(&self, key: &str) -> Result<(), CacheError>; 19 + 20 + async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 21 + self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 22 + } 23 + 24 + async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 25 + let encoded = BASE64.encode(value); 26 + self.set(key, &encoded, ttl).await 27 + } 18 28 } 19 29 20 30 #[derive(Clone)]
+4
src/lib.rs
··· 5 5 pub mod config; 6 6 pub mod crawlers; 7 7 pub mod image; 8 + pub mod metrics; 8 9 pub mod notifications; 9 10 pub mod oauth; 10 11 pub mod plc; ··· 18 19 19 20 use axum::{ 20 21 Router, 22 + middleware, 21 23 routing::{any, get, post}, 22 24 }; 23 25 use state::AppState; ··· 25 27 26 28 pub fn app(state: AppState) -> Router { 27 29 let router = Router::new() 30 + .route("/metrics", get(metrics::metrics_handler)) 28 31 .route("/health", get(api::server::health)) 29 32 .route("/xrpc/_health", get(api::server::health)) 30 33 .route("/robots.txt", get(api::server::robots_txt)) ··· 382 385 post(api::notification_prefs::update_notification_prefs), 383 386 ) 384 387 .route("/xrpc/{*method}", any(api::proxy::proxy_handler)) 388 + .layer(middleware::from_fn(metrics::metrics_middleware)) 385 389 .with_state(state); 386 390 387 391 let frontend_dir = std::env::var("FRONTEND_DIR")
+23 -3
src/main.rs
··· 12 12 dotenvy::dotenv().ok(); 13 13 tracing_subscriber::fmt::init(); 14 14 15 + bspds::metrics::init_metrics(); 16 + 15 17 match run().await { 16 18 Ok(()) => ExitCode::SUCCESS, 17 19 Err(e) => { ··· 25 27 let database_url = std::env::var("DATABASE_URL") 26 28 .map_err(|_| "DATABASE_URL environment variable must be set")?; 27 29 30 + let max_connections: u32 = std::env::var("DATABASE_MAX_CONNECTIONS") 31 + .ok() 32 + .and_then(|v| v.parse().ok()) 33 + .unwrap_or(100); 34 + let min_connections: u32 = std::env::var("DATABASE_MIN_CONNECTIONS") 35 + .ok() 36 + .and_then(|v| v.parse().ok()) 37 + .unwrap_or(10); 38 + let acquire_timeout_secs: u64 = std::env::var("DATABASE_ACQUIRE_TIMEOUT_SECS") 39 + .ok() 40 + .and_then(|v| v.parse().ok()) 41 + .unwrap_or(10); 42 + 43 + info!( 44 + "Configuring database pool: max={}, min={}, acquire_timeout={}s", 45 + max_connections, min_connections, acquire_timeout_secs 46 + ); 47 + 28 48 let pool = sqlx::postgres::PgPoolOptions::new() 29 - .max_connections(20) 30 - .min_connections(2) 31 - .acquire_timeout(std::time::Duration::from_secs(10)) 49 + .max_connections(max_connections) 50 + .min_connections(min_connections) 51 + .acquire_timeout(std::time::Duration::from_secs(acquire_timeout_secs)) 32 52 .idle_timeout(std::time::Duration::from_secs(300)) 33 53 .max_lifetime(std::time::Duration::from_secs(1800)) 34 54 .connect(&database_url)
+212
src/metrics.rs
··· 1 + use axum::{ 2 + body::Body, 3 + http::{Request, StatusCode}, 4 + middleware::Next, 5 + response::{IntoResponse, Response}, 6 + }; 7 + use metrics::{counter, gauge, histogram}; 8 + use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; 9 + use std::sync::OnceLock; 10 + use std::time::Instant; 11 + 12 + static PROMETHEUS_HANDLE: OnceLock<PrometheusHandle> = OnceLock::new(); 13 + 14 + pub fn init_metrics() -> PrometheusHandle { 15 + let builder = PrometheusBuilder::new(); 16 + let handle = builder 17 + .install_recorder() 18 + .expect("failed to install Prometheus recorder"); 19 + 20 + PROMETHEUS_HANDLE.set(handle.clone()).ok(); 21 + 22 + describe_metrics(); 23 + 24 + handle 25 + } 26 + 27 + fn describe_metrics() { 28 + metrics::describe_counter!( 29 + "bspds_http_requests_total", 30 + "Total number of HTTP requests" 31 + ); 32 + metrics::describe_histogram!( 33 + "bspds_http_request_duration_seconds", 34 + "HTTP request duration in seconds" 35 + ); 36 + metrics::describe_counter!( 37 + "bspds_auth_cache_hits_total", 38 + "Total number of authentication cache hits" 39 + ); 40 + metrics::describe_counter!( 41 + "bspds_auth_cache_misses_total", 42 + "Total number of authentication cache misses" 43 + ); 44 + metrics::describe_gauge!( 45 + "bspds_firehose_subscribers", 46 + "Number of active firehose WebSocket subscribers" 47 + ); 48 + metrics::describe_counter!( 49 + "bspds_firehose_events_total", 50 + "Total number of firehose events published" 51 + ); 52 + metrics::describe_counter!( 53 + "bspds_block_operations_total", 54 + "Total number of block store operations" 55 + ); 56 + metrics::describe_counter!( 57 + "bspds_s3_operations_total", 58 + "Total number of S3/blob storage operations" 59 + ); 60 + metrics::describe_gauge!( 61 + "bspds_notification_queue_size", 62 + "Current size of the notification queue" 63 + ); 64 + metrics::describe_counter!( 65 + "bspds_rate_limit_rejections_total", 66 + "Total number of rate limit rejections" 67 + ); 68 + metrics::describe_counter!( 69 + "bspds_db_queries_total", 70 + "Total number of database queries" 71 + ); 72 + metrics::describe_histogram!( 73 + "bspds_db_query_duration_seconds", 74 + "Database query duration in seconds" 75 + ); 76 + } 77 + 78 + pub async fn metrics_handler() -> impl IntoResponse { 79 + match PROMETHEUS_HANDLE.get() { 80 + Some(handle) => { 81 + let metrics = handle.render(); 82 + (StatusCode::OK, [("content-type", "text/plain; version=0.0.4")], metrics) 83 + } 84 + None => ( 85 + StatusCode::INTERNAL_SERVER_ERROR, 86 + [("content-type", "text/plain")], 87 + "Metrics not initialized".to_string(), 88 + ), 89 + } 90 + } 91 + 92 + pub async fn metrics_middleware(request: Request<Body>, next: Next) -> Response { 93 + let start = Instant::now(); 94 + let method = request.method().to_string(); 95 + let path = normalize_path(request.uri().path()); 96 + 97 + let response = next.run(request).await; 98 + 99 + let duration = start.elapsed().as_secs_f64(); 100 + let status = response.status().as_u16().to_string(); 101 + 102 + counter!( 103 + "bspds_http_requests_total", 104 + "method" => method.clone(), 105 + "path" => path.clone(), 106 + "status" => status.clone() 107 + ) 108 + .increment(1); 109 + 110 + histogram!( 111 + "bspds_http_request_duration_seconds", 112 + "method" => method, 113 + "path" => path 114 + ) 115 + .record(duration); 116 + 117 + response 118 + } 119 + 120 + fn normalize_path(path: &str) -> String { 121 + if path.starts_with("/xrpc/") { 122 + if let Some(method) = path.strip_prefix("/xrpc/") { 123 + if let Some(q) = method.find('?') { 124 + return format!("/xrpc/{}", &method[..q]); 125 + } 126 + return path.to_string(); 127 + } 128 + } 129 + 130 + if path.starts_with("/u/") && path.ends_with("/did.json") { 131 + return "/u/{handle}/did.json".to_string(); 132 + } 133 + 134 + if path.starts_with("/oauth/") { 135 + return path.to_string(); 136 + } 137 + 138 + path.to_string() 139 + } 140 + 141 + pub fn record_auth_cache_hit(cache_type: &str) { 142 + counter!("bspds_auth_cache_hits_total", "cache_type" => cache_type.to_string()).increment(1); 143 + } 144 + 145 + pub fn record_auth_cache_miss(cache_type: &str) { 146 + counter!("bspds_auth_cache_misses_total", "cache_type" => cache_type.to_string()).increment(1); 147 + } 148 + 149 + pub fn set_firehose_subscribers(count: usize) { 150 + gauge!("bspds_firehose_subscribers").set(count as f64); 151 + } 152 + 153 + pub fn increment_firehose_subscribers() { 154 + counter!("bspds_firehose_events_total").increment(1); 155 + } 156 + 157 + pub fn record_firehose_event() { 158 + counter!("bspds_firehose_events_total").increment(1); 159 + } 160 + 161 + pub fn record_block_operation(op_type: &str) { 162 + counter!("bspds_block_operations_total", "op_type" => op_type.to_string()).increment(1); 163 + } 164 + 165 + pub fn record_s3_operation(op_type: &str, status: &str) { 166 + counter!( 167 + "bspds_s3_operations_total", 168 + "op_type" => op_type.to_string(), 169 + "status" => status.to_string() 170 + ) 171 + .increment(1); 172 + } 173 + 174 + pub fn set_notification_queue_size(size: usize) { 175 + gauge!("bspds_notification_queue_size").set(size as f64); 176 + } 177 + 178 + pub fn record_rate_limit_rejection(limiter: &str) { 179 + counter!("bspds_rate_limit_rejections_total", "limiter" => limiter.to_string()).increment(1); 180 + } 181 + 182 + pub fn record_db_query(query_type: &str, duration_seconds: f64) { 183 + counter!("bspds_db_queries_total", "query_type" => query_type.to_string()).increment(1); 184 + histogram!( 185 + "bspds_db_query_duration_seconds", 186 + "query_type" => query_type.to_string() 187 + ) 188 + .record(duration_seconds); 189 + } 190 + 191 + #[cfg(test)] 192 + mod tests { 193 + use super::*; 194 + 195 + #[test] 196 + fn test_normalize_path() { 197 + assert_eq!( 198 + normalize_path("/xrpc/com.atproto.repo.getRecord"), 199 + "/xrpc/com.atproto.repo.getRecord" 200 + ); 201 + assert_eq!( 202 + normalize_path("/xrpc/com.atproto.repo.getRecord?foo=bar"), 203 + "/xrpc/com.atproto.repo.getRecord" 204 + ); 205 + assert_eq!( 206 + normalize_path("/u/alice.example.com/did.json"), 207 + "/u/{handle}/did.json" 208 + ); 209 + assert_eq!(normalize_path("/oauth/token"), "/oauth/token"); 210 + assert_eq!(normalize_path("/health"), "/health"); 211 + } 212 + }
+12 -2
src/notifications/service.rs
··· 21 21 22 22 impl NotificationService { 23 23 pub fn new(db: PgPool) -> Self { 24 + let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 25 + .ok() 26 + .and_then(|v| v.parse().ok()) 27 + .unwrap_or(1000); 28 + 29 + let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 30 + .ok() 31 + .and_then(|v| v.parse().ok()) 32 + .unwrap_or(100); 33 + 24 34 Self { 25 35 db, 26 36 senders: HashMap::new(), 27 - poll_interval: Duration::from_secs(5), 28 - batch_size: 10, 37 + poll_interval: Duration::from_millis(poll_interval_ms), 38 + batch_size, 29 39 } 30 40 } 31 41
+3 -3
src/oauth/endpoints/authorize.rs
··· 9 9 use subtle::ConstantTimeEq; 10 10 use urlencoding::encode as url_encode; 11 11 12 - use crate::state::AppState; 12 + use crate::state::{AppState, RateLimitKind}; 13 13 use crate::oauth::{Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, db, templates}; 14 14 use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code}; 15 15 ··· 273 273 let json_response = wants_json(&headers); 274 274 275 275 let client_ip = extract_client_ip(&headers); 276 - if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 276 + if !state.check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip).await { 277 277 tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 278 278 if json_response { 279 279 return ( ··· 761 761 Form(form): Form<Authorize2faSubmit>, 762 762 ) -> Response { 763 763 let client_ip = extract_client_ip(&headers); 764 - if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 764 + if !state.check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip).await { 765 765 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 766 766 return ( 767 767 axum::http::StatusCode::TOO_MANY_REQUESTS,
+4 -10
src/oauth/endpoints/par.rs
··· 6 6 use chrono::{Duration, Utc}; 7 7 use serde::{Deserialize, Serialize}; 8 8 9 - use crate::state::AppState; 9 + use crate::state::{AppState, RateLimitKind}; 10 10 use crate::oauth::{ 11 11 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 12 12 client::ClientMetadataCache, ··· 54 54 Form(request): Form<ParRequest>, 55 55 ) -> Result<Json<ParResponse>, OAuthError> { 56 56 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 57 - if !state.distributed_rate_limiter.check_rate_limit( 58 - &format!("oauth_par:{}", client_ip), 59 - 30, 60 - 60_000, 61 - ).await { 62 - if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() { 63 - tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 64 - return Err(OAuthError::RateLimited); 65 - } 57 + if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 58 + tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 59 + return Err(OAuthError::RateLimited); 66 60 } 67 61 68 62 if request.response_type != "code" {
+7 -19
src/oauth/endpoints/token/introspect.rs
··· 4 4 use chrono::Utc; 5 5 use serde::{Deserialize, Serialize}; 6 6 7 - use crate::state::AppState; 7 + use crate::state::{AppState, RateLimitKind}; 8 8 use crate::oauth::{OAuthError, db}; 9 9 10 10 use super::helpers::extract_token_claims; ··· 22 22 Form(request): Form<RevokeRequest>, 23 23 ) -> Result<StatusCode, OAuthError> { 24 24 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 25 - if !state.distributed_rate_limiter.check_rate_limit( 26 - &format!("oauth_revoke:{}", client_ip), 27 - 30, 28 - 60_000, 29 - ).await { 30 - if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 31 - tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 32 - return Err(OAuthError::RateLimited); 33 - } 25 + if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await { 26 + tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 27 + return Err(OAuthError::RateLimited); 34 28 } 35 29 36 30 if let Some(token) = &request.token { ··· 84 78 Form(request): Form<IntrospectRequest>, 85 79 ) -> Result<Json<IntrospectResponse>, OAuthError> { 86 80 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 87 - if !state.distributed_rate_limiter.check_rate_limit( 88 - &format!("oauth_introspect:{}", client_ip), 89 - 30, 90 - 60_000, 91 - ).await { 92 - if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 93 - tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 94 - return Err(OAuthError::RateLimited); 95 - } 81 + if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await { 82 + tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 83 + return Err(OAuthError::RateLimited); 96 84 } 97 85 98 86 let inactive_response = IntrospectResponse {
+2 -2
src/oauth/endpoints/token/mod.rs
··· 9 9 http::HeaderMap, 10 10 }; 11 11 12 - use crate::state::AppState; 12 + use crate::state::{AppState, RateLimitKind}; 13 13 use crate::oauth::OAuthError; 14 14 15 15 pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; ··· 41 41 Form(request): Form<TokenRequest>, 42 42 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 43 43 let client_ip = extract_client_ip(&headers); 44 - if state.rate_limiters.oauth_token.check_key(&client_ip).is_err() { 44 + if !state.check_rate_limit(RateLimitKind::OAuthToken, &client_ip).await { 45 45 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 46 46 return Err(OAuthError::InvalidRequest( 47 47 "Too many requests. Please try again later.".to_string(),
+24 -1
src/plc/mod.rs
··· 5 5 use serde_json::{json, Value}; 6 6 use sha2::{Digest, Sha256}; 7 7 use std::collections::HashMap; 8 + use std::time::Duration; 8 9 use thiserror::Error; 9 10 10 11 #[derive(Error, Debug)] ··· 21 22 Serialization(String), 22 23 #[error("Signing error: {0}")] 23 24 Signing(String), 25 + #[error("Request timeout")] 26 + Timeout, 27 + #[error("Service unavailable (circuit breaker open)")] 28 + CircuitBreakerOpen, 24 29 } 25 30 26 31 #[derive(Debug, Clone, Serialize, Deserialize)] ··· 82 87 std::env::var("PLC_DIRECTORY_URL") 83 88 .unwrap_or_else(|_| "https://plc.directory".to_string()) 84 89 }); 90 + 91 + let timeout_secs: u64 = std::env::var("PLC_TIMEOUT_SECS") 92 + .ok() 93 + .and_then(|v| v.parse().ok()) 94 + .unwrap_or(10); 95 + 96 + let connect_timeout_secs: u64 = std::env::var("PLC_CONNECT_TIMEOUT_SECS") 97 + .ok() 98 + .and_then(|v| v.parse().ok()) 99 + .unwrap_or(5); 100 + 101 + let client = Client::builder() 102 + .timeout(Duration::from_secs(timeout_secs)) 103 + .connect_timeout(Duration::from_secs(connect_timeout_secs)) 104 + .pool_max_idle_per_host(5) 105 + .build() 106 + .unwrap_or_else(|_| Client::new()); 107 + 85 108 Self { 86 109 base_url, 87 - client: Client::new(), 110 + client, 88 111 } 89 112 } 90 113
+13
src/rate_limit.rs
··· 20 20 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 21 21 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 22 22 23 + // NOTE: For production deployments with high traffic, prefer using the distributed rate 24 + // limiter (Redis/Valkey-based) via AppState::distributed_rate_limiter. The in-memory 25 + // rate limiters here don't automatically clean up expired entries, which can cause 26 + // memory growth over time with many unique client IPs. The distributed rate limiter 27 + // uses Redis TTL for automatic cleanup and works correctly across multiple instances. 28 + 23 29 #[derive(Clone)] 24 30 pub struct RateLimiters { 25 31 pub login: Arc<KeyedRateLimiter>, ··· 111 117 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 112 118 self.account_creation = Arc::new(RateLimiter::keyed( 113 119 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap())) 120 + )); 121 + self 122 + } 123 + 124 + pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 125 + self.email_update = Arc::new(RateLimiter::keyed( 126 + Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) 114 127 )); 115 128 self 116 129 }
+47 -14
src/repo/mod.rs
··· 22 22 23 23 impl BlockStore for PostgresBlockStore { 24 24 async fn get(&self, cid: &Cid) -> Result<Option<Bytes>, RepoError> { 25 + crate::metrics::record_block_operation("get"); 25 26 let cid_bytes = cid.to_bytes(); 26 27 let row = sqlx::query!("SELECT data FROM blocks WHERE cid = $1", &cid_bytes) 27 28 .fetch_optional(&self.pool) ··· 35 36 } 36 37 37 38 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 39 + crate::metrics::record_block_operation("put"); 38 40 let mut hasher = Sha256::new(); 39 41 hasher.update(data); 40 42 let hash = hasher.finalize(); ··· 52 54 } 53 55 54 56 async fn has(&self, cid: &Cid) -> Result<bool, RepoError> { 57 + crate::metrics::record_block_operation("has"); 55 58 let cid_bytes = cid.to_bytes(); 56 59 let row = sqlx::query!("SELECT 1 as one FROM blocks WHERE cid = $1", &cid_bytes) 57 60 .fetch_optional(&self.pool) ··· 66 69 blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send, 67 70 ) -> Result<(), RepoError> { 68 71 let blocks: Vec<_> = blocks.into_iter().collect(); 69 - for (cid, data) in blocks { 70 - let cid_bytes = cid.to_bytes(); 71 - let data_ref = data.as_ref(); 72 - sqlx::query!( 73 - "INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", 74 - &cid_bytes, 75 - data_ref 76 - ) 77 - .execute(&self.pool) 78 - .await 79 - .map_err(|e| RepoError::storage(e))?; 72 + if blocks.is_empty() { 73 + return Ok(()); 80 74 } 75 + 76 + crate::metrics::record_block_operation("put_many"); 77 + let cids: Vec<Vec<u8>> = blocks.iter().map(|(cid, _)| cid.to_bytes()).collect(); 78 + let data: Vec<&[u8]> = blocks.iter().map(|(_, d)| d.as_ref()).collect(); 79 + 80 + sqlx::query!( 81 + r#" 82 + INSERT INTO blocks (cid, data) 83 + SELECT * FROM UNNEST($1::bytea[], $2::bytea[]) 84 + ON CONFLICT (cid) DO NOTHING 85 + "#, 86 + &cids, 87 + &data as &[&[u8]] 88 + ) 89 + .execute(&self.pool) 90 + .await 91 + .map_err(|e| RepoError::storage(e))?; 92 + 81 93 Ok(()) 82 94 } 83 95 84 96 async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Bytes>>, RepoError> { 85 - let mut results = Vec::new(); 86 - for cid in cids { 87 - results.push(self.get(cid).await?); 97 + if cids.is_empty() { 98 + return Ok(Vec::new()); 88 99 } 100 + 101 + crate::metrics::record_block_operation("get_many"); 102 + let cid_bytes: Vec<Vec<u8>> = cids.iter().map(|c| c.to_bytes()).collect(); 103 + 104 + let rows = sqlx::query!( 105 + "SELECT cid, data FROM blocks WHERE cid = ANY($1)", 106 + &cid_bytes 107 + ) 108 + .fetch_all(&self.pool) 109 + .await 110 + .map_err(|e| RepoError::storage(e))?; 111 + 112 + let found: std::collections::HashMap<Vec<u8>, Bytes> = rows 113 + .into_iter() 114 + .map(|row| (row.cid, Bytes::from(row.data))) 115 + .collect(); 116 + 117 + let results = cid_bytes 118 + .iter() 119 + .map(|cid| found.get(cid).cloned()) 120 + .collect(); 121 + 89 122 Ok(results) 90 123 } 91 124
+88 -1
src/state.rs
··· 21 21 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 22 22 } 23 23 24 + pub enum RateLimitKind { 25 + Login, 26 + AccountCreation, 27 + PasswordReset, 28 + ResetPassword, 29 + RefreshSession, 30 + OAuthToken, 31 + OAuthAuthorize, 32 + OAuthPar, 33 + OAuthIntrospect, 34 + AppPassword, 35 + EmailUpdate, 36 + } 37 + 38 + impl RateLimitKind { 39 + fn key_prefix(&self) -> &'static str { 40 + match self { 41 + Self::Login => "login", 42 + Self::AccountCreation => "account_creation", 43 + Self::PasswordReset => "password_reset", 44 + Self::ResetPassword => "reset_password", 45 + Self::RefreshSession => "refresh_session", 46 + Self::OAuthToken => "oauth_token", 47 + Self::OAuthAuthorize => "oauth_authorize", 48 + Self::OAuthPar => "oauth_par", 49 + Self::OAuthIntrospect => "oauth_introspect", 50 + Self::AppPassword => "app_password", 51 + Self::EmailUpdate => "email_update", 52 + } 53 + } 54 + 55 + fn limit_and_window_ms(&self) -> (u32, u64) { 56 + match self { 57 + Self::Login => (10, 60_000), 58 + Self::AccountCreation => (10, 3_600_000), 59 + Self::PasswordReset => (5, 3_600_000), 60 + Self::ResetPassword => (10, 60_000), 61 + Self::RefreshSession => (60, 60_000), 62 + Self::OAuthToken => (30, 60_000), 63 + Self::OAuthAuthorize => (10, 60_000), 64 + Self::OAuthPar => (30, 60_000), 65 + Self::OAuthIntrospect => (30, 60_000), 66 + Self::AppPassword => (10, 60_000), 67 + Self::EmailUpdate => (5, 3_600_000), 68 + } 69 + } 70 + } 71 + 24 72 impl AppState { 25 73 pub async fn new(db: PgPool) -> Self { 26 74 AuthConfig::init(); 27 75 28 76 let block_store = PostgresBlockStore::new(db.clone()); 29 77 let blob_store = S3BlobStorage::new().await; 30 - let (firehose_tx, _) = broadcast::channel(1000); 78 + let firehose_buffer_size: usize = std::env::var("FIREHOSE_BUFFER_SIZE") 79 + .ok() 80 + .and_then(|v| v.parse().ok()) 81 + .unwrap_or(10000); 82 + let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 31 83 let rate_limiters = Arc::new(RateLimiters::new()); 32 84 let circuit_breakers = Arc::new(CircuitBreakers::new()); 33 85 let (cache, distributed_rate_limiter) = create_cache().await; ··· 51 103 pub fn with_circuit_breakers(mut self, circuit_breakers: CircuitBreakers) -> Self { 52 104 self.circuit_breakers = Arc::new(circuit_breakers); 53 105 self 106 + } 107 + 108 + pub async fn check_rate_limit(&self, kind: RateLimitKind, client_ip: &str) -> bool { 109 + if std::env::var("DISABLE_RATE_LIMITING").is_ok() { 110 + return true; 111 + } 112 + 113 + let key = format!("{}:{}", kind.key_prefix(), client_ip); 114 + let limiter_name = kind.key_prefix(); 115 + let (limit, window_ms) = kind.limit_and_window_ms(); 116 + 117 + if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await { 118 + crate::metrics::record_rate_limit_rejection(limiter_name); 119 + return false; 120 + } 121 + 122 + let limiter = match kind { 123 + RateLimitKind::Login => &self.rate_limiters.login, 124 + RateLimitKind::AccountCreation => &self.rate_limiters.account_creation, 125 + RateLimitKind::PasswordReset => &self.rate_limiters.password_reset, 126 + RateLimitKind::ResetPassword => &self.rate_limiters.reset_password, 127 + RateLimitKind::RefreshSession => &self.rate_limiters.refresh_session, 128 + RateLimitKind::OAuthToken => &self.rate_limiters.oauth_token, 129 + RateLimitKind::OAuthAuthorize => &self.rate_limiters.oauth_authorize, 130 + RateLimitKind::OAuthPar => &self.rate_limiters.oauth_par, 131 + RateLimitKind::OAuthIntrospect => &self.rate_limiters.oauth_introspect, 132 + RateLimitKind::AppPassword => &self.rate_limiters.app_password, 133 + RateLimitKind::EmailUpdate => &self.rate_limiters.email_update, 134 + }; 135 + 136 + let ok = limiter.check_key(&client_ip.to_string()).is_ok(); 137 + if !ok { 138 + crate::metrics::record_rate_limit_rejection(limiter_name); 139 + } 140 + ok 54 141 } 55 142 }
+38 -8
src/storage/mod.rs
··· 3 3 use aws_config::meta::region::RegionProviderChain; 4 4 use aws_sdk_s3::Client; 5 5 use aws_sdk_s3::primitives::ByteStream; 6 + use bytes::Bytes; 6 7 use thiserror::Error; 7 8 8 9 #[derive(Error, Debug)] ··· 18 19 #[async_trait] 19 20 pub trait BlobStorage: Send + Sync { 20 21 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError>; 22 + async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError>; 21 23 async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError>; 24 + async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError>; 22 25 async fn delete(&self, key: &str) -> Result<(), StorageError>; 23 26 } 24 27 ··· 55 58 #[async_trait] 56 59 impl BlobStorage for S3BlobStorage { 57 60 async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 58 - self.client 61 + self.put_bytes(key, Bytes::copy_from_slice(data)).await 62 + } 63 + 64 + async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 65 + let result = self.client 59 66 .put_object() 60 67 .bucket(&self.bucket) 61 68 .key(key) 62 - .body(ByteStream::from(data.to_vec())) 69 + .body(ByteStream::from(data)) 63 70 .send() 64 71 .await 65 - .map_err(|e| StorageError::S3(e.to_string()))?; 72 + .map_err(|e| StorageError::S3(e.to_string())); 73 + 74 + match &result { 75 + Ok(_) => crate::metrics::record_s3_operation("put", "success"), 76 + Err(_) => crate::metrics::record_s3_operation("put", "error"), 77 + } 78 + result?; 66 79 Ok(()) 67 80 } 68 81 69 82 async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 83 + self.get_bytes(key).await.map(|b| b.to_vec()) 84 + } 85 + 86 + async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 70 87 let resp = self 71 88 .client 72 89 .get_object() ··· 74 91 .key(key) 75 92 .send() 76 93 .await 77 - .map_err(|e| StorageError::S3(e.to_string()))?; 94 + .map_err(|e| { 95 + crate::metrics::record_s3_operation("get", "error"); 96 + StorageError::S3(e.to_string()) 97 + })?; 78 98 79 99 let data = resp 80 100 .body 81 101 .collect() 82 102 .await 83 - .map_err(|e| StorageError::S3(e.to_string()))? 103 + .map_err(|e| { 104 + crate::metrics::record_s3_operation("get", "error"); 105 + StorageError::S3(e.to_string()) 106 + })? 84 107 .into_bytes(); 85 108 86 - Ok(data.to_vec()) 109 + crate::metrics::record_s3_operation("get", "success"); 110 + Ok(data) 87 111 } 88 112 89 113 async fn delete(&self, key: &str) -> Result<(), StorageError> { 90 - self.client 114 + let result = self.client 91 115 .delete_object() 92 116 .bucket(&self.bucket) 93 117 .key(key) 94 118 .send() 95 119 .await 96 - .map_err(|e| StorageError::S3(e.to_string()))?; 120 + .map_err(|e| StorageError::S3(e.to_string())); 121 + 122 + match &result { 123 + Ok(_) => crate::metrics::record_s3_operation("delete", "success"), 124 + Err(_) => crate::metrics::record_s3_operation("delete", "error"), 125 + } 126 + result?; 97 127 Ok(()) 98 128 } 99 129 }
+68 -2
src/sync/listener.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 3 use sqlx::postgres::PgListener; 4 - use tracing::{error, info, warn}; 4 + use std::sync::atomic::{AtomicI64, Ordering}; 5 + use tracing::{debug, error, info, warn}; 6 + 7 + static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 5 8 6 9 pub async fn start_sequencer_listener(state: AppState) { 10 + let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq") 11 + .fetch_one(&state.db) 12 + .await 13 + .unwrap_or(Some(0)) 14 + .unwrap_or(0); 15 + LAST_BROADCAST_SEQ.store(initial_seq, Ordering::SeqCst); 16 + info!(initial_seq = initial_seq, "Initialized sequencer listener"); 17 + 7 18 tokio::spawn(async move { 8 19 info!("Starting sequencer listener background task"); 9 20 loop { ··· 20 31 listener.listen("repo_updates").await?; 21 32 info!("Connected to Postgres and listening for 'repo_updates'"); 22 33 34 + let catchup_start = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 35 + let events = sqlx::query_as!( 36 + SequencedEvent, 37 + r#" 38 + SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids 39 + FROM repo_seq 40 + WHERE seq > $1 41 + ORDER BY seq ASC 42 + "#, 43 + catchup_start 44 + ) 45 + .fetch_all(&state.db) 46 + .await?; 47 + 48 + if !events.is_empty() { 49 + info!(count = events.len(), from_seq = catchup_start, "Broadcasting catch-up events"); 50 + for event in events { 51 + let seq = event.seq; 52 + let _ = state.firehose_tx.send(event); 53 + LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 54 + } 55 + } 56 + 23 57 loop { 24 58 let notification = listener.recv().await?; 25 59 let payload = notification.payload(); ··· 32 66 } 33 67 }; 34 68 69 + let last_seq = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 70 + if seq_id <= last_seq { 71 + debug!(seq = seq_id, last = last_seq, "Skipping already-broadcast event"); 72 + continue; 73 + } 74 + 75 + if seq_id > last_seq + 1 { 76 + let gap_events = sqlx::query_as!( 77 + SequencedEvent, 78 + r#" 79 + SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids 80 + FROM repo_seq 81 + WHERE seq > $1 AND seq < $2 82 + ORDER BY seq ASC 83 + "#, 84 + last_seq, 85 + seq_id 86 + ) 87 + .fetch_all(&state.db) 88 + .await?; 89 + 90 + if !gap_events.is_empty() { 91 + debug!(count = gap_events.len(), "Filling sequence gap"); 92 + for event in gap_events { 93 + let seq = event.seq; 94 + let _ = state.firehose_tx.send(event); 95 + LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 96 + } 97 + } 98 + } 99 + 35 100 let event = sqlx::query_as!( 36 101 SequencedEvent, 37 102 r#" ··· 46 111 47 112 if let Some(event) = event { 48 113 let _ = state.firehose_tx.send(event); 114 + LAST_BROADCAST_SEQ.store(seq_id, Ordering::SeqCst); 49 115 } else { 50 - warn!("Received notification for seq {} but could not find row in repo_seq", seq_id); 116 + warn!(seq = seq_id, "Received notification but could not find row in repo_seq"); 51 117 } 52 118 } 53 119 }
+70 -11
src/sync/subscribe_repos.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 - use crate::sync::util::format_event_for_sending; 3 + use crate::sync::util::{format_event_for_sending, format_event_with_prefetched_blocks, prefetch_blocks_for_events}; 4 4 use axum::{ 5 5 extract::{ws::Message, ws::WebSocket, ws::WebSocketUpgrade, Query, State}, 6 6 response::Response, 7 7 }; 8 8 use futures::{sink::SinkExt, stream::StreamExt}; 9 9 use serde::Deserialize; 10 + use std::sync::atomic::{AtomicUsize, Ordering}; 11 + use tokio::sync::broadcast::error::RecvError; 10 12 use tracing::{error, info, warn}; 11 13 12 14 const BACKFILL_BATCH_SIZE: i64 = 1000; 15 + static SUBSCRIBER_COUNT: AtomicUsize = AtomicUsize::new(0); 13 16 14 17 #[derive(Deserialize)] 15 18 pub struct SubscribeReposParams { ··· 35 38 Ok(()) 36 39 } 37 40 41 + pub fn get_subscriber_count() -> usize { 42 + SUBSCRIBER_COUNT.load(Ordering::SeqCst) 43 + } 44 + 38 45 async fn handle_socket(mut socket: WebSocket, state: AppState, params: SubscribeReposParams) { 39 - info!(cursor = ?params.cursor, "New firehose subscriber"); 46 + let count = SUBSCRIBER_COUNT.fetch_add(1, Ordering::SeqCst) + 1; 47 + crate::metrics::set_firehose_subscribers(count); 48 + info!(cursor = ?params.cursor, subscribers = count, "New firehose subscriber"); 49 + 50 + let _ = handle_socket_inner(&mut socket, &state, params).await; 40 51 52 + let count = SUBSCRIBER_COUNT.fetch_sub(1, Ordering::SeqCst) - 1; 53 + crate::metrics::set_firehose_subscribers(count); 54 + info!(subscribers = count, "Firehose subscriber disconnected"); 55 + } 56 + 57 + async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> { 41 58 if let Some(cursor) = params.cursor { 42 59 let mut current_cursor = cursor; 43 60 loop { ··· 61 78 if events.is_empty() { 62 79 break; 63 80 } 64 - for event in &events { 81 + 82 + let events_count = events.len(); 83 + 84 + let prefetched = match prefetch_blocks_for_events(state, &events).await { 85 + Ok(blocks) => blocks, 86 + Err(e) => { 87 + error!("Failed to prefetch blocks for backfill: {}", e); 88 + socket.close().await.ok(); 89 + return Err(()); 90 + } 91 + }; 92 + 93 + for event in events { 65 94 current_cursor = event.seq; 66 - if let Err(e) = send_event(&mut socket, &state, event.clone()).await { 95 + let bytes = match format_event_with_prefetched_blocks(event, &prefetched).await { 96 + Ok(b) => b, 97 + Err(e) => { 98 + warn!("Failed to format backfill event: {}", e); 99 + return Err(()); 100 + } 101 + }; 102 + if let Err(e) = socket.send(Message::Binary(bytes.into())).await { 67 103 warn!("Failed to send backfill event: {}", e); 68 - return; 104 + return Err(()); 69 105 } 106 + crate::metrics::record_firehose_event(); 70 107 } 71 - if (events.len() as i64) < BACKFILL_BATCH_SIZE { 108 + if (events_count as i64) < BACKFILL_BATCH_SIZE { 72 109 break; 73 110 } 74 111 } 75 112 Err(e) => { 76 113 error!("Failed to fetch backfill events: {}", e); 77 114 socket.close().await.ok(); 78 - return; 115 + return Err(()); 79 116 } 80 117 } 81 118 } 82 119 } 83 120 84 121 let mut rx = state.firehose_tx.subscribe(); 122 + let max_lag_before_disconnect: u64 = std::env::var("FIREHOSE_MAX_LAG") 123 + .ok() 124 + .and_then(|v| v.parse().ok()) 125 + .unwrap_or(5000); 85 126 86 127 loop { 87 128 tokio::select! { 88 - Ok(event) = rx.recv() => { 89 - if let Err(e) = send_event(&mut socket, &state, event).await { 90 - warn!("Failed to send event: {}", e); 91 - break; 129 + result = rx.recv() => { 130 + match result { 131 + Ok(event) => { 132 + if let Err(e) = send_event(socket, state, event).await { 133 + warn!("Failed to send event: {}", e); 134 + break; 135 + } 136 + crate::metrics::record_firehose_event(); 137 + } 138 + Err(RecvError::Lagged(skipped)) => { 139 + warn!(skipped = skipped, "Firehose subscriber lagged behind"); 140 + if skipped > max_lag_before_disconnect { 141 + warn!(skipped = skipped, max_lag = max_lag_before_disconnect, 142 + "Disconnecting slow firehose consumer"); 143 + break; 144 + } 145 + } 146 + Err(RecvError::Closed) => { 147 + info!("Firehose channel closed"); 148 + break; 149 + } 92 150 } 93 151 } 94 152 Some(Ok(msg)) = socket.next() => { ··· 102 160 } 103 161 } 104 162 } 163 + Ok(()) 105 164 }
+85 -8
src/sync/util.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 3 use crate::sync::frame::{CommitFrame, Frame, FrameData}; 4 + use bytes::Bytes; 4 5 use cid::Cid; 5 6 use jacquard_repo::car::write_car_bytes; 6 7 use jacquard_repo::storage::BlockStore; 8 + use std::collections::{BTreeMap, HashMap}; 7 9 use std::str::FromStr; 8 10 9 11 pub async fn format_event_for_sending( ··· 15 17 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 16 18 17 19 let car_bytes = if !block_cids_str.is_empty() { 20 + let cids: Vec<Cid> = block_cids_str 21 + .iter() 22 + .filter_map(|s| Cid::from_str(s).ok()) 23 + .collect(); 24 + 25 + let fetched = state.block_store.get_many(&cids).await?; 26 + 18 27 let mut blocks = std::collections::BTreeMap::new(); 28 + for (cid, data_opt) in cids.into_iter().zip(fetched.into_iter()) { 29 + if let Some(data) = data_opt { 30 + blocks.insert(cid, data); 31 + } 32 + } 33 + 34 + let root = Cid::from_str(&frame.commit)?; 35 + write_car_bytes(root, blocks).await? 36 + } else { 37 + Vec::new() 38 + }; 39 + frame.blocks = car_bytes; 40 + 41 + let frame = Frame { 42 + op: 1, 43 + data: FrameData::Commit(Box::new(frame)), 44 + }; 45 + 46 + let mut bytes = Vec::new(); 47 + serde_ipld_dagcbor::to_writer(&mut bytes, &frame)?; 48 + Ok(bytes) 49 + } 50 + 51 + pub async fn prefetch_blocks_for_events( 52 + state: &AppState, 53 + events: &[SequencedEvent], 54 + ) -> Result<HashMap<Cid, Bytes>, anyhow::Error> { 55 + let mut all_cids: Vec<Cid> = Vec::new(); 56 + 57 + for event in events { 58 + if let Some(ref block_cids_str) = event.blocks_cids { 59 + for s in block_cids_str { 60 + if let Ok(cid) = Cid::from_str(s) { 61 + all_cids.push(cid); 62 + } 63 + } 64 + } 65 + } 19 66 20 - for cid_str in block_cids_str { 21 - let cid = Cid::from_str(&cid_str)?; 22 - let data = state 23 - .block_store 24 - .get(&cid) 25 - .await? 26 - .ok_or_else(|| anyhow::anyhow!("Block not found: {}", cid))?; 27 - blocks.insert(cid, data); 67 + all_cids.sort(); 68 + all_cids.dedup(); 69 + 70 + if all_cids.is_empty() { 71 + return Ok(HashMap::new()); 72 + } 73 + 74 + let fetched = state.block_store.get_many(&all_cids).await?; 75 + 76 + let mut blocks_map = HashMap::new(); 77 + for (cid, data_opt) in all_cids.into_iter().zip(fetched.into_iter()) { 78 + if let Some(data) = data_opt { 79 + blocks_map.insert(cid, data); 80 + } 81 + } 82 + 83 + Ok(blocks_map) 84 + } 85 + 86 + pub async fn format_event_with_prefetched_blocks( 87 + event: SequencedEvent, 88 + prefetched: &HashMap<Cid, Bytes>, 89 + ) -> Result<Vec<u8>, anyhow::Error> { 90 + let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 91 + let mut frame: CommitFrame = event.try_into() 92 + .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 93 + 94 + let car_bytes = if !block_cids_str.is_empty() { 95 + let cids: Vec<Cid> = block_cids_str 96 + .iter() 97 + .filter_map(|s| Cid::from_str(s).ok()) 98 + .collect(); 99 + 100 + let mut blocks = BTreeMap::new(); 101 + for cid in cids { 102 + if let Some(data) = prefetched.get(&cid) { 103 + blocks.insert(cid, data.clone()); 104 + } 28 105 } 29 106 30 107 let root = Cid::from_str(&frame.commit)?;
-139
tests/auth.rs
··· 1 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 2 - use bspds::auth; 3 - use chrono::{Duration, Utc}; 4 - use k256::SecretKey; 5 - use k256::ecdsa::{SigningKey, signature::Signer}; 6 - use rand::rngs::OsRng; 7 - use serde_json::json; 8 - 9 - #[test] 10 - fn test_jwt_flow() { 11 - let secret_key = SecretKey::random(&mut OsRng); 12 - let key_bytes = secret_key.to_bytes(); 13 - let did = "did:plc:test"; 14 - 15 - let token = auth::create_access_token(did, &key_bytes).expect("create token"); 16 - let data = auth::verify_access_token(&token, &key_bytes).expect("verify access token"); 17 - assert_eq!(data.claims.sub, did); 18 - assert_eq!(data.claims.iss, did); 19 - assert_eq!(data.claims.scope, Some(auth::SCOPE_ACCESS.to_string())); 20 - 21 - let r_token = auth::create_refresh_token(did, &key_bytes).expect("create refresh token"); 22 - let r_data = auth::verify_refresh_token(&r_token, &key_bytes).expect("verify refresh token"); 23 - assert_eq!(r_data.claims.scope, Some(auth::SCOPE_REFRESH.to_string())); 24 - 25 - let aud = "did:web:service"; 26 - let lxm = "com.example.test"; 27 - let s_token = 28 - auth::create_service_token(did, aud, lxm, &key_bytes).expect("create service token"); 29 - let s_data = auth::verify_token(&s_token, &key_bytes).expect("verify service token"); 30 - assert_eq!(s_data.claims.aud, aud); 31 - assert_eq!(s_data.claims.lxm, Some(lxm.to_string())); 32 - } 33 - 34 - #[test] 35 - fn test_token_type_confusion_prevented() { 36 - let secret_key = SecretKey::random(&mut OsRng); 37 - let key_bytes = secret_key.to_bytes(); 38 - let did = "did:plc:test"; 39 - 40 - let access_token = auth::create_access_token(did, &key_bytes).expect("create access token"); 41 - let refresh_token = auth::create_refresh_token(did, &key_bytes).expect("create refresh token"); 42 - 43 - assert!(auth::verify_access_token(&access_token, &key_bytes).is_ok()); 44 - assert!(auth::verify_access_token(&refresh_token, &key_bytes).is_err()); 45 - 46 - assert!(auth::verify_refresh_token(&refresh_token, &key_bytes).is_ok()); 47 - assert!(auth::verify_refresh_token(&access_token, &key_bytes).is_err()); 48 - } 49 - 50 - #[test] 51 - fn test_verify_fails_with_wrong_key() { 52 - let secret_key1 = SecretKey::random(&mut OsRng); 53 - let key_bytes1 = secret_key1.to_bytes(); 54 - 55 - let secret_key2 = SecretKey::random(&mut OsRng); 56 - let key_bytes2 = secret_key2.to_bytes(); 57 - 58 - let did = "did:plc:test"; 59 - let token = auth::create_access_token(did, &key_bytes1).expect("create token"); 60 - 61 - let result = auth::verify_token(&token, &key_bytes2); 62 - assert!(result.is_err()); 63 - } 64 - 65 - #[test] 66 - fn test_token_expiration() { 67 - let secret_key = SecretKey::random(&mut OsRng); 68 - let key_bytes = secret_key.to_bytes(); 69 - let signing_key = SigningKey::from_slice(&key_bytes).expect("key"); 70 - 71 - let header = json!({ 72 - "alg": "ES256K", 73 - "typ": "JWT" 74 - }); 75 - let claims = json!({ 76 - "iss": "did:plc:test", 77 - "sub": "did:plc:test", 78 - "aud": "did:web:test", 79 - "exp": (Utc::now() - Duration::seconds(10)).timestamp(), 80 - "iat": (Utc::now() - Duration::minutes(1)).timestamp(), 81 - "jti": "unique", 82 - }); 83 - 84 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 85 - let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 86 - let message = format!("{}.{}", header_b64, claims_b64); 87 - let signature: k256::ecdsa::Signature = signing_key.sign(message.as_bytes()); 88 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 89 - let token = format!("{}.{}", message, signature_b64); 90 - 91 - let result = auth::verify_token(&token, &key_bytes); 92 - match result { 93 - Ok(_) => panic!("Token should be expired"), 94 - Err(e) => assert_eq!(e.to_string(), "Token expired"), 95 - } 96 - } 97 - 98 - #[test] 99 - fn test_invalid_token_format() { 100 - let secret_key = SecretKey::random(&mut OsRng); 101 - let key_bytes = secret_key.to_bytes(); 102 - 103 - assert!(auth::verify_token("invalid.token", &key_bytes).is_err()); 104 - assert!(auth::verify_token("too.many.parts.here", &key_bytes).is_err()); 105 - assert!(auth::verify_token("bad_base64.payload.sig", &key_bytes).is_err()); 106 - } 107 - 108 - #[test] 109 - fn test_tampered_token() { 110 - let secret_key = SecretKey::random(&mut OsRng); 111 - let key_bytes = secret_key.to_bytes(); 112 - let did = "did:plc:test"; 113 - 114 - let token = auth::create_access_token(did, &key_bytes).expect("create token"); 115 - let parts: Vec<&str> = token.split('.').collect(); 116 - 117 - let claims_json = String::from_utf8(URL_SAFE_NO_PAD.decode(parts[1]).unwrap()).unwrap(); 118 - let mut claims: serde_json::Value = serde_json::from_str(&claims_json).unwrap(); 119 - claims["sub"] = json!("did:plc:hacker"); 120 - let tampered_claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap()); 121 - 122 - let tampered_token = format!("{}.{}.{}", parts[0], tampered_claims_b64, parts[2]); 123 - 124 - let result = auth::verify_token(&tampered_token, &key_bytes); 125 - assert!(result.is_err()); 126 - } 127 - 128 - #[test] 129 - fn test_get_did_from_token() { 130 - let secret_key = SecretKey::random(&mut OsRng); 131 - let key_bytes = secret_key.to_bytes(); 132 - let did = "did:plc:test"; 133 - 134 - let token = auth::create_access_token(did, &key_bytes).expect("create token"); 135 - let extracted_did = auth::get_did_from_token(&token).expect("get did"); 136 - assert_eq!(extracted_did, did); 137 - 138 - assert!(auth::get_did_from_token("bad.token").is_err()); 139 - }
+101 -6
tests/common/mod.rs
··· 352 352 } 353 353 354 354 async fn spawn_app(database_url: String) -> String { 355 + use bspds::rate_limit::RateLimiters; 356 + 355 357 let pool = PgPoolOptions::new() 356 358 .max_connections(50) 357 359 .connect(&database_url) ··· 371 373 std::env::set_var("PDS_HOSTNAME", addr.to_string()); 372 374 } 373 375 374 - let state = AppState::new(pool).await; 376 + let rate_limiters = RateLimiters::new() 377 + .with_login_limit(10000) 378 + .with_account_creation_limit(10000) 379 + .with_password_reset_limit(10000) 380 + .with_email_update_limit(10000) 381 + .with_oauth_authorize_limit(10000) 382 + .with_oauth_token_limit(10000); 383 + 384 + let state = AppState::new(pool).await.with_rate_limiters(rate_limiters); 375 385 376 386 bspds::sync::listener::start_sequencer_listener(state.clone()).await; 377 387 ··· 402 412 panic!("DATABASE_URL must be set with external-infra feature"); 403 413 } 404 414 } 415 + } 416 + 417 + #[allow(dead_code)] 418 + pub async fn verify_new_account(client: &Client, did: &str) -> String { 419 + let conn_str = get_db_connection_string().await; 420 + let pool = sqlx::postgres::PgPoolOptions::new() 421 + .max_connections(2) 422 + .connect(&conn_str) 423 + .await 424 + .expect("Failed to connect to test database"); 425 + 426 + let verification_code: String = sqlx::query_scalar!( 427 + "SELECT email_confirmation_code FROM users WHERE did = $1", 428 + did 429 + ) 430 + .fetch_one(&pool) 431 + .await 432 + .expect("Failed to get verification code") 433 + .expect("No verification code found"); 434 + 435 + let confirm_payload = json!({ 436 + "did": did, 437 + "verificationCode": verification_code 438 + }); 439 + 440 + let confirm_res = client 441 + .post(format!( 442 + "{}/xrpc/com.atproto.server.confirmSignup", 443 + base_url().await 444 + )) 445 + .json(&confirm_payload) 446 + .send() 447 + .await 448 + .expect("confirmSignup request failed"); 449 + 450 + assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed"); 451 + let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup"); 452 + confirm_body["accessJwt"] 453 + .as_str() 454 + .expect("No accessJwt in confirmSignup response") 455 + .to_string() 405 456 } 406 457 407 458 #[allow(dead_code)] ··· 514 565 515 566 if res.status() == StatusCode::OK { 516 567 let body: Value = res.json().await.expect("Invalid JSON"); 517 - let access_jwt = body["accessJwt"] 518 - .as_str() 519 - .expect("No accessJwt") 520 - .to_string(); 568 + 569 + if let Some(access_jwt) = body["accessJwt"].as_str() { 570 + let did = body["did"].as_str().expect("No did").to_string(); 571 + return (access_jwt.to_string(), did); 572 + } 573 + 521 574 let did = body["did"].as_str().expect("No did").to_string(); 522 - return (access_jwt, did); 575 + 576 + let conn_str = get_db_connection_string().await; 577 + let pool = sqlx::postgres::PgPoolOptions::new() 578 + .max_connections(2) 579 + .connect(&conn_str) 580 + .await 581 + .expect("Failed to connect to test database"); 582 + 583 + let verification_code: String = sqlx::query_scalar!( 584 + "SELECT email_confirmation_code FROM users WHERE did = $1", 585 + &did 586 + ) 587 + .fetch_one(&pool) 588 + .await 589 + .expect("Failed to get verification code") 590 + .expect("No verification code found"); 591 + 592 + let confirm_payload = json!({ 593 + "did": did, 594 + "verificationCode": verification_code 595 + }); 596 + 597 + let confirm_res = client 598 + .post(format!( 599 + "{}/xrpc/com.atproto.server.confirmSignup", 600 + base_url().await 601 + )) 602 + .json(&confirm_payload) 603 + .send() 604 + .await 605 + .expect("confirmSignup request failed"); 606 + 607 + if confirm_res.status() == StatusCode::OK { 608 + let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup"); 609 + let access_jwt = confirm_body["accessJwt"] 610 + .as_str() 611 + .expect("No accessJwt in confirmSignup response") 612 + .to_string(); 613 + return (access_jwt, did); 614 + } 615 + 616 + last_error = format!("confirmSignup failed: {:?}", confirm_res.text().await); 617 + continue; 523 618 } 524 619 525 620 last_error = format!("Status {}: {:?}", res.status(), res.text().await);
+70 -139
tests/delete_account.rs
··· 6 6 use chrono::Utc; 7 7 use reqwest::StatusCode; 8 8 use serde_json::{Value, json}; 9 + use sqlx::PgPool; 10 + 11 + async fn get_pool() -> PgPool { 12 + let conn_str = get_db_connection_string().await; 13 + sqlx::postgres::PgPoolOptions::new() 14 + .max_connections(5) 15 + .connect(&conn_str) 16 + .await 17 + .expect("Failed to connect to test database") 18 + } 19 + 20 + async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) { 21 + let res = client 22 + .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 23 + .json(&json!({ 24 + "handle": handle, 25 + "email": email, 26 + "password": password 27 + })) 28 + .send() 29 + .await 30 + .expect("Failed to create account"); 31 + assert_eq!(res.status(), StatusCode::OK); 32 + let body: Value = res.json().await.expect("Invalid JSON"); 33 + let did = body["did"].as_str().expect("No did").to_string(); 34 + let jwt = verify_new_account(client, &did).await; 35 + (did, jwt) 36 + } 9 37 10 38 #[tokio::test] 11 39 async fn test_delete_account_full_flow() { 12 40 let client = client(); 41 + let base_url = base_url().await; 13 42 let ts = Utc::now().timestamp_millis(); 14 43 let handle = format!("delete-test-{}.test", ts); 15 44 let email = format!("delete-test-{}@test.com", ts); 16 45 let password = "delete-password-123"; 17 46 18 - let create_payload = json!({ 19 - "handle": handle, 20 - "email": email, 21 - "password": password 22 - }); 23 - let create_res = client 24 - .post(format!( 25 - "{}/xrpc/com.atproto.server.createAccount", 26 - base_url().await 27 - )) 28 - .json(&create_payload) 29 - .send() 30 - .await 31 - .expect("Failed to create account"); 32 - assert_eq!(create_res.status(), StatusCode::OK); 33 - let create_body: Value = create_res.json().await.unwrap(); 34 - let did = create_body["did"].as_str().unwrap().to_string(); 35 - let jwt = create_body["accessJwt"].as_str().unwrap().to_string(); 47 + let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, password).await; 36 48 37 49 let request_delete_res = client 38 50 .post(format!( 39 51 "{}/xrpc/com.atproto.server.requestAccountDelete", 40 - base_url().await 52 + base_url 41 53 )) 42 54 .bearer_auth(&jwt) 43 55 .send() ··· 45 57 .expect("Failed to request account deletion"); 46 58 assert_eq!(request_delete_res.status(), StatusCode::OK); 47 59 48 - let db_url = get_db_connection_string().await; 49 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 60 + let pool = get_pool().await; 50 61 51 62 let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 52 63 .fetch_one(&pool) ··· 62 73 let delete_res = client 63 74 .post(format!( 64 75 "{}/xrpc/com.atproto.server.deleteAccount", 65 - base_url().await 76 + base_url 66 77 )) 67 78 .json(&delete_payload) 68 79 .send() ··· 79 90 let session_res = client 80 91 .get(format!( 81 92 "{}/xrpc/com.atproto.server.getSession", 82 - base_url().await 93 + base_url 83 94 )) 84 95 .bearer_auth(&jwt) 85 96 .send() ··· 91 102 #[tokio::test] 92 103 async fn test_delete_account_wrong_password() { 93 104 let client = client(); 105 + let base_url = base_url().await; 94 106 let ts = Utc::now().timestamp_millis(); 95 107 let handle = format!("delete-wrongpw-{}.test", ts); 96 108 let email = format!("delete-wrongpw-{}@test.com", ts); 97 109 let password = "correct-password"; 98 110 99 - let create_payload = json!({ 100 - "handle": handle, 101 - "email": email, 102 - "password": password 103 - }); 104 - let create_res = client 105 - .post(format!( 106 - "{}/xrpc/com.atproto.server.createAccount", 107 - base_url().await 108 - )) 109 - .json(&create_payload) 110 - .send() 111 - .await 112 - .expect("Failed to create account"); 113 - assert_eq!(create_res.status(), StatusCode::OK); 114 - let create_body: Value = create_res.json().await.unwrap(); 115 - let did = create_body["did"].as_str().unwrap().to_string(); 116 - let jwt = create_body["accessJwt"].as_str().unwrap().to_string(); 111 + let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, password).await; 117 112 118 113 let request_delete_res = client 119 114 .post(format!( 120 115 "{}/xrpc/com.atproto.server.requestAccountDelete", 121 - base_url().await 116 + base_url 122 117 )) 123 118 .bearer_auth(&jwt) 124 119 .send() ··· 126 121 .expect("Failed to request account deletion"); 127 122 assert_eq!(request_delete_res.status(), StatusCode::OK); 128 123 129 - let db_url = get_db_connection_string().await; 130 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 124 + let pool = get_pool().await; 131 125 132 126 let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 133 127 .fetch_one(&pool) ··· 143 137 let delete_res = client 144 138 .post(format!( 145 139 "{}/xrpc/com.atproto.server.deleteAccount", 146 - base_url().await 140 + base_url 147 141 )) 148 142 .json(&delete_payload) 149 143 .send() ··· 158 152 #[tokio::test] 159 153 async fn test_delete_account_invalid_token() { 160 154 let client = client(); 155 + let base_url = base_url().await; 161 156 let ts = Utc::now().timestamp_millis(); 162 157 let handle = format!("delete-badtoken-{}.test", ts); 163 158 let email = format!("delete-badtoken-{}@test.com", ts); 164 159 let password = "delete-password"; 165 160 166 - let create_payload = json!({ 167 - "handle": handle, 168 - "email": email, 169 - "password": password 170 - }); 171 161 let create_res = client 172 162 .post(format!( 173 163 "{}/xrpc/com.atproto.server.createAccount", 174 - base_url().await 164 + base_url 175 165 )) 176 - .json(&create_payload) 166 + .json(&json!({ 167 + "handle": handle, 168 + "email": email, 169 + "password": password 170 + })) 177 171 .send() 178 172 .await 179 173 .expect("Failed to create account"); ··· 189 183 let delete_res = client 190 184 .post(format!( 191 185 "{}/xrpc/com.atproto.server.deleteAccount", 192 - base_url().await 186 + base_url 193 187 )) 194 188 .json(&delete_payload) 195 189 .send() ··· 204 198 #[tokio::test] 205 199 async fn test_delete_account_expired_token() { 206 200 let client = client(); 201 + let base_url = base_url().await; 207 202 let ts = Utc::now().timestamp_millis(); 208 203 let handle = format!("delete-expired-{}.test", ts); 209 204 let email = format!("delete-expired-{}@test.com", ts); 210 205 let password = "delete-password"; 211 206 212 - let create_payload = json!({ 213 - "handle": handle, 214 - "email": email, 215 - "password": password 216 - }); 217 - let create_res = client 218 - .post(format!( 219 - "{}/xrpc/com.atproto.server.createAccount", 220 - base_url().await 221 - )) 222 - .json(&create_payload) 223 - .send() 224 - .await 225 - .expect("Failed to create account"); 226 - assert_eq!(create_res.status(), StatusCode::OK); 227 - let create_body: Value = create_res.json().await.unwrap(); 228 - let did = create_body["did"].as_str().unwrap().to_string(); 229 - let jwt = create_body["accessJwt"].as_str().unwrap().to_string(); 207 + let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, password).await; 230 208 231 209 let request_delete_res = client 232 210 .post(format!( 233 211 "{}/xrpc/com.atproto.server.requestAccountDelete", 234 - base_url().await 212 + base_url 235 213 )) 236 214 .bearer_auth(&jwt) 237 215 .send() ··· 239 217 .expect("Failed to request account deletion"); 240 218 assert_eq!(request_delete_res.status(), StatusCode::OK); 241 219 242 - let db_url = get_db_connection_string().await; 243 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 220 + let pool = get_pool().await; 244 221 245 222 let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 246 223 .fetch_one(&pool) ··· 264 241 let delete_res = client 265 242 .post(format!( 266 243 "{}/xrpc/com.atproto.server.deleteAccount", 267 - base_url().await 244 + base_url 268 245 )) 269 246 .json(&delete_payload) 270 247 .send() ··· 279 256 #[tokio::test] 280 257 async fn test_delete_account_token_mismatch() { 281 258 let client = client(); 259 + let base_url = base_url().await; 282 260 let ts = Utc::now().timestamp_millis(); 283 261 284 262 let handle1 = format!("delete-user1-{}.test", ts); 285 263 let email1 = format!("delete-user1-{}@test.com", ts); 286 264 let password1 = "user1-password"; 287 265 288 - let create1_res = client 289 - .post(format!( 290 - "{}/xrpc/com.atproto.server.createAccount", 291 - base_url().await 292 - )) 293 - .json(&json!({ 294 - "handle": handle1, 295 - "email": email1, 296 - "password": password1 297 - })) 298 - .send() 299 - .await 300 - .expect("Failed to create account 1"); 301 - assert_eq!(create1_res.status(), StatusCode::OK); 302 - let create1_body: Value = create1_res.json().await.unwrap(); 303 - let did1 = create1_body["did"].as_str().unwrap().to_string(); 304 - let jwt1 = create1_body["accessJwt"].as_str().unwrap().to_string(); 266 + let (did1, jwt1) = create_verified_account(&client, &base_url, &handle1, &email1, password1).await; 305 267 306 268 let handle2 = format!("delete-user2-{}.test", ts); 307 269 let email2 = format!("delete-user2-{}@test.com", ts); 308 270 let password2 = "user2-password"; 309 271 310 - let create2_res = client 311 - .post(format!( 312 - "{}/xrpc/com.atproto.server.createAccount", 313 - base_url().await 314 - )) 315 - .json(&json!({ 316 - "handle": handle2, 317 - "email": email2, 318 - "password": password2 319 - })) 320 - .send() 321 - .await 322 - .expect("Failed to create account 2"); 323 - assert_eq!(create2_res.status(), StatusCode::OK); 324 - let create2_body: Value = create2_res.json().await.unwrap(); 325 - let did2 = create2_body["did"].as_str().unwrap().to_string(); 272 + let (did2, _) = create_verified_account(&client, &base_url, &handle2, &email2, password2).await; 326 273 327 274 let request_delete_res = client 328 275 .post(format!( 329 276 "{}/xrpc/com.atproto.server.requestAccountDelete", 330 - base_url().await 277 + base_url 331 278 )) 332 279 .bearer_auth(&jwt1) 333 280 .send() ··· 335 282 .expect("Failed to request account deletion"); 336 283 assert_eq!(request_delete_res.status(), StatusCode::OK); 337 284 338 - let db_url = get_db_connection_string().await; 339 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 285 + let pool = get_pool().await; 340 286 341 287 let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did1) 342 288 .fetch_one(&pool) ··· 352 298 let delete_res = client 353 299 .post(format!( 354 300 "{}/xrpc/com.atproto.server.deleteAccount", 355 - base_url().await 301 + base_url 356 302 )) 357 303 .json(&delete_payload) 358 304 .send() ··· 367 313 #[tokio::test] 368 314 async fn test_delete_account_with_app_password() { 369 315 let client = client(); 316 + let base_url = base_url().await; 370 317 let ts = Utc::now().timestamp_millis(); 371 318 let handle = format!("delete-apppw-{}.test", ts); 372 319 let email = format!("delete-apppw-{}@test.com", ts); 373 320 let main_password = "main-password-123"; 374 321 375 - let create_payload = json!({ 376 - "handle": handle, 377 - "email": email, 378 - "password": main_password 379 - }); 380 - let create_res = client 381 - .post(format!( 382 - "{}/xrpc/com.atproto.server.createAccount", 383 - base_url().await 384 - )) 385 - .json(&create_payload) 386 - .send() 387 - .await 388 - .expect("Failed to create account"); 389 - assert_eq!(create_res.status(), StatusCode::OK); 390 - let create_body: Value = create_res.json().await.unwrap(); 391 - let did = create_body["did"].as_str().unwrap().to_string(); 392 - let jwt = create_body["accessJwt"].as_str().unwrap().to_string(); 322 + let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, main_password).await; 393 323 394 324 let app_password_res = client 395 325 .post(format!( 396 326 "{}/xrpc/com.atproto.server.createAppPassword", 397 - base_url().await 327 + base_url 398 328 )) 399 329 .bearer_auth(&jwt) 400 330 .json(&json!({ "name": "delete-test-app" })) ··· 408 338 let request_delete_res = client 409 339 .post(format!( 410 340 "{}/xrpc/com.atproto.server.requestAccountDelete", 411 - base_url().await 341 + base_url 412 342 )) 413 343 .bearer_auth(&jwt) 414 344 .send() ··· 416 346 .expect("Failed to request account deletion"); 417 347 assert_eq!(request_delete_res.status(), StatusCode::OK); 418 348 419 - let db_url = get_db_connection_string().await; 420 - let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB"); 349 + let pool = get_pool().await; 421 350 422 351 let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did) 423 352 .fetch_one(&pool) ··· 433 362 let delete_res = client 434 363 .post(format!( 435 364 "{}/xrpc/com.atproto.server.deleteAccount", 436 - base_url().await 365 + base_url 437 366 )) 438 367 .json(&delete_payload) 439 368 .send() ··· 451 380 #[tokio::test] 452 381 async fn test_delete_account_missing_fields() { 453 382 let client = client(); 383 + let base_url = base_url().await; 454 384 455 385 let res1 = client 456 386 .post(format!( 457 387 "{}/xrpc/com.atproto.server.deleteAccount", 458 - base_url().await 388 + base_url 459 389 )) 460 390 .json(&json!({ 461 391 "password": "test", ··· 469 399 let res2 = client 470 400 .post(format!( 471 401 "{}/xrpc/com.atproto.server.deleteAccount", 472 - base_url().await 402 + base_url 473 403 )) 474 404 .json(&json!({ 475 405 "did": "did:web:test", ··· 483 413 let res3 = client 484 414 .post(format!( 485 415 "{}/xrpc/com.atproto.server.deleteAccount", 486 - base_url().await 416 + base_url 487 417 )) 488 418 .json(&json!({ 489 419 "did": "did:web:test", ··· 498 428 #[tokio::test] 499 429 async fn test_delete_account_nonexistent_user() { 500 430 let client = client(); 431 + let base_url = base_url().await; 501 432 502 433 let delete_payload = json!({ 503 434 "did": "did:web:nonexistent.user", ··· 507 438 let delete_res = client 508 439 .post(format!( 509 440 "{}/xrpc/com.atproto.server.deleteAccount", 510 - base_url().await 441 + base_url 511 442 )) 512 443 .json(&delete_payload) 513 444 .send()
+50 -187
tests/email_update.rs
··· 13 13 .expect("Failed to connect to test database") 14 14 } 15 15 16 + async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String { 17 + let res = client 18 + .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 19 + .json(&json!({ 20 + "handle": handle, 21 + "email": email, 22 + "password": "password" 23 + })) 24 + .send() 25 + .await 26 + .expect("Failed to create account"); 27 + assert_eq!(res.status(), StatusCode::OK); 28 + let body: Value = res.json().await.expect("Invalid JSON"); 29 + let did = body["did"].as_str().expect("No did"); 30 + common::verify_new_account(client, did).await 31 + } 32 + 16 33 #[tokio::test] 17 34 async fn test_email_update_flow_success() { 18 35 let client = common::client(); ··· 21 38 22 39 let handle = format!("emailup_{}", uuid::Uuid::new_v4()); 23 40 let email = format!("{}@example.com", handle); 24 - let payload = json!({ 25 - "handle": handle, 26 - "email": email, 27 - "password": "password" 28 - }); 29 - 30 - let res = client 31 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 32 - .json(&payload) 33 - .send() 34 - .await 35 - .expect("Failed to create account"); 36 - assert_eq!(res.status(), StatusCode::OK); 37 - let body: Value = res.json().await.expect("Invalid JSON"); 38 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 41 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 39 42 40 43 let new_email = format!("new_{}@example.com", handle); 41 44 let res = client 42 45 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 43 - .bearer_auth(access_jwt) 46 + .bearer_auth(&access_jwt) 44 47 .json(&json!({"email": new_email})) 45 48 .send() 46 49 .await ··· 63 66 64 67 let res = client 65 68 .post(format!("{}/xrpc/com.atproto.server.confirmEmail", base_url)) 66 - .bearer_auth(access_jwt) 69 + .bearer_auth(&access_jwt) 67 70 .json(&json!({ 68 71 "email": new_email, 69 72 "token": code ··· 81 84 .await 82 85 .expect("User not found"); 83 86 84 - assert_eq!(user.email, new_email); 87 + assert_eq!(user.email, Some(new_email)); 85 88 assert!(user.email_pending_verification.is_none()); 86 89 assert!(user.email_confirmation_code.is_none()); 87 90 } ··· 93 96 94 97 let handle1 = format!("emailup_taken1_{}", uuid::Uuid::new_v4()); 95 98 let email1 = format!("{}@example.com", handle1); 96 - let res = client 97 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 98 - .json(&json!({ 99 - "handle": handle1, 100 - "email": email1, 101 - "password": "password" 102 - })) 103 - .send() 104 - .await 105 - .expect("Failed to create account 1"); 106 - assert_eq!(res.status(), StatusCode::OK); 99 + let _ = create_verified_account(&client, &base_url, &handle1, &email1).await; 107 100 108 101 let handle2 = format!("emailup_taken2_{}", uuid::Uuid::new_v4()); 109 102 let email2 = format!("{}@example.com", handle2); 110 - let res = client 111 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 112 - .json(&json!({ 113 - "handle": handle2, 114 - "email": email2, 115 - "password": "password" 116 - })) 117 - .send() 118 - .await 119 - .expect("Failed to create account 2"); 120 - assert_eq!(res.status(), StatusCode::OK); 121 - let body: Value = res.json().await.expect("Invalid JSON"); 122 - let access_jwt2 = body["accessJwt"].as_str().expect("No accessJwt"); 103 + let access_jwt2 = create_verified_account(&client, &base_url, &handle2, &email2).await; 123 104 124 105 let res = client 125 106 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 126 - .bearer_auth(access_jwt2) 107 + .bearer_auth(&access_jwt2) 127 108 .json(&json!({"email": email1})) 128 109 .send() 129 110 .await ··· 141 122 142 123 let handle = format!("emailup_inv_{}", uuid::Uuid::new_v4()); 143 124 let email = format!("{}@example.com", handle); 144 - let res = client 145 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 146 - .json(&json!({ 147 - "handle": handle, 148 - "email": email, 149 - "password": "password" 150 - })) 151 - .send() 152 - .await 153 - .expect("Failed to create account"); 154 - assert_eq!(res.status(), StatusCode::OK); 155 - let body: Value = res.json().await.expect("Invalid JSON"); 156 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 125 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 157 126 158 127 let new_email = format!("new_{}@example.com", handle); 159 128 let res = client 160 129 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 161 - .bearer_auth(access_jwt) 130 + .bearer_auth(&access_jwt) 162 131 .json(&json!({"email": new_email})) 163 132 .send() 164 133 .await ··· 167 136 168 137 let res = client 169 138 .post(format!("{}/xrpc/com.atproto.server.confirmEmail", base_url)) 170 - .bearer_auth(access_jwt) 139 + .bearer_auth(&access_jwt) 171 140 .json(&json!({ 172 141 "email": new_email, 173 142 "token": "wrong-token" ··· 189 158 190 159 let handle = format!("emailup_wrong_{}", uuid::Uuid::new_v4()); 191 160 let email = format!("{}@example.com", handle); 192 - let res = client 193 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 194 - .json(&json!({ 195 - "handle": handle, 196 - "email": email, 197 - "password": "password" 198 - })) 199 - .send() 200 - .await 201 - .expect("Failed to create account"); 202 - assert_eq!(res.status(), StatusCode::OK); 203 - let body: Value = res.json().await.expect("Invalid JSON"); 204 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 161 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 205 162 206 163 let new_email = format!("new_{}@example.com", handle); 207 164 let res = client 208 165 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 209 - .bearer_auth(access_jwt) 166 + .bearer_auth(&access_jwt) 210 167 .json(&json!({"email": new_email})) 211 168 .send() 212 169 .await ··· 221 178 222 179 let res = client 223 180 .post(format!("{}/xrpc/com.atproto.server.confirmEmail", base_url)) 224 - .bearer_auth(access_jwt) 181 + .bearer_auth(&access_jwt) 225 182 .json(&json!({ 226 183 "email": "another_random@example.com", 227 184 "token": code ··· 243 200 244 201 let handle = format!("emailup_direct_{}", uuid::Uuid::new_v4()); 245 202 let email = format!("{}@example.com", handle); 246 - let res = client 247 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 248 - .json(&json!({ 249 - "handle": handle, 250 - "email": email, 251 - "password": "password" 252 - })) 253 - .send() 254 - .await 255 - .expect("Failed to create account"); 256 - assert_eq!(res.status(), StatusCode::OK); 257 - let body: Value = res.json().await.expect("Invalid JSON"); 258 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 203 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 259 204 260 205 let new_email = format!("direct_{}@example.com", handle); 261 206 let res = client 262 207 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 263 - .bearer_auth(access_jwt) 208 + .bearer_auth(&access_jwt) 264 209 .json(&json!({ "email": new_email })) 265 210 .send() 266 211 .await ··· 272 217 .fetch_one(&pool) 273 218 .await 274 219 .expect("User not found"); 275 - assert_eq!(user.email, new_email); 220 + assert_eq!(user.email, Some(new_email)); 276 221 } 277 222 278 223 #[tokio::test] ··· 282 227 283 228 let handle = format!("emailup_same_{}", uuid::Uuid::new_v4()); 284 229 let email = format!("{}@example.com", handle); 285 - let res = client 286 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 287 - .json(&json!({ 288 - "handle": handle, 289 - "email": email, 290 - "password": "password" 291 - })) 292 - .send() 293 - .await 294 - .expect("Failed to create account"); 295 - assert_eq!(res.status(), StatusCode::OK); 296 - let body: Value = res.json().await.expect("Invalid JSON"); 297 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 230 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 298 231 299 232 let res = client 300 233 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 301 - .bearer_auth(access_jwt) 234 + .bearer_auth(&access_jwt) 302 235 .json(&json!({ "email": email })) 303 236 .send() 304 237 .await ··· 314 247 315 248 let handle = format!("emailup_token_{}", uuid::Uuid::new_v4()); 316 249 let email = format!("{}@example.com", handle); 317 - let res = client 318 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 319 - .json(&json!({ 320 - "handle": handle, 321 - "email": email, 322 - "password": "password" 323 - })) 324 - .send() 325 - .await 326 - .expect("Failed to create account"); 327 - assert_eq!(res.status(), StatusCode::OK); 328 - let body: Value = res.json().await.expect("Invalid JSON"); 329 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 250 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 330 251 331 252 let new_email = format!("pending_{}@example.com", handle); 332 253 let res = client 333 254 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 334 - .bearer_auth(access_jwt) 255 + .bearer_auth(&access_jwt) 335 256 .json(&json!({"email": new_email})) 336 257 .send() 337 258 .await ··· 340 261 341 262 let res = client 342 263 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 343 - .bearer_auth(access_jwt) 264 + .bearer_auth(&access_jwt) 344 265 .json(&json!({ "email": new_email })) 345 266 .send() 346 267 .await ··· 359 280 360 281 let handle = format!("emailup_valid_{}", uuid::Uuid::new_v4()); 361 282 let email = format!("{}@example.com", handle); 362 - let res = client 363 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 364 - .json(&json!({ 365 - "handle": handle, 366 - "email": email, 367 - "password": "password" 368 - })) 369 - .send() 370 - .await 371 - .expect("Failed to create account"); 372 - assert_eq!(res.status(), StatusCode::OK); 373 - let body: Value = res.json().await.expect("Invalid JSON"); 374 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 283 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 375 284 376 285 let new_email = format!("valid_{}@example.com", handle); 377 286 let res = client 378 287 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 379 - .bearer_auth(access_jwt) 288 + .bearer_auth(&access_jwt) 380 289 .json(&json!({"email": new_email})) 381 290 .send() 382 291 .await ··· 394 303 395 304 let res = client 396 305 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 397 - .bearer_auth(access_jwt) 306 + .bearer_auth(&access_jwt) 398 307 .json(&json!({ 399 308 "email": new_email, 400 309 "token": code ··· 409 318 .fetch_one(&pool) 410 319 .await 411 320 .expect("User not found"); 412 - assert_eq!(user.email, new_email); 321 + assert_eq!(user.email, Some(new_email)); 413 322 assert!(user.email_pending_verification.is_none()); 414 323 } 415 324 ··· 420 329 421 330 let handle = format!("emailup_badtok_{}", uuid::Uuid::new_v4()); 422 331 let email = format!("{}@example.com", handle); 423 - let res = client 424 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 425 - .json(&json!({ 426 - "handle": handle, 427 - "email": email, 428 - "password": "password" 429 - })) 430 - .send() 431 - .await 432 - .expect("Failed to create account"); 433 - assert_eq!(res.status(), StatusCode::OK); 434 - let body: Value = res.json().await.expect("Invalid JSON"); 435 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 332 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 436 333 437 334 let new_email = format!("badtok_{}@example.com", handle); 438 335 let res = client 439 336 .post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url)) 440 - .bearer_auth(access_jwt) 337 + .bearer_auth(&access_jwt) 441 338 .json(&json!({"email": new_email})) 442 339 .send() 443 340 .await ··· 446 343 447 344 let res = client 448 345 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 449 - .bearer_auth(access_jwt) 346 + .bearer_auth(&access_jwt) 450 347 .json(&json!({ 451 348 "email": new_email, 452 349 "token": "wrong-token-12345" ··· 467 364 468 365 let handle1 = format!("emailup_dup1_{}", uuid::Uuid::new_v4()); 469 366 let email1 = format!("{}@example.com", handle1); 470 - let res = client 471 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 472 - .json(&json!({ 473 - "handle": handle1, 474 - "email": email1, 475 - "password": "password" 476 - })) 477 - .send() 478 - .await 479 - .expect("Failed to create account 1"); 480 - assert_eq!(res.status(), StatusCode::OK); 367 + let _ = create_verified_account(&client, &base_url, &handle1, &email1).await; 481 368 482 369 let handle2 = format!("emailup_dup2_{}", uuid::Uuid::new_v4()); 483 370 let email2 = format!("{}@example.com", handle2); 484 - let res = client 485 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 486 - .json(&json!({ 487 - "handle": handle2, 488 - "email": email2, 489 - "password": "password" 490 - })) 491 - .send() 492 - .await 493 - .expect("Failed to create account 2"); 494 - assert_eq!(res.status(), StatusCode::OK); 495 - let body: Value = res.json().await.expect("Invalid JSON"); 496 - let access_jwt2 = body["accessJwt"].as_str().expect("No accessJwt"); 371 + let access_jwt2 = create_verified_account(&client, &base_url, &handle2, &email2).await; 497 372 498 373 let res = client 499 374 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 500 - .bearer_auth(access_jwt2) 375 + .bearer_auth(&access_jwt2) 501 376 .json(&json!({ "email": email1 })) 502 377 .send() 503 378 .await ··· 532 407 533 408 let handle = format!("emailup_fmt_{}", uuid::Uuid::new_v4()); 534 409 let email = format!("{}@example.com", handle); 535 - let res = client 536 - .post(format!("{}/xrpc/com.atproto.server.createAccount", base_url)) 537 - .json(&json!({ 538 - "handle": handle, 539 - "email": email, 540 - "password": "password" 541 - })) 542 - .send() 543 - .await 544 - .expect("Failed to create account"); 545 - assert_eq!(res.status(), StatusCode::OK); 546 - let body: Value = res.json().await.expect("Invalid JSON"); 547 - let access_jwt = body["accessJwt"].as_str().expect("No accessJwt"); 410 + let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await; 548 411 549 412 let res = client 550 413 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url)) 551 - .bearer_auth(access_jwt) 414 + .bearer_auth(&access_jwt) 552 415 .json(&json!({ "email": "not-an-email" })) 553 416 .send() 554 417 .await
+2 -4
tests/helpers/mod.rs
··· 43 43 .as_str() 44 44 .expect("setup_new_user: Response had no DID") 45 45 .to_string(); 46 - let new_jwt = create_body["accessJwt"] 47 - .as_str() 48 - .expect("setup_new_user: Response had no accessJwt") 49 - .to_string(); 46 + 47 + let new_jwt = verify_new_account(&client, &new_did).await; 50 48 51 49 (new_did, new_jwt) 52 50 }
+1 -17
tests/identity.rs
··· 264 264 let create_body: Value = res.json().await.expect("Not JSON"); 265 265 assert_eq!(create_body["did"], did); 266 266 267 - let login_payload = json!({ 268 - "identifier": handle, 269 - "password": "password" 270 - }); 271 - let res = client 272 - .post(format!( 273 - "{}/xrpc/com.atproto.server.createSession", 274 - base_url().await 275 - )) 276 - .json(&login_payload) 277 - .send() 278 - .await 279 - .expect("Failed createSession"); 280 - 281 - assert_eq!(res.status(), StatusCode::OK); 282 - let session_body: Value = res.json().await.expect("Not JSON"); 283 - let _jwt = session_body["accessJwt"].as_str().unwrap(); 267 + let _jwt = verify_new_account(&client, &did).await; 284 268 285 269 /* 286 270 let profile_payload = json!({
-109
tests/import_repo.rs
··· 1 - mod common; 2 - use common::*; 3 - 4 - use reqwest::StatusCode; 5 - use serde_json::json; 6 - 7 - #[tokio::test] 8 - async fn test_import_repo_requires_auth() { 9 - let client = client(); 10 - 11 - let res = client 12 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 13 - .header("Content-Type", "application/vnd.ipld.car") 14 - .body(vec![0u8; 100]) 15 - .send() 16 - .await 17 - .expect("Request failed"); 18 - 19 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 20 - } 21 - 22 - #[tokio::test] 23 - async fn test_import_repo_invalid_car() { 24 - let client = client(); 25 - let (token, _did) = create_account_and_login(&client).await; 26 - 27 - let res = client 28 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 29 - .bearer_auth(&token) 30 - .header("Content-Type", "application/vnd.ipld.car") 31 - .body(vec![0u8; 100]) 32 - .send() 33 - .await 34 - .expect("Request failed"); 35 - 36 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 37 - let body: serde_json::Value = res.json().await.unwrap(); 38 - assert_eq!(body["error"], "InvalidRequest"); 39 - } 40 - 41 - #[tokio::test] 42 - async fn test_import_repo_empty_body() { 43 - let client = client(); 44 - let (token, _did) = create_account_and_login(&client).await; 45 - 46 - let res = client 47 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 48 - .bearer_auth(&token) 49 - .header("Content-Type", "application/vnd.ipld.car") 50 - .body(vec![]) 51 - .send() 52 - .await 53 - .expect("Request failed"); 54 - 55 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 56 - } 57 - 58 - #[tokio::test] 59 - async fn test_import_repo_with_exported_repo() { 60 - let client = client(); 61 - let (token, did) = create_account_and_login(&client).await; 62 - 63 - let post_payload = json!({ 64 - "repo": did, 65 - "collection": "app.bsky.feed.post", 66 - "record": { 67 - "$type": "app.bsky.feed.post", 68 - "text": "Test post for import", 69 - "createdAt": chrono::Utc::now().to_rfc3339(), 70 - } 71 - }); 72 - 73 - let create_res = client 74 - .post(format!( 75 - "{}/xrpc/com.atproto.repo.createRecord", 76 - base_url().await 77 - )) 78 - .bearer_auth(&token) 79 - .json(&post_payload) 80 - .send() 81 - .await 82 - .expect("Failed to create post"); 83 - assert_eq!(create_res.status(), StatusCode::OK); 84 - 85 - let export_res = client 86 - .get(format!( 87 - "{}/xrpc/com.atproto.sync.getRepo?did={}", 88 - base_url().await, 89 - did 90 - )) 91 - .send() 92 - .await 93 - .expect("Failed to export repo"); 94 - assert_eq!(export_res.status(), StatusCode::OK); 95 - 96 - let car_bytes = export_res.bytes().await.expect("Failed to get CAR bytes"); 97 - 98 - let import_res = client 99 - .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 100 - .bearer_auth(&token) 101 - .header("Content-Type", "application/vnd.ipld.car") 102 - .body(car_bytes.to_vec()) 103 - .send() 104 - .await 105 - .expect("Failed to import repo"); 106 - 107 - assert_eq!(import_res.status(), StatusCode::OK); 108 - } 109 -
+51
tests/import_verification.rs
··· 5 5 use reqwest::StatusCode; 6 6 use serde_json::json; 7 7 8 + #[tokio::test] 9 + async fn test_import_repo_requires_auth() { 10 + let client = client(); 11 + 12 + let res = client 13 + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 14 + .header("Content-Type", "application/vnd.ipld.car") 15 + .body(vec![0u8; 100]) 16 + .send() 17 + .await 18 + .expect("Request failed"); 19 + 20 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 21 + } 22 + 23 + #[tokio::test] 24 + async fn test_import_repo_invalid_car() { 25 + let client = client(); 26 + let (token, _did) = create_account_and_login(&client).await; 27 + 28 + let res = client 29 + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 30 + .bearer_auth(&token) 31 + .header("Content-Type", "application/vnd.ipld.car") 32 + .body(vec![0u8; 100]) 33 + .send() 34 + .await 35 + .expect("Request failed"); 36 + 37 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 38 + let body: serde_json::Value = res.json().await.unwrap(); 39 + assert_eq!(body["error"], "InvalidRequest"); 40 + } 41 + 42 + #[tokio::test] 43 + async fn test_import_repo_empty_body() { 44 + let client = client(); 45 + let (token, _did) = create_account_and_login(&client).await; 46 + 47 + let res = client 48 + .post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await)) 49 + .bearer_auth(&token) 50 + .header("Content-Type", "application/vnd.ipld.car") 51 + .body(vec![]) 52 + .send() 53 + .await 54 + .expect("Request failed"); 55 + 56 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); 57 + } 58 + 8 59 fn write_varint(buf: &mut Vec<u8>, mut value: u64) { 9 60 loop { 10 61 let mut byte = (value & 0x7F) as u8;
+34 -38
tests/jwt_security.rs
··· 10 10 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 11 11 }; 12 12 use chrono::{Duration, Utc}; 13 - use common::{base_url, client, create_account_and_login}; 13 + use common::{base_url, client, create_account_and_login, get_db_connection_string}; 14 14 use k256::SecretKey; 15 15 use k256::ecdsa::{SigningKey, Signature, signature::Signer}; 16 16 use rand::rngs::OsRng; ··· 906 906 907 907 assert_eq!(create_res.status(), StatusCode::OK); 908 908 let account: Value = create_res.json().await.unwrap(); 909 - let refresh_jwt = account["refreshJwt"].as_str().unwrap().to_string(); 909 + let did = account["did"].as_str().unwrap(); 910 + 911 + let conn_str = get_db_connection_string().await; 912 + let pool = sqlx::postgres::PgPoolOptions::new() 913 + .max_connections(2) 914 + .connect(&conn_str) 915 + .await 916 + .expect("Failed to connect to test database"); 917 + 918 + let verification_code: String = sqlx::query_scalar!( 919 + "SELECT email_confirmation_code FROM users WHERE did = $1", 920 + did 921 + ) 922 + .fetch_one(&pool) 923 + .await 924 + .expect("Failed to get verification code") 925 + .expect("No verification code found"); 926 + 927 + let confirm_res = http_client 928 + .post(format!("{}/xrpc/com.atproto.server.confirmSignup", url)) 929 + .json(&json!({ 930 + "did": did, 931 + "verificationCode": verification_code 932 + })) 933 + .send() 934 + .await 935 + .unwrap(); 936 + 937 + assert_eq!(confirm_res.status(), StatusCode::OK); 938 + let confirmed: Value = confirm_res.json().await.unwrap(); 939 + let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string(); 910 940 911 941 let first_refresh = http_client 912 942 .post(format!("{}/xrpc/com.atproto.server.refreshSession", url)) ··· 980 1010 let url = base_url().await; 981 1011 let http_client = client(); 982 1012 983 - let ts = Utc::now().timestamp_millis(); 984 - let handle = format!("del-sess-{}", ts); 985 - let email = format!("del-sess-{}@example.com", ts); 986 - let password = "test-password-123"; 987 - 988 - let create_res = http_client 989 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 990 - .json(&json!({ 991 - "handle": handle, 992 - "email": email, 993 - "password": password 994 - })) 995 - .send() 996 - .await 997 - .unwrap(); 998 - 999 - let account: Value = create_res.json().await.unwrap(); 1000 - let access_jwt = account["accessJwt"].as_str().unwrap().to_string(); 1013 + let (access_jwt, _did) = create_account_and_login(&http_client).await; 1001 1014 1002 1015 let get_res = http_client 1003 1016 .get(format!("{}/xrpc/com.atproto.server.getSession", url)) ··· 1029 1042 let url = base_url().await; 1030 1043 let http_client = client(); 1031 1044 1032 - let ts = Utc::now().timestamp_millis(); 1033 - let handle = format!("deact-jwt-{}", ts); 1034 - let email = format!("deact-jwt-{}@example.com", ts); 1035 - let password = "test-password-123"; 1036 - 1037 - let create_res = http_client 1038 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1039 - .json(&json!({ 1040 - "handle": handle, 1041 - "email": email, 1042 - "password": password 1043 - })) 1044 - .send() 1045 - .await 1046 - .unwrap(); 1047 - 1048 - let account: Value = create_res.json().await.unwrap(); 1049 - let access_jwt = account["accessJwt"].as_str().unwrap().to_string(); 1045 + let (access_jwt, _did) = create_account_and_login(&http_client).await; 1050 1046 1051 1047 let deact_res = http_client 1052 1048 .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
+545 -52
tests/lifecycle_record.rs
··· 664 664 } 665 665 666 666 #[tokio::test] 667 - async fn test_list_records_pagination() { 668 - let client = client(); 669 - let (did, jwt) = setup_new_user("list-pagination").await; 670 - 671 - for i in 0..5 { 672 - tokio::time::sleep(Duration::from_millis(50)).await; 673 - create_post(&client, &did, &jwt, &format!("Post number {}", i)).await; 674 - } 675 - 676 - let list_res = client 677 - .get(format!( 678 - "{}/xrpc/com.atproto.repo.listRecords", 679 - base_url().await 680 - )) 681 - .query(&[ 682 - ("repo", did.as_str()), 683 - ("collection", "app.bsky.feed.post"), 684 - ("limit", "2"), 685 - ]) 686 - .send() 687 - .await 688 - .expect("Failed to list records"); 689 - 690 - assert_eq!(list_res.status(), StatusCode::OK); 691 - let list_body: Value = list_res.json().await.unwrap(); 692 - let records = list_body["records"].as_array().unwrap(); 693 - assert_eq!(records.len(), 2, "Should return 2 records with limit=2"); 694 - 695 - if let Some(cursor) = list_body["cursor"].as_str() { 696 - let list_page2_res = client 697 - .get(format!( 698 - "{}/xrpc/com.atproto.repo.listRecords", 699 - base_url().await 700 - )) 701 - .query(&[ 702 - ("repo", did.as_str()), 703 - ("collection", "app.bsky.feed.post"), 704 - ("limit", "2"), 705 - ("cursor", cursor), 706 - ]) 707 - .send() 708 - .await 709 - .expect("Failed to list records page 2"); 710 - 711 - assert_eq!(list_page2_res.status(), StatusCode::OK); 712 - let page2_body: Value = list_page2_res.json().await.unwrap(); 713 - let page2_records = page2_body["records"].as_array().unwrap(); 714 - assert_eq!(page2_records.len(), 2, "Page 2 should have 2 more records"); 715 - } 716 - } 717 - 718 - #[tokio::test] 719 667 async fn test_apply_writes_batch_lifecycle() { 720 668 let client = client(); 721 669 let (did, jwt) = setup_new_user("apply-writes-batch").await; ··· 885 833 "Batch-deleted post should be gone" 886 834 ); 887 835 } 836 + 837 + async fn create_post_with_rkey( 838 + client: &reqwest::Client, 839 + did: &str, 840 + jwt: &str, 841 + rkey: &str, 842 + text: &str, 843 + ) -> (String, String) { 844 + let payload = json!({ 845 + "repo": did, 846 + "collection": "app.bsky.feed.post", 847 + "rkey": rkey, 848 + "record": { 849 + "$type": "app.bsky.feed.post", 850 + "text": text, 851 + "createdAt": Utc::now().to_rfc3339() 852 + } 853 + }); 854 + 855 + let res = client 856 + .post(format!( 857 + "{}/xrpc/com.atproto.repo.putRecord", 858 + base_url().await 859 + )) 860 + .bearer_auth(jwt) 861 + .json(&payload) 862 + .send() 863 + .await 864 + .expect("Failed to create record"); 865 + 866 + assert_eq!(res.status(), StatusCode::OK); 867 + let body: Value = res.json().await.unwrap(); 868 + ( 869 + body["uri"].as_str().unwrap().to_string(), 870 + body["cid"].as_str().unwrap().to_string(), 871 + ) 872 + } 873 + 874 + #[tokio::test] 875 + async fn test_list_records_default_order() { 876 + let client = client(); 877 + let (did, jwt) = setup_new_user("list-default-order").await; 878 + 879 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 880 + tokio::time::sleep(Duration::from_millis(50)).await; 881 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 882 + tokio::time::sleep(Duration::from_millis(50)).await; 883 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 884 + 885 + let res = client 886 + .get(format!( 887 + "{}/xrpc/com.atproto.repo.listRecords", 888 + base_url().await 889 + )) 890 + .query(&[ 891 + ("repo", did.as_str()), 892 + ("collection", "app.bsky.feed.post"), 893 + ]) 894 + .send() 895 + .await 896 + .expect("Failed to list records"); 897 + 898 + assert_eq!(res.status(), StatusCode::OK); 899 + let body: Value = res.json().await.unwrap(); 900 + let records = body["records"].as_array().unwrap(); 901 + 902 + assert_eq!(records.len(), 3); 903 + let rkeys: Vec<&str> = records 904 + .iter() 905 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 906 + .collect(); 907 + 908 + assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 909 + } 910 + 911 + #[tokio::test] 912 + async fn test_list_records_reverse_true() { 913 + let client = client(); 914 + let (did, jwt) = setup_new_user("list-reverse").await; 915 + 916 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 917 + tokio::time::sleep(Duration::from_millis(50)).await; 918 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 919 + tokio::time::sleep(Duration::from_millis(50)).await; 920 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 921 + 922 + let res = client 923 + .get(format!( 924 + "{}/xrpc/com.atproto.repo.listRecords", 925 + base_url().await 926 + )) 927 + .query(&[ 928 + ("repo", did.as_str()), 929 + ("collection", "app.bsky.feed.post"), 930 + ("reverse", "true"), 931 + ]) 932 + .send() 933 + .await 934 + .expect("Failed to list records"); 935 + 936 + assert_eq!(res.status(), StatusCode::OK); 937 + let body: Value = res.json().await.unwrap(); 938 + let records = body["records"].as_array().unwrap(); 939 + 940 + let rkeys: Vec<&str> = records 941 + .iter() 942 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 943 + .collect(); 944 + 945 + assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 946 + } 947 + 948 + #[tokio::test] 949 + async fn test_list_records_cursor_pagination() { 950 + let client = client(); 951 + let (did, jwt) = setup_new_user("list-cursor").await; 952 + 953 + for i in 0..5 { 954 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 955 + tokio::time::sleep(Duration::from_millis(50)).await; 956 + } 957 + 958 + let res = client 959 + .get(format!( 960 + "{}/xrpc/com.atproto.repo.listRecords", 961 + base_url().await 962 + )) 963 + .query(&[ 964 + ("repo", did.as_str()), 965 + ("collection", "app.bsky.feed.post"), 966 + ("limit", "2"), 967 + ]) 968 + .send() 969 + .await 970 + .expect("Failed to list records"); 971 + 972 + assert_eq!(res.status(), StatusCode::OK); 973 + let body: Value = res.json().await.unwrap(); 974 + let records = body["records"].as_array().unwrap(); 975 + assert_eq!(records.len(), 2); 976 + 977 + let cursor = body["cursor"].as_str().expect("Should have cursor with more records"); 978 + 979 + let res2 = client 980 + .get(format!( 981 + "{}/xrpc/com.atproto.repo.listRecords", 982 + base_url().await 983 + )) 984 + .query(&[ 985 + ("repo", did.as_str()), 986 + ("collection", "app.bsky.feed.post"), 987 + ("limit", "2"), 988 + ("cursor", cursor), 989 + ]) 990 + .send() 991 + .await 992 + .expect("Failed to list records with cursor"); 993 + 994 + assert_eq!(res2.status(), StatusCode::OK); 995 + let body2: Value = res2.json().await.unwrap(); 996 + let records2 = body2["records"].as_array().unwrap(); 997 + assert_eq!(records2.len(), 2); 998 + 999 + let all_uris: Vec<&str> = records 1000 + .iter() 1001 + .chain(records2.iter()) 1002 + .map(|r| r["uri"].as_str().unwrap()) 1003 + .collect(); 1004 + let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 1005 + assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 1006 + } 1007 + 1008 + #[tokio::test] 1009 + async fn test_list_records_rkey_start() { 1010 + let client = client(); 1011 + let (did, jwt) = setup_new_user("list-rkey-start").await; 1012 + 1013 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 1014 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 1015 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 1016 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 1017 + 1018 + let res = client 1019 + .get(format!( 1020 + "{}/xrpc/com.atproto.repo.listRecords", 1021 + base_url().await 1022 + )) 1023 + .query(&[ 1024 + ("repo", did.as_str()), 1025 + ("collection", "app.bsky.feed.post"), 1026 + ("rkeyStart", "bbbb"), 1027 + ("reverse", "true"), 1028 + ]) 1029 + .send() 1030 + .await 1031 + .expect("Failed to list records"); 1032 + 1033 + assert_eq!(res.status(), StatusCode::OK); 1034 + let body: Value = res.json().await.unwrap(); 1035 + let records = body["records"].as_array().unwrap(); 1036 + 1037 + let rkeys: Vec<&str> = records 1038 + .iter() 1039 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1040 + .collect(); 1041 + 1042 + for rkey in &rkeys { 1043 + assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 1044 + } 1045 + } 1046 + 1047 + #[tokio::test] 1048 + async fn test_list_records_rkey_end() { 1049 + let client = client(); 1050 + let (did, jwt) = setup_new_user("list-rkey-end").await; 1051 + 1052 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 1053 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 1054 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 1055 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 1056 + 1057 + let res = client 1058 + .get(format!( 1059 + "{}/xrpc/com.atproto.repo.listRecords", 1060 + base_url().await 1061 + )) 1062 + .query(&[ 1063 + ("repo", did.as_str()), 1064 + ("collection", "app.bsky.feed.post"), 1065 + ("rkeyEnd", "cccc"), 1066 + ("reverse", "true"), 1067 + ]) 1068 + .send() 1069 + .await 1070 + .expect("Failed to list records"); 1071 + 1072 + assert_eq!(res.status(), StatusCode::OK); 1073 + let body: Value = res.json().await.unwrap(); 1074 + let records = body["records"].as_array().unwrap(); 1075 + 1076 + let rkeys: Vec<&str> = records 1077 + .iter() 1078 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1079 + .collect(); 1080 + 1081 + for rkey in &rkeys { 1082 + assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 1083 + } 1084 + } 1085 + 1086 + #[tokio::test] 1087 + async fn test_list_records_rkey_range() { 1088 + let client = client(); 1089 + let (did, jwt) = setup_new_user("list-rkey-range").await; 1090 + 1091 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 1092 + create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 1093 + create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 1094 + create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 1095 + create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; 1096 + 1097 + let res = client 1098 + .get(format!( 1099 + "{}/xrpc/com.atproto.repo.listRecords", 1100 + base_url().await 1101 + )) 1102 + .query(&[ 1103 + ("repo", did.as_str()), 1104 + ("collection", "app.bsky.feed.post"), 1105 + ("rkeyStart", "bbbb"), 1106 + ("rkeyEnd", "dddd"), 1107 + ("reverse", "true"), 1108 + ]) 1109 + .send() 1110 + .await 1111 + .expect("Failed to list records"); 1112 + 1113 + assert_eq!(res.status(), StatusCode::OK); 1114 + let body: Value = res.json().await.unwrap(); 1115 + let records = body["records"].as_array().unwrap(); 1116 + 1117 + let rkeys: Vec<&str> = records 1118 + .iter() 1119 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1120 + .collect(); 1121 + 1122 + for rkey in &rkeys { 1123 + assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey); 1124 + } 1125 + assert!(!rkeys.is_empty(), "Should have at least some records in range"); 1126 + } 1127 + 1128 + #[tokio::test] 1129 + async fn test_list_records_limit_clamping_max() { 1130 + let client = client(); 1131 + let (did, jwt) = setup_new_user("list-limit-max").await; 1132 + 1133 + for i in 0..5 { 1134 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1135 + } 1136 + 1137 + let res = client 1138 + .get(format!( 1139 + "{}/xrpc/com.atproto.repo.listRecords", 1140 + base_url().await 1141 + )) 1142 + .query(&[ 1143 + ("repo", did.as_str()), 1144 + ("collection", "app.bsky.feed.post"), 1145 + ("limit", "1000"), 1146 + ]) 1147 + .send() 1148 + .await 1149 + .expect("Failed to list records"); 1150 + 1151 + assert_eq!(res.status(), StatusCode::OK); 1152 + let body: Value = res.json().await.unwrap(); 1153 + let records = body["records"].as_array().unwrap(); 1154 + assert!(records.len() <= 100, "Limit should be clamped to max 100"); 1155 + } 1156 + 1157 + #[tokio::test] 1158 + async fn test_list_records_limit_clamping_min() { 1159 + let client = client(); 1160 + let (did, jwt) = setup_new_user("list-limit-min").await; 1161 + 1162 + create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; 1163 + 1164 + let res = client 1165 + .get(format!( 1166 + "{}/xrpc/com.atproto.repo.listRecords", 1167 + base_url().await 1168 + )) 1169 + .query(&[ 1170 + ("repo", did.as_str()), 1171 + ("collection", "app.bsky.feed.post"), 1172 + ("limit", "0"), 1173 + ]) 1174 + .send() 1175 + .await 1176 + .expect("Failed to list records"); 1177 + 1178 + assert_eq!(res.status(), StatusCode::OK); 1179 + let body: Value = res.json().await.unwrap(); 1180 + let records = body["records"].as_array().unwrap(); 1181 + assert!(records.len() >= 1, "Limit should be clamped to min 1"); 1182 + } 1183 + 1184 + #[tokio::test] 1185 + async fn test_list_records_empty_collection() { 1186 + let client = client(); 1187 + let (did, _jwt) = setup_new_user("list-empty").await; 1188 + 1189 + let res = client 1190 + .get(format!( 1191 + "{}/xrpc/com.atproto.repo.listRecords", 1192 + base_url().await 1193 + )) 1194 + .query(&[ 1195 + ("repo", did.as_str()), 1196 + ("collection", "app.bsky.feed.post"), 1197 + ]) 1198 + .send() 1199 + .await 1200 + .expect("Failed to list records"); 1201 + 1202 + assert_eq!(res.status(), StatusCode::OK); 1203 + let body: Value = res.json().await.unwrap(); 1204 + let records = body["records"].as_array().unwrap(); 1205 + assert!(records.is_empty(), "Empty collection should return empty array"); 1206 + assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 1207 + } 1208 + 1209 + #[tokio::test] 1210 + async fn test_list_records_exact_limit() { 1211 + let client = client(); 1212 + let (did, jwt) = setup_new_user("list-exact-limit").await; 1213 + 1214 + for i in 0..10 { 1215 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1216 + } 1217 + 1218 + let res = client 1219 + .get(format!( 1220 + "{}/xrpc/com.atproto.repo.listRecords", 1221 + base_url().await 1222 + )) 1223 + .query(&[ 1224 + ("repo", did.as_str()), 1225 + ("collection", "app.bsky.feed.post"), 1226 + ("limit", "5"), 1227 + ]) 1228 + .send() 1229 + .await 1230 + .expect("Failed to list records"); 1231 + 1232 + assert_eq!(res.status(), StatusCode::OK); 1233 + let body: Value = res.json().await.unwrap(); 1234 + let records = body["records"].as_array().unwrap(); 1235 + assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 1236 + } 1237 + 1238 + #[tokio::test] 1239 + async fn test_list_records_cursor_exhaustion() { 1240 + let client = client(); 1241 + let (did, jwt) = setup_new_user("list-cursor-exhaust").await; 1242 + 1243 + for i in 0..3 { 1244 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1245 + } 1246 + 1247 + let res = client 1248 + .get(format!( 1249 + "{}/xrpc/com.atproto.repo.listRecords", 1250 + base_url().await 1251 + )) 1252 + .query(&[ 1253 + ("repo", did.as_str()), 1254 + ("collection", "app.bsky.feed.post"), 1255 + ("limit", "10"), 1256 + ]) 1257 + .send() 1258 + .await 1259 + .expect("Failed to list records"); 1260 + 1261 + assert_eq!(res.status(), StatusCode::OK); 1262 + let body: Value = res.json().await.unwrap(); 1263 + let records = body["records"].as_array().unwrap(); 1264 + assert_eq!(records.len(), 3); 1265 + } 1266 + 1267 + #[tokio::test] 1268 + async fn test_list_records_repo_not_found() { 1269 + let client = client(); 1270 + 1271 + let res = client 1272 + .get(format!( 1273 + "{}/xrpc/com.atproto.repo.listRecords", 1274 + base_url().await 1275 + )) 1276 + .query(&[ 1277 + ("repo", "did:plc:nonexistent12345"), 1278 + ("collection", "app.bsky.feed.post"), 1279 + ]) 1280 + .send() 1281 + .await 1282 + .expect("Failed to list records"); 1283 + 1284 + assert_eq!(res.status(), StatusCode::NOT_FOUND); 1285 + } 1286 + 1287 + #[tokio::test] 1288 + async fn test_list_records_includes_cid() { 1289 + let client = client(); 1290 + let (did, jwt) = setup_new_user("list-includes-cid").await; 1291 + 1292 + create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; 1293 + 1294 + let res = client 1295 + .get(format!( 1296 + "{}/xrpc/com.atproto.repo.listRecords", 1297 + base_url().await 1298 + )) 1299 + .query(&[ 1300 + ("repo", did.as_str()), 1301 + ("collection", "app.bsky.feed.post"), 1302 + ]) 1303 + .send() 1304 + .await 1305 + .expect("Failed to list records"); 1306 + 1307 + assert_eq!(res.status(), StatusCode::OK); 1308 + let body: Value = res.json().await.unwrap(); 1309 + let records = body["records"].as_array().unwrap(); 1310 + 1311 + for record in records { 1312 + assert!(record["uri"].is_string(), "Record should have uri"); 1313 + assert!(record["cid"].is_string(), "Record should have cid"); 1314 + assert!(record["value"].is_object(), "Record should have value"); 1315 + let cid = record["cid"].as_str().unwrap(); 1316 + assert!(cid.starts_with("bafy"), "CID should be valid"); 1317 + } 1318 + } 1319 + 1320 + #[tokio::test] 1321 + async fn test_list_records_cursor_with_reverse() { 1322 + let client = client(); 1323 + let (did, jwt) = setup_new_user("list-cursor-reverse").await; 1324 + 1325 + for i in 0..5 { 1326 + create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 1327 + } 1328 + 1329 + let res = client 1330 + .get(format!( 1331 + "{}/xrpc/com.atproto.repo.listRecords", 1332 + base_url().await 1333 + )) 1334 + .query(&[ 1335 + ("repo", did.as_str()), 1336 + ("collection", "app.bsky.feed.post"), 1337 + ("limit", "2"), 1338 + ("reverse", "true"), 1339 + ]) 1340 + .send() 1341 + .await 1342 + .expect("Failed to list records"); 1343 + 1344 + assert_eq!(res.status(), StatusCode::OK); 1345 + let body: Value = res.json().await.unwrap(); 1346 + let records = body["records"].as_array().unwrap(); 1347 + let first_rkeys: Vec<&str> = records 1348 + .iter() 1349 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1350 + .collect(); 1351 + 1352 + assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest"); 1353 + 1354 + if let Some(cursor) = body["cursor"].as_str() { 1355 + let res2 = client 1356 + .get(format!( 1357 + "{}/xrpc/com.atproto.repo.listRecords", 1358 + base_url().await 1359 + )) 1360 + .query(&[ 1361 + ("repo", did.as_str()), 1362 + ("collection", "app.bsky.feed.post"), 1363 + ("limit", "2"), 1364 + ("reverse", "true"), 1365 + ("cursor", cursor), 1366 + ]) 1367 + .send() 1368 + .await 1369 + .expect("Failed to list records with cursor"); 1370 + 1371 + let body2: Value = res2.json().await.unwrap(); 1372 + let records2 = body2["records"].as_array().unwrap(); 1373 + let second_rkeys: Vec<&str> = records2 1374 + .iter() 1375 + .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 1376 + .collect(); 1377 + 1378 + assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order"); 1379 + } 1380 + }
+18 -7
tests/lifecycle_session.rs
··· 58 58 .await 59 59 .expect("Failed to create account"); 60 60 assert_eq!(create_res.status(), StatusCode::OK); 61 + let create_body: Value = create_res.json().await.unwrap(); 62 + let did = create_body["did"].as_str().unwrap(); 63 + 64 + let _ = verify_new_account(&client, did).await; 61 65 62 66 let login_payload = json!({ 63 67 "identifier": handle, ··· 128 132 "email": email, 129 133 "password": password 130 134 }); 131 - client 135 + let create_res = client 132 136 .post(format!( 133 137 "{}/xrpc/com.atproto.server.createAccount", 134 138 base_url().await ··· 137 141 .send() 138 142 .await 139 143 .expect("Failed to create account"); 144 + let create_body: Value = create_res.json().await.unwrap(); 145 + let did = create_body["did"].as_str().unwrap(); 146 + 147 + let _ = verify_new_account(&client, did).await; 140 148 141 149 let login_payload = json!({ 142 150 "identifier": handle, ··· 209 217 210 218 assert_eq!(create_res.status(), StatusCode::OK); 211 219 let account: Value = create_res.json().await.unwrap(); 212 - let jwt = account["accessJwt"].as_str().unwrap(); 220 + let did = account["did"].as_str().unwrap(); 221 + 222 + let jwt = verify_new_account(&client, did).await; 213 223 214 224 let create_app_pass_res = client 215 225 .post(format!( 216 226 "{}/xrpc/com.atproto.server.createAppPassword", 217 227 base_url().await 218 228 )) 219 - .bearer_auth(jwt) 229 + .bearer_auth(&jwt) 220 230 .json(&json!({ "name": "Test App" })) 221 231 .send() 222 232 .await ··· 232 242 "{}/xrpc/com.atproto.server.listAppPasswords", 233 243 base_url().await 234 244 )) 235 - .bearer_auth(jwt) 245 + .bearer_auth(&jwt) 236 246 .send() 237 247 .await 238 248 .expect("Failed to list app passwords"); ··· 263 273 "{}/xrpc/com.atproto.server.revokeAppPassword", 264 274 base_url().await 265 275 )) 266 - .bearer_auth(jwt) 276 + .bearer_auth(&jwt) 267 277 .json(&json!({ "name": "Test App" })) 268 278 .send() 269 279 .await ··· 295 305 "{}/xrpc/com.atproto.server.listAppPasswords", 296 306 base_url().await 297 307 )) 298 - .bearer_auth(jwt) 308 + .bearer_auth(&jwt) 299 309 .send() 300 310 .await 301 311 .expect("Failed to list after revoke"); ··· 330 340 assert_eq!(create_res.status(), StatusCode::OK); 331 341 let account: Value = create_res.json().await.unwrap(); 332 342 let did = account["did"].as_str().unwrap().to_string(); 333 - let jwt = account["accessJwt"].as_str().unwrap().to_string(); 343 + 344 + let jwt = verify_new_account(&client, &did).await; 334 345 335 346 let (post_uri, _) = create_post(&client, &did, &jwt, "Post before deactivation").await; 336 347 let post_rkey = post_uri.split('/').last().unwrap();
+2 -1
tests/lifecycle_social.rs
··· 441 441 assert_eq!(create_account_res.status(), StatusCode::OK); 442 442 let account_body: Value = create_account_res.json().await.unwrap(); 443 443 let did = account_body["did"].as_str().unwrap().to_string(); 444 - let access_jwt = account_body["accessJwt"].as_str().unwrap().to_string(); 444 + 445 + let access_jwt = verify_new_account(&client, &did).await; 445 446 446 447 let get_session_res = client 447 448 .get(format!(
-554
tests/list_records_pagination.rs
··· 1 - mod common; 2 - mod helpers; 3 - use common::*; 4 - use helpers::*; 5 - 6 - use chrono::Utc; 7 - use reqwest::StatusCode; 8 - use serde_json::{Value, json}; 9 - use std::time::Duration; 10 - 11 - async fn create_post_with_rkey( 12 - client: &reqwest::Client, 13 - did: &str, 14 - jwt: &str, 15 - rkey: &str, 16 - text: &str, 17 - ) -> (String, String) { 18 - let payload = json!({ 19 - "repo": did, 20 - "collection": "app.bsky.feed.post", 21 - "rkey": rkey, 22 - "record": { 23 - "$type": "app.bsky.feed.post", 24 - "text": text, 25 - "createdAt": Utc::now().to_rfc3339() 26 - } 27 - }); 28 - 29 - let res = client 30 - .post(format!( 31 - "{}/xrpc/com.atproto.repo.putRecord", 32 - base_url().await 33 - )) 34 - .bearer_auth(jwt) 35 - .json(&payload) 36 - .send() 37 - .await 38 - .expect("Failed to create record"); 39 - 40 - assert_eq!(res.status(), StatusCode::OK); 41 - let body: Value = res.json().await.unwrap(); 42 - ( 43 - body["uri"].as_str().unwrap().to_string(), 44 - body["cid"].as_str().unwrap().to_string(), 45 - ) 46 - } 47 - 48 - #[tokio::test] 49 - async fn test_list_records_default_order() { 50 - let client = client(); 51 - let (did, jwt) = setup_new_user("list-default-order").await; 52 - 53 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 54 - tokio::time::sleep(Duration::from_millis(50)).await; 55 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 56 - tokio::time::sleep(Duration::from_millis(50)).await; 57 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 58 - 59 - let res = client 60 - .get(format!( 61 - "{}/xrpc/com.atproto.repo.listRecords", 62 - base_url().await 63 - )) 64 - .query(&[ 65 - ("repo", did.as_str()), 66 - ("collection", "app.bsky.feed.post"), 67 - ]) 68 - .send() 69 - .await 70 - .expect("Failed to list records"); 71 - 72 - assert_eq!(res.status(), StatusCode::OK); 73 - let body: Value = res.json().await.unwrap(); 74 - let records = body["records"].as_array().unwrap(); 75 - 76 - assert_eq!(records.len(), 3); 77 - let rkeys: Vec<&str> = records 78 - .iter() 79 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 80 - .collect(); 81 - 82 - assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)"); 83 - } 84 - 85 - #[tokio::test] 86 - async fn test_list_records_reverse_true() { 87 - let client = client(); 88 - let (did, jwt) = setup_new_user("list-reverse").await; 89 - 90 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await; 91 - tokio::time::sleep(Duration::from_millis(50)).await; 92 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await; 93 - tokio::time::sleep(Duration::from_millis(50)).await; 94 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await; 95 - 96 - let res = client 97 - .get(format!( 98 - "{}/xrpc/com.atproto.repo.listRecords", 99 - base_url().await 100 - )) 101 - .query(&[ 102 - ("repo", did.as_str()), 103 - ("collection", "app.bsky.feed.post"), 104 - ("reverse", "true"), 105 - ]) 106 - .send() 107 - .await 108 - .expect("Failed to list records"); 109 - 110 - assert_eq!(res.status(), StatusCode::OK); 111 - let body: Value = res.json().await.unwrap(); 112 - let records = body["records"].as_array().unwrap(); 113 - 114 - let rkeys: Vec<&str> = records 115 - .iter() 116 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 117 - .collect(); 118 - 119 - assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)"); 120 - } 121 - 122 - #[tokio::test] 123 - async fn test_list_records_cursor_pagination() { 124 - let client = client(); 125 - let (did, jwt) = setup_new_user("list-cursor").await; 126 - 127 - for i in 0..5 { 128 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 129 - tokio::time::sleep(Duration::from_millis(50)).await; 130 - } 131 - 132 - let res = client 133 - .get(format!( 134 - "{}/xrpc/com.atproto.repo.listRecords", 135 - base_url().await 136 - )) 137 - .query(&[ 138 - ("repo", did.as_str()), 139 - ("collection", "app.bsky.feed.post"), 140 - ("limit", "2"), 141 - ]) 142 - .send() 143 - .await 144 - .expect("Failed to list records"); 145 - 146 - assert_eq!(res.status(), StatusCode::OK); 147 - let body: Value = res.json().await.unwrap(); 148 - let records = body["records"].as_array().unwrap(); 149 - assert_eq!(records.len(), 2); 150 - 151 - let cursor = body["cursor"].as_str().expect("Should have cursor with more records"); 152 - 153 - let res2 = client 154 - .get(format!( 155 - "{}/xrpc/com.atproto.repo.listRecords", 156 - base_url().await 157 - )) 158 - .query(&[ 159 - ("repo", did.as_str()), 160 - ("collection", "app.bsky.feed.post"), 161 - ("limit", "2"), 162 - ("cursor", cursor), 163 - ]) 164 - .send() 165 - .await 166 - .expect("Failed to list records with cursor"); 167 - 168 - assert_eq!(res2.status(), StatusCode::OK); 169 - let body2: Value = res2.json().await.unwrap(); 170 - let records2 = body2["records"].as_array().unwrap(); 171 - assert_eq!(records2.len(), 2); 172 - 173 - let all_uris: Vec<&str> = records 174 - .iter() 175 - .chain(records2.iter()) 176 - .map(|r| r["uri"].as_str().unwrap()) 177 - .collect(); 178 - let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect(); 179 - assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records"); 180 - } 181 - 182 - #[tokio::test] 183 - async fn test_list_records_rkey_start() { 184 - let client = client(); 185 - let (did, jwt) = setup_new_user("list-rkey-start").await; 186 - 187 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 188 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 189 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 190 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 191 - 192 - let res = client 193 - .get(format!( 194 - "{}/xrpc/com.atproto.repo.listRecords", 195 - base_url().await 196 - )) 197 - .query(&[ 198 - ("repo", did.as_str()), 199 - ("collection", "app.bsky.feed.post"), 200 - ("rkeyStart", "bbbb"), 201 - ("reverse", "true"), 202 - ]) 203 - .send() 204 - .await 205 - .expect("Failed to list records"); 206 - 207 - assert_eq!(res.status(), StatusCode::OK); 208 - let body: Value = res.json().await.unwrap(); 209 - let records = body["records"].as_array().unwrap(); 210 - 211 - let rkeys: Vec<&str> = records 212 - .iter() 213 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 214 - .collect(); 215 - 216 - for rkey in &rkeys { 217 - assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start"); 218 - } 219 - } 220 - 221 - #[tokio::test] 222 - async fn test_list_records_rkey_end() { 223 - let client = client(); 224 - let (did, jwt) = setup_new_user("list-rkey-end").await; 225 - 226 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 227 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 228 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 229 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 230 - 231 - let res = client 232 - .get(format!( 233 - "{}/xrpc/com.atproto.repo.listRecords", 234 - base_url().await 235 - )) 236 - .query(&[ 237 - ("repo", did.as_str()), 238 - ("collection", "app.bsky.feed.post"), 239 - ("rkeyEnd", "cccc"), 240 - ("reverse", "true"), 241 - ]) 242 - .send() 243 - .await 244 - .expect("Failed to list records"); 245 - 246 - assert_eq!(res.status(), StatusCode::OK); 247 - let body: Value = res.json().await.unwrap(); 248 - let records = body["records"].as_array().unwrap(); 249 - 250 - let rkeys: Vec<&str> = records 251 - .iter() 252 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 253 - .collect(); 254 - 255 - for rkey in &rkeys { 256 - assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end"); 257 - } 258 - } 259 - 260 - #[tokio::test] 261 - async fn test_list_records_rkey_range() { 262 - let client = client(); 263 - let (did, jwt) = setup_new_user("list-rkey-range").await; 264 - 265 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await; 266 - create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await; 267 - create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await; 268 - create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await; 269 - create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await; 270 - 271 - let res = client 272 - .get(format!( 273 - "{}/xrpc/com.atproto.repo.listRecords", 274 - base_url().await 275 - )) 276 - .query(&[ 277 - ("repo", did.as_str()), 278 - ("collection", "app.bsky.feed.post"), 279 - ("rkeyStart", "bbbb"), 280 - ("rkeyEnd", "dddd"), 281 - ("reverse", "true"), 282 - ]) 283 - .send() 284 - .await 285 - .expect("Failed to list records"); 286 - 287 - assert_eq!(res.status(), StatusCode::OK); 288 - let body: Value = res.json().await.unwrap(); 289 - let records = body["records"].as_array().unwrap(); 290 - 291 - let rkeys: Vec<&str> = records 292 - .iter() 293 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 294 - .collect(); 295 - 296 - for rkey in &rkeys { 297 - assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey); 298 - } 299 - assert!(!rkeys.is_empty(), "Should have at least some records in range"); 300 - } 301 - 302 - #[tokio::test] 303 - async fn test_list_records_limit_clamping_max() { 304 - let client = client(); 305 - let (did, jwt) = setup_new_user("list-limit-max").await; 306 - 307 - for i in 0..5 { 308 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 309 - } 310 - 311 - let res = client 312 - .get(format!( 313 - "{}/xrpc/com.atproto.repo.listRecords", 314 - base_url().await 315 - )) 316 - .query(&[ 317 - ("repo", did.as_str()), 318 - ("collection", "app.bsky.feed.post"), 319 - ("limit", "1000"), 320 - ]) 321 - .send() 322 - .await 323 - .expect("Failed to list records"); 324 - 325 - assert_eq!(res.status(), StatusCode::OK); 326 - let body: Value = res.json().await.unwrap(); 327 - let records = body["records"].as_array().unwrap(); 328 - assert!(records.len() <= 100, "Limit should be clamped to max 100"); 329 - } 330 - 331 - #[tokio::test] 332 - async fn test_list_records_limit_clamping_min() { 333 - let client = client(); 334 - let (did, jwt) = setup_new_user("list-limit-min").await; 335 - 336 - create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await; 337 - 338 - let res = client 339 - .get(format!( 340 - "{}/xrpc/com.atproto.repo.listRecords", 341 - base_url().await 342 - )) 343 - .query(&[ 344 - ("repo", did.as_str()), 345 - ("collection", "app.bsky.feed.post"), 346 - ("limit", "0"), 347 - ]) 348 - .send() 349 - .await 350 - .expect("Failed to list records"); 351 - 352 - assert_eq!(res.status(), StatusCode::OK); 353 - let body: Value = res.json().await.unwrap(); 354 - let records = body["records"].as_array().unwrap(); 355 - assert!(records.len() >= 1, "Limit should be clamped to min 1"); 356 - } 357 - 358 - #[tokio::test] 359 - async fn test_list_records_empty_collection() { 360 - let client = client(); 361 - let (did, _jwt) = setup_new_user("list-empty").await; 362 - 363 - let res = client 364 - .get(format!( 365 - "{}/xrpc/com.atproto.repo.listRecords", 366 - base_url().await 367 - )) 368 - .query(&[ 369 - ("repo", did.as_str()), 370 - ("collection", "app.bsky.feed.post"), 371 - ]) 372 - .send() 373 - .await 374 - .expect("Failed to list records"); 375 - 376 - assert_eq!(res.status(), StatusCode::OK); 377 - let body: Value = res.json().await.unwrap(); 378 - let records = body["records"].as_array().unwrap(); 379 - assert!(records.is_empty(), "Empty collection should return empty array"); 380 - assert!(body["cursor"].is_null(), "Empty collection should have no cursor"); 381 - } 382 - 383 - #[tokio::test] 384 - async fn test_list_records_exact_limit() { 385 - let client = client(); 386 - let (did, jwt) = setup_new_user("list-exact-limit").await; 387 - 388 - for i in 0..10 { 389 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 390 - } 391 - 392 - let res = client 393 - .get(format!( 394 - "{}/xrpc/com.atproto.repo.listRecords", 395 - base_url().await 396 - )) 397 - .query(&[ 398 - ("repo", did.as_str()), 399 - ("collection", "app.bsky.feed.post"), 400 - ("limit", "5"), 401 - ]) 402 - .send() 403 - .await 404 - .expect("Failed to list records"); 405 - 406 - assert_eq!(res.status(), StatusCode::OK); 407 - let body: Value = res.json().await.unwrap(); 408 - let records = body["records"].as_array().unwrap(); 409 - assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5"); 410 - } 411 - 412 - #[tokio::test] 413 - async fn test_list_records_cursor_exhaustion() { 414 - let client = client(); 415 - let (did, jwt) = setup_new_user("list-cursor-exhaust").await; 416 - 417 - for i in 0..3 { 418 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 419 - } 420 - 421 - let res = client 422 - .get(format!( 423 - "{}/xrpc/com.atproto.repo.listRecords", 424 - base_url().await 425 - )) 426 - .query(&[ 427 - ("repo", did.as_str()), 428 - ("collection", "app.bsky.feed.post"), 429 - ("limit", "10"), 430 - ]) 431 - .send() 432 - .await 433 - .expect("Failed to list records"); 434 - 435 - assert_eq!(res.status(), StatusCode::OK); 436 - let body: Value = res.json().await.unwrap(); 437 - let records = body["records"].as_array().unwrap(); 438 - assert_eq!(records.len(), 3); 439 - } 440 - 441 - #[tokio::test] 442 - async fn test_list_records_repo_not_found() { 443 - let client = client(); 444 - 445 - let res = client 446 - .get(format!( 447 - "{}/xrpc/com.atproto.repo.listRecords", 448 - base_url().await 449 - )) 450 - .query(&[ 451 - ("repo", "did:plc:nonexistent12345"), 452 - ("collection", "app.bsky.feed.post"), 453 - ]) 454 - .send() 455 - .await 456 - .expect("Failed to list records"); 457 - 458 - assert_eq!(res.status(), StatusCode::NOT_FOUND); 459 - } 460 - 461 - #[tokio::test] 462 - async fn test_list_records_includes_cid() { 463 - let client = client(); 464 - let (did, jwt) = setup_new_user("list-includes-cid").await; 465 - 466 - create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await; 467 - 468 - let res = client 469 - .get(format!( 470 - "{}/xrpc/com.atproto.repo.listRecords", 471 - base_url().await 472 - )) 473 - .query(&[ 474 - ("repo", did.as_str()), 475 - ("collection", "app.bsky.feed.post"), 476 - ]) 477 - .send() 478 - .await 479 - .expect("Failed to list records"); 480 - 481 - assert_eq!(res.status(), StatusCode::OK); 482 - let body: Value = res.json().await.unwrap(); 483 - let records = body["records"].as_array().unwrap(); 484 - 485 - for record in records { 486 - assert!(record["uri"].is_string(), "Record should have uri"); 487 - assert!(record["cid"].is_string(), "Record should have cid"); 488 - assert!(record["value"].is_object(), "Record should have value"); 489 - let cid = record["cid"].as_str().unwrap(); 490 - assert!(cid.starts_with("bafy"), "CID should be valid"); 491 - } 492 - } 493 - 494 - #[tokio::test] 495 - async fn test_list_records_cursor_with_reverse() { 496 - let client = client(); 497 - let (did, jwt) = setup_new_user("list-cursor-reverse").await; 498 - 499 - for i in 0..5 { 500 - create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await; 501 - } 502 - 503 - let res = client 504 - .get(format!( 505 - "{}/xrpc/com.atproto.repo.listRecords", 506 - base_url().await 507 - )) 508 - .query(&[ 509 - ("repo", did.as_str()), 510 - ("collection", "app.bsky.feed.post"), 511 - ("limit", "2"), 512 - ("reverse", "true"), 513 - ]) 514 - .send() 515 - .await 516 - .expect("Failed to list records"); 517 - 518 - assert_eq!(res.status(), StatusCode::OK); 519 - let body: Value = res.json().await.unwrap(); 520 - let records = body["records"].as_array().unwrap(); 521 - let first_rkeys: Vec<&str> = records 522 - .iter() 523 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 524 - .collect(); 525 - 526 - assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest"); 527 - 528 - if let Some(cursor) = body["cursor"].as_str() { 529 - let res2 = client 530 - .get(format!( 531 - "{}/xrpc/com.atproto.repo.listRecords", 532 - base_url().await 533 - )) 534 - .query(&[ 535 - ("repo", did.as_str()), 536 - ("collection", "app.bsky.feed.post"), 537 - ("limit", "2"), 538 - ("reverse", "true"), 539 - ("cursor", cursor), 540 - ]) 541 - .send() 542 - .await 543 - .expect("Failed to list records with cursor"); 544 - 545 - let body2: Value = res2.json().await.unwrap(); 546 - let records2 = body2["records"].as_array().unwrap(); 547 - let second_rkeys: Vec<&str> = records2 548 - .iter() 549 - .map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()) 550 - .collect(); 551 - 552 - assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order"); 553 - } 554 - }
+1 -1
tests/notifications.rs
··· 92 92 .await 93 93 .expect("Notification not found"); 94 94 95 - assert_eq!(row.recipient, user_row.email); 95 + assert_eq!(Some(row.recipient), user_row.email); 96 96 assert_eq!(row.subject.as_deref(), Some("Welcome to example.com")); 97 97 assert!(row.body.contains(&format!("@{}", user_row.handle))); 98 98 assert_eq!(row.notification_type, NotificationType::Welcome);
-456
tests/oauth.rs
··· 206 206 } 207 207 208 208 #[tokio::test] 209 - async fn test_par_requires_pkce() { 210 - let url = base_url().await; 211 - let client = client(); 212 - 213 - let redirect_uri = "https://example.com/callback"; 214 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 215 - let client_id = mock_client.uri(); 216 - 217 - let res = client 218 - .post(format!("{}/oauth/par", url)) 219 - .form(&[ 220 - ("response_type", "code"), 221 - ("client_id", &client_id), 222 - ("redirect_uri", redirect_uri), 223 - ("scope", "atproto"), 224 - ]) 225 - .send() 226 - .await 227 - .expect("Failed to send PAR request"); 228 - 229 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 230 - 231 - let body: Value = res.json().await.expect("Invalid JSON"); 232 - assert_eq!(body["error"], "invalid_request"); 233 - } 234 - 235 - #[tokio::test] 236 - async fn test_par_requires_s256() { 237 - let url = base_url().await; 238 - let client = client(); 239 - 240 - let redirect_uri = "https://example.com/callback"; 241 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 242 - let client_id = mock_client.uri(); 243 - 244 - let res = client 245 - .post(format!("{}/oauth/par", url)) 246 - .form(&[ 247 - ("response_type", "code"), 248 - ("client_id", &client_id), 249 - ("redirect_uri", redirect_uri), 250 - ("code_challenge", "test-challenge"), 251 - ("code_challenge_method", "plain"), 252 - ]) 253 - .send() 254 - .await 255 - .expect("Failed to send PAR request"); 256 - 257 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 258 - 259 - let body: Value = res.json().await.expect("Invalid JSON"); 260 - assert_eq!(body["error"], "invalid_request"); 261 - assert!(body["error_description"].as_str().unwrap().contains("S256")); 262 - } 263 - 264 - #[tokio::test] 265 - async fn test_par_validates_redirect_uri() { 266 - let url = base_url().await; 267 - let client = client(); 268 - 269 - let registered_redirect = "https://example.com/callback"; 270 - let wrong_redirect = "https://evil.com/steal"; 271 - let mock_client = setup_mock_client_metadata(registered_redirect).await; 272 - let client_id = mock_client.uri(); 273 - 274 - let (_, code_challenge) = generate_pkce(); 275 - 276 - let res = client 277 - .post(format!("{}/oauth/par", url)) 278 - .form(&[ 279 - ("response_type", "code"), 280 - ("client_id", &client_id), 281 - ("redirect_uri", wrong_redirect), 282 - ("code_challenge", &code_challenge), 283 - ("code_challenge_method", "S256"), 284 - ]) 285 - .send() 286 - .await 287 - .expect("Failed to send PAR request"); 288 - 289 - assert_eq!(res.status(), StatusCode::BAD_REQUEST); 290 - 291 - let body: Value = res.json().await.expect("Invalid JSON"); 292 - assert_eq!(body["error"], "invalid_request"); 293 - } 294 - 295 - #[tokio::test] 296 209 async fn test_authorize_get_with_valid_request_uri() { 297 210 let url = base_url().await; 298 211 let client = client(); ··· 604 517 } 605 518 606 519 #[tokio::test] 607 - async fn test_refresh_token_reuse_detection() { 608 - let url = base_url().await; 609 - let http_client = client(); 610 - 611 - let ts = Utc::now().timestamp_millis(); 612 - let handle = format!("reuse-test-{}", ts); 613 - let email = format!("reuse-test-{}@example.com", ts); 614 - let password = "reuse-test-password"; 615 - 616 - http_client 617 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 618 - .json(&json!({ 619 - "handle": handle, 620 - "email": email, 621 - "password": password 622 - })) 623 - .send() 624 - .await 625 - .unwrap(); 626 - 627 - let redirect_uri = "https://example.com/reuse-callback"; 628 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 629 - let client_id = mock_client.uri(); 630 - 631 - let (code_verifier, code_challenge) = generate_pkce(); 632 - 633 - let par_body: Value = http_client 634 - .post(format!("{}/oauth/par", url)) 635 - .form(&[ 636 - ("response_type", "code"), 637 - ("client_id", &client_id), 638 - ("redirect_uri", redirect_uri), 639 - ("code_challenge", &code_challenge), 640 - ("code_challenge_method", "S256"), 641 - ]) 642 - .send() 643 - .await 644 - .unwrap() 645 - .json() 646 - .await 647 - .unwrap(); 648 - 649 - let request_uri = par_body["request_uri"].as_str().unwrap(); 650 - 651 - let auth_client = no_redirect_client(); 652 - let auth_res = auth_client 653 - .post(format!("{}/oauth/authorize", url)) 654 - .form(&[ 655 - ("request_uri", request_uri), 656 - ("username", &handle), 657 - ("password", password), 658 - ("remember_device", "false"), 659 - ]) 660 - .send() 661 - .await 662 - .unwrap(); 663 - 664 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 665 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 666 - 667 - let token_body: Value = http_client 668 - .post(format!("{}/oauth/token", url)) 669 - .form(&[ 670 - ("grant_type", "authorization_code"), 671 - ("code", code), 672 - ("redirect_uri", redirect_uri), 673 - ("code_verifier", &code_verifier), 674 - ("client_id", &client_id), 675 - ]) 676 - .send() 677 - .await 678 - .unwrap() 679 - .json() 680 - .await 681 - .unwrap(); 682 - 683 - let original_refresh_token = token_body["refresh_token"].as_str().unwrap().to_string(); 684 - 685 - let first_refresh: Value = http_client 686 - .post(format!("{}/oauth/token", url)) 687 - .form(&[ 688 - ("grant_type", "refresh_token"), 689 - ("refresh_token", &original_refresh_token), 690 - ("client_id", &client_id), 691 - ]) 692 - .send() 693 - .await 694 - .unwrap() 695 - .json() 696 - .await 697 - .unwrap(); 698 - 699 - assert!(first_refresh["access_token"].is_string(), "First refresh should succeed"); 700 - 701 - let reuse_res = http_client 702 - .post(format!("{}/oauth/token", url)) 703 - .form(&[ 704 - ("grant_type", "refresh_token"), 705 - ("refresh_token", &original_refresh_token), 706 - ("client_id", &client_id), 707 - ]) 708 - .send() 709 - .await 710 - .unwrap(); 711 - 712 - assert_eq!(reuse_res.status(), StatusCode::BAD_REQUEST, "Reuse should be rejected"); 713 - 714 - let reuse_body: Value = reuse_res.json().await.unwrap(); 715 - assert_eq!(reuse_body["error"], "invalid_grant"); 716 - assert!( 717 - reuse_body["error_description"].as_str().unwrap().to_lowercase().contains("reuse"), 718 - "Error should mention reuse" 719 - ); 720 - } 721 - 722 - #[tokio::test] 723 - async fn test_pkce_verification() { 724 - let url = base_url().await; 725 - let http_client = client(); 726 - 727 - let ts = Utc::now().timestamp_millis(); 728 - let handle = format!("pkce-test-{}", ts); 729 - let email = format!("pkce-test-{}@example.com", ts); 730 - let password = "pkce-test-password"; 731 - 732 - http_client 733 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 734 - .json(&json!({ 735 - "handle": handle, 736 - "email": email, 737 - "password": password 738 - })) 739 - .send() 740 - .await 741 - .unwrap(); 742 - 743 - let redirect_uri = "https://example.com/pkce-callback"; 744 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 745 - let client_id = mock_client.uri(); 746 - 747 - let (_, code_challenge) = generate_pkce(); 748 - let wrong_verifier = "wrong-code-verifier-that-does-not-match"; 749 - 750 - let par_body: Value = http_client 751 - .post(format!("{}/oauth/par", url)) 752 - .form(&[ 753 - ("response_type", "code"), 754 - ("client_id", &client_id), 755 - ("redirect_uri", redirect_uri), 756 - ("code_challenge", &code_challenge), 757 - ("code_challenge_method", "S256"), 758 - ]) 759 - .send() 760 - .await 761 - .unwrap() 762 - .json() 763 - .await 764 - .unwrap(); 765 - 766 - let request_uri = par_body["request_uri"].as_str().unwrap(); 767 - 768 - let auth_client = no_redirect_client(); 769 - let auth_res = auth_client 770 - .post(format!("{}/oauth/authorize", url)) 771 - .form(&[ 772 - ("request_uri", request_uri), 773 - ("username", &handle), 774 - ("password", password), 775 - ("remember_device", "false"), 776 - ]) 777 - .send() 778 - .await 779 - .unwrap(); 780 - 781 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 782 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 783 - 784 - let token_res = http_client 785 - .post(format!("{}/oauth/token", url)) 786 - .form(&[ 787 - ("grant_type", "authorization_code"), 788 - ("code", code), 789 - ("redirect_uri", redirect_uri), 790 - ("code_verifier", wrong_verifier), 791 - ("client_id", &client_id), 792 - ]) 793 - .send() 794 - .await 795 - .unwrap(); 796 - 797 - assert_eq!(token_res.status(), StatusCode::BAD_REQUEST); 798 - 799 - let token_body: Value = token_res.json().await.unwrap(); 800 - assert_eq!(token_body["error"], "invalid_grant"); 801 - assert!(token_body["error_description"].as_str().unwrap().contains("PKCE")); 802 - } 803 - 804 - #[tokio::test] 805 - async fn test_authorization_code_cannot_be_reused() { 806 - let url = base_url().await; 807 - let http_client = client(); 808 - 809 - let ts = Utc::now().timestamp_millis(); 810 - let handle = format!("code-reuse-{}", ts); 811 - let email = format!("code-reuse-{}@example.com", ts); 812 - let password = "code-reuse-password"; 813 - 814 - http_client 815 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 816 - .json(&json!({ 817 - "handle": handle, 818 - "email": email, 819 - "password": password 820 - })) 821 - .send() 822 - .await 823 - .unwrap(); 824 - 825 - let redirect_uri = "https://example.com/code-reuse-callback"; 826 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 827 - let client_id = mock_client.uri(); 828 - 829 - let (code_verifier, code_challenge) = generate_pkce(); 830 - 831 - let par_body: Value = http_client 832 - .post(format!("{}/oauth/par", url)) 833 - .form(&[ 834 - ("response_type", "code"), 835 - ("client_id", &client_id), 836 - ("redirect_uri", redirect_uri), 837 - ("code_challenge", &code_challenge), 838 - ("code_challenge_method", "S256"), 839 - ]) 840 - .send() 841 - .await 842 - .unwrap() 843 - .json() 844 - .await 845 - .unwrap(); 846 - 847 - let request_uri = par_body["request_uri"].as_str().unwrap(); 848 - 849 - let auth_client = no_redirect_client(); 850 - let auth_res = auth_client 851 - .post(format!("{}/oauth/authorize", url)) 852 - .form(&[ 853 - ("request_uri", request_uri), 854 - ("username", &handle), 855 - ("password", password), 856 - ("remember_device", "false"), 857 - ]) 858 - .send() 859 - .await 860 - .unwrap(); 861 - 862 - let location = auth_res.headers().get("location").unwrap().to_str().unwrap(); 863 - let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 864 - 865 - let first_token_res = http_client 866 - .post(format!("{}/oauth/token", url)) 867 - .form(&[ 868 - ("grant_type", "authorization_code"), 869 - ("code", code), 870 - ("redirect_uri", redirect_uri), 871 - ("code_verifier", &code_verifier), 872 - ("client_id", &client_id), 873 - ]) 874 - .send() 875 - .await 876 - .unwrap(); 877 - 878 - assert_eq!(first_token_res.status(), StatusCode::OK, "First use should succeed"); 879 - 880 - let second_token_res = http_client 881 - .post(format!("{}/oauth/token", url)) 882 - .form(&[ 883 - ("grant_type", "authorization_code"), 884 - ("code", code), 885 - ("redirect_uri", redirect_uri), 886 - ("code_verifier", &code_verifier), 887 - ("client_id", &client_id), 888 - ]) 889 - .send() 890 - .await 891 - .unwrap(); 892 - 893 - assert_eq!(second_token_res.status(), StatusCode::BAD_REQUEST, "Second use should fail"); 894 - 895 - let error_body: Value = second_token_res.json().await.unwrap(); 896 - assert_eq!(error_body["error"], "invalid_grant"); 897 - } 898 - 899 - #[tokio::test] 900 520 async fn test_wrong_credentials_denied() { 901 521 let url = base_url().await; 902 522 let http_client = client(); ··· 1103 723 1104 724 let body: Value = res.json().await.unwrap(); 1105 725 assert_eq!(body["error"], "invalid_grant"); 1106 - } 1107 - 1108 - #[tokio::test] 1109 - async fn test_deactivated_account_cannot_authorize() { 1110 - let url = base_url().await; 1111 - let http_client = client(); 1112 - 1113 - let ts = Utc::now().timestamp_millis(); 1114 - let handle = format!("deact-oauth-{}", ts); 1115 - let email = format!("deact-oauth-{}@example.com", ts); 1116 - let password = "deact-oauth-password"; 1117 - 1118 - let create_res = http_client 1119 - .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 1120 - .json(&json!({ 1121 - "handle": handle, 1122 - "email": email, 1123 - "password": password 1124 - })) 1125 - .send() 1126 - .await 1127 - .unwrap(); 1128 - 1129 - assert_eq!(create_res.status(), StatusCode::OK); 1130 - let account: Value = create_res.json().await.unwrap(); 1131 - let access_jwt = account["accessJwt"].as_str().unwrap(); 1132 - 1133 - let deact_res = http_client 1134 - .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) 1135 - .header("Authorization", format!("Bearer {}", access_jwt)) 1136 - .json(&json!({})) 1137 - .send() 1138 - .await 1139 - .unwrap(); 1140 - assert_eq!(deact_res.status(), StatusCode::OK); 1141 - 1142 - let redirect_uri = "https://example.com/deact-callback"; 1143 - let mock_client = setup_mock_client_metadata(redirect_uri).await; 1144 - let client_id = mock_client.uri(); 1145 - 1146 - let (_, code_challenge) = generate_pkce(); 1147 - 1148 - let par_body: Value = http_client 1149 - .post(format!("{}/oauth/par", url)) 1150 - .form(&[ 1151 - ("response_type", "code"), 1152 - ("client_id", &client_id), 1153 - ("redirect_uri", redirect_uri), 1154 - ("code_challenge", &code_challenge), 1155 - ("code_challenge_method", "S256"), 1156 - ]) 1157 - .send() 1158 - .await 1159 - .unwrap() 1160 - .json() 1161 - .await 1162 - .unwrap(); 1163 - 1164 - let request_uri = par_body["request_uri"].as_str().unwrap(); 1165 - 1166 - let auth_res = http_client 1167 - .post(format!("{}/oauth/authorize", url)) 1168 - .header("Accept", "application/json") 1169 - .form(&[ 1170 - ("request_uri", request_uri), 1171 - ("username", &handle), 1172 - ("password", password), 1173 - ("remember_device", "false"), 1174 - ]) 1175 - .send() 1176 - .await 1177 - .unwrap(); 1178 - 1179 - assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should not be able to authorize"); 1180 - let body: Value = auth_res.json().await.unwrap(); 1181 - assert_eq!(body["error"], "access_denied"); 1182 726 } 1183 727 1184 728 #[tokio::test]
-357
tests/oauth_dpop.rs
··· 1 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 2 - use bspds::oauth::dpop::{DPoPVerifier, compute_jwk_thumbprint, DPoPJwk}; 3 - use chrono::Utc; 4 - use serde_json::json; 5 - 6 - fn create_dpop_proof( 7 - method: &str, 8 - uri: &str, 9 - nonce: Option<&str>, 10 - ath: Option<&str>, 11 - iat_offset_secs: i64, 12 - ) -> String { 13 - use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 14 - 15 - let signing_key = SigningKey::random(&mut rand::thread_rng()); 16 - let verifying_key = signing_key.verifying_key(); 17 - let point = verifying_key.to_encoded_point(false); 18 - 19 - let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 20 - let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 21 - 22 - let jwk = json!({ 23 - "kty": "EC", 24 - "crv": "P-256", 25 - "x": x, 26 - "y": y 27 - }); 28 - 29 - let header = json!({ 30 - "typ": "dpop+jwt", 31 - "alg": "ES256", 32 - "jwk": jwk 33 - }); 34 - 35 - let mut payload = json!({ 36 - "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 37 - "htm": method, 38 - "htu": uri, 39 - "iat": Utc::now().timestamp() + iat_offset_secs 40 - }); 41 - 42 - if let Some(n) = nonce { 43 - payload["nonce"] = json!(n); 44 - } 45 - 46 - if let Some(a) = ath { 47 - payload["ath"] = json!(a); 48 - } 49 - 50 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 51 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 52 - 53 - let signing_input = format!("{}.{}", header_b64, payload_b64); 54 - let signature: Signature = signing_key.sign(signing_input.as_bytes()); 55 - let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 56 - 57 - format!("{}.{}", signing_input, signature_b64) 58 - } 59 - 60 - #[test] 61 - fn test_dpop_nonce_generation() { 62 - let secret = b"test-dpop-secret-32-bytes-long!!"; 63 - let verifier = DPoPVerifier::new(secret); 64 - 65 - let nonce1 = verifier.generate_nonce(); 66 - let nonce2 = verifier.generate_nonce(); 67 - 68 - assert!(!nonce1.is_empty()); 69 - assert!(!nonce2.is_empty()); 70 - } 71 - 72 - #[test] 73 - fn test_dpop_nonce_validation_success() { 74 - let secret = b"test-dpop-secret-32-bytes-long!!"; 75 - let verifier = DPoPVerifier::new(secret); 76 - 77 - let nonce = verifier.generate_nonce(); 78 - let result = verifier.validate_nonce(&nonce); 79 - 80 - assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 81 - } 82 - 83 - #[test] 84 - fn test_dpop_nonce_wrong_secret() { 85 - let secret1 = b"test-dpop-secret-32-bytes-long!!"; 86 - let secret2 = b"different-secret-32-bytes-long!!"; 87 - 88 - let verifier1 = DPoPVerifier::new(secret1); 89 - let verifier2 = DPoPVerifier::new(secret2); 90 - 91 - let nonce = verifier1.generate_nonce(); 92 - let result = verifier2.validate_nonce(&nonce); 93 - 94 - assert!(result.is_err(), "Nonce from different secret should fail"); 95 - } 96 - 97 - #[test] 98 - fn test_dpop_nonce_invalid_format() { 99 - let secret = b"test-dpop-secret-32-bytes-long!!"; 100 - let verifier = DPoPVerifier::new(secret); 101 - 102 - assert!(verifier.validate_nonce("invalid").is_err()); 103 - assert!(verifier.validate_nonce("").is_err()); 104 - assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 105 - } 106 - 107 - #[test] 108 - fn test_jwk_thumbprint_ec_p256() { 109 - let jwk = DPoPJwk { 110 - kty: "EC".to_string(), 111 - crv: Some("P-256".to_string()), 112 - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 113 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 114 - }; 115 - 116 - let thumbprint = compute_jwk_thumbprint(&jwk); 117 - assert!(thumbprint.is_ok()); 118 - 119 - let tp = thumbprint.unwrap(); 120 - assert!(!tp.is_empty()); 121 - assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 122 - } 123 - 124 - #[test] 125 - fn test_jwk_thumbprint_ec_secp256k1() { 126 - let jwk = DPoPJwk { 127 - kty: "EC".to_string(), 128 - crv: Some("secp256k1".to_string()), 129 - x: Some("some_x_value".to_string()), 130 - y: Some("some_y_value".to_string()), 131 - }; 132 - 133 - let thumbprint = compute_jwk_thumbprint(&jwk); 134 - assert!(thumbprint.is_ok()); 135 - } 136 - 137 - #[test] 138 - fn test_jwk_thumbprint_okp_ed25519() { 139 - let jwk = DPoPJwk { 140 - kty: "OKP".to_string(), 141 - crv: Some("Ed25519".to_string()), 142 - x: Some("some_x_value".to_string()), 143 - y: None, 144 - }; 145 - 146 - let thumbprint = compute_jwk_thumbprint(&jwk); 147 - assert!(thumbprint.is_ok()); 148 - } 149 - 150 - #[test] 151 - fn test_jwk_thumbprint_missing_crv() { 152 - let jwk = DPoPJwk { 153 - kty: "EC".to_string(), 154 - crv: None, 155 - x: Some("x".to_string()), 156 - y: Some("y".to_string()), 157 - }; 158 - 159 - let thumbprint = compute_jwk_thumbprint(&jwk); 160 - assert!(thumbprint.is_err()); 161 - } 162 - 163 - #[test] 164 - fn test_jwk_thumbprint_missing_x() { 165 - let jwk = DPoPJwk { 166 - kty: "EC".to_string(), 167 - crv: Some("P-256".to_string()), 168 - x: None, 169 - y: Some("y".to_string()), 170 - }; 171 - 172 - let thumbprint = compute_jwk_thumbprint(&jwk); 173 - assert!(thumbprint.is_err()); 174 - } 175 - 176 - #[test] 177 - fn test_jwk_thumbprint_missing_y_for_ec() { 178 - let jwk = DPoPJwk { 179 - kty: "EC".to_string(), 180 - crv: Some("P-256".to_string()), 181 - x: Some("x".to_string()), 182 - y: None, 183 - }; 184 - 185 - let thumbprint = compute_jwk_thumbprint(&jwk); 186 - assert!(thumbprint.is_err()); 187 - } 188 - 189 - #[test] 190 - fn test_jwk_thumbprint_unsupported_key_type() { 191 - let jwk = DPoPJwk { 192 - kty: "RSA".to_string(), 193 - crv: None, 194 - x: None, 195 - y: None, 196 - }; 197 - 198 - let thumbprint = compute_jwk_thumbprint(&jwk); 199 - assert!(thumbprint.is_err()); 200 - } 201 - 202 - #[test] 203 - fn test_jwk_thumbprint_deterministic() { 204 - let jwk = DPoPJwk { 205 - kty: "EC".to_string(), 206 - crv: Some("P-256".to_string()), 207 - x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 208 - y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 209 - }; 210 - 211 - let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 212 - let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 213 - 214 - assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 215 - } 216 - 217 - #[test] 218 - fn test_dpop_proof_invalid_format() { 219 - let secret = b"test-dpop-secret-32-bytes-long!!"; 220 - let verifier = DPoPVerifier::new(secret); 221 - 222 - let result = verifier.verify_proof("not.enough.parts", "POST", "https://example.com", None); 223 - assert!(result.is_err()); 224 - 225 - let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 226 - assert!(result.is_err()); 227 - } 228 - 229 - #[test] 230 - fn test_dpop_proof_invalid_typ() { 231 - let secret = b"test-dpop-secret-32-bytes-long!!"; 232 - let verifier = DPoPVerifier::new(secret); 233 - 234 - let header = json!({ 235 - "typ": "JWT", 236 - "alg": "ES256", 237 - "jwk": { 238 - "kty": "EC", 239 - "crv": "P-256", 240 - "x": "x", 241 - "y": "y" 242 - } 243 - }); 244 - 245 - let payload = json!({ 246 - "jti": "unique", 247 - "htm": "POST", 248 - "htu": "https://example.com", 249 - "iat": Utc::now().timestamp() 250 - }); 251 - 252 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 253 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 254 - let proof = format!("{}.{}.sig", header_b64, payload_b64); 255 - 256 - let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 257 - assert!(result.is_err()); 258 - } 259 - 260 - #[test] 261 - fn test_dpop_proof_method_mismatch() { 262 - let secret = b"test-dpop-secret-32-bytes-long!!"; 263 - let verifier = DPoPVerifier::new(secret); 264 - 265 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 266 - 267 - let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 268 - assert!(result.is_err()); 269 - } 270 - 271 - #[test] 272 - fn test_dpop_proof_uri_mismatch() { 273 - let secret = b"test-dpop-secret-32-bytes-long!!"; 274 - let verifier = DPoPVerifier::new(secret); 275 - 276 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 277 - 278 - let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 279 - assert!(result.is_err()); 280 - } 281 - 282 - #[test] 283 - fn test_dpop_proof_iat_too_old() { 284 - let secret = b"test-dpop-secret-32-bytes-long!!"; 285 - let verifier = DPoPVerifier::new(secret); 286 - 287 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 288 - 289 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 290 - assert!(result.is_err()); 291 - } 292 - 293 - #[test] 294 - fn test_dpop_proof_iat_future() { 295 - let secret = b"test-dpop-secret-32-bytes-long!!"; 296 - let verifier = DPoPVerifier::new(secret); 297 - 298 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 299 - 300 - let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 301 - assert!(result.is_err()); 302 - } 303 - 304 - #[test] 305 - fn test_dpop_proof_ath_mismatch() { 306 - let secret = b"test-dpop-secret-32-bytes-long!!"; 307 - let verifier = DPoPVerifier::new(secret); 308 - 309 - let proof = create_dpop_proof( 310 - "GET", 311 - "https://example.com/resource", 312 - None, 313 - Some("wrong_hash"), 314 - 0, 315 - ); 316 - 317 - let result = verifier.verify_proof( 318 - &proof, 319 - "GET", 320 - "https://example.com/resource", 321 - Some("correct_hash"), 322 - ); 323 - assert!(result.is_err()); 324 - } 325 - 326 - #[test] 327 - fn test_dpop_proof_missing_ath_when_required() { 328 - let secret = b"test-dpop-secret-32-bytes-long!!"; 329 - let verifier = DPoPVerifier::new(secret); 330 - 331 - let proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 332 - 333 - let result = verifier.verify_proof( 334 - &proof, 335 - "GET", 336 - "https://example.com/resource", 337 - Some("expected_hash"), 338 - ); 339 - assert!(result.is_err()); 340 - } 341 - 342 - #[test] 343 - fn test_dpop_proof_uri_ignores_query_params() { 344 - let secret = b"test-dpop-secret-32-bytes-long!!"; 345 - let verifier = DPoPVerifier::new(secret); 346 - 347 - let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 348 - 349 - let result = verifier.verify_proof( 350 - &proof, 351 - "POST", 352 - "https://example.com/token?foo=bar", 353 - None, 354 - ); 355 - 356 - assert!(result.is_ok(), "Query params should be ignored: {:?}", result); 357 - }
+5
tests/oauth_lifecycle.rs
··· 4 4 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 5 use chrono::Utc; 6 6 use common::{base_url, client}; 7 + use helpers::verify_new_account; 7 8 use reqwest::{redirect, StatusCode}; 8 9 use serde_json::{json, Value}; 9 10 use sha2::{Digest, Sha256}; ··· 82 83 assert_eq!(create_res.status(), StatusCode::OK); 83 84 let account: Value = create_res.json().await.unwrap(); 84 85 let user_did = account["did"].as_str().unwrap().to_string(); 86 + 87 + let _ = verify_new_account(&http_client, &user_did).await; 85 88 86 89 let mock_client = setup_mock_client_metadata(redirect_uri).await; 87 90 let client_id = mock_client.uri(); ··· 588 591 assert_eq!(create_res.status(), StatusCode::OK); 589 592 let account: Value = create_res.json().await.unwrap(); 590 593 let user_did = account["did"].as_str().unwrap(); 594 + 595 + let _ = verify_new_account(&http_client, user_did).await; 591 596 592 597 let mock_client1 = setup_mock_client_metadata("https://client1.example.com/callback").await; 593 598 let client1_id = mock_client1.uri();
+358 -1
tests/oauth_security.rs
··· 8 8 use bspds::oauth::dpop::{DPoPVerifier, DPoPJwk, compute_jwk_thumbprint}; 9 9 use chrono::Utc; 10 10 use common::{base_url, client}; 11 + use helpers::verify_new_account; 11 12 use reqwest::{redirect, StatusCode}; 12 13 use serde_json::{json, Value}; 13 14 use sha2::{Digest, Sha256}; ··· 698 699 699 700 assert_eq!(create_res.status(), StatusCode::OK); 700 701 let account: Value = create_res.json().await.unwrap(); 701 - let access_jwt = account["accessJwt"].as_str().unwrap(); 702 + let did = account["did"].as_str().unwrap(); 703 + 704 + let access_jwt = verify_new_account(&http_client, did).await; 702 705 703 706 let deact_res = http_client 704 707 .post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url)) ··· 1449 1452 } 1450 1453 1451 1454 #[tokio::test] 1455 + #[ignore = "rate limiting is disabled in test environment"] 1452 1456 async fn test_security_oauth_authorize_rate_limiting() { 1453 1457 let url = base_url().await; 1454 1458 let http_client = no_redirect_client(); ··· 1511 1515 rate_limited_count 1512 1516 ); 1513 1517 } 1518 + 1519 + fn create_dpop_proof( 1520 + method: &str, 1521 + uri: &str, 1522 + nonce: Option<&str>, 1523 + ath: Option<&str>, 1524 + iat_offset_secs: i64, 1525 + ) -> String { 1526 + use p256::ecdsa::{SigningKey, Signature, signature::Signer}; 1527 + 1528 + let signing_key = SigningKey::random(&mut rand::thread_rng()); 1529 + let verifying_key = signing_key.verifying_key(); 1530 + let point = verifying_key.to_encoded_point(false); 1531 + 1532 + let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 1533 + let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 1534 + 1535 + let jwk = json!({ 1536 + "kty": "EC", 1537 + "crv": "P-256", 1538 + "x": x, 1539 + "y": y 1540 + }); 1541 + 1542 + let header = json!({ 1543 + "typ": "dpop+jwt", 1544 + "alg": "ES256", 1545 + "jwk": jwk 1546 + }); 1547 + 1548 + let mut payload = json!({ 1549 + "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)), 1550 + "htm": method, 1551 + "htu": uri, 1552 + "iat": Utc::now().timestamp() + iat_offset_secs 1553 + }); 1554 + 1555 + if let Some(n) = nonce { 1556 + payload["nonce"] = json!(n); 1557 + } 1558 + 1559 + if let Some(a) = ath { 1560 + payload["ath"] = json!(a); 1561 + } 1562 + 1563 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1564 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1565 + 1566 + let signing_input = format!("{}.{}", header_b64, payload_b64); 1567 + let signature: Signature = signing_key.sign(signing_input.as_bytes()); 1568 + let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 1569 + 1570 + format!("{}.{}", signing_input, signature_b64) 1571 + } 1572 + 1573 + #[test] 1574 + fn test_dpop_nonce_generation() { 1575 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1576 + let verifier = DPoPVerifier::new(secret); 1577 + 1578 + let nonce1 = verifier.generate_nonce(); 1579 + let nonce2 = verifier.generate_nonce(); 1580 + 1581 + assert!(!nonce1.is_empty()); 1582 + assert!(!nonce2.is_empty()); 1583 + } 1584 + 1585 + #[test] 1586 + fn test_dpop_nonce_validation_success() { 1587 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1588 + let verifier = DPoPVerifier::new(secret); 1589 + 1590 + let nonce = verifier.generate_nonce(); 1591 + let result = verifier.validate_nonce(&nonce); 1592 + 1593 + assert!(result.is_ok(), "Valid nonce should pass: {:?}", result); 1594 + } 1595 + 1596 + #[test] 1597 + fn test_dpop_nonce_wrong_secret() { 1598 + let secret1 = b"test-dpop-secret-32-bytes-long!!"; 1599 + let secret2 = b"different-secret-32-bytes-long!!"; 1600 + 1601 + let verifier1 = DPoPVerifier::new(secret1); 1602 + let verifier2 = DPoPVerifier::new(secret2); 1603 + 1604 + let nonce = verifier1.generate_nonce(); 1605 + let result = verifier2.validate_nonce(&nonce); 1606 + 1607 + assert!(result.is_err(), "Nonce from different secret should fail"); 1608 + } 1609 + 1610 + #[test] 1611 + fn test_dpop_nonce_invalid_format() { 1612 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1613 + let verifier = DPoPVerifier::new(secret); 1614 + 1615 + assert!(verifier.validate_nonce("invalid").is_err()); 1616 + assert!(verifier.validate_nonce("").is_err()); 1617 + assert!(verifier.validate_nonce("!!!not-base64!!!").is_err()); 1618 + } 1619 + 1620 + #[test] 1621 + fn test_jwk_thumbprint_ec_p256() { 1622 + let jwk = DPoPJwk { 1623 + kty: "EC".to_string(), 1624 + crv: Some("P-256".to_string()), 1625 + x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 1626 + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 1627 + }; 1628 + 1629 + let thumbprint = compute_jwk_thumbprint(&jwk); 1630 + assert!(thumbprint.is_ok()); 1631 + 1632 + let tp = thumbprint.unwrap(); 1633 + assert!(!tp.is_empty()); 1634 + assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')); 1635 + } 1636 + 1637 + #[test] 1638 + fn test_jwk_thumbprint_ec_secp256k1() { 1639 + let jwk = DPoPJwk { 1640 + kty: "EC".to_string(), 1641 + crv: Some("secp256k1".to_string()), 1642 + x: Some("some_x_value".to_string()), 1643 + y: Some("some_y_value".to_string()), 1644 + }; 1645 + 1646 + let thumbprint = compute_jwk_thumbprint(&jwk); 1647 + assert!(thumbprint.is_ok()); 1648 + } 1649 + 1650 + #[test] 1651 + fn test_jwk_thumbprint_okp_ed25519() { 1652 + let jwk = DPoPJwk { 1653 + kty: "OKP".to_string(), 1654 + crv: Some("Ed25519".to_string()), 1655 + x: Some("some_x_value".to_string()), 1656 + y: None, 1657 + }; 1658 + 1659 + let thumbprint = compute_jwk_thumbprint(&jwk); 1660 + assert!(thumbprint.is_ok()); 1661 + } 1662 + 1663 + #[test] 1664 + fn test_jwk_thumbprint_missing_crv() { 1665 + let jwk = DPoPJwk { 1666 + kty: "EC".to_string(), 1667 + crv: None, 1668 + x: Some("x".to_string()), 1669 + y: Some("y".to_string()), 1670 + }; 1671 + 1672 + let thumbprint = compute_jwk_thumbprint(&jwk); 1673 + assert!(thumbprint.is_err()); 1674 + } 1675 + 1676 + #[test] 1677 + fn test_jwk_thumbprint_missing_x() { 1678 + let jwk = DPoPJwk { 1679 + kty: "EC".to_string(), 1680 + crv: Some("P-256".to_string()), 1681 + x: None, 1682 + y: Some("y".to_string()), 1683 + }; 1684 + 1685 + let thumbprint = compute_jwk_thumbprint(&jwk); 1686 + assert!(thumbprint.is_err()); 1687 + } 1688 + 1689 + #[test] 1690 + fn test_jwk_thumbprint_missing_y_for_ec() { 1691 + let jwk = DPoPJwk { 1692 + kty: "EC".to_string(), 1693 + crv: Some("P-256".to_string()), 1694 + x: Some("x".to_string()), 1695 + y: None, 1696 + }; 1697 + 1698 + let thumbprint = compute_jwk_thumbprint(&jwk); 1699 + assert!(thumbprint.is_err()); 1700 + } 1701 + 1702 + #[test] 1703 + fn test_jwk_thumbprint_unsupported_key_type() { 1704 + let jwk = DPoPJwk { 1705 + kty: "RSA".to_string(), 1706 + crv: None, 1707 + x: None, 1708 + y: None, 1709 + }; 1710 + 1711 + let thumbprint = compute_jwk_thumbprint(&jwk); 1712 + assert!(thumbprint.is_err()); 1713 + } 1714 + 1715 + #[test] 1716 + fn test_jwk_thumbprint_deterministic() { 1717 + let jwk = DPoPJwk { 1718 + kty: "EC".to_string(), 1719 + crv: Some("P-256".to_string()), 1720 + x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()), 1721 + y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()), 1722 + }; 1723 + 1724 + let tp1 = compute_jwk_thumbprint(&jwk).unwrap(); 1725 + let tp2 = compute_jwk_thumbprint(&jwk).unwrap(); 1726 + 1727 + assert_eq!(tp1, tp2, "Thumbprint should be deterministic"); 1728 + } 1729 + 1730 + #[test] 1731 + fn test_dpop_proof_invalid_format() { 1732 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1733 + let verifier = DPoPVerifier::new(secret); 1734 + 1735 + let result = verifier.verify_proof("not.enough.parts", "POST", "https://example.com", None); 1736 + assert!(result.is_err()); 1737 + 1738 + let result = verifier.verify_proof("invalid", "POST", "https://example.com", None); 1739 + assert!(result.is_err()); 1740 + } 1741 + 1742 + #[test] 1743 + fn test_dpop_proof_invalid_typ() { 1744 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1745 + let verifier = DPoPVerifier::new(secret); 1746 + 1747 + let header = json!({ 1748 + "typ": "JWT", 1749 + "alg": "ES256", 1750 + "jwk": { 1751 + "kty": "EC", 1752 + "crv": "P-256", 1753 + "x": "x", 1754 + "y": "y" 1755 + } 1756 + }); 1757 + 1758 + let payload = json!({ 1759 + "jti": "unique", 1760 + "htm": "POST", 1761 + "htu": "https://example.com", 1762 + "iat": Utc::now().timestamp() 1763 + }); 1764 + 1765 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 1766 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 1767 + let proof = format!("{}.{}.sig", header_b64, payload_b64); 1768 + 1769 + let result = verifier.verify_proof(&proof, "POST", "https://example.com", None); 1770 + assert!(result.is_err()); 1771 + } 1772 + 1773 + #[test] 1774 + fn test_dpop_proof_method_mismatch() { 1775 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1776 + let verifier = DPoPVerifier::new(secret); 1777 + 1778 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1779 + 1780 + let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None); 1781 + assert!(result.is_err()); 1782 + } 1783 + 1784 + #[test] 1785 + fn test_dpop_proof_uri_mismatch() { 1786 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1787 + let verifier = DPoPVerifier::new(secret); 1788 + 1789 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1790 + 1791 + let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None); 1792 + assert!(result.is_err()); 1793 + } 1794 + 1795 + #[test] 1796 + fn test_dpop_proof_iat_too_old() { 1797 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1798 + let verifier = DPoPVerifier::new(secret); 1799 + 1800 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600); 1801 + 1802 + let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1803 + assert!(result.is_err()); 1804 + } 1805 + 1806 + #[test] 1807 + fn test_dpop_proof_iat_future() { 1808 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1809 + let verifier = DPoPVerifier::new(secret); 1810 + 1811 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600); 1812 + 1813 + let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None); 1814 + assert!(result.is_err()); 1815 + } 1816 + 1817 + #[test] 1818 + fn test_dpop_proof_ath_mismatch() { 1819 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1820 + let verifier = DPoPVerifier::new(secret); 1821 + 1822 + let proof = create_dpop_proof( 1823 + "GET", 1824 + "https://example.com/resource", 1825 + None, 1826 + Some("wrong_hash"), 1827 + 0, 1828 + ); 1829 + 1830 + let result = verifier.verify_proof( 1831 + &proof, 1832 + "GET", 1833 + "https://example.com/resource", 1834 + Some("correct_hash"), 1835 + ); 1836 + assert!(result.is_err()); 1837 + } 1838 + 1839 + #[test] 1840 + fn test_dpop_proof_missing_ath_when_required() { 1841 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1842 + let verifier = DPoPVerifier::new(secret); 1843 + 1844 + let proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0); 1845 + 1846 + let result = verifier.verify_proof( 1847 + &proof, 1848 + "GET", 1849 + "https://example.com/resource", 1850 + Some("expected_hash"), 1851 + ); 1852 + assert!(result.is_err()); 1853 + } 1854 + 1855 + #[test] 1856 + fn test_dpop_proof_uri_ignores_query_params() { 1857 + let secret = b"test-dpop-secret-32-bytes-long!!"; 1858 + let verifier = DPoPVerifier::new(secret); 1859 + 1860 + let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0); 1861 + 1862 + let result = verifier.verify_proof( 1863 + &proof, 1864 + "POST", 1865 + "https://example.com/token?foo=bar", 1866 + None, 1867 + ); 1868 + 1869 + assert!(result.is_ok(), "Query params should be ignored: {:?}", result); 1870 + }
+9 -1
tests/password_reset.rs
··· 1 1 mod common; 2 + mod helpers; 2 3 3 4 use reqwest::StatusCode; 4 5 use serde_json::{json, Value}; 5 6 use sqlx::PgPool; 7 + use helpers::verify_new_account; 6 8 7 9 async fn get_pool() -> PgPool { 8 10 let conn_str = common::get_db_connection_string().await; ··· 99 101 .await 100 102 .expect("Failed to create account"); 101 103 assert_eq!(res.status(), StatusCode::OK); 104 + let body: Value = res.json().await.unwrap(); 105 + let did = body["did"].as_str().unwrap(); 106 + 107 + let _ = verify_new_account(&client, did).await; 102 108 103 109 let res = client 104 110 .post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url)) ··· 270 276 .expect("Failed to create account"); 271 277 assert_eq!(res.status(), StatusCode::OK); 272 278 let body: Value = res.json().await.expect("Invalid JSON"); 273 - let original_token = body["accessJwt"].as_str().expect("No accessJwt").to_string(); 279 + let did = body["did"].as_str().expect("No did"); 280 + 281 + let original_token = verify_new_account(&client, did).await; 274 282 275 283 let res = client 276 284 .get(format!("{}/xrpc/com.atproto.server.getSession", base_url))
+3
tests/rate_limit.rs
··· 5 5 use serde_json::json; 6 6 7 7 #[tokio::test] 8 + #[ignore = "rate limiting is disabled in test environment"] 8 9 async fn test_login_rate_limiting() { 9 10 let client = client(); 10 11 let url = format!("{}/xrpc/com.atproto.server.createSession", base_url().await); ··· 47 48 } 48 49 49 50 #[tokio::test] 51 + #[ignore = "rate limiting is disabled in test environment"] 50 52 async fn test_password_reset_rate_limiting() { 51 53 let client = client(); 52 54 let url = format!( ··· 91 93 } 92 94 93 95 #[tokio::test] 96 + #[ignore = "rate limiting is disabled in test environment"] 94 97 async fn test_account_creation_rate_limiting() { 95 98 let client = client(); 96 99 let url = format!(
-347
tests/repo_record.rs
··· 1 - mod common; 2 - use common::*; 3 - 4 - use chrono::Utc; 5 - use reqwest::StatusCode; 6 - use serde_json::{Value, json}; 7 - 8 - #[tokio::test] 9 - async fn test_get_record_not_found() { 10 - let client = client(); 11 - let (_, did) = create_account_and_login(&client).await; 12 - 13 - let params = [ 14 - ("repo", did.as_str()), 15 - ("collection", "app.bsky.feed.post"), 16 - ("rkey", "nonexistent"), 17 - ]; 18 - 19 - let res = client 20 - .get(format!( 21 - "{}/xrpc/com.atproto.repo.getRecord", 22 - base_url().await 23 - )) 24 - .query(&params) 25 - .send() 26 - .await 27 - .expect("Failed to send request"); 28 - 29 - assert_eq!(res.status(), StatusCode::NOT_FOUND); 30 - } 31 - 32 - #[tokio::test] 33 - async fn test_put_record_no_auth() { 34 - let client = client(); 35 - let payload = json!({ 36 - "repo": "did:plc:123", 37 - "collection": "app.bsky.feed.post", 38 - "rkey": "fake", 39 - "record": {} 40 - }); 41 - 42 - let res = client 43 - .post(format!( 44 - "{}/xrpc/com.atproto.repo.putRecord", 45 - base_url().await 46 - )) 47 - .json(&payload) 48 - .send() 49 - .await 50 - .expect("Failed to send request"); 51 - 52 - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); 53 - let body: Value = res.json().await.expect("Response was not valid JSON"); 54 - assert_eq!(body["error"], "AuthenticationRequired"); 55 - } 56 - 57 - #[tokio::test] 58 - async fn test_put_record_success() { 59 - let client = client(); 60 - let (token, did) = create_account_and_login(&client).await; 61 - let now = Utc::now().to_rfc3339(); 62 - let payload = json!({ 63 - "repo": did, 64 - "collection": "app.bsky.feed.post", 65 - "rkey": "e2e_test_post", 66 - "record": { 67 - "$type": "app.bsky.feed.post", 68 - "text": "Hello from the e2e test script!", 69 - "createdAt": now 70 - } 71 - }); 72 - 73 - let res = client 74 - .post(format!( 75 - "{}/xrpc/com.atproto.repo.putRecord", 76 - base_url().await 77 - )) 78 - .bearer_auth(token) 79 - .json(&payload) 80 - .send() 81 - .await 82 - .expect("Failed to send request"); 83 - 84 - assert_eq!(res.status(), StatusCode::OK); 85 - let body: Value = res.json().await.expect("Response was not valid JSON"); 86 - assert!(body.get("uri").is_some()); 87 - assert!(body.get("cid").is_some()); 88 - } 89 - 90 - #[tokio::test] 91 - async fn test_get_record_missing_params() { 92 - let client = client(); 93 - let params = [("repo", "did:plc:12345")]; 94 - 95 - let res = client 96 - .get(format!( 97 - "{}/xrpc/com.atproto.repo.getRecord", 98 - base_url().await 99 - )) 100 - .query(&params) 101 - .send() 102 - .await 103 - .expect("Failed to send request"); 104 - 105 - assert_eq!( 106 - res.status(), 107 - StatusCode::BAD_REQUEST, 108 - "Expected 400 for missing params" 109 - ); 110 - } 111 - 112 - #[tokio::test] 113 - async fn test_put_record_mismatched_repo() { 114 - let client = client(); 115 - let (token, _) = create_account_and_login(&client).await; 116 - let now = Utc::now().to_rfc3339(); 117 - let payload = json!({ 118 - "repo": "did:plc:OTHER-USER", 119 - "collection": "app.bsky.feed.post", 120 - "rkey": "e2e_test_post", 121 - "record": { 122 - "$type": "app.bsky.feed.post", 123 - "text": "Hello from the e2e test script!", 124 - "createdAt": now 125 - } 126 - }); 127 - 128 - let res = client 129 - .post(format!( 130 - "{}/xrpc/com.atproto.repo.putRecord", 131 - base_url().await 132 - )) 133 - .bearer_auth(token) 134 - .json(&payload) 135 - .send() 136 - .await 137 - .expect("Failed to send request"); 138 - 139 - assert!( 140 - res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED, 141 - "Expected 403 or 401 for mismatched repo and auth, got {}", 142 - res.status() 143 - ); 144 - } 145 - 146 - #[tokio::test] 147 - async fn test_put_record_invalid_schema() { 148 - let client = client(); 149 - let (token, did) = create_account_and_login(&client).await; 150 - let now = Utc::now().to_rfc3339(); 151 - let payload = json!({ 152 - "repo": did, 153 - "collection": "app.bsky.feed.post", 154 - "rkey": "e2e_test_invalid", 155 - "record": { 156 - "$type": "app.bsky.feed.post", 157 - "createdAt": now 158 - } 159 - }); 160 - 161 - let res = client 162 - .post(format!( 163 - "{}/xrpc/com.atproto.repo.putRecord", 164 - base_url().await 165 - )) 166 - .bearer_auth(token) 167 - .json(&payload) 168 - .send() 169 - .await 170 - .expect("Failed to send request"); 171 - 172 - assert_eq!( 173 - res.status(), 174 - StatusCode::BAD_REQUEST, 175 - "Expected 400 for invalid record schema" 176 - ); 177 - } 178 - 179 - #[tokio::test] 180 - async fn test_list_records() { 181 - let client = client(); 182 - let (_, did) = create_account_and_login(&client).await; 183 - let params = [ 184 - ("repo", did.as_str()), 185 - ("collection", "app.bsky.feed.post"), 186 - ("limit", "10"), 187 - ]; 188 - let res = client 189 - .get(format!( 190 - "{}/xrpc/com.atproto.repo.listRecords", 191 - base_url().await 192 - )) 193 - .query(&params) 194 - .send() 195 - .await 196 - .expect("Failed to send request"); 197 - 198 - assert_eq!(res.status(), StatusCode::OK); 199 - } 200 - 201 - #[tokio::test] 202 - async fn test_describe_repo() { 203 - let client = client(); 204 - let (_, did) = create_account_and_login(&client).await; 205 - let params = [("repo", did.as_str())]; 206 - let res = client 207 - .get(format!( 208 - "{}/xrpc/com.atproto.repo.describeRepo", 209 - base_url().await 210 - )) 211 - .query(&params) 212 - .send() 213 - .await 214 - .expect("Failed to send request"); 215 - 216 - assert_eq!(res.status(), StatusCode::OK); 217 - } 218 - 219 - #[tokio::test] 220 - async fn test_create_record_success_with_generated_rkey() { 221 - let client = client(); 222 - let (token, did) = create_account_and_login(&client).await; 223 - let payload = json!({ 224 - "repo": did, 225 - "collection": "app.bsky.feed.post", 226 - "record": { 227 - "$type": "app.bsky.feed.post", 228 - "text": "Hello, world!", 229 - "createdAt": "2025-12-02T12:00:00Z" 230 - } 231 - }); 232 - 233 - let res = client 234 - .post(format!( 235 - "{}/xrpc/com.atproto.repo.createRecord", 236 - base_url().await 237 - )) 238 - .json(&payload) 239 - .bearer_auth(token) 240 - .send() 241 - .await 242 - .expect("Failed to send request"); 243 - 244 - assert_eq!(res.status(), StatusCode::OK); 245 - let body: Value = res.json().await.expect("Response was not valid JSON"); 246 - let uri = body["uri"].as_str().unwrap(); 247 - assert!(uri.starts_with(&format!("at://{}/app.bsky.feed.post/", did))); 248 - assert!(body.get("cid").is_some()); 249 - } 250 - 251 - #[tokio::test] 252 - async fn test_create_record_success_with_provided_rkey() { 253 - let client = client(); 254 - let (token, did) = create_account_and_login(&client).await; 255 - let rkey = format!("custom-rkey-{}", Utc::now().timestamp_millis()); 256 - let payload = json!({ 257 - "repo": did, 258 - "collection": "app.bsky.feed.post", 259 - "rkey": rkey, 260 - "record": { 261 - "$type": "app.bsky.feed.post", 262 - "text": "Hello, world!", 263 - "createdAt": "2025-12-02T12:00:00Z" 264 - } 265 - }); 266 - 267 - let res = client 268 - .post(format!( 269 - "{}/xrpc/com.atproto.repo.createRecord", 270 - base_url().await 271 - )) 272 - .json(&payload) 273 - .bearer_auth(token) 274 - .send() 275 - .await 276 - .expect("Failed to send request"); 277 - 278 - assert_eq!(res.status(), StatusCode::OK); 279 - let body: Value = res.json().await.expect("Response was not valid JSON"); 280 - assert_eq!( 281 - body["uri"], 282 - format!("at://{}/app.bsky.feed.post/{}", did, rkey) 283 - ); 284 - assert!(body.get("cid").is_some()); 285 - } 286 - 287 - #[tokio::test] 288 - async fn test_delete_record() { 289 - let client = client(); 290 - let (token, did) = create_account_and_login(&client).await; 291 - let rkey = format!("post_to_delete_{}", Utc::now().timestamp_millis()); 292 - 293 - let create_payload = json!({ 294 - "repo": did, 295 - "collection": "app.bsky.feed.post", 296 - "rkey": rkey, 297 - "record": { 298 - "$type": "app.bsky.feed.post", 299 - "text": "This post will be deleted", 300 - "createdAt": Utc::now().to_rfc3339() 301 - } 302 - }); 303 - let create_res = client 304 - .post(format!( 305 - "{}/xrpc/com.atproto.repo.putRecord", 306 - base_url().await 307 - )) 308 - .bearer_auth(&token) 309 - .json(&create_payload) 310 - .send() 311 - .await 312 - .expect("Failed to create record"); 313 - assert_eq!(create_res.status(), StatusCode::OK); 314 - 315 - let delete_payload = json!({ 316 - "repo": did, 317 - "collection": "app.bsky.feed.post", 318 - "rkey": rkey 319 - }); 320 - let delete_res = client 321 - .post(format!( 322 - "{}/xrpc/com.atproto.repo.deleteRecord", 323 - base_url().await 324 - )) 325 - .bearer_auth(&token) 326 - .json(&delete_payload) 327 - .send() 328 - .await 329 - .expect("Failed to send request"); 330 - 331 - assert_eq!(delete_res.status(), StatusCode::OK); 332 - 333 - let get_res = client 334 - .get(format!( 335 - "{}/xrpc/com.atproto.repo.getRecord", 336 - base_url().await 337 - )) 338 - .query(&[ 339 - ("repo", did.as_str()), 340 - ("collection", "app.bsky.feed.post"), 341 - ("rkey", rkey.as_str()), 342 - ]) 343 - .send() 344 - .await 345 - .expect("Failed to verify deletion"); 346 - assert_eq!(get_res.status(), StatusCode::NOT_FOUND); 347 - }
+20 -4
tests/server.rs
··· 1 1 mod common; 2 + mod helpers; 2 3 use common::*; 4 + use helpers::verify_new_account; 3 5 4 6 use reqwest::StatusCode; 5 7 use serde_json::{Value, json}; ··· 44 46 "email": format!("{}@example.com", handle), 45 47 "password": "password" 46 48 }); 47 - let _ = client 49 + let create_res = client 48 50 .post(format!( 49 51 "{}/xrpc/com.atproto.server.createAccount", 50 52 base_url().await 51 53 )) 52 54 .json(&payload) 53 55 .send() 54 - .await; 56 + .await 57 + .expect("Failed to create account"); 58 + 59 + assert_eq!(create_res.status(), StatusCode::OK); 60 + let create_body: Value = create_res.json().await.unwrap(); 61 + let did = create_body["did"].as_str().unwrap(); 62 + 63 + let _ = verify_new_account(&client, did).await; 55 64 56 65 let payload = json!({ 57 66 "identifier": handle, ··· 149 158 "email": format!("{}@example.com", handle), 150 159 "password": "password" 151 160 }); 152 - let _ = client 161 + let create_res = client 153 162 .post(format!( 154 163 "{}/xrpc/com.atproto.server.createAccount", 155 164 base_url().await 156 165 )) 157 166 .json(&payload) 158 167 .send() 159 - .await; 168 + .await 169 + .expect("Failed to create account"); 170 + 171 + assert_eq!(create_res.status(), StatusCode::OK); 172 + let create_body: Value = create_res.json().await.unwrap(); 173 + let did = create_body["did"].as_str().unwrap(); 174 + 175 + let _ = verify_new_account(&client, did).await; 160 176 161 177 let login_payload = json!({ 162 178 "identifier": handle,
+10 -3
tests/signing_key.rs
··· 1 1 mod common; 2 + mod helpers; 2 3 3 4 use reqwest::StatusCode; 4 5 use serde_json::{json, Value}; 5 6 use sqlx::PgPool; 7 + use helpers::verify_new_account; 6 8 7 9 async fn get_pool() -> PgPool { 8 10 let conn_str = common::get_db_connection_string().await; ··· 200 202 201 203 assert_eq!(res.status(), StatusCode::OK); 202 204 let body: Value = res.json().await.unwrap(); 203 - assert!(body["accessJwt"].is_string()); 204 205 assert!(body["did"].is_string()); 206 + let did = body["did"].as_str().unwrap(); 207 + 208 + let access_jwt = verify_new_account(&client, did).await; 209 + assert!(!access_jwt.is_empty()); 205 210 206 211 let reserved = sqlx::query!( 207 212 "SELECT used_at FROM reserved_signing_keys WHERE public_key_did_key = $1", ··· 337 342 .expect("Failed to create account"); 338 343 assert_eq!(res.status(), StatusCode::OK); 339 344 let body: Value = res.json().await.unwrap(); 340 - let access_jwt = body["accessJwt"].as_str().unwrap(); 345 + let did = body["did"].as_str().unwrap(); 346 + 347 + let access_jwt = verify_new_account(&client, did).await; 341 348 342 349 let res = client 343 350 .get(format!( 344 351 "{}/xrpc/com.atproto.server.getSession", 345 352 base_url 346 353 )) 347 - .bearer_auth(access_jwt) 354 + .bearer_auth(&access_jwt) 348 355 .send() 349 356 .await 350 357 .expect("Failed to get session");