this repo has no description
1use crate::api::ApiError; 2use crate::auth::extractor::BearerAuthAdmin; 3use crate::auth::BearerAuth; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9}; 10use rand::Rng; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567"; 15 16fn gen_random_token() -> String { 17 let mut rng = rand::thread_rng(); 18 let mut token = String::with_capacity(11); 19 for i in 0..10 { 20 if i == 5 { 21 token.push('-'); 22 } 23 let idx = rng.gen_range(0..32); 24 token.push(BASE32_ALPHABET[idx] as char); 25 } 26 token 27} 28 29fn gen_invite_code() -> String { 30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 31 let hostname_prefix = hostname.replace('.', "-"); 32 format!("{}-{}", hostname_prefix, gen_random_token()) 33} 34 35#[derive(Deserialize)] 36#[serde(rename_all = "camelCase")] 37pub struct CreateInviteCodeInput { 38 pub use_count: i32, 39 pub for_account: Option<String>, 40} 41 42#[derive(Serialize)] 43pub struct CreateInviteCodeOutput { 44 pub code: String, 45} 46 47pub async fn create_invite_code( 48 State(state): State<AppState>, 49 BearerAuthAdmin(auth_user): BearerAuthAdmin, 50 Json(input): Json<CreateInviteCodeInput>, 51) -> Response { 52 if input.use_count < 1 { 53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 54 } 55 56 let for_account = input.for_account.unwrap_or_else(|| auth_user.did.clone()); 57 let code = gen_invite_code(); 58 59 match sqlx::query!( 60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 62 code, 63 input.use_count, 64 for_account 65 ) 66 .execute(&state.db) 67 .await 68 { 69 Ok(result) => { 70 if result.rows_affected() == 0 { 71 error!("No admin user found to create invite code"); 72 return ApiError::InternalError.into_response(); 73 } 74 Json(CreateInviteCodeOutput { code }).into_response() 75 } 76 Err(e) => { 77 error!("DB error creating invite code: {:?}", e); 78 ApiError::InternalError.into_response() 79 } 80 } 81} 82 83#[derive(Deserialize)] 84#[serde(rename_all = "camelCase")] 85pub struct CreateInviteCodesInput { 86 pub code_count: Option<i32>, 87 pub use_count: i32, 88 pub for_accounts: Option<Vec<String>>, 89} 90 91#[derive(Serialize)] 92pub struct CreateInviteCodesOutput { 93 pub codes: Vec<AccountCodes>, 94} 95 96#[derive(Serialize)] 97pub struct AccountCodes { 98 pub account: String, 99 pub codes: Vec<String>, 100} 101 102pub async fn create_invite_codes( 103 State(state): State<AppState>, 104 BearerAuthAdmin(auth_user): BearerAuthAdmin, 105 Json(input): Json<CreateInviteCodesInput>, 106) -> Response { 107 if input.use_count < 1 { 108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 109 } 110 111 let code_count = input.code_count.unwrap_or(1).max(1); 112 let for_accounts = input 113 .for_accounts 114 .filter(|v| !v.is_empty()) 115 .unwrap_or_else(|| vec![auth_user.did.clone()]); 116 117 let admin_user_id = match sqlx::query_scalar!( 118 "SELECT id FROM users WHERE is_admin = true LIMIT 1" 119 ) 120 .fetch_optional(&state.db) 121 .await 122 { 123 Ok(Some(id)) => id, 124 Ok(None) => { 125 error!("No admin user found to create invite codes"); 126 return ApiError::InternalError.into_response(); 127 } 128 Err(e) => { 129 error!("DB error looking up admin user: {:?}", e); 130 return ApiError::InternalError.into_response(); 131 } 132 }; 133 134 let mut result_codes = Vec::new(); 135 136 for account in for_accounts { 137 let mut codes = Vec::new(); 138 for _ in 0..code_count { 139 let code = gen_invite_code(); 140 if let Err(e) = sqlx::query!( 141 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)", 142 code, 143 input.use_count, 144 admin_user_id, 145 account 146 ) 147 .execute(&state.db) 148 .await 149 { 150 error!("DB error creating invite code: {:?}", e); 151 return ApiError::InternalError.into_response(); 152 } 153 codes.push(code); 154 } 155 result_codes.push(AccountCodes { account, codes }); 156 } 157 158 Json(CreateInviteCodesOutput { 159 codes: result_codes, 160 }) 161 .into_response() 162} 163 164#[derive(Deserialize)] 165#[serde(rename_all = "camelCase")] 166pub struct GetAccountInviteCodesParams { 167 pub include_used: Option<bool>, 168 pub create_available: Option<bool>, 169} 170 171#[derive(Serialize)] 172#[serde(rename_all = "camelCase")] 173pub struct InviteCode { 174 pub code: String, 175 pub available: i32, 176 pub disabled: bool, 177 pub for_account: String, 178 pub created_by: String, 179 pub created_at: String, 180 pub uses: Vec<InviteCodeUse>, 181} 182 183#[derive(Serialize)] 184#[serde(rename_all = "camelCase")] 185pub struct InviteCodeUse { 186 pub used_by: String, 187 #[serde(skip_serializing_if = "Option::is_none")] 188 pub used_by_handle: Option<String>, 189 pub used_at: String, 190} 191 192#[derive(Serialize)] 193pub struct GetAccountInviteCodesOutput { 194 pub codes: Vec<InviteCode>, 195} 196 197pub async fn get_account_invite_codes( 198 State(state): State<AppState>, 199 BearerAuth(auth_user): BearerAuth, 200 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 201) -> Response { 202 let include_used = params.include_used.unwrap_or(true); 203 204 let codes_rows = match sqlx::query!( 205 r#" 206 SELECT 207 ic.code, 208 ic.available_uses, 209 ic.created_at, 210 ic.disabled, 211 ic.for_account, 212 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 213 FROM invite_codes ic 214 WHERE ic.for_account = $1 215 ORDER BY ic.created_at DESC 216 "#, 217 auth_user.did 218 ) 219 .fetch_all(&state.db) 220 .await 221 { 222 Ok(rows) => rows, 223 Err(e) => { 224 error!("DB error fetching invite codes: {:?}", e); 225 return ApiError::InternalError.into_response(); 226 } 227 }; 228 229 let mut codes = Vec::new(); 230 for row in codes_rows { 231 let disabled = row.disabled.unwrap_or(false); 232 if disabled { 233 continue; 234 } 235 236 let use_count = row.use_count; 237 if !include_used && use_count >= row.available_uses { 238 continue; 239 } 240 241 let uses = sqlx::query!( 242 r#" 243 SELECT u.did, u.handle, icu.used_at 244 FROM invite_code_uses icu 245 JOIN users u ON icu.used_by_user = u.id 246 WHERE icu.code = $1 247 ORDER BY icu.used_at DESC 248 "#, 249 row.code 250 ) 251 .fetch_all(&state.db) 252 .await 253 .map(|use_rows| { 254 use_rows 255 .iter() 256 .map(|u| InviteCodeUse { 257 used_by: u.did.clone(), 258 used_by_handle: Some(u.handle.clone()), 259 used_at: u.used_at.to_rfc3339(), 260 }) 261 .collect() 262 }) 263 .unwrap_or_default(); 264 265 codes.push(InviteCode { 266 code: row.code, 267 available: row.available_uses, 268 disabled, 269 for_account: row.for_account, 270 created_by: "admin".to_string(), 271 created_at: row.created_at.to_rfc3339(), 272 uses, 273 }); 274 } 275 276 Json(GetAccountInviteCodesOutput { codes }).into_response() 277}