this repo has no description
1use crate::auth::BearerAuth; 2use crate::state::{AppState, RateLimitKind}; 3use axum::{ 4 Json, 5 extract::State, 6 http::{HeaderMap, StatusCode}, 7 response::{IntoResponse, Response}, 8}; 9use bcrypt::{DEFAULT_COST, hash, verify}; 10use chrono::{Duration, Utc}; 11use serde::Deserialize; 12use serde_json::json; 13use tracing::{error, info, warn}; 14use uuid::Uuid; 15 16fn generate_reset_code() -> String { 17 crate::util::generate_token_code() 18} 19fn extract_client_ip(headers: &HeaderMap) -> String { 20 if let Some(forwarded) = headers.get("x-forwarded-for") 21 && let Ok(value) = forwarded.to_str() 22 && let Some(first_ip) = value.split(',').next() 23 { 24 return first_ip.trim().to_string(); 25 } 26 if let Some(real_ip) = headers.get("x-real-ip") 27 && let Ok(value) = real_ip.to_str() 28 { 29 return value.trim().to_string(); 30 } 31 "unknown".to_string() 32} 33 34#[derive(Deserialize)] 35pub struct RequestPasswordResetInput { 36 pub email: String, 37} 38 39pub async fn request_password_reset( 40 State(state): State<AppState>, 41 headers: HeaderMap, 42 Json(input): Json<RequestPasswordResetInput>, 43) -> Response { 44 let client_ip = extract_client_ip(&headers); 45 if !state 46 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 47 .await 48 { 49 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 50 return ( 51 StatusCode::TOO_MANY_REQUESTS, 52 Json(json!({ 53 "error": "RateLimitExceeded", 54 "message": "Too many password reset requests. Please try again later." 55 })), 56 ) 57 .into_response(); 58 } 59 let email = input.email.trim().to_lowercase(); 60 if email.is_empty() { 61 return ( 62 StatusCode::BAD_REQUEST, 63 Json(json!({"error": "InvalidRequest", "message": "email is required"})), 64 ) 65 .into_response(); 66 } 67 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email) 68 .fetch_optional(&state.db) 69 .await; 70 let user_id = match user { 71 Ok(Some(row)) => row.id, 72 Ok(None) => { 73 info!("Password reset requested for unknown email"); 74 return (StatusCode::OK, Json(json!({}))).into_response(); 75 } 76 Err(e) => { 77 error!("DB error in request_password_reset: {:?}", e); 78 return ( 79 StatusCode::INTERNAL_SERVER_ERROR, 80 Json(json!({"error": "InternalError"})), 81 ) 82 .into_response(); 83 } 84 }; 85 let code = generate_reset_code(); 86 let expires_at = Utc::now() + Duration::minutes(10); 87 let update = sqlx::query!( 88 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 89 code, 90 expires_at, 91 user_id 92 ) 93 .execute(&state.db) 94 .await; 95 if let Err(e) = update { 96 error!("DB error setting reset code: {:?}", e); 97 return ( 98 StatusCode::INTERNAL_SERVER_ERROR, 99 Json(json!({"error": "InternalError"})), 100 ) 101 .into_response(); 102 } 103 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 104 if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 105 { 106 warn!("Failed to enqueue password reset notification: {:?}", e); 107 } 108 info!("Password reset requested for user {}", user_id); 109 (StatusCode::OK, Json(json!({}))).into_response() 110} 111 112#[derive(Deserialize)] 113pub struct ResetPasswordInput { 114 pub token: String, 115 pub password: String, 116} 117 118pub async fn reset_password( 119 State(state): State<AppState>, 120 headers: HeaderMap, 121 Json(input): Json<ResetPasswordInput>, 122) -> Response { 123 let client_ip = extract_client_ip(&headers); 124 if !state 125 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 126 .await 127 { 128 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 129 return ( 130 StatusCode::TOO_MANY_REQUESTS, 131 Json(json!({ 132 "error": "RateLimitExceeded", 133 "message": "Too many requests. Please try again later." 134 })), 135 ) 136 .into_response(); 137 } 138 let token = input.token.trim(); 139 let password = &input.password; 140 if token.is_empty() { 141 return ( 142 StatusCode::BAD_REQUEST, 143 Json(json!({"error": "InvalidToken", "message": "token is required"})), 144 ) 145 .into_response(); 146 } 147 if password.is_empty() { 148 return ( 149 StatusCode::BAD_REQUEST, 150 Json(json!({"error": "InvalidRequest", "message": "password is required"})), 151 ) 152 .into_response(); 153 } 154 let user = sqlx::query!( 155 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 156 token 157 ) 158 .fetch_optional(&state.db) 159 .await; 160 let (user_id, expires_at) = match user { 161 Ok(Some(row)) => { 162 let expires = row.password_reset_code_expires_at; 163 (row.id, expires) 164 } 165 Ok(None) => { 166 return ( 167 StatusCode::BAD_REQUEST, 168 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 169 ) 170 .into_response(); 171 } 172 Err(e) => { 173 error!("DB error in reset_password: {:?}", e); 174 return ( 175 StatusCode::INTERNAL_SERVER_ERROR, 176 Json(json!({"error": "InternalError"})), 177 ) 178 .into_response(); 179 } 180 }; 181 if let Some(exp) = expires_at { 182 if Utc::now() > exp { 183 if let Err(e) = sqlx::query!( 184 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 185 user_id 186 ) 187 .execute(&state.db) 188 .await 189 { 190 error!("Failed to clear expired reset code: {:?}", e); 191 } 192 return ( 193 StatusCode::BAD_REQUEST, 194 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 195 ) 196 .into_response(); 197 } 198 } else { 199 return ( 200 StatusCode::BAD_REQUEST, 201 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 202 ) 203 .into_response(); 204 } 205 let password_hash = match hash(password, DEFAULT_COST) { 206 Ok(h) => h, 207 Err(e) => { 208 error!("Failed to hash password: {:?}", e); 209 return ( 210 StatusCode::INTERNAL_SERVER_ERROR, 211 Json(json!({"error": "InternalError"})), 212 ) 213 .into_response(); 214 } 215 }; 216 let mut tx = match state.db.begin().await { 217 Ok(tx) => tx, 218 Err(e) => { 219 error!("Failed to begin transaction: {:?}", e); 220 return ( 221 StatusCode::INTERNAL_SERVER_ERROR, 222 Json(json!({"error": "InternalError"})), 223 ) 224 .into_response(); 225 } 226 }; 227 if let Err(e) = sqlx::query!( 228 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2", 229 password_hash, 230 user_id 231 ) 232 .execute(&mut *tx) 233 .await 234 { 235 error!("DB error updating password: {:?}", e); 236 return ( 237 StatusCode::INTERNAL_SERVER_ERROR, 238 Json(json!({"error": "InternalError"})), 239 ) 240 .into_response(); 241 } 242 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id) 243 .fetch_one(&mut *tx) 244 .await 245 { 246 Ok(did) => did, 247 Err(e) => { 248 error!("Failed to get DID for user {}: {:?}", user_id, e); 249 return ( 250 StatusCode::INTERNAL_SERVER_ERROR, 251 Json(json!({"error": "InternalError"})), 252 ) 253 .into_response(); 254 } 255 }; 256 let session_jtis: Vec<String> = match sqlx::query_scalar!( 257 "SELECT access_jti FROM session_tokens WHERE did = $1", 258 user_did 259 ) 260 .fetch_all(&mut *tx) 261 .await 262 { 263 Ok(jtis) => jtis, 264 Err(e) => { 265 error!("Failed to fetch session JTIs: {:?}", e); 266 vec![] 267 } 268 }; 269 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 270 .execute(&mut *tx) 271 .await 272 { 273 error!( 274 "Failed to invalidate sessions after password reset: {:?}", 275 e 276 ); 277 return ( 278 StatusCode::INTERNAL_SERVER_ERROR, 279 Json(json!({"error": "InternalError"})), 280 ) 281 .into_response(); 282 } 283 if let Err(e) = tx.commit().await { 284 error!("Failed to commit password reset transaction: {:?}", e); 285 return ( 286 StatusCode::INTERNAL_SERVER_ERROR, 287 Json(json!({"error": "InternalError"})), 288 ) 289 .into_response(); 290 } 291 for jti in session_jtis { 292 let cache_key = format!("auth:session:{}:{}", user_did, jti); 293 if let Err(e) = state.cache.delete(&cache_key).await { 294 warn!( 295 "Failed to invalidate session cache for {}: {:?}", 296 cache_key, e 297 ); 298 } 299 } 300 info!("Password reset completed for user {}", user_id); 301 (StatusCode::OK, Json(json!({}))).into_response() 302} 303 304#[derive(Deserialize)] 305#[serde(rename_all = "camelCase")] 306pub struct ChangePasswordInput { 307 pub current_password: String, 308 pub new_password: String, 309} 310 311pub async fn change_password( 312 State(state): State<AppState>, 313 auth: BearerAuth, 314 Json(input): Json<ChangePasswordInput>, 315) -> Response { 316 let current_password = &input.current_password; 317 let new_password = &input.new_password; 318 if current_password.is_empty() { 319 return ( 320 StatusCode::BAD_REQUEST, 321 Json(json!({"error": "InvalidRequest", "message": "currentPassword is required"})), 322 ) 323 .into_response(); 324 } 325 if new_password.is_empty() { 326 return ( 327 StatusCode::BAD_REQUEST, 328 Json(json!({"error": "InvalidRequest", "message": "newPassword is required"})), 329 ) 330 .into_response(); 331 } 332 if new_password.len() < 8 { 333 return ( 334 StatusCode::BAD_REQUEST, 335 Json(json!({"error": "InvalidRequest", "message": "Password must be at least 8 characters"})), 336 ) 337 .into_response(); 338 } 339 let user = 340 sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1") 341 .bind(&auth.0.did) 342 .fetch_optional(&state.db) 343 .await; 344 let (user_id, password_hash) = match user { 345 Ok(Some(row)) => row, 346 Ok(None) => { 347 return ( 348 StatusCode::NOT_FOUND, 349 Json(json!({"error": "AccountNotFound", "message": "Account not found"})), 350 ) 351 .into_response(); 352 } 353 Err(e) => { 354 error!("DB error in change_password: {:?}", e); 355 return ( 356 StatusCode::INTERNAL_SERVER_ERROR, 357 Json(json!({"error": "InternalError"})), 358 ) 359 .into_response(); 360 } 361 }; 362 let valid = match verify(current_password, &password_hash) { 363 Ok(v) => v, 364 Err(e) => { 365 error!("Password verification error: {:?}", e); 366 return ( 367 StatusCode::INTERNAL_SERVER_ERROR, 368 Json(json!({"error": "InternalError"})), 369 ) 370 .into_response(); 371 } 372 }; 373 if !valid { 374 return ( 375 StatusCode::UNAUTHORIZED, 376 Json(json!({"error": "InvalidPassword", "message": "Current password is incorrect"})), 377 ) 378 .into_response(); 379 } 380 let new_hash = match hash(new_password, DEFAULT_COST) { 381 Ok(h) => h, 382 Err(e) => { 383 error!("Failed to hash password: {:?}", e); 384 return ( 385 StatusCode::INTERNAL_SERVER_ERROR, 386 Json(json!({"error": "InternalError"})), 387 ) 388 .into_response(); 389 } 390 }; 391 if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2") 392 .bind(&new_hash) 393 .bind(user_id) 394 .execute(&state.db) 395 .await 396 { 397 error!("DB error updating password: {:?}", e); 398 return ( 399 StatusCode::INTERNAL_SERVER_ERROR, 400 Json(json!({"error": "InternalError"})), 401 ) 402 .into_response(); 403 } 404 info!(did = %auth.0.did, "Password changed successfully"); 405 (StatusCode::OK, Json(json!({}))).into_response() 406}