this repo has no description
1use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code};
2use crate::oauth::{
3 Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, templates,
4};
5use crate::state::{AppState, RateLimitKind};
6use axum::{
7 Form, Json,
8 extract::{Query, State},
9 http::{
10 HeaderMap, StatusCode,
11 header::{LOCATION, SET_COOKIE},
12 },
13 response::{Html, IntoResponse, Redirect, Response},
14};
15use chrono::Utc;
16use serde::{Deserialize, Serialize};
17use subtle::ConstantTimeEq;
18use urlencoding::encode as url_encode;
19
20const DEVICE_COOKIE_NAME: &str = "oauth_device_id";
21
22fn redirect_see_other(uri: &str) -> Response {
23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response()
24}
25
26fn extract_device_cookie(headers: &HeaderMap) -> Option<String> {
27 headers
28 .get("cookie")
29 .and_then(|v| v.to_str().ok())
30 .and_then(|cookie_str| {
31 for cookie in cookie_str.split(';') {
32 let cookie = cookie.trim();
33 if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) {
34 return Some(value.to_string());
35 }
36 }
37 None
38 })
39}
40
41fn extract_client_ip(headers: &HeaderMap) -> String {
42 if let Some(forwarded) = headers.get("x-forwarded-for")
43 && let Ok(value) = forwarded.to_str()
44 && let Some(first_ip) = value.split(',').next() {
45 return first_ip.trim().to_string();
46 }
47 if let Some(real_ip) = headers.get("x-real-ip")
48 && let Ok(value) = real_ip.to_str() {
49 return value.trim().to_string();
50 }
51 "0.0.0.0".to_string()
52}
53
54fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
55 headers
56 .get("user-agent")
57 .and_then(|v| v.to_str().ok())
58 .map(|s| s.to_string())
59}
60
61fn make_device_cookie(device_id: &str) -> String {
62 format!(
63 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000",
64 DEVICE_COOKIE_NAME, device_id
65 )
66}
67
68#[derive(Debug, Deserialize)]
69pub struct AuthorizeQuery {
70 pub request_uri: Option<String>,
71 pub client_id: Option<String>,
72 pub new_account: Option<bool>,
73}
74
75#[derive(Debug, Serialize)]
76pub struct AuthorizeResponse {
77 pub client_id: String,
78 pub client_name: Option<String>,
79 pub scope: Option<String>,
80 pub redirect_uri: String,
81 pub state: Option<String>,
82 pub login_hint: Option<String>,
83}
84
85#[derive(Debug, Deserialize)]
86pub struct AuthorizeSubmit {
87 pub request_uri: String,
88 pub username: String,
89 pub password: String,
90 #[serde(default)]
91 pub remember_device: bool,
92}
93
94#[derive(Debug, Deserialize)]
95pub struct AuthorizeSelectSubmit {
96 pub request_uri: String,
97 pub did: String,
98}
99
100fn wants_json(headers: &HeaderMap) -> bool {
101 headers
102 .get("accept")
103 .and_then(|v| v.to_str().ok())
104 .map(|accept| accept.contains("application/json"))
105 .unwrap_or(false)
106}
107
108pub async fn authorize_get(
109 State(state): State<AppState>,
110 headers: HeaderMap,
111 Query(query): Query<AuthorizeQuery>,
112) -> Response {
113 let request_uri = match query.request_uri {
114 Some(uri) => uri,
115 None => {
116 if wants_json(&headers) {
117 return (
118 axum::http::StatusCode::BAD_REQUEST,
119 Json(serde_json::json!({
120 "error": "invalid_request",
121 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization."
122 })),
123 ).into_response();
124 }
125 return (
126 axum::http::StatusCode::BAD_REQUEST,
127 Html(templates::error_page(
128 "invalid_request",
129 Some("Missing request_uri parameter. Use PAR to initiate authorization."),
130 )),
131 )
132 .into_response();
133 }
134 };
135 let request_data = match db::get_authorization_request(&state.db, &request_uri).await {
136 Ok(Some(data)) => data,
137 Ok(None) => {
138 if wants_json(&headers) {
139 return (
140 axum::http::StatusCode::BAD_REQUEST,
141 Json(serde_json::json!({
142 "error": "invalid_request",
143 "error_description": "Invalid or expired request_uri. Please start a new authorization request."
144 })),
145 ).into_response();
146 }
147 return (
148 axum::http::StatusCode::BAD_REQUEST,
149 Html(templates::error_page(
150 "invalid_request",
151 Some(
152 "Invalid or expired request_uri. Please start a new authorization request.",
153 ),
154 )),
155 )
156 .into_response();
157 }
158 Err(e) => {
159 if wants_json(&headers) {
160 return (
161 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
162 Json(serde_json::json!({
163 "error": "server_error",
164 "error_description": format!("Database error: {:?}", e)
165 })),
166 )
167 .into_response();
168 }
169 return (
170 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
171 Html(templates::error_page(
172 "server_error",
173 Some(&format!("Database error: {:?}", e)),
174 )),
175 )
176 .into_response();
177 }
178 };
179 if request_data.expires_at < Utc::now() {
180 let _ = db::delete_authorization_request(&state.db, &request_uri).await;
181 if wants_json(&headers) {
182 return (
183 axum::http::StatusCode::BAD_REQUEST,
184 Json(serde_json::json!({
185 "error": "invalid_request",
186 "error_description": "Authorization request has expired. Please start a new request."
187 })),
188 ).into_response();
189 }
190 return (
191 axum::http::StatusCode::BAD_REQUEST,
192 Html(templates::error_page(
193 "invalid_request",
194 Some("Authorization request has expired. Please start a new request."),
195 )),
196 )
197 .into_response();
198 }
199 let client_cache = ClientMetadataCache::new(3600);
200 let client_name = client_cache
201 .get(&request_data.parameters.client_id)
202 .await
203 .ok()
204 .and_then(|m| m.client_name);
205 if wants_json(&headers) {
206 return Json(AuthorizeResponse {
207 client_id: request_data.parameters.client_id.clone(),
208 client_name: client_name.clone(),
209 scope: request_data.parameters.scope.clone(),
210 redirect_uri: request_data.parameters.redirect_uri.clone(),
211 state: request_data.parameters.state.clone(),
212 login_hint: request_data.parameters.login_hint.clone(),
213 })
214 .into_response();
215 }
216 let force_new_account = query.new_account.unwrap_or(false);
217 if !force_new_account
218 && let Some(device_id) = extract_device_cookie(&headers)
219 && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await
220 && !accounts.is_empty() {
221 let device_accounts: Vec<DeviceAccount> = accounts
222 .into_iter()
223 .map(|row| DeviceAccount {
224 did: row.did,
225 handle: row.handle,
226 email: row.email,
227 last_used_at: row.last_used_at,
228 })
229 .collect();
230 return Html(templates::account_selector_page(
231 &request_data.parameters.client_id,
232 client_name.as_deref(),
233 &request_uri,
234 &device_accounts,
235 ))
236 .into_response();
237 }
238 Html(templates::login_page(
239 &request_data.parameters.client_id,
240 client_name.as_deref(),
241 request_data.parameters.scope.as_deref(),
242 &request_uri,
243 None,
244 request_data.parameters.login_hint.as_deref(),
245 ))
246 .into_response()
247}
248
249pub async fn authorize_get_json(
250 State(state): State<AppState>,
251 Query(query): Query<AuthorizeQuery>,
252) -> Result<Json<AuthorizeResponse>, OAuthError> {
253 let request_uri = query
254 .request_uri
255 .ok_or_else(|| OAuthError::InvalidRequest("request_uri is required".to_string()))?;
256 let request_data = db::get_authorization_request(&state.db, &request_uri)
257 .await?
258 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
259 if request_data.expires_at < Utc::now() {
260 db::delete_authorization_request(&state.db, &request_uri).await?;
261 return Err(OAuthError::InvalidRequest(
262 "request_uri has expired".to_string(),
263 ));
264 }
265 Ok(Json(AuthorizeResponse {
266 client_id: request_data.parameters.client_id.clone(),
267 client_name: None,
268 scope: request_data.parameters.scope.clone(),
269 redirect_uri: request_data.parameters.redirect_uri.clone(),
270 state: request_data.parameters.state.clone(),
271 login_hint: request_data.parameters.login_hint.clone(),
272 }))
273}
274
275pub async fn authorize_post(
276 State(state): State<AppState>,
277 headers: HeaderMap,
278 Form(form): Form<AuthorizeSubmit>,
279) -> Response {
280 let json_response = wants_json(&headers);
281 let client_ip = extract_client_ip(&headers);
282 if !state
283 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip)
284 .await
285 {
286 tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded");
287 if json_response {
288 return (
289 axum::http::StatusCode::TOO_MANY_REQUESTS,
290 Json(serde_json::json!({
291 "error": "RateLimitExceeded",
292 "error_description": "Too many login attempts. Please try again later."
293 })),
294 )
295 .into_response();
296 }
297 return (
298 axum::http::StatusCode::TOO_MANY_REQUESTS,
299 Html(templates::error_page(
300 "RateLimitExceeded",
301 Some("Too many login attempts. Please try again later."),
302 )),
303 )
304 .into_response();
305 }
306 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
307 Ok(Some(data)) => data,
308 Ok(None) => {
309 if json_response {
310 return (
311 axum::http::StatusCode::BAD_REQUEST,
312 Json(serde_json::json!({
313 "error": "invalid_request",
314 "error_description": "Invalid or expired request_uri."
315 })),
316 )
317 .into_response();
318 }
319 return Html(templates::error_page(
320 "invalid_request",
321 Some("Invalid or expired request_uri. Please start a new authorization request."),
322 ))
323 .into_response();
324 }
325 Err(e) => {
326 if json_response {
327 return (
328 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
329 Json(serde_json::json!({
330 "error": "server_error",
331 "error_description": format!("Database error: {:?}", e)
332 })),
333 )
334 .into_response();
335 }
336 return Html(templates::error_page(
337 "server_error",
338 Some(&format!("Database error: {:?}", e)),
339 ))
340 .into_response();
341 }
342 };
343 if request_data.expires_at < Utc::now() {
344 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
345 if json_response {
346 return (
347 axum::http::StatusCode::BAD_REQUEST,
348 Json(serde_json::json!({
349 "error": "invalid_request",
350 "error_description": "Authorization request has expired."
351 })),
352 )
353 .into_response();
354 }
355 return Html(templates::error_page(
356 "invalid_request",
357 Some("Authorization request has expired. Please start a new request."),
358 ))
359 .into_response();
360 }
361 let client_cache = ClientMetadataCache::new(3600);
362 let client_name = client_cache
363 .get(&request_data.parameters.client_id)
364 .await
365 .ok()
366 .and_then(|m| m.client_name);
367 let show_login_error = |error_msg: &str, json: bool| -> Response {
368 if json {
369 return (
370 axum::http::StatusCode::FORBIDDEN,
371 Json(serde_json::json!({
372 "error": "access_denied",
373 "error_description": error_msg
374 })),
375 )
376 .into_response();
377 }
378 Html(templates::login_page(
379 &request_data.parameters.client_id,
380 client_name.as_deref(),
381 request_data.parameters.scope.as_deref(),
382 &form.request_uri,
383 Some(error_msg),
384 Some(&form.username),
385 ))
386 .into_response()
387 };
388 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
389 let normalized_username = form.username.trim();
390 let normalized_username = normalized_username
391 .strip_prefix('@')
392 .unwrap_or(normalized_username);
393 let normalized_username = if let Some(bare_handle) =
394 normalized_username.strip_suffix(&format!(".{}", pds_hostname))
395 {
396 bare_handle.to_string()
397 } else {
398 normalized_username.to_string()
399 };
400 tracing::debug!(
401 original_username = %form.username,
402 normalized_username = %normalized_username,
403 pds_hostname = %pds_hostname,
404 "Normalized username for lookup"
405 );
406 let user = match sqlx::query!(
407 r#"
408 SELECT id, did, email, password_hash, two_factor_enabled,
409 preferred_comms_channel as "preferred_comms_channel: CommsChannel",
410 deactivated_at, takedown_ref,
411 email_verified, discord_verified, telegram_verified, signal_verified
412 FROM users
413 WHERE handle = $1 OR email = $1
414 "#,
415 normalized_username
416 )
417 .fetch_optional(&state.db)
418 .await
419 {
420 Ok(Some(u)) => u,
421 Ok(None) => {
422 let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK");
423 return show_login_error("Invalid handle/email or password.", json_response);
424 }
425 Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
426 };
427 if user.deactivated_at.is_some() {
428 return show_login_error("This account has been deactivated.", json_response);
429 }
430 if user.takedown_ref.is_some() {
431 return show_login_error("This account has been taken down.", json_response);
432 }
433 let is_verified = user.email_verified
434 || user.discord_verified
435 || user.telegram_verified
436 || user.signal_verified;
437 if !is_verified {
438 return show_login_error("Please verify your account before logging in.", json_response);
439 }
440 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) {
441 Ok(valid) => valid,
442 Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
443 };
444 if !password_valid {
445 return show_login_error("Invalid handle/email or password.", json_response);
446 }
447 if user.two_factor_enabled {
448 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
449 match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await {
450 Ok(challenge) => {
451 let hostname =
452 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
453 if let Err(e) =
454 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await
455 {
456 tracing::warn!(
457 did = %user.did,
458 error = %e,
459 "Failed to enqueue 2FA notification"
460 );
461 }
462 let channel_name = channel_display_name(user.preferred_comms_channel);
463 let redirect_url = format!(
464 "/oauth/authorize/2fa?request_uri={}&channel={}",
465 url_encode(&form.request_uri),
466 url_encode(channel_name)
467 );
468 return Redirect::temporary(&redirect_url).into_response();
469 }
470 Err(_) => {
471 return show_login_error("An error occurred. Please try again.", json_response);
472 }
473 }
474 }
475 let code = Code::generate();
476 let mut device_id: Option<String> = extract_device_cookie(&headers);
477 let mut new_cookie: Option<String> = None;
478 if form.remember_device {
479 let final_device_id = if let Some(existing_id) = &device_id {
480 existing_id.clone()
481 } else {
482 let new_id = DeviceId::generate();
483 let device_data = DeviceData {
484 session_id: SessionId::generate().0,
485 user_agent: extract_user_agent(&headers),
486 ip_address: extract_client_ip(&headers),
487 last_seen_at: Utc::now(),
488 };
489 if db::create_device(&state.db, &new_id.0, &device_data)
490 .await
491 .is_ok()
492 {
493 new_cookie = Some(make_device_cookie(&new_id.0));
494 device_id = Some(new_id.0.clone());
495 }
496 new_id.0
497 };
498 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await;
499 }
500 if db::update_authorization_request(
501 &state.db,
502 &form.request_uri,
503 &user.did,
504 device_id.as_deref(),
505 &code.0,
506 )
507 .await
508 .is_err()
509 {
510 return show_login_error("An error occurred. Please try again.", json_response);
511 }
512 let redirect_url = build_success_redirect(
513 &request_data.parameters.redirect_uri,
514 &code.0,
515 request_data.parameters.state.as_deref(),
516 );
517 if let Some(cookie) = new_cookie {
518 (
519 StatusCode::SEE_OTHER,
520 [(SET_COOKIE, cookie), (LOCATION, redirect_url)],
521 )
522 .into_response()
523 } else {
524 redirect_see_other(&redirect_url)
525 }
526}
527
528pub async fn authorize_select(
529 State(state): State<AppState>,
530 headers: HeaderMap,
531 Form(form): Form<AuthorizeSelectSubmit>,
532) -> Response {
533 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
534 Ok(Some(data)) => data,
535 Ok(None) => {
536 return Html(templates::error_page(
537 "invalid_request",
538 Some("Invalid or expired request_uri. Please start a new authorization request."),
539 ))
540 .into_response();
541 }
542 Err(_) => {
543 return Html(templates::error_page(
544 "server_error",
545 Some("An error occurred. Please try again."),
546 ))
547 .into_response();
548 }
549 };
550 if request_data.expires_at < Utc::now() {
551 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
552 return Html(templates::error_page(
553 "invalid_request",
554 Some("Authorization request has expired. Please start a new request."),
555 ))
556 .into_response();
557 }
558 let device_id = match extract_device_cookie(&headers) {
559 Some(id) => id,
560 None => {
561 return Html(templates::error_page(
562 "invalid_request",
563 Some("No device session found. Please sign in."),
564 ))
565 .into_response();
566 }
567 };
568 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await {
569 Ok(valid) => valid,
570 Err(_) => {
571 return Html(templates::error_page(
572 "server_error",
573 Some("An error occurred. Please try again."),
574 ))
575 .into_response();
576 }
577 };
578 if !account_valid {
579 return Html(templates::error_page(
580 "access_denied",
581 Some("This account is not available on this device. Please sign in."),
582 ))
583 .into_response();
584 }
585 let user = match sqlx::query!(
586 r#"
587 SELECT id, two_factor_enabled,
588 preferred_comms_channel as "preferred_comms_channel: CommsChannel",
589 email_verified, discord_verified, telegram_verified, signal_verified
590 FROM users
591 WHERE did = $1
592 "#,
593 form.did
594 )
595 .fetch_optional(&state.db)
596 .await
597 {
598 Ok(Some(u)) => u,
599 Ok(None) => {
600 return Html(templates::error_page(
601 "access_denied",
602 Some("Account not found. Please sign in."),
603 )).into_response();
604 }
605 Err(_) => {
606 return Html(templates::error_page(
607 "server_error",
608 Some("An error occurred. Please try again."),
609 )).into_response();
610 }
611 };
612 let is_verified = user.email_verified
613 || user.discord_verified
614 || user.telegram_verified
615 || user.signal_verified;
616 if !is_verified {
617 return Html(templates::error_page(
618 "access_denied",
619 Some("Please verify your account before logging in."),
620 ))
621 .into_response();
622 }
623 if user.two_factor_enabled {
624 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
625 match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await {
626 Ok(challenge) => {
627 let hostname =
628 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
629 if let Err(e) =
630 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await
631 {
632 tracing::warn!(
633 did = %form.did,
634 error = %e,
635 "Failed to enqueue 2FA notification"
636 );
637 }
638 let channel_name = channel_display_name(user.preferred_comms_channel);
639 let redirect_url = format!(
640 "/oauth/authorize/2fa?request_uri={}&channel={}",
641 url_encode(&form.request_uri),
642 url_encode(channel_name)
643 );
644 return Redirect::temporary(&redirect_url).into_response();
645 }
646 Err(_) => {
647 return Html(templates::error_page(
648 "server_error",
649 Some("An error occurred. Please try again."),
650 ))
651 .into_response();
652 }
653 }
654 }
655 let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await;
656 let code = Code::generate();
657 if db::update_authorization_request(
658 &state.db,
659 &form.request_uri,
660 &form.did,
661 Some(&device_id),
662 &code.0,
663 )
664 .await
665 .is_err()
666 {
667 return Html(templates::error_page(
668 "server_error",
669 Some("An error occurred. Please try again."),
670 ))
671 .into_response();
672 }
673 let redirect_url = build_success_redirect(
674 &request_data.parameters.redirect_uri,
675 &code.0,
676 request_data.parameters.state.as_deref(),
677 );
678 redirect_see_other(&redirect_url)
679}
680
681fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
682 let mut redirect_url = redirect_uri.to_string();
683 let separator = if redirect_url.contains('?') { '&' } else { '?' };
684 redirect_url.push(separator);
685 redirect_url.push_str(&format!("code={}", url_encode(code)));
686 if let Some(req_state) = state {
687 redirect_url.push_str(&format!("&state={}", url_encode(req_state)));
688 }
689 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
690 redirect_url.push_str(&format!(
691 "&iss={}",
692 url_encode(&format!("https://{}", pds_hostname))
693 ));
694 redirect_url
695}
696
697#[derive(Debug, Serialize)]
698pub struct AuthorizeDenyResponse {
699 pub error: String,
700 pub error_description: String,
701}
702
703pub async fn authorize_deny(
704 State(state): State<AppState>,
705 Form(form): Form<AuthorizeDenyForm>,
706) -> Result<Response, OAuthError> {
707 let request_data = db::get_authorization_request(&state.db, &form.request_uri)
708 .await?
709 .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?;
710 db::delete_authorization_request(&state.db, &form.request_uri).await?;
711 let redirect_uri = &request_data.parameters.redirect_uri;
712 let mut redirect_url = redirect_uri.to_string();
713 let separator = if redirect_url.contains('?') { '&' } else { '?' };
714 redirect_url.push(separator);
715 redirect_url.push_str("error=access_denied");
716 redirect_url.push_str("&error_description=User%20denied%20the%20request");
717 if let Some(state) = &request_data.parameters.state {
718 redirect_url.push_str(&format!("&state={}", url_encode(state)));
719 }
720 Ok(redirect_see_other(&redirect_url))
721}
722
723#[derive(Debug, Deserialize)]
724pub struct AuthorizeDenyForm {
725 pub request_uri: String,
726}
727
728#[derive(Debug, Deserialize)]
729pub struct Authorize2faQuery {
730 pub request_uri: String,
731 pub channel: Option<String>,
732}
733
734#[derive(Debug, Deserialize)]
735pub struct Authorize2faSubmit {
736 pub request_uri: String,
737 pub code: String,
738}
739
740const MAX_2FA_ATTEMPTS: i32 = 5;
741
742pub async fn authorize_2fa_get(
743 State(state): State<AppState>,
744 Query(query): Query<Authorize2faQuery>,
745) -> Response {
746 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await {
747 Ok(Some(c)) => c,
748 Ok(None) => {
749 return Html(templates::error_page(
750 "invalid_request",
751 Some("No 2FA challenge found. Please start over."),
752 ))
753 .into_response();
754 }
755 Err(_) => {
756 return Html(templates::error_page(
757 "server_error",
758 Some("An error occurred. Please try again."),
759 ))
760 .into_response();
761 }
762 };
763 if challenge.expires_at < Utc::now() {
764 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
765 return Html(templates::error_page(
766 "invalid_request",
767 Some("2FA code has expired. Please start over."),
768 ))
769 .into_response();
770 }
771 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await {
772 Ok(Some(d)) => d,
773 Ok(None) => {
774 return Html(templates::error_page(
775 "invalid_request",
776 Some("Authorization request not found. Please start over."),
777 ))
778 .into_response();
779 }
780 Err(_) => {
781 return Html(templates::error_page(
782 "server_error",
783 Some("An error occurred. Please try again."),
784 ))
785 .into_response();
786 }
787 };
788 let channel = query.channel.as_deref().unwrap_or("email");
789 Html(templates::two_factor_page(
790 &query.request_uri,
791 channel,
792 None,
793 ))
794 .into_response()
795}
796
797pub async fn authorize_2fa_post(
798 State(state): State<AppState>,
799 headers: HeaderMap,
800 Form(form): Form<Authorize2faSubmit>,
801) -> Response {
802 let client_ip = extract_client_ip(&headers);
803 if !state
804 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip)
805 .await
806 {
807 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded");
808 return (
809 axum::http::StatusCode::TOO_MANY_REQUESTS,
810 Html(templates::error_page(
811 "RateLimitExceeded",
812 Some("Too many attempts. Please try again later."),
813 )),
814 )
815 .into_response();
816 }
817 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await {
818 Ok(Some(c)) => c,
819 Ok(None) => {
820 return Html(templates::error_page(
821 "invalid_request",
822 Some("No 2FA challenge found. Please start over."),
823 ))
824 .into_response();
825 }
826 Err(_) => {
827 return Html(templates::error_page(
828 "server_error",
829 Some("An error occurred. Please try again."),
830 ))
831 .into_response();
832 }
833 };
834 if challenge.expires_at < Utc::now() {
835 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
836 return Html(templates::error_page(
837 "invalid_request",
838 Some("2FA code has expired. Please start over."),
839 ))
840 .into_response();
841 }
842 if challenge.attempts >= MAX_2FA_ATTEMPTS {
843 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
844 return Html(templates::error_page(
845 "access_denied",
846 Some("Too many failed attempts. Please start over."),
847 ))
848 .into_response();
849 }
850 let code_valid: bool = form
851 .code
852 .trim()
853 .as_bytes()
854 .ct_eq(challenge.code.as_bytes())
855 .into();
856 if !code_valid {
857 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await;
858 let channel = match sqlx::query_scalar!(
859 r#"SELECT preferred_comms_channel as "channel: CommsChannel" FROM users WHERE did = $1"#,
860 challenge.did
861 )
862 .fetch_optional(&state.db)
863 .await
864 {
865 Ok(Some(ch)) => channel_display_name(ch).to_string(),
866 Ok(None) | Err(_) => "email".to_string(),
867 };
868 let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await
869 {
870 Ok(Some(d)) => d,
871 Ok(None) => {
872 return Html(templates::error_page(
873 "invalid_request",
874 Some("Authorization request not found. Please start over."),
875 ))
876 .into_response();
877 }
878 Err(_) => {
879 return Html(templates::error_page(
880 "server_error",
881 Some("An error occurred. Please try again."),
882 ))
883 .into_response();
884 }
885 };
886 return Html(templates::two_factor_page(
887 &form.request_uri,
888 &channel,
889 Some("Invalid verification code. Please try again."),
890 ))
891 .into_response();
892 }
893 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
894 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
895 Ok(Some(d)) => d,
896 Ok(None) => {
897 return Html(templates::error_page(
898 "invalid_request",
899 Some("Authorization request not found."),
900 ))
901 .into_response();
902 }
903 Err(_) => {
904 return Html(templates::error_page(
905 "server_error",
906 Some("An error occurred."),
907 ))
908 .into_response();
909 }
910 };
911 let code = Code::generate();
912 let device_id = extract_device_cookie(&headers);
913 if db::update_authorization_request(
914 &state.db,
915 &form.request_uri,
916 &challenge.did,
917 device_id.as_deref(),
918 &code.0,
919 )
920 .await
921 .is_err()
922 {
923 return Html(templates::error_page(
924 "server_error",
925 Some("An error occurred. Please try again."),
926 ))
927 .into_response();
928 }
929 let redirect_url = build_success_redirect(
930 &request_data.parameters.redirect_uri,
931 &code.0,
932 request_data.parameters.state.as_deref(),
933 );
934 redirect_see_other(&redirect_url)
935}