this repo has no description
1use crate::api::error::ApiError;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use tracing::{error, info, warn};
12
13use crate::auth::BearerAuth;
14use crate::state::{AppState, RateLimitKind};
15use crate::types::PlainPassword;
16
17const REAUTH_WINDOW_SECONDS: i64 = 300;
18
19#[derive(Serialize)]
20#[serde(rename_all = "camelCase")]
21pub struct ReauthStatusResponse {
22 pub last_reauth_at: Option<DateTime<Utc>>,
23 pub reauth_required: bool,
24 pub available_methods: Vec<String>,
25}
26
27pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response {
28 let session = sqlx::query!(
29 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
30 &auth.0.did
31 )
32 .fetch_optional(&state.db)
33 .await;
34
35 let last_reauth_at = match session {
36 Ok(Some(row)) => row.last_reauth_at,
37 Ok(None) => None,
38 Err(e) => {
39 error!("DB error: {:?}", e);
40 return ApiError::InternalError(None).into_response();
41 }
42 };
43
44 let reauth_required = is_reauth_required(last_reauth_at);
45 let available_methods = get_available_reauth_methods(&state.db, &auth.0.did).await;
46
47 Json(ReauthStatusResponse {
48 last_reauth_at,
49 reauth_required,
50 available_methods,
51 })
52 .into_response()
53}
54
55#[derive(Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct PasswordReauthInput {
58 pub password: PlainPassword,
59}
60
61#[derive(Serialize)]
62#[serde(rename_all = "camelCase")]
63pub struct ReauthResponse {
64 pub reauthed_at: DateTime<Utc>,
65}
66
67pub async fn reauth_password(
68 State(state): State<AppState>,
69 auth: BearerAuth,
70 Json(input): Json<PasswordReauthInput>,
71) -> Response {
72 let user = sqlx::query!(
73 "SELECT password_hash FROM users WHERE did = $1",
74 &*&auth.0.did
75 )
76 .fetch_optional(&state.db)
77 .await;
78
79 let password_hash = match user {
80 Ok(Some(row)) => row.password_hash,
81 Ok(None) => {
82 return ApiError::AccountNotFound.into_response();
83 }
84 Err(e) => {
85 error!("DB error: {:?}", e);
86 return ApiError::InternalError(None).into_response();
87 }
88 };
89
90 let password_valid = password_hash
91 .as_ref()
92 .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false))
93 .unwrap_or(false);
94
95 if !password_valid {
96 let app_passwords = sqlx::query!(
97 "SELECT ap.password_hash FROM app_passwords ap
98 JOIN users u ON ap.user_id = u.id
99 WHERE u.did = $1",
100 &auth.0.did
101 )
102 .fetch_all(&state.db)
103 .await
104 .unwrap_or_default();
105
106 let app_password_valid = app_passwords
107 .iter()
108 .any(|ap| bcrypt::verify(&input.password, &ap.password_hash).unwrap_or(false));
109
110 if !app_password_valid {
111 warn!(did = %&auth.0.did, "Re-auth failed: invalid password");
112 return ApiError::InvalidPassword("Password is incorrect".into()).into_response();
113 }
114 }
115
116 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
117 Ok(reauthed_at) => {
118 info!(did = %&auth.0.did, "Re-auth successful via password");
119 Json(ReauthResponse { reauthed_at }).into_response()
120 }
121 Err(e) => {
122 error!("DB error updating reauth: {:?}", e);
123 ApiError::InternalError(None).into_response()
124 }
125 }
126}
127
128#[derive(Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct TotpReauthInput {
131 pub code: String,
132}
133
134pub async fn reauth_totp(
135 State(state): State<AppState>,
136 auth: BearerAuth,
137 Json(input): Json<TotpReauthInput>,
138) -> Response {
139 if !state
140 .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did)
141 .await
142 {
143 warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded");
144 return ApiError::RateLimitExceeded(Some(
145 "Too many verification attempts. Please try again in a few minutes.".into(),
146 ))
147 .into_response();
148 }
149
150 let valid =
151 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code)
152 .await;
153
154 if !valid {
155 warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code");
156 return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response();
157 }
158
159 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
160 Ok(reauthed_at) => {
161 info!(did = %&auth.0.did, "Re-auth successful via TOTP");
162 Json(ReauthResponse { reauthed_at }).into_response()
163 }
164 Err(e) => {
165 error!("DB error updating reauth: {:?}", e);
166 ApiError::InternalError(None).into_response()
167 }
168 }
169}
170
171#[derive(Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct PasskeyReauthStartResponse {
174 pub options: serde_json::Value,
175}
176
177pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response {
178 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
179
180 let stored_passkeys =
181 match crate::auth::webauthn::get_passkeys_for_user(&state.db, &auth.0.did).await {
182 Ok(pks) => pks,
183 Err(e) => {
184 error!("Failed to get passkeys: {:?}", e);
185 return ApiError::InternalError(None).into_response();
186 }
187 };
188
189 if stored_passkeys.is_empty() {
190 return ApiError::NoPasskeys.into_response();
191 }
192
193 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys
194 .iter()
195 .filter_map(|sp| sp.to_security_key().ok())
196 .collect();
197
198 if passkeys.is_empty() {
199 return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response();
200 }
201
202 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
203 Ok(w) => w,
204 Err(e) => {
205 error!("Failed to create WebAuthn config: {:?}", e);
206 return ApiError::InternalError(None).into_response();
207 }
208 };
209
210 let (rcr, auth_state) = match webauthn.start_authentication(passkeys) {
211 Ok(result) => result,
212 Err(e) => {
213 error!("Failed to start passkey authentication: {:?}", e);
214 return ApiError::InternalError(None).into_response();
215 }
216 };
217
218 if let Err(e) =
219 crate::auth::webauthn::save_authentication_state(&state.db, &auth.0.did, &auth_state).await
220 {
221 error!("Failed to save authentication state: {:?}", e);
222 return ApiError::InternalError(None).into_response();
223 }
224
225 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({}));
226 Json(PasskeyReauthStartResponse { options }).into_response()
227}
228
229#[derive(Deserialize)]
230#[serde(rename_all = "camelCase")]
231pub struct PasskeyReauthFinishInput {
232 pub credential: serde_json::Value,
233}
234
235pub async fn reauth_passkey_finish(
236 State(state): State<AppState>,
237 auth: BearerAuth,
238 Json(input): Json<PasskeyReauthFinishInput>,
239) -> Response {
240 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
241
242 let auth_state =
243 match crate::auth::webauthn::load_authentication_state(&state.db, &auth.0.did).await {
244 Ok(Some(s)) => s,
245 Ok(None) => {
246 return ApiError::NoChallengeInProgress.into_response();
247 }
248 Err(e) => {
249 error!("Failed to load authentication state: {:?}", e);
250 return ApiError::InternalError(None).into_response();
251 }
252 };
253
254 let credential: webauthn_rs::prelude::PublicKeyCredential =
255 match serde_json::from_value(input.credential) {
256 Ok(c) => c,
257 Err(e) => {
258 warn!("Failed to parse credential: {:?}", e);
259 return ApiError::InvalidCredential.into_response();
260 }
261 };
262
263 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
264 Ok(w) => w,
265 Err(e) => {
266 error!("Failed to create WebAuthn config: {:?}", e);
267 return ApiError::InternalError(None).into_response();
268 }
269 };
270
271 let auth_result = match webauthn.finish_authentication(&credential, &auth_state) {
272 Ok(r) => r,
273 Err(e) => {
274 warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e);
275 return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into()))
276 .into_response();
277 }
278 };
279
280 let cred_id_bytes = auth_result.cred_id().as_ref();
281 match crate::auth::webauthn::update_passkey_counter(
282 &state.db,
283 cred_id_bytes,
284 auth_result.counter(),
285 )
286 .await
287 {
288 Ok(false) => {
289 warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key");
290 let _ =
291 crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await;
292 return ApiError::PasskeyCounterAnomaly.into_response();
293 }
294 Err(e) => {
295 error!("Failed to update passkey counter: {:?}", e);
296 }
297 Ok(true) => {}
298 }
299
300 let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await;
301
302 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
303 Ok(reauthed_at) => {
304 info!(did = %&auth.0.did, "Re-auth successful via passkey");
305 Json(ReauthResponse { reauthed_at }).into_response()
306 }
307 Err(e) => {
308 error!("DB error updating reauth: {:?}", e);
309 ApiError::InternalError(None).into_response()
310 }
311 }
312}
313
314pub async fn update_last_reauth_cached(
315 db: &PgPool,
316 cache: &std::sync::Arc<dyn crate::cache::Cache>,
317 did: &str,
318) -> Result<DateTime<Utc>, sqlx::Error> {
319 let now = Utc::now();
320 sqlx::query!(
321 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2",
322 now,
323 did
324 )
325 .execute(db)
326 .await?;
327 let cache_key = format!("reauth:{}", did);
328 let _ = cache
329 .set(
330 &cache_key,
331 &now.timestamp().to_string(),
332 std::time::Duration::from_secs(REAUTH_WINDOW_SECONDS as u64),
333 )
334 .await;
335 Ok(now)
336}
337
338fn is_reauth_required(last_reauth_at: Option<DateTime<Utc>>) -> bool {
339 match last_reauth_at {
340 None => true,
341 Some(t) => {
342 let elapsed = Utc::now().signed_duration_since(t);
343 elapsed.num_seconds() > REAUTH_WINDOW_SECONDS
344 }
345 }
346}
347
348async fn get_available_reauth_methods(db: &PgPool, did: &str) -> Vec<String> {
349 let mut methods = Vec::new();
350
351 let has_password = sqlx::query_scalar!(
352 "SELECT password_hash IS NOT NULL as has_pw FROM users WHERE did = $1",
353 did
354 )
355 .fetch_optional(db)
356 .await
357 .ok()
358 .flatten()
359 .unwrap_or(Some(false));
360
361 if has_password == Some(true) {
362 methods.push("password".to_string());
363 }
364
365 let has_totp = crate::api::server::totp::has_totp_enabled_db(db, did).await;
366 if has_totp {
367 methods.push("totp".to_string());
368 }
369
370 let has_passkeys = crate::api::server::passkeys::has_passkeys_for_user_db(db, did).await;
371 if has_passkeys {
372 methods.push("passkey".to_string());
373 }
374
375 methods
376}
377
378pub async fn check_reauth_required(db: &PgPool, did: &str) -> bool {
379 let session = sqlx::query!(
380 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
381 did
382 )
383 .fetch_optional(db)
384 .await;
385
386 match session {
387 Ok(Some(row)) => is_reauth_required(row.last_reauth_at),
388 _ => true,
389 }
390}
391
392pub async fn check_reauth_required_cached(
393 db: &PgPool,
394 cache: &std::sync::Arc<dyn crate::cache::Cache>,
395 did: &str,
396) -> bool {
397 let cache_key = format!("reauth:{}", did);
398 if let Some(timestamp_str) = cache.get(&cache_key).await
399 && let Ok(timestamp) = timestamp_str.parse::<i64>()
400 {
401 let reauth_time = chrono::DateTime::from_timestamp(timestamp, 0);
402 if let Some(t) = reauth_time {
403 let elapsed = Utc::now().signed_duration_since(t);
404 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
405 return false;
406 }
407 }
408 }
409 let session = sqlx::query!(
410 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
411 did
412 )
413 .fetch_optional(db)
414 .await;
415
416 match session {
417 Ok(Some(row)) => is_reauth_required(row.last_reauth_at),
418 _ => true,
419 }
420}
421
422#[derive(Serialize)]
423#[serde(rename_all = "camelCase")]
424pub struct ReauthRequiredError {
425 pub error: String,
426 pub message: String,
427 pub reauth_methods: Vec<String>,
428}
429
430pub async fn reauth_required_response(db: &PgPool, did: &str) -> Response {
431 let methods = get_available_reauth_methods(db, did).await;
432 (
433 StatusCode::UNAUTHORIZED,
434 Json(ReauthRequiredError {
435 error: "ReauthRequired".to_string(),
436 message: "Re-authentication required for this action".to_string(),
437 reauth_methods: methods,
438 }),
439 )
440 .into_response()
441}
442
443pub async fn check_legacy_session_mfa(db: &PgPool, did: &str) -> bool {
444 let session = sqlx::query!(
445 "SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
446 did
447 )
448 .fetch_optional(db)
449 .await;
450
451 match session {
452 Ok(Some(row)) => {
453 if !row.legacy_login {
454 return true;
455 }
456 if row.mfa_verified {
457 return true;
458 }
459 if let Some(last_reauth) = row.last_reauth_at {
460 let elapsed = chrono::Utc::now().signed_duration_since(last_reauth);
461 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
462 return true;
463 }
464 }
465 false
466 }
467 _ => true,
468 }
469}
470
471pub async fn update_mfa_verified(db: &PgPool, did: &str) -> Result<(), sqlx::Error> {
472 sqlx::query!(
473 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1",
474 did
475 )
476 .execute(db)
477 .await?;
478 Ok(())
479}
480
481pub async fn legacy_mfa_required_response(db: &PgPool, did: &str) -> Response {
482 let methods = get_available_reauth_methods(db, did).await;
483 (
484 StatusCode::FORBIDDEN,
485 Json(MfaVerificationRequiredError {
486 error: "MfaVerificationRequired".to_string(),
487 message: "This sensitive operation requires MFA verification. Your session was created via a legacy app that doesn't support MFA during login.".to_string(),
488 reauth_methods: methods,
489 }),
490 )
491 .into_response()
492}
493
494#[derive(Serialize)]
495#[serde(rename_all = "camelCase")]
496pub struct MfaVerificationRequiredError {
497 pub error: String,
498 pub message: String,
499 pub reauth_methods: Vec<String>,
500}