this repo has no description
1use axum::{
2 extract::{State, Path},
3 Json,
4 response::{IntoResponse, Response},
5 http::StatusCode,
6};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use crate::state::AppState;
10use sqlx::Row;
11use bcrypt::{hash, DEFAULT_COST};
12use tracing::{info, error};
13use jacquard_repo::{mst::Mst, commit::Commit, storage::BlockStore};
14use jacquard::types::{string::Tid, did::Did, integer::LimitedU32};
15use std::sync::Arc;
16use k256::SecretKey;
17use rand::rngs::OsRng;
18use base64::Engine;
19
20#[derive(Deserialize)]
21pub struct CreateAccountInput {
22 pub handle: String,
23 pub email: String,
24 pub password: String,
25 #[serde(rename = "inviteCode")]
26 pub invite_code: Option<String>,
27 pub did: Option<String>,
28}
29
30#[derive(Serialize)]
31#[serde(rename_all = "camelCase")]
32pub struct CreateAccountOutput {
33 pub access_jwt: String,
34 pub refresh_jwt: String,
35 pub handle: String,
36 pub did: String,
37}
38
39pub async fn create_account(
40 State(state): State<AppState>,
41 Json(input): Json<CreateAccountInput>,
42) -> Response {
43 info!("create_account hit: {}", input.handle);
44 if input.handle.contains('!') || input.handle.contains('@') {
45 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidHandle", "message": "Handle contains invalid characters"}))).into_response();
46 }
47
48 let did = if let Some(d) = &input.did {
49 if d.trim().is_empty() {
50 format!("did:plc:{}", uuid::Uuid::new_v4())
51 } else {
52 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
53 let _expected_prefix = format!("did:web:{}", hostname);
54
55 // TODO: should verify we are the authority for it if it matches our hostname.
56 // TODO: if it's an external did:web, we should technically verify ownership via ServiceAuth, but skipping for now.
57 d.clone()
58 }
59 } else {
60 format!("did:plc:{}", uuid::Uuid::new_v4())
61 };
62
63 let mut tx = match state.db.begin().await {
64 Ok(tx) => tx,
65 Err(e) => {
66 error!("Error starting transaction: {:?}", e);
67 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
68 }
69 };
70
71 let exists_query = sqlx::query("SELECT 1 FROM users WHERE handle = $1")
72 .bind(&input.handle)
73 .fetch_optional(&mut *tx)
74 .await;
75
76 match exists_query {
77 Ok(Some(_)) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "HandleTaken", "message": "Handle already taken"}))).into_response(),
78 Err(e) => {
79 error!("Error checking handle: {:?}", e);
80 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
81 }
82 Ok(None) => {}
83 }
84
85 if let Some(code) = &input.invite_code {
86 let invite_query = sqlx::query("SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE")
87 .bind(code)
88 .fetch_optional(&mut *tx)
89 .await;
90
91 match invite_query {
92 Ok(Some(row)) => {
93 let uses: i32 = row.get("available_uses");
94 if uses <= 0 {
95 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code exhausted"}))).into_response();
96 }
97
98 let update_invite = sqlx::query("UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1")
99 .bind(code)
100 .execute(&mut *tx)
101 .await;
102
103 if let Err(e) = update_invite {
104 error!("Error updating invite code: {:?}", e);
105 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
106 }
107 },
108 Ok(None) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidInviteCode", "message": "Invite code not found"}))).into_response(),
109 Err(e) => {
110 error!("Error checking invite code: {:?}", e);
111 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
112 }
113 }
114 }
115
116 let password_hash = match hash(&input.password, DEFAULT_COST) {
117 Ok(h) => h,
118 Err(e) => {
119 error!("Error hashing password: {:?}", e);
120 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
121 }
122 };
123
124 let user_insert = sqlx::query("INSERT INTO users (handle, email, did, password_hash) VALUES ($1, $2, $3, $4) RETURNING id")
125 .bind(&input.handle)
126 .bind(&input.email)
127 .bind(&did)
128 .bind(&password_hash)
129 .fetch_one(&mut *tx)
130 .await;
131
132 let user_id: uuid::Uuid = match user_insert {
133 Ok(row) => row.get("id"),
134 Err(e) => {
135 error!("Error inserting user: {:?}", e);
136 // TODO: Check for unique constraint violation on email/did specifically
137 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
138 }
139 };
140
141 let secret_key = SecretKey::random(&mut OsRng);
142 let secret_key_bytes = secret_key.to_bytes();
143
144 let key_insert = sqlx::query("INSERT INTO user_keys (user_id, key_bytes) VALUES ($1, $2)")
145 .bind(user_id)
146 .bind(&secret_key_bytes[..])
147 .execute(&mut *tx)
148 .await;
149
150 if let Err(e) = key_insert {
151 error!("Error inserting user key: {:?}", e);
152 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
153 }
154
155 let mst = Mst::new(Arc::new(state.block_store.clone()));
156 let mst_root = match mst.root().await {
157 Ok(c) => c,
158 Err(e) => {
159 error!("Error creating MST root: {:?}", e);
160 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
161 }
162 };
163
164 let did_obj = match Did::new(&did) {
165 Ok(d) => d,
166 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Invalid DID"}))).into_response(),
167 };
168
169 let rev = Tid::now(LimitedU32::MIN);
170
171 let commit = Commit::new_unsigned(
172 did_obj,
173 mst_root,
174 rev,
175 None
176 );
177
178 let commit_bytes = match commit.to_cbor() {
179 Ok(b) => b,
180 Err(e) => {
181 error!("Error serializing genesis commit: {:?}", e);
182 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
183 }
184 };
185
186 let commit_cid = match state.block_store.put(&commit_bytes).await {
187 Ok(c) => c,
188 Err(e) => {
189 error!("Error saving genesis commit: {:?}", e);
190 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
191 }
192 };
193
194 let repo_insert = sqlx::query("INSERT INTO repos (user_id, repo_root_cid) VALUES ($1, $2)")
195 .bind(user_id)
196 .bind(commit_cid.to_string())
197 .execute(&mut *tx)
198 .await;
199
200 if let Err(e) = repo_insert {
201 error!("Error initializing repo: {:?}", e);
202 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
203 }
204
205 if let Some(code) = &input.invite_code {
206 let use_insert = sqlx::query("INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)")
207 .bind(code)
208 .bind(user_id)
209 .execute(&mut *tx)
210 .await;
211
212 if let Err(e) = use_insert {
213 error!("Error recording invite usage: {:?}", e);
214 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
215 }
216 }
217
218 let access_jwt = crate::auth::create_access_token(&did, &secret_key_bytes[..]).map_err(|e| {
219 error!("Error creating access token: {:?}", e);
220 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
221 });
222 let access_jwt = match access_jwt {
223 Ok(t) => t,
224 Err(r) => return r,
225 };
226
227 let refresh_jwt = crate::auth::create_refresh_token(&did, &secret_key_bytes[..]).map_err(|e| {
228 error!("Error creating refresh token: {:?}", e);
229 (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
230 });
231 let refresh_jwt = match refresh_jwt {
232 Ok(t) => t,
233 Err(r) => return r,
234 };
235
236 let session_insert = sqlx::query("INSERT INTO sessions (access_jwt, refresh_jwt, did) VALUES ($1, $2, $3)")
237 .bind(&access_jwt)
238 .bind(&refresh_jwt)
239 .bind(&did)
240 .execute(&mut *tx)
241 .await;
242
243 if let Err(e) = session_insert {
244 error!("Error inserting session: {:?}", e);
245 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
246 }
247
248 if let Err(e) = tx.commit().await {
249 error!("Error committing transaction: {:?}", e);
250 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response();
251 }
252
253 (StatusCode::OK, Json(CreateAccountOutput {
254 access_jwt,
255 refresh_jwt,
256 handle: input.handle,
257 did,
258 })).into_response()
259}
260
261fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
262 use k256::elliptic_curve::sec1::ToEncodedPoint;
263
264 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
265 let public_key = secret_key.public_key();
266 let encoded = public_key.to_encoded_point(false);
267 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
268 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
269
270 json!({
271 "kty": "EC",
272 "crv": "secp256k1",
273 "x": x,
274 "y": y
275 })
276}
277
278pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
279 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
280 // Kinda for local dev, encode hostname if it contains port
281 let did = if hostname.contains(':') {
282 format!("did:web:{}", hostname.replace(':', "%3A"))
283 } else {
284 format!("did:web:{}", hostname)
285 };
286
287 Json(json!({
288 "@context": ["https://www.w3.org/ns/did/v1"],
289 "id": did,
290 "service": [{
291 "id": "#atproto_pds",
292 "type": "AtprotoPersonalDataServer",
293 "serviceEndpoint": format!("https://{}", hostname)
294 }]
295 }))
296}
297
298pub async fn user_did_doc(
299 State(state): State<AppState>,
300 Path(handle): Path<String>,
301) -> Response {
302 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
303
304 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
305 .bind(&handle)
306 .fetch_optional(&state.db)
307 .await;
308
309 let (user_id, did) = match user {
310 Ok(Some(row)) => {
311 let id: uuid::Uuid = row.get("id");
312 let d: String = row.get("did");
313 (id, d)
314 },
315 Ok(None) => return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(),
316 Err(e) => {
317 error!("DB Error: {:?}", e);
318 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response()
319 },
320 };
321
322 if !did.starts_with("did:web:") {
323 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound", "message": "User is not did:web"}))).into_response();
324 }
325
326 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
327 .bind(user_id)
328 .fetch_optional(&state.db)
329 .await;
330
331 let key_bytes: Vec<u8> = match key_row {
332 Ok(Some(row)) => row.get("key_bytes"),
333 _ => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError"}))).into_response(),
334 };
335
336 let jwk = get_jwk(&key_bytes);
337
338 Json(json!({
339 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
340 "id": did,
341 "alsoKnownAs": [format!("at://{}", handle)],
342 "verificationMethod": [{
343 "id": format!("{}#atproto", did),
344 "type": "JsonWebKey2020",
345 "controller": did,
346 "publicKeyJwk": jwk
347 }],
348 "service": [{
349 "id": "#atproto_pds",
350 "type": "AtprotoPersonalDataServer",
351 "serviceEndpoint": format!("https://{}", hostname)
352 }]
353 })).into_response()
354}