this repo has no description
1use hickory_resolver::TokioAsyncResolver;
2use hickory_resolver::config::{ResolverConfig, ResolverOpts};
3use reqwest::Client;
4use std::time::Duration;
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum HandleResolutionError {
9 #[error("DNS lookup failed: {0}")]
10 DnsError(String),
11 #[error("HTTP request failed: {0}")]
12 HttpError(String),
13 #[error("No DID found for handle")]
14 NotFound,
15 #[error("Invalid DID format in record")]
16 InvalidDid,
17 #[error("DID mismatch: expected {expected}, got {actual}")]
18 DidMismatch { expected: String, actual: String },
19}
20
21pub async fn resolve_handle_dns(handle: &str) -> Result<String, HandleResolutionError> {
22 let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
23 let query_name = format!("_atproto.{}", handle);
24 let txt_lookup = resolver
25 .txt_lookup(&query_name)
26 .await
27 .map_err(|e| HandleResolutionError::DnsError(e.to_string()))?;
28 for record in txt_lookup.iter() {
29 for txt in record.txt_data() {
30 let txt_str = String::from_utf8_lossy(txt);
31 if let Some(did) = txt_str.strip_prefix("did=") {
32 let did = did.trim();
33 if did.starts_with("did:") {
34 return Ok(did.to_string());
35 }
36 }
37 }
38 }
39 Err(HandleResolutionError::NotFound)
40}
41
42pub async fn resolve_handle_http(handle: &str) -> Result<String, HandleResolutionError> {
43 let url = format!("https://{}/.well-known/atproto-did", handle);
44 let client = Client::builder()
45 .timeout(Duration::from_secs(10))
46 .redirect(reqwest::redirect::Policy::limited(5))
47 .build()
48 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
49 let response = client
50 .get(&url)
51 .header("Accept", "text/plain")
52 .send()
53 .await
54 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
55 if !response.status().is_success() {
56 return Err(HandleResolutionError::NotFound);
57 }
58 let body = response
59 .text()
60 .await
61 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
62 let did = body.trim();
63 if did.starts_with("did:") {
64 Ok(did.to_string())
65 } else {
66 Err(HandleResolutionError::InvalidDid)
67 }
68}
69
70pub async fn resolve_handle(handle: &str) -> Result<String, HandleResolutionError> {
71 match resolve_handle_dns(handle).await {
72 Ok(did) => return Ok(did),
73 Err(e) => {
74 tracing::debug!("DNS resolution failed for {}: {}, trying HTTP", handle, e);
75 }
76 }
77 resolve_handle_http(handle).await
78}
79
80pub async fn verify_handle_ownership(
81 handle: &str,
82 expected_did: &str,
83) -> Result<(), HandleResolutionError> {
84 let resolved_did = resolve_handle(handle).await?;
85 if resolved_did == expected_did {
86 Ok(())
87 } else {
88 Err(HandleResolutionError::DidMismatch {
89 expected: expected_did.to_string(),
90 actual: resolved_did,
91 })
92 }
93}
94
95pub fn is_service_domain_handle(handle: &str, hostname: &str) -> bool {
96 let service_domains: Vec<String> = std::env::var("PDS_SERVICE_HANDLE_DOMAINS")
97 .map(|s| s.split(',').map(|d| d.trim().to_string()).collect())
98 .unwrap_or_else(|_| vec![hostname.to_string()]);
99 for domain in service_domains {
100 if handle.ends_with(&format!(".{}", domain)) {
101 return true;
102 }
103 if handle == domain {
104 return true;
105 }
106 }
107 false
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_is_service_domain_handle() {
116 assert!(is_service_domain_handle("user.example.com", "example.com"));
117 assert!(is_service_domain_handle("example.com", "example.com"));
118 assert!(!is_service_domain_handle("user.other.com", "example.com"));
119 assert!(!is_service_domain_handle("myhandle.xyz", "example.com"));
120 }
121}