this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::locale::{format_message, get_strings};
13use super::sender::{CommsSender, SendError};
14use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms};
15
16pub struct CommsService {
17 db: PgPool,
18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>,
19 poll_interval: Duration,
20 batch_size: i64,
21}
22
23impl CommsService {
24 pub fn new(db: PgPool) -> Self {
25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
26 .ok()
27 .and_then(|v| v.parse().ok())
28 .unwrap_or(1000);
29 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
30 .ok()
31 .and_then(|v| v.parse().ok())
32 .unwrap_or(100);
33 Self {
34 db,
35 senders: HashMap::new(),
36 poll_interval: Duration::from_millis(poll_interval_ms),
37 batch_size,
38 }
39 }
40
41 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
42 self.poll_interval = interval;
43 self
44 }
45
46 pub fn with_batch_size(mut self, size: i64) -> Self {
47 self.batch_size = size;
48 self
49 }
50
51 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self {
52 self.senders.insert(sender.channel(), Arc::new(sender));
53 self
54 }
55
56 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> {
57 let id = sqlx::query_scalar!(
58 r#"
59 INSERT INTO comms_queue
60 (user_id, channel, comms_type, recipient, subject, body, metadata)
61 VALUES ($1, $2, $3, $4, $5, $6, $7)
62 RETURNING id
63 "#,
64 item.user_id,
65 item.channel as CommsChannel,
66 item.comms_type as super::types::CommsType,
67 item.recipient,
68 item.subject,
69 item.body,
70 item.metadata
71 )
72 .fetch_one(&self.db)
73 .await?;
74 debug!(comms_id = %id, "Comms enqueued");
75 Ok(id)
76 }
77
78 pub fn has_senders(&self) -> bool {
79 !self.senders.is_empty()
80 }
81
82 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
83 if self.senders.is_empty() {
84 warn!(
85 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured."
86 );
87 }
88 info!(
89 poll_interval_secs = self.poll_interval.as_secs(),
90 batch_size = self.batch_size,
91 channels = ?self.senders.keys().collect::<Vec<_>>(),
92 "Starting comms service"
93 );
94 let mut ticker = interval(self.poll_interval);
95 loop {
96 tokio::select! {
97 _ = ticker.tick() => {
98 if let Err(e) = self.process_batch().await {
99 error!(error = %e, "Failed to process comms batch");
100 }
101 }
102 _ = shutdown.changed() => {
103 if *shutdown.borrow() {
104 info!("Comms service shutting down");
105 break;
106 }
107 }
108 }
109 }
110 }
111
112 async fn process_batch(&self) -> Result<(), sqlx::Error> {
113 let items = self.fetch_pending().await?;
114 if items.is_empty() {
115 return Ok(());
116 }
117 debug!(count = items.len(), "Processing comms batch");
118 for item in items {
119 self.process_item(item).await;
120 }
121 Ok(())
122 }
123
124 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> {
125 let now = Utc::now();
126 sqlx::query_as!(
127 QueuedComms,
128 r#"
129 UPDATE comms_queue
130 SET status = 'processing', updated_at = NOW()
131 WHERE id IN (
132 SELECT id FROM comms_queue
133 WHERE status = 'pending'
134 AND scheduled_for <= $1
135 AND attempts < max_attempts
136 ORDER BY scheduled_for ASC
137 LIMIT $2
138 FOR UPDATE SKIP LOCKED
139 )
140 RETURNING
141 id, user_id,
142 channel as "channel: CommsChannel",
143 comms_type as "comms_type: super::types::CommsType",
144 status as "status: CommsStatus",
145 recipient, subject, body, metadata,
146 attempts, max_attempts, last_error,
147 created_at, updated_at, scheduled_for, processed_at
148 "#,
149 now,
150 self.batch_size
151 )
152 .fetch_all(&self.db)
153 .await
154 }
155
156 async fn process_item(&self, item: QueuedComms) {
157 let comms_id = item.id;
158 let channel = item.channel;
159 let result = match self.senders.get(&channel) {
160 Some(sender) => sender.send(&item).await,
161 None => {
162 warn!(
163 comms_id = %comms_id,
164 channel = ?channel,
165 "No sender registered for channel"
166 );
167 Err(SendError::NotConfigured(channel))
168 }
169 };
170 match result {
171 Ok(()) => {
172 debug!(comms_id = %comms_id, "Comms sent successfully");
173 if let Err(e) = self.mark_sent(comms_id).await {
174 error!(
175 comms_id = %comms_id,
176 error = %e,
177 "Failed to mark comms as sent"
178 );
179 }
180 }
181 Err(e) => {
182 let error_msg = e.to_string();
183 warn!(
184 comms_id = %comms_id,
185 error = %error_msg,
186 "Failed to send comms"
187 );
188 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await {
189 error!(
190 comms_id = %comms_id,
191 error = %db_err,
192 "Failed to mark comms as failed"
193 );
194 }
195 }
196 }
197 }
198
199 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
200 sqlx::query!(
201 r#"
202 UPDATE comms_queue
203 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
204 WHERE id = $1
205 "#,
206 id
207 )
208 .execute(&self.db)
209 .await?;
210 Ok(())
211 }
212
213 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
214 sqlx::query!(
215 r#"
216 UPDATE comms_queue
217 SET
218 status = CASE
219 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status
220 ELSE 'pending'::comms_status
221 END,
222 attempts = attempts + 1,
223 last_error = $2,
224 updated_at = NOW(),
225 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
226 WHERE id = $1
227 "#,
228 id,
229 error
230 )
231 .execute(&self.db)
232 .await?;
233 Ok(())
234 }
235}
236
237pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> {
238 sqlx::query_scalar!(
239 r#"
240 INSERT INTO comms_queue
241 (user_id, channel, comms_type, recipient, subject, body, metadata)
242 VALUES ($1, $2, $3, $4, $5, $6, $7)
243 RETURNING id
244 "#,
245 item.user_id,
246 item.channel as CommsChannel,
247 item.comms_type as super::types::CommsType,
248 item.recipient,
249 item.subject,
250 item.body,
251 item.metadata
252 )
253 .fetch_one(db)
254 .await
255}
256
257pub struct UserCommsPrefs {
258 pub channel: CommsChannel,
259 pub email: Option<String>,
260 pub handle: String,
261 pub locale: String,
262}
263
264pub async fn get_user_comms_prefs(
265 db: &PgPool,
266 user_id: Uuid,
267) -> Result<UserCommsPrefs, sqlx::Error> {
268 let row = sqlx::query!(
269 r#"
270 SELECT
271 email,
272 handle,
273 preferred_comms_channel as "channel: CommsChannel",
274 preferred_locale
275 FROM users
276 WHERE id = $1
277 "#,
278 user_id
279 )
280 .fetch_one(db)
281 .await?;
282 Ok(UserCommsPrefs {
283 channel: row.channel,
284 email: row.email,
285 handle: row.handle,
286 locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()),
287 })
288}
289
290pub async fn enqueue_welcome(
291 db: &PgPool,
292 user_id: Uuid,
293 hostname: &str,
294) -> Result<Uuid, sqlx::Error> {
295 let prefs = get_user_comms_prefs(db, user_id).await?;
296 let strings = get_strings(&prefs.locale);
297 let body = format_message(
298 strings.welcome_body,
299 &[("hostname", hostname), ("handle", &prefs.handle)],
300 );
301 let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]);
302 enqueue_comms(
303 db,
304 NewComms::new(
305 user_id,
306 prefs.channel,
307 super::types::CommsType::Welcome,
308 prefs.email.clone().unwrap_or_default(),
309 Some(subject),
310 body,
311 ),
312 )
313 .await
314}
315
316pub async fn enqueue_password_reset(
317 db: &PgPool,
318 user_id: Uuid,
319 code: &str,
320 hostname: &str,
321) -> Result<Uuid, sqlx::Error> {
322 let prefs = get_user_comms_prefs(db, user_id).await?;
323 let strings = get_strings(&prefs.locale);
324 let body = format_message(
325 strings.password_reset_body,
326 &[("handle", &prefs.handle), ("code", code)],
327 );
328 let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]);
329 enqueue_comms(
330 db,
331 NewComms::new(
332 user_id,
333 prefs.channel,
334 super::types::CommsType::PasswordReset,
335 prefs.email.clone().unwrap_or_default(),
336 Some(subject),
337 body,
338 ),
339 )
340 .await
341}
342
343pub async fn enqueue_email_update(
344 db: &PgPool,
345 user_id: Uuid,
346 new_email: &str,
347 handle: &str,
348 code: &str,
349 hostname: &str,
350) -> Result<Uuid, sqlx::Error> {
351 let prefs = get_user_comms_prefs(db, user_id).await?;
352 let strings = get_strings(&prefs.locale);
353 let encoded_email = urlencoding::encode(new_email);
354 let encoded_token = urlencoding::encode(code);
355 let verify_page = format!("https://{}/#/verify", hostname);
356 let verify_link = format!(
357 "https://{}/#/verify?token={}&identifier={}",
358 hostname, encoded_token, encoded_email
359 );
360 let body = format_message(
361 strings.email_update_body,
362 &[
363 ("handle", handle),
364 ("code", code),
365 ("verify_page", &verify_page),
366 ("verify_link", &verify_link),
367 ],
368 );
369 let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]);
370 enqueue_comms(
371 db,
372 NewComms::email(
373 user_id,
374 super::types::CommsType::EmailUpdate,
375 new_email.to_string(),
376 subject,
377 body,
378 ),
379 )
380 .await
381}
382
383pub async fn enqueue_account_deletion(
384 db: &PgPool,
385 user_id: Uuid,
386 code: &str,
387 hostname: &str,
388) -> Result<Uuid, sqlx::Error> {
389 let prefs = get_user_comms_prefs(db, user_id).await?;
390 let strings = get_strings(&prefs.locale);
391 let body = format_message(
392 strings.account_deletion_body,
393 &[("handle", &prefs.handle), ("code", code)],
394 );
395 let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]);
396 enqueue_comms(
397 db,
398 NewComms::new(
399 user_id,
400 prefs.channel,
401 super::types::CommsType::AccountDeletion,
402 prefs.email.clone().unwrap_or_default(),
403 Some(subject),
404 body,
405 ),
406 )
407 .await
408}
409
410pub async fn enqueue_plc_operation(
411 db: &PgPool,
412 user_id: Uuid,
413 token: &str,
414 hostname: &str,
415) -> Result<Uuid, sqlx::Error> {
416 let prefs = get_user_comms_prefs(db, user_id).await?;
417 let strings = get_strings(&prefs.locale);
418 let body = format_message(
419 strings.plc_operation_body,
420 &[("handle", &prefs.handle), ("token", token)],
421 );
422 let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]);
423 enqueue_comms(
424 db,
425 NewComms::new(
426 user_id,
427 prefs.channel,
428 super::types::CommsType::PlcOperation,
429 prefs.email.clone().unwrap_or_default(),
430 Some(subject),
431 body,
432 ),
433 )
434 .await
435}
436
437pub async fn enqueue_2fa_code(
438 db: &PgPool,
439 user_id: Uuid,
440 code: &str,
441 hostname: &str,
442) -> Result<Uuid, sqlx::Error> {
443 let prefs = get_user_comms_prefs(db, user_id).await?;
444 let strings = get_strings(&prefs.locale);
445 let body = format_message(
446 strings.two_factor_code_body,
447 &[("handle", &prefs.handle), ("code", code)],
448 );
449 let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]);
450 enqueue_comms(
451 db,
452 NewComms::new(
453 user_id,
454 prefs.channel,
455 super::types::CommsType::TwoFactorCode,
456 prefs.email.clone().unwrap_or_default(),
457 Some(subject),
458 body,
459 ),
460 )
461 .await
462}
463
464pub async fn enqueue_passkey_recovery(
465 db: &PgPool,
466 user_id: Uuid,
467 recovery_url: &str,
468 hostname: &str,
469) -> Result<Uuid, sqlx::Error> {
470 let prefs = get_user_comms_prefs(db, user_id).await?;
471 let strings = get_strings(&prefs.locale);
472 let body = format_message(
473 strings.passkey_recovery_body,
474 &[("handle", &prefs.handle), ("url", recovery_url)],
475 );
476 let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]);
477 enqueue_comms(
478 db,
479 NewComms::new(
480 user_id,
481 prefs.channel,
482 super::types::CommsType::PasskeyRecovery,
483 prefs.email.clone().unwrap_or_default(),
484 Some(subject),
485 body,
486 ),
487 )
488 .await
489}
490
491pub fn channel_display_name(channel: CommsChannel) -> &'static str {
492 match channel {
493 CommsChannel::Email => "email",
494 CommsChannel::Discord => "Discord",
495 CommsChannel::Telegram => "Telegram",
496 CommsChannel::Signal => "Signal",
497 }
498}
499
500pub async fn enqueue_signup_verification(
501 db: &PgPool,
502 user_id: Uuid,
503 channel: &str,
504 recipient: &str,
505 code: &str,
506 locale: Option<&str>,
507) -> Result<Uuid, sqlx::Error> {
508 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
509 let comms_channel = match channel {
510 "email" => CommsChannel::Email,
511 "discord" => CommsChannel::Discord,
512 "telegram" => CommsChannel::Telegram,
513 "signal" => CommsChannel::Signal,
514 _ => CommsChannel::Email,
515 };
516 let strings = get_strings(locale.unwrap_or("en"));
517 let (verify_page, verify_link) = if comms_channel == CommsChannel::Email {
518 let encoded_email = urlencoding::encode(recipient);
519 let encoded_token = urlencoding::encode(code);
520 (
521 format!("https://{}/#/verify", hostname),
522 format!(
523 "https://{}/#/verify?token={}&identifier={}",
524 hostname, encoded_token, encoded_email
525 ),
526 )
527 } else {
528 (String::new(), String::new())
529 };
530 let body = format_message(
531 strings.signup_verification_body,
532 &[
533 ("code", code),
534 ("hostname", &hostname),
535 ("verify_page", &verify_page),
536 ("verify_link", &verify_link),
537 ],
538 );
539 let subject = match comms_channel {
540 CommsChannel::Email => Some(format_message(
541 strings.signup_verification_subject,
542 &[("hostname", &hostname)],
543 )),
544 _ => None,
545 };
546 enqueue_comms(
547 db,
548 NewComms::new(
549 user_id,
550 comms_channel,
551 super::types::CommsType::EmailVerification,
552 recipient.to_string(),
553 subject,
554 body,
555 ),
556 )
557 .await
558}
559
560pub async fn enqueue_migration_verification(
561 db: &PgPool,
562 user_id: Uuid,
563 email: &str,
564 token: &str,
565 hostname: &str,
566) -> Result<Uuid, sqlx::Error> {
567 let prefs = get_user_comms_prefs(db, user_id).await?;
568 let strings = get_strings(&prefs.locale);
569 let encoded_email = urlencoding::encode(email);
570 let encoded_token = urlencoding::encode(token);
571 let verify_page = format!("https://{}/#/verify", hostname);
572 let verify_link = format!(
573 "https://{}/#/verify?token={}&identifier={}",
574 hostname, encoded_token, encoded_email
575 );
576 let body = format_message(
577 strings.migration_verification_body,
578 &[
579 ("code", token),
580 ("hostname", hostname),
581 ("verify_page", &verify_page),
582 ("verify_link", &verify_link),
583 ],
584 );
585 let subject = format_message(
586 strings.migration_verification_subject,
587 &[("hostname", hostname)],
588 );
589 enqueue_comms(
590 db,
591 NewComms::email(
592 user_id,
593 super::types::CommsType::MigrationVerification,
594 email.to_string(),
595 subject,
596 body,
597 ),
598 )
599 .await
600}
601
602pub async fn queue_legacy_login_notification(
603 db: &PgPool,
604 user_id: Uuid,
605 hostname: &str,
606 client_ip: &str,
607 channel: CommsChannel,
608) -> Result<Uuid, sqlx::Error> {
609 let prefs = get_user_comms_prefs(db, user_id).await?;
610 let strings = get_strings(&prefs.locale);
611 let timestamp = chrono::Utc::now()
612 .format("%Y-%m-%d %H:%M:%S UTC")
613 .to_string();
614 let body = format_message(
615 strings.legacy_login_body,
616 &[
617 ("handle", &prefs.handle),
618 ("timestamp", ×tamp),
619 ("ip", client_ip),
620 ("hostname", hostname),
621 ],
622 );
623 let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]);
624 enqueue_comms(
625 db,
626 NewComms::new(
627 user_id,
628 channel,
629 super::types::CommsType::LegacyLoginAlert,
630 prefs.email.clone().unwrap_or_default(),
631 Some(subject),
632 body,
633 ),
634 )
635 .await
636}