this repo has no description
1use axum::{ 2 extract::{State, Path}, 3 Json, 4 response::{IntoResponse, Response}, 5 http::StatusCode, 6}; 7use serde::{Deserialize, Serialize}; 8use serde_json::json; 9use crate::state::AppState; 10use sqlx::Row; 11use bcrypt::{hash, DEFAULT_COST}; 12use tracing::{info, error}; 13use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore}; 14use jacquard::types::{string::Tid, did::Did, integer::LimitedU32}; 15use std::sync::Arc; 16use k256::SecretKey; 17use rand::rngs::OsRng; 18use base64::Engine; 19use reqwest; 20 21#[derive(Deserialize)] 22pub struct CreateAccountInput { 23 pub handle: String, 24 pub email: String, 25 pub password: String, 26 #[serde(rename = "inviteCode")] 27 pub invite_code: Option<String>, 28 pub did: Option<String>, 29} 30 31#[derive(Serialize)] 32#[serde(rename_all = "camelCase")] 33pub struct CreateAccountOutput { 34 pub access_jwt: String, 35 pub refresh_jwt: String, 36 pub handle: String, 37 pub did: String, 38} 39 40pub async fn create_account( 41 State(state): State<AppState>, 42 Json(input): Json<CreateAccountInput>, 43) -> Response { 44 info!("create_account hit: {}", input.handle); 45 if input.handle.contains('!') || input.handle.contains('@') { 46 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response(); 47 } 48 49 let did = if let Some(d) = &input.did { 50 if d.trim().is_empty() { 51 format!("did:plc:{}", uuid::Uuid::new_v4()) 52 } else { 53 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 54 if let Err(e) = verify_did_web(d, &hostname, &input.handle).await { 55 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response(); 56 } 57 d.clone() 58 } 59 } else { 60 format!("did:plc:{}", uuid::Uuid::new_v4()) 61 }; 62 63 let mut tx = match state.db.begin().await { 64 Ok(tx) => tx, 65 Err(e) => { 66 error!("Error starting transaction: {:?}", e); 67 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 68 } 69 }; 70 71 let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1") 72 .bind(&input.handle) 73 .fetch_optional(&mut *tx) 74 .await; 75 76 match exists_query { 77 Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(), 78 Err(e) => { 79 error!("Error checking handle: {:?}", e); 80 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 81 } 82 Ok(None) => {} 83 } 84 85 if let Some(code) = &input.invite_code { 86 let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE") 87 .bind(code) 88 .fetch_optional(&mut *tx) 89 .await; 90 91 match invite_query { 92 Ok(Some(row)) => { 93 let uses: i32 = row.get("available_uses"); 94 if uses <= 0 { 95 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response(); 96 } 97 98 let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1") 99 .bind(code) 100 .execute(&mut *tx) 101 .await; 102 103 if let Err(e) = update_invite { 104 error!("Error updating invite code: {:?}", e); 105 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 106 } 107 }, 108 Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(), 109 Err(e) => { 110 error!("Error checking invite code: {:?}", e); 111 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 112 } 113 } 114 } 115 116 let password_hash = match hash(&input.password, DEFAULT_COST) { 117 Ok(h) => h, 118 Err(e) => { 119 error!("Error hashing password: {:?}", e); 120 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 121 } 122 }; 123 124 let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id") 125 .bind(&input.handle) 126 .bind(&input.email) 127 .bind(&did) 128 .bind(&password_hash) 129 .fetch_one(&mut *tx) 130 .await; 131 132 let user_id: uuid::Uuid = match user_insert { 133 Ok(row) => row.get("id"), 134 Err(e) => { 135 error!("Error inserting user: {:?}", e); 136 // TODO: Check for unique constraint violation on email/did specifically 137 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 138 } 139 }; 140 141 let secret_key = SecretKey::random(&mut OsRng); 142 let secret_key_bytes = secret_key.to_bytes(); 143 144 let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)") 145 .bind(user_id) 146 .bind(&secret_key_bytes[..]) 147 .execute(&mut *tx) 148 .await; 149 150 if let Err(e) = key_insert { 151 error!("Error inserting user key: {:?}", e); 152 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 153 } 154 155 let mst = Mst::new(Arc::new(state.block_store.clone())); 156 let mst_root = match mst.root().await { 157 Ok(c) => c, 158 Err(e) => { 159 error!("Error creating MST root: {:?}", e); 160 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 161 } 162 }; 163 164 let did_obj = match Did::new(&did) { 165 Ok(d) => d, 166 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(), 167 }; 168 169 let rev = Tid::now(LimitedU32::MIN); 170 171 let commit = Commit::new_unsigned( 172 did_obj, 173 mst_root, 174 rev, 175 None 176 ); 177 178 let commit_bytes = match commit.to_cbor() { 179 Ok(b) => b, 180 Err(e) => { 181 error!("Error serializing genesis commit: {:?}", e); 182 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 183 } 184 }; 185 186 let commit_cid = match state.block_store.put(&commit_bytes).await { 187 Ok(c) => c, 188 Err(e) => { 189 error!("Error saving genesis commit: {:?}", e); 190 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 191 } 192 }; 193 194 let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)") 195 .bind(user_id) 196 .bind(commit_cid.to_string()) 197 .execute(&mut *tx) 198 .await; 199 200 if let Err(e) = repo_insert { 201 error!("Error initializing repo: {:?}", e); 202 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 203 } 204 205 if let Some(code) = &input.invite_code { 206 let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)") 207 .bind(code) 208 .bind(user_id) 209 .execute(&mut *tx) 210 .await; 211 212 if let Err(e) = use_insert { 213 error!("Error recording invite usage: {:?}", e); 214 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 215 } 216 } 217 218 let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| { 219 error!("Error creating access token: {:?}", e); 220 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() 221 }); 222 let access_jwt = match access_jwt { 223 Ok(t) => t, 224 Err(r) => return r, 225 }; 226 227 let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| { 228 error!("Error creating refresh token: {:?}", e); 229 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() 230 }); 231 let refresh_jwt = match refresh_jwt { 232 Ok(t) => t, 233 Err(r) => return r, 234 }; 235 236 let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)") 237 .bind(&access_jwt) 238 .bind(&refresh_jwt) 239 .bind(&did) 240 .execute(&mut *tx) 241 .await; 242 243 if let Err(e) = session_insert { 244 error!("Error inserting session: {:?}", e); 245 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 246 } 247 248 if let Err(e) = tx.commit().await { 249 error!("Error committing transaction: {:?}", e); 250 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(); 251 } 252 253 (StatusCode::OK, Json(CreateAccountOutput { 254 access_jwt, 255 refresh_jwt, 256 handle: input.handle, 257 did, 258 })).into_response() 259} 260 261fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 262 use k256::elliptic_curve::sec1::ToEncodedPoint; 263 264 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 265 let public_key = secret_key.public_key(); 266 let encoded = public_key.to_encoded_point(false); 267 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 268 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 269 270 json!({ 271 "kty": "EC", 272 "crv": "secp256k1", 273 "x": x, 274 "y": y 275 }) 276} 277 278pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 279 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 280 // Kinda for local dev, encode hostname if it contains port 281 let did = if hostname.contains(':') { 282 format!("did:web:{}", hostname.replace(':', "%3A")) 283 } else { 284 format!("did:web:{}", hostname) 285 }; 286 287 Json(json!({ 288 "@context": ["https://www.w3.org/ns/did/v1"], 289 "id": did, 290 "service": [{ 291 "id": "#atproto_pds", 292 "type": "AtprotoPersonalDataServer", 293 "serviceEndpoint": format!("https://{}", hostname) 294 }] 295 })) 296} 297 298pub async fn user_did_doc( 299 State(state): State<AppState>, 300 Path(handle): Path<String>, 301) -> Response { 302 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 303 304 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") 305 .bind(&handle) 306 .fetch_optional(&state.db) 307 .await; 308 309 let (user_id, did) = match user { 310 Ok(Some(row)) => { 311 let id: uuid::Uuid = row.get("id"); 312 let d: String = row.get("did"); 313 (id, d) 314 }, 315 Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(), 316 Err(e) => { 317 error!("DB Error: {:?}", e); 318 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response() 319 }, 320 }; 321 322 if !did.starts_with("did:web:") { 323 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response(); 324 } 325 326 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") 327 .bind(user_id) 328 .fetch_optional(&state.db) 329 .await; 330 331 let key_bytes: Vec<u8> = match key_row { 332 Ok(Some(row)) => row.get("key_bytes"), 333 _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(), 334 }; 335 336 let jwk = get_jwk(&key_bytes); 337 338 Json(json!({ 339 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 340 "id": did, 341 "alsoKnownAs": [format!("at://{}", handle)], 342 "verificationMethod": [{ 343 "id": format!("{}#atproto", did), 344 "type": "JsonWebKey2020", 345 "controller": did, 346 "publicKeyJwk": jwk 347 }], 348 "service": [{ 349 "id": "#atproto_pds", 350 "type": "AtprotoPersonalDataServer", 351 "serviceEndpoint": format!("https://{}", hostname) 352 }] 353 })).into_response() 354} 355 356async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 357 let expected_prefix = if hostname.contains(':') { 358 format!("did:web:{}", hostname.replace(':', "%3A")) 359 } else { 360 format!("did:web:{}", hostname) 361 }; 362 363 if did.starts_with(&expected_prefix) { 364 let suffix = &did[expected_prefix.len()..]; 365 let expected_suffix = format!(":u:{}", handle); 366 if suffix == expected_suffix { 367 Ok(()) 368 } else { 369 Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix)) 370 } 371 } else { 372 let parts: Vec<&str> = did.split(':').collect(); 373 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 374 return Err("Invalid did:web format".into()); 375 } 376 377 let domain_segment = parts[2]; 378 let domain = domain_segment.replace("%3A", ":"); 379 380 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 381 "http" 382 } else { 383 "https" 384 }; 385 386 let url = if parts.len() == 3 { 387 format!("{}://{}/.well-known/did.json", scheme, domain) 388 } else { 389 let path = parts[3..].join("/"); 390 format!("{}://{}/{}/did.json", scheme, domain, path) 391 }; 392 393 let client = reqwest::Client::builder() 394 .timeout(std::time::Duration::from_secs(5)) 395 .build() 396 .map_err(|e| format!("Failed to create client: {}", e))?; 397 398 let resp = client.get(&url).send().await 399 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 400 401 if !resp.status().is_success() { 402 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 403 } 404 405 let doc: serde_json::Value = resp.json().await 406 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 407 408 let services = doc["service"].as_array() 409 .ok_or("No services found in DID doc")?; 410 411 let pds_endpoint = format!("https://{}", hostname); 412 413 let has_valid_service = services.iter().any(|s| { 414 s["type"] == "AtprotoPersonalDataServer" && 415 s["serviceEndpoint"] == pds_endpoint 416 }); 417 418 if has_valid_service { 419 Ok(()) 420 } else { 421 Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint)) 422 } 423 } 424}