this repo has no description
1pub mod reserved;
2
3use hickory_resolver::TokioAsyncResolver;
4use hickory_resolver::config::{ResolverConfig, ResolverOpts};
5use reqwest::Client;
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Error, Debug)]
10pub enum HandleResolutionError {
11 #[error("DNS lookup failed: {0}")]
12 DnsError(String),
13 #[error("HTTP request failed: {0}")]
14 HttpError(String),
15 #[error("No DID found for handle")]
16 NotFound,
17 #[error("Invalid DID format in record")]
18 InvalidDid,
19 #[error("DID mismatch: expected {expected}, got {actual}")]
20 DidMismatch { expected: String, actual: String },
21}
22
23pub async fn resolve_handle_dns(handle: &str) -> Result<String, HandleResolutionError> {
24 let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
25 let query_name = format!("_atproto.{}", handle);
26 let txt_lookup = resolver
27 .txt_lookup(&query_name)
28 .await
29 .map_err(|e| HandleResolutionError::DnsError(e.to_string()))?;
30 for record in txt_lookup.iter() {
31 for txt in record.txt_data() {
32 let txt_str = String::from_utf8_lossy(txt);
33 if let Some(did) = txt_str.strip_prefix("did=") {
34 let did = did.trim();
35 if did.starts_with("did:") {
36 return Ok(did.to_string());
37 }
38 }
39 }
40 }
41 Err(HandleResolutionError::NotFound)
42}
43
44pub async fn resolve_handle_http(handle: &str) -> Result<String, HandleResolutionError> {
45 let url = format!("https://{}/.well-known/atproto-did", handle);
46 let client = Client::builder()
47 .timeout(Duration::from_secs(10))
48 .redirect(reqwest::redirect::Policy::limited(5))
49 .build()
50 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
51 let response = client
52 .get(&url)
53 .header("Accept", "text/plain")
54 .send()
55 .await
56 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
57 if !response.status().is_success() {
58 return Err(HandleResolutionError::NotFound);
59 }
60 let body = response
61 .text()
62 .await
63 .map_err(|e| HandleResolutionError::HttpError(e.to_string()))?;
64 let did = body.trim();
65 if did.starts_with("did:") {
66 Ok(did.to_string())
67 } else {
68 Err(HandleResolutionError::InvalidDid)
69 }
70}
71
72pub async fn resolve_handle(handle: &str) -> Result<String, HandleResolutionError> {
73 match resolve_handle_dns(handle).await {
74 Ok(did) => return Ok(did),
75 Err(e) => {
76 tracing::debug!("DNS resolution failed for {}: {}, trying HTTP", handle, e);
77 }
78 }
79 resolve_handle_http(handle).await
80}
81
82pub async fn verify_handle_ownership(
83 handle: &str,
84 expected_did: &str,
85) -> Result<(), HandleResolutionError> {
86 let resolved_did = resolve_handle(handle).await?;
87 if resolved_did == expected_did {
88 Ok(())
89 } else {
90 Err(HandleResolutionError::DidMismatch {
91 expected: expected_did.to_string(),
92 actual: resolved_did,
93 })
94 }
95}
96
97pub fn is_service_domain_handle(handle: &str, hostname: &str) -> bool {
98 if !handle.contains('.') {
99 return true;
100 }
101 let service_domains: Vec<String> = std::env::var("PDS_SERVICE_HANDLE_DOMAINS")
102 .map(|s| s.split(',').map(|d| d.trim().to_string()).collect())
103 .unwrap_or_else(|_| vec![hostname.to_string()]);
104 for domain in service_domains {
105 if handle.ends_with(&format!(".{}", domain)) {
106 return true;
107 }
108 if handle == domain {
109 return true;
110 }
111 }
112 false
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_is_service_domain_handle() {
121 assert!(is_service_domain_handle("user.example.com", "example.com"));
122 assert!(is_service_domain_handle("example.com", "example.com"));
123 assert!(is_service_domain_handle("myhandle", "example.com"));
124 assert!(!is_service_domain_handle("user.other.com", "example.com"));
125 assert!(!is_service_domain_handle("myhandle.xyz", "example.com"));
126 }
127}