this repo has no description
1use std::collections::HashMap; 2use std::sync::Arc; 3use std::time::Duration; 4 5use chrono::Utc; 6use sqlx::PgPool; 7use tokio::sync::watch; 8use tokio::time::interval; 9use tracing::{debug, error, info, warn}; 10use uuid::Uuid; 11 12use super::sender::{NotificationSender, SendError}; 13use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification}; 14 15pub struct NotificationService { 16 db: PgPool, 17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>, 18 poll_interval: Duration, 19 batch_size: i64, 20} 21 22impl NotificationService { 23 pub fn new(db: PgPool) -> Self { 24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 25 .ok() 26 .and_then(|v| v.parse().ok()) 27 .unwrap_or(1000); 28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE") 29 .ok() 30 .and_then(|v| v.parse().ok()) 31 .unwrap_or(100); 32 Self { 33 db, 34 senders: HashMap::new(), 35 poll_interval: Duration::from_millis(poll_interval_ms), 36 batch_size, 37 } 38 } 39 40 pub fn with_poll_interval(mut self, interval: Duration) -> Self { 41 self.poll_interval = interval; 42 self 43 } 44 45 pub fn with_batch_size(mut self, size: i64) -> Self { 46 self.batch_size = size; 47 self 48 } 49 50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self { 51 self.senders.insert(sender.channel(), Arc::new(sender)); 52 self 53 } 54 55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> { 56 let id = sqlx::query_scalar!( 57 r#" 58 INSERT INTO notification_queue 59 (user_id, channel, notification_type, recipient, subject, body, metadata) 60 VALUES ($1, $2, $3, $4, $5, $6, $7) 61 RETURNING id 62 "#, 63 notification.user_id, 64 notification.channel as NotificationChannel, 65 notification.notification_type as super::types::NotificationType, 66 notification.recipient, 67 notification.subject, 68 notification.body, 69 notification.metadata 70 ) 71 .fetch_one(&self.db) 72 .await?; 73 debug!(notification_id = %id, "Notification enqueued"); 74 Ok(id) 75 } 76 77 pub fn has_senders(&self) -> bool { 78 !self.senders.is_empty() 79 } 80 81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) { 82 if self.senders.is_empty() { 83 warn!( 84 "Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured." 85 ); 86 } 87 info!( 88 poll_interval_secs = self.poll_interval.as_secs(), 89 batch_size = self.batch_size, 90 channels = ?self.senders.keys().collect::<Vec<_>>(), 91 "Starting notification service" 92 ); 93 let mut ticker = interval(self.poll_interval); 94 loop { 95 tokio::select! { 96 _ = ticker.tick() => { 97 if let Err(e) = self.process_batch().await { 98 error!(error = %e, "Failed to process notification batch"); 99 } 100 } 101 _ = shutdown.changed() => { 102 if *shutdown.borrow() { 103 info!("Notification service shutting down"); 104 break; 105 } 106 } 107 } 108 } 109 } 110 111 async fn process_batch(&self) -> Result<(), sqlx::Error> { 112 let notifications = self.fetch_pending_notifications().await?; 113 if notifications.is_empty() { 114 return Ok(()); 115 } 116 debug!(count = notifications.len(), "Processing notification batch"); 117 for notification in notifications { 118 self.process_notification(notification).await; 119 } 120 Ok(()) 121 } 122 123 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> { 124 let now = Utc::now(); 125 sqlx::query_as!( 126 QueuedNotification, 127 r#" 128 UPDATE notification_queue 129 SET status = 'processing', updated_at = NOW() 130 WHERE id IN ( 131 SELECT id FROM notification_queue 132 WHERE status = 'pending' 133 AND scheduled_for <= $1 134 AND attempts < max_attempts 135 ORDER BY scheduled_for ASC 136 LIMIT $2 137 FOR UPDATE SKIP LOCKED 138 ) 139 RETURNING 140 id, user_id, 141 channel as "channel: NotificationChannel", 142 notification_type as "notification_type: super::types::NotificationType", 143 status as "status: NotificationStatus", 144 recipient, subject, body, metadata, 145 attempts, max_attempts, last_error, 146 created_at, updated_at, scheduled_for, processed_at 147 "#, 148 now, 149 self.batch_size 150 ) 151 .fetch_all(&self.db) 152 .await 153 } 154 155 async fn process_notification(&self, notification: QueuedNotification) { 156 let notification_id = notification.id; 157 let channel = notification.channel; 158 let result = match self.senders.get(&channel) { 159 Some(sender) => sender.send(&notification).await, 160 None => { 161 warn!( 162 notification_id = %notification_id, 163 channel = ?channel, 164 "No sender registered for channel" 165 ); 166 Err(SendError::NotConfigured(channel)) 167 } 168 }; 169 match result { 170 Ok(()) => { 171 debug!(notification_id = %notification_id, "Notification sent successfully"); 172 if let Err(e) = self.mark_sent(notification_id).await { 173 error!( 174 notification_id = %notification_id, 175 error = %e, 176 "Failed to mark notification as sent" 177 ); 178 } 179 } 180 Err(e) => { 181 let error_msg = e.to_string(); 182 warn!( 183 notification_id = %notification_id, 184 error = %error_msg, 185 "Failed to send notification" 186 ); 187 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await { 188 error!( 189 notification_id = %notification_id, 190 error = %db_err, 191 "Failed to mark notification as failed" 192 ); 193 } 194 } 195 } 196 } 197 198 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 199 sqlx::query!( 200 r#" 201 UPDATE notification_queue 202 SET status = 'sent', processed_at = NOW(), updated_at = NOW() 203 WHERE id = $1 204 "#, 205 id 206 ) 207 .execute(&self.db) 208 .await?; 209 Ok(()) 210 } 211 212 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 213 sqlx::query!( 214 r#" 215 UPDATE notification_queue 216 SET 217 status = CASE 218 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status 219 ELSE 'pending'::notification_status 220 END, 221 attempts = attempts + 1, 222 last_error = $2, 223 updated_at = NOW(), 224 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 225 WHERE id = $1 226 "#, 227 id, 228 error 229 ) 230 .execute(&self.db) 231 .await?; 232 Ok(()) 233 } 234} 235 236pub async fn enqueue_notification( 237 db: &PgPool, 238 notification: NewNotification, 239) -> Result<Uuid, sqlx::Error> { 240 sqlx::query_scalar!( 241 r#" 242 INSERT INTO notification_queue 243 (user_id, channel, notification_type, recipient, subject, body, metadata) 244 VALUES ($1, $2, $3, $4, $5, $6, $7) 245 RETURNING id 246 "#, 247 notification.user_id, 248 notification.channel as NotificationChannel, 249 notification.notification_type as super::types::NotificationType, 250 notification.recipient, 251 notification.subject, 252 notification.body, 253 notification.metadata 254 ) 255 .fetch_one(db) 256 .await 257} 258 259pub struct UserNotificationPrefs { 260 pub channel: NotificationChannel, 261 pub email: Option<String>, 262 pub handle: String, 263} 264 265pub async fn get_user_notification_prefs( 266 db: &PgPool, 267 user_id: Uuid, 268) -> Result<UserNotificationPrefs, sqlx::Error> { 269 let row = sqlx::query!( 270 r#" 271 SELECT 272 email, 273 handle, 274 preferred_notification_channel as "channel: NotificationChannel" 275 FROM users 276 WHERE id = $1 277 "#, 278 user_id 279 ) 280 .fetch_one(db) 281 .await?; 282 Ok(UserNotificationPrefs { 283 channel: row.channel, 284 email: row.email, 285 handle: row.handle, 286 }) 287} 288 289pub async fn enqueue_welcome( 290 db: &PgPool, 291 user_id: Uuid, 292 hostname: &str, 293) -> Result<Uuid, sqlx::Error> { 294 let prefs = get_user_notification_prefs(db, user_id).await?; 295 let body = format!( 296 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.", 297 hostname, prefs.handle 298 ); 299 enqueue_notification( 300 db, 301 NewNotification::new( 302 user_id, 303 prefs.channel, 304 super::types::NotificationType::Welcome, 305 prefs.email.clone().unwrap_or_default(), 306 Some(format!("Welcome to {}", hostname)), 307 body, 308 ), 309 ) 310 .await 311} 312 313pub async fn enqueue_email_verification( 314 db: &PgPool, 315 user_id: Uuid, 316 email: &str, 317 handle: &str, 318 code: &str, 319 hostname: &str, 320) -> Result<Uuid, sqlx::Error> { 321 let body = format!( 322 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 323 handle, code 324 ); 325 enqueue_notification( 326 db, 327 NewNotification::email( 328 user_id, 329 super::types::NotificationType::EmailVerification, 330 email.to_string(), 331 format!("Verify your email - {}", hostname), 332 body, 333 ), 334 ) 335 .await 336} 337 338pub async fn enqueue_password_reset( 339 db: &PgPool, 340 user_id: Uuid, 341 code: &str, 342 hostname: &str, 343) -> Result<Uuid, sqlx::Error> { 344 let prefs = get_user_notification_prefs(db, user_id).await?; 345 let body = format!( 346 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.", 347 prefs.handle, code 348 ); 349 enqueue_notification( 350 db, 351 NewNotification::new( 352 user_id, 353 prefs.channel, 354 super::types::NotificationType::PasswordReset, 355 prefs.email.clone().unwrap_or_default(), 356 Some(format!("Password Reset - {}", hostname)), 357 body, 358 ), 359 ) 360 .await 361} 362 363pub async fn enqueue_email_update( 364 db: &PgPool, 365 user_id: Uuid, 366 new_email: &str, 367 handle: &str, 368 code: &str, 369 hostname: &str, 370) -> Result<Uuid, sqlx::Error> { 371 let body = format!( 372 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.", 373 handle, code 374 ); 375 enqueue_notification( 376 db, 377 NewNotification::email( 378 user_id, 379 super::types::NotificationType::EmailUpdate, 380 new_email.to_string(), 381 format!("Confirm your new email - {}", hostname), 382 body, 383 ), 384 ) 385 .await 386} 387 388pub async fn enqueue_account_deletion( 389 db: &PgPool, 390 user_id: Uuid, 391 code: &str, 392 hostname: &str, 393) -> Result<Uuid, sqlx::Error> { 394 let prefs = get_user_notification_prefs(db, user_id).await?; 395 let body = format!( 396 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 397 prefs.handle, code 398 ); 399 enqueue_notification( 400 db, 401 NewNotification::new( 402 user_id, 403 prefs.channel, 404 super::types::NotificationType::AccountDeletion, 405 prefs.email.clone().unwrap_or_default(), 406 Some(format!("Account Deletion Request - {}", hostname)), 407 body, 408 ), 409 ) 410 .await 411} 412 413pub async fn enqueue_plc_operation( 414 db: &PgPool, 415 user_id: Uuid, 416 token: &str, 417 hostname: &str, 418) -> Result<Uuid, sqlx::Error> { 419 let prefs = get_user_notification_prefs(db, user_id).await?; 420 let body = format!( 421 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.", 422 prefs.handle, token 423 ); 424 enqueue_notification( 425 db, 426 NewNotification::new( 427 user_id, 428 prefs.channel, 429 super::types::NotificationType::PlcOperation, 430 prefs.email.clone().unwrap_or_default(), 431 Some(format!("{} - PLC Operation Token", hostname)), 432 body, 433 ), 434 ) 435 .await 436} 437 438pub async fn enqueue_2fa_code( 439 db: &PgPool, 440 user_id: Uuid, 441 code: &str, 442 hostname: &str, 443) -> Result<Uuid, sqlx::Error> { 444 let prefs = get_user_notification_prefs(db, user_id).await?; 445 let body = format!( 446 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.", 447 prefs.handle, code 448 ); 449 enqueue_notification( 450 db, 451 NewNotification::new( 452 user_id, 453 prefs.channel, 454 super::types::NotificationType::TwoFactorCode, 455 prefs.email.clone().unwrap_or_default(), 456 Some(format!("Sign-in Verification - {}", hostname)), 457 body, 458 ), 459 ) 460 .await 461} 462 463pub fn channel_display_name(channel: NotificationChannel) -> &'static str { 464 match channel { 465 NotificationChannel::Email => "email", 466 NotificationChannel::Discord => "Discord", 467 NotificationChannel::Telegram => "Telegram", 468 NotificationChannel::Signal => "Signal", 469 } 470} 471 472pub async fn enqueue_signup_verification( 473 db: &PgPool, 474 user_id: Uuid, 475 channel: &str, 476 recipient: &str, 477 code: &str, 478) -> Result<Uuid, sqlx::Error> { 479 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 480 let notification_channel = match channel { 481 "email" => NotificationChannel::Email, 482 "discord" => NotificationChannel::Discord, 483 "telegram" => NotificationChannel::Telegram, 484 "signal" => NotificationChannel::Signal, 485 _ => NotificationChannel::Email, 486 }; 487 let body = format!( 488 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.", 489 code, hostname 490 ); 491 let subject = match notification_channel { 492 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)), 493 _ => None, 494 }; 495 enqueue_notification( 496 db, 497 NewNotification::new( 498 user_id, 499 notification_channel, 500 super::types::NotificationType::EmailVerification, 501 recipient.to_string(), 502 subject, 503 body, 504 ), 505 ) 506 .await 507}