this repo has no description
1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, error, info, warn};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct DidDocument {
11 pub id: String,
12 #[serde(default)]
13 pub service: Vec<DidService>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct DidService {
19 pub id: String,
20 #[serde(rename = "type")]
21 pub service_type: String,
22 pub service_endpoint: String,
23}
24
25#[derive(Clone)]
26struct CachedDid {
27 url: String,
28 did: String,
29 resolved_at: Instant,
30}
31
32#[derive(Debug, Clone)]
33pub struct ResolvedService {
34 pub url: String,
35 pub did: String,
36}
37
38pub struct DidResolver {
39 did_cache: RwLock<HashMap<String, CachedDid>>,
40 client: Client,
41 cache_ttl: Duration,
42 plc_directory_url: String,
43}
44
45impl Clone for DidResolver {
46 fn clone(&self) -> Self {
47 Self {
48 did_cache: RwLock::new(HashMap::new()),
49 client: self.client.clone(),
50 cache_ttl: self.cache_ttl,
51 plc_directory_url: self.plc_directory_url.clone(),
52 }
53 }
54}
55
56impl DidResolver {
57 pub fn new() -> Self {
58 let cache_ttl_secs: u64 = std::env::var("DID_CACHE_TTL_SECS")
59 .ok()
60 .and_then(|v| v.parse().ok())
61 .unwrap_or(300);
62
63 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
64 .unwrap_or_else(|_| "https://plc.directory".to_string());
65
66 let client = Client::builder()
67 .timeout(Duration::from_secs(10))
68 .connect_timeout(Duration::from_secs(5))
69 .pool_max_idle_per_host(10)
70 .build()
71 .unwrap_or_else(|_| Client::new());
72
73 info!("DID resolver initialized");
74
75 Self {
76 did_cache: RwLock::new(HashMap::new()),
77 client,
78 cache_ttl: Duration::from_secs(cache_ttl_secs),
79 plc_directory_url,
80 }
81 }
82
83 pub async fn resolve_did(&self, did: &str) -> Option<ResolvedService> {
84 {
85 let cache = self.did_cache.read().await;
86 if let Some(cached) = cache.get(did)
87 && cached.resolved_at.elapsed() < self.cache_ttl
88 {
89 return Some(ResolvedService {
90 url: cached.url.clone(),
91 did: cached.did.clone(),
92 });
93 }
94 }
95
96 let resolved = self.resolve_did_internal(did).await?;
97
98 {
99 let mut cache = self.did_cache.write().await;
100 cache.insert(
101 did.to_string(),
102 CachedDid {
103 url: resolved.url.clone(),
104 did: resolved.did.clone(),
105 resolved_at: Instant::now(),
106 },
107 );
108 }
109
110 Some(resolved)
111 }
112
113 async fn resolve_did_internal(&self, did: &str) -> Option<ResolvedService> {
114 let did_doc = if did.starts_with("did:web:") {
115 self.resolve_did_web(did).await
116 } else if did.starts_with("did:plc:") {
117 self.resolve_did_plc(did).await
118 } else {
119 warn!("Unsupported DID method: {}", did);
120 return None;
121 };
122
123 let doc = match did_doc {
124 Ok(doc) => doc,
125 Err(e) => {
126 error!("Failed to resolve DID {}: {}", did, e);
127 return None;
128 }
129 };
130
131 self.extract_service_endpoint(&doc)
132 }
133
134 async fn resolve_did_web(&self, did: &str) -> Result<DidDocument, String> {
135 let host = did
136 .strip_prefix("did:web:")
137 .ok_or("Invalid did:web format")?;
138
139 let (host, path) = if host.contains(':') {
140 let decoded = host.replace("%3A", ":");
141 let parts: Vec<&str> = decoded.splitn(2, '/').collect();
142 if parts.len() > 1 {
143 (parts[0].to_string(), format!("/{}", parts[1]))
144 } else {
145 (decoded, String::new())
146 }
147 } else {
148 let parts: Vec<&str> = host.splitn(2, ':').collect();
149 if parts.len() > 1 && parts[1].contains('/') {
150 let path_parts: Vec<&str> = parts[1].splitn(2, '/').collect();
151 if path_parts.len() > 1 {
152 (
153 format!("{}:{}", parts[0], path_parts[0]),
154 format!("/{}", path_parts[1]),
155 )
156 } else {
157 (host.to_string(), String::new())
158 }
159 } else {
160 (host.to_string(), String::new())
161 }
162 };
163
164 let scheme =
165 if host.starts_with("localhost") || host.starts_with("127.0.0.1") || host.contains(':')
166 {
167 "http"
168 } else {
169 "https"
170 };
171
172 let url = if path.is_empty() {
173 format!("{}://{}/.well-known/did.json", scheme, host)
174 } else {
175 format!("{}://{}{}/did.json", scheme, host, path)
176 };
177
178 debug!("Resolving did:web {} via {}", did, url);
179
180 let resp = self
181 .client
182 .get(&url)
183 .send()
184 .await
185 .map_err(|e| format!("HTTP request failed: {}", e))?;
186
187 if !resp.status().is_success() {
188 return Err(format!("HTTP {}", resp.status()));
189 }
190
191 resp.json::<DidDocument>()
192 .await
193 .map_err(|e| format!("Failed to parse DID document: {}", e))
194 }
195
196 async fn resolve_did_plc(&self, did: &str) -> Result<DidDocument, String> {
197 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did));
198
199 debug!("Resolving did:plc {} via {}", did, url);
200
201 let resp = self
202 .client
203 .get(&url)
204 .send()
205 .await
206 .map_err(|e| format!("HTTP request failed: {}", e))?;
207
208 if resp.status() == reqwest::StatusCode::NOT_FOUND {
209 return Err("DID not found".to_string());
210 }
211
212 if !resp.status().is_success() {
213 return Err(format!("HTTP {}", resp.status()));
214 }
215
216 resp.json::<DidDocument>()
217 .await
218 .map_err(|e| format!("Failed to parse DID document: {}", e))
219 }
220
221 fn extract_service_endpoint(&self, doc: &DidDocument) -> Option<ResolvedService> {
222 for service in &doc.service {
223 if service.service_type == "AtprotoAppView"
224 || service.id.contains("atproto_appview")
225 || service.id.ends_with("#bsky_appview")
226 {
227 return Some(ResolvedService {
228 url: service.service_endpoint.clone(),
229 did: doc.id.clone(),
230 });
231 }
232 }
233
234 for service in &doc.service {
235 if service.service_type.contains("AppView") || service.id.contains("appview") {
236 return Some(ResolvedService {
237 url: service.service_endpoint.clone(),
238 did: doc.id.clone(),
239 });
240 }
241 }
242
243 if let Some(service) = doc.service.first()
244 && service.service_endpoint.starts_with("http")
245 {
246 warn!(
247 "No explicit AppView service found for {}, using first service: {}",
248 doc.id, service.service_endpoint
249 );
250 return Some(ResolvedService {
251 url: service.service_endpoint.clone(),
252 did: doc.id.clone(),
253 });
254 }
255
256 if doc.id.starts_with("did:web:") {
257 let host = doc.id.strip_prefix("did:web:")?;
258 let decoded_host = host.replace("%3A", ":");
259 let base_host = decoded_host.split('/').next()?;
260 let scheme = if base_host.starts_with("localhost")
261 || base_host.starts_with("127.0.0.1")
262 || base_host.contains(':')
263 {
264 "http"
265 } else {
266 "https"
267 };
268 warn!(
269 "No service found for {}, deriving URL from DID: {}://{}",
270 doc.id, scheme, base_host
271 );
272 return Some(ResolvedService {
273 url: format!("{}://{}", scheme, base_host),
274 did: doc.id.clone(),
275 });
276 }
277
278 None
279 }
280
281 pub async fn invalidate_cache(&self, did: &str) {
282 let mut cache = self.did_cache.write().await;
283 cache.remove(did);
284 }
285}
286
287impl Default for DidResolver {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293pub fn create_did_resolver() -> Arc<DidResolver> {
294 Arc::new(DidResolver::new())
295}