this repo has no description
1use crate::api::ApiError;
2use crate::auth::extractor::BearerAuthAdmin;
3use crate::auth::BearerAuth;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
15
16fn gen_random_token() -> String {
17 let mut rng = rand::thread_rng();
18 let mut token = String::with_capacity(11);
19 for i in 0..10 {
20 if i == 5 {
21 token.push('-');
22 }
23 let idx = rng.gen_range(0..32);
24 token.push(BASE32_ALPHABET[idx] as char);
25 }
26 token
27}
28
29fn gen_invite_code() -> String {
30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
31 let hostname_prefix = hostname.replace('.', "-");
32 format!("{}-{}", hostname_prefix, gen_random_token())
33}
34
35#[derive(Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CreateInviteCodeInput {
38 pub use_count: i32,
39 pub for_account: Option<String>,
40}
41
42#[derive(Serialize)]
43pub struct CreateInviteCodeOutput {
44 pub code: String,
45}
46
47pub async fn create_invite_code(
48 State(state): State<AppState>,
49 BearerAuthAdmin(_auth_user): BearerAuthAdmin,
50 Json(input): Json<CreateInviteCodeInput>,
51) -> Response {
52 if input.use_count < 1 {
53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
54 }
55
56 let for_account = input.for_account.unwrap_or_else(|| "admin".to_string());
57 let code = gen_invite_code();
58
59 match sqlx::query!(
60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
62 code,
63 input.use_count,
64 for_account
65 )
66 .execute(&state.db)
67 .await
68 {
69 Ok(result) => {
70 if result.rows_affected() == 0 {
71 error!("No admin user found to create invite code");
72 return ApiError::InternalError.into_response();
73 }
74 Json(CreateInviteCodeOutput { code }).into_response()
75 }
76 Err(e) => {
77 error!("DB error creating invite code: {:?}", e);
78 ApiError::InternalError.into_response()
79 }
80 }
81}
82
83#[derive(Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct CreateInviteCodesInput {
86 pub code_count: Option<i32>,
87 pub use_count: i32,
88 pub for_accounts: Option<Vec<String>>,
89}
90
91#[derive(Serialize)]
92pub struct CreateInviteCodesOutput {
93 pub codes: Vec<AccountCodes>,
94}
95
96#[derive(Serialize)]
97pub struct AccountCodes {
98 pub account: String,
99 pub codes: Vec<String>,
100}
101
102pub async fn create_invite_codes(
103 State(state): State<AppState>,
104 BearerAuthAdmin(_auth_user): BearerAuthAdmin,
105 Json(input): Json<CreateInviteCodesInput>,
106) -> Response {
107 if input.use_count < 1 {
108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
109 }
110
111 let code_count = input.code_count.unwrap_or(1).max(1);
112 let for_accounts = input
113 .for_accounts
114 .filter(|v| !v.is_empty())
115 .unwrap_or_else(|| vec!["admin".to_string()]);
116
117 let admin_user_id = match sqlx::query_scalar!(
118 "SELECT id FROM users WHERE is_admin = true LIMIT 1"
119 )
120 .fetch_optional(&state.db)
121 .await
122 {
123 Ok(Some(id)) => id,
124 Ok(None) => {
125 error!("No admin user found to create invite codes");
126 return ApiError::InternalError.into_response();
127 }
128 Err(e) => {
129 error!("DB error looking up admin user: {:?}", e);
130 return ApiError::InternalError.into_response();
131 }
132 };
133
134 let mut result_codes = Vec::new();
135
136 for account in for_accounts {
137 let mut codes = Vec::new();
138 for _ in 0..code_count {
139 let code = gen_invite_code();
140 if let Err(e) = sqlx::query!(
141 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)",
142 code,
143 input.use_count,
144 admin_user_id,
145 account
146 )
147 .execute(&state.db)
148 .await
149 {
150 error!("DB error creating invite code: {:?}", e);
151 return ApiError::InternalError.into_response();
152 }
153 codes.push(code);
154 }
155 result_codes.push(AccountCodes { account, codes });
156 }
157
158 Json(CreateInviteCodesOutput {
159 codes: result_codes,
160 })
161 .into_response()
162}
163
164#[derive(Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct GetAccountInviteCodesParams {
167 pub include_used: Option<bool>,
168 pub create_available: Option<bool>,
169}
170
171#[derive(Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct InviteCode {
174 pub code: String,
175 pub available: i32,
176 pub disabled: bool,
177 pub for_account: String,
178 pub created_by: String,
179 pub created_at: String,
180 pub uses: Vec<InviteCodeUse>,
181}
182
183#[derive(Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct InviteCodeUse {
186 pub used_by: String,
187 pub used_at: String,
188}
189
190#[derive(Serialize)]
191pub struct GetAccountInviteCodesOutput {
192 pub codes: Vec<InviteCode>,
193}
194
195pub async fn get_account_invite_codes(
196 State(state): State<AppState>,
197 BearerAuth(auth_user): BearerAuth,
198 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
199) -> Response {
200 let include_used = params.include_used.unwrap_or(true);
201
202 let codes_rows = match sqlx::query!(
203 r#"
204 SELECT
205 ic.code,
206 ic.available_uses,
207 ic.created_at,
208 ic.disabled,
209 ic.for_account,
210 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
211 FROM invite_codes ic
212 WHERE ic.for_account = $1
213 ORDER BY ic.created_at DESC
214 "#,
215 auth_user.did
216 )
217 .fetch_all(&state.db)
218 .await
219 {
220 Ok(rows) => rows,
221 Err(e) => {
222 error!("DB error fetching invite codes: {:?}", e);
223 return ApiError::InternalError.into_response();
224 }
225 };
226
227 let mut codes = Vec::new();
228 for row in codes_rows {
229 let disabled = row.disabled.unwrap_or(false);
230 if disabled {
231 continue;
232 }
233
234 let use_count = row.use_count;
235 if !include_used && use_count >= row.available_uses {
236 continue;
237 }
238
239 let uses = sqlx::query!(
240 r#"
241 SELECT u.did, icu.used_at
242 FROM invite_code_uses icu
243 JOIN users u ON icu.used_by_user = u.id
244 WHERE icu.code = $1
245 ORDER BY icu.used_at DESC
246 "#,
247 row.code
248 )
249 .fetch_all(&state.db)
250 .await
251 .map(|use_rows| {
252 use_rows
253 .iter()
254 .map(|u| InviteCodeUse {
255 used_by: u.did.clone(),
256 used_at: u.used_at.to_rfc3339(),
257 })
258 .collect()
259 })
260 .unwrap_or_default();
261
262 codes.push(InviteCode {
263 code: row.code,
264 available: row.available_uses,
265 disabled,
266 for_account: row.for_account,
267 created_by: "admin".to_string(),
268 created_at: row.created_at.to_rfc3339(),
269 uses,
270 });
271 }
272
273 Json(GetAccountInviteCodesOutput { codes }).into_response()
274}