this repo has no description

My misc TODOs

lewis 4a767dea 8a10a7a6

+34
.sqlx/query-08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT id, password_hash, handle FROM users WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "password_hash", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "handle", 19 + "type_info": "Text" 20 + } 21 + ], 22 + "parameters": { 23 + "Left": [ 24 + "Text" 25 + ] 26 + }, 27 + "nullable": [ 28 + false, 29 + false, 30 + false 31 + ] 32 + }, 33 + "hash": "08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97" 34 + }
-28
.sqlx/query-76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT id, password_hash FROM users WHERE did = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "id", 9 - "type_info": "Uuid" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "password_hash", 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Left": [ 19 - "Text" 20 - ] 21 - }, 22 - "nullable": [ 23 - false, 24 - false 25 - ] 26 - }, 27 - "hash": "76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3" 28 - }
+22
.sqlx/query-e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT handle FROM users WHERE id = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "handle", 9 + "type_info": "Text" 10 + } 11 + ], 12 + "parameters": { 13 + "Left": [ 14 + "Uuid" 15 + ] 16 + }, 17 + "nullable": [ 18 + false 19 + ] 20 + }, 21 + "hash": "e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf" 22 + }
+73 -2
Cargo.lock
··· 99 99 checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" 100 100 101 101 [[package]] 102 + name = "arc-swap" 103 + version = "1.7.1" 104 + source = "registry+https://github.com/rust-lang/crates.io-index" 105 + checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" 106 + 107 + [[package]] 102 108 name = "assert-json-diff" 103 109 version = "2.0.2" 104 110 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 689 695 ] 690 696 691 697 [[package]] 698 + name = "backon" 699 + version = "1.6.0" 700 + source = "registry+https://github.com/rust-lang/crates.io-index" 701 + checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" 702 + dependencies = [ 703 + "fastrand", 704 + ] 705 + 706 + [[package]] 692 707 name = "base-x" 693 708 version = "0.2.11" 694 709 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 931 946 "p256 0.13.2", 932 947 "p384", 933 948 "rand 0.8.5", 949 + "redis", 934 950 "reqwest", 935 951 "serde", 936 952 "serde_bytes", ··· 1176 1192 version = "1.1.0" 1177 1193 source = "registry+https://github.com/rust-lang/crates.io-index" 1178 1194 checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 1195 + 1196 + [[package]] 1197 + name = "combine" 1198 + version = "4.6.7" 1199 + source = "registry+https://github.com/rust-lang/crates.io-index" 1200 + checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" 1201 + dependencies = [ 1202 + "bytes", 1203 + "futures-core", 1204 + "memchr", 1205 + "pin-project-lite", 1206 + "tokio", 1207 + "tokio-util", 1208 + ] 1179 1209 1180 1210 [[package]] 1181 1211 name = "compression-codecs" ··· 2973 3003 2974 3004 [[package]] 2975 3005 name = "itertools" 3006 + version = "0.13.0" 3007 + source = "registry+https://github.com/rust-lang/crates.io-index" 3008 + checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" 3009 + dependencies = [ 3010 + "either", 3011 + ] 3012 + 3013 + [[package]] 3014 + name = "itertools" 2976 3015 version = "0.14.0" 2977 3016 source = "registry+https://github.com/rust-lang/crates.io-index" 2978 3017 checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" ··· 4241 4280 checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" 4242 4281 dependencies = [ 4243 4282 "anyhow", 4244 - "itertools", 4283 + "itertools 0.14.0", 4245 4284 "proc-macro2", 4246 4285 "quote", 4247 4286 "syn 2.0.111", ··· 4442 4481 ] 4443 4482 4444 4483 [[package]] 4484 + name = "redis" 4485 + version = "0.27.6" 4486 + source = "registry+https://github.com/rust-lang/crates.io-index" 4487 + checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc" 4488 + dependencies = [ 4489 + "arc-swap", 4490 + "async-trait", 4491 + "backon", 4492 + "bytes", 4493 + "combine", 4494 + "futures", 4495 + "futures-util", 4496 + "itertools 0.13.0", 4497 + "itoa", 4498 + "num-bigint", 4499 + "percent-encoding", 4500 + "pin-project-lite", 4501 + "ryu", 4502 + "sha1_smol", 4503 + "socket2 0.5.10", 4504 + "tokio", 4505 + "tokio-util", 4506 + "url", 4507 + ] 4508 + 4509 + [[package]] 4445 4510 name = "redox_syscall" 4446 4511 version = "0.5.18" 4447 4512 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5055 5120 ] 5056 5121 5057 5122 [[package]] 5123 + name = "sha1_smol" 5124 + version = "1.0.1" 5125 + source = "registry+https://github.com/rust-lang/crates.io-index" 5126 + checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" 5127 + 5128 + [[package]] 5058 5129 name = "sha2" 5059 5130 version = "0.10.9" 5060 5131 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 5646 5717 "etcetera 0.11.0", 5647 5718 "ferroid", 5648 5719 "futures", 5649 - "itertools", 5720 + "itertools 0.14.0", 5650 5721 "log", 5651 5722 "memchr", 5652 5723 "parse-display",
+1
Cargo.toml
··· 49 49 uuid = { version = "1.19.0", features = ["v4", "fast-rng"] } 50 50 iroh-car = "0.5.1" 51 51 image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } 52 + redis = { version = "0.27", features = ["tokio-comp", "connection-manager"] } 52 53 53 54 [features] 54 55 external-infra = []
+17 -7
TODO.md
··· 198 198 - [x] Implement Atomic Repo Transactions. 199 199 - [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction. 200 200 - [x] Implement concurrency control (row-level locking via FOR UPDATE). 201 - - [ ] DID Cache 202 - - [ ] Implement caching layer for DID resolution (Redis or in-memory). 203 - - [ ] Handle cache invalidation/expiry. 201 + - [x] DID Cache 202 + - [x] Implement caching layer for DID resolution (valkey). 203 + - [x] Handle cache invalidation/expiry. 204 + - [x] Graceful fallback to no-cache when Valkey unavailable. 204 205 - [x] Crawlers Service 205 206 - [x] Implement `Crawlers` service (debounce notifications to relays). 206 207 - [x] 20-minute notification debounce. ··· 229 230 - [x] Per-IP rate limiting on OAuth token endpoint (30/min). 230 231 - [x] Per-IP rate limiting on password reset (5/hour). 231 232 - [x] Per-IP rate limiting on account creation (10/hour). 233 + - [x] Per-IP rate limiting on refreshSession (60/min). 234 + - [x] Per-IP rate limiting on OAuth authorize POST (10/min). 235 + - [x] Per-IP rate limiting on OAuth 2FA POST (10/min). 236 + - [x] Per-IP rate limiting on OAuth PAR (30/min). 237 + - [x] Per-IP rate limiting on OAuth revoke/introspect (30/min). 238 + - [x] Per-IP rate limiting on createAppPassword (10/min). 239 + - [x] Per-IP rate limiting on email endpoints (5/hour). 240 + - [x] Distributed rate limiting via Valkey/Redis (with in-memory fallback). 232 241 - [x] Circuit Breakers 233 242 - [x] PLC directory circuit breaker (5 failures → open, 60s timeout). 234 243 - [x] Relay notification circuit breaker (10 failures → open, 30s timeout). ··· 237 246 - [x] Signal command injection prevention (phone number validation). 238 247 - [x] Constant-time signature comparison. 239 248 - [x] SSRF protection for outbound requests. 249 + - [x] Timing attack protection (dummy bcrypt on user-not-found prevents account enumeration). 240 250 241 251 ## Lewis' fabulous mini-list of remaining TODOs 242 - - [ ] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again. 243 - - [ ] DID resolution caching (valkey). 244 - - [ ] Record schema validation (generic validation framework). 245 - - [ ] Fix any remaining TODOs in the code. 252 + - [x] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again. 253 + - [x] DID resolution caching (valkey). 254 + - [x] Record schema validation (generic validation framework). 255 + - [x] Fix any remaining TODOs in the code. 246 256 247 257 ## Future: Web Management UI 248 258 A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
+10
docker-compose.yaml
··· 11 11 environment: 12 12 DATABASE_URL: postgres://postgres:postgres@db:5432/pds 13 13 S3_ENDPOINT: http://objsto:9000 14 + VALKEY_URL: redis://cache:6379 14 15 depends_on: 15 16 - db 16 17 - objsto 18 + - cache 17 19 18 20 db: 19 21 image: postgres:latest ··· 38 40 - minio_data:/data 39 41 command: server /data --console-address ":9001" 40 42 43 + cache: 44 + image: valkey/valkey:8-alpine 45 + ports: 46 + - "6379:6379" 47 + volumes: 48 + - valkey_data:/data 49 + 41 50 volumes: 42 51 postgres_data: 43 52 minio_data: 53 + valkey_data:
+20 -3
scripts/test-infra.sh
··· 38 38 rm -f "$INFRA_FILE" 39 39 fi 40 40 41 - $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true 41 + $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true 42 42 43 43 echo "Starting PostgreSQL..." 44 44 $CONTAINER_CMD run -d \ ··· 59 59 --label bspds_test=true \ 60 60 minio/minio:latest server /data >/dev/null 61 61 62 + echo "Starting Valkey..." 63 + $CONTAINER_CMD run -d \ 64 + --name "${CONTAINER_PREFIX}-valkey" \ 65 + -P \ 66 + --label bspds_test=true \ 67 + valkey/valkey:8-alpine >/dev/null 68 + 62 69 echo "Waiting for services to be ready..." 63 70 sleep 2 64 71 65 72 PG_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-postgres" 5432 | head -1 | cut -d: -f2) 66 73 MINIO_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-minio" 9000 | head -1 | cut -d: -f2) 74 + VALKEY_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-valkey" 6379 | head -1 | cut -d: -f2) 67 75 68 76 for i in {1..30}; do 69 77 if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-postgres" pg_isready -U postgres >/dev/null 2>&1; then ··· 81 89 sleep 1 82 90 done 83 91 92 + for i in {1..30}; do 93 + if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-valkey" valkey-cli ping 2>/dev/null | grep -q PONG; then 94 + break 95 + fi 96 + echo "Waiting for Valkey... ($i/30)" 97 + sleep 1 98 + done 99 + 84 100 echo "Creating MinIO bucket..." 85 101 $CONTAINER_CMD run --rm --network host \ 86 102 -e MC_HOST_minio="http://minioadmin:minioadmin@127.0.0.1:${MINIO_PORT}" \ ··· 94 110 export AWS_ACCESS_KEY_ID="minioadmin" 95 111 export AWS_SECRET_ACCESS_KEY="minioadmin" 96 112 export AWS_REGION="us-east-1" 113 + export VALKEY_URL="redis://127.0.0.1:${VALKEY_PORT}" 97 114 export BSPDS_TEST_INFRA_READY="1" 98 115 export BSPDS_ALLOW_INSECURE_SECRETS="1" 99 116 export SKIP_IMPORT_VERIFICATION="true" ··· 108 125 109 126 stop_infra() { 110 127 echo "Stopping test infrastructure..." 111 - $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true 128 + $CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true 112 129 rm -f "$INFRA_FILE" 113 130 echo "Infrastructure stopped." 114 131 } ··· 157 174 echo "Usage: $0 {start|stop|restart|status|env}" 158 175 echo "" 159 176 echo "Commands:" 160 - echo " start - Start test infrastructure (Postgres, MinIO)" 177 + echo " start - Start test infrastructure (Postgres, MinIO, Valkey)" 161 178 echo " stop - Stop and remove test containers" 162 179 echo " restart - Stop then start infrastructure" 163 180 echo " status - Show infrastructure status"
+5 -3
src/api/admin/account/delete.rs
··· 37 37 .into_response(); 38 38 } 39 39 40 - let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 40 + let user = sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) 41 41 .fetch_optional(&state.db) 42 42 .await; 43 43 44 - let user_id = match user { 45 - Ok(Some(row)) => row.id, 44 + let (user_id, handle) = match user { 45 + Ok(Some(row)) => (row.id, row.handle), 46 46 Ok(None) => { 47 47 return ( 48 48 StatusCode::NOT_FOUND, ··· 185 185 ) 186 186 .into_response(); 187 187 } 188 + 189 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 188 190 189 191 (StatusCode::OK, Json(json!({}))).into_response() 190 192 }
+10
src/api/admin/account/update.rs
··· 108 108 .into_response(); 109 109 } 110 110 111 + let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 112 + .fetch_optional(&state.db) 113 + .await 114 + .ok() 115 + .flatten(); 116 + 111 117 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 112 118 .fetch_optional(&state.db) 113 119 .await; ··· 133 139 ) 134 140 .into_response(); 135 141 } 142 + if let Some(old) = old_handle { 143 + let _ = state.cache.delete(&format!("handle:{}", old)).await; 144 + } 145 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 136 146 (StatusCode::OK, Json(json!({}))).into_response() 137 147 } 138 148 Err(e) => {
+7
src/api/admin/status.rs
··· 305 305 .into_response(); 306 306 } 307 307 308 + if let Ok(Some(handle)) = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 309 + .fetch_optional(&state.db) 310 + .await 311 + { 312 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 313 + } 314 + 308 315 return ( 309 316 StatusCode::OK, 310 317 Json(json!({
+19 -1
src/api/identity/did.rs
··· 33 33 .into_response(); 34 34 } 35 35 36 + let cache_key = format!("handle:{}", handle); 37 + if let Some(did) = state.cache.get(&cache_key).await { 38 + return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 39 + } 40 + 36 41 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 37 42 .fetch_optional(&state.db) 38 43 .await; 39 44 40 45 match user { 41 46 Ok(Some(row)) => { 47 + let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 42 48 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 43 49 } 44 50 Ok(None) => ( ··· 406 412 .into_response(); 407 413 } 408 414 415 + let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 416 + .fetch_optional(&state.db) 417 + .await 418 + .ok() 419 + .flatten(); 420 + 409 421 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 410 422 .fetch_optional(&state.db) 411 423 .await; ··· 423 435 .await; 424 436 425 437 match result { 426 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 438 + Ok(_) => { 439 + if let Some(old) = old_handle { 440 + let _ = state.cache.delete(&format!("handle:{}", old)).await; 441 + } 442 + let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 443 + (StatusCode::OK, Json(json!({}))).into_response() 444 + } 427 445 Err(e) => { 428 446 error!("DB error updating handle: {:?}", e); 429 447 (
+11
src/api/repo/record/batch.rs
··· 1 + use super::validation::validate_record; 1 2 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 2 3 use crate::repo::tracking::TrackingBlockStore; 3 4 use crate::state::AppState; ··· 211 212 rkey, 212 213 value, 213 214 } => { 215 + if input.validate.unwrap_or(true) { 216 + if let Err(err_response) = validate_record(value, collection) { 217 + return err_response; 218 + } 219 + } 214 220 let rkey = rkey 215 221 .clone() 216 222 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string()); ··· 249 255 rkey, 250 256 value, 251 257 } => { 258 + if input.validate.unwrap_or(true) { 259 + if let Err(err_response) = validate_record(value, collection) { 260 + return err_response; 261 + } 262 + } 252 263 let mut record_bytes = Vec::new(); 253 264 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 254 265 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
+1
src/api/repo/record/mod.rs
··· 2 2 pub mod delete; 3 3 pub mod read; 4 4 pub mod utils; 5 + pub mod validation; 5 6 pub mod write; 6 7 7 8 pub use batch::apply_writes;
+38
src/api/repo/record/validation.rs
··· 1 + use crate::validation::{RecordValidator, ValidationError}; 2 + use axum::{ 3 + http::StatusCode, 4 + response::{IntoResponse, Response}, 5 + Json, 6 + }; 7 + use serde_json::json; 8 + 9 + pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> { 10 + let validator = RecordValidator::new(); 11 + match validator.validate(record, collection) { 12 + Ok(_) => Ok(()), 13 + Err(ValidationError::MissingType) => Err(( 14 + StatusCode::BAD_REQUEST, 15 + Json(json!({"error": "InvalidRecord", "message": "Record must have a $type field"})), 16 + ).into_response()), 17 + Err(ValidationError::TypeMismatch { expected, actual }) => Err(( 18 + StatusCode::BAD_REQUEST, 19 + Json(json!({"error": "InvalidRecord", "message": format!("Record $type '{}' does not match collection '{}'", actual, expected)})), 20 + ).into_response()), 21 + Err(ValidationError::MissingField(field)) => Err(( 22 + StatusCode::BAD_REQUEST, 23 + Json(json!({"error": "InvalidRecord", "message": format!("Missing required field: {}", field)})), 24 + ).into_response()), 25 + Err(ValidationError::InvalidField { path, message }) => Err(( 26 + StatusCode::BAD_REQUEST, 27 + Json(json!({"error": "InvalidRecord", "message": format!("Invalid field '{}': {}", path, message)})), 28 + ).into_response()), 29 + Err(ValidationError::InvalidDatetime { path }) => Err(( 30 + StatusCode::BAD_REQUEST, 31 + Json(json!({"error": "InvalidRecord", "message": format!("Invalid datetime format at '{}'", path)})), 32 + ).into_response()), 33 + Err(e) => Err(( 34 + StatusCode::BAD_REQUEST, 35 + Json(json!({"error": "InvalidRecord", "message": e.to_string()})), 36 + ).into_response()), 37 + } 38 + }
+5 -16
src/api/repo/record/write.rs
··· 1 + use super::validation::validate_record; 1 2 use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 2 3 use crate::repo::tracking::TrackingBlockStore; 3 4 use crate::state::AppState; ··· 156 157 }; 157 158 158 159 if input.validate.unwrap_or(true) { 159 - if input.collection == "app.bsky.feed.post" { 160 - if input.record.get("text").is_none() || input.record.get("createdAt").is_none() { 161 - return ( 162 - StatusCode::BAD_REQUEST, 163 - Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})), 164 - ) 165 - .into_response(); 166 - } 160 + if let Err(err_response) = validate_record(&input.record, &input.collection) { 161 + return err_response; 167 162 } 168 163 } 169 164 ··· 263 258 let key = format!("{}/{}", collection_nsid, input.rkey); 264 259 265 260 if input.validate.unwrap_or(true) { 266 - if input.collection == "app.bsky.feed.post" { 267 - if input.record.get("text").is_none() || input.record.get("createdAt").is_none() { 268 - return ( 269 - StatusCode::BAD_REQUEST, 270 - Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})), 271 - ) 272 - .into_response(); 273 - } 261 + if let Err(err_response) = validate_record(&input.record, &input.collection) { 262 + return err_response; 274 263 } 275 264 } 276 265
+28 -5
src/api/server/account_status.rs
··· 123 123 Err(e) => return ApiError::from(e).into_response(), 124 124 }; 125 125 126 + let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 127 + .fetch_optional(&state.db) 128 + .await 129 + .ok() 130 + .flatten(); 131 + 126 132 let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) 127 133 .execute(&state.db) 128 134 .await; 129 135 130 136 match result { 131 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 137 + Ok(_) => { 138 + if let Some(h) = handle { 139 + let _ = state.cache.delete(&format!("handle:{}", h)).await; 140 + } 141 + (StatusCode::OK, Json(json!({}))).into_response() 142 + } 132 143 Err(e) => { 133 144 error!("DB error activating account: {:?}", e); 134 145 ( ··· 163 174 Err(e) => return ApiError::from(e).into_response(), 164 175 }; 165 176 177 + let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did) 178 + .fetch_optional(&state.db) 179 + .await 180 + .ok() 181 + .flatten(); 182 + 166 183 let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did) 167 184 .execute(&state.db) 168 185 .await; 169 186 170 187 match result { 171 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 188 + Ok(_) => { 189 + if let Some(h) = handle { 190 + let _ = state.cache.delete(&format!("handle:{}", h)).await; 191 + } 192 + (StatusCode::OK, Json(json!({}))).into_response() 193 + } 172 194 Err(e) => { 173 195 error!("DB error deactivating account: {:?}", e); 174 196 ( ··· 283 305 } 284 306 285 307 let user = sqlx::query!( 286 - "SELECT id, password_hash FROM users WHERE did = $1", 308 + "SELECT id, password_hash, handle FROM users WHERE did = $1", 287 309 did 288 310 ) 289 311 .fetch_optional(&state.db) 290 312 .await; 291 313 292 - let (user_id, password_hash) = match user { 293 - Ok(Some(row)) => (row.id, row.password_hash), 314 + let (user_id, password_hash, handle) = match user { 315 + Ok(Some(row)) => (row.id, row.password_hash, row.handle), 294 316 Ok(None) => { 295 317 return ( 296 318 StatusCode::BAD_REQUEST, ··· 437 459 ) 438 460 .into_response(); 439 461 } 462 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 440 463 info!("Account {} deleted successfully", did); 441 464 (StatusCode::OK, Json(json!({}))).into_response() 442 465 }
+21 -1
src/api/server/app_password.rs
··· 5 5 use axum::{ 6 6 Json, 7 7 extract::State, 8 + http::HeaderMap, 8 9 response::{IntoResponse, Response}, 9 10 }; 10 11 use serde::{Deserialize, Serialize}; 11 12 use serde_json::json; 12 - use tracing::error; 13 + use tracing::{error, warn}; 13 14 14 15 #[derive(Serialize)] 15 16 #[serde(rename_all = "camelCase")] ··· 76 77 77 78 pub async fn create_app_password( 78 79 State(state): State<AppState>, 80 + headers: HeaderMap, 79 81 BearerAuth(auth_user): BearerAuth, 80 82 Json(input): Json<CreateAppPasswordInput>, 81 83 ) -> Response { 84 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 85 + if !state.distributed_rate_limiter.check_rate_limit( 86 + &format!("app_password:{}", client_ip), 87 + 10, 88 + 60_000, 89 + ).await { 90 + if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 91 + warn!(ip = %client_ip, "App password creation rate limit exceeded"); 92 + return ( 93 + axum::http::StatusCode::TOO_MANY_REQUESTS, 94 + Json(json!({ 95 + "error": "RateLimitExceeded", 96 + "message": "Too many requests. Please try again later." 97 + })), 98 + ).into_response(); 99 + } 100 + } 101 + 82 102 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 83 103 Ok(id) => id, 84 104 Err(e) => return ApiError::from(e).into_response(),
+36
src/api/server/email.rs
··· 26 26 headers: axum::http::HeaderMap, 27 27 Json(input): Json<RequestEmailUpdateInput>, 28 28 ) -> Response { 29 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 30 + if !state.distributed_rate_limiter.check_rate_limit( 31 + &format!("email_update:{}", client_ip), 32 + 5, 33 + 3_600_000, 34 + ).await { 35 + if state.rate_limiters.email_update.check_key(&client_ip).is_err() { 36 + warn!(ip = %client_ip, "Email update rate limit exceeded"); 37 + return ( 38 + StatusCode::TOO_MANY_REQUESTS, 39 + Json(json!({ 40 + "error": "RateLimitExceeded", 41 + "message": "Too many requests. Please try again later." 42 + })), 43 + ).into_response(); 44 + } 45 + } 46 + 29 47 let token = match crate::auth::extract_bearer_token_from_header( 30 48 headers.get("Authorization").and_then(|h| h.to_str().ok()) 31 49 ) { ··· 135 153 headers: axum::http::HeaderMap, 136 154 Json(input): Json<ConfirmEmailInput>, 137 155 ) -> Response { 156 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 157 + if !state.distributed_rate_limiter.check_rate_limit( 158 + &format!("confirm_email:{}", client_ip), 159 + 10, 160 + 60_000, 161 + ).await { 162 + if state.rate_limiters.app_password.check_key(&client_ip).is_err() { 163 + warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 164 + return ( 165 + StatusCode::TOO_MANY_REQUESTS, 166 + Json(json!({ 167 + "error": "RateLimitExceeded", 168 + "message": "Too many requests. Please try again later." 169 + })), 170 + ).into_response(); 171 + } 172 + } 173 + 138 174 let token = match crate::auth::extract_bearer_token_from_header( 139 175 headers.get("Authorization").and_then(|h| h.to_str().ok()) 140 176 ) {
+19
src/api/server/password.rs
··· 124 124 125 125 pub async fn reset_password( 126 126 State(state): State<AppState>, 127 + headers: HeaderMap, 127 128 Json(input): Json<ResetPasswordInput>, 128 129 ) -> Response { 130 + let client_ip = extract_client_ip(&headers); 131 + if !state.distributed_rate_limiter.check_rate_limit( 132 + &format!("reset_password:{}", client_ip), 133 + 10, 134 + 60_000, 135 + ).await { 136 + if state.rate_limiters.reset_password.check_key(&client_ip).is_err() { 137 + warn!(ip = %client_ip, "Reset password rate limit exceeded"); 138 + return ( 139 + StatusCode::TOO_MANY_REQUESTS, 140 + Json(json!({ 141 + "error": "RateLimitExceeded", 142 + "message": "Too many requests. Please try again later." 143 + })), 144 + ).into_response(); 145 + } 146 + } 147 + 129 148 let token = input.token.trim(); 130 149 let password = &input.password; 131 150
+19
src/api/server/session.rs
··· 72 72 { 73 73 Ok(Some(row)) => row, 74 74 Ok(None) => { 75 + let _ = verify(&input.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 75 76 warn!("User not found for login attempt"); 76 77 return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 77 78 } ··· 196 197 State(state): State<AppState>, 197 198 headers: axum::http::HeaderMap, 198 199 ) -> Response { 200 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 201 + if !state.distributed_rate_limiter.check_rate_limit( 202 + &format!("refresh_session:{}", client_ip), 203 + 60, 204 + 60_000, 205 + ).await { 206 + if state.rate_limiters.refresh_session.check_key(&client_ip).is_err() { 207 + tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 208 + return ( 209 + axum::http::StatusCode::TOO_MANY_REQUESTS, 210 + axum::Json(serde_json::json!({ 211 + "error": "RateLimitExceeded", 212 + "message": "Too many requests. Please try again later." 213 + })), 214 + ).into_response(); 215 + } 216 + } 217 + 199 218 let refresh_token = match crate::auth::extract_bearer_token_from_header( 200 219 headers.get("Authorization").and_then(|h| h.to_str().ok()) 201 220 ) {
+207
src/cache/mod.rs
··· 1 + use async_trait::async_trait; 2 + use std::sync::Arc; 3 + use std::time::Duration; 4 + 5 + #[derive(Debug, thiserror::Error)] 6 + pub enum CacheError { 7 + #[error("Cache connection error: {0}")] 8 + Connection(String), 9 + #[error("Serialization error: {0}")] 10 + Serialization(String), 11 + } 12 + 13 + #[async_trait] 14 + pub trait Cache: Send + Sync { 15 + async fn get(&self, key: &str) -> Option<String>; 16 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>; 17 + async fn delete(&self, key: &str) -> Result<(), CacheError>; 18 + } 19 + 20 + #[derive(Clone)] 21 + pub struct ValkeyCache { 22 + conn: redis::aio::ConnectionManager, 23 + } 24 + 25 + impl ValkeyCache { 26 + pub async fn new(url: &str) -> Result<Self, CacheError> { 27 + let client = redis::Client::open(url) 28 + .map_err(|e| CacheError::Connection(e.to_string()))?; 29 + let manager = client 30 + .get_connection_manager() 31 + .await 32 + .map_err(|e| CacheError::Connection(e.to_string()))?; 33 + Ok(Self { conn: manager }) 34 + } 35 + 36 + pub fn connection(&self) -> redis::aio::ConnectionManager { 37 + self.conn.clone() 38 + } 39 + } 40 + 41 + #[async_trait] 42 + impl Cache for ValkeyCache { 43 + async fn get(&self, key: &str) -> Option<String> { 44 + let mut conn = self.conn.clone(); 45 + redis::cmd("GET") 46 + .arg(key) 47 + .query_async::<Option<String>>(&mut conn) 48 + .await 49 + .ok() 50 + .flatten() 51 + } 52 + 53 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 54 + let mut conn = self.conn.clone(); 55 + redis::cmd("SET") 56 + .arg(key) 57 + .arg(value) 58 + .arg("EX") 59 + .arg(ttl.as_secs() as i64) 60 + .query_async::<()>(&mut conn) 61 + .await 62 + .map_err(|e| CacheError::Connection(e.to_string())) 63 + } 64 + 65 + async fn delete(&self, key: &str) -> Result<(), CacheError> { 66 + let mut conn = self.conn.clone(); 67 + redis::cmd("DEL") 68 + .arg(key) 69 + .query_async::<()>(&mut conn) 70 + .await 71 + .map_err(|e| CacheError::Connection(e.to_string())) 72 + } 73 + } 74 + 75 + pub struct NoOpCache; 76 + 77 + #[async_trait] 78 + impl Cache for NoOpCache { 79 + async fn get(&self, _key: &str) -> Option<String> { 80 + None 81 + } 82 + 83 + async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> { 84 + Ok(()) 85 + } 86 + 87 + async fn delete(&self, _key: &str) -> Result<(), CacheError> { 88 + Ok(()) 89 + } 90 + } 91 + 92 + #[async_trait] 93 + pub trait DistributedRateLimiter: Send + Sync { 94 + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool; 95 + } 96 + 97 + #[derive(Clone)] 98 + pub struct RedisRateLimiter { 99 + conn: redis::aio::ConnectionManager, 100 + } 101 + 102 + impl RedisRateLimiter { 103 + pub fn new(conn: redis::aio::ConnectionManager) -> Self { 104 + Self { conn } 105 + } 106 + } 107 + 108 + #[async_trait] 109 + impl DistributedRateLimiter for RedisRateLimiter { 110 + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 111 + let mut conn = self.conn.clone(); 112 + let full_key = format!("rl:{}", key); 113 + let window_secs = ((window_ms + 999) / 1000).max(1) as i64; 114 + 115 + let count: Result<i64, _> = redis::cmd("INCR") 116 + .arg(&full_key) 117 + .query_async(&mut conn) 118 + .await; 119 + 120 + let count = match count { 121 + Ok(c) => c, 122 + Err(e) => { 123 + tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e); 124 + return true; 125 + } 126 + }; 127 + 128 + if count == 1 { 129 + let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE") 130 + .arg(&full_key) 131 + .arg(window_secs) 132 + .query_async(&mut conn) 133 + .await; 134 + } 135 + 136 + count <= limit as i64 137 + } 138 + } 139 + 140 + pub struct NoOpRateLimiter; 141 + 142 + #[async_trait] 143 + impl DistributedRateLimiter for NoOpRateLimiter { 144 + async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool { 145 + true 146 + } 147 + } 148 + 149 + pub enum CacheBackend { 150 + Valkey(ValkeyCache), 151 + NoOp, 152 + } 153 + 154 + impl CacheBackend { 155 + pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> { 156 + match self { 157 + CacheBackend::Valkey(cache) => { 158 + Arc::new(RedisRateLimiter::new(cache.connection())) 159 + } 160 + CacheBackend::NoOp => Arc::new(NoOpRateLimiter), 161 + } 162 + } 163 + } 164 + 165 + #[async_trait] 166 + impl Cache for CacheBackend { 167 + async fn get(&self, key: &str) -> Option<String> { 168 + match self { 169 + CacheBackend::Valkey(c) => c.get(key).await, 170 + CacheBackend::NoOp => None, 171 + } 172 + } 173 + 174 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 175 + match self { 176 + CacheBackend::Valkey(c) => c.set(key, value, ttl).await, 177 + CacheBackend::NoOp => Ok(()), 178 + } 179 + } 180 + 181 + async fn delete(&self, key: &str) -> Result<(), CacheError> { 182 + match self { 183 + CacheBackend::Valkey(c) => c.delete(key).await, 184 + CacheBackend::NoOp => Ok(()), 185 + } 186 + } 187 + } 188 + 189 + pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 190 + match std::env::var("VALKEY_URL") { 191 + Ok(url) => match ValkeyCache::new(&url).await { 192 + Ok(cache) => { 193 + tracing::info!("Connected to Valkey cache at {}", url); 194 + let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection())); 195 + (Arc::new(cache), rate_limiter) 196 + } 197 + Err(e) => { 198 + tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e); 199 + (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 200 + } 201 + }, 202 + Err(_) => { 203 + tracing::info!("VALKEY_URL not set. Running without cache."); 204 + (Arc::new(NoOpCache), Arc::new(NoOpRateLimiter)) 205 + } 206 + } 207 + }
+1
src/lib.rs
··· 1 1 pub mod api; 2 2 pub mod auth; 3 + pub mod cache; 3 4 pub mod circuit_breaker; 4 5 pub mod config; 5 6 pub mod crawlers;
+37 -1
src/oauth/endpoints/authorize.rs
··· 272 272 ) -> Response { 273 273 let json_response = wants_json(&headers); 274 274 275 + let client_ip = extract_client_ip(&headers); 276 + if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 277 + tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 278 + if json_response { 279 + return ( 280 + axum::http::StatusCode::TOO_MANY_REQUESTS, 281 + Json(serde_json::json!({ 282 + "error": "RateLimitExceeded", 283 + "error_description": "Too many login attempts. Please try again later." 284 + })), 285 + ).into_response(); 286 + } 287 + return ( 288 + axum::http::StatusCode::TOO_MANY_REQUESTS, 289 + Html(templates::error_page( 290 + "RateLimitExceeded", 291 + Some("Too many login attempts. Please try again later."), 292 + )), 293 + ).into_response(); 294 + } 295 + 275 296 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 276 297 Ok(Some(data)) => data, 277 298 Ok(None) => { ··· 357 378 .await 358 379 { 359 380 Ok(Some(u)) => u, 360 - Ok(None) => return show_login_error("Invalid handle/email or password.", json_response), 381 + Ok(None) => { 382 + let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK"); 383 + return show_login_error("Invalid handle/email or password.", json_response); 384 + } 361 385 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 362 386 }; 363 387 ··· 736 760 headers: HeaderMap, 737 761 Form(form): Form<Authorize2faSubmit>, 738 762 ) -> Response { 763 + let client_ip = extract_client_ip(&headers); 764 + if state.rate_limiters.oauth_authorize.check_key(&client_ip).is_err() { 765 + tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 766 + return ( 767 + axum::http::StatusCode::TOO_MANY_REQUESTS, 768 + Html(templates::error_page( 769 + "RateLimitExceeded", 770 + Some("Too many attempts. Please try again later."), 771 + )), 772 + ).into_response(); 773 + } 774 + 739 775 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 740 776 Ok(Some(c)) => c, 741 777 Ok(None) => {
+14
src/oauth/endpoints/par.rs
··· 1 1 use axum::{ 2 2 Form, Json, 3 3 extract::State, 4 + http::HeaderMap, 4 5 }; 5 6 use chrono::{Duration, Utc}; 6 7 use serde::{Deserialize, Serialize}; ··· 49 50 50 51 pub async fn pushed_authorization_request( 51 52 State(state): State<AppState>, 53 + headers: HeaderMap, 52 54 Form(request): Form<ParRequest>, 53 55 ) -> Result<Json<ParResponse>, OAuthError> { 56 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 57 + if !state.distributed_rate_limiter.check_rate_limit( 58 + &format!("oauth_par:{}", client_ip), 59 + 30, 60 + 60_000, 61 + ).await { 62 + if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() { 63 + tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 64 + return Err(OAuthError::RateLimited); 65 + } 66 + } 67 + 54 68 if request.response_type != "code" { 55 69 return Err(OAuthError::InvalidRequest( 56 70 "response_type must be 'code'".to_string(),
+33 -7
src/oauth/endpoints/token/introspect.rs
··· 1 1 use axum::{Form, Json}; 2 2 use axum::extract::State; 3 - use axum::http::StatusCode; 3 + use axum::http::{HeaderMap, StatusCode}; 4 4 use chrono::Utc; 5 5 use serde::{Deserialize, Serialize}; 6 6 ··· 18 18 19 19 pub async fn revoke_token( 20 20 State(state): State<AppState>, 21 + headers: HeaderMap, 21 22 Form(request): Form<RevokeRequest>, 22 23 ) -> Result<StatusCode, OAuthError> { 24 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 25 + if !state.distributed_rate_limiter.check_rate_limit( 26 + &format!("oauth_revoke:{}", client_ip), 27 + 30, 28 + 60_000, 29 + ).await { 30 + if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 31 + tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 32 + return Err(OAuthError::RateLimited); 33 + } 34 + } 35 + 23 36 if let Some(token) = &request.token { 24 37 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 25 38 db::delete_token_family(&state.db, db_id).await?; ··· 67 80 68 81 pub async fn introspect_token( 69 82 State(state): State<AppState>, 83 + headers: HeaderMap, 70 84 Form(request): Form<IntrospectRequest>, 71 - ) -> Json<IntrospectResponse> { 85 + ) -> Result<Json<IntrospectResponse>, OAuthError> { 86 + let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 87 + if !state.distributed_rate_limiter.check_rate_limit( 88 + &format!("oauth_introspect:{}", client_ip), 89 + 30, 90 + 60_000, 91 + ).await { 92 + if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() { 93 + tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 94 + return Err(OAuthError::RateLimited); 95 + } 96 + } 97 + 72 98 let inactive_response = IntrospectResponse { 73 99 active: false, 74 100 scope: None, ··· 86 112 87 113 let token_info = match extract_token_claims(&request.token) { 88 114 Ok(info) => info, 89 - Err(_) => return Json(inactive_response), 115 + Err(_) => return Ok(Json(inactive_response)), 90 116 }; 91 117 92 118 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 93 119 Ok(Some(data)) => data, 94 - _ => return Json(inactive_response), 120 + _ => return Ok(Json(inactive_response)), 95 121 }; 96 122 97 123 if token_data.expires_at < Utc::now() { 98 - return Json(inactive_response); 124 + return Ok(Json(inactive_response)); 99 125 } 100 126 101 127 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 128 let issuer = format!("https://{}", pds_hostname); 103 129 104 - Json(IntrospectResponse { 130 + Ok(Json(IntrospectResponse { 105 131 active: true, 106 132 scope: token_data.scope, 107 133 client_id: Some(token_data.client_id), ··· 118 144 aud: Some(issuer.clone()), 119 145 iss: Some(issuer), 120 146 jti: Some(token_info.jti), 121 - }) 147 + })) 122 148 }
+4
src/oauth/error.rs
··· 19 19 InvalidDpopProof(String), 20 20 ExpiredToken(String), 21 21 InvalidToken(String), 22 + RateLimited, 22 23 } 23 24 24 25 #[derive(Serialize)] ··· 73 74 } 74 75 OAuthError::InvalidToken(msg) => { 75 76 (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)) 77 + } 78 + OAuthError::RateLimited => { 79 + (StatusCode::TOO_MANY_REQUESTS, "rate_limited", Some("Too many requests. Please try again later.".to_string())) 76 80 } 77 81 }; 78 82
+36 -1
src/rate_limit.rs
··· 24 24 pub struct RateLimiters { 25 25 pub login: Arc<KeyedRateLimiter>, 26 26 pub oauth_token: Arc<KeyedRateLimiter>, 27 + pub oauth_authorize: Arc<KeyedRateLimiter>, 27 28 pub password_reset: Arc<KeyedRateLimiter>, 28 29 pub account_creation: Arc<KeyedRateLimiter>, 30 + pub refresh_session: Arc<KeyedRateLimiter>, 31 + pub reset_password: Arc<KeyedRateLimiter>, 32 + pub oauth_par: Arc<KeyedRateLimiter>, 33 + pub oauth_introspect: Arc<KeyedRateLimiter>, 34 + pub app_password: Arc<KeyedRateLimiter>, 35 + pub email_update: Arc<KeyedRateLimiter>, 29 36 } 30 37 31 38 impl Default for RateLimiters { ··· 42 49 )), 43 50 oauth_token: Arc::new(RateLimiter::keyed( 44 51 Quota::per_minute(NonZeroU32::new(30).unwrap()) 52 + )), 53 + oauth_authorize: Arc::new(RateLimiter::keyed( 54 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 45 55 )), 46 56 password_reset: Arc::new(RateLimiter::keyed( 47 57 Quota::per_hour(NonZeroU32::new(5).unwrap()) ··· 49 59 account_creation: Arc::new(RateLimiter::keyed( 50 60 Quota::per_hour(NonZeroU32::new(10).unwrap()) 51 61 )), 62 + refresh_session: Arc::new(RateLimiter::keyed( 63 + Quota::per_minute(NonZeroU32::new(60).unwrap()) 64 + )), 65 + reset_password: Arc::new(RateLimiter::keyed( 66 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 67 + )), 68 + oauth_par: Arc::new(RateLimiter::keyed( 69 + Quota::per_minute(NonZeroU32::new(30).unwrap()) 70 + )), 71 + oauth_introspect: Arc::new(RateLimiter::keyed( 72 + Quota::per_minute(NonZeroU32::new(30).unwrap()) 73 + )), 74 + app_password: Arc::new(RateLimiter::keyed( 75 + Quota::per_minute(NonZeroU32::new(10).unwrap()) 76 + )), 77 + email_update: Arc::new(RateLimiter::keyed( 78 + Quota::per_hour(NonZeroU32::new(5).unwrap()) 79 + )), 52 80 } 53 81 } 54 82 ··· 66 94 self 67 95 } 68 96 97 + pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 98 + self.oauth_authorize = Arc::new(RateLimiter::keyed( 99 + Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap())) 100 + )); 101 + self 102 + } 103 + 69 104 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 70 105 self.password_reset = Arc::new(RateLimiter::keyed( 71 106 Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap())) ··· 81 116 } 82 117 } 83 118 84 - fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 119 + pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 85 120 if let Some(forwarded) = headers.get("x-forwarded-for") { 86 121 if let Ok(value) = forwarded.to_str() { 87 122 if let Some(first_ip) = value.split(',').next() {
+6
src/state.rs
··· 1 + use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 1 2 use crate::circuit_breaker::CircuitBreakers; 2 3 use crate::config::AuthConfig; 3 4 use crate::rate_limit::RateLimiters; ··· 16 17 pub firehose_tx: broadcast::Sender<SequencedEvent>, 17 18 pub rate_limiters: Arc<RateLimiters>, 18 19 pub circuit_breakers: Arc<CircuitBreakers>, 20 + pub cache: Arc<dyn Cache>, 21 + pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, 19 22 } 20 23 21 24 impl AppState { ··· 27 30 let (firehose_tx, _) = broadcast::channel(1000); 28 31 let rate_limiters = Arc::new(RateLimiters::new()); 29 32 let circuit_breakers = Arc::new(CircuitBreakers::new()); 33 + let (cache, distributed_rate_limiter) = create_cache().await; 30 34 Self { 31 35 db, 32 36 block_store, ··· 34 38 firehose_tx, 35 39 rate_limiters, 36 40 circuit_breakers, 41 + cache, 42 + distributed_rate_limiter, 37 43 } 38 44 } 39 45
+64
tests/oauth_security.rs
··· 1447 1447 let introspect_body: Value = introspect_res.json().await.unwrap(); 1448 1448 assert_eq!(introspect_body["active"], false, "Revoked token should be inactive"); 1449 1449 } 1450 + 1451 + #[tokio::test] 1452 + async fn test_security_oauth_authorize_rate_limiting() { 1453 + let url = base_url().await; 1454 + let http_client = no_redirect_client(); 1455 + 1456 + let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0); 1457 + let unique_ip = format!("10.{}.{}.{}", (ts >> 16) & 0xFF, (ts >> 8) & 0xFF, ts & 0xFF); 1458 + 1459 + let redirect_uri = "https://example.com/rate-limit-callback"; 1460 + let mock_client = setup_mock_client_metadata(redirect_uri).await; 1461 + let client_id = mock_client.uri(); 1462 + 1463 + let (_, code_challenge) = generate_pkce(); 1464 + 1465 + let client_for_par = client(); 1466 + let par_body: Value = client_for_par 1467 + .post(format!("{}/oauth/par", url)) 1468 + .form(&[ 1469 + ("response_type", "code"), 1470 + ("client_id", &client_id), 1471 + ("redirect_uri", redirect_uri), 1472 + ("code_challenge", &code_challenge), 1473 + ("code_challenge_method", "S256"), 1474 + ]) 1475 + .send() 1476 + .await 1477 + .unwrap() 1478 + .json() 1479 + .await 1480 + .unwrap(); 1481 + 1482 + let request_uri = par_body["request_uri"].as_str().unwrap(); 1483 + 1484 + let mut rate_limited_count = 0; 1485 + let mut other_count = 0; 1486 + 1487 + for _ in 0..15 { 1488 + let res = http_client 1489 + .post(format!("{}/oauth/authorize", url)) 1490 + .header("X-Forwarded-For", &unique_ip) 1491 + .form(&[ 1492 + ("request_uri", request_uri), 1493 + ("username", "nonexistent_user"), 1494 + ("password", "wrong_password"), 1495 + ("remember_device", "false"), 1496 + ]) 1497 + .send() 1498 + .await 1499 + .unwrap(); 1500 + 1501 + match res.status() { 1502 + StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1, 1503 + _ => other_count += 1, 1504 + } 1505 + } 1506 + 1507 + assert!( 1508 + rate_limited_count > 0, 1509 + "Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.", 1510 + other_count, 1511 + rate_limited_count 1512 + ); 1513 + }
+228
tests/rate_limit.rs
··· 1 + mod common; 2 + 3 + use common::{base_url, client}; 4 + use reqwest::StatusCode; 5 + use serde_json::json; 6 + 7 + #[tokio::test] 8 + async fn test_login_rate_limiting() { 9 + let client = client(); 10 + let url = format!("{}/xrpc/com.atproto.server.createSession", base_url().await); 11 + 12 + let payload = json!({ 13 + "identifier": "nonexistent_user_for_rate_limit_test", 14 + "password": "wrongpassword" 15 + }); 16 + 17 + let mut rate_limited_count = 0; 18 + let mut auth_failed_count = 0; 19 + 20 + for _ in 0..15 { 21 + let res = client 22 + .post(&url) 23 + .json(&payload) 24 + .send() 25 + .await 26 + .expect("Request failed"); 27 + 28 + match res.status() { 29 + StatusCode::TOO_MANY_REQUESTS => { 30 + rate_limited_count += 1; 31 + } 32 + StatusCode::UNAUTHORIZED => { 33 + auth_failed_count += 1; 34 + } 35 + status => { 36 + panic!("Unexpected status: {}", status); 37 + } 38 + } 39 + } 40 + 41 + assert!( 42 + rate_limited_count > 0, 43 + "Expected at least one rate-limited response after 15 login attempts. Got {} auth failures and {} rate limits.", 44 + auth_failed_count, 45 + rate_limited_count 46 + ); 47 + } 48 + 49 + #[tokio::test] 50 + async fn test_password_reset_rate_limiting() { 51 + let client = client(); 52 + let url = format!( 53 + "{}/xrpc/com.atproto.server.requestPasswordReset", 54 + base_url().await 55 + ); 56 + 57 + let mut rate_limited_count = 0; 58 + let mut success_count = 0; 59 + 60 + for i in 0..8 { 61 + let payload = json!({ 62 + "email": format!("ratelimit_test_{}@example.com", i) 63 + }); 64 + 65 + let res = client 66 + .post(&url) 67 + .json(&payload) 68 + .send() 69 + .await 70 + .expect("Request failed"); 71 + 72 + match res.status() { 73 + StatusCode::TOO_MANY_REQUESTS => { 74 + rate_limited_count += 1; 75 + } 76 + StatusCode::OK => { 77 + success_count += 1; 78 + } 79 + status => { 80 + panic!("Unexpected status: {} - {:?}", status, res.text().await); 81 + } 82 + } 83 + } 84 + 85 + assert!( 86 + rate_limited_count > 0, 87 + "Expected rate limiting after {} password reset requests. Got {} successes.", 88 + success_count + rate_limited_count, 89 + success_count 90 + ); 91 + } 92 + 93 + #[tokio::test] 94 + async fn test_account_creation_rate_limiting() { 95 + let client = client(); 96 + let url = format!( 97 + "{}/xrpc/com.atproto.server.createAccount", 98 + base_url().await 99 + ); 100 + 101 + let mut rate_limited_count = 0; 102 + let mut other_count = 0; 103 + 104 + for i in 0..15 { 105 + let unique_id = uuid::Uuid::new_v4(); 106 + let payload = json!({ 107 + "handle": format!("ratelimit_{}_{}", i, unique_id), 108 + "email": format!("ratelimit_{}_{}@example.com", i, unique_id), 109 + "password": "testpassword123" 110 + }); 111 + 112 + let res = client 113 + .post(&url) 114 + .json(&payload) 115 + .send() 116 + .await 117 + .expect("Request failed"); 118 + 119 + match res.status() { 120 + StatusCode::TOO_MANY_REQUESTS => { 121 + rate_limited_count += 1; 122 + } 123 + _ => { 124 + other_count += 1; 125 + } 126 + } 127 + } 128 + 129 + assert!( 130 + rate_limited_count > 0, 131 + "Expected rate limiting after account creation attempts. Got {} other responses and {} rate limits.", 132 + other_count, 133 + rate_limited_count 134 + ); 135 + } 136 + 137 + #[tokio::test] 138 + async fn test_valkey_connection() { 139 + if std::env::var("VALKEY_URL").is_err() { 140 + println!("VALKEY_URL not set, skipping Valkey connection test"); 141 + return; 142 + } 143 + 144 + let valkey_url = std::env::var("VALKEY_URL").unwrap(); 145 + let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client"); 146 + let mut conn = client 147 + .get_multiplexed_async_connection() 148 + .await 149 + .expect("Failed to connect to Valkey"); 150 + 151 + let pong: String = redis::cmd("PING") 152 + .query_async(&mut conn) 153 + .await 154 + .expect("PING failed"); 155 + assert_eq!(pong, "PONG"); 156 + 157 + let _: () = redis::cmd("SET") 158 + .arg("test_key") 159 + .arg("test_value") 160 + .arg("EX") 161 + .arg(10) 162 + .query_async(&mut conn) 163 + .await 164 + .expect("SET failed"); 165 + 166 + let value: String = redis::cmd("GET") 167 + .arg("test_key") 168 + .query_async(&mut conn) 169 + .await 170 + .expect("GET failed"); 171 + assert_eq!(value, "test_value"); 172 + 173 + let _: () = redis::cmd("DEL") 174 + .arg("test_key") 175 + .query_async(&mut conn) 176 + .await 177 + .expect("DEL failed"); 178 + } 179 + 180 + #[tokio::test] 181 + async fn test_distributed_rate_limiter_directly() { 182 + if std::env::var("VALKEY_URL").is_err() { 183 + println!("VALKEY_URL not set, skipping distributed rate limiter test"); 184 + return; 185 + } 186 + 187 + use bspds::cache::{DistributedRateLimiter, RedisRateLimiter}; 188 + 189 + let valkey_url = std::env::var("VALKEY_URL").unwrap(); 190 + let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client"); 191 + let conn = client 192 + .get_connection_manager() 193 + .await 194 + .expect("Failed to get connection manager"); 195 + 196 + let rate_limiter = RedisRateLimiter::new(conn); 197 + 198 + let test_key = format!("test_rate_limit_{}", uuid::Uuid::new_v4()); 199 + let limit = 5; 200 + let window_ms = 60_000; 201 + 202 + for i in 0..limit { 203 + let allowed = rate_limiter 204 + .check_rate_limit(&test_key, limit, window_ms) 205 + .await; 206 + assert!( 207 + allowed, 208 + "Request {} should have been allowed (limit: {})", 209 + i + 1, 210 + limit 211 + ); 212 + } 213 + 214 + let allowed = rate_limiter 215 + .check_rate_limit(&test_key, limit, window_ms) 216 + .await; 217 + assert!( 218 + !allowed, 219 + "Request {} should have been rate limited (limit: {})", 220 + limit + 1, 221 + limit 222 + ); 223 + 224 + let allowed = rate_limiter 225 + .check_rate_limit(&test_key, limit, window_ms) 226 + .await; 227 + assert!(!allowed, "Subsequent request should also be rate limited"); 228 + }