this repo has no description
1use crate::api::ApiError; 2use crate::auth::BearerAuth; 3use crate::auth::extractor::BearerAuthAdmin; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9}; 10use rand::Rng; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567"; 15 16fn gen_random_token() -> String { 17 let mut rng = rand::thread_rng(); 18 let mut token = String::with_capacity(11); 19 for i in 0..10 { 20 if i == 5 { 21 token.push('-'); 22 } 23 let idx = rng.gen_range(0..32); 24 token.push(BASE32_ALPHABET[idx] as char); 25 } 26 token 27} 28 29fn gen_invite_code() -> String { 30 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 31 let hostname_prefix = hostname.replace('.', "-"); 32 format!("{}-{}", hostname_prefix, gen_random_token()) 33} 34 35#[derive(Deserialize)] 36#[serde(rename_all = "camelCase")] 37pub struct CreateInviteCodeInput { 38 pub use_count: i32, 39 pub for_account: Option<String>, 40} 41 42#[derive(Serialize)] 43pub struct CreateInviteCodeOutput { 44 pub code: String, 45} 46 47pub async fn create_invite_code( 48 State(state): State<AppState>, 49 BearerAuthAdmin(auth_user): BearerAuthAdmin, 50 Json(input): Json<CreateInviteCodeInput>, 51) -> Response { 52 if input.use_count < 1 { 53 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 54 } 55 56 let for_account = input.for_account.unwrap_or_else(|| auth_user.did.clone()); 57 let code = gen_invite_code(); 58 59 match sqlx::query!( 60 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 61 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 62 code, 63 input.use_count, 64 for_account 65 ) 66 .execute(&state.db) 67 .await 68 { 69 Ok(result) => { 70 if result.rows_affected() == 0 { 71 error!("No admin user found to create invite code"); 72 return ApiError::InternalError.into_response(); 73 } 74 Json(CreateInviteCodeOutput { code }).into_response() 75 } 76 Err(e) => { 77 error!("DB error creating invite code: {:?}", e); 78 ApiError::InternalError.into_response() 79 } 80 } 81} 82 83#[derive(Deserialize)] 84#[serde(rename_all = "camelCase")] 85pub struct CreateInviteCodesInput { 86 pub code_count: Option<i32>, 87 pub use_count: i32, 88 pub for_accounts: Option<Vec<String>>, 89} 90 91#[derive(Serialize)] 92pub struct CreateInviteCodesOutput { 93 pub codes: Vec<AccountCodes>, 94} 95 96#[derive(Serialize)] 97pub struct AccountCodes { 98 pub account: String, 99 pub codes: Vec<String>, 100} 101 102pub async fn create_invite_codes( 103 State(state): State<AppState>, 104 BearerAuthAdmin(auth_user): BearerAuthAdmin, 105 Json(input): Json<CreateInviteCodesInput>, 106) -> Response { 107 if input.use_count < 1 { 108 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 109 } 110 111 let code_count = input.code_count.unwrap_or(1).max(1); 112 let for_accounts = input 113 .for_accounts 114 .filter(|v| !v.is_empty()) 115 .unwrap_or_else(|| vec![auth_user.did.clone()]); 116 117 let admin_user_id = 118 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1") 119 .fetch_optional(&state.db) 120 .await 121 { 122 Ok(Some(id)) => id, 123 Ok(None) => { 124 error!("No admin user found to create invite codes"); 125 return ApiError::InternalError.into_response(); 126 } 127 Err(e) => { 128 error!("DB error looking up admin user: {:?}", e); 129 return ApiError::InternalError.into_response(); 130 } 131 }; 132 133 let mut result_codes = Vec::new(); 134 135 for account in for_accounts { 136 let mut codes = Vec::new(); 137 for _ in 0..code_count { 138 let code = gen_invite_code(); 139 if let Err(e) = sqlx::query!( 140 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) VALUES ($1, $2, $3, $4)", 141 code, 142 input.use_count, 143 admin_user_id, 144 account 145 ) 146 .execute(&state.db) 147 .await 148 { 149 error!("DB error creating invite code: {:?}", e); 150 return ApiError::InternalError.into_response(); 151 } 152 codes.push(code); 153 } 154 result_codes.push(AccountCodes { account, codes }); 155 } 156 157 Json(CreateInviteCodesOutput { 158 codes: result_codes, 159 }) 160 .into_response() 161} 162 163#[derive(Deserialize)] 164#[serde(rename_all = "camelCase")] 165pub struct GetAccountInviteCodesParams { 166 pub include_used: Option<bool>, 167 pub create_available: Option<bool>, 168} 169 170#[derive(Serialize)] 171#[serde(rename_all = "camelCase")] 172pub struct InviteCode { 173 pub code: String, 174 pub available: i32, 175 pub disabled: bool, 176 pub for_account: String, 177 pub created_by: String, 178 pub created_at: String, 179 pub uses: Vec<InviteCodeUse>, 180} 181 182#[derive(Serialize)] 183#[serde(rename_all = "camelCase")] 184pub struct InviteCodeUse { 185 pub used_by: String, 186 #[serde(skip_serializing_if = "Option::is_none")] 187 pub used_by_handle: Option<String>, 188 pub used_at: String, 189} 190 191#[derive(Serialize)] 192pub struct GetAccountInviteCodesOutput { 193 pub codes: Vec<InviteCode>, 194} 195 196pub async fn get_account_invite_codes( 197 State(state): State<AppState>, 198 BearerAuth(auth_user): BearerAuth, 199 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 200) -> Response { 201 let include_used = params.include_used.unwrap_or(true); 202 203 let codes_rows = match sqlx::query!( 204 r#" 205 SELECT 206 ic.code, 207 ic.available_uses, 208 ic.created_at, 209 ic.disabled, 210 ic.for_account, 211 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 212 FROM invite_codes ic 213 WHERE ic.for_account = $1 214 ORDER BY ic.created_at DESC 215 "#, 216 auth_user.did 217 ) 218 .fetch_all(&state.db) 219 .await 220 { 221 Ok(rows) => rows, 222 Err(e) => { 223 error!("DB error fetching invite codes: {:?}", e); 224 return ApiError::InternalError.into_response(); 225 } 226 }; 227 228 let mut codes = Vec::new(); 229 for row in codes_rows { 230 let disabled = row.disabled.unwrap_or(false); 231 if disabled { 232 continue; 233 } 234 235 let use_count = row.use_count; 236 if !include_used && use_count >= row.available_uses { 237 continue; 238 } 239 240 let uses = sqlx::query!( 241 r#" 242 SELECT u.did, u.handle, icu.used_at 243 FROM invite_code_uses icu 244 JOIN users u ON icu.used_by_user = u.id 245 WHERE icu.code = $1 246 ORDER BY icu.used_at DESC 247 "#, 248 row.code 249 ) 250 .fetch_all(&state.db) 251 .await 252 .map(|use_rows| { 253 use_rows 254 .iter() 255 .map(|u| InviteCodeUse { 256 used_by: u.did.clone(), 257 used_by_handle: Some(u.handle.clone()), 258 used_at: u.used_at.to_rfc3339(), 259 }) 260 .collect() 261 }) 262 .unwrap_or_default(); 263 264 codes.push(InviteCode { 265 code: row.code, 266 available: row.available_uses, 267 disabled, 268 for_account: row.for_account, 269 created_by: "admin".to_string(), 270 created_at: row.created_at.to_rfc3339(), 271 uses, 272 }); 273 } 274 275 Json(GetAccountInviteCodesOutput { codes }).into_response() 276}