this repo has no description
1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use hmac::Mac;
3use sha2::{Digest, Sha256};
4
5type HmacSha256 = hmac::Hmac<Sha256>;
6
7const TOKEN_VERSION: u8 = 1;
8const DEFAULT_SIGNUP_EXPIRY_MINUTES: u64 = 30;
9const DEFAULT_MIGRATION_EXPIRY_HOURS: u64 = 48;
10const DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES: u64 = 10;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum VerificationPurpose {
14 Signup,
15 Migration,
16 ChannelUpdate,
17}
18
19impl VerificationPurpose {
20 fn as_str(&self) -> &'static str {
21 match self {
22 Self::Signup => "signup",
23 Self::Migration => "migration",
24 Self::ChannelUpdate => "channel_update",
25 }
26 }
27
28 fn from_str(s: &str) -> Option<Self> {
29 match s {
30 "signup" => Some(Self::Signup),
31 "migration" => Some(Self::Migration),
32 "channel_update" => Some(Self::ChannelUpdate),
33 _ => None,
34 }
35 }
36
37 fn default_expiry_seconds(&self) -> u64 {
38 match self {
39 Self::Signup => DEFAULT_SIGNUP_EXPIRY_MINUTES * 60,
40 Self::Migration => DEFAULT_MIGRATION_EXPIRY_HOURS * 3600,
41 Self::ChannelUpdate => DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES * 60,
42 }
43 }
44}
45
46#[derive(Debug)]
47pub struct VerificationToken {
48 pub did: String,
49 pub purpose: VerificationPurpose,
50 pub channel: String,
51 pub identifier_hash: String,
52 pub expires_at: u64,
53}
54
55fn derive_verification_key() -> [u8; 32] {
56 use hkdf::Hkdf;
57 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| {
58 if cfg!(test) || std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_ok() {
59 "test-master-key-not-for-production".to_string()
60 } else {
61 panic!("MASTER_KEY must be set");
62 }
63 });
64 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
65 let mut key = [0u8; 32];
66 hk.expand(b"tranquil-pds-verification-token-v1", &mut key)
67 .expect("HKDF expansion failed");
68 key
69}
70
71pub fn hash_identifier(identifier: &str) -> String {
72 let mut hasher = Sha256::new();
73 hasher.update(identifier.to_lowercase().as_bytes());
74 let result = hasher.finalize();
75 URL_SAFE_NO_PAD.encode(&result[..16])
76}
77
78pub fn generate_signup_token(did: &str, channel: &str, identifier: &str) -> String {
79 generate_token(did, VerificationPurpose::Signup, channel, identifier)
80}
81
82pub fn generate_migration_token(did: &str, email: &str) -> String {
83 generate_token(did, VerificationPurpose::Migration, "email", email)
84}
85
86pub fn generate_channel_update_token(did: &str, channel: &str, identifier: &str) -> String {
87 generate_token(did, VerificationPurpose::ChannelUpdate, channel, identifier)
88}
89
90pub fn generate_token(
91 did: &str,
92 purpose: VerificationPurpose,
93 channel: &str,
94 identifier: &str,
95) -> String {
96 generate_token_with_expiry(
97 did,
98 purpose,
99 channel,
100 identifier,
101 purpose.default_expiry_seconds(),
102 )
103}
104
105pub fn generate_token_with_expiry(
106 did: &str,
107 purpose: VerificationPurpose,
108 channel: &str,
109 identifier: &str,
110 expiry_seconds: u64,
111) -> String {
112 let key = derive_verification_key();
113 let identifier_hash = hash_identifier(identifier);
114 let expires_at = std::time::SystemTime::now()
115 .duration_since(std::time::UNIX_EPOCH)
116 .unwrap_or_default()
117 .as_secs()
118 + expiry_seconds;
119
120 let payload = format!(
121 "{}|{}|{}|{}|{}",
122 did,
123 purpose.as_str(),
124 channel,
125 identifier_hash,
126 expires_at
127 );
128
129 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid");
130 mac.update(payload.as_bytes());
131 let signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
132
133 let token_data = format!(
134 "{}|{}|{}|{}|{}|{}|{}",
135 TOKEN_VERSION,
136 did,
137 purpose.as_str(),
138 channel,
139 identifier_hash,
140 expires_at,
141 signature
142 );
143 URL_SAFE_NO_PAD.encode(token_data.as_bytes())
144}
145
146#[derive(Debug)]
147pub enum VerifyError {
148 InvalidFormat,
149 UnsupportedVersion,
150 Expired,
151 InvalidSignature,
152 IdentifierMismatch,
153 PurposeMismatch,
154 ChannelMismatch,
155}
156
157impl std::fmt::Display for VerifyError {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 match self {
160 Self::InvalidFormat => write!(f, "Invalid token format"),
161 Self::UnsupportedVersion => write!(f, "Unsupported token version"),
162 Self::Expired => write!(f, "Token has expired"),
163 Self::InvalidSignature => write!(f, "Invalid token signature"),
164 Self::IdentifierMismatch => write!(f, "Identifier does not match token"),
165 Self::PurposeMismatch => write!(f, "Token purpose does not match"),
166 Self::ChannelMismatch => write!(f, "Token channel does not match"),
167 }
168 }
169}
170
171pub fn verify_signup_token(
172 token: &str,
173 expected_channel: &str,
174 expected_identifier: &str,
175) -> Result<VerificationToken, VerifyError> {
176 let parsed = verify_token_signature(token)?;
177 if parsed.purpose != VerificationPurpose::Signup {
178 return Err(VerifyError::PurposeMismatch);
179 }
180 if parsed.channel != expected_channel {
181 return Err(VerifyError::ChannelMismatch);
182 }
183 let expected_hash = hash_identifier(expected_identifier);
184 if parsed.identifier_hash != expected_hash {
185 return Err(VerifyError::IdentifierMismatch);
186 }
187 Ok(parsed)
188}
189
190pub fn verify_migration_token(
191 token: &str,
192 expected_email: &str,
193) -> Result<VerificationToken, VerifyError> {
194 let parsed = verify_token_signature(token)?;
195 if parsed.purpose != VerificationPurpose::Migration {
196 return Err(VerifyError::PurposeMismatch);
197 }
198 if parsed.channel != "email" {
199 return Err(VerifyError::ChannelMismatch);
200 }
201 let expected_hash = hash_identifier(expected_email);
202 if parsed.identifier_hash != expected_hash {
203 return Err(VerifyError::IdentifierMismatch);
204 }
205 Ok(parsed)
206}
207
208pub fn verify_channel_update_token(
209 token: &str,
210 expected_channel: &str,
211 expected_identifier: &str,
212) -> Result<VerificationToken, VerifyError> {
213 let parsed = verify_token_signature(token)?;
214 if parsed.purpose != VerificationPurpose::ChannelUpdate {
215 return Err(VerifyError::PurposeMismatch);
216 }
217 if parsed.channel != expected_channel {
218 return Err(VerifyError::ChannelMismatch);
219 }
220 let expected_hash = hash_identifier(expected_identifier);
221 if parsed.identifier_hash != expected_hash {
222 return Err(VerifyError::IdentifierMismatch);
223 }
224 Ok(parsed)
225}
226
227pub fn verify_token_for_did(
228 token: &str,
229 expected_did: &str,
230) -> Result<VerificationToken, VerifyError> {
231 let parsed = verify_token_signature(token)?;
232 if parsed.did != expected_did {
233 return Err(VerifyError::IdentifierMismatch);
234 }
235 Ok(parsed)
236}
237
238pub fn verify_token_signature(token: &str) -> Result<VerificationToken, VerifyError> {
239 let token_bytes = URL_SAFE_NO_PAD
240 .decode(token.trim())
241 .map_err(|_| VerifyError::InvalidFormat)?;
242 let token_str = String::from_utf8(token_bytes).map_err(|_| VerifyError::InvalidFormat)?;
243
244 let parts: Vec<&str> = token_str.split('|').collect();
245 if parts.len() != 7 {
246 return Err(VerifyError::InvalidFormat);
247 }
248
249 let version: u8 = parts[0].parse().map_err(|_| VerifyError::InvalidFormat)?;
250 if version != TOKEN_VERSION {
251 return Err(VerifyError::UnsupportedVersion);
252 }
253
254 let did = parts[1];
255 let purpose_str = parts[2];
256 let channel = parts[3];
257 let identifier_hash = parts[4];
258 let expires_at: u64 = parts[5].parse().map_err(|_| VerifyError::InvalidFormat)?;
259 let provided_signature = parts[6];
260
261 let purpose = VerificationPurpose::from_str(purpose_str).ok_or(VerifyError::InvalidFormat)?;
262
263 let now = std::time::SystemTime::now()
264 .duration_since(std::time::UNIX_EPOCH)
265 .unwrap_or_default()
266 .as_secs();
267 if now > expires_at {
268 return Err(VerifyError::Expired);
269 }
270
271 let key = derive_verification_key();
272 let payload = format!(
273 "{}|{}|{}|{}|{}",
274 did, purpose_str, channel, identifier_hash, expires_at
275 );
276 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid");
277 mac.update(payload.as_bytes());
278 let expected_signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
279
280 use subtle::ConstantTimeEq;
281 let sig_matches: bool = provided_signature
282 .as_bytes()
283 .ct_eq(expected_signature.as_bytes())
284 .into();
285 if !sig_matches {
286 return Err(VerifyError::InvalidSignature);
287 }
288
289 Ok(VerificationToken {
290 did: did.to_string(),
291 purpose,
292 channel: channel.to_string(),
293 identifier_hash: identifier_hash.to_string(),
294 expires_at,
295 })
296}
297
298pub fn format_token_for_display(token: &str) -> String {
299 let clean = token.replace(['-', ' '], "");
300 let mut result = String::new();
301 for (i, c) in clean.chars().enumerate() {
302 if i > 0 && i % 4 == 0 {
303 result.push('-');
304 }
305 result.push(c);
306 }
307 result
308}
309
310pub fn normalize_token_input(input: &str) -> String {
311 input
312 .chars()
313 .filter(|c| c.is_ascii_alphanumeric() || *c == '_' || *c == '=')
314 .collect()
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_signup_token() {
323 let did = "did:plc:test123";
324 let channel = "email";
325 let identifier = "test@example.com";
326 let token = generate_signup_token(did, channel, identifier);
327 let result = verify_signup_token(&token, channel, identifier);
328 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
329 let parsed = result.unwrap();
330 assert_eq!(parsed.did, did);
331 assert_eq!(parsed.purpose, VerificationPurpose::Signup);
332 assert_eq!(parsed.channel, channel);
333 }
334
335 #[test]
336 fn test_migration_token() {
337 let did = "did:plc:test123";
338 let email = "test@example.com";
339 let token = generate_migration_token(did, email);
340 let result = verify_migration_token(&token, email);
341 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
342 let parsed = result.unwrap();
343 assert_eq!(parsed.did, did);
344 assert_eq!(parsed.purpose, VerificationPurpose::Migration);
345 }
346
347 #[test]
348 fn test_token_case_insensitive() {
349 let did = "did:plc:test123";
350 let token = generate_signup_token(did, "email", "Test@Example.COM");
351 let result = verify_signup_token(&token, "email", "test@example.com");
352 assert!(result.is_ok());
353 }
354
355 #[test]
356 fn test_token_wrong_identifier() {
357 let did = "did:plc:test123";
358 let token = generate_signup_token(did, "email", "test@example.com");
359 let result = verify_signup_token(&token, "email", "other@example.com");
360 assert!(matches!(result, Err(VerifyError::IdentifierMismatch)));
361 }
362
363 #[test]
364 fn test_token_wrong_channel() {
365 let did = "did:plc:test123";
366 let token = generate_signup_token(did, "email", "test@example.com");
367 let result = verify_signup_token(&token, "discord", "test@example.com");
368 assert!(matches!(result, Err(VerifyError::ChannelMismatch)));
369 }
370
371 #[test]
372 fn test_expired_token() {
373 let did = "did:plc:test123";
374 let token = generate_token_with_expiry(
375 did,
376 VerificationPurpose::Signup,
377 "email",
378 "test@example.com",
379 0,
380 );
381 std::thread::sleep(std::time::Duration::from_millis(1100));
382 let result = verify_signup_token(&token, "email", "test@example.com");
383 assert!(matches!(result, Err(VerifyError::Expired)));
384 }
385
386 #[test]
387 fn test_invalid_token() {
388 let result = verify_signup_token("invalid-token", "email", "test@example.com");
389 assert!(matches!(result, Err(VerifyError::InvalidFormat)));
390 }
391
392 #[test]
393 fn test_purpose_mismatch() {
394 let did = "did:plc:test123";
395 let email = "test@example.com";
396 let signup_token = generate_signup_token(did, "email", email);
397 let result = verify_migration_token(&signup_token, email);
398 assert!(matches!(result, Err(VerifyError::PurposeMismatch)));
399 }
400
401 #[test]
402 fn test_discord_channel() {
403 let did = "did:plc:test123";
404 let discord_id = "123456789012345678";
405 let token = generate_signup_token(did, "discord", discord_id);
406 let result = verify_signup_token(&token, "discord", discord_id);
407 assert!(result.is_ok());
408 }
409
410 #[test]
411 fn test_format_token_for_display() {
412 let token = "ABCDEFGHIJKLMNOP";
413 let formatted = format_token_for_display(token);
414 assert_eq!(formatted, "ABCD-EFGH-IJKL-MNOP");
415 }
416
417 #[test]
418 fn test_normalize_token_input() {
419 let input = "ABCD-EFGH IJKL-MNOP";
420 let normalized = normalize_token_input(input);
421 assert_eq!(normalized, "ABCDEFGHIJKLMNOP");
422 }
423}