this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::auth::extractor::BearerAuthAdmin;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
15
16fn gen_random_token() -> String {
17 let mut rng = rand::thread_rng();
18 let gen_segment = |rng: &mut rand::rngs::ThreadRng, len: usize| -> String {
19 (0..len)
20 .map(|_| BASE32_ALPHABET[rng.gen_range(0..32)] as char)
21 .collect()
22 };
23 format!("{}-{}", gen_segment(&mut rng, 5), gen_segment(&mut rng, 5))
24}
25
26fn gen_invite_code() -> String {
27 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
28 let hostname_prefix = hostname.replace('.', "-");
29 format!("{}-{}", hostname_prefix, gen_random_token())
30}
31
32#[derive(Deserialize)]
33#[serde(rename_all = "camelCase")]
34pub struct CreateInviteCodeInput {
35 pub use_count: i32,
36 pub for_account: Option<String>,
37}
38
39#[derive(Serialize)]
40pub struct CreateInviteCodeOutput {
41 pub code: String,
42}
43
44pub async fn create_invite_code(
45 State(state): State<AppState>,
46 BearerAuthAdmin(auth_user): BearerAuthAdmin,
47 Json(input): Json<CreateInviteCodeInput>,
48) -> Response {
49 if input.use_count < 1 {
50 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
51 }
52
53 let for_account = input
54 .for_account
55 .unwrap_or_else(|| auth_user.did.to_string());
56 let code = gen_invite_code();
57
58 match sqlx::query!(
59 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
60 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
61 code,
62 input.use_count,
63 for_account
64 )
65 .execute(&state.db)
66 .await
67 {
68 Ok(result) => {
69 if result.rows_affected() == 0 {
70 error!("No admin user found to create invite code");
71 return ApiError::InternalError(None).into_response();
72 }
73 Json(CreateInviteCodeOutput { code }).into_response()
74 }
75 Err(e) => {
76 error!("DB error creating invite code: {:?}", e);
77 ApiError::InternalError(None).into_response()
78 }
79 }
80}
81
82#[derive(Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct CreateInviteCodesInput {
85 pub code_count: Option<i32>,
86 pub use_count: i32,
87 pub for_accounts: Option<Vec<String>>,
88}
89
90#[derive(Serialize)]
91pub struct CreateInviteCodesOutput {
92 pub codes: Vec<AccountCodes>,
93}
94
95#[derive(Serialize)]
96pub struct AccountCodes {
97 pub account: String,
98 pub codes: Vec<String>,
99}
100
101pub async fn create_invite_codes(
102 State(state): State<AppState>,
103 BearerAuthAdmin(auth_user): BearerAuthAdmin,
104 Json(input): Json<CreateInviteCodesInput>,
105) -> Response {
106 if input.use_count < 1 {
107 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
108 }
109
110 let code_count = input.code_count.unwrap_or(1).max(1);
111 let for_accounts = input
112 .for_accounts
113 .filter(|v| !v.is_empty())
114 .unwrap_or_else(|| vec![auth_user.did.to_string()]);
115
116 let admin_user_id =
117 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1")
118 .fetch_optional(&state.db)
119 .await
120 {
121 Ok(Some(id)) => id,
122 Ok(None) => {
123 error!("No admin user found to create invite codes");
124 return ApiError::InternalError(None).into_response();
125 }
126 Err(e) => {
127 error!("DB error looking up admin user: {:?}", e);
128 return ApiError::InternalError(None).into_response();
129 }
130 };
131
132 let result = futures::future::try_join_all(for_accounts.into_iter().map(|account| {
133 let db = state.db.clone();
134 let use_count = input.use_count;
135 async move {
136 let codes: Vec<String> = (0..code_count).map(|_| gen_invite_code()).collect();
137 sqlx::query!(
138 r#"
139 INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
140 SELECT code, $2, $3, $4 FROM UNNEST($1::text[]) AS t(code)
141 "#,
142 &codes[..],
143 use_count,
144 admin_user_id,
145 account
146 )
147 .execute(&db)
148 .await
149 .map(|_| AccountCodes { account, codes })
150 }
151 }))
152 .await;
153
154 match result {
155 Ok(result_codes) => Json(CreateInviteCodesOutput {
156 codes: result_codes,
157 })
158 .into_response(),
159 Err(e) => {
160 error!("DB error creating invite codes: {:?}", e);
161 ApiError::InternalError(None).into_response()
162 }
163 }
164}
165
166#[derive(Deserialize)]
167#[serde(rename_all = "camelCase")]
168pub struct GetAccountInviteCodesParams {
169 pub include_used: Option<bool>,
170 pub create_available: Option<bool>,
171}
172
173#[derive(Serialize)]
174#[serde(rename_all = "camelCase")]
175pub struct InviteCode {
176 pub code: String,
177 pub available: i32,
178 pub disabled: bool,
179 pub for_account: String,
180 pub created_by: String,
181 pub created_at: String,
182 pub uses: Vec<InviteCodeUse>,
183}
184
185#[derive(Serialize)]
186#[serde(rename_all = "camelCase")]
187pub struct InviteCodeUse {
188 pub used_by: String,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub used_by_handle: Option<String>,
191 pub used_at: String,
192}
193
194#[derive(Serialize)]
195pub struct GetAccountInviteCodesOutput {
196 pub codes: Vec<InviteCode>,
197}
198
199pub async fn get_account_invite_codes(
200 State(state): State<AppState>,
201 BearerAuth(auth_user): BearerAuth,
202 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
203) -> Response {
204 let include_used = params.include_used.unwrap_or(true);
205
206 let codes_rows = match sqlx::query!(
207 r#"
208 SELECT
209 ic.code,
210 ic.available_uses,
211 ic.created_at,
212 ic.disabled,
213 ic.for_account,
214 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
215 FROM invite_codes ic
216 WHERE ic.for_account = $1
217 ORDER BY ic.created_at DESC
218 "#,
219 &auth_user.did
220 )
221 .fetch_all(&state.db)
222 .await
223 {
224 Ok(rows) => rows,
225 Err(e) => {
226 error!("DB error fetching invite codes: {:?}", e);
227 return ApiError::InternalError(None).into_response();
228 }
229 };
230
231 let filtered_rows: Vec<_> = codes_rows
232 .into_iter()
233 .filter(|row| {
234 let disabled = row.disabled.unwrap_or(false);
235 !disabled && (include_used || row.use_count < row.available_uses)
236 })
237 .collect();
238
239 let codes = futures::future::join_all(filtered_rows.into_iter().map(|row| {
240 let db = state.db.clone();
241 async move {
242 let uses = sqlx::query!(
243 r#"
244 SELECT u.did, u.handle, icu.used_at
245 FROM invite_code_uses icu
246 JOIN users u ON icu.used_by_user = u.id
247 WHERE icu.code = $1
248 ORDER BY icu.used_at DESC
249 "#,
250 row.code
251 )
252 .fetch_all(&db)
253 .await
254 .map(|use_rows| {
255 use_rows
256 .iter()
257 .map(|u| InviteCodeUse {
258 used_by: u.did.clone(),
259 used_by_handle: Some(u.handle.clone()),
260 used_at: u.used_at.to_rfc3339(),
261 })
262 .collect()
263 })
264 .unwrap_or_default();
265
266 InviteCode {
267 code: row.code,
268 available: row.available_uses,
269 disabled: false,
270 for_account: row.for_account,
271 created_by: "admin".to_string(),
272 created_at: row.created_at.to_rfc3339(),
273 uses,
274 }
275 }
276 }))
277 .await;
278
279 Json(GetAccountInviteCodesOutput { codes }).into_response()
280}