this repo has no description
1use crate::api::ApiError;
2use crate::plc::signing_key_to_did_key;
3use crate::state::AppState;
4use axum::{
5 Json,
6 extract::{Path, Query, State},
7 http::{HeaderMap, StatusCode},
8 response::{IntoResponse, Response},
9};
10use base64::Engine;
11use k256::SecretKey;
12use k256::elliptic_curve::sec1::ToEncodedPoint;
13use reqwest;
14use serde::Deserialize;
15use serde_json::json;
16use tracing::{error, warn};
17
18#[derive(Deserialize)]
19pub struct ResolveHandleParams {
20 pub handle: String,
21}
22
23pub async fn resolve_handle(
24 State(state): State<AppState>,
25 Query(params): Query<ResolveHandleParams>,
26) -> Response {
27 let handle = params.handle.trim();
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35 let cache_key = format!("handle:{}", handle);
36 if let Some(did) = state.cache.get(&cache_key).await {
37 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
38 }
39 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
40 let suffix = format!(".{}", hostname);
41 let short_handle = if handle.ends_with(&suffix) {
42 handle.strip_suffix(&suffix).unwrap_or(handle)
43 } else {
44 handle
45 };
46 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
47 .fetch_optional(&state.db)
48 .await;
49 match user {
50 Ok(Some(row)) => {
51 let _ = state
52 .cache
53 .set(&cache_key, &row.did, std::time::Duration::from_secs(300))
54 .await;
55 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
56 }
57 Ok(None) => match crate::handle::resolve_handle(handle).await {
58 Ok(did) => {
59 let _ = state
60 .cache
61 .set(&cache_key, &did, std::time::Duration::from_secs(300))
62 .await;
63 (StatusCode::OK, Json(json!({ "did": did }))).into_response()
64 }
65 Err(_) => (
66 StatusCode::NOT_FOUND,
67 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
68 )
69 .into_response(),
70 },
71 Err(e) => {
72 error!("DB error resolving handle: {:?}", e);
73 (
74 StatusCode::INTERNAL_SERVER_ERROR,
75 Json(json!({"error": "InternalError"})),
76 )
77 .into_response()
78 }
79 }
80}
81
82pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
83 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
84 let public_key = secret_key.public_key();
85 let encoded = public_key.to_encoded_point(false);
86 let x = encoded.x().ok_or("Missing x coordinate")?;
87 let y = encoded.y().ok_or("Missing y coordinate")?;
88 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
89 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
90 Ok(json!({
91 "kty": "EC",
92 "crv": "secp256k1",
93 "x": x_b64,
94 "y": y_b64
95 }))
96}
97
98pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
100 // Kinda for local dev, encode hostname if it contains port
101 let did = if hostname.contains(':') {
102 format!("did:web:{}", hostname.replace(':', "%3A"))
103 } else {
104 format!("did:web:{}", hostname)
105 };
106 Json(json!({
107 "@context": ["https://www.w3.org/ns/did/v1"],
108 "id": did,
109 "service": [{
110 "id": "#atproto_pds",
111 "type": "AtprotoPersonalDataServer",
112 "serviceEndpoint": format!("https://{}", hostname)
113 }]
114 }))
115}
116
117pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
118 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
119 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
120 .fetch_optional(&state.db)
121 .await;
122 let (user_id, did) = match user {
123 Ok(Some(row)) => (row.id, row.did),
124 Ok(None) => {
125 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
126 }
127 Err(e) => {
128 error!("DB Error: {:?}", e);
129 return (
130 StatusCode::INTERNAL_SERVER_ERROR,
131 Json(json!({"error": "InternalError"})),
132 )
133 .into_response();
134 }
135 };
136 if !did.starts_with("did:web:") {
137 return (
138 StatusCode::NOT_FOUND,
139 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
140 )
141 .into_response();
142 }
143 let key_row = sqlx::query!(
144 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
145 user_id
146 )
147 .fetch_optional(&state.db)
148 .await;
149 let key_bytes: Vec<u8> = match key_row {
150 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
151 Ok(k) => k,
152 Err(_) => {
153 return (
154 StatusCode::INTERNAL_SERVER_ERROR,
155 Json(json!({"error": "InternalError"})),
156 )
157 .into_response();
158 }
159 },
160 _ => {
161 return (
162 StatusCode::INTERNAL_SERVER_ERROR,
163 Json(json!({"error": "InternalError"})),
164 )
165 .into_response();
166 }
167 };
168 let jwk = match get_jwk(&key_bytes) {
169 Ok(j) => j,
170 Err(e) => {
171 tracing::error!("Failed to generate JWK: {}", e);
172 return (
173 StatusCode::INTERNAL_SERVER_ERROR,
174 Json(json!({"error": "InternalError"})),
175 )
176 .into_response();
177 }
178 };
179 Json(json!({
180 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
181 "id": did,
182 "alsoKnownAs": [format!("at://{}", handle)],
183 "verificationMethod": [{
184 "id": format!("{}#atproto", did),
185 "type": "JsonWebKey2020",
186 "controller": did,
187 "publicKeyJwk": jwk
188 }],
189 "service": [{
190 "id": "#atproto_pds",
191 "type": "AtprotoPersonalDataServer",
192 "serviceEndpoint": format!("https://{}", hostname)
193 }]
194 })).into_response()
195}
196
197pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
198 let expected_prefix = if hostname.contains(':') {
199 format!("did:web:{}", hostname.replace(':', "%3A"))
200 } else {
201 format!("did:web:{}", hostname)
202 };
203 if did.starts_with(&expected_prefix) {
204 let suffix = &did[expected_prefix.len()..];
205 let expected_suffix = format!(":u:{}", handle);
206 if suffix == expected_suffix {
207 Ok(())
208 } else {
209 Err(format!(
210 "Invalid DID path for this PDS. Expected {}",
211 expected_suffix
212 ))
213 }
214 } else {
215 let parts: Vec<&str> = did.split(':').collect();
216 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
217 return Err("Invalid did:web format".into());
218 }
219 let domain_segment = parts[2];
220 let domain = domain_segment.replace("%3A", ":");
221 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
222 "http"
223 } else {
224 "https"
225 };
226 let url = if parts.len() == 3 {
227 format!("{}://{}/.well-known/did.json", scheme, domain)
228 } else {
229 let path = parts[3..].join("/");
230 format!("{}://{}/{}/did.json", scheme, domain, path)
231 };
232 let client = reqwest::Client::builder()
233 .timeout(std::time::Duration::from_secs(5))
234 .build()
235 .map_err(|e| format!("Failed to create client: {}", e))?;
236 let resp = client
237 .get(&url)
238 .send()
239 .await
240 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
241 if !resp.status().is_success() {
242 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
243 }
244 let doc: serde_json::Value = resp
245 .json()
246 .await
247 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
248 let services = doc["service"]
249 .as_array()
250 .ok_or("No services found in DID doc")?;
251 let pds_endpoint = format!("https://{}", hostname);
252 let has_valid_service = services.iter().any(|s| {
253 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
254 });
255 if has_valid_service {
256 Ok(())
257 } else {
258 Err(format!(
259 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
260 pds_endpoint
261 ))
262 }
263 }
264}
265
266#[derive(serde::Serialize)]
267#[serde(rename_all = "camelCase")]
268pub struct GetRecommendedDidCredentialsOutput {
269 pub rotation_keys: Vec<String>,
270 pub also_known_as: Vec<String>,
271 pub verification_methods: VerificationMethods,
272 pub services: Services,
273}
274
275#[derive(serde::Serialize)]
276#[serde(rename_all = "camelCase")]
277pub struct VerificationMethods {
278 pub atproto: String,
279}
280
281#[derive(serde::Serialize)]
282#[serde(rename_all = "camelCase")]
283pub struct Services {
284 pub atproto_pds: AtprotoPds,
285}
286
287#[derive(serde::Serialize)]
288#[serde(rename_all = "camelCase")]
289pub struct AtprotoPds {
290 #[serde(rename = "type")]
291 pub service_type: String,
292 pub endpoint: String,
293}
294
295pub async fn get_recommended_did_credentials(
296 State(state): State<AppState>,
297 headers: axum::http::HeaderMap,
298) -> Response {
299 let token = match crate::auth::extract_bearer_token_from_header(
300 headers.get("Authorization").and_then(|h| h.to_str().ok()),
301 ) {
302 Some(t) => t,
303 None => {
304 return (
305 StatusCode::UNAUTHORIZED,
306 Json(json!({"error": "AuthenticationRequired"})),
307 )
308 .into_response();
309 }
310 };
311 let auth_user =
312 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
313 Ok(user) => user,
314 Err(e) => return ApiError::from(e).into_response(),
315 };
316 let user = match sqlx::query!(
317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
318 auth_user.did
319 )
320 .fetch_optional(&state.db)
321 .await
322 {
323 Ok(Some(row)) => row,
324 _ => return ApiError::InternalError.into_response(),
325 };
326 let key_bytes = match auth_user.key_bytes {
327 Some(kb) => kb,
328 None => {
329 return ApiError::AuthenticationFailedMsg(
330 "OAuth tokens cannot get DID credentials".into(),
331 )
332 .into_response();
333 }
334 };
335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
336 let pds_endpoint = format!("https://{}", hostname);
337 let full_handle = if user.handle.contains('.') {
338 user.handle.clone()
339 } else {
340 format!("{}.{}", user.handle, hostname)
341 };
342 let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) {
343 Ok(k) => k,
344 Err(_) => return ApiError::InternalError.into_response(),
345 };
346 let did_key = signing_key_to_did_key(&signing_key);
347 (
348 StatusCode::OK,
349 Json(GetRecommendedDidCredentialsOutput {
350 rotation_keys: vec![did_key.clone()],
351 also_known_as: vec![format!("at://{}", full_handle)],
352 verification_methods: VerificationMethods { atproto: did_key },
353 services: Services {
354 atproto_pds: AtprotoPds {
355 service_type: "AtprotoPersonalDataServer".to_string(),
356 endpoint: pds_endpoint,
357 },
358 },
359 }),
360 )
361 .into_response()
362}
363
364#[derive(Deserialize)]
365pub struct UpdateHandleInput {
366 pub handle: String,
367}
368
369pub async fn update_handle(
370 State(state): State<AppState>,
371 headers: axum::http::HeaderMap,
372 Json(input): Json<UpdateHandleInput>,
373) -> Response {
374 let token = match crate::auth::extract_bearer_token_from_header(
375 headers.get("Authorization").and_then(|h| h.to_str().ok()),
376 ) {
377 Some(t) => t,
378 None => return ApiError::AuthenticationRequired.into_response(),
379 };
380 let auth_user =
381 match crate::auth::validate_bearer_token_allow_deactivated(&state.db, &token).await {
382 Ok(user) => user,
383 Err(e) => return ApiError::from(e).into_response(),
384 };
385 if let Err(e) = crate::auth::scope_check::check_identity_scope(
386 auth_user.is_oauth,
387 auth_user.scope.as_deref(),
388 crate::oauth::scopes::IdentityAttr::Handle,
389 ) {
390 return e;
391 }
392 let did = auth_user.did;
393 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
394 .fetch_optional(&state.db)
395 .await
396 {
397 Ok(Some(id)) => id,
398 _ => return ApiError::InternalError.into_response(),
399 };
400 let new_handle = input.handle.trim();
401 if new_handle.is_empty() {
402 return ApiError::InvalidRequest("handle is required".into()).into_response();
403 }
404 if !new_handle
405 .chars()
406 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
407 {
408 return (
409 StatusCode::BAD_REQUEST,
410 Json(
411 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
412 ),
413 )
414 .into_response();
415 }
416 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
417 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname);
418 let (handle_to_store, full_handle) = if is_service_domain {
419 let suffix = format!(".{}", hostname);
420 let short_handle = if new_handle.ends_with(&suffix) {
421 new_handle.strip_suffix(&suffix).unwrap_or(new_handle)
422 } else {
423 new_handle
424 };
425 (
426 short_handle.to_string(),
427 format!("{}.{}", short_handle, hostname),
428 )
429 } else {
430 match crate::handle::verify_handle_ownership(new_handle, &did).await {
431 Ok(()) => {}
432 Err(crate::handle::HandleResolutionError::NotFound) => {
433 return (
434 StatusCode::BAD_REQUEST,
435 Json(json!({
436 "error": "HandleNotAvailable",
437 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did",
438 "handle": new_handle
439 })),
440 )
441 .into_response();
442 }
443 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => {
444 return (
445 StatusCode::BAD_REQUEST,
446 Json(json!({
447 "error": "HandleNotAvailable",
448 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual)
449 })),
450 )
451 .into_response();
452 }
453 Err(e) => {
454 warn!("Handle verification failed: {}", e);
455 return (
456 StatusCode::BAD_REQUEST,
457 Json(json!({
458 "error": "HandleNotAvailable",
459 "message": format!("Handle verification failed: {}", e)
460 })),
461 )
462 .into_response();
463 }
464 }
465 (new_handle.to_string(), new_handle.to_string())
466 };
467 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
468 .fetch_optional(&state.db)
469 .await
470 .ok()
471 .flatten();
472 let existing = sqlx::query!(
473 "SELECT id FROM users WHERE handle = $1 AND id != $2",
474 handle_to_store,
475 user_id
476 )
477 .fetch_optional(&state.db)
478 .await;
479 if let Ok(Some(_)) = existing {
480 return (
481 StatusCode::BAD_REQUEST,
482 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
483 )
484 .into_response();
485 }
486 let result = sqlx::query!(
487 "UPDATE users SET handle = $1 WHERE id = $2",
488 handle_to_store,
489 user_id
490 )
491 .execute(&state.db)
492 .await;
493 match result {
494 Ok(_) => {
495 if let Some(old) = old_handle {
496 let _ = state.cache.delete(&format!("handle:{}", old)).await;
497 }
498 let _ = state
499 .cache
500 .delete(&format!("handle:{}", handle_to_store))
501 .await;
502 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await;
503 if let Err(e) =
504 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
505 .await
506 {
507 warn!("Failed to sequence identity event for handle update: {}", e);
508 }
509 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await {
510 warn!("Failed to update PLC handle: {}", e);
511 }
512 (StatusCode::OK, Json(json!({}))).into_response()
513 }
514 Err(e) => {
515 error!("DB error updating handle: {:?}", e);
516 (
517 StatusCode::INTERNAL_SERVER_ERROR,
518 Json(json!({"error": "InternalError"})),
519 )
520 .into_response()
521 }
522 }
523}
524
525async fn update_plc_handle(
526 state: &AppState,
527 did: &str,
528 new_handle: &str,
529) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
530 if !did.starts_with("did:plc:") {
531 return Ok(());
532 }
533 let user_row = sqlx::query!(
534 r#"SELECT u.id, uk.key_bytes, uk.encryption_version
535 FROM users u
536 JOIN user_keys uk ON u.id = uk.user_id
537 WHERE u.did = $1"#,
538 did
539 )
540 .fetch_optional(&state.db)
541 .await?;
542 let user_row = match user_row {
543 Some(r) => r,
544 None => return Ok(()),
545 };
546 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?;
547 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?;
548 let plc_client = crate::plc::PlcClient::new(None);
549 let last_op = plc_client.get_last_op(did).await?;
550 let new_also_known_as = vec![format!("at://{}", new_handle)];
551 let update_op =
552 crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?;
553 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?;
554 plc_client.send_operation(did, &signed_op).await?;
555 Ok(())
556}
557
558pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
559 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
560 Some(h) => h,
561 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
562 };
563 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
564 let suffix = format!(".{}", hostname);
565 let handle = host.split(':').next().unwrap_or(host);
566 let short_handle = if handle.ends_with(&suffix) {
567 handle.strip_suffix(&suffix).unwrap_or(handle)
568 } else {
569 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
570 };
571 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
572 .fetch_optional(&state.db)
573 .await;
574 match user {
575 Ok(Some(row)) => row.did.into_response(),
576 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
577 Err(e) => {
578 error!("DB error in well-known atproto-did: {:?}", e);
579 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
580 }
581 }
582}