this repo has no description
1use crate::api::EmptyResponse;
2use crate::api::error::ApiError;
3use crate::auth::BearerAuthAdmin;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::{Query, State},
8 http::StatusCode,
9 response::{IntoResponse, Response},
10};
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14#[derive(Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub struct DisableInviteCodesInput {
17 pub codes: Option<Vec<String>>,
18 pub accounts: Option<Vec<String>>,
19}
20
21pub async fn disable_invite_codes(
22 State(state): State<AppState>,
23 _auth: BearerAuthAdmin,
24 Json(input): Json<DisableInviteCodesInput>,
25) -> Response {
26 if let Some(codes) = &input.codes {
27 let _ = sqlx::query!(
28 "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)",
29 codes as &[String]
30 )
31 .execute(&state.db)
32 .await;
33 }
34 if let Some(accounts) = &input.accounts {
35 let _ = sqlx::query!(
36 "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))",
37 accounts as &[String]
38 )
39 .execute(&state.db)
40 .await;
41 }
42 EmptyResponse::ok().into_response()
43}
44
45#[derive(Deserialize)]
46pub struct GetInviteCodesParams {
47 pub sort: Option<String>,
48 pub limit: Option<i64>,
49 pub cursor: Option<String>,
50}
51
52#[derive(Serialize)]
53#[serde(rename_all = "camelCase")]
54pub struct InviteCodeInfo {
55 pub code: String,
56 pub available: i32,
57 pub disabled: bool,
58 pub for_account: String,
59 pub created_by: String,
60 pub created_at: String,
61 pub uses: Vec<InviteCodeUseInfo>,
62}
63
64#[derive(Clone, Serialize)]
65#[serde(rename_all = "camelCase")]
66pub struct InviteCodeUseInfo {
67 pub used_by: String,
68 pub used_at: String,
69}
70
71#[derive(Serialize)]
72pub struct GetInviteCodesOutput {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub cursor: Option<String>,
75 pub codes: Vec<InviteCodeInfo>,
76}
77
78pub async fn get_invite_codes(
79 State(state): State<AppState>,
80 _auth: BearerAuthAdmin,
81 Query(params): Query<GetInviteCodesParams>,
82) -> Response {
83 let limit = params.limit.unwrap_or(100).clamp(1, 500);
84 let sort = params.sort.as_deref().unwrap_or("recent");
85 let order_clause = match sort {
86 "usage" => "available_uses DESC",
87 _ => "created_at DESC",
88 };
89 let codes_result = if let Some(cursor) = ¶ms.cursor {
90 sqlx::query_as::<
91 _,
92 (
93 String,
94 i32,
95 Option<bool>,
96 uuid::Uuid,
97 chrono::DateTime<chrono::Utc>,
98 ),
99 >(&format!(
100 r#"
101 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
102 FROM invite_codes ic
103 WHERE ic.created_at < (SELECT created_at FROM invite_codes WHERE code = $1)
104 ORDER BY {}
105 LIMIT $2
106 "#,
107 order_clause
108 ))
109 .bind(cursor)
110 .bind(limit)
111 .fetch_all(&state.db)
112 .await
113 } else {
114 sqlx::query_as::<
115 _,
116 (
117 String,
118 i32,
119 Option<bool>,
120 uuid::Uuid,
121 chrono::DateTime<chrono::Utc>,
122 ),
123 >(&format!(
124 r#"
125 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
126 FROM invite_codes ic
127 ORDER BY {}
128 LIMIT $1
129 "#,
130 order_clause
131 ))
132 .bind(limit)
133 .fetch_all(&state.db)
134 .await
135 };
136 let codes_rows = match codes_result {
137 Ok(rows) => rows,
138 Err(e) => {
139 error!("DB error fetching invite codes: {:?}", e);
140 return ApiError::InternalError(None).into_response();
141 }
142 };
143
144 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect();
145 let code_strings: Vec<String> = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect();
146
147 let mut creator_dids: std::collections::HashMap<uuid::Uuid, String> =
148 std::collections::HashMap::new();
149 sqlx::query!(
150 "SELECT id, did FROM users WHERE id = ANY($1)",
151 &user_ids
152 )
153 .fetch_all(&state.db)
154 .await
155 .unwrap_or_default()
156 .into_iter()
157 .for_each(|r| {
158 creator_dids.insert(r.id, r.did);
159 });
160
161 let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> =
162 std::collections::HashMap::new();
163 if !code_strings.is_empty() {
164 sqlx::query!(
165 r#"
166 SELECT icu.code, u.did, icu.used_at
167 FROM invite_code_uses icu
168 JOIN users u ON icu.used_by_user = u.id
169 WHERE icu.code = ANY($1)
170 ORDER BY icu.used_at DESC
171 "#,
172 &code_strings
173 )
174 .fetch_all(&state.db)
175 .await
176 .unwrap_or_default()
177 .into_iter()
178 .for_each(|r| {
179 uses_by_code
180 .entry(r.code)
181 .or_default()
182 .push(InviteCodeUseInfo {
183 used_by: r.did,
184 used_at: r.used_at.to_rfc3339(),
185 });
186 });
187 }
188
189 let codes: Vec<InviteCodeInfo> = codes_rows
190 .iter()
191 .map(|(code, available_uses, disabled, created_by_user, created_at)| {
192 let creator_did = creator_dids
193 .get(created_by_user)
194 .cloned()
195 .unwrap_or_else(|| "unknown".to_string());
196 InviteCodeInfo {
197 code: code.clone(),
198 available: *available_uses,
199 disabled: disabled.unwrap_or(false),
200 for_account: creator_did.clone(),
201 created_by: creator_did,
202 created_at: created_at.to_rfc3339(),
203 uses: uses_by_code.get(code).cloned().unwrap_or_default(),
204 }
205 })
206 .collect();
207
208 let next_cursor = if codes_rows.len() == limit as usize {
209 codes_rows.last().map(|(code, _, _, _, _)| code.clone())
210 } else {
211 None
212 };
213 (
214 StatusCode::OK,
215 Json(GetInviteCodesOutput {
216 cursor: next_cursor,
217 codes,
218 }),
219 )
220 .into_response()
221}
222
223#[derive(Deserialize)]
224pub struct DisableAccountInvitesInput {
225 pub account: String,
226}
227
228pub async fn disable_account_invites(
229 State(state): State<AppState>,
230 _auth: BearerAuthAdmin,
231 Json(input): Json<DisableAccountInvitesInput>,
232) -> Response {
233 let account = input.account.trim();
234 if account.is_empty() {
235 return ApiError::InvalidRequest("account is required".into()).into_response();
236 }
237 let result = sqlx::query!(
238 "UPDATE users SET invites_disabled = TRUE WHERE did = $1",
239 account
240 )
241 .execute(&state.db)
242 .await;
243 match result {
244 Ok(r) => {
245 if r.rows_affected() == 0 {
246 return ApiError::AccountNotFound.into_response();
247 }
248 EmptyResponse::ok().into_response()
249 }
250 Err(e) => {
251 error!("DB error disabling account invites: {:?}", e);
252 ApiError::InternalError(None).into_response()
253 }
254 }
255}
256
257#[derive(Deserialize)]
258pub struct EnableAccountInvitesInput {
259 pub account: String,
260}
261
262pub async fn enable_account_invites(
263 State(state): State<AppState>,
264 _auth: BearerAuthAdmin,
265 Json(input): Json<EnableAccountInvitesInput>,
266) -> Response {
267 let account = input.account.trim();
268 if account.is_empty() {
269 return ApiError::InvalidRequest("account is required".into()).into_response();
270 }
271 let result = sqlx::query!(
272 "UPDATE users SET invites_disabled = FALSE WHERE did = $1",
273 account
274 )
275 .execute(&state.db)
276 .await;
277 match result {
278 Ok(r) => {
279 if r.rows_affected() == 0 {
280 return ApiError::AccountNotFound.into_response();
281 }
282 EmptyResponse::ok().into_response()
283 }
284 Err(e) => {
285 error!("DB error enabling account invites: {:?}", e);
286 ApiError::InternalError(None).into_response()
287 }
288 }
289}