this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::StatusCode,
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::error;
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35
36 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle)
37 .fetch_optional(&state.db)
38 .await;
39
40 match user {
41 Ok(Some(row)) => {
42 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
43 }
44 Ok(None) => (
45 StatusCode::NOT_FOUND,
46 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
47 )
48 .into_response(),
49 Err(e) => {
50 error!("DB error resolving handle: {:?}", e);
51 (
52 StatusCode::INTERNAL_SERVER_ERROR,
53 Json(json!({"error": "InternalError"})),
54 )
55 .into_response()
56 }
57 }
58}
59
60pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
61 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
62 let public_key = secret_key.public_key();
63 let encoded = public_key.to_encoded_point(false);
64 let x = encoded.x().ok_or("Missing x coordinate")?;
65 let y = encoded.y().ok_or("Missing y coordinate")?;
66 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
67 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
68
69 Ok(json!({
70 "kty": "EC",
71 "crv": "secp256k1",
72 "x": x_b64,
73 "y": y_b64
74 }))
75}
76
77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
79 // Kinda for local dev, encode hostname if it contains port
80 let did = if hostname.contains(':') {
81 format!("did:web:{}", hostname.replace(':', "%3A"))
82 } else {
83 format!("did:web:{}", hostname)
84 };
85
86 Json(json!({
87 "@context": ["https://www.w3.org/ns/did/v1"],
88 "id": did,
89 "service": [{
90 "id": "#atproto_pds",
91 "type": "AtprotoPersonalDataServer",
92 "serviceEndpoint": format!("https://{}", hostname)
93 }]
94 }))
95}
96
97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
99
100 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
101 .fetch_optional(&state.db)
102 .await;
103
104 let (user_id, did) = match user {
105 Ok(Some(row)) => (row.id, row.did),
106 Ok(None) => {
107 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
108 }
109 Err(e) => {
110 error!("DB Error: {:?}", e);
111 return (
112 StatusCode::INTERNAL_SERVER_ERROR,
113 Json(json!({"error": "InternalError"})),
114 )
115 .into_response();
116 }
117 };
118
119 if !did.starts_with("did:web:") {
120 return (
121 StatusCode::NOT_FOUND,
122 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
123 )
124 .into_response();
125 }
126
127 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
128 .fetch_optional(&state.db)
129 .await;
130
131 let key_bytes: Vec<u8> = match key_row {
132 Ok(Some(row)) => {
133 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
134 Ok(k) => k,
135 Err(_) => {
136 return (
137 StatusCode::INTERNAL_SERVER_ERROR,
138 Json(json!({"error": "InternalError"})),
139 )
140 .into_response();
141 }
142 }
143 }
144 _ => {
145 return (
146 StatusCode::INTERNAL_SERVER_ERROR,
147 Json(json!({"error": "InternalError"})),
148 )
149 .into_response();
150 }
151 };
152
153 let jwk = match get_jwk(&key_bytes) {
154 Ok(j) => j,
155 Err(e) => {
156 tracing::error!("Failed to generate JWK: {}", e);
157 return (
158 StatusCode::INTERNAL_SERVER_ERROR,
159 Json(json!({"error": "InternalError"})),
160 )
161 .into_response();
162 }
163 };
164
165 Json(json!({
166 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
167 "id": did,
168 "alsoKnownAs": [format!("at://{}", handle)],
169 "verificationMethod": [{
170 "id": format!("{}#atproto", did),
171 "type": "JsonWebKey2020",
172 "controller": did,
173 "publicKeyJwk": jwk
174 }],
175 "service": [{
176 "id": "#atproto_pds",
177 "type": "AtprotoPersonalDataServer",
178 "serviceEndpoint": format!("https://{}", hostname)
179 }]
180 })).into_response()
181}
182
183pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
184 let expected_prefix = if hostname.contains(':') {
185 format!("did:web:{}", hostname.replace(':', "%3A"))
186 } else {
187 format!("did:web:{}", hostname)
188 };
189
190 if did.starts_with(&expected_prefix) {
191 let suffix = &did[expected_prefix.len()..];
192 let expected_suffix = format!(":u:{}", handle);
193 if suffix == expected_suffix {
194 Ok(())
195 } else {
196 Err(format!(
197 "Invalid DID path for this PDS. Expected {}",
198 expected_suffix
199 ))
200 }
201 } else {
202 let parts: Vec<&str> = did.split(':').collect();
203 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
204 return Err("Invalid did:web format".into());
205 }
206
207 let domain_segment = parts[2];
208 let domain = domain_segment.replace("%3A", ":");
209
210 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
211 "http"
212 } else {
213 "https"
214 };
215
216 let url = if parts.len() == 3 {
217 format!("{}://{}/.well-known/did.json", scheme, domain)
218 } else {
219 let path = parts[3..].join("/");
220 format!("{}://{}/{}/did.json", scheme, domain, path)
221 };
222
223 let client = reqwest::Client::builder()
224 .timeout(std::time::Duration::from_secs(5))
225 .build()
226 .map_err(|e| format!("Failed to create client: {}", e))?;
227
228 let resp = client
229 .get(&url)
230 .send()
231 .await
232 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
233
234 if !resp.status().is_success() {
235 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
236 }
237
238 let doc: serde_json::Value = resp
239 .json()
240 .await
241 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
242
243 let services = doc["service"]
244 .as_array()
245 .ok_or("No services found in DID doc")?;
246
247 let pds_endpoint = format!("https://{}", hostname);
248
249 let has_valid_service = services.iter().any(|s| {
250 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
251 });
252
253 if has_valid_service {
254 Ok(())
255 } else {
256 Err(format!(
257 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
258 pds_endpoint
259 ))
260 }
261 }
262}
263
264#[derive(serde::Serialize)]
265#[serde(rename_all = "camelCase")]
266pub struct GetRecommendedDidCredentialsOutput {
267 pub rotation_keys: Vec<String>,
268 pub also_known_as: Vec<String>,
269 pub verification_methods: VerificationMethods,
270 pub services: Services,
271}
272
273#[derive(serde::Serialize)]
274#[serde(rename_all = "camelCase")]
275pub struct VerificationMethods {
276 pub atproto: String,
277}
278
279#[derive(serde::Serialize)]
280#[serde(rename_all = "camelCase")]
281pub struct Services {
282 pub atproto_pds: AtprotoPds,
283}
284
285#[derive(serde::Serialize)]
286#[serde(rename_all = "camelCase")]
287pub struct AtprotoPds {
288 #[serde(rename = "type")]
289 pub service_type: String,
290 pub endpoint: String,
291}
292
293pub async fn get_recommended_did_credentials(
294 State(state): State<AppState>,
295 headers: axum::http::HeaderMap,
296) -> Response {
297 let token = match crate::auth::extract_bearer_token_from_header(
298 headers.get("Authorization").and_then(|h| h.to_str().ok())
299 ) {
300 Some(t) => t,
301 None => {
302 return (
303 StatusCode::UNAUTHORIZED,
304 Json(json!({"error": "AuthenticationRequired"})),
305 )
306 .into_response();
307 }
308 };
309
310 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
311 Ok(user) => user,
312 Err(e) => return ApiError::from(e).into_response(),
313 };
314
315 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
316 .fetch_optional(&state.db)
317 .await
318 {
319 Ok(Some(row)) => row,
320 _ => return ApiError::InternalError.into_response(),
321 };
322
323 let key_bytes = match auth_user.key_bytes {
324 Some(kb) => kb,
325 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
326 };
327
328 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
329 let pds_endpoint = format!("https://{}", hostname);
330
331 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
332 Ok(k) => k,
333 Err(_) => return ApiError::InternalError.into_response(),
334 };
335
336 let public_key = secret_key.public_key();
337 let encoded = public_key.to_encoded_point(true);
338 let did_key = format!(
339 "did:key:zQ3sh{}",
340 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
341 .chars()
342 .skip(1)
343 .collect::<String>()
344 );
345
346 (
347 StatusCode::OK,
348 Json(GetRecommendedDidCredentialsOutput {
349 rotation_keys: vec![did_key.clone()],
350 also_known_as: vec![format!("at://{}", user.handle)],
351 verification_methods: VerificationMethods { atproto: did_key },
352 services: Services {
353 atproto_pds: AtprotoPds {
354 service_type: "AtprotoPersonalDataServer".to_string(),
355 endpoint: pds_endpoint,
356 },
357 },
358 }),
359 )
360 .into_response()
361}
362
363#[derive(Deserialize)]
364pub struct UpdateHandleInput {
365 pub handle: String,
366}
367
368pub async fn update_handle(
369 State(state): State<AppState>,
370 headers: axum::http::HeaderMap,
371 Json(input): Json<UpdateHandleInput>,
372) -> Response {
373 let token = match crate::auth::extract_bearer_token_from_header(
374 headers.get("Authorization").and_then(|h| h.to_str().ok())
375 ) {
376 Some(t) => t,
377 None => return ApiError::AuthenticationRequired.into_response(),
378 };
379
380 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
381 Ok(user) => user.did,
382 Err(e) => return ApiError::from(e).into_response(),
383 };
384
385 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
386 .fetch_optional(&state.db)
387 .await
388 {
389 Ok(Some(id)) => id,
390 _ => return ApiError::InternalError.into_response(),
391 };
392
393 let new_handle = input.handle.trim();
394 if new_handle.is_empty() {
395 return ApiError::InvalidRequest("handle is required".into()).into_response();
396 }
397
398 if !new_handle
399 .chars()
400 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
401 {
402 return (
403 StatusCode::BAD_REQUEST,
404 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
405 )
406 .into_response();
407 }
408
409 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
410 .fetch_optional(&state.db)
411 .await;
412
413 if let Ok(Some(_)) = existing {
414 return (
415 StatusCode::BAD_REQUEST,
416 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
417 )
418 .into_response();
419 }
420
421 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
422 .execute(&state.db)
423 .await;
424
425 match result {
426 Ok(_) => (StatusCode::OK, Json(json!({}))).into_response(),
427 Err(e) => {
428 error!("DB error updating handle: {:?}", e);
429 (
430 StatusCode::INTERNAL_SERVER_ERROR,
431 Json(json!({"error": "InternalError"})),
432 )
433 .into_response()
434 }
435 }
436}