//! the actual gateway implementation use std::collections::HashMap; use std::ops::ControlFlow; use std::sync::Arc; use async_trait::async_trait; use cookie_rs::CookieJar; use http::status::StatusCode; use http::{HeaderName, HeaderValue}; use pingora::lb::selection::consistent::KetamaHashing; use pingora::prelude::*; use url::Url; use crate::gateway::oidc::{InProgressAuth, SESSION_COOKIE_NAME, UserInfo}; use crate::httputil::{internal_error, internal_error_from, status_error, status_error_from}; use crate::oauth::auth_code_flow; use crate::{config, cookies, httputil}; pub mod oidc; /// per-domain information about backends and such pub struct DomainInfo { /// the load balancer to use to select backends pub balancer: Arc>, /// whether or not we allow insecure connections from clients pub tls_mode: config::format::domain::TlsMode, /// the sni name of this domain, used to pass to backends pub sni_name: String, /// auth settings for this domain, if any pub oidc: Option, /// headers to mangle for requests on this domain pub headers: config::format::ManageHeaders, } /// the actual gateway logic pub struct AuthGateway { /// all known domains and their corresponding backends & settings pub domains: HashMap, } impl AuthGateway { /// fetch the domain info for this request fn domain_info<'s>(&'s self, session: &Session) -> Result<&'s DomainInfo> { let req = session.req_header(); // TODO(potential-bug): afaict, right now, afaict, pingora a) does not check that SNI matches the `Host` // header, b) does not support extracting the SNI info on rustls, so we'll have to switch // to boringssl and implement that ourselves T_T let host = req .headers .get(http::header::HOST) .ok_or_else(status_error( "no host set", ErrorSource::Downstream, StatusCode::BAD_REQUEST, ))? .to_str() .map_err(|e| { Error::because( ErrorType::HTTPStatus(StatusCode::BAD_REQUEST.into()), "no host", e, ) })?; let info = self.domains.get(host).ok_or_else(status_error( "unknown host", ErrorSource::Downstream, StatusCode::SERVICE_UNAVAILABLE, ))?; Ok(info) } /// mangle general headers, per [`config::format::ManageHeaders`] async fn strip_and_apply_general_headers( &self, session: &mut Session, info: &DomainInfo, is_https: bool, ) -> Result<()> { let remote_addr = session.client_addr().and_then(|addr| match addr { pingora::protocols::l4::socket::SocketAddr::Inet(socket_addr) => { Some(socket_addr.ip().to_string()) } pingora::protocols::l4::socket::SocketAddr::Unix(_) => None, }); let req = session.req_header_mut(); if let Some(header) = &info.headers.host { // TODO(cleanup): preprocess all header names let name = HeaderName::from_bytes(header.as_bytes()) .map_err(internal_error_from("invalid claim-to-header header name"))?; let val = req .headers .get(http::header::HOST) .expect("we had to have this to look up our backend") .clone(); req.headers.insert(name, val); } if let Some(header) = &info.headers.x_forwarded_for && let Some(addr) = &remote_addr { let name = HeaderName::from_bytes(header.as_bytes()) .map_err(internal_error_from("invalid claim-to-header header name"))?; let mut val = req .headers .get("x-forwarded-for") .map(|v| v.as_bytes()) .unwrap_or(b"") .to_owned(); val.extend(b","); val.extend(addr.as_bytes()); let val = HeaderValue::from_bytes(&val) .map_err(internal_error_from("invalid remote-addr header value"))?; req.headers.insert(name, val); } if let Some(header) = &info.headers.x_forwarded_proto { let name = HeaderName::from_bytes(header.as_bytes()) .map_err(internal_error_from("invalid claim-to-header header name"))?; req.headers.insert( name, HeaderValue::from_static(if is_https { "https" } else { "http" }), ); } if let Some(header) = &info.headers.remote_addr && let Some(addr) = &remote_addr { let name = HeaderName::from_bytes(header.as_bytes()) .map_err(internal_error_from("invalid claim-to-header header name"))?; let val = HeaderValue::from_str(addr) .map_err(internal_error_from("invalid remote-addr header value"))?; req.headers.insert(name, val); } Ok(()) } /// check auth, starting the flow if necessary async fn check_auth( &self, session: &mut Session, auth_info: &oidc::Info, ctx: &mut AuthCtx, ) -> Result> { use auth_code_flow::code_request; let req = session.req_header_mut(); let cookies = httputil::cookie_jar(req)?.unwrap_or_default(); let auth_cookie = cookies .get(SESSION_COOKIE_NAME) .map(|c| c.value()) .and_then(|c| cookies::CookieContents::contents(c, &auth_info.cookie_signing_key).ok()); { // auth_info map pin let sessions = auth_info.sessions.pin(); if let Some(valid_session) = auth_cookie .and_then(|c| sessions.get(&c.session_id)) .filter(|sess| sess.expires_at > jiff::Timestamp::now()) { if let Some(claim_map) = &auth_info.config.claims { for (claim, header) in &claim_map.claim_to_header { match valid_session.claims.get(claim) { Some(val) => { let val = HeaderValue::from_bytes(val.as_bytes()) .map_err(internal_error_from("invalid claim value"))?; let name = HeaderName::from_bytes(header.as_bytes()).map_err( internal_error_from("invalid claim-to-header header name"), )?; req.headers.insert(name, val) } None => req.headers.remove(header), }; } } ctx.session_valid = true; return Ok(ControlFlow::Continue(())); } } // otherwise! start the auth flow let meta_cache = auth_info.get_or_cache_metadata().await?; // TODO(cleanup): precompute scopes let redirect_info = code_request::redirect_to_auth_server( (&meta_cache.metadata).into(), code_request::Data::new( &auth_info.config.client_id, &auth_info .config .scopes .as_ref() .map(|s| { s.required .iter() .fold(auth_code_flow::Scopes::base_scopes(), |scopes, scope| { scopes.add_scope(scope) }) }) .unwrap_or(auth_code_flow::Scopes::base_scopes()), // technically this is a spec violate, but it's useful for testing &Url::parse(&format!( "https://{domain}/{path}", domain = auth_info.domain, path = OAUTH_CONTINUE_PATH, )) .map_err(internal_error_from( "unable to construct redirect url from domain", ))?, ), ) .map_err(internal_error_from("unable to construct redirect"))?; if auth_info .auth_states .pin() .try_insert( redirect_info.state, InProgressAuth { code_verifier: redirect_info.code_verifier, original_path: req.uri.path().to_string(), }, ) .is_err() { // this is _extremely_ unlikely to happen, but worth checking anyway return Err(internal_error("state id collision")()); }; httputil::redirect_response(session, redirect_info.url.as_str(), |_, _| Ok(())).await?; Ok(ControlFlow::Break(())) } /// continue auth from inbound redirects, or logout from a logout redirect async fn receive_redirect( &self, session: &mut Session, info: &DomainInfo, ) -> Result> { use auth_code_flow::{code_response, token_request, token_response}; let req = session.req_header(); let Some(pq) = req.uri.path_and_query() else { return Ok(ControlFlow::Continue(())); }; let Some(auth_info) = &info.oidc else { return Ok(ControlFlow::Continue(())); }; if pq.path() == OAUTH_LOGOUT_PATH { let Some(mut cookies) = httputil::cookie_jar(req)? else { // we're not logged in, just return fine httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(())) .await?; return Ok(ControlFlow::Break(())); }; { let Some(raw) = cookies .get(SESSION_COOKIE_NAME) .map(|raw| raw.value().to_string()) else { // we're not logged in, just return fine httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| { Ok(()) }) .await?; return Ok(ControlFlow::Break(())); }; cookies.remove(SESSION_COOKIE_NAME); let Some(cookies::CookieMessage { session_id }) = cookies::CookieContents::contents(&raw, &auth_info.cookie_signing_key).ok() else { // invalid cookie, just ignore httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| { Ok(()) }) .await?; return Ok(ControlFlow::Break(())); }; auth_info.sessions.pin().remove(&session_id); }; httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(())) .await?; return Ok(ControlFlow::Break(())); } if let Some(auth_info) = &info.oidc && pq.path().starts_with("/") && &pq.path()[1..] == OAUTH_CONTINUE_PATH { let Some(query) = pq.query() else { session .respond_error(StatusCode::BAD_REQUEST.into()) .await?; return Ok(ControlFlow::Break(())); }; let Some(meta_cache) = auth_info.meta_cache.lock().await.as_ref().cloned() else { // if we don't already have discovery metadata, something's real weird, cause // how did we start the flow session .respond_error(StatusCode::BAD_REQUEST.into()) .await?; return Ok(ControlFlow::Break(())); }; let status = code_response::receive_redirect(query, meta_cache.metadata.issuer.as_str()) .map_err(status_error_from( "unable to deserialize oauth2 response", ErrorSource::Internal, StatusCode::BAD_REQUEST, ))?; let resp = match status { Ok(resp) => resp, Err(err) => { auth_info.auth_states.pin().remove(&err.state); match err.error { code_response::ErrorType::AccessDenied => { session.respond_error(StatusCode::FORBIDDEN.into()).await?; return Ok(ControlFlow::Break(())); } code_response::ErrorType::TemporarilyUnavailable => { session .respond_error(StatusCode::SERVICE_UNAVAILABLE.into()) .await?; return Ok(ControlFlow::Break(())); } _ => { session .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) .await?; return Ok(ControlFlow::Break(())); } } } }; let Some(in_progress) = auth_info.auth_states.pin().remove(&resp.state).cloned() else { session .respond_error(StatusCode::BAD_REQUEST.into()) .await?; return Ok(ControlFlow::Break(())); }; let mut body = String::new(); let redirect_uri = &format!( "https://{domain}/{path}", domain = auth_info.domain, path = OAUTH_CONTINUE_PATH, ); let token_req = token_request::request_access_token( (&meta_cache.metadata).into(), token_request::Data { code: resp, client_id: &auth_info.config.client_id, client_secret: &auth_info.client_secret, // this should not be a clone, but it's a weird quirk of our threadsafe // hashmap choice code_verifier: in_progress.code_verifier, redirect_uri, }, &mut body, ) .map_err(internal_error_from("unable produce access token request"))?; let resp: token_response::Valid = { let client = reqwest::Client::new(); let resp = client .post(token_req.url.as_str().to_string()) .header( http::header::CONTENT_TYPE, "application/x-www-form-urlencoded", ) .body(body) .send() .await .map_err(internal_error_from("unable to make token request"))?; if resp.status() == StatusCode::BAD_REQUEST { let _resp: token_response::Error = resp .json() .await .map_err(internal_error_from("unable to deserialize response"))?; session .respond_error(StatusCode::BAD_REQUEST.into()) .await?; return Ok(ControlFlow::Break(())); // error per [the rfc][ref:draft-ietf-oauth-v2-1#3.2.4] } else if resp.status() == StatusCode::NOT_FOUND { // maybe it moved? try fetching the info again later auth_info.clear_metadata_cache().await; } else if !resp.status().is_success() { session .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) .await?; return Ok(ControlFlow::Break(())); } resp.json() .await .map_err(internal_error_from("unable to deserialize token response"))? }; use std::str::FromStr as _; let id_token = compact_jwt::JwtUnverified::from_str( &resp .id_token .ok_or_else(internal_error("no id token in response"))?, ) .map_err(internal_error_from("unable to deserialize id token"))?; // will be some if we had the option "verify" turned on, will be none otherwise // (will never be none if the option is on but we couldn't fetch the token) let id_token: compact_jwt::Jwt<()> = match &meta_cache.jws_verifier { Some(verifier) => { use compact_jwt::JwsVerifier as _; verifier .verify(&id_token) .map_err(internal_error_from("unable to verify id token"))? } None => { use compact_jwt::JwsVerifier as _; let verifier = compact_jwt::dangernoverify::JwsDangerReleaseWithoutVerify::default(); verifier .verify(&id_token) .map_err(internal_error_from("unable to deserialize id_token to jwt"))? } }; // per https://openid.net/specs/openid-connect-core-1_0-final.html#TokenResponseValidation, we _must_ to check // - iss (must match the expected issuer) // - aud (must match our client_id) // - exp (must expire in the future) if id_token .iss .is_none_or(|iss| iss != meta_cache.metadata.issuer.as_str()) { return Err(internal_error("issuer mismatch on id token")()); } if id_token .aud .is_none_or(|aud| aud != auth_info.config.client_id) { return Err(internal_error("audience mismatch on id token")()); } let expires_at = jiff::Timestamp::from_second( id_token .exp .ok_or_else(internal_error("missing exp on token"))?, ) .map_err(internal_error_from("unable to parse exp as timestamp"))?; if expires_at < jiff::Timestamp::now() { session .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) .await?; return Ok(ControlFlow::Break(())); } let user_info = UserInfo { expires_at, claims: id_token .claims .into_iter() // [`serde_json::Value`] implements display that's just "semi-infallible // serialize" .map(|(k, v)| (k, v.to_string())) .collect(), }; let expiry = user_info.expires_at; let mut rng = rand::rngs::OsRng; use rand::Rng as _; let session_id = rng.r#gen(); auth_info.sessions.pin().insert(session_id, user_info); let cookie = cookies::CookieContents::sign( cookies::CookieMessage { session_id }, &auth_info.cookie_signing_key, ) .map_err(|()| internal_error("unable to sign cookie")())?; let url = format!( "https://{domain}{original_path}", domain = auth_info.domain, original_path = in_progress.original_path ); httputil::redirect_response(session, &url, |resp, session| { let mut cookies = httputil::cookie_jar(session.req_header())?.unwrap_or_default(); cookies.set( cookie_rs::Cookie::builder(SESSION_COOKIE_NAME, cookie) .http_only(true) .secure(true) // utc technically potentially different than gmt, but this is just advisory (we enforce // elsewhere), so it's ok .max_age( std::time::Duration::try_from(expiry - jiff::Timestamp::now()) .expect("formed from timestamps, can't have relative parts"), ) .path("/") .build(), ); cookies .as_header_values() .into_iter() .try_for_each(|cookie| { let val = HeaderValue::from_bytes(cookie.as_bytes()) .map_err(internal_error_from("bad cookie header value"))?; resp.append_header(http::header::SET_COOKIE, val)?; Ok::<_, Box>(()) })?; Ok(()) }) .await?; return Ok(ControlFlow::Break(())); } Ok(ControlFlow::Continue(())) } } pub struct AuthCtx { session_valid: bool, } /// the oauth2 redirect path, without the leading slash const OAUTH_CONTINUE_PATH: &str = ".oauth2/continue"; /// the logout/cookie-clear path, _with_ the leading slash const OAUTH_LOGOUT_PATH: &str = "/.oauth2/logout"; #[async_trait] impl ProxyHttp for AuthGateway { type CTX = AuthCtx; fn new_ctx(&self) -> Self::CTX { AuthCtx { session_valid: false, } } async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result { let info = self.domain_info(session)?; // check if we need to terminate the connection cause someone sent us an http request and we // don't allow that let is_https = session .digest() .and_then(|d| d.ssl_digest.as_ref()) .is_some(); if !is_https { use config::format::domain::TlsMode; match info.tls_mode { TlsMode::Only => { // we should just drop the connection, although people should really just be // using HSTS session.shutdown().await; return Ok(true); } TlsMode::UnsafeAllowHttp => {} } } // next, check if we're in the middle of an oauth flow match self.receive_redirect(session, info).await? { ControlFlow::Continue(()) => {} ControlFlow::Break(()) => return Ok(true), } // finally check our actual auth state, starting the auth flow as needed if let Some(auth_info) = &info.oidc { match self.check_auth(session, auth_info, ctx).await? { ControlFlow::Continue(()) => {} ControlFlow::Break(()) => return Ok(true), } } // we're past auth and are processing as normal, proceed self.strip_and_apply_general_headers(session, info, is_https) .await?; Ok(false) } async fn upstream_peer( &self, session: &mut Session, _ctx: &mut Self::CTX, ) -> Result> { fn client_addr_key(sock_addr: &pingora::protocols::l4::socket::SocketAddr) -> Vec { use pingora::protocols::l4::socket::SocketAddr; match sock_addr { SocketAddr::Inet(socket_addr) => match socket_addr { std::net::SocketAddr::V4(v4) => Vec::from(v4.ip().octets()), std::net::SocketAddr::V6(v6) => Vec::from(v6.ip().octets()), }, _ => unreachable!(), } } let backends = self.domain_info(session)?; let backend = backends .balancer // NB: this means that CGNAT, other proxies, etc will? consistently hit the same // backend, so we might wanna take that into consideration. fine for now, this is // currently for personal use ;-) .select( &client_addr_key(session.client_addr().ok_or_else(status_error( "no client address", ErrorSource::Downstream, StatusCode::BAD_REQUEST, ))?), /* lb on client address */ 256, ) .ok_or_else(status_error( "no available backends", ErrorSource::Upstream, StatusCode::SERVICE_UNAVAILABLE, ))?; let needs_tls = backend .ext .get::() .map(|d| d.tls) .unwrap_or(true); Ok(Box::new(HttpPeer::new( backend, needs_tls, backends.sni_name.to_string(), ))) } async fn response_filter( &self, _session: &mut Session, upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX, ) -> Result<()> where Self::CTX: Send + Sync, { // if we had no valid session, clear the cookie if !ctx.session_valid { let mut cookies = CookieJar::default(); cookies.remove(SESSION_COOKIE_NAME); cookies.as_header_values().into_iter().try_for_each(|v| { let v = HeaderValue::from_bytes(v.as_bytes()) .map_err(internal_error_from("invalid clear cookie header value"))?; upstream_response .headers .append(http::header::SET_COOKIE, v); Ok::<_, Box>(()) })?; } Ok(()) } } /// additional data stored in the load balancer's backend structure /// /// for use in [`AuthGateway::upstream_peer`] #[derive(Clone)] pub struct BackendData { /// does the backend want tls pub tls: bool, }