this repo has no description
1use reqwest::{Client, ClientBuilder, Url};
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::sync::OnceLock;
4use std::time::Duration;
5use tracing::warn;
6
7pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
8pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
9pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
10pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
11
12static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
13static DID_RESOLUTION_CLIENT: OnceLock<Client> = OnceLock::new();
14static HANDLE_RESOLUTION_CLIENT: OnceLock<Client> = OnceLock::new();
15
16pub fn proxy_client() -> &'static Client {
17 PROXY_CLIENT.get_or_init(|| {
18 ClientBuilder::new()
19 .timeout(DEFAULT_BODY_TIMEOUT)
20 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
21 .pool_max_idle_per_host(10)
22 .pool_idle_timeout(Duration::from_secs(90))
23 .redirect(reqwest::redirect::Policy::none())
24 .build()
25 .expect(
26 "Failed to build HTTP client - this indicates a TLS or system configuration issue",
27 )
28 })
29}
30
31pub fn did_resolution_client() -> &'static Client {
32 DID_RESOLUTION_CLIENT.get_or_init(|| {
33 ClientBuilder::new()
34 .timeout(Duration::from_secs(5))
35 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
36 .pool_max_idle_per_host(10)
37 .pool_idle_timeout(Duration::from_secs(90))
38 .build()
39 .expect(
40 "Failed to build DID resolution client - this indicates a TLS or system configuration issue",
41 )
42 })
43}
44
45pub fn handle_resolution_client() -> &'static Client {
46 HANDLE_RESOLUTION_CLIENT.get_or_init(|| {
47 ClientBuilder::new()
48 .timeout(Duration::from_secs(10))
49 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
50 .pool_max_idle_per_host(10)
51 .pool_idle_timeout(Duration::from_secs(90))
52 .redirect(reqwest::redirect::Policy::limited(5))
53 .build()
54 .expect(
55 "Failed to build handle resolution client - this indicates a TLS or system configuration issue",
56 )
57 })
58}
59
60pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
61 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
62 let scheme = parsed.scheme();
63 if scheme != "https" {
64 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok()
65 || url.starts_with("http://127.0.0.1")
66 || url.starts_with("http://localhost");
67 if !allow_http {
68 return Err(SsrfError::InsecureProtocol(scheme.to_string()));
69 }
70 }
71 let host = parsed.host_str().ok_or(SsrfError::NoHost)?;
72 if host == "localhost" {
73 return Ok(());
74 }
75 if let Ok(ip) = host.parse::<IpAddr>() {
76 if ip.is_loopback() {
77 return Ok(());
78 }
79 if !is_unicast_ip(&ip) {
80 return Err(SsrfError::NonUnicastIp(ip.to_string()));
81 }
82 return Ok(());
83 }
84 let port = parsed
85 .port()
86 .unwrap_or(if scheme == "https" { 443 } else { 80 });
87 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
88 Ok(addrs) => addrs.collect(),
89 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
90 };
91 if let Some(addr) = socket_addrs.iter().find(|addr| !is_unicast_ip(&addr.ip())) {
92 warn!(
93 "DNS resolution for {} returned non-unicast IP: {}",
94 host,
95 addr.ip()
96 );
97 return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
98 }
99 Ok(())
100}
101
102fn is_unicast_ip(ip: &IpAddr) -> bool {
103 match ip {
104 IpAddr::V4(v4) => {
105 !v4.is_loopback()
106 && !v4.is_broadcast()
107 && !v4.is_multicast()
108 && !v4.is_unspecified()
109 && !v4.is_link_local()
110 && !is_private_v4(v4)
111 }
112 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
113 }
114}
115
116fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
117 let octets = ip.octets();
118 octets[0] == 10
119 || (octets[0] == 172 && (16..=31).contains(&octets[1]))
120 || (octets[0] == 192 && octets[1] == 168)
121 || (octets[0] == 169 && octets[1] == 254)
122}
123
124#[derive(Debug, Clone)]
125pub enum SsrfError {
126 InvalidUrl,
127 InsecureProtocol(String),
128 NoHost,
129 NonUnicastIp(String),
130 DnsResolutionFailed(String),
131}
132
133impl std::fmt::Display for SsrfError {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 SsrfError::InvalidUrl => write!(f, "Invalid URL"),
137 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
138 SsrfError::NoHost => write!(f, "No host in URL"),
139 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
140 SsrfError::DnsResolutionFailed(host) => {
141 write!(f, "DNS resolution failed for: {}", host)
142 }
143 }
144 }
145}
146
147impl std::error::Error for SsrfError {}
148
149pub const HEADERS_TO_FORWARD: &[&str] = &[
150 "accept-language",
151 "atproto-accept-labelers",
152 "x-bsky-topics",
153 "content-type",
154];
155pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[
156 "atproto-repo-rev",
157 "atproto-content-labelers",
158 "retry-after",
159 "content-type",
160 "cache-control",
161 "etag",
162];
163
164pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
165 if !uri.starts_with("at://") {
166 return Err("URI must start with at://");
167 }
168 let path = uri.trim_start_matches("at://");
169 let parts: Vec<&str> = path.split('/').collect();
170 if parts.is_empty() {
171 return Err("URI missing DID");
172 }
173 let did = parts[0];
174 if !did.starts_with("did:") {
175 return Err("Invalid DID in URI");
176 }
177 if parts.len() > 1 {
178 let collection = parts[1];
179 if collection.is_empty() || !collection.contains('.') {
180 return Err("Invalid collection NSID");
181 }
182 }
183 Ok(AtUriParts {
184 did: did.to_string(),
185 collection: parts.get(1).map(|s| s.to_string()),
186 rkey: parts.get(2).map(|s| s.to_string()),
187 })
188}
189
190#[derive(Debug, Clone)]
191pub struct AtUriParts {
192 pub did: String,
193 pub collection: Option<String>,
194 pub rkey: Option<String>,
195}
196
197pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
198 match limit {
199 Some(0) => default,
200 Some(l) if l > max => max,
201 Some(l) => l,
202 None => default,
203 }
204}
205
206pub fn validate_did(did: &str) -> Result<(), &'static str> {
207 if !did.starts_with("did:") {
208 return Err("Invalid DID format");
209 }
210 let parts: Vec<&str> = did.split(':').collect();
211 if parts.len() < 3 {
212 return Err("DID must have at least method and identifier");
213 }
214 let method = parts[1];
215 if method != "plc" && method != "web" {
216 return Err("Unsupported DID method");
217 }
218 Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 #[test]
225 fn test_ssrf_safe_https() {
226 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok());
227 }
228 #[test]
229 fn test_ssrf_blocks_http_by_default() {
230 let result = is_ssrf_safe("http://external.example.com/xrpc/test");
231 assert!(matches!(
232 result,
233 Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))
234 ));
235 }
236 #[test]
237 fn test_ssrf_allows_localhost_http() {
238 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok());
239 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok());
240 }
241 #[test]
242 fn test_validate_at_uri() {
243 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123");
244 assert!(result.is_ok());
245 let parts = result.unwrap();
246 assert_eq!(parts.did, "did:plc:test");
247 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string()));
248 assert_eq!(parts.rkey, Some("abc123".to_string()));
249 }
250 #[test]
251 fn test_validate_at_uri_invalid() {
252 assert!(validate_at_uri("https://example.com").is_err());
253 assert!(validate_at_uri("at://notadid/collection/rkey").is_err());
254 }
255 #[test]
256 fn test_validate_limit() {
257 assert_eq!(validate_limit(None, 50, 100), 50);
258 assert_eq!(validate_limit(Some(0), 50, 100), 50);
259 assert_eq!(validate_limit(Some(200), 50, 100), 100);
260 assert_eq!(validate_limit(Some(75), 50, 100), 75);
261 }
262 #[test]
263 fn test_validate_did() {
264 assert!(validate_did("did:plc:abc123").is_ok());
265 assert!(validate_did("did:web:example.com").is_ok());
266 assert!(validate_did("notadid").is_err());
267 assert!(validate_did("did:unknown:test").is_err());
268 }
269}