this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::sender::{NotificationSender, SendError};
13use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
14
15pub struct NotificationService {
16 db: PgPool,
17 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
18 poll_interval: Duration,
19 batch_size: i64,
20}
21
22impl NotificationService {
23 pub fn new(db: PgPool) -> Self {
24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
25 .ok()
26 .and_then(|v| v.parse().ok())
27 .unwrap_or(1000);
28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
29 .ok()
30 .and_then(|v| v.parse().ok())
31 .unwrap_or(100);
32 Self {
33 db,
34 senders: HashMap::new(),
35 poll_interval: Duration::from_millis(poll_interval_ms),
36 batch_size,
37 }
38 }
39
40 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
41 self.poll_interval = interval;
42 self
43 }
44
45 pub fn with_batch_size(mut self, size: i64) -> Self {
46 self.batch_size = size;
47 self
48 }
49
50 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
51 self.senders.insert(sender.channel(), Arc::new(sender));
52 self
53 }
54
55 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
56 let id = sqlx::query_scalar!(
57 r#"
58 INSERT INTO notification_queue
59 (user_id, channel, notification_type, recipient, subject, body, metadata)
60 VALUES ($1, $2, $3, $4, $5, $6, $7)
61 RETURNING id
62 "#,
63 notification.user_id,
64 notification.channel as NotificationChannel,
65 notification.notification_type as super::types::NotificationType,
66 notification.recipient,
67 notification.subject,
68 notification.body,
69 notification.metadata
70 )
71 .fetch_one(&self.db)
72 .await?;
73 debug!(notification_id = %id, "Notification enqueued");
74 Ok(id)
75 }
76
77 pub fn has_senders(&self) -> bool {
78 !self.senders.is_empty()
79 }
80
81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
82 if self.senders.is_empty() {
83 warn!(
84 "Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured."
85 );
86 }
87 info!(
88 poll_interval_secs = self.poll_interval.as_secs(),
89 batch_size = self.batch_size,
90 channels = ?self.senders.keys().collect::<Vec<_>>(),
91 "Starting notification service"
92 );
93 let mut ticker = interval(self.poll_interval);
94 loop {
95 tokio::select! {
96 _ = ticker.tick() => {
97 if let Err(e) = self.process_batch().await {
98 error!(error = %e, "Failed to process notification batch");
99 }
100 }
101 _ = shutdown.changed() => {
102 if *shutdown.borrow() {
103 info!("Notification service shutting down");
104 break;
105 }
106 }
107 }
108 }
109 }
110
111 async fn process_batch(&self) -> Result<(), sqlx::Error> {
112 let notifications = self.fetch_pending_notifications().await?;
113 if notifications.is_empty() {
114 return Ok(());
115 }
116 debug!(count = notifications.len(), "Processing notification batch");
117 for notification in notifications {
118 self.process_notification(notification).await;
119 }
120 Ok(())
121 }
122
123 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
124 let now = Utc::now();
125 sqlx::query_as!(
126 QueuedNotification,
127 r#"
128 UPDATE notification_queue
129 SET status = 'processing', updated_at = NOW()
130 WHERE id IN (
131 SELECT id FROM notification_queue
132 WHERE status = 'pending'
133 AND scheduled_for <= $1
134 AND attempts < max_attempts
135 ORDER BY scheduled_for ASC
136 LIMIT $2
137 FOR UPDATE SKIP LOCKED
138 )
139 RETURNING
140 id, user_id,
141 channel as "channel: NotificationChannel",
142 notification_type as "notification_type: super::types::NotificationType",
143 status as "status: NotificationStatus",
144 recipient, subject, body, metadata,
145 attempts, max_attempts, last_error,
146 created_at, updated_at, scheduled_for, processed_at
147 "#,
148 now,
149 self.batch_size
150 )
151 .fetch_all(&self.db)
152 .await
153 }
154
155 async fn process_notification(&self, notification: QueuedNotification) {
156 let notification_id = notification.id;
157 let channel = notification.channel;
158 let result = match self.senders.get(&channel) {
159 Some(sender) => sender.send(¬ification).await,
160 None => {
161 warn!(
162 notification_id = %notification_id,
163 channel = ?channel,
164 "No sender registered for channel"
165 );
166 Err(SendError::NotConfigured(channel))
167 }
168 };
169 match result {
170 Ok(()) => {
171 debug!(notification_id = %notification_id, "Notification sent successfully");
172 if let Err(e) = self.mark_sent(notification_id).await {
173 error!(
174 notification_id = %notification_id,
175 error = %e,
176 "Failed to mark notification as sent"
177 );
178 }
179 }
180 Err(e) => {
181 let error_msg = e.to_string();
182 warn!(
183 notification_id = %notification_id,
184 error = %error_msg,
185 "Failed to send notification"
186 );
187 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await {
188 error!(
189 notification_id = %notification_id,
190 error = %db_err,
191 "Failed to mark notification as failed"
192 );
193 }
194 }
195 }
196 }
197
198 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
199 sqlx::query!(
200 r#"
201 UPDATE notification_queue
202 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
203 WHERE id = $1
204 "#,
205 id
206 )
207 .execute(&self.db)
208 .await?;
209 Ok(())
210 }
211
212 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
213 sqlx::query!(
214 r#"
215 UPDATE notification_queue
216 SET
217 status = CASE
218 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status
219 ELSE 'pending'::notification_status
220 END,
221 attempts = attempts + 1,
222 last_error = $2,
223 updated_at = NOW(),
224 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
225 WHERE id = $1
226 "#,
227 id,
228 error
229 )
230 .execute(&self.db)
231 .await?;
232 Ok(())
233 }
234}
235
236pub async fn enqueue_notification(
237 db: &PgPool,
238 notification: NewNotification,
239) -> Result<Uuid, sqlx::Error> {
240 sqlx::query_scalar!(
241 r#"
242 INSERT INTO notification_queue
243 (user_id, channel, notification_type, recipient, subject, body, metadata)
244 VALUES ($1, $2, $3, $4, $5, $6, $7)
245 RETURNING id
246 "#,
247 notification.user_id,
248 notification.channel as NotificationChannel,
249 notification.notification_type as super::types::NotificationType,
250 notification.recipient,
251 notification.subject,
252 notification.body,
253 notification.metadata
254 )
255 .fetch_one(db)
256 .await
257}
258
259pub struct UserNotificationPrefs {
260 pub channel: NotificationChannel,
261 pub email: Option<String>,
262 pub handle: String,
263}
264
265pub async fn get_user_notification_prefs(
266 db: &PgPool,
267 user_id: Uuid,
268) -> Result<UserNotificationPrefs, sqlx::Error> {
269 let row = sqlx::query!(
270 r#"
271 SELECT
272 email,
273 handle,
274 preferred_notification_channel as "channel: NotificationChannel"
275 FROM users
276 WHERE id = $1
277 "#,
278 user_id
279 )
280 .fetch_one(db)
281 .await?;
282 Ok(UserNotificationPrefs {
283 channel: row.channel,
284 email: row.email,
285 handle: row.handle,
286 })
287}
288
289pub async fn enqueue_welcome(
290 db: &PgPool,
291 user_id: Uuid,
292 hostname: &str,
293) -> Result<Uuid, sqlx::Error> {
294 let prefs = get_user_notification_prefs(db, user_id).await?;
295 let body = format!(
296 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.",
297 hostname, prefs.handle
298 );
299 enqueue_notification(
300 db,
301 NewNotification::new(
302 user_id,
303 prefs.channel,
304 super::types::NotificationType::Welcome,
305 prefs.email.clone().unwrap_or_default(),
306 Some(format!("Welcome to {}", hostname)),
307 body,
308 ),
309 )
310 .await
311}
312
313pub async fn enqueue_email_verification(
314 db: &PgPool,
315 user_id: Uuid,
316 email: &str,
317 handle: &str,
318 code: &str,
319 hostname: &str,
320) -> Result<Uuid, sqlx::Error> {
321 let body = format!(
322 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
323 handle, code
324 );
325 enqueue_notification(
326 db,
327 NewNotification::email(
328 user_id,
329 super::types::NotificationType::EmailVerification,
330 email.to_string(),
331 format!("Verify your email - {}", hostname),
332 body,
333 ),
334 )
335 .await
336}
337
338pub async fn enqueue_password_reset(
339 db: &PgPool,
340 user_id: Uuid,
341 code: &str,
342 hostname: &str,
343) -> Result<Uuid, sqlx::Error> {
344 let prefs = get_user_notification_prefs(db, user_id).await?;
345 let body = format!(
346 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.",
347 prefs.handle, code
348 );
349 enqueue_notification(
350 db,
351 NewNotification::new(
352 user_id,
353 prefs.channel,
354 super::types::NotificationType::PasswordReset,
355 prefs.email.clone().unwrap_or_default(),
356 Some(format!("Password Reset - {}", hostname)),
357 body,
358 ),
359 )
360 .await
361}
362
363pub async fn enqueue_email_update(
364 db: &PgPool,
365 user_id: Uuid,
366 new_email: &str,
367 handle: &str,
368 code: &str,
369 hostname: &str,
370) -> Result<Uuid, sqlx::Error> {
371 let body = format!(
372 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
373 handle, code
374 );
375 enqueue_notification(
376 db,
377 NewNotification::email(
378 user_id,
379 super::types::NotificationType::EmailUpdate,
380 new_email.to_string(),
381 format!("Confirm your new email - {}", hostname),
382 body,
383 ),
384 )
385 .await
386}
387
388pub async fn enqueue_account_deletion(
389 db: &PgPool,
390 user_id: Uuid,
391 code: &str,
392 hostname: &str,
393) -> Result<Uuid, sqlx::Error> {
394 let prefs = get_user_notification_prefs(db, user_id).await?;
395 let body = format!(
396 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
397 prefs.handle, code
398 );
399 enqueue_notification(
400 db,
401 NewNotification::new(
402 user_id,
403 prefs.channel,
404 super::types::NotificationType::AccountDeletion,
405 prefs.email.clone().unwrap_or_default(),
406 Some(format!("Account Deletion Request - {}", hostname)),
407 body,
408 ),
409 )
410 .await
411}
412
413pub async fn enqueue_plc_operation(
414 db: &PgPool,
415 user_id: Uuid,
416 token: &str,
417 hostname: &str,
418) -> Result<Uuid, sqlx::Error> {
419 let prefs = get_user_notification_prefs(db, user_id).await?;
420 let body = format!(
421 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
422 prefs.handle, token
423 );
424 enqueue_notification(
425 db,
426 NewNotification::new(
427 user_id,
428 prefs.channel,
429 super::types::NotificationType::PlcOperation,
430 prefs.email.clone().unwrap_or_default(),
431 Some(format!("{} - PLC Operation Token", hostname)),
432 body,
433 ),
434 )
435 .await
436}
437
438pub async fn enqueue_2fa_code(
439 db: &PgPool,
440 user_id: Uuid,
441 code: &str,
442 hostname: &str,
443) -> Result<Uuid, sqlx::Error> {
444 let prefs = get_user_notification_prefs(db, user_id).await?;
445 let body = format!(
446 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
447 prefs.handle, code
448 );
449 enqueue_notification(
450 db,
451 NewNotification::new(
452 user_id,
453 prefs.channel,
454 super::types::NotificationType::TwoFactorCode,
455 prefs.email.clone().unwrap_or_default(),
456 Some(format!("Sign-in Verification - {}", hostname)),
457 body,
458 ),
459 )
460 .await
461}
462
463pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
464 match channel {
465 NotificationChannel::Email => "email",
466 NotificationChannel::Discord => "Discord",
467 NotificationChannel::Telegram => "Telegram",
468 NotificationChannel::Signal => "Signal",
469 }
470}
471
472pub async fn enqueue_signup_verification(
473 db: &PgPool,
474 user_id: Uuid,
475 channel: &str,
476 recipient: &str,
477 code: &str,
478) -> Result<Uuid, sqlx::Error> {
479 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
480 let notification_channel = match channel {
481 "email" => NotificationChannel::Email,
482 "discord" => NotificationChannel::Discord,
483 "telegram" => NotificationChannel::Telegram,
484 "signal" => NotificationChannel::Signal,
485 _ => NotificationChannel::Email,
486 };
487 let body = format!(
488 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.",
489 code, hostname
490 );
491 let subject = match notification_channel {
492 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)),
493 _ => None,
494 };
495 enqueue_notification(
496 db,
497 NewNotification::new(
498 user_id,
499 notification_channel,
500 super::types::NotificationType::EmailVerification,
501 recipient.to_string(),
502 subject,
503 body,
504 ),
505 )
506 .await
507}