this repo has no description
1use reqwest::{Client, ClientBuilder, Url};
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::sync::OnceLock;
4use std::time::Duration;
5use tracing::warn;
6pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
7pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
8pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
9pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
10static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
11pub fn proxy_client() -> &'static Client {
12 PROXY_CLIENT.get_or_init(|| {
13 ClientBuilder::new()
14 .timeout(DEFAULT_BODY_TIMEOUT)
15 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
16 .pool_max_idle_per_host(10)
17 .pool_idle_timeout(Duration::from_secs(90))
18 .redirect(reqwest::redirect::Policy::none())
19 .build()
20 .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue")
21 })
22}
23pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
24 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
25 let scheme = parsed.scheme();
26 if scheme != "https" {
27 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok()
28 || url.starts_with("http://127.0.0.1")
29 || url.starts_with("http://localhost");
30 if !allow_http {
31 return Err(SsrfError::InsecureProtocol(scheme.to_string()));
32 }
33 }
34 let host = parsed.host_str().ok_or(SsrfError::NoHost)?;
35 if host == "localhost" {
36 return Ok(());
37 }
38 if let Ok(ip) = host.parse::<IpAddr>() {
39 if ip.is_loopback() {
40 return Ok(());
41 }
42 if !is_unicast_ip(&ip) {
43 return Err(SsrfError::NonUnicastIp(ip.to_string()));
44 }
45 return Ok(());
46 }
47 let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 });
48 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
49 Ok(addrs) => addrs.collect(),
50 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
51 };
52 for addr in &socket_addrs {
53 if !is_unicast_ip(&addr.ip()) {
54 warn!(
55 "DNS resolution for {} returned non-unicast IP: {}",
56 host,
57 addr.ip()
58 );
59 return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
60 }
61 }
62 Ok(())
63}
64fn is_unicast_ip(ip: &IpAddr) -> bool {
65 match ip {
66 IpAddr::V4(v4) => {
67 !v4.is_loopback()
68 && !v4.is_broadcast()
69 && !v4.is_multicast()
70 && !v4.is_unspecified()
71 && !v4.is_link_local()
72 && !is_private_v4(v4)
73 }
74 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
75 }
76}
77fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
78 let octets = ip.octets();
79 octets[0] == 10
80 || (octets[0] == 172 && (16..=31).contains(&octets[1]))
81 || (octets[0] == 192 && octets[1] == 168)
82 || (octets[0] == 169 && octets[1] == 254)
83}
84#[derive(Debug, Clone)]
85pub enum SsrfError {
86 InvalidUrl,
87 InsecureProtocol(String),
88 NoHost,
89 NonUnicastIp(String),
90 DnsResolutionFailed(String),
91}
92impl std::fmt::Display for SsrfError {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 SsrfError::InvalidUrl => write!(f, "Invalid URL"),
96 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
97 SsrfError::NoHost => write!(f, "No host in URL"),
98 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
99 SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host),
100 }
101 }
102}
103impl std::error::Error for SsrfError {}
104pub const HEADERS_TO_FORWARD: &[&str] = &[
105 "accept-language",
106 "atproto-accept-labelers",
107 "x-bsky-topics",
108];
109pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[
110 "atproto-repo-rev",
111 "atproto-content-labelers",
112 "retry-after",
113 "content-type",
114];
115pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
116 if !uri.starts_with("at://") {
117 return Err("URI must start with at://");
118 }
119 let path = uri.trim_start_matches("at://");
120 let parts: Vec<&str> = path.split('/').collect();
121 if parts.is_empty() {
122 return Err("URI missing DID");
123 }
124 let did = parts[0];
125 if !did.starts_with("did:") {
126 return Err("Invalid DID in URI");
127 }
128 if parts.len() > 1 {
129 let collection = parts[1];
130 if collection.is_empty() || !collection.contains('.') {
131 return Err("Invalid collection NSID");
132 }
133 }
134 Ok(AtUriParts {
135 did: did.to_string(),
136 collection: parts.get(1).map(|s| s.to_string()),
137 rkey: parts.get(2).map(|s| s.to_string()),
138 })
139}
140#[derive(Debug, Clone)]
141pub struct AtUriParts {
142 pub did: String,
143 pub collection: Option<String>,
144 pub rkey: Option<String>,
145}
146pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
147 match limit {
148 Some(l) if l == 0 => default,
149 Some(l) if l > max => max,
150 Some(l) => l,
151 None => default,
152 }
153}
154pub fn validate_did(did: &str) -> Result<(), &'static str> {
155 if !did.starts_with("did:") {
156 return Err("Invalid DID format");
157 }
158 let parts: Vec<&str> = did.split(':').collect();
159 if parts.len() < 3 {
160 return Err("DID must have at least method and identifier");
161 }
162 let method = parts[1];
163 if method != "plc" && method != "web" {
164 return Err("Unsupported DID method");
165 }
166 Ok(())
167}
168#[cfg(test)]
169mod tests {
170 use super::*;
171 #[test]
172 fn test_ssrf_safe_https() {
173 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok());
174 }
175 #[test]
176 fn test_ssrf_blocks_http_by_default() {
177 let result = is_ssrf_safe("http://external.example.com/xrpc/test");
178 assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))));
179 }
180 #[test]
181 fn test_ssrf_allows_localhost_http() {
182 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok());
183 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok());
184 }
185 #[test]
186 fn test_validate_at_uri() {
187 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123");
188 assert!(result.is_ok());
189 let parts = result.unwrap();
190 assert_eq!(parts.did, "did:plc:test");
191 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string()));
192 assert_eq!(parts.rkey, Some("abc123".to_string()));
193 }
194 #[test]
195 fn test_validate_at_uri_invalid() {
196 assert!(validate_at_uri("https://example.com").is_err());
197 assert!(validate_at_uri("at://notadid/collection/rkey").is_err());
198 }
199 #[test]
200 fn test_validate_limit() {
201 assert_eq!(validate_limit(None, 50, 100), 50);
202 assert_eq!(validate_limit(Some(0), 50, 100), 50);
203 assert_eq!(validate_limit(Some(200), 50, 100), 100);
204 assert_eq!(validate_limit(Some(75), 50, 100), 75);
205 }
206 #[test]
207 fn test_validate_did() {
208 assert!(validate_did("did:plc:abc123").is_ok());
209 assert!(validate_did("did:web:example.com").is_ok());
210 assert!(validate_did("notadid").is_err());
211 assert!(validate_did("did:unknown:test").is_err());
212 }
213}