+34
.sqlx/query-08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97.json
+34
.sqlx/query-08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, password_hash, handle FROM users WHERE did = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "password_hash",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "handle",
19
+
"type_info": "Text"
20
+
}
21
+
],
22
+
"parameters": {
23
+
"Left": [
24
+
"Text"
25
+
]
26
+
},
27
+
"nullable": [
28
+
false,
29
+
false,
30
+
false
31
+
]
32
+
},
33
+
"hash": "08c08b0644d79d5de72f3500dd7dbb8827af340e3c04fec9a5c28aeff46e0c97"
34
+
}
-28
.sqlx/query-76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3.json
-28
.sqlx/query-76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT id, password_hash FROM users WHERE did = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "id",
9
-
"type_info": "Uuid"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "password_hash",
14
-
"type_info": "Text"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Text"
20
-
]
21
-
},
22
-
"nullable": [
23
-
false,
24
-
false
25
-
]
26
-
},
27
-
"hash": "76c6ef1d5395105a0cdedb27ca321c9e3eae1ce87c223b706ed81ebf973875f3"
28
-
}
+22
.sqlx/query-e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf.json
+22
.sqlx/query-e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT handle FROM users WHERE id = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "handle",
9
+
"type_info": "Text"
10
+
}
11
+
],
12
+
"parameters": {
13
+
"Left": [
14
+
"Uuid"
15
+
]
16
+
},
17
+
"nullable": [
18
+
false
19
+
]
20
+
},
21
+
"hash": "e223898d53602c1c8b23eb08a4b96cf20ac349d1fa4e91334b225d3069209dcf"
22
+
}
+73
-2
Cargo.lock
+73
-2
Cargo.lock
···
99
99
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
100
100
101
101
[[package]]
102
+
name = "arc-swap"
103
+
version = "1.7.1"
104
+
source = "registry+https://github.com/rust-lang/crates.io-index"
105
+
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
106
+
107
+
[[package]]
102
108
name = "assert-json-diff"
103
109
version = "2.0.2"
104
110
source = "registry+https://github.com/rust-lang/crates.io-index"
···
689
695
]
690
696
691
697
[[package]]
698
+
name = "backon"
699
+
version = "1.6.0"
700
+
source = "registry+https://github.com/rust-lang/crates.io-index"
701
+
checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef"
702
+
dependencies = [
703
+
"fastrand",
704
+
]
705
+
706
+
[[package]]
692
707
name = "base-x"
693
708
version = "0.2.11"
694
709
source = "registry+https://github.com/rust-lang/crates.io-index"
···
931
946
"p256 0.13.2",
932
947
"p384",
933
948
"rand 0.8.5",
949
+
"redis",
934
950
"reqwest",
935
951
"serde",
936
952
"serde_bytes",
···
1176
1192
version = "1.1.0"
1177
1193
source = "registry+https://github.com/rust-lang/crates.io-index"
1178
1194
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
1195
+
1196
+
[[package]]
1197
+
name = "combine"
1198
+
version = "4.6.7"
1199
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1200
+
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
1201
+
dependencies = [
1202
+
"bytes",
1203
+
"futures-core",
1204
+
"memchr",
1205
+
"pin-project-lite",
1206
+
"tokio",
1207
+
"tokio-util",
1208
+
]
1179
1209
1180
1210
[[package]]
1181
1211
name = "compression-codecs"
···
2973
3003
2974
3004
[[package]]
2975
3005
name = "itertools"
3006
+
version = "0.13.0"
3007
+
source = "registry+https://github.com/rust-lang/crates.io-index"
3008
+
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
3009
+
dependencies = [
3010
+
"either",
3011
+
]
3012
+
3013
+
[[package]]
3014
+
name = "itertools"
2976
3015
version = "0.14.0"
2977
3016
source = "registry+https://github.com/rust-lang/crates.io-index"
2978
3017
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
···
4241
4280
checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425"
4242
4281
dependencies = [
4243
4282
"anyhow",
4244
-
"itertools",
4283
+
"itertools 0.14.0",
4245
4284
"proc-macro2",
4246
4285
"quote",
4247
4286
"syn 2.0.111",
···
4442
4481
]
4443
4482
4444
4483
[[package]]
4484
+
name = "redis"
4485
+
version = "0.27.6"
4486
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4487
+
checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc"
4488
+
dependencies = [
4489
+
"arc-swap",
4490
+
"async-trait",
4491
+
"backon",
4492
+
"bytes",
4493
+
"combine",
4494
+
"futures",
4495
+
"futures-util",
4496
+
"itertools 0.13.0",
4497
+
"itoa",
4498
+
"num-bigint",
4499
+
"percent-encoding",
4500
+
"pin-project-lite",
4501
+
"ryu",
4502
+
"sha1_smol",
4503
+
"socket2 0.5.10",
4504
+
"tokio",
4505
+
"tokio-util",
4506
+
"url",
4507
+
]
4508
+
4509
+
[[package]]
4445
4510
name = "redox_syscall"
4446
4511
version = "0.5.18"
4447
4512
source = "registry+https://github.com/rust-lang/crates.io-index"
···
5055
5120
]
5056
5121
5057
5122
[[package]]
5123
+
name = "sha1_smol"
5124
+
version = "1.0.1"
5125
+
source = "registry+https://github.com/rust-lang/crates.io-index"
5126
+
checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d"
5127
+
5128
+
[[package]]
5058
5129
name = "sha2"
5059
5130
version = "0.10.9"
5060
5131
source = "registry+https://github.com/rust-lang/crates.io-index"
···
5646
5717
"etcetera 0.11.0",
5647
5718
"ferroid",
5648
5719
"futures",
5649
-
"itertools",
5720
+
"itertools 0.14.0",
5650
5721
"log",
5651
5722
"memchr",
5652
5723
"parse-display",
+1
Cargo.toml
+1
Cargo.toml
···
49
49
uuid = { version = "1.19.0", features = ["v4", "fast-rng"] }
50
50
iroh-car = "0.5.1"
51
51
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] }
52
+
redis = { version = "0.27", features = ["tokio-comp", "connection-manager"] }
52
53
53
54
[features]
54
55
external-infra = []
+17
-7
TODO.md
+17
-7
TODO.md
···
198
198
- [x] Implement Atomic Repo Transactions.
199
199
- [x] Ensure `blocks` write, `repo_root` update, `records` index update, and `sequencer` event are committed in a single transaction.
200
200
- [x] Implement concurrency control (row-level locking via FOR UPDATE).
201
-
- [ ] DID Cache
202
-
- [ ] Implement caching layer for DID resolution (Redis or in-memory).
203
-
- [ ] Handle cache invalidation/expiry.
201
+
- [x] DID Cache
202
+
- [x] Implement caching layer for DID resolution (valkey).
203
+
- [x] Handle cache invalidation/expiry.
204
+
- [x] Graceful fallback to no-cache when Valkey unavailable.
204
205
- [x] Crawlers Service
205
206
- [x] Implement `Crawlers` service (debounce notifications to relays).
206
207
- [x] 20-minute notification debounce.
···
229
230
- [x] Per-IP rate limiting on OAuth token endpoint (30/min).
230
231
- [x] Per-IP rate limiting on password reset (5/hour).
231
232
- [x] Per-IP rate limiting on account creation (10/hour).
233
+
- [x] Per-IP rate limiting on refreshSession (60/min).
234
+
- [x] Per-IP rate limiting on OAuth authorize POST (10/min).
235
+
- [x] Per-IP rate limiting on OAuth 2FA POST (10/min).
236
+
- [x] Per-IP rate limiting on OAuth PAR (30/min).
237
+
- [x] Per-IP rate limiting on OAuth revoke/introspect (30/min).
238
+
- [x] Per-IP rate limiting on createAppPassword (10/min).
239
+
- [x] Per-IP rate limiting on email endpoints (5/hour).
240
+
- [x] Distributed rate limiting via Valkey/Redis (with in-memory fallback).
232
241
- [x] Circuit Breakers
233
242
- [x] PLC directory circuit breaker (5 failures → open, 60s timeout).
234
243
- [x] Relay notification circuit breaker (10 failures → open, 30s timeout).
···
237
246
- [x] Signal command injection prevention (phone number validation).
238
247
- [x] Constant-time signature comparison.
239
248
- [x] SSRF protection for outbound requests.
249
+
- [x] Timing attack protection (dummy bcrypt on user-not-found prevents account enumeration).
240
250
241
251
## Lewis' fabulous mini-list of remaining TODOs
242
-
- [ ] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again.
243
-
- [ ] DID resolution caching (valkey).
244
-
- [ ] Record schema validation (generic validation framework).
245
-
- [ ] Fix any remaining TODOs in the code.
252
+
- [x] The OAuth authorize POST endpoint has no rate limiting, allowing password brute-forcing. Fix this and audit all oauth and 2fa surface again.
253
+
- [x] DID resolution caching (valkey).
254
+
- [x] Record schema validation (generic validation framework).
255
+
- [x] Fix any remaining TODOs in the code.
246
256
247
257
## Future: Web Management UI
248
258
A single-page web app for account management. The frontend (JS framework) calls existing ATProto XRPC endpoints - no server-side rendering or bespoke HTML form handlers.
+10
docker-compose.yaml
+10
docker-compose.yaml
···
11
11
environment:
12
12
DATABASE_URL: postgres://postgres:postgres@db:5432/pds
13
13
S3_ENDPOINT: http://objsto:9000
14
+
VALKEY_URL: redis://cache:6379
14
15
depends_on:
15
16
- db
16
17
- objsto
18
+
- cache
17
19
18
20
db:
19
21
image: postgres:latest
···
38
40
- minio_data:/data
39
41
command: server /data --console-address ":9001"
40
42
43
+
cache:
44
+
image: valkey/valkey:8-alpine
45
+
ports:
46
+
- "6379:6379"
47
+
volumes:
48
+
- valkey_data:/data
49
+
41
50
volumes:
42
51
postgres_data:
43
52
minio_data:
53
+
valkey_data:
+20
-3
scripts/test-infra.sh
+20
-3
scripts/test-infra.sh
···
38
38
rm -f "$INFRA_FILE"
39
39
fi
40
40
41
-
$CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true
41
+
$CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true
42
42
43
43
echo "Starting PostgreSQL..."
44
44
$CONTAINER_CMD run -d \
···
59
59
--label bspds_test=true \
60
60
minio/minio:latest server /data >/dev/null
61
61
62
+
echo "Starting Valkey..."
63
+
$CONTAINER_CMD run -d \
64
+
--name "${CONTAINER_PREFIX}-valkey" \
65
+
-P \
66
+
--label bspds_test=true \
67
+
valkey/valkey:8-alpine >/dev/null
68
+
62
69
echo "Waiting for services to be ready..."
63
70
sleep 2
64
71
65
72
PG_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-postgres" 5432 | head -1 | cut -d: -f2)
66
73
MINIO_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-minio" 9000 | head -1 | cut -d: -f2)
74
+
VALKEY_PORT=$($CONTAINER_CMD port "${CONTAINER_PREFIX}-valkey" 6379 | head -1 | cut -d: -f2)
67
75
68
76
for i in {1..30}; do
69
77
if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-postgres" pg_isready -U postgres >/dev/null 2>&1; then
···
81
89
sleep 1
82
90
done
83
91
92
+
for i in {1..30}; do
93
+
if $CONTAINER_CMD exec "${CONTAINER_PREFIX}-valkey" valkey-cli ping 2>/dev/null | grep -q PONG; then
94
+
break
95
+
fi
96
+
echo "Waiting for Valkey... ($i/30)"
97
+
sleep 1
98
+
done
99
+
84
100
echo "Creating MinIO bucket..."
85
101
$CONTAINER_CMD run --rm --network host \
86
102
-e MC_HOST_minio="http://minioadmin:minioadmin@127.0.0.1:${MINIO_PORT}" \
···
94
110
export AWS_ACCESS_KEY_ID="minioadmin"
95
111
export AWS_SECRET_ACCESS_KEY="minioadmin"
96
112
export AWS_REGION="us-east-1"
113
+
export VALKEY_URL="redis://127.0.0.1:${VALKEY_PORT}"
97
114
export BSPDS_TEST_INFRA_READY="1"
98
115
export BSPDS_ALLOW_INSECURE_SECRETS="1"
99
116
export SKIP_IMPORT_VERIFICATION="true"
···
108
125
109
126
stop_infra() {
110
127
echo "Stopping test infrastructure..."
111
-
$CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" 2>/dev/null || true
128
+
$CONTAINER_CMD rm -f "${CONTAINER_PREFIX}-postgres" "${CONTAINER_PREFIX}-minio" "${CONTAINER_PREFIX}-valkey" 2>/dev/null || true
112
129
rm -f "$INFRA_FILE"
113
130
echo "Infrastructure stopped."
114
131
}
···
157
174
echo "Usage: $0 {start|stop|restart|status|env}"
158
175
echo ""
159
176
echo "Commands:"
160
-
echo " start - Start test infrastructure (Postgres, MinIO)"
177
+
echo " start - Start test infrastructure (Postgres, MinIO, Valkey)"
161
178
echo " stop - Stop and remove test containers"
162
179
echo " restart - Stop then start infrastructure"
163
180
echo " status - Show infrastructure status"
+5
-3
src/api/admin/account/delete.rs
+5
-3
src/api/admin/account/delete.rs
···
37
37
.into_response();
38
38
}
39
39
40
-
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
40
+
let user = sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
41
41
.fetch_optional(&state.db)
42
42
.await;
43
43
44
-
let user_id = match user {
45
-
Ok(Some(row)) => row.id,
44
+
let (user_id, handle) = match user {
45
+
Ok(Some(row)) => (row.id, row.handle),
46
46
Ok(None) => {
47
47
return (
48
48
StatusCode::NOT_FOUND,
···
185
185
)
186
186
.into_response();
187
187
}
188
+
189
+
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
188
190
189
191
(StatusCode::OK, Json(json!({}))).into_response()
190
192
}
+10
src/api/admin/account/update.rs
+10
src/api/admin/account/update.rs
···
108
108
.into_response();
109
109
}
110
110
111
+
let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
112
+
.fetch_optional(&state.db)
113
+
.await
114
+
.ok()
115
+
.flatten();
116
+
111
117
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
112
118
.fetch_optional(&state.db)
113
119
.await;
···
133
139
)
134
140
.into_response();
135
141
}
142
+
if let Some(old) = old_handle {
143
+
let _ = state.cache.delete(&format!("handle:{}", old)).await;
144
+
}
145
+
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
136
146
(StatusCode::OK, Json(json!({}))).into_response()
137
147
}
138
148
Err(e) => {
+7
src/api/admin/status.rs
+7
src/api/admin/status.rs
···
305
305
.into_response();
306
306
}
307
307
308
+
if let Ok(Some(handle)) = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
309
+
.fetch_optional(&state.db)
310
+
.await
311
+
{
312
+
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
313
+
}
314
+
308
315
return (
309
316
StatusCode::OK,
310
317
Json(json!({
+19
-1
src/api/identity/did.rs
+19
-1
src/api/identity/did.rs
···
33
33
.into_response();
34
34
}
35
35
36
+
let cache_key = format!("handle:{}", handle);
37
+
if let Some(did) = state.cache.get(&cache_key).await {
38
+
return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
39
+
}
40
+
36
41
let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle)
37
42
.fetch_optional(&state.db)
38
43
.await;
39
44
40
45
match user {
41
46
Ok(Some(row)) => {
47
+
let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
42
48
(StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
43
49
}
44
50
Ok(None) => (
···
406
412
.into_response();
407
413
}
408
414
415
+
let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
416
+
.fetch_optional(&state.db)
417
+
.await
418
+
.ok()
419
+
.flatten();
420
+
409
421
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
410
422
.fetch_optional(&state.db)
411
423
.await;
···
423
435
.await;
424
436
425
437
match result {
426
-
Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
438
+
Ok(_) => {
439
+
if let Some(old) = old_handle {
440
+
let _ = state.cache.delete(&format!("handle:{}", old)).await;
441
+
}
442
+
let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
443
+
(StatusCode::OK, Json(json!({}))).into_response()
444
+
}
427
445
Err(e) => {
428
446
error!("DB error updating handle: {:?}", e);
429
447
(
+11
src/api/repo/record/batch.rs
+11
src/api/repo/record/batch.rs
···
1
+
use super::validation::validate_record;
1
2
use crate::api::repo::record::utils::{commit_and_log, RecordOp};
2
3
use crate::repo::tracking::TrackingBlockStore;
3
4
use crate::state::AppState;
···
211
212
rkey,
212
213
value,
213
214
} => {
215
+
if input.validate.unwrap_or(true) {
216
+
if let Err(err_response) = validate_record(value, collection) {
217
+
return err_response;
218
+
}
219
+
}
214
220
let rkey = rkey
215
221
.clone()
216
222
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
···
249
255
rkey,
250
256
value,
251
257
} => {
258
+
if input.validate.unwrap_or(true) {
259
+
if let Err(err_response) = validate_record(value, collection) {
260
+
return err_response;
261
+
}
262
+
}
252
263
let mut record_bytes = Vec::new();
253
264
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
254
265
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
+1
src/api/repo/record/mod.rs
+1
src/api/repo/record/mod.rs
+38
src/api/repo/record/validation.rs
+38
src/api/repo/record/validation.rs
···
1
+
use crate::validation::{RecordValidator, ValidationError};
2
+
use axum::{
3
+
http::StatusCode,
4
+
response::{IntoResponse, Response},
5
+
Json,
6
+
};
7
+
use serde_json::json;
8
+
9
+
pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> {
10
+
let validator = RecordValidator::new();
11
+
match validator.validate(record, collection) {
12
+
Ok(_) => Ok(()),
13
+
Err(ValidationError::MissingType) => Err((
14
+
StatusCode::BAD_REQUEST,
15
+
Json(json!({"error": "InvalidRecord", "message": "Record must have a $type field"})),
16
+
).into_response()),
17
+
Err(ValidationError::TypeMismatch { expected, actual }) => Err((
18
+
StatusCode::BAD_REQUEST,
19
+
Json(json!({"error": "InvalidRecord", "message": format!("Record $type '{}' does not match collection '{}'", actual, expected)})),
20
+
).into_response()),
21
+
Err(ValidationError::MissingField(field)) => Err((
22
+
StatusCode::BAD_REQUEST,
23
+
Json(json!({"error": "InvalidRecord", "message": format!("Missing required field: {}", field)})),
24
+
).into_response()),
25
+
Err(ValidationError::InvalidField { path, message }) => Err((
26
+
StatusCode::BAD_REQUEST,
27
+
Json(json!({"error": "InvalidRecord", "message": format!("Invalid field '{}': {}", path, message)})),
28
+
).into_response()),
29
+
Err(ValidationError::InvalidDatetime { path }) => Err((
30
+
StatusCode::BAD_REQUEST,
31
+
Json(json!({"error": "InvalidRecord", "message": format!("Invalid datetime format at '{}'", path)})),
32
+
).into_response()),
33
+
Err(e) => Err((
34
+
StatusCode::BAD_REQUEST,
35
+
Json(json!({"error": "InvalidRecord", "message": e.to_string()})),
36
+
).into_response()),
37
+
}
38
+
}
+5
-16
src/api/repo/record/write.rs
+5
-16
src/api/repo/record/write.rs
···
1
+
use super::validation::validate_record;
1
2
use crate::api::repo::record::utils::{commit_and_log, RecordOp};
2
3
use crate::repo::tracking::TrackingBlockStore;
3
4
use crate::state::AppState;
···
156
157
};
157
158
158
159
if input.validate.unwrap_or(true) {
159
-
if input.collection == "app.bsky.feed.post" {
160
-
if input.record.get("text").is_none() || input.record.get("createdAt").is_none() {
161
-
return (
162
-
StatusCode::BAD_REQUEST,
163
-
Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})),
164
-
)
165
-
.into_response();
166
-
}
160
+
if let Err(err_response) = validate_record(&input.record, &input.collection) {
161
+
return err_response;
167
162
}
168
163
}
169
164
···
263
258
let key = format!("{}/{}", collection_nsid, input.rkey);
264
259
265
260
if input.validate.unwrap_or(true) {
266
-
if input.collection == "app.bsky.feed.post" {
267
-
if input.record.get("text").is_none() || input.record.get("createdAt").is_none() {
268
-
return (
269
-
StatusCode::BAD_REQUEST,
270
-
Json(json!({"error": "InvalidRecord", "message": "Record validation failed"})),
271
-
)
272
-
.into_response();
273
-
}
261
+
if let Err(err_response) = validate_record(&input.record, &input.collection) {
262
+
return err_response;
274
263
}
275
264
}
276
265
+28
-5
src/api/server/account_status.rs
+28
-5
src/api/server/account_status.rs
···
123
123
Err(e) => return ApiError::from(e).into_response(),
124
124
};
125
125
126
+
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
127
+
.fetch_optional(&state.db)
128
+
.await
129
+
.ok()
130
+
.flatten();
131
+
126
132
let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
127
133
.execute(&state.db)
128
134
.await;
129
135
130
136
match result {
131
-
Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
137
+
Ok(_) => {
138
+
if let Some(h) = handle {
139
+
let _ = state.cache.delete(&format!("handle:{}", h)).await;
140
+
}
141
+
(StatusCode::OK, Json(json!({}))).into_response()
142
+
}
132
143
Err(e) => {
133
144
error!("DB error activating account: {:?}", e);
134
145
(
···
163
174
Err(e) => return ApiError::from(e).into_response(),
164
175
};
165
176
177
+
let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
178
+
.fetch_optional(&state.db)
179
+
.await
180
+
.ok()
181
+
.flatten();
182
+
166
183
let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did)
167
184
.execute(&state.db)
168
185
.await;
169
186
170
187
match result {
171
-
Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
188
+
Ok(_) => {
189
+
if let Some(h) = handle {
190
+
let _ = state.cache.delete(&format!("handle:{}", h)).await;
191
+
}
192
+
(StatusCode::OK, Json(json!({}))).into_response()
193
+
}
172
194
Err(e) => {
173
195
error!("DB error deactivating account: {:?}", e);
174
196
(
···
283
305
}
284
306
285
307
let user = sqlx::query!(
286
-
"SELECT id, password_hash FROM users WHERE did = $1",
308
+
"SELECT id, password_hash, handle FROM users WHERE did = $1",
287
309
did
288
310
)
289
311
.fetch_optional(&state.db)
290
312
.await;
291
313
292
-
let (user_id, password_hash) = match user {
293
-
Ok(Some(row)) => (row.id, row.password_hash),
314
+
let (user_id, password_hash, handle) = match user {
315
+
Ok(Some(row)) => (row.id, row.password_hash, row.handle),
294
316
Ok(None) => {
295
317
return (
296
318
StatusCode::BAD_REQUEST,
···
437
459
)
438
460
.into_response();
439
461
}
462
+
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
440
463
info!("Account {} deleted successfully", did);
441
464
(StatusCode::OK, Json(json!({}))).into_response()
442
465
}
+21
-1
src/api/server/app_password.rs
+21
-1
src/api/server/app_password.rs
···
5
5
use axum::{
6
6
Json,
7
7
extract::State,
8
+
http::HeaderMap,
8
9
response::{IntoResponse, Response},
9
10
};
10
11
use serde::{Deserialize, Serialize};
11
12
use serde_json::json;
12
-
use tracing::error;
13
+
use tracing::{error, warn};
13
14
14
15
#[derive(Serialize)]
15
16
#[serde(rename_all = "camelCase")]
···
76
77
77
78
pub async fn create_app_password(
78
79
State(state): State<AppState>,
80
+
headers: HeaderMap,
79
81
BearerAuth(auth_user): BearerAuth,
80
82
Json(input): Json<CreateAppPasswordInput>,
81
83
) -> Response {
84
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
85
+
if !state.distributed_rate_limiter.check_rate_limit(
86
+
&format!("app_password:{}", client_ip),
87
+
10,
88
+
60_000,
89
+
).await {
90
+
if state.rate_limiters.app_password.check_key(&client_ip).is_err() {
91
+
warn!(ip = %client_ip, "App password creation rate limit exceeded");
92
+
return (
93
+
axum::http::StatusCode::TOO_MANY_REQUESTS,
94
+
Json(json!({
95
+
"error": "RateLimitExceeded",
96
+
"message": "Too many requests. Please try again later."
97
+
})),
98
+
).into_response();
99
+
}
100
+
}
101
+
82
102
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
83
103
Ok(id) => id,
84
104
Err(e) => return ApiError::from(e).into_response(),
+36
src/api/server/email.rs
+36
src/api/server/email.rs
···
26
26
headers: axum::http::HeaderMap,
27
27
Json(input): Json<RequestEmailUpdateInput>,
28
28
) -> Response {
29
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
30
+
if !state.distributed_rate_limiter.check_rate_limit(
31
+
&format!("email_update:{}", client_ip),
32
+
5,
33
+
3_600_000,
34
+
).await {
35
+
if state.rate_limiters.email_update.check_key(&client_ip).is_err() {
36
+
warn!(ip = %client_ip, "Email update rate limit exceeded");
37
+
return (
38
+
StatusCode::TOO_MANY_REQUESTS,
39
+
Json(json!({
40
+
"error": "RateLimitExceeded",
41
+
"message": "Too many requests. Please try again later."
42
+
})),
43
+
).into_response();
44
+
}
45
+
}
46
+
29
47
let token = match crate::auth::extract_bearer_token_from_header(
30
48
headers.get("Authorization").and_then(|h| h.to_str().ok())
31
49
) {
···
135
153
headers: axum::http::HeaderMap,
136
154
Json(input): Json<ConfirmEmailInput>,
137
155
) -> Response {
156
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
157
+
if !state.distributed_rate_limiter.check_rate_limit(
158
+
&format!("confirm_email:{}", client_ip),
159
+
10,
160
+
60_000,
161
+
).await {
162
+
if state.rate_limiters.app_password.check_key(&client_ip).is_err() {
163
+
warn!(ip = %client_ip, "Confirm email rate limit exceeded");
164
+
return (
165
+
StatusCode::TOO_MANY_REQUESTS,
166
+
Json(json!({
167
+
"error": "RateLimitExceeded",
168
+
"message": "Too many requests. Please try again later."
169
+
})),
170
+
).into_response();
171
+
}
172
+
}
173
+
138
174
let token = match crate::auth::extract_bearer_token_from_header(
139
175
headers.get("Authorization").and_then(|h| h.to_str().ok())
140
176
) {
+19
src/api/server/password.rs
+19
src/api/server/password.rs
···
124
124
125
125
pub async fn reset_password(
126
126
State(state): State<AppState>,
127
+
headers: HeaderMap,
127
128
Json(input): Json<ResetPasswordInput>,
128
129
) -> Response {
130
+
let client_ip = extract_client_ip(&headers);
131
+
if !state.distributed_rate_limiter.check_rate_limit(
132
+
&format!("reset_password:{}", client_ip),
133
+
10,
134
+
60_000,
135
+
).await {
136
+
if state.rate_limiters.reset_password.check_key(&client_ip).is_err() {
137
+
warn!(ip = %client_ip, "Reset password rate limit exceeded");
138
+
return (
139
+
StatusCode::TOO_MANY_REQUESTS,
140
+
Json(json!({
141
+
"error": "RateLimitExceeded",
142
+
"message": "Too many requests. Please try again later."
143
+
})),
144
+
).into_response();
145
+
}
146
+
}
147
+
129
148
let token = input.token.trim();
130
149
let password = &input.password;
131
150
+19
src/api/server/session.rs
+19
src/api/server/session.rs
···
72
72
{
73
73
Ok(Some(row)) => row,
74
74
Ok(None) => {
75
+
let _ = verify(&input.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK");
75
76
warn!("User not found for login attempt");
76
77
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
77
78
}
···
196
197
State(state): State<AppState>,
197
198
headers: axum::http::HeaderMap,
198
199
) -> Response {
200
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
201
+
if !state.distributed_rate_limiter.check_rate_limit(
202
+
&format!("refresh_session:{}", client_ip),
203
+
60,
204
+
60_000,
205
+
).await {
206
+
if state.rate_limiters.refresh_session.check_key(&client_ip).is_err() {
207
+
tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded");
208
+
return (
209
+
axum::http::StatusCode::TOO_MANY_REQUESTS,
210
+
axum::Json(serde_json::json!({
211
+
"error": "RateLimitExceeded",
212
+
"message": "Too many requests. Please try again later."
213
+
})),
214
+
).into_response();
215
+
}
216
+
}
217
+
199
218
let refresh_token = match crate::auth::extract_bearer_token_from_header(
200
219
headers.get("Authorization").and_then(|h| h.to_str().ok())
201
220
) {
+207
src/cache/mod.rs
+207
src/cache/mod.rs
···
1
+
use async_trait::async_trait;
2
+
use std::sync::Arc;
3
+
use std::time::Duration;
4
+
5
+
#[derive(Debug, thiserror::Error)]
6
+
pub enum CacheError {
7
+
#[error("Cache connection error: {0}")]
8
+
Connection(String),
9
+
#[error("Serialization error: {0}")]
10
+
Serialization(String),
11
+
}
12
+
13
+
#[async_trait]
14
+
pub trait Cache: Send + Sync {
15
+
async fn get(&self, key: &str) -> Option<String>;
16
+
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError>;
17
+
async fn delete(&self, key: &str) -> Result<(), CacheError>;
18
+
}
19
+
20
+
#[derive(Clone)]
21
+
pub struct ValkeyCache {
22
+
conn: redis::aio::ConnectionManager,
23
+
}
24
+
25
+
impl ValkeyCache {
26
+
pub async fn new(url: &str) -> Result<Self, CacheError> {
27
+
let client = redis::Client::open(url)
28
+
.map_err(|e| CacheError::Connection(e.to_string()))?;
29
+
let manager = client
30
+
.get_connection_manager()
31
+
.await
32
+
.map_err(|e| CacheError::Connection(e.to_string()))?;
33
+
Ok(Self { conn: manager })
34
+
}
35
+
36
+
pub fn connection(&self) -> redis::aio::ConnectionManager {
37
+
self.conn.clone()
38
+
}
39
+
}
40
+
41
+
#[async_trait]
42
+
impl Cache for ValkeyCache {
43
+
async fn get(&self, key: &str) -> Option<String> {
44
+
let mut conn = self.conn.clone();
45
+
redis::cmd("GET")
46
+
.arg(key)
47
+
.query_async::<Option<String>>(&mut conn)
48
+
.await
49
+
.ok()
50
+
.flatten()
51
+
}
52
+
53
+
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
54
+
let mut conn = self.conn.clone();
55
+
redis::cmd("SET")
56
+
.arg(key)
57
+
.arg(value)
58
+
.arg("EX")
59
+
.arg(ttl.as_secs() as i64)
60
+
.query_async::<()>(&mut conn)
61
+
.await
62
+
.map_err(|e| CacheError::Connection(e.to_string()))
63
+
}
64
+
65
+
async fn delete(&self, key: &str) -> Result<(), CacheError> {
66
+
let mut conn = self.conn.clone();
67
+
redis::cmd("DEL")
68
+
.arg(key)
69
+
.query_async::<()>(&mut conn)
70
+
.await
71
+
.map_err(|e| CacheError::Connection(e.to_string()))
72
+
}
73
+
}
74
+
75
+
pub struct NoOpCache;
76
+
77
+
#[async_trait]
78
+
impl Cache for NoOpCache {
79
+
async fn get(&self, _key: &str) -> Option<String> {
80
+
None
81
+
}
82
+
83
+
async fn set(&self, _key: &str, _value: &str, _ttl: Duration) -> Result<(), CacheError> {
84
+
Ok(())
85
+
}
86
+
87
+
async fn delete(&self, _key: &str) -> Result<(), CacheError> {
88
+
Ok(())
89
+
}
90
+
}
91
+
92
+
#[async_trait]
93
+
pub trait DistributedRateLimiter: Send + Sync {
94
+
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool;
95
+
}
96
+
97
+
#[derive(Clone)]
98
+
pub struct RedisRateLimiter {
99
+
conn: redis::aio::ConnectionManager,
100
+
}
101
+
102
+
impl RedisRateLimiter {
103
+
pub fn new(conn: redis::aio::ConnectionManager) -> Self {
104
+
Self { conn }
105
+
}
106
+
}
107
+
108
+
#[async_trait]
109
+
impl DistributedRateLimiter for RedisRateLimiter {
110
+
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
111
+
let mut conn = self.conn.clone();
112
+
let full_key = format!("rl:{}", key);
113
+
let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
114
+
115
+
let count: Result<i64, _> = redis::cmd("INCR")
116
+
.arg(&full_key)
117
+
.query_async(&mut conn)
118
+
.await;
119
+
120
+
let count = match count {
121
+
Ok(c) => c,
122
+
Err(e) => {
123
+
tracing::warn!("Redis rate limit INCR failed: {}. Allowing request.", e);
124
+
return true;
125
+
}
126
+
};
127
+
128
+
if count == 1 {
129
+
let _: Result<bool, redis::RedisError> = redis::cmd("EXPIRE")
130
+
.arg(&full_key)
131
+
.arg(window_secs)
132
+
.query_async(&mut conn)
133
+
.await;
134
+
}
135
+
136
+
count <= limit as i64
137
+
}
138
+
}
139
+
140
+
pub struct NoOpRateLimiter;
141
+
142
+
#[async_trait]
143
+
impl DistributedRateLimiter for NoOpRateLimiter {
144
+
async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
145
+
true
146
+
}
147
+
}
148
+
149
+
pub enum CacheBackend {
150
+
Valkey(ValkeyCache),
151
+
NoOp,
152
+
}
153
+
154
+
impl CacheBackend {
155
+
pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
156
+
match self {
157
+
CacheBackend::Valkey(cache) => {
158
+
Arc::new(RedisRateLimiter::new(cache.connection()))
159
+
}
160
+
CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
161
+
}
162
+
}
163
+
}
164
+
165
+
#[async_trait]
166
+
impl Cache for CacheBackend {
167
+
async fn get(&self, key: &str) -> Option<String> {
168
+
match self {
169
+
CacheBackend::Valkey(c) => c.get(key).await,
170
+
CacheBackend::NoOp => None,
171
+
}
172
+
}
173
+
174
+
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
175
+
match self {
176
+
CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
177
+
CacheBackend::NoOp => Ok(()),
178
+
}
179
+
}
180
+
181
+
async fn delete(&self, key: &str) -> Result<(), CacheError> {
182
+
match self {
183
+
CacheBackend::Valkey(c) => c.delete(key).await,
184
+
CacheBackend::NoOp => Ok(()),
185
+
}
186
+
}
187
+
}
188
+
189
+
pub async fn create_cache() -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) {
190
+
match std::env::var("VALKEY_URL") {
191
+
Ok(url) => match ValkeyCache::new(&url).await {
192
+
Ok(cache) => {
193
+
tracing::info!("Connected to Valkey cache at {}", url);
194
+
let rate_limiter = Arc::new(RedisRateLimiter::new(cache.connection()));
195
+
(Arc::new(cache), rate_limiter)
196
+
}
197
+
Err(e) => {
198
+
tracing::warn!("Failed to connect to Valkey: {}. Running without cache.", e);
199
+
(Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
200
+
}
201
+
},
202
+
Err(_) => {
203
+
tracing::info!("VALKEY_URL not set. Running without cache.");
204
+
(Arc::new(NoOpCache), Arc::new(NoOpRateLimiter))
205
+
}
206
+
}
207
+
}
+1
src/lib.rs
+1
src/lib.rs
+14
src/oauth/endpoints/par.rs
+14
src/oauth/endpoints/par.rs
···
1
1
use axum::{
2
2
Form, Json,
3
3
extract::State,
4
+
http::HeaderMap,
4
5
};
5
6
use chrono::{Duration, Utc};
6
7
use serde::{Deserialize, Serialize};
···
49
50
50
51
pub async fn pushed_authorization_request(
51
52
State(state): State<AppState>,
53
+
headers: HeaderMap,
52
54
Form(request): Form<ParRequest>,
53
55
) -> Result<Json<ParResponse>, OAuthError> {
56
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
57
+
if !state.distributed_rate_limiter.check_rate_limit(
58
+
&format!("oauth_par:{}", client_ip),
59
+
30,
60
+
60_000,
61
+
).await {
62
+
if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() {
63
+
tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
64
+
return Err(OAuthError::RateLimited);
65
+
}
66
+
}
67
+
54
68
if request.response_type != "code" {
55
69
return Err(OAuthError::InvalidRequest(
56
70
"response_type must be 'code'".to_string(),
+33
-7
src/oauth/endpoints/token/introspect.rs
+33
-7
src/oauth/endpoints/token/introspect.rs
···
1
1
use axum::{Form, Json};
2
2
use axum::extract::State;
3
-
use axum::http::StatusCode;
3
+
use axum::http::{HeaderMap, StatusCode};
4
4
use chrono::Utc;
5
5
use serde::{Deserialize, Serialize};
6
6
···
18
18
19
19
pub async fn revoke_token(
20
20
State(state): State<AppState>,
21
+
headers: HeaderMap,
21
22
Form(request): Form<RevokeRequest>,
22
23
) -> Result<StatusCode, OAuthError> {
24
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
25
+
if !state.distributed_rate_limiter.check_rate_limit(
26
+
&format!("oauth_revoke:{}", client_ip),
27
+
30,
28
+
60_000,
29
+
).await {
30
+
if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() {
31
+
tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
32
+
return Err(OAuthError::RateLimited);
33
+
}
34
+
}
35
+
23
36
if let Some(token) = &request.token {
24
37
if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
25
38
db::delete_token_family(&state.db, db_id).await?;
···
67
80
68
81
pub async fn introspect_token(
69
82
State(state): State<AppState>,
83
+
headers: HeaderMap,
70
84
Form(request): Form<IntrospectRequest>,
71
-
) -> Json<IntrospectResponse> {
85
+
) -> Result<Json<IntrospectResponse>, OAuthError> {
86
+
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
87
+
if !state.distributed_rate_limiter.check_rate_limit(
88
+
&format!("oauth_introspect:{}", client_ip),
89
+
30,
90
+
60_000,
91
+
).await {
92
+
if state.rate_limiters.oauth_introspect.check_key(&client_ip).is_err() {
93
+
tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
94
+
return Err(OAuthError::RateLimited);
95
+
}
96
+
}
97
+
72
98
let inactive_response = IntrospectResponse {
73
99
active: false,
74
100
scope: None,
···
86
112
87
113
let token_info = match extract_token_claims(&request.token) {
88
114
Ok(info) => info,
89
-
Err(_) => return Json(inactive_response),
115
+
Err(_) => return Ok(Json(inactive_response)),
90
116
};
91
117
92
118
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
93
119
Ok(Some(data)) => data,
94
-
_ => return Json(inactive_response),
120
+
_ => return Ok(Json(inactive_response)),
95
121
};
96
122
97
123
if token_data.expires_at < Utc::now() {
98
-
return Json(inactive_response);
124
+
return Ok(Json(inactive_response));
99
125
}
100
126
101
127
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102
128
let issuer = format!("https://{}", pds_hostname);
103
129
104
-
Json(IntrospectResponse {
130
+
Ok(Json(IntrospectResponse {
105
131
active: true,
106
132
scope: token_data.scope,
107
133
client_id: Some(token_data.client_id),
···
118
144
aud: Some(issuer.clone()),
119
145
iss: Some(issuer),
120
146
jti: Some(token_info.jti),
121
-
})
147
+
}))
122
148
}
+4
src/oauth/error.rs
+4
src/oauth/error.rs
···
19
19
InvalidDpopProof(String),
20
20
ExpiredToken(String),
21
21
InvalidToken(String),
22
+
RateLimited,
22
23
}
23
24
24
25
#[derive(Serialize)]
···
73
74
}
74
75
OAuthError::InvalidToken(msg) => {
75
76
(StatusCode::UNAUTHORIZED, "invalid_token", Some(msg))
77
+
}
78
+
OAuthError::RateLimited => {
79
+
(StatusCode::TOO_MANY_REQUESTS, "rate_limited", Some("Too many requests. Please try again later.".to_string()))
76
80
}
77
81
};
78
82
+36
-1
src/rate_limit.rs
+36
-1
src/rate_limit.rs
···
24
24
pub struct RateLimiters {
25
25
pub login: Arc<KeyedRateLimiter>,
26
26
pub oauth_token: Arc<KeyedRateLimiter>,
27
+
pub oauth_authorize: Arc<KeyedRateLimiter>,
27
28
pub password_reset: Arc<KeyedRateLimiter>,
28
29
pub account_creation: Arc<KeyedRateLimiter>,
30
+
pub refresh_session: Arc<KeyedRateLimiter>,
31
+
pub reset_password: Arc<KeyedRateLimiter>,
32
+
pub oauth_par: Arc<KeyedRateLimiter>,
33
+
pub oauth_introspect: Arc<KeyedRateLimiter>,
34
+
pub app_password: Arc<KeyedRateLimiter>,
35
+
pub email_update: Arc<KeyedRateLimiter>,
29
36
}
30
37
31
38
impl Default for RateLimiters {
···
42
49
)),
43
50
oauth_token: Arc::new(RateLimiter::keyed(
44
51
Quota::per_minute(NonZeroU32::new(30).unwrap())
52
+
)),
53
+
oauth_authorize: Arc::new(RateLimiter::keyed(
54
+
Quota::per_minute(NonZeroU32::new(10).unwrap())
45
55
)),
46
56
password_reset: Arc::new(RateLimiter::keyed(
47
57
Quota::per_hour(NonZeroU32::new(5).unwrap())
···
49
59
account_creation: Arc::new(RateLimiter::keyed(
50
60
Quota::per_hour(NonZeroU32::new(10).unwrap())
51
61
)),
62
+
refresh_session: Arc::new(RateLimiter::keyed(
63
+
Quota::per_minute(NonZeroU32::new(60).unwrap())
64
+
)),
65
+
reset_password: Arc::new(RateLimiter::keyed(
66
+
Quota::per_minute(NonZeroU32::new(10).unwrap())
67
+
)),
68
+
oauth_par: Arc::new(RateLimiter::keyed(
69
+
Quota::per_minute(NonZeroU32::new(30).unwrap())
70
+
)),
71
+
oauth_introspect: Arc::new(RateLimiter::keyed(
72
+
Quota::per_minute(NonZeroU32::new(30).unwrap())
73
+
)),
74
+
app_password: Arc::new(RateLimiter::keyed(
75
+
Quota::per_minute(NonZeroU32::new(10).unwrap())
76
+
)),
77
+
email_update: Arc::new(RateLimiter::keyed(
78
+
Quota::per_hour(NonZeroU32::new(5).unwrap())
79
+
)),
52
80
}
53
81
}
54
82
···
66
94
self
67
95
}
68
96
97
+
pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
98
+
self.oauth_authorize = Arc::new(RateLimiter::keyed(
99
+
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
100
+
));
101
+
self
102
+
}
103
+
69
104
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
70
105
self.password_reset = Arc::new(RateLimiter::keyed(
71
106
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
···
81
116
}
82
117
}
83
118
84
-
fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
119
+
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
85
120
if let Some(forwarded) = headers.get("x-forwarded-for") {
86
121
if let Ok(value) = forwarded.to_str() {
87
122
if let Some(first_ip) = value.split(',').next() {
+6
src/state.rs
+6
src/state.rs
···
1
+
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
1
2
use crate::circuit_breaker::CircuitBreakers;
2
3
use crate::config::AuthConfig;
3
4
use crate::rate_limit::RateLimiters;
···
16
17
pub firehose_tx: broadcast::Sender<SequencedEvent>,
17
18
pub rate_limiters: Arc<RateLimiters>,
18
19
pub circuit_breakers: Arc<CircuitBreakers>,
20
+
pub cache: Arc<dyn Cache>,
21
+
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
19
22
}
20
23
21
24
impl AppState {
···
27
30
let (firehose_tx, _) = broadcast::channel(1000);
28
31
let rate_limiters = Arc::new(RateLimiters::new());
29
32
let circuit_breakers = Arc::new(CircuitBreakers::new());
33
+
let (cache, distributed_rate_limiter) = create_cache().await;
30
34
Self {
31
35
db,
32
36
block_store,
···
34
38
firehose_tx,
35
39
rate_limiters,
36
40
circuit_breakers,
41
+
cache,
42
+
distributed_rate_limiter,
37
43
}
38
44
}
39
45
+64
tests/oauth_security.rs
+64
tests/oauth_security.rs
···
1447
1447
let introspect_body: Value = introspect_res.json().await.unwrap();
1448
1448
assert_eq!(introspect_body["active"], false, "Revoked token should be inactive");
1449
1449
}
1450
+
1451
+
#[tokio::test]
1452
+
async fn test_security_oauth_authorize_rate_limiting() {
1453
+
let url = base_url().await;
1454
+
let http_client = no_redirect_client();
1455
+
1456
+
let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0);
1457
+
let unique_ip = format!("10.{}.{}.{}", (ts >> 16) & 0xFF, (ts >> 8) & 0xFF, ts & 0xFF);
1458
+
1459
+
let redirect_uri = "https://example.com/rate-limit-callback";
1460
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1461
+
let client_id = mock_client.uri();
1462
+
1463
+
let (_, code_challenge) = generate_pkce();
1464
+
1465
+
let client_for_par = client();
1466
+
let par_body: Value = client_for_par
1467
+
.post(format!("{}/oauth/par", url))
1468
+
.form(&[
1469
+
("response_type", "code"),
1470
+
("client_id", &client_id),
1471
+
("redirect_uri", redirect_uri),
1472
+
("code_challenge", &code_challenge),
1473
+
("code_challenge_method", "S256"),
1474
+
])
1475
+
.send()
1476
+
.await
1477
+
.unwrap()
1478
+
.json()
1479
+
.await
1480
+
.unwrap();
1481
+
1482
+
let request_uri = par_body["request_uri"].as_str().unwrap();
1483
+
1484
+
let mut rate_limited_count = 0;
1485
+
let mut other_count = 0;
1486
+
1487
+
for _ in 0..15 {
1488
+
let res = http_client
1489
+
.post(format!("{}/oauth/authorize", url))
1490
+
.header("X-Forwarded-For", &unique_ip)
1491
+
.form(&[
1492
+
("request_uri", request_uri),
1493
+
("username", "nonexistent_user"),
1494
+
("password", "wrong_password"),
1495
+
("remember_device", "false"),
1496
+
])
1497
+
.send()
1498
+
.await
1499
+
.unwrap();
1500
+
1501
+
match res.status() {
1502
+
StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1,
1503
+
_ => other_count += 1,
1504
+
}
1505
+
}
1506
+
1507
+
assert!(
1508
+
rate_limited_count > 0,
1509
+
"Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.",
1510
+
other_count,
1511
+
rate_limited_count
1512
+
);
1513
+
}
+228
tests/rate_limit.rs
+228
tests/rate_limit.rs
···
1
+
mod common;
2
+
3
+
use common::{base_url, client};
4
+
use reqwest::StatusCode;
5
+
use serde_json::json;
6
+
7
+
#[tokio::test]
8
+
async fn test_login_rate_limiting() {
9
+
let client = client();
10
+
let url = format!("{}/xrpc/com.atproto.server.createSession", base_url().await);
11
+
12
+
let payload = json!({
13
+
"identifier": "nonexistent_user_for_rate_limit_test",
14
+
"password": "wrongpassword"
15
+
});
16
+
17
+
let mut rate_limited_count = 0;
18
+
let mut auth_failed_count = 0;
19
+
20
+
for _ in 0..15 {
21
+
let res = client
22
+
.post(&url)
23
+
.json(&payload)
24
+
.send()
25
+
.await
26
+
.expect("Request failed");
27
+
28
+
match res.status() {
29
+
StatusCode::TOO_MANY_REQUESTS => {
30
+
rate_limited_count += 1;
31
+
}
32
+
StatusCode::UNAUTHORIZED => {
33
+
auth_failed_count += 1;
34
+
}
35
+
status => {
36
+
panic!("Unexpected status: {}", status);
37
+
}
38
+
}
39
+
}
40
+
41
+
assert!(
42
+
rate_limited_count > 0,
43
+
"Expected at least one rate-limited response after 15 login attempts. Got {} auth failures and {} rate limits.",
44
+
auth_failed_count,
45
+
rate_limited_count
46
+
);
47
+
}
48
+
49
+
#[tokio::test]
50
+
async fn test_password_reset_rate_limiting() {
51
+
let client = client();
52
+
let url = format!(
53
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
54
+
base_url().await
55
+
);
56
+
57
+
let mut rate_limited_count = 0;
58
+
let mut success_count = 0;
59
+
60
+
for i in 0..8 {
61
+
let payload = json!({
62
+
"email": format!("ratelimit_test_{}@example.com", i)
63
+
});
64
+
65
+
let res = client
66
+
.post(&url)
67
+
.json(&payload)
68
+
.send()
69
+
.await
70
+
.expect("Request failed");
71
+
72
+
match res.status() {
73
+
StatusCode::TOO_MANY_REQUESTS => {
74
+
rate_limited_count += 1;
75
+
}
76
+
StatusCode::OK => {
77
+
success_count += 1;
78
+
}
79
+
status => {
80
+
panic!("Unexpected status: {} - {:?}", status, res.text().await);
81
+
}
82
+
}
83
+
}
84
+
85
+
assert!(
86
+
rate_limited_count > 0,
87
+
"Expected rate limiting after {} password reset requests. Got {} successes.",
88
+
success_count + rate_limited_count,
89
+
success_count
90
+
);
91
+
}
92
+
93
+
#[tokio::test]
94
+
async fn test_account_creation_rate_limiting() {
95
+
let client = client();
96
+
let url = format!(
97
+
"{}/xrpc/com.atproto.server.createAccount",
98
+
base_url().await
99
+
);
100
+
101
+
let mut rate_limited_count = 0;
102
+
let mut other_count = 0;
103
+
104
+
for i in 0..15 {
105
+
let unique_id = uuid::Uuid::new_v4();
106
+
let payload = json!({
107
+
"handle": format!("ratelimit_{}_{}", i, unique_id),
108
+
"email": format!("ratelimit_{}_{}@example.com", i, unique_id),
109
+
"password": "testpassword123"
110
+
});
111
+
112
+
let res = client
113
+
.post(&url)
114
+
.json(&payload)
115
+
.send()
116
+
.await
117
+
.expect("Request failed");
118
+
119
+
match res.status() {
120
+
StatusCode::TOO_MANY_REQUESTS => {
121
+
rate_limited_count += 1;
122
+
}
123
+
_ => {
124
+
other_count += 1;
125
+
}
126
+
}
127
+
}
128
+
129
+
assert!(
130
+
rate_limited_count > 0,
131
+
"Expected rate limiting after account creation attempts. Got {} other responses and {} rate limits.",
132
+
other_count,
133
+
rate_limited_count
134
+
);
135
+
}
136
+
137
+
#[tokio::test]
138
+
async fn test_valkey_connection() {
139
+
if std::env::var("VALKEY_URL").is_err() {
140
+
println!("VALKEY_URL not set, skipping Valkey connection test");
141
+
return;
142
+
}
143
+
144
+
let valkey_url = std::env::var("VALKEY_URL").unwrap();
145
+
let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client");
146
+
let mut conn = client
147
+
.get_multiplexed_async_connection()
148
+
.await
149
+
.expect("Failed to connect to Valkey");
150
+
151
+
let pong: String = redis::cmd("PING")
152
+
.query_async(&mut conn)
153
+
.await
154
+
.expect("PING failed");
155
+
assert_eq!(pong, "PONG");
156
+
157
+
let _: () = redis::cmd("SET")
158
+
.arg("test_key")
159
+
.arg("test_value")
160
+
.arg("EX")
161
+
.arg(10)
162
+
.query_async(&mut conn)
163
+
.await
164
+
.expect("SET failed");
165
+
166
+
let value: String = redis::cmd("GET")
167
+
.arg("test_key")
168
+
.query_async(&mut conn)
169
+
.await
170
+
.expect("GET failed");
171
+
assert_eq!(value, "test_value");
172
+
173
+
let _: () = redis::cmd("DEL")
174
+
.arg("test_key")
175
+
.query_async(&mut conn)
176
+
.await
177
+
.expect("DEL failed");
178
+
}
179
+
180
+
#[tokio::test]
181
+
async fn test_distributed_rate_limiter_directly() {
182
+
if std::env::var("VALKEY_URL").is_err() {
183
+
println!("VALKEY_URL not set, skipping distributed rate limiter test");
184
+
return;
185
+
}
186
+
187
+
use bspds::cache::{DistributedRateLimiter, RedisRateLimiter};
188
+
189
+
let valkey_url = std::env::var("VALKEY_URL").unwrap();
190
+
let client = redis::Client::open(valkey_url.as_str()).expect("Failed to create Redis client");
191
+
let conn = client
192
+
.get_connection_manager()
193
+
.await
194
+
.expect("Failed to get connection manager");
195
+
196
+
let rate_limiter = RedisRateLimiter::new(conn);
197
+
198
+
let test_key = format!("test_rate_limit_{}", uuid::Uuid::new_v4());
199
+
let limit = 5;
200
+
let window_ms = 60_000;
201
+
202
+
for i in 0..limit {
203
+
let allowed = rate_limiter
204
+
.check_rate_limit(&test_key, limit, window_ms)
205
+
.await;
206
+
assert!(
207
+
allowed,
208
+
"Request {} should have been allowed (limit: {})",
209
+
i + 1,
210
+
limit
211
+
);
212
+
}
213
+
214
+
let allowed = rate_limiter
215
+
.check_rate_limit(&test_key, limit, window_ms)
216
+
.await;
217
+
assert!(
218
+
!allowed,
219
+
"Request {} should have been rate limited (limit: {})",
220
+
limit + 1,
221
+
limit
222
+
);
223
+
224
+
let allowed = rate_limiter
225
+
.check_rate_limit(&test_key, limit, window_ms)
226
+
.await;
227
+
assert!(!allowed, "Subsequent request should also be rate limited");
228
+
}