this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::State, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use serde::{Deserialize, Serialize}; 9use serde_json::json; 10use tracing::error; 11use uuid::Uuid; 12 13#[derive(Deserialize)] 14#[serde(rename_all = "camelCase")] 15pub struct CreateInviteCodeInput { 16 pub use_count: i32, 17 pub for_account: Option<String>, 18} 19 20#[derive(Serialize)] 21pub struct CreateInviteCodeOutput { 22 pub code: String, 23} 24 25pub async fn create_invite_code( 26 State(state): State<AppState>, 27 headers: axum::http::HeaderMap, 28 Json(input): Json<CreateInviteCodeInput>, 29) -> Response { 30 let auth_header = headers.get("Authorization"); 31 if auth_header.is_none() { 32 return ( 33 StatusCode::UNAUTHORIZED, 34 Json(json!({"error": "AuthenticationRequired"})), 35 ) 36 .into_response(); 37 } 38 39 if input.use_count < 1 { 40 return ( 41 StatusCode::BAD_REQUEST, 42 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})), 43 ) 44 .into_response(); 45 } 46 47 let token = auth_header 48 .unwrap() 49 .to_str() 50 .unwrap_or("") 51 .replace("Bearer ", ""); 52 53 let session = sqlx::query!( 54 r#" 55 SELECT s.did, k.key_bytes, u.id as user_id 56 FROM sessions s 57 JOIN users u ON s.did = u.did 58 JOIN user_keys k ON u.id = k.user_id 59 WHERE s.access_jwt = $1 60 "#, 61 token 62 ) 63 .fetch_optional(&state.db) 64 .await; 65 66 let (did, key_bytes, user_id) = match session { 67 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id), 68 Ok(None) => { 69 return ( 70 StatusCode::UNAUTHORIZED, 71 Json(json!({"error": "AuthenticationFailed"})), 72 ) 73 .into_response(); 74 } 75 Err(e) => { 76 error!("DB error in create_invite_code: {:?}", e); 77 return ( 78 StatusCode::INTERNAL_SERVER_ERROR, 79 Json(json!({"error": "InternalError"})), 80 ) 81 .into_response(); 82 } 83 }; 84 85 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 86 return ( 87 StatusCode::UNAUTHORIZED, 88 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 89 ) 90 .into_response(); 91 } 92 93 let creator_user_id = if let Some(for_account) = &input.for_account { 94 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", for_account) 95 .fetch_optional(&state.db) 96 .await; 97 98 match target { 99 Ok(Some(row)) => row.id, 100 Ok(None) => { 101 return ( 102 StatusCode::NOT_FOUND, 103 Json(json!({"error": "AccountNotFound", "message": "Target account not found"})), 104 ) 105 .into_response(); 106 } 107 Err(e) => { 108 error!("DB error looking up target account: {:?}", e); 109 return ( 110 StatusCode::INTERNAL_SERVER_ERROR, 111 Json(json!({"error": "InternalError"})), 112 ) 113 .into_response(); 114 } 115 } 116 } else { 117 user_id 118 }; 119 120 let user_invites_disabled = sqlx::query_scalar!( 121 "SELECT invites_disabled FROM users WHERE did = $1", 122 did 123 ) 124 .fetch_optional(&state.db) 125 .await 126 .ok() 127 .flatten() 128 .flatten() 129 .unwrap_or(false); 130 131 if user_invites_disabled { 132 return ( 133 StatusCode::FORBIDDEN, 134 Json(json!({"error": "InvitesDisabled", "message": "Invites are disabled for this account"})), 135 ) 136 .into_response(); 137 } 138 139 let code = Uuid::new_v4().to_string(); 140 141 let result = sqlx::query!( 142 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 143 code, 144 input.use_count, 145 creator_user_id 146 ) 147 .execute(&state.db) 148 .await; 149 150 match result { 151 Ok(_) => (StatusCode::OK, Json(CreateInviteCodeOutput { code })).into_response(), 152 Err(e) => { 153 error!("DB error creating invite code: {:?}", e); 154 ( 155 StatusCode::INTERNAL_SERVER_ERROR, 156 Json(json!({"error": "InternalError"})), 157 ) 158 .into_response() 159 } 160 } 161} 162 163#[derive(Deserialize)] 164#[serde(rename_all = "camelCase")] 165pub struct CreateInviteCodesInput { 166 pub code_count: Option<i32>, 167 pub use_count: i32, 168 pub for_accounts: Option<Vec<String>>, 169} 170 171#[derive(Serialize)] 172pub struct CreateInviteCodesOutput { 173 pub codes: Vec<AccountCodes>, 174} 175 176#[derive(Serialize)] 177pub struct AccountCodes { 178 pub account: String, 179 pub codes: Vec<String>, 180} 181 182pub async fn create_invite_codes( 183 State(state): State<AppState>, 184 headers: axum::http::HeaderMap, 185 Json(input): Json<CreateInviteCodesInput>, 186) -> Response { 187 let auth_header = headers.get("Authorization"); 188 if auth_header.is_none() { 189 return ( 190 StatusCode::UNAUTHORIZED, 191 Json(json!({"error": "AuthenticationRequired"})), 192 ) 193 .into_response(); 194 } 195 196 if input.use_count < 1 { 197 return ( 198 StatusCode::BAD_REQUEST, 199 Json(json!({"error": "InvalidRequest", "message": "useCount must be at least 1"})), 200 ) 201 .into_response(); 202 } 203 204 let token = auth_header 205 .unwrap() 206 .to_str() 207 .unwrap_or("") 208 .replace("Bearer ", ""); 209 210 let session = sqlx::query!( 211 r#" 212 SELECT s.did, k.key_bytes, u.id as user_id 213 FROM sessions s 214 JOIN users u ON s.did = u.did 215 JOIN user_keys k ON u.id = k.user_id 216 WHERE s.access_jwt = $1 217 "#, 218 token 219 ) 220 .fetch_optional(&state.db) 221 .await; 222 223 let (_did, key_bytes, user_id) = match session { 224 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id), 225 Ok(None) => { 226 return ( 227 StatusCode::UNAUTHORIZED, 228 Json(json!({"error": "AuthenticationFailed"})), 229 ) 230 .into_response(); 231 } 232 Err(e) => { 233 error!("DB error in create_invite_codes: {:?}", e); 234 return ( 235 StatusCode::INTERNAL_SERVER_ERROR, 236 Json(json!({"error": "InternalError"})), 237 ) 238 .into_response(); 239 } 240 }; 241 242 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 243 return ( 244 StatusCode::UNAUTHORIZED, 245 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 246 ) 247 .into_response(); 248 } 249 250 let code_count = input.code_count.unwrap_or(1).max(1); 251 let for_accounts = input.for_accounts.unwrap_or_default(); 252 253 let mut result_codes = Vec::new(); 254 255 if for_accounts.is_empty() { 256 let mut codes = Vec::new(); 257 for _ in 0..code_count { 258 let code = Uuid::new_v4().to_string(); 259 260 let insert = sqlx::query!( 261 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 262 code, 263 input.use_count, 264 user_id 265 ) 266 .execute(&state.db) 267 .await; 268 269 if let Err(e) = insert { 270 error!("DB error creating invite code: {:?}", e); 271 return ( 272 StatusCode::INTERNAL_SERVER_ERROR, 273 Json(json!({"error": "InternalError"})), 274 ) 275 .into_response(); 276 } 277 278 codes.push(code); 279 } 280 281 result_codes.push(AccountCodes { 282 account: "admin".to_string(), 283 codes, 284 }); 285 } else { 286 for account_did in for_accounts { 287 let target = sqlx::query!("SELECT id FROM users WHERE did = $1", account_did) 288 .fetch_optional(&state.db) 289 .await; 290 291 let target_user_id = match target { 292 Ok(Some(row)) => row.id, 293 Ok(None) => { 294 continue; 295 } 296 Err(e) => { 297 error!("DB error looking up target account: {:?}", e); 298 return ( 299 StatusCode::INTERNAL_SERVER_ERROR, 300 Json(json!({"error": "InternalError"})), 301 ) 302 .into_response(); 303 } 304 }; 305 306 let mut codes = Vec::new(); 307 for _ in 0..code_count { 308 let code = Uuid::new_v4().to_string(); 309 310 let insert = sqlx::query!( 311 "INSERT INTO invite_codes (code, available_uses, created_by_user) VALUES ($1, $2, $3)", 312 code, 313 input.use_count, 314 target_user_id 315 ) 316 .execute(&state.db) 317 .await; 318 319 if let Err(e) = insert { 320 error!("DB error creating invite code: {:?}", e); 321 return ( 322 StatusCode::INTERNAL_SERVER_ERROR, 323 Json(json!({"error": "InternalError"})), 324 ) 325 .into_response(); 326 } 327 328 codes.push(code); 329 } 330 331 result_codes.push(AccountCodes { 332 account: account_did, 333 codes, 334 }); 335 } 336 } 337 338 (StatusCode::OK, Json(CreateInviteCodesOutput { codes: result_codes })).into_response() 339} 340 341#[derive(Deserialize)] 342#[serde(rename_all = "camelCase")] 343pub struct GetAccountInviteCodesParams { 344 pub include_used: Option<bool>, 345 pub create_available: Option<bool>, 346} 347 348#[derive(Serialize)] 349#[serde(rename_all = "camelCase")] 350pub struct InviteCode { 351 pub code: String, 352 pub available: i32, 353 pub disabled: bool, 354 pub for_account: String, 355 pub created_by: String, 356 pub created_at: String, 357 pub uses: Vec<InviteCodeUse>, 358} 359 360#[derive(Serialize)] 361#[serde(rename_all = "camelCase")] 362pub struct InviteCodeUse { 363 pub used_by: String, 364 pub used_at: String, 365} 366 367#[derive(Serialize)] 368pub struct GetAccountInviteCodesOutput { 369 pub codes: Vec<InviteCode>, 370} 371 372pub async fn get_account_invite_codes( 373 State(state): State<AppState>, 374 headers: axum::http::HeaderMap, 375 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 376) -> Response { 377 let auth_header = headers.get("Authorization"); 378 if auth_header.is_none() { 379 return ( 380 StatusCode::UNAUTHORIZED, 381 Json(json!({"error": "AuthenticationRequired"})), 382 ) 383 .into_response(); 384 } 385 386 let token = auth_header 387 .unwrap() 388 .to_str() 389 .unwrap_or("") 390 .replace("Bearer ", ""); 391 392 let session = sqlx::query!( 393 r#" 394 SELECT s.did, k.key_bytes, u.id as user_id 395 FROM sessions s 396 JOIN users u ON s.did = u.did 397 JOIN user_keys k ON u.id = k.user_id 398 WHERE s.access_jwt = $1 399 "#, 400 token 401 ) 402 .fetch_optional(&state.db) 403 .await; 404 405 let (did, key_bytes, user_id) = match session { 406 Ok(Some(row)) => (row.did, row.key_bytes, row.user_id), 407 Ok(None) => { 408 return ( 409 StatusCode::UNAUTHORIZED, 410 Json(json!({"error": "AuthenticationFailed"})), 411 ) 412 .into_response(); 413 } 414 Err(e) => { 415 error!("DB error in get_account_invite_codes: {:?}", e); 416 return ( 417 StatusCode::INTERNAL_SERVER_ERROR, 418 Json(json!({"error": "InternalError"})), 419 ) 420 .into_response(); 421 } 422 }; 423 424 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) { 425 return ( 426 StatusCode::UNAUTHORIZED, 427 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})), 428 ) 429 .into_response(); 430 } 431 432 let include_used = params.include_used.unwrap_or(true); 433 434 let codes_result = sqlx::query!( 435 r#" 436 SELECT code, available_uses, created_at, disabled 437 FROM invite_codes 438 WHERE created_by_user = $1 439 ORDER BY created_at DESC 440 "#, 441 user_id 442 ) 443 .fetch_all(&state.db) 444 .await; 445 446 let codes_rows = match codes_result { 447 Ok(rows) => { 448 if include_used { 449 rows 450 } else { 451 rows.into_iter().filter(|r| r.available_uses > 0).collect() 452 } 453 } 454 Err(e) => { 455 error!("DB error fetching invite codes: {:?}", e); 456 return ( 457 StatusCode::INTERNAL_SERVER_ERROR, 458 Json(json!({"error": "InternalError"})), 459 ) 460 .into_response(); 461 } 462 }; 463 464 let mut codes = Vec::new(); 465 for row in codes_rows { 466 let uses_result = sqlx::query!( 467 r#" 468 SELECT u.did, icu.used_at 469 FROM invite_code_uses icu 470 JOIN users u ON icu.used_by_user = u.id 471 WHERE icu.code = $1 472 ORDER BY icu.used_at DESC 473 "#, 474 row.code 475 ) 476 .fetch_all(&state.db) 477 .await; 478 479 let uses = match uses_result { 480 Ok(use_rows) => use_rows 481 .iter() 482 .map(|u| InviteCodeUse { 483 used_by: u.did.clone(), 484 used_at: u.used_at.to_rfc3339(), 485 }) 486 .collect(), 487 Err(_) => Vec::new(), 488 }; 489 490 codes.push(InviteCode { 491 code: row.code, 492 available: row.available_uses, 493 disabled: row.disabled.unwrap_or(false), 494 for_account: did.clone(), 495 created_by: did.clone(), 496 created_at: row.created_at.to_rfc3339(), 497 uses, 498 }); 499 } 500 501 (StatusCode::OK, Json(GetAccountInviteCodesOutput { codes })).into_response() 502}