this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tracing::error;
11use uuid::Uuid;
12
13#[derive(Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct CreateInviteCodeInput {
16 pub use_count: i32,
17 pub for_account: Option<String>,
18}
19
20#[derive(Serialize)]
21pub struct CreateInviteCodeOutput {
22 pub code: String,
23}
24
25pub async fn create_invite_code(
26 State(state): State<AppState>,
27 headers: axum::http::HeaderMap,
28 Json(input): Json<CreateInviteCodeInput>,
29) -> Response {
30 let token = match crate::auth::extract_bearer_token_from_header(
31 headers.get("Authorization").and_then(|h| h.to_str().ok())
32 ) {
33 Some(t) => t,
34 None => {
35 return (
36 StatusCode::UNAUTHORIZED,
37 Json(json!({"error": "AuthenticationRequired"})),
38 )
39 .into_response();
40 }
41 };
42
43 if input.use_count < 1 {
44 return (
45 StatusCode::BAD_REQUEST,
46 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
47 )
48 .into_response();
49 }
50
51 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
52 let did = match auth_result {
53 Ok(user) => user.did,
54 Err(e) => {
55 return (
56 StatusCode::UNAUTHORIZED,
57 Json(json!({"error": e})),
58 )
59 .into_response();
60 }
61 };
62
63 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
64 .fetch_optional(&state.db)
65 .await
66 {
67 Ok(Some(id)) => id,
68 _ => {
69 return (
70 StatusCode::INTERNAL_SERVER_ERROR,
71 Json(json!({"error": "InternalError"})),
72 )
73 .into_response();
74 }
75 };
76
77 let creator_user_id = if let Some(for_account) = &input.for_account {
78 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
79 .fetch_optional(&state.db)
80 .await;
81
82 match target {
83 Ok(Some(row)) => row.id,
84 Ok(None) => {
85 return (
86 StatusCode::NOT_FOUND,
87 Json(json!({"error": "AccountNotFound", "message": "Target account not found"})),
88 )
89 .into_response();
90 }
91 Err(e) => {
92 error!("DB error looking up target account: {:?}", e);
93 return (
94 StatusCode::INTERNAL_SERVER_ERROR,
95 Json(json!({"error": "InternalError"})),
96 )
97 .into_response();
98 }
99 }
100 } else {
101 user_id
102 };
103
104 let user_invites_disabled = sqlx::query_scalar!(
105 "SELECT invites_disabled FROM users WHERE did = $1",
106 did
107 )
108 .fetch_optional(&state.db)
109 .await
110 .ok()
111 .flatten()
112 .flatten()
113 .unwrap_or(false);
114
115 if user_invites_disabled {
116 return (
117 StatusCode::FORBIDDEN,
118 Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})),
119 )
120 .into_response();
121 }
122
123 let code = Uuid::new_v4().to_string();
124
125 let result = sqlx::query!(
126 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
127 code,
128 input.use_count,
129 creator_user_id
130 )
131 .execute(&state.db)
132 .await;
133
134 match result {
135 Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(),
136 Err(e) => {
137 error!("DB error creating invite code: {:?}", e);
138 (
139 StatusCode::INTERNAL_SERVER_ERROR,
140 Json(json!({"error": "InternalError"})),
141 )
142 .into_response()
143 }
144 }
145}
146
147#[derive(Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct CreateInviteCodesInput {
150 pub code_count: Option<i32>,
151 pub use_count: i32,
152 pub for_accounts: Option<Vec<String>>,
153}
154
155#[derive(Serialize)]
156pub struct CreateInviteCodesOutput {
157 pub codes: Vec<AccountCodes>,
158}
159
160#[derive(Serialize)]
161pub struct AccountCodes {
162 pub account: String,
163 pub codes: Vec<String>,
164}
165
166pub async fn create_invite_codes(
167 State(state): State<AppState>,
168 headers: axum::http::HeaderMap,
169 Json(input): Json<CreateInviteCodesInput>,
170) -> Response {
171 let token = match crate::auth::extract_bearer_token_from_header(
172 headers.get("Authorization").and_then(|h| h.to_str().ok())
173 ) {
174 Some(t) => t,
175 None => {
176 return (
177 StatusCode::UNAUTHORIZED,
178 Json(json!({"error": "AuthenticationRequired"})),
179 )
180 .into_response();
181 }
182 };
183
184 if input.use_count < 1 {
185 return (
186 StatusCode::BAD_REQUEST,
187 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})),
188 )
189 .into_response();
190 }
191
192 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
193 let did = match auth_result {
194 Ok(user) => user.did,
195 Err(e) => {
196 return (
197 StatusCode::UNAUTHORIZED,
198 Json(json!({"error": e})),
199 )
200 .into_response();
201 }
202 };
203
204 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
205 .fetch_optional(&state.db)
206 .await
207 {
208 Ok(Some(id)) => id,
209 _ => {
210 return (
211 StatusCode::INTERNAL_SERVER_ERROR,
212 Json(json!({"error": "InternalError"})),
213 )
214 .into_response();
215 }
216 };
217
218 let code_count = input.code_count.unwrap_or(1).max(1);
219 let for_accounts = input.for_accounts.unwrap_or_default();
220
221 let mut result_codes = Vec::new();
222
223 if for_accounts.is_empty() {
224 let mut codes = Vec::new();
225 for _ in 0..code_count {
226 let code = Uuid::new_v4().to_string();
227
228 let insert = sqlx::query!(
229 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
230 code,
231 input.use_count,
232 user_id
233 )
234 .execute(&state.db)
235 .await;
236
237 if let Err(e) = insert {
238 error!("DB error creating invite code: {:?}", e);
239 return (
240 StatusCode::INTERNAL_SERVER_ERROR,
241 Json(json!({"error": "InternalError"})),
242 )
243 .into_response();
244 }
245
246 codes.push(code);
247 }
248
249 result_codes.push(AccountCodes {
250 account: "admin".to_string(),
251 codes,
252 });
253 } else {
254 for account_did in for_accounts {
255 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
256 .fetch_optional(&state.db)
257 .await;
258
259 let target_user_id = match target {
260 Ok(Some(row)) => row.id,
261 Ok(None) => {
262 continue;
263 }
264 Err(e) => {
265 error!("DB error looking up target account: {:?}", e);
266 return (
267 StatusCode::INTERNAL_SERVER_ERROR,
268 Json(json!({"error": "InternalError"})),
269 )
270 .into_response();
271 }
272 };
273
274 let mut codes = Vec::new();
275 for _ in 0..code_count {
276 let code = Uuid::new_v4().to_string();
277
278 let insert = sqlx::query!(
279 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
280 code,
281 input.use_count,
282 target_user_id
283 )
284 .execute(&state.db)
285 .await;
286
287 if let Err(e) = insert {
288 error!("DB error creating invite code: {:?}", e);
289 return (
290 StatusCode::INTERNAL_SERVER_ERROR,
291 Json(json!({"error": "InternalError"})),
292 )
293 .into_response();
294 }
295
296 codes.push(code);
297 }
298
299 result_codes.push(AccountCodes {
300 account: account_did,
301 codes,
302 });
303 }
304 }
305
306 (StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response()
307}
308
309#[derive(Deserialize)]
310#[serde(rename_all = "camelCase")]
311pub struct GetAccountInviteCodesParams {
312 pub include_used: Option<bool>,
313 pub create_available: Option<bool>,
314}
315
316#[derive(Serialize)]
317#[serde(rename_all = "camelCase")]
318pub struct InviteCode {
319 pub code: String,
320 pub available: i32,
321 pub disabled: bool,
322 pub for_account: String,
323 pub created_by: String,
324 pub created_at: String,
325 pub uses: Vec<InviteCodeUse>,
326}
327
328#[derive(Serialize)]
329#[serde(rename_all = "camelCase")]
330pub struct InviteCodeUse {
331 pub used_by: String,
332 pub used_at: String,
333}
334
335#[derive(Serialize)]
336pub struct GetAccountInviteCodesOutput {
337 pub codes: Vec<InviteCode>,
338}
339
340pub async fn get_account_invite_codes(
341 State(state): State<AppState>,
342 headers: axum::http::HeaderMap,
343 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
344) -> Response {
345 let token = match crate::auth::extract_bearer_token_from_header(
346 headers.get("Authorization").and_then(|h| h.to_str().ok())
347 ) {
348 Some(t) => t,
349 None => {
350 return (
351 StatusCode::UNAUTHORIZED,
352 Json(json!({"error": "AuthenticationRequired"})),
353 )
354 .into_response();
355 }
356 };
357
358 let auth_result = crate::auth::validate_bearer_token(&state.db, &token).await;
359 let did = match auth_result {
360 Ok(user) => user.did,
361 Err(e) => {
362 return (
363 StatusCode::UNAUTHORIZED,
364 Json(json!({"error": e})),
365 )
366 .into_response();
367 }
368 };
369
370 let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
371 .fetch_optional(&state.db)
372 .await
373 {
374 Ok(Some(id)) => id,
375 _ => {
376 return (
377 StatusCode::INTERNAL_SERVER_ERROR,
378 Json(json!({"error": "InternalError"})),
379 )
380 .into_response();
381 }
382 };
383
384 let include_used = params.include_used.unwrap_or(true);
385
386 let codes_result = sqlx::query!(
387 r#"
388 SELECT code, available_uses, created_at, disabled
389 FROM invite_codes
390 WHERE created_by_user = $1
391 ORDER BY created_at DESC
392 "#,
393 user_id
394 )
395 .fetch_all(&state.db)
396 .await;
397
398 let codes_rows = match codes_result {
399 Ok(rows) => {
400 if include_used {
401 rows
402 } else {
403 rows.into_iter().filter(|r| r.available_uses > 0).collect()
404 }
405 }
406 Err(e) => {
407 error!("DB error fetching invite codes: {:?}", e);
408 return (
409 StatusCode::INTERNAL_SERVER_ERROR,
410 Json(json!({"error": "InternalError"})),
411 )
412 .into_response();
413 }
414 };
415
416 let mut codes = Vec::new();
417 for row in codes_rows {
418 let uses_result = sqlx::query!(
419 r#"
420 SELECT u.did, icu.used_at
421 FROM invite_code_uses icu
422 JOIN users u ON icu.used_by_user = u.id
423 WHERE icu.code = $1
424 ORDER BY icu.used_at DESC
425 "#,
426 row.code
427 )
428 .fetch_all(&state.db)
429 .await;
430
431 let uses = match uses_result {
432 Ok(use_rows) => use_rows
433 .iter()
434 .map(|u| InviteCodeUse {
435 used_by: u.did.clone(),
436 used_at: u.used_at.to_rfc3339(),
437 })
438 .collect(),
439 Err(_) => Vec::new(),
440 };
441
442 codes.push(InviteCode {
443 code: row.code,
444 available: row.available_uses,
445 disabled: row.disabled.unwrap_or(false),
446 for_account: did.clone(),
447 created_by: did.clone(),
448 created_at: row.created_at.to_rfc3339(),
449 uses,
450 });
451 }
452
453 (StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response()
454}