this repo has no description
1use crate::api::error::ApiError;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use sqlx::PgPool;
11use tracing::{error, info, warn};
12
13use crate::auth::BearerAuth;
14use crate::state::{AppState, RateLimitKind};
15use crate::types::PlainPassword;
16
17const REAUTH_WINDOW_SECONDS: i64 = 300;
18
19#[derive(Serialize)]
20#[serde(rename_all = "camelCase")]
21pub struct ReauthStatusResponse {
22 pub last_reauth_at: Option<DateTime<Utc>>,
23 pub reauth_required: bool,
24 pub available_methods: Vec<String>,
25}
26
27pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response {
28 let session = sqlx::query!(
29 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
30 &auth.0.did
31 )
32 .fetch_optional(&state.db)
33 .await;
34
35 let last_reauth_at = match session {
36 Ok(Some(row)) => row.last_reauth_at,
37 Ok(None) => None,
38 Err(e) => {
39 error!("DB error: {:?}", e);
40 return ApiError::InternalError(None).into_response();
41 }
42 };
43
44 let reauth_required = is_reauth_required(last_reauth_at);
45 let available_methods = get_available_reauth_methods(&state.db, &auth.0.did).await;
46
47 Json(ReauthStatusResponse {
48 last_reauth_at,
49 reauth_required,
50 available_methods,
51 })
52 .into_response()
53}
54
55#[derive(Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct PasswordReauthInput {
58 pub password: PlainPassword,
59}
60
61#[derive(Serialize)]
62#[serde(rename_all = "camelCase")]
63pub struct ReauthResponse {
64 pub reauthed_at: DateTime<Utc>,
65}
66
67pub async fn reauth_password(
68 State(state): State<AppState>,
69 auth: BearerAuth,
70 Json(input): Json<PasswordReauthInput>,
71) -> Response {
72 let user = sqlx::query!("SELECT password_hash FROM users WHERE did = $1", &*&auth.0.did)
73 .fetch_optional(&state.db)
74 .await;
75
76 let password_hash = match user {
77 Ok(Some(row)) => row.password_hash,
78 Ok(None) => {
79 return ApiError::AccountNotFound.into_response();
80 }
81 Err(e) => {
82 error!("DB error: {:?}", e);
83 return ApiError::InternalError(None).into_response();
84 }
85 };
86
87 let password_valid = password_hash
88 .as_ref()
89 .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false))
90 .unwrap_or(false);
91
92 if !password_valid {
93 let app_passwords = sqlx::query!(
94 "SELECT ap.password_hash FROM app_passwords ap
95 JOIN users u ON ap.user_id = u.id
96 WHERE u.did = $1",
97 &auth.0.did
98 )
99 .fetch_all(&state.db)
100 .await
101 .unwrap_or_default();
102
103 let app_password_valid = app_passwords
104 .iter()
105 .any(|ap| bcrypt::verify(&input.password, &ap.password_hash).unwrap_or(false));
106
107 if !app_password_valid {
108 warn!(did = %&auth.0.did, "Re-auth failed: invalid password");
109 return ApiError::InvalidPassword("Password is incorrect".into()).into_response();
110 }
111 }
112
113 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
114 Ok(reauthed_at) => {
115 info!(did = %&auth.0.did, "Re-auth successful via password");
116 Json(ReauthResponse { reauthed_at }).into_response()
117 }
118 Err(e) => {
119 error!("DB error updating reauth: {:?}", e);
120 ApiError::InternalError(None).into_response()
121 }
122 }
123}
124
125#[derive(Deserialize)]
126#[serde(rename_all = "camelCase")]
127pub struct TotpReauthInput {
128 pub code: String,
129}
130
131pub async fn reauth_totp(
132 State(state): State<AppState>,
133 auth: BearerAuth,
134 Json(input): Json<TotpReauthInput>,
135) -> Response {
136 if !state
137 .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did)
138 .await
139 {
140 warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded");
141 return ApiError::RateLimitExceeded(Some("Too many verification attempts. Please try again in a few minutes.".into(),))
142 .into_response();
143 }
144
145 let valid =
146 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code)
147 .await;
148
149 if !valid {
150 warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code");
151 return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response();
152 }
153
154 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
155 Ok(reauthed_at) => {
156 info!(did = %&auth.0.did, "Re-auth successful via TOTP");
157 Json(ReauthResponse { reauthed_at }).into_response()
158 }
159 Err(e) => {
160 error!("DB error updating reauth: {:?}", e);
161 ApiError::InternalError(None).into_response()
162 }
163 }
164}
165
166#[derive(Serialize)]
167#[serde(rename_all = "camelCase")]
168pub struct PasskeyReauthStartResponse {
169 pub options: serde_json::Value,
170}
171
172pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response {
173 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
174
175 let stored_passkeys =
176 match crate::auth::webauthn::get_passkeys_for_user(&state.db, &auth.0.did).await {
177 Ok(pks) => pks,
178 Err(e) => {
179 error!("Failed to get passkeys: {:?}", e);
180 return ApiError::InternalError(None).into_response();
181 }
182 };
183
184 if stored_passkeys.is_empty() {
185 return ApiError::NoPasskeys.into_response();
186 }
187
188 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys
189 .iter()
190 .filter_map(|sp| sp.to_security_key().ok())
191 .collect();
192
193 if passkeys.is_empty() {
194 return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response();
195 }
196
197 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
198 Ok(w) => w,
199 Err(e) => {
200 error!("Failed to create WebAuthn config: {:?}", e);
201 return ApiError::InternalError(None).into_response();
202 }
203 };
204
205 let (rcr, auth_state) = match webauthn.start_authentication(passkeys) {
206 Ok(result) => result,
207 Err(e) => {
208 error!("Failed to start passkey authentication: {:?}", e);
209 return ApiError::InternalError(None).into_response();
210 }
211 };
212
213 if let Err(e) =
214 crate::auth::webauthn::save_authentication_state(&state.db, &auth.0.did, &auth_state).await
215 {
216 error!("Failed to save authentication state: {:?}", e);
217 return ApiError::InternalError(None).into_response();
218 }
219
220 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({}));
221 Json(PasskeyReauthStartResponse { options }).into_response()
222}
223
224#[derive(Deserialize)]
225#[serde(rename_all = "camelCase")]
226pub struct PasskeyReauthFinishInput {
227 pub credential: serde_json::Value,
228}
229
230pub async fn reauth_passkey_finish(
231 State(state): State<AppState>,
232 auth: BearerAuth,
233 Json(input): Json<PasskeyReauthFinishInput>,
234) -> Response {
235 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
236
237 let auth_state =
238 match crate::auth::webauthn::load_authentication_state(&state.db, &auth.0.did).await {
239 Ok(Some(s)) => s,
240 Ok(None) => {
241 return ApiError::NoChallengeInProgress.into_response();
242 }
243 Err(e) => {
244 error!("Failed to load authentication state: {:?}", e);
245 return ApiError::InternalError(None).into_response();
246 }
247 };
248
249 let credential: webauthn_rs::prelude::PublicKeyCredential =
250 match serde_json::from_value(input.credential) {
251 Ok(c) => c,
252 Err(e) => {
253 warn!("Failed to parse credential: {:?}", e);
254 return ApiError::InvalidCredential.into_response();
255 }
256 };
257
258 let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
259 Ok(w) => w,
260 Err(e) => {
261 error!("Failed to create WebAuthn config: {:?}", e);
262 return ApiError::InternalError(None).into_response();
263 }
264 };
265
266 let auth_result = match webauthn.finish_authentication(&credential, &auth_state) {
267 Ok(r) => r,
268 Err(e) => {
269 warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e);
270 return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into()))
271 .into_response();
272 }
273 };
274
275 let cred_id_bytes = auth_result.cred_id().as_ref();
276 match crate::auth::webauthn::update_passkey_counter(
277 &state.db,
278 cred_id_bytes,
279 auth_result.counter(),
280 )
281 .await
282 {
283 Ok(false) => {
284 warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key");
285 let _ =
286 crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await;
287 return ApiError::PasskeyCounterAnomaly.into_response();
288 }
289 Err(e) => {
290 error!("Failed to update passkey counter: {:?}", e);
291 }
292 Ok(true) => {}
293 }
294
295 let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await;
296
297 match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await {
298 Ok(reauthed_at) => {
299 info!(did = %&auth.0.did, "Re-auth successful via passkey");
300 Json(ReauthResponse { reauthed_at }).into_response()
301 }
302 Err(e) => {
303 error!("DB error updating reauth: {:?}", e);
304 ApiError::InternalError(None).into_response()
305 }
306 }
307}
308
309pub async fn update_last_reauth_cached(
310 db: &PgPool,
311 cache: &std::sync::Arc<dyn crate::cache::Cache>,
312 did: &str,
313) -> Result<DateTime<Utc>, sqlx::Error> {
314 let now = Utc::now();
315 sqlx::query!(
316 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2",
317 now,
318 did
319 )
320 .execute(db)
321 .await?;
322 let cache_key = format!("reauth:{}", did);
323 let _ = cache
324 .set(
325 &cache_key,
326 &now.timestamp().to_string(),
327 std::time::Duration::from_secs(REAUTH_WINDOW_SECONDS as u64),
328 )
329 .await;
330 Ok(now)
331}
332
333fn is_reauth_required(last_reauth_at: Option<DateTime<Utc>>) -> bool {
334 match last_reauth_at {
335 None => true,
336 Some(t) => {
337 let elapsed = Utc::now().signed_duration_since(t);
338 elapsed.num_seconds() > REAUTH_WINDOW_SECONDS
339 }
340 }
341}
342
343async fn get_available_reauth_methods(db: &PgPool, did: &str) -> Vec<String> {
344 let mut methods = Vec::new();
345
346 let has_password = sqlx::query_scalar!(
347 "SELECT password_hash IS NOT NULL as has_pw FROM users WHERE did = $1",
348 did
349 )
350 .fetch_optional(db)
351 .await
352 .ok()
353 .flatten()
354 .unwrap_or(Some(false));
355
356 if has_password == Some(true) {
357 methods.push("password".to_string());
358 }
359
360 let has_totp = crate::api::server::totp::has_totp_enabled_db(db, did).await;
361 if has_totp {
362 methods.push("totp".to_string());
363 }
364
365 let has_passkeys = crate::api::server::passkeys::has_passkeys_for_user_db(db, did).await;
366 if has_passkeys {
367 methods.push("passkey".to_string());
368 }
369
370 methods
371}
372
373pub async fn check_reauth_required(db: &PgPool, did: &str) -> bool {
374 let session = sqlx::query!(
375 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
376 did
377 )
378 .fetch_optional(db)
379 .await;
380
381 match session {
382 Ok(Some(row)) => is_reauth_required(row.last_reauth_at),
383 _ => true,
384 }
385}
386
387pub async fn check_reauth_required_cached(
388 db: &PgPool,
389 cache: &std::sync::Arc<dyn crate::cache::Cache>,
390 did: &str,
391) -> bool {
392 let cache_key = format!("reauth:{}", did);
393 if let Some(timestamp_str) = cache.get(&cache_key).await
394 && let Ok(timestamp) = timestamp_str.parse::<i64>()
395 {
396 let reauth_time = chrono::DateTime::from_timestamp(timestamp, 0);
397 if let Some(t) = reauth_time {
398 let elapsed = Utc::now().signed_duration_since(t);
399 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
400 return false;
401 }
402 }
403 }
404 let session = sqlx::query!(
405 "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
406 did
407 )
408 .fetch_optional(db)
409 .await;
410
411 match session {
412 Ok(Some(row)) => is_reauth_required(row.last_reauth_at),
413 _ => true,
414 }
415}
416
417#[derive(Serialize)]
418#[serde(rename_all = "camelCase")]
419pub struct ReauthRequiredError {
420 pub error: String,
421 pub message: String,
422 pub reauth_methods: Vec<String>,
423}
424
425pub async fn reauth_required_response(db: &PgPool, did: &str) -> Response {
426 let methods = get_available_reauth_methods(db, did).await;
427 (
428 StatusCode::UNAUTHORIZED,
429 Json(ReauthRequiredError {
430 error: "ReauthRequired".to_string(),
431 message: "Re-authentication required for this action".to_string(),
432 reauth_methods: methods,
433 }),
434 )
435 .into_response()
436}
437
438pub async fn check_legacy_session_mfa(db: &PgPool, did: &str) -> bool {
439 let session = sqlx::query!(
440 "SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1",
441 did
442 )
443 .fetch_optional(db)
444 .await;
445
446 match session {
447 Ok(Some(row)) => {
448 if !row.legacy_login {
449 return true;
450 }
451 if row.mfa_verified {
452 return true;
453 }
454 if let Some(last_reauth) = row.last_reauth_at {
455 let elapsed = chrono::Utc::now().signed_duration_since(last_reauth);
456 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
457 return true;
458 }
459 }
460 false
461 }
462 _ => true,
463 }
464}
465
466pub async fn update_mfa_verified(db: &PgPool, did: &str) -> Result<(), sqlx::Error> {
467 sqlx::query!(
468 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1",
469 did
470 )
471 .execute(db)
472 .await?;
473 Ok(())
474}
475
476pub async fn legacy_mfa_required_response(db: &PgPool, did: &str) -> Response {
477 let methods = get_available_reauth_methods(db, did).await;
478 (
479 StatusCode::FORBIDDEN,
480 Json(MfaVerificationRequiredError {
481 error: "MfaVerificationRequired".to_string(),
482 message: "This sensitive operation requires MFA verification. Your session was created via a legacy app that doesn't support MFA during login.".to_string(),
483 reauth_methods: methods,
484 }),
485 )
486 .into_response()
487}
488
489#[derive(Serialize)]
490#[serde(rename_all = "camelCase")]
491pub struct MfaVerificationRequiredError {
492 pub error: String,
493 pub message: String,
494 pub reauth_methods: Vec<String>,
495}