this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::{error, warn};
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27 if handle.is_empty() {
28 return (
29 StatusCode::BAD_REQUEST,
30 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
31 )
32 .into_response();
33 }
34 let cache_key = format!("handle:{}", handle);
35 if let Some(did) = state.cache.get(&cache_key).await {
36 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
37 }
38 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
39 let suffix = format!(".{}", hostname);
40 let short_handle = if handle.ends_with(&suffix) {
41 handle.strip_suffix(&suffix).unwrap_or(handle)
42 } else {
43 handle
44 };
45 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
46 .fetch_optional(&state.db)
47 .await;
48 match user {
49 Ok(Some(row)) => {
50 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
51 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
52 }
53 Ok(None) => (
54 StatusCode::NOT_FOUND,
55 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
56 )
57 .into_response(),
58 Err(e) => {
59 error!("DB error resolving handle: {:?}", e);
60 (
61 StatusCode::INTERNAL_SERVER_ERROR,
62 Json(json!({"error": "InternalError"})),
63 )
64 .into_response()
65 }
66 }
67}
68
69pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
70 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
71 let public_key = secret_key.public_key();
72 let encoded = public_key.to_encoded_point(false);
73 let x = encoded.x().ok_or("Missing x coordinate")?;
74 let y = encoded.y().ok_or("Missing y coordinate")?;
75 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
76 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
77 Ok(json!({
78 "kty": "EC",
79 "crv": "secp256k1",
80 "x": x_b64,
81 "y": y_b64
82 }))
83}
84
85pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
86 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
87 // Kinda for local dev, encode hostname if it contains port
88 let did = if hostname.contains(':') {
89 format!("did:web:{}", hostname.replace(':', "%3A"))
90 } else {
91 format!("did:web:{}", hostname)
92 };
93 Json(json!({
94 "@context": ["https://www.w3.org/ns/did/v1"],
95 "id": did,
96 "service": [{
97 "id": "#atproto_pds",
98 "type": "AtprotoPersonalDataServer",
99 "serviceEndpoint": format!("https://{}", hostname)
100 }]
101 }))
102}
103
104pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
105 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
106 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
107 .fetch_optional(&state.db)
108 .await;
109 let (user_id, did) = match user {
110 Ok(Some(row)) => (row.id, row.did),
111 Ok(None) => {
112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
113 }
114 Err(e) => {
115 error!("DB Error: {:?}", e);
116 return (
117 StatusCode::INTERNAL_SERVER_ERROR,
118 Json(json!({"error": "InternalError"})),
119 )
120 .into_response();
121 }
122 };
123 if !did.starts_with("did:web:") {
124 return (
125 StatusCode::NOT_FOUND,
126 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
127 )
128 .into_response();
129 }
130 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
131 .fetch_optional(&state.db)
132 .await;
133 let key_bytes: Vec<u8> = match key_row {
134 Ok(Some(row)) => {
135 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
136 Ok(k) => k,
137 Err(_) => {
138 return (
139 StatusCode::INTERNAL_SERVER_ERROR,
140 Json(json!({"error": "InternalError"})),
141 )
142 .into_response();
143 }
144 }
145 }
146 _ => {
147 return (
148 StatusCode::INTERNAL_SERVER_ERROR,
149 Json(json!({"error": "InternalError"})),
150 )
151 .into_response();
152 }
153 };
154 let jwk = match get_jwk(&key_bytes) {
155 Ok(j) => j,
156 Err(e) => {
157 tracing::error!("Failed to generate JWK: {}", e);
158 return (
159 StatusCode::INTERNAL_SERVER_ERROR,
160 Json(json!({"error": "InternalError"})),
161 )
162 .into_response();
163 }
164 };
165 Json(json!({
166 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
167 "id": did,
168 "alsoKnownAs": [format!("at://{}", handle)],
169 "verificationMethod": [{
170 "id": format!("{}#atproto", did),
171 "type": "JsonWebKey2020",
172 "controller": did,
173 "publicKeyJwk": jwk
174 }],
175 "service": [{
176 "id": "#atproto_pds",
177 "type": "AtprotoPersonalDataServer",
178 "serviceEndpoint": format!("https://{}", hostname)
179 }]
180 })).into_response()
181}
182
183pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
184 let expected_prefix = if hostname.contains(':') {
185 format!("did:web:{}", hostname.replace(':', "%3A"))
186 } else {
187 format!("did:web:{}", hostname)
188 };
189 if did.starts_with(&expected_prefix) {
190 let suffix = &did[expected_prefix.len()..];
191 let expected_suffix = format!(":u:{}", handle);
192 if suffix == expected_suffix {
193 Ok(())
194 } else {
195 Err(format!(
196 "Invalid DID path for this PDS. Expected {}",
197 expected_suffix
198 ))
199 }
200 } else {
201 let parts: Vec<&str> = did.split(':').collect();
202 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
203 return Err("Invalid did:web format".into());
204 }
205 let domain_segment = parts[2];
206 let domain = domain_segment.replace("%3A", ":");
207 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
208 "http"
209 } else {
210 "https"
211 };
212 let url = if parts.len() == 3 {
213 format!("{}://{}/.well-known/did.json", scheme, domain)
214 } else {
215 let path = parts[3..].join("/");
216 format!("{}://{}/{}/did.json", scheme, domain, path)
217 };
218 let client = reqwest::Client::builder()
219 .timeout(std::time::Duration::from_secs(5))
220 .build()
221 .map_err(|e| format!("Failed to create client: {}", e))?;
222 let resp = client
223 .get(&url)
224 .send()
225 .await
226 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
227 if !resp.status().is_success() {
228 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
229 }
230 let doc: serde_json::Value = resp
231 .json()
232 .await
233 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
234 let services = doc["service"]
235 .as_array()
236 .ok_or("No services found in DID doc")?;
237 let pds_endpoint = format!("https://{}", hostname);
238 let has_valid_service = services.iter().any(|s| {
239 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
240 });
241 if has_valid_service {
242 Ok(())
243 } else {
244 Err(format!(
245 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
246 pds_endpoint
247 ))
248 }
249 }
250}
251
252#[derive(serde::Serialize)]
253#[serde(rename_all = "camelCase")]
254pub struct GetRecommendedDidCredentialsOutput {
255 pub rotation_keys: Vec<String>,
256 pub also_known_as: Vec<String>,
257 pub verification_methods: VerificationMethods,
258 pub services: Services,
259}
260
261#[derive(serde::Serialize)]
262#[serde(rename_all = "camelCase")]
263pub struct VerificationMethods {
264 pub atproto: String,
265}
266
267#[derive(serde::Serialize)]
268#[serde(rename_all = "camelCase")]
269pub struct Services {
270 pub atproto_pds: AtprotoPds,
271}
272
273#[derive(serde::Serialize)]
274#[serde(rename_all = "camelCase")]
275pub struct AtprotoPds {
276 #[serde(rename = "type")]
277 pub service_type: String,
278 pub endpoint: String,
279}
280
281pub async fn get_recommended_did_credentials(
282 State(state): State<AppState>,
283 headers: axum::http::HeaderMap,
284) -> Response {
285 let token = match crate::auth::extract_bearer_token_from_header(
286 headers.get("Authorization").and_then(|h| h.to_str().ok())
287 ) {
288 Some(t) => t,
289 None => {
290 return (
291 StatusCode::UNAUTHORIZED,
292 Json(json!({"error": "AuthenticationRequired"})),
293 )
294 .into_response();
295 }
296 };
297 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
298 Ok(user) => user,
299 Err(e) => return ApiError::from(e).into_response(),
300 };
301 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
302 .fetch_optional(&state.db)
303 .await
304 {
305 Ok(Some(row)) => row,
306 _ => return ApiError::InternalError.into_response(),
307 };
308 let key_bytes = match auth_user.key_bytes {
309 Some(kb) => kb,
310 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
311 };
312 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
313 let pds_endpoint = format!("https://{}", hostname);
314 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
315 Ok(k) => k,
316 Err(_) => return ApiError::InternalError.into_response(),
317 };
318 let public_key = secret_key.public_key();
319 let encoded = public_key.to_encoded_point(true);
320 let did_key = format!(
321 "did:key:zQ3sh{}",
322 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
323 .chars()
324 .skip(1)
325 .collect::<String>()
326 );
327 (
328 StatusCode::OK,
329 Json(GetRecommendedDidCredentialsOutput {
330 rotation_keys: vec![did_key.clone()],
331 also_known_as: vec![format!("at://{}", user.handle)],
332 verification_methods: VerificationMethods { atproto: did_key },
333 services: Services {
334 atproto_pds: AtprotoPds {
335 service_type: "AtprotoPersonalDataServer".to_string(),
336 endpoint: pds_endpoint,
337 },
338 },
339 }),
340 )
341 .into_response()
342}
343
344#[derive(Deserialize)]
345pub struct UpdateHandleInput {
346 pub handle: String,
347}
348
349pub async fn update_handle(
350 State(state): State<AppState>,
351 headers: axum::http::HeaderMap,
352 Json(input): Json<UpdateHandleInput>,
353) -> Response {
354 let token = match crate::auth::extract_bearer_token_from_header(
355 headers.get("Authorization").and_then(|h| h.to_str().ok())
356 ) {
357 Some(t) => t,
358 None => return ApiError::AuthenticationRequired.into_response(),
359 };
360 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
361 Ok(user) => user.did,
362 Err(e) => return ApiError::from(e).into_response(),
363 };
364 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
365 .fetch_optional(&state.db)
366 .await
367 {
368 Ok(Some(id)) => id,
369 _ => return ApiError::InternalError.into_response(),
370 };
371 let new_handle = input.handle.trim();
372 if new_handle.is_empty() {
373 return ApiError::InvalidRequest("handle is required".into()).into_response();
374 }
375 if !new_handle
376 .chars()
377 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
378 {
379 return (
380 StatusCode::BAD_REQUEST,
381 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
382 )
383 .into_response();
384 }
385 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
386 .fetch_optional(&state.db)
387 .await
388 .ok()
389 .flatten();
390 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
391 .fetch_optional(&state.db)
392 .await;
393 if let Ok(Some(_)) = existing {
394 return (
395 StatusCode::BAD_REQUEST,
396 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
397 )
398 .into_response();
399 }
400 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
401 .execute(&state.db)
402 .await;
403 match result {
404 Ok(_) => {
405 if let Some(old) = old_handle {
406 let _ = state.cache.delete(&format!("handle:{}", old)).await;
407 }
408 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
409 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
410 let full_handle = format!("{}.{}", new_handle, hostname);
411 if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await {
412 warn!("Failed to sequence identity event for handle update: {}", e);
413 }
414 (StatusCode::OK, Json(json!({}))).into_response()
415 }
416 Err(e) => {
417 error!("DB error updating handle: {:?}", e);
418 (
419 StatusCode::INTERNAL_SERVER_ERROR,
420 Json(json!({"error": "InternalError"})),
421 )
422 .into_response()
423 }
424 }
425}
426
427pub async fn well_known_atproto_did(
428 State(state): State<AppState>,
429 headers: HeaderMap,
430) -> Response {
431 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
432 Some(h) => h,
433 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
434 };
435 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
436 let suffix = format!(".{}", hostname);
437 let handle = host.split(':').next().unwrap_or(host);
438 let short_handle = if handle.ends_with(&suffix) {
439 handle.strip_suffix(&suffix).unwrap_or(handle)
440 } else {
441 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
442 };
443 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
444 .fetch_optional(&state.db)
445 .await;
446 match user {
447 Ok(Some(row)) => row.did.into_response(),
448 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
449 Err(e) => {
450 error!("DB error in well-known atproto-did: {:?}", e);
451 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
452 }
453 }
454}