this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::sender::{NotificationSender, SendError};
13use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
14
15pub struct NotificationService {
16 db: PgPool,
17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
18 poll_interval: Duration,
19 batch_size: i64,
20}
21
22impl NotificationService {
23 pub fn new(db: PgPool) -> Self {
24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
25 .ok()
26 .and_then(|v| v.parse().ok())
27 .unwrap_or(1000);
28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
29 .ok()
30 .and_then(|v| v.parse().ok())
31 .unwrap_or(100);
32 Self {
33 db,
34 senders: HashMap::new(),
35 poll_interval: Duration::from_millis(poll_interval_ms),
36 batch_size,
37 }
38 }
39
40 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
41 self.poll_interval = interval;
42 self
43 }
44
45 pub fn with_batch_size(mut self, size: i64) -> Self {
46 self.batch_size = size;
47 self
48 }
49
50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
51 self.senders.insert(sender.channel(), Arc::new(sender));
52 self
53 }
54
55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
56 let id = sqlx::query_scalar!(
57 r#"
58 INSERT INTO notification_queue
59 (user_id, channel, notification_type, recipient, subject, body, metadata)
60 VALUES ($1, $2, $3, $4, $5, $6, $7)
61 RETURNING id
62 "#,
63 notification.user_id,
64 notification.channel as NotificationChannel,
65 notification.notification_type as super::types::NotificationType,
66 notification.recipient,
67 notification.subject,
68 notification.body,
69 notification.metadata
70 )
71 .fetch_one(&self.db)
72 .await?;
73 debug!(notification_id = %id, "Notification enqueued");
74 Ok(id)
75 }
76
77 pub fn has_senders(&self) -> bool {
78 !self.senders.is_empty()
79 }
80
81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
82 if self.senders.is_empty() {
83 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
84 }
85 info!(
86 poll_interval_secs = self.poll_interval.as_secs(),
87 batch_size = self.batch_size,
88 channels = ?self.senders.keys().collect::<Vec<_>>(),
89 "Starting notification service"
90 );
91 let mut ticker = interval(self.poll_interval);
92 loop {
93 tokio::select! {
94 _ = ticker.tick() => {
95 if let Err(e) = self.process_batch().await {
96 error!(error = %e, "Failed to process notification batch");
97 }
98 }
99 _ = shutdown.changed() => {
100 if *shutdown.borrow() {
101 info!("Notification service shutting down");
102 break;
103 }
104 }
105 }
106 }
107 }
108
109 async fn process_batch(&self) -> Result<(), sqlx::Error> {
110 let notifications = self.fetch_pending_notifications().await?;
111 if notifications.is_empty() {
112 return Ok(());
113 }
114 debug!(count = notifications.len(), "Processing notification batch");
115 for notification in notifications {
116 self.process_notification(notification).await;
117 }
118 Ok(())
119 }
120
121 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
122 let now = Utc::now();
123 sqlx::query_as!(
124 QueuedNotification,
125 r#"
126 UPDATE notification_queue
127 SET status = 'processing', updated_at = NOW()
128 WHERE id IN (
129 SELECT id FROM notification_queue
130 WHERE status = 'pending'
131 AND scheduled_for <= $1
132 AND attempts < max_attempts
133 ORDER BY scheduled_for ASC
134 LIMIT $2
135 FOR UPDATE SKIP LOCKED
136 )
137 RETURNING
138 id, user_id,
139 channel as "channel: NotificationChannel",
140 notification_type as "notification_type: super::types::NotificationType",
141 status as "status: NotificationStatus",
142 recipient, subject, body, metadata,
143 attempts, max_attempts, last_error,
144 created_at, updated_at, scheduled_for, processed_at
145 "#,
146 now,
147 self.batch_size
148 )
149 .fetch_all(&self.db)
150 .await
151 }
152
153 async fn process_notification(&self, notification: QueuedNotification) {
154 let notification_id = notification.id;
155 let channel = notification.channel;
156 let result = match self.senders.get(&channel) {
157 Some(sender) => sender.send(¬ification).await,
158 None => {
159 warn!(
160 notification_id = %notification_id,
161 channel = ?channel,
162 "No sender registered for channel"
163 );
164 Err(SendError::NotConfigured(channel))
165 }
166 };
167 match result {
168 Ok(()) => {
169 debug!(notification_id = %notification_id, "Notification sent successfully");
170 if let Err(e) = self.mark_sent(notification_id).await {
171 error!(
172 notification_id = %notification_id,
173 error = %e,
174 "Failed to mark notification as sent"
175 );
176 }
177 }
178 Err(e) => {
179 let error_msg = e.to_string();
180 warn!(
181 notification_id = %notification_id,
182 error = %error_msg,
183 "Failed to send notification"
184 );
185 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await {
186 error!(
187 notification_id = %notification_id,
188 error = %db_err,
189 "Failed to mark notification as failed"
190 );
191 }
192 }
193 }
194 }
195
196 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
197 sqlx::query!(
198 r#"
199 UPDATE notification_queue
200 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
201 WHERE id = $1
202 "#,
203 id
204 )
205 .execute(&self.db)
206 .await?;
207 Ok(())
208 }
209
210 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
211 sqlx::query!(
212 r#"
213 UPDATE notification_queue
214 SET
215 status = CASE
216 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status
217 ELSE 'pending'::notification_status
218 END,
219 attempts = attempts + 1,
220 last_error = $2,
221 updated_at = NOW(),
222 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
223 WHERE id = $1
224 "#,
225 id,
226 error
227 )
228 .execute(&self.db)
229 .await?;
230 Ok(())
231 }
232}
233
234pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
235 sqlx::query_scalar!(
236 r#"
237 INSERT INTO notification_queue
238 (user_id, channel, notification_type, recipient, subject, body, metadata)
239 VALUES ($1, $2, $3, $4, $5, $6, $7)
240 RETURNING id
241 "#,
242 notification.user_id,
243 notification.channel as NotificationChannel,
244 notification.notification_type as super::types::NotificationType,
245 notification.recipient,
246 notification.subject,
247 notification.body,
248 notification.metadata
249 )
250 .fetch_one(db)
251 .await
252}
253
254pub struct UserNotificationPrefs {
255 pub channel: NotificationChannel,
256 pub email: Option<String>,
257 pub handle: String,
258}
259
260pub async fn get_user_notification_prefs(
261 db: &PgPool,
262 user_id: Uuid,
263) -> Result<UserNotificationPrefs, sqlx::Error> {
264 let row = sqlx::query!(
265 r#"
266 SELECT
267 email,
268 handle,
269 preferred_notification_channel as "channel: NotificationChannel"
270 FROM users
271 WHERE id = $1
272 "#,
273 user_id
274 )
275 .fetch_one(db)
276 .await?;
277 Ok(UserNotificationPrefs {
278 channel: row.channel,
279 email: row.email,
280 handle: row.handle,
281 })
282}
283
284pub async fn enqueue_welcome(
285 db: &PgPool,
286 user_id: Uuid,
287 hostname: &str,
288) -> Result<Uuid, sqlx::Error> {
289 let prefs = get_user_notification_prefs(db, user_id).await?;
290 let body = format!(
291 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.",
292 hostname, prefs.handle
293 );
294 enqueue_notification(
295 db,
296 NewNotification::new(
297 user_id,
298 prefs.channel,
299 super::types::NotificationType::Welcome,
300 prefs.email.clone().unwrap_or_default(),
301 Some(format!("Welcome to {}", hostname)),
302 body,
303 ),
304 )
305 .await
306}
307
308pub async fn enqueue_email_verification(
309 db: &PgPool,
310 user_id: Uuid,
311 email: &str,
312 handle: &str,
313 code: &str,
314 hostname: &str,
315) -> Result<Uuid, sqlx::Error> {
316 let body = format!(
317 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
318 handle, code
319 );
320 enqueue_notification(
321 db,
322 NewNotification::email(
323 user_id,
324 super::types::NotificationType::EmailVerification,
325 email.to_string(),
326 format!("Verify your email - {}", hostname),
327 body,
328 ),
329 )
330 .await
331}
332
333pub async fn enqueue_password_reset(
334 db: &PgPool,
335 user_id: Uuid,
336 code: &str,
337 hostname: &str,
338) -> Result<Uuid, sqlx::Error> {
339 let prefs = get_user_notification_prefs(db, user_id).await?;
340 let body = format!(
341 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.",
342 prefs.handle, code
343 );
344 enqueue_notification(
345 db,
346 NewNotification::new(
347 user_id,
348 prefs.channel,
349 super::types::NotificationType::PasswordReset,
350 prefs.email.clone().unwrap_or_default(),
351 Some(format!("Password Reset - {}", hostname)),
352 body,
353 ),
354 )
355 .await
356}
357
358pub async fn enqueue_email_update(
359 db: &PgPool,
360 user_id: Uuid,
361 new_email: &str,
362 handle: &str,
363 code: &str,
364 hostname: &str,
365) -> Result<Uuid, sqlx::Error> {
366 let body = format!(
367 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
368 handle, code
369 );
370 enqueue_notification(
371 db,
372 NewNotification::email(
373 user_id,
374 super::types::NotificationType::EmailUpdate,
375 new_email.to_string(),
376 format!("Confirm your new email - {}", hostname),
377 body,
378 ),
379 )
380 .await
381}
382
383pub async fn enqueue_account_deletion(
384 db: &PgPool,
385 user_id: Uuid,
386 code: &str,
387 hostname: &str,
388) -> Result<Uuid, sqlx::Error> {
389 let prefs = get_user_notification_prefs(db, user_id).await?;
390 let body = format!(
391 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
392 prefs.handle, code
393 );
394 enqueue_notification(
395 db,
396 NewNotification::new(
397 user_id,
398 prefs.channel,
399 super::types::NotificationType::AccountDeletion,
400 prefs.email.clone().unwrap_or_default(),
401 Some(format!("Account Deletion Request - {}", hostname)),
402 body,
403 ),
404 )
405 .await
406}
407
408pub async fn enqueue_plc_operation(
409 db: &PgPool,
410 user_id: Uuid,
411 token: &str,
412 hostname: &str,
413) -> Result<Uuid, sqlx::Error> {
414 let prefs = get_user_notification_prefs(db, user_id).await?;
415 let body = format!(
416 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
417 prefs.handle, token
418 );
419 enqueue_notification(
420 db,
421 NewNotification::new(
422 user_id,
423 prefs.channel,
424 super::types::NotificationType::PlcOperation,
425 prefs.email.clone().unwrap_or_default(),
426 Some(format!("{} - PLC Operation Token", hostname)),
427 body,
428 ),
429 )
430 .await
431}
432
433pub async fn enqueue_2fa_code(
434 db: &PgPool,
435 user_id: Uuid,
436 code: &str,
437 hostname: &str,
438) -> Result<Uuid, sqlx::Error> {
439 let prefs = get_user_notification_prefs(db, user_id).await?;
440 let body = format!(
441 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
442 prefs.handle, code
443 );
444 enqueue_notification(
445 db,
446 NewNotification::new(
447 user_id,
448 prefs.channel,
449 super::types::NotificationType::TwoFactorCode,
450 prefs.email.clone().unwrap_or_default(),
451 Some(format!("Sign-in Verification - {}", hostname)),
452 body,
453 ),
454 )
455 .await
456}
457
458pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
459 match channel {
460 NotificationChannel::Email => "email",
461 NotificationChannel::Discord => "Discord",
462 NotificationChannel::Telegram => "Telegram",
463 NotificationChannel::Signal => "Signal",
464 }
465}
466
467pub async fn enqueue_signup_verification(
468 db: &PgPool,
469 user_id: Uuid,
470 channel: &str,
471 recipient: &str,
472 code: &str,
473) -> Result<Uuid, sqlx::Error> {
474 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
475 let notification_channel = match channel {
476 "email" => NotificationChannel::Email,
477 "discord" => NotificationChannel::Discord,
478 "telegram" => NotificationChannel::Telegram,
479 "signal" => NotificationChannel::Signal,
480 _ => NotificationChannel::Email,
481 };
482 let body = format!(
483 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.",
484 code, hostname
485 );
486 let subject = match notification_channel {
487 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)),
488 _ => None,
489 };
490 enqueue_notification(
491 db,
492 NewNotification::new(
493 user_id,
494 notification_channel,
495 super::types::NotificationType::EmailVerification,
496 recipient.to_string(),
497 subject,
498 body,
499 ),
500 )
501 .await
502}