this repo has no description
1use crate::state::{AppState, RateLimitKind};
2use axum::{
3 Json,
4 extract::State,
5 http::{HeaderMap, StatusCode},
6 response::{IntoResponse, Response},
7};
8use bcrypt::{hash, DEFAULT_COST};
9use chrono::{Duration, Utc};
10use serde::Deserialize;
11use serde_json::json;
12use tracing::{error, info, warn};
13fn generate_reset_code() -> String {
14 crate::util::generate_token_code()
15}
16fn extract_client_ip(headers: &HeaderMap) -> String {
17 if let Some(forwarded) = headers.get("x-forwarded-for") {
18 if let Ok(value) = forwarded.to_str() {
19 if let Some(first_ip) = value.split(',').next() {
20 return first_ip.trim().to_string();
21 }
22 }
23 }
24 if let Some(real_ip) = headers.get("x-real-ip") {
25 if let Ok(value) = real_ip.to_str() {
26 return value.trim().to_string();
27 }
28 }
29 "unknown".to_string()
30}
31#[derive(Deserialize)]
32pub struct RequestPasswordResetInput {
33 pub email: String,
34}
35pub async fn request_password_reset(
36 State(state): State<AppState>,
37 headers: HeaderMap,
38 Json(input): Json<RequestPasswordResetInput>,
39) -> Response {
40 let client_ip = extract_client_ip(&headers);
41 if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await {
42 warn!(ip = %client_ip, "Password reset rate limit exceeded");
43 return (
44 StatusCode::TOO_MANY_REQUESTS,
45 Json(json!({
46 "error": "RateLimitExceeded",
47 "message": "Too many password reset requests. Please try again later."
48 })),
49 )
50 .into_response();
51 }
52 let email = input.email.trim().to_lowercase();
53 if email.is_empty() {
54 return (
55 StatusCode::BAD_REQUEST,
56 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
57 )
58 .into_response();
59 }
60 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
61 .fetch_optional(&state.db)
62 .await;
63 let user_id = match user {
64 Ok(Some(row)) => row.id,
65 Ok(None) => {
66 info!("Password reset requested for unknown email");
67 return (StatusCode::OK, Json(json!({}))).into_response();
68 }
69 Err(e) => {
70 error!("DB error in request_password_reset: {:?}", e);
71 return (
72 StatusCode::INTERNAL_SERVER_ERROR,
73 Json(json!({"error": "InternalError"})),
74 )
75 .into_response();
76 }
77 };
78 let code = generate_reset_code();
79 let expires_at = Utc::now() + Duration::minutes(10);
80 let update = sqlx::query!(
81 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
82 code,
83 expires_at,
84 user_id
85 )
86 .execute(&state.db)
87 .await;
88 if let Err(e) = update {
89 error!("DB error setting reset code: {:?}", e);
90 return (
91 StatusCode::INTERNAL_SERVER_ERROR,
92 Json(json!({"error": "InternalError"})),
93 )
94 .into_response();
95 }
96 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
97 if let Err(e) =
98 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
99 {
100 warn!("Failed to enqueue password reset notification: {:?}", e);
101 }
102 info!("Password reset requested for user {}", user_id);
103 (StatusCode::OK, Json(json!({}))).into_response()
104}
105#[derive(Deserialize)]
106pub struct ResetPasswordInput {
107 pub token: String,
108 pub password: String,
109}
110pub async fn reset_password(
111 State(state): State<AppState>,
112 headers: HeaderMap,
113 Json(input): Json<ResetPasswordInput>,
114) -> Response {
115 let client_ip = extract_client_ip(&headers);
116 if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await {
117 warn!(ip = %client_ip, "Reset password rate limit exceeded");
118 return (
119 StatusCode::TOO_MANY_REQUESTS,
120 Json(json!({
121 "error": "RateLimitExceeded",
122 "message": "Too many requests. Please try again later."
123 })),
124 ).into_response();
125 }
126 let token = input.token.trim();
127 let password = &input.password;
128 if token.is_empty() {
129 return (
130 StatusCode::BAD_REQUEST,
131 Json(json!({"error": "InvalidToken", "message": "token is required"})),
132 )
133 .into_response();
134 }
135 if password.is_empty() {
136 return (
137 StatusCode::BAD_REQUEST,
138 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
139 )
140 .into_response();
141 }
142 let user = sqlx::query!(
143 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
144 token
145 )
146 .fetch_optional(&state.db)
147 .await;
148 let (user_id, expires_at) = match user {
149 Ok(Some(row)) => {
150 let expires = row.password_reset_code_expires_at;
151 (row.id, expires)
152 }
153 Ok(None) => {
154 return (
155 StatusCode::BAD_REQUEST,
156 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
157 )
158 .into_response();
159 }
160 Err(e) => {
161 error!("DB error in reset_password: {:?}", e);
162 return (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(json!({"error": "InternalError"})),
165 )
166 .into_response();
167 }
168 };
169 if let Some(exp) = expires_at {
170 if Utc::now() > exp {
171 if let Err(e) = sqlx::query!(
172 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
173 user_id
174 )
175 .execute(&state.db)
176 .await
177 {
178 error!("Failed to clear expired reset code: {:?}", e);
179 }
180 return (
181 StatusCode::BAD_REQUEST,
182 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
183 )
184 .into_response();
185 }
186 } else {
187 return (
188 StatusCode::BAD_REQUEST,
189 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
190 )
191 .into_response();
192 }
193 let password_hash = match hash(password, DEFAULT_COST) {
194 Ok(h) => h,
195 Err(e) => {
196 error!("Failed to hash password: {:?}", e);
197 return (
198 StatusCode::INTERNAL_SERVER_ERROR,
199 Json(json!({"error": "InternalError"})),
200 )
201 .into_response();
202 }
203 };
204 let mut tx = match state.db.begin().await {
205 Ok(tx) => tx,
206 Err(e) => {
207 error!("Failed to begin transaction: {:?}", e);
208 return (
209 StatusCode::INTERNAL_SERVER_ERROR,
210 Json(json!({"error": "InternalError"})),
211 )
212 .into_response();
213 }
214 };
215 if let Err(e) = sqlx::query!(
216 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
217 password_hash,
218 user_id
219 )
220 .execute(&mut *tx)
221 .await
222 {
223 error!("DB error updating password: {:?}", e);
224 return (
225 StatusCode::INTERNAL_SERVER_ERROR,
226 Json(json!({"error": "InternalError"})),
227 )
228 .into_response();
229 }
230 let user_did = match sqlx::query_scalar!(
231 "SELECT did FROM users WHERE id = $1",
232 user_id
233 )
234 .fetch_one(&mut *tx)
235 .await
236 {
237 Ok(did) => did,
238 Err(e) => {
239 error!("Failed to get DID for user {}: {:?}", user_id, e);
240 return (
241 StatusCode::INTERNAL_SERVER_ERROR,
242 Json(json!({"error": "InternalError"})),
243 )
244 .into_response();
245 }
246 };
247 let session_jtis: Vec<String> = match sqlx::query_scalar!(
248 "SELECT access_jti FROM session_tokens WHERE did = $1",
249 user_did
250 )
251 .fetch_all(&mut *tx)
252 .await
253 {
254 Ok(jtis) => jtis,
255 Err(e) => {
256 error!("Failed to fetch session JTIs: {:?}", e);
257 vec![]
258 }
259 };
260 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
261 .execute(&mut *tx)
262 .await
263 {
264 error!("Failed to invalidate sessions after password reset: {:?}", e);
265 return (
266 StatusCode::INTERNAL_SERVER_ERROR,
267 Json(json!({"error": "InternalError"})),
268 )
269 .into_response();
270 }
271 if let Err(e) = tx.commit().await {
272 error!("Failed to commit password reset transaction: {:?}", e);
273 return (
274 StatusCode::INTERNAL_SERVER_ERROR,
275 Json(json!({"error": "InternalError"})),
276 )
277 .into_response();
278 }
279 for jti in session_jtis {
280 let cache_key = format!("auth:session:{}:{}", user_did, jti);
281 if let Err(e) = state.cache.delete(&cache_key).await {
282 warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e);
283 }
284 }
285 info!("Password reset completed for user {}", user_id);
286 (StatusCode::OK, Json(json!({}))).into_response()
287}