this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::{error, warn};
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27 if handle.is_empty() {
28 return (
29 StatusCode::BAD_REQUEST,
30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
31 )
32 .into_response();
33 }
34 let cache_key = format!("handle:{}", handle);
35 if let Some(did) = state.cache.get(&cache_key).await {
36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
37 }
38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
39 let suffix = format!(".{}", hostname);
40 let short_handle = if handle.ends_with(&suffix) {
41 handle.strip_suffix(&suffix).unwrap_or(handle)
42 } else {
43 handle
44 };
45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
46 .fetch_optional(&state.db)
47 .await;
48 match user {
49 Ok(Some(row)) => {
50 let _ = state
51 .cache
52 .set(&cache_key, &row.did, std::time::Duration::from_secs(300))
53 .await;
54 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
55 }
56 Ok(None) => {
57 match crate::handle::resolve_handle(handle).await {
58 Ok(did) => {
59 let _ = state
60 .cache
61 .set(&cache_key, &did, std::time::Duration::from_secs(300))
62 .await;
63 (StatusCode::OK, Json(json!({ "did": did }))).into_response()
64 }
65 Err(_) => (
66 StatusCode::NOT_FOUND,
67 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
68 )
69 .into_response(),
70 }
71 }
72 Err(e) => {
73 error!("DB error resolving handle: {:?}", e);
74 (
75 StatusCode::INTERNAL_SERVER_ERROR,
76 Json(json!({"error": "InternalError"})),
77 )
78 .into_response()
79 }
80 }
81}
82
83pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
84 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
85 let public_key = secret_key.public_key();
86 let encoded = public_key.to_encoded_point(false);
87 let x = encoded.x().ok_or("Missing x coordinate")?;
88 let y = encoded.y().ok_or("Missing y coordinate")?;
89 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
90 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
91 Ok(json!({
92 "kty": "EC",
93 "crv": "secp256k1",
94 "x": x_b64,
95 "y": y_b64
96 }))
97}
98
99pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
100 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
101 // Kinda for local dev, encode hostname if it contains port
102 let did = if hostname.contains(':') {
103 format!("did:web:{}", hostname.replace(':', "%3A"))
104 } else {
105 format!("did:web:{}", hostname)
106 };
107 Json(json!({
108 "@context": ["https://www.w3.org/ns/did/v1"],
109 "id": did,
110 "service": [{
111 "id": "#atproto_pds",
112 "type": "AtprotoPersonalDataServer",
113 "serviceEndpoint": format!("https://{}", hostname)
114 }]
115 }))
116}
117
118pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
119 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
120 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
121 .fetch_optional(&state.db)
122 .await;
123 let (user_id, did) = match user {
124 Ok(Some(row)) => (row.id, row.did),
125 Ok(None) => {
126 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
127 }
128 Err(e) => {
129 error!("DB Error: {:?}", e);
130 return (
131 StatusCode::INTERNAL_SERVER_ERROR,
132 Json(json!({"error": "InternalError"})),
133 )
134 .into_response();
135 }
136 };
137 if !did.starts_with("did:web:") {
138 return (
139 StatusCode::NOT_FOUND,
140 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
141 )
142 .into_response();
143 }
144 let key_row = sqlx::query!(
145 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
146 user_id
147 )
148 .fetch_optional(&state.db)
149 .await;
150 let key_bytes: Vec<u8> = match key_row {
151 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
152 Ok(k) => k,
153 Err(_) => {
154 return (
155 StatusCode::INTERNAL_SERVER_ERROR,
156 Json(json!({"error": "InternalError"})),
157 )
158 .into_response();
159 }
160 },
161 _ => {
162 return (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(json!({"error": "InternalError"})),
165 )
166 .into_response();
167 }
168 };
169 let jwk = match get_jwk(&key_bytes) {
170 Ok(j) => j,
171 Err(e) => {
172 tracing::error!("Failed to generate JWK: {}", e);
173 return (
174 StatusCode::INTERNAL_SERVER_ERROR,
175 Json(json!({"error": "InternalError"})),
176 )
177 .into_response();
178 }
179 };
180 Json(json!({
181 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
182 "id": did,
183 "alsoKnownAs": [format!("at://{}", handle)],
184 "verificationMethod": [{
185 "id": format!("{}#atproto", did),
186 "type": "JsonWebKey2020",
187 "controller": did,
188 "publicKeyJwk": jwk
189 }],
190 "service": [{
191 "id": "#atproto_pds",
192 "type": "AtprotoPersonalDataServer",
193 "serviceEndpoint": format!("https://{}", hostname)
194 }]
195 })).into_response()
196}
197
198pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
199 let expected_prefix = if hostname.contains(':') {
200 format!("did:web:{}", hostname.replace(':', "%3A"))
201 } else {
202 format!("did:web:{}", hostname)
203 };
204 if did.starts_with(&expected_prefix) {
205 let suffix = &did[expected_prefix.len()..];
206 let expected_suffix = format!(":u:{}", handle);
207 if suffix == expected_suffix {
208 Ok(())
209 } else {
210 Err(format!(
211 "Invalid DID path for this PDS. Expected {}",
212 expected_suffix
213 ))
214 }
215 } else {
216 let parts: Vec<&str> = did.split(':').collect();
217 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
218 return Err("Invalid did:web format".into());
219 }
220 let domain_segment = parts[2];
221 let domain = domain_segment.replace("%3A", ":");
222 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
223 "http"
224 } else {
225 "https"
226 };
227 let url = if parts.len() == 3 {
228 format!("{}://{}/.well-known/did.json", scheme, domain)
229 } else {
230 let path = parts[3..].join("/");
231 format!("{}://{}/{}/did.json", scheme, domain, path)
232 };
233 let client = reqwest::Client::builder()
234 .timeout(std::time::Duration::from_secs(5))
235 .build()
236 .map_err(|e| format!("Failed to create client: {}", e))?;
237 let resp = client
238 .get(&url)
239 .send()
240 .await
241 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
242 if !resp.status().is_success() {
243 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
244 }
245 let doc: serde_json::Value = resp
246 .json()
247 .await
248 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
249 let services = doc["service"]
250 .as_array()
251 .ok_or("No services found in DID doc")?;
252 let pds_endpoint = format!("https://{}", hostname);
253 let has_valid_service = services.iter().any(|s| {
254 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
255 });
256 if has_valid_service {
257 Ok(())
258 } else {
259 Err(format!(
260 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
261 pds_endpoint
262 ))
263 }
264 }
265}
266
267#[derive(serde::Serialize)]
268#[serde(rename_all = "camelCase")]
269pub struct GetRecommendedDidCredentialsOutput {
270 pub rotation_keys: Vec<String>,
271 pub also_known_as: Vec<String>,
272 pub verification_methods: VerificationMethods,
273 pub services: Services,
274}
275
276#[derive(serde::Serialize)]
277#[serde(rename_all = "camelCase")]
278pub struct VerificationMethods {
279 pub atproto: String,
280}
281
282#[derive(serde::Serialize)]
283#[serde(rename_all = "camelCase")]
284pub struct Services {
285 pub atproto_pds: AtprotoPds,
286}
287
288#[derive(serde::Serialize)]
289#[serde(rename_all = "camelCase")]
290pub struct AtprotoPds {
291 #[serde(rename = "type")]
292 pub service_type: String,
293 pub endpoint: String,
294}
295
296pub async fn get_recommended_did_credentials(
297 State(state): State<AppState>,
298 headers: axum::http::HeaderMap,
299) -> Response {
300 let token = match crate::auth::extract_bearer_token_from_header(
301 headers.get("Authorization").and_then(|h| h.to_str().ok()),
302 ) {
303 Some(t) => t,
304 None => {
305 return (
306 StatusCode::UNAUTHORIZED,
307 Json(json!({"error": "AuthenticationRequired"})),
308 )
309 .into_response();
310 }
311 };
312 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
313 Ok(user) => user,
314 Err(e) => return ApiError::from(e).into_response(),
315 };
316 let user = match sqlx::query!(
317 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
318 auth_user.did
319 )
320 .fetch_optional(&state.db)
321 .await
322 {
323 Ok(Some(row)) => row,
324 _ => return ApiError::InternalError.into_response(),
325 };
326 let key_bytes = match auth_user.key_bytes {
327 Some(kb) => kb,
328 None => {
329 return ApiError::AuthenticationFailedMsg(
330 "OAuth tokens cannot get DID credentials".into(),
331 )
332 .into_response();
333 }
334 };
335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
336 let pds_endpoint = format!("https://{}", hostname);
337 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
338 Ok(k) => k,
339 Err(_) => return ApiError::InternalError.into_response(),
340 };
341 let public_key = secret_key.public_key();
342 let encoded = public_key.to_encoded_point(true);
343 let did_key = format!(
344 "did:key:zQ3sh{}",
345 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
346 .chars()
347 .skip(1)
348 .collect::<String>()
349 );
350 (
351 StatusCode::OK,
352 Json(GetRecommendedDidCredentialsOutput {
353 rotation_keys: vec![did_key.clone()],
354 also_known_as: vec![format!("at://{}", user.handle)],
355 verification_methods: VerificationMethods { atproto: did_key },
356 services: Services {
357 atproto_pds: AtprotoPds {
358 service_type: "AtprotoPersonalDataServer".to_string(),
359 endpoint: pds_endpoint,
360 },
361 },
362 }),
363 )
364 .into_response()
365}
366
367#[derive(Deserialize)]
368pub struct UpdateHandleInput {
369 pub handle: String,
370}
371
372pub async fn update_handle(
373 State(state): State<AppState>,
374 headers: axum::http::HeaderMap,
375 Json(input): Json<UpdateHandleInput>,
376) -> Response {
377 let token = match crate::auth::extract_bearer_token_from_header(
378 headers.get("Authorization").and_then(|h| h.to_str().ok()),
379 ) {
380 Some(t) => t,
381 None => return ApiError::AuthenticationRequired.into_response(),
382 };
383 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
384 Ok(user) => user.did,
385 Err(e) => return ApiError::from(e).into_response(),
386 };
387 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
388 .fetch_optional(&state.db)
389 .await
390 {
391 Ok(Some(id)) => id,
392 _ => return ApiError::InternalError.into_response(),
393 };
394 let new_handle = input.handle.trim();
395 if new_handle.is_empty() {
396 return ApiError::InvalidRequest("handle is required".into()).into_response();
397 }
398 if !new_handle
399 .chars()
400 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
401 {
402 return (
403 StatusCode::BAD_REQUEST,
404 Json(
405 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
406 ),
407 )
408 .into_response();
409 }
410 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
411 let is_service_domain = crate::handle::is_service_domain_handle(new_handle, &hostname);
412 let (handle_to_store, full_handle) = if is_service_domain {
413 let suffix = format!(".{}", hostname);
414 let short_handle = if new_handle.ends_with(&suffix) {
415 new_handle.strip_suffix(&suffix).unwrap_or(new_handle)
416 } else {
417 new_handle
418 };
419 (short_handle.to_string(), format!("{}.{}", short_handle, hostname))
420 } else {
421 match crate::handle::verify_handle_ownership(new_handle, &did).await {
422 Ok(()) => {}
423 Err(crate::handle::HandleResolutionError::NotFound) => {
424 return (
425 StatusCode::BAD_REQUEST,
426 Json(json!({
427 "error": "HandleNotAvailable",
428 "message": "Handle verification failed. Please set up DNS TXT record at _atproto.{} or serve your DID at https://{}/.well-known/atproto-did",
429 "handle": new_handle
430 })),
431 )
432 .into_response();
433 }
434 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => {
435 return (
436 StatusCode::BAD_REQUEST,
437 Json(json!({
438 "error": "HandleNotAvailable",
439 "message": format!("Handle points to different DID. Expected {}, got {}", expected, actual)
440 })),
441 )
442 .into_response();
443 }
444 Err(e) => {
445 warn!("Handle verification failed: {}", e);
446 return (
447 StatusCode::BAD_REQUEST,
448 Json(json!({
449 "error": "HandleNotAvailable",
450 "message": format!("Handle verification failed: {}", e)
451 })),
452 )
453 .into_response();
454 }
455 }
456 (new_handle.to_string(), new_handle.to_string())
457 };
458 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
459 .fetch_optional(&state.db)
460 .await
461 .ok()
462 .flatten();
463 let existing = sqlx::query!(
464 "SELECT id FROM users WHERE handle = $1 AND id != $2",
465 handle_to_store,
466 user_id
467 )
468 .fetch_optional(&state.db)
469 .await;
470 if let Ok(Some(_)) = existing {
471 return (
472 StatusCode::BAD_REQUEST,
473 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
474 )
475 .into_response();
476 }
477 let result = sqlx::query!(
478 "UPDATE users SET handle = $1 WHERE id = $2",
479 handle_to_store,
480 user_id
481 )
482 .execute(&state.db)
483 .await;
484 match result {
485 Ok(_) => {
486 if let Some(old) = old_handle {
487 let _ = state.cache.delete(&format!("handle:{}", old)).await;
488 }
489 let _ = state
490 .cache
491 .delete(&format!("handle:{}", handle_to_store))
492 .await;
493 let _ = state.cache.delete(&format!("handle:{}", full_handle)).await;
494 if let Err(e) =
495 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
496 .await
497 {
498 warn!("Failed to sequence identity event for handle update: {}", e);
499 }
500 if let Err(e) = update_plc_handle(&state, &did, &full_handle).await {
501 warn!("Failed to update PLC handle: {}", e);
502 }
503 (StatusCode::OK, Json(json!({}))).into_response()
504 }
505 Err(e) => {
506 error!("DB error updating handle: {:?}", e);
507 (
508 StatusCode::INTERNAL_SERVER_ERROR,
509 Json(json!({"error": "InternalError"})),
510 )
511 .into_response()
512 }
513 }
514}
515
516async fn update_plc_handle(
517 state: &AppState,
518 did: &str,
519 new_handle: &str,
520) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
521 if !did.starts_with("did:plc:") {
522 return Ok(());
523 }
524 let user_row = sqlx::query!(
525 r#"SELECT u.id, uk.key_bytes, uk.encryption_version
526 FROM users u
527 JOIN user_keys uk ON u.id = uk.user_id
528 WHERE u.did = $1"#,
529 did
530 )
531 .fetch_optional(&state.db)
532 .await?;
533 let user_row = match user_row {
534 Some(r) => r,
535 None => return Ok(()),
536 };
537 let key_bytes = crate::config::decrypt_key(&user_row.key_bytes, user_row.encryption_version)?;
538 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)?;
539 let plc_client = crate::plc::PlcClient::new(None);
540 let last_op = plc_client.get_last_op(did).await?;
541 let new_also_known_as = vec![format!("at://{}", new_handle)];
542 let update_op = crate::plc::create_update_op(&last_op, None, None, Some(new_also_known_as), None)?;
543 let signed_op = crate::plc::sign_operation(&update_op, &signing_key)?;
544 plc_client.send_operation(did, &signed_op).await?;
545 Ok(())
546}
547
548pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
549 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
550 Some(h) => h,
551 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
552 };
553 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
554 let suffix = format!(".{}", hostname);
555 let handle = host.split(':').next().unwrap_or(host);
556 let short_handle = if handle.ends_with(&suffix) {
557 handle.strip_suffix(&suffix).unwrap_or(handle)
558 } else {
559 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
560 };
561 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
562 .fetch_optional(&state.db)
563 .await;
564 match user {
565 Ok(Some(row)) => row.did.into_response(),
566 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
567 Err(e) => {
568 error!("DB error in well-known atproto-did: {:?}", e);
569 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
570 }
571 }
572}