this repo has no description
1use crate::api::ApiError; 2use crate::plc::signing_key_to_did_key; 3use crate::state::AppState; 4use axum::{ 5 Json, 6 extract::{Path, Query, State}, 7 http::{HeaderMap, StatusCode}, 8 response::{IntoResponse, Response}, 9}; 10use base64::Engine; 11use k256::SecretKey; 12use k256::elliptic_curve::sec1::ToEncodedPoint; 13use reqwest; 14use serde::Deserialize; 15use serde_json::json; 16use tracing::{error, warn}; 17 18#[derive(Deserialize)] 19pub struct ResolveHandleParams { 20 pub handle: String, 21} 22 23pub async fn resolve_handle( 24 State(state): State<AppState>, 25 Query(params): Query<ResolveHandleParams>, 26) -> Response { 27 let handle = params.handle.trim(); 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 let cache_key = format!("handle:{}", handle); 36 if let Some(did) = state.cache.get(&cache_key).await { 37 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 38 } 39 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 40 let suffix = format!(".{}", hostname); 41 let short_handle = if handle.ends_with(&suffix) { 42 handle.strip_suffix(&suffix).unwrap_or(handle) 43 } else { 44 handle 45 }; 46 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 47 .fetch_optional(&state.db) 48 .await; 49 match user { 50 Ok(Some(row)) => { 51 let _ = state 52 .cache 53 .set(&cache_key, &row.did, std::time::Duration::from_secs(300)) 54 .await; 55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 56 } 57 Ok(None) => match crate::handle::resolve_handle(handle).await { 58 Ok(did) => { 59 let _ = state 60 .cache 61 .set(&cache_key, &did, std::time::Duration::from_secs(300)) 62 .await; 63 (StatusCode::OK, Json(json!({ "did": did }))).into_response() 64 } 65 Err(_) => ( 66 StatusCode::NOT_FOUND, 67 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 68 ) 69 .into_response(), 70 }, 71 Err(e) => { 72 error!("DB error resolving handle: {:?}", e); 73 ( 74 StatusCode::INTERNAL_SERVER_ERROR, 75 Json(json!({"error": "InternalError"})), 76 ) 77 .into_response() 78 } 79 } 80} 81 82pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 83 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 84 let public_key = secret_key.public_key(); 85 let encoded = public_key.to_encoded_point(false); 86 let x = encoded.x().ok_or("Missing x coordinate")?; 87 let y = encoded.y().ok_or("Missing y coordinate")?; 88 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 89 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 90 Ok(json!({ 91 "kty": "EC", 92 "crv": "secp256k1", 93 "x": x_b64, 94 "y": y_b64 95 })) 96} 97 98pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 100 // Kinda for local dev, encode hostname if it contains port 101 let did = if hostname.contains(':') { 102 format!("did:web:{}", hostname.replace(':', "%3A")) 103 } else { 104 format!("did:web:{}", hostname) 105 }; 106 Json(json!({ 107 "@context": ["https://www.w3.org/ns/did/v1"], 108 "id": did, 109 "service": [{ 110 "id": "#atproto_pds", 111 "type": "AtprotoPersonalDataServer", 112 "serviceEndpoint": format!("https://{}", hostname) 113 }] 114 })) 115} 116 117pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 118 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 119 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 120 .fetch_optional(&state.db) 121 .await; 122 let (user_id, did) = match user { 123 Ok(Some(row)) => (row.id, row.did), 124 Ok(None) => { 125 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 126 } 127 Err(e) => { 128 error!("DB Error: {:?}", e); 129 return ( 130 StatusCode::INTERNAL_SERVER_ERROR, 131 Json(json!({"error": "InternalError"})), 132 ) 133 .into_response(); 134 } 135 }; 136 if !did.starts_with("did:web:") { 137 return ( 138 StatusCode::NOT_FOUND, 139 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 140 ) 141 .into_response(); 142 } 143 let key_row = sqlx::query!( 144 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 145 user_id 146 ) 147 .fetch_optional(&state.db) 148 .await; 149 let key_bytes: Vec<u8> = match key_row { 150 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 151 Ok(k) => k, 152 Err(_) => { 153 return ( 154 StatusCode::INTERNAL_SERVER_ERROR, 155 Json(json!({"error": "InternalError"})), 156 ) 157 .into_response(); 158 } 159 }, 160 _ => { 161 return ( 162 StatusCode::INTERNAL_SERVER_ERROR, 163 Json(json!({"error": "InternalError"})), 164 ) 165 .into_response(); 166 } 167 }; 168 let jwk = match get_jwk(&key_bytes) { 169 Ok(j) => j, 170 Err(e) => { 171 tracing::error!("Failed to generate JWK: {}", e); 172 return ( 173 StatusCode::INTERNAL_SERVER_ERROR, 174 Json(json!({"error": "InternalError"})), 175 ) 176 .into_response(); 177 } 178 }; 179 Json(json!({ 180 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 181 "id": did, 182 "alsoKnownAs": [format!("at://{}", handle)], 183 "verificationMethod": [{ 184 "id": format!("{}#atproto", did), 185 "type": "JsonWebKey2020", 186 "controller": did, 187 "publicKeyJwk": jwk 188 }], 189 "service": [{ 190 "id": "#atproto_pds", 191 "type": "AtprotoPersonalDataServer", 192 "serviceEndpoint": format!("https://{}", hostname) 193 }] 194 })).into_response() 195} 196 197pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 198 let expected_prefix = if hostname.contains(':') { 199 format!("did:web:{}", hostname.replace(':', "%3A")) 200 } else { 201 format!("did:web:{}", hostname) 202 }; 203 if did.starts_with(&expected_prefix) { 204 let suffix = &did[expected_prefix.len()..]; 205 let expected_suffix = format!(":u:{}", handle); 206 if suffix == expected_suffix { 207 Ok(()) 208 } else { 209 Err(format!( 210 "Invalid DID path for this PDS. Expected {}", 211 expected_suffix 212 )) 213 } 214 } else { 215 let parts: Vec<&str> = did.split(':').collect(); 216 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 217 return Err("Invalid did:web format".into()); 218 } 219 let domain_segment = parts[2]; 220 let domain = domain_segment.replace("%3A", ":"); 221 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 222 "http" 223 } else { 224 "https" 225 }; 226 let url = if parts.len() == 3 { 227 format!("{}://{}/.well-known/did.json", scheme, domain) 228 } else { 229 let path = parts[3..].join("/"); 230 format!("{}://{}/{}/did.json", scheme, domain, path) 231 }; 232 let client = reqwest::Client::builder() 233 .timeout(std::time::Duration::from_secs(5)) 234 .build() 235 .map_err(|e| format!("Failed to create client: {}", e))?; 236 let resp = client 237 .get(&url) 238 .send() 239 .await 240 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 241 if !resp.status().is_success() { 242 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 243 } 244 let doc: serde_json::Value = resp 245 .json() 246 .await 247 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 248 let services = doc["service"] 249 .as_array() 250 .ok_or("No services found in DID doc")?; 251 let pds_endpoint = format!("https://{}", hostname); 252 let has_valid_service = services.iter().any(|s| { 253 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 254 }); 255 if has_valid_service { 256 Ok(()) 257 } else { 258 Err(format!( 259 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 260 pds_endpoint 261 )) 262 } 263 } 264} 265 266#[derive(serde::Serialize)] 267#[serde(rename_all = "camelCase")] 268pub struct GetRecommendedDidCredentialsOutput { 269 pub rotation_keys: Vec<String>, 270 pub also_known_as: Vec<String>, 271 pub verification_methods: VerificationMethods, 272 pub services: Services, 273} 274 275#[derive(serde::Serialize)] 276#[serde(rename_all = "camelCase")] 277pub struct VerificationMethods { 278 pub atproto: String, 279} 280 281#[derive(serde::Serialize)] 282#[serde(rename_all = "camelCase")] 283pub struct Services { 284 pub atproto_pds: AtprotoPds, 285} 286 287#[derive(serde::Serialize)] 288#[serde(rename_all = "camelCase")] 289pub struct AtprotoPds { 290 #[serde(rename = "type")] 291 pub service_type: String, 292 pub endpoint: String, 293} 294 295pub async fn get_recommended_did_credentials( 296 State(state): State<AppState>, 297 headers: axum::http::HeaderMap, 298) -> Response { 299 let token = match crate::auth::extract_bearer_token_from_header( 300 headers.get("Authorization").and_then(|h| h.to_str().ok()), 301 ) { 302 Some(t) => t, 303 None => { 304 return ( 305 StatusCode::UNAUTHORIZED, 306 Json(json!({"error": "AuthenticationRequired"})), 307 ) 308 .into_response(); 309 } 310 }; 311 let auth_user = 312 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 313 Ok(user) => user, 314 Err(e) => return ApiError::from(e).into_response(), 315 }; 316 let user = match sqlx::query!( 317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 318 auth_user.did 319 ) 320 .fetch_optional(&state.db) 321 .await 322 { 323 Ok(Some(row)) => row, 324 _ => return ApiError::InternalError.into_response(), 325 }; 326 let key_bytes = match auth_user.key_bytes { 327 Some(kb) => kb, 328 None => { 329 return ApiError::AuthenticationFailedMsg( 330 "OAuth tokens cannot get DID credentials".into(), 331 ) 332 .into_response(); 333 } 334 }; 335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 336 let pds_endpoint = format!("https://{}", hostname); 337 let full_handle = if user.handle.contains('.') { 338 user.handle.clone() 339 } else { 340 format!("{}.{}", user.handle, hostname) 341 }; 342 let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) { 343 Ok(k) => k, 344 Err(_) => return ApiError::InternalError.into_response(), 345 }; 346 let did_key = signing_key_to_did_key(&signing_key); 347 ( 348 StatusCode::OK, 349 Json(GetRecommendedDidCredentialsOutput { 350 rotation_keys: vec![did_key.clone()], 351 also_known_as: vec![format!("at://{}", full_handle)], 352 verification_methods: VerificationMethods { atproto: did_key }, 353 services: Services { 354 atproto_pds: AtprotoPds { 355 service_type: "AtprotoPersonalDataServer".to_string(), 356 endpoint: pds_endpoint, 357 }, 358 }, 359 }), 360 ) 361 .into_response() 362} 363 364#[derive(Deserialize)] 365pub struct UpdateHandleInput { 366 pub handle: String, 367} 368 369pub async fn update_handle( 370 State(state): State<AppState>, 371 headers: axum::http::HeaderMap, 372 Json(input): Json<UpdateHandleInput>, 373) -> Response { 374 let token = match crate::auth::extract_bearer_token_from_header( 375 headers.get("Authorization").and_then(|h| h.to_str().ok()), 376 ) { 377 Some(t) => t, 378 None => return ApiError::AuthenticationRequired.into_response(), 379 }; 380 let auth_user = 381 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 Ok(user) => user, 383 Err(e) => return ApiError::from(e).into_response(), 384 }; 385 if let Err(e) = crate::auth::scope_check::check_identity_scope( 386 auth_user.is_oauth, 387 auth_user.scope.as_deref(), 388 crate::oauth::scopes::IdentityAttr::Handle, 389 ) { 390 return e; 391 } 392 let did = auth_user.did; 393 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 394 .fetch_optional(&state.db) 395 .await 396 { 397 Ok(Some(id)) => id, 398 _ => return ApiError::InternalError.into_response(), 399 }; 400 let new_handle = input.handle.trim(); 401 if new_handle.is_empty() { 402 return ApiError::InvalidRequest("handle is required".into()).into_response(); 403 } 404 if !new_handle 405 .chars() 406 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 407 { 408 return ( 409 StatusCode::BAD_REQUEST, 410 Json( 411 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 412 ), 413 ) 414 .into_response(); 415 } 416 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 417 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname); 418 let (handle_to_store, full_handle) = if is_service_domain { 419 let suffix = format!(".{}", hostname); 420 let short_handle = if new_handle.ends_with(&suffix) { 421 new_handle.strip_suffix(&suffix).unwrap_or(new_handle) 422 } else { 423 new_handle 424 }; 425 ( 426 short_handle.to_string(), 427 format!("{}.{}", short_handle, hostname), 428 ) 429 } else { 430 match crate::handle::verify_handle_ownership(new_handle, &did).await { 431 Ok(()) => {} 432 Err(crate::handle::HandleResolutionError::NotFound) => { 433 return ( 434 StatusCode::BAD_REQUEST, 435 Json(json!({ 436 "error": "HandleNotAvailable", 437 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did", 438 "handle": new_handle 439 })), 440 ) 441 .into_response(); 442 } 443 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => { 444 return ( 445 StatusCode::BAD_REQUEST, 446 Json(json!({ 447 "error": "HandleNotAvailable", 448 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual) 449 })), 450 ) 451 .into_response(); 452 } 453 Err(e) => { 454 warn!("Handle verification failed: {}", e); 455 return ( 456 StatusCode::BAD_REQUEST, 457 Json(json!({ 458 "error": "HandleNotAvailable", 459 "message": format!("Handle verification failed: {}", e) 460 })), 461 ) 462 .into_response(); 463 } 464 } 465 (new_handle.to_string(), new_handle.to_string()) 466 }; 467 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 468 .fetch_optional(&state.db) 469 .await 470 .ok() 471 .flatten(); 472 let existing = sqlx::query!( 473 "SELECT id FROM users WHERE handle = $1 AND id != $2", 474 handle_to_store, 475 user_id 476 ) 477 .fetch_optional(&state.db) 478 .await; 479 if let Ok(Some(_)) = existing { 480 return ( 481 StatusCode::BAD_REQUEST, 482 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 483 ) 484 .into_response(); 485 } 486 let result = sqlx::query!( 487 "UPDATE users SET handle = $1 WHERE id = $2", 488 handle_to_store, 489 user_id 490 ) 491 .execute(&state.db) 492 .await; 493 match result { 494 Ok(_) => { 495 if let Some(old) = old_handle { 496 let _ = state.cache.delete(&format!("handle:{}", old)).await; 497 } 498 let _ = state 499 .cache 500 .delete(&format!("handle:{}", handle_to_store)) 501 .await; 502 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await; 503 if let Err(e) = 504 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 505 .await 506 { 507 warn!("Failed to sequence identity event for handle update: {}", e); 508 } 509 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await { 510 warn!("Failed to update PLC handle: {}", e); 511 } 512 (StatusCode::OK, Json(json!({}))).into_response() 513 } 514 Err(e) => { 515 error!("DB error updating handle: {:?}", e); 516 ( 517 StatusCode::INTERNAL_SERVER_ERROR, 518 Json(json!({"error": "InternalError"})), 519 ) 520 .into_response() 521 } 522 } 523} 524 525async fn update_plc_handle( 526 state: &AppState, 527 did: &str, 528 new_handle: &str, 529) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 530 if !did.starts_with("did:plc:") { 531 return Ok(()); 532 } 533 let user_row = sqlx::query!( 534 r#"SELECT u.id, uk.key_bytes, uk.encryption_version 535 FROM users u 536 JOIN user_keys uk ON u.id = uk.user_id 537 WHERE u.did = $1"#, 538 did 539 ) 540 .fetch_optional(&state.db) 541 .await?; 542 let user_row = match user_row { 543 Some(r) => r, 544 None => return Ok(()), 545 }; 546 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?; 547 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?; 548 let plc_client = crate::plc::PlcClient::new(None); 549 let last_op = plc_client.get_last_op(did).await?; 550 let new_also_known_as = vec![format!("at://{}", new_handle)]; 551 let update_op = 552 crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 553 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 554 plc_client.send_operation(did, &signed_op).await?; 555 Ok(()) 556} 557 558pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 559 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 560 Some(h) => h, 561 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 562 }; 563 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 564 let suffix = format!(".{}", hostname); 565 let handle = host.split(':').next().unwrap_or(host); 566 let short_handle = if handle.ends_with(&suffix) { 567 handle.strip_suffix(&suffix).unwrap_or(handle) 568 } else { 569 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 570 }; 571 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 572 .fetch_optional(&state.db) 573 .await; 574 match user { 575 Ok(Some(row)) => row.did.into_response(), 576 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 577 Err(e) => { 578 error!("DB error in well-known atproto-did: {:?}", e); 579 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 580 } 581 } 582}