this repo has no description
1use crate::api::ApiError;
2use crate::plc::signing_key_to_did_key;
3use crate::state::AppState;
4use axum::{
5 Json,
6 extract::{Path, Query, State},
7 http::{HeaderMap, StatusCode},
8 response::{IntoResponse, Response},
9};
10use base64::Engine;
11use k256::SecretKey;
12use k256::elliptic_curve::sec1::ToEncodedPoint;
13use reqwest;
14use serde::Deserialize;
15use serde_json::json;
16use tracing::{error, warn};
17
18#[derive(Deserialize)]
19pub struct ResolveHandleParams {
20 pub handle: String,
21}
22
23pub async fn resolve_handle(
24 State(state): State<AppState>,
25 Query(params): Query<ResolveHandleParams>,
26) -> Response {
27 let handle = params.handle.trim();
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35 let cache_key = format!("handle:{}", handle);
36 if let Some(did) = state.cache.get(&cache_key).await {
37 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
38 }
39 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
40 let suffix = format!(".{}", hostname);
41 let short_handle = if handle.ends_with(&suffix) {
42 handle.strip_suffix(&suffix).unwrap_or(handle)
43 } else {
44 handle
45 };
46 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
47 .fetch_optional(&state.db)
48 .await;
49 match user {
50 Ok(Some(row)) => {
51 let _ = state
52 .cache
53 .set(&cache_key, &row.did, std::time::Duration::from_secs(300))
54 .await;
55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
56 }
57 Ok(None) => {
58 match crate::handle::resolve_handle(handle).await {
59 Ok(did) => {
60 let _ = state
61 .cache
62 .set(&cache_key, &did, std::time::Duration::from_secs(300))
63 .await;
64 (StatusCode::OK, Json(json!({ "did": did }))).into_response()
65 }
66 Err(_) => (
67 StatusCode::NOT_FOUND,
68 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
69 )
70 .into_response(),
71 }
72 }
73 Err(e) => {
74 error!("DB error resolving handle: {:?}", e);
75 (
76 StatusCode::INTERNAL_SERVER_ERROR,
77 Json(json!({"error": "InternalError"})),
78 )
79 .into_response()
80 }
81 }
82}
83
84pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
85 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
86 let public_key = secret_key.public_key();
87 let encoded = public_key.to_encoded_point(false);
88 let x = encoded.x().ok_or("Missing x coordinate")?;
89 let y = encoded.y().ok_or("Missing y coordinate")?;
90 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
91 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
92 Ok(json!({
93 "kty": "EC",
94 "crv": "secp256k1",
95 "x": x_b64,
96 "y": y_b64
97 }))
98}
99
100pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
101 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102 // Kinda for local dev, encode hostname if it contains port
103 let did = if hostname.contains(':') {
104 format!("did:web:{}", hostname.replace(':', "%3A"))
105 } else {
106 format!("did:web:{}", hostname)
107 };
108 Json(json!({
109 "@context": ["https://www.w3.org/ns/did/v1"],
110 "id": did,
111 "service": [{
112 "id": "#atproto_pds",
113 "type": "AtprotoPersonalDataServer",
114 "serviceEndpoint": format!("https://{}", hostname)
115 }]
116 }))
117}
118
119pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
120 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
121 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
122 .fetch_optional(&state.db)
123 .await;
124 let (user_id, did) = match user {
125 Ok(Some(row)) => (row.id, row.did),
126 Ok(None) => {
127 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
128 }
129 Err(e) => {
130 error!("DB Error: {:?}", e);
131 return (
132 StatusCode::INTERNAL_SERVER_ERROR,
133 Json(json!({"error": "InternalError"})),
134 )
135 .into_response();
136 }
137 };
138 if !did.starts_with("did:web:") {
139 return (
140 StatusCode::NOT_FOUND,
141 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
142 )
143 .into_response();
144 }
145 let key_row = sqlx::query!(
146 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
147 user_id
148 )
149 .fetch_optional(&state.db)
150 .await;
151 let key_bytes: Vec<u8> = match key_row {
152 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
153 Ok(k) => k,
154 Err(_) => {
155 return (
156 StatusCode::INTERNAL_SERVER_ERROR,
157 Json(json!({"error": "InternalError"})),
158 )
159 .into_response();
160 }
161 },
162 _ => {
163 return (
164 StatusCode::INTERNAL_SERVER_ERROR,
165 Json(json!({"error": "InternalError"})),
166 )
167 .into_response();
168 }
169 };
170 let jwk = match get_jwk(&key_bytes) {
171 Ok(j) => j,
172 Err(e) => {
173 tracing::error!("Failed to generate JWK: {}", e);
174 return (
175 StatusCode::INTERNAL_SERVER_ERROR,
176 Json(json!({"error": "InternalError"})),
177 )
178 .into_response();
179 }
180 };
181 Json(json!({
182 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
183 "id": did,
184 "alsoKnownAs": [format!("at://{}", handle)],
185 "verificationMethod": [{
186 "id": format!("{}#atproto", did),
187 "type": "JsonWebKey2020",
188 "controller": did,
189 "publicKeyJwk": jwk
190 }],
191 "service": [{
192 "id": "#atproto_pds",
193 "type": "AtprotoPersonalDataServer",
194 "serviceEndpoint": format!("https://{}", hostname)
195 }]
196 })).into_response()
197}
198
199pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
200 let expected_prefix = if hostname.contains(':') {
201 format!("did:web:{}", hostname.replace(':', "%3A"))
202 } else {
203 format!("did:web:{}", hostname)
204 };
205 if did.starts_with(&expected_prefix) {
206 let suffix = &did[expected_prefix.len()..];
207 let expected_suffix = format!(":u:{}", handle);
208 if suffix == expected_suffix {
209 Ok(())
210 } else {
211 Err(format!(
212 "Invalid DID path for this PDS. Expected {}",
213 expected_suffix
214 ))
215 }
216 } else {
217 let parts: Vec<&str> = did.split(':').collect();
218 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
219 return Err("Invalid did:web format".into());
220 }
221 let domain_segment = parts[2];
222 let domain = domain_segment.replace("%3A", ":");
223 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
224 "http"
225 } else {
226 "https"
227 };
228 let url = if parts.len() == 3 {
229 format!("{}://{}/.well-known/did.json", scheme, domain)
230 } else {
231 let path = parts[3..].join("/");
232 format!("{}://{}/{}/did.json", scheme, domain, path)
233 };
234 let client = reqwest::Client::builder()
235 .timeout(std::time::Duration::from_secs(5))
236 .build()
237 .map_err(|e| format!("Failed to create client: {}", e))?;
238 let resp = client
239 .get(&url)
240 .send()
241 .await
242 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
243 if !resp.status().is_success() {
244 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
245 }
246 let doc: serde_json::Value = resp
247 .json()
248 .await
249 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
250 let services = doc["service"]
251 .as_array()
252 .ok_or("No services found in DID doc")?;
253 let pds_endpoint = format!("https://{}", hostname);
254 let has_valid_service = services.iter().any(|s| {
255 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
256 });
257 if has_valid_service {
258 Ok(())
259 } else {
260 Err(format!(
261 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
262 pds_endpoint
263 ))
264 }
265 }
266}
267
268#[derive(serde::Serialize)]
269#[serde(rename_all = "camelCase")]
270pub struct GetRecommendedDidCredentialsOutput {
271 pub rotation_keys: Vec<String>,
272 pub also_known_as: Vec<String>,
273 pub verification_methods: VerificationMethods,
274 pub services: Services,
275}
276
277#[derive(serde::Serialize)]
278#[serde(rename_all = "camelCase")]
279pub struct VerificationMethods {
280 pub atproto: String,
281}
282
283#[derive(serde::Serialize)]
284#[serde(rename_all = "camelCase")]
285pub struct Services {
286 pub atproto_pds: AtprotoPds,
287}
288
289#[derive(serde::Serialize)]
290#[serde(rename_all = "camelCase")]
291pub struct AtprotoPds {
292 #[serde(rename = "type")]
293 pub service_type: String,
294 pub endpoint: String,
295}
296
297pub async fn get_recommended_did_credentials(
298 State(state): State<AppState>,
299 headers: axum::http::HeaderMap,
300) -> Response {
301 let token = match crate::auth::extract_bearer_token_from_header(
302 headers.get("Authorization").and_then(|h| h.to_str().ok()),
303 ) {
304 Some(t) => t,
305 None => {
306 return (
307 StatusCode::UNAUTHORIZED,
308 Json(json!({"error": "AuthenticationRequired"})),
309 )
310 .into_response();
311 }
312 };
313 let auth_user = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
314 Ok(user) => user,
315 Err(e) => return ApiError::from(e).into_response(),
316 };
317 let user = match sqlx::query!(
318 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
319 auth_user.did
320 )
321 .fetch_optional(&state.db)
322 .await
323 {
324 Ok(Some(row)) => row,
325 _ => return ApiError::InternalError.into_response(),
326 };
327 let key_bytes = match auth_user.key_bytes {
328 Some(kb) => kb,
329 None => {
330 return ApiError::AuthenticationFailedMsg(
331 "OAuth tokens cannot get DID credentials".into(),
332 )
333 .into_response();
334 }
335 };
336 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
337 let pds_endpoint = format!("https://{}", hostname);
338 let full_handle = if user.handle.contains('.') {
339 user.handle.clone()
340 } else {
341 format!("{}.{}", user.handle, hostname)
342 };
343 let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) {
344 Ok(k) => k,
345 Err(_) => return ApiError::InternalError.into_response(),
346 };
347 let did_key = signing_key_to_did_key(&signing_key);
348 (
349 StatusCode::OK,
350 Json(GetRecommendedDidCredentialsOutput {
351 rotation_keys: vec![did_key.clone()],
352 also_known_as: vec![format!("at://{}", full_handle)],
353 verification_methods: VerificationMethods { atproto: did_key },
354 services: Services {
355 atproto_pds: AtprotoPds {
356 service_type: "AtprotoPersonalDataServer".to_string(),
357 endpoint: pds_endpoint,
358 },
359 },
360 }),
361 )
362 .into_response()
363}
364
365#[derive(Deserialize)]
366pub struct UpdateHandleInput {
367 pub handle: String,
368}
369
370pub async fn update_handle(
371 State(state): State<AppState>,
372 headers: axum::http::HeaderMap,
373 Json(input): Json<UpdateHandleInput>,
374) -> Response {
375 let token = match crate::auth::extract_bearer_token_from_header(
376 headers.get("Authorization").and_then(|h| h.to_str().ok()),
377 ) {
378 Some(t) => t,
379 None => return ApiError::AuthenticationRequired.into_response(),
380 };
381 let did = match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
382 Ok(user) => user.did,
383 Err(e) => return ApiError::from(e).into_response(),
384 };
385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
386 .fetch_optional(&state.db)
387 .await
388 {
389 Ok(Some(id)) => id,
390 _ => return ApiError::InternalError.into_response(),
391 };
392 let new_handle = input.handle.trim();
393 if new_handle.is_empty() {
394 return ApiError::InvalidRequest("handle is required".into()).into_response();
395 }
396 if !new_handle
397 .chars()
398 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
399 {
400 return (
401 StatusCode::BAD_REQUEST,
402 Json(
403 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
404 ),
405 )
406 .into_response();
407 }
408 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
409 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname);
410 let (handle_to_store, full_handle) = if is_service_domain {
411 let suffix = format!(".{}", hostname);
412 let short_handle = if new_handle.ends_with(&suffix) {
413 new_handle.strip_suffix(&suffix).unwrap_or(new_handle)
414 } else {
415 new_handle
416 };
417 (short_handle.to_string(), format!("{}.{}", short_handle, hostname))
418 } else {
419 match crate::handle::verify_handle_ownership(new_handle, &did).await {
420 Ok(()) => {}
421 Err(crate::handle::HandleResolutionError::NotFound) => {
422 return (
423 StatusCode::BAD_REQUEST,
424 Json(json!({
425 "error": "HandleNotAvailable",
426 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did",
427 "handle": new_handle
428 })),
429 )
430 .into_response();
431 }
432 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => {
433 return (
434 StatusCode::BAD_REQUEST,
435 Json(json!({
436 "error": "HandleNotAvailable",
437 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual)
438 })),
439 )
440 .into_response();
441 }
442 Err(e) => {
443 warn!("Handle verification failed: {}", e);
444 return (
445 StatusCode::BAD_REQUEST,
446 Json(json!({
447 "error": "HandleNotAvailable",
448 "message": format!("Handle verification failed: {}", e)
449 })),
450 )
451 .into_response();
452 }
453 }
454 (new_handle.to_string(), new_handle.to_string())
455 };
456 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
457 .fetch_optional(&state.db)
458 .await
459 .ok()
460 .flatten();
461 let existing = sqlx::query!(
462 "SELECT id FROM users WHERE handle = $1 AND id != $2",
463 handle_to_store,
464 user_id
465 )
466 .fetch_optional(&state.db)
467 .await;
468 if let Ok(Some(_)) = existing {
469 return (
470 StatusCode::BAD_REQUEST,
471 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
472 )
473 .into_response();
474 }
475 let result = sqlx::query!(
476 "UPDATE users SET handle = $1 WHERE id = $2",
477 handle_to_store,
478 user_id
479 )
480 .execute(&state.db)
481 .await;
482 match result {
483 Ok(_) => {
484 if let Some(old) = old_handle {
485 let _ = state.cache.delete(&format!("handle:{}", old)).await;
486 }
487 let _ = state
488 .cache
489 .delete(&format!("handle:{}", handle_to_store))
490 .await;
491 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await;
492 if let Err(e) =
493 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
494 .await
495 {
496 warn!("Failed to sequence identity event for handle update: {}", e);
497 }
498 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await {
499 warn!("Failed to update PLC handle: {}", e);
500 }
501 (StatusCode::OK, Json(json!({}))).into_response()
502 }
503 Err(e) => {
504 error!("DB error updating handle: {:?}", e);
505 (
506 StatusCode::INTERNAL_SERVER_ERROR,
507 Json(json!({"error": "InternalError"})),
508 )
509 .into_response()
510 }
511 }
512}
513
514async fn update_plc_handle(
515 state: &AppState,
516 did: &str,
517 new_handle: &str,
518) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
519 if !did.starts_with("did:plc:") {
520 return Ok(());
521 }
522 let user_row = sqlx::query!(
523 r#"SELECT u.id, uk.key_bytes, uk.encryption_version
524 FROM users u
525 JOIN user_keys uk ON u.id = uk.user_id
526 WHERE u.did = $1"#,
527 did
528 )
529 .fetch_optional(&state.db)
530 .await?;
531 let user_row = match user_row {
532 Some(r) => r,
533 None => return Ok(()),
534 };
535 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?;
536 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?;
537 let plc_client = crate::plc::PlcClient::new(None);
538 let last_op = plc_client.get_last_op(did).await?;
539 let new_also_known_as = vec![format!("at://{}", new_handle)];
540 let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?;
541 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?;
542 plc_client.send_operation(did, &signed_op).await?;
543 Ok(())
544}
545
546pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
547 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
548 Some(h) => h,
549 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
550 };
551 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
552 let suffix = format!(".{}", hostname);
553 let handle = host.split(':').next().unwrap_or(host);
554 let short_handle = if handle.ends_with(&suffix) {
555 handle.strip_suffix(&suffix).unwrap_or(handle)
556 } else {
557 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
558 };
559 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
560 .fetch_optional(&state.db)
561 .await;
562 match user {
563 Ok(Some(row)) => row.did.into_response(),
564 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
565 Err(e) => {
566 error!("DB error in well-known atproto-did: {:?}", e);
567 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
568 }
569 }
570}