this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::locale::{format_message, get_strings}; 13use super::sender::{CommsSender, SendError}; 14use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 15 16pub struct CommsService { 17 db: PgPool, 18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>, 19 poll_interval: Duration, 20 batch_size: i64, 21} 22 23impl CommsService { 24 pub fn new(db: PgPool) -> Self { 25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 26 .ok() 27 .and_then(|v| v.parse().ok()) 28 .unwrap_or(1000); 29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 30 .ok() 31 .and_then(|v| v.parse().ok()) 32 .unwrap_or(100); 33 Self { 34 db, 35 senders: HashMap::new(), 36 poll_interval: Duration::from_millis(poll_interval_ms), 37 batch_size, 38 } 39 } 40 41 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 42 self.poll_interval = interval; 43 self 44 } 45 46 pub fn with_batch_size(mut self, size: i64) -> Self { 47 self.batch_size = size; 48 self 49 } 50 51 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self { 52 self.senders.insert(sender.channel(), Arc::new(sender)); 53 self 54 } 55 56 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> { 57 let id = sqlx::query_scalar!( 58 r#" 59 INSERT INTO comms_queue 60 (user_id, channel, comms_type, recipient, subject, body, metadata) 61 VALUES ($1, $2, $3, $4, $5, $6, $7) 62 RETURNING id 63 "#, 64 item.user_id, 65 item.channel as CommsChannel, 66 item.comms_type as super::types::CommsType, 67 item.recipient, 68 item.subject, 69 item.body, 70 item.metadata 71 ) 72 .fetch_one(&self.db) 73 .await?; 74 debug!(comms_id = %id, "Comms enqueued"); 75 Ok(id) 76 } 77 78 pub fn has_senders(&self) -> bool { 79 !self.senders.is_empty() 80 } 81 82 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 83 if self.senders.is_empty() { 84 warn!( 85 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured." 86 ); 87 } 88 info!( 89 poll_interval_secs = self.poll_interval.as_secs(), 90 batch_size = self.batch_size, 91 channels = ?self.senders.keys().collect::<Vec<_>>(), 92 "Starting comms service" 93 ); 94 let mut ticker = interval(self.poll_interval); 95 loop { 96 tokio::select! { 97 _ = ticker.tick() => { 98 if let Err(e) = self.process_batch().await { 99 error!(error = %e, "Failed to process comms batch"); 100 } 101 } 102 _ = shutdown.changed() => { 103 if *shutdown.borrow() { 104 info!("Comms service shutting down"); 105 break; 106 } 107 } 108 } 109 } 110 } 111 112 async fn process_batch(&self) -> Result<(), sqlx::Error> { 113 let items = self.fetch_pending().await?; 114 if items.is_empty() { 115 return Ok(()); 116 } 117 debug!(count = items.len(), "Processing comms batch"); 118 for item in items { 119 self.process_item(item).await; 120 } 121 Ok(()) 122 } 123 124 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> { 125 let now = Utc::now(); 126 sqlx::query_as!( 127 QueuedComms, 128 r#" 129 UPDATE comms_queue 130 SET status = 'processing', updated_at = NOW() 131 WHERE id IN ( 132 SELECT id FROM comms_queue 133 WHERE status = 'pending' 134 AND scheduled_for <= $1 135 AND attempts < max_attempts 136 ORDER BY scheduled_for ASC 137 LIMIT $2 138 FOR UPDATE SKIP LOCKED 139 ) 140 RETURNING 141 id, user_id, 142 channel as "channel: CommsChannel", 143 comms_type as "comms_type: super::types::CommsType", 144 status as "status: CommsStatus", 145 recipient, subject, body, metadata, 146 attempts, max_attempts, last_error, 147 created_at, updated_at, scheduled_for, processed_at 148 "#, 149 now, 150 self.batch_size 151 ) 152 .fetch_all(&self.db) 153 .await 154 } 155 156 async fn process_item(&self, item: QueuedComms) { 157 let comms_id = item.id; 158 let channel = item.channel; 159 let result = match self.senders.get(&channel) { 160 Some(sender) => sender.send(&item).await, 161 None => { 162 warn!( 163 comms_id = %comms_id, 164 channel = ?channel, 165 "No sender registered for channel" 166 ); 167 Err(SendError::NotConfigured(channel)) 168 } 169 }; 170 match result { 171 Ok(()) => { 172 debug!(comms_id = %comms_id, "Comms sent successfully"); 173 if let Err(e) = self.mark_sent(comms_id).await { 174 error!( 175 comms_id = %comms_id, 176 error = %e, 177 "Failed to mark comms as sent" 178 ); 179 } 180 } 181 Err(e) => { 182 let error_msg = e.to_string(); 183 warn!( 184 comms_id = %comms_id, 185 error = %error_msg, 186 "Failed to send comms" 187 ); 188 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await { 189 error!( 190 comms_id = %comms_id, 191 error = %db_err, 192 "Failed to mark comms as failed" 193 ); 194 } 195 } 196 } 197 } 198 199 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 200 sqlx::query!( 201 r#" 202 UPDATE comms_queue 203 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 204 WHERE id = $1 205 "#, 206 id 207 ) 208 .execute(&self.db) 209 .await?; 210 Ok(()) 211 } 212 213 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 214 sqlx::query!( 215 r#" 216 UPDATE comms_queue 217 SET 218 status = CASE 219 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status 220 ELSE 'pending'::comms_status 221 END, 222 attempts = attempts + 1, 223 last_error = $2, 224 updated_at = NOW(), 225 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 226 WHERE id = $1 227 "#, 228 id, 229 error 230 ) 231 .execute(&self.db) 232 .await?; 233 Ok(()) 234 } 235} 236 237pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> { 238 sqlx::query_scalar!( 239 r#" 240 INSERT INTO comms_queue 241 (user_id, channel, comms_type, recipient, subject, body, metadata) 242 VALUES ($1, $2, $3, $4, $5, $6, $7) 243 RETURNING id 244 "#, 245 item.user_id, 246 item.channel as CommsChannel, 247 item.comms_type as super::types::CommsType, 248 item.recipient, 249 item.subject, 250 item.body, 251 item.metadata 252 ) 253 .fetch_one(db) 254 .await 255} 256 257pub struct UserCommsPrefs { 258 pub channel: CommsChannel, 259 pub email: Option<String>, 260 pub handle: String, 261 pub locale: String, 262} 263 264pub async fn get_user_comms_prefs( 265 db: &PgPool, 266 user_id: Uuid, 267) -> Result<UserCommsPrefs, sqlx::Error> { 268 let row = sqlx::query!( 269 r#" 270 SELECT 271 email, 272 handle, 273 preferred_comms_channel as "channel: CommsChannel", 274 preferred_locale 275 FROM users 276 WHERE id = $1 277 "#, 278 user_id 279 ) 280 .fetch_one(db) 281 .await?; 282 Ok(UserCommsPrefs { 283 channel: row.channel, 284 email: row.email, 285 handle: row.handle, 286 locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()), 287 }) 288} 289 290pub async fn enqueue_welcome( 291 db: &PgPool, 292 user_id: Uuid, 293 hostname: &str, 294) -> Result<Uuid, sqlx::Error> { 295 let prefs = get_user_comms_prefs(db, user_id).await?; 296 let strings = get_strings(&prefs.locale); 297 let body = format_message( 298 strings.welcome_body, 299 &[("hostname", hostname), ("handle", &prefs.handle)], 300 ); 301 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 302 enqueue_comms( 303 db, 304 NewComms::new( 305 user_id, 306 prefs.channel, 307 super::types::CommsType::Welcome, 308 prefs.email.clone().unwrap_or_default(), 309 Some(subject), 310 body, 311 ), 312 ) 313 .await 314} 315 316pub async fn enqueue_password_reset( 317 db: &PgPool, 318 user_id: Uuid, 319 code: &str, 320 hostname: &str, 321) -> Result<Uuid, sqlx::Error> { 322 let prefs = get_user_comms_prefs(db, user_id).await?; 323 let strings = get_strings(&prefs.locale); 324 let body = format_message( 325 strings.password_reset_body, 326 &[("handle", &prefs.handle), ("code", code)], 327 ); 328 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 329 enqueue_comms( 330 db, 331 NewComms::new( 332 user_id, 333 prefs.channel, 334 super::types::CommsType::PasswordReset, 335 prefs.email.clone().unwrap_or_default(), 336 Some(subject), 337 body, 338 ), 339 ) 340 .await 341} 342 343pub async fn enqueue_email_update( 344 db: &PgPool, 345 user_id: Uuid, 346 new_email: &str, 347 handle: &str, 348 code: &str, 349 hostname: &str, 350) -> Result<Uuid, sqlx::Error> { 351 let prefs = get_user_comms_prefs(db, user_id).await?; 352 let strings = get_strings(&prefs.locale); 353 let encoded_email = urlencoding::encode(new_email); 354 let encoded_token = urlencoding::encode(code); 355 let verify_page = format!("https://{}/app/verify", hostname); 356 let verify_link = format!( 357 "https://{}/app/verify?token={}&identifier={}", 358 hostname, encoded_token, encoded_email 359 ); 360 let body = format_message( 361 strings.email_update_body, 362 &[ 363 ("handle", handle), 364 ("code", code), 365 ("verify_page", &verify_page), 366 ("verify_link", &verify_link), 367 ], 368 ); 369 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 370 enqueue_comms( 371 db, 372 NewComms::email( 373 user_id, 374 super::types::CommsType::EmailUpdate, 375 new_email.to_string(), 376 subject, 377 body, 378 ), 379 ) 380 .await 381} 382 383pub async fn enqueue_email_update_token( 384 db: &PgPool, 385 user_id: Uuid, 386 code: &str, 387 hostname: &str, 388) -> Result<Uuid, sqlx::Error> { 389 let prefs = get_user_comms_prefs(db, user_id).await?; 390 let strings = get_strings(&prefs.locale); 391 let current_email = prefs.email.clone().unwrap_or_default(); 392 let verify_page = format!("https://{}/app/verify?type=email-update", hostname); 393 let verify_link = format!( 394 "https://{}/app/verify?type=email-update&token={}", 395 hostname, 396 urlencoding::encode(code) 397 ); 398 let body = format_message( 399 strings.email_update_body, 400 &[ 401 ("handle", &prefs.handle), 402 ("code", code), 403 ("verify_page", &verify_page), 404 ("verify_link", &verify_link), 405 ], 406 ); 407 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 408 enqueue_comms( 409 db, 410 NewComms::email( 411 user_id, 412 super::types::CommsType::EmailUpdate, 413 current_email, 414 subject, 415 body, 416 ), 417 ) 418 .await 419} 420 421pub async fn enqueue_account_deletion( 422 db: &PgPool, 423 user_id: Uuid, 424 code: &str, 425 hostname: &str, 426) -> Result<Uuid, sqlx::Error> { 427 let prefs = get_user_comms_prefs(db, user_id).await?; 428 let strings = get_strings(&prefs.locale); 429 let body = format_message( 430 strings.account_deletion_body, 431 &[("handle", &prefs.handle), ("code", code)], 432 ); 433 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 434 enqueue_comms( 435 db, 436 NewComms::new( 437 user_id, 438 prefs.channel, 439 super::types::CommsType::AccountDeletion, 440 prefs.email.clone().unwrap_or_default(), 441 Some(subject), 442 body, 443 ), 444 ) 445 .await 446} 447 448pub async fn enqueue_plc_operation( 449 db: &PgPool, 450 user_id: Uuid, 451 token: &str, 452 hostname: &str, 453) -> Result<Uuid, sqlx::Error> { 454 let prefs = get_user_comms_prefs(db, user_id).await?; 455 let strings = get_strings(&prefs.locale); 456 let body = format_message( 457 strings.plc_operation_body, 458 &[("handle", &prefs.handle), ("token", token)], 459 ); 460 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 461 enqueue_comms( 462 db, 463 NewComms::new( 464 user_id, 465 prefs.channel, 466 super::types::CommsType::PlcOperation, 467 prefs.email.clone().unwrap_or_default(), 468 Some(subject), 469 body, 470 ), 471 ) 472 .await 473} 474 475pub async fn enqueue_2fa_code( 476 db: &PgPool, 477 user_id: Uuid, 478 code: &str, 479 hostname: &str, 480) -> Result<Uuid, sqlx::Error> { 481 let prefs = get_user_comms_prefs(db, user_id).await?; 482 let strings = get_strings(&prefs.locale); 483 let body = format_message( 484 strings.two_factor_code_body, 485 &[("handle", &prefs.handle), ("code", code)], 486 ); 487 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 488 enqueue_comms( 489 db, 490 NewComms::new( 491 user_id, 492 prefs.channel, 493 super::types::CommsType::TwoFactorCode, 494 prefs.email.clone().unwrap_or_default(), 495 Some(subject), 496 body, 497 ), 498 ) 499 .await 500} 501 502pub async fn enqueue_passkey_recovery( 503 db: &PgPool, 504 user_id: Uuid, 505 recovery_url: &str, 506 hostname: &str, 507) -> Result<Uuid, sqlx::Error> { 508 let prefs = get_user_comms_prefs(db, user_id).await?; 509 let strings = get_strings(&prefs.locale); 510 let body = format_message( 511 strings.passkey_recovery_body, 512 &[("handle", &prefs.handle), ("url", recovery_url)], 513 ); 514 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 515 enqueue_comms( 516 db, 517 NewComms::new( 518 user_id, 519 prefs.channel, 520 super::types::CommsType::PasskeyRecovery, 521 prefs.email.clone().unwrap_or_default(), 522 Some(subject), 523 body, 524 ), 525 ) 526 .await 527} 528 529pub fn channel_display_name(channel: CommsChannel) -> &'static str { 530 match channel { 531 CommsChannel::Email => "email", 532 CommsChannel::Discord => "Discord", 533 CommsChannel::Telegram => "Telegram", 534 CommsChannel::Signal => "Signal", 535 } 536} 537 538pub async fn enqueue_signup_verification( 539 db: &PgPool, 540 user_id: Uuid, 541 channel: &str, 542 recipient: &str, 543 code: &str, 544 locale: Option<&str>, 545) -> Result<Uuid, sqlx::Error> { 546 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 547 let comms_channel = match channel { 548 "email" => CommsChannel::Email, 549 "discord" => CommsChannel::Discord, 550 "telegram" => CommsChannel::Telegram, 551 "signal" => CommsChannel::Signal, 552 _ => CommsChannel::Email, 553 }; 554 let strings = get_strings(locale.unwrap_or("en")); 555 let (verify_page, verify_link) = if comms_channel == CommsChannel::Email { 556 let encoded_email = urlencoding::encode(recipient); 557 let encoded_token = urlencoding::encode(code); 558 ( 559 format!("https://{}/app/verify", hostname), 560 format!( 561 "https://{}/app/verify?token={}&identifier={}", 562 hostname, encoded_token, encoded_email 563 ), 564 ) 565 } else { 566 (String::new(), String::new()) 567 }; 568 let body = format_message( 569 strings.signup_verification_body, 570 &[ 571 ("code", code), 572 ("hostname", &hostname), 573 ("verify_page", &verify_page), 574 ("verify_link", &verify_link), 575 ], 576 ); 577 let subject = match comms_channel { 578 CommsChannel::Email => Some(format_message( 579 strings.signup_verification_subject, 580 &[("hostname", &hostname)], 581 )), 582 _ => None, 583 }; 584 enqueue_comms( 585 db, 586 NewComms::new( 587 user_id, 588 comms_channel, 589 super::types::CommsType::EmailVerification, 590 recipient.to_string(), 591 subject, 592 body, 593 ), 594 ) 595 .await 596} 597 598pub async fn enqueue_migration_verification( 599 db: &PgPool, 600 user_id: Uuid, 601 email: &str, 602 token: &str, 603 hostname: &str, 604) -> Result<Uuid, sqlx::Error> { 605 let prefs = get_user_comms_prefs(db, user_id).await?; 606 let strings = get_strings(&prefs.locale); 607 let encoded_email = urlencoding::encode(email); 608 let encoded_token = urlencoding::encode(token); 609 let verify_page = format!("https://{}/app/verify", hostname); 610 let verify_link = format!( 611 "https://{}/app/verify?token={}&identifier={}", 612 hostname, encoded_token, encoded_email 613 ); 614 let body = format_message( 615 strings.migration_verification_body, 616 &[ 617 ("code", token), 618 ("hostname", hostname), 619 ("verify_page", &verify_page), 620 ("verify_link", &verify_link), 621 ], 622 ); 623 let subject = format_message( 624 strings.migration_verification_subject, 625 &[("hostname", hostname)], 626 ); 627 enqueue_comms( 628 db, 629 NewComms::email( 630 user_id, 631 super::types::CommsType::MigrationVerification, 632 email.to_string(), 633 subject, 634 body, 635 ), 636 ) 637 .await 638} 639 640pub async fn queue_legacy_login_notification( 641 db: &PgPool, 642 user_id: Uuid, 643 hostname: &str, 644 client_ip: &str, 645 channel: CommsChannel, 646) -> Result<Uuid, sqlx::Error> { 647 let prefs = get_user_comms_prefs(db, user_id).await?; 648 let strings = get_strings(&prefs.locale); 649 let timestamp = chrono::Utc::now() 650 .format("%Y-%m-%d %H:%M:%S UTC") 651 .to_string(); 652 let body = format_message( 653 strings.legacy_login_body, 654 &[ 655 ("handle", &prefs.handle), 656 ("timestamp", &timestamp), 657 ("ip", client_ip), 658 ("hostname", hostname), 659 ], 660 ); 661 let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]); 662 enqueue_comms( 663 db, 664 NewComms::new( 665 user_id, 666 channel, 667 super::types::CommsType::LegacyLoginAlert, 668 prefs.email.clone().unwrap_or_default(), 669 Some(subject), 670 body, 671 ), 672 ) 673 .await 674}