this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::{error, warn}; 16#[derive(Deserialize)] 17pub struct ResolveHandleParams { 18 pub handle: String, 19} 20pub async fn resolve_handle( 21 State(state): State<AppState>, 22 Query(params): Query<ResolveHandleParams>, 23) -> Response { 24 let handle = params.handle.trim(); 25 if handle.is_empty() { 26 return ( 27 StatusCode::BAD_REQUEST, 28 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 29 ) 30 .into_response(); 31 } 32 let cache_key = format!("handle:{}", handle); 33 if let Some(did) = state.cache.get(&cache_key).await { 34 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 35 } 36 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 37 let suffix = format!(".{}", hostname); 38 let short_handle = if handle.ends_with(&suffix) { 39 handle.strip_suffix(&suffix).unwrap_or(handle) 40 } else { 41 handle 42 }; 43 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 44 .fetch_optional(&state.db) 45 .await; 46 match user { 47 Ok(Some(row)) => { 48 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 49 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 50 } 51 Ok(None) => ( 52 StatusCode::NOT_FOUND, 53 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 54 ) 55 .into_response(), 56 Err(e) => { 57 error!("DB error resolving handle: {:?}", e); 58 ( 59 StatusCode::INTERNAL_SERVER_ERROR, 60 Json(json!({"error": "InternalError"})), 61 ) 62 .into_response() 63 } 64 } 65} 66pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 67 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 68 let public_key = secret_key.public_key(); 69 let encoded = public_key.to_encoded_point(false); 70 let x = encoded.x().ok_or("Missing x coordinate")?; 71 let y = encoded.y().ok_or("Missing y coordinate")?; 72 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 73 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 74 Ok(json!({ 75 "kty": "EC", 76 "crv": "secp256k1", 77 "x": x_b64, 78 "y": y_b64 79 })) 80} 81pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 82 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 // Kinda for local dev, encode hostname if it contains port 84 let did = if hostname.contains(':') { 85 format!("did:web:{}", hostname.replace(':', "%3A")) 86 } else { 87 format!("did:web:{}", hostname) 88 }; 89 Json(json!({ 90 "@context": ["https://www.w3.org/ns/did/v1"], 91 "id": did, 92 "service": [{ 93 "id": "#atproto_pds", 94 "type": "AtprotoPersonalDataServer", 95 "serviceEndpoint": format!("https://{}", hostname) 96 }] 97 })) 98} 99pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 100 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 102 .fetch_optional(&state.db) 103 .await; 104 let (user_id, did) = match user { 105 Ok(Some(row)) => (row.id, row.did), 106 Ok(None) => { 107 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 108 } 109 Err(e) => { 110 error!("DB Error: {:?}", e); 111 return ( 112 StatusCode::INTERNAL_SERVER_ERROR, 113 Json(json!({"error": "InternalError"})), 114 ) 115 .into_response(); 116 } 117 }; 118 if !did.starts_with("did:web:") { 119 return ( 120 StatusCode::NOT_FOUND, 121 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 122 ) 123 .into_response(); 124 } 125 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 126 .fetch_optional(&state.db) 127 .await; 128 let key_bytes: Vec<u8> = match key_row { 129 Ok(Some(row)) => { 130 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 131 Ok(k) => k, 132 Err(_) => { 133 return ( 134 StatusCode::INTERNAL_SERVER_ERROR, 135 Json(json!({"error": "InternalError"})), 136 ) 137 .into_response(); 138 } 139 } 140 } 141 _ => { 142 return ( 143 StatusCode::INTERNAL_SERVER_ERROR, 144 Json(json!({"error": "InternalError"})), 145 ) 146 .into_response(); 147 } 148 }; 149 let jwk = match get_jwk(&key_bytes) { 150 Ok(j) => j, 151 Err(e) => { 152 tracing::error!("Failed to generate JWK: {}", e); 153 return ( 154 StatusCode::INTERNAL_SERVER_ERROR, 155 Json(json!({"error": "InternalError"})), 156 ) 157 .into_response(); 158 } 159 }; 160 Json(json!({ 161 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 162 "id": did, 163 "alsoKnownAs": [format!("at://{}", handle)], 164 "verificationMethod": [{ 165 "id": format!("{}#atproto", did), 166 "type": "JsonWebKey2020", 167 "controller": did, 168 "publicKeyJwk": jwk 169 }], 170 "service": [{ 171 "id": "#atproto_pds", 172 "type": "AtprotoPersonalDataServer", 173 "serviceEndpoint": format!("https://{}", hostname) 174 }] 175 })).into_response() 176} 177pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 178 let expected_prefix = if hostname.contains(':') { 179 format!("did:web:{}", hostname.replace(':', "%3A")) 180 } else { 181 format!("did:web:{}", hostname) 182 }; 183 if did.starts_with(&expected_prefix) { 184 let suffix = &did[expected_prefix.len()..]; 185 let expected_suffix = format!(":u:{}", handle); 186 if suffix == expected_suffix { 187 Ok(()) 188 } else { 189 Err(format!( 190 "Invalid DID path for this PDS. Expected {}", 191 expected_suffix 192 )) 193 } 194 } else { 195 let parts: Vec<&str> = did.split(':').collect(); 196 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 197 return Err("Invalid did:web format".into()); 198 } 199 let domain_segment = parts[2]; 200 let domain = domain_segment.replace("%3A", ":"); 201 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 202 "http" 203 } else { 204 "https" 205 }; 206 let url = if parts.len() == 3 { 207 format!("{}://{}/.well-known/did.json", scheme, domain) 208 } else { 209 let path = parts[3..].join("/"); 210 format!("{}://{}/{}/did.json", scheme, domain, path) 211 }; 212 let client = reqwest::Client::builder() 213 .timeout(std::time::Duration::from_secs(5)) 214 .build() 215 .map_err(|e| format!("Failed to create client: {}", e))?; 216 let resp = client 217 .get(&url) 218 .send() 219 .await 220 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 221 if !resp.status().is_success() { 222 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 223 } 224 let doc: serde_json::Value = resp 225 .json() 226 .await 227 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 228 let services = doc["service"] 229 .as_array() 230 .ok_or("No services found in DID doc")?; 231 let pds_endpoint = format!("https://{}", hostname); 232 let has_valid_service = services.iter().any(|s| { 233 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 234 }); 235 if has_valid_service { 236 Ok(()) 237 } else { 238 Err(format!( 239 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 240 pds_endpoint 241 )) 242 } 243 } 244} 245#[derive(serde::Serialize)] 246#[serde(rename_all = "camelCase")] 247pub struct GetRecommendedDidCredentialsOutput { 248 pub rotation_keys: Vec<String>, 249 pub also_known_as: Vec<String>, 250 pub verification_methods: VerificationMethods, 251 pub services: Services, 252} 253#[derive(serde::Serialize)] 254#[serde(rename_all = "camelCase")] 255pub struct VerificationMethods { 256 pub atproto: String, 257} 258#[derive(serde::Serialize)] 259#[serde(rename_all = "camelCase")] 260pub struct Services { 261 pub atproto_pds: AtprotoPds, 262} 263#[derive(serde::Serialize)] 264#[serde(rename_all = "camelCase")] 265pub struct AtprotoPds { 266 #[serde(rename = "type")] 267 pub service_type: String, 268 pub endpoint: String, 269} 270pub async fn get_recommended_did_credentials( 271 State(state): State<AppState>, 272 headers: axum::http::HeaderMap, 273) -> Response { 274 let token = match crate::auth::extract_bearer_token_from_header( 275 headers.get("Authorization").and_then(|h| h.to_str().ok()) 276 ) { 277 Some(t) => t, 278 None => { 279 return ( 280 StatusCode::UNAUTHORIZED, 281 Json(json!({"error": "AuthenticationRequired"})), 282 ) 283 .into_response(); 284 } 285 }; 286 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 287 Ok(user) => user, 288 Err(e) => return ApiError::from(e).into_response(), 289 }; 290 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 291 .fetch_optional(&state.db) 292 .await 293 { 294 Ok(Some(row)) => row, 295 _ => return ApiError::InternalError.into_response(), 296 }; 297 let key_bytes = match auth_user.key_bytes { 298 Some(kb) => kb, 299 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 300 }; 301 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 302 let pds_endpoint = format!("https://{}", hostname); 303 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 304 Ok(k) => k, 305 Err(_) => return ApiError::InternalError.into_response(), 306 }; 307 let public_key = secret_key.public_key(); 308 let encoded = public_key.to_encoded_point(true); 309 let did_key = format!( 310 "did:key:zQ3sh{}", 311 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 312 .chars() 313 .skip(1) 314 .collect::<String>() 315 ); 316 ( 317 StatusCode::OK, 318 Json(GetRecommendedDidCredentialsOutput { 319 rotation_keys: vec![did_key.clone()], 320 also_known_as: vec![format!("at://{}", user.handle)], 321 verification_methods: VerificationMethods { atproto: did_key }, 322 services: Services { 323 atproto_pds: AtprotoPds { 324 service_type: "AtprotoPersonalDataServer".to_string(), 325 endpoint: pds_endpoint, 326 }, 327 }, 328 }), 329 ) 330 .into_response() 331} 332#[derive(Deserialize)] 333pub struct UpdateHandleInput { 334 pub handle: String, 335} 336pub async fn update_handle( 337 State(state): State<AppState>, 338 headers: axum::http::HeaderMap, 339 Json(input): Json<UpdateHandleInput>, 340) -> Response { 341 let token = match crate::auth::extract_bearer_token_from_header( 342 headers.get("Authorization").and_then(|h| h.to_str().ok()) 343 ) { 344 Some(t) => t, 345 None => return ApiError::AuthenticationRequired.into_response(), 346 }; 347 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 348 Ok(user) => user.did, 349 Err(e) => return ApiError::from(e).into_response(), 350 }; 351 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 352 .fetch_optional(&state.db) 353 .await 354 { 355 Ok(Some(id)) => id, 356 _ => return ApiError::InternalError.into_response(), 357 }; 358 let new_handle = input.handle.trim(); 359 if new_handle.is_empty() { 360 return ApiError::InvalidRequest("handle is required".into()).into_response(); 361 } 362 if !new_handle 363 .chars() 364 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 365 { 366 return ( 367 StatusCode::BAD_REQUEST, 368 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 369 ) 370 .into_response(); 371 } 372 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 373 .fetch_optional(&state.db) 374 .await 375 .ok() 376 .flatten(); 377 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 378 .fetch_optional(&state.db) 379 .await; 380 if let Ok(Some(_)) = existing { 381 return ( 382 StatusCode::BAD_REQUEST, 383 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 384 ) 385 .into_response(); 386 } 387 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 388 .execute(&state.db) 389 .await; 390 match result { 391 Ok(_) => { 392 if let Some(old) = old_handle { 393 let _ = state.cache.delete(&format!("handle:{}", old)).await; 394 } 395 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 396 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 397 let full_handle = format!("{}.{}", new_handle, hostname); 398 if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await { 399 warn!("Failed to sequence identity event for handle update: {}", e); 400 } 401 (StatusCode::OK, Json(json!({}))).into_response() 402 } 403 Err(e) => { 404 error!("DB error updating handle: {:?}", e); 405 ( 406 StatusCode::INTERNAL_SERVER_ERROR, 407 Json(json!({"error": "InternalError"})), 408 ) 409 .into_response() 410 } 411 } 412} 413pub async fn well_known_atproto_did( 414 State(state): State<AppState>, 415 headers: HeaderMap, 416) -> Response { 417 let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 418 Some(h) => h, 419 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 420 }; 421 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 422 let suffix = format!(".{}", hostname); 423 let handle = host.split(':').next().unwrap_or(host); 424 let short_handle = if handle.ends_with(&suffix) { 425 handle.strip_suffix(&suffix).unwrap_or(handle) 426 } else { 427 return (StatusCode::NOT_FOUND, "Handle not found").into_response(); 428 }; 429 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle) 430 .fetch_optional(&state.db) 431 .await; 432 match user { 433 Ok(Some(row)) => row.did.into_response(), 434 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 435 Err(e) => { 436 error!("DB error in well-known atproto-did: {:?}", e); 437 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() 438 } 439 } 440}