this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::{Path, Query, State}, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8}; 9use base64::Engine; 10use k256::SecretKey; 11use k256::elliptic_curve::sec1::ToEncodedPoint; 12use reqwest; 13use serde::Deserialize; 14use serde_json::json; 15use tracing::error; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 36 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 37 .fetch_optional(&state.db) 38 .await; 39 40 match user { 41 Ok(Some(row)) => { 42 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response() 43 } 44 Ok(None) => ( 45 StatusCode::NOT_FOUND, 46 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 47 ) 48 .into_response(), 49 Err(e) => { 50 error!("DB error resolving handle: {:?}", e); 51 ( 52 StatusCode::INTERNAL_SERVER_ERROR, 53 Json(json!({"error": "InternalError"})), 54 ) 55 .into_response() 56 } 57 } 58} 59 60pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> { 61 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?; 62 let public_key = secret_key.public_key(); 63 let encoded = public_key.to_encoded_point(false); 64 let x = encoded.x().ok_or("Missing x coordinate")?; 65 let y = encoded.y().ok_or("Missing y coordinate")?; 66 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x); 67 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y); 68 69 Ok(json!({ 70 "kty": "EC", 71 "crv": "secp256k1", 72 "x": x_b64, 73 "y": y_b64 74 })) 75} 76 77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 79 // Kinda for local dev, encode hostname if it contains port 80 let did = if hostname.contains(':') { 81 format!("did:web:{}", hostname.replace(':', "%3A")) 82 } else { 83 format!("did:web:{}", hostname) 84 }; 85 86 Json(json!({ 87 "@context": ["https://www.w3.org/ns/did/v1"], 88 "id": did, 89 "service": [{ 90 "id": "#atproto_pds", 91 "type": "AtprotoPersonalDataServer", 92 "serviceEndpoint": format!("https://{}", hostname) 93 }] 94 })) 95} 96 97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 99 100 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle) 101 .fetch_optional(&state.db) 102 .await; 103 104 let (user_id, did) = match user { 105 Ok(Some(row)) => (row.id, row.did), 106 Ok(None) => { 107 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 108 } 109 Err(e) => { 110 error!("DB Error: {:?}", e); 111 return ( 112 StatusCode::INTERNAL_SERVER_ERROR, 113 Json(json!({"error": "InternalError"})), 114 ) 115 .into_response(); 116 } 117 }; 118 119 if !did.starts_with("did:web:") { 120 return ( 121 StatusCode::NOT_FOUND, 122 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 123 ) 124 .into_response(); 125 } 126 127 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id) 128 .fetch_optional(&state.db) 129 .await; 130 131 let key_bytes: Vec<u8> = match key_row { 132 Ok(Some(row)) => { 133 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 134 Ok(k) => k, 135 Err(_) => { 136 return ( 137 StatusCode::INTERNAL_SERVER_ERROR, 138 Json(json!({"error": "InternalError"})), 139 ) 140 .into_response(); 141 } 142 } 143 } 144 _ => { 145 return ( 146 StatusCode::INTERNAL_SERVER_ERROR, 147 Json(json!({"error": "InternalError"})), 148 ) 149 .into_response(); 150 } 151 }; 152 153 let jwk = match get_jwk(&key_bytes) { 154 Ok(j) => j, 155 Err(e) => { 156 tracing::error!("Failed to generate JWK: {}", e); 157 return ( 158 StatusCode::INTERNAL_SERVER_ERROR, 159 Json(json!({"error": "InternalError"})), 160 ) 161 .into_response(); 162 } 163 }; 164 165 Json(json!({ 166 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 167 "id": did, 168 "alsoKnownAs": [format!("at://{}", handle)], 169 "verificationMethod": [{ 170 "id": format!("{}#atproto", did), 171 "type": "JsonWebKey2020", 172 "controller": did, 173 "publicKeyJwk": jwk 174 }], 175 "service": [{ 176 "id": "#atproto_pds", 177 "type": "AtprotoPersonalDataServer", 178 "serviceEndpoint": format!("https://{}", hostname) 179 }] 180 })).into_response() 181} 182 183pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 184 let expected_prefix = if hostname.contains(':') { 185 format!("did:web:{}", hostname.replace(':', "%3A")) 186 } else { 187 format!("did:web:{}", hostname) 188 }; 189 190 if did.starts_with(&expected_prefix) { 191 let suffix = &did[expected_prefix.len()..]; 192 let expected_suffix = format!(":u:{}", handle); 193 if suffix == expected_suffix { 194 Ok(()) 195 } else { 196 Err(format!( 197 "Invalid DID path for this PDS. Expected {}", 198 expected_suffix 199 )) 200 } 201 } else { 202 let parts: Vec<&str> = did.split(':').collect(); 203 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 204 return Err("Invalid did:web format".into()); 205 } 206 207 let domain_segment = parts[2]; 208 let domain = domain_segment.replace("%3A", ":"); 209 210 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 211 "http" 212 } else { 213 "https" 214 }; 215 216 let url = if parts.len() == 3 { 217 format!("{}://{}/.well-known/did.json", scheme, domain) 218 } else { 219 let path = parts[3..].join("/"); 220 format!("{}://{}/{}/did.json", scheme, domain, path) 221 }; 222 223 let client = reqwest::Client::builder() 224 .timeout(std::time::Duration::from_secs(5)) 225 .build() 226 .map_err(|e| format!("Failed to create client: {}", e))?; 227 228 let resp = client 229 .get(&url) 230 .send() 231 .await 232 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 233 234 if !resp.status().is_success() { 235 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 236 } 237 238 let doc: serde_json::Value = resp 239 .json() 240 .await 241 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 242 243 let services = doc["service"] 244 .as_array() 245 .ok_or("No services found in DID doc")?; 246 247 let pds_endpoint = format!("https://{}", hostname); 248 249 let has_valid_service = services.iter().any(|s| { 250 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 251 }); 252 253 if has_valid_service { 254 Ok(()) 255 } else { 256 Err(format!( 257 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 258 pds_endpoint 259 )) 260 } 261 } 262} 263 264#[derive(serde::Serialize)] 265#[serde(rename_all = "camelCase")] 266pub struct GetRecommendedDidCredentialsOutput { 267 pub rotation_keys: Vec<String>, 268 pub also_known_as: Vec<String>, 269 pub verification_methods: VerificationMethods, 270 pub services: Services, 271} 272 273#[derive(serde::Serialize)] 274#[serde(rename_all = "camelCase")] 275pub struct VerificationMethods { 276 pub atproto: String, 277} 278 279#[derive(serde::Serialize)] 280#[serde(rename_all = "camelCase")] 281pub struct Services { 282 pub atproto_pds: AtprotoPds, 283} 284 285#[derive(serde::Serialize)] 286#[serde(rename_all = "camelCase")] 287pub struct AtprotoPds { 288 #[serde(rename = "type")] 289 pub service_type: String, 290 pub endpoint: String, 291} 292 293pub async fn get_recommended_did_credentials( 294 State(state): State<AppState>, 295 headers: axum::http::HeaderMap, 296) -> Response { 297 let token = match crate::auth::extract_bearer_token_from_header( 298 headers.get("Authorization").and_then(|h| h.to_str().ok()) 299 ) { 300 Some(t) => t, 301 None => { 302 return ( 303 StatusCode::UNAUTHORIZED, 304 Json(json!({"error": "AuthenticationRequired"})), 305 ) 306 .into_response(); 307 } 308 }; 309 310 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 311 Ok(user) => user, 312 Err(e) => return ApiError::from(e).into_response(), 313 }; 314 315 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did) 316 .fetch_optional(&state.db) 317 .await 318 { 319 Ok(Some(row)) => row, 320 _ => return ApiError::InternalError.into_response(), 321 }; 322 323 let key_bytes = match auth_user.key_bytes { 324 Some(kb) => kb, 325 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(), 326 }; 327 328 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 329 let pds_endpoint = format!("https://{}", hostname); 330 331 let secret_key = match k256::SecretKey::from_slice(&key_bytes) { 332 Ok(k) => k, 333 Err(_) => return ApiError::InternalError.into_response(), 334 }; 335 336 let public_key = secret_key.public_key(); 337 let encoded = public_key.to_encoded_point(true); 338 let did_key = format!( 339 "did:key:zQ3sh{}", 340 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes()) 341 .chars() 342 .skip(1) 343 .collect::<String>() 344 ); 345 346 ( 347 StatusCode::OK, 348 Json(GetRecommendedDidCredentialsOutput { 349 rotation_keys: vec![did_key.clone()], 350 also_known_as: vec![format!("at://{}", user.handle)], 351 verification_methods: VerificationMethods { atproto: did_key }, 352 services: Services { 353 atproto_pds: AtprotoPds { 354 service_type: "AtprotoPersonalDataServer".to_string(), 355 endpoint: pds_endpoint, 356 }, 357 }, 358 }), 359 ) 360 .into_response() 361} 362 363#[derive(Deserialize)] 364pub struct UpdateHandleInput { 365 pub handle: String, 366} 367 368pub async fn update_handle( 369 State(state): State<AppState>, 370 headers: axum::http::HeaderMap, 371 Json(input): Json<UpdateHandleInput>, 372) -> Response { 373 let token = match crate::auth::extract_bearer_token_from_header( 374 headers.get("Authorization").and_then(|h| h.to_str().ok()) 375 ) { 376 Some(t) => t, 377 None => return ApiError::AuthenticationRequired.into_response(), 378 }; 379 380 let did = match crate::auth::validate_bearer_token(&state.db, &token).await { 381 Ok(user) => user.did, 382 Err(e) => return ApiError::from(e).into_response(), 383 }; 384 385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 386 .fetch_optional(&state.db) 387 .await 388 { 389 Ok(Some(id)) => id, 390 _ => return ApiError::InternalError.into_response(), 391 }; 392 393 let new_handle = input.handle.trim(); 394 if new_handle.is_empty() { 395 return ApiError::InvalidRequest("handle is required".into()).into_response(); 396 } 397 398 if !new_handle 399 .chars() 400 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 401 { 402 return ( 403 StatusCode::BAD_REQUEST, 404 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})), 405 ) 406 .into_response(); 407 } 408 409 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id) 410 .fetch_optional(&state.db) 411 .await; 412 413 if let Ok(Some(_)) = existing { 414 return ( 415 StatusCode::BAD_REQUEST, 416 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})), 417 ) 418 .into_response(); 419 } 420 421 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id) 422 .execute(&state.db) 423 .await; 424 425 match result { 426 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(), 427 Err(e) => { 428 error!("DB error updating handle: {:?}", e); 429 ( 430 StatusCode::INTERNAL_SERVER_ERROR, 431 Json(json!({"error": "InternalError"})), 432 ) 433 .into_response() 434 } 435 } 436}