this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Path, Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use base64::Engine;
9use k256::SecretKey;
10use k256::elliptic_curve::sec1::ToEncodedPoint;
11use reqwest;
12use serde::Deserialize;
13use serde_json::json;
14use sqlx::Row;
15use tracing::error;
16
17#[derive(Deserialize)]
18pub struct ResolveHandleParams {
19 pub handle: String,
20}
21
22pub async fn resolve_handle(
23 State(state): State<AppState>,
24 Query(params): Query<ResolveHandleParams>,
25) -> Response {
26 let handle = params.handle.trim();
27
28 if handle.is_empty() {
29 return (
30 StatusCode::BAD_REQUEST,
31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})),
32 )
33 .into_response();
34 }
35
36 let user = sqlx::query("SELECT did FROM users WHERE handle = $1")
37 .bind(handle)
38 .fetch_optional(&state.db)
39 .await;
40
41 match user {
42 Ok(Some(row)) => {
43 let did: String = row.get("did");
44 (StatusCode::OK, Json(json!({ "did": did }))).into_response()
45 }
46 Ok(None) => (
47 StatusCode::NOT_FOUND,
48 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})),
49 )
50 .into_response(),
51 Err(e) => {
52 error!("DB error resolving handle: {:?}", e);
53 (
54 StatusCode::INTERNAL_SERVER_ERROR,
55 Json(json!({"error": "InternalError"})),
56 )
57 .into_response()
58 }
59 }
60}
61
62pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
63 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
64 let public_key = secret_key.public_key();
65 let encoded = public_key.to_encoded_point(false);
66 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
67 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
68
69 json!({
70 "kty": "EC",
71 "crv": "secp256k1",
72 "x": x,
73 "y": y
74 })
75}
76
77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
79 // Kinda for local dev, encode hostname if it contains port
80 let did = if hostname.contains(':') {
81 format!("did:web:{}", hostname.replace(':', "%3A"))
82 } else {
83 format!("did:web:{}", hostname)
84 };
85
86 Json(json!({
87 "@context": ["https://www.w3.org/ns/did/v1"],
88 "id": did,
89 "service": [{
90 "id": "#atproto_pds",
91 "type": "AtprotoPersonalDataServer",
92 "serviceEndpoint": format!("https://{}", hostname)
93 }]
94 }))
95}
96
97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
99
100 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
101 .bind(&handle)
102 .fetch_optional(&state.db)
103 .await;
104
105 let (user_id, did) = match user {
106 Ok(Some(row)) => {
107 let id: uuid::Uuid = row.get("id");
108 let d: String = row.get("did");
109 (id, d)
110 }
111 Ok(None) => {
112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
113 }
114 Err(e) => {
115 error!("DB Error: {:?}", e);
116 return (
117 StatusCode::INTERNAL_SERVER_ERROR,
118 Json(json!({"error": "InternalError"})),
119 )
120 .into_response();
121 }
122 };
123
124 if !did.starts_with("did:web:") {
125 return (
126 StatusCode::NOT_FOUND,
127 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
128 )
129 .into_response();
130 }
131
132 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
133 .bind(user_id)
134 .fetch_optional(&state.db)
135 .await;
136
137 let key_bytes: Vec<u8> = match key_row {
138 Ok(Some(row)) => row.get("key_bytes"),
139 _ => {
140 return (
141 StatusCode::INTERNAL_SERVER_ERROR,
142 Json(json!({"error": "InternalError"})),
143 )
144 .into_response();
145 }
146 };
147
148 let jwk = get_jwk(&key_bytes);
149
150 Json(json!({
151 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
152 "id": did,
153 "alsoKnownAs": [format!("at://{}", handle)],
154 "verificationMethod": [{
155 "id": format!("{}#atproto", did),
156 "type": "JsonWebKey2020",
157 "controller": did,
158 "publicKeyJwk": jwk
159 }],
160 "service": [{
161 "id": "#atproto_pds",
162 "type": "AtprotoPersonalDataServer",
163 "serviceEndpoint": format!("https://{}", hostname)
164 }]
165 })).into_response()
166}
167
168pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
169 let expected_prefix = if hostname.contains(':') {
170 format!("did:web:{}", hostname.replace(':', "%3A"))
171 } else {
172 format!("did:web:{}", hostname)
173 };
174
175 if did.starts_with(&expected_prefix) {
176 let suffix = &did[expected_prefix.len()..];
177 let expected_suffix = format!(":u:{}", handle);
178 if suffix == expected_suffix {
179 Ok(())
180 } else {
181 Err(format!(
182 "Invalid DID path for this PDS. Expected {}",
183 expected_suffix
184 ))
185 }
186 } else {
187 let parts: Vec<&str> = did.split(':').collect();
188 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
189 return Err("Invalid did:web format".into());
190 }
191
192 let domain_segment = parts[2];
193 let domain = domain_segment.replace("%3A", ":");
194
195 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
196 "http"
197 } else {
198 "https"
199 };
200
201 let url = if parts.len() == 3 {
202 format!("{}://{}/.well-known/did.json", scheme, domain)
203 } else {
204 let path = parts[3..].join("/");
205 format!("{}://{}/{}/did.json", scheme, domain, path)
206 };
207
208 let client = reqwest::Client::builder()
209 .timeout(std::time::Duration::from_secs(5))
210 .build()
211 .map_err(|e| format!("Failed to create client: {}", e))?;
212
213 let resp = client
214 .get(&url)
215 .send()
216 .await
217 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
218
219 if !resp.status().is_success() {
220 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
221 }
222
223 let doc: serde_json::Value = resp
224 .json()
225 .await
226 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
227
228 let services = doc["service"]
229 .as_array()
230 .ok_or("No services found in DID doc")?;
231
232 let pds_endpoint = format!("https://{}", hostname);
233
234 let has_valid_service = services.iter().any(|s| {
235 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
236 });
237
238 if has_valid_service {
239 Ok(())
240 } else {
241 Err(format!(
242 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
243 pds_endpoint
244 ))
245 }
246 }
247}