A third party ATProto appview
at main 472 lines 19 kB view raw
1#!/usr/bin/env python3 2""" 3DID Resolution Service for AT Protocol 4 5Resolves DIDs to PDS endpoints, handles, and verifies identity. 6Python port of server/services/did-resolver.ts 7""" 8 9import asyncio 10import logging 11import time 12from typing import Optional, Dict, Any, Tuple 13from collections import OrderedDict 14import aiohttp 15import aiodns 16 17logger = logging.getLogger(__name__) 18 19 20class LRUCache: 21 """Simple LRU Cache with TTL support""" 22 23 def __init__(self, max_size: int, ttl_ms: int): 24 self.cache: OrderedDict = OrderedDict() 25 self.max_size = max_size 26 self.ttl_ms = ttl_ms 27 28 def get(self, key: str) -> Optional[Any]: 29 """Get value from cache if not expired""" 30 if key not in self.cache: 31 return None 32 33 value, timestamp = self.cache[key] 34 35 # Check if expired 36 if (time.time() * 1000) - timestamp > self.ttl_ms: 37 del self.cache[key] 38 return None 39 40 # Move to end (most recently used) 41 self.cache.move_to_end(key) 42 return value 43 44 def set(self, key: str, value: Any): 45 """Set value in cache with current timestamp""" 46 # Remove if exists 47 if key in self.cache: 48 del self.cache[key] 49 50 # Evict oldest if at capacity 51 if len(self.cache) >= self.max_size: 52 self.cache.popitem(last=False) 53 54 self.cache[key] = (value, time.time() * 1000) 55 56 def clear(self): 57 """Clear all cache entries""" 58 self.cache.clear() 59 60 def size(self) -> int: 61 """Get current cache size""" 62 return len(self.cache) 63 64 65class RequestQueue: 66 """Request queue with concurrency limiting""" 67 68 def __init__(self, max_concurrent: int): 69 self.max_concurrent = max_concurrent 70 self.semaphore = asyncio.Semaphore(max_concurrent) 71 self.active_count = 0 72 self.completed_count = 0 73 self.failed_count = 0 74 75 async def enqueue(self, operation): 76 """Enqueue and execute operation with concurrency limit""" 77 async with self.semaphore: 78 self.active_count += 1 79 try: 80 result = await operation() 81 self.completed_count += 1 82 return result 83 except Exception as e: 84 self.failed_count += 1 85 raise 86 finally: 87 self.active_count -= 1 88 89 90class DIDResolver: 91 """DID Resolution Service for AT Protocol""" 92 93 def __init__(self): 94 self.plc_directory = "https://plc.directory" 95 self.max_retries = 3 96 self.base_timeout = 15 # seconds 97 self.retry_delay = 1.0 # seconds 98 self.circuit_breaker_threshold = 5 99 self.circuit_breaker_timeout = 60000 # ms 100 self.failure_count = 0 101 self.last_failure_time = 0 102 self.circuit_open = False 103 self.resolution_count = 0 104 self.batch_log_size = 5000 105 106 # Caching 107 self.did_document_cache = LRUCache(100000, 24 * 60 * 60 * 1000) # 24 hour TTL 108 self.handle_cache = LRUCache(100000, 24 * 60 * 60 * 1000) 109 self.cache_hits = 0 110 self.cache_misses = 0 111 112 # Request queue for rate limiting 113 self.request_queue = RequestQueue(15) # Max 15 concurrent requests 114 115 # HTTP session 116 self.session: Optional[aiohttp.ClientSession] = None 117 118 async def initialize(self): 119 """Initialize HTTP session""" 120 if not self.session: 121 self.session = aiohttp.ClientSession() 122 123 async def close(self): 124 """Close HTTP session""" 125 if self.session: 126 await self.session.close() 127 self.session = None 128 129 def is_circuit_open(self) -> bool: 130 """Check if circuit breaker is open""" 131 if not self.circuit_open: 132 return False 133 134 # Check if enough time has passed 135 if (time.time() * 1000) - self.last_failure_time > self.circuit_breaker_timeout: 136 self.circuit_open = False 137 self.failure_count = 0 138 logger.info("[DID_RESOLVER] Circuit breaker reset") 139 return False 140 141 return True 142 143 def record_success(self): 144 """Record successful operation""" 145 self.failure_count = 0 146 self.circuit_open = False 147 148 def record_failure(self): 149 """Record failed operation""" 150 self.failure_count += 1 151 self.last_failure_time = time.time() * 1000 152 153 if self.failure_count >= self.circuit_breaker_threshold: 154 self.circuit_open = True 155 logger.warning(f"[DID_RESOLVER] Circuit breaker opened after {self.failure_count} failures") 156 157 async def retry_with_backoff(self, operation, max_retries: int = None, base_delay: float = None): 158 """Retry operation with exponential backoff""" 159 max_retries = max_retries or self.max_retries 160 base_delay = base_delay or self.retry_delay 161 last_error = None 162 163 for attempt in range(max_retries + 1): 164 try: 165 return await operation() 166 except Exception as e: 167 last_error = e 168 169 if attempt == max_retries: 170 raise last_error 171 172 delay = base_delay * (2 ** attempt) 173 logger.warning(f"[DID_RESOLVER] Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}") 174 await asyncio.sleep(delay) 175 176 raise last_error 177 178 async def resolve_handle_via_dns(self, handle: str) -> Optional[str]: 179 """Resolve handle to DID via DNS TXT record""" 180 try: 181 resolver = aiodns.DNSResolver() 182 txt_records = await resolver.query(f"_atproto.{handle}", 'TXT') 183 184 for record in txt_records: 185 did = record.text.strip() 186 if did.startswith('did:'): 187 if not (did.startswith('did:plc:') or did.startswith('did:web:')): 188 logger.warning(f"[DID_RESOLVER] Unsupported DID method in DNS for {handle}: {did}") 189 return did 190 191 return None 192 except Exception as e: 193 # DNS errors are common, don't log unless it's not NXDOMAIN 194 if 'NXDOMAIN' not in str(e) and 'NODATA' not in str(e): 195 logger.debug(f"[DID_RESOLVER] DNS error for {handle}: {str(e)}") 196 return None 197 198 async def resolve_handle_via_https(self, handle: str) -> Optional[str]: 199 """Resolve handle to DID via HTTPS well-known endpoint""" 200 try: 201 url = f"https://{handle}/.well-known/atproto-did" 202 timeout = aiohttp.ClientTimeout(total=self.base_timeout) 203 204 async with self.session.get(url, headers={'Accept': 'text/plain'}, timeout=timeout) as response: 205 if response.status == 404: 206 return None 207 208 if response.status != 200: 209 logger.warning(f"[DID_RESOLVER] HTTP {response.status} for {handle}/.well-known/atproto-did") 210 return None 211 212 did = (await response.text()).strip() 213 214 # Check for HTML/JSON response 215 if did.startswith('<') or did.startswith('{'): 216 return None 217 218 if not did.startswith('did:'): 219 return None 220 221 if not (did.startsWith('did:plc:') or did.startswith('did:web:')): 222 logger.warning(f"[DID_RESOLVER] Unsupported DID method for {handle}: {did}") 223 224 return did 225 except asyncio.TimeoutError: 226 logger.warning(f"[DID_RESOLVER] Timeout resolving {handle}/.well-known/atproto-did") 227 return None 228 except Exception as e: 229 logger.debug(f"[DID_RESOLVER] HTTPS error for {handle}: {str(e)}") 230 return None 231 232 async def resolve_handle(self, handle: str) -> Optional[str]: 233 """Resolve handle to DID""" 234 try: 235 # Try DNS first 236 did = await self.resolve_handle_via_dns(handle) 237 if did: 238 return did 239 240 # Fallback to HTTPS 241 did = await self.resolve_handle_via_https(handle) 242 return did 243 except Exception as e: 244 logger.error(f"[DID_RESOLVER] Error resolving handle {handle}: {str(e)}") 245 return None 246 247 async def resolve_plc_did(self, did: str) -> Optional[Dict[str, Any]]: 248 """Resolve PLC DID to DID document""" 249 if self.is_circuit_open(): 250 logger.warning(f"[DID_RESOLVER] Circuit breaker open, skipping PLC DID resolution for {did}") 251 return None 252 253 try: 254 async def fetch_operation(): 255 return await self.retry_with_backoff(async_fetch) 256 257 async def async_fetch(): 258 url = f"{self.plc_directory}/{did}" 259 timeout = aiohttp.ClientTimeout(total=self.base_timeout) 260 261 async with self.session.get(url, headers={'Accept': 'application/did+ld+json, application/json'}, timeout=timeout) as response: 262 if response.status == 404: 263 logger.warning(f"[DID_RESOLVER] DID not found in PLC: {did}") 264 return None 265 266 if response.status >= 500: 267 raise Exception(f"PLC directory server error {response.status}") 268 269 if response.status != 200: 270 raise Exception(f"HTTP {response.status}") 271 272 data = await response.json() 273 274 # Validate DID document 275 if not data or not isinstance(data, dict): 276 raise Exception("Invalid DID document: not an object") 277 278 if not data.get('id'): 279 raise Exception("Invalid DID document: missing id") 280 281 # Security: Verify DID matches 282 if data['id'] != did: 283 logger.error(f"[DID_RESOLVER] SECURITY: DID mismatch from PLC for {did}: document contains {data['id']}") 284 raise Exception(f"DID mismatch: expected {did}, got {data['id']}") 285 286 return data 287 288 result = await self.request_queue.enqueue(fetch_operation) 289 self.record_success() 290 return result 291 except Exception as e: 292 self.record_failure() 293 logger.error(f"[DID_RESOLVER] Error resolving PLC DID {did}: {str(e)}") 294 return None 295 296 async def resolve_web_did(self, did: str) -> Optional[Dict[str, Any]]: 297 """Resolve Web DID to DID document""" 298 try: 299 async def fetch_operation(): 300 # Extract domain from did:web:example.com 301 did_parts = did.replace('did:web:', '').split(':') 302 domain = did_parts[0] 303 path = '/' + '/'.join(did_parts[1:]) if len(did_parts) > 1 else '' 304 305 # Construct URL 306 if path: 307 url = f"https://{domain}{path}/did.json" 308 else: 309 url = f"https://{domain}/.well-known/did.json" 310 311 logger.info(f"[DID_RESOLVER] Resolving Web DID from: {url}") 312 313 timeout = aiohttp.ClientTimeout(total=self.base_timeout) 314 async with self.session.get(url, headers={'Accept': 'application/did+ld+json, application/json'}, timeout=timeout) as response: 315 if response.status == 404: 316 logger.warning(f"[DID_RESOLVER] Web DID not found: {did} at {url}") 317 return None 318 319 if response.status != 200: 320 raise Exception(f"HTTP {response.status}") 321 322 data = await response.json() 323 324 # Validate 325 if not data or not isinstance(data, dict): 326 raise Exception("Invalid DID document: not an object") 327 328 if not data.get('id'): 329 raise Exception("Invalid DID document: missing id") 330 331 # Security: Verify DID matches 332 if data['id'] != did: 333 logger.error(f"[DID_RESOLVER] SECURITY: DID mismatch for did:web {did}: document contains {data['id']}") 334 raise Exception(f"DID mismatch: expected {did}, got {data['id']}") 335 336 return data 337 338 return await self.retry_with_backoff(fetch_operation) 339 except Exception as e: 340 logger.error(f"[DID_RESOLVER] Error resolving Web DID {did}: {str(e)}") 341 return None 342 343 async def resolve_did(self, did: str) -> Optional[Dict[str, Any]]: 344 """Resolve DID to DID document""" 345 # Check cache first 346 cached = self.did_document_cache.get(did) 347 if cached: 348 self.cache_hits += 1 349 return cached 350 351 self.cache_misses += 1 352 353 try: 354 did_doc = None 355 356 if did.startswith('did:plc:'): 357 did_doc = await self.resolve_plc_did(did) 358 elif did.startswith('did:web:'): 359 did_doc = await self.resolve_web_did(did) 360 else: 361 logger.error(f"[DID_RESOLVER] Unsupported DID method: {did}") 362 return None 363 364 # Cache successful resolutions 365 if did_doc: 366 self.did_document_cache.set(did, did_doc) 367 368 return did_doc 369 except Exception as e: 370 logger.error(f"[DID_RESOLVER] Error resolving DID {did}: {str(e)}") 371 return None 372 373 def get_pds_endpoint(self, did_doc: Dict[str, Any]) -> Optional[str]: 374 """Extract PDS endpoint from DID document""" 375 services = did_doc.get('service') 376 if not services or not isinstance(services, list): 377 logger.warning(f"[DID_RESOLVER] No services array in DID document for {did_doc.get('id')}") 378 return None 379 380 # Find AtprotoPersonalDataServer service 381 for service in services: 382 if service.get('id') in ['#atproto_pds', 'atproto_pds'] or \ 383 service.get('type') in ['AtprotoPersonalDataServer', 'AtProtoPersonalDataServer']: 384 endpoint = service.get('serviceEndpoint') 385 386 if not endpoint or not isinstance(endpoint, str): 387 logger.warning(f"[DID_RESOLVER] Invalid PDS endpoint format for {did_doc.get('id')}") 388 return None 389 390 if not (endpoint.startswith('https://') or endpoint.startswith('http://')): 391 logger.warning(f"[DID_RESOLVER] PDS endpoint must be HTTP(S) URL: {endpoint}") 392 return None 393 394 return endpoint 395 396 logger.warning(f"[DID_RESOLVER] No PDS service found in DID document for {did_doc.get('id')}") 397 return None 398 399 def get_handle_from_did_document(self, did_doc: Dict[str, Any]) -> Optional[str]: 400 """Extract handle from DID document""" 401 also_known_as = did_doc.get('alsoKnownAs') 402 if not also_known_as or not isinstance(also_known_as, list): 403 return None 404 405 # Find handle URI in alsoKnownAs (format: at://username.domain) 406 for uri in also_known_as: 407 if isinstance(uri, str) and uri.startswith('at://'): 408 handle = uri.replace('at://', '') 409 if handle and '.' in handle: 410 return handle 411 412 return None 413 414 async def resolve_did_to_pds(self, did: str) -> Optional[str]: 415 """Resolve DID directly to PDS endpoint""" 416 try: 417 did_doc = await self.resolve_did(did) 418 if not did_doc: 419 return None 420 421 return self.get_pds_endpoint(did_doc) 422 except Exception as e: 423 logger.error(f"[DID_RESOLVER] Error resolving DID {did} to PDS: {str(e)}") 424 return None 425 426 async def resolve_did_to_handle(self, did: str) -> Optional[str]: 427 """Resolve DID to handle""" 428 # Check handle cache first 429 cached_handle = self.handle_cache.get(did) 430 if cached_handle: 431 self.cache_hits += 1 432 return cached_handle 433 434 self.cache_misses += 1 435 436 try: 437 did_doc = await self.resolve_did(did) 438 if not did_doc: 439 logger.warning(f"[DID_RESOLVER] Could not resolve DID document for {did}") 440 return None 441 442 handle = self.get_handle_from_did_document(did_doc) 443 if not handle: 444 logger.warning(f"[DID_RESOLVER] No handle found in DID document for {did}") 445 return None 446 447 # Cache the handle mapping 448 self.handle_cache.set(did, handle) 449 450 # Batch logging 451 self.resolution_count += 1 452 if self.resolution_count % self.batch_log_size == 0: 453 total_requests = self.cache_hits + self.cache_misses 454 cache_hit_rate = (self.cache_hits / total_requests * 100) if total_requests > 0 else 0 455 logger.info(f"[DID_RESOLVER] Resolved {self.batch_log_size} DIDs (total: {self.resolution_count}, cache hit rate: {cache_hit_rate:.1f}%)") 456 457 return handle 458 except Exception as e: 459 logger.error(f"[DID_RESOLVER] Error resolving DID {did} to handle: {str(e)}") 460 return None 461 462 def clear_caches(self): 463 """Clear all caches""" 464 self.did_document_cache.clear() 465 self.handle_cache.clear() 466 self.cache_hits = 0 467 self.cache_misses = 0 468 logger.info("[DID_RESOLVER] Caches cleared") 469 470 471# Global singleton instance 472did_resolver = DIDResolver()