+2
-2
.sqlx/query-3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817.json
.sqlx/query-e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8.json
+2
-2
.sqlx/query-3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817.json
.sqlx/query-e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()",
3
+
"query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()\n FOR UPDATE OF st",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
36
36
true
37
37
]
38
38
},
39
-
"hash": "3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817"
39
+
"hash": "e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8"
40
40
}
+3
-2
.sqlx/query-51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c.json
.sqlx/query-b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d.json
+3
-2
.sqlx/query-51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c.json
.sqlx/query-b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n ",
3
+
"query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2\n ",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
51
51
],
52
52
"parameters": {
53
53
"Left": [
54
+
"Int8",
54
55
"Int8"
55
56
]
56
57
},
···
66
67
true
67
68
]
68
69
},
69
-
"hash": "51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c"
70
+
"hash": "b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d"
70
71
}
+14
.sqlx/query-642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f.json
+14
.sqlx/query-642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM invite_codes WHERE created_by_user = $1",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Uuid"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f"
14
+
}
+14
.sqlx/query-6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea.json
+14
.sqlx/query-6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM invite_code_uses WHERE used_by_user = $1",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Uuid"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea"
14
+
}
+2
-2
.sqlx/query-7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd.json
.sqlx/query-9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529.json
+2
-2
.sqlx/query-7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd.json
.sqlx/query-9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2)",
3
+
"query": "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
4
4
"describe": {
5
5
"columns": [],
6
6
"parameters": {
···
11
11
},
12
12
"nullable": []
13
13
},
14
-
"hash": "7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd"
14
+
"hash": "9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529"
15
15
}
+34
.sqlx/query-9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020.json
+34
.sqlx/query-9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "handle",
19
+
"type_info": "Text"
20
+
}
21
+
],
22
+
"parameters": {
23
+
"Left": [
24
+
"Text"
25
+
]
26
+
},
27
+
"nullable": [
28
+
false,
29
+
false,
30
+
false
31
+
]
32
+
},
33
+
"hash": "9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020"
34
+
}
+34
.sqlx/query-b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3.json
+34
.sqlx/query-b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, did, handle FROM users WHERE did = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "handle",
19
+
"type_info": "Text"
20
+
}
21
+
],
22
+
"parameters": {
23
+
"Left": [
24
+
"Text"
25
+
]
26
+
},
27
+
"nullable": [
28
+
false,
29
+
false,
30
+
false
31
+
]
32
+
},
33
+
"hash": "b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3"
34
+
}
+14
.sqlx/query-c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813.json
+14
.sqlx/query-c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813.json
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813"
14
+
}
+2
-2
.sqlx/query-fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3.json
.sqlx/query-b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7.json
+2
-2
.sqlx/query-fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3.json
.sqlx/query-b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7.json
···
1
1
{
2
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1",
3
+
"query": "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE",
4
4
"describe": {
5
5
"columns": [
6
6
{
···
18
18
false
19
19
]
20
20
},
21
-
"hash": "fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3"
21
+
"hash": "b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7"
22
22
}
+4
-4
TODO.md
+4
-4
TODO.md
···
253
253
### Frontend Views
254
254
Uses existing ATProto endpoints where possible:
255
255
256
-
**User Dashboard**
256
+
User Dashboard
257
257
- [ ] Account overview (uses `com.atproto.server.getSession`, `com.atproto.admin.getAccountInfo`)
258
258
- [ ] Active sessions view (needs new endpoint or extend existing)
259
259
- [ ] App passwords (uses `com.atproto.server.listAppPasswords`, `createAppPassword`, `revokeAppPassword`)
260
260
- [ ] Invite codes (uses `com.atproto.server.getAccountInviteCodes`, `createInviteCode`)
261
261
262
-
**Notification Preferences**
262
+
Notification Preferences
263
263
- [ ] Channel selector (uses `com.bspds.account.*` endpoints above)
264
264
- [ ] Verification flows for Discord/Telegram/Signal
265
265
- [ ] Notification history view
266
266
267
-
**Account Settings**
267
+
Account Settings
268
268
- [ ] Email change (uses `com.atproto.server.requestEmailUpdate`, `updateEmail`)
269
269
- [ ] Password change (uses `com.atproto.server.requestPasswordReset`, `resetPassword`)
270
270
- [ ] Handle change (uses `com.atproto.identity.updateHandle`)
271
271
- [ ] Account deletion (uses `com.atproto.server.requestAccountDelete`, `deleteAccount`)
272
272
- [ ] Data export (uses `com.atproto.sync.getRepo`)
273
273
274
-
**Admin Dashboard** (privileged users only)
274
+
Admin Dashboard (privileged users only)
275
275
- [ ] User list (uses `com.atproto.admin.getAccountInfos` with pagination)
276
276
- [ ] User detail/actions (uses `com.atproto.admin.*` endpoints)
277
277
- [ ] Invite management (uses `com.atproto.admin.getInviteCodes`, `disableInviteCodes`)
+23
-1
src/api/actor/preferences.rs
+23
-1
src/api/actor/preferences.rs
···
9
9
use serde_json::{json, Value};
10
10
11
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
12
+
const MAX_PREFERENCES_COUNT: usize = 100;
13
+
const MAX_PREFERENCE_SIZE: usize = 10_000;
12
14
13
15
#[derive(Serialize)]
14
16
pub struct GetPreferencesOutput {
···
141
143
}
142
144
};
143
145
146
+
if input.preferences.len() > MAX_PREFERENCES_COUNT {
147
+
return (
148
+
StatusCode::BAD_REQUEST,
149
+
Json(json!({"error": "InvalidRequest", "message": format!("Too many preferences: {} exceeds limit of {}", input.preferences.len(), MAX_PREFERENCES_COUNT)})),
150
+
)
151
+
.into_response();
152
+
}
153
+
144
154
for pref in &input.preferences {
155
+
let pref_str = serde_json::to_string(pref).unwrap_or_default();
156
+
if pref_str.len() > MAX_PREFERENCE_SIZE {
157
+
return (
158
+
StatusCode::BAD_REQUEST,
159
+
Json(json!({"error": "InvalidRequest", "message": format!("Preference too large: {} bytes exceeds limit of {}", pref_str.len(), MAX_PREFERENCE_SIZE)})),
160
+
)
161
+
.into_response();
162
+
}
163
+
145
164
let pref_type = match pref.get("$type").and_then(|t| t.as_str()) {
146
165
Some(t) => t,
147
166
None => {
···
200
219
}
201
220
202
221
for pref in input.preferences {
203
-
let pref_type = pref.get("$type").and_then(|t| t.as_str()).unwrap();
222
+
let pref_type = match pref.get("$type").and_then(|t| t.as_str()) {
223
+
Some(t) => t,
224
+
None => continue,
225
+
};
204
226
205
227
let insert_result = sqlx::query!(
206
228
"INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3)",
-564
src/api/admin/account.rs
-564
src/api/admin/account.rs
···
1
-
use crate::state::AppState;
2
-
use axum::{
3
-
Json,
4
-
extract::{Query, State},
5
-
http::StatusCode,
6
-
response::{IntoResponse, Response},
7
-
};
8
-
use serde::{Deserialize, Serialize};
9
-
use serde_json::json;
10
-
use tracing::{error, warn};
11
-
12
-
#[derive(Deserialize)]
13
-
pub struct GetAccountInfoParams {
14
-
pub did: String,
15
-
}
16
-
17
-
#[derive(Serialize)]
18
-
#[serde(rename_all = "camelCase")]
19
-
pub struct AccountInfo {
20
-
pub did: String,
21
-
pub handle: String,
22
-
pub email: Option<String>,
23
-
pub indexed_at: String,
24
-
pub invite_note: Option<String>,
25
-
pub invites_disabled: bool,
26
-
pub email_confirmed_at: Option<String>,
27
-
pub deactivated_at: Option<String>,
28
-
}
29
-
30
-
#[derive(Serialize)]
31
-
#[serde(rename_all = "camelCase")]
32
-
pub struct GetAccountInfosOutput {
33
-
pub infos: Vec<AccountInfo>,
34
-
}
35
-
36
-
pub async fn get_account_info(
37
-
State(state): State<AppState>,
38
-
headers: axum::http::HeaderMap,
39
-
Query(params): Query<GetAccountInfoParams>,
40
-
) -> Response {
41
-
let auth_header = headers.get("Authorization");
42
-
if auth_header.is_none() {
43
-
return (
44
-
StatusCode::UNAUTHORIZED,
45
-
Json(json!({"error": "AuthenticationRequired"})),
46
-
)
47
-
.into_response();
48
-
}
49
-
50
-
let did = params.did.trim();
51
-
if did.is_empty() {
52
-
return (
53
-
StatusCode::BAD_REQUEST,
54
-
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
55
-
)
56
-
.into_response();
57
-
}
58
-
59
-
let result = sqlx::query!(
60
-
r#"
61
-
SELECT did, handle, email, created_at
62
-
FROM users
63
-
WHERE did = $1
64
-
"#,
65
-
did
66
-
)
67
-
.fetch_optional(&state.db)
68
-
.await;
69
-
70
-
match result {
71
-
Ok(Some(row)) => {
72
-
(
73
-
StatusCode::OK,
74
-
Json(AccountInfo {
75
-
did: row.did,
76
-
handle: row.handle,
77
-
email: Some(row.email),
78
-
indexed_at: row.created_at.to_rfc3339(),
79
-
invite_note: None,
80
-
invites_disabled: false,
81
-
email_confirmed_at: None,
82
-
deactivated_at: None,
83
-
}),
84
-
)
85
-
.into_response()
86
-
}
87
-
Ok(None) => (
88
-
StatusCode::NOT_FOUND,
89
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
90
-
)
91
-
.into_response(),
92
-
Err(e) => {
93
-
error!("DB error in get_account_info: {:?}", e);
94
-
(
95
-
StatusCode::INTERNAL_SERVER_ERROR,
96
-
Json(json!({"error": "InternalError"})),
97
-
)
98
-
.into_response()
99
-
}
100
-
}
101
-
}
102
-
103
-
#[derive(Deserialize)]
104
-
pub struct GetAccountInfosParams {
105
-
pub dids: String,
106
-
}
107
-
108
-
pub async fn get_account_infos(
109
-
State(state): State<AppState>,
110
-
headers: axum::http::HeaderMap,
111
-
Query(params): Query<GetAccountInfosParams>,
112
-
) -> Response {
113
-
let auth_header = headers.get("Authorization");
114
-
if auth_header.is_none() {
115
-
return (
116
-
StatusCode::UNAUTHORIZED,
117
-
Json(json!({"error": "AuthenticationRequired"})),
118
-
)
119
-
.into_response();
120
-
}
121
-
122
-
let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect();
123
-
if dids.is_empty() {
124
-
return (
125
-
StatusCode::BAD_REQUEST,
126
-
Json(json!({"error": "InvalidRequest", "message": "dids is required"})),
127
-
)
128
-
.into_response();
129
-
}
130
-
131
-
let mut infos = Vec::new();
132
-
133
-
for did in dids {
134
-
if did.is_empty() {
135
-
continue;
136
-
}
137
-
138
-
let result = sqlx::query!(
139
-
r#"
140
-
SELECT did, handle, email, created_at
141
-
FROM users
142
-
WHERE did = $1
143
-
"#,
144
-
did
145
-
)
146
-
.fetch_optional(&state.db)
147
-
.await;
148
-
149
-
if let Ok(Some(row)) = result {
150
-
infos.push(AccountInfo {
151
-
did: row.did,
152
-
handle: row.handle,
153
-
email: Some(row.email),
154
-
indexed_at: row.created_at.to_rfc3339(),
155
-
invite_note: None,
156
-
invites_disabled: false,
157
-
email_confirmed_at: None,
158
-
deactivated_at: None,
159
-
});
160
-
}
161
-
}
162
-
163
-
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
164
-
}
165
-
166
-
#[derive(Deserialize)]
167
-
pub struct DeleteAccountInput {
168
-
pub did: String,
169
-
}
170
-
171
-
pub async fn delete_account(
172
-
State(state): State<AppState>,
173
-
headers: axum::http::HeaderMap,
174
-
Json(input): Json<DeleteAccountInput>,
175
-
) -> Response {
176
-
let auth_header = headers.get("Authorization");
177
-
if auth_header.is_none() {
178
-
return (
179
-
StatusCode::UNAUTHORIZED,
180
-
Json(json!({"error": "AuthenticationRequired"})),
181
-
)
182
-
.into_response();
183
-
}
184
-
185
-
let did = input.did.trim();
186
-
if did.is_empty() {
187
-
return (
188
-
StatusCode::BAD_REQUEST,
189
-
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
190
-
)
191
-
.into_response();
192
-
}
193
-
194
-
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
195
-
.fetch_optional(&state.db)
196
-
.await;
197
-
198
-
let user_id = match user {
199
-
Ok(Some(row)) => row.id,
200
-
Ok(None) => {
201
-
return (
202
-
StatusCode::NOT_FOUND,
203
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
204
-
)
205
-
.into_response();
206
-
}
207
-
Err(e) => {
208
-
error!("DB error in delete_account: {:?}", e);
209
-
return (
210
-
StatusCode::INTERNAL_SERVER_ERROR,
211
-
Json(json!({"error": "InternalError"})),
212
-
)
213
-
.into_response();
214
-
}
215
-
};
216
-
217
-
let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did)
218
-
.execute(&state.db)
219
-
.await;
220
-
221
-
let _ = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id)
222
-
.execute(&state.db)
223
-
.await;
224
-
225
-
let _ = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id)
226
-
.execute(&state.db)
227
-
.await;
228
-
229
-
let _ = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id)
230
-
.execute(&state.db)
231
-
.await;
232
-
233
-
let _ = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id)
234
-
.execute(&state.db)
235
-
.await;
236
-
237
-
let result = sqlx::query!("DELETE FROM users WHERE id = $1", user_id)
238
-
.execute(&state.db)
239
-
.await;
240
-
241
-
match result {
242
-
Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
243
-
Err(e) => {
244
-
error!("DB error deleting account: {:?}", e);
245
-
(
246
-
StatusCode::INTERNAL_SERVER_ERROR,
247
-
Json(json!({"error": "InternalError"})),
248
-
)
249
-
.into_response()
250
-
}
251
-
}
252
-
}
253
-
254
-
#[derive(Deserialize)]
255
-
pub struct UpdateAccountEmailInput {
256
-
pub account: String,
257
-
pub email: String,
258
-
}
259
-
260
-
pub async fn update_account_email(
261
-
State(state): State<AppState>,
262
-
headers: axum::http::HeaderMap,
263
-
Json(input): Json<UpdateAccountEmailInput>,
264
-
) -> Response {
265
-
let auth_header = headers.get("Authorization");
266
-
if auth_header.is_none() {
267
-
return (
268
-
StatusCode::UNAUTHORIZED,
269
-
Json(json!({"error": "AuthenticationRequired"})),
270
-
)
271
-
.into_response();
272
-
}
273
-
274
-
let account = input.account.trim();
275
-
let email = input.email.trim();
276
-
277
-
if account.is_empty() || email.is_empty() {
278
-
return (
279
-
StatusCode::BAD_REQUEST,
280
-
Json(json!({"error": "InvalidRequest", "message": "account and email are required"})),
281
-
)
282
-
.into_response();
283
-
}
284
-
285
-
let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account)
286
-
.execute(&state.db)
287
-
.await;
288
-
289
-
match result {
290
-
Ok(r) => {
291
-
if r.rows_affected() == 0 {
292
-
return (
293
-
StatusCode::NOT_FOUND,
294
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
295
-
)
296
-
.into_response();
297
-
}
298
-
(StatusCode::OK, Json(json!({}))).into_response()
299
-
}
300
-
Err(e) => {
301
-
error!("DB error updating email: {:?}", e);
302
-
(
303
-
StatusCode::INTERNAL_SERVER_ERROR,
304
-
Json(json!({"error": "InternalError"})),
305
-
)
306
-
.into_response()
307
-
}
308
-
}
309
-
}
310
-
311
-
#[derive(Deserialize)]
312
-
pub struct UpdateAccountHandleInput {
313
-
pub did: String,
314
-
pub handle: String,
315
-
}
316
-
317
-
pub async fn update_account_handle(
318
-
State(state): State<AppState>,
319
-
headers: axum::http::HeaderMap,
320
-
Json(input): Json<UpdateAccountHandleInput>,
321
-
) -> Response {
322
-
let auth_header = headers.get("Authorization");
323
-
if auth_header.is_none() {
324
-
return (
325
-
StatusCode::UNAUTHORIZED,
326
-
Json(json!({"error": "AuthenticationRequired"})),
327
-
)
328
-
.into_response();
329
-
}
330
-
331
-
let did = input.did.trim();
332
-
let handle = input.handle.trim();
333
-
334
-
if did.is_empty() || handle.is_empty() {
335
-
return (
336
-
StatusCode::BAD_REQUEST,
337
-
Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})),
338
-
)
339
-
.into_response();
340
-
}
341
-
342
-
if !handle
343
-
.chars()
344
-
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
345
-
{
346
-
return (
347
-
StatusCode::BAD_REQUEST,
348
-
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
349
-
)
350
-
.into_response();
351
-
}
352
-
353
-
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
354
-
.fetch_optional(&state.db)
355
-
.await;
356
-
357
-
if let Ok(Some(_)) = existing {
358
-
return (
359
-
StatusCode::BAD_REQUEST,
360
-
Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
361
-
)
362
-
.into_response();
363
-
}
364
-
365
-
let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did)
366
-
.execute(&state.db)
367
-
.await;
368
-
369
-
match result {
370
-
Ok(r) => {
371
-
if r.rows_affected() == 0 {
372
-
return (
373
-
StatusCode::NOT_FOUND,
374
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
375
-
)
376
-
.into_response();
377
-
}
378
-
(StatusCode::OK, Json(json!({}))).into_response()
379
-
}
380
-
Err(e) => {
381
-
error!("DB error updating handle: {:?}", e);
382
-
(
383
-
StatusCode::INTERNAL_SERVER_ERROR,
384
-
Json(json!({"error": "InternalError"})),
385
-
)
386
-
.into_response()
387
-
}
388
-
}
389
-
}
390
-
391
-
#[derive(Deserialize)]
392
-
pub struct UpdateAccountPasswordInput {
393
-
pub did: String,
394
-
pub password: String,
395
-
}
396
-
397
-
pub async fn update_account_password(
398
-
State(state): State<AppState>,
399
-
headers: axum::http::HeaderMap,
400
-
Json(input): Json<UpdateAccountPasswordInput>,
401
-
) -> Response {
402
-
let auth_header = headers.get("Authorization");
403
-
if auth_header.is_none() {
404
-
return (
405
-
StatusCode::UNAUTHORIZED,
406
-
Json(json!({"error": "AuthenticationRequired"})),
407
-
)
408
-
.into_response();
409
-
}
410
-
411
-
let did = input.did.trim();
412
-
let password = input.password.trim();
413
-
414
-
if did.is_empty() || password.is_empty() {
415
-
return (
416
-
StatusCode::BAD_REQUEST,
417
-
Json(json!({"error": "InvalidRequest", "message": "did and password are required"})),
418
-
)
419
-
.into_response();
420
-
}
421
-
422
-
let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) {
423
-
Ok(h) => h,
424
-
Err(e) => {
425
-
error!("Failed to hash password: {:?}", e);
426
-
return (
427
-
StatusCode::INTERNAL_SERVER_ERROR,
428
-
Json(json!({"error": "InternalError"})),
429
-
)
430
-
.into_response();
431
-
}
432
-
};
433
-
434
-
let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did)
435
-
.execute(&state.db)
436
-
.await;
437
-
438
-
match result {
439
-
Ok(r) => {
440
-
if r.rows_affected() == 0 {
441
-
return (
442
-
StatusCode::NOT_FOUND,
443
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
444
-
)
445
-
.into_response();
446
-
}
447
-
(StatusCode::OK, Json(json!({}))).into_response()
448
-
}
449
-
Err(e) => {
450
-
error!("DB error updating password: {:?}", e);
451
-
(
452
-
StatusCode::INTERNAL_SERVER_ERROR,
453
-
Json(json!({"error": "InternalError"})),
454
-
)
455
-
.into_response()
456
-
}
457
-
}
458
-
}
459
-
460
-
#[derive(Deserialize)]
461
-
#[serde(rename_all = "camelCase")]
462
-
pub struct SendEmailInput {
463
-
pub recipient_did: String,
464
-
pub sender_did: String,
465
-
pub content: String,
466
-
pub subject: Option<String>,
467
-
pub comment: Option<String>,
468
-
}
469
-
470
-
#[derive(Serialize)]
471
-
pub struct SendEmailOutput {
472
-
pub sent: bool,
473
-
}
474
-
475
-
pub async fn send_email(
476
-
State(state): State<AppState>,
477
-
headers: axum::http::HeaderMap,
478
-
Json(input): Json<SendEmailInput>,
479
-
) -> Response {
480
-
let auth_header = headers.get("Authorization");
481
-
if auth_header.is_none() {
482
-
return (
483
-
StatusCode::UNAUTHORIZED,
484
-
Json(json!({"error": "AuthenticationRequired"})),
485
-
)
486
-
.into_response();
487
-
}
488
-
489
-
let recipient_did = input.recipient_did.trim();
490
-
let content = input.content.trim();
491
-
492
-
if recipient_did.is_empty() {
493
-
return (
494
-
StatusCode::BAD_REQUEST,
495
-
Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})),
496
-
)
497
-
.into_response();
498
-
}
499
-
500
-
if content.is_empty() {
501
-
return (
502
-
StatusCode::BAD_REQUEST,
503
-
Json(json!({"error": "InvalidRequest", "message": "content is required"})),
504
-
)
505
-
.into_response();
506
-
}
507
-
508
-
let user = sqlx::query!(
509
-
"SELECT id, email, handle FROM users WHERE did = $1",
510
-
recipient_did
511
-
)
512
-
.fetch_optional(&state.db)
513
-
.await;
514
-
515
-
let (user_id, email, handle) = match user {
516
-
Ok(Some(row)) => (row.id, row.email, row.handle),
517
-
Ok(None) => {
518
-
return (
519
-
StatusCode::NOT_FOUND,
520
-
Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})),
521
-
)
522
-
.into_response();
523
-
}
524
-
Err(e) => {
525
-
error!("DB error in send_email: {:?}", e);
526
-
return (
527
-
StatusCode::INTERNAL_SERVER_ERROR,
528
-
Json(json!({"error": "InternalError"})),
529
-
)
530
-
.into_response();
531
-
}
532
-
};
533
-
534
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
535
-
let subject = input
536
-
.subject
537
-
.clone()
538
-
.unwrap_or_else(|| format!("Message from {}", hostname));
539
-
540
-
let notification = crate::notifications::NewNotification::email(
541
-
user_id,
542
-
crate::notifications::NotificationType::AdminEmail,
543
-
email,
544
-
subject,
545
-
content.to_string(),
546
-
);
547
-
548
-
let result = crate::notifications::enqueue_notification(&state.db, notification).await;
549
-
550
-
match result {
551
-
Ok(_) => {
552
-
tracing::info!(
553
-
"Admin email queued for {} ({})",
554
-
handle,
555
-
recipient_did
556
-
);
557
-
(StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()
558
-
}
559
-
Err(e) => {
560
-
warn!("Failed to enqueue admin email: {:?}", e);
561
-
(StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()
562
-
}
563
-
}
564
-
}
+190
src/api/admin/account/delete.rs
+190
src/api/admin/account/delete.rs
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::Deserialize;
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct DeleteAccountInput {
14
+
pub did: String,
15
+
}
16
+
17
+
pub async fn delete_account(
18
+
State(state): State<AppState>,
19
+
headers: axum::http::HeaderMap,
20
+
Json(input): Json<DeleteAccountInput>,
21
+
) -> Response {
22
+
let auth_header = headers.get("Authorization");
23
+
if auth_header.is_none() {
24
+
return (
25
+
StatusCode::UNAUTHORIZED,
26
+
Json(json!({"error": "AuthenticationRequired"})),
27
+
)
28
+
.into_response();
29
+
}
30
+
31
+
let did = input.did.trim();
32
+
if did.is_empty() {
33
+
return (
34
+
StatusCode::BAD_REQUEST,
35
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
36
+
)
37
+
.into_response();
38
+
}
39
+
40
+
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
41
+
.fetch_optional(&state.db)
42
+
.await;
43
+
44
+
let user_id = match user {
45
+
Ok(Some(row)) => row.id,
46
+
Ok(None) => {
47
+
return (
48
+
StatusCode::NOT_FOUND,
49
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
50
+
)
51
+
.into_response();
52
+
}
53
+
Err(e) => {
54
+
error!("DB error in delete_account: {:?}", e);
55
+
return (
56
+
StatusCode::INTERNAL_SERVER_ERROR,
57
+
Json(json!({"error": "InternalError"})),
58
+
)
59
+
.into_response();
60
+
}
61
+
};
62
+
63
+
let mut tx = match state.db.begin().await {
64
+
Ok(tx) => tx,
65
+
Err(e) => {
66
+
error!("Failed to begin transaction for account deletion: {:?}", e);
67
+
return (
68
+
StatusCode::INTERNAL_SERVER_ERROR,
69
+
Json(json!({"error": "InternalError"})),
70
+
)
71
+
.into_response();
72
+
}
73
+
};
74
+
75
+
if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did)
76
+
.execute(&mut *tx)
77
+
.await
78
+
{
79
+
error!("Failed to delete session tokens for {}: {:?}", did, e);
80
+
return (
81
+
StatusCode::INTERNAL_SERVER_ERROR,
82
+
Json(json!({"error": "InternalError", "message": "Failed to delete session tokens"})),
83
+
)
84
+
.into_response();
85
+
}
86
+
87
+
if let Err(e) = sqlx::query!("DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)", did)
88
+
.execute(&mut *tx)
89
+
.await
90
+
{
91
+
error!("Failed to delete used refresh tokens for {}: {:?}", did, e);
92
+
}
93
+
94
+
if let Err(e) = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id)
95
+
.execute(&mut *tx)
96
+
.await
97
+
{
98
+
error!("Failed to delete records for user {}: {:?}", user_id, e);
99
+
return (
100
+
StatusCode::INTERNAL_SERVER_ERROR,
101
+
Json(json!({"error": "InternalError", "message": "Failed to delete records"})),
102
+
)
103
+
.into_response();
104
+
}
105
+
106
+
if let Err(e) = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id)
107
+
.execute(&mut *tx)
108
+
.await
109
+
{
110
+
error!("Failed to delete repos for user {}: {:?}", user_id, e);
111
+
return (
112
+
StatusCode::INTERNAL_SERVER_ERROR,
113
+
Json(json!({"error": "InternalError", "message": "Failed to delete repos"})),
114
+
)
115
+
.into_response();
116
+
}
117
+
118
+
if let Err(e) = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id)
119
+
.execute(&mut *tx)
120
+
.await
121
+
{
122
+
error!("Failed to delete blobs for user {}: {:?}", user_id, e);
123
+
return (
124
+
StatusCode::INTERNAL_SERVER_ERROR,
125
+
Json(json!({"error": "InternalError", "message": "Failed to delete blobs"})),
126
+
)
127
+
.into_response();
128
+
}
129
+
130
+
if let Err(e) = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1", user_id)
131
+
.execute(&mut *tx)
132
+
.await
133
+
{
134
+
error!("Failed to delete app passwords for user {}: {:?}", user_id, e);
135
+
return (
136
+
StatusCode::INTERNAL_SERVER_ERROR,
137
+
Json(json!({"error": "InternalError", "message": "Failed to delete app passwords"})),
138
+
)
139
+
.into_response();
140
+
}
141
+
142
+
if let Err(e) = sqlx::query!("DELETE FROM invite_code_uses WHERE used_by_user = $1", user_id)
143
+
.execute(&mut *tx)
144
+
.await
145
+
{
146
+
error!("Failed to delete invite code uses for user {}: {:?}", user_id, e);
147
+
}
148
+
149
+
if let Err(e) = sqlx::query!("DELETE FROM invite_codes WHERE created_by_user = $1", user_id)
150
+
.execute(&mut *tx)
151
+
.await
152
+
{
153
+
error!("Failed to delete invite codes for user {}: {:?}", user_id, e);
154
+
}
155
+
156
+
if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id)
157
+
.execute(&mut *tx)
158
+
.await
159
+
{
160
+
error!("Failed to delete user keys for user {}: {:?}", user_id, e);
161
+
return (
162
+
StatusCode::INTERNAL_SERVER_ERROR,
163
+
Json(json!({"error": "InternalError", "message": "Failed to delete user keys"})),
164
+
)
165
+
.into_response();
166
+
}
167
+
168
+
if let Err(e) = sqlx::query!("DELETE FROM users WHERE id = $1", user_id)
169
+
.execute(&mut *tx)
170
+
.await
171
+
{
172
+
error!("Failed to delete user {}: {:?}", user_id, e);
173
+
return (
174
+
StatusCode::INTERNAL_SERVER_ERROR,
175
+
Json(json!({"error": "InternalError", "message": "Failed to delete user"})),
176
+
)
177
+
.into_response();
178
+
}
179
+
180
+
if let Err(e) = tx.commit().await {
181
+
error!("Failed to commit account deletion transaction: {:?}", e);
182
+
return (
183
+
StatusCode::INTERNAL_SERVER_ERROR,
184
+
Json(json!({"error": "InternalError", "message": "Failed to commit deletion"})),
185
+
)
186
+
.into_response();
187
+
}
188
+
189
+
(StatusCode::OK, Json(json!({}))).into_response()
190
+
}
+116
src/api/admin/account/email.rs
+116
src/api/admin/account/email.rs
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::{Deserialize, Serialize};
9
+
use serde_json::json;
10
+
use tracing::{error, warn};
11
+
12
+
#[derive(Deserialize)]
13
+
#[serde(rename_all = "camelCase")]
14
+
pub struct SendEmailInput {
15
+
pub recipient_did: String,
16
+
pub sender_did: String,
17
+
pub content: String,
18
+
pub subject: Option<String>,
19
+
pub comment: Option<String>,
20
+
}
21
+
22
+
#[derive(Serialize)]
23
+
pub struct SendEmailOutput {
24
+
pub sent: bool,
25
+
}
26
+
27
+
pub async fn send_email(
28
+
State(state): State<AppState>,
29
+
headers: axum::http::HeaderMap,
30
+
Json(input): Json<SendEmailInput>,
31
+
) -> Response {
32
+
let auth_header = headers.get("Authorization");
33
+
if auth_header.is_none() {
34
+
return (
35
+
StatusCode::UNAUTHORIZED,
36
+
Json(json!({"error": "AuthenticationRequired"})),
37
+
)
38
+
.into_response();
39
+
}
40
+
41
+
let recipient_did = input.recipient_did.trim();
42
+
let content = input.content.trim();
43
+
44
+
if recipient_did.is_empty() {
45
+
return (
46
+
StatusCode::BAD_REQUEST,
47
+
Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})),
48
+
)
49
+
.into_response();
50
+
}
51
+
52
+
if content.is_empty() {
53
+
return (
54
+
StatusCode::BAD_REQUEST,
55
+
Json(json!({"error": "InvalidRequest", "message": "content is required"})),
56
+
)
57
+
.into_response();
58
+
}
59
+
60
+
let user = sqlx::query!(
61
+
"SELECT id, email, handle FROM users WHERE did = $1",
62
+
recipient_did
63
+
)
64
+
.fetch_optional(&state.db)
65
+
.await;
66
+
67
+
let (user_id, email, handle) = match user {
68
+
Ok(Some(row)) => (row.id, row.email, row.handle),
69
+
Ok(None) => {
70
+
return (
71
+
StatusCode::NOT_FOUND,
72
+
Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})),
73
+
)
74
+
.into_response();
75
+
}
76
+
Err(e) => {
77
+
error!("DB error in send_email: {:?}", e);
78
+
return (
79
+
StatusCode::INTERNAL_SERVER_ERROR,
80
+
Json(json!({"error": "InternalError"})),
81
+
)
82
+
.into_response();
83
+
}
84
+
};
85
+
86
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
87
+
let subject = input
88
+
.subject
89
+
.clone()
90
+
.unwrap_or_else(|| format!("Message from {}", hostname));
91
+
92
+
let notification = crate::notifications::NewNotification::email(
93
+
user_id,
94
+
crate::notifications::NotificationType::AdminEmail,
95
+
email,
96
+
subject,
97
+
content.to_string(),
98
+
);
99
+
100
+
let result = crate::notifications::enqueue_notification(&state.db, notification).await;
101
+
102
+
match result {
103
+
Ok(_) => {
104
+
tracing::info!(
105
+
"Admin email queued for {} ({})",
106
+
handle,
107
+
recipient_did
108
+
);
109
+
(StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()
110
+
}
111
+
Err(e) => {
112
+
warn!("Failed to enqueue admin email: {:?}", e);
113
+
(StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()
114
+
}
115
+
}
116
+
}
+164
src/api/admin/account/info.rs
+164
src/api/admin/account/info.rs
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::{Query, State},
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::{Deserialize, Serialize};
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct GetAccountInfoParams {
14
+
pub did: String,
15
+
}
16
+
17
+
#[derive(Serialize)]
18
+
#[serde(rename_all = "camelCase")]
19
+
pub struct AccountInfo {
20
+
pub did: String,
21
+
pub handle: String,
22
+
pub email: Option<String>,
23
+
pub indexed_at: String,
24
+
pub invite_note: Option<String>,
25
+
pub invites_disabled: bool,
26
+
pub email_confirmed_at: Option<String>,
27
+
pub deactivated_at: Option<String>,
28
+
}
29
+
30
+
#[derive(Serialize)]
31
+
#[serde(rename_all = "camelCase")]
32
+
pub struct GetAccountInfosOutput {
33
+
pub infos: Vec<AccountInfo>,
34
+
}
35
+
36
+
pub async fn get_account_info(
37
+
State(state): State<AppState>,
38
+
headers: axum::http::HeaderMap,
39
+
Query(params): Query<GetAccountInfoParams>,
40
+
) -> Response {
41
+
let auth_header = headers.get("Authorization");
42
+
if auth_header.is_none() {
43
+
return (
44
+
StatusCode::UNAUTHORIZED,
45
+
Json(json!({"error": "AuthenticationRequired"})),
46
+
)
47
+
.into_response();
48
+
}
49
+
50
+
let did = params.did.trim();
51
+
if did.is_empty() {
52
+
return (
53
+
StatusCode::BAD_REQUEST,
54
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
55
+
)
56
+
.into_response();
57
+
}
58
+
59
+
let result = sqlx::query!(
60
+
r#"
61
+
SELECT did, handle, email, created_at
62
+
FROM users
63
+
WHERE did = $1
64
+
"#,
65
+
did
66
+
)
67
+
.fetch_optional(&state.db)
68
+
.await;
69
+
70
+
match result {
71
+
Ok(Some(row)) => {
72
+
(
73
+
StatusCode::OK,
74
+
Json(AccountInfo {
75
+
did: row.did,
76
+
handle: row.handle,
77
+
email: Some(row.email),
78
+
indexed_at: row.created_at.to_rfc3339(),
79
+
invite_note: None,
80
+
invites_disabled: false,
81
+
email_confirmed_at: None,
82
+
deactivated_at: None,
83
+
}),
84
+
)
85
+
.into_response()
86
+
}
87
+
Ok(None) => (
88
+
StatusCode::NOT_FOUND,
89
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
90
+
)
91
+
.into_response(),
92
+
Err(e) => {
93
+
error!("DB error in get_account_info: {:?}", e);
94
+
(
95
+
StatusCode::INTERNAL_SERVER_ERROR,
96
+
Json(json!({"error": "InternalError"})),
97
+
)
98
+
.into_response()
99
+
}
100
+
}
101
+
}
102
+
103
+
#[derive(Deserialize)]
104
+
pub struct GetAccountInfosParams {
105
+
pub dids: String,
106
+
}
107
+
108
+
pub async fn get_account_infos(
109
+
State(state): State<AppState>,
110
+
headers: axum::http::HeaderMap,
111
+
Query(params): Query<GetAccountInfosParams>,
112
+
) -> Response {
113
+
let auth_header = headers.get("Authorization");
114
+
if auth_header.is_none() {
115
+
return (
116
+
StatusCode::UNAUTHORIZED,
117
+
Json(json!({"error": "AuthenticationRequired"})),
118
+
)
119
+
.into_response();
120
+
}
121
+
122
+
let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect();
123
+
if dids.is_empty() {
124
+
return (
125
+
StatusCode::BAD_REQUEST,
126
+
Json(json!({"error": "InvalidRequest", "message": "dids is required"})),
127
+
)
128
+
.into_response();
129
+
}
130
+
131
+
let mut infos = Vec::new();
132
+
133
+
for did in dids {
134
+
if did.is_empty() {
135
+
continue;
136
+
}
137
+
138
+
let result = sqlx::query!(
139
+
r#"
140
+
SELECT did, handle, email, created_at
141
+
FROM users
142
+
WHERE did = $1
143
+
"#,
144
+
did
145
+
)
146
+
.fetch_optional(&state.db)
147
+
.await;
148
+
149
+
if let Ok(Some(row)) = result {
150
+
infos.push(AccountInfo {
151
+
did: row.did,
152
+
handle: row.handle,
153
+
email: Some(row.email),
154
+
indexed_at: row.created_at.to_rfc3339(),
155
+
invite_note: None,
156
+
invites_disabled: false,
157
+
email_confirmed_at: None,
158
+
deactivated_at: None,
159
+
});
160
+
}
161
+
}
162
+
163
+
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
164
+
}
+15
src/api/admin/account/mod.rs
+15
src/api/admin/account/mod.rs
···
1
+
mod delete;
2
+
mod email;
3
+
mod info;
4
+
mod update;
5
+
6
+
pub use delete::{delete_account, DeleteAccountInput};
7
+
pub use email::{send_email, SendEmailInput, SendEmailOutput};
8
+
pub use info::{
9
+
get_account_info, get_account_infos, AccountInfo, GetAccountInfoParams, GetAccountInfosOutput,
10
+
GetAccountInfosParams,
11
+
};
12
+
pub use update::{
13
+
update_account_email, update_account_handle, update_account_password, UpdateAccountEmailInput,
14
+
UpdateAccountHandleInput, UpdateAccountPasswordInput,
15
+
};
+216
src/api/admin/account/update.rs
+216
src/api/admin/account/update.rs
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::Deserialize;
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct UpdateAccountEmailInput {
14
+
pub account: String,
15
+
pub email: String,
16
+
}
17
+
18
+
pub async fn update_account_email(
19
+
State(state): State<AppState>,
20
+
headers: axum::http::HeaderMap,
21
+
Json(input): Json<UpdateAccountEmailInput>,
22
+
) -> Response {
23
+
let auth_header = headers.get("Authorization");
24
+
if auth_header.is_none() {
25
+
return (
26
+
StatusCode::UNAUTHORIZED,
27
+
Json(json!({"error": "AuthenticationRequired"})),
28
+
)
29
+
.into_response();
30
+
}
31
+
32
+
let account = input.account.trim();
33
+
let email = input.email.trim();
34
+
35
+
if account.is_empty() || email.is_empty() {
36
+
return (
37
+
StatusCode::BAD_REQUEST,
38
+
Json(json!({"error": "InvalidRequest", "message": "account and email are required"})),
39
+
)
40
+
.into_response();
41
+
}
42
+
43
+
let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account)
44
+
.execute(&state.db)
45
+
.await;
46
+
47
+
match result {
48
+
Ok(r) => {
49
+
if r.rows_affected() == 0 {
50
+
return (
51
+
StatusCode::NOT_FOUND,
52
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
53
+
)
54
+
.into_response();
55
+
}
56
+
(StatusCode::OK, Json(json!({}))).into_response()
57
+
}
58
+
Err(e) => {
59
+
error!("DB error updating email: {:?}", e);
60
+
(
61
+
StatusCode::INTERNAL_SERVER_ERROR,
62
+
Json(json!({"error": "InternalError"})),
63
+
)
64
+
.into_response()
65
+
}
66
+
}
67
+
}
68
+
69
+
#[derive(Deserialize)]
70
+
pub struct UpdateAccountHandleInput {
71
+
pub did: String,
72
+
pub handle: String,
73
+
}
74
+
75
+
pub async fn update_account_handle(
76
+
State(state): State<AppState>,
77
+
headers: axum::http::HeaderMap,
78
+
Json(input): Json<UpdateAccountHandleInput>,
79
+
) -> Response {
80
+
let auth_header = headers.get("Authorization");
81
+
if auth_header.is_none() {
82
+
return (
83
+
StatusCode::UNAUTHORIZED,
84
+
Json(json!({"error": "AuthenticationRequired"})),
85
+
)
86
+
.into_response();
87
+
}
88
+
89
+
let did = input.did.trim();
90
+
let handle = input.handle.trim();
91
+
92
+
if did.is_empty() || handle.is_empty() {
93
+
return (
94
+
StatusCode::BAD_REQUEST,
95
+
Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})),
96
+
)
97
+
.into_response();
98
+
}
99
+
100
+
if !handle
101
+
.chars()
102
+
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
103
+
{
104
+
return (
105
+
StatusCode::BAD_REQUEST,
106
+
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
107
+
)
108
+
.into_response();
109
+
}
110
+
111
+
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
112
+
.fetch_optional(&state.db)
113
+
.await;
114
+
115
+
if let Ok(Some(_)) = existing {
116
+
return (
117
+
StatusCode::BAD_REQUEST,
118
+
Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
119
+
)
120
+
.into_response();
121
+
}
122
+
123
+
let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did)
124
+
.execute(&state.db)
125
+
.await;
126
+
127
+
match result {
128
+
Ok(r) => {
129
+
if r.rows_affected() == 0 {
130
+
return (
131
+
StatusCode::NOT_FOUND,
132
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
133
+
)
134
+
.into_response();
135
+
}
136
+
(StatusCode::OK, Json(json!({}))).into_response()
137
+
}
138
+
Err(e) => {
139
+
error!("DB error updating handle: {:?}", e);
140
+
(
141
+
StatusCode::INTERNAL_SERVER_ERROR,
142
+
Json(json!({"error": "InternalError"})),
143
+
)
144
+
.into_response()
145
+
}
146
+
}
147
+
}
148
+
149
+
#[derive(Deserialize)]
150
+
pub struct UpdateAccountPasswordInput {
151
+
pub did: String,
152
+
pub password: String,
153
+
}
154
+
155
+
pub async fn update_account_password(
156
+
State(state): State<AppState>,
157
+
headers: axum::http::HeaderMap,
158
+
Json(input): Json<UpdateAccountPasswordInput>,
159
+
) -> Response {
160
+
let auth_header = headers.get("Authorization");
161
+
if auth_header.is_none() {
162
+
return (
163
+
StatusCode::UNAUTHORIZED,
164
+
Json(json!({"error": "AuthenticationRequired"})),
165
+
)
166
+
.into_response();
167
+
}
168
+
169
+
let did = input.did.trim();
170
+
let password = input.password.trim();
171
+
172
+
if did.is_empty() || password.is_empty() {
173
+
return (
174
+
StatusCode::BAD_REQUEST,
175
+
Json(json!({"error": "InvalidRequest", "message": "did and password are required"})),
176
+
)
177
+
.into_response();
178
+
}
179
+
180
+
let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) {
181
+
Ok(h) => h,
182
+
Err(e) => {
183
+
error!("Failed to hash password: {:?}", e);
184
+
return (
185
+
StatusCode::INTERNAL_SERVER_ERROR,
186
+
Json(json!({"error": "InternalError"})),
187
+
)
188
+
.into_response();
189
+
}
190
+
};
191
+
192
+
let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did)
193
+
.execute(&state.db)
194
+
.await;
195
+
196
+
match result {
197
+
Ok(r) => {
198
+
if r.rows_affected() == 0 {
199
+
return (
200
+
StatusCode::NOT_FOUND,
201
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
202
+
)
203
+
.into_response();
204
+
}
205
+
(StatusCode::OK, Json(json!({}))).into_response()
206
+
}
207
+
Err(e) => {
208
+
error!("DB error updating password: {:?}", e);
209
+
(
210
+
StatusCode::INTERNAL_SERVER_ERROR,
211
+
Json(json!({"error": "InternalError"})),
212
+
)
213
+
.into_response()
214
+
}
215
+
}
216
+
}
+1
-1
src/api/admin/invite.rs
+1
-1
src/api/admin/invite.rs
+68
-14
src/api/admin/status.rs
+68
-14
src/api/admin/status.rs
···
234
234
Some("com.atproto.admin.defs#repoRef") => {
235
235
let did = input.subject.get("did").and_then(|d| d.as_str());
236
236
if let Some(did) = did {
237
+
let mut tx = match state.db.begin().await {
238
+
Ok(tx) => tx,
239
+
Err(e) => {
240
+
error!("Failed to begin transaction: {:?}", e);
241
+
return (
242
+
StatusCode::INTERNAL_SERVER_ERROR,
243
+
Json(json!({"error": "InternalError"})),
244
+
)
245
+
.into_response();
246
+
}
247
+
};
248
+
237
249
if let Some(takedown) = &input.takedown {
238
250
let takedown_ref = if takedown.apply {
239
251
takedown.r#ref.clone()
240
252
} else {
241
253
None
242
254
};
243
-
let _ = sqlx::query!(
255
+
if let Err(e) = sqlx::query!(
244
256
"UPDATE users SET takedown_ref = $1 WHERE did = $2",
245
257
takedown_ref,
246
258
did
247
259
)
248
-
.execute(&state.db)
249
-
.await;
260
+
.execute(&mut *tx)
261
+
.await
262
+
{
263
+
error!("Failed to update user takedown status for {}: {:?}", did, e);
264
+
return (
265
+
StatusCode::INTERNAL_SERVER_ERROR,
266
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
267
+
)
268
+
.into_response();
269
+
}
250
270
}
251
271
252
272
if let Some(deactivated) = &input.deactivated {
253
-
if deactivated.apply {
254
-
let _ = sqlx::query!(
273
+
let result = if deactivated.apply {
274
+
sqlx::query!(
255
275
"UPDATE users SET deactivated_at = NOW() WHERE did = $1",
256
276
did
257
277
)
258
-
.execute(&state.db)
259
-
.await;
278
+
.execute(&mut *tx)
279
+
.await
260
280
} else {
261
-
let _ = sqlx::query!(
281
+
sqlx::query!(
262
282
"UPDATE users SET deactivated_at = NULL WHERE did = $1",
263
283
did
264
284
)
265
-
.execute(&state.db)
266
-
.await;
285
+
.execute(&mut *tx)
286
+
.await
287
+
};
288
+
289
+
if let Err(e) = result {
290
+
error!("Failed to update user deactivation status for {}: {:?}", did, e);
291
+
return (
292
+
StatusCode::INTERNAL_SERVER_ERROR,
293
+
Json(json!({"error": "InternalError", "message": "Failed to update deactivation status"})),
294
+
)
295
+
.into_response();
267
296
}
297
+
}
298
+
299
+
if let Err(e) = tx.commit().await {
300
+
error!("Failed to commit transaction: {:?}", e);
301
+
return (
302
+
StatusCode::INTERNAL_SERVER_ERROR,
303
+
Json(json!({"error": "InternalError"})),
304
+
)
305
+
.into_response();
268
306
}
269
307
270
308
return (
···
292
330
} else {
293
331
None
294
332
};
295
-
let _ = sqlx::query!(
333
+
if let Err(e) = sqlx::query!(
296
334
"UPDATE records SET takedown_ref = $1 WHERE record_cid = $2",
297
335
takedown_ref,
298
336
uri
299
337
)
300
338
.execute(&state.db)
301
-
.await;
339
+
.await
340
+
{
341
+
error!("Failed to update record takedown status for {}: {:?}", uri, e);
342
+
return (
343
+
StatusCode::INTERNAL_SERVER_ERROR,
344
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
345
+
)
346
+
.into_response();
347
+
}
302
348
}
303
349
304
350
return (
···
323
369
} else {
324
370
None
325
371
};
326
-
let _ = sqlx::query!(
372
+
if let Err(e) = sqlx::query!(
327
373
"UPDATE blobs SET takedown_ref = $1 WHERE cid = $2",
328
374
takedown_ref,
329
375
cid
330
376
)
331
377
.execute(&state.db)
332
-
.await;
378
+
.await
379
+
{
380
+
error!("Failed to update blob takedown status for {}: {:?}", cid, e);
381
+
return (
382
+
StatusCode::INTERNAL_SERVER_ERROR,
383
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
384
+
)
385
+
.into_response();
386
+
}
333
387
}
334
388
335
389
return (
+163
src/api/error.rs
+163
src/api/error.rs
···
1
+
use axum::{
2
+
Json,
3
+
http::StatusCode,
4
+
response::{IntoResponse, Response},
5
+
};
6
+
use serde::Serialize;
7
+
8
+
#[derive(Debug, Serialize)]
9
+
struct ErrorBody {
10
+
error: &'static str,
11
+
#[serde(skip_serializing_if = "Option::is_none")]
12
+
message: Option<String>,
13
+
}
14
+
15
+
#[derive(Debug)]
16
+
pub enum ApiError {
17
+
InternalError,
18
+
AuthenticationRequired,
19
+
AuthenticationFailed,
20
+
AuthenticationFailedMsg(String),
21
+
InvalidRequest(String),
22
+
InvalidToken,
23
+
ExpiredToken,
24
+
ExpiredTokenMsg(String),
25
+
TokenRequired,
26
+
AccountDeactivated,
27
+
AccountTakedown,
28
+
AccountNotFound,
29
+
RepoNotFound,
30
+
RepoNotFoundMsg(String),
31
+
RecordNotFound,
32
+
BlobNotFound,
33
+
InvalidHandle,
34
+
HandleNotAvailable,
35
+
HandleTaken,
36
+
InvalidEmail,
37
+
EmailTaken,
38
+
InvalidInviteCode,
39
+
DuplicateCreate,
40
+
DuplicateAppPassword,
41
+
AppPasswordNotFound,
42
+
InvalidSwap,
43
+
Forbidden,
44
+
InvitesDisabled,
45
+
DatabaseError,
46
+
UpstreamFailure,
47
+
}
48
+
49
+
impl ApiError {
50
+
fn status_code(&self) -> StatusCode {
51
+
match self {
52
+
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => {
53
+
StatusCode::INTERNAL_SERVER_ERROR
54
+
}
55
+
Self::AuthenticationRequired
56
+
| Self::AuthenticationFailed
57
+
| Self::AuthenticationFailedMsg(_)
58
+
| Self::InvalidToken
59
+
| Self::ExpiredToken
60
+
| Self::ExpiredTokenMsg(_)
61
+
| Self::TokenRequired
62
+
| Self::AccountDeactivated
63
+
| Self::AccountTakedown => StatusCode::UNAUTHORIZED,
64
+
Self::Forbidden | Self::InvitesDisabled => StatusCode::FORBIDDEN,
65
+
Self::AccountNotFound
66
+
| Self::RepoNotFound
67
+
| Self::RepoNotFoundMsg(_)
68
+
| Self::RecordNotFound
69
+
| Self::BlobNotFound
70
+
| Self::AppPasswordNotFound => StatusCode::NOT_FOUND,
71
+
Self::InvalidRequest(_)
72
+
| Self::InvalidHandle
73
+
| Self::HandleNotAvailable
74
+
| Self::HandleTaken
75
+
| Self::InvalidEmail
76
+
| Self::EmailTaken
77
+
| Self::InvalidInviteCode
78
+
| Self::DuplicateCreate
79
+
| Self::DuplicateAppPassword
80
+
| Self::InvalidSwap => StatusCode::BAD_REQUEST,
81
+
}
82
+
}
83
+
84
+
fn error_name(&self) -> &'static str {
85
+
match self {
86
+
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError",
87
+
Self::AuthenticationRequired => "AuthenticationRequired",
88
+
Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed",
89
+
Self::InvalidToken => "InvalidToken",
90
+
Self::ExpiredToken | Self::ExpiredTokenMsg(_) => "ExpiredToken",
91
+
Self::TokenRequired => "TokenRequired",
92
+
Self::AccountDeactivated => "AccountDeactivated",
93
+
Self::AccountTakedown => "AccountTakedown",
94
+
Self::Forbidden => "Forbidden",
95
+
Self::InvitesDisabled => "InvitesDisabled",
96
+
Self::AccountNotFound => "AccountNotFound",
97
+
Self::RepoNotFound | Self::RepoNotFoundMsg(_) => "RepoNotFound",
98
+
Self::RecordNotFound => "RecordNotFound",
99
+
Self::BlobNotFound => "BlobNotFound",
100
+
Self::AppPasswordNotFound => "AppPasswordNotFound",
101
+
Self::InvalidRequest(_) => "InvalidRequest",
102
+
Self::InvalidHandle => "InvalidHandle",
103
+
Self::HandleNotAvailable => "HandleNotAvailable",
104
+
Self::HandleTaken => "HandleTaken",
105
+
Self::InvalidEmail => "InvalidEmail",
106
+
Self::EmailTaken => "EmailTaken",
107
+
Self::InvalidInviteCode => "InvalidInviteCode",
108
+
Self::DuplicateCreate => "DuplicateCreate",
109
+
Self::DuplicateAppPassword => "DuplicateAppPassword",
110
+
Self::InvalidSwap => "InvalidSwap",
111
+
}
112
+
}
113
+
114
+
fn message(&self) -> Option<String> {
115
+
match self {
116
+
Self::AuthenticationFailedMsg(msg)
117
+
| Self::ExpiredTokenMsg(msg)
118
+
| Self::InvalidRequest(msg)
119
+
| Self::RepoNotFoundMsg(msg) => Some(msg.clone()),
120
+
_ => None,
121
+
}
122
+
}
123
+
}
124
+
125
+
impl IntoResponse for ApiError {
126
+
fn into_response(self) -> Response {
127
+
let body = ErrorBody {
128
+
error: self.error_name(),
129
+
message: self.message(),
130
+
};
131
+
(self.status_code(), Json(body)).into_response()
132
+
}
133
+
}
134
+
135
+
impl From<sqlx::Error> for ApiError {
136
+
fn from(e: sqlx::Error) -> Self {
137
+
tracing::error!("Database error: {:?}", e);
138
+
Self::DatabaseError
139
+
}
140
+
}
141
+
142
+
impl From<crate::auth::TokenValidationError> for ApiError {
143
+
fn from(e: crate::auth::TokenValidationError) -> Self {
144
+
match e {
145
+
crate::auth::TokenValidationError::AccountDeactivated => Self::AccountDeactivated,
146
+
crate::auth::TokenValidationError::AccountTakedown => Self::AccountTakedown,
147
+
crate::auth::TokenValidationError::KeyDecryptionFailed => Self::InternalError,
148
+
crate::auth::TokenValidationError::AuthenticationFailed => Self::AuthenticationFailed,
149
+
}
150
+
}
151
+
}
152
+
153
+
impl From<crate::util::DbLookupError> for ApiError {
154
+
fn from(e: crate::util::DbLookupError) -> Self {
155
+
match e {
156
+
crate::util::DbLookupError::NotFound => Self::AccountNotFound,
157
+
crate::util::DbLookupError::DatabaseError(db_err) => {
158
+
tracing::error!("Database error: {:?}", db_err);
159
+
Self::DatabaseError
160
+
}
161
+
}
162
+
}
163
+
}
+9
-1
src/api/identity/account.rs
+9
-1
src/api/identity/account.rs
···
40
40
State(state): State<AppState>,
41
41
Json(input): Json<CreateAccountInput>,
42
42
) -> Response {
43
-
info!("create_account hit: {}", input.handle);
43
+
info!("create_account called");
44
44
if input.handle.contains('!') || input.handle.contains('@') {
45
45
return (
46
46
StatusCode::BAD_REQUEST,
47
47
Json(
48
48
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
49
49
),
50
+
)
51
+
.into_response();
52
+
}
53
+
54
+
if !crate::api::validation::is_valid_email(&input.email) {
55
+
return (
56
+
StatusCode::BAD_REQUEST,
57
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
50
58
)
51
59
.into_response();
52
60
}
+36
-72
src/api/identity/did.rs
+36
-72
src/api/identity/did.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::state::AppState;
2
3
use axum::{
3
4
Json,
···
56
57
}
57
58
}
58
59
59
-
pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
60
-
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
60
+
pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
61
+
let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
61
62
let public_key = secret_key.public_key();
62
63
let encoded = public_key.to_encoded_point(false);
63
-
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
64
-
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
64
+
let x = encoded.x().ok_or("Missing x coordinate")?;
65
+
let y = encoded.y().ok_or("Missing y coordinate")?;
66
+
let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
67
+
let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
65
68
66
-
json!({
69
+
Ok(json!({
67
70
"kty": "EC",
68
71
"crv": "secp256k1",
69
-
"x": x,
70
-
"y": y
71
-
})
72
+
"x": x_b64,
73
+
"y": y_b64
74
+
}))
72
75
}
73
76
74
77
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
···
147
150
}
148
151
};
149
152
150
-
let jwk = get_jwk(&key_bytes);
153
+
let jwk = match get_jwk(&key_bytes) {
154
+
Ok(j) => j,
155
+
Err(e) => {
156
+
tracing::error!("Failed to generate JWK: {}", e);
157
+
return (
158
+
StatusCode::INTERNAL_SERVER_ERROR,
159
+
Json(json!({"error": "InternalError"})),
160
+
)
161
+
.into_response();
162
+
}
163
+
};
151
164
152
165
Json(json!({
153
166
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
···
294
307
}
295
308
};
296
309
297
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
298
-
let did = match auth_result {
299
-
Ok(ref user) => user.did.clone(),
300
-
Err(e) => {
301
-
return (
302
-
StatusCode::UNAUTHORIZED,
303
-
Json(json!({"error": e})),
304
-
)
305
-
.into_response();
306
-
}
310
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
311
+
Ok(user) => user,
312
+
Err(e) => return ApiError::from(e).into_response(),
307
313
};
308
314
309
-
let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", did)
315
+
let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
310
316
.fetch_optional(&state.db)
311
317
.await
312
318
{
313
319
Ok(Some(row)) => row,
314
-
_ => {
315
-
return (
316
-
StatusCode::INTERNAL_SERVER_ERROR,
317
-
Json(json!({"error": "InternalError"})),
318
-
)
319
-
.into_response();
320
-
}
320
+
_ => return ApiError::InternalError.into_response(),
321
321
};
322
-
let handle = user.handle;
323
322
324
-
let key_bytes = match auth_result.ok().and_then(|u| u.key_bytes) {
323
+
let key_bytes = match auth_user.key_bytes {
325
324
Some(kb) => kb,
326
-
None => {
327
-
return (
328
-
StatusCode::UNAUTHORIZED,
329
-
Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot get DID credentials"})),
330
-
)
331
-
.into_response();
332
-
}
325
+
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
333
326
};
334
327
335
328
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
···
337
330
338
331
let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
339
332
Ok(k) => k,
340
-
Err(_) => {
341
-
return (
342
-
StatusCode::INTERNAL_SERVER_ERROR,
343
-
Json(json!({"error": "InternalError"})),
344
-
)
345
-
.into_response();
346
-
}
333
+
Err(_) => return ApiError::InternalError.into_response(),
347
334
};
348
335
349
336
let public_key = secret_key.public_key();
···
360
347
StatusCode::OK,
361
348
Json(GetRecommendedDidCredentialsOutput {
362
349
rotation_keys: vec![did_key.clone()],
363
-
also_known_as: vec![format!("at://{}", handle)],
350
+
also_known_as: vec![format!("at://{}", user.handle)],
364
351
verification_methods: VerificationMethods { atproto: did_key },
365
352
services: Services {
366
353
atproto_pds: AtprotoPds {
···
387
374
headers.get("Authorization").and_then(|h| h.to_str().ok())
388
375
) {
389
376
Some(t) => t,
390
-
None => {
391
-
return (
392
-
StatusCode::UNAUTHORIZED,
393
-
Json(json!({"error": "AuthenticationRequired"})),
394
-
)
395
-
.into_response();
396
-
}
377
+
None => return ApiError::AuthenticationRequired.into_response(),
397
378
};
398
379
399
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
400
-
let did = match auth_result {
380
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
401
381
Ok(user) => user.did,
402
-
Err(e) => {
403
-
return (
404
-
StatusCode::UNAUTHORIZED,
405
-
Json(json!({"error": e})),
406
-
)
407
-
.into_response();
408
-
}
382
+
Err(e) => return ApiError::from(e).into_response(),
409
383
};
410
384
411
385
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
413
387
.await
414
388
{
415
389
Ok(Some(id)) => id,
416
-
_ => {
417
-
return (
418
-
StatusCode::INTERNAL_SERVER_ERROR,
419
-
Json(json!({"error": "InternalError"})),
420
-
)
421
-
.into_response();
422
-
}
390
+
_ => return ApiError::InternalError.into_response(),
423
391
};
424
392
425
393
let new_handle = input.handle.trim();
426
394
if new_handle.is_empty() {
427
-
return (
428
-
StatusCode::BAD_REQUEST,
429
-
Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
430
-
)
431
-
.into_response();
395
+
return ApiError::InvalidRequest("handle is required".into()).into_response();
432
396
}
433
397
434
398
if !new_handle
-618
src/api/identity/plc.rs
-618
src/api/identity/plc.rs
···
1
-
use crate::plc::{
2
-
create_update_op, sign_operation, signing_key_to_did_key, validate_plc_operation,
3
-
PlcClient, PlcError, PlcService,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
extract::State,
8
-
http::StatusCode,
9
-
response::{IntoResponse, Response},
10
-
Json,
11
-
};
12
-
use chrono::{Duration, Utc};
13
-
use k256::ecdsa::SigningKey;
14
-
use rand::Rng;
15
-
use serde::{Deserialize, Serialize};
16
-
use serde_json::{json, Value};
17
-
use std::collections::HashMap;
18
-
use tracing::{error, info, warn};
19
-
20
-
fn generate_plc_token() -> String {
21
-
let mut rng = rand::thread_rng();
22
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
23
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
24
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
25
-
format!("{}-{}", part1, part2)
26
-
}
27
-
28
-
pub async fn request_plc_operation_signature(
29
-
State(state): State<AppState>,
30
-
headers: axum::http::HeaderMap,
31
-
) -> Response {
32
-
let token = match crate::auth::extract_bearer_token_from_header(
33
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
34
-
) {
35
-
Some(t) => t,
36
-
None => {
37
-
return (
38
-
StatusCode::UNAUTHORIZED,
39
-
Json(json!({"error": "AuthenticationRequired"})),
40
-
)
41
-
.into_response();
42
-
}
43
-
};
44
-
45
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
46
-
Ok(user) => user,
47
-
Err(e) => {
48
-
return (
49
-
StatusCode::UNAUTHORIZED,
50
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
51
-
)
52
-
.into_response();
53
-
}
54
-
};
55
-
56
-
let did = &auth_user.did;
57
-
58
-
let user = match sqlx::query!(
59
-
"SELECT id FROM users WHERE did = $1",
60
-
did
61
-
)
62
-
.fetch_optional(&state.db)
63
-
.await
64
-
{
65
-
Ok(Some(row)) => row,
66
-
Ok(None) => {
67
-
return (
68
-
StatusCode::NOT_FOUND,
69
-
Json(json!({"error": "AccountNotFound"})),
70
-
)
71
-
.into_response();
72
-
}
73
-
Err(e) => {
74
-
error!("DB error: {:?}", e);
75
-
return (
76
-
StatusCode::INTERNAL_SERVER_ERROR,
77
-
Json(json!({"error": "InternalError"})),
78
-
)
79
-
.into_response();
80
-
}
81
-
};
82
-
83
-
let _ = sqlx::query!(
84
-
"DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
85
-
user.id
86
-
)
87
-
.execute(&state.db)
88
-
.await;
89
-
90
-
let plc_token = generate_plc_token();
91
-
let expires_at = Utc::now() + Duration::minutes(10);
92
-
93
-
if let Err(e) = sqlx::query!(
94
-
r#"
95
-
INSERT INTO plc_operation_tokens (user_id, token, expires_at)
96
-
VALUES ($1, $2, $3)
97
-
"#,
98
-
user.id,
99
-
plc_token,
100
-
expires_at
101
-
)
102
-
.execute(&state.db)
103
-
.await
104
-
{
105
-
error!("Failed to create PLC token: {:?}", e);
106
-
return (
107
-
StatusCode::INTERNAL_SERVER_ERROR,
108
-
Json(json!({"error": "InternalError"})),
109
-
)
110
-
.into_response();
111
-
}
112
-
113
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
114
-
115
-
if let Err(e) = crate::notifications::enqueue_plc_operation(
116
-
&state.db,
117
-
user.id,
118
-
&plc_token,
119
-
&hostname,
120
-
)
121
-
.await
122
-
{
123
-
warn!("Failed to enqueue PLC operation notification: {:?}", e);
124
-
}
125
-
126
-
info!("PLC operation signature requested for user {}", did);
127
-
128
-
(StatusCode::OK, Json(json!({}))).into_response()
129
-
}
130
-
131
-
#[derive(Debug, Deserialize)]
132
-
#[serde(rename_all = "camelCase")]
133
-
pub struct SignPlcOperationInput {
134
-
pub token: Option<String>,
135
-
pub rotation_keys: Option<Vec<String>>,
136
-
pub also_known_as: Option<Vec<String>>,
137
-
pub verification_methods: Option<HashMap<String, String>>,
138
-
pub services: Option<HashMap<String, ServiceInput>>,
139
-
}
140
-
141
-
#[derive(Debug, Deserialize, Clone)]
142
-
pub struct ServiceInput {
143
-
#[serde(rename = "type")]
144
-
pub service_type: String,
145
-
pub endpoint: String,
146
-
}
147
-
148
-
#[derive(Debug, Serialize)]
149
-
pub struct SignPlcOperationOutput {
150
-
pub operation: Value,
151
-
}
152
-
153
-
pub async fn sign_plc_operation(
154
-
State(state): State<AppState>,
155
-
headers: axum::http::HeaderMap,
156
-
Json(input): Json<SignPlcOperationInput>,
157
-
) -> Response {
158
-
let bearer = match crate::auth::extract_bearer_token_from_header(
159
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
160
-
) {
161
-
Some(t) => t,
162
-
None => {
163
-
return (
164
-
StatusCode::UNAUTHORIZED,
165
-
Json(json!({"error": "AuthenticationRequired"})),
166
-
)
167
-
.into_response();
168
-
}
169
-
};
170
-
171
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
172
-
Ok(user) => user,
173
-
Err(e) => {
174
-
return (
175
-
StatusCode::UNAUTHORIZED,
176
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
177
-
)
178
-
.into_response();
179
-
}
180
-
};
181
-
182
-
let did = &auth_user.did;
183
-
184
-
let token = match &input.token {
185
-
Some(t) => t,
186
-
None => {
187
-
return (
188
-
StatusCode::BAD_REQUEST,
189
-
Json(json!({
190
-
"error": "InvalidRequest",
191
-
"message": "Email confirmation token required to sign PLC operations"
192
-
})),
193
-
)
194
-
.into_response();
195
-
}
196
-
};
197
-
198
-
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
199
-
.fetch_optional(&state.db)
200
-
.await
201
-
{
202
-
Ok(Some(row)) => row,
203
-
_ => {
204
-
return (
205
-
StatusCode::NOT_FOUND,
206
-
Json(json!({"error": "AccountNotFound"})),
207
-
)
208
-
.into_response();
209
-
}
210
-
};
211
-
212
-
let token_row = match sqlx::query!(
213
-
"SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
214
-
user.id,
215
-
token
216
-
)
217
-
.fetch_optional(&state.db)
218
-
.await
219
-
{
220
-
Ok(Some(row)) => row,
221
-
Ok(None) => {
222
-
return (
223
-
StatusCode::BAD_REQUEST,
224
-
Json(json!({
225
-
"error": "InvalidToken",
226
-
"message": "Invalid or expired token"
227
-
})),
228
-
)
229
-
.into_response();
230
-
}
231
-
Err(e) => {
232
-
error!("DB error: {:?}", e);
233
-
return (
234
-
StatusCode::INTERNAL_SERVER_ERROR,
235
-
Json(json!({"error": "InternalError"})),
236
-
)
237
-
.into_response();
238
-
}
239
-
};
240
-
241
-
if Utc::now() > token_row.expires_at {
242
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
243
-
.execute(&state.db)
244
-
.await;
245
-
return (
246
-
StatusCode::BAD_REQUEST,
247
-
Json(json!({
248
-
"error": "ExpiredToken",
249
-
"message": "Token has expired"
250
-
})),
251
-
)
252
-
.into_response();
253
-
}
254
-
255
-
let key_row = match sqlx::query!(
256
-
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
257
-
user.id
258
-
)
259
-
.fetch_optional(&state.db)
260
-
.await
261
-
{
262
-
Ok(Some(row)) => row,
263
-
_ => {
264
-
return (
265
-
StatusCode::INTERNAL_SERVER_ERROR,
266
-
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
267
-
)
268
-
.into_response();
269
-
}
270
-
};
271
-
272
-
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
273
-
{
274
-
Ok(k) => k,
275
-
Err(e) => {
276
-
error!("Failed to decrypt user key: {}", e);
277
-
return (
278
-
StatusCode::INTERNAL_SERVER_ERROR,
279
-
Json(json!({"error": "InternalError"})),
280
-
)
281
-
.into_response();
282
-
}
283
-
};
284
-
285
-
let signing_key = match SigningKey::from_slice(&key_bytes) {
286
-
Ok(k) => k,
287
-
Err(e) => {
288
-
error!("Failed to create signing key: {:?}", e);
289
-
return (
290
-
StatusCode::INTERNAL_SERVER_ERROR,
291
-
Json(json!({"error": "InternalError"})),
292
-
)
293
-
.into_response();
294
-
}
295
-
};
296
-
297
-
let plc_client = PlcClient::new(None);
298
-
let last_op = match plc_client.get_last_op(did).await {
299
-
Ok(op) => op,
300
-
Err(PlcError::NotFound) => {
301
-
return (
302
-
StatusCode::NOT_FOUND,
303
-
Json(json!({
304
-
"error": "NotFound",
305
-
"message": "DID not found in PLC directory"
306
-
})),
307
-
)
308
-
.into_response();
309
-
}
310
-
Err(e) => {
311
-
error!("Failed to fetch PLC operation: {:?}", e);
312
-
return (
313
-
StatusCode::BAD_GATEWAY,
314
-
Json(json!({
315
-
"error": "UpstreamError",
316
-
"message": "Failed to communicate with PLC directory"
317
-
})),
318
-
)
319
-
.into_response();
320
-
}
321
-
};
322
-
323
-
if last_op.is_tombstone() {
324
-
return (
325
-
StatusCode::BAD_REQUEST,
326
-
Json(json!({
327
-
"error": "InvalidRequest",
328
-
"message": "DID is tombstoned"
329
-
})),
330
-
)
331
-
.into_response();
332
-
}
333
-
334
-
let services = input.services.map(|s| {
335
-
s.into_iter()
336
-
.map(|(k, v)| {
337
-
(
338
-
k,
339
-
PlcService {
340
-
service_type: v.service_type,
341
-
endpoint: v.endpoint,
342
-
},
343
-
)
344
-
})
345
-
.collect()
346
-
});
347
-
348
-
let unsigned_op = match create_update_op(
349
-
&last_op,
350
-
input.rotation_keys,
351
-
input.verification_methods,
352
-
input.also_known_as,
353
-
services,
354
-
) {
355
-
Ok(op) => op,
356
-
Err(PlcError::Tombstoned) => {
357
-
return (
358
-
StatusCode::BAD_REQUEST,
359
-
Json(json!({
360
-
"error": "InvalidRequest",
361
-
"message": "Cannot update tombstoned DID"
362
-
})),
363
-
)
364
-
.into_response();
365
-
}
366
-
Err(e) => {
367
-
error!("Failed to create PLC operation: {:?}", e);
368
-
return (
369
-
StatusCode::INTERNAL_SERVER_ERROR,
370
-
Json(json!({"error": "InternalError"})),
371
-
)
372
-
.into_response();
373
-
}
374
-
};
375
-
376
-
let signed_op = match sign_operation(&unsigned_op, &signing_key) {
377
-
Ok(op) => op,
378
-
Err(e) => {
379
-
error!("Failed to sign PLC operation: {:?}", e);
380
-
return (
381
-
StatusCode::INTERNAL_SERVER_ERROR,
382
-
Json(json!({"error": "InternalError"})),
383
-
)
384
-
.into_response();
385
-
}
386
-
};
387
-
388
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
389
-
.execute(&state.db)
390
-
.await;
391
-
392
-
info!("Signed PLC operation for user {}", did);
393
-
394
-
(
395
-
StatusCode::OK,
396
-
Json(SignPlcOperationOutput {
397
-
operation: signed_op,
398
-
}),
399
-
)
400
-
.into_response()
401
-
}
402
-
403
-
#[derive(Debug, Deserialize)]
404
-
pub struct SubmitPlcOperationInput {
405
-
pub operation: Value,
406
-
}
407
-
408
-
pub async fn submit_plc_operation(
409
-
State(state): State<AppState>,
410
-
headers: axum::http::HeaderMap,
411
-
Json(input): Json<SubmitPlcOperationInput>,
412
-
) -> Response {
413
-
let bearer = match crate::auth::extract_bearer_token_from_header(
414
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
415
-
) {
416
-
Some(t) => t,
417
-
None => {
418
-
return (
419
-
StatusCode::UNAUTHORIZED,
420
-
Json(json!({"error": "AuthenticationRequired"})),
421
-
)
422
-
.into_response();
423
-
}
424
-
};
425
-
426
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
427
-
Ok(user) => user,
428
-
Err(e) => {
429
-
return (
430
-
StatusCode::UNAUTHORIZED,
431
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
432
-
)
433
-
.into_response();
434
-
}
435
-
};
436
-
437
-
let did = &auth_user.did;
438
-
439
-
if let Err(e) = validate_plc_operation(&input.operation) {
440
-
return (
441
-
StatusCode::BAD_REQUEST,
442
-
Json(json!({
443
-
"error": "InvalidRequest",
444
-
"message": format!("Invalid operation: {}", e)
445
-
})),
446
-
)
447
-
.into_response();
448
-
}
449
-
450
-
let op = &input.operation;
451
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
452
-
let public_url = format!("https://{}", hostname);
453
-
454
-
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
455
-
.fetch_optional(&state.db)
456
-
.await
457
-
{
458
-
Ok(Some(row)) => row,
459
-
_ => {
460
-
return (
461
-
StatusCode::NOT_FOUND,
462
-
Json(json!({"error": "AccountNotFound"})),
463
-
)
464
-
.into_response();
465
-
}
466
-
};
467
-
468
-
let key_row = match sqlx::query!(
469
-
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
470
-
user.id
471
-
)
472
-
.fetch_optional(&state.db)
473
-
.await
474
-
{
475
-
Ok(Some(row)) => row,
476
-
_ => {
477
-
return (
478
-
StatusCode::INTERNAL_SERVER_ERROR,
479
-
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
480
-
)
481
-
.into_response();
482
-
}
483
-
};
484
-
485
-
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
486
-
{
487
-
Ok(k) => k,
488
-
Err(e) => {
489
-
error!("Failed to decrypt user key: {}", e);
490
-
return (
491
-
StatusCode::INTERNAL_SERVER_ERROR,
492
-
Json(json!({"error": "InternalError"})),
493
-
)
494
-
.into_response();
495
-
}
496
-
};
497
-
498
-
let signing_key = match SigningKey::from_slice(&key_bytes) {
499
-
Ok(k) => k,
500
-
Err(e) => {
501
-
error!("Failed to create signing key: {:?}", e);
502
-
return (
503
-
StatusCode::INTERNAL_SERVER_ERROR,
504
-
Json(json!({"error": "InternalError"})),
505
-
)
506
-
.into_response();
507
-
}
508
-
};
509
-
510
-
let user_did_key = signing_key_to_did_key(&signing_key);
511
-
512
-
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
513
-
let server_rotation_key =
514
-
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
515
-
516
-
let has_server_key = rotation_keys
517
-
.iter()
518
-
.any(|k| k.as_str() == Some(&server_rotation_key));
519
-
520
-
if !has_server_key {
521
-
return (
522
-
StatusCode::BAD_REQUEST,
523
-
Json(json!({
524
-
"error": "InvalidRequest",
525
-
"message": "Rotation keys do not include server's rotation key"
526
-
})),
527
-
)
528
-
.into_response();
529
-
}
530
-
}
531
-
532
-
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
533
-
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
534
-
let service_type = pds.get("type").and_then(|v| v.as_str());
535
-
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
536
-
537
-
if service_type != Some("AtprotoPersonalDataServer") {
538
-
return (
539
-
StatusCode::BAD_REQUEST,
540
-
Json(json!({
541
-
"error": "InvalidRequest",
542
-
"message": "Incorrect type on atproto_pds service"
543
-
})),
544
-
)
545
-
.into_response();
546
-
}
547
-
548
-
if endpoint != Some(&public_url) {
549
-
return (
550
-
StatusCode::BAD_REQUEST,
551
-
Json(json!({
552
-
"error": "InvalidRequest",
553
-
"message": "Incorrect endpoint on atproto_pds service"
554
-
})),
555
-
)
556
-
.into_response();
557
-
}
558
-
}
559
-
}
560
-
561
-
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
562
-
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
563
-
if atproto_key != user_did_key {
564
-
return (
565
-
StatusCode::BAD_REQUEST,
566
-
Json(json!({
567
-
"error": "InvalidRequest",
568
-
"message": "Incorrect signing key in verificationMethods"
569
-
})),
570
-
)
571
-
.into_response();
572
-
}
573
-
}
574
-
}
575
-
576
-
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
577
-
let expected_handle = format!("at://{}", user.handle);
578
-
let first_aka = also_known_as.first().and_then(|v| v.as_str());
579
-
580
-
if first_aka != Some(&expected_handle) {
581
-
return (
582
-
StatusCode::BAD_REQUEST,
583
-
Json(json!({
584
-
"error": "InvalidRequest",
585
-
"message": "Incorrect handle in alsoKnownAs"
586
-
})),
587
-
)
588
-
.into_response();
589
-
}
590
-
}
591
-
592
-
let plc_client = PlcClient::new(None);
593
-
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
594
-
error!("Failed to submit PLC operation: {:?}", e);
595
-
return (
596
-
StatusCode::BAD_GATEWAY,
597
-
Json(json!({
598
-
"error": "UpstreamError",
599
-
"message": format!("Failed to submit to PLC directory: {}", e)
600
-
})),
601
-
)
602
-
.into_response();
603
-
}
604
-
605
-
if let Err(e) = sqlx::query!(
606
-
"INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
607
-
did
608
-
)
609
-
.execute(&state.db)
610
-
.await
611
-
{
612
-
warn!("Failed to sequence identity event: {:?}", e);
613
-
}
614
-
615
-
info!("Submitted PLC operation for user {}", did);
616
-
617
-
(StatusCode::OK, Json(json!({}))).into_response()
618
-
}
+7
src/api/identity/plc/mod.rs
+7
src/api/identity/plc/mod.rs
+91
src/api/identity/plc/request.rs
+91
src/api/identity/plc/request.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::state::AppState;
3
+
use axum::{
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
Json,
8
+
};
9
+
use chrono::{Duration, Utc};
10
+
use serde_json::json;
11
+
use tracing::{error, info, warn};
12
+
13
+
fn generate_plc_token() -> String {
14
+
crate::util::generate_token_code()
15
+
}
16
+
17
+
pub async fn request_plc_operation_signature(
18
+
State(state): State<AppState>,
19
+
headers: axum::http::HeaderMap,
20
+
) -> Response {
21
+
let token = match crate::auth::extract_bearer_token_from_header(
22
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
23
+
) {
24
+
Some(t) => t,
25
+
None => return ApiError::AuthenticationRequired.into_response(),
26
+
};
27
+
28
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
29
+
Ok(user) => user,
30
+
Err(e) => return ApiError::from(e).into_response(),
31
+
};
32
+
33
+
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
34
+
.fetch_optional(&state.db)
35
+
.await
36
+
{
37
+
Ok(Some(row)) => row,
38
+
Ok(None) => return ApiError::AccountNotFound.into_response(),
39
+
Err(e) => {
40
+
error!("DB error: {:?}", e);
41
+
return ApiError::InternalError.into_response();
42
+
}
43
+
};
44
+
45
+
let _ = sqlx::query!(
46
+
"DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
47
+
user.id
48
+
)
49
+
.execute(&state.db)
50
+
.await;
51
+
52
+
let plc_token = generate_plc_token();
53
+
let expires_at = Utc::now() + Duration::minutes(10);
54
+
55
+
if let Err(e) = sqlx::query!(
56
+
r#"
57
+
INSERT INTO plc_operation_tokens (user_id, token, expires_at)
58
+
VALUES ($1, $2, $3)
59
+
"#,
60
+
user.id,
61
+
plc_token,
62
+
expires_at
63
+
)
64
+
.execute(&state.db)
65
+
.await
66
+
{
67
+
error!("Failed to create PLC token: {:?}", e);
68
+
return (
69
+
StatusCode::INTERNAL_SERVER_ERROR,
70
+
Json(json!({"error": "InternalError"})),
71
+
)
72
+
.into_response();
73
+
}
74
+
75
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
76
+
77
+
if let Err(e) = crate::notifications::enqueue_plc_operation(
78
+
&state.db,
79
+
user.id,
80
+
&plc_token,
81
+
&hostname,
82
+
)
83
+
.await
84
+
{
85
+
warn!("Failed to enqueue PLC operation notification: {:?}", e);
86
+
}
87
+
88
+
info!("PLC operation signature requested for user {}", auth_user.did);
89
+
90
+
(StatusCode::OK, Json(json!({}))).into_response()
91
+
}
+272
src/api/identity/plc/sign.rs
+272
src/api/identity/plc/sign.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::plc::{
3
+
create_update_op, sign_operation, PlcClient, PlcError, PlcService,
4
+
};
5
+
use crate::state::AppState;
6
+
use axum::{
7
+
extract::State,
8
+
http::StatusCode,
9
+
response::{IntoResponse, Response},
10
+
Json,
11
+
};
12
+
use chrono::Utc;
13
+
use k256::ecdsa::SigningKey;
14
+
use serde::{Deserialize, Serialize};
15
+
use serde_json::{json, Value};
16
+
use std::collections::HashMap;
17
+
use tracing::{error, info};
18
+
19
+
#[derive(Debug, Deserialize)]
20
+
#[serde(rename_all = "camelCase")]
21
+
pub struct SignPlcOperationInput {
22
+
pub token: Option<String>,
23
+
pub rotation_keys: Option<Vec<String>>,
24
+
pub also_known_as: Option<Vec<String>>,
25
+
pub verification_methods: Option<HashMap<String, String>>,
26
+
pub services: Option<HashMap<String, ServiceInput>>,
27
+
}
28
+
29
+
#[derive(Debug, Deserialize, Clone)]
30
+
pub struct ServiceInput {
31
+
#[serde(rename = "type")]
32
+
pub service_type: String,
33
+
pub endpoint: String,
34
+
}
35
+
36
+
#[derive(Debug, Serialize)]
37
+
pub struct SignPlcOperationOutput {
38
+
pub operation: Value,
39
+
}
40
+
41
+
pub async fn sign_plc_operation(
42
+
State(state): State<AppState>,
43
+
headers: axum::http::HeaderMap,
44
+
Json(input): Json<SignPlcOperationInput>,
45
+
) -> Response {
46
+
let bearer = match crate::auth::extract_bearer_token_from_header(
47
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
48
+
) {
49
+
Some(t) => t,
50
+
None => return ApiError::AuthenticationRequired.into_response(),
51
+
};
52
+
53
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
54
+
Ok(user) => user,
55
+
Err(e) => return ApiError::from(e).into_response(),
56
+
};
57
+
58
+
let did = &auth_user.did;
59
+
60
+
let token = match &input.token {
61
+
Some(t) => t,
62
+
None => {
63
+
return ApiError::InvalidRequest(
64
+
"Email confirmation token required to sign PLC operations".into()
65
+
).into_response();
66
+
}
67
+
};
68
+
69
+
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
70
+
.fetch_optional(&state.db)
71
+
.await
72
+
{
73
+
Ok(Some(row)) => row,
74
+
_ => {
75
+
return (
76
+
StatusCode::NOT_FOUND,
77
+
Json(json!({"error": "AccountNotFound"})),
78
+
)
79
+
.into_response();
80
+
}
81
+
};
82
+
83
+
let token_row = match sqlx::query!(
84
+
"SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
85
+
user.id,
86
+
token
87
+
)
88
+
.fetch_optional(&state.db)
89
+
.await
90
+
{
91
+
Ok(Some(row)) => row,
92
+
Ok(None) => {
93
+
return (
94
+
StatusCode::BAD_REQUEST,
95
+
Json(json!({
96
+
"error": "InvalidToken",
97
+
"message": "Invalid or expired token"
98
+
})),
99
+
)
100
+
.into_response();
101
+
}
102
+
Err(e) => {
103
+
error!("DB error: {:?}", e);
104
+
return (
105
+
StatusCode::INTERNAL_SERVER_ERROR,
106
+
Json(json!({"error": "InternalError"})),
107
+
)
108
+
.into_response();
109
+
}
110
+
};
111
+
112
+
if Utc::now() > token_row.expires_at {
113
+
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
114
+
.execute(&state.db)
115
+
.await;
116
+
return (
117
+
StatusCode::BAD_REQUEST,
118
+
Json(json!({
119
+
"error": "ExpiredToken",
120
+
"message": "Token has expired"
121
+
})),
122
+
)
123
+
.into_response();
124
+
}
125
+
126
+
let key_row = match sqlx::query!(
127
+
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
128
+
user.id
129
+
)
130
+
.fetch_optional(&state.db)
131
+
.await
132
+
{
133
+
Ok(Some(row)) => row,
134
+
_ => {
135
+
return (
136
+
StatusCode::INTERNAL_SERVER_ERROR,
137
+
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
138
+
)
139
+
.into_response();
140
+
}
141
+
};
142
+
143
+
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
144
+
{
145
+
Ok(k) => k,
146
+
Err(e) => {
147
+
error!("Failed to decrypt user key: {}", e);
148
+
return (
149
+
StatusCode::INTERNAL_SERVER_ERROR,
150
+
Json(json!({"error": "InternalError"})),
151
+
)
152
+
.into_response();
153
+
}
154
+
};
155
+
156
+
let signing_key = match SigningKey::from_slice(&key_bytes) {
157
+
Ok(k) => k,
158
+
Err(e) => {
159
+
error!("Failed to create signing key: {:?}", e);
160
+
return (
161
+
StatusCode::INTERNAL_SERVER_ERROR,
162
+
Json(json!({"error": "InternalError"})),
163
+
)
164
+
.into_response();
165
+
}
166
+
};
167
+
168
+
let plc_client = PlcClient::new(None);
169
+
let last_op = match plc_client.get_last_op(did).await {
170
+
Ok(op) => op,
171
+
Err(PlcError::NotFound) => {
172
+
return (
173
+
StatusCode::NOT_FOUND,
174
+
Json(json!({
175
+
"error": "NotFound",
176
+
"message": "DID not found in PLC directory"
177
+
})),
178
+
)
179
+
.into_response();
180
+
}
181
+
Err(e) => {
182
+
error!("Failed to fetch PLC operation: {:?}", e);
183
+
return (
184
+
StatusCode::BAD_GATEWAY,
185
+
Json(json!({
186
+
"error": "UpstreamError",
187
+
"message": "Failed to communicate with PLC directory"
188
+
})),
189
+
)
190
+
.into_response();
191
+
}
192
+
};
193
+
194
+
if last_op.is_tombstone() {
195
+
return (
196
+
StatusCode::BAD_REQUEST,
197
+
Json(json!({
198
+
"error": "InvalidRequest",
199
+
"message": "DID is tombstoned"
200
+
})),
201
+
)
202
+
.into_response();
203
+
}
204
+
205
+
let services = input.services.map(|s| {
206
+
s.into_iter()
207
+
.map(|(k, v)| {
208
+
(
209
+
k,
210
+
PlcService {
211
+
service_type: v.service_type,
212
+
endpoint: v.endpoint,
213
+
},
214
+
)
215
+
})
216
+
.collect()
217
+
});
218
+
219
+
let unsigned_op = match create_update_op(
220
+
&last_op,
221
+
input.rotation_keys,
222
+
input.verification_methods,
223
+
input.also_known_as,
224
+
services,
225
+
) {
226
+
Ok(op) => op,
227
+
Err(PlcError::Tombstoned) => {
228
+
return (
229
+
StatusCode::BAD_REQUEST,
230
+
Json(json!({
231
+
"error": "InvalidRequest",
232
+
"message": "Cannot update tombstoned DID"
233
+
})),
234
+
)
235
+
.into_response();
236
+
}
237
+
Err(e) => {
238
+
error!("Failed to create PLC operation: {:?}", e);
239
+
return (
240
+
StatusCode::INTERNAL_SERVER_ERROR,
241
+
Json(json!({"error": "InternalError"})),
242
+
)
243
+
.into_response();
244
+
}
245
+
};
246
+
247
+
let signed_op = match sign_operation(&unsigned_op, &signing_key) {
248
+
Ok(op) => op,
249
+
Err(e) => {
250
+
error!("Failed to sign PLC operation: {:?}", e);
251
+
return (
252
+
StatusCode::INTERNAL_SERVER_ERROR,
253
+
Json(json!({"error": "InternalError"})),
254
+
)
255
+
.into_response();
256
+
}
257
+
};
258
+
259
+
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
260
+
.execute(&state.db)
261
+
.await;
262
+
263
+
info!("Signed PLC operation for user {}", did);
264
+
265
+
(
266
+
StatusCode::OK,
267
+
Json(SignPlcOperationOutput {
268
+
operation: signed_op,
269
+
}),
270
+
)
271
+
.into_response()
272
+
}
+211
src/api/identity/plc/submit.rs
+211
src/api/identity/plc/submit.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient};
3
+
use crate::state::AppState;
4
+
use axum::{
5
+
extract::State,
6
+
http::StatusCode,
7
+
response::{IntoResponse, Response},
8
+
Json,
9
+
};
10
+
use k256::ecdsa::SigningKey;
11
+
use serde::Deserialize;
12
+
use serde_json::{json, Value};
13
+
use tracing::{error, info, warn};
14
+
15
+
#[derive(Debug, Deserialize)]
16
+
pub struct SubmitPlcOperationInput {
17
+
pub operation: Value,
18
+
}
19
+
20
+
pub async fn submit_plc_operation(
21
+
State(state): State<AppState>,
22
+
headers: axum::http::HeaderMap,
23
+
Json(input): Json<SubmitPlcOperationInput>,
24
+
) -> Response {
25
+
let bearer = match crate::auth::extract_bearer_token_from_header(
26
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
27
+
) {
28
+
Some(t) => t,
29
+
None => return ApiError::AuthenticationRequired.into_response(),
30
+
};
31
+
32
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
33
+
Ok(user) => user,
34
+
Err(e) => return ApiError::from(e).into_response(),
35
+
};
36
+
37
+
let did = &auth_user.did;
38
+
39
+
if let Err(e) = validate_plc_operation(&input.operation) {
40
+
return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response();
41
+
}
42
+
43
+
let op = &input.operation;
44
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
45
+
let public_url = format!("https://{}", hostname);
46
+
47
+
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
48
+
.fetch_optional(&state.db)
49
+
.await
50
+
{
51
+
Ok(Some(row)) => row,
52
+
_ => {
53
+
return (
54
+
StatusCode::NOT_FOUND,
55
+
Json(json!({"error": "AccountNotFound"})),
56
+
)
57
+
.into_response();
58
+
}
59
+
};
60
+
61
+
let key_row = match sqlx::query!(
62
+
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
63
+
user.id
64
+
)
65
+
.fetch_optional(&state.db)
66
+
.await
67
+
{
68
+
Ok(Some(row)) => row,
69
+
_ => {
70
+
return (
71
+
StatusCode::INTERNAL_SERVER_ERROR,
72
+
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
73
+
)
74
+
.into_response();
75
+
}
76
+
};
77
+
78
+
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
79
+
{
80
+
Ok(k) => k,
81
+
Err(e) => {
82
+
error!("Failed to decrypt user key: {}", e);
83
+
return (
84
+
StatusCode::INTERNAL_SERVER_ERROR,
85
+
Json(json!({"error": "InternalError"})),
86
+
)
87
+
.into_response();
88
+
}
89
+
};
90
+
91
+
let signing_key = match SigningKey::from_slice(&key_bytes) {
92
+
Ok(k) => k,
93
+
Err(e) => {
94
+
error!("Failed to create signing key: {:?}", e);
95
+
return (
96
+
StatusCode::INTERNAL_SERVER_ERROR,
97
+
Json(json!({"error": "InternalError"})),
98
+
)
99
+
.into_response();
100
+
}
101
+
};
102
+
103
+
let user_did_key = signing_key_to_did_key(&signing_key);
104
+
105
+
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
106
+
let server_rotation_key =
107
+
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
108
+
109
+
let has_server_key = rotation_keys
110
+
.iter()
111
+
.any(|k| k.as_str() == Some(&server_rotation_key));
112
+
113
+
if !has_server_key {
114
+
return (
115
+
StatusCode::BAD_REQUEST,
116
+
Json(json!({
117
+
"error": "InvalidRequest",
118
+
"message": "Rotation keys do not include server's rotation key"
119
+
})),
120
+
)
121
+
.into_response();
122
+
}
123
+
}
124
+
125
+
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
126
+
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
127
+
let service_type = pds.get("type").and_then(|v| v.as_str());
128
+
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
129
+
130
+
if service_type != Some("AtprotoPersonalDataServer") {
131
+
return (
132
+
StatusCode::BAD_REQUEST,
133
+
Json(json!({
134
+
"error": "InvalidRequest",
135
+
"message": "Incorrect type on atproto_pds service"
136
+
})),
137
+
)
138
+
.into_response();
139
+
}
140
+
141
+
if endpoint != Some(&public_url) {
142
+
return (
143
+
StatusCode::BAD_REQUEST,
144
+
Json(json!({
145
+
"error": "InvalidRequest",
146
+
"message": "Incorrect endpoint on atproto_pds service"
147
+
})),
148
+
)
149
+
.into_response();
150
+
}
151
+
}
152
+
}
153
+
154
+
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
155
+
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
156
+
if atproto_key != user_did_key {
157
+
return (
158
+
StatusCode::BAD_REQUEST,
159
+
Json(json!({
160
+
"error": "InvalidRequest",
161
+
"message": "Incorrect signing key in verificationMethods"
162
+
})),
163
+
)
164
+
.into_response();
165
+
}
166
+
}
167
+
}
168
+
169
+
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
170
+
let expected_handle = format!("at://{}", user.handle);
171
+
let first_aka = also_known_as.first().and_then(|v| v.as_str());
172
+
173
+
if first_aka != Some(&expected_handle) {
174
+
return (
175
+
StatusCode::BAD_REQUEST,
176
+
Json(json!({
177
+
"error": "InvalidRequest",
178
+
"message": "Incorrect handle in alsoKnownAs"
179
+
})),
180
+
)
181
+
.into_response();
182
+
}
183
+
}
184
+
185
+
let plc_client = PlcClient::new(None);
186
+
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
187
+
error!("Failed to submit PLC operation: {:?}", e);
188
+
return (
189
+
StatusCode::BAD_GATEWAY,
190
+
Json(json!({
191
+
"error": "UpstreamError",
192
+
"message": format!("Failed to submit to PLC directory: {}", e)
193
+
})),
194
+
)
195
+
.into_response();
196
+
}
197
+
198
+
if let Err(e) = sqlx::query!(
199
+
"INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
200
+
did
201
+
)
202
+
.execute(&state.db)
203
+
.await
204
+
{
205
+
warn!("Failed to sequence identity event: {:?}", e);
206
+
}
207
+
208
+
info!("Submitted PLC operation for user {}", did);
209
+
210
+
(StatusCode::OK, Json(json!({}))).into_response()
211
+
}
+4
src/api/mod.rs
+4
src/api/mod.rs
+4
-16
src/api/moderation/mod.rs
+4
-16
src/api/moderation/mod.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::state::AppState;
2
3
use axum::{
3
4
Json,
···
37
38
headers.get("Authorization").and_then(|h| h.to_str().ok())
38
39
) {
39
40
Some(t) => t,
40
-
None => {
41
-
return (
42
-
StatusCode::UNAUTHORIZED,
43
-
Json(json!({"error": "AuthenticationRequired"})),
44
-
)
45
-
.into_response();
46
-
}
41
+
None => return ApiError::AuthenticationRequired.into_response(),
47
42
};
48
43
49
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
50
-
let did = match auth_result {
44
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
51
45
Ok(user) => user.did,
52
-
Err(e) => {
53
-
return (
54
-
StatusCode::UNAUTHORIZED,
55
-
Json(json!({"error": e})),
56
-
)
57
-
.into_response();
58
-
}
46
+
Err(e) => return ApiError::from(e).into_response(),
59
47
};
60
48
61
49
let valid_reason_types = [
+22
-2
src/api/repo/blob.rs
+22
-2
src/api/repo/blob.rs
···
15
15
use std::str::FromStr;
16
16
use tracing::error;
17
17
18
+
const MAX_BLOB_SIZE: usize = 1_000_000;
19
+
18
20
pub async fn upload_blob(
19
21
State(state): State<AppState>,
20
22
headers: axum::http::HeaderMap,
21
23
body: Bytes,
22
24
) -> Response {
25
+
if body.len() > MAX_BLOB_SIZE {
26
+
return (
27
+
StatusCode::PAYLOAD_TOO_LARGE,
28
+
Json(json!({"error": "BlobTooLarge", "message": format!("Blob size {} exceeds maximum of {} bytes", body.len(), MAX_BLOB_SIZE)})),
29
+
)
30
+
.into_response();
31
+
}
32
+
23
33
let token = match crate::auth::extract_bearer_token_from_header(
24
34
headers.get("Authorization").and_then(|h| h.to_str().ok())
25
35
) {
···
57
67
let mut hasher = Sha256::new();
58
68
hasher.update(&data);
59
69
let hash = hasher.finalize();
60
-
let multihash = Multihash::wrap(0x12, &hash).unwrap();
70
+
let multihash = match Multihash::wrap(0x12, &hash) {
71
+
Ok(mh) => mh,
72
+
Err(e) => {
73
+
error!("Failed to create multihash for blob: {:?}", e);
74
+
return (
75
+
StatusCode::INTERNAL_SERVER_ERROR,
76
+
Json(json!({"error": "InternalError", "message": "Failed to hash blob"})),
77
+
)
78
+
.into_response();
79
+
}
80
+
};
61
81
let cid = Cid::new_v1(0x55, multihash);
62
82
let cid_str = cid.to_string();
63
83
···
207
227
}
208
228
};
209
229
210
-
let limit = params.limit.unwrap_or(500).min(1000);
230
+
let limit = params.limit.unwrap_or(500).clamp(1, 1000);
211
231
let cursor_str = params.cursor.unwrap_or_default();
212
232
let (cursor_collection, cursor_rkey) = if cursor_str.contains('|') {
213
233
let parts: Vec<&str> = cursor_str.split('|').collect();
+3
-14
src/api/repo/import.rs
+3
-14
src/api/repo/import.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::state::AppState;
2
3
use crate::sync::import::{apply_import, parse_car, ImportError};
3
4
use crate::sync::verify::CarVerifier;
···
54
55
headers.get("Authorization").and_then(|h| h.to_str().ok()),
55
56
) {
56
57
Some(t) => t,
57
-
None => {
58
-
return (
59
-
StatusCode::UNAUTHORIZED,
60
-
Json(json!({"error": "AuthenticationRequired"})),
61
-
)
62
-
.into_response();
63
-
}
58
+
None => return ApiError::AuthenticationRequired.into_response(),
64
59
};
65
60
66
61
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
67
62
Ok(user) => user,
68
-
Err(e) => {
69
-
return (
70
-
StatusCode::UNAUTHORIZED,
71
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
72
-
)
73
-
.into_response();
74
-
}
63
+
Err(e) => return ApiError::from(e).into_response(),
75
64
};
76
65
77
66
let did = &auth_user.did;
+49
-13
src/api/repo/record/batch.rs
+49
-13
src/api/repo/record/batch.rs
···
17
17
use std::sync::Arc;
18
18
use tracing::error;
19
19
20
+
const MAX_BATCH_WRITES: usize = 200;
21
+
20
22
#[derive(Deserialize)]
21
23
#[serde(tag = "$type")]
22
24
pub enum WriteOp {
···
115
117
.into_response();
116
118
}
117
119
118
-
if input.writes.len() > 200 {
120
+
if input.writes.len() > MAX_BATCH_WRITES {
119
121
return (
120
122
StatusCode::BAD_REQUEST,
121
-
Json(json!({"error": "InvalidRequest", "message": "Too many writes (max 200)"})),
123
+
Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
122
124
)
123
125
.into_response();
124
126
}
···
213
215
.clone()
214
216
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
215
217
let mut record_bytes = Vec::new();
216
-
serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
217
-
let record_cid = tracking_store.put(&record_bytes).await.unwrap();
218
+
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
219
+
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
220
+
}
221
+
let record_cid = match tracking_store.put(&record_bytes).await {
222
+
Ok(c) => c,
223
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
224
+
};
218
225
219
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
220
-
mst = mst.add(&key, record_cid).await.unwrap();
226
+
let collection_nsid = match collection.parse::<Nsid>() {
227
+
Ok(n) => n,
228
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
229
+
};
230
+
let key = format!("{}/{}", collection_nsid, rkey);
231
+
mst = match mst.add(&key, record_cid).await {
232
+
Ok(m) => m,
233
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
234
+
};
221
235
222
236
let uri = format!("at://{}/{}/{}", did, collection, rkey);
223
237
results.push(WriteResult::CreateResult {
···
236
250
value,
237
251
} => {
238
252
let mut record_bytes = Vec::new();
239
-
serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
240
-
let record_cid = tracking_store.put(&record_bytes).await.unwrap();
253
+
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
254
+
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
255
+
}
256
+
let record_cid = match tracking_store.put(&record_bytes).await {
257
+
Ok(c) => c,
258
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
259
+
};
241
260
242
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
243
-
mst = mst.update(&key, record_cid).await.unwrap();
261
+
let collection_nsid = match collection.parse::<Nsid>() {
262
+
Ok(n) => n,
263
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
264
+
};
265
+
let key = format!("{}/{}", collection_nsid, rkey);
266
+
mst = match mst.update(&key, record_cid).await {
267
+
Ok(m) => m,
268
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
269
+
};
244
270
245
271
let uri = format!("at://{}/{}/{}", did, collection, rkey);
246
272
results.push(WriteResult::UpdateResult {
···
254
280
});
255
281
}
256
282
WriteOp::Delete { collection, rkey } => {
257
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
258
-
mst = mst.delete(&key).await.unwrap();
283
+
let collection_nsid = match collection.parse::<Nsid>() {
284
+
Ok(n) => n,
285
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
286
+
};
287
+
let key = format!("{}/{}", collection_nsid, rkey);
288
+
mst = match mst.delete(&key).await {
289
+
Ok(m) => m,
290
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
291
+
};
259
292
260
293
results.push(WriteResult::DeleteResult {});
261
294
ops.push(RecordOp::Delete {
···
266
299
}
267
300
}
268
301
269
-
let new_mst_root = mst.persist().await.unwrap();
302
+
let new_mst_root = match mst.persist().await {
303
+
Ok(c) => c,
304
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
305
+
};
270
306
let written_cids = tracking_store.get_written_cids();
271
307
let written_cids_str = written_cids
272
308
.iter()
+11
-5
src/api/repo/record/utils.rs
+11
-5
src/api/repo/record/utils.rs
···
55
55
let new_root_cid = state.block_store.put(&new_commit_bytes).await
56
56
.map_err(|e| format!("Failed to save commit block: {:?}", e))?;
57
57
58
+
let mut tx = state.db.begin().await
59
+
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
60
+
58
61
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
59
-
.execute(&state.db)
62
+
.execute(&mut *tx)
60
63
.await
61
64
.map_err(|e| format!("DB Error (repos): {}", e))?;
62
65
···
71
74
rkey,
72
75
cid.to_string()
73
76
)
74
-
.execute(&state.db)
77
+
.execute(&mut *tx)
75
78
.await
76
79
.map_err(|e| format!("DB Error (records): {}", e))?;
77
80
}
···
82
85
collection,
83
86
rkey
84
87
)
85
-
.execute(&state.db)
88
+
.execute(&mut *tx)
86
89
.await
87
90
.map_err(|e| format!("DB Error (records): {}", e))?;
88
91
}
···
126
129
&[] as &[String],
127
130
blocks_cids,
128
131
)
129
-
.fetch_one(&state.db)
132
+
.fetch_one(&mut *tx)
130
133
.await
131
134
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
132
135
133
136
sqlx::query(
134
137
&format!("NOTIFY repo_updates, '{}'", seq_row.seq)
135
138
)
136
-
.execute(&state.db)
139
+
.execute(&mut *tx)
137
140
.await
138
141
.map_err(|e| format!("DB Error (notify): {}", e))?;
142
+
143
+
tx.commit().await
144
+
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
139
145
140
146
Ok(CommitResult {
141
147
commit_cid: new_root_cid,
+12
-3
src/api/repo/record/write.rs
+12
-3
src/api/repo/record/write.rs
···
294
294
};
295
295
296
296
let new_mst = if existing_cid.is_some() {
297
-
mst.update(&key, record_cid).await.unwrap()
297
+
match mst.update(&key, record_cid).await {
298
+
Ok(m) => m,
299
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
300
+
}
298
301
} else {
299
-
mst.add(&key, record_cid).await.unwrap()
302
+
match mst.add(&key, record_cid).await {
303
+
Ok(m) => m,
304
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
305
+
}
306
+
};
307
+
let new_mst_root = match new_mst.persist().await {
308
+
Ok(c) => c,
309
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
300
310
};
301
-
let new_mst_root = new_mst.persist().await.unwrap();
302
311
303
312
let op = if existing_cid.is_some() {
304
313
RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid }
+13
-64
src/api/server/account_status.rs
+13
-64
src/api/server/account_status.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::state::AppState;
2
3
use axum::{
3
4
Json,
···
34
35
headers.get("Authorization").and_then(|h| h.to_str().ok())
35
36
) {
36
37
Some(t) => t,
37
-
None => {
38
-
return (
39
-
StatusCode::UNAUTHORIZED,
40
-
Json(json!({"error": "AuthenticationRequired"})),
41
-
)
42
-
.into_response();
43
-
}
38
+
None => return ApiError::AuthenticationRequired.into_response(),
44
39
};
45
40
46
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
47
-
let did = match auth_result {
41
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
48
42
Ok(user) => user.did,
49
-
Err(e) => {
50
-
return (
51
-
StatusCode::UNAUTHORIZED,
52
-
Json(json!({"error": e})),
53
-
)
54
-
.into_response();
55
-
}
43
+
Err(e) => return ApiError::from(e).into_response(),
56
44
};
57
45
58
46
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
127
115
headers.get("Authorization").and_then(|h| h.to_str().ok())
128
116
) {
129
117
Some(t) => t,
130
-
None => {
131
-
return (
132
-
StatusCode::UNAUTHORIZED,
133
-
Json(json!({"error": "AuthenticationRequired"})),
134
-
)
135
-
.into_response();
136
-
}
118
+
None => return ApiError::AuthenticationRequired.into_response(),
137
119
};
138
120
139
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
140
-
let did = match auth_result {
121
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
141
122
Ok(user) => user.did,
142
-
Err(e) => {
143
-
return (
144
-
StatusCode::UNAUTHORIZED,
145
-
Json(json!({"error": e})),
146
-
)
147
-
.into_response();
148
-
}
123
+
Err(e) => return ApiError::from(e).into_response(),
149
124
};
150
125
151
126
let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
···
180
155
headers.get("Authorization").and_then(|h| h.to_str().ok())
181
156
) {
182
157
Some(t) => t,
183
-
None => {
184
-
return (
185
-
StatusCode::UNAUTHORIZED,
186
-
Json(json!({"error": "AuthenticationRequired"})),
187
-
)
188
-
.into_response();
189
-
}
158
+
None => return ApiError::AuthenticationRequired.into_response(),
190
159
};
191
160
192
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
193
-
let did = match auth_result {
161
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
194
162
Ok(user) => user.did,
195
-
Err(e) => {
196
-
return (
197
-
StatusCode::UNAUTHORIZED,
198
-
Json(json!({"error": e})),
199
-
)
200
-
.into_response();
201
-
}
163
+
Err(e) => return ApiError::from(e).into_response(),
202
164
};
203
165
204
166
let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did)
···
226
188
headers.get("Authorization").and_then(|h| h.to_str().ok())
227
189
) {
228
190
Some(t) => t,
229
-
None => {
230
-
return (
231
-
StatusCode::UNAUTHORIZED,
232
-
Json(json!({"error": "AuthenticationRequired"})),
233
-
)
234
-
.into_response();
235
-
}
191
+
None => return ApiError::AuthenticationRequired.into_response(),
236
192
};
237
193
238
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
239
-
let did = match auth_result {
194
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
240
195
Ok(user) => user.did,
241
-
Err(e) => {
242
-
return (
243
-
StatusCode::UNAUTHORIZED,
244
-
Json(json!({"error": e})),
245
-
)
246
-
.into_response();
247
-
}
196
+
Err(e) => return ApiError::from(e).into_response(),
248
197
};
249
198
250
199
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
+63
-190
src/api/server/app_password.rs
+63
-190
src/api/server/app_password.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
1
3
use crate::state::AppState;
4
+
use crate::util::get_user_id_by_did;
2
5
use axum::{
3
6
Json,
4
7
extract::State,
5
-
http::StatusCode,
6
8
response::{IntoResponse, Response},
7
9
};
8
10
use serde::{Deserialize, Serialize};
···
24
26
25
27
pub async fn list_app_passwords(
26
28
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
29
+
BearerAuth(auth_user): BearerAuth,
28
30
) -> Response {
29
-
let token = match crate::auth::extract_bearer_token_from_header(
30
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
31
-
) {
32
-
Some(t) => t,
33
-
None => {
34
-
return (
35
-
StatusCode::UNAUTHORIZED,
36
-
Json(json!({"error": "AuthenticationRequired"})),
37
-
)
38
-
.into_response();
39
-
}
40
-
};
41
-
42
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
43
-
let did = match auth_result {
44
-
Ok(user) => user.did,
45
-
Err(e) => {
46
-
return (
47
-
StatusCode::UNAUTHORIZED,
48
-
Json(json!({"error": e})),
49
-
)
50
-
.into_response();
51
-
}
31
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
32
+
Ok(id) => id,
33
+
Err(e) => return ApiError::from(e).into_response(),
52
34
};
53
35
54
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
55
-
.fetch_optional(&state.db)
56
-
.await
36
+
match sqlx::query!(
37
+
"SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC",
38
+
user_id
39
+
)
40
+
.fetch_all(&state.db)
41
+
.await
57
42
{
58
-
Ok(Some(id)) => id,
59
-
_ => {
60
-
return (
61
-
StatusCode::INTERNAL_SERVER_ERROR,
62
-
Json(json!({"error": "InternalError"})),
63
-
)
64
-
.into_response();
65
-
}
66
-
};
67
-
68
-
let result = sqlx::query!("SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC", user_id)
69
-
.fetch_all(&state.db)
70
-
.await;
71
-
72
-
match result {
73
43
Ok(rows) => {
74
44
let passwords: Vec<AppPassword> = rows
75
45
.iter()
76
-
.map(|row| {
77
-
AppPassword {
78
-
name: row.name.clone(),
79
-
created_at: row.created_at.to_rfc3339(),
80
-
privileged: row.privileged,
81
-
}
46
+
.map(|row| AppPassword {
47
+
name: row.name.clone(),
48
+
created_at: row.created_at.to_rfc3339(),
49
+
privileged: row.privileged,
82
50
})
83
51
.collect();
84
52
85
-
(StatusCode::OK, Json(ListAppPasswordsOutput { passwords })).into_response()
53
+
Json(ListAppPasswordsOutput { passwords }).into_response()
86
54
}
87
55
Err(e) => {
88
56
error!("DB error listing app passwords: {:?}", e);
89
-
(
90
-
StatusCode::INTERNAL_SERVER_ERROR,
91
-
Json(json!({"error": "InternalError"})),
92
-
)
93
-
.into_response()
57
+
ApiError::InternalError.into_response()
94
58
}
95
59
}
96
60
}
···
112
76
113
77
pub async fn create_app_password(
114
78
State(state): State<AppState>,
115
-
headers: axum::http::HeaderMap,
79
+
BearerAuth(auth_user): BearerAuth,
116
80
Json(input): Json<CreateAppPasswordInput>,
117
81
) -> Response {
118
-
let token = match crate::auth::extract_bearer_token_from_header(
119
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
120
-
) {
121
-
Some(t) => t,
122
-
None => {
123
-
return (
124
-
StatusCode::UNAUTHORIZED,
125
-
Json(json!({"error": "AuthenticationRequired"})),
126
-
)
127
-
.into_response();
128
-
}
129
-
};
130
-
131
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
132
-
let did = match auth_result {
133
-
Ok(user) => user.did,
134
-
Err(e) => {
135
-
return (
136
-
StatusCode::UNAUTHORIZED,
137
-
Json(json!({"error": e})),
138
-
)
139
-
.into_response();
140
-
}
141
-
};
142
-
143
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
144
-
.fetch_optional(&state.db)
145
-
.await
146
-
{
147
-
Ok(Some(id)) => id,
148
-
_ => {
149
-
return (
150
-
StatusCode::INTERNAL_SERVER_ERROR,
151
-
Json(json!({"error": "InternalError"})),
152
-
)
153
-
.into_response();
154
-
}
82
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
83
+
Ok(id) => id,
84
+
Err(e) => return ApiError::from(e).into_response(),
155
85
};
156
86
157
87
let name = input.name.trim();
158
88
if name.is_empty() {
159
-
return (
160
-
StatusCode::BAD_REQUEST,
161
-
Json(json!({"error": "InvalidRequest", "message": "name is required"})),
162
-
)
163
-
.into_response();
89
+
return ApiError::InvalidRequest("name is required".into()).into_response();
164
90
}
165
91
166
-
let existing = sqlx::query!("SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name)
167
-
.fetch_optional(&state.db)
168
-
.await;
92
+
let existing = sqlx::query!(
93
+
"SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2",
94
+
user_id,
95
+
name
96
+
)
97
+
.fetch_optional(&state.db)
98
+
.await;
169
99
170
100
if let Ok(Some(_)) = existing {
171
-
return (
172
-
StatusCode::BAD_REQUEST,
173
-
Json(json!({"error": "DuplicateAppPassword", "message": "App password with this name already exists"})),
174
-
)
175
-
.into_response();
101
+
return ApiError::DuplicateAppPassword.into_response();
176
102
}
177
103
178
104
let password: String = (0..4)
···
180
106
use rand::Rng;
181
107
let mut rng = rand::thread_rng();
182
108
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
183
-
(0..4).map(|_| chars[rng.gen_range(0..chars.len())]).collect::<String>()
109
+
(0..4)
110
+
.map(|_| chars[rng.gen_range(0..chars.len())])
111
+
.collect::<String>()
184
112
})
185
113
.collect::<Vec<String>>()
186
114
.join("-");
···
189
117
Ok(h) => h,
190
118
Err(e) => {
191
119
error!("Failed to hash password: {:?}", e);
192
-
return (
193
-
StatusCode::INTERNAL_SERVER_ERROR,
194
-
Json(json!({"error": "InternalError"})),
195
-
)
196
-
.into_response();
120
+
return ApiError::InternalError.into_response();
197
121
}
198
122
};
199
123
200
124
let privileged = input.privileged.unwrap_or(false);
201
125
let created_at = chrono::Utc::now();
202
126
203
-
let result = sqlx::query!(
127
+
match sqlx::query!(
204
128
"INSERT INTO app_passwords (user_id, name, password_hash, created_at, privileged) VALUES ($1, $2, $3, $4, $5)",
205
129
user_id,
206
130
name,
···
209
133
privileged
210
134
)
211
135
.execute(&state.db)
212
-
.await;
213
-
214
-
match result {
215
-
Ok(_) => (
216
-
StatusCode::OK,
217
-
Json(CreateAppPasswordOutput {
218
-
name: name.to_string(),
219
-
password,
220
-
created_at: created_at.to_rfc3339(),
221
-
privileged,
222
-
}),
223
-
)
224
-
.into_response(),
136
+
.await
137
+
{
138
+
Ok(_) => Json(CreateAppPasswordOutput {
139
+
name: name.to_string(),
140
+
password,
141
+
created_at: created_at.to_rfc3339(),
142
+
privileged,
143
+
})
144
+
.into_response(),
225
145
Err(e) => {
226
146
error!("DB error creating app password: {:?}", e);
227
-
(
228
-
StatusCode::INTERNAL_SERVER_ERROR,
229
-
Json(json!({"error": "InternalError"})),
230
-
)
231
-
.into_response()
147
+
ApiError::InternalError.into_response()
232
148
}
233
149
}
234
150
}
···
240
156
241
157
pub async fn revoke_app_password(
242
158
State(state): State<AppState>,
243
-
headers: axum::http::HeaderMap,
159
+
BearerAuth(auth_user): BearerAuth,
244
160
Json(input): Json<RevokeAppPasswordInput>,
245
161
) -> Response {
246
-
let token = match crate::auth::extract_bearer_token_from_header(
247
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
248
-
) {
249
-
Some(t) => t,
250
-
None => {
251
-
return (
252
-
StatusCode::UNAUTHORIZED,
253
-
Json(json!({"error": "AuthenticationRequired"})),
254
-
)
255
-
.into_response();
256
-
}
257
-
};
258
-
259
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
260
-
let did = match auth_result {
261
-
Ok(user) => user.did,
262
-
Err(e) => {
263
-
return (
264
-
StatusCode::UNAUTHORIZED,
265
-
Json(json!({"error": e})),
266
-
)
267
-
.into_response();
268
-
}
269
-
};
270
-
271
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
272
-
.fetch_optional(&state.db)
273
-
.await
274
-
{
275
-
Ok(Some(id)) => id,
276
-
_ => {
277
-
return (
278
-
StatusCode::INTERNAL_SERVER_ERROR,
279
-
Json(json!({"error": "InternalError"})),
280
-
)
281
-
.into_response();
282
-
}
162
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
163
+
Ok(id) => id,
164
+
Err(e) => return ApiError::from(e).into_response(),
283
165
};
284
166
285
167
let name = input.name.trim();
286
168
if name.is_empty() {
287
-
return (
288
-
StatusCode::BAD_REQUEST,
289
-
Json(json!({"error": "InvalidRequest", "message": "name is required"})),
290
-
)
291
-
.into_response();
169
+
return ApiError::InvalidRequest("name is required".into()).into_response();
292
170
}
293
171
294
-
let result = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name)
295
-
.execute(&state.db)
296
-
.await;
297
-
298
-
match result {
172
+
match sqlx::query!(
173
+
"DELETE FROM app_passwords WHERE user_id = $1 AND name = $2",
174
+
user_id,
175
+
name
176
+
)
177
+
.execute(&state.db)
178
+
.await
179
+
{
299
180
Ok(r) => {
300
181
if r.rows_affected() == 0 {
301
-
return (
302
-
StatusCode::NOT_FOUND,
303
-
Json(json!({"error": "AppPasswordNotFound", "message": "App password not found"})),
304
-
)
305
-
.into_response();
182
+
return ApiError::AppPasswordNotFound.into_response();
306
183
}
307
-
(StatusCode::OK, Json(json!({}))).into_response()
184
+
Json(json!({})).into_response()
308
185
}
309
186
Err(e) => {
310
187
error!("DB error revoking app password: {:?}", e);
311
-
(
312
-
StatusCode::INTERNAL_SERVER_ERROR,
313
-
Json(json!({"error": "InternalError"})),
314
-
)
315
-
.into_response()
188
+
ApiError::InternalError.into_response()
316
189
}
317
190
}
318
191
}
+47
-54
src/api/server/email.rs
+47
-54
src/api/server/email.rs
···
1
+
use crate::api::ApiError;
1
2
use crate::state::AppState;
2
3
use axum::{
3
4
Json,
···
6
7
response::{IntoResponse, Response},
7
8
};
8
9
use chrono::{Duration, Utc};
9
-
use rand::Rng;
10
10
use serde::Deserialize;
11
11
use serde_json::json;
12
12
use tracing::{error, info, warn};
13
13
14
14
fn generate_confirmation_code() -> String {
15
-
let mut rng = rand::thread_rng();
16
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
17
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
18
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19
-
format!("{}-{}", part1, part2)
15
+
crate::util::generate_token_code()
20
16
}
21
17
22
18
#[derive(Deserialize)]
···
46
42
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
47
43
let did = match auth_result {
48
44
Ok(user) => user.did,
49
-
Err(e) => {
50
-
return (
51
-
StatusCode::UNAUTHORIZED,
52
-
Json(json!({"error": e})),
53
-
)
54
-
.into_response();
55
-
}
45
+
Err(e) => return ApiError::from(e).into_response(),
56
46
};
57
47
58
48
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
···
72
62
let handle = user.handle;
73
63
74
64
let email = input.email.trim().to_lowercase();
75
-
if email.is_empty() {
65
+
if !crate::api::validation::is_valid_email(&email) {
76
66
return (
77
67
StatusCode::BAD_REQUEST,
78
-
Json(json!({"error": "InvalidRequest", "message": "email is required"})),
68
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
79
69
)
80
70
.into_response();
81
71
}
···
161
151
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
162
152
let did = match auth_result {
163
153
Ok(user) => user.did,
164
-
Err(e) => {
165
-
return (
166
-
StatusCode::UNAUTHORIZED,
167
-
Json(json!({"error": e})),
168
-
)
169
-
.into_response();
170
-
}
154
+
Err(e) => return ApiError::from(e).into_response(),
171
155
};
172
156
173
157
let user = match sqlx::query!(
···
194
178
let email = input.email.trim().to_lowercase();
195
179
let confirmation_code = input.token.trim();
196
180
197
-
if email_pending_verification.is_none() || stored_code.is_none() || expires_at.is_none() {
198
-
return (
199
-
StatusCode::BAD_REQUEST,
200
-
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
201
-
)
202
-
.into_response();
203
-
}
181
+
let (pending_email, saved_code, expiry) = match (email_pending_verification, stored_code, expires_at) {
182
+
(Some(p), Some(c), Some(e)) => (p, c, e),
183
+
_ => {
184
+
return (
185
+
StatusCode::BAD_REQUEST,
186
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
187
+
)
188
+
.into_response();
189
+
}
190
+
};
204
191
205
-
let email_pending_verification = email_pending_verification.unwrap();
206
-
if email_pending_verification != email {
192
+
if pending_email != email {
207
193
return (
208
194
StatusCode::BAD_REQUEST,
209
195
Json(json!({"error": "InvalidRequest", "message": "Email does not match pending update"})),
···
211
197
.into_response();
212
198
}
213
199
214
-
if stored_code.unwrap() != confirmation_code {
200
+
if saved_code != confirmation_code {
215
201
return (
216
202
StatusCode::BAD_REQUEST,
217
203
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
219
205
.into_response();
220
206
}
221
207
222
-
if Utc::now() > expires_at.unwrap() {
208
+
if Utc::now() > expiry {
223
209
return (
224
210
StatusCode::BAD_REQUEST,
225
211
Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
···
229
215
230
216
let update = sqlx::query!(
231
217
"UPDATE users SET email = $1, email_pending_verification = NULL, email_confirmation_code = NULL, email_confirmation_code_expires_at = NULL WHERE id = $2",
232
-
email_pending_verification,
218
+
pending_email,
233
219
user_id
234
220
)
235
221
.execute(&state.db)
···
287
273
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
288
274
let did = match auth_result {
289
275
Ok(user) => user.did,
290
-
Err(e) => {
291
-
return (
292
-
StatusCode::UNAUTHORIZED,
293
-
Json(json!({"error": e})),
294
-
)
295
-
.into_response();
296
-
}
276
+
Err(e) => return ApiError::from(e).into_response(),
297
277
};
298
278
299
279
let user = match sqlx::query!(
···
319
299
let email_pending_verification = user.email_pending_verification;
320
300
321
301
let new_email = input.email.trim().to_lowercase();
322
-
if new_email.is_empty() {
302
+
if !crate::api::validation::is_valid_email(&new_email) {
323
303
return (
324
304
StatusCode::BAD_REQUEST,
325
-
Json(json!({"error": "InvalidRequest", "message": "email is required"})),
326
-
)
327
-
.into_response();
328
-
}
329
-
330
-
if !new_email.contains('@') || !new_email.contains('.') {
331
-
return (
332
-
StatusCode::BAD_REQUEST,
333
-
Json(json!({"error": "InvalidRequest", "message": "Invalid email format"})),
305
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
334
306
)
335
307
.into_response();
336
308
}
···
353
325
}
354
326
};
355
327
356
-
let pending_email = email_pending_verification.unwrap();
328
+
let pending_email = match email_pending_verification {
329
+
Some(p) => p,
330
+
None => {
331
+
return (
332
+
StatusCode::BAD_REQUEST,
333
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
334
+
)
335
+
.into_response();
336
+
}
337
+
};
338
+
357
339
if pending_email.to_lowercase() != new_email {
358
340
return (
359
341
StatusCode::BAD_REQUEST,
···
362
344
.into_response();
363
345
}
364
346
365
-
if stored_code.unwrap() != confirmation_token {
347
+
let saved_code = match stored_code {
348
+
Some(c) => c,
349
+
None => {
350
+
return (
351
+
StatusCode::BAD_REQUEST,
352
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
353
+
)
354
+
.into_response();
355
+
}
356
+
};
357
+
358
+
if saved_code != confirmation_token {
366
359
return (
367
360
StatusCode::BAD_REQUEST,
368
361
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
415
408
416
409
match update {
417
410
Ok(_) => {
418
-
info!("Email updated to {} for user {}", new_email, user_id);
411
+
info!("Email updated for user {}", user_id);
419
412
(StatusCode::OK, Json(json!({}))).into_response()
420
413
}
421
414
Err(e) => {
+61
-209
src/api/server/invite.rs
+61
-209
src/api/server/invite.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
1
3
use crate::state::AppState;
4
+
use crate::util::get_user_id_by_did;
2
5
use axum::{
3
6
Json,
4
7
extract::State,
5
-
http::StatusCode,
6
8
response::{IntoResponse, Response},
7
9
};
8
10
use serde::{Deserialize, Serialize};
9
-
use serde_json::json;
10
11
use tracing::error;
11
12
use uuid::Uuid;
12
13
···
24
25
25
26
pub async fn create_invite_code(
26
27
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
28
+
BearerAuth(auth_user): BearerAuth,
28
29
Json(input): Json<CreateInviteCodeInput>,
29
30
) -> Response {
30
-
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
-
) {
33
-
Some(t) => t,
34
-
None => {
35
-
return (
36
-
StatusCode::UNAUTHORIZED,
37
-
Json(json!({"error": "AuthenticationRequired"})),
38
-
)
39
-
.into_response();
40
-
}
41
-
};
42
-
43
31
if input.use_count < 1 {
44
-
return (
45
-
StatusCode::BAD_REQUEST,
46
-
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
47
-
)
48
-
.into_response();
32
+
return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
49
33
}
50
34
51
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
52
-
let did = match auth_result {
53
-
Ok(user) => user.did,
54
-
Err(e) => {
55
-
return (
56
-
StatusCode::UNAUTHORIZED,
57
-
Json(json!({"error": e})),
58
-
)
59
-
.into_response();
60
-
}
61
-
};
62
-
63
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
64
-
.fetch_optional(&state.db)
65
-
.await
66
-
{
67
-
Ok(Some(id)) => id,
68
-
_ => {
69
-
return (
70
-
StatusCode::INTERNAL_SERVER_ERROR,
71
-
Json(json!({"error": "InternalError"})),
72
-
)
73
-
.into_response();
74
-
}
35
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
36
+
Ok(id) => id,
37
+
Err(e) => return ApiError::from(e).into_response(),
75
38
};
76
39
77
40
let creator_user_id = if let Some(for_account) = &input.for_account {
78
-
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
41
+
match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
79
42
.fetch_optional(&state.db)
80
-
.await;
81
-
82
-
match target {
43
+
.await
44
+
{
83
45
Ok(Some(row)) => row.id,
84
-
Ok(None) => {
85
-
return (
86
-
StatusCode::NOT_FOUND,
87
-
Json(json!({"error": "AccountNotFound", "message": "Target account not found"})),
88
-
)
89
-
.into_response();
90
-
}
46
+
Ok(None) => return ApiError::AccountNotFound.into_response(),
91
47
Err(e) => {
92
48
error!("DB error looking up target account: {:?}", e);
93
-
return (
94
-
StatusCode::INTERNAL_SERVER_ERROR,
95
-
Json(json!({"error": "InternalError"})),
96
-
)
97
-
.into_response();
49
+
return ApiError::InternalError.into_response();
98
50
}
99
51
}
100
52
} else {
···
103
55
104
56
let user_invites_disabled = sqlx::query_scalar!(
105
57
"SELECT invites_disabled FROM users WHERE did = $1",
106
-
did
58
+
auth_user.did
107
59
)
108
60
.fetch_optional(&state.db)
109
61
.await
62
+
.map_err(|e| {
63
+
error!("DB error checking invites_disabled: {:?}", e);
64
+
ApiError::InternalError
65
+
})
110
66
.ok()
111
67
.flatten()
112
68
.flatten()
113
69
.unwrap_or(false);
114
70
115
71
if user_invites_disabled {
116
-
return (
117
-
StatusCode::FORBIDDEN,
118
-
Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})),
119
-
)
120
-
.into_response();
72
+
return ApiError::InvitesDisabled.into_response();
121
73
}
122
74
123
75
let code = Uuid::new_v4().to_string();
124
76
125
-
let result = sqlx::query!(
77
+
match sqlx::query!(
126
78
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
127
79
code,
128
80
input.use_count,
129
81
creator_user_id
130
82
)
131
83
.execute(&state.db)
132
-
.await;
133
-
134
-
match result {
135
-
Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(),
84
+
.await
85
+
{
86
+
Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
136
87
Err(e) => {
137
88
error!("DB error creating invite code: {:?}", e);
138
-
(
139
-
StatusCode::INTERNAL_SERVER_ERROR,
140
-
Json(json!({"error": "InternalError"})),
141
-
)
142
-
.into_response()
89
+
ApiError::InternalError.into_response()
143
90
}
144
91
}
145
92
}
···
165
112
166
113
pub async fn create_invite_codes(
167
114
State(state): State<AppState>,
168
-
headers: axum::http::HeaderMap,
115
+
BearerAuth(auth_user): BearerAuth,
169
116
Json(input): Json<CreateInviteCodesInput>,
170
117
) -> Response {
171
-
let token = match crate::auth::extract_bearer_token_from_header(
172
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
173
-
) {
174
-
Some(t) => t,
175
-
None => {
176
-
return (
177
-
StatusCode::UNAUTHORIZED,
178
-
Json(json!({"error": "AuthenticationRequired"})),
179
-
)
180
-
.into_response();
181
-
}
182
-
};
183
-
184
118
if input.use_count < 1 {
185
-
return (
186
-
StatusCode::BAD_REQUEST,
187
-
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
188
-
)
189
-
.into_response();
119
+
return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
190
120
}
191
121
192
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
193
-
let did = match auth_result {
194
-
Ok(user) => user.did,
195
-
Err(e) => {
196
-
return (
197
-
StatusCode::UNAUTHORIZED,
198
-
Json(json!({"error": e})),
199
-
)
200
-
.into_response();
201
-
}
202
-
};
203
-
204
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
205
-
.fetch_optional(&state.db)
206
-
.await
207
-
{
208
-
Ok(Some(id)) => id,
209
-
_ => {
210
-
return (
211
-
StatusCode::INTERNAL_SERVER_ERROR,
212
-
Json(json!({"error": "InternalError"})),
213
-
)
214
-
.into_response();
215
-
}
122
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
123
+
Ok(id) => id,
124
+
Err(e) => return ApiError::from(e).into_response(),
216
125
};
217
126
218
127
let code_count = input.code_count.unwrap_or(1).max(1);
···
225
134
for _ in 0..code_count {
226
135
let code = Uuid::new_v4().to_string();
227
136
228
-
let insert = sqlx::query!(
137
+
if let Err(e) = sqlx::query!(
229
138
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
230
139
code,
231
140
input.use_count,
232
141
user_id
233
142
)
234
143
.execute(&state.db)
235
-
.await;
236
-
237
-
if let Err(e) = insert {
144
+
.await
145
+
{
238
146
error!("DB error creating invite code: {:?}", e);
239
-
return (
240
-
StatusCode::INTERNAL_SERVER_ERROR,
241
-
Json(json!({"error": "InternalError"})),
242
-
)
243
-
.into_response();
147
+
return ApiError::InternalError.into_response();
244
148
}
245
149
246
150
codes.push(code);
···
252
156
});
253
157
} else {
254
158
for account_did in for_accounts {
255
-
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
159
+
let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
256
160
.fetch_optional(&state.db)
257
-
.await;
258
-
259
-
let target_user_id = match target {
161
+
.await
162
+
{
260
163
Ok(Some(row)) => row.id,
261
-
Ok(None) => {
262
-
continue;
263
-
}
164
+
Ok(None) => continue,
264
165
Err(e) => {
265
166
error!("DB error looking up target account: {:?}", e);
266
-
return (
267
-
StatusCode::INTERNAL_SERVER_ERROR,
268
-
Json(json!({"error": "InternalError"})),
269
-
)
270
-
.into_response();
167
+
return ApiError::InternalError.into_response();
271
168
}
272
169
};
273
170
···
275
172
for _ in 0..code_count {
276
173
let code = Uuid::new_v4().to_string();
277
174
278
-
let insert = sqlx::query!(
175
+
if let Err(e) = sqlx::query!(
279
176
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
280
177
code,
281
178
input.use_count,
282
179
target_user_id
283
180
)
284
181
.execute(&state.db)
285
-
.await;
286
-
287
-
if let Err(e) = insert {
182
+
.await
183
+
{
288
184
error!("DB error creating invite code: {:?}", e);
289
-
return (
290
-
StatusCode::INTERNAL_SERVER_ERROR,
291
-
Json(json!({"error": "InternalError"})),
292
-
)
293
-
.into_response();
185
+
return ApiError::InternalError.into_response();
294
186
}
295
187
296
188
codes.push(code);
···
303
195
}
304
196
}
305
197
306
-
(StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response()
198
+
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
307
199
}
308
200
309
201
#[derive(Deserialize)]
···
339
231
340
232
pub async fn get_account_invite_codes(
341
233
State(state): State<AppState>,
342
-
headers: axum::http::HeaderMap,
234
+
BearerAuth(auth_user): BearerAuth,
343
235
axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
344
236
) -> Response {
345
-
let token = match crate::auth::extract_bearer_token_from_header(
346
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
347
-
) {
348
-
Some(t) => t,
349
-
None => {
350
-
return (
351
-
StatusCode::UNAUTHORIZED,
352
-
Json(json!({"error": "AuthenticationRequired"})),
353
-
)
354
-
.into_response();
355
-
}
356
-
};
357
-
358
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
359
-
let did = match auth_result {
360
-
Ok(user) => user.did,
361
-
Err(e) => {
362
-
return (
363
-
StatusCode::UNAUTHORIZED,
364
-
Json(json!({"error": e})),
365
-
)
366
-
.into_response();
367
-
}
368
-
};
369
-
370
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
371
-
.fetch_optional(&state.db)
372
-
.await
373
-
{
374
-
Ok(Some(id)) => id,
375
-
_ => {
376
-
return (
377
-
StatusCode::INTERNAL_SERVER_ERROR,
378
-
Json(json!({"error": "InternalError"})),
379
-
)
380
-
.into_response();
381
-
}
237
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
238
+
Ok(id) => id,
239
+
Err(e) => return ApiError::from(e).into_response(),
382
240
};
383
241
384
242
let include_used = params.include_used.unwrap_or(true);
385
243
386
-
let codes_result = sqlx::query!(
244
+
let codes_rows = match sqlx::query!(
387
245
r#"
388
246
SELECT code, available_uses, created_at, disabled
389
247
FROM invite_codes
···
393
251
user_id
394
252
)
395
253
.fetch_all(&state.db)
396
-
.await;
397
-
398
-
let codes_rows = match codes_result {
254
+
.await
255
+
{
399
256
Ok(rows) => {
400
257
if include_used {
401
258
rows
···
405
262
}
406
263
Err(e) => {
407
264
error!("DB error fetching invite codes: {:?}", e);
408
-
return (
409
-
StatusCode::INTERNAL_SERVER_ERROR,
410
-
Json(json!({"error": "InternalError"})),
411
-
)
412
-
.into_response();
265
+
return ApiError::InternalError.into_response();
413
266
}
414
267
};
415
268
416
269
let mut codes = Vec::new();
417
270
for row in codes_rows {
418
-
let uses_result = sqlx::query!(
271
+
let uses = sqlx::query!(
419
272
r#"
420
273
SELECT u.did, icu.used_at
421
274
FROM invite_code_uses icu
···
426
279
row.code
427
280
)
428
281
.fetch_all(&state.db)
429
-
.await;
430
-
431
-
let uses = match uses_result {
432
-
Ok(use_rows) => use_rows
282
+
.await
283
+
.map(|use_rows| {
284
+
use_rows
433
285
.iter()
434
286
.map(|u| InviteCodeUse {
435
287
used_by: u.did.clone(),
436
288
used_at: u.used_at.to_rfc3339(),
437
289
})
438
-
.collect(),
439
-
Err(_) => Vec::new(),
440
-
};
290
+
.collect()
291
+
})
292
+
.unwrap_or_default();
441
293
442
294
codes.push(InviteCode {
443
295
code: row.code,
444
296
available: row.available_uses,
445
297
disabled: row.disabled.unwrap_or(false),
446
-
for_account: did.clone(),
447
-
created_by: did.clone(),
298
+
for_account: auth_user.did.clone(),
299
+
created_by: auth_user.did.clone(),
448
300
created_at: row.created_at.to_rfc3339(),
449
301
uses,
450
302
});
451
303
}
452
304
453
-
(StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response()
305
+
Json(GetAccountInviteCodesOutput { codes }).into_response()
454
306
}
+3
-3
src/api/server/mod.rs
+3
-3
src/api/server/mod.rs
···
4
4
pub mod invite;
5
5
pub mod meta;
6
6
pub mod password;
7
+
pub mod service_auth;
7
8
pub mod session;
8
9
pub mod signing_key;
9
10
···
16
17
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
17
18
pub use meta::{describe_server, health};
18
19
pub use password::{request_password_reset, reset_password};
19
-
pub use session::{
20
-
create_session, delete_session, get_service_auth, get_session, refresh_session,
21
-
};
20
+
pub use service_auth::get_service_auth;
21
+
pub use session::{create_session, delete_session, get_session, refresh_session};
22
22
pub use signing_key::reserve_signing_key;
+43
-17
src/api/server/password.rs
+43
-17
src/api/server/password.rs
···
7
7
};
8
8
use bcrypt::{hash, DEFAULT_COST};
9
9
use chrono::{Duration, Utc};
10
-
use rand::Rng;
11
10
use serde::Deserialize;
12
11
use serde_json::json;
13
12
use tracing::{error, info, warn};
14
13
15
14
fn generate_reset_code() -> String {
16
-
let mut rng = rand::thread_rng();
17
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
18
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
20
-
format!("{}-{}", part1, part2)
15
+
crate::util::generate_token_code()
21
16
}
22
17
23
18
#[derive(Deserialize)]
···
45
40
let user_id = match user {
46
41
Ok(Some(row)) => row.id,
47
42
Ok(None) => {
48
-
info!("Password reset requested for unknown email: {}", email);
43
+
info!("Password reset requested for unknown email");
49
44
return (StatusCode::OK, Json(json!({}))).into_response();
50
45
}
51
46
Err(e) => {
···
151
146
152
147
if let Some(exp) = expires_at {
153
148
if Utc::now() > exp {
154
-
let _ = sqlx::query!(
149
+
if let Err(e) = sqlx::query!(
155
150
"UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
156
151
user_id
157
152
)
158
153
.execute(&state.db)
159
-
.await;
154
+
.await
155
+
{
156
+
error!("Failed to clear expired reset code: {:?}", e);
157
+
}
160
158
161
159
return (
162
160
StatusCode::BAD_REQUEST,
···
184
182
}
185
183
};
186
184
187
-
let update = sqlx::query!(
185
+
let mut tx = match state.db.begin().await {
186
+
Ok(tx) => tx,
187
+
Err(e) => {
188
+
error!("Failed to begin transaction: {:?}", e);
189
+
return (
190
+
StatusCode::INTERNAL_SERVER_ERROR,
191
+
Json(json!({"error": "InternalError"})),
192
+
)
193
+
.into_response();
194
+
}
195
+
};
196
+
197
+
if let Err(e) = sqlx::query!(
188
198
"UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
189
199
password_hash,
190
200
user_id
191
201
)
192
-
.execute(&state.db)
193
-
.await;
194
-
195
-
if let Err(e) = update {
202
+
.execute(&mut *tx)
203
+
.await
204
+
{
196
205
error!("DB error updating password: {:?}", e);
197
206
return (
198
207
StatusCode::INTERNAL_SERVER_ERROR,
···
201
210
.into_response();
202
211
}
203
212
204
-
let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
205
-
.execute(&state.db)
206
-
.await;
213
+
if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
214
+
.execute(&mut *tx)
215
+
.await
216
+
{
217
+
error!("Failed to invalidate sessions after password reset: {:?}", e);
218
+
return (
219
+
StatusCode::INTERNAL_SERVER_ERROR,
220
+
Json(json!({"error": "InternalError"})),
221
+
)
222
+
.into_response();
223
+
}
224
+
225
+
if let Err(e) = tx.commit().await {
226
+
error!("Failed to commit password reset transaction: {:?}", e);
227
+
return (
228
+
StatusCode::INTERNAL_SERVER_ERROR,
229
+
Json(json!({"error": "InternalError"})),
230
+
)
231
+
.into_response();
232
+
}
207
233
208
234
info!("Password reset completed for user {}", user_id);
209
235
+63
src/api/server/service_auth.rs
+63
src/api/server/service_auth.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::state::AppState;
3
+
use axum::{
4
+
Json,
5
+
extract::{Query, State},
6
+
http::StatusCode,
7
+
response::{IntoResponse, Response},
8
+
};
9
+
use serde::{Deserialize, Serialize};
10
+
use serde_json::json;
11
+
use tracing::error;
12
+
13
+
#[derive(Deserialize)]
14
+
pub struct GetServiceAuthParams {
15
+
pub aud: String,
16
+
pub lxm: Option<String>,
17
+
pub exp: Option<i64>,
18
+
}
19
+
20
+
#[derive(Serialize)]
21
+
pub struct GetServiceAuthOutput {
22
+
pub token: String,
23
+
}
24
+
25
+
pub async fn get_service_auth(
26
+
State(state): State<AppState>,
27
+
headers: axum::http::HeaderMap,
28
+
Query(params): Query<GetServiceAuthParams>,
29
+
) -> Response {
30
+
let token = match crate::auth::extract_bearer_token_from_header(
31
+
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
+
) {
33
+
Some(t) => t,
34
+
None => return ApiError::AuthenticationRequired.into_response(),
35
+
};
36
+
37
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
38
+
Ok(user) => user,
39
+
Err(e) => return ApiError::from(e).into_response(),
40
+
};
41
+
42
+
let key_bytes = match auth_user.key_bytes {
43
+
Some(kb) => kb,
44
+
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot create service auth".into()).into_response(),
45
+
};
46
+
47
+
let lxm = params.lxm.as_deref().unwrap_or("*");
48
+
49
+
let service_token = match crate::auth::create_service_token(&auth_user.did, ¶ms.aud, lxm, &key_bytes)
50
+
{
51
+
Ok(t) => t,
52
+
Err(e) => {
53
+
error!("Failed to create service token: {:?}", e);
54
+
return (
55
+
StatusCode::INTERNAL_SERVER_ERROR,
56
+
Json(json!({"error": "InternalError"})),
57
+
)
58
+
.into_response();
59
+
}
60
+
};
61
+
62
+
(StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response()
63
+
}
+199
-512
src/api/server/session.rs
+199
-512
src/api/server/session.rs
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
1
3
use crate::state::AppState;
2
4
use axum::{
3
5
Json,
4
-
extract::{Query, State},
5
-
http::StatusCode,
6
+
extract::State,
6
7
response::{IntoResponse, Response},
7
8
};
8
9
use bcrypt::verify;
···
11
12
use tracing::{error, info, warn};
12
13
13
14
#[derive(Deserialize)]
14
-
pub struct GetServiceAuthParams {
15
-
pub aud: String,
16
-
pub lxm: Option<String>,
17
-
pub exp: Option<i64>,
18
-
}
19
-
20
-
#[derive(Serialize)]
21
-
pub struct GetServiceAuthOutput {
22
-
pub token: String,
23
-
}
24
-
25
-
pub async fn get_service_auth(
26
-
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
28
-
Query(params): Query<GetServiceAuthParams>,
29
-
) -> Response {
30
-
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
-
) {
33
-
Some(t) => t,
34
-
None => {
35
-
return (
36
-
StatusCode::UNAUTHORIZED,
37
-
Json(json!({"error": "AuthenticationRequired"})),
38
-
)
39
-
.into_response();
40
-
}
41
-
};
42
-
43
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
44
-
let (did, key_bytes) = match auth_result {
45
-
Ok(user) => {
46
-
let kb = match user.key_bytes {
47
-
Some(kb) => kb,
48
-
None => {
49
-
return (
50
-
StatusCode::UNAUTHORIZED,
51
-
Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot create service auth"})),
52
-
)
53
-
.into_response();
54
-
}
55
-
};
56
-
(user.did, kb)
57
-
}
58
-
Err(e) => {
59
-
return (
60
-
StatusCode::UNAUTHORIZED,
61
-
Json(json!({"error": e})),
62
-
)
63
-
.into_response();
64
-
}
65
-
};
66
-
67
-
let lxm = params.lxm.as_deref().unwrap_or("*");
68
-
69
-
let service_token = match crate::auth::create_service_token(&did, ¶ms.aud, lxm, &key_bytes)
70
-
{
71
-
Ok(t) => t,
72
-
Err(e) => {
73
-
error!("Failed to create service token: {:?}", e);
74
-
return (
75
-
StatusCode::INTERNAL_SERVER_ERROR,
76
-
Json(json!({"error": "InternalError"})),
77
-
)
78
-
.into_response();
79
-
}
80
-
};
81
-
82
-
(StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response()
83
-
}
84
-
85
-
#[derive(Deserialize)]
86
15
pub struct CreateSessionInput {
87
16
pub identifier: String,
88
17
pub password: String,
···
101
30
State(state): State<AppState>,
102
31
Json(input): Json<CreateSessionInput>,
103
32
) -> Response {
104
-
info!("create_session: identifier='{}'", input.identifier);
33
+
info!("create_session called");
105
34
106
-
let user_row = sqlx::query!(
35
+
let row = match sqlx::query!(
107
36
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
108
37
input.identifier
109
38
)
110
-
.fetch_optional(&state.db)
111
-
.await;
112
-
113
-
match user_row {
114
-
Ok(Some(row)) => {
115
-
let user_id = row.id;
116
-
let stored_hash = &row.password_hash;
117
-
let did = &row.did;
118
-
let handle = &row.handle;
119
-
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
120
-
Ok(k) => k,
121
-
Err(e) => {
122
-
error!("Failed to decrypt user key: {:?}", e);
123
-
return (
124
-
StatusCode::INTERNAL_SERVER_ERROR,
125
-
Json(json!({"error": "InternalError"})),
126
-
)
127
-
.into_response();
128
-
}
129
-
};
130
-
131
-
let password_valid = if verify(&input.password, stored_hash).unwrap_or(false) {
132
-
true
133
-
} else {
134
-
let app_pass_rows = sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", user_id)
135
-
.fetch_all(&state.db)
136
-
.await
137
-
.unwrap_or_default();
138
-
139
-
app_pass_rows.iter().any(|row| {
140
-
verify(&input.password, &row.password_hash).unwrap_or(false)
141
-
})
142
-
};
143
-
144
-
if password_valid {
145
-
let access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) {
146
-
Ok(m) => m,
147
-
Err(e) => {
148
-
error!("Failed to create access token: {:?}", e);
149
-
return (
150
-
StatusCode::INTERNAL_SERVER_ERROR,
151
-
Json(json!({"error": "InternalError"})),
152
-
)
153
-
.into_response();
154
-
}
155
-
};
156
-
157
-
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) {
158
-
Ok(m) => m,
159
-
Err(e) => {
160
-
error!("Failed to create refresh token: {:?}", e);
161
-
return (
162
-
StatusCode::INTERNAL_SERVER_ERROR,
163
-
Json(json!({"error": "InternalError"})),
164
-
)
165
-
.into_response();
166
-
}
167
-
};
168
-
169
-
let session_insert = sqlx::query!(
170
-
"INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)",
171
-
did,
172
-
access_meta.jti,
173
-
refresh_meta.jti,
174
-
access_meta.expires_at,
175
-
refresh_meta.expires_at
176
-
)
177
-
.execute(&state.db)
178
-
.await;
179
-
180
-
match session_insert {
181
-
Ok(_) => {
182
-
return (
183
-
StatusCode::OK,
184
-
Json(CreateSessionOutput {
185
-
access_jwt: access_meta.token,
186
-
refresh_jwt: refresh_meta.token,
187
-
handle: handle.clone(),
188
-
did: did.clone(),
189
-
}),
190
-
)
191
-
.into_response();
192
-
}
193
-
Err(e) => {
194
-
error!("Failed to insert session: {:?}", e);
195
-
return (
196
-
StatusCode::INTERNAL_SERVER_ERROR,
197
-
Json(json!({"error": "InternalError"})),
198
-
)
199
-
.into_response();
200
-
}
201
-
}
202
-
} else {
203
-
warn!(
204
-
"Password verification failed for identifier: {}",
205
-
input.identifier
206
-
);
207
-
}
208
-
}
39
+
.fetch_optional(&state.db)
40
+
.await
41
+
{
42
+
Ok(Some(row)) => row,
209
43
Ok(None) => {
210
-
warn!("User not found for identifier: {}", input.identifier);
44
+
warn!("User not found for login attempt");
45
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
211
46
}
212
47
Err(e) => {
213
48
error!("Database error fetching user: {:?}", e);
214
-
return (
215
-
StatusCode::INTERNAL_SERVER_ERROR,
216
-
Json(json!({"error": "InternalError"})),
217
-
)
218
-
.into_response();
49
+
return ApiError::InternalError.into_response();
50
+
}
51
+
};
52
+
53
+
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
54
+
Ok(k) => k,
55
+
Err(e) => {
56
+
error!("Failed to decrypt user key: {:?}", e);
57
+
return ApiError::InternalError.into_response();
219
58
}
59
+
};
60
+
61
+
let password_valid = verify(&input.password, &row.password_hash).unwrap_or(false)
62
+
|| sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", row.id)
63
+
.fetch_all(&state.db)
64
+
.await
65
+
.unwrap_or_default()
66
+
.iter()
67
+
.any(|app| verify(&input.password, &app.password_hash).unwrap_or(false));
68
+
69
+
if !password_valid {
70
+
warn!("Password verification failed for login attempt");
71
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
220
72
}
221
73
222
-
(
223
-
StatusCode::UNAUTHORIZED,
224
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})),
225
-
)
226
-
.into_response()
227
-
}
228
-
229
-
pub async fn get_session(
230
-
State(state): State<AppState>,
231
-
headers: axum::http::HeaderMap,
232
-
) -> Response {
233
-
let token = match crate::auth::extract_bearer_token_from_header(
234
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
235
-
) {
236
-
Some(t) => t,
237
-
None => {
238
-
return (
239
-
StatusCode::UNAUTHORIZED,
240
-
Json(json!({"error": "AuthenticationRequired", "message": "Invalid Authorization header format"})),
241
-
)
242
-
.into_response();
74
+
let access_meta = match crate::auth::create_access_token_with_metadata(&row.did, &key_bytes) {
75
+
Ok(m) => m,
76
+
Err(e) => {
77
+
error!("Failed to create access token: {:?}", e);
78
+
return ApiError::InternalError.into_response();
243
79
}
244
80
};
245
81
246
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
247
-
let did = match auth_result {
248
-
Ok(user) => user.did,
82
+
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&row.did, &key_bytes) {
83
+
Ok(m) => m,
249
84
Err(e) => {
250
-
return (
251
-
StatusCode::UNAUTHORIZED,
252
-
Json(json!({"error": e})),
253
-
)
254
-
.into_response();
85
+
error!("Failed to create refresh token: {:?}", e);
86
+
return ApiError::InternalError.into_response();
255
87
}
256
88
};
257
89
258
-
let user = sqlx::query!(
259
-
"SELECT handle, email FROM users WHERE did = $1",
260
-
did
90
+
if let Err(e) = sqlx::query!(
91
+
"INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)",
92
+
row.did,
93
+
access_meta.jti,
94
+
refresh_meta.jti,
95
+
access_meta.expires_at,
96
+
refresh_meta.expires_at
261
97
)
262
-
.fetch_optional(&state.db)
263
-
.await;
98
+
.execute(&state.db)
99
+
.await
100
+
{
101
+
error!("Failed to insert session: {:?}", e);
102
+
return ApiError::InternalError.into_response();
103
+
}
264
104
265
-
match user {
266
-
Ok(Some(row)) => {
267
-
return (
268
-
StatusCode::OK,
269
-
Json(json!({
270
-
"handle": row.handle,
271
-
"did": did,
272
-
"email": row.email,
273
-
"didDoc": {}
274
-
})),
275
-
)
276
-
.into_response();
277
-
}
278
-
Ok(None) => {
279
-
return (
280
-
StatusCode::UNAUTHORIZED,
281
-
Json(json!({"error": "AuthenticationFailed"})),
282
-
)
283
-
.into_response();
284
-
}
105
+
Json(CreateSessionOutput {
106
+
access_jwt: access_meta.token,
107
+
refresh_jwt: refresh_meta.token,
108
+
handle: row.handle,
109
+
did: row.did,
110
+
}).into_response()
111
+
}
112
+
113
+
pub async fn get_session(
114
+
State(state): State<AppState>,
115
+
BearerAuth(auth_user): BearerAuth,
116
+
) -> Response {
117
+
match sqlx::query!("SELECT handle, email FROM users WHERE did = $1", auth_user.did)
118
+
.fetch_optional(&state.db)
119
+
.await
120
+
{
121
+
Ok(Some(row)) => Json(json!({
122
+
"handle": row.handle,
123
+
"did": auth_user.did,
124
+
"email": row.email,
125
+
"didDoc": {}
126
+
})).into_response(),
127
+
Ok(None) => ApiError::AuthenticationFailed.into_response(),
285
128
Err(e) => {
286
129
error!("Database error in get_session: {:?}", e);
287
-
return (
288
-
StatusCode::INTERNAL_SERVER_ERROR,
289
-
Json(json!({"error": "InternalError"})),
290
-
)
291
-
.into_response();
130
+
ApiError::InternalError.into_response()
292
131
}
293
132
}
294
133
}
···
301
140
headers.get("Authorization").and_then(|h| h.to_str().ok())
302
141
) {
303
142
Some(t) => t,
304
-
None => {
305
-
return (
306
-
StatusCode::UNAUTHORIZED,
307
-
Json(json!({"error": "AuthenticationRequired"})),
308
-
)
309
-
.into_response();
310
-
}
143
+
None => return ApiError::AuthenticationRequired.into_response(),
311
144
};
312
145
313
-
let jti = match crate::auth::get_did_from_token(&token) {
314
-
Ok(_) => {
315
-
let parts: Vec<&str> = token.split('.').collect();
316
-
if parts.len() != 3 {
317
-
return (
318
-
StatusCode::UNAUTHORIZED,
319
-
Json(json!({"error": "AuthenticationFailed"})),
320
-
)
321
-
.into_response();
322
-
}
323
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
324
-
let claims_json = match URL_SAFE_NO_PAD.decode(parts[1]) {
325
-
Ok(bytes) => bytes,
326
-
Err(_) => {
327
-
return (
328
-
StatusCode::UNAUTHORIZED,
329
-
Json(json!({"error": "AuthenticationFailed"})),
330
-
)
331
-
.into_response();
332
-
}
333
-
};
334
-
let claims: serde_json::Value = match serde_json::from_slice(&claims_json) {
335
-
Ok(c) => c,
336
-
Err(_) => {
337
-
return (
338
-
StatusCode::UNAUTHORIZED,
339
-
Json(json!({"error": "AuthenticationFailed"})),
340
-
)
341
-
.into_response();
342
-
}
343
-
};
344
-
match claims.get("jti").and_then(|j| j.as_str()) {
345
-
Some(jti) => jti.to_string(),
346
-
None => {
347
-
return (
348
-
StatusCode::UNAUTHORIZED,
349
-
Json(json!({"error": "AuthenticationFailed"})),
350
-
)
351
-
.into_response();
352
-
}
353
-
}
354
-
}
355
-
Err(_) => {
356
-
return (
357
-
StatusCode::UNAUTHORIZED,
358
-
Json(json!({"error": "AuthenticationFailed"})),
359
-
)
360
-
.into_response();
361
-
}
146
+
let jti = match crate::auth::get_jti_from_token(&token) {
147
+
Ok(jti) => jti,
148
+
Err(_) => return ApiError::AuthenticationFailed.into_response(),
362
149
};
363
150
364
-
let result = sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti)
151
+
match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti)
365
152
.execute(&state.db)
366
-
.await;
367
-
368
-
match result {
369
-
Ok(res) => {
370
-
if res.rows_affected() > 0 {
371
-
return (StatusCode::OK, Json(json!({}))).into_response();
372
-
}
373
-
}
153
+
.await
154
+
{
155
+
Ok(res) if res.rows_affected() > 0 => Json(json!({})).into_response(),
156
+
Ok(_) => ApiError::AuthenticationFailed.into_response(),
374
157
Err(e) => {
375
158
error!("Database error in delete_session: {:?}", e);
159
+
ApiError::AuthenticationFailed.into_response()
376
160
}
377
161
}
378
-
379
-
(
380
-
StatusCode::UNAUTHORIZED,
381
-
Json(json!({"error": "AuthenticationFailed"})),
382
-
)
383
-
.into_response()
384
162
}
385
163
386
164
pub async fn refresh_session(
387
165
State(state): State<AppState>,
388
166
headers: axum::http::HeaderMap,
389
167
) -> Response {
390
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
391
-
392
168
let refresh_token = match crate::auth::extract_bearer_token_from_header(
393
169
headers.get("Authorization").and_then(|h| h.to_str().ok())
394
170
) {
395
171
Some(t) => t,
396
-
None => {
397
-
return (
398
-
StatusCode::UNAUTHORIZED,
399
-
Json(json!({"error": "AuthenticationRequired"})),
400
-
)
401
-
.into_response();
402
-
}
172
+
None => return ApiError::AuthenticationRequired.into_response(),
403
173
};
404
174
405
-
let refresh_jti = {
406
-
let parts: Vec<&str> = refresh_token.split('.').collect();
407
-
if parts.len() != 3 {
408
-
return (
409
-
StatusCode::UNAUTHORIZED,
410
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token format"})),
411
-
)
412
-
.into_response();
413
-
}
414
-
let claims_bytes = match URL_SAFE_NO_PAD.decode(parts[1]) {
415
-
Ok(b) => b,
416
-
Err(_) => {
417
-
return (
418
-
StatusCode::UNAUTHORIZED,
419
-
Json(json!({"error": "AuthenticationFailed"})),
420
-
)
421
-
.into_response();
422
-
}
423
-
};
424
-
let claims: serde_json::Value = match serde_json::from_slice(&claims_bytes) {
425
-
Ok(c) => c,
426
-
Err(_) => {
427
-
return (
428
-
StatusCode::UNAUTHORIZED,
429
-
Json(json!({"error": "AuthenticationFailed"})),
430
-
)
431
-
.into_response();
432
-
}
433
-
};
434
-
match claims.get("jti").and_then(|j| j.as_str()) {
435
-
Some(jti) => jti.to_string(),
436
-
None => {
437
-
return (
438
-
StatusCode::UNAUTHORIZED,
439
-
Json(json!({"error": "AuthenticationFailed"})),
440
-
)
441
-
.into_response();
442
-
}
175
+
let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) {
176
+
Ok(jti) => jti,
177
+
Err(_) => return ApiError::AuthenticationFailedMsg("Invalid token format".into()).into_response(),
178
+
};
179
+
180
+
let mut tx = match state.db.begin().await {
181
+
Ok(tx) => tx,
182
+
Err(e) => {
183
+
error!("Failed to begin transaction: {:?}", e);
184
+
return ApiError::InternalError.into_response();
443
185
}
444
186
};
445
187
446
-
let reuse_check = sqlx::query_scalar!(
447
-
"SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1",
188
+
if let Ok(Some(session_id)) = sqlx::query_scalar!(
189
+
"SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE",
448
190
refresh_jti
449
191
)
450
-
.fetch_optional(&state.db)
451
-
.await;
452
-
453
-
if let Ok(Some(session_id)) = reuse_check {
192
+
.fetch_optional(&mut *tx)
193
+
.await
194
+
{
454
195
warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id);
455
196
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
456
-
.execute(&state.db)
197
+
.execute(&mut *tx)
457
198
.await;
458
-
return (
459
-
StatusCode::UNAUTHORIZED,
460
-
Json(json!({"error": "ExpiredToken", "message": "Refresh token has been revoked due to suspected compromise"})),
461
-
)
462
-
.into_response();
199
+
let _ = tx.commit().await;
200
+
return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response();
463
201
}
464
202
465
-
let session = sqlx::query!(
203
+
let session_row = match sqlx::query!(
466
204
r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version
467
205
FROM session_tokens st
468
206
JOIN users u ON st.did = u.did
469
207
JOIN user_keys k ON u.id = k.user_id
470
-
WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()"#,
208
+
WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()
209
+
FOR UPDATE OF st"#,
471
210
refresh_jti
472
211
)
473
-
.fetch_optional(&state.db)
474
-
.await;
212
+
.fetch_optional(&mut *tx)
213
+
.await
214
+
{
215
+
Ok(Some(row)) => row,
216
+
Ok(None) => return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(),
217
+
Err(e) => {
218
+
error!("Database error fetching session: {:?}", e);
219
+
return ApiError::InternalError.into_response();
220
+
}
221
+
};
475
222
476
-
match session {
477
-
Ok(Some(session_row)) => {
478
-
let session_id = session_row.id;
479
-
let did = &session_row.did;
480
-
let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
481
-
Ok(k) => k,
482
-
Err(e) => {
483
-
error!("Failed to decrypt user key: {:?}", e);
484
-
return (
485
-
StatusCode::INTERNAL_SERVER_ERROR,
486
-
Json(json!({"error": "InternalError"})),
487
-
)
488
-
.into_response();
489
-
}
490
-
};
223
+
let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
224
+
Ok(k) => k,
225
+
Err(e) => {
226
+
error!("Failed to decrypt user key: {:?}", e);
227
+
return ApiError::InternalError.into_response();
228
+
}
229
+
};
491
230
492
-
if let Err(_) = crate::auth::verify_refresh_token(&refresh_token, &key_bytes) {
493
-
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response();
494
-
}
231
+
if crate::auth::verify_refresh_token(&refresh_token, &key_bytes).is_err() {
232
+
return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response();
233
+
}
495
234
496
-
let new_access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) {
497
-
Ok(m) => m,
498
-
Err(e) => {
499
-
error!("Failed to create access token: {:?}", e);
500
-
return (
501
-
StatusCode::INTERNAL_SERVER_ERROR,
502
-
Json(json!({"error": "InternalError"})),
503
-
)
504
-
.into_response();
505
-
}
506
-
};
507
-
let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) {
508
-
Ok(m) => m,
509
-
Err(e) => {
510
-
error!("Failed to create refresh token: {:?}", e);
511
-
return (
512
-
StatusCode::INTERNAL_SERVER_ERROR,
513
-
Json(json!({"error": "InternalError"})),
514
-
)
515
-
.into_response();
516
-
}
517
-
};
235
+
let new_access_meta = match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) {
236
+
Ok(m) => m,
237
+
Err(e) => {
238
+
error!("Failed to create access token: {:?}", e);
239
+
return ApiError::InternalError.into_response();
240
+
}
241
+
};
518
242
519
-
let mut tx = match state.db.begin().await {
520
-
Ok(tx) => tx,
521
-
Err(e) => {
522
-
error!("Failed to begin transaction: {:?}", e);
523
-
return (
524
-
StatusCode::INTERNAL_SERVER_ERROR,
525
-
Json(json!({"error": "InternalError"})),
526
-
)
527
-
.into_response();
528
-
}
529
-
};
243
+
let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) {
244
+
Ok(m) => m,
245
+
Err(e) => {
246
+
error!("Failed to create refresh token: {:?}", e);
247
+
return ApiError::InternalError.into_response();
248
+
}
249
+
};
530
250
531
-
if let Err(e) = sqlx::query!(
532
-
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2)",
533
-
refresh_jti,
534
-
session_id
535
-
)
536
-
.execute(&mut *tx)
537
-
.await
538
-
{
539
-
error!("Failed to record used refresh token: {:?}", e);
540
-
return (
541
-
StatusCode::INTERNAL_SERVER_ERROR,
542
-
Json(json!({"error": "InternalError"})),
543
-
)
544
-
.into_response();
545
-
}
251
+
match sqlx::query!(
252
+
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
253
+
refresh_jti,
254
+
session_row.id
255
+
)
256
+
.execute(&mut *tx)
257
+
.await
258
+
{
259
+
Ok(result) if result.rows_affected() == 0 => {
260
+
warn!("Concurrent refresh token reuse detected for session_id: {}", session_row.id);
261
+
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_row.id)
262
+
.execute(&mut *tx)
263
+
.await;
264
+
let _ = tx.commit().await;
265
+
return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response();
266
+
}
267
+
Err(e) => {
268
+
error!("Failed to record used refresh token: {:?}", e);
269
+
return ApiError::InternalError.into_response();
270
+
}
271
+
Ok(_) => {}
272
+
}
546
273
547
-
if let Err(e) = sqlx::query!(
548
-
"UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5",
549
-
new_access_meta.jti,
550
-
new_refresh_meta.jti,
551
-
new_access_meta.expires_at,
552
-
new_refresh_meta.expires_at,
553
-
session_id
554
-
)
555
-
.execute(&mut *tx)
556
-
.await
557
-
{
558
-
error!("Database error updating session: {:?}", e);
559
-
return (
560
-
StatusCode::INTERNAL_SERVER_ERROR,
561
-
Json(json!({"error": "InternalError"})),
562
-
)
563
-
.into_response();
564
-
}
274
+
if let Err(e) = sqlx::query!(
275
+
"UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5",
276
+
new_access_meta.jti,
277
+
new_refresh_meta.jti,
278
+
new_access_meta.expires_at,
279
+
new_refresh_meta.expires_at,
280
+
session_row.id
281
+
)
282
+
.execute(&mut *tx)
283
+
.await
284
+
{
285
+
error!("Database error updating session: {:?}", e);
286
+
return ApiError::InternalError.into_response();
287
+
}
565
288
566
-
if let Err(e) = tx.commit().await {
567
-
error!("Failed to commit transaction: {:?}", e);
568
-
return (
569
-
StatusCode::INTERNAL_SERVER_ERROR,
570
-
Json(json!({"error": "InternalError"})),
571
-
)
572
-
.into_response();
573
-
}
289
+
if let Err(e) = tx.commit().await {
290
+
error!("Failed to commit transaction: {:?}", e);
291
+
return ApiError::InternalError.into_response();
292
+
}
574
293
575
-
let user = sqlx::query!("SELECT handle FROM users WHERE did = $1", did)
576
-
.fetch_optional(&state.db)
577
-
.await;
578
-
579
-
match user {
580
-
Ok(Some(u)) => {
581
-
return (
582
-
StatusCode::OK,
583
-
Json(json!({
584
-
"accessJwt": new_access_meta.token,
585
-
"refreshJwt": new_refresh_meta.token,
586
-
"handle": u.handle,
587
-
"did": did
588
-
})),
589
-
)
590
-
.into_response();
591
-
}
592
-
Ok(None) => {
593
-
error!("User not found for existing session: {}", did);
594
-
return (
595
-
StatusCode::INTERNAL_SERVER_ERROR,
596
-
Json(json!({"error": "InternalError"})),
597
-
)
598
-
.into_response();
599
-
}
600
-
Err(e) => {
601
-
error!("Database error fetching user: {:?}", e);
602
-
return (
603
-
StatusCode::INTERNAL_SERVER_ERROR,
604
-
Json(json!({"error": "InternalError"})),
605
-
)
606
-
.into_response();
607
-
}
608
-
}
609
-
}
294
+
match sqlx::query!("SELECT handle FROM users WHERE did = $1", session_row.did)
295
+
.fetch_optional(&state.db)
296
+
.await
297
+
{
298
+
Ok(Some(u)) => Json(json!({
299
+
"accessJwt": new_access_meta.token,
300
+
"refreshJwt": new_refresh_meta.token,
301
+
"handle": u.handle,
302
+
"did": session_row.did
303
+
})).into_response(),
610
304
Ok(None) => {
611
-
return (
612
-
StatusCode::UNAUTHORIZED,
613
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})),
614
-
)
615
-
.into_response();
305
+
error!("User not found for existing session: {}", session_row.did);
306
+
ApiError::InternalError.into_response()
616
307
}
617
308
Err(e) => {
618
-
error!("Database error fetching session: {:?}", e);
619
-
return (
620
-
StatusCode::INTERNAL_SERVER_ERROR,
621
-
Json(json!({"error": "InternalError"})),
622
-
)
623
-
.into_response();
309
+
error!("Database error fetching user: {:?}", e);
310
+
ApiError::InternalError.into_response()
624
311
}
625
312
}
626
313
}
+104
src/api/validation.rs
+104
src/api/validation.rs
···
1
+
pub const MAX_EMAIL_LENGTH: usize = 254;
2
+
pub const MAX_LOCAL_PART_LENGTH: usize = 64;
3
+
pub const MAX_DOMAIN_LENGTH: usize = 253;
4
+
pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63;
5
+
6
+
const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-";
7
+
8
+
pub fn is_valid_email(email: &str) -> bool {
9
+
let email = email.trim();
10
+
11
+
if email.is_empty() || email.len() > MAX_EMAIL_LENGTH {
12
+
return false;
13
+
}
14
+
15
+
let parts: Vec<&str> = email.rsplitn(2, '@').collect();
16
+
if parts.len() != 2 {
17
+
return false;
18
+
}
19
+
20
+
let domain = parts[0];
21
+
let local = parts[1];
22
+
23
+
if local.is_empty() || local.len() > MAX_LOCAL_PART_LENGTH {
24
+
return false;
25
+
}
26
+
27
+
if local.starts_with('.') || local.ends_with('.') {
28
+
return false;
29
+
}
30
+
31
+
if local.contains("..") {
32
+
return false;
33
+
}
34
+
35
+
for c in local.chars() {
36
+
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
37
+
return false;
38
+
}
39
+
}
40
+
41
+
if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH {
42
+
return false;
43
+
}
44
+
45
+
if !domain.contains('.') {
46
+
return false;
47
+
}
48
+
49
+
for label in domain.split('.') {
50
+
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
51
+
return false;
52
+
}
53
+
54
+
if label.starts_with('-') || label.ends_with('-') {
55
+
return false;
56
+
}
57
+
58
+
for c in label.chars() {
59
+
if !c.is_ascii_alphanumeric() && c != '-' {
60
+
return false;
61
+
}
62
+
}
63
+
}
64
+
65
+
true
66
+
}
67
+
68
+
#[cfg(test)]
69
+
mod tests {
70
+
use super::*;
71
+
72
+
#[test]
73
+
fn test_valid_emails() {
74
+
assert!(is_valid_email("user@example.com"));
75
+
assert!(is_valid_email("user.name@example.com"));
76
+
assert!(is_valid_email("user+tag@example.com"));
77
+
assert!(is_valid_email("user@sub.example.com"));
78
+
assert!(is_valid_email("USER@EXAMPLE.COM"));
79
+
assert!(is_valid_email("user123@example123.com"));
80
+
assert!(is_valid_email("a@b.co"));
81
+
}
82
+
83
+
#[test]
84
+
fn test_invalid_emails() {
85
+
assert!(!is_valid_email(""));
86
+
assert!(!is_valid_email("user"));
87
+
assert!(!is_valid_email("user@"));
88
+
assert!(!is_valid_email("@example.com"));
89
+
assert!(!is_valid_email("user@example"));
90
+
assert!(!is_valid_email("user@@example.com"));
91
+
assert!(!is_valid_email("user@.example.com"));
92
+
assert!(!is_valid_email("user@example..com"));
93
+
assert!(!is_valid_email(".user@example.com"));
94
+
assert!(!is_valid_email("user.@example.com"));
95
+
assert!(!is_valid_email("user..name@example.com"));
96
+
assert!(!is_valid_email("user@-example.com"));
97
+
assert!(!is_valid_email("user@example-.com"));
98
+
}
99
+
100
+
#[test]
101
+
fn test_trimmed_whitespace() {
102
+
assert!(is_valid_email(" user@example.com "));
103
+
}
104
+
}
+29
-3
src/auth/extractor.rs
+29
-3
src/auth/extractor.rs
···
7
7
use serde_json::json;
8
8
9
9
use crate::state::AppState;
10
-
use super::{AuthenticatedUser, validate_bearer_token};
10
+
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token, validate_bearer_token_allow_deactivated};
11
11
12
12
pub struct BearerAuth(pub AuthenticatedUser);
13
13
···
112
112
113
113
match validate_bearer_token(&state.db, token).await {
114
114
Ok(user) => Ok(BearerAuth(user)),
115
-
Err("AccountDeactivated") => Err(AuthError::AccountDeactivated),
116
-
Err("AccountTakedown") => Err(AuthError::AccountTakedown),
115
+
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
116
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
117
+
Err(_) => Err(AuthError::AuthenticationFailed),
118
+
}
119
+
}
120
+
}
121
+
122
+
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
123
+
124
+
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
125
+
type Rejection = AuthError;
126
+
127
+
async fn from_request_parts(
128
+
parts: &mut Parts,
129
+
state: &AppState,
130
+
) -> Result<Self, Self::Rejection> {
131
+
let auth_header = parts
132
+
.headers
133
+
.get(AUTHORIZATION)
134
+
.ok_or(AuthError::MissingToken)?
135
+
.to_str()
136
+
.map_err(|_| AuthError::InvalidFormat)?;
137
+
138
+
let token = extract_bearer_token(auth_header)?;
139
+
140
+
match validate_bearer_token_allow_deactivated(&state.db, token).await {
141
+
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
142
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
117
143
Err(_) => Err(AuthError::AuthenticationFailed),
118
144
}
119
145
}
+31
-13
src/auth/mod.rs
+31
-13
src/auth/mod.rs
···
1
1
use serde::{Deserialize, Serialize};
2
2
use sqlx::PgPool;
3
+
use std::fmt;
3
4
4
5
pub mod extractor;
5
6
pub mod token;
6
7
pub mod verify;
7
8
8
-
pub use extractor::{BearerAuth, AuthError, extract_bearer_token_from_header};
9
+
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
9
10
pub use token::{
10
11
create_access_token, create_refresh_token, create_service_token,
11
12
create_access_token_with_metadata, create_refresh_token_with_metadata,
···
14
15
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
15
16
};
16
17
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
18
+
19
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20
+
pub enum TokenValidationError {
21
+
AccountDeactivated,
22
+
AccountTakedown,
23
+
KeyDecryptionFailed,
24
+
AuthenticationFailed,
25
+
}
26
+
27
+
impl fmt::Display for TokenValidationError {
28
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29
+
match self {
30
+
Self::AccountDeactivated => write!(f, "AccountDeactivated"),
31
+
Self::AccountTakedown => write!(f, "AccountTakedown"),
32
+
Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
33
+
Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
34
+
}
35
+
}
36
+
}
17
37
18
38
pub struct AuthenticatedUser {
19
39
pub did: String,
···
24
44
pub async fn validate_bearer_token(
25
45
db: &PgPool,
26
46
token: &str,
27
-
) -> Result<AuthenticatedUser, &'static str> {
47
+
) -> Result<AuthenticatedUser, TokenValidationError> {
28
48
validate_bearer_token_with_options(db, token, false).await
29
49
}
30
50
31
51
pub async fn validate_bearer_token_allow_deactivated(
32
52
db: &PgPool,
33
53
token: &str,
34
-
) -> Result<AuthenticatedUser, &'static str> {
54
+
) -> Result<AuthenticatedUser, TokenValidationError> {
35
55
validate_bearer_token_with_options(db, token, true).await
36
56
}
37
57
···
39
59
db: &PgPool,
40
60
token: &str,
41
61
allow_deactivated: bool,
42
-
) -> Result<AuthenticatedUser, &'static str> {
62
+
) -> Result<AuthenticatedUser, TokenValidationError> {
43
63
let did_from_token = get_did_from_token(token).ok();
44
64
45
65
if let Some(ref did) = did_from_token {
···
56
76
.flatten()
57
77
{
58
78
if !allow_deactivated && user.deactivated_at.is_some() {
59
-
return Err("AccountDeactivated");
79
+
return Err(TokenValidationError::AccountDeactivated);
60
80
}
61
81
if user.takedown_ref.is_some() {
62
-
return Err("AccountTakedown");
82
+
return Err(TokenValidationError::AccountTakedown);
63
83
}
64
84
65
-
let decrypted_key = match crate::config::decrypt_key(&user.key_bytes, user.encryption_version) {
66
-
Ok(k) => k,
67
-
Err(_) => return Err("KeyDecryptionFailed"),
68
-
};
85
+
let decrypted_key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
86
+
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
69
87
70
88
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
71
89
let session_exists = sqlx::query_scalar!(
···
103
121
.flatten()
104
122
{
105
123
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
106
-
return Err("AccountDeactivated");
124
+
return Err(TokenValidationError::AccountDeactivated);
107
125
}
108
126
if oauth_token.takedown_ref.is_some() {
109
-
return Err("AccountTakedown");
127
+
return Err(TokenValidationError::AccountTakedown);
110
128
}
111
129
112
130
let now = chrono::Utc::now();
···
120
138
}
121
139
}
122
140
123
-
Err("AuthenticationFailed")
141
+
Err(TokenValidationError::AuthenticationFailed)
124
142
}
125
143
126
144
#[derive(Debug, Serialize, Deserialize)]
+7
-3
src/config.rs
+7
-3
src/config.rs
···
62
62
let seed = hasher.finalize();
63
63
64
64
let signing_key = SigningKey::from_slice(&seed)
65
-
.expect("Failed to create signing key from seed");
65
+
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
66
66
67
67
let verifying_key = signing_key.verifying_key();
68
68
let point = verifying_key.to_encoded_point(false);
69
69
70
-
let signing_key_x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
71
-
let signing_key_y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
70
+
let signing_key_x = URL_SAFE_NO_PAD.encode(
71
+
point.x().expect("EC point missing X coordinate - this should never happen")
72
+
);
73
+
let signing_key_y = URL_SAFE_NO_PAD.encode(
74
+
point.y().expect("EC point missing Y coordinate - this should never happen")
75
+
);
72
76
73
77
let mut kid_hasher = Sha256::new();
74
78
kid_hasher.update(signing_key_x.as_bytes());
+1
src/lib.rs
+1
src/lib.rs
+43
-15
src/main.rs
+43
-15
src/main.rs
···
1
1
use bspds::notifications::{EmailSender, NotificationService};
2
2
use bspds::state::AppState;
3
3
use std::net::SocketAddr;
4
+
use std::process::ExitCode;
4
5
use tokio::sync::watch;
5
-
use tracing::{info, warn};
6
+
use tracing::{error, info, warn};
6
7
7
8
#[tokio::main]
8
-
async fn main() {
9
+
async fn main() -> ExitCode {
9
10
dotenvy::dotenv().ok();
10
11
tracing_subscriber::fmt::init();
11
12
12
-
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
13
+
match run().await {
14
+
Ok(()) => ExitCode::SUCCESS,
15
+
Err(e) => {
16
+
error!("Fatal error: {}", e);
17
+
ExitCode::FAILURE
18
+
}
19
+
}
20
+
}
21
+
22
+
async fn run() -> Result<(), Box<dyn std::error::Error>> {
23
+
let database_url = std::env::var("DATABASE_URL")
24
+
.map_err(|_| "DATABASE_URL environment variable must be set")?;
13
25
14
26
let pool = sqlx::postgres::PgPoolOptions::new()
15
-
.max_connections(5)
27
+
.max_connections(20)
28
+
.min_connections(2)
29
+
.acquire_timeout(std::time::Duration::from_secs(10))
30
+
.idle_timeout(std::time::Duration::from_secs(300))
31
+
.max_lifetime(std::time::Duration::from_secs(1800))
16
32
.connect(&database_url)
17
33
.await
18
-
.expect("Failed to connect to Postgres");
34
+
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
19
35
20
36
sqlx::migrate!("./migrations")
21
37
.run(&pool)
22
38
.await
23
-
.expect("Failed to run migrations");
39
+
.map_err(|e| format!("Failed to run migrations: {}", e))?;
24
40
25
41
let state = AppState::new(pool.clone()).await;
26
42
···
50
66
51
67
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
52
68
info!("listening on {}", addr);
53
-
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
69
+
let listener = tokio::net::TcpListener::bind(addr)
70
+
.await
71
+
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
54
72
55
73
let server_result = axum::serve(listener, app)
56
74
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
···
59
77
notification_handle.await.ok();
60
78
61
79
if let Err(e) = server_result {
62
-
tracing::error!("Server error: {}", e);
80
+
return Err(format!("Server error: {}", e).into());
63
81
}
82
+
83
+
Ok(())
64
84
}
65
85
66
86
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
67
87
let ctrl_c = async {
68
-
tokio::signal::ctrl_c()
69
-
.await
70
-
.expect("Failed to install Ctrl+C handler");
88
+
match tokio::signal::ctrl_c().await {
89
+
Ok(()) => {}
90
+
Err(e) => {
91
+
error!("Failed to install Ctrl+C handler: {}", e);
92
+
}
93
+
}
71
94
};
72
95
73
96
#[cfg(unix)]
74
97
let terminate = async {
75
-
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
76
-
.expect("Failed to install signal handler")
77
-
.recv()
78
-
.await;
98
+
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
99
+
Ok(mut signal) => {
100
+
signal.recv().await;
101
+
}
102
+
Err(e) => {
103
+
error!("Failed to install SIGTERM handler: {}", e);
104
+
std::future::pending::<()>().await;
105
+
}
106
+
}
79
107
};
80
108
81
109
#[cfg(not(unix))]
-641
src/oauth/db.rs
-641
src/oauth/db.rs
···
1
-
use chrono::{DateTime, Utc};
2
-
use serde::{de::DeserializeOwned, Serialize};
3
-
use sqlx::PgPool;
4
-
5
-
use super::{
6
-
AuthorizationRequestParameters, ClientAuth, DeviceData, OAuthError, RequestData, TokenData,
7
-
AuthorizedClientData,
8
-
};
9
-
10
-
fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
11
-
serde_json::to_value(value).map_err(|e| {
12
-
tracing::error!("JSON serialization error: {}", e);
13
-
OAuthError::ServerError("Internal serialization error".to_string())
14
-
})
15
-
}
16
-
17
-
fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
18
-
serde_json::from_value(value).map_err(|e| {
19
-
tracing::error!("JSON deserialization error: {}", e);
20
-
OAuthError::ServerError("Internal data corruption".to_string())
21
-
})
22
-
}
23
-
24
-
pub async fn create_device(
25
-
pool: &PgPool,
26
-
device_id: &str,
27
-
data: &DeviceData,
28
-
) -> Result<(), OAuthError> {
29
-
sqlx::query!(
30
-
r#"
31
-
INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at)
32
-
VALUES ($1, $2, $3, $4, $5)
33
-
"#,
34
-
device_id,
35
-
data.session_id,
36
-
data.user_agent,
37
-
data.ip_address,
38
-
data.last_seen_at,
39
-
)
40
-
.execute(pool)
41
-
.await?;
42
-
43
-
Ok(())
44
-
}
45
-
46
-
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
47
-
let row = sqlx::query!(
48
-
r#"
49
-
SELECT session_id, user_agent, ip_address, last_seen_at
50
-
FROM oauth_device
51
-
WHERE id = $1
52
-
"#,
53
-
device_id
54
-
)
55
-
.fetch_optional(pool)
56
-
.await?;
57
-
58
-
Ok(row.map(|r| DeviceData {
59
-
session_id: r.session_id,
60
-
user_agent: r.user_agent,
61
-
ip_address: r.ip_address,
62
-
last_seen_at: r.last_seen_at,
63
-
}))
64
-
}
65
-
66
-
pub async fn update_device_last_seen(
67
-
pool: &PgPool,
68
-
device_id: &str,
69
-
) -> Result<(), OAuthError> {
70
-
sqlx::query!(
71
-
r#"
72
-
UPDATE oauth_device
73
-
SET last_seen_at = NOW()
74
-
WHERE id = $1
75
-
"#,
76
-
device_id
77
-
)
78
-
.execute(pool)
79
-
.await?;
80
-
81
-
Ok(())
82
-
}
83
-
84
-
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
85
-
sqlx::query!(
86
-
r#"
87
-
DELETE FROM oauth_device WHERE id = $1
88
-
"#,
89
-
device_id
90
-
)
91
-
.execute(pool)
92
-
.await?;
93
-
94
-
Ok(())
95
-
}
96
-
97
-
pub async fn create_authorization_request(
98
-
pool: &PgPool,
99
-
request_id: &str,
100
-
data: &RequestData,
101
-
) -> Result<(), OAuthError> {
102
-
let client_auth_json = match &data.client_auth {
103
-
Some(ca) => Some(to_json(ca)?),
104
-
None => None,
105
-
};
106
-
let parameters_json = to_json(&data.parameters)?;
107
-
108
-
sqlx::query!(
109
-
r#"
110
-
INSERT INTO oauth_authorization_request
111
-
(id, did, device_id, client_id, client_auth, parameters, expires_at, code)
112
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
113
-
"#,
114
-
request_id,
115
-
data.did,
116
-
data.device_id,
117
-
data.client_id,
118
-
client_auth_json,
119
-
parameters_json,
120
-
data.expires_at,
121
-
data.code,
122
-
)
123
-
.execute(pool)
124
-
.await?;
125
-
126
-
Ok(())
127
-
}
128
-
129
-
pub async fn get_authorization_request(
130
-
pool: &PgPool,
131
-
request_id: &str,
132
-
) -> Result<Option<RequestData>, OAuthError> {
133
-
let row = sqlx::query!(
134
-
r#"
135
-
SELECT did, device_id, client_id, client_auth, parameters, expires_at, code
136
-
FROM oauth_authorization_request
137
-
WHERE id = $1
138
-
"#,
139
-
request_id
140
-
)
141
-
.fetch_optional(pool)
142
-
.await?;
143
-
144
-
match row {
145
-
Some(r) => {
146
-
let client_auth: Option<ClientAuth> = match r.client_auth {
147
-
Some(v) => Some(from_json(v)?),
148
-
None => None,
149
-
};
150
-
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
151
-
152
-
Ok(Some(RequestData {
153
-
client_id: r.client_id,
154
-
client_auth,
155
-
parameters,
156
-
expires_at: r.expires_at,
157
-
did: r.did,
158
-
device_id: r.device_id,
159
-
code: r.code,
160
-
}))
161
-
}
162
-
None => Ok(None),
163
-
}
164
-
}
165
-
166
-
pub async fn update_authorization_request(
167
-
pool: &PgPool,
168
-
request_id: &str,
169
-
did: &str,
170
-
device_id: Option<&str>,
171
-
code: &str,
172
-
) -> Result<(), OAuthError> {
173
-
sqlx::query!(
174
-
r#"
175
-
UPDATE oauth_authorization_request
176
-
SET did = $2, device_id = $3, code = $4
177
-
WHERE id = $1
178
-
"#,
179
-
request_id,
180
-
did,
181
-
device_id,
182
-
code
183
-
)
184
-
.execute(pool)
185
-
.await?;
186
-
187
-
Ok(())
188
-
}
189
-
190
-
pub async fn consume_authorization_request_by_code(
191
-
pool: &PgPool,
192
-
code: &str,
193
-
) -> Result<Option<RequestData>, OAuthError> {
194
-
let row = sqlx::query!(
195
-
r#"
196
-
DELETE FROM oauth_authorization_request
197
-
WHERE code = $1
198
-
RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code
199
-
"#,
200
-
code
201
-
)
202
-
.fetch_optional(pool)
203
-
.await?;
204
-
205
-
match row {
206
-
Some(r) => {
207
-
let client_auth: Option<ClientAuth> = match r.client_auth {
208
-
Some(v) => Some(from_json(v)?),
209
-
None => None,
210
-
};
211
-
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
212
-
213
-
Ok(Some(RequestData {
214
-
client_id: r.client_id,
215
-
client_auth,
216
-
parameters,
217
-
expires_at: r.expires_at,
218
-
did: r.did,
219
-
device_id: r.device_id,
220
-
code: r.code,
221
-
}))
222
-
}
223
-
None => Ok(None),
224
-
}
225
-
}
226
-
227
-
pub async fn delete_authorization_request(
228
-
pool: &PgPool,
229
-
request_id: &str,
230
-
) -> Result<(), OAuthError> {
231
-
sqlx::query!(
232
-
r#"
233
-
DELETE FROM oauth_authorization_request WHERE id = $1
234
-
"#,
235
-
request_id
236
-
)
237
-
.execute(pool)
238
-
.await?;
239
-
240
-
Ok(())
241
-
}
242
-
243
-
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
244
-
let result = sqlx::query!(
245
-
r#"
246
-
DELETE FROM oauth_authorization_request
247
-
WHERE expires_at < NOW()
248
-
"#
249
-
)
250
-
.execute(pool)
251
-
.await?;
252
-
253
-
Ok(result.rows_affected())
254
-
}
255
-
256
-
pub async fn create_token(
257
-
pool: &PgPool,
258
-
data: &TokenData,
259
-
) -> Result<i32, OAuthError> {
260
-
let client_auth_json = to_json(&data.client_auth)?;
261
-
let parameters_json = to_json(&data.parameters)?;
262
-
263
-
let row = sqlx::query!(
264
-
r#"
265
-
INSERT INTO oauth_token
266
-
(did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
267
-
device_id, parameters, details, code, current_refresh_token, scope)
268
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
269
-
RETURNING id
270
-
"#,
271
-
data.did,
272
-
data.token_id,
273
-
data.created_at,
274
-
data.updated_at,
275
-
data.expires_at,
276
-
data.client_id,
277
-
client_auth_json,
278
-
data.device_id,
279
-
parameters_json,
280
-
data.details,
281
-
data.code,
282
-
data.current_refresh_token,
283
-
data.scope,
284
-
)
285
-
.fetch_one(pool)
286
-
.await?;
287
-
288
-
Ok(row.id)
289
-
}
290
-
291
-
pub async fn get_token_by_id(
292
-
pool: &PgPool,
293
-
token_id: &str,
294
-
) -> Result<Option<TokenData>, OAuthError> {
295
-
let row = sqlx::query!(
296
-
r#"
297
-
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
298
-
device_id, parameters, details, code, current_refresh_token, scope
299
-
FROM oauth_token
300
-
WHERE token_id = $1
301
-
"#,
302
-
token_id
303
-
)
304
-
.fetch_optional(pool)
305
-
.await?;
306
-
307
-
match row {
308
-
Some(r) => Ok(Some(TokenData {
309
-
did: r.did,
310
-
token_id: r.token_id,
311
-
created_at: r.created_at,
312
-
updated_at: r.updated_at,
313
-
expires_at: r.expires_at,
314
-
client_id: r.client_id,
315
-
client_auth: from_json(r.client_auth)?,
316
-
device_id: r.device_id,
317
-
parameters: from_json(r.parameters)?,
318
-
details: r.details,
319
-
code: r.code,
320
-
current_refresh_token: r.current_refresh_token,
321
-
scope: r.scope,
322
-
})),
323
-
None => Ok(None),
324
-
}
325
-
}
326
-
327
-
pub async fn get_token_by_refresh_token(
328
-
pool: &PgPool,
329
-
refresh_token: &str,
330
-
) -> Result<Option<(i32, TokenData)>, OAuthError> {
331
-
let row = sqlx::query!(
332
-
r#"
333
-
SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
334
-
device_id, parameters, details, code, current_refresh_token, scope
335
-
FROM oauth_token
336
-
WHERE current_refresh_token = $1
337
-
"#,
338
-
refresh_token
339
-
)
340
-
.fetch_optional(pool)
341
-
.await?;
342
-
343
-
match row {
344
-
Some(r) => Ok(Some((
345
-
r.id,
346
-
TokenData {
347
-
did: r.did,
348
-
token_id: r.token_id,
349
-
created_at: r.created_at,
350
-
updated_at: r.updated_at,
351
-
expires_at: r.expires_at,
352
-
client_id: r.client_id,
353
-
client_auth: from_json(r.client_auth)?,
354
-
device_id: r.device_id,
355
-
parameters: from_json(r.parameters)?,
356
-
details: r.details,
357
-
code: r.code,
358
-
current_refresh_token: r.current_refresh_token,
359
-
scope: r.scope,
360
-
},
361
-
))),
362
-
None => Ok(None),
363
-
}
364
-
}
365
-
366
-
pub async fn rotate_token(
367
-
pool: &PgPool,
368
-
old_db_id: i32,
369
-
new_token_id: &str,
370
-
new_refresh_token: &str,
371
-
new_expires_at: DateTime<Utc>,
372
-
) -> Result<(), OAuthError> {
373
-
let mut tx = pool.begin().await?;
374
-
375
-
let old_refresh = sqlx::query_scalar!(
376
-
r#"
377
-
SELECT current_refresh_token FROM oauth_token WHERE id = $1
378
-
"#,
379
-
old_db_id
380
-
)
381
-
.fetch_one(&mut *tx)
382
-
.await?;
383
-
384
-
if let Some(old_rt) = old_refresh {
385
-
sqlx::query!(
386
-
r#"
387
-
INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
388
-
VALUES ($1, $2)
389
-
"#,
390
-
old_rt,
391
-
old_db_id
392
-
)
393
-
.execute(&mut *tx)
394
-
.await?;
395
-
}
396
-
397
-
sqlx::query!(
398
-
r#"
399
-
UPDATE oauth_token
400
-
SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
401
-
WHERE id = $1
402
-
"#,
403
-
old_db_id,
404
-
new_token_id,
405
-
new_refresh_token,
406
-
new_expires_at
407
-
)
408
-
.execute(&mut *tx)
409
-
.await?;
410
-
411
-
tx.commit().await?;
412
-
Ok(())
413
-
}
414
-
415
-
pub async fn check_refresh_token_used(
416
-
pool: &PgPool,
417
-
refresh_token: &str,
418
-
) -> Result<Option<i32>, OAuthError> {
419
-
let row = sqlx::query_scalar!(
420
-
r#"
421
-
SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
422
-
"#,
423
-
refresh_token
424
-
)
425
-
.fetch_optional(pool)
426
-
.await?;
427
-
428
-
Ok(row)
429
-
}
430
-
431
-
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
432
-
sqlx::query!(
433
-
r#"
434
-
DELETE FROM oauth_token WHERE token_id = $1
435
-
"#,
436
-
token_id
437
-
)
438
-
.execute(pool)
439
-
.await?;
440
-
441
-
Ok(())
442
-
}
443
-
444
-
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
445
-
sqlx::query!(
446
-
r#"
447
-
DELETE FROM oauth_token WHERE id = $1
448
-
"#,
449
-
db_id
450
-
)
451
-
.execute(pool)
452
-
.await?;
453
-
454
-
Ok(())
455
-
}
456
-
457
-
pub async fn upsert_account_device(
458
-
pool: &PgPool,
459
-
did: &str,
460
-
device_id: &str,
461
-
) -> Result<(), OAuthError> {
462
-
sqlx::query!(
463
-
r#"
464
-
INSERT INTO oauth_account_device (did, device_id, created_at, updated_at)
465
-
VALUES ($1, $2, NOW(), NOW())
466
-
ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW()
467
-
"#,
468
-
did,
469
-
device_id
470
-
)
471
-
.execute(pool)
472
-
.await?;
473
-
474
-
Ok(())
475
-
}
476
-
477
-
pub async fn upsert_authorized_client(
478
-
pool: &PgPool,
479
-
did: &str,
480
-
client_id: &str,
481
-
data: &AuthorizedClientData,
482
-
) -> Result<(), OAuthError> {
483
-
let data_json = to_json(data)?;
484
-
485
-
sqlx::query!(
486
-
r#"
487
-
INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data)
488
-
VALUES ($1, $2, NOW(), NOW(), $3)
489
-
ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3
490
-
"#,
491
-
did,
492
-
client_id,
493
-
data_json
494
-
)
495
-
.execute(pool)
496
-
.await?;
497
-
498
-
Ok(())
499
-
}
500
-
501
-
pub async fn get_authorized_client(
502
-
pool: &PgPool,
503
-
did: &str,
504
-
client_id: &str,
505
-
) -> Result<Option<AuthorizedClientData>, OAuthError> {
506
-
let row = sqlx::query_scalar!(
507
-
r#"
508
-
SELECT data FROM oauth_authorized_client
509
-
WHERE did = $1 AND client_id = $2
510
-
"#,
511
-
did,
512
-
client_id
513
-
)
514
-
.fetch_optional(pool)
515
-
.await?;
516
-
517
-
match row {
518
-
Some(v) => Ok(Some(from_json(v)?)),
519
-
None => Ok(None),
520
-
}
521
-
}
522
-
523
-
pub async fn list_tokens_for_user(
524
-
pool: &PgPool,
525
-
did: &str,
526
-
) -> Result<Vec<TokenData>, OAuthError> {
527
-
let rows = sqlx::query!(
528
-
r#"
529
-
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
530
-
device_id, parameters, details, code, current_refresh_token, scope
531
-
FROM oauth_token
532
-
WHERE did = $1
533
-
"#,
534
-
did
535
-
)
536
-
.fetch_all(pool)
537
-
.await?;
538
-
539
-
let mut tokens = Vec::with_capacity(rows.len());
540
-
for r in rows {
541
-
tokens.push(TokenData {
542
-
did: r.did,
543
-
token_id: r.token_id,
544
-
created_at: r.created_at,
545
-
updated_at: r.updated_at,
546
-
expires_at: r.expires_at,
547
-
client_id: r.client_id,
548
-
client_auth: from_json(r.client_auth)?,
549
-
device_id: r.device_id,
550
-
parameters: from_json(r.parameters)?,
551
-
details: r.details,
552
-
code: r.code,
553
-
current_refresh_token: r.current_refresh_token,
554
-
scope: r.scope,
555
-
});
556
-
}
557
-
Ok(tokens)
558
-
}
559
-
560
-
pub async fn check_and_record_dpop_jti(
561
-
pool: &PgPool,
562
-
jti: &str,
563
-
) -> Result<bool, OAuthError> {
564
-
let result = sqlx::query!(
565
-
r#"
566
-
INSERT INTO oauth_dpop_jti (jti)
567
-
VALUES ($1)
568
-
ON CONFLICT (jti) DO NOTHING
569
-
"#,
570
-
jti
571
-
)
572
-
.execute(pool)
573
-
.await?;
574
-
575
-
Ok(result.rows_affected() > 0)
576
-
}
577
-
578
-
pub async fn cleanup_expired_dpop_jtis(
579
-
pool: &PgPool,
580
-
max_age_secs: i64,
581
-
) -> Result<u64, OAuthError> {
582
-
let result = sqlx::query!(
583
-
r#"
584
-
DELETE FROM oauth_dpop_jti
585
-
WHERE created_at < NOW() - INTERVAL '1 second' * $1
586
-
"#,
587
-
max_age_secs as f64
588
-
)
589
-
.execute(pool)
590
-
.await?;
591
-
592
-
Ok(result.rows_affected())
593
-
}
594
-
595
-
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
596
-
let count = sqlx::query_scalar!(
597
-
r#"
598
-
SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
599
-
"#,
600
-
did
601
-
)
602
-
.fetch_one(pool)
603
-
.await?;
604
-
605
-
Ok(count)
606
-
}
607
-
608
-
pub async fn delete_oldest_tokens_for_user(
609
-
pool: &PgPool,
610
-
did: &str,
611
-
keep_count: i64,
612
-
) -> Result<u64, OAuthError> {
613
-
let result = sqlx::query!(
614
-
r#"
615
-
DELETE FROM oauth_token
616
-
WHERE id IN (
617
-
SELECT id FROM oauth_token
618
-
WHERE did = $1
619
-
ORDER BY updated_at ASC
620
-
OFFSET $2
621
-
)
622
-
"#,
623
-
did,
624
-
keep_count
625
-
)
626
-
.execute(pool)
627
-
.await?;
628
-
629
-
Ok(result.rows_affected())
630
-
}
631
-
632
-
const MAX_TOKENS_PER_USER: i64 = 100;
633
-
634
-
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
635
-
let count = count_tokens_for_user(pool, did).await?;
636
-
if count > MAX_TOKENS_PER_USER {
637
-
let to_keep = MAX_TOKENS_PER_USER - 1;
638
-
delete_oldest_tokens_for_user(pool, did, to_keep).await?;
639
-
}
640
-
Ok(())
641
-
}
+50
src/oauth/db/client.rs
+50
src/oauth/db/client.rs
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{AuthorizedClientData, OAuthError};
4
+
use super::helpers::{from_json, to_json};
5
+
6
+
pub async fn upsert_authorized_client(
7
+
pool: &PgPool,
8
+
did: &str,
9
+
client_id: &str,
10
+
data: &AuthorizedClientData,
11
+
) -> Result<(), OAuthError> {
12
+
let data_json = to_json(data)?;
13
+
14
+
sqlx::query!(
15
+
r#"
16
+
INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data)
17
+
VALUES ($1, $2, NOW(), NOW(), $3)
18
+
ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3
19
+
"#,
20
+
did,
21
+
client_id,
22
+
data_json
23
+
)
24
+
.execute(pool)
25
+
.await?;
26
+
27
+
Ok(())
28
+
}
29
+
30
+
pub async fn get_authorized_client(
31
+
pool: &PgPool,
32
+
did: &str,
33
+
client_id: &str,
34
+
) -> Result<Option<AuthorizedClientData>, OAuthError> {
35
+
let row = sqlx::query_scalar!(
36
+
r#"
37
+
SELECT data FROM oauth_authorized_client
38
+
WHERE did = $1 AND client_id = $2
39
+
"#,
40
+
did,
41
+
client_id
42
+
)
43
+
.fetch_optional(pool)
44
+
.await?;
45
+
46
+
match row {
47
+
Some(v) => Ok(Some(from_json(v)?)),
48
+
None => Ok(None),
49
+
}
50
+
}
+96
src/oauth/db/device.rs
+96
src/oauth/db/device.rs
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{DeviceData, OAuthError};
4
+
5
+
pub async fn create_device(
6
+
pool: &PgPool,
7
+
device_id: &str,
8
+
data: &DeviceData,
9
+
) -> Result<(), OAuthError> {
10
+
sqlx::query!(
11
+
r#"
12
+
INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at)
13
+
VALUES ($1, $2, $3, $4, $5)
14
+
"#,
15
+
device_id,
16
+
data.session_id,
17
+
data.user_agent,
18
+
data.ip_address,
19
+
data.last_seen_at,
20
+
)
21
+
.execute(pool)
22
+
.await?;
23
+
24
+
Ok(())
25
+
}
26
+
27
+
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
28
+
let row = sqlx::query!(
29
+
r#"
30
+
SELECT session_id, user_agent, ip_address, last_seen_at
31
+
FROM oauth_device
32
+
WHERE id = $1
33
+
"#,
34
+
device_id
35
+
)
36
+
.fetch_optional(pool)
37
+
.await?;
38
+
39
+
Ok(row.map(|r| DeviceData {
40
+
session_id: r.session_id,
41
+
user_agent: r.user_agent,
42
+
ip_address: r.ip_address,
43
+
last_seen_at: r.last_seen_at,
44
+
}))
45
+
}
46
+
47
+
pub async fn update_device_last_seen(
48
+
pool: &PgPool,
49
+
device_id: &str,
50
+
) -> Result<(), OAuthError> {
51
+
sqlx::query!(
52
+
r#"
53
+
UPDATE oauth_device
54
+
SET last_seen_at = NOW()
55
+
WHERE id = $1
56
+
"#,
57
+
device_id
58
+
)
59
+
.execute(pool)
60
+
.await?;
61
+
62
+
Ok(())
63
+
}
64
+
65
+
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
66
+
sqlx::query!(
67
+
r#"
68
+
DELETE FROM oauth_device WHERE id = $1
69
+
"#,
70
+
device_id
71
+
)
72
+
.execute(pool)
73
+
.await?;
74
+
75
+
Ok(())
76
+
}
77
+
78
+
pub async fn upsert_account_device(
79
+
pool: &PgPool,
80
+
did: &str,
81
+
device_id: &str,
82
+
) -> Result<(), OAuthError> {
83
+
sqlx::query!(
84
+
r#"
85
+
INSERT INTO oauth_account_device (did, device_id, created_at, updated_at)
86
+
VALUES ($1, $2, NOW(), NOW())
87
+
ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW()
88
+
"#,
89
+
did,
90
+
device_id
91
+
)
92
+
.execute(pool)
93
+
.await?;
94
+
95
+
Ok(())
96
+
}
+38
src/oauth/db/dpop.rs
+38
src/oauth/db/dpop.rs
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::OAuthError;
4
+
5
+
pub async fn check_and_record_dpop_jti(
6
+
pool: &PgPool,
7
+
jti: &str,
8
+
) -> Result<bool, OAuthError> {
9
+
let result = sqlx::query!(
10
+
r#"
11
+
INSERT INTO oauth_dpop_jti (jti)
12
+
VALUES ($1)
13
+
ON CONFLICT (jti) DO NOTHING
14
+
"#,
15
+
jti
16
+
)
17
+
.execute(pool)
18
+
.await?;
19
+
20
+
Ok(result.rows_affected() > 0)
21
+
}
22
+
23
+
pub async fn cleanup_expired_dpop_jtis(
24
+
pool: &PgPool,
25
+
max_age_secs: i64,
26
+
) -> Result<u64, OAuthError> {
27
+
let result = sqlx::query!(
28
+
r#"
29
+
DELETE FROM oauth_dpop_jti
30
+
WHERE created_at < NOW() - INTERVAL '1 second' * $1
31
+
"#,
32
+
max_age_secs as f64
33
+
)
34
+
.execute(pool)
35
+
.await?;
36
+
37
+
Ok(result.rows_affected())
38
+
}
+17
src/oauth/db/helpers.rs
+17
src/oauth/db/helpers.rs
···
1
+
use serde::{de::DeserializeOwned, Serialize};
2
+
3
+
use super::super::OAuthError;
4
+
5
+
pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
6
+
serde_json::to_value(value).map_err(|e| {
7
+
tracing::error!("JSON serialization error: {}", e);
8
+
OAuthError::ServerError("Internal serialization error".to_string())
9
+
})
10
+
}
11
+
12
+
pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
13
+
serde_json::from_value(value).map_err(|e| {
14
+
tracing::error!("JSON deserialization error: {}", e);
15
+
OAuthError::ServerError("Internal data corruption".to_string())
16
+
})
17
+
}
+22
src/oauth/db/mod.rs
+22
src/oauth/db/mod.rs
···
1
+
mod client;
2
+
mod device;
3
+
mod dpop;
4
+
mod helpers;
5
+
mod request;
6
+
mod token;
7
+
8
+
pub use client::{get_authorized_client, upsert_authorized_client};
9
+
pub use device::{
10
+
create_device, delete_device, get_device, update_device_last_seen, upsert_account_device,
11
+
};
12
+
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
13
+
pub use request::{
14
+
consume_authorization_request_by_code, create_authorization_request,
15
+
delete_authorization_request, delete_expired_authorization_requests, get_authorization_request,
16
+
update_authorization_request,
17
+
};
18
+
pub use token::{
19
+
check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user,
20
+
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
21
+
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
22
+
};
+163
src/oauth/db/request.rs
+163
src/oauth/db/request.rs
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData};
4
+
use super::helpers::{from_json, to_json};
5
+
6
+
pub async fn create_authorization_request(
7
+
pool: &PgPool,
8
+
request_id: &str,
9
+
data: &RequestData,
10
+
) -> Result<(), OAuthError> {
11
+
let client_auth_json = match &data.client_auth {
12
+
Some(ca) => Some(to_json(ca)?),
13
+
None => None,
14
+
};
15
+
let parameters_json = to_json(&data.parameters)?;
16
+
17
+
sqlx::query!(
18
+
r#"
19
+
INSERT INTO oauth_authorization_request
20
+
(id, did, device_id, client_id, client_auth, parameters, expires_at, code)
21
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
22
+
"#,
23
+
request_id,
24
+
data.did,
25
+
data.device_id,
26
+
data.client_id,
27
+
client_auth_json,
28
+
parameters_json,
29
+
data.expires_at,
30
+
data.code,
31
+
)
32
+
.execute(pool)
33
+
.await?;
34
+
35
+
Ok(())
36
+
}
37
+
38
+
pub async fn get_authorization_request(
39
+
pool: &PgPool,
40
+
request_id: &str,
41
+
) -> Result<Option<RequestData>, OAuthError> {
42
+
let row = sqlx::query!(
43
+
r#"
44
+
SELECT did, device_id, client_id, client_auth, parameters, expires_at, code
45
+
FROM oauth_authorization_request
46
+
WHERE id = $1
47
+
"#,
48
+
request_id
49
+
)
50
+
.fetch_optional(pool)
51
+
.await?;
52
+
53
+
match row {
54
+
Some(r) => {
55
+
let client_auth: Option<ClientAuth> = match r.client_auth {
56
+
Some(v) => Some(from_json(v)?),
57
+
None => None,
58
+
};
59
+
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
60
+
61
+
Ok(Some(RequestData {
62
+
client_id: r.client_id,
63
+
client_auth,
64
+
parameters,
65
+
expires_at: r.expires_at,
66
+
did: r.did,
67
+
device_id: r.device_id,
68
+
code: r.code,
69
+
}))
70
+
}
71
+
None => Ok(None),
72
+
}
73
+
}
74
+
75
+
pub async fn update_authorization_request(
76
+
pool: &PgPool,
77
+
request_id: &str,
78
+
did: &str,
79
+
device_id: Option<&str>,
80
+
code: &str,
81
+
) -> Result<(), OAuthError> {
82
+
sqlx::query!(
83
+
r#"
84
+
UPDATE oauth_authorization_request
85
+
SET did = $2, device_id = $3, code = $4
86
+
WHERE id = $1
87
+
"#,
88
+
request_id,
89
+
did,
90
+
device_id,
91
+
code
92
+
)
93
+
.execute(pool)
94
+
.await?;
95
+
96
+
Ok(())
97
+
}
98
+
99
+
pub async fn consume_authorization_request_by_code(
100
+
pool: &PgPool,
101
+
code: &str,
102
+
) -> Result<Option<RequestData>, OAuthError> {
103
+
let row = sqlx::query!(
104
+
r#"
105
+
DELETE FROM oauth_authorization_request
106
+
WHERE code = $1
107
+
RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code
108
+
"#,
109
+
code
110
+
)
111
+
.fetch_optional(pool)
112
+
.await?;
113
+
114
+
match row {
115
+
Some(r) => {
116
+
let client_auth: Option<ClientAuth> = match r.client_auth {
117
+
Some(v) => Some(from_json(v)?),
118
+
None => None,
119
+
};
120
+
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
121
+
122
+
Ok(Some(RequestData {
123
+
client_id: r.client_id,
124
+
client_auth,
125
+
parameters,
126
+
expires_at: r.expires_at,
127
+
did: r.did,
128
+
device_id: r.device_id,
129
+
code: r.code,
130
+
}))
131
+
}
132
+
None => Ok(None),
133
+
}
134
+
}
135
+
136
+
pub async fn delete_authorization_request(
137
+
pool: &PgPool,
138
+
request_id: &str,
139
+
) -> Result<(), OAuthError> {
140
+
sqlx::query!(
141
+
r#"
142
+
DELETE FROM oauth_authorization_request WHERE id = $1
143
+
"#,
144
+
request_id
145
+
)
146
+
.execute(pool)
147
+
.await?;
148
+
149
+
Ok(())
150
+
}
151
+
152
+
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
153
+
let result = sqlx::query!(
154
+
r#"
155
+
DELETE FROM oauth_authorization_request
156
+
WHERE expires_at < NOW()
157
+
"#
158
+
)
159
+
.execute(pool)
160
+
.await?;
161
+
162
+
Ok(result.rows_affected())
163
+
}
+291
src/oauth/db/token.rs
+291
src/oauth/db/token.rs
···
1
+
use chrono::{DateTime, Utc};
2
+
use sqlx::PgPool;
3
+
4
+
use super::super::{OAuthError, TokenData};
5
+
use super::helpers::{from_json, to_json};
6
+
7
+
pub async fn create_token(
8
+
pool: &PgPool,
9
+
data: &TokenData,
10
+
) -> Result<i32, OAuthError> {
11
+
let client_auth_json = to_json(&data.client_auth)?;
12
+
let parameters_json = to_json(&data.parameters)?;
13
+
14
+
let row = sqlx::query!(
15
+
r#"
16
+
INSERT INTO oauth_token
17
+
(did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
18
+
device_id, parameters, details, code, current_refresh_token, scope)
19
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
20
+
RETURNING id
21
+
"#,
22
+
data.did,
23
+
data.token_id,
24
+
data.created_at,
25
+
data.updated_at,
26
+
data.expires_at,
27
+
data.client_id,
28
+
client_auth_json,
29
+
data.device_id,
30
+
parameters_json,
31
+
data.details,
32
+
data.code,
33
+
data.current_refresh_token,
34
+
data.scope,
35
+
)
36
+
.fetch_one(pool)
37
+
.await?;
38
+
39
+
Ok(row.id)
40
+
}
41
+
42
+
pub async fn get_token_by_id(
43
+
pool: &PgPool,
44
+
token_id: &str,
45
+
) -> Result<Option<TokenData>, OAuthError> {
46
+
let row = sqlx::query!(
47
+
r#"
48
+
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
49
+
device_id, parameters, details, code, current_refresh_token, scope
50
+
FROM oauth_token
51
+
WHERE token_id = $1
52
+
"#,
53
+
token_id
54
+
)
55
+
.fetch_optional(pool)
56
+
.await?;
57
+
58
+
match row {
59
+
Some(r) => Ok(Some(TokenData {
60
+
did: r.did,
61
+
token_id: r.token_id,
62
+
created_at: r.created_at,
63
+
updated_at: r.updated_at,
64
+
expires_at: r.expires_at,
65
+
client_id: r.client_id,
66
+
client_auth: from_json(r.client_auth)?,
67
+
device_id: r.device_id,
68
+
parameters: from_json(r.parameters)?,
69
+
details: r.details,
70
+
code: r.code,
71
+
current_refresh_token: r.current_refresh_token,
72
+
scope: r.scope,
73
+
})),
74
+
None => Ok(None),
75
+
}
76
+
}
77
+
78
+
pub async fn get_token_by_refresh_token(
79
+
pool: &PgPool,
80
+
refresh_token: &str,
81
+
) -> Result<Option<(i32, TokenData)>, OAuthError> {
82
+
let row = sqlx::query!(
83
+
r#"
84
+
SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
85
+
device_id, parameters, details, code, current_refresh_token, scope
86
+
FROM oauth_token
87
+
WHERE current_refresh_token = $1
88
+
"#,
89
+
refresh_token
90
+
)
91
+
.fetch_optional(pool)
92
+
.await?;
93
+
94
+
match row {
95
+
Some(r) => Ok(Some((
96
+
r.id,
97
+
TokenData {
98
+
did: r.did,
99
+
token_id: r.token_id,
100
+
created_at: r.created_at,
101
+
updated_at: r.updated_at,
102
+
expires_at: r.expires_at,
103
+
client_id: r.client_id,
104
+
client_auth: from_json(r.client_auth)?,
105
+
device_id: r.device_id,
106
+
parameters: from_json(r.parameters)?,
107
+
details: r.details,
108
+
code: r.code,
109
+
current_refresh_token: r.current_refresh_token,
110
+
scope: r.scope,
111
+
},
112
+
))),
113
+
None => Ok(None),
114
+
}
115
+
}
116
+
117
+
pub async fn rotate_token(
118
+
pool: &PgPool,
119
+
old_db_id: i32,
120
+
new_token_id: &str,
121
+
new_refresh_token: &str,
122
+
new_expires_at: DateTime<Utc>,
123
+
) -> Result<(), OAuthError> {
124
+
let mut tx = pool.begin().await?;
125
+
126
+
let old_refresh = sqlx::query_scalar!(
127
+
r#"
128
+
SELECT current_refresh_token FROM oauth_token WHERE id = $1
129
+
"#,
130
+
old_db_id
131
+
)
132
+
.fetch_one(&mut *tx)
133
+
.await?;
134
+
135
+
if let Some(old_rt) = old_refresh {
136
+
sqlx::query!(
137
+
r#"
138
+
INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
139
+
VALUES ($1, $2)
140
+
"#,
141
+
old_rt,
142
+
old_db_id
143
+
)
144
+
.execute(&mut *tx)
145
+
.await?;
146
+
}
147
+
148
+
sqlx::query!(
149
+
r#"
150
+
UPDATE oauth_token
151
+
SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
152
+
WHERE id = $1
153
+
"#,
154
+
old_db_id,
155
+
new_token_id,
156
+
new_refresh_token,
157
+
new_expires_at
158
+
)
159
+
.execute(&mut *tx)
160
+
.await?;
161
+
162
+
tx.commit().await?;
163
+
Ok(())
164
+
}
165
+
166
+
pub async fn check_refresh_token_used(
167
+
pool: &PgPool,
168
+
refresh_token: &str,
169
+
) -> Result<Option<i32>, OAuthError> {
170
+
let row = sqlx::query_scalar!(
171
+
r#"
172
+
SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
173
+
"#,
174
+
refresh_token
175
+
)
176
+
.fetch_optional(pool)
177
+
.await?;
178
+
179
+
Ok(row)
180
+
}
181
+
182
+
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
183
+
sqlx::query!(
184
+
r#"
185
+
DELETE FROM oauth_token WHERE token_id = $1
186
+
"#,
187
+
token_id
188
+
)
189
+
.execute(pool)
190
+
.await?;
191
+
192
+
Ok(())
193
+
}
194
+
195
+
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
196
+
sqlx::query!(
197
+
r#"
198
+
DELETE FROM oauth_token WHERE id = $1
199
+
"#,
200
+
db_id
201
+
)
202
+
.execute(pool)
203
+
.await?;
204
+
205
+
Ok(())
206
+
}
207
+
208
+
pub async fn list_tokens_for_user(
209
+
pool: &PgPool,
210
+
did: &str,
211
+
) -> Result<Vec<TokenData>, OAuthError> {
212
+
let rows = sqlx::query!(
213
+
r#"
214
+
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
215
+
device_id, parameters, details, code, current_refresh_token, scope
216
+
FROM oauth_token
217
+
WHERE did = $1
218
+
"#,
219
+
did
220
+
)
221
+
.fetch_all(pool)
222
+
.await?;
223
+
224
+
let mut tokens = Vec::with_capacity(rows.len());
225
+
for r in rows {
226
+
tokens.push(TokenData {
227
+
did: r.did,
228
+
token_id: r.token_id,
229
+
created_at: r.created_at,
230
+
updated_at: r.updated_at,
231
+
expires_at: r.expires_at,
232
+
client_id: r.client_id,
233
+
client_auth: from_json(r.client_auth)?,
234
+
device_id: r.device_id,
235
+
parameters: from_json(r.parameters)?,
236
+
details: r.details,
237
+
code: r.code,
238
+
current_refresh_token: r.current_refresh_token,
239
+
scope: r.scope,
240
+
});
241
+
}
242
+
Ok(tokens)
243
+
}
244
+
245
+
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
246
+
let count = sqlx::query_scalar!(
247
+
r#"
248
+
SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
249
+
"#,
250
+
did
251
+
)
252
+
.fetch_one(pool)
253
+
.await?;
254
+
255
+
Ok(count)
256
+
}
257
+
258
+
pub async fn delete_oldest_tokens_for_user(
259
+
pool: &PgPool,
260
+
did: &str,
261
+
keep_count: i64,
262
+
) -> Result<u64, OAuthError> {
263
+
let result = sqlx::query!(
264
+
r#"
265
+
DELETE FROM oauth_token
266
+
WHERE id IN (
267
+
SELECT id FROM oauth_token
268
+
WHERE did = $1
269
+
ORDER BY updated_at ASC
270
+
OFFSET $2
271
+
)
272
+
"#,
273
+
did,
274
+
keep_count
275
+
)
276
+
.execute(pool)
277
+
.await?;
278
+
279
+
Ok(result.rows_affected())
280
+
}
281
+
282
+
const MAX_TOKENS_PER_USER: i64 = 100;
283
+
284
+
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
285
+
let count = count_tokens_for_user(pool, did).await?;
286
+
if count > MAX_TOKENS_PER_USER {
287
+
let to_keep = MAX_TOKENS_PER_USER - 1;
288
+
delete_oldest_tokens_for_user(pool, did, to_keep).await?;
289
+
}
290
+
Ok(())
291
+
}
+8
-10
src/oauth/dpop.rs
+8
-10
src/oauth/dpop.rs
···
237
237
false,
238
238
);
239
239
240
-
let affine = AffinePoint::from_encoded_point(&point);
241
-
if affine.is_none().into() {
242
-
return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string()));
243
-
}
240
+
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
241
+
let affine = affine_opt
242
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
244
243
245
-
let verifying_key = VerifyingKey::from_affine(affine.unwrap())
244
+
let verifying_key = VerifyingKey::from_affine(affine)
246
245
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
247
246
248
247
let sig = Signature::from_slice(signature)
···
287
286
false,
288
287
);
289
288
290
-
let affine = AffinePoint::from_encoded_point(&point);
291
-
if affine.is_none().into() {
292
-
return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string()));
293
-
}
289
+
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
290
+
let affine = affine_opt
291
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
294
292
295
-
let verifying_key = VerifyingKey::from_affine(affine.unwrap())
293
+
let verifying_key = VerifyingKey::from_affine(affine)
296
294
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
297
295
298
296
let sig = Signature::from_slice(signature)
-558
src/oauth/endpoints/token.rs
-558
src/oauth/endpoints/token.rs
···
1
-
use axum::{
2
-
Form, Json,
3
-
extract::State,
4
-
http::{HeaderMap, StatusCode},
5
-
};
6
-
use base64::Engine;
7
-
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8
-
use chrono::{Duration, Utc};
9
-
use hmac::Mac;
10
-
use serde::{Deserialize, Serialize};
11
-
use sha2::{Digest, Sha256};
12
-
use subtle::ConstantTimeEq;
13
-
14
-
use crate::config::AuthConfig;
15
-
use crate::state::AppState;
16
-
use crate::oauth::{
17
-
ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
18
-
client::{ClientMetadataCache, verify_client_auth},
19
-
db,
20
-
dpop::DPoPVerifier,
21
-
};
22
-
23
-
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
24
-
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
25
-
26
-
#[derive(Debug, Deserialize)]
27
-
pub struct TokenRequest {
28
-
pub grant_type: String,
29
-
#[serde(default)]
30
-
pub code: Option<String>,
31
-
#[serde(default)]
32
-
pub redirect_uri: Option<String>,
33
-
#[serde(default)]
34
-
pub code_verifier: Option<String>,
35
-
#[serde(default)]
36
-
pub refresh_token: Option<String>,
37
-
#[serde(default)]
38
-
pub client_id: Option<String>,
39
-
#[serde(default)]
40
-
pub client_secret: Option<String>,
41
-
#[serde(default)]
42
-
pub client_assertion: Option<String>,
43
-
#[serde(default)]
44
-
pub client_assertion_type: Option<String>,
45
-
}
46
-
47
-
#[derive(Debug, Serialize)]
48
-
pub struct TokenResponse {
49
-
pub access_token: String,
50
-
pub token_type: String,
51
-
pub expires_in: u64,
52
-
#[serde(skip_serializing_if = "Option::is_none")]
53
-
pub refresh_token: Option<String>,
54
-
#[serde(skip_serializing_if = "Option::is_none")]
55
-
pub scope: Option<String>,
56
-
#[serde(skip_serializing_if = "Option::is_none")]
57
-
pub sub: Option<String>,
58
-
}
59
-
60
-
pub async fn token_endpoint(
61
-
State(state): State<AppState>,
62
-
headers: HeaderMap,
63
-
Form(request): Form<TokenRequest>,
64
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
65
-
let dpop_proof = headers
66
-
.get("DPoP")
67
-
.and_then(|v| v.to_str().ok())
68
-
.map(|s| s.to_string());
69
-
70
-
match request.grant_type.as_str() {
71
-
"authorization_code" => {
72
-
handle_authorization_code_grant(state, headers, request, dpop_proof).await
73
-
}
74
-
"refresh_token" => {
75
-
handle_refresh_token_grant(state, headers, request, dpop_proof).await
76
-
}
77
-
_ => Err(OAuthError::UnsupportedGrantType(format!(
78
-
"Unsupported grant_type: {}",
79
-
request.grant_type
80
-
))),
81
-
}
82
-
}
83
-
84
-
async fn handle_authorization_code_grant(
85
-
state: AppState,
86
-
_headers: HeaderMap,
87
-
request: TokenRequest,
88
-
dpop_proof: Option<String>,
89
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
90
-
let code = request
91
-
.code
92
-
.ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
93
-
94
-
let code_verifier = request
95
-
.code_verifier
96
-
.ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
97
-
98
-
let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
99
-
.await?
100
-
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
101
-
102
-
if auth_request.expires_at < Utc::now() {
103
-
return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
104
-
}
105
-
106
-
if let Some(request_client_id) = &request.client_id {
107
-
if request_client_id != &auth_request.client_id {
108
-
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
109
-
}
110
-
}
111
-
112
-
let did = auth_request
113
-
.did
114
-
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
115
-
116
-
let client_metadata_cache = ClientMetadataCache::new(3600);
117
-
let client_metadata = client_metadata_cache
118
-
.get(&auth_request.client_id)
119
-
.await?;
120
-
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
121
-
verify_client_auth(&client_metadata, &client_auth)?;
122
-
123
-
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
124
-
125
-
if let Some(redirect_uri) = &request.redirect_uri {
126
-
if redirect_uri != &auth_request.parameters.redirect_uri {
127
-
return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
128
-
}
129
-
}
130
-
131
-
let dpop_jkt = if let Some(proof) = &dpop_proof {
132
-
let config = AuthConfig::get();
133
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
134
-
135
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
136
-
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
137
-
138
-
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
139
-
140
-
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
141
-
return Err(OAuthError::InvalidDpopProof(
142
-
"DPoP proof has already been used".to_string(),
143
-
));
144
-
}
145
-
146
-
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
147
-
if &result.jkt != expected_jkt {
148
-
return Err(OAuthError::InvalidDpopProof(
149
-
"DPoP key binding mismatch".to_string(),
150
-
));
151
-
}
152
-
}
153
-
154
-
Some(result.jkt)
155
-
} else if auth_request.parameters.dpop_jkt.is_some() {
156
-
return Err(OAuthError::InvalidRequest(
157
-
"DPoP proof required for this authorization".to_string(),
158
-
));
159
-
} else {
160
-
None
161
-
};
162
-
163
-
let token_id = TokenId::generate();
164
-
let refresh_token = RefreshToken::generate();
165
-
let now = Utc::now();
166
-
167
-
let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
168
-
169
-
let token_data = TokenData {
170
-
did: did.clone(),
171
-
token_id: token_id.0.clone(),
172
-
created_at: now,
173
-
updated_at: now,
174
-
expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
175
-
client_id: auth_request.client_id.clone(),
176
-
client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
177
-
device_id: auth_request.device_id,
178
-
parameters: auth_request.parameters.clone(),
179
-
details: None,
180
-
code: None,
181
-
current_refresh_token: Some(refresh_token.0.clone()),
182
-
scope: auth_request.parameters.scope.clone(),
183
-
};
184
-
185
-
db::create_token(&state.db, &token_data).await?;
186
-
187
-
tokio::spawn({
188
-
let pool = state.db.clone();
189
-
let did_clone = did.clone();
190
-
async move {
191
-
if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
192
-
tracing::warn!("Failed to enforce token limit for user: {:?}", e);
193
-
}
194
-
}
195
-
});
196
-
197
-
let mut response_headers = HeaderMap::new();
198
-
let config = AuthConfig::get();
199
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
200
-
response_headers.insert(
201
-
"DPoP-Nonce",
202
-
verifier.generate_nonce().parse().unwrap(),
203
-
);
204
-
205
-
Ok((
206
-
response_headers,
207
-
Json(TokenResponse {
208
-
access_token,
209
-
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
210
-
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
211
-
refresh_token: Some(refresh_token.0),
212
-
scope: auth_request.parameters.scope,
213
-
sub: Some(did),
214
-
}),
215
-
))
216
-
}
217
-
218
-
async fn handle_refresh_token_grant(
219
-
state: AppState,
220
-
_headers: HeaderMap,
221
-
request: TokenRequest,
222
-
dpop_proof: Option<String>,
223
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
224
-
let refresh_token_str = request
225
-
.refresh_token
226
-
.ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
227
-
228
-
if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
229
-
db::delete_token_family(&state.db, token_id).await?;
230
-
return Err(OAuthError::InvalidGrant(
231
-
"Refresh token reuse detected, token family revoked".to_string(),
232
-
));
233
-
}
234
-
235
-
let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
236
-
.await?
237
-
.ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
238
-
239
-
if token_data.expires_at < Utc::now() {
240
-
db::delete_token_family(&state.db, db_id).await?;
241
-
return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
242
-
}
243
-
244
-
let dpop_jkt = if let Some(proof) = &dpop_proof {
245
-
let config = AuthConfig::get();
246
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
247
-
248
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
249
-
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
250
-
251
-
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
252
-
253
-
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
254
-
return Err(OAuthError::InvalidDpopProof(
255
-
"DPoP proof has already been used".to_string(),
256
-
));
257
-
}
258
-
259
-
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
260
-
if &result.jkt != expected_jkt {
261
-
return Err(OAuthError::InvalidDpopProof(
262
-
"DPoP key binding mismatch".to_string(),
263
-
));
264
-
}
265
-
}
266
-
267
-
Some(result.jkt)
268
-
} else if token_data.parameters.dpop_jkt.is_some() {
269
-
return Err(OAuthError::InvalidRequest(
270
-
"DPoP proof required".to_string(),
271
-
));
272
-
} else {
273
-
None
274
-
};
275
-
276
-
let new_token_id = TokenId::generate();
277
-
let new_refresh_token = RefreshToken::generate();
278
-
let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
279
-
280
-
db::rotate_token(
281
-
&state.db,
282
-
db_id,
283
-
&new_token_id.0,
284
-
&new_refresh_token.0,
285
-
new_expires_at,
286
-
)
287
-
.await?;
288
-
289
-
let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
290
-
291
-
let mut response_headers = HeaderMap::new();
292
-
let config = AuthConfig::get();
293
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
294
-
response_headers.insert(
295
-
"DPoP-Nonce",
296
-
verifier.generate_nonce().parse().unwrap(),
297
-
);
298
-
299
-
Ok((
300
-
response_headers,
301
-
Json(TokenResponse {
302
-
access_token,
303
-
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
304
-
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
305
-
refresh_token: Some(new_refresh_token.0),
306
-
scope: token_data.scope,
307
-
sub: Some(token_data.did),
308
-
}),
309
-
))
310
-
}
311
-
312
-
fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
313
-
use subtle::ConstantTimeEq;
314
-
315
-
let mut hasher = Sha256::new();
316
-
hasher.update(code_verifier.as_bytes());
317
-
let hash = hasher.finalize();
318
-
let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
319
-
320
-
if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
321
-
return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
322
-
}
323
-
324
-
Ok(())
325
-
}
326
-
327
-
fn create_access_token(
328
-
token_id: &str,
329
-
sub: &str,
330
-
dpop_jkt: Option<&str>,
331
-
) -> Result<String, OAuthError> {
332
-
use serde_json::json;
333
-
334
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
335
-
let issuer = format!("https://{}", pds_hostname);
336
-
337
-
let now = Utc::now().timestamp();
338
-
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
339
-
340
-
let mut payload = json!({
341
-
"iss": issuer,
342
-
"sub": sub,
343
-
"aud": issuer,
344
-
"iat": now,
345
-
"exp": exp,
346
-
"jti": token_id,
347
-
"scope": "atproto"
348
-
});
349
-
350
-
if let Some(jkt) = dpop_jkt {
351
-
payload["cnf"] = json!({ "jkt": jkt });
352
-
}
353
-
354
-
let header = json!({
355
-
"alg": "HS256",
356
-
"typ": "at+jwt"
357
-
});
358
-
359
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
360
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
361
-
362
-
let signing_input = format!("{}.{}", header_b64, payload_b64);
363
-
364
-
let config = AuthConfig::get();
365
-
366
-
use sha2::Sha256 as HmacSha256;
367
-
use hmac::{Hmac, Mac};
368
-
type HmacSha256Type = Hmac<HmacSha256>;
369
-
370
-
let mut mac = HmacSha256Type::new_from_slice(config.jwt_secret().as_bytes())
371
-
.map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
372
-
mac.update(signing_input.as_bytes());
373
-
let signature = mac.finalize().into_bytes();
374
-
375
-
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
376
-
377
-
Ok(format!("{}.{}", signing_input, signature_b64))
378
-
}
379
-
380
-
pub async fn revoke_token(
381
-
State(state): State<AppState>,
382
-
Form(request): Form<RevokeRequest>,
383
-
) -> Result<StatusCode, OAuthError> {
384
-
if let Some(token) = &request.token {
385
-
if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
386
-
db::delete_token_family(&state.db, db_id).await?;
387
-
} else {
388
-
db::delete_token(&state.db, token).await?;
389
-
}
390
-
}
391
-
392
-
Ok(StatusCode::OK)
393
-
}
394
-
395
-
#[derive(Debug, Deserialize)]
396
-
pub struct RevokeRequest {
397
-
pub token: Option<String>,
398
-
#[serde(default)]
399
-
pub token_type_hint: Option<String>,
400
-
}
401
-
402
-
#[derive(Debug, Deserialize)]
403
-
pub struct IntrospectRequest {
404
-
pub token: String,
405
-
#[serde(default)]
406
-
pub token_type_hint: Option<String>,
407
-
}
408
-
409
-
#[derive(Debug, Serialize)]
410
-
pub struct IntrospectResponse {
411
-
pub active: bool,
412
-
#[serde(skip_serializing_if = "Option::is_none")]
413
-
pub scope: Option<String>,
414
-
#[serde(skip_serializing_if = "Option::is_none")]
415
-
pub client_id: Option<String>,
416
-
#[serde(skip_serializing_if = "Option::is_none")]
417
-
pub username: Option<String>,
418
-
#[serde(skip_serializing_if = "Option::is_none")]
419
-
pub token_type: Option<String>,
420
-
#[serde(skip_serializing_if = "Option::is_none")]
421
-
pub exp: Option<i64>,
422
-
#[serde(skip_serializing_if = "Option::is_none")]
423
-
pub iat: Option<i64>,
424
-
#[serde(skip_serializing_if = "Option::is_none")]
425
-
pub nbf: Option<i64>,
426
-
#[serde(skip_serializing_if = "Option::is_none")]
427
-
pub sub: Option<String>,
428
-
#[serde(skip_serializing_if = "Option::is_none")]
429
-
pub aud: Option<String>,
430
-
#[serde(skip_serializing_if = "Option::is_none")]
431
-
pub iss: Option<String>,
432
-
#[serde(skip_serializing_if = "Option::is_none")]
433
-
pub jti: Option<String>,
434
-
}
435
-
436
-
pub async fn introspect_token(
437
-
State(state): State<AppState>,
438
-
Form(request): Form<IntrospectRequest>,
439
-
) -> Json<IntrospectResponse> {
440
-
let inactive_response = IntrospectResponse {
441
-
active: false,
442
-
scope: None,
443
-
client_id: None,
444
-
username: None,
445
-
token_type: None,
446
-
exp: None,
447
-
iat: None,
448
-
nbf: None,
449
-
sub: None,
450
-
aud: None,
451
-
iss: None,
452
-
jti: None,
453
-
};
454
-
455
-
let token_info = match extract_token_claims(&request.token) {
456
-
Ok(info) => info,
457
-
Err(_) => return Json(inactive_response),
458
-
};
459
-
460
-
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
461
-
Ok(Some(data)) => data,
462
-
_ => return Json(inactive_response),
463
-
};
464
-
465
-
if token_data.expires_at < Utc::now() {
466
-
return Json(inactive_response);
467
-
}
468
-
469
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
470
-
let issuer = format!("https://{}", pds_hostname);
471
-
472
-
Json(IntrospectResponse {
473
-
active: true,
474
-
scope: token_data.scope,
475
-
client_id: Some(token_data.client_id),
476
-
username: None,
477
-
token_type: if token_data.parameters.dpop_jkt.is_some() {
478
-
Some("DPoP".to_string())
479
-
} else {
480
-
Some("Bearer".to_string())
481
-
},
482
-
exp: Some(token_info.exp),
483
-
iat: Some(token_info.iat),
484
-
nbf: Some(token_info.iat),
485
-
sub: Some(token_data.did),
486
-
aud: Some(issuer.clone()),
487
-
iss: Some(issuer),
488
-
jti: Some(token_info.jti),
489
-
})
490
-
}
491
-
492
-
struct TokenClaims {
493
-
jti: String,
494
-
exp: i64,
495
-
iat: i64,
496
-
}
497
-
498
-
fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
499
-
let parts: Vec<&str> = token.split('.').collect();
500
-
if parts.len() != 3 {
501
-
return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
502
-
}
503
-
504
-
let header_bytes = URL_SAFE_NO_PAD
505
-
.decode(parts[0])
506
-
.map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
507
-
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
508
-
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
509
-
510
-
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
511
-
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
512
-
}
513
-
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
514
-
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
515
-
}
516
-
517
-
let config = AuthConfig::get();
518
-
let secret = config.jwt_secret();
519
-
520
-
let signing_input = format!("{}.{}", parts[0], parts[1]);
521
-
let provided_sig = URL_SAFE_NO_PAD
522
-
.decode(parts[2])
523
-
.map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
524
-
525
-
type HmacSha256 = hmac::Hmac<Sha256>;
526
-
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
527
-
.map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
528
-
mac.update(signing_input.as_bytes());
529
-
let expected_sig = mac.finalize().into_bytes();
530
-
531
-
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
532
-
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
533
-
}
534
-
535
-
let payload_bytes = URL_SAFE_NO_PAD
536
-
.decode(parts[1])
537
-
.map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
538
-
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
539
-
.map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
540
-
541
-
let jti = payload
542
-
.get("jti")
543
-
.and_then(|j| j.as_str())
544
-
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
545
-
.to_string();
546
-
547
-
let exp = payload
548
-
.get("exp")
549
-
.and_then(|e| e.as_i64())
550
-
.ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
551
-
552
-
let iat = payload
553
-
.get("iat")
554
-
.and_then(|i| i.as_i64())
555
-
.ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
556
-
557
-
Ok(TokenClaims { jti, exp, iat })
558
-
}
+246
src/oauth/endpoints/token/grants.rs
+246
src/oauth/endpoints/token/grants.rs
···
1
+
use axum::http::HeaderMap;
2
+
use axum::Json;
3
+
use chrono::{Duration, Utc};
4
+
5
+
use crate::config::AuthConfig;
6
+
use crate::state::AppState;
7
+
use crate::oauth::{
8
+
ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
9
+
client::{ClientMetadataCache, verify_client_auth},
10
+
db,
11
+
dpop::DPoPVerifier,
12
+
};
13
+
14
+
use super::types::{TokenRequest, TokenResponse};
15
+
use super::helpers::{create_access_token, verify_pkce};
16
+
17
+
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
18
+
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
19
+
20
+
pub async fn handle_authorization_code_grant(
21
+
state: AppState,
22
+
_headers: HeaderMap,
23
+
request: TokenRequest,
24
+
dpop_proof: Option<String>,
25
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
26
+
let code = request
27
+
.code
28
+
.ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
29
+
30
+
let code_verifier = request
31
+
.code_verifier
32
+
.ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
33
+
34
+
let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
35
+
.await?
36
+
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
37
+
38
+
if auth_request.expires_at < Utc::now() {
39
+
return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
40
+
}
41
+
42
+
if let Some(request_client_id) = &request.client_id {
43
+
if request_client_id != &auth_request.client_id {
44
+
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
45
+
}
46
+
}
47
+
48
+
let did = auth_request
49
+
.did
50
+
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
51
+
52
+
let client_metadata_cache = ClientMetadataCache::new(3600);
53
+
let client_metadata = client_metadata_cache
54
+
.get(&auth_request.client_id)
55
+
.await?;
56
+
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
57
+
verify_client_auth(&client_metadata, &client_auth)?;
58
+
59
+
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
60
+
61
+
if let Some(redirect_uri) = &request.redirect_uri {
62
+
if redirect_uri != &auth_request.parameters.redirect_uri {
63
+
return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
64
+
}
65
+
}
66
+
67
+
let dpop_jkt = if let Some(proof) = &dpop_proof {
68
+
let config = AuthConfig::get();
69
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
70
+
71
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
72
+
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
73
+
74
+
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
75
+
76
+
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
77
+
return Err(OAuthError::InvalidDpopProof(
78
+
"DPoP proof has already been used".to_string(),
79
+
));
80
+
}
81
+
82
+
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
83
+
if &result.jkt != expected_jkt {
84
+
return Err(OAuthError::InvalidDpopProof(
85
+
"DPoP key binding mismatch".to_string(),
86
+
));
87
+
}
88
+
}
89
+
90
+
Some(result.jkt)
91
+
} else if auth_request.parameters.dpop_jkt.is_some() {
92
+
return Err(OAuthError::InvalidRequest(
93
+
"DPoP proof required for this authorization".to_string(),
94
+
));
95
+
} else {
96
+
None
97
+
};
98
+
99
+
let token_id = TokenId::generate();
100
+
let refresh_token = RefreshToken::generate();
101
+
let now = Utc::now();
102
+
103
+
let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
104
+
105
+
let token_data = TokenData {
106
+
did: did.clone(),
107
+
token_id: token_id.0.clone(),
108
+
created_at: now,
109
+
updated_at: now,
110
+
expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
111
+
client_id: auth_request.client_id.clone(),
112
+
client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
113
+
device_id: auth_request.device_id,
114
+
parameters: auth_request.parameters.clone(),
115
+
details: None,
116
+
code: None,
117
+
current_refresh_token: Some(refresh_token.0.clone()),
118
+
scope: auth_request.parameters.scope.clone(),
119
+
};
120
+
121
+
db::create_token(&state.db, &token_data).await?;
122
+
123
+
tokio::spawn({
124
+
let pool = state.db.clone();
125
+
let did_clone = did.clone();
126
+
async move {
127
+
if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
128
+
tracing::warn!("Failed to enforce token limit for user: {:?}", e);
129
+
}
130
+
}
131
+
});
132
+
133
+
let mut response_headers = HeaderMap::new();
134
+
let config = AuthConfig::get();
135
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
136
+
response_headers.insert(
137
+
"DPoP-Nonce",
138
+
verifier.generate_nonce().parse().unwrap(),
139
+
);
140
+
141
+
Ok((
142
+
response_headers,
143
+
Json(TokenResponse {
144
+
access_token,
145
+
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
146
+
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
147
+
refresh_token: Some(refresh_token.0),
148
+
scope: auth_request.parameters.scope,
149
+
sub: Some(did),
150
+
}),
151
+
))
152
+
}
153
+
154
+
pub async fn handle_refresh_token_grant(
155
+
state: AppState,
156
+
_headers: HeaderMap,
157
+
request: TokenRequest,
158
+
dpop_proof: Option<String>,
159
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
160
+
let refresh_token_str = request
161
+
.refresh_token
162
+
.ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
163
+
164
+
if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
165
+
db::delete_token_family(&state.db, token_id).await?;
166
+
return Err(OAuthError::InvalidGrant(
167
+
"Refresh token reuse detected, token family revoked".to_string(),
168
+
));
169
+
}
170
+
171
+
let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
172
+
.await?
173
+
.ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
174
+
175
+
if token_data.expires_at < Utc::now() {
176
+
db::delete_token_family(&state.db, db_id).await?;
177
+
return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
178
+
}
179
+
180
+
let dpop_jkt = if let Some(proof) = &dpop_proof {
181
+
let config = AuthConfig::get();
182
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
183
+
184
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
185
+
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
186
+
187
+
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
188
+
189
+
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
190
+
return Err(OAuthError::InvalidDpopProof(
191
+
"DPoP proof has already been used".to_string(),
192
+
));
193
+
}
194
+
195
+
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
196
+
if &result.jkt != expected_jkt {
197
+
return Err(OAuthError::InvalidDpopProof(
198
+
"DPoP key binding mismatch".to_string(),
199
+
));
200
+
}
201
+
}
202
+
203
+
Some(result.jkt)
204
+
} else if token_data.parameters.dpop_jkt.is_some() {
205
+
return Err(OAuthError::InvalidRequest(
206
+
"DPoP proof required".to_string(),
207
+
));
208
+
} else {
209
+
None
210
+
};
211
+
212
+
let new_token_id = TokenId::generate();
213
+
let new_refresh_token = RefreshToken::generate();
214
+
let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
215
+
216
+
db::rotate_token(
217
+
&state.db,
218
+
db_id,
219
+
&new_token_id.0,
220
+
&new_refresh_token.0,
221
+
new_expires_at,
222
+
)
223
+
.await?;
224
+
225
+
let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
226
+
227
+
let mut response_headers = HeaderMap::new();
228
+
let config = AuthConfig::get();
229
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
230
+
response_headers.insert(
231
+
"DPoP-Nonce",
232
+
verifier.generate_nonce().parse().unwrap(),
233
+
);
234
+
235
+
Ok((
236
+
response_headers,
237
+
Json(TokenResponse {
238
+
access_token,
239
+
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
240
+
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
241
+
refresh_token: Some(new_refresh_token.0),
242
+
scope: token_data.scope,
243
+
sub: Some(token_data.did),
244
+
}),
245
+
))
246
+
}
+143
src/oauth/endpoints/token/helpers.rs
+143
src/oauth/endpoints/token/helpers.rs
···
1
+
use base64::Engine;
2
+
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3
+
use chrono::Utc;
4
+
use hmac::Mac;
5
+
use sha2::{Digest, Sha256};
6
+
use subtle::ConstantTimeEq;
7
+
8
+
use crate::config::AuthConfig;
9
+
use crate::oauth::OAuthError;
10
+
11
+
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
12
+
13
+
pub struct TokenClaims {
14
+
pub jti: String,
15
+
pub exp: i64,
16
+
pub iat: i64,
17
+
}
18
+
19
+
pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
20
+
let mut hasher = Sha256::new();
21
+
hasher.update(code_verifier.as_bytes());
22
+
let hash = hasher.finalize();
23
+
let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
24
+
25
+
if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
26
+
return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
27
+
}
28
+
29
+
Ok(())
30
+
}
31
+
32
+
pub fn create_access_token(
33
+
token_id: &str,
34
+
sub: &str,
35
+
dpop_jkt: Option<&str>,
36
+
) -> Result<String, OAuthError> {
37
+
use serde_json::json;
38
+
39
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
40
+
let issuer = format!("https://{}", pds_hostname);
41
+
42
+
let now = Utc::now().timestamp();
43
+
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
44
+
45
+
let mut payload = json!({
46
+
"iss": issuer,
47
+
"sub": sub,
48
+
"aud": issuer,
49
+
"iat": now,
50
+
"exp": exp,
51
+
"jti": token_id,
52
+
"scope": "atproto"
53
+
});
54
+
55
+
if let Some(jkt) = dpop_jkt {
56
+
payload["cnf"] = json!({ "jkt": jkt });
57
+
}
58
+
59
+
let header = json!({
60
+
"alg": "HS256",
61
+
"typ": "at+jwt"
62
+
});
63
+
64
+
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
65
+
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
66
+
67
+
let signing_input = format!("{}.{}", header_b64, payload_b64);
68
+
69
+
let config = AuthConfig::get();
70
+
71
+
type HmacSha256 = hmac::Hmac<Sha256>;
72
+
73
+
let mut mac = HmacSha256::new_from_slice(config.jwt_secret().as_bytes())
74
+
.map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
75
+
mac.update(signing_input.as_bytes());
76
+
let signature = mac.finalize().into_bytes();
77
+
78
+
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
79
+
80
+
Ok(format!("{}.{}", signing_input, signature_b64))
81
+
}
82
+
83
+
pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
84
+
let parts: Vec<&str> = token.split('.').collect();
85
+
if parts.len() != 3 {
86
+
return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
87
+
}
88
+
89
+
let header_bytes = URL_SAFE_NO_PAD
90
+
.decode(parts[0])
91
+
.map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
92
+
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
93
+
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
94
+
95
+
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
96
+
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
97
+
}
98
+
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
99
+
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
100
+
}
101
+
102
+
let config = AuthConfig::get();
103
+
let secret = config.jwt_secret();
104
+
105
+
let signing_input = format!("{}.{}", parts[0], parts[1]);
106
+
let provided_sig = URL_SAFE_NO_PAD
107
+
.decode(parts[2])
108
+
.map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
109
+
110
+
type HmacSha256 = hmac::Hmac<Sha256>;
111
+
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
112
+
.map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
113
+
mac.update(signing_input.as_bytes());
114
+
let expected_sig = mac.finalize().into_bytes();
115
+
116
+
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
117
+
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
118
+
}
119
+
120
+
let payload_bytes = URL_SAFE_NO_PAD
121
+
.decode(parts[1])
122
+
.map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
123
+
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
124
+
.map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
125
+
126
+
let jti = payload
127
+
.get("jti")
128
+
.and_then(|j| j.as_str())
129
+
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
130
+
.to_string();
131
+
132
+
let exp = payload
133
+
.get("exp")
134
+
.and_then(|e| e.as_i64())
135
+
.ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
136
+
137
+
let iat = payload
138
+
.get("iat")
139
+
.and_then(|i| i.as_i64())
140
+
.ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
141
+
142
+
Ok(TokenClaims { jti, exp, iat })
143
+
}
+122
src/oauth/endpoints/token/introspect.rs
+122
src/oauth/endpoints/token/introspect.rs
···
1
+
use axum::{Form, Json};
2
+
use axum::extract::State;
3
+
use axum::http::StatusCode;
4
+
use chrono::Utc;
5
+
use serde::{Deserialize, Serialize};
6
+
7
+
use crate::state::AppState;
8
+
use crate::oauth::{OAuthError, db};
9
+
10
+
use super::helpers::extract_token_claims;
11
+
12
+
#[derive(Debug, Deserialize)]
13
+
pub struct RevokeRequest {
14
+
pub token: Option<String>,
15
+
#[serde(default)]
16
+
pub token_type_hint: Option<String>,
17
+
}
18
+
19
+
pub async fn revoke_token(
20
+
State(state): State<AppState>,
21
+
Form(request): Form<RevokeRequest>,
22
+
) -> Result<StatusCode, OAuthError> {
23
+
if let Some(token) = &request.token {
24
+
if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
25
+
db::delete_token_family(&state.db, db_id).await?;
26
+
} else {
27
+
db::delete_token(&state.db, token).await?;
28
+
}
29
+
}
30
+
31
+
Ok(StatusCode::OK)
32
+
}
33
+
34
+
#[derive(Debug, Deserialize)]
35
+
pub struct IntrospectRequest {
36
+
pub token: String,
37
+
#[serde(default)]
38
+
pub token_type_hint: Option<String>,
39
+
}
40
+
41
+
#[derive(Debug, Serialize)]
42
+
pub struct IntrospectResponse {
43
+
pub active: bool,
44
+
#[serde(skip_serializing_if = "Option::is_none")]
45
+
pub scope: Option<String>,
46
+
#[serde(skip_serializing_if = "Option::is_none")]
47
+
pub client_id: Option<String>,
48
+
#[serde(skip_serializing_if = "Option::is_none")]
49
+
pub username: Option<String>,
50
+
#[serde(skip_serializing_if = "Option::is_none")]
51
+
pub token_type: Option<String>,
52
+
#[serde(skip_serializing_if = "Option::is_none")]
53
+
pub exp: Option<i64>,
54
+
#[serde(skip_serializing_if = "Option::is_none")]
55
+
pub iat: Option<i64>,
56
+
#[serde(skip_serializing_if = "Option::is_none")]
57
+
pub nbf: Option<i64>,
58
+
#[serde(skip_serializing_if = "Option::is_none")]
59
+
pub sub: Option<String>,
60
+
#[serde(skip_serializing_if = "Option::is_none")]
61
+
pub aud: Option<String>,
62
+
#[serde(skip_serializing_if = "Option::is_none")]
63
+
pub iss: Option<String>,
64
+
#[serde(skip_serializing_if = "Option::is_none")]
65
+
pub jti: Option<String>,
66
+
}
67
+
68
+
pub async fn introspect_token(
69
+
State(state): State<AppState>,
70
+
Form(request): Form<IntrospectRequest>,
71
+
) -> Json<IntrospectResponse> {
72
+
let inactive_response = IntrospectResponse {
73
+
active: false,
74
+
scope: None,
75
+
client_id: None,
76
+
username: None,
77
+
token_type: None,
78
+
exp: None,
79
+
iat: None,
80
+
nbf: None,
81
+
sub: None,
82
+
aud: None,
83
+
iss: None,
84
+
jti: None,
85
+
};
86
+
87
+
let token_info = match extract_token_claims(&request.token) {
88
+
Ok(info) => info,
89
+
Err(_) => return Json(inactive_response),
90
+
};
91
+
92
+
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
93
+
Ok(Some(data)) => data,
94
+
_ => return Json(inactive_response),
95
+
};
96
+
97
+
if token_data.expires_at < Utc::now() {
98
+
return Json(inactive_response);
99
+
}
100
+
101
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102
+
let issuer = format!("https://{}", pds_hostname);
103
+
104
+
Json(IntrospectResponse {
105
+
active: true,
106
+
scope: token_data.scope,
107
+
client_id: Some(token_data.client_id),
108
+
username: None,
109
+
token_type: if token_data.parameters.dpop_jkt.is_some() {
110
+
Some("DPoP".to_string())
111
+
} else {
112
+
Some("Bearer".to_string())
113
+
},
114
+
exp: Some(token_info.exp),
115
+
iat: Some(token_info.iat),
116
+
nbf: Some(token_info.iat),
117
+
sub: Some(token_data.did),
118
+
aud: Some(issuer.clone()),
119
+
iss: Some(issuer),
120
+
jti: Some(token_info.jti),
121
+
})
122
+
}
+44
src/oauth/endpoints/token/mod.rs
+44
src/oauth/endpoints/token/mod.rs
···
1
+
mod grants;
2
+
mod helpers;
3
+
mod introspect;
4
+
mod types;
5
+
6
+
use axum::{
7
+
Form, Json,
8
+
extract::State,
9
+
http::HeaderMap,
10
+
};
11
+
12
+
use crate::state::AppState;
13
+
use crate::oauth::OAuthError;
14
+
15
+
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
16
+
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
17
+
pub use introspect::{
18
+
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
19
+
};
20
+
pub use types::{TokenRequest, TokenResponse};
21
+
22
+
pub async fn token_endpoint(
23
+
State(state): State<AppState>,
24
+
headers: HeaderMap,
25
+
Form(request): Form<TokenRequest>,
26
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
27
+
let dpop_proof = headers
28
+
.get("DPoP")
29
+
.and_then(|v| v.to_str().ok())
30
+
.map(|s| s.to_string());
31
+
32
+
match request.grant_type.as_str() {
33
+
"authorization_code" => {
34
+
handle_authorization_code_grant(state, headers, request, dpop_proof).await
35
+
}
36
+
"refresh_token" => {
37
+
handle_refresh_token_grant(state, headers, request, dpop_proof).await
38
+
}
39
+
_ => Err(OAuthError::UnsupportedGrantType(format!(
40
+
"Unsupported grant_type: {}",
41
+
request.grant_type
42
+
))),
43
+
}
44
+
}
+35
src/oauth/endpoints/token/types.rs
+35
src/oauth/endpoints/token/types.rs
···
1
+
use serde::{Deserialize, Serialize};
2
+
3
+
#[derive(Debug, Deserialize)]
4
+
pub struct TokenRequest {
5
+
pub grant_type: String,
6
+
#[serde(default)]
7
+
pub code: Option<String>,
8
+
#[serde(default)]
9
+
pub redirect_uri: Option<String>,
10
+
#[serde(default)]
11
+
pub code_verifier: Option<String>,
12
+
#[serde(default)]
13
+
pub refresh_token: Option<String>,
14
+
#[serde(default)]
15
+
pub client_id: Option<String>,
16
+
#[serde(default)]
17
+
pub client_secret: Option<String>,
18
+
#[serde(default)]
19
+
pub client_assertion: Option<String>,
20
+
#[serde(default)]
21
+
pub client_assertion_type: Option<String>,
22
+
}
23
+
24
+
#[derive(Debug, Serialize)]
25
+
pub struct TokenResponse {
26
+
pub access_token: String,
27
+
pub token_type: String,
28
+
pub expires_in: u64,
29
+
#[serde(skip_serializing_if = "Option::is_none")]
30
+
pub refresh_token: Option<String>,
31
+
#[serde(skip_serializing_if = "Option::is_none")]
32
+
pub scope: Option<String>,
33
+
#[serde(skip_serializing_if = "Option::is_none")]
34
+
pub sub: Option<String>,
35
+
}
+2
-1
src/repo/mod.rs
+2
-1
src/repo/mod.rs
···
38
38
let mut hasher = Sha256::new();
39
39
hasher.update(data);
40
40
let hash = hasher.finalize();
41
-
let multihash = Multihash::wrap(0x12, &hash).unwrap();
41
+
let multihash = Multihash::wrap(0x12, &hash)
42
+
.map_err(|e| RepoError::storage(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to wrap multihash: {:?}", e))))?;
42
43
let cid = Cid::new_v1(0x71, multihash);
43
44
let cid_bytes = cid.to_bytes();
44
45
+12
-3
src/repo/tracking.rs
+12
-3
src/repo/tracking.rs
···
21
21
}
22
22
23
23
pub fn get_written_cids(&self) -> Vec<Cid> {
24
-
self.written_cids.lock().unwrap().clone()
24
+
match self.written_cids.lock() {
25
+
Ok(guard) => guard.clone(),
26
+
Err(poisoned) => poisoned.into_inner().clone(),
27
+
}
25
28
}
26
29
}
27
30
···
32
35
33
36
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
34
37
let cid = self.inner.put(data).await?;
35
-
self.written_cids.lock().unwrap().push(cid.clone());
38
+
match self.written_cids.lock() {
39
+
Ok(mut guard) => guard.push(cid.clone()),
40
+
Err(poisoned) => poisoned.into_inner().push(cid.clone()),
41
+
}
36
42
Ok(cid)
37
43
}
38
44
···
47
53
let blocks: Vec<_> = blocks.into_iter().collect();
48
54
let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect();
49
55
self.inner.put_many(blocks).await?;
50
-
self.written_cids.lock().unwrap().extend(cids);
56
+
match self.written_cids.lock() {
57
+
Ok(mut guard) => guard.extend(cids),
58
+
Err(poisoned) => poisoned.into_inner().extend(cids),
59
+
}
51
60
Ok(())
52
61
}
53
62
+1
-1
src/sync/blob.rs
+1
-1
src/sync/blob.rs
···
132
132
.into_response();
133
133
}
134
134
135
-
let limit = params.limit.unwrap_or(500).min(1000);
135
+
let limit = params.limit.unwrap_or(500).clamp(1, 1000);
136
136
let cursor_cid = params.cursor.as_deref().unwrap_or("");
137
137
138
138
let user_result = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
+5
-4
src/sync/car.rs
+5
-4
src/sync/car.rs
···
23
23
Ok(())
24
24
}
25
25
26
-
pub fn encode_car_header(root_cid: &Cid) -> Vec<u8> {
26
+
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
27
27
let header = CarHeader::new_v1(vec![root_cid.clone()]);
28
-
let header_cbor = header.encode().unwrap_or_default();
28
+
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
29
29
30
30
let mut result = Vec::new();
31
-
write_varint(&mut result, header_cbor.len() as u64).unwrap();
31
+
write_varint(&mut result, header_cbor.len() as u64)
32
+
.expect("Writing to Vec<u8> should never fail");
32
33
result.extend_from_slice(&header_cbor);
33
-
result
34
+
Ok(result)
34
35
}
+1
-1
src/sync/commit.rs
+1
-1
src/sync/commit.rs
···
98
98
State(state): State<AppState>,
99
99
Query(params): Query<ListReposParams>,
100
100
) -> Response {
101
-
let limit = params.limit.unwrap_or(50).min(1000);
101
+
let limit = params.limit.unwrap_or(50).clamp(1, 1000);
102
102
let cursor_did = params.cursor.as_deref().unwrap_or("");
103
103
104
104
let result = sqlx::query!(
+9
-5
src/sync/frame.rs
+9
-5
src/sync/frame.rs
···
38
38
pub cid: Option<String>,
39
39
}
40
40
41
-
impl From<SequencedEvent> for CommitFrame {
42
-
fn from(event: SequencedEvent) -> Self {
41
+
impl TryFrom<SequencedEvent> for CommitFrame {
42
+
type Error = &'static str;
43
+
44
+
fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> {
43
45
let ops = serde_json::from_value::<Vec<RepoOp>>(event.ops.unwrap_or_default())
44
46
.unwrap_or_else(|_| vec![]);
45
47
46
-
CommitFrame {
48
+
let commit_cid = event.commit_cid.ok_or("Missing commit_cid in event")?;
49
+
50
+
Ok(CommitFrame {
47
51
seq: event.seq,
48
52
rebase: false,
49
53
too_big: false,
50
54
repo: event.did,
51
-
commit: event.commit_cid.unwrap_or_default(),
55
+
commit: commit_cid,
52
56
prev: event.prev_cid,
53
57
blocks: Vec::new(),
54
58
ops,
55
59
blobs: event.blobs.unwrap_or_default(),
56
60
time: event.created_at.to_rfc3339(),
57
-
}
61
+
})
58
62
}
59
63
}
+1
-2
src/sync/relay_client.rs
+1
-2
src/sync/relay_client.rs
···
12
12
match connect_async(&url).await {
13
13
Ok((mut ws_stream, _)) => {
14
14
info!("Connected to firehose relay: {}", url);
15
+
let mut rx = state.firehose_tx.subscribe();
15
16
if let Some(tx) = ready_tx.as_ref() {
16
17
tx.send(()).await.ok();
17
18
}
18
-
19
-
let mut rx = state.firehose_tx.subscribe();
20
19
21
20
loop {
22
21
tokio::select! {
+48
-17
src/sync/repo.rs
+48
-17
src/sync/repo.rs
···
15
15
use std::str::FromStr;
16
16
use tracing::error;
17
17
18
+
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
18
20
#[derive(Deserialize)]
19
21
pub struct GetBlocksQuery {
20
22
pub did: String,
···
52
54
}
53
55
};
54
56
55
-
let root_cid = cids.first().cloned().unwrap_or_default();
56
-
57
57
if cids.is_empty() {
58
58
return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response();
59
59
}
60
60
61
-
let header = encode_car_header(&root_cid);
61
+
let root_cid = cids[0];
62
+
63
+
let header = match encode_car_header(&root_cid) {
64
+
Ok(h) => h,
65
+
Err(e) => {
66
+
error!("Failed to encode CAR header: {}", e);
67
+
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to encode CAR").into_response();
68
+
}
69
+
};
62
70
63
71
let mut car_bytes = header;
64
72
···
69
77
let total_len = cid_bytes.len() + block.len();
70
78
71
79
let mut writer = Vec::new();
72
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
73
-
writer.write_all(&cid_bytes).unwrap();
74
-
writer.write_all(&block).unwrap();
80
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
81
+
.expect("Writing to Vec<u8> should never fail");
82
+
writer.write_all(&cid_bytes)
83
+
.expect("Writing to Vec<u8> should never fail");
84
+
writer.write_all(&block)
85
+
.expect("Writing to Vec<u8> should never fail");
75
86
76
87
car_bytes.extend_from_slice(&writer);
77
88
}
···
143
154
}
144
155
};
145
156
146
-
let mut car_bytes = encode_car_header(&head_cid);
157
+
let mut car_bytes = match encode_car_header(&head_cid) {
158
+
Ok(h) => h,
159
+
Err(e) => {
160
+
return (
161
+
StatusCode::INTERNAL_SERVER_ERROR,
162
+
Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})),
163
+
)
164
+
.into_response();
165
+
}
166
+
};
147
167
148
168
let mut stack = vec![head_cid];
149
169
let mut visited = std::collections::HashSet::new();
150
-
let mut limit = 20000;
170
+
let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL;
151
171
152
172
while let Some(cid) = stack.pop() {
153
173
if visited.contains(&cid) {
154
174
continue;
155
175
}
156
176
visited.insert(cid);
157
-
if limit == 0 { break; }
158
-
limit -= 1;
177
+
if remaining == 0 { break; }
178
+
remaining -= 1;
159
179
160
180
if let Ok(Some(block)) = state.block_store.get(&cid).await {
161
181
let cid_bytes = cid.to_bytes();
162
182
let total_len = cid_bytes.len() + block.len();
163
183
let mut writer = Vec::new();
164
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
165
-
writer.write_all(&cid_bytes).unwrap();
166
-
writer.write_all(&block).unwrap();
184
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
185
+
.expect("Writing to Vec<u8> should never fail");
186
+
writer.write_all(&cid_bytes)
187
+
.expect("Writing to Vec<u8> should never fail");
188
+
writer.write_all(&block)
189
+
.expect("Writing to Vec<u8> should never fail");
167
190
car_bytes.extend_from_slice(&writer);
168
191
169
192
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
···
258
281
_ => return (StatusCode::NOT_FOUND, "Block not found").into_response(),
259
282
};
260
283
261
-
let header = encode_car_header(&cid);
284
+
let header = match encode_car_header(&cid) {
285
+
Ok(h) => h,
286
+
Err(e) => {
287
+
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to encode CAR header: {}", e)).into_response();
288
+
}
289
+
};
262
290
let mut car_bytes = header;
263
291
264
292
let cid_bytes = cid.to_bytes();
265
293
let total_len = cid_bytes.len() + block.len();
266
294
let mut writer = Vec::new();
267
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
268
-
writer.write_all(&cid_bytes).unwrap();
269
-
writer.write_all(&block).unwrap();
295
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
296
+
.expect("Writing to Vec<u8> should never fail");
297
+
writer.write_all(&cid_bytes)
298
+
.expect("Writing to Vec<u8> should never fail");
299
+
writer.write_all(&block)
300
+
.expect("Writing to Vec<u8> should never fail");
270
301
car_bytes.extend_from_slice(&writer);
271
302
272
303
(
+37
-23
src/sync/subscribe_repos.rs
+37
-23
src/sync/subscribe_repos.rs
···
9
9
use serde::Deserialize;
10
10
use tracing::{error, info, warn};
11
11
12
+
const BACKFILL_BATCH_SIZE: i64 = 1000;
13
+
12
14
#[derive(Deserialize)]
13
15
pub struct SubscribeReposParams {
14
16
pub cursor: Option<i64>,
···
37
39
info!(cursor = ?params.cursor, "New firehose subscriber");
38
40
39
41
if let Some(cursor) = params.cursor {
40
-
let events = sqlx::query_as!(
41
-
SequencedEvent,
42
-
r#"
43
-
SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids
44
-
FROM repo_seq
45
-
WHERE seq > $1
46
-
ORDER BY seq ASC
47
-
"#,
48
-
cursor
49
-
)
50
-
.fetch_all(&state.db)
51
-
.await;
42
+
let mut current_cursor = cursor;
43
+
loop {
44
+
let events = sqlx::query_as!(
45
+
SequencedEvent,
46
+
r#"
47
+
SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids
48
+
FROM repo_seq
49
+
WHERE seq > $1
50
+
ORDER BY seq ASC
51
+
LIMIT $2
52
+
"#,
53
+
current_cursor,
54
+
BACKFILL_BATCH_SIZE
55
+
)
56
+
.fetch_all(&state.db)
57
+
.await;
52
58
53
-
match events {
54
-
Ok(events) => {
55
-
for event in events {
56
-
if let Err(e) = send_event(&mut socket, &state, event).await {
57
-
warn!("Failed to send backfill event: {}", e);
58
-
return;
59
+
match events {
60
+
Ok(events) => {
61
+
if events.is_empty() {
62
+
break;
63
+
}
64
+
for event in &events {
65
+
current_cursor = event.seq;
66
+
if let Err(e) = send_event(&mut socket, &state, event.clone()).await {
67
+
warn!("Failed to send backfill event: {}", e);
68
+
return;
69
+
}
70
+
}
71
+
if (events.len() as i64) < BACKFILL_BATCH_SIZE {
72
+
break;
59
73
}
60
74
}
61
-
}
62
-
Err(e) => {
63
-
error!("Failed to fetch backfill events: {}", e);
64
-
socket.close().await.ok();
65
-
return;
75
+
Err(e) => {
76
+
error!("Failed to fetch backfill events: {}", e);
77
+
socket.close().await.ok();
78
+
return;
79
+
}
66
80
}
67
81
}
68
82
}
+8
-15
src/sync/util.rs
+8
-15
src/sync/util.rs
···
2
2
use crate::sync::firehose::SequencedEvent;
3
3
use crate::sync::frame::{CommitFrame, Frame, FrameData};
4
4
use cid::Cid;
5
-
use jacquard_repo::car::write_car;
5
+
use jacquard_repo::car::write_car_bytes;
6
6
use jacquard_repo::storage::BlockStore;
7
-
use std::fs;
8
7
use std::str::FromStr;
9
-
use tokio::fs::File;
10
-
use tokio::io::AsyncReadExt;
11
-
use uuid::Uuid;
12
8
13
9
pub async fn format_event_for_sending(
14
10
state: &AppState,
15
11
event: SequencedEvent,
16
12
) -> Result<Vec<u8>, anyhow::Error> {
17
13
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
18
-
let mut frame: CommitFrame = event.into();
14
+
let mut frame: CommitFrame = event.try_into()
15
+
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
19
16
20
-
let mut car_bytes = Vec::new();
21
-
if !block_cids_str.is_empty() {
22
-
let temp_path = format!("/tmp/{}.car", Uuid::new_v4());
17
+
let car_bytes = if !block_cids_str.is_empty() {
23
18
let mut blocks = std::collections::BTreeMap::new();
24
19
25
20
for cid_str in block_cids_str {
···
33
28
}
34
29
35
30
let root = Cid::from_str(&frame.commit)?;
36
-
write_car(&temp_path, vec![root], blocks).await?;
37
-
38
-
let mut file = File::open(&temp_path).await?;
39
-
file.read_to_end(&mut car_bytes).await?;
40
-
fs::remove_file(&temp_path)?;
41
-
}
31
+
write_car_bytes(root, blocks).await?
32
+
} else {
33
+
Vec::new()
34
+
};
42
35
frame.blocks = car_bytes;
43
36
44
37
let frame = Frame {
+2
-342
src/sync/verify.rs
+2
-342
src/sync/verify.rs
···
302
302
}
303
303
304
304
#[cfg(test)]
305
-
mod tests {
306
-
use super::*;
307
-
use sha2::{Digest, Sha256};
308
-
309
-
fn make_cid(data: &[u8]) -> Cid {
310
-
let mut hasher = Sha256::new();
311
-
hasher.update(data);
312
-
let hash = hasher.finalize();
313
-
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
314
-
Cid::new_v1(0x71, multihash)
315
-
}
316
-
317
-
#[test]
318
-
fn test_verifier_creation() {
319
-
let _verifier = CarVerifier::new();
320
-
}
321
-
322
-
#[test]
323
-
fn test_verify_error_display() {
324
-
let err = VerifyError::DidMismatch {
325
-
commit_did: "did:plc:abc".to_string(),
326
-
expected_did: "did:plc:xyz".to_string(),
327
-
};
328
-
assert!(err.to_string().contains("did:plc:abc"));
329
-
assert!(err.to_string().contains("did:plc:xyz"));
330
-
331
-
let err = VerifyError::InvalidSignature;
332
-
assert!(err.to_string().contains("signature"));
333
-
334
-
let err = VerifyError::NoSigningKey;
335
-
assert!(err.to_string().contains("signing key"));
336
-
337
-
let err = VerifyError::MstValidationFailed("test error".to_string());
338
-
assert!(err.to_string().contains("test error"));
339
-
}
340
-
341
-
#[test]
342
-
fn test_mst_validation_missing_root_block() {
343
-
let verifier = CarVerifier::new();
344
-
let blocks: HashMap<Cid, Bytes> = HashMap::new();
345
-
346
-
let fake_cid = make_cid(b"fake data");
347
-
let result = verifier.verify_mst_structure(&fake_cid, &blocks);
348
-
349
-
assert!(result.is_err());
350
-
let err = result.unwrap_err();
351
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
352
-
}
353
-
354
-
#[test]
355
-
fn test_mst_validation_invalid_cbor() {
356
-
let verifier = CarVerifier::new();
357
-
358
-
let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]);
359
-
let cid = make_cid(&bad_cbor);
360
-
361
-
let mut blocks = HashMap::new();
362
-
blocks.insert(cid, bad_cbor);
363
-
364
-
let result = verifier.verify_mst_structure(&cid, &blocks);
365
-
366
-
assert!(result.is_err());
367
-
let err = result.unwrap_err();
368
-
assert!(matches!(err, VerifyError::InvalidCbor(_)));
369
-
}
370
-
371
-
#[test]
372
-
fn test_mst_validation_empty_node() {
373
-
let verifier = CarVerifier::new();
374
-
375
-
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
376
-
"e": []
377
-
})).unwrap();
378
-
let cid = make_cid(&empty_node);
379
-
380
-
let mut blocks = HashMap::new();
381
-
blocks.insert(cid, Bytes::from(empty_node));
382
-
383
-
let result = verifier.verify_mst_structure(&cid, &blocks);
384
-
assert!(result.is_ok());
385
-
}
386
-
387
-
#[test]
388
-
fn test_mst_validation_missing_left_pointer() {
389
-
use ipld_core::ipld::Ipld;
390
-
391
-
let verifier = CarVerifier::new();
392
-
393
-
let missing_left_cid = make_cid(b"missing left");
394
-
let node = Ipld::Map(std::collections::BTreeMap::from([
395
-
("l".to_string(), Ipld::Link(missing_left_cid)),
396
-
("e".to_string(), Ipld::List(vec![])),
397
-
]));
398
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
399
-
let cid = make_cid(&node_bytes);
400
-
401
-
let mut blocks = HashMap::new();
402
-
blocks.insert(cid, Bytes::from(node_bytes));
403
-
404
-
let result = verifier.verify_mst_structure(&cid, &blocks);
405
-
406
-
assert!(result.is_err());
407
-
let err = result.unwrap_err();
408
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
409
-
assert!(err.to_string().contains("left pointer"));
410
-
}
411
-
412
-
#[test]
413
-
fn test_mst_validation_missing_subtree() {
414
-
use ipld_core::ipld::Ipld;
415
-
416
-
let verifier = CarVerifier::new();
417
-
418
-
let missing_subtree_cid = make_cid(b"missing subtree");
419
-
let record_cid = make_cid(b"record");
420
-
421
-
let entry = Ipld::Map(std::collections::BTreeMap::from([
422
-
("k".to_string(), Ipld::Bytes(b"key1".to_vec())),
423
-
("v".to_string(), Ipld::Link(record_cid)),
424
-
("p".to_string(), Ipld::Integer(0)),
425
-
("t".to_string(), Ipld::Link(missing_subtree_cid)),
426
-
]));
427
-
428
-
let node = Ipld::Map(std::collections::BTreeMap::from([
429
-
("e".to_string(), Ipld::List(vec![entry])),
430
-
]));
431
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
432
-
let cid = make_cid(&node_bytes);
433
-
434
-
let mut blocks = HashMap::new();
435
-
blocks.insert(cid, Bytes::from(node_bytes));
436
-
437
-
let result = verifier.verify_mst_structure(&cid, &blocks);
438
-
439
-
assert!(result.is_err());
440
-
let err = result.unwrap_err();
441
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
442
-
assert!(err.to_string().contains("subtree"));
443
-
}
444
-
445
-
#[test]
446
-
fn test_mst_validation_unsorted_keys() {
447
-
use ipld_core::ipld::Ipld;
448
-
449
-
let verifier = CarVerifier::new();
450
-
451
-
let record_cid = make_cid(b"record");
452
-
453
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
454
-
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
455
-
("v".to_string(), Ipld::Link(record_cid)),
456
-
("p".to_string(), Ipld::Integer(0)),
457
-
]));
458
-
459
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
460
-
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
461
-
("v".to_string(), Ipld::Link(record_cid)),
462
-
("p".to_string(), Ipld::Integer(0)),
463
-
]));
464
-
465
-
let node = Ipld::Map(std::collections::BTreeMap::from([
466
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
467
-
]));
468
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
469
-
let cid = make_cid(&node_bytes);
470
-
471
-
let mut blocks = HashMap::new();
472
-
blocks.insert(cid, Bytes::from(node_bytes));
473
-
474
-
let result = verifier.verify_mst_structure(&cid, &blocks);
475
-
476
-
assert!(result.is_err());
477
-
let err = result.unwrap_err();
478
-
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
479
-
assert!(err.to_string().contains("sorted"));
480
-
}
481
-
482
-
#[test]
483
-
fn test_mst_validation_sorted_keys_ok() {
484
-
use ipld_core::ipld::Ipld;
485
-
486
-
let verifier = CarVerifier::new();
487
-
488
-
let record_cid = make_cid(b"record");
489
-
490
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
491
-
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
492
-
("v".to_string(), Ipld::Link(record_cid)),
493
-
("p".to_string(), Ipld::Integer(0)),
494
-
]));
495
-
496
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
497
-
("k".to_string(), Ipld::Bytes(b"bbb".to_vec())),
498
-
("v".to_string(), Ipld::Link(record_cid)),
499
-
("p".to_string(), Ipld::Integer(0)),
500
-
]));
501
-
502
-
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
503
-
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
504
-
("v".to_string(), Ipld::Link(record_cid)),
505
-
("p".to_string(), Ipld::Integer(0)),
506
-
]));
507
-
508
-
let node = Ipld::Map(std::collections::BTreeMap::from([
509
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
510
-
]));
511
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
512
-
let cid = make_cid(&node_bytes);
513
-
514
-
let mut blocks = HashMap::new();
515
-
blocks.insert(cid, Bytes::from(node_bytes));
516
-
517
-
let result = verifier.verify_mst_structure(&cid, &blocks);
518
-
assert!(result.is_ok());
519
-
}
520
-
521
-
#[test]
522
-
fn test_mst_validation_with_valid_left_pointer() {
523
-
use ipld_core::ipld::Ipld;
524
-
525
-
let verifier = CarVerifier::new();
526
-
527
-
let left_node = Ipld::Map(std::collections::BTreeMap::from([
528
-
("e".to_string(), Ipld::List(vec![])),
529
-
]));
530
-
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
531
-
let left_cid = make_cid(&left_node_bytes);
532
-
533
-
let root_node = Ipld::Map(std::collections::BTreeMap::from([
534
-
("l".to_string(), Ipld::Link(left_cid)),
535
-
("e".to_string(), Ipld::List(vec![])),
536
-
]));
537
-
let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap();
538
-
let root_cid = make_cid(&root_node_bytes);
539
-
540
-
let mut blocks = HashMap::new();
541
-
blocks.insert(root_cid, Bytes::from(root_node_bytes));
542
-
blocks.insert(left_cid, Bytes::from(left_node_bytes));
543
-
544
-
let result = verifier.verify_mst_structure(&root_cid, &blocks);
545
-
assert!(result.is_ok());
546
-
}
547
-
548
-
#[test]
549
-
fn test_mst_validation_cycle_detection() {
550
-
let verifier = CarVerifier::new();
551
-
552
-
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
553
-
"e": []
554
-
})).unwrap();
555
-
let cid = make_cid(&node);
556
-
557
-
let mut blocks = HashMap::new();
558
-
blocks.insert(cid, Bytes::from(node));
559
-
560
-
let result = verifier.verify_mst_structure(&cid, &blocks);
561
-
assert!(result.is_ok());
562
-
}
563
-
564
-
#[tokio::test]
565
-
async fn test_unsupported_did_method() {
566
-
let verifier = CarVerifier::new();
567
-
let result = verifier.resolve_did_document("did:unknown:test").await;
568
-
569
-
assert!(result.is_err());
570
-
let err = result.unwrap_err();
571
-
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
572
-
assert!(err.to_string().contains("Unsupported"));
573
-
}
574
-
575
-
#[test]
576
-
fn test_mst_validation_with_prefix_compression() {
577
-
use ipld_core::ipld::Ipld;
578
-
579
-
let verifier = CarVerifier::new();
580
-
let record_cid = make_cid(b"record");
581
-
582
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
583
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
584
-
("v".to_string(), Ipld::Link(record_cid)),
585
-
("p".to_string(), Ipld::Integer(0)),
586
-
]));
587
-
588
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
589
-
("k".to_string(), Ipld::Bytes(b"def".to_vec())),
590
-
("v".to_string(), Ipld::Link(record_cid)),
591
-
("p".to_string(), Ipld::Integer(19)),
592
-
]));
593
-
594
-
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
595
-
("k".to_string(), Ipld::Bytes(b"xyz".to_vec())),
596
-
("v".to_string(), Ipld::Link(record_cid)),
597
-
("p".to_string(), Ipld::Integer(19)),
598
-
]));
599
-
600
-
let node = Ipld::Map(std::collections::BTreeMap::from([
601
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
602
-
]));
603
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
604
-
let cid = make_cid(&node_bytes);
605
-
606
-
let mut blocks = HashMap::new();
607
-
blocks.insert(cid, Bytes::from(node_bytes));
608
-
609
-
let result = verifier.verify_mst_structure(&cid, &blocks);
610
-
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
611
-
}
612
-
613
-
#[test]
614
-
fn test_mst_validation_prefix_compression_unsorted() {
615
-
use ipld_core::ipld::Ipld;
616
-
617
-
let verifier = CarVerifier::new();
618
-
let record_cid = make_cid(b"record");
619
-
620
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
621
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
622
-
("v".to_string(), Ipld::Link(record_cid)),
623
-
("p".to_string(), Ipld::Integer(0)),
624
-
]));
625
-
626
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
627
-
("k".to_string(), Ipld::Bytes(b"abc".to_vec())),
628
-
("v".to_string(), Ipld::Link(record_cid)),
629
-
("p".to_string(), Ipld::Integer(19)),
630
-
]));
631
-
632
-
let node = Ipld::Map(std::collections::BTreeMap::from([
633
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
634
-
]));
635
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
636
-
let cid = make_cid(&node_bytes);
637
-
638
-
let mut blocks = HashMap::new();
639
-
blocks.insert(cid, Bytes::from(node_bytes));
640
-
641
-
let result = verifier.verify_mst_structure(&cid, &blocks);
642
-
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
643
-
let err = result.unwrap_err();
644
-
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
645
-
}
646
-
}
305
+
#[path = "verify_tests.rs"]
306
+
mod tests;
+346
src/sync/verify_tests.rs
+346
src/sync/verify_tests.rs
···
1
+
#[cfg(test)]
2
+
mod tests {
3
+
use crate::sync::verify::{CarVerifier, VerifyError};
4
+
use bytes::Bytes;
5
+
use cid::Cid;
6
+
use sha2::{Digest, Sha256};
7
+
use std::collections::HashMap;
8
+
9
+
fn make_cid(data: &[u8]) -> Cid {
10
+
let mut hasher = Sha256::new();
11
+
hasher.update(data);
12
+
let hash = hasher.finalize();
13
+
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
14
+
Cid::new_v1(0x71, multihash)
15
+
}
16
+
17
+
#[test]
18
+
fn test_verifier_creation() {
19
+
let _verifier = CarVerifier::new();
20
+
}
21
+
22
+
#[test]
23
+
fn test_verify_error_display() {
24
+
let err = VerifyError::DidMismatch {
25
+
commit_did: "did:plc:abc".to_string(),
26
+
expected_did: "did:plc:xyz".to_string(),
27
+
};
28
+
assert!(err.to_string().contains("did:plc:abc"));
29
+
assert!(err.to_string().contains("did:plc:xyz"));
30
+
31
+
let err = VerifyError::InvalidSignature;
32
+
assert!(err.to_string().contains("signature"));
33
+
34
+
let err = VerifyError::NoSigningKey;
35
+
assert!(err.to_string().contains("signing key"));
36
+
37
+
let err = VerifyError::MstValidationFailed("test error".to_string());
38
+
assert!(err.to_string().contains("test error"));
39
+
}
40
+
41
+
#[test]
42
+
fn test_mst_validation_missing_root_block() {
43
+
let verifier = CarVerifier::new();
44
+
let blocks: HashMap<Cid, Bytes> = HashMap::new();
45
+
46
+
let fake_cid = make_cid(b"fake data");
47
+
let result = verifier.verify_mst_structure(&fake_cid, &blocks);
48
+
49
+
assert!(result.is_err());
50
+
let err = result.unwrap_err();
51
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
52
+
}
53
+
54
+
#[test]
55
+
fn test_mst_validation_invalid_cbor() {
56
+
let verifier = CarVerifier::new();
57
+
58
+
let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]);
59
+
let cid = make_cid(&bad_cbor);
60
+
61
+
let mut blocks = HashMap::new();
62
+
blocks.insert(cid, bad_cbor);
63
+
64
+
let result = verifier.verify_mst_structure(&cid, &blocks);
65
+
66
+
assert!(result.is_err());
67
+
let err = result.unwrap_err();
68
+
assert!(matches!(err, VerifyError::InvalidCbor(_)));
69
+
}
70
+
71
+
#[test]
72
+
fn test_mst_validation_empty_node() {
73
+
let verifier = CarVerifier::new();
74
+
75
+
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
76
+
"e": []
77
+
})).unwrap();
78
+
let cid = make_cid(&empty_node);
79
+
80
+
let mut blocks = HashMap::new();
81
+
blocks.insert(cid, Bytes::from(empty_node));
82
+
83
+
let result = verifier.verify_mst_structure(&cid, &blocks);
84
+
assert!(result.is_ok());
85
+
}
86
+
87
+
#[test]
88
+
fn test_mst_validation_missing_left_pointer() {
89
+
use ipld_core::ipld::Ipld;
90
+
91
+
let verifier = CarVerifier::new();
92
+
93
+
let missing_left_cid = make_cid(b"missing left");
94
+
let node = Ipld::Map(std::collections::BTreeMap::from([
95
+
("l".to_string(), Ipld::Link(missing_left_cid)),
96
+
("e".to_string(), Ipld::List(vec![])),
97
+
]));
98
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
99
+
let cid = make_cid(&node_bytes);
100
+
101
+
let mut blocks = HashMap::new();
102
+
blocks.insert(cid, Bytes::from(node_bytes));
103
+
104
+
let result = verifier.verify_mst_structure(&cid, &blocks);
105
+
106
+
assert!(result.is_err());
107
+
let err = result.unwrap_err();
108
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
109
+
assert!(err.to_string().contains("left pointer"));
110
+
}
111
+
112
+
#[test]
113
+
fn test_mst_validation_missing_subtree() {
114
+
use ipld_core::ipld::Ipld;
115
+
116
+
let verifier = CarVerifier::new();
117
+
118
+
let missing_subtree_cid = make_cid(b"missing subtree");
119
+
let record_cid = make_cid(b"record");
120
+
121
+
let entry = Ipld::Map(std::collections::BTreeMap::from([
122
+
("k".to_string(), Ipld::Bytes(b"key1".to_vec())),
123
+
("v".to_string(), Ipld::Link(record_cid)),
124
+
("p".to_string(), Ipld::Integer(0)),
125
+
("t".to_string(), Ipld::Link(missing_subtree_cid)),
126
+
]));
127
+
128
+
let node = Ipld::Map(std::collections::BTreeMap::from([
129
+
("e".to_string(), Ipld::List(vec![entry])),
130
+
]));
131
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
132
+
let cid = make_cid(&node_bytes);
133
+
134
+
let mut blocks = HashMap::new();
135
+
blocks.insert(cid, Bytes::from(node_bytes));
136
+
137
+
let result = verifier.verify_mst_structure(&cid, &blocks);
138
+
139
+
assert!(result.is_err());
140
+
let err = result.unwrap_err();
141
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
142
+
assert!(err.to_string().contains("subtree"));
143
+
}
144
+
145
+
#[test]
146
+
fn test_mst_validation_unsorted_keys() {
147
+
use ipld_core::ipld::Ipld;
148
+
149
+
let verifier = CarVerifier::new();
150
+
151
+
let record_cid = make_cid(b"record");
152
+
153
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
154
+
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
155
+
("v".to_string(), Ipld::Link(record_cid)),
156
+
("p".to_string(), Ipld::Integer(0)),
157
+
]));
158
+
159
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
160
+
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
161
+
("v".to_string(), Ipld::Link(record_cid)),
162
+
("p".to_string(), Ipld::Integer(0)),
163
+
]));
164
+
165
+
let node = Ipld::Map(std::collections::BTreeMap::from([
166
+
("e".to_string(), Ipld::List(vec![entry1, entry2])),
167
+
]));
168
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
169
+
let cid = make_cid(&node_bytes);
170
+
171
+
let mut blocks = HashMap::new();
172
+
blocks.insert(cid, Bytes::from(node_bytes));
173
+
174
+
let result = verifier.verify_mst_structure(&cid, &blocks);
175
+
176
+
assert!(result.is_err());
177
+
let err = result.unwrap_err();
178
+
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
179
+
assert!(err.to_string().contains("sorted"));
180
+
}
181
+
182
+
#[test]
183
+
fn test_mst_validation_sorted_keys_ok() {
184
+
use ipld_core::ipld::Ipld;
185
+
186
+
let verifier = CarVerifier::new();
187
+
188
+
let record_cid = make_cid(b"record");
189
+
190
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
191
+
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
192
+
("v".to_string(), Ipld::Link(record_cid)),
193
+
("p".to_string(), Ipld::Integer(0)),
194
+
]));
195
+
196
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
197
+
("k".to_string(), Ipld::Bytes(b"bbb".to_vec())),
198
+
("v".to_string(), Ipld::Link(record_cid)),
199
+
("p".to_string(), Ipld::Integer(0)),
200
+
]));
201
+
202
+
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
203
+
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
204
+
("v".to_string(), Ipld::Link(record_cid)),
205
+
("p".to_string(), Ipld::Integer(0)),
206
+
]));
207
+
208
+
let node = Ipld::Map(std::collections::BTreeMap::from([
209
+
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
210
+
]));
211
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
212
+
let cid = make_cid(&node_bytes);
213
+
214
+
let mut blocks = HashMap::new();
215
+
blocks.insert(cid, Bytes::from(node_bytes));
216
+
217
+
let result = verifier.verify_mst_structure(&cid, &blocks);
218
+
assert!(result.is_ok());
219
+
}
220
+
221
+
#[test]
222
+
fn test_mst_validation_with_valid_left_pointer() {
223
+
use ipld_core::ipld::Ipld;
224
+
225
+
let verifier = CarVerifier::new();
226
+
227
+
let left_node = Ipld::Map(std::collections::BTreeMap::from([
228
+
("e".to_string(), Ipld::List(vec![])),
229
+
]));
230
+
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
231
+
let left_cid = make_cid(&left_node_bytes);
232
+
233
+
let root_node = Ipld::Map(std::collections::BTreeMap::from([
234
+
("l".to_string(), Ipld::Link(left_cid)),
235
+
("e".to_string(), Ipld::List(vec![])),
236
+
]));
237
+
let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap();
238
+
let root_cid = make_cid(&root_node_bytes);
239
+
240
+
let mut blocks = HashMap::new();
241
+
blocks.insert(root_cid, Bytes::from(root_node_bytes));
242
+
blocks.insert(left_cid, Bytes::from(left_node_bytes));
243
+
244
+
let result = verifier.verify_mst_structure(&root_cid, &blocks);
245
+
assert!(result.is_ok());
246
+
}
247
+
248
+
#[test]
249
+
fn test_mst_validation_cycle_detection() {
250
+
let verifier = CarVerifier::new();
251
+
252
+
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
253
+
"e": []
254
+
})).unwrap();
255
+
let cid = make_cid(&node);
256
+
257
+
let mut blocks = HashMap::new();
258
+
blocks.insert(cid, Bytes::from(node));
259
+
260
+
let result = verifier.verify_mst_structure(&cid, &blocks);
261
+
assert!(result.is_ok());
262
+
}
263
+
264
+
#[tokio::test]
265
+
async fn test_unsupported_did_method() {
266
+
let verifier = CarVerifier::new();
267
+
let result = verifier.resolve_did_document("did:unknown:test").await;
268
+
269
+
assert!(result.is_err());
270
+
let err = result.unwrap_err();
271
+
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
272
+
assert!(err.to_string().contains("Unsupported"));
273
+
}
274
+
275
+
#[test]
276
+
fn test_mst_validation_with_prefix_compression() {
277
+
use ipld_core::ipld::Ipld;
278
+
279
+
let verifier = CarVerifier::new();
280
+
let record_cid = make_cid(b"record");
281
+
282
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
283
+
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
284
+
("v".to_string(), Ipld::Link(record_cid)),
285
+
("p".to_string(), Ipld::Integer(0)),
286
+
]));
287
+
288
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
289
+
("k".to_string(), Ipld::Bytes(b"def".to_vec())),
290
+
("v".to_string(), Ipld::Link(record_cid)),
291
+
("p".to_string(), Ipld::Integer(19)),
292
+
]));
293
+
294
+
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
295
+
("k".to_string(), Ipld::Bytes(b"xyz".to_vec())),
296
+
("v".to_string(), Ipld::Link(record_cid)),
297
+
("p".to_string(), Ipld::Integer(19)),
298
+
]));
299
+
300
+
let node = Ipld::Map(std::collections::BTreeMap::from([
301
+
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
302
+
]));
303
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
304
+
let cid = make_cid(&node_bytes);
305
+
306
+
let mut blocks = HashMap::new();
307
+
blocks.insert(cid, Bytes::from(node_bytes));
308
+
309
+
let result = verifier.verify_mst_structure(&cid, &blocks);
310
+
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
311
+
}
312
+
313
+
#[test]
314
+
fn test_mst_validation_prefix_compression_unsorted() {
315
+
use ipld_core::ipld::Ipld;
316
+
317
+
let verifier = CarVerifier::new();
318
+
let record_cid = make_cid(b"record");
319
+
320
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
321
+
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
322
+
("v".to_string(), Ipld::Link(record_cid)),
323
+
("p".to_string(), Ipld::Integer(0)),
324
+
]));
325
+
326
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
327
+
("k".to_string(), Ipld::Bytes(b"abc".to_vec())),
328
+
("v".to_string(), Ipld::Link(record_cid)),
329
+
("p".to_string(), Ipld::Integer(19)),
330
+
]));
331
+
332
+
let node = Ipld::Map(std::collections::BTreeMap::from([
333
+
("e".to_string(), Ipld::List(vec![entry1, entry2])),
334
+
]));
335
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
336
+
let cid = make_cid(&node_bytes);
337
+
338
+
let mut blocks = HashMap::new();
339
+
blocks.insert(cid, Bytes::from(node_bytes));
340
+
341
+
let result = verifier.verify_mst_structure(&cid, &blocks);
342
+
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
343
+
let err = result.unwrap_err();
344
+
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
345
+
}
346
+
}
+103
src/util.rs
+103
src/util.rs
···
1
+
use rand::Rng;
2
+
use sqlx::PgPool;
3
+
use uuid::Uuid;
4
+
5
+
const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
+
7
+
pub fn generate_token_code() -> String {
8
+
generate_token_code_parts(2, 5)
9
+
}
10
+
11
+
pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
12
+
let mut rng = rand::thread_rng();
13
+
let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
+
15
+
(0..parts)
16
+
.map(|_| {
17
+
(0..part_len)
18
+
.map(|_| chars[rng.gen_range(0..chars.len())])
19
+
.collect::<String>()
20
+
})
21
+
.collect::<Vec<_>>()
22
+
.join("-")
23
+
}
24
+
25
+
#[derive(Debug)]
26
+
pub enum DbLookupError {
27
+
NotFound,
28
+
DatabaseError(sqlx::Error),
29
+
}
30
+
31
+
impl From<sqlx::Error> for DbLookupError {
32
+
fn from(e: sqlx::Error) -> Self {
33
+
DbLookupError::DatabaseError(e)
34
+
}
35
+
}
36
+
37
+
pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
38
+
sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
39
+
.fetch_optional(db)
40
+
.await?
41
+
.ok_or(DbLookupError::NotFound)
42
+
}
43
+
44
+
pub struct UserInfo {
45
+
pub id: Uuid,
46
+
pub did: String,
47
+
pub handle: String,
48
+
}
49
+
50
+
pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
51
+
sqlx::query_as!(
52
+
UserInfo,
53
+
"SELECT id, did, handle FROM users WHERE did = $1",
54
+
did
55
+
)
56
+
.fetch_optional(db)
57
+
.await?
58
+
.ok_or(DbLookupError::NotFound)
59
+
}
60
+
61
+
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
62
+
sqlx::query_as!(
63
+
UserInfo,
64
+
"SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
65
+
identifier
66
+
)
67
+
.fetch_optional(db)
68
+
.await?
69
+
.ok_or(DbLookupError::NotFound)
70
+
}
71
+
72
+
#[cfg(test)]
73
+
mod tests {
74
+
use super::*;
75
+
76
+
#[test]
77
+
fn test_generate_token_code() {
78
+
let code = generate_token_code();
79
+
assert_eq!(code.len(), 11);
80
+
assert!(code.contains('-'));
81
+
82
+
let parts: Vec<&str> = code.split('-').collect();
83
+
assert_eq!(parts.len(), 2);
84
+
assert_eq!(parts[0].len(), 5);
85
+
assert_eq!(parts[1].len(), 5);
86
+
87
+
for c in code.chars() {
88
+
if c != '-' {
89
+
assert!(BASE32_ALPHABET.contains(c));
90
+
}
91
+
}
92
+
}
93
+
94
+
#[test]
95
+
fn test_generate_token_code_parts() {
96
+
let code = generate_token_code_parts(3, 4);
97
+
let parts: Vec<&str> = code.split('-').collect();
98
+
assert_eq!(parts.len(), 3);
99
+
for part in parts {
100
+
assert_eq!(part.len(), 4);
101
+
}
102
+
}
103
+
}
+1
-1
tests/email_update.rs
+1
-1
tests/email_update.rs
+30
-12
tests/relay_client.rs
+30
-12
tests/relay_client.rs
···
13
13
async fn mock_relay_server(
14
14
listener: TcpListener,
15
15
event_tx: mpsc::Sender<Vec<u8>>,
16
-
ready_tx: mpsc::Sender<()>,
16
+
connected_tx: mpsc::Sender<()>,
17
17
) {
18
18
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
19
19
ws.on_upgrade(move |mut socket| async move {
20
-
ready_tx.send(()).await.unwrap();
21
-
if let Some(Ok(Message::Binary(bytes))) = socket.recv().await {
22
-
event_tx.send(bytes.to_vec()).await.unwrap();
20
+
let _ = connected_tx.send(()).await;
21
+
while let Some(Ok(msg)) = socket.recv().await {
22
+
if let Message::Binary(bytes) = msg {
23
+
let _ = event_tx.send(bytes.to_vec()).await;
24
+
break;
25
+
}
23
26
}
24
27
})
25
28
};
···
35
38
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
36
39
let addr = listener.local_addr().unwrap();
37
40
let (event_tx, mut event_rx) = mpsc::channel(1);
38
-
let (ready_tx, ready_rx) = mpsc::channel(1);
39
-
tokio::spawn(mock_relay_server(listener, event_tx, ready_tx));
41
+
let (connected_tx, _connected_rx) = mpsc::channel::<()>(1);
42
+
tokio::spawn(mock_relay_server(listener, event_tx, connected_tx));
40
43
let relay_url = format!("ws://{}", addr);
41
44
42
45
let db_url = get_db_connection_string().await;
···
46
49
.unwrap();
47
50
let state = AppState::new(pool).await;
48
51
52
+
let (ready_tx, ready_rx) = mpsc::channel(1);
49
53
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
50
54
51
-
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
55
+
tokio::time::timeout(
56
+
tokio::time::Duration::from_secs(5),
57
+
async {
58
+
ready_tx.closed().await;
59
+
}
60
+
)
61
+
.await
62
+
.expect("Timeout waiting for relay client to be ready");
52
63
53
64
let dummy_event = SequencedEvent {
54
65
seq: 1,
55
66
did: "did:plc:test".to_string(),
56
67
created_at: Utc::now(),
57
68
event_type: "commit".to_string(),
58
-
commit_cid: None,
69
+
commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()),
59
70
prev_cid: None,
60
-
ops: None,
61
-
blobs: None,
62
-
blocks_cids: None,
71
+
ops: Some(serde_json::json!([])),
72
+
blobs: Some(vec![]),
73
+
blocks_cids: Some(vec![]),
63
74
};
64
75
state.firehose_tx.send(dummy_event).unwrap();
65
76
66
-
let received_bytes = event_rx.recv().await.expect("Did not receive event");
77
+
let received_bytes = tokio::time::timeout(
78
+
tokio::time::Duration::from_secs(5),
79
+
event_rx.recv()
80
+
)
81
+
.await
82
+
.expect("Timeout waiting for event")
83
+
.expect("Event channel closed");
84
+
67
85
assert!(!received_bytes.is_empty());
68
86
}