this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::locale::{format_message, get_strings};
13use super::sender::{CommsSender, SendError};
14use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms};
15
16pub struct CommsService {
17 db: PgPool,
18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>,
19 poll_interval: Duration,
20 batch_size: i64,
21}
22
23impl CommsService {
24 pub fn new(db: PgPool) -> Self {
25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
26 .ok()
27 .and_then(|v| v.parse().ok())
28 .unwrap_or(1000);
29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
30 .ok()
31 .and_then(|v| v.parse().ok())
32 .unwrap_or(100);
33 Self {
34 db,
35 senders: HashMap::new(),
36 poll_interval: Duration::from_millis(poll_interval_ms),
37 batch_size,
38 }
39 }
40
41 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
42 self.poll_interval = interval;
43 self
44 }
45
46 pub fn with_batch_size(mut self, size: i64) -> Self {
47 self.batch_size = size;
48 self
49 }
50
51 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self {
52 self.senders.insert(sender.channel(), Arc::new(sender));
53 self
54 }
55
56 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> {
57 let id = sqlx::query_scalar!(
58 r#"
59 INSERT INTO comms_queue
60 (user_id, channel, comms_type, recipient, subject, body, metadata)
61 VALUES ($1, $2, $3, $4, $5, $6, $7)
62 RETURNING id
63 "#,
64 item.user_id,
65 item.channel as CommsChannel,
66 item.comms_type as super::types::CommsType,
67 item.recipient,
68 item.subject,
69 item.body,
70 item.metadata
71 )
72 .fetch_one(&self.db)
73 .await?;
74 debug!(comms_id = %id, "Comms enqueued");
75 Ok(id)
76 }
77
78 pub fn has_senders(&self) -> bool {
79 !self.senders.is_empty()
80 }
81
82 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
83 if self.senders.is_empty() {
84 warn!(
85 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured."
86 );
87 }
88 info!(
89 poll_interval_secs = self.poll_interval.as_secs(),
90 batch_size = self.batch_size,
91 channels = ?self.senders.keys().collect::<Vec<_>>(),
92 "Starting comms service"
93 );
94 let mut ticker = interval(self.poll_interval);
95 loop {
96 tokio::select! {
97 _ = ticker.tick() => {
98 if let Err(e) = self.process_batch().await {
99 error!(error = %e, "Failed to process comms batch");
100 }
101 }
102 _ = shutdown.changed() => {
103 if *shutdown.borrow() {
104 info!("Comms service shutting down");
105 break;
106 }
107 }
108 }
109 }
110 }
111
112 async fn process_batch(&self) -> Result<(), sqlx::Error> {
113 let items = self.fetch_pending().await?;
114 if items.is_empty() {
115 return Ok(());
116 }
117 debug!(count = items.len(), "Processing comms batch");
118 for item in items {
119 self.process_item(item).await;
120 }
121 Ok(())
122 }
123
124 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> {
125 let now = Utc::now();
126 sqlx::query_as!(
127 QueuedComms,
128 r#"
129 UPDATE comms_queue
130 SET status = 'processing', updated_at = NOW()
131 WHERE id IN (
132 SELECT id FROM comms_queue
133 WHERE status = 'pending'
134 AND scheduled_for <= $1
135 AND attempts < max_attempts
136 ORDER BY scheduled_for ASC
137 LIMIT $2
138 FOR UPDATE SKIP LOCKED
139 )
140 RETURNING
141 id, user_id,
142 channel as "channel: CommsChannel",
143 comms_type as "comms_type: super::types::CommsType",
144 status as "status: CommsStatus",
145 recipient, subject, body, metadata,
146 attempts, max_attempts, last_error,
147 created_at, updated_at, scheduled_for, processed_at
148 "#,
149 now,
150 self.batch_size
151 )
152 .fetch_all(&self.db)
153 .await
154 }
155
156 async fn process_item(&self, item: QueuedComms) {
157 let comms_id = item.id;
158 let channel = item.channel;
159 let result = match self.senders.get(&channel) {
160 Some(sender) => sender.send(&item).await,
161 None => {
162 warn!(
163 comms_id = %comms_id,
164 channel = ?channel,
165 "No sender registered for channel"
166 );
167 Err(SendError::NotConfigured(channel))
168 }
169 };
170 match result {
171 Ok(()) => {
172 debug!(comms_id = %comms_id, "Comms sent successfully");
173 if let Err(e) = self.mark_sent(comms_id).await {
174 error!(
175 comms_id = %comms_id,
176 error = %e,
177 "Failed to mark comms as sent"
178 );
179 }
180 }
181 Err(e) => {
182 let error_msg = e.to_string();
183 warn!(
184 comms_id = %comms_id,
185 error = %error_msg,
186 "Failed to send comms"
187 );
188 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await {
189 error!(
190 comms_id = %comms_id,
191 error = %db_err,
192 "Failed to mark comms as failed"
193 );
194 }
195 }
196 }
197 }
198
199 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
200 sqlx::query!(
201 r#"
202 UPDATE comms_queue
203 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
204 WHERE id = $1
205 "#,
206 id
207 )
208 .execute(&self.db)
209 .await?;
210 Ok(())
211 }
212
213 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
214 sqlx::query!(
215 r#"
216 UPDATE comms_queue
217 SET
218 status = CASE
219 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status
220 ELSE 'pending'::comms_status
221 END,
222 attempts = attempts + 1,
223 last_error = $2,
224 updated_at = NOW(),
225 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
226 WHERE id = $1
227 "#,
228 id,
229 error
230 )
231 .execute(&self.db)
232 .await?;
233 Ok(())
234 }
235}
236
237pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> {
238 sqlx::query_scalar!(
239 r#"
240 INSERT INTO comms_queue
241 (user_id, channel, comms_type, recipient, subject, body, metadata)
242 VALUES ($1, $2, $3, $4, $5, $6, $7)
243 RETURNING id
244 "#,
245 item.user_id,
246 item.channel as CommsChannel,
247 item.comms_type as super::types::CommsType,
248 item.recipient,
249 item.subject,
250 item.body,
251 item.metadata
252 )
253 .fetch_one(db)
254 .await
255}
256
257pub struct UserCommsPrefs {
258 pub channel: CommsChannel,
259 pub email: Option<String>,
260 pub handle: String,
261 pub locale: String,
262}
263
264pub async fn get_user_comms_prefs(
265 db: &PgPool,
266 user_id: Uuid,
267) -> Result<UserCommsPrefs, sqlx::Error> {
268 let row = sqlx::query!(
269 r#"
270 SELECT
271 email,
272 handle,
273 preferred_comms_channel as "channel: CommsChannel",
274 preferred_locale
275 FROM users
276 WHERE id = $1
277 "#,
278 user_id
279 )
280 .fetch_one(db)
281 .await?;
282 Ok(UserCommsPrefs {
283 channel: row.channel,
284 email: row.email,
285 handle: row.handle,
286 locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()),
287 })
288}
289
290pub async fn enqueue_welcome(
291 db: &PgPool,
292 user_id: Uuid,
293 hostname: &str,
294) -> Result<Uuid, sqlx::Error> {
295 let prefs = get_user_comms_prefs(db, user_id).await?;
296 let strings = get_strings(&prefs.locale);
297 let body = format_message(
298 strings.welcome_body,
299 &[("hostname", hostname), ("handle", &prefs.handle)],
300 );
301 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]);
302 enqueue_comms(
303 db,
304 NewComms::new(
305 user_id,
306 prefs.channel,
307 super::types::CommsType::Welcome,
308 prefs.email.clone().unwrap_or_default(),
309 Some(subject),
310 body,
311 ),
312 )
313 .await
314}
315
316pub async fn enqueue_email_verification(
317 db: &PgPool,
318 user_id: Uuid,
319 email: &str,
320 handle: &str,
321 code: &str,
322 hostname: &str,
323) -> Result<Uuid, sqlx::Error> {
324 let prefs = get_user_comms_prefs(db, user_id).await?;
325 let strings = get_strings(&prefs.locale);
326 let body = format_message(
327 strings.email_verification_body,
328 &[("handle", handle), ("code", code)],
329 );
330 let subject = format_message(strings.email_verification_subject, &[("hostname", hostname)]);
331 enqueue_comms(
332 db,
333 NewComms::email(
334 user_id,
335 super::types::CommsType::EmailVerification,
336 email.to_string(),
337 subject,
338 body,
339 ),
340 )
341 .await
342}
343
344pub async fn enqueue_password_reset(
345 db: &PgPool,
346 user_id: Uuid,
347 code: &str,
348 hostname: &str,
349) -> Result<Uuid, sqlx::Error> {
350 let prefs = get_user_comms_prefs(db, user_id).await?;
351 let strings = get_strings(&prefs.locale);
352 let body = format_message(
353 strings.password_reset_body,
354 &[("handle", &prefs.handle), ("code", code)],
355 );
356 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]);
357 enqueue_comms(
358 db,
359 NewComms::new(
360 user_id,
361 prefs.channel,
362 super::types::CommsType::PasswordReset,
363 prefs.email.clone().unwrap_or_default(),
364 Some(subject),
365 body,
366 ),
367 )
368 .await
369}
370
371pub async fn enqueue_email_update(
372 db: &PgPool,
373 user_id: Uuid,
374 new_email: &str,
375 handle: &str,
376 code: &str,
377 hostname: &str,
378) -> Result<Uuid, sqlx::Error> {
379 let prefs = get_user_comms_prefs(db, user_id).await?;
380 let strings = get_strings(&prefs.locale);
381 let body = format_message(
382 strings.email_update_body,
383 &[("handle", handle), ("code", code)],
384 );
385 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]);
386 enqueue_comms(
387 db,
388 NewComms::email(
389 user_id,
390 super::types::CommsType::EmailUpdate,
391 new_email.to_string(),
392 subject,
393 body,
394 ),
395 )
396 .await
397}
398
399pub async fn enqueue_account_deletion(
400 db: &PgPool,
401 user_id: Uuid,
402 code: &str,
403 hostname: &str,
404) -> Result<Uuid, sqlx::Error> {
405 let prefs = get_user_comms_prefs(db, user_id).await?;
406 let strings = get_strings(&prefs.locale);
407 let body = format_message(
408 strings.account_deletion_body,
409 &[("handle", &prefs.handle), ("code", code)],
410 );
411 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]);
412 enqueue_comms(
413 db,
414 NewComms::new(
415 user_id,
416 prefs.channel,
417 super::types::CommsType::AccountDeletion,
418 prefs.email.clone().unwrap_or_default(),
419 Some(subject),
420 body,
421 ),
422 )
423 .await
424}
425
426pub async fn enqueue_plc_operation(
427 db: &PgPool,
428 user_id: Uuid,
429 token: &str,
430 hostname: &str,
431) -> Result<Uuid, sqlx::Error> {
432 let prefs = get_user_comms_prefs(db, user_id).await?;
433 let strings = get_strings(&prefs.locale);
434 let body = format_message(
435 strings.plc_operation_body,
436 &[("handle", &prefs.handle), ("token", token)],
437 );
438 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]);
439 enqueue_comms(
440 db,
441 NewComms::new(
442 user_id,
443 prefs.channel,
444 super::types::CommsType::PlcOperation,
445 prefs.email.clone().unwrap_or_default(),
446 Some(subject),
447 body,
448 ),
449 )
450 .await
451}
452
453pub async fn enqueue_2fa_code(
454 db: &PgPool,
455 user_id: Uuid,
456 code: &str,
457 hostname: &str,
458) -> Result<Uuid, sqlx::Error> {
459 let prefs = get_user_comms_prefs(db, user_id).await?;
460 let strings = get_strings(&prefs.locale);
461 let body = format_message(
462 strings.two_factor_code_body,
463 &[("handle", &prefs.handle), ("code", code)],
464 );
465 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]);
466 enqueue_comms(
467 db,
468 NewComms::new(
469 user_id,
470 prefs.channel,
471 super::types::CommsType::TwoFactorCode,
472 prefs.email.clone().unwrap_or_default(),
473 Some(subject),
474 body,
475 ),
476 )
477 .await
478}
479
480pub async fn enqueue_passkey_recovery(
481 db: &PgPool,
482 user_id: Uuid,
483 recovery_url: &str,
484 hostname: &str,
485) -> Result<Uuid, sqlx::Error> {
486 let prefs = get_user_comms_prefs(db, user_id).await?;
487 let strings = get_strings(&prefs.locale);
488 let body = format_message(
489 strings.passkey_recovery_body,
490 &[("handle", &prefs.handle), ("url", recovery_url)],
491 );
492 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]);
493 enqueue_comms(
494 db,
495 NewComms::new(
496 user_id,
497 prefs.channel,
498 super::types::CommsType::PasskeyRecovery,
499 prefs.email.clone().unwrap_or_default(),
500 Some(subject),
501 body,
502 ),
503 )
504 .await
505}
506
507pub fn channel_display_name(channel: CommsChannel) -> &'static str {
508 match channel {
509 CommsChannel::Email => "email",
510 CommsChannel::Discord => "Discord",
511 CommsChannel::Telegram => "Telegram",
512 CommsChannel::Signal => "Signal",
513 }
514}
515
516pub async fn enqueue_signup_verification(
517 db: &PgPool,
518 user_id: Uuid,
519 channel: &str,
520 recipient: &str,
521 code: &str,
522 locale: Option<&str>,
523) -> Result<Uuid, sqlx::Error> {
524 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
525 let comms_channel = match channel {
526 "email" => CommsChannel::Email,
527 "discord" => CommsChannel::Discord,
528 "telegram" => CommsChannel::Telegram,
529 "signal" => CommsChannel::Signal,
530 _ => CommsChannel::Email,
531 };
532 let strings = get_strings(locale.unwrap_or("en"));
533 let body = format_message(
534 strings.signup_verification_body,
535 &[("code", code), ("hostname", &hostname)],
536 );
537 let subject = match comms_channel {
538 CommsChannel::Email => {
539 Some(format_message(strings.signup_verification_subject, &[("hostname", &hostname)]))
540 }
541 _ => None,
542 };
543 enqueue_comms(
544 db,
545 NewComms::new(
546 user_id,
547 comms_channel,
548 super::types::CommsType::EmailVerification,
549 recipient.to_string(),
550 subject,
551 body,
552 ),
553 )
554 .await
555}
556
557pub async fn queue_legacy_login_notification(
558 db: &PgPool,
559 user_id: Uuid,
560 hostname: &str,
561 client_ip: &str,
562 channel: CommsChannel,
563) -> Result<Uuid, sqlx::Error> {
564 let prefs = get_user_comms_prefs(db, user_id).await?;
565 let strings = get_strings(&prefs.locale);
566 let timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
567 let body = format_message(
568 strings.legacy_login_body,
569 &[
570 ("handle", &prefs.handle),
571 ("timestamp", ×tamp),
572 ("ip", client_ip),
573 ("hostname", hostname),
574 ],
575 );
576 let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]);
577 enqueue_comms(
578 db,
579 NewComms::new(
580 user_id,
581 channel,
582 super::types::CommsType::LegacyLoginAlert,
583 prefs.email.clone().unwrap_or_default(),
584 Some(subject),
585 body,
586 ),
587 )
588 .await
589}