this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use chrono::Utc;
5use sqlx::PgPool;
6use tokio::sync::watch;
7use tokio::time::interval;
8use tracing::{debug, error, info, warn};
9use uuid::Uuid;
10use super::sender::{NotificationSender, SendError};
11use super::types::{NewNotification, NotificationChannel, NotificationStatus, QueuedNotification};
12pub struct NotificationService {
13 db: PgPool,
14 senders: HashMap<NotificationChannel, Arc<dyn NotificationSender>>,
15 poll_interval: Duration,
16 batch_size: i64,
17}
18impl NotificationService {
19 pub fn new(db: PgPool) -> Self {
20 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
21 .ok()
22 .and_then(|v| v.parse().ok())
23 .unwrap_or(1000);
24 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
25 .ok()
26 .and_then(|v| v.parse().ok())
27 .unwrap_or(100);
28 Self {
29 db,
30 senders: HashMap::new(),
31 poll_interval: Duration::from_millis(poll_interval_ms),
32 batch_size,
33 }
34 }
35 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
36 self.poll_interval = interval;
37 self
38 }
39 pub fn with_batch_size(mut self, size: i64) -> Self {
40 self.batch_size = size;
41 self
42 }
43 pub fn register_sender<S: NotificationSender + 'static>(mut self, sender: S) -> Self {
44 self.senders.insert(sender.channel(), Arc::new(sender));
45 self
46 }
47 pub async fn enqueue(&self, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
48 let id = sqlx::query_scalar!(
49 r#"
50 INSERT INTO notification_queue
51 (user_id, channel, notification_type, recipient, subject, body, metadata)
52 VALUES ($1, $2, $3, $4, $5, $6, $7)
53 RETURNING id
54 "#,
55 notification.user_id,
56 notification.channel as NotificationChannel,
57 notification.notification_type as super::types::NotificationType,
58 notification.recipient,
59 notification.subject,
60 notification.body,
61 notification.metadata
62 )
63 .fetch_one(&self.db)
64 .await?;
65 debug!(notification_id = %id, "Notification enqueued");
66 Ok(id)
67 }
68 pub fn has_senders(&self) -> bool {
69 !self.senders.is_empty()
70 }
71 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
72 if self.senders.is_empty() {
73 warn!("Notification service starting with no senders configured. Notifications will be queued but not delivered until senders are configured.");
74 }
75 info!(
76 poll_interval_secs = self.poll_interval.as_secs(),
77 batch_size = self.batch_size,
78 channels = ?self.senders.keys().collect::<Vec<_>>(),
79 "Starting notification service"
80 );
81 let mut ticker = interval(self.poll_interval);
82 loop {
83 tokio::select! {
84 _ = ticker.tick() => {
85 if let Err(e) = self.process_batch().await {
86 error!(error = %e, "Failed to process notification batch");
87 }
88 }
89 _ = shutdown.changed() => {
90 if *shutdown.borrow() {
91 info!("Notification service shutting down");
92 break;
93 }
94 }
95 }
96 }
97 }
98 async fn process_batch(&self) -> Result<(), sqlx::Error> {
99 let notifications = self.fetch_pending_notifications().await?;
100 if notifications.is_empty() {
101 return Ok(());
102 }
103 debug!(count = notifications.len(), "Processing notification batch");
104 for notification in notifications {
105 self.process_notification(notification).await;
106 }
107 Ok(())
108 }
109 async fn fetch_pending_notifications(&self) -> Result<Vec<QueuedNotification>, sqlx::Error> {
110 let now = Utc::now();
111 sqlx::query_as!(
112 QueuedNotification,
113 r#"
114 UPDATE notification_queue
115 SET status = 'processing', updated_at = NOW()
116 WHERE id IN (
117 SELECT id FROM notification_queue
118 WHERE status = 'pending'
119 AND scheduled_for <= $1
120 AND attempts < max_attempts
121 ORDER BY scheduled_for ASC
122 LIMIT $2
123 FOR UPDATE SKIP LOCKED
124 )
125 RETURNING
126 id, user_id,
127 channel as "channel: NotificationChannel",
128 notification_type as "notification_type: super::types::NotificationType",
129 status as "status: NotificationStatus",
130 recipient, subject, body, metadata,
131 attempts, max_attempts, last_error,
132 created_at, updated_at, scheduled_for, processed_at
133 "#,
134 now,
135 self.batch_size
136 )
137 .fetch_all(&self.db)
138 .await
139 }
140 async fn process_notification(&self, notification: QueuedNotification) {
141 let notification_id = notification.id;
142 let channel = notification.channel;
143 let result = match self.senders.get(&channel) {
144 Some(sender) => sender.send(¬ification).await,
145 None => {
146 warn!(
147 notification_id = %notification_id,
148 channel = ?channel,
149 "No sender registered for channel"
150 );
151 Err(SendError::NotConfigured(channel))
152 }
153 };
154 match result {
155 Ok(()) => {
156 debug!(notification_id = %notification_id, "Notification sent successfully");
157 if let Err(e) = self.mark_sent(notification_id).await {
158 error!(
159 notification_id = %notification_id,
160 error = %e,
161 "Failed to mark notification as sent"
162 );
163 }
164 }
165 Err(e) => {
166 let error_msg = e.to_string();
167 warn!(
168 notification_id = %notification_id,
169 error = %error_msg,
170 "Failed to send notification"
171 );
172 if let Err(db_err) = self.mark_failed(notification_id, &error_msg).await {
173 error!(
174 notification_id = %notification_id,
175 error = %db_err,
176 "Failed to mark notification as failed"
177 );
178 }
179 }
180 }
181 }
182 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
183 sqlx::query!(
184 r#"
185 UPDATE notification_queue
186 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
187 WHERE id = $1
188 "#,
189 id
190 )
191 .execute(&self.db)
192 .await?;
193 Ok(())
194 }
195 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
196 sqlx::query!(
197 r#"
198 UPDATE notification_queue
199 SET
200 status = CASE
201 WHEN attempts + 1 >= max_attempts THEN 'failed'::notification_status
202 ELSE 'pending'::notification_status
203 END,
204 attempts = attempts + 1,
205 last_error = $2,
206 updated_at = NOW(),
207 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
208 WHERE id = $1
209 "#,
210 id,
211 error
212 )
213 .execute(&self.db)
214 .await?;
215 Ok(())
216 }
217}
218pub async fn enqueue_notification(db: &PgPool, notification: NewNotification) -> Result<Uuid, sqlx::Error> {
219 sqlx::query_scalar!(
220 r#"
221 INSERT INTO notification_queue
222 (user_id, channel, notification_type, recipient, subject, body, metadata)
223 VALUES ($1, $2, $3, $4, $5, $6, $7)
224 RETURNING id
225 "#,
226 notification.user_id,
227 notification.channel as NotificationChannel,
228 notification.notification_type as super::types::NotificationType,
229 notification.recipient,
230 notification.subject,
231 notification.body,
232 notification.metadata
233 )
234 .fetch_one(db)
235 .await
236}
237pub struct UserNotificationPrefs {
238 pub channel: NotificationChannel,
239 pub email: Option<String>,
240 pub handle: String,
241}
242pub async fn get_user_notification_prefs(
243 db: &PgPool,
244 user_id: Uuid,
245) -> Result<UserNotificationPrefs, sqlx::Error> {
246 let row = sqlx::query!(
247 r#"
248 SELECT
249 email,
250 handle,
251 preferred_notification_channel as "channel: NotificationChannel"
252 FROM users
253 WHERE id = $1
254 "#,
255 user_id
256 )
257 .fetch_one(db)
258 .await?;
259 Ok(UserNotificationPrefs {
260 channel: row.channel,
261 email: row.email,
262 handle: row.handle,
263 })
264}
265pub async fn enqueue_welcome(
266 db: &PgPool,
267 user_id: Uuid,
268 hostname: &str,
269) -> Result<Uuid, sqlx::Error> {
270 let prefs = get_user_notification_prefs(db, user_id).await?;
271 let body = format!(
272 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.",
273 hostname, prefs.handle
274 );
275 enqueue_notification(
276 db,
277 NewNotification::new(
278 user_id,
279 prefs.channel,
280 super::types::NotificationType::Welcome,
281 prefs.email.clone().unwrap_or_default(),
282 Some(format!("Welcome to {}", hostname)),
283 body,
284 ),
285 )
286 .await
287}
288pub async fn enqueue_email_verification(
289 db: &PgPool,
290 user_id: Uuid,
291 email: &str,
292 handle: &str,
293 code: &str,
294 hostname: &str,
295) -> Result<Uuid, sqlx::Error> {
296 let body = format!(
297 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
298 handle, code
299 );
300 enqueue_notification(
301 db,
302 NewNotification::email(
303 user_id,
304 super::types::NotificationType::EmailVerification,
305 email.to_string(),
306 format!("Verify your email - {}", hostname),
307 body,
308 ),
309 )
310 .await
311}
312pub async fn enqueue_password_reset(
313 db: &PgPool,
314 user_id: Uuid,
315 code: &str,
316 hostname: &str,
317) -> Result<Uuid, sqlx::Error> {
318 let prefs = get_user_notification_prefs(db, user_id).await?;
319 let body = format!(
320 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.",
321 prefs.handle, code
322 );
323 enqueue_notification(
324 db,
325 NewNotification::new(
326 user_id,
327 prefs.channel,
328 super::types::NotificationType::PasswordReset,
329 prefs.email.clone().unwrap_or_default(),
330 Some(format!("Password Reset - {}", hostname)),
331 body,
332 ),
333 )
334 .await
335}
336pub async fn enqueue_email_update(
337 db: &PgPool,
338 user_id: Uuid,
339 new_email: &str,
340 handle: &str,
341 code: &str,
342 hostname: &str,
343) -> Result<Uuid, sqlx::Error> {
344 let body = format!(
345 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
346 handle, code
347 );
348 enqueue_notification(
349 db,
350 NewNotification::email(
351 user_id,
352 super::types::NotificationType::EmailUpdate,
353 new_email.to_string(),
354 format!("Confirm your new email - {}", hostname),
355 body,
356 ),
357 )
358 .await
359}
360pub async fn enqueue_account_deletion(
361 db: &PgPool,
362 user_id: Uuid,
363 code: &str,
364 hostname: &str,
365) -> Result<Uuid, sqlx::Error> {
366 let prefs = get_user_notification_prefs(db, user_id).await?;
367 let body = format!(
368 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
369 prefs.handle, code
370 );
371 enqueue_notification(
372 db,
373 NewNotification::new(
374 user_id,
375 prefs.channel,
376 super::types::NotificationType::AccountDeletion,
377 prefs.email.clone().unwrap_or_default(),
378 Some(format!("Account Deletion Request - {}", hostname)),
379 body,
380 ),
381 )
382 .await
383}
384pub async fn enqueue_plc_operation(
385 db: &PgPool,
386 user_id: Uuid,
387 token: &str,
388 hostname: &str,
389) -> Result<Uuid, sqlx::Error> {
390 let prefs = get_user_notification_prefs(db, user_id).await?;
391 let body = format!(
392 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
393 prefs.handle, token
394 );
395 enqueue_notification(
396 db,
397 NewNotification::new(
398 user_id,
399 prefs.channel,
400 super::types::NotificationType::PlcOperation,
401 prefs.email.clone().unwrap_or_default(),
402 Some(format!("{} - PLC Operation Token", hostname)),
403 body,
404 ),
405 )
406 .await
407}
408pub async fn enqueue_2fa_code(
409 db: &PgPool,
410 user_id: Uuid,
411 code: &str,
412 hostname: &str,
413) -> Result<Uuid, sqlx::Error> {
414 let prefs = get_user_notification_prefs(db, user_id).await?;
415 let body = format!(
416 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
417 prefs.handle, code
418 );
419 enqueue_notification(
420 db,
421 NewNotification::new(
422 user_id,
423 prefs.channel,
424 super::types::NotificationType::TwoFactorCode,
425 prefs.email.clone().unwrap_or_default(),
426 Some(format!("Sign-in Verification - {}", hostname)),
427 body,
428 ),
429 )
430 .await
431}
432pub fn channel_display_name(channel: NotificationChannel) -> &'static str {
433 match channel {
434 NotificationChannel::Email => "email",
435 NotificationChannel::Discord => "Discord",
436 NotificationChannel::Telegram => "Telegram",
437 NotificationChannel::Signal => "Signal",
438 }
439}
440pub async fn enqueue_signup_verification(
441 db: &PgPool,
442 user_id: Uuid,
443 channel: &str,
444 recipient: &str,
445 code: &str,
446) -> Result<Uuid, sqlx::Error> {
447 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
448 let notification_channel = match channel {
449 "email" => NotificationChannel::Email,
450 "discord" => NotificationChannel::Discord,
451 "telegram" => NotificationChannel::Telegram,
452 "signal" => NotificationChannel::Signal,
453 _ => NotificationChannel::Email,
454 };
455 let body = format!(
456 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.",
457 code, hostname
458 );
459 let subject = match notification_channel {
460 NotificationChannel::Email => Some(format!("Verify your account - {}", hostname)),
461 _ => None,
462 };
463 enqueue_notification(
464 db,
465 NewNotification::new(
466 user_id,
467 notification_channel,
468 super::types::NotificationType::EmailVerification,
469 recipient.to_string(),
470 subject,
471 body,
472 ),
473 )
474 .await
475}