this repo has no description
1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use chrono::Utc;
6use sqlx::PgPool;
7use tokio::sync::watch;
8use tokio::time::interval;
9use tracing::{debug, error, info, warn};
10use uuid::Uuid;
11
12use super::sender::{CommsSender, SendError};
13use super::types::{CommsChannel, CommsStatus, NewComms, QueuedComms};
14
15pub struct CommsService {
16 db: PgPool,
17 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>,
18 poll_interval: Duration,
19 batch_size: i64,
20}
21
22impl CommsService {
23 pub fn new(db: PgPool) -> Self {
24 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS")
25 .ok()
26 .and_then(|v| v.parse().ok())
27 .unwrap_or(1000);
28 let batch_size: i64 = std::env::var("NOTIFICATION_BATCH_SIZE")
29 .ok()
30 .and_then(|v| v.parse().ok())
31 .unwrap_or(100);
32 Self {
33 db,
34 senders: HashMap::new(),
35 poll_interval: Duration::from_millis(poll_interval_ms),
36 batch_size,
37 }
38 }
39
40 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
41 self.poll_interval = interval;
42 self
43 }
44
45 pub fn with_batch_size(mut self, size: i64) -> Self {
46 self.batch_size = size;
47 self
48 }
49
50 pub fn register_sender<S: CommsSender + 'static>(mut self, sender: S) -> Self {
51 self.senders.insert(sender.channel(), Arc::new(sender));
52 self
53 }
54
55 pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> {
56 let id = sqlx::query_scalar!(
57 r#"
58 INSERT INTO comms_queue
59 (user_id, channel, comms_type, recipient, subject, body, metadata)
60 VALUES ($1, $2, $3, $4, $5, $6, $7)
61 RETURNING id
62 "#,
63 item.user_id,
64 item.channel as CommsChannel,
65 item.comms_type as super::types::CommsType,
66 item.recipient,
67 item.subject,
68 item.body,
69 item.metadata
70 )
71 .fetch_one(&self.db)
72 .await?;
73 debug!(comms_id = %id, "Comms enqueued");
74 Ok(id)
75 }
76
77 pub fn has_senders(&self) -> bool {
78 !self.senders.is_empty()
79 }
80
81 pub async fn run(self, mut shutdown: watch::Receiver<bool>) {
82 if self.senders.is_empty() {
83 warn!(
84 "Comms service starting with no senders configured. Messages will be queued but not delivered until senders are configured."
85 );
86 }
87 info!(
88 poll_interval_secs = self.poll_interval.as_secs(),
89 batch_size = self.batch_size,
90 channels = ?self.senders.keys().collect::<Vec<_>>(),
91 "Starting comms service"
92 );
93 let mut ticker = interval(self.poll_interval);
94 loop {
95 tokio::select! {
96 _ = ticker.tick() => {
97 if let Err(e) = self.process_batch().await {
98 error!(error = %e, "Failed to process comms batch");
99 }
100 }
101 _ = shutdown.changed() => {
102 if *shutdown.borrow() {
103 info!("Comms service shutting down");
104 break;
105 }
106 }
107 }
108 }
109 }
110
111 async fn process_batch(&self) -> Result<(), sqlx::Error> {
112 let items = self.fetch_pending().await?;
113 if items.is_empty() {
114 return Ok(());
115 }
116 debug!(count = items.len(), "Processing comms batch");
117 for item in items {
118 self.process_item(item).await;
119 }
120 Ok(())
121 }
122
123 async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> {
124 let now = Utc::now();
125 sqlx::query_as!(
126 QueuedComms,
127 r#"
128 UPDATE comms_queue
129 SET status = 'processing', updated_at = NOW()
130 WHERE id IN (
131 SELECT id FROM comms_queue
132 WHERE status = 'pending'
133 AND scheduled_for <= $1
134 AND attempts < max_attempts
135 ORDER BY scheduled_for ASC
136 LIMIT $2
137 FOR UPDATE SKIP LOCKED
138 )
139 RETURNING
140 id, user_id,
141 channel as "channel: CommsChannel",
142 comms_type as "comms_type: super::types::CommsType",
143 status as "status: CommsStatus",
144 recipient, subject, body, metadata,
145 attempts, max_attempts, last_error,
146 created_at, updated_at, scheduled_for, processed_at
147 "#,
148 now,
149 self.batch_size
150 )
151 .fetch_all(&self.db)
152 .await
153 }
154
155 async fn process_item(&self, item: QueuedComms) {
156 let comms_id = item.id;
157 let channel = item.channel;
158 let result = match self.senders.get(&channel) {
159 Some(sender) => sender.send(&item).await,
160 None => {
161 warn!(
162 comms_id = %comms_id,
163 channel = ?channel,
164 "No sender registered for channel"
165 );
166 Err(SendError::NotConfigured(channel))
167 }
168 };
169 match result {
170 Ok(()) => {
171 debug!(comms_id = %comms_id, "Comms sent successfully");
172 if let Err(e) = self.mark_sent(comms_id).await {
173 error!(
174 comms_id = %comms_id,
175 error = %e,
176 "Failed to mark comms as sent"
177 );
178 }
179 }
180 Err(e) => {
181 let error_msg = e.to_string();
182 warn!(
183 comms_id = %comms_id,
184 error = %error_msg,
185 "Failed to send comms"
186 );
187 if let Err(db_err) = self.mark_failed(comms_id, &error_msg).await {
188 error!(
189 comms_id = %comms_id,
190 error = %db_err,
191 "Failed to mark comms as failed"
192 );
193 }
194 }
195 }
196 }
197
198 async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> {
199 sqlx::query!(
200 r#"
201 UPDATE comms_queue
202 SET status = 'sent', processed_at = NOW(), updated_at = NOW()
203 WHERE id = $1
204 "#,
205 id
206 )
207 .execute(&self.db)
208 .await?;
209 Ok(())
210 }
211
212 async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> {
213 sqlx::query!(
214 r#"
215 UPDATE comms_queue
216 SET
217 status = CASE
218 WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status
219 ELSE 'pending'::comms_status
220 END,
221 attempts = attempts + 1,
222 last_error = $2,
223 updated_at = NOW(),
224 scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1))
225 WHERE id = $1
226 "#,
227 id,
228 error
229 )
230 .execute(&self.db)
231 .await?;
232 Ok(())
233 }
234}
235
236pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> {
237 sqlx::query_scalar!(
238 r#"
239 INSERT INTO comms_queue
240 (user_id, channel, comms_type, recipient, subject, body, metadata)
241 VALUES ($1, $2, $3, $4, $5, $6, $7)
242 RETURNING id
243 "#,
244 item.user_id,
245 item.channel as CommsChannel,
246 item.comms_type as super::types::CommsType,
247 item.recipient,
248 item.subject,
249 item.body,
250 item.metadata
251 )
252 .fetch_one(db)
253 .await
254}
255
256pub struct UserCommsPrefs {
257 pub channel: CommsChannel,
258 pub email: Option<String>,
259 pub handle: String,
260}
261
262pub async fn get_user_comms_prefs(
263 db: &PgPool,
264 user_id: Uuid,
265) -> Result<UserCommsPrefs, sqlx::Error> {
266 let row = sqlx::query!(
267 r#"
268 SELECT
269 email,
270 handle,
271 preferred_comms_channel as "channel: CommsChannel"
272 FROM users
273 WHERE id = $1
274 "#,
275 user_id
276 )
277 .fetch_one(db)
278 .await?;
279 Ok(UserCommsPrefs {
280 channel: row.channel,
281 email: row.email,
282 handle: row.handle,
283 })
284}
285
286pub async fn enqueue_welcome(
287 db: &PgPool,
288 user_id: Uuid,
289 hostname: &str,
290) -> Result<Uuid, sqlx::Error> {
291 let prefs = get_user_comms_prefs(db, user_id).await?;
292 let body = format!(
293 "Welcome to {}!\n\nYour handle is: @{}\n\nThank you for joining us.",
294 hostname, prefs.handle
295 );
296 enqueue_comms(
297 db,
298 NewComms::new(
299 user_id,
300 prefs.channel,
301 super::types::CommsType::Welcome,
302 prefs.email.clone().unwrap_or_default(),
303 Some(format!("Welcome to {}", hostname)),
304 body,
305 ),
306 )
307 .await
308}
309
310pub async fn enqueue_email_verification(
311 db: &PgPool,
312 user_id: Uuid,
313 email: &str,
314 handle: &str,
315 code: &str,
316 hostname: &str,
317) -> Result<Uuid, sqlx::Error> {
318 let body = format!(
319 "Hello @{},\n\nYour email verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
320 handle, code
321 );
322 enqueue_comms(
323 db,
324 NewComms::email(
325 user_id,
326 super::types::CommsType::EmailVerification,
327 email.to_string(),
328 format!("Verify your email - {}", hostname),
329 body,
330 ),
331 )
332 .await
333}
334
335pub async fn enqueue_password_reset(
336 db: &PgPool,
337 user_id: Uuid,
338 code: &str,
339 hostname: &str,
340) -> Result<Uuid, sqlx::Error> {
341 let prefs = get_user_comms_prefs(db, user_id).await?;
342 let body = format!(
343 "Hello @{},\n\nYour password reset code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this message.",
344 prefs.handle, code
345 );
346 enqueue_comms(
347 db,
348 NewComms::new(
349 user_id,
350 prefs.channel,
351 super::types::CommsType::PasswordReset,
352 prefs.email.clone().unwrap_or_default(),
353 Some(format!("Password Reset - {}", hostname)),
354 body,
355 ),
356 )
357 .await
358}
359
360pub async fn enqueue_email_update(
361 db: &PgPool,
362 user_id: Uuid,
363 new_email: &str,
364 handle: &str,
365 code: &str,
366 hostname: &str,
367) -> Result<Uuid, sqlx::Error> {
368 let body = format!(
369 "Hello @{},\n\nYour email update confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please ignore this email.",
370 handle, code
371 );
372 enqueue_comms(
373 db,
374 NewComms::email(
375 user_id,
376 super::types::CommsType::EmailUpdate,
377 new_email.to_string(),
378 format!("Confirm your new email - {}", hostname),
379 body,
380 ),
381 )
382 .await
383}
384
385pub async fn enqueue_account_deletion(
386 db: &PgPool,
387 user_id: Uuid,
388 code: &str,
389 hostname: &str,
390) -> Result<Uuid, sqlx::Error> {
391 let prefs = get_user_comms_prefs(db, user_id).await?;
392 let body = format!(
393 "Hello @{},\n\nYour account deletion confirmation code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
394 prefs.handle, code
395 );
396 enqueue_comms(
397 db,
398 NewComms::new(
399 user_id,
400 prefs.channel,
401 super::types::CommsType::AccountDeletion,
402 prefs.email.clone().unwrap_or_default(),
403 Some(format!("Account Deletion Request - {}", hostname)),
404 body,
405 ),
406 )
407 .await
408}
409
410pub async fn enqueue_plc_operation(
411 db: &PgPool,
412 user_id: Uuid,
413 token: &str,
414 hostname: &str,
415) -> Result<Uuid, sqlx::Error> {
416 let prefs = get_user_comms_prefs(db, user_id).await?;
417 let body = format!(
418 "Hello @{},\n\nYou requested to sign a PLC operation for your account.\n\nYour verification token is: {}\n\nThis token will expire in 10 minutes.\n\nIf you did not request this, you can safely ignore this message.",
419 prefs.handle, token
420 );
421 enqueue_comms(
422 db,
423 NewComms::new(
424 user_id,
425 prefs.channel,
426 super::types::CommsType::PlcOperation,
427 prefs.email.clone().unwrap_or_default(),
428 Some(format!("{} - PLC Operation Token", hostname)),
429 body,
430 ),
431 )
432 .await
433}
434
435pub async fn enqueue_2fa_code(
436 db: &PgPool,
437 user_id: Uuid,
438 code: &str,
439 hostname: &str,
440) -> Result<Uuid, sqlx::Error> {
441 let prefs = get_user_comms_prefs(db, user_id).await?;
442 let body = format!(
443 "Hello @{},\n\nYour sign-in verification code is: {}\n\nThis code will expire in 10 minutes.\n\nIf you did not request this, please secure your account immediately.",
444 prefs.handle, code
445 );
446 enqueue_comms(
447 db,
448 NewComms::new(
449 user_id,
450 prefs.channel,
451 super::types::CommsType::TwoFactorCode,
452 prefs.email.clone().unwrap_or_default(),
453 Some(format!("Sign-in Verification - {}", hostname)),
454 body,
455 ),
456 )
457 .await
458}
459
460pub fn channel_display_name(channel: CommsChannel) -> &'static str {
461 match channel {
462 CommsChannel::Email => "email",
463 CommsChannel::Discord => "Discord",
464 CommsChannel::Telegram => "Telegram",
465 CommsChannel::Signal => "Signal",
466 }
467}
468
469pub async fn enqueue_signup_verification(
470 db: &PgPool,
471 user_id: Uuid,
472 channel: &str,
473 recipient: &str,
474 code: &str,
475) -> Result<Uuid, sqlx::Error> {
476 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
477 let comms_channel = match channel {
478 "email" => CommsChannel::Email,
479 "discord" => CommsChannel::Discord,
480 "telegram" => CommsChannel::Telegram,
481 "signal" => CommsChannel::Signal,
482 _ => CommsChannel::Email,
483 };
484 let body = format!(
485 "Welcome! Your account verification code is: {}\n\nThis code will expire in 30 minutes.\n\nEnter this code to complete your registration on {}.",
486 code, hostname
487 );
488 let subject = match comms_channel {
489 CommsChannel::Email => Some(format!("Verify your account - {}", hostname)),
490 _ => None,
491 };
492 enqueue_comms(
493 db,
494 NewComms::new(
495 user_id,
496 comms_channel,
497 super::types::CommsType::EmailVerification,
498 recipient.to_string(),
499 subject,
500 body,
501 ),
502 )
503 .await
504}