this repo has no description
1use reqwest::{Client, ClientBuilder, Url}; 2use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; 3use std::sync::OnceLock; 4use std::time::Duration; 5use tracing::warn; 6pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 7pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 8pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 9pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 10static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 11pub fn proxy_client() -> &'static Client { 12 PROXY_CLIENT.get_or_init(|| { 13 ClientBuilder::new() 14 .timeout(DEFAULT_BODY_TIMEOUT) 15 .connect_timeout(DEFAULT_CONNECT_TIMEOUT) 16 .pool_max_idle_per_host(10) 17 .pool_idle_timeout(Duration::from_secs(90)) 18 .redirect(reqwest::redirect::Policy::none()) 19 .build() 20 .expect("Failed to build HTTP client - this indicates a TLS or system configuration issue") 21 }) 22} 23pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 24 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 25 let scheme = parsed.scheme(); 26 if scheme != "https" { 27 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok() 28 || url.starts_with("http://127.0.0.1") 29 || url.starts_with("http://localhost"); 30 if !allow_http { 31 return Err(SsrfError::InsecureProtocol(scheme.to_string())); 32 } 33 } 34 let host = parsed.host_str().ok_or(SsrfError::NoHost)?; 35 if host == "localhost" { 36 return Ok(()); 37 } 38 if let Ok(ip) = host.parse::<IpAddr>() { 39 if ip.is_loopback() { 40 return Ok(()); 41 } 42 if !is_unicast_ip(&ip) { 43 return Err(SsrfError::NonUnicastIp(ip.to_string())); 44 } 45 return Ok(()); 46 } 47 let port = parsed.port().unwrap_or(if scheme == "https" { 443 } else { 80 }); 48 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() { 49 Ok(addrs) => addrs.collect(), 50 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), 51 }; 52 for addr in &socket_addrs { 53 if !is_unicast_ip(&addr.ip()) { 54 warn!( 55 "DNS resolution for {} returned non-unicast IP: {}", 56 host, 57 addr.ip() 58 ); 59 return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); 60 } 61 } 62 Ok(()) 63} 64fn is_unicast_ip(ip: &IpAddr) -> bool { 65 match ip { 66 IpAddr::V4(v4) => { 67 !v4.is_loopback() 68 && !v4.is_broadcast() 69 && !v4.is_multicast() 70 && !v4.is_unspecified() 71 && !v4.is_link_local() 72 && !is_private_v4(v4) 73 } 74 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 75 } 76} 77fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 78 let octets = ip.octets(); 79 octets[0] == 10 80 || (octets[0] == 172 && (16..=31).contains(&octets[1])) 81 || (octets[0] == 192 && octets[1] == 168) 82 || (octets[0] == 169 && octets[1] == 254) 83} 84#[derive(Debug, Clone)] 85pub enum SsrfError { 86 InvalidUrl, 87 InsecureProtocol(String), 88 NoHost, 89 NonUnicastIp(String), 90 DnsResolutionFailed(String), 91} 92impl std::fmt::Display for SsrfError { 93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 94 match self { 95 SsrfError::InvalidUrl => write!(f, "Invalid URL"), 96 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p), 97 SsrfError::NoHost => write!(f, "No host in URL"), 98 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip), 99 SsrfError::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for: {}", host), 100 } 101 } 102} 103impl std::error::Error for SsrfError {} 104pub const HEADERS_TO_FORWARD: &[&str] = &[ 105 "accept-language", 106 "atproto-accept-labelers", 107 "x-bsky-topics", 108]; 109pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[ 110 "atproto-repo-rev", 111 "atproto-content-labelers", 112 "retry-after", 113 "content-type", 114]; 115pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 116 if !uri.starts_with("at://") { 117 return Err("URI must start with at://"); 118 } 119 let path = uri.trim_start_matches("at://"); 120 let parts: Vec<&str> = path.split('/').collect(); 121 if parts.is_empty() { 122 return Err("URI missing DID"); 123 } 124 let did = parts[0]; 125 if !did.starts_with("did:") { 126 return Err("Invalid DID in URI"); 127 } 128 if parts.len() > 1 { 129 let collection = parts[1]; 130 if collection.is_empty() || !collection.contains('.') { 131 return Err("Invalid collection NSID"); 132 } 133 } 134 Ok(AtUriParts { 135 did: did.to_string(), 136 collection: parts.get(1).map(|s| s.to_string()), 137 rkey: parts.get(2).map(|s| s.to_string()), 138 }) 139} 140#[derive(Debug, Clone)] 141pub struct AtUriParts { 142 pub did: String, 143 pub collection: Option<String>, 144 pub rkey: Option<String>, 145} 146pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 147 match limit { 148 Some(l) if l == 0 => default, 149 Some(l) if l > max => max, 150 Some(l) => l, 151 None => default, 152 } 153} 154pub fn validate_did(did: &str) -> Result<(), &'static str> { 155 if !did.starts_with("did:") { 156 return Err("Invalid DID format"); 157 } 158 let parts: Vec<&str> = did.split(':').collect(); 159 if parts.len() < 3 { 160 return Err("DID must have at least method and identifier"); 161 } 162 let method = parts[1]; 163 if method != "plc" && method != "web" { 164 return Err("Unsupported DID method"); 165 } 166 Ok(()) 167} 168#[cfg(test)] 169mod tests { 170 use super::*; 171 #[test] 172 fn test_ssrf_safe_https() { 173 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok()); 174 } 175 #[test] 176 fn test_ssrf_blocks_http_by_default() { 177 let result = is_ssrf_safe("http://external.example.com/xrpc/test"); 178 assert!(matches!(result, Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)))); 179 } 180 #[test] 181 fn test_ssrf_allows_localhost_http() { 182 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok()); 183 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok()); 184 } 185 #[test] 186 fn test_validate_at_uri() { 187 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123"); 188 assert!(result.is_ok()); 189 let parts = result.unwrap(); 190 assert_eq!(parts.did, "did:plc:test"); 191 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string())); 192 assert_eq!(parts.rkey, Some("abc123".to_string())); 193 } 194 #[test] 195 fn test_validate_at_uri_invalid() { 196 assert!(validate_at_uri("https://example.com").is_err()); 197 assert!(validate_at_uri("at://notadid/collection/rkey").is_err()); 198 } 199 #[test] 200 fn test_validate_limit() { 201 assert_eq!(validate_limit(None, 50, 100), 50); 202 assert_eq!(validate_limit(Some(0), 50, 100), 50); 203 assert_eq!(validate_limit(Some(200), 50, 100), 100); 204 assert_eq!(validate_limit(Some(75), 50, 100), 75); 205 } 206 #[test] 207 fn test_validate_did() { 208 assert!(validate_did("did:plc:abc123").is_ok()); 209 assert!(validate_did("did:web:example.com").is_ok()); 210 assert!(validate_did("notadid").is_err()); 211 assert!(validate_did("did:unknown:test").is_err()); 212 } 213}