this repo has no description
1use reqwest::{Client, ClientBuilder, Url}; 2use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; 3use std::sync::OnceLock; 4use std::time::Duration; 5use tracing::warn; 6 7pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 8pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 9pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 10pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 11 12static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 13 14pub fn proxy_client() -> &'static Client { 15 PROXY_CLIENT.get_or_init(|| { 16 ClientBuilder::new() 17 .timeout(DEFAULT_BODY_TIMEOUT) 18 .connect_timeout(DEFAULT_CONNECT_TIMEOUT) 19 .pool_max_idle_per_host(10) 20 .pool_idle_timeout(Duration::from_secs(90)) 21 .redirect(reqwest::redirect::Policy::none()) 22 .build() 23 .expect( 24 "Failed to build HTTP client - this indicates a TLS or system configuration issue", 25 ) 26 }) 27} 28 29pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 30 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 31 let scheme = parsed.scheme(); 32 if scheme != "https" { 33 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok() 34 || url.starts_with("http://127.0.0.1") 35 || url.starts_with("http://localhost"); 36 if !allow_http { 37 return Err(SsrfError::InsecureProtocol(scheme.to_string())); 38 } 39 } 40 let host = parsed.host_str().ok_or(SsrfError::NoHost)?; 41 if host == "localhost" { 42 return Ok(()); 43 } 44 if let Ok(ip) = host.parse::<IpAddr>() { 45 if ip.is_loopback() { 46 return Ok(()); 47 } 48 if !is_unicast_ip(&ip) { 49 return Err(SsrfError::NonUnicastIp(ip.to_string())); 50 } 51 return Ok(()); 52 } 53 let port = parsed 54 .port() 55 .unwrap_or(if scheme == "https" { 443 } else { 80 }); 56 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() { 57 Ok(addrs) => addrs.collect(), 58 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), 59 }; 60 for addr in &socket_addrs { 61 if !is_unicast_ip(&addr.ip()) { 62 warn!( 63 "DNS resolution for {} returned non-unicast IP: {}", 64 host, 65 addr.ip() 66 ); 67 return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); 68 } 69 } 70 Ok(()) 71} 72 73fn is_unicast_ip(ip: &IpAddr) -> bool { 74 match ip { 75 IpAddr::V4(v4) => { 76 !v4.is_loopback() 77 && !v4.is_broadcast() 78 && !v4.is_multicast() 79 && !v4.is_unspecified() 80 && !v4.is_link_local() 81 && !is_private_v4(v4) 82 } 83 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 84 } 85} 86 87fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 88 let octets = ip.octets(); 89 octets[0] == 10 90 || (octets[0] == 172 && (16..=31).contains(&octets[1])) 91 || (octets[0] == 192 && octets[1] == 168) 92 || (octets[0] == 169 && octets[1] == 254) 93} 94 95#[derive(Debug, Clone)] 96pub enum SsrfError { 97 InvalidUrl, 98 InsecureProtocol(String), 99 NoHost, 100 NonUnicastIp(String), 101 DnsResolutionFailed(String), 102} 103 104impl std::fmt::Display for SsrfError { 105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 106 match self { 107 SsrfError::InvalidUrl => write!(f, "Invalid URL"), 108 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p), 109 SsrfError::NoHost => write!(f, "No host in URL"), 110 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip), 111 SsrfError::DnsResolutionFailed(host) => { 112 write!(f, "DNS resolution failed for: {}", host) 113 } 114 } 115 } 116} 117 118impl std::error::Error for SsrfError {} 119 120pub const HEADERS_TO_FORWARD: &[&str] = &[ 121 "accept-language", 122 "atproto-accept-labelers", 123 "x-bsky-topics", 124 "content-type", 125]; 126pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[ 127 "atproto-repo-rev", 128 "atproto-content-labelers", 129 "retry-after", 130 "content-type", 131 "cache-control", 132 "etag", 133]; 134 135pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 136 if !uri.starts_with("at://") { 137 return Err("URI must start with at://"); 138 } 139 let path = uri.trim_start_matches("at://"); 140 let parts: Vec<&str> = path.split('/').collect(); 141 if parts.is_empty() { 142 return Err("URI missing DID"); 143 } 144 let did = parts[0]; 145 if !did.starts_with("did:") { 146 return Err("Invalid DID in URI"); 147 } 148 if parts.len() > 1 { 149 let collection = parts[1]; 150 if collection.is_empty() || !collection.contains('.') { 151 return Err("Invalid collection NSID"); 152 } 153 } 154 Ok(AtUriParts { 155 did: did.to_string(), 156 collection: parts.get(1).map(|s| s.to_string()), 157 rkey: parts.get(2).map(|s| s.to_string()), 158 }) 159} 160 161#[derive(Debug, Clone)] 162pub struct AtUriParts { 163 pub did: String, 164 pub collection: Option<String>, 165 pub rkey: Option<String>, 166} 167 168pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 169 match limit { 170 Some(0) => default, 171 Some(l) if l > max => max, 172 Some(l) => l, 173 None => default, 174 } 175} 176 177pub fn validate_did(did: &str) -> Result<(), &'static str> { 178 if !did.starts_with("did:") { 179 return Err("Invalid DID format"); 180 } 181 let parts: Vec<&str> = did.split(':').collect(); 182 if parts.len() < 3 { 183 return Err("DID must have at least method and identifier"); 184 } 185 let method = parts[1]; 186 if method != "plc" && method != "web" { 187 return Err("Unsupported DID method"); 188 } 189 Ok(()) 190} 191 192#[cfg(test)] 193mod tests { 194 use super::*; 195 #[test] 196 fn test_ssrf_safe_https() { 197 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok()); 198 } 199 #[test] 200 fn test_ssrf_blocks_http_by_default() { 201 let result = is_ssrf_safe("http://external.example.com/xrpc/test"); 202 assert!(matches!( 203 result, 204 Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)) 205 )); 206 } 207 #[test] 208 fn test_ssrf_allows_localhost_http() { 209 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok()); 210 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok()); 211 } 212 #[test] 213 fn test_validate_at_uri() { 214 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123"); 215 assert!(result.is_ok()); 216 let parts = result.unwrap(); 217 assert_eq!(parts.did, "did:plc:test"); 218 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string())); 219 assert_eq!(parts.rkey, Some("abc123".to_string())); 220 } 221 #[test] 222 fn test_validate_at_uri_invalid() { 223 assert!(validate_at_uri("https://example.com").is_err()); 224 assert!(validate_at_uri("at://notadid/collection/rkey").is_err()); 225 } 226 #[test] 227 fn test_validate_limit() { 228 assert_eq!(validate_limit(None, 50, 100), 50); 229 assert_eq!(validate_limit(Some(0), 50, 100), 50); 230 assert_eq!(validate_limit(Some(200), 50, 100), 100); 231 assert_eq!(validate_limit(Some(75), 50, 100), 75); 232 } 233 #[test] 234 fn test_validate_did() { 235 assert!(validate_did("did:plc:abc123").is_ok()); 236 assert!(validate_did("did:web:example.com").is_ok()); 237 assert!(validate_did("notadid").is_err()); 238 assert!(validate_did("did:unknown:test").is_err()); 239 } 240}