this repo has no description
1use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code}; 2use crate::oauth::{ 3 Code, DeviceData, DeviceId, OAuthError, SessionId, client::ClientMetadataCache, db, 4}; 5use crate::state::{AppState, RateLimitKind}; 6use axum::{ 7 Json, 8 extract::{Query, State}, 9 http::{ 10 HeaderMap, StatusCode, 11 header::{LOCATION, SET_COOKIE}, 12 }, 13 response::{IntoResponse, Response}, 14}; 15use chrono::Utc; 16use serde::{Deserialize, Serialize}; 17use subtle::ConstantTimeEq; 18use urlencoding::encode as url_encode; 19 20const DEVICE_COOKIE_NAME: &str = "oauth_device_id"; 21 22fn redirect_see_other(uri: &str) -> Response { 23 (StatusCode::SEE_OTHER, [(LOCATION, uri.to_string())]).into_response() 24} 25 26fn redirect_to_frontend_error(error: &str, description: &str) -> Response { 27 redirect_see_other(&format!( 28 "/#/oauth/error?error={}&error_description={}", 29 url_encode(error), 30 url_encode(description) 31 )) 32} 33 34fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 35 headers 36 .get("cookie") 37 .and_then(|v| v.to_str().ok()) 38 .and_then(|cookie_str| { 39 for cookie in cookie_str.split(';') { 40 let cookie = cookie.trim(); 41 if let Some(value) = cookie.strip_prefix(&format!("{}=", DEVICE_COOKIE_NAME)) { 42 return Some(value.to_string()); 43 } 44 } 45 None 46 }) 47} 48 49fn extract_client_ip(headers: &HeaderMap) -> String { 50 if let Some(forwarded) = headers.get("x-forwarded-for") 51 && let Ok(value) = forwarded.to_str() 52 && let Some(first_ip) = value.split(',').next() 53 { 54 return first_ip.trim().to_string(); 55 } 56 if let Some(real_ip) = headers.get("x-real-ip") 57 && let Ok(value) = real_ip.to_str() 58 { 59 return value.trim().to_string(); 60 } 61 "0.0.0.0".to_string() 62} 63 64fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 65 headers 66 .get("user-agent") 67 .and_then(|v| v.to_str().ok()) 68 .map(|s| s.to_string()) 69} 70 71fn make_device_cookie(device_id: &str) -> String { 72 format!( 73 "{}={}; Path=/oauth; HttpOnly; Secure; SameSite=Lax; Max-Age=31536000", 74 DEVICE_COOKIE_NAME, device_id 75 ) 76} 77 78#[derive(Debug, Deserialize)] 79pub struct AuthorizeQuery { 80 pub request_uri: Option<String>, 81 pub client_id: Option<String>, 82 pub new_account: Option<bool>, 83} 84 85#[derive(Debug, Serialize)] 86pub struct AuthorizeResponse { 87 pub client_id: String, 88 pub client_name: Option<String>, 89 pub scope: Option<String>, 90 pub redirect_uri: String, 91 pub state: Option<String>, 92 pub login_hint: Option<String>, 93} 94 95#[derive(Debug, Deserialize)] 96pub struct AuthorizeSubmit { 97 pub request_uri: String, 98 pub username: String, 99 pub password: String, 100 #[serde(default)] 101 pub remember_device: bool, 102} 103 104#[derive(Debug, Deserialize)] 105pub struct AuthorizeSelectSubmit { 106 pub request_uri: String, 107 pub did: String, 108} 109 110fn wants_json(headers: &HeaderMap) -> bool { 111 headers 112 .get("accept") 113 .and_then(|v| v.to_str().ok()) 114 .map(|accept| accept.contains("application/json")) 115 .unwrap_or(false) 116} 117 118pub async fn authorize_get( 119 State(state): State<AppState>, 120 headers: HeaderMap, 121 Query(query): Query<AuthorizeQuery>, 122) -> Response { 123 let request_uri = match query.request_uri { 124 Some(uri) => uri, 125 None => { 126 if wants_json(&headers) { 127 return ( 128 StatusCode::BAD_REQUEST, 129 Json(serde_json::json!({ 130 "error": "invalid_request", 131 "error_description": "Missing request_uri parameter. Use PAR to initiate authorization." 132 })), 133 ).into_response(); 134 } 135 return redirect_to_frontend_error( 136 "invalid_request", 137 "Missing request_uri parameter. Use PAR to initiate authorization.", 138 ); 139 } 140 }; 141 let request_data = match db::get_authorization_request(&state.db, &request_uri).await { 142 Ok(Some(data)) => data, 143 Ok(None) => { 144 if wants_json(&headers) { 145 return ( 146 StatusCode::BAD_REQUEST, 147 Json(serde_json::json!({ 148 "error": "invalid_request", 149 "error_description": "Invalid or expired request_uri. Please start a new authorization request." 150 })), 151 ).into_response(); 152 } 153 return redirect_to_frontend_error( 154 "invalid_request", 155 "Invalid or expired request_uri. Please start a new authorization request.", 156 ); 157 } 158 Err(e) => { 159 if wants_json(&headers) { 160 return ( 161 StatusCode::INTERNAL_SERVER_ERROR, 162 Json(serde_json::json!({ 163 "error": "server_error", 164 "error_description": format!("Database error: {:?}", e) 165 })), 166 ) 167 .into_response(); 168 } 169 return redirect_to_frontend_error("server_error", "A database error occurred."); 170 } 171 }; 172 if request_data.expires_at < Utc::now() { 173 let _ = db::delete_authorization_request(&state.db, &request_uri).await; 174 if wants_json(&headers) { 175 return ( 176 StatusCode::BAD_REQUEST, 177 Json(serde_json::json!({ 178 "error": "invalid_request", 179 "error_description": "Authorization request has expired. Please start a new request." 180 })), 181 ).into_response(); 182 } 183 return redirect_to_frontend_error( 184 "invalid_request", 185 "Authorization request has expired. Please start a new request.", 186 ); 187 } 188 let client_cache = ClientMetadataCache::new(3600); 189 let client_name = client_cache 190 .get(&request_data.parameters.client_id) 191 .await 192 .ok() 193 .and_then(|m| m.client_name); 194 if wants_json(&headers) { 195 return Json(AuthorizeResponse { 196 client_id: request_data.parameters.client_id.clone(), 197 client_name: client_name.clone(), 198 scope: request_data.parameters.scope.clone(), 199 redirect_uri: request_data.parameters.redirect_uri.clone(), 200 state: request_data.parameters.state.clone(), 201 login_hint: request_data.parameters.login_hint.clone(), 202 }) 203 .into_response(); 204 } 205 let force_new_account = query.new_account.unwrap_or(false); 206 if !force_new_account 207 && let Some(device_id) = extract_device_cookie(&headers) 208 && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 209 && !accounts.is_empty() 210 { 211 return redirect_see_other(&format!( 212 "/#/oauth/accounts?request_uri={}", 213 url_encode(&request_uri) 214 )); 215 } 216 redirect_see_other(&format!( 217 "/#/oauth/login?request_uri={}", 218 url_encode(&request_uri) 219 )) 220} 221 222pub async fn authorize_get_json( 223 State(state): State<AppState>, 224 Query(query): Query<AuthorizeQuery>, 225) -> Result<Json<AuthorizeResponse>, OAuthError> { 226 let request_uri = query 227 .request_uri 228 .ok_or_else(|| OAuthError::InvalidRequest("request_uri is required".to_string()))?; 229 let request_data = db::get_authorization_request(&state.db, &request_uri) 230 .await? 231 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 232 if request_data.expires_at < Utc::now() { 233 db::delete_authorization_request(&state.db, &request_uri).await?; 234 return Err(OAuthError::InvalidRequest( 235 "request_uri has expired".to_string(), 236 )); 237 } 238 Ok(Json(AuthorizeResponse { 239 client_id: request_data.parameters.client_id.clone(), 240 client_name: None, 241 scope: request_data.parameters.scope.clone(), 242 redirect_uri: request_data.parameters.redirect_uri.clone(), 243 state: request_data.parameters.state.clone(), 244 login_hint: request_data.parameters.login_hint.clone(), 245 })) 246} 247 248#[derive(Debug, Serialize)] 249pub struct AccountInfo { 250 pub did: String, 251 pub handle: String, 252 #[serde(skip_serializing_if = "Option::is_none")] 253 pub email: Option<String>, 254} 255 256#[derive(Debug, Serialize)] 257pub struct AccountsResponse { 258 pub accounts: Vec<AccountInfo>, 259 pub request_uri: String, 260} 261 262fn mask_email(email: &str) -> String { 263 if let Some(at_pos) = email.find('@') { 264 let local = &email[..at_pos]; 265 let domain = &email[at_pos..]; 266 if local.len() <= 2 { 267 format!("{}***{}", local.chars().next().unwrap_or('*'), domain) 268 } else { 269 let first = local.chars().next().unwrap_or('*'); 270 let last = local.chars().last().unwrap_or('*'); 271 format!("{}***{}{}", first, last, domain) 272 } 273 } else { 274 "***".to_string() 275 } 276} 277 278pub async fn authorize_accounts( 279 State(state): State<AppState>, 280 headers: HeaderMap, 281 Query(query): Query<AuthorizeQuery>, 282) -> Response { 283 let request_uri = match query.request_uri { 284 Some(uri) => uri, 285 None => { 286 return ( 287 StatusCode::BAD_REQUEST, 288 Json(serde_json::json!({ 289 "error": "invalid_request", 290 "error_description": "Missing request_uri parameter" 291 })), 292 ) 293 .into_response(); 294 } 295 }; 296 let device_id = match extract_device_cookie(&headers) { 297 Some(id) => id, 298 None => { 299 return Json(AccountsResponse { 300 accounts: vec![], 301 request_uri, 302 }) 303 .into_response(); 304 } 305 }; 306 let accounts = match db::get_device_accounts(&state.db, &device_id).await { 307 Ok(accts) => accts, 308 Err(_) => { 309 return Json(AccountsResponse { 310 accounts: vec![], 311 request_uri, 312 }) 313 .into_response(); 314 } 315 }; 316 let account_infos: Vec<AccountInfo> = accounts 317 .into_iter() 318 .map(|row| AccountInfo { 319 did: row.did, 320 handle: row.handle, 321 email: row.email.map(|e| mask_email(&e)), 322 }) 323 .collect(); 324 Json(AccountsResponse { 325 accounts: account_infos, 326 request_uri, 327 }) 328 .into_response() 329} 330 331pub async fn authorize_post( 332 State(state): State<AppState>, 333 headers: HeaderMap, 334 Json(form): Json<AuthorizeSubmit>, 335) -> Response { 336 let json_response = wants_json(&headers); 337 let client_ip = extract_client_ip(&headers); 338 if !state 339 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 340 .await 341 { 342 tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 343 if json_response { 344 return ( 345 axum::http::StatusCode::TOO_MANY_REQUESTS, 346 Json(serde_json::json!({ 347 "error": "RateLimitExceeded", 348 "error_description": "Too many login attempts. Please try again later." 349 })), 350 ) 351 .into_response(); 352 } 353 return redirect_to_frontend_error( 354 "RateLimitExceeded", 355 "Too many login attempts. Please try again later.", 356 ); 357 } 358 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 359 Ok(Some(data)) => data, 360 Ok(None) => { 361 if json_response { 362 return ( 363 axum::http::StatusCode::BAD_REQUEST, 364 Json(serde_json::json!({ 365 "error": "invalid_request", 366 "error_description": "Invalid or expired request_uri." 367 })), 368 ) 369 .into_response(); 370 } 371 return redirect_to_frontend_error( 372 "invalid_request", 373 "Invalid or expired request_uri. Please start a new authorization request.", 374 ); 375 } 376 Err(e) => { 377 if json_response { 378 return ( 379 axum::http::StatusCode::INTERNAL_SERVER_ERROR, 380 Json(serde_json::json!({ 381 "error": "server_error", 382 "error_description": format!("Database error: {:?}", e) 383 })), 384 ) 385 .into_response(); 386 } 387 return redirect_to_frontend_error("server_error", &format!("Database error: {:?}", e)); 388 } 389 }; 390 if request_data.expires_at < Utc::now() { 391 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 392 if json_response { 393 return ( 394 axum::http::StatusCode::BAD_REQUEST, 395 Json(serde_json::json!({ 396 "error": "invalid_request", 397 "error_description": "Authorization request has expired." 398 })), 399 ) 400 .into_response(); 401 } 402 return redirect_to_frontend_error( 403 "invalid_request", 404 "Authorization request has expired. Please start a new request.", 405 ); 406 } 407 let show_login_error = |error_msg: &str, json: bool| -> Response { 408 if json { 409 return ( 410 axum::http::StatusCode::FORBIDDEN, 411 Json(serde_json::json!({ 412 "error": "access_denied", 413 "error_description": error_msg 414 })), 415 ) 416 .into_response(); 417 } 418 redirect_see_other(&format!( 419 "/#/oauth/login?request_uri={}&error={}", 420 url_encode(&form.request_uri), 421 url_encode(error_msg) 422 )) 423 }; 424 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 425 let normalized_username = form.username.trim(); 426 let normalized_username = normalized_username 427 .strip_prefix('@') 428 .unwrap_or(normalized_username); 429 let normalized_username = if let Some(bare_handle) = 430 normalized_username.strip_suffix(&format!(".{}", pds_hostname)) 431 { 432 bare_handle.to_string() 433 } else { 434 normalized_username.to_string() 435 }; 436 tracing::debug!( 437 original_username = %form.username, 438 normalized_username = %normalized_username, 439 pds_hostname = %pds_hostname, 440 "Normalized username for lookup" 441 ); 442 let user = match sqlx::query!( 443 r#" 444 SELECT id, did, email, password_hash, two_factor_enabled, 445 preferred_comms_channel as "preferred_comms_channel: CommsChannel", 446 deactivated_at, takedown_ref, 447 email_verified, discord_verified, telegram_verified, signal_verified 448 FROM users 449 WHERE handle = $1 OR email = $1 450 "#, 451 normalized_username 452 ) 453 .fetch_optional(&state.db) 454 .await 455 { 456 Ok(Some(u)) => u, 457 Ok(None) => { 458 let _ = bcrypt::verify( 459 &form.password, 460 "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/X4.VTtYw1ZzQKZqmK", 461 ); 462 return show_login_error("Invalid handle/email or password.", json_response); 463 } 464 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 465 }; 466 if user.deactivated_at.is_some() { 467 return show_login_error("This account has been deactivated.", json_response); 468 } 469 if user.takedown_ref.is_some() { 470 return show_login_error("This account has been taken down.", json_response); 471 } 472 let is_verified = user.email_verified 473 || user.discord_verified 474 || user.telegram_verified 475 || user.signal_verified; 476 if !is_verified { 477 return show_login_error( 478 "Please verify your account before logging in.", 479 json_response, 480 ); 481 } 482 let password_valid = match bcrypt::verify(&form.password, &user.password_hash) { 483 Ok(valid) => valid, 484 Err(_) => return show_login_error("An error occurred. Please try again.", json_response), 485 }; 486 if !password_valid { 487 return show_login_error("Invalid handle/email or password.", json_response); 488 } 489 if user.two_factor_enabled { 490 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 491 match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await { 492 Ok(challenge) => { 493 let hostname = 494 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 495 if let Err(e) = 496 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 497 { 498 tracing::warn!( 499 did = %user.did, 500 error = %e, 501 "Failed to enqueue 2FA notification" 502 ); 503 } 504 let channel_name = channel_display_name(user.preferred_comms_channel); 505 if json_response { 506 return Json(serde_json::json!({ 507 "needs_2fa": true, 508 "channel": channel_name 509 })) 510 .into_response(); 511 } 512 return redirect_see_other(&format!( 513 "/#/oauth/2fa?request_uri={}&channel={}", 514 url_encode(&form.request_uri), 515 url_encode(channel_name) 516 )); 517 } 518 Err(_) => { 519 return show_login_error("An error occurred. Please try again.", json_response); 520 } 521 } 522 } 523 let mut device_id: Option<String> = extract_device_cookie(&headers); 524 let mut new_cookie: Option<String> = None; 525 if form.remember_device { 526 let final_device_id = if let Some(existing_id) = &device_id { 527 existing_id.clone() 528 } else { 529 let new_id = DeviceId::generate(); 530 let device_data = DeviceData { 531 session_id: SessionId::generate().0, 532 user_agent: extract_user_agent(&headers), 533 ip_address: extract_client_ip(&headers), 534 last_seen_at: Utc::now(), 535 }; 536 if db::create_device(&state.db, &new_id.0, &device_data) 537 .await 538 .is_ok() 539 { 540 new_cookie = Some(make_device_cookie(&new_id.0)); 541 device_id = Some(new_id.0.clone()); 542 } 543 new_id.0 544 }; 545 let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 546 } 547 if db::set_authorization_did( 548 &state.db, 549 &form.request_uri, 550 &user.did, 551 device_id.as_deref(), 552 ) 553 .await 554 .is_err() 555 { 556 return show_login_error("An error occurred. Please try again.", json_response); 557 } 558 let requested_scope_str = request_data 559 .parameters 560 .scope 561 .as_deref() 562 .unwrap_or("atproto"); 563 let requested_scopes: Vec<String> = requested_scope_str 564 .split_whitespace() 565 .map(|s| s.to_string()) 566 .collect(); 567 let needs_consent = db::should_show_consent( 568 &state.db, 569 &user.did, 570 &request_data.parameters.client_id, 571 &requested_scopes, 572 ) 573 .await 574 .unwrap_or(true); 575 if needs_consent { 576 let consent_url = format!( 577 "/#/oauth/consent?request_uri={}", 578 url_encode(&form.request_uri) 579 ); 580 if json_response { 581 if let Some(cookie) = new_cookie { 582 return ( 583 StatusCode::OK, 584 [(SET_COOKIE, cookie)], 585 Json(serde_json::json!({"redirect_uri": consent_url})), 586 ) 587 .into_response(); 588 } 589 return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 590 } 591 if let Some(cookie) = new_cookie { 592 return ( 593 StatusCode::SEE_OTHER, 594 [(SET_COOKIE, cookie), (LOCATION, consent_url)], 595 ) 596 .into_response(); 597 } 598 return redirect_see_other(&consent_url); 599 } 600 let code = Code::generate(); 601 if db::update_authorization_request( 602 &state.db, 603 &form.request_uri, 604 &user.did, 605 device_id.as_deref(), 606 &code.0, 607 ) 608 .await 609 .is_err() 610 { 611 return show_login_error("An error occurred. Please try again.", json_response); 612 } 613 let redirect_url = build_success_redirect( 614 &request_data.parameters.redirect_uri, 615 &code.0, 616 request_data.parameters.state.as_deref(), 617 request_data.parameters.response_mode.as_deref(), 618 ); 619 if json_response { 620 if let Some(cookie) = new_cookie { 621 ( 622 StatusCode::OK, 623 [(SET_COOKIE, cookie)], 624 Json(serde_json::json!({"redirect_uri": redirect_url})), 625 ) 626 .into_response() 627 } else { 628 Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 629 } 630 } else if let Some(cookie) = new_cookie { 631 ( 632 StatusCode::SEE_OTHER, 633 [(SET_COOKIE, cookie), (LOCATION, redirect_url)], 634 ) 635 .into_response() 636 } else { 637 redirect_see_other(&redirect_url) 638 } 639} 640 641pub async fn authorize_select( 642 State(state): State<AppState>, 643 headers: HeaderMap, 644 Json(form): Json<AuthorizeSelectSubmit>, 645) -> Response { 646 let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 647 ( 648 status, 649 Json(serde_json::json!({ 650 "error": error, 651 "error_description": description 652 })), 653 ) 654 .into_response() 655 }; 656 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 657 Ok(Some(data)) => data, 658 Ok(None) => { 659 return json_error( 660 StatusCode::BAD_REQUEST, 661 "invalid_request", 662 "Invalid or expired request_uri. Please start a new authorization request.", 663 ); 664 } 665 Err(_) => { 666 return json_error( 667 StatusCode::INTERNAL_SERVER_ERROR, 668 "server_error", 669 "An error occurred. Please try again.", 670 ); 671 } 672 }; 673 if request_data.expires_at < Utc::now() { 674 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 675 return json_error( 676 StatusCode::BAD_REQUEST, 677 "invalid_request", 678 "Authorization request has expired. Please start a new request.", 679 ); 680 } 681 let device_id = match extract_device_cookie(&headers) { 682 Some(id) => id, 683 None => { 684 return json_error( 685 StatusCode::BAD_REQUEST, 686 "invalid_request", 687 "No device session found. Please sign in.", 688 ); 689 } 690 }; 691 let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 692 Ok(valid) => valid, 693 Err(_) => { 694 return json_error( 695 StatusCode::INTERNAL_SERVER_ERROR, 696 "server_error", 697 "An error occurred. Please try again.", 698 ); 699 } 700 }; 701 if !account_valid { 702 return json_error( 703 StatusCode::FORBIDDEN, 704 "access_denied", 705 "This account is not available on this device. Please sign in.", 706 ); 707 } 708 let user = match sqlx::query!( 709 r#" 710 SELECT id, two_factor_enabled, 711 preferred_comms_channel as "preferred_comms_channel: CommsChannel", 712 email_verified, discord_verified, telegram_verified, signal_verified 713 FROM users 714 WHERE did = $1 715 "#, 716 form.did 717 ) 718 .fetch_optional(&state.db) 719 .await 720 { 721 Ok(Some(u)) => u, 722 Ok(None) => { 723 return json_error( 724 StatusCode::FORBIDDEN, 725 "access_denied", 726 "Account not found. Please sign in.", 727 ); 728 } 729 Err(_) => { 730 return json_error( 731 StatusCode::INTERNAL_SERVER_ERROR, 732 "server_error", 733 "An error occurred. Please try again.", 734 ); 735 } 736 }; 737 let is_verified = user.email_verified 738 || user.discord_verified 739 || user.telegram_verified 740 || user.signal_verified; 741 if !is_verified { 742 return json_error( 743 StatusCode::FORBIDDEN, 744 "access_denied", 745 "Please verify your account before logging in.", 746 ); 747 } 748 if user.two_factor_enabled { 749 let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 750 match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await { 751 Ok(challenge) => { 752 let hostname = 753 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 754 if let Err(e) = 755 enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 756 { 757 tracing::warn!( 758 did = %form.did, 759 error = %e, 760 "Failed to enqueue 2FA notification" 761 ); 762 } 763 let channel_name = channel_display_name(user.preferred_comms_channel); 764 return Json(serde_json::json!({ 765 "needs_2fa": true, 766 "channel": channel_name 767 })) 768 .into_response(); 769 } 770 Err(_) => { 771 return json_error( 772 StatusCode::INTERNAL_SERVER_ERROR, 773 "server_error", 774 "An error occurred. Please try again.", 775 ); 776 } 777 } 778 } 779 let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await; 780 let code = Code::generate(); 781 if db::update_authorization_request( 782 &state.db, 783 &form.request_uri, 784 &form.did, 785 Some(&device_id), 786 &code.0, 787 ) 788 .await 789 .is_err() 790 { 791 return json_error( 792 StatusCode::INTERNAL_SERVER_ERROR, 793 "server_error", 794 "An error occurred. Please try again.", 795 ); 796 } 797 let redirect_url = build_success_redirect( 798 &request_data.parameters.redirect_uri, 799 &code.0, 800 request_data.parameters.state.as_deref(), 801 request_data.parameters.response_mode.as_deref(), 802 ); 803 Json(serde_json::json!({ 804 "redirect_uri": redirect_url 805 })) 806 .into_response() 807} 808 809fn build_success_redirect( 810 redirect_uri: &str, 811 code: &str, 812 state: Option<&str>, 813 response_mode: Option<&str>, 814) -> String { 815 let mut redirect_url = redirect_uri.to_string(); 816 let use_fragment = response_mode == Some("fragment"); 817 let separator = if use_fragment { 818 '#' 819 } else if redirect_url.contains('?') { 820 '&' 821 } else { 822 '?' 823 }; 824 redirect_url.push(separator); 825 redirect_url.push_str(&format!("code={}", url_encode(code))); 826 if let Some(req_state) = state { 827 redirect_url.push_str(&format!("&state={}", url_encode(req_state))); 828 } 829 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 830 redirect_url.push_str(&format!( 831 "&iss={}", 832 url_encode(&format!("https://{}", pds_hostname)) 833 )); 834 redirect_url 835} 836 837#[derive(Debug, Serialize)] 838pub struct AuthorizeDenyResponse { 839 pub error: String, 840 pub error_description: String, 841} 842 843pub async fn authorize_deny( 844 State(state): State<AppState>, 845 Json(form): Json<AuthorizeDenyForm>, 846) -> Response { 847 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 848 Ok(Some(data)) => data, 849 Ok(None) => { 850 return ( 851 StatusCode::BAD_REQUEST, 852 Json(serde_json::json!({ 853 "error": "invalid_request", 854 "error_description": "Invalid request_uri" 855 })), 856 ) 857 .into_response(); 858 } 859 Err(_) => { 860 return ( 861 StatusCode::INTERNAL_SERVER_ERROR, 862 Json(serde_json::json!({ 863 "error": "server_error", 864 "error_description": "An error occurred" 865 })), 866 ) 867 .into_response(); 868 } 869 }; 870 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 871 let redirect_uri = &request_data.parameters.redirect_uri; 872 let mut redirect_url = redirect_uri.to_string(); 873 let separator = if redirect_url.contains('?') { '&' } else { '?' }; 874 redirect_url.push(separator); 875 redirect_url.push_str("error=access_denied"); 876 redirect_url.push_str("&error_description=User%20denied%20the%20request"); 877 if let Some(state) = &request_data.parameters.state { 878 redirect_url.push_str(&format!("&state={}", url_encode(state))); 879 } 880 Json(serde_json::json!({ 881 "redirect_uri": redirect_url 882 })) 883 .into_response() 884} 885 886#[derive(Debug, Deserialize)] 887pub struct AuthorizeDenyForm { 888 pub request_uri: String, 889} 890 891#[derive(Debug, Deserialize)] 892pub struct Authorize2faQuery { 893 pub request_uri: String, 894 pub channel: Option<String>, 895} 896 897#[derive(Debug, Deserialize)] 898pub struct Authorize2faSubmit { 899 pub request_uri: String, 900 pub code: String, 901} 902 903const MAX_2FA_ATTEMPTS: i32 = 5; 904 905pub async fn authorize_2fa_get( 906 State(state): State<AppState>, 907 Query(query): Query<Authorize2faQuery>, 908) -> Response { 909 let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 910 Ok(Some(c)) => c, 911 Ok(None) => { 912 return redirect_to_frontend_error( 913 "invalid_request", 914 "No 2FA challenge found. Please start over.", 915 ); 916 } 917 Err(_) => { 918 return redirect_to_frontend_error( 919 "server_error", 920 "An error occurred. Please try again.", 921 ); 922 } 923 }; 924 if challenge.expires_at < Utc::now() { 925 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 926 return redirect_to_frontend_error( 927 "invalid_request", 928 "2FA code has expired. Please start over.", 929 ); 930 } 931 let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 932 Ok(Some(d)) => d, 933 Ok(None) => { 934 return redirect_to_frontend_error( 935 "invalid_request", 936 "Authorization request not found. Please start over.", 937 ); 938 } 939 Err(_) => { 940 return redirect_to_frontend_error( 941 "server_error", 942 "An error occurred. Please try again.", 943 ); 944 } 945 }; 946 let channel = query.channel.as_deref().unwrap_or("email"); 947 redirect_see_other(&format!( 948 "/#/oauth/2fa?request_uri={}&channel={}", 949 url_encode(&query.request_uri), 950 url_encode(channel) 951 )) 952} 953 954#[derive(Debug, Serialize)] 955pub struct ScopeInfo { 956 pub scope: String, 957 pub category: String, 958 pub required: bool, 959 pub description: String, 960 pub display_name: String, 961 pub granted: Option<bool>, 962} 963 964#[derive(Debug, Serialize)] 965pub struct ConsentResponse { 966 pub request_uri: String, 967 pub client_id: String, 968 pub client_name: Option<String>, 969 pub client_uri: Option<String>, 970 pub logo_uri: Option<String>, 971 pub scopes: Vec<ScopeInfo>, 972 pub show_consent: bool, 973 pub did: String, 974} 975 976#[derive(Debug, Deserialize)] 977pub struct ConsentQuery { 978 pub request_uri: String, 979} 980 981#[derive(Debug, Deserialize)] 982pub struct ConsentSubmit { 983 pub request_uri: String, 984 pub approved_scopes: Vec<String>, 985 pub remember: bool, 986} 987 988pub async fn consent_get( 989 State(state): State<AppState>, 990 Query(query): Query<ConsentQuery>, 991) -> Response { 992 let request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 993 Ok(Some(data)) => data, 994 Ok(None) => { 995 return ( 996 StatusCode::BAD_REQUEST, 997 Json(serde_json::json!({ 998 "error": "invalid_request", 999 "error_description": "Invalid or expired request_uri" 1000 })), 1001 ) 1002 .into_response(); 1003 } 1004 Err(e) => { 1005 return ( 1006 StatusCode::INTERNAL_SERVER_ERROR, 1007 Json(serde_json::json!({ 1008 "error": "server_error", 1009 "error_description": format!("Database error: {:?}", e) 1010 })), 1011 ) 1012 .into_response(); 1013 } 1014 }; 1015 if request_data.expires_at < Utc::now() { 1016 let _ = db::delete_authorization_request(&state.db, &query.request_uri).await; 1017 return ( 1018 StatusCode::BAD_REQUEST, 1019 Json(serde_json::json!({ 1020 "error": "invalid_request", 1021 "error_description": "Authorization request has expired" 1022 })), 1023 ) 1024 .into_response(); 1025 } 1026 let did = match &request_data.did { 1027 Some(d) => d.clone(), 1028 None => { 1029 return ( 1030 StatusCode::FORBIDDEN, 1031 Json(serde_json::json!({ 1032 "error": "access_denied", 1033 "error_description": "Not authenticated" 1034 })), 1035 ) 1036 .into_response(); 1037 } 1038 }; 1039 let client_cache = ClientMetadataCache::new(3600); 1040 let client_metadata = client_cache 1041 .get(&request_data.parameters.client_id) 1042 .await 1043 .ok(); 1044 let requested_scope_str = request_data 1045 .parameters 1046 .scope 1047 .as_deref() 1048 .unwrap_or("atproto"); 1049 let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1050 let preferences = 1051 db::get_scope_preferences(&state.db, &did, &request_data.parameters.client_id) 1052 .await 1053 .unwrap_or_default(); 1054 let pref_map: std::collections::HashMap<_, _> = preferences 1055 .iter() 1056 .map(|p| (p.scope.as_str(), p.granted)) 1057 .collect(); 1058 let requested_scope_strings: Vec<String> = 1059 requested_scopes.iter().map(|s| s.to_string()).collect(); 1060 let show_consent = db::should_show_consent( 1061 &state.db, 1062 &did, 1063 &request_data.parameters.client_id, 1064 &requested_scope_strings, 1065 ) 1066 .await 1067 .unwrap_or(true); 1068 let mut scopes = Vec::new(); 1069 for scope in &requested_scopes { 1070 let (category, required, description, display_name) = 1071 if let Some(def) = crate::oauth::scopes::SCOPE_DEFINITIONS.get(*scope) { 1072 ( 1073 def.category.display_name().to_string(), 1074 def.required, 1075 def.description.to_string(), 1076 def.display_name.to_string(), 1077 ) 1078 } else if scope.starts_with("ref:") { 1079 ( 1080 "Reference".to_string(), 1081 false, 1082 "Referenced scope".to_string(), 1083 scope.to_string(), 1084 ) 1085 } else { 1086 ( 1087 "Other".to_string(), 1088 false, 1089 format!("Access to {}", scope), 1090 scope.to_string(), 1091 ) 1092 }; 1093 let granted = pref_map.get(*scope).copied(); 1094 scopes.push(ScopeInfo { 1095 scope: scope.to_string(), 1096 category, 1097 required, 1098 description, 1099 display_name, 1100 granted, 1101 }); 1102 } 1103 Json(ConsentResponse { 1104 request_uri: query.request_uri.clone(), 1105 client_id: request_data.parameters.client_id.clone(), 1106 client_name: client_metadata.as_ref().and_then(|m| m.client_name.clone()), 1107 client_uri: client_metadata.as_ref().and_then(|m| m.client_uri.clone()), 1108 logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1109 scopes, 1110 show_consent, 1111 did, 1112 }) 1113 .into_response() 1114} 1115 1116pub async fn consent_post( 1117 State(state): State<AppState>, 1118 Json(form): Json<ConsentSubmit>, 1119) -> Response { 1120 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1121 Ok(Some(data)) => data, 1122 Ok(None) => { 1123 return ( 1124 StatusCode::BAD_REQUEST, 1125 Json(serde_json::json!({ 1126 "error": "invalid_request", 1127 "error_description": "Invalid or expired request_uri" 1128 })), 1129 ) 1130 .into_response(); 1131 } 1132 Err(e) => { 1133 return ( 1134 StatusCode::INTERNAL_SERVER_ERROR, 1135 Json(serde_json::json!({ 1136 "error": "server_error", 1137 "error_description": format!("Database error: {:?}", e) 1138 })), 1139 ) 1140 .into_response(); 1141 } 1142 }; 1143 if request_data.expires_at < Utc::now() { 1144 let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1145 return ( 1146 StatusCode::BAD_REQUEST, 1147 Json(serde_json::json!({ 1148 "error": "invalid_request", 1149 "error_description": "Authorization request has expired" 1150 })), 1151 ) 1152 .into_response(); 1153 } 1154 let did = match &request_data.did { 1155 Some(d) => d.clone(), 1156 None => { 1157 return ( 1158 StatusCode::FORBIDDEN, 1159 Json(serde_json::json!({ 1160 "error": "access_denied", 1161 "error_description": "Not authenticated" 1162 })), 1163 ) 1164 .into_response(); 1165 } 1166 }; 1167 let requested_scope_str = request_data 1168 .parameters 1169 .scope 1170 .as_deref() 1171 .unwrap_or("atproto"); 1172 let requested_scopes: Vec<&str> = requested_scope_str.split_whitespace().collect(); 1173 let has_granular_scopes = requested_scopes.iter().any(|s| { 1174 s.starts_with("repo:") 1175 || s.starts_with("blob:") 1176 || s.starts_with("rpc:") 1177 || s.starts_with("account:") 1178 || s.starts_with("identity:") 1179 }); 1180 let user_denied_some_granular = has_granular_scopes 1181 && requested_scopes 1182 .iter() 1183 .filter(|s| { 1184 s.starts_with("repo:") 1185 || s.starts_with("blob:") 1186 || s.starts_with("rpc:") 1187 || s.starts_with("account:") 1188 || s.starts_with("identity:") 1189 }) 1190 .any(|s| !form.approved_scopes.contains(&s.to_string())); 1191 let atproto_was_requested = requested_scopes.contains(&"atproto"); 1192 if atproto_was_requested 1193 && !has_granular_scopes 1194 && !form.approved_scopes.contains(&"atproto".to_string()) 1195 { 1196 return ( 1197 StatusCode::BAD_REQUEST, 1198 Json(serde_json::json!({ 1199 "error": "invalid_request", 1200 "error_description": "The atproto scope was requested and must be approved" 1201 })), 1202 ) 1203 .into_response(); 1204 } 1205 let final_approved: Vec<String> = if user_denied_some_granular { 1206 form.approved_scopes 1207 .iter() 1208 .filter(|s| *s != "atproto") 1209 .cloned() 1210 .collect() 1211 } else { 1212 form.approved_scopes.clone() 1213 }; 1214 if final_approved.is_empty() { 1215 return ( 1216 StatusCode::BAD_REQUEST, 1217 Json(serde_json::json!({ 1218 "error": "invalid_request", 1219 "error_description": "At least one scope must be approved" 1220 })), 1221 ) 1222 .into_response(); 1223 } 1224 let approved_scope_str = final_approved.join(" "); 1225 let has_valid_scope = final_approved.iter().all(|s| { 1226 s == "atproto" 1227 || s == "transition:generic" 1228 || s == "transition:chat.bsky" 1229 || s == "transition:email" 1230 || s.starts_with("repo:") 1231 || s.starts_with("blob:") 1232 || s.starts_with("rpc:") 1233 || s.starts_with("account:") 1234 || s.starts_with("include:") 1235 }); 1236 if !has_valid_scope { 1237 return ( 1238 StatusCode::BAD_REQUEST, 1239 Json(serde_json::json!({ 1240 "error": "invalid_request", 1241 "error_description": "Invalid scope format" 1242 })), 1243 ) 1244 .into_response(); 1245 } 1246 if form.remember { 1247 let preferences: Vec<db::ScopePreference> = requested_scopes 1248 .iter() 1249 .map(|s| db::ScopePreference { 1250 scope: s.to_string(), 1251 granted: form.approved_scopes.contains(&s.to_string()), 1252 }) 1253 .collect(); 1254 let _ = db::upsert_scope_preferences( 1255 &state.db, 1256 &did, 1257 &request_data.parameters.client_id, 1258 &preferences, 1259 ) 1260 .await; 1261 } 1262 if let Err(e) = 1263 db::update_request_scope(&state.db, &form.request_uri, &approved_scope_str).await 1264 { 1265 tracing::warn!("Failed to update request scope: {:?}", e); 1266 } 1267 let code = Code::generate(); 1268 if db::update_authorization_request( 1269 &state.db, 1270 &form.request_uri, 1271 &did, 1272 request_data.device_id.as_deref(), 1273 &code.0, 1274 ) 1275 .await 1276 .is_err() 1277 { 1278 return ( 1279 StatusCode::INTERNAL_SERVER_ERROR, 1280 Json(serde_json::json!({ 1281 "error": "server_error", 1282 "error_description": "Failed to complete authorization" 1283 })), 1284 ) 1285 .into_response(); 1286 } 1287 let redirect_url = build_success_redirect( 1288 &request_data.parameters.redirect_uri, 1289 &code.0, 1290 request_data.parameters.state.as_deref(), 1291 request_data.parameters.response_mode.as_deref(), 1292 ); 1293 Json(serde_json::json!({ 1294 "redirect_uri": redirect_url 1295 })) 1296 .into_response() 1297} 1298 1299pub async fn authorize_2fa_post( 1300 State(state): State<AppState>, 1301 headers: HeaderMap, 1302 Json(form): Json<Authorize2faSubmit>, 1303) -> Response { 1304 let json_error = |status: StatusCode, error: &str, description: &str| -> Response { 1305 ( 1306 status, 1307 Json(serde_json::json!({ 1308 "error": error, 1309 "error_description": description 1310 })), 1311 ) 1312 .into_response() 1313 }; 1314 let client_ip = extract_client_ip(&headers); 1315 if !state 1316 .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 1317 .await 1318 { 1319 tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 1320 return json_error( 1321 StatusCode::TOO_MANY_REQUESTS, 1322 "RateLimitExceeded", 1323 "Too many attempts. Please try again later.", 1324 ); 1325 } 1326 let challenge = match db::get_2fa_challenge(&state.db, &form.request_uri).await { 1327 Ok(Some(c)) => c, 1328 Ok(None) => { 1329 return json_error( 1330 StatusCode::BAD_REQUEST, 1331 "invalid_request", 1332 "No 2FA challenge found. Please start over.", 1333 ); 1334 } 1335 Err(_) => { 1336 return json_error( 1337 StatusCode::INTERNAL_SERVER_ERROR, 1338 "server_error", 1339 "An error occurred. Please try again.", 1340 ); 1341 } 1342 }; 1343 if challenge.expires_at < Utc::now() { 1344 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1345 return json_error( 1346 StatusCode::BAD_REQUEST, 1347 "invalid_request", 1348 "2FA code has expired. Please start over.", 1349 ); 1350 } 1351 if challenge.attempts >= MAX_2FA_ATTEMPTS { 1352 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1353 return json_error( 1354 StatusCode::FORBIDDEN, 1355 "access_denied", 1356 "Too many failed attempts. Please start over.", 1357 ); 1358 } 1359 let code_valid: bool = form 1360 .code 1361 .trim() 1362 .as_bytes() 1363 .ct_eq(challenge.code.as_bytes()) 1364 .into(); 1365 if !code_valid { 1366 let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 1367 return json_error( 1368 StatusCode::FORBIDDEN, 1369 "invalid_code", 1370 "Invalid verification code. Please try again.", 1371 ); 1372 } 1373 let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1374 let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1375 Ok(Some(d)) => d, 1376 Ok(None) => { 1377 return json_error( 1378 StatusCode::BAD_REQUEST, 1379 "invalid_request", 1380 "Authorization request not found.", 1381 ); 1382 } 1383 Err(_) => { 1384 return json_error( 1385 StatusCode::INTERNAL_SERVER_ERROR, 1386 "server_error", 1387 "An error occurred.", 1388 ); 1389 } 1390 }; 1391 let code = Code::generate(); 1392 let device_id = extract_device_cookie(&headers); 1393 if db::update_authorization_request( 1394 &state.db, 1395 &form.request_uri, 1396 &challenge.did, 1397 device_id.as_deref(), 1398 &code.0, 1399 ) 1400 .await 1401 .is_err() 1402 { 1403 return json_error( 1404 StatusCode::INTERNAL_SERVER_ERROR, 1405 "server_error", 1406 "An error occurred. Please try again.", 1407 ); 1408 } 1409 let redirect_url = build_success_redirect( 1410 &request_data.parameters.redirect_uri, 1411 &code.0, 1412 request_data.parameters.state.as_deref(), 1413 request_data.parameters.response_mode.as_deref(), 1414 ); 1415 Json(serde_json::json!({ 1416 "redirect_uri": redirect_url 1417 })) 1418 .into_response() 1419}