a (hacky, wip) multi-tenant oidc-terminating reverse proxy, written in anger on top of pingora
1use std::collections::HashMap;
2
3use color_eyre::eyre::Context as _;
4use pingora::lb;
5use pingora::lb::selection::consistent::KetamaHashing;
6use pingora::prelude::*;
7
8use self::gateway::{AuthGateway, BackendData, DomainInfo, oidc};
9
10mod config;
11mod cookies;
12mod gateway;
13mod httputil;
14mod oauth;
15
16/// constructed load balancer, with [backend info][`BackendInfo`] to be passed to [`AuthGateway`]
17type BalancerInfo = (
18 Vec<pingora::services::background::GenBackgroundService<LoadBalancer<KetamaHashing>>>,
19 HashMap<String, DomainInfo>,
20);
21
22/// construct the load balancer and initialize the [backend info][`BackendInfo`] for the
23/// [`AuthGateway`]
24fn balancer(domains: &HashMap<String, config::format::Domain>) -> color_eyre::Result<BalancerInfo> {
25 use lb::{self, Backend, discovery};
26 use pingora::protocols::l4::socket::SocketAddr;
27
28 let mut balancers = HashMap::with_capacity(domains.len());
29 let mut svcs = Vec::with_capacity(domains.len());
30 for (name, domain) in domains {
31 let backends = domain
32 .https
33 .iter()
34 .map(|backend| {
35 let mut ext = lb::Extensions::new();
36 ext.insert(BackendData { tls: true });
37 let mut backend = Backend::new_with_weight(
38 &backend.addr,
39 backend.weight.map(|w| w as usize).unwrap_or(1),
40 )
41 .context("parsing addr of https socket backend")?;
42 backend.ext = ext;
43 Ok(backend)
44 })
45 .chain(domain.http.iter().map(|backend| {
46 let mut ext = lb::Extensions::new();
47 ext.insert(BackendData { tls: false });
48 let mut backend = Backend::new_with_weight(
49 &backend.addr,
50 backend.weight.map(|w| w as usize).unwrap_or(1),
51 )
52 .context("parsing addr of http socket backend")?;
53 backend.ext = ext;
54 Ok(backend)
55 }))
56 .chain(domain.uds.iter().map(|backend| {
57 let mut ext = lb::Extensions::new();
58 ext.insert(BackendData { tls: false });
59 Ok(Backend {
60 addr: SocketAddr::Unix(
61 std::os::unix::net::SocketAddr::from_pathname(&backend.path)
62 .context("turning uds path into socketaddr")?,
63 ),
64 weight: backend.weight.map(|w| w as usize).unwrap_or(1),
65 ext,
66 })
67 }))
68 .collect::<color_eyre::Result<_>>()
69 .context("constucting backends for domain")?;
70 let backends = lb::Backends::new(discovery::Static::new(backends));
71 let balancer = LoadBalancer::from_backends(backends);
72 let svc = background_service("health checking", balancer);
73
74 let info = DomainInfo {
75 balancer: svc.task(),
76 tls_mode: config::format::domain::TlsMode::try_from(domain.tls_mode)
77 .context("invalid tls mode")?,
78 sni_name: name.clone(),
79 oidc: domain
80 .oidc_auth
81 .clone()
82 .map(|config| oidc::Info::from_config(config, name.clone()))
83 .transpose()?,
84 headers: domain.manage_headers.clone().unwrap_or_default(),
85 };
86
87 balancers.insert(name.clone(), info);
88 svcs.push(svc)
89 }
90
91 Ok((svcs, balancers))
92}
93
94fn main() -> color_eyre::Result<()> {
95 use color_eyre::eyre::eyre;
96
97 tracing_subscriber::fmt().init();
98 color_eyre::install()?;
99 rustls::crypto::aws_lc_rs::default_provider()
100 .install_default()
101 .expect("unable to install crypto provider");
102
103 let opts = Opt::parse_args();
104
105 let config = config::load(opts.conf.as_ref().ok_or_else(|| {
106 eyre!("no config file specified, refusing to do anything (try `-c FILE`?)")
107 })?)?;
108
109 let pingora_config = match config.pingora.as_ref().map(Into::into) {
110 Some(conf) => conf,
111 None => pingora::server::configuration::ServerConf::new_with_opt_override(&opts)
112 .ok_or_else(|| {
113 eyre!("could not create a base pingora config, and none was specified")
114 })?,
115 };
116 let mut server = Server::new_with_opt_and_conf(opts, pingora_config);
117 server.bootstrap();
118
119 let (balancer_svcs, balancers) =
120 balancer(&config.domains).context("setting up load balancing")?;
121 let mut gateway = http_proxy_service(&server.configuration, AuthGateway { domains: balancers });
122 for binding in config.bind_to_tcp {
123 match binding.tls {
124 Some(tls) => gateway
125 .add_tls(&binding.addr, &tls.cert_path, &tls.key_path)
126 .context("setting up tls")?,
127 None => gateway.add_tcp(&binding.addr),
128 }
129 }
130
131 balancer_svcs
132 .into_iter()
133 .for_each(|svc| server.add_service(svc));
134 server.add_service(gateway);
135
136 server.run_forever();
137}