this repo has no description
at main 8.3 kB view raw
1use crate::api::ApiError; 2use crate::auth::BearerAuth; 3use crate::auth::extractor::BearerAuthAdmin; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::State, 8 response::{IntoResponse, Response}, 9}; 10use rand::Rng; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567"; 15 16fn gen_random_token() -> String { 17 let mut rng = rand::thread_rng(); 18 let gen_segment = |rng: &mut rand::rngs::ThreadRng, len: usize| -> String { 19 (0..len) 20 .map(|_| BASE32_ALPHABET[rng.gen_range(0..32)] as char) 21 .collect() 22 }; 23 format!("{}-{}", gen_segment(&mut rng, 5), gen_segment(&mut rng, 5)) 24} 25 26fn gen_invite_code() -> String { 27 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 28 let hostname_prefix = hostname.replace('.', "-"); 29 format!("{}-{}", hostname_prefix, gen_random_token()) 30} 31 32#[derive(Deserialize)] 33#[serde(rename_all = "camelCase")] 34pub struct CreateInviteCodeInput { 35 pub use_count: i32, 36 pub for_account: Option<String>, 37} 38 39#[derive(Serialize)] 40pub struct CreateInviteCodeOutput { 41 pub code: String, 42} 43 44pub async fn create_invite_code( 45 State(state): State<AppState>, 46 BearerAuthAdmin(auth_user): BearerAuthAdmin, 47 Json(input): Json<CreateInviteCodeInput>, 48) -> Response { 49 if input.use_count < 1 { 50 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 51 } 52 53 let for_account = input 54 .for_account 55 .unwrap_or_else(|| auth_user.did.to_string()); 56 let code = gen_invite_code(); 57 58 match sqlx::query!( 59 "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 60 SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 61 code, 62 input.use_count, 63 for_account 64 ) 65 .execute(&state.db) 66 .await 67 { 68 Ok(result) => { 69 if result.rows_affected() == 0 { 70 error!("No admin user found to create invite code"); 71 return ApiError::InternalError(None).into_response(); 72 } 73 Json(CreateInviteCodeOutput { code }).into_response() 74 } 75 Err(e) => { 76 error!("DB error creating invite code: {:?}", e); 77 ApiError::InternalError(None).into_response() 78 } 79 } 80} 81 82#[derive(Deserialize)] 83#[serde(rename_all = "camelCase")] 84pub struct CreateInviteCodesInput { 85 pub code_count: Option<i32>, 86 pub use_count: i32, 87 pub for_accounts: Option<Vec<String>>, 88} 89 90#[derive(Serialize)] 91pub struct CreateInviteCodesOutput { 92 pub codes: Vec<AccountCodes>, 93} 94 95#[derive(Serialize)] 96pub struct AccountCodes { 97 pub account: String, 98 pub codes: Vec<String>, 99} 100 101pub async fn create_invite_codes( 102 State(state): State<AppState>, 103 BearerAuthAdmin(auth_user): BearerAuthAdmin, 104 Json(input): Json<CreateInviteCodesInput>, 105) -> Response { 106 if input.use_count < 1 { 107 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 108 } 109 110 let code_count = input.code_count.unwrap_or(1).max(1); 111 let for_accounts = input 112 .for_accounts 113 .filter(|v| !v.is_empty()) 114 .unwrap_or_else(|| vec![auth_user.did.to_string()]); 115 116 let admin_user_id = 117 match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1") 118 .fetch_optional(&state.db) 119 .await 120 { 121 Ok(Some(id)) => id, 122 Ok(None) => { 123 error!("No admin user found to create invite codes"); 124 return ApiError::InternalError(None).into_response(); 125 } 126 Err(e) => { 127 error!("DB error looking up admin user: {:?}", e); 128 return ApiError::InternalError(None).into_response(); 129 } 130 }; 131 132 let result = futures::future::try_join_all(for_accounts.into_iter().map(|account| { 133 let db = state.db.clone(); 134 let use_count = input.use_count; 135 async move { 136 let codes: Vec<String> = (0..code_count).map(|_| gen_invite_code()).collect(); 137 sqlx::query!( 138 r#" 139 INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 140 SELECT code, $2, $3, $4 FROM UNNEST($1::text[]) AS t(code) 141 "#, 142 &codes[..], 143 use_count, 144 admin_user_id, 145 account 146 ) 147 .execute(&db) 148 .await 149 .map(|_| AccountCodes { account, codes }) 150 } 151 })) 152 .await; 153 154 match result { 155 Ok(result_codes) => Json(CreateInviteCodesOutput { 156 codes: result_codes, 157 }) 158 .into_response(), 159 Err(e) => { 160 error!("DB error creating invite codes: {:?}", e); 161 ApiError::InternalError(None).into_response() 162 } 163 } 164} 165 166#[derive(Deserialize)] 167#[serde(rename_all = "camelCase")] 168pub struct GetAccountInviteCodesParams { 169 pub include_used: Option<bool>, 170 pub create_available: Option<bool>, 171} 172 173#[derive(Serialize)] 174#[serde(rename_all = "camelCase")] 175pub struct InviteCode { 176 pub code: String, 177 pub available: i32, 178 pub disabled: bool, 179 pub for_account: String, 180 pub created_by: String, 181 pub created_at: String, 182 pub uses: Vec<InviteCodeUse>, 183} 184 185#[derive(Serialize)] 186#[serde(rename_all = "camelCase")] 187pub struct InviteCodeUse { 188 pub used_by: String, 189 #[serde(skip_serializing_if = "Option::is_none")] 190 pub used_by_handle: Option<String>, 191 pub used_at: String, 192} 193 194#[derive(Serialize)] 195pub struct GetAccountInviteCodesOutput { 196 pub codes: Vec<InviteCode>, 197} 198 199pub async fn get_account_invite_codes( 200 State(state): State<AppState>, 201 BearerAuth(auth_user): BearerAuth, 202 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 203) -> Response { 204 let include_used = params.include_used.unwrap_or(true); 205 206 let codes_rows = match sqlx::query!( 207 r#" 208 SELECT 209 ic.code, 210 ic.available_uses, 211 ic.created_at, 212 ic.disabled, 213 ic.for_account, 214 (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 215 FROM invite_codes ic 216 WHERE ic.for_account = $1 217 ORDER BY ic.created_at DESC 218 "#, 219 &auth_user.did 220 ) 221 .fetch_all(&state.db) 222 .await 223 { 224 Ok(rows) => rows, 225 Err(e) => { 226 error!("DB error fetching invite codes: {:?}", e); 227 return ApiError::InternalError(None).into_response(); 228 } 229 }; 230 231 let filtered_rows: Vec<_> = codes_rows 232 .into_iter() 233 .filter(|row| { 234 let disabled = row.disabled.unwrap_or(false); 235 !disabled && (include_used || row.use_count < row.available_uses) 236 }) 237 .collect(); 238 239 let codes = futures::future::join_all(filtered_rows.into_iter().map(|row| { 240 let db = state.db.clone(); 241 async move { 242 let uses = sqlx::query!( 243 r#" 244 SELECT u.did, u.handle, icu.used_at 245 FROM invite_code_uses icu 246 JOIN users u ON icu.used_by_user = u.id 247 WHERE icu.code = $1 248 ORDER BY icu.used_at DESC 249 "#, 250 row.code 251 ) 252 .fetch_all(&db) 253 .await 254 .map(|use_rows| { 255 use_rows 256 .iter() 257 .map(|u| InviteCodeUse { 258 used_by: u.did.clone(), 259 used_by_handle: Some(u.handle.clone()), 260 used_at: u.used_at.to_rfc3339(), 261 }) 262 .collect() 263 }) 264 .unwrap_or_default(); 265 266 InviteCode { 267 code: row.code, 268 available: row.available_uses, 269 disabled: false, 270 for_account: row.for_account, 271 created_by: "admin".to_string(), 272 created_at: row.created_at.to_rfc3339(), 273 uses, 274 } 275 } 276 })) 277 .await; 278 279 Json(GetAccountInviteCodesOutput { codes }).into_response() 280}