this repo has no description
at main 15 kB view raw
1use crate::api::error::ApiError; 2use axum::{ 3 Json, 4 extract::State, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use chrono::{DateTime, Utc}; 9use serde::{Deserialize, Serialize}; 10use sqlx::PgPool; 11use tracing::{error, info, warn}; 12 13use crate::auth::BearerAuth; 14use crate::state::{AppState, RateLimitKind}; 15use crate::types::PlainPassword; 16 17const REAUTH_WINDOW_SECONDS: i64 = 300; 18 19#[derive(Serialize)] 20#[serde(rename_all = "camelCase")] 21pub struct ReauthStatusResponse { 22 pub last_reauth_at: Option<DateTime<Utc>>, 23 pub reauth_required: bool, 24 pub available_methods: Vec<String>, 25} 26 27pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 let session = sqlx::query!( 29 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 30 &auth.0.did 31 ) 32 .fetch_optional(&state.db) 33 .await; 34 35 let last_reauth_at = match session { 36 Ok(Some(row)) => row.last_reauth_at, 37 Ok(None) => None, 38 Err(e) => { 39 error!("DB error: {:?}", e); 40 return ApiError::InternalError(None).into_response(); 41 } 42 }; 43 44 let reauth_required = is_reauth_required(last_reauth_at); 45 let available_methods = get_available_reauth_methods(&state.db, &auth.0.did).await; 46 47 Json(ReauthStatusResponse { 48 last_reauth_at, 49 reauth_required, 50 available_methods, 51 }) 52 .into_response() 53} 54 55#[derive(Deserialize)] 56#[serde(rename_all = "camelCase")] 57pub struct PasswordReauthInput { 58 pub password: PlainPassword, 59} 60 61#[derive(Serialize)] 62#[serde(rename_all = "camelCase")] 63pub struct ReauthResponse { 64 pub reauthed_at: DateTime<Utc>, 65} 66 67pub async fn reauth_password( 68 State(state): State<AppState>, 69 auth: BearerAuth, 70 Json(input): Json<PasswordReauthInput>, 71) -> Response { 72 let user = sqlx::query!( 73 "SELECT password_hash FROM users WHERE did = $1", 74 &*&auth.0.did 75 ) 76 .fetch_optional(&state.db) 77 .await; 78 79 let password_hash = match user { 80 Ok(Some(row)) => row.password_hash, 81 Ok(None) => { 82 return ApiError::AccountNotFound.into_response(); 83 } 84 Err(e) => { 85 error!("DB error: {:?}", e); 86 return ApiError::InternalError(None).into_response(); 87 } 88 }; 89 90 let password_valid = password_hash 91 .as_ref() 92 .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false)) 93 .unwrap_or(false); 94 95 if !password_valid { 96 let app_passwords = sqlx::query!( 97 "SELECT ap.password_hash FROM app_passwords ap 98 JOIN users u ON ap.user_id = u.id 99 WHERE u.did = $1", 100 &auth.0.did 101 ) 102 .fetch_all(&state.db) 103 .await 104 .unwrap_or_default(); 105 106 let app_password_valid = app_passwords 107 .iter() 108 .any(|ap| bcrypt::verify(&input.password, &ap.password_hash).unwrap_or(false)); 109 110 if !app_password_valid { 111 warn!(did = %&auth.0.did, "Re-auth failed: invalid password"); 112 return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 113 } 114 } 115 116 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 117 Ok(reauthed_at) => { 118 info!(did = %&auth.0.did, "Re-auth successful via password"); 119 Json(ReauthResponse { reauthed_at }).into_response() 120 } 121 Err(e) => { 122 error!("DB error updating reauth: {:?}", e); 123 ApiError::InternalError(None).into_response() 124 } 125 } 126} 127 128#[derive(Deserialize)] 129#[serde(rename_all = "camelCase")] 130pub struct TotpReauthInput { 131 pub code: String, 132} 133 134pub async fn reauth_totp( 135 State(state): State<AppState>, 136 auth: BearerAuth, 137 Json(input): Json<TotpReauthInput>, 138) -> Response { 139 if !state 140 .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 141 .await 142 { 143 warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 144 return ApiError::RateLimitExceeded(Some( 145 "Too many verification attempts. Please try again in a few minutes.".into(), 146 )) 147 .into_response(); 148 } 149 150 let valid = 151 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code) 152 .await; 153 154 if !valid { 155 warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code"); 156 return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response(); 157 } 158 159 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 160 Ok(reauthed_at) => { 161 info!(did = %&auth.0.did, "Re-auth successful via TOTP"); 162 Json(ReauthResponse { reauthed_at }).into_response() 163 } 164 Err(e) => { 165 error!("DB error updating reauth: {:?}", e); 166 ApiError::InternalError(None).into_response() 167 } 168 } 169} 170 171#[derive(Serialize)] 172#[serde(rename_all = "camelCase")] 173pub struct PasskeyReauthStartResponse { 174 pub options: serde_json::Value, 175} 176 177pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response { 178 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 179 180 let stored_passkeys = 181 match crate::auth::webauthn::get_passkeys_for_user(&state.db, &auth.0.did).await { 182 Ok(pks) => pks, 183 Err(e) => { 184 error!("Failed to get passkeys: {:?}", e); 185 return ApiError::InternalError(None).into_response(); 186 } 187 }; 188 189 if stored_passkeys.is_empty() { 190 return ApiError::NoPasskeys.into_response(); 191 } 192 193 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys 194 .iter() 195 .filter_map(|sp| sp.to_security_key().ok()) 196 .collect(); 197 198 if passkeys.is_empty() { 199 return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response(); 200 } 201 202 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 203 Ok(w) => w, 204 Err(e) => { 205 error!("Failed to create WebAuthn config: {:?}", e); 206 return ApiError::InternalError(None).into_response(); 207 } 208 }; 209 210 let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 211 Ok(result) => result, 212 Err(e) => { 213 error!("Failed to start passkey authentication: {:?}", e); 214 return ApiError::InternalError(None).into_response(); 215 } 216 }; 217 218 if let Err(e) = 219 crate::auth::webauthn::save_authentication_state(&state.db, &auth.0.did, &auth_state).await 220 { 221 error!("Failed to save authentication state: {:?}", e); 222 return ApiError::InternalError(None).into_response(); 223 } 224 225 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 226 Json(PasskeyReauthStartResponse { options }).into_response() 227} 228 229#[derive(Deserialize)] 230#[serde(rename_all = "camelCase")] 231pub struct PasskeyReauthFinishInput { 232 pub credential: serde_json::Value, 233} 234 235pub async fn reauth_passkey_finish( 236 State(state): State<AppState>, 237 auth: BearerAuth, 238 Json(input): Json<PasskeyReauthFinishInput>, 239) -> Response { 240 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 241 242 let auth_state = 243 match crate::auth::webauthn::load_authentication_state(&state.db, &auth.0.did).await { 244 Ok(Some(s)) => s, 245 Ok(None) => { 246 return ApiError::NoChallengeInProgress.into_response(); 247 } 248 Err(e) => { 249 error!("Failed to load authentication state: {:?}", e); 250 return ApiError::InternalError(None).into_response(); 251 } 252 }; 253 254 let credential: webauthn_rs::prelude::PublicKeyCredential = 255 match serde_json::from_value(input.credential) { 256 Ok(c) => c, 257 Err(e) => { 258 warn!("Failed to parse credential: {:?}", e); 259 return ApiError::InvalidCredential.into_response(); 260 } 261 }; 262 263 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 264 Ok(w) => w, 265 Err(e) => { 266 error!("Failed to create WebAuthn config: {:?}", e); 267 return ApiError::InternalError(None).into_response(); 268 } 269 }; 270 271 let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 272 Ok(r) => r, 273 Err(e) => { 274 warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e); 275 return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 276 .into_response(); 277 } 278 }; 279 280 let cred_id_bytes = auth_result.cred_id().as_ref(); 281 match crate::auth::webauthn::update_passkey_counter( 282 &state.db, 283 cred_id_bytes, 284 auth_result.counter(), 285 ) 286 .await 287 { 288 Ok(false) => { 289 warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key"); 290 let _ = 291 crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 292 return ApiError::PasskeyCounterAnomaly.into_response(); 293 } 294 Err(e) => { 295 error!("Failed to update passkey counter: {:?}", e); 296 } 297 Ok(true) => {} 298 } 299 300 let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 301 302 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 303 Ok(reauthed_at) => { 304 info!(did = %&auth.0.did, "Re-auth successful via passkey"); 305 Json(ReauthResponse { reauthed_at }).into_response() 306 } 307 Err(e) => { 308 error!("DB error updating reauth: {:?}", e); 309 ApiError::InternalError(None).into_response() 310 } 311 } 312} 313 314pub async fn update_last_reauth_cached( 315 db: &PgPool, 316 cache: &std::sync::Arc<dyn crate::cache::Cache>, 317 did: &str, 318) -> Result<DateTime<Utc>, sqlx::Error> { 319 let now = Utc::now(); 320 sqlx::query!( 321 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2", 322 now, 323 did 324 ) 325 .execute(db) 326 .await?; 327 let cache_key = format!("reauth:{}", did); 328 let _ = cache 329 .set( 330 &cache_key, 331 &now.timestamp().to_string(), 332 std::time::Duration::from_secs(REAUTH_WINDOW_SECONDS as u64), 333 ) 334 .await; 335 Ok(now) 336} 337 338fn is_reauth_required(last_reauth_at: Option<DateTime<Utc>>) -> bool { 339 match last_reauth_at { 340 None => true, 341 Some(t) => { 342 let elapsed = Utc::now().signed_duration_since(t); 343 elapsed.num_seconds() > REAUTH_WINDOW_SECONDS 344 } 345 } 346} 347 348async fn get_available_reauth_methods(db: &PgPool, did: &str) -> Vec<String> { 349 let mut methods = Vec::new(); 350 351 let has_password = sqlx::query_scalar!( 352 "SELECT password_hash IS NOT NULL as has_pw FROM users WHERE did = $1", 353 did 354 ) 355 .fetch_optional(db) 356 .await 357 .ok() 358 .flatten() 359 .unwrap_or(Some(false)); 360 361 if has_password == Some(true) { 362 methods.push("password".to_string()); 363 } 364 365 let has_totp = crate::api::server::totp::has_totp_enabled_db(db, did).await; 366 if has_totp { 367 methods.push("totp".to_string()); 368 } 369 370 let has_passkeys = crate::api::server::passkeys::has_passkeys_for_user_db(db, did).await; 371 if has_passkeys { 372 methods.push("passkey".to_string()); 373 } 374 375 methods 376} 377 378pub async fn check_reauth_required(db: &PgPool, did: &str) -> bool { 379 let session = sqlx::query!( 380 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 381 did 382 ) 383 .fetch_optional(db) 384 .await; 385 386 match session { 387 Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 388 _ => true, 389 } 390} 391 392pub async fn check_reauth_required_cached( 393 db: &PgPool, 394 cache: &std::sync::Arc<dyn crate::cache::Cache>, 395 did: &str, 396) -> bool { 397 let cache_key = format!("reauth:{}", did); 398 if let Some(timestamp_str) = cache.get(&cache_key).await 399 && let Ok(timestamp) = timestamp_str.parse::<i64>() 400 { 401 let reauth_time = chrono::DateTime::from_timestamp(timestamp, 0); 402 if let Some(t) = reauth_time { 403 let elapsed = Utc::now().signed_duration_since(t); 404 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 405 return false; 406 } 407 } 408 } 409 let session = sqlx::query!( 410 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 411 did 412 ) 413 .fetch_optional(db) 414 .await; 415 416 match session { 417 Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 418 _ => true, 419 } 420} 421 422#[derive(Serialize)] 423#[serde(rename_all = "camelCase")] 424pub struct ReauthRequiredError { 425 pub error: String, 426 pub message: String, 427 pub reauth_methods: Vec<String>, 428} 429 430pub async fn reauth_required_response(db: &PgPool, did: &str) -> Response { 431 let methods = get_available_reauth_methods(db, did).await; 432 ( 433 StatusCode::UNAUTHORIZED, 434 Json(ReauthRequiredError { 435 error: "ReauthRequired".to_string(), 436 message: "Re-authentication required for this action".to_string(), 437 reauth_methods: methods, 438 }), 439 ) 440 .into_response() 441} 442 443pub async fn check_legacy_session_mfa(db: &PgPool, did: &str) -> bool { 444 let session = sqlx::query!( 445 "SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 446 did 447 ) 448 .fetch_optional(db) 449 .await; 450 451 match session { 452 Ok(Some(row)) => { 453 if !row.legacy_login { 454 return true; 455 } 456 if row.mfa_verified { 457 return true; 458 } 459 if let Some(last_reauth) = row.last_reauth_at { 460 let elapsed = chrono::Utc::now().signed_duration_since(last_reauth); 461 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 462 return true; 463 } 464 } 465 false 466 } 467 _ => true, 468 } 469} 470 471pub async fn update_mfa_verified(db: &PgPool, did: &str) -> Result<(), sqlx::Error> { 472 sqlx::query!( 473 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1", 474 did 475 ) 476 .execute(db) 477 .await?; 478 Ok(()) 479} 480 481pub async fn legacy_mfa_required_response(db: &PgPool, did: &str) -> Response { 482 let methods = get_available_reauth_methods(db, did).await; 483 ( 484 StatusCode::FORBIDDEN, 485 Json(MfaVerificationRequiredError { 486 error: "MfaVerificationRequired".to_string(), 487 message: "This sensitive operation requires MFA verification. Your session was created via a legacy app that doesn't support MFA during login.".to_string(), 488 reauth_methods: methods, 489 }), 490 ) 491 .into_response() 492} 493 494#[derive(Serialize)] 495#[serde(rename_all = "camelCase")] 496pub struct MfaVerificationRequiredError { 497 pub error: String, 498 pub message: String, 499 pub reauth_methods: Vec<String>, 500}