this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::state::AppState;
4use crate::util::get_user_id_by_did;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use serde::{Deserialize, Serialize};
11use tracing::error;
12use uuid::Uuid;
13
14#[derive(Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct CreateInviteCodeInput {
17 pub use_count: i32,
18 pub for_account: Option<String>,
19}
20
21#[derive(Serialize)]
22pub struct CreateInviteCodeOutput {
23 pub code: String,
24}
25
26pub async fn create_invite_code(
27 State(state): State<AppState>,
28 BearerAuth(auth_user): BearerAuth,
29 Json(input): Json<CreateInviteCodeInput>,
30) -> Response {
31 if input.use_count < 1 {
32 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
33 }
34 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
35 Ok(id) => id,
36 Err(e) => return ApiError::from(e).into_response(),
37 };
38 let creator_user_id = if let Some(for_account) = &input.for_account {
39 match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
40 .fetch_optional(&state.db)
41 .await
42 {
43 Ok(Some(row)) => row.id,
44 Ok(None) => return ApiError::AccountNotFound.into_response(),
45 Err(e) => {
46 error!("DB error looking up target account: {:?}", e);
47 return ApiError::InternalError.into_response();
48 }
49 }
50 } else {
51 user_id
52 };
53 let user_invites_disabled = sqlx::query_scalar!(
54 "SELECT invites_disabled FROM users WHERE did = $1",
55 auth_user.did
56 )
57 .fetch_optional(&state.db)
58 .await
59 .map_err(|e| {
60 error!("DB error checking invites_disabled: {:?}", e);
61 ApiError::InternalError
62 })
63 .ok()
64 .flatten()
65 .flatten()
66 .unwrap_or(false);
67 if user_invites_disabled {
68 return ApiError::InvitesDisabled.into_response();
69 }
70 let code = Uuid::new_v4().to_string();
71 match sqlx::query!(
72 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
73 code,
74 input.use_count,
75 creator_user_id
76 )
77 .execute(&state.db)
78 .await
79 {
80 Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
81 Err(e) => {
82 error!("DB error creating invite code: {:?}", e);
83 ApiError::InternalError.into_response()
84 }
85 }
86}
87
88#[derive(Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct CreateInviteCodesInput {
91 pub code_count: Option<i32>,
92 pub use_count: i32,
93 pub for_accounts: Option<Vec<String>>,
94}
95
96#[derive(Serialize)]
97pub struct CreateInviteCodesOutput {
98 pub codes: Vec<AccountCodes>,
99}
100
101#[derive(Serialize)]
102pub struct AccountCodes {
103 pub account: String,
104 pub codes: Vec<String>,
105}
106
107pub async fn create_invite_codes(
108 State(state): State<AppState>,
109 BearerAuth(auth_user): BearerAuth,
110 Json(input): Json<CreateInviteCodesInput>,
111) -> Response {
112 if input.use_count < 1 {
113 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
114 }
115 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
116 Ok(id) => id,
117 Err(e) => return ApiError::from(e).into_response(),
118 };
119 let code_count = input.code_count.unwrap_or(1).max(1);
120 let for_accounts = input.for_accounts.unwrap_or_default();
121 let mut result_codes = Vec::new();
122 if for_accounts.is_empty() {
123 let mut codes = Vec::new();
124 for _ in 0..code_count {
125 let code = Uuid::new_v4().to_string();
126 if let Err(e) = sqlx::query!(
127 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
128 code,
129 input.use_count,
130 user_id
131 )
132 .execute(&state.db)
133 .await
134 {
135 error!("DB error creating invite code: {:?}", e);
136 return ApiError::InternalError.into_response();
137 }
138 codes.push(code);
139 }
140 result_codes.push(AccountCodes {
141 account: "admin".to_string(),
142 codes,
143 });
144 } else {
145 for account_did in for_accounts {
146 let target_user_id =
147 match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
148 .fetch_optional(&state.db)
149 .await
150 {
151 Ok(Some(row)) => row.id,
152 Ok(None) => continue,
153 Err(e) => {
154 error!("DB error looking up target account: {:?}", e);
155 return ApiError::InternalError.into_response();
156 }
157 };
158 let mut codes = Vec::new();
159 for _ in 0..code_count {
160 let code = Uuid::new_v4().to_string();
161 if let Err(e) = sqlx::query!(
162 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
163 code,
164 input.use_count,
165 target_user_id
166 )
167 .execute(&state.db)
168 .await
169 {
170 error!("DB error creating invite code: {:?}", e);
171 return ApiError::InternalError.into_response();
172 }
173 codes.push(code);
174 }
175 result_codes.push(AccountCodes {
176 account: account_did,
177 codes,
178 });
179 }
180 }
181 Json(CreateInviteCodesOutput {
182 codes: result_codes,
183 })
184 .into_response()
185}
186
187#[derive(Deserialize)]
188#[serde(rename_all = "camelCase")]
189pub struct GetAccountInviteCodesParams {
190 pub include_used: Option<bool>,
191 pub create_available: Option<bool>,
192}
193
194#[derive(Serialize)]
195#[serde(rename_all = "camelCase")]
196pub struct InviteCode {
197 pub code: String,
198 pub available: i32,
199 pub disabled: bool,
200 pub for_account: String,
201 pub created_by: String,
202 pub created_at: String,
203 pub uses: Vec<InviteCodeUse>,
204}
205
206#[derive(Serialize)]
207#[serde(rename_all = "camelCase")]
208pub struct InviteCodeUse {
209 pub used_by: String,
210 pub used_at: String,
211}
212
213#[derive(Serialize)]
214pub struct GetAccountInviteCodesOutput {
215 pub codes: Vec<InviteCode>,
216}
217
218pub async fn get_account_invite_codes(
219 State(state): State<AppState>,
220 BearerAuth(auth_user): BearerAuth,
221 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
222) -> Response {
223 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
224 Ok(id) => id,
225 Err(e) => return ApiError::from(e).into_response(),
226 };
227 let include_used = params.include_used.unwrap_or(true);
228 let codes_rows = match sqlx::query!(
229 r#"
230 SELECT code, available_uses, created_at, disabled
231 FROM invite_codes
232 WHERE created_by_user = $1
233 ORDER BY created_at DESC
234 "#,
235 user_id
236 )
237 .fetch_all(&state.db)
238 .await
239 {
240 Ok(rows) => {
241 if include_used {
242 rows
243 } else {
244 rows.into_iter().filter(|r| r.available_uses > 0).collect()
245 }
246 }
247 Err(e) => {
248 error!("DB error fetching invite codes: {:?}", e);
249 return ApiError::InternalError.into_response();
250 }
251 };
252 let mut codes = Vec::new();
253 for row in codes_rows {
254 let uses = sqlx::query!(
255 r#"
256 SELECT u.did, icu.used_at
257 FROM invite_code_uses icu
258 JOIN users u ON icu.used_by_user = u.id
259 WHERE icu.code = $1
260 ORDER BY icu.used_at DESC
261 "#,
262 row.code
263 )
264 .fetch_all(&state.db)
265 .await
266 .map(|use_rows| {
267 use_rows
268 .iter()
269 .map(|u| InviteCodeUse {
270 used_by: u.did.clone(),
271 used_at: u.used_at.to_rfc3339(),
272 })
273 .collect()
274 })
275 .unwrap_or_default();
276 codes.push(InviteCode {
277 code: row.code,
278 available: row.available_uses,
279 disabled: row.disabled.unwrap_or(false),
280 for_account: auth_user.did.clone(),
281 created_by: auth_user.did.clone(),
282 created_at: row.created_at.to_rfc3339(),
283 uses,
284 });
285 }
286 Json(GetAccountInviteCodesOutput { codes }).into_response()
287}