this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::state::AppState;
4use crate::util::get_user_id_by_did;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use serde::{Deserialize, Serialize};
11use tracing::error;
12use uuid::Uuid;
13
14#[derive(Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct CreateInviteCodeInput {
17 pub use_count: i32,
18 pub for_account: Option<String>,
19}
20
21#[derive(Serialize)]
22pub struct CreateInviteCodeOutput {
23 pub code: String,
24}
25
26pub async fn create_invite_code(
27 State(state): State<AppState>,
28 BearerAuth(auth_user): BearerAuth,
29 Json(input): Json<CreateInviteCodeInput>,
30) -> Response {
31 if input.use_count < 1 {
32 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
33 }
34
35 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
36 Ok(id) => id,
37 Err(e) => return ApiError::from(e).into_response(),
38 };
39
40 let creator_user_id = if let Some(for_account) = &input.for_account {
41 match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
42 .fetch_optional(&state.db)
43 .await
44 {
45 Ok(Some(row)) => row.id,
46 Ok(None) => return ApiError::AccountNotFound.into_response(),
47 Err(e) => {
48 error!("DB error looking up target account: {:?}", e);
49 return ApiError::InternalError.into_response();
50 }
51 }
52 } else {
53 user_id
54 };
55
56 let user_invites_disabled = sqlx::query_scalar!(
57 "SELECT invites_disabled FROM users WHERE did = $1",
58 auth_user.did
59 )
60 .fetch_optional(&state.db)
61 .await
62 .map_err(|e| {
63 error!("DB error checking invites_disabled: {:?}", e);
64 ApiError::InternalError
65 })
66 .ok()
67 .flatten()
68 .flatten()
69 .unwrap_or(false);
70
71 if user_invites_disabled {
72 return ApiError::InvitesDisabled.into_response();
73 }
74
75 let code = Uuid::new_v4().to_string();
76
77 match sqlx::query!(
78 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
79 code,
80 input.use_count,
81 creator_user_id
82 )
83 .execute(&state.db)
84 .await
85 {
86 Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
87 Err(e) => {
88 error!("DB error creating invite code: {:?}", e);
89 ApiError::InternalError.into_response()
90 }
91 }
92}
93
94#[derive(Deserialize)]
95#[serde(rename_all = "camelCase")]
96pub struct CreateInviteCodesInput {
97 pub code_count: Option<i32>,
98 pub use_count: i32,
99 pub for_accounts: Option<Vec<String>>,
100}
101
102#[derive(Serialize)]
103pub struct CreateInviteCodesOutput {
104 pub codes: Vec<AccountCodes>,
105}
106
107#[derive(Serialize)]
108pub struct AccountCodes {
109 pub account: String,
110 pub codes: Vec<String>,
111}
112
113pub async fn create_invite_codes(
114 State(state): State<AppState>,
115 BearerAuth(auth_user): BearerAuth,
116 Json(input): Json<CreateInviteCodesInput>,
117) -> Response {
118 if input.use_count < 1 {
119 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
120 }
121
122 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
123 Ok(id) => id,
124 Err(e) => return ApiError::from(e).into_response(),
125 };
126
127 let code_count = input.code_count.unwrap_or(1).max(1);
128 let for_accounts = input.for_accounts.unwrap_or_default();
129
130 let mut result_codes = Vec::new();
131
132 if for_accounts.is_empty() {
133 let mut codes = Vec::new();
134 for _ in 0..code_count {
135 let code = Uuid::new_v4().to_string();
136
137 if let Err(e) = sqlx::query!(
138 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
139 code,
140 input.use_count,
141 user_id
142 )
143 .execute(&state.db)
144 .await
145 {
146 error!("DB error creating invite code: {:?}", e);
147 return ApiError::InternalError.into_response();
148 }
149
150 codes.push(code);
151 }
152
153 result_codes.push(AccountCodes {
154 account: "admin".to_string(),
155 codes,
156 });
157 } else {
158 for account_did in for_accounts {
159 let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
160 .fetch_optional(&state.db)
161 .await
162 {
163 Ok(Some(row)) => row.id,
164 Ok(None) => continue,
165 Err(e) => {
166 error!("DB error looking up target account: {:?}", e);
167 return ApiError::InternalError.into_response();
168 }
169 };
170
171 let mut codes = Vec::new();
172 for _ in 0..code_count {
173 let code = Uuid::new_v4().to_string();
174
175 if let Err(e) = sqlx::query!(
176 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
177 code,
178 input.use_count,
179 target_user_id
180 )
181 .execute(&state.db)
182 .await
183 {
184 error!("DB error creating invite code: {:?}", e);
185 return ApiError::InternalError.into_response();
186 }
187
188 codes.push(code);
189 }
190
191 result_codes.push(AccountCodes {
192 account: account_did,
193 codes,
194 });
195 }
196 }
197
198 Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
199}
200
201#[derive(Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct GetAccountInviteCodesParams {
204 pub include_used: Option<bool>,
205 pub create_available: Option<bool>,
206}
207
208#[derive(Serialize)]
209#[serde(rename_all = "camelCase")]
210pub struct InviteCode {
211 pub code: String,
212 pub available: i32,
213 pub disabled: bool,
214 pub for_account: String,
215 pub created_by: String,
216 pub created_at: String,
217 pub uses: Vec<InviteCodeUse>,
218}
219
220#[derive(Serialize)]
221#[serde(rename_all = "camelCase")]
222pub struct InviteCodeUse {
223 pub used_by: String,
224 pub used_at: String,
225}
226
227#[derive(Serialize)]
228pub struct GetAccountInviteCodesOutput {
229 pub codes: Vec<InviteCode>,
230}
231
232pub async fn get_account_invite_codes(
233 State(state): State<AppState>,
234 BearerAuth(auth_user): BearerAuth,
235 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
236) -> Response {
237 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
238 Ok(id) => id,
239 Err(e) => return ApiError::from(e).into_response(),
240 };
241
242 let include_used = params.include_used.unwrap_or(true);
243
244 let codes_rows = match sqlx::query!(
245 r#"
246 SELECT code, available_uses, created_at, disabled
247 FROM invite_codes
248 WHERE created_by_user = $1
249 ORDER BY created_at DESC
250 "#,
251 user_id
252 )
253 .fetch_all(&state.db)
254 .await
255 {
256 Ok(rows) => {
257 if include_used {
258 rows
259 } else {
260 rows.into_iter().filter(|r| r.available_uses > 0).collect()
261 }
262 }
263 Err(e) => {
264 error!("DB error fetching invite codes: {:?}", e);
265 return ApiError::InternalError.into_response();
266 }
267 };
268
269 let mut codes = Vec::new();
270 for row in codes_rows {
271 let uses = sqlx::query!(
272 r#"
273 SELECT u.did, icu.used_at
274 FROM invite_code_uses icu
275 JOIN users u ON icu.used_by_user = u.id
276 WHERE icu.code = $1
277 ORDER BY icu.used_at DESC
278 "#,
279 row.code
280 )
281 .fetch_all(&state.db)
282 .await
283 .map(|use_rows| {
284 use_rows
285 .iter()
286 .map(|u| InviteCodeUse {
287 used_by: u.did.clone(),
288 used_at: u.used_at.to_rfc3339(),
289 })
290 .collect()
291 })
292 .unwrap_or_default();
293
294 codes.push(InviteCode {
295 code: row.code,
296 available: row.available_uses,
297 disabled: row.disabled.unwrap_or(false),
298 for_account: auth_user.did.clone(),
299 created_by: auth_user.did.clone(),
300 created_at: row.created_at.to_rfc3339(),
301 uses,
302 });
303 }
304
305 Json(GetAccountInviteCodesOutput { codes }).into_response()
306}