this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::error; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 36 let cache_key = format!("handle:{}", handle); 37 if let Some(did) = state.cache.get(&cache_key).await { 38 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 39 } 40 41 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 42 let suffix = format!(".{}", hostname); 43 let short_handle = if handle.ends_with(&suffix) { 44 handle.strip_suffix(&suffix).unwrap_or(handle) 45 } else { 46 handle 47 }; 48 49 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 50 .fetch_optional(&state.db) 51 .await; 52 53 match user { 54 Ok(Some(row)) => { 55 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 56 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 57 } 58 Ok(None) => ( 59 StatusCode::NOT_FOUND, 60 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 61 ) 62 .into_response(), 63 Err(e) => { 64 error!("DB error resolving handle: {:?}", e); 65 ( 66 StatusCode::INTERNAL_SERVER_ERROR, 67 Json(json!({"error": "InternalError"})), 68 ) 69 .into_response() 70 } 71 } 72} 73 74pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 75 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 76 let public_key = secret_key.public_key(); 77 let encoded = public_key.to_encoded_point(false); 78 let x = encoded.x().ok_or("Missing x coordinate")?; 79 let y = encoded.y().ok_or("Missing y coordinate")?; 80 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 81 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 82 83 Ok(json!({ 84 "kty": "EC", 85 "crv": "secp256k1", 86 "x": x_b64, 87 "y": y_b64 88 })) 89} 90 91pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 92 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 93 // Kinda for local dev, encode hostname if it contains port 94 let did = if hostname.contains(':') { 95 format!("did:web:{}", hostname.replace(':', "%3A")) 96 } else { 97 format!("did:web:{}", hostname) 98 }; 99 100 Json(json!({ 101 "@context": ["https://www.w3.org/ns/did/v1"], 102 "id": did, 103 "service": [{ 104 "id": "#atproto_pds", 105 "type": "AtprotoPersonalDataServer", 106 "serviceEndpoint": format!("https://{}", hostname) 107 }] 108 })) 109} 110 111pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 112 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 113 114 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 115 .fetch_optional(&state.db) 116 .await; 117 118 let (user_id, did) = match user { 119 Ok(Some(row)) => (row.id, row.did), 120 Ok(None) => { 121 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 122 } 123 Err(e) => { 124 error!("DB Error: {:?}", e); 125 return ( 126 StatusCode::INTERNAL_SERVER_ERROR, 127 Json(json!({"error": "InternalError"})), 128 ) 129 .into_response(); 130 } 131 }; 132 133 if !did.starts_with("did:web:") { 134 return ( 135 StatusCode::NOT_FOUND, 136 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 137 ) 138 .into_response(); 139 } 140 141 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 142 .fetch_optional(&state.db) 143 .await; 144 145 let key_bytes: Vec<u8> = match key_row { 146 Ok(Some(row)) => { 147 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 148 Ok(k) => k, 149 Err(_) => { 150 return ( 151 StatusCode::INTERNAL_SERVER_ERROR, 152 Json(json!({"error": "InternalError"})), 153 ) 154 .into_response(); 155 } 156 } 157 } 158 _ => { 159 return ( 160 StatusCode::INTERNAL_SERVER_ERROR, 161 Json(json!({"error": "InternalError"})), 162 ) 163 .into_response(); 164 } 165 }; 166 167 let jwk = match get_jwk(&key_bytes) { 168 Ok(j) => j, 169 Err(e) => { 170 tracing::error!("Failed to generate JWK: {}", e); 171 return ( 172 StatusCode::INTERNAL_SERVER_ERROR, 173 Json(json!({"error": "InternalError"})), 174 ) 175 .into_response(); 176 } 177 }; 178 179 Json(json!({ 180 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 181 "id": did, 182 "alsoKnownAs": [format!("at://{}", handle)], 183 "verificationMethod": [{ 184 "id": format!("{}#atproto", did), 185 "type": "JsonWebKey2020", 186 "controller": did, 187 "publicKeyJwk": jwk 188 }], 189 "service": [{ 190 "id": "#atproto_pds", 191 "type": "AtprotoPersonalDataServer", 192 "serviceEndpoint": format!("https://{}", hostname) 193 }] 194 })).into_response() 195} 196 197pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 198 let expected_prefix = if hostname.contains(':') { 199 format!("did:web:{}", hostname.replace(':', "%3A")) 200 } else { 201 format!("did:web:{}", hostname) 202 }; 203 204 if did.starts_with(&expected_prefix) { 205 let suffix = &did[expected_prefix.len()..]; 206 let expected_suffix = format!(":u:{}", handle); 207 if suffix == expected_suffix { 208 Ok(()) 209 } else { 210 Err(format!( 211 "Invalid DID path for this PDS. Expected {}", 212 expected_suffix 213 )) 214 } 215 } else { 216 let parts: Vec<&str> = did.split(':').collect(); 217 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 218 return Err("Invalid did:web format".into()); 219 } 220 221 let domain_segment = parts[2]; 222 let domain = domain_segment.replace("%3A", ":"); 223 224 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 225 "http" 226 } else { 227 "https" 228 }; 229 230 let url = if parts.len() == 3 { 231 format!("{}://{}/.well-known/did.json", scheme, domain) 232 } else { 233 let path = parts[3..].join("/"); 234 format!("{}://{}/{}/did.json", scheme, domain, path) 235 }; 236 237 let client = reqwest::Client::builder() 238 .timeout(std::time::Duration::from_secs(5)) 239 .build() 240 .map_err(|e| format!("Failed to create client: {}", e))?; 241 242 let resp = client 243 .get(&url) 244 .send() 245 .await 246 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 247 248 if !resp.status().is_success() { 249 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 250 } 251 252 let doc: serde_json::Value = resp 253 .json() 254 .await 255 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 256 257 let services = doc["service"] 258 .as_array() 259 .ok_or("No services found in DID doc")?; 260 261 let pds_endpoint = format!("https://{}", hostname); 262 263 let has_valid_service = services.iter().any(|s| { 264 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 265 }); 266 267 if has_valid_service { 268 Ok(()) 269 } else { 270 Err(format!( 271 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 272 pds_endpoint 273 )) 274 } 275 } 276} 277 278#[derive(serde::Serialize)] 279#[serde(rename_all = "camelCase")] 280pub struct GetRecommendedDidCredentialsOutput { 281 pub rotation_keys: Vec<String>, 282 pub also_known_as: Vec<String>, 283 pub verification_methods: VerificationMethods, 284 pub services: Services, 285} 286 287#[derive(serde::Serialize)] 288#[serde(rename_all = "camelCase")] 289pub struct VerificationMethods { 290 pub atproto: String, 291} 292 293#[derive(serde::Serialize)] 294#[serde(rename_all = "camelCase")] 295pub struct Services { 296 pub atproto_pds: AtprotoPds, 297} 298 299#[derive(serde::Serialize)] 300#[serde(rename_all = "camelCase")] 301pub struct AtprotoPds { 302 #[serde(rename = "type")] 303 pub service_type: String, 304 pub endpoint: String, 305} 306 307pub async fn get_recommended_did_credentials( 308 State(state): State<AppState>, 309 headers: axum::http::HeaderMap, 310) -> Response { 311 let token = match crate::auth::extract_bearer_token_from_header( 312 headers.get("Authorization").and_then(|h| h.to_str().ok()) 313 ) { 314 Some(t) => t, 315 None => { 316 return ( 317 StatusCode::UNAUTHORIZED, 318 Json(json!({"error": "AuthenticationRequired"})), 319 ) 320 .into_response(); 321 } 322 }; 323 324 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 325 Ok(user) => user, 326 Err(e) => return ApiError::from(e).into_response(), 327 }; 328 329 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 330 .fetch_optional(&state.db) 331 .await 332 { 333 Ok(Some(row)) => row, 334 _ => return ApiError::InternalError.into_response(), 335 }; 336 337 let key_bytes = match auth_user.key_bytes { 338 Some(kb) => kb, 339 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 340 }; 341 342 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 343 let pds_endpoint = format!("https://{}", hostname); 344 345 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 346 Ok(k) => k, 347 Err(_) => return ApiError::InternalError.into_response(), 348 }; 349 350 let public_key = secret_key.public_key(); 351 let encoded = public_key.to_encoded_point(true); 352 let did_key = format!( 353 "did:key:zQ3sh{}", 354 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 355 .chars() 356 .skip(1) 357 .collect::<String>() 358 ); 359 360 ( 361 StatusCode::OK, 362 Json(GetRecommendedDidCredentialsOutput { 363 rotation_keys: vec![did_key.clone()], 364 also_known_as: vec![format!("at://{}", user.handle)], 365 verification_methods: VerificationMethods { atproto: did_key }, 366 services: Services { 367 atproto_pds: AtprotoPds { 368 service_type: "AtprotoPersonalDataServer".to_string(), 369 endpoint: pds_endpoint, 370 }, 371 }, 372 }), 373 ) 374 .into_response() 375} 376 377#[derive(Deserialize)] 378pub struct UpdateHandleInput { 379 pub handle: String, 380} 381 382pub async fn update_handle( 383 State(state): State<AppState>, 384 headers: axum::http::HeaderMap, 385 Json(input): Json<UpdateHandleInput>, 386) -> Response { 387 let token = match crate::auth::extract_bearer_token_from_header( 388 headers.get("Authorization").and_then(|h| h.to_str().ok()) 389 ) { 390 Some(t) => t, 391 None => return ApiError::AuthenticationRequired.into_response(), 392 }; 393 394 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 395 Ok(user) => user.did, 396 Err(e) => return ApiError::from(e).into_response(), 397 }; 398 399 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 400 .fetch_optional(&state.db) 401 .await 402 { 403 Ok(Some(id)) => id, 404 _ => return ApiError::InternalError.into_response(), 405 }; 406 407 let new_handle = input.handle.trim(); 408 if new_handle.is_empty() { 409 return ApiError::InvalidRequest("handle is required".into()).into_response(); 410 } 411 412 if !new_handle 413 .chars() 414 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 415 { 416 return ( 417 StatusCode::BAD_REQUEST, 418 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 419 ) 420 .into_response(); 421 } 422 423 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 424 .fetch_optional(&state.db) 425 .await 426 .ok() 427 .flatten(); 428 429 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 430 .fetch_optional(&state.db) 431 .await; 432 433 if let Ok(Some(_)) = existing { 434 return ( 435 StatusCode::BAD_REQUEST, 436 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 437 ) 438 .into_response(); 439 } 440 441 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 442 .execute(&state.db) 443 .await; 444 445 match result { 446 Ok(_) => { 447 if let Some(old) = old_handle { 448 let _ = state.cache.delete(&format!("handle:{}", old)).await; 449 } 450 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 451 (StatusCode::OK, Json(json!({}))).into_response() 452 } 453 Err(e) => { 454 error!("DB error updating handle: {:?}", e); 455 ( 456 StatusCode::INTERNAL_SERVER_ERROR, 457 Json(json!({"error": "InternalError"})), 458 ) 459 .into_response() 460 } 461 } 462} 463 464pub async fn well_known_atproto_did( 465 State(state): State<AppState>, 466 headers: HeaderMap, 467) -> Response { 468 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 469 Some(h) => h, 470 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 471 }; 472 473 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 474 let suffix = format!(".{}", hostname); 475 476 let handle = host.split(':').next().unwrap_or(host); 477 478 let short_handle = if handle.ends_with(&suffix) { 479 handle.strip_suffix(&suffix).unwrap_or(handle) 480 } else { 481 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 482 }; 483 484 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 485 .fetch_optional(&state.db) 486 .await; 487 488 match user { 489 Ok(Some(row)) => row.did.into_response(), 490 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 491 Err(e) => { 492 error!("DB error in well-known atproto-did: {:?}", e); 493 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 494 } 495 } 496}