this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::error;
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35
36 let cache_key = format!("handle:{}", handle);
37 if let Some(did) = state.cache.get(&cache_key).await {
38 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
39 }
40
41 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
42 let suffix = format!(".{}", hostname);
43 let short_handle = if handle.ends_with(&suffix) {
44 handle.strip_suffix(&suffix).unwrap_or(handle)
45 } else {
46 handle
47 };
48
49 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
50 .fetch_optional(&state.db)
51 .await;
52
53 match user {
54 Ok(Some(row)) => {
55 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
56 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
57 }
58 Ok(None) => (
59 StatusCode::NOT_FOUND,
60 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
61 )
62 .into_response(),
63 Err(e) => {
64 error!("DB error resolving handle: {:?}", e);
65 (
66 StatusCode::INTERNAL_SERVER_ERROR,
67 Json(json!({"error": "InternalError"})),
68 )
69 .into_response()
70 }
71 }
72}
73
74pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
75 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
76 let public_key = secret_key.public_key();
77 let encoded = public_key.to_encoded_point(false);
78 let x = encoded.x().ok_or("Missing x coordinate")?;
79 let y = encoded.y().ok_or("Missing y coordinate")?;
80 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
81 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
82
83 Ok(json!({
84 "kty": "EC",
85 "crv": "secp256k1",
86 "x": x_b64,
87 "y": y_b64
88 }))
89}
90
91pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
92 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
93 // Kinda for local dev, encode hostname if it contains port
94 let did = if hostname.contains(':') {
95 format!("did:web:{}", hostname.replace(':', "%3A"))
96 } else {
97 format!("did:web:{}", hostname)
98 };
99
100 Json(json!({
101 "@context": ["https://www.w3.org/ns/did/v1"],
102 "id": did,
103 "service": [{
104 "id": "#atproto_pds",
105 "type": "AtprotoPersonalDataServer",
106 "serviceEndpoint": format!("https://{}", hostname)
107 }]
108 }))
109}
110
111pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
112 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
113
114 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
115 .fetch_optional(&state.db)
116 .await;
117
118 let (user_id, did) = match user {
119 Ok(Some(row)) => (row.id, row.did),
120 Ok(None) => {
121 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
122 }
123 Err(e) => {
124 error!("DB Error: {:?}", e);
125 return (
126 StatusCode::INTERNAL_SERVER_ERROR,
127 Json(json!({"error": "InternalError"})),
128 )
129 .into_response();
130 }
131 };
132
133 if !did.starts_with("did:web:") {
134 return (
135 StatusCode::NOT_FOUND,
136 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
137 )
138 .into_response();
139 }
140
141 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
142 .fetch_optional(&state.db)
143 .await;
144
145 let key_bytes: Vec<u8> = match key_row {
146 Ok(Some(row)) => {
147 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
148 Ok(k) => k,
149 Err(_) => {
150 return (
151 StatusCode::INTERNAL_SERVER_ERROR,
152 Json(json!({"error": "InternalError"})),
153 )
154 .into_response();
155 }
156 }
157 }
158 _ => {
159 return (
160 StatusCode::INTERNAL_SERVER_ERROR,
161 Json(json!({"error": "InternalError"})),
162 )
163 .into_response();
164 }
165 };
166
167 let jwk = match get_jwk(&key_bytes) {
168 Ok(j) => j,
169 Err(e) => {
170 tracing::error!("Failed to generate JWK: {}", e);
171 return (
172 StatusCode::INTERNAL_SERVER_ERROR,
173 Json(json!({"error": "InternalError"})),
174 )
175 .into_response();
176 }
177 };
178
179 Json(json!({
180 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
181 "id": did,
182 "alsoKnownAs": [format!("at://{}", handle)],
183 "verificationMethod": [{
184 "id": format!("{}#atproto", did),
185 "type": "JsonWebKey2020",
186 "controller": did,
187 "publicKeyJwk": jwk
188 }],
189 "service": [{
190 "id": "#atproto_pds",
191 "type": "AtprotoPersonalDataServer",
192 "serviceEndpoint": format!("https://{}", hostname)
193 }]
194 })).into_response()
195}
196
197pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
198 let expected_prefix = if hostname.contains(':') {
199 format!("did:web:{}", hostname.replace(':', "%3A"))
200 } else {
201 format!("did:web:{}", hostname)
202 };
203
204 if did.starts_with(&expected_prefix) {
205 let suffix = &did[expected_prefix.len()..];
206 let expected_suffix = format!(":u:{}", handle);
207 if suffix == expected_suffix {
208 Ok(())
209 } else {
210 Err(format!(
211 "Invalid DID path for this PDS. Expected {}",
212 expected_suffix
213 ))
214 }
215 } else {
216 let parts: Vec<&str> = did.split(':').collect();
217 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
218 return Err("Invalid did:web format".into());
219 }
220
221 let domain_segment = parts[2];
222 let domain = domain_segment.replace("%3A", ":");
223
224 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
225 "http"
226 } else {
227 "https"
228 };
229
230 let url = if parts.len() == 3 {
231 format!("{}://{}/.well-known/did.json", scheme, domain)
232 } else {
233 let path = parts[3..].join("/");
234 format!("{}://{}/{}/did.json", scheme, domain, path)
235 };
236
237 let client = reqwest::Client::builder()
238 .timeout(std::time::Duration::from_secs(5))
239 .build()
240 .map_err(|e| format!("Failed to create client: {}", e))?;
241
242 let resp = client
243 .get(&url)
244 .send()
245 .await
246 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
247
248 if !resp.status().is_success() {
249 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
250 }
251
252 let doc: serde_json::Value = resp
253 .json()
254 .await
255 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
256
257 let services = doc["service"]
258 .as_array()
259 .ok_or("No services found in DID doc")?;
260
261 let pds_endpoint = format!("https://{}", hostname);
262
263 let has_valid_service = services.iter().any(|s| {
264 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
265 });
266
267 if has_valid_service {
268 Ok(())
269 } else {
270 Err(format!(
271 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
272 pds_endpoint
273 ))
274 }
275 }
276}
277
278#[derive(serde::Serialize)]
279#[serde(rename_all = "camelCase")]
280pub struct GetRecommendedDidCredentialsOutput {
281 pub rotation_keys: Vec<String>,
282 pub also_known_as: Vec<String>,
283 pub verification_methods: VerificationMethods,
284 pub services: Services,
285}
286
287#[derive(serde::Serialize)]
288#[serde(rename_all = "camelCase")]
289pub struct VerificationMethods {
290 pub atproto: String,
291}
292
293#[derive(serde::Serialize)]
294#[serde(rename_all = "camelCase")]
295pub struct Services {
296 pub atproto_pds: AtprotoPds,
297}
298
299#[derive(serde::Serialize)]
300#[serde(rename_all = "camelCase")]
301pub struct AtprotoPds {
302 #[serde(rename = "type")]
303 pub service_type: String,
304 pub endpoint: String,
305}
306
307pub async fn get_recommended_did_credentials(
308 State(state): State<AppState>,
309 headers: axum::http::HeaderMap,
310) -> Response {
311 let token = match crate::auth::extract_bearer_token_from_header(
312 headers.get("Authorization").and_then(|h| h.to_str().ok())
313 ) {
314 Some(t) => t,
315 None => {
316 return (
317 StatusCode::UNAUTHORIZED,
318 Json(json!({"error": "AuthenticationRequired"})),
319 )
320 .into_response();
321 }
322 };
323
324 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
325 Ok(user) => user,
326 Err(e) => return ApiError::from(e).into_response(),
327 };
328
329 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
330 .fetch_optional(&state.db)
331 .await
332 {
333 Ok(Some(row)) => row,
334 _ => return ApiError::InternalError.into_response(),
335 };
336
337 let key_bytes = match auth_user.key_bytes {
338 Some(kb) => kb,
339 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
340 };
341
342 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
343 let pds_endpoint = format!("https://{}", hostname);
344
345 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
346 Ok(k) => k,
347 Err(_) => return ApiError::InternalError.into_response(),
348 };
349
350 let public_key = secret_key.public_key();
351 let encoded = public_key.to_encoded_point(true);
352 let did_key = format!(
353 "did:key:zQ3sh{}",
354 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
355 .chars()
356 .skip(1)
357 .collect::<String>()
358 );
359
360 (
361 StatusCode::OK,
362 Json(GetRecommendedDidCredentialsOutput {
363 rotation_keys: vec![did_key.clone()],
364 also_known_as: vec![format!("at://{}", user.handle)],
365 verification_methods: VerificationMethods { atproto: did_key },
366 services: Services {
367 atproto_pds: AtprotoPds {
368 service_type: "AtprotoPersonalDataServer".to_string(),
369 endpoint: pds_endpoint,
370 },
371 },
372 }),
373 )
374 .into_response()
375}
376
377#[derive(Deserialize)]
378pub struct UpdateHandleInput {
379 pub handle: String,
380}
381
382pub async fn update_handle(
383 State(state): State<AppState>,
384 headers: axum::http::HeaderMap,
385 Json(input): Json<UpdateHandleInput>,
386) -> Response {
387 let token = match crate::auth::extract_bearer_token_from_header(
388 headers.get("Authorization").and_then(|h| h.to_str().ok())
389 ) {
390 Some(t) => t,
391 None => return ApiError::AuthenticationRequired.into_response(),
392 };
393
394 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
395 Ok(user) => user.did,
396 Err(e) => return ApiError::from(e).into_response(),
397 };
398
399 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
400 .fetch_optional(&state.db)
401 .await
402 {
403 Ok(Some(id)) => id,
404 _ => return ApiError::InternalError.into_response(),
405 };
406
407 let new_handle = input.handle.trim();
408 if new_handle.is_empty() {
409 return ApiError::InvalidRequest("handle is required".into()).into_response();
410 }
411
412 if !new_handle
413 .chars()
414 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
415 {
416 return (
417 StatusCode::BAD_REQUEST,
418 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
419 )
420 .into_response();
421 }
422
423 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
424 .fetch_optional(&state.db)
425 .await
426 .ok()
427 .flatten();
428
429 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
430 .fetch_optional(&state.db)
431 .await;
432
433 if let Ok(Some(_)) = existing {
434 return (
435 StatusCode::BAD_REQUEST,
436 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
437 )
438 .into_response();
439 }
440
441 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
442 .execute(&state.db)
443 .await;
444
445 match result {
446 Ok(_) => {
447 if let Some(old) = old_handle {
448 let _ = state.cache.delete(&format!("handle:{}", old)).await;
449 }
450 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
451 (StatusCode::OK, Json(json!({}))).into_response()
452 }
453 Err(e) => {
454 error!("DB error updating handle: {:?}", e);
455 (
456 StatusCode::INTERNAL_SERVER_ERROR,
457 Json(json!({"error": "InternalError"})),
458 )
459 .into_response()
460 }
461 }
462}
463
464pub async fn well_known_atproto_did(
465 State(state): State<AppState>,
466 headers: HeaderMap,
467) -> Response {
468 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
469 Some(h) => h,
470 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
471 };
472
473 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
474 let suffix = format!(".{}", hostname);
475
476 let handle = host.split(':').next().unwrap_or(host);
477
478 let short_handle = if handle.ends_with(&suffix) {
479 handle.strip_suffix(&suffix).unwrap_or(handle)
480 } else {
481 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
482 };
483
484 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
485 .fetch_optional(&state.db)
486 .await;
487
488 match user {
489 Ok(Some(row)) => row.did.into_response(),
490 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
491 Err(e) => {
492 error!("DB error in well-known atproto-did: {:?}", e);
493 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
494 }
495 }
496}