this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Path, Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use base64::Engine;
9use k256::SecretKey;
10use k256::elliptic_curve::sec1::ToEncodedPoint;
11use reqwest;
12use serde::Deserialize;
13use serde_json::json;
14use tracing::error;
15
16#[derive(Deserialize)]
17pub struct ResolveHandleParams {
18 pub handle: String,
19}
20
21pub async fn resolve_handle(
22 State(state): State<AppState>,
23 Query(params): Query<ResolveHandleParams>,
24) -> Response {
25 let handle = params.handle.trim();
26
27 if handle.is_empty() {
28 return (
29 StatusCode::BAD_REQUEST,
30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
31 )
32 .into_response();
33 }
34
35 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle)
36 .fetch_optional(&state.db)
37 .await;
38
39 match user {
40 Ok(Some(row)) => {
41 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
42 }
43 Ok(None) => (
44 StatusCode::NOT_FOUND,
45 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
46 )
47 .into_response(),
48 Err(e) => {
49 error!("DB error resolving handle: {:?}", e);
50 (
51 StatusCode::INTERNAL_SERVER_ERROR,
52 Json(json!({"error": "InternalError"})),
53 )
54 .into_response()
55 }
56 }
57}
58
59pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
60 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
61 let public_key = secret_key.public_key();
62 let encoded = public_key.to_encoded_point(false);
63 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
64 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
65
66 json!({
67 "kty": "EC",
68 "crv": "secp256k1",
69 "x": x,
70 "y": y
71 })
72}
73
74pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
75 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
76 // Kinda for local dev, encode hostname if it contains port
77 let did = if hostname.contains(':') {
78 format!("did:web:{}", hostname.replace(':', "%3A"))
79 } else {
80 format!("did:web:{}", hostname)
81 };
82
83 Json(json!({
84 "@context": ["https://www.w3.org/ns/did/v1"],
85 "id": did,
86 "service": [{
87 "id": "#atproto_pds",
88 "type": "AtprotoPersonalDataServer",
89 "serviceEndpoint": format!("https://{}", hostname)
90 }]
91 }))
92}
93
94pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
95 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
96
97 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
98 .fetch_optional(&state.db)
99 .await;
100
101 let (user_id, did) = match user {
102 Ok(Some(row)) => (row.id, row.did),
103 Ok(None) => {
104 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
105 }
106 Err(e) => {
107 error!("DB Error: {:?}", e);
108 return (
109 StatusCode::INTERNAL_SERVER_ERROR,
110 Json(json!({"error": "InternalError"})),
111 )
112 .into_response();
113 }
114 };
115
116 if !did.starts_with("did:web:") {
117 return (
118 StatusCode::NOT_FOUND,
119 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
120 )
121 .into_response();
122 }
123
124 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
125 .fetch_optional(&state.db)
126 .await;
127
128 let key_bytes: Vec<u8> = match key_row {
129 Ok(Some(row)) => {
130 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
131 Ok(k) => k,
132 Err(_) => {
133 return (
134 StatusCode::INTERNAL_SERVER_ERROR,
135 Json(json!({"error": "InternalError"})),
136 )
137 .into_response();
138 }
139 }
140 }
141 _ => {
142 return (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 Json(json!({"error": "InternalError"})),
145 )
146 .into_response();
147 }
148 };
149
150 let jwk = get_jwk(&key_bytes);
151
152 Json(json!({
153 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
154 "id": did,
155 "alsoKnownAs": [format!("at://{}", handle)],
156 "verificationMethod": [{
157 "id": format!("{}#atproto", did),
158 "type": "JsonWebKey2020",
159 "controller": did,
160 "publicKeyJwk": jwk
161 }],
162 "service": [{
163 "id": "#atproto_pds",
164 "type": "AtprotoPersonalDataServer",
165 "serviceEndpoint": format!("https://{}", hostname)
166 }]
167 })).into_response()
168}
169
170pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
171 let expected_prefix = if hostname.contains(':') {
172 format!("did:web:{}", hostname.replace(':', "%3A"))
173 } else {
174 format!("did:web:{}", hostname)
175 };
176
177 if did.starts_with(&expected_prefix) {
178 let suffix = &did[expected_prefix.len()..];
179 let expected_suffix = format!(":u:{}", handle);
180 if suffix == expected_suffix {
181 Ok(())
182 } else {
183 Err(format!(
184 "Invalid DID path for this PDS. Expected {}",
185 expected_suffix
186 ))
187 }
188 } else {
189 let parts: Vec<&str> = did.split(':').collect();
190 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
191 return Err("Invalid did:web format".into());
192 }
193
194 let domain_segment = parts[2];
195 let domain = domain_segment.replace("%3A", ":");
196
197 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
198 "http"
199 } else {
200 "https"
201 };
202
203 let url = if parts.len() == 3 {
204 format!("{}://{}/.well-known/did.json", scheme, domain)
205 } else {
206 let path = parts[3..].join("/");
207 format!("{}://{}/{}/did.json", scheme, domain, path)
208 };
209
210 let client = reqwest::Client::builder()
211 .timeout(std::time::Duration::from_secs(5))
212 .build()
213 .map_err(|e| format!("Failed to create client: {}", e))?;
214
215 let resp = client
216 .get(&url)
217 .send()
218 .await
219 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
220
221 if !resp.status().is_success() {
222 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
223 }
224
225 let doc: serde_json::Value = resp
226 .json()
227 .await
228 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
229
230 let services = doc["service"]
231 .as_array()
232 .ok_or("No services found in DID doc")?;
233
234 let pds_endpoint = format!("https://{}", hostname);
235
236 let has_valid_service = services.iter().any(|s| {
237 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
238 });
239
240 if has_valid_service {
241 Ok(())
242 } else {
243 Err(format!(
244 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
245 pds_endpoint
246 ))
247 }
248 }
249}
250
251#[derive(serde::Serialize)]
252#[serde(rename_all = "camelCase")]
253pub struct GetRecommendedDidCredentialsOutput {
254 pub rotation_keys: Vec<String>,
255 pub also_known_as: Vec<String>,
256 pub verification_methods: VerificationMethods,
257 pub services: Services,
258}
259
260#[derive(serde::Serialize)]
261#[serde(rename_all = "camelCase")]
262pub struct VerificationMethods {
263 pub atproto: String,
264}
265
266#[derive(serde::Serialize)]
267#[serde(rename_all = "camelCase")]
268pub struct Services {
269 pub atproto_pds: AtprotoPds,
270}
271
272#[derive(serde::Serialize)]
273#[serde(rename_all = "camelCase")]
274pub struct AtprotoPds {
275 #[serde(rename = "type")]
276 pub service_type: String,
277 pub endpoint: String,
278}
279
280pub async fn get_recommended_did_credentials(
281 State(state): State<AppState>,
282 headers: axum::http::HeaderMap,
283) -> Response {
284 let token = match crate::auth::extract_bearer_token_from_header(
285 headers.get("Authorization").and_then(|h| h.to_str().ok())
286 ) {
287 Some(t) => t,
288 None => {
289 return (
290 StatusCode::UNAUTHORIZED,
291 Json(json!({"error": "AuthenticationRequired"})),
292 )
293 .into_response();
294 }
295 };
296
297 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
298 let did = match auth_result {
299 Ok(ref user) => user.did.clone(),
300 Err(e) => {
301 return (
302 StatusCode::UNAUTHORIZED,
303 Json(json!({"error": e})),
304 )
305 .into_response();
306 }
307 };
308
309 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", did)
310 .fetch_optional(&state.db)
311 .await
312 {
313 Ok(Some(row)) => row,
314 _ => {
315 return (
316 StatusCode::INTERNAL_SERVER_ERROR,
317 Json(json!({"error": "InternalError"})),
318 )
319 .into_response();
320 }
321 };
322 let handle = user.handle;
323
324 let key_bytes = match auth_result.ok().and_then(|u| u.key_bytes) {
325 Some(kb) => kb,
326 None => {
327 return (
328 StatusCode::UNAUTHORIZED,
329 Json(json!({"error": "AuthenticationFailed", "message": "OAuth tokens cannot get DID credentials"})),
330 )
331 .into_response();
332 }
333 };
334
335 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
336 let pds_endpoint = format!("https://{}", hostname);
337
338 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
339 Ok(k) => k,
340 Err(_) => {
341 return (
342 StatusCode::INTERNAL_SERVER_ERROR,
343 Json(json!({"error": "InternalError"})),
344 )
345 .into_response();
346 }
347 };
348
349 let public_key = secret_key.public_key();
350 let encoded = public_key.to_encoded_point(true);
351 let did_key = format!(
352 "did:key:zQ3sh{}",
353 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
354 .chars()
355 .skip(1)
356 .collect::<String>()
357 );
358
359 (
360 StatusCode::OK,
361 Json(GetRecommendedDidCredentialsOutput {
362 rotation_keys: vec![did_key.clone()],
363 also_known_as: vec![format!("at://{}", handle)],
364 verification_methods: VerificationMethods { atproto: did_key },
365 services: Services {
366 atproto_pds: AtprotoPds {
367 service_type: "AtprotoPersonalDataServer".to_string(),
368 endpoint: pds_endpoint,
369 },
370 },
371 }),
372 )
373 .into_response()
374}
375
376#[derive(Deserialize)]
377pub struct UpdateHandleInput {
378 pub handle: String,
379}
380
381pub async fn update_handle(
382 State(state): State<AppState>,
383 headers: axum::http::HeaderMap,
384 Json(input): Json<UpdateHandleInput>,
385) -> Response {
386 let token = match crate::auth::extract_bearer_token_from_header(
387 headers.get("Authorization").and_then(|h| h.to_str().ok())
388 ) {
389 Some(t) => t,
390 None => {
391 return (
392 StatusCode::UNAUTHORIZED,
393 Json(json!({"error": "AuthenticationRequired"})),
394 )
395 .into_response();
396 }
397 };
398
399 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
400 let did = match auth_result {
401 Ok(user) => user.did,
402 Err(e) => {
403 return (
404 StatusCode::UNAUTHORIZED,
405 Json(json!({"error": e})),
406 )
407 .into_response();
408 }
409 };
410
411 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
412 .fetch_optional(&state.db)
413 .await
414 {
415 Ok(Some(id)) => id,
416 _ => {
417 return (
418 StatusCode::INTERNAL_SERVER_ERROR,
419 Json(json!({"error": "InternalError"})),
420 )
421 .into_response();
422 }
423 };
424
425 let new_handle = input.handle.trim();
426 if new_handle.is_empty() {
427 return (
428 StatusCode::BAD_REQUEST,
429 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
430 )
431 .into_response();
432 }
433
434 if !new_handle
435 .chars()
436 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
437 {
438 return (
439 StatusCode::BAD_REQUEST,
440 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
441 )
442 .into_response();
443 }
444
445 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
446 .fetch_optional(&state.db)
447 .await;
448
449 if let Ok(Some(_)) = existing {
450 return (
451 StatusCode::BAD_REQUEST,
452 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
453 )
454 .into_response();
455 }
456
457 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
458 .execute(&state.db)
459 .await;
460
461 match result {
462 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
463 Err(e) => {
464 error!("DB error updating handle: {:?}", e);
465 (
466 StatusCode::INTERNAL_SERVER_ERROR,
467 Json(json!({"error": "InternalError"})),
468 )
469 .into_response()
470 }
471 }
472}