a (hacky, wip) multi-tenant oidc-terminating reverse proxy, written in anger on top of pingora
at wip/primary 669 lines 26 kB view raw
1//! the actual gateway implementation 2 3use std::collections::HashMap; 4use std::ops::ControlFlow; 5use std::sync::Arc; 6 7use async_trait::async_trait; 8use cookie_rs::CookieJar; 9use http::status::StatusCode; 10use http::{HeaderName, HeaderValue}; 11use pingora::lb::selection::consistent::KetamaHashing; 12use pingora::prelude::*; 13use url::Url; 14 15use crate::gateway::oidc::{InProgressAuth, SESSION_COOKIE_NAME, UserInfo}; 16use crate::httputil::{internal_error, internal_error_from, status_error, status_error_from}; 17use crate::oauth::auth_code_flow; 18use crate::{config, cookies, httputil}; 19 20pub mod oidc; 21 22/// per-domain information about backends and such 23pub struct DomainInfo { 24 /// the load balancer to use to select backends 25 pub balancer: Arc<LoadBalancer<KetamaHashing>>, 26 /// whether or not we allow insecure connections from clients 27 pub tls_mode: config::format::domain::TlsMode, 28 /// the sni name of this domain, used to pass to backends 29 pub sni_name: String, 30 /// auth settings for this domain, if any 31 pub oidc: Option<oidc::Info>, 32 /// headers to mangle for requests on this domain 33 pub headers: config::format::ManageHeaders, 34} 35 36/// the actual gateway logic 37pub struct AuthGateway { 38 /// all known domains and their corresponding backends & settings 39 pub domains: HashMap<String, DomainInfo>, 40} 41 42impl AuthGateway { 43 /// fetch the domain info for this request 44 fn domain_info<'s>(&'s self, session: &Session) -> Result<&'s DomainInfo> { 45 let req = session.req_header(); 46 // TODO(potential-bug): afaict, right now, afaict, pingora a) does not check that SNI matches the `Host` 47 // header, b) does not support extracting the SNI info on rustls, so we'll have to switch 48 // to boringssl and implement that ourselves T_T 49 let host = req 50 .headers 51 .get(http::header::HOST) 52 .ok_or_else(status_error( 53 "no host set", 54 ErrorSource::Downstream, 55 StatusCode::BAD_REQUEST, 56 ))? 57 .to_str() 58 .map_err(|e| { 59 Error::because( 60 ErrorType::HTTPStatus(StatusCode::BAD_REQUEST.into()), 61 "no host", 62 e, 63 ) 64 })?; 65 let info = self.domains.get(host).ok_or_else(status_error( 66 "unknown host", 67 ErrorSource::Downstream, 68 StatusCode::SERVICE_UNAVAILABLE, 69 ))?; 70 71 Ok(info) 72 } 73 74 /// mangle general headers, per [`config::format::ManageHeaders`] 75 async fn strip_and_apply_general_headers( 76 &self, 77 session: &mut Session, 78 info: &DomainInfo, 79 is_https: bool, 80 ) -> Result<()> { 81 let remote_addr = session.client_addr().and_then(|addr| match addr { 82 pingora::protocols::l4::socket::SocketAddr::Inet(socket_addr) => { 83 Some(socket_addr.ip().to_string()) 84 } 85 pingora::protocols::l4::socket::SocketAddr::Unix(_) => None, 86 }); 87 let req = session.req_header_mut(); 88 if let Some(header) = &info.headers.host { 89 // TODO(cleanup): preprocess all header names 90 let name = HeaderName::from_bytes(header.as_bytes()) 91 .map_err(internal_error_from("invalid claim-to-header header name"))?; 92 let val = req 93 .headers 94 .get(http::header::HOST) 95 .expect("we had to have this to look up our backend") 96 .clone(); 97 req.headers.insert(name, val); 98 } 99 if let Some(header) = &info.headers.x_forwarded_for 100 && let Some(addr) = &remote_addr 101 { 102 let name = HeaderName::from_bytes(header.as_bytes()) 103 .map_err(internal_error_from("invalid claim-to-header header name"))?; 104 let mut val = req 105 .headers 106 .get("x-forwarded-for") 107 .map(|v| v.as_bytes()) 108 .unwrap_or(b"") 109 .to_owned(); 110 val.extend(b","); 111 val.extend(addr.as_bytes()); 112 let val = HeaderValue::from_bytes(&val) 113 .map_err(internal_error_from("invalid remote-addr header value"))?; 114 req.headers.insert(name, val); 115 } 116 if let Some(header) = &info.headers.x_forwarded_proto { 117 let name = HeaderName::from_bytes(header.as_bytes()) 118 .map_err(internal_error_from("invalid claim-to-header header name"))?; 119 req.headers.insert( 120 name, 121 HeaderValue::from_static(if is_https { "https" } else { "http" }), 122 ); 123 } 124 if let Some(header) = &info.headers.remote_addr 125 && let Some(addr) = &remote_addr 126 { 127 let name = HeaderName::from_bytes(header.as_bytes()) 128 .map_err(internal_error_from("invalid claim-to-header header name"))?; 129 let val = HeaderValue::from_str(addr) 130 .map_err(internal_error_from("invalid remote-addr header value"))?; 131 req.headers.insert(name, val); 132 } 133 134 Ok(()) 135 } 136 137 /// check auth, starting the flow if necessary 138 async fn check_auth( 139 &self, 140 session: &mut Session, 141 auth_info: &oidc::Info, 142 ctx: &mut AuthCtx, 143 ) -> Result<ControlFlow<()>> { 144 use auth_code_flow::code_request; 145 146 let req = session.req_header_mut(); 147 let cookies = httputil::cookie_jar(req)?.unwrap_or_default(); 148 149 let auth_cookie = cookies 150 .get(SESSION_COOKIE_NAME) 151 .map(|c| c.value()) 152 .and_then(|c| cookies::CookieContents::contents(c, &auth_info.cookie_signing_key).ok()); 153 { 154 // auth_info map pin 155 let sessions = auth_info.sessions.pin(); 156 if let Some(valid_session) = auth_cookie 157 .and_then(|c| sessions.get(&c.session_id)) 158 .filter(|sess| sess.expires_at > jiff::Timestamp::now()) 159 { 160 if let Some(claim_map) = &auth_info.config.claims { 161 for (claim, header) in &claim_map.claim_to_header { 162 match valid_session.claims.get(claim) { 163 Some(val) => { 164 let val = HeaderValue::from_bytes(val.as_bytes()) 165 .map_err(internal_error_from("invalid claim value"))?; 166 let name = HeaderName::from_bytes(header.as_bytes()).map_err( 167 internal_error_from("invalid claim-to-header header name"), 168 )?; 169 req.headers.insert(name, val) 170 } 171 None => req.headers.remove(header), 172 }; 173 } 174 } 175 176 ctx.session_valid = true; 177 178 return Ok(ControlFlow::Continue(())); 179 } 180 } 181 182 // otherwise! start the auth flow 183 let meta_cache = auth_info.get_or_cache_metadata().await?; 184 185 // TODO(cleanup): precompute scopes 186 let redirect_info = code_request::redirect_to_auth_server( 187 (&meta_cache.metadata).into(), 188 code_request::Data::new( 189 &auth_info.config.client_id, 190 &auth_info 191 .config 192 .scopes 193 .as_ref() 194 .map(|s| { 195 s.required 196 .iter() 197 .fold(auth_code_flow::Scopes::base_scopes(), |scopes, scope| { 198 scopes.add_scope(scope) 199 }) 200 }) 201 .unwrap_or(auth_code_flow::Scopes::base_scopes()), 202 // technically this is a spec violate, but it's useful for testing 203 &Url::parse(&format!( 204 "https://{domain}/{path}", 205 domain = auth_info.domain, 206 path = OAUTH_CONTINUE_PATH, 207 )) 208 .map_err(internal_error_from( 209 "unable to construct redirect url from domain", 210 ))?, 211 ), 212 ) 213 .map_err(internal_error_from("unable to construct redirect"))?; 214 215 if auth_info 216 .auth_states 217 .pin() 218 .try_insert( 219 redirect_info.state, 220 InProgressAuth { 221 code_verifier: redirect_info.code_verifier, 222 original_path: req.uri.path().to_string(), 223 }, 224 ) 225 .is_err() 226 { 227 // this is _extremely_ unlikely to happen, but worth checking anyway 228 return Err(internal_error("state id collision")()); 229 }; 230 231 httputil::redirect_response(session, redirect_info.url.as_str(), |_, _| Ok(())).await?; 232 Ok(ControlFlow::Break(())) 233 } 234 235 /// continue auth from inbound redirects, or logout from a logout redirect 236 async fn receive_redirect( 237 &self, 238 session: &mut Session, 239 info: &DomainInfo, 240 ) -> Result<ControlFlow<()>> { 241 use auth_code_flow::{code_response, token_request, token_response}; 242 243 let req = session.req_header(); 244 let Some(pq) = req.uri.path_and_query() else { 245 return Ok(ControlFlow::Continue(())); 246 }; 247 let Some(auth_info) = &info.oidc else { 248 return Ok(ControlFlow::Continue(())); 249 }; 250 251 if pq.path() == OAUTH_LOGOUT_PATH { 252 let Some(mut cookies) = httputil::cookie_jar(req)? else { 253 // we're not logged in, just return fine 254 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(())) 255 .await?; 256 return Ok(ControlFlow::Break(())); 257 }; 258 259 { 260 let Some(raw) = cookies 261 .get(SESSION_COOKIE_NAME) 262 .map(|raw| raw.value().to_string()) 263 else { 264 // we're not logged in, just return fine 265 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| { 266 Ok(()) 267 }) 268 .await?; 269 return Ok(ControlFlow::Break(())); 270 }; 271 cookies.remove(SESSION_COOKIE_NAME); 272 273 let Some(cookies::CookieMessage { session_id }) = 274 cookies::CookieContents::contents(&raw, &auth_info.cookie_signing_key).ok() 275 else { 276 // invalid cookie, just ignore 277 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| { 278 Ok(()) 279 }) 280 .await?; 281 return Ok(ControlFlow::Break(())); 282 }; 283 auth_info.sessions.pin().remove(&session_id); 284 }; 285 286 httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(())) 287 .await?; 288 return Ok(ControlFlow::Break(())); 289 } 290 291 if let Some(auth_info) = &info.oidc 292 && pq.path().starts_with("/") 293 && &pq.path()[1..] == OAUTH_CONTINUE_PATH 294 { 295 let Some(query) = pq.query() else { 296 session 297 .respond_error(StatusCode::BAD_REQUEST.into()) 298 .await?; 299 return Ok(ControlFlow::Break(())); 300 }; 301 302 let Some(meta_cache) = auth_info.meta_cache.lock().await.as_ref().cloned() else { 303 // if we don't already have discovery metadata, something's real weird, cause 304 // how did we start the flow 305 session 306 .respond_error(StatusCode::BAD_REQUEST.into()) 307 .await?; 308 return Ok(ControlFlow::Break(())); 309 }; 310 311 let status = 312 code_response::receive_redirect(query, meta_cache.metadata.issuer.as_str()) 313 .map_err(status_error_from( 314 "unable to deserialize oauth2 response", 315 ErrorSource::Internal, 316 StatusCode::BAD_REQUEST, 317 ))?; 318 let resp = match status { 319 Ok(resp) => resp, 320 Err(err) => { 321 auth_info.auth_states.pin().remove(&err.state); 322 match err.error { 323 code_response::ErrorType::AccessDenied => { 324 session.respond_error(StatusCode::FORBIDDEN.into()).await?; 325 return Ok(ControlFlow::Break(())); 326 } 327 code_response::ErrorType::TemporarilyUnavailable => { 328 session 329 .respond_error(StatusCode::SERVICE_UNAVAILABLE.into()) 330 .await?; 331 return Ok(ControlFlow::Break(())); 332 } 333 _ => { 334 session 335 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) 336 .await?; 337 return Ok(ControlFlow::Break(())); 338 } 339 } 340 } 341 }; 342 343 let Some(in_progress) = auth_info.auth_states.pin().remove(&resp.state).cloned() else { 344 session 345 .respond_error(StatusCode::BAD_REQUEST.into()) 346 .await?; 347 return Ok(ControlFlow::Break(())); 348 }; 349 350 let mut body = String::new(); 351 let redirect_uri = &format!( 352 "https://{domain}/{path}", 353 domain = auth_info.domain, 354 path = OAUTH_CONTINUE_PATH, 355 ); 356 let token_req = token_request::request_access_token( 357 (&meta_cache.metadata).into(), 358 token_request::Data { 359 code: resp, 360 client_id: &auth_info.config.client_id, 361 client_secret: &auth_info.client_secret, 362 // this should not be a clone, but it's a weird quirk of our threadsafe 363 // hashmap choice 364 code_verifier: in_progress.code_verifier, 365 redirect_uri, 366 }, 367 &mut body, 368 ) 369 .map_err(internal_error_from("unable produce access token request"))?; 370 371 let resp: token_response::Valid = { 372 let client = reqwest::Client::new(); 373 let resp = client 374 .post(token_req.url.as_str().to_string()) 375 .header( 376 http::header::CONTENT_TYPE, 377 "application/x-www-form-urlencoded", 378 ) 379 .body(body) 380 .send() 381 .await 382 .map_err(internal_error_from("unable to make token request"))?; 383 if resp.status() == StatusCode::BAD_REQUEST { 384 let _resp: token_response::Error = resp 385 .json() 386 .await 387 .map_err(internal_error_from("unable to deserialize response"))?; 388 session 389 .respond_error(StatusCode::BAD_REQUEST.into()) 390 .await?; 391 return Ok(ControlFlow::Break(())); 392 // error per [the rfc][ref:draft-ietf-oauth-v2-1#3.2.4] 393 } else if resp.status() == StatusCode::NOT_FOUND { 394 // maybe it moved? try fetching the info again later 395 auth_info.clear_metadata_cache().await; 396 } else if !resp.status().is_success() { 397 session 398 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) 399 .await?; 400 return Ok(ControlFlow::Break(())); 401 } 402 resp.json() 403 .await 404 .map_err(internal_error_from("unable to deserialize token response"))? 405 }; 406 407 use std::str::FromStr as _; 408 let id_token = compact_jwt::JwtUnverified::from_str( 409 &resp 410 .id_token 411 .ok_or_else(internal_error("no id token in response"))?, 412 ) 413 .map_err(internal_error_from("unable to deserialize id token"))?; 414 415 // will be some if we had the option "verify" turned on, will be none otherwise 416 // (will never be none if the option is on but we couldn't fetch the token) 417 let id_token: compact_jwt::Jwt<()> = match &meta_cache.jws_verifier { 418 Some(verifier) => { 419 use compact_jwt::JwsVerifier as _; 420 verifier 421 .verify(&id_token) 422 .map_err(internal_error_from("unable to verify id token"))? 423 } 424 None => { 425 use compact_jwt::JwsVerifier as _; 426 let verifier = 427 compact_jwt::dangernoverify::JwsDangerReleaseWithoutVerify::default(); 428 verifier 429 .verify(&id_token) 430 .map_err(internal_error_from("unable to deserialize id_token to jwt"))? 431 } 432 }; 433 434 // per https://openid.net/specs/openid-connect-core-1_0-final.html#TokenResponseValidation, we _must_ to check 435 // - iss (must match the expected issuer) 436 // - aud (must match our client_id) 437 // - exp (must expire in the future) 438 if id_token 439 .iss 440 .is_none_or(|iss| iss != meta_cache.metadata.issuer.as_str()) 441 { 442 return Err(internal_error("issuer mismatch on id token")()); 443 } 444 if id_token 445 .aud 446 .is_none_or(|aud| aud != auth_info.config.client_id) 447 { 448 return Err(internal_error("audience mismatch on id token")()); 449 } 450 let expires_at = jiff::Timestamp::from_second( 451 id_token 452 .exp 453 .ok_or_else(internal_error("missing exp on token"))?, 454 ) 455 .map_err(internal_error_from("unable to parse exp as timestamp"))?; 456 if expires_at < jiff::Timestamp::now() { 457 session 458 .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into()) 459 .await?; 460 return Ok(ControlFlow::Break(())); 461 } 462 463 let user_info = UserInfo { 464 expires_at, 465 claims: id_token 466 .claims 467 .into_iter() 468 // [`serde_json::Value`] implements display that's just "semi-infallible 469 // serialize" 470 .map(|(k, v)| (k, v.to_string())) 471 .collect(), 472 }; 473 let expiry = user_info.expires_at; 474 let mut rng = rand::rngs::OsRng; 475 use rand::Rng as _; 476 let session_id = rng.r#gen(); 477 auth_info.sessions.pin().insert(session_id, user_info); 478 let cookie = cookies::CookieContents::sign( 479 cookies::CookieMessage { session_id }, 480 &auth_info.cookie_signing_key, 481 ) 482 .map_err(|()| internal_error("unable to sign cookie")())?; 483 484 let url = format!( 485 "https://{domain}{original_path}", 486 domain = auth_info.domain, 487 original_path = in_progress.original_path 488 ); 489 httputil::redirect_response(session, &url, |resp, session| { 490 let mut cookies = httputil::cookie_jar(session.req_header())?.unwrap_or_default(); 491 cookies.set( 492 cookie_rs::Cookie::builder(SESSION_COOKIE_NAME, cookie) 493 .http_only(true) 494 .secure(true) 495 // utc technically potentially different than gmt, but this is just advisory (we enforce 496 // elsewhere), so it's ok 497 .max_age( 498 std::time::Duration::try_from(expiry - jiff::Timestamp::now()) 499 .expect("formed from timestamps, can't have relative parts"), 500 ) 501 .path("/") 502 .build(), 503 ); 504 cookies 505 .as_header_values() 506 .into_iter() 507 .try_for_each(|cookie| { 508 let val = HeaderValue::from_bytes(cookie.as_bytes()) 509 .map_err(internal_error_from("bad cookie header value"))?; 510 511 resp.append_header(http::header::SET_COOKIE, val)?; 512 Ok::<_, Box<Error>>(()) 513 })?; 514 Ok(()) 515 }) 516 .await?; 517 return Ok(ControlFlow::Break(())); 518 } 519 520 Ok(ControlFlow::Continue(())) 521 } 522} 523 524pub struct AuthCtx { 525 session_valid: bool, 526} 527 528/// the oauth2 redirect path, without the leading slash 529const OAUTH_CONTINUE_PATH: &str = ".oauth2/continue"; 530/// the logout/cookie-clear path, _with_ the leading slash 531const OAUTH_LOGOUT_PATH: &str = "/.oauth2/logout"; 532 533#[async_trait] 534impl ProxyHttp for AuthGateway { 535 type CTX = AuthCtx; 536 fn new_ctx(&self) -> Self::CTX { 537 AuthCtx { 538 session_valid: false, 539 } 540 } 541 542 async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> { 543 let info = self.domain_info(session)?; 544 545 // check if we need to terminate the connection cause someone sent us an http request and we 546 // don't allow that 547 let is_https = session 548 .digest() 549 .and_then(|d| d.ssl_digest.as_ref()) 550 .is_some(); 551 if !is_https { 552 use config::format::domain::TlsMode; 553 match info.tls_mode { 554 TlsMode::Only => { 555 // we should just drop the connection, although people should really just be 556 // using HSTS 557 session.shutdown().await; 558 return Ok(true); 559 } 560 TlsMode::UnsafeAllowHttp => {} 561 } 562 } 563 564 // next, check if we're in the middle of an oauth flow 565 match self.receive_redirect(session, info).await? { 566 ControlFlow::Continue(()) => {} 567 ControlFlow::Break(()) => return Ok(true), 568 } 569 570 // finally check our actual auth state, starting the auth flow as needed 571 if let Some(auth_info) = &info.oidc { 572 match self.check_auth(session, auth_info, ctx).await? { 573 ControlFlow::Continue(()) => {} 574 ControlFlow::Break(()) => return Ok(true), 575 } 576 } 577 578 // we're past auth and are processing as normal, proceed 579 self.strip_and_apply_general_headers(session, info, is_https) 580 .await?; 581 582 Ok(false) 583 } 584 585 async fn upstream_peer( 586 &self, 587 session: &mut Session, 588 _ctx: &mut Self::CTX, 589 ) -> Result<Box<HttpPeer>> { 590 fn client_addr_key(sock_addr: &pingora::protocols::l4::socket::SocketAddr) -> Vec<u8> { 591 use pingora::protocols::l4::socket::SocketAddr; 592 match sock_addr { 593 SocketAddr::Inet(socket_addr) => match socket_addr { 594 std::net::SocketAddr::V4(v4) => Vec::from(v4.ip().octets()), 595 std::net::SocketAddr::V6(v6) => Vec::from(v6.ip().octets()), 596 }, 597 _ => unreachable!(), 598 } 599 } 600 601 let backends = self.domain_info(session)?; 602 let backend = backends 603 .balancer 604 // NB: this means that CGNAT, other proxies, etc will? consistently hit the same 605 // backend, so we might wanna take that into consideration. fine for now, this is 606 // currently for personal use ;-) 607 .select( 608 &client_addr_key(session.client_addr().ok_or_else(status_error( 609 "no client address", 610 ErrorSource::Downstream, 611 StatusCode::BAD_REQUEST, 612 ))?), /* lb on client address */ 613 256, 614 ) 615 .ok_or_else(status_error( 616 "no available backends", 617 ErrorSource::Upstream, 618 StatusCode::SERVICE_UNAVAILABLE, 619 ))?; 620 621 let needs_tls = backend 622 .ext 623 .get::<BackendData>() 624 .map(|d| d.tls) 625 .unwrap_or(true); 626 627 Ok(Box::new(HttpPeer::new( 628 backend, 629 needs_tls, 630 backends.sni_name.to_string(), 631 ))) 632 } 633 634 async fn response_filter( 635 &self, 636 _session: &mut Session, 637 upstream_response: &mut ResponseHeader, 638 ctx: &mut Self::CTX, 639 ) -> Result<()> 640 where 641 Self::CTX: Send + Sync, 642 { 643 // if we had no valid session, clear the cookie 644 if !ctx.session_valid { 645 let mut cookies = CookieJar::default(); 646 cookies.remove(SESSION_COOKIE_NAME); 647 648 cookies.as_header_values().into_iter().try_for_each(|v| { 649 let v = HeaderValue::from_bytes(v.as_bytes()) 650 .map_err(internal_error_from("invalid clear cookie header value"))?; 651 652 upstream_response 653 .headers 654 .append(http::header::SET_COOKIE, v); 655 Ok::<_, Box<Error>>(()) 656 })?; 657 } 658 Ok(()) 659 } 660} 661 662/// additional data stored in the load balancer's backend structure 663/// 664/// for use in [`AuthGateway::upstream_peer`] 665#[derive(Clone)] 666pub struct BackendData { 667 /// does the backend want tls 668 pub tls: bool, 669}