this repo has no description
1use axum::{
2 extract::{State, Path},
3 Json,
4 response::{IntoResponse, Response},
5 http::StatusCode,
6};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use crate::state::AppState;
10use sqlx::Row;
11use bcrypt::{hash, DEFAULT_COST};
12use tracing::{info, error};
13use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
14use jacquard::types::{string::Tid, did::Did, integer::LimitedU32};
15use std::sync::Arc;
16use k256::SecretKey;
17use rand::rngs::OsRng;
18use base64::Engine;
19use reqwest;
20
21#[derive(Deserialize)]
22pub struct CreateAccountInput {
23 pub handle: String,
24 pub email: String,
25 pub password: String,
26 #[serde(rename = "inviteCode")]
27 pub invite_code: Option<String>,
28 pub did: Option<String>,
29}
30
31#[derive(Serialize)]
32#[serde(rename_all = "camelCase")]
33pub struct CreateAccountOutput {
34 pub access_jwt: String,
35 pub refresh_jwt: String,
36 pub handle: String,
37 pub did: String,
38}
39
40pub async fn create_account(
41 State(state): State<AppState>,
42 Json(input): Json<CreateAccountInput>,
43) -> Response {
44 info!("create_account hit: {}", input.handle);
45 if input.handle.contains('!') || input.handle.contains('@') {
46 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response();
47 }
48
49 let did = if let Some(d) = &input.did {
50 if d.trim().is_empty() {
51 format!("did:plc:{}", uuid::Uuid::new_v4())
52 } else {
53 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
54 if let Err(e) = verify_did_web(d, &hostname, &input.handle).await {
55 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidDid", "message": e}))).into_response();
56 }
57 d.clone()
58 }
59 } else {
60 format!("did:plc:{}", uuid::Uuid::new_v4())
61 };
62
63 let mut tx = match state.db.begin().await {
64 Ok(tx) => tx,
65 Err(e) => {
66 error!("Error starting transaction: {:?}", e);
67 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
68 }
69 };
70
71 let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
72 .bind(&input.handle)
73 .fetch_optional(&mut *tx)
74 .await;
75
76 match exists_query {
77 Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(),
78 Err(e) => {
79 error!("Error checking handle: {:?}", e);
80 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
81 }
82 Ok(None) => {}
83 }
84
85 if let Some(code) = &input.invite_code {
86 let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
87 .bind(code)
88 .fetch_optional(&mut *tx)
89 .await;
90
91 match invite_query {
92 Ok(Some(row)) => {
93 let uses: i32 = row.get("available_uses");
94 if uses <= 0 {
95 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
96 }
97
98 let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1")
99 .bind(code)
100 .execute(&mut *tx)
101 .await;
102
103 if let Err(e) = update_invite {
104 error!("Error updating invite code: {:?}", e);
105 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
106 }
107 },
108 Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(),
109 Err(e) => {
110 error!("Error checking invite code: {:?}", e);
111 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
112 }
113 }
114 }
115
116 let password_hash = match hash(&input.password, DEFAULT_COST) {
117 Ok(h) => h,
118 Err(e) => {
119 error!("Error hashing password: {:?}", e);
120 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
121 }
122 };
123
124 let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
125 .bind(&input.handle)
126 .bind(&input.email)
127 .bind(&did)
128 .bind(&password_hash)
129 .fetch_one(&mut *tx)
130 .await;
131
132 let user_id: uuid::Uuid = match user_insert {
133 Ok(row) => row.get("id"),
134 Err(e) => {
135 error!("Error inserting user: {:?}", e);
136 // TODO: Check for unique constraint violation on email/did specifically
137 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
138 }
139 };
140
141 let secret_key = SecretKey::random(&mut OsRng);
142 let secret_key_bytes = secret_key.to_bytes();
143
144 let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
145 .bind(user_id)
146 .bind(&secret_key_bytes[..])
147 .execute(&mut *tx)
148 .await;
149
150 if let Err(e) = key_insert {
151 error!("Error inserting user key: {:?}", e);
152 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
153 }
154
155 let mst = Mst::new(Arc::new(state.block_store.clone()));
156 let mst_root = match mst.root().await {
157 Ok(c) => c,
158 Err(e) => {
159 error!("Error creating MST root: {:?}", e);
160 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
161 }
162 };
163
164 let did_obj = match Did::new(&did) {
165 Ok(d) => d,
166 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
167 };
168
169 let rev = Tid::now(LimitedU32::MIN);
170
171 let commit = Commit::new_unsigned(
172 did_obj,
173 mst_root,
174 rev,
175 None
176 );
177
178 let commit_bytes = match commit.to_cbor() {
179 Ok(b) => b,
180 Err(e) => {
181 error!("Error serializing genesis commit: {:?}", e);
182 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
183 }
184 };
185
186 let commit_cid = match state.block_store.put(&commit_bytes).await {
187 Ok(c) => c,
188 Err(e) => {
189 error!("Error saving genesis commit: {:?}", e);
190 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
191 }
192 };
193
194 let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
195 .bind(user_id)
196 .bind(commit_cid.to_string())
197 .execute(&mut *tx)
198 .await;
199
200 if let Err(e) = repo_insert {
201 error!("Error initializing repo: {:?}", e);
202 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
203 }
204
205 if let Some(code) = &input.invite_code {
206 let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
207 .bind(code)
208 .bind(user_id)
209 .execute(&mut *tx)
210 .await;
211
212 if let Err(e) = use_insert {
213 error!("Error recording invite usage: {:?}", e);
214 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
215 }
216 }
217
218 let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
219 error!("Error creating access token: {:?}", e);
220 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
221 });
222 let access_jwt = match access_jwt {
223 Ok(t) => t,
224 Err(r) => return r,
225 };
226
227 let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
228 error!("Error creating refresh token: {:?}", e);
229 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
230 });
231 let refresh_jwt = match refresh_jwt {
232 Ok(t) => t,
233 Err(r) => return r,
234 };
235
236 let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
237 .bind(&access_jwt)
238 .bind(&refresh_jwt)
239 .bind(&did)
240 .execute(&mut *tx)
241 .await;
242
243 if let Err(e) = session_insert {
244 error!("Error inserting session: {:?}", e);
245 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
246 }
247
248 if let Err(e) = tx.commit().await {
249 error!("Error committing transaction: {:?}", e);
250 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
251 }
252
253 (StatusCode::OK, Json(CreateAccountOutput {
254 access_jwt,
255 refresh_jwt,
256 handle: input.handle,
257 did,
258 })).into_response()
259}
260
261fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
262 use k256::elliptic_curve::sec1::ToEncodedPoint;
263
264 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
265 let public_key = secret_key.public_key();
266 let encoded = public_key.to_encoded_point(false);
267 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
268 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
269
270 json!({
271 "kty": "EC",
272 "crv": "secp256k1",
273 "x": x,
274 "y": y
275 })
276}
277
278pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
279 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
280 // Kinda for local dev, encode hostname if it contains port
281 let did = if hostname.contains(':') {
282 format!("did:web:{}", hostname.replace(':', "%3A"))
283 } else {
284 format!("did:web:{}", hostname)
285 };
286
287 Json(json!({
288 "@context": ["https://www.w3.org/ns/did/v1"],
289 "id": did,
290 "service": [{
291 "id": "#atproto_pds",
292 "type": "AtprotoPersonalDataServer",
293 "serviceEndpoint": format!("https://{}", hostname)
294 }]
295 }))
296}
297
298pub async fn user_did_doc(
299 State(state): State<AppState>,
300 Path(handle): Path<String>,
301) -> Response {
302 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
303
304 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
305 .bind(&handle)
306 .fetch_optional(&state.db)
307 .await;
308
309 let (user_id, did) = match user {
310 Ok(Some(row)) => {
311 let id: uuid::Uuid = row.get("id");
312 let d: String = row.get("did");
313 (id, d)
314 },
315 Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(),
316 Err(e) => {
317 error!("DB Error: {:?}", e);
318 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
319 },
320 };
321
322 if !did.starts_with("did:web:") {
323 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response();
324 }
325
326 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
327 .bind(user_id)
328 .fetch_optional(&state.db)
329 .await;
330
331 let key_bytes: Vec<u8> = match key_row {
332 Ok(Some(row)) => row.get("key_bytes"),
333 _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
334 };
335
336 let jwk = get_jwk(&key_bytes);
337
338 Json(json!({
339 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
340 "id": did,
341 "alsoKnownAs": [format!("at://{}", handle)],
342 "verificationMethod": [{
343 "id": format!("{}#atproto", did),
344 "type": "JsonWebKey2020",
345 "controller": did,
346 "publicKeyJwk": jwk
347 }],
348 "service": [{
349 "id": "#atproto_pds",
350 "type": "AtprotoPersonalDataServer",
351 "serviceEndpoint": format!("https://{}", hostname)
352 }]
353 })).into_response()
354}
355
356async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
357 let expected_prefix = if hostname.contains(':') {
358 format!("did:web:{}", hostname.replace(':', "%3A"))
359 } else {
360 format!("did:web:{}", hostname)
361 };
362
363 if did.starts_with(&expected_prefix) {
364 let suffix = &did[expected_prefix.len()..];
365 let expected_suffix = format!(":u:{}", handle);
366 if suffix == expected_suffix {
367 Ok(())
368 } else {
369 Err(format!("Invalid DID path for this PDS. Expected {}", expected_suffix))
370 }
371 } else {
372 let parts: Vec<&str> = did.split(':').collect();
373 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
374 return Err("Invalid did:web format".into());
375 }
376
377 let domain_segment = parts[2];
378 let domain = domain_segment.replace("%3A", ":");
379
380 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
381 "http"
382 } else {
383 "https"
384 };
385
386 let url = if parts.len() == 3 {
387 format!("{}://{}/.well-known/did.json", scheme, domain)
388 } else {
389 let path = parts[3..].join("/");
390 format!("{}://{}/{}/did.json", scheme, domain, path)
391 };
392
393 let client = reqwest::Client::builder()
394 .timeout(std::time::Duration::from_secs(5))
395 .build()
396 .map_err(|e| format!("Failed to create client: {}", e))?;
397
398 let resp = client.get(&url).send().await
399 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
400
401 if !resp.status().is_success() {
402 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
403 }
404
405 let doc: serde_json::Value = resp.json().await
406 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
407
408 let services = doc["service"].as_array()
409 .ok_or("No services found in DID doc")?;
410
411 let pds_endpoint = format!("https://{}", hostname);
412
413 let has_valid_service = services.iter().any(|s| {
414 s["type"] == "AtprotoPersonalDataServer" &&
415 s["serviceEndpoint"] == pds_endpoint
416 });
417
418 if has_valid_service {
419 Ok(())
420 } else {
421 Err(format!("DID document does not list this PDS ({}) as AtprotoPersonalDataServer", pds_endpoint))
422 }
423 }
424}