use std::collections::HashMap; use color_eyre::eyre::Context as _; use pingora::lb; use pingora::lb::selection::consistent::KetamaHashing; use pingora::prelude::*; use self::gateway::{AuthGateway, BackendData, DomainInfo, oidc}; mod config; mod cookies; mod gateway; mod httputil; mod oauth; /// constructed load balancer, with [backend info][`BackendInfo`] to be passed to [`AuthGateway`] type BalancerInfo = ( Vec>>, HashMap, ); /// construct the load balancer and initialize the [backend info][`BackendInfo`] for the /// [`AuthGateway`] fn balancer(domains: &HashMap) -> color_eyre::Result { use lb::{self, Backend, discovery}; use pingora::protocols::l4::socket::SocketAddr; let mut balancers = HashMap::with_capacity(domains.len()); let mut svcs = Vec::with_capacity(domains.len()); for (name, domain) in domains { let backends = domain .https .iter() .map(|backend| { let mut ext = lb::Extensions::new(); ext.insert(BackendData { tls: true }); let mut backend = Backend::new_with_weight( &backend.addr, backend.weight.map(|w| w as usize).unwrap_or(1), ) .context("parsing addr of https socket backend")?; backend.ext = ext; Ok(backend) }) .chain(domain.http.iter().map(|backend| { let mut ext = lb::Extensions::new(); ext.insert(BackendData { tls: false }); let mut backend = Backend::new_with_weight( &backend.addr, backend.weight.map(|w| w as usize).unwrap_or(1), ) .context("parsing addr of http socket backend")?; backend.ext = ext; Ok(backend) })) .chain(domain.uds.iter().map(|backend| { let mut ext = lb::Extensions::new(); ext.insert(BackendData { tls: false }); Ok(Backend { addr: SocketAddr::Unix( std::os::unix::net::SocketAddr::from_pathname(&backend.path) .context("turning uds path into socketaddr")?, ), weight: backend.weight.map(|w| w as usize).unwrap_or(1), ext, }) })) .collect::>() .context("constucting backends for domain")?; let backends = lb::Backends::new(discovery::Static::new(backends)); let balancer = LoadBalancer::from_backends(backends); let svc = background_service("health checking", balancer); let info = DomainInfo { balancer: svc.task(), tls_mode: config::format::domain::TlsMode::try_from(domain.tls_mode) .context("invalid tls mode")?, sni_name: name.clone(), oidc: domain .oidc_auth .clone() .map(|config| oidc::Info::from_config(config, name.clone())) .transpose()?, headers: domain.manage_headers.clone().unwrap_or_default(), }; balancers.insert(name.clone(), info); svcs.push(svc) } Ok((svcs, balancers)) } fn main() -> color_eyre::Result<()> { use color_eyre::eyre::eyre; tracing_subscriber::fmt().init(); color_eyre::install()?; rustls::crypto::aws_lc_rs::default_provider() .install_default() .expect("unable to install crypto provider"); let opts = Opt::parse_args(); let config = config::load(opts.conf.as_ref().ok_or_else(|| { eyre!("no config file specified, refusing to do anything (try `-c FILE`?)") })?)?; let pingora_config = match config.pingora.as_ref().map(Into::into) { Some(conf) => conf, None => pingora::server::configuration::ServerConf::new_with_opt_override(&opts) .ok_or_else(|| { eyre!("could not create a base pingora config, and none was specified") })?, }; let mut server = Server::new_with_opt_and_conf(opts, pingora_config); server.bootstrap(); let (balancer_svcs, balancers) = balancer(&config.domains).context("setting up load balancing")?; let mut gateway = http_proxy_service(&server.configuration, AuthGateway { domains: balancers }); for binding in config.bind_to_tcp { match binding.tls { Some(tls) => gateway .add_tls(&binding.addr, &tls.cert_path, &tls.key_path) .context("setting up tls")?, None => gateway.add_tcp(&binding.addr), } } balancer_svcs .into_iter() .for_each(|svc| server.add_service(svc)); server.add_service(gateway); server.run_forever(); }