this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::auth::extractor::BearerAuthAdmin;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
15
16fn gen_random_token() -> String {
17 let mut rng = rand::thread_rng();
18 let mut token = String::with_capacity(11);
19 for i in 0..10 {
20 if i == 5 {
21 token.push('-');
22 }
23 let idx = rng.gen_range(0..32);
24 token.push(BASE32_ALPHABET[idx] as char);
25 }
26 token
27}
28
29fn gen_invite_code() -> String {
30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
31 let hostname_prefix = hostname.replace('.', "-");
32 format!("{}-{}", hostname_prefix, gen_random_token())
33}
34
35#[derive(Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CreateInviteCodeInput {
38 pub use_count: i32,
39 pub for_account: Option<String>,
40}
41
42#[derive(Serialize)]
43pub struct CreateInviteCodeOutput {
44 pub code: String,
45}
46
47pub async fn create_invite_code(
48 State(state): State<AppState>,
49 BearerAuthAdmin(auth_user): BearerAuthAdmin,
50 Json(input): Json<CreateInviteCodeInput>,
51) -> Response {
52 if input.use_count < 1 {
53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
54 }
55
56 let for_account = input.for_account.unwrap_or_else(|| auth_user.did.to_string());
57 let code = gen_invite_code();
58
59 match sqlx::query!(
60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
62 code,
63 input.use_count,
64 for_account
65 )
66 .execute(&state.db)
67 .await
68 {
69 Ok(result) => {
70 if result.rows_affected() == 0 {
71 error!("No admin user found to create invite code");
72 return ApiError::InternalError(None).into_response();
73 }
74 Json(CreateInviteCodeOutput { code }).into_response()
75 }
76 Err(e) => {
77 error!("DB error creating invite code: {:?}", e);
78 ApiError::InternalError(None).into_response()
79 }
80 }
81}
82
83#[derive(Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct CreateInviteCodesInput {
86 pub code_count: Option<i32>,
87 pub use_count: i32,
88 pub for_accounts: Option<Vec<String>>,
89}
90
91#[derive(Serialize)]
92pub struct CreateInviteCodesOutput {
93 pub codes: Vec<AccountCodes>,
94}
95
96#[derive(Serialize)]
97pub struct AccountCodes {
98 pub account: String,
99 pub codes: Vec<String>,
100}
101
102pub async fn create_invite_codes(
103 State(state): State<AppState>,
104 BearerAuthAdmin(auth_user): BearerAuthAdmin,
105 Json(input): Json<CreateInviteCodesInput>,
106) -> Response {
107 if input.use_count < 1 {
108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
109 }
110
111 let code_count = input.code_count.unwrap_or(1).max(1);
112 let for_accounts = input
113 .for_accounts
114 .filter(|v| !v.is_empty())
115 .unwrap_or_else(|| vec![auth_user.did.to_string()]);
116
117 let admin_user_id =
118 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1")
119 .fetch_optional(&state.db)
120 .await
121 {
122 Ok(Some(id)) => id,
123 Ok(None) => {
124 error!("No admin user found to create invite codes");
125 return ApiError::InternalError(None).into_response();
126 }
127 Err(e) => {
128 error!("DB error looking up admin user: {:?}", e);
129 return ApiError::InternalError(None).into_response();
130 }
131 };
132
133 let mut result_codes = Vec::new();
134
135 for account in for_accounts {
136 let mut codes = Vec::new();
137 for _ in 0..code_count {
138 let code = gen_invite_code();
139 if let Err(e) = sqlx::query!(
140 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)",
141 code,
142 input.use_count,
143 admin_user_id,
144 account
145 )
146 .execute(&state.db)
147 .await
148 {
149 error!("DB error creating invite code: {:?}", e);
150 return ApiError::InternalError(None).into_response();
151 }
152 codes.push(code);
153 }
154 result_codes.push(AccountCodes { account, codes });
155 }
156
157 Json(CreateInviteCodesOutput {
158 codes: result_codes,
159 })
160 .into_response()
161}
162
163#[derive(Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct GetAccountInviteCodesParams {
166 pub include_used: Option<bool>,
167 pub create_available: Option<bool>,
168}
169
170#[derive(Serialize)]
171#[serde(rename_all = "camelCase")]
172pub struct InviteCode {
173 pub code: String,
174 pub available: i32,
175 pub disabled: bool,
176 pub for_account: String,
177 pub created_by: String,
178 pub created_at: String,
179 pub uses: Vec<InviteCodeUse>,
180}
181
182#[derive(Serialize)]
183#[serde(rename_all = "camelCase")]
184pub struct InviteCodeUse {
185 pub used_by: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub used_by_handle: Option<String>,
188 pub used_at: String,
189}
190
191#[derive(Serialize)]
192pub struct GetAccountInviteCodesOutput {
193 pub codes: Vec<InviteCode>,
194}
195
196pub async fn get_account_invite_codes(
197 State(state): State<AppState>,
198 BearerAuth(auth_user): BearerAuth,
199 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
200) -> Response {
201 let include_used = params.include_used.unwrap_or(true);
202
203 let codes_rows = match sqlx::query!(
204 r#"
205 SELECT
206 ic.code,
207 ic.available_uses,
208 ic.created_at,
209 ic.disabled,
210 ic.for_account,
211 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
212 FROM invite_codes ic
213 WHERE ic.for_account = $1
214 ORDER BY ic.created_at DESC
215 "#,
216 &auth_user.did
217 )
218 .fetch_all(&state.db)
219 .await
220 {
221 Ok(rows) => rows,
222 Err(e) => {
223 error!("DB error fetching invite codes: {:?}", e);
224 return ApiError::InternalError(None).into_response();
225 }
226 };
227
228 let mut codes = Vec::new();
229 for row in codes_rows {
230 let disabled = row.disabled.unwrap_or(false);
231 if disabled {
232 continue;
233 }
234
235 let use_count = row.use_count;
236 if !include_used && use_count >= row.available_uses {
237 continue;
238 }
239
240 let uses = sqlx::query!(
241 r#"
242 SELECT u.did, u.handle, icu.used_at
243 FROM invite_code_uses icu
244 JOIN users u ON icu.used_by_user = u.id
245 WHERE icu.code = $1
246 ORDER BY icu.used_at DESC
247 "#,
248 row.code
249 )
250 .fetch_all(&state.db)
251 .await
252 .map(|use_rows| {
253 use_rows
254 .iter()
255 .map(|u| InviteCodeUse {
256 used_by: u.did.clone(),
257 used_by_handle: Some(u.handle.clone()),
258 used_at: u.used_at.to_rfc3339(),
259 })
260 .collect()
261 })
262 .unwrap_or_default();
263
264 codes.push(InviteCode {
265 code: row.code,
266 available: row.available_uses,
267 disabled,
268 for_account: row.for_account,
269 created_by: "admin".to_string(),
270 created_at: row.created_at.to_rfc3339(),
271 uses,
272 });
273 }
274
275 Json(GetAccountInviteCodesOutput { codes }).into_response()
276}