this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::sender::{NotificationSender, SendError};
13use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
14
15pub struct NotificationService {
16 db: PgPool,
17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
18 poll_interval: Duration,
19 batch_size: i64,
20}
21
22impl NotificationService {
23 pub fn new(db: PgPool) -> Self {
24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
25 .ok()
26 .and_then(|v| v.parse().ok())
27 .unwrap_or(1000);
28
29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
30 .ok()
31 .and_then(|v| v.parse().ok())
32 .unwrap_or(100);
33
34 Self {
35 db,
36 senders: HashMap::new(),
37 poll_interval: Duration::from_millis(poll_interval_ms),
38 batch_size,
39 }
40 }
41
42 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
43 self.poll_interval = interval;
44 self
45 }
46
47 pub fn with_batch_size(mut self, size: i64) -> Self {
48 self.batch_size = size;
49 self
50 }
51
52 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
53 self.senders.insert(sender.channel(), Arc::new(sender));
54 self
55 }
56
57 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
58 let id = sqlx::query_scalar!(
59 r#"
60 INSERT INTO notification_queue
61 (user_id, channel, notification_type, recipient, subject, body, metadata)
62 VALUES ($1, $2, $3, $4, $5, $6, $7)
63 RETURNING id
64 "#,
65 notification.user_id,
66 notification.channel as NotificationChannel,
67 notification.notification_type as super::types::NotificationType,
68 notification.recipient,
69 notification.subject,
70 notification.body,
71 notification.metadata
72 )
73 .fetch_one(&self.db)
74 .await?;
75
76 debug!(notification_id = %id, "Notification enqueued");
77 Ok(id)
78 }
79
80 pub fn has_senders(&self) -> bool {
81 !self.senders.is_empty()
82 }
83
84 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
85 if self.senders.is_empty() {
86 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
87 }
88
89 info!(
90 poll_interval_secs = self.poll_interval.as_secs(),
91 batch_size = self.batch_size,
92 channels = ?self.senders.keys().collect::<Vec<_>>(),
93 "Starting notification service"
94 );
95
96 let mut ticker = interval(self.poll_interval);
97
98 loop {
99 tokio::select! {
100 _ = ticker.tick() => {
101 if let Err(e) = self.process_batch().await {
102 error!(error = %e, "Failed to process notification batch");
103 }
104 }
105 _ = shutdown.changed() => {
106 if *shutdown.borrow() {
107 info!("Notification service shutting down");
108 break;
109 }
110 }
111 }
112 }
113 }
114
115 async fn process_batch(&self) -> Result<(), sqlx::Error> {
116 let notifications = self.fetch_pending_notifications().await?;
117
118 if notifications.is_empty() {
119 return Ok(());
120 }
121
122 debug!(count = notifications.len(), "Processing notification batch");
123
124 for notification in notifications {
125 self.process_notification(notification).await;
126 }
127
128 Ok(())
129 }
130
131 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
132 let now = Utc::now();
133
134 sqlx::query_as!(
135 QueuedNotification,
136 r#"
137 UPDATE notification_queue
138 SET status = 'processing', updated_at = NOW()
139 WHERE id IN (
140 SELECT id FROM notification_queue
141 WHERE status = 'pending'
142 AND scheduled_for <= $1
143 AND attempts < max_attempts
144 ORDER BY scheduled_for ASC
145 LIMIT $2
146 FOR UPDATE SKIP LOCKED
147 )
148 RETURNING
149 id, user_id,
150 channel as "channel: NotificationChannel",
151 notification_type as "notification_type: super::types::NotificationType",
152 status as "status: NotificationStatus",
153 recipient, subject, body, metadata,
154 attempts, max_attempts, last_error,
155 created_at, updated_at, scheduled_for, processed_at
156 "#,
157 now,
158 self.batch_size
159 )
160 .fetch_all(&self.db)
161 .await
162 }
163
164 async fn process_notification(&self, notification: QueuedNotification) {
165 let notification_id = notification.id;
166 let channel = notification.channel;
167
168 let result = match self.senders.get(&channel) {
169 Some(sender) => sender.send(¬ification).await,
170 None => {
171 warn!(
172 notification_id = %notification_id,
173 channel = ?channel,
174 "No sender registered for channel"
175 );
176 Err(SendError::NotConfigured(channel))
177 }
178 };
179
180 match result {
181 Ok(()) => {
182 debug!(notification_id = %notification_id, "Notification sent successfully");
183 if let Err(e) = self.mark_sent(notification_id).await {
184 error!(
185 notification_id = %notification_id,
186 error = %e,
187 "Failed to mark notification as sent"
188 );
189 }
190 }
191 Err(e) => {
192 let error_msg = e.to_string();
193 warn!(
194 notification_id = %notification_id,
195 error = %error_msg,
196 "Failed to send notification"
197 );
198 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await {
199 error!(
200 notification_id = %notification_id,
201 error = %db_err,
202 "Failed to mark notification as failed"
203 );
204 }
205 }
206 }
207 }
208
209 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
210 sqlx::query!(
211 r#"
212 UPDATE notification_queue
213 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
214 WHERE id = $1
215 "#,
216 id
217 )
218 .execute(&self.db)
219 .await?;
220 Ok(())
221 }
222
223 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
224 sqlx::query!(
225 r#"
226 UPDATE notification_queue
227 SET
228 status = CASE
229 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status
230 ELSE 'pending'::notification_status
231 END,
232 attempts = attempts + 1,
233 last_error = $2,
234 updated_at = NOW(),
235 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
236 WHERE id = $1
237 "#,
238 id,
239 error
240 )
241 .execute(&self.db)
242 .await?;
243 Ok(())
244 }
245}
246
247pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
248 sqlx::query_scalar!(
249 r#"
250 INSERT INTO notification_queue
251 (user_id, channel, notification_type, recipient, subject, body, metadata)
252 VALUES ($1, $2, $3, $4, $5, $6, $7)
253 RETURNING id
254 "#,
255 notification.user_id,
256 notification.channel as NotificationChannel,
257 notification.notification_type as super::types::NotificationType,
258 notification.recipient,
259 notification.subject,
260 notification.body,
261 notification.metadata
262 )
263 .fetch_one(db)
264 .await
265}
266
267pub struct UserNotificationPrefs {
268 pub channel: NotificationChannel,
269 pub email: Option<String>,
270 pub handle: String,
271}
272
273pub async fn get_user_notification_prefs(
274 db: &PgPool,
275 user_id: Uuid,
276) -> Result<UserNotificationPrefs, sqlx::Error> {
277 let row = sqlx::query!(
278 r#"
279 SELECT
280 email,
281 handle,
282 preferred_notification_channel as "channel: NotificationChannel"
283 FROM users
284 WHERE id = $1
285 "#,
286 user_id
287 )
288 .fetch_one(db)
289 .await?;
290
291 Ok(UserNotificationPrefs {
292 channel: row.channel,
293 email: row.email,
294 handle: row.handle,
295 })
296}
297
298pub async fn enqueue_welcome(
299 db: &PgPool,
300 user_id: Uuid,
301 hostname: &str,
302) -> Result<Uuid, sqlx::Error> {
303 let prefs = get_user_notification_prefs(db, user_id).await?;
304
305 let body = format!(
306 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.",
307 hostname, prefs.handle
308 );
309
310 enqueue_notification(
311 db,
312 NewNotification::new(
313 user_id,
314 prefs.channel,
315 super::types::NotificationType::Welcome,
316 prefs.email.clone().unwrap_or_default(),
317 Some(format!("Welcome to {}", hostname)),
318 body,
319 ),
320 )
321 .await
322}
323
324pub async fn enqueue_email_verification(
325 db: &PgPool,
326 user_id: Uuid,
327 email: &str,
328 handle: &str,
329 code: &str,
330 hostname: &str,
331) -> Result<Uuid, sqlx::Error> {
332 let body = format!(
333 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
334 handle, code
335 );
336
337 enqueue_notification(
338 db,
339 NewNotification::email(
340 user_id,
341 super::types::NotificationType::EmailVerification,
342 email.to_string(),
343 format!("Verify your email - {}", hostname),
344 body,
345 ),
346 )
347 .await
348}
349
350pub async fn enqueue_password_reset(
351 db: &PgPool,
352 user_id: Uuid,
353 code: &str,
354 hostname: &str,
355) -> Result<Uuid, sqlx::Error> {
356 let prefs = get_user_notification_prefs(db, user_id).await?;
357
358 let body = format!(
359 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.",
360 prefs.handle, code
361 );
362
363 enqueue_notification(
364 db,
365 NewNotification::new(
366 user_id,
367 prefs.channel,
368 super::types::NotificationType::PasswordReset,
369 prefs.email.clone().unwrap_or_default(),
370 Some(format!("Password Reset - {}", hostname)),
371 body,
372 ),
373 )
374 .await
375}
376
377pub async fn enqueue_email_update(
378 db: &PgPool,
379 user_id: Uuid,
380 new_email: &str,
381 handle: &str,
382 code: &str,
383 hostname: &str,
384) -> Result<Uuid, sqlx::Error> {
385 let body = format!(
386 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
387 handle, code
388 );
389
390 enqueue_notification(
391 db,
392 NewNotification::email(
393 user_id,
394 super::types::NotificationType::EmailUpdate,
395 new_email.to_string(),
396 format!("Confirm your new email - {}", hostname),
397 body,
398 ),
399 )
400 .await
401}
402
403pub async fn enqueue_account_deletion(
404 db: &PgPool,
405 user_id: Uuid,
406 code: &str,
407 hostname: &str,
408) -> Result<Uuid, sqlx::Error> {
409 let prefs = get_user_notification_prefs(db, user_id).await?;
410
411 let body = format!(
412 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
413 prefs.handle, code
414 );
415
416 enqueue_notification(
417 db,
418 NewNotification::new(
419 user_id,
420 prefs.channel,
421 super::types::NotificationType::AccountDeletion,
422 prefs.email.clone().unwrap_or_default(),
423 Some(format!("Account Deletion Request - {}", hostname)),
424 body,
425 ),
426 )
427 .await
428}
429
430pub async fn enqueue_plc_operation(
431 db: &PgPool,
432 user_id: Uuid,
433 token: &str,
434 hostname: &str,
435) -> Result<Uuid, sqlx::Error> {
436 let prefs = get_user_notification_prefs(db, user_id).await?;
437
438 let body = format!(
439 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
440 prefs.handle, token
441 );
442
443 enqueue_notification(
444 db,
445 NewNotification::new(
446 user_id,
447 prefs.channel,
448 super::types::NotificationType::PlcOperation,
449 prefs.email.clone().unwrap_or_default(),
450 Some(format!("{} - PLC Operation Token", hostname)),
451 body,
452 ),
453 )
454 .await
455}
456
457pub async fn enqueue_2fa_code(
458 db: &PgPool,
459 user_id: Uuid,
460 code: &str,
461 hostname: &str,
462) -> Result<Uuid, sqlx::Error> {
463 let prefs = get_user_notification_prefs(db, user_id).await?;
464
465 let body = format!(
466 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
467 prefs.handle, code
468 );
469
470 enqueue_notification(
471 db,
472 NewNotification::new(
473 user_id,
474 prefs.channel,
475 super::types::NotificationType::TwoFactorCode,
476 prefs.email.clone().unwrap_or_default(),
477 Some(format!("Sign-in Verification - {}", hostname)),
478 body,
479 ),
480 )
481 .await
482}
483
484pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
485 match channel {
486 NotificationChannel::Email => "email",
487 NotificationChannel::Discord => "Discord",
488 NotificationChannel::Telegram => "Telegram",
489 NotificationChannel::Signal => "Signal",
490 }
491}
492
493pub async fn enqueue_signup_verification(
494 db: &PgPool,
495 user_id: Uuid,
496 channel: &str,
497 recipient: &str,
498 code: &str,
499) -> Result<Uuid, sqlx::Error> {
500 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
501
502 let notification_channel = match channel {
503 "email" => NotificationChannel::Email,
504 "discord" => NotificationChannel::Discord,
505 "telegram" => NotificationChannel::Telegram,
506 "signal" => NotificationChannel::Signal,
507 _ => NotificationChannel::Email,
508 };
509
510 let body = format!(
511 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.",
512 code, hostname
513 );
514
515 let subject = match notification_channel {
516 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)),
517 _ => None,
518 };
519
520 enqueue_notification(
521 db,
522 NewNotification::new(
523 user_id,
524 notification_channel,
525 super::types::NotificationType::EmailVerification,
526 recipient.to_string(),
527 subject,
528 body,
529 ),
530 )
531 .await
532}