this repo has no description
at main 13 kB view raw
1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 2use hmac::Mac; 3use sha2::{Digest, Sha256}; 4 5type HmacSha256 = hmac::Hmac<Sha256>; 6 7const TOKEN_VERSION: u8 = 1; 8const DEFAULT_SIGNUP_EXPIRY_MINUTES: u64 = 30; 9const DEFAULT_MIGRATION_EXPIRY_HOURS: u64 = 48; 10const DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES: u64 = 10; 11 12#[derive(Debug, Clone, Copy, PartialEq, Eq)] 13pub enum VerificationPurpose { 14 Signup, 15 Migration, 16 ChannelUpdate, 17} 18 19impl VerificationPurpose { 20 fn as_str(&self) -> &'static str { 21 match self { 22 Self::Signup => "signup", 23 Self::Migration => "migration", 24 Self::ChannelUpdate => "channel_update", 25 } 26 } 27 28 fn from_str(s: &str) -> Option<Self> { 29 match s { 30 "signup" => Some(Self::Signup), 31 "migration" => Some(Self::Migration), 32 "channel_update" => Some(Self::ChannelUpdate), 33 _ => None, 34 } 35 } 36 37 fn default_expiry_seconds(&self) -> u64 { 38 match self { 39 Self::Signup => DEFAULT_SIGNUP_EXPIRY_MINUTES * 60, 40 Self::Migration => DEFAULT_MIGRATION_EXPIRY_HOURS * 3600, 41 Self::ChannelUpdate => DEFAULT_CHANNEL_UPDATE_EXPIRY_MINUTES * 60, 42 } 43 } 44} 45 46#[derive(Debug)] 47pub struct VerificationToken { 48 pub did: String, 49 pub purpose: VerificationPurpose, 50 pub channel: String, 51 pub identifier_hash: String, 52 pub expires_at: u64, 53} 54 55fn derive_verification_key() -> [u8; 32] { 56 use hkdf::Hkdf; 57 let master_key = std::env::var("MASTER_KEY").unwrap_or_else(|_| { 58 if cfg!(test) || std::env::var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS").is_ok() { 59 "test-master-key-not-for-production".to_string() 60 } else { 61 panic!("MASTER_KEY must be set"); 62 } 63 }); 64 let hk = Hkdf::<Sha256>::new(None, master_key.as_bytes()); 65 let mut key = [0u8; 32]; 66 hk.expand(b"tranquil-pds-verification-token-v1", &mut key) 67 .expect("HKDF expansion failed"); 68 key 69} 70 71pub fn hash_identifier(identifier: &str) -> String { 72 let mut hasher = Sha256::new(); 73 hasher.update(identifier.to_lowercase().as_bytes()); 74 let result = hasher.finalize(); 75 URL_SAFE_NO_PAD.encode(&result[..16]) 76} 77 78pub fn generate_signup_token(did: &str, channel: &str, identifier: &str) -> String { 79 generate_token(did, VerificationPurpose::Signup, channel, identifier) 80} 81 82pub fn generate_migration_token(did: &str, email: &str) -> String { 83 generate_token(did, VerificationPurpose::Migration, "email", email) 84} 85 86pub fn generate_channel_update_token(did: &str, channel: &str, identifier: &str) -> String { 87 generate_token(did, VerificationPurpose::ChannelUpdate, channel, identifier) 88} 89 90pub fn generate_token( 91 did: &str, 92 purpose: VerificationPurpose, 93 channel: &str, 94 identifier: &str, 95) -> String { 96 generate_token_with_expiry( 97 did, 98 purpose, 99 channel, 100 identifier, 101 purpose.default_expiry_seconds(), 102 ) 103} 104 105pub fn generate_token_with_expiry( 106 did: &str, 107 purpose: VerificationPurpose, 108 channel: &str, 109 identifier: &str, 110 expiry_seconds: u64, 111) -> String { 112 let key = derive_verification_key(); 113 let identifier_hash = hash_identifier(identifier); 114 let expires_at = std::time::SystemTime::now() 115 .duration_since(std::time::UNIX_EPOCH) 116 .unwrap_or_default() 117 .as_secs() 118 + expiry_seconds; 119 120 let payload = format!( 121 "{}|{}|{}|{}|{}", 122 did, 123 purpose.as_str(), 124 channel, 125 identifier_hash, 126 expires_at 127 ); 128 129 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid"); 130 mac.update(payload.as_bytes()); 131 let signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes()); 132 133 let token_data = format!( 134 "{}|{}|{}|{}|{}|{}|{}", 135 TOKEN_VERSION, 136 did, 137 purpose.as_str(), 138 channel, 139 identifier_hash, 140 expires_at, 141 signature 142 ); 143 URL_SAFE_NO_PAD.encode(token_data.as_bytes()) 144} 145 146#[derive(Debug)] 147pub enum VerifyError { 148 InvalidFormat, 149 UnsupportedVersion, 150 Expired, 151 InvalidSignature, 152 IdentifierMismatch, 153 PurposeMismatch, 154 ChannelMismatch, 155} 156 157impl std::fmt::Display for VerifyError { 158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 159 match self { 160 Self::InvalidFormat => write!(f, "Invalid token format"), 161 Self::UnsupportedVersion => write!(f, "Unsupported token version"), 162 Self::Expired => write!(f, "Token has expired"), 163 Self::InvalidSignature => write!(f, "Invalid token signature"), 164 Self::IdentifierMismatch => write!(f, "Identifier does not match token"), 165 Self::PurposeMismatch => write!(f, "Token purpose does not match"), 166 Self::ChannelMismatch => write!(f, "Token channel does not match"), 167 } 168 } 169} 170 171pub fn verify_signup_token( 172 token: &str, 173 expected_channel: &str, 174 expected_identifier: &str, 175) -> Result<VerificationToken, VerifyError> { 176 let parsed = verify_token_signature(token)?; 177 if parsed.purpose != VerificationPurpose::Signup { 178 return Err(VerifyError::PurposeMismatch); 179 } 180 if parsed.channel != expected_channel { 181 return Err(VerifyError::ChannelMismatch); 182 } 183 let expected_hash = hash_identifier(expected_identifier); 184 if parsed.identifier_hash != expected_hash { 185 return Err(VerifyError::IdentifierMismatch); 186 } 187 Ok(parsed) 188} 189 190pub fn verify_migration_token( 191 token: &str, 192 expected_email: &str, 193) -> Result<VerificationToken, VerifyError> { 194 let parsed = verify_token_signature(token)?; 195 if parsed.purpose != VerificationPurpose::Migration { 196 return Err(VerifyError::PurposeMismatch); 197 } 198 if parsed.channel != "email" { 199 return Err(VerifyError::ChannelMismatch); 200 } 201 let expected_hash = hash_identifier(expected_email); 202 if parsed.identifier_hash != expected_hash { 203 return Err(VerifyError::IdentifierMismatch); 204 } 205 Ok(parsed) 206} 207 208pub fn verify_channel_update_token( 209 token: &str, 210 expected_channel: &str, 211 expected_identifier: &str, 212) -> Result<VerificationToken, VerifyError> { 213 let parsed = verify_token_signature(token)?; 214 if parsed.purpose != VerificationPurpose::ChannelUpdate { 215 return Err(VerifyError::PurposeMismatch); 216 } 217 if parsed.channel != expected_channel { 218 return Err(VerifyError::ChannelMismatch); 219 } 220 let expected_hash = hash_identifier(expected_identifier); 221 if parsed.identifier_hash != expected_hash { 222 return Err(VerifyError::IdentifierMismatch); 223 } 224 Ok(parsed) 225} 226 227pub fn verify_token_for_did( 228 token: &str, 229 expected_did: &str, 230) -> Result<VerificationToken, VerifyError> { 231 let parsed = verify_token_signature(token)?; 232 if parsed.did != expected_did { 233 return Err(VerifyError::IdentifierMismatch); 234 } 235 Ok(parsed) 236} 237 238pub fn verify_token_signature(token: &str) -> Result<VerificationToken, VerifyError> { 239 let token_bytes = URL_SAFE_NO_PAD 240 .decode(token.trim()) 241 .map_err(|_| VerifyError::InvalidFormat)?; 242 let token_str = String::from_utf8(token_bytes).map_err(|_| VerifyError::InvalidFormat)?; 243 244 let parts: Vec<&str> = token_str.split('|').collect(); 245 if parts.len() != 7 { 246 return Err(VerifyError::InvalidFormat); 247 } 248 249 let version: u8 = parts[0].parse().map_err(|_| VerifyError::InvalidFormat)?; 250 if version != TOKEN_VERSION { 251 return Err(VerifyError::UnsupportedVersion); 252 } 253 254 let did = parts[1]; 255 let purpose_str = parts[2]; 256 let channel = parts[3]; 257 let identifier_hash = parts[4]; 258 let expires_at: u64 = parts[5].parse().map_err(|_| VerifyError::InvalidFormat)?; 259 let provided_signature = parts[6]; 260 261 let purpose = VerificationPurpose::from_str(purpose_str).ok_or(VerifyError::InvalidFormat)?; 262 263 let now = std::time::SystemTime::now() 264 .duration_since(std::time::UNIX_EPOCH) 265 .unwrap_or_default() 266 .as_secs(); 267 if now > expires_at { 268 return Err(VerifyError::Expired); 269 } 270 271 let key = derive_verification_key(); 272 let payload = format!( 273 "{}|{}|{}|{}|{}", 274 did, purpose_str, channel, identifier_hash, expires_at 275 ); 276 let mut mac = <HmacSha256 as Mac>::new_from_slice(&key).expect("HMAC key size is valid"); 277 mac.update(payload.as_bytes()); 278 let expected_signature = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes()); 279 280 use subtle::ConstantTimeEq; 281 let sig_matches: bool = provided_signature 282 .as_bytes() 283 .ct_eq(expected_signature.as_bytes()) 284 .into(); 285 if !sig_matches { 286 return Err(VerifyError::InvalidSignature); 287 } 288 289 Ok(VerificationToken { 290 did: did.to_string(), 291 purpose, 292 channel: channel.to_string(), 293 identifier_hash: identifier_hash.to_string(), 294 expires_at, 295 }) 296} 297 298pub fn format_token_for_display(token: &str) -> String { 299 token 300 .replace(['-', ' '], "") 301 .chars() 302 .collect::<Vec<_>>() 303 .chunks(4) 304 .map(|chunk| chunk.iter().collect::<String>()) 305 .collect::<Vec<_>>() 306 .join("-") 307} 308 309pub fn normalize_token_input(input: &str) -> String { 310 input 311 .chars() 312 .filter(|c| c.is_ascii_alphanumeric() || *c == '_' || *c == '=') 313 .collect() 314} 315 316#[cfg(test)] 317mod tests { 318 use super::*; 319 320 #[test] 321 fn test_signup_token() { 322 let did = "did:plc:test123"; 323 let channel = "email"; 324 let identifier = "test@example.com"; 325 let token = generate_signup_token(did, channel, identifier); 326 let result = verify_signup_token(&token, channel, identifier); 327 assert!(result.is_ok(), "Expected Ok, got {:?}", result); 328 let parsed = result.unwrap(); 329 assert_eq!(parsed.did, did); 330 assert_eq!(parsed.purpose, VerificationPurpose::Signup); 331 assert_eq!(parsed.channel, channel); 332 } 333 334 #[test] 335 fn test_migration_token() { 336 let did = "did:plc:test123"; 337 let email = "test@example.com"; 338 let token = generate_migration_token(did, email); 339 let result = verify_migration_token(&token, email); 340 assert!(result.is_ok(), "Expected Ok, got {:?}", result); 341 let parsed = result.unwrap(); 342 assert_eq!(parsed.did, did); 343 assert_eq!(parsed.purpose, VerificationPurpose::Migration); 344 } 345 346 #[test] 347 fn test_token_case_insensitive() { 348 let did = "did:plc:test123"; 349 let token = generate_signup_token(did, "email", "Test@Example.COM"); 350 let result = verify_signup_token(&token, "email", "test@example.com"); 351 assert!(result.is_ok()); 352 } 353 354 #[test] 355 fn test_token_wrong_identifier() { 356 let did = "did:plc:test123"; 357 let token = generate_signup_token(did, "email", "test@example.com"); 358 let result = verify_signup_token(&token, "email", "other@example.com"); 359 assert!(matches!(result, Err(VerifyError::IdentifierMismatch))); 360 } 361 362 #[test] 363 fn test_token_wrong_channel() { 364 let did = "did:plc:test123"; 365 let token = generate_signup_token(did, "email", "test@example.com"); 366 let result = verify_signup_token(&token, "discord", "test@example.com"); 367 assert!(matches!(result, Err(VerifyError::ChannelMismatch))); 368 } 369 370 #[test] 371 fn test_expired_token() { 372 let did = "did:plc:test123"; 373 let token = generate_token_with_expiry( 374 did, 375 VerificationPurpose::Signup, 376 "email", 377 "test@example.com", 378 0, 379 ); 380 std::thread::sleep(std::time::Duration::from_millis(1100)); 381 let result = verify_signup_token(&token, "email", "test@example.com"); 382 assert!(matches!(result, Err(VerifyError::Expired))); 383 } 384 385 #[test] 386 fn test_invalid_token() { 387 let result = verify_signup_token("invalid-token", "email", "test@example.com"); 388 assert!(matches!(result, Err(VerifyError::InvalidFormat))); 389 } 390 391 #[test] 392 fn test_purpose_mismatch() { 393 let did = "did:plc:test123"; 394 let email = "test@example.com"; 395 let signup_token = generate_signup_token(did, "email", email); 396 let result = verify_migration_token(&signup_token, email); 397 assert!(matches!(result, Err(VerifyError::PurposeMismatch))); 398 } 399 400 #[test] 401 fn test_discord_channel() { 402 let did = "did:plc:test123"; 403 let discord_id = "123456789012345678"; 404 let token = generate_signup_token(did, "discord", discord_id); 405 let result = verify_signup_token(&token, "discord", discord_id); 406 assert!(result.is_ok()); 407 } 408 409 #[test] 410 fn test_format_token_for_display() { 411 let token = "ABCDEFGHIJKLMNOP"; 412 let formatted = format_token_for_display(token); 413 assert_eq!(formatted, "ABCD-EFGH-IJKL-MNOP"); 414 } 415 416 #[test] 417 fn test_normalize_token_input() { 418 let input = "ABCD-EFGH IJKL-MNOP"; 419 let normalized = normalize_token_input(input); 420 assert_eq!(normalized, "ABCDEFGHIJKLMNOP"); 421 } 422}