this repo has no description
1use crate::auth::BearerAuth;
2use crate::state::{AppState, RateLimitKind};
3use axum::{
4 Json,
5 extract::State,
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use bcrypt::{DEFAULT_COST, hash, verify};
10use chrono::{Duration, Utc};
11use uuid::Uuid;
12use serde::Deserialize;
13use serde_json::json;
14use tracing::{error, info, warn};
15
16fn generate_reset_code() -> String {
17 crate::util::generate_token_code()
18}
19fn extract_client_ip(headers: &HeaderMap) -> String {
20 if let Some(forwarded) = headers.get("x-forwarded-for")
21 && let Ok(value) = forwarded.to_str()
22 && let Some(first_ip) = value.split(',').next() {
23 return first_ip.trim().to_string();
24 }
25 if let Some(real_ip) = headers.get("x-real-ip")
26 && let Ok(value) = real_ip.to_str() {
27 return value.trim().to_string();
28 }
29 "unknown".to_string()
30}
31
32#[derive(Deserialize)]
33pub struct RequestPasswordResetInput {
34 pub email: String,
35}
36
37pub async fn request_password_reset(
38 State(state): State<AppState>,
39 headers: HeaderMap,
40 Json(input): Json<RequestPasswordResetInput>,
41) -> Response {
42 let client_ip = extract_client_ip(&headers);
43 if !state
44 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
45 .await
46 {
47 warn!(ip = %client_ip, "Password reset rate limit exceeded");
48 return (
49 StatusCode::TOO_MANY_REQUESTS,
50 Json(json!({
51 "error": "RateLimitExceeded",
52 "message": "Too many password reset requests. Please try again later."
53 })),
54 )
55 .into_response();
56 }
57 let email = input.email.trim().to_lowercase();
58 if email.is_empty() {
59 return (
60 StatusCode::BAD_REQUEST,
61 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
62 )
63 .into_response();
64 }
65 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
66 .fetch_optional(&state.db)
67 .await;
68 let user_id = match user {
69 Ok(Some(row)) => row.id,
70 Ok(None) => {
71 info!("Password reset requested for unknown email");
72 return (StatusCode::OK, Json(json!({}))).into_response();
73 }
74 Err(e) => {
75 error!("DB error in request_password_reset: {:?}", e);
76 return (
77 StatusCode::INTERNAL_SERVER_ERROR,
78 Json(json!({"error": "InternalError"})),
79 )
80 .into_response();
81 }
82 };
83 let code = generate_reset_code();
84 let expires_at = Utc::now() + Duration::minutes(10);
85 let update = sqlx::query!(
86 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
87 code,
88 expires_at,
89 user_id
90 )
91 .execute(&state.db)
92 .await;
93 if let Err(e) = update {
94 error!("DB error setting reset code: {:?}", e);
95 return (
96 StatusCode::INTERNAL_SERVER_ERROR,
97 Json(json!({"error": "InternalError"})),
98 )
99 .into_response();
100 }
101 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
102 if let Err(e) =
103 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
104 {
105 warn!("Failed to enqueue password reset notification: {:?}", e);
106 }
107 info!("Password reset requested for user {}", user_id);
108 (StatusCode::OK, Json(json!({}))).into_response()
109}
110
111#[derive(Deserialize)]
112pub struct ResetPasswordInput {
113 pub token: String,
114 pub password: String,
115}
116
117pub async fn reset_password(
118 State(state): State<AppState>,
119 headers: HeaderMap,
120 Json(input): Json<ResetPasswordInput>,
121) -> Response {
122 let client_ip = extract_client_ip(&headers);
123 if !state
124 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
125 .await
126 {
127 warn!(ip = %client_ip, "Reset password rate limit exceeded");
128 return (
129 StatusCode::TOO_MANY_REQUESTS,
130 Json(json!({
131 "error": "RateLimitExceeded",
132 "message": "Too many requests. Please try again later."
133 })),
134 )
135 .into_response();
136 }
137 let token = input.token.trim();
138 let password = &input.password;
139 if token.is_empty() {
140 return (
141 StatusCode::BAD_REQUEST,
142 Json(json!({"error": "InvalidToken", "message": "token is required"})),
143 )
144 .into_response();
145 }
146 if password.is_empty() {
147 return (
148 StatusCode::BAD_REQUEST,
149 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
150 )
151 .into_response();
152 }
153 let user = sqlx::query!(
154 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
155 token
156 )
157 .fetch_optional(&state.db)
158 .await;
159 let (user_id, expires_at) = match user {
160 Ok(Some(row)) => {
161 let expires = row.password_reset_code_expires_at;
162 (row.id, expires)
163 }
164 Ok(None) => {
165 return (
166 StatusCode::BAD_REQUEST,
167 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
168 )
169 .into_response();
170 }
171 Err(e) => {
172 error!("DB error in reset_password: {:?}", e);
173 return (
174 StatusCode::INTERNAL_SERVER_ERROR,
175 Json(json!({"error": "InternalError"})),
176 )
177 .into_response();
178 }
179 };
180 if let Some(exp) = expires_at {
181 if Utc::now() > exp {
182 if let Err(e) = sqlx::query!(
183 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
184 user_id
185 )
186 .execute(&state.db)
187 .await
188 {
189 error!("Failed to clear expired reset code: {:?}", e);
190 }
191 return (
192 StatusCode::BAD_REQUEST,
193 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
194 )
195 .into_response();
196 }
197 } else {
198 return (
199 StatusCode::BAD_REQUEST,
200 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
201 )
202 .into_response();
203 }
204 let password_hash = match hash(password, DEFAULT_COST) {
205 Ok(h) => h,
206 Err(e) => {
207 error!("Failed to hash password: {:?}", e);
208 return (
209 StatusCode::INTERNAL_SERVER_ERROR,
210 Json(json!({"error": "InternalError"})),
211 )
212 .into_response();
213 }
214 };
215 let mut tx = match state.db.begin().await {
216 Ok(tx) => tx,
217 Err(e) => {
218 error!("Failed to begin transaction: {:?}", e);
219 return (
220 StatusCode::INTERNAL_SERVER_ERROR,
221 Json(json!({"error": "InternalError"})),
222 )
223 .into_response();
224 }
225 };
226 if let Err(e) = sqlx::query!(
227 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
228 password_hash,
229 user_id
230 )
231 .execute(&mut *tx)
232 .await
233 {
234 error!("DB error updating password: {:?}", e);
235 return (
236 StatusCode::INTERNAL_SERVER_ERROR,
237 Json(json!({"error": "InternalError"})),
238 )
239 .into_response();
240 }
241 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id)
242 .fetch_one(&mut *tx)
243 .await
244 {
245 Ok(did) => did,
246 Err(e) => {
247 error!("Failed to get DID for user {}: {:?}", user_id, e);
248 return (
249 StatusCode::INTERNAL_SERVER_ERROR,
250 Json(json!({"error": "InternalError"})),
251 )
252 .into_response();
253 }
254 };
255 let session_jtis: Vec<String> = match sqlx::query_scalar!(
256 "SELECT access_jti FROM session_tokens WHERE did = $1",
257 user_did
258 )
259 .fetch_all(&mut *tx)
260 .await
261 {
262 Ok(jtis) => jtis,
263 Err(e) => {
264 error!("Failed to fetch session JTIs: {:?}", e);
265 vec![]
266 }
267 };
268 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
269 .execute(&mut *tx)
270 .await
271 {
272 error!(
273 "Failed to invalidate sessions after password reset: {:?}",
274 e
275 );
276 return (
277 StatusCode::INTERNAL_SERVER_ERROR,
278 Json(json!({"error": "InternalError"})),
279 )
280 .into_response();
281 }
282 if let Err(e) = tx.commit().await {
283 error!("Failed to commit password reset transaction: {:?}", e);
284 return (
285 StatusCode::INTERNAL_SERVER_ERROR,
286 Json(json!({"error": "InternalError"})),
287 )
288 .into_response();
289 }
290 for jti in session_jtis {
291 let cache_key = format!("auth:session:{}:{}", user_did, jti);
292 if let Err(e) = state.cache.delete(&cache_key).await {
293 warn!(
294 "Failed to invalidate session cache for {}: {:?}",
295 cache_key, e
296 );
297 }
298 }
299 info!("Password reset completed for user {}", user_id);
300 (StatusCode::OK, Json(json!({}))).into_response()
301}
302
303#[derive(Deserialize)]
304#[serde(rename_all = "camelCase")]
305pub struct ChangePasswordInput {
306 pub current_password: String,
307 pub new_password: String,
308}
309
310pub async fn change_password(
311 State(state): State<AppState>,
312 auth: BearerAuth,
313 Json(input): Json<ChangePasswordInput>,
314) -> Response {
315 let current_password = &input.current_password;
316 let new_password = &input.new_password;
317 if current_password.is_empty() {
318 return (
319 StatusCode::BAD_REQUEST,
320 Json(json!({"error": "InvalidRequest", "message": "currentPassword is required"})),
321 )
322 .into_response();
323 }
324 if new_password.is_empty() {
325 return (
326 StatusCode::BAD_REQUEST,
327 Json(json!({"error": "InvalidRequest", "message": "newPassword is required"})),
328 )
329 .into_response();
330 }
331 if new_password.len() < 8 {
332 return (
333 StatusCode::BAD_REQUEST,
334 Json(json!({"error": "InvalidRequest", "message": "Password must be at least 8 characters"})),
335 )
336 .into_response();
337 }
338 let user = sqlx::query_as::<_, (Uuid, String)>(
339 "SELECT id, password_hash FROM users WHERE did = $1",
340 )
341 .bind(&auth.0.did)
342 .fetch_optional(&state.db)
343 .await;
344 let (user_id, password_hash) = match user {
345 Ok(Some(row)) => row,
346 Ok(None) => {
347 return (
348 StatusCode::NOT_FOUND,
349 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
350 )
351 .into_response();
352 }
353 Err(e) => {
354 error!("DB error in change_password: {:?}", e);
355 return (
356 StatusCode::INTERNAL_SERVER_ERROR,
357 Json(json!({"error": "InternalError"})),
358 )
359 .into_response();
360 }
361 };
362 let valid = match verify(current_password, &password_hash) {
363 Ok(v) => v,
364 Err(e) => {
365 error!("Password verification error: {:?}", e);
366 return (
367 StatusCode::INTERNAL_SERVER_ERROR,
368 Json(json!({"error": "InternalError"})),
369 )
370 .into_response();
371 }
372 };
373 if !valid {
374 return (
375 StatusCode::UNAUTHORIZED,
376 Json(json!({"error": "InvalidPassword", "message": "Current password is incorrect"})),
377 )
378 .into_response();
379 }
380 let new_hash = match hash(new_password, DEFAULT_COST) {
381 Ok(h) => h,
382 Err(e) => {
383 error!("Failed to hash password: {:?}", e);
384 return (
385 StatusCode::INTERNAL_SERVER_ERROR,
386 Json(json!({"error": "InternalError"})),
387 )
388 .into_response();
389 }
390 };
391 if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
392 .bind(&new_hash)
393 .bind(user_id)
394 .execute(&state.db)
395 .await
396 {
397 error!("DB error updating password: {:?}", e);
398 return (
399 StatusCode::INTERNAL_SERVER_ERROR,
400 Json(json!({"error": "InternalError"})),
401 )
402 .into_response();
403 }
404 info!(did = %auth.0.did, "Password changed successfully");
405 (StatusCode::OK, Json(json!({}))).into_response()
406}