this repo has no description
1use crate::state::{AppState, RateLimitKind}; 2use axum::{ 3 Json, 4 extract::State, 5 http::{HeaderMap, StatusCode}, 6 response::{IntoResponse, Response}, 7}; 8use bcrypt::{hash, DEFAULT_COST}; 9use chrono::{Duration, Utc}; 10use serde::Deserialize; 11use serde_json::json; 12use tracing::{error, info, warn}; 13 14fn generate_reset_code() -> String { 15 crate::util::generate_token_code() 16} 17fn extract_client_ip(headers: &HeaderMap) -> String { 18 if let Some(forwarded) = headers.get("x-forwarded-for") { 19 if let Ok(value) = forwarded.to_str() { 20 if let Some(first_ip) = value.split(',').next() { 21 return first_ip.trim().to_string(); 22 } 23 } 24 } 25 if let Some(real_ip) = headers.get("x-real-ip") { 26 if let Ok(value) = real_ip.to_str() { 27 return value.trim().to_string(); 28 } 29 } 30 "unknown".to_string() 31} 32 33#[derive(Deserialize)] 34pub struct RequestPasswordResetInput { 35 pub email: String, 36} 37 38pub async fn request_password_reset( 39 State(state): State<AppState>, 40 headers: HeaderMap, 41 Json(input): Json<RequestPasswordResetInput>, 42) -> Response { 43 let client_ip = extract_client_ip(&headers); 44 if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await { 45 warn!(ip = %client_ip, "Password reset rate limit exceeded"); 46 return ( 47 StatusCode::TOO_MANY_REQUESTS, 48 Json(json!({ 49 "error": "RateLimitExceeded", 50 "message": "Too many password reset requests. Please try again later." 51 })), 52 ) 53 .into_response(); 54 } 55 let email = input.email.trim().to_lowercase(); 56 if email.is_empty() { 57 return ( 58 StatusCode::BAD_REQUEST, 59 Json(json!({"error": "InvalidRequest", "message": "email is required"})), 60 ) 61 .into_response(); 62 } 63 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email) 64 .fetch_optional(&state.db) 65 .await; 66 let user_id = match user { 67 Ok(Some(row)) => row.id, 68 Ok(None) => { 69 info!("Password reset requested for unknown email"); 70 return (StatusCode::OK, Json(json!({}))).into_response(); 71 } 72 Err(e) => { 73 error!("DB error in request_password_reset: {:?}", e); 74 return ( 75 StatusCode::INTERNAL_SERVER_ERROR, 76 Json(json!({"error": "InternalError"})), 77 ) 78 .into_response(); 79 } 80 }; 81 let code = generate_reset_code(); 82 let expires_at = Utc::now() + Duration::minutes(10); 83 let update = sqlx::query!( 84 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 85 code, 86 expires_at, 87 user_id 88 ) 89 .execute(&state.db) 90 .await; 91 if let Err(e) = update { 92 error!("DB error setting reset code: {:?}", e); 93 return ( 94 StatusCode::INTERNAL_SERVER_ERROR, 95 Json(json!({"error": "InternalError"})), 96 ) 97 .into_response(); 98 } 99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 100 if let Err(e) = 101 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 102 { 103 warn!("Failed to enqueue password reset notification: {:?}", e); 104 } 105 info!("Password reset requested for user {}", user_id); 106 (StatusCode::OK, Json(json!({}))).into_response() 107} 108 109#[derive(Deserialize)] 110pub struct ResetPasswordInput { 111 pub token: String, 112 pub password: String, 113} 114 115pub async fn reset_password( 116 State(state): State<AppState>, 117 headers: HeaderMap, 118 Json(input): Json<ResetPasswordInput>, 119) -> Response { 120 let client_ip = extract_client_ip(&headers); 121 if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await { 122 warn!(ip = %client_ip, "Reset password rate limit exceeded"); 123 return ( 124 StatusCode::TOO_MANY_REQUESTS, 125 Json(json!({ 126 "error": "RateLimitExceeded", 127 "message": "Too many requests. Please try again later." 128 })), 129 ).into_response(); 130 } 131 let token = input.token.trim(); 132 let password = &input.password; 133 if token.is_empty() { 134 return ( 135 StatusCode::BAD_REQUEST, 136 Json(json!({"error": "InvalidToken", "message": "token is required"})), 137 ) 138 .into_response(); 139 } 140 if password.is_empty() { 141 return ( 142 StatusCode::BAD_REQUEST, 143 Json(json!({"error": "InvalidRequest", "message": "password is required"})), 144 ) 145 .into_response(); 146 } 147 let user = sqlx::query!( 148 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 149 token 150 ) 151 .fetch_optional(&state.db) 152 .await; 153 let (user_id, expires_at) = match user { 154 Ok(Some(row)) => { 155 let expires = row.password_reset_code_expires_at; 156 (row.id, expires) 157 } 158 Ok(None) => { 159 return ( 160 StatusCode::BAD_REQUEST, 161 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 162 ) 163 .into_response(); 164 } 165 Err(e) => { 166 error!("DB error in reset_password: {:?}", e); 167 return ( 168 StatusCode::INTERNAL_SERVER_ERROR, 169 Json(json!({"error": "InternalError"})), 170 ) 171 .into_response(); 172 } 173 }; 174 if let Some(exp) = expires_at { 175 if Utc::now() > exp { 176 if let Err(e) = sqlx::query!( 177 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 178 user_id 179 ) 180 .execute(&state.db) 181 .await 182 { 183 error!("Failed to clear expired reset code: {:?}", e); 184 } 185 return ( 186 StatusCode::BAD_REQUEST, 187 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})), 188 ) 189 .into_response(); 190 } 191 } else { 192 return ( 193 StatusCode::BAD_REQUEST, 194 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})), 195 ) 196 .into_response(); 197 } 198 let password_hash = match hash(password, DEFAULT_COST) { 199 Ok(h) => h, 200 Err(e) => { 201 error!("Failed to hash password: {:?}", e); 202 return ( 203 StatusCode::INTERNAL_SERVER_ERROR, 204 Json(json!({"error": "InternalError"})), 205 ) 206 .into_response(); 207 } 208 }; 209 let mut tx = match state.db.begin().await { 210 Ok(tx) => tx, 211 Err(e) => { 212 error!("Failed to begin transaction: {:?}", e); 213 return ( 214 StatusCode::INTERNAL_SERVER_ERROR, 215 Json(json!({"error": "InternalError"})), 216 ) 217 .into_response(); 218 } 219 }; 220 if let Err(e) = sqlx::query!( 221 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2", 222 password_hash, 223 user_id 224 ) 225 .execute(&mut *tx) 226 .await 227 { 228 error!("DB error updating password: {:?}", e); 229 return ( 230 StatusCode::INTERNAL_SERVER_ERROR, 231 Json(json!({"error": "InternalError"})), 232 ) 233 .into_response(); 234 } 235 let user_did = match sqlx::query_scalar!( 236 "SELECT did FROM users WHERE id = $1", 237 user_id 238 ) 239 .fetch_one(&mut *tx) 240 .await 241 { 242 Ok(did) => did, 243 Err(e) => { 244 error!("Failed to get DID for user {}: {:?}", user_id, e); 245 return ( 246 StatusCode::INTERNAL_SERVER_ERROR, 247 Json(json!({"error": "InternalError"})), 248 ) 249 .into_response(); 250 } 251 }; 252 let session_jtis: Vec<String> = match sqlx::query_scalar!( 253 "SELECT access_jti FROM session_tokens WHERE did = $1", 254 user_did 255 ) 256 .fetch_all(&mut *tx) 257 .await 258 { 259 Ok(jtis) => jtis, 260 Err(e) => { 261 error!("Failed to fetch session JTIs: {:?}", e); 262 vec![] 263 } 264 }; 265 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 266 .execute(&mut *tx) 267 .await 268 { 269 error!("Failed to invalidate sessions after password reset: {:?}", e); 270 return ( 271 StatusCode::INTERNAL_SERVER_ERROR, 272 Json(json!({"error": "InternalError"})), 273 ) 274 .into_response(); 275 } 276 if let Err(e) = tx.commit().await { 277 error!("Failed to commit password reset transaction: {:?}", e); 278 return ( 279 StatusCode::INTERNAL_SERVER_ERROR, 280 Json(json!({"error": "InternalError"})), 281 ) 282 .into_response(); 283 } 284 for jti in session_jtis { 285 let cache_key = format!("auth:session:{}:{}", user_did, jti); 286 if let Err(e) = state.cache.delete(&cache_key).await { 287 warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e); 288 } 289 } 290 info!("Password reset completed for user {}", user_id); 291 (StatusCode::OK, Json(json!({}))).into_response() 292}