this repo has no description
1pub mod reserved; 2 3use hickory_resolver::TokioAsyncResolver; 4use hickory_resolver::config::{ResolverConfig, ResolverOpts}; 5use reqwest::Client; 6use std::time::Duration; 7use thiserror::Error; 8 9#[derive(Error, Debug)] 10pub enum HandleResolutionError { 11 #[error("DNS lookup failed: {0}")] 12 DnsError(String), 13 #[error("HTTP request failed: {0}")] 14 HttpError(String), 15 #[error("No DID found for handle")] 16 NotFound, 17 #[error("Invalid DID format in record")] 18 InvalidDid, 19 #[error("DID mismatch: expected {expected}, got {actual}")] 20 DidMismatch { expected: String, actual: String }, 21} 22 23pub async fn resolve_handle_dns(handle: &str) -> Result<String, HandleResolutionError> { 24 let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default()); 25 let query_name = format!("_atproto.{}", handle); 26 let txt_lookup = resolver 27 .txt_lookup(&query_name) 28 .await 29 .map_err(|e| HandleResolutionError::DnsError(e.to_string()))?; 30 for record in txt_lookup.iter() { 31 for txt in record.txt_data() { 32 let txt_str = String::from_utf8_lossy(txt); 33 if let Some(did) = txt_str.strip_prefix("did=") { 34 let did = did.trim(); 35 if did.starts_with("did:") { 36 return Ok(did.to_string()); 37 } 38 } 39 } 40 } 41 Err(HandleResolutionError::NotFound) 42} 43 44pub async fn resolve_handle_http(handle: &str) -> Result<String, HandleResolutionError> { 45 let url = format!("https://{}/.well-known/atproto-did", handle); 46 let client = Client::builder() 47 .timeout(Duration::from_secs(10)) 48 .redirect(reqwest::redirect::Policy::limited(5)) 49 .build() 50 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?; 51 let response = client 52 .get(&url) 53 .header("Accept", "text/plain") 54 .send() 55 .await 56 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?; 57 if !response.status().is_success() { 58 return Err(HandleResolutionError::NotFound); 59 } 60 let body = response 61 .text() 62 .await 63 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?; 64 let did = body.trim(); 65 if did.starts_with("did:") { 66 Ok(did.to_string()) 67 } else { 68 Err(HandleResolutionError::InvalidDid) 69 } 70} 71 72pub async fn resolve_handle(handle: &str) -> Result<String, HandleResolutionError> { 73 match resolve_handle_dns(handle).await { 74 Ok(did) => return Ok(did), 75 Err(e) => { 76 tracing::debug!("DNS resolution failed for {}: {}, trying HTTP", handle, e); 77 } 78 } 79 resolve_handle_http(handle).await 80} 81 82pub async fn verify_handle_ownership( 83 handle: &str, 84 expected_did: &str, 85) -> Result<(), HandleResolutionError> { 86 let resolved_did = resolve_handle(handle).await?; 87 if resolved_did == expected_did { 88 Ok(()) 89 } else { 90 Err(HandleResolutionError::DidMismatch { 91 expected: expected_did.to_string(), 92 actual: resolved_did, 93 }) 94 } 95} 96 97pub fn is_service_domain_handle(handle: &str, hostname: &str) -> bool { 98 if !handle.contains('.') { 99 return true; 100 } 101 let service_domains: Vec<String> = std::env::var("PDS_SERVICE_HANDLE_DOMAINS") 102 .map(|s| s.split(',').map(|d| d.trim().to_string()).collect()) 103 .unwrap_or_else(|_| vec![hostname.to_string()]); 104 for domain in service_domains { 105 if handle.ends_with(&format!(".{}", domain)) { 106 return true; 107 } 108 if handle == domain { 109 return true; 110 } 111 } 112 false 113} 114 115#[cfg(test)] 116mod tests { 117 use super::*; 118 119 #[test] 120 fn test_is_service_domain_handle() { 121 assert!(is_service_domain_handle("user.example.com", "example.com")); 122 assert!(is_service_domain_handle("example.com", "example.com")); 123 assert!(is_service_domain_handle("myhandle", "example.com")); 124 assert!(!is_service_domain_handle("user.other.com", "example.com")); 125 assert!(!is_service_domain_handle("myhandle.xyz", "example.com")); 126 } 127}