this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::State,
6 http::StatusCode,
7 response::{IntoResponse, Response},
8};
9use chrono::Utc;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct VerificationMethod {
16 pub id: String,
17 #[serde(rename = "type")]
18 pub method_type: String,
19 pub public_key_multibase: String,
20}
21
22#[derive(Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct UpdateDidDocumentInput {
25 pub verification_methods: Option<Vec<VerificationMethod>>,
26 pub also_known_as: Option<Vec<String>>,
27 pub service_endpoint: Option<String>,
28}
29
30#[derive(Serialize)]
31#[serde(rename_all = "camelCase")]
32pub struct UpdateDidDocumentOutput {
33 pub success: bool,
34 pub did_document: serde_json::Value,
35}
36
37pub async fn update_did_document(
38 State(state): State<AppState>,
39 headers: axum::http::HeaderMap,
40 Json(input): Json<UpdateDidDocumentInput>,
41) -> Response {
42 let extracted = match crate::auth::extract_auth_token_from_header(
43 headers.get("Authorization").and_then(|h| h.to_str().ok()),
44 ) {
45 Some(t) => t,
46 None => return ApiError::AuthenticationRequired.into_response(),
47 };
48 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
49 let http_uri = format!(
50 "https://{}/xrpc/_account.updateDidDocument",
51 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
52 );
53 let auth_user = match crate::auth::validate_token_with_dpop(
54 &state.db,
55 &extracted.token,
56 extracted.is_dpop,
57 dpop_proof,
58 "POST",
59 &http_uri,
60 true,
61 )
62 .await
63 {
64 Ok(user) => user,
65 Err(e) => return ApiError::from(e).into_response(),
66 };
67
68 if !auth_user.did.starts_with("did:web:") {
69 return (
70 StatusCode::BAD_REQUEST,
71 Json(json!({
72 "error": "InvalidRequest",
73 "message": "DID document updates are only available for did:web accounts"
74 })),
75 )
76 .into_response();
77 }
78
79 let user = match sqlx::query!(
80 "SELECT id, handle, deactivated_at FROM users WHERE did = $1",
81 auth_user.did
82 )
83 .fetch_optional(&state.db)
84 .await
85 {
86 Ok(Some(row)) => row,
87 Ok(None) => return ApiError::AccountNotFound.into_response(),
88 Err(e) => {
89 tracing::error!("DB error getting user: {:?}", e);
90 return ApiError::InternalError.into_response();
91 }
92 };
93
94 if user.deactivated_at.is_some() {
95 return ApiError::AccountDeactivated.into_response();
96 }
97
98 if let Some(ref methods) = input.verification_methods {
99 if methods.is_empty() {
100 return ApiError::InvalidRequest("verification_methods cannot be empty".into())
101 .into_response();
102 }
103 for method in methods {
104 if method.id.is_empty() {
105 return ApiError::InvalidRequest("verification method id is required".into())
106 .into_response();
107 }
108 if method.method_type != "Multikey" {
109 return ApiError::InvalidRequest(
110 "verification method type must be 'Multikey'".into(),
111 )
112 .into_response();
113 }
114 if !method.public_key_multibase.starts_with('z') {
115 return ApiError::InvalidRequest(
116 "publicKeyMultibase must start with 'z' (base58btc)".into(),
117 )
118 .into_response();
119 }
120 if method.public_key_multibase.len() < 40 {
121 return ApiError::InvalidRequest(
122 "publicKeyMultibase appears too short for a valid key".into(),
123 )
124 .into_response();
125 }
126 }
127 }
128
129 if let Some(ref handles) = input.also_known_as {
130 for handle in handles {
131 if !handle.starts_with("at://") {
132 return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into())
133 .into_response();
134 }
135 }
136 }
137
138 if let Some(ref endpoint) = input.service_endpoint {
139 let endpoint = endpoint.trim();
140 if !endpoint.starts_with("https://") {
141 return ApiError::InvalidRequest("serviceEndpoint must start with https://".into())
142 .into_response();
143 }
144 }
145
146 let verification_methods_json = input
147 .verification_methods
148 .as_ref()
149 .map(|v| serde_json::to_value(v).unwrap_or_default());
150
151 let also_known_as: Option<Vec<String>> = input.also_known_as.clone();
152
153 let now = Utc::now();
154
155 let upsert_result = sqlx::query!(
156 r#"
157 INSERT INTO did_web_overrides (user_id, verification_methods, also_known_as, updated_at)
158 VALUES ($1, COALESCE($2, '[]'::jsonb), COALESCE($3, '{}'::text[]), $4)
159 ON CONFLICT (user_id) DO UPDATE SET
160 verification_methods = CASE WHEN $2 IS NOT NULL THEN $2 ELSE did_web_overrides.verification_methods END,
161 also_known_as = CASE WHEN $3 IS NOT NULL THEN $3 ELSE did_web_overrides.also_known_as END,
162 updated_at = $4
163 "#,
164 user.id,
165 verification_methods_json,
166 also_known_as.as_deref(),
167 now
168 )
169 .execute(&state.db)
170 .await;
171
172 if let Err(e) = upsert_result {
173 tracing::error!("DB error upserting did_web_overrides: {:?}", e);
174 return ApiError::InternalError.into_response();
175 }
176
177 if let Some(ref endpoint) = input.service_endpoint {
178 let endpoint_clean = endpoint.trim().trim_end_matches('/');
179 let update_result = sqlx::query!(
180 "UPDATE users SET migrated_to_pds = $1, migrated_at = $2 WHERE did = $3",
181 endpoint_clean,
182 now,
183 auth_user.did
184 )
185 .execute(&state.db)
186 .await;
187
188 if let Err(e) = update_result {
189 tracing::error!("DB error updating service endpoint: {:?}", e);
190 return ApiError::InternalError.into_response();
191 }
192 }
193
194 let did_doc = build_did_document(&state.db, &auth_user.did).await;
195
196 tracing::info!("Updated DID document for {}", auth_user.did);
197
198 (
199 StatusCode::OK,
200 Json(UpdateDidDocumentOutput {
201 success: true,
202 did_document: did_doc,
203 }),
204 )
205 .into_response()
206}
207
208pub async fn get_did_document(
209 State(state): State<AppState>,
210 headers: axum::http::HeaderMap,
211) -> Response {
212 let extracted = match crate::auth::extract_auth_token_from_header(
213 headers.get("Authorization").and_then(|h| h.to_str().ok()),
214 ) {
215 Some(t) => t,
216 None => return ApiError::AuthenticationRequired.into_response(),
217 };
218 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
219 let http_uri = format!(
220 "https://{}/xrpc/_account.getDidDocument",
221 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
222 );
223 let auth_user = match crate::auth::validate_token_with_dpop(
224 &state.db,
225 &extracted.token,
226 extracted.is_dpop,
227 dpop_proof,
228 "GET",
229 &http_uri,
230 true,
231 )
232 .await
233 {
234 Ok(user) => user,
235 Err(e) => return ApiError::from(e).into_response(),
236 };
237
238 if !auth_user.did.starts_with("did:web:") {
239 return (
240 StatusCode::BAD_REQUEST,
241 Json(json!({
242 "error": "InvalidRequest",
243 "message": "This endpoint is only available for did:web accounts"
244 })),
245 )
246 .into_response();
247 }
248
249 let did_doc = build_did_document(&state.db, &auth_user.did).await;
250
251 (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response()
252}
253
254async fn build_did_document(db: &sqlx::PgPool, did: &str) -> serde_json::Value {
255 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
256
257 let user = match sqlx::query!(
258 "SELECT id, handle, migrated_to_pds FROM users WHERE did = $1",
259 did
260 )
261 .fetch_optional(db)
262 .await
263 {
264 Ok(Some(row)) => row,
265 _ => {
266 return json!({
267 "error": "User not found"
268 });
269 }
270 };
271
272 let overrides = sqlx::query!(
273 "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1",
274 user.id
275 )
276 .fetch_optional(db)
277 .await
278 .ok()
279 .flatten();
280
281 let service_endpoint = user
282 .migrated_to_pds
283 .unwrap_or_else(|| format!("https://{}", hostname));
284
285 if let Some((ovr, parsed)) = overrides.as_ref().and_then(|ovr| {
286 serde_json::from_value::<Vec<VerificationMethod>>(ovr.verification_methods.clone())
287 .ok()
288 .filter(|p| !p.is_empty())
289 .map(|p| (ovr, p))
290 }) {
291 let also_known_as = if !ovr.also_known_as.is_empty() {
292 ovr.also_known_as.clone()
293 } else {
294 vec![format!("at://{}", user.handle)]
295 };
296 return json!({
297 "@context": [
298 "https://www.w3.org/ns/did/v1",
299 "https://w3id.org/security/multikey/v1",
300 "https://w3id.org/security/suites/secp256k1-2019/v1"
301 ],
302 "id": did,
303 "alsoKnownAs": also_known_as,
304 "verificationMethod": parsed.iter().map(|m| json!({
305 "id": format!("{}{}", did, if m.id.starts_with('#') { m.id.clone() } else { format!("#{}", m.id) }),
306 "type": m.method_type,
307 "controller": did,
308 "publicKeyMultibase": m.public_key_multibase
309 })).collect::<Vec<_>>(),
310 "service": [{
311 "id": "#atproto_pds",
312 "type": "AtprotoPersonalDataServer",
313 "serviceEndpoint": service_endpoint
314 }]
315 });
316 }
317
318 let key_row = sqlx::query!(
319 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
320 user.id
321 )
322 .fetch_optional(db)
323 .await;
324
325 let public_key_multibase = match key_row {
326 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
327 Ok(key_bytes) => crate::api::identity::did::get_public_key_multibase(&key_bytes)
328 .unwrap_or_else(|_| "error".to_string()),
329 Err(_) => "error".to_string(),
330 },
331 _ => "error".to_string(),
332 };
333
334 let also_known_as = if let Some(ref ovr) = overrides {
335 if !ovr.also_known_as.is_empty() {
336 ovr.also_known_as.clone()
337 } else {
338 vec![format!("at://{}", user.handle)]
339 }
340 } else {
341 vec![format!("at://{}", user.handle)]
342 };
343
344 json!({
345 "@context": [
346 "https://www.w3.org/ns/did/v1",
347 "https://w3id.org/security/multikey/v1",
348 "https://w3id.org/security/suites/secp256k1-2019/v1"
349 ],
350 "id": did,
351 "alsoKnownAs": also_known_as,
352 "verificationMethod": [{
353 "id": format!("{}#atproto", did),
354 "type": "Multikey",
355 "controller": did,
356 "publicKeyMultibase": public_key_multibase
357 }],
358 "service": [{
359 "id": "#atproto_pds",
360 "type": "AtprotoPersonalDataServer",
361 "serviceEndpoint": service_endpoint
362 }]
363 })
364}