this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::{Path, Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use base64::Engine; 9use k256::SecretKey; 10use k256::elliptic_curve::sec1::ToEncodedPoint; 11use reqwest; 12use serde::Deserialize; 13use serde_json::json; 14use tracing::error; 15 16#[derive(Deserialize)] 17pub struct ResolveHandleParams { 18 pub handle: String, 19} 20 21pub async fn resolve_handle( 22 State(state): State<AppState>, 23 Query(params): Query<ResolveHandleParams>, 24) -> Response { 25 let handle = params.handle.trim(); 26 27 if handle.is_empty() { 28 return ( 29 StatusCode::BAD_REQUEST, 30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 31 ) 32 .into_response(); 33 } 34 35 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 36 .fetch_optional(&state.db) 37 .await; 38 39 match user { 40 Ok(Some(row)) => { 41 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 42 } 43 Ok(None) => ( 44 StatusCode::NOT_FOUND, 45 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 46 ) 47 .into_response(), 48 Err(e) => { 49 error!("DB error resolving handle: {:?}", e); 50 ( 51 StatusCode::INTERNAL_SERVER_ERROR, 52 Json(json!({"error": "InternalError"})), 53 ) 54 .into_response() 55 } 56 } 57} 58 59pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 60 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 61 let public_key = secret_key.public_key(); 62 let encoded = public_key.to_encoded_point(false); 63 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 64 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 65 66 json!({ 67 "kty": "EC", 68 "crv": "secp256k1", 69 "x": x, 70 "y": y 71 }) 72} 73 74pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 75 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 76 // Kinda for local dev, encode hostname if it contains port 77 let did = if hostname.contains(':') { 78 format!("did:web:{}", hostname.replace(':', "%3A")) 79 } else { 80 format!("did:web:{}", hostname) 81 }; 82 83 Json(json!({ 84 "@context": ["https://www.w3.org/ns/did/v1"], 85 "id": did, 86 "service": [{ 87 "id": "#atproto_pds", 88 "type": "AtprotoPersonalDataServer", 89 "serviceEndpoint": format!("https://{}", hostname) 90 }] 91 })) 92} 93 94pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 95 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 96 97 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 98 .fetch_optional(&state.db) 99 .await; 100 101 let (user_id, did) = match user { 102 Ok(Some(row)) => (row.id, row.did), 103 Ok(None) => { 104 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 105 } 106 Err(e) => { 107 error!("DB Error: {:?}", e); 108 return ( 109 StatusCode::INTERNAL_SERVER_ERROR, 110 Json(json!({"error": "InternalError"})), 111 ) 112 .into_response(); 113 } 114 }; 115 116 if !did.starts_with("did:web:") { 117 return ( 118 StatusCode::NOT_FOUND, 119 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 120 ) 121 .into_response(); 122 } 123 124 let key_row = sqlx::query!("SELECT key_bytes FROM user_keys WHERE user_id = $1", user_id) 125 .fetch_optional(&state.db) 126 .await; 127 128 let key_bytes: Vec<u8> = match key_row { 129 Ok(Some(row)) => row.key_bytes, 130 _ => { 131 return ( 132 StatusCode::INTERNAL_SERVER_ERROR, 133 Json(json!({"error": "InternalError"})), 134 ) 135 .into_response(); 136 } 137 }; 138 139 let jwk = get_jwk(&key_bytes); 140 141 Json(json!({ 142 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 143 "id": did, 144 "alsoKnownAs": [format!("at://{}", handle)], 145 "verificationMethod": [{ 146 "id": format!("{}#atproto", did), 147 "type": "JsonWebKey2020", 148 "controller": did, 149 "publicKeyJwk": jwk 150 }], 151 "service": [{ 152 "id": "#atproto_pds", 153 "type": "AtprotoPersonalDataServer", 154 "serviceEndpoint": format!("https://{}", hostname) 155 }] 156 })).into_response() 157} 158 159pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 160 let expected_prefix = if hostname.contains(':') { 161 format!("did:web:{}", hostname.replace(':', "%3A")) 162 } else { 163 format!("did:web:{}", hostname) 164 }; 165 166 if did.starts_with(&expected_prefix) { 167 let suffix = &did[expected_prefix.len()..]; 168 let expected_suffix = format!(":u:{}", handle); 169 if suffix == expected_suffix { 170 Ok(()) 171 } else { 172 Err(format!( 173 "Invalid DID path for this PDS. Expected {}", 174 expected_suffix 175 )) 176 } 177 } else { 178 let parts: Vec<&str> = did.split(':').collect(); 179 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 180 return Err("Invalid did:web format".into()); 181 } 182 183 let domain_segment = parts[2]; 184 let domain = domain_segment.replace("%3A", ":"); 185 186 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 187 "http" 188 } else { 189 "https" 190 }; 191 192 let url = if parts.len() == 3 { 193 format!("{}://{}/.well-known/did.json", scheme, domain) 194 } else { 195 let path = parts[3..].join("/"); 196 format!("{}://{}/{}/did.json", scheme, domain, path) 197 }; 198 199 let client = reqwest::Client::builder() 200 .timeout(std::time::Duration::from_secs(5)) 201 .build() 202 .map_err(|e| format!("Failed to create client: {}", e))?; 203 204 let resp = client 205 .get(&url) 206 .send() 207 .await 208 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 209 210 if !resp.status().is_success() { 211 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 212 } 213 214 let doc: serde_json::Value = resp 215 .json() 216 .await 217 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 218 219 let services = doc["service"] 220 .as_array() 221 .ok_or("No services found in DID doc")?; 222 223 let pds_endpoint = format!("https://{}", hostname); 224 225 let has_valid_service = services.iter().any(|s| { 226 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 227 }); 228 229 if has_valid_service { 230 Ok(()) 231 } else { 232 Err(format!( 233 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 234 pds_endpoint 235 )) 236 } 237 } 238} 239 240#[derive(serde::Serialize)] 241#[serde(rename_all = "camelCase")] 242pub struct GetRecommendedDidCredentialsOutput { 243 pub rotation_keys: Vec<String>, 244 pub also_known_as: Vec<String>, 245 pub verification_methods: VerificationMethods, 246 pub services: Services, 247} 248 249#[derive(serde::Serialize)] 250#[serde(rename_all = "camelCase")] 251pub struct VerificationMethods { 252 pub atproto: String, 253} 254 255#[derive(serde::Serialize)] 256#[serde(rename_all = "camelCase")] 257pub struct Services { 258 pub atproto_pds: AtprotoPds, 259} 260 261#[derive(serde::Serialize)] 262#[serde(rename_all = "camelCase")] 263pub struct AtprotoPds { 264 #[serde(rename = "type")] 265 pub service_type: String, 266 pub endpoint: String, 267} 268 269pub async fn get_recommended_did_credentials( 270 State(state): State<AppState>, 271 headers: axum::http::HeaderMap, 272) -> Response { 273 let auth_header = headers.get("Authorization"); 274 if auth_header.is_none() { 275 return ( 276 StatusCode::UNAUTHORIZED, 277 Json(json!({"error": "AuthenticationRequired"})), 278 ) 279 .into_response(); 280 } 281 282 let token = auth_header 283 .unwrap() 284 .to_str() 285 .unwrap_or("") 286 .replace("Bearer ", ""); 287 288 let session = sqlx::query!( 289 r#" 290 SELECT s.did, k.key_bytes, u.handle 291 FROM sessions s 292 JOIN users u ON s.did = u.did 293 JOIN user_keys k ON u.id = k.user_id 294 WHERE s.access_jwt = $1 295 "#, 296 token 297 ) 298 .fetch_optional(&state.db) 299 .await; 300 301 let (_did, key_bytes, handle) = match session { 302 Ok(Some(row)) => (row.did, row.key_bytes, row.handle), 303 Ok(None) => { 304 return ( 305 StatusCode::UNAUTHORIZED, 306 Json(json!({"error": "AuthenticationFailed"})), 307 ) 308 .into_response(); 309 } 310 Err(e) => { 311 error!("DB error in get_recommended_did_credentials: {:?}", e); 312 return ( 313 StatusCode::INTERNAL_SERVER_ERROR, 314 Json(json!({"error": "InternalError"})), 315 ) 316 .into_response(); 317 } 318 }; 319 320 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 321 return ( 322 StatusCode::UNAUTHORIZED, 323 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 324 ) 325 .into_response(); 326 } 327 328 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 329 let pds_endpoint = format!("https://{}", hostname); 330 331 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 332 Ok(k) => k, 333 Err(_) => { 334 return ( 335 StatusCode::INTERNAL_SERVER_ERROR, 336 Json(json!({"error": "InternalError"})), 337 ) 338 .into_response(); 339 } 340 }; 341 342 let public_key = secret_key.public_key(); 343 let encoded = public_key.to_encoded_point(true); 344 let did_key = format!( 345 "did:key:zQ3sh{}", 346 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 347 .chars() 348 .skip(1) 349 .collect::<String>() 350 ); 351 352 ( 353 StatusCode::OK, 354 Json(GetRecommendedDidCredentialsOutput { 355 rotation_keys: vec![did_key.clone()], 356 also_known_as: vec![format!("at://{}", handle)], 357 verification_methods: VerificationMethods { atproto: did_key }, 358 services: Services { 359 atproto_pds: AtprotoPds { 360 service_type: "AtprotoPersonalDataServer".to_string(), 361 endpoint: pds_endpoint, 362 }, 363 }, 364 }), 365 ) 366 .into_response() 367} 368 369#[derive(Deserialize)] 370pub struct UpdateHandleInput { 371 pub handle: String, 372} 373 374pub async fn update_handle( 375 State(state): State<AppState>, 376 headers: axum::http::HeaderMap, 377 Json(input): Json<UpdateHandleInput>, 378) -> Response { 379 let auth_header = headers.get("Authorization"); 380 if auth_header.is_none() { 381 return ( 382 StatusCode::UNAUTHORIZED, 383 Json(json!({"error": "AuthenticationRequired"})), 384 ) 385 .into_response(); 386 } 387 388 let token = auth_header 389 .unwrap() 390 .to_str() 391 .unwrap_or("") 392 .replace("Bearer ", ""); 393 394 let session = sqlx::query!( 395 r#" 396 SELECT s.did, k.key_bytes, u.id as user_id 397 FROM sessions s 398 JOIN users u ON s.did = u.did 399 JOIN user_keys k ON u.id = k.user_id 400 WHERE s.access_jwt = $1 401 "#, 402 token 403 ) 404 .fetch_optional(&state.db) 405 .await; 406 407 let (_did, key_bytes, user_id) = match session { 408 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id), 409 Ok(None) => { 410 return ( 411 StatusCode::UNAUTHORIZED, 412 Json(json!({"error": "AuthenticationFailed"})), 413 ) 414 .into_response(); 415 } 416 Err(e) => { 417 error!("DB error in update_handle: {:?}", e); 418 return ( 419 StatusCode::INTERNAL_SERVER_ERROR, 420 Json(json!({"error": "InternalError"})), 421 ) 422 .into_response(); 423 } 424 }; 425 426 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 427 return ( 428 StatusCode::UNAUTHORIZED, 429 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 430 ) 431 .into_response(); 432 } 433 434 let new_handle = input.handle.trim(); 435 if new_handle.is_empty() { 436 return ( 437 StatusCode::BAD_REQUEST, 438 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 439 ) 440 .into_response(); 441 } 442 443 if !new_handle 444 .chars() 445 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 446 { 447 return ( 448 StatusCode::BAD_REQUEST, 449 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 450 ) 451 .into_response(); 452 } 453 454 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 455 .fetch_optional(&state.db) 456 .await; 457 458 if let Ok(Some(_)) = existing { 459 return ( 460 StatusCode::BAD_REQUEST, 461 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 462 ) 463 .into_response(); 464 } 465 466 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 467 .execute(&state.db) 468 .await; 469 470 match result { 471 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 472 Err(e) => { 473 error!("DB error updating handle: {:?}", e); 474 ( 475 StatusCode::INTERNAL_SERVER_ERROR, 476 Json(json!({"error": "InternalError"})), 477 ) 478 .into_response() 479 } 480 } 481}