this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tracing::error;
11use uuid::Uuid;
12
13#[derive(Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct CreateInviteCodeInput {
16 pub use_count: i32,
17 pub for_account: Option<String>,
18}
19
20#[derive(Serialize)]
21pub struct CreateInviteCodeOutput {
22 pub code: String,
23}
24
25pub async fn create_invite_code(
26 State(state): State<AppState>,
27 headers: axum::http::HeaderMap,
28 Json(input): Json<CreateInviteCodeInput>,
29) -> Response {
30 let auth_header = headers.get("Authorization");
31 if auth_header.is_none() {
32 return (
33 StatusCode::UNAUTHORIZED,
34 Json(json!({"error": "AuthenticationRequired"})),
35 )
36 .into_response();
37 }
38
39 if input.use_count < 1 {
40 return (
41 StatusCode::BAD_REQUEST,
42 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
43 )
44 .into_response();
45 }
46
47 let token = auth_header
48 .unwrap()
49 .to_str()
50 .unwrap_or("")
51 .replace("Bearer ", "");
52
53 let session = sqlx::query!(
54 r#"
55 SELECT s.did, k.key_bytes, u.id as user_id
56 FROM sessions s
57 JOIN users u ON s.did = u.did
58 JOIN user_keys k ON u.id = k.user_id
59 WHERE s.access_jwt = $1
60 "#,
61 token
62 )
63 .fetch_optional(&state.db)
64 .await;
65
66 let (did, key_bytes, user_id) = match session {
67 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
68 Ok(None) => {
69 return (
70 StatusCode::UNAUTHORIZED,
71 Json(json!({"error": "AuthenticationFailed"})),
72 )
73 .into_response();
74 }
75 Err(e) => {
76 error!("DB error in create_invite_code: {:?}", e);
77 return (
78 StatusCode::INTERNAL_SERVER_ERROR,
79 Json(json!({"error": "InternalError"})),
80 )
81 .into_response();
82 }
83 };
84
85 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
86 return (
87 StatusCode::UNAUTHORIZED,
88 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
89 )
90 .into_response();
91 }
92
93 let creator_user_id = if let Some(for_account) = &input.for_account {
94 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
95 .fetch_optional(&state.db)
96 .await;
97
98 match target {
99 Ok(Some(row)) => row.id,
100 Ok(None) => {
101 return (
102 StatusCode::NOT_FOUND,
103 Json(json!({"error": "AccountNotFound", "message": "Target account not found"})),
104 )
105 .into_response();
106 }
107 Err(e) => {
108 error!("DB error looking up target account: {:?}", e);
109 return (
110 StatusCode::INTERNAL_SERVER_ERROR,
111 Json(json!({"error": "InternalError"})),
112 )
113 .into_response();
114 }
115 }
116 } else {
117 user_id
118 };
119
120 let user_invites_disabled = sqlx::query_scalar!(
121 "SELECT invites_disabled FROM users WHERE did = $1",
122 did
123 )
124 .fetch_optional(&state.db)
125 .await
126 .ok()
127 .flatten()
128 .flatten()
129 .unwrap_or(false);
130
131 if user_invites_disabled {
132 return (
133 StatusCode::FORBIDDEN,
134 Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})),
135 )
136 .into_response();
137 }
138
139 let code = Uuid::new_v4().to_string();
140
141 let result = sqlx::query!(
142 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
143 code,
144 input.use_count,
145 creator_user_id
146 )
147 .execute(&state.db)
148 .await;
149
150 match result {
151 Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(),
152 Err(e) => {
153 error!("DB error creating invite code: {:?}", e);
154 (
155 StatusCode::INTERNAL_SERVER_ERROR,
156 Json(json!({"error": "InternalError"})),
157 )
158 .into_response()
159 }
160 }
161}
162
163#[derive(Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct CreateInviteCodesInput {
166 pub code_count: Option<i32>,
167 pub use_count: i32,
168 pub for_accounts: Option<Vec<String>>,
169}
170
171#[derive(Serialize)]
172pub struct CreateInviteCodesOutput {
173 pub codes: Vec<AccountCodes>,
174}
175
176#[derive(Serialize)]
177pub struct AccountCodes {
178 pub account: String,
179 pub codes: Vec<String>,
180}
181
182pub async fn create_invite_codes(
183 State(state): State<AppState>,
184 headers: axum::http::HeaderMap,
185 Json(input): Json<CreateInviteCodesInput>,
186) -> Response {
187 let auth_header = headers.get("Authorization");
188 if auth_header.is_none() {
189 return (
190 StatusCode::UNAUTHORIZED,
191 Json(json!({"error": "AuthenticationRequired"})),
192 )
193 .into_response();
194 }
195
196 if input.use_count < 1 {
197 return (
198 StatusCode::BAD_REQUEST,
199 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
200 )
201 .into_response();
202 }
203
204 let token = auth_header
205 .unwrap()
206 .to_str()
207 .unwrap_or("")
208 .replace("Bearer ", "");
209
210 let session = sqlx::query!(
211 r#"
212 SELECT s.did, k.key_bytes, u.id as user_id
213 FROM sessions s
214 JOIN users u ON s.did = u.did
215 JOIN user_keys k ON u.id = k.user_id
216 WHERE s.access_jwt = $1
217 "#,
218 token
219 )
220 .fetch_optional(&state.db)
221 .await;
222
223 let (_did, key_bytes, user_id) = match session {
224 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
225 Ok(None) => {
226 return (
227 StatusCode::UNAUTHORIZED,
228 Json(json!({"error": "AuthenticationFailed"})),
229 )
230 .into_response();
231 }
232 Err(e) => {
233 error!("DB error in create_invite_codes: {:?}", e);
234 return (
235 StatusCode::INTERNAL_SERVER_ERROR,
236 Json(json!({"error": "InternalError"})),
237 )
238 .into_response();
239 }
240 };
241
242 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
243 return (
244 StatusCode::UNAUTHORIZED,
245 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
246 )
247 .into_response();
248 }
249
250 let code_count = input.code_count.unwrap_or(1).max(1);
251 let for_accounts = input.for_accounts.unwrap_or_default();
252
253 let mut result_codes = Vec::new();
254
255 if for_accounts.is_empty() {
256 let mut codes = Vec::new();
257 for _ in 0..code_count {
258 let code = Uuid::new_v4().to_string();
259
260 let insert = sqlx::query!(
261 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
262 code,
263 input.use_count,
264 user_id
265 )
266 .execute(&state.db)
267 .await;
268
269 if let Err(e) = insert {
270 error!("DB error creating invite code: {:?}", e);
271 return (
272 StatusCode::INTERNAL_SERVER_ERROR,
273 Json(json!({"error": "InternalError"})),
274 )
275 .into_response();
276 }
277
278 codes.push(code);
279 }
280
281 result_codes.push(AccountCodes {
282 account: "admin".to_string(),
283 codes,
284 });
285 } else {
286 for account_did in for_accounts {
287 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
288 .fetch_optional(&state.db)
289 .await;
290
291 let target_user_id = match target {
292 Ok(Some(row)) => row.id,
293 Ok(None) => {
294 continue;
295 }
296 Err(e) => {
297 error!("DB error looking up target account: {:?}", e);
298 return (
299 StatusCode::INTERNAL_SERVER_ERROR,
300 Json(json!({"error": "InternalError"})),
301 )
302 .into_response();
303 }
304 };
305
306 let mut codes = Vec::new();
307 for _ in 0..code_count {
308 let code = Uuid::new_v4().to_string();
309
310 let insert = sqlx::query!(
311 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
312 code,
313 input.use_count,
314 target_user_id
315 )
316 .execute(&state.db)
317 .await;
318
319 if let Err(e) = insert {
320 error!("DB error creating invite code: {:?}", e);
321 return (
322 StatusCode::INTERNAL_SERVER_ERROR,
323 Json(json!({"error": "InternalError"})),
324 )
325 .into_response();
326 }
327
328 codes.push(code);
329 }
330
331 result_codes.push(AccountCodes {
332 account: account_did,
333 codes,
334 });
335 }
336 }
337
338 (StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response()
339}
340
341#[derive(Deserialize)]
342#[serde(rename_all = "camelCase")]
343pub struct GetAccountInviteCodesParams {
344 pub include_used: Option<bool>,
345 pub create_available: Option<bool>,
346}
347
348#[derive(Serialize)]
349#[serde(rename_all = "camelCase")]
350pub struct InviteCode {
351 pub code: String,
352 pub available: i32,
353 pub disabled: bool,
354 pub for_account: String,
355 pub created_by: String,
356 pub created_at: String,
357 pub uses: Vec<InviteCodeUse>,
358}
359
360#[derive(Serialize)]
361#[serde(rename_all = "camelCase")]
362pub struct InviteCodeUse {
363 pub used_by: String,
364 pub used_at: String,
365}
366
367#[derive(Serialize)]
368pub struct GetAccountInviteCodesOutput {
369 pub codes: Vec<InviteCode>,
370}
371
372pub async fn get_account_invite_codes(
373 State(state): State<AppState>,
374 headers: axum::http::HeaderMap,
375 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
376) -> Response {
377 let auth_header = headers.get("Authorization");
378 if auth_header.is_none() {
379 return (
380 StatusCode::UNAUTHORIZED,
381 Json(json!({"error": "AuthenticationRequired"})),
382 )
383 .into_response();
384 }
385
386 let token = auth_header
387 .unwrap()
388 .to_str()
389 .unwrap_or("")
390 .replace("Bearer ", "");
391
392 let session = sqlx::query!(
393 r#"
394 SELECT s.did, k.key_bytes, u.id as user_id
395 FROM sessions s
396 JOIN users u ON s.did = u.did
397 JOIN user_keys k ON u.id = k.user_id
398 WHERE s.access_jwt = $1
399 "#,
400 token
401 )
402 .fetch_optional(&state.db)
403 .await;
404
405 let (did, key_bytes, user_id) = match session {
406 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id),
407 Ok(None) => {
408 return (
409 StatusCode::UNAUTHORIZED,
410 Json(json!({"error": "AuthenticationFailed"})),
411 )
412 .into_response();
413 }
414 Err(e) => {
415 error!("DB error in get_account_invite_codes: {:?}", e);
416 return (
417 StatusCode::INTERNAL_SERVER_ERROR,
418 Json(json!({"error": "InternalError"})),
419 )
420 .into_response();
421 }
422 };
423
424 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
425 return (
426 StatusCode::UNAUTHORIZED,
427 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
428 )
429 .into_response();
430 }
431
432 let include_used = params.include_used.unwrap_or(true);
433
434 let codes_result = sqlx::query!(
435 r#"
436 SELECT code, available_uses, created_at, disabled
437 FROM invite_codes
438 WHERE created_by_user = $1
439 ORDER BY created_at DESC
440 "#,
441 user_id
442 )
443 .fetch_all(&state.db)
444 .await;
445
446 let codes_rows = match codes_result {
447 Ok(rows) => {
448 if include_used {
449 rows
450 } else {
451 rows.into_iter().filter(|r| r.available_uses > 0).collect()
452 }
453 }
454 Err(e) => {
455 error!("DB error fetching invite codes: {:?}", e);
456 return (
457 StatusCode::INTERNAL_SERVER_ERROR,
458 Json(json!({"error": "InternalError"})),
459 )
460 .into_response();
461 }
462 };
463
464 let mut codes = Vec::new();
465 for row in codes_rows {
466 let uses_result = sqlx::query!(
467 r#"
468 SELECT u.did, icu.used_at
469 FROM invite_code_uses icu
470 JOIN users u ON icu.used_by_user = u.id
471 WHERE icu.code = $1
472 ORDER BY icu.used_at DESC
473 "#,
474 row.code
475 )
476 .fetch_all(&state.db)
477 .await;
478
479 let uses = match uses_result {
480 Ok(use_rows) => use_rows
481 .iter()
482 .map(|u| InviteCodeUse {
483 used_by: u.did.clone(),
484 used_at: u.used_at.to_rfc3339(),
485 })
486 .collect(),
487 Err(_) => Vec::new(),
488 };
489
490 codes.push(InviteCode {
491 code: row.code,
492 available: row.available_uses,
493 disabled: row.disabled.unwrap_or(false),
494 for_account: did.clone(),
495 created_by: did.clone(),
496 created_at: row.created_at.to_rfc3339(),
497 uses,
498 });
499 }
500
501 (StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response()
502}