this repo has no description
1use crate::api::ApiError; 2use crate::auth::BearerAuth; 3use crate::auth::extractor::BearerAuthAdmin; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9}; 10use rand::Rng; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567"; 15 16fn gen_random_token() -> String { 17 let mut rng = rand::thread_rng(); 18 let mut token = String::with_capacity(11); 19 for i in 0..10 { 20 if i == 5 { 21 token.push('-'); 22 } 23 let idx = rng.gen_range(0..32); 24 token.push(BASE32_ALPHABET[idx] as char); 25 } 26 token 27} 28 29fn gen_invite_code() -> String { 30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 31 let hostname_prefix = hostname.replace('.', "-"); 32 format!("{}-{}", hostname_prefix, gen_random_token()) 33} 34 35#[derive(Deserialize)] 36#[serde(rename_all = "camelCase")] 37pub struct CreateInviteCodeInput { 38 pub use_count: i32, 39 pub for_account: Option<String>, 40} 41 42#[derive(Serialize)] 43pub struct CreateInviteCodeOutput { 44 pub code: String, 45} 46 47pub async fn create_invite_code( 48 State(state): State<AppState>, 49 BearerAuthAdmin(auth_user): BearerAuthAdmin, 50 Json(input): Json<CreateInviteCodeInput>, 51) -> Response { 52 if input.use_count < 1 { 53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 54 } 55 56 let for_account = input 57 .for_account 58 .unwrap_or_else(|| auth_user.did.to_string()); 59 let code = gen_invite_code(); 60 61 match sqlx::query!( 62 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 63 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 64 code, 65 input.use_count, 66 for_account 67 ) 68 .execute(&state.db) 69 .await 70 { 71 Ok(result) => { 72 if result.rows_affected() == 0 { 73 error!("No admin user found to create invite code"); 74 return ApiError::InternalError(None).into_response(); 75 } 76 Json(CreateInviteCodeOutput { code }).into_response() 77 } 78 Err(e) => { 79 error!("DB error creating invite code: {:?}", e); 80 ApiError::InternalError(None).into_response() 81 } 82 } 83} 84 85#[derive(Deserialize)] 86#[serde(rename_all = "camelCase")] 87pub struct CreateInviteCodesInput { 88 pub code_count: Option<i32>, 89 pub use_count: i32, 90 pub for_accounts: Option<Vec<String>>, 91} 92 93#[derive(Serialize)] 94pub struct CreateInviteCodesOutput { 95 pub codes: Vec<AccountCodes>, 96} 97 98#[derive(Serialize)] 99pub struct AccountCodes { 100 pub account: String, 101 pub codes: Vec<String>, 102} 103 104pub async fn create_invite_codes( 105 State(state): State<AppState>, 106 BearerAuthAdmin(auth_user): BearerAuthAdmin, 107 Json(input): Json<CreateInviteCodesInput>, 108) -> Response { 109 if input.use_count < 1 { 110 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 111 } 112 113 let code_count = input.code_count.unwrap_or(1).max(1); 114 let for_accounts = input 115 .for_accounts 116 .filter(|v| !v.is_empty()) 117 .unwrap_or_else(|| vec![auth_user.did.to_string()]); 118 119 let admin_user_id = 120 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1") 121 .fetch_optional(&state.db) 122 .await 123 { 124 Ok(Some(id)) => id, 125 Ok(None) => { 126 error!("No admin user found to create invite codes"); 127 return ApiError::InternalError(None).into_response(); 128 } 129 Err(e) => { 130 error!("DB error looking up admin user: {:?}", e); 131 return ApiError::InternalError(None).into_response(); 132 } 133 }; 134 135 let mut result_codes = Vec::new(); 136 137 for account in for_accounts { 138 let mut codes = Vec::new(); 139 for _ in 0..code_count { 140 let code = gen_invite_code(); 141 if let Err(e) = sqlx::query!( 142 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)", 143 code, 144 input.use_count, 145 admin_user_id, 146 account 147 ) 148 .execute(&state.db) 149 .await 150 { 151 error!("DB error creating invite code: {:?}", e); 152 return ApiError::InternalError(None).into_response(); 153 } 154 codes.push(code); 155 } 156 result_codes.push(AccountCodes { account, codes }); 157 } 158 159 Json(CreateInviteCodesOutput { 160 codes: result_codes, 161 }) 162 .into_response() 163} 164 165#[derive(Deserialize)] 166#[serde(rename_all = "camelCase")] 167pub struct GetAccountInviteCodesParams { 168 pub include_used: Option<bool>, 169 pub create_available: Option<bool>, 170} 171 172#[derive(Serialize)] 173#[serde(rename_all = "camelCase")] 174pub struct InviteCode { 175 pub code: String, 176 pub available: i32, 177 pub disabled: bool, 178 pub for_account: String, 179 pub created_by: String, 180 pub created_at: String, 181 pub uses: Vec<InviteCodeUse>, 182} 183 184#[derive(Serialize)] 185#[serde(rename_all = "camelCase")] 186pub struct InviteCodeUse { 187 pub used_by: String, 188 #[serde(skip_serializing_if = "Option::is_none")] 189 pub used_by_handle: Option<String>, 190 pub used_at: String, 191} 192 193#[derive(Serialize)] 194pub struct GetAccountInviteCodesOutput { 195 pub codes: Vec<InviteCode>, 196} 197 198pub async fn get_account_invite_codes( 199 State(state): State<AppState>, 200 BearerAuth(auth_user): BearerAuth, 201 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 202) -> Response { 203 let include_used = params.include_used.unwrap_or(true); 204 205 let codes_rows = match sqlx::query!( 206 r#" 207 SELECT 208 ic.code, 209 ic.available_uses, 210 ic.created_at, 211 ic.disabled, 212 ic.for_account, 213 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 214 FROM invite_codes ic 215 WHERE ic.for_account = $1 216 ORDER BY ic.created_at DESC 217 "#, 218 &auth_user.did 219 ) 220 .fetch_all(&state.db) 221 .await 222 { 223 Ok(rows) => rows, 224 Err(e) => { 225 error!("DB error fetching invite codes: {:?}", e); 226 return ApiError::InternalError(None).into_response(); 227 } 228 }; 229 230 let mut codes = Vec::new(); 231 for row in codes_rows { 232 let disabled = row.disabled.unwrap_or(false); 233 if disabled { 234 continue; 235 } 236 237 let use_count = row.use_count; 238 if !include_used && use_count >= row.available_uses { 239 continue; 240 } 241 242 let uses = sqlx::query!( 243 r#" 244 SELECT u.did, u.handle, icu.used_at 245 FROM invite_code_uses icu 246 JOIN users u ON icu.used_by_user = u.id 247 WHERE icu.code = $1 248 ORDER BY icu.used_at DESC 249 "#, 250 row.code 251 ) 252 .fetch_all(&state.db) 253 .await 254 .map(|use_rows| { 255 use_rows 256 .iter() 257 .map(|u| InviteCodeUse { 258 used_by: u.did.clone(), 259 used_by_handle: Some(u.handle.clone()), 260 used_at: u.used_at.to_rfc3339(), 261 }) 262 .collect() 263 }) 264 .unwrap_or_default(); 265 266 codes.push(InviteCode { 267 code: row.code, 268 available: row.available_uses, 269 disabled, 270 for_account: row.for_account, 271 created_by: "admin".to_string(), 272 created_at: row.created_at.to_rfc3339(), 273 uses, 274 }); 275 } 276 277 Json(GetAccountInviteCodesOutput { codes }).into_response() 278}