this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::{error, warn}; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 if handle.is_empty() { 28 return ( 29 StatusCode::BAD_REQUEST, 30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 31 ) 32 .into_response(); 33 } 34 let cache_key = format!("handle:{}", handle); 35 if let Some(did) = state.cache.get(&cache_key).await { 36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 37 } 38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 39 let suffix = format!(".{}", hostname); 40 let short_handle = if handle.ends_with(&suffix) { 41 handle.strip_suffix(&suffix).unwrap_or(handle) 42 } else { 43 handle 44 }; 45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 46 .fetch_optional(&state.db) 47 .await; 48 match user { 49 Ok(Some(row)) => { 50 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 51 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 52 } 53 Ok(None) => ( 54 StatusCode::NOT_FOUND, 55 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 56 ) 57 .into_response(), 58 Err(e) => { 59 error!("DB error resolving handle: {:?}", e); 60 ( 61 StatusCode::INTERNAL_SERVER_ERROR, 62 Json(json!({"error": "InternalError"})), 63 ) 64 .into_response() 65 } 66 } 67} 68 69pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 70 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 71 let public_key = secret_key.public_key(); 72 let encoded = public_key.to_encoded_point(false); 73 let x = encoded.x().ok_or("Missing x coordinate")?; 74 let y = encoded.y().ok_or("Missing y coordinate")?; 75 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 76 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 77 Ok(json!({ 78 "kty": "EC", 79 "crv": "secp256k1", 80 "x": x_b64, 81 "y": y_b64 82 })) 83} 84 85pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 86 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 87 // Kinda for local dev, encode hostname if it contains port 88 let did = if hostname.contains(':') { 89 format!("did:web:{}", hostname.replace(':', "%3A")) 90 } else { 91 format!("did:web:{}", hostname) 92 }; 93 Json(json!({ 94 "@context": ["https://www.w3.org/ns/did/v1"], 95 "id": did, 96 "service": [{ 97 "id": "#atproto_pds", 98 "type": "AtprotoPersonalDataServer", 99 "serviceEndpoint": format!("https://{}", hostname) 100 }] 101 })) 102} 103 104pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 105 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 107 .fetch_optional(&state.db) 108 .await; 109 let (user_id, did) = match user { 110 Ok(Some(row)) => (row.id, row.did), 111 Ok(None) => { 112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 113 } 114 Err(e) => { 115 error!("DB Error: {:?}", e); 116 return ( 117 StatusCode::INTERNAL_SERVER_ERROR, 118 Json(json!({"error": "InternalError"})), 119 ) 120 .into_response(); 121 } 122 }; 123 if !did.starts_with("did:web:") { 124 return ( 125 StatusCode::NOT_FOUND, 126 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 127 ) 128 .into_response(); 129 } 130 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 131 .fetch_optional(&state.db) 132 .await; 133 let key_bytes: Vec<u8> = match key_row { 134 Ok(Some(row)) => { 135 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 136 Ok(k) => k, 137 Err(_) => { 138 return ( 139 StatusCode::INTERNAL_SERVER_ERROR, 140 Json(json!({"error": "InternalError"})), 141 ) 142 .into_response(); 143 } 144 } 145 } 146 _ => { 147 return ( 148 StatusCode::INTERNAL_SERVER_ERROR, 149 Json(json!({"error": "InternalError"})), 150 ) 151 .into_response(); 152 } 153 }; 154 let jwk = match get_jwk(&key_bytes) { 155 Ok(j) => j, 156 Err(e) => { 157 tracing::error!("Failed to generate JWK: {}", e); 158 return ( 159 StatusCode::INTERNAL_SERVER_ERROR, 160 Json(json!({"error": "InternalError"})), 161 ) 162 .into_response(); 163 } 164 }; 165 Json(json!({ 166 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 167 "id": did, 168 "alsoKnownAs": [format!("at://{}", handle)], 169 "verificationMethod": [{ 170 "id": format!("{}#atproto", did), 171 "type": "JsonWebKey2020", 172 "controller": did, 173 "publicKeyJwk": jwk 174 }], 175 "service": [{ 176 "id": "#atproto_pds", 177 "type": "AtprotoPersonalDataServer", 178 "serviceEndpoint": format!("https://{}", hostname) 179 }] 180 })).into_response() 181} 182 183pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 184 let expected_prefix = if hostname.contains(':') { 185 format!("did:web:{}", hostname.replace(':', "%3A")) 186 } else { 187 format!("did:web:{}", hostname) 188 }; 189 if did.starts_with(&expected_prefix) { 190 let suffix = &did[expected_prefix.len()..]; 191 let expected_suffix = format!(":u:{}", handle); 192 if suffix == expected_suffix { 193 Ok(()) 194 } else { 195 Err(format!( 196 "Invalid DID path for this PDS. Expected {}", 197 expected_suffix 198 )) 199 } 200 } else { 201 let parts: Vec<&str> = did.split(':').collect(); 202 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 203 return Err("Invalid did:web format".into()); 204 } 205 let domain_segment = parts[2]; 206 let domain = domain_segment.replace("%3A", ":"); 207 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 208 "http" 209 } else { 210 "https" 211 }; 212 let url = if parts.len() == 3 { 213 format!("{}://{}/.well-known/did.json", scheme, domain) 214 } else { 215 let path = parts[3..].join("/"); 216 format!("{}://{}/{}/did.json", scheme, domain, path) 217 }; 218 let client = reqwest::Client::builder() 219 .timeout(std::time::Duration::from_secs(5)) 220 .build() 221 .map_err(|e| format!("Failed to create client: {}", e))?; 222 let resp = client 223 .get(&url) 224 .send() 225 .await 226 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 227 if !resp.status().is_success() { 228 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 229 } 230 let doc: serde_json::Value = resp 231 .json() 232 .await 233 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 234 let services = doc["service"] 235 .as_array() 236 .ok_or("No services found in DID doc")?; 237 let pds_endpoint = format!("https://{}", hostname); 238 let has_valid_service = services.iter().any(|s| { 239 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 240 }); 241 if has_valid_service { 242 Ok(()) 243 } else { 244 Err(format!( 245 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 246 pds_endpoint 247 )) 248 } 249 } 250} 251 252#[derive(serde::Serialize)] 253#[serde(rename_all = "camelCase")] 254pub struct GetRecommendedDidCredentialsOutput { 255 pub rotation_keys: Vec<String>, 256 pub also_known_as: Vec<String>, 257 pub verification_methods: VerificationMethods, 258 pub services: Services, 259} 260 261#[derive(serde::Serialize)] 262#[serde(rename_all = "camelCase")] 263pub struct VerificationMethods { 264 pub atproto: String, 265} 266 267#[derive(serde::Serialize)] 268#[serde(rename_all = "camelCase")] 269pub struct Services { 270 pub atproto_pds: AtprotoPds, 271} 272 273#[derive(serde::Serialize)] 274#[serde(rename_all = "camelCase")] 275pub struct AtprotoPds { 276 #[serde(rename = "type")] 277 pub service_type: String, 278 pub endpoint: String, 279} 280 281pub async fn get_recommended_did_credentials( 282 State(state): State<AppState>, 283 headers: axum::http::HeaderMap, 284) -> Response { 285 let token = match crate::auth::extract_bearer_token_from_header( 286 headers.get("Authorization").and_then(|h| h.to_str().ok()) 287 ) { 288 Some(t) => t, 289 None => { 290 return ( 291 StatusCode::UNAUTHORIZED, 292 Json(json!({"error": "AuthenticationRequired"})), 293 ) 294 .into_response(); 295 } 296 }; 297 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 298 Ok(user) => user, 299 Err(e) => return ApiError::from(e).into_response(), 300 }; 301 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 302 .fetch_optional(&state.db) 303 .await 304 { 305 Ok(Some(row)) => row, 306 _ => return ApiError::InternalError.into_response(), 307 }; 308 let key_bytes = match auth_user.key_bytes { 309 Some(kb) => kb, 310 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 311 }; 312 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 313 let pds_endpoint = format!("https://{}", hostname); 314 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 315 Ok(k) => k, 316 Err(_) => return ApiError::InternalError.into_response(), 317 }; 318 let public_key = secret_key.public_key(); 319 let encoded = public_key.to_encoded_point(true); 320 let did_key = format!( 321 "did:key:zQ3sh{}", 322 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 323 .chars() 324 .skip(1) 325 .collect::<String>() 326 ); 327 ( 328 StatusCode::OK, 329 Json(GetRecommendedDidCredentialsOutput { 330 rotation_keys: vec![did_key.clone()], 331 also_known_as: vec![format!("at://{}", user.handle)], 332 verification_methods: VerificationMethods { atproto: did_key }, 333 services: Services { 334 atproto_pds: AtprotoPds { 335 service_type: "AtprotoPersonalDataServer".to_string(), 336 endpoint: pds_endpoint, 337 }, 338 }, 339 }), 340 ) 341 .into_response() 342} 343 344#[derive(Deserialize)] 345pub struct UpdateHandleInput { 346 pub handle: String, 347} 348 349pub async fn update_handle( 350 State(state): State<AppState>, 351 headers: axum::http::HeaderMap, 352 Json(input): Json<UpdateHandleInput>, 353) -> Response { 354 let token = match crate::auth::extract_bearer_token_from_header( 355 headers.get("Authorization").and_then(|h| h.to_str().ok()) 356 ) { 357 Some(t) => t, 358 None => return ApiError::AuthenticationRequired.into_response(), 359 }; 360 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 361 Ok(user) => user.did, 362 Err(e) => return ApiError::from(e).into_response(), 363 }; 364 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 365 .fetch_optional(&state.db) 366 .await 367 { 368 Ok(Some(id)) => id, 369 _ => return ApiError::InternalError.into_response(), 370 }; 371 let new_handle = input.handle.trim(); 372 if new_handle.is_empty() { 373 return ApiError::InvalidRequest("handle is required".into()).into_response(); 374 } 375 if !new_handle 376 .chars() 377 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 378 { 379 return ( 380 StatusCode::BAD_REQUEST, 381 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 382 ) 383 .into_response(); 384 } 385 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 386 .fetch_optional(&state.db) 387 .await 388 .ok() 389 .flatten(); 390 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 391 .fetch_optional(&state.db) 392 .await; 393 if let Ok(Some(_)) = existing { 394 return ( 395 StatusCode::BAD_REQUEST, 396 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 397 ) 398 .into_response(); 399 } 400 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 401 .execute(&state.db) 402 .await; 403 match result { 404 Ok(_) => { 405 if let Some(old) = old_handle { 406 let _ = state.cache.delete(&format!("handle:{}", old)).await; 407 } 408 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 409 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 410 let full_handle = format!("{}.{}", new_handle, hostname); 411 if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await { 412 warn!("Failed to sequence identity event for handle update: {}", e); 413 } 414 (StatusCode::OK, Json(json!({}))).into_response() 415 } 416 Err(e) => { 417 error!("DB error updating handle: {:?}", e); 418 ( 419 StatusCode::INTERNAL_SERVER_ERROR, 420 Json(json!({"error": "InternalError"})), 421 ) 422 .into_response() 423 } 424 } 425} 426 427pub async fn well_known_atproto_did( 428 State(state): State<AppState>, 429 headers: HeaderMap, 430) -> Response { 431 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 432 Some(h) => h, 433 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 434 }; 435 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 436 let suffix = format!(".{}", hostname); 437 let handle = host.split(':').next().unwrap_or(host); 438 let short_handle = if handle.ends_with(&suffix) { 439 handle.strip_suffix(&suffix).unwrap_or(handle) 440 } else { 441 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 442 }; 443 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 444 .fetch_optional(&state.db) 445 .await; 446 match user { 447 Ok(Some(row)) => row.did.into_response(), 448 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 449 Err(e) => { 450 error!("DB error in well-known atproto-did: {:?}", e); 451 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 452 } 453 } 454}