a (hacky, wip) multi-tenant oidc-terminating reverse proxy, written in anger on top of pingora
1//! the actual gateway implementation
2
3use std::collections::HashMap;
4use std::ops::ControlFlow;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use cookie_rs::CookieJar;
9use http::status::StatusCode;
10use http::{HeaderName, HeaderValue};
11use pingora::lb::selection::consistent::KetamaHashing;
12use pingora::prelude::*;
13use url::Url;
14
15use crate::gateway::oidc::{InProgressAuth, SESSION_COOKIE_NAME, UserInfo};
16use crate::httputil::{internal_error, internal_error_from, status_error, status_error_from};
17use crate::oauth::auth_code_flow;
18use crate::{config, cookies, httputil};
19
20pub mod oidc;
21
22/// per-domain information about backends and such
23pub struct DomainInfo {
24 /// the load balancer to use to select backends
25 pub balancer: Arc<LoadBalancer<KetamaHashing>>,
26 /// whether or not we allow insecure connections from clients
27 pub tls_mode: config::format::domain::TlsMode,
28 /// the sni name of this domain, used to pass to backends
29 pub sni_name: String,
30 /// auth settings for this domain, if any
31 pub oidc: Option<oidc::Info>,
32 /// headers to mangle for requests on this domain
33 pub headers: config::format::ManageHeaders,
34}
35
36/// the actual gateway logic
37pub struct AuthGateway {
38 /// all known domains and their corresponding backends & settings
39 pub domains: HashMap<String, DomainInfo>,
40}
41
42impl AuthGateway {
43 /// fetch the domain info for this request
44 fn domain_info<'s>(&'s self, session: &Session) -> Result<&'s DomainInfo> {
45 let req = session.req_header();
46 // TODO(potential-bug): afaict, right now, afaict, pingora a) does not check that SNI matches the `Host`
47 // header, b) does not support extracting the SNI info on rustls, so we'll have to switch
48 // to boringssl and implement that ourselves T_T
49 let host = req
50 .headers
51 .get(http::header::HOST)
52 .ok_or_else(status_error(
53 "no host set",
54 ErrorSource::Downstream,
55 StatusCode::BAD_REQUEST,
56 ))?
57 .to_str()
58 .map_err(|e| {
59 Error::because(
60 ErrorType::HTTPStatus(StatusCode::BAD_REQUEST.into()),
61 "no host",
62 e,
63 )
64 })?;
65 let info = self.domains.get(host).ok_or_else(status_error(
66 "unknown host",
67 ErrorSource::Downstream,
68 StatusCode::SERVICE_UNAVAILABLE,
69 ))?;
70
71 Ok(info)
72 }
73
74 /// mangle general headers, per [`config::format::ManageHeaders`]
75 async fn strip_and_apply_general_headers(
76 &self,
77 session: &mut Session,
78 info: &DomainInfo,
79 is_https: bool,
80 ) -> Result<()> {
81 let remote_addr = session.client_addr().and_then(|addr| match addr {
82 pingora::protocols::l4::socket::SocketAddr::Inet(socket_addr) => {
83 Some(socket_addr.ip().to_string())
84 }
85 pingora::protocols::l4::socket::SocketAddr::Unix(_) => None,
86 });
87 let req = session.req_header_mut();
88 if let Some(header) = &info.headers.host {
89 // TODO(cleanup): preprocess all header names
90 let name = HeaderName::from_bytes(header.as_bytes())
91 .map_err(internal_error_from("invalid claim-to-header header name"))?;
92 let val = req
93 .headers
94 .get(http::header::HOST)
95 .expect("we had to have this to look up our backend")
96 .clone();
97 req.headers.insert(name, val);
98 }
99 if let Some(header) = &info.headers.x_forwarded_for
100 && let Some(addr) = &remote_addr
101 {
102 let name = HeaderName::from_bytes(header.as_bytes())
103 .map_err(internal_error_from("invalid claim-to-header header name"))?;
104 let mut val = req
105 .headers
106 .get("x-forwarded-for")
107 .map(|v| v.as_bytes())
108 .unwrap_or(b"")
109 .to_owned();
110 val.extend(b",");
111 val.extend(addr.as_bytes());
112 let val = HeaderValue::from_bytes(&val)
113 .map_err(internal_error_from("invalid remote-addr header value"))?;
114 req.headers.insert(name, val);
115 }
116 if let Some(header) = &info.headers.x_forwarded_proto {
117 let name = HeaderName::from_bytes(header.as_bytes())
118 .map_err(internal_error_from("invalid claim-to-header header name"))?;
119 req.headers.insert(
120 name,
121 HeaderValue::from_static(if is_https { "https" } else { "http" }),
122 );
123 }
124 if let Some(header) = &info.headers.remote_addr
125 && let Some(addr) = &remote_addr
126 {
127 let name = HeaderName::from_bytes(header.as_bytes())
128 .map_err(internal_error_from("invalid claim-to-header header name"))?;
129 let val = HeaderValue::from_str(addr)
130 .map_err(internal_error_from("invalid remote-addr header value"))?;
131 req.headers.insert(name, val);
132 }
133
134 Ok(())
135 }
136
137 /// check auth, starting the flow if necessary
138 async fn check_auth(
139 &self,
140 session: &mut Session,
141 auth_info: &oidc::Info,
142 ctx: &mut AuthCtx,
143 ) -> Result<ControlFlow<()>> {
144 use auth_code_flow::code_request;
145
146 let req = session.req_header_mut();
147 let cookies = httputil::cookie_jar(req)?.unwrap_or_default();
148
149 let auth_cookie = cookies
150 .get(SESSION_COOKIE_NAME)
151 .map(|c| c.value())
152 .and_then(|c| cookies::CookieContents::contents(c, &auth_info.cookie_signing_key).ok());
153 {
154 // auth_info map pin
155 let sessions = auth_info.sessions.pin();
156 if let Some(valid_session) = auth_cookie
157 .and_then(|c| sessions.get(&c.session_id))
158 .filter(|sess| sess.expires_at > jiff::Timestamp::now())
159 {
160 if let Some(claim_map) = &auth_info.config.claims {
161 for (claim, header) in &claim_map.claim_to_header {
162 match valid_session.claims.get(claim) {
163 Some(val) => {
164 let val = HeaderValue::from_bytes(val.as_bytes())
165 .map_err(internal_error_from("invalid claim value"))?;
166 let name = HeaderName::from_bytes(header.as_bytes()).map_err(
167 internal_error_from("invalid claim-to-header header name"),
168 )?;
169 req.headers.insert(name, val)
170 }
171 None => req.headers.remove(header),
172 };
173 }
174 }
175
176 ctx.session_valid = true;
177
178 return Ok(ControlFlow::Continue(()));
179 }
180 }
181
182 // otherwise! start the auth flow
183 let meta_cache = auth_info.get_or_cache_metadata().await?;
184
185 // TODO(cleanup): precompute scopes
186 let redirect_info = code_request::redirect_to_auth_server(
187 (&meta_cache.metadata).into(),
188 code_request::Data::new(
189 &auth_info.config.client_id,
190 &auth_info
191 .config
192 .scopes
193 .as_ref()
194 .map(|s| {
195 s.required
196 .iter()
197 .fold(auth_code_flow::Scopes::base_scopes(), |scopes, scope| {
198 scopes.add_scope(scope)
199 })
200 })
201 .unwrap_or(auth_code_flow::Scopes::base_scopes()),
202 // technically this is a spec violate, but it's useful for testing
203 &Url::parse(&format!(
204 "https://{domain}/{path}",
205 domain = auth_info.domain,
206 path = OAUTH_CONTINUE_PATH,
207 ))
208 .map_err(internal_error_from(
209 "unable to construct redirect url from domain",
210 ))?,
211 ),
212 )
213 .map_err(internal_error_from("unable to construct redirect"))?;
214
215 if auth_info
216 .auth_states
217 .pin()
218 .try_insert(
219 redirect_info.state,
220 InProgressAuth {
221 code_verifier: redirect_info.code_verifier,
222 original_path: req.uri.path().to_string(),
223 },
224 )
225 .is_err()
226 {
227 // this is _extremely_ unlikely to happen, but worth checking anyway
228 return Err(internal_error("state id collision")());
229 };
230
231 httputil::redirect_response(session, redirect_info.url.as_str(), |_, _| Ok(())).await?;
232 Ok(ControlFlow::Break(()))
233 }
234
235 /// continue auth from inbound redirects, or logout from a logout redirect
236 async fn receive_redirect(
237 &self,
238 session: &mut Session,
239 info: &DomainInfo,
240 ) -> Result<ControlFlow<()>> {
241 use auth_code_flow::{code_response, token_request, token_response};
242
243 let req = session.req_header();
244 let Some(pq) = req.uri.path_and_query() else {
245 return Ok(ControlFlow::Continue(()));
246 };
247 let Some(auth_info) = &info.oidc else {
248 return Ok(ControlFlow::Continue(()));
249 };
250
251 if pq.path() == OAUTH_LOGOUT_PATH {
252 let Some(mut cookies) = httputil::cookie_jar(req)? else {
253 // we're not logged in, just return fine
254 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(()))
255 .await?;
256 return Ok(ControlFlow::Break(()));
257 };
258
259 {
260 let Some(raw) = cookies
261 .get(SESSION_COOKIE_NAME)
262 .map(|raw| raw.value().to_string())
263 else {
264 // we're not logged in, just return fine
265 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| {
266 Ok(())
267 })
268 .await?;
269 return Ok(ControlFlow::Break(()));
270 };
271 cookies.remove(SESSION_COOKIE_NAME);
272
273 let Some(cookies::CookieMessage { session_id }) =
274 cookies::CookieContents::contents(&raw, &auth_info.cookie_signing_key).ok()
275 else {
276 // invalid cookie, just ignore
277 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| {
278 Ok(())
279 })
280 .await?;
281 return Ok(ControlFlow::Break(()));
282 };
283 auth_info.sessions.pin().remove(&session_id);
284 };
285
286 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(()))
287 .await?;
288 return Ok(ControlFlow::Break(()));
289 }
290
291 if let Some(auth_info) = &info.oidc
292 && pq.path().starts_with("/")
293 && &pq.path()[1..] == OAUTH_CONTINUE_PATH
294 {
295 let Some(query) = pq.query() else {
296 session
297 .respond_error(StatusCode::BAD_REQUEST.into())
298 .await?;
299 return Ok(ControlFlow::Break(()));
300 };
301
302 let Some(meta_cache) = auth_info.meta_cache.lock().await.as_ref().cloned() else {
303 // if we don't already have discovery metadata, something's real weird, cause
304 // how did we start the flow
305 session
306 .respond_error(StatusCode::BAD_REQUEST.into())
307 .await?;
308 return Ok(ControlFlow::Break(()));
309 };
310
311 let status =
312 code_response::receive_redirect(query, meta_cache.metadata.issuer.as_str())
313 .map_err(status_error_from(
314 "unable to deserialize oauth2 response",
315 ErrorSource::Internal,
316 StatusCode::BAD_REQUEST,
317 ))?;
318 let resp = match status {
319 Ok(resp) => resp,
320 Err(err) => {
321 auth_info.auth_states.pin().remove(&err.state);
322 match err.error {
323 code_response::ErrorType::AccessDenied => {
324 session.respond_error(StatusCode::FORBIDDEN.into()).await?;
325 return Ok(ControlFlow::Break(()));
326 }
327 code_response::ErrorType::TemporarilyUnavailable => {
328 session
329 .respond_error(StatusCode::SERVICE_UNAVAILABLE.into())
330 .await?;
331 return Ok(ControlFlow::Break(()));
332 }
333 _ => {
334 session
335 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
336 .await?;
337 return Ok(ControlFlow::Break(()));
338 }
339 }
340 }
341 };
342
343 let Some(in_progress) = auth_info.auth_states.pin().remove(&resp.state).cloned() else {
344 session
345 .respond_error(StatusCode::BAD_REQUEST.into())
346 .await?;
347 return Ok(ControlFlow::Break(()));
348 };
349
350 let mut body = String::new();
351 let redirect_uri = &format!(
352 "https://{domain}/{path}",
353 domain = auth_info.domain,
354 path = OAUTH_CONTINUE_PATH,
355 );
356 let token_req = token_request::request_access_token(
357 (&meta_cache.metadata).into(),
358 token_request::Data {
359 code: resp,
360 client_id: &auth_info.config.client_id,
361 client_secret: &auth_info.client_secret,
362 // this should not be a clone, but it's a weird quirk of our threadsafe
363 // hashmap choice
364 code_verifier: in_progress.code_verifier,
365 redirect_uri,
366 },
367 &mut body,
368 )
369 .map_err(internal_error_from("unable produce access token request"))?;
370
371 let resp: token_response::Valid = {
372 let client = reqwest::Client::new();
373 let resp = client
374 .post(token_req.url.as_str().to_string())
375 .header(
376 http::header::CONTENT_TYPE,
377 "application/x-www-form-urlencoded",
378 )
379 .body(body)
380 .send()
381 .await
382 .map_err(internal_error_from("unable to make token request"))?;
383 if resp.status() == StatusCode::BAD_REQUEST {
384 let _resp: token_response::Error = resp
385 .json()
386 .await
387 .map_err(internal_error_from("unable to deserialize response"))?;
388 session
389 .respond_error(StatusCode::BAD_REQUEST.into())
390 .await?;
391 return Ok(ControlFlow::Break(()));
392 // error per [the rfc][ref:draft-ietf-oauth-v2-1#3.2.4]
393 } else if resp.status() == StatusCode::NOT_FOUND {
394 // maybe it moved? try fetching the info again later
395 auth_info.clear_metadata_cache().await;
396 } else if !resp.status().is_success() {
397 session
398 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
399 .await?;
400 return Ok(ControlFlow::Break(()));
401 }
402 resp.json()
403 .await
404 .map_err(internal_error_from("unable to deserialize token response"))?
405 };
406
407 use std::str::FromStr as _;
408 let id_token = compact_jwt::JwtUnverified::from_str(
409 &resp
410 .id_token
411 .ok_or_else(internal_error("no id token in response"))?,
412 )
413 .map_err(internal_error_from("unable to deserialize id token"))?;
414
415 // will be some if we had the option "verify" turned on, will be none otherwise
416 // (will never be none if the option is on but we couldn't fetch the token)
417 let id_token: compact_jwt::Jwt<()> = match &meta_cache.jws_verifier {
418 Some(verifier) => {
419 use compact_jwt::JwsVerifier as _;
420 verifier
421 .verify(&id_token)
422 .map_err(internal_error_from("unable to verify id token"))?
423 }
424 None => {
425 use compact_jwt::JwsVerifier as _;
426 let verifier =
427 compact_jwt::dangernoverify::JwsDangerReleaseWithoutVerify::default();
428 verifier
429 .verify(&id_token)
430 .map_err(internal_error_from("unable to deserialize id_token to jwt"))?
431 }
432 };
433
434 // per https://openid.net/specs/openid-connect-core-1_0-final.html#TokenResponseValidation, we _must_ to check
435 // - iss (must match the expected issuer)
436 // - aud (must match our client_id)
437 // - exp (must expire in the future)
438 if id_token
439 .iss
440 .is_none_or(|iss| iss != meta_cache.metadata.issuer.as_str())
441 {
442 return Err(internal_error("issuer mismatch on id token")());
443 }
444 if id_token
445 .aud
446 .is_none_or(|aud| aud != auth_info.config.client_id)
447 {
448 return Err(internal_error("audience mismatch on id token")());
449 }
450 let expires_at = jiff::Timestamp::from_second(
451 id_token
452 .exp
453 .ok_or_else(internal_error("missing exp on token"))?,
454 )
455 .map_err(internal_error_from("unable to parse exp as timestamp"))?;
456 if expires_at < jiff::Timestamp::now() {
457 session
458 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
459 .await?;
460 return Ok(ControlFlow::Break(()));
461 }
462
463 let user_info = UserInfo {
464 expires_at,
465 claims: id_token
466 .claims
467 .into_iter()
468 // [`serde_json::Value`] implements display that's just "semi-infallible
469 // serialize"
470 .map(|(k, v)| (k, v.to_string()))
471 .collect(),
472 };
473 let expiry = user_info.expires_at;
474 let mut rng = rand::rngs::OsRng;
475 use rand::Rng as _;
476 let session_id = rng.r#gen();
477 auth_info.sessions.pin().insert(session_id, user_info);
478 let cookie = cookies::CookieContents::sign(
479 cookies::CookieMessage { session_id },
480 &auth_info.cookie_signing_key,
481 )
482 .map_err(|()| internal_error("unable to sign cookie")())?;
483
484 let url = format!(
485 "https://{domain}{original_path}",
486 domain = auth_info.domain,
487 original_path = in_progress.original_path
488 );
489 httputil::redirect_response(session, &url, |resp, session| {
490 let mut cookies = httputil::cookie_jar(session.req_header())?.unwrap_or_default();
491 cookies.set(
492 cookie_rs::Cookie::builder(SESSION_COOKIE_NAME, cookie)
493 .http_only(true)
494 .secure(true)
495 // utc technically potentially different than gmt, but this is just advisory (we enforce
496 // elsewhere), so it's ok
497 .max_age(
498 std::time::Duration::try_from(expiry - jiff::Timestamp::now())
499 .expect("formed from timestamps, can't have relative parts"),
500 )
501 .path("/")
502 .build(),
503 );
504 cookies
505 .as_header_values()
506 .into_iter()
507 .try_for_each(|cookie| {
508 let val = HeaderValue::from_bytes(cookie.as_bytes())
509 .map_err(internal_error_from("bad cookie header value"))?;
510
511 resp.append_header(http::header::SET_COOKIE, val)?;
512 Ok::<_, Box<Error>>(())
513 })?;
514 Ok(())
515 })
516 .await?;
517 return Ok(ControlFlow::Break(()));
518 }
519
520 Ok(ControlFlow::Continue(()))
521 }
522}
523
524pub struct AuthCtx {
525 session_valid: bool,
526}
527
528/// the oauth2 redirect path, without the leading slash
529const OAUTH_CONTINUE_PATH: &str = ".oauth2/continue";
530/// the logout/cookie-clear path, _with_ the leading slash
531const OAUTH_LOGOUT_PATH: &str = "/.oauth2/logout";
532
533#[async_trait]
534impl ProxyHttp for AuthGateway {
535 type CTX = AuthCtx;
536 fn new_ctx(&self) -> Self::CTX {
537 AuthCtx {
538 session_valid: false,
539 }
540 }
541
542 async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
543 let info = self.domain_info(session)?;
544
545 // check if we need to terminate the connection cause someone sent us an http request and we
546 // don't allow that
547 let is_https = session
548 .digest()
549 .and_then(|d| d.ssl_digest.as_ref())
550 .is_some();
551 if !is_https {
552 use config::format::domain::TlsMode;
553 match info.tls_mode {
554 TlsMode::Only => {
555 // we should just drop the connection, although people should really just be
556 // using HSTS
557 session.shutdown().await;
558 return Ok(true);
559 }
560 TlsMode::UnsafeAllowHttp => {}
561 }
562 }
563
564 // next, check if we're in the middle of an oauth flow
565 match self.receive_redirect(session, info).await? {
566 ControlFlow::Continue(()) => {}
567 ControlFlow::Break(()) => return Ok(true),
568 }
569
570 // finally check our actual auth state, starting the auth flow as needed
571 if let Some(auth_info) = &info.oidc {
572 match self.check_auth(session, auth_info, ctx).await? {
573 ControlFlow::Continue(()) => {}
574 ControlFlow::Break(()) => return Ok(true),
575 }
576 }
577
578 // we're past auth and are processing as normal, proceed
579 self.strip_and_apply_general_headers(session, info, is_https)
580 .await?;
581
582 Ok(false)
583 }
584
585 async fn upstream_peer(
586 &self,
587 session: &mut Session,
588 _ctx: &mut Self::CTX,
589 ) -> Result<Box<HttpPeer>> {
590 fn client_addr_key(sock_addr: &pingora::protocols::l4::socket::SocketAddr) -> Vec<u8> {
591 use pingora::protocols::l4::socket::SocketAddr;
592 match sock_addr {
593 SocketAddr::Inet(socket_addr) => match socket_addr {
594 std::net::SocketAddr::V4(v4) => Vec::from(v4.ip().octets()),
595 std::net::SocketAddr::V6(v6) => Vec::from(v6.ip().octets()),
596 },
597 _ => unreachable!(),
598 }
599 }
600
601 let backends = self.domain_info(session)?;
602 let backend = backends
603 .balancer
604 // NB: this means that CGNAT, other proxies, etc will? consistently hit the same
605 // backend, so we might wanna take that into consideration. fine for now, this is
606 // currently for personal use ;-)
607 .select(
608 &client_addr_key(session.client_addr().ok_or_else(status_error(
609 "no client address",
610 ErrorSource::Downstream,
611 StatusCode::BAD_REQUEST,
612 ))?), /* lb on client address */
613 256,
614 )
615 .ok_or_else(status_error(
616 "no available backends",
617 ErrorSource::Upstream,
618 StatusCode::SERVICE_UNAVAILABLE,
619 ))?;
620
621 let needs_tls = backend
622 .ext
623 .get::<BackendData>()
624 .map(|d| d.tls)
625 .unwrap_or(true);
626
627 Ok(Box::new(HttpPeer::new(
628 backend,
629 needs_tls,
630 backends.sni_name.to_string(),
631 )))
632 }
633
634 async fn response_filter(
635 &self,
636 _session: &mut Session,
637 upstream_response: &mut ResponseHeader,
638 ctx: &mut Self::CTX,
639 ) -> Result<()>
640 where
641 Self::CTX: Send + Sync,
642 {
643 // if we had no valid session, clear the cookie
644 if !ctx.session_valid {
645 let mut cookies = CookieJar::default();
646 cookies.remove(SESSION_COOKIE_NAME);
647
648 cookies.as_header_values().into_iter().try_for_each(|v| {
649 let v = HeaderValue::from_bytes(v.as_bytes())
650 .map_err(internal_error_from("invalid clear cookie header value"))?;
651
652 upstream_response
653 .headers
654 .append(http::header::SET_COOKIE, v);
655 Ok::<_, Box<Error>>(())
656 })?;
657 }
658 Ok(())
659 }
660}
661
662/// additional data stored in the load balancer's backend structure
663///
664/// for use in [`AuthGateway::upstream_peer`]
665#[derive(Clone)]
666pub struct BackendData {
667 /// does the backend want tls
668 pub tls: bool,
669}