this repo has no description
1use crate::state::{AppState, RateLimitKind}; 2use axum::{ 3 Json, 4 extract::State, 5 http::{HeaderMap, StatusCode}, 6 response::{IntoResponse, Response}, 7}; 8use bcrypt::{DEFAULT_COST, hash}; 9use chrono::{Duration, Utc}; 10use serde::Deserialize; 11use serde_json::json; 12use tracing::{error, info, warn}; 13 14fn generate_reset_code() -> String { 15 crate::util::generate_token_code() 16} 17fn extract_client_ip(headers: &HeaderMap) -> String { 18 if let Some(forwarded) = headers.get("x-forwarded-for") 19 && let Ok(value) = forwarded.to_str() 20 && let Some(first_ip) = value.split(',').next() { 21 return first_ip.trim().to_string(); 22 } 23 if let Some(real_ip) = headers.get("x-real-ip") 24 && let Ok(value) = real_ip.to_str() { 25 return value.trim().to_string(); 26 } 27 "unknown".to_string() 28} 29 30#[derive(Deserialize)] 31pub struct RequestPasswordResetInput { 32 pub email: String, 33} 34 35pub async fn request_password_reset( 36 State(state): State<AppState>, 37 headers: HeaderMap, 38 Json(input): Json<RequestPasswordResetInput>, 39) -> Response { 40 let client_ip = extract_client_ip(&headers); 41 if !state 42 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 43 .await 44 { 45 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 46 return ( 47 StatusCode::TOO_MANY_REQUESTS, 48 Json(json!({ 49 "error": "RateLimitExceeded", 50 "message": "Too many password reset requests. Please try again later." 51 })), 52 ) 53 .into_response(); 54 } 55 let email = input.email.trim().to_lowercase(); 56 if email.is_empty() { 57 return ( 58 StatusCode::BAD_REQUEST, 59 Json(json!({"error": "InvalidRequest", "message": "email is required"})), 60 ) 61 .into_response(); 62 } 63 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email) 64 .fetch_optional(&state.db) 65 .await; 66 let user_id = match user { 67 Ok(Some(row)) => row.id, 68 Ok(None) => { 69 info!("Password reset requested for unknown email"); 70 return (StatusCode::OK, Json(json!({}))).into_response(); 71 } 72 Err(e) => { 73 error!("DB error in request_password_reset: {:?}", e); 74 return ( 75 StatusCode::INTERNAL_SERVER_ERROR, 76 Json(json!({"error": "InternalError"})), 77 ) 78 .into_response(); 79 } 80 }; 81 let code = generate_reset_code(); 82 let expires_at = Utc::now() + Duration::minutes(10); 83 let update = sqlx::query!( 84 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 85 code, 86 expires_at, 87 user_id 88 ) 89 .execute(&state.db) 90 .await; 91 if let Err(e) = update { 92 error!("DB error setting reset code: {:?}", e); 93 return ( 94 StatusCode::INTERNAL_SERVER_ERROR, 95 Json(json!({"error": "InternalError"})), 96 ) 97 .into_response(); 98 } 99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 100 if let Err(e) = 101 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 102 { 103 warn!("Failed to enqueue password reset notification: {:?}", e); 104 } 105 info!("Password reset requested for user {}", user_id); 106 (StatusCode::OK, Json(json!({}))).into_response() 107} 108 109#[derive(Deserialize)] 110pub struct ResetPasswordInput { 111 pub token: String, 112 pub password: String, 113} 114 115pub async fn reset_password( 116 State(state): State<AppState>, 117 headers: HeaderMap, 118 Json(input): Json<ResetPasswordInput>, 119) -> Response { 120 let client_ip = extract_client_ip(&headers); 121 if !state 122 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 123 .await 124 { 125 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 126 return ( 127 StatusCode::TOO_MANY_REQUESTS, 128 Json(json!({ 129 "error": "RateLimitExceeded", 130 "message": "Too many requests. Please try again later." 131 })), 132 ) 133 .into_response(); 134 } 135 let token = input.token.trim(); 136 let password = &input.password; 137 if token.is_empty() { 138 return ( 139 StatusCode::BAD_REQUEST, 140 Json(json!({"error": "InvalidToken", "message": "token is required"})), 141 ) 142 .into_response(); 143 } 144 if password.is_empty() { 145 return ( 146 StatusCode::BAD_REQUEST, 147 Json(json!({"error": "InvalidRequest", "message": "password is required"})), 148 ) 149 .into_response(); 150 } 151 let user = sqlx::query!( 152 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 153 token 154 ) 155 .fetch_optional(&state.db) 156 .await; 157 let (user_id, expires_at) = match user { 158 Ok(Some(row)) => { 159 let expires = row.password_reset_code_expires_at; 160 (row.id, expires) 161 } 162 Ok(None) => { 163 return ( 164 StatusCode::BAD_REQUEST, 165 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 166 ) 167 .into_response(); 168 } 169 Err(e) => { 170 error!("DB error in reset_password: {:?}", e); 171 return ( 172 StatusCode::INTERNAL_SERVER_ERROR, 173 Json(json!({"error": "InternalError"})), 174 ) 175 .into_response(); 176 } 177 }; 178 if let Some(exp) = expires_at { 179 if Utc::now() > exp { 180 if let Err(e) = sqlx::query!( 181 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 182 user_id 183 ) 184 .execute(&state.db) 185 .await 186 { 187 error!("Failed to clear expired reset code: {:?}", e); 188 } 189 return ( 190 StatusCode::BAD_REQUEST, 191 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 192 ) 193 .into_response(); 194 } 195 } else { 196 return ( 197 StatusCode::BAD_REQUEST, 198 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 199 ) 200 .into_response(); 201 } 202 let password_hash = match hash(password, DEFAULT_COST) { 203 Ok(h) => h, 204 Err(e) => { 205 error!("Failed to hash password: {:?}", e); 206 return ( 207 StatusCode::INTERNAL_SERVER_ERROR, 208 Json(json!({"error": "InternalError"})), 209 ) 210 .into_response(); 211 } 212 }; 213 let mut tx = match state.db.begin().await { 214 Ok(tx) => tx, 215 Err(e) => { 216 error!("Failed to begin transaction: {:?}", e); 217 return ( 218 StatusCode::INTERNAL_SERVER_ERROR, 219 Json(json!({"error": "InternalError"})), 220 ) 221 .into_response(); 222 } 223 }; 224 if let Err(e) = sqlx::query!( 225 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2", 226 password_hash, 227 user_id 228 ) 229 .execute(&mut *tx) 230 .await 231 { 232 error!("DB error updating password: {:?}", e); 233 return ( 234 StatusCode::INTERNAL_SERVER_ERROR, 235 Json(json!({"error": "InternalError"})), 236 ) 237 .into_response(); 238 } 239 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id) 240 .fetch_one(&mut *tx) 241 .await 242 { 243 Ok(did) => did, 244 Err(e) => { 245 error!("Failed to get DID for user {}: {:?}", user_id, e); 246 return ( 247 StatusCode::INTERNAL_SERVER_ERROR, 248 Json(json!({"error": "InternalError"})), 249 ) 250 .into_response(); 251 } 252 }; 253 let session_jtis: Vec<String> = match sqlx::query_scalar!( 254 "SELECT access_jti FROM session_tokens WHERE did = $1", 255 user_did 256 ) 257 .fetch_all(&mut *tx) 258 .await 259 { 260 Ok(jtis) => jtis, 261 Err(e) => { 262 error!("Failed to fetch session JTIs: {:?}", e); 263 vec![] 264 } 265 }; 266 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 267 .execute(&mut *tx) 268 .await 269 { 270 error!( 271 "Failed to invalidate sessions after password reset: {:?}", 272 e 273 ); 274 return ( 275 StatusCode::INTERNAL_SERVER_ERROR, 276 Json(json!({"error": "InternalError"})), 277 ) 278 .into_response(); 279 } 280 if let Err(e) = tx.commit().await { 281 error!("Failed to commit password reset transaction: {:?}", e); 282 return ( 283 StatusCode::INTERNAL_SERVER_ERROR, 284 Json(json!({"error": "InternalError"})), 285 ) 286 .into_response(); 287 } 288 for jti in session_jtis { 289 let cache_key = format!("auth:session:{}:{}", user_did, jti); 290 if let Err(e) = state.cache.delete(&cache_key).await { 291 warn!( 292 "Failed to invalidate session cache for {}: {:?}", 293 cache_key, e 294 ); 295 } 296 } 297 info!("Password reset completed for user {}", user_id); 298 (StatusCode::OK, Json(json!({}))).into_response() 299}