this repo has no description
1use crate::api::ApiError;
2use crate::auth::BearerAuth;
3use crate::auth::extractor::BearerAuthAdmin;
4use crate::state::AppState;
5use axum::{
6 Json,
7 extract::State,
8 response::{IntoResponse, Response},
9};
10use rand::Rng;
11use serde::{Deserialize, Serialize};
12use tracing::error;
13
14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
15
16fn gen_random_token() -> String {
17 let mut rng = rand::thread_rng();
18 let mut token = String::with_capacity(11);
19 for i in 0..10 {
20 if i == 5 {
21 token.push('-');
22 }
23 let idx = rng.gen_range(0..32);
24 token.push(BASE32_ALPHABET[idx] as char);
25 }
26 token
27}
28
29fn gen_invite_code() -> String {
30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
31 let hostname_prefix = hostname.replace('.', "-");
32 format!("{}-{}", hostname_prefix, gen_random_token())
33}
34
35#[derive(Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CreateInviteCodeInput {
38 pub use_count: i32,
39 pub for_account: Option<String>,
40}
41
42#[derive(Serialize)]
43pub struct CreateInviteCodeOutput {
44 pub code: String,
45}
46
47pub async fn create_invite_code(
48 State(state): State<AppState>,
49 BearerAuthAdmin(auth_user): BearerAuthAdmin,
50 Json(input): Json<CreateInviteCodeInput>,
51) -> Response {
52 if input.use_count < 1 {
53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
54 }
55
56 let for_account = input
57 .for_account
58 .unwrap_or_else(|| auth_user.did.to_string());
59 let code = gen_invite_code();
60
61 match sqlx::query!(
62 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
63 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
64 code,
65 input.use_count,
66 for_account
67 )
68 .execute(&state.db)
69 .await
70 {
71 Ok(result) => {
72 if result.rows_affected() == 0 {
73 error!("No admin user found to create invite code");
74 return ApiError::InternalError(None).into_response();
75 }
76 Json(CreateInviteCodeOutput { code }).into_response()
77 }
78 Err(e) => {
79 error!("DB error creating invite code: {:?}", e);
80 ApiError::InternalError(None).into_response()
81 }
82 }
83}
84
85#[derive(Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct CreateInviteCodesInput {
88 pub code_count: Option<i32>,
89 pub use_count: i32,
90 pub for_accounts: Option<Vec<String>>,
91}
92
93#[derive(Serialize)]
94pub struct CreateInviteCodesOutput {
95 pub codes: Vec<AccountCodes>,
96}
97
98#[derive(Serialize)]
99pub struct AccountCodes {
100 pub account: String,
101 pub codes: Vec<String>,
102}
103
104pub async fn create_invite_codes(
105 State(state): State<AppState>,
106 BearerAuthAdmin(auth_user): BearerAuthAdmin,
107 Json(input): Json<CreateInviteCodesInput>,
108) -> Response {
109 if input.use_count < 1 {
110 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
111 }
112
113 let code_count = input.code_count.unwrap_or(1).max(1);
114 let for_accounts = input
115 .for_accounts
116 .filter(|v| !v.is_empty())
117 .unwrap_or_else(|| vec![auth_user.did.to_string()]);
118
119 let admin_user_id =
120 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1")
121 .fetch_optional(&state.db)
122 .await
123 {
124 Ok(Some(id)) => id,
125 Ok(None) => {
126 error!("No admin user found to create invite codes");
127 return ApiError::InternalError(None).into_response();
128 }
129 Err(e) => {
130 error!("DB error looking up admin user: {:?}", e);
131 return ApiError::InternalError(None).into_response();
132 }
133 };
134
135 let mut result_codes = Vec::new();
136
137 for account in for_accounts {
138 let mut codes = Vec::new();
139 for _ in 0..code_count {
140 let code = gen_invite_code();
141 if let Err(e) = sqlx::query!(
142 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)",
143 code,
144 input.use_count,
145 admin_user_id,
146 account
147 )
148 .execute(&state.db)
149 .await
150 {
151 error!("DB error creating invite code: {:?}", e);
152 return ApiError::InternalError(None).into_response();
153 }
154 codes.push(code);
155 }
156 result_codes.push(AccountCodes { account, codes });
157 }
158
159 Json(CreateInviteCodesOutput {
160 codes: result_codes,
161 })
162 .into_response()
163}
164
165#[derive(Deserialize)]
166#[serde(rename_all = "camelCase")]
167pub struct GetAccountInviteCodesParams {
168 pub include_used: Option<bool>,
169 pub create_available: Option<bool>,
170}
171
172#[derive(Serialize)]
173#[serde(rename_all = "camelCase")]
174pub struct InviteCode {
175 pub code: String,
176 pub available: i32,
177 pub disabled: bool,
178 pub for_account: String,
179 pub created_by: String,
180 pub created_at: String,
181 pub uses: Vec<InviteCodeUse>,
182}
183
184#[derive(Serialize)]
185#[serde(rename_all = "camelCase")]
186pub struct InviteCodeUse {
187 pub used_by: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 pub used_by_handle: Option<String>,
190 pub used_at: String,
191}
192
193#[derive(Serialize)]
194pub struct GetAccountInviteCodesOutput {
195 pub codes: Vec<InviteCode>,
196}
197
198pub async fn get_account_invite_codes(
199 State(state): State<AppState>,
200 BearerAuth(auth_user): BearerAuth,
201 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
202) -> Response {
203 let include_used = params.include_used.unwrap_or(true);
204
205 let codes_rows = match sqlx::query!(
206 r#"
207 SELECT
208 ic.code,
209 ic.available_uses,
210 ic.created_at,
211 ic.disabled,
212 ic.for_account,
213 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
214 FROM invite_codes ic
215 WHERE ic.for_account = $1
216 ORDER BY ic.created_at DESC
217 "#,
218 &auth_user.did
219 )
220 .fetch_all(&state.db)
221 .await
222 {
223 Ok(rows) => rows,
224 Err(e) => {
225 error!("DB error fetching invite codes: {:?}", e);
226 return ApiError::InternalError(None).into_response();
227 }
228 };
229
230 let mut codes = Vec::new();
231 for row in codes_rows {
232 let disabled = row.disabled.unwrap_or(false);
233 if disabled {
234 continue;
235 }
236
237 let use_count = row.use_count;
238 if !include_used && use_count >= row.available_uses {
239 continue;
240 }
241
242 let uses = sqlx::query!(
243 r#"
244 SELECT u.did, u.handle, icu.used_at
245 FROM invite_code_uses icu
246 JOIN users u ON icu.used_by_user = u.id
247 WHERE icu.code = $1
248 ORDER BY icu.used_at DESC
249 "#,
250 row.code
251 )
252 .fetch_all(&state.db)
253 .await
254 .map(|use_rows| {
255 use_rows
256 .iter()
257 .map(|u| InviteCodeUse {
258 used_by: u.did.clone(),
259 used_by_handle: Some(u.handle.clone()),
260 used_at: u.used_at.to_rfc3339(),
261 })
262 .collect()
263 })
264 .unwrap_or_default();
265
266 codes.push(InviteCode {
267 code: row.code,
268 available: row.available_uses,
269 disabled,
270 for_account: row.for_account,
271 created_by: "admin".to_string(),
272 created_at: row.created_at.to_rfc3339(),
273 uses,
274 });
275 }
276
277 Json(GetAccountInviteCodesOutput { codes }).into_response()
278}