this repo has no description
1use crate::api::ApiError;
2use crate::auth::extractor::BearerAuthAdmin;
3use crate::auth::BearerAuth;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
15
16fn gen_random_token() -> String {
17 let mut rng = rand::thread_rng();
18 let mut token = String::with_capacity(11);
19 for i in 0..10 {
20 if i == 5 {
21 token.push('-');
22 }
23 let idx = rng.gen_range(0..32);
24 token.push(BASE32_ALPHABET[idx] as char);
25 }
26 token
27}
28
29fn gen_invite_code() -> String {
30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
31 let hostname_prefix = hostname.replace('.', "-");
32 format!("{}-{}", hostname_prefix, gen_random_token())
33}
34
35#[derive(Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CreateInviteCodeInput {
38 pub use_count: i32,
39 pub for_account: Option<String>,
40}
41
42#[derive(Serialize)]
43pub struct CreateInviteCodeOutput {
44 pub code: String,
45}
46
47pub async fn create_invite_code(
48 State(state): State<AppState>,
49 BearerAuthAdmin(auth_user): BearerAuthAdmin,
50 Json(input): Json<CreateInviteCodeInput>,
51) -> Response {
52 if input.use_count < 1 {
53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
54 }
55
56 let for_account = input.for_account.unwrap_or_else(|| auth_user.did.clone());
57 let code = gen_invite_code();
58
59 match sqlx::query!(
60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
62 code,
63 input.use_count,
64 for_account
65 )
66 .execute(&state.db)
67 .await
68 {
69 Ok(result) => {
70 if result.rows_affected() == 0 {
71 error!("No admin user found to create invite code");
72 return ApiError::InternalError.into_response();
73 }
74 Json(CreateInviteCodeOutput { code }).into_response()
75 }
76 Err(e) => {
77 error!("DB error creating invite code: {:?}", e);
78 ApiError::InternalError.into_response()
79 }
80 }
81}
82
83#[derive(Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct CreateInviteCodesInput {
86 pub code_count: Option<i32>,
87 pub use_count: i32,
88 pub for_accounts: Option<Vec<String>>,
89}
90
91#[derive(Serialize)]
92pub struct CreateInviteCodesOutput {
93 pub codes: Vec<AccountCodes>,
94}
95
96#[derive(Serialize)]
97pub struct AccountCodes {
98 pub account: String,
99 pub codes: Vec<String>,
100}
101
102pub async fn create_invite_codes(
103 State(state): State<AppState>,
104 BearerAuthAdmin(auth_user): BearerAuthAdmin,
105 Json(input): Json<CreateInviteCodesInput>,
106) -> Response {
107 if input.use_count < 1 {
108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
109 }
110
111 let code_count = input.code_count.unwrap_or(1).max(1);
112 let for_accounts = input
113 .for_accounts
114 .filter(|v| !v.is_empty())
115 .unwrap_or_else(|| vec![auth_user.did.clone()]);
116
117 let admin_user_id = match sqlx::query_scalar!(
118 "SELECT id FROM users WHERE is_admin = true LIMIT 1"
119 )
120 .fetch_optional(&state.db)
121 .await
122 {
123 Ok(Some(id)) => id,
124 Ok(None) => {
125 error!("No admin user found to create invite codes");
126 return ApiError::InternalError.into_response();
127 }
128 Err(e) => {
129 error!("DB error looking up admin user: {:?}", e);
130 return ApiError::InternalError.into_response();
131 }
132 };
133
134 let mut result_codes = Vec::new();
135
136 for account in for_accounts {
137 let mut codes = Vec::new();
138 for _ in 0..code_count {
139 let code = gen_invite_code();
140 if let Err(e) = sqlx::query!(
141 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)",
142 code,
143 input.use_count,
144 admin_user_id,
145 account
146 )
147 .execute(&state.db)
148 .await
149 {
150 error!("DB error creating invite code: {:?}", e);
151 return ApiError::InternalError.into_response();
152 }
153 codes.push(code);
154 }
155 result_codes.push(AccountCodes { account, codes });
156 }
157
158 Json(CreateInviteCodesOutput {
159 codes: result_codes,
160 })
161 .into_response()
162}
163
164#[derive(Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct GetAccountInviteCodesParams {
167 pub include_used: Option<bool>,
168 pub create_available: Option<bool>,
169}
170
171#[derive(Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct InviteCode {
174 pub code: String,
175 pub available: i32,
176 pub disabled: bool,
177 pub for_account: String,
178 pub created_by: String,
179 pub created_at: String,
180 pub uses: Vec<InviteCodeUse>,
181}
182
183#[derive(Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct InviteCodeUse {
186 pub used_by: String,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub used_by_handle: Option<String>,
189 pub used_at: String,
190}
191
192#[derive(Serialize)]
193pub struct GetAccountInviteCodesOutput {
194 pub codes: Vec<InviteCode>,
195}
196
197pub async fn get_account_invite_codes(
198 State(state): State<AppState>,
199 BearerAuth(auth_user): BearerAuth,
200 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
201) -> Response {
202 let include_used = params.include_used.unwrap_or(true);
203
204 let codes_rows = match sqlx::query!(
205 r#"
206 SELECT
207 ic.code,
208 ic.available_uses,
209 ic.created_at,
210 ic.disabled,
211 ic.for_account,
212 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
213 FROM invite_codes ic
214 WHERE ic.for_account = $1
215 ORDER BY ic.created_at DESC
216 "#,
217 auth_user.did
218 )
219 .fetch_all(&state.db)
220 .await
221 {
222 Ok(rows) => rows,
223 Err(e) => {
224 error!("DB error fetching invite codes: {:?}", e);
225 return ApiError::InternalError.into_response();
226 }
227 };
228
229 let mut codes = Vec::new();
230 for row in codes_rows {
231 let disabled = row.disabled.unwrap_or(false);
232 if disabled {
233 continue;
234 }
235
236 let use_count = row.use_count;
237 if !include_used && use_count >= row.available_uses {
238 continue;
239 }
240
241 let uses = sqlx::query!(
242 r#"
243 SELECT u.did, u.handle, icu.used_at
244 FROM invite_code_uses icu
245 JOIN users u ON icu.used_by_user = u.id
246 WHERE icu.code = $1
247 ORDER BY icu.used_at DESC
248 "#,
249 row.code
250 )
251 .fetch_all(&state.db)
252 .await
253 .map(|use_rows| {
254 use_rows
255 .iter()
256 .map(|u| InviteCodeUse {
257 used_by: u.did.clone(),
258 used_by_handle: Some(u.handle.clone()),
259 used_at: u.used_at.to_rfc3339(),
260 })
261 .collect()
262 })
263 .unwrap_or_default();
264
265 codes.push(InviteCode {
266 code: row.code,
267 available: row.available_uses,
268 disabled,
269 for_account: row.for_account,
270 created_by: "admin".to_string(),
271 created_at: row.created_at.to_rfc3339(),
272 uses,
273 });
274 }
275
276 Json(GetAccountInviteCodesOutput { codes }).into_response()
277}