···991010 // bind to tcp ports, with optional tls
1111 repeated TCPBinding bind_to_tcp = 2;
1212- // bind to unix domain sockets
1313- repeated UDSBinding bind_to_uds = 4;
14121513 // lower-level pingora config
1614 Pingora pingora = 3;
···1917message Domain {
2018 // require oidc auth if this is set
2119 optional OIDC oidc_auth = 1;
2222-2323- // TODO: ACME challenge hosting support natively?
24202521 // https backends
2622 repeated HTTPSBackend https = 3;
···55515652 Scopes scopes = 4;
5753 Claims claims = 5;
5454+5555+ // per oidc core v1-with-errata-2§3.1.3.7 point 6, we _may_ skip validation
5656+ // of the id token if it was received over tls. which it will be, in our
5757+ // case. some folks may want to be extra paranoid, but generally you either
5858+ // trust tls, or you can't trust discovery, and thus can't trust the jwks info,
5959+ // so default this to false.
6060+ bool validate_with_jwk = 6;
6161+6262+ // where to redirect to on logout
6363+ string logout_url = 7;
5864}
59656066message Scopes {
···97103 // set an `X-Forwarded-Proto`-style header to the original scheme of the request
98104 optional string x_forwarded_proto = 3;
99105 // set an `X-Real-IP`-style header (i.e. _just_ the remote address)
100100- repeated string remote_addr = 4;
106106+ optional string remote_addr = 4;
101107102108 // always clear these headers
103109 repeated string always_clear = 5;
···143149 // tls, if desired
144150 optional TLS tls = 2;
145151146146- // TODO: surface tcp options from pingora
147147-}
148148-message UDSBinding {
149149- message Permissions {
150150- uint32 mode = 1;
151151- }
152152- // socket path
153153- string path = 1;
154154- // permissions to set on the socket path
155155- optional Permissions permissions = 2;
152152+ // TODO(feature): surface tcp options from pingora
156153}
+85
src/cookies.rs
···11+//! # Cookie handling
22+//!
33+//! cookies are stored as postcard-encoded ed25519-signed tokens, [Protocol Buffer Tokens], which
44+//! this is inspired by. They contain session ids, _not_ the actual access token, which are stored
55+//! in an in-memory session store.
66+//!
77+//! the signing key is generated on server start this does mean server restarts invalidate active
88+//! sessions. this is not a big deal for my current usecase, and we could probably do secret
99+//! handover or have configurable secrets should the need arise.
1010+//!
1111+//! [Protocol Buffer Tokens]: https://fly.io/blog/api-tokens-a-tedious-survey/
1212+1313+use std::marker::PhantomData;
1414+1515+use base64::Engine as _;
1616+use base64::prelude::BASE64_STANDARD;
1717+use color_eyre::Result;
1818+use serde::de::DeserializeOwned;
1919+use serde::{Deserialize, Serialize};
2020+2121+#[derive(Serialize, Deserialize)]
2222+pub struct Signed<MSG> {
2323+ signature: ed25519_dalek::Signature,
2424+ message: Vec<u8>,
2525+ _kind: PhantomData<MSG>,
2626+}
2727+impl<MSG: DeserializeOwned + Versioned> Signed<MSG> {
2828+ pub fn contents(raw: &str, key: &ed25519_dalek::SigningKey) -> Result<MSG, ()> {
2929+ let raw = BASE64_STANDARD.decode(raw).map_err(drop)?;
3030+ // timing threats: technically, someone could try to discover the appropriate shape of our
3131+ // tokens by seeing if we early-return from the postcard::from_bytes message.
3232+ // this doesn't seem like a huge threat, since we're not encrypting our cookies, so the
3333+ // structure is already pretty obvious
3434+3535+ let envelope: Self = postcard::from_bytes(&raw).map_err(drop)?;
3636+ key.verify(&envelope.message, &envelope.signature)
3737+ .map_err(drop)?;
3838+ let (version_num, msg_raw): (u64, _) =
3939+ postcard::take_from_bytes(&envelope.message).map_err(drop)?;
4040+ // we're past the signature part, so we know this is ours.
4141+ // worst we're getting here is an attempt to check if old tokens still work,
4242+ // so we don't need to be as stressed about timing sensitivity
4343+ if version_num != MSG::VERSION {
4444+ return Err(());
4545+ }
4646+ postcard::from_bytes(msg_raw).map_err(drop)
4747+ }
4848+}
4949+5050+impl<MSG: Serialize + Versioned> Signed<MSG> {
5151+ pub fn sign(msg: MSG, key: &ed25519_dalek::SigningKey) -> Result<String, ()> {
5252+ use ed25519_dalek::ed25519::signature::Signer as _;
5353+5454+ let raw = {
5555+ let raw = Vec::new();
5656+ let raw = postcard::to_extend(&MSG::VERSION, raw).map_err(drop)?;
5757+ postcard::to_extend(&msg, raw).map_err(drop)?
5858+ };
5959+ let signature = key.sign(&raw);
6060+ let envelope = Self {
6161+ signature,
6262+ message: raw,
6363+ _kind: Default::default(),
6464+ };
6565+ let raw = postcard::to_extend(&envelope, Vec::new()).map_err(drop)?;
6666+ Ok(BASE64_STANDARD.encode(raw))
6767+ }
6868+}
6969+7070+pub trait Versioned {
7171+ const VERSION: u64;
7272+}
7373+7474+#[derive(Serialize, Deserialize)]
7575+pub struct CookieMessage {
7676+ // NB: per rfc:draft-ietf-oauth-v2-1#7.1.3.4, since we're signing our cookies and not encrypting
7777+ // them, we can't store the access token directly. instead we'll store the session id, which
7878+ // we'll use to look up the token in a session store.
7979+ pub session_id: u64,
8080+}
8181+impl Versioned for CookieMessage {
8282+ const VERSION: u64 = 1;
8383+}
8484+8585+pub type CookieContents = Signed<CookieMessage>;
+669
src/gateway.rs
···11+//! the actual gateway implementation
22+33+use std::collections::HashMap;
44+use std::ops::ControlFlow;
55+use std::sync::Arc;
66+77+use async_trait::async_trait;
88+use cookie_rs::CookieJar;
99+use http::status::StatusCode;
1010+use http::{HeaderName, HeaderValue};
1111+use pingora::lb::selection::consistent::KetamaHashing;
1212+use pingora::prelude::*;
1313+use url::Url;
1414+1515+use crate::gateway::oidc::{InProgressAuth, SESSION_COOKIE_NAME, UserInfo};
1616+use crate::httputil::{internal_error, internal_error_from, status_error, status_error_from};
1717+use crate::oauth::auth_code_flow;
1818+use crate::{config, cookies, httputil};
1919+2020+pub mod oidc;
2121+2222+/// per-domain information about backends and such
2323+pub struct DomainInfo {
2424+ /// the load balancer to use to select backends
2525+ pub balancer: Arc<LoadBalancer<KetamaHashing>>,
2626+ /// whether or not we allow insecure connections from clients
2727+ pub tls_mode: config::format::domain::TlsMode,
2828+ /// the sni name of this domain, used to pass to backends
2929+ pub sni_name: String,
3030+ /// auth settings for this domain, if any
3131+ pub oidc: Option<oidc::Info>,
3232+ /// headers to mangle for requests on this domain
3333+ pub headers: config::format::ManageHeaders,
3434+}
3535+3636+/// the actual gateway logic
3737+pub struct AuthGateway {
3838+ /// all known domains and their corresponding backends & settings
3939+ pub domains: HashMap<String, DomainInfo>,
4040+}
4141+4242+impl AuthGateway {
4343+ /// fetch the domain info for this request
4444+ fn domain_info<'s>(&'s self, session: &Session) -> Result<&'s DomainInfo> {
4545+ let req = session.req_header();
4646+ // TODO(potential-bug): afaict, right now, afaict, pingora a) does not check that SNI matches the `Host`
4747+ // header, b) does not support extracting the SNI info on rustls, so we'll have to switch
4848+ // to boringssl and implement that ourselves T_T
4949+ let host = req
5050+ .headers
5151+ .get(http::header::HOST)
5252+ .ok_or_else(status_error(
5353+ "no host set",
5454+ ErrorSource::Downstream,
5555+ StatusCode::BAD_REQUEST,
5656+ ))?
5757+ .to_str()
5858+ .map_err(|e| {
5959+ Error::because(
6060+ ErrorType::HTTPStatus(StatusCode::BAD_REQUEST.into()),
6161+ "no host",
6262+ e,
6363+ )
6464+ })?;
6565+ let info = self.domains.get(host).ok_or_else(status_error(
6666+ "unknown host",
6767+ ErrorSource::Downstream,
6868+ StatusCode::SERVICE_UNAVAILABLE,
6969+ ))?;
7070+7171+ Ok(info)
7272+ }
7373+7474+ /// mangle general headers, per [`config::format::ManageHeaders`]
7575+ async fn strip_and_apply_general_headers(
7676+ &self,
7777+ session: &mut Session,
7878+ info: &DomainInfo,
7979+ is_https: bool,
8080+ ) -> Result<()> {
8181+ let remote_addr = session.client_addr().and_then(|addr| match addr {
8282+ pingora::protocols::l4::socket::SocketAddr::Inet(socket_addr) => {
8383+ Some(socket_addr.ip().to_string())
8484+ }
8585+ pingora::protocols::l4::socket::SocketAddr::Unix(_) => None,
8686+ });
8787+ let req = session.req_header_mut();
8888+ if let Some(header) = &info.headers.host {
8989+ // TODO(cleanup): preprocess all header names
9090+ let name = HeaderName::from_bytes(header.as_bytes())
9191+ .map_err(internal_error_from("invalid claim-to-header header name"))?;
9292+ let val = req
9393+ .headers
9494+ .get(http::header::HOST)
9595+ .expect("we had to have this to look up our backend")
9696+ .clone();
9797+ req.headers.insert(name, val);
9898+ }
9999+ if let Some(header) = &info.headers.x_forwarded_for
100100+ && let Some(addr) = &remote_addr
101101+ {
102102+ let name = HeaderName::from_bytes(header.as_bytes())
103103+ .map_err(internal_error_from("invalid claim-to-header header name"))?;
104104+ let mut val = req
105105+ .headers
106106+ .get("x-forwarded-for")
107107+ .map(|v| v.as_bytes())
108108+ .unwrap_or(b"")
109109+ .to_owned();
110110+ val.extend(b",");
111111+ val.extend(addr.as_bytes());
112112+ let val = HeaderValue::from_bytes(&val)
113113+ .map_err(internal_error_from("invalid remote-addr header value"))?;
114114+ req.headers.insert(name, val);
115115+ }
116116+ if let Some(header) = &info.headers.x_forwarded_proto {
117117+ let name = HeaderName::from_bytes(header.as_bytes())
118118+ .map_err(internal_error_from("invalid claim-to-header header name"))?;
119119+ req.headers.insert(
120120+ name,
121121+ HeaderValue::from_static(if is_https { "https" } else { "http" }),
122122+ );
123123+ }
124124+ if let Some(header) = &info.headers.remote_addr
125125+ && let Some(addr) = &remote_addr
126126+ {
127127+ let name = HeaderName::from_bytes(header.as_bytes())
128128+ .map_err(internal_error_from("invalid claim-to-header header name"))?;
129129+ let val = HeaderValue::from_str(addr)
130130+ .map_err(internal_error_from("invalid remote-addr header value"))?;
131131+ req.headers.insert(name, val);
132132+ }
133133+134134+ Ok(())
135135+ }
136136+137137+ /// check auth, starting the flow if necessary
138138+ async fn check_auth(
139139+ &self,
140140+ session: &mut Session,
141141+ auth_info: &oidc::Info,
142142+ ctx: &mut AuthCtx,
143143+ ) -> Result<ControlFlow<()>> {
144144+ use auth_code_flow::code_request;
145145+146146+ let req = session.req_header_mut();
147147+ let cookies = httputil::cookie_jar(req)?.unwrap_or_default();
148148+149149+ let auth_cookie = cookies
150150+ .get(SESSION_COOKIE_NAME)
151151+ .map(|c| c.value())
152152+ .and_then(|c| cookies::CookieContents::contents(c, &auth_info.cookie_signing_key).ok());
153153+ {
154154+ // auth_info map pin
155155+ let sessions = auth_info.sessions.pin();
156156+ if let Some(valid_session) = auth_cookie
157157+ .and_then(|c| sessions.get(&c.session_id))
158158+ .filter(|sess| sess.expires_at > jiff::Timestamp::now())
159159+ {
160160+ if let Some(claim_map) = &auth_info.config.claims {
161161+ for (claim, header) in &claim_map.claim_to_header {
162162+ match valid_session.claims.get(claim) {
163163+ Some(val) => {
164164+ let val = HeaderValue::from_bytes(val.as_bytes())
165165+ .map_err(internal_error_from("invalid claim value"))?;
166166+ let name = HeaderName::from_bytes(header.as_bytes()).map_err(
167167+ internal_error_from("invalid claim-to-header header name"),
168168+ )?;
169169+ req.headers.insert(name, val)
170170+ }
171171+ None => req.headers.remove(header),
172172+ };
173173+ }
174174+ }
175175+176176+ ctx.session_valid = true;
177177+178178+ return Ok(ControlFlow::Continue(()));
179179+ }
180180+ }
181181+182182+ // otherwise! start the auth flow
183183+ let meta_cache = auth_info.get_or_cache_metadata().await?;
184184+185185+ // TODO(cleanup): precompute scopes
186186+ let redirect_info = code_request::redirect_to_auth_server(
187187+ (&meta_cache.metadata).into(),
188188+ code_request::Data::new(
189189+ &auth_info.config.client_id,
190190+ &auth_info
191191+ .config
192192+ .scopes
193193+ .as_ref()
194194+ .map(|s| {
195195+ s.required
196196+ .iter()
197197+ .fold(auth_code_flow::Scopes::base_scopes(), |scopes, scope| {
198198+ scopes.add_scope(scope)
199199+ })
200200+ })
201201+ .unwrap_or(auth_code_flow::Scopes::base_scopes()),
202202+ // technically this is a spec violate, but it's useful for testing
203203+ &Url::parse(&format!(
204204+ "https://{domain}/{path}",
205205+ domain = auth_info.domain,
206206+ path = OAUTH_CONTINUE_PATH,
207207+ ))
208208+ .map_err(internal_error_from(
209209+ "unable to construct redirect url from domain",
210210+ ))?,
211211+ ),
212212+ )
213213+ .map_err(internal_error_from("unable to construct redirect"))?;
214214+215215+ if auth_info
216216+ .auth_states
217217+ .pin()
218218+ .try_insert(
219219+ redirect_info.state,
220220+ InProgressAuth {
221221+ code_verifier: redirect_info.code_verifier,
222222+ original_path: req.uri.path().to_string(),
223223+ },
224224+ )
225225+ .is_err()
226226+ {
227227+ // this is _extremely_ unlikely to happen, but worth checking anyway
228228+ return Err(internal_error("state id collision")());
229229+ };
230230+231231+ httputil::redirect_response(session, redirect_info.url.as_str(), |_, _| Ok(())).await?;
232232+ Ok(ControlFlow::Break(()))
233233+ }
234234+235235+ /// continue auth from inbound redirects, or logout from a logout redirect
236236+ async fn receive_redirect(
237237+ &self,
238238+ session: &mut Session,
239239+ info: &DomainInfo,
240240+ ) -> Result<ControlFlow<()>> {
241241+ use auth_code_flow::{code_response, token_request, token_response};
242242+243243+ let req = session.req_header();
244244+ let Some(pq) = req.uri.path_and_query() else {
245245+ return Ok(ControlFlow::Continue(()));
246246+ };
247247+ let Some(auth_info) = &info.oidc else {
248248+ return Ok(ControlFlow::Continue(()));
249249+ };
250250+251251+ if pq.path() == OAUTH_LOGOUT_PATH {
252252+ let Some(mut cookies) = httputil::cookie_jar(req)? else {
253253+ // we're not logged in, just return fine
254254+ httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(()))
255255+ .await?;
256256+ return Ok(ControlFlow::Break(()));
257257+ };
258258+259259+ {
260260+ let Some(raw) = cookies
261261+ .get(SESSION_COOKIE_NAME)
262262+ .map(|raw| raw.value().to_string())
263263+ else {
264264+ // we're not logged in, just return fine
265265+ httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| {
266266+ Ok(())
267267+ })
268268+ .await?;
269269+ return Ok(ControlFlow::Break(()));
270270+ };
271271+ cookies.remove(SESSION_COOKIE_NAME);
272272+273273+ let Some(cookies::CookieMessage { session_id }) =
274274+ cookies::CookieContents::contents(&raw, &auth_info.cookie_signing_key).ok()
275275+ else {
276276+ // invalid cookie, just ignore
277277+ httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| {
278278+ Ok(())
279279+ })
280280+ .await?;
281281+ return Ok(ControlFlow::Break(()));
282282+ };
283283+ auth_info.sessions.pin().remove(&session_id);
284284+ };
285285+286286+ httputil::redirect_response(session, &auth_info.config.logout_url, |_, _| Ok(()))
287287+ .await?;
288288+ return Ok(ControlFlow::Break(()));
289289+ }
290290+291291+ if let Some(auth_info) = &info.oidc
292292+ && pq.path().starts_with("/")
293293+ && &pq.path()[1..] == OAUTH_CONTINUE_PATH
294294+ {
295295+ let Some(query) = pq.query() else {
296296+ session
297297+ .respond_error(StatusCode::BAD_REQUEST.into())
298298+ .await?;
299299+ return Ok(ControlFlow::Break(()));
300300+ };
301301+302302+ let Some(meta_cache) = auth_info.meta_cache.lock().await.as_ref().cloned() else {
303303+ // if we don't already have discovery metadata, something's real weird, cause
304304+ // how did we start the flow
305305+ session
306306+ .respond_error(StatusCode::BAD_REQUEST.into())
307307+ .await?;
308308+ return Ok(ControlFlow::Break(()));
309309+ };
310310+311311+ let status =
312312+ code_response::receive_redirect(query, meta_cache.metadata.issuer.as_str())
313313+ .map_err(status_error_from(
314314+ "unable to deserialize oauth2 response",
315315+ ErrorSource::Internal,
316316+ StatusCode::BAD_REQUEST,
317317+ ))?;
318318+ let resp = match status {
319319+ Ok(resp) => resp,
320320+ Err(err) => {
321321+ auth_info.auth_states.pin().remove(&err.state);
322322+ match err.error {
323323+ code_response::ErrorType::AccessDenied => {
324324+ session.respond_error(StatusCode::FORBIDDEN.into()).await?;
325325+ return Ok(ControlFlow::Break(()));
326326+ }
327327+ code_response::ErrorType::TemporarilyUnavailable => {
328328+ session
329329+ .respond_error(StatusCode::SERVICE_UNAVAILABLE.into())
330330+ .await?;
331331+ return Ok(ControlFlow::Break(()));
332332+ }
333333+ _ => {
334334+ session
335335+ .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
336336+ .await?;
337337+ return Ok(ControlFlow::Break(()));
338338+ }
339339+ }
340340+ }
341341+ };
342342+343343+ let Some(in_progress) = auth_info.auth_states.pin().remove(&resp.state).cloned() else {
344344+ session
345345+ .respond_error(StatusCode::BAD_REQUEST.into())
346346+ .await?;
347347+ return Ok(ControlFlow::Break(()));
348348+ };
349349+350350+ let mut body = String::new();
351351+ let redirect_uri = &format!(
352352+ "https://{domain}/{path}",
353353+ domain = auth_info.domain,
354354+ path = OAUTH_CONTINUE_PATH,
355355+ );
356356+ let token_req = token_request::request_access_token(
357357+ (&meta_cache.metadata).into(),
358358+ token_request::Data {
359359+ code: resp,
360360+ client_id: &auth_info.config.client_id,
361361+ client_secret: &auth_info.client_secret,
362362+ // this should not be a clone, but it's a weird quirk of our threadsafe
363363+ // hashmap choice
364364+ code_verifier: in_progress.code_verifier,
365365+ redirect_uri,
366366+ },
367367+ &mut body,
368368+ )
369369+ .map_err(internal_error_from("unable produce access token request"))?;
370370+371371+ let resp: token_response::Valid = {
372372+ let client = reqwest::Client::new();
373373+ let resp = client
374374+ .post(token_req.url.as_str().to_string())
375375+ .header(
376376+ http::header::CONTENT_TYPE,
377377+ "application/x-www-form-urlencoded",
378378+ )
379379+ .body(body)
380380+ .send()
381381+ .await
382382+ .map_err(internal_error_from("unable to make token request"))?;
383383+ if resp.status() == StatusCode::BAD_REQUEST {
384384+ let _resp: token_response::Error = resp
385385+ .json()
386386+ .await
387387+ .map_err(internal_error_from("unable to deserialize response"))?;
388388+ session
389389+ .respond_error(StatusCode::BAD_REQUEST.into())
390390+ .await?;
391391+ return Ok(ControlFlow::Break(()));
392392+ // error per [the rfc][ref:draft-ietf-oauth-v2-1#3.2.4]
393393+ } else if resp.status() == StatusCode::NOT_FOUND {
394394+ // maybe it moved? try fetching the info again later
395395+ auth_info.clear_metadata_cache().await;
396396+ } else if !resp.status().is_success() {
397397+ session
398398+ .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
399399+ .await?;
400400+ return Ok(ControlFlow::Break(()));
401401+ }
402402+ resp.json()
403403+ .await
404404+ .map_err(internal_error_from("unable to deserialize token response"))?
405405+ };
406406+407407+ use std::str::FromStr as _;
408408+ let id_token = compact_jwt::JwtUnverified::from_str(
409409+ &resp
410410+ .id_token
411411+ .ok_or_else(internal_error("no id token in response"))?,
412412+ )
413413+ .map_err(internal_error_from("unable to deserialize id token"))?;
414414+415415+ // will be some if we had the option "verify" turned on, will be none otherwise
416416+ // (will never be none if the option is on but we couldn't fetch the token)
417417+ let id_token: compact_jwt::Jwt<()> = match &meta_cache.jws_verifier {
418418+ Some(verifier) => {
419419+ use compact_jwt::JwsVerifier as _;
420420+ verifier
421421+ .verify(&id_token)
422422+ .map_err(internal_error_from("unable to verify id token"))?
423423+ }
424424+ None => {
425425+ use compact_jwt::JwsVerifier as _;
426426+ let verifier =
427427+ compact_jwt::dangernoverify::JwsDangerReleaseWithoutVerify::default();
428428+ verifier
429429+ .verify(&id_token)
430430+ .map_err(internal_error_from("unable to deserialize id_token to jwt"))?
431431+ }
432432+ };
433433+434434+ // per https://openid.net/specs/openid-connect-core-1_0-final.html#TokenResponseValidation, we _must_ to check
435435+ // - iss (must match the expected issuer)
436436+ // - aud (must match our client_id)
437437+ // - exp (must expire in the future)
438438+ if id_token
439439+ .iss
440440+ .is_none_or(|iss| iss != meta_cache.metadata.issuer.as_str())
441441+ {
442442+ return Err(internal_error("issuer mismatch on id token")());
443443+ }
444444+ if id_token
445445+ .aud
446446+ .is_none_or(|aud| aud != auth_info.config.client_id)
447447+ {
448448+ return Err(internal_error("audience mismatch on id token")());
449449+ }
450450+ let expires_at = jiff::Timestamp::from_second(
451451+ id_token
452452+ .exp
453453+ .ok_or_else(internal_error("missing exp on token"))?,
454454+ )
455455+ .map_err(internal_error_from("unable to parse exp as timestamp"))?;
456456+ if expires_at < jiff::Timestamp::now() {
457457+ session
458458+ .respond_error(StatusCode::INTERNAL_SERVER_ERROR.into())
459459+ .await?;
460460+ return Ok(ControlFlow::Break(()));
461461+ }
462462+463463+ let user_info = UserInfo {
464464+ expires_at,
465465+ claims: id_token
466466+ .claims
467467+ .into_iter()
468468+ // [`serde_json::Value`] implements display that's just "semi-infallible
469469+ // serialize"
470470+ .map(|(k, v)| (k, v.to_string()))
471471+ .collect(),
472472+ };
473473+ let expiry = user_info.expires_at;
474474+ let mut rng = rand::rngs::OsRng;
475475+ use rand::Rng as _;
476476+ let session_id = rng.r#gen();
477477+ auth_info.sessions.pin().insert(session_id, user_info);
478478+ let cookie = cookies::CookieContents::sign(
479479+ cookies::CookieMessage { session_id },
480480+ &auth_info.cookie_signing_key,
481481+ )
482482+ .map_err(|()| internal_error("unable to sign cookie")())?;
483483+484484+ let url = format!(
485485+ "https://{domain}{original_path}",
486486+ domain = auth_info.domain,
487487+ original_path = in_progress.original_path
488488+ );
489489+ httputil::redirect_response(session, &url, |resp, session| {
490490+ let mut cookies = httputil::cookie_jar(session.req_header())?.unwrap_or_default();
491491+ cookies.set(
492492+ cookie_rs::Cookie::builder(SESSION_COOKIE_NAME, cookie)
493493+ .http_only(true)
494494+ .secure(true)
495495+ // utc technically potentially different than gmt, but this is just advisory (we enforce
496496+ // elsewhere), so it's ok
497497+ .max_age(
498498+ std::time::Duration::try_from(expiry - jiff::Timestamp::now())
499499+ .expect("formed from timestamps, can't have relative parts"),
500500+ )
501501+ .path("/")
502502+ .build(),
503503+ );
504504+ cookies
505505+ .as_header_values()
506506+ .into_iter()
507507+ .try_for_each(|cookie| {
508508+ let val = HeaderValue::from_bytes(cookie.as_bytes())
509509+ .map_err(internal_error_from("bad cookie header value"))?;
510510+511511+ resp.append_header(http::header::SET_COOKIE, val)?;
512512+ Ok::<_, Box<Error>>(())
513513+ })?;
514514+ Ok(())
515515+ })
516516+ .await?;
517517+ return Ok(ControlFlow::Break(()));
518518+ }
519519+520520+ Ok(ControlFlow::Continue(()))
521521+ }
522522+}
523523+524524+pub struct AuthCtx {
525525+ session_valid: bool,
526526+}
527527+528528+/// the oauth2 redirect path, without the leading slash
529529+const OAUTH_CONTINUE_PATH: &str = ".oauth2/continue";
530530+/// the logout/cookie-clear path, _with_ the leading slash
531531+const OAUTH_LOGOUT_PATH: &str = "/.oauth2/logout";
532532+533533+#[async_trait]
534534+impl ProxyHttp for AuthGateway {
535535+ type CTX = AuthCtx;
536536+ fn new_ctx(&self) -> Self::CTX {
537537+ AuthCtx {
538538+ session_valid: false,
539539+ }
540540+ }
541541+542542+ async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
543543+ let info = self.domain_info(session)?;
544544+545545+ // check if we need to terminate the connection cause someone sent us an http request and we
546546+ // don't allow that
547547+ let is_https = session
548548+ .digest()
549549+ .and_then(|d| d.ssl_digest.as_ref())
550550+ .is_some();
551551+ if !is_https {
552552+ use config::format::domain::TlsMode;
553553+ match info.tls_mode {
554554+ TlsMode::Only => {
555555+ // we should just drop the connection, although people should really just be
556556+ // using HSTS
557557+ session.shutdown().await;
558558+ return Ok(true);
559559+ }
560560+ TlsMode::UnsafeAllowHttp => {}
561561+ }
562562+ }
563563+564564+ // next, check if we're in the middle of an oauth flow
565565+ match self.receive_redirect(session, info).await? {
566566+ ControlFlow::Continue(()) => {}
567567+ ControlFlow::Break(()) => return Ok(true),
568568+ }
569569+570570+ // finally check our actual auth state, starting the auth flow as needed
571571+ if let Some(auth_info) = &info.oidc {
572572+ match self.check_auth(session, auth_info, ctx).await? {
573573+ ControlFlow::Continue(()) => {}
574574+ ControlFlow::Break(()) => return Ok(true),
575575+ }
576576+ }
577577+578578+ // we're past auth and are processing as normal, proceed
579579+ self.strip_and_apply_general_headers(session, info, is_https)
580580+ .await?;
581581+582582+ Ok(false)
583583+ }
584584+585585+ async fn upstream_peer(
586586+ &self,
587587+ session: &mut Session,
588588+ _ctx: &mut Self::CTX,
589589+ ) -> Result<Box<HttpPeer>> {
590590+ fn client_addr_key(sock_addr: &pingora::protocols::l4::socket::SocketAddr) -> Vec<u8> {
591591+ use pingora::protocols::l4::socket::SocketAddr;
592592+ match sock_addr {
593593+ SocketAddr::Inet(socket_addr) => match socket_addr {
594594+ std::net::SocketAddr::V4(v4) => Vec::from(v4.ip().octets()),
595595+ std::net::SocketAddr::V6(v6) => Vec::from(v6.ip().octets()),
596596+ },
597597+ _ => unreachable!(),
598598+ }
599599+ }
600600+601601+ let backends = self.domain_info(session)?;
602602+ let backend = backends
603603+ .balancer
604604+ // NB: this means that CGNAT, other proxies, etc will? consistently hit the same
605605+ // backend, so we might wanna take that into consideration. fine for now, this is
606606+ // currently for personal use ;-)
607607+ .select(
608608+ &client_addr_key(session.client_addr().ok_or_else(status_error(
609609+ "no client address",
610610+ ErrorSource::Downstream,
611611+ StatusCode::BAD_REQUEST,
612612+ ))?), /* lb on client address */
613613+ 256,
614614+ )
615615+ .ok_or_else(status_error(
616616+ "no available backends",
617617+ ErrorSource::Upstream,
618618+ StatusCode::SERVICE_UNAVAILABLE,
619619+ ))?;
620620+621621+ let needs_tls = backend
622622+ .ext
623623+ .get::<BackendData>()
624624+ .map(|d| d.tls)
625625+ .unwrap_or(true);
626626+627627+ Ok(Box::new(HttpPeer::new(
628628+ backend,
629629+ needs_tls,
630630+ backends.sni_name.to_string(),
631631+ )))
632632+ }
633633+634634+ async fn response_filter(
635635+ &self,
636636+ _session: &mut Session,
637637+ upstream_response: &mut ResponseHeader,
638638+ ctx: &mut Self::CTX,
639639+ ) -> Result<()>
640640+ where
641641+ Self::CTX: Send + Sync,
642642+ {
643643+ // if we had no valid session, clear the cookie
644644+ if !ctx.session_valid {
645645+ let mut cookies = CookieJar::default();
646646+ cookies.remove(SESSION_COOKIE_NAME);
647647+648648+ cookies.as_header_values().into_iter().try_for_each(|v| {
649649+ let v = HeaderValue::from_bytes(v.as_bytes())
650650+ .map_err(internal_error_from("invalid clear cookie header value"))?;
651651+652652+ upstream_response
653653+ .headers
654654+ .append(http::header::SET_COOKIE, v);
655655+ Ok::<_, Box<Error>>(())
656656+ })?;
657657+ }
658658+ Ok(())
659659+ }
660660+}
661661+662662+/// additional data stored in the load balancer's backend structure
663663+///
664664+/// for use in [`AuthGateway::upstream_peer`]
665665+#[derive(Clone)]
666666+pub struct BackendData {
667667+ /// does the backend want tls
668668+ pub tls: bool,
669669+}
+193
src/gateway/oidc.rs
···11+use std::collections::HashMap;
22+use std::sync::Arc;
33+44+use color_eyre::eyre::Context as _;
55+use http::status::StatusCode;
66+use pingora::prelude::*;
77+use tokio::sync::Mutex as TokioMutex;
88+use url::Url;
99+1010+use crate::config;
1111+use crate::httputil::{internal_error, internal_error_from, status_error, status_error_from};
1212+use crate::oauth;
1313+1414+/// see https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/Cookies#cookie_prefixes
1515+pub const SESSION_COOKIE_NAME: &str = "__Host-Http-oauth-session";
1616+1717+/// active session user information
1818+pub struct UserInfo {
1919+ /// when this session expires, from the id token `exp` claim
2020+ pub expires_at: jiff::Timestamp,
2121+ /// all other non-default claims attached to the id token
2222+ pub claims: HashMap<String, String>,
2323+}
2424+2525+/// in-progress auth flow state
2626+#[derive(Clone)] // only needed cause papaya
2727+pub struct InProgressAuth {
2828+ /// the code verifier, whence the code challenge was derived
2929+ pub code_verifier: String,
3030+ /// the original path we were trying to go to on this domain,
3131+ /// for redirection once the auth flow is over
3232+ pub original_path: String,
3333+}
3434+3535+/// cache of auth server metadata and associated bits
3636+pub struct MetadataCache {
3737+ /// the metadata itself
3838+ pub metadata: oauth::metadata::AuthServerMetadata,
3939+ /// the fetched, parsed jwks info
4040+ /// "none" means validation was disabled in our config, _NOT_ "we couldn't fetch the jwks"
4141+ pub jws_verifier: Option<compact_jwt::JwsEs256Verifier>,
4242+}
4343+4444+/// overall auth info
4545+pub struct Info {
4646+ /// the raw config from our config file
4747+ pub config: config::format::Oidc,
4848+ // needs to be tokio because we need to hold it across an await point
4949+ /// cache of auth server metadata
5050+ pub meta_cache: TokioMutex<Option<Arc<MetadataCache>>>,
5151+ /// the current in-progress authorization flows, bound to states submitted to the auth serve
5252+ pub auth_states: papaya::HashMap<uuid::Uuid, InProgressAuth>,
5353+ /// the currently active sessions, by id from the session id cookie
5454+ pub sessions: papaya::HashMap<u64, UserInfo>,
5555+ /// the root domain
5656+ pub domain: String,
5757+ /// the oauth client secret
5858+ pub client_secret: String,
5959+6060+ /// the signing key used to sign session cookies
6161+ pub cookie_signing_key: ed25519_dalek::SigningKey,
6262+}
6363+impl Info {
6464+ /// clear the metadata cache
6565+ pub async fn clear_metadata_cache(&self) {
6666+ *self.meta_cache.lock().await = None;
6767+ }
6868+ /// get the metadata, or cache it (and the accompanying jwks data if needed) if it hasn't get
6969+ /// been fetched
7070+ pub async fn get_or_cache_metadata(&self) -> Result<Arc<MetadataCache>> {
7171+ let mut cache = self.meta_cache.lock().await;
7272+ if let Some(ref meta) = *cache {
7373+ return Ok(meta.clone());
7474+ }
7575+ let discovery_url = {
7676+ let url = Url::parse(&self.config.discovery_url_base)
7777+ .map_err(internal_error_from("invalid discovery url"))?;
7878+ oauth::metadata::oidc_discovery_uri(&url)
7979+ .map_err(internal_error_from("invalid discovery url suffix"))?
8080+ };
8181+8282+ let meta: oauth::metadata::AuthServerMetadata = {
8383+ let resp = reqwest::Client::new()
8484+ .get(discovery_url.as_str())
8585+ .header(http::header::ACCEPT, "application/json")
8686+ .send()
8787+ .await
8888+ .map_err(internal_error_from("unable to fetch oauth metadata doc"))?;
8989+9090+ if !resp.status().is_success() {
9191+ return Err(status_error(
9292+ "unable to fetch discovery info",
9393+ ErrorSource::Internal,
9494+ StatusCode::SERVICE_UNAVAILABLE,
9595+ )());
9696+ }
9797+ resp.json().await.map_err(internal_error_from(
9898+ "unable to deserialize oauth metadata doc",
9999+ ))?
100100+ };
101101+102102+ meta.generally_as_expected().map_err(internal_error_from(
103103+ "auth server not generally as expected/required",
104104+ ))?;
105105+106106+ let jws_verifier = if self.config.validate_with_jwk {
107107+ if !meta
108108+ .id_token_signing_alg_values_supported
109109+ .as_ref()
110110+ .is_some_and(|s| s.contains(&oauth::metadata::SigningAlgValue::ES256))
111111+ {
112112+ return Err(internal_error("es256 signing not supported by endpoint")());
113113+ }
114114+115115+ let Some(jwks_uri) = &meta.jwks_uri else {
116116+ return Err(internal_error(
117117+ "jwks not available or es256 signing not supported by endpoint",
118118+ )());
119119+ };
120120+ let resp = reqwest::Client::new()
121121+ .get(jwks_uri.as_str())
122122+ .header(http::header::ACCEPT, "application/json")
123123+ .send()
124124+ .await
125125+ .map_err(status_error_from(
126126+ "unable to fetch jwks",
127127+ ErrorSource::Internal,
128128+ StatusCode::SERVICE_UNAVAILABLE,
129129+ ))?;
130130+ if !resp.status().is_success() {
131131+ return Err(status_error(
132132+ "unable to fetch jwks",
133133+ ErrorSource::Internal,
134134+ StatusCode::SERVICE_UNAVAILABLE,
135135+ )());
136136+ }
137137+ let jwks: compact_jwt::JwkKeySet = resp
138138+ .json()
139139+ .await
140140+ .map_err(internal_error_from("unable to deserialize jwks"))?;
141141+ // per oidc discovery v1 section 3, this either contains only signing keys, or has
142142+ // keys with a `use` option. we're going to choose to only support 1 key per use
143143+ // here and require a use anyway. this whole thing is so poorly specified. the jwt
144144+ // ecosystem really is a tire fire
145145+ Some(
146146+ jwks.keys
147147+ .iter()
148148+ .filter_map(|key| {
149149+ let compact_jwt::Jwk::EC { use_: r#use, .. } = &key else {
150150+ return None;
151151+ };
152152+ if !r#use
153153+ .as_ref()
154154+ .is_some_and(|r#use| r#use == &compact_jwt::JwkUse::Sig)
155155+ {
156156+ return None;
157157+ }
158158+ compact_jwt::JwsEs256Verifier::try_from(key).ok()
159159+ })
160160+ .next()
161161+ .ok_or_else(internal_error("no sig keys availabe from jwks"))?,
162162+ )
163163+ } else {
164164+ None
165165+ };
166166+167167+ Ok(cache
168168+ .insert(Arc::new(MetadataCache {
169169+ metadata: meta,
170170+ jws_verifier,
171171+ }))
172172+ .clone())
173173+ }
174174+175175+ pub fn from_config(config: config::format::Oidc, domain: String) -> color_eyre::Result<Self> {
176176+ let mut rng = rand::rngs::OsRng;
177177+ // TODO(feature): check & warn on permissions here?
178178+ let client_secret = std::fs::read_to_string(&config.client_secret_path)
179179+ .context("reading client secret")?
180180+ .trim()
181181+ .to_string();
182182+ Ok(Self {
183183+ config,
184184+ meta_cache: TokioMutex::new(None),
185185+ auth_states: Default::default(),
186186+ sessions: Default::default(),
187187+ client_secret,
188188+189189+ cookie_signing_key: ed25519_dalek::SigningKey::generate(&mut rng),
190190+ domain,
191191+ })
192192+ }
193193+}
+118
src/httputil.rs
···11+//! http utilities
22+33+use http::StatusCode;
44+use pingora::prelude::*;
55+66+/// closure that returns the given status with no cause
77+///
88+/// use with [`Option::ok_or_else`]
99+pub fn status_error(
1010+ why: &'static str,
1111+ src: ErrorSource,
1212+ status: StatusCode,
1313+) -> impl FnOnce() -> Box<Error> {
1414+ move || {
1515+ Error::create(
1616+ ErrorType::HTTPStatus(status.into()),
1717+ src,
1818+ Some(why.into()),
1919+ None,
2020+ )
2121+ }
2222+}
2323+/// closure that returns `500 Internal Server Error`, marked as caused by the error returned to
2424+/// the given closure
2525+///
2626+/// use with [`Result::map_err`]
2727+pub fn internal_error_from<E>(why: &'static str) -> impl FnOnce(E) -> Box<Error>
2828+where
2929+ E: Into<Box<dyn ErrorTrait + Send + Sync>>,
3030+{
3131+ move |cause| {
3232+ Error::create(
3333+ ErrorType::HTTPStatus(StatusCode::INTERNAL_SERVER_ERROR.into()),
3434+ ErrorSource::Internal,
3535+ Some(why.into()),
3636+ Some(cause.into()),
3737+ )
3838+ }
3939+}
4040+4141+/// closure that returns `500 Internal Server Error` with no cause
4242+///
4343+/// use with [`Option::ok_or_else`]
4444+pub fn internal_error(why: &'static str) -> impl FnOnce() -> Box<Error> {
4545+ move || {
4646+ Error::create(
4747+ ErrorType::HTTPStatus(StatusCode::INTERNAL_SERVER_ERROR.into()),
4848+ ErrorSource::Internal,
4949+ Some(why.into()),
5050+ None,
5151+ )
5252+ }
5353+}
5454+/// closure that returns the given status, marked as caused by the error returned to the given closure
5555+///
5656+/// use with [`Result::map_err`]
5757+pub fn status_error_from<E>(
5858+ why: &'static str,
5959+ src: ErrorSource,
6060+ status: http::StatusCode,
6161+) -> impl FnOnce(E) -> Box<Error>
6262+where
6363+ E: Into<Box<dyn ErrorTrait + Send + Sync>>,
6464+{
6565+ move |cause| {
6666+ Error::create(
6767+ ErrorType::HTTPStatus(status.into()),
6868+ src,
6969+ Some(why.into()),
7070+ Some(cause.into()),
7171+ )
7272+ }
7373+}
7474+7575+/// redirect to the given location
7676+///
7777+/// the given callback can be used to inject additional headers, like `set-cookie`
7878+pub async fn redirect_response(
7979+ session: &mut Session,
8080+ to: &str,
8181+ bld_resp_header: impl FnOnce(&mut ResponseHeader, &Session) -> Result<()>,
8282+) -> Result<()> {
8383+ session.set_keepalive(None);
8484+ session
8585+ .write_response_header(
8686+ Box::new({
8787+ // per <rfc:draft-ietf-oauth-v2-1#1.6>, any redirect is fine save 307, but HTTP 302 seems to
8888+ // be their example.
8989+ let mut resp = ResponseHeader::build(StatusCode::FOUND, Some(0))?;
9090+ resp.insert_header(http::header::LOCATION, to)?;
9191+ bld_resp_header(&mut resp, session)?;
9292+ resp
9393+ }),
9494+ true,
9595+ )
9696+ .await?;
9797+ session.finish_body().await?;
9898+ Ok(())
9999+}
100100+101101+/// fetch the cookies for the current request
102102+pub fn cookie_jar(req: &'_ RequestHeader) -> Result<Option<cookie_rs::CookieJar<'_>>> {
103103+ use cookie_rs::CookieJar;
104104+105105+ let Some(raw) = req.headers.get(http::header::COOKIE) else {
106106+ return Ok(None);
107107+ };
108108+ Ok(Some(
109109+ raw.to_str()
110110+ .map_err(Box::<dyn ErrorTrait + Send + Sync>::from)
111111+ .and_then(|c| Ok(CookieJar::parse(c)?))
112112+ .map_err(status_error_from(
113113+ "bad cookie header",
114114+ ErrorSource::Downstream,
115115+ StatusCode::BAD_REQUEST,
116116+ ))?,
117117+ ))
118118+}
+23-189
src/main.rs
···11use std::collections::HashMap;
22-use std::sync::Arc;
3244-use async_trait::async_trait;
53use color_eyre::eyre::Context as _;
66-use http::status::StatusCode;
74use pingora::lb;
85use pingora::lb::selection::consistent::KetamaHashing;
99-use pingora::modules::http::HttpModules;
1010-use pingora::modules::http::compression::ResponseCompressionBuilder;
116use pingora::prelude::*;
1271313-mod config;
1414-1515-struct BackendInfo {
1616- balancer: Arc<LoadBalancer<KetamaHashing>>,
1717- tls_mode: config::format::domain::TlsMode,
1818- name: String,
1919- // TODO: force ssl
2020-}
2121-2222-pub struct AuthGateway {
2323- backends: HashMap<String, BackendInfo>,
2424-}
2525-2626-fn status_error(
2727- why: &'static str,
2828- src: ErrorSource,
2929- status: http::StatusCode,
3030-) -> impl FnOnce() -> Box<Error> {
3131- move || {
3232- Error::create(
3333- ErrorType::HTTPStatus(status.into()),
3434- src,
3535- Some(why.into()),
3636- None,
3737- )
3838- }
3939-}
4040-4141-impl AuthGateway {
4242- fn backend_info<'s>(&'s self, session: &Session) -> Result<&'s BackendInfo> {
4343- let req = session.req_header();
4444- // TODO: afaict, right now, afaict, pingora a) does not check that SNI matches the `Host`
4545- // header, b) does not support extracting the SNI info on rustls, so we'll have to switch
4646- // to boringssl and implement that ourselves T_T
4747- let host = req
4848- .headers
4949- .get(http::header::HOST)
5050- .ok_or_else(status_error(
5151- "no host set",
5252- ErrorSource::Downstream,
5353- StatusCode::BAD_REQUEST,
5454- ))?
5555- .to_str()
5656- .map_err(|e| {
5757- Error::because(
5858- ErrorType::HTTPStatus(StatusCode::BAD_REQUEST.into()),
5959- "no host",
6060- e,
6161- )
6262- })?;
6363- let info = self.backends.get(host).ok_or_else(status_error(
6464- "unknown host",
6565- ErrorSource::Downstream,
6666- StatusCode::SERVICE_UNAVAILABLE,
6767- ))?;
6868-6969- Ok(info)
7070- }
7171-}
7272-7373-#[async_trait]
7474-impl ProxyHttp for AuthGateway {
7575- type CTX = ();
7676- fn new_ctx(&self) -> Self::CTX {}
7777-7878- fn init_downstream_modules(&self, modules: &mut HttpModules) {
7979- // TODO: make this configurable?
8080- modules.add_module(ResponseCompressionBuilder::enable(1));
8181- }
8282-8383- async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool> {
8484- // check if this is http, and redirect
8585- // TODO: maybe should be a module?
8686- let is_https = session
8787- .digest()
8888- .and_then(|d| d.ssl_digest.as_ref())
8989- .is_some();
9090- if !is_https {
9191- use config::format::domain::TlsMode;
9292- let info = self.backend_info(session)?;
9393- match info.tls_mode {
9494- TlsMode::Only => {
9595- // we should just drop the connection, although people should really just be
9696- // using HSTS
9797- session.shutdown().await;
9898- return Ok(true);
9999- }
100100- TlsMode::UnsafeAllowHttp => {}
101101- }
102102- }
103103-104104- Ok(false)
105105- }
106106-107107- async fn upstream_peer(&self, session: &mut Session, _ctx: &mut ()) -> Result<Box<HttpPeer>> {
108108- fn client_addr_key(sock_addr: &pingora::protocols::l4::socket::SocketAddr) -> Vec<u8> {
109109- use pingora::protocols::l4::socket::SocketAddr;
110110- match sock_addr {
111111- SocketAddr::Inet(socket_addr) => match socket_addr {
112112- std::net::SocketAddr::V4(v4) => Vec::from(v4.ip().octets()),
113113- std::net::SocketAddr::V6(v6) => Vec::from(v6.ip().octets()),
114114- },
115115- // TODO: this is... not a great key for hashing
116116- SocketAddr::Unix(_socket_addr) => vec![],
117117- }
118118- }
119119-120120- let backends = self.backend_info(session)?;
121121- let backend = backends
122122- .balancer
123123- // NB: this means that CGNAT, other proxies, etc will? consistently hit the same
124124- // backend, so we might wanna take that into consideration. fine for now, this is
125125- // currently for personal use ;-)
126126- .select(
127127- &client_addr_key(session.client_addr().ok_or_else(status_error(
128128- "no client address",
129129- ErrorSource::Downstream,
130130- StatusCode::BAD_REQUEST,
131131- ))?), /* lb on client address */
132132- 256,
133133- )
134134- .ok_or_else(status_error(
135135- "no available backends",
136136- ErrorSource::Upstream,
137137- StatusCode::SERVICE_UNAVAILABLE,
138138- ))?;
139139-140140- let needs_tls = backend
141141- .ext
142142- .get::<BackendData>()
143143- .map(|d| d.tls)
144144- .unwrap_or(true);
145145-146146- Ok(Box::new(HttpPeer::new(
147147- backend,
148148- needs_tls,
149149- backends.name.to_string(),
150150- )))
151151- }
152152-153153- // TODO: upstream_request_filter to insert the right headers
154154-155155- async fn response_filter(
156156- &self,
157157- _session: &mut Session,
158158- _upstream_response: &mut ResponseHeader,
159159- _ctx: &mut Self::CTX,
160160- ) -> Result<()>
161161- where
162162- Self::CTX: Send + Sync,
163163- {
164164- Ok(())
165165- }
166166-167167- // TODO: logging
168168-}
88+use self::gateway::{AuthGateway, BackendData, DomainInfo, oidc};
1699170170-#[derive(Clone)]
171171-struct BackendData {
172172- tls: bool,
173173-}
1010+mod config;
1111+mod cookies;
1212+mod gateway;
1313+mod httputil;
1414+mod oauth;
17415175175-fn balancer(
176176- domains: &HashMap<String, config::format::Domain>,
177177-) -> color_eyre::Result<(
1616+/// constructed load balancer, with [backend info][`BackendInfo`] to be passed to [`AuthGateway`]
1717+type BalancerInfo = (
17818 Vec<pingora::services::background::GenBackgroundService<LoadBalancer<KetamaHashing>>>,
179179- HashMap<String, BackendInfo>,
180180-)> {
1919+ HashMap<String, DomainInfo>,
2020+);
2121+2222+/// construct the load balancer and initialize the [backend info][`BackendInfo`] for the
2323+/// [`AuthGateway`]
2424+fn balancer(domains: &HashMap<String, config::format::Domain>) -> color_eyre::Result<BalancerInfo> {
18125 use lb::{self, Backend, discovery};
18226 use pingora::protocols::l4::socket::SocketAddr;
18327···22468 .collect::<color_eyre::Result<_>>()
22569 .context("constucting backends for domain")?;
22670 let backends = lb::Backends::new(discovery::Static::new(backends));
227227- // TODO: allow configuring healthchecks
22871 let balancer = LoadBalancer::from_backends(backends);
22972 let svc = background_service("health checking", balancer);
23073231231- let info = BackendInfo {
7474+ let info = DomainInfo {
23275 balancer: svc.task(),
23376 tls_mode: config::format::domain::TlsMode::try_from(domain.tls_mode)
23477 .context("invalid tls mode")?,
235235- name: name.clone(),
7878+ sni_name: name.clone(),
7979+ oidc: domain
8080+ .oidc_auth
8181+ .clone()
8282+ .map(|config| oidc::Info::from_config(config, name.clone()))
8383+ .transpose()?,
8484+ headers: domain.manage_headers.clone().unwrap_or_default(),
23685 };
2378623887 balancers.insert(name.clone(), info);
···2449324594fn main() -> color_eyre::Result<()> {
24695 use color_eyre::eyre::eyre;
247247-248248- use std::os::unix::fs::PermissionsExt as _;
2499625097 tracing_subscriber::fmt().init();
25198 color_eyre::install()?;
···271118272119 let (balancer_svcs, balancers) =
273120 balancer(&config.domains).context("setting up load balancing")?;
274274- let mut gateway = http_proxy_service(
275275- &server.configuration,
276276- AuthGateway {
277277- backends: balancers,
278278- },
279279- );
121121+ let mut gateway = http_proxy_service(&server.configuration, AuthGateway { domains: balancers });
280122 for binding in config.bind_to_tcp {
281123 match binding.tls {
282124 Some(tls) => gateway
···284126 .context("setting up tls")?,
285127 None => gateway.add_tcp(&binding.addr),
286128 }
287287- }
288288- for binding in config.bind_to_uds {
289289- gateway.add_uds(
290290- &binding.path,
291291- binding
292292- .permissions
293293- .map(|p| std::fs::Permissions::from_mode(p.mode)),
294294- );
295129 }
296130297131 balancer_svcs
+596
src/oauth.rs
···11+//! # Background
22+//!
33+//! so! there exists oauth & oidc packages. they're what kanidm uses to implement oidc.
44+//! unfortunately, they're poorly maintained on the client side (e.g. their reqwest bindings don't
55+//! work with the latest version of reqwest), and rather clunkily implemented through callbacks
66+//! instead of typestate, which mean they're tightly bound to exact implementation details.
77+//!
88+//! this is... annoying. so we have this instead
99+//!
1010+//! # Overview
1111+//!
1212+//! this implementation is based on [oauth2.1]. this means, basically, it's [oauth2.0] with best
1313+//! practices applied.
1414+//!
1515+//! the oidc half is based on [oidc core 1.0 + errata 2][oidc1], but with... some of the more
1616+//! ill-advised ignorable parts duely ignored.
1717+//!
1818+//! [oauth2.1]: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1
1919+//! [oauth2.0]: https://www.rfc-editor.org/rfc/rfc8414.html
2020+//! [oidc1]: https://openid.net/specs/openid-connect-core-1_0.html
2121+2222+/// # [OAuth 2.0 Authorization Server Metadata][rfc:8414] and related helpers
2323+///
2424+/// most enums here also have additional values specified in the [iana registry].
2525+///
2626+/// [iana registry]: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
2727+pub mod metadata {
2828+ use std::collections::HashSet;
2929+3030+ use serde::Deserialize;
3131+ use url::Url;
3232+3333+ /// [OAuth 2.0 Authorization Server Metadata][rfc:8414#2]
3434+ ///
3535+ /// see the rfc for field descriptions.
3636+ #[derive(Deserialize)]
3737+ pub struct AuthServerMetadata {
3838+ pub issuer: Url,
3939+ // > \[authorization_endpoint\] is REQUIRED unless no grant types are supported that use the
4040+ // > authorization endpoint
4141+ //
4242+ // we require the authorization flow, so we leave this as required
4343+ pub authorization_endpoint: Url,
4444+ // same here
4545+ pub token_endpoint: Url,
4646+ pub jwks_uri: Option<Url>,
4747+ pub response_types_supported: HashSet<ResponseType>,
4848+ pub response_modes_supported: Option<HashSet<ResponseMode>>,
4949+ pub grant_types_supported: Option<HashSet<GrantType>>,
5050+ // defaults to [`TokenEndpointAuthMethod::ClientSecretBasic`], but oauth 2.1 [adds in post
5151+ // too][rfc:draft-ietf-v2.1#10.1]
5252+ pub token_endpoint_auth_methods_supported: Option<HashSet<AuthMethod>>,
5353+ pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
5454+5555+ // per https://openid.net/specs/openid-connect-discovery-1_0.html
5656+ pub id_token_signing_alg_values_supported: Option<HashSet<SigningAlgValue>>,
5757+ // per the spec, extra fields are defined in [OIDC Discovery 1.0 with errata
5858+ // 2][oidc-discovery-1].
5959+ //
6060+ // the rfc also contains a bunch of extra fields that we don't use, so aren't captured
6161+ // here
6262+ //
6363+ // [oidc-discovery-1]: https://openid.net/specs/openid-connect-discovery-1_0.html
6464+ }
6565+ impl AuthServerMetadata {
6666+ /// check if this metadata conforms to our expectations of a modern oauth v2.1 & oidc core
6767+ /// v1 server
6868+ pub fn generally_as_expected(&self) -> color_eyre::Result<()> {
6969+ use color_eyre::eyre::eyre;
7070+ if !self.response_types_supported.contains(&ResponseType::Code) {
7171+ return Err(eyre!("response type `code` not supported by auth server"));
7272+ }
7373+ // if this is missing, assume query is supported
7474+ if !ResponseMode::Query.is_supported_for(self.response_modes_supported.as_ref()) {
7575+ return Err(eyre!("response mode `query` not supported by auth server"));
7676+ }
7777+ if !GrantType::AuthorizationCode.is_supported_for(self.grant_types_supported.as_ref()) {
7878+ return Err(eyre!(
7979+ "grant type `authorization_code` not supported by auth server"
8080+ ));
8181+ }
8282+ if !AuthMethod::ClientSecretPost
8383+ .is_supported_for(self.token_endpoint_auth_methods_supported.as_ref())
8484+ {
8585+ return Err(eyre!(
8686+ "client_secret_post auth method not supported, not a valid oauth 2.1 server, and honestly not a good oauth 2.0 server either"
8787+ ));
8888+ }
8989+ if self
9090+ .code_challenge_methods_supported
9191+ .as_ref()
9292+ .is_none_or(|m| !m.contains(&CodeChallengeMethod::S256))
9393+ {
9494+ return Err(eyre!(
9595+ "auth server does not support pkce, or does not support S256 pkce"
9696+ ));
9797+ }
9898+ if self.authorization_endpoint.host() != self.issuer.host() {
9999+ return Err(eyre!(
100100+ "authorization endpoint not on issuer server: {} vs {}",
101101+ self.authorization_endpoint.as_str(),
102102+ self.issuer.as_str()
103103+ ));
104104+ }
105105+ if self.token_endpoint.host() != self.issuer.host() {
106106+ return Err(eyre!("token endpoint not on issuer server"));
107107+ }
108108+ if self
109109+ .jwks_uri
110110+ .as_ref()
111111+ .is_some_and(|jwks_uri| jwks_uri.host() != self.issuer.host())
112112+ {
113113+ return Err(eyre!("jwks uri not on issuer server"));
114114+ }
115115+ // the rest need to be checked if we ever use them
116116+117117+ // signing methods only checked if we want to actually verify the id tokens (see the
118118+ // config)
119119+120120+ Ok(())
121121+ }
122122+ }
123123+124124+ /// [OAuth 2.0 Dynamic Client Registration response types][rfc:7591#2]
125125+ ///
126126+ /// as linked in the [`AuthServerMetadata::response_types_supported`] specification.
127127+ #[derive(Deserialize, Eq, PartialEq, Hash)]
128128+ #[serde(rename_all = "snake_case")]
129129+ pub enum ResponseType {
130130+ Code,
131131+ Token,
132132+ #[serde(untagged)]
133133+ Other(String),
134134+ }
135135+136136+ /// per <rfc8414#2> this is a mix of [OAuth.Response], and [OAuth.Post] from the openid folks.
137137+ ///
138138+ ///
139139+ /// [OAuth.Response]: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html
140140+ /// [OAuth.Post]: https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html
141141+ ///
142142+ /// as linked in the [`AuthServerMetadata::response_types_supported`] specification.
143143+ ///
144144+ /// defaults to [`ResponseMode::Query`] and [`ResponseMode::Fragment`] if missing
145145+ #[derive(Deserialize, Eq, PartialEq, Hash)]
146146+ #[serde(rename_all = "snake_case")]
147147+ pub enum ResponseMode {
148148+ Query,
149149+ Fragment,
150150+ FormPost,
151151+ #[serde(untagged)]
152152+ Other(String),
153153+ }
154154+ impl ResponseMode {
155155+ pub fn is_supported_for(&self, options: Option<&HashSet<Self>>) -> bool {
156156+ match options {
157157+ Some(opts) => opts.contains(self),
158158+ // per the field documentation (see [`Self`])
159159+ None if *self == Self::Query => true,
160160+ None if *self == Self::Fragment => true,
161161+ None => false,
162162+ }
163163+ }
164164+ }
165165+166166+ /// [OAuth 2.0 Dynamic Client Registration grant types][rfc:7591#2]
167167+ ///
168168+ /// as linked in the [`AuthServerMetadata::grant_types_supported`] specification.
169169+ ///
170170+ /// defaults to [`GrantType::AuthorizationCode`] and [`GrantType::Implicit`] if missing per
171171+ /// the rfc, but oauth 2.1 [removes the implict grant][rfc:draft-ietf-v2.1#10.1]
172172+ #[derive(Deserialize, Eq, PartialEq, Hash)]
173173+ #[serde(rename_all = "snake_case")]
174174+ pub enum GrantType {
175175+ AuthorizationCode,
176176+ Implicit,
177177+ Password,
178178+ ClientCredentials,
179179+ RefreshToken,
180180+ #[serde(rename = "urn:ietf:params:oauth:grant-type:jwt-bearer")]
181181+ UrnIetfParamsOauthGrantTypeJwtBearer,
182182+ #[serde(rename = "urn:ietf:params:oauth:grant-type:saml2-bearer")]
183183+ UrnIetfParamsOauthGrantTypeSaml2Bearer,
184184+ #[serde(untagged)]
185185+ Other(String),
186186+ }
187187+ impl GrantType {
188188+ pub fn is_supported_for(&self, options: Option<&HashSet<Self>>) -> bool {
189189+ match options {
190190+ Some(opts) => opts.contains(self),
191191+ // per the field documentation (see [`Self`])
192192+ None if *self == Self::AuthorizationCode => true,
193193+ // NB: oauth 2.1 removes the implicit grant
194194+ // None if *self == Self::Implicit => true,
195195+ None => false,
196196+ }
197197+ }
198198+ }
199199+ /// [OAuth 2.0 Dynamic Client Registration token endpoint auth methods][rfc:7591#2]
200200+ ///
201201+ /// as linked in the [`AuthServerMetadata::token_endpoint_auth_methods_supported`] specification.
202202+ #[derive(Deserialize, Eq, PartialEq, Hash)]
203203+ #[serde(rename_all = "snake_case")]
204204+ pub enum AuthMethod {
205205+ None,
206206+ ClientSecretPost,
207207+ ClientSecretBasic,
208208+ #[serde(untagged)]
209209+ Other(String),
210210+ }
211211+ impl AuthMethod {
212212+ pub fn is_supported_for(&self, options: Option<&HashSet<Self>>) -> bool {
213213+ match options {
214214+ Some(opts) => opts.contains(self),
215215+ // per the field documentation (see [`Self`]), this MUST be supported
216216+ None if *self == Self::ClientSecretBasic => true,
217217+ // per <rfc:draft-ietf-oauth-v2.1#2.5>, servers MUST support this too
218218+ None if *self == Self::ClientSecretPost => true,
219219+ None => false,
220220+ }
221221+ }
222222+ }
223223+ /// JWT signing alg values
224224+ ///
225225+ /// as linked in the [`AuthServerMetadata::token_endpoint_auth_signing_alg_values_supported`] specification.
226226+ #[derive(Deserialize, Eq, PartialEq, Hash, Debug)]
227227+ pub enum SigningAlgValue {
228228+ /// NB: this must be rejected, but we capture it here to avoid it going into [`Self::Other`]
229229+ None,
230230+ RS256,
231231+ ES256,
232232+ #[serde(untagged)]
233233+ Other(String),
234234+ }
235235+ /// [PKCE challenge methods][rfc:7636#4.3]
236236+ ///
237237+ /// as linked in the [`AuthServerMetadata::code_challenge_methods_supported`] specification.
238238+ #[derive(Deserialize, Eq, PartialEq, Hash)]
239239+ pub enum CodeChallengeMethod {
240240+ /// should never be used, but we want to catch it so it doesn't go in Other
241241+ Plain,
242242+ S256,
243243+ #[serde(untagged)]
244244+ Other(String),
245245+ }
246246+247247+ /// get the [oidc discovery 1.0+errata 2][oidc-discovery-1] well-known endpoint for a given base
248248+ /// url
249249+ ///
250250+ /// users are expected to perform (both) issuer-id transformations themselves, if need-be
251251+ /// (per [rfc:8414#5])
252252+ ///
253253+ /// [oidc-discovery-1]: https://openid.net/specs/openid-connect-discovery-1_0.html
254254+ pub fn oidc_discovery_uri(base_url: &Url) -> color_eyre::Result<Url> {
255255+ Ok(base_url.join(".well-known/openid-configuration")?)
256256+ }
257257+}
258258+259259+pub mod auth_code_flow {
260260+ //! [`metadata::GrantType::AuthorizationCode`] flow
261261+ //!
262262+ //! # [oauth 2.1 authorization code grant][rfc:draft-ietf-oauth-v2.1#4.1]
263263+ //!
264264+ //! 1. [`self::code_request::redirect_to_auth_server`]
265265+ //! 2. [`self::code_response::receive_redirect`]
266266+ //! 3. [`self::token_request::request_access_token`]
267267+ //! 4. deserialize response into either [`self::token_response::Valid`] or
268268+ //! [`self::token_response::Error`]
269269+270270+ use std::borrow::Cow;
271271+272272+ use serde::Deserialize;
273273+274274+ /// auto-join/split authorization code scopes
275275+ #[derive(Deserialize)]
276276+ #[serde(try_from = "Cow<'_, str>")]
277277+ pub struct Scopes<'u>(Cow<'u, str>);
278278+279279+ #[allow(clippy::infallible_try_from, reason = "required for serde")]
280280+ impl<'u> TryFrom<Cow<'u, str>> for Scopes<'u> {
281281+ type Error = std::convert::Infallible;
282282+283283+ fn try_from(value: Cow<'u, str>) -> std::result::Result<Self, Self::Error> {
284284+ Ok(Self(value))
285285+ }
286286+ }
287287+ impl<'u> Scopes<'u> {
288288+ pub fn base_scopes() -> Self {
289289+ Self(Cow::Borrowed("openid"))
290290+ }
291291+ pub fn add_scope(mut self, scope: impl AsRef<str>) -> Self {
292292+ match &mut self.0 {
293293+ Cow::Borrowed(b) => {
294294+ self.0 = format!("{b} {}", scope.as_ref()).into();
295295+ }
296296+ Cow::Owned(v) => {
297297+ v.push(' ');
298298+ v.push_str(scope.as_ref());
299299+ }
300300+ }
301301+ self
302302+ }
303303+ }
304304+ impl<'u, S: AsRef<str>> FromIterator<S> for Scopes<'u> {
305305+ fn from_iter<T: IntoIterator<Item = S>>(iter: T) -> Self {
306306+ Self(
307307+ iter.into_iter()
308308+ .fold(String::new(), |mut acc, elem| {
309309+ if !acc.is_empty() {
310310+ acc.push(' ');
311311+ }
312312+ acc.push_str(elem.as_ref());
313313+ acc
314314+ })
315315+ .into(),
316316+ )
317317+ }
318318+ }
319319+320320+ /// Step 1
321321+ pub mod code_request {
322322+323323+ use color_eyre::Result;
324324+ use url::Url;
325325+326326+ use super::super::metadata::AuthServerMetadata;
327327+ use super::Scopes;
328328+329329+ /// data used to construct the initial authorization code browser redirect url
330330+ pub struct Data<'u> {
331331+ // owned cause it's unique every time
332332+ code_verifier: String,
333333+ client_id: &'u str,
334334+ scope: &'u Scopes<'u>,
335335+ state: uuid::Uuid,
336336+ redirect_uri: &'u Url,
337337+ }
338338+ impl<'u> Data<'u> {
339339+ pub fn new(client_id: &'u str, scope: &'u Scopes<'u>, redirect_uri: &'u Url) -> Self {
340340+ use base64::prelude::*;
341341+ use rand::Rng as _;
342342+ use sha2::Digest as _;
343343+344344+ let mut rng = rand::rngs::OsRng;
345345+ Self {
346346+ code_verifier: BASE64_URL_SAFE_NO_PAD
347347+ .encode(sha2::Sha256::digest(rng.r#gen::<[u8; 32]>())),
348348+ client_id,
349349+ scope,
350350+ state: uuid::Uuid::new_v4(),
351351+ redirect_uri,
352352+ }
353353+ }
354354+ }
355355+356356+ /// slice of [`super::metadata::AuthServerMetadata`] needed for the initial
357357+ /// [`redirect_to_auth_server`] call
358358+ pub struct Metadata<'u> {
359359+ authorization_endpoint: &'u Url,
360360+ }
361361+ impl<'u> From<&'u AuthServerMetadata> for Metadata<'u> {
362362+ fn from(orig: &'u AuthServerMetadata) -> Self {
363363+ Self {
364364+ authorization_endpoint: &orig.authorization_endpoint,
365365+ }
366366+ }
367367+ }
368368+369369+ /// the information required to start the authorization code flow and redirect a browser
370370+ pub struct RedirectInfo {
371371+ /// the url to send to the browser
372372+ pub url: Url,
373373+ /// the code verifier, to use when submitting the token request
374374+ pub code_verifier: String,
375375+ /// the state, to use to associate the authorization server's response back to the code
376376+ /// verifier and such
377377+ pub state: uuid::Uuid,
378378+ }
379379+380380+ /// construct the url used to redirect the user to the authorization server login
381381+ ///
382382+ /// take the returned state and use it to save the code verifier
383383+ pub fn redirect_to_auth_server(
384384+ meta: Metadata<'_>,
385385+ params: Data<'_>,
386386+ ) -> Result<RedirectInfo> {
387387+ let mut url = meta.authorization_endpoint.clone();
388388+ let mut query = url.query_pairs_mut();
389389+ query.append_pair("response_type", "code");
390390+ query.append_pair("client_id", params.client_id);
391391+ {
392392+ use base64::prelude::*;
393393+ use sha2::Digest as _;
394394+395395+ query.append_pair("code_challenge_method", "S256");
396396+ let challenge =
397397+ BASE64_URL_SAFE.encode(sha2::Sha256::digest(params.code_verifier.as_bytes()));
398398+ query.append_pair("code_challenge", &challenge);
399399+ }
400400+401401+ // NB: there's some optional oidc parameters here, but they're mostly worth skipping
402402+ // the main one that's useful is the nonce, but pkce takes the place of that and is
403403+ // more broadly standardized in oauth2 v2.1
404404+405405+ query.append_pair("redirect_uri", params.redirect_uri.as_str());
406406+ query.append_pair("scope", ¶ms.scope.0);
407407+ query.append_pair("state", ¶ms.state.to_string());
408408+ drop(query);
409409+410410+ Ok(RedirectInfo {
411411+ url,
412412+ code_verifier: params.code_verifier,
413413+ state: params.state,
414414+ })
415415+ }
416416+ }
417417+418418+ /// Step 2
419419+ pub mod code_response {
420420+ use std::borrow::Cow;
421421+422422+ use color_eyre::Result;
423423+ use serde::Deserialize;
424424+425425+ /// types of errors that a server can respond with, having failed the initial auth code request
426426+ #[derive(Deserialize)]
427427+ #[serde(rename_all = "snake_case")]
428428+ pub enum ErrorType<'u> {
429429+ // oauth 2.1 per [rfc:draft-ietf-oauth-v2-1#4.1.2.1]
430430+ InvalidRequest,
431431+ UnauthorizedClient,
432432+ AccessDenied,
433433+ UnsupportedResponseType,
434434+ InvalidScope,
435435+ ServerError,
436436+ TemporarilyUnavailable,
437437+438438+ // TODO(on-oss):
439439+ // [oidc-core-1](https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.2.6)
440440+ // defines some, but we don't really care about those for now, so they can go in "other"
441441+442442+ // anything else a server happens to randomly shove in there
443443+ #[serde(untagged)]
444444+ #[allow(
445445+ dead_code,
446446+ reason = "gonna make use of these shortly to surface errors better"
447447+ )]
448448+ Other(Cow<'u, str>),
449449+ }
450450+ /// an error that the server responds with when an authorization code request fails
451451+ #[allow(
452452+ dead_code,
453453+ reason = "gonna make use of these shortly to surface errors better"
454454+ )]
455455+ #[derive(Deserialize)]
456456+ pub struct Error<'u> {
457457+ pub error: ErrorType<'u>,
458458+ pub error_description: Option<Cow<'u, str>>,
459459+ pub error_uri: Option<Cow<'u, str>>,
460460+ pub state: uuid::Uuid,
461461+ pub iss: Option<Cow<'u, str>>,
462462+ }
463463+464464+ /// the query parameters for a successful authorization code response
465465+ pub struct Response<'u> {
466466+ pub code: Cow<'u, str>,
467467+ pub state: uuid::Uuid,
468468+ }
469469+470470+ pub fn receive_redirect<'u>(
471471+ query: &'u str,
472472+ issuer: &'u str,
473473+ ) -> Result<Result<Response<'u>, Error<'u>>> {
474474+ #[derive(Deserialize)]
475475+ #[serde(untagged)]
476476+ enum Params<'u> {
477477+ Valid {
478478+ code: Cow<'u, str>,
479479+ state: uuid::Uuid,
480480+ iss: Option<Cow<'u, str>>,
481481+ },
482482+ Error(Error<'u>),
483483+ }
484484+485485+ let params =
486486+ Params::deserialize(serde_html_form::Deserializer::from_bytes(query.as_bytes()))?;
487487+ let (code, state, iss) = match params {
488488+ Params::Valid { code, state, iss } => (code, state, iss),
489489+ Params::Error(err) => return Ok(Err(err)),
490490+ };
491491+ if iss.as_ref().is_some_and(|iss| iss != issuer) {
492492+ // NB: it's unlikely to happen except via a misconfiguration, but technically this
493493+ // could cause us to leak our in_progress states
494494+ // per [rfc:draft-ietf-oauth-v2-1#7.14], we _must_ validate this if present
495495+ return Err(color_eyre::eyre::eyre!("issuer mismatch"));
496496+ }
497497+498498+ Ok(Ok(Response { code, state }))
499499+ }
500500+ }
501501+502502+ /// Step 3
503503+ pub mod token_request {
504504+ use color_eyre::Result;
505505+ use url::Url;
506506+507507+ use super::super::metadata::AuthServerMetadata;
508508+ use super::code_response::Response;
509509+510510+ /// slice of [`super::metadata::AuthServerMetadata`] needed for the initial
511511+ /// [`request_access_token`] call
512512+ pub struct Metadata<'u> {
513513+ token_endpoint: &'u Url,
514514+ }
515515+ impl<'u> From<&'u AuthServerMetadata> for Metadata<'u> {
516516+ fn from(orig: &'u AuthServerMetadata) -> Self {
517517+ Self {
518518+ token_endpoint: &orig.token_endpoint,
519519+ }
520520+ }
521521+ }
522522+523523+ pub struct Data<'u> {
524524+ pub code: Response<'u>, // owned, is consumed
525525+ pub client_id: &'u str,
526526+ pub client_secret: &'u str,
527527+ // grant type is hardcoded for this
528528+ pub code_verifier: String, // owned, consumed
529529+ pub redirect_uri: &'u str,
530530+ }
531531+532532+ pub struct Request<'u> {
533533+ pub url: &'u Url,
534534+ }
535535+536536+ /// Step 4: request an access token
537537+ pub fn request_access_token<'u, 'm: 'u, BODY: form_urlencoded::Target>(
538538+ meta: Metadata<'m>,
539539+ data: Data<'u>,
540540+ body: BODY,
541541+ ) -> Result<Request<'u>> {
542542+ let mut body = form_urlencoded::Serializer::new(body);
543543+ body.append_pair("grant_type", "authorization_code");
544544+ body.append_pair("code", &data.code.code);
545545+ body.append_pair("redirect_uri", data.redirect_uri); // oauth 2.0 only
546546+ body.append_pair("client_id", data.client_id);
547547+ body.append_pair("code_verifier", &data.code_verifier);
548548+ body.append_pair("client_secret", data.client_secret);
549549+550550+ Ok(Request {
551551+ url: meta.token_endpoint,
552552+ })
553553+ }
554554+ }
555555+556556+ /// Step 4
557557+ pub mod token_response {
558558+ use serde::Deserialize;
559559+ use std::borrow::Cow;
560560+561561+ #[derive(Deserialize)]
562562+ pub struct Valid<'u> {
563563+ /// required per
564564+ /// [oidc-core-1](https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.3.1.3.3)
565565+ /// when oidc is in play
566566+ pub id_token: Option<Cow<'u, str>>,
567567+ }
568568+569569+ #[derive(Deserialize)]
570570+ #[serde(rename_all = "snake_case")]
571571+ pub enum ErrorType<'u> {
572572+ // oauth 2.1 per [rfc:draft-ietf-oauth-v2-1#4.3.2]
573573+ InvalidRequest,
574574+ InvalidClient,
575575+ InvalidGrant,
576576+ UnauthorizedClient,
577577+ UnsupportedGrantType,
578578+ InvalidScope,
579579+580580+ // anything else a server happens to randomly shove in there
581581+ #[serde(untagged)]
582582+ #[allow(dead_code, reason = "deserialization purposes")]
583583+ Other(Cow<'u, str>),
584584+ }
585585+ #[derive(Deserialize)]
586586+ #[allow(
587587+ dead_code,
588588+ reason = "gonna make use of these shortly to surface errors better"
589589+ )]
590590+ pub struct Error<'u> {
591591+ pub error: ErrorType<'u>,
592592+ pub error_description: Option<Cow<'u, str>>,
593593+ pub error_uri: Option<Cow<'u, str>>,
594594+ }
595595+ }
596596+}