this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::StatusCode,
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::error;
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35
36 let cache_key = format!("handle:{}", handle);
37 if let Some(did) = state.cache.get(&cache_key).await {
38 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
39 }
40
41 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle)
42 .fetch_optional(&state.db)
43 .await;
44
45 match user {
46 Ok(Some(row)) => {
47 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
48 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
49 }
50 Ok(None) => (
51 StatusCode::NOT_FOUND,
52 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
53 )
54 .into_response(),
55 Err(e) => {
56 error!("DB error resolving handle: {:?}", e);
57 (
58 StatusCode::INTERNAL_SERVER_ERROR,
59 Json(json!({"error": "InternalError"})),
60 )
61 .into_response()
62 }
63 }
64}
65
66pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
67 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
68 let public_key = secret_key.public_key();
69 let encoded = public_key.to_encoded_point(false);
70 let x = encoded.x().ok_or("Missing x coordinate")?;
71 let y = encoded.y().ok_or("Missing y coordinate")?;
72 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
73 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
74
75 Ok(json!({
76 "kty": "EC",
77 "crv": "secp256k1",
78 "x": x_b64,
79 "y": y_b64
80 }))
81}
82
83pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
84 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
85 // Kinda for local dev, encode hostname if it contains port
86 let did = if hostname.contains(':') {
87 format!("did:web:{}", hostname.replace(':', "%3A"))
88 } else {
89 format!("did:web:{}", hostname)
90 };
91
92 Json(json!({
93 "@context": ["https://www.w3.org/ns/did/v1"],
94 "id": did,
95 "service": [{
96 "id": "#atproto_pds",
97 "type": "AtprotoPersonalDataServer",
98 "serviceEndpoint": format!("https://{}", hostname)
99 }]
100 }))
101}
102
103pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
104 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
105
106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
107 .fetch_optional(&state.db)
108 .await;
109
110 let (user_id, did) = match user {
111 Ok(Some(row)) => (row.id, row.did),
112 Ok(None) => {
113 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
114 }
115 Err(e) => {
116 error!("DB Error: {:?}", e);
117 return (
118 StatusCode::INTERNAL_SERVER_ERROR,
119 Json(json!({"error": "InternalError"})),
120 )
121 .into_response();
122 }
123 };
124
125 if !did.starts_with("did:web:") {
126 return (
127 StatusCode::NOT_FOUND,
128 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
129 )
130 .into_response();
131 }
132
133 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
134 .fetch_optional(&state.db)
135 .await;
136
137 let key_bytes: Vec<u8> = match key_row {
138 Ok(Some(row)) => {
139 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
140 Ok(k) => k,
141 Err(_) => {
142 return (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 Json(json!({"error": "InternalError"})),
145 )
146 .into_response();
147 }
148 }
149 }
150 _ => {
151 return (
152 StatusCode::INTERNAL_SERVER_ERROR,
153 Json(json!({"error": "InternalError"})),
154 )
155 .into_response();
156 }
157 };
158
159 let jwk = match get_jwk(&key_bytes) {
160 Ok(j) => j,
161 Err(e) => {
162 tracing::error!("Failed to generate JWK: {}", e);
163 return (
164 StatusCode::INTERNAL_SERVER_ERROR,
165 Json(json!({"error": "InternalError"})),
166 )
167 .into_response();
168 }
169 };
170
171 Json(json!({
172 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
173 "id": did,
174 "alsoKnownAs": [format!("at://{}", handle)],
175 "verificationMethod": [{
176 "id": format!("{}#atproto", did),
177 "type": "JsonWebKey2020",
178 "controller": did,
179 "publicKeyJwk": jwk
180 }],
181 "service": [{
182 "id": "#atproto_pds",
183 "type": "AtprotoPersonalDataServer",
184 "serviceEndpoint": format!("https://{}", hostname)
185 }]
186 })).into_response()
187}
188
189pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
190 let expected_prefix = if hostname.contains(':') {
191 format!("did:web:{}", hostname.replace(':', "%3A"))
192 } else {
193 format!("did:web:{}", hostname)
194 };
195
196 if did.starts_with(&expected_prefix) {
197 let suffix = &did[expected_prefix.len()..];
198 let expected_suffix = format!(":u:{}", handle);
199 if suffix == expected_suffix {
200 Ok(())
201 } else {
202 Err(format!(
203 "Invalid DID path for this PDS. Expected {}",
204 expected_suffix
205 ))
206 }
207 } else {
208 let parts: Vec<&str> = did.split(':').collect();
209 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
210 return Err("Invalid did:web format".into());
211 }
212
213 let domain_segment = parts[2];
214 let domain = domain_segment.replace("%3A", ":");
215
216 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
217 "http"
218 } else {
219 "https"
220 };
221
222 let url = if parts.len() == 3 {
223 format!("{}://{}/.well-known/did.json", scheme, domain)
224 } else {
225 let path = parts[3..].join("/");
226 format!("{}://{}/{}/did.json", scheme, domain, path)
227 };
228
229 let client = reqwest::Client::builder()
230 .timeout(std::time::Duration::from_secs(5))
231 .build()
232 .map_err(|e| format!("Failed to create client: {}", e))?;
233
234 let resp = client
235 .get(&url)
236 .send()
237 .await
238 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
239
240 if !resp.status().is_success() {
241 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
242 }
243
244 let doc: serde_json::Value = resp
245 .json()
246 .await
247 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
248
249 let services = doc["service"]
250 .as_array()
251 .ok_or("No services found in DID doc")?;
252
253 let pds_endpoint = format!("https://{}", hostname);
254
255 let has_valid_service = services.iter().any(|s| {
256 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
257 });
258
259 if has_valid_service {
260 Ok(())
261 } else {
262 Err(format!(
263 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
264 pds_endpoint
265 ))
266 }
267 }
268}
269
270#[derive(serde::Serialize)]
271#[serde(rename_all = "camelCase")]
272pub struct GetRecommendedDidCredentialsOutput {
273 pub rotation_keys: Vec<String>,
274 pub also_known_as: Vec<String>,
275 pub verification_methods: VerificationMethods,
276 pub services: Services,
277}
278
279#[derive(serde::Serialize)]
280#[serde(rename_all = "camelCase")]
281pub struct VerificationMethods {
282 pub atproto: String,
283}
284
285#[derive(serde::Serialize)]
286#[serde(rename_all = "camelCase")]
287pub struct Services {
288 pub atproto_pds: AtprotoPds,
289}
290
291#[derive(serde::Serialize)]
292#[serde(rename_all = "camelCase")]
293pub struct AtprotoPds {
294 #[serde(rename = "type")]
295 pub service_type: String,
296 pub endpoint: String,
297}
298
299pub async fn get_recommended_did_credentials(
300 State(state): State<AppState>,
301 headers: axum::http::HeaderMap,
302) -> Response {
303 let token = match crate::auth::extract_bearer_token_from_header(
304 headers.get("Authorization").and_then(|h| h.to_str().ok())
305 ) {
306 Some(t) => t,
307 None => {
308 return (
309 StatusCode::UNAUTHORIZED,
310 Json(json!({"error": "AuthenticationRequired"})),
311 )
312 .into_response();
313 }
314 };
315
316 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
317 Ok(user) => user,
318 Err(e) => return ApiError::from(e).into_response(),
319 };
320
321 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
322 .fetch_optional(&state.db)
323 .await
324 {
325 Ok(Some(row)) => row,
326 _ => return ApiError::InternalError.into_response(),
327 };
328
329 let key_bytes = match auth_user.key_bytes {
330 Some(kb) => kb,
331 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
332 };
333
334 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
335 let pds_endpoint = format!("https://{}", hostname);
336
337 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
338 Ok(k) => k,
339 Err(_) => return ApiError::InternalError.into_response(),
340 };
341
342 let public_key = secret_key.public_key();
343 let encoded = public_key.to_encoded_point(true);
344 let did_key = format!(
345 "did:key:zQ3sh{}",
346 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
347 .chars()
348 .skip(1)
349 .collect::<String>()
350 );
351
352 (
353 StatusCode::OK,
354 Json(GetRecommendedDidCredentialsOutput {
355 rotation_keys: vec![did_key.clone()],
356 also_known_as: vec![format!("at://{}", user.handle)],
357 verification_methods: VerificationMethods { atproto: did_key },
358 services: Services {
359 atproto_pds: AtprotoPds {
360 service_type: "AtprotoPersonalDataServer".to_string(),
361 endpoint: pds_endpoint,
362 },
363 },
364 }),
365 )
366 .into_response()
367}
368
369#[derive(Deserialize)]
370pub struct UpdateHandleInput {
371 pub handle: String,
372}
373
374pub async fn update_handle(
375 State(state): State<AppState>,
376 headers: axum::http::HeaderMap,
377 Json(input): Json<UpdateHandleInput>,
378) -> Response {
379 let token = match crate::auth::extract_bearer_token_from_header(
380 headers.get("Authorization").and_then(|h| h.to_str().ok())
381 ) {
382 Some(t) => t,
383 None => return ApiError::AuthenticationRequired.into_response(),
384 };
385
386 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
387 Ok(user) => user.did,
388 Err(e) => return ApiError::from(e).into_response(),
389 };
390
391 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
392 .fetch_optional(&state.db)
393 .await
394 {
395 Ok(Some(id)) => id,
396 _ => return ApiError::InternalError.into_response(),
397 };
398
399 let new_handle = input.handle.trim();
400 if new_handle.is_empty() {
401 return ApiError::InvalidRequest("handle is required".into()).into_response();
402 }
403
404 if !new_handle
405 .chars()
406 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
407 {
408 return (
409 StatusCode::BAD_REQUEST,
410 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
411 )
412 .into_response();
413 }
414
415 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
416 .fetch_optional(&state.db)
417 .await
418 .ok()
419 .flatten();
420
421 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
422 .fetch_optional(&state.db)
423 .await;
424
425 if let Ok(Some(_)) = existing {
426 return (
427 StatusCode::BAD_REQUEST,
428 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
429 )
430 .into_response();
431 }
432
433 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
434 .execute(&state.db)
435 .await;
436
437 match result {
438 Ok(_) => {
439 if let Some(old) = old_handle {
440 let _ = state.cache.delete(&format!("handle:{}", old)).await;
441 }
442 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
443 (StatusCode::OK, Json(json!({}))).into_response()
444 }
445 Err(e) => {
446 error!("DB error updating handle: {:?}", e);
447 (
448 StatusCode::INTERNAL_SERVER_ERROR,
449 Json(json!({"error": "InternalError"})),
450 )
451 .into_response()
452 }
453 }
454}