···11+ALTER TABLE invite_codes ADD COLUMN IF NOT EXISTS for_account TEXT NOT NULL DEFAULT 'admin';
22+CREATE INDEX IF NOT EXISTS idx_invite_codes_for_account ON invite_codes(for_account);
+99-112
src/api/server/invite.rs
···11use crate::api::ApiError;
22+use crate::auth::extractor::BearerAuthAdmin;
23use crate::auth::BearerAuth;
34use crate::state::AppState;
44-use crate::util::get_user_id_by_did;
55use axum::{
66 Json,
77 extract::State,
88 response::{IntoResponse, Response},
99};
1010+use rand::Rng;
1011use serde::{Deserialize, Serialize};
1112use tracing::error;
1212-use uuid::Uuid;
1313+1414+const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
1515+1616+fn gen_random_token() -> String {
1717+ let mut rng = rand::thread_rng();
1818+ let mut token = String::with_capacity(11);
1919+ for i in 0..10 {
2020+ if i == 5 {
2121+ token.push('-');
2222+ }
2323+ let idx = rng.gen_range(0..32);
2424+ token.push(BASE32_ALPHABET[idx] as char);
2525+ }
2626+ token
2727+}
2828+2929+fn gen_invite_code() -> String {
3030+ let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
3131+ let hostname_prefix = hostname.replace('.', "-");
3232+ format!("{}-{}", hostname_prefix, gen_random_token())
3333+}
13341435#[derive(Deserialize)]
1536#[serde(rename_all = "camelCase")]
···25462647pub async fn create_invite_code(
2748 State(state): State<AppState>,
2828- BearerAuth(auth_user): BearerAuth,
4949+ BearerAuthAdmin(_auth_user): BearerAuthAdmin,
2950 Json(input): Json<CreateInviteCodeInput>,
3051) -> Response {
3152 if input.use_count < 1 {
3253 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
3354 }
3434- let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
3535- Ok(id) => id,
3636- Err(e) => return ApiError::from(e).into_response(),
3737- };
3838- let creator_user_id = if let Some(for_account) = &input.for_account {
3939- match sqlx::query!("SELECT id FROM users WHERE did = $1", for_account)
4040- .fetch_optional(&state.db)
4141- .await
4242- {
4343- Ok(Some(row)) => row.id,
4444- Ok(None) => return ApiError::AccountNotFound.into_response(),
4545- Err(e) => {
4646- error!("DB error looking up target account: {:?}", e);
4747- return ApiError::InternalError.into_response();
4848- }
4949- }
5050- } else {
5151- user_id
5252- };
5353- let user_invites_disabled = sqlx::query_scalar!(
5454- "SELECT invites_disabled FROM users WHERE did = $1",
5555- auth_user.did
5656- )
5757- .fetch_optional(&state.db)
5858- .await
5959- .map_err(|e| {
6060- error!("DB error checking invites_disabled: {:?}", e);
6161- ApiError::InternalError
6262- })
6363- .ok()
6464- .flatten()
6565- .flatten()
6666- .unwrap_or(false);
6767- if user_invites_disabled {
6868- return ApiError::InvitesDisabled.into_response();
6969- }
7070- let code = Uuid::new_v4().to_string();
5555+5656+ let for_account = input.for_account.unwrap_or_else(|| "admin".to_string());
5757+ let code = gen_invite_code();
5858+7159 match sqlx::query!(
7272- "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
6060+ "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account)
6161+ SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1",
7362 code,
7463 input.use_count,
7575- creator_user_id
6464+ for_account
7665 )
7766 .execute(&state.db)
7867 .await
7968 {
8080- Ok(_) => Json(CreateInviteCodeOutput { code }).into_response(),
6969+ Ok(result) => {
7070+ if result.rows_affected() == 0 {
7171+ error!("No admin user found to create invite code");
7272+ return ApiError::InternalError.into_response();
7373+ }
7474+ Json(CreateInviteCodeOutput { code }).into_response()
7575+ }
8176 Err(e) => {
8277 error!("DB error creating invite code: {:?}", e);
8378 ApiError::InternalError.into_response()
···106101107102pub async fn create_invite_codes(
108103 State(state): State<AppState>,
109109- BearerAuth(auth_user): BearerAuth,
104104+ BearerAuthAdmin(_auth_user): BearerAuthAdmin,
110105 Json(input): Json<CreateInviteCodesInput>,
111106) -> Response {
112107 if input.use_count < 1 {
113108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response();
114109 }
115115- let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
116116- Ok(id) => id,
117117- Err(e) => return ApiError::from(e).into_response(),
110110+111111+ let code_count = input.code_count.unwrap_or(1).max(1);
112112+ let for_accounts = input
113113+ .for_accounts
114114+ .filter(|v| !v.is_empty())
115115+ .unwrap_or_else(|| vec!["admin".to_string()]);
116116+117117+ let admin_user_id = match sqlx::query_scalar!(
118118+ "SELECT id FROM users WHERE is_admin = true LIMIT 1"
119119+ )
120120+ .fetch_optional(&state.db)
121121+ .await
122122+ {
123123+ Ok(Some(id)) => id,
124124+ Ok(None) => {
125125+ error!("No admin user found to create invite codes");
126126+ return ApiError::InternalError.into_response();
127127+ }
128128+ Err(e) => {
129129+ error!("DB error looking up admin user: {:?}", e);
130130+ return ApiError::InternalError.into_response();
131131+ }
118132 };
119119- let code_count = input.code_count.unwrap_or(1).max(1);
120120- let for_accounts = input.for_accounts.unwrap_or_default();
133133+121134 let mut result_codes = Vec::new();
122122- if for_accounts.is_empty() {
135135+136136+ for account in for_accounts {
123137 let mut codes = Vec::new();
124138 for _ in 0..code_count {
125125- let code = Uuid::new_v4().to_string();
139139+ let code = gen_invite_code();
126140 if let Err(e) = sqlx::query!(
127127- "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
141141+ "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)",
128142 code,
129143 input.use_count,
130130- user_id
144144+ admin_user_id,
145145+ account
131146 )
132147 .execute(&state.db)
133148 .await
···137152 }
138153 codes.push(code);
139154 }
140140- result_codes.push(AccountCodes {
141141- account: "admin".to_string(),
142142- codes,
143143- });
144144- } else {
145145- for account_did in for_accounts {
146146- let target_user_id =
147147- match sqlx::query!("SELECT id FROM users WHERE did = $1", account_did)
148148- .fetch_optional(&state.db)
149149- .await
150150- {
151151- Ok(Some(row)) => row.id,
152152- Ok(None) => continue,
153153- Err(e) => {
154154- error!("DB error looking up target account: {:?}", e);
155155- return ApiError::InternalError.into_response();
156156- }
157157- };
158158- let mut codes = Vec::new();
159159- for _ in 0..code_count {
160160- let code = Uuid::new_v4().to_string();
161161- if let Err(e) = sqlx::query!(
162162- "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)",
163163- code,
164164- input.use_count,
165165- target_user_id
166166- )
167167- .execute(&state.db)
168168- .await
169169- {
170170- error!("DB error creating invite code: {:?}", e);
171171- return ApiError::InternalError.into_response();
172172- }
173173- codes.push(code);
174174- }
175175- result_codes.push(AccountCodes {
176176- account: account_did,
177177- codes,
178178- });
179179- }
155155+ result_codes.push(AccountCodes { account, codes });
180156 }
157157+181158 Json(CreateInviteCodesOutput {
182159 codes: result_codes,
183160 })
···220197 BearerAuth(auth_user): BearerAuth,
221198 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>,
222199) -> Response {
223223- let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await {
224224- Ok(id) => id,
225225- Err(e) => return ApiError::from(e).into_response(),
226226- };
227200 let include_used = params.include_used.unwrap_or(true);
201201+228202 let codes_rows = match sqlx::query!(
229203 r#"
230230- SELECT code, available_uses, created_at, disabled
231231- FROM invite_codes
232232- WHERE created_by_user = $1
233233- ORDER BY created_at DESC
204204+ SELECT
205205+ ic.code,
206206+ ic.available_uses,
207207+ ic.created_at,
208208+ ic.disabled,
209209+ ic.for_account,
210210+ (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!"
211211+ FROM invite_codes ic
212212+ WHERE ic.for_account = $1
213213+ ORDER BY ic.created_at DESC
234214 "#,
235235- user_id
215215+ auth_user.did
236216 )
237217 .fetch_all(&state.db)
238218 .await
239219 {
240240- Ok(rows) => {
241241- if include_used {
242242- rows
243243- } else {
244244- rows.into_iter().filter(|r| r.available_uses > 0).collect()
245245- }
246246- }
220220+ Ok(rows) => rows,
247221 Err(e) => {
248222 error!("DB error fetching invite codes: {:?}", e);
249223 return ApiError::InternalError.into_response();
250224 }
251225 };
226226+252227 let mut codes = Vec::new();
253228 for row in codes_rows {
229229+ let disabled = row.disabled.unwrap_or(false);
230230+ if disabled {
231231+ continue;
232232+ }
233233+234234+ let use_count = row.use_count;
235235+ if !include_used && use_count >= row.available_uses {
236236+ continue;
237237+ }
238238+254239 let uses = sqlx::query!(
255240 r#"
256241 SELECT u.did, icu.used_at
···273258 .collect()
274259 })
275260 .unwrap_or_default();
261261+276262 codes.push(InviteCode {
277263 code: row.code,
278264 available: row.available_uses,
279279- disabled: row.disabled.unwrap_or(false),
280280- for_account: auth_user.did.clone(),
281281- created_by: auth_user.did.clone(),
265265+ disabled,
266266+ for_account: row.for_account,
267267+ created_by: "admin".to_string(),
282268 created_at: row.created_at.to_rfc3339(),
283269 uses,
284270 });
285271 }
272272+286273 Json(GetAccountInviteCodesOutput { codes }).into_response()
287274}
+11-88
tests/admin_invite.rs
···8484}
85858686#[tokio::test]
8787-async fn test_disable_account_invites_success() {
8888- let client = client();
8989- let (access_jwt, did) = create_admin_account_and_login(&client).await;
9090- let payload = json!({
9191- "account": did
9292- });
9393- let res = client
9494- .post(format!(
9595- "{}/xrpc/com.atproto.admin.disableAccountInvites",
9696- base_url().await
9797- ))
9898- .bearer_auth(&access_jwt)
9999- .json(&payload)
100100- .send()
101101- .await
102102- .expect("Failed to send request");
103103- assert_eq!(res.status(), StatusCode::OK);
104104- let create_payload = json!({
105105- "useCount": 1
106106- });
107107- let res = client
108108- .post(format!(
109109- "{}/xrpc/com.atproto.server.createInviteCode",
110110- base_url().await
111111- ))
112112- .bearer_auth(&access_jwt)
113113- .json(&create_payload)
114114- .send()
115115- .await
116116- .expect("Failed to send request");
117117- assert_eq!(res.status(), StatusCode::FORBIDDEN);
118118- let body: Value = res.json().await.expect("Response was not valid JSON");
119119- assert_eq!(body["error"], "InvitesDisabled");
120120-}
121121-122122-#[tokio::test]
123123-async fn test_enable_account_invites_success() {
124124- let client = client();
125125- let (access_jwt, did) = create_admin_account_and_login(&client).await;
126126- let disable_payload = json!({
127127- "account": did
128128- });
129129- let _ = client
130130- .post(format!(
131131- "{}/xrpc/com.atproto.admin.disableAccountInvites",
132132- base_url().await
133133- ))
134134- .bearer_auth(&access_jwt)
135135- .json(&disable_payload)
136136- .send()
137137- .await;
138138- let enable_payload = json!({
139139- "account": did
140140- });
141141- let res = client
142142- .post(format!(
143143- "{}/xrpc/com.atproto.admin.enableAccountInvites",
144144- base_url().await
145145- ))
146146- .bearer_auth(&access_jwt)
147147- .json(&enable_payload)
148148- .send()
149149- .await
150150- .expect("Failed to send request");
151151- assert_eq!(res.status(), StatusCode::OK);
152152- let create_payload = json!({
153153- "useCount": 1
154154- });
155155- let res = client
156156- .post(format!(
157157- "{}/xrpc/com.atproto.server.createInviteCode",
158158- base_url().await
159159- ))
160160- .bearer_auth(&access_jwt)
161161- .json(&create_payload)
162162- .send()
163163- .await
164164- .expect("Failed to send request");
165165- assert_eq!(res.status(), StatusCode::OK);
166166-}
167167-168168-#[tokio::test]
16987async fn test_disable_account_invites_no_auth() {
17088 let client = client();
17189 let payload = json!({
···206124#[tokio::test]
207125async fn test_disable_invite_codes_by_code() {
208126 let client = client();
209209- let (access_jwt, _did) = create_admin_account_and_login(&client).await;
127127+ let (access_jwt, admin_did) = create_admin_account_and_login(&client).await;
210128 let create_payload = json!({
211211- "useCount": 5
129129+ "useCount": 5,
130130+ "forAccount": admin_did
212131 });
213132 let create_res = client
214133 .post(format!(
···236155 .await
237156 .expect("Failed to send request");
238157 assert_eq!(res.status(), StatusCode::OK);
158158+239159 let list_res = client
240160 .get(format!(
241241- "{}/xrpc/com.atproto.server.getAccountInviteCodes",
161161+ "{}/xrpc/com.atproto.admin.getInviteCodes",
242162 base_url().await
243163 ))
244164 .bearer_auth(&access_jwt)
···258178 let (access_jwt, did) = create_admin_account_and_login(&client).await;
259179 for _ in 0..3 {
260180 let create_payload = json!({
261261- "useCount": 1
181181+ "useCount": 1,
182182+ "forAccount": did
262183 });
263184 let _ = client
264185 .post(format!(
···284205 .await
285206 .expect("Failed to send request");
286207 assert_eq!(res.status(), StatusCode::OK);
208208+287209 let list_res = client
288210 .get(format!(
289289- "{}/xrpc/com.atproto.server.getAccountInviteCodes",
211211+ "{}/xrpc/com.atproto.admin.getInviteCodes",
290212 base_url().await
291213 ))
292214 .bearer_auth(&access_jwt)
···295217 .expect("Failed to get invite codes");
296218 let list_body: Value = list_res.json().await.unwrap();
297219 let codes = list_body["codes"].as_array().unwrap();
298298- for code in codes {
220220+ let admin_codes: Vec<_> = codes.iter().filter(|c| c["forAccount"].as_str() == Some(&did)).collect();
221221+ for code in admin_codes {
299222 assert_eq!(code["disabled"], true);
300223 }
301224}
+124-14
tests/invite.rs
···66#[tokio::test]
77async fn test_create_invite_code_success() {
88 let client = client();
99- let (access_jwt, _did) = create_account_and_login(&client).await;
99+ let (access_jwt, _did) = create_admin_account_and_login(&client).await;
1010 let payload = json!({
1111 "useCount": 5
1212 });
···2525 assert!(body["code"].is_string());
2626 let code = body["code"].as_str().unwrap();
2727 assert!(!code.is_empty());
2828- assert!(code.contains('-'), "Code should be a UUID format");
2828+ assert!(code.contains('-'), "Code should be in hostname-xxxxx-xxxxx format");
2929+ let parts: Vec<&str> = code.split('-').collect();
3030+ assert!(parts.len() >= 3, "Code should have at least 3 parts (hostname + 2 random parts)");
2931}
30323133#[tokio::test]
···4951}
50525153#[tokio::test]
5454+async fn test_create_invite_code_non_admin() {
5555+ let client = client();
5656+ let (access_jwt, _did) = create_account_and_login(&client).await;
5757+ let payload = json!({
5858+ "useCount": 5
5959+ });
6060+ let res = client
6161+ .post(format!(
6262+ "{}/xrpc/com.atproto.server.createInviteCode",
6363+ base_url().await
6464+ ))
6565+ .bearer_auth(&access_jwt)
6666+ .json(&payload)
6767+ .send()
6868+ .await
6969+ .expect("Failed to send request");
7070+ assert_eq!(res.status(), StatusCode::FORBIDDEN);
7171+ let body: Value = res.json().await.expect("Response was not valid JSON");
7272+ assert_eq!(body["error"], "AdminRequired");
7373+}
7474+7575+#[tokio::test]
5276async fn test_create_invite_code_invalid_use_count() {
5377 let client = client();
5454- let (access_jwt, _did) = create_account_and_login(&client).await;
7878+ let (access_jwt, _did) = create_admin_account_and_login(&client).await;
5579 let payload = json!({
5680 "useCount": 0
5781 });
···7397#[tokio::test]
7498async fn test_create_invite_code_for_another_account() {
7599 let client = client();
7676- let (access_jwt1, _did1) = create_account_and_login(&client).await;
100100+ let (access_jwt1, _did1) = create_admin_account_and_login(&client).await;
77101 let (_access_jwt2, did2) = create_account_and_login(&client).await;
78102 let payload = json!({
79103 "useCount": 3,
···97121#[tokio::test]
98122async fn test_create_invite_codes_success() {
99123 let client = client();
100100- let (access_jwt, _did) = create_account_and_login(&client).await;
124124+ let (access_jwt, _did) = create_admin_account_and_login(&client).await;
101125 let payload = json!({
102126 "useCount": 2,
103127 "codeCount": 3
···117141 assert!(body["codes"].is_array());
118142 let codes = body["codes"].as_array().unwrap();
119143 assert_eq!(codes.len(), 1);
144144+ assert_eq!(codes[0]["account"], "admin");
120145 assert_eq!(codes[0]["codes"].as_array().unwrap().len(), 3);
121146}
122147123148#[tokio::test]
124149async fn test_create_invite_codes_for_multiple_accounts() {
125150 let client = client();
126126- let (access_jwt1, did1) = create_account_and_login(&client).await;
151151+ let (access_jwt1, did1) = create_admin_account_and_login(&client).await;
127152 let (_access_jwt2, did2) = create_account_and_login(&client).await;
128153 let payload = json!({
129154 "useCount": 1,
···169194}
170195171196#[tokio::test]
172172-async fn test_get_account_invite_codes_success() {
197197+async fn test_create_invite_codes_non_admin() {
173198 let client = client();
174199 let (access_jwt, _did) = create_account_and_login(&client).await;
200200+ let payload = json!({
201201+ "useCount": 2
202202+ });
203203+ let res = client
204204+ .post(format!(
205205+ "{}/xrpc/com.atproto.server.createInviteCodes",
206206+ base_url().await
207207+ ))
208208+ .bearer_auth(&access_jwt)
209209+ .json(&payload)
210210+ .send()
211211+ .await
212212+ .expect("Failed to send request");
213213+ assert_eq!(res.status(), StatusCode::FORBIDDEN);
214214+ let body: Value = res.json().await.expect("Response was not valid JSON");
215215+ assert_eq!(body["error"], "AdminRequired");
216216+}
217217+218218+#[tokio::test]
219219+async fn test_get_account_invite_codes_success() {
220220+ let client = client();
221221+ let (admin_jwt, _admin_did) = create_admin_account_and_login(&client).await;
222222+ let (user_jwt, user_did) = create_account_and_login(&client).await;
223223+175224 let create_payload = json!({
176176- "useCount": 5
225225+ "useCount": 5,
226226+ "forAccount": user_did
177227 });
178228 let _ = client
179229 .post(format!(
180230 "{}/xrpc/com.atproto.server.createInviteCode",
181231 base_url().await
182232 ))
183183- .bearer_auth(&access_jwt)
233233+ .bearer_auth(&admin_jwt)
184234 .json(&create_payload)
185235 .send()
186236 .await
187237 .expect("Failed to create invite code");
238238+188239 let res = client
189240 .get(format!(
190241 "{}/xrpc/com.atproto.server.getAccountInviteCodes",
191242 base_url().await
192243 ))
193193- .bearer_auth(&access_jwt)
244244+ .bearer_auth(&user_jwt)
194245 .send()
195246 .await
196247 .expect("Failed to send request");
···205256 assert!(code["disabled"].is_boolean());
206257 assert!(code["createdAt"].is_string());
207258 assert!(code["uses"].is_array());
259259+ assert_eq!(code["forAccount"], user_did);
260260+ assert_eq!(code["createdBy"], "admin");
208261}
209262210263#[tokio::test]
···224277#[tokio::test]
225278async fn test_get_account_invite_codes_include_used_filter() {
226279 let client = client();
227227- let (access_jwt, _did) = create_account_and_login(&client).await;
280280+ let (admin_jwt, _admin_did) = create_admin_account_and_login(&client).await;
281281+ let (user_jwt, user_did) = create_account_and_login(&client).await;
282282+228283 let create_payload = json!({
229229- "useCount": 5
284284+ "useCount": 5,
285285+ "forAccount": user_did
230286 });
231287 let _ = client
232288 .post(format!(
233289 "{}/xrpc/com.atproto.server.createInviteCode",
234290 base_url().await
235291 ))
236236- .bearer_auth(&access_jwt)
292292+ .bearer_auth(&admin_jwt)
237293 .json(&create_payload)
238294 .send()
239295 .await
240296 .expect("Failed to create invite code");
297297+241298 let res = client
242299 .get(format!(
243300 "{}/xrpc/com.atproto.server.getAccountInviteCodes",
244301 base_url().await
245302 ))
246246- .bearer_auth(&access_jwt)
303303+ .bearer_auth(&user_jwt)
247304 .query(&[("includeUsed", "false")])
248305 .send()
249306 .await
···255312 assert!(code["available"].as_i64().unwrap() > 0);
256313 }
257314}
315315+316316+#[tokio::test]
317317+async fn test_get_account_invite_codes_filters_disabled() {
318318+ let client = client();
319319+ let (admin_jwt, admin_did) = create_admin_account_and_login(&client).await;
320320+321321+ let create_payload = json!({
322322+ "useCount": 5,
323323+ "forAccount": admin_did
324324+ });
325325+ let create_res = client
326326+ .post(format!(
327327+ "{}/xrpc/com.atproto.server.createInviteCode",
328328+ base_url().await
329329+ ))
330330+ .bearer_auth(&admin_jwt)
331331+ .json(&create_payload)
332332+ .send()
333333+ .await
334334+ .expect("Failed to create invite code");
335335+ let create_body: Value = create_res.json().await.unwrap();
336336+ let code = create_body["code"].as_str().unwrap();
337337+338338+ let disable_payload = json!({
339339+ "codes": [code]
340340+ });
341341+ let _ = client
342342+ .post(format!(
343343+ "{}/xrpc/com.atproto.admin.disableInviteCodes",
344344+ base_url().await
345345+ ))
346346+ .bearer_auth(&admin_jwt)
347347+ .json(&disable_payload)
348348+ .send()
349349+ .await
350350+ .expect("Failed to disable invite code");
351351+352352+ let res = client
353353+ .get(format!(
354354+ "{}/xrpc/com.atproto.server.getAccountInviteCodes",
355355+ base_url().await
356356+ ))
357357+ .bearer_auth(&admin_jwt)
358358+ .send()
359359+ .await
360360+ .expect("Failed to send request");
361361+ assert_eq!(res.status(), StatusCode::OK);
362362+ let body: Value = res.json().await.expect("Response was not valid JSON");
363363+ let codes = body["codes"].as_array().unwrap();
364364+ for c in codes {
365365+ assert_ne!(c["code"].as_str().unwrap(), code, "Disabled code should be filtered out");
366366+ }
367367+}