this repo has no description
1use reqwest::{Client, ClientBuilder, Url};
2use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
3use std::sync::OnceLock;
4use std::time::Duration;
5use tracing::warn;
6
7pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10);
8pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30);
9pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
10pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024;
11
12static PROXY_CLIENT: OnceLock<Client> = OnceLock::new();
13
14pub fn proxy_client() -> &'static Client {
15 PROXY_CLIENT.get_or_init(|| {
16 ClientBuilder::new()
17 .timeout(DEFAULT_BODY_TIMEOUT)
18 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
19 .pool_max_idle_per_host(10)
20 .pool_idle_timeout(Duration::from_secs(90))
21 .redirect(reqwest::redirect::Policy::none())
22 .build()
23 .expect(
24 "Failed to build HTTP client - this indicates a TLS or system configuration issue",
25 )
26 })
27}
28
29pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> {
30 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?;
31 let scheme = parsed.scheme();
32 if scheme != "https" {
33 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok()
34 || url.starts_with("http://127.0.0.1")
35 || url.starts_with("http://localhost");
36 if !allow_http {
37 return Err(SsrfError::InsecureProtocol(scheme.to_string()));
38 }
39 }
40 let host = parsed.host_str().ok_or(SsrfError::NoHost)?;
41 if host == "localhost" {
42 return Ok(());
43 }
44 if let Ok(ip) = host.parse::<IpAddr>() {
45 if ip.is_loopback() {
46 return Ok(());
47 }
48 if !is_unicast_ip(&ip) {
49 return Err(SsrfError::NonUnicastIp(ip.to_string()));
50 }
51 return Ok(());
52 }
53 let port = parsed
54 .port()
55 .unwrap_or(if scheme == "https" { 443 } else { 80 });
56 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() {
57 Ok(addrs) => addrs.collect(),
58 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())),
59 };
60 for addr in &socket_addrs {
61 if !is_unicast_ip(&addr.ip()) {
62 warn!(
63 "DNS resolution for {} returned non-unicast IP: {}",
64 host,
65 addr.ip()
66 );
67 return Err(SsrfError::NonUnicastIp(addr.ip().to_string()));
68 }
69 }
70 Ok(())
71}
72
73fn is_unicast_ip(ip: &IpAddr) -> bool {
74 match ip {
75 IpAddr::V4(v4) => {
76 !v4.is_loopback()
77 && !v4.is_broadcast()
78 && !v4.is_multicast()
79 && !v4.is_unspecified()
80 && !v4.is_link_local()
81 && !is_private_v4(v4)
82 }
83 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(),
84 }
85}
86
87fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool {
88 let octets = ip.octets();
89 octets[0] == 10
90 || (octets[0] == 172 && (16..=31).contains(&octets[1]))
91 || (octets[0] == 192 && octets[1] == 168)
92 || (octets[0] == 169 && octets[1] == 254)
93}
94
95#[derive(Debug, Clone)]
96pub enum SsrfError {
97 InvalidUrl,
98 InsecureProtocol(String),
99 NoHost,
100 NonUnicastIp(String),
101 DnsResolutionFailed(String),
102}
103
104impl std::fmt::Display for SsrfError {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 SsrfError::InvalidUrl => write!(f, "Invalid URL"),
108 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p),
109 SsrfError::NoHost => write!(f, "No host in URL"),
110 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip),
111 SsrfError::DnsResolutionFailed(host) => {
112 write!(f, "DNS resolution failed for: {}", host)
113 }
114 }
115 }
116}
117
118impl std::error::Error for SsrfError {}
119
120pub const HEADERS_TO_FORWARD: &[&str] = &[
121 "accept-language",
122 "atproto-accept-labelers",
123 "x-bsky-topics",
124 "content-type",
125];
126pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[
127 "atproto-repo-rev",
128 "atproto-content-labelers",
129 "retry-after",
130 "content-type",
131 "cache-control",
132 "etag",
133];
134
135pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> {
136 if !uri.starts_with("at://") {
137 return Err("URI must start with at://");
138 }
139 let path = uri.trim_start_matches("at://");
140 let parts: Vec<&str> = path.split('/').collect();
141 if parts.is_empty() {
142 return Err("URI missing DID");
143 }
144 let did = parts[0];
145 if !did.starts_with("did:") {
146 return Err("Invalid DID in URI");
147 }
148 if parts.len() > 1 {
149 let collection = parts[1];
150 if collection.is_empty() || !collection.contains('.') {
151 return Err("Invalid collection NSID");
152 }
153 }
154 Ok(AtUriParts {
155 did: did.to_string(),
156 collection: parts.get(1).map(|s| s.to_string()),
157 rkey: parts.get(2).map(|s| s.to_string()),
158 })
159}
160
161#[derive(Debug, Clone)]
162pub struct AtUriParts {
163 pub did: String,
164 pub collection: Option<String>,
165 pub rkey: Option<String>,
166}
167
168pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 {
169 match limit {
170 Some(0) => default,
171 Some(l) if l > max => max,
172 Some(l) => l,
173 None => default,
174 }
175}
176
177pub fn validate_did(did: &str) -> Result<(), &'static str> {
178 if !did.starts_with("did:") {
179 return Err("Invalid DID format");
180 }
181 let parts: Vec<&str> = did.split(':').collect();
182 if parts.len() < 3 {
183 return Err("DID must have at least method and identifier");
184 }
185 let method = parts[1];
186 if method != "plc" && method != "web" {
187 return Err("Unsupported DID method");
188 }
189 Ok(())
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 #[test]
196 fn test_ssrf_safe_https() {
197 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok());
198 }
199 #[test]
200 fn test_ssrf_blocks_http_by_default() {
201 let result = is_ssrf_safe("http://external.example.com/xrpc/test");
202 assert!(matches!(
203 result,
204 Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_))
205 ));
206 }
207 #[test]
208 fn test_ssrf_allows_localhost_http() {
209 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok());
210 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok());
211 }
212 #[test]
213 fn test_validate_at_uri() {
214 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123");
215 assert!(result.is_ok());
216 let parts = result.unwrap();
217 assert_eq!(parts.did, "did:plc:test");
218 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string()));
219 assert_eq!(parts.rkey, Some("abc123".to_string()));
220 }
221 #[test]
222 fn test_validate_at_uri_invalid() {
223 assert!(validate_at_uri("https://example.com").is_err());
224 assert!(validate_at_uri("at://notadid/collection/rkey").is_err());
225 }
226 #[test]
227 fn test_validate_limit() {
228 assert_eq!(validate_limit(None, 50, 100), 50);
229 assert_eq!(validate_limit(Some(0), 50, 100), 50);
230 assert_eq!(validate_limit(Some(200), 50, 100), 100);
231 assert_eq!(validate_limit(Some(75), 50, 100), 75);
232 }
233 #[test]
234 fn test_validate_did() {
235 assert!(validate_did("did:plc:abc123").is_ok());
236 assert!(validate_did("did:web:example.com").is_ok());
237 assert!(validate_did("notadid").is_err());
238 assert!(validate_did("did:unknown:test").is_err());
239 }
240}