this repo has no description
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::{Path, Query, State},
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use base64::Engine;
10use k256::SecretKey;
11use k256::elliptic_curve::sec1::ToEncodedPoint;
12use reqwest;
13use serde::Deserialize;
14use serde_json::json;
15use tracing::{error, warn};
16#[derive(Deserialize)]
17pub struct ResolveHandleParams {
18 pub handle: String,
19}
20pub async fn resolve_handle(
21 State(state): State<AppState>,
22 Query(params): Query<ResolveHandleParams>,
23) -> Response {
24 let handle = params.handle.trim();
25 if handle.is_empty() {
26 return (
27 StatusCode::BAD_REQUEST,
28 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
29 )
30 .into_response();
31 }
32 let cache_key = format!("handle:{}", handle);
33 if let Some(did) = state.cache.get(&cache_key).await {
34 return (StatusCode::OK, Json(json!({ "did": did }))).into_response();
35 }
36 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
37 let suffix = format!(".{}", hostname);
38 let short_handle = if handle.ends_with(&suffix) {
39 handle.strip_suffix(&suffix).unwrap_or(handle)
40 } else {
41 handle
42 };
43 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
44 .fetch_optional(&state.db)
45 .await;
46 match user {
47 Ok(Some(row)) => {
48 let _ = state.cache.set(&cache_key, &row.did, std::time::Duration::from_secs(300)).await;
49 (StatusCode::OK, Json(json!({ "did": row.did }))).into_response()
50 }
51 Ok(None) => (
52 StatusCode::NOT_FOUND,
53 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
54 )
55 .into_response(),
56 Err(e) => {
57 error!("DB error resolving handle: {:?}", e);
58 (
59 StatusCode::INTERNAL_SERVER_ERROR,
60 Json(json!({"error": "InternalError"})),
61 )
62 .into_response()
63 }
64 }
65}
66pub fn get_jwk(key_bytes: &[u8]) -> Result<serde_json::Value, &'static str> {
67 let secret_key = SecretKey::from_slice(key_bytes).map_err(|_| "Invalid key length")?;
68 let public_key = secret_key.public_key();
69 let encoded = public_key.to_encoded_point(false);
70 let x = encoded.x().ok_or("Missing x coordinate")?;
71 let y = encoded.y().ok_or("Missing y coordinate")?;
72 let x_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x);
73 let y_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y);
74 Ok(json!({
75 "kty": "EC",
76 "crv": "secp256k1",
77 "x": x_b64,
78 "y": y_b64
79 }))
80}
81pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
82 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83 // Kinda for local dev, encode hostname if it contains port
84 let did = if hostname.contains(':') {
85 format!("did:web:{}", hostname.replace(':', "%3A"))
86 } else {
87 format!("did:web:{}", hostname)
88 };
89 Json(json!({
90 "@context": ["https://www.w3.org/ns/did/v1"],
91 "id": did,
92 "service": [{
93 "id": "#atproto_pds",
94 "type": "AtprotoPersonalDataServer",
95 "serviceEndpoint": format!("https://{}", hostname)
96 }]
97 }))
98}
99pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
100 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
101 let user = sqlx::query!("SELECT id, did FROM users WHERE handle = $1", handle)
102 .fetch_optional(&state.db)
103 .await;
104 let (user_id, did) = match user {
105 Ok(Some(row)) => (row.id, row.did),
106 Ok(None) => {
107 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
108 }
109 Err(e) => {
110 error!("DB Error: {:?}", e);
111 return (
112 StatusCode::INTERNAL_SERVER_ERROR,
113 Json(json!({"error": "InternalError"})),
114 )
115 .into_response();
116 }
117 };
118 if !did.starts_with("did:web:") {
119 return (
120 StatusCode::NOT_FOUND,
121 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
122 )
123 .into_response();
124 }
125 let key_row = sqlx::query!("SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", user_id)
126 .fetch_optional(&state.db)
127 .await;
128 let key_bytes: Vec<u8> = match key_row {
129 Ok(Some(row)) => {
130 match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
131 Ok(k) => k,
132 Err(_) => {
133 return (
134 StatusCode::INTERNAL_SERVER_ERROR,
135 Json(json!({"error": "InternalError"})),
136 )
137 .into_response();
138 }
139 }
140 }
141 _ => {
142 return (
143 StatusCode::INTERNAL_SERVER_ERROR,
144 Json(json!({"error": "InternalError"})),
145 )
146 .into_response();
147 }
148 };
149 let jwk = match get_jwk(&key_bytes) {
150 Ok(j) => j,
151 Err(e) => {
152 tracing::error!("Failed to generate JWK: {}", e);
153 return (
154 StatusCode::INTERNAL_SERVER_ERROR,
155 Json(json!({"error": "InternalError"})),
156 )
157 .into_response();
158 }
159 };
160 Json(json!({
161 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
162 "id": did,
163 "alsoKnownAs": [format!("at://{}", handle)],
164 "verificationMethod": [{
165 "id": format!("{}#atproto", did),
166 "type": "JsonWebKey2020",
167 "controller": did,
168 "publicKeyJwk": jwk
169 }],
170 "service": [{
171 "id": "#atproto_pds",
172 "type": "AtprotoPersonalDataServer",
173 "serviceEndpoint": format!("https://{}", hostname)
174 }]
175 })).into_response()
176}
177pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
178 let expected_prefix = if hostname.contains(':') {
179 format!("did:web:{}", hostname.replace(':', "%3A"))
180 } else {
181 format!("did:web:{}", hostname)
182 };
183 if did.starts_with(&expected_prefix) {
184 let suffix = &did[expected_prefix.len()..];
185 let expected_suffix = format!(":u:{}", handle);
186 if suffix == expected_suffix {
187 Ok(())
188 } else {
189 Err(format!(
190 "Invalid DID path for this PDS. Expected {}",
191 expected_suffix
192 ))
193 }
194 } else {
195 let parts: Vec<&str> = did.split(':').collect();
196 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
197 return Err("Invalid did:web format".into());
198 }
199 let domain_segment = parts[2];
200 let domain = domain_segment.replace("%3A", ":");
201 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
202 "http"
203 } else {
204 "https"
205 };
206 let url = if parts.len() == 3 {
207 format!("{}://{}/.well-known/did.json", scheme, domain)
208 } else {
209 let path = parts[3..].join("/");
210 format!("{}://{}/{}/did.json", scheme, domain, path)
211 };
212 let client = reqwest::Client::builder()
213 .timeout(std::time::Duration::from_secs(5))
214 .build()
215 .map_err(|e| format!("Failed to create client: {}", e))?;
216 let resp = client
217 .get(&url)
218 .send()
219 .await
220 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
221 if !resp.status().is_success() {
222 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
223 }
224 let doc: serde_json::Value = resp
225 .json()
226 .await
227 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
228 let services = doc["service"]
229 .as_array()
230 .ok_or("No services found in DID doc")?;
231 let pds_endpoint = format!("https://{}", hostname);
232 let has_valid_service = services.iter().any(|s| {
233 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
234 });
235 if has_valid_service {
236 Ok(())
237 } else {
238 Err(format!(
239 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
240 pds_endpoint
241 ))
242 }
243 }
244}
245#[derive(serde::Serialize)]
246#[serde(rename_all = "camelCase")]
247pub struct GetRecommendedDidCredentialsOutput {
248 pub rotation_keys: Vec<String>,
249 pub also_known_as: Vec<String>,
250 pub verification_methods: VerificationMethods,
251 pub services: Services,
252}
253#[derive(serde::Serialize)]
254#[serde(rename_all = "camelCase")]
255pub struct VerificationMethods {
256 pub atproto: String,
257}
258#[derive(serde::Serialize)]
259#[serde(rename_all = "camelCase")]
260pub struct Services {
261 pub atproto_pds: AtprotoPds,
262}
263#[derive(serde::Serialize)]
264#[serde(rename_all = "camelCase")]
265pub struct AtprotoPds {
266 #[serde(rename = "type")]
267 pub service_type: String,
268 pub endpoint: String,
269}
270pub async fn get_recommended_did_credentials(
271 State(state): State<AppState>,
272 headers: axum::http::HeaderMap,
273) -> Response {
274 let token = match crate::auth::extract_bearer_token_from_header(
275 headers.get("Authorization").and_then(|h| h.to_str().ok())
276 ) {
277 Some(t) => t,
278 None => {
279 return (
280 StatusCode::UNAUTHORIZED,
281 Json(json!({"error": "AuthenticationRequired"})),
282 )
283 .into_response();
284 }
285 };
286 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
287 Ok(user) => user,
288 Err(e) => return ApiError::from(e).into_response(),
289 };
290 let user = match sqlx::query!("SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", auth_user.did)
291 .fetch_optional(&state.db)
292 .await
293 {
294 Ok(Some(row)) => row,
295 _ => return ApiError::InternalError.into_response(),
296 };
297 let key_bytes = match auth_user.key_bytes {
298 Some(kb) => kb,
299 None => return ApiError::AuthenticationFailedMsg("OAuth tokens cannot get DID credentials".into()).into_response(),
300 };
301 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
302 let pds_endpoint = format!("https://{}", hostname);
303 let secret_key = match k256::SecretKey::from_slice(&key_bytes) {
304 Ok(k) => k,
305 Err(_) => return ApiError::InternalError.into_response(),
306 };
307 let public_key = secret_key.public_key();
308 let encoded = public_key.to_encoded_point(true);
309 let did_key = format!(
310 "did:key:zQ3sh{}",
311 multibase::encode(multibase::Base::Base58Btc, encoded.as_bytes())
312 .chars()
313 .skip(1)
314 .collect::<String>()
315 );
316 (
317 StatusCode::OK,
318 Json(GetRecommendedDidCredentialsOutput {
319 rotation_keys: vec![did_key.clone()],
320 also_known_as: vec![format!("at://{}", user.handle)],
321 verification_methods: VerificationMethods { atproto: did_key },
322 services: Services {
323 atproto_pds: AtprotoPds {
324 service_type: "AtprotoPersonalDataServer".to_string(),
325 endpoint: pds_endpoint,
326 },
327 },
328 }),
329 )
330 .into_response()
331}
332#[derive(Deserialize)]
333pub struct UpdateHandleInput {
334 pub handle: String,
335}
336pub async fn update_handle(
337 State(state): State<AppState>,
338 headers: axum::http::HeaderMap,
339 Json(input): Json<UpdateHandleInput>,
340) -> Response {
341 let token = match crate::auth::extract_bearer_token_from_header(
342 headers.get("Authorization").and_then(|h| h.to_str().ok())
343 ) {
344 Some(t) => t,
345 None => return ApiError::AuthenticationRequired.into_response(),
346 };
347 let did = match crate::auth::validate_bearer_token(&state.db, &token).await {
348 Ok(user) => user.did,
349 Err(e) => return ApiError::from(e).into_response(),
350 };
351 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
352 .fetch_optional(&state.db)
353 .await
354 {
355 Ok(Some(id)) => id,
356 _ => return ApiError::InternalError.into_response(),
357 };
358 let new_handle = input.handle.trim();
359 if new_handle.is_empty() {
360 return ApiError::InvalidRequest("handle is required".into()).into_response();
361 }
362 if !new_handle
363 .chars()
364 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_')
365 {
366 return (
367 StatusCode::BAD_REQUEST,
368 Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"})),
369 )
370 .into_response();
371 }
372 let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE id = $1", user_id)
373 .fetch_optional(&state.db)
374 .await
375 .ok()
376 .flatten();
377 let existing = sqlx::query!("SELECT id FROM users WHERE handle = $1 AND id != $2", new_handle, user_id)
378 .fetch_optional(&state.db)
379 .await;
380 if let Ok(Some(_)) = existing {
381 return (
382 StatusCode::BAD_REQUEST,
383 Json(json!({"error": "HandleTaken", "message": "Handle is already in use"})),
384 )
385 .into_response();
386 }
387 let result = sqlx::query!("UPDATE users SET handle = $1 WHERE id = $2", new_handle, user_id)
388 .execute(&state.db)
389 .await;
390 match result {
391 Ok(_) => {
392 if let Some(old) = old_handle {
393 let _ = state.cache.delete(&format!("handle:{}", old)).await;
394 }
395 let _ = state.cache.delete(&format!("handle:{}", new_handle)).await;
396 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
397 let full_handle = format!("{}.{}", new_handle, hostname);
398 if let Err(e) = crate::api::repo::record::sequence_identity_event(&state, &did, Some(&full_handle)).await {
399 warn!("Failed to sequence identity event for handle update: {}", e);
400 }
401 (StatusCode::OK, Json(json!({}))).into_response()
402 }
403 Err(e) => {
404 error!("DB error updating handle: {:?}", e);
405 (
406 StatusCode::INTERNAL_SERVER_ERROR,
407 Json(json!({"error": "InternalError"})),
408 )
409 .into_response()
410 }
411 }
412}
413pub async fn well_known_atproto_did(
414 State(state): State<AppState>,
415 headers: HeaderMap,
416) -> Response {
417 let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
418 Some(h) => h,
419 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
420 };
421 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
422 let suffix = format!(".{}", hostname);
423 let handle = host.split(':').next().unwrap_or(host);
424 let short_handle = if handle.ends_with(&suffix) {
425 handle.strip_suffix(&suffix).unwrap_or(handle)
426 } else {
427 return (StatusCode::NOT_FOUND, "Handle not found").into_response();
428 };
429 let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", short_handle)
430 .fetch_optional(&state.db)
431 .await;
432 match user {
433 Ok(Some(row)) => row.did.into_response(),
434 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(),
435 Err(e) => {
436 error!("DB error in well-known atproto-did: {:?}", e);
437 (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response()
438 }
439 }
440}