this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::sender::{CommsSender, SendError}; 13use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms}; 14 15pub struct CommsService { 16 db: PgPool, 17 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>, 18 poll_interval: Duration, 19 batch_size: i64, 20} 21 22impl CommsService { 23 pub fn new(db: PgPool) -> Self { 24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 25 .ok() 26 .and_then(|v| v.parse().ok()) 27 .unwrap_or(1000); 28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 29 .ok() 30 .and_then(|v| v.parse().ok()) 31 .unwrap_or(100); 32 Self { 33 db, 34 senders: HashMap::new(), 35 poll_interval: Duration::from_millis(poll_interval_ms), 36 batch_size, 37 } 38 } 39 40 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 41 self.poll_interval = interval; 42 self 43 } 44 45 pub fn with_batch_size(mut self, size: i64) -> Self { 46 self.batch_size = size; 47 self 48 } 49 50 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self { 51 self.senders.insert(sender.channel(), Arc::new(sender)); 52 self 53 } 54 55 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> { 56 let id = sqlx::query_scalar!( 57 r#" 58 INSERT INTO comms_queue 59 (user_id, channel, comms_type, recipient, subject, body, metadata) 60 VALUES ($1, $2, $3, $4, $5, $6, $7) 61 RETURNING id 62 "#, 63 item.user_id, 64 item.channel as CommsChannel, 65 item.comms_type as super::types::CommsType, 66 item.recipient, 67 item.subject, 68 item.body, 69 item.metadata 70 ) 71 .fetch_one(&self.db) 72 .await?; 73 debug!(comms_id = %id, "Comms enqueued"); 74 Ok(id) 75 } 76 77 pub fn has_senders(&self) -> bool { 78 !self.senders.is_empty() 79 } 80 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 82 if self.senders.is_empty() { 83 warn!( 84 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured." 85 ); 86 } 87 info!( 88 poll_interval_secs = self.poll_interval.as_secs(), 89 batch_size = self.batch_size, 90 channels = ?self.senders.keys().collect::<Vec<_>>(), 91 "Starting comms service" 92 ); 93 let mut ticker = interval(self.poll_interval); 94 loop { 95 tokio::select! { 96 _ = ticker.tick() => { 97 if let Err(e) = self.process_batch().await { 98 error!(error = %e, "Failed to process comms batch"); 99 } 100 } 101 _ = shutdown.changed() => { 102 if *shutdown.borrow() { 103 info!("Comms service shutting down"); 104 break; 105 } 106 } 107 } 108 } 109 } 110 111 async fn process_batch(&self) -> Result<(), sqlx::Error> { 112 let items = self.fetch_pending().await?; 113 if items.is_empty() { 114 return Ok(()); 115 } 116 debug!(count = items.len(), "Processing comms batch"); 117 for item in items { 118 self.process_item(item).await; 119 } 120 Ok(()) 121 } 122 123 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> { 124 let now = Utc::now(); 125 sqlx::query_as!( 126 QueuedComms, 127 r#" 128 UPDATE comms_queue 129 SET status = 'processing', updated_at = NOW() 130 WHERE id IN ( 131 SELECT id FROM comms_queue 132 WHERE status = 'pending' 133 AND scheduled_for <= $1 134 AND attempts < max_attempts 135 ORDER BY scheduled_for ASC 136 LIMIT $2 137 FOR UPDATE SKIP LOCKED 138 ) 139 RETURNING 140 id, user_id, 141 channel as "channel: CommsChannel", 142 comms_type as "comms_type: super::types::CommsType", 143 status as "status: CommsStatus", 144 recipient, subject, body, metadata, 145 attempts, max_attempts, last_error, 146 created_at, updated_at, scheduled_for, processed_at 147 "#, 148 now, 149 self.batch_size 150 ) 151 .fetch_all(&self.db) 152 .await 153 } 154 155 async fn process_item(&self, item: QueuedComms) { 156 let comms_id = item.id; 157 let channel = item.channel; 158 let result = match self.senders.get(&channel) { 159 Some(sender) => sender.send(&item).await, 160 None => { 161 warn!( 162 comms_id = %comms_id, 163 channel = ?channel, 164 "No sender registered for channel" 165 ); 166 Err(SendError::NotConfigured(channel)) 167 } 168 }; 169 match result { 170 Ok(()) => { 171 debug!(comms_id = %comms_id, "Comms sent successfully"); 172 if let Err(e) = self.mark_sent(comms_id).await { 173 error!( 174 comms_id = %comms_id, 175 error = %e, 176 "Failed to mark comms as sent" 177 ); 178 } 179 } 180 Err(e) => { 181 let error_msg = e.to_string(); 182 warn!( 183 comms_id = %comms_id, 184 error = %error_msg, 185 "Failed to send comms" 186 ); 187 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await { 188 error!( 189 comms_id = %comms_id, 190 error = %db_err, 191 "Failed to mark comms as failed" 192 ); 193 } 194 } 195 } 196 } 197 198 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 199 sqlx::query!( 200 r#" 201 UPDATE comms_queue 202 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 203 WHERE id = $1 204 "#, 205 id 206 ) 207 .execute(&self.db) 208 .await?; 209 Ok(()) 210 } 211 212 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 213 sqlx::query!( 214 r#" 215 UPDATE comms_queue 216 SET 217 status = CASE 218 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status 219 ELSE 'pending'::comms_status 220 END, 221 attempts = attempts + 1, 222 last_error = $2, 223 updated_at = NOW(), 224 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 225 WHERE id = $1 226 "#, 227 id, 228 error 229 ) 230 .execute(&self.db) 231 .await?; 232 Ok(()) 233 } 234} 235 236pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> { 237 sqlx::query_scalar!( 238 r#" 239 INSERT INTO comms_queue 240 (user_id, channel, comms_type, recipient, subject, body, metadata) 241 VALUES ($1, $2, $3, $4, $5, $6, $7) 242 RETURNING id 243 "#, 244 item.user_id, 245 item.channel as CommsChannel, 246 item.comms_type as super::types::CommsType, 247 item.recipient, 248 item.subject, 249 item.body, 250 item.metadata 251 ) 252 .fetch_one(db) 253 .await 254} 255 256pub struct UserCommsPrefs { 257 pub channel: CommsChannel, 258 pub email: Option<String>, 259 pub handle: String, 260} 261 262pub async fn get_user_comms_prefs( 263 db: &PgPool, 264 user_id: Uuid, 265) -> Result<UserCommsPrefs, sqlx::Error> { 266 let row = sqlx::query!( 267 r#" 268 SELECT 269 email, 270 handle, 271 preferred_comms_channel as "channel: CommsChannel" 272 FROM users 273 WHERE id = $1 274 "#, 275 user_id 276 ) 277 .fetch_one(db) 278 .await?; 279 Ok(UserCommsPrefs { 280 channel: row.channel, 281 email: row.email, 282 handle: row.handle, 283 }) 284} 285 286pub async fn enqueue_welcome( 287 db: &PgPool, 288 user_id: Uuid, 289 hostname: &str, 290) -> Result<Uuid, sqlx::Error> { 291 let prefs = get_user_comms_prefs(db, user_id).await?; 292 let body = format!( 293 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.", 294 hostname, prefs.handle 295 ); 296 enqueue_comms( 297 db, 298 NewComms::new( 299 user_id, 300 prefs.channel, 301 super::types::CommsType::Welcome, 302 prefs.email.clone().unwrap_or_default(), 303 Some(format!("Welcome to {}", hostname)), 304 body, 305 ), 306 ) 307 .await 308} 309 310pub async fn enqueue_email_verification( 311 db: &PgPool, 312 user_id: Uuid, 313 email: &str, 314 handle: &str, 315 code: &str, 316 hostname: &str, 317) -> Result<Uuid, sqlx::Error> { 318 let body = format!( 319 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 320 handle, code 321 ); 322 enqueue_comms( 323 db, 324 NewComms::email( 325 user_id, 326 super::types::CommsType::EmailVerification, 327 email.to_string(), 328 format!("Verify your email - {}", hostname), 329 body, 330 ), 331 ) 332 .await 333} 334 335pub async fn enqueue_password_reset( 336 db: &PgPool, 337 user_id: Uuid, 338 code: &str, 339 hostname: &str, 340) -> Result<Uuid, sqlx::Error> { 341 let prefs = get_user_comms_prefs(db, user_id).await?; 342 let body = format!( 343 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.", 344 prefs.handle, code 345 ); 346 enqueue_comms( 347 db, 348 NewComms::new( 349 user_id, 350 prefs.channel, 351 super::types::CommsType::PasswordReset, 352 prefs.email.clone().unwrap_or_default(), 353 Some(format!("Password Reset - {}", hostname)), 354 body, 355 ), 356 ) 357 .await 358} 359 360pub async fn enqueue_email_update( 361 db: &PgPool, 362 user_id: Uuid, 363 new_email: &str, 364 handle: &str, 365 code: &str, 366 hostname: &str, 367) -> Result<Uuid, sqlx::Error> { 368 let body = format!( 369 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 370 handle, code 371 ); 372 enqueue_comms( 373 db, 374 NewComms::email( 375 user_id, 376 super::types::CommsType::EmailUpdate, 377 new_email.to_string(), 378 format!("Confirm your new email - {}", hostname), 379 body, 380 ), 381 ) 382 .await 383} 384 385pub async fn enqueue_account_deletion( 386 db: &PgPool, 387 user_id: Uuid, 388 code: &str, 389 hostname: &str, 390) -> Result<Uuid, sqlx::Error> { 391 let prefs = get_user_comms_prefs(db, user_id).await?; 392 let body = format!( 393 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 394 prefs.handle, code 395 ); 396 enqueue_comms( 397 db, 398 NewComms::new( 399 user_id, 400 prefs.channel, 401 super::types::CommsType::AccountDeletion, 402 prefs.email.clone().unwrap_or_default(), 403 Some(format!("Account Deletion Request - {}", hostname)), 404 body, 405 ), 406 ) 407 .await 408} 409 410pub async fn enqueue_plc_operation( 411 db: &PgPool, 412 user_id: Uuid, 413 token: &str, 414 hostname: &str, 415) -> Result<Uuid, sqlx::Error> { 416 let prefs = get_user_comms_prefs(db, user_id).await?; 417 let body = format!( 418 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.", 419 prefs.handle, token 420 ); 421 enqueue_comms( 422 db, 423 NewComms::new( 424 user_id, 425 prefs.channel, 426 super::types::CommsType::PlcOperation, 427 prefs.email.clone().unwrap_or_default(), 428 Some(format!("{} - PLC Operation Token", hostname)), 429 body, 430 ), 431 ) 432 .await 433} 434 435pub async fn enqueue_2fa_code( 436 db: &PgPool, 437 user_id: Uuid, 438 code: &str, 439 hostname: &str, 440) -> Result<Uuid, sqlx::Error> { 441 let prefs = get_user_comms_prefs(db, user_id).await?; 442 let body = format!( 443 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 444 prefs.handle, code 445 ); 446 enqueue_comms( 447 db, 448 NewComms::new( 449 user_id, 450 prefs.channel, 451 super::types::CommsType::TwoFactorCode, 452 prefs.email.clone().unwrap_or_default(), 453 Some(format!("Sign-in Verification - {}", hostname)), 454 body, 455 ), 456 ) 457 .await 458} 459 460pub fn channel_display_name(channel: CommsChannel) -> &'static str { 461 match channel { 462 CommsChannel::Email => "email", 463 CommsChannel::Discord => "Discord", 464 CommsChannel::Telegram => "Telegram", 465 CommsChannel::Signal => "Signal", 466 } 467} 468 469pub async fn enqueue_signup_verification( 470 db: &PgPool, 471 user_id: Uuid, 472 channel: &str, 473 recipient: &str, 474 code: &str, 475) -> Result<Uuid, sqlx::Error> { 476 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 477 let comms_channel = match channel { 478 "email" => CommsChannel::Email, 479 "discord" => CommsChannel::Discord, 480 "telegram" => CommsChannel::Telegram, 481 "signal" => CommsChannel::Signal, 482 _ => CommsChannel::Email, 483 }; 484 let body = format!( 485 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.", 486 code, hostname 487 ); 488 let subject = match comms_channel { 489 CommsChannel::Email => Some(format!("Verify your account - {}", hostname)), 490 _ => None, 491 }; 492 enqueue_comms( 493 db, 494 NewComms::new( 495 user_id, 496 comms_channel, 497 super::types::CommsType::EmailVerification, 498 recipient.to_string(), 499 subject, 500 body, 501 ), 502 ) 503 .await 504}