this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::{error, warn}; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 if handle.is_empty() { 28 return ( 29 StatusCode::BAD_REQUEST, 30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 31 ) 32 .into_response(); 33 } 34 let cache_key = format!("handle:{}", handle); 35 if let Some(did) = state.cache.get(&cache_key).await { 36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 37 } 38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 39 let suffix = format!(".{}", hostname); 40 let short_handle = if handle.ends_with(&suffix) { 41 handle.strip_suffix(&suffix).unwrap_or(handle) 42 } else { 43 handle 44 }; 45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 46 .fetch_optional(&state.db) 47 .await; 48 match user { 49 Ok(Some(row)) => { 50 let _ = state 51 .cache 52 .set(&cache_key, &row.did, std::time::Duration::from_secs(300)) 53 .await; 54 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 55 } 56 Ok(None) => { 57 match crate::handle::resolve_handle(handle).await { 58 Ok(did) => { 59 let _ = state 60 .cache 61 .set(&cache_key, &did, std::time::Duration::from_secs(300)) 62 .await; 63 (StatusCode::OK, Json(json!({ "did": did }))).into_response() 64 } 65 Err(_) => ( 66 StatusCode::NOT_FOUND, 67 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 68 ) 69 .into_response(), 70 } 71 } 72 Err(e) => { 73 error!("DB error resolving handle: {:?}", e); 74 ( 75 StatusCode::INTERNAL_SERVER_ERROR, 76 Json(json!({"error": "InternalError"})), 77 ) 78 .into_response() 79 } 80 } 81} 82 83pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 84 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 85 let public_key = secret_key.public_key(); 86 let encoded = public_key.to_encoded_point(false); 87 let x = encoded.x().ok_or("Missing x coordinate")?; 88 let y = encoded.y().ok_or("Missing y coordinate")?; 89 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 90 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 91 Ok(json!({ 92 "kty": "EC", 93 "crv": "secp256k1", 94 "x": x_b64, 95 "y": y_b64 96 })) 97} 98 99pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 100 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 // Kinda for local dev, encode hostname if it contains port 102 let did = if hostname.contains(':') { 103 format!("did:web:{}", hostname.replace(':', "%3A")) 104 } else { 105 format!("did:web:{}", hostname) 106 }; 107 Json(json!({ 108 "@context": ["https://www.w3.org/ns/did/v1"], 109 "id": did, 110 "service": [{ 111 "id": "#atproto_pds", 112 "type": "AtprotoPersonalDataServer", 113 "serviceEndpoint": format!("https://{}", hostname) 114 }] 115 })) 116} 117 118pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 119 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 120 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 121 .fetch_optional(&state.db) 122 .await; 123 let (user_id, did) = match user { 124 Ok(Some(row)) => (row.id, row.did), 125 Ok(None) => { 126 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 127 } 128 Err(e) => { 129 error!("DB Error: {:?}", e); 130 return ( 131 StatusCode::INTERNAL_SERVER_ERROR, 132 Json(json!({"error": "InternalError"})), 133 ) 134 .into_response(); 135 } 136 }; 137 if !did.starts_with("did:web:") { 138 return ( 139 StatusCode::NOT_FOUND, 140 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 141 ) 142 .into_response(); 143 } 144 let key_row = sqlx::query!( 145 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 146 user_id 147 ) 148 .fetch_optional(&state.db) 149 .await; 150 let key_bytes: Vec<u8> = match key_row { 151 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 152 Ok(k) => k, 153 Err(_) => { 154 return ( 155 StatusCode::INTERNAL_SERVER_ERROR, 156 Json(json!({"error": "InternalError"})), 157 ) 158 .into_response(); 159 } 160 }, 161 _ => { 162 return ( 163 StatusCode::INTERNAL_SERVER_ERROR, 164 Json(json!({"error": "InternalError"})), 165 ) 166 .into_response(); 167 } 168 }; 169 let jwk = match get_jwk(&key_bytes) { 170 Ok(j) => j, 171 Err(e) => { 172 tracing::error!("Failed to generate JWK: {}", e); 173 return ( 174 StatusCode::INTERNAL_SERVER_ERROR, 175 Json(json!({"error": "InternalError"})), 176 ) 177 .into_response(); 178 } 179 }; 180 Json(json!({ 181 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 182 "id": did, 183 "alsoKnownAs": [format!("at://{}", handle)], 184 "verificationMethod": [{ 185 "id": format!("{}#atproto", did), 186 "type": "JsonWebKey2020", 187 "controller": did, 188 "publicKeyJwk": jwk 189 }], 190 "service": [{ 191 "id": "#atproto_pds", 192 "type": "AtprotoPersonalDataServer", 193 "serviceEndpoint": format!("https://{}", hostname) 194 }] 195 })).into_response() 196} 197 198pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 199 let expected_prefix = if hostname.contains(':') { 200 format!("did:web:{}", hostname.replace(':', "%3A")) 201 } else { 202 format!("did:web:{}", hostname) 203 }; 204 if did.starts_with(&expected_prefix) { 205 let suffix = &did[expected_prefix.len()..]; 206 let expected_suffix = format!(":u:{}", handle); 207 if suffix == expected_suffix { 208 Ok(()) 209 } else { 210 Err(format!( 211 "Invalid DID path for this PDS. Expected {}", 212 expected_suffix 213 )) 214 } 215 } else { 216 let parts: Vec<&str> = did.split(':').collect(); 217 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 218 return Err("Invalid did:web format".into()); 219 } 220 let domain_segment = parts[2]; 221 let domain = domain_segment.replace("%3A", ":"); 222 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 223 "http" 224 } else { 225 "https" 226 }; 227 let url = if parts.len() == 3 { 228 format!("{}://{}/.well-known/did.json", scheme, domain) 229 } else { 230 let path = parts[3..].join("/"); 231 format!("{}://{}/{}/did.json", scheme, domain, path) 232 }; 233 let client = reqwest::Client::builder() 234 .timeout(std::time::Duration::from_secs(5)) 235 .build() 236 .map_err(|e| format!("Failed to create client: {}", e))?; 237 let resp = client 238 .get(&url) 239 .send() 240 .await 241 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 242 if !resp.status().is_success() { 243 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 244 } 245 let doc: serde_json::Value = resp 246 .json() 247 .await 248 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 249 let services = doc["service"] 250 .as_array() 251 .ok_or("No services found in DID doc")?; 252 let pds_endpoint = format!("https://{}", hostname); 253 let has_valid_service = services.iter().any(|s| { 254 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 255 }); 256 if has_valid_service { 257 Ok(()) 258 } else { 259 Err(format!( 260 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 261 pds_endpoint 262 )) 263 } 264 } 265} 266 267#[derive(serde::Serialize)] 268#[serde(rename_all = "camelCase")] 269pub struct GetRecommendedDidCredentialsOutput { 270 pub rotation_keys: Vec<String>, 271 pub also_known_as: Vec<String>, 272 pub verification_methods: VerificationMethods, 273 pub services: Services, 274} 275 276#[derive(serde::Serialize)] 277#[serde(rename_all = "camelCase")] 278pub struct VerificationMethods { 279 pub atproto: String, 280} 281 282#[derive(serde::Serialize)] 283#[serde(rename_all = "camelCase")] 284pub struct Services { 285 pub atproto_pds: AtprotoPds, 286} 287 288#[derive(serde::Serialize)] 289#[serde(rename_all = "camelCase")] 290pub struct AtprotoPds { 291 #[serde(rename = "type")] 292 pub service_type: String, 293 pub endpoint: String, 294} 295 296pub async fn get_recommended_did_credentials( 297 State(state): State<AppState>, 298 headers: axum::http::HeaderMap, 299) -> Response { 300 let token = match crate::auth::extract_bearer_token_from_header( 301 headers.get("Authorization").and_then(|h| h.to_str().ok()), 302 ) { 303 Some(t) => t, 304 None => { 305 return ( 306 StatusCode::UNAUTHORIZED, 307 Json(json!({"error": "AuthenticationRequired"})), 308 ) 309 .into_response(); 310 } 311 }; 312 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 313 Ok(user) => user, 314 Err(e) => return ApiError::from(e).into_response(), 315 }; 316 let user = match sqlx::query!( 317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 318 auth_user.did 319 ) 320 .fetch_optional(&state.db) 321 .await 322 { 323 Ok(Some(row)) => row, 324 _ => return ApiError::InternalError.into_response(), 325 }; 326 let key_bytes = match auth_user.key_bytes { 327 Some(kb) => kb, 328 None => { 329 return ApiError::AuthenticationFailedMsg( 330 "OAuth tokens cannot get DID credentials".into(), 331 ) 332 .into_response(); 333 } 334 }; 335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 336 let pds_endpoint = format!("https://{}", hostname); 337 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 338 Ok(k) => k, 339 Err(_) => return ApiError::InternalError.into_response(), 340 }; 341 let public_key = secret_key.public_key(); 342 let encoded = public_key.to_encoded_point(true); 343 let did_key = format!( 344 "did:key:zQ3sh{}", 345 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 346 .chars() 347 .skip(1) 348 .collect::<String>() 349 ); 350 ( 351 StatusCode::OK, 352 Json(GetRecommendedDidCredentialsOutput { 353 rotation_keys: vec![did_key.clone()], 354 also_known_as: vec![format!("at://{}", user.handle)], 355 verification_methods: VerificationMethods { atproto: did_key }, 356 services: Services { 357 atproto_pds: AtprotoPds { 358 service_type: "AtprotoPersonalDataServer".to_string(), 359 endpoint: pds_endpoint, 360 }, 361 }, 362 }), 363 ) 364 .into_response() 365} 366 367#[derive(Deserialize)] 368pub struct UpdateHandleInput { 369 pub handle: String, 370} 371 372pub async fn update_handle( 373 State(state): State<AppState>, 374 headers: axum::http::HeaderMap, 375 Json(input): Json<UpdateHandleInput>, 376) -> Response { 377 let token = match crate::auth::extract_bearer_token_from_header( 378 headers.get("Authorization").and_then(|h| h.to_str().ok()), 379 ) { 380 Some(t) => t, 381 None => return ApiError::AuthenticationRequired.into_response(), 382 }; 383 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 384 Ok(user) => user.did, 385 Err(e) => return ApiError::from(e).into_response(), 386 }; 387 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 388 .fetch_optional(&state.db) 389 .await 390 { 391 Ok(Some(id)) => id, 392 _ => return ApiError::InternalError.into_response(), 393 }; 394 let new_handle = input.handle.trim(); 395 if new_handle.is_empty() { 396 return ApiError::InvalidRequest("handle is required".into()).into_response(); 397 } 398 if !new_handle 399 .chars() 400 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 401 { 402 return ( 403 StatusCode::BAD_REQUEST, 404 Json( 405 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 406 ), 407 ) 408 .into_response(); 409 } 410 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 411 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname); 412 let (handle_to_store, full_handle) = if is_service_domain { 413 let suffix = format!(".{}", hostname); 414 let short_handle = if new_handle.ends_with(&suffix) { 415 new_handle.strip_suffix(&suffix).unwrap_or(new_handle) 416 } else { 417 new_handle 418 }; 419 (short_handle.to_string(), format!("{}.{}", short_handle, hostname)) 420 } else { 421 match crate::handle::verify_handle_ownership(new_handle, &did).await { 422 Ok(()) => {} 423 Err(crate::handle::HandleResolutionError::NotFound) => { 424 return ( 425 StatusCode::BAD_REQUEST, 426 Json(json!({ 427 "error": "HandleNotAvailable", 428 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did", 429 "handle": new_handle 430 })), 431 ) 432 .into_response(); 433 } 434 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => { 435 return ( 436 StatusCode::BAD_REQUEST, 437 Json(json!({ 438 "error": "HandleNotAvailable", 439 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual) 440 })), 441 ) 442 .into_response(); 443 } 444 Err(e) => { 445 warn!("Handle verification failed: {}", e); 446 return ( 447 StatusCode::BAD_REQUEST, 448 Json(json!({ 449 "error": "HandleNotAvailable", 450 "message": format!("Handle verification failed: {}", e) 451 })), 452 ) 453 .into_response(); 454 } 455 } 456 (new_handle.to_string(), new_handle.to_string()) 457 }; 458 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 459 .fetch_optional(&state.db) 460 .await 461 .ok() 462 .flatten(); 463 let existing = sqlx::query!( 464 "SELECT id FROM users WHERE handle = $1 AND id != $2", 465 handle_to_store, 466 user_id 467 ) 468 .fetch_optional(&state.db) 469 .await; 470 if let Ok(Some(_)) = existing { 471 return ( 472 StatusCode::BAD_REQUEST, 473 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 474 ) 475 .into_response(); 476 } 477 let result = sqlx::query!( 478 "UPDATE users SET handle = $1 WHERE id = $2", 479 handle_to_store, 480 user_id 481 ) 482 .execute(&state.db) 483 .await; 484 match result { 485 Ok(_) => { 486 if let Some(old) = old_handle { 487 let _ = state.cache.delete(&format!("handle:{}", old)).await; 488 } 489 let _ = state 490 .cache 491 .delete(&format!("handle:{}", handle_to_store)) 492 .await; 493 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await; 494 if let Err(e) = 495 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 496 .await 497 { 498 warn!("Failed to sequence identity event for handle update: {}", e); 499 } 500 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await { 501 warn!("Failed to update PLC handle: {}", e); 502 } 503 (StatusCode::OK, Json(json!({}))).into_response() 504 } 505 Err(e) => { 506 error!("DB error updating handle: {:?}", e); 507 ( 508 StatusCode::INTERNAL_SERVER_ERROR, 509 Json(json!({"error": "InternalError"})), 510 ) 511 .into_response() 512 } 513 } 514} 515 516async fn update_plc_handle( 517 state: &AppState, 518 did: &str, 519 new_handle: &str, 520) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 521 if !did.starts_with("did:plc:") { 522 return Ok(()); 523 } 524 let user_row = sqlx::query!( 525 r#"SELECT u.id, uk.key_bytes, uk.encryption_version 526 FROM users u 527 JOIN user_keys uk ON u.id = uk.user_id 528 WHERE u.did = $1"#, 529 did 530 ) 531 .fetch_optional(&state.db) 532 .await?; 533 let user_row = match user_row { 534 Some(r) => r, 535 None => return Ok(()), 536 }; 537 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?; 538 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?; 539 let plc_client = crate::plc::PlcClient::new(None); 540 let last_op = plc_client.get_last_op(did).await?; 541 let new_also_known_as = vec![format!("at://{}", new_handle)]; 542 let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 543 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 544 plc_client.send_operation(did, &signed_op).await?; 545 Ok(()) 546} 547 548pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 549 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 550 Some(h) => h, 551 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 552 }; 553 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 554 let suffix = format!(".{}", hostname); 555 let handle = host.split(':').next().unwrap_or(host); 556 let short_handle = if handle.ends_with(&suffix) { 557 handle.strip_suffix(&suffix).unwrap_or(handle) 558 } else { 559 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 560 }; 561 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 562 .fetch_optional(&state.db) 563 .await; 564 match user { 565 Ok(Some(row)) => row.did.into_response(), 566 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 567 Err(e) => { 568 error!("DB error in well-known atproto-did: {:?}", e); 569 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 570 } 571 } 572}