this repo has no description
1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use hmac::Mac;
3use sha2::{Digest, Sha256};
4
5type HmacSha256 = hmac::Hmac<Sha256>;
6
7const TOKEN_VERSION: u8 = 1;
8const DEFAULT_SIGNUP_EXPIRY_MINUTES: u64 = 30;
9const DEFAULT_MIGRATION_EXPIRY_HOURS: u64 = 48;
10const DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES: u64 = 10;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum VerificationPurpose {
14 Signup,
15 Migration,
16 ChannelUpdate,
17}
18
19impl VerificationPurpose {
20 fn as_str(&self) -> &'static str {
21 match self {
22 Self::Signup => "signup",
23 Self::Migration => "migration",
24 Self::ChannelUpdate => "channel_update",
25 }
26 }
27
28 fn from_str(s: &str) -> Option<Self> {
29 match s {
30 "signup" => Some(Self::Signup),
31 "migration" => Some(Self::Migration),
32 "channel_update" => Some(Self::ChannelUpdate),
33 _ => None,
34 }
35 }
36
37 fn default_expiry_seconds(&self) -> u64 {
38 match self {
39 Self::Signup => DEFAULT_SIGNUP_EXPIRY_MINUTES * 60,
40 Self::Migration => DEFAULT_MIGRATION_EXPIRY_HOURS * 3600,
41 Self::ChannelUpdate => DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES * 60,
42 }
43 }
44}
45
46#[derive(Debug)]
47pub struct VerificationToken {
48 pub did: String,
49 pub purpose: VerificationPurpose,
50 pub channel: String,
51 pub identifier_hash: String,
52 pub expires_at: u64,
53}
54
55fn derive_verification_key() -> [u8; 32] {
56 use hkdf::Hkdf;
57 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| {
58 if cfg!(test) || std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_ok() {
59 "test-master-key-not-for-production".to_string()
60 } else {
61 panic!("MASTER_KEY must be set");
62 }
63 });
64 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes());
65 let mut key = [0u8; 32];
66 hk.expand(b"tranquil-pds-verification-token-v1", &mut key)
67 .expect("HKDF expansion failed");
68 key
69}
70
71pub fn hash_identifier(identifier: &str) -> String {
72 let mut hasher = Sha256::new();
73 hasher.update(identifier.to_lowercase().as_bytes());
74 let result = hasher.finalize();
75 URL_SAFE_NO_PAD.encode(&result[..16])
76}
77
78pub fn generate_signup_token(did: &str, channel: &str, identifier: &str) -> String {
79 generate_token(did, VerificationPurpose::Signup, channel, identifier)
80}
81
82pub fn generate_migration_token(did: &str, email: &str) -> String {
83 generate_token(did, VerificationPurpose::Migration, "email", email)
84}
85
86pub fn generate_channel_update_token(did: &str, channel: &str, identifier: &str) -> String {
87 generate_token(did, VerificationPurpose::ChannelUpdate, channel, identifier)
88}
89
90pub fn generate_token(
91 did: &str,
92 purpose: VerificationPurpose,
93 channel: &str,
94 identifier: &str,
95) -> String {
96 generate_token_with_expiry(
97 did,
98 purpose,
99 channel,
100 identifier,
101 purpose.default_expiry_seconds(),
102 )
103}
104
105pub fn generate_token_with_expiry(
106 did: &str,
107 purpose: VerificationPurpose,
108 channel: &str,
109 identifier: &str,
110 expiry_seconds: u64,
111) -> String {
112 let key = derive_verification_key();
113 let identifier_hash = hash_identifier(identifier);
114 let expires_at = std::time::SystemTime::now()
115 .duration_since(std::time::UNIX_EPOCH)
116 .unwrap_or_default()
117 .as_secs()
118 + expiry_seconds;
119
120 let payload = format!(
121 "{}|{}|{}|{}|{}",
122 did,
123 purpose.as_str(),
124 channel,
125 identifier_hash,
126 expires_at
127 );
128
129 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid");
130 mac.update(payload.as_bytes());
131 let signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
132
133 let token_data = format!(
134 "{}|{}|{}|{}|{}|{}|{}",
135 TOKEN_VERSION,
136 did,
137 purpose.as_str(),
138 channel,
139 identifier_hash,
140 expires_at,
141 signature
142 );
143 URL_SAFE_NO_PAD.encode(token_data.as_bytes())
144}
145
146#[derive(Debug)]
147pub enum VerifyError {
148 InvalidFormat,
149 UnsupportedVersion,
150 Expired,
151 InvalidSignature,
152 IdentifierMismatch,
153 PurposeMismatch,
154 ChannelMismatch,
155}
156
157impl std::fmt::Display for VerifyError {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 match self {
160 Self::InvalidFormat => write!(f, "Invalid token format"),
161 Self::UnsupportedVersion => write!(f, "Unsupported token version"),
162 Self::Expired => write!(f, "Token has expired"),
163 Self::InvalidSignature => write!(f, "Invalid token signature"),
164 Self::IdentifierMismatch => write!(f, "Identifier does not match token"),
165 Self::PurposeMismatch => write!(f, "Token purpose does not match"),
166 Self::ChannelMismatch => write!(f, "Token channel does not match"),
167 }
168 }
169}
170
171pub fn verify_signup_token(
172 token: &str,
173 expected_channel: &str,
174 expected_identifier: &str,
175) -> Result<VerificationToken, VerifyError> {
176 let parsed = verify_token_signature(token)?;
177 if parsed.purpose != VerificationPurpose::Signup {
178 return Err(VerifyError::PurposeMismatch);
179 }
180 if parsed.channel != expected_channel {
181 return Err(VerifyError::ChannelMismatch);
182 }
183 let expected_hash = hash_identifier(expected_identifier);
184 if parsed.identifier_hash != expected_hash {
185 return Err(VerifyError::IdentifierMismatch);
186 }
187 Ok(parsed)
188}
189
190pub fn verify_migration_token(
191 token: &str,
192 expected_email: &str,
193) -> Result<VerificationToken, VerifyError> {
194 let parsed = verify_token_signature(token)?;
195 if parsed.purpose != VerificationPurpose::Migration {
196 return Err(VerifyError::PurposeMismatch);
197 }
198 if parsed.channel != "email" {
199 return Err(VerifyError::ChannelMismatch);
200 }
201 let expected_hash = hash_identifier(expected_email);
202 if parsed.identifier_hash != expected_hash {
203 return Err(VerifyError::IdentifierMismatch);
204 }
205 Ok(parsed)
206}
207
208pub fn verify_channel_update_token(
209 token: &str,
210 expected_channel: &str,
211 expected_identifier: &str,
212) -> Result<VerificationToken, VerifyError> {
213 let parsed = verify_token_signature(token)?;
214 if parsed.purpose != VerificationPurpose::ChannelUpdate {
215 return Err(VerifyError::PurposeMismatch);
216 }
217 if parsed.channel != expected_channel {
218 return Err(VerifyError::ChannelMismatch);
219 }
220 let expected_hash = hash_identifier(expected_identifier);
221 if parsed.identifier_hash != expected_hash {
222 return Err(VerifyError::IdentifierMismatch);
223 }
224 Ok(parsed)
225}
226
227pub fn verify_token_for_did(
228 token: &str,
229 expected_did: &str,
230) -> Result<VerificationToken, VerifyError> {
231 let parsed = verify_token_signature(token)?;
232 if parsed.did != expected_did {
233 return Err(VerifyError::IdentifierMismatch);
234 }
235 Ok(parsed)
236}
237
238pub fn verify_token_signature(token: &str) -> Result<VerificationToken, VerifyError> {
239 let token_bytes = URL_SAFE_NO_PAD
240 .decode(token.trim())
241 .map_err(|_| VerifyError::InvalidFormat)?;
242 let token_str = String::from_utf8(token_bytes).map_err(|_| VerifyError::InvalidFormat)?;
243
244 let parts: Vec<&str> = token_str.split('|').collect();
245 if parts.len() != 7 {
246 return Err(VerifyError::InvalidFormat);
247 }
248
249 let version: u8 = parts[0].parse().map_err(|_| VerifyError::InvalidFormat)?;
250 if version != TOKEN_VERSION {
251 return Err(VerifyError::UnsupportedVersion);
252 }
253
254 let did = parts[1];
255 let purpose_str = parts[2];
256 let channel = parts[3];
257 let identifier_hash = parts[4];
258 let expires_at: u64 = parts[5].parse().map_err(|_| VerifyError::InvalidFormat)?;
259 let provided_signature = parts[6];
260
261 let purpose = VerificationPurpose::from_str(purpose_str).ok_or(VerifyError::InvalidFormat)?;
262
263 let now = std::time::SystemTime::now()
264 .duration_since(std::time::UNIX_EPOCH)
265 .unwrap_or_default()
266 .as_secs();
267 if now > expires_at {
268 return Err(VerifyError::Expired);
269 }
270
271 let key = derive_verification_key();
272 let payload = format!(
273 "{}|{}|{}|{}|{}",
274 did, purpose_str, channel, identifier_hash, expires_at
275 );
276 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid");
277 mac.update(payload.as_bytes());
278 let expected_signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
279
280 use subtle::ConstantTimeEq;
281 let sig_matches: bool = provided_signature
282 .as_bytes()
283 .ct_eq(expected_signature.as_bytes())
284 .into();
285 if !sig_matches {
286 return Err(VerifyError::InvalidSignature);
287 }
288
289 Ok(VerificationToken {
290 did: did.to_string(),
291 purpose,
292 channel: channel.to_string(),
293 identifier_hash: identifier_hash.to_string(),
294 expires_at,
295 })
296}
297
298pub fn format_token_for_display(token: &str) -> String {
299 token
300 .replace(['-', ' '], "")
301 .chars()
302 .collect::<Vec<_>>()
303 .chunks(4)
304 .map(|chunk| chunk.iter().collect::<String>())
305 .collect::<Vec<_>>()
306 .join("-")
307}
308
309pub fn normalize_token_input(input: &str) -> String {
310 input
311 .chars()
312 .filter(|c| c.is_ascii_alphanumeric() || *c == '_' || *c == '=')
313 .collect()
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_signup_token() {
322 let did = "did:plc:test123";
323 let channel = "email";
324 let identifier = "test@example.com";
325 let token = generate_signup_token(did, channel, identifier);
326 let result = verify_signup_token(&token, channel, identifier);
327 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
328 let parsed = result.unwrap();
329 assert_eq!(parsed.did, did);
330 assert_eq!(parsed.purpose, VerificationPurpose::Signup);
331 assert_eq!(parsed.channel, channel);
332 }
333
334 #[test]
335 fn test_migration_token() {
336 let did = "did:plc:test123";
337 let email = "test@example.com";
338 let token = generate_migration_token(did, email);
339 let result = verify_migration_token(&token, email);
340 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
341 let parsed = result.unwrap();
342 assert_eq!(parsed.did, did);
343 assert_eq!(parsed.purpose, VerificationPurpose::Migration);
344 }
345
346 #[test]
347 fn test_token_case_insensitive() {
348 let did = "did:plc:test123";
349 let token = generate_signup_token(did, "email", "Test@Example.COM");
350 let result = verify_signup_token(&token, "email", "test@example.com");
351 assert!(result.is_ok());
352 }
353
354 #[test]
355 fn test_token_wrong_identifier() {
356 let did = "did:plc:test123";
357 let token = generate_signup_token(did, "email", "test@example.com");
358 let result = verify_signup_token(&token, "email", "other@example.com");
359 assert!(matches!(result, Err(VerifyError::IdentifierMismatch)));
360 }
361
362 #[test]
363 fn test_token_wrong_channel() {
364 let did = "did:plc:test123";
365 let token = generate_signup_token(did, "email", "test@example.com");
366 let result = verify_signup_token(&token, "discord", "test@example.com");
367 assert!(matches!(result, Err(VerifyError::ChannelMismatch)));
368 }
369
370 #[test]
371 fn test_expired_token() {
372 let did = "did:plc:test123";
373 let token = generate_token_with_expiry(
374 did,
375 VerificationPurpose::Signup,
376 "email",
377 "test@example.com",
378 0,
379 );
380 std::thread::sleep(std::time::Duration::from_millis(1100));
381 let result = verify_signup_token(&token, "email", "test@example.com");
382 assert!(matches!(result, Err(VerifyError::Expired)));
383 }
384
385 #[test]
386 fn test_invalid_token() {
387 let result = verify_signup_token("invalid-token", "email", "test@example.com");
388 assert!(matches!(result, Err(VerifyError::InvalidFormat)));
389 }
390
391 #[test]
392 fn test_purpose_mismatch() {
393 let did = "did:plc:test123";
394 let email = "test@example.com";
395 let signup_token = generate_signup_token(did, "email", email);
396 let result = verify_migration_token(&signup_token, email);
397 assert!(matches!(result, Err(VerifyError::PurposeMismatch)));
398 }
399
400 #[test]
401 fn test_discord_channel() {
402 let did = "did:plc:test123";
403 let discord_id = "123456789012345678";
404 let token = generate_signup_token(did, "discord", discord_id);
405 let result = verify_signup_token(&token, "discord", discord_id);
406 assert!(result.is_ok());
407 }
408
409 #[test]
410 fn test_format_token_for_display() {
411 let token = "ABCDEFGHIJKLMNOP";
412 let formatted = format_token_for_display(token);
413 assert_eq!(formatted, "ABCD-EFGH-IJKL-MNOP");
414 }
415
416 #[test]
417 fn test_normalize_token_input() {
418 let input = "ABCD-EFGH IJKL-MNOP";
419 let normalized = normalize_token_input(input);
420 assert_eq!(normalized, "ABCDEFGHIJKLMNOP");
421 }
422}