this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::{error, warn};
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27 if handle.is_empty() {
28 return (
29 StatusCode::BAD_REQUEST,
30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
31 )
32 .into_response();
33 }
34 let cache_key = format!("handle:{}", handle);
35 if let Some(did) = state.cache.get(&cache_key).await {
36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
37 }
38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
39 let suffix = format!(".{}", hostname);
40 let short_handle = if handle.ends_with(&suffix) {
41 handle.strip_suffix(&suffix).unwrap_or(handle)
42 } else {
43 handle
44 };
45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
46 .fetch_optional(&state.db)
47 .await;
48 match user {
49 Ok(Some(row)) => {
50 let _ = state
51 .cache
52 .set(&cache_key, &row.did, std::time::Duration::from_secs(300))
53 .await;
54 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
55 }
56 Ok(None) => (
57 StatusCode::NOT_FOUND,
58 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
59 )
60 .into_response(),
61 Err(e) => {
62 error!("DB error resolving handle: {:?}", e);
63 (
64 StatusCode::INTERNAL_SERVER_ERROR,
65 Json(json!({"error": "InternalError"})),
66 )
67 .into_response()
68 }
69 }
70}
71
72pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
73 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
74 let public_key = secret_key.public_key();
75 let encoded = public_key.to_encoded_point(false);
76 let x = encoded.x().ok_or("Missing x coordinate")?;
77 let y = encoded.y().ok_or("Missing y coordinate")?;
78 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
79 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
80 Ok(json!({
81 "kty": "EC",
82 "crv": "secp256k1",
83 "x": x_b64,
84 "y": y_b64
85 }))
86}
87
88pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
89 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
90 // Kinda for local dev, encode hostname if it contains port
91 let did = if hostname.contains(':') {
92 format!("did:web:{}", hostname.replace(':', "%3A"))
93 } else {
94 format!("did:web:{}", hostname)
95 };
96 Json(json!({
97 "@context": ["https://www.w3.org/ns/did/v1"],
98 "id": did,
99 "service": [{
100 "id": "#atproto_pds",
101 "type": "AtprotoPersonalDataServer",
102 "serviceEndpoint": format!("https://{}", hostname)
103 }]
104 }))
105}
106
107pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
108 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
109 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
110 .fetch_optional(&state.db)
111 .await;
112 let (user_id, did) = match user {
113 Ok(Some(row)) => (row.id, row.did),
114 Ok(None) => {
115 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
116 }
117 Err(e) => {
118 error!("DB Error: {:?}", e);
119 return (
120 StatusCode::INTERNAL_SERVER_ERROR,
121 Json(json!({"error": "InternalError"})),
122 )
123 .into_response();
124 }
125 };
126 if !did.starts_with("did:web:") {
127 return (
128 StatusCode::NOT_FOUND,
129 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
130 )
131 .into_response();
132 }
133 let key_row = sqlx::query!(
134 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
135 user_id
136 )
137 .fetch_optional(&state.db)
138 .await;
139 let key_bytes: Vec<u8> = match key_row {
140 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
141 Ok(k) => k,
142 Err(_) => {
143 return (
144 StatusCode::INTERNAL_SERVER_ERROR,
145 Json(json!({"error": "InternalError"})),
146 )
147 .into_response();
148 }
149 },
150 _ => {
151 return (
152 StatusCode::INTERNAL_SERVER_ERROR,
153 Json(json!({"error": "InternalError"})),
154 )
155 .into_response();
156 }
157 };
158 let jwk = match get_jwk(&key_bytes) {
159 Ok(j) => j,
160 Err(e) => {
161 tracing::error!("Failed to generate JWK: {}", e);
162 return (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(json!({"error": "InternalError"})),
165 )
166 .into_response();
167 }
168 };
169 Json(json!({
170 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
171 "id": did,
172 "alsoKnownAs": [format!("at://{}", handle)],
173 "verificationMethod": [{
174 "id": format!("{}#atproto", did),
175 "type": "JsonWebKey2020",
176 "controller": did,
177 "publicKeyJwk": jwk
178 }],
179 "service": [{
180 "id": "#atproto_pds",
181 "type": "AtprotoPersonalDataServer",
182 "serviceEndpoint": format!("https://{}", hostname)
183 }]
184 })).into_response()
185}
186
187pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
188 let expected_prefix = if hostname.contains(':') {
189 format!("did:web:{}", hostname.replace(':', "%3A"))
190 } else {
191 format!("did:web:{}", hostname)
192 };
193 if did.starts_with(&expected_prefix) {
194 let suffix = &did[expected_prefix.len()..];
195 let expected_suffix = format!(":u:{}", handle);
196 if suffix == expected_suffix {
197 Ok(())
198 } else {
199 Err(format!(
200 "Invalid DID path for this PDS. Expected {}",
201 expected_suffix
202 ))
203 }
204 } else {
205 let parts: Vec<&str> = did.split(':').collect();
206 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
207 return Err("Invalid did:web format".into());
208 }
209 let domain_segment = parts[2];
210 let domain = domain_segment.replace("%3A", ":");
211 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
212 "http"
213 } else {
214 "https"
215 };
216 let url = if parts.len() == 3 {
217 format!("{}://{}/.well-known/did.json", scheme, domain)
218 } else {
219 let path = parts[3..].join("/");
220 format!("{}://{}/{}/did.json", scheme, domain, path)
221 };
222 let client = reqwest::Client::builder()
223 .timeout(std::time::Duration::from_secs(5))
224 .build()
225 .map_err(|e| format!("Failed to create client: {}", e))?;
226 let resp = client
227 .get(&url)
228 .send()
229 .await
230 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
231 if !resp.status().is_success() {
232 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
233 }
234 let doc: serde_json::Value = resp
235 .json()
236 .await
237 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
238 let services = doc["service"]
239 .as_array()
240 .ok_or("No services found in DID doc")?;
241 let pds_endpoint = format!("https://{}", hostname);
242 let has_valid_service = services.iter().any(|s| {
243 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
244 });
245 if has_valid_service {
246 Ok(())
247 } else {
248 Err(format!(
249 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
250 pds_endpoint
251 ))
252 }
253 }
254}
255
256#[derive(serde::Serialize)]
257#[serde(rename_all = "camelCase")]
258pub struct GetRecommendedDidCredentialsOutput {
259 pub rotation_keys: Vec<String>,
260 pub also_known_as: Vec<String>,
261 pub verification_methods: VerificationMethods,
262 pub services: Services,
263}
264
265#[derive(serde::Serialize)]
266#[serde(rename_all = "camelCase")]
267pub struct VerificationMethods {
268 pub atproto: String,
269}
270
271#[derive(serde::Serialize)]
272#[serde(rename_all = "camelCase")]
273pub struct Services {
274 pub atproto_pds: AtprotoPds,
275}
276
277#[derive(serde::Serialize)]
278#[serde(rename_all = "camelCase")]
279pub struct AtprotoPds {
280 #[serde(rename = "type")]
281 pub service_type: String,
282 pub endpoint: String,
283}
284
285pub async fn get_recommended_did_credentials(
286 State(state): State<AppState>,
287 headers: axum::http::HeaderMap,
288) -> Response {
289 let token = match crate::auth::extract_bearer_token_from_header(
290 headers.get("Authorization").and_then(|h| h.to_str().ok()),
291 ) {
292 Some(t) => t,
293 None => {
294 return (
295 StatusCode::UNAUTHORIZED,
296 Json(json!({"error": "AuthenticationRequired"})),
297 )
298 .into_response();
299 }
300 };
301 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
302 Ok(user) => user,
303 Err(e) => return ApiError::from(e).into_response(),
304 };
305 let user = match sqlx::query!(
306 "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1",
307 auth_user.did
308 )
309 .fetch_optional(&state.db)
310 .await
311 {
312 Ok(Some(row)) => row,
313 _ => return ApiError::InternalError.into_response(),
314 };
315 let key_bytes = match auth_user.key_bytes {
316 Some(kb) => kb,
317 None => {
318 return ApiError::AuthenticationFailedMsg(
319 "OAuth tokens cannot get DID credentials".into(),
320 )
321 .into_response();
322 }
323 };
324 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
325 let pds_endpoint = format!("https://{}", hostname);
326 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
327 Ok(k) => k,
328 Err(_) => return ApiError::InternalError.into_response(),
329 };
330 let public_key = secret_key.public_key();
331 let encoded = public_key.to_encoded_point(true);
332 let did_key = format!(
333 "did:key:zQ3sh{}",
334 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
335 .chars()
336 .skip(1)
337 .collect::<String>()
338 );
339 (
340 StatusCode::OK,
341 Json(GetRecommendedDidCredentialsOutput {
342 rotation_keys: vec![did_key.clone()],
343 also_known_as: vec![format!("at://{}", user.handle)],
344 verification_methods: VerificationMethods { atproto: did_key },
345 services: Services {
346 atproto_pds: AtprotoPds {
347 service_type: "AtprotoPersonalDataServer".to_string(),
348 endpoint: pds_endpoint,
349 },
350 },
351 }),
352 )
353 .into_response()
354}
355
356#[derive(Deserialize)]
357pub struct UpdateHandleInput {
358 pub handle: String,
359}
360
361pub async fn update_handle(
362 State(state): State<AppState>,
363 headers: axum::http::HeaderMap,
364 Json(input): Json<UpdateHandleInput>,
365) -> Response {
366 let token = match crate::auth::extract_bearer_token_from_header(
367 headers.get("Authorization").and_then(|h| h.to_str().ok()),
368 ) {
369 Some(t) => t,
370 None => return ApiError::AuthenticationRequired.into_response(),
371 };
372 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
373 Ok(user) => user.did,
374 Err(e) => return ApiError::from(e).into_response(),
375 };
376 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
377 .fetch_optional(&state.db)
378 .await
379 {
380 Ok(Some(id)) => id,
381 _ => return ApiError::InternalError.into_response(),
382 };
383 let new_handle = input.handle.trim();
384 if new_handle.is_empty() {
385 return ApiError::InvalidRequest("handle is required".into()).into_response();
386 }
387 if !new_handle
388 .chars()
389 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
390 {
391 return (
392 StatusCode::BAD_REQUEST,
393 Json(
394 json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}),
395 ),
396 )
397 .into_response();
398 }
399 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
400 .fetch_optional(&state.db)
401 .await
402 .ok()
403 .flatten();
404 let existing = sqlx::query!(
405 "SELECT id FROM users WHERE handle = $1 AND id != $2",
406 new_handle,
407 user_id
408 )
409 .fetch_optional(&state.db)
410 .await;
411 if let Ok(Some(_)) = existing {
412 return (
413 StatusCode::BAD_REQUEST,
414 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
415 )
416 .into_response();
417 }
418 let result = sqlx::query!(
419 "UPDATE users SET handle = $1 WHERE id = $2",
420 new_handle,
421 user_id
422 )
423 .execute(&state.db)
424 .await;
425 match result {
426 Ok(_) => {
427 if let Some(old) = old_handle {
428 let _ = state.cache.delete(&format!("handle:{}", old)).await;
429 }
430 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
431 let hostname =
432 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
433 let full_handle = format!("{}.{}", new_handle, hostname);
434 if let Err(e) =
435 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle))
436 .await
437 {
438 warn!("Failed to sequence identity event for handle update: {}", e);
439 }
440 (StatusCode::OK, Json(json!({}))).into_response()
441 }
442 Err(e) => {
443 error!("DB error updating handle: {:?}", e);
444 (
445 StatusCode::INTERNAL_SERVER_ERROR,
446 Json(json!({"error": "InternalError"})),
447 )
448 .into_response()
449 }
450 }
451}
452
453pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
454 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
455 Some(h) => h,
456 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
457 };
458 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
459 let suffix = format!(".{}", hostname);
460 let handle = host.split(':').next().unwrap_or(host);
461 let short_handle = if handle.ends_with(&suffix) {
462 handle.strip_suffix(&suffix).unwrap_or(handle)
463 } else {
464 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
465 };
466 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
467 .fetch_optional(&state.db)
468 .await;
469 match user {
470 Ok(Some(row)) => row.did.into_response(),
471 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
472 Err(e) => {
473 error!("DB error in well-known atproto-did: {:?}", e);
474 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
475 }
476 }
477}