this repo has no description

My misc TODOs

lewis 4a767dea 8a10a7a6

+34
.sqlx/query-08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT id, password_hash, handle FROM users WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "password_hash", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "handle", 19 + "type_info": "Text" 20 + } 21 + ], 22 + "parameters": { 23 + "Left": [ 24 + "Text" 25 + ] 26 + }, 27 + "nullable": [ 28 + false, 29 + false, 30 + false 31 + ] 32 + }, 33 + "hash": "08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97" 34 + }
-28
.sqlx/query-76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT id, password_hash FROM users WHERE did = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "id", 9 - "type_info": "Uuid" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "password_hash", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Text" 20 - ] 21 - }, 22 - "nullable": [ 23 - false, 24 - false 25 - ] 26 - }, 27 - "hash": "76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3" 28 - }
···
+22
.sqlx/query-e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf.json
···
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT handle FROM users WHERE id = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "handle", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf" 22 + }
+73 -2
Cargo.lock
··· 99 checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" 100 101 [[package]] 102 name = "assert-json-diff" 103 version = "2.0.2" 104 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 689 ] 690 691 [[package]] 692 name = "base-x" 693 version = "0.2.11" 694 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 931 "p256 0.13.2", 932 "p384", 933 "rand 0.8.5", 934 "reqwest", 935 "serde", 936 "serde_bytes", ··· 1176 version = "1.1.0" 1177 source = "registry+https://github.com/rust-lang/crates.io-index" 1178 checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 1179 1180 [[package]] 1181 name = "compression-codecs" ··· 2973 2974 [[package]] 2975 name = "itertools" 2976 version = "0.14.0" 2977 source = "registry+https://github.com/rust-lang/crates.io-index" 2978 checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" ··· 4241 checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" 4242 dependencies = [ 4243 "anyhow", 4244 - "itertools", 4245 "proc-macro2", 4246 "quote", 4247 "syn 2.0.111", ··· 4442 ] 4443 4444 [[package]] 4445 name = "redox_syscall" 4446 version = "0.5.18" 4447 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5055 ] 5056 5057 [[package]] 5058 name = "sha2" 5059 version = "0.10.9" 5060 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5646 "etcetera 0.11.0", 5647 "ferroid", 5648 "futures", 5649 - "itertools", 5650 "log", 5651 "memchr", 5652 "parse-display",
··· 99 checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" 100 101 [[package]] 102 + name = "arc-swap" 103 + version = "1.7.1" 104 + source = "registry+https://github.com/rust-lang/crates.io-index" 105 + checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" 106 + 107 + [[package]] 108 name = "assert-json-diff" 109 version = "2.0.2" 110 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 695 ] 696 697 [[package]] 698 + name = "backon" 699 + version = "1.6.0" 700 + source = "registry+https://github.com/rust-lang/crates.io-index" 701 + checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" 702 + dependencies = [ 703 + "fastrand", 704 + ] 705 + 706 + [[package]] 707 name = "base-x" 708 version = "0.2.11" 709 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 946 "p256 0.13.2", 947 "p384", 948 "rand 0.8.5", 949 + "redis", 950 "reqwest", 951 "serde", 952 "serde_bytes", ··· 1192 version = "1.1.0" 1193 source = "registry+https://github.com/rust-lang/crates.io-index" 1194 checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 1195 + 1196 + [[package]] 1197 + name = "combine" 1198 + version = "4.6.7" 1199 + source = "registry+https://github.com/rust-lang/crates.io-index" 1200 + checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" 1201 + dependencies = [ 1202 + "bytes", 1203 + "futures-core", 1204 + "memchr", 1205 + "pin-project-lite", 1206 + "tokio", 1207 + "tokio-util", 1208 + ] 1209 1210 [[package]] 1211 name = "compression-codecs" ··· 3003 3004 [[package]] 3005 name = "itertools" 3006 + version = "0.13.0" 3007 + source = "registry+https://github.com/rust-lang/crates.io-index" 3008 + checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" 3009 + dependencies = [ 3010 + "either", 3011 + ] 3012 + 3013 + [[package]] 3014 + name = "itertools" 3015 version = "0.14.0" 3016 source = "registry+https://github.com/rust-lang/crates.io-index" 3017 checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" ··· 4280 checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" 4281 dependencies = [ 4282 "anyhow", 4283 + "itertools 0.14.0", 4284 "proc-macro2", 4285 "quote", 4286 "syn 2.0.111", ··· 4481 ] 4482 4483 [[package]] 4484 + name = "redis" 4485 + version = "0.27.6" 4486 + source = "registry+https://github.com/rust-lang/crates.io-index" 4487 + checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc" 4488 + dependencies = [ 4489 + "arc-swap", 4490 + "async-trait", 4491 + "backon", 4492 + "bytes", 4493 + "combine", 4494 + "futures", 4495 + "futures-util", 4496 + "itertools 0.13.0", 4497 + "itoa", 4498 + "num-bigint", 4499 + "percent-encoding", 4500 + "pin-project-lite", 4501 + "ryu", 4502 + "sha1_smol", 4503 + "socket2 0.5.10", 4504 + "tokio", 4505 + "tokio-util", 4506 + "url", 4507 + ] 4508 + 4509 + [[package]] 4510 name = "redox_syscall" 4511 version = "0.5.18" 4512 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5120 ] 5121 5122 [[package]] 5123 + name = "sha1_smol" 5124 + version = "1.0.1" 5125 + source = "registry+https://github.com/rust-lang/crates.io-index" 5126 + checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" 5127 + 5128 + [[package]] 5129 name = "sha2" 5130 version = "0.10.9" 5131 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5717 "etcetera 0.11.0", 5718 "ferroid", 5719 "futures", 5720 + "itertools 0.14.0", 5721 "log", 5722 "memchr", 5723 "parse-display",
+1
Cargo.toml
··· 49 uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } 50 iroh-car = "0.5.1" 51 image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } 52 53 [features] 54 external-infra = []
··· 49 uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } 50 iroh-car = "0.5.1" 51 image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } 52 + redis = { version = "0.27", features = ["tokio-comp", "connection-manager"] } 53 54 [features] 55 external-infra = []
+17 -7
TODO.md
··· 198 - [x] Implement Atomic Repo Transactions. 199 - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. 200 - [x] Implement concurrency control (row-level locking via FOR UPDATE). 201 - - [ ] DID Cache 202 - - [ ] Implement caching layer for DID resolution (Redis or in-memory). 203 - - [ ] Handle cache invalidation/expiry. 204 - [x] Crawlers Service 205 - [x] Implement `Crawlers` service (debounce notifications to relays). 206 - [x] 20-minute notification debounce. ··· 229 - [x] Per-IP rate limiting on OAuth token endpoint (30/min). 230 - [x] Per-IP rate limiting on password reset (5/hour). 231 - [x] Per-IP rate limiting on account creation (10/hour). 232 - [x] Circuit Breakers 233 - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). 234 - [x] Relay notification circuit breaker (10 failures → open, 30s timeout). ··· 237 - [x] Signal command injection prevention (phone number validation). 238 - [x] Constant-time signature comparison. 239 - [x] SSRF protection for outbound requests. 240 241 ## Lewis' fabulous mini-list of remaining TODOs 242 - - [ ] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again. 243 - - [ ] DID resolution caching (valkey). 244 - - [ ] Record schema validation (generic validation framework). 245 - - [ ] Fix any remaining TODOs in the code. 246 247 ## Future: Web Management UI 248 A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
··· 198 - [x] Implement Atomic Repo Transactions. 199 - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. 200 - [x] Implement concurrency control (row-level locking via FOR UPDATE). 201 + - [x] DID Cache 202 + - [x] Implement caching layer for DID resolution (valkey). 203 + - [x] Handle cache invalidation/expiry. 204 + - [x] Graceful fallback to no-cache when Valkey unavailable. 205 - [x] Crawlers Service 206 - [x] Implement `Crawlers` service (debounce notifications to relays). 207 - [x] 20-minute notification debounce. ··· 230 - [x] Per-IP rate limiting on OAuth token endpoint (30/min). 231 - [x] Per-IP rate limiting on password reset (5/hour). 232 - [x] Per-IP rate limiting on account creation (10/hour). 233 + - [x] Per-IP rate limiting on refreshSession (60/min). 234 + - [x] Per-IP rate limiting on OAuth authorize POST (10/min). 235 + - [x] Per-IP rate limiting on OAuth 2FA POST (10/min). 236 + - [x] Per-IP rate limiting on OAuth PAR (30/min). 237 + - [x] Per-IP rate limiting on OAuth revoke/introspect (30/min). 238 + - [x] Per-IP rate limiting on createAppPassword (10/min). 239 + - [x] Per-IP rate limiting on email endpoints (5/hour). 240 + - [x] Distributed rate limiting via Valkey/Redis (with in-memory fallback). 241 - [x] Circuit Breakers 242 - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). 243 - [x] Relay notification circuit breaker (10 failures → open, 30s timeout). ··· 246 - [x] Signal command injection prevention (phone number validation). 247 - [x] Constant-time signature comparison. 248 - [x] SSRF protection for outbound requests. 249 + - [x] Timing attack protection (dummy bcrypt on user-not-found prevents account enumeration). 250 251 ## Lewis' fabulous mini-list of remaining TODOs 252 + - [x] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again. 253 + - [x] DID resolution caching (valkey). 254 + - [x] Record schema validation (generic validation framework). 255 + - [x] Fix any remaining TODOs in the code. 256 257 ## Future: Web Management UI 258 A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
+10
docker-compose.yaml
··· 11 environment: 12 DATABASE_URL: postgres://postgres:postgres@db:5432/pds 13 S3_ENDPOINT: http://objsto:9000 14 depends_on: 15 - db 16 - objsto 17 18 db: 19 image: postgres:latest ··· 38 - minio_data:/data 39 command: server /data --console-address ":9001" 40 41 volumes: 42 postgres_data: 43 minio_data:
··· 11 environment: 12 DATABASE_URL: postgres://postgres:postgres@db:5432/pds 13 S3_ENDPOINT: http://objsto:9000 14 + VALKEY_URL: redis://cache:6379 15 depends_on: 16 - db 17 - objsto 18 + - cache 19 20 db: 21 image: postgres:latest ··· 40 - minio_data:/data 41 command: server /data --console-address ":9001" 42 43 + cache: 44 + image: valkey/valkey:8-alpine 45 + ports: 46 + - "6379:6379" 47 + volumes: 48 + - valkey_data:/data 49 + 50 volumes: 51 postgres_data: 52 minio_data: 53 + valkey_data:
+20 -3
scripts/test-infra.sh
··· 38 rm -f "$INFRA_FILE" 39 fi 40 41 - $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true 42 43 echo "Starting PostgreSQL..." 44 $CONTAINER_CMD run -d \ ··· 59 --label bspds_test=true \ 60 minio/minio:latest server /data >/dev/null 61 62 echo "Waiting for services to be ready..." 63 sleep 2 64 65 PG_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-postgres" 5432 | head -1 | cut -d: -f2) 66 MINIO_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-minio" 9000 | head -1 | cut -d: -f2) 67 68 for i in {1..30}; do 69 if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-postgres" pg_isready -U postgres >/dev/null 2>&1; then ··· 81 sleep 1 82 done 83 84 echo "Creating MinIO bucket..." 85 $CONTAINER_CMD run --rm --network host \ 86 -e MC_HOST_minio="http://minioadmin:minioadmin@127.0.0.1:${MINIO_PORT}" \ ··· 94 export AWS_ACCESS_KEY_ID="minioadmin" 95 export AWS_SECRET_ACCESS_KEY="minioadmin" 96 export AWS_REGION="us-east-1" 97 export BSPDS_TEST_INFRA_READY="1" 98 export BSPDS_ALLOW_INSECURE_SECRETS="1" 99 export SKIP_IMPORT_VERIFICATION="true" ··· 108 109 stop_infra() { 110 echo "Stopping test infrastructure..." 111 - $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true 112 rm -f "$INFRA_FILE" 113 echo "Infrastructure stopped." 114 } ··· 157 echo "Usage: $0 {start|stop|restart|status|env}" 158 echo "" 159 echo "Commands:" 160 - echo " start - Start test infrastructure (Postgres, MinIO)" 161 echo " stop - Stop and remove test containers" 162 echo " restart - Stop then start infrastructure" 163 echo " status - Show infrastructure status"
··· 38 rm -f "$INFRA_FILE" 39 fi 40 41 + $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true 42 43 echo "Starting PostgreSQL..." 44 $CONTAINER_CMD run -d \ ··· 59 --label bspds_test=true \ 60 minio/minio:latest server /data >/dev/null 61 62 + echo "Starting Valkey..." 63 + $CONTAINER_CMD run -d \ 64 + --name "${CONTAINER_PREFIX}-valkey" \ 65 + -P \ 66 + --label bspds_test=true \ 67 + valkey/valkey:8-alpine >/dev/null 68 + 69 echo "Waiting for services to be ready..." 70 sleep 2 71 72 PG_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-postgres" 5432 | head -1 | cut -d: -f2) 73 MINIO_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-minio" 9000 | head -1 | cut -d: -f2) 74 + VALKEY_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-valkey" 6379 | head -1 | cut -d: -f2) 75 76 for i in {1..30}; do 77 if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-postgres" pg_isready -U postgres >/dev/null 2>&1; then ··· 89 sleep 1 90 done 91 92 + for i in {1..30}; do 93 + if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-valkey" valkey-cli ping 2>/dev/null | grep -q PONG; then 94 + break 95 + fi 96 + echo "Waiting for Valkey... ($i/30)" 97 + sleep 1 98 + done 99 + 100 echo "Creating MinIO bucket..." 101 $CONTAINER_CMD run --rm --network host \ 102 -e MC_HOST_minio="http://minioadmin:minioadmin@127.0.0.1:${MINIO_PORT}" \ ··· 110 export AWS_ACCESS_KEY_ID="minioadmin" 111 export AWS_SECRET_ACCESS_KEY="minioadmin" 112 export AWS_REGION="us-east-1" 113 + export VALKEY_URL="redis://127.0.0.1:${VALKEY_PORT}" 114 export BSPDS_TEST_INFRA_READY="1" 115 export BSPDS_ALLOW_INSECURE_SECRETS="1" 116 export SKIP_IMPORT_VERIFICATION="true" ··· 125 126 stop_infra() { 127 echo "Stopping test infrastructure..." 128 + $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true 129 rm -f "$INFRA_FILE" 130 echo "Infrastructure stopped." 131 } ··· 174 echo "Usage: $0 {start|stop|restart|status|env}" 175 echo "" 176 echo "Commands:" 177 + echo " start - Start test infrastructure (Postgres, MinIO, Valkey)" 178 echo " stop - Stop and remove test containers" 179 echo " restart - Stop then start infrastructure" 180 echo " status - Show infrastructure status"
+5 -3
src/api/admin/account/delete.rs
··· 37 .into_response(); 38 } 39 40 - let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 41 .fetch_optional(&state.db) 42 .await; 43 44 - let user_id = match user { 45 - Ok(Some(row)) => row.id, 46 Ok(None) => { 47 return ( 48 StatusCode::NOT_FOUND, ··· 185 ) 186 .into_response(); 187 } 188 189 (StatusCode::OK, Json(json!({}))).into_response() 190 }
··· 37 .into_response(); 38 } 39 40 + let user = sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) 41 .fetch_optional(&state.db) 42 .await; 43 44 + let (user_id, handle) = match user { 45 + Ok(Some(row)) => (row.id, row.handle), 46 Ok(None) => { 47 return ( 48 StatusCode::NOT_FOUND, ··· 185 ) 186 .into_response(); 187 } 188 + 189 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 190 191 (StatusCode::OK, Json(json!({}))).into_response() 192 }
+10
src/api/admin/account/update.rs
··· 108 .into_response(); 109 } 110 111 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 112 .fetch_optional(&state.db) 113 .await; ··· 133 ) 134 .into_response(); 135 } 136 (StatusCode::OK, Json(json!({}))).into_response() 137 } 138 Err(e) => {
··· 108 .into_response(); 109 } 110 111 + let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 112 + .fetch_optional(&state.db) 113 + .await 114 + .ok() 115 + .flatten(); 116 + 117 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 118 .fetch_optional(&state.db) 119 .await; ··· 139 ) 140 .into_response(); 141 } 142 + if let Some(old) = old_handle { 143 + let _ = state.cache.delete(&format!("handle:{}", old)).await; 144 + } 145 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 146 (StatusCode::OK, Json(json!({}))).into_response() 147 } 148 Err(e) => {
+7
src/api/admin/status.rs
··· 305 .into_response(); 306 } 307 308 return ( 309 StatusCode::OK, 310 Json(json!({
··· 305 .into_response(); 306 } 307 308 + if let Ok(Some(handle)) = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 309 + .fetch_optional(&state.db) 310 + .await 311 + { 312 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 313 + } 314 + 315 return ( 316 StatusCode::OK, 317 Json(json!({
+19 -1
src/api/identity/did.rs
··· 33 .into_response(); 34 } 35 36 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 37 .fetch_optional(&state.db) 38 .await; 39 40 match user { 41 Ok(Some(row)) => { 42 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 43 } 44 Ok(None) => ( ··· 406 .into_response(); 407 } 408 409 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 410 .fetch_optional(&state.db) 411 .await; ··· 423 .await; 424 425 match result { 426 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 427 Err(e) => { 428 error!("DB error updating handle: {:?}", e); 429 (
··· 33 .into_response(); 34 } 35 36 + let cache_key = format!("handle:{}", handle); 37 + if let Some(did) = state.cache.get(&cache_key).await { 38 + return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 39 + } 40 + 41 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 42 .fetch_optional(&state.db) 43 .await; 44 45 match user { 46 Ok(Some(row)) => { 47 + let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 48 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 49 } 50 Ok(None) => ( ··· 412 .into_response(); 413 } 414 415 + let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 416 + .fetch_optional(&state.db) 417 + .await 418 + .ok() 419 + .flatten(); 420 + 421 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 422 .fetch_optional(&state.db) 423 .await; ··· 435 .await; 436 437 match result { 438 + Ok(_) => { 439 + if let Some(old) = old_handle { 440 + let _ = state.cache.delete(&format!("handle:{}", old)).await; 441 + } 442 + let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 443 + (StatusCode::OK, Json(json!({}))).into_response() 444 + } 445 Err(e) => { 446 error!("DB error updating handle: {:?}", e); 447 (
+11
src/api/repo/record/batch.rs
··· 1 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 2 use crate::repo::tracking::TrackingBlockStore; 3 use crate::state::AppState; ··· 211 rkey, 212 value, 213 } => { 214 let rkey = rkey 215 .clone() 216 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string()); ··· 249 rkey, 250 value, 251 } => { 252 let mut record_bytes = Vec::new(); 253 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 254 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
··· 1 + use super::validation::validate_record; 2 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 3 use crate::repo::tracking::TrackingBlockStore; 4 use crate::state::AppState; ··· 212 rkey, 213 value, 214 } => { 215 + if input.validate.unwrap_or(true) { 216 + if let Err(err_response) = validate_record(value, collection) { 217 + return err_response; 218 + } 219 + } 220 let rkey = rkey 221 .clone() 222 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string()); ··· 255 rkey, 256 value, 257 } => { 258 + if input.validate.unwrap_or(true) { 259 + if let Err(err_response) = validate_record(value, collection) { 260 + return err_response; 261 + } 262 + } 263 let mut record_bytes = Vec::new(); 264 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 265 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
+1
src/api/repo/record/mod.rs
··· 2 pub mod delete; 3 pub mod read; 4 pub mod utils; 5 pub mod write; 6 7 pub use batch::apply_writes;
··· 2 pub mod delete; 3 pub mod read; 4 pub mod utils; 5 + pub mod validation; 6 pub mod write; 7 8 pub use batch::apply_writes;
+38
src/api/repo/record/validation.rs
···
··· 1 + use crate::validation::{RecordValidator, ValidationError}; 2 + use axum::{ 3 + http::StatusCode, 4 + response::{IntoResponse, Response}, 5 + Json, 6 + }; 7 + use serde_json::json; 8 + 9 + pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 10 + let validator = RecordValidator::new(); 11 + match validator.validate(record, collection) { 12 + Ok(_) => Ok(()), 13 + Err(ValidationError::MissingType) => Err(( 14 + StatusCode::BAD_REQUEST, 15 + Json(json!({"error": "InvalidRecord", "message": "Record must have a $type field"})), 16 + ).into_response()), 17 + Err(ValidationError::TypeMismatch { expected, actual }) => Err(( 18 + StatusCode::BAD_REQUEST, 19 + Json(json!({"error": "InvalidRecord", "message": format!("Record $type '{}' does not match collection '{}'", actual, expected)})), 20 + ).into_response()), 21 + Err(ValidationError::MissingField(field)) => Err(( 22 + StatusCode::BAD_REQUEST, 23 + Json(json!({"error": "InvalidRecord", "message": format!("Missing required field: {}", field)})), 24 + ).into_response()), 25 + Err(ValidationError::InvalidField { path, message }) => Err(( 26 + StatusCode::BAD_REQUEST, 27 + Json(json!({"error": "InvalidRecord", "message": format!("Invalid field '{}': {}", path, message)})), 28 + ).into_response()), 29 + Err(ValidationError::InvalidDatetime { path }) => Err(( 30 + StatusCode::BAD_REQUEST, 31 + Json(json!({"error": "InvalidRecord", "message": format!("Invalid datetime format at '{}'", path)})), 32 + ).into_response()), 33 + Err(e) => Err(( 34 + StatusCode::BAD_REQUEST, 35 + Json(json!({"error": "InvalidRecord", "message": e.to_string()})), 36 + ).into_response()), 37 + } 38 + }
+5 -16
src/api/repo/record/write.rs
··· 1 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 2 use crate::repo::tracking::TrackingBlockStore; 3 use crate::state::AppState; ··· 156 }; 157 158 if input.validate.unwrap_or(true) { 159 - if input.collection == "app.bsky.feed.post" { 160 - if input.record.get("text").is_none() || input.record.get("createdAt").is_none() { 161 - return ( 162 - StatusCode::BAD_REQUEST, 163 - Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})), 164 - ) 165 - .into_response(); 166 - } 167 } 168 } 169 ··· 263 let key = format!("{}/{}", collection_nsid, input.rkey); 264 265 if input.validate.unwrap_or(true) { 266 - if input.collection == "app.bsky.feed.post" { 267 - if input.record.get("text").is_none() || input.record.get("createdAt").is_none() { 268 - return ( 269 - StatusCode::BAD_REQUEST, 270 - Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})), 271 - ) 272 - .into_response(); 273 - } 274 } 275 } 276
··· 1 + use super::validation::validate_record; 2 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 3 use crate::repo::tracking::TrackingBlockStore; 4 use crate::state::AppState; ··· 157 }; 158 159 if input.validate.unwrap_or(true) { 160 + if let Err(err_response) = validate_record(&input.record, &input.collection) { 161 + return err_response; 162 } 163 } 164 ··· 258 let key = format!("{}/{}", collection_nsid, input.rkey); 259 260 if input.validate.unwrap_or(true) { 261 + if let Err(err_response) = validate_record(&input.record, &input.collection) { 262 + return err_response; 263 } 264 } 265
+28 -5
src/api/server/account_status.rs
··· 123 Err(e) => return ApiError::from(e).into_response(), 124 }; 125 126 let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 127 .execute(&state.db) 128 .await; 129 130 match result { 131 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 132 Err(e) => { 133 error!("DB error activating account: {:?}", e); 134 ( ··· 163 Err(e) => return ApiError::from(e).into_response(), 164 }; 165 166 let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did) 167 .execute(&state.db) 168 .await; 169 170 match result { 171 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 172 Err(e) => { 173 error!("DB error deactivating account: {:?}", e); 174 ( ··· 283 } 284 285 let user = sqlx::query!( 286 - "SELECT id, password_hash FROM users WHERE did = $1", 287 did 288 ) 289 .fetch_optional(&state.db) 290 .await; 291 292 - let (user_id, password_hash) = match user { 293 - Ok(Some(row)) => (row.id, row.password_hash), 294 Ok(None) => { 295 return ( 296 StatusCode::BAD_REQUEST, ··· 437 ) 438 .into_response(); 439 } 440 info!("Account {} deleted successfully", did); 441 (StatusCode::OK, Json(json!({}))).into_response() 442 }
··· 123 Err(e) => return ApiError::from(e).into_response(), 124 }; 125 126 + let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 127 + .fetch_optional(&state.db) 128 + .await 129 + .ok() 130 + .flatten(); 131 + 132 let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 133 .execute(&state.db) 134 .await; 135 136 match result { 137 + Ok(_) => { 138 + if let Some(h) = handle { 139 + let _ = state.cache.delete(&format!("handle:{}", h)).await; 140 + } 141 + (StatusCode::OK, Json(json!({}))).into_response() 142 + } 143 Err(e) => { 144 error!("DB error activating account: {:?}", e); 145 ( ··· 174 Err(e) => return ApiError::from(e).into_response(), 175 }; 176 177 + let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 178 + .fetch_optional(&state.db) 179 + .await 180 + .ok() 181 + .flatten(); 182 + 183 let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did) 184 .execute(&state.db) 185 .await; 186 187 match result { 188 + Ok(_) => { 189 + if let Some(h) = handle { 190 + let _ = state.cache.delete(&format!("handle:{}", h)).await; 191 + } 192 + (StatusCode::OK, Json(json!({}))).into_response() 193 + } 194 Err(e) => { 195 error!("DB error deactivating account: {:?}", e); 196 ( ··· 305 } 306 307 let user = sqlx::query!( 308 + "SELECT id, password_hash, handle FROM users WHERE did = $1", 309 did 310 ) 311 .fetch_optional(&state.db) 312 .await; 313 314 + let (user_id, password_hash, handle) = match user { 315 + Ok(Some(row)) => (row.id, row.password_hash, row.handle), 316 Ok(None) => { 317 return ( 318 StatusCode::BAD_REQUEST, ··· 459 ) 460 .into_response(); 461 } 462 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 463 info!("Account {} deleted successfully", did); 464 (StatusCode::OK, Json(json!({}))).into_response() 465 }
+21 -1
src/api/server/app_password.rs
··· 5 use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9 }; 10 use serde::{Deserialize, Serialize}; 11 use serde_json::json; 12 - use tracing::error; 13 14 #[derive(Serialize)] 15 #[serde(rename_all = "camelCase")] ··· 76 77 pub async fn create_app_password( 78 State(state): State<AppState>, 79 BearerAuth(auth_user): BearerAuth, 80 Json(input): Json<CreateAppPasswordInput>, 81 ) -> Response { 82 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 83 Ok(id) => id, 84 Err(e) => return ApiError::from(e).into_response(),
··· 5 use axum::{ 6 Json, 7 extract::State, 8 + http::HeaderMap, 9 response::{IntoResponse, Response}, 10 }; 11 use serde::{Deserialize, Serialize}; 12 use serde_json::json; 13 + use tracing::{error, warn}; 14 15 #[derive(Serialize)] 16 #[serde(rename_all = "camelCase")] ··· 77 78 pub async fn create_app_password( 79 State(state): State<AppState>, 80 + headers: HeaderMap, 81 BearerAuth(auth_user): BearerAuth, 82 Json(input): Json<CreateAppPasswordInput>, 83 ) -> Response { 84 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 85 + if !state.distributed_rate_limiter.check_rate_limit( 86 + &format!("app_password:{}", client_ip), 87 + 10, 88 + 60_000, 89 + ).await { 90 + if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 91 + warn!(ip = %client_ip, "App password creation rate limit exceeded"); 92 + return ( 93 + axum::http::StatusCode::TOO_MANY_REQUESTS, 94 + Json(json!({ 95 + "error": "RateLimitExceeded", 96 + "message": "Too many requests. Please try again later." 97 + })), 98 + ).into_response(); 99 + } 100 + } 101 + 102 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 103 Ok(id) => id, 104 Err(e) => return ApiError::from(e).into_response(),
+36
src/api/server/email.rs
··· 26 headers: axum::http::HeaderMap, 27 Json(input): Json<RequestEmailUpdateInput>, 28 ) -> Response { 29 let token = match crate::auth::extract_bearer_token_from_header( 30 headers.get("Authorization").and_then(|h| h.to_str().ok()) 31 ) { ··· 135 headers: axum::http::HeaderMap, 136 Json(input): Json<ConfirmEmailInput>, 137 ) -> Response { 138 let token = match crate::auth::extract_bearer_token_from_header( 139 headers.get("Authorization").and_then(|h| h.to_str().ok()) 140 ) {
··· 26 headers: axum::http::HeaderMap, 27 Json(input): Json<RequestEmailUpdateInput>, 28 ) -> Response { 29 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 30 + if !state.distributed_rate_limiter.check_rate_limit( 31 + &format!("email_update:{}", client_ip), 32 + 5, 33 + 3_600_000, 34 + ).await { 35 + if state.rate_limiters.email_update.check_key(&client_ip).is_err() { 36 + warn!(ip = %client_ip, "Email update rate limit exceeded"); 37 + return ( 38 + StatusCode::TOO_MANY_REQUESTS, 39 + Json(json!({ 40 + "error": "RateLimitExceeded", 41 + "message": "Too many requests. Please try again later." 42 + })), 43 + ).into_response(); 44 + } 45 + } 46 + 47 let token = match crate::auth::extract_bearer_token_from_header( 48 headers.get("Authorization").and_then(|h| h.to_str().ok()) 49 ) { ··· 153 headers: axum::http::HeaderMap, 154 Json(input): Json<ConfirmEmailInput>, 155 ) -> Response { 156 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 157 + if !state.distributed_rate_limiter.check_rate_limit( 158 + &format!("confirm_email:{}", client_ip), 159 + 10, 160 + 60_000, 161 + ).await { 162 + if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 163 + warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 164 + return ( 165 + StatusCode::TOO_MANY_REQUESTS, 166 + Json(json!({ 167 + "error": "RateLimitExceeded", 168 + "message": "Too many requests. Please try again later." 169 + })), 170 + ).into_response(); 171 + } 172 + } 173 + 174 let token = match crate::auth::extract_bearer_token_from_header( 175 headers.get("Authorization").and_then(|h| h.to_str().ok()) 176 ) {
+19
src/api/server/password.rs
··· 124 125 pub async fn reset_password( 126 State(state): State<AppState>, 127 Json(input): Json<ResetPasswordInput>, 128 ) -> Response { 129 let token = input.token.trim(); 130 let password = &input.password; 131
··· 124 125 pub async fn reset_password( 126 State(state): State<AppState>, 127 + headers: HeaderMap, 128 Json(input): Json<ResetPasswordInput>, 129 ) -> Response { 130 + let client_ip = extract_client_ip(&headers); 131 + if !state.distributed_rate_limiter.check_rate_limit( 132 + &format!("reset_password:{}", client_ip), 133 + 10, 134 + 60_000, 135 + ).await { 136 + if state.rate_limiters.reset_password.check_key(&client_ip).is_err() { 137 + warn!(ip = %client_ip, "Reset password rate limit exceeded"); 138 + return ( 139 + StatusCode::TOO_MANY_REQUESTS, 140 + Json(json!({ 141 + "error": "RateLimitExceeded", 142 + "message": "Too many requests. Please try again later." 143 + })), 144 + ).into_response(); 145 + } 146 + } 147 + 148 let token = input.token.trim(); 149 let password = &input.password; 150
+19
src/api/server/session.rs
··· 72 { 73 Ok(Some(row)) => row, 74 Ok(None) => { 75 warn!("User not found for login attempt"); 76 return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 77 } ··· 196 State(state): State<AppState>, 197 headers: axum::http::HeaderMap, 198 ) -> Response { 199 let refresh_token = match crate::auth::extract_bearer_token_from_header( 200 headers.get("Authorization").and_then(|h| h.to_str().ok()) 201 ) {
··· 72 { 73 Ok(Some(row)) => row, 74 Ok(None) => { 75 + let _ = verify(&input.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 76 warn!("User not found for login attempt"); 77 return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 78 } ··· 197 State(state): State<AppState>, 198 headers: axum::http::HeaderMap, 199 ) -> Response { 200 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 201 + if !state.distributed_rate_limiter.check_rate_limit( 202 + &format!("refresh_session:{}", client_ip), 203 + 60, 204 + 60_000, 205 + ).await { 206 + if state.rate_limiters.refresh_session.check_key(&client_ip).is_err() { 207 + tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 208 + return ( 209 + axum::http::StatusCode::TOO_MANY_REQUESTS, 210 + axum::Json(serde_json::json!({ 211 + "error": "RateLimitExceeded", 212 + "message": "Too many requests. Please try again later." 213 + })), 214 + ).into_response(); 215 + } 216 + } 217 + 218 let refresh_token = match crate::auth::extract_bearer_token_from_header( 219 headers.get("Authorization").and_then(|h| h.to_str().ok()) 220 ) {
+207
src/cache/mod.rs
···
··· 1 + use async_trait::async_trait; 2 + use std::sync::Arc; 3 + use std::time::Duration; 4 + 5 + #[derive(Debug, thiserror::Error)] 6 + pub enum CacheError { 7 + #[error("Cache connection error: {0}")] 8 + Connection(String), 9 + #[error("Serialization error: {0}")] 10 + Serialization(String), 11 + } 12 + 13 + #[async_trait] 14 + pub trait Cache: Send + Sync { 15 + async fn get(&self, key: &str) -> Option<String>; 16 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 17 + async fn delete(&self, key: &str) -> Result<(), CacheError>; 18 + } 19 + 20 + #[derive(Clone)] 21 + pub struct ValkeyCache { 22 + conn: redis::aio::ConnectionManager, 23 + } 24 + 25 + impl ValkeyCache { 26 + pub async fn new(url: &str) -> Result<Self, CacheError> { 27 + let client = redis::Client::open(url) 28 + .map_err(|e| CacheError::Connection(e.to_string()))?; 29 + let manager = client 30 + .get_connection_manager() 31 + .await 32 + .map_err(|e| CacheError::Connection(e.to_string()))?; 33 + Ok(Self { conn: manager }) 34 + } 35 + 36 + pub fn connection(&self) -> redis::aio::ConnectionManager { 37 + self.conn.clone() 38 + } 39 + } 40 + 41 + #[async_trait] 42 + impl Cache for ValkeyCache { 43 + async fn get(&self, key: &str) -> Option<String> { 44 + let mut conn = self.conn.clone(); 45 + redis::cmd("GET") 46 + .arg(key) 47 + .query_async::<Option<String>>(&mut conn) 48 + .await 49 + .ok() 50 + .flatten() 51 + } 52 + 53 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 54 + let mut conn = self.conn.clone(); 55 + redis::cmd("SET") 56 + .arg(key) 57 + .arg(value) 58 + .arg("EX") 59 + .arg(ttl.as_secs() as i64) 60 + .query_async::<()>(&mut conn) 61 + .await 62 + .map_err(|e| CacheError::Connection(e.to_string())) 63 + } 64 + 65 + async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 + let mut conn = self.conn.clone(); 67 + redis::cmd("DEL") 68 + .arg(key) 69 + .query_async::<()>(&mut conn) 70 + .await 71 + .map_err(|e| CacheError::Connection(e.to_string())) 72 + } 73 + } 74 + 75 + pub struct NoOpCache; 76 + 77 + #[async_trait] 78 + impl Cache for NoOpCache { 79 + async fn get(&self, _key: &str) -> Option<String> { 80 + None 81 + } 82 + 83 + async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 84 + Ok(()) 85 + } 86 + 87 + async fn delete(&self, _key: &str) -> Result<(), CacheError> { 88 + Ok(()) 89 + } 90 + } 91 + 92 + #[async_trait] 93 + pub trait DistributedRateLimiter: Send + Sync { 94 + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 95 + } 96 + 97 + #[derive(Clone)] 98 + pub struct RedisRateLimiter { 99 + conn: redis::aio::ConnectionManager, 100 + } 101 + 102 + impl RedisRateLimiter { 103 + pub fn new(conn: redis::aio::ConnectionManager) -> Self { 104 + Self { conn } 105 + } 106 + } 107 + 108 + #[async_trait] 109 + impl DistributedRateLimiter for RedisRateLimiter { 110 + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 111 + let mut conn = self.conn.clone(); 112 + let full_key = format!("rl:{}", key); 113 + let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 114 + 115 + let count: Result<i64, _> = redis::cmd("INCR") 116 + .arg(&full_key) 117 + .query_async(&mut conn) 118 + .await; 119 + 120 + let count = match count { 121 + Ok(c) => c, 122 + Err(e) => { 123 + tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 124 + return true; 125 + } 126 + }; 127 + 128 + if count == 1 { 129 + let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 130 + .arg(&full_key) 131 + .arg(window_secs) 132 + .query_async(&mut conn) 133 + .await; 134 + } 135 + 136 + count <= limit as i64 137 + } 138 + } 139 + 140 + pub struct NoOpRateLimiter; 141 + 142 + #[async_trait] 143 + impl DistributedRateLimiter for NoOpRateLimiter { 144 + async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 145 + true 146 + } 147 + } 148 + 149 + pub enum CacheBackend { 150 + Valkey(ValkeyCache), 151 + NoOp, 152 + } 153 + 154 + impl CacheBackend { 155 + pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 156 + match self { 157 + CacheBackend::Valkey(cache) => { 158 + Arc::new(RedisRateLimiter::new(cache.connection())) 159 + } 160 + CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 161 + } 162 + } 163 + } 164 + 165 + #[async_trait] 166 + impl Cache for CacheBackend { 167 + async fn get(&self, key: &str) -> Option<String> { 168 + match self { 169 + CacheBackend::Valkey(c) => c.get(key).await, 170 + CacheBackend::NoOp => None, 171 + } 172 + } 173 + 174 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 175 + match self { 176 + CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 177 + CacheBackend::NoOp => Ok(()), 178 + } 179 + } 180 + 181 + async fn delete(&self, key: &str) -> Result<(), CacheError> { 182 + match self { 183 + CacheBackend::Valkey(c) => c.delete(key).await, 184 + CacheBackend::NoOp => Ok(()), 185 + } 186 + } 187 + } 188 + 189 + pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 190 + match std::env::var("VALKEY_URL") { 191 + Ok(url) => match ValkeyCache::new(&url).await { 192 + Ok(cache) => { 193 + tracing::info!("Connected to Valkey cache at {}", url); 194 + let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 195 + (Arc::new(cache), rate_limiter) 196 + } 197 + Err(e) => { 198 + tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 199 + (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 200 + } 201 + }, 202 + Err(_) => { 203 + tracing::info!("VALKEY_URL not set. Running without cache."); 204 + (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 205 + } 206 + } 207 + }
+1
src/lib.rs
··· 1 pub mod api; 2 pub mod auth; 3 pub mod circuit_breaker; 4 pub mod config; 5 pub mod crawlers;
··· 1 pub mod api; 2 pub mod auth; 3 + pub mod cache; 4 pub mod circuit_breaker; 5 pub mod config; 6 pub mod crawlers;
+37 -1
src/oauth/endpoints/authorize.rs
··· 272 ) -> Response { 273 let json_response = wants_json(&headers); 274 275 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 276 Ok(Some(data)) => data, 277 Ok(None) => { ··· 357 .await 358 { 359 Ok(Some(u)) => u, 360 - Ok(None) => return show_login_error("Invalid handle/email or password.", json_response), 361 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 362 }; 363 ··· 736 headers: HeaderMap, 737 Form(form): Form<Authorize2faSubmit>, 738 ) -> Response { 739 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 740 Ok(Some(c)) => c, 741 Ok(None) => {
··· 272 ) -> Response { 273 let json_response = wants_json(&headers); 274 275 + let client_ip = extract_client_ip(&headers); 276 + if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 277 + tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 278 + if json_response { 279 + return ( 280 + axum::http::StatusCode::TOO_MANY_REQUESTS, 281 + Json(serde_json::json!({ 282 + "error": "RateLimitExceeded", 283 + "error_description": "Too many login attempts. Please try again later." 284 + })), 285 + ).into_response(); 286 + } 287 + return ( 288 + axum::http::StatusCode::TOO_MANY_REQUESTS, 289 + Html(templates::error_page( 290 + "RateLimitExceeded", 291 + Some("Too many login attempts. Please try again later."), 292 + )), 293 + ).into_response(); 294 + } 295 + 296 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 297 Ok(Some(data)) => data, 298 Ok(None) => { ··· 378 .await 379 { 380 Ok(Some(u)) => u, 381 + Ok(None) => { 382 + let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 383 + return show_login_error("Invalid handle/email or password.", json_response); 384 + } 385 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 386 }; 387 ··· 760 headers: HeaderMap, 761 Form(form): Form<Authorize2faSubmit>, 762 ) -> Response { 763 + let client_ip = extract_client_ip(&headers); 764 + if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 765 + tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 766 + return ( 767 + axum::http::StatusCode::TOO_MANY_REQUESTS, 768 + Html(templates::error_page( 769 + "RateLimitExceeded", 770 + Some("Too many attempts. Please try again later."), 771 + )), 772 + ).into_response(); 773 + } 774 + 775 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 776 Ok(Some(c)) => c, 777 Ok(None) => {
+14
src/oauth/endpoints/par.rs
··· 1 use axum::{ 2 Form, Json, 3 extract::State, 4 }; 5 use chrono::{Duration, Utc}; 6 use serde::{Deserialize, Serialize}; ··· 49 50 pub async fn pushed_authorization_request( 51 State(state): State<AppState>, 52 Form(request): Form<ParRequest>, 53 ) -> Result<Json<ParResponse>, OAuthError> { 54 if request.response_type != "code" { 55 return Err(OAuthError::InvalidRequest( 56 "response_type must be 'code'".to_string(),
··· 1 use axum::{ 2 Form, Json, 3 extract::State, 4 + http::HeaderMap, 5 }; 6 use chrono::{Duration, Utc}; 7 use serde::{Deserialize, Serialize}; ··· 50 51 pub async fn pushed_authorization_request( 52 State(state): State<AppState>, 53 + headers: HeaderMap, 54 Form(request): Form<ParRequest>, 55 ) -> Result<Json<ParResponse>, OAuthError> { 56 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 57 + if !state.distributed_rate_limiter.check_rate_limit( 58 + &format!("oauth_par:{}", client_ip), 59 + 30, 60 + 60_000, 61 + ).await { 62 + if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() { 63 + tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 64 + return Err(OAuthError::RateLimited); 65 + } 66 + } 67 + 68 if request.response_type != "code" { 69 return Err(OAuthError::InvalidRequest( 70 "response_type must be 'code'".to_string(),
+33 -7
src/oauth/endpoints/token/introspect.rs
··· 1 use axum::{Form, Json}; 2 use axum::extract::State; 3 - use axum::http::StatusCode; 4 use chrono::Utc; 5 use serde::{Deserialize, Serialize}; 6 ··· 18 19 pub async fn revoke_token( 20 State(state): State<AppState>, 21 Form(request): Form<RevokeRequest>, 22 ) -> Result<StatusCode, OAuthError> { 23 if let Some(token) = &request.token { 24 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 25 db::delete_token_family(&state.db, db_id).await?; ··· 67 68 pub async fn introspect_token( 69 State(state): State<AppState>, 70 Form(request): Form<IntrospectRequest>, 71 - ) -> Json<IntrospectResponse> { 72 let inactive_response = IntrospectResponse { 73 active: false, 74 scope: None, ··· 86 87 let token_info = match extract_token_claims(&request.token) { 88 Ok(info) => info, 89 - Err(_) => return Json(inactive_response), 90 }; 91 92 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 93 Ok(Some(data)) => data, 94 - _ => return Json(inactive_response), 95 }; 96 97 if token_data.expires_at < Utc::now() { 98 - return Json(inactive_response); 99 } 100 101 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 let issuer = format!("https://{}", pds_hostname); 103 104 - Json(IntrospectResponse { 105 active: true, 106 scope: token_data.scope, 107 client_id: Some(token_data.client_id), ··· 118 aud: Some(issuer.clone()), 119 iss: Some(issuer), 120 jti: Some(token_info.jti), 121 - }) 122 }
··· 1 use axum::{Form, Json}; 2 use axum::extract::State; 3 + use axum::http::{HeaderMap, StatusCode}; 4 use chrono::Utc; 5 use serde::{Deserialize, Serialize}; 6 ··· 18 19 pub async fn revoke_token( 20 State(state): State<AppState>, 21 + headers: HeaderMap, 22 Form(request): Form<RevokeRequest>, 23 ) -> Result<StatusCode, OAuthError> { 24 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 25 + if !state.distributed_rate_limiter.check_rate_limit( 26 + &format!("oauth_revoke:{}", client_ip), 27 + 30, 28 + 60_000, 29 + ).await { 30 + if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 31 + tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 32 + return Err(OAuthError::RateLimited); 33 + } 34 + } 35 + 36 if let Some(token) = &request.token { 37 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 38 db::delete_token_family(&state.db, db_id).await?; ··· 80 81 pub async fn introspect_token( 82 State(state): State<AppState>, 83 + headers: HeaderMap, 84 Form(request): Form<IntrospectRequest>, 85 + ) -> Result<Json<IntrospectResponse>, OAuthError> { 86 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 87 + if !state.distributed_rate_limiter.check_rate_limit( 88 + &format!("oauth_introspect:{}", client_ip), 89 + 30, 90 + 60_000, 91 + ).await { 92 + if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 93 + tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 94 + return Err(OAuthError::RateLimited); 95 + } 96 + } 97 + 98 let inactive_response = IntrospectResponse { 99 active: false, 100 scope: None, ··· 112 113 let token_info = match extract_token_claims(&request.token) { 114 Ok(info) => info, 115 + Err(_) => return Ok(Json(inactive_response)), 116 }; 117 118 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 119 Ok(Some(data)) => data, 120 + _ => return Ok(Json(inactive_response)), 121 }; 122 123 if token_data.expires_at < Utc::now() { 124 + return Ok(Json(inactive_response)); 125 } 126 127 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 128 let issuer = format!("https://{}", pds_hostname); 129 130 + Ok(Json(IntrospectResponse { 131 active: true, 132 scope: token_data.scope, 133 client_id: Some(token_data.client_id), ··· 144 aud: Some(issuer.clone()), 145 iss: Some(issuer), 146 jti: Some(token_info.jti), 147 + })) 148 }
+4
src/oauth/error.rs
··· 19 InvalidDpopProof(String), 20 ExpiredToken(String), 21 InvalidToken(String), 22 } 23 24 #[derive(Serialize)] ··· 73 } 74 OAuthError::InvalidToken(msg) => { 75 (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)) 76 } 77 }; 78
··· 19 InvalidDpopProof(String), 20 ExpiredToken(String), 21 InvalidToken(String), 22 + RateLimited, 23 } 24 25 #[derive(Serialize)] ··· 74 } 75 OAuthError::InvalidToken(msg) => { 76 (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)) 77 + } 78 + OAuthError::RateLimited => { 79 + (StatusCode::TOO_MANY_REQUESTS, "rate_limited", Some("Too many requests. Please try again later.".to_string())) 80 } 81 }; 82
+36 -1
src/rate_limit.rs
··· 24 pub struct RateLimiters { 25 pub login: Arc<KeyedRateLimiter>, 26 pub oauth_token: Arc<KeyedRateLimiter>, 27 pub password_reset: Arc<KeyedRateLimiter>, 28 pub account_creation: Arc<KeyedRateLimiter>, 29 } 30 31 impl Default for RateLimiters { ··· 42 )), 43 oauth_token: Arc::new(RateLimiter::keyed( 44 Quota::per_minute(NonZeroU32::new(30).unwrap()) 45 )), 46 password_reset: Arc::new(RateLimiter::keyed( 47 Quota::per_hour(NonZeroU32::new(5).unwrap()) ··· 49 account_creation: Arc::new(RateLimiter::keyed( 50 Quota::per_hour(NonZeroU32::new(10).unwrap()) 51 )), 52 } 53 } 54 ··· 66 self 67 } 68 69 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 70 self.password_reset = Arc::new(RateLimiter::keyed( 71 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 81 } 82 } 83 84 - fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 85 if let Some(forwarded) = headers.get("x-forwarded-for") { 86 if let Ok(value) = forwarded.to_str() { 87 if let Some(first_ip) = value.split(',').next() {
··· 24 pub struct RateLimiters { 25 pub login: Arc<KeyedRateLimiter>, 26 pub oauth_token: Arc<KeyedRateLimiter>, 27 + pub oauth_authorize: Arc<KeyedRateLimiter>, 28 pub password_reset: Arc<KeyedRateLimiter>, 29 pub account_creation: Arc<KeyedRateLimiter>, 30 + pub refresh_session: Arc<KeyedRateLimiter>, 31 + pub reset_password: Arc<KeyedRateLimiter>, 32 + pub oauth_par: Arc<KeyedRateLimiter>, 33 + pub oauth_introspect: Arc<KeyedRateLimiter>, 34 + pub app_password: Arc<KeyedRateLimiter>, 35 + pub email_update: Arc<KeyedRateLimiter>, 36 } 37 38 impl Default for RateLimiters { ··· 49 )), 50 oauth_token: Arc::new(RateLimiter::keyed( 51 Quota::per_minute(NonZeroU32::new(30).unwrap()) 52 + )), 53 + oauth_authorize: Arc::new(RateLimiter::keyed( 54 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 55 )), 56 password_reset: Arc::new(RateLimiter::keyed( 57 Quota::per_hour(NonZeroU32::new(5).unwrap()) ··· 59 account_creation: Arc::new(RateLimiter::keyed( 60 Quota::per_hour(NonZeroU32::new(10).unwrap()) 61 )), 62 + refresh_session: Arc::new(RateLimiter::keyed( 63 + Quota::per_minute(NonZeroU32::new(60).unwrap()) 64 + )), 65 + reset_password: Arc::new(RateLimiter::keyed( 66 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 67 + )), 68 + oauth_par: Arc::new(RateLimiter::keyed( 69 + Quota::per_minute(NonZeroU32::new(30).unwrap()) 70 + )), 71 + oauth_introspect: Arc::new(RateLimiter::keyed( 72 + Quota::per_minute(NonZeroU32::new(30).unwrap()) 73 + )), 74 + app_password: Arc::new(RateLimiter::keyed( 75 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 76 + )), 77 + email_update: Arc::new(RateLimiter::keyed( 78 + Quota::per_hour(NonZeroU32::new(5).unwrap()) 79 + )), 80 } 81 } 82 ··· 94 self 95 } 96 97 + pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 98 + self.oauth_authorize = Arc::new(RateLimiter::keyed( 99 + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 100 + )); 101 + self 102 + } 103 + 104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 105 self.password_reset = Arc::new(RateLimiter::keyed( 106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 116 } 117 } 118 119 + pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 120 if let Some(forwarded) = headers.get("x-forwarded-for") { 121 if let Ok(value) = forwarded.to_str() { 122 if let Some(first_ip) = value.split(',').next() {
+6
src/state.rs
··· 1 use crate::circuit_breaker::CircuitBreakers; 2 use crate::config::AuthConfig; 3 use crate::rate_limit::RateLimiters; ··· 16 pub firehose_tx: broadcast::Sender<SequencedEvent>, 17 pub rate_limiters: Arc<RateLimiters>, 18 pub circuit_breakers: Arc<CircuitBreakers>, 19 } 20 21 impl AppState { ··· 27 let (firehose_tx, _) = broadcast::channel(1000); 28 let rate_limiters = Arc::new(RateLimiters::new()); 29 let circuit_breakers = Arc::new(CircuitBreakers::new()); 30 Self { 31 db, 32 block_store, ··· 34 firehose_tx, 35 rate_limiters, 36 circuit_breakers, 37 } 38 } 39
··· 1 + use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 2 use crate::circuit_breaker::CircuitBreakers; 3 use crate::config::AuthConfig; 4 use crate::rate_limit::RateLimiters; ··· 17 pub firehose_tx: broadcast::Sender<SequencedEvent>, 18 pub rate_limiters: Arc<RateLimiters>, 19 pub circuit_breakers: Arc<CircuitBreakers>, 20 + pub cache: Arc<dyn Cache>, 21 + pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 22 } 23 24 impl AppState { ··· 30 let (firehose_tx, _) = broadcast::channel(1000); 31 let rate_limiters = Arc::new(RateLimiters::new()); 32 let circuit_breakers = Arc::new(CircuitBreakers::new()); 33 + let (cache, distributed_rate_limiter) = create_cache().await; 34 Self { 35 db, 36 block_store, ··· 38 firehose_tx, 39 rate_limiters, 40 circuit_breakers, 41 + cache, 42 + distributed_rate_limiter, 43 } 44 } 45
+64
tests/oauth_security.rs
··· 1447 let introspect_body: Value = introspect_res.json().await.unwrap(); 1448 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1449 }
··· 1447 let introspect_body: Value = introspect_res.json().await.unwrap(); 1448 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1449 } 1450 + 1451 + #[tokio::test] 1452 + async fn test_security_oauth_authorize_rate_limiting() { 1453 + let url = base_url().await; 1454 + let http_client = no_redirect_client(); 1455 + 1456 + let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0); 1457 + let unique_ip = format!("10.{}.{}.{}", (ts >> 16) & 0xFF, (ts >> 8) & 0xFF, ts & 0xFF); 1458 + 1459 + let redirect_uri = "https://example.com/rate-limit-callback"; 1460 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1461 + let client_id = mock_client.uri(); 1462 + 1463 + let (_, code_challenge) = generate_pkce(); 1464 + 1465 + let client_for_par = client(); 1466 + let par_body: Value = client_for_par 1467 + .post(format!("{}/oauth/par", url)) 1468 + .form(&[ 1469 + ("response_type", "code"), 1470 + ("client_id", &client_id), 1471 + ("redirect_uri", redirect_uri), 1472 + ("code_challenge", &code_challenge), 1473 + ("code_challenge_method", "S256"), 1474 + ]) 1475 + .send() 1476 + .await 1477 + .unwrap() 1478 + .json() 1479 + .await 1480 + .unwrap(); 1481 + 1482 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1483 + 1484 + let mut rate_limited_count = 0; 1485 + let mut other_count = 0; 1486 + 1487 + for _ in 0..15 { 1488 + let res = http_client 1489 + .post(format!("{}/oauth/authorize", url)) 1490 + .header("X-Forwarded-For", &unique_ip) 1491 + .form(&[ 1492 + ("request_uri", request_uri), 1493 + ("username", "nonexistent_user"), 1494 + ("password", "wrong_password"), 1495 + ("remember_device", "false"), 1496 + ]) 1497 + .send() 1498 + .await 1499 + .unwrap(); 1500 + 1501 + match res.status() { 1502 + StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1, 1503 + _ => other_count += 1, 1504 + } 1505 + } 1506 + 1507 + assert!( 1508 + rate_limited_count > 0, 1509 + "Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.", 1510 + other_count, 1511 + rate_limited_count 1512 + ); 1513 + }
+228
tests/rate_limit.rs
···
··· 1 + mod common; 2 + 3 + use common::{base_url, client}; 4 + use reqwest::StatusCode; 5 + use serde_json::json; 6 + 7 + #[tokio::test] 8 + async fn test_login_rate_limiting() { 9 + let client = client(); 10 + let url = format!("{}/xrpc/com.atproto.server.createSession", base_url().await); 11 + 12 + let payload = json!({ 13 + "identifier": "nonexistent_user_for_rate_limit_test", 14 + "password": "wrongpassword" 15 + }); 16 + 17 + let mut rate_limited_count = 0; 18 + let mut auth_failed_count = 0; 19 + 20 + for _ in 0..15 { 21 + let res = client 22 + .post(&url) 23 + .json(&payload) 24 + .send() 25 + .await 26 + .expect("Request failed"); 27 + 28 + match res.status() { 29 + StatusCode::TOO_MANY_REQUESTS => { 30 + rate_limited_count += 1; 31 + } 32 + StatusCode::UNAUTHORIZED => { 33 + auth_failed_count += 1; 34 + } 35 + status => { 36 + panic!("Unexpected status: {}", status); 37 + } 38 + } 39 + } 40 + 41 + assert!( 42 + rate_limited_count > 0, 43 + "Expected at least one rate-limited response after 15 login attempts. Got {} auth failures and {} rate limits.", 44 + auth_failed_count, 45 + rate_limited_count 46 + ); 47 + } 48 + 49 + #[tokio::test] 50 + async fn test_password_reset_rate_limiting() { 51 + let client = client(); 52 + let url = format!( 53 + "{}/xrpc/com.atproto.server.requestPasswordReset", 54 + base_url().await 55 + ); 56 + 57 + let mut rate_limited_count = 0; 58 + let mut success_count = 0; 59 + 60 + for i in 0..8 { 61 + let payload = json!({ 62 + "email": format!("ratelimit_test_{}@example.com", i) 63 + }); 64 + 65 + let res = client 66 + .post(&url) 67 + .json(&payload) 68 + .send() 69 + .await 70 + .expect("Request failed"); 71 + 72 + match res.status() { 73 + StatusCode::TOO_MANY_REQUESTS => { 74 + rate_limited_count += 1; 75 + } 76 + StatusCode::OK => { 77 + success_count += 1; 78 + } 79 + status => { 80 + panic!("Unexpected status: {} - {:?}", status, res.text().await); 81 + } 82 + } 83 + } 84 + 85 + assert!( 86 + rate_limited_count > 0, 87 + "Expected rate limiting after {} password reset requests. Got {} successes.", 88 + success_count + rate_limited_count, 89 + success_count 90 + ); 91 + } 92 + 93 + #[tokio::test] 94 + async fn test_account_creation_rate_limiting() { 95 + let client = client(); 96 + let url = format!( 97 + "{}/xrpc/com.atproto.server.createAccount", 98 + base_url().await 99 + ); 100 + 101 + let mut rate_limited_count = 0; 102 + let mut other_count = 0; 103 + 104 + for i in 0..15 { 105 + let unique_id = uuid::Uuid::new_v4(); 106 + let payload = json!({ 107 + "handle": format!("ratelimit_{}_{}", i, unique_id), 108 + "email": format!("ratelimit_{}_{}@example.com", i, unique_id), 109 + "password": "testpassword123" 110 + }); 111 + 112 + let res = client 113 + .post(&url) 114 + .json(&payload) 115 + .send() 116 + .await 117 + .expect("Request failed"); 118 + 119 + match res.status() { 120 + StatusCode::TOO_MANY_REQUESTS => { 121 + rate_limited_count += 1; 122 + } 123 + _ => { 124 + other_count += 1; 125 + } 126 + } 127 + } 128 + 129 + assert!( 130 + rate_limited_count > 0, 131 + "Expected rate limiting after account creation attempts. Got {} other responses and {} rate limits.", 132 + other_count, 133 + rate_limited_count 134 + ); 135 + } 136 + 137 + #[tokio::test] 138 + async fn test_valkey_connection() { 139 + if std::env::var("VALKEY_URL").is_err() { 140 + println!("VALKEY_URL not set, skipping Valkey connection test"); 141 + return; 142 + } 143 + 144 + let valkey_url = std::env::var("VALKEY_URL").unwrap(); 145 + let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client"); 146 + let mut conn = client 147 + .get_multiplexed_async_connection() 148 + .await 149 + .expect("Failed to connect to Valkey"); 150 + 151 + let pong: String = redis::cmd("PING") 152 + .query_async(&mut conn) 153 + .await 154 + .expect("PING failed"); 155 + assert_eq!(pong, "PONG"); 156 + 157 + let _: () = redis::cmd("SET") 158 + .arg("test_key") 159 + .arg("test_value") 160 + .arg("EX") 161 + .arg(10) 162 + .query_async(&mut conn) 163 + .await 164 + .expect("SET failed"); 165 + 166 + let value: String = redis::cmd("GET") 167 + .arg("test_key") 168 + .query_async(&mut conn) 169 + .await 170 + .expect("GET failed"); 171 + assert_eq!(value, "test_value"); 172 + 173 + let _: () = redis::cmd("DEL") 174 + .arg("test_key") 175 + .query_async(&mut conn) 176 + .await 177 + .expect("DEL failed"); 178 + } 179 + 180 + #[tokio::test] 181 + async fn test_distributed_rate_limiter_directly() { 182 + if std::env::var("VALKEY_URL").is_err() { 183 + println!("VALKEY_URL not set, skipping distributed rate limiter test"); 184 + return; 185 + } 186 + 187 + use bspds::cache::{DistributedRateLimiter, RedisRateLimiter}; 188 + 189 + let valkey_url = std::env::var("VALKEY_URL").unwrap(); 190 + let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client"); 191 + let conn = client 192 + .get_connection_manager() 193 + .await 194 + .expect("Failed to get connection manager"); 195 + 196 + let rate_limiter = RedisRateLimiter::new(conn); 197 + 198 + let test_key = format!("test_rate_limit_{}", uuid::Uuid::new_v4()); 199 + let limit = 5; 200 + let window_ms = 60_000; 201 + 202 + for i in 0..limit { 203 + let allowed = rate_limiter 204 + .check_rate_limit(&test_key, limit, window_ms) 205 + .await; 206 + assert!( 207 + allowed, 208 + "Request {} should have been allowed (limit: {})", 209 + i + 1, 210 + limit 211 + ); 212 + } 213 + 214 + let allowed = rate_limiter 215 + .check_rate_limit(&test_key, limit, window_ms) 216 + .await; 217 + assert!( 218 + !allowed, 219 + "Request {} should have been rate limited (limit: {})", 220 + limit + 1, 221 + limit 222 + ); 223 + 224 + let allowed = rate_limiter 225 + .check_rate_limit(&test_key, limit, window_ms) 226 + .await; 227 + assert!(!allowed, "Subsequent request should also be rate limited"); 228 + }