this repo has no description
1use crate::api::error::ApiError; 2use axum::{ 3 Json, 4 extract::State, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use chrono::{DateTime, Utc}; 9use serde::{Deserialize, Serialize}; 10use sqlx::PgPool; 11use tracing::{error, info, warn}; 12 13use crate::auth::BearerAuth; 14use crate::state::{AppState, RateLimitKind}; 15use crate::types::PlainPassword; 16 17const REAUTH_WINDOW_SECONDS: i64 = 300; 18 19#[derive(Serialize)] 20#[serde(rename_all = "camelCase")] 21pub struct ReauthStatusResponse { 22 pub last_reauth_at: Option<DateTime<Utc>>, 23 pub reauth_required: bool, 24 pub available_methods: Vec<String>, 25} 26 27pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 let session = sqlx::query!( 29 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 30 &auth.0.did 31 ) 32 .fetch_optional(&state.db) 33 .await; 34 35 let last_reauth_at = match session { 36 Ok(Some(row)) => row.last_reauth_at, 37 Ok(None) => None, 38 Err(e) => { 39 error!("DB error: {:?}", e); 40 return ApiError::InternalError(None).into_response(); 41 } 42 }; 43 44 let reauth_required = is_reauth_required(last_reauth_at); 45 let available_methods = get_available_reauth_methods(&state.db, &auth.0.did).await; 46 47 Json(ReauthStatusResponse { 48 last_reauth_at, 49 reauth_required, 50 available_methods, 51 }) 52 .into_response() 53} 54 55#[derive(Deserialize)] 56#[serde(rename_all = "camelCase")] 57pub struct PasswordReauthInput { 58 pub password: PlainPassword, 59} 60 61#[derive(Serialize)] 62#[serde(rename_all = "camelCase")] 63pub struct ReauthResponse { 64 pub reauthed_at: DateTime<Utc>, 65} 66 67pub async fn reauth_password( 68 State(state): State<AppState>, 69 auth: BearerAuth, 70 Json(input): Json<PasswordReauthInput>, 71) -> Response { 72 let user = sqlx::query!("SELECT password_hash FROM users WHERE did = $1", &*&auth.0.did) 73 .fetch_optional(&state.db) 74 .await; 75 76 let password_hash = match user { 77 Ok(Some(row)) => row.password_hash, 78 Ok(None) => { 79 return ApiError::AccountNotFound.into_response(); 80 } 81 Err(e) => { 82 error!("DB error: {:?}", e); 83 return ApiError::InternalError(None).into_response(); 84 } 85 }; 86 87 let password_valid = password_hash 88 .as_ref() 89 .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false)) 90 .unwrap_or(false); 91 92 if !password_valid { 93 let app_passwords = sqlx::query!( 94 "SELECT ap.password_hash FROM app_passwords ap 95 JOIN users u ON ap.user_id = u.id 96 WHERE u.did = $1", 97 &auth.0.did 98 ) 99 .fetch_all(&state.db) 100 .await 101 .unwrap_or_default(); 102 103 let app_password_valid = app_passwords 104 .iter() 105 .any(|ap| bcrypt::verify(&input.password, &ap.password_hash).unwrap_or(false)); 106 107 if !app_password_valid { 108 warn!(did = %&auth.0.did, "Re-auth failed: invalid password"); 109 return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 110 } 111 } 112 113 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 114 Ok(reauthed_at) => { 115 info!(did = %&auth.0.did, "Re-auth successful via password"); 116 Json(ReauthResponse { reauthed_at }).into_response() 117 } 118 Err(e) => { 119 error!("DB error updating reauth: {:?}", e); 120 ApiError::InternalError(None).into_response() 121 } 122 } 123} 124 125#[derive(Deserialize)] 126#[serde(rename_all = "camelCase")] 127pub struct TotpReauthInput { 128 pub code: String, 129} 130 131pub async fn reauth_totp( 132 State(state): State<AppState>, 133 auth: BearerAuth, 134 Json(input): Json<TotpReauthInput>, 135) -> Response { 136 if !state 137 .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 138 .await 139 { 140 warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 141 return ApiError::RateLimitExceeded(Some("Too many verification attempts. Please try again in a few minutes.".into(),)) 142 .into_response(); 143 } 144 145 let valid = 146 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code) 147 .await; 148 149 if !valid { 150 warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code"); 151 return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response(); 152 } 153 154 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 155 Ok(reauthed_at) => { 156 info!(did = %&auth.0.did, "Re-auth successful via TOTP"); 157 Json(ReauthResponse { reauthed_at }).into_response() 158 } 159 Err(e) => { 160 error!("DB error updating reauth: {:?}", e); 161 ApiError::InternalError(None).into_response() 162 } 163 } 164} 165 166#[derive(Serialize)] 167#[serde(rename_all = "camelCase")] 168pub struct PasskeyReauthStartResponse { 169 pub options: serde_json::Value, 170} 171 172pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response { 173 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 174 175 let stored_passkeys = 176 match crate::auth::webauthn::get_passkeys_for_user(&state.db, &auth.0.did).await { 177 Ok(pks) => pks, 178 Err(e) => { 179 error!("Failed to get passkeys: {:?}", e); 180 return ApiError::InternalError(None).into_response(); 181 } 182 }; 183 184 if stored_passkeys.is_empty() { 185 return ApiError::NoPasskeys.into_response(); 186 } 187 188 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys 189 .iter() 190 .filter_map(|sp| sp.to_security_key().ok()) 191 .collect(); 192 193 if passkeys.is_empty() { 194 return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response(); 195 } 196 197 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 198 Ok(w) => w, 199 Err(e) => { 200 error!("Failed to create WebAuthn config: {:?}", e); 201 return ApiError::InternalError(None).into_response(); 202 } 203 }; 204 205 let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 206 Ok(result) => result, 207 Err(e) => { 208 error!("Failed to start passkey authentication: {:?}", e); 209 return ApiError::InternalError(None).into_response(); 210 } 211 }; 212 213 if let Err(e) = 214 crate::auth::webauthn::save_authentication_state(&state.db, &auth.0.did, &auth_state).await 215 { 216 error!("Failed to save authentication state: {:?}", e); 217 return ApiError::InternalError(None).into_response(); 218 } 219 220 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 221 Json(PasskeyReauthStartResponse { options }).into_response() 222} 223 224#[derive(Deserialize)] 225#[serde(rename_all = "camelCase")] 226pub struct PasskeyReauthFinishInput { 227 pub credential: serde_json::Value, 228} 229 230pub async fn reauth_passkey_finish( 231 State(state): State<AppState>, 232 auth: BearerAuth, 233 Json(input): Json<PasskeyReauthFinishInput>, 234) -> Response { 235 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 236 237 let auth_state = 238 match crate::auth::webauthn::load_authentication_state(&state.db, &auth.0.did).await { 239 Ok(Some(s)) => s, 240 Ok(None) => { 241 return ApiError::NoChallengeInProgress.into_response(); 242 } 243 Err(e) => { 244 error!("Failed to load authentication state: {:?}", e); 245 return ApiError::InternalError(None).into_response(); 246 } 247 }; 248 249 let credential: webauthn_rs::prelude::PublicKeyCredential = 250 match serde_json::from_value(input.credential) { 251 Ok(c) => c, 252 Err(e) => { 253 warn!("Failed to parse credential: {:?}", e); 254 return ApiError::InvalidCredential.into_response(); 255 } 256 }; 257 258 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 259 Ok(w) => w, 260 Err(e) => { 261 error!("Failed to create WebAuthn config: {:?}", e); 262 return ApiError::InternalError(None).into_response(); 263 } 264 }; 265 266 let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 267 Ok(r) => r, 268 Err(e) => { 269 warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e); 270 return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 271 .into_response(); 272 } 273 }; 274 275 let cred_id_bytes = auth_result.cred_id().as_ref(); 276 match crate::auth::webauthn::update_passkey_counter( 277 &state.db, 278 cred_id_bytes, 279 auth_result.counter(), 280 ) 281 .await 282 { 283 Ok(false) => { 284 warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key"); 285 let _ = 286 crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 287 return ApiError::PasskeyCounterAnomaly.into_response(); 288 } 289 Err(e) => { 290 error!("Failed to update passkey counter: {:?}", e); 291 } 292 Ok(true) => {} 293 } 294 295 let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 296 297 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 298 Ok(reauthed_at) => { 299 info!(did = %&auth.0.did, "Re-auth successful via passkey"); 300 Json(ReauthResponse { reauthed_at }).into_response() 301 } 302 Err(e) => { 303 error!("DB error updating reauth: {:?}", e); 304 ApiError::InternalError(None).into_response() 305 } 306 } 307} 308 309pub async fn update_last_reauth_cached( 310 db: &PgPool, 311 cache: &std::sync::Arc<dyn crate::cache::Cache>, 312 did: &str, 313) -> Result<DateTime<Utc>, sqlx::Error> { 314 let now = Utc::now(); 315 sqlx::query!( 316 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2", 317 now, 318 did 319 ) 320 .execute(db) 321 .await?; 322 let cache_key = format!("reauth:{}", did); 323 let _ = cache 324 .set( 325 &cache_key, 326 &now.timestamp().to_string(), 327 std::time::Duration::from_secs(REAUTH_WINDOW_SECONDS as u64), 328 ) 329 .await; 330 Ok(now) 331} 332 333fn is_reauth_required(last_reauth_at: Option<DateTime<Utc>>) -> bool { 334 match last_reauth_at { 335 None => true, 336 Some(t) => { 337 let elapsed = Utc::now().signed_duration_since(t); 338 elapsed.num_seconds() > REAUTH_WINDOW_SECONDS 339 } 340 } 341} 342 343async fn get_available_reauth_methods(db: &PgPool, did: &str) -> Vec<String> { 344 let mut methods = Vec::new(); 345 346 let has_password = sqlx::query_scalar!( 347 "SELECT password_hash IS NOT NULL as has_pw FROM users WHERE did = $1", 348 did 349 ) 350 .fetch_optional(db) 351 .await 352 .ok() 353 .flatten() 354 .unwrap_or(Some(false)); 355 356 if has_password == Some(true) { 357 methods.push("password".to_string()); 358 } 359 360 let has_totp = crate::api::server::totp::has_totp_enabled_db(db, did).await; 361 if has_totp { 362 methods.push("totp".to_string()); 363 } 364 365 let has_passkeys = crate::api::server::passkeys::has_passkeys_for_user_db(db, did).await; 366 if has_passkeys { 367 methods.push("passkey".to_string()); 368 } 369 370 methods 371} 372 373pub async fn check_reauth_required(db: &PgPool, did: &str) -> bool { 374 let session = sqlx::query!( 375 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 376 did 377 ) 378 .fetch_optional(db) 379 .await; 380 381 match session { 382 Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 383 _ => true, 384 } 385} 386 387pub async fn check_reauth_required_cached( 388 db: &PgPool, 389 cache: &std::sync::Arc<dyn crate::cache::Cache>, 390 did: &str, 391) -> bool { 392 let cache_key = format!("reauth:{}", did); 393 if let Some(timestamp_str) = cache.get(&cache_key).await 394 && let Ok(timestamp) = timestamp_str.parse::<i64>() 395 { 396 let reauth_time = chrono::DateTime::from_timestamp(timestamp, 0); 397 if let Some(t) = reauth_time { 398 let elapsed = Utc::now().signed_duration_since(t); 399 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 400 return false; 401 } 402 } 403 } 404 let session = sqlx::query!( 405 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 406 did 407 ) 408 .fetch_optional(db) 409 .await; 410 411 match session { 412 Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 413 _ => true, 414 } 415} 416 417#[derive(Serialize)] 418#[serde(rename_all = "camelCase")] 419pub struct ReauthRequiredError { 420 pub error: String, 421 pub message: String, 422 pub reauth_methods: Vec<String>, 423} 424 425pub async fn reauth_required_response(db: &PgPool, did: &str) -> Response { 426 let methods = get_available_reauth_methods(db, did).await; 427 ( 428 StatusCode::UNAUTHORIZED, 429 Json(ReauthRequiredError { 430 error: "ReauthRequired".to_string(), 431 message: "Re-authentication required for this action".to_string(), 432 reauth_methods: methods, 433 }), 434 ) 435 .into_response() 436} 437 438pub async fn check_legacy_session_mfa(db: &PgPool, did: &str) -> bool { 439 let session = sqlx::query!( 440 "SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 441 did 442 ) 443 .fetch_optional(db) 444 .await; 445 446 match session { 447 Ok(Some(row)) => { 448 if !row.legacy_login { 449 return true; 450 } 451 if row.mfa_verified { 452 return true; 453 } 454 if let Some(last_reauth) = row.last_reauth_at { 455 let elapsed = chrono::Utc::now().signed_duration_since(last_reauth); 456 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 457 return true; 458 } 459 } 460 false 461 } 462 _ => true, 463 } 464} 465 466pub async fn update_mfa_verified(db: &PgPool, did: &str) -> Result<(), sqlx::Error> { 467 sqlx::query!( 468 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1", 469 did 470 ) 471 .execute(db) 472 .await?; 473 Ok(()) 474} 475 476pub async fn legacy_mfa_required_response(db: &PgPool, did: &str) -> Response { 477 let methods = get_available_reauth_methods(db, did).await; 478 ( 479 StatusCode::FORBIDDEN, 480 Json(MfaVerificationRequiredError { 481 error: "MfaVerificationRequired".to_string(), 482 message: "This sensitive operation requires MFA verification. Your session was created via a legacy app that doesn't support MFA during login.".to_string(), 483 reauth_methods: methods, 484 }), 485 ) 486 .into_response() 487} 488 489#[derive(Serialize)] 490#[serde(rename_all = "camelCase")] 491pub struct MfaVerificationRequiredError { 492 pub error: String, 493 pub message: String, 494 pub reauth_methods: Vec<String>, 495}