this repo has no description

Idk. Code quality in general?

lewis 61dcea2c d922ae03

Changed files
+4398 -4055
.sqlx
src
tests
+2 -2
.sqlx/query-3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817.json .sqlx/query-e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()", 3 + "query": "SELECT st.id, st.did, k.key_bytes, k.encryption_version\n FROM session_tokens st\n JOIN users u ON st.did = u.did\n JOIN user_keys k ON u.id = k.user_id\n WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()\n FOR UPDATE OF st", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 36 36 true 37 37 ] 38 38 }, 39 - "hash": "3889903e58405370152b9ded229d843c0114e71454ea7da2b212519e98d09817" 39 + "hash": "e2e51654f146a3a336f5a28cbd47addbdd311aeaead530c00c1891c95bede0b8" 40 40 }
+3 -2
.sqlx/query-51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c.json .sqlx/query-b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n ", 3 + "query": "\n SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids\n FROM repo_seq\n WHERE seq > $1\n ORDER BY seq ASC\n LIMIT $2\n ", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 51 51 ], 52 52 "parameters": { 53 53 "Left": [ 54 + "Int8", 54 55 "Int8" 55 56 ] 56 57 }, ··· 66 67 true 67 68 ] 68 69 }, 69 - "hash": "51809819130908ef3600e5843f6098fb510afb4c827a41bc3a32ad78ea10184c" 70 + "hash": "b1c54d3f3e2d3031c0d926ccb0d39a0250320d41d08df65d6d9dcc640451527d" 70 71 }
+14
.sqlx/query-642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "DELETE FROM invite_codes WHERE created_by_user = $1", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Uuid" 9 + ] 10 + }, 11 + "nullable": [] 12 + }, 13 + "hash": "642b7199f2cbde74af72fc5b5b80f9e2b3efe901a3fdfc732f0d36d00db6326f" 14 + }
+14
.sqlx/query-6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "DELETE FROM invite_code_uses WHERE used_by_user = $1", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Uuid" 9 + ] 10 + }, 11 + "nullable": [] 12 + }, 13 + "hash": "6c71c4ac31f897e9d33a3637d89377c5977f76a117b042e1800b890b84a655ea" 14 + }
+2 -2
.sqlx/query-7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd.json .sqlx/query-9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2)", 3 + "query": "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 4 4 "describe": { 5 5 "columns": [], 6 6 "parameters": { ··· 11 11 }, 12 12 "nullable": [] 13 13 }, 14 - "hash": "7b76e2fcd809a1536465306c79da7985354175e0f025b29c6004dffa310feebd" 14 + "hash": "9a8b9c1cfecf02d1266b1544d5cb2dd8f1254b66b884ff22f983a2ba9dee0529" 15 15 }
+34
.sqlx/query-9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "handle", 19 + "type_info": "Text" 20 + } 21 + ], 22 + "parameters": { 23 + "Left": [ 24 + "Text" 25 + ] 26 + }, 27 + "nullable": [ 28 + false, 29 + false, 30 + false 31 + ] 32 + }, 33 + "hash": "9f435d95d7c270c82a164c59e9d0caa80ffd7107aff32c806709973fdc6b0020" 34 + }
+34
.sqlx/query-b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "SELECT id, did, handle FROM users WHERE did = $1", 4 + "describe": { 5 + "columns": [ 6 + { 7 + "ordinal": 0, 8 + "name": "id", 9 + "type_info": "Uuid" 10 + }, 11 + { 12 + "ordinal": 1, 13 + "name": "did", 14 + "type_info": "Text" 15 + }, 16 + { 17 + "ordinal": 2, 18 + "name": "handle", 19 + "type_info": "Text" 20 + } 21 + ], 22 + "parameters": { 23 + "Left": [ 24 + "Text" 25 + ] 26 + }, 27 + "nullable": [ 28 + false, 29 + false, 30 + false 31 + ] 32 + }, 33 + "hash": "b22827038d6041ad1f3b7eae07d77433def15237391fe26004577b12cb7e95b3" 34 + }
+14
.sqlx/query-c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813.json
··· 1 + { 2 + "db_name": "PostgreSQL", 3 + "query": "DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Left": [ 8 + "Text" 9 + ] 10 + }, 11 + "nullable": [] 12 + }, 13 + "hash": "c583f0016bf5f61c17781f55d121698e81b2314465321a01916ee7902b17e813" 14 + }
+2 -2
.sqlx/query-fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3.json .sqlx/query-b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7.json
··· 1 1 { 2 2 "db_name": "PostgreSQL", 3 - "query": "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1", 3 + "query": "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE", 4 4 "describe": { 5 5 "columns": [ 6 6 { ··· 18 18 false 19 19 ] 20 20 }, 21 - "hash": "fcd868a192d27fd4eccae92a884e881b8d6f09bf7ae08a9b431a44acbf2f91f3" 21 + "hash": "b2e1736dbe2ab9114e373353bcc299176417f3c9220025f9521591ba62928bd7" 22 22 }
+4 -4
TODO.md
··· 253 253 ### Frontend Views 254 254 Uses existing ATProto endpoints where possible: 255 255 256 - **User Dashboard** 256 + User Dashboard 257 257 - [ ] Account overview (uses `com.atproto.server.getSession`, `com.atproto.admin.getAccountInfo`) 258 258 - [ ] Active sessions view (needs new endpoint or extend existing) 259 259 - [ ] App passwords (uses `com.atproto.server.listAppPasswords`, `createAppPassword`, `revokeAppPassword`) 260 260 - [ ] Invite codes (uses `com.atproto.server.getAccountInviteCodes`, `createInviteCode`) 261 261 262 - **Notification Preferences** 262 + Notification Preferences 263 263 - [ ] Channel selector (uses `com.bspds.account.*` endpoints above) 264 264 - [ ] Verification flows for Discord/Telegram/Signal 265 265 - [ ] Notification history view 266 266 267 - **Account Settings** 267 + Account Settings 268 268 - [ ] Email change (uses `com.atproto.server.requestEmailUpdate`, `updateEmail`) 269 269 - [ ] Password change (uses `com.atproto.server.requestPasswordReset`, `resetPassword`) 270 270 - [ ] Handle change (uses `com.atproto.identity.updateHandle`) 271 271 - [ ] Account deletion (uses `com.atproto.server.requestAccountDelete`, `deleteAccount`) 272 272 - [ ] Data export (uses `com.atproto.sync.getRepo`) 273 273 274 - **Admin Dashboard** (privileged users only) 274 + Admin Dashboard (privileged users only) 275 275 - [ ] User list (uses `com.atproto.admin.getAccountInfos` with pagination) 276 276 - [ ] User detail/actions (uses `com.atproto.admin.*` endpoints) 277 277 - [ ] Invite management (uses `com.atproto.admin.getInviteCodes`, `disableInviteCodes`)
+23 -1
src/api/actor/preferences.rs
··· 9 9 use serde_json::{json, Value}; 10 10 11 11 const APP_BSKY_NAMESPACE: &str = "app.bsky"; 12 + const MAX_PREFERENCES_COUNT: usize = 100; 13 + const MAX_PREFERENCE_SIZE: usize = 10_000; 12 14 13 15 #[derive(Serialize)] 14 16 pub struct GetPreferencesOutput { ··· 141 143 } 142 144 }; 143 145 146 + if input.preferences.len() > MAX_PREFERENCES_COUNT { 147 + return ( 148 + StatusCode::BAD_REQUEST, 149 + Json(json!({"error": "InvalidRequest", "message": format!("Too many preferences: {} exceeds limit of {}", input.preferences.len(), MAX_PREFERENCES_COUNT)})), 150 + ) 151 + .into_response(); 152 + } 153 + 144 154 for pref in &input.preferences { 155 + let pref_str = serde_json::to_string(pref).unwrap_or_default(); 156 + if pref_str.len() > MAX_PREFERENCE_SIZE { 157 + return ( 158 + StatusCode::BAD_REQUEST, 159 + Json(json!({"error": "InvalidRequest", "message": format!("Preference too large: {} bytes exceeds limit of {}", pref_str.len(), MAX_PREFERENCE_SIZE)})), 160 + ) 161 + .into_response(); 162 + } 163 + 145 164 let pref_type = match pref.get("$type").and_then(|t| t.as_str()) { 146 165 Some(t) => t, 147 166 None => { ··· 200 219 } 201 220 202 221 for pref in input.preferences { 203 - let pref_type = pref.get("$type").and_then(|t| t.as_str()).unwrap(); 222 + let pref_type = match pref.get("$type").and_then(|t| t.as_str()) { 223 + Some(t) => t, 224 + None => continue, 225 + }; 204 226 205 227 let insert_result = sqlx::query!( 206 228 "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3)",
-564
src/api/admin/account.rs
··· 1 - use crate::state::AppState; 2 - use axum::{ 3 - Json, 4 - extract::{Query, State}, 5 - http::StatusCode, 6 - response::{IntoResponse, Response}, 7 - }; 8 - use serde::{Deserialize, Serialize}; 9 - use serde_json::json; 10 - use tracing::{error, warn}; 11 - 12 - #[derive(Deserialize)] 13 - pub struct GetAccountInfoParams { 14 - pub did: String, 15 - } 16 - 17 - #[derive(Serialize)] 18 - #[serde(rename_all = "camelCase")] 19 - pub struct AccountInfo { 20 - pub did: String, 21 - pub handle: String, 22 - pub email: Option<String>, 23 - pub indexed_at: String, 24 - pub invite_note: Option<String>, 25 - pub invites_disabled: bool, 26 - pub email_confirmed_at: Option<String>, 27 - pub deactivated_at: Option<String>, 28 - } 29 - 30 - #[derive(Serialize)] 31 - #[serde(rename_all = "camelCase")] 32 - pub struct GetAccountInfosOutput { 33 - pub infos: Vec<AccountInfo>, 34 - } 35 - 36 - pub async fn get_account_info( 37 - State(state): State<AppState>, 38 - headers: axum::http::HeaderMap, 39 - Query(params): Query<GetAccountInfoParams>, 40 - ) -> Response { 41 - let auth_header = headers.get("Authorization"); 42 - if auth_header.is_none() { 43 - return ( 44 - StatusCode::UNAUTHORIZED, 45 - Json(json!({"error": "AuthenticationRequired"})), 46 - ) 47 - .into_response(); 48 - } 49 - 50 - let did = params.did.trim(); 51 - if did.is_empty() { 52 - return ( 53 - StatusCode::BAD_REQUEST, 54 - Json(json!({"error": "InvalidRequest", "message": "did is required"})), 55 - ) 56 - .into_response(); 57 - } 58 - 59 - let result = sqlx::query!( 60 - r#" 61 - SELECT did, handle, email, created_at 62 - FROM users 63 - WHERE did = $1 64 - "#, 65 - did 66 - ) 67 - .fetch_optional(&state.db) 68 - .await; 69 - 70 - match result { 71 - Ok(Some(row)) => { 72 - ( 73 - StatusCode::OK, 74 - Json(AccountInfo { 75 - did: row.did, 76 - handle: row.handle, 77 - email: Some(row.email), 78 - indexed_at: row.created_at.to_rfc3339(), 79 - invite_note: None, 80 - invites_disabled: false, 81 - email_confirmed_at: None, 82 - deactivated_at: None, 83 - }), 84 - ) 85 - .into_response() 86 - } 87 - Ok(None) => ( 88 - StatusCode::NOT_FOUND, 89 - Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 90 - ) 91 - .into_response(), 92 - Err(e) => { 93 - error!("DB error in get_account_info: {:?}", e); 94 - ( 95 - StatusCode::INTERNAL_SERVER_ERROR, 96 - Json(json!({"error": "InternalError"})), 97 - ) 98 - .into_response() 99 - } 100 - } 101 - } 102 - 103 - #[derive(Deserialize)] 104 - pub struct GetAccountInfosParams { 105 - pub dids: String, 106 - } 107 - 108 - pub async fn get_account_infos( 109 - State(state): State<AppState>, 110 - headers: axum::http::HeaderMap, 111 - Query(params): Query<GetAccountInfosParams>, 112 - ) -> Response { 113 - let auth_header = headers.get("Authorization"); 114 - if auth_header.is_none() { 115 - return ( 116 - StatusCode::UNAUTHORIZED, 117 - Json(json!({"error": "AuthenticationRequired"})), 118 - ) 119 - .into_response(); 120 - } 121 - 122 - let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect(); 123 - if dids.is_empty() { 124 - return ( 125 - StatusCode::BAD_REQUEST, 126 - Json(json!({"error": "InvalidRequest", "message": "dids is required"})), 127 - ) 128 - .into_response(); 129 - } 130 - 131 - let mut infos = Vec::new(); 132 - 133 - for did in dids { 134 - if did.is_empty() { 135 - continue; 136 - } 137 - 138 - let result = sqlx::query!( 139 - r#" 140 - SELECT did, handle, email, created_at 141 - FROM users 142 - WHERE did = $1 143 - "#, 144 - did 145 - ) 146 - .fetch_optional(&state.db) 147 - .await; 148 - 149 - if let Ok(Some(row)) = result { 150 - infos.push(AccountInfo { 151 - did: row.did, 152 - handle: row.handle, 153 - email: Some(row.email), 154 - indexed_at: row.created_at.to_rfc3339(), 155 - invite_note: None, 156 - invites_disabled: false, 157 - email_confirmed_at: None, 158 - deactivated_at: None, 159 - }); 160 - } 161 - } 162 - 163 - (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() 164 - } 165 - 166 - #[derive(Deserialize)] 167 - pub struct DeleteAccountInput { 168 - pub did: String, 169 - } 170 - 171 - pub async fn delete_account( 172 - State(state): State<AppState>, 173 - headers: axum::http::HeaderMap, 174 - Json(input): Json<DeleteAccountInput>, 175 - ) -> Response { 176 - let auth_header = headers.get("Authorization"); 177 - if auth_header.is_none() { 178 - return ( 179 - StatusCode::UNAUTHORIZED, 180 - Json(json!({"error": "AuthenticationRequired"})), 181 - ) 182 - .into_response(); 183 - } 184 - 185 - let did = input.did.trim(); 186 - if did.is_empty() { 187 - return ( 188 - StatusCode::BAD_REQUEST, 189 - Json(json!({"error": "InvalidRequest", "message": "did is required"})), 190 - ) 191 - .into_response(); 192 - } 193 - 194 - let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 195 - .fetch_optional(&state.db) 196 - .await; 197 - 198 - let user_id = match user { 199 - Ok(Some(row)) => row.id, 200 - Ok(None) => { 201 - return ( 202 - StatusCode::NOT_FOUND, 203 - Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 204 - ) 205 - .into_response(); 206 - } 207 - Err(e) => { 208 - error!("DB error in delete_account: {:?}", e); 209 - return ( 210 - StatusCode::INTERNAL_SERVER_ERROR, 211 - Json(json!({"error": "InternalError"})), 212 - ) 213 - .into_response(); 214 - } 215 - }; 216 - 217 - let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did) 218 - .execute(&state.db) 219 - .await; 220 - 221 - let _ = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) 222 - .execute(&state.db) 223 - .await; 224 - 225 - let _ = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id) 226 - .execute(&state.db) 227 - .await; 228 - 229 - let _ = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id) 230 - .execute(&state.db) 231 - .await; 232 - 233 - let _ = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id) 234 - .execute(&state.db) 235 - .await; 236 - 237 - let result = sqlx::query!("DELETE FROM users WHERE id = $1", user_id) 238 - .execute(&state.db) 239 - .await; 240 - 241 - match result { 242 - Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 243 - Err(e) => { 244 - error!("DB error deleting account: {:?}", e); 245 - ( 246 - StatusCode::INTERNAL_SERVER_ERROR, 247 - Json(json!({"error": "InternalError"})), 248 - ) 249 - .into_response() 250 - } 251 - } 252 - } 253 - 254 - #[derive(Deserialize)] 255 - pub struct UpdateAccountEmailInput { 256 - pub account: String, 257 - pub email: String, 258 - } 259 - 260 - pub async fn update_account_email( 261 - State(state): State<AppState>, 262 - headers: axum::http::HeaderMap, 263 - Json(input): Json<UpdateAccountEmailInput>, 264 - ) -> Response { 265 - let auth_header = headers.get("Authorization"); 266 - if auth_header.is_none() { 267 - return ( 268 - StatusCode::UNAUTHORIZED, 269 - Json(json!({"error": "AuthenticationRequired"})), 270 - ) 271 - .into_response(); 272 - } 273 - 274 - let account = input.account.trim(); 275 - let email = input.email.trim(); 276 - 277 - if account.is_empty() || email.is_empty() { 278 - return ( 279 - StatusCode::BAD_REQUEST, 280 - Json(json!({"error": "InvalidRequest", "message": "account and email are required"})), 281 - ) 282 - .into_response(); 283 - } 284 - 285 - let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account) 286 - .execute(&state.db) 287 - .await; 288 - 289 - match result { 290 - Ok(r) => { 291 - if r.rows_affected() == 0 { 292 - return ( 293 - StatusCode::NOT_FOUND, 294 - Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 295 - ) 296 - .into_response(); 297 - } 298 - (StatusCode::OK, Json(json!({}))).into_response() 299 - } 300 - Err(e) => { 301 - error!("DB error updating email: {:?}", e); 302 - ( 303 - StatusCode::INTERNAL_SERVER_ERROR, 304 - Json(json!({"error": "InternalError"})), 305 - ) 306 - .into_response() 307 - } 308 - } 309 - } 310 - 311 - #[derive(Deserialize)] 312 - pub struct UpdateAccountHandleInput { 313 - pub did: String, 314 - pub handle: String, 315 - } 316 - 317 - pub async fn update_account_handle( 318 - State(state): State<AppState>, 319 - headers: axum::http::HeaderMap, 320 - Json(input): Json<UpdateAccountHandleInput>, 321 - ) -> Response { 322 - let auth_header = headers.get("Authorization"); 323 - if auth_header.is_none() { 324 - return ( 325 - StatusCode::UNAUTHORIZED, 326 - Json(json!({"error": "AuthenticationRequired"})), 327 - ) 328 - .into_response(); 329 - } 330 - 331 - let did = input.did.trim(); 332 - let handle = input.handle.trim(); 333 - 334 - if did.is_empty() || handle.is_empty() { 335 - return ( 336 - StatusCode::BAD_REQUEST, 337 - Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})), 338 - ) 339 - .into_response(); 340 - } 341 - 342 - if !handle 343 - .chars() 344 - .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 345 - { 346 - return ( 347 - StatusCode::BAD_REQUEST, 348 - Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 349 - ) 350 - .into_response(); 351 - } 352 - 353 - let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 354 - .fetch_optional(&state.db) 355 - .await; 356 - 357 - if let Ok(Some(_)) = existing { 358 - return ( 359 - StatusCode::BAD_REQUEST, 360 - Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 361 - ) 362 - .into_response(); 363 - } 364 - 365 - let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did) 366 - .execute(&state.db) 367 - .await; 368 - 369 - match result { 370 - Ok(r) => { 371 - if r.rows_affected() == 0 { 372 - return ( 373 - StatusCode::NOT_FOUND, 374 - Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 375 - ) 376 - .into_response(); 377 - } 378 - (StatusCode::OK, Json(json!({}))).into_response() 379 - } 380 - Err(e) => { 381 - error!("DB error updating handle: {:?}", e); 382 - ( 383 - StatusCode::INTERNAL_SERVER_ERROR, 384 - Json(json!({"error": "InternalError"})), 385 - ) 386 - .into_response() 387 - } 388 - } 389 - } 390 - 391 - #[derive(Deserialize)] 392 - pub struct UpdateAccountPasswordInput { 393 - pub did: String, 394 - pub password: String, 395 - } 396 - 397 - pub async fn update_account_password( 398 - State(state): State<AppState>, 399 - headers: axum::http::HeaderMap, 400 - Json(input): Json<UpdateAccountPasswordInput>, 401 - ) -> Response { 402 - let auth_header = headers.get("Authorization"); 403 - if auth_header.is_none() { 404 - return ( 405 - StatusCode::UNAUTHORIZED, 406 - Json(json!({"error": "AuthenticationRequired"})), 407 - ) 408 - .into_response(); 409 - } 410 - 411 - let did = input.did.trim(); 412 - let password = input.password.trim(); 413 - 414 - if did.is_empty() || password.is_empty() { 415 - return ( 416 - StatusCode::BAD_REQUEST, 417 - Json(json!({"error": "InvalidRequest", "message": "did and password are required"})), 418 - ) 419 - .into_response(); 420 - } 421 - 422 - let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) { 423 - Ok(h) => h, 424 - Err(e) => { 425 - error!("Failed to hash password: {:?}", e); 426 - return ( 427 - StatusCode::INTERNAL_SERVER_ERROR, 428 - Json(json!({"error": "InternalError"})), 429 - ) 430 - .into_response(); 431 - } 432 - }; 433 - 434 - let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did) 435 - .execute(&state.db) 436 - .await; 437 - 438 - match result { 439 - Ok(r) => { 440 - if r.rows_affected() == 0 { 441 - return ( 442 - StatusCode::NOT_FOUND, 443 - Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 444 - ) 445 - .into_response(); 446 - } 447 - (StatusCode::OK, Json(json!({}))).into_response() 448 - } 449 - Err(e) => { 450 - error!("DB error updating password: {:?}", e); 451 - ( 452 - StatusCode::INTERNAL_SERVER_ERROR, 453 - Json(json!({"error": "InternalError"})), 454 - ) 455 - .into_response() 456 - } 457 - } 458 - } 459 - 460 - #[derive(Deserialize)] 461 - #[serde(rename_all = "camelCase")] 462 - pub struct SendEmailInput { 463 - pub recipient_did: String, 464 - pub sender_did: String, 465 - pub content: String, 466 - pub subject: Option<String>, 467 - pub comment: Option<String>, 468 - } 469 - 470 - #[derive(Serialize)] 471 - pub struct SendEmailOutput { 472 - pub sent: bool, 473 - } 474 - 475 - pub async fn send_email( 476 - State(state): State<AppState>, 477 - headers: axum::http::HeaderMap, 478 - Json(input): Json<SendEmailInput>, 479 - ) -> Response { 480 - let auth_header = headers.get("Authorization"); 481 - if auth_header.is_none() { 482 - return ( 483 - StatusCode::UNAUTHORIZED, 484 - Json(json!({"error": "AuthenticationRequired"})), 485 - ) 486 - .into_response(); 487 - } 488 - 489 - let recipient_did = input.recipient_did.trim(); 490 - let content = input.content.trim(); 491 - 492 - if recipient_did.is_empty() { 493 - return ( 494 - StatusCode::BAD_REQUEST, 495 - Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})), 496 - ) 497 - .into_response(); 498 - } 499 - 500 - if content.is_empty() { 501 - return ( 502 - StatusCode::BAD_REQUEST, 503 - Json(json!({"error": "InvalidRequest", "message": "content is required"})), 504 - ) 505 - .into_response(); 506 - } 507 - 508 - let user = sqlx::query!( 509 - "SELECT id, email, handle FROM users WHERE did = $1", 510 - recipient_did 511 - ) 512 - .fetch_optional(&state.db) 513 - .await; 514 - 515 - let (user_id, email, handle) = match user { 516 - Ok(Some(row)) => (row.id, row.email, row.handle), 517 - Ok(None) => { 518 - return ( 519 - StatusCode::NOT_FOUND, 520 - Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})), 521 - ) 522 - .into_response(); 523 - } 524 - Err(e) => { 525 - error!("DB error in send_email: {:?}", e); 526 - return ( 527 - StatusCode::INTERNAL_SERVER_ERROR, 528 - Json(json!({"error": "InternalError"})), 529 - ) 530 - .into_response(); 531 - } 532 - }; 533 - 534 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 535 - let subject = input 536 - .subject 537 - .clone() 538 - .unwrap_or_else(|| format!("Message from {}", hostname)); 539 - 540 - let notification = crate::notifications::NewNotification::email( 541 - user_id, 542 - crate::notifications::NotificationType::AdminEmail, 543 - email, 544 - subject, 545 - content.to_string(), 546 - ); 547 - 548 - let result = crate::notifications::enqueue_notification(&state.db, notification).await; 549 - 550 - match result { 551 - Ok(_) => { 552 - tracing::info!( 553 - "Admin email queued for {} ({})", 554 - handle, 555 - recipient_did 556 - ); 557 - (StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response() 558 - } 559 - Err(e) => { 560 - warn!("Failed to enqueue admin email: {:?}", e); 561 - (StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response() 562 - } 563 - } 564 - }
+190
src/api/admin/account/delete.rs
··· 1 + use crate::state::AppState; 2 + use axum::{ 3 + Json, 4 + extract::State, 5 + http::StatusCode, 6 + response::{IntoResponse, Response}, 7 + }; 8 + use serde::Deserialize; 9 + use serde_json::json; 10 + use tracing::error; 11 + 12 + #[derive(Deserialize)] 13 + pub struct DeleteAccountInput { 14 + pub did: String, 15 + } 16 + 17 + pub async fn delete_account( 18 + State(state): State<AppState>, 19 + headers: axum::http::HeaderMap, 20 + Json(input): Json<DeleteAccountInput>, 21 + ) -> Response { 22 + let auth_header = headers.get("Authorization"); 23 + if auth_header.is_none() { 24 + return ( 25 + StatusCode::UNAUTHORIZED, 26 + Json(json!({"error": "AuthenticationRequired"})), 27 + ) 28 + .into_response(); 29 + } 30 + 31 + let did = input.did.trim(); 32 + if did.is_empty() { 33 + return ( 34 + StatusCode::BAD_REQUEST, 35 + Json(json!({"error": "InvalidRequest", "message": "did is required"})), 36 + ) 37 + .into_response(); 38 + } 39 + 40 + let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 41 + .fetch_optional(&state.db) 42 + .await; 43 + 44 + let user_id = match user { 45 + Ok(Some(row)) => row.id, 46 + Ok(None) => { 47 + return ( 48 + StatusCode::NOT_FOUND, 49 + Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 50 + ) 51 + .into_response(); 52 + } 53 + Err(e) => { 54 + error!("DB error in delete_account: {:?}", e); 55 + return ( 56 + StatusCode::INTERNAL_SERVER_ERROR, 57 + Json(json!({"error": "InternalError"})), 58 + ) 59 + .into_response(); 60 + } 61 + }; 62 + 63 + let mut tx = match state.db.begin().await { 64 + Ok(tx) => tx, 65 + Err(e) => { 66 + error!("Failed to begin transaction for account deletion: {:?}", e); 67 + return ( 68 + StatusCode::INTERNAL_SERVER_ERROR, 69 + Json(json!({"error": "InternalError"})), 70 + ) 71 + .into_response(); 72 + } 73 + }; 74 + 75 + if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did) 76 + .execute(&mut *tx) 77 + .await 78 + { 79 + error!("Failed to delete session tokens for {}: {:?}", did, e); 80 + return ( 81 + StatusCode::INTERNAL_SERVER_ERROR, 82 + Json(json!({"error": "InternalError", "message": "Failed to delete session tokens"})), 83 + ) 84 + .into_response(); 85 + } 86 + 87 + if let Err(e) = sqlx::query!("DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)", did) 88 + .execute(&mut *tx) 89 + .await 90 + { 91 + error!("Failed to delete used refresh tokens for {}: {:?}", did, e); 92 + } 93 + 94 + if let Err(e) = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) 95 + .execute(&mut *tx) 96 + .await 97 + { 98 + error!("Failed to delete records for user {}: {:?}", user_id, e); 99 + return ( 100 + StatusCode::INTERNAL_SERVER_ERROR, 101 + Json(json!({"error": "InternalError", "message": "Failed to delete records"})), 102 + ) 103 + .into_response(); 104 + } 105 + 106 + if let Err(e) = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id) 107 + .execute(&mut *tx) 108 + .await 109 + { 110 + error!("Failed to delete repos for user {}: {:?}", user_id, e); 111 + return ( 112 + StatusCode::INTERNAL_SERVER_ERROR, 113 + Json(json!({"error": "InternalError", "message": "Failed to delete repos"})), 114 + ) 115 + .into_response(); 116 + } 117 + 118 + if let Err(e) = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id) 119 + .execute(&mut *tx) 120 + .await 121 + { 122 + error!("Failed to delete blobs for user {}: {:?}", user_id, e); 123 + return ( 124 + StatusCode::INTERNAL_SERVER_ERROR, 125 + Json(json!({"error": "InternalError", "message": "Failed to delete blobs"})), 126 + ) 127 + .into_response(); 128 + } 129 + 130 + if let Err(e) = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1", user_id) 131 + .execute(&mut *tx) 132 + .await 133 + { 134 + error!("Failed to delete app passwords for user {}: {:?}", user_id, e); 135 + return ( 136 + StatusCode::INTERNAL_SERVER_ERROR, 137 + Json(json!({"error": "InternalError", "message": "Failed to delete app passwords"})), 138 + ) 139 + .into_response(); 140 + } 141 + 142 + if let Err(e) = sqlx::query!("DELETE FROM invite_code_uses WHERE used_by_user = $1", user_id) 143 + .execute(&mut *tx) 144 + .await 145 + { 146 + error!("Failed to delete invite code uses for user {}: {:?}", user_id, e); 147 + } 148 + 149 + if let Err(e) = sqlx::query!("DELETE FROM invite_codes WHERE created_by_user = $1", user_id) 150 + .execute(&mut *tx) 151 + .await 152 + { 153 + error!("Failed to delete invite codes for user {}: {:?}", user_id, e); 154 + } 155 + 156 + if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id) 157 + .execute(&mut *tx) 158 + .await 159 + { 160 + error!("Failed to delete user keys for user {}: {:?}", user_id, e); 161 + return ( 162 + StatusCode::INTERNAL_SERVER_ERROR, 163 + Json(json!({"error": "InternalError", "message": "Failed to delete user keys"})), 164 + ) 165 + .into_response(); 166 + } 167 + 168 + if let Err(e) = sqlx::query!("DELETE FROM users WHERE id = $1", user_id) 169 + .execute(&mut *tx) 170 + .await 171 + { 172 + error!("Failed to delete user {}: {:?}", user_id, e); 173 + return ( 174 + StatusCode::INTERNAL_SERVER_ERROR, 175 + Json(json!({"error": "InternalError", "message": "Failed to delete user"})), 176 + ) 177 + .into_response(); 178 + } 179 + 180 + if let Err(e) = tx.commit().await { 181 + error!("Failed to commit account deletion transaction: {:?}", e); 182 + return ( 183 + StatusCode::INTERNAL_SERVER_ERROR, 184 + Json(json!({"error": "InternalError", "message": "Failed to commit deletion"})), 185 + ) 186 + .into_response(); 187 + } 188 + 189 + (StatusCode::OK, Json(json!({}))).into_response() 190 + }
+116
src/api/admin/account/email.rs
··· 1 + use crate::state::AppState; 2 + use axum::{ 3 + Json, 4 + extract::State, 5 + http::StatusCode, 6 + response::{IntoResponse, Response}, 7 + }; 8 + use serde::{Deserialize, Serialize}; 9 + use serde_json::json; 10 + use tracing::{error, warn}; 11 + 12 + #[derive(Deserialize)] 13 + #[serde(rename_all = "camelCase")] 14 + pub struct SendEmailInput { 15 + pub recipient_did: String, 16 + pub sender_did: String, 17 + pub content: String, 18 + pub subject: Option<String>, 19 + pub comment: Option<String>, 20 + } 21 + 22 + #[derive(Serialize)] 23 + pub struct SendEmailOutput { 24 + pub sent: bool, 25 + } 26 + 27 + pub async fn send_email( 28 + State(state): State<AppState>, 29 + headers: axum::http::HeaderMap, 30 + Json(input): Json<SendEmailInput>, 31 + ) -> Response { 32 + let auth_header = headers.get("Authorization"); 33 + if auth_header.is_none() { 34 + return ( 35 + StatusCode::UNAUTHORIZED, 36 + Json(json!({"error": "AuthenticationRequired"})), 37 + ) 38 + .into_response(); 39 + } 40 + 41 + let recipient_did = input.recipient_did.trim(); 42 + let content = input.content.trim(); 43 + 44 + if recipient_did.is_empty() { 45 + return ( 46 + StatusCode::BAD_REQUEST, 47 + Json(json!({"error": "InvalidRequest", "message": "recipientDid is required"})), 48 + ) 49 + .into_response(); 50 + } 51 + 52 + if content.is_empty() { 53 + return ( 54 + StatusCode::BAD_REQUEST, 55 + Json(json!({"error": "InvalidRequest", "message": "content is required"})), 56 + ) 57 + .into_response(); 58 + } 59 + 60 + let user = sqlx::query!( 61 + "SELECT id, email, handle FROM users WHERE did = $1", 62 + recipient_did 63 + ) 64 + .fetch_optional(&state.db) 65 + .await; 66 + 67 + let (user_id, email, handle) = match user { 68 + Ok(Some(row)) => (row.id, row.email, row.handle), 69 + Ok(None) => { 70 + return ( 71 + StatusCode::NOT_FOUND, 72 + Json(json!({"error": "AccountNotFound", "message": "Recipient account not found"})), 73 + ) 74 + .into_response(); 75 + } 76 + Err(e) => { 77 + error!("DB error in send_email: {:?}", e); 78 + return ( 79 + StatusCode::INTERNAL_SERVER_ERROR, 80 + Json(json!({"error": "InternalError"})), 81 + ) 82 + .into_response(); 83 + } 84 + }; 85 + 86 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 87 + let subject = input 88 + .subject 89 + .clone() 90 + .unwrap_or_else(|| format!("Message from {}", hostname)); 91 + 92 + let notification = crate::notifications::NewNotification::email( 93 + user_id, 94 + crate::notifications::NotificationType::AdminEmail, 95 + email, 96 + subject, 97 + content.to_string(), 98 + ); 99 + 100 + let result = crate::notifications::enqueue_notification(&state.db, notification).await; 101 + 102 + match result { 103 + Ok(_) => { 104 + tracing::info!( 105 + "Admin email queued for {} ({})", 106 + handle, 107 + recipient_did 108 + ); 109 + (StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response() 110 + } 111 + Err(e) => { 112 + warn!("Failed to enqueue admin email: {:?}", e); 113 + (StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response() 114 + } 115 + } 116 + }
+164
src/api/admin/account/info.rs
··· 1 + use crate::state::AppState; 2 + use axum::{ 3 + Json, 4 + extract::{Query, State}, 5 + http::StatusCode, 6 + response::{IntoResponse, Response}, 7 + }; 8 + use serde::{Deserialize, Serialize}; 9 + use serde_json::json; 10 + use tracing::error; 11 + 12 + #[derive(Deserialize)] 13 + pub struct GetAccountInfoParams { 14 + pub did: String, 15 + } 16 + 17 + #[derive(Serialize)] 18 + #[serde(rename_all = "camelCase")] 19 + pub struct AccountInfo { 20 + pub did: String, 21 + pub handle: String, 22 + pub email: Option<String>, 23 + pub indexed_at: String, 24 + pub invite_note: Option<String>, 25 + pub invites_disabled: bool, 26 + pub email_confirmed_at: Option<String>, 27 + pub deactivated_at: Option<String>, 28 + } 29 + 30 + #[derive(Serialize)] 31 + #[serde(rename_all = "camelCase")] 32 + pub struct GetAccountInfosOutput { 33 + pub infos: Vec<AccountInfo>, 34 + } 35 + 36 + pub async fn get_account_info( 37 + State(state): State<AppState>, 38 + headers: axum::http::HeaderMap, 39 + Query(params): Query<GetAccountInfoParams>, 40 + ) -> Response { 41 + let auth_header = headers.get("Authorization"); 42 + if auth_header.is_none() { 43 + return ( 44 + StatusCode::UNAUTHORIZED, 45 + Json(json!({"error": "AuthenticationRequired"})), 46 + ) 47 + .into_response(); 48 + } 49 + 50 + let did = params.did.trim(); 51 + if did.is_empty() { 52 + return ( 53 + StatusCode::BAD_REQUEST, 54 + Json(json!({"error": "InvalidRequest", "message": "did is required"})), 55 + ) 56 + .into_response(); 57 + } 58 + 59 + let result = sqlx::query!( 60 + r#" 61 + SELECT did, handle, email, created_at 62 + FROM users 63 + WHERE did = $1 64 + "#, 65 + did 66 + ) 67 + .fetch_optional(&state.db) 68 + .await; 69 + 70 + match result { 71 + Ok(Some(row)) => { 72 + ( 73 + StatusCode::OK, 74 + Json(AccountInfo { 75 + did: row.did, 76 + handle: row.handle, 77 + email: Some(row.email), 78 + indexed_at: row.created_at.to_rfc3339(), 79 + invite_note: None, 80 + invites_disabled: false, 81 + email_confirmed_at: None, 82 + deactivated_at: None, 83 + }), 84 + ) 85 + .into_response() 86 + } 87 + Ok(None) => ( 88 + StatusCode::NOT_FOUND, 89 + Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 90 + ) 91 + .into_response(), 92 + Err(e) => { 93 + error!("DB error in get_account_info: {:?}", e); 94 + ( 95 + StatusCode::INTERNAL_SERVER_ERROR, 96 + Json(json!({"error": "InternalError"})), 97 + ) 98 + .into_response() 99 + } 100 + } 101 + } 102 + 103 + #[derive(Deserialize)] 104 + pub struct GetAccountInfosParams { 105 + pub dids: String, 106 + } 107 + 108 + pub async fn get_account_infos( 109 + State(state): State<AppState>, 110 + headers: axum::http::HeaderMap, 111 + Query(params): Query<GetAccountInfosParams>, 112 + ) -> Response { 113 + let auth_header = headers.get("Authorization"); 114 + if auth_header.is_none() { 115 + return ( 116 + StatusCode::UNAUTHORIZED, 117 + Json(json!({"error": "AuthenticationRequired"})), 118 + ) 119 + .into_response(); 120 + } 121 + 122 + let dids: Vec<&str> = params.dids.split(',').map(|s| s.trim()).collect(); 123 + if dids.is_empty() { 124 + return ( 125 + StatusCode::BAD_REQUEST, 126 + Json(json!({"error": "InvalidRequest", "message": "dids is required"})), 127 + ) 128 + .into_response(); 129 + } 130 + 131 + let mut infos = Vec::new(); 132 + 133 + for did in dids { 134 + if did.is_empty() { 135 + continue; 136 + } 137 + 138 + let result = sqlx::query!( 139 + r#" 140 + SELECT did, handle, email, created_at 141 + FROM users 142 + WHERE did = $1 143 + "#, 144 + did 145 + ) 146 + .fetch_optional(&state.db) 147 + .await; 148 + 149 + if let Ok(Some(row)) = result { 150 + infos.push(AccountInfo { 151 + did: row.did, 152 + handle: row.handle, 153 + email: Some(row.email), 154 + indexed_at: row.created_at.to_rfc3339(), 155 + invite_note: None, 156 + invites_disabled: false, 157 + email_confirmed_at: None, 158 + deactivated_at: None, 159 + }); 160 + } 161 + } 162 + 163 + (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() 164 + }
+15
src/api/admin/account/mod.rs
··· 1 + mod delete; 2 + mod email; 3 + mod info; 4 + mod update; 5 + 6 + pub use delete::{delete_account, DeleteAccountInput}; 7 + pub use email::{send_email, SendEmailInput, SendEmailOutput}; 8 + pub use info::{ 9 + get_account_info, get_account_infos, AccountInfo, GetAccountInfoParams, GetAccountInfosOutput, 10 + GetAccountInfosParams, 11 + }; 12 + pub use update::{ 13 + update_account_email, update_account_handle, update_account_password, UpdateAccountEmailInput, 14 + UpdateAccountHandleInput, UpdateAccountPasswordInput, 15 + };
+216
src/api/admin/account/update.rs
··· 1 + use crate::state::AppState; 2 + use axum::{ 3 + Json, 4 + extract::State, 5 + http::StatusCode, 6 + response::{IntoResponse, Response}, 7 + }; 8 + use serde::Deserialize; 9 + use serde_json::json; 10 + use tracing::error; 11 + 12 + #[derive(Deserialize)] 13 + pub struct UpdateAccountEmailInput { 14 + pub account: String, 15 + pub email: String, 16 + } 17 + 18 + pub async fn update_account_email( 19 + State(state): State<AppState>, 20 + headers: axum::http::HeaderMap, 21 + Json(input): Json<UpdateAccountEmailInput>, 22 + ) -> Response { 23 + let auth_header = headers.get("Authorization"); 24 + if auth_header.is_none() { 25 + return ( 26 + StatusCode::UNAUTHORIZED, 27 + Json(json!({"error": "AuthenticationRequired"})), 28 + ) 29 + .into_response(); 30 + } 31 + 32 + let account = input.account.trim(); 33 + let email = input.email.trim(); 34 + 35 + if account.is_empty() || email.is_empty() { 36 + return ( 37 + StatusCode::BAD_REQUEST, 38 + Json(json!({"error": "InvalidRequest", "message": "account and email are required"})), 39 + ) 40 + .into_response(); 41 + } 42 + 43 + let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account) 44 + .execute(&state.db) 45 + .await; 46 + 47 + match result { 48 + Ok(r) => { 49 + if r.rows_affected() == 0 { 50 + return ( 51 + StatusCode::NOT_FOUND, 52 + Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 53 + ) 54 + .into_response(); 55 + } 56 + (StatusCode::OK, Json(json!({}))).into_response() 57 + } 58 + Err(e) => { 59 + error!("DB error updating email: {:?}", e); 60 + ( 61 + StatusCode::INTERNAL_SERVER_ERROR, 62 + Json(json!({"error": "InternalError"})), 63 + ) 64 + .into_response() 65 + } 66 + } 67 + } 68 + 69 + #[derive(Deserialize)] 70 + pub struct UpdateAccountHandleInput { 71 + pub did: String, 72 + pub handle: String, 73 + } 74 + 75 + pub async fn update_account_handle( 76 + State(state): State<AppState>, 77 + headers: axum::http::HeaderMap, 78 + Json(input): Json<UpdateAccountHandleInput>, 79 + ) -> Response { 80 + let auth_header = headers.get("Authorization"); 81 + if auth_header.is_none() { 82 + return ( 83 + StatusCode::UNAUTHORIZED, 84 + Json(json!({"error": "AuthenticationRequired"})), 85 + ) 86 + .into_response(); 87 + } 88 + 89 + let did = input.did.trim(); 90 + let handle = input.handle.trim(); 91 + 92 + if did.is_empty() || handle.is_empty() { 93 + return ( 94 + StatusCode::BAD_REQUEST, 95 + Json(json!({"error": "InvalidRequest", "message": "did and handle are required"})), 96 + ) 97 + .into_response(); 98 + } 99 + 100 + if !handle 101 + .chars() 102 + .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 103 + { 104 + return ( 105 + StatusCode::BAD_REQUEST, 106 + Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 107 + ) 108 + .into_response(); 109 + } 110 + 111 + let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND did != $2", handle, did) 112 + .fetch_optional(&state.db) 113 + .await; 114 + 115 + if let Ok(Some(_)) = existing { 116 + return ( 117 + StatusCode::BAD_REQUEST, 118 + Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 119 + ) 120 + .into_response(); 121 + } 122 + 123 + let result = sqlx::query!("UPDATE users SET handle = $1 WHERE did = $2", handle, did) 124 + .execute(&state.db) 125 + .await; 126 + 127 + match result { 128 + Ok(r) => { 129 + if r.rows_affected() == 0 { 130 + return ( 131 + StatusCode::NOT_FOUND, 132 + Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 133 + ) 134 + .into_response(); 135 + } 136 + (StatusCode::OK, Json(json!({}))).into_response() 137 + } 138 + Err(e) => { 139 + error!("DB error updating handle: {:?}", e); 140 + ( 141 + StatusCode::INTERNAL_SERVER_ERROR, 142 + Json(json!({"error": "InternalError"})), 143 + ) 144 + .into_response() 145 + } 146 + } 147 + } 148 + 149 + #[derive(Deserialize)] 150 + pub struct UpdateAccountPasswordInput { 151 + pub did: String, 152 + pub password: String, 153 + } 154 + 155 + pub async fn update_account_password( 156 + State(state): State<AppState>, 157 + headers: axum::http::HeaderMap, 158 + Json(input): Json<UpdateAccountPasswordInput>, 159 + ) -> Response { 160 + let auth_header = headers.get("Authorization"); 161 + if auth_header.is_none() { 162 + return ( 163 + StatusCode::UNAUTHORIZED, 164 + Json(json!({"error": "AuthenticationRequired"})), 165 + ) 166 + .into_response(); 167 + } 168 + 169 + let did = input.did.trim(); 170 + let password = input.password.trim(); 171 + 172 + if did.is_empty() || password.is_empty() { 173 + return ( 174 + StatusCode::BAD_REQUEST, 175 + Json(json!({"error": "InvalidRequest", "message": "did and password are required"})), 176 + ) 177 + .into_response(); 178 + } 179 + 180 + let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) { 181 + Ok(h) => h, 182 + Err(e) => { 183 + error!("Failed to hash password: {:?}", e); 184 + return ( 185 + StatusCode::INTERNAL_SERVER_ERROR, 186 + Json(json!({"error": "InternalError"})), 187 + ) 188 + .into_response(); 189 + } 190 + }; 191 + 192 + let result = sqlx::query!("UPDATE users SET password_hash = $1 WHERE did = $2", password_hash, did) 193 + .execute(&state.db) 194 + .await; 195 + 196 + match result { 197 + Ok(r) => { 198 + if r.rows_affected() == 0 { 199 + return ( 200 + StatusCode::NOT_FOUND, 201 + Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 202 + ) 203 + .into_response(); 204 + } 205 + (StatusCode::OK, Json(json!({}))).into_response() 206 + } 207 + Err(e) => { 208 + error!("DB error updating password: {:?}", e); 209 + ( 210 + StatusCode::INTERNAL_SERVER_ERROR, 211 + Json(json!({"error": "InternalError"})), 212 + ) 213 + .into_response() 214 + } 215 + } 216 + }
+1 -1
src/api/admin/invite.rs
··· 104 104 .into_response(); 105 105 } 106 106 107 - let limit = params.limit.unwrap_or(100).min(500); 107 + let limit = params.limit.unwrap_or(100).clamp(1, 500); 108 108 let sort = params.sort.as_deref().unwrap_or("recent"); 109 109 110 110 let order_clause = match sort {
+68 -14
src/api/admin/status.rs
··· 234 234 Some("com.atproto.admin.defs#repoRef") => { 235 235 let did = input.subject.get("did").and_then(|d| d.as_str()); 236 236 if let Some(did) = did { 237 + let mut tx = match state.db.begin().await { 238 + Ok(tx) => tx, 239 + Err(e) => { 240 + error!("Failed to begin transaction: {:?}", e); 241 + return ( 242 + StatusCode::INTERNAL_SERVER_ERROR, 243 + Json(json!({"error": "InternalError"})), 244 + ) 245 + .into_response(); 246 + } 247 + }; 248 + 237 249 if let Some(takedown) = &input.takedown { 238 250 let takedown_ref = if takedown.apply { 239 251 takedown.r#ref.clone() 240 252 } else { 241 253 None 242 254 }; 243 - let _ = sqlx::query!( 255 + if let Err(e) = sqlx::query!( 244 256 "UPDATE users SET takedown_ref = $1 WHERE did = $2", 245 257 takedown_ref, 246 258 did 247 259 ) 248 - .execute(&state.db) 249 - .await; 260 + .execute(&mut *tx) 261 + .await 262 + { 263 + error!("Failed to update user takedown status for {}: {:?}", did, e); 264 + return ( 265 + StatusCode::INTERNAL_SERVER_ERROR, 266 + Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})), 267 + ) 268 + .into_response(); 269 + } 250 270 } 251 271 252 272 if let Some(deactivated) = &input.deactivated { 253 - if deactivated.apply { 254 - let _ = sqlx::query!( 273 + let result = if deactivated.apply { 274 + sqlx::query!( 255 275 "UPDATE users SET deactivated_at = NOW() WHERE did = $1", 256 276 did 257 277 ) 258 - .execute(&state.db) 259 - .await; 278 + .execute(&mut *tx) 279 + .await 260 280 } else { 261 - let _ = sqlx::query!( 281 + sqlx::query!( 262 282 "UPDATE users SET deactivated_at = NULL WHERE did = $1", 263 283 did 264 284 ) 265 - .execute(&state.db) 266 - .await; 285 + .execute(&mut *tx) 286 + .await 287 + }; 288 + 289 + if let Err(e) = result { 290 + error!("Failed to update user deactivation status for {}: {:?}", did, e); 291 + return ( 292 + StatusCode::INTERNAL_SERVER_ERROR, 293 + Json(json!({"error": "InternalError", "message": "Failed to update deactivation status"})), 294 + ) 295 + .into_response(); 267 296 } 297 + } 298 + 299 + if let Err(e) = tx.commit().await { 300 + error!("Failed to commit transaction: {:?}", e); 301 + return ( 302 + StatusCode::INTERNAL_SERVER_ERROR, 303 + Json(json!({"error": "InternalError"})), 304 + ) 305 + .into_response(); 268 306 } 269 307 270 308 return ( ··· 292 330 } else { 293 331 None 294 332 }; 295 - let _ = sqlx::query!( 333 + if let Err(e) = sqlx::query!( 296 334 "UPDATE records SET takedown_ref = $1 WHERE record_cid = $2", 297 335 takedown_ref, 298 336 uri 299 337 ) 300 338 .execute(&state.db) 301 - .await; 339 + .await 340 + { 341 + error!("Failed to update record takedown status for {}: {:?}", uri, e); 342 + return ( 343 + StatusCode::INTERNAL_SERVER_ERROR, 344 + Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})), 345 + ) 346 + .into_response(); 347 + } 302 348 } 303 349 304 350 return ( ··· 323 369 } else { 324 370 None 325 371 }; 326 - let _ = sqlx::query!( 372 + if let Err(e) = sqlx::query!( 327 373 "UPDATE blobs SET takedown_ref = $1 WHERE cid = $2", 328 374 takedown_ref, 329 375 cid 330 376 ) 331 377 .execute(&state.db) 332 - .await; 378 + .await 379 + { 380 + error!("Failed to update blob takedown status for {}: {:?}", cid, e); 381 + return ( 382 + StatusCode::INTERNAL_SERVER_ERROR, 383 + Json(json!({"error": "InternalError", "message": "Failed to update takedown status"})), 384 + ) 385 + .into_response(); 386 + } 333 387 } 334 388 335 389 return (
+163
src/api/error.rs
··· 1 + use axum::{ 2 + Json, 3 + http::StatusCode, 4 + response::{IntoResponse, Response}, 5 + }; 6 + use serde::Serialize; 7 + 8 + #[derive(Debug, Serialize)] 9 + struct ErrorBody { 10 + error: &'static str, 11 + #[serde(skip_serializing_if = "Option::is_none")] 12 + message: Option<String>, 13 + } 14 + 15 + #[derive(Debug)] 16 + pub enum ApiError { 17 + InternalError, 18 + AuthenticationRequired, 19 + AuthenticationFailed, 20 + AuthenticationFailedMsg(String), 21 + InvalidRequest(String), 22 + InvalidToken, 23 + ExpiredToken, 24 + ExpiredTokenMsg(String), 25 + TokenRequired, 26 + AccountDeactivated, 27 + AccountTakedown, 28 + AccountNotFound, 29 + RepoNotFound, 30 + RepoNotFoundMsg(String), 31 + RecordNotFound, 32 + BlobNotFound, 33 + InvalidHandle, 34 + HandleNotAvailable, 35 + HandleTaken, 36 + InvalidEmail, 37 + EmailTaken, 38 + InvalidInviteCode, 39 + DuplicateCreate, 40 + DuplicateAppPassword, 41 + AppPasswordNotFound, 42 + InvalidSwap, 43 + Forbidden, 44 + InvitesDisabled, 45 + DatabaseError, 46 + UpstreamFailure, 47 + } 48 + 49 + impl ApiError { 50 + fn status_code(&self) -> StatusCode { 51 + match self { 52 + Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => { 53 + StatusCode::INTERNAL_SERVER_ERROR 54 + } 55 + Self::AuthenticationRequired 56 + | Self::AuthenticationFailed 57 + | Self::AuthenticationFailedMsg(_) 58 + | Self::InvalidToken 59 + | Self::ExpiredToken 60 + | Self::ExpiredTokenMsg(_) 61 + | Self::TokenRequired 62 + | Self::AccountDeactivated 63 + | Self::AccountTakedown => StatusCode::UNAUTHORIZED, 64 + Self::Forbidden | Self::InvitesDisabled => StatusCode::FORBIDDEN, 65 + Self::AccountNotFound 66 + | Self::RepoNotFound 67 + | Self::RepoNotFoundMsg(_) 68 + | Self::RecordNotFound 69 + | Self::BlobNotFound 70 + | Self::AppPasswordNotFound => StatusCode::NOT_FOUND, 71 + Self::InvalidRequest(_) 72 + | Self::InvalidHandle 73 + | Self::HandleNotAvailable 74 + | Self::HandleTaken 75 + | Self::InvalidEmail 76 + | Self::EmailTaken 77 + | Self::InvalidInviteCode 78 + | Self::DuplicateCreate 79 + | Self::DuplicateAppPassword 80 + | Self::InvalidSwap => StatusCode::BAD_REQUEST, 81 + } 82 + } 83 + 84 + fn error_name(&self) -> &'static str { 85 + match self { 86 + Self::InternalError | Self::DatabaseError | Self::UpstreamFailure => "InternalError", 87 + Self::AuthenticationRequired => "AuthenticationRequired", 88 + Self::AuthenticationFailed | Self::AuthenticationFailedMsg(_) => "AuthenticationFailed", 89 + Self::InvalidToken => "InvalidToken", 90 + Self::ExpiredToken | Self::ExpiredTokenMsg(_) => "ExpiredToken", 91 + Self::TokenRequired => "TokenRequired", 92 + Self::AccountDeactivated => "AccountDeactivated", 93 + Self::AccountTakedown => "AccountTakedown", 94 + Self::Forbidden => "Forbidden", 95 + Self::InvitesDisabled => "InvitesDisabled", 96 + Self::AccountNotFound => "AccountNotFound", 97 + Self::RepoNotFound | Self::RepoNotFoundMsg(_) => "RepoNotFound", 98 + Self::RecordNotFound => "RecordNotFound", 99 + Self::BlobNotFound => "BlobNotFound", 100 + Self::AppPasswordNotFound => "AppPasswordNotFound", 101 + Self::InvalidRequest(_) => "InvalidRequest", 102 + Self::InvalidHandle => "InvalidHandle", 103 + Self::HandleNotAvailable => "HandleNotAvailable", 104 + Self::HandleTaken => "HandleTaken", 105 + Self::InvalidEmail => "InvalidEmail", 106 + Self::EmailTaken => "EmailTaken", 107 + Self::InvalidInviteCode => "InvalidInviteCode", 108 + Self::DuplicateCreate => "DuplicateCreate", 109 + Self::DuplicateAppPassword => "DuplicateAppPassword", 110 + Self::InvalidSwap => "InvalidSwap", 111 + } 112 + } 113 + 114 + fn message(&self) -> Option<String> { 115 + match self { 116 + Self::AuthenticationFailedMsg(msg) 117 + | Self::ExpiredTokenMsg(msg) 118 + | Self::InvalidRequest(msg) 119 + | Self::RepoNotFoundMsg(msg) => Some(msg.clone()), 120 + _ => None, 121 + } 122 + } 123 + } 124 + 125 + impl IntoResponse for ApiError { 126 + fn into_response(self) -> Response { 127 + let body = ErrorBody { 128 + error: self.error_name(), 129 + message: self.message(), 130 + }; 131 + (self.status_code(), Json(body)).into_response() 132 + } 133 + } 134 + 135 + impl From<sqlx::Error> for ApiError { 136 + fn from(e: sqlx::Error) -> Self { 137 + tracing::error!("Database error: {:?}", e); 138 + Self::DatabaseError 139 + } 140 + } 141 + 142 + impl From<crate::auth::TokenValidationError> for ApiError { 143 + fn from(e: crate::auth::TokenValidationError) -> Self { 144 + match e { 145 + crate::auth::TokenValidationError::AccountDeactivated => Self::AccountDeactivated, 146 + crate::auth::TokenValidationError::AccountTakedown => Self::AccountTakedown, 147 + crate::auth::TokenValidationError::KeyDecryptionFailed => Self::InternalError, 148 + crate::auth::TokenValidationError::AuthenticationFailed => Self::AuthenticationFailed, 149 + } 150 + } 151 + } 152 + 153 + impl From<crate::util::DbLookupError> for ApiError { 154 + fn from(e: crate::util::DbLookupError) -> Self { 155 + match e { 156 + crate::util::DbLookupError::NotFound => Self::AccountNotFound, 157 + crate::util::DbLookupError::DatabaseError(db_err) => { 158 + tracing::error!("Database error: {:?}", db_err); 159 + Self::DatabaseError 160 + } 161 + } 162 + } 163 + }
+9 -1
src/api/identity/account.rs
··· 40 40 State(state): State<AppState>, 41 41 Json(input): Json<CreateAccountInput>, 42 42 ) -> Response { 43 - info!("create_account hit: {}", input.handle); 43 + info!("create_account called"); 44 44 if input.handle.contains('!') || input.handle.contains('@') { 45 45 return ( 46 46 StatusCode::BAD_REQUEST, 47 47 Json( 48 48 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 49 49 ), 50 + ) 51 + .into_response(); 52 + } 53 + 54 + if !crate::api::validation::is_valid_email(&input.email) { 55 + return ( 56 + StatusCode::BAD_REQUEST, 57 + Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 50 58 ) 51 59 .into_response(); 52 60 }
+36 -72
src/api/identity/did.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::state::AppState; 2 3 use axum::{ 3 4 Json, ··· 56 57 } 57 58 } 58 59 59 - pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 60 - let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 60 + pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 61 + let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 61 62 let public_key = secret_key.public_key(); 62 63 let encoded = public_key.to_encoded_point(false); 63 - let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 64 - let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 64 + let x = encoded.x().ok_or("Missing x coordinate")?; 65 + let y = encoded.y().ok_or("Missing y coordinate")?; 66 + let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 67 + let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 65 68 66 - json!({ 69 + Ok(json!({ 67 70 "kty": "EC", 68 71 "crv": "secp256k1", 69 - "x": x, 70 - "y": y 71 - }) 72 + "x": x_b64, 73 + "y": y_b64 74 + })) 72 75 } 73 76 74 77 pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { ··· 147 150 } 148 151 }; 149 152 150 - let jwk = get_jwk(&key_bytes); 153 + let jwk = match get_jwk(&key_bytes) { 154 + Ok(j) => j, 155 + Err(e) => { 156 + tracing::error!("Failed to generate JWK: {}", e); 157 + return ( 158 + StatusCode::INTERNAL_SERVER_ERROR, 159 + Json(json!({"error": "InternalError"})), 160 + ) 161 + .into_response(); 162 + } 163 + }; 151 164 152 165 Json(json!({ 153 166 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], ··· 294 307 } 295 308 }; 296 309 297 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 298 - let did = match auth_result { 299 - Ok(ref user) => user.did.clone(), 300 - Err(e) => { 301 - return ( 302 - StatusCode::UNAUTHORIZED, 303 - Json(json!({"error": e})), 304 - ) 305 - .into_response(); 306 - } 310 + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 311 + Ok(user) => user, 312 + Err(e) => return ApiError::from(e).into_response(), 307 313 }; 308 314 309 - let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", did) 315 + let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 310 316 .fetch_optional(&state.db) 311 317 .await 312 318 { 313 319 Ok(Some(row)) => row, 314 - _ => { 315 - return ( 316 - StatusCode::INTERNAL_SERVER_ERROR, 317 - Json(json!({"error": "InternalError"})), 318 - ) 319 - .into_response(); 320 - } 320 + _ => return ApiError::InternalError.into_response(), 321 321 }; 322 - let handle = user.handle; 323 322 324 - let key_bytes = match auth_result.ok().and_then(|u| u.key_bytes) { 323 + let key_bytes = match auth_user.key_bytes { 325 324 Some(kb) => kb, 326 - None => { 327 - return ( 328 - StatusCode::UNAUTHORIZED, 329 - Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot get DID credentials"})), 330 - ) 331 - .into_response(); 332 - } 325 + None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 333 326 }; 334 327 335 328 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); ··· 337 330 338 331 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 339 332 Ok(k) => k, 340 - Err(_) => { 341 - return ( 342 - StatusCode::INTERNAL_SERVER_ERROR, 343 - Json(json!({"error": "InternalError"})), 344 - ) 345 - .into_response(); 346 - } 333 + Err(_) => return ApiError::InternalError.into_response(), 347 334 }; 348 335 349 336 let public_key = secret_key.public_key(); ··· 360 347 StatusCode::OK, 361 348 Json(GetRecommendedDidCredentialsOutput { 362 349 rotation_keys: vec![did_key.clone()], 363 - also_known_as: vec![format!("at://{}", handle)], 350 + also_known_as: vec![format!("at://{}", user.handle)], 364 351 verification_methods: VerificationMethods { atproto: did_key }, 365 352 services: Services { 366 353 atproto_pds: AtprotoPds { ··· 387 374 headers.get("Authorization").and_then(|h| h.to_str().ok()) 388 375 ) { 389 376 Some(t) => t, 390 - None => { 391 - return ( 392 - StatusCode::UNAUTHORIZED, 393 - Json(json!({"error": "AuthenticationRequired"})), 394 - ) 395 - .into_response(); 396 - } 377 + None => return ApiError::AuthenticationRequired.into_response(), 397 378 }; 398 379 399 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 400 - let did = match auth_result { 380 + let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 401 381 Ok(user) => user.did, 402 - Err(e) => { 403 - return ( 404 - StatusCode::UNAUTHORIZED, 405 - Json(json!({"error": e})), 406 - ) 407 - .into_response(); 408 - } 382 + Err(e) => return ApiError::from(e).into_response(), 409 383 }; 410 384 411 385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) ··· 413 387 .await 414 388 { 415 389 Ok(Some(id)) => id, 416 - _ => { 417 - return ( 418 - StatusCode::INTERNAL_SERVER_ERROR, 419 - Json(json!({"error": "InternalError"})), 420 - ) 421 - .into_response(); 422 - } 390 + _ => return ApiError::InternalError.into_response(), 423 391 }; 424 392 425 393 let new_handle = input.handle.trim(); 426 394 if new_handle.is_empty() { 427 - return ( 428 - StatusCode::BAD_REQUEST, 429 - Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 430 - ) 431 - .into_response(); 395 + return ApiError::InvalidRequest("handle is required".into()).into_response(); 432 396 } 433 397 434 398 if !new_handle
-618
src/api/identity/plc.rs
··· 1 - use crate::plc::{ 2 - create_update_op, sign_operation, signing_key_to_did_key, validate_plc_operation, 3 - PlcClient, PlcError, PlcService, 4 - }; 5 - use crate::state::AppState; 6 - use axum::{ 7 - extract::State, 8 - http::StatusCode, 9 - response::{IntoResponse, Response}, 10 - Json, 11 - }; 12 - use chrono::{Duration, Utc}; 13 - use k256::ecdsa::SigningKey; 14 - use rand::Rng; 15 - use serde::{Deserialize, Serialize}; 16 - use serde_json::{json, Value}; 17 - use std::collections::HashMap; 18 - use tracing::{error, info, warn}; 19 - 20 - fn generate_plc_token() -> String { 21 - let mut rng = rand::thread_rng(); 22 - let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect(); 23 - let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 24 - let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 25 - format!("{}-{}", part1, part2) 26 - } 27 - 28 - pub async fn request_plc_operation_signature( 29 - State(state): State<AppState>, 30 - headers: axum::http::HeaderMap, 31 - ) -> Response { 32 - let token = match crate::auth::extract_bearer_token_from_header( 33 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 34 - ) { 35 - Some(t) => t, 36 - None => { 37 - return ( 38 - StatusCode::UNAUTHORIZED, 39 - Json(json!({"error": "AuthenticationRequired"})), 40 - ) 41 - .into_response(); 42 - } 43 - }; 44 - 45 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 46 - Ok(user) => user, 47 - Err(e) => { 48 - return ( 49 - StatusCode::UNAUTHORIZED, 50 - Json(json!({"error": "AuthenticationFailed", "message": e})), 51 - ) 52 - .into_response(); 53 - } 54 - }; 55 - 56 - let did = &auth_user.did; 57 - 58 - let user = match sqlx::query!( 59 - "SELECT id FROM users WHERE did = $1", 60 - did 61 - ) 62 - .fetch_optional(&state.db) 63 - .await 64 - { 65 - Ok(Some(row)) => row, 66 - Ok(None) => { 67 - return ( 68 - StatusCode::NOT_FOUND, 69 - Json(json!({"error": "AccountNotFound"})), 70 - ) 71 - .into_response(); 72 - } 73 - Err(e) => { 74 - error!("DB error: {:?}", e); 75 - return ( 76 - StatusCode::INTERNAL_SERVER_ERROR, 77 - Json(json!({"error": "InternalError"})), 78 - ) 79 - .into_response(); 80 - } 81 - }; 82 - 83 - let _ = sqlx::query!( 84 - "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()", 85 - user.id 86 - ) 87 - .execute(&state.db) 88 - .await; 89 - 90 - let plc_token = generate_plc_token(); 91 - let expires_at = Utc::now() + Duration::minutes(10); 92 - 93 - if let Err(e) = sqlx::query!( 94 - r#" 95 - INSERT INTO plc_operation_tokens (user_id, token, expires_at) 96 - VALUES ($1, $2, $3) 97 - "#, 98 - user.id, 99 - plc_token, 100 - expires_at 101 - ) 102 - .execute(&state.db) 103 - .await 104 - { 105 - error!("Failed to create PLC token: {:?}", e); 106 - return ( 107 - StatusCode::INTERNAL_SERVER_ERROR, 108 - Json(json!({"error": "InternalError"})), 109 - ) 110 - .into_response(); 111 - } 112 - 113 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 114 - 115 - if let Err(e) = crate::notifications::enqueue_plc_operation( 116 - &state.db, 117 - user.id, 118 - &plc_token, 119 - &hostname, 120 - ) 121 - .await 122 - { 123 - warn!("Failed to enqueue PLC operation notification: {:?}", e); 124 - } 125 - 126 - info!("PLC operation signature requested for user {}", did); 127 - 128 - (StatusCode::OK, Json(json!({}))).into_response() 129 - } 130 - 131 - #[derive(Debug, Deserialize)] 132 - #[serde(rename_all = "camelCase")] 133 - pub struct SignPlcOperationInput { 134 - pub token: Option<String>, 135 - pub rotation_keys: Option<Vec<String>>, 136 - pub also_known_as: Option<Vec<String>>, 137 - pub verification_methods: Option<HashMap<String, String>>, 138 - pub services: Option<HashMap<String, ServiceInput>>, 139 - } 140 - 141 - #[derive(Debug, Deserialize, Clone)] 142 - pub struct ServiceInput { 143 - #[serde(rename = "type")] 144 - pub service_type: String, 145 - pub endpoint: String, 146 - } 147 - 148 - #[derive(Debug, Serialize)] 149 - pub struct SignPlcOperationOutput { 150 - pub operation: Value, 151 - } 152 - 153 - pub async fn sign_plc_operation( 154 - State(state): State<AppState>, 155 - headers: axum::http::HeaderMap, 156 - Json(input): Json<SignPlcOperationInput>, 157 - ) -> Response { 158 - let bearer = match crate::auth::extract_bearer_token_from_header( 159 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 160 - ) { 161 - Some(t) => t, 162 - None => { 163 - return ( 164 - StatusCode::UNAUTHORIZED, 165 - Json(json!({"error": "AuthenticationRequired"})), 166 - ) 167 - .into_response(); 168 - } 169 - }; 170 - 171 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { 172 - Ok(user) => user, 173 - Err(e) => { 174 - return ( 175 - StatusCode::UNAUTHORIZED, 176 - Json(json!({"error": "AuthenticationFailed", "message": e})), 177 - ) 178 - .into_response(); 179 - } 180 - }; 181 - 182 - let did = &auth_user.did; 183 - 184 - let token = match &input.token { 185 - Some(t) => t, 186 - None => { 187 - return ( 188 - StatusCode::BAD_REQUEST, 189 - Json(json!({ 190 - "error": "InvalidRequest", 191 - "message": "Email confirmation token required to sign PLC operations" 192 - })), 193 - ) 194 - .into_response(); 195 - } 196 - }; 197 - 198 - let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did) 199 - .fetch_optional(&state.db) 200 - .await 201 - { 202 - Ok(Some(row)) => row, 203 - _ => { 204 - return ( 205 - StatusCode::NOT_FOUND, 206 - Json(json!({"error": "AccountNotFound"})), 207 - ) 208 - .into_response(); 209 - } 210 - }; 211 - 212 - let token_row = match sqlx::query!( 213 - "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2", 214 - user.id, 215 - token 216 - ) 217 - .fetch_optional(&state.db) 218 - .await 219 - { 220 - Ok(Some(row)) => row, 221 - Ok(None) => { 222 - return ( 223 - StatusCode::BAD_REQUEST, 224 - Json(json!({ 225 - "error": "InvalidToken", 226 - "message": "Invalid or expired token" 227 - })), 228 - ) 229 - .into_response(); 230 - } 231 - Err(e) => { 232 - error!("DB error: {:?}", e); 233 - return ( 234 - StatusCode::INTERNAL_SERVER_ERROR, 235 - Json(json!({"error": "InternalError"})), 236 - ) 237 - .into_response(); 238 - } 239 - }; 240 - 241 - if Utc::now() > token_row.expires_at { 242 - let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 243 - .execute(&state.db) 244 - .await; 245 - return ( 246 - StatusCode::BAD_REQUEST, 247 - Json(json!({ 248 - "error": "ExpiredToken", 249 - "message": "Token has expired" 250 - })), 251 - ) 252 - .into_response(); 253 - } 254 - 255 - let key_row = match sqlx::query!( 256 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 257 - user.id 258 - ) 259 - .fetch_optional(&state.db) 260 - .await 261 - { 262 - Ok(Some(row)) => row, 263 - _ => { 264 - return ( 265 - StatusCode::INTERNAL_SERVER_ERROR, 266 - Json(json!({"error": "InternalError", "message": "User signing key not found"})), 267 - ) 268 - .into_response(); 269 - } 270 - }; 271 - 272 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 273 - { 274 - Ok(k) => k, 275 - Err(e) => { 276 - error!("Failed to decrypt user key: {}", e); 277 - return ( 278 - StatusCode::INTERNAL_SERVER_ERROR, 279 - Json(json!({"error": "InternalError"})), 280 - ) 281 - .into_response(); 282 - } 283 - }; 284 - 285 - let signing_key = match SigningKey::from_slice(&key_bytes) { 286 - Ok(k) => k, 287 - Err(e) => { 288 - error!("Failed to create signing key: {:?}", e); 289 - return ( 290 - StatusCode::INTERNAL_SERVER_ERROR, 291 - Json(json!({"error": "InternalError"})), 292 - ) 293 - .into_response(); 294 - } 295 - }; 296 - 297 - let plc_client = PlcClient::new(None); 298 - let last_op = match plc_client.get_last_op(did).await { 299 - Ok(op) => op, 300 - Err(PlcError::NotFound) => { 301 - return ( 302 - StatusCode::NOT_FOUND, 303 - Json(json!({ 304 - "error": "NotFound", 305 - "message": "DID not found in PLC directory" 306 - })), 307 - ) 308 - .into_response(); 309 - } 310 - Err(e) => { 311 - error!("Failed to fetch PLC operation: {:?}", e); 312 - return ( 313 - StatusCode::BAD_GATEWAY, 314 - Json(json!({ 315 - "error": "UpstreamError", 316 - "message": "Failed to communicate with PLC directory" 317 - })), 318 - ) 319 - .into_response(); 320 - } 321 - }; 322 - 323 - if last_op.is_tombstone() { 324 - return ( 325 - StatusCode::BAD_REQUEST, 326 - Json(json!({ 327 - "error": "InvalidRequest", 328 - "message": "DID is tombstoned" 329 - })), 330 - ) 331 - .into_response(); 332 - } 333 - 334 - let services = input.services.map(|s| { 335 - s.into_iter() 336 - .map(|(k, v)| { 337 - ( 338 - k, 339 - PlcService { 340 - service_type: v.service_type, 341 - endpoint: v.endpoint, 342 - }, 343 - ) 344 - }) 345 - .collect() 346 - }); 347 - 348 - let unsigned_op = match create_update_op( 349 - &last_op, 350 - input.rotation_keys, 351 - input.verification_methods, 352 - input.also_known_as, 353 - services, 354 - ) { 355 - Ok(op) => op, 356 - Err(PlcError::Tombstoned) => { 357 - return ( 358 - StatusCode::BAD_REQUEST, 359 - Json(json!({ 360 - "error": "InvalidRequest", 361 - "message": "Cannot update tombstoned DID" 362 - })), 363 - ) 364 - .into_response(); 365 - } 366 - Err(e) => { 367 - error!("Failed to create PLC operation: {:?}", e); 368 - return ( 369 - StatusCode::INTERNAL_SERVER_ERROR, 370 - Json(json!({"error": "InternalError"})), 371 - ) 372 - .into_response(); 373 - } 374 - }; 375 - 376 - let signed_op = match sign_operation(&unsigned_op, &signing_key) { 377 - Ok(op) => op, 378 - Err(e) => { 379 - error!("Failed to sign PLC operation: {:?}", e); 380 - return ( 381 - StatusCode::INTERNAL_SERVER_ERROR, 382 - Json(json!({"error": "InternalError"})), 383 - ) 384 - .into_response(); 385 - } 386 - }; 387 - 388 - let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 389 - .execute(&state.db) 390 - .await; 391 - 392 - info!("Signed PLC operation for user {}", did); 393 - 394 - ( 395 - StatusCode::OK, 396 - Json(SignPlcOperationOutput { 397 - operation: signed_op, 398 - }), 399 - ) 400 - .into_response() 401 - } 402 - 403 - #[derive(Debug, Deserialize)] 404 - pub struct SubmitPlcOperationInput { 405 - pub operation: Value, 406 - } 407 - 408 - pub async fn submit_plc_operation( 409 - State(state): State<AppState>, 410 - headers: axum::http::HeaderMap, 411 - Json(input): Json<SubmitPlcOperationInput>, 412 - ) -> Response { 413 - let bearer = match crate::auth::extract_bearer_token_from_header( 414 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 415 - ) { 416 - Some(t) => t, 417 - None => { 418 - return ( 419 - StatusCode::UNAUTHORIZED, 420 - Json(json!({"error": "AuthenticationRequired"})), 421 - ) 422 - .into_response(); 423 - } 424 - }; 425 - 426 - let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { 427 - Ok(user) => user, 428 - Err(e) => { 429 - return ( 430 - StatusCode::UNAUTHORIZED, 431 - Json(json!({"error": "AuthenticationFailed", "message": e})), 432 - ) 433 - .into_response(); 434 - } 435 - }; 436 - 437 - let did = &auth_user.did; 438 - 439 - if let Err(e) = validate_plc_operation(&input.operation) { 440 - return ( 441 - StatusCode::BAD_REQUEST, 442 - Json(json!({ 443 - "error": "InvalidRequest", 444 - "message": format!("Invalid operation: {}", e) 445 - })), 446 - ) 447 - .into_response(); 448 - } 449 - 450 - let op = &input.operation; 451 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 452 - let public_url = format!("https://{}", hostname); 453 - 454 - let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) 455 - .fetch_optional(&state.db) 456 - .await 457 - { 458 - Ok(Some(row)) => row, 459 - _ => { 460 - return ( 461 - StatusCode::NOT_FOUND, 462 - Json(json!({"error": "AccountNotFound"})), 463 - ) 464 - .into_response(); 465 - } 466 - }; 467 - 468 - let key_row = match sqlx::query!( 469 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 470 - user.id 471 - ) 472 - .fetch_optional(&state.db) 473 - .await 474 - { 475 - Ok(Some(row)) => row, 476 - _ => { 477 - return ( 478 - StatusCode::INTERNAL_SERVER_ERROR, 479 - Json(json!({"error": "InternalError", "message": "User signing key not found"})), 480 - ) 481 - .into_response(); 482 - } 483 - }; 484 - 485 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 486 - { 487 - Ok(k) => k, 488 - Err(e) => { 489 - error!("Failed to decrypt user key: {}", e); 490 - return ( 491 - StatusCode::INTERNAL_SERVER_ERROR, 492 - Json(json!({"error": "InternalError"})), 493 - ) 494 - .into_response(); 495 - } 496 - }; 497 - 498 - let signing_key = match SigningKey::from_slice(&key_bytes) { 499 - Ok(k) => k, 500 - Err(e) => { 501 - error!("Failed to create signing key: {:?}", e); 502 - return ( 503 - StatusCode::INTERNAL_SERVER_ERROR, 504 - Json(json!({"error": "InternalError"})), 505 - ) 506 - .into_response(); 507 - } 508 - }; 509 - 510 - let user_did_key = signing_key_to_did_key(&signing_key); 511 - 512 - if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) { 513 - let server_rotation_key = 514 - std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 515 - 516 - let has_server_key = rotation_keys 517 - .iter() 518 - .any(|k| k.as_str() == Some(&server_rotation_key)); 519 - 520 - if !has_server_key { 521 - return ( 522 - StatusCode::BAD_REQUEST, 523 - Json(json!({ 524 - "error": "InvalidRequest", 525 - "message": "Rotation keys do not include server's rotation key" 526 - })), 527 - ) 528 - .into_response(); 529 - } 530 - } 531 - 532 - if let Some(services) = op.get("services").and_then(|v| v.as_object()) { 533 - if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 534 - let service_type = pds.get("type").and_then(|v| v.as_str()); 535 - let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 536 - 537 - if service_type != Some("AtprotoPersonalDataServer") { 538 - return ( 539 - StatusCode::BAD_REQUEST, 540 - Json(json!({ 541 - "error": "InvalidRequest", 542 - "message": "Incorrect type on atproto_pds service" 543 - })), 544 - ) 545 - .into_response(); 546 - } 547 - 548 - if endpoint != Some(&public_url) { 549 - return ( 550 - StatusCode::BAD_REQUEST, 551 - Json(json!({ 552 - "error": "InvalidRequest", 553 - "message": "Incorrect endpoint on atproto_pds service" 554 - })), 555 - ) 556 - .into_response(); 557 - } 558 - } 559 - } 560 - 561 - if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) { 562 - if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { 563 - if atproto_key != user_did_key { 564 - return ( 565 - StatusCode::BAD_REQUEST, 566 - Json(json!({ 567 - "error": "InvalidRequest", 568 - "message": "Incorrect signing key in verificationMethods" 569 - })), 570 - ) 571 - .into_response(); 572 - } 573 - } 574 - } 575 - 576 - if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 577 - let expected_handle = format!("at://{}", user.handle); 578 - let first_aka = also_known_as.first().and_then(|v| v.as_str()); 579 - 580 - if first_aka != Some(&expected_handle) { 581 - return ( 582 - StatusCode::BAD_REQUEST, 583 - Json(json!({ 584 - "error": "InvalidRequest", 585 - "message": "Incorrect handle in alsoKnownAs" 586 - })), 587 - ) 588 - .into_response(); 589 - } 590 - } 591 - 592 - let plc_client = PlcClient::new(None); 593 - if let Err(e) = plc_client.send_operation(did, &input.operation).await { 594 - error!("Failed to submit PLC operation: {:?}", e); 595 - return ( 596 - StatusCode::BAD_GATEWAY, 597 - Json(json!({ 598 - "error": "UpstreamError", 599 - "message": format!("Failed to submit to PLC directory: {}", e) 600 - })), 601 - ) 602 - .into_response(); 603 - } 604 - 605 - if let Err(e) = sqlx::query!( 606 - "INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')", 607 - did 608 - ) 609 - .execute(&state.db) 610 - .await 611 - { 612 - warn!("Failed to sequence identity event: {:?}", e); 613 - } 614 - 615 - info!("Submitted PLC operation for user {}", did); 616 - 617 - (StatusCode::OK, Json(json!({}))).into_response() 618 - }
+7
src/api/identity/plc/mod.rs
··· 1 + mod request; 2 + mod sign; 3 + mod submit; 4 + 5 + pub use request::request_plc_operation_signature; 6 + pub use sign::{sign_plc_operation, ServiceInput, SignPlcOperationInput, SignPlcOperationOutput}; 7 + pub use submit::{submit_plc_operation, SubmitPlcOperationInput};
+91
src/api/identity/plc/request.rs
··· 1 + use crate::api::ApiError; 2 + use crate::state::AppState; 3 + use axum::{ 4 + extract::State, 5 + http::StatusCode, 6 + response::{IntoResponse, Response}, 7 + Json, 8 + }; 9 + use chrono::{Duration, Utc}; 10 + use serde_json::json; 11 + use tracing::{error, info, warn}; 12 + 13 + fn generate_plc_token() -> String { 14 + crate::util::generate_token_code() 15 + } 16 + 17 + pub async fn request_plc_operation_signature( 18 + State(state): State<AppState>, 19 + headers: axum::http::HeaderMap, 20 + ) -> Response { 21 + let token = match crate::auth::extract_bearer_token_from_header( 22 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 23 + ) { 24 + Some(t) => t, 25 + None => return ApiError::AuthenticationRequired.into_response(), 26 + }; 27 + 28 + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 29 + Ok(user) => user, 30 + Err(e) => return ApiError::from(e).into_response(), 31 + }; 32 + 33 + let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth_user.did) 34 + .fetch_optional(&state.db) 35 + .await 36 + { 37 + Ok(Some(row)) => row, 38 + Ok(None) => return ApiError::AccountNotFound.into_response(), 39 + Err(e) => { 40 + error!("DB error: {:?}", e); 41 + return ApiError::InternalError.into_response(); 42 + } 43 + }; 44 + 45 + let _ = sqlx::query!( 46 + "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()", 47 + user.id 48 + ) 49 + .execute(&state.db) 50 + .await; 51 + 52 + let plc_token = generate_plc_token(); 53 + let expires_at = Utc::now() + Duration::minutes(10); 54 + 55 + if let Err(e) = sqlx::query!( 56 + r#" 57 + INSERT INTO plc_operation_tokens (user_id, token, expires_at) 58 + VALUES ($1, $2, $3) 59 + "#, 60 + user.id, 61 + plc_token, 62 + expires_at 63 + ) 64 + .execute(&state.db) 65 + .await 66 + { 67 + error!("Failed to create PLC token: {:?}", e); 68 + return ( 69 + StatusCode::INTERNAL_SERVER_ERROR, 70 + Json(json!({"error": "InternalError"})), 71 + ) 72 + .into_response(); 73 + } 74 + 75 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 76 + 77 + if let Err(e) = crate::notifications::enqueue_plc_operation( 78 + &state.db, 79 + user.id, 80 + &plc_token, 81 + &hostname, 82 + ) 83 + .await 84 + { 85 + warn!("Failed to enqueue PLC operation notification: {:?}", e); 86 + } 87 + 88 + info!("PLC operation signature requested for user {}", auth_user.did); 89 + 90 + (StatusCode::OK, Json(json!({}))).into_response() 91 + }
+272
src/api/identity/plc/sign.rs
··· 1 + use crate::api::ApiError; 2 + use crate::plc::{ 3 + create_update_op, sign_operation, PlcClient, PlcError, PlcService, 4 + }; 5 + use crate::state::AppState; 6 + use axum::{ 7 + extract::State, 8 + http::StatusCode, 9 + response::{IntoResponse, Response}, 10 + Json, 11 + }; 12 + use chrono::Utc; 13 + use k256::ecdsa::SigningKey; 14 + use serde::{Deserialize, Serialize}; 15 + use serde_json::{json, Value}; 16 + use std::collections::HashMap; 17 + use tracing::{error, info}; 18 + 19 + #[derive(Debug, Deserialize)] 20 + #[serde(rename_all = "camelCase")] 21 + pub struct SignPlcOperationInput { 22 + pub token: Option<String>, 23 + pub rotation_keys: Option<Vec<String>>, 24 + pub also_known_as: Option<Vec<String>>, 25 + pub verification_methods: Option<HashMap<String, String>>, 26 + pub services: Option<HashMap<String, ServiceInput>>, 27 + } 28 + 29 + #[derive(Debug, Deserialize, Clone)] 30 + pub struct ServiceInput { 31 + #[serde(rename = "type")] 32 + pub service_type: String, 33 + pub endpoint: String, 34 + } 35 + 36 + #[derive(Debug, Serialize)] 37 + pub struct SignPlcOperationOutput { 38 + pub operation: Value, 39 + } 40 + 41 + pub async fn sign_plc_operation( 42 + State(state): State<AppState>, 43 + headers: axum::http::HeaderMap, 44 + Json(input): Json<SignPlcOperationInput>, 45 + ) -> Response { 46 + let bearer = match crate::auth::extract_bearer_token_from_header( 47 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 48 + ) { 49 + Some(t) => t, 50 + None => return ApiError::AuthenticationRequired.into_response(), 51 + }; 52 + 53 + let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { 54 + Ok(user) => user, 55 + Err(e) => return ApiError::from(e).into_response(), 56 + }; 57 + 58 + let did = &auth_user.did; 59 + 60 + let token = match &input.token { 61 + Some(t) => t, 62 + None => { 63 + return ApiError::InvalidRequest( 64 + "Email confirmation token required to sign PLC operations".into() 65 + ).into_response(); 66 + } 67 + }; 68 + 69 + let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did) 70 + .fetch_optional(&state.db) 71 + .await 72 + { 73 + Ok(Some(row)) => row, 74 + _ => { 75 + return ( 76 + StatusCode::NOT_FOUND, 77 + Json(json!({"error": "AccountNotFound"})), 78 + ) 79 + .into_response(); 80 + } 81 + }; 82 + 83 + let token_row = match sqlx::query!( 84 + "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2", 85 + user.id, 86 + token 87 + ) 88 + .fetch_optional(&state.db) 89 + .await 90 + { 91 + Ok(Some(row)) => row, 92 + Ok(None) => { 93 + return ( 94 + StatusCode::BAD_REQUEST, 95 + Json(json!({ 96 + "error": "InvalidToken", 97 + "message": "Invalid or expired token" 98 + })), 99 + ) 100 + .into_response(); 101 + } 102 + Err(e) => { 103 + error!("DB error: {:?}", e); 104 + return ( 105 + StatusCode::INTERNAL_SERVER_ERROR, 106 + Json(json!({"error": "InternalError"})), 107 + ) 108 + .into_response(); 109 + } 110 + }; 111 + 112 + if Utc::now() > token_row.expires_at { 113 + let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 114 + .execute(&state.db) 115 + .await; 116 + return ( 117 + StatusCode::BAD_REQUEST, 118 + Json(json!({ 119 + "error": "ExpiredToken", 120 + "message": "Token has expired" 121 + })), 122 + ) 123 + .into_response(); 124 + } 125 + 126 + let key_row = match sqlx::query!( 127 + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 128 + user.id 129 + ) 130 + .fetch_optional(&state.db) 131 + .await 132 + { 133 + Ok(Some(row)) => row, 134 + _ => { 135 + return ( 136 + StatusCode::INTERNAL_SERVER_ERROR, 137 + Json(json!({"error": "InternalError", "message": "User signing key not found"})), 138 + ) 139 + .into_response(); 140 + } 141 + }; 142 + 143 + let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 144 + { 145 + Ok(k) => k, 146 + Err(e) => { 147 + error!("Failed to decrypt user key: {}", e); 148 + return ( 149 + StatusCode::INTERNAL_SERVER_ERROR, 150 + Json(json!({"error": "InternalError"})), 151 + ) 152 + .into_response(); 153 + } 154 + }; 155 + 156 + let signing_key = match SigningKey::from_slice(&key_bytes) { 157 + Ok(k) => k, 158 + Err(e) => { 159 + error!("Failed to create signing key: {:?}", e); 160 + return ( 161 + StatusCode::INTERNAL_SERVER_ERROR, 162 + Json(json!({"error": "InternalError"})), 163 + ) 164 + .into_response(); 165 + } 166 + }; 167 + 168 + let plc_client = PlcClient::new(None); 169 + let last_op = match plc_client.get_last_op(did).await { 170 + Ok(op) => op, 171 + Err(PlcError::NotFound) => { 172 + return ( 173 + StatusCode::NOT_FOUND, 174 + Json(json!({ 175 + "error": "NotFound", 176 + "message": "DID not found in PLC directory" 177 + })), 178 + ) 179 + .into_response(); 180 + } 181 + Err(e) => { 182 + error!("Failed to fetch PLC operation: {:?}", e); 183 + return ( 184 + StatusCode::BAD_GATEWAY, 185 + Json(json!({ 186 + "error": "UpstreamError", 187 + "message": "Failed to communicate with PLC directory" 188 + })), 189 + ) 190 + .into_response(); 191 + } 192 + }; 193 + 194 + if last_op.is_tombstone() { 195 + return ( 196 + StatusCode::BAD_REQUEST, 197 + Json(json!({ 198 + "error": "InvalidRequest", 199 + "message": "DID is tombstoned" 200 + })), 201 + ) 202 + .into_response(); 203 + } 204 + 205 + let services = input.services.map(|s| { 206 + s.into_iter() 207 + .map(|(k, v)| { 208 + ( 209 + k, 210 + PlcService { 211 + service_type: v.service_type, 212 + endpoint: v.endpoint, 213 + }, 214 + ) 215 + }) 216 + .collect() 217 + }); 218 + 219 + let unsigned_op = match create_update_op( 220 + &last_op, 221 + input.rotation_keys, 222 + input.verification_methods, 223 + input.also_known_as, 224 + services, 225 + ) { 226 + Ok(op) => op, 227 + Err(PlcError::Tombstoned) => { 228 + return ( 229 + StatusCode::BAD_REQUEST, 230 + Json(json!({ 231 + "error": "InvalidRequest", 232 + "message": "Cannot update tombstoned DID" 233 + })), 234 + ) 235 + .into_response(); 236 + } 237 + Err(e) => { 238 + error!("Failed to create PLC operation: {:?}", e); 239 + return ( 240 + StatusCode::INTERNAL_SERVER_ERROR, 241 + Json(json!({"error": "InternalError"})), 242 + ) 243 + .into_response(); 244 + } 245 + }; 246 + 247 + let signed_op = match sign_operation(&unsigned_op, &signing_key) { 248 + Ok(op) => op, 249 + Err(e) => { 250 + error!("Failed to sign PLC operation: {:?}", e); 251 + return ( 252 + StatusCode::INTERNAL_SERVER_ERROR, 253 + Json(json!({"error": "InternalError"})), 254 + ) 255 + .into_response(); 256 + } 257 + }; 258 + 259 + let _ = sqlx::query!("DELETE FROM plc_operation_tokens WHERE id = $1", token_row.id) 260 + .execute(&state.db) 261 + .await; 262 + 263 + info!("Signed PLC operation for user {}", did); 264 + 265 + ( 266 + StatusCode::OK, 267 + Json(SignPlcOperationOutput { 268 + operation: signed_op, 269 + }), 270 + ) 271 + .into_response() 272 + }
+211
src/api/identity/plc/submit.rs
··· 1 + use crate::api::ApiError; 2 + use crate::plc::{signing_key_to_did_key, validate_plc_operation, PlcClient}; 3 + use crate::state::AppState; 4 + use axum::{ 5 + extract::State, 6 + http::StatusCode, 7 + response::{IntoResponse, Response}, 8 + Json, 9 + }; 10 + use k256::ecdsa::SigningKey; 11 + use serde::Deserialize; 12 + use serde_json::{json, Value}; 13 + use tracing::{error, info, warn}; 14 + 15 + #[derive(Debug, Deserialize)] 16 + pub struct SubmitPlcOperationInput { 17 + pub operation: Value, 18 + } 19 + 20 + pub async fn submit_plc_operation( 21 + State(state): State<AppState>, 22 + headers: axum::http::HeaderMap, 23 + Json(input): Json<SubmitPlcOperationInput>, 24 + ) -> Response { 25 + let bearer = match crate::auth::extract_bearer_token_from_header( 26 + headers.get("Authorization").and_then(|h| h.to_str().ok()), 27 + ) { 28 + Some(t) => t, 29 + None => return ApiError::AuthenticationRequired.into_response(), 30 + }; 31 + 32 + let auth_user = match crate::auth::validate_bearer_token(&state.db, &bearer).await { 33 + Ok(user) => user, 34 + Err(e) => return ApiError::from(e).into_response(), 35 + }; 36 + 37 + let did = &auth_user.did; 38 + 39 + if let Err(e) = validate_plc_operation(&input.operation) { 40 + return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); 41 + } 42 + 43 + let op = &input.operation; 44 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 45 + let public_url = format!("https://{}", hostname); 46 + 47 + let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) 48 + .fetch_optional(&state.db) 49 + .await 50 + { 51 + Ok(Some(row)) => row, 52 + _ => { 53 + return ( 54 + StatusCode::NOT_FOUND, 55 + Json(json!({"error": "AccountNotFound"})), 56 + ) 57 + .into_response(); 58 + } 59 + }; 60 + 61 + let key_row = match sqlx::query!( 62 + "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 63 + user.id 64 + ) 65 + .fetch_optional(&state.db) 66 + .await 67 + { 68 + Ok(Some(row)) => row, 69 + _ => { 70 + return ( 71 + StatusCode::INTERNAL_SERVER_ERROR, 72 + Json(json!({"error": "InternalError", "message": "User signing key not found"})), 73 + ) 74 + .into_response(); 75 + } 76 + }; 77 + 78 + let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 79 + { 80 + Ok(k) => k, 81 + Err(e) => { 82 + error!("Failed to decrypt user key: {}", e); 83 + return ( 84 + StatusCode::INTERNAL_SERVER_ERROR, 85 + Json(json!({"error": "InternalError"})), 86 + ) 87 + .into_response(); 88 + } 89 + }; 90 + 91 + let signing_key = match SigningKey::from_slice(&key_bytes) { 92 + Ok(k) => k, 93 + Err(e) => { 94 + error!("Failed to create signing key: {:?}", e); 95 + return ( 96 + StatusCode::INTERNAL_SERVER_ERROR, 97 + Json(json!({"error": "InternalError"})), 98 + ) 99 + .into_response(); 100 + } 101 + }; 102 + 103 + let user_did_key = signing_key_to_did_key(&signing_key); 104 + 105 + if let Some(rotation_keys) = op.get("rotationKeys").and_then(|v| v.as_array()) { 106 + let server_rotation_key = 107 + std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); 108 + 109 + let has_server_key = rotation_keys 110 + .iter() 111 + .any(|k| k.as_str() == Some(&server_rotation_key)); 112 + 113 + if !has_server_key { 114 + return ( 115 + StatusCode::BAD_REQUEST, 116 + Json(json!({ 117 + "error": "InvalidRequest", 118 + "message": "Rotation keys do not include server's rotation key" 119 + })), 120 + ) 121 + .into_response(); 122 + } 123 + } 124 + 125 + if let Some(services) = op.get("services").and_then(|v| v.as_object()) { 126 + if let Some(pds) = services.get("atproto_pds").and_then(|v| v.as_object()) { 127 + let service_type = pds.get("type").and_then(|v| v.as_str()); 128 + let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 129 + 130 + if service_type != Some("AtprotoPersonalDataServer") { 131 + return ( 132 + StatusCode::BAD_REQUEST, 133 + Json(json!({ 134 + "error": "InvalidRequest", 135 + "message": "Incorrect type on atproto_pds service" 136 + })), 137 + ) 138 + .into_response(); 139 + } 140 + 141 + if endpoint != Some(&public_url) { 142 + return ( 143 + StatusCode::BAD_REQUEST, 144 + Json(json!({ 145 + "error": "InvalidRequest", 146 + "message": "Incorrect endpoint on atproto_pds service" 147 + })), 148 + ) 149 + .into_response(); 150 + } 151 + } 152 + } 153 + 154 + if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) { 155 + if let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) { 156 + if atproto_key != user_did_key { 157 + return ( 158 + StatusCode::BAD_REQUEST, 159 + Json(json!({ 160 + "error": "InvalidRequest", 161 + "message": "Incorrect signing key in verificationMethods" 162 + })), 163 + ) 164 + .into_response(); 165 + } 166 + } 167 + } 168 + 169 + if let Some(also_known_as) = op.get("alsoKnownAs").and_then(|v| v.as_array()) { 170 + let expected_handle = format!("at://{}", user.handle); 171 + let first_aka = also_known_as.first().and_then(|v| v.as_str()); 172 + 173 + if first_aka != Some(&expected_handle) { 174 + return ( 175 + StatusCode::BAD_REQUEST, 176 + Json(json!({ 177 + "error": "InvalidRequest", 178 + "message": "Incorrect handle in alsoKnownAs" 179 + })), 180 + ) 181 + .into_response(); 182 + } 183 + } 184 + 185 + let plc_client = PlcClient::new(None); 186 + if let Err(e) = plc_client.send_operation(did, &input.operation).await { 187 + error!("Failed to submit PLC operation: {:?}", e); 188 + return ( 189 + StatusCode::BAD_GATEWAY, 190 + Json(json!({ 191 + "error": "UpstreamError", 192 + "message": format!("Failed to submit to PLC directory: {}", e) 193 + })), 194 + ) 195 + .into_response(); 196 + } 197 + 198 + if let Err(e) = sqlx::query!( 199 + "INSERT INTO repo_seq (did, event_type) VALUES ($1, 'identity')", 200 + did 201 + ) 202 + .execute(&state.db) 203 + .await 204 + { 205 + warn!("Failed to sequence identity event: {:?}", e); 206 + } 207 + 208 + info!("Submitted PLC operation for user {}", did); 209 + 210 + (StatusCode::OK, Json(json!({}))).into_response() 211 + }
+4
src/api/mod.rs
··· 1 1 pub mod actor; 2 2 pub mod admin; 3 + pub mod error; 3 4 pub mod feed; 4 5 pub mod identity; 5 6 pub mod moderation; 6 7 pub mod proxy; 7 8 pub mod repo; 8 9 pub mod server; 10 + pub mod validation; 11 + 12 + pub use error::ApiError;
+4 -16
src/api/moderation/mod.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::state::AppState; 2 3 use axum::{ 3 4 Json, ··· 37 38 headers.get("Authorization").and_then(|h| h.to_str().ok()) 38 39 ) { 39 40 Some(t) => t, 40 - None => { 41 - return ( 42 - StatusCode::UNAUTHORIZED, 43 - Json(json!({"error": "AuthenticationRequired"})), 44 - ) 45 - .into_response(); 46 - } 41 + None => return ApiError::AuthenticationRequired.into_response(), 47 42 }; 48 43 49 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 50 - let did = match auth_result { 44 + let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 51 45 Ok(user) => user.did, 52 - Err(e) => { 53 - return ( 54 - StatusCode::UNAUTHORIZED, 55 - Json(json!({"error": e})), 56 - ) 57 - .into_response(); 58 - } 46 + Err(e) => return ApiError::from(e).into_response(), 59 47 }; 60 48 61 49 let valid_reason_types = [
+22 -2
src/api/repo/blob.rs
··· 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 17 18 + const MAX_BLOB_SIZE: usize = 1_000_000; 19 + 18 20 pub async fn upload_blob( 19 21 State(state): State<AppState>, 20 22 headers: axum::http::HeaderMap, 21 23 body: Bytes, 22 24 ) -> Response { 25 + if body.len() > MAX_BLOB_SIZE { 26 + return ( 27 + StatusCode::PAYLOAD_TOO_LARGE, 28 + Json(json!({"error": "BlobTooLarge", "message": format!("Blob size {} exceeds maximum of {} bytes", body.len(), MAX_BLOB_SIZE)})), 29 + ) 30 + .into_response(); 31 + } 32 + 23 33 let token = match crate::auth::extract_bearer_token_from_header( 24 34 headers.get("Authorization").and_then(|h| h.to_str().ok()) 25 35 ) { ··· 57 67 let mut hasher = Sha256::new(); 58 68 hasher.update(&data); 59 69 let hash = hasher.finalize(); 60 - let multihash = Multihash::wrap(0x12, &hash).unwrap(); 70 + let multihash = match Multihash::wrap(0x12, &hash) { 71 + Ok(mh) => mh, 72 + Err(e) => { 73 + error!("Failed to create multihash for blob: {:?}", e); 74 + return ( 75 + StatusCode::INTERNAL_SERVER_ERROR, 76 + Json(json!({"error": "InternalError", "message": "Failed to hash blob"})), 77 + ) 78 + .into_response(); 79 + } 80 + }; 61 81 let cid = Cid::new_v1(0x55, multihash); 62 82 let cid_str = cid.to_string(); 63 83 ··· 207 227 } 208 228 }; 209 229 210 - let limit = params.limit.unwrap_or(500).min(1000); 230 + let limit = params.limit.unwrap_or(500).clamp(1, 1000); 211 231 let cursor_str = params.cursor.unwrap_or_default(); 212 232 let (cursor_collection, cursor_rkey) = if cursor_str.contains('|') { 213 233 let parts: Vec<&str> = cursor_str.split('|').collect();
+3 -14
src/api/repo/import.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::state::AppState; 2 3 use crate::sync::import::{apply_import, parse_car, ImportError}; 3 4 use crate::sync::verify::CarVerifier; ··· 54 55 headers.get("Authorization").and_then(|h| h.to_str().ok()), 55 56 ) { 56 57 Some(t) => t, 57 - None => { 58 - return ( 59 - StatusCode::UNAUTHORIZED, 60 - Json(json!({"error": "AuthenticationRequired"})), 61 - ) 62 - .into_response(); 63 - } 58 + None => return ApiError::AuthenticationRequired.into_response(), 64 59 }; 65 60 66 61 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 67 62 Ok(user) => user, 68 - Err(e) => { 69 - return ( 70 - StatusCode::UNAUTHORIZED, 71 - Json(json!({"error": "AuthenticationFailed", "message": e})), 72 - ) 73 - .into_response(); 74 - } 63 + Err(e) => return ApiError::from(e).into_response(), 75 64 }; 76 65 77 66 let did = &auth_user.did;
+49 -13
src/api/repo/record/batch.rs
··· 17 17 use std::sync::Arc; 18 18 use tracing::error; 19 19 20 + const MAX_BATCH_WRITES: usize = 200; 21 + 20 22 #[derive(Deserialize)] 21 23 #[serde(tag = "$type")] 22 24 pub enum WriteOp { ··· 115 117 .into_response(); 116 118 } 117 119 118 - if input.writes.len() > 200 { 120 + if input.writes.len() > MAX_BATCH_WRITES { 119 121 return ( 120 122 StatusCode::BAD_REQUEST, 121 - Json(json!({"error": "InvalidRequest", "message": "Too many writes (max 200)"})), 123 + Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})), 122 124 ) 123 125 .into_response(); 124 126 } ··· 213 215 .clone() 214 216 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string()); 215 217 let mut record_bytes = Vec::new(); 216 - serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap(); 217 - let record_cid = tracking_store.put(&record_bytes).await.unwrap(); 218 + if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 219 + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 220 + } 221 + let record_cid = match tracking_store.put(&record_bytes).await { 222 + Ok(c) => c, 223 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 224 + }; 218 225 219 - let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey); 220 - mst = mst.add(&key, record_cid).await.unwrap(); 226 + let collection_nsid = match collection.parse::<Nsid>() { 227 + Ok(n) => n, 228 + Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 229 + }; 230 + let key = format!("{}/{}", collection_nsid, rkey); 231 + mst = match mst.add(&key, record_cid).await { 232 + Ok(m) => m, 233 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 234 + }; 221 235 222 236 let uri = format!("at://{}/{}/{}", did, collection, rkey); 223 237 results.push(WriteResult::CreateResult { ··· 236 250 value, 237 251 } => { 238 252 let mut record_bytes = Vec::new(); 239 - serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap(); 240 - let record_cid = tracking_store.put(&record_bytes).await.unwrap(); 253 + if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 254 + return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 255 + } 256 + let record_cid = match tracking_store.put(&record_bytes).await { 257 + Ok(c) => c, 258 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 259 + }; 241 260 242 - let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey); 243 - mst = mst.update(&key, record_cid).await.unwrap(); 261 + let collection_nsid = match collection.parse::<Nsid>() { 262 + Ok(n) => n, 263 + Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 264 + }; 265 + let key = format!("{}/{}", collection_nsid, rkey); 266 + mst = match mst.update(&key, record_cid).await { 267 + Ok(m) => m, 268 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(), 269 + }; 244 270 245 271 let uri = format!("at://{}/{}/{}", did, collection, rkey); 246 272 results.push(WriteResult::UpdateResult { ··· 254 280 }); 255 281 } 256 282 WriteOp::Delete { collection, rkey } => { 257 - let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey); 258 - mst = mst.delete(&key).await.unwrap(); 283 + let collection_nsid = match collection.parse::<Nsid>() { 284 + Ok(n) => n, 285 + Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 286 + }; 287 + let key = format!("{}/{}", collection_nsid, rkey); 288 + mst = match mst.delete(&key).await { 289 + Ok(m) => m, 290 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(), 291 + }; 259 292 260 293 results.push(WriteResult::DeleteResult {}); 261 294 ops.push(RecordOp::Delete { ··· 266 299 } 267 300 } 268 301 269 - let new_mst_root = mst.persist().await.unwrap(); 302 + let new_mst_root = match mst.persist().await { 303 + Ok(c) => c, 304 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 305 + }; 270 306 let written_cids = tracking_store.get_written_cids(); 271 307 let written_cids_str = written_cids 272 308 .iter()
+11 -5
src/api/repo/record/utils.rs
··· 55 55 let new_root_cid = state.block_store.put(&new_commit_bytes).await 56 56 .map_err(|e| format!("Failed to save commit block: {:?}", e))?; 57 57 58 + let mut tx = state.db.begin().await 59 + .map_err(|e| format!("Failed to begin transaction: {}", e))?; 60 + 58 61 sqlx::query!("UPDATE repos SET repo_root_cid = $1 WHERE user_id = $2", new_root_cid.to_string(), user_id) 59 - .execute(&state.db) 62 + .execute(&mut *tx) 60 63 .await 61 64 .map_err(|e| format!("DB Error (repos): {}", e))?; 62 65 ··· 71 74 rkey, 72 75 cid.to_string() 73 76 ) 74 - .execute(&state.db) 77 + .execute(&mut *tx) 75 78 .await 76 79 .map_err(|e| format!("DB Error (records): {}", e))?; 77 80 } ··· 82 85 collection, 83 86 rkey 84 87 ) 85 - .execute(&state.db) 88 + .execute(&mut *tx) 86 89 .await 87 90 .map_err(|e| format!("DB Error (records): {}", e))?; 88 91 } ··· 126 129 &[] as &[String], 127 130 blocks_cids, 128 131 ) 129 - .fetch_one(&state.db) 132 + .fetch_one(&mut *tx) 130 133 .await 131 134 .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 132 135 133 136 sqlx::query( 134 137 &format!("NOTIFY repo_updates, '{}'", seq_row.seq) 135 138 ) 136 - .execute(&state.db) 139 + .execute(&mut *tx) 137 140 .await 138 141 .map_err(|e| format!("DB Error (notify): {}", e))?; 142 + 143 + tx.commit().await 144 + .map_err(|e| format!("Failed to commit transaction: {}", e))?; 139 145 140 146 Ok(CommitResult { 141 147 commit_cid: new_root_cid,
+12 -3
src/api/repo/record/write.rs
··· 294 294 }; 295 295 296 296 let new_mst = if existing_cid.is_some() { 297 - mst.update(&key, record_cid).await.unwrap() 297 + match mst.update(&key, record_cid).await { 298 + Ok(m) => m, 299 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(), 300 + } 298 301 } else { 299 - mst.add(&key, record_cid).await.unwrap() 302 + match mst.add(&key, record_cid).await { 303 + Ok(m) => m, 304 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 305 + } 306 + }; 307 + let new_mst_root = match new_mst.persist().await { 308 + Ok(c) => c, 309 + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 300 310 }; 301 - let new_mst_root = new_mst.persist().await.unwrap(); 302 311 303 312 let op = if existing_cid.is_some() { 304 313 RecordOp::Update { collection: input.collection.clone(), rkey: input.rkey.clone(), cid: record_cid }
+13 -64
src/api/server/account_status.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::state::AppState; 2 3 use axum::{ 3 4 Json, ··· 34 35 headers.get("Authorization").and_then(|h| h.to_str().ok()) 35 36 ) { 36 37 Some(t) => t, 37 - None => { 38 - return ( 39 - StatusCode::UNAUTHORIZED, 40 - Json(json!({"error": "AuthenticationRequired"})), 41 - ) 42 - .into_response(); 43 - } 38 + None => return ApiError::AuthenticationRequired.into_response(), 44 39 }; 45 40 46 - let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await; 47 - let did = match auth_result { 41 + let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 48 42 Ok(user) => user.did, 49 - Err(e) => { 50 - return ( 51 - StatusCode::UNAUTHORIZED, 52 - Json(json!({"error": e})), 53 - ) 54 - .into_response(); 55 - } 43 + Err(e) => return ApiError::from(e).into_response(), 56 44 }; 57 45 58 46 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) ··· 127 115 headers.get("Authorization").and_then(|h| h.to_str().ok()) 128 116 ) { 129 117 Some(t) => t, 130 - None => { 131 - return ( 132 - StatusCode::UNAUTHORIZED, 133 - Json(json!({"error": "AuthenticationRequired"})), 134 - ) 135 - .into_response(); 136 - } 118 + None => return ApiError::AuthenticationRequired.into_response(), 137 119 }; 138 120 139 - let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await; 140 - let did = match auth_result { 121 + let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 141 122 Ok(user) => user.did, 142 - Err(e) => { 143 - return ( 144 - StatusCode::UNAUTHORIZED, 145 - Json(json!({"error": e})), 146 - ) 147 - .into_response(); 148 - } 123 + Err(e) => return ApiError::from(e).into_response(), 149 124 }; 150 125 151 126 let result = sqlx::query!("UPDATE users SET deactivated_at = NULL WHERE did = $1", did) ··· 180 155 headers.get("Authorization").and_then(|h| h.to_str().ok()) 181 156 ) { 182 157 Some(t) => t, 183 - None => { 184 - return ( 185 - StatusCode::UNAUTHORIZED, 186 - Json(json!({"error": "AuthenticationRequired"})), 187 - ) 188 - .into_response(); 189 - } 158 + None => return ApiError::AuthenticationRequired.into_response(), 190 159 }; 191 160 192 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 193 - let did = match auth_result { 161 + let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 194 162 Ok(user) => user.did, 195 - Err(e) => { 196 - return ( 197 - StatusCode::UNAUTHORIZED, 198 - Json(json!({"error": e})), 199 - ) 200 - .into_response(); 201 - } 163 + Err(e) => return ApiError::from(e).into_response(), 202 164 }; 203 165 204 166 let result = sqlx::query!("UPDATE users SET deactivated_at = NOW() WHERE did = $1", did) ··· 226 188 headers.get("Authorization").and_then(|h| h.to_str().ok()) 227 189 ) { 228 190 Some(t) => t, 229 - None => { 230 - return ( 231 - StatusCode::UNAUTHORIZED, 232 - Json(json!({"error": "AuthenticationRequired"})), 233 - ) 234 - .into_response(); 235 - } 191 + None => return ApiError::AuthenticationRequired.into_response(), 236 192 }; 237 193 238 - let auth_result = crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await; 239 - let did = match auth_result { 194 + let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 240 195 Ok(user) => user.did, 241 - Err(e) => { 242 - return ( 243 - StatusCode::UNAUTHORIZED, 244 - Json(json!({"error": e})), 245 - ) 246 - .into_response(); 247 - } 196 + Err(e) => return ApiError::from(e).into_response(), 248 197 }; 249 198 250 199 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
+63 -190
src/api/server/app_password.rs
··· 1 + use crate::api::ApiError; 2 + use crate::auth::BearerAuth; 1 3 use crate::state::AppState; 4 + use crate::util::get_user_id_by_did; 2 5 use axum::{ 3 6 Json, 4 7 extract::State, 5 - http::StatusCode, 6 8 response::{IntoResponse, Response}, 7 9 }; 8 10 use serde::{Deserialize, Serialize}; ··· 24 26 25 27 pub async fn list_app_passwords( 26 28 State(state): State<AppState>, 27 - headers: axum::http::HeaderMap, 29 + BearerAuth(auth_user): BearerAuth, 28 30 ) -> Response { 29 - let token = match crate::auth::extract_bearer_token_from_header( 30 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 31 - ) { 32 - Some(t) => t, 33 - None => { 34 - return ( 35 - StatusCode::UNAUTHORIZED, 36 - Json(json!({"error": "AuthenticationRequired"})), 37 - ) 38 - .into_response(); 39 - } 40 - }; 41 - 42 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 43 - let did = match auth_result { 44 - Ok(user) => user.did, 45 - Err(e) => { 46 - return ( 47 - StatusCode::UNAUTHORIZED, 48 - Json(json!({"error": e})), 49 - ) 50 - .into_response(); 51 - } 31 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 32 + Ok(id) => id, 33 + Err(e) => return ApiError::from(e).into_response(), 52 34 }; 53 35 54 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 55 - .fetch_optional(&state.db) 56 - .await 36 + match sqlx::query!( 37 + "SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC", 38 + user_id 39 + ) 40 + .fetch_all(&state.db) 41 + .await 57 42 { 58 - Ok(Some(id)) => id, 59 - _ => { 60 - return ( 61 - StatusCode::INTERNAL_SERVER_ERROR, 62 - Json(json!({"error": "InternalError"})), 63 - ) 64 - .into_response(); 65 - } 66 - }; 67 - 68 - let result = sqlx::query!("SELECT name, created_at, privileged FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC", user_id) 69 - .fetch_all(&state.db) 70 - .await; 71 - 72 - match result { 73 43 Ok(rows) => { 74 44 let passwords: Vec<AppPassword> = rows 75 45 .iter() 76 - .map(|row| { 77 - AppPassword { 78 - name: row.name.clone(), 79 - created_at: row.created_at.to_rfc3339(), 80 - privileged: row.privileged, 81 - } 46 + .map(|row| AppPassword { 47 + name: row.name.clone(), 48 + created_at: row.created_at.to_rfc3339(), 49 + privileged: row.privileged, 82 50 }) 83 51 .collect(); 84 52 85 - (StatusCode::OK, Json(ListAppPasswordsOutput { passwords })).into_response() 53 + Json(ListAppPasswordsOutput { passwords }).into_response() 86 54 } 87 55 Err(e) => { 88 56 error!("DB error listing app passwords: {:?}", e); 89 - ( 90 - StatusCode::INTERNAL_SERVER_ERROR, 91 - Json(json!({"error": "InternalError"})), 92 - ) 93 - .into_response() 57 + ApiError::InternalError.into_response() 94 58 } 95 59 } 96 60 } ··· 112 76 113 77 pub async fn create_app_password( 114 78 State(state): State<AppState>, 115 - headers: axum::http::HeaderMap, 79 + BearerAuth(auth_user): BearerAuth, 116 80 Json(input): Json<CreateAppPasswordInput>, 117 81 ) -> Response { 118 - let token = match crate::auth::extract_bearer_token_from_header( 119 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 120 - ) { 121 - Some(t) => t, 122 - None => { 123 - return ( 124 - StatusCode::UNAUTHORIZED, 125 - Json(json!({"error": "AuthenticationRequired"})), 126 - ) 127 - .into_response(); 128 - } 129 - }; 130 - 131 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 132 - let did = match auth_result { 133 - Ok(user) => user.did, 134 - Err(e) => { 135 - return ( 136 - StatusCode::UNAUTHORIZED, 137 - Json(json!({"error": e})), 138 - ) 139 - .into_response(); 140 - } 141 - }; 142 - 143 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 144 - .fetch_optional(&state.db) 145 - .await 146 - { 147 - Ok(Some(id)) => id, 148 - _ => { 149 - return ( 150 - StatusCode::INTERNAL_SERVER_ERROR, 151 - Json(json!({"error": "InternalError"})), 152 - ) 153 - .into_response(); 154 - } 82 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 83 + Ok(id) => id, 84 + Err(e) => return ApiError::from(e).into_response(), 155 85 }; 156 86 157 87 let name = input.name.trim(); 158 88 if name.is_empty() { 159 - return ( 160 - StatusCode::BAD_REQUEST, 161 - Json(json!({"error": "InvalidRequest", "message": "name is required"})), 162 - ) 163 - .into_response(); 89 + return ApiError::InvalidRequest("name is required".into()).into_response(); 164 90 } 165 91 166 - let existing = sqlx::query!("SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name) 167 - .fetch_optional(&state.db) 168 - .await; 92 + let existing = sqlx::query!( 93 + "SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2", 94 + user_id, 95 + name 96 + ) 97 + .fetch_optional(&state.db) 98 + .await; 169 99 170 100 if let Ok(Some(_)) = existing { 171 - return ( 172 - StatusCode::BAD_REQUEST, 173 - Json(json!({"error": "DuplicateAppPassword", "message": "App password with this name already exists"})), 174 - ) 175 - .into_response(); 101 + return ApiError::DuplicateAppPassword.into_response(); 176 102 } 177 103 178 104 let password: String = (0..4) ··· 180 106 use rand::Rng; 181 107 let mut rng = rand::thread_rng(); 182 108 let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect(); 183 - (0..4).map(|_| chars[rng.gen_range(0..chars.len())]).collect::<String>() 109 + (0..4) 110 + .map(|_| chars[rng.gen_range(0..chars.len())]) 111 + .collect::<String>() 184 112 }) 185 113 .collect::<Vec<String>>() 186 114 .join("-"); ··· 189 117 Ok(h) => h, 190 118 Err(e) => { 191 119 error!("Failed to hash password: {:?}", e); 192 - return ( 193 - StatusCode::INTERNAL_SERVER_ERROR, 194 - Json(json!({"error": "InternalError"})), 195 - ) 196 - .into_response(); 120 + return ApiError::InternalError.into_response(); 197 121 } 198 122 }; 199 123 200 124 let privileged = input.privileged.unwrap_or(false); 201 125 let created_at = chrono::Utc::now(); 202 126 203 - let result = sqlx::query!( 127 + match sqlx::query!( 204 128 "INSERT INTO app_passwords (user_id, name, password_hash, created_at, privileged) VALUES ($1, $2, $3, $4, $5)", 205 129 user_id, 206 130 name, ··· 209 133 privileged 210 134 ) 211 135 .execute(&state.db) 212 - .await; 213 - 214 - match result { 215 - Ok(_) => ( 216 - StatusCode::OK, 217 - Json(CreateAppPasswordOutput { 218 - name: name.to_string(), 219 - password, 220 - created_at: created_at.to_rfc3339(), 221 - privileged, 222 - }), 223 - ) 224 - .into_response(), 136 + .await 137 + { 138 + Ok(_) => Json(CreateAppPasswordOutput { 139 + name: name.to_string(), 140 + password, 141 + created_at: created_at.to_rfc3339(), 142 + privileged, 143 + }) 144 + .into_response(), 225 145 Err(e) => { 226 146 error!("DB error creating app password: {:?}", e); 227 - ( 228 - StatusCode::INTERNAL_SERVER_ERROR, 229 - Json(json!({"error": "InternalError"})), 230 - ) 231 - .into_response() 147 + ApiError::InternalError.into_response() 232 148 } 233 149 } 234 150 } ··· 240 156 241 157 pub async fn revoke_app_password( 242 158 State(state): State<AppState>, 243 - headers: axum::http::HeaderMap, 159 + BearerAuth(auth_user): BearerAuth, 244 160 Json(input): Json<RevokeAppPasswordInput>, 245 161 ) -> Response { 246 - let token = match crate::auth::extract_bearer_token_from_header( 247 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 248 - ) { 249 - Some(t) => t, 250 - None => { 251 - return ( 252 - StatusCode::UNAUTHORIZED, 253 - Json(json!({"error": "AuthenticationRequired"})), 254 - ) 255 - .into_response(); 256 - } 257 - }; 258 - 259 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 260 - let did = match auth_result { 261 - Ok(user) => user.did, 262 - Err(e) => { 263 - return ( 264 - StatusCode::UNAUTHORIZED, 265 - Json(json!({"error": e})), 266 - ) 267 - .into_response(); 268 - } 269 - }; 270 - 271 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 272 - .fetch_optional(&state.db) 273 - .await 274 - { 275 - Ok(Some(id)) => id, 276 - _ => { 277 - return ( 278 - StatusCode::INTERNAL_SERVER_ERROR, 279 - Json(json!({"error": "InternalError"})), 280 - ) 281 - .into_response(); 282 - } 162 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 163 + Ok(id) => id, 164 + Err(e) => return ApiError::from(e).into_response(), 283 165 }; 284 166 285 167 let name = input.name.trim(); 286 168 if name.is_empty() { 287 - return ( 288 - StatusCode::BAD_REQUEST, 289 - Json(json!({"error": "InvalidRequest", "message": "name is required"})), 290 - ) 291 - .into_response(); 169 + return ApiError::InvalidRequest("name is required".into()).into_response(); 292 170 } 293 171 294 - let result = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", user_id, name) 295 - .execute(&state.db) 296 - .await; 297 - 298 - match result { 172 + match sqlx::query!( 173 + "DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", 174 + user_id, 175 + name 176 + ) 177 + .execute(&state.db) 178 + .await 179 + { 299 180 Ok(r) => { 300 181 if r.rows_affected() == 0 { 301 - return ( 302 - StatusCode::NOT_FOUND, 303 - Json(json!({"error": "AppPasswordNotFound", "message": "App password not found"})), 304 - ) 305 - .into_response(); 182 + return ApiError::AppPasswordNotFound.into_response(); 306 183 } 307 - (StatusCode::OK, Json(json!({}))).into_response() 184 + Json(json!({})).into_response() 308 185 } 309 186 Err(e) => { 310 187 error!("DB error revoking app password: {:?}", e); 311 - ( 312 - StatusCode::INTERNAL_SERVER_ERROR, 313 - Json(json!({"error": "InternalError"})), 314 - ) 315 - .into_response() 188 + ApiError::InternalError.into_response() 316 189 } 317 190 } 318 191 }
+47 -54
src/api/server/email.rs
··· 1 + use crate::api::ApiError; 1 2 use crate::state::AppState; 2 3 use axum::{ 3 4 Json, ··· 6 7 response::{IntoResponse, Response}, 7 8 }; 8 9 use chrono::{Duration, Utc}; 9 - use rand::Rng; 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 12 use tracing::{error, info, warn}; 13 13 14 14 fn generate_confirmation_code() -> String { 15 - let mut rng = rand::thread_rng(); 16 - let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect(); 17 - let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 18 - let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 19 - format!("{}-{}", part1, part2) 15 + crate::util::generate_token_code() 20 16 } 21 17 22 18 #[derive(Deserialize)] ··· 46 42 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 47 43 let did = match auth_result { 48 44 Ok(user) => user.did, 49 - Err(e) => { 50 - return ( 51 - StatusCode::UNAUTHORIZED, 52 - Json(json!({"error": e})), 53 - ) 54 - .into_response(); 55 - } 45 + Err(e) => return ApiError::from(e).into_response(), 56 46 }; 57 47 58 48 let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) ··· 72 62 let handle = user.handle; 73 63 74 64 let email = input.email.trim().to_lowercase(); 75 - if email.is_empty() { 65 + if !crate::api::validation::is_valid_email(&email) { 76 66 return ( 77 67 StatusCode::BAD_REQUEST, 78 - Json(json!({"error": "InvalidRequest", "message": "email is required"})), 68 + Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 79 69 ) 80 70 .into_response(); 81 71 } ··· 161 151 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 162 152 let did = match auth_result { 163 153 Ok(user) => user.did, 164 - Err(e) => { 165 - return ( 166 - StatusCode::UNAUTHORIZED, 167 - Json(json!({"error": e})), 168 - ) 169 - .into_response(); 170 - } 154 + Err(e) => return ApiError::from(e).into_response(), 171 155 }; 172 156 173 157 let user = match sqlx::query!( ··· 194 178 let email = input.email.trim().to_lowercase(); 195 179 let confirmation_code = input.token.trim(); 196 180 197 - if email_pending_verification.is_none() || stored_code.is_none() || expires_at.is_none() { 198 - return ( 199 - StatusCode::BAD_REQUEST, 200 - Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})), 201 - ) 202 - .into_response(); 203 - } 181 + let (pending_email, saved_code, expiry) = match (email_pending_verification, stored_code, expires_at) { 182 + (Some(p), Some(c), Some(e)) => (p, c, e), 183 + _ => { 184 + return ( 185 + StatusCode::BAD_REQUEST, 186 + Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})), 187 + ) 188 + .into_response(); 189 + } 190 + }; 204 191 205 - let email_pending_verification = email_pending_verification.unwrap(); 206 - if email_pending_verification != email { 192 + if pending_email != email { 207 193 return ( 208 194 StatusCode::BAD_REQUEST, 209 195 Json(json!({"error": "InvalidRequest", "message": "Email does not match pending update"})), ··· 211 197 .into_response(); 212 198 } 213 199 214 - if stored_code.unwrap() != confirmation_code { 200 + if saved_code != confirmation_code { 215 201 return ( 216 202 StatusCode::BAD_REQUEST, 217 203 Json(json!({"error": "InvalidToken", "message": "Invalid token"})), ··· 219 205 .into_response(); 220 206 } 221 207 222 - if Utc::now() > expires_at.unwrap() { 208 + if Utc::now() > expiry { 223 209 return ( 224 210 StatusCode::BAD_REQUEST, 225 211 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), ··· 229 215 230 216 let update = sqlx::query!( 231 217 "UPDATE users SET email = $1, email_pending_verification = NULL, email_confirmation_code = NULL, email_confirmation_code_expires_at = NULL WHERE id = $2", 232 - email_pending_verification, 218 + pending_email, 233 219 user_id 234 220 ) 235 221 .execute(&state.db) ··· 287 273 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 288 274 let did = match auth_result { 289 275 Ok(user) => user.did, 290 - Err(e) => { 291 - return ( 292 - StatusCode::UNAUTHORIZED, 293 - Json(json!({"error": e})), 294 - ) 295 - .into_response(); 296 - } 276 + Err(e) => return ApiError::from(e).into_response(), 297 277 }; 298 278 299 279 let user = match sqlx::query!( ··· 319 299 let email_pending_verification = user.email_pending_verification; 320 300 321 301 let new_email = input.email.trim().to_lowercase(); 322 - if new_email.is_empty() { 302 + if !crate::api::validation::is_valid_email(&new_email) { 323 303 return ( 324 304 StatusCode::BAD_REQUEST, 325 - Json(json!({"error": "InvalidRequest", "message": "email is required"})), 326 - ) 327 - .into_response(); 328 - } 329 - 330 - if !new_email.contains('@') || !new_email.contains('.') { 331 - return ( 332 - StatusCode::BAD_REQUEST, 333 - Json(json!({"error": "InvalidRequest", "message": "Invalid email format"})), 305 + Json(json!({"error": "InvalidEmail", "message": "Invalid email format"})), 334 306 ) 335 307 .into_response(); 336 308 } ··· 353 325 } 354 326 }; 355 327 356 - let pending_email = email_pending_verification.unwrap(); 328 + let pending_email = match email_pending_verification { 329 + Some(p) => p, 330 + None => { 331 + return ( 332 + StatusCode::BAD_REQUEST, 333 + Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})), 334 + ) 335 + .into_response(); 336 + } 337 + }; 338 + 357 339 if pending_email.to_lowercase() != new_email { 358 340 return ( 359 341 StatusCode::BAD_REQUEST, ··· 362 344 .into_response(); 363 345 } 364 346 365 - if stored_code.unwrap() != confirmation_token { 347 + let saved_code = match stored_code { 348 + Some(c) => c, 349 + None => { 350 + return ( 351 + StatusCode::BAD_REQUEST, 352 + Json(json!({"error": "InvalidRequest", "message": "No pending email update found"})), 353 + ) 354 + .into_response(); 355 + } 356 + }; 357 + 358 + if saved_code != confirmation_token { 366 359 return ( 367 360 StatusCode::BAD_REQUEST, 368 361 Json(json!({"error": "InvalidToken", "message": "Invalid token"})), ··· 415 408 416 409 match update { 417 410 Ok(_) => { 418 - info!("Email updated to {} for user {}", new_email, user_id); 411 + info!("Email updated for user {}", user_id); 419 412 (StatusCode::OK, Json(json!({}))).into_response() 420 413 } 421 414 Err(e) => {
+61 -209
src/api/server/invite.rs
··· 1 + use crate::api::ApiError; 2 + use crate::auth::BearerAuth; 1 3 use crate::state::AppState; 4 + use crate::util::get_user_id_by_did; 2 5 use axum::{ 3 6 Json, 4 7 extract::State, 5 - http::StatusCode, 6 8 response::{IntoResponse, Response}, 7 9 }; 8 10 use serde::{Deserialize, Serialize}; 9 - use serde_json::json; 10 11 use tracing::error; 11 12 use uuid::Uuid; 12 13 ··· 24 25 25 26 pub async fn create_invite_code( 26 27 State(state): State<AppState>, 27 - headers: axum::http::HeaderMap, 28 + BearerAuth(auth_user): BearerAuth, 28 29 Json(input): Json<CreateInviteCodeInput>, 29 30 ) -> Response { 30 - let token = match crate::auth::extract_bearer_token_from_header( 31 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 32 - ) { 33 - Some(t) => t, 34 - None => { 35 - return ( 36 - StatusCode::UNAUTHORIZED, 37 - Json(json!({"error": "AuthenticationRequired"})), 38 - ) 39 - .into_response(); 40 - } 41 - }; 42 - 43 31 if input.use_count < 1 { 44 - return ( 45 - StatusCode::BAD_REQUEST, 46 - Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})), 47 - ) 48 - .into_response(); 32 + return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 49 33 } 50 34 51 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 52 - let did = match auth_result { 53 - Ok(user) => user.did, 54 - Err(e) => { 55 - return ( 56 - StatusCode::UNAUTHORIZED, 57 - Json(json!({"error": e})), 58 - ) 59 - .into_response(); 60 - } 61 - }; 62 - 63 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 64 - .fetch_optional(&state.db) 65 - .await 66 - { 67 - Ok(Some(id)) => id, 68 - _ => { 69 - return ( 70 - StatusCode::INTERNAL_SERVER_ERROR, 71 - Json(json!({"error": "InternalError"})), 72 - ) 73 - .into_response(); 74 - } 35 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 36 + Ok(id) => id, 37 + Err(e) => return ApiError::from(e).into_response(), 75 38 }; 76 39 77 40 let creator_user_id = if let Some(for_account) = &input.for_account { 78 - let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account) 41 + match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account) 79 42 .fetch_optional(&state.db) 80 - .await; 81 - 82 - match target { 43 + .await 44 + { 83 45 Ok(Some(row)) => row.id, 84 - Ok(None) => { 85 - return ( 86 - StatusCode::NOT_FOUND, 87 - Json(json!({"error": "AccountNotFound", "message": "Target account not found"})), 88 - ) 89 - .into_response(); 90 - } 46 + Ok(None) => return ApiError::AccountNotFound.into_response(), 91 47 Err(e) => { 92 48 error!("DB error looking up target account: {:?}", e); 93 - return ( 94 - StatusCode::INTERNAL_SERVER_ERROR, 95 - Json(json!({"error": "InternalError"})), 96 - ) 97 - .into_response(); 49 + return ApiError::InternalError.into_response(); 98 50 } 99 51 } 100 52 } else { ··· 103 55 104 56 let user_invites_disabled = sqlx::query_scalar!( 105 57 "SELECT invites_disabled FROM users WHERE did = $1", 106 - did 58 + auth_user.did 107 59 ) 108 60 .fetch_optional(&state.db) 109 61 .await 62 + .map_err(|e| { 63 + error!("DB error checking invites_disabled: {:?}", e); 64 + ApiError::InternalError 65 + }) 110 66 .ok() 111 67 .flatten() 112 68 .flatten() 113 69 .unwrap_or(false); 114 70 115 71 if user_invites_disabled { 116 - return ( 117 - StatusCode::FORBIDDEN, 118 - Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})), 119 - ) 120 - .into_response(); 72 + return ApiError::InvitesDisabled.into_response(); 121 73 } 122 74 123 75 let code = Uuid::new_v4().to_string(); 124 76 125 - let result = sqlx::query!( 77 + match sqlx::query!( 126 78 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 127 79 code, 128 80 input.use_count, 129 81 creator_user_id 130 82 ) 131 83 .execute(&state.db) 132 - .await; 133 - 134 - match result { 135 - Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(), 84 + .await 85 + { 86 + Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(), 136 87 Err(e) => { 137 88 error!("DB error creating invite code: {:?}", e); 138 - ( 139 - StatusCode::INTERNAL_SERVER_ERROR, 140 - Json(json!({"error": "InternalError"})), 141 - ) 142 - .into_response() 89 + ApiError::InternalError.into_response() 143 90 } 144 91 } 145 92 } ··· 165 112 166 113 pub async fn create_invite_codes( 167 114 State(state): State<AppState>, 168 - headers: axum::http::HeaderMap, 115 + BearerAuth(auth_user): BearerAuth, 169 116 Json(input): Json<CreateInviteCodesInput>, 170 117 ) -> Response { 171 - let token = match crate::auth::extract_bearer_token_from_header( 172 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 173 - ) { 174 - Some(t) => t, 175 - None => { 176 - return ( 177 - StatusCode::UNAUTHORIZED, 178 - Json(json!({"error": "AuthenticationRequired"})), 179 - ) 180 - .into_response(); 181 - } 182 - }; 183 - 184 118 if input.use_count < 1 { 185 - return ( 186 - StatusCode::BAD_REQUEST, 187 - Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})), 188 - ) 189 - .into_response(); 119 + return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 190 120 } 191 121 192 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 193 - let did = match auth_result { 194 - Ok(user) => user.did, 195 - Err(e) => { 196 - return ( 197 - StatusCode::UNAUTHORIZED, 198 - Json(json!({"error": e})), 199 - ) 200 - .into_response(); 201 - } 202 - }; 203 - 204 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 205 - .fetch_optional(&state.db) 206 - .await 207 - { 208 - Ok(Some(id)) => id, 209 - _ => { 210 - return ( 211 - StatusCode::INTERNAL_SERVER_ERROR, 212 - Json(json!({"error": "InternalError"})), 213 - ) 214 - .into_response(); 215 - } 122 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 123 + Ok(id) => id, 124 + Err(e) => return ApiError::from(e).into_response(), 216 125 }; 217 126 218 127 let code_count = input.code_count.unwrap_or(1).max(1); ··· 225 134 for _ in 0..code_count { 226 135 let code = Uuid::new_v4().to_string(); 227 136 228 - let insert = sqlx::query!( 137 + if let Err(e) = sqlx::query!( 229 138 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 230 139 code, 231 140 input.use_count, 232 141 user_id 233 142 ) 234 143 .execute(&state.db) 235 - .await; 236 - 237 - if let Err(e) = insert { 144 + .await 145 + { 238 146 error!("DB error creating invite code: {:?}", e); 239 - return ( 240 - StatusCode::INTERNAL_SERVER_ERROR, 241 - Json(json!({"error": "InternalError"})), 242 - ) 243 - .into_response(); 147 + return ApiError::InternalError.into_response(); 244 148 } 245 149 246 150 codes.push(code); ··· 252 156 }); 253 157 } else { 254 158 for account_did in for_accounts { 255 - let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did) 159 + let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did) 256 160 .fetch_optional(&state.db) 257 - .await; 258 - 259 - let target_user_id = match target { 161 + .await 162 + { 260 163 Ok(Some(row)) => row.id, 261 - Ok(None) => { 262 - continue; 263 - } 164 + Ok(None) => continue, 264 165 Err(e) => { 265 166 error!("DB error looking up target account: {:?}", e); 266 - return ( 267 - StatusCode::INTERNAL_SERVER_ERROR, 268 - Json(json!({"error": "InternalError"})), 269 - ) 270 - .into_response(); 167 + return ApiError::InternalError.into_response(); 271 168 } 272 169 }; 273 170 ··· 275 172 for _ in 0..code_count { 276 173 let code = Uuid::new_v4().to_string(); 277 174 278 - let insert = sqlx::query!( 175 + if let Err(e) = sqlx::query!( 279 176 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 280 177 code, 281 178 input.use_count, 282 179 target_user_id 283 180 ) 284 181 .execute(&state.db) 285 - .await; 286 - 287 - if let Err(e) = insert { 182 + .await 183 + { 288 184 error!("DB error creating invite code: {:?}", e); 289 - return ( 290 - StatusCode::INTERNAL_SERVER_ERROR, 291 - Json(json!({"error": "InternalError"})), 292 - ) 293 - .into_response(); 185 + return ApiError::InternalError.into_response(); 294 186 } 295 187 296 188 codes.push(code); ··· 303 195 } 304 196 } 305 197 306 - (StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response() 198 + Json(CreateInviteCodesOutput { codes: result_codes }).into_response() 307 199 } 308 200 309 201 #[derive(Deserialize)] ··· 339 231 340 232 pub async fn get_account_invite_codes( 341 233 State(state): State<AppState>, 342 - headers: axum::http::HeaderMap, 234 + BearerAuth(auth_user): BearerAuth, 343 235 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 344 236 ) -> Response { 345 - let token = match crate::auth::extract_bearer_token_from_header( 346 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 347 - ) { 348 - Some(t) => t, 349 - None => { 350 - return ( 351 - StatusCode::UNAUTHORIZED, 352 - Json(json!({"error": "AuthenticationRequired"})), 353 - ) 354 - .into_response(); 355 - } 356 - }; 357 - 358 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 359 - let did = match auth_result { 360 - Ok(user) => user.did, 361 - Err(e) => { 362 - return ( 363 - StatusCode::UNAUTHORIZED, 364 - Json(json!({"error": e})), 365 - ) 366 - .into_response(); 367 - } 368 - }; 369 - 370 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 371 - .fetch_optional(&state.db) 372 - .await 373 - { 374 - Ok(Some(id)) => id, 375 - _ => { 376 - return ( 377 - StatusCode::INTERNAL_SERVER_ERROR, 378 - Json(json!({"error": "InternalError"})), 379 - ) 380 - .into_response(); 381 - } 237 + let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 238 + Ok(id) => id, 239 + Err(e) => return ApiError::from(e).into_response(), 382 240 }; 383 241 384 242 let include_used = params.include_used.unwrap_or(true); 385 243 386 - let codes_result = sqlx::query!( 244 + let codes_rows = match sqlx::query!( 387 245 r#" 388 246 SELECT code, available_uses, created_at, disabled 389 247 FROM invite_codes ··· 393 251 user_id 394 252 ) 395 253 .fetch_all(&state.db) 396 - .await; 397 - 398 - let codes_rows = match codes_result { 254 + .await 255 + { 399 256 Ok(rows) => { 400 257 if include_used { 401 258 rows ··· 405 262 } 406 263 Err(e) => { 407 264 error!("DB error fetching invite codes: {:?}", e); 408 - return ( 409 - StatusCode::INTERNAL_SERVER_ERROR, 410 - Json(json!({"error": "InternalError"})), 411 - ) 412 - .into_response(); 265 + return ApiError::InternalError.into_response(); 413 266 } 414 267 }; 415 268 416 269 let mut codes = Vec::new(); 417 270 for row in codes_rows { 418 - let uses_result = sqlx::query!( 271 + let uses = sqlx::query!( 419 272 r#" 420 273 SELECT u.did, icu.used_at 421 274 FROM invite_code_uses icu ··· 426 279 row.code 427 280 ) 428 281 .fetch_all(&state.db) 429 - .await; 430 - 431 - let uses = match uses_result { 432 - Ok(use_rows) => use_rows 282 + .await 283 + .map(|use_rows| { 284 + use_rows 433 285 .iter() 434 286 .map(|u| InviteCodeUse { 435 287 used_by: u.did.clone(), 436 288 used_at: u.used_at.to_rfc3339(), 437 289 }) 438 - .collect(), 439 - Err(_) => Vec::new(), 440 - }; 290 + .collect() 291 + }) 292 + .unwrap_or_default(); 441 293 442 294 codes.push(InviteCode { 443 295 code: row.code, 444 296 available: row.available_uses, 445 297 disabled: row.disabled.unwrap_or(false), 446 - for_account: did.clone(), 447 - created_by: did.clone(), 298 + for_account: auth_user.did.clone(), 299 + created_by: auth_user.did.clone(), 448 300 created_at: row.created_at.to_rfc3339(), 449 301 uses, 450 302 }); 451 303 } 452 304 453 - (StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response() 305 + Json(GetAccountInviteCodesOutput { codes }).into_response() 454 306 }
+3 -3
src/api/server/mod.rs
··· 4 4 pub mod invite; 5 5 pub mod meta; 6 6 pub mod password; 7 + pub mod service_auth; 7 8 pub mod session; 8 9 pub mod signing_key; 9 10 ··· 16 17 pub use invite::{create_invite_code, create_invite_codes, get_account_invite_codes}; 17 18 pub use meta::{describe_server, health}; 18 19 pub use password::{request_password_reset, reset_password}; 19 - pub use session::{ 20 - create_session, delete_session, get_service_auth, get_session, refresh_session, 21 - }; 20 + pub use service_auth::get_service_auth; 21 + pub use session::{create_session, delete_session, get_session, refresh_session}; 22 22 pub use signing_key::reserve_signing_key;
+43 -17
src/api/server/password.rs
··· 7 7 }; 8 8 use bcrypt::{hash, DEFAULT_COST}; 9 9 use chrono::{Duration, Utc}; 10 - use rand::Rng; 11 10 use serde::Deserialize; 12 11 use serde_json::json; 13 12 use tracing::{error, info, warn}; 14 13 15 14 fn generate_reset_code() -> String { 16 - let mut rng = rand::thread_rng(); 17 - let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect(); 18 - let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 19 - let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect(); 20 - format!("{}-{}", part1, part2) 15 + crate::util::generate_token_code() 21 16 } 22 17 23 18 #[derive(Deserialize)] ··· 45 40 let user_id = match user { 46 41 Ok(Some(row)) => row.id, 47 42 Ok(None) => { 48 - info!("Password reset requested for unknown email: {}", email); 43 + info!("Password reset requested for unknown email"); 49 44 return (StatusCode::OK, Json(json!({}))).into_response(); 50 45 } 51 46 Err(e) => { ··· 151 146 152 147 if let Some(exp) = expires_at { 153 148 if Utc::now() > exp { 154 - let _ = sqlx::query!( 149 + if let Err(e) = sqlx::query!( 155 150 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 156 151 user_id 157 152 ) 158 153 .execute(&state.db) 159 - .await; 154 + .await 155 + { 156 + error!("Failed to clear expired reset code: {:?}", e); 157 + } 160 158 161 159 return ( 162 160 StatusCode::BAD_REQUEST, ··· 184 182 } 185 183 }; 186 184 187 - let update = sqlx::query!( 185 + let mut tx = match state.db.begin().await { 186 + Ok(tx) => tx, 187 + Err(e) => { 188 + error!("Failed to begin transaction: {:?}", e); 189 + return ( 190 + StatusCode::INTERNAL_SERVER_ERROR, 191 + Json(json!({"error": "InternalError"})), 192 + ) 193 + .into_response(); 194 + } 195 + }; 196 + 197 + if let Err(e) = sqlx::query!( 188 198 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2", 189 199 password_hash, 190 200 user_id 191 201 ) 192 - .execute(&state.db) 193 - .await; 194 - 195 - if let Err(e) = update { 202 + .execute(&mut *tx) 203 + .await 204 + { 196 205 error!("DB error updating password: {:?}", e); 197 206 return ( 198 207 StatusCode::INTERNAL_SERVER_ERROR, ··· 201 210 .into_response(); 202 211 } 203 212 204 - let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id) 205 - .execute(&state.db) 206 - .await; 213 + if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id) 214 + .execute(&mut *tx) 215 + .await 216 + { 217 + error!("Failed to invalidate sessions after password reset: {:?}", e); 218 + return ( 219 + StatusCode::INTERNAL_SERVER_ERROR, 220 + Json(json!({"error": "InternalError"})), 221 + ) 222 + .into_response(); 223 + } 224 + 225 + if let Err(e) = tx.commit().await { 226 + error!("Failed to commit password reset transaction: {:?}", e); 227 + return ( 228 + StatusCode::INTERNAL_SERVER_ERROR, 229 + Json(json!({"error": "InternalError"})), 230 + ) 231 + .into_response(); 232 + } 207 233 208 234 info!("Password reset completed for user {}", user_id); 209 235
+63
src/api/server/service_auth.rs
··· 1 + use crate::api::ApiError; 2 + use crate::state::AppState; 3 + use axum::{ 4 + Json, 5 + extract::{Query, State}, 6 + http::StatusCode, 7 + response::{IntoResponse, Response}, 8 + }; 9 + use serde::{Deserialize, Serialize}; 10 + use serde_json::json; 11 + use tracing::error; 12 + 13 + #[derive(Deserialize)] 14 + pub struct GetServiceAuthParams { 15 + pub aud: String, 16 + pub lxm: Option<String>, 17 + pub exp: Option<i64>, 18 + } 19 + 20 + #[derive(Serialize)] 21 + pub struct GetServiceAuthOutput { 22 + pub token: String, 23 + } 24 + 25 + pub async fn get_service_auth( 26 + State(state): State<AppState>, 27 + headers: axum::http::HeaderMap, 28 + Query(params): Query<GetServiceAuthParams>, 29 + ) -> Response { 30 + let token = match crate::auth::extract_bearer_token_from_header( 31 + headers.get("Authorization").and_then(|h| h.to_str().ok()) 32 + ) { 33 + Some(t) => t, 34 + None => return ApiError::AuthenticationRequired.into_response(), 35 + }; 36 + 37 + let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 38 + Ok(user) => user, 39 + Err(e) => return ApiError::from(e).into_response(), 40 + }; 41 + 42 + let key_bytes = match auth_user.key_bytes { 43 + Some(kb) => kb, 44 + None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot create service auth".into()).into_response(), 45 + }; 46 + 47 + let lxm = params.lxm.as_deref().unwrap_or("*"); 48 + 49 + let service_token = match crate::auth::create_service_token(&auth_user.did, &params.aud, lxm, &key_bytes) 50 + { 51 + Ok(t) => t, 52 + Err(e) => { 53 + error!("Failed to create service token: {:?}", e); 54 + return ( 55 + StatusCode::INTERNAL_SERVER_ERROR, 56 + Json(json!({"error": "InternalError"})), 57 + ) 58 + .into_response(); 59 + } 60 + }; 61 + 62 + (StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response() 63 + }
+199 -512
src/api/server/session.rs
··· 1 + use crate::api::ApiError; 2 + use crate::auth::BearerAuth; 1 3 use crate::state::AppState; 2 4 use axum::{ 3 5 Json, 4 - extract::{Query, State}, 5 - http::StatusCode, 6 + extract::State, 6 7 response::{IntoResponse, Response}, 7 8 }; 8 9 use bcrypt::verify; ··· 11 12 use tracing::{error, info, warn}; 12 13 13 14 #[derive(Deserialize)] 14 - pub struct GetServiceAuthParams { 15 - pub aud: String, 16 - pub lxm: Option<String>, 17 - pub exp: Option<i64>, 18 - } 19 - 20 - #[derive(Serialize)] 21 - pub struct GetServiceAuthOutput { 22 - pub token: String, 23 - } 24 - 25 - pub async fn get_service_auth( 26 - State(state): State<AppState>, 27 - headers: axum::http::HeaderMap, 28 - Query(params): Query<GetServiceAuthParams>, 29 - ) -> Response { 30 - let token = match crate::auth::extract_bearer_token_from_header( 31 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 32 - ) { 33 - Some(t) => t, 34 - None => { 35 - return ( 36 - StatusCode::UNAUTHORIZED, 37 - Json(json!({"error": "AuthenticationRequired"})), 38 - ) 39 - .into_response(); 40 - } 41 - }; 42 - 43 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 44 - let (did, key_bytes) = match auth_result { 45 - Ok(user) => { 46 - let kb = match user.key_bytes { 47 - Some(kb) => kb, 48 - None => { 49 - return ( 50 - StatusCode::UNAUTHORIZED, 51 - Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot create service auth"})), 52 - ) 53 - .into_response(); 54 - } 55 - }; 56 - (user.did, kb) 57 - } 58 - Err(e) => { 59 - return ( 60 - StatusCode::UNAUTHORIZED, 61 - Json(json!({"error": e})), 62 - ) 63 - .into_response(); 64 - } 65 - }; 66 - 67 - let lxm = params.lxm.as_deref().unwrap_or("*"); 68 - 69 - let service_token = match crate::auth::create_service_token(&did, &params.aud, lxm, &key_bytes) 70 - { 71 - Ok(t) => t, 72 - Err(e) => { 73 - error!("Failed to create service token: {:?}", e); 74 - return ( 75 - StatusCode::INTERNAL_SERVER_ERROR, 76 - Json(json!({"error": "InternalError"})), 77 - ) 78 - .into_response(); 79 - } 80 - }; 81 - 82 - (StatusCode::OK, Json(GetServiceAuthOutput { token: service_token })).into_response() 83 - } 84 - 85 - #[derive(Deserialize)] 86 15 pub struct CreateSessionInput { 87 16 pub identifier: String, 88 17 pub password: String, ··· 101 30 State(state): State<AppState>, 102 31 Json(input): Json<CreateSessionInput>, 103 32 ) -> Response { 104 - info!("create_session: identifier='{}'", input.identifier); 33 + info!("create_session called"); 105 34 106 - let user_row = sqlx::query!( 35 + let row = match sqlx::query!( 107 36 "SELECT u.id, u.did, u.handle, u.password_hash, k.key_bytes, k.encryption_version FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.handle = $1 OR u.email = $1", 108 37 input.identifier 109 38 ) 110 - .fetch_optional(&state.db) 111 - .await; 112 - 113 - match user_row { 114 - Ok(Some(row)) => { 115 - let user_id = row.id; 116 - let stored_hash = &row.password_hash; 117 - let did = &row.did; 118 - let handle = &row.handle; 119 - let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 120 - Ok(k) => k, 121 - Err(e) => { 122 - error!("Failed to decrypt user key: {:?}", e); 123 - return ( 124 - StatusCode::INTERNAL_SERVER_ERROR, 125 - Json(json!({"error": "InternalError"})), 126 - ) 127 - .into_response(); 128 - } 129 - }; 130 - 131 - let password_valid = if verify(&input.password, stored_hash).unwrap_or(false) { 132 - true 133 - } else { 134 - let app_pass_rows = sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", user_id) 135 - .fetch_all(&state.db) 136 - .await 137 - .unwrap_or_default(); 138 - 139 - app_pass_rows.iter().any(|row| { 140 - verify(&input.password, &row.password_hash).unwrap_or(false) 141 - }) 142 - }; 143 - 144 - if password_valid { 145 - let access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) { 146 - Ok(m) => m, 147 - Err(e) => { 148 - error!("Failed to create access token: {:?}", e); 149 - return ( 150 - StatusCode::INTERNAL_SERVER_ERROR, 151 - Json(json!({"error": "InternalError"})), 152 - ) 153 - .into_response(); 154 - } 155 - }; 156 - 157 - let refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) { 158 - Ok(m) => m, 159 - Err(e) => { 160 - error!("Failed to create refresh token: {:?}", e); 161 - return ( 162 - StatusCode::INTERNAL_SERVER_ERROR, 163 - Json(json!({"error": "InternalError"})), 164 - ) 165 - .into_response(); 166 - } 167 - }; 168 - 169 - let session_insert = sqlx::query!( 170 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)", 171 - did, 172 - access_meta.jti, 173 - refresh_meta.jti, 174 - access_meta.expires_at, 175 - refresh_meta.expires_at 176 - ) 177 - .execute(&state.db) 178 - .await; 179 - 180 - match session_insert { 181 - Ok(_) => { 182 - return ( 183 - StatusCode::OK, 184 - Json(CreateSessionOutput { 185 - access_jwt: access_meta.token, 186 - refresh_jwt: refresh_meta.token, 187 - handle: handle.clone(), 188 - did: did.clone(), 189 - }), 190 - ) 191 - .into_response(); 192 - } 193 - Err(e) => { 194 - error!("Failed to insert session: {:?}", e); 195 - return ( 196 - StatusCode::INTERNAL_SERVER_ERROR, 197 - Json(json!({"error": "InternalError"})), 198 - ) 199 - .into_response(); 200 - } 201 - } 202 - } else { 203 - warn!( 204 - "Password verification failed for identifier: {}", 205 - input.identifier 206 - ); 207 - } 208 - } 39 + .fetch_optional(&state.db) 40 + .await 41 + { 42 + Ok(Some(row)) => row, 209 43 Ok(None) => { 210 - warn!("User not found for identifier: {}", input.identifier); 44 + warn!("User not found for login attempt"); 45 + return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 211 46 } 212 47 Err(e) => { 213 48 error!("Database error fetching user: {:?}", e); 214 - return ( 215 - StatusCode::INTERNAL_SERVER_ERROR, 216 - Json(json!({"error": "InternalError"})), 217 - ) 218 - .into_response(); 49 + return ApiError::InternalError.into_response(); 50 + } 51 + }; 52 + 53 + let key_bytes = match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 54 + Ok(k) => k, 55 + Err(e) => { 56 + error!("Failed to decrypt user key: {:?}", e); 57 + return ApiError::InternalError.into_response(); 219 58 } 59 + }; 60 + 61 + let password_valid = verify(&input.password, &row.password_hash).unwrap_or(false) 62 + || sqlx::query!("SELECT password_hash FROM app_passwords WHERE user_id = $1", row.id) 63 + .fetch_all(&state.db) 64 + .await 65 + .unwrap_or_default() 66 + .iter() 67 + .any(|app| verify(&input.password, &app.password_hash).unwrap_or(false)); 68 + 69 + if !password_valid { 70 + warn!("Password verification failed for login attempt"); 71 + return ApiError::AuthenticationFailedMsg("Invalid identifier or password".into()).into_response(); 220 72 } 221 73 222 - ( 223 - StatusCode::UNAUTHORIZED, 224 - Json(json!({"error": "AuthenticationFailed", "message": "Invalid identifier or password"})), 225 - ) 226 - .into_response() 227 - } 228 - 229 - pub async fn get_session( 230 - State(state): State<AppState>, 231 - headers: axum::http::HeaderMap, 232 - ) -> Response { 233 - let token = match crate::auth::extract_bearer_token_from_header( 234 - headers.get("Authorization").and_then(|h| h.to_str().ok()) 235 - ) { 236 - Some(t) => t, 237 - None => { 238 - return ( 239 - StatusCode::UNAUTHORIZED, 240 - Json(json!({"error": "AuthenticationRequired", "message": "Invalid Authorization header format"})), 241 - ) 242 - .into_response(); 74 + let access_meta = match crate::auth::create_access_token_with_metadata(&row.did, &key_bytes) { 75 + Ok(m) => m, 76 + Err(e) => { 77 + error!("Failed to create access token: {:?}", e); 78 + return ApiError::InternalError.into_response(); 243 79 } 244 80 }; 245 81 246 - let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 247 - let did = match auth_result { 248 - Ok(user) => user.did, 82 + let refresh_meta = match crate::auth::create_refresh_token_with_metadata(&row.did, &key_bytes) { 83 + Ok(m) => m, 249 84 Err(e) => { 250 - return ( 251 - StatusCode::UNAUTHORIZED, 252 - Json(json!({"error": e})), 253 - ) 254 - .into_response(); 85 + error!("Failed to create refresh token: {:?}", e); 86 + return ApiError::InternalError.into_response(); 255 87 } 256 88 }; 257 89 258 - let user = sqlx::query!( 259 - "SELECT handle, email FROM users WHERE did = $1", 260 - did 90 + if let Err(e) = sqlx::query!( 91 + "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)", 92 + row.did, 93 + access_meta.jti, 94 + refresh_meta.jti, 95 + access_meta.expires_at, 96 + refresh_meta.expires_at 261 97 ) 262 - .fetch_optional(&state.db) 263 - .await; 98 + .execute(&state.db) 99 + .await 100 + { 101 + error!("Failed to insert session: {:?}", e); 102 + return ApiError::InternalError.into_response(); 103 + } 264 104 265 - match user { 266 - Ok(Some(row)) => { 267 - return ( 268 - StatusCode::OK, 269 - Json(json!({ 270 - "handle": row.handle, 271 - "did": did, 272 - "email": row.email, 273 - "didDoc": {} 274 - })), 275 - ) 276 - .into_response(); 277 - } 278 - Ok(None) => { 279 - return ( 280 - StatusCode::UNAUTHORIZED, 281 - Json(json!({"error": "AuthenticationFailed"})), 282 - ) 283 - .into_response(); 284 - } 105 + Json(CreateSessionOutput { 106 + access_jwt: access_meta.token, 107 + refresh_jwt: refresh_meta.token, 108 + handle: row.handle, 109 + did: row.did, 110 + }).into_response() 111 + } 112 + 113 + pub async fn get_session( 114 + State(state): State<AppState>, 115 + BearerAuth(auth_user): BearerAuth, 116 + ) -> Response { 117 + match sqlx::query!("SELECT handle, email FROM users WHERE did = $1", auth_user.did) 118 + .fetch_optional(&state.db) 119 + .await 120 + { 121 + Ok(Some(row)) => Json(json!({ 122 + "handle": row.handle, 123 + "did": auth_user.did, 124 + "email": row.email, 125 + "didDoc": {} 126 + })).into_response(), 127 + Ok(None) => ApiError::AuthenticationFailed.into_response(), 285 128 Err(e) => { 286 129 error!("Database error in get_session: {:?}", e); 287 - return ( 288 - StatusCode::INTERNAL_SERVER_ERROR, 289 - Json(json!({"error": "InternalError"})), 290 - ) 291 - .into_response(); 130 + ApiError::InternalError.into_response() 292 131 } 293 132 } 294 133 } ··· 301 140 headers.get("Authorization").and_then(|h| h.to_str().ok()) 302 141 ) { 303 142 Some(t) => t, 304 - None => { 305 - return ( 306 - StatusCode::UNAUTHORIZED, 307 - Json(json!({"error": "AuthenticationRequired"})), 308 - ) 309 - .into_response(); 310 - } 143 + None => return ApiError::AuthenticationRequired.into_response(), 311 144 }; 312 145 313 - let jti = match crate::auth::get_did_from_token(&token) { 314 - Ok(_) => { 315 - let parts: Vec<&str> = token.split('.').collect(); 316 - if parts.len() != 3 { 317 - return ( 318 - StatusCode::UNAUTHORIZED, 319 - Json(json!({"error": "AuthenticationFailed"})), 320 - ) 321 - .into_response(); 322 - } 323 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 324 - let claims_json = match URL_SAFE_NO_PAD.decode(parts[1]) { 325 - Ok(bytes) => bytes, 326 - Err(_) => { 327 - return ( 328 - StatusCode::UNAUTHORIZED, 329 - Json(json!({"error": "AuthenticationFailed"})), 330 - ) 331 - .into_response(); 332 - } 333 - }; 334 - let claims: serde_json::Value = match serde_json::from_slice(&claims_json) { 335 - Ok(c) => c, 336 - Err(_) => { 337 - return ( 338 - StatusCode::UNAUTHORIZED, 339 - Json(json!({"error": "AuthenticationFailed"})), 340 - ) 341 - .into_response(); 342 - } 343 - }; 344 - match claims.get("jti").and_then(|j| j.as_str()) { 345 - Some(jti) => jti.to_string(), 346 - None => { 347 - return ( 348 - StatusCode::UNAUTHORIZED, 349 - Json(json!({"error": "AuthenticationFailed"})), 350 - ) 351 - .into_response(); 352 - } 353 - } 354 - } 355 - Err(_) => { 356 - return ( 357 - StatusCode::UNAUTHORIZED, 358 - Json(json!({"error": "AuthenticationFailed"})), 359 - ) 360 - .into_response(); 361 - } 146 + let jti = match crate::auth::get_jti_from_token(&token) { 147 + Ok(jti) => jti, 148 + Err(_) => return ApiError::AuthenticationFailed.into_response(), 362 149 }; 363 150 364 - let result = sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti) 151 + match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti) 365 152 .execute(&state.db) 366 - .await; 367 - 368 - match result { 369 - Ok(res) => { 370 - if res.rows_affected() > 0 { 371 - return (StatusCode::OK, Json(json!({}))).into_response(); 372 - } 373 - } 153 + .await 154 + { 155 + Ok(res) if res.rows_affected() > 0 => Json(json!({})).into_response(), 156 + Ok(_) => ApiError::AuthenticationFailed.into_response(), 374 157 Err(e) => { 375 158 error!("Database error in delete_session: {:?}", e); 159 + ApiError::AuthenticationFailed.into_response() 376 160 } 377 161 } 378 - 379 - ( 380 - StatusCode::UNAUTHORIZED, 381 - Json(json!({"error": "AuthenticationFailed"})), 382 - ) 383 - .into_response() 384 162 } 385 163 386 164 pub async fn refresh_session( 387 165 State(state): State<AppState>, 388 166 headers: axum::http::HeaderMap, 389 167 ) -> Response { 390 - use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 391 - 392 168 let refresh_token = match crate::auth::extract_bearer_token_from_header( 393 169 headers.get("Authorization").and_then(|h| h.to_str().ok()) 394 170 ) { 395 171 Some(t) => t, 396 - None => { 397 - return ( 398 - StatusCode::UNAUTHORIZED, 399 - Json(json!({"error": "AuthenticationRequired"})), 400 - ) 401 - .into_response(); 402 - } 172 + None => return ApiError::AuthenticationRequired.into_response(), 403 173 }; 404 174 405 - let refresh_jti = { 406 - let parts: Vec<&str> = refresh_token.split('.').collect(); 407 - if parts.len() != 3 { 408 - return ( 409 - StatusCode::UNAUTHORIZED, 410 - Json(json!({"error": "AuthenticationFailed", "message": "Invalid token format"})), 411 - ) 412 - .into_response(); 413 - } 414 - let claims_bytes = match URL_SAFE_NO_PAD.decode(parts[1]) { 415 - Ok(b) => b, 416 - Err(_) => { 417 - return ( 418 - StatusCode::UNAUTHORIZED, 419 - Json(json!({"error": "AuthenticationFailed"})), 420 - ) 421 - .into_response(); 422 - } 423 - }; 424 - let claims: serde_json::Value = match serde_json::from_slice(&claims_bytes) { 425 - Ok(c) => c, 426 - Err(_) => { 427 - return ( 428 - StatusCode::UNAUTHORIZED, 429 - Json(json!({"error": "AuthenticationFailed"})), 430 - ) 431 - .into_response(); 432 - } 433 - }; 434 - match claims.get("jti").and_then(|j| j.as_str()) { 435 - Some(jti) => jti.to_string(), 436 - None => { 437 - return ( 438 - StatusCode::UNAUTHORIZED, 439 - Json(json!({"error": "AuthenticationFailed"})), 440 - ) 441 - .into_response(); 442 - } 175 + let refresh_jti = match crate::auth::get_jti_from_token(&refresh_token) { 176 + Ok(jti) => jti, 177 + Err(_) => return ApiError::AuthenticationFailedMsg("Invalid token format".into()).into_response(), 178 + }; 179 + 180 + let mut tx = match state.db.begin().await { 181 + Ok(tx) => tx, 182 + Err(e) => { 183 + error!("Failed to begin transaction: {:?}", e); 184 + return ApiError::InternalError.into_response(); 443 185 } 444 186 }; 445 187 446 - let reuse_check = sqlx::query_scalar!( 447 - "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1", 188 + if let Ok(Some(session_id)) = sqlx::query_scalar!( 189 + "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE", 448 190 refresh_jti 449 191 ) 450 - .fetch_optional(&state.db) 451 - .await; 452 - 453 - if let Ok(Some(session_id)) = reuse_check { 192 + .fetch_optional(&mut *tx) 193 + .await 194 + { 454 195 warn!("Refresh token reuse detected! Revoking token family for session_id: {}", session_id); 455 196 let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id) 456 - .execute(&state.db) 197 + .execute(&mut *tx) 457 198 .await; 458 - return ( 459 - StatusCode::UNAUTHORIZED, 460 - Json(json!({"error": "ExpiredToken", "message": "Refresh token has been revoked due to suspected compromise"})), 461 - ) 462 - .into_response(); 199 + let _ = tx.commit().await; 200 + return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response(); 463 201 } 464 202 465 - let session = sqlx::query!( 203 + let session_row = match sqlx::query!( 466 204 r#"SELECT st.id, st.did, k.key_bytes, k.encryption_version 467 205 FROM session_tokens st 468 206 JOIN users u ON st.did = u.did 469 207 JOIN user_keys k ON u.id = k.user_id 470 - WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()"#, 208 + WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW() 209 + FOR UPDATE OF st"#, 471 210 refresh_jti 472 211 ) 473 - .fetch_optional(&state.db) 474 - .await; 212 + .fetch_optional(&mut *tx) 213 + .await 214 + { 215 + Ok(Some(row)) => row, 216 + Ok(None) => return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(), 217 + Err(e) => { 218 + error!("Database error fetching session: {:?}", e); 219 + return ApiError::InternalError.into_response(); 220 + } 221 + }; 475 222 476 - match session { 477 - Ok(Some(session_row)) => { 478 - let session_id = session_row.id; 479 - let did = &session_row.did; 480 - let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) { 481 - Ok(k) => k, 482 - Err(e) => { 483 - error!("Failed to decrypt user key: {:?}", e); 484 - return ( 485 - StatusCode::INTERNAL_SERVER_ERROR, 486 - Json(json!({"error": "InternalError"})), 487 - ) 488 - .into_response(); 489 - } 490 - }; 223 + let key_bytes = match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) { 224 + Ok(k) => k, 225 + Err(e) => { 226 + error!("Failed to decrypt user key: {:?}", e); 227 + return ApiError::InternalError.into_response(); 228 + } 229 + }; 491 230 492 - if let Err(_) = crate::auth::verify_refresh_token(&refresh_token, &key_bytes) { 493 - return (StatusCode::UNAUTHORIZED, Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"}))).into_response(); 494 - } 231 + if crate::auth::verify_refresh_token(&refresh_token, &key_bytes).is_err() { 232 + return ApiError::AuthenticationFailedMsg("Invalid refresh token".into()).into_response(); 233 + } 495 234 496 - let new_access_meta = match crate::auth::create_access_token_with_metadata(did, &key_bytes) { 497 - Ok(m) => m, 498 - Err(e) => { 499 - error!("Failed to create access token: {:?}", e); 500 - return ( 501 - StatusCode::INTERNAL_SERVER_ERROR, 502 - Json(json!({"error": "InternalError"})), 503 - ) 504 - .into_response(); 505 - } 506 - }; 507 - let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(did, &key_bytes) { 508 - Ok(m) => m, 509 - Err(e) => { 510 - error!("Failed to create refresh token: {:?}", e); 511 - return ( 512 - StatusCode::INTERNAL_SERVER_ERROR, 513 - Json(json!({"error": "InternalError"})), 514 - ) 515 - .into_response(); 516 - } 517 - }; 235 + let new_access_meta = match crate::auth::create_access_token_with_metadata(&session_row.did, &key_bytes) { 236 + Ok(m) => m, 237 + Err(e) => { 238 + error!("Failed to create access token: {:?}", e); 239 + return ApiError::InternalError.into_response(); 240 + } 241 + }; 518 242 519 - let mut tx = match state.db.begin().await { 520 - Ok(tx) => tx, 521 - Err(e) => { 522 - error!("Failed to begin transaction: {:?}", e); 523 - return ( 524 - StatusCode::INTERNAL_SERVER_ERROR, 525 - Json(json!({"error": "InternalError"})), 526 - ) 527 - .into_response(); 528 - } 529 - }; 243 + let new_refresh_meta = match crate::auth::create_refresh_token_with_metadata(&session_row.did, &key_bytes) { 244 + Ok(m) => m, 245 + Err(e) => { 246 + error!("Failed to create refresh token: {:?}", e); 247 + return ApiError::InternalError.into_response(); 248 + } 249 + }; 530 250 531 - if let Err(e) = sqlx::query!( 532 - "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2)", 533 - refresh_jti, 534 - session_id 535 - ) 536 - .execute(&mut *tx) 537 - .await 538 - { 539 - error!("Failed to record used refresh token: {:?}", e); 540 - return ( 541 - StatusCode::INTERNAL_SERVER_ERROR, 542 - Json(json!({"error": "InternalError"})), 543 - ) 544 - .into_response(); 545 - } 251 + match sqlx::query!( 252 + "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 253 + refresh_jti, 254 + session_row.id 255 + ) 256 + .execute(&mut *tx) 257 + .await 258 + { 259 + Ok(result) if result.rows_affected() == 0 => { 260 + warn!("Concurrent refresh token reuse detected for session_id: {}", session_row.id); 261 + let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_row.id) 262 + .execute(&mut *tx) 263 + .await; 264 + let _ = tx.commit().await; 265 + return ApiError::ExpiredTokenMsg("Refresh token has been revoked due to suspected compromise".into()).into_response(); 266 + } 267 + Err(e) => { 268 + error!("Failed to record used refresh token: {:?}", e); 269 + return ApiError::InternalError.into_response(); 270 + } 271 + Ok(_) => {} 272 + } 546 273 547 - if let Err(e) = sqlx::query!( 548 - "UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5", 549 - new_access_meta.jti, 550 - new_refresh_meta.jti, 551 - new_access_meta.expires_at, 552 - new_refresh_meta.expires_at, 553 - session_id 554 - ) 555 - .execute(&mut *tx) 556 - .await 557 - { 558 - error!("Database error updating session: {:?}", e); 559 - return ( 560 - StatusCode::INTERNAL_SERVER_ERROR, 561 - Json(json!({"error": "InternalError"})), 562 - ) 563 - .into_response(); 564 - } 274 + if let Err(e) = sqlx::query!( 275 + "UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5", 276 + new_access_meta.jti, 277 + new_refresh_meta.jti, 278 + new_access_meta.expires_at, 279 + new_refresh_meta.expires_at, 280 + session_row.id 281 + ) 282 + .execute(&mut *tx) 283 + .await 284 + { 285 + error!("Database error updating session: {:?}", e); 286 + return ApiError::InternalError.into_response(); 287 + } 565 288 566 - if let Err(e) = tx.commit().await { 567 - error!("Failed to commit transaction: {:?}", e); 568 - return ( 569 - StatusCode::INTERNAL_SERVER_ERROR, 570 - Json(json!({"error": "InternalError"})), 571 - ) 572 - .into_response(); 573 - } 289 + if let Err(e) = tx.commit().await { 290 + error!("Failed to commit transaction: {:?}", e); 291 + return ApiError::InternalError.into_response(); 292 + } 574 293 575 - let user = sqlx::query!("SELECT handle FROM users WHERE did = $1", did) 576 - .fetch_optional(&state.db) 577 - .await; 578 - 579 - match user { 580 - Ok(Some(u)) => { 581 - return ( 582 - StatusCode::OK, 583 - Json(json!({ 584 - "accessJwt": new_access_meta.token, 585 - "refreshJwt": new_refresh_meta.token, 586 - "handle": u.handle, 587 - "did": did 588 - })), 589 - ) 590 - .into_response(); 591 - } 592 - Ok(None) => { 593 - error!("User not found for existing session: {}", did); 594 - return ( 595 - StatusCode::INTERNAL_SERVER_ERROR, 596 - Json(json!({"error": "InternalError"})), 597 - ) 598 - .into_response(); 599 - } 600 - Err(e) => { 601 - error!("Database error fetching user: {:?}", e); 602 - return ( 603 - StatusCode::INTERNAL_SERVER_ERROR, 604 - Json(json!({"error": "InternalError"})), 605 - ) 606 - .into_response(); 607 - } 608 - } 609 - } 294 + match sqlx::query!("SELECT handle FROM users WHERE did = $1", session_row.did) 295 + .fetch_optional(&state.db) 296 + .await 297 + { 298 + Ok(Some(u)) => Json(json!({ 299 + "accessJwt": new_access_meta.token, 300 + "refreshJwt": new_refresh_meta.token, 301 + "handle": u.handle, 302 + "did": session_row.did 303 + })).into_response(), 610 304 Ok(None) => { 611 - return ( 612 - StatusCode::UNAUTHORIZED, 613 - Json(json!({"error": "AuthenticationFailed", "message": "Invalid refresh token"})), 614 - ) 615 - .into_response(); 305 + error!("User not found for existing session: {}", session_row.did); 306 + ApiError::InternalError.into_response() 616 307 } 617 308 Err(e) => { 618 - error!("Database error fetching session: {:?}", e); 619 - return ( 620 - StatusCode::INTERNAL_SERVER_ERROR, 621 - Json(json!({"error": "InternalError"})), 622 - ) 623 - .into_response(); 309 + error!("Database error fetching user: {:?}", e); 310 + ApiError::InternalError.into_response() 624 311 } 625 312 } 626 313 }
+104
src/api/validation.rs
··· 1 + pub const MAX_EMAIL_LENGTH: usize = 254; 2 + pub const MAX_LOCAL_PART_LENGTH: usize = 64; 3 + pub const MAX_DOMAIN_LENGTH: usize = 253; 4 + pub const MAX_DOMAIN_LABEL_LENGTH: usize = 63; 5 + 6 + const EMAIL_LOCAL_SPECIAL_CHARS: &str = ".!#$%&'*+/=?^_`{|}~-"; 7 + 8 + pub fn is_valid_email(email: &str) -> bool { 9 + let email = email.trim(); 10 + 11 + if email.is_empty() || email.len() > MAX_EMAIL_LENGTH { 12 + return false; 13 + } 14 + 15 + let parts: Vec<&str> = email.rsplitn(2, '@').collect(); 16 + if parts.len() != 2 { 17 + return false; 18 + } 19 + 20 + let domain = parts[0]; 21 + let local = parts[1]; 22 + 23 + if local.is_empty() || local.len() > MAX_LOCAL_PART_LENGTH { 24 + return false; 25 + } 26 + 27 + if local.starts_with('.') || local.ends_with('.') { 28 + return false; 29 + } 30 + 31 + if local.contains("..") { 32 + return false; 33 + } 34 + 35 + for c in local.chars() { 36 + if !c.is_ascii_alphanumeric() && !EMAIL_LOCAL_SPECIAL_CHARS.contains(c) { 37 + return false; 38 + } 39 + } 40 + 41 + if domain.is_empty() || domain.len() > MAX_DOMAIN_LENGTH { 42 + return false; 43 + } 44 + 45 + if !domain.contains('.') { 46 + return false; 47 + } 48 + 49 + for label in domain.split('.') { 50 + if label.is_empty() || label.len() > MAX_DOMAIN_LABEL_LENGTH { 51 + return false; 52 + } 53 + 54 + if label.starts_with('-') || label.ends_with('-') { 55 + return false; 56 + } 57 + 58 + for c in label.chars() { 59 + if !c.is_ascii_alphanumeric() && c != '-' { 60 + return false; 61 + } 62 + } 63 + } 64 + 65 + true 66 + } 67 + 68 + #[cfg(test)] 69 + mod tests { 70 + use super::*; 71 + 72 + #[test] 73 + fn test_valid_emails() { 74 + assert!(is_valid_email("user@example.com")); 75 + assert!(is_valid_email("user.name@example.com")); 76 + assert!(is_valid_email("user+tag@example.com")); 77 + assert!(is_valid_email("user@sub.example.com")); 78 + assert!(is_valid_email("USER@EXAMPLE.COM")); 79 + assert!(is_valid_email("user123@example123.com")); 80 + assert!(is_valid_email("a@b.co")); 81 + } 82 + 83 + #[test] 84 + fn test_invalid_emails() { 85 + assert!(!is_valid_email("")); 86 + assert!(!is_valid_email("user")); 87 + assert!(!is_valid_email("user@")); 88 + assert!(!is_valid_email("@example.com")); 89 + assert!(!is_valid_email("user@example")); 90 + assert!(!is_valid_email("user@@example.com")); 91 + assert!(!is_valid_email("user@.example.com")); 92 + assert!(!is_valid_email("user@example..com")); 93 + assert!(!is_valid_email(".user@example.com")); 94 + assert!(!is_valid_email("user.@example.com")); 95 + assert!(!is_valid_email("user..name@example.com")); 96 + assert!(!is_valid_email("user@-example.com")); 97 + assert!(!is_valid_email("user@example-.com")); 98 + } 99 + 100 + #[test] 101 + fn test_trimmed_whitespace() { 102 + assert!(is_valid_email(" user@example.com ")); 103 + } 104 + }
+29 -3
src/auth/extractor.rs
··· 7 7 use serde_json::json; 8 8 9 9 use crate::state::AppState; 10 - use super::{AuthenticatedUser, validate_bearer_token}; 10 + use super::{AuthenticatedUser, TokenValidationError, validate_bearer_token, validate_bearer_token_allow_deactivated}; 11 11 12 12 pub struct BearerAuth(pub AuthenticatedUser); 13 13 ··· 112 112 113 113 match validate_bearer_token(&state.db, token).await { 114 114 Ok(user) => Ok(BearerAuth(user)), 115 - Err("AccountDeactivated") => Err(AuthError::AccountDeactivated), 116 - Err("AccountTakedown") => Err(AuthError::AccountTakedown), 115 + Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 116 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 117 + Err(_) => Err(AuthError::AuthenticationFailed), 118 + } 119 + } 120 + } 121 + 122 + pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 123 + 124 + impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 125 + type Rejection = AuthError; 126 + 127 + async fn from_request_parts( 128 + parts: &mut Parts, 129 + state: &AppState, 130 + ) -> Result<Self, Self::Rejection> { 131 + let auth_header = parts 132 + .headers 133 + .get(AUTHORIZATION) 134 + .ok_or(AuthError::MissingToken)? 135 + .to_str() 136 + .map_err(|_| AuthError::InvalidFormat)?; 137 + 138 + let token = extract_bearer_token(auth_header)?; 139 + 140 + match validate_bearer_token_allow_deactivated(&state.db, token).await { 141 + Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 142 + Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 117 143 Err(_) => Err(AuthError::AuthenticationFailed), 118 144 } 119 145 }
+31 -13
src/auth/mod.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 2 use sqlx::PgPool; 3 + use std::fmt; 3 4 4 5 pub mod extractor; 5 6 pub mod token; 6 7 pub mod verify; 7 8 8 - pub use extractor::{BearerAuth, AuthError, extract_bearer_token_from_header}; 9 + pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header}; 9 10 pub use token::{ 10 11 create_access_token, create_refresh_token, create_service_token, 11 12 create_access_token_with_metadata, create_refresh_token_with_metadata, ··· 14 15 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, 15 16 }; 16 17 pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token}; 18 + 19 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 20 + pub enum TokenValidationError { 21 + AccountDeactivated, 22 + AccountTakedown, 23 + KeyDecryptionFailed, 24 + AuthenticationFailed, 25 + } 26 + 27 + impl fmt::Display for TokenValidationError { 28 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 + match self { 30 + Self::AccountDeactivated => write!(f, "AccountDeactivated"), 31 + Self::AccountTakedown => write!(f, "AccountTakedown"), 32 + Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 33 + Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 34 + } 35 + } 36 + } 17 37 18 38 pub struct AuthenticatedUser { 19 39 pub did: String, ··· 24 44 pub async fn validate_bearer_token( 25 45 db: &PgPool, 26 46 token: &str, 27 - ) -> Result<AuthenticatedUser, &'static str> { 47 + ) -> Result<AuthenticatedUser, TokenValidationError> { 28 48 validate_bearer_token_with_options(db, token, false).await 29 49 } 30 50 31 51 pub async fn validate_bearer_token_allow_deactivated( 32 52 db: &PgPool, 33 53 token: &str, 34 - ) -> Result<AuthenticatedUser, &'static str> { 54 + ) -> Result<AuthenticatedUser, TokenValidationError> { 35 55 validate_bearer_token_with_options(db, token, true).await 36 56 } 37 57 ··· 39 59 db: &PgPool, 40 60 token: &str, 41 61 allow_deactivated: bool, 42 - ) -> Result<AuthenticatedUser, &'static str> { 62 + ) -> Result<AuthenticatedUser, TokenValidationError> { 43 63 let did_from_token = get_did_from_token(token).ok(); 44 64 45 65 if let Some(ref did) = did_from_token { ··· 56 76 .flatten() 57 77 { 58 78 if !allow_deactivated && user.deactivated_at.is_some() { 59 - return Err("AccountDeactivated"); 79 + return Err(TokenValidationError::AccountDeactivated); 60 80 } 61 81 if user.takedown_ref.is_some() { 62 - return Err("AccountTakedown"); 82 + return Err(TokenValidationError::AccountTakedown); 63 83 } 64 84 65 - let decrypted_key = match crate::config::decrypt_key(&user.key_bytes, user.encryption_version) { 66 - Ok(k) => k, 67 - Err(_) => return Err("KeyDecryptionFailed"), 68 - }; 85 + let decrypted_key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 86 + .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 69 87 70 88 if let Ok(token_data) = verify_access_token(token, &decrypted_key) { 71 89 let session_exists = sqlx::query_scalar!( ··· 103 121 .flatten() 104 122 { 105 123 if !allow_deactivated && oauth_token.deactivated_at.is_some() { 106 - return Err("AccountDeactivated"); 124 + return Err(TokenValidationError::AccountDeactivated); 107 125 } 108 126 if oauth_token.takedown_ref.is_some() { 109 - return Err("AccountTakedown"); 127 + return Err(TokenValidationError::AccountTakedown); 110 128 } 111 129 112 130 let now = chrono::Utc::now(); ··· 120 138 } 121 139 } 122 140 123 - Err("AuthenticationFailed") 141 + Err(TokenValidationError::AuthenticationFailed) 124 142 } 125 143 126 144 #[derive(Debug, Serialize, Deserialize)]
+7 -3
src/config.rs
··· 62 62 let seed = hasher.finalize(); 63 63 64 64 let signing_key = SigningKey::from_slice(&seed) 65 - .expect("Failed to create signing key from seed"); 65 + .unwrap_or_else(|e| panic!("Failed to create signing key from seed: {}. This is a bug.", e)); 66 66 67 67 let verifying_key = signing_key.verifying_key(); 68 68 let point = verifying_key.to_encoded_point(false); 69 69 70 - let signing_key_x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 71 - let signing_key_y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 70 + let signing_key_x = URL_SAFE_NO_PAD.encode( 71 + point.x().expect("EC point missing X coordinate - this should never happen") 72 + ); 73 + let signing_key_y = URL_SAFE_NO_PAD.encode( 74 + point.y().expect("EC point missing Y coordinate - this should never happen") 75 + ); 72 76 73 77 let mut kid_hasher = Sha256::new(); 74 78 kid_hasher.update(signing_key_x.as_bytes());
+1
src/lib.rs
··· 8 8 pub mod state; 9 9 pub mod storage; 10 10 pub mod sync; 11 + pub mod util; 11 12 12 13 use axum::{ 13 14 Router,
+43 -15
src/main.rs
··· 1 1 use bspds::notifications::{EmailSender, NotificationService}; 2 2 use bspds::state::AppState; 3 3 use std::net::SocketAddr; 4 + use std::process::ExitCode; 4 5 use tokio::sync::watch; 5 - use tracing::{info, warn}; 6 + use tracing::{error, info, warn}; 6 7 7 8 #[tokio::main] 8 - async fn main() { 9 + async fn main() -> ExitCode { 9 10 dotenvy::dotenv().ok(); 10 11 tracing_subscriber::fmt::init(); 11 12 12 - let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); 13 + match run().await { 14 + Ok(()) => ExitCode::SUCCESS, 15 + Err(e) => { 16 + error!("Fatal error: {}", e); 17 + ExitCode::FAILURE 18 + } 19 + } 20 + } 21 + 22 + async fn run() -> Result<(), Box<dyn std::error::Error>> { 23 + let database_url = std::env::var("DATABASE_URL") 24 + .map_err(|_| "DATABASE_URL environment variable must be set")?; 13 25 14 26 let pool = sqlx::postgres::PgPoolOptions::new() 15 - .max_connections(5) 27 + .max_connections(20) 28 + .min_connections(2) 29 + .acquire_timeout(std::time::Duration::from_secs(10)) 30 + .idle_timeout(std::time::Duration::from_secs(300)) 31 + .max_lifetime(std::time::Duration::from_secs(1800)) 16 32 .connect(&database_url) 17 33 .await 18 - .expect("Failed to connect to Postgres"); 34 + .map_err(|e| format!("Failed to connect to Postgres: {}", e))?; 19 35 20 36 sqlx::migrate!("./migrations") 21 37 .run(&pool) 22 38 .await 23 - .expect("Failed to run migrations"); 39 + .map_err(|e| format!("Failed to run migrations: {}", e))?; 24 40 25 41 let state = AppState::new(pool.clone()).await; 26 42 ··· 50 66 51 67 let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); 52 68 info!("listening on {}", addr); 53 - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); 69 + let listener = tokio::net::TcpListener::bind(addr) 70 + .await 71 + .map_err(|e| format!("Failed to bind to {}: {}", addr, e))?; 54 72 55 73 let server_result = axum::serve(listener, app) 56 74 .with_graceful_shutdown(shutdown_signal(shutdown_tx)) ··· 59 77 notification_handle.await.ok(); 60 78 61 79 if let Err(e) = server_result { 62 - tracing::error!("Server error: {}", e); 80 + return Err(format!("Server error: {}", e).into()); 63 81 } 82 + 83 + Ok(()) 64 84 } 65 85 66 86 async fn shutdown_signal(shutdown_tx: watch::Sender<bool>) { 67 87 let ctrl_c = async { 68 - tokio::signal::ctrl_c() 69 - .await 70 - .expect("Failed to install Ctrl+C handler"); 88 + match tokio::signal::ctrl_c().await { 89 + Ok(()) => {} 90 + Err(e) => { 91 + error!("Failed to install Ctrl+C handler: {}", e); 92 + } 93 + } 71 94 }; 72 95 73 96 #[cfg(unix)] 74 97 let terminate = async { 75 - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) 76 - .expect("Failed to install signal handler") 77 - .recv() 78 - .await; 98 + match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { 99 + Ok(mut signal) => { 100 + signal.recv().await; 101 + } 102 + Err(e) => { 103 + error!("Failed to install SIGTERM handler: {}", e); 104 + std::future::pending::<()>().await; 105 + } 106 + } 79 107 }; 80 108 81 109 #[cfg(not(unix))]
-641
src/oauth/db.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - use serde::{de::DeserializeOwned, Serialize}; 3 - use sqlx::PgPool; 4 - 5 - use super::{ 6 - AuthorizationRequestParameters, ClientAuth, DeviceData, OAuthError, RequestData, TokenData, 7 - AuthorizedClientData, 8 - }; 9 - 10 - fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 11 - serde_json::to_value(value).map_err(|e| { 12 - tracing::error!("JSON serialization error: {}", e); 13 - OAuthError::ServerError("Internal serialization error".to_string()) 14 - }) 15 - } 16 - 17 - fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 18 - serde_json::from_value(value).map_err(|e| { 19 - tracing::error!("JSON deserialization error: {}", e); 20 - OAuthError::ServerError("Internal data corruption".to_string()) 21 - }) 22 - } 23 - 24 - pub async fn create_device( 25 - pool: &PgPool, 26 - device_id: &str, 27 - data: &DeviceData, 28 - ) -> Result<(), OAuthError> { 29 - sqlx::query!( 30 - r#" 31 - INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at) 32 - VALUES ($1, $2, $3, $4, $5) 33 - "#, 34 - device_id, 35 - data.session_id, 36 - data.user_agent, 37 - data.ip_address, 38 - data.last_seen_at, 39 - ) 40 - .execute(pool) 41 - .await?; 42 - 43 - Ok(()) 44 - } 45 - 46 - pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 47 - let row = sqlx::query!( 48 - r#" 49 - SELECT session_id, user_agent, ip_address, last_seen_at 50 - FROM oauth_device 51 - WHERE id = $1 52 - "#, 53 - device_id 54 - ) 55 - .fetch_optional(pool) 56 - .await?; 57 - 58 - Ok(row.map(|r| DeviceData { 59 - session_id: r.session_id, 60 - user_agent: r.user_agent, 61 - ip_address: r.ip_address, 62 - last_seen_at: r.last_seen_at, 63 - })) 64 - } 65 - 66 - pub async fn update_device_last_seen( 67 - pool: &PgPool, 68 - device_id: &str, 69 - ) -> Result<(), OAuthError> { 70 - sqlx::query!( 71 - r#" 72 - UPDATE oauth_device 73 - SET last_seen_at = NOW() 74 - WHERE id = $1 75 - "#, 76 - device_id 77 - ) 78 - .execute(pool) 79 - .await?; 80 - 81 - Ok(()) 82 - } 83 - 84 - pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 85 - sqlx::query!( 86 - r#" 87 - DELETE FROM oauth_device WHERE id = $1 88 - "#, 89 - device_id 90 - ) 91 - .execute(pool) 92 - .await?; 93 - 94 - Ok(()) 95 - } 96 - 97 - pub async fn create_authorization_request( 98 - pool: &PgPool, 99 - request_id: &str, 100 - data: &RequestData, 101 - ) -> Result<(), OAuthError> { 102 - let client_auth_json = match &data.client_auth { 103 - Some(ca) => Some(to_json(ca)?), 104 - None => None, 105 - }; 106 - let parameters_json = to_json(&data.parameters)?; 107 - 108 - sqlx::query!( 109 - r#" 110 - INSERT INTO oauth_authorization_request 111 - (id, did, device_id, client_id, client_auth, parameters, expires_at, code) 112 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 113 - "#, 114 - request_id, 115 - data.did, 116 - data.device_id, 117 - data.client_id, 118 - client_auth_json, 119 - parameters_json, 120 - data.expires_at, 121 - data.code, 122 - ) 123 - .execute(pool) 124 - .await?; 125 - 126 - Ok(()) 127 - } 128 - 129 - pub async fn get_authorization_request( 130 - pool: &PgPool, 131 - request_id: &str, 132 - ) -> Result<Option<RequestData>, OAuthError> { 133 - let row = sqlx::query!( 134 - r#" 135 - SELECT did, device_id, client_id, client_auth, parameters, expires_at, code 136 - FROM oauth_authorization_request 137 - WHERE id = $1 138 - "#, 139 - request_id 140 - ) 141 - .fetch_optional(pool) 142 - .await?; 143 - 144 - match row { 145 - Some(r) => { 146 - let client_auth: Option<ClientAuth> = match r.client_auth { 147 - Some(v) => Some(from_json(v)?), 148 - None => None, 149 - }; 150 - let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 151 - 152 - Ok(Some(RequestData { 153 - client_id: r.client_id, 154 - client_auth, 155 - parameters, 156 - expires_at: r.expires_at, 157 - did: r.did, 158 - device_id: r.device_id, 159 - code: r.code, 160 - })) 161 - } 162 - None => Ok(None), 163 - } 164 - } 165 - 166 - pub async fn update_authorization_request( 167 - pool: &PgPool, 168 - request_id: &str, 169 - did: &str, 170 - device_id: Option<&str>, 171 - code: &str, 172 - ) -> Result<(), OAuthError> { 173 - sqlx::query!( 174 - r#" 175 - UPDATE oauth_authorization_request 176 - SET did = $2, device_id = $3, code = $4 177 - WHERE id = $1 178 - "#, 179 - request_id, 180 - did, 181 - device_id, 182 - code 183 - ) 184 - .execute(pool) 185 - .await?; 186 - 187 - Ok(()) 188 - } 189 - 190 - pub async fn consume_authorization_request_by_code( 191 - pool: &PgPool, 192 - code: &str, 193 - ) -> Result<Option<RequestData>, OAuthError> { 194 - let row = sqlx::query!( 195 - r#" 196 - DELETE FROM oauth_authorization_request 197 - WHERE code = $1 198 - RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code 199 - "#, 200 - code 201 - ) 202 - .fetch_optional(pool) 203 - .await?; 204 - 205 - match row { 206 - Some(r) => { 207 - let client_auth: Option<ClientAuth> = match r.client_auth { 208 - Some(v) => Some(from_json(v)?), 209 - None => None, 210 - }; 211 - let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 212 - 213 - Ok(Some(RequestData { 214 - client_id: r.client_id, 215 - client_auth, 216 - parameters, 217 - expires_at: r.expires_at, 218 - did: r.did, 219 - device_id: r.device_id, 220 - code: r.code, 221 - })) 222 - } 223 - None => Ok(None), 224 - } 225 - } 226 - 227 - pub async fn delete_authorization_request( 228 - pool: &PgPool, 229 - request_id: &str, 230 - ) -> Result<(), OAuthError> { 231 - sqlx::query!( 232 - r#" 233 - DELETE FROM oauth_authorization_request WHERE id = $1 234 - "#, 235 - request_id 236 - ) 237 - .execute(pool) 238 - .await?; 239 - 240 - Ok(()) 241 - } 242 - 243 - pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 244 - let result = sqlx::query!( 245 - r#" 246 - DELETE FROM oauth_authorization_request 247 - WHERE expires_at < NOW() 248 - "# 249 - ) 250 - .execute(pool) 251 - .await?; 252 - 253 - Ok(result.rows_affected()) 254 - } 255 - 256 - pub async fn create_token( 257 - pool: &PgPool, 258 - data: &TokenData, 259 - ) -> Result<i32, OAuthError> { 260 - let client_auth_json = to_json(&data.client_auth)?; 261 - let parameters_json = to_json(&data.parameters)?; 262 - 263 - let row = sqlx::query!( 264 - r#" 265 - INSERT INTO oauth_token 266 - (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 267 - device_id, parameters, details, code, current_refresh_token, scope) 268 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 269 - RETURNING id 270 - "#, 271 - data.did, 272 - data.token_id, 273 - data.created_at, 274 - data.updated_at, 275 - data.expires_at, 276 - data.client_id, 277 - client_auth_json, 278 - data.device_id, 279 - parameters_json, 280 - data.details, 281 - data.code, 282 - data.current_refresh_token, 283 - data.scope, 284 - ) 285 - .fetch_one(pool) 286 - .await?; 287 - 288 - Ok(row.id) 289 - } 290 - 291 - pub async fn get_token_by_id( 292 - pool: &PgPool, 293 - token_id: &str, 294 - ) -> Result<Option<TokenData>, OAuthError> { 295 - let row = sqlx::query!( 296 - r#" 297 - SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 298 - device_id, parameters, details, code, current_refresh_token, scope 299 - FROM oauth_token 300 - WHERE token_id = $1 301 - "#, 302 - token_id 303 - ) 304 - .fetch_optional(pool) 305 - .await?; 306 - 307 - match row { 308 - Some(r) => Ok(Some(TokenData { 309 - did: r.did, 310 - token_id: r.token_id, 311 - created_at: r.created_at, 312 - updated_at: r.updated_at, 313 - expires_at: r.expires_at, 314 - client_id: r.client_id, 315 - client_auth: from_json(r.client_auth)?, 316 - device_id: r.device_id, 317 - parameters: from_json(r.parameters)?, 318 - details: r.details, 319 - code: r.code, 320 - current_refresh_token: r.current_refresh_token, 321 - scope: r.scope, 322 - })), 323 - None => Ok(None), 324 - } 325 - } 326 - 327 - pub async fn get_token_by_refresh_token( 328 - pool: &PgPool, 329 - refresh_token: &str, 330 - ) -> Result<Option<(i32, TokenData)>, OAuthError> { 331 - let row = sqlx::query!( 332 - r#" 333 - SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 334 - device_id, parameters, details, code, current_refresh_token, scope 335 - FROM oauth_token 336 - WHERE current_refresh_token = $1 337 - "#, 338 - refresh_token 339 - ) 340 - .fetch_optional(pool) 341 - .await?; 342 - 343 - match row { 344 - Some(r) => Ok(Some(( 345 - r.id, 346 - TokenData { 347 - did: r.did, 348 - token_id: r.token_id, 349 - created_at: r.created_at, 350 - updated_at: r.updated_at, 351 - expires_at: r.expires_at, 352 - client_id: r.client_id, 353 - client_auth: from_json(r.client_auth)?, 354 - device_id: r.device_id, 355 - parameters: from_json(r.parameters)?, 356 - details: r.details, 357 - code: r.code, 358 - current_refresh_token: r.current_refresh_token, 359 - scope: r.scope, 360 - }, 361 - ))), 362 - None => Ok(None), 363 - } 364 - } 365 - 366 - pub async fn rotate_token( 367 - pool: &PgPool, 368 - old_db_id: i32, 369 - new_token_id: &str, 370 - new_refresh_token: &str, 371 - new_expires_at: DateTime<Utc>, 372 - ) -> Result<(), OAuthError> { 373 - let mut tx = pool.begin().await?; 374 - 375 - let old_refresh = sqlx::query_scalar!( 376 - r#" 377 - SELECT current_refresh_token FROM oauth_token WHERE id = $1 378 - "#, 379 - old_db_id 380 - ) 381 - .fetch_one(&mut *tx) 382 - .await?; 383 - 384 - if let Some(old_rt) = old_refresh { 385 - sqlx::query!( 386 - r#" 387 - INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 388 - VALUES ($1, $2) 389 - "#, 390 - old_rt, 391 - old_db_id 392 - ) 393 - .execute(&mut *tx) 394 - .await?; 395 - } 396 - 397 - sqlx::query!( 398 - r#" 399 - UPDATE oauth_token 400 - SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW() 401 - WHERE id = $1 402 - "#, 403 - old_db_id, 404 - new_token_id, 405 - new_refresh_token, 406 - new_expires_at 407 - ) 408 - .execute(&mut *tx) 409 - .await?; 410 - 411 - tx.commit().await?; 412 - Ok(()) 413 - } 414 - 415 - pub async fn check_refresh_token_used( 416 - pool: &PgPool, 417 - refresh_token: &str, 418 - ) -> Result<Option<i32>, OAuthError> { 419 - let row = sqlx::query_scalar!( 420 - r#" 421 - SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 422 - "#, 423 - refresh_token 424 - ) 425 - .fetch_optional(pool) 426 - .await?; 427 - 428 - Ok(row) 429 - } 430 - 431 - pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 432 - sqlx::query!( 433 - r#" 434 - DELETE FROM oauth_token WHERE token_id = $1 435 - "#, 436 - token_id 437 - ) 438 - .execute(pool) 439 - .await?; 440 - 441 - Ok(()) 442 - } 443 - 444 - pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 445 - sqlx::query!( 446 - r#" 447 - DELETE FROM oauth_token WHERE id = $1 448 - "#, 449 - db_id 450 - ) 451 - .execute(pool) 452 - .await?; 453 - 454 - Ok(()) 455 - } 456 - 457 - pub async fn upsert_account_device( 458 - pool: &PgPool, 459 - did: &str, 460 - device_id: &str, 461 - ) -> Result<(), OAuthError> { 462 - sqlx::query!( 463 - r#" 464 - INSERT INTO oauth_account_device (did, device_id, created_at, updated_at) 465 - VALUES ($1, $2, NOW(), NOW()) 466 - ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW() 467 - "#, 468 - did, 469 - device_id 470 - ) 471 - .execute(pool) 472 - .await?; 473 - 474 - Ok(()) 475 - } 476 - 477 - pub async fn upsert_authorized_client( 478 - pool: &PgPool, 479 - did: &str, 480 - client_id: &str, 481 - data: &AuthorizedClientData, 482 - ) -> Result<(), OAuthError> { 483 - let data_json = to_json(data)?; 484 - 485 - sqlx::query!( 486 - r#" 487 - INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data) 488 - VALUES ($1, $2, NOW(), NOW(), $3) 489 - ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3 490 - "#, 491 - did, 492 - client_id, 493 - data_json 494 - ) 495 - .execute(pool) 496 - .await?; 497 - 498 - Ok(()) 499 - } 500 - 501 - pub async fn get_authorized_client( 502 - pool: &PgPool, 503 - did: &str, 504 - client_id: &str, 505 - ) -> Result<Option<AuthorizedClientData>, OAuthError> { 506 - let row = sqlx::query_scalar!( 507 - r#" 508 - SELECT data FROM oauth_authorized_client 509 - WHERE did = $1 AND client_id = $2 510 - "#, 511 - did, 512 - client_id 513 - ) 514 - .fetch_optional(pool) 515 - .await?; 516 - 517 - match row { 518 - Some(v) => Ok(Some(from_json(v)?)), 519 - None => Ok(None), 520 - } 521 - } 522 - 523 - pub async fn list_tokens_for_user( 524 - pool: &PgPool, 525 - did: &str, 526 - ) -> Result<Vec<TokenData>, OAuthError> { 527 - let rows = sqlx::query!( 528 - r#" 529 - SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 530 - device_id, parameters, details, code, current_refresh_token, scope 531 - FROM oauth_token 532 - WHERE did = $1 533 - "#, 534 - did 535 - ) 536 - .fetch_all(pool) 537 - .await?; 538 - 539 - let mut tokens = Vec::with_capacity(rows.len()); 540 - for r in rows { 541 - tokens.push(TokenData { 542 - did: r.did, 543 - token_id: r.token_id, 544 - created_at: r.created_at, 545 - updated_at: r.updated_at, 546 - expires_at: r.expires_at, 547 - client_id: r.client_id, 548 - client_auth: from_json(r.client_auth)?, 549 - device_id: r.device_id, 550 - parameters: from_json(r.parameters)?, 551 - details: r.details, 552 - code: r.code, 553 - current_refresh_token: r.current_refresh_token, 554 - scope: r.scope, 555 - }); 556 - } 557 - Ok(tokens) 558 - } 559 - 560 - pub async fn check_and_record_dpop_jti( 561 - pool: &PgPool, 562 - jti: &str, 563 - ) -> Result<bool, OAuthError> { 564 - let result = sqlx::query!( 565 - r#" 566 - INSERT INTO oauth_dpop_jti (jti) 567 - VALUES ($1) 568 - ON CONFLICT (jti) DO NOTHING 569 - "#, 570 - jti 571 - ) 572 - .execute(pool) 573 - .await?; 574 - 575 - Ok(result.rows_affected() > 0) 576 - } 577 - 578 - pub async fn cleanup_expired_dpop_jtis( 579 - pool: &PgPool, 580 - max_age_secs: i64, 581 - ) -> Result<u64, OAuthError> { 582 - let result = sqlx::query!( 583 - r#" 584 - DELETE FROM oauth_dpop_jti 585 - WHERE created_at < NOW() - INTERVAL '1 second' * $1 586 - "#, 587 - max_age_secs as f64 588 - ) 589 - .execute(pool) 590 - .await?; 591 - 592 - Ok(result.rows_affected()) 593 - } 594 - 595 - pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 596 - let count = sqlx::query_scalar!( 597 - r#" 598 - SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 599 - "#, 600 - did 601 - ) 602 - .fetch_one(pool) 603 - .await?; 604 - 605 - Ok(count) 606 - } 607 - 608 - pub async fn delete_oldest_tokens_for_user( 609 - pool: &PgPool, 610 - did: &str, 611 - keep_count: i64, 612 - ) -> Result<u64, OAuthError> { 613 - let result = sqlx::query!( 614 - r#" 615 - DELETE FROM oauth_token 616 - WHERE id IN ( 617 - SELECT id FROM oauth_token 618 - WHERE did = $1 619 - ORDER BY updated_at ASC 620 - OFFSET $2 621 - ) 622 - "#, 623 - did, 624 - keep_count 625 - ) 626 - .execute(pool) 627 - .await?; 628 - 629 - Ok(result.rows_affected()) 630 - } 631 - 632 - const MAX_TOKENS_PER_USER: i64 = 100; 633 - 634 - pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 635 - let count = count_tokens_for_user(pool, did).await?; 636 - if count > MAX_TOKENS_PER_USER { 637 - let to_keep = MAX_TOKENS_PER_USER - 1; 638 - delete_oldest_tokens_for_user(pool, did, to_keep).await?; 639 - } 640 - Ok(()) 641 - }
+50
src/oauth/db/client.rs
··· 1 + use sqlx::PgPool; 2 + 3 + use super::super::{AuthorizedClientData, OAuthError}; 4 + use super::helpers::{from_json, to_json}; 5 + 6 + pub async fn upsert_authorized_client( 7 + pool: &PgPool, 8 + did: &str, 9 + client_id: &str, 10 + data: &AuthorizedClientData, 11 + ) -> Result<(), OAuthError> { 12 + let data_json = to_json(data)?; 13 + 14 + sqlx::query!( 15 + r#" 16 + INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data) 17 + VALUES ($1, $2, NOW(), NOW(), $3) 18 + ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3 19 + "#, 20 + did, 21 + client_id, 22 + data_json 23 + ) 24 + .execute(pool) 25 + .await?; 26 + 27 + Ok(()) 28 + } 29 + 30 + pub async fn get_authorized_client( 31 + pool: &PgPool, 32 + did: &str, 33 + client_id: &str, 34 + ) -> Result<Option<AuthorizedClientData>, OAuthError> { 35 + let row = sqlx::query_scalar!( 36 + r#" 37 + SELECT data FROM oauth_authorized_client 38 + WHERE did = $1 AND client_id = $2 39 + "#, 40 + did, 41 + client_id 42 + ) 43 + .fetch_optional(pool) 44 + .await?; 45 + 46 + match row { 47 + Some(v) => Ok(Some(from_json(v)?)), 48 + None => Ok(None), 49 + } 50 + }
+96
src/oauth/db/device.rs
··· 1 + use sqlx::PgPool; 2 + 3 + use super::super::{DeviceData, OAuthError}; 4 + 5 + pub async fn create_device( 6 + pool: &PgPool, 7 + device_id: &str, 8 + data: &DeviceData, 9 + ) -> Result<(), OAuthError> { 10 + sqlx::query!( 11 + r#" 12 + INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at) 13 + VALUES ($1, $2, $3, $4, $5) 14 + "#, 15 + device_id, 16 + data.session_id, 17 + data.user_agent, 18 + data.ip_address, 19 + data.last_seen_at, 20 + ) 21 + .execute(pool) 22 + .await?; 23 + 24 + Ok(()) 25 + } 26 + 27 + pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 28 + let row = sqlx::query!( 29 + r#" 30 + SELECT session_id, user_agent, ip_address, last_seen_at 31 + FROM oauth_device 32 + WHERE id = $1 33 + "#, 34 + device_id 35 + ) 36 + .fetch_optional(pool) 37 + .await?; 38 + 39 + Ok(row.map(|r| DeviceData { 40 + session_id: r.session_id, 41 + user_agent: r.user_agent, 42 + ip_address: r.ip_address, 43 + last_seen_at: r.last_seen_at, 44 + })) 45 + } 46 + 47 + pub async fn update_device_last_seen( 48 + pool: &PgPool, 49 + device_id: &str, 50 + ) -> Result<(), OAuthError> { 51 + sqlx::query!( 52 + r#" 53 + UPDATE oauth_device 54 + SET last_seen_at = NOW() 55 + WHERE id = $1 56 + "#, 57 + device_id 58 + ) 59 + .execute(pool) 60 + .await?; 61 + 62 + Ok(()) 63 + } 64 + 65 + pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 66 + sqlx::query!( 67 + r#" 68 + DELETE FROM oauth_device WHERE id = $1 69 + "#, 70 + device_id 71 + ) 72 + .execute(pool) 73 + .await?; 74 + 75 + Ok(()) 76 + } 77 + 78 + pub async fn upsert_account_device( 79 + pool: &PgPool, 80 + did: &str, 81 + device_id: &str, 82 + ) -> Result<(), OAuthError> { 83 + sqlx::query!( 84 + r#" 85 + INSERT INTO oauth_account_device (did, device_id, created_at, updated_at) 86 + VALUES ($1, $2, NOW(), NOW()) 87 + ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW() 88 + "#, 89 + did, 90 + device_id 91 + ) 92 + .execute(pool) 93 + .await?; 94 + 95 + Ok(()) 96 + }
+38
src/oauth/db/dpop.rs
··· 1 + use sqlx::PgPool; 2 + 3 + use super::super::OAuthError; 4 + 5 + pub async fn check_and_record_dpop_jti( 6 + pool: &PgPool, 7 + jti: &str, 8 + ) -> Result<bool, OAuthError> { 9 + let result = sqlx::query!( 10 + r#" 11 + INSERT INTO oauth_dpop_jti (jti) 12 + VALUES ($1) 13 + ON CONFLICT (jti) DO NOTHING 14 + "#, 15 + jti 16 + ) 17 + .execute(pool) 18 + .await?; 19 + 20 + Ok(result.rows_affected() > 0) 21 + } 22 + 23 + pub async fn cleanup_expired_dpop_jtis( 24 + pool: &PgPool, 25 + max_age_secs: i64, 26 + ) -> Result<u64, OAuthError> { 27 + let result = sqlx::query!( 28 + r#" 29 + DELETE FROM oauth_dpop_jti 30 + WHERE created_at < NOW() - INTERVAL '1 second' * $1 31 + "#, 32 + max_age_secs as f64 33 + ) 34 + .execute(pool) 35 + .await?; 36 + 37 + Ok(result.rows_affected()) 38 + }
+17
src/oauth/db/helpers.rs
··· 1 + use serde::{de::DeserializeOwned, Serialize}; 2 + 3 + use super::super::OAuthError; 4 + 5 + pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 6 + serde_json::to_value(value).map_err(|e| { 7 + tracing::error!("JSON serialization error: {}", e); 8 + OAuthError::ServerError("Internal serialization error".to_string()) 9 + }) 10 + } 11 + 12 + pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 13 + serde_json::from_value(value).map_err(|e| { 14 + tracing::error!("JSON deserialization error: {}", e); 15 + OAuthError::ServerError("Internal data corruption".to_string()) 16 + }) 17 + }
+22
src/oauth/db/mod.rs
··· 1 + mod client; 2 + mod device; 3 + mod dpop; 4 + mod helpers; 5 + mod request; 6 + mod token; 7 + 8 + pub use client::{get_authorized_client, upsert_authorized_client}; 9 + pub use device::{ 10 + create_device, delete_device, get_device, update_device_last_seen, upsert_account_device, 11 + }; 12 + pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; 13 + pub use request::{ 14 + consume_authorization_request_by_code, create_authorization_request, 15 + delete_authorization_request, delete_expired_authorization_requests, get_authorization_request, 16 + update_authorization_request, 17 + }; 18 + pub use token::{ 19 + check_refresh_token_used, count_tokens_for_user, create_token, delete_oldest_tokens_for_user, 20 + delete_token, delete_token_family, enforce_token_limit_for_user, get_token_by_id, 21 + get_token_by_refresh_token, list_tokens_for_user, rotate_token, 22 + };
+163
src/oauth/db/request.rs
··· 1 + use sqlx::PgPool; 2 + 3 + use super::super::{AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData}; 4 + use super::helpers::{from_json, to_json}; 5 + 6 + pub async fn create_authorization_request( 7 + pool: &PgPool, 8 + request_id: &str, 9 + data: &RequestData, 10 + ) -> Result<(), OAuthError> { 11 + let client_auth_json = match &data.client_auth { 12 + Some(ca) => Some(to_json(ca)?), 13 + None => None, 14 + }; 15 + let parameters_json = to_json(&data.parameters)?; 16 + 17 + sqlx::query!( 18 + r#" 19 + INSERT INTO oauth_authorization_request 20 + (id, did, device_id, client_id, client_auth, parameters, expires_at, code) 21 + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 22 + "#, 23 + request_id, 24 + data.did, 25 + data.device_id, 26 + data.client_id, 27 + client_auth_json, 28 + parameters_json, 29 + data.expires_at, 30 + data.code, 31 + ) 32 + .execute(pool) 33 + .await?; 34 + 35 + Ok(()) 36 + } 37 + 38 + pub async fn get_authorization_request( 39 + pool: &PgPool, 40 + request_id: &str, 41 + ) -> Result<Option<RequestData>, OAuthError> { 42 + let row = sqlx::query!( 43 + r#" 44 + SELECT did, device_id, client_id, client_auth, parameters, expires_at, code 45 + FROM oauth_authorization_request 46 + WHERE id = $1 47 + "#, 48 + request_id 49 + ) 50 + .fetch_optional(pool) 51 + .await?; 52 + 53 + match row { 54 + Some(r) => { 55 + let client_auth: Option<ClientAuth> = match r.client_auth { 56 + Some(v) => Some(from_json(v)?), 57 + None => None, 58 + }; 59 + let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 60 + 61 + Ok(Some(RequestData { 62 + client_id: r.client_id, 63 + client_auth, 64 + parameters, 65 + expires_at: r.expires_at, 66 + did: r.did, 67 + device_id: r.device_id, 68 + code: r.code, 69 + })) 70 + } 71 + None => Ok(None), 72 + } 73 + } 74 + 75 + pub async fn update_authorization_request( 76 + pool: &PgPool, 77 + request_id: &str, 78 + did: &str, 79 + device_id: Option<&str>, 80 + code: &str, 81 + ) -> Result<(), OAuthError> { 82 + sqlx::query!( 83 + r#" 84 + UPDATE oauth_authorization_request 85 + SET did = $2, device_id = $3, code = $4 86 + WHERE id = $1 87 + "#, 88 + request_id, 89 + did, 90 + device_id, 91 + code 92 + ) 93 + .execute(pool) 94 + .await?; 95 + 96 + Ok(()) 97 + } 98 + 99 + pub async fn consume_authorization_request_by_code( 100 + pool: &PgPool, 101 + code: &str, 102 + ) -> Result<Option<RequestData>, OAuthError> { 103 + let row = sqlx::query!( 104 + r#" 105 + DELETE FROM oauth_authorization_request 106 + WHERE code = $1 107 + RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code 108 + "#, 109 + code 110 + ) 111 + .fetch_optional(pool) 112 + .await?; 113 + 114 + match row { 115 + Some(r) => { 116 + let client_auth: Option<ClientAuth> = match r.client_auth { 117 + Some(v) => Some(from_json(v)?), 118 + None => None, 119 + }; 120 + let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 121 + 122 + Ok(Some(RequestData { 123 + client_id: r.client_id, 124 + client_auth, 125 + parameters, 126 + expires_at: r.expires_at, 127 + did: r.did, 128 + device_id: r.device_id, 129 + code: r.code, 130 + })) 131 + } 132 + None => Ok(None), 133 + } 134 + } 135 + 136 + pub async fn delete_authorization_request( 137 + pool: &PgPool, 138 + request_id: &str, 139 + ) -> Result<(), OAuthError> { 140 + sqlx::query!( 141 + r#" 142 + DELETE FROM oauth_authorization_request WHERE id = $1 143 + "#, 144 + request_id 145 + ) 146 + .execute(pool) 147 + .await?; 148 + 149 + Ok(()) 150 + } 151 + 152 + pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 153 + let result = sqlx::query!( 154 + r#" 155 + DELETE FROM oauth_authorization_request 156 + WHERE expires_at < NOW() 157 + "# 158 + ) 159 + .execute(pool) 160 + .await?; 161 + 162 + Ok(result.rows_affected()) 163 + }
+291
src/oauth/db/token.rs
··· 1 + use chrono::{DateTime, Utc}; 2 + use sqlx::PgPool; 3 + 4 + use super::super::{OAuthError, TokenData}; 5 + use super::helpers::{from_json, to_json}; 6 + 7 + pub async fn create_token( 8 + pool: &PgPool, 9 + data: &TokenData, 10 + ) -> Result<i32, OAuthError> { 11 + let client_auth_json = to_json(&data.client_auth)?; 12 + let parameters_json = to_json(&data.parameters)?; 13 + 14 + let row = sqlx::query!( 15 + r#" 16 + INSERT INTO oauth_token 17 + (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 18 + device_id, parameters, details, code, current_refresh_token, scope) 19 + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 20 + RETURNING id 21 + "#, 22 + data.did, 23 + data.token_id, 24 + data.created_at, 25 + data.updated_at, 26 + data.expires_at, 27 + data.client_id, 28 + client_auth_json, 29 + data.device_id, 30 + parameters_json, 31 + data.details, 32 + data.code, 33 + data.current_refresh_token, 34 + data.scope, 35 + ) 36 + .fetch_one(pool) 37 + .await?; 38 + 39 + Ok(row.id) 40 + } 41 + 42 + pub async fn get_token_by_id( 43 + pool: &PgPool, 44 + token_id: &str, 45 + ) -> Result<Option<TokenData>, OAuthError> { 46 + let row = sqlx::query!( 47 + r#" 48 + SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 49 + device_id, parameters, details, code, current_refresh_token, scope 50 + FROM oauth_token 51 + WHERE token_id = $1 52 + "#, 53 + token_id 54 + ) 55 + .fetch_optional(pool) 56 + .await?; 57 + 58 + match row { 59 + Some(r) => Ok(Some(TokenData { 60 + did: r.did, 61 + token_id: r.token_id, 62 + created_at: r.created_at, 63 + updated_at: r.updated_at, 64 + expires_at: r.expires_at, 65 + client_id: r.client_id, 66 + client_auth: from_json(r.client_auth)?, 67 + device_id: r.device_id, 68 + parameters: from_json(r.parameters)?, 69 + details: r.details, 70 + code: r.code, 71 + current_refresh_token: r.current_refresh_token, 72 + scope: r.scope, 73 + })), 74 + None => Ok(None), 75 + } 76 + } 77 + 78 + pub async fn get_token_by_refresh_token( 79 + pool: &PgPool, 80 + refresh_token: &str, 81 + ) -> Result<Option<(i32, TokenData)>, OAuthError> { 82 + let row = sqlx::query!( 83 + r#" 84 + SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 85 + device_id, parameters, details, code, current_refresh_token, scope 86 + FROM oauth_token 87 + WHERE current_refresh_token = $1 88 + "#, 89 + refresh_token 90 + ) 91 + .fetch_optional(pool) 92 + .await?; 93 + 94 + match row { 95 + Some(r) => Ok(Some(( 96 + r.id, 97 + TokenData { 98 + did: r.did, 99 + token_id: r.token_id, 100 + created_at: r.created_at, 101 + updated_at: r.updated_at, 102 + expires_at: r.expires_at, 103 + client_id: r.client_id, 104 + client_auth: from_json(r.client_auth)?, 105 + device_id: r.device_id, 106 + parameters: from_json(r.parameters)?, 107 + details: r.details, 108 + code: r.code, 109 + current_refresh_token: r.current_refresh_token, 110 + scope: r.scope, 111 + }, 112 + ))), 113 + None => Ok(None), 114 + } 115 + } 116 + 117 + pub async fn rotate_token( 118 + pool: &PgPool, 119 + old_db_id: i32, 120 + new_token_id: &str, 121 + new_refresh_token: &str, 122 + new_expires_at: DateTime<Utc>, 123 + ) -> Result<(), OAuthError> { 124 + let mut tx = pool.begin().await?; 125 + 126 + let old_refresh = sqlx::query_scalar!( 127 + r#" 128 + SELECT current_refresh_token FROM oauth_token WHERE id = $1 129 + "#, 130 + old_db_id 131 + ) 132 + .fetch_one(&mut *tx) 133 + .await?; 134 + 135 + if let Some(old_rt) = old_refresh { 136 + sqlx::query!( 137 + r#" 138 + INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 139 + VALUES ($1, $2) 140 + "#, 141 + old_rt, 142 + old_db_id 143 + ) 144 + .execute(&mut *tx) 145 + .await?; 146 + } 147 + 148 + sqlx::query!( 149 + r#" 150 + UPDATE oauth_token 151 + SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW() 152 + WHERE id = $1 153 + "#, 154 + old_db_id, 155 + new_token_id, 156 + new_refresh_token, 157 + new_expires_at 158 + ) 159 + .execute(&mut *tx) 160 + .await?; 161 + 162 + tx.commit().await?; 163 + Ok(()) 164 + } 165 + 166 + pub async fn check_refresh_token_used( 167 + pool: &PgPool, 168 + refresh_token: &str, 169 + ) -> Result<Option<i32>, OAuthError> { 170 + let row = sqlx::query_scalar!( 171 + r#" 172 + SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 173 + "#, 174 + refresh_token 175 + ) 176 + .fetch_optional(pool) 177 + .await?; 178 + 179 + Ok(row) 180 + } 181 + 182 + pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 183 + sqlx::query!( 184 + r#" 185 + DELETE FROM oauth_token WHERE token_id = $1 186 + "#, 187 + token_id 188 + ) 189 + .execute(pool) 190 + .await?; 191 + 192 + Ok(()) 193 + } 194 + 195 + pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 196 + sqlx::query!( 197 + r#" 198 + DELETE FROM oauth_token WHERE id = $1 199 + "#, 200 + db_id 201 + ) 202 + .execute(pool) 203 + .await?; 204 + 205 + Ok(()) 206 + } 207 + 208 + pub async fn list_tokens_for_user( 209 + pool: &PgPool, 210 + did: &str, 211 + ) -> Result<Vec<TokenData>, OAuthError> { 212 + let rows = sqlx::query!( 213 + r#" 214 + SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 215 + device_id, parameters, details, code, current_refresh_token, scope 216 + FROM oauth_token 217 + WHERE did = $1 218 + "#, 219 + did 220 + ) 221 + .fetch_all(pool) 222 + .await?; 223 + 224 + let mut tokens = Vec::with_capacity(rows.len()); 225 + for r in rows { 226 + tokens.push(TokenData { 227 + did: r.did, 228 + token_id: r.token_id, 229 + created_at: r.created_at, 230 + updated_at: r.updated_at, 231 + expires_at: r.expires_at, 232 + client_id: r.client_id, 233 + client_auth: from_json(r.client_auth)?, 234 + device_id: r.device_id, 235 + parameters: from_json(r.parameters)?, 236 + details: r.details, 237 + code: r.code, 238 + current_refresh_token: r.current_refresh_token, 239 + scope: r.scope, 240 + }); 241 + } 242 + Ok(tokens) 243 + } 244 + 245 + pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 246 + let count = sqlx::query_scalar!( 247 + r#" 248 + SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 249 + "#, 250 + did 251 + ) 252 + .fetch_one(pool) 253 + .await?; 254 + 255 + Ok(count) 256 + } 257 + 258 + pub async fn delete_oldest_tokens_for_user( 259 + pool: &PgPool, 260 + did: &str, 261 + keep_count: i64, 262 + ) -> Result<u64, OAuthError> { 263 + let result = sqlx::query!( 264 + r#" 265 + DELETE FROM oauth_token 266 + WHERE id IN ( 267 + SELECT id FROM oauth_token 268 + WHERE did = $1 269 + ORDER BY updated_at ASC 270 + OFFSET $2 271 + ) 272 + "#, 273 + did, 274 + keep_count 275 + ) 276 + .execute(pool) 277 + .await?; 278 + 279 + Ok(result.rows_affected()) 280 + } 281 + 282 + const MAX_TOKENS_PER_USER: i64 = 100; 283 + 284 + pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 285 + let count = count_tokens_for_user(pool, did).await?; 286 + if count > MAX_TOKENS_PER_USER { 287 + let to_keep = MAX_TOKENS_PER_USER - 1; 288 + delete_oldest_tokens_for_user(pool, did, to_keep).await?; 289 + } 290 + Ok(()) 291 + }
+8 -10
src/oauth/dpop.rs
··· 237 237 false, 238 238 ); 239 239 240 - let affine = AffinePoint::from_encoded_point(&point); 241 - if affine.is_none().into() { 242 - return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string())); 243 - } 240 + let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into(); 241 + let affine = affine_opt 242 + .ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 244 243 245 - let verifying_key = VerifyingKey::from_affine(affine.unwrap()) 244 + let verifying_key = VerifyingKey::from_affine(affine) 246 245 .map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?; 247 246 248 247 let sig = Signature::from_slice(signature) ··· 287 286 false, 288 287 ); 289 288 290 - let affine = AffinePoint::from_encoded_point(&point); 291 - if affine.is_none().into() { 292 - return Err(OAuthError::InvalidDpopProof("Invalid EC point".to_string())); 293 - } 289 + let affine_opt: Option<AffinePoint> = AffinePoint::from_encoded_point(&point).into(); 290 + let affine = affine_opt 291 + .ok_or_else(|| OAuthError::InvalidDpopProof("Invalid EC point".to_string()))?; 294 292 295 - let verifying_key = VerifyingKey::from_affine(affine.unwrap()) 293 + let verifying_key = VerifyingKey::from_affine(affine) 296 294 .map_err(|_| OAuthError::InvalidDpopProof("Invalid verifying key".to_string()))?; 297 295 298 296 let sig = Signature::from_slice(signature)
-558
src/oauth/endpoints/token.rs
··· 1 - use axum::{ 2 - Form, Json, 3 - extract::State, 4 - http::{HeaderMap, StatusCode}, 5 - }; 6 - use base64::Engine; 7 - use base64::engine::general_purpose::URL_SAFE_NO_PAD; 8 - use chrono::{Duration, Utc}; 9 - use hmac::Mac; 10 - use serde::{Deserialize, Serialize}; 11 - use sha2::{Digest, Sha256}; 12 - use subtle::ConstantTimeEq; 13 - 14 - use crate::config::AuthConfig; 15 - use crate::state::AppState; 16 - use crate::oauth::{ 17 - ClientAuth, OAuthError, RefreshToken, TokenData, TokenId, 18 - client::{ClientMetadataCache, verify_client_auth}, 19 - db, 20 - dpop::DPoPVerifier, 21 - }; 22 - 23 - const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 24 - const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 25 - 26 - #[derive(Debug, Deserialize)] 27 - pub struct TokenRequest { 28 - pub grant_type: String, 29 - #[serde(default)] 30 - pub code: Option<String>, 31 - #[serde(default)] 32 - pub redirect_uri: Option<String>, 33 - #[serde(default)] 34 - pub code_verifier: Option<String>, 35 - #[serde(default)] 36 - pub refresh_token: Option<String>, 37 - #[serde(default)] 38 - pub client_id: Option<String>, 39 - #[serde(default)] 40 - pub client_secret: Option<String>, 41 - #[serde(default)] 42 - pub client_assertion: Option<String>, 43 - #[serde(default)] 44 - pub client_assertion_type: Option<String>, 45 - } 46 - 47 - #[derive(Debug, Serialize)] 48 - pub struct TokenResponse { 49 - pub access_token: String, 50 - pub token_type: String, 51 - pub expires_in: u64, 52 - #[serde(skip_serializing_if = "Option::is_none")] 53 - pub refresh_token: Option<String>, 54 - #[serde(skip_serializing_if = "Option::is_none")] 55 - pub scope: Option<String>, 56 - #[serde(skip_serializing_if = "Option::is_none")] 57 - pub sub: Option<String>, 58 - } 59 - 60 - pub async fn token_endpoint( 61 - State(state): State<AppState>, 62 - headers: HeaderMap, 63 - Form(request): Form<TokenRequest>, 64 - ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 65 - let dpop_proof = headers 66 - .get("DPoP") 67 - .and_then(|v| v.to_str().ok()) 68 - .map(|s| s.to_string()); 69 - 70 - match request.grant_type.as_str() { 71 - "authorization_code" => { 72 - handle_authorization_code_grant(state, headers, request, dpop_proof).await 73 - } 74 - "refresh_token" => { 75 - handle_refresh_token_grant(state, headers, request, dpop_proof).await 76 - } 77 - _ => Err(OAuthError::UnsupportedGrantType(format!( 78 - "Unsupported grant_type: {}", 79 - request.grant_type 80 - ))), 81 - } 82 - } 83 - 84 - async fn handle_authorization_code_grant( 85 - state: AppState, 86 - _headers: HeaderMap, 87 - request: TokenRequest, 88 - dpop_proof: Option<String>, 89 - ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 90 - let code = request 91 - .code 92 - .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?; 93 - 94 - let code_verifier = request 95 - .code_verifier 96 - .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?; 97 - 98 - let auth_request = db::consume_authorization_request_by_code(&state.db, &code) 99 - .await? 100 - .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 101 - 102 - if auth_request.expires_at < Utc::now() { 103 - return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string())); 104 - } 105 - 106 - if let Some(request_client_id) = &request.client_id { 107 - if request_client_id != &auth_request.client_id { 108 - return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 109 - } 110 - } 111 - 112 - let did = auth_request 113 - .did 114 - .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 115 - 116 - let client_metadata_cache = ClientMetadataCache::new(3600); 117 - let client_metadata = client_metadata_cache 118 - .get(&auth_request.client_id) 119 - .await?; 120 - let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); 121 - verify_client_auth(&client_metadata, &client_auth)?; 122 - 123 - verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 124 - 125 - if let Some(redirect_uri) = &request.redirect_uri { 126 - if redirect_uri != &auth_request.parameters.redirect_uri { 127 - return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string())); 128 - } 129 - } 130 - 131 - let dpop_jkt = if let Some(proof) = &dpop_proof { 132 - let config = AuthConfig::get(); 133 - let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 134 - 135 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 136 - let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 137 - 138 - let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 139 - 140 - if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 141 - return Err(OAuthError::InvalidDpopProof( 142 - "DPoP proof has already been used".to_string(), 143 - )); 144 - } 145 - 146 - if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt { 147 - if &result.jkt != expected_jkt { 148 - return Err(OAuthError::InvalidDpopProof( 149 - "DPoP key binding mismatch".to_string(), 150 - )); 151 - } 152 - } 153 - 154 - Some(result.jkt) 155 - } else if auth_request.parameters.dpop_jkt.is_some() { 156 - return Err(OAuthError::InvalidRequest( 157 - "DPoP proof required for this authorization".to_string(), 158 - )); 159 - } else { 160 - None 161 - }; 162 - 163 - let token_id = TokenId::generate(); 164 - let refresh_token = RefreshToken::generate(); 165 - let now = Utc::now(); 166 - 167 - let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 168 - 169 - let token_data = TokenData { 170 - did: did.clone(), 171 - token_id: token_id.0.clone(), 172 - created_at: now, 173 - updated_at: now, 174 - expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS), 175 - client_id: auth_request.client_id.clone(), 176 - client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None), 177 - device_id: auth_request.device_id, 178 - parameters: auth_request.parameters.clone(), 179 - details: None, 180 - code: None, 181 - current_refresh_token: Some(refresh_token.0.clone()), 182 - scope: auth_request.parameters.scope.clone(), 183 - }; 184 - 185 - db::create_token(&state.db, &token_data).await?; 186 - 187 - tokio::spawn({ 188 - let pool = state.db.clone(); 189 - let did_clone = did.clone(); 190 - async move { 191 - if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await { 192 - tracing::warn!("Failed to enforce token limit for user: {:?}", e); 193 - } 194 - } 195 - }); 196 - 197 - let mut response_headers = HeaderMap::new(); 198 - let config = AuthConfig::get(); 199 - let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 200 - response_headers.insert( 201 - "DPoP-Nonce", 202 - verifier.generate_nonce().parse().unwrap(), 203 - ); 204 - 205 - Ok(( 206 - response_headers, 207 - Json(TokenResponse { 208 - access_token, 209 - token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 210 - expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 211 - refresh_token: Some(refresh_token.0), 212 - scope: auth_request.parameters.scope, 213 - sub: Some(did), 214 - }), 215 - )) 216 - } 217 - 218 - async fn handle_refresh_token_grant( 219 - state: AppState, 220 - _headers: HeaderMap, 221 - request: TokenRequest, 222 - dpop_proof: Option<String>, 223 - ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 224 - let refresh_token_str = request 225 - .refresh_token 226 - .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?; 227 - 228 - if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? { 229 - db::delete_token_family(&state.db, token_id).await?; 230 - return Err(OAuthError::InvalidGrant( 231 - "Refresh token reuse detected, token family revoked".to_string(), 232 - )); 233 - } 234 - 235 - let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str) 236 - .await? 237 - .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?; 238 - 239 - if token_data.expires_at < Utc::now() { 240 - db::delete_token_family(&state.db, db_id).await?; 241 - return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string())); 242 - } 243 - 244 - let dpop_jkt = if let Some(proof) = &dpop_proof { 245 - let config = AuthConfig::get(); 246 - let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 247 - 248 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 249 - let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 250 - 251 - let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 252 - 253 - if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 254 - return Err(OAuthError::InvalidDpopProof( 255 - "DPoP proof has already been used".to_string(), 256 - )); 257 - } 258 - 259 - if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 260 - if &result.jkt != expected_jkt { 261 - return Err(OAuthError::InvalidDpopProof( 262 - "DPoP key binding mismatch".to_string(), 263 - )); 264 - } 265 - } 266 - 267 - Some(result.jkt) 268 - } else if token_data.parameters.dpop_jkt.is_some() { 269 - return Err(OAuthError::InvalidRequest( 270 - "DPoP proof required".to_string(), 271 - )); 272 - } else { 273 - None 274 - }; 275 - 276 - let new_token_id = TokenId::generate(); 277 - let new_refresh_token = RefreshToken::generate(); 278 - let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS); 279 - 280 - db::rotate_token( 281 - &state.db, 282 - db_id, 283 - &new_token_id.0, 284 - &new_refresh_token.0, 285 - new_expires_at, 286 - ) 287 - .await?; 288 - 289 - let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 290 - 291 - let mut response_headers = HeaderMap::new(); 292 - let config = AuthConfig::get(); 293 - let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 294 - response_headers.insert( 295 - "DPoP-Nonce", 296 - verifier.generate_nonce().parse().unwrap(), 297 - ); 298 - 299 - Ok(( 300 - response_headers, 301 - Json(TokenResponse { 302 - access_token, 303 - token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 304 - expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 305 - refresh_token: Some(new_refresh_token.0), 306 - scope: token_data.scope, 307 - sub: Some(token_data.did), 308 - }), 309 - )) 310 - } 311 - 312 - fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 313 - use subtle::ConstantTimeEq; 314 - 315 - let mut hasher = Sha256::new(); 316 - hasher.update(code_verifier.as_bytes()); 317 - let hash = hasher.finalize(); 318 - let computed_challenge = URL_SAFE_NO_PAD.encode(&hash); 319 - 320 - if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) { 321 - return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string())); 322 - } 323 - 324 - Ok(()) 325 - } 326 - 327 - fn create_access_token( 328 - token_id: &str, 329 - sub: &str, 330 - dpop_jkt: Option<&str>, 331 - ) -> Result<String, OAuthError> { 332 - use serde_json::json; 333 - 334 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 335 - let issuer = format!("https://{}", pds_hostname); 336 - 337 - let now = Utc::now().timestamp(); 338 - let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 339 - 340 - let mut payload = json!({ 341 - "iss": issuer, 342 - "sub": sub, 343 - "aud": issuer, 344 - "iat": now, 345 - "exp": exp, 346 - "jti": token_id, 347 - "scope": "atproto" 348 - }); 349 - 350 - if let Some(jkt) = dpop_jkt { 351 - payload["cnf"] = json!({ "jkt": jkt }); 352 - } 353 - 354 - let header = json!({ 355 - "alg": "HS256", 356 - "typ": "at+jwt" 357 - }); 358 - 359 - let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 360 - let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 361 - 362 - let signing_input = format!("{}.{}", header_b64, payload_b64); 363 - 364 - let config = AuthConfig::get(); 365 - 366 - use sha2::Sha256 as HmacSha256; 367 - use hmac::{Hmac, Mac}; 368 - type HmacSha256Type = Hmac<HmacSha256>; 369 - 370 - let mut mac = HmacSha256Type::new_from_slice(config.jwt_secret().as_bytes()) 371 - .map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?; 372 - mac.update(signing_input.as_bytes()); 373 - let signature = mac.finalize().into_bytes(); 374 - 375 - let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 376 - 377 - Ok(format!("{}.{}", signing_input, signature_b64)) 378 - } 379 - 380 - pub async fn revoke_token( 381 - State(state): State<AppState>, 382 - Form(request): Form<RevokeRequest>, 383 - ) -> Result<StatusCode, OAuthError> { 384 - if let Some(token) = &request.token { 385 - if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 386 - db::delete_token_family(&state.db, db_id).await?; 387 - } else { 388 - db::delete_token(&state.db, token).await?; 389 - } 390 - } 391 - 392 - Ok(StatusCode::OK) 393 - } 394 - 395 - #[derive(Debug, Deserialize)] 396 - pub struct RevokeRequest { 397 - pub token: Option<String>, 398 - #[serde(default)] 399 - pub token_type_hint: Option<String>, 400 - } 401 - 402 - #[derive(Debug, Deserialize)] 403 - pub struct IntrospectRequest { 404 - pub token: String, 405 - #[serde(default)] 406 - pub token_type_hint: Option<String>, 407 - } 408 - 409 - #[derive(Debug, Serialize)] 410 - pub struct IntrospectResponse { 411 - pub active: bool, 412 - #[serde(skip_serializing_if = "Option::is_none")] 413 - pub scope: Option<String>, 414 - #[serde(skip_serializing_if = "Option::is_none")] 415 - pub client_id: Option<String>, 416 - #[serde(skip_serializing_if = "Option::is_none")] 417 - pub username: Option<String>, 418 - #[serde(skip_serializing_if = "Option::is_none")] 419 - pub token_type: Option<String>, 420 - #[serde(skip_serializing_if = "Option::is_none")] 421 - pub exp: Option<i64>, 422 - #[serde(skip_serializing_if = "Option::is_none")] 423 - pub iat: Option<i64>, 424 - #[serde(skip_serializing_if = "Option::is_none")] 425 - pub nbf: Option<i64>, 426 - #[serde(skip_serializing_if = "Option::is_none")] 427 - pub sub: Option<String>, 428 - #[serde(skip_serializing_if = "Option::is_none")] 429 - pub aud: Option<String>, 430 - #[serde(skip_serializing_if = "Option::is_none")] 431 - pub iss: Option<String>, 432 - #[serde(skip_serializing_if = "Option::is_none")] 433 - pub jti: Option<String>, 434 - } 435 - 436 - pub async fn introspect_token( 437 - State(state): State<AppState>, 438 - Form(request): Form<IntrospectRequest>, 439 - ) -> Json<IntrospectResponse> { 440 - let inactive_response = IntrospectResponse { 441 - active: false, 442 - scope: None, 443 - client_id: None, 444 - username: None, 445 - token_type: None, 446 - exp: None, 447 - iat: None, 448 - nbf: None, 449 - sub: None, 450 - aud: None, 451 - iss: None, 452 - jti: None, 453 - }; 454 - 455 - let token_info = match extract_token_claims(&request.token) { 456 - Ok(info) => info, 457 - Err(_) => return Json(inactive_response), 458 - }; 459 - 460 - let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 461 - Ok(Some(data)) => data, 462 - _ => return Json(inactive_response), 463 - }; 464 - 465 - if token_data.expires_at < Utc::now() { 466 - return Json(inactive_response); 467 - } 468 - 469 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 470 - let issuer = format!("https://{}", pds_hostname); 471 - 472 - Json(IntrospectResponse { 473 - active: true, 474 - scope: token_data.scope, 475 - client_id: Some(token_data.client_id), 476 - username: None, 477 - token_type: if token_data.parameters.dpop_jkt.is_some() { 478 - Some("DPoP".to_string()) 479 - } else { 480 - Some("Bearer".to_string()) 481 - }, 482 - exp: Some(token_info.exp), 483 - iat: Some(token_info.iat), 484 - nbf: Some(token_info.iat), 485 - sub: Some(token_data.did), 486 - aud: Some(issuer.clone()), 487 - iss: Some(issuer), 488 - jti: Some(token_info.jti), 489 - }) 490 - } 491 - 492 - struct TokenClaims { 493 - jti: String, 494 - exp: i64, 495 - iat: i64, 496 - } 497 - 498 - fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 499 - let parts: Vec<&str> = token.split('.').collect(); 500 - if parts.len() != 3 { 501 - return Err(OAuthError::InvalidToken("Invalid token format".to_string())); 502 - } 503 - 504 - let header_bytes = URL_SAFE_NO_PAD 505 - .decode(parts[0]) 506 - .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?; 507 - let header: serde_json::Value = serde_json::from_slice(&header_bytes) 508 - .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?; 509 - 510 - if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") { 511 - return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string())); 512 - } 513 - if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") { 514 - return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string())); 515 - } 516 - 517 - let config = AuthConfig::get(); 518 - let secret = config.jwt_secret(); 519 - 520 - let signing_input = format!("{}.{}", parts[0], parts[1]); 521 - let provided_sig = URL_SAFE_NO_PAD 522 - .decode(parts[2]) 523 - .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?; 524 - 525 - type HmacSha256 = hmac::Hmac<Sha256>; 526 - let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) 527 - .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?; 528 - mac.update(signing_input.as_bytes()); 529 - let expected_sig = mac.finalize().into_bytes(); 530 - 531 - if !bool::from(expected_sig.ct_eq(&provided_sig)) { 532 - return Err(OAuthError::InvalidToken("Invalid token signature".to_string())); 533 - } 534 - 535 - let payload_bytes = URL_SAFE_NO_PAD 536 - .decode(parts[1]) 537 - .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?; 538 - let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 539 - .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?; 540 - 541 - let jti = payload 542 - .get("jti") 543 - .and_then(|j| j.as_str()) 544 - .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))? 545 - .to_string(); 546 - 547 - let exp = payload 548 - .get("exp") 549 - .and_then(|e| e.as_i64()) 550 - .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?; 551 - 552 - let iat = payload 553 - .get("iat") 554 - .and_then(|i| i.as_i64()) 555 - .ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?; 556 - 557 - Ok(TokenClaims { jti, exp, iat }) 558 - }
+246
src/oauth/endpoints/token/grants.rs
··· 1 + use axum::http::HeaderMap; 2 + use axum::Json; 3 + use chrono::{Duration, Utc}; 4 + 5 + use crate::config::AuthConfig; 6 + use crate::state::AppState; 7 + use crate::oauth::{ 8 + ClientAuth, OAuthError, RefreshToken, TokenData, TokenId, 9 + client::{ClientMetadataCache, verify_client_auth}, 10 + db, 11 + dpop::DPoPVerifier, 12 + }; 13 + 14 + use super::types::{TokenRequest, TokenResponse}; 15 + use super::helpers::{create_access_token, verify_pkce}; 16 + 17 + const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 18 + const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 19 + 20 + pub async fn handle_authorization_code_grant( 21 + state: AppState, 22 + _headers: HeaderMap, 23 + request: TokenRequest, 24 + dpop_proof: Option<String>, 25 + ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 26 + let code = request 27 + .code 28 + .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?; 29 + 30 + let code_verifier = request 31 + .code_verifier 32 + .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?; 33 + 34 + let auth_request = db::consume_authorization_request_by_code(&state.db, &code) 35 + .await? 36 + .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 37 + 38 + if auth_request.expires_at < Utc::now() { 39 + return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string())); 40 + } 41 + 42 + if let Some(request_client_id) = &request.client_id { 43 + if request_client_id != &auth_request.client_id { 44 + return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 45 + } 46 + } 47 + 48 + let did = auth_request 49 + .did 50 + .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 51 + 52 + let client_metadata_cache = ClientMetadataCache::new(3600); 53 + let client_metadata = client_metadata_cache 54 + .get(&auth_request.client_id) 55 + .await?; 56 + let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); 57 + verify_client_auth(&client_metadata, &client_auth)?; 58 + 59 + verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 60 + 61 + if let Some(redirect_uri) = &request.redirect_uri { 62 + if redirect_uri != &auth_request.parameters.redirect_uri { 63 + return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string())); 64 + } 65 + } 66 + 67 + let dpop_jkt = if let Some(proof) = &dpop_proof { 68 + let config = AuthConfig::get(); 69 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 70 + 71 + let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 72 + let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 73 + 74 + let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 75 + 76 + if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 77 + return Err(OAuthError::InvalidDpopProof( 78 + "DPoP proof has already been used".to_string(), 79 + )); 80 + } 81 + 82 + if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt { 83 + if &result.jkt != expected_jkt { 84 + return Err(OAuthError::InvalidDpopProof( 85 + "DPoP key binding mismatch".to_string(), 86 + )); 87 + } 88 + } 89 + 90 + Some(result.jkt) 91 + } else if auth_request.parameters.dpop_jkt.is_some() { 92 + return Err(OAuthError::InvalidRequest( 93 + "DPoP proof required for this authorization".to_string(), 94 + )); 95 + } else { 96 + None 97 + }; 98 + 99 + let token_id = TokenId::generate(); 100 + let refresh_token = RefreshToken::generate(); 101 + let now = Utc::now(); 102 + 103 + let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 104 + 105 + let token_data = TokenData { 106 + did: did.clone(), 107 + token_id: token_id.0.clone(), 108 + created_at: now, 109 + updated_at: now, 110 + expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS), 111 + client_id: auth_request.client_id.clone(), 112 + client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None), 113 + device_id: auth_request.device_id, 114 + parameters: auth_request.parameters.clone(), 115 + details: None, 116 + code: None, 117 + current_refresh_token: Some(refresh_token.0.clone()), 118 + scope: auth_request.parameters.scope.clone(), 119 + }; 120 + 121 + db::create_token(&state.db, &token_data).await?; 122 + 123 + tokio::spawn({ 124 + let pool = state.db.clone(); 125 + let did_clone = did.clone(); 126 + async move { 127 + if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await { 128 + tracing::warn!("Failed to enforce token limit for user: {:?}", e); 129 + } 130 + } 131 + }); 132 + 133 + let mut response_headers = HeaderMap::new(); 134 + let config = AuthConfig::get(); 135 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 136 + response_headers.insert( 137 + "DPoP-Nonce", 138 + verifier.generate_nonce().parse().unwrap(), 139 + ); 140 + 141 + Ok(( 142 + response_headers, 143 + Json(TokenResponse { 144 + access_token, 145 + token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 146 + expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 147 + refresh_token: Some(refresh_token.0), 148 + scope: auth_request.parameters.scope, 149 + sub: Some(did), 150 + }), 151 + )) 152 + } 153 + 154 + pub async fn handle_refresh_token_grant( 155 + state: AppState, 156 + _headers: HeaderMap, 157 + request: TokenRequest, 158 + dpop_proof: Option<String>, 159 + ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 160 + let refresh_token_str = request 161 + .refresh_token 162 + .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?; 163 + 164 + if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? { 165 + db::delete_token_family(&state.db, token_id).await?; 166 + return Err(OAuthError::InvalidGrant( 167 + "Refresh token reuse detected, token family revoked".to_string(), 168 + )); 169 + } 170 + 171 + let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str) 172 + .await? 173 + .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?; 174 + 175 + if token_data.expires_at < Utc::now() { 176 + db::delete_token_family(&state.db, db_id).await?; 177 + return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string())); 178 + } 179 + 180 + let dpop_jkt = if let Some(proof) = &dpop_proof { 181 + let config = AuthConfig::get(); 182 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 183 + 184 + let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 185 + let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 186 + 187 + let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 188 + 189 + if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 190 + return Err(OAuthError::InvalidDpopProof( 191 + "DPoP proof has already been used".to_string(), 192 + )); 193 + } 194 + 195 + if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 196 + if &result.jkt != expected_jkt { 197 + return Err(OAuthError::InvalidDpopProof( 198 + "DPoP key binding mismatch".to_string(), 199 + )); 200 + } 201 + } 202 + 203 + Some(result.jkt) 204 + } else if token_data.parameters.dpop_jkt.is_some() { 205 + return Err(OAuthError::InvalidRequest( 206 + "DPoP proof required".to_string(), 207 + )); 208 + } else { 209 + None 210 + }; 211 + 212 + let new_token_id = TokenId::generate(); 213 + let new_refresh_token = RefreshToken::generate(); 214 + let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS); 215 + 216 + db::rotate_token( 217 + &state.db, 218 + db_id, 219 + &new_token_id.0, 220 + &new_refresh_token.0, 221 + new_expires_at, 222 + ) 223 + .await?; 224 + 225 + let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 226 + 227 + let mut response_headers = HeaderMap::new(); 228 + let config = AuthConfig::get(); 229 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 230 + response_headers.insert( 231 + "DPoP-Nonce", 232 + verifier.generate_nonce().parse().unwrap(), 233 + ); 234 + 235 + Ok(( 236 + response_headers, 237 + Json(TokenResponse { 238 + access_token, 239 + token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 240 + expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 241 + refresh_token: Some(new_refresh_token.0), 242 + scope: token_data.scope, 243 + sub: Some(token_data.did), 244 + }), 245 + )) 246 + }
+143
src/oauth/endpoints/token/helpers.rs
··· 1 + use base64::Engine; 2 + use base64::engine::general_purpose::URL_SAFE_NO_PAD; 3 + use chrono::Utc; 4 + use hmac::Mac; 5 + use sha2::{Digest, Sha256}; 6 + use subtle::ConstantTimeEq; 7 + 8 + use crate::config::AuthConfig; 9 + use crate::oauth::OAuthError; 10 + 11 + const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 12 + 13 + pub struct TokenClaims { 14 + pub jti: String, 15 + pub exp: i64, 16 + pub iat: i64, 17 + } 18 + 19 + pub fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 20 + let mut hasher = Sha256::new(); 21 + hasher.update(code_verifier.as_bytes()); 22 + let hash = hasher.finalize(); 23 + let computed_challenge = URL_SAFE_NO_PAD.encode(&hash); 24 + 25 + if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) { 26 + return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string())); 27 + } 28 + 29 + Ok(()) 30 + } 31 + 32 + pub fn create_access_token( 33 + token_id: &str, 34 + sub: &str, 35 + dpop_jkt: Option<&str>, 36 + ) -> Result<String, OAuthError> { 37 + use serde_json::json; 38 + 39 + let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 40 + let issuer = format!("https://{}", pds_hostname); 41 + 42 + let now = Utc::now().timestamp(); 43 + let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 44 + 45 + let mut payload = json!({ 46 + "iss": issuer, 47 + "sub": sub, 48 + "aud": issuer, 49 + "iat": now, 50 + "exp": exp, 51 + "jti": token_id, 52 + "scope": "atproto" 53 + }); 54 + 55 + if let Some(jkt) = dpop_jkt { 56 + payload["cnf"] = json!({ "jkt": jkt }); 57 + } 58 + 59 + let header = json!({ 60 + "alg": "HS256", 61 + "typ": "at+jwt" 62 + }); 63 + 64 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 65 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 66 + 67 + let signing_input = format!("{}.{}", header_b64, payload_b64); 68 + 69 + let config = AuthConfig::get(); 70 + 71 + type HmacSha256 = hmac::Hmac<Sha256>; 72 + 73 + let mut mac = HmacSha256::new_from_slice(config.jwt_secret().as_bytes()) 74 + .map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?; 75 + mac.update(signing_input.as_bytes()); 76 + let signature = mac.finalize().into_bytes(); 77 + 78 + let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 79 + 80 + Ok(format!("{}.{}", signing_input, signature_b64)) 81 + } 82 + 83 + pub fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 84 + let parts: Vec<&str> = token.split('.').collect(); 85 + if parts.len() != 3 { 86 + return Err(OAuthError::InvalidToken("Invalid token format".to_string())); 87 + } 88 + 89 + let header_bytes = URL_SAFE_NO_PAD 90 + .decode(parts[0]) 91 + .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?; 92 + let header: serde_json::Value = serde_json::from_slice(&header_bytes) 93 + .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?; 94 + 95 + if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") { 96 + return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string())); 97 + } 98 + if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") { 99 + return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string())); 100 + } 101 + 102 + let config = AuthConfig::get(); 103 + let secret = config.jwt_secret(); 104 + 105 + let signing_input = format!("{}.{}", parts[0], parts[1]); 106 + let provided_sig = URL_SAFE_NO_PAD 107 + .decode(parts[2]) 108 + .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?; 109 + 110 + type HmacSha256 = hmac::Hmac<Sha256>; 111 + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) 112 + .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?; 113 + mac.update(signing_input.as_bytes()); 114 + let expected_sig = mac.finalize().into_bytes(); 115 + 116 + if !bool::from(expected_sig.ct_eq(&provided_sig)) { 117 + return Err(OAuthError::InvalidToken("Invalid token signature".to_string())); 118 + } 119 + 120 + let payload_bytes = URL_SAFE_NO_PAD 121 + .decode(parts[1]) 122 + .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?; 123 + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 124 + .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?; 125 + 126 + let jti = payload 127 + .get("jti") 128 + .and_then(|j| j.as_str()) 129 + .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))? 130 + .to_string(); 131 + 132 + let exp = payload 133 + .get("exp") 134 + .and_then(|e| e.as_i64()) 135 + .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?; 136 + 137 + let iat = payload 138 + .get("iat") 139 + .and_then(|i| i.as_i64()) 140 + .ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?; 141 + 142 + Ok(TokenClaims { jti, exp, iat }) 143 + }
+122
src/oauth/endpoints/token/introspect.rs
··· 1 + use axum::{Form, Json}; 2 + use axum::extract::State; 3 + use axum::http::StatusCode; 4 + use chrono::Utc; 5 + use serde::{Deserialize, Serialize}; 6 + 7 + use crate::state::AppState; 8 + use crate::oauth::{OAuthError, db}; 9 + 10 + use super::helpers::extract_token_claims; 11 + 12 + #[derive(Debug, Deserialize)] 13 + pub struct RevokeRequest { 14 + pub token: Option<String>, 15 + #[serde(default)] 16 + pub token_type_hint: Option<String>, 17 + } 18 + 19 + pub async fn revoke_token( 20 + State(state): State<AppState>, 21 + Form(request): Form<RevokeRequest>, 22 + ) -> Result<StatusCode, OAuthError> { 23 + if let Some(token) = &request.token { 24 + if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 25 + db::delete_token_family(&state.db, db_id).await?; 26 + } else { 27 + db::delete_token(&state.db, token).await?; 28 + } 29 + } 30 + 31 + Ok(StatusCode::OK) 32 + } 33 + 34 + #[derive(Debug, Deserialize)] 35 + pub struct IntrospectRequest { 36 + pub token: String, 37 + #[serde(default)] 38 + pub token_type_hint: Option<String>, 39 + } 40 + 41 + #[derive(Debug, Serialize)] 42 + pub struct IntrospectResponse { 43 + pub active: bool, 44 + #[serde(skip_serializing_if = "Option::is_none")] 45 + pub scope: Option<String>, 46 + #[serde(skip_serializing_if = "Option::is_none")] 47 + pub client_id: Option<String>, 48 + #[serde(skip_serializing_if = "Option::is_none")] 49 + pub username: Option<String>, 50 + #[serde(skip_serializing_if = "Option::is_none")] 51 + pub token_type: Option<String>, 52 + #[serde(skip_serializing_if = "Option::is_none")] 53 + pub exp: Option<i64>, 54 + #[serde(skip_serializing_if = "Option::is_none")] 55 + pub iat: Option<i64>, 56 + #[serde(skip_serializing_if = "Option::is_none")] 57 + pub nbf: Option<i64>, 58 + #[serde(skip_serializing_if = "Option::is_none")] 59 + pub sub: Option<String>, 60 + #[serde(skip_serializing_if = "Option::is_none")] 61 + pub aud: Option<String>, 62 + #[serde(skip_serializing_if = "Option::is_none")] 63 + pub iss: Option<String>, 64 + #[serde(skip_serializing_if = "Option::is_none")] 65 + pub jti: Option<String>, 66 + } 67 + 68 + pub async fn introspect_token( 69 + State(state): State<AppState>, 70 + Form(request): Form<IntrospectRequest>, 71 + ) -> Json<IntrospectResponse> { 72 + let inactive_response = IntrospectResponse { 73 + active: false, 74 + scope: None, 75 + client_id: None, 76 + username: None, 77 + token_type: None, 78 + exp: None, 79 + iat: None, 80 + nbf: None, 81 + sub: None, 82 + aud: None, 83 + iss: None, 84 + jti: None, 85 + }; 86 + 87 + let token_info = match extract_token_claims(&request.token) { 88 + Ok(info) => info, 89 + Err(_) => return Json(inactive_response), 90 + }; 91 + 92 + let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 93 + Ok(Some(data)) => data, 94 + _ => return Json(inactive_response), 95 + }; 96 + 97 + if token_data.expires_at < Utc::now() { 98 + return Json(inactive_response); 99 + } 100 + 101 + let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 + let issuer = format!("https://{}", pds_hostname); 103 + 104 + Json(IntrospectResponse { 105 + active: true, 106 + scope: token_data.scope, 107 + client_id: Some(token_data.client_id), 108 + username: None, 109 + token_type: if token_data.parameters.dpop_jkt.is_some() { 110 + Some("DPoP".to_string()) 111 + } else { 112 + Some("Bearer".to_string()) 113 + }, 114 + exp: Some(token_info.exp), 115 + iat: Some(token_info.iat), 116 + nbf: Some(token_info.iat), 117 + sub: Some(token_data.did), 118 + aud: Some(issuer.clone()), 119 + iss: Some(issuer), 120 + jti: Some(token_info.jti), 121 + }) 122 + }
+44
src/oauth/endpoints/token/mod.rs
··· 1 + mod grants; 2 + mod helpers; 3 + mod introspect; 4 + mod types; 5 + 6 + use axum::{ 7 + Form, Json, 8 + extract::State, 9 + http::HeaderMap, 10 + }; 11 + 12 + use crate::state::AppState; 13 + use crate::oauth::OAuthError; 14 + 15 + pub use grants::{handle_authorization_code_grant, handle_refresh_token_grant}; 16 + pub use helpers::{create_access_token, extract_token_claims, verify_pkce, TokenClaims}; 17 + pub use introspect::{ 18 + introspect_token, revoke_token, IntrospectRequest, IntrospectResponse, RevokeRequest, 19 + }; 20 + pub use types::{TokenRequest, TokenResponse}; 21 + 22 + pub async fn token_endpoint( 23 + State(state): State<AppState>, 24 + headers: HeaderMap, 25 + Form(request): Form<TokenRequest>, 26 + ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 27 + let dpop_proof = headers 28 + .get("DPoP") 29 + .and_then(|v| v.to_str().ok()) 30 + .map(|s| s.to_string()); 31 + 32 + match request.grant_type.as_str() { 33 + "authorization_code" => { 34 + handle_authorization_code_grant(state, headers, request, dpop_proof).await 35 + } 36 + "refresh_token" => { 37 + handle_refresh_token_grant(state, headers, request, dpop_proof).await 38 + } 39 + _ => Err(OAuthError::UnsupportedGrantType(format!( 40 + "Unsupported grant_type: {}", 41 + request.grant_type 42 + ))), 43 + } 44 + }
+35
src/oauth/endpoints/token/types.rs
··· 1 + use serde::{Deserialize, Serialize}; 2 + 3 + #[derive(Debug, Deserialize)] 4 + pub struct TokenRequest { 5 + pub grant_type: String, 6 + #[serde(default)] 7 + pub code: Option<String>, 8 + #[serde(default)] 9 + pub redirect_uri: Option<String>, 10 + #[serde(default)] 11 + pub code_verifier: Option<String>, 12 + #[serde(default)] 13 + pub refresh_token: Option<String>, 14 + #[serde(default)] 15 + pub client_id: Option<String>, 16 + #[serde(default)] 17 + pub client_secret: Option<String>, 18 + #[serde(default)] 19 + pub client_assertion: Option<String>, 20 + #[serde(default)] 21 + pub client_assertion_type: Option<String>, 22 + } 23 + 24 + #[derive(Debug, Serialize)] 25 + pub struct TokenResponse { 26 + pub access_token: String, 27 + pub token_type: String, 28 + pub expires_in: u64, 29 + #[serde(skip_serializing_if = "Option::is_none")] 30 + pub refresh_token: Option<String>, 31 + #[serde(skip_serializing_if = "Option::is_none")] 32 + pub scope: Option<String>, 33 + #[serde(skip_serializing_if = "Option::is_none")] 34 + pub sub: Option<String>, 35 + }
+2 -1
src/repo/mod.rs
··· 38 38 let mut hasher = Sha256::new(); 39 39 hasher.update(data); 40 40 let hash = hasher.finalize(); 41 - let multihash = Multihash::wrap(0x12, &hash).unwrap(); 41 + let multihash = Multihash::wrap(0x12, &hash) 42 + .map_err(|e| RepoError::storage(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to wrap multihash: {:?}", e))))?; 42 43 let cid = Cid::new_v1(0x71, multihash); 43 44 let cid_bytes = cid.to_bytes(); 44 45
+12 -3
src/repo/tracking.rs
··· 21 21 } 22 22 23 23 pub fn get_written_cids(&self) -> Vec<Cid> { 24 - self.written_cids.lock().unwrap().clone() 24 + match self.written_cids.lock() { 25 + Ok(guard) => guard.clone(), 26 + Err(poisoned) => poisoned.into_inner().clone(), 27 + } 25 28 } 26 29 } 27 30 ··· 32 35 33 36 async fn put(&self, data: &[u8]) -> Result<Cid, RepoError> { 34 37 let cid = self.inner.put(data).await?; 35 - self.written_cids.lock().unwrap().push(cid.clone()); 38 + match self.written_cids.lock() { 39 + Ok(mut guard) => guard.push(cid.clone()), 40 + Err(poisoned) => poisoned.into_inner().push(cid.clone()), 41 + } 36 42 Ok(cid) 37 43 } 38 44 ··· 47 53 let blocks: Vec<_> = blocks.into_iter().collect(); 48 54 let cids: Vec<Cid> = blocks.iter().map(|(cid, _)| cid.clone()).collect(); 49 55 self.inner.put_many(blocks).await?; 50 - self.written_cids.lock().unwrap().extend(cids); 56 + match self.written_cids.lock() { 57 + Ok(mut guard) => guard.extend(cids), 58 + Err(poisoned) => poisoned.into_inner().extend(cids), 59 + } 51 60 Ok(()) 52 61 } 53 62
+1 -1
src/sync/blob.rs
··· 132 132 .into_response(); 133 133 } 134 134 135 - let limit = params.limit.unwrap_or(500).min(1000); 135 + let limit = params.limit.unwrap_or(500).clamp(1, 1000); 136 136 let cursor_cid = params.cursor.as_deref().unwrap_or(""); 137 137 138 138 let user_result = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
+5 -4
src/sync/car.rs
··· 23 23 Ok(()) 24 24 } 25 25 26 - pub fn encode_car_header(root_cid: &Cid) -> Vec<u8> { 26 + pub fn encode_car_header(root_cid: &Cid) -> Result<Vec<u8>, String> { 27 27 let header = CarHeader::new_v1(vec![root_cid.clone()]); 28 - let header_cbor = header.encode().unwrap_or_default(); 28 + let header_cbor = header.encode().map_err(|e| format!("Failed to encode CAR header: {:?}", e))?; 29 29 30 30 let mut result = Vec::new(); 31 - write_varint(&mut result, header_cbor.len() as u64).unwrap(); 31 + write_varint(&mut result, header_cbor.len() as u64) 32 + .expect("Writing to Vec<u8> should never fail"); 32 33 result.extend_from_slice(&header_cbor); 33 - result 34 + Ok(result) 34 35 }
+1 -1
src/sync/commit.rs
··· 98 98 State(state): State<AppState>, 99 99 Query(params): Query<ListReposParams>, 100 100 ) -> Response { 101 - let limit = params.limit.unwrap_or(50).min(1000); 101 + let limit = params.limit.unwrap_or(50).clamp(1, 1000); 102 102 let cursor_did = params.cursor.as_deref().unwrap_or(""); 103 103 104 104 let result = sqlx::query!(
+9 -5
src/sync/frame.rs
··· 38 38 pub cid: Option<String>, 39 39 } 40 40 41 - impl From<SequencedEvent> for CommitFrame { 42 - fn from(event: SequencedEvent) -> Self { 41 + impl TryFrom<SequencedEvent> for CommitFrame { 42 + type Error = &'static str; 43 + 44 + fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> { 43 45 let ops = serde_json::from_value::<Vec<RepoOp>>(event.ops.unwrap_or_default()) 44 46 .unwrap_or_else(|_| vec![]); 45 47 46 - CommitFrame { 48 + let commit_cid = event.commit_cid.ok_or("Missing commit_cid in event")?; 49 + 50 + Ok(CommitFrame { 47 51 seq: event.seq, 48 52 rebase: false, 49 53 too_big: false, 50 54 repo: event.did, 51 - commit: event.commit_cid.unwrap_or_default(), 55 + commit: commit_cid, 52 56 prev: event.prev_cid, 53 57 blocks: Vec::new(), 54 58 ops, 55 59 blobs: event.blobs.unwrap_or_default(), 56 60 time: event.created_at.to_rfc3339(), 57 - } 61 + }) 58 62 } 59 63 }
+1 -2
src/sync/relay_client.rs
··· 12 12 match connect_async(&url).await { 13 13 Ok((mut ws_stream, _)) => { 14 14 info!("Connected to firehose relay: {}", url); 15 + let mut rx = state.firehose_tx.subscribe(); 15 16 if let Some(tx) = ready_tx.as_ref() { 16 17 tx.send(()).await.ok(); 17 18 } 18 - 19 - let mut rx = state.firehose_tx.subscribe(); 20 19 21 20 loop { 22 21 tokio::select! {
+48 -17
src/sync/repo.rs
··· 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 17 18 + const MAX_REPO_BLOCKS_TRAVERSAL: usize = 20_000; 19 + 18 20 #[derive(Deserialize)] 19 21 pub struct GetBlocksQuery { 20 22 pub did: String, ··· 52 54 } 53 55 }; 54 56 55 - let root_cid = cids.first().cloned().unwrap_or_default(); 56 - 57 57 if cids.is_empty() { 58 58 return (StatusCode::BAD_REQUEST, "No CIDs provided").into_response(); 59 59 } 60 60 61 - let header = encode_car_header(&root_cid); 61 + let root_cid = cids[0]; 62 + 63 + let header = match encode_car_header(&root_cid) { 64 + Ok(h) => h, 65 + Err(e) => { 66 + error!("Failed to encode CAR header: {}", e); 67 + return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to encode CAR").into_response(); 68 + } 69 + }; 62 70 63 71 let mut car_bytes = header; 64 72 ··· 69 77 let total_len = cid_bytes.len() + block.len(); 70 78 71 79 let mut writer = Vec::new(); 72 - crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap(); 73 - writer.write_all(&cid_bytes).unwrap(); 74 - writer.write_all(&block).unwrap(); 80 + crate::sync::car::write_varint(&mut writer, total_len as u64) 81 + .expect("Writing to Vec<u8> should never fail"); 82 + writer.write_all(&cid_bytes) 83 + .expect("Writing to Vec<u8> should never fail"); 84 + writer.write_all(&block) 85 + .expect("Writing to Vec<u8> should never fail"); 75 86 76 87 car_bytes.extend_from_slice(&writer); 77 88 } ··· 143 154 } 144 155 }; 145 156 146 - let mut car_bytes = encode_car_header(&head_cid); 157 + let mut car_bytes = match encode_car_header(&head_cid) { 158 + Ok(h) => h, 159 + Err(e) => { 160 + return ( 161 + StatusCode::INTERNAL_SERVER_ERROR, 162 + Json(json!({"error": "InternalError", "message": format!("Failed to encode CAR header: {}", e)})), 163 + ) 164 + .into_response(); 165 + } 166 + }; 147 167 148 168 let mut stack = vec![head_cid]; 149 169 let mut visited = std::collections::HashSet::new(); 150 - let mut limit = 20000; 170 + let mut remaining = MAX_REPO_BLOCKS_TRAVERSAL; 151 171 152 172 while let Some(cid) = stack.pop() { 153 173 if visited.contains(&cid) { 154 174 continue; 155 175 } 156 176 visited.insert(cid); 157 - if limit == 0 { break; } 158 - limit -= 1; 177 + if remaining == 0 { break; } 178 + remaining -= 1; 159 179 160 180 if let Ok(Some(block)) = state.block_store.get(&cid).await { 161 181 let cid_bytes = cid.to_bytes(); 162 182 let total_len = cid_bytes.len() + block.len(); 163 183 let mut writer = Vec::new(); 164 - crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap(); 165 - writer.write_all(&cid_bytes).unwrap(); 166 - writer.write_all(&block).unwrap(); 184 + crate::sync::car::write_varint(&mut writer, total_len as u64) 185 + .expect("Writing to Vec<u8> should never fail"); 186 + writer.write_all(&cid_bytes) 187 + .expect("Writing to Vec<u8> should never fail"); 188 + writer.write_all(&block) 189 + .expect("Writing to Vec<u8> should never fail"); 167 190 car_bytes.extend_from_slice(&writer); 168 191 169 192 if let Ok(value) = serde_ipld_dagcbor::from_slice::<Ipld>(&block) { ··· 258 281 _ => return (StatusCode::NOT_FOUND, "Block not found").into_response(), 259 282 }; 260 283 261 - let header = encode_car_header(&cid); 284 + let header = match encode_car_header(&cid) { 285 + Ok(h) => h, 286 + Err(e) => { 287 + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to encode CAR header: {}", e)).into_response(); 288 + } 289 + }; 262 290 let mut car_bytes = header; 263 291 264 292 let cid_bytes = cid.to_bytes(); 265 293 let total_len = cid_bytes.len() + block.len(); 266 294 let mut writer = Vec::new(); 267 - crate::sync::car::write_varint(&mut writer, total_len as u64).unwrap(); 268 - writer.write_all(&cid_bytes).unwrap(); 269 - writer.write_all(&block).unwrap(); 295 + crate::sync::car::write_varint(&mut writer, total_len as u64) 296 + .expect("Writing to Vec<u8> should never fail"); 297 + writer.write_all(&cid_bytes) 298 + .expect("Writing to Vec<u8> should never fail"); 299 + writer.write_all(&block) 300 + .expect("Writing to Vec<u8> should never fail"); 270 301 car_bytes.extend_from_slice(&writer); 271 302 272 303 (
+37 -23
src/sync/subscribe_repos.rs
··· 9 9 use serde::Deserialize; 10 10 use tracing::{error, info, warn}; 11 11 12 + const BACKFILL_BATCH_SIZE: i64 = 1000; 13 + 12 14 #[derive(Deserialize)] 13 15 pub struct SubscribeReposParams { 14 16 pub cursor: Option<i64>, ··· 37 39 info!(cursor = ?params.cursor, "New firehose subscriber"); 38 40 39 41 if let Some(cursor) = params.cursor { 40 - let events = sqlx::query_as!( 41 - SequencedEvent, 42 - r#" 43 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids 44 - FROM repo_seq 45 - WHERE seq > $1 46 - ORDER BY seq ASC 47 - "#, 48 - cursor 49 - ) 50 - .fetch_all(&state.db) 51 - .await; 42 + let mut current_cursor = cursor; 43 + loop { 44 + let events = sqlx::query_as!( 45 + SequencedEvent, 46 + r#" 47 + SELECT seq, did, created_at, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids 48 + FROM repo_seq 49 + WHERE seq > $1 50 + ORDER BY seq ASC 51 + LIMIT $2 52 + "#, 53 + current_cursor, 54 + BACKFILL_BATCH_SIZE 55 + ) 56 + .fetch_all(&state.db) 57 + .await; 52 58 53 - match events { 54 - Ok(events) => { 55 - for event in events { 56 - if let Err(e) = send_event(&mut socket, &state, event).await { 57 - warn!("Failed to send backfill event: {}", e); 58 - return; 59 + match events { 60 + Ok(events) => { 61 + if events.is_empty() { 62 + break; 63 + } 64 + for event in &events { 65 + current_cursor = event.seq; 66 + if let Err(e) = send_event(&mut socket, &state, event.clone()).await { 67 + warn!("Failed to send backfill event: {}", e); 68 + return; 69 + } 70 + } 71 + if (events.len() as i64) < BACKFILL_BATCH_SIZE { 72 + break; 59 73 } 60 74 } 61 - } 62 - Err(e) => { 63 - error!("Failed to fetch backfill events: {}", e); 64 - socket.close().await.ok(); 65 - return; 75 + Err(e) => { 76 + error!("Failed to fetch backfill events: {}", e); 77 + socket.close().await.ok(); 78 + return; 79 + } 66 80 } 67 81 } 68 82 }
+8 -15
src/sync/util.rs
··· 2 2 use crate::sync::firehose::SequencedEvent; 3 3 use crate::sync::frame::{CommitFrame, Frame, FrameData}; 4 4 use cid::Cid; 5 - use jacquard_repo::car::write_car; 5 + use jacquard_repo::car::write_car_bytes; 6 6 use jacquard_repo::storage::BlockStore; 7 - use std::fs; 8 7 use std::str::FromStr; 9 - use tokio::fs::File; 10 - use tokio::io::AsyncReadExt; 11 - use uuid::Uuid; 12 8 13 9 pub async fn format_event_for_sending( 14 10 state: &AppState, 15 11 event: SequencedEvent, 16 12 ) -> Result<Vec<u8>, anyhow::Error> { 17 13 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 18 - let mut frame: CommitFrame = event.into(); 14 + let mut frame: CommitFrame = event.try_into() 15 + .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 19 16 20 - let mut car_bytes = Vec::new(); 21 - if !block_cids_str.is_empty() { 22 - let temp_path = format!("/tmp/{}.car", Uuid::new_v4()); 17 + let car_bytes = if !block_cids_str.is_empty() { 23 18 let mut blocks = std::collections::BTreeMap::new(); 24 19 25 20 for cid_str in block_cids_str { ··· 33 28 } 34 29 35 30 let root = Cid::from_str(&frame.commit)?; 36 - write_car(&temp_path, vec![root], blocks).await?; 37 - 38 - let mut file = File::open(&temp_path).await?; 39 - file.read_to_end(&mut car_bytes).await?; 40 - fs::remove_file(&temp_path)?; 41 - } 31 + write_car_bytes(root, blocks).await? 32 + } else { 33 + Vec::new() 34 + }; 42 35 frame.blocks = car_bytes; 43 36 44 37 let frame = Frame {
+2 -342
src/sync/verify.rs
··· 302 302 } 303 303 304 304 #[cfg(test)] 305 - mod tests { 306 - use super::*; 307 - use sha2::{Digest, Sha256}; 308 - 309 - fn make_cid(data: &[u8]) -> Cid { 310 - let mut hasher = Sha256::new(); 311 - hasher.update(data); 312 - let hash = hasher.finalize(); 313 - let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 314 - Cid::new_v1(0x71, multihash) 315 - } 316 - 317 - #[test] 318 - fn test_verifier_creation() { 319 - let _verifier = CarVerifier::new(); 320 - } 321 - 322 - #[test] 323 - fn test_verify_error_display() { 324 - let err = VerifyError::DidMismatch { 325 - commit_did: "did:plc:abc".to_string(), 326 - expected_did: "did:plc:xyz".to_string(), 327 - }; 328 - assert!(err.to_string().contains("did:plc:abc")); 329 - assert!(err.to_string().contains("did:plc:xyz")); 330 - 331 - let err = VerifyError::InvalidSignature; 332 - assert!(err.to_string().contains("signature")); 333 - 334 - let err = VerifyError::NoSigningKey; 335 - assert!(err.to_string().contains("signing key")); 336 - 337 - let err = VerifyError::MstValidationFailed("test error".to_string()); 338 - assert!(err.to_string().contains("test error")); 339 - } 340 - 341 - #[test] 342 - fn test_mst_validation_missing_root_block() { 343 - let verifier = CarVerifier::new(); 344 - let blocks: HashMap<Cid, Bytes> = HashMap::new(); 345 - 346 - let fake_cid = make_cid(b"fake data"); 347 - let result = verifier.verify_mst_structure(&fake_cid, &blocks); 348 - 349 - assert!(result.is_err()); 350 - let err = result.unwrap_err(); 351 - assert!(matches!(err, VerifyError::BlockNotFound(_))); 352 - } 353 - 354 - #[test] 355 - fn test_mst_validation_invalid_cbor() { 356 - let verifier = CarVerifier::new(); 357 - 358 - let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]); 359 - let cid = make_cid(&bad_cbor); 360 - 361 - let mut blocks = HashMap::new(); 362 - blocks.insert(cid, bad_cbor); 363 - 364 - let result = verifier.verify_mst_structure(&cid, &blocks); 365 - 366 - assert!(result.is_err()); 367 - let err = result.unwrap_err(); 368 - assert!(matches!(err, VerifyError::InvalidCbor(_))); 369 - } 370 - 371 - #[test] 372 - fn test_mst_validation_empty_node() { 373 - let verifier = CarVerifier::new(); 374 - 375 - let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 376 - "e": [] 377 - })).unwrap(); 378 - let cid = make_cid(&empty_node); 379 - 380 - let mut blocks = HashMap::new(); 381 - blocks.insert(cid, Bytes::from(empty_node)); 382 - 383 - let result = verifier.verify_mst_structure(&cid, &blocks); 384 - assert!(result.is_ok()); 385 - } 386 - 387 - #[test] 388 - fn test_mst_validation_missing_left_pointer() { 389 - use ipld_core::ipld::Ipld; 390 - 391 - let verifier = CarVerifier::new(); 392 - 393 - let missing_left_cid = make_cid(b"missing left"); 394 - let node = Ipld::Map(std::collections::BTreeMap::from([ 395 - ("l".to_string(), Ipld::Link(missing_left_cid)), 396 - ("e".to_string(), Ipld::List(vec![])), 397 - ])); 398 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 399 - let cid = make_cid(&node_bytes); 400 - 401 - let mut blocks = HashMap::new(); 402 - blocks.insert(cid, Bytes::from(node_bytes)); 403 - 404 - let result = verifier.verify_mst_structure(&cid, &blocks); 405 - 406 - assert!(result.is_err()); 407 - let err = result.unwrap_err(); 408 - assert!(matches!(err, VerifyError::BlockNotFound(_))); 409 - assert!(err.to_string().contains("left pointer")); 410 - } 411 - 412 - #[test] 413 - fn test_mst_validation_missing_subtree() { 414 - use ipld_core::ipld::Ipld; 415 - 416 - let verifier = CarVerifier::new(); 417 - 418 - let missing_subtree_cid = make_cid(b"missing subtree"); 419 - let record_cid = make_cid(b"record"); 420 - 421 - let entry = Ipld::Map(std::collections::BTreeMap::from([ 422 - ("k".to_string(), Ipld::Bytes(b"key1".to_vec())), 423 - ("v".to_string(), Ipld::Link(record_cid)), 424 - ("p".to_string(), Ipld::Integer(0)), 425 - ("t".to_string(), Ipld::Link(missing_subtree_cid)), 426 - ])); 427 - 428 - let node = Ipld::Map(std::collections::BTreeMap::from([ 429 - ("e".to_string(), Ipld::List(vec![entry])), 430 - ])); 431 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 432 - let cid = make_cid(&node_bytes); 433 - 434 - let mut blocks = HashMap::new(); 435 - blocks.insert(cid, Bytes::from(node_bytes)); 436 - 437 - let result = verifier.verify_mst_structure(&cid, &blocks); 438 - 439 - assert!(result.is_err()); 440 - let err = result.unwrap_err(); 441 - assert!(matches!(err, VerifyError::BlockNotFound(_))); 442 - assert!(err.to_string().contains("subtree")); 443 - } 444 - 445 - #[test] 446 - fn test_mst_validation_unsorted_keys() { 447 - use ipld_core::ipld::Ipld; 448 - 449 - let verifier = CarVerifier::new(); 450 - 451 - let record_cid = make_cid(b"record"); 452 - 453 - let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 454 - ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), 455 - ("v".to_string(), Ipld::Link(record_cid)), 456 - ("p".to_string(), Ipld::Integer(0)), 457 - ])); 458 - 459 - let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 460 - ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), 461 - ("v".to_string(), Ipld::Link(record_cid)), 462 - ("p".to_string(), Ipld::Integer(0)), 463 - ])); 464 - 465 - let node = Ipld::Map(std::collections::BTreeMap::from([ 466 - ("e".to_string(), Ipld::List(vec![entry1, entry2])), 467 - ])); 468 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 469 - let cid = make_cid(&node_bytes); 470 - 471 - let mut blocks = HashMap::new(); 472 - blocks.insert(cid, Bytes::from(node_bytes)); 473 - 474 - let result = verifier.verify_mst_structure(&cid, &blocks); 475 - 476 - assert!(result.is_err()); 477 - let err = result.unwrap_err(); 478 - assert!(matches!(err, VerifyError::MstValidationFailed(_))); 479 - assert!(err.to_string().contains("sorted")); 480 - } 481 - 482 - #[test] 483 - fn test_mst_validation_sorted_keys_ok() { 484 - use ipld_core::ipld::Ipld; 485 - 486 - let verifier = CarVerifier::new(); 487 - 488 - let record_cid = make_cid(b"record"); 489 - 490 - let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 491 - ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), 492 - ("v".to_string(), Ipld::Link(record_cid)), 493 - ("p".to_string(), Ipld::Integer(0)), 494 - ])); 495 - 496 - let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 497 - ("k".to_string(), Ipld::Bytes(b"bbb".to_vec())), 498 - ("v".to_string(), Ipld::Link(record_cid)), 499 - ("p".to_string(), Ipld::Integer(0)), 500 - ])); 501 - 502 - let entry3 = Ipld::Map(std::collections::BTreeMap::from([ 503 - ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), 504 - ("v".to_string(), Ipld::Link(record_cid)), 505 - ("p".to_string(), Ipld::Integer(0)), 506 - ])); 507 - 508 - let node = Ipld::Map(std::collections::BTreeMap::from([ 509 - ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 510 - ])); 511 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 512 - let cid = make_cid(&node_bytes); 513 - 514 - let mut blocks = HashMap::new(); 515 - blocks.insert(cid, Bytes::from(node_bytes)); 516 - 517 - let result = verifier.verify_mst_structure(&cid, &blocks); 518 - assert!(result.is_ok()); 519 - } 520 - 521 - #[test] 522 - fn test_mst_validation_with_valid_left_pointer() { 523 - use ipld_core::ipld::Ipld; 524 - 525 - let verifier = CarVerifier::new(); 526 - 527 - let left_node = Ipld::Map(std::collections::BTreeMap::from([ 528 - ("e".to_string(), Ipld::List(vec![])), 529 - ])); 530 - let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap(); 531 - let left_cid = make_cid(&left_node_bytes); 532 - 533 - let root_node = Ipld::Map(std::collections::BTreeMap::from([ 534 - ("l".to_string(), Ipld::Link(left_cid)), 535 - ("e".to_string(), Ipld::List(vec![])), 536 - ])); 537 - let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap(); 538 - let root_cid = make_cid(&root_node_bytes); 539 - 540 - let mut blocks = HashMap::new(); 541 - blocks.insert(root_cid, Bytes::from(root_node_bytes)); 542 - blocks.insert(left_cid, Bytes::from(left_node_bytes)); 543 - 544 - let result = verifier.verify_mst_structure(&root_cid, &blocks); 545 - assert!(result.is_ok()); 546 - } 547 - 548 - #[test] 549 - fn test_mst_validation_cycle_detection() { 550 - let verifier = CarVerifier::new(); 551 - 552 - let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 553 - "e": [] 554 - })).unwrap(); 555 - let cid = make_cid(&node); 556 - 557 - let mut blocks = HashMap::new(); 558 - blocks.insert(cid, Bytes::from(node)); 559 - 560 - let result = verifier.verify_mst_structure(&cid, &blocks); 561 - assert!(result.is_ok()); 562 - } 563 - 564 - #[tokio::test] 565 - async fn test_unsupported_did_method() { 566 - let verifier = CarVerifier::new(); 567 - let result = verifier.resolve_did_document("did:unknown:test").await; 568 - 569 - assert!(result.is_err()); 570 - let err = result.unwrap_err(); 571 - assert!(matches!(err, VerifyError::DidResolutionFailed(_))); 572 - assert!(err.to_string().contains("Unsupported")); 573 - } 574 - 575 - #[test] 576 - fn test_mst_validation_with_prefix_compression() { 577 - use ipld_core::ipld::Ipld; 578 - 579 - let verifier = CarVerifier::new(); 580 - let record_cid = make_cid(b"record"); 581 - 582 - let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 583 - ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())), 584 - ("v".to_string(), Ipld::Link(record_cid)), 585 - ("p".to_string(), Ipld::Integer(0)), 586 - ])); 587 - 588 - let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 589 - ("k".to_string(), Ipld::Bytes(b"def".to_vec())), 590 - ("v".to_string(), Ipld::Link(record_cid)), 591 - ("p".to_string(), Ipld::Integer(19)), 592 - ])); 593 - 594 - let entry3 = Ipld::Map(std::collections::BTreeMap::from([ 595 - ("k".to_string(), Ipld::Bytes(b"xyz".to_vec())), 596 - ("v".to_string(), Ipld::Link(record_cid)), 597 - ("p".to_string(), Ipld::Integer(19)), 598 - ])); 599 - 600 - let node = Ipld::Map(std::collections::BTreeMap::from([ 601 - ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 602 - ])); 603 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 604 - let cid = make_cid(&node_bytes); 605 - 606 - let mut blocks = HashMap::new(); 607 - blocks.insert(cid, Bytes::from(node_bytes)); 608 - 609 - let result = verifier.verify_mst_structure(&cid, &blocks); 610 - assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 611 - } 612 - 613 - #[test] 614 - fn test_mst_validation_prefix_compression_unsorted() { 615 - use ipld_core::ipld::Ipld; 616 - 617 - let verifier = CarVerifier::new(); 618 - let record_cid = make_cid(b"record"); 619 - 620 - let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 621 - ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())), 622 - ("v".to_string(), Ipld::Link(record_cid)), 623 - ("p".to_string(), Ipld::Integer(0)), 624 - ])); 625 - 626 - let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 627 - ("k".to_string(), Ipld::Bytes(b"abc".to_vec())), 628 - ("v".to_string(), Ipld::Link(record_cid)), 629 - ("p".to_string(), Ipld::Integer(19)), 630 - ])); 631 - 632 - let node = Ipld::Map(std::collections::BTreeMap::from([ 633 - ("e".to_string(), Ipld::List(vec![entry1, entry2])), 634 - ])); 635 - let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 636 - let cid = make_cid(&node_bytes); 637 - 638 - let mut blocks = HashMap::new(); 639 - blocks.insert(cid, Bytes::from(node_bytes)); 640 - 641 - let result = verifier.verify_mst_structure(&cid, &blocks); 642 - assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation"); 643 - let err = result.unwrap_err(); 644 - assert!(matches!(err, VerifyError::MstValidationFailed(_))); 645 - } 646 - } 305 + #[path = "verify_tests.rs"] 306 + mod tests;
+346
src/sync/verify_tests.rs
··· 1 + #[cfg(test)] 2 + mod tests { 3 + use crate::sync::verify::{CarVerifier, VerifyError}; 4 + use bytes::Bytes; 5 + use cid::Cid; 6 + use sha2::{Digest, Sha256}; 7 + use std::collections::HashMap; 8 + 9 + fn make_cid(data: &[u8]) -> Cid { 10 + let mut hasher = Sha256::new(); 11 + hasher.update(data); 12 + let hash = hasher.finalize(); 13 + let multihash = multihash::Multihash::wrap(0x12, &hash).unwrap(); 14 + Cid::new_v1(0x71, multihash) 15 + } 16 + 17 + #[test] 18 + fn test_verifier_creation() { 19 + let _verifier = CarVerifier::new(); 20 + } 21 + 22 + #[test] 23 + fn test_verify_error_display() { 24 + let err = VerifyError::DidMismatch { 25 + commit_did: "did:plc:abc".to_string(), 26 + expected_did: "did:plc:xyz".to_string(), 27 + }; 28 + assert!(err.to_string().contains("did:plc:abc")); 29 + assert!(err.to_string().contains("did:plc:xyz")); 30 + 31 + let err = VerifyError::InvalidSignature; 32 + assert!(err.to_string().contains("signature")); 33 + 34 + let err = VerifyError::NoSigningKey; 35 + assert!(err.to_string().contains("signing key")); 36 + 37 + let err = VerifyError::MstValidationFailed("test error".to_string()); 38 + assert!(err.to_string().contains("test error")); 39 + } 40 + 41 + #[test] 42 + fn test_mst_validation_missing_root_block() { 43 + let verifier = CarVerifier::new(); 44 + let blocks: HashMap<Cid, Bytes> = HashMap::new(); 45 + 46 + let fake_cid = make_cid(b"fake data"); 47 + let result = verifier.verify_mst_structure(&fake_cid, &blocks); 48 + 49 + assert!(result.is_err()); 50 + let err = result.unwrap_err(); 51 + assert!(matches!(err, VerifyError::BlockNotFound(_))); 52 + } 53 + 54 + #[test] 55 + fn test_mst_validation_invalid_cbor() { 56 + let verifier = CarVerifier::new(); 57 + 58 + let bad_cbor = Bytes::from(vec![0xFF, 0xFF, 0xFF]); 59 + let cid = make_cid(&bad_cbor); 60 + 61 + let mut blocks = HashMap::new(); 62 + blocks.insert(cid, bad_cbor); 63 + 64 + let result = verifier.verify_mst_structure(&cid, &blocks); 65 + 66 + assert!(result.is_err()); 67 + let err = result.unwrap_err(); 68 + assert!(matches!(err, VerifyError::InvalidCbor(_))); 69 + } 70 + 71 + #[test] 72 + fn test_mst_validation_empty_node() { 73 + let verifier = CarVerifier::new(); 74 + 75 + let empty_node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 76 + "e": [] 77 + })).unwrap(); 78 + let cid = make_cid(&empty_node); 79 + 80 + let mut blocks = HashMap::new(); 81 + blocks.insert(cid, Bytes::from(empty_node)); 82 + 83 + let result = verifier.verify_mst_structure(&cid, &blocks); 84 + assert!(result.is_ok()); 85 + } 86 + 87 + #[test] 88 + fn test_mst_validation_missing_left_pointer() { 89 + use ipld_core::ipld::Ipld; 90 + 91 + let verifier = CarVerifier::new(); 92 + 93 + let missing_left_cid = make_cid(b"missing left"); 94 + let node = Ipld::Map(std::collections::BTreeMap::from([ 95 + ("l".to_string(), Ipld::Link(missing_left_cid)), 96 + ("e".to_string(), Ipld::List(vec![])), 97 + ])); 98 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 99 + let cid = make_cid(&node_bytes); 100 + 101 + let mut blocks = HashMap::new(); 102 + blocks.insert(cid, Bytes::from(node_bytes)); 103 + 104 + let result = verifier.verify_mst_structure(&cid, &blocks); 105 + 106 + assert!(result.is_err()); 107 + let err = result.unwrap_err(); 108 + assert!(matches!(err, VerifyError::BlockNotFound(_))); 109 + assert!(err.to_string().contains("left pointer")); 110 + } 111 + 112 + #[test] 113 + fn test_mst_validation_missing_subtree() { 114 + use ipld_core::ipld::Ipld; 115 + 116 + let verifier = CarVerifier::new(); 117 + 118 + let missing_subtree_cid = make_cid(b"missing subtree"); 119 + let record_cid = make_cid(b"record"); 120 + 121 + let entry = Ipld::Map(std::collections::BTreeMap::from([ 122 + ("k".to_string(), Ipld::Bytes(b"key1".to_vec())), 123 + ("v".to_string(), Ipld::Link(record_cid)), 124 + ("p".to_string(), Ipld::Integer(0)), 125 + ("t".to_string(), Ipld::Link(missing_subtree_cid)), 126 + ])); 127 + 128 + let node = Ipld::Map(std::collections::BTreeMap::from([ 129 + ("e".to_string(), Ipld::List(vec![entry])), 130 + ])); 131 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 132 + let cid = make_cid(&node_bytes); 133 + 134 + let mut blocks = HashMap::new(); 135 + blocks.insert(cid, Bytes::from(node_bytes)); 136 + 137 + let result = verifier.verify_mst_structure(&cid, &blocks); 138 + 139 + assert!(result.is_err()); 140 + let err = result.unwrap_err(); 141 + assert!(matches!(err, VerifyError::BlockNotFound(_))); 142 + assert!(err.to_string().contains("subtree")); 143 + } 144 + 145 + #[test] 146 + fn test_mst_validation_unsorted_keys() { 147 + use ipld_core::ipld::Ipld; 148 + 149 + let verifier = CarVerifier::new(); 150 + 151 + let record_cid = make_cid(b"record"); 152 + 153 + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 154 + ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), 155 + ("v".to_string(), Ipld::Link(record_cid)), 156 + ("p".to_string(), Ipld::Integer(0)), 157 + ])); 158 + 159 + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 160 + ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), 161 + ("v".to_string(), Ipld::Link(record_cid)), 162 + ("p".to_string(), Ipld::Integer(0)), 163 + ])); 164 + 165 + let node = Ipld::Map(std::collections::BTreeMap::from([ 166 + ("e".to_string(), Ipld::List(vec![entry1, entry2])), 167 + ])); 168 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 169 + let cid = make_cid(&node_bytes); 170 + 171 + let mut blocks = HashMap::new(); 172 + blocks.insert(cid, Bytes::from(node_bytes)); 173 + 174 + let result = verifier.verify_mst_structure(&cid, &blocks); 175 + 176 + assert!(result.is_err()); 177 + let err = result.unwrap_err(); 178 + assert!(matches!(err, VerifyError::MstValidationFailed(_))); 179 + assert!(err.to_string().contains("sorted")); 180 + } 181 + 182 + #[test] 183 + fn test_mst_validation_sorted_keys_ok() { 184 + use ipld_core::ipld::Ipld; 185 + 186 + let verifier = CarVerifier::new(); 187 + 188 + let record_cid = make_cid(b"record"); 189 + 190 + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 191 + ("k".to_string(), Ipld::Bytes(b"aaa".to_vec())), 192 + ("v".to_string(), Ipld::Link(record_cid)), 193 + ("p".to_string(), Ipld::Integer(0)), 194 + ])); 195 + 196 + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 197 + ("k".to_string(), Ipld::Bytes(b"bbb".to_vec())), 198 + ("v".to_string(), Ipld::Link(record_cid)), 199 + ("p".to_string(), Ipld::Integer(0)), 200 + ])); 201 + 202 + let entry3 = Ipld::Map(std::collections::BTreeMap::from([ 203 + ("k".to_string(), Ipld::Bytes(b"zzz".to_vec())), 204 + ("v".to_string(), Ipld::Link(record_cid)), 205 + ("p".to_string(), Ipld::Integer(0)), 206 + ])); 207 + 208 + let node = Ipld::Map(std::collections::BTreeMap::from([ 209 + ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 210 + ])); 211 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 212 + let cid = make_cid(&node_bytes); 213 + 214 + let mut blocks = HashMap::new(); 215 + blocks.insert(cid, Bytes::from(node_bytes)); 216 + 217 + let result = verifier.verify_mst_structure(&cid, &blocks); 218 + assert!(result.is_ok()); 219 + } 220 + 221 + #[test] 222 + fn test_mst_validation_with_valid_left_pointer() { 223 + use ipld_core::ipld::Ipld; 224 + 225 + let verifier = CarVerifier::new(); 226 + 227 + let left_node = Ipld::Map(std::collections::BTreeMap::from([ 228 + ("e".to_string(), Ipld::List(vec![])), 229 + ])); 230 + let left_node_bytes = serde_ipld_dagcbor::to_vec(&left_node).unwrap(); 231 + let left_cid = make_cid(&left_node_bytes); 232 + 233 + let root_node = Ipld::Map(std::collections::BTreeMap::from([ 234 + ("l".to_string(), Ipld::Link(left_cid)), 235 + ("e".to_string(), Ipld::List(vec![])), 236 + ])); 237 + let root_node_bytes = serde_ipld_dagcbor::to_vec(&root_node).unwrap(); 238 + let root_cid = make_cid(&root_node_bytes); 239 + 240 + let mut blocks = HashMap::new(); 241 + blocks.insert(root_cid, Bytes::from(root_node_bytes)); 242 + blocks.insert(left_cid, Bytes::from(left_node_bytes)); 243 + 244 + let result = verifier.verify_mst_structure(&root_cid, &blocks); 245 + assert!(result.is_ok()); 246 + } 247 + 248 + #[test] 249 + fn test_mst_validation_cycle_detection() { 250 + let verifier = CarVerifier::new(); 251 + 252 + let node = serde_ipld_dagcbor::to_vec(&serde_json::json!({ 253 + "e": [] 254 + })).unwrap(); 255 + let cid = make_cid(&node); 256 + 257 + let mut blocks = HashMap::new(); 258 + blocks.insert(cid, Bytes::from(node)); 259 + 260 + let result = verifier.verify_mst_structure(&cid, &blocks); 261 + assert!(result.is_ok()); 262 + } 263 + 264 + #[tokio::test] 265 + async fn test_unsupported_did_method() { 266 + let verifier = CarVerifier::new(); 267 + let result = verifier.resolve_did_document("did:unknown:test").await; 268 + 269 + assert!(result.is_err()); 270 + let err = result.unwrap_err(); 271 + assert!(matches!(err, VerifyError::DidResolutionFailed(_))); 272 + assert!(err.to_string().contains("Unsupported")); 273 + } 274 + 275 + #[test] 276 + fn test_mst_validation_with_prefix_compression() { 277 + use ipld_core::ipld::Ipld; 278 + 279 + let verifier = CarVerifier::new(); 280 + let record_cid = make_cid(b"record"); 281 + 282 + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 283 + ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/abc".to_vec())), 284 + ("v".to_string(), Ipld::Link(record_cid)), 285 + ("p".to_string(), Ipld::Integer(0)), 286 + ])); 287 + 288 + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 289 + ("k".to_string(), Ipld::Bytes(b"def".to_vec())), 290 + ("v".to_string(), Ipld::Link(record_cid)), 291 + ("p".to_string(), Ipld::Integer(19)), 292 + ])); 293 + 294 + let entry3 = Ipld::Map(std::collections::BTreeMap::from([ 295 + ("k".to_string(), Ipld::Bytes(b"xyz".to_vec())), 296 + ("v".to_string(), Ipld::Link(record_cid)), 297 + ("p".to_string(), Ipld::Integer(19)), 298 + ])); 299 + 300 + let node = Ipld::Map(std::collections::BTreeMap::from([ 301 + ("e".to_string(), Ipld::List(vec![entry1, entry2, entry3])), 302 + ])); 303 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 304 + let cid = make_cid(&node_bytes); 305 + 306 + let mut blocks = HashMap::new(); 307 + blocks.insert(cid, Bytes::from(node_bytes)); 308 + 309 + let result = verifier.verify_mst_structure(&cid, &blocks); 310 + assert!(result.is_ok(), "Prefix-compressed keys should be validated correctly"); 311 + } 312 + 313 + #[test] 314 + fn test_mst_validation_prefix_compression_unsorted() { 315 + use ipld_core::ipld::Ipld; 316 + 317 + let verifier = CarVerifier::new(); 318 + let record_cid = make_cid(b"record"); 319 + 320 + let entry1 = Ipld::Map(std::collections::BTreeMap::from([ 321 + ("k".to_string(), Ipld::Bytes(b"app.bsky.feed.post/xyz".to_vec())), 322 + ("v".to_string(), Ipld::Link(record_cid)), 323 + ("p".to_string(), Ipld::Integer(0)), 324 + ])); 325 + 326 + let entry2 = Ipld::Map(std::collections::BTreeMap::from([ 327 + ("k".to_string(), Ipld::Bytes(b"abc".to_vec())), 328 + ("v".to_string(), Ipld::Link(record_cid)), 329 + ("p".to_string(), Ipld::Integer(19)), 330 + ])); 331 + 332 + let node = Ipld::Map(std::collections::BTreeMap::from([ 333 + ("e".to_string(), Ipld::List(vec![entry1, entry2])), 334 + ])); 335 + let node_bytes = serde_ipld_dagcbor::to_vec(&node).unwrap(); 336 + let cid = make_cid(&node_bytes); 337 + 338 + let mut blocks = HashMap::new(); 339 + blocks.insert(cid, Bytes::from(node_bytes)); 340 + 341 + let result = verifier.verify_mst_structure(&cid, &blocks); 342 + assert!(result.is_err(), "Unsorted prefix-compressed keys should fail validation"); 343 + let err = result.unwrap_err(); 344 + assert!(matches!(err, VerifyError::MstValidationFailed(_))); 345 + } 346 + }
+103
src/util.rs
··· 1 + use rand::Rng; 2 + use sqlx::PgPool; 3 + use uuid::Uuid; 4 + 5 + const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 6 + 7 + pub fn generate_token_code() -> String { 8 + generate_token_code_parts(2, 5) 9 + } 10 + 11 + pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 12 + let mut rng = rand::thread_rng(); 13 + let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 14 + 15 + (0..parts) 16 + .map(|_| { 17 + (0..part_len) 18 + .map(|_| chars[rng.gen_range(0..chars.len())]) 19 + .collect::<String>() 20 + }) 21 + .collect::<Vec<_>>() 22 + .join("-") 23 + } 24 + 25 + #[derive(Debug)] 26 + pub enum DbLookupError { 27 + NotFound, 28 + DatabaseError(sqlx::Error), 29 + } 30 + 31 + impl From<sqlx::Error> for DbLookupError { 32 + fn from(e: sqlx::Error) -> Self { 33 + DbLookupError::DatabaseError(e) 34 + } 35 + } 36 + 37 + pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 38 + sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 39 + .fetch_optional(db) 40 + .await? 41 + .ok_or(DbLookupError::NotFound) 42 + } 43 + 44 + pub struct UserInfo { 45 + pub id: Uuid, 46 + pub did: String, 47 + pub handle: String, 48 + } 49 + 50 + pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 51 + sqlx::query_as!( 52 + UserInfo, 53 + "SELECT id, did, handle FROM users WHERE did = $1", 54 + did 55 + ) 56 + .fetch_optional(db) 57 + .await? 58 + .ok_or(DbLookupError::NotFound) 59 + } 60 + 61 + pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> { 62 + sqlx::query_as!( 63 + UserInfo, 64 + "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 65 + identifier 66 + ) 67 + .fetch_optional(db) 68 + .await? 69 + .ok_or(DbLookupError::NotFound) 70 + } 71 + 72 + #[cfg(test)] 73 + mod tests { 74 + use super::*; 75 + 76 + #[test] 77 + fn test_generate_token_code() { 78 + let code = generate_token_code(); 79 + assert_eq!(code.len(), 11); 80 + assert!(code.contains('-')); 81 + 82 + let parts: Vec<&str> = code.split('-').collect(); 83 + assert_eq!(parts.len(), 2); 84 + assert_eq!(parts[0].len(), 5); 85 + assert_eq!(parts[1].len(), 5); 86 + 87 + for c in code.chars() { 88 + if c != '-' { 89 + assert!(BASE32_ALPHABET.contains(c)); 90 + } 91 + } 92 + } 93 + 94 + #[test] 95 + fn test_generate_token_code_parts() { 96 + let code = generate_token_code_parts(3, 4); 97 + let parts: Vec<&str> = code.split('-').collect(); 98 + assert_eq!(parts.len(), 3); 99 + for part in parts { 100 + assert_eq!(part.len(), 4); 101 + } 102 + } 103 + }
+1 -1
tests/email_update.rs
··· 556 556 557 557 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 558 558 let body: Value = res.json().await.expect("Invalid JSON"); 559 - assert_eq!(body["error"], "InvalidRequest"); 559 + assert_eq!(body["error"], "InvalidEmail"); 560 560 }
+30 -12
tests/relay_client.rs
··· 13 13 async fn mock_relay_server( 14 14 listener: TcpListener, 15 15 event_tx: mpsc::Sender<Vec<u8>>, 16 - ready_tx: mpsc::Sender<()>, 16 + connected_tx: mpsc::Sender<()>, 17 17 ) { 18 18 let handler = |ws: axum::extract::ws::WebSocketUpgrade| async { 19 19 ws.on_upgrade(move |mut socket| async move { 20 - ready_tx.send(()).await.unwrap(); 21 - if let Some(Ok(Message::Binary(bytes))) = socket.recv().await { 22 - event_tx.send(bytes.to_vec()).await.unwrap(); 20 + let _ = connected_tx.send(()).await; 21 + while let Some(Ok(msg)) = socket.recv().await { 22 + if let Message::Binary(bytes) = msg { 23 + let _ = event_tx.send(bytes.to_vec()).await; 24 + break; 25 + } 23 26 } 24 27 }) 25 28 }; ··· 35 38 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 36 39 let addr = listener.local_addr().unwrap(); 37 40 let (event_tx, mut event_rx) = mpsc::channel(1); 38 - let (ready_tx, ready_rx) = mpsc::channel(1); 39 - tokio::spawn(mock_relay_server(listener, event_tx, ready_tx)); 41 + let (connected_tx, _connected_rx) = mpsc::channel::<()>(1); 42 + tokio::spawn(mock_relay_server(listener, event_tx, connected_tx)); 40 43 let relay_url = format!("ws://{}", addr); 41 44 42 45 let db_url = get_db_connection_string().await; ··· 46 49 .unwrap(); 47 50 let state = AppState::new(pool).await; 48 51 52 + let (ready_tx, ready_rx) = mpsc::channel(1); 49 53 start_relay_clients(state.clone(), vec![relay_url], Some(ready_rx)).await; 50 54 51 - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; 55 + tokio::time::timeout( 56 + tokio::time::Duration::from_secs(5), 57 + async { 58 + ready_tx.closed().await; 59 + } 60 + ) 61 + .await 62 + .expect("Timeout waiting for relay client to be ready"); 52 63 53 64 let dummy_event = SequencedEvent { 54 65 seq: 1, 55 66 did: "did:plc:test".to_string(), 56 67 created_at: Utc::now(), 57 68 event_type: "commit".to_string(), 58 - commit_cid: None, 69 + commit_cid: Some("bafyreihffx5a4o3qbv7vp6qmxpxok5mx5xvlsq6z4x3xv3zqv7vqvc7mzy".to_string()), 59 70 prev_cid: None, 60 - ops: None, 61 - blobs: None, 62 - blocks_cids: None, 71 + ops: Some(serde_json::json!([])), 72 + blobs: Some(vec![]), 73 + blocks_cids: Some(vec![]), 63 74 }; 64 75 state.firehose_tx.send(dummy_event).unwrap(); 65 76 66 - let received_bytes = event_rx.recv().await.expect("Did not receive event"); 77 + let received_bytes = tokio::time::timeout( 78 + tokio::time::Duration::from_secs(5), 79 + event_rx.recv() 80 + ) 81 + .await 82 + .expect("Timeout waiting for event") 83 + .expect("Event channel closed"); 84 + 67 85 assert!(!received_bytes.is_empty()); 68 86 }