forked from
smokesignal.events/smokesignal
i18n+filtering fork - fluent-templates v2
1use anyhow::Result;
2use axum::response::Redirect;
3use axum::{extract::State, response::IntoResponse};
4use axum_extra::extract::{Cached, Form, Query};
5use axum_htmx::{HxBoosted, HxRedirect, HxRequest};
6use base64::{engine::general_purpose, Engine as _};
7use http::StatusCode;
8use minijinja::context as template_context;
9use p256::SecretKey;
10use rand::{distributions::Alphanumeric, Rng};
11use serde::Deserialize;
12use sha2::{Digest, Sha256};
13use std::borrow::Cow;
14
15use crate::{
16 contextual_error, create_renderer,
17 did::{plc::query as plc_query, web::query as web_query},
18 http::{
19 context::WebContext,
20 errors::{CommonError, LoginError, WebError},
21 middleware_auth::Auth,
22 middleware_i18n::Language,
23 utils::stringify,
24 },
25 jose,
26 oauth::{oauth_init, pds_resources},
27 resolve::{parse_input, resolve_subject, InputType},
28 storage::{
29 denylist::denylist_exists,
30 handle::handle_warm_up,
31 oauth::{model::OAuthRequestState, oauth_request_insert},
32 },
33};
34
35#[derive(Deserialize)]
36pub struct OAuthLoginForm {
37 pub handle: Option<String>,
38 pub destination: Option<String>,
39}
40
41#[derive(Deserialize)]
42pub struct Destination {
43 pub destination: Option<String>,
44}
45
46pub async fn handle_oauth_login(
47 State(web_context): State<WebContext>,
48 Language(language): Language,
49 Cached(auth): Cached<Auth>,
50 HxRequest(hx_request): HxRequest,
51 HxBoosted(hx_boosted): HxBoosted,
52 Query(destination): Query<Destination>,
53 Form(login_form): Form<OAuthLoginForm>,
54) -> Result<impl IntoResponse, WebError> {
55 // Create the template renderer with enhanced context
56 let renderer = create_renderer!(web_context.clone(), Language(language.clone()), hx_boosted, hx_request);
57
58 let canonical_url = format!("https://{}/oauth/login", web_context.config.external_base);
59 let current_handle = auth.0.as_ref();
60 let is_development = cfg!(debug_assertions);
61
62 // Create comprehensive default context like the original
63 let default_context = template_context! {
64 current_handle => current_handle,
65 language => language.to_string(),
66 canonical_url => canonical_url.clone(),
67 destination => destination.destination,
68 is_development => is_development,
69 };
70
71 if let Some(subject) = login_form.handle {
72 let resolved_did = resolve_subject(
73 &web_context.http_client,
74 &web_context.dns_resolver,
75 &subject,
76 )
77 .await;
78
79 if let Err(err) = resolved_did {
80 let error_context = template_context! {
81 is_development,
82 oauth_login => true,
83 has_error => true,
84 handle_input => subject,
85 };
86 return contextual_error!(renderer: renderer, err, error_context);
87 }
88
89 let resolved_did = resolved_did.unwrap();
90
91 let query_results = match parse_input(&resolved_did) {
92 Ok(InputType::Plc(did)) => {
93 plc_query(
94 &web_context.http_client,
95 &web_context.config.plc_hostname,
96 &did,
97 )
98 .await
99 }
100 Ok(InputType::Web(did)) => web_query(&web_context.http_client, &did).await,
101 _ => Err(LoginError::NoHandle.into()),
102 };
103
104 let did_document = match query_results {
105 Ok(value) => value,
106 Err(err) => {
107 let error_context = template_context! {
108 is_development,
109 oauth_login => true,
110 has_error => true,
111 handle_input => subject,
112 };
113 return contextual_error!(renderer: renderer, err, error_context);
114 }
115 };
116
117 let mut lookup_values: Vec<&str> = vec![&resolved_did, &did_document.id];
118 if let Some(pds) = did_document.pds_endpoint() {
119 lookup_values.push(pds);
120 }
121
122 let handle_denied = match denylist_exists(&web_context.pool, &lookup_values).await {
123 Ok(value) => value,
124 Err(err) => {
125 return contextual_error!(
126 renderer: renderer,
127 err,
128 template_context! {
129 is_development,
130 oauth_login => true,
131 has_error => true,
132 handle_input => subject,
133 }
134 );
135 }
136 };
137
138 if handle_denied {
139 return contextual_error!(
140 renderer: renderer,
141 CommonError::NotAuthorized,
142 template_context! {
143 is_development,
144 oauth_login => true,
145 has_error => true,
146 handle_input => subject,
147 }
148 );
149 }
150
151 let pds = match did_document.pds_endpoint() {
152 Some(value) => value,
153 None => {
154 let error_context = template_context! {
155 is_development,
156 oauth_login => true,
157 has_error => true,
158 handle_input => subject,
159 };
160 return contextual_error!(renderer: renderer, LoginError::NoPDS, error_context);
161 }
162 };
163
164 let primary_handle = match did_document.primary_handle() {
165 Some(value) => value,
166 None => {
167 let error_context = template_context! {
168 is_development,
169 oauth_login => true,
170 has_error => true,
171 handle_input => subject,
172 };
173 return contextual_error!(renderer: renderer, LoginError::NoHandle, error_context);
174 }
175 };
176
177 if let Err(err) =
178 handle_warm_up(&web_context.pool, &did_document.id, primary_handle, pds).await
179 {
180 return contextual_error!(renderer: renderer, err, default_context);
181 }
182
183 let state: String = rand::thread_rng()
184 .sample_iter(&Alphanumeric)
185 .take(30)
186 .map(char::from)
187 .collect();
188 let nonce: String = rand::thread_rng()
189 .sample_iter(&Alphanumeric)
190 .take(30)
191 .map(char::from)
192 .collect();
193 let (pkce_verifier, code_challenge) = gen_pkce();
194
195 let oauth_request_state = OAuthRequestState {
196 state,
197 nonce,
198 code_challenge,
199 };
200
201 let pds_auth_resources = pds_resources(&web_context.http_client, pds).await;
202
203 if let Err(err) = pds_auth_resources {
204 return contextual_error!(
205 renderer: renderer,
206 err,
207 template_context! {
208 is_development,
209 oauth_login => true,
210 has_error => true,
211 handle_input => subject,
212 }
213 );
214 }
215
216 let (_, authorization_server) = pds_auth_resources.unwrap();
217 tracing::info!(authorization_server = ?authorization_server, "resolved authorization server");
218
219 let signing_key = web_context.config.select_oauth_signing_key();
220 if let Err(err) = signing_key {
221 return contextual_error!(
222 renderer: renderer,
223 err,
224 template_context! {
225 is_development,
226 oauth_login => true,
227 has_error => true,
228 handle_input => subject,
229 }
230 );
231 }
232
233 let (key_id, signing_key) = signing_key.unwrap();
234
235 let dpop_jwk = jose::jwk::generate();
236 let dpop_secret_key = SecretKey::from_jwk(&dpop_jwk.jwk);
237
238 if let Err(err) = dpop_secret_key {
239 return contextual_error!(
240 renderer: renderer,
241 err,
242 template_context! {
243 is_development,
244 oauth_login => true,
245 has_error => true,
246 handle_input => subject,
247 }
248 );
249 }
250
251 let dpop_secret_key = dpop_secret_key.unwrap();
252
253 let par_response = oauth_init(
254 &web_context.http_client,
255 &web_context.config.external_base,
256 (&key_id, signing_key),
257 &dpop_secret_key,
258 primary_handle,
259 &authorization_server,
260 &oauth_request_state,
261 )
262 .await;
263
264 if let Err(err) = par_response {
265 return contextual_error!(
266 renderer: renderer,
267 err,
268 template_context! {
269 is_development,
270 oauth_login => true,
271 has_error => true,
272 handle_input => subject,
273 }
274 );
275 }
276
277 let par_response = par_response.unwrap();
278
279 let created_at = chrono::Utc::now();
280 let expires_at = created_at + chrono::Duration::seconds(par_response.expires_in as i64);
281
282 if let Err(err) = oauth_request_insert(
283 &web_context.pool,
284 crate::storage::oauth::OAuthRequestParams {
285 oauth_state: Cow::Owned(oauth_request_state.state.clone()),
286 issuer: Cow::Owned(authorization_server.issuer.clone()),
287 did: Cow::Owned(did_document.id.clone()),
288 nonce: Cow::Owned(oauth_request_state.nonce.clone()),
289 pkce_verifier: Cow::Owned(pkce_verifier.clone()),
290 secret_jwk_id: Cow::Owned(key_id.clone()),
291 dpop_jwk: Some(dpop_jwk.clone()),
292 destination: login_form.destination.clone().map(Cow::Owned),
293 created_at,
294 expires_at,
295 },
296 )
297 .await
298 {
299 return contextual_error!(
300 renderer: renderer,
301 err,
302 template_context! {
303 is_development,
304 oauth_login => true,
305 has_error => true,
306 handle_input => subject,
307 }
308 );
309 }
310
311 let oauth_args = [
312 (
313 "request_uri".to_string(),
314 urlencoding::encode(&par_response.request_uri).to_string(),
315 ),
316 (
317 "client_id".to_string(),
318 urlencoding::encode(&format!(
319 "https://{}/oauth/client-metadata.json",
320 web_context.config.external_base
321 ))
322 .to_string(),
323 ),
324 ];
325 let oauth_args = oauth_args.iter().map(|(k, v)| (&**k, &**v)).collect();
326
327 let destination = format!(
328 "{}?{}",
329 authorization_server.authorization_endpoint,
330 stringify(oauth_args)
331 );
332
333 if hx_request {
334 if let Ok(hx_redirect) = HxRedirect::try_from(destination.as_str()) {
335 return Ok((StatusCode::OK, hx_redirect, "").into_response());
336 }
337 }
338
339 return Ok(Redirect::temporary(destination.as_str()).into_response());
340 }
341
342 let final_context = template_context! {
343 is_development,
344 oauth_login => true,
345 destination => destination.destination,
346 };
347
348 Ok(renderer.render_template(
349 "login",
350 final_context,
351 current_handle,
352 &canonical_url,
353 ))
354}
355
356pub fn gen_pkce() -> (String, String) {
357 let token: String = rand::thread_rng()
358 .sample_iter(&Alphanumeric)
359 .take(100)
360 .map(char::from)
361 .collect();
362 (token.clone(), pkce_challenge(&token))
363}
364
365pub fn pkce_challenge(token: &str) -> String {
366 let mut hasher = Sha256::new();
367 hasher.update(token.as_bytes());
368 let result = hasher.finalize();
369
370 general_purpose::URL_SAFE_NO_PAD.encode(result)
371}