this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Query, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use tracing::error;
11
12#[derive(Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct DisableInviteCodesInput {
15 pub codes: Option<Vec<String>>,
16 pub accounts: Option<Vec<String>>,
17}
18
19pub async fn disable_invite_codes(
20 State(state): State<AppState>,
21 headers: axum::http::HeaderMap,
22 Json(input): Json<DisableInviteCodesInput>,
23) -> Response {
24 let auth_header = headers.get("Authorization");
25 if auth_header.is_none() {
26 return (
27 StatusCode::UNAUTHORIZED,
28 Json(json!({"error": "AuthenticationRequired"})),
29 )
30 .into_response();
31 }
32 if let Some(codes) = &input.codes {
33 for code in codes {
34 let _ = sqlx::query!(
35 "UPDATE invite_codes SET disabled = TRUE WHERE code = $1",
36 code
37 )
38 .execute(&state.db)
39 .await;
40 }
41 }
42 if let Some(accounts) = &input.accounts {
43 for account in accounts {
44 let user = sqlx::query!("SELECT id FROM users WHERE did = $1", account)
45 .fetch_optional(&state.db)
46 .await;
47 if let Ok(Some(user_row)) = user {
48 let _ = sqlx::query!(
49 "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user = $1",
50 user_row.id
51 )
52 .execute(&state.db)
53 .await;
54 }
55 }
56 }
57 (StatusCode::OK, Json(json!({}))).into_response()
58}
59
60#[derive(Deserialize)]
61pub struct GetInviteCodesParams {
62 pub sort: Option<String>,
63 pub limit: Option<i64>,
64 pub cursor: Option<String>,
65}
66
67#[derive(Serialize)]
68#[serde(rename_all = "camelCase")]
69pub struct InviteCodeInfo {
70 pub code: String,
71 pub available: i32,
72 pub disabled: bool,
73 pub for_account: String,
74 pub created_by: String,
75 pub created_at: String,
76 pub uses: Vec<InviteCodeUseInfo>,
77}
78
79#[derive(Serialize)]
80#[serde(rename_all = "camelCase")]
81pub struct InviteCodeUseInfo {
82 pub used_by: String,
83 pub used_at: String,
84}
85
86#[derive(Serialize)]
87pub struct GetInviteCodesOutput {
88 pub cursor: Option<String>,
89 pub codes: Vec<InviteCodeInfo>,
90}
91
92pub async fn get_invite_codes(
93 State(state): State<AppState>,
94 headers: axum::http::HeaderMap,
95 Query(params): Query<GetInviteCodesParams>,
96) -> Response {
97 let auth_header = headers.get("Authorization");
98 if auth_header.is_none() {
99 return (
100 StatusCode::UNAUTHORIZED,
101 Json(json!({"error": "AuthenticationRequired"})),
102 )
103 .into_response();
104 }
105 let limit = params.limit.unwrap_or(100).clamp(1, 500);
106 let sort = params.sort.as_deref().unwrap_or("recent");
107 let order_clause = match sort {
108 "usage" => "available_uses DESC",
109 _ => "created_at DESC",
110 };
111 let codes_result = if let Some(cursor) = ¶ms.cursor {
112 sqlx::query_as::<
113 _,
114 (
115 String,
116 i32,
117 Option<bool>,
118 uuid::Uuid,
119 chrono::DateTime<chrono::Utc>,
120 ),
121 >(&format!(
122 r#"
123 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
124 FROM invite_codes ic
125 WHERE ic.created_at < (SELECT created_at FROM invite_codes WHERE code = $1)
126 ORDER BY {}
127 LIMIT $2
128 "#,
129 order_clause
130 ))
131 .bind(cursor)
132 .bind(limit)
133 .fetch_all(&state.db)
134 .await
135 } else {
136 sqlx::query_as::<
137 _,
138 (
139 String,
140 i32,
141 Option<bool>,
142 uuid::Uuid,
143 chrono::DateTime<chrono::Utc>,
144 ),
145 >(&format!(
146 r#"
147 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at
148 FROM invite_codes ic
149 ORDER BY {}
150 LIMIT $1
151 "#,
152 order_clause
153 ))
154 .bind(limit)
155 .fetch_all(&state.db)
156 .await
157 };
158 let codes_rows = match codes_result {
159 Ok(rows) => rows,
160 Err(e) => {
161 error!("DB error fetching invite codes: {:?}", e);
162 return (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(json!({"error": "InternalError"})),
165 )
166 .into_response();
167 }
168 };
169 let mut codes = Vec::new();
170 for (code, available_uses, disabled, created_by_user, created_at) in &codes_rows {
171 let creator_did =
172 sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", created_by_user)
173 .fetch_optional(&state.db)
174 .await
175 .ok()
176 .flatten()
177 .unwrap_or_else(|| "unknown".to_string());
178 let uses_result = sqlx::query!(
179 r#"
180 SELECT u.did, icu.used_at
181 FROM invite_code_uses icu
182 JOIN users u ON icu.used_by_user = u.id
183 WHERE icu.code = $1
184 ORDER BY icu.used_at DESC
185 "#,
186 code
187 )
188 .fetch_all(&state.db)
189 .await;
190 let uses = match uses_result {
191 Ok(use_rows) => use_rows
192 .iter()
193 .map(|u| InviteCodeUseInfo {
194 used_by: u.did.clone(),
195 used_at: u.used_at.to_rfc3339(),
196 })
197 .collect(),
198 Err(_) => Vec::new(),
199 };
200 codes.push(InviteCodeInfo {
201 code: code.clone(),
202 available: *available_uses,
203 disabled: disabled.unwrap_or(false),
204 for_account: creator_did.clone(),
205 created_by: creator_did,
206 created_at: created_at.to_rfc3339(),
207 uses,
208 });
209 }
210 let next_cursor = if codes_rows.len() == limit as usize {
211 codes_rows.last().map(|(code, _, _, _, _)| code.clone())
212 } else {
213 None
214 };
215 (
216 StatusCode::OK,
217 Json(GetInviteCodesOutput {
218 cursor: next_cursor,
219 codes,
220 }),
221 )
222 .into_response()
223}
224
225#[derive(Deserialize)]
226pub struct DisableAccountInvitesInput {
227 pub account: String,
228}
229
230pub async fn disable_account_invites(
231 State(state): State<AppState>,
232 headers: axum::http::HeaderMap,
233 Json(input): Json<DisableAccountInvitesInput>,
234) -> Response {
235 let auth_header = headers.get("Authorization");
236 if auth_header.is_none() {
237 return (
238 StatusCode::UNAUTHORIZED,
239 Json(json!({"error": "AuthenticationRequired"})),
240 )
241 .into_response();
242 }
243 let account = input.account.trim();
244 if account.is_empty() {
245 return (
246 StatusCode::BAD_REQUEST,
247 Json(json!({"error": "InvalidRequest", "message": "account is required"})),
248 )
249 .into_response();
250 }
251 let result = sqlx::query!(
252 "UPDATE users SET invites_disabled = TRUE WHERE did = $1",
253 account
254 )
255 .execute(&state.db)
256 .await;
257 match result {
258 Ok(r) => {
259 if r.rows_affected() == 0 {
260 return (
261 StatusCode::NOT_FOUND,
262 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
263 )
264 .into_response();
265 }
266 (StatusCode::OK, Json(json!({}))).into_response()
267 }
268 Err(e) => {
269 error!("DB error disabling account invites: {:?}", e);
270 (
271 StatusCode::INTERNAL_SERVER_ERROR,
272 Json(json!({"error": "InternalError"})),
273 )
274 .into_response()
275 }
276 }
277}
278
279#[derive(Deserialize)]
280pub struct EnableAccountInvitesInput {
281 pub account: String,
282}
283
284pub async fn enable_account_invites(
285 State(state): State<AppState>,
286 headers: axum::http::HeaderMap,
287 Json(input): Json<EnableAccountInvitesInput>,
288) -> Response {
289 let auth_header = headers.get("Authorization");
290 if auth_header.is_none() {
291 return (
292 StatusCode::UNAUTHORIZED,
293 Json(json!({"error": "AuthenticationRequired"})),
294 )
295 .into_response();
296 }
297 let account = input.account.trim();
298 if account.is_empty() {
299 return (
300 StatusCode::BAD_REQUEST,
301 Json(json!({"error": "InvalidRequest", "message": "account is required"})),
302 )
303 .into_response();
304 }
305 let result = sqlx::query!(
306 "UPDATE users SET invites_disabled = FALSE WHERE did = $1",
307 account
308 )
309 .execute(&state.db)
310 .await;
311 match result {
312 Ok(r) => {
313 if r.rows_affected() == 0 {
314 return (
315 StatusCode::NOT_FOUND,
316 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
317 )
318 .into_response();
319 }
320 (StatusCode::OK, Json(json!({}))).into_response()
321 }
322 Err(e) => {
323 error!("DB error enabling account invites: {:?}", e);
324 (
325 StatusCode::INTERNAL_SERVER_ERROR,
326 Json(json!({"error": "InternalError"})),
327 )
328 .into_response()
329 }
330 }
331}