this repo has no description
1use crate::api::ApiError; 2use crate::auth::extractor::BearerAuthAdmin; 3use crate::auth::BearerAuth; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9}; 10use rand::Rng; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567"; 15 16fn gen_random_token() -> String { 17 let mut rng = rand::thread_rng(); 18 let mut token = String::with_capacity(11); 19 for i in 0..10 { 20 if i == 5 { 21 token.push('-'); 22 } 23 let idx = rng.gen_range(0..32); 24 token.push(BASE32_ALPHABET[idx] as char); 25 } 26 token 27} 28 29fn gen_invite_code() -> String { 30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 31 let hostname_prefix = hostname.replace('.', "-"); 32 format!("{}-{}", hostname_prefix, gen_random_token()) 33} 34 35#[derive(Deserialize)] 36#[serde(rename_all = "camelCase")] 37pub struct CreateInviteCodeInput { 38 pub use_count: i32, 39 pub for_account: Option<String>, 40} 41 42#[derive(Serialize)] 43pub struct CreateInviteCodeOutput { 44 pub code: String, 45} 46 47pub async fn create_invite_code( 48 State(state): State<AppState>, 49 BearerAuthAdmin(_auth_user): BearerAuthAdmin, 50 Json(input): Json<CreateInviteCodeInput>, 51) -> Response { 52 if input.use_count < 1 { 53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 54 } 55 56 let for_account = input.for_account.unwrap_or_else(|| "admin".to_string()); 57 let code = gen_invite_code(); 58 59 match sqlx::query!( 60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 62 code, 63 input.use_count, 64 for_account 65 ) 66 .execute(&state.db) 67 .await 68 { 69 Ok(result) => { 70 if result.rows_affected() == 0 { 71 error!("No admin user found to create invite code"); 72 return ApiError::InternalError.into_response(); 73 } 74 Json(CreateInviteCodeOutput { code }).into_response() 75 } 76 Err(e) => { 77 error!("DB error creating invite code: {:?}", e); 78 ApiError::InternalError.into_response() 79 } 80 } 81} 82 83#[derive(Deserialize)] 84#[serde(rename_all = "camelCase")] 85pub struct CreateInviteCodesInput { 86 pub code_count: Option<i32>, 87 pub use_count: i32, 88 pub for_accounts: Option<Vec<String>>, 89} 90 91#[derive(Serialize)] 92pub struct CreateInviteCodesOutput { 93 pub codes: Vec<AccountCodes>, 94} 95 96#[derive(Serialize)] 97pub struct AccountCodes { 98 pub account: String, 99 pub codes: Vec<String>, 100} 101 102pub async fn create_invite_codes( 103 State(state): State<AppState>, 104 BearerAuthAdmin(_auth_user): BearerAuthAdmin, 105 Json(input): Json<CreateInviteCodesInput>, 106) -> Response { 107 if input.use_count < 1 { 108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 109 } 110 111 let code_count = input.code_count.unwrap_or(1).max(1); 112 let for_accounts = input 113 .for_accounts 114 .filter(|v| !v.is_empty()) 115 .unwrap_or_else(|| vec!["admin".to_string()]); 116 117 let admin_user_id = match sqlx::query_scalar!( 118 "SELECT id FROM users WHERE is_admin = true LIMIT 1" 119 ) 120 .fetch_optional(&state.db) 121 .await 122 { 123 Ok(Some(id)) => id, 124 Ok(None) => { 125 error!("No admin user found to create invite codes"); 126 return ApiError::InternalError.into_response(); 127 } 128 Err(e) => { 129 error!("DB error looking up admin user: {:?}", e); 130 return ApiError::InternalError.into_response(); 131 } 132 }; 133 134 let mut result_codes = Vec::new(); 135 136 for account in for_accounts { 137 let mut codes = Vec::new(); 138 for _ in 0..code_count { 139 let code = gen_invite_code(); 140 if let Err(e) = sqlx::query!( 141 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)", 142 code, 143 input.use_count, 144 admin_user_id, 145 account 146 ) 147 .execute(&state.db) 148 .await 149 { 150 error!("DB error creating invite code: {:?}", e); 151 return ApiError::InternalError.into_response(); 152 } 153 codes.push(code); 154 } 155 result_codes.push(AccountCodes { account, codes }); 156 } 157 158 Json(CreateInviteCodesOutput { 159 codes: result_codes, 160 }) 161 .into_response() 162} 163 164#[derive(Deserialize)] 165#[serde(rename_all = "camelCase")] 166pub struct GetAccountInviteCodesParams { 167 pub include_used: Option<bool>, 168 pub create_available: Option<bool>, 169} 170 171#[derive(Serialize)] 172#[serde(rename_all = "camelCase")] 173pub struct InviteCode { 174 pub code: String, 175 pub available: i32, 176 pub disabled: bool, 177 pub for_account: String, 178 pub created_by: String, 179 pub created_at: String, 180 pub uses: Vec<InviteCodeUse>, 181} 182 183#[derive(Serialize)] 184#[serde(rename_all = "camelCase")] 185pub struct InviteCodeUse { 186 pub used_by: String, 187 pub used_at: String, 188} 189 190#[derive(Serialize)] 191pub struct GetAccountInviteCodesOutput { 192 pub codes: Vec<InviteCode>, 193} 194 195pub async fn get_account_invite_codes( 196 State(state): State<AppState>, 197 BearerAuth(auth_user): BearerAuth, 198 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 199) -> Response { 200 let include_used = params.include_used.unwrap_or(true); 201 202 let codes_rows = match sqlx::query!( 203 r#" 204 SELECT 205 ic.code, 206 ic.available_uses, 207 ic.created_at, 208 ic.disabled, 209 ic.for_account, 210 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 211 FROM invite_codes ic 212 WHERE ic.for_account = $1 213 ORDER BY ic.created_at DESC 214 "#, 215 auth_user.did 216 ) 217 .fetch_all(&state.db) 218 .await 219 { 220 Ok(rows) => rows, 221 Err(e) => { 222 error!("DB error fetching invite codes: {:?}", e); 223 return ApiError::InternalError.into_response(); 224 } 225 }; 226 227 let mut codes = Vec::new(); 228 for row in codes_rows { 229 let disabled = row.disabled.unwrap_or(false); 230 if disabled { 231 continue; 232 } 233 234 let use_count = row.use_count; 235 if !include_used && use_count >= row.available_uses { 236 continue; 237 } 238 239 let uses = sqlx::query!( 240 r#" 241 SELECT u.did, icu.used_at 242 FROM invite_code_uses icu 243 JOIN users u ON icu.used_by_user = u.id 244 WHERE icu.code = $1 245 ORDER BY icu.used_at DESC 246 "#, 247 row.code 248 ) 249 .fetch_all(&state.db) 250 .await 251 .map(|use_rows| { 252 use_rows 253 .iter() 254 .map(|u| InviteCodeUse { 255 used_by: u.did.clone(), 256 used_at: u.used_at.to_rfc3339(), 257 }) 258 .collect() 259 }) 260 .unwrap_or_default(); 261 262 codes.push(InviteCode { 263 code: row.code, 264 available: row.available_uses, 265 disabled, 266 for_account: row.for_account, 267 created_by: "admin".to_string(), 268 created_at: row.created_at.to_rfc3339(), 269 uses, 270 }); 271 } 272 273 Json(GetAccountInviteCodesOutput { codes }).into_response() 274}