this repo has no description
1use crate::api::ApiError; 2use crate::plc::signing_key_to_did_key; 3use crate::state::AppState; 4use axum::{ 5 Json, 6 extract::{Path, Query, State}, 7 http::{HeaderMap, StatusCode}, 8 response::{IntoResponse, Response}, 9}; 10use base64::Engine; 11use k256::SecretKey; 12use k256::elliptic_curve::sec1::ToEncodedPoint; 13use reqwest; 14use serde::Deserialize; 15use serde_json::json; 16use tracing::{error, warn}; 17 18#[derive(Deserialize)] 19pub struct ResolveHandleParams { 20 pub handle: String, 21} 22 23pub async fn resolve_handle( 24 State(state): State<AppState>, 25 Query(params): Query<ResolveHandleParams>, 26) -> Response { 27 let handle = params.handle.trim(); 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 let cache_key = format!("handle:{}", handle); 36 if let Some(did) = state.cache.get(&cache_key).await { 37 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 38 } 39 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 40 let suffix = format!(".{}", hostname); 41 let short_handle = if handle.ends_with(&suffix) { 42 handle.strip_suffix(&suffix).unwrap_or(handle) 43 } else { 44 handle 45 }; 46 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 47 .fetch_optional(&state.db) 48 .await; 49 match user { 50 Ok(Some(row)) => { 51 let _ = state 52 .cache 53 .set(&cache_key, &row.did, std::time::Duration::from_secs(300)) 54 .await; 55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 56 } 57 Ok(None) => { 58 match crate::handle::resolve_handle(handle).await { 59 Ok(did) => { 60 let _ = state 61 .cache 62 .set(&cache_key, &did, std::time::Duration::from_secs(300)) 63 .await; 64 (StatusCode::OK, Json(json!({ "did": did }))).into_response() 65 } 66 Err(_) => ( 67 StatusCode::NOT_FOUND, 68 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 69 ) 70 .into_response(), 71 } 72 } 73 Err(e) => { 74 error!("DB error resolving handle: {:?}", e); 75 ( 76 StatusCode::INTERNAL_SERVER_ERROR, 77 Json(json!({"error": "InternalError"})), 78 ) 79 .into_response() 80 } 81 } 82} 83 84pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 85 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 86 let public_key = secret_key.public_key(); 87 let encoded = public_key.to_encoded_point(false); 88 let x = encoded.x().ok_or("Missing x coordinate")?; 89 let y = encoded.y().ok_or("Missing y coordinate")?; 90 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 91 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 92 Ok(json!({ 93 "kty": "EC", 94 "crv": "secp256k1", 95 "x": x_b64, 96 "y": y_b64 97 })) 98} 99 100pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 101 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 102 // Kinda for local dev, encode hostname if it contains port 103 let did = if hostname.contains(':') { 104 format!("did:web:{}", hostname.replace(':', "%3A")) 105 } else { 106 format!("did:web:{}", hostname) 107 }; 108 Json(json!({ 109 "@context": ["https://www.w3.org/ns/did/v1"], 110 "id": did, 111 "service": [{ 112 "id": "#atproto_pds", 113 "type": "AtprotoPersonalDataServer", 114 "serviceEndpoint": format!("https://{}", hostname) 115 }] 116 })) 117} 118 119pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 120 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 121 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 122 .fetch_optional(&state.db) 123 .await; 124 let (user_id, did) = match user { 125 Ok(Some(row)) => (row.id, row.did), 126 Ok(None) => { 127 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 128 } 129 Err(e) => { 130 error!("DB Error: {:?}", e); 131 return ( 132 StatusCode::INTERNAL_SERVER_ERROR, 133 Json(json!({"error": "InternalError"})), 134 ) 135 .into_response(); 136 } 137 }; 138 if !did.starts_with("did:web:") { 139 return ( 140 StatusCode::NOT_FOUND, 141 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 142 ) 143 .into_response(); 144 } 145 let key_row = sqlx::query!( 146 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 147 user_id 148 ) 149 .fetch_optional(&state.db) 150 .await; 151 let key_bytes: Vec<u8> = match key_row { 152 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 153 Ok(k) => k, 154 Err(_) => { 155 return ( 156 StatusCode::INTERNAL_SERVER_ERROR, 157 Json(json!({"error": "InternalError"})), 158 ) 159 .into_response(); 160 } 161 }, 162 _ => { 163 return ( 164 StatusCode::INTERNAL_SERVER_ERROR, 165 Json(json!({"error": "InternalError"})), 166 ) 167 .into_response(); 168 } 169 }; 170 let jwk = match get_jwk(&key_bytes) { 171 Ok(j) => j, 172 Err(e) => { 173 tracing::error!("Failed to generate JWK: {}", e); 174 return ( 175 StatusCode::INTERNAL_SERVER_ERROR, 176 Json(json!({"error": "InternalError"})), 177 ) 178 .into_response(); 179 } 180 }; 181 Json(json!({ 182 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 183 "id": did, 184 "alsoKnownAs": [format!("at://{}", handle)], 185 "verificationMethod": [{ 186 "id": format!("{}#atproto", did), 187 "type": "JsonWebKey2020", 188 "controller": did, 189 "publicKeyJwk": jwk 190 }], 191 "service": [{ 192 "id": "#atproto_pds", 193 "type": "AtprotoPersonalDataServer", 194 "serviceEndpoint": format!("https://{}", hostname) 195 }] 196 })).into_response() 197} 198 199pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 200 let expected_prefix = if hostname.contains(':') { 201 format!("did:web:{}", hostname.replace(':', "%3A")) 202 } else { 203 format!("did:web:{}", hostname) 204 }; 205 if did.starts_with(&expected_prefix) { 206 let suffix = &did[expected_prefix.len()..]; 207 let expected_suffix = format!(":u:{}", handle); 208 if suffix == expected_suffix { 209 Ok(()) 210 } else { 211 Err(format!( 212 "Invalid DID path for this PDS. Expected {}", 213 expected_suffix 214 )) 215 } 216 } else { 217 let parts: Vec<&str> = did.split(':').collect(); 218 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 219 return Err("Invalid did:web format".into()); 220 } 221 let domain_segment = parts[2]; 222 let domain = domain_segment.replace("%3A", ":"); 223 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 224 "http" 225 } else { 226 "https" 227 }; 228 let url = if parts.len() == 3 { 229 format!("{}://{}/.well-known/did.json", scheme, domain) 230 } else { 231 let path = parts[3..].join("/"); 232 format!("{}://{}/{}/did.json", scheme, domain, path) 233 }; 234 let client = reqwest::Client::builder() 235 .timeout(std::time::Duration::from_secs(5)) 236 .build() 237 .map_err(|e| format!("Failed to create client: {}", e))?; 238 let resp = client 239 .get(&url) 240 .send() 241 .await 242 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 243 if !resp.status().is_success() { 244 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 245 } 246 let doc: serde_json::Value = resp 247 .json() 248 .await 249 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 250 let services = doc["service"] 251 .as_array() 252 .ok_or("No services found in DID doc")?; 253 let pds_endpoint = format!("https://{}", hostname); 254 let has_valid_service = services.iter().any(|s| { 255 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 256 }); 257 if has_valid_service { 258 Ok(()) 259 } else { 260 Err(format!( 261 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 262 pds_endpoint 263 )) 264 } 265 } 266} 267 268#[derive(serde::Serialize)] 269#[serde(rename_all = "camelCase")] 270pub struct GetRecommendedDidCredentialsOutput { 271 pub rotation_keys: Vec<String>, 272 pub also_known_as: Vec<String>, 273 pub verification_methods: VerificationMethods, 274 pub services: Services, 275} 276 277#[derive(serde::Serialize)] 278#[serde(rename_all = "camelCase")] 279pub struct VerificationMethods { 280 pub atproto: String, 281} 282 283#[derive(serde::Serialize)] 284#[serde(rename_all = "camelCase")] 285pub struct Services { 286 pub atproto_pds: AtprotoPds, 287} 288 289#[derive(serde::Serialize)] 290#[serde(rename_all = "camelCase")] 291pub struct AtprotoPds { 292 #[serde(rename = "type")] 293 pub service_type: String, 294 pub endpoint: String, 295} 296 297pub async fn get_recommended_did_credentials( 298 State(state): State<AppState>, 299 headers: axum::http::HeaderMap, 300) -> Response { 301 let token = match crate::auth::extract_bearer_token_from_header( 302 headers.get("Authorization").and_then(|h| h.to_str().ok()), 303 ) { 304 Some(t) => t, 305 None => { 306 return ( 307 StatusCode::UNAUTHORIZED, 308 Json(json!({"error": "AuthenticationRequired"})), 309 ) 310 .into_response(); 311 } 312 }; 313 let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 314 Ok(user) => user, 315 Err(e) => return ApiError::from(e).into_response(), 316 }; 317 let user = match sqlx::query!( 318 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 319 auth_user.did 320 ) 321 .fetch_optional(&state.db) 322 .await 323 { 324 Ok(Some(row)) => row, 325 _ => return ApiError::InternalError.into_response(), 326 }; 327 let key_bytes = match auth_user.key_bytes { 328 Some(kb) => kb, 329 None => { 330 return ApiError::AuthenticationFailedMsg( 331 "OAuth tokens cannot get DID credentials".into(), 332 ) 333 .into_response(); 334 } 335 }; 336 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 337 let pds_endpoint = format!("https://{}", hostname); 338 let full_handle = if user.handle.contains('.') { 339 user.handle.clone() 340 } else { 341 format!("{}.{}", user.handle, hostname) 342 }; 343 let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) { 344 Ok(k) => k, 345 Err(_) => return ApiError::InternalError.into_response(), 346 }; 347 let did_key = signing_key_to_did_key(&signing_key); 348 ( 349 StatusCode::OK, 350 Json(GetRecommendedDidCredentialsOutput { 351 rotation_keys: vec![did_key.clone()], 352 also_known_as: vec![format!("at://{}", full_handle)], 353 verification_methods: VerificationMethods { atproto: did_key }, 354 services: Services { 355 atproto_pds: AtprotoPds { 356 service_type: "AtprotoPersonalDataServer".to_string(), 357 endpoint: pds_endpoint, 358 }, 359 }, 360 }), 361 ) 362 .into_response() 363} 364 365#[derive(Deserialize)] 366pub struct UpdateHandleInput { 367 pub handle: String, 368} 369 370pub async fn update_handle( 371 State(state): State<AppState>, 372 headers: axum::http::HeaderMap, 373 Json(input): Json<UpdateHandleInput>, 374) -> Response { 375 let token = match crate::auth::extract_bearer_token_from_header( 376 headers.get("Authorization").and_then(|h| h.to_str().ok()), 377 ) { 378 Some(t) => t, 379 None => return ApiError::AuthenticationRequired.into_response(), 380 }; 381 let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await { 382 Ok(user) => user.did, 383 Err(e) => return ApiError::from(e).into_response(), 384 }; 385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 386 .fetch_optional(&state.db) 387 .await 388 { 389 Ok(Some(id)) => id, 390 _ => return ApiError::InternalError.into_response(), 391 }; 392 let new_handle = input.handle.trim(); 393 if new_handle.is_empty() { 394 return ApiError::InvalidRequest("handle is required".into()).into_response(); 395 } 396 if !new_handle 397 .chars() 398 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 399 { 400 return ( 401 StatusCode::BAD_REQUEST, 402 Json( 403 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 404 ), 405 ) 406 .into_response(); 407 } 408 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 409 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname); 410 let (handle_to_store, full_handle) = if is_service_domain { 411 let suffix = format!(".{}", hostname); 412 let short_handle = if new_handle.ends_with(&suffix) { 413 new_handle.strip_suffix(&suffix).unwrap_or(new_handle) 414 } else { 415 new_handle 416 }; 417 (short_handle.to_string(), format!("{}.{}", short_handle, hostname)) 418 } else { 419 match crate::handle::verify_handle_ownership(new_handle, &did).await { 420 Ok(()) => {} 421 Err(crate::handle::HandleResolutionError::NotFound) => { 422 return ( 423 StatusCode::BAD_REQUEST, 424 Json(json!({ 425 "error": "HandleNotAvailable", 426 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did", 427 "handle": new_handle 428 })), 429 ) 430 .into_response(); 431 } 432 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => { 433 return ( 434 StatusCode::BAD_REQUEST, 435 Json(json!({ 436 "error": "HandleNotAvailable", 437 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual) 438 })), 439 ) 440 .into_response(); 441 } 442 Err(e) => { 443 warn!("Handle verification failed: {}", e); 444 return ( 445 StatusCode::BAD_REQUEST, 446 Json(json!({ 447 "error": "HandleNotAvailable", 448 "message": format!("Handle verification failed: {}", e) 449 })), 450 ) 451 .into_response(); 452 } 453 } 454 (new_handle.to_string(), new_handle.to_string()) 455 }; 456 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 457 .fetch_optional(&state.db) 458 .await 459 .ok() 460 .flatten(); 461 let existing = sqlx::query!( 462 "SELECT id FROM users WHERE handle = $1 AND id != $2", 463 handle_to_store, 464 user_id 465 ) 466 .fetch_optional(&state.db) 467 .await; 468 if let Ok(Some(_)) = existing { 469 return ( 470 StatusCode::BAD_REQUEST, 471 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 472 ) 473 .into_response(); 474 } 475 let result = sqlx::query!( 476 "UPDATE users SET handle = $1 WHERE id = $2", 477 handle_to_store, 478 user_id 479 ) 480 .execute(&state.db) 481 .await; 482 match result { 483 Ok(_) => { 484 if let Some(old) = old_handle { 485 let _ = state.cache.delete(&format!("handle:{}", old)).await; 486 } 487 let _ = state 488 .cache 489 .delete(&format!("handle:{}", handle_to_store)) 490 .await; 491 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await; 492 if let Err(e) = 493 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 494 .await 495 { 496 warn!("Failed to sequence identity event for handle update: {}", e); 497 } 498 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await { 499 warn!("Failed to update PLC handle: {}", e); 500 } 501 (StatusCode::OK, Json(json!({}))).into_response() 502 } 503 Err(e) => { 504 error!("DB error updating handle: {:?}", e); 505 ( 506 StatusCode::INTERNAL_SERVER_ERROR, 507 Json(json!({"error": "InternalError"})), 508 ) 509 .into_response() 510 } 511 } 512} 513 514async fn update_plc_handle( 515 state: &AppState, 516 did: &str, 517 new_handle: &str, 518) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 519 if !did.starts_with("did:plc:") { 520 return Ok(()); 521 } 522 let user_row = sqlx::query!( 523 r#"SELECT u.id, uk.key_bytes, uk.encryption_version 524 FROM users u 525 JOIN user_keys uk ON u.id = uk.user_id 526 WHERE u.did = $1"#, 527 did 528 ) 529 .fetch_optional(&state.db) 530 .await?; 531 let user_row = match user_row { 532 Some(r) => r, 533 None => return Ok(()), 534 }; 535 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?; 536 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?; 537 let plc_client = crate::plc::PlcClient::new(None); 538 let last_op = plc_client.get_last_op(did).await?; 539 let new_also_known_as = vec![format!("at://{}", new_handle)]; 540 let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?; 541 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?; 542 plc_client.send_operation(did, &signed_op).await?; 543 Ok(()) 544} 545 546pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 547 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 548 Some(h) => h, 549 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 550 }; 551 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 552 let suffix = format!(".{}", hostname); 553 let handle = host.split(':').next().unwrap_or(host); 554 let short_handle = if handle.ends_with(&suffix) { 555 handle.strip_suffix(&suffix).unwrap_or(handle) 556 } else { 557 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 558 }; 559 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 560 .fetch_optional(&state.db) 561 .await; 562 match user { 563 Ok(Some(row)) => row.did.into_response(), 564 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 565 Err(e) => { 566 error!("DB error in well-known atproto-did: {:?}", e); 567 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 568 } 569 } 570}