this repo has no description
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::State, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8}; 9use chrono::Utc; 10use serde::{Deserialize, Serialize}; 11use serde_json::json; 12 13#[derive(Debug, Clone, Serialize, Deserialize)] 14#[serde(rename_all = "camelCase")] 15pub struct VerificationMethod { 16 pub id: String, 17 #[serde(rename = "type")] 18 pub method_type: String, 19 pub public_key_multibase: String, 20} 21 22#[derive(Deserialize)] 23#[serde(rename_all = "camelCase")] 24pub struct UpdateDidDocumentInput { 25 pub verification_methods: Option<Vec<VerificationMethod>>, 26 pub also_known_as: Option<Vec<String>>, 27 pub service_endpoint: Option<String>, 28} 29 30#[derive(Serialize)] 31#[serde(rename_all = "camelCase")] 32pub struct UpdateDidDocumentOutput { 33 pub success: bool, 34 pub did_document: serde_json::Value, 35} 36 37pub async fn update_did_document( 38 State(state): State<AppState>, 39 headers: axum::http::HeaderMap, 40 Json(input): Json<UpdateDidDocumentInput>, 41) -> Response { 42 let extracted = match crate::auth::extract_auth_token_from_header( 43 headers.get("Authorization").and_then(|h| h.to_str().ok()), 44 ) { 45 Some(t) => t, 46 None => return ApiError::AuthenticationRequired.into_response(), 47 }; 48 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 49 let http_uri = format!( 50 "https://{}/xrpc/_account.updateDidDocument", 51 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 52 ); 53 let auth_user = match crate::auth::validate_token_with_dpop( 54 &state.db, 55 &extracted.token, 56 extracted.is_dpop, 57 dpop_proof, 58 "POST", 59 &http_uri, 60 true, 61 ) 62 .await 63 { 64 Ok(user) => user, 65 Err(e) => return ApiError::from(e).into_response(), 66 }; 67 68 if !auth_user.did.starts_with("did:web:") { 69 return ( 70 StatusCode::BAD_REQUEST, 71 Json(json!({ 72 "error": "InvalidRequest", 73 "message": "DID document updates are only available for did:web accounts" 74 })), 75 ) 76 .into_response(); 77 } 78 79 let user = match sqlx::query!( 80 "SELECT id, handle, deactivated_at FROM users WHERE did = $1", 81 auth_user.did 82 ) 83 .fetch_optional(&state.db) 84 .await 85 { 86 Ok(Some(row)) => row, 87 Ok(None) => return ApiError::AccountNotFound.into_response(), 88 Err(e) => { 89 tracing::error!("DB error getting user: {:?}", e); 90 return ApiError::InternalError.into_response(); 91 } 92 }; 93 94 if user.deactivated_at.is_some() { 95 return ApiError::AccountDeactivated.into_response(); 96 } 97 98 if let Some(ref methods) = input.verification_methods { 99 if methods.is_empty() { 100 return ApiError::InvalidRequest("verification_methods cannot be empty".into()) 101 .into_response(); 102 } 103 for method in methods { 104 if method.id.is_empty() { 105 return ApiError::InvalidRequest("verification method id is required".into()) 106 .into_response(); 107 } 108 if method.method_type != "Multikey" { 109 return ApiError::InvalidRequest( 110 "verification method type must be 'Multikey'".into(), 111 ) 112 .into_response(); 113 } 114 if !method.public_key_multibase.starts_with('z') { 115 return ApiError::InvalidRequest( 116 "publicKeyMultibase must start with 'z' (base58btc)".into(), 117 ) 118 .into_response(); 119 } 120 if method.public_key_multibase.len() < 40 { 121 return ApiError::InvalidRequest( 122 "publicKeyMultibase appears too short for a valid key".into(), 123 ) 124 .into_response(); 125 } 126 } 127 } 128 129 if let Some(ref handles) = input.also_known_as { 130 for handle in handles { 131 if !handle.starts_with("at://") { 132 return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) 133 .into_response(); 134 } 135 } 136 } 137 138 if let Some(ref endpoint) = input.service_endpoint { 139 let endpoint = endpoint.trim(); 140 if !endpoint.starts_with("https://") { 141 return ApiError::InvalidRequest("serviceEndpoint must start with https://".into()) 142 .into_response(); 143 } 144 } 145 146 let verification_methods_json = input 147 .verification_methods 148 .as_ref() 149 .map(|v| serde_json::to_value(v).unwrap_or_default()); 150 151 let also_known_as: Option<Vec<String>> = input.also_known_as.clone(); 152 153 let now = Utc::now(); 154 155 let upsert_result = sqlx::query!( 156 r#" 157 INSERT INTO did_web_overrides (user_id, verification_methods, also_known_as, updated_at) 158 VALUES ($1, COALESCE($2, '[]'::jsonb), COALESCE($3, '{}'::text[]), $4) 159 ON CONFLICT (user_id) DO UPDATE SET 160 verification_methods = CASE WHEN $2 IS NOT NULL THEN $2 ELSE did_web_overrides.verification_methods END, 161 also_known_as = CASE WHEN $3 IS NOT NULL THEN $3 ELSE did_web_overrides.also_known_as END, 162 updated_at = $4 163 "#, 164 user.id, 165 verification_methods_json, 166 also_known_as.as_deref(), 167 now 168 ) 169 .execute(&state.db) 170 .await; 171 172 if let Err(e) = upsert_result { 173 tracing::error!("DB error upserting did_web_overrides: {:?}", e); 174 return ApiError::InternalError.into_response(); 175 } 176 177 if let Some(ref endpoint) = input.service_endpoint { 178 let endpoint_clean = endpoint.trim().trim_end_matches('/'); 179 let update_result = sqlx::query!( 180 "UPDATE users SET migrated_to_pds = $1, migrated_at = $2 WHERE did = $3", 181 endpoint_clean, 182 now, 183 auth_user.did 184 ) 185 .execute(&state.db) 186 .await; 187 188 if let Err(e) = update_result { 189 tracing::error!("DB error updating service endpoint: {:?}", e); 190 return ApiError::InternalError.into_response(); 191 } 192 } 193 194 let did_doc = build_did_document(&state.db, &auth_user.did).await; 195 196 tracing::info!("Updated DID document for {}", auth_user.did); 197 198 ( 199 StatusCode::OK, 200 Json(UpdateDidDocumentOutput { 201 success: true, 202 did_document: did_doc, 203 }), 204 ) 205 .into_response() 206} 207 208pub async fn get_did_document( 209 State(state): State<AppState>, 210 headers: axum::http::HeaderMap, 211) -> Response { 212 let extracted = match crate::auth::extract_auth_token_from_header( 213 headers.get("Authorization").and_then(|h| h.to_str().ok()), 214 ) { 215 Some(t) => t, 216 None => return ApiError::AuthenticationRequired.into_response(), 217 }; 218 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 219 let http_uri = format!( 220 "https://{}/xrpc/_account.getDidDocument", 221 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 222 ); 223 let auth_user = match crate::auth::validate_token_with_dpop( 224 &state.db, 225 &extracted.token, 226 extracted.is_dpop, 227 dpop_proof, 228 "GET", 229 &http_uri, 230 true, 231 ) 232 .await 233 { 234 Ok(user) => user, 235 Err(e) => return ApiError::from(e).into_response(), 236 }; 237 238 if !auth_user.did.starts_with("did:web:") { 239 return ( 240 StatusCode::BAD_REQUEST, 241 Json(json!({ 242 "error": "InvalidRequest", 243 "message": "This endpoint is only available for did:web accounts" 244 })), 245 ) 246 .into_response(); 247 } 248 249 let did_doc = build_did_document(&state.db, &auth_user.did).await; 250 251 (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response() 252} 253 254async fn build_did_document(db: &sqlx::PgPool, did: &str) -> serde_json::Value { 255 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 256 257 let user = match sqlx::query!( 258 "SELECT id, handle, migrated_to_pds FROM users WHERE did = $1", 259 did 260 ) 261 .fetch_optional(db) 262 .await 263 { 264 Ok(Some(row)) => row, 265 _ => { 266 return json!({ 267 "error": "User not found" 268 }); 269 } 270 }; 271 272 let overrides = sqlx::query!( 273 "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1", 274 user.id 275 ) 276 .fetch_optional(db) 277 .await 278 .ok() 279 .flatten(); 280 281 let service_endpoint = user 282 .migrated_to_pds 283 .unwrap_or_else(|| format!("https://{}", hostname)); 284 285 if let Some((ovr, parsed)) = overrides.as_ref().and_then(|ovr| { 286 serde_json::from_value::<Vec<VerificationMethod>>(ovr.verification_methods.clone()) 287 .ok() 288 .filter(|p| !p.is_empty()) 289 .map(|p| (ovr, p)) 290 }) { 291 let also_known_as = if !ovr.also_known_as.is_empty() { 292 ovr.also_known_as.clone() 293 } else { 294 vec![format!("at://{}", user.handle)] 295 }; 296 return json!({ 297 "@context": [ 298 "https://www.w3.org/ns/did/v1", 299 "https://w3id.org/security/multikey/v1", 300 "https://w3id.org/security/suites/secp256k1-2019/v1" 301 ], 302 "id": did, 303 "alsoKnownAs": also_known_as, 304 "verificationMethod": parsed.iter().map(|m| json!({ 305 "id": format!("{}{}", did, if m.id.starts_with('#') { m.id.clone() } else { format!("#{}", m.id) }), 306 "type": m.method_type, 307 "controller": did, 308 "publicKeyMultibase": m.public_key_multibase 309 })).collect::<Vec<_>>(), 310 "service": [{ 311 "id": "#atproto_pds", 312 "type": "AtprotoPersonalDataServer", 313 "serviceEndpoint": service_endpoint 314 }] 315 }); 316 } 317 318 let key_row = sqlx::query!( 319 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 320 user.id 321 ) 322 .fetch_optional(db) 323 .await; 324 325 let public_key_multibase = match key_row { 326 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 327 Ok(key_bytes) => crate::api::identity::did::get_public_key_multibase(&key_bytes) 328 .unwrap_or_else(|_| "error".to_string()), 329 Err(_) => "error".to_string(), 330 }, 331 _ => "error".to_string(), 332 }; 333 334 let also_known_as = if let Some(ref ovr) = overrides { 335 if !ovr.also_known_as.is_empty() { 336 ovr.also_known_as.clone() 337 } else { 338 vec![format!("at://{}", user.handle)] 339 } 340 } else { 341 vec![format!("at://{}", user.handle)] 342 }; 343 344 json!({ 345 "@context": [ 346 "https://www.w3.org/ns/did/v1", 347 "https://w3id.org/security/multikey/v1", 348 "https://w3id.org/security/suites/secp256k1-2019/v1" 349 ], 350 "id": did, 351 "alsoKnownAs": also_known_as, 352 "verificationMethod": [{ 353 "id": format!("{}#atproto", did), 354 "type": "Multikey", 355 "controller": did, 356 "publicKeyMultibase": public_key_multibase 357 }], 358 "service": [{ 359 "id": "#atproto_pds", 360 "type": "AtprotoPersonalDataServer", 361 "serviceEndpoint": service_endpoint 362 }] 363 }) 364}