this repo has no description
1use crate::auth::BearerAuth; 2use crate::state::{AppState, RateLimitKind}; 3use axum::{ 4 Json, 5 extract::State, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use bcrypt::{DEFAULT_COST, hash, verify}; 10use chrono::{Duration, Utc}; 11use serde::Deserialize; 12use serde_json::json; 13use tracing::{error, info, warn}; 14use uuid::Uuid; 15 16fn generate_reset_code() -> String { 17 crate::util::generate_token_code() 18} 19fn extract_client_ip(headers: &HeaderMap) -> String { 20 if let Some(forwarded) = headers.get("x-forwarded-for") 21 && let Ok(value) = forwarded.to_str() 22 && let Some(first_ip) = value.split(',').next() 23 { 24 return first_ip.trim().to_string(); 25 } 26 if let Some(real_ip) = headers.get("x-real-ip") 27 && let Ok(value) = real_ip.to_str() 28 { 29 return value.trim().to_string(); 30 } 31 "unknown".to_string() 32} 33 34#[derive(Deserialize)] 35pub struct RequestPasswordResetInput { 36 #[serde(alias = "identifier")] 37 pub email: String, 38} 39 40pub async fn request_password_reset( 41 State(state): State<AppState>, 42 headers: HeaderMap, 43 Json(input): Json<RequestPasswordResetInput>, 44) -> Response { 45 let client_ip = extract_client_ip(&headers); 46 if !state 47 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 48 .await 49 { 50 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 51 return ( 52 StatusCode::TOO_MANY_REQUESTS, 53 Json(json!({ 54 "error": "RateLimitExceeded", 55 "message": "Too many password reset requests. Please try again later." 56 })), 57 ) 58 .into_response(); 59 } 60 let identifier = input.email.trim(); 61 if identifier.is_empty() { 62 return ( 63 StatusCode::BAD_REQUEST, 64 Json(json!({"error": "InvalidRequest", "message": "email or handle is required"})), 65 ) 66 .into_response(); 67 } 68 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 69 let normalized = identifier.to_lowercase(); 70 let normalized = normalized.strip_prefix('@').unwrap_or(&normalized); 71 let normalized_handle = if normalized.contains('@') || normalized.contains('.') { 72 normalized.to_string() 73 } else { 74 format!("{}.{}", normalized, pds_hostname) 75 }; 76 let user = sqlx::query!( 77 "SELECT id FROM users WHERE LOWER(email) = $1 OR handle = $2", 78 normalized, 79 normalized_handle 80 ) 81 .fetch_optional(&state.db) 82 .await; 83 let user_id = match user { 84 Ok(Some(row)) => row.id, 85 Ok(None) => { 86 info!("Password reset requested for unknown identifier"); 87 return (StatusCode::OK, Json(json!({}))).into_response(); 88 } 89 Err(e) => { 90 error!("DB error in request_password_reset: {:?}", e); 91 return ( 92 StatusCode::INTERNAL_SERVER_ERROR, 93 Json(json!({"error": "InternalError"})), 94 ) 95 .into_response(); 96 } 97 }; 98 let code = generate_reset_code(); 99 let expires_at = Utc::now() + Duration::minutes(10); 100 let update = sqlx::query!( 101 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 102 code, 103 expires_at, 104 user_id 105 ) 106 .execute(&state.db) 107 .await; 108 if let Err(e) = update { 109 error!("DB error setting reset code: {:?}", e); 110 return ( 111 StatusCode::INTERNAL_SERVER_ERROR, 112 Json(json!({"error": "InternalError"})), 113 ) 114 .into_response(); 115 } 116 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 117 if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 118 { 119 warn!("Failed to enqueue password reset notification: {:?}", e); 120 } 121 info!("Password reset requested for user {}", user_id); 122 (StatusCode::OK, Json(json!({}))).into_response() 123} 124 125#[derive(Deserialize)] 126pub struct ResetPasswordInput { 127 pub token: String, 128 pub password: String, 129} 130 131pub async fn reset_password( 132 State(state): State<AppState>, 133 headers: HeaderMap, 134 Json(input): Json<ResetPasswordInput>, 135) -> Response { 136 let client_ip = extract_client_ip(&headers); 137 if !state 138 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 139 .await 140 { 141 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 142 return ( 143 StatusCode::TOO_MANY_REQUESTS, 144 Json(json!({ 145 "error": "RateLimitExceeded", 146 "message": "Too many requests. Please try again later." 147 })), 148 ) 149 .into_response(); 150 } 151 let token = input.token.trim(); 152 let password = &input.password; 153 if token.is_empty() { 154 return ( 155 StatusCode::BAD_REQUEST, 156 Json(json!({"error": "InvalidToken", "message": "token is required"})), 157 ) 158 .into_response(); 159 } 160 if password.is_empty() { 161 return ( 162 StatusCode::BAD_REQUEST, 163 Json(json!({"error": "InvalidRequest", "message": "password is required"})), 164 ) 165 .into_response(); 166 } 167 let user = sqlx::query!( 168 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 169 token 170 ) 171 .fetch_optional(&state.db) 172 .await; 173 let (user_id, expires_at) = match user { 174 Ok(Some(row)) => { 175 let expires = row.password_reset_code_expires_at; 176 (row.id, expires) 177 } 178 Ok(None) => { 179 return ( 180 StatusCode::BAD_REQUEST, 181 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 182 ) 183 .into_response(); 184 } 185 Err(e) => { 186 error!("DB error in reset_password: {:?}", e); 187 return ( 188 StatusCode::INTERNAL_SERVER_ERROR, 189 Json(json!({"error": "InternalError"})), 190 ) 191 .into_response(); 192 } 193 }; 194 if let Some(exp) = expires_at { 195 if Utc::now() > exp { 196 if let Err(e) = sqlx::query!( 197 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 198 user_id 199 ) 200 .execute(&state.db) 201 .await 202 { 203 error!("Failed to clear expired reset code: {:?}", e); 204 } 205 return ( 206 StatusCode::BAD_REQUEST, 207 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 208 ) 209 .into_response(); 210 } 211 } else { 212 return ( 213 StatusCode::BAD_REQUEST, 214 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 215 ) 216 .into_response(); 217 } 218 let password_hash = match hash(password, DEFAULT_COST) { 219 Ok(h) => h, 220 Err(e) => { 221 error!("Failed to hash password: {:?}", e); 222 return ( 223 StatusCode::INTERNAL_SERVER_ERROR, 224 Json(json!({"error": "InternalError"})), 225 ) 226 .into_response(); 227 } 228 }; 229 let mut tx = match state.db.begin().await { 230 Ok(tx) => tx, 231 Err(e) => { 232 error!("Failed to begin transaction: {:?}", e); 233 return ( 234 StatusCode::INTERNAL_SERVER_ERROR, 235 Json(json!({"error": "InternalError"})), 236 ) 237 .into_response(); 238 } 239 }; 240 if let Err(e) = sqlx::query!( 241 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL, password_required = TRUE WHERE id = $2", 242 password_hash, 243 user_id 244 ) 245 .execute(&mut *tx) 246 .await 247 { 248 error!("DB error updating password: {:?}", e); 249 return ( 250 StatusCode::INTERNAL_SERVER_ERROR, 251 Json(json!({"error": "InternalError"})), 252 ) 253 .into_response(); 254 } 255 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id) 256 .fetch_one(&mut *tx) 257 .await 258 { 259 Ok(did) => did, 260 Err(e) => { 261 error!("Failed to get DID for user {}: {:?}", user_id, e); 262 return ( 263 StatusCode::INTERNAL_SERVER_ERROR, 264 Json(json!({"error": "InternalError"})), 265 ) 266 .into_response(); 267 } 268 }; 269 let session_jtis: Vec<String> = match sqlx::query_scalar!( 270 "SELECT access_jti FROM session_tokens WHERE did = $1", 271 user_did 272 ) 273 .fetch_all(&mut *tx) 274 .await 275 { 276 Ok(jtis) => jtis, 277 Err(e) => { 278 error!("Failed to fetch session JTIs: {:?}", e); 279 vec![] 280 } 281 }; 282 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 283 .execute(&mut *tx) 284 .await 285 { 286 error!( 287 "Failed to invalidate sessions after password reset: {:?}", 288 e 289 ); 290 return ( 291 StatusCode::INTERNAL_SERVER_ERROR, 292 Json(json!({"error": "InternalError"})), 293 ) 294 .into_response(); 295 } 296 if let Err(e) = tx.commit().await { 297 error!("Failed to commit password reset transaction: {:?}", e); 298 return ( 299 StatusCode::INTERNAL_SERVER_ERROR, 300 Json(json!({"error": "InternalError"})), 301 ) 302 .into_response(); 303 } 304 for jti in session_jtis { 305 let cache_key = format!("auth:session:{}:{}", user_did, jti); 306 if let Err(e) = state.cache.delete(&cache_key).await { 307 warn!( 308 "Failed to invalidate session cache for {}: {:?}", 309 cache_key, e 310 ); 311 } 312 } 313 info!("Password reset completed for user {}", user_id); 314 (StatusCode::OK, Json(json!({}))).into_response() 315} 316 317#[derive(Deserialize)] 318#[serde(rename_all = "camelCase")] 319pub struct ChangePasswordInput { 320 pub current_password: String, 321 pub new_password: String, 322} 323 324pub async fn change_password( 325 State(state): State<AppState>, 326 auth: BearerAuth, 327 Json(input): Json<ChangePasswordInput>, 328) -> Response { 329 let current_password = &input.current_password; 330 let new_password = &input.new_password; 331 if current_password.is_empty() { 332 return ( 333 StatusCode::BAD_REQUEST, 334 Json(json!({"error": "InvalidRequest", "message": "currentPassword is required"})), 335 ) 336 .into_response(); 337 } 338 if new_password.is_empty() { 339 return ( 340 StatusCode::BAD_REQUEST, 341 Json(json!({"error": "InvalidRequest", "message": "newPassword is required"})), 342 ) 343 .into_response(); 344 } 345 if new_password.len() < 8 { 346 return ( 347 StatusCode::BAD_REQUEST, 348 Json(json!({"error": "InvalidRequest", "message": "Password must be at least 8 characters"})), 349 ) 350 .into_response(); 351 } 352 let user = 353 sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1") 354 .bind(&auth.0.did) 355 .fetch_optional(&state.db) 356 .await; 357 let (user_id, password_hash) = match user { 358 Ok(Some(row)) => row, 359 Ok(None) => { 360 return ( 361 StatusCode::NOT_FOUND, 362 Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 363 ) 364 .into_response(); 365 } 366 Err(e) => { 367 error!("DB error in change_password: {:?}", e); 368 return ( 369 StatusCode::INTERNAL_SERVER_ERROR, 370 Json(json!({"error": "InternalError"})), 371 ) 372 .into_response(); 373 } 374 }; 375 let valid = match verify(current_password, &password_hash) { 376 Ok(v) => v, 377 Err(e) => { 378 error!("Password verification error: {:?}", e); 379 return ( 380 StatusCode::INTERNAL_SERVER_ERROR, 381 Json(json!({"error": "InternalError"})), 382 ) 383 .into_response(); 384 } 385 }; 386 if !valid { 387 return ( 388 StatusCode::UNAUTHORIZED, 389 Json(json!({"error": "InvalidPassword", "message": "Current password is incorrect"})), 390 ) 391 .into_response(); 392 } 393 let new_hash = match hash(new_password, DEFAULT_COST) { 394 Ok(h) => h, 395 Err(e) => { 396 error!("Failed to hash password: {:?}", e); 397 return ( 398 StatusCode::INTERNAL_SERVER_ERROR, 399 Json(json!({"error": "InternalError"})), 400 ) 401 .into_response(); 402 } 403 }; 404 if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2") 405 .bind(&new_hash) 406 .bind(user_id) 407 .execute(&state.db) 408 .await 409 { 410 error!("DB error updating password: {:?}", e); 411 return ( 412 StatusCode::INTERNAL_SERVER_ERROR, 413 Json(json!({"error": "InternalError"})), 414 ) 415 .into_response(); 416 } 417 info!(did = %auth.0.did, "Password changed successfully"); 418 (StatusCode::OK, Json(json!({}))).into_response() 419} 420 421pub async fn get_password_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 422 let user = sqlx::query!( 423 "SELECT password_hash IS NOT NULL as has_password FROM users WHERE did = $1", 424 auth.0.did 425 ) 426 .fetch_optional(&state.db) 427 .await; 428 429 match user { 430 Ok(Some(row)) => { 431 Json(json!({"hasPassword": row.has_password.unwrap_or(false)})).into_response() 432 } 433 Ok(None) => ( 434 StatusCode::NOT_FOUND, 435 Json(json!({"error": "AccountNotFound"})), 436 ) 437 .into_response(), 438 Err(e) => { 439 error!("DB error: {:?}", e); 440 ( 441 StatusCode::INTERNAL_SERVER_ERROR, 442 Json(json!({"error": "InternalError"})), 443 ) 444 .into_response() 445 } 446 } 447} 448 449pub async fn remove_password(State(state): State<AppState>, auth: BearerAuth) -> Response { 450 if crate::api::server::reauth::check_reauth_required(&state.db, &auth.0.did).await { 451 return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await; 452 } 453 454 let has_passkeys = 455 crate::api::server::passkeys::has_passkeys_for_user_db(&state.db, &auth.0.did).await; 456 if !has_passkeys { 457 return ( 458 StatusCode::BAD_REQUEST, 459 Json(json!({ 460 "error": "NoPasskeys", 461 "message": "You must have at least one passkey registered before removing your password" 462 })), 463 ) 464 .into_response(); 465 } 466 467 let user = sqlx::query!( 468 "SELECT id, password_hash FROM users WHERE did = $1", 469 auth.0.did 470 ) 471 .fetch_optional(&state.db) 472 .await; 473 474 let user = match user { 475 Ok(Some(u)) => u, 476 Ok(None) => { 477 return ( 478 StatusCode::NOT_FOUND, 479 Json(json!({"error": "AccountNotFound"})), 480 ) 481 .into_response(); 482 } 483 Err(e) => { 484 error!("DB error: {:?}", e); 485 return ( 486 StatusCode::INTERNAL_SERVER_ERROR, 487 Json(json!({"error": "InternalError"})), 488 ) 489 .into_response(); 490 } 491 }; 492 493 if user.password_hash.is_none() { 494 return ( 495 StatusCode::BAD_REQUEST, 496 Json(json!({ 497 "error": "NoPassword", 498 "message": "Account already has no password" 499 })), 500 ) 501 .into_response(); 502 } 503 504 if let Err(e) = sqlx::query!( 505 "UPDATE users SET password_hash = NULL, password_required = FALSE WHERE id = $1", 506 user.id 507 ) 508 .execute(&state.db) 509 .await 510 { 511 error!("DB error removing password: {:?}", e); 512 return ( 513 StatusCode::INTERNAL_SERVER_ERROR, 514 Json(json!({"error": "InternalError"})), 515 ) 516 .into_response(); 517 } 518 519 info!(did = %auth.0.did, "Password removed - account is now passkey-only"); 520 (StatusCode::OK, Json(json!({"success": true}))).into_response() 521}