this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::state::AppState;
4use crate::util::get_user_id_by_did;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use serde::{Deserialize, Serialize};
11use tracing::error;
12use uuid::Uuid;
13#[derive(Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct CreateInviteCodeInput {
16 pub use_count: i32,
17 pub for_account: Option<String>,
18}
19#[derive(Serialize)]
20pub struct CreateInviteCodeOutput {
21 pub code: String,
22}
23pub async fn create_invite_code(
24 State(state): State<AppState>,
25 BearerAuth(auth_user): BearerAuth,
26 Json(input): Json<CreateInviteCodeInput>,
27) -> Response {
28 if input.use_count < 1 {
29 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
30 }
31 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
32 Ok(id) => id,
33 Err(e) => return ApiError::from(e).into_response(),
34 };
35 let creator_user_id = if let Some(for_account) = &input.for_account {
36 match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
37 .fetch_optional(&state.db)
38 .await
39 {
40 Ok(Some(row)) => row.id,
41 Ok(None) => return ApiError::AccountNotFound.into_response(),
42 Err(e) => {
43 error!("DB error looking up target account: {:?}", e);
44 return ApiError::InternalError.into_response();
45 }
46 }
47 } else {
48 user_id
49 };
50 let user_invites_disabled = sqlx::query_scalar!(
51 "SELECT invites_disabled FROM users WHERE did = $1",
52 auth_user.did
53 )
54 .fetch_optional(&state.db)
55 .await
56 .map_err(|e| {
57 error!("DB error checking invites_disabled: {:?}", e);
58 ApiError::InternalError
59 })
60 .ok()
61 .flatten()
62 .flatten()
63 .unwrap_or(false);
64 if user_invites_disabled {
65 return ApiError::InvitesDisabled.into_response();
66 }
67 let code = Uuid::new_v4().to_string();
68 match sqlx::query!(
69 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
70 code,
71 input.use_count,
72 creator_user_id
73 )
74 .execute(&state.db)
75 .await
76 {
77 Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
78 Err(e) => {
79 error!("DB error creating invite code: {:?}", e);
80 ApiError::InternalError.into_response()
81 }
82 }
83}
84#[derive(Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct CreateInviteCodesInput {
87 pub code_count: Option<i32>,
88 pub use_count: i32,
89 pub for_accounts: Option<Vec<String>>,
90}
91#[derive(Serialize)]
92pub struct CreateInviteCodesOutput {
93 pub codes: Vec<AccountCodes>,
94}
95#[derive(Serialize)]
96pub struct AccountCodes {
97 pub account: String,
98 pub codes: Vec<String>,
99}
100pub async fn create_invite_codes(
101 State(state): State<AppState>,
102 BearerAuth(auth_user): BearerAuth,
103 Json(input): Json<CreateInviteCodesInput>,
104) -> Response {
105 if input.use_count < 1 {
106 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
107 }
108 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
109 Ok(id) => id,
110 Err(e) => return ApiError::from(e).into_response(),
111 };
112 let code_count = input.code_count.unwrap_or(1).max(1);
113 let for_accounts = input.for_accounts.unwrap_or_default();
114 let mut result_codes = Vec::new();
115 if for_accounts.is_empty() {
116 let mut codes = Vec::new();
117 for _ in 0..code_count {
118 let code = Uuid::new_v4().to_string();
119 if let Err(e) = sqlx::query!(
120 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
121 code,
122 input.use_count,
123 user_id
124 )
125 .execute(&state.db)
126 .await
127 {
128 error!("DB error creating invite code: {:?}", e);
129 return ApiError::InternalError.into_response();
130 }
131 codes.push(code);
132 }
133 result_codes.push(AccountCodes {
134 account: "admin".to_string(),
135 codes,
136 });
137 } else {
138 for account_did in for_accounts {
139 let target_user_id = match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
140 .fetch_optional(&state.db)
141 .await
142 {
143 Ok(Some(row)) => row.id,
144 Ok(None) => continue,
145 Err(e) => {
146 error!("DB error looking up target account: {:?}", e);
147 return ApiError::InternalError.into_response();
148 }
149 };
150 let mut codes = Vec::new();
151 for _ in 0..code_count {
152 let code = Uuid::new_v4().to_string();
153 if let Err(e) = sqlx::query!(
154 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
155 code,
156 input.use_count,
157 target_user_id
158 )
159 .execute(&state.db)
160 .await
161 {
162 error!("DB error creating invite code: {:?}", e);
163 return ApiError::InternalError.into_response();
164 }
165 codes.push(code);
166 }
167 result_codes.push(AccountCodes {
168 account: account_did,
169 codes,
170 });
171 }
172 }
173 Json(CreateInviteCodesOutput { codes: result_codes }).into_response()
174}
175#[derive(Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct GetAccountInviteCodesParams {
178 pub include_used: Option<bool>,
179 pub create_available: Option<bool>,
180}
181#[derive(Serialize)]
182#[serde(rename_all = "camelCase")]
183pub struct InviteCode {
184 pub code: String,
185 pub available: i32,
186 pub disabled: bool,
187 pub for_account: String,
188 pub created_by: String,
189 pub created_at: String,
190 pub uses: Vec<InviteCodeUse>,
191}
192#[derive(Serialize)]
193#[serde(rename_all = "camelCase")]
194pub struct InviteCodeUse {
195 pub used_by: String,
196 pub used_at: String,
197}
198#[derive(Serialize)]
199pub struct GetAccountInviteCodesOutput {
200 pub codes: Vec<InviteCode>,
201}
202pub async fn get_account_invite_codes(
203 State(state): State<AppState>,
204 BearerAuth(auth_user): BearerAuth,
205 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
206) -> Response {
207 let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
208 Ok(id) => id,
209 Err(e) => return ApiError::from(e).into_response(),
210 };
211 let include_used = params.include_used.unwrap_or(true);
212 let codes_rows = match sqlx::query!(
213 r#"
214 SELECT code, available_uses, created_at, disabled
215 FROM invite_codes
216 WHERE created_by_user = $1
217 ORDER BY created_at DESC
218 "#,
219 user_id
220 )
221 .fetch_all(&state.db)
222 .await
223 {
224 Ok(rows) => {
225 if include_used {
226 rows
227 } else {
228 rows.into_iter().filter(|r| r.available_uses > 0).collect()
229 }
230 }
231 Err(e) => {
232 error!("DB error fetching invite codes: {:?}", e);
233 return ApiError::InternalError.into_response();
234 }
235 };
236 let mut codes = Vec::new();
237 for row in codes_rows {
238 let uses = sqlx::query!(
239 r#"
240 SELECT u.did, icu.used_at
241 FROM invite_code_uses icu
242 JOIN users u ON icu.used_by_user = u.id
243 WHERE icu.code = $1
244 ORDER BY icu.used_at DESC
245 "#,
246 row.code
247 )
248 .fetch_all(&state.db)
249 .await
250 .map(|use_rows| {
251 use_rows
252 .iter()
253 .map(|u| InviteCodeUse {
254 used_by: u.did.clone(),
255 used_at: u.used_at.to_rfc3339(),
256 })
257 .collect()
258 })
259 .unwrap_or_default();
260 codes.push(InviteCode {
261 code: row.code,
262 available: row.available_uses,
263 disabled: row.disabled.unwrap_or(false),
264 for_account: auth_user.did.clone(),
265 created_by: auth_user.did.clone(),
266 created_at: row.created_at.to_rfc3339(),
267 uses,
268 });
269 }
270 Json(GetAccountInviteCodesOutput { codes }).into_response()
271}