this repo has no description
1use crate::state::{AppState, RateLimitKind}; 2use axum::{ 3 Json, 4 extract::State, 5 http::{HeaderMap, StatusCode}, 6 response::{IntoResponse, Response}, 7}; 8use bcrypt::{hash, DEFAULT_COST}; 9use chrono::{Duration, Utc}; 10use serde::Deserialize; 11use serde_json::json; 12use tracing::{error, info, warn}; 13fn generate_reset_code() -> String { 14 crate::util::generate_token_code() 15} 16fn extract_client_ip(headers: &HeaderMap) -> String { 17 if let Some(forwarded) = headers.get("x-forwarded-for") { 18 if let Ok(value) = forwarded.to_str() { 19 if let Some(first_ip) = value.split(',').next() { 20 return first_ip.trim().to_string(); 21 } 22 } 23 } 24 if let Some(real_ip) = headers.get("x-real-ip") { 25 if let Ok(value) = real_ip.to_str() { 26 return value.trim().to_string(); 27 } 28 } 29 "unknown".to_string() 30} 31#[derive(Deserialize)] 32pub struct RequestPasswordResetInput { 33 pub email: String, 34} 35pub async fn request_password_reset( 36 State(state): State<AppState>, 37 headers: HeaderMap, 38 Json(input): Json<RequestPasswordResetInput>, 39) -> Response { 40 let client_ip = extract_client_ip(&headers); 41 if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await { 42 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 43 return ( 44 StatusCode::TOO_MANY_REQUESTS, 45 Json(json!({ 46 "error": "RateLimitExceeded", 47 "message": "Too many password reset requests. Please try again later." 48 })), 49 ) 50 .into_response(); 51 } 52 let email = input.email.trim().to_lowercase(); 53 if email.is_empty() { 54 return ( 55 StatusCode::BAD_REQUEST, 56 Json(json!({"error": "InvalidRequest", "message": "email is required"})), 57 ) 58 .into_response(); 59 } 60 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email) 61 .fetch_optional(&state.db) 62 .await; 63 let user_id = match user { 64 Ok(Some(row)) => row.id, 65 Ok(None) => { 66 info!("Password reset requested for unknown email"); 67 return (StatusCode::OK, Json(json!({}))).into_response(); 68 } 69 Err(e) => { 70 error!("DB error in request_password_reset: {:?}", e); 71 return ( 72 StatusCode::INTERNAL_SERVER_ERROR, 73 Json(json!({"error": "InternalError"})), 74 ) 75 .into_response(); 76 } 77 }; 78 let code = generate_reset_code(); 79 let expires_at = Utc::now() + Duration::minutes(10); 80 let update = sqlx::query!( 81 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 82 code, 83 expires_at, 84 user_id 85 ) 86 .execute(&state.db) 87 .await; 88 if let Err(e) = update { 89 error!("DB error setting reset code: {:?}", e); 90 return ( 91 StatusCode::INTERNAL_SERVER_ERROR, 92 Json(json!({"error": "InternalError"})), 93 ) 94 .into_response(); 95 } 96 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 97 if let Err(e) = 98 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 99 { 100 warn!("Failed to enqueue password reset notification: {:?}", e); 101 } 102 info!("Password reset requested for user {}", user_id); 103 (StatusCode::OK, Json(json!({}))).into_response() 104} 105#[derive(Deserialize)] 106pub struct ResetPasswordInput { 107 pub token: String, 108 pub password: String, 109} 110pub async fn reset_password( 111 State(state): State<AppState>, 112 headers: HeaderMap, 113 Json(input): Json<ResetPasswordInput>, 114) -> Response { 115 let client_ip = extract_client_ip(&headers); 116 if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await { 117 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 118 return ( 119 StatusCode::TOO_MANY_REQUESTS, 120 Json(json!({ 121 "error": "RateLimitExceeded", 122 "message": "Too many requests. Please try again later." 123 })), 124 ).into_response(); 125 } 126 let token = input.token.trim(); 127 let password = &input.password; 128 if token.is_empty() { 129 return ( 130 StatusCode::BAD_REQUEST, 131 Json(json!({"error": "InvalidToken", "message": "token is required"})), 132 ) 133 .into_response(); 134 } 135 if password.is_empty() { 136 return ( 137 StatusCode::BAD_REQUEST, 138 Json(json!({"error": "InvalidRequest", "message": "password is required"})), 139 ) 140 .into_response(); 141 } 142 let user = sqlx::query!( 143 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 144 token 145 ) 146 .fetch_optional(&state.db) 147 .await; 148 let (user_id, expires_at) = match user { 149 Ok(Some(row)) => { 150 let expires = row.password_reset_code_expires_at; 151 (row.id, expires) 152 } 153 Ok(None) => { 154 return ( 155 StatusCode::BAD_REQUEST, 156 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 157 ) 158 .into_response(); 159 } 160 Err(e) => { 161 error!("DB error in reset_password: {:?}", e); 162 return ( 163 StatusCode::INTERNAL_SERVER_ERROR, 164 Json(json!({"error": "InternalError"})), 165 ) 166 .into_response(); 167 } 168 }; 169 if let Some(exp) = expires_at { 170 if Utc::now() > exp { 171 if let Err(e) = sqlx::query!( 172 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 173 user_id 174 ) 175 .execute(&state.db) 176 .await 177 { 178 error!("Failed to clear expired reset code: {:?}", e); 179 } 180 return ( 181 StatusCode::BAD_REQUEST, 182 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 183 ) 184 .into_response(); 185 } 186 } else { 187 return ( 188 StatusCode::BAD_REQUEST, 189 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 190 ) 191 .into_response(); 192 } 193 let password_hash = match hash(password, DEFAULT_COST) { 194 Ok(h) => h, 195 Err(e) => { 196 error!("Failed to hash password: {:?}", e); 197 return ( 198 StatusCode::INTERNAL_SERVER_ERROR, 199 Json(json!({"error": "InternalError"})), 200 ) 201 .into_response(); 202 } 203 }; 204 let mut tx = match state.db.begin().await { 205 Ok(tx) => tx, 206 Err(e) => { 207 error!("Failed to begin transaction: {:?}", e); 208 return ( 209 StatusCode::INTERNAL_SERVER_ERROR, 210 Json(json!({"error": "InternalError"})), 211 ) 212 .into_response(); 213 } 214 }; 215 if let Err(e) = sqlx::query!( 216 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2", 217 password_hash, 218 user_id 219 ) 220 .execute(&mut *tx) 221 .await 222 { 223 error!("DB error updating password: {:?}", e); 224 return ( 225 StatusCode::INTERNAL_SERVER_ERROR, 226 Json(json!({"error": "InternalError"})), 227 ) 228 .into_response(); 229 } 230 let user_did = match sqlx::query_scalar!( 231 "SELECT did FROM users WHERE id = $1", 232 user_id 233 ) 234 .fetch_one(&mut *tx) 235 .await 236 { 237 Ok(did) => did, 238 Err(e) => { 239 error!("Failed to get DID for user {}: {:?}", user_id, e); 240 return ( 241 StatusCode::INTERNAL_SERVER_ERROR, 242 Json(json!({"error": "InternalError"})), 243 ) 244 .into_response(); 245 } 246 }; 247 let session_jtis: Vec<String> = match sqlx::query_scalar!( 248 "SELECT access_jti FROM session_tokens WHERE did = $1", 249 user_did 250 ) 251 .fetch_all(&mut *tx) 252 .await 253 { 254 Ok(jtis) => jtis, 255 Err(e) => { 256 error!("Failed to fetch session JTIs: {:?}", e); 257 vec![] 258 } 259 }; 260 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 261 .execute(&mut *tx) 262 .await 263 { 264 error!("Failed to invalidate sessions after password reset: {:?}", e); 265 return ( 266 StatusCode::INTERNAL_SERVER_ERROR, 267 Json(json!({"error": "InternalError"})), 268 ) 269 .into_response(); 270 } 271 if let Err(e) = tx.commit().await { 272 error!("Failed to commit password reset transaction: {:?}", e); 273 return ( 274 StatusCode::INTERNAL_SERVER_ERROR, 275 Json(json!({"error": "InternalError"})), 276 ) 277 .into_response(); 278 } 279 for jti in session_jtis { 280 let cache_key = format!("auth:session:{}:{}", user_did, jti); 281 if let Err(e) = state.cache.delete(&cache_key).await { 282 warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e); 283 } 284 } 285 info!("Password reset completed for user {}", user_id); 286 (StatusCode::OK, Json(json!({}))).into_response() 287}