this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::locale::{format_message, get_strings}; 13use super::sender::{CommsSender, SendError}; 14use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 15 16pub struct CommsService { 17 db: PgPool, 18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>, 19 poll_interval: Duration, 20 batch_size: i64, 21} 22 23impl CommsService { 24 pub fn new(db: PgPool) -> Self { 25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 26 .ok() 27 .and_then(|v| v.parse().ok()) 28 .unwrap_or(1000); 29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 30 .ok() 31 .and_then(|v| v.parse().ok()) 32 .unwrap_or(100); 33 Self { 34 db, 35 senders: HashMap::new(), 36 poll_interval: Duration::from_millis(poll_interval_ms), 37 batch_size, 38 } 39 } 40 41 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 42 self.poll_interval = interval; 43 self 44 } 45 46 pub fn with_batch_size(mut self, size: i64) -> Self { 47 self.batch_size = size; 48 self 49 } 50 51 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self { 52 self.senders.insert(sender.channel(), Arc::new(sender)); 53 self 54 } 55 56 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> { 57 let id = sqlx::query_scalar!( 58 r#" 59 INSERT INTO comms_queue 60 (user_id, channel, comms_type, recipient, subject, body, metadata) 61 VALUES ($1, $2, $3, $4, $5, $6, $7) 62 RETURNING id 63 "#, 64 item.user_id, 65 item.channel as CommsChannel, 66 item.comms_type as super::types::CommsType, 67 item.recipient, 68 item.subject, 69 item.body, 70 item.metadata 71 ) 72 .fetch_one(&self.db) 73 .await?; 74 debug!(comms_id = %id, "Comms enqueued"); 75 Ok(id) 76 } 77 78 pub fn has_senders(&self) -> bool { 79 !self.senders.is_empty() 80 } 81 82 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 83 if self.senders.is_empty() { 84 warn!( 85 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured." 86 ); 87 } 88 info!( 89 poll_interval_secs = self.poll_interval.as_secs(), 90 batch_size = self.batch_size, 91 channels = ?self.senders.keys().collect::<Vec<_>>(), 92 "Starting comms service" 93 ); 94 let mut ticker = interval(self.poll_interval); 95 loop { 96 tokio::select! { 97 _ = ticker.tick() => { 98 if let Err(e) = self.process_batch().await { 99 error!(error = %e, "Failed to process comms batch"); 100 } 101 } 102 _ = shutdown.changed() => { 103 if *shutdown.borrow() { 104 info!("Comms service shutting down"); 105 break; 106 } 107 } 108 } 109 } 110 } 111 112 async fn process_batch(&self) -> Result<(), sqlx::Error> { 113 let items = self.fetch_pending().await?; 114 if items.is_empty() { 115 return Ok(()); 116 } 117 debug!(count = items.len(), "Processing comms batch"); 118 for item in items { 119 self.process_item(item).await; 120 } 121 Ok(()) 122 } 123 124 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> { 125 let now = Utc::now(); 126 sqlx::query_as!( 127 QueuedComms, 128 r#" 129 UPDATE comms_queue 130 SET status = 'processing', updated_at = NOW() 131 WHERE id IN ( 132 SELECT id FROM comms_queue 133 WHERE status = 'pending' 134 AND scheduled_for <= $1 135 AND attempts < max_attempts 136 ORDER BY scheduled_for ASC 137 LIMIT $2 138 FOR UPDATE SKIP LOCKED 139 ) 140 RETURNING 141 id, user_id, 142 channel as "channel: CommsChannel", 143 comms_type as "comms_type: super::types::CommsType", 144 status as "status: CommsStatus", 145 recipient, subject, body, metadata, 146 attempts, max_attempts, last_error, 147 created_at, updated_at, scheduled_for, processed_at 148 "#, 149 now, 150 self.batch_size 151 ) 152 .fetch_all(&self.db) 153 .await 154 } 155 156 async fn process_item(&self, item: QueuedComms) { 157 let comms_id = item.id; 158 let channel = item.channel; 159 let result = match self.senders.get(&channel) { 160 Some(sender) => sender.send(&item).await, 161 None => { 162 warn!( 163 comms_id = %comms_id, 164 channel = ?channel, 165 "No sender registered for channel" 166 ); 167 Err(SendError::NotConfigured(channel)) 168 } 169 }; 170 match result { 171 Ok(()) => { 172 debug!(comms_id = %comms_id, "Comms sent successfully"); 173 if let Err(e) = self.mark_sent(comms_id).await { 174 error!( 175 comms_id = %comms_id, 176 error = %e, 177 "Failed to mark comms as sent" 178 ); 179 } 180 } 181 Err(e) => { 182 let error_msg = e.to_string(); 183 warn!( 184 comms_id = %comms_id, 185 error = %error_msg, 186 "Failed to send comms" 187 ); 188 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await { 189 error!( 190 comms_id = %comms_id, 191 error = %db_err, 192 "Failed to mark comms as failed" 193 ); 194 } 195 } 196 } 197 } 198 199 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 200 sqlx::query!( 201 r#" 202 UPDATE comms_queue 203 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 204 WHERE id = $1 205 "#, 206 id 207 ) 208 .execute(&self.db) 209 .await?; 210 Ok(()) 211 } 212 213 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 214 sqlx::query!( 215 r#" 216 UPDATE comms_queue 217 SET 218 status = CASE 219 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status 220 ELSE 'pending'::comms_status 221 END, 222 attempts = attempts + 1, 223 last_error = $2, 224 updated_at = NOW(), 225 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 226 WHERE id = $1 227 "#, 228 id, 229 error 230 ) 231 .execute(&self.db) 232 .await?; 233 Ok(()) 234 } 235} 236 237pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> { 238 sqlx::query_scalar!( 239 r#" 240 INSERT INTO comms_queue 241 (user_id, channel, comms_type, recipient, subject, body, metadata) 242 VALUES ($1, $2, $3, $4, $5, $6, $7) 243 RETURNING id 244 "#, 245 item.user_id, 246 item.channel as CommsChannel, 247 item.comms_type as super::types::CommsType, 248 item.recipient, 249 item.subject, 250 item.body, 251 item.metadata 252 ) 253 .fetch_one(db) 254 .await 255} 256 257pub struct UserCommsPrefs { 258 pub channel: CommsChannel, 259 pub email: Option<String>, 260 pub handle: String, 261 pub locale: String, 262} 263 264pub async fn get_user_comms_prefs( 265 db: &PgPool, 266 user_id: Uuid, 267) -> Result<UserCommsPrefs, sqlx::Error> { 268 let row = sqlx::query!( 269 r#" 270 SELECT 271 email, 272 handle, 273 preferred_comms_channel as "channel: CommsChannel", 274 preferred_locale 275 FROM users 276 WHERE id = $1 277 "#, 278 user_id 279 ) 280 .fetch_one(db) 281 .await?; 282 Ok(UserCommsPrefs { 283 channel: row.channel, 284 email: row.email, 285 handle: row.handle, 286 locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()), 287 }) 288} 289 290pub async fn enqueue_welcome( 291 db: &PgPool, 292 user_id: Uuid, 293 hostname: &str, 294) -> Result<Uuid, sqlx::Error> { 295 let prefs = get_user_comms_prefs(db, user_id).await?; 296 let strings = get_strings(&prefs.locale); 297 let body = format_message( 298 strings.welcome_body, 299 &[("hostname", hostname), ("handle", &prefs.handle)], 300 ); 301 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 302 enqueue_comms( 303 db, 304 NewComms::new( 305 user_id, 306 prefs.channel, 307 super::types::CommsType::Welcome, 308 prefs.email.clone().unwrap_or_default(), 309 Some(subject), 310 body, 311 ), 312 ) 313 .await 314} 315 316pub async fn enqueue_password_reset( 317 db: &PgPool, 318 user_id: Uuid, 319 code: &str, 320 hostname: &str, 321) -> Result<Uuid, sqlx::Error> { 322 let prefs = get_user_comms_prefs(db, user_id).await?; 323 let strings = get_strings(&prefs.locale); 324 let body = format_message( 325 strings.password_reset_body, 326 &[("handle", &prefs.handle), ("code", code)], 327 ); 328 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 329 enqueue_comms( 330 db, 331 NewComms::new( 332 user_id, 333 prefs.channel, 334 super::types::CommsType::PasswordReset, 335 prefs.email.clone().unwrap_or_default(), 336 Some(subject), 337 body, 338 ), 339 ) 340 .await 341} 342 343pub async fn enqueue_email_update( 344 db: &PgPool, 345 user_id: Uuid, 346 new_email: &str, 347 handle: &str, 348 code: &str, 349 hostname: &str, 350) -> Result<Uuid, sqlx::Error> { 351 let prefs = get_user_comms_prefs(db, user_id).await?; 352 let strings = get_strings(&prefs.locale); 353 let encoded_email = urlencoding::encode(new_email); 354 let encoded_token = urlencoding::encode(code); 355 let verify_page = format!("https://{}/#/verify", hostname); 356 let verify_link = format!( 357 "https://{}/#/verify?token={}&identifier={}", 358 hostname, encoded_token, encoded_email 359 ); 360 let body = format_message( 361 strings.email_update_body, 362 &[ 363 ("handle", handle), 364 ("code", code), 365 ("verify_page", &verify_page), 366 ("verify_link", &verify_link), 367 ], 368 ); 369 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 370 enqueue_comms( 371 db, 372 NewComms::email( 373 user_id, 374 super::types::CommsType::EmailUpdate, 375 new_email.to_string(), 376 subject, 377 body, 378 ), 379 ) 380 .await 381} 382 383pub async fn enqueue_account_deletion( 384 db: &PgPool, 385 user_id: Uuid, 386 code: &str, 387 hostname: &str, 388) -> Result<Uuid, sqlx::Error> { 389 let prefs = get_user_comms_prefs(db, user_id).await?; 390 let strings = get_strings(&prefs.locale); 391 let body = format_message( 392 strings.account_deletion_body, 393 &[("handle", &prefs.handle), ("code", code)], 394 ); 395 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 396 enqueue_comms( 397 db, 398 NewComms::new( 399 user_id, 400 prefs.channel, 401 super::types::CommsType::AccountDeletion, 402 prefs.email.clone().unwrap_or_default(), 403 Some(subject), 404 body, 405 ), 406 ) 407 .await 408} 409 410pub async fn enqueue_plc_operation( 411 db: &PgPool, 412 user_id: Uuid, 413 token: &str, 414 hostname: &str, 415) -> Result<Uuid, sqlx::Error> { 416 let prefs = get_user_comms_prefs(db, user_id).await?; 417 let strings = get_strings(&prefs.locale); 418 let body = format_message( 419 strings.plc_operation_body, 420 &[("handle", &prefs.handle), ("token", token)], 421 ); 422 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 423 enqueue_comms( 424 db, 425 NewComms::new( 426 user_id, 427 prefs.channel, 428 super::types::CommsType::PlcOperation, 429 prefs.email.clone().unwrap_or_default(), 430 Some(subject), 431 body, 432 ), 433 ) 434 .await 435} 436 437pub async fn enqueue_2fa_code( 438 db: &PgPool, 439 user_id: Uuid, 440 code: &str, 441 hostname: &str, 442) -> Result<Uuid, sqlx::Error> { 443 let prefs = get_user_comms_prefs(db, user_id).await?; 444 let strings = get_strings(&prefs.locale); 445 let body = format_message( 446 strings.two_factor_code_body, 447 &[("handle", &prefs.handle), ("code", code)], 448 ); 449 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 450 enqueue_comms( 451 db, 452 NewComms::new( 453 user_id, 454 prefs.channel, 455 super::types::CommsType::TwoFactorCode, 456 prefs.email.clone().unwrap_or_default(), 457 Some(subject), 458 body, 459 ), 460 ) 461 .await 462} 463 464pub async fn enqueue_passkey_recovery( 465 db: &PgPool, 466 user_id: Uuid, 467 recovery_url: &str, 468 hostname: &str, 469) -> Result<Uuid, sqlx::Error> { 470 let prefs = get_user_comms_prefs(db, user_id).await?; 471 let strings = get_strings(&prefs.locale); 472 let body = format_message( 473 strings.passkey_recovery_body, 474 &[("handle", &prefs.handle), ("url", recovery_url)], 475 ); 476 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 477 enqueue_comms( 478 db, 479 NewComms::new( 480 user_id, 481 prefs.channel, 482 super::types::CommsType::PasskeyRecovery, 483 prefs.email.clone().unwrap_or_default(), 484 Some(subject), 485 body, 486 ), 487 ) 488 .await 489} 490 491pub fn channel_display_name(channel: CommsChannel) -> &'static str { 492 match channel { 493 CommsChannel::Email => "email", 494 CommsChannel::Discord => "Discord", 495 CommsChannel::Telegram => "Telegram", 496 CommsChannel::Signal => "Signal", 497 } 498} 499 500pub async fn enqueue_signup_verification( 501 db: &PgPool, 502 user_id: Uuid, 503 channel: &str, 504 recipient: &str, 505 code: &str, 506 locale: Option<&str>, 507) -> Result<Uuid, sqlx::Error> { 508 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 509 let comms_channel = match channel { 510 "email" => CommsChannel::Email, 511 "discord" => CommsChannel::Discord, 512 "telegram" => CommsChannel::Telegram, 513 "signal" => CommsChannel::Signal, 514 _ => CommsChannel::Email, 515 }; 516 let strings = get_strings(locale.unwrap_or("en")); 517 let (verify_page, verify_link) = if comms_channel == CommsChannel::Email { 518 let encoded_email = urlencoding::encode(recipient); 519 let encoded_token = urlencoding::encode(code); 520 ( 521 format!("https://{}/#/verify", hostname), 522 format!( 523 "https://{}/#/verify?token={}&identifier={}", 524 hostname, encoded_token, encoded_email 525 ), 526 ) 527 } else { 528 (String::new(), String::new()) 529 }; 530 let body = format_message( 531 strings.signup_verification_body, 532 &[ 533 ("code", code), 534 ("hostname", &hostname), 535 ("verify_page", &verify_page), 536 ("verify_link", &verify_link), 537 ], 538 ); 539 let subject = match comms_channel { 540 CommsChannel::Email => Some(format_message( 541 strings.signup_verification_subject, 542 &[("hostname", &hostname)], 543 )), 544 _ => None, 545 }; 546 enqueue_comms( 547 db, 548 NewComms::new( 549 user_id, 550 comms_channel, 551 super::types::CommsType::EmailVerification, 552 recipient.to_string(), 553 subject, 554 body, 555 ), 556 ) 557 .await 558} 559 560pub async fn enqueue_migration_verification( 561 db: &PgPool, 562 user_id: Uuid, 563 email: &str, 564 token: &str, 565 hostname: &str, 566) -> Result<Uuid, sqlx::Error> { 567 let prefs = get_user_comms_prefs(db, user_id).await?; 568 let strings = get_strings(&prefs.locale); 569 let encoded_email = urlencoding::encode(email); 570 let encoded_token = urlencoding::encode(token); 571 let verify_page = format!("https://{}/#/verify", hostname); 572 let verify_link = format!( 573 "https://{}/#/verify?token={}&identifier={}", 574 hostname, encoded_token, encoded_email 575 ); 576 let body = format_message( 577 strings.migration_verification_body, 578 &[ 579 ("code", token), 580 ("hostname", hostname), 581 ("verify_page", &verify_page), 582 ("verify_link", &verify_link), 583 ], 584 ); 585 let subject = format_message( 586 strings.migration_verification_subject, 587 &[("hostname", hostname)], 588 ); 589 enqueue_comms( 590 db, 591 NewComms::email( 592 user_id, 593 super::types::CommsType::MigrationVerification, 594 email.to_string(), 595 subject, 596 body, 597 ), 598 ) 599 .await 600} 601 602pub async fn queue_legacy_login_notification( 603 db: &PgPool, 604 user_id: Uuid, 605 hostname: &str, 606 client_ip: &str, 607 channel: CommsChannel, 608) -> Result<Uuid, sqlx::Error> { 609 let prefs = get_user_comms_prefs(db, user_id).await?; 610 let strings = get_strings(&prefs.locale); 611 let timestamp = chrono::Utc::now() 612 .format("%Y-%m-%d %H:%M:%S UTC") 613 .to_string(); 614 let body = format_message( 615 strings.legacy_login_body, 616 &[ 617 ("handle", &prefs.handle), 618 ("timestamp", &timestamp), 619 ("ip", client_ip), 620 ("hostname", hostname), 621 ], 622 ); 623 let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]); 624 enqueue_comms( 625 db, 626 NewComms::new( 627 user_id, 628 channel, 629 super::types::CommsType::LegacyLoginAlert, 630 prefs.email.clone().unwrap_or_default(), 631 Some(subject), 632 body, 633 ), 634 ) 635 .await 636}