this repo has no description
1use crate::auth::BearerAuth;
2use crate::state::{AppState, RateLimitKind};
3use axum::{
4 Json,
5 extract::State,
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use bcrypt::{DEFAULT_COST, hash, verify};
10use chrono::{Duration, Utc};
11use serde::Deserialize;
12use serde_json::json;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16fn generate_reset_code() -> String {
17 crate::util::generate_token_code()
18}
19fn extract_client_ip(headers: &HeaderMap) -> String {
20 if let Some(forwarded) = headers.get("x-forwarded-for")
21 && let Ok(value) = forwarded.to_str()
22 && let Some(first_ip) = value.split(',').next()
23 {
24 return first_ip.trim().to_string();
25 }
26 if let Some(real_ip) = headers.get("x-real-ip")
27 && let Ok(value) = real_ip.to_str()
28 {
29 return value.trim().to_string();
30 }
31 "unknown".to_string()
32}
33
34#[derive(Deserialize)]
35pub struct RequestPasswordResetInput {
36 pub email: String,
37}
38
39pub async fn request_password_reset(
40 State(state): State<AppState>,
41 headers: HeaderMap,
42 Json(input): Json<RequestPasswordResetInput>,
43) -> Response {
44 let client_ip = extract_client_ip(&headers);
45 if !state
46 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
47 .await
48 {
49 warn!(ip = %client_ip, "Password reset rate limit exceeded");
50 return (
51 StatusCode::TOO_MANY_REQUESTS,
52 Json(json!({
53 "error": "RateLimitExceeded",
54 "message": "Too many password reset requests. Please try again later."
55 })),
56 )
57 .into_response();
58 }
59 let email = input.email.trim().to_lowercase();
60 if email.is_empty() {
61 return (
62 StatusCode::BAD_REQUEST,
63 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
64 )
65 .into_response();
66 }
67 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
68 .fetch_optional(&state.db)
69 .await;
70 let user_id = match user {
71 Ok(Some(row)) => row.id,
72 Ok(None) => {
73 info!("Password reset requested for unknown email");
74 return (StatusCode::OK, Json(json!({}))).into_response();
75 }
76 Err(e) => {
77 error!("DB error in request_password_reset: {:?}", e);
78 return (
79 StatusCode::INTERNAL_SERVER_ERROR,
80 Json(json!({"error": "InternalError"})),
81 )
82 .into_response();
83 }
84 };
85 let code = generate_reset_code();
86 let expires_at = Utc::now() + Duration::minutes(10);
87 let update = sqlx::query!(
88 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
89 code,
90 expires_at,
91 user_id
92 )
93 .execute(&state.db)
94 .await;
95 if let Err(e) = update {
96 error!("DB error setting reset code: {:?}", e);
97 return (
98 StatusCode::INTERNAL_SERVER_ERROR,
99 Json(json!({"error": "InternalError"})),
100 )
101 .into_response();
102 }
103 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
104 if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
105 {
106 warn!("Failed to enqueue password reset notification: {:?}", e);
107 }
108 info!("Password reset requested for user {}", user_id);
109 (StatusCode::OK, Json(json!({}))).into_response()
110}
111
112#[derive(Deserialize)]
113pub struct ResetPasswordInput {
114 pub token: String,
115 pub password: String,
116}
117
118pub async fn reset_password(
119 State(state): State<AppState>,
120 headers: HeaderMap,
121 Json(input): Json<ResetPasswordInput>,
122) -> Response {
123 let client_ip = extract_client_ip(&headers);
124 if !state
125 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
126 .await
127 {
128 warn!(ip = %client_ip, "Reset password rate limit exceeded");
129 return (
130 StatusCode::TOO_MANY_REQUESTS,
131 Json(json!({
132 "error": "RateLimitExceeded",
133 "message": "Too many requests. Please try again later."
134 })),
135 )
136 .into_response();
137 }
138 let token = input.token.trim();
139 let password = &input.password;
140 if token.is_empty() {
141 return (
142 StatusCode::BAD_REQUEST,
143 Json(json!({"error": "InvalidToken", "message": "token is required"})),
144 )
145 .into_response();
146 }
147 if password.is_empty() {
148 return (
149 StatusCode::BAD_REQUEST,
150 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
151 )
152 .into_response();
153 }
154 let user = sqlx::query!(
155 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
156 token
157 )
158 .fetch_optional(&state.db)
159 .await;
160 let (user_id, expires_at) = match user {
161 Ok(Some(row)) => {
162 let expires = row.password_reset_code_expires_at;
163 (row.id, expires)
164 }
165 Ok(None) => {
166 return (
167 StatusCode::BAD_REQUEST,
168 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
169 )
170 .into_response();
171 }
172 Err(e) => {
173 error!("DB error in reset_password: {:?}", e);
174 return (
175 StatusCode::INTERNAL_SERVER_ERROR,
176 Json(json!({"error": "InternalError"})),
177 )
178 .into_response();
179 }
180 };
181 if let Some(exp) = expires_at {
182 if Utc::now() > exp {
183 if let Err(e) = sqlx::query!(
184 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
185 user_id
186 )
187 .execute(&state.db)
188 .await
189 {
190 error!("Failed to clear expired reset code: {:?}", e);
191 }
192 return (
193 StatusCode::BAD_REQUEST,
194 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
195 )
196 .into_response();
197 }
198 } else {
199 return (
200 StatusCode::BAD_REQUEST,
201 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
202 )
203 .into_response();
204 }
205 let password_hash = match hash(password, DEFAULT_COST) {
206 Ok(h) => h,
207 Err(e) => {
208 error!("Failed to hash password: {:?}", e);
209 return (
210 StatusCode::INTERNAL_SERVER_ERROR,
211 Json(json!({"error": "InternalError"})),
212 )
213 .into_response();
214 }
215 };
216 let mut tx = match state.db.begin().await {
217 Ok(tx) => tx,
218 Err(e) => {
219 error!("Failed to begin transaction: {:?}", e);
220 return (
221 StatusCode::INTERNAL_SERVER_ERROR,
222 Json(json!({"error": "InternalError"})),
223 )
224 .into_response();
225 }
226 };
227 if let Err(e) = sqlx::query!(
228 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
229 password_hash,
230 user_id
231 )
232 .execute(&mut *tx)
233 .await
234 {
235 error!("DB error updating password: {:?}", e);
236 return (
237 StatusCode::INTERNAL_SERVER_ERROR,
238 Json(json!({"error": "InternalError"})),
239 )
240 .into_response();
241 }
242 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id)
243 .fetch_one(&mut *tx)
244 .await
245 {
246 Ok(did) => did,
247 Err(e) => {
248 error!("Failed to get DID for user {}: {:?}", user_id, e);
249 return (
250 StatusCode::INTERNAL_SERVER_ERROR,
251 Json(json!({"error": "InternalError"})),
252 )
253 .into_response();
254 }
255 };
256 let session_jtis: Vec<String> = match sqlx::query_scalar!(
257 "SELECT access_jti FROM session_tokens WHERE did = $1",
258 user_did
259 )
260 .fetch_all(&mut *tx)
261 .await
262 {
263 Ok(jtis) => jtis,
264 Err(e) => {
265 error!("Failed to fetch session JTIs: {:?}", e);
266 vec![]
267 }
268 };
269 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
270 .execute(&mut *tx)
271 .await
272 {
273 error!(
274 "Failed to invalidate sessions after password reset: {:?}",
275 e
276 );
277 return (
278 StatusCode::INTERNAL_SERVER_ERROR,
279 Json(json!({"error": "InternalError"})),
280 )
281 .into_response();
282 }
283 if let Err(e) = tx.commit().await {
284 error!("Failed to commit password reset transaction: {:?}", e);
285 return (
286 StatusCode::INTERNAL_SERVER_ERROR,
287 Json(json!({"error": "InternalError"})),
288 )
289 .into_response();
290 }
291 for jti in session_jtis {
292 let cache_key = format!("auth:session:{}:{}", user_did, jti);
293 if let Err(e) = state.cache.delete(&cache_key).await {
294 warn!(
295 "Failed to invalidate session cache for {}: {:?}",
296 cache_key, e
297 );
298 }
299 }
300 info!("Password reset completed for user {}", user_id);
301 (StatusCode::OK, Json(json!({}))).into_response()
302}
303
304#[derive(Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct ChangePasswordInput {
307 pub current_password: String,
308 pub new_password: String,
309}
310
311pub async fn change_password(
312 State(state): State<AppState>,
313 auth: BearerAuth,
314 Json(input): Json<ChangePasswordInput>,
315) -> Response {
316 let current_password = &input.current_password;
317 let new_password = &input.new_password;
318 if current_password.is_empty() {
319 return (
320 StatusCode::BAD_REQUEST,
321 Json(json!({"error": "InvalidRequest", "message": "currentPassword is required"})),
322 )
323 .into_response();
324 }
325 if new_password.is_empty() {
326 return (
327 StatusCode::BAD_REQUEST,
328 Json(json!({"error": "InvalidRequest", "message": "newPassword is required"})),
329 )
330 .into_response();
331 }
332 if new_password.len() < 8 {
333 return (
334 StatusCode::BAD_REQUEST,
335 Json(json!({"error": "InvalidRequest", "message": "Password must be at least 8 characters"})),
336 )
337 .into_response();
338 }
339 let user =
340 sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1")
341 .bind(&auth.0.did)
342 .fetch_optional(&state.db)
343 .await;
344 let (user_id, password_hash) = match user {
345 Ok(Some(row)) => row,
346 Ok(None) => {
347 return (
348 StatusCode::NOT_FOUND,
349 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
350 )
351 .into_response();
352 }
353 Err(e) => {
354 error!("DB error in change_password: {:?}", e);
355 return (
356 StatusCode::INTERNAL_SERVER_ERROR,
357 Json(json!({"error": "InternalError"})),
358 )
359 .into_response();
360 }
361 };
362 let valid = match verify(current_password, &password_hash) {
363 Ok(v) => v,
364 Err(e) => {
365 error!("Password verification error: {:?}", e);
366 return (
367 StatusCode::INTERNAL_SERVER_ERROR,
368 Json(json!({"error": "InternalError"})),
369 )
370 .into_response();
371 }
372 };
373 if !valid {
374 return (
375 StatusCode::UNAUTHORIZED,
376 Json(json!({"error": "InvalidPassword", "message": "Current password is incorrect"})),
377 )
378 .into_response();
379 }
380 let new_hash = match hash(new_password, DEFAULT_COST) {
381 Ok(h) => h,
382 Err(e) => {
383 error!("Failed to hash password: {:?}", e);
384 return (
385 StatusCode::INTERNAL_SERVER_ERROR,
386 Json(json!({"error": "InternalError"})),
387 )
388 .into_response();
389 }
390 };
391 if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
392 .bind(&new_hash)
393 .bind(user_id)
394 .execute(&state.db)
395 .await
396 {
397 error!("DB error updating password: {:?}", e);
398 return (
399 StatusCode::INTERNAL_SERVER_ERROR,
400 Json(json!({"error": "InternalError"})),
401 )
402 .into_response();
403 }
404 info!(did = %auth.0.did, "Password changed successfully");
405 (StatusCode::OK, Json(json!({}))).into_response()
406}