this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::locale::{format_message, get_strings}; 13use super::sender::{CommsSender, SendError}; 14use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 15 16pub struct CommsService { 17 db: PgPool, 18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>, 19 poll_interval: Duration, 20 batch_size: i64, 21} 22 23impl CommsService { 24 pub fn new(db: PgPool) -> Self { 25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 26 .ok() 27 .and_then(|v| v.parse().ok()) 28 .unwrap_or(1000); 29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 30 .ok() 31 .and_then(|v| v.parse().ok()) 32 .unwrap_or(100); 33 Self { 34 db, 35 senders: HashMap::new(), 36 poll_interval: Duration::from_millis(poll_interval_ms), 37 batch_size, 38 } 39 } 40 41 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 42 self.poll_interval = interval; 43 self 44 } 45 46 pub fn with_batch_size(mut self, size: i64) -> Self { 47 self.batch_size = size; 48 self 49 } 50 51 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self { 52 self.senders.insert(sender.channel(), Arc::new(sender)); 53 self 54 } 55 56 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> { 57 let id = sqlx::query_scalar!( 58 r#" 59 INSERT INTO comms_queue 60 (user_id, channel, comms_type, recipient, subject, body, metadata) 61 VALUES ($1, $2, $3, $4, $5, $6, $7) 62 RETURNING id 63 "#, 64 item.user_id, 65 item.channel as CommsChannel, 66 item.comms_type as super::types::CommsType, 67 item.recipient, 68 item.subject, 69 item.body, 70 item.metadata 71 ) 72 .fetch_one(&self.db) 73 .await?; 74 debug!(comms_id = %id, "Comms enqueued"); 75 Ok(id) 76 } 77 78 pub fn has_senders(&self) -> bool { 79 !self.senders.is_empty() 80 } 81 82 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 83 if self.senders.is_empty() { 84 warn!( 85 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured." 86 ); 87 } 88 info!( 89 poll_interval_secs = self.poll_interval.as_secs(), 90 batch_size = self.batch_size, 91 channels = ?self.senders.keys().collect::<Vec<_>>(), 92 "Starting comms service" 93 ); 94 let mut ticker = interval(self.poll_interval); 95 loop { 96 tokio::select! { 97 _ = ticker.tick() => { 98 if let Err(e) = self.process_batch().await { 99 error!(error = %e, "Failed to process comms batch"); 100 } 101 } 102 _ = shutdown.changed() => { 103 if *shutdown.borrow() { 104 info!("Comms service shutting down"); 105 break; 106 } 107 } 108 } 109 } 110 } 111 112 async fn process_batch(&self) -> Result<(), sqlx::Error> { 113 let items = self.fetch_pending().await?; 114 if items.is_empty() { 115 return Ok(()); 116 } 117 debug!(count = items.len(), "Processing comms batch"); 118 for item in items { 119 self.process_item(item).await; 120 } 121 Ok(()) 122 } 123 124 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> { 125 let now = Utc::now(); 126 sqlx::query_as!( 127 QueuedComms, 128 r#" 129 UPDATE comms_queue 130 SET status = 'processing', updated_at = NOW() 131 WHERE id IN ( 132 SELECT id FROM comms_queue 133 WHERE status = 'pending' 134 AND scheduled_for <= $1 135 AND attempts < max_attempts 136 ORDER BY scheduled_for ASC 137 LIMIT $2 138 FOR UPDATE SKIP LOCKED 139 ) 140 RETURNING 141 id, user_id, 142 channel as "channel: CommsChannel", 143 comms_type as "comms_type: super::types::CommsType", 144 status as "status: CommsStatus", 145 recipient, subject, body, metadata, 146 attempts, max_attempts, last_error, 147 created_at, updated_at, scheduled_for, processed_at 148 "#, 149 now, 150 self.batch_size 151 ) 152 .fetch_all(&self.db) 153 .await 154 } 155 156 async fn process_item(&self, item: QueuedComms) { 157 let comms_id = item.id; 158 let channel = item.channel; 159 let result = match self.senders.get(&channel) { 160 Some(sender) => sender.send(&item).await, 161 None => { 162 warn!( 163 comms_id = %comms_id, 164 channel = ?channel, 165 "No sender registered for channel" 166 ); 167 Err(SendError::NotConfigured(channel)) 168 } 169 }; 170 match result { 171 Ok(()) => { 172 debug!(comms_id = %comms_id, "Comms sent successfully"); 173 if let Err(e) = self.mark_sent(comms_id).await { 174 error!( 175 comms_id = %comms_id, 176 error = %e, 177 "Failed to mark comms as sent" 178 ); 179 } 180 } 181 Err(e) => { 182 let error_msg = e.to_string(); 183 warn!( 184 comms_id = %comms_id, 185 error = %error_msg, 186 "Failed to send comms" 187 ); 188 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await { 189 error!( 190 comms_id = %comms_id, 191 error = %db_err, 192 "Failed to mark comms as failed" 193 ); 194 } 195 } 196 } 197 } 198 199 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 200 sqlx::query!( 201 r#" 202 UPDATE comms_queue 203 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 204 WHERE id = $1 205 "#, 206 id 207 ) 208 .execute(&self.db) 209 .await?; 210 Ok(()) 211 } 212 213 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 214 sqlx::query!( 215 r#" 216 UPDATE comms_queue 217 SET 218 status = CASE 219 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status 220 ELSE 'pending'::comms_status 221 END, 222 attempts = attempts + 1, 223 last_error = $2, 224 updated_at = NOW(), 225 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 226 WHERE id = $1 227 "#, 228 id, 229 error 230 ) 231 .execute(&self.db) 232 .await?; 233 Ok(()) 234 } 235} 236 237pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> { 238 sqlx::query_scalar!( 239 r#" 240 INSERT INTO comms_queue 241 (user_id, channel, comms_type, recipient, subject, body, metadata) 242 VALUES ($1, $2, $3, $4, $5, $6, $7) 243 RETURNING id 244 "#, 245 item.user_id, 246 item.channel as CommsChannel, 247 item.comms_type as super::types::CommsType, 248 item.recipient, 249 item.subject, 250 item.body, 251 item.metadata 252 ) 253 .fetch_one(db) 254 .await 255} 256 257pub struct UserCommsPrefs { 258 pub channel: CommsChannel, 259 pub email: Option<String>, 260 pub handle: String, 261 pub locale: String, 262} 263 264pub async fn get_user_comms_prefs( 265 db: &PgPool, 266 user_id: Uuid, 267) -> Result<UserCommsPrefs, sqlx::Error> { 268 let row = sqlx::query!( 269 r#" 270 SELECT 271 email, 272 handle, 273 preferred_comms_channel as "channel: CommsChannel", 274 preferred_locale 275 FROM users 276 WHERE id = $1 277 "#, 278 user_id 279 ) 280 .fetch_one(db) 281 .await?; 282 Ok(UserCommsPrefs { 283 channel: row.channel, 284 email: row.email, 285 handle: row.handle, 286 locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()), 287 }) 288} 289 290pub async fn enqueue_welcome( 291 db: &PgPool, 292 user_id: Uuid, 293 hostname: &str, 294) -> Result<Uuid, sqlx::Error> { 295 let prefs = get_user_comms_prefs(db, user_id).await?; 296 let strings = get_strings(&prefs.locale); 297 let body = format_message( 298 strings.welcome_body, 299 &[("hostname", hostname), ("handle", &prefs.handle)], 300 ); 301 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 302 enqueue_comms( 303 db, 304 NewComms::new( 305 user_id, 306 prefs.channel, 307 super::types::CommsType::Welcome, 308 prefs.email.clone().unwrap_or_default(), 309 Some(subject), 310 body, 311 ), 312 ) 313 .await 314} 315 316pub async fn enqueue_email_verification( 317 db: &PgPool, 318 user_id: Uuid, 319 email: &str, 320 handle: &str, 321 code: &str, 322 hostname: &str, 323) -> Result<Uuid, sqlx::Error> { 324 let prefs = get_user_comms_prefs(db, user_id).await?; 325 let strings = get_strings(&prefs.locale); 326 let body = format_message( 327 strings.email_verification_body, 328 &[("handle", handle), ("code", code)], 329 ); 330 let subject = format_message(strings.email_verification_subject, &[("hostname", hostname)]); 331 enqueue_comms( 332 db, 333 NewComms::email( 334 user_id, 335 super::types::CommsType::EmailVerification, 336 email.to_string(), 337 subject, 338 body, 339 ), 340 ) 341 .await 342} 343 344pub async fn enqueue_password_reset( 345 db: &PgPool, 346 user_id: Uuid, 347 code: &str, 348 hostname: &str, 349) -> Result<Uuid, sqlx::Error> { 350 let prefs = get_user_comms_prefs(db, user_id).await?; 351 let strings = get_strings(&prefs.locale); 352 let body = format_message( 353 strings.password_reset_body, 354 &[("handle", &prefs.handle), ("code", code)], 355 ); 356 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 357 enqueue_comms( 358 db, 359 NewComms::new( 360 user_id, 361 prefs.channel, 362 super::types::CommsType::PasswordReset, 363 prefs.email.clone().unwrap_or_default(), 364 Some(subject), 365 body, 366 ), 367 ) 368 .await 369} 370 371pub async fn enqueue_email_update( 372 db: &PgPool, 373 user_id: Uuid, 374 new_email: &str, 375 handle: &str, 376 code: &str, 377 hostname: &str, 378) -> Result<Uuid, sqlx::Error> { 379 let prefs = get_user_comms_prefs(db, user_id).await?; 380 let strings = get_strings(&prefs.locale); 381 let body = format_message( 382 strings.email_update_body, 383 &[("handle", handle), ("code", code)], 384 ); 385 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 386 enqueue_comms( 387 db, 388 NewComms::email( 389 user_id, 390 super::types::CommsType::EmailUpdate, 391 new_email.to_string(), 392 subject, 393 body, 394 ), 395 ) 396 .await 397} 398 399pub async fn enqueue_account_deletion( 400 db: &PgPool, 401 user_id: Uuid, 402 code: &str, 403 hostname: &str, 404) -> Result<Uuid, sqlx::Error> { 405 let prefs = get_user_comms_prefs(db, user_id).await?; 406 let strings = get_strings(&prefs.locale); 407 let body = format_message( 408 strings.account_deletion_body, 409 &[("handle", &prefs.handle), ("code", code)], 410 ); 411 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 412 enqueue_comms( 413 db, 414 NewComms::new( 415 user_id, 416 prefs.channel, 417 super::types::CommsType::AccountDeletion, 418 prefs.email.clone().unwrap_or_default(), 419 Some(subject), 420 body, 421 ), 422 ) 423 .await 424} 425 426pub async fn enqueue_plc_operation( 427 db: &PgPool, 428 user_id: Uuid, 429 token: &str, 430 hostname: &str, 431) -> Result<Uuid, sqlx::Error> { 432 let prefs = get_user_comms_prefs(db, user_id).await?; 433 let strings = get_strings(&prefs.locale); 434 let body = format_message( 435 strings.plc_operation_body, 436 &[("handle", &prefs.handle), ("token", token)], 437 ); 438 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 439 enqueue_comms( 440 db, 441 NewComms::new( 442 user_id, 443 prefs.channel, 444 super::types::CommsType::PlcOperation, 445 prefs.email.clone().unwrap_or_default(), 446 Some(subject), 447 body, 448 ), 449 ) 450 .await 451} 452 453pub async fn enqueue_2fa_code( 454 db: &PgPool, 455 user_id: Uuid, 456 code: &str, 457 hostname: &str, 458) -> Result<Uuid, sqlx::Error> { 459 let prefs = get_user_comms_prefs(db, user_id).await?; 460 let strings = get_strings(&prefs.locale); 461 let body = format_message( 462 strings.two_factor_code_body, 463 &[("handle", &prefs.handle), ("code", code)], 464 ); 465 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 466 enqueue_comms( 467 db, 468 NewComms::new( 469 user_id, 470 prefs.channel, 471 super::types::CommsType::TwoFactorCode, 472 prefs.email.clone().unwrap_or_default(), 473 Some(subject), 474 body, 475 ), 476 ) 477 .await 478} 479 480pub async fn enqueue_passkey_recovery( 481 db: &PgPool, 482 user_id: Uuid, 483 recovery_url: &str, 484 hostname: &str, 485) -> Result<Uuid, sqlx::Error> { 486 let prefs = get_user_comms_prefs(db, user_id).await?; 487 let strings = get_strings(&prefs.locale); 488 let body = format_message( 489 strings.passkey_recovery_body, 490 &[("handle", &prefs.handle), ("url", recovery_url)], 491 ); 492 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 493 enqueue_comms( 494 db, 495 NewComms::new( 496 user_id, 497 prefs.channel, 498 super::types::CommsType::PasskeyRecovery, 499 prefs.email.clone().unwrap_or_default(), 500 Some(subject), 501 body, 502 ), 503 ) 504 .await 505} 506 507pub fn channel_display_name(channel: CommsChannel) -> &'static str { 508 match channel { 509 CommsChannel::Email => "email", 510 CommsChannel::Discord => "Discord", 511 CommsChannel::Telegram => "Telegram", 512 CommsChannel::Signal => "Signal", 513 } 514} 515 516pub async fn enqueue_signup_verification( 517 db: &PgPool, 518 user_id: Uuid, 519 channel: &str, 520 recipient: &str, 521 code: &str, 522 locale: Option<&str>, 523) -> Result<Uuid, sqlx::Error> { 524 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 525 let comms_channel = match channel { 526 "email" => CommsChannel::Email, 527 "discord" => CommsChannel::Discord, 528 "telegram" => CommsChannel::Telegram, 529 "signal" => CommsChannel::Signal, 530 _ => CommsChannel::Email, 531 }; 532 let strings = get_strings(locale.unwrap_or("en")); 533 let body = format_message( 534 strings.signup_verification_body, 535 &[("code", code), ("hostname", &hostname)], 536 ); 537 let subject = match comms_channel { 538 CommsChannel::Email => { 539 Some(format_message(strings.signup_verification_subject, &[("hostname", &hostname)])) 540 } 541 _ => None, 542 }; 543 enqueue_comms( 544 db, 545 NewComms::new( 546 user_id, 547 comms_channel, 548 super::types::CommsType::EmailVerification, 549 recipient.to_string(), 550 subject, 551 body, 552 ), 553 ) 554 .await 555} 556 557pub async fn queue_legacy_login_notification( 558 db: &PgPool, 559 user_id: Uuid, 560 hostname: &str, 561 client_ip: &str, 562 channel: CommsChannel, 563) -> Result<Uuid, sqlx::Error> { 564 let prefs = get_user_comms_prefs(db, user_id).await?; 565 let strings = get_strings(&prefs.locale); 566 let timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string(); 567 let body = format_message( 568 strings.legacy_login_body, 569 &[ 570 ("handle", &prefs.handle), 571 ("timestamp", &timestamp), 572 ("ip", client_ip), 573 ("hostname", hostname), 574 ], 575 ); 576 let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]); 577 enqueue_comms( 578 db, 579 NewComms::new( 580 user_id, 581 channel, 582 super::types::CommsType::LegacyLoginAlert, 583 prefs.email.clone().unwrap_or_default(), 584 Some(subject), 585 body, 586 ), 587 ) 588 .await 589}