this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Path, Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use base64::Engine;
9use k256::SecretKey;
10use k256::elliptic_curve::sec1::ToEncodedPoint;
11use reqwest;
12use serde::Deserialize;
13use serde_json::json;
14use sqlx::Row;
15use tracing::error;
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35
36 let user = sqlx::query("SELECT did FROM users WHERE handle = $1")
37 .bind(handle)
38 .fetch_optional(&state.db)
39 .await;
40
41 match user {
42 Ok(Some(row)) => {
43 let did: String = row.get("did");
44 (StatusCode::OK, Json(json!({ "did": did }))).into_response()
45 }
46 Ok(None) => (
47 StatusCode::NOT_FOUND,
48 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
49 )
50 .into_response(),
51 Err(e) => {
52 error!("DB error resolving handle: {:?}", e);
53 (
54 StatusCode::INTERNAL_SERVER_ERROR,
55 Json(json!({"error": "InternalError"})),
56 )
57 .into_response()
58 }
59 }
60}
61
62pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
63 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
64 let public_key = secret_key.public_key();
65 let encoded = public_key.to_encoded_point(false);
66 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
67 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
68
69 json!({
70 "kty": "EC",
71 "crv": "secp256k1",
72 "x": x,
73 "y": y
74 })
75}
76
77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
79 // Kinda for local dev, encode hostname if it contains port
80 let did = if hostname.contains(':') {
81 format!("did:web:{}", hostname.replace(':', "%3A"))
82 } else {
83 format!("did:web:{}", hostname)
84 };
85
86 Json(json!({
87 "@context": ["https://www.w3.org/ns/did/v1"],
88 "id": did,
89 "service": [{
90 "id": "#atproto_pds",
91 "type": "AtprotoPersonalDataServer",
92 "serviceEndpoint": format!("https://{}", hostname)
93 }]
94 }))
95}
96
97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
99
100 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
101 .bind(&handle)
102 .fetch_optional(&state.db)
103 .await;
104
105 let (user_id, did) = match user {
106 Ok(Some(row)) => {
107 let id: uuid::Uuid = row.get("id");
108 let d: String = row.get("did");
109 (id, d)
110 }
111 Ok(None) => {
112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
113 }
114 Err(e) => {
115 error!("DB Error: {:?}", e);
116 return (
117 StatusCode::INTERNAL_SERVER_ERROR,
118 Json(json!({"error": "InternalError"})),
119 )
120 .into_response();
121 }
122 };
123
124 if !did.starts_with("did:web:") {
125 return (
126 StatusCode::NOT_FOUND,
127 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
128 )
129 .into_response();
130 }
131
132 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
133 .bind(user_id)
134 .fetch_optional(&state.db)
135 .await;
136
137 let key_bytes: Vec<u8> = match key_row {
138 Ok(Some(row)) => row.get("key_bytes"),
139 _ => {
140 return (
141 StatusCode::INTERNAL_SERVER_ERROR,
142 Json(json!({"error": "InternalError"})),
143 )
144 .into_response();
145 }
146 };
147
148 let jwk = get_jwk(&key_bytes);
149
150 Json(json!({
151 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
152 "id": did,
153 "alsoKnownAs": [format!("at://{}", handle)],
154 "verificationMethod": [{
155 "id": format!("{}#atproto", did),
156 "type": "JsonWebKey2020",
157 "controller": did,
158 "publicKeyJwk": jwk
159 }],
160 "service": [{
161 "id": "#atproto_pds",
162 "type": "AtprotoPersonalDataServer",
163 "serviceEndpoint": format!("https://{}", hostname)
164 }]
165 })).into_response()
166}
167
168pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
169 let expected_prefix = if hostname.contains(':') {
170 format!("did:web:{}", hostname.replace(':', "%3A"))
171 } else {
172 format!("did:web:{}", hostname)
173 };
174
175 if did.starts_with(&expected_prefix) {
176 let suffix = &did[expected_prefix.len()..];
177 let expected_suffix = format!(":u:{}", handle);
178 if suffix == expected_suffix {
179 Ok(())
180 } else {
181 Err(format!(
182 "Invalid DID path for this PDS. Expected {}",
183 expected_suffix
184 ))
185 }
186 } else {
187 let parts: Vec<&str> = did.split(':').collect();
188 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
189 return Err("Invalid did:web format".into());
190 }
191
192 let domain_segment = parts[2];
193 let domain = domain_segment.replace("%3A", ":");
194
195 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
196 "http"
197 } else {
198 "https"
199 };
200
201 let url = if parts.len() == 3 {
202 format!("{}://{}/.well-known/did.json", scheme, domain)
203 } else {
204 let path = parts[3..].join("/");
205 format!("{}://{}/{}/did.json", scheme, domain, path)
206 };
207
208 let client = reqwest::Client::builder()
209 .timeout(std::time::Duration::from_secs(5))
210 .build()
211 .map_err(|e| format!("Failed to create client: {}", e))?;
212
213 let resp = client
214 .get(&url)
215 .send()
216 .await
217 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
218
219 if !resp.status().is_success() {
220 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
221 }
222
223 let doc: serde_json::Value = resp
224 .json()
225 .await
226 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
227
228 let services = doc["service"]
229 .as_array()
230 .ok_or("No services found in DID doc")?;
231
232 let pds_endpoint = format!("https://{}", hostname);
233
234 let has_valid_service = services.iter().any(|s| {
235 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
236 });
237
238 if has_valid_service {
239 Ok(())
240 } else {
241 Err(format!(
242 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
243 pds_endpoint
244 ))
245 }
246 }
247}
248
249#[derive(serde::Serialize)]
250#[serde(rename_all = "camelCase")]
251pub struct GetRecommendedDidCredentialsOutput {
252 pub rotation_keys: Vec<String>,
253 pub also_known_as: Vec<String>,
254 pub verification_methods: VerificationMethods,
255 pub services: Services,
256}
257
258#[derive(serde::Serialize)]
259#[serde(rename_all = "camelCase")]
260pub struct VerificationMethods {
261 pub atproto: String,
262}
263
264#[derive(serde::Serialize)]
265#[serde(rename_all = "camelCase")]
266pub struct Services {
267 pub atproto_pds: AtprotoPds,
268}
269
270#[derive(serde::Serialize)]
271#[serde(rename_all = "camelCase")]
272pub struct AtprotoPds {
273 #[serde(rename = "type")]
274 pub service_type: String,
275 pub endpoint: String,
276}
277
278pub async fn get_recommended_did_credentials(
279 State(state): State<AppState>,
280 headers: axum::http::HeaderMap,
281) -> Response {
282 let auth_header = headers.get("Authorization");
283 if auth_header.is_none() {
284 return (
285 StatusCode::UNAUTHORIZED,
286 Json(json!({"error": "AuthenticationRequired"})),
287 )
288 .into_response();
289 }
290
291 let token = auth_header
292 .unwrap()
293 .to_str()
294 .unwrap_or("")
295 .replace("Bearer ", "");
296
297 let session = sqlx::query(
298 r#"
299 SELECT s.did, k.key_bytes, u.handle
300 FROM sessions s
301 JOIN users u ON s.did = u.did
302 JOIN user_keys k ON u.id = k.user_id
303 WHERE s.access_jwt = $1
304 "#,
305 )
306 .bind(&token)
307 .fetch_optional(&state.db)
308 .await;
309
310 let (_did, key_bytes, handle) = match session {
311 Ok(Some(row)) => (
312 row.get::<String, _>("did"),
313 row.get::<Vec<u8>, _>("key_bytes"),
314 row.get::<String, _>("handle"),
315 ),
316 Ok(None) => {
317 return (
318 StatusCode::UNAUTHORIZED,
319 Json(json!({"error": "AuthenticationFailed"})),
320 )
321 .into_response();
322 }
323 Err(e) => {
324 error!("DB error in get_recommended_did_credentials: {:?}", e);
325 return (
326 StatusCode::INTERNAL_SERVER_ERROR,
327 Json(json!({"error": "InternalError"})),
328 )
329 .into_response();
330 }
331 };
332
333 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
334 return (
335 StatusCode::UNAUTHORIZED,
336 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
337 )
338 .into_response();
339 }
340
341 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
342 let pds_endpoint = format!("https://{}", hostname);
343
344 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
345 Ok(k) => k,
346 Err(_) => {
347 return (
348 StatusCode::INTERNAL_SERVER_ERROR,
349 Json(json!({"error": "InternalError"})),
350 )
351 .into_response();
352 }
353 };
354
355 let public_key = secret_key.public_key();
356 let encoded = public_key.to_encoded_point(true);
357 let did_key = format!(
358 "did:key:zQ3sh{}",
359 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
360 .chars()
361 .skip(1)
362 .collect::<String>()
363 );
364
365 (
366 StatusCode::OK,
367 Json(GetRecommendedDidCredentialsOutput {
368 rotation_keys: vec![did_key.clone()],
369 also_known_as: vec![format!("at://{}", handle)],
370 verification_methods: VerificationMethods { atproto: did_key },
371 services: Services {
372 atproto_pds: AtprotoPds {
373 service_type: "AtprotoPersonalDataServer".to_string(),
374 endpoint: pds_endpoint,
375 },
376 },
377 }),
378 )
379 .into_response()
380}
381
382#[derive(Deserialize)]
383pub struct UpdateHandleInput {
384 pub handle: String,
385}
386
387pub async fn update_handle(
388 State(state): State<AppState>,
389 headers: axum::http::HeaderMap,
390 Json(input): Json<UpdateHandleInput>,
391) -> Response {
392 let auth_header = headers.get("Authorization");
393 if auth_header.is_none() {
394 return (
395 StatusCode::UNAUTHORIZED,
396 Json(json!({"error": "AuthenticationRequired"})),
397 )
398 .into_response();
399 }
400
401 let token = auth_header
402 .unwrap()
403 .to_str()
404 .unwrap_or("")
405 .replace("Bearer ", "");
406
407 let session = sqlx::query(
408 r#"
409 SELECT s.did, k.key_bytes, u.id as user_id
410 FROM sessions s
411 JOIN users u ON s.did = u.did
412 JOIN user_keys k ON u.id = k.user_id
413 WHERE s.access_jwt = $1
414 "#,
415 )
416 .bind(&token)
417 .fetch_optional(&state.db)
418 .await;
419
420 let (_did, key_bytes, user_id) = match session {
421 Ok(Some(row)) => (
422 row.get::<String, _>("did"),
423 row.get::<Vec<u8>, _>("key_bytes"),
424 row.get::<uuid::Uuid, _>("user_id"),
425 ),
426 Ok(None) => {
427 return (
428 StatusCode::UNAUTHORIZED,
429 Json(json!({"error": "AuthenticationFailed"})),
430 )
431 .into_response();
432 }
433 Err(e) => {
434 error!("DB error in update_handle: {:?}", e);
435 return (
436 StatusCode::INTERNAL_SERVER_ERROR,
437 Json(json!({"error": "InternalError"})),
438 )
439 .into_response();
440 }
441 };
442
443 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
444 return (
445 StatusCode::UNAUTHORIZED,
446 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
447 )
448 .into_response();
449 }
450
451 let new_handle = input.handle.trim();
452 if new_handle.is_empty() {
453 return (
454 StatusCode::BAD_REQUEST,
455 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
456 )
457 .into_response();
458 }
459
460 if !new_handle
461 .chars()
462 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
463 {
464 return (
465 StatusCode::BAD_REQUEST,
466 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
467 )
468 .into_response();
469 }
470
471 let existing = sqlx::query("SELECT id FROM users WHERE handle = $1 AND id != $2")
472 .bind(new_handle)
473 .bind(user_id)
474 .fetch_optional(&state.db)
475 .await;
476
477 if let Ok(Some(_)) = existing {
478 return (
479 StatusCode::BAD_REQUEST,
480 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
481 )
482 .into_response();
483 }
484
485 let result = sqlx::query("UPDATE users SET handle = $1 WHERE id = $2")
486 .bind(new_handle)
487 .bind(user_id)
488 .execute(&state.db)
489 .await;
490
491 match result {
492 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
493 Err(e) => {
494 error!("DB error updating handle: {:?}", e);
495 (
496 StatusCode::INTERNAL_SERVER_ERROR,
497 Json(json!({"error": "InternalError"})),
498 )
499 .into_response()
500 }
501 }
502}