this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::error; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 36 let cache_key = format!("handle:{}", handle); 37 if let Some(did) = state.cache.get(&cache_key).await { 38 return (StatusCode::OK, Json(json!({ "did": did }))).into_response(); 39 } 40 41 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 42 .fetch_optional(&state.db) 43 .await; 44 45 match user { 46 Ok(Some(row)) => { 47 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await; 48 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 49 } 50 Ok(None) => ( 51 StatusCode::NOT_FOUND, 52 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 53 ) 54 .into_response(), 55 Err(e) => { 56 error!("DB error resolving handle: {:?}", e); 57 ( 58 StatusCode::INTERNAL_SERVER_ERROR, 59 Json(json!({"error": "InternalError"})), 60 ) 61 .into_response() 62 } 63 } 64} 65 66pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 67 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 68 let public_key = secret_key.public_key(); 69 let encoded = public_key.to_encoded_point(false); 70 let x = encoded.x().ok_or("Missing x coordinate")?; 71 let y = encoded.y().ok_or("Missing y coordinate")?; 72 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 73 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 74 75 Ok(json!({ 76 "kty": "EC", 77 "crv": "secp256k1", 78 "x": x_b64, 79 "y": y_b64 80 })) 81} 82 83pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 84 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 85 // Kinda for local dev, encode hostname if it contains port 86 let did = if hostname.contains(':') { 87 format!("did:web:{}", hostname.replace(':', "%3A")) 88 } else { 89 format!("did:web:{}", hostname) 90 }; 91 92 Json(json!({ 93 "@context": ["https://www.w3.org/ns/did/v1"], 94 "id": did, 95 "service": [{ 96 "id": "#atproto_pds", 97 "type": "AtprotoPersonalDataServer", 98 "serviceEndpoint": format!("https://{}", hostname) 99 }] 100 })) 101} 102 103pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 104 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 105 106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 107 .fetch_optional(&state.db) 108 .await; 109 110 let (user_id, did) = match user { 111 Ok(Some(row)) => (row.id, row.did), 112 Ok(None) => { 113 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 114 } 115 Err(e) => { 116 error!("DB Error: {:?}", e); 117 return ( 118 StatusCode::INTERNAL_SERVER_ERROR, 119 Json(json!({"error": "InternalError"})), 120 ) 121 .into_response(); 122 } 123 }; 124 125 if !did.starts_with("did:web:") { 126 return ( 127 StatusCode::NOT_FOUND, 128 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 129 ) 130 .into_response(); 131 } 132 133 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 134 .fetch_optional(&state.db) 135 .await; 136 137 let key_bytes: Vec<u8> = match key_row { 138 Ok(Some(row)) => { 139 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 140 Ok(k) => k, 141 Err(_) => { 142 return ( 143 StatusCode::INTERNAL_SERVER_ERROR, 144 Json(json!({"error": "InternalError"})), 145 ) 146 .into_response(); 147 } 148 } 149 } 150 _ => { 151 return ( 152 StatusCode::INTERNAL_SERVER_ERROR, 153 Json(json!({"error": "InternalError"})), 154 ) 155 .into_response(); 156 } 157 }; 158 159 let jwk = match get_jwk(&key_bytes) { 160 Ok(j) => j, 161 Err(e) => { 162 tracing::error!("Failed to generate JWK: {}", e); 163 return ( 164 StatusCode::INTERNAL_SERVER_ERROR, 165 Json(json!({"error": "InternalError"})), 166 ) 167 .into_response(); 168 } 169 }; 170 171 Json(json!({ 172 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 173 "id": did, 174 "alsoKnownAs": [format!("at://{}", handle)], 175 "verificationMethod": [{ 176 "id": format!("{}#atproto", did), 177 "type": "JsonWebKey2020", 178 "controller": did, 179 "publicKeyJwk": jwk 180 }], 181 "service": [{ 182 "id": "#atproto_pds", 183 "type": "AtprotoPersonalDataServer", 184 "serviceEndpoint": format!("https://{}", hostname) 185 }] 186 })).into_response() 187} 188 189pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 190 let expected_prefix = if hostname.contains(':') { 191 format!("did:web:{}", hostname.replace(':', "%3A")) 192 } else { 193 format!("did:web:{}", hostname) 194 }; 195 196 if did.starts_with(&expected_prefix) { 197 let suffix = &did[expected_prefix.len()..]; 198 let expected_suffix = format!(":u:{}", handle); 199 if suffix == expected_suffix { 200 Ok(()) 201 } else { 202 Err(format!( 203 "Invalid DID path for this PDS. Expected {}", 204 expected_suffix 205 )) 206 } 207 } else { 208 let parts: Vec<&str> = did.split(':').collect(); 209 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 210 return Err("Invalid did:web format".into()); 211 } 212 213 let domain_segment = parts[2]; 214 let domain = domain_segment.replace("%3A", ":"); 215 216 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 217 "http" 218 } else { 219 "https" 220 }; 221 222 let url = if parts.len() == 3 { 223 format!("{}://{}/.well-known/did.json", scheme, domain) 224 } else { 225 let path = parts[3..].join("/"); 226 format!("{}://{}/{}/did.json", scheme, domain, path) 227 }; 228 229 let client = reqwest::Client::builder() 230 .timeout(std::time::Duration::from_secs(5)) 231 .build() 232 .map_err(|e| format!("Failed to create client: {}", e))?; 233 234 let resp = client 235 .get(&url) 236 .send() 237 .await 238 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 239 240 if !resp.status().is_success() { 241 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 242 } 243 244 let doc: serde_json::Value = resp 245 .json() 246 .await 247 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 248 249 let services = doc["service"] 250 .as_array() 251 .ok_or("No services found in DID doc")?; 252 253 let pds_endpoint = format!("https://{}", hostname); 254 255 let has_valid_service = services.iter().any(|s| { 256 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 257 }); 258 259 if has_valid_service { 260 Ok(()) 261 } else { 262 Err(format!( 263 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 264 pds_endpoint 265 )) 266 } 267 } 268} 269 270#[derive(serde::Serialize)] 271#[serde(rename_all = "camelCase")] 272pub struct GetRecommendedDidCredentialsOutput { 273 pub rotation_keys: Vec<String>, 274 pub also_known_as: Vec<String>, 275 pub verification_methods: VerificationMethods, 276 pub services: Services, 277} 278 279#[derive(serde::Serialize)] 280#[serde(rename_all = "camelCase")] 281pub struct VerificationMethods { 282 pub atproto: String, 283} 284 285#[derive(serde::Serialize)] 286#[serde(rename_all = "camelCase")] 287pub struct Services { 288 pub atproto_pds: AtprotoPds, 289} 290 291#[derive(serde::Serialize)] 292#[serde(rename_all = "camelCase")] 293pub struct AtprotoPds { 294 #[serde(rename = "type")] 295 pub service_type: String, 296 pub endpoint: String, 297} 298 299pub async fn get_recommended_did_credentials( 300 State(state): State<AppState>, 301 headers: axum::http::HeaderMap, 302) -> Response { 303 let token = match crate::auth::extract_bearer_token_from_header( 304 headers.get("Authorization").and_then(|h| h.to_str().ok()) 305 ) { 306 Some(t) => t, 307 None => { 308 return ( 309 StatusCode::UNAUTHORIZED, 310 Json(json!({"error": "AuthenticationRequired"})), 311 ) 312 .into_response(); 313 } 314 }; 315 316 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 317 Ok(user) => user, 318 Err(e) => return ApiError::from(e).into_response(), 319 }; 320 321 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 322 .fetch_optional(&state.db) 323 .await 324 { 325 Ok(Some(row)) => row, 326 _ => return ApiError::InternalError.into_response(), 327 }; 328 329 let key_bytes = match auth_user.key_bytes { 330 Some(kb) => kb, 331 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 332 }; 333 334 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 335 let pds_endpoint = format!("https://{}", hostname); 336 337 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 338 Ok(k) => k, 339 Err(_) => return ApiError::InternalError.into_response(), 340 }; 341 342 let public_key = secret_key.public_key(); 343 let encoded = public_key.to_encoded_point(true); 344 let did_key = format!( 345 "did:key:zQ3sh{}", 346 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 347 .chars() 348 .skip(1) 349 .collect::<String>() 350 ); 351 352 ( 353 StatusCode::OK, 354 Json(GetRecommendedDidCredentialsOutput { 355 rotation_keys: vec![did_key.clone()], 356 also_known_as: vec![format!("at://{}", user.handle)], 357 verification_methods: VerificationMethods { atproto: did_key }, 358 services: Services { 359 atproto_pds: AtprotoPds { 360 service_type: "AtprotoPersonalDataServer".to_string(), 361 endpoint: pds_endpoint, 362 }, 363 }, 364 }), 365 ) 366 .into_response() 367} 368 369#[derive(Deserialize)] 370pub struct UpdateHandleInput { 371 pub handle: String, 372} 373 374pub async fn update_handle( 375 State(state): State<AppState>, 376 headers: axum::http::HeaderMap, 377 Json(input): Json<UpdateHandleInput>, 378) -> Response { 379 let token = match crate::auth::extract_bearer_token_from_header( 380 headers.get("Authorization").and_then(|h| h.to_str().ok()) 381 ) { 382 Some(t) => t, 383 None => return ApiError::AuthenticationRequired.into_response(), 384 }; 385 386 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 387 Ok(user) => user.did, 388 Err(e) => return ApiError::from(e).into_response(), 389 }; 390 391 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 392 .fetch_optional(&state.db) 393 .await 394 { 395 Ok(Some(id)) => id, 396 _ => return ApiError::InternalError.into_response(), 397 }; 398 399 let new_handle = input.handle.trim(); 400 if new_handle.is_empty() { 401 return ApiError::InvalidRequest("handle is required".into()).into_response(); 402 } 403 404 if !new_handle 405 .chars() 406 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 407 { 408 return ( 409 StatusCode::BAD_REQUEST, 410 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 411 ) 412 .into_response(); 413 } 414 415 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id) 416 .fetch_optional(&state.db) 417 .await 418 .ok() 419 .flatten(); 420 421 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 422 .fetch_optional(&state.db) 423 .await; 424 425 if let Ok(Some(_)) = existing { 426 return ( 427 StatusCode::BAD_REQUEST, 428 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 429 ) 430 .into_response(); 431 } 432 433 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 434 .execute(&state.db) 435 .await; 436 437 match result { 438 Ok(_) => { 439 if let Some(old) = old_handle { 440 let _ = state.cache.delete(&format!("handle:{}", old)).await; 441 } 442 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await; 443 (StatusCode::OK, Json(json!({}))).into_response() 444 } 445 Err(e) => { 446 error!("DB error updating handle: {:?}", e); 447 ( 448 StatusCode::INTERNAL_SERVER_ERROR, 449 Json(json!({"error": "InternalError"})), 450 ) 451 .into_response() 452 } 453 } 454}