A trust and safety agent that interacts with Osprey for investigation, real-time analysis, and prevention implementations

add a domain tool

+169 -15
+1
pyproject.toml
··· 10 10 "atproto>=0.0.65", 11 11 "click>=8.3.1", 12 12 "clickhouse-connect>=0.10.0", 13 + "dnspython>=2.8.0", 13 14 "pydantic>=2.12.5", 14 15 "pydantic-settings>=2.12.0", 15 16 ]
+12 -7
src/agent/agent.py
··· 1 - from abc import ABC, abstractmethod 2 1 import asyncio 3 2 import json 4 3 import logging 4 + from abc import ABC, abstractmethod 5 5 from dataclasses import dataclass 6 6 from typing import Any, Literal 7 7 8 8 import anthropic 9 - from anthropic.types import TextBlock, ToolUseBlock 10 9 import httpx 10 + from anthropic.types import TextBlock, ToolUseBlock 11 11 12 12 from src.agent.prompt import build_system_prompt 13 13 from src.tools.executor import ToolExecutor ··· 134 134 json=payload, 135 135 ) 136 136 if not resp.is_success: 137 - logger.error( 138 - "API error %d: %s", resp.status_code, resp.text[:1000] 139 - ) 137 + logger.error("API error %d: %s", resp.status_code, resp.text[:1000]) 140 138 resp.raise_for_status() 141 139 data = resp.json() 142 140 ··· 240 238 241 239 stop_reason = "tool_use" if finish_reason == "tool_calls" else "end_turn" 242 240 reasoning_content = message.get("reasoning_content") 243 - return AgentResponse(content=content, stop_reason=stop_reason, reasoning_content=reasoning_content) 241 + return AgentResponse( 242 + content=content, 243 + stop_reason=stop_reason, 244 + reasoning_content=reasoning_content, 245 + ) 244 246 245 247 246 248 MAX_TOOL_RESULT_LENGTH = 10_000 ··· 323 325 } 324 326 ) 325 327 326 - assistant_msg: dict[str, Any] = {"role": "assistant", "content": assistant_content} 328 + assistant_msg: dict[str, Any] = { 329 + "role": "assistant", 330 + "content": assistant_content, 331 + } 327 332 if resp.reasoning_content: 328 333 assistant_msg["reasoning_content"] = resp.reasoning_content 329 334 self._conversation.append(assistant_msg)
+1
src/clickhouse/clickhouse.py
··· 1 1 from typing import Any 2 + 2 3 from clickhouse_connect import get_async_client # type: ignore 3 4 from clickhouse_connect.driver.asyncclient import AsyncClient # type: ignore 4 5
+6 -6
src/tools/__init__.py
··· 1 + # Import tool definitions so they register themselves with TOOL_REGISTRY 2 + import src.tools.definitions.clickhouse # noqa: F401 3 + import src.tools.definitions.domain 4 + import src.tools.definitions.osprey # noqa: F401 5 + import src.tools.definitions.ozone # noqa: F401 1 6 from src.tools.executor import ToolExecutor 2 7 from src.tools.registry import ( 8 + TOOL_REGISTRY, 3 9 Tool, 4 10 ToolContext, 5 11 ToolParameter, 6 12 ToolRegistry, 7 - TOOL_REGISTRY, 8 13 ) 9 - 10 - # Import tool definitions so they register themselves with TOOL_REGISTRY 11 - import src.tools.definitions.clickhouse # noqa: F401 12 - import src.tools.definitions.osprey # noqa: F401 13 - import src.tools.definitions.ozone # noqa: F401 14 14 15 15 __all__ = [ 16 16 "Tool",
+145
src/tools/definitions/domain.py
··· 1 + import asyncio 2 + import re 3 + from typing import Any 4 + 5 + import httpx 6 + from dns import asyncresolver 7 + 8 + from src.tools.registry import TOOL_REGISTRY, ToolContext, ToolParameter 9 + 10 + _DOMAIN_REGEX = re.compile(r"^https?://") 11 + 12 + 13 + async def _check_http(domain: str) -> tuple[str | int, str | None]: 14 + """check the http status and see if the domain redirects elsewhere""" 15 + # give it a shot with https first 16 + try: 17 + async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client: 18 + response = await client.head(f"https://{domain}") 19 + redirects_to = response.headers.get("Location") 20 + return response.status_code, redirects_to 21 + except Exception: 22 + pass 23 + 24 + # otherwise try http 25 + try: 26 + async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client: 27 + response = await client.head(f"http://{domain}") 28 + redirects_to = response.headers.get("Location") 29 + return response.status_code, redirects_to 30 + except Exception: 31 + return "unreachable", None 32 + 33 + 34 + async def _query_dns( 35 + resolver: asyncresolver.Resolver, domain: str, record_type: str 36 + ) -> list[str] | str | None: 37 + """query domains for a given domain and record type, with an input resolver""" 38 + try: 39 + answers = await resolver.resolve(domain, record_type) 40 + 41 + if record_type == "SOA": 42 + # soa returns a single answer 43 + return str(answers[0]) if answers else None 44 + elif record_type == "MX": 45 + # mx have priority 46 + return [f"{answer.preference} {answer.exchange}" for answer in answers] 47 + elif record_type == "TXT": 48 + # txt have quotes 49 + return [ 50 + " ".join( 51 + str(s, "utf-8") if isinstance(s, bytes) else str(s) 52 + for s in answer.strings 53 + ) 54 + for answer in answers 55 + ] 56 + else: 57 + return [str(answer) for answer in answers] 58 + except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.NoNameservers): 59 + return [] if record_type != "SOA" else None 60 + except Exception: 61 + return [] if record_type != "SOA" else None 62 + 63 + 64 + @TOOL_REGISTRY.tool( 65 + name="clickhouse.query", 66 + description="Lookup A, AAAA, NS, MX, TXT, CNAME, and SOA for a given input domain", 67 + parameters=[ 68 + ToolParameter( 69 + name="domain", 70 + type="string", 71 + description="The domain name (not a URL) to check", 72 + ), 73 + ], 74 + ) 75 + async def check_domain(ctx: ToolContext, domain: str): 76 + # defensive incase the model decides to stick a url in instead of a domain 77 + re.sub(_DOMAIN_REGEX, "", domain).split("/")[0] 78 + 79 + try: 80 + resolver = asyncresolver.Resolver() 81 + 82 + dns_tasks: dict[str, Any] = { 83 + "a": _query_dns(resolver, domain, "A"), 84 + "aaaa": _query_dns(resolver, domain, "AAAA"), 85 + "ns": _query_dns(resolver, domain, "NS"), 86 + "mx": _query_dns(resolver, domain, "MX"), 87 + "txt": _query_dns(resolver, domain, "TXT"), 88 + "cname": _query_dns(resolver, domain, "CNAME"), 89 + "soa": _query_dns(resolver, domain, "SOA"), 90 + } 91 + 92 + # run all of the lookups in parallel 93 + dns_results = await asyncio.gather(*dns_tasks.values(), return_exceptions=True) 94 + dns_data = dict(zip(dns_tasks.keys(), dns_results)) 95 + 96 + a_records = ( 97 + dns_data.get("a", []) 98 + if not isinstance(dns_data.get("a"), Exception) 99 + else [] 100 + ) 101 + aaaa_records = ( 102 + dns_data.get("aaaa", []) 103 + if not isinstance(dns_data.get("aaaa"), Exception) 104 + else [] 105 + ) 106 + cname_records = ( 107 + dns_data.get("cname", []) 108 + if not isinstance(dns_data.get("cname"), Exception) 109 + else [] 110 + ) 111 + 112 + http_status, redirects_to = await _check_http(domain) 113 + 114 + result: dict[str, Any] = { 115 + "success": True, 116 + "domain": domain, 117 + "resolves": len(a_records) > 0 118 + or len(aaaa_records) > 0 119 + or len(cname_records) > 0, 120 + "dns": { 121 + "a": a_records, 122 + "aaaa": aaaa_records, 123 + "cname": cname_records, 124 + "ns": dns_data.get("ns", []) 125 + if not isinstance(dns_data.get("ns"), Exception) 126 + else [], 127 + "mx": dns_data.get("mx", []) 128 + if not isinstance(dns_data.get("mx"), Exception) 129 + else [], 130 + "txt": dns_data.get("txt", []) 131 + if not isinstance(dns_data.get("txt"), Exception) 132 + else [], 133 + "soa": dns_data.get("soa") 134 + if not isinstance(dns_data.get("soa"), Exception) 135 + else None, 136 + }, 137 + "http_status": http_status, 138 + "redirects_to": redirects_to, 139 + } 140 + 141 + return result 142 + 143 + except Exception as e: 144 + result = {"success": False, "domain": domain, "error": str(e)} 145 + return result
+2 -2
src/tools/deno/tools.ts
··· 5 5 /** Get Osprey/network table schema information including tables and their columns. Schema is for the table default.osprey_execution_results */ 6 6 getSchema: (): Promise<unknown> => callTool("clickhouse.getSchema", {}), 7 7 8 - /** Execute a SQL query against ClickHouse and return the results. All queries must include a LIMIT, and all queries must be executed on default.osprey_execution_results. */ 9 - query: (sql: string): Promise<unknown> => callTool("clickhouse.query", { sql }), 8 + /** Lookup A, AAAA, NS, MX, TXT, CNAME, and SOA for a given input domain */ 9 + query: (domain: string): Promise<unknown> => callTool("clickhouse.query", { domain }), 10 10 }; 11 11 12 12 export const osprey = {
+2
uv.lock
··· 603 603 { name = "atproto" }, 604 604 { name = "click" }, 605 605 { name = "clickhouse-connect" }, 606 + { name = "dnspython" }, 606 607 { name = "pydantic" }, 607 608 { name = "pydantic-settings" }, 608 609 ] ··· 614 615 { name = "atproto", specifier = ">=0.0.65" }, 615 616 { name = "click", specifier = ">=8.3.1" }, 616 617 { name = "clickhouse-connect", specifier = ">=0.10.0" }, 618 + { name = "dnspython", specifier = ">=2.8.0" }, 617 619 { name = "pydantic", specifier = ">=2.12.5" }, 618 620 { name = "pydantic-settings", specifier = ">=2.12.0" }, 619 621 ]