this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Path, Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use base64::Engine;
9use k256::SecretKey;
10use k256::elliptic_curve::sec1::ToEncodedPoint;
11use reqwest;
12use serde::Deserialize;
13use serde_json::json;
14use tracing::error;
15
16#[derive(Deserialize)]
17pub struct ResolveHandleParams {
18 pub handle: String,
19}
20
21pub async fn resolve_handle(
22 State(state): State<AppState>,
23 Query(params): Query<ResolveHandleParams>,
24) -> Response {
25 let handle = params.handle.trim();
26
27 if handle.is_empty() {
28 return (
29 StatusCode::BAD_REQUEST,
30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
31 )
32 .into_response();
33 }
34
35 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle)
36 .fetch_optional(&state.db)
37 .await;
38
39 match user {
40 Ok(Some(row)) => {
41 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
42 }
43 Ok(None) => (
44 StatusCode::NOT_FOUND,
45 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
46 )
47 .into_response(),
48 Err(e) => {
49 error!("DB error resolving handle: {:?}", e);
50 (
51 StatusCode::INTERNAL_SERVER_ERROR,
52 Json(json!({"error": "InternalError"})),
53 )
54 .into_response()
55 }
56 }
57}
58
59pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
60 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
61 let public_key = secret_key.public_key();
62 let encoded = public_key.to_encoded_point(false);
63 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
64 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
65
66 json!({
67 "kty": "EC",
68 "crv": "secp256k1",
69 "x": x,
70 "y": y
71 })
72}
73
74pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
75 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
76 // Kinda for local dev, encode hostname if it contains port
77 let did = if hostname.contains(':') {
78 format!("did:web:{}", hostname.replace(':', "%3A"))
79 } else {
80 format!("did:web:{}", hostname)
81 };
82
83 Json(json!({
84 "@context": ["https://www.w3.org/ns/did/v1"],
85 "id": did,
86 "service": [{
87 "id": "#atproto_pds",
88 "type": "AtprotoPersonalDataServer",
89 "serviceEndpoint": format!("https://{}", hostname)
90 }]
91 }))
92}
93
94pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
95 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
96
97 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
98 .fetch_optional(&state.db)
99 .await;
100
101 let (user_id, did) = match user {
102 Ok(Some(row)) => (row.id, row.did),
103 Ok(None) => {
104 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
105 }
106 Err(e) => {
107 error!("DB Error: {:?}", e);
108 return (
109 StatusCode::INTERNAL_SERVER_ERROR,
110 Json(json!({"error": "InternalError"})),
111 )
112 .into_response();
113 }
114 };
115
116 if !did.starts_with("did:web:") {
117 return (
118 StatusCode::NOT_FOUND,
119 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
120 )
121 .into_response();
122 }
123
124 let key_row = sqlx::query!("SELECT key_bytes FROM user_keys WHERE user_id = $1", user_id)
125 .fetch_optional(&state.db)
126 .await;
127
128 let key_bytes: Vec<u8> = match key_row {
129 Ok(Some(row)) => row.key_bytes,
130 _ => {
131 return (
132 StatusCode::INTERNAL_SERVER_ERROR,
133 Json(json!({"error": "InternalError"})),
134 )
135 .into_response();
136 }
137 };
138
139 let jwk = get_jwk(&key_bytes);
140
141 Json(json!({
142 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
143 "id": did,
144 "alsoKnownAs": [format!("at://{}", handle)],
145 "verificationMethod": [{
146 "id": format!("{}#atproto", did),
147 "type": "JsonWebKey2020",
148 "controller": did,
149 "publicKeyJwk": jwk
150 }],
151 "service": [{
152 "id": "#atproto_pds",
153 "type": "AtprotoPersonalDataServer",
154 "serviceEndpoint": format!("https://{}", hostname)
155 }]
156 })).into_response()
157}
158
159pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
160 let expected_prefix = if hostname.contains(':') {
161 format!("did:web:{}", hostname.replace(':', "%3A"))
162 } else {
163 format!("did:web:{}", hostname)
164 };
165
166 if did.starts_with(&expected_prefix) {
167 let suffix = &did[expected_prefix.len()..];
168 let expected_suffix = format!(":u:{}", handle);
169 if suffix == expected_suffix {
170 Ok(())
171 } else {
172 Err(format!(
173 "Invalid DID path for this PDS. Expected {}",
174 expected_suffix
175 ))
176 }
177 } else {
178 let parts: Vec<&str> = did.split(':').collect();
179 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
180 return Err("Invalid did:web format".into());
181 }
182
183 let domain_segment = parts[2];
184 let domain = domain_segment.replace("%3A", ":");
185
186 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
187 "http"
188 } else {
189 "https"
190 };
191
192 let url = if parts.len() == 3 {
193 format!("{}://{}/.well-known/did.json", scheme, domain)
194 } else {
195 let path = parts[3..].join("/");
196 format!("{}://{}/{}/did.json", scheme, domain, path)
197 };
198
199 let client = reqwest::Client::builder()
200 .timeout(std::time::Duration::from_secs(5))
201 .build()
202 .map_err(|e| format!("Failed to create client: {}", e))?;
203
204 let resp = client
205 .get(&url)
206 .send()
207 .await
208 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
209
210 if !resp.status().is_success() {
211 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
212 }
213
214 let doc: serde_json::Value = resp
215 .json()
216 .await
217 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
218
219 let services = doc["service"]
220 .as_array()
221 .ok_or("No services found in DID doc")?;
222
223 let pds_endpoint = format!("https://{}", hostname);
224
225 let has_valid_service = services.iter().any(|s| {
226 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
227 });
228
229 if has_valid_service {
230 Ok(())
231 } else {
232 Err(format!(
233 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
234 pds_endpoint
235 ))
236 }
237 }
238}
239
240#[derive(serde::Serialize)]
241#[serde(rename_all = "camelCase")]
242pub struct GetRecommendedDidCredentialsOutput {
243 pub rotation_keys: Vec<String>,
244 pub also_known_as: Vec<String>,
245 pub verification_methods: VerificationMethods,
246 pub services: Services,
247}
248
249#[derive(serde::Serialize)]
250#[serde(rename_all = "camelCase")]
251pub struct VerificationMethods {
252 pub atproto: String,
253}
254
255#[derive(serde::Serialize)]
256#[serde(rename_all = "camelCase")]
257pub struct Services {
258 pub atproto_pds: AtprotoPds,
259}
260
261#[derive(serde::Serialize)]
262#[serde(rename_all = "camelCase")]
263pub struct AtprotoPds {
264 #[serde(rename = "type")]
265 pub service_type: String,
266 pub endpoint: String,
267}
268
269pub async fn get_recommended_did_credentials(
270 State(state): State<AppState>,
271 headers: axum::http::HeaderMap,
272) -> Response {
273 let auth_header = headers.get("Authorization");
274 if auth_header.is_none() {
275 return (
276 StatusCode::UNAUTHORIZED,
277 Json(json!({"error": "AuthenticationRequired"})),
278 )
279 .into_response();
280 }
281
282 let token = auth_header
283 .unwrap()
284 .to_str()
285 .unwrap_or("")
286 .replace("Bearer ", "");
287
288 let session = sqlx::query!(
289 r#"
290 SELECT s.did, k.key_bytes, u.handle
291 FROM sessions s
292 JOIN users u ON s.did = u.did
293 JOIN user_keys k ON u.id = k.user_id
294 WHERE s.access_jwt = $1
295 "#,
296 token
297 )
298 .fetch_optional(&state.db)
299 .await;
300
301 let (_did, key_bytes, handle) = match session {
302 Ok(Some(row)) => (row.did, row.key_bytes, row.handle),
303 Ok(None) => {
304 return (
305 StatusCode::UNAUTHORIZED,
306 Json(json!({"error": "AuthenticationFailed"})),
307 )
308 .into_response();
309 }
310 Err(e) => {
311 error!("DB error in get_recommended_did_credentials: {:?}", e);
312 return (
313 StatusCode::INTERNAL_SERVER_ERROR,
314 Json(json!({"error": "InternalError"})),
315 )
316 .into_response();
317 }
318 };
319
320 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
321 return (
322 StatusCode::UNAUTHORIZED,
323 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
324 )
325 .into_response();
326 }
327
328 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
329 let pds_endpoint = format!("https://{}", hostname);
330
331 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
332 Ok(k) => k,
333 Err(_) => {
334 return (
335 StatusCode::INTERNAL_SERVER_ERROR,
336 Json(json!({"error": "InternalError"})),
337 )
338 .into_response();
339 }
340 };
341
342 let public_key = secret_key.public_key();
343 let encoded = public_key.to_encoded_point(true);
344 let did_key = format!(
345 "did:key:zQ3sh{}",
346 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
347 .chars()
348 .skip(1)
349 .collect::<String>()
350 );
351
352 (
353 StatusCode::OK,
354 Json(GetRecommendedDidCredentialsOutput {
355 rotation_keys: vec![did_key.clone()],
356 also_known_as: vec![format!("at://{}", handle)],
357 verification_methods: VerificationMethods { atproto: did_key },
358 services: Services {
359 atproto_pds: AtprotoPds {
360 service_type: "AtprotoPersonalDataServer".to_string(),
361 endpoint: pds_endpoint,
362 },
363 },
364 }),
365 )
366 .into_response()
367}
368
369#[derive(Deserialize)]
370pub struct UpdateHandleInput {
371 pub handle: String,
372}
373
374pub async fn update_handle(
375 State(state): State<AppState>,
376 headers: axum::http::HeaderMap,
377 Json(input): Json<UpdateHandleInput>,
378) -> Response {
379 let auth_header = headers.get("Authorization");
380 if auth_header.is_none() {
381 return (
382 StatusCode::UNAUTHORIZED,
383 Json(json!({"error": "AuthenticationRequired"})),
384 )
385 .into_response();
386 }
387
388 let token = auth_header
389 .unwrap()
390 .to_str()
391 .unwrap_or("")
392 .replace("Bearer ", "");
393
394 let session = sqlx::query!(
395 r#"
396 SELECT s.did, k.key_bytes, u.id as user_id
397 FROM sessions s
398 JOIN users u ON s.did = u.did
399 JOIN user_keys k ON u.id = k.user_id
400 WHERE s.access_jwt = $1
401 "#,
402 token
403 )
404 .fetch_optional(&state.db)
405 .await;
406
407 let (_did, key_bytes, user_id) = match session {
408 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
409 Ok(None) => {
410 return (
411 StatusCode::UNAUTHORIZED,
412 Json(json!({"error": "AuthenticationFailed"})),
413 )
414 .into_response();
415 }
416 Err(e) => {
417 error!("DB error in update_handle: {:?}", e);
418 return (
419 StatusCode::INTERNAL_SERVER_ERROR,
420 Json(json!({"error": "InternalError"})),
421 )
422 .into_response();
423 }
424 };
425
426 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
427 return (
428 StatusCode::UNAUTHORIZED,
429 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
430 )
431 .into_response();
432 }
433
434 let new_handle = input.handle.trim();
435 if new_handle.is_empty() {
436 return (
437 StatusCode::BAD_REQUEST,
438 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
439 )
440 .into_response();
441 }
442
443 if !new_handle
444 .chars()
445 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
446 {
447 return (
448 StatusCode::BAD_REQUEST,
449 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
450 )
451 .into_response();
452 }
453
454 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
455 .fetch_optional(&state.db)
456 .await;
457
458 if let Ok(Some(_)) = existing {
459 return (
460 StatusCode::BAD_REQUEST,
461 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
462 )
463 .into_response();
464 }
465
466 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
467 .execute(&state.db)
468 .await;
469
470 match result {
471 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
472 Err(e) => {
473 error!("DB error updating handle: {:?}", e);
474 (
475 StatusCode::INTERNAL_SERVER_ERROR,
476 Json(json!({"error": "InternalError"})),
477 )
478 .into_response()
479 }
480 }
481}