this repo has no description
at main 8.2 kB view raw
1use crate::api::EmptyResponse; 2use crate::api::error::ApiError; 3use crate::auth::BearerAuthAdmin; 4use crate::state::AppState; 5use axum::{ 6 Json, 7 extract::{Query, State}, 8 http::StatusCode, 9 response::{IntoResponse, Response}, 10}; 11use serde::{Deserialize, Serialize}; 12use tracing::error; 13 14#[derive(Deserialize)] 15#[serde(rename_all = "camelCase")] 16pub struct DisableInviteCodesInput { 17 pub codes: Option<Vec<String>>, 18 pub accounts: Option<Vec<String>>, 19} 20 21pub async fn disable_invite_codes( 22 State(state): State<AppState>, 23 _auth: BearerAuthAdmin, 24 Json(input): Json<DisableInviteCodesInput>, 25) -> Response { 26 if let Some(codes) = &input.codes { 27 let _ = sqlx::query!( 28 "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)", 29 codes as &[String] 30 ) 31 .execute(&state.db) 32 .await; 33 } 34 if let Some(accounts) = &input.accounts { 35 let _ = sqlx::query!( 36 "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))", 37 accounts as &[String] 38 ) 39 .execute(&state.db) 40 .await; 41 } 42 EmptyResponse::ok().into_response() 43} 44 45#[derive(Deserialize)] 46pub struct GetInviteCodesParams { 47 pub sort: Option<String>, 48 pub limit: Option<i64>, 49 pub cursor: Option<String>, 50} 51 52#[derive(Serialize)] 53#[serde(rename_all = "camelCase")] 54pub struct InviteCodeInfo { 55 pub code: String, 56 pub available: i32, 57 pub disabled: bool, 58 pub for_account: String, 59 pub created_by: String, 60 pub created_at: String, 61 pub uses: Vec<InviteCodeUseInfo>, 62} 63 64#[derive(Clone, Serialize)] 65#[serde(rename_all = "camelCase")] 66pub struct InviteCodeUseInfo { 67 pub used_by: String, 68 pub used_at: String, 69} 70 71#[derive(Serialize)] 72pub struct GetInviteCodesOutput { 73 #[serde(skip_serializing_if = "Option::is_none")] 74 pub cursor: Option<String>, 75 pub codes: Vec<InviteCodeInfo>, 76} 77 78pub async fn get_invite_codes( 79 State(state): State<AppState>, 80 _auth: BearerAuthAdmin, 81 Query(params): Query<GetInviteCodesParams>, 82) -> Response { 83 let limit = params.limit.unwrap_or(100).clamp(1, 500); 84 let sort = params.sort.as_deref().unwrap_or("recent"); 85 let order_clause = match sort { 86 "usage" => "available_uses DESC", 87 _ => "created_at DESC", 88 }; 89 let codes_result = if let Some(cursor) = &params.cursor { 90 sqlx::query_as::< 91 _, 92 ( 93 String, 94 i32, 95 Option<bool>, 96 uuid::Uuid, 97 chrono::DateTime<chrono::Utc>, 98 ), 99 >(&format!( 100 r#" 101 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 102 FROM invite_codes ic 103 WHERE ic.created_at < (SELECT created_at FROM invite_codes WHERE code = $1) 104 ORDER BY {} 105 LIMIT $2 106 "#, 107 order_clause 108 )) 109 .bind(cursor) 110 .bind(limit) 111 .fetch_all(&state.db) 112 .await 113 } else { 114 sqlx::query_as::< 115 _, 116 ( 117 String, 118 i32, 119 Option<bool>, 120 uuid::Uuid, 121 chrono::DateTime<chrono::Utc>, 122 ), 123 >(&format!( 124 r#" 125 SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 126 FROM invite_codes ic 127 ORDER BY {} 128 LIMIT $1 129 "#, 130 order_clause 131 )) 132 .bind(limit) 133 .fetch_all(&state.db) 134 .await 135 }; 136 let codes_rows = match codes_result { 137 Ok(rows) => rows, 138 Err(e) => { 139 error!("DB error fetching invite codes: {:?}", e); 140 return ApiError::InternalError(None).into_response(); 141 } 142 }; 143 144 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect(); 145 let code_strings: Vec<String> = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect(); 146 147 let mut creator_dids: std::collections::HashMap<uuid::Uuid, String> = 148 std::collections::HashMap::new(); 149 sqlx::query!( 150 "SELECT id, did FROM users WHERE id = ANY($1)", 151 &user_ids 152 ) 153 .fetch_all(&state.db) 154 .await 155 .unwrap_or_default() 156 .into_iter() 157 .for_each(|r| { 158 creator_dids.insert(r.id, r.did); 159 }); 160 161 let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> = 162 std::collections::HashMap::new(); 163 if !code_strings.is_empty() { 164 sqlx::query!( 165 r#" 166 SELECT icu.code, u.did, icu.used_at 167 FROM invite_code_uses icu 168 JOIN users u ON icu.used_by_user = u.id 169 WHERE icu.code = ANY($1) 170 ORDER BY icu.used_at DESC 171 "#, 172 &code_strings 173 ) 174 .fetch_all(&state.db) 175 .await 176 .unwrap_or_default() 177 .into_iter() 178 .for_each(|r| { 179 uses_by_code 180 .entry(r.code) 181 .or_default() 182 .push(InviteCodeUseInfo { 183 used_by: r.did, 184 used_at: r.used_at.to_rfc3339(), 185 }); 186 }); 187 } 188 189 let codes: Vec<InviteCodeInfo> = codes_rows 190 .iter() 191 .map(|(code, available_uses, disabled, created_by_user, created_at)| { 192 let creator_did = creator_dids 193 .get(created_by_user) 194 .cloned() 195 .unwrap_or_else(|| "unknown".to_string()); 196 InviteCodeInfo { 197 code: code.clone(), 198 available: *available_uses, 199 disabled: disabled.unwrap_or(false), 200 for_account: creator_did.clone(), 201 created_by: creator_did, 202 created_at: created_at.to_rfc3339(), 203 uses: uses_by_code.get(code).cloned().unwrap_or_default(), 204 } 205 }) 206 .collect(); 207 208 let next_cursor = if codes_rows.len() == limit as usize { 209 codes_rows.last().map(|(code, _, _, _, _)| code.clone()) 210 } else { 211 None 212 }; 213 ( 214 StatusCode::OK, 215 Json(GetInviteCodesOutput { 216 cursor: next_cursor, 217 codes, 218 }), 219 ) 220 .into_response() 221} 222 223#[derive(Deserialize)] 224pub struct DisableAccountInvitesInput { 225 pub account: String, 226} 227 228pub async fn disable_account_invites( 229 State(state): State<AppState>, 230 _auth: BearerAuthAdmin, 231 Json(input): Json<DisableAccountInvitesInput>, 232) -> Response { 233 let account = input.account.trim(); 234 if account.is_empty() { 235 return ApiError::InvalidRequest("account is required".into()).into_response(); 236 } 237 let result = sqlx::query!( 238 "UPDATE users SET invites_disabled = TRUE WHERE did = $1", 239 account 240 ) 241 .execute(&state.db) 242 .await; 243 match result { 244 Ok(r) => { 245 if r.rows_affected() == 0 { 246 return ApiError::AccountNotFound.into_response(); 247 } 248 EmptyResponse::ok().into_response() 249 } 250 Err(e) => { 251 error!("DB error disabling account invites: {:?}", e); 252 ApiError::InternalError(None).into_response() 253 } 254 } 255} 256 257#[derive(Deserialize)] 258pub struct EnableAccountInvitesInput { 259 pub account: String, 260} 261 262pub async fn enable_account_invites( 263 State(state): State<AppState>, 264 _auth: BearerAuthAdmin, 265 Json(input): Json<EnableAccountInvitesInput>, 266) -> Response { 267 let account = input.account.trim(); 268 if account.is_empty() { 269 return ApiError::InvalidRequest("account is required".into()).into_response(); 270 } 271 let result = sqlx::query!( 272 "UPDATE users SET invites_disabled = FALSE WHERE did = $1", 273 account 274 ) 275 .execute(&state.db) 276 .await; 277 match result { 278 Ok(r) => { 279 if r.rows_affected() == 0 { 280 return ApiError::AccountNotFound.into_response(); 281 } 282 EmptyResponse::ok().into_response() 283 } 284 Err(e) => { 285 error!("DB error enabling account invites: {:?}", e); 286 ApiError::InternalError(None).into_response() 287 } 288 } 289}