this repo has no description
1use crate::notifications::{NotificationChannel, channel_display_name, enqueue_2fa_code};
2use crate::oauth::{
3 Code, DeviceAccount, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, templates,
4};
5use crate::state::{AppState, RateLimitKind};
6use axum::{
7 Form, Json,
8 extract::{Query, State},
9 http::{
10 HeaderMap, StatusCode,
11 header::{LOCATION, SET_COOKIE},
12 },
13 response::{Html, IntoResponse, Redirect, Response},
14};
15use chrono::Utc;
16use serde::{Deserialize, Serialize};
17use subtle::ConstantTimeEq;
18use urlencoding::encode as url_encode;
19
20const DEVICE_COOKIE_NAME: &str = "oauth_device_id";
21
22fn redirect_see_other(uri: &str) -> Response {
23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response()
24}
25
26fn extract_device_cookie(headers: &HeaderMap) -> Option<String> {
27 headers
28 .get("cookie")
29 .and_then(|v| v.to_str().ok())
30 .and_then(|cookie_str| {
31 for cookie in cookie_str.split(';') {
32 let cookie = cookie.trim();
33 if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) {
34 return Some(value.to_string());
35 }
36 }
37 None
38 })
39}
40
41fn extract_client_ip(headers: &HeaderMap) -> String {
42 if let Some(forwarded) = headers.get("x-forwarded-for")
43 && let Ok(value) = forwarded.to_str()
44 && let Some(first_ip) = value.split(',').next() {
45 return first_ip.trim().to_string();
46 }
47 if let Some(real_ip) = headers.get("x-real-ip")
48 && let Ok(value) = real_ip.to_str() {
49 return value.trim().to_string();
50 }
51 "0.0.0.0".to_string()
52}
53
54fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
55 headers
56 .get("user-agent")
57 .and_then(|v| v.to_str().ok())
58 .map(|s| s.to_string())
59}
60
61fn make_device_cookie(device_id: &str) -> String {
62 format!(
63 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000",
64 DEVICE_COOKIE_NAME, device_id
65 )
66}
67
68#[derive(Debug, Deserialize)]
69pub struct AuthorizeQuery {
70 pub request_uri: Option<String>,
71 pub client_id: Option<String>,
72 pub new_account: Option<bool>,
73}
74
75#[derive(Debug, Serialize)]
76pub struct AuthorizeResponse {
77 pub client_id: String,
78 pub client_name: Option<String>,
79 pub scope: Option<String>,
80 pub redirect_uri: String,
81 pub state: Option<String>,
82 pub login_hint: Option<String>,
83}
84
85#[derive(Debug, Deserialize)]
86pub struct AuthorizeSubmit {
87 pub request_uri: String,
88 pub username: String,
89 pub password: String,
90 #[serde(default)]
91 pub remember_device: bool,
92}
93
94#[derive(Debug, Deserialize)]
95pub struct AuthorizeSelectSubmit {
96 pub request_uri: String,
97 pub did: String,
98}
99
100fn wants_json(headers: &HeaderMap) -> bool {
101 headers
102 .get("accept")
103 .and_then(|v| v.to_str().ok())
104 .map(|accept| accept.contains("application/json"))
105 .unwrap_or(false)
106}
107
108pub async fn authorize_get(
109 State(state): State<AppState>,
110 headers: HeaderMap,
111 Query(query): Query<AuthorizeQuery>,
112) -> Response {
113 let request_uri = match query.request_uri {
114 Some(uri) => uri,
115 None => {
116 if wants_json(&headers) {
117 return (
118 axum::http::StatusCode::BAD_REQUEST,
119 Json(serde_json::json!({
120 "error": "invalid_request",
121 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization."
122 })),
123 ).into_response();
124 }
125 return (
126 axum::http::StatusCode::BAD_REQUEST,
127 Html(templates::error_page(
128 "invalid_request",
129 Some("Missing request_uri parameter. Use PAR to initiate authorization."),
130 )),
131 )
132 .into_response();
133 }
134 };
135 let request_data = match db::get_authorization_request(&state.db, &request_uri).await {
136 Ok(Some(data)) => data,
137 Ok(None) => {
138 if wants_json(&headers) {
139 return (
140 axum::http::StatusCode::BAD_REQUEST,
141 Json(serde_json::json!({
142 "error": "invalid_request",
143 "error_description": "Invalid or expired request_uri. Please start a new authorization request."
144 })),
145 ).into_response();
146 }
147 return (
148 axum::http::StatusCode::BAD_REQUEST,
149 Html(templates::error_page(
150 "invalid_request",
151 Some(
152 "Invalid or expired request_uri. Please start a new authorization request.",
153 ),
154 )),
155 )
156 .into_response();
157 }
158 Err(e) => {
159 if wants_json(&headers) {
160 return (
161 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
162 Json(serde_json::json!({
163 "error": "server_error",
164 "error_description": format!("Database error: {:?}", e)
165 })),
166 )
167 .into_response();
168 }
169 return (
170 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
171 Html(templates::error_page(
172 "server_error",
173 Some(&format!("Database error: {:?}", e)),
174 )),
175 )
176 .into_response();
177 }
178 };
179 if request_data.expires_at < Utc::now() {
180 let _ = db::delete_authorization_request(&state.db, &request_uri).await;
181 if wants_json(&headers) {
182 return (
183 axum::http::StatusCode::BAD_REQUEST,
184 Json(serde_json::json!({
185 "error": "invalid_request",
186 "error_description": "Authorization request has expired. Please start a new request."
187 })),
188 ).into_response();
189 }
190 return (
191 axum::http::StatusCode::BAD_REQUEST,
192 Html(templates::error_page(
193 "invalid_request",
194 Some("Authorization request has expired. Please start a new request."),
195 )),
196 )
197 .into_response();
198 }
199 let client_cache = ClientMetadataCache::new(3600);
200 let client_name = client_cache
201 .get(&request_data.parameters.client_id)
202 .await
203 .ok()
204 .and_then(|m| m.client_name);
205 if wants_json(&headers) {
206 return Json(AuthorizeResponse {
207 client_id: request_data.parameters.client_id.clone(),
208 client_name: client_name.clone(),
209 scope: request_data.parameters.scope.clone(),
210 redirect_uri: request_data.parameters.redirect_uri.clone(),
211 state: request_data.parameters.state.clone(),
212 login_hint: request_data.parameters.login_hint.clone(),
213 })
214 .into_response();
215 }
216 let force_new_account = query.new_account.unwrap_or(false);
217 if !force_new_account
218 && let Some(device_id) = extract_device_cookie(&headers)
219 && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await
220 && !accounts.is_empty() {
221 let device_accounts: Vec<DeviceAccount> = accounts
222 .into_iter()
223 .map(|row| DeviceAccount {
224 did: row.did,
225 handle: row.handle,
226 email: row.email,
227 last_used_at: row.last_used_at,
228 })
229 .collect();
230 return Html(templates::account_selector_page(
231 &request_data.parameters.client_id,
232 client_name.as_deref(),
233 &request_uri,
234 &device_accounts,
235 ))
236 .into_response();
237 }
238 Html(templates::login_page(
239 &request_data.parameters.client_id,
240 client_name.as_deref(),
241 request_data.parameters.scope.as_deref(),
242 &request_uri,
243 None,
244 request_data.parameters.login_hint.as_deref(),
245 ))
246 .into_response()
247}
248
249pub async fn authorize_get_json(
250 State(state): State<AppState>,
251 Query(query): Query<AuthorizeQuery>,
252) -> Result<Json<AuthorizeResponse>, OAuthError> {
253 let request_uri = query
254 .request_uri
255 .ok_or_else(|| OAuthError::InvalidRequest("request_uri is required".to_string()))?;
256 let request_data = db::get_authorization_request(&state.db, &request_uri)
257 .await?
258 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
259 if request_data.expires_at < Utc::now() {
260 db::delete_authorization_request(&state.db, &request_uri).await?;
261 return Err(OAuthError::InvalidRequest(
262 "request_uri has expired".to_string(),
263 ));
264 }
265 Ok(Json(AuthorizeResponse {
266 client_id: request_data.parameters.client_id.clone(),
267 client_name: None,
268 scope: request_data.parameters.scope.clone(),
269 redirect_uri: request_data.parameters.redirect_uri.clone(),
270 state: request_data.parameters.state.clone(),
271 login_hint: request_data.parameters.login_hint.clone(),
272 }))
273}
274
275pub async fn authorize_post(
276 State(state): State<AppState>,
277 headers: HeaderMap,
278 Form(form): Form<AuthorizeSubmit>,
279) -> Response {
280 let json_response = wants_json(&headers);
281 let client_ip = extract_client_ip(&headers);
282 if !state
283 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip)
284 .await
285 {
286 tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded");
287 if json_response {
288 return (
289 axum::http::StatusCode::TOO_MANY_REQUESTS,
290 Json(serde_json::json!({
291 "error": "RateLimitExceeded",
292 "error_description": "Too many login attempts. Please try again later."
293 })),
294 )
295 .into_response();
296 }
297 return (
298 axum::http::StatusCode::TOO_MANY_REQUESTS,
299 Html(templates::error_page(
300 "RateLimitExceeded",
301 Some("Too many login attempts. Please try again later."),
302 )),
303 )
304 .into_response();
305 }
306 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
307 Ok(Some(data)) => data,
308 Ok(None) => {
309 if json_response {
310 return (
311 axum::http::StatusCode::BAD_REQUEST,
312 Json(serde_json::json!({
313 "error": "invalid_request",
314 "error_description": "Invalid or expired request_uri."
315 })),
316 )
317 .into_response();
318 }
319 return Html(templates::error_page(
320 "invalid_request",
321 Some("Invalid or expired request_uri. Please start a new authorization request."),
322 ))
323 .into_response();
324 }
325 Err(e) => {
326 if json_response {
327 return (
328 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
329 Json(serde_json::json!({
330 "error": "server_error",
331 "error_description": format!("Database error: {:?}", e)
332 })),
333 )
334 .into_response();
335 }
336 return Html(templates::error_page(
337 "server_error",
338 Some(&format!("Database error: {:?}", e)),
339 ))
340 .into_response();
341 }
342 };
343 if request_data.expires_at < Utc::now() {
344 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
345 if json_response {
346 return (
347 axum::http::StatusCode::BAD_REQUEST,
348 Json(serde_json::json!({
349 "error": "invalid_request",
350 "error_description": "Authorization request has expired."
351 })),
352 )
353 .into_response();
354 }
355 return Html(templates::error_page(
356 "invalid_request",
357 Some("Authorization request has expired. Please start a new request."),
358 ))
359 .into_response();
360 }
361 let client_cache = ClientMetadataCache::new(3600);
362 let client_name = client_cache
363 .get(&request_data.parameters.client_id)
364 .await
365 .ok()
366 .and_then(|m| m.client_name);
367 let show_login_error = |error_msg: &str, json: bool| -> Response {
368 if json {
369 return (
370 axum::http::StatusCode::FORBIDDEN,
371 Json(serde_json::json!({
372 "error": "access_denied",
373 "error_description": error_msg
374 })),
375 )
376 .into_response();
377 }
378 Html(templates::login_page(
379 &request_data.parameters.client_id,
380 client_name.as_deref(),
381 request_data.parameters.scope.as_deref(),
382 &form.request_uri,
383 Some(error_msg),
384 Some(&form.username),
385 ))
386 .into_response()
387 };
388 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
389 let normalized_username = form.username.trim();
390 let normalized_username = normalized_username
391 .strip_prefix('@')
392 .unwrap_or(normalized_username);
393 let normalized_username = if let Some(bare_handle) =
394 normalized_username.strip_suffix(&format!(".{}", pds_hostname))
395 {
396 bare_handle.to_string()
397 } else {
398 normalized_username.to_string()
399 };
400 tracing::debug!(
401 original_username = %form.username,
402 normalized_username = %normalized_username,
403 pds_hostname = %pds_hostname,
404 "Normalized username for lookup"
405 );
406 let user = match sqlx::query!(
407 r#"
408 SELECT id, did, email, password_hash, two_factor_enabled,
409 preferred_notification_channel as "preferred_notification_channel: NotificationChannel",
410 deactivated_at, takedown_ref
411 FROM users
412 WHERE handle = $1 OR email = $1
413 "#,
414 normalized_username
415 )
416 .fetch_optional(&state.db)
417 .await
418 {
419 Ok(Some(u)) => u,
420 Ok(None) => {
421 let _ = bcrypt::verify(&form.password, "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK");
422 return show_login_error("Invalid handle/email or password.", json_response);
423 }
424 Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
425 };
426 if user.deactivated_at.is_some() {
427 return show_login_error("This account has been deactivated.", json_response);
428 }
429 if user.takedown_ref.is_some() {
430 return show_login_error("This account has been taken down.", json_response);
431 }
432 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) {
433 Ok(valid) => valid,
434 Err(_) => return show_login_error("An error occurred. Please try again.", json_response),
435 };
436 if !password_valid {
437 return show_login_error("Invalid handle/email or password.", json_response);
438 }
439 if user.two_factor_enabled {
440 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
441 match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await {
442 Ok(challenge) => {
443 let hostname =
444 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
445 if let Err(e) =
446 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await
447 {
448 tracing::warn!(
449 did = %user.did,
450 error = %e,
451 "Failed to enqueue 2FA notification"
452 );
453 }
454 let channel_name = channel_display_name(user.preferred_notification_channel);
455 let redirect_url = format!(
456 "/oauth/authorize/2fa?request_uri={}&channel={}",
457 url_encode(&form.request_uri),
458 url_encode(channel_name)
459 );
460 return Redirect::temporary(&redirect_url).into_response();
461 }
462 Err(_) => {
463 return show_login_error("An error occurred. Please try again.", json_response);
464 }
465 }
466 }
467 let code = Code::generate();
468 let mut device_id: Option<String> = extract_device_cookie(&headers);
469 let mut new_cookie: Option<String> = None;
470 if form.remember_device {
471 let final_device_id = if let Some(existing_id) = &device_id {
472 existing_id.clone()
473 } else {
474 let new_id = DeviceId::generate();
475 let device_data = DeviceData {
476 session_id: SessionId::generate().0,
477 user_agent: extract_user_agent(&headers),
478 ip_address: extract_client_ip(&headers),
479 last_seen_at: Utc::now(),
480 };
481 if db::create_device(&state.db, &new_id.0, &device_data)
482 .await
483 .is_ok()
484 {
485 new_cookie = Some(make_device_cookie(&new_id.0));
486 device_id = Some(new_id.0.clone());
487 }
488 new_id.0
489 };
490 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await;
491 }
492 if db::update_authorization_request(
493 &state.db,
494 &form.request_uri,
495 &user.did,
496 device_id.as_deref(),
497 &code.0,
498 )
499 .await
500 .is_err()
501 {
502 return show_login_error("An error occurred. Please try again.", json_response);
503 }
504 let redirect_url = build_success_redirect(
505 &request_data.parameters.redirect_uri,
506 &code.0,
507 request_data.parameters.state.as_deref(),
508 );
509 if let Some(cookie) = new_cookie {
510 (
511 StatusCode::SEE_OTHER,
512 [(SET_COOKIE, cookie), (LOCATION, redirect_url)],
513 )
514 .into_response()
515 } else {
516 redirect_see_other(&redirect_url)
517 }
518}
519
520pub async fn authorize_select(
521 State(state): State<AppState>,
522 headers: HeaderMap,
523 Form(form): Form<AuthorizeSelectSubmit>,
524) -> Response {
525 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
526 Ok(Some(data)) => data,
527 Ok(None) => {
528 return Html(templates::error_page(
529 "invalid_request",
530 Some("Invalid or expired request_uri. Please start a new authorization request."),
531 ))
532 .into_response();
533 }
534 Err(_) => {
535 return Html(templates::error_page(
536 "server_error",
537 Some("An error occurred. Please try again."),
538 ))
539 .into_response();
540 }
541 };
542 if request_data.expires_at < Utc::now() {
543 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await;
544 return Html(templates::error_page(
545 "invalid_request",
546 Some("Authorization request has expired. Please start a new request."),
547 ))
548 .into_response();
549 }
550 let device_id = match extract_device_cookie(&headers) {
551 Some(id) => id,
552 None => {
553 return Html(templates::error_page(
554 "invalid_request",
555 Some("No device session found. Please sign in."),
556 ))
557 .into_response();
558 }
559 };
560 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await {
561 Ok(valid) => valid,
562 Err(_) => {
563 return Html(templates::error_page(
564 "server_error",
565 Some("An error occurred. Please try again."),
566 ))
567 .into_response();
568 }
569 };
570 if !account_valid {
571 return Html(templates::error_page(
572 "access_denied",
573 Some("This account is not available on this device. Please sign in."),
574 ))
575 .into_response();
576 }
577 let user = match sqlx::query!(
578 r#"
579 SELECT id, two_factor_enabled,
580 preferred_notification_channel as "preferred_notification_channel: NotificationChannel"
581 FROM users
582 WHERE did = $1
583 "#,
584 form.did
585 )
586 .fetch_optional(&state.db)
587 .await
588 {
589 Ok(Some(u)) => u,
590 Ok(None) => {
591 return Html(templates::error_page(
592 "access_denied",
593 Some("Account not found. Please sign in."),
594 )).into_response();
595 }
596 Err(_) => {
597 return Html(templates::error_page(
598 "server_error",
599 Some("An error occurred. Please try again."),
600 )).into_response();
601 }
602 };
603 if user.two_factor_enabled {
604 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await;
605 match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await {
606 Ok(challenge) => {
607 let hostname =
608 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
609 if let Err(e) =
610 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await
611 {
612 tracing::warn!(
613 did = %form.did,
614 error = %e,
615 "Failed to enqueue 2FA notification"
616 );
617 }
618 let channel_name = channel_display_name(user.preferred_notification_channel);
619 let redirect_url = format!(
620 "/oauth/authorize/2fa?request_uri={}&channel={}",
621 url_encode(&form.request_uri),
622 url_encode(channel_name)
623 );
624 return Redirect::temporary(&redirect_url).into_response();
625 }
626 Err(_) => {
627 return Html(templates::error_page(
628 "server_error",
629 Some("An error occurred. Please try again."),
630 ))
631 .into_response();
632 }
633 }
634 }
635 let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await;
636 let code = Code::generate();
637 if db::update_authorization_request(
638 &state.db,
639 &form.request_uri,
640 &form.did,
641 Some(&device_id),
642 &code.0,
643 )
644 .await
645 .is_err()
646 {
647 return Html(templates::error_page(
648 "server_error",
649 Some("An error occurred. Please try again."),
650 ))
651 .into_response();
652 }
653 let redirect_url = build_success_redirect(
654 &request_data.parameters.redirect_uri,
655 &code.0,
656 request_data.parameters.state.as_deref(),
657 );
658 redirect_see_other(&redirect_url)
659}
660
661fn build_success_redirect(redirect_uri: &str, code: &str, state: Option<&str>) -> String {
662 let mut redirect_url = redirect_uri.to_string();
663 let separator = if redirect_url.contains('?') { '&' } else { '?' };
664 redirect_url.push(separator);
665 redirect_url.push_str(&format!("code={}", url_encode(code)));
666 if let Some(req_state) = state {
667 redirect_url.push_str(&format!("&state={}", url_encode(req_state)));
668 }
669 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
670 redirect_url.push_str(&format!(
671 "&iss={}",
672 url_encode(&format!("https://{}", pds_hostname))
673 ));
674 redirect_url
675}
676
677#[derive(Debug, Serialize)]
678pub struct AuthorizeDenyResponse {
679 pub error: String,
680 pub error_description: String,
681}
682
683pub async fn authorize_deny(
684 State(state): State<AppState>,
685 Form(form): Form<AuthorizeDenyForm>,
686) -> Result<Response, OAuthError> {
687 let request_data = db::get_authorization_request(&state.db, &form.request_uri)
688 .await?
689 .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?;
690 db::delete_authorization_request(&state.db, &form.request_uri).await?;
691 let redirect_uri = &request_data.parameters.redirect_uri;
692 let mut redirect_url = redirect_uri.to_string();
693 let separator = if redirect_url.contains('?') { '&' } else { '?' };
694 redirect_url.push(separator);
695 redirect_url.push_str("error=access_denied");
696 redirect_url.push_str("&error_description=User%20denied%20the%20request");
697 if let Some(state) = &request_data.parameters.state {
698 redirect_url.push_str(&format!("&state={}", url_encode(state)));
699 }
700 Ok(redirect_see_other(&redirect_url))
701}
702
703#[derive(Debug, Deserialize)]
704pub struct AuthorizeDenyForm {
705 pub request_uri: String,
706}
707
708#[derive(Debug, Deserialize)]
709pub struct Authorize2faQuery {
710 pub request_uri: String,
711 pub channel: Option<String>,
712}
713
714#[derive(Debug, Deserialize)]
715pub struct Authorize2faSubmit {
716 pub request_uri: String,
717 pub code: String,
718}
719
720const MAX_2FA_ATTEMPTS: i32 = 5;
721
722pub async fn authorize_2fa_get(
723 State(state): State<AppState>,
724 Query(query): Query<Authorize2faQuery>,
725) -> Response {
726 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await {
727 Ok(Some(c)) => c,
728 Ok(None) => {
729 return Html(templates::error_page(
730 "invalid_request",
731 Some("No 2FA challenge found. Please start over."),
732 ))
733 .into_response();
734 }
735 Err(_) => {
736 return Html(templates::error_page(
737 "server_error",
738 Some("An error occurred. Please try again."),
739 ))
740 .into_response();
741 }
742 };
743 if challenge.expires_at < Utc::now() {
744 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
745 return Html(templates::error_page(
746 "invalid_request",
747 Some("2FA code has expired. Please start over."),
748 ))
749 .into_response();
750 }
751 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await {
752 Ok(Some(d)) => d,
753 Ok(None) => {
754 return Html(templates::error_page(
755 "invalid_request",
756 Some("Authorization request not found. Please start over."),
757 ))
758 .into_response();
759 }
760 Err(_) => {
761 return Html(templates::error_page(
762 "server_error",
763 Some("An error occurred. Please try again."),
764 ))
765 .into_response();
766 }
767 };
768 let channel = query.channel.as_deref().unwrap_or("email");
769 Html(templates::two_factor_page(
770 &query.request_uri,
771 channel,
772 None,
773 ))
774 .into_response()
775}
776
777pub async fn authorize_2fa_post(
778 State(state): State<AppState>,
779 headers: HeaderMap,
780 Form(form): Form<Authorize2faSubmit>,
781) -> Response {
782 let client_ip = extract_client_ip(&headers);
783 if !state
784 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip)
785 .await
786 {
787 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded");
788 return (
789 axum::http::StatusCode::TOO_MANY_REQUESTS,
790 Html(templates::error_page(
791 "RateLimitExceeded",
792 Some("Too many attempts. Please try again later."),
793 )),
794 )
795 .into_response();
796 }
797 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await {
798 Ok(Some(c)) => c,
799 Ok(None) => {
800 return Html(templates::error_page(
801 "invalid_request",
802 Some("No 2FA challenge found. Please start over."),
803 ))
804 .into_response();
805 }
806 Err(_) => {
807 return Html(templates::error_page(
808 "server_error",
809 Some("An error occurred. Please try again."),
810 ))
811 .into_response();
812 }
813 };
814 if challenge.expires_at < Utc::now() {
815 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
816 return Html(templates::error_page(
817 "invalid_request",
818 Some("2FA code has expired. Please start over."),
819 ))
820 .into_response();
821 }
822 if challenge.attempts >= MAX_2FA_ATTEMPTS {
823 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
824 return Html(templates::error_page(
825 "access_denied",
826 Some("Too many failed attempts. Please start over."),
827 ))
828 .into_response();
829 }
830 let code_valid: bool = form
831 .code
832 .trim()
833 .as_bytes()
834 .ct_eq(challenge.code.as_bytes())
835 .into();
836 if !code_valid {
837 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await;
838 let channel = match sqlx::query_scalar!(
839 r#"SELECT preferred_notification_channel as "channel: NotificationChannel" FROM users WHERE did = $1"#,
840 challenge.did
841 )
842 .fetch_optional(&state.db)
843 .await
844 {
845 Ok(Some(ch)) => channel_display_name(ch).to_string(),
846 Ok(None) | Err(_) => "email".to_string(),
847 };
848 let _request_data = match db::get_authorization_request(&state.db, &form.request_uri).await
849 {
850 Ok(Some(d)) => d,
851 Ok(None) => {
852 return Html(templates::error_page(
853 "invalid_request",
854 Some("Authorization request not found. Please start over."),
855 ))
856 .into_response();
857 }
858 Err(_) => {
859 return Html(templates::error_page(
860 "server_error",
861 Some("An error occurred. Please try again."),
862 ))
863 .into_response();
864 }
865 };
866 return Html(templates::two_factor_page(
867 &form.request_uri,
868 &channel,
869 Some("Invalid verification code. Please try again."),
870 ))
871 .into_response();
872 }
873 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await;
874 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await {
875 Ok(Some(d)) => d,
876 Ok(None) => {
877 return Html(templates::error_page(
878 "invalid_request",
879 Some("Authorization request not found."),
880 ))
881 .into_response();
882 }
883 Err(_) => {
884 return Html(templates::error_page(
885 "server_error",
886 Some("An error occurred."),
887 ))
888 .into_response();
889 }
890 };
891 let code = Code::generate();
892 let device_id = extract_device_cookie(&headers);
893 if db::update_authorization_request(
894 &state.db,
895 &form.request_uri,
896 &challenge.did,
897 device_id.as_deref(),
898 &code.0,
899 )
900 .await
901 .is_err()
902 {
903 return Html(templates::error_page(
904 "server_error",
905 Some("An error occurred. Please try again."),
906 ))
907 .into_response();
908 }
909 let redirect_url = build_success_redirect(
910 &request_data.parameters.redirect_uri,
911 &code.0,
912 request_data.parameters.state.as_deref(),
913 );
914 redirect_see_other(&redirect_url)
915}