forked from
smokesignal.events/smokesignal
i18n+filtering fork - fluent-templates v2
1use anyhow::Result;
2use axum::{
3 extract::State,
4 response::{IntoResponse, Redirect},
5};
6use axum_extra::extract::{
7 cookie::{Cookie, SameSite},
8 Form, PrivateCookieJar,
9};
10use deadpool_redis::redis::AsyncCommands as _;
11use minijinja::context as template_context;
12use p256::SecretKey;
13use serde::{Deserialize, Serialize};
14use std::borrow::Cow;
15
16use crate::jose_errors::JwkError;
17use crate::storage::errors::CacheError;
18
19use crate::{
20 contextual_error,
21 oauth::oauth_complete,
22 select_template,
23 storage::{
24 cache::OAUTH_REFRESH_QUEUE,
25 handle::handle_for_did,
26 oauth::{oauth_request_get, oauth_request_remove, oauth_session_insert},
27 },
28};
29
30use super::{
31 context::WebContext,
32 errors::{LoginError, WebError},
33 middleware_auth::{WebSession, AUTH_COOKIE_NAME},
34 middleware_i18n::Language,
35};
36
37#[derive(Deserialize, Serialize)]
38pub struct OAuthCallbackForm {
39 pub state: Option<String>,
40 pub iss: Option<String>,
41 pub code: Option<String>,
42}
43
44pub async fn handle_oauth_callback(
45 State(web_context): State<WebContext>,
46 Language(language): Language,
47 jar: PrivateCookieJar,
48 Form(callback_form): Form<OAuthCallbackForm>,
49) -> Result<impl IntoResponse, WebError> {
50 let default_context = template_context! {
51 language => language.to_string(),
52 canonical_url => format!("https://{}/oauth/callback", web_context.config.external_base),
53 };
54
55 let error_template = select_template!(false, false, language);
56
57 let (callback_code, callback_iss, callback_state) =
58 match (callback_form.code, callback_form.iss, callback_form.state) {
59 (Some(x), Some(y), Some(z)) => (x, y, z),
60 _ => {
61 return contextual_error!(
62 web_context,
63 language,
64 error_template,
65 default_context,
66 LoginError::OAuthCallbackIncomplete
67 );
68 }
69 };
70
71 let oauth_request = oauth_request_get(&web_context.pool, &callback_state).await;
72 if let Err(err) = oauth_request {
73 return contextual_error!(web_context, language, error_template, default_context, err);
74 }
75
76 let oauth_request = oauth_request.unwrap();
77
78 if oauth_request.issuer != callback_iss {
79 return contextual_error!(
80 web_context,
81 language,
82 error_template,
83 default_context,
84 LoginError::OAuthIssuerMismatch
85 );
86 }
87
88 let handle = handle_for_did(&web_context.pool, &oauth_request.did).await;
89 if let Err(err) = handle {
90 return contextual_error!(web_context, language, error_template, default_context, err);
91 }
92
93 let handle = handle.unwrap();
94
95 let secret_signing_key = web_context
96 .config
97 .signing_keys
98 .as_ref()
99 .get(&oauth_request.secret_jwk_id)
100 .cloned()
101 .ok_or(JwkError::SecretKeyNotFound);
102
103 if let Err(err) = secret_signing_key {
104 return contextual_error!(web_context, language, error_template, default_context, err);
105 }
106 let secret_signing_key = secret_signing_key.unwrap();
107
108 let dpop_secret_key = SecretKey::from_jwk(&oauth_request.dpop_jwk.jwk);
109
110 if let Err(err) = dpop_secret_key {
111 return contextual_error!(web_context, language, error_template, default_context, err);
112 }
113 let dpop_secret_key = dpop_secret_key.unwrap();
114
115 let token_response = oauth_complete(
116 &web_context.http_client,
117 &web_context.config.external_base,
118 (&oauth_request.secret_jwk_id, secret_signing_key),
119 &callback_code,
120 &oauth_request,
121 &handle,
122 &dpop_secret_key,
123 )
124 .await;
125 if let Err(err) = token_response {
126 return contextual_error!(web_context, language, error_template, default_context, err);
127 }
128
129 let token_response = token_response.unwrap();
130
131 if let Err(err) = oauth_request_remove(&web_context.pool, &oauth_request.oauth_state).await {
132 tracing::error!(error = ?err, "Unable to remove oauth_request");
133 }
134
135 let session_group = ulid::Ulid::new().to_string();
136 let now = chrono::Utc::now();
137
138 if let Err(err) = oauth_session_insert(
139 &web_context.pool,
140 crate::storage::oauth::OAuthSessionParams {
141 session_group: Cow::Owned(session_group.clone()),
142 access_token: Cow::Owned(token_response.access_token.clone()),
143 did: Cow::Owned(token_response.sub.clone()),
144 issuer: Cow::Owned(oauth_request.issuer.clone()),
145 refresh_token: Cow::Owned(token_response.refresh_token.clone()),
146 secret_jwk_id: Cow::Owned(oauth_request.secret_jwk_id.clone()),
147 dpop_jwk: oauth_request.dpop_jwk.0.clone(),
148 created_at: now,
149 access_token_expires_at: now
150 + chrono::Duration::seconds(token_response.expires_in as i64),
151 },
152 )
153 .await
154 {
155 return contextual_error!(web_context, language, error_template, default_context, err);
156 }
157
158 {
159 let mut conn = web_context
160 .cache_pool
161 .get()
162 .await
163 .map_err(CacheError::FailedToGetConnection)?;
164
165 let modified_expires_at = ((token_response.expires_in as f64) * 0.8).round() as i64;
166 let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis();
167
168 let _: () = conn
169 .zadd(OAUTH_REFRESH_QUEUE, &session_group, refresh_at)
170 .await
171 .map_err(CacheError::FailedToPlaceInRefreshQueue)?;
172 }
173
174 let cookie_value: String = WebSession {
175 did: token_response.sub.clone(),
176 session_group: session_group.clone(),
177 }
178 .try_into()?;
179
180 let mut cookie = Cookie::new(AUTH_COOKIE_NAME, cookie_value);
181 cookie.set_domain(web_context.config.external_base.clone());
182 cookie.set_path("/");
183 cookie.set_http_only(true);
184 cookie.set_secure(true);
185 cookie.set_max_age(Some(cookie::time::Duration::days(1)));
186 cookie.set_same_site(Some(SameSite::Lax));
187
188 let updated_jar = jar.add(cookie);
189
190 let destination = match oauth_request.destination {
191 Some(destination) => destination,
192 None => "/".to_string(),
193 };
194
195 Ok((updated_jar, Redirect::to(&destination)).into_response())
196}