this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::{Path, Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use base64::Engine; 9use k256::SecretKey; 10use k256::elliptic_curve::sec1::ToEncodedPoint; 11use reqwest; 12use serde::Deserialize; 13use serde_json::json; 14use sqlx::Row; 15use tracing::error; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 36 let user = sqlx::query("SELECT did FROM users WHERE handle = $1") 37 .bind(handle) 38 .fetch_optional(&state.db) 39 .await; 40 41 match user { 42 Ok(Some(row)) => { 43 let did: String = row.get("did"); 44 (StatusCode::OK, Json(json!({ "did": did }))).into_response() 45 } 46 Ok(None) => ( 47 StatusCode::NOT_FOUND, 48 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 49 ) 50 .into_response(), 51 Err(e) => { 52 error!("DB error resolving handle: {:?}", e); 53 ( 54 StatusCode::INTERNAL_SERVER_ERROR, 55 Json(json!({"error": "InternalError"})), 56 ) 57 .into_response() 58 } 59 } 60} 61 62pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 63 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 64 let public_key = secret_key.public_key(); 65 let encoded = public_key.to_encoded_point(false); 66 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 67 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 68 69 json!({ 70 "kty": "EC", 71 "crv": "secp256k1", 72 "x": x, 73 "y": y 74 }) 75} 76 77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 79 // Kinda for local dev, encode hostname if it contains port 80 let did = if hostname.contains(':') { 81 format!("did:web:{}", hostname.replace(':', "%3A")) 82 } else { 83 format!("did:web:{}", hostname) 84 }; 85 86 Json(json!({ 87 "@context": ["https://www.w3.org/ns/did/v1"], 88 "id": did, 89 "service": [{ 90 "id": "#atproto_pds", 91 "type": "AtprotoPersonalDataServer", 92 "serviceEndpoint": format!("https://{}", hostname) 93 }] 94 })) 95} 96 97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 99 100 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") 101 .bind(&handle) 102 .fetch_optional(&state.db) 103 .await; 104 105 let (user_id, did) = match user { 106 Ok(Some(row)) => { 107 let id: uuid::Uuid = row.get("id"); 108 let d: String = row.get("did"); 109 (id, d) 110 } 111 Ok(None) => { 112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 113 } 114 Err(e) => { 115 error!("DB Error: {:?}", e); 116 return ( 117 StatusCode::INTERNAL_SERVER_ERROR, 118 Json(json!({"error": "InternalError"})), 119 ) 120 .into_response(); 121 } 122 }; 123 124 if !did.starts_with("did:web:") { 125 return ( 126 StatusCode::NOT_FOUND, 127 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 128 ) 129 .into_response(); 130 } 131 132 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") 133 .bind(user_id) 134 .fetch_optional(&state.db) 135 .await; 136 137 let key_bytes: Vec<u8> = match key_row { 138 Ok(Some(row)) => row.get("key_bytes"), 139 _ => { 140 return ( 141 StatusCode::INTERNAL_SERVER_ERROR, 142 Json(json!({"error": "InternalError"})), 143 ) 144 .into_response(); 145 } 146 }; 147 148 let jwk = get_jwk(&key_bytes); 149 150 Json(json!({ 151 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 152 "id": did, 153 "alsoKnownAs": [format!("at://{}", handle)], 154 "verificationMethod": [{ 155 "id": format!("{}#atproto", did), 156 "type": "JsonWebKey2020", 157 "controller": did, 158 "publicKeyJwk": jwk 159 }], 160 "service": [{ 161 "id": "#atproto_pds", 162 "type": "AtprotoPersonalDataServer", 163 "serviceEndpoint": format!("https://{}", hostname) 164 }] 165 })).into_response() 166} 167 168pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 169 let expected_prefix = if hostname.contains(':') { 170 format!("did:web:{}", hostname.replace(':', "%3A")) 171 } else { 172 format!("did:web:{}", hostname) 173 }; 174 175 if did.starts_with(&expected_prefix) { 176 let suffix = &did[expected_prefix.len()..]; 177 let expected_suffix = format!(":u:{}", handle); 178 if suffix == expected_suffix { 179 Ok(()) 180 } else { 181 Err(format!( 182 "Invalid DID path for this PDS. Expected {}", 183 expected_suffix 184 )) 185 } 186 } else { 187 let parts: Vec<&str> = did.split(':').collect(); 188 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 189 return Err("Invalid did:web format".into()); 190 } 191 192 let domain_segment = parts[2]; 193 let domain = domain_segment.replace("%3A", ":"); 194 195 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 196 "http" 197 } else { 198 "https" 199 }; 200 201 let url = if parts.len() == 3 { 202 format!("{}://{}/.well-known/did.json", scheme, domain) 203 } else { 204 let path = parts[3..].join("/"); 205 format!("{}://{}/{}/did.json", scheme, domain, path) 206 }; 207 208 let client = reqwest::Client::builder() 209 .timeout(std::time::Duration::from_secs(5)) 210 .build() 211 .map_err(|e| format!("Failed to create client: {}", e))?; 212 213 let resp = client 214 .get(&url) 215 .send() 216 .await 217 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 218 219 if !resp.status().is_success() { 220 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 221 } 222 223 let doc: serde_json::Value = resp 224 .json() 225 .await 226 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 227 228 let services = doc["service"] 229 .as_array() 230 .ok_or("No services found in DID doc")?; 231 232 let pds_endpoint = format!("https://{}", hostname); 233 234 let has_valid_service = services.iter().any(|s| { 235 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 236 }); 237 238 if has_valid_service { 239 Ok(()) 240 } else { 241 Err(format!( 242 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 243 pds_endpoint 244 )) 245 } 246 } 247} 248 249#[derive(serde::Serialize)] 250#[serde(rename_all = "camelCase")] 251pub struct GetRecommendedDidCredentialsOutput { 252 pub rotation_keys: Vec<String>, 253 pub also_known_as: Vec<String>, 254 pub verification_methods: VerificationMethods, 255 pub services: Services, 256} 257 258#[derive(serde::Serialize)] 259#[serde(rename_all = "camelCase")] 260pub struct VerificationMethods { 261 pub atproto: String, 262} 263 264#[derive(serde::Serialize)] 265#[serde(rename_all = "camelCase")] 266pub struct Services { 267 pub atproto_pds: AtprotoPds, 268} 269 270#[derive(serde::Serialize)] 271#[serde(rename_all = "camelCase")] 272pub struct AtprotoPds { 273 #[serde(rename = "type")] 274 pub service_type: String, 275 pub endpoint: String, 276} 277 278pub async fn get_recommended_did_credentials( 279 State(state): State<AppState>, 280 headers: axum::http::HeaderMap, 281) -> Response { 282 let auth_header = headers.get("Authorization"); 283 if auth_header.is_none() { 284 return ( 285 StatusCode::UNAUTHORIZED, 286 Json(json!({"error": "AuthenticationRequired"})), 287 ) 288 .into_response(); 289 } 290 291 let token = auth_header 292 .unwrap() 293 .to_str() 294 .unwrap_or("") 295 .replace("Bearer ", ""); 296 297 let session = sqlx::query( 298 r#" 299 SELECT s.did, k.key_bytes, u.handle 300 FROM sessions s 301 JOIN users u ON s.did = u.did 302 JOIN user_keys k ON u.id = k.user_id 303 WHERE s.access_jwt = $1 304 "#, 305 ) 306 .bind(&token) 307 .fetch_optional(&state.db) 308 .await; 309 310 let (_did, key_bytes, handle) = match session { 311 Ok(Some(row)) => ( 312 row.get::<String, _>("did"), 313 row.get::<Vec<u8>, _>("key_bytes"), 314 row.get::<String, _>("handle"), 315 ), 316 Ok(None) => { 317 return ( 318 StatusCode::UNAUTHORIZED, 319 Json(json!({"error": "AuthenticationFailed"})), 320 ) 321 .into_response(); 322 } 323 Err(e) => { 324 error!("DB error in get_recommended_did_credentials: {:?}", e); 325 return ( 326 StatusCode::INTERNAL_SERVER_ERROR, 327 Json(json!({"error": "InternalError"})), 328 ) 329 .into_response(); 330 } 331 }; 332 333 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 334 return ( 335 StatusCode::UNAUTHORIZED, 336 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 337 ) 338 .into_response(); 339 } 340 341 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 342 let pds_endpoint = format!("https://{}", hostname); 343 344 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 345 Ok(k) => k, 346 Err(_) => { 347 return ( 348 StatusCode::INTERNAL_SERVER_ERROR, 349 Json(json!({"error": "InternalError"})), 350 ) 351 .into_response(); 352 } 353 }; 354 355 let public_key = secret_key.public_key(); 356 let encoded = public_key.to_encoded_point(true); 357 let did_key = format!( 358 "did:key:zQ3sh{}", 359 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 360 .chars() 361 .skip(1) 362 .collect::<String>() 363 ); 364 365 ( 366 StatusCode::OK, 367 Json(GetRecommendedDidCredentialsOutput { 368 rotation_keys: vec![did_key.clone()], 369 also_known_as: vec![format!("at://{}", handle)], 370 verification_methods: VerificationMethods { atproto: did_key }, 371 services: Services { 372 atproto_pds: AtprotoPds { 373 service_type: "AtprotoPersonalDataServer".to_string(), 374 endpoint: pds_endpoint, 375 }, 376 }, 377 }), 378 ) 379 .into_response() 380} 381 382#[derive(Deserialize)] 383pub struct UpdateHandleInput { 384 pub handle: String, 385} 386 387pub async fn update_handle( 388 State(state): State<AppState>, 389 headers: axum::http::HeaderMap, 390 Json(input): Json<UpdateHandleInput>, 391) -> Response { 392 let auth_header = headers.get("Authorization"); 393 if auth_header.is_none() { 394 return ( 395 StatusCode::UNAUTHORIZED, 396 Json(json!({"error": "AuthenticationRequired"})), 397 ) 398 .into_response(); 399 } 400 401 let token = auth_header 402 .unwrap() 403 .to_str() 404 .unwrap_or("") 405 .replace("Bearer ", ""); 406 407 let session = sqlx::query( 408 r#" 409 SELECT s.did, k.key_bytes, u.id as user_id 410 FROM sessions s 411 JOIN users u ON s.did = u.did 412 JOIN user_keys k ON u.id = k.user_id 413 WHERE s.access_jwt = $1 414 "#, 415 ) 416 .bind(&token) 417 .fetch_optional(&state.db) 418 .await; 419 420 let (_did, key_bytes, user_id) = match session { 421 Ok(Some(row)) => ( 422 row.get::<String, _>("did"), 423 row.get::<Vec<u8>, _>("key_bytes"), 424 row.get::<uuid::Uuid, _>("user_id"), 425 ), 426 Ok(None) => { 427 return ( 428 StatusCode::UNAUTHORIZED, 429 Json(json!({"error": "AuthenticationFailed"})), 430 ) 431 .into_response(); 432 } 433 Err(e) => { 434 error!("DB error in update_handle: {:?}", e); 435 return ( 436 StatusCode::INTERNAL_SERVER_ERROR, 437 Json(json!({"error": "InternalError"})), 438 ) 439 .into_response(); 440 } 441 }; 442 443 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 444 return ( 445 StatusCode::UNAUTHORIZED, 446 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 447 ) 448 .into_response(); 449 } 450 451 let new_handle = input.handle.trim(); 452 if new_handle.is_empty() { 453 return ( 454 StatusCode::BAD_REQUEST, 455 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 456 ) 457 .into_response(); 458 } 459 460 if !new_handle 461 .chars() 462 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 463 { 464 return ( 465 StatusCode::BAD_REQUEST, 466 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 467 ) 468 .into_response(); 469 } 470 471 let existing = sqlx::query("SELECT id FROM users WHERE handle = $1 AND id != $2") 472 .bind(new_handle) 473 .bind(user_id) 474 .fetch_optional(&state.db) 475 .await; 476 477 if let Ok(Some(_)) = existing { 478 return ( 479 StatusCode::BAD_REQUEST, 480 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 481 ) 482 .into_response(); 483 } 484 485 let result = sqlx::query("UPDATE users SET handle = $1 WHERE id = $2") 486 .bind(new_handle) 487 .bind(user_id) 488 .execute(&state.db) 489 .await; 490 491 match result { 492 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 493 Err(e) => { 494 error!("DB error updating handle: {:?}", e); 495 ( 496 StatusCode::INTERNAL_SERVER_ERROR, 497 Json(json!({"error": "InternalError"})), 498 ) 499 .into_response() 500 } 501 } 502}