this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::sender::{NotificationSender, SendError}; 13use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification}; 14 15pub struct NotificationService { 16 db: PgPool, 17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>, 18 poll_interval: Duration, 19 batch_size: i64, 20} 21 22impl NotificationService { 23 pub fn new(db: PgPool) -> Self { 24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 25 .ok() 26 .and_then(|v| v.parse().ok()) 27 .unwrap_or(1000); 28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 29 .ok() 30 .and_then(|v| v.parse().ok()) 31 .unwrap_or(100); 32 Self { 33 db, 34 senders: HashMap::new(), 35 poll_interval: Duration::from_millis(poll_interval_ms), 36 batch_size, 37 } 38 } 39 40 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 41 self.poll_interval = interval; 42 self 43 } 44 45 pub fn with_batch_size(mut self, size: i64) -> Self { 46 self.batch_size = size; 47 self 48 } 49 50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self { 51 self.senders.insert(sender.channel(), Arc::new(sender)); 52 self 53 } 54 55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 56 let id = sqlx::query_scalar!( 57 r#" 58 INSERT INTO notification_queue 59 (user_id, channel, notification_type, recipient, subject, body, metadata) 60 VALUES ($1, $2, $3, $4, $5, $6, $7) 61 RETURNING id 62 "#, 63 notification.user_id, 64 notification.channel as NotificationChannel, 65 notification.notification_type as super::types::NotificationType, 66 notification.recipient, 67 notification.subject, 68 notification.body, 69 notification.metadata 70 ) 71 .fetch_one(&self.db) 72 .await?; 73 debug!(notification_id = %id, "Notification enqueued"); 74 Ok(id) 75 } 76 77 pub fn has_senders(&self) -> bool { 78 !self.senders.is_empty() 79 } 80 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 82 if self.senders.is_empty() { 83 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."); 84 } 85 info!( 86 poll_interval_secs = self.poll_interval.as_secs(), 87 batch_size = self.batch_size, 88 channels = ?self.senders.keys().collect::<Vec<_>>(), 89 "Starting notification service" 90 ); 91 let mut ticker = interval(self.poll_interval); 92 loop { 93 tokio::select! { 94 _ = ticker.tick() => { 95 if let Err(e) = self.process_batch().await { 96 error!(error = %e, "Failed to process notification batch"); 97 } 98 } 99 _ = shutdown.changed() => { 100 if *shutdown.borrow() { 101 info!("Notification service shutting down"); 102 break; 103 } 104 } 105 } 106 } 107 } 108 109 async fn process_batch(&self) -> Result<(), sqlx::Error> { 110 let notifications = self.fetch_pending_notifications().await?; 111 if notifications.is_empty() { 112 return Ok(()); 113 } 114 debug!(count = notifications.len(), "Processing notification batch"); 115 for notification in notifications { 116 self.process_notification(notification).await; 117 } 118 Ok(()) 119 } 120 121 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> { 122 let now = Utc::now(); 123 sqlx::query_as!( 124 QueuedNotification, 125 r#" 126 UPDATE notification_queue 127 SET status = 'processing', updated_at = NOW() 128 WHERE id IN ( 129 SELECT id FROM notification_queue 130 WHERE status = 'pending' 131 AND scheduled_for <= $1 132 AND attempts < max_attempts 133 ORDER BY scheduled_for ASC 134 LIMIT $2 135 FOR UPDATE SKIP LOCKED 136 ) 137 RETURNING 138 id, user_id, 139 channel as "channel: NotificationChannel", 140 notification_type as "notification_type: super::types::NotificationType", 141 status as "status: NotificationStatus", 142 recipient, subject, body, metadata, 143 attempts, max_attempts, last_error, 144 created_at, updated_at, scheduled_for, processed_at 145 "#, 146 now, 147 self.batch_size 148 ) 149 .fetch_all(&self.db) 150 .await 151 } 152 153 async fn process_notification(&self, notification: QueuedNotification) { 154 let notification_id = notification.id; 155 let channel = notification.channel; 156 let result = match self.senders.get(&channel) { 157 Some(sender) => sender.send(&notification).await, 158 None => { 159 warn!( 160 notification_id = %notification_id, 161 channel = ?channel, 162 "No sender registered for channel" 163 ); 164 Err(SendError::NotConfigured(channel)) 165 } 166 }; 167 match result { 168 Ok(()) => { 169 debug!(notification_id = %notification_id, "Notification sent successfully"); 170 if let Err(e) = self.mark_sent(notification_id).await { 171 error!( 172 notification_id = %notification_id, 173 error = %e, 174 "Failed to mark notification as sent" 175 ); 176 } 177 } 178 Err(e) => { 179 let error_msg = e.to_string(); 180 warn!( 181 notification_id = %notification_id, 182 error = %error_msg, 183 "Failed to send notification" 184 ); 185 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await { 186 error!( 187 notification_id = %notification_id, 188 error = %db_err, 189 "Failed to mark notification as failed" 190 ); 191 } 192 } 193 } 194 } 195 196 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 197 sqlx::query!( 198 r#" 199 UPDATE notification_queue 200 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 201 WHERE id = $1 202 "#, 203 id 204 ) 205 .execute(&self.db) 206 .await?; 207 Ok(()) 208 } 209 210 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 211 sqlx::query!( 212 r#" 213 UPDATE notification_queue 214 SET 215 status = CASE 216 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status 217 ELSE 'pending'::notification_status 218 END, 219 attempts = attempts + 1, 220 last_error = $2, 221 updated_at = NOW(), 222 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 223 WHERE id = $1 224 "#, 225 id, 226 error 227 ) 228 .execute(&self.db) 229 .await?; 230 Ok(()) 231 } 232} 233 234pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 235 sqlx::query_scalar!( 236 r#" 237 INSERT INTO notification_queue 238 (user_id, channel, notification_type, recipient, subject, body, metadata) 239 VALUES ($1, $2, $3, $4, $5, $6, $7) 240 RETURNING id 241 "#, 242 notification.user_id, 243 notification.channel as NotificationChannel, 244 notification.notification_type as super::types::NotificationType, 245 notification.recipient, 246 notification.subject, 247 notification.body, 248 notification.metadata 249 ) 250 .fetch_one(db) 251 .await 252} 253 254pub struct UserNotificationPrefs { 255 pub channel: NotificationChannel, 256 pub email: Option<String>, 257 pub handle: String, 258} 259 260pub async fn get_user_notification_prefs( 261 db: &PgPool, 262 user_id: Uuid, 263) -> Result<UserNotificationPrefs, sqlx::Error> { 264 let row = sqlx::query!( 265 r#" 266 SELECT 267 email, 268 handle, 269 preferred_notification_channel as "channel: NotificationChannel" 270 FROM users 271 WHERE id = $1 272 "#, 273 user_id 274 ) 275 .fetch_one(db) 276 .await?; 277 Ok(UserNotificationPrefs { 278 channel: row.channel, 279 email: row.email, 280 handle: row.handle, 281 }) 282} 283 284pub async fn enqueue_welcome( 285 db: &PgPool, 286 user_id: Uuid, 287 hostname: &str, 288) -> Result<Uuid, sqlx::Error> { 289 let prefs = get_user_notification_prefs(db, user_id).await?; 290 let body = format!( 291 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.", 292 hostname, prefs.handle 293 ); 294 enqueue_notification( 295 db, 296 NewNotification::new( 297 user_id, 298 prefs.channel, 299 super::types::NotificationType::Welcome, 300 prefs.email.clone().unwrap_or_default(), 301 Some(format!("Welcome to {}", hostname)), 302 body, 303 ), 304 ) 305 .await 306} 307 308pub async fn enqueue_email_verification( 309 db: &PgPool, 310 user_id: Uuid, 311 email: &str, 312 handle: &str, 313 code: &str, 314 hostname: &str, 315) -> Result<Uuid, sqlx::Error> { 316 let body = format!( 317 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 318 handle, code 319 ); 320 enqueue_notification( 321 db, 322 NewNotification::email( 323 user_id, 324 super::types::NotificationType::EmailVerification, 325 email.to_string(), 326 format!("Verify your email - {}", hostname), 327 body, 328 ), 329 ) 330 .await 331} 332 333pub async fn enqueue_password_reset( 334 db: &PgPool, 335 user_id: Uuid, 336 code: &str, 337 hostname: &str, 338) -> Result<Uuid, sqlx::Error> { 339 let prefs = get_user_notification_prefs(db, user_id).await?; 340 let body = format!( 341 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.", 342 prefs.handle, code 343 ); 344 enqueue_notification( 345 db, 346 NewNotification::new( 347 user_id, 348 prefs.channel, 349 super::types::NotificationType::PasswordReset, 350 prefs.email.clone().unwrap_or_default(), 351 Some(format!("Password Reset - {}", hostname)), 352 body, 353 ), 354 ) 355 .await 356} 357 358pub async fn enqueue_email_update( 359 db: &PgPool, 360 user_id: Uuid, 361 new_email: &str, 362 handle: &str, 363 code: &str, 364 hostname: &str, 365) -> Result<Uuid, sqlx::Error> { 366 let body = format!( 367 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 368 handle, code 369 ); 370 enqueue_notification( 371 db, 372 NewNotification::email( 373 user_id, 374 super::types::NotificationType::EmailUpdate, 375 new_email.to_string(), 376 format!("Confirm your new email - {}", hostname), 377 body, 378 ), 379 ) 380 .await 381} 382 383pub async fn enqueue_account_deletion( 384 db: &PgPool, 385 user_id: Uuid, 386 code: &str, 387 hostname: &str, 388) -> Result<Uuid, sqlx::Error> { 389 let prefs = get_user_notification_prefs(db, user_id).await?; 390 let body = format!( 391 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 392 prefs.handle, code 393 ); 394 enqueue_notification( 395 db, 396 NewNotification::new( 397 user_id, 398 prefs.channel, 399 super::types::NotificationType::AccountDeletion, 400 prefs.email.clone().unwrap_or_default(), 401 Some(format!("Account Deletion Request - {}", hostname)), 402 body, 403 ), 404 ) 405 .await 406} 407 408pub async fn enqueue_plc_operation( 409 db: &PgPool, 410 user_id: Uuid, 411 token: &str, 412 hostname: &str, 413) -> Result<Uuid, sqlx::Error> { 414 let prefs = get_user_notification_prefs(db, user_id).await?; 415 let body = format!( 416 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.", 417 prefs.handle, token 418 ); 419 enqueue_notification( 420 db, 421 NewNotification::new( 422 user_id, 423 prefs.channel, 424 super::types::NotificationType::PlcOperation, 425 prefs.email.clone().unwrap_or_default(), 426 Some(format!("{} - PLC Operation Token", hostname)), 427 body, 428 ), 429 ) 430 .await 431} 432 433pub async fn enqueue_2fa_code( 434 db: &PgPool, 435 user_id: Uuid, 436 code: &str, 437 hostname: &str, 438) -> Result<Uuid, sqlx::Error> { 439 let prefs = get_user_notification_prefs(db, user_id).await?; 440 let body = format!( 441 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 442 prefs.handle, code 443 ); 444 enqueue_notification( 445 db, 446 NewNotification::new( 447 user_id, 448 prefs.channel, 449 super::types::NotificationType::TwoFactorCode, 450 prefs.email.clone().unwrap_or_default(), 451 Some(format!("Sign-in Verification - {}", hostname)), 452 body, 453 ), 454 ) 455 .await 456} 457 458pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 459 match channel { 460 NotificationChannel::Email => "email", 461 NotificationChannel::Discord => "Discord", 462 NotificationChannel::Telegram => "Telegram", 463 NotificationChannel::Signal => "Signal", 464 } 465} 466 467pub async fn enqueue_signup_verification( 468 db: &PgPool, 469 user_id: Uuid, 470 channel: &str, 471 recipient: &str, 472 code: &str, 473) -> Result<Uuid, sqlx::Error> { 474 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 475 let notification_channel = match channel { 476 "email" => NotificationChannel::Email, 477 "discord" => NotificationChannel::Discord, 478 "telegram" => NotificationChannel::Telegram, 479 "signal" => NotificationChannel::Signal, 480 _ => NotificationChannel::Email, 481 }; 482 let body = format!( 483 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.", 484 code, hostname 485 ); 486 let subject = match notification_channel { 487 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)), 488 _ => None, 489 }; 490 enqueue_notification( 491 db, 492 NewNotification::new( 493 user_id, 494 notification_channel, 495 super::types::NotificationType::EmailVerification, 496 recipient.to_string(), 497 subject, 498 body, 499 ), 500 ) 501 .await 502}