this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::{Path, Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use base64::Engine; 9use k256::SecretKey; 10use k256::elliptic_curve::sec1::ToEncodedPoint; 11use reqwest; 12use serde::Deserialize; 13use serde_json::json; 14use tracing::error; 15 16#[derive(Deserialize)] 17pub struct ResolveHandleParams { 18 pub handle: String, 19} 20 21pub async fn resolve_handle( 22 State(state): State<AppState>, 23 Query(params): Query<ResolveHandleParams>, 24) -> Response { 25 let handle = params.handle.trim(); 26 27 if handle.is_empty() { 28 return ( 29 StatusCode::BAD_REQUEST, 30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 31 ) 32 .into_response(); 33 } 34 35 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 36 .fetch_optional(&state.db) 37 .await; 38 39 match user { 40 Ok(Some(row)) => { 41 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 42 } 43 Ok(None) => ( 44 StatusCode::NOT_FOUND, 45 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 46 ) 47 .into_response(), 48 Err(e) => { 49 error!("DB error resolving handle: {:?}", e); 50 ( 51 StatusCode::INTERNAL_SERVER_ERROR, 52 Json(json!({"error": "InternalError"})), 53 ) 54 .into_response() 55 } 56 } 57} 58 59pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 60 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 61 let public_key = secret_key.public_key(); 62 let encoded = public_key.to_encoded_point(false); 63 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 64 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 65 66 json!({ 67 "kty": "EC", 68 "crv": "secp256k1", 69 "x": x, 70 "y": y 71 }) 72} 73 74pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 75 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 76 // Kinda for local dev, encode hostname if it contains port 77 let did = if hostname.contains(':') { 78 format!("did:web:{}", hostname.replace(':', "%3A")) 79 } else { 80 format!("did:web:{}", hostname) 81 }; 82 83 Json(json!({ 84 "@context": ["https://www.w3.org/ns/did/v1"], 85 "id": did, 86 "service": [{ 87 "id": "#atproto_pds", 88 "type": "AtprotoPersonalDataServer", 89 "serviceEndpoint": format!("https://{}", hostname) 90 }] 91 })) 92} 93 94pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 95 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 96 97 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 98 .fetch_optional(&state.db) 99 .await; 100 101 let (user_id, did) = match user { 102 Ok(Some(row)) => (row.id, row.did), 103 Ok(None) => { 104 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 105 } 106 Err(e) => { 107 error!("DB Error: {:?}", e); 108 return ( 109 StatusCode::INTERNAL_SERVER_ERROR, 110 Json(json!({"error": "InternalError"})), 111 ) 112 .into_response(); 113 } 114 }; 115 116 if !did.starts_with("did:web:") { 117 return ( 118 StatusCode::NOT_FOUND, 119 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 120 ) 121 .into_response(); 122 } 123 124 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 125 .fetch_optional(&state.db) 126 .await; 127 128 let key_bytes: Vec<u8> = match key_row { 129 Ok(Some(row)) => { 130 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 131 Ok(k) => k, 132 Err(_) => { 133 return ( 134 StatusCode::INTERNAL_SERVER_ERROR, 135 Json(json!({"error": "InternalError"})), 136 ) 137 .into_response(); 138 } 139 } 140 } 141 _ => { 142 return ( 143 StatusCode::INTERNAL_SERVER_ERROR, 144 Json(json!({"error": "InternalError"})), 145 ) 146 .into_response(); 147 } 148 }; 149 150 let jwk = get_jwk(&key_bytes); 151 152 Json(json!({ 153 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 154 "id": did, 155 "alsoKnownAs": [format!("at://{}", handle)], 156 "verificationMethod": [{ 157 "id": format!("{}#atproto", did), 158 "type": "JsonWebKey2020", 159 "controller": did, 160 "publicKeyJwk": jwk 161 }], 162 "service": [{ 163 "id": "#atproto_pds", 164 "type": "AtprotoPersonalDataServer", 165 "serviceEndpoint": format!("https://{}", hostname) 166 }] 167 })).into_response() 168} 169 170pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 171 let expected_prefix = if hostname.contains(':') { 172 format!("did:web:{}", hostname.replace(':', "%3A")) 173 } else { 174 format!("did:web:{}", hostname) 175 }; 176 177 if did.starts_with(&expected_prefix) { 178 let suffix = &did[expected_prefix.len()..]; 179 let expected_suffix = format!(":u:{}", handle); 180 if suffix == expected_suffix { 181 Ok(()) 182 } else { 183 Err(format!( 184 "Invalid DID path for this PDS. Expected {}", 185 expected_suffix 186 )) 187 } 188 } else { 189 let parts: Vec<&str> = did.split(':').collect(); 190 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 191 return Err("Invalid did:web format".into()); 192 } 193 194 let domain_segment = parts[2]; 195 let domain = domain_segment.replace("%3A", ":"); 196 197 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 198 "http" 199 } else { 200 "https" 201 }; 202 203 let url = if parts.len() == 3 { 204 format!("{}://{}/.well-known/did.json", scheme, domain) 205 } else { 206 let path = parts[3..].join("/"); 207 format!("{}://{}/{}/did.json", scheme, domain, path) 208 }; 209 210 let client = reqwest::Client::builder() 211 .timeout(std::time::Duration::from_secs(5)) 212 .build() 213 .map_err(|e| format!("Failed to create client: {}", e))?; 214 215 let resp = client 216 .get(&url) 217 .send() 218 .await 219 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 220 221 if !resp.status().is_success() { 222 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 223 } 224 225 let doc: serde_json::Value = resp 226 .json() 227 .await 228 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 229 230 let services = doc["service"] 231 .as_array() 232 .ok_or("No services found in DID doc")?; 233 234 let pds_endpoint = format!("https://{}", hostname); 235 236 let has_valid_service = services.iter().any(|s| { 237 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 238 }); 239 240 if has_valid_service { 241 Ok(()) 242 } else { 243 Err(format!( 244 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 245 pds_endpoint 246 )) 247 } 248 } 249} 250 251#[derive(serde::Serialize)] 252#[serde(rename_all = "camelCase")] 253pub struct GetRecommendedDidCredentialsOutput { 254 pub rotation_keys: Vec<String>, 255 pub also_known_as: Vec<String>, 256 pub verification_methods: VerificationMethods, 257 pub services: Services, 258} 259 260#[derive(serde::Serialize)] 261#[serde(rename_all = "camelCase")] 262pub struct VerificationMethods { 263 pub atproto: String, 264} 265 266#[derive(serde::Serialize)] 267#[serde(rename_all = "camelCase")] 268pub struct Services { 269 pub atproto_pds: AtprotoPds, 270} 271 272#[derive(serde::Serialize)] 273#[serde(rename_all = "camelCase")] 274pub struct AtprotoPds { 275 #[serde(rename = "type")] 276 pub service_type: String, 277 pub endpoint: String, 278} 279 280pub async fn get_recommended_did_credentials( 281 State(state): State<AppState>, 282 headers: axum::http::HeaderMap, 283) -> Response { 284 let token = match crate::auth::extract_bearer_token_from_header( 285 headers.get("Authorization").and_then(|h| h.to_str().ok()) 286 ) { 287 Some(t) => t, 288 None => { 289 return ( 290 StatusCode::UNAUTHORIZED, 291 Json(json!({"error": "AuthenticationRequired"})), 292 ) 293 .into_response(); 294 } 295 }; 296 297 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 298 let did = match auth_result { 299 Ok(ref user) => user.did.clone(), 300 Err(e) => { 301 return ( 302 StatusCode::UNAUTHORIZED, 303 Json(json!({"error": e})), 304 ) 305 .into_response(); 306 } 307 }; 308 309 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", did) 310 .fetch_optional(&state.db) 311 .await 312 { 313 Ok(Some(row)) => row, 314 _ => { 315 return ( 316 StatusCode::INTERNAL_SERVER_ERROR, 317 Json(json!({"error": "InternalError"})), 318 ) 319 .into_response(); 320 } 321 }; 322 let handle = user.handle; 323 324 let key_bytes = match auth_result.ok().and_then(|u| u.key_bytes) { 325 Some(kb) => kb, 326 None => { 327 return ( 328 StatusCode::UNAUTHORIZED, 329 Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot get DID credentials"})), 330 ) 331 .into_response(); 332 } 333 }; 334 335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 336 let pds_endpoint = format!("https://{}", hostname); 337 338 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 339 Ok(k) => k, 340 Err(_) => { 341 return ( 342 StatusCode::INTERNAL_SERVER_ERROR, 343 Json(json!({"error": "InternalError"})), 344 ) 345 .into_response(); 346 } 347 }; 348 349 let public_key = secret_key.public_key(); 350 let encoded = public_key.to_encoded_point(true); 351 let did_key = format!( 352 "did:key:zQ3sh{}", 353 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 354 .chars() 355 .skip(1) 356 .collect::<String>() 357 ); 358 359 ( 360 StatusCode::OK, 361 Json(GetRecommendedDidCredentialsOutput { 362 rotation_keys: vec![did_key.clone()], 363 also_known_as: vec![format!("at://{}", handle)], 364 verification_methods: VerificationMethods { atproto: did_key }, 365 services: Services { 366 atproto_pds: AtprotoPds { 367 service_type: "AtprotoPersonalDataServer".to_string(), 368 endpoint: pds_endpoint, 369 }, 370 }, 371 }), 372 ) 373 .into_response() 374} 375 376#[derive(Deserialize)] 377pub struct UpdateHandleInput { 378 pub handle: String, 379} 380 381pub async fn update_handle( 382 State(state): State<AppState>, 383 headers: axum::http::HeaderMap, 384 Json(input): Json<UpdateHandleInput>, 385) -> Response { 386 let token = match crate::auth::extract_bearer_token_from_header( 387 headers.get("Authorization").and_then(|h| h.to_str().ok()) 388 ) { 389 Some(t) => t, 390 None => { 391 return ( 392 StatusCode::UNAUTHORIZED, 393 Json(json!({"error": "AuthenticationRequired"})), 394 ) 395 .into_response(); 396 } 397 }; 398 399 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await; 400 let did = match auth_result { 401 Ok(user) => user.did, 402 Err(e) => { 403 return ( 404 StatusCode::UNAUTHORIZED, 405 Json(json!({"error": e})), 406 ) 407 .into_response(); 408 } 409 }; 410 411 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 412 .fetch_optional(&state.db) 413 .await 414 { 415 Ok(Some(id)) => id, 416 _ => { 417 return ( 418 StatusCode::INTERNAL_SERVER_ERROR, 419 Json(json!({"error": "InternalError"})), 420 ) 421 .into_response(); 422 } 423 }; 424 425 let new_handle = input.handle.trim(); 426 if new_handle.is_empty() { 427 return ( 428 StatusCode::BAD_REQUEST, 429 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 430 ) 431 .into_response(); 432 } 433 434 if !new_handle 435 .chars() 436 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 437 { 438 return ( 439 StatusCode::BAD_REQUEST, 440 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 441 ) 442 .into_response(); 443 } 444 445 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 446 .fetch_optional(&state.db) 447 .await; 448 449 if let Ok(Some(_)) = existing { 450 return ( 451 StatusCode::BAD_REQUEST, 452 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 453 ) 454 .into_response(); 455 } 456 457 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 458 .execute(&state.db) 459 .await; 460 461 match result { 462 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 463 Err(e) => { 464 error!("DB error updating handle: {:?}", e); 465 ( 466 StatusCode::INTERNAL_SERVER_ERROR, 467 Json(json!({"error": "InternalError"})), 468 ) 469 .into_response() 470 } 471 } 472}