this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::{error, warn}; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 if handle.is_empty() { 28 return ( 29 StatusCode::BAD_REQUEST, 30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 31 ) 32 .into_response(); 33 } 34 let cache_key = format!("handle:{}", handle); 35 if let Some(did) = state.cache.get(&cache_key).await { 36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 37 } 38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 39 let suffix = format!(".{}", hostname); 40 let short_handle = if handle.ends_with(&suffix) { 41 handle.strip_suffix(&suffix).unwrap_or(handle) 42 } else { 43 handle 44 }; 45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 46 .fetch_optional(&state.db) 47 .await; 48 match user { 49 Ok(Some(row)) => { 50 let _ = state 51 .cache 52 .set(&cache_key, &row.did, std::time::Duration::from_secs(300)) 53 .await; 54 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 55 } 56 Ok(None) => ( 57 StatusCode::NOT_FOUND, 58 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 59 ) 60 .into_response(), 61 Err(e) => { 62 error!("DB error resolving handle: {:?}", e); 63 ( 64 StatusCode::INTERNAL_SERVER_ERROR, 65 Json(json!({"error": "InternalError"})), 66 ) 67 .into_response() 68 } 69 } 70} 71 72pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 73 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 74 let public_key = secret_key.public_key(); 75 let encoded = public_key.to_encoded_point(false); 76 let x = encoded.x().ok_or("Missing x coordinate")?; 77 let y = encoded.y().ok_or("Missing y coordinate")?; 78 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 79 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 80 Ok(json!({ 81 "kty": "EC", 82 "crv": "secp256k1", 83 "x": x_b64, 84 "y": y_b64 85 })) 86} 87 88pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 89 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 90 // Kinda for local dev, encode hostname if it contains port 91 let did = if hostname.contains(':') { 92 format!("did:web:{}", hostname.replace(':', "%3A")) 93 } else { 94 format!("did:web:{}", hostname) 95 }; 96 Json(json!({ 97 "@context": ["https://www.w3.org/ns/did/v1"], 98 "id": did, 99 "service": [{ 100 "id": "#atproto_pds", 101 "type": "AtprotoPersonalDataServer", 102 "serviceEndpoint": format!("https://{}", hostname) 103 }] 104 })) 105} 106 107pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 108 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 109 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 110 .fetch_optional(&state.db) 111 .await; 112 let (user_id, did) = match user { 113 Ok(Some(row)) => (row.id, row.did), 114 Ok(None) => { 115 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 116 } 117 Err(e) => { 118 error!("DB Error: {:?}", e); 119 return ( 120 StatusCode::INTERNAL_SERVER_ERROR, 121 Json(json!({"error": "InternalError"})), 122 ) 123 .into_response(); 124 } 125 }; 126 if !did.starts_with("did:web:") { 127 return ( 128 StatusCode::NOT_FOUND, 129 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 130 ) 131 .into_response(); 132 } 133 let key_row = sqlx::query!( 134 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 135 user_id 136 ) 137 .fetch_optional(&state.db) 138 .await; 139 let key_bytes: Vec<u8> = match key_row { 140 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 141 Ok(k) => k, 142 Err(_) => { 143 return ( 144 StatusCode::INTERNAL_SERVER_ERROR, 145 Json(json!({"error": "InternalError"})), 146 ) 147 .into_response(); 148 } 149 }, 150 _ => { 151 return ( 152 StatusCode::INTERNAL_SERVER_ERROR, 153 Json(json!({"error": "InternalError"})), 154 ) 155 .into_response(); 156 } 157 }; 158 let jwk = match get_jwk(&key_bytes) { 159 Ok(j) => j, 160 Err(e) => { 161 tracing::error!("Failed to generate JWK: {}", e); 162 return ( 163 StatusCode::INTERNAL_SERVER_ERROR, 164 Json(json!({"error": "InternalError"})), 165 ) 166 .into_response(); 167 } 168 }; 169 Json(json!({ 170 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 171 "id": did, 172 "alsoKnownAs": [format!("at://{}", handle)], 173 "verificationMethod": [{ 174 "id": format!("{}#atproto", did), 175 "type": "JsonWebKey2020", 176 "controller": did, 177 "publicKeyJwk": jwk 178 }], 179 "service": [{ 180 "id": "#atproto_pds", 181 "type": "AtprotoPersonalDataServer", 182 "serviceEndpoint": format!("https://{}", hostname) 183 }] 184 })).into_response() 185} 186 187pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 188 let expected_prefix = if hostname.contains(':') { 189 format!("did:web:{}", hostname.replace(':', "%3A")) 190 } else { 191 format!("did:web:{}", hostname) 192 }; 193 if did.starts_with(&expected_prefix) { 194 let suffix = &did[expected_prefix.len()..]; 195 let expected_suffix = format!(":u:{}", handle); 196 if suffix == expected_suffix { 197 Ok(()) 198 } else { 199 Err(format!( 200 "Invalid DID path for this PDS. Expected {}", 201 expected_suffix 202 )) 203 } 204 } else { 205 let parts: Vec<&str> = did.split(':').collect(); 206 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 207 return Err("Invalid did:web format".into()); 208 } 209 let domain_segment = parts[2]; 210 let domain = domain_segment.replace("%3A", ":"); 211 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 212 "http" 213 } else { 214 "https" 215 }; 216 let url = if parts.len() == 3 { 217 format!("{}://{}/.well-known/did.json", scheme, domain) 218 } else { 219 let path = parts[3..].join("/"); 220 format!("{}://{}/{}/did.json", scheme, domain, path) 221 }; 222 let client = reqwest::Client::builder() 223 .timeout(std::time::Duration::from_secs(5)) 224 .build() 225 .map_err(|e| format!("Failed to create client: {}", e))?; 226 let resp = client 227 .get(&url) 228 .send() 229 .await 230 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 231 if !resp.status().is_success() { 232 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 233 } 234 let doc: serde_json::Value = resp 235 .json() 236 .await 237 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 238 let services = doc["service"] 239 .as_array() 240 .ok_or("No services found in DID doc")?; 241 let pds_endpoint = format!("https://{}", hostname); 242 let has_valid_service = services.iter().any(|s| { 243 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 244 }); 245 if has_valid_service { 246 Ok(()) 247 } else { 248 Err(format!( 249 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 250 pds_endpoint 251 )) 252 } 253 } 254} 255 256#[derive(serde::Serialize)] 257#[serde(rename_all = "camelCase")] 258pub struct GetRecommendedDidCredentialsOutput { 259 pub rotation_keys: Vec<String>, 260 pub also_known_as: Vec<String>, 261 pub verification_methods: VerificationMethods, 262 pub services: Services, 263} 264 265#[derive(serde::Serialize)] 266#[serde(rename_all = "camelCase")] 267pub struct VerificationMethods { 268 pub atproto: String, 269} 270 271#[derive(serde::Serialize)] 272#[serde(rename_all = "camelCase")] 273pub struct Services { 274 pub atproto_pds: AtprotoPds, 275} 276 277#[derive(serde::Serialize)] 278#[serde(rename_all = "camelCase")] 279pub struct AtprotoPds { 280 #[serde(rename = "type")] 281 pub service_type: String, 282 pub endpoint: String, 283} 284 285pub async fn get_recommended_did_credentials( 286 State(state): State<AppState>, 287 headers: axum::http::HeaderMap, 288) -> Response { 289 let token = match crate::auth::extract_bearer_token_from_header( 290 headers.get("Authorization").and_then(|h| h.to_str().ok()), 291 ) { 292 Some(t) => t, 293 None => { 294 return ( 295 StatusCode::UNAUTHORIZED, 296 Json(json!({"error": "AuthenticationRequired"})), 297 ) 298 .into_response(); 299 } 300 }; 301 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 302 Ok(user) => user, 303 Err(e) => return ApiError::from(e).into_response(), 304 }; 305 let user = match sqlx::query!( 306 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 307 auth_user.did 308 ) 309 .fetch_optional(&state.db) 310 .await 311 { 312 Ok(Some(row)) => row, 313 _ => return ApiError::InternalError.into_response(), 314 }; 315 let key_bytes = match auth_user.key_bytes { 316 Some(kb) => kb, 317 None => { 318 return ApiError::AuthenticationFailedMsg( 319 "OAuth tokens cannot get DID credentials".into(), 320 ) 321 .into_response(); 322 } 323 }; 324 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 325 let pds_endpoint = format!("https://{}", hostname); 326 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 327 Ok(k) => k, 328 Err(_) => return ApiError::InternalError.into_response(), 329 }; 330 let public_key = secret_key.public_key(); 331 let encoded = public_key.to_encoded_point(true); 332 let did_key = format!( 333 "did:key:zQ3sh{}", 334 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 335 .chars() 336 .skip(1) 337 .collect::<String>() 338 ); 339 ( 340 StatusCode::OK, 341 Json(GetRecommendedDidCredentialsOutput { 342 rotation_keys: vec![did_key.clone()], 343 also_known_as: vec![format!("at://{}", user.handle)], 344 verification_methods: VerificationMethods { atproto: did_key }, 345 services: Services { 346 atproto_pds: AtprotoPds { 347 service_type: "AtprotoPersonalDataServer".to_string(), 348 endpoint: pds_endpoint, 349 }, 350 }, 351 }), 352 ) 353 .into_response() 354} 355 356#[derive(Deserialize)] 357pub struct UpdateHandleInput { 358 pub handle: String, 359} 360 361pub async fn update_handle( 362 State(state): State<AppState>, 363 headers: axum::http::HeaderMap, 364 Json(input): Json<UpdateHandleInput>, 365) -> Response { 366 let token = match crate::auth::extract_bearer_token_from_header( 367 headers.get("Authorization").and_then(|h| h.to_str().ok()), 368 ) { 369 Some(t) => t, 370 None => return ApiError::AuthenticationRequired.into_response(), 371 }; 372 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 373 Ok(user) => user.did, 374 Err(e) => return ApiError::from(e).into_response(), 375 }; 376 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 377 .fetch_optional(&state.db) 378 .await 379 { 380 Ok(Some(id)) => id, 381 _ => return ApiError::InternalError.into_response(), 382 }; 383 let new_handle = input.handle.trim(); 384 if new_handle.is_empty() { 385 return ApiError::InvalidRequest("handle is required".into()).into_response(); 386 } 387 if !new_handle 388 .chars() 389 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 390 { 391 return ( 392 StatusCode::BAD_REQUEST, 393 Json( 394 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}), 395 ), 396 ) 397 .into_response(); 398 } 399 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 400 .fetch_optional(&state.db) 401 .await 402 .ok() 403 .flatten(); 404 let existing = sqlx::query!( 405 "SELECT id FROM users WHERE handle = $1 AND id != $2", 406 new_handle, 407 user_id 408 ) 409 .fetch_optional(&state.db) 410 .await; 411 if let Ok(Some(_)) = existing { 412 return ( 413 StatusCode::BAD_REQUEST, 414 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 415 ) 416 .into_response(); 417 } 418 let result = sqlx::query!( 419 "UPDATE users SET handle = $1 WHERE id = $2", 420 new_handle, 421 user_id 422 ) 423 .execute(&state.db) 424 .await; 425 match result { 426 Ok(_) => { 427 if let Some(old) = old_handle { 428 let _ = state.cache.delete(&format!("handle:{}", old)).await; 429 } 430 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 431 let hostname = 432 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 433 let full_handle = format!("{}.{}", new_handle, hostname); 434 if let Err(e) = 435 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)) 436 .await 437 { 438 warn!("Failed to sequence identity event for handle update: {}", e); 439 } 440 (StatusCode::OK, Json(json!({}))).into_response() 441 } 442 Err(e) => { 443 error!("DB error updating handle: {:?}", e); 444 ( 445 StatusCode::INTERNAL_SERVER_ERROR, 446 Json(json!({"error": "InternalError"})), 447 ) 448 .into_response() 449 } 450 } 451} 452 453pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 454 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 455 Some(h) => h, 456 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 457 }; 458 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 459 let suffix = format!(".{}", hostname); 460 let handle = host.split(':').next().unwrap_or(host); 461 let short_handle = if handle.ends_with(&suffix) { 462 handle.strip_suffix(&suffix).unwrap_or(handle) 463 } else { 464 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 465 }; 466 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 467 .fetch_optional(&state.db) 468 .await; 469 match user { 470 Ok(Some(row)) => row.did.into_response(), 471 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 472 Err(e) => { 473 error!("DB error in well-known atproto-did: {:?}", e); 474 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 475 } 476 } 477}