+2
-2
.sqlx/query-3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817.json
.sqlx/query-e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8.json
+2
-2
.sqlx/query-3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817.json
.sqlx/query-e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8.json
···
1
{
2
"db_name": "PostgreSQL",
3
-
"query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()",
4
"describe": {
5
"columns": [
6
{
···
36
true
37
]
38
},
39
-
"hash": "3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817"
40
}
···
1
{
2
"db_name": "PostgreSQL",
3
+
"query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()\n FOR UPDATE OF st",
4
"describe": {
5
"columns": [
6
{
···
36
true
37
]
38
},
39
+
"hash": "e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8"
40
}
+3
-2
.sqlx/query-51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c.json
.sqlx/query-b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d.json
+3
-2
.sqlx/query-51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c.json
.sqlx/query-b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d.json
···
1
{
2
"db_name": "PostgreSQL",
3
-
"query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n ",
4
"describe": {
5
"columns": [
6
{
···
51
],
52
"parameters": {
53
"Left": [
54
"Int8"
55
]
56
},
···
66
true
67
]
68
},
69
-
"hash": "51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c"
70
}
···
1
{
2
"db_name": "PostgreSQL",
3
+
"query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2\n ",
4
"describe": {
5
"columns": [
6
{
···
51
],
52
"parameters": {
53
"Left": [
54
+
"Int8",
55
"Int8"
56
]
57
},
···
67
true
68
]
69
},
70
+
"hash": "b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d"
71
}
+14
.sqlx/query-642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f.json
+14
.sqlx/query-642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM invite_codes WHERE created_by_user = $1",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Uuid"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f"
14
+
}
+14
.sqlx/query-6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea.json
+14
.sqlx/query-6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM invite_code_uses WHERE used_by_user = $1",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Uuid"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea"
14
+
}
+2
-2
.sqlx/query-7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd.json
.sqlx/query-9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529.json
+2
-2
.sqlx/query-7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd.json
.sqlx/query-9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529.json
···
1
{
2
"db_name": "PostgreSQL",
3
+
"query": "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
4
"describe": {
5
"columns": [],
6
"parameters": {
···
11
},
12
"nullable": []
13
},
14
+
"hash": "9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529"
15
}
+34
.sqlx/query-9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020.json
+34
.sqlx/query-9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "handle",
19
+
"type_info": "Text"
20
+
}
21
+
],
22
+
"parameters": {
23
+
"Left": [
24
+
"Text"
25
+
]
26
+
},
27
+
"nullable": [
28
+
false,
29
+
false,
30
+
false
31
+
]
32
+
},
33
+
"hash": "9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020"
34
+
}
+34
.sqlx/query-b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3.json
+34
.sqlx/query-b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "SELECT id, did, handle FROM users WHERE did = $1",
4
+
"describe": {
5
+
"columns": [
6
+
{
7
+
"ordinal": 0,
8
+
"name": "id",
9
+
"type_info": "Uuid"
10
+
},
11
+
{
12
+
"ordinal": 1,
13
+
"name": "did",
14
+
"type_info": "Text"
15
+
},
16
+
{
17
+
"ordinal": 2,
18
+
"name": "handle",
19
+
"type_info": "Text"
20
+
}
21
+
],
22
+
"parameters": {
23
+
"Left": [
24
+
"Text"
25
+
]
26
+
},
27
+
"nullable": [
28
+
false,
29
+
false,
30
+
false
31
+
]
32
+
},
33
+
"hash": "b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3"
34
+
}
+14
.sqlx/query-c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813.json
+14
.sqlx/query-c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813.json
···
···
1
+
{
2
+
"db_name": "PostgreSQL",
3
+
"query": "DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)",
4
+
"describe": {
5
+
"columns": [],
6
+
"parameters": {
7
+
"Left": [
8
+
"Text"
9
+
]
10
+
},
11
+
"nullable": []
12
+
},
13
+
"hash": "c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813"
14
+
}
+2
-2
.sqlx/query-fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3.json
.sqlx/query-b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7.json
+2
-2
.sqlx/query-fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3.json
.sqlx/query-b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7.json
+4
-4
TODO.md
+4
-4
TODO.md
···
253
### Frontend Views
254
Uses existing ATProto endpoints where possible:
255
256
-
**User Dashboard**
257
- [ ] Account overview (uses `com.atproto.server.getSession`, `com.atproto.admin.getAccountInfo`)
258
- [ ] Active sessions view (needs new endpoint or extend existing)
259
- [ ] App passwords (uses `com.atproto.server.listAppPasswords`, `createAppPassword`, `revokeAppPassword`)
260
- [ ] Invite codes (uses `com.atproto.server.getAccountInviteCodes`, `createInviteCode`)
261
262
-
**Notification Preferences**
263
- [ ] Channel selector (uses `com.bspds.account.*` endpoints above)
264
- [ ] Verification flows for Discord/Telegram/Signal
265
- [ ] Notification history view
266
267
-
**Account Settings**
268
- [ ] Email change (uses `com.atproto.server.requestEmailUpdate`, `updateEmail`)
269
- [ ] Password change (uses `com.atproto.server.requestPasswordReset`, `resetPassword`)
270
- [ ] Handle change (uses `com.atproto.identity.updateHandle`)
271
- [ ] Account deletion (uses `com.atproto.server.requestAccountDelete`, `deleteAccount`)
272
- [ ] Data export (uses `com.atproto.sync.getRepo`)
273
274
-
**Admin Dashboard** (privileged users only)
275
- [ ] User list (uses `com.atproto.admin.getAccountInfos` with pagination)
276
- [ ] User detail/actions (uses `com.atproto.admin.*` endpoints)
277
- [ ] Invite management (uses `com.atproto.admin.getInviteCodes`, `disableInviteCodes`)
···
253
### Frontend Views
254
Uses existing ATProto endpoints where possible:
255
256
+
User Dashboard
257
- [ ] Account overview (uses `com.atproto.server.getSession`, `com.atproto.admin.getAccountInfo`)
258
- [ ] Active sessions view (needs new endpoint or extend existing)
259
- [ ] App passwords (uses `com.atproto.server.listAppPasswords`, `createAppPassword`, `revokeAppPassword`)
260
- [ ] Invite codes (uses `com.atproto.server.getAccountInviteCodes`, `createInviteCode`)
261
262
+
Notification Preferences
263
- [ ] Channel selector (uses `com.bspds.account.*` endpoints above)
264
- [ ] Verification flows for Discord/Telegram/Signal
265
- [ ] Notification history view
266
267
+
Account Settings
268
- [ ] Email change (uses `com.atproto.server.requestEmailUpdate`, `updateEmail`)
269
- [ ] Password change (uses `com.atproto.server.requestPasswordReset`, `resetPassword`)
270
- [ ] Handle change (uses `com.atproto.identity.updateHandle`)
271
- [ ] Account deletion (uses `com.atproto.server.requestAccountDelete`, `deleteAccount`)
272
- [ ] Data export (uses `com.atproto.sync.getRepo`)
273
274
+
Admin Dashboard (privileged users only)
275
- [ ] User list (uses `com.atproto.admin.getAccountInfos` with pagination)
276
- [ ] User detail/actions (uses `com.atproto.admin.*` endpoints)
277
- [ ] Invite management (uses `com.atproto.admin.getInviteCodes`, `disableInviteCodes`)
+23
-1
src/api/actor/preferences.rs
+23
-1
src/api/actor/preferences.rs
···
9
use serde_json::{json, Value};
10
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
12
13
#[derive(Serialize)]
14
pub struct GetPreferencesOutput {
···
141
}
142
};
143
144
for pref in &input.preferences {
145
let pref_type = match pref.get("$type").and_then(|t| t.as_str()) {
146
Some(t) => t,
147
None => {
···
200
}
201
202
for pref in input.preferences {
203
-
let pref_type = pref.get("$type").and_then(|t| t.as_str()).unwrap();
204
205
let insert_result = sqlx::query!(
206
"INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3)",
···
9
use serde_json::{json, Value};
10
11
const APP_BSKY_NAMESPACE: &str = "app.bsky";
12
+
const MAX_PREFERENCES_COUNT: usize = 100;
13
+
const MAX_PREFERENCE_SIZE: usize = 10_000;
14
15
#[derive(Serialize)]
16
pub struct GetPreferencesOutput {
···
143
}
144
};
145
146
+
if input.preferences.len() > MAX_PREFERENCES_COUNT {
147
+
return (
148
+
StatusCode::BAD_REQUEST,
149
+
Json(json!({"error": "InvalidRequest", "message": format!("Too many preferences: {} exceeds limit of {}", input.preferences.len(), MAX_PREFERENCES_COUNT)})),
150
+
)
151
+
.into_response();
152
+
}
153
+
154
for pref in &input.preferences {
155
+
let pref_str = serde_json::to_string(pref).unwrap_or_default();
156
+
if pref_str.len() > MAX_PREFERENCE_SIZE {
157
+
return (
158
+
StatusCode::BAD_REQUEST,
159
+
Json(json!({"error": "InvalidRequest", "message": format!("Preference too large: {} bytes exceeds limit of {}", pref_str.len(), MAX_PREFERENCE_SIZE)})),
160
+
)
161
+
.into_response();
162
+
}
163
+
164
let pref_type = match pref.get("$type").and_then(|t| t.as_str()) {
165
Some(t) => t,
166
None => {
···
219
}
220
221
for pref in input.preferences {
222
+
let pref_type = match pref.get("$type").and_then(|t| t.as_str()) {
223
+
Some(t) => t,
224
+
None => continue,
225
+
};
226
227
let insert_result = sqlx::query!(
228
"INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3)",
-564
src/api/admin/account.rs
-564
src/api/admin/account.rs
···
1
-
use crate::state::AppState;
2
-
use axum::{
3
-
Json,
4
-
extract::{Query, State},
5
-
http::StatusCode,
6
-
response::{IntoResponse, Response},
7
-
};
8
-
use serde::{Deserialize, Serialize};
9
-
use serde_json::json;
10
-
use tracing::{error, warn};
11
-
12
-
#[derive(Deserialize)]
13
-
pub struct GetAccountInfoParams {
14
-
pub did: String,
15
-
}
16
-
17
-
#[derive(Serialize)]
18
-
#[serde(rename_all = "camelCase")]
19
-
pub struct AccountInfo {
20
-
pub did: String,
21
-
pub handle: String,
22
-
pub email: Option<String>,
23
-
pub indexed_at: String,
24
-
pub invite_note: Option<String>,
25
-
pub invites_disabled: bool,
26
-
pub email_confirmed_at: Option<String>,
27
-
pub deactivated_at: Option<String>,
28
-
}
29
-
30
-
#[derive(Serialize)]
31
-
#[serde(rename_all = "camelCase")]
32
-
pub struct GetAccountInfosOutput {
33
-
pub infos: Vec<AccountInfo>,
34
-
}
35
-
36
-
pub async fn get_account_info(
37
-
State(state): State<AppState>,
38
-
headers: axum::http::HeaderMap,
39
-
Query(params): Query<GetAccountInfoParams>,
40
-
) -> Response {
41
-
let auth_header = headers.get("Authorization");
42
-
if auth_header.is_none() {
43
-
return (
44
-
StatusCode::UNAUTHORIZED,
45
-
Json(json!({"error": "AuthenticationRequired"})),
46
-
)
47
-
.into_response();
48
-
}
49
-
50
-
let did = params.did.trim();
51
-
if did.is_empty() {
52
-
return (
53
-
StatusCode::BAD_REQUEST,
54
-
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
55
-
)
56
-
.into_response();
57
-
}
58
-
59
-
let result = sqlx::query!(
60
-
r#"
61
-
SELECT did, handle, email, created_at
62
-
FROM users
63
-
WHERE did = $1
64
-
"#,
65
-
did
66
-
)
67
-
.fetch_optional(&state.db)
68
-
.await;
69
-
70
-
match result {
71
-
Ok(Some(row)) => {
72
-
(
73
-
StatusCode::OK,
74
-
Json(AccountInfo {
75
-
did: row.did,
76
-
handle: row.handle,
77
-
email: Some(row.email),
78
-
indexed_at: row.created_at.to_rfc3339(),
79
-
invite_note: None,
80
-
invites_disabled: false,
81
-
email_confirmed_at: None,
82
-
deactivated_at: None,
83
-
}),
84
-
)
85
-
.into_response()
86
-
}
87
-
Ok(None) => (
88
-
StatusCode::NOT_FOUND,
89
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
90
-
)
91
-
.into_response(),
92
-
Err(e) => {
93
-
error!("DB error in get_account_info: {:?}", e);
94
-
(
95
-
StatusCode::INTERNAL_SERVER_ERROR,
96
-
Json(json!({"error": "InternalError"})),
97
-
)
98
-
.into_response()
99
-
}
100
-
}
101
-
}
102
-
103
-
#[derive(Deserialize)]
104
-
pub struct GetAccountInfosParams {
105
-
pub dids: String,
106
-
}
107
-
108
-
pub async fn get_account_infos(
109
-
State(state): State<AppState>,
110
-
headers: axum::http::HeaderMap,
111
-
Query(params): Query<GetAccountInfosParams>,
112
-
) -> Response {
113
-
let auth_header = headers.get("Authorization");
114
-
if auth_header.is_none() {
115
-
return (
116
-
StatusCode::UNAUTHORIZED,
117
-
Json(json!({"error": "AuthenticationRequired"})),
118
-
)
119
-
.into_response();
120
-
}
121
-
122
-
let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect();
123
-
if dids.is_empty() {
124
-
return (
125
-
StatusCode::BAD_REQUEST,
126
-
Json(json!({"error": "InvalidRequest", "message": "dids is required"})),
127
-
)
128
-
.into_response();
129
-
}
130
-
131
-
let mut infos = Vec::new();
132
-
133
-
for did in dids {
134
-
if did.is_empty() {
135
-
continue;
136
-
}
137
-
138
-
let result = sqlx::query!(
139
-
r#"
140
-
SELECT did, handle, email, created_at
141
-
FROM users
142
-
WHERE did = $1
143
-
"#,
144
-
did
145
-
)
146
-
.fetch_optional(&state.db)
147
-
.await;
148
-
149
-
if let Ok(Some(row)) = result {
150
-
infos.push(AccountInfo {
151
-
did: row.did,
152
-
handle: row.handle,
153
-
email: Some(row.email),
154
-
indexed_at: row.created_at.to_rfc3339(),
155
-
invite_note: None,
156
-
invites_disabled: false,
157
-
email_confirmed_at: None,
158
-
deactivated_at: None,
159
-
});
160
-
}
161
-
}
162
-
163
-
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
164
-
}
165
-
166
-
#[derive(Deserialize)]
167
-
pub struct DeleteAccountInput {
168
-
pub did: String,
169
-
}
170
-
171
-
pub async fn delete_account(
172
-
State(state): State<AppState>,
173
-
headers: axum::http::HeaderMap,
174
-
Json(input): Json<DeleteAccountInput>,
175
-
) -> Response {
176
-
let auth_header = headers.get("Authorization");
177
-
if auth_header.is_none() {
178
-
return (
179
-
StatusCode::UNAUTHORIZED,
180
-
Json(json!({"error": "AuthenticationRequired"})),
181
-
)
182
-
.into_response();
183
-
}
184
-
185
-
let did = input.did.trim();
186
-
if did.is_empty() {
187
-
return (
188
-
StatusCode::BAD_REQUEST,
189
-
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
190
-
)
191
-
.into_response();
192
-
}
193
-
194
-
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
195
-
.fetch_optional(&state.db)
196
-
.await;
197
-
198
-
let user_id = match user {
199
-
Ok(Some(row)) => row.id,
200
-
Ok(None) => {
201
-
return (
202
-
StatusCode::NOT_FOUND,
203
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
204
-
)
205
-
.into_response();
206
-
}
207
-
Err(e) => {
208
-
error!("DB error in delete_account: {:?}", e);
209
-
return (
210
-
StatusCode::INTERNAL_SERVER_ERROR,
211
-
Json(json!({"error": "InternalError"})),
212
-
)
213
-
.into_response();
214
-
}
215
-
};
216
-
217
-
let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did)
218
-
.execute(&state.db)
219
-
.await;
220
-
221
-
let _ = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id)
222
-
.execute(&state.db)
223
-
.await;
224
-
225
-
let _ = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id)
226
-
.execute(&state.db)
227
-
.await;
228
-
229
-
let _ = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id)
230
-
.execute(&state.db)
231
-
.await;
232
-
233
-
let _ = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id)
234
-
.execute(&state.db)
235
-
.await;
236
-
237
-
let result = sqlx::query!("DELETE FROM users WHERE id = $1", user_id)
238
-
.execute(&state.db)
239
-
.await;
240
-
241
-
match result {
242
-
Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
243
-
Err(e) => {
244
-
error!("DB error deleting account: {:?}", e);
245
-
(
246
-
StatusCode::INTERNAL_SERVER_ERROR,
247
-
Json(json!({"error": "InternalError"})),
248
-
)
249
-
.into_response()
250
-
}
251
-
}
252
-
}
253
-
254
-
#[derive(Deserialize)]
255
-
pub struct UpdateAccountEmailInput {
256
-
pub account: String,
257
-
pub email: String,
258
-
}
259
-
260
-
pub async fn update_account_email(
261
-
State(state): State<AppState>,
262
-
headers: axum::http::HeaderMap,
263
-
Json(input): Json<UpdateAccountEmailInput>,
264
-
) -> Response {
265
-
let auth_header = headers.get("Authorization");
266
-
if auth_header.is_none() {
267
-
return (
268
-
StatusCode::UNAUTHORIZED,
269
-
Json(json!({"error": "AuthenticationRequired"})),
270
-
)
271
-
.into_response();
272
-
}
273
-
274
-
let account = input.account.trim();
275
-
let email = input.email.trim();
276
-
277
-
if account.is_empty() || email.is_empty() {
278
-
return (
279
-
StatusCode::BAD_REQUEST,
280
-
Json(json!({"error": "InvalidRequest", "message": "account and email are required"})),
281
-
)
282
-
.into_response();
283
-
}
284
-
285
-
let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account)
286
-
.execute(&state.db)
287
-
.await;
288
-
289
-
match result {
290
-
Ok(r) => {
291
-
if r.rows_affected() == 0 {
292
-
return (
293
-
StatusCode::NOT_FOUND,
294
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
295
-
)
296
-
.into_response();
297
-
}
298
-
(StatusCode::OK, Json(json!({}))).into_response()
299
-
}
300
-
Err(e) => {
301
-
error!("DB error updating email: {:?}", e);
302
-
(
303
-
StatusCode::INTERNAL_SERVER_ERROR,
304
-
Json(json!({"error": "InternalError"})),
305
-
)
306
-
.into_response()
307
-
}
308
-
}
309
-
}
310
-
311
-
#[derive(Deserialize)]
312
-
pub struct UpdateAccountHandleInput {
313
-
pub did: String,
314
-
pub handle: String,
315
-
}
316
-
317
-
pub async fn update_account_handle(
318
-
State(state): State<AppState>,
319
-
headers: axum::http::HeaderMap,
320
-
Json(input): Json<UpdateAccountHandleInput>,
321
-
) -> Response {
322
-
let auth_header = headers.get("Authorization");
323
-
if auth_header.is_none() {
324
-
return (
325
-
StatusCode::UNAUTHORIZED,
326
-
Json(json!({"error": "AuthenticationRequired"})),
327
-
)
328
-
.into_response();
329
-
}
330
-
331
-
let did = input.did.trim();
332
-
let handle = input.handle.trim();
333
-
334
-
if did.is_empty() || handle.is_empty() {
335
-
return (
336
-
StatusCode::BAD_REQUEST,
337
-
Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})),
338
-
)
339
-
.into_response();
340
-
}
341
-
342
-
if !handle
343
-
.chars()
344
-
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
345
-
{
346
-
return (
347
-
StatusCode::BAD_REQUEST,
348
-
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
349
-
)
350
-
.into_response();
351
-
}
352
-
353
-
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
354
-
.fetch_optional(&state.db)
355
-
.await;
356
-
357
-
if let Ok(Some(_)) = existing {
358
-
return (
359
-
StatusCode::BAD_REQUEST,
360
-
Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
361
-
)
362
-
.into_response();
363
-
}
364
-
365
-
let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did)
366
-
.execute(&state.db)
367
-
.await;
368
-
369
-
match result {
370
-
Ok(r) => {
371
-
if r.rows_affected() == 0 {
372
-
return (
373
-
StatusCode::NOT_FOUND,
374
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
375
-
)
376
-
.into_response();
377
-
}
378
-
(StatusCode::OK, Json(json!({}))).into_response()
379
-
}
380
-
Err(e) => {
381
-
error!("DB error updating handle: {:?}", e);
382
-
(
383
-
StatusCode::INTERNAL_SERVER_ERROR,
384
-
Json(json!({"error": "InternalError"})),
385
-
)
386
-
.into_response()
387
-
}
388
-
}
389
-
}
390
-
391
-
#[derive(Deserialize)]
392
-
pub struct UpdateAccountPasswordInput {
393
-
pub did: String,
394
-
pub password: String,
395
-
}
396
-
397
-
pub async fn update_account_password(
398
-
State(state): State<AppState>,
399
-
headers: axum::http::HeaderMap,
400
-
Json(input): Json<UpdateAccountPasswordInput>,
401
-
) -> Response {
402
-
let auth_header = headers.get("Authorization");
403
-
if auth_header.is_none() {
404
-
return (
405
-
StatusCode::UNAUTHORIZED,
406
-
Json(json!({"error": "AuthenticationRequired"})),
407
-
)
408
-
.into_response();
409
-
}
410
-
411
-
let did = input.did.trim();
412
-
let password = input.password.trim();
413
-
414
-
if did.is_empty() || password.is_empty() {
415
-
return (
416
-
StatusCode::BAD_REQUEST,
417
-
Json(json!({"error": "InvalidRequest", "message": "did and password are required"})),
418
-
)
419
-
.into_response();
420
-
}
421
-
422
-
let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) {
423
-
Ok(h) => h,
424
-
Err(e) => {
425
-
error!("Failed to hash password: {:?}", e);
426
-
return (
427
-
StatusCode::INTERNAL_SERVER_ERROR,
428
-
Json(json!({"error": "InternalError"})),
429
-
)
430
-
.into_response();
431
-
}
432
-
};
433
-
434
-
let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did)
435
-
.execute(&state.db)
436
-
.await;
437
-
438
-
match result {
439
-
Ok(r) => {
440
-
if r.rows_affected() == 0 {
441
-
return (
442
-
StatusCode::NOT_FOUND,
443
-
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
444
-
)
445
-
.into_response();
446
-
}
447
-
(StatusCode::OK, Json(json!({}))).into_response()
448
-
}
449
-
Err(e) => {
450
-
error!("DB error updating password: {:?}", e);
451
-
(
452
-
StatusCode::INTERNAL_SERVER_ERROR,
453
-
Json(json!({"error": "InternalError"})),
454
-
)
455
-
.into_response()
456
-
}
457
-
}
458
-
}
459
-
460
-
#[derive(Deserialize)]
461
-
#[serde(rename_all = "camelCase")]
462
-
pub struct SendEmailInput {
463
-
pub recipient_did: String,
464
-
pub sender_did: String,
465
-
pub content: String,
466
-
pub subject: Option<String>,
467
-
pub comment: Option<String>,
468
-
}
469
-
470
-
#[derive(Serialize)]
471
-
pub struct SendEmailOutput {
472
-
pub sent: bool,
473
-
}
474
-
475
-
pub async fn send_email(
476
-
State(state): State<AppState>,
477
-
headers: axum::http::HeaderMap,
478
-
Json(input): Json<SendEmailInput>,
479
-
) -> Response {
480
-
let auth_header = headers.get("Authorization");
481
-
if auth_header.is_none() {
482
-
return (
483
-
StatusCode::UNAUTHORIZED,
484
-
Json(json!({"error": "AuthenticationRequired"})),
485
-
)
486
-
.into_response();
487
-
}
488
-
489
-
let recipient_did = input.recipient_did.trim();
490
-
let content = input.content.trim();
491
-
492
-
if recipient_did.is_empty() {
493
-
return (
494
-
StatusCode::BAD_REQUEST,
495
-
Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})),
496
-
)
497
-
.into_response();
498
-
}
499
-
500
-
if content.is_empty() {
501
-
return (
502
-
StatusCode::BAD_REQUEST,
503
-
Json(json!({"error": "InvalidRequest", "message": "content is required"})),
504
-
)
505
-
.into_response();
506
-
}
507
-
508
-
let user = sqlx::query!(
509
-
"SELECT id, email, handle FROM users WHERE did = $1",
510
-
recipient_did
511
-
)
512
-
.fetch_optional(&state.db)
513
-
.await;
514
-
515
-
let (user_id, email, handle) = match user {
516
-
Ok(Some(row)) => (row.id, row.email, row.handle),
517
-
Ok(None) => {
518
-
return (
519
-
StatusCode::NOT_FOUND,
520
-
Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})),
521
-
)
522
-
.into_response();
523
-
}
524
-
Err(e) => {
525
-
error!("DB error in send_email: {:?}", e);
526
-
return (
527
-
StatusCode::INTERNAL_SERVER_ERROR,
528
-
Json(json!({"error": "InternalError"})),
529
-
)
530
-
.into_response();
531
-
}
532
-
};
533
-
534
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
535
-
let subject = input
536
-
.subject
537
-
.clone()
538
-
.unwrap_or_else(|| format!("Message from {}", hostname));
539
-
540
-
let notification = crate::notifications::NewNotification::email(
541
-
user_id,
542
-
crate::notifications::NotificationType::AdminEmail,
543
-
email,
544
-
subject,
545
-
content.to_string(),
546
-
);
547
-
548
-
let result = crate::notifications::enqueue_notification(&state.db, notification).await;
549
-
550
-
match result {
551
-
Ok(_) => {
552
-
tracing::info!(
553
-
"Admin email queued for {} ({})",
554
-
handle,
555
-
recipient_did
556
-
);
557
-
(StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()
558
-
}
559
-
Err(e) => {
560
-
warn!("Failed to enqueue admin email: {:?}", e);
561
-
(StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()
562
-
}
563
-
}
564
-
}
···
+190
src/api/admin/account/delete.rs
+190
src/api/admin/account/delete.rs
···
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::Deserialize;
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct DeleteAccountInput {
14
+
pub did: String,
15
+
}
16
+
17
+
pub async fn delete_account(
18
+
State(state): State<AppState>,
19
+
headers: axum::http::HeaderMap,
20
+
Json(input): Json<DeleteAccountInput>,
21
+
) -> Response {
22
+
let auth_header = headers.get("Authorization");
23
+
if auth_header.is_none() {
24
+
return (
25
+
StatusCode::UNAUTHORIZED,
26
+
Json(json!({"error": "AuthenticationRequired"})),
27
+
)
28
+
.into_response();
29
+
}
30
+
31
+
let did = input.did.trim();
32
+
if did.is_empty() {
33
+
return (
34
+
StatusCode::BAD_REQUEST,
35
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
36
+
)
37
+
.into_response();
38
+
}
39
+
40
+
let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
41
+
.fetch_optional(&state.db)
42
+
.await;
43
+
44
+
let user_id = match user {
45
+
Ok(Some(row)) => row.id,
46
+
Ok(None) => {
47
+
return (
48
+
StatusCode::NOT_FOUND,
49
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
50
+
)
51
+
.into_response();
52
+
}
53
+
Err(e) => {
54
+
error!("DB error in delete_account: {:?}", e);
55
+
return (
56
+
StatusCode::INTERNAL_SERVER_ERROR,
57
+
Json(json!({"error": "InternalError"})),
58
+
)
59
+
.into_response();
60
+
}
61
+
};
62
+
63
+
let mut tx = match state.db.begin().await {
64
+
Ok(tx) => tx,
65
+
Err(e) => {
66
+
error!("Failed to begin transaction for account deletion: {:?}", e);
67
+
return (
68
+
StatusCode::INTERNAL_SERVER_ERROR,
69
+
Json(json!({"error": "InternalError"})),
70
+
)
71
+
.into_response();
72
+
}
73
+
};
74
+
75
+
if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did)
76
+
.execute(&mut *tx)
77
+
.await
78
+
{
79
+
error!("Failed to delete session tokens for {}: {:?}", did, e);
80
+
return (
81
+
StatusCode::INTERNAL_SERVER_ERROR,
82
+
Json(json!({"error": "InternalError", "message": "Failed to delete session tokens"})),
83
+
)
84
+
.into_response();
85
+
}
86
+
87
+
if let Err(e) = sqlx::query!("DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)", did)
88
+
.execute(&mut *tx)
89
+
.await
90
+
{
91
+
error!("Failed to delete used refresh tokens for {}: {:?}", did, e);
92
+
}
93
+
94
+
if let Err(e) = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id)
95
+
.execute(&mut *tx)
96
+
.await
97
+
{
98
+
error!("Failed to delete records for user {}: {:?}", user_id, e);
99
+
return (
100
+
StatusCode::INTERNAL_SERVER_ERROR,
101
+
Json(json!({"error": "InternalError", "message": "Failed to delete records"})),
102
+
)
103
+
.into_response();
104
+
}
105
+
106
+
if let Err(e) = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id)
107
+
.execute(&mut *tx)
108
+
.await
109
+
{
110
+
error!("Failed to delete repos for user {}: {:?}", user_id, e);
111
+
return (
112
+
StatusCode::INTERNAL_SERVER_ERROR,
113
+
Json(json!({"error": "InternalError", "message": "Failed to delete repos"})),
114
+
)
115
+
.into_response();
116
+
}
117
+
118
+
if let Err(e) = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id)
119
+
.execute(&mut *tx)
120
+
.await
121
+
{
122
+
error!("Failed to delete blobs for user {}: {:?}", user_id, e);
123
+
return (
124
+
StatusCode::INTERNAL_SERVER_ERROR,
125
+
Json(json!({"error": "InternalError", "message": "Failed to delete blobs"})),
126
+
)
127
+
.into_response();
128
+
}
129
+
130
+
if let Err(e) = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1", user_id)
131
+
.execute(&mut *tx)
132
+
.await
133
+
{
134
+
error!("Failed to delete app passwords for user {}: {:?}", user_id, e);
135
+
return (
136
+
StatusCode::INTERNAL_SERVER_ERROR,
137
+
Json(json!({"error": "InternalError", "message": "Failed to delete app passwords"})),
138
+
)
139
+
.into_response();
140
+
}
141
+
142
+
if let Err(e) = sqlx::query!("DELETE FROM invite_code_uses WHERE used_by_user = $1", user_id)
143
+
.execute(&mut *tx)
144
+
.await
145
+
{
146
+
error!("Failed to delete invite code uses for user {}: {:?}", user_id, e);
147
+
}
148
+
149
+
if let Err(e) = sqlx::query!("DELETE FROM invite_codes WHERE created_by_user = $1", user_id)
150
+
.execute(&mut *tx)
151
+
.await
152
+
{
153
+
error!("Failed to delete invite codes for user {}: {:?}", user_id, e);
154
+
}
155
+
156
+
if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id)
157
+
.execute(&mut *tx)
158
+
.await
159
+
{
160
+
error!("Failed to delete user keys for user {}: {:?}", user_id, e);
161
+
return (
162
+
StatusCode::INTERNAL_SERVER_ERROR,
163
+
Json(json!({"error": "InternalError", "message": "Failed to delete user keys"})),
164
+
)
165
+
.into_response();
166
+
}
167
+
168
+
if let Err(e) = sqlx::query!("DELETE FROM users WHERE id = $1", user_id)
169
+
.execute(&mut *tx)
170
+
.await
171
+
{
172
+
error!("Failed to delete user {}: {:?}", user_id, e);
173
+
return (
174
+
StatusCode::INTERNAL_SERVER_ERROR,
175
+
Json(json!({"error": "InternalError", "message": "Failed to delete user"})),
176
+
)
177
+
.into_response();
178
+
}
179
+
180
+
if let Err(e) = tx.commit().await {
181
+
error!("Failed to commit account deletion transaction: {:?}", e);
182
+
return (
183
+
StatusCode::INTERNAL_SERVER_ERROR,
184
+
Json(json!({"error": "InternalError", "message": "Failed to commit deletion"})),
185
+
)
186
+
.into_response();
187
+
}
188
+
189
+
(StatusCode::OK, Json(json!({}))).into_response()
190
+
}
+116
src/api/admin/account/email.rs
+116
src/api/admin/account/email.rs
···
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::{Deserialize, Serialize};
9
+
use serde_json::json;
10
+
use tracing::{error, warn};
11
+
12
+
#[derive(Deserialize)]
13
+
#[serde(rename_all = "camelCase")]
14
+
pub struct SendEmailInput {
15
+
pub recipient_did: String,
16
+
pub sender_did: String,
17
+
pub content: String,
18
+
pub subject: Option<String>,
19
+
pub comment: Option<String>,
20
+
}
21
+
22
+
#[derive(Serialize)]
23
+
pub struct SendEmailOutput {
24
+
pub sent: bool,
25
+
}
26
+
27
+
pub async fn send_email(
28
+
State(state): State<AppState>,
29
+
headers: axum::http::HeaderMap,
30
+
Json(input): Json<SendEmailInput>,
31
+
) -> Response {
32
+
let auth_header = headers.get("Authorization");
33
+
if auth_header.is_none() {
34
+
return (
35
+
StatusCode::UNAUTHORIZED,
36
+
Json(json!({"error": "AuthenticationRequired"})),
37
+
)
38
+
.into_response();
39
+
}
40
+
41
+
let recipient_did = input.recipient_did.trim();
42
+
let content = input.content.trim();
43
+
44
+
if recipient_did.is_empty() {
45
+
return (
46
+
StatusCode::BAD_REQUEST,
47
+
Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})),
48
+
)
49
+
.into_response();
50
+
}
51
+
52
+
if content.is_empty() {
53
+
return (
54
+
StatusCode::BAD_REQUEST,
55
+
Json(json!({"error": "InvalidRequest", "message": "content is required"})),
56
+
)
57
+
.into_response();
58
+
}
59
+
60
+
let user = sqlx::query!(
61
+
"SELECT id, email, handle FROM users WHERE did = $1",
62
+
recipient_did
63
+
)
64
+
.fetch_optional(&state.db)
65
+
.await;
66
+
67
+
let (user_id, email, handle) = match user {
68
+
Ok(Some(row)) => (row.id, row.email, row.handle),
69
+
Ok(None) => {
70
+
return (
71
+
StatusCode::NOT_FOUND,
72
+
Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})),
73
+
)
74
+
.into_response();
75
+
}
76
+
Err(e) => {
77
+
error!("DB error in send_email: {:?}", e);
78
+
return (
79
+
StatusCode::INTERNAL_SERVER_ERROR,
80
+
Json(json!({"error": "InternalError"})),
81
+
)
82
+
.into_response();
83
+
}
84
+
};
85
+
86
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
87
+
let subject = input
88
+
.subject
89
+
.clone()
90
+
.unwrap_or_else(|| format!("Message from {}", hostname));
91
+
92
+
let notification = crate::notifications::NewNotification::email(
93
+
user_id,
94
+
crate::notifications::NotificationType::AdminEmail,
95
+
email,
96
+
subject,
97
+
content.to_string(),
98
+
);
99
+
100
+
let result = crate::notifications::enqueue_notification(&state.db, notification).await;
101
+
102
+
match result {
103
+
Ok(_) => {
104
+
tracing::info!(
105
+
"Admin email queued for {} ({})",
106
+
handle,
107
+
recipient_did
108
+
);
109
+
(StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()
110
+
}
111
+
Err(e) => {
112
+
warn!("Failed to enqueue admin email: {:?}", e);
113
+
(StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()
114
+
}
115
+
}
116
+
}
+164
src/api/admin/account/info.rs
+164
src/api/admin/account/info.rs
···
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::{Query, State},
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::{Deserialize, Serialize};
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct GetAccountInfoParams {
14
+
pub did: String,
15
+
}
16
+
17
+
#[derive(Serialize)]
18
+
#[serde(rename_all = "camelCase")]
19
+
pub struct AccountInfo {
20
+
pub did: String,
21
+
pub handle: String,
22
+
pub email: Option<String>,
23
+
pub indexed_at: String,
24
+
pub invite_note: Option<String>,
25
+
pub invites_disabled: bool,
26
+
pub email_confirmed_at: Option<String>,
27
+
pub deactivated_at: Option<String>,
28
+
}
29
+
30
+
#[derive(Serialize)]
31
+
#[serde(rename_all = "camelCase")]
32
+
pub struct GetAccountInfosOutput {
33
+
pub infos: Vec<AccountInfo>,
34
+
}
35
+
36
+
pub async fn get_account_info(
37
+
State(state): State<AppState>,
38
+
headers: axum::http::HeaderMap,
39
+
Query(params): Query<GetAccountInfoParams>,
40
+
) -> Response {
41
+
let auth_header = headers.get("Authorization");
42
+
if auth_header.is_none() {
43
+
return (
44
+
StatusCode::UNAUTHORIZED,
45
+
Json(json!({"error": "AuthenticationRequired"})),
46
+
)
47
+
.into_response();
48
+
}
49
+
50
+
let did = params.did.trim();
51
+
if did.is_empty() {
52
+
return (
53
+
StatusCode::BAD_REQUEST,
54
+
Json(json!({"error": "InvalidRequest", "message": "did is required"})),
55
+
)
56
+
.into_response();
57
+
}
58
+
59
+
let result = sqlx::query!(
60
+
r#"
61
+
SELECT did, handle, email, created_at
62
+
FROM users
63
+
WHERE did = $1
64
+
"#,
65
+
did
66
+
)
67
+
.fetch_optional(&state.db)
68
+
.await;
69
+
70
+
match result {
71
+
Ok(Some(row)) => {
72
+
(
73
+
StatusCode::OK,
74
+
Json(AccountInfo {
75
+
did: row.did,
76
+
handle: row.handle,
77
+
email: Some(row.email),
78
+
indexed_at: row.created_at.to_rfc3339(),
79
+
invite_note: None,
80
+
invites_disabled: false,
81
+
email_confirmed_at: None,
82
+
deactivated_at: None,
83
+
}),
84
+
)
85
+
.into_response()
86
+
}
87
+
Ok(None) => (
88
+
StatusCode::NOT_FOUND,
89
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
90
+
)
91
+
.into_response(),
92
+
Err(e) => {
93
+
error!("DB error in get_account_info: {:?}", e);
94
+
(
95
+
StatusCode::INTERNAL_SERVER_ERROR,
96
+
Json(json!({"error": "InternalError"})),
97
+
)
98
+
.into_response()
99
+
}
100
+
}
101
+
}
102
+
103
+
#[derive(Deserialize)]
104
+
pub struct GetAccountInfosParams {
105
+
pub dids: String,
106
+
}
107
+
108
+
pub async fn get_account_infos(
109
+
State(state): State<AppState>,
110
+
headers: axum::http::HeaderMap,
111
+
Query(params): Query<GetAccountInfosParams>,
112
+
) -> Response {
113
+
let auth_header = headers.get("Authorization");
114
+
if auth_header.is_none() {
115
+
return (
116
+
StatusCode::UNAUTHORIZED,
117
+
Json(json!({"error": "AuthenticationRequired"})),
118
+
)
119
+
.into_response();
120
+
}
121
+
122
+
let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect();
123
+
if dids.is_empty() {
124
+
return (
125
+
StatusCode::BAD_REQUEST,
126
+
Json(json!({"error": "InvalidRequest", "message": "dids is required"})),
127
+
)
128
+
.into_response();
129
+
}
130
+
131
+
let mut infos = Vec::new();
132
+
133
+
for did in dids {
134
+
if did.is_empty() {
135
+
continue;
136
+
}
137
+
138
+
let result = sqlx::query!(
139
+
r#"
140
+
SELECT did, handle, email, created_at
141
+
FROM users
142
+
WHERE did = $1
143
+
"#,
144
+
did
145
+
)
146
+
.fetch_optional(&state.db)
147
+
.await;
148
+
149
+
if let Ok(Some(row)) = result {
150
+
infos.push(AccountInfo {
151
+
did: row.did,
152
+
handle: row.handle,
153
+
email: Some(row.email),
154
+
indexed_at: row.created_at.to_rfc3339(),
155
+
invite_note: None,
156
+
invites_disabled: false,
157
+
email_confirmed_at: None,
158
+
deactivated_at: None,
159
+
});
160
+
}
161
+
}
162
+
163
+
(StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()
164
+
}
+15
src/api/admin/account/mod.rs
+15
src/api/admin/account/mod.rs
···
···
1
+
mod delete;
2
+
mod email;
3
+
mod info;
4
+
mod update;
5
+
6
+
pub use delete::{delete_account, DeleteAccountInput};
7
+
pub use email::{send_email, SendEmailInput, SendEmailOutput};
8
+
pub use info::{
9
+
get_account_info, get_account_infos, AccountInfo, GetAccountInfoParams, GetAccountInfosOutput,
10
+
GetAccountInfosParams,
11
+
};
12
+
pub use update::{
13
+
update_account_email, update_account_handle, update_account_password, UpdateAccountEmailInput,
14
+
UpdateAccountHandleInput, UpdateAccountPasswordInput,
15
+
};
+216
src/api/admin/account/update.rs
+216
src/api/admin/account/update.rs
···
···
1
+
use crate::state::AppState;
2
+
use axum::{
3
+
Json,
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
use serde::Deserialize;
9
+
use serde_json::json;
10
+
use tracing::error;
11
+
12
+
#[derive(Deserialize)]
13
+
pub struct UpdateAccountEmailInput {
14
+
pub account: String,
15
+
pub email: String,
16
+
}
17
+
18
+
pub async fn update_account_email(
19
+
State(state): State<AppState>,
20
+
headers: axum::http::HeaderMap,
21
+
Json(input): Json<UpdateAccountEmailInput>,
22
+
) -> Response {
23
+
let auth_header = headers.get("Authorization");
24
+
if auth_header.is_none() {
25
+
return (
26
+
StatusCode::UNAUTHORIZED,
27
+
Json(json!({"error": "AuthenticationRequired"})),
28
+
)
29
+
.into_response();
30
+
}
31
+
32
+
let account = input.account.trim();
33
+
let email = input.email.trim();
34
+
35
+
if account.is_empty() || email.is_empty() {
36
+
return (
37
+
StatusCode::BAD_REQUEST,
38
+
Json(json!({"error": "InvalidRequest", "message": "account and email are required"})),
39
+
)
40
+
.into_response();
41
+
}
42
+
43
+
let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account)
44
+
.execute(&state.db)
45
+
.await;
46
+
47
+
match result {
48
+
Ok(r) => {
49
+
if r.rows_affected() == 0 {
50
+
return (
51
+
StatusCode::NOT_FOUND,
52
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
53
+
)
54
+
.into_response();
55
+
}
56
+
(StatusCode::OK, Json(json!({}))).into_response()
57
+
}
58
+
Err(e) => {
59
+
error!("DB error updating email: {:?}", e);
60
+
(
61
+
StatusCode::INTERNAL_SERVER_ERROR,
62
+
Json(json!({"error": "InternalError"})),
63
+
)
64
+
.into_response()
65
+
}
66
+
}
67
+
}
68
+
69
+
#[derive(Deserialize)]
70
+
pub struct UpdateAccountHandleInput {
71
+
pub did: String,
72
+
pub handle: String,
73
+
}
74
+
75
+
pub async fn update_account_handle(
76
+
State(state): State<AppState>,
77
+
headers: axum::http::HeaderMap,
78
+
Json(input): Json<UpdateAccountHandleInput>,
79
+
) -> Response {
80
+
let auth_header = headers.get("Authorization");
81
+
if auth_header.is_none() {
82
+
return (
83
+
StatusCode::UNAUTHORIZED,
84
+
Json(json!({"error": "AuthenticationRequired"})),
85
+
)
86
+
.into_response();
87
+
}
88
+
89
+
let did = input.did.trim();
90
+
let handle = input.handle.trim();
91
+
92
+
if did.is_empty() || handle.is_empty() {
93
+
return (
94
+
StatusCode::BAD_REQUEST,
95
+
Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})),
96
+
)
97
+
.into_response();
98
+
}
99
+
100
+
if !handle
101
+
.chars()
102
+
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
103
+
{
104
+
return (
105
+
StatusCode::BAD_REQUEST,
106
+
Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
107
+
)
108
+
.into_response();
109
+
}
110
+
111
+
let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did)
112
+
.fetch_optional(&state.db)
113
+
.await;
114
+
115
+
if let Ok(Some(_)) = existing {
116
+
return (
117
+
StatusCode::BAD_REQUEST,
118
+
Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
119
+
)
120
+
.into_response();
121
+
}
122
+
123
+
let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did)
124
+
.execute(&state.db)
125
+
.await;
126
+
127
+
match result {
128
+
Ok(r) => {
129
+
if r.rows_affected() == 0 {
130
+
return (
131
+
StatusCode::NOT_FOUND,
132
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
133
+
)
134
+
.into_response();
135
+
}
136
+
(StatusCode::OK, Json(json!({}))).into_response()
137
+
}
138
+
Err(e) => {
139
+
error!("DB error updating handle: {:?}", e);
140
+
(
141
+
StatusCode::INTERNAL_SERVER_ERROR,
142
+
Json(json!({"error": "InternalError"})),
143
+
)
144
+
.into_response()
145
+
}
146
+
}
147
+
}
148
+
149
+
#[derive(Deserialize)]
150
+
pub struct UpdateAccountPasswordInput {
151
+
pub did: String,
152
+
pub password: String,
153
+
}
154
+
155
+
pub async fn update_account_password(
156
+
State(state): State<AppState>,
157
+
headers: axum::http::HeaderMap,
158
+
Json(input): Json<UpdateAccountPasswordInput>,
159
+
) -> Response {
160
+
let auth_header = headers.get("Authorization");
161
+
if auth_header.is_none() {
162
+
return (
163
+
StatusCode::UNAUTHORIZED,
164
+
Json(json!({"error": "AuthenticationRequired"})),
165
+
)
166
+
.into_response();
167
+
}
168
+
169
+
let did = input.did.trim();
170
+
let password = input.password.trim();
171
+
172
+
if did.is_empty() || password.is_empty() {
173
+
return (
174
+
StatusCode::BAD_REQUEST,
175
+
Json(json!({"error": "InvalidRequest", "message": "did and password are required"})),
176
+
)
177
+
.into_response();
178
+
}
179
+
180
+
let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) {
181
+
Ok(h) => h,
182
+
Err(e) => {
183
+
error!("Failed to hash password: {:?}", e);
184
+
return (
185
+
StatusCode::INTERNAL_SERVER_ERROR,
186
+
Json(json!({"error": "InternalError"})),
187
+
)
188
+
.into_response();
189
+
}
190
+
};
191
+
192
+
let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did)
193
+
.execute(&state.db)
194
+
.await;
195
+
196
+
match result {
197
+
Ok(r) => {
198
+
if r.rows_affected() == 0 {
199
+
return (
200
+
StatusCode::NOT_FOUND,
201
+
Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
202
+
)
203
+
.into_response();
204
+
}
205
+
(StatusCode::OK, Json(json!({}))).into_response()
206
+
}
207
+
Err(e) => {
208
+
error!("DB error updating password: {:?}", e);
209
+
(
210
+
StatusCode::INTERNAL_SERVER_ERROR,
211
+
Json(json!({"error": "InternalError"})),
212
+
)
213
+
.into_response()
214
+
}
215
+
}
216
+
}
+1
-1
src/api/admin/invite.rs
+1
-1
src/api/admin/invite.rs
+68
-14
src/api/admin/status.rs
+68
-14
src/api/admin/status.rs
···
234
Some("com.atproto.admin.defs#repoRef") => {
235
let did = input.subject.get("did").and_then(|d| d.as_str());
236
if let Some(did) = did {
237
if let Some(takedown) = &input.takedown {
238
let takedown_ref = if takedown.apply {
239
takedown.r#ref.clone()
240
} else {
241
None
242
};
243
-
let _ = sqlx::query!(
244
"UPDATE users SET takedown_ref = $1 WHERE did = $2",
245
takedown_ref,
246
did
247
)
248
-
.execute(&state.db)
249
-
.await;
250
}
251
252
if let Some(deactivated) = &input.deactivated {
253
-
if deactivated.apply {
254
-
let _ = sqlx::query!(
255
"UPDATE users SET deactivated_at = NOW() WHERE did = $1",
256
did
257
)
258
-
.execute(&state.db)
259
-
.await;
260
} else {
261
-
let _ = sqlx::query!(
262
"UPDATE users SET deactivated_at = NULL WHERE did = $1",
263
did
264
)
265
-
.execute(&state.db)
266
-
.await;
267
}
268
}
269
270
return (
···
292
} else {
293
None
294
};
295
-
let _ = sqlx::query!(
296
"UPDATE records SET takedown_ref = $1 WHERE record_cid = $2",
297
takedown_ref,
298
uri
299
)
300
.execute(&state.db)
301
-
.await;
302
}
303
304
return (
···
323
} else {
324
None
325
};
326
-
let _ = sqlx::query!(
327
"UPDATE blobs SET takedown_ref = $1 WHERE cid = $2",
328
takedown_ref,
329
cid
330
)
331
.execute(&state.db)
332
-
.await;
333
}
334
335
return (
···
234
Some("com.atproto.admin.defs#repoRef") => {
235
let did = input.subject.get("did").and_then(|d| d.as_str());
236
if let Some(did) = did {
237
+
let mut tx = match state.db.begin().await {
238
+
Ok(tx) => tx,
239
+
Err(e) => {
240
+
error!("Failed to begin transaction: {:?}", e);
241
+
return (
242
+
StatusCode::INTERNAL_SERVER_ERROR,
243
+
Json(json!({"error": "InternalError"})),
244
+
)
245
+
.into_response();
246
+
}
247
+
};
248
+
249
if let Some(takedown) = &input.takedown {
250
let takedown_ref = if takedown.apply {
251
takedown.r#ref.clone()
252
} else {
253
None
254
};
255
+
if let Err(e) = sqlx::query!(
256
"UPDATE users SET takedown_ref = $1 WHERE did = $2",
257
takedown_ref,
258
did
259
)
260
+
.execute(&mut *tx)
261
+
.await
262
+
{
263
+
error!("Failed to update user takedown status for {}: {:?}", did, e);
264
+
return (
265
+
StatusCode::INTERNAL_SERVER_ERROR,
266
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
267
+
)
268
+
.into_response();
269
+
}
270
}
271
272
if let Some(deactivated) = &input.deactivated {
273
+
let result = if deactivated.apply {
274
+
sqlx::query!(
275
"UPDATE users SET deactivated_at = NOW() WHERE did = $1",
276
did
277
)
278
+
.execute(&mut *tx)
279
+
.await
280
} else {
281
+
sqlx::query!(
282
"UPDATE users SET deactivated_at = NULL WHERE did = $1",
283
did
284
)
285
+
.execute(&mut *tx)
286
+
.await
287
+
};
288
+
289
+
if let Err(e) = result {
290
+
error!("Failed to update user deactivation status for {}: {:?}", did, e);
291
+
return (
292
+
StatusCode::INTERNAL_SERVER_ERROR,
293
+
Json(json!({"error": "InternalError", "message": "Failed to update deactivation status"})),
294
+
)
295
+
.into_response();
296
}
297
+
}
298
+
299
+
if let Err(e) = tx.commit().await {
300
+
error!("Failed to commit transaction: {:?}", e);
301
+
return (
302
+
StatusCode::INTERNAL_SERVER_ERROR,
303
+
Json(json!({"error": "InternalError"})),
304
+
)
305
+
.into_response();
306
}
307
308
return (
···
330
} else {
331
None
332
};
333
+
if let Err(e) = sqlx::query!(
334
"UPDATE records SET takedown_ref = $1 WHERE record_cid = $2",
335
takedown_ref,
336
uri
337
)
338
.execute(&state.db)
339
+
.await
340
+
{
341
+
error!("Failed to update record takedown status for {}: {:?}", uri, e);
342
+
return (
343
+
StatusCode::INTERNAL_SERVER_ERROR,
344
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
345
+
)
346
+
.into_response();
347
+
}
348
}
349
350
return (
···
369
} else {
370
None
371
};
372
+
if let Err(e) = sqlx::query!(
373
"UPDATE blobs SET takedown_ref = $1 WHERE cid = $2",
374
takedown_ref,
375
cid
376
)
377
.execute(&state.db)
378
+
.await
379
+
{
380
+
error!("Failed to update blob takedown status for {}: {:?}", cid, e);
381
+
return (
382
+
StatusCode::INTERNAL_SERVER_ERROR,
383
+
Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})),
384
+
)
385
+
.into_response();
386
+
}
387
}
388
389
return (
+163
src/api/error.rs
+163
src/api/error.rs
···
···
1
+
use axum::{
2
+
Json,
3
+
http::StatusCode,
4
+
response::{IntoResponse, Response},
5
+
};
6
+
use serde::Serialize;
7
+
8
+
#[derive(Debug, Serialize)]
9
+
struct ErrorBody {
10
+
error: &'static str,
11
+
#[serde(skip_serializing_if = "Option::is_none")]
12
+
message: Option<String>,
13
+
}
14
+
15
+
#[derive(Debug)]
16
+
pub enum ApiError {
17
+
InternalError,
18
+
AuthenticationRequired,
19
+
AuthenticationFailed,
20
+
AuthenticationFailedMsg(String),
21
+
InvalidRequest(String),
22
+
InvalidToken,
23
+
ExpiredToken,
24
+
ExpiredTokenMsg(String),
25
+
TokenRequired,
26
+
AccountDeactivated,
27
+
AccountTakedown,
28
+
AccountNotFound,
29
+
RepoNotFound,
30
+
RepoNotFoundMsg(String),
31
+
RecordNotFound,
32
+
BlobNotFound,
33
+
InvalidHandle,
34
+
HandleNotAvailable,
35
+
HandleTaken,
36
+
InvalidEmail,
37
+
EmailTaken,
38
+
InvalidInviteCode,
39
+
DuplicateCreate,
40
+
DuplicateAppPassword,
41
+
AppPasswordNotFound,
42
+
InvalidSwap,
43
+
Forbidden,
44
+
InvitesDisabled,
45
+
DatabaseError,
46
+
UpstreamFailure,
47
+
}
48
+
49
+
impl ApiError {
50
+
fn status_code(&self) -> StatusCode {
51
+
match self {
52
+
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => {
53
+
StatusCode::INTERNAL_SERVER_ERROR
54
+
}
55
+
Self::AuthenticationRequired
56
+
| Self::AuthenticationFailed
57
+
| Self::AuthenticationFailedMsg(_)
58
+
| Self::InvalidToken
59
+
| Self::ExpiredToken
60
+
| Self::ExpiredTokenMsg(_)
61
+
| Self::TokenRequired
62
+
| Self::AccountDeactivated
63
+
| Self::AccountTakedown => StatusCode::UNAUTHORIZED,
64
+
Self::Forbidden | Self::InvitesDisabled => StatusCode::FORBIDDEN,
65
+
Self::AccountNotFound
66
+
| Self::RepoNotFound
67
+
| Self::RepoNotFoundMsg(_)
68
+
| Self::RecordNotFound
69
+
| Self::BlobNotFound
70
+
| Self::AppPasswordNotFound => StatusCode::NOT_FOUND,
71
+
Self::InvalidRequest(_)
72
+
| Self::InvalidHandle
73
+
| Self::HandleNotAvailable
74
+
| Self::HandleTaken
75
+
| Self::InvalidEmail
76
+
| Self::EmailTaken
77
+
| Self::InvalidInviteCode
78
+
| Self::DuplicateCreate
79
+
| Self::DuplicateAppPassword
80
+
| Self::InvalidSwap => StatusCode::BAD_REQUEST,
81
+
}
82
+
}
83
+
84
+
fn error_name(&self) -> &'static str {
85
+
match self {
86
+
Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError",
87
+
Self::AuthenticationRequired => "AuthenticationRequired",
88
+
Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed",
89
+
Self::InvalidToken => "InvalidToken",
90
+
Self::ExpiredToken | Self::ExpiredTokenMsg(_) => "ExpiredToken",
91
+
Self::TokenRequired => "TokenRequired",
92
+
Self::AccountDeactivated => "AccountDeactivated",
93
+
Self::AccountTakedown => "AccountTakedown",
94
+
Self::Forbidden => "Forbidden",
95
+
Self::InvitesDisabled => "InvitesDisabled",
96
+
Self::AccountNotFound => "AccountNotFound",
97
+
Self::RepoNotFound | Self::RepoNotFoundMsg(_) => "RepoNotFound",
98
+
Self::RecordNotFound => "RecordNotFound",
99
+
Self::BlobNotFound => "BlobNotFound",
100
+
Self::AppPasswordNotFound => "AppPasswordNotFound",
101
+
Self::InvalidRequest(_) => "InvalidRequest",
102
+
Self::InvalidHandle => "InvalidHandle",
103
+
Self::HandleNotAvailable => "HandleNotAvailable",
104
+
Self::HandleTaken => "HandleTaken",
105
+
Self::InvalidEmail => "InvalidEmail",
106
+
Self::EmailTaken => "EmailTaken",
107
+
Self::InvalidInviteCode => "InvalidInviteCode",
108
+
Self::DuplicateCreate => "DuplicateCreate",
109
+
Self::DuplicateAppPassword => "DuplicateAppPassword",
110
+
Self::InvalidSwap => "InvalidSwap",
111
+
}
112
+
}
113
+
114
+
fn message(&self) -> Option<String> {
115
+
match self {
116
+
Self::AuthenticationFailedMsg(msg)
117
+
| Self::ExpiredTokenMsg(msg)
118
+
| Self::InvalidRequest(msg)
119
+
| Self::RepoNotFoundMsg(msg) => Some(msg.clone()),
120
+
_ => None,
121
+
}
122
+
}
123
+
}
124
+
125
+
impl IntoResponse for ApiError {
126
+
fn into_response(self) -> Response {
127
+
let body = ErrorBody {
128
+
error: self.error_name(),
129
+
message: self.message(),
130
+
};
131
+
(self.status_code(), Json(body)).into_response()
132
+
}
133
+
}
134
+
135
+
impl From<sqlx::Error> for ApiError {
136
+
fn from(e: sqlx::Error) -> Self {
137
+
tracing::error!("Database error: {:?}", e);
138
+
Self::DatabaseError
139
+
}
140
+
}
141
+
142
+
impl From<crate::auth::TokenValidationError> for ApiError {
143
+
fn from(e: crate::auth::TokenValidationError) -> Self {
144
+
match e {
145
+
crate::auth::TokenValidationError::AccountDeactivated => Self::AccountDeactivated,
146
+
crate::auth::TokenValidationError::AccountTakedown => Self::AccountTakedown,
147
+
crate::auth::TokenValidationError::KeyDecryptionFailed => Self::InternalError,
148
+
crate::auth::TokenValidationError::AuthenticationFailed => Self::AuthenticationFailed,
149
+
}
150
+
}
151
+
}
152
+
153
+
impl From<crate::util::DbLookupError> for ApiError {
154
+
fn from(e: crate::util::DbLookupError) -> Self {
155
+
match e {
156
+
crate::util::DbLookupError::NotFound => Self::AccountNotFound,
157
+
crate::util::DbLookupError::DatabaseError(db_err) => {
158
+
tracing::error!("Database error: {:?}", db_err);
159
+
Self::DatabaseError
160
+
}
161
+
}
162
+
}
163
+
}
+9
-1
src/api/identity/account.rs
+9
-1
src/api/identity/account.rs
···
40
State(state): State<AppState>,
41
Json(input): Json<CreateAccountInput>,
42
) -> Response {
43
-
info!("create_account hit: {}", input.handle);
44
if input.handle.contains('!') || input.handle.contains('@') {
45
return (
46
StatusCode::BAD_REQUEST,
47
Json(
48
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
49
),
50
)
51
.into_response();
52
}
···
40
State(state): State<AppState>,
41
Json(input): Json<CreateAccountInput>,
42
) -> Response {
43
+
info!("create_account called");
44
if input.handle.contains('!') || input.handle.contains('@') {
45
return (
46
StatusCode::BAD_REQUEST,
47
Json(
48
json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
49
),
50
+
)
51
+
.into_response();
52
+
}
53
+
54
+
if !crate::api::validation::is_valid_email(&input.email) {
55
+
return (
56
+
StatusCode::BAD_REQUEST,
57
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
58
)
59
.into_response();
60
}
+36
-72
src/api/identity/did.rs
+36
-72
src/api/identity/did.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
···
56
}
57
}
58
59
-
pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
60
-
let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
61
let public_key = secret_key.public_key();
62
let encoded = public_key.to_encoded_point(false);
63
-
let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
64
-
let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
65
66
-
json!({
67
"kty": "EC",
68
"crv": "secp256k1",
69
-
"x": x,
70
-
"y": y
71
-
})
72
}
73
74
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
···
147
}
148
};
149
150
-
let jwk = get_jwk(&key_bytes);
151
152
Json(json!({
153
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
···
294
}
295
};
296
297
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
298
-
let did = match auth_result {
299
-
Ok(ref user) => user.did.clone(),
300
-
Err(e) => {
301
-
return (
302
-
StatusCode::UNAUTHORIZED,
303
-
Json(json!({"error": e})),
304
-
)
305
-
.into_response();
306
-
}
307
};
308
309
-
let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", did)
310
.fetch_optional(&state.db)
311
.await
312
{
313
Ok(Some(row)) => row,
314
-
_ => {
315
-
return (
316
-
StatusCode::INTERNAL_SERVER_ERROR,
317
-
Json(json!({"error": "InternalError"})),
318
-
)
319
-
.into_response();
320
-
}
321
};
322
-
let handle = user.handle;
323
324
-
let key_bytes = match auth_result.ok().and_then(|u| u.key_bytes) {
325
Some(kb) => kb,
326
-
None => {
327
-
return (
328
-
StatusCode::UNAUTHORIZED,
329
-
Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot get DID credentials"})),
330
-
)
331
-
.into_response();
332
-
}
333
};
334
335
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
···
337
338
let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
339
Ok(k) => k,
340
-
Err(_) => {
341
-
return (
342
-
StatusCode::INTERNAL_SERVER_ERROR,
343
-
Json(json!({"error": "InternalError"})),
344
-
)
345
-
.into_response();
346
-
}
347
};
348
349
let public_key = secret_key.public_key();
···
360
StatusCode::OK,
361
Json(GetRecommendedDidCredentialsOutput {
362
rotation_keys: vec![did_key.clone()],
363
-
also_known_as: vec![format!("at://{}", handle)],
364
verification_methods: VerificationMethods { atproto: did_key },
365
services: Services {
366
atproto_pds: AtprotoPds {
···
387
headers.get("Authorization").and_then(|h| h.to_str().ok())
388
) {
389
Some(t) => t,
390
-
None => {
391
-
return (
392
-
StatusCode::UNAUTHORIZED,
393
-
Json(json!({"error": "AuthenticationRequired"})),
394
-
)
395
-
.into_response();
396
-
}
397
};
398
399
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
400
-
let did = match auth_result {
401
Ok(user) => user.did,
402
-
Err(e) => {
403
-
return (
404
-
StatusCode::UNAUTHORIZED,
405
-
Json(json!({"error": e})),
406
-
)
407
-
.into_response();
408
-
}
409
};
410
411
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
413
.await
414
{
415
Ok(Some(id)) => id,
416
-
_ => {
417
-
return (
418
-
StatusCode::INTERNAL_SERVER_ERROR,
419
-
Json(json!({"error": "InternalError"})),
420
-
)
421
-
.into_response();
422
-
}
423
};
424
425
let new_handle = input.handle.trim();
426
if new_handle.is_empty() {
427
-
return (
428
-
StatusCode::BAD_REQUEST,
429
-
Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
430
-
)
431
-
.into_response();
432
}
433
434
if !new_handle
···
1
+
use crate::api::ApiError;
2
use crate::state::AppState;
3
use axum::{
4
Json,
···
57
}
58
}
59
60
+
pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
61
+
let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
62
let public_key = secret_key.public_key();
63
let encoded = public_key.to_encoded_point(false);
64
+
let x = encoded.x().ok_or("Missing x coordinate")?;
65
+
let y = encoded.y().ok_or("Missing y coordinate")?;
66
+
let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
67
+
let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
68
69
+
Ok(json!({
70
"kty": "EC",
71
"crv": "secp256k1",
72
+
"x": x_b64,
73
+
"y": y_b64
74
+
}))
75
}
76
77
pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
···
150
}
151
};
152
153
+
let jwk = match get_jwk(&key_bytes) {
154
+
Ok(j) => j,
155
+
Err(e) => {
156
+
tracing::error!("Failed to generate JWK: {}", e);
157
+
return (
158
+
StatusCode::INTERNAL_SERVER_ERROR,
159
+
Json(json!({"error": "InternalError"})),
160
+
)
161
+
.into_response();
162
+
}
163
+
};
164
165
Json(json!({
166
"@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
···
307
}
308
};
309
310
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
311
+
Ok(user) => user,
312
+
Err(e) => return ApiError::from(e).into_response(),
313
};
314
315
+
let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
316
.fetch_optional(&state.db)
317
.await
318
{
319
Ok(Some(row)) => row,
320
+
_ => return ApiError::InternalError.into_response(),
321
};
322
323
+
let key_bytes = match auth_user.key_bytes {
324
Some(kb) => kb,
325
+
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
326
};
327
328
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
···
330
331
let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
332
Ok(k) => k,
333
+
Err(_) => return ApiError::InternalError.into_response(),
334
};
335
336
let public_key = secret_key.public_key();
···
347
StatusCode::OK,
348
Json(GetRecommendedDidCredentialsOutput {
349
rotation_keys: vec![did_key.clone()],
350
+
also_known_as: vec![format!("at://{}", user.handle)],
351
verification_methods: VerificationMethods { atproto: did_key },
352
services: Services {
353
atproto_pds: AtprotoPds {
···
374
headers.get("Authorization").and_then(|h| h.to_str().ok())
375
) {
376
Some(t) => t,
377
+
None => return ApiError::AuthenticationRequired.into_response(),
378
};
379
380
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
381
Ok(user) => user.did,
382
+
Err(e) => return ApiError::from(e).into_response(),
383
};
384
385
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
387
.await
388
{
389
Ok(Some(id)) => id,
390
+
_ => return ApiError::InternalError.into_response(),
391
};
392
393
let new_handle = input.handle.trim();
394
if new_handle.is_empty() {
395
+
return ApiError::InvalidRequest("handle is required".into()).into_response();
396
}
397
398
if !new_handle
-618
src/api/identity/plc.rs
-618
src/api/identity/plc.rs
···
1
-
use crate::plc::{
2
-
create_update_op, sign_operation, signing_key_to_did_key, validate_plc_operation,
3
-
PlcClient, PlcError, PlcService,
4
-
};
5
-
use crate::state::AppState;
6
-
use axum::{
7
-
extract::State,
8
-
http::StatusCode,
9
-
response::{IntoResponse, Response},
10
-
Json,
11
-
};
12
-
use chrono::{Duration, Utc};
13
-
use k256::ecdsa::SigningKey;
14
-
use rand::Rng;
15
-
use serde::{Deserialize, Serialize};
16
-
use serde_json::{json, Value};
17
-
use std::collections::HashMap;
18
-
use tracing::{error, info, warn};
19
-
20
-
fn generate_plc_token() -> String {
21
-
let mut rng = rand::thread_rng();
22
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
23
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
24
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
25
-
format!("{}-{}", part1, part2)
26
-
}
27
-
28
-
pub async fn request_plc_operation_signature(
29
-
State(state): State<AppState>,
30
-
headers: axum::http::HeaderMap,
31
-
) -> Response {
32
-
let token = match crate::auth::extract_bearer_token_from_header(
33
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
34
-
) {
35
-
Some(t) => t,
36
-
None => {
37
-
return (
38
-
StatusCode::UNAUTHORIZED,
39
-
Json(json!({"error": "AuthenticationRequired"})),
40
-
)
41
-
.into_response();
42
-
}
43
-
};
44
-
45
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
46
-
Ok(user) => user,
47
-
Err(e) => {
48
-
return (
49
-
StatusCode::UNAUTHORIZED,
50
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
51
-
)
52
-
.into_response();
53
-
}
54
-
};
55
-
56
-
let did = &auth_user.did;
57
-
58
-
let user = match sqlx::query!(
59
-
"SELECT id FROM users WHERE did = $1",
60
-
did
61
-
)
62
-
.fetch_optional(&state.db)
63
-
.await
64
-
{
65
-
Ok(Some(row)) => row,
66
-
Ok(None) => {
67
-
return (
68
-
StatusCode::NOT_FOUND,
69
-
Json(json!({"error": "AccountNotFound"})),
70
-
)
71
-
.into_response();
72
-
}
73
-
Err(e) => {
74
-
error!("DB error: {:?}", e);
75
-
return (
76
-
StatusCode::INTERNAL_SERVER_ERROR,
77
-
Json(json!({"error": "InternalError"})),
78
-
)
79
-
.into_response();
80
-
}
81
-
};
82
-
83
-
let _ = sqlx::query!(
84
-
"DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
85
-
user.id
86
-
)
87
-
.execute(&state.db)
88
-
.await;
89
-
90
-
let plc_token = generate_plc_token();
91
-
let expires_at = Utc::now() + Duration::minutes(10);
92
-
93
-
if let Err(e) = sqlx::query!(
94
-
r#"
95
-
INSERT INTO plc_operation_tokens (user_id, token, expires_at)
96
-
VALUES ($1, $2, $3)
97
-
"#,
98
-
user.id,
99
-
plc_token,
100
-
expires_at
101
-
)
102
-
.execute(&state.db)
103
-
.await
104
-
{
105
-
error!("Failed to create PLC token: {:?}", e);
106
-
return (
107
-
StatusCode::INTERNAL_SERVER_ERROR,
108
-
Json(json!({"error": "InternalError"})),
109
-
)
110
-
.into_response();
111
-
}
112
-
113
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
114
-
115
-
if let Err(e) = crate::notifications::enqueue_plc_operation(
116
-
&state.db,
117
-
user.id,
118
-
&plc_token,
119
-
&hostname,
120
-
)
121
-
.await
122
-
{
123
-
warn!("Failed to enqueue PLC operation notification: {:?}", e);
124
-
}
125
-
126
-
info!("PLC operation signature requested for user {}", did);
127
-
128
-
(StatusCode::OK, Json(json!({}))).into_response()
129
-
}
130
-
131
-
#[derive(Debug, Deserialize)]
132
-
#[serde(rename_all = "camelCase")]
133
-
pub struct SignPlcOperationInput {
134
-
pub token: Option<String>,
135
-
pub rotation_keys: Option<Vec<String>>,
136
-
pub also_known_as: Option<Vec<String>>,
137
-
pub verification_methods: Option<HashMap<String, String>>,
138
-
pub services: Option<HashMap<String, ServiceInput>>,
139
-
}
140
-
141
-
#[derive(Debug, Deserialize, Clone)]
142
-
pub struct ServiceInput {
143
-
#[serde(rename = "type")]
144
-
pub service_type: String,
145
-
pub endpoint: String,
146
-
}
147
-
148
-
#[derive(Debug, Serialize)]
149
-
pub struct SignPlcOperationOutput {
150
-
pub operation: Value,
151
-
}
152
-
153
-
pub async fn sign_plc_operation(
154
-
State(state): State<AppState>,
155
-
headers: axum::http::HeaderMap,
156
-
Json(input): Json<SignPlcOperationInput>,
157
-
) -> Response {
158
-
let bearer = match crate::auth::extract_bearer_token_from_header(
159
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
160
-
) {
161
-
Some(t) => t,
162
-
None => {
163
-
return (
164
-
StatusCode::UNAUTHORIZED,
165
-
Json(json!({"error": "AuthenticationRequired"})),
166
-
)
167
-
.into_response();
168
-
}
169
-
};
170
-
171
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
172
-
Ok(user) => user,
173
-
Err(e) => {
174
-
return (
175
-
StatusCode::UNAUTHORIZED,
176
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
177
-
)
178
-
.into_response();
179
-
}
180
-
};
181
-
182
-
let did = &auth_user.did;
183
-
184
-
let token = match &input.token {
185
-
Some(t) => t,
186
-
None => {
187
-
return (
188
-
StatusCode::BAD_REQUEST,
189
-
Json(json!({
190
-
"error": "InvalidRequest",
191
-
"message": "Email confirmation token required to sign PLC operations"
192
-
})),
193
-
)
194
-
.into_response();
195
-
}
196
-
};
197
-
198
-
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
199
-
.fetch_optional(&state.db)
200
-
.await
201
-
{
202
-
Ok(Some(row)) => row,
203
-
_ => {
204
-
return (
205
-
StatusCode::NOT_FOUND,
206
-
Json(json!({"error": "AccountNotFound"})),
207
-
)
208
-
.into_response();
209
-
}
210
-
};
211
-
212
-
let token_row = match sqlx::query!(
213
-
"SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
214
-
user.id,
215
-
token
216
-
)
217
-
.fetch_optional(&state.db)
218
-
.await
219
-
{
220
-
Ok(Some(row)) => row,
221
-
Ok(None) => {
222
-
return (
223
-
StatusCode::BAD_REQUEST,
224
-
Json(json!({
225
-
"error": "InvalidToken",
226
-
"message": "Invalid or expired token"
227
-
})),
228
-
)
229
-
.into_response();
230
-
}
231
-
Err(e) => {
232
-
error!("DB error: {:?}", e);
233
-
return (
234
-
StatusCode::INTERNAL_SERVER_ERROR,
235
-
Json(json!({"error": "InternalError"})),
236
-
)
237
-
.into_response();
238
-
}
239
-
};
240
-
241
-
if Utc::now() > token_row.expires_at {
242
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
243
-
.execute(&state.db)
244
-
.await;
245
-
return (
246
-
StatusCode::BAD_REQUEST,
247
-
Json(json!({
248
-
"error": "ExpiredToken",
249
-
"message": "Token has expired"
250
-
})),
251
-
)
252
-
.into_response();
253
-
}
254
-
255
-
let key_row = match sqlx::query!(
256
-
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
257
-
user.id
258
-
)
259
-
.fetch_optional(&state.db)
260
-
.await
261
-
{
262
-
Ok(Some(row)) => row,
263
-
_ => {
264
-
return (
265
-
StatusCode::INTERNAL_SERVER_ERROR,
266
-
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
267
-
)
268
-
.into_response();
269
-
}
270
-
};
271
-
272
-
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
273
-
{
274
-
Ok(k) => k,
275
-
Err(e) => {
276
-
error!("Failed to decrypt user key: {}", e);
277
-
return (
278
-
StatusCode::INTERNAL_SERVER_ERROR,
279
-
Json(json!({"error": "InternalError"})),
280
-
)
281
-
.into_response();
282
-
}
283
-
};
284
-
285
-
let signing_key = match SigningKey::from_slice(&key_bytes) {
286
-
Ok(k) => k,
287
-
Err(e) => {
288
-
error!("Failed to create signing key: {:?}", e);
289
-
return (
290
-
StatusCode::INTERNAL_SERVER_ERROR,
291
-
Json(json!({"error": "InternalError"})),
292
-
)
293
-
.into_response();
294
-
}
295
-
};
296
-
297
-
let plc_client = PlcClient::new(None);
298
-
let last_op = match plc_client.get_last_op(did).await {
299
-
Ok(op) => op,
300
-
Err(PlcError::NotFound) => {
301
-
return (
302
-
StatusCode::NOT_FOUND,
303
-
Json(json!({
304
-
"error": "NotFound",
305
-
"message": "DID not found in PLC directory"
306
-
})),
307
-
)
308
-
.into_response();
309
-
}
310
-
Err(e) => {
311
-
error!("Failed to fetch PLC operation: {:?}", e);
312
-
return (
313
-
StatusCode::BAD_GATEWAY,
314
-
Json(json!({
315
-
"error": "UpstreamError",
316
-
"message": "Failed to communicate with PLC directory"
317
-
})),
318
-
)
319
-
.into_response();
320
-
}
321
-
};
322
-
323
-
if last_op.is_tombstone() {
324
-
return (
325
-
StatusCode::BAD_REQUEST,
326
-
Json(json!({
327
-
"error": "InvalidRequest",
328
-
"message": "DID is tombstoned"
329
-
})),
330
-
)
331
-
.into_response();
332
-
}
333
-
334
-
let services = input.services.map(|s| {
335
-
s.into_iter()
336
-
.map(|(k, v)| {
337
-
(
338
-
k,
339
-
PlcService {
340
-
service_type: v.service_type,
341
-
endpoint: v.endpoint,
342
-
},
343
-
)
344
-
})
345
-
.collect()
346
-
});
347
-
348
-
let unsigned_op = match create_update_op(
349
-
&last_op,
350
-
input.rotation_keys,
351
-
input.verification_methods,
352
-
input.also_known_as,
353
-
services,
354
-
) {
355
-
Ok(op) => op,
356
-
Err(PlcError::Tombstoned) => {
357
-
return (
358
-
StatusCode::BAD_REQUEST,
359
-
Json(json!({
360
-
"error": "InvalidRequest",
361
-
"message": "Cannot update tombstoned DID"
362
-
})),
363
-
)
364
-
.into_response();
365
-
}
366
-
Err(e) => {
367
-
error!("Failed to create PLC operation: {:?}", e);
368
-
return (
369
-
StatusCode::INTERNAL_SERVER_ERROR,
370
-
Json(json!({"error": "InternalError"})),
371
-
)
372
-
.into_response();
373
-
}
374
-
};
375
-
376
-
let signed_op = match sign_operation(&unsigned_op, &signing_key) {
377
-
Ok(op) => op,
378
-
Err(e) => {
379
-
error!("Failed to sign PLC operation: {:?}", e);
380
-
return (
381
-
StatusCode::INTERNAL_SERVER_ERROR,
382
-
Json(json!({"error": "InternalError"})),
383
-
)
384
-
.into_response();
385
-
}
386
-
};
387
-
388
-
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
389
-
.execute(&state.db)
390
-
.await;
391
-
392
-
info!("Signed PLC operation for user {}", did);
393
-
394
-
(
395
-
StatusCode::OK,
396
-
Json(SignPlcOperationOutput {
397
-
operation: signed_op,
398
-
}),
399
-
)
400
-
.into_response()
401
-
}
402
-
403
-
#[derive(Debug, Deserialize)]
404
-
pub struct SubmitPlcOperationInput {
405
-
pub operation: Value,
406
-
}
407
-
408
-
pub async fn submit_plc_operation(
409
-
State(state): State<AppState>,
410
-
headers: axum::http::HeaderMap,
411
-
Json(input): Json<SubmitPlcOperationInput>,
412
-
) -> Response {
413
-
let bearer = match crate::auth::extract_bearer_token_from_header(
414
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
415
-
) {
416
-
Some(t) => t,
417
-
None => {
418
-
return (
419
-
StatusCode::UNAUTHORIZED,
420
-
Json(json!({"error": "AuthenticationRequired"})),
421
-
)
422
-
.into_response();
423
-
}
424
-
};
425
-
426
-
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
427
-
Ok(user) => user,
428
-
Err(e) => {
429
-
return (
430
-
StatusCode::UNAUTHORIZED,
431
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
432
-
)
433
-
.into_response();
434
-
}
435
-
};
436
-
437
-
let did = &auth_user.did;
438
-
439
-
if let Err(e) = validate_plc_operation(&input.operation) {
440
-
return (
441
-
StatusCode::BAD_REQUEST,
442
-
Json(json!({
443
-
"error": "InvalidRequest",
444
-
"message": format!("Invalid operation: {}", e)
445
-
})),
446
-
)
447
-
.into_response();
448
-
}
449
-
450
-
let op = &input.operation;
451
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
452
-
let public_url = format!("https://{}", hostname);
453
-
454
-
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
455
-
.fetch_optional(&state.db)
456
-
.await
457
-
{
458
-
Ok(Some(row)) => row,
459
-
_ => {
460
-
return (
461
-
StatusCode::NOT_FOUND,
462
-
Json(json!({"error": "AccountNotFound"})),
463
-
)
464
-
.into_response();
465
-
}
466
-
};
467
-
468
-
let key_row = match sqlx::query!(
469
-
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
470
-
user.id
471
-
)
472
-
.fetch_optional(&state.db)
473
-
.await
474
-
{
475
-
Ok(Some(row)) => row,
476
-
_ => {
477
-
return (
478
-
StatusCode::INTERNAL_SERVER_ERROR,
479
-
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
480
-
)
481
-
.into_response();
482
-
}
483
-
};
484
-
485
-
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
486
-
{
487
-
Ok(k) => k,
488
-
Err(e) => {
489
-
error!("Failed to decrypt user key: {}", e);
490
-
return (
491
-
StatusCode::INTERNAL_SERVER_ERROR,
492
-
Json(json!({"error": "InternalError"})),
493
-
)
494
-
.into_response();
495
-
}
496
-
};
497
-
498
-
let signing_key = match SigningKey::from_slice(&key_bytes) {
499
-
Ok(k) => k,
500
-
Err(e) => {
501
-
error!("Failed to create signing key: {:?}", e);
502
-
return (
503
-
StatusCode::INTERNAL_SERVER_ERROR,
504
-
Json(json!({"error": "InternalError"})),
505
-
)
506
-
.into_response();
507
-
}
508
-
};
509
-
510
-
let user_did_key = signing_key_to_did_key(&signing_key);
511
-
512
-
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
513
-
let server_rotation_key =
514
-
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
515
-
516
-
let has_server_key = rotation_keys
517
-
.iter()
518
-
.any(|k| k.as_str() == Some(&server_rotation_key));
519
-
520
-
if !has_server_key {
521
-
return (
522
-
StatusCode::BAD_REQUEST,
523
-
Json(json!({
524
-
"error": "InvalidRequest",
525
-
"message": "Rotation keys do not include server's rotation key"
526
-
})),
527
-
)
528
-
.into_response();
529
-
}
530
-
}
531
-
532
-
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
533
-
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
534
-
let service_type = pds.get("type").and_then(|v| v.as_str());
535
-
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
536
-
537
-
if service_type != Some("AtprotoPersonalDataServer") {
538
-
return (
539
-
StatusCode::BAD_REQUEST,
540
-
Json(json!({
541
-
"error": "InvalidRequest",
542
-
"message": "Incorrect type on atproto_pds service"
543
-
})),
544
-
)
545
-
.into_response();
546
-
}
547
-
548
-
if endpoint != Some(&public_url) {
549
-
return (
550
-
StatusCode::BAD_REQUEST,
551
-
Json(json!({
552
-
"error": "InvalidRequest",
553
-
"message": "Incorrect endpoint on atproto_pds service"
554
-
})),
555
-
)
556
-
.into_response();
557
-
}
558
-
}
559
-
}
560
-
561
-
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
562
-
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
563
-
if atproto_key != user_did_key {
564
-
return (
565
-
StatusCode::BAD_REQUEST,
566
-
Json(json!({
567
-
"error": "InvalidRequest",
568
-
"message": "Incorrect signing key in verificationMethods"
569
-
})),
570
-
)
571
-
.into_response();
572
-
}
573
-
}
574
-
}
575
-
576
-
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
577
-
let expected_handle = format!("at://{}", user.handle);
578
-
let first_aka = also_known_as.first().and_then(|v| v.as_str());
579
-
580
-
if first_aka != Some(&expected_handle) {
581
-
return (
582
-
StatusCode::BAD_REQUEST,
583
-
Json(json!({
584
-
"error": "InvalidRequest",
585
-
"message": "Incorrect handle in alsoKnownAs"
586
-
})),
587
-
)
588
-
.into_response();
589
-
}
590
-
}
591
-
592
-
let plc_client = PlcClient::new(None);
593
-
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
594
-
error!("Failed to submit PLC operation: {:?}", e);
595
-
return (
596
-
StatusCode::BAD_GATEWAY,
597
-
Json(json!({
598
-
"error": "UpstreamError",
599
-
"message": format!("Failed to submit to PLC directory: {}", e)
600
-
})),
601
-
)
602
-
.into_response();
603
-
}
604
-
605
-
if let Err(e) = sqlx::query!(
606
-
"INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
607
-
did
608
-
)
609
-
.execute(&state.db)
610
-
.await
611
-
{
612
-
warn!("Failed to sequence identity event: {:?}", e);
613
-
}
614
-
615
-
info!("Submitted PLC operation for user {}", did);
616
-
617
-
(StatusCode::OK, Json(json!({}))).into_response()
618
-
}
···
+7
src/api/identity/plc/mod.rs
+7
src/api/identity/plc/mod.rs
+91
src/api/identity/plc/request.rs
+91
src/api/identity/plc/request.rs
···
···
1
+
use crate::api::ApiError;
2
+
use crate::state::AppState;
3
+
use axum::{
4
+
extract::State,
5
+
http::StatusCode,
6
+
response::{IntoResponse, Response},
7
+
Json,
8
+
};
9
+
use chrono::{Duration, Utc};
10
+
use serde_json::json;
11
+
use tracing::{error, info, warn};
12
+
13
+
fn generate_plc_token() -> String {
14
+
crate::util::generate_token_code()
15
+
}
16
+
17
+
pub async fn request_plc_operation_signature(
18
+
State(state): State<AppState>,
19
+
headers: axum::http::HeaderMap,
20
+
) -> Response {
21
+
let token = match crate::auth::extract_bearer_token_from_header(
22
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
23
+
) {
24
+
Some(t) => t,
25
+
None => return ApiError::AuthenticationRequired.into_response(),
26
+
};
27
+
28
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
29
+
Ok(user) => user,
30
+
Err(e) => return ApiError::from(e).into_response(),
31
+
};
32
+
33
+
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did)
34
+
.fetch_optional(&state.db)
35
+
.await
36
+
{
37
+
Ok(Some(row)) => row,
38
+
Ok(None) => return ApiError::AccountNotFound.into_response(),
39
+
Err(e) => {
40
+
error!("DB error: {:?}", e);
41
+
return ApiError::InternalError.into_response();
42
+
}
43
+
};
44
+
45
+
let _ = sqlx::query!(
46
+
"DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()",
47
+
user.id
48
+
)
49
+
.execute(&state.db)
50
+
.await;
51
+
52
+
let plc_token = generate_plc_token();
53
+
let expires_at = Utc::now() + Duration::minutes(10);
54
+
55
+
if let Err(e) = sqlx::query!(
56
+
r#"
57
+
INSERT INTO plc_operation_tokens (user_id, token, expires_at)
58
+
VALUES ($1, $2, $3)
59
+
"#,
60
+
user.id,
61
+
plc_token,
62
+
expires_at
63
+
)
64
+
.execute(&state.db)
65
+
.await
66
+
{
67
+
error!("Failed to create PLC token: {:?}", e);
68
+
return (
69
+
StatusCode::INTERNAL_SERVER_ERROR,
70
+
Json(json!({"error": "InternalError"})),
71
+
)
72
+
.into_response();
73
+
}
74
+
75
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
76
+
77
+
if let Err(e) = crate::notifications::enqueue_plc_operation(
78
+
&state.db,
79
+
user.id,
80
+
&plc_token,
81
+
&hostname,
82
+
)
83
+
.await
84
+
{
85
+
warn!("Failed to enqueue PLC operation notification: {:?}", e);
86
+
}
87
+
88
+
info!("PLC operation signature requested for user {}", auth_user.did);
89
+
90
+
(StatusCode::OK, Json(json!({}))).into_response()
91
+
}
+272
src/api/identity/plc/sign.rs
+272
src/api/identity/plc/sign.rs
···
···
1
+
use crate::api::ApiError;
2
+
use crate::plc::{
3
+
create_update_op, sign_operation, PlcClient, PlcError, PlcService,
4
+
};
5
+
use crate::state::AppState;
6
+
use axum::{
7
+
extract::State,
8
+
http::StatusCode,
9
+
response::{IntoResponse, Response},
10
+
Json,
11
+
};
12
+
use chrono::Utc;
13
+
use k256::ecdsa::SigningKey;
14
+
use serde::{Deserialize, Serialize};
15
+
use serde_json::{json, Value};
16
+
use std::collections::HashMap;
17
+
use tracing::{error, info};
18
+
19
+
#[derive(Debug, Deserialize)]
20
+
#[serde(rename_all = "camelCase")]
21
+
pub struct SignPlcOperationInput {
22
+
pub token: Option<String>,
23
+
pub rotation_keys: Option<Vec<String>>,
24
+
pub also_known_as: Option<Vec<String>>,
25
+
pub verification_methods: Option<HashMap<String, String>>,
26
+
pub services: Option<HashMap<String, ServiceInput>>,
27
+
}
28
+
29
+
#[derive(Debug, Deserialize, Clone)]
30
+
pub struct ServiceInput {
31
+
#[serde(rename = "type")]
32
+
pub service_type: String,
33
+
pub endpoint: String,
34
+
}
35
+
36
+
#[derive(Debug, Serialize)]
37
+
pub struct SignPlcOperationOutput {
38
+
pub operation: Value,
39
+
}
40
+
41
+
pub async fn sign_plc_operation(
42
+
State(state): State<AppState>,
43
+
headers: axum::http::HeaderMap,
44
+
Json(input): Json<SignPlcOperationInput>,
45
+
) -> Response {
46
+
let bearer = match crate::auth::extract_bearer_token_from_header(
47
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
48
+
) {
49
+
Some(t) => t,
50
+
None => return ApiError::AuthenticationRequired.into_response(),
51
+
};
52
+
53
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
54
+
Ok(user) => user,
55
+
Err(e) => return ApiError::from(e).into_response(),
56
+
};
57
+
58
+
let did = &auth_user.did;
59
+
60
+
let token = match &input.token {
61
+
Some(t) => t,
62
+
None => {
63
+
return ApiError::InvalidRequest(
64
+
"Email confirmation token required to sign PLC operations".into()
65
+
).into_response();
66
+
}
67
+
};
68
+
69
+
let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did)
70
+
.fetch_optional(&state.db)
71
+
.await
72
+
{
73
+
Ok(Some(row)) => row,
74
+
_ => {
75
+
return (
76
+
StatusCode::NOT_FOUND,
77
+
Json(json!({"error": "AccountNotFound"})),
78
+
)
79
+
.into_response();
80
+
}
81
+
};
82
+
83
+
let token_row = match sqlx::query!(
84
+
"SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2",
85
+
user.id,
86
+
token
87
+
)
88
+
.fetch_optional(&state.db)
89
+
.await
90
+
{
91
+
Ok(Some(row)) => row,
92
+
Ok(None) => {
93
+
return (
94
+
StatusCode::BAD_REQUEST,
95
+
Json(json!({
96
+
"error": "InvalidToken",
97
+
"message": "Invalid or expired token"
98
+
})),
99
+
)
100
+
.into_response();
101
+
}
102
+
Err(e) => {
103
+
error!("DB error: {:?}", e);
104
+
return (
105
+
StatusCode::INTERNAL_SERVER_ERROR,
106
+
Json(json!({"error": "InternalError"})),
107
+
)
108
+
.into_response();
109
+
}
110
+
};
111
+
112
+
if Utc::now() > token_row.expires_at {
113
+
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
114
+
.execute(&state.db)
115
+
.await;
116
+
return (
117
+
StatusCode::BAD_REQUEST,
118
+
Json(json!({
119
+
"error": "ExpiredToken",
120
+
"message": "Token has expired"
121
+
})),
122
+
)
123
+
.into_response();
124
+
}
125
+
126
+
let key_row = match sqlx::query!(
127
+
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
128
+
user.id
129
+
)
130
+
.fetch_optional(&state.db)
131
+
.await
132
+
{
133
+
Ok(Some(row)) => row,
134
+
_ => {
135
+
return (
136
+
StatusCode::INTERNAL_SERVER_ERROR,
137
+
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
138
+
)
139
+
.into_response();
140
+
}
141
+
};
142
+
143
+
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
144
+
{
145
+
Ok(k) => k,
146
+
Err(e) => {
147
+
error!("Failed to decrypt user key: {}", e);
148
+
return (
149
+
StatusCode::INTERNAL_SERVER_ERROR,
150
+
Json(json!({"error": "InternalError"})),
151
+
)
152
+
.into_response();
153
+
}
154
+
};
155
+
156
+
let signing_key = match SigningKey::from_slice(&key_bytes) {
157
+
Ok(k) => k,
158
+
Err(e) => {
159
+
error!("Failed to create signing key: {:?}", e);
160
+
return (
161
+
StatusCode::INTERNAL_SERVER_ERROR,
162
+
Json(json!({"error": "InternalError"})),
163
+
)
164
+
.into_response();
165
+
}
166
+
};
167
+
168
+
let plc_client = PlcClient::new(None);
169
+
let last_op = match plc_client.get_last_op(did).await {
170
+
Ok(op) => op,
171
+
Err(PlcError::NotFound) => {
172
+
return (
173
+
StatusCode::NOT_FOUND,
174
+
Json(json!({
175
+
"error": "NotFound",
176
+
"message": "DID not found in PLC directory"
177
+
})),
178
+
)
179
+
.into_response();
180
+
}
181
+
Err(e) => {
182
+
error!("Failed to fetch PLC operation: {:?}", e);
183
+
return (
184
+
StatusCode::BAD_GATEWAY,
185
+
Json(json!({
186
+
"error": "UpstreamError",
187
+
"message": "Failed to communicate with PLC directory"
188
+
})),
189
+
)
190
+
.into_response();
191
+
}
192
+
};
193
+
194
+
if last_op.is_tombstone() {
195
+
return (
196
+
StatusCode::BAD_REQUEST,
197
+
Json(json!({
198
+
"error": "InvalidRequest",
199
+
"message": "DID is tombstoned"
200
+
})),
201
+
)
202
+
.into_response();
203
+
}
204
+
205
+
let services = input.services.map(|s| {
206
+
s.into_iter()
207
+
.map(|(k, v)| {
208
+
(
209
+
k,
210
+
PlcService {
211
+
service_type: v.service_type,
212
+
endpoint: v.endpoint,
213
+
},
214
+
)
215
+
})
216
+
.collect()
217
+
});
218
+
219
+
let unsigned_op = match create_update_op(
220
+
&last_op,
221
+
input.rotation_keys,
222
+
input.verification_methods,
223
+
input.also_known_as,
224
+
services,
225
+
) {
226
+
Ok(op) => op,
227
+
Err(PlcError::Tombstoned) => {
228
+
return (
229
+
StatusCode::BAD_REQUEST,
230
+
Json(json!({
231
+
"error": "InvalidRequest",
232
+
"message": "Cannot update tombstoned DID"
233
+
})),
234
+
)
235
+
.into_response();
236
+
}
237
+
Err(e) => {
238
+
error!("Failed to create PLC operation: {:?}", e);
239
+
return (
240
+
StatusCode::INTERNAL_SERVER_ERROR,
241
+
Json(json!({"error": "InternalError"})),
242
+
)
243
+
.into_response();
244
+
}
245
+
};
246
+
247
+
let signed_op = match sign_operation(&unsigned_op, &signing_key) {
248
+
Ok(op) => op,
249
+
Err(e) => {
250
+
error!("Failed to sign PLC operation: {:?}", e);
251
+
return (
252
+
StatusCode::INTERNAL_SERVER_ERROR,
253
+
Json(json!({"error": "InternalError"})),
254
+
)
255
+
.into_response();
256
+
}
257
+
};
258
+
259
+
let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id)
260
+
.execute(&state.db)
261
+
.await;
262
+
263
+
info!("Signed PLC operation for user {}", did);
264
+
265
+
(
266
+
StatusCode::OK,
267
+
Json(SignPlcOperationOutput {
268
+
operation: signed_op,
269
+
}),
270
+
)
271
+
.into_response()
272
+
}
+211
src/api/identity/plc/submit.rs
+211
src/api/identity/plc/submit.rs
···
···
1
+
use crate::api::ApiError;
2
+
use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient};
3
+
use crate::state::AppState;
4
+
use axum::{
5
+
extract::State,
6
+
http::StatusCode,
7
+
response::{IntoResponse, Response},
8
+
Json,
9
+
};
10
+
use k256::ecdsa::SigningKey;
11
+
use serde::Deserialize;
12
+
use serde_json::{json, Value};
13
+
use tracing::{error, info, warn};
14
+
15
+
#[derive(Debug, Deserialize)]
16
+
pub struct SubmitPlcOperationInput {
17
+
pub operation: Value,
18
+
}
19
+
20
+
pub async fn submit_plc_operation(
21
+
State(state): State<AppState>,
22
+
headers: axum::http::HeaderMap,
23
+
Json(input): Json<SubmitPlcOperationInput>,
24
+
) -> Response {
25
+
let bearer = match crate::auth::extract_bearer_token_from_header(
26
+
headers.get("Authorization").and_then(|h| h.to_str().ok()),
27
+
) {
28
+
Some(t) => t,
29
+
None => return ApiError::AuthenticationRequired.into_response(),
30
+
};
31
+
32
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await {
33
+
Ok(user) => user,
34
+
Err(e) => return ApiError::from(e).into_response(),
35
+
};
36
+
37
+
let did = &auth_user.did;
38
+
39
+
if let Err(e) = validate_plc_operation(&input.operation) {
40
+
return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response();
41
+
}
42
+
43
+
let op = &input.operation;
44
+
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
45
+
let public_url = format!("https://{}", hostname);
46
+
47
+
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
48
+
.fetch_optional(&state.db)
49
+
.await
50
+
{
51
+
Ok(Some(row)) => row,
52
+
_ => {
53
+
return (
54
+
StatusCode::NOT_FOUND,
55
+
Json(json!({"error": "AccountNotFound"})),
56
+
)
57
+
.into_response();
58
+
}
59
+
};
60
+
61
+
let key_row = match sqlx::query!(
62
+
"SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
63
+
user.id
64
+
)
65
+
.fetch_optional(&state.db)
66
+
.await
67
+
{
68
+
Ok(Some(row)) => row,
69
+
_ => {
70
+
return (
71
+
StatusCode::INTERNAL_SERVER_ERROR,
72
+
Json(json!({"error": "InternalError", "message": "User signing key not found"})),
73
+
)
74
+
.into_response();
75
+
}
76
+
};
77
+
78
+
let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
79
+
{
80
+
Ok(k) => k,
81
+
Err(e) => {
82
+
error!("Failed to decrypt user key: {}", e);
83
+
return (
84
+
StatusCode::INTERNAL_SERVER_ERROR,
85
+
Json(json!({"error": "InternalError"})),
86
+
)
87
+
.into_response();
88
+
}
89
+
};
90
+
91
+
let signing_key = match SigningKey::from_slice(&key_bytes) {
92
+
Ok(k) => k,
93
+
Err(e) => {
94
+
error!("Failed to create signing key: {:?}", e);
95
+
return (
96
+
StatusCode::INTERNAL_SERVER_ERROR,
97
+
Json(json!({"error": "InternalError"})),
98
+
)
99
+
.into_response();
100
+
}
101
+
};
102
+
103
+
let user_did_key = signing_key_to_did_key(&signing_key);
104
+
105
+
if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) {
106
+
let server_rotation_key =
107
+
std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone());
108
+
109
+
let has_server_key = rotation_keys
110
+
.iter()
111
+
.any(|k| k.as_str() == Some(&server_rotation_key));
112
+
113
+
if !has_server_key {
114
+
return (
115
+
StatusCode::BAD_REQUEST,
116
+
Json(json!({
117
+
"error": "InvalidRequest",
118
+
"message": "Rotation keys do not include server's rotation key"
119
+
})),
120
+
)
121
+
.into_response();
122
+
}
123
+
}
124
+
125
+
if let Some(services) = op.get("services").and_then(|v| v.as_object()) {
126
+
if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) {
127
+
let service_type = pds.get("type").and_then(|v| v.as_str());
128
+
let endpoint = pds.get("endpoint").and_then(|v| v.as_str());
129
+
130
+
if service_type != Some("AtprotoPersonalDataServer") {
131
+
return (
132
+
StatusCode::BAD_REQUEST,
133
+
Json(json!({
134
+
"error": "InvalidRequest",
135
+
"message": "Incorrect type on atproto_pds service"
136
+
})),
137
+
)
138
+
.into_response();
139
+
}
140
+
141
+
if endpoint != Some(&public_url) {
142
+
return (
143
+
StatusCode::BAD_REQUEST,
144
+
Json(json!({
145
+
"error": "InvalidRequest",
146
+
"message": "Incorrect endpoint on atproto_pds service"
147
+
})),
148
+
)
149
+
.into_response();
150
+
}
151
+
}
152
+
}
153
+
154
+
if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) {
155
+
if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) {
156
+
if atproto_key != user_did_key {
157
+
return (
158
+
StatusCode::BAD_REQUEST,
159
+
Json(json!({
160
+
"error": "InvalidRequest",
161
+
"message": "Incorrect signing key in verificationMethods"
162
+
})),
163
+
)
164
+
.into_response();
165
+
}
166
+
}
167
+
}
168
+
169
+
if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) {
170
+
let expected_handle = format!("at://{}", user.handle);
171
+
let first_aka = also_known_as.first().and_then(|v| v.as_str());
172
+
173
+
if first_aka != Some(&expected_handle) {
174
+
return (
175
+
StatusCode::BAD_REQUEST,
176
+
Json(json!({
177
+
"error": "InvalidRequest",
178
+
"message": "Incorrect handle in alsoKnownAs"
179
+
})),
180
+
)
181
+
.into_response();
182
+
}
183
+
}
184
+
185
+
let plc_client = PlcClient::new(None);
186
+
if let Err(e) = plc_client.send_operation(did, &input.operation).await {
187
+
error!("Failed to submit PLC operation: {:?}", e);
188
+
return (
189
+
StatusCode::BAD_GATEWAY,
190
+
Json(json!({
191
+
"error": "UpstreamError",
192
+
"message": format!("Failed to submit to PLC directory: {}", e)
193
+
})),
194
+
)
195
+
.into_response();
196
+
}
197
+
198
+
if let Err(e) = sqlx::query!(
199
+
"INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')",
200
+
did
201
+
)
202
+
.execute(&state.db)
203
+
.await
204
+
{
205
+
warn!("Failed to sequence identity event: {:?}", e);
206
+
}
207
+
208
+
info!("Submitted PLC operation for user {}", did);
209
+
210
+
(StatusCode::OK, Json(json!({}))).into_response()
211
+
}
+4
src/api/mod.rs
+4
src/api/mod.rs
+4
-16
src/api/moderation/mod.rs
+4
-16
src/api/moderation/mod.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
···
37
headers.get("Authorization").and_then(|h| h.to_str().ok())
38
) {
39
Some(t) => t,
40
-
None => {
41
-
return (
42
-
StatusCode::UNAUTHORIZED,
43
-
Json(json!({"error": "AuthenticationRequired"})),
44
-
)
45
-
.into_response();
46
-
}
47
};
48
49
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
50
-
let did = match auth_result {
51
Ok(user) => user.did,
52
-
Err(e) => {
53
-
return (
54
-
StatusCode::UNAUTHORIZED,
55
-
Json(json!({"error": e})),
56
-
)
57
-
.into_response();
58
-
}
59
};
60
61
let valid_reason_types = [
···
1
+
use crate::api::ApiError;
2
use crate::state::AppState;
3
use axum::{
4
Json,
···
38
headers.get("Authorization").and_then(|h| h.to_str().ok())
39
) {
40
Some(t) => t,
41
+
None => return ApiError::AuthenticationRequired.into_response(),
42
};
43
44
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
45
Ok(user) => user.did,
46
+
Err(e) => return ApiError::from(e).into_response(),
47
};
48
49
let valid_reason_types = [
+22
-2
src/api/repo/blob.rs
+22
-2
src/api/repo/blob.rs
···
15
use std::str::FromStr;
16
use tracing::error;
17
18
pub async fn upload_blob(
19
State(state): State<AppState>,
20
headers: axum::http::HeaderMap,
21
body: Bytes,
22
) -> Response {
23
let token = match crate::auth::extract_bearer_token_from_header(
24
headers.get("Authorization").and_then(|h| h.to_str().ok())
25
) {
···
57
let mut hasher = Sha256::new();
58
hasher.update(&data);
59
let hash = hasher.finalize();
60
-
let multihash = Multihash::wrap(0x12, &hash).unwrap();
61
let cid = Cid::new_v1(0x55, multihash);
62
let cid_str = cid.to_string();
63
···
207
}
208
};
209
210
-
let limit = params.limit.unwrap_or(500).min(1000);
211
let cursor_str = params.cursor.unwrap_or_default();
212
let (cursor_collection, cursor_rkey) = if cursor_str.contains('|') {
213
let parts: Vec<&str> = cursor_str.split('|').collect();
···
15
use std::str::FromStr;
16
use tracing::error;
17
18
+
const MAX_BLOB_SIZE: usize = 1_000_000;
19
+
20
pub async fn upload_blob(
21
State(state): State<AppState>,
22
headers: axum::http::HeaderMap,
23
body: Bytes,
24
) -> Response {
25
+
if body.len() > MAX_BLOB_SIZE {
26
+
return (
27
+
StatusCode::PAYLOAD_TOO_LARGE,
28
+
Json(json!({"error": "BlobTooLarge", "message": format!("Blob size {} exceeds maximum of {} bytes", body.len(), MAX_BLOB_SIZE)})),
29
+
)
30
+
.into_response();
31
+
}
32
+
33
let token = match crate::auth::extract_bearer_token_from_header(
34
headers.get("Authorization").and_then(|h| h.to_str().ok())
35
) {
···
67
let mut hasher = Sha256::new();
68
hasher.update(&data);
69
let hash = hasher.finalize();
70
+
let multihash = match Multihash::wrap(0x12, &hash) {
71
+
Ok(mh) => mh,
72
+
Err(e) => {
73
+
error!("Failed to create multihash for blob: {:?}", e);
74
+
return (
75
+
StatusCode::INTERNAL_SERVER_ERROR,
76
+
Json(json!({"error": "InternalError", "message": "Failed to hash blob"})),
77
+
)
78
+
.into_response();
79
+
}
80
+
};
81
let cid = Cid::new_v1(0x55, multihash);
82
let cid_str = cid.to_string();
83
···
227
}
228
};
229
230
+
let limit = params.limit.unwrap_or(500).clamp(1, 1000);
231
let cursor_str = params.cursor.unwrap_or_default();
232
let (cursor_collection, cursor_rkey) = if cursor_str.contains('|') {
233
let parts: Vec<&str> = cursor_str.split('|').collect();
+3
-14
src/api/repo/import.rs
+3
-14
src/api/repo/import.rs
···
1
use crate::state::AppState;
2
use crate::sync::import::{apply_import, parse_car, ImportError};
3
use crate::sync::verify::CarVerifier;
···
54
headers.get("Authorization").and_then(|h| h.to_str().ok()),
55
) {
56
Some(t) => t,
57
-
None => {
58
-
return (
59
-
StatusCode::UNAUTHORIZED,
60
-
Json(json!({"error": "AuthenticationRequired"})),
61
-
)
62
-
.into_response();
63
-
}
64
};
65
66
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
67
Ok(user) => user,
68
-
Err(e) => {
69
-
return (
70
-
StatusCode::UNAUTHORIZED,
71
-
Json(json!({"error": "AuthenticationFailed", "message": e})),
72
-
)
73
-
.into_response();
74
-
}
75
};
76
77
let did = &auth_user.did;
···
1
+
use crate::api::ApiError;
2
use crate::state::AppState;
3
use crate::sync::import::{apply_import, parse_car, ImportError};
4
use crate::sync::verify::CarVerifier;
···
55
headers.get("Authorization").and_then(|h| h.to_str().ok()),
56
) {
57
Some(t) => t,
58
+
None => return ApiError::AuthenticationRequired.into_response(),
59
};
60
61
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
62
Ok(user) => user,
63
+
Err(e) => return ApiError::from(e).into_response(),
64
};
65
66
let did = &auth_user.did;
+49
-13
src/api/repo/record/batch.rs
+49
-13
src/api/repo/record/batch.rs
···
17
use std::sync::Arc;
18
use tracing::error;
19
20
#[derive(Deserialize)]
21
#[serde(tag = "$type")]
22
pub enum WriteOp {
···
115
.into_response();
116
}
117
118
-
if input.writes.len() > 200 {
119
return (
120
StatusCode::BAD_REQUEST,
121
-
Json(json!({"error": "InvalidRequest", "message": "Too many writes (max 200)"})),
122
)
123
.into_response();
124
}
···
213
.clone()
214
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
215
let mut record_bytes = Vec::new();
216
-
serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
217
-
let record_cid = tracking_store.put(&record_bytes).await.unwrap();
218
219
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
220
-
mst = mst.add(&key, record_cid).await.unwrap();
221
222
let uri = format!("at://{}/{}/{}", did, collection, rkey);
223
results.push(WriteResult::CreateResult {
···
236
value,
237
} => {
238
let mut record_bytes = Vec::new();
239
-
serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
240
-
let record_cid = tracking_store.put(&record_bytes).await.unwrap();
241
242
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
243
-
mst = mst.update(&key, record_cid).await.unwrap();
244
245
let uri = format!("at://{}/{}/{}", did, collection, rkey);
246
results.push(WriteResult::UpdateResult {
···
254
});
255
}
256
WriteOp::Delete { collection, rkey } => {
257
-
let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
258
-
mst = mst.delete(&key).await.unwrap();
259
260
results.push(WriteResult::DeleteResult {});
261
ops.push(RecordOp::Delete {
···
266
}
267
}
268
269
-
let new_mst_root = mst.persist().await.unwrap();
270
let written_cids = tracking_store.get_written_cids();
271
let written_cids_str = written_cids
272
.iter()
···
17
use std::sync::Arc;
18
use tracing::error;
19
20
+
const MAX_BATCH_WRITES: usize = 200;
21
+
22
#[derive(Deserialize)]
23
#[serde(tag = "$type")]
24
pub enum WriteOp {
···
117
.into_response();
118
}
119
120
+
if input.writes.len() > MAX_BATCH_WRITES {
121
return (
122
StatusCode::BAD_REQUEST,
123
+
Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
124
)
125
.into_response();
126
}
···
215
.clone()
216
.unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
217
let mut record_bytes = Vec::new();
218
+
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
219
+
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
220
+
}
221
+
let record_cid = match tracking_store.put(&record_bytes).await {
222
+
Ok(c) => c,
223
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
224
+
};
225
226
+
let collection_nsid = match collection.parse::<Nsid>() {
227
+
Ok(n) => n,
228
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
229
+
};
230
+
let key = format!("{}/{}", collection_nsid, rkey);
231
+
mst = match mst.add(&key, record_cid).await {
232
+
Ok(m) => m,
233
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
234
+
};
235
236
let uri = format!("at://{}/{}/{}", did, collection, rkey);
237
results.push(WriteResult::CreateResult {
···
250
value,
251
} => {
252
let mut record_bytes = Vec::new();
253
+
if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
254
+
return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
255
+
}
256
+
let record_cid = match tracking_store.put(&record_bytes).await {
257
+
Ok(c) => c,
258
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
259
+
};
260
261
+
let collection_nsid = match collection.parse::<Nsid>() {
262
+
Ok(n) => n,
263
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
264
+
};
265
+
let key = format!("{}/{}", collection_nsid, rkey);
266
+
mst = match mst.update(&key, record_cid).await {
267
+
Ok(m) => m,
268
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
269
+
};
270
271
let uri = format!("at://{}/{}/{}", did, collection, rkey);
272
results.push(WriteResult::UpdateResult {
···
280
});
281
}
282
WriteOp::Delete { collection, rkey } => {
283
+
let collection_nsid = match collection.parse::<Nsid>() {
284
+
Ok(n) => n,
285
+
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
286
+
};
287
+
let key = format!("{}/{}", collection_nsid, rkey);
288
+
mst = match mst.delete(&key).await {
289
+
Ok(m) => m,
290
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
291
+
};
292
293
results.push(WriteResult::DeleteResult {});
294
ops.push(RecordOp::Delete {
···
299
}
300
}
301
302
+
let new_mst_root = match mst.persist().await {
303
+
Ok(c) => c,
304
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
305
+
};
306
let written_cids = tracking_store.get_written_cids();
307
let written_cids_str = written_cids
308
.iter()
+11
-5
src/api/repo/record/utils.rs
+11
-5
src/api/repo/record/utils.rs
···
55
let new_root_cid = state.block_store.put(&new_commit_bytes).await
56
.map_err(|e| format!("Failed to save commit block: {:?}", e))?;
57
58
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
59
-
.execute(&state.db)
60
.await
61
.map_err(|e| format!("DB Error (repos): {}", e))?;
62
···
71
rkey,
72
cid.to_string()
73
)
74
-
.execute(&state.db)
75
.await
76
.map_err(|e| format!("DB Error (records): {}", e))?;
77
}
···
82
collection,
83
rkey
84
)
85
-
.execute(&state.db)
86
.await
87
.map_err(|e| format!("DB Error (records): {}", e))?;
88
}
···
126
&[] as &[String],
127
blocks_cids,
128
)
129
-
.fetch_one(&state.db)
130
.await
131
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
132
133
sqlx::query(
134
&format!("NOTIFY repo_updates, '{}'", seq_row.seq)
135
)
136
-
.execute(&state.db)
137
.await
138
.map_err(|e| format!("DB Error (notify): {}", e))?;
139
140
Ok(CommitResult {
141
commit_cid: new_root_cid,
···
55
let new_root_cid = state.block_store.put(&new_commit_bytes).await
56
.map_err(|e| format!("Failed to save commit block: {:?}", e))?;
57
58
+
let mut tx = state.db.begin().await
59
+
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
60
+
61
sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id)
62
+
.execute(&mut *tx)
63
.await
64
.map_err(|e| format!("DB Error (repos): {}", e))?;
65
···
74
rkey,
75
cid.to_string()
76
)
77
+
.execute(&mut *tx)
78
.await
79
.map_err(|e| format!("DB Error (records): {}", e))?;
80
}
···
85
collection,
86
rkey
87
)
88
+
.execute(&mut *tx)
89
.await
90
.map_err(|e| format!("DB Error (records): {}", e))?;
91
}
···
129
&[] as &[String],
130
blocks_cids,
131
)
132
+
.fetch_one(&mut *tx)
133
.await
134
.map_err(|e| format!("DB Error (repo_seq): {}", e))?;
135
136
sqlx::query(
137
&format!("NOTIFY repo_updates, '{}'", seq_row.seq)
138
)
139
+
.execute(&mut *tx)
140
.await
141
.map_err(|e| format!("DB Error (notify): {}", e))?;
142
+
143
+
tx.commit().await
144
+
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
145
146
Ok(CommitResult {
147
commit_cid: new_root_cid,
+12
-3
src/api/repo/record/write.rs
+12
-3
src/api/repo/record/write.rs
···
294
};
295
296
let new_mst = if existing_cid.is_some() {
297
-
mst.update(&key, record_cid).await.unwrap()
298
} else {
299
-
mst.add(&key, record_cid).await.unwrap()
300
};
301
-
let new_mst_root = new_mst.persist().await.unwrap();
302
303
let op = if existing_cid.is_some() {
304
RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid }
···
294
};
295
296
let new_mst = if existing_cid.is_some() {
297
+
match mst.update(&key, record_cid).await {
298
+
Ok(m) => m,
299
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
300
+
}
301
} else {
302
+
match mst.add(&key, record_cid).await {
303
+
Ok(m) => m,
304
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
305
+
}
306
+
};
307
+
let new_mst_root = match new_mst.persist().await {
308
+
Ok(c) => c,
309
+
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
310
};
311
312
let op = if existing_cid.is_some() {
313
RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid }
+13
-64
src/api/server/account_status.rs
+13
-64
src/api/server/account_status.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
···
34
headers.get("Authorization").and_then(|h| h.to_str().ok())
35
) {
36
Some(t) => t,
37
-
None => {
38
-
return (
39
-
StatusCode::UNAUTHORIZED,
40
-
Json(json!({"error": "AuthenticationRequired"})),
41
-
)
42
-
.into_response();
43
-
}
44
};
45
46
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
47
-
let did = match auth_result {
48
Ok(user) => user.did,
49
-
Err(e) => {
50
-
return (
51
-
StatusCode::UNAUTHORIZED,
52
-
Json(json!({"error": e})),
53
-
)
54
-
.into_response();
55
-
}
56
};
57
58
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
127
headers.get("Authorization").and_then(|h| h.to_str().ok())
128
) {
129
Some(t) => t,
130
-
None => {
131
-
return (
132
-
StatusCode::UNAUTHORIZED,
133
-
Json(json!({"error": "AuthenticationRequired"})),
134
-
)
135
-
.into_response();
136
-
}
137
};
138
139
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
140
-
let did = match auth_result {
141
Ok(user) => user.did,
142
-
Err(e) => {
143
-
return (
144
-
StatusCode::UNAUTHORIZED,
145
-
Json(json!({"error": e})),
146
-
)
147
-
.into_response();
148
-
}
149
};
150
151
let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
···
180
headers.get("Authorization").and_then(|h| h.to_str().ok())
181
) {
182
Some(t) => t,
183
-
None => {
184
-
return (
185
-
StatusCode::UNAUTHORIZED,
186
-
Json(json!({"error": "AuthenticationRequired"})),
187
-
)
188
-
.into_response();
189
-
}
190
};
191
192
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
193
-
let did = match auth_result {
194
Ok(user) => user.did,
195
-
Err(e) => {
196
-
return (
197
-
StatusCode::UNAUTHORIZED,
198
-
Json(json!({"error": e})),
199
-
)
200
-
.into_response();
201
-
}
202
};
203
204
let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did)
···
226
headers.get("Authorization").and_then(|h| h.to_str().ok())
227
) {
228
Some(t) => t,
229
-
None => {
230
-
return (
231
-
StatusCode::UNAUTHORIZED,
232
-
Json(json!({"error": "AuthenticationRequired"})),
233
-
)
234
-
.into_response();
235
-
}
236
};
237
238
-
let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await;
239
-
let did = match auth_result {
240
Ok(user) => user.did,
241
-
Err(e) => {
242
-
return (
243
-
StatusCode::UNAUTHORIZED,
244
-
Json(json!({"error": e})),
245
-
)
246
-
.into_response();
247
-
}
248
};
249
250
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
1
+
use crate::api::ApiError;
2
use crate::state::AppState;
3
use axum::{
4
Json,
···
35
headers.get("Authorization").and_then(|h| h.to_str().ok())
36
) {
37
Some(t) => t,
38
+
None => return ApiError::AuthenticationRequired.into_response(),
39
};
40
41
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
42
Ok(user) => user.did,
43
+
Err(e) => return ApiError::from(e).into_response(),
44
};
45
46
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
···
115
headers.get("Authorization").and_then(|h| h.to_str().ok())
116
) {
117
Some(t) => t,
118
+
None => return ApiError::AuthenticationRequired.into_response(),
119
};
120
121
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
122
Ok(user) => user.did,
123
+
Err(e) => return ApiError::from(e).into_response(),
124
};
125
126
let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did)
···
155
headers.get("Authorization").and_then(|h| h.to_str().ok())
156
) {
157
Some(t) => t,
158
+
None => return ApiError::AuthenticationRequired.into_response(),
159
};
160
161
+
let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
162
Ok(user) => user.did,
163
+
Err(e) => return ApiError::from(e).into_response(),
164
};
165
166
let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did)
···
188
headers.get("Authorization").and_then(|h| h.to_str().ok())
189
) {
190
Some(t) => t,
191
+
None => return ApiError::AuthenticationRequired.into_response(),
192
};
193
194
+
let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
195
Ok(user) => user.did,
196
+
Err(e) => return ApiError::from(e).into_response(),
197
};
198
199
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
+63
-190
src/api/server/app_password.rs
+63
-190
src/api/server/app_password.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
4
extract::State,
5
-
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
8
use serde::{Deserialize, Serialize};
···
24
25
pub async fn list_app_passwords(
26
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
28
) -> Response {
29
-
let token = match crate::auth::extract_bearer_token_from_header(
30
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
31
-
) {
32
-
Some(t) => t,
33
-
None => {
34
-
return (
35
-
StatusCode::UNAUTHORIZED,
36
-
Json(json!({"error": "AuthenticationRequired"})),
37
-
)
38
-
.into_response();
39
-
}
40
-
};
41
-
42
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
43
-
let did = match auth_result {
44
-
Ok(user) => user.did,
45
-
Err(e) => {
46
-
return (
47
-
StatusCode::UNAUTHORIZED,
48
-
Json(json!({"error": e})),
49
-
)
50
-
.into_response();
51
-
}
52
};
53
54
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
55
-
.fetch_optional(&state.db)
56
-
.await
57
{
58
-
Ok(Some(id)) => id,
59
-
_ => {
60
-
return (
61
-
StatusCode::INTERNAL_SERVER_ERROR,
62
-
Json(json!({"error": "InternalError"})),
63
-
)
64
-
.into_response();
65
-
}
66
-
};
67
-
68
-
let result = sqlx::query!("SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC", user_id)
69
-
.fetch_all(&state.db)
70
-
.await;
71
-
72
-
match result {
73
Ok(rows) => {
74
let passwords: Vec<AppPassword> = rows
75
.iter()
76
-
.map(|row| {
77
-
AppPassword {
78
-
name: row.name.clone(),
79
-
created_at: row.created_at.to_rfc3339(),
80
-
privileged: row.privileged,
81
-
}
82
})
83
.collect();
84
85
-
(StatusCode::OK, Json(ListAppPasswordsOutput { passwords })).into_response()
86
}
87
Err(e) => {
88
error!("DB error listing app passwords: {:?}", e);
89
-
(
90
-
StatusCode::INTERNAL_SERVER_ERROR,
91
-
Json(json!({"error": "InternalError"})),
92
-
)
93
-
.into_response()
94
}
95
}
96
}
···
112
113
pub async fn create_app_password(
114
State(state): State<AppState>,
115
-
headers: axum::http::HeaderMap,
116
Json(input): Json<CreateAppPasswordInput>,
117
) -> Response {
118
-
let token = match crate::auth::extract_bearer_token_from_header(
119
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
120
-
) {
121
-
Some(t) => t,
122
-
None => {
123
-
return (
124
-
StatusCode::UNAUTHORIZED,
125
-
Json(json!({"error": "AuthenticationRequired"})),
126
-
)
127
-
.into_response();
128
-
}
129
-
};
130
-
131
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
132
-
let did = match auth_result {
133
-
Ok(user) => user.did,
134
-
Err(e) => {
135
-
return (
136
-
StatusCode::UNAUTHORIZED,
137
-
Json(json!({"error": e})),
138
-
)
139
-
.into_response();
140
-
}
141
-
};
142
-
143
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
144
-
.fetch_optional(&state.db)
145
-
.await
146
-
{
147
-
Ok(Some(id)) => id,
148
-
_ => {
149
-
return (
150
-
StatusCode::INTERNAL_SERVER_ERROR,
151
-
Json(json!({"error": "InternalError"})),
152
-
)
153
-
.into_response();
154
-
}
155
};
156
157
let name = input.name.trim();
158
if name.is_empty() {
159
-
return (
160
-
StatusCode::BAD_REQUEST,
161
-
Json(json!({"error": "InvalidRequest", "message": "name is required"})),
162
-
)
163
-
.into_response();
164
}
165
166
-
let existing = sqlx::query!("SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name)
167
-
.fetch_optional(&state.db)
168
-
.await;
169
170
if let Ok(Some(_)) = existing {
171
-
return (
172
-
StatusCode::BAD_REQUEST,
173
-
Json(json!({"error": "DuplicateAppPassword", "message": "App password with this name already exists"})),
174
-
)
175
-
.into_response();
176
}
177
178
let password: String = (0..4)
···
180
use rand::Rng;
181
let mut rng = rand::thread_rng();
182
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
183
-
(0..4).map(|_| chars[rng.gen_range(0..chars.len())]).collect::<String>()
184
})
185
.collect::<Vec<String>>()
186
.join("-");
···
189
Ok(h) => h,
190
Err(e) => {
191
error!("Failed to hash password: {:?}", e);
192
-
return (
193
-
StatusCode::INTERNAL_SERVER_ERROR,
194
-
Json(json!({"error": "InternalError"})),
195
-
)
196
-
.into_response();
197
}
198
};
199
200
let privileged = input.privileged.unwrap_or(false);
201
let created_at = chrono::Utc::now();
202
203
-
let result = sqlx::query!(
204
"INSERT INTO app_passwords (user_id, name, password_hash, created_at, privileged) VALUES ($1, $2, $3, $4, $5)",
205
user_id,
206
name,
···
209
privileged
210
)
211
.execute(&state.db)
212
-
.await;
213
-
214
-
match result {
215
-
Ok(_) => (
216
-
StatusCode::OK,
217
-
Json(CreateAppPasswordOutput {
218
-
name: name.to_string(),
219
-
password,
220
-
created_at: created_at.to_rfc3339(),
221
-
privileged,
222
-
}),
223
-
)
224
-
.into_response(),
225
Err(e) => {
226
error!("DB error creating app password: {:?}", e);
227
-
(
228
-
StatusCode::INTERNAL_SERVER_ERROR,
229
-
Json(json!({"error": "InternalError"})),
230
-
)
231
-
.into_response()
232
}
233
}
234
}
···
240
241
pub async fn revoke_app_password(
242
State(state): State<AppState>,
243
-
headers: axum::http::HeaderMap,
244
Json(input): Json<RevokeAppPasswordInput>,
245
) -> Response {
246
-
let token = match crate::auth::extract_bearer_token_from_header(
247
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
248
-
) {
249
-
Some(t) => t,
250
-
None => {
251
-
return (
252
-
StatusCode::UNAUTHORIZED,
253
-
Json(json!({"error": "AuthenticationRequired"})),
254
-
)
255
-
.into_response();
256
-
}
257
-
};
258
-
259
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
260
-
let did = match auth_result {
261
-
Ok(user) => user.did,
262
-
Err(e) => {
263
-
return (
264
-
StatusCode::UNAUTHORIZED,
265
-
Json(json!({"error": e})),
266
-
)
267
-
.into_response();
268
-
}
269
-
};
270
-
271
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
272
-
.fetch_optional(&state.db)
273
-
.await
274
-
{
275
-
Ok(Some(id)) => id,
276
-
_ => {
277
-
return (
278
-
StatusCode::INTERNAL_SERVER_ERROR,
279
-
Json(json!({"error": "InternalError"})),
280
-
)
281
-
.into_response();
282
-
}
283
};
284
285
let name = input.name.trim();
286
if name.is_empty() {
287
-
return (
288
-
StatusCode::BAD_REQUEST,
289
-
Json(json!({"error": "InvalidRequest", "message": "name is required"})),
290
-
)
291
-
.into_response();
292
}
293
294
-
let result = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name)
295
-
.execute(&state.db)
296
-
.await;
297
-
298
-
match result {
299
Ok(r) => {
300
if r.rows_affected() == 0 {
301
-
return (
302
-
StatusCode::NOT_FOUND,
303
-
Json(json!({"error": "AppPasswordNotFound", "message": "App password not found"})),
304
-
)
305
-
.into_response();
306
}
307
-
(StatusCode::OK, Json(json!({}))).into_response()
308
}
309
Err(e) => {
310
error!("DB error revoking app password: {:?}", e);
311
-
(
312
-
StatusCode::INTERNAL_SERVER_ERROR,
313
-
Json(json!({"error": "InternalError"})),
314
-
)
315
-
.into_response()
316
}
317
}
318
}
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
3
use crate::state::AppState;
4
+
use crate::util::get_user_id_by_did;
5
use axum::{
6
Json,
7
extract::State,
8
response::{IntoResponse, Response},
9
};
10
use serde::{Deserialize, Serialize};
···
26
27
pub async fn list_app_passwords(
28
State(state): State<AppState>,
29
+
BearerAuth(auth_user): BearerAuth,
30
) -> Response {
31
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
32
+
Ok(id) => id,
33
+
Err(e) => return ApiError::from(e).into_response(),
34
};
35
36
+
match sqlx::query!(
37
+
"SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC",
38
+
user_id
39
+
)
40
+
.fetch_all(&state.db)
41
+
.await
42
{
43
Ok(rows) => {
44
let passwords: Vec<AppPassword> = rows
45
.iter()
46
+
.map(|row| AppPassword {
47
+
name: row.name.clone(),
48
+
created_at: row.created_at.to_rfc3339(),
49
+
privileged: row.privileged,
50
})
51
.collect();
52
53
+
Json(ListAppPasswordsOutput { passwords }).into_response()
54
}
55
Err(e) => {
56
error!("DB error listing app passwords: {:?}", e);
57
+
ApiError::InternalError.into_response()
58
}
59
}
60
}
···
76
77
pub async fn create_app_password(
78
State(state): State<AppState>,
79
+
BearerAuth(auth_user): BearerAuth,
80
Json(input): Json<CreateAppPasswordInput>,
81
) -> Response {
82
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
83
+
Ok(id) => id,
84
+
Err(e) => return ApiError::from(e).into_response(),
85
};
86
87
let name = input.name.trim();
88
if name.is_empty() {
89
+
return ApiError::InvalidRequest("name is required".into()).into_response();
90
}
91
92
+
let existing = sqlx::query!(
93
+
"SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2",
94
+
user_id,
95
+
name
96
+
)
97
+
.fetch_optional(&state.db)
98
+
.await;
99
100
if let Ok(Some(_)) = existing {
101
+
return ApiError::DuplicateAppPassword.into_response();
102
}
103
104
let password: String = (0..4)
···
106
use rand::Rng;
107
let mut rng = rand::thread_rng();
108
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
109
+
(0..4)
110
+
.map(|_| chars[rng.gen_range(0..chars.len())])
111
+
.collect::<String>()
112
})
113
.collect::<Vec<String>>()
114
.join("-");
···
117
Ok(h) => h,
118
Err(e) => {
119
error!("Failed to hash password: {:?}", e);
120
+
return ApiError::InternalError.into_response();
121
}
122
};
123
124
let privileged = input.privileged.unwrap_or(false);
125
let created_at = chrono::Utc::now();
126
127
+
match sqlx::query!(
128
"INSERT INTO app_passwords (user_id, name, password_hash, created_at, privileged) VALUES ($1, $2, $3, $4, $5)",
129
user_id,
130
name,
···
133
privileged
134
)
135
.execute(&state.db)
136
+
.await
137
+
{
138
+
Ok(_) => Json(CreateAppPasswordOutput {
139
+
name: name.to_string(),
140
+
password,
141
+
created_at: created_at.to_rfc3339(),
142
+
privileged,
143
+
})
144
+
.into_response(),
145
Err(e) => {
146
error!("DB error creating app password: {:?}", e);
147
+
ApiError::InternalError.into_response()
148
}
149
}
150
}
···
156
157
pub async fn revoke_app_password(
158
State(state): State<AppState>,
159
+
BearerAuth(auth_user): BearerAuth,
160
Json(input): Json<RevokeAppPasswordInput>,
161
) -> Response {
162
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
163
+
Ok(id) => id,
164
+
Err(e) => return ApiError::from(e).into_response(),
165
};
166
167
let name = input.name.trim();
168
if name.is_empty() {
169
+
return ApiError::InvalidRequest("name is required".into()).into_response();
170
}
171
172
+
match sqlx::query!(
173
+
"DELETE FROM app_passwords WHERE user_id = $1 AND name = $2",
174
+
user_id,
175
+
name
176
+
)
177
+
.execute(&state.db)
178
+
.await
179
+
{
180
Ok(r) => {
181
if r.rows_affected() == 0 {
182
+
return ApiError::AppPasswordNotFound.into_response();
183
}
184
+
Json(json!({})).into_response()
185
}
186
Err(e) => {
187
error!("DB error revoking app password: {:?}", e);
188
+
ApiError::InternalError.into_response()
189
}
190
}
191
}
+47
-54
src/api/server/email.rs
+47
-54
src/api/server/email.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
···
6
response::{IntoResponse, Response},
7
};
8
use chrono::{Duration, Utc};
9
-
use rand::Rng;
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
14
fn generate_confirmation_code() -> String {
15
-
let mut rng = rand::thread_rng();
16
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
17
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
18
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19
-
format!("{}-{}", part1, part2)
20
}
21
22
#[derive(Deserialize)]
···
46
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
47
let did = match auth_result {
48
Ok(user) => user.did,
49
-
Err(e) => {
50
-
return (
51
-
StatusCode::UNAUTHORIZED,
52
-
Json(json!({"error": e})),
53
-
)
54
-
.into_response();
55
-
}
56
};
57
58
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
···
72
let handle = user.handle;
73
74
let email = input.email.trim().to_lowercase();
75
-
if email.is_empty() {
76
return (
77
StatusCode::BAD_REQUEST,
78
-
Json(json!({"error": "InvalidRequest", "message": "email is required"})),
79
)
80
.into_response();
81
}
···
161
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
162
let did = match auth_result {
163
Ok(user) => user.did,
164
-
Err(e) => {
165
-
return (
166
-
StatusCode::UNAUTHORIZED,
167
-
Json(json!({"error": e})),
168
-
)
169
-
.into_response();
170
-
}
171
};
172
173
let user = match sqlx::query!(
···
194
let email = input.email.trim().to_lowercase();
195
let confirmation_code = input.token.trim();
196
197
-
if email_pending_verification.is_none() || stored_code.is_none() || expires_at.is_none() {
198
-
return (
199
-
StatusCode::BAD_REQUEST,
200
-
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
201
-
)
202
-
.into_response();
203
-
}
204
205
-
let email_pending_verification = email_pending_verification.unwrap();
206
-
if email_pending_verification != email {
207
return (
208
StatusCode::BAD_REQUEST,
209
Json(json!({"error": "InvalidRequest", "message": "Email does not match pending update"})),
···
211
.into_response();
212
}
213
214
-
if stored_code.unwrap() != confirmation_code {
215
return (
216
StatusCode::BAD_REQUEST,
217
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
219
.into_response();
220
}
221
222
-
if Utc::now() > expires_at.unwrap() {
223
return (
224
StatusCode::BAD_REQUEST,
225
Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
···
229
230
let update = sqlx::query!(
231
"UPDATE users SET email = $1, email_pending_verification = NULL, email_confirmation_code = NULL, email_confirmation_code_expires_at = NULL WHERE id = $2",
232
-
email_pending_verification,
233
user_id
234
)
235
.execute(&state.db)
···
287
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
288
let did = match auth_result {
289
Ok(user) => user.did,
290
-
Err(e) => {
291
-
return (
292
-
StatusCode::UNAUTHORIZED,
293
-
Json(json!({"error": e})),
294
-
)
295
-
.into_response();
296
-
}
297
};
298
299
let user = match sqlx::query!(
···
319
let email_pending_verification = user.email_pending_verification;
320
321
let new_email = input.email.trim().to_lowercase();
322
-
if new_email.is_empty() {
323
return (
324
StatusCode::BAD_REQUEST,
325
-
Json(json!({"error": "InvalidRequest", "message": "email is required"})),
326
-
)
327
-
.into_response();
328
-
}
329
-
330
-
if !new_email.contains('@') || !new_email.contains('.') {
331
-
return (
332
-
StatusCode::BAD_REQUEST,
333
-
Json(json!({"error": "InvalidRequest", "message": "Invalid email format"})),
334
)
335
.into_response();
336
}
···
353
}
354
};
355
356
-
let pending_email = email_pending_verification.unwrap();
357
if pending_email.to_lowercase() != new_email {
358
return (
359
StatusCode::BAD_REQUEST,
···
362
.into_response();
363
}
364
365
-
if stored_code.unwrap() != confirmation_token {
366
return (
367
StatusCode::BAD_REQUEST,
368
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
415
416
match update {
417
Ok(_) => {
418
-
info!("Email updated to {} for user {}", new_email, user_id);
419
(StatusCode::OK, Json(json!({}))).into_response()
420
}
421
Err(e) => {
···
1
+
use crate::api::ApiError;
2
use crate::state::AppState;
3
use axum::{
4
Json,
···
7
response::{IntoResponse, Response},
8
};
9
use chrono::{Duration, Utc};
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
14
fn generate_confirmation_code() -> String {
15
+
crate::util::generate_token_code()
16
}
17
18
#[derive(Deserialize)]
···
42
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
43
let did = match auth_result {
44
Ok(user) => user.did,
45
+
Err(e) => return ApiError::from(e).into_response(),
46
};
47
48
let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did)
···
62
let handle = user.handle;
63
64
let email = input.email.trim().to_lowercase();
65
+
if !crate::api::validation::is_valid_email(&email) {
66
return (
67
StatusCode::BAD_REQUEST,
68
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
69
)
70
.into_response();
71
}
···
151
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
152
let did = match auth_result {
153
Ok(user) => user.did,
154
+
Err(e) => return ApiError::from(e).into_response(),
155
};
156
157
let user = match sqlx::query!(
···
178
let email = input.email.trim().to_lowercase();
179
let confirmation_code = input.token.trim();
180
181
+
let (pending_email, saved_code, expiry) = match (email_pending_verification, stored_code, expires_at) {
182
+
(Some(p), Some(c), Some(e)) => (p, c, e),
183
+
_ => {
184
+
return (
185
+
StatusCode::BAD_REQUEST,
186
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
187
+
)
188
+
.into_response();
189
+
}
190
+
};
191
192
+
if pending_email != email {
193
return (
194
StatusCode::BAD_REQUEST,
195
Json(json!({"error": "InvalidRequest", "message": "Email does not match pending update"})),
···
197
.into_response();
198
}
199
200
+
if saved_code != confirmation_code {
201
return (
202
StatusCode::BAD_REQUEST,
203
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
205
.into_response();
206
}
207
208
+
if Utc::now() > expiry {
209
return (
210
StatusCode::BAD_REQUEST,
211
Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
···
215
216
let update = sqlx::query!(
217
"UPDATE users SET email = $1, email_pending_verification = NULL, email_confirmation_code = NULL, email_confirmation_code_expires_at = NULL WHERE id = $2",
218
+
pending_email,
219
user_id
220
)
221
.execute(&state.db)
···
273
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
274
let did = match auth_result {
275
Ok(user) => user.did,
276
+
Err(e) => return ApiError::from(e).into_response(),
277
};
278
279
let user = match sqlx::query!(
···
299
let email_pending_verification = user.email_pending_verification;
300
301
let new_email = input.email.trim().to_lowercase();
302
+
if !crate::api::validation::is_valid_email(&new_email) {
303
return (
304
StatusCode::BAD_REQUEST,
305
+
Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})),
306
)
307
.into_response();
308
}
···
325
}
326
};
327
328
+
let pending_email = match email_pending_verification {
329
+
Some(p) => p,
330
+
None => {
331
+
return (
332
+
StatusCode::BAD_REQUEST,
333
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
334
+
)
335
+
.into_response();
336
+
}
337
+
};
338
+
339
if pending_email.to_lowercase() != new_email {
340
return (
341
StatusCode::BAD_REQUEST,
···
344
.into_response();
345
}
346
347
+
let saved_code = match stored_code {
348
+
Some(c) => c,
349
+
None => {
350
+
return (
351
+
StatusCode::BAD_REQUEST,
352
+
Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})),
353
+
)
354
+
.into_response();
355
+
}
356
+
};
357
+
358
+
if saved_code != confirmation_token {
359
return (
360
StatusCode::BAD_REQUEST,
361
Json(json!({"error": "InvalidToken", "message": "Invalid token"})),
···
408
409
match update {
410
Ok(_) => {
411
+
info!("Email updated for user {}", user_id);
412
(StatusCode::OK, Json(json!({}))).into_response()
413
}
414
Err(e) => {
+61
-209
src/api/server/invite.rs
+61
-209
src/api/server/invite.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
4
extract::State,
5
-
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
8
use serde::{Deserialize, Serialize};
9
-
use serde_json::json;
10
use tracing::error;
11
use uuid::Uuid;
12
···
24
25
pub async fn create_invite_code(
26
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
28
Json(input): Json<CreateInviteCodeInput>,
29
) -> Response {
30
-
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
-
) {
33
-
Some(t) => t,
34
-
None => {
35
-
return (
36
-
StatusCode::UNAUTHORIZED,
37
-
Json(json!({"error": "AuthenticationRequired"})),
38
-
)
39
-
.into_response();
40
-
}
41
-
};
42
-
43
if input.use_count < 1 {
44
-
return (
45
-
StatusCode::BAD_REQUEST,
46
-
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
47
-
)
48
-
.into_response();
49
}
50
51
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
52
-
let did = match auth_result {
53
-
Ok(user) => user.did,
54
-
Err(e) => {
55
-
return (
56
-
StatusCode::UNAUTHORIZED,
57
-
Json(json!({"error": e})),
58
-
)
59
-
.into_response();
60
-
}
61
-
};
62
-
63
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
64
-
.fetch_optional(&state.db)
65
-
.await
66
-
{
67
-
Ok(Some(id)) => id,
68
-
_ => {
69
-
return (
70
-
StatusCode::INTERNAL_SERVER_ERROR,
71
-
Json(json!({"error": "InternalError"})),
72
-
)
73
-
.into_response();
74
-
}
75
};
76
77
let creator_user_id = if let Some(for_account) = &input.for_account {
78
-
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
79
.fetch_optional(&state.db)
80
-
.await;
81
-
82
-
match target {
83
Ok(Some(row)) => row.id,
84
-
Ok(None) => {
85
-
return (
86
-
StatusCode::NOT_FOUND,
87
-
Json(json!({"error": "AccountNotFound", "message": "Target account not found"})),
88
-
)
89
-
.into_response();
90
-
}
91
Err(e) => {
92
error!("DB error looking up target account: {:?}", e);
93
-
return (
94
-
StatusCode::INTERNAL_SERVER_ERROR,
95
-
Json(json!({"error": "InternalError"})),
96
-
)
97
-
.into_response();
98
}
99
}
100
} else {
···
103
104
let user_invites_disabled = sqlx::query_scalar!(
105
"SELECT invites_disabled FROM users WHERE did = $1",
106
-
did
107
)
108
.fetch_optional(&state.db)
109
.await
110
.ok()
111
.flatten()
112
.flatten()
113
.unwrap_or(false);
114
115
if user_invites_disabled {
116
-
return (
117
-
StatusCode::FORBIDDEN,
118
-
Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})),
119
-
)
120
-
.into_response();
121
}
122
123
let code = Uuid::new_v4().to_string();
124
125
-
let result = sqlx::query!(
126
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
127
code,
128
input.use_count,
129
creator_user_id
130
)
131
.execute(&state.db)
132
-
.await;
133
-
134
-
match result {
135
-
Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(),
136
Err(e) => {
137
error!("DB error creating invite code: {:?}", e);
138
-
(
139
-
StatusCode::INTERNAL_SERVER_ERROR,
140
-
Json(json!({"error": "InternalError"})),
141
-
)
142
-
.into_response()
143
}
144
}
145
}
···
165
166
pub async fn create_invite_codes(
167
State(state): State<AppState>,
168
-
headers: axum::http::HeaderMap,
169
Json(input): Json<CreateInviteCodesInput>,
170
) -> Response {
171
-
let token = match crate::auth::extract_bearer_token_from_header(
172
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
173
-
) {
174
-
Some(t) => t,
175
-
None => {
176
-
return (
177
-
StatusCode::UNAUTHORIZED,
178
-
Json(json!({"error": "AuthenticationRequired"})),
179
-
)
180
-
.into_response();
181
-
}
182
-
};
183
-
184
if input.use_count < 1 {
185
-
return (
186
-
StatusCode::BAD_REQUEST,
187
-
Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
188
-
)
189
-
.into_response();
190
}
191
192
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
193
-
let did = match auth_result {
194
-
Ok(user) => user.did,
195
-
Err(e) => {
196
-
return (
197
-
StatusCode::UNAUTHORIZED,
198
-
Json(json!({"error": e})),
199
-
)
200
-
.into_response();
201
-
}
202
-
};
203
-
204
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
205
-
.fetch_optional(&state.db)
206
-
.await
207
-
{
208
-
Ok(Some(id)) => id,
209
-
_ => {
210
-
return (
211
-
StatusCode::INTERNAL_SERVER_ERROR,
212
-
Json(json!({"error": "InternalError"})),
213
-
)
214
-
.into_response();
215
-
}
216
};
217
218
let code_count = input.code_count.unwrap_or(1).max(1);
···
225
for _ in 0..code_count {
226
let code = Uuid::new_v4().to_string();
227
228
-
let insert = sqlx::query!(
229
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
230
code,
231
input.use_count,
232
user_id
233
)
234
.execute(&state.db)
235
-
.await;
236
-
237
-
if let Err(e) = insert {
238
error!("DB error creating invite code: {:?}", e);
239
-
return (
240
-
StatusCode::INTERNAL_SERVER_ERROR,
241
-
Json(json!({"error": "InternalError"})),
242
-
)
243
-
.into_response();
244
}
245
246
codes.push(code);
···
252
});
253
} else {
254
for account_did in for_accounts {
255
-
let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
256
.fetch_optional(&state.db)
257
-
.await;
258
-
259
-
let target_user_id = match target {
260
Ok(Some(row)) => row.id,
261
-
Ok(None) => {
262
-
continue;
263
-
}
264
Err(e) => {
265
error!("DB error looking up target account: {:?}", e);
266
-
return (
267
-
StatusCode::INTERNAL_SERVER_ERROR,
268
-
Json(json!({"error": "InternalError"})),
269
-
)
270
-
.into_response();
271
}
272
};
273
···
275
for _ in 0..code_count {
276
let code = Uuid::new_v4().to_string();
277
278
-
let insert = sqlx::query!(
279
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
280
code,
281
input.use_count,
282
target_user_id
283
)
284
.execute(&state.db)
285
-
.await;
286
-
287
-
if let Err(e) = insert {
288
error!("DB error creating invite code: {:?}", e);
289
-
return (
290
-
StatusCode::INTERNAL_SERVER_ERROR,
291
-
Json(json!({"error": "InternalError"})),
292
-
)
293
-
.into_response();
294
}
295
296
codes.push(code);
···
303
}
304
}
305
306
-
(StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response()
307
}
308
309
#[derive(Deserialize)]
···
339
340
pub async fn get_account_invite_codes(
341
State(state): State<AppState>,
342
-
headers: axum::http::HeaderMap,
343
axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
344
) -> Response {
345
-
let token = match crate::auth::extract_bearer_token_from_header(
346
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
347
-
) {
348
-
Some(t) => t,
349
-
None => {
350
-
return (
351
-
StatusCode::UNAUTHORIZED,
352
-
Json(json!({"error": "AuthenticationRequired"})),
353
-
)
354
-
.into_response();
355
-
}
356
-
};
357
-
358
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
359
-
let did = match auth_result {
360
-
Ok(user) => user.did,
361
-
Err(e) => {
362
-
return (
363
-
StatusCode::UNAUTHORIZED,
364
-
Json(json!({"error": e})),
365
-
)
366
-
.into_response();
367
-
}
368
-
};
369
-
370
-
let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
371
-
.fetch_optional(&state.db)
372
-
.await
373
-
{
374
-
Ok(Some(id)) => id,
375
-
_ => {
376
-
return (
377
-
StatusCode::INTERNAL_SERVER_ERROR,
378
-
Json(json!({"error": "InternalError"})),
379
-
)
380
-
.into_response();
381
-
}
382
};
383
384
let include_used = params.include_used.unwrap_or(true);
385
386
-
let codes_result = sqlx::query!(
387
r#"
388
SELECT code, available_uses, created_at, disabled
389
FROM invite_codes
···
393
user_id
394
)
395
.fetch_all(&state.db)
396
-
.await;
397
-
398
-
let codes_rows = match codes_result {
399
Ok(rows) => {
400
if include_used {
401
rows
···
405
}
406
Err(e) => {
407
error!("DB error fetching invite codes: {:?}", e);
408
-
return (
409
-
StatusCode::INTERNAL_SERVER_ERROR,
410
-
Json(json!({"error": "InternalError"})),
411
-
)
412
-
.into_response();
413
}
414
};
415
416
let mut codes = Vec::new();
417
for row in codes_rows {
418
-
let uses_result = sqlx::query!(
419
r#"
420
SELECT u.did, icu.used_at
421
FROM invite_code_uses icu
···
426
row.code
427
)
428
.fetch_all(&state.db)
429
-
.await;
430
-
431
-
let uses = match uses_result {
432
-
Ok(use_rows) => use_rows
433
.iter()
434
.map(|u| InviteCodeUse {
435
used_by: u.did.clone(),
436
used_at: u.used_at.to_rfc3339(),
437
})
438
-
.collect(),
439
-
Err(_) => Vec::new(),
440
-
};
441
442
codes.push(InviteCode {
443
code: row.code,
444
available: row.available_uses,
445
disabled: row.disabled.unwrap_or(false),
446
-
for_account: did.clone(),
447
-
created_by: did.clone(),
448
created_at: row.created_at.to_rfc3339(),
449
uses,
450
});
451
}
452
453
-
(StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response()
454
}
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
3
use crate::state::AppState;
4
+
use crate::util::get_user_id_by_did;
5
use axum::{
6
Json,
7
extract::State,
8
response::{IntoResponse, Response},
9
};
10
use serde::{Deserialize, Serialize};
11
use tracing::error;
12
use uuid::Uuid;
13
···
25
26
pub async fn create_invite_code(
27
State(state): State<AppState>,
28
+
BearerAuth(auth_user): BearerAuth,
29
Json(input): Json<CreateInviteCodeInput>,
30
) -> Response {
31
if input.use_count < 1 {
32
+
return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
33
}
34
35
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
36
+
Ok(id) => id,
37
+
Err(e) => return ApiError::from(e).into_response(),
38
};
39
40
let creator_user_id = if let Some(for_account) = &input.for_account {
41
+
match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
42
.fetch_optional(&state.db)
43
+
.await
44
+
{
45
Ok(Some(row)) => row.id,
46
+
Ok(None) => return ApiError::AccountNotFound.into_response(),
47
Err(e) => {
48
error!("DB error looking up target account: {:?}", e);
49
+
return ApiError::InternalError.into_response();
50
}
51
}
52
} else {
···
55
56
let user_invites_disabled = sqlx::query_scalar!(
57
"SELECT invites_disabled FROM users WHERE did = $1",
58
+
auth_user.did
59
)
60
.fetch_optional(&state.db)
61
.await
62
+
.map_err(|e| {
63
+
error!("DB error checking invites_disabled: {:?}", e);
64
+
ApiError::InternalError
65
+
})
66
.ok()
67
.flatten()
68
.flatten()
69
.unwrap_or(false);
70
71
if user_invites_disabled {
72
+
return ApiError::InvitesDisabled.into_response();
73
}
74
75
let code = Uuid::new_v4().to_string();
76
77
+
match sqlx::query!(
78
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
79
code,
80
input.use_count,
81
creator_user_id
82
)
83
.execute(&state.db)
84
+
.await
85
+
{
86
+
Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
87
Err(e) => {
88
error!("DB error creating invite code: {:?}", e);
89
+
ApiError::InternalError.into_response()
90
}
91
}
92
}
···
112
113
pub async fn create_invite_codes(
114
State(state): State<AppState>,
115
+
BearerAuth(auth_user): BearerAuth,
116
Json(input): Json<CreateInviteCodesInput>,
117
) -> Response {
118
if input.use_count < 1 {
119
+
return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
120
}
121
122
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
123
+
Ok(id) => id,
124
+
Err(e) => return ApiError::from(e).into_response(),
125
};
126
127
let code_count = input.code_count.unwrap_or(1).max(1);
···
134
for _ in 0..code_count {
135
let code = Uuid::new_v4().to_string();
136
137
+
if let Err(e) = sqlx::query!(
138
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
139
code,
140
input.use_count,
141
user_id
142
)
143
.execute(&state.db)
144
+
.await
145
+
{
146
error!("DB error creating invite code: {:?}", e);
147
+
return ApiError::InternalError.into_response();
148
}
149
150
codes.push(code);
···
156
});
157
} else {
158
for account_did in for_accounts {
159
+
let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
160
.fetch_optional(&state.db)
161
+
.await
162
+
{
163
Ok(Some(row)) => row.id,
164
+
Ok(None) => continue,
165
Err(e) => {
166
error!("DB error looking up target account: {:?}", e);
167
+
return ApiError::InternalError.into_response();
168
}
169
};
170
···
172
for _ in 0..code_count {
173
let code = Uuid::new_v4().to_string();
174
175
+
if let Err(e) = sqlx::query!(
176
"INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
177
code,
178
input.use_count,
179
target_user_id
180
)
181
.execute(&state.db)
182
+
.await
183
+
{
184
error!("DB error creating invite code: {:?}", e);
185
+
return ApiError::InternalError.into_response();
186
}
187
188
codes.push(code);
···
195
}
196
}
197
198
+
Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
199
}
200
201
#[derive(Deserialize)]
···
231
232
pub async fn get_account_invite_codes(
233
State(state): State<AppState>,
234
+
BearerAuth(auth_user): BearerAuth,
235
axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
236
) -> Response {
237
+
let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
238
+
Ok(id) => id,
239
+
Err(e) => return ApiError::from(e).into_response(),
240
};
241
242
let include_used = params.include_used.unwrap_or(true);
243
244
+
let codes_rows = match sqlx::query!(
245
r#"
246
SELECT code, available_uses, created_at, disabled
247
FROM invite_codes
···
251
user_id
252
)
253
.fetch_all(&state.db)
254
+
.await
255
+
{
256
Ok(rows) => {
257
if include_used {
258
rows
···
262
}
263
Err(e) => {
264
error!("DB error fetching invite codes: {:?}", e);
265
+
return ApiError::InternalError.into_response();
266
}
267
};
268
269
let mut codes = Vec::new();
270
for row in codes_rows {
271
+
let uses = sqlx::query!(
272
r#"
273
SELECT u.did, icu.used_at
274
FROM invite_code_uses icu
···
279
row.code
280
)
281
.fetch_all(&state.db)
282
+
.await
283
+
.map(|use_rows| {
284
+
use_rows
285
.iter()
286
.map(|u| InviteCodeUse {
287
used_by: u.did.clone(),
288
used_at: u.used_at.to_rfc3339(),
289
})
290
+
.collect()
291
+
})
292
+
.unwrap_or_default();
293
294
codes.push(InviteCode {
295
code: row.code,
296
available: row.available_uses,
297
disabled: row.disabled.unwrap_or(false),
298
+
for_account: auth_user.did.clone(),
299
+
created_by: auth_user.did.clone(),
300
created_at: row.created_at.to_rfc3339(),
301
uses,
302
});
303
}
304
305
+
Json(GetAccountInviteCodesOutput { codes }).into_response()
306
}
+3
-3
src/api/server/mod.rs
+3
-3
src/api/server/mod.rs
···
4
pub mod invite;
5
pub mod meta;
6
pub mod password;
7
pub mod session;
8
pub mod signing_key;
9
···
16
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
17
pub use meta::{describe_server, health};
18
pub use password::{request_password_reset, reset_password};
19
-
pub use session::{
20
-
create_session, delete_session, get_service_auth, get_session, refresh_session,
21
-
};
22
pub use signing_key::reserve_signing_key;
···
4
pub mod invite;
5
pub mod meta;
6
pub mod password;
7
+
pub mod service_auth;
8
pub mod session;
9
pub mod signing_key;
10
···
17
pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes};
18
pub use meta::{describe_server, health};
19
pub use password::{request_password_reset, reset_password};
20
+
pub use service_auth::get_service_auth;
21
+
pub use session::{create_session, delete_session, get_session, refresh_session};
22
pub use signing_key::reserve_signing_key;
+43
-17
src/api/server/password.rs
+43
-17
src/api/server/password.rs
···
7
};
8
use bcrypt::{hash, DEFAULT_COST};
9
use chrono::{Duration, Utc};
10
-
use rand::Rng;
11
use serde::Deserialize;
12
use serde_json::json;
13
use tracing::{error, info, warn};
14
15
fn generate_reset_code() -> String {
16
-
let mut rng = rand::thread_rng();
17
-
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
18
-
let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19
-
let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
20
-
format!("{}-{}", part1, part2)
21
}
22
23
#[derive(Deserialize)]
···
45
let user_id = match user {
46
Ok(Some(row)) => row.id,
47
Ok(None) => {
48
-
info!("Password reset requested for unknown email: {}", email);
49
return (StatusCode::OK, Json(json!({}))).into_response();
50
}
51
Err(e) => {
···
151
152
if let Some(exp) = expires_at {
153
if Utc::now() > exp {
154
-
let _ = sqlx::query!(
155
"UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
156
user_id
157
)
158
.execute(&state.db)
159
-
.await;
160
161
return (
162
StatusCode::BAD_REQUEST,
···
184
}
185
};
186
187
-
let update = sqlx::query!(
188
"UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
189
password_hash,
190
user_id
191
)
192
-
.execute(&state.db)
193
-
.await;
194
-
195
-
if let Err(e) = update {
196
error!("DB error updating password: {:?}", e);
197
return (
198
StatusCode::INTERNAL_SERVER_ERROR,
···
201
.into_response();
202
}
203
204
-
let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
205
-
.execute(&state.db)
206
-
.await;
207
208
info!("Password reset completed for user {}", user_id);
209
···
7
};
8
use bcrypt::{hash, DEFAULT_COST};
9
use chrono::{Duration, Utc};
10
use serde::Deserialize;
11
use serde_json::json;
12
use tracing::{error, info, warn};
13
14
fn generate_reset_code() -> String {
15
+
crate::util::generate_token_code()
16
}
17
18
#[derive(Deserialize)]
···
40
let user_id = match user {
41
Ok(Some(row)) => row.id,
42
Ok(None) => {
43
+
info!("Password reset requested for unknown email");
44
return (StatusCode::OK, Json(json!({}))).into_response();
45
}
46
Err(e) => {
···
146
147
if let Some(exp) = expires_at {
148
if Utc::now() > exp {
149
+
if let Err(e) = sqlx::query!(
150
"UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
151
user_id
152
)
153
.execute(&state.db)
154
+
.await
155
+
{
156
+
error!("Failed to clear expired reset code: {:?}", e);
157
+
}
158
159
return (
160
StatusCode::BAD_REQUEST,
···
182
}
183
};
184
185
+
let mut tx = match state.db.begin().await {
186
+
Ok(tx) => tx,
187
+
Err(e) => {
188
+
error!("Failed to begin transaction: {:?}", e);
189
+
return (
190
+
StatusCode::INTERNAL_SERVER_ERROR,
191
+
Json(json!({"error": "InternalError"})),
192
+
)
193
+
.into_response();
194
+
}
195
+
};
196
+
197
+
if let Err(e) = sqlx::query!(
198
"UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
199
password_hash,
200
user_id
201
)
202
+
.execute(&mut *tx)
203
+
.await
204
+
{
205
error!("DB error updating password: {:?}", e);
206
return (
207
StatusCode::INTERNAL_SERVER_ERROR,
···
210
.into_response();
211
}
212
213
+
if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
214
+
.execute(&mut *tx)
215
+
.await
216
+
{
217
+
error!("Failed to invalidate sessions after password reset: {:?}", e);
218
+
return (
219
+
StatusCode::INTERNAL_SERVER_ERROR,
220
+
Json(json!({"error": "InternalError"})),
221
+
)
222
+
.into_response();
223
+
}
224
+
225
+
if let Err(e) = tx.commit().await {
226
+
error!("Failed to commit password reset transaction: {:?}", e);
227
+
return (
228
+
StatusCode::INTERNAL_SERVER_ERROR,
229
+
Json(json!({"error": "InternalError"})),
230
+
)
231
+
.into_response();
232
+
}
233
234
info!("Password reset completed for user {}", user_id);
235
+63
src/api/server/service_auth.rs
+63
src/api/server/service_auth.rs
···
···
1
+
use crate::api::ApiError;
2
+
use crate::state::AppState;
3
+
use axum::{
4
+
Json,
5
+
extract::{Query, State},
6
+
http::StatusCode,
7
+
response::{IntoResponse, Response},
8
+
};
9
+
use serde::{Deserialize, Serialize};
10
+
use serde_json::json;
11
+
use tracing::error;
12
+
13
+
#[derive(Deserialize)]
14
+
pub struct GetServiceAuthParams {
15
+
pub aud: String,
16
+
pub lxm: Option<String>,
17
+
pub exp: Option<i64>,
18
+
}
19
+
20
+
#[derive(Serialize)]
21
+
pub struct GetServiceAuthOutput {
22
+
pub token: String,
23
+
}
24
+
25
+
pub async fn get_service_auth(
26
+
State(state): State<AppState>,
27
+
headers: axum::http::HeaderMap,
28
+
Query(params): Query<GetServiceAuthParams>,
29
+
) -> Response {
30
+
let token = match crate::auth::extract_bearer_token_from_header(
31
+
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
+
) {
33
+
Some(t) => t,
34
+
None => return ApiError::AuthenticationRequired.into_response(),
35
+
};
36
+
37
+
let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
38
+
Ok(user) => user,
39
+
Err(e) => return ApiError::from(e).into_response(),
40
+
};
41
+
42
+
let key_bytes = match auth_user.key_bytes {
43
+
Some(kb) => kb,
44
+
None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot create service auth".into()).into_response(),
45
+
};
46
+
47
+
let lxm = params.lxm.as_deref().unwrap_or("*");
48
+
49
+
let service_token = match crate::auth::create_service_token(&auth_user.did, ¶ms.aud, lxm, &key_bytes)
50
+
{
51
+
Ok(t) => t,
52
+
Err(e) => {
53
+
error!("Failed to create service token: {:?}", e);
54
+
return (
55
+
StatusCode::INTERNAL_SERVER_ERROR,
56
+
Json(json!({"error": "InternalError"})),
57
+
)
58
+
.into_response();
59
+
}
60
+
};
61
+
62
+
(StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response()
63
+
}
+199
-512
src/api/server/session.rs
+199
-512
src/api/server/session.rs
···
1
use crate::state::AppState;
2
use axum::{
3
Json,
4
-
extract::{Query, State},
5
-
http::StatusCode,
6
response::{IntoResponse, Response},
7
};
8
use bcrypt::verify;
···
11
use tracing::{error, info, warn};
12
13
#[derive(Deserialize)]
14
-
pub struct GetServiceAuthParams {
15
-
pub aud: String,
16
-
pub lxm: Option<String>,
17
-
pub exp: Option<i64>,
18
-
}
19
-
20
-
#[derive(Serialize)]
21
-
pub struct GetServiceAuthOutput {
22
-
pub token: String,
23
-
}
24
-
25
-
pub async fn get_service_auth(
26
-
State(state): State<AppState>,
27
-
headers: axum::http::HeaderMap,
28
-
Query(params): Query<GetServiceAuthParams>,
29
-
) -> Response {
30
-
let token = match crate::auth::extract_bearer_token_from_header(
31
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
32
-
) {
33
-
Some(t) => t,
34
-
None => {
35
-
return (
36
-
StatusCode::UNAUTHORIZED,
37
-
Json(json!({"error": "AuthenticationRequired"})),
38
-
)
39
-
.into_response();
40
-
}
41
-
};
42
-
43
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
44
-
let (did, key_bytes) = match auth_result {
45
-
Ok(user) => {
46
-
let kb = match user.key_bytes {
47
-
Some(kb) => kb,
48
-
None => {
49
-
return (
50
-
StatusCode::UNAUTHORIZED,
51
-
Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot create service auth"})),
52
-
)
53
-
.into_response();
54
-
}
55
-
};
56
-
(user.did, kb)
57
-
}
58
-
Err(e) => {
59
-
return (
60
-
StatusCode::UNAUTHORIZED,
61
-
Json(json!({"error": e})),
62
-
)
63
-
.into_response();
64
-
}
65
-
};
66
-
67
-
let lxm = params.lxm.as_deref().unwrap_or("*");
68
-
69
-
let service_token = match crate::auth::create_service_token(&did, ¶ms.aud, lxm, &key_bytes)
70
-
{
71
-
Ok(t) => t,
72
-
Err(e) => {
73
-
error!("Failed to create service token: {:?}", e);
74
-
return (
75
-
StatusCode::INTERNAL_SERVER_ERROR,
76
-
Json(json!({"error": "InternalError"})),
77
-
)
78
-
.into_response();
79
-
}
80
-
};
81
-
82
-
(StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response()
83
-
}
84
-
85
-
#[derive(Deserialize)]
86
pub struct CreateSessionInput {
87
pub identifier: String,
88
pub password: String,
···
101
State(state): State<AppState>,
102
Json(input): Json<CreateSessionInput>,
103
) -> Response {
104
-
info!("create_session: identifier='{}'", input.identifier);
105
106
-
let user_row = sqlx::query!(
107
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
108
input.identifier
109
)
110
-
.fetch_optional(&state.db)
111
-
.await;
112
-
113
-
match user_row {
114
-
Ok(Some(row)) => {
115
-
let user_id = row.id;
116
-
let stored_hash = &row.password_hash;
117
-
let did = &row.did;
118
-
let handle = &row.handle;
119
-
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
120
-
Ok(k) => k,
121
-
Err(e) => {
122
-
error!("Failed to decrypt user key: {:?}", e);
123
-
return (
124
-
StatusCode::INTERNAL_SERVER_ERROR,
125
-
Json(json!({"error": "InternalError"})),
126
-
)
127
-
.into_response();
128
-
}
129
-
};
130
-
131
-
let password_valid = if verify(&input.password, stored_hash).unwrap_or(false) {
132
-
true
133
-
} else {
134
-
let app_pass_rows = sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", user_id)
135
-
.fetch_all(&state.db)
136
-
.await
137
-
.unwrap_or_default();
138
-
139
-
app_pass_rows.iter().any(|row| {
140
-
verify(&input.password, &row.password_hash).unwrap_or(false)
141
-
})
142
-
};
143
-
144
-
if password_valid {
145
-
let access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) {
146
-
Ok(m) => m,
147
-
Err(e) => {
148
-
error!("Failed to create access token: {:?}", e);
149
-
return (
150
-
StatusCode::INTERNAL_SERVER_ERROR,
151
-
Json(json!({"error": "InternalError"})),
152
-
)
153
-
.into_response();
154
-
}
155
-
};
156
-
157
-
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) {
158
-
Ok(m) => m,
159
-
Err(e) => {
160
-
error!("Failed to create refresh token: {:?}", e);
161
-
return (
162
-
StatusCode::INTERNAL_SERVER_ERROR,
163
-
Json(json!({"error": "InternalError"})),
164
-
)
165
-
.into_response();
166
-
}
167
-
};
168
-
169
-
let session_insert = sqlx::query!(
170
-
"INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)",
171
-
did,
172
-
access_meta.jti,
173
-
refresh_meta.jti,
174
-
access_meta.expires_at,
175
-
refresh_meta.expires_at
176
-
)
177
-
.execute(&state.db)
178
-
.await;
179
-
180
-
match session_insert {
181
-
Ok(_) => {
182
-
return (
183
-
StatusCode::OK,
184
-
Json(CreateSessionOutput {
185
-
access_jwt: access_meta.token,
186
-
refresh_jwt: refresh_meta.token,
187
-
handle: handle.clone(),
188
-
did: did.clone(),
189
-
}),
190
-
)
191
-
.into_response();
192
-
}
193
-
Err(e) => {
194
-
error!("Failed to insert session: {:?}", e);
195
-
return (
196
-
StatusCode::INTERNAL_SERVER_ERROR,
197
-
Json(json!({"error": "InternalError"})),
198
-
)
199
-
.into_response();
200
-
}
201
-
}
202
-
} else {
203
-
warn!(
204
-
"Password verification failed for identifier: {}",
205
-
input.identifier
206
-
);
207
-
}
208
-
}
209
Ok(None) => {
210
-
warn!("User not found for identifier: {}", input.identifier);
211
}
212
Err(e) => {
213
error!("Database error fetching user: {:?}", e);
214
-
return (
215
-
StatusCode::INTERNAL_SERVER_ERROR,
216
-
Json(json!({"error": "InternalError"})),
217
-
)
218
-
.into_response();
219
}
220
}
221
222
-
(
223
-
StatusCode::UNAUTHORIZED,
224
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})),
225
-
)
226
-
.into_response()
227
-
}
228
-
229
-
pub async fn get_session(
230
-
State(state): State<AppState>,
231
-
headers: axum::http::HeaderMap,
232
-
) -> Response {
233
-
let token = match crate::auth::extract_bearer_token_from_header(
234
-
headers.get("Authorization").and_then(|h| h.to_str().ok())
235
-
) {
236
-
Some(t) => t,
237
-
None => {
238
-
return (
239
-
StatusCode::UNAUTHORIZED,
240
-
Json(json!({"error": "AuthenticationRequired", "message": "Invalid Authorization header format"})),
241
-
)
242
-
.into_response();
243
}
244
};
245
246
-
let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
247
-
let did = match auth_result {
248
-
Ok(user) => user.did,
249
Err(e) => {
250
-
return (
251
-
StatusCode::UNAUTHORIZED,
252
-
Json(json!({"error": e})),
253
-
)
254
-
.into_response();
255
}
256
};
257
258
-
let user = sqlx::query!(
259
-
"SELECT handle, email FROM users WHERE did = $1",
260
-
did
261
)
262
-
.fetch_optional(&state.db)
263
-
.await;
264
265
-
match user {
266
-
Ok(Some(row)) => {
267
-
return (
268
-
StatusCode::OK,
269
-
Json(json!({
270
-
"handle": row.handle,
271
-
"did": did,
272
-
"email": row.email,
273
-
"didDoc": {}
274
-
})),
275
-
)
276
-
.into_response();
277
-
}
278
-
Ok(None) => {
279
-
return (
280
-
StatusCode::UNAUTHORIZED,
281
-
Json(json!({"error": "AuthenticationFailed"})),
282
-
)
283
-
.into_response();
284
-
}
285
Err(e) => {
286
error!("Database error in get_session: {:?}", e);
287
-
return (
288
-
StatusCode::INTERNAL_SERVER_ERROR,
289
-
Json(json!({"error": "InternalError"})),
290
-
)
291
-
.into_response();
292
}
293
}
294
}
···
301
headers.get("Authorization").and_then(|h| h.to_str().ok())
302
) {
303
Some(t) => t,
304
-
None => {
305
-
return (
306
-
StatusCode::UNAUTHORIZED,
307
-
Json(json!({"error": "AuthenticationRequired"})),
308
-
)
309
-
.into_response();
310
-
}
311
};
312
313
-
let jti = match crate::auth::get_did_from_token(&token) {
314
-
Ok(_) => {
315
-
let parts: Vec<&str> = token.split('.').collect();
316
-
if parts.len() != 3 {
317
-
return (
318
-
StatusCode::UNAUTHORIZED,
319
-
Json(json!({"error": "AuthenticationFailed"})),
320
-
)
321
-
.into_response();
322
-
}
323
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
324
-
let claims_json = match URL_SAFE_NO_PAD.decode(parts[1]) {
325
-
Ok(bytes) => bytes,
326
-
Err(_) => {
327
-
return (
328
-
StatusCode::UNAUTHORIZED,
329
-
Json(json!({"error": "AuthenticationFailed"})),
330
-
)
331
-
.into_response();
332
-
}
333
-
};
334
-
let claims: serde_json::Value = match serde_json::from_slice(&claims_json) {
335
-
Ok(c) => c,
336
-
Err(_) => {
337
-
return (
338
-
StatusCode::UNAUTHORIZED,
339
-
Json(json!({"error": "AuthenticationFailed"})),
340
-
)
341
-
.into_response();
342
-
}
343
-
};
344
-
match claims.get("jti").and_then(|j| j.as_str()) {
345
-
Some(jti) => jti.to_string(),
346
-
None => {
347
-
return (
348
-
StatusCode::UNAUTHORIZED,
349
-
Json(json!({"error": "AuthenticationFailed"})),
350
-
)
351
-
.into_response();
352
-
}
353
-
}
354
-
}
355
-
Err(_) => {
356
-
return (
357
-
StatusCode::UNAUTHORIZED,
358
-
Json(json!({"error": "AuthenticationFailed"})),
359
-
)
360
-
.into_response();
361
-
}
362
};
363
364
-
let result = sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti)
365
.execute(&state.db)
366
-
.await;
367
-
368
-
match result {
369
-
Ok(res) => {
370
-
if res.rows_affected() > 0 {
371
-
return (StatusCode::OK, Json(json!({}))).into_response();
372
-
}
373
-
}
374
Err(e) => {
375
error!("Database error in delete_session: {:?}", e);
376
}
377
}
378
-
379
-
(
380
-
StatusCode::UNAUTHORIZED,
381
-
Json(json!({"error": "AuthenticationFailed"})),
382
-
)
383
-
.into_response()
384
}
385
386
pub async fn refresh_session(
387
State(state): State<AppState>,
388
headers: axum::http::HeaderMap,
389
) -> Response {
390
-
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
391
-
392
let refresh_token = match crate::auth::extract_bearer_token_from_header(
393
headers.get("Authorization").and_then(|h| h.to_str().ok())
394
) {
395
Some(t) => t,
396
-
None => {
397
-
return (
398
-
StatusCode::UNAUTHORIZED,
399
-
Json(json!({"error": "AuthenticationRequired"})),
400
-
)
401
-
.into_response();
402
-
}
403
};
404
405
-
let refresh_jti = {
406
-
let parts: Vec<&str> = refresh_token.split('.').collect();
407
-
if parts.len() != 3 {
408
-
return (
409
-
StatusCode::UNAUTHORIZED,
410
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid token format"})),
411
-
)
412
-
.into_response();
413
-
}
414
-
let claims_bytes = match URL_SAFE_NO_PAD.decode(parts[1]) {
415
-
Ok(b) => b,
416
-
Err(_) => {
417
-
return (
418
-
StatusCode::UNAUTHORIZED,
419
-
Json(json!({"error": "AuthenticationFailed"})),
420
-
)
421
-
.into_response();
422
-
}
423
-
};
424
-
let claims: serde_json::Value = match serde_json::from_slice(&claims_bytes) {
425
-
Ok(c) => c,
426
-
Err(_) => {
427
-
return (
428
-
StatusCode::UNAUTHORIZED,
429
-
Json(json!({"error": "AuthenticationFailed"})),
430
-
)
431
-
.into_response();
432
-
}
433
-
};
434
-
match claims.get("jti").and_then(|j| j.as_str()) {
435
-
Some(jti) => jti.to_string(),
436
-
None => {
437
-
return (
438
-
StatusCode::UNAUTHORIZED,
439
-
Json(json!({"error": "AuthenticationFailed"})),
440
-
)
441
-
.into_response();
442
-
}
443
}
444
};
445
446
-
let reuse_check = sqlx::query_scalar!(
447
-
"SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1",
448
refresh_jti
449
)
450
-
.fetch_optional(&state.db)
451
-
.await;
452
-
453
-
if let Ok(Some(session_id)) = reuse_check {
454
warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id);
455
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
456
-
.execute(&state.db)
457
.await;
458
-
return (
459
-
StatusCode::UNAUTHORIZED,
460
-
Json(json!({"error": "ExpiredToken", "message": "Refresh token has been revoked due to suspected compromise"})),
461
-
)
462
-
.into_response();
463
}
464
465
-
let session = sqlx::query!(
466
r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version
467
FROM session_tokens st
468
JOIN users u ON st.did = u.did
469
JOIN user_keys k ON u.id = k.user_id
470
-
WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()"#,
471
refresh_jti
472
)
473
-
.fetch_optional(&state.db)
474
-
.await;
475
476
-
match session {
477
-
Ok(Some(session_row)) => {
478
-
let session_id = session_row.id;
479
-
let did = &session_row.did;
480
-
let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
481
-
Ok(k) => k,
482
-
Err(e) => {
483
-
error!("Failed to decrypt user key: {:?}", e);
484
-
return (
485
-
StatusCode::INTERNAL_SERVER_ERROR,
486
-
Json(json!({"error": "InternalError"})),
487
-
)
488
-
.into_response();
489
-
}
490
-
};
491
492
-
if let Err(_) = crate::auth::verify_refresh_token(&refresh_token, &key_bytes) {
493
-
return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response();
494
-
}
495
496
-
let new_access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) {
497
-
Ok(m) => m,
498
-
Err(e) => {
499
-
error!("Failed to create access token: {:?}", e);
500
-
return (
501
-
StatusCode::INTERNAL_SERVER_ERROR,
502
-
Json(json!({"error": "InternalError"})),
503
-
)
504
-
.into_response();
505
-
}
506
-
};
507
-
let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) {
508
-
Ok(m) => m,
509
-
Err(e) => {
510
-
error!("Failed to create refresh token: {:?}", e);
511
-
return (
512
-
StatusCode::INTERNAL_SERVER_ERROR,
513
-
Json(json!({"error": "InternalError"})),
514
-
)
515
-
.into_response();
516
-
}
517
-
};
518
519
-
let mut tx = match state.db.begin().await {
520
-
Ok(tx) => tx,
521
-
Err(e) => {
522
-
error!("Failed to begin transaction: {:?}", e);
523
-
return (
524
-
StatusCode::INTERNAL_SERVER_ERROR,
525
-
Json(json!({"error": "InternalError"})),
526
-
)
527
-
.into_response();
528
-
}
529
-
};
530
531
-
if let Err(e) = sqlx::query!(
532
-
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2)",
533
-
refresh_jti,
534
-
session_id
535
-
)
536
-
.execute(&mut *tx)
537
-
.await
538
-
{
539
-
error!("Failed to record used refresh token: {:?}", e);
540
-
return (
541
-
StatusCode::INTERNAL_SERVER_ERROR,
542
-
Json(json!({"error": "InternalError"})),
543
-
)
544
-
.into_response();
545
-
}
546
547
-
if let Err(e) = sqlx::query!(
548
-
"UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5",
549
-
new_access_meta.jti,
550
-
new_refresh_meta.jti,
551
-
new_access_meta.expires_at,
552
-
new_refresh_meta.expires_at,
553
-
session_id
554
-
)
555
-
.execute(&mut *tx)
556
-
.await
557
-
{
558
-
error!("Database error updating session: {:?}", e);
559
-
return (
560
-
StatusCode::INTERNAL_SERVER_ERROR,
561
-
Json(json!({"error": "InternalError"})),
562
-
)
563
-
.into_response();
564
-
}
565
566
-
if let Err(e) = tx.commit().await {
567
-
error!("Failed to commit transaction: {:?}", e);
568
-
return (
569
-
StatusCode::INTERNAL_SERVER_ERROR,
570
-
Json(json!({"error": "InternalError"})),
571
-
)
572
-
.into_response();
573
-
}
574
575
-
let user = sqlx::query!("SELECT handle FROM users WHERE did = $1", did)
576
-
.fetch_optional(&state.db)
577
-
.await;
578
-
579
-
match user {
580
-
Ok(Some(u)) => {
581
-
return (
582
-
StatusCode::OK,
583
-
Json(json!({
584
-
"accessJwt": new_access_meta.token,
585
-
"refreshJwt": new_refresh_meta.token,
586
-
"handle": u.handle,
587
-
"did": did
588
-
})),
589
-
)
590
-
.into_response();
591
-
}
592
-
Ok(None) => {
593
-
error!("User not found for existing session: {}", did);
594
-
return (
595
-
StatusCode::INTERNAL_SERVER_ERROR,
596
-
Json(json!({"error": "InternalError"})),
597
-
)
598
-
.into_response();
599
-
}
600
-
Err(e) => {
601
-
error!("Database error fetching user: {:?}", e);
602
-
return (
603
-
StatusCode::INTERNAL_SERVER_ERROR,
604
-
Json(json!({"error": "InternalError"})),
605
-
)
606
-
.into_response();
607
-
}
608
-
}
609
-
}
610
Ok(None) => {
611
-
return (
612
-
StatusCode::UNAUTHORIZED,
613
-
Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})),
614
-
)
615
-
.into_response();
616
}
617
Err(e) => {
618
-
error!("Database error fetching session: {:?}", e);
619
-
return (
620
-
StatusCode::INTERNAL_SERVER_ERROR,
621
-
Json(json!({"error": "InternalError"})),
622
-
)
623
-
.into_response();
624
}
625
}
626
}
···
1
+
use crate::api::ApiError;
2
+
use crate::auth::BearerAuth;
3
use crate::state::AppState;
4
use axum::{
5
Json,
6
+
extract::State,
7
response::{IntoResponse, Response},
8
};
9
use bcrypt::verify;
···
12
use tracing::{error, info, warn};
13
14
#[derive(Deserialize)]
15
pub struct CreateSessionInput {
16
pub identifier: String,
17
pub password: String,
···
30
State(state): State<AppState>,
31
Json(input): Json<CreateSessionInput>,
32
) -> Response {
33
+
info!("create_session called");
34
35
+
let row = match sqlx::query!(
36
"SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1",
37
input.identifier
38
)
39
+
.fetch_optional(&state.db)
40
+
.await
41
+
{
42
+
Ok(Some(row)) => row,
43
Ok(None) => {
44
+
warn!("User not found for login attempt");
45
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
46
}
47
Err(e) => {
48
error!("Database error fetching user: {:?}", e);
49
+
return ApiError::InternalError.into_response();
50
+
}
51
+
};
52
+
53
+
let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
54
+
Ok(k) => k,
55
+
Err(e) => {
56
+
error!("Failed to decrypt user key: {:?}", e);
57
+
return ApiError::InternalError.into_response();
58
}
59
+
};
60
+
61
+
let password_valid = verify(&input.password, &row.password_hash).unwrap_or(false)
62
+
|| sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", row.id)
63
+
.fetch_all(&state.db)
64
+
.await
65
+
.unwrap_or_default()
66
+
.iter()
67
+
.any(|app| verify(&input.password, &app.password_hash).unwrap_or(false));
68
+
69
+
if !password_valid {
70
+
warn!("Password verification failed for login attempt");
71
+
return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response();
72
}
73
74
+
let access_meta = match crate::auth::create_access_token_with_metadata(&row.did, &key_bytes) {
75
+
Ok(m) => m,
76
+
Err(e) => {
77
+
error!("Failed to create access token: {:?}", e);
78
+
return ApiError::InternalError.into_response();
79
}
80
};
81
82
+
let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&row.did, &key_bytes) {
83
+
Ok(m) => m,
84
Err(e) => {
85
+
error!("Failed to create refresh token: {:?}", e);
86
+
return ApiError::InternalError.into_response();
87
}
88
};
89
90
+
if let Err(e) = sqlx::query!(
91
+
"INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)",
92
+
row.did,
93
+
access_meta.jti,
94
+
refresh_meta.jti,
95
+
access_meta.expires_at,
96
+
refresh_meta.expires_at
97
)
98
+
.execute(&state.db)
99
+
.await
100
+
{
101
+
error!("Failed to insert session: {:?}", e);
102
+
return ApiError::InternalError.into_response();
103
+
}
104
105
+
Json(CreateSessionOutput {
106
+
access_jwt: access_meta.token,
107
+
refresh_jwt: refresh_meta.token,
108
+
handle: row.handle,
109
+
did: row.did,
110
+
}).into_response()
111
+
}
112
+
113
+
pub async fn get_session(
114
+
State(state): State<AppState>,
115
+
BearerAuth(auth_user): BearerAuth,
116
+
) -> Response {
117
+
match sqlx::query!("SELECT handle, email FROM users WHERE did = $1", auth_user.did)
118
+
.fetch_optional(&state.db)
119
+
.await
120
+
{
121
+
Ok(Some(row)) => Json(json!({
122
+
"handle": row.handle,
123
+
"did": auth_user.did,
124
+
"email": row.email,
125
+
"didDoc": {}
126
+
})).into_response(),
127
+
Ok(None) => ApiError::AuthenticationFailed.into_response(),
128
Err(e) => {
129
error!("Database error in get_session: {:?}", e);
130
+
ApiError::InternalError.into_response()
131
}
132
}
133
}
···
140
headers.get("Authorization").and_then(|h| h.to_str().ok())
141
) {
142
Some(t) => t,
143
+
None => return ApiError::AuthenticationRequired.into_response(),
144
};
145
146
+
let jti = match crate::auth::get_jti_from_token(&token) {
147
+
Ok(jti) => jti,
148
+
Err(_) => return ApiError::AuthenticationFailed.into_response(),
149
};
150
151
+
match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti)
152
.execute(&state.db)
153
+
.await
154
+
{
155
+
Ok(res) if res.rows_affected() > 0 => Json(json!({})).into_response(),
156
+
Ok(_) => ApiError::AuthenticationFailed.into_response(),
157
Err(e) => {
158
error!("Database error in delete_session: {:?}", e);
159
+
ApiError::AuthenticationFailed.into_response()
160
}
161
}
162
}
163
164
pub async fn refresh_session(
165
State(state): State<AppState>,
166
headers: axum::http::HeaderMap,
167
) -> Response {
168
let refresh_token = match crate::auth::extract_bearer_token_from_header(
169
headers.get("Authorization").and_then(|h| h.to_str().ok())
170
) {
171
Some(t) => t,
172
+
None => return ApiError::AuthenticationRequired.into_response(),
173
};
174
175
+
let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) {
176
+
Ok(jti) => jti,
177
+
Err(_) => return ApiError::AuthenticationFailedMsg("Invalid token format".into()).into_response(),
178
+
};
179
+
180
+
let mut tx = match state.db.begin().await {
181
+
Ok(tx) => tx,
182
+
Err(e) => {
183
+
error!("Failed to begin transaction: {:?}", e);
184
+
return ApiError::InternalError.into_response();
185
}
186
};
187
188
+
if let Ok(Some(session_id)) = sqlx::query_scalar!(
189
+
"SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE",
190
refresh_jti
191
)
192
+
.fetch_optional(&mut *tx)
193
+
.await
194
+
{
195
warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id);
196
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
197
+
.execute(&mut *tx)
198
.await;
199
+
let _ = tx.commit().await;
200
+
return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response();
201
}
202
203
+
let session_row = match sqlx::query!(
204
r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version
205
FROM session_tokens st
206
JOIN users u ON st.did = u.did
207
JOIN user_keys k ON u.id = k.user_id
208
+
WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()
209
+
FOR UPDATE OF st"#,
210
refresh_jti
211
)
212
+
.fetch_optional(&mut *tx)
213
+
.await
214
+
{
215
+
Ok(Some(row)) => row,
216
+
Ok(None) => return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(),
217
+
Err(e) => {
218
+
error!("Database error fetching session: {:?}", e);
219
+
return ApiError::InternalError.into_response();
220
+
}
221
+
};
222
223
+
let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) {
224
+
Ok(k) => k,
225
+
Err(e) => {
226
+
error!("Failed to decrypt user key: {:?}", e);
227
+
return ApiError::InternalError.into_response();
228
+
}
229
+
};
230
231
+
if crate::auth::verify_refresh_token(&refresh_token, &key_bytes).is_err() {
232
+
return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response();
233
+
}
234
235
+
let new_access_meta = match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) {
236
+
Ok(m) => m,
237
+
Err(e) => {
238
+
error!("Failed to create access token: {:?}", e);
239
+
return ApiError::InternalError.into_response();
240
+
}
241
+
};
242
243
+
let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) {
244
+
Ok(m) => m,
245
+
Err(e) => {
246
+
error!("Failed to create refresh token: {:?}", e);
247
+
return ApiError::InternalError.into_response();
248
+
}
249
+
};
250
251
+
match sqlx::query!(
252
+
"INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
253
+
refresh_jti,
254
+
session_row.id
255
+
)
256
+
.execute(&mut *tx)
257
+
.await
258
+
{
259
+
Ok(result) if result.rows_affected() == 0 => {
260
+
warn!("Concurrent refresh token reuse detected for session_id: {}", session_row.id);
261
+
let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_row.id)
262
+
.execute(&mut *tx)
263
+
.await;
264
+
let _ = tx.commit().await;
265
+
return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response();
266
+
}
267
+
Err(e) => {
268
+
error!("Failed to record used refresh token: {:?}", e);
269
+
return ApiError::InternalError.into_response();
270
+
}
271
+
Ok(_) => {}
272
+
}
273
274
+
if let Err(e) = sqlx::query!(
275
+
"UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5",
276
+
new_access_meta.jti,
277
+
new_refresh_meta.jti,
278
+
new_access_meta.expires_at,
279
+
new_refresh_meta.expires_at,
280
+
session_row.id
281
+
)
282
+
.execute(&mut *tx)
283
+
.await
284
+
{
285
+
error!("Database error updating session: {:?}", e);
286
+
return ApiError::InternalError.into_response();
287
+
}
288
289
+
if let Err(e) = tx.commit().await {
290
+
error!("Failed to commit transaction: {:?}", e);
291
+
return ApiError::InternalError.into_response();
292
+
}
293
294
+
match sqlx::query!("SELECT handle FROM users WHERE did = $1", session_row.did)
295
+
.fetch_optional(&state.db)
296
+
.await
297
+
{
298
+
Ok(Some(u)) => Json(json!({
299
+
"accessJwt": new_access_meta.token,
300
+
"refreshJwt": new_refresh_meta.token,
301
+
"handle": u.handle,
302
+
"did": session_row.did
303
+
})).into_response(),
304
Ok(None) => {
305
+
error!("User not found for existing session: {}", session_row.did);
306
+
ApiError::InternalError.into_response()
307
}
308
Err(e) => {
309
+
error!("Database error fetching user: {:?}", e);
310
+
ApiError::InternalError.into_response()
311
}
312
}
313
}
+104
src/api/validation.rs
+104
src/api/validation.rs
···
···
1
+
pub const MAX_EMAIL_LENGTH: usize = 254;
2
+
pub const MAX_LOCAL_PART_LENGTH: usize = 64;
3
+
pub const MAX_DOMAIN_LENGTH: usize = 253;
4
+
pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63;
5
+
6
+
const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-";
7
+
8
+
pub fn is_valid_email(email: &str) -> bool {
9
+
let email = email.trim();
10
+
11
+
if email.is_empty() || email.len() > MAX_EMAIL_LENGTH {
12
+
return false;
13
+
}
14
+
15
+
let parts: Vec<&str> = email.rsplitn(2, '@').collect();
16
+
if parts.len() != 2 {
17
+
return false;
18
+
}
19
+
20
+
let domain = parts[0];
21
+
let local = parts[1];
22
+
23
+
if local.is_empty() || local.len() > MAX_LOCAL_PART_LENGTH {
24
+
return false;
25
+
}
26
+
27
+
if local.starts_with('.') || local.ends_with('.') {
28
+
return false;
29
+
}
30
+
31
+
if local.contains("..") {
32
+
return false;
33
+
}
34
+
35
+
for c in local.chars() {
36
+
if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) {
37
+
return false;
38
+
}
39
+
}
40
+
41
+
if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH {
42
+
return false;
43
+
}
44
+
45
+
if !domain.contains('.') {
46
+
return false;
47
+
}
48
+
49
+
for label in domain.split('.') {
50
+
if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH {
51
+
return false;
52
+
}
53
+
54
+
if label.starts_with('-') || label.ends_with('-') {
55
+
return false;
56
+
}
57
+
58
+
for c in label.chars() {
59
+
if !c.is_ascii_alphanumeric() && c != '-' {
60
+
return false;
61
+
}
62
+
}
63
+
}
64
+
65
+
true
66
+
}
67
+
68
+
#[cfg(test)]
69
+
mod tests {
70
+
use super::*;
71
+
72
+
#[test]
73
+
fn test_valid_emails() {
74
+
assert!(is_valid_email("user@example.com"));
75
+
assert!(is_valid_email("user.name@example.com"));
76
+
assert!(is_valid_email("user+tag@example.com"));
77
+
assert!(is_valid_email("user@sub.example.com"));
78
+
assert!(is_valid_email("USER@EXAMPLE.COM"));
79
+
assert!(is_valid_email("user123@example123.com"));
80
+
assert!(is_valid_email("a@b.co"));
81
+
}
82
+
83
+
#[test]
84
+
fn test_invalid_emails() {
85
+
assert!(!is_valid_email(""));
86
+
assert!(!is_valid_email("user"));
87
+
assert!(!is_valid_email("user@"));
88
+
assert!(!is_valid_email("@example.com"));
89
+
assert!(!is_valid_email("user@example"));
90
+
assert!(!is_valid_email("user@@example.com"));
91
+
assert!(!is_valid_email("user@.example.com"));
92
+
assert!(!is_valid_email("user@example..com"));
93
+
assert!(!is_valid_email(".user@example.com"));
94
+
assert!(!is_valid_email("user.@example.com"));
95
+
assert!(!is_valid_email("user..name@example.com"));
96
+
assert!(!is_valid_email("user@-example.com"));
97
+
assert!(!is_valid_email("user@example-.com"));
98
+
}
99
+
100
+
#[test]
101
+
fn test_trimmed_whitespace() {
102
+
assert!(is_valid_email(" user@example.com "));
103
+
}
104
+
}
+29
-3
src/auth/extractor.rs
+29
-3
src/auth/extractor.rs
···
7
use serde_json::json;
8
9
use crate::state::AppState;
10
-
use super::{AuthenticatedUser, validate_bearer_token};
11
12
pub struct BearerAuth(pub AuthenticatedUser);
13
···
112
113
match validate_bearer_token(&state.db, token).await {
114
Ok(user) => Ok(BearerAuth(user)),
115
-
Err("AccountDeactivated") => Err(AuthError::AccountDeactivated),
116
-
Err("AccountTakedown") => Err(AuthError::AccountTakedown),
117
Err(_) => Err(AuthError::AuthenticationFailed),
118
}
119
}
···
7
use serde_json::json;
8
9
use crate::state::AppState;
10
+
use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token, validate_bearer_token_allow_deactivated};
11
12
pub struct BearerAuth(pub AuthenticatedUser);
13
···
112
113
match validate_bearer_token(&state.db, token).await {
114
Ok(user) => Ok(BearerAuth(user)),
115
+
Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated),
116
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
117
+
Err(_) => Err(AuthError::AuthenticationFailed),
118
+
}
119
+
}
120
+
}
121
+
122
+
pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser);
123
+
124
+
impl FromRequestParts<AppState> for BearerAuthAllowDeactivated {
125
+
type Rejection = AuthError;
126
+
127
+
async fn from_request_parts(
128
+
parts: &mut Parts,
129
+
state: &AppState,
130
+
) -> Result<Self, Self::Rejection> {
131
+
let auth_header = parts
132
+
.headers
133
+
.get(AUTHORIZATION)
134
+
.ok_or(AuthError::MissingToken)?
135
+
.to_str()
136
+
.map_err(|_| AuthError::InvalidFormat)?;
137
+
138
+
let token = extract_bearer_token(auth_header)?;
139
+
140
+
match validate_bearer_token_allow_deactivated(&state.db, token).await {
141
+
Ok(user) => Ok(BearerAuthAllowDeactivated(user)),
142
+
Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown),
143
Err(_) => Err(AuthError::AuthenticationFailed),
144
}
145
}
+31
-13
src/auth/mod.rs
+31
-13
src/auth/mod.rs
···
1
use serde::{Deserialize, Serialize};
2
use sqlx::PgPool;
3
4
pub mod extractor;
5
pub mod token;
6
pub mod verify;
7
8
-
pub use extractor::{BearerAuth, AuthError, extract_bearer_token_from_header};
9
pub use token::{
10
create_access_token, create_refresh_token, create_service_token,
11
create_access_token_with_metadata, create_refresh_token_with_metadata,
···
14
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
15
};
16
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
17
18
pub struct AuthenticatedUser {
19
pub did: String,
···
24
pub async fn validate_bearer_token(
25
db: &PgPool,
26
token: &str,
27
-
) -> Result<AuthenticatedUser, &'static str> {
28
validate_bearer_token_with_options(db, token, false).await
29
}
30
31
pub async fn validate_bearer_token_allow_deactivated(
32
db: &PgPool,
33
token: &str,
34
-
) -> Result<AuthenticatedUser, &'static str> {
35
validate_bearer_token_with_options(db, token, true).await
36
}
37
···
39
db: &PgPool,
40
token: &str,
41
allow_deactivated: bool,
42
-
) -> Result<AuthenticatedUser, &'static str> {
43
let did_from_token = get_did_from_token(token).ok();
44
45
if let Some(ref did) = did_from_token {
···
56
.flatten()
57
{
58
if !allow_deactivated && user.deactivated_at.is_some() {
59
-
return Err("AccountDeactivated");
60
}
61
if user.takedown_ref.is_some() {
62
-
return Err("AccountTakedown");
63
}
64
65
-
let decrypted_key = match crate::config::decrypt_key(&user.key_bytes, user.encryption_version) {
66
-
Ok(k) => k,
67
-
Err(_) => return Err("KeyDecryptionFailed"),
68
-
};
69
70
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
71
let session_exists = sqlx::query_scalar!(
···
103
.flatten()
104
{
105
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
106
-
return Err("AccountDeactivated");
107
}
108
if oauth_token.takedown_ref.is_some() {
109
-
return Err("AccountTakedown");
110
}
111
112
let now = chrono::Utc::now();
···
120
}
121
}
122
123
-
Err("AuthenticationFailed")
124
}
125
126
#[derive(Debug, Serialize, Deserialize)]
···
1
use serde::{Deserialize, Serialize};
2
use sqlx::PgPool;
3
+
use std::fmt;
4
5
pub mod extractor;
6
pub mod token;
7
pub mod verify;
8
9
+
pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
10
pub use token::{
11
create_access_token, create_refresh_token, create_service_token,
12
create_access_token_with_metadata, create_refresh_token_with_metadata,
···
15
SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
16
};
17
pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
18
+
19
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20
+
pub enum TokenValidationError {
21
+
AccountDeactivated,
22
+
AccountTakedown,
23
+
KeyDecryptionFailed,
24
+
AuthenticationFailed,
25
+
}
26
+
27
+
impl fmt::Display for TokenValidationError {
28
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29
+
match self {
30
+
Self::AccountDeactivated => write!(f, "AccountDeactivated"),
31
+
Self::AccountTakedown => write!(f, "AccountTakedown"),
32
+
Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
33
+
Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
34
+
}
35
+
}
36
+
}
37
38
pub struct AuthenticatedUser {
39
pub did: String,
···
44
pub async fn validate_bearer_token(
45
db: &PgPool,
46
token: &str,
47
+
) -> Result<AuthenticatedUser, TokenValidationError> {
48
validate_bearer_token_with_options(db, token, false).await
49
}
50
51
pub async fn validate_bearer_token_allow_deactivated(
52
db: &PgPool,
53
token: &str,
54
+
) -> Result<AuthenticatedUser, TokenValidationError> {
55
validate_bearer_token_with_options(db, token, true).await
56
}
57
···
59
db: &PgPool,
60
token: &str,
61
allow_deactivated: bool,
62
+
) -> Result<AuthenticatedUser, TokenValidationError> {
63
let did_from_token = get_did_from_token(token).ok();
64
65
if let Some(ref did) = did_from_token {
···
76
.flatten()
77
{
78
if !allow_deactivated && user.deactivated_at.is_some() {
79
+
return Err(TokenValidationError::AccountDeactivated);
80
}
81
if user.takedown_ref.is_some() {
82
+
return Err(TokenValidationError::AccountTakedown);
83
}
84
85
+
let decrypted_key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
86
+
.map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
87
88
if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
89
let session_exists = sqlx::query_scalar!(
···
121
.flatten()
122
{
123
if !allow_deactivated && oauth_token.deactivated_at.is_some() {
124
+
return Err(TokenValidationError::AccountDeactivated);
125
}
126
if oauth_token.takedown_ref.is_some() {
127
+
return Err(TokenValidationError::AccountTakedown);
128
}
129
130
let now = chrono::Utc::now();
···
138
}
139
}
140
141
+
Err(TokenValidationError::AuthenticationFailed)
142
}
143
144
#[derive(Debug, Serialize, Deserialize)]
+7
-3
src/config.rs
+7
-3
src/config.rs
···
62
let seed = hasher.finalize();
63
64
let signing_key = SigningKey::from_slice(&seed)
65
-
.expect("Failed to create signing key from seed");
66
67
let verifying_key = signing_key.verifying_key();
68
let point = verifying_key.to_encoded_point(false);
69
70
-
let signing_key_x = URL_SAFE_NO_PAD.encode(point.x().unwrap());
71
-
let signing_key_y = URL_SAFE_NO_PAD.encode(point.y().unwrap());
72
73
let mut kid_hasher = Sha256::new();
74
kid_hasher.update(signing_key_x.as_bytes());
···
62
let seed = hasher.finalize();
63
64
let signing_key = SigningKey::from_slice(&seed)
65
+
.unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e));
66
67
let verifying_key = signing_key.verifying_key();
68
let point = verifying_key.to_encoded_point(false);
69
70
+
let signing_key_x = URL_SAFE_NO_PAD.encode(
71
+
point.x().expect("EC point missing X coordinate - this should never happen")
72
+
);
73
+
let signing_key_y = URL_SAFE_NO_PAD.encode(
74
+
point.y().expect("EC point missing Y coordinate - this should never happen")
75
+
);
76
77
let mut kid_hasher = Sha256::new();
78
kid_hasher.update(signing_key_x.as_bytes());
+1
src/lib.rs
+1
src/lib.rs
+43
-15
src/main.rs
+43
-15
src/main.rs
···
1
use bspds::notifications::{EmailSender, NotificationService};
2
use bspds::state::AppState;
3
use std::net::SocketAddr;
4
use tokio::sync::watch;
5
-
use tracing::{info, warn};
6
7
#[tokio::main]
8
-
async fn main() {
9
dotenvy::dotenv().ok();
10
tracing_subscriber::fmt::init();
11
12
-
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
13
14
let pool = sqlx::postgres::PgPoolOptions::new()
15
-
.max_connections(5)
16
.connect(&database_url)
17
.await
18
-
.expect("Failed to connect to Postgres");
19
20
sqlx::migrate!("./migrations")
21
.run(&pool)
22
.await
23
-
.expect("Failed to run migrations");
24
25
let state = AppState::new(pool.clone()).await;
26
···
50
51
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
52
info!("listening on {}", addr);
53
-
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
54
55
let server_result = axum::serve(listener, app)
56
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
···
59
notification_handle.await.ok();
60
61
if let Err(e) = server_result {
62
-
tracing::error!("Server error: {}", e);
63
}
64
}
65
66
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
67
let ctrl_c = async {
68
-
tokio::signal::ctrl_c()
69
-
.await
70
-
.expect("Failed to install Ctrl+C handler");
71
};
72
73
#[cfg(unix)]
74
let terminate = async {
75
-
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
76
-
.expect("Failed to install signal handler")
77
-
.recv()
78
-
.await;
79
};
80
81
#[cfg(not(unix))]
···
1
use bspds::notifications::{EmailSender, NotificationService};
2
use bspds::state::AppState;
3
use std::net::SocketAddr;
4
+
use std::process::ExitCode;
5
use tokio::sync::watch;
6
+
use tracing::{error, info, warn};
7
8
#[tokio::main]
9
+
async fn main() -> ExitCode {
10
dotenvy::dotenv().ok();
11
tracing_subscriber::fmt::init();
12
13
+
match run().await {
14
+
Ok(()) => ExitCode::SUCCESS,
15
+
Err(e) => {
16
+
error!("Fatal error: {}", e);
17
+
ExitCode::FAILURE
18
+
}
19
+
}
20
+
}
21
+
22
+
async fn run() -> Result<(), Box<dyn std::error::Error>> {
23
+
let database_url = std::env::var("DATABASE_URL")
24
+
.map_err(|_| "DATABASE_URL environment variable must be set")?;
25
26
let pool = sqlx::postgres::PgPoolOptions::new()
27
+
.max_connections(20)
28
+
.min_connections(2)
29
+
.acquire_timeout(std::time::Duration::from_secs(10))
30
+
.idle_timeout(std::time::Duration::from_secs(300))
31
+
.max_lifetime(std::time::Duration::from_secs(1800))
32
.connect(&database_url)
33
.await
34
+
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
35
36
sqlx::migrate!("./migrations")
37
.run(&pool)
38
.await
39
+
.map_err(|e| format!("Failed to run migrations: {}", e))?;
40
41
let state = AppState::new(pool.clone()).await;
42
···
66
67
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
68
info!("listening on {}", addr);
69
+
let listener = tokio::net::TcpListener::bind(addr)
70
+
.await
71
+
.map_err(|e| format!("Failed to bind to {}: {}", addr, e))?;
72
73
let server_result = axum::serve(listener, app)
74
.with_graceful_shutdown(shutdown_signal(shutdown_tx))
···
77
notification_handle.await.ok();
78
79
if let Err(e) = server_result {
80
+
return Err(format!("Server error: {}", e).into());
81
}
82
+
83
+
Ok(())
84
}
85
86
async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) {
87
let ctrl_c = async {
88
+
match tokio::signal::ctrl_c().await {
89
+
Ok(()) => {}
90
+
Err(e) => {
91
+
error!("Failed to install Ctrl+C handler: {}", e);
92
+
}
93
+
}
94
};
95
96
#[cfg(unix)]
97
let terminate = async {
98
+
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
99
+
Ok(mut signal) => {
100
+
signal.recv().await;
101
+
}
102
+
Err(e) => {
103
+
error!("Failed to install SIGTERM handler: {}", e);
104
+
std::future::pending::<()>().await;
105
+
}
106
+
}
107
};
108
109
#[cfg(not(unix))]
-641
src/oauth/db.rs
-641
src/oauth/db.rs
···
1
-
use chrono::{DateTime, Utc};
2
-
use serde::{de::DeserializeOwned, Serialize};
3
-
use sqlx::PgPool;
4
-
5
-
use super::{
6
-
AuthorizationRequestParameters, ClientAuth, DeviceData, OAuthError, RequestData, TokenData,
7
-
AuthorizedClientData,
8
-
};
9
-
10
-
fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
11
-
serde_json::to_value(value).map_err(|e| {
12
-
tracing::error!("JSON serialization error: {}", e);
13
-
OAuthError::ServerError("Internal serialization error".to_string())
14
-
})
15
-
}
16
-
17
-
fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
18
-
serde_json::from_value(value).map_err(|e| {
19
-
tracing::error!("JSON deserialization error: {}", e);
20
-
OAuthError::ServerError("Internal data corruption".to_string())
21
-
})
22
-
}
23
-
24
-
pub async fn create_device(
25
-
pool: &PgPool,
26
-
device_id: &str,
27
-
data: &DeviceData,
28
-
) -> Result<(), OAuthError> {
29
-
sqlx::query!(
30
-
r#"
31
-
INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at)
32
-
VALUES ($1, $2, $3, $4, $5)
33
-
"#,
34
-
device_id,
35
-
data.session_id,
36
-
data.user_agent,
37
-
data.ip_address,
38
-
data.last_seen_at,
39
-
)
40
-
.execute(pool)
41
-
.await?;
42
-
43
-
Ok(())
44
-
}
45
-
46
-
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
47
-
let row = sqlx::query!(
48
-
r#"
49
-
SELECT session_id, user_agent, ip_address, last_seen_at
50
-
FROM oauth_device
51
-
WHERE id = $1
52
-
"#,
53
-
device_id
54
-
)
55
-
.fetch_optional(pool)
56
-
.await?;
57
-
58
-
Ok(row.map(|r| DeviceData {
59
-
session_id: r.session_id,
60
-
user_agent: r.user_agent,
61
-
ip_address: r.ip_address,
62
-
last_seen_at: r.last_seen_at,
63
-
}))
64
-
}
65
-
66
-
pub async fn update_device_last_seen(
67
-
pool: &PgPool,
68
-
device_id: &str,
69
-
) -> Result<(), OAuthError> {
70
-
sqlx::query!(
71
-
r#"
72
-
UPDATE oauth_device
73
-
SET last_seen_at = NOW()
74
-
WHERE id = $1
75
-
"#,
76
-
device_id
77
-
)
78
-
.execute(pool)
79
-
.await?;
80
-
81
-
Ok(())
82
-
}
83
-
84
-
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
85
-
sqlx::query!(
86
-
r#"
87
-
DELETE FROM oauth_device WHERE id = $1
88
-
"#,
89
-
device_id
90
-
)
91
-
.execute(pool)
92
-
.await?;
93
-
94
-
Ok(())
95
-
}
96
-
97
-
pub async fn create_authorization_request(
98
-
pool: &PgPool,
99
-
request_id: &str,
100
-
data: &RequestData,
101
-
) -> Result<(), OAuthError> {
102
-
let client_auth_json = match &data.client_auth {
103
-
Some(ca) => Some(to_json(ca)?),
104
-
None => None,
105
-
};
106
-
let parameters_json = to_json(&data.parameters)?;
107
-
108
-
sqlx::query!(
109
-
r#"
110
-
INSERT INTO oauth_authorization_request
111
-
(id, did, device_id, client_id, client_auth, parameters, expires_at, code)
112
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
113
-
"#,
114
-
request_id,
115
-
data.did,
116
-
data.device_id,
117
-
data.client_id,
118
-
client_auth_json,
119
-
parameters_json,
120
-
data.expires_at,
121
-
data.code,
122
-
)
123
-
.execute(pool)
124
-
.await?;
125
-
126
-
Ok(())
127
-
}
128
-
129
-
pub async fn get_authorization_request(
130
-
pool: &PgPool,
131
-
request_id: &str,
132
-
) -> Result<Option<RequestData>, OAuthError> {
133
-
let row = sqlx::query!(
134
-
r#"
135
-
SELECT did, device_id, client_id, client_auth, parameters, expires_at, code
136
-
FROM oauth_authorization_request
137
-
WHERE id = $1
138
-
"#,
139
-
request_id
140
-
)
141
-
.fetch_optional(pool)
142
-
.await?;
143
-
144
-
match row {
145
-
Some(r) => {
146
-
let client_auth: Option<ClientAuth> = match r.client_auth {
147
-
Some(v) => Some(from_json(v)?),
148
-
None => None,
149
-
};
150
-
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
151
-
152
-
Ok(Some(RequestData {
153
-
client_id: r.client_id,
154
-
client_auth,
155
-
parameters,
156
-
expires_at: r.expires_at,
157
-
did: r.did,
158
-
device_id: r.device_id,
159
-
code: r.code,
160
-
}))
161
-
}
162
-
None => Ok(None),
163
-
}
164
-
}
165
-
166
-
pub async fn update_authorization_request(
167
-
pool: &PgPool,
168
-
request_id: &str,
169
-
did: &str,
170
-
device_id: Option<&str>,
171
-
code: &str,
172
-
) -> Result<(), OAuthError> {
173
-
sqlx::query!(
174
-
r#"
175
-
UPDATE oauth_authorization_request
176
-
SET did = $2, device_id = $3, code = $4
177
-
WHERE id = $1
178
-
"#,
179
-
request_id,
180
-
did,
181
-
device_id,
182
-
code
183
-
)
184
-
.execute(pool)
185
-
.await?;
186
-
187
-
Ok(())
188
-
}
189
-
190
-
pub async fn consume_authorization_request_by_code(
191
-
pool: &PgPool,
192
-
code: &str,
193
-
) -> Result<Option<RequestData>, OAuthError> {
194
-
let row = sqlx::query!(
195
-
r#"
196
-
DELETE FROM oauth_authorization_request
197
-
WHERE code = $1
198
-
RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code
199
-
"#,
200
-
code
201
-
)
202
-
.fetch_optional(pool)
203
-
.await?;
204
-
205
-
match row {
206
-
Some(r) => {
207
-
let client_auth: Option<ClientAuth> = match r.client_auth {
208
-
Some(v) => Some(from_json(v)?),
209
-
None => None,
210
-
};
211
-
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
212
-
213
-
Ok(Some(RequestData {
214
-
client_id: r.client_id,
215
-
client_auth,
216
-
parameters,
217
-
expires_at: r.expires_at,
218
-
did: r.did,
219
-
device_id: r.device_id,
220
-
code: r.code,
221
-
}))
222
-
}
223
-
None => Ok(None),
224
-
}
225
-
}
226
-
227
-
pub async fn delete_authorization_request(
228
-
pool: &PgPool,
229
-
request_id: &str,
230
-
) -> Result<(), OAuthError> {
231
-
sqlx::query!(
232
-
r#"
233
-
DELETE FROM oauth_authorization_request WHERE id = $1
234
-
"#,
235
-
request_id
236
-
)
237
-
.execute(pool)
238
-
.await?;
239
-
240
-
Ok(())
241
-
}
242
-
243
-
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
244
-
let result = sqlx::query!(
245
-
r#"
246
-
DELETE FROM oauth_authorization_request
247
-
WHERE expires_at < NOW()
248
-
"#
249
-
)
250
-
.execute(pool)
251
-
.await?;
252
-
253
-
Ok(result.rows_affected())
254
-
}
255
-
256
-
pub async fn create_token(
257
-
pool: &PgPool,
258
-
data: &TokenData,
259
-
) -> Result<i32, OAuthError> {
260
-
let client_auth_json = to_json(&data.client_auth)?;
261
-
let parameters_json = to_json(&data.parameters)?;
262
-
263
-
let row = sqlx::query!(
264
-
r#"
265
-
INSERT INTO oauth_token
266
-
(did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
267
-
device_id, parameters, details, code, current_refresh_token, scope)
268
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
269
-
RETURNING id
270
-
"#,
271
-
data.did,
272
-
data.token_id,
273
-
data.created_at,
274
-
data.updated_at,
275
-
data.expires_at,
276
-
data.client_id,
277
-
client_auth_json,
278
-
data.device_id,
279
-
parameters_json,
280
-
data.details,
281
-
data.code,
282
-
data.current_refresh_token,
283
-
data.scope,
284
-
)
285
-
.fetch_one(pool)
286
-
.await?;
287
-
288
-
Ok(row.id)
289
-
}
290
-
291
-
pub async fn get_token_by_id(
292
-
pool: &PgPool,
293
-
token_id: &str,
294
-
) -> Result<Option<TokenData>, OAuthError> {
295
-
let row = sqlx::query!(
296
-
r#"
297
-
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
298
-
device_id, parameters, details, code, current_refresh_token, scope
299
-
FROM oauth_token
300
-
WHERE token_id = $1
301
-
"#,
302
-
token_id
303
-
)
304
-
.fetch_optional(pool)
305
-
.await?;
306
-
307
-
match row {
308
-
Some(r) => Ok(Some(TokenData {
309
-
did: r.did,
310
-
token_id: r.token_id,
311
-
created_at: r.created_at,
312
-
updated_at: r.updated_at,
313
-
expires_at: r.expires_at,
314
-
client_id: r.client_id,
315
-
client_auth: from_json(r.client_auth)?,
316
-
device_id: r.device_id,
317
-
parameters: from_json(r.parameters)?,
318
-
details: r.details,
319
-
code: r.code,
320
-
current_refresh_token: r.current_refresh_token,
321
-
scope: r.scope,
322
-
})),
323
-
None => Ok(None),
324
-
}
325
-
}
326
-
327
-
pub async fn get_token_by_refresh_token(
328
-
pool: &PgPool,
329
-
refresh_token: &str,
330
-
) -> Result<Option<(i32, TokenData)>, OAuthError> {
331
-
let row = sqlx::query!(
332
-
r#"
333
-
SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
334
-
device_id, parameters, details, code, current_refresh_token, scope
335
-
FROM oauth_token
336
-
WHERE current_refresh_token = $1
337
-
"#,
338
-
refresh_token
339
-
)
340
-
.fetch_optional(pool)
341
-
.await?;
342
-
343
-
match row {
344
-
Some(r) => Ok(Some((
345
-
r.id,
346
-
TokenData {
347
-
did: r.did,
348
-
token_id: r.token_id,
349
-
created_at: r.created_at,
350
-
updated_at: r.updated_at,
351
-
expires_at: r.expires_at,
352
-
client_id: r.client_id,
353
-
client_auth: from_json(r.client_auth)?,
354
-
device_id: r.device_id,
355
-
parameters: from_json(r.parameters)?,
356
-
details: r.details,
357
-
code: r.code,
358
-
current_refresh_token: r.current_refresh_token,
359
-
scope: r.scope,
360
-
},
361
-
))),
362
-
None => Ok(None),
363
-
}
364
-
}
365
-
366
-
pub async fn rotate_token(
367
-
pool: &PgPool,
368
-
old_db_id: i32,
369
-
new_token_id: &str,
370
-
new_refresh_token: &str,
371
-
new_expires_at: DateTime<Utc>,
372
-
) -> Result<(), OAuthError> {
373
-
let mut tx = pool.begin().await?;
374
-
375
-
let old_refresh = sqlx::query_scalar!(
376
-
r#"
377
-
SELECT current_refresh_token FROM oauth_token WHERE id = $1
378
-
"#,
379
-
old_db_id
380
-
)
381
-
.fetch_one(&mut *tx)
382
-
.await?;
383
-
384
-
if let Some(old_rt) = old_refresh {
385
-
sqlx::query!(
386
-
r#"
387
-
INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
388
-
VALUES ($1, $2)
389
-
"#,
390
-
old_rt,
391
-
old_db_id
392
-
)
393
-
.execute(&mut *tx)
394
-
.await?;
395
-
}
396
-
397
-
sqlx::query!(
398
-
r#"
399
-
UPDATE oauth_token
400
-
SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
401
-
WHERE id = $1
402
-
"#,
403
-
old_db_id,
404
-
new_token_id,
405
-
new_refresh_token,
406
-
new_expires_at
407
-
)
408
-
.execute(&mut *tx)
409
-
.await?;
410
-
411
-
tx.commit().await?;
412
-
Ok(())
413
-
}
414
-
415
-
pub async fn check_refresh_token_used(
416
-
pool: &PgPool,
417
-
refresh_token: &str,
418
-
) -> Result<Option<i32>, OAuthError> {
419
-
let row = sqlx::query_scalar!(
420
-
r#"
421
-
SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
422
-
"#,
423
-
refresh_token
424
-
)
425
-
.fetch_optional(pool)
426
-
.await?;
427
-
428
-
Ok(row)
429
-
}
430
-
431
-
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
432
-
sqlx::query!(
433
-
r#"
434
-
DELETE FROM oauth_token WHERE token_id = $1
435
-
"#,
436
-
token_id
437
-
)
438
-
.execute(pool)
439
-
.await?;
440
-
441
-
Ok(())
442
-
}
443
-
444
-
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
445
-
sqlx::query!(
446
-
r#"
447
-
DELETE FROM oauth_token WHERE id = $1
448
-
"#,
449
-
db_id
450
-
)
451
-
.execute(pool)
452
-
.await?;
453
-
454
-
Ok(())
455
-
}
456
-
457
-
pub async fn upsert_account_device(
458
-
pool: &PgPool,
459
-
did: &str,
460
-
device_id: &str,
461
-
) -> Result<(), OAuthError> {
462
-
sqlx::query!(
463
-
r#"
464
-
INSERT INTO oauth_account_device (did, device_id, created_at, updated_at)
465
-
VALUES ($1, $2, NOW(), NOW())
466
-
ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW()
467
-
"#,
468
-
did,
469
-
device_id
470
-
)
471
-
.execute(pool)
472
-
.await?;
473
-
474
-
Ok(())
475
-
}
476
-
477
-
pub async fn upsert_authorized_client(
478
-
pool: &PgPool,
479
-
did: &str,
480
-
client_id: &str,
481
-
data: &AuthorizedClientData,
482
-
) -> Result<(), OAuthError> {
483
-
let data_json = to_json(data)?;
484
-
485
-
sqlx::query!(
486
-
r#"
487
-
INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data)
488
-
VALUES ($1, $2, NOW(), NOW(), $3)
489
-
ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3
490
-
"#,
491
-
did,
492
-
client_id,
493
-
data_json
494
-
)
495
-
.execute(pool)
496
-
.await?;
497
-
498
-
Ok(())
499
-
}
500
-
501
-
pub async fn get_authorized_client(
502
-
pool: &PgPool,
503
-
did: &str,
504
-
client_id: &str,
505
-
) -> Result<Option<AuthorizedClientData>, OAuthError> {
506
-
let row = sqlx::query_scalar!(
507
-
r#"
508
-
SELECT data FROM oauth_authorized_client
509
-
WHERE did = $1 AND client_id = $2
510
-
"#,
511
-
did,
512
-
client_id
513
-
)
514
-
.fetch_optional(pool)
515
-
.await?;
516
-
517
-
match row {
518
-
Some(v) => Ok(Some(from_json(v)?)),
519
-
None => Ok(None),
520
-
}
521
-
}
522
-
523
-
pub async fn list_tokens_for_user(
524
-
pool: &PgPool,
525
-
did: &str,
526
-
) -> Result<Vec<TokenData>, OAuthError> {
527
-
let rows = sqlx::query!(
528
-
r#"
529
-
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
530
-
device_id, parameters, details, code, current_refresh_token, scope
531
-
FROM oauth_token
532
-
WHERE did = $1
533
-
"#,
534
-
did
535
-
)
536
-
.fetch_all(pool)
537
-
.await?;
538
-
539
-
let mut tokens = Vec::with_capacity(rows.len());
540
-
for r in rows {
541
-
tokens.push(TokenData {
542
-
did: r.did,
543
-
token_id: r.token_id,
544
-
created_at: r.created_at,
545
-
updated_at: r.updated_at,
546
-
expires_at: r.expires_at,
547
-
client_id: r.client_id,
548
-
client_auth: from_json(r.client_auth)?,
549
-
device_id: r.device_id,
550
-
parameters: from_json(r.parameters)?,
551
-
details: r.details,
552
-
code: r.code,
553
-
current_refresh_token: r.current_refresh_token,
554
-
scope: r.scope,
555
-
});
556
-
}
557
-
Ok(tokens)
558
-
}
559
-
560
-
pub async fn check_and_record_dpop_jti(
561
-
pool: &PgPool,
562
-
jti: &str,
563
-
) -> Result<bool, OAuthError> {
564
-
let result = sqlx::query!(
565
-
r#"
566
-
INSERT INTO oauth_dpop_jti (jti)
567
-
VALUES ($1)
568
-
ON CONFLICT (jti) DO NOTHING
569
-
"#,
570
-
jti
571
-
)
572
-
.execute(pool)
573
-
.await?;
574
-
575
-
Ok(result.rows_affected() > 0)
576
-
}
577
-
578
-
pub async fn cleanup_expired_dpop_jtis(
579
-
pool: &PgPool,
580
-
max_age_secs: i64,
581
-
) -> Result<u64, OAuthError> {
582
-
let result = sqlx::query!(
583
-
r#"
584
-
DELETE FROM oauth_dpop_jti
585
-
WHERE created_at < NOW() - INTERVAL '1 second' * $1
586
-
"#,
587
-
max_age_secs as f64
588
-
)
589
-
.execute(pool)
590
-
.await?;
591
-
592
-
Ok(result.rows_affected())
593
-
}
594
-
595
-
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
596
-
let count = sqlx::query_scalar!(
597
-
r#"
598
-
SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
599
-
"#,
600
-
did
601
-
)
602
-
.fetch_one(pool)
603
-
.await?;
604
-
605
-
Ok(count)
606
-
}
607
-
608
-
pub async fn delete_oldest_tokens_for_user(
609
-
pool: &PgPool,
610
-
did: &str,
611
-
keep_count: i64,
612
-
) -> Result<u64, OAuthError> {
613
-
let result = sqlx::query!(
614
-
r#"
615
-
DELETE FROM oauth_token
616
-
WHERE id IN (
617
-
SELECT id FROM oauth_token
618
-
WHERE did = $1
619
-
ORDER BY updated_at ASC
620
-
OFFSET $2
621
-
)
622
-
"#,
623
-
did,
624
-
keep_count
625
-
)
626
-
.execute(pool)
627
-
.await?;
628
-
629
-
Ok(result.rows_affected())
630
-
}
631
-
632
-
const MAX_TOKENS_PER_USER: i64 = 100;
633
-
634
-
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
635
-
let count = count_tokens_for_user(pool, did).await?;
636
-
if count > MAX_TOKENS_PER_USER {
637
-
let to_keep = MAX_TOKENS_PER_USER - 1;
638
-
delete_oldest_tokens_for_user(pool, did, to_keep).await?;
639
-
}
640
-
Ok(())
641
-
}
···
+50
src/oauth/db/client.rs
+50
src/oauth/db/client.rs
···
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{AuthorizedClientData, OAuthError};
4
+
use super::helpers::{from_json, to_json};
5
+
6
+
pub async fn upsert_authorized_client(
7
+
pool: &PgPool,
8
+
did: &str,
9
+
client_id: &str,
10
+
data: &AuthorizedClientData,
11
+
) -> Result<(), OAuthError> {
12
+
let data_json = to_json(data)?;
13
+
14
+
sqlx::query!(
15
+
r#"
16
+
INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data)
17
+
VALUES ($1, $2, NOW(), NOW(), $3)
18
+
ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3
19
+
"#,
20
+
did,
21
+
client_id,
22
+
data_json
23
+
)
24
+
.execute(pool)
25
+
.await?;
26
+
27
+
Ok(())
28
+
}
29
+
30
+
pub async fn get_authorized_client(
31
+
pool: &PgPool,
32
+
did: &str,
33
+
client_id: &str,
34
+
) -> Result<Option<AuthorizedClientData>, OAuthError> {
35
+
let row = sqlx::query_scalar!(
36
+
r#"
37
+
SELECT data FROM oauth_authorized_client
38
+
WHERE did = $1 AND client_id = $2
39
+
"#,
40
+
did,
41
+
client_id
42
+
)
43
+
.fetch_optional(pool)
44
+
.await?;
45
+
46
+
match row {
47
+
Some(v) => Ok(Some(from_json(v)?)),
48
+
None => Ok(None),
49
+
}
50
+
}
+96
src/oauth/db/device.rs
+96
src/oauth/db/device.rs
···
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{DeviceData, OAuthError};
4
+
5
+
pub async fn create_device(
6
+
pool: &PgPool,
7
+
device_id: &str,
8
+
data: &DeviceData,
9
+
) -> Result<(), OAuthError> {
10
+
sqlx::query!(
11
+
r#"
12
+
INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at)
13
+
VALUES ($1, $2, $3, $4, $5)
14
+
"#,
15
+
device_id,
16
+
data.session_id,
17
+
data.user_agent,
18
+
data.ip_address,
19
+
data.last_seen_at,
20
+
)
21
+
.execute(pool)
22
+
.await?;
23
+
24
+
Ok(())
25
+
}
26
+
27
+
pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
28
+
let row = sqlx::query!(
29
+
r#"
30
+
SELECT session_id, user_agent, ip_address, last_seen_at
31
+
FROM oauth_device
32
+
WHERE id = $1
33
+
"#,
34
+
device_id
35
+
)
36
+
.fetch_optional(pool)
37
+
.await?;
38
+
39
+
Ok(row.map(|r| DeviceData {
40
+
session_id: r.session_id,
41
+
user_agent: r.user_agent,
42
+
ip_address: r.ip_address,
43
+
last_seen_at: r.last_seen_at,
44
+
}))
45
+
}
46
+
47
+
pub async fn update_device_last_seen(
48
+
pool: &PgPool,
49
+
device_id: &str,
50
+
) -> Result<(), OAuthError> {
51
+
sqlx::query!(
52
+
r#"
53
+
UPDATE oauth_device
54
+
SET last_seen_at = NOW()
55
+
WHERE id = $1
56
+
"#,
57
+
device_id
58
+
)
59
+
.execute(pool)
60
+
.await?;
61
+
62
+
Ok(())
63
+
}
64
+
65
+
pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
66
+
sqlx::query!(
67
+
r#"
68
+
DELETE FROM oauth_device WHERE id = $1
69
+
"#,
70
+
device_id
71
+
)
72
+
.execute(pool)
73
+
.await?;
74
+
75
+
Ok(())
76
+
}
77
+
78
+
pub async fn upsert_account_device(
79
+
pool: &PgPool,
80
+
did: &str,
81
+
device_id: &str,
82
+
) -> Result<(), OAuthError> {
83
+
sqlx::query!(
84
+
r#"
85
+
INSERT INTO oauth_account_device (did, device_id, created_at, updated_at)
86
+
VALUES ($1, $2, NOW(), NOW())
87
+
ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW()
88
+
"#,
89
+
did,
90
+
device_id
91
+
)
92
+
.execute(pool)
93
+
.await?;
94
+
95
+
Ok(())
96
+
}
+38
src/oauth/db/dpop.rs
+38
src/oauth/db/dpop.rs
···
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::OAuthError;
4
+
5
+
pub async fn check_and_record_dpop_jti(
6
+
pool: &PgPool,
7
+
jti: &str,
8
+
) -> Result<bool, OAuthError> {
9
+
let result = sqlx::query!(
10
+
r#"
11
+
INSERT INTO oauth_dpop_jti (jti)
12
+
VALUES ($1)
13
+
ON CONFLICT (jti) DO NOTHING
14
+
"#,
15
+
jti
16
+
)
17
+
.execute(pool)
18
+
.await?;
19
+
20
+
Ok(result.rows_affected() > 0)
21
+
}
22
+
23
+
pub async fn cleanup_expired_dpop_jtis(
24
+
pool: &PgPool,
25
+
max_age_secs: i64,
26
+
) -> Result<u64, OAuthError> {
27
+
let result = sqlx::query!(
28
+
r#"
29
+
DELETE FROM oauth_dpop_jti
30
+
WHERE created_at < NOW() - INTERVAL '1 second' * $1
31
+
"#,
32
+
max_age_secs as f64
33
+
)
34
+
.execute(pool)
35
+
.await?;
36
+
37
+
Ok(result.rows_affected())
38
+
}
+17
src/oauth/db/helpers.rs
+17
src/oauth/db/helpers.rs
···
···
1
+
use serde::{de::DeserializeOwned, Serialize};
2
+
3
+
use super::super::OAuthError;
4
+
5
+
pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
6
+
serde_json::to_value(value).map_err(|e| {
7
+
tracing::error!("JSON serialization error: {}", e);
8
+
OAuthError::ServerError("Internal serialization error".to_string())
9
+
})
10
+
}
11
+
12
+
pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
13
+
serde_json::from_value(value).map_err(|e| {
14
+
tracing::error!("JSON deserialization error: {}", e);
15
+
OAuthError::ServerError("Internal data corruption".to_string())
16
+
})
17
+
}
+22
src/oauth/db/mod.rs
+22
src/oauth/db/mod.rs
···
···
1
+
mod client;
2
+
mod device;
3
+
mod dpop;
4
+
mod helpers;
5
+
mod request;
6
+
mod token;
7
+
8
+
pub use client::{get_authorized_client, upsert_authorized_client};
9
+
pub use device::{
10
+
create_device, delete_device, get_device, update_device_last_seen, upsert_account_device,
11
+
};
12
+
pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis};
13
+
pub use request::{
14
+
consume_authorization_request_by_code, create_authorization_request,
15
+
delete_authorization_request, delete_expired_authorization_requests, get_authorization_request,
16
+
update_authorization_request,
17
+
};
18
+
pub use token::{
19
+
check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user,
20
+
delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id,
21
+
get_token_by_refresh_token, list_tokens_for_user, rotate_token,
22
+
};
+163
src/oauth/db/request.rs
+163
src/oauth/db/request.rs
···
···
1
+
use sqlx::PgPool;
2
+
3
+
use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData};
4
+
use super::helpers::{from_json, to_json};
5
+
6
+
pub async fn create_authorization_request(
7
+
pool: &PgPool,
8
+
request_id: &str,
9
+
data: &RequestData,
10
+
) -> Result<(), OAuthError> {
11
+
let client_auth_json = match &data.client_auth {
12
+
Some(ca) => Some(to_json(ca)?),
13
+
None => None,
14
+
};
15
+
let parameters_json = to_json(&data.parameters)?;
16
+
17
+
sqlx::query!(
18
+
r#"
19
+
INSERT INTO oauth_authorization_request
20
+
(id, did, device_id, client_id, client_auth, parameters, expires_at, code)
21
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
22
+
"#,
23
+
request_id,
24
+
data.did,
25
+
data.device_id,
26
+
data.client_id,
27
+
client_auth_json,
28
+
parameters_json,
29
+
data.expires_at,
30
+
data.code,
31
+
)
32
+
.execute(pool)
33
+
.await?;
34
+
35
+
Ok(())
36
+
}
37
+
38
+
pub async fn get_authorization_request(
39
+
pool: &PgPool,
40
+
request_id: &str,
41
+
) -> Result<Option<RequestData>, OAuthError> {
42
+
let row = sqlx::query!(
43
+
r#"
44
+
SELECT did, device_id, client_id, client_auth, parameters, expires_at, code
45
+
FROM oauth_authorization_request
46
+
WHERE id = $1
47
+
"#,
48
+
request_id
49
+
)
50
+
.fetch_optional(pool)
51
+
.await?;
52
+
53
+
match row {
54
+
Some(r) => {
55
+
let client_auth: Option<ClientAuth> = match r.client_auth {
56
+
Some(v) => Some(from_json(v)?),
57
+
None => None,
58
+
};
59
+
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
60
+
61
+
Ok(Some(RequestData {
62
+
client_id: r.client_id,
63
+
client_auth,
64
+
parameters,
65
+
expires_at: r.expires_at,
66
+
did: r.did,
67
+
device_id: r.device_id,
68
+
code: r.code,
69
+
}))
70
+
}
71
+
None => Ok(None),
72
+
}
73
+
}
74
+
75
+
pub async fn update_authorization_request(
76
+
pool: &PgPool,
77
+
request_id: &str,
78
+
did: &str,
79
+
device_id: Option<&str>,
80
+
code: &str,
81
+
) -> Result<(), OAuthError> {
82
+
sqlx::query!(
83
+
r#"
84
+
UPDATE oauth_authorization_request
85
+
SET did = $2, device_id = $3, code = $4
86
+
WHERE id = $1
87
+
"#,
88
+
request_id,
89
+
did,
90
+
device_id,
91
+
code
92
+
)
93
+
.execute(pool)
94
+
.await?;
95
+
96
+
Ok(())
97
+
}
98
+
99
+
pub async fn consume_authorization_request_by_code(
100
+
pool: &PgPool,
101
+
code: &str,
102
+
) -> Result<Option<RequestData>, OAuthError> {
103
+
let row = sqlx::query!(
104
+
r#"
105
+
DELETE FROM oauth_authorization_request
106
+
WHERE code = $1
107
+
RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code
108
+
"#,
109
+
code
110
+
)
111
+
.fetch_optional(pool)
112
+
.await?;
113
+
114
+
match row {
115
+
Some(r) => {
116
+
let client_auth: Option<ClientAuth> = match r.client_auth {
117
+
Some(v) => Some(from_json(v)?),
118
+
None => None,
119
+
};
120
+
let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
121
+
122
+
Ok(Some(RequestData {
123
+
client_id: r.client_id,
124
+
client_auth,
125
+
parameters,
126
+
expires_at: r.expires_at,
127
+
did: r.did,
128
+
device_id: r.device_id,
129
+
code: r.code,
130
+
}))
131
+
}
132
+
None => Ok(None),
133
+
}
134
+
}
135
+
136
+
pub async fn delete_authorization_request(
137
+
pool: &PgPool,
138
+
request_id: &str,
139
+
) -> Result<(), OAuthError> {
140
+
sqlx::query!(
141
+
r#"
142
+
DELETE FROM oauth_authorization_request WHERE id = $1
143
+
"#,
144
+
request_id
145
+
)
146
+
.execute(pool)
147
+
.await?;
148
+
149
+
Ok(())
150
+
}
151
+
152
+
pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
153
+
let result = sqlx::query!(
154
+
r#"
155
+
DELETE FROM oauth_authorization_request
156
+
WHERE expires_at < NOW()
157
+
"#
158
+
)
159
+
.execute(pool)
160
+
.await?;
161
+
162
+
Ok(result.rows_affected())
163
+
}
+291
src/oauth/db/token.rs
+291
src/oauth/db/token.rs
···
···
1
+
use chrono::{DateTime, Utc};
2
+
use sqlx::PgPool;
3
+
4
+
use super::super::{OAuthError, TokenData};
5
+
use super::helpers::{from_json, to_json};
6
+
7
+
pub async fn create_token(
8
+
pool: &PgPool,
9
+
data: &TokenData,
10
+
) -> Result<i32, OAuthError> {
11
+
let client_auth_json = to_json(&data.client_auth)?;
12
+
let parameters_json = to_json(&data.parameters)?;
13
+
14
+
let row = sqlx::query!(
15
+
r#"
16
+
INSERT INTO oauth_token
17
+
(did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
18
+
device_id, parameters, details, code, current_refresh_token, scope)
19
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
20
+
RETURNING id
21
+
"#,
22
+
data.did,
23
+
data.token_id,
24
+
data.created_at,
25
+
data.updated_at,
26
+
data.expires_at,
27
+
data.client_id,
28
+
client_auth_json,
29
+
data.device_id,
30
+
parameters_json,
31
+
data.details,
32
+
data.code,
33
+
data.current_refresh_token,
34
+
data.scope,
35
+
)
36
+
.fetch_one(pool)
37
+
.await?;
38
+
39
+
Ok(row.id)
40
+
}
41
+
42
+
pub async fn get_token_by_id(
43
+
pool: &PgPool,
44
+
token_id: &str,
45
+
) -> Result<Option<TokenData>, OAuthError> {
46
+
let row = sqlx::query!(
47
+
r#"
48
+
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
49
+
device_id, parameters, details, code, current_refresh_token, scope
50
+
FROM oauth_token
51
+
WHERE token_id = $1
52
+
"#,
53
+
token_id
54
+
)
55
+
.fetch_optional(pool)
56
+
.await?;
57
+
58
+
match row {
59
+
Some(r) => Ok(Some(TokenData {
60
+
did: r.did,
61
+
token_id: r.token_id,
62
+
created_at: r.created_at,
63
+
updated_at: r.updated_at,
64
+
expires_at: r.expires_at,
65
+
client_id: r.client_id,
66
+
client_auth: from_json(r.client_auth)?,
67
+
device_id: r.device_id,
68
+
parameters: from_json(r.parameters)?,
69
+
details: r.details,
70
+
code: r.code,
71
+
current_refresh_token: r.current_refresh_token,
72
+
scope: r.scope,
73
+
})),
74
+
None => Ok(None),
75
+
}
76
+
}
77
+
78
+
pub async fn get_token_by_refresh_token(
79
+
pool: &PgPool,
80
+
refresh_token: &str,
81
+
) -> Result<Option<(i32, TokenData)>, OAuthError> {
82
+
let row = sqlx::query!(
83
+
r#"
84
+
SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
85
+
device_id, parameters, details, code, current_refresh_token, scope
86
+
FROM oauth_token
87
+
WHERE current_refresh_token = $1
88
+
"#,
89
+
refresh_token
90
+
)
91
+
.fetch_optional(pool)
92
+
.await?;
93
+
94
+
match row {
95
+
Some(r) => Ok(Some((
96
+
r.id,
97
+
TokenData {
98
+
did: r.did,
99
+
token_id: r.token_id,
100
+
created_at: r.created_at,
101
+
updated_at: r.updated_at,
102
+
expires_at: r.expires_at,
103
+
client_id: r.client_id,
104
+
client_auth: from_json(r.client_auth)?,
105
+
device_id: r.device_id,
106
+
parameters: from_json(r.parameters)?,
107
+
details: r.details,
108
+
code: r.code,
109
+
current_refresh_token: r.current_refresh_token,
110
+
scope: r.scope,
111
+
},
112
+
))),
113
+
None => Ok(None),
114
+
}
115
+
}
116
+
117
+
pub async fn rotate_token(
118
+
pool: &PgPool,
119
+
old_db_id: i32,
120
+
new_token_id: &str,
121
+
new_refresh_token: &str,
122
+
new_expires_at: DateTime<Utc>,
123
+
) -> Result<(), OAuthError> {
124
+
let mut tx = pool.begin().await?;
125
+
126
+
let old_refresh = sqlx::query_scalar!(
127
+
r#"
128
+
SELECT current_refresh_token FROM oauth_token WHERE id = $1
129
+
"#,
130
+
old_db_id
131
+
)
132
+
.fetch_one(&mut *tx)
133
+
.await?;
134
+
135
+
if let Some(old_rt) = old_refresh {
136
+
sqlx::query!(
137
+
r#"
138
+
INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
139
+
VALUES ($1, $2)
140
+
"#,
141
+
old_rt,
142
+
old_db_id
143
+
)
144
+
.execute(&mut *tx)
145
+
.await?;
146
+
}
147
+
148
+
sqlx::query!(
149
+
r#"
150
+
UPDATE oauth_token
151
+
SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
152
+
WHERE id = $1
153
+
"#,
154
+
old_db_id,
155
+
new_token_id,
156
+
new_refresh_token,
157
+
new_expires_at
158
+
)
159
+
.execute(&mut *tx)
160
+
.await?;
161
+
162
+
tx.commit().await?;
163
+
Ok(())
164
+
}
165
+
166
+
pub async fn check_refresh_token_used(
167
+
pool: &PgPool,
168
+
refresh_token: &str,
169
+
) -> Result<Option<i32>, OAuthError> {
170
+
let row = sqlx::query_scalar!(
171
+
r#"
172
+
SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
173
+
"#,
174
+
refresh_token
175
+
)
176
+
.fetch_optional(pool)
177
+
.await?;
178
+
179
+
Ok(row)
180
+
}
181
+
182
+
pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
183
+
sqlx::query!(
184
+
r#"
185
+
DELETE FROM oauth_token WHERE token_id = $1
186
+
"#,
187
+
token_id
188
+
)
189
+
.execute(pool)
190
+
.await?;
191
+
192
+
Ok(())
193
+
}
194
+
195
+
pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
196
+
sqlx::query!(
197
+
r#"
198
+
DELETE FROM oauth_token WHERE id = $1
199
+
"#,
200
+
db_id
201
+
)
202
+
.execute(pool)
203
+
.await?;
204
+
205
+
Ok(())
206
+
}
207
+
208
+
pub async fn list_tokens_for_user(
209
+
pool: &PgPool,
210
+
did: &str,
211
+
) -> Result<Vec<TokenData>, OAuthError> {
212
+
let rows = sqlx::query!(
213
+
r#"
214
+
SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
215
+
device_id, parameters, details, code, current_refresh_token, scope
216
+
FROM oauth_token
217
+
WHERE did = $1
218
+
"#,
219
+
did
220
+
)
221
+
.fetch_all(pool)
222
+
.await?;
223
+
224
+
let mut tokens = Vec::with_capacity(rows.len());
225
+
for r in rows {
226
+
tokens.push(TokenData {
227
+
did: r.did,
228
+
token_id: r.token_id,
229
+
created_at: r.created_at,
230
+
updated_at: r.updated_at,
231
+
expires_at: r.expires_at,
232
+
client_id: r.client_id,
233
+
client_auth: from_json(r.client_auth)?,
234
+
device_id: r.device_id,
235
+
parameters: from_json(r.parameters)?,
236
+
details: r.details,
237
+
code: r.code,
238
+
current_refresh_token: r.current_refresh_token,
239
+
scope: r.scope,
240
+
});
241
+
}
242
+
Ok(tokens)
243
+
}
244
+
245
+
pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
246
+
let count = sqlx::query_scalar!(
247
+
r#"
248
+
SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
249
+
"#,
250
+
did
251
+
)
252
+
.fetch_one(pool)
253
+
.await?;
254
+
255
+
Ok(count)
256
+
}
257
+
258
+
pub async fn delete_oldest_tokens_for_user(
259
+
pool: &PgPool,
260
+
did: &str,
261
+
keep_count: i64,
262
+
) -> Result<u64, OAuthError> {
263
+
let result = sqlx::query!(
264
+
r#"
265
+
DELETE FROM oauth_token
266
+
WHERE id IN (
267
+
SELECT id FROM oauth_token
268
+
WHERE did = $1
269
+
ORDER BY updated_at ASC
270
+
OFFSET $2
271
+
)
272
+
"#,
273
+
did,
274
+
keep_count
275
+
)
276
+
.execute(pool)
277
+
.await?;
278
+
279
+
Ok(result.rows_affected())
280
+
}
281
+
282
+
const MAX_TOKENS_PER_USER: i64 = 100;
283
+
284
+
pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
285
+
let count = count_tokens_for_user(pool, did).await?;
286
+
if count > MAX_TOKENS_PER_USER {
287
+
let to_keep = MAX_TOKENS_PER_USER - 1;
288
+
delete_oldest_tokens_for_user(pool, did, to_keep).await?;
289
+
}
290
+
Ok(())
291
+
}
+8
-10
src/oauth/dpop.rs
+8
-10
src/oauth/dpop.rs
···
237
false,
238
);
239
240
-
let affine = AffinePoint::from_encoded_point(&point);
241
-
if affine.is_none().into() {
242
-
return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string()));
243
-
}
244
245
-
let verifying_key = VerifyingKey::from_affine(affine.unwrap())
246
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
247
248
let sig = Signature::from_slice(signature)
···
287
false,
288
);
289
290
-
let affine = AffinePoint::from_encoded_point(&point);
291
-
if affine.is_none().into() {
292
-
return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string()));
293
-
}
294
295
-
let verifying_key = VerifyingKey::from_affine(affine.unwrap())
296
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
297
298
let sig = Signature::from_slice(signature)
···
237
false,
238
);
239
240
+
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
241
+
let affine = affine_opt
242
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
243
244
+
let verifying_key = VerifyingKey::from_affine(affine)
245
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
246
247
let sig = Signature::from_slice(signature)
···
286
false,
287
);
288
289
+
let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into();
290
+
let affine = affine_opt
291
+
.ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?;
292
293
+
let verifying_key = VerifyingKey::from_affine(affine)
294
.map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?;
295
296
let sig = Signature::from_slice(signature)
-558
src/oauth/endpoints/token.rs
-558
src/oauth/endpoints/token.rs
···
1
-
use axum::{
2
-
Form, Json,
3
-
extract::State,
4
-
http::{HeaderMap, StatusCode},
5
-
};
6
-
use base64::Engine;
7
-
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8
-
use chrono::{Duration, Utc};
9
-
use hmac::Mac;
10
-
use serde::{Deserialize, Serialize};
11
-
use sha2::{Digest, Sha256};
12
-
use subtle::ConstantTimeEq;
13
-
14
-
use crate::config::AuthConfig;
15
-
use crate::state::AppState;
16
-
use crate::oauth::{
17
-
ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
18
-
client::{ClientMetadataCache, verify_client_auth},
19
-
db,
20
-
dpop::DPoPVerifier,
21
-
};
22
-
23
-
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
24
-
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
25
-
26
-
#[derive(Debug, Deserialize)]
27
-
pub struct TokenRequest {
28
-
pub grant_type: String,
29
-
#[serde(default)]
30
-
pub code: Option<String>,
31
-
#[serde(default)]
32
-
pub redirect_uri: Option<String>,
33
-
#[serde(default)]
34
-
pub code_verifier: Option<String>,
35
-
#[serde(default)]
36
-
pub refresh_token: Option<String>,
37
-
#[serde(default)]
38
-
pub client_id: Option<String>,
39
-
#[serde(default)]
40
-
pub client_secret: Option<String>,
41
-
#[serde(default)]
42
-
pub client_assertion: Option<String>,
43
-
#[serde(default)]
44
-
pub client_assertion_type: Option<String>,
45
-
}
46
-
47
-
#[derive(Debug, Serialize)]
48
-
pub struct TokenResponse {
49
-
pub access_token: String,
50
-
pub token_type: String,
51
-
pub expires_in: u64,
52
-
#[serde(skip_serializing_if = "Option::is_none")]
53
-
pub refresh_token: Option<String>,
54
-
#[serde(skip_serializing_if = "Option::is_none")]
55
-
pub scope: Option<String>,
56
-
#[serde(skip_serializing_if = "Option::is_none")]
57
-
pub sub: Option<String>,
58
-
}
59
-
60
-
pub async fn token_endpoint(
61
-
State(state): State<AppState>,
62
-
headers: HeaderMap,
63
-
Form(request): Form<TokenRequest>,
64
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
65
-
let dpop_proof = headers
66
-
.get("DPoP")
67
-
.and_then(|v| v.to_str().ok())
68
-
.map(|s| s.to_string());
69
-
70
-
match request.grant_type.as_str() {
71
-
"authorization_code" => {
72
-
handle_authorization_code_grant(state, headers, request, dpop_proof).await
73
-
}
74
-
"refresh_token" => {
75
-
handle_refresh_token_grant(state, headers, request, dpop_proof).await
76
-
}
77
-
_ => Err(OAuthError::UnsupportedGrantType(format!(
78
-
"Unsupported grant_type: {}",
79
-
request.grant_type
80
-
))),
81
-
}
82
-
}
83
-
84
-
async fn handle_authorization_code_grant(
85
-
state: AppState,
86
-
_headers: HeaderMap,
87
-
request: TokenRequest,
88
-
dpop_proof: Option<String>,
89
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
90
-
let code = request
91
-
.code
92
-
.ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
93
-
94
-
let code_verifier = request
95
-
.code_verifier
96
-
.ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
97
-
98
-
let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
99
-
.await?
100
-
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
101
-
102
-
if auth_request.expires_at < Utc::now() {
103
-
return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
104
-
}
105
-
106
-
if let Some(request_client_id) = &request.client_id {
107
-
if request_client_id != &auth_request.client_id {
108
-
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
109
-
}
110
-
}
111
-
112
-
let did = auth_request
113
-
.did
114
-
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
115
-
116
-
let client_metadata_cache = ClientMetadataCache::new(3600);
117
-
let client_metadata = client_metadata_cache
118
-
.get(&auth_request.client_id)
119
-
.await?;
120
-
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
121
-
verify_client_auth(&client_metadata, &client_auth)?;
122
-
123
-
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
124
-
125
-
if let Some(redirect_uri) = &request.redirect_uri {
126
-
if redirect_uri != &auth_request.parameters.redirect_uri {
127
-
return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
128
-
}
129
-
}
130
-
131
-
let dpop_jkt = if let Some(proof) = &dpop_proof {
132
-
let config = AuthConfig::get();
133
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
134
-
135
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
136
-
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
137
-
138
-
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
139
-
140
-
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
141
-
return Err(OAuthError::InvalidDpopProof(
142
-
"DPoP proof has already been used".to_string(),
143
-
));
144
-
}
145
-
146
-
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
147
-
if &result.jkt != expected_jkt {
148
-
return Err(OAuthError::InvalidDpopProof(
149
-
"DPoP key binding mismatch".to_string(),
150
-
));
151
-
}
152
-
}
153
-
154
-
Some(result.jkt)
155
-
} else if auth_request.parameters.dpop_jkt.is_some() {
156
-
return Err(OAuthError::InvalidRequest(
157
-
"DPoP proof required for this authorization".to_string(),
158
-
));
159
-
} else {
160
-
None
161
-
};
162
-
163
-
let token_id = TokenId::generate();
164
-
let refresh_token = RefreshToken::generate();
165
-
let now = Utc::now();
166
-
167
-
let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
168
-
169
-
let token_data = TokenData {
170
-
did: did.clone(),
171
-
token_id: token_id.0.clone(),
172
-
created_at: now,
173
-
updated_at: now,
174
-
expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
175
-
client_id: auth_request.client_id.clone(),
176
-
client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
177
-
device_id: auth_request.device_id,
178
-
parameters: auth_request.parameters.clone(),
179
-
details: None,
180
-
code: None,
181
-
current_refresh_token: Some(refresh_token.0.clone()),
182
-
scope: auth_request.parameters.scope.clone(),
183
-
};
184
-
185
-
db::create_token(&state.db, &token_data).await?;
186
-
187
-
tokio::spawn({
188
-
let pool = state.db.clone();
189
-
let did_clone = did.clone();
190
-
async move {
191
-
if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
192
-
tracing::warn!("Failed to enforce token limit for user: {:?}", e);
193
-
}
194
-
}
195
-
});
196
-
197
-
let mut response_headers = HeaderMap::new();
198
-
let config = AuthConfig::get();
199
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
200
-
response_headers.insert(
201
-
"DPoP-Nonce",
202
-
verifier.generate_nonce().parse().unwrap(),
203
-
);
204
-
205
-
Ok((
206
-
response_headers,
207
-
Json(TokenResponse {
208
-
access_token,
209
-
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
210
-
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
211
-
refresh_token: Some(refresh_token.0),
212
-
scope: auth_request.parameters.scope,
213
-
sub: Some(did),
214
-
}),
215
-
))
216
-
}
217
-
218
-
async fn handle_refresh_token_grant(
219
-
state: AppState,
220
-
_headers: HeaderMap,
221
-
request: TokenRequest,
222
-
dpop_proof: Option<String>,
223
-
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
224
-
let refresh_token_str = request
225
-
.refresh_token
226
-
.ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
227
-
228
-
if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
229
-
db::delete_token_family(&state.db, token_id).await?;
230
-
return Err(OAuthError::InvalidGrant(
231
-
"Refresh token reuse detected, token family revoked".to_string(),
232
-
));
233
-
}
234
-
235
-
let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
236
-
.await?
237
-
.ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
238
-
239
-
if token_data.expires_at < Utc::now() {
240
-
db::delete_token_family(&state.db, db_id).await?;
241
-
return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
242
-
}
243
-
244
-
let dpop_jkt = if let Some(proof) = &dpop_proof {
245
-
let config = AuthConfig::get();
246
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
247
-
248
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
249
-
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
250
-
251
-
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
252
-
253
-
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
254
-
return Err(OAuthError::InvalidDpopProof(
255
-
"DPoP proof has already been used".to_string(),
256
-
));
257
-
}
258
-
259
-
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
260
-
if &result.jkt != expected_jkt {
261
-
return Err(OAuthError::InvalidDpopProof(
262
-
"DPoP key binding mismatch".to_string(),
263
-
));
264
-
}
265
-
}
266
-
267
-
Some(result.jkt)
268
-
} else if token_data.parameters.dpop_jkt.is_some() {
269
-
return Err(OAuthError::InvalidRequest(
270
-
"DPoP proof required".to_string(),
271
-
));
272
-
} else {
273
-
None
274
-
};
275
-
276
-
let new_token_id = TokenId::generate();
277
-
let new_refresh_token = RefreshToken::generate();
278
-
let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
279
-
280
-
db::rotate_token(
281
-
&state.db,
282
-
db_id,
283
-
&new_token_id.0,
284
-
&new_refresh_token.0,
285
-
new_expires_at,
286
-
)
287
-
.await?;
288
-
289
-
let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
290
-
291
-
let mut response_headers = HeaderMap::new();
292
-
let config = AuthConfig::get();
293
-
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
294
-
response_headers.insert(
295
-
"DPoP-Nonce",
296
-
verifier.generate_nonce().parse().unwrap(),
297
-
);
298
-
299
-
Ok((
300
-
response_headers,
301
-
Json(TokenResponse {
302
-
access_token,
303
-
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
304
-
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
305
-
refresh_token: Some(new_refresh_token.0),
306
-
scope: token_data.scope,
307
-
sub: Some(token_data.did),
308
-
}),
309
-
))
310
-
}
311
-
312
-
fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
313
-
use subtle::ConstantTimeEq;
314
-
315
-
let mut hasher = Sha256::new();
316
-
hasher.update(code_verifier.as_bytes());
317
-
let hash = hasher.finalize();
318
-
let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
319
-
320
-
if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
321
-
return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
322
-
}
323
-
324
-
Ok(())
325
-
}
326
-
327
-
fn create_access_token(
328
-
token_id: &str,
329
-
sub: &str,
330
-
dpop_jkt: Option<&str>,
331
-
) -> Result<String, OAuthError> {
332
-
use serde_json::json;
333
-
334
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
335
-
let issuer = format!("https://{}", pds_hostname);
336
-
337
-
let now = Utc::now().timestamp();
338
-
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
339
-
340
-
let mut payload = json!({
341
-
"iss": issuer,
342
-
"sub": sub,
343
-
"aud": issuer,
344
-
"iat": now,
345
-
"exp": exp,
346
-
"jti": token_id,
347
-
"scope": "atproto"
348
-
});
349
-
350
-
if let Some(jkt) = dpop_jkt {
351
-
payload["cnf"] = json!({ "jkt": jkt });
352
-
}
353
-
354
-
let header = json!({
355
-
"alg": "HS256",
356
-
"typ": "at+jwt"
357
-
});
358
-
359
-
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
360
-
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
361
-
362
-
let signing_input = format!("{}.{}", header_b64, payload_b64);
363
-
364
-
let config = AuthConfig::get();
365
-
366
-
use sha2::Sha256 as HmacSha256;
367
-
use hmac::{Hmac, Mac};
368
-
type HmacSha256Type = Hmac<HmacSha256>;
369
-
370
-
let mut mac = HmacSha256Type::new_from_slice(config.jwt_secret().as_bytes())
371
-
.map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
372
-
mac.update(signing_input.as_bytes());
373
-
let signature = mac.finalize().into_bytes();
374
-
375
-
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
376
-
377
-
Ok(format!("{}.{}", signing_input, signature_b64))
378
-
}
379
-
380
-
pub async fn revoke_token(
381
-
State(state): State<AppState>,
382
-
Form(request): Form<RevokeRequest>,
383
-
) -> Result<StatusCode, OAuthError> {
384
-
if let Some(token) = &request.token {
385
-
if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
386
-
db::delete_token_family(&state.db, db_id).await?;
387
-
} else {
388
-
db::delete_token(&state.db, token).await?;
389
-
}
390
-
}
391
-
392
-
Ok(StatusCode::OK)
393
-
}
394
-
395
-
#[derive(Debug, Deserialize)]
396
-
pub struct RevokeRequest {
397
-
pub token: Option<String>,
398
-
#[serde(default)]
399
-
pub token_type_hint: Option<String>,
400
-
}
401
-
402
-
#[derive(Debug, Deserialize)]
403
-
pub struct IntrospectRequest {
404
-
pub token: String,
405
-
#[serde(default)]
406
-
pub token_type_hint: Option<String>,
407
-
}
408
-
409
-
#[derive(Debug, Serialize)]
410
-
pub struct IntrospectResponse {
411
-
pub active: bool,
412
-
#[serde(skip_serializing_if = "Option::is_none")]
413
-
pub scope: Option<String>,
414
-
#[serde(skip_serializing_if = "Option::is_none")]
415
-
pub client_id: Option<String>,
416
-
#[serde(skip_serializing_if = "Option::is_none")]
417
-
pub username: Option<String>,
418
-
#[serde(skip_serializing_if = "Option::is_none")]
419
-
pub token_type: Option<String>,
420
-
#[serde(skip_serializing_if = "Option::is_none")]
421
-
pub exp: Option<i64>,
422
-
#[serde(skip_serializing_if = "Option::is_none")]
423
-
pub iat: Option<i64>,
424
-
#[serde(skip_serializing_if = "Option::is_none")]
425
-
pub nbf: Option<i64>,
426
-
#[serde(skip_serializing_if = "Option::is_none")]
427
-
pub sub: Option<String>,
428
-
#[serde(skip_serializing_if = "Option::is_none")]
429
-
pub aud: Option<String>,
430
-
#[serde(skip_serializing_if = "Option::is_none")]
431
-
pub iss: Option<String>,
432
-
#[serde(skip_serializing_if = "Option::is_none")]
433
-
pub jti: Option<String>,
434
-
}
435
-
436
-
pub async fn introspect_token(
437
-
State(state): State<AppState>,
438
-
Form(request): Form<IntrospectRequest>,
439
-
) -> Json<IntrospectResponse> {
440
-
let inactive_response = IntrospectResponse {
441
-
active: false,
442
-
scope: None,
443
-
client_id: None,
444
-
username: None,
445
-
token_type: None,
446
-
exp: None,
447
-
iat: None,
448
-
nbf: None,
449
-
sub: None,
450
-
aud: None,
451
-
iss: None,
452
-
jti: None,
453
-
};
454
-
455
-
let token_info = match extract_token_claims(&request.token) {
456
-
Ok(info) => info,
457
-
Err(_) => return Json(inactive_response),
458
-
};
459
-
460
-
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
461
-
Ok(Some(data)) => data,
462
-
_ => return Json(inactive_response),
463
-
};
464
-
465
-
if token_data.expires_at < Utc::now() {
466
-
return Json(inactive_response);
467
-
}
468
-
469
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
470
-
let issuer = format!("https://{}", pds_hostname);
471
-
472
-
Json(IntrospectResponse {
473
-
active: true,
474
-
scope: token_data.scope,
475
-
client_id: Some(token_data.client_id),
476
-
username: None,
477
-
token_type: if token_data.parameters.dpop_jkt.is_some() {
478
-
Some("DPoP".to_string())
479
-
} else {
480
-
Some("Bearer".to_string())
481
-
},
482
-
exp: Some(token_info.exp),
483
-
iat: Some(token_info.iat),
484
-
nbf: Some(token_info.iat),
485
-
sub: Some(token_data.did),
486
-
aud: Some(issuer.clone()),
487
-
iss: Some(issuer),
488
-
jti: Some(token_info.jti),
489
-
})
490
-
}
491
-
492
-
struct TokenClaims {
493
-
jti: String,
494
-
exp: i64,
495
-
iat: i64,
496
-
}
497
-
498
-
fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
499
-
let parts: Vec<&str> = token.split('.').collect();
500
-
if parts.len() != 3 {
501
-
return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
502
-
}
503
-
504
-
let header_bytes = URL_SAFE_NO_PAD
505
-
.decode(parts[0])
506
-
.map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
507
-
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
508
-
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
509
-
510
-
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
511
-
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
512
-
}
513
-
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
514
-
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
515
-
}
516
-
517
-
let config = AuthConfig::get();
518
-
let secret = config.jwt_secret();
519
-
520
-
let signing_input = format!("{}.{}", parts[0], parts[1]);
521
-
let provided_sig = URL_SAFE_NO_PAD
522
-
.decode(parts[2])
523
-
.map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
524
-
525
-
type HmacSha256 = hmac::Hmac<Sha256>;
526
-
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
527
-
.map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
528
-
mac.update(signing_input.as_bytes());
529
-
let expected_sig = mac.finalize().into_bytes();
530
-
531
-
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
532
-
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
533
-
}
534
-
535
-
let payload_bytes = URL_SAFE_NO_PAD
536
-
.decode(parts[1])
537
-
.map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
538
-
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
539
-
.map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
540
-
541
-
let jti = payload
542
-
.get("jti")
543
-
.and_then(|j| j.as_str())
544
-
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
545
-
.to_string();
546
-
547
-
let exp = payload
548
-
.get("exp")
549
-
.and_then(|e| e.as_i64())
550
-
.ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
551
-
552
-
let iat = payload
553
-
.get("iat")
554
-
.and_then(|i| i.as_i64())
555
-
.ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
556
-
557
-
Ok(TokenClaims { jti, exp, iat })
558
-
}
···
+246
src/oauth/endpoints/token/grants.rs
+246
src/oauth/endpoints/token/grants.rs
···
···
1
+
use axum::http::HeaderMap;
2
+
use axum::Json;
3
+
use chrono::{Duration, Utc};
4
+
5
+
use crate::config::AuthConfig;
6
+
use crate::state::AppState;
7
+
use crate::oauth::{
8
+
ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
9
+
client::{ClientMetadataCache, verify_client_auth},
10
+
db,
11
+
dpop::DPoPVerifier,
12
+
};
13
+
14
+
use super::types::{TokenRequest, TokenResponse};
15
+
use super::helpers::{create_access_token, verify_pkce};
16
+
17
+
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
18
+
const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
19
+
20
+
pub async fn handle_authorization_code_grant(
21
+
state: AppState,
22
+
_headers: HeaderMap,
23
+
request: TokenRequest,
24
+
dpop_proof: Option<String>,
25
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
26
+
let code = request
27
+
.code
28
+
.ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
29
+
30
+
let code_verifier = request
31
+
.code_verifier
32
+
.ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
33
+
34
+
let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
35
+
.await?
36
+
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
37
+
38
+
if auth_request.expires_at < Utc::now() {
39
+
return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
40
+
}
41
+
42
+
if let Some(request_client_id) = &request.client_id {
43
+
if request_client_id != &auth_request.client_id {
44
+
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
45
+
}
46
+
}
47
+
48
+
let did = auth_request
49
+
.did
50
+
.ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
51
+
52
+
let client_metadata_cache = ClientMetadataCache::new(3600);
53
+
let client_metadata = client_metadata_cache
54
+
.get(&auth_request.client_id)
55
+
.await?;
56
+
let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
57
+
verify_client_auth(&client_metadata, &client_auth)?;
58
+
59
+
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
60
+
61
+
if let Some(redirect_uri) = &request.redirect_uri {
62
+
if redirect_uri != &auth_request.parameters.redirect_uri {
63
+
return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
64
+
}
65
+
}
66
+
67
+
let dpop_jkt = if let Some(proof) = &dpop_proof {
68
+
let config = AuthConfig::get();
69
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
70
+
71
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
72
+
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
73
+
74
+
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
75
+
76
+
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
77
+
return Err(OAuthError::InvalidDpopProof(
78
+
"DPoP proof has already been used".to_string(),
79
+
));
80
+
}
81
+
82
+
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
83
+
if &result.jkt != expected_jkt {
84
+
return Err(OAuthError::InvalidDpopProof(
85
+
"DPoP key binding mismatch".to_string(),
86
+
));
87
+
}
88
+
}
89
+
90
+
Some(result.jkt)
91
+
} else if auth_request.parameters.dpop_jkt.is_some() {
92
+
return Err(OAuthError::InvalidRequest(
93
+
"DPoP proof required for this authorization".to_string(),
94
+
));
95
+
} else {
96
+
None
97
+
};
98
+
99
+
let token_id = TokenId::generate();
100
+
let refresh_token = RefreshToken::generate();
101
+
let now = Utc::now();
102
+
103
+
let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
104
+
105
+
let token_data = TokenData {
106
+
did: did.clone(),
107
+
token_id: token_id.0.clone(),
108
+
created_at: now,
109
+
updated_at: now,
110
+
expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
111
+
client_id: auth_request.client_id.clone(),
112
+
client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
113
+
device_id: auth_request.device_id,
114
+
parameters: auth_request.parameters.clone(),
115
+
details: None,
116
+
code: None,
117
+
current_refresh_token: Some(refresh_token.0.clone()),
118
+
scope: auth_request.parameters.scope.clone(),
119
+
};
120
+
121
+
db::create_token(&state.db, &token_data).await?;
122
+
123
+
tokio::spawn({
124
+
let pool = state.db.clone();
125
+
let did_clone = did.clone();
126
+
async move {
127
+
if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
128
+
tracing::warn!("Failed to enforce token limit for user: {:?}", e);
129
+
}
130
+
}
131
+
});
132
+
133
+
let mut response_headers = HeaderMap::new();
134
+
let config = AuthConfig::get();
135
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
136
+
response_headers.insert(
137
+
"DPoP-Nonce",
138
+
verifier.generate_nonce().parse().unwrap(),
139
+
);
140
+
141
+
Ok((
142
+
response_headers,
143
+
Json(TokenResponse {
144
+
access_token,
145
+
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
146
+
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
147
+
refresh_token: Some(refresh_token.0),
148
+
scope: auth_request.parameters.scope,
149
+
sub: Some(did),
150
+
}),
151
+
))
152
+
}
153
+
154
+
pub async fn handle_refresh_token_grant(
155
+
state: AppState,
156
+
_headers: HeaderMap,
157
+
request: TokenRequest,
158
+
dpop_proof: Option<String>,
159
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
160
+
let refresh_token_str = request
161
+
.refresh_token
162
+
.ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
163
+
164
+
if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
165
+
db::delete_token_family(&state.db, token_id).await?;
166
+
return Err(OAuthError::InvalidGrant(
167
+
"Refresh token reuse detected, token family revoked".to_string(),
168
+
));
169
+
}
170
+
171
+
let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
172
+
.await?
173
+
.ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
174
+
175
+
if token_data.expires_at < Utc::now() {
176
+
db::delete_token_family(&state.db, db_id).await?;
177
+
return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
178
+
}
179
+
180
+
let dpop_jkt = if let Some(proof) = &dpop_proof {
181
+
let config = AuthConfig::get();
182
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
183
+
184
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
185
+
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
186
+
187
+
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
188
+
189
+
if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
190
+
return Err(OAuthError::InvalidDpopProof(
191
+
"DPoP proof has already been used".to_string(),
192
+
));
193
+
}
194
+
195
+
if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
196
+
if &result.jkt != expected_jkt {
197
+
return Err(OAuthError::InvalidDpopProof(
198
+
"DPoP key binding mismatch".to_string(),
199
+
));
200
+
}
201
+
}
202
+
203
+
Some(result.jkt)
204
+
} else if token_data.parameters.dpop_jkt.is_some() {
205
+
return Err(OAuthError::InvalidRequest(
206
+
"DPoP proof required".to_string(),
207
+
));
208
+
} else {
209
+
None
210
+
};
211
+
212
+
let new_token_id = TokenId::generate();
213
+
let new_refresh_token = RefreshToken::generate();
214
+
let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
215
+
216
+
db::rotate_token(
217
+
&state.db,
218
+
db_id,
219
+
&new_token_id.0,
220
+
&new_refresh_token.0,
221
+
new_expires_at,
222
+
)
223
+
.await?;
224
+
225
+
let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
226
+
227
+
let mut response_headers = HeaderMap::new();
228
+
let config = AuthConfig::get();
229
+
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
230
+
response_headers.insert(
231
+
"DPoP-Nonce",
232
+
verifier.generate_nonce().parse().unwrap(),
233
+
);
234
+
235
+
Ok((
236
+
response_headers,
237
+
Json(TokenResponse {
238
+
access_token,
239
+
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
240
+
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
241
+
refresh_token: Some(new_refresh_token.0),
242
+
scope: token_data.scope,
243
+
sub: Some(token_data.did),
244
+
}),
245
+
))
246
+
}
+143
src/oauth/endpoints/token/helpers.rs
+143
src/oauth/endpoints/token/helpers.rs
···
···
1
+
use base64::Engine;
2
+
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3
+
use chrono::Utc;
4
+
use hmac::Mac;
5
+
use sha2::{Digest, Sha256};
6
+
use subtle::ConstantTimeEq;
7
+
8
+
use crate::config::AuthConfig;
9
+
use crate::oauth::OAuthError;
10
+
11
+
const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
12
+
13
+
pub struct TokenClaims {
14
+
pub jti: String,
15
+
pub exp: i64,
16
+
pub iat: i64,
17
+
}
18
+
19
+
pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
20
+
let mut hasher = Sha256::new();
21
+
hasher.update(code_verifier.as_bytes());
22
+
let hash = hasher.finalize();
23
+
let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
24
+
25
+
if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
26
+
return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
27
+
}
28
+
29
+
Ok(())
30
+
}
31
+
32
+
pub fn create_access_token(
33
+
token_id: &str,
34
+
sub: &str,
35
+
dpop_jkt: Option<&str>,
36
+
) -> Result<String, OAuthError> {
37
+
use serde_json::json;
38
+
39
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
40
+
let issuer = format!("https://{}", pds_hostname);
41
+
42
+
let now = Utc::now().timestamp();
43
+
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
44
+
45
+
let mut payload = json!({
46
+
"iss": issuer,
47
+
"sub": sub,
48
+
"aud": issuer,
49
+
"iat": now,
50
+
"exp": exp,
51
+
"jti": token_id,
52
+
"scope": "atproto"
53
+
});
54
+
55
+
if let Some(jkt) = dpop_jkt {
56
+
payload["cnf"] = json!({ "jkt": jkt });
57
+
}
58
+
59
+
let header = json!({
60
+
"alg": "HS256",
61
+
"typ": "at+jwt"
62
+
});
63
+
64
+
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
65
+
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
66
+
67
+
let signing_input = format!("{}.{}", header_b64, payload_b64);
68
+
69
+
let config = AuthConfig::get();
70
+
71
+
type HmacSha256 = hmac::Hmac<Sha256>;
72
+
73
+
let mut mac = HmacSha256::new_from_slice(config.jwt_secret().as_bytes())
74
+
.map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
75
+
mac.update(signing_input.as_bytes());
76
+
let signature = mac.finalize().into_bytes();
77
+
78
+
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
79
+
80
+
Ok(format!("{}.{}", signing_input, signature_b64))
81
+
}
82
+
83
+
pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
84
+
let parts: Vec<&str> = token.split('.').collect();
85
+
if parts.len() != 3 {
86
+
return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
87
+
}
88
+
89
+
let header_bytes = URL_SAFE_NO_PAD
90
+
.decode(parts[0])
91
+
.map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
92
+
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
93
+
.map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
94
+
95
+
if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
96
+
return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
97
+
}
98
+
if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
99
+
return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
100
+
}
101
+
102
+
let config = AuthConfig::get();
103
+
let secret = config.jwt_secret();
104
+
105
+
let signing_input = format!("{}.{}", parts[0], parts[1]);
106
+
let provided_sig = URL_SAFE_NO_PAD
107
+
.decode(parts[2])
108
+
.map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
109
+
110
+
type HmacSha256 = hmac::Hmac<Sha256>;
111
+
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
112
+
.map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
113
+
mac.update(signing_input.as_bytes());
114
+
let expected_sig = mac.finalize().into_bytes();
115
+
116
+
if !bool::from(expected_sig.ct_eq(&provided_sig)) {
117
+
return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
118
+
}
119
+
120
+
let payload_bytes = URL_SAFE_NO_PAD
121
+
.decode(parts[1])
122
+
.map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
123
+
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
124
+
.map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
125
+
126
+
let jti = payload
127
+
.get("jti")
128
+
.and_then(|j| j.as_str())
129
+
.ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
130
+
.to_string();
131
+
132
+
let exp = payload
133
+
.get("exp")
134
+
.and_then(|e| e.as_i64())
135
+
.ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
136
+
137
+
let iat = payload
138
+
.get("iat")
139
+
.and_then(|i| i.as_i64())
140
+
.ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
141
+
142
+
Ok(TokenClaims { jti, exp, iat })
143
+
}
+122
src/oauth/endpoints/token/introspect.rs
+122
src/oauth/endpoints/token/introspect.rs
···
···
1
+
use axum::{Form, Json};
2
+
use axum::extract::State;
3
+
use axum::http::StatusCode;
4
+
use chrono::Utc;
5
+
use serde::{Deserialize, Serialize};
6
+
7
+
use crate::state::AppState;
8
+
use crate::oauth::{OAuthError, db};
9
+
10
+
use super::helpers::extract_token_claims;
11
+
12
+
#[derive(Debug, Deserialize)]
13
+
pub struct RevokeRequest {
14
+
pub token: Option<String>,
15
+
#[serde(default)]
16
+
pub token_type_hint: Option<String>,
17
+
}
18
+
19
+
pub async fn revoke_token(
20
+
State(state): State<AppState>,
21
+
Form(request): Form<RevokeRequest>,
22
+
) -> Result<StatusCode, OAuthError> {
23
+
if let Some(token) = &request.token {
24
+
if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
25
+
db::delete_token_family(&state.db, db_id).await?;
26
+
} else {
27
+
db::delete_token(&state.db, token).await?;
28
+
}
29
+
}
30
+
31
+
Ok(StatusCode::OK)
32
+
}
33
+
34
+
#[derive(Debug, Deserialize)]
35
+
pub struct IntrospectRequest {
36
+
pub token: String,
37
+
#[serde(default)]
38
+
pub token_type_hint: Option<String>,
39
+
}
40
+
41
+
#[derive(Debug, Serialize)]
42
+
pub struct IntrospectResponse {
43
+
pub active: bool,
44
+
#[serde(skip_serializing_if = "Option::is_none")]
45
+
pub scope: Option<String>,
46
+
#[serde(skip_serializing_if = "Option::is_none")]
47
+
pub client_id: Option<String>,
48
+
#[serde(skip_serializing_if = "Option::is_none")]
49
+
pub username: Option<String>,
50
+
#[serde(skip_serializing_if = "Option::is_none")]
51
+
pub token_type: Option<String>,
52
+
#[serde(skip_serializing_if = "Option::is_none")]
53
+
pub exp: Option<i64>,
54
+
#[serde(skip_serializing_if = "Option::is_none")]
55
+
pub iat: Option<i64>,
56
+
#[serde(skip_serializing_if = "Option::is_none")]
57
+
pub nbf: Option<i64>,
58
+
#[serde(skip_serializing_if = "Option::is_none")]
59
+
pub sub: Option<String>,
60
+
#[serde(skip_serializing_if = "Option::is_none")]
61
+
pub aud: Option<String>,
62
+
#[serde(skip_serializing_if = "Option::is_none")]
63
+
pub iss: Option<String>,
64
+
#[serde(skip_serializing_if = "Option::is_none")]
65
+
pub jti: Option<String>,
66
+
}
67
+
68
+
pub async fn introspect_token(
69
+
State(state): State<AppState>,
70
+
Form(request): Form<IntrospectRequest>,
71
+
) -> Json<IntrospectResponse> {
72
+
let inactive_response = IntrospectResponse {
73
+
active: false,
74
+
scope: None,
75
+
client_id: None,
76
+
username: None,
77
+
token_type: None,
78
+
exp: None,
79
+
iat: None,
80
+
nbf: None,
81
+
sub: None,
82
+
aud: None,
83
+
iss: None,
84
+
jti: None,
85
+
};
86
+
87
+
let token_info = match extract_token_claims(&request.token) {
88
+
Ok(info) => info,
89
+
Err(_) => return Json(inactive_response),
90
+
};
91
+
92
+
let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
93
+
Ok(Some(data)) => data,
94
+
_ => return Json(inactive_response),
95
+
};
96
+
97
+
if token_data.expires_at < Utc::now() {
98
+
return Json(inactive_response);
99
+
}
100
+
101
+
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102
+
let issuer = format!("https://{}", pds_hostname);
103
+
104
+
Json(IntrospectResponse {
105
+
active: true,
106
+
scope: token_data.scope,
107
+
client_id: Some(token_data.client_id),
108
+
username: None,
109
+
token_type: if token_data.parameters.dpop_jkt.is_some() {
110
+
Some("DPoP".to_string())
111
+
} else {
112
+
Some("Bearer".to_string())
113
+
},
114
+
exp: Some(token_info.exp),
115
+
iat: Some(token_info.iat),
116
+
nbf: Some(token_info.iat),
117
+
sub: Some(token_data.did),
118
+
aud: Some(issuer.clone()),
119
+
iss: Some(issuer),
120
+
jti: Some(token_info.jti),
121
+
})
122
+
}
+44
src/oauth/endpoints/token/mod.rs
+44
src/oauth/endpoints/token/mod.rs
···
···
1
+
mod grants;
2
+
mod helpers;
3
+
mod introspect;
4
+
mod types;
5
+
6
+
use axum::{
7
+
Form, Json,
8
+
extract::State,
9
+
http::HeaderMap,
10
+
};
11
+
12
+
use crate::state::AppState;
13
+
use crate::oauth::OAuthError;
14
+
15
+
pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant};
16
+
pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims};
17
+
pub use introspect::{
18
+
introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest,
19
+
};
20
+
pub use types::{TokenRequest, TokenResponse};
21
+
22
+
pub async fn token_endpoint(
23
+
State(state): State<AppState>,
24
+
headers: HeaderMap,
25
+
Form(request): Form<TokenRequest>,
26
+
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
27
+
let dpop_proof = headers
28
+
.get("DPoP")
29
+
.and_then(|v| v.to_str().ok())
30
+
.map(|s| s.to_string());
31
+
32
+
match request.grant_type.as_str() {
33
+
"authorization_code" => {
34
+
handle_authorization_code_grant(state, headers, request, dpop_proof).await
35
+
}
36
+
"refresh_token" => {
37
+
handle_refresh_token_grant(state, headers, request, dpop_proof).await
38
+
}
39
+
_ => Err(OAuthError::UnsupportedGrantType(format!(
40
+
"Unsupported grant_type: {}",
41
+
request.grant_type
42
+
))),
43
+
}
44
+
}
+35
src/oauth/endpoints/token/types.rs
+35
src/oauth/endpoints/token/types.rs
···
···
1
+
use serde::{Deserialize, Serialize};
2
+
3
+
#[derive(Debug, Deserialize)]
4
+
pub struct TokenRequest {
5
+
pub grant_type: String,
6
+
#[serde(default)]
7
+
pub code: Option<String>,
8
+
#[serde(default)]
9
+
pub redirect_uri: Option<String>,
10
+
#[serde(default)]
11
+
pub code_verifier: Option<String>,
12
+
#[serde(default)]
13
+
pub refresh_token: Option<String>,
14
+
#[serde(default)]
15
+
pub client_id: Option<String>,
16
+
#[serde(default)]
17
+
pub client_secret: Option<String>,
18
+
#[serde(default)]
19
+
pub client_assertion: Option<String>,
20
+
#[serde(default)]
21
+
pub client_assertion_type: Option<String>,
22
+
}
23
+
24
+
#[derive(Debug, Serialize)]
25
+
pub struct TokenResponse {
26
+
pub access_token: String,
27
+
pub token_type: String,
28
+
pub expires_in: u64,
29
+
#[serde(skip_serializing_if = "Option::is_none")]
30
+
pub refresh_token: Option<String>,
31
+
#[serde(skip_serializing_if = "Option::is_none")]
32
+
pub scope: Option<String>,
33
+
#[serde(skip_serializing_if = "Option::is_none")]
34
+
pub sub: Option<String>,
35
+
}
+2
-1
src/repo/mod.rs
+2
-1
src/repo/mod.rs
···
38
let mut hasher = Sha256::new();
39
hasher.update(data);
40
let hash = hasher.finalize();
41
+
let multihash = Multihash::wrap(0x12, &hash)
42
+
.map_err(|e| RepoError::storage(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to wrap multihash: {:?}", e))))?;
43
let cid = Cid::new_v1(0x71, multihash);
44
let cid_bytes = cid.to_bytes();
45
+12
-3
src/repo/tracking.rs
+12
-3
src/repo/tracking.rs
···
21
}
22
23
pub fn get_written_cids(&self) -> Vec<Cid> {
24
-
self.written_cids.lock().unwrap().clone()
25
}
26
}
27
···
32
33
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
34
let cid = self.inner.put(data).await?;
35
-
self.written_cids.lock().unwrap().push(cid.clone());
36
Ok(cid)
37
}
38
···
47
let blocks: Vec<_> = blocks.into_iter().collect();
48
let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect();
49
self.inner.put_many(blocks).await?;
50
-
self.written_cids.lock().unwrap().extend(cids);
51
Ok(())
52
}
53
···
21
}
22
23
pub fn get_written_cids(&self) -> Vec<Cid> {
24
+
match self.written_cids.lock() {
25
+
Ok(guard) => guard.clone(),
26
+
Err(poisoned) => poisoned.into_inner().clone(),
27
+
}
28
}
29
}
30
···
35
36
async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> {
37
let cid = self.inner.put(data).await?;
38
+
match self.written_cids.lock() {
39
+
Ok(mut guard) => guard.push(cid.clone()),
40
+
Err(poisoned) => poisoned.into_inner().push(cid.clone()),
41
+
}
42
Ok(cid)
43
}
44
···
53
let blocks: Vec<_> = blocks.into_iter().collect();
54
let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect();
55
self.inner.put_many(blocks).await?;
56
+
match self.written_cids.lock() {
57
+
Ok(mut guard) => guard.extend(cids),
58
+
Err(poisoned) => poisoned.into_inner().extend(cids),
59
+
}
60
Ok(())
61
}
62
+1
-1
src/sync/blob.rs
+1
-1
src/sync/blob.rs
+5
-4
src/sync/car.rs
+5
-4
src/sync/car.rs
···
23
Ok(())
24
}
25
26
-
pub fn encode_car_header(root_cid: &Cid) -> Vec<u8> {
27
let header = CarHeader::new_v1(vec![root_cid.clone()]);
28
-
let header_cbor = header.encode().unwrap_or_default();
29
30
let mut result = Vec::new();
31
-
write_varint(&mut result, header_cbor.len() as u64).unwrap();
32
result.extend_from_slice(&header_cbor);
33
-
result
34
}
···
23
Ok(())
24
}
25
26
+
pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> {
27
let header = CarHeader::new_v1(vec![root_cid.clone()]);
28
+
let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?;
29
30
let mut result = Vec::new();
31
+
write_varint(&mut result, header_cbor.len() as u64)
32
+
.expect("Writing to Vec<u8> should never fail");
33
result.extend_from_slice(&header_cbor);
34
+
Ok(result)
35
}
+1
-1
src/sync/commit.rs
+1
-1
src/sync/commit.rs
+9
-5
src/sync/frame.rs
+9
-5
src/sync/frame.rs
···
38
pub cid: Option<String>,
39
}
40
41
-
impl From<SequencedEvent> for CommitFrame {
42
-
fn from(event: SequencedEvent) -> Self {
43
let ops = serde_json::from_value::<Vec<RepoOp>>(event.ops.unwrap_or_default())
44
.unwrap_or_else(|_| vec![]);
45
46
-
CommitFrame {
47
seq: event.seq,
48
rebase: false,
49
too_big: false,
50
repo: event.did,
51
-
commit: event.commit_cid.unwrap_or_default(),
52
prev: event.prev_cid,
53
blocks: Vec::new(),
54
ops,
55
blobs: event.blobs.unwrap_or_default(),
56
time: event.created_at.to_rfc3339(),
57
-
}
58
}
59
}
···
38
pub cid: Option<String>,
39
}
40
41
+
impl TryFrom<SequencedEvent> for CommitFrame {
42
+
type Error = &'static str;
43
+
44
+
fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> {
45
let ops = serde_json::from_value::<Vec<RepoOp>>(event.ops.unwrap_or_default())
46
.unwrap_or_else(|_| vec![]);
47
48
+
let commit_cid = event.commit_cid.ok_or("Missing commit_cid in event")?;
49
+
50
+
Ok(CommitFrame {
51
seq: event.seq,
52
rebase: false,
53
too_big: false,
54
repo: event.did,
55
+
commit: commit_cid,
56
prev: event.prev_cid,
57
blocks: Vec::new(),
58
ops,
59
blobs: event.blobs.unwrap_or_default(),
60
time: event.created_at.to_rfc3339(),
61
+
})
62
}
63
}
+1
-2
src/sync/relay_client.rs
+1
-2
src/sync/relay_client.rs
+48
-17
src/sync/repo.rs
+48
-17
src/sync/repo.rs
···
15
use std::str::FromStr;
16
use tracing::error;
17
18
#[derive(Deserialize)]
19
pub struct GetBlocksQuery {
20
pub did: String,
···
52
}
53
};
54
55
-
let root_cid = cids.first().cloned().unwrap_or_default();
56
-
57
if cids.is_empty() {
58
return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response();
59
}
60
61
-
let header = encode_car_header(&root_cid);
62
63
let mut car_bytes = header;
64
···
69
let total_len = cid_bytes.len() + block.len();
70
71
let mut writer = Vec::new();
72
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
73
-
writer.write_all(&cid_bytes).unwrap();
74
-
writer.write_all(&block).unwrap();
75
76
car_bytes.extend_from_slice(&writer);
77
}
···
143
}
144
};
145
146
-
let mut car_bytes = encode_car_header(&head_cid);
147
148
let mut stack = vec![head_cid];
149
let mut visited = std::collections::HashSet::new();
150
-
let mut limit = 20000;
151
152
while let Some(cid) = stack.pop() {
153
if visited.contains(&cid) {
154
continue;
155
}
156
visited.insert(cid);
157
-
if limit == 0 { break; }
158
-
limit -= 1;
159
160
if let Ok(Some(block)) = state.block_store.get(&cid).await {
161
let cid_bytes = cid.to_bytes();
162
let total_len = cid_bytes.len() + block.len();
163
let mut writer = Vec::new();
164
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
165
-
writer.write_all(&cid_bytes).unwrap();
166
-
writer.write_all(&block).unwrap();
167
car_bytes.extend_from_slice(&writer);
168
169
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
···
258
_ => return (StatusCode::NOT_FOUND, "Block not found").into_response(),
259
};
260
261
-
let header = encode_car_header(&cid);
262
let mut car_bytes = header;
263
264
let cid_bytes = cid.to_bytes();
265
let total_len = cid_bytes.len() + block.len();
266
let mut writer = Vec::new();
267
-
crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap();
268
-
writer.write_all(&cid_bytes).unwrap();
269
-
writer.write_all(&block).unwrap();
270
car_bytes.extend_from_slice(&writer);
271
272
(
···
15
use std::str::FromStr;
16
use tracing::error;
17
18
+
const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000;
19
+
20
#[derive(Deserialize)]
21
pub struct GetBlocksQuery {
22
pub did: String,
···
54
}
55
};
56
57
if cids.is_empty() {
58
return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response();
59
}
60
61
+
let root_cid = cids[0];
62
+
63
+
let header = match encode_car_header(&root_cid) {
64
+
Ok(h) => h,
65
+
Err(e) => {
66
+
error!("Failed to encode CAR header: {}", e);
67
+
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to encode CAR").into_response();
68
+
}
69
+
};
70
71
let mut car_bytes = header;
72
···
77
let total_len = cid_bytes.len() + block.len();
78
79
let mut writer = Vec::new();
80
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
81
+
.expect("Writing to Vec<u8> should never fail");
82
+
writer.write_all(&cid_bytes)
83
+
.expect("Writing to Vec<u8> should never fail");
84
+
writer.write_all(&block)
85
+
.expect("Writing to Vec<u8> should never fail");
86
87
car_bytes.extend_from_slice(&writer);
88
}
···
154
}
155
};
156
157
+
let mut car_bytes = match encode_car_header(&head_cid) {
158
+
Ok(h) => h,
159
+
Err(e) => {
160
+
return (
161
+
StatusCode::INTERNAL_SERVER_ERROR,
162
+
Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})),
163
+
)
164
+
.into_response();
165
+
}
166
+
};
167
168
let mut stack = vec![head_cid];
169
let mut visited = std::collections::HashSet::new();
170
+
let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL;
171
172
while let Some(cid) = stack.pop() {
173
if visited.contains(&cid) {
174
continue;
175
}
176
visited.insert(cid);
177
+
if remaining == 0 { break; }
178
+
remaining -= 1;
179
180
if let Ok(Some(block)) = state.block_store.get(&cid).await {
181
let cid_bytes = cid.to_bytes();
182
let total_len = cid_bytes.len() + block.len();
183
let mut writer = Vec::new();
184
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
185
+
.expect("Writing to Vec<u8> should never fail");
186
+
writer.write_all(&cid_bytes)
187
+
.expect("Writing to Vec<u8> should never fail");
188
+
writer.write_all(&block)
189
+
.expect("Writing to Vec<u8> should never fail");
190
car_bytes.extend_from_slice(&writer);
191
192
if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) {
···
281
_ => return (StatusCode::NOT_FOUND, "Block not found").into_response(),
282
};
283
284
+
let header = match encode_car_header(&cid) {
285
+
Ok(h) => h,
286
+
Err(e) => {
287
+
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to encode CAR header: {}", e)).into_response();
288
+
}
289
+
};
290
let mut car_bytes = header;
291
292
let cid_bytes = cid.to_bytes();
293
let total_len = cid_bytes.len() + block.len();
294
let mut writer = Vec::new();
295
+
crate::sync::car::write_varint(&mut writer, total_len as u64)
296
+
.expect("Writing to Vec<u8> should never fail");
297
+
writer.write_all(&cid_bytes)
298
+
.expect("Writing to Vec<u8> should never fail");
299
+
writer.write_all(&block)
300
+
.expect("Writing to Vec<u8> should never fail");
301
car_bytes.extend_from_slice(&writer);
302
303
(
+37
-23
src/sync/subscribe_repos.rs
+37
-23
src/sync/subscribe_repos.rs
···
9
use serde::Deserialize;
10
use tracing::{error, info, warn};
11
12
#[derive(Deserialize)]
13
pub struct SubscribeReposParams {
14
pub cursor: Option<i64>,
···
37
info!(cursor = ?params.cursor, "New firehose subscriber");
38
39
if let Some(cursor) = params.cursor {
40
-
let events = sqlx::query_as!(
41
-
SequencedEvent,
42
-
r#"
43
-
SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids
44
-
FROM repo_seq
45
-
WHERE seq > $1
46
-
ORDER BY seq ASC
47
-
"#,
48
-
cursor
49
-
)
50
-
.fetch_all(&state.db)
51
-
.await;
52
53
-
match events {
54
-
Ok(events) => {
55
-
for event in events {
56
-
if let Err(e) = send_event(&mut socket, &state, event).await {
57
-
warn!("Failed to send backfill event: {}", e);
58
-
return;
59
}
60
}
61
-
}
62
-
Err(e) => {
63
-
error!("Failed to fetch backfill events: {}", e);
64
-
socket.close().await.ok();
65
-
return;
66
}
67
}
68
}
···
9
use serde::Deserialize;
10
use tracing::{error, info, warn};
11
12
+
const BACKFILL_BATCH_SIZE: i64 = 1000;
13
+
14
#[derive(Deserialize)]
15
pub struct SubscribeReposParams {
16
pub cursor: Option<i64>,
···
39
info!(cursor = ?params.cursor, "New firehose subscriber");
40
41
if let Some(cursor) = params.cursor {
42
+
let mut current_cursor = cursor;
43
+
loop {
44
+
let events = sqlx::query_as!(
45
+
SequencedEvent,
46
+
r#"
47
+
SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids
48
+
FROM repo_seq
49
+
WHERE seq > $1
50
+
ORDER BY seq ASC
51
+
LIMIT $2
52
+
"#,
53
+
current_cursor,
54
+
BACKFILL_BATCH_SIZE
55
+
)
56
+
.fetch_all(&state.db)
57
+
.await;
58
59
+
match events {
60
+
Ok(events) => {
61
+
if events.is_empty() {
62
+
break;
63
+
}
64
+
for event in &events {
65
+
current_cursor = event.seq;
66
+
if let Err(e) = send_event(&mut socket, &state, event.clone()).await {
67
+
warn!("Failed to send backfill event: {}", e);
68
+
return;
69
+
}
70
+
}
71
+
if (events.len() as i64) < BACKFILL_BATCH_SIZE {
72
+
break;
73
}
74
}
75
+
Err(e) => {
76
+
error!("Failed to fetch backfill events: {}", e);
77
+
socket.close().await.ok();
78
+
return;
79
+
}
80
}
81
}
82
}
+8
-15
src/sync/util.rs
+8
-15
src/sync/util.rs
···
2
use crate::sync::firehose::SequencedEvent;
3
use crate::sync::frame::{CommitFrame, Frame, FrameData};
4
use cid::Cid;
5
-
use jacquard_repo::car::write_car;
6
use jacquard_repo::storage::BlockStore;
7
-
use std::fs;
8
use std::str::FromStr;
9
-
use tokio::fs::File;
10
-
use tokio::io::AsyncReadExt;
11
-
use uuid::Uuid;
12
13
pub async fn format_event_for_sending(
14
state: &AppState,
15
event: SequencedEvent,
16
) -> Result<Vec<u8>, anyhow::Error> {
17
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
18
-
let mut frame: CommitFrame = event.into();
19
20
-
let mut car_bytes = Vec::new();
21
-
if !block_cids_str.is_empty() {
22
-
let temp_path = format!("/tmp/{}.car", Uuid::new_v4());
23
let mut blocks = std::collections::BTreeMap::new();
24
25
for cid_str in block_cids_str {
···
33
}
34
35
let root = Cid::from_str(&frame.commit)?;
36
-
write_car(&temp_path, vec![root], blocks).await?;
37
-
38
-
let mut file = File::open(&temp_path).await?;
39
-
file.read_to_end(&mut car_bytes).await?;
40
-
fs::remove_file(&temp_path)?;
41
-
}
42
frame.blocks = car_bytes;
43
44
let frame = Frame {
···
2
use crate::sync::firehose::SequencedEvent;
3
use crate::sync::frame::{CommitFrame, Frame, FrameData};
4
use cid::Cid;
5
+
use jacquard_repo::car::write_car_bytes;
6
use jacquard_repo::storage::BlockStore;
7
use std::str::FromStr;
8
9
pub async fn format_event_for_sending(
10
state: &AppState,
11
event: SequencedEvent,
12
) -> Result<Vec<u8>, anyhow::Error> {
13
let block_cids_str = event.blocks_cids.clone().unwrap_or_default();
14
+
let mut frame: CommitFrame = event.try_into()
15
+
.map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?;
16
17
+
let car_bytes = if !block_cids_str.is_empty() {
18
let mut blocks = std::collections::BTreeMap::new();
19
20
for cid_str in block_cids_str {
···
28
}
29
30
let root = Cid::from_str(&frame.commit)?;
31
+
write_car_bytes(root, blocks).await?
32
+
} else {
33
+
Vec::new()
34
+
};
35
frame.blocks = car_bytes;
36
37
let frame = Frame {
+2
-342
src/sync/verify.rs
+2
-342
src/sync/verify.rs
···
302
}
303
304
#[cfg(test)]
305
-
mod tests {
306
-
use super::*;
307
-
use sha2::{Digest, Sha256};
308
-
309
-
fn make_cid(data: &[u8]) -> Cid {
310
-
let mut hasher = Sha256::new();
311
-
hasher.update(data);
312
-
let hash = hasher.finalize();
313
-
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
314
-
Cid::new_v1(0x71, multihash)
315
-
}
316
-
317
-
#[test]
318
-
fn test_verifier_creation() {
319
-
let _verifier = CarVerifier::new();
320
-
}
321
-
322
-
#[test]
323
-
fn test_verify_error_display() {
324
-
let err = VerifyError::DidMismatch {
325
-
commit_did: "did:plc:abc".to_string(),
326
-
expected_did: "did:plc:xyz".to_string(),
327
-
};
328
-
assert!(err.to_string().contains("did:plc:abc"));
329
-
assert!(err.to_string().contains("did:plc:xyz"));
330
-
331
-
let err = VerifyError::InvalidSignature;
332
-
assert!(err.to_string().contains("signature"));
333
-
334
-
let err = VerifyError::NoSigningKey;
335
-
assert!(err.to_string().contains("signing key"));
336
-
337
-
let err = VerifyError::MstValidationFailed("test error".to_string());
338
-
assert!(err.to_string().contains("test error"));
339
-
}
340
-
341
-
#[test]
342
-
fn test_mst_validation_missing_root_block() {
343
-
let verifier = CarVerifier::new();
344
-
let blocks: HashMap<Cid, Bytes> = HashMap::new();
345
-
346
-
let fake_cid = make_cid(b"fake data");
347
-
let result = verifier.verify_mst_structure(&fake_cid, &blocks);
348
-
349
-
assert!(result.is_err());
350
-
let err = result.unwrap_err();
351
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
352
-
}
353
-
354
-
#[test]
355
-
fn test_mst_validation_invalid_cbor() {
356
-
let verifier = CarVerifier::new();
357
-
358
-
let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]);
359
-
let cid = make_cid(&bad_cbor);
360
-
361
-
let mut blocks = HashMap::new();
362
-
blocks.insert(cid, bad_cbor);
363
-
364
-
let result = verifier.verify_mst_structure(&cid, &blocks);
365
-
366
-
assert!(result.is_err());
367
-
let err = result.unwrap_err();
368
-
assert!(matches!(err, VerifyError::InvalidCbor(_)));
369
-
}
370
-
371
-
#[test]
372
-
fn test_mst_validation_empty_node() {
373
-
let verifier = CarVerifier::new();
374
-
375
-
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
376
-
"e": []
377
-
})).unwrap();
378
-
let cid = make_cid(&empty_node);
379
-
380
-
let mut blocks = HashMap::new();
381
-
blocks.insert(cid, Bytes::from(empty_node));
382
-
383
-
let result = verifier.verify_mst_structure(&cid, &blocks);
384
-
assert!(result.is_ok());
385
-
}
386
-
387
-
#[test]
388
-
fn test_mst_validation_missing_left_pointer() {
389
-
use ipld_core::ipld::Ipld;
390
-
391
-
let verifier = CarVerifier::new();
392
-
393
-
let missing_left_cid = make_cid(b"missing left");
394
-
let node = Ipld::Map(std::collections::BTreeMap::from([
395
-
("l".to_string(), Ipld::Link(missing_left_cid)),
396
-
("e".to_string(), Ipld::List(vec![])),
397
-
]));
398
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
399
-
let cid = make_cid(&node_bytes);
400
-
401
-
let mut blocks = HashMap::new();
402
-
blocks.insert(cid, Bytes::from(node_bytes));
403
-
404
-
let result = verifier.verify_mst_structure(&cid, &blocks);
405
-
406
-
assert!(result.is_err());
407
-
let err = result.unwrap_err();
408
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
409
-
assert!(err.to_string().contains("left pointer"));
410
-
}
411
-
412
-
#[test]
413
-
fn test_mst_validation_missing_subtree() {
414
-
use ipld_core::ipld::Ipld;
415
-
416
-
let verifier = CarVerifier::new();
417
-
418
-
let missing_subtree_cid = make_cid(b"missing subtree");
419
-
let record_cid = make_cid(b"record");
420
-
421
-
let entry = Ipld::Map(std::collections::BTreeMap::from([
422
-
("k".to_string(), Ipld::Bytes(b"key1".to_vec())),
423
-
("v".to_string(), Ipld::Link(record_cid)),
424
-
("p".to_string(), Ipld::Integer(0)),
425
-
("t".to_string(), Ipld::Link(missing_subtree_cid)),
426
-
]));
427
-
428
-
let node = Ipld::Map(std::collections::BTreeMap::from([
429
-
("e".to_string(), Ipld::List(vec![entry])),
430
-
]));
431
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
432
-
let cid = make_cid(&node_bytes);
433
-
434
-
let mut blocks = HashMap::new();
435
-
blocks.insert(cid, Bytes::from(node_bytes));
436
-
437
-
let result = verifier.verify_mst_structure(&cid, &blocks);
438
-
439
-
assert!(result.is_err());
440
-
let err = result.unwrap_err();
441
-
assert!(matches!(err, VerifyError::BlockNotFound(_)));
442
-
assert!(err.to_string().contains("subtree"));
443
-
}
444
-
445
-
#[test]
446
-
fn test_mst_validation_unsorted_keys() {
447
-
use ipld_core::ipld::Ipld;
448
-
449
-
let verifier = CarVerifier::new();
450
-
451
-
let record_cid = make_cid(b"record");
452
-
453
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
454
-
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
455
-
("v".to_string(), Ipld::Link(record_cid)),
456
-
("p".to_string(), Ipld::Integer(0)),
457
-
]));
458
-
459
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
460
-
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
461
-
("v".to_string(), Ipld::Link(record_cid)),
462
-
("p".to_string(), Ipld::Integer(0)),
463
-
]));
464
-
465
-
let node = Ipld::Map(std::collections::BTreeMap::from([
466
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
467
-
]));
468
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
469
-
let cid = make_cid(&node_bytes);
470
-
471
-
let mut blocks = HashMap::new();
472
-
blocks.insert(cid, Bytes::from(node_bytes));
473
-
474
-
let result = verifier.verify_mst_structure(&cid, &blocks);
475
-
476
-
assert!(result.is_err());
477
-
let err = result.unwrap_err();
478
-
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
479
-
assert!(err.to_string().contains("sorted"));
480
-
}
481
-
482
-
#[test]
483
-
fn test_mst_validation_sorted_keys_ok() {
484
-
use ipld_core::ipld::Ipld;
485
-
486
-
let verifier = CarVerifier::new();
487
-
488
-
let record_cid = make_cid(b"record");
489
-
490
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
491
-
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
492
-
("v".to_string(), Ipld::Link(record_cid)),
493
-
("p".to_string(), Ipld::Integer(0)),
494
-
]));
495
-
496
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
497
-
("k".to_string(), Ipld::Bytes(b"bbb".to_vec())),
498
-
("v".to_string(), Ipld::Link(record_cid)),
499
-
("p".to_string(), Ipld::Integer(0)),
500
-
]));
501
-
502
-
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
503
-
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
504
-
("v".to_string(), Ipld::Link(record_cid)),
505
-
("p".to_string(), Ipld::Integer(0)),
506
-
]));
507
-
508
-
let node = Ipld::Map(std::collections::BTreeMap::from([
509
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
510
-
]));
511
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
512
-
let cid = make_cid(&node_bytes);
513
-
514
-
let mut blocks = HashMap::new();
515
-
blocks.insert(cid, Bytes::from(node_bytes));
516
-
517
-
let result = verifier.verify_mst_structure(&cid, &blocks);
518
-
assert!(result.is_ok());
519
-
}
520
-
521
-
#[test]
522
-
fn test_mst_validation_with_valid_left_pointer() {
523
-
use ipld_core::ipld::Ipld;
524
-
525
-
let verifier = CarVerifier::new();
526
-
527
-
let left_node = Ipld::Map(std::collections::BTreeMap::from([
528
-
("e".to_string(), Ipld::List(vec![])),
529
-
]));
530
-
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
531
-
let left_cid = make_cid(&left_node_bytes);
532
-
533
-
let root_node = Ipld::Map(std::collections::BTreeMap::from([
534
-
("l".to_string(), Ipld::Link(left_cid)),
535
-
("e".to_string(), Ipld::List(vec![])),
536
-
]));
537
-
let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap();
538
-
let root_cid = make_cid(&root_node_bytes);
539
-
540
-
let mut blocks = HashMap::new();
541
-
blocks.insert(root_cid, Bytes::from(root_node_bytes));
542
-
blocks.insert(left_cid, Bytes::from(left_node_bytes));
543
-
544
-
let result = verifier.verify_mst_structure(&root_cid, &blocks);
545
-
assert!(result.is_ok());
546
-
}
547
-
548
-
#[test]
549
-
fn test_mst_validation_cycle_detection() {
550
-
let verifier = CarVerifier::new();
551
-
552
-
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
553
-
"e": []
554
-
})).unwrap();
555
-
let cid = make_cid(&node);
556
-
557
-
let mut blocks = HashMap::new();
558
-
blocks.insert(cid, Bytes::from(node));
559
-
560
-
let result = verifier.verify_mst_structure(&cid, &blocks);
561
-
assert!(result.is_ok());
562
-
}
563
-
564
-
#[tokio::test]
565
-
async fn test_unsupported_did_method() {
566
-
let verifier = CarVerifier::new();
567
-
let result = verifier.resolve_did_document("did:unknown:test").await;
568
-
569
-
assert!(result.is_err());
570
-
let err = result.unwrap_err();
571
-
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
572
-
assert!(err.to_string().contains("Unsupported"));
573
-
}
574
-
575
-
#[test]
576
-
fn test_mst_validation_with_prefix_compression() {
577
-
use ipld_core::ipld::Ipld;
578
-
579
-
let verifier = CarVerifier::new();
580
-
let record_cid = make_cid(b"record");
581
-
582
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
583
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
584
-
("v".to_string(), Ipld::Link(record_cid)),
585
-
("p".to_string(), Ipld::Integer(0)),
586
-
]));
587
-
588
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
589
-
("k".to_string(), Ipld::Bytes(b"def".to_vec())),
590
-
("v".to_string(), Ipld::Link(record_cid)),
591
-
("p".to_string(), Ipld::Integer(19)),
592
-
]));
593
-
594
-
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
595
-
("k".to_string(), Ipld::Bytes(b"xyz".to_vec())),
596
-
("v".to_string(), Ipld::Link(record_cid)),
597
-
("p".to_string(), Ipld::Integer(19)),
598
-
]));
599
-
600
-
let node = Ipld::Map(std::collections::BTreeMap::from([
601
-
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
602
-
]));
603
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
604
-
let cid = make_cid(&node_bytes);
605
-
606
-
let mut blocks = HashMap::new();
607
-
blocks.insert(cid, Bytes::from(node_bytes));
608
-
609
-
let result = verifier.verify_mst_structure(&cid, &blocks);
610
-
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
611
-
}
612
-
613
-
#[test]
614
-
fn test_mst_validation_prefix_compression_unsorted() {
615
-
use ipld_core::ipld::Ipld;
616
-
617
-
let verifier = CarVerifier::new();
618
-
let record_cid = make_cid(b"record");
619
-
620
-
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
621
-
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
622
-
("v".to_string(), Ipld::Link(record_cid)),
623
-
("p".to_string(), Ipld::Integer(0)),
624
-
]));
625
-
626
-
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
627
-
("k".to_string(), Ipld::Bytes(b"abc".to_vec())),
628
-
("v".to_string(), Ipld::Link(record_cid)),
629
-
("p".to_string(), Ipld::Integer(19)),
630
-
]));
631
-
632
-
let node = Ipld::Map(std::collections::BTreeMap::from([
633
-
("e".to_string(), Ipld::List(vec![entry1, entry2])),
634
-
]));
635
-
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
636
-
let cid = make_cid(&node_bytes);
637
-
638
-
let mut blocks = HashMap::new();
639
-
blocks.insert(cid, Bytes::from(node_bytes));
640
-
641
-
let result = verifier.verify_mst_structure(&cid, &blocks);
642
-
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
643
-
let err = result.unwrap_err();
644
-
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
645
-
}
646
-
}
+346
src/sync/verify_tests.rs
+346
src/sync/verify_tests.rs
···
···
1
+
#[cfg(test)]
2
+
mod tests {
3
+
use crate::sync::verify::{CarVerifier, VerifyError};
4
+
use bytes::Bytes;
5
+
use cid::Cid;
6
+
use sha2::{Digest, Sha256};
7
+
use std::collections::HashMap;
8
+
9
+
fn make_cid(data: &[u8]) -> Cid {
10
+
let mut hasher = Sha256::new();
11
+
hasher.update(data);
12
+
let hash = hasher.finalize();
13
+
let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap();
14
+
Cid::new_v1(0x71, multihash)
15
+
}
16
+
17
+
#[test]
18
+
fn test_verifier_creation() {
19
+
let _verifier = CarVerifier::new();
20
+
}
21
+
22
+
#[test]
23
+
fn test_verify_error_display() {
24
+
let err = VerifyError::DidMismatch {
25
+
commit_did: "did:plc:abc".to_string(),
26
+
expected_did: "did:plc:xyz".to_string(),
27
+
};
28
+
assert!(err.to_string().contains("did:plc:abc"));
29
+
assert!(err.to_string().contains("did:plc:xyz"));
30
+
31
+
let err = VerifyError::InvalidSignature;
32
+
assert!(err.to_string().contains("signature"));
33
+
34
+
let err = VerifyError::NoSigningKey;
35
+
assert!(err.to_string().contains("signing key"));
36
+
37
+
let err = VerifyError::MstValidationFailed("test error".to_string());
38
+
assert!(err.to_string().contains("test error"));
39
+
}
40
+
41
+
#[test]
42
+
fn test_mst_validation_missing_root_block() {
43
+
let verifier = CarVerifier::new();
44
+
let blocks: HashMap<Cid, Bytes> = HashMap::new();
45
+
46
+
let fake_cid = make_cid(b"fake data");
47
+
let result = verifier.verify_mst_structure(&fake_cid, &blocks);
48
+
49
+
assert!(result.is_err());
50
+
let err = result.unwrap_err();
51
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
52
+
}
53
+
54
+
#[test]
55
+
fn test_mst_validation_invalid_cbor() {
56
+
let verifier = CarVerifier::new();
57
+
58
+
let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]);
59
+
let cid = make_cid(&bad_cbor);
60
+
61
+
let mut blocks = HashMap::new();
62
+
blocks.insert(cid, bad_cbor);
63
+
64
+
let result = verifier.verify_mst_structure(&cid, &blocks);
65
+
66
+
assert!(result.is_err());
67
+
let err = result.unwrap_err();
68
+
assert!(matches!(err, VerifyError::InvalidCbor(_)));
69
+
}
70
+
71
+
#[test]
72
+
fn test_mst_validation_empty_node() {
73
+
let verifier = CarVerifier::new();
74
+
75
+
let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
76
+
"e": []
77
+
})).unwrap();
78
+
let cid = make_cid(&empty_node);
79
+
80
+
let mut blocks = HashMap::new();
81
+
blocks.insert(cid, Bytes::from(empty_node));
82
+
83
+
let result = verifier.verify_mst_structure(&cid, &blocks);
84
+
assert!(result.is_ok());
85
+
}
86
+
87
+
#[test]
88
+
fn test_mst_validation_missing_left_pointer() {
89
+
use ipld_core::ipld::Ipld;
90
+
91
+
let verifier = CarVerifier::new();
92
+
93
+
let missing_left_cid = make_cid(b"missing left");
94
+
let node = Ipld::Map(std::collections::BTreeMap::from([
95
+
("l".to_string(), Ipld::Link(missing_left_cid)),
96
+
("e".to_string(), Ipld::List(vec![])),
97
+
]));
98
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
99
+
let cid = make_cid(&node_bytes);
100
+
101
+
let mut blocks = HashMap::new();
102
+
blocks.insert(cid, Bytes::from(node_bytes));
103
+
104
+
let result = verifier.verify_mst_structure(&cid, &blocks);
105
+
106
+
assert!(result.is_err());
107
+
let err = result.unwrap_err();
108
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
109
+
assert!(err.to_string().contains("left pointer"));
110
+
}
111
+
112
+
#[test]
113
+
fn test_mst_validation_missing_subtree() {
114
+
use ipld_core::ipld::Ipld;
115
+
116
+
let verifier = CarVerifier::new();
117
+
118
+
let missing_subtree_cid = make_cid(b"missing subtree");
119
+
let record_cid = make_cid(b"record");
120
+
121
+
let entry = Ipld::Map(std::collections::BTreeMap::from([
122
+
("k".to_string(), Ipld::Bytes(b"key1".to_vec())),
123
+
("v".to_string(), Ipld::Link(record_cid)),
124
+
("p".to_string(), Ipld::Integer(0)),
125
+
("t".to_string(), Ipld::Link(missing_subtree_cid)),
126
+
]));
127
+
128
+
let node = Ipld::Map(std::collections::BTreeMap::from([
129
+
("e".to_string(), Ipld::List(vec![entry])),
130
+
]));
131
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
132
+
let cid = make_cid(&node_bytes);
133
+
134
+
let mut blocks = HashMap::new();
135
+
blocks.insert(cid, Bytes::from(node_bytes));
136
+
137
+
let result = verifier.verify_mst_structure(&cid, &blocks);
138
+
139
+
assert!(result.is_err());
140
+
let err = result.unwrap_err();
141
+
assert!(matches!(err, VerifyError::BlockNotFound(_)));
142
+
assert!(err.to_string().contains("subtree"));
143
+
}
144
+
145
+
#[test]
146
+
fn test_mst_validation_unsorted_keys() {
147
+
use ipld_core::ipld::Ipld;
148
+
149
+
let verifier = CarVerifier::new();
150
+
151
+
let record_cid = make_cid(b"record");
152
+
153
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
154
+
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
155
+
("v".to_string(), Ipld::Link(record_cid)),
156
+
("p".to_string(), Ipld::Integer(0)),
157
+
]));
158
+
159
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
160
+
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
161
+
("v".to_string(), Ipld::Link(record_cid)),
162
+
("p".to_string(), Ipld::Integer(0)),
163
+
]));
164
+
165
+
let node = Ipld::Map(std::collections::BTreeMap::from([
166
+
("e".to_string(), Ipld::List(vec![entry1, entry2])),
167
+
]));
168
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
169
+
let cid = make_cid(&node_bytes);
170
+
171
+
let mut blocks = HashMap::new();
172
+
blocks.insert(cid, Bytes::from(node_bytes));
173
+
174
+
let result = verifier.verify_mst_structure(&cid, &blocks);
175
+
176
+
assert!(result.is_err());
177
+
let err = result.unwrap_err();
178
+
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
179
+
assert!(err.to_string().contains("sorted"));
180
+
}
181
+
182
+
#[test]
183
+
fn test_mst_validation_sorted_keys_ok() {
184
+
use ipld_core::ipld::Ipld;
185
+
186
+
let verifier = CarVerifier::new();
187
+
188
+
let record_cid = make_cid(b"record");
189
+
190
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
191
+
("k".to_string(), Ipld::Bytes(b"aaa".to_vec())),
192
+
("v".to_string(), Ipld::Link(record_cid)),
193
+
("p".to_string(), Ipld::Integer(0)),
194
+
]));
195
+
196
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
197
+
("k".to_string(), Ipld::Bytes(b"bbb".to_vec())),
198
+
("v".to_string(), Ipld::Link(record_cid)),
199
+
("p".to_string(), Ipld::Integer(0)),
200
+
]));
201
+
202
+
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
203
+
("k".to_string(), Ipld::Bytes(b"zzz".to_vec())),
204
+
("v".to_string(), Ipld::Link(record_cid)),
205
+
("p".to_string(), Ipld::Integer(0)),
206
+
]));
207
+
208
+
let node = Ipld::Map(std::collections::BTreeMap::from([
209
+
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
210
+
]));
211
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
212
+
let cid = make_cid(&node_bytes);
213
+
214
+
let mut blocks = HashMap::new();
215
+
blocks.insert(cid, Bytes::from(node_bytes));
216
+
217
+
let result = verifier.verify_mst_structure(&cid, &blocks);
218
+
assert!(result.is_ok());
219
+
}
220
+
221
+
#[test]
222
+
fn test_mst_validation_with_valid_left_pointer() {
223
+
use ipld_core::ipld::Ipld;
224
+
225
+
let verifier = CarVerifier::new();
226
+
227
+
let left_node = Ipld::Map(std::collections::BTreeMap::from([
228
+
("e".to_string(), Ipld::List(vec![])),
229
+
]));
230
+
let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap();
231
+
let left_cid = make_cid(&left_node_bytes);
232
+
233
+
let root_node = Ipld::Map(std::collections::BTreeMap::from([
234
+
("l".to_string(), Ipld::Link(left_cid)),
235
+
("e".to_string(), Ipld::List(vec![])),
236
+
]));
237
+
let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap();
238
+
let root_cid = make_cid(&root_node_bytes);
239
+
240
+
let mut blocks = HashMap::new();
241
+
blocks.insert(root_cid, Bytes::from(root_node_bytes));
242
+
blocks.insert(left_cid, Bytes::from(left_node_bytes));
243
+
244
+
let result = verifier.verify_mst_structure(&root_cid, &blocks);
245
+
assert!(result.is_ok());
246
+
}
247
+
248
+
#[test]
249
+
fn test_mst_validation_cycle_detection() {
250
+
let verifier = CarVerifier::new();
251
+
252
+
let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({
253
+
"e": []
254
+
})).unwrap();
255
+
let cid = make_cid(&node);
256
+
257
+
let mut blocks = HashMap::new();
258
+
blocks.insert(cid, Bytes::from(node));
259
+
260
+
let result = verifier.verify_mst_structure(&cid, &blocks);
261
+
assert!(result.is_ok());
262
+
}
263
+
264
+
#[tokio::test]
265
+
async fn test_unsupported_did_method() {
266
+
let verifier = CarVerifier::new();
267
+
let result = verifier.resolve_did_document("did:unknown:test").await;
268
+
269
+
assert!(result.is_err());
270
+
let err = result.unwrap_err();
271
+
assert!(matches!(err, VerifyError::DidResolutionFailed(_)));
272
+
assert!(err.to_string().contains("Unsupported"));
273
+
}
274
+
275
+
#[test]
276
+
fn test_mst_validation_with_prefix_compression() {
277
+
use ipld_core::ipld::Ipld;
278
+
279
+
let verifier = CarVerifier::new();
280
+
let record_cid = make_cid(b"record");
281
+
282
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
283
+
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())),
284
+
("v".to_string(), Ipld::Link(record_cid)),
285
+
("p".to_string(), Ipld::Integer(0)),
286
+
]));
287
+
288
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
289
+
("k".to_string(), Ipld::Bytes(b"def".to_vec())),
290
+
("v".to_string(), Ipld::Link(record_cid)),
291
+
("p".to_string(), Ipld::Integer(19)),
292
+
]));
293
+
294
+
let entry3 = Ipld::Map(std::collections::BTreeMap::from([
295
+
("k".to_string(), Ipld::Bytes(b"xyz".to_vec())),
296
+
("v".to_string(), Ipld::Link(record_cid)),
297
+
("p".to_string(), Ipld::Integer(19)),
298
+
]));
299
+
300
+
let node = Ipld::Map(std::collections::BTreeMap::from([
301
+
("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])),
302
+
]));
303
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
304
+
let cid = make_cid(&node_bytes);
305
+
306
+
let mut blocks = HashMap::new();
307
+
blocks.insert(cid, Bytes::from(node_bytes));
308
+
309
+
let result = verifier.verify_mst_structure(&cid, &blocks);
310
+
assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly");
311
+
}
312
+
313
+
#[test]
314
+
fn test_mst_validation_prefix_compression_unsorted() {
315
+
use ipld_core::ipld::Ipld;
316
+
317
+
let verifier = CarVerifier::new();
318
+
let record_cid = make_cid(b"record");
319
+
320
+
let entry1 = Ipld::Map(std::collections::BTreeMap::from([
321
+
("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())),
322
+
("v".to_string(), Ipld::Link(record_cid)),
323
+
("p".to_string(), Ipld::Integer(0)),
324
+
]));
325
+
326
+
let entry2 = Ipld::Map(std::collections::BTreeMap::from([
327
+
("k".to_string(), Ipld::Bytes(b"abc".to_vec())),
328
+
("v".to_string(), Ipld::Link(record_cid)),
329
+
("p".to_string(), Ipld::Integer(19)),
330
+
]));
331
+
332
+
let node = Ipld::Map(std::collections::BTreeMap::from([
333
+
("e".to_string(), Ipld::List(vec![entry1, entry2])),
334
+
]));
335
+
let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap();
336
+
let cid = make_cid(&node_bytes);
337
+
338
+
let mut blocks = HashMap::new();
339
+
blocks.insert(cid, Bytes::from(node_bytes));
340
+
341
+
let result = verifier.verify_mst_structure(&cid, &blocks);
342
+
assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation");
343
+
let err = result.unwrap_err();
344
+
assert!(matches!(err, VerifyError::MstValidationFailed(_)));
345
+
}
346
+
}
+103
src/util.rs
+103
src/util.rs
···
···
1
+
use rand::Rng;
2
+
use sqlx::PgPool;
3
+
use uuid::Uuid;
4
+
5
+
const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
+
7
+
pub fn generate_token_code() -> String {
8
+
generate_token_code_parts(2, 5)
9
+
}
10
+
11
+
pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
12
+
let mut rng = rand::thread_rng();
13
+
let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
+
15
+
(0..parts)
16
+
.map(|_| {
17
+
(0..part_len)
18
+
.map(|_| chars[rng.gen_range(0..chars.len())])
19
+
.collect::<String>()
20
+
})
21
+
.collect::<Vec<_>>()
22
+
.join("-")
23
+
}
24
+
25
+
#[derive(Debug)]
26
+
pub enum DbLookupError {
27
+
NotFound,
28
+
DatabaseError(sqlx::Error),
29
+
}
30
+
31
+
impl From<sqlx::Error> for DbLookupError {
32
+
fn from(e: sqlx::Error) -> Self {
33
+
DbLookupError::DatabaseError(e)
34
+
}
35
+
}
36
+
37
+
pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
38
+
sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
39
+
.fetch_optional(db)
40
+
.await?
41
+
.ok_or(DbLookupError::NotFound)
42
+
}
43
+
44
+
pub struct UserInfo {
45
+
pub id: Uuid,
46
+
pub did: String,
47
+
pub handle: String,
48
+
}
49
+
50
+
pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
51
+
sqlx::query_as!(
52
+
UserInfo,
53
+
"SELECT id, did, handle FROM users WHERE did = $1",
54
+
did
55
+
)
56
+
.fetch_optional(db)
57
+
.await?
58
+
.ok_or(DbLookupError::NotFound)
59
+
}
60
+
61
+
pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
62
+
sqlx::query_as!(
63
+
UserInfo,
64
+
"SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
65
+
identifier
66
+
)
67
+
.fetch_optional(db)
68
+
.await?
69
+
.ok_or(DbLookupError::NotFound)
70
+
}
71
+
72
+
#[cfg(test)]
73
+
mod tests {
74
+
use super::*;
75
+
76
+
#[test]
77
+
fn test_generate_token_code() {
78
+
let code = generate_token_code();
79
+
assert_eq!(code.len(), 11);
80
+
assert!(code.contains('-'));
81
+
82
+
let parts: Vec<&str> = code.split('-').collect();
83
+
assert_eq!(parts.len(), 2);
84
+
assert_eq!(parts[0].len(), 5);
85
+
assert_eq!(parts[1].len(), 5);
86
+
87
+
for c in code.chars() {
88
+
if c != '-' {
89
+
assert!(BASE32_ALPHABET.contains(c));
90
+
}
91
+
}
92
+
}
93
+
94
+
#[test]
95
+
fn test_generate_token_code_parts() {
96
+
let code = generate_token_code_parts(3, 4);
97
+
let parts: Vec<&str> = code.split('-').collect();
98
+
assert_eq!(parts.len(), 3);
99
+
for part in parts {
100
+
assert_eq!(part.len(), 4);
101
+
}
102
+
}
103
+
}
+1
-1
tests/email_update.rs
+1
-1
tests/email_update.rs
+30
-12
tests/relay_client.rs
+30
-12
tests/relay_client.rs
···
13
async fn mock_relay_server(
14
listener: TcpListener,
15
event_tx: mpsc::Sender<Vec<u8>>,
16
-
ready_tx: mpsc::Sender<()>,
17
) {
18
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
19
ws.on_upgrade(move |mut socket| async move {
20
-
ready_tx.send(()).await.unwrap();
21
-
if let Some(Ok(Message::Binary(bytes))) = socket.recv().await {
22
-
event_tx.send(bytes.to_vec()).await.unwrap();
23
}
24
})
25
};
···
35
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
36
let addr = listener.local_addr().unwrap();
37
let (event_tx, mut event_rx) = mpsc::channel(1);
38
-
let (ready_tx, ready_rx) = mpsc::channel(1);
39
-
tokio::spawn(mock_relay_server(listener, event_tx, ready_tx));
40
let relay_url = format!("ws://{}", addr);
41
42
let db_url = get_db_connection_string().await;
···
46
.unwrap();
47
let state = AppState::new(pool).await;
48
49
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
50
51
-
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
52
53
let dummy_event = SequencedEvent {
54
seq: 1,
55
did: "did:plc:test".to_string(),
56
created_at: Utc::now(),
57
event_type: "commit".to_string(),
58
-
commit_cid: None,
59
prev_cid: None,
60
-
ops: None,
61
-
blobs: None,
62
-
blocks_cids: None,
63
};
64
state.firehose_tx.send(dummy_event).unwrap();
65
66
-
let received_bytes = event_rx.recv().await.expect("Did not receive event");
67
assert!(!received_bytes.is_empty());
68
}
···
13
async fn mock_relay_server(
14
listener: TcpListener,
15
event_tx: mpsc::Sender<Vec<u8>>,
16
+
connected_tx: mpsc::Sender<()>,
17
) {
18
let handler = |ws: axum::extract::ws::WebSocketUpgrade| async {
19
ws.on_upgrade(move |mut socket| async move {
20
+
let _ = connected_tx.send(()).await;
21
+
while let Some(Ok(msg)) = socket.recv().await {
22
+
if let Message::Binary(bytes) = msg {
23
+
let _ = event_tx.send(bytes.to_vec()).await;
24
+
break;
25
+
}
26
}
27
})
28
};
···
38
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
39
let addr = listener.local_addr().unwrap();
40
let (event_tx, mut event_rx) = mpsc::channel(1);
41
+
let (connected_tx, _connected_rx) = mpsc::channel::<()>(1);
42
+
tokio::spawn(mock_relay_server(listener, event_tx, connected_tx));
43
let relay_url = format!("ws://{}", addr);
44
45
let db_url = get_db_connection_string().await;
···
49
.unwrap();
50
let state = AppState::new(pool).await;
51
52
+
let (ready_tx, ready_rx) = mpsc::channel(1);
53
start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await;
54
55
+
tokio::time::timeout(
56
+
tokio::time::Duration::from_secs(5),
57
+
async {
58
+
ready_tx.closed().await;
59
+
}
60
+
)
61
+
.await
62
+
.expect("Timeout waiting for relay client to be ready");
63
64
let dummy_event = SequencedEvent {
65
seq: 1,
66
did: "did:plc:test".to_string(),
67
created_at: Utc::now(),
68
event_type: "commit".to_string(),
69
+
commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()),
70
prev_cid: None,
71
+
ops: Some(serde_json::json!([])),
72
+
blobs: Some(vec![]),
73
+
blocks_cids: Some(vec![]),
74
};
75
state.firehose_tx.send(dummy_event).unwrap();
76
77
+
let received_bytes = tokio::time::timeout(
78
+
tokio::time::Duration::from_secs(5),
79
+
event_rx.recv()
80
+
)
81
+
.await
82
+
.expect("Timeout waiting for event")
83
+
.expect("Event channel closed");
84
+
85
assert!(!received_bytes.is_empty());
86
}