this repo has no description
1use crate::state::{AppState, RateLimitKind};
2use axum::{
3 Json,
4 extract::State,
5 http::{HeaderMap, StatusCode},
6 response::{IntoResponse, Response},
7};
8use bcrypt::{DEFAULT_COST, hash};
9use chrono::{Duration, Utc};
10use serde::Deserialize;
11use serde_json::json;
12use tracing::{error, info, warn};
13
14fn generate_reset_code() -> String {
15 crate::util::generate_token_code()
16}
17fn extract_client_ip(headers: &HeaderMap) -> String {
18 if let Some(forwarded) = headers.get("x-forwarded-for")
19 && let Ok(value) = forwarded.to_str()
20 && let Some(first_ip) = value.split(',').next() {
21 return first_ip.trim().to_string();
22 }
23 if let Some(real_ip) = headers.get("x-real-ip")
24 && let Ok(value) = real_ip.to_str() {
25 return value.trim().to_string();
26 }
27 "unknown".to_string()
28}
29
30#[derive(Deserialize)]
31pub struct RequestPasswordResetInput {
32 pub email: String,
33}
34
35pub async fn request_password_reset(
36 State(state): State<AppState>,
37 headers: HeaderMap,
38 Json(input): Json<RequestPasswordResetInput>,
39) -> Response {
40 let client_ip = extract_client_ip(&headers);
41 if !state
42 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
43 .await
44 {
45 warn!(ip = %client_ip, "Password reset rate limit exceeded");
46 return (
47 StatusCode::TOO_MANY_REQUESTS,
48 Json(json!({
49 "error": "RateLimitExceeded",
50 "message": "Too many password reset requests. Please try again later."
51 })),
52 )
53 .into_response();
54 }
55 let email = input.email.trim().to_lowercase();
56 if email.is_empty() {
57 return (
58 StatusCode::BAD_REQUEST,
59 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
60 )
61 .into_response();
62 }
63 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
64 .fetch_optional(&state.db)
65 .await;
66 let user_id = match user {
67 Ok(Some(row)) => row.id,
68 Ok(None) => {
69 info!("Password reset requested for unknown email");
70 return (StatusCode::OK, Json(json!({}))).into_response();
71 }
72 Err(e) => {
73 error!("DB error in request_password_reset: {:?}", e);
74 return (
75 StatusCode::INTERNAL_SERVER_ERROR,
76 Json(json!({"error": "InternalError"})),
77 )
78 .into_response();
79 }
80 };
81 let code = generate_reset_code();
82 let expires_at = Utc::now() + Duration::minutes(10);
83 let update = sqlx::query!(
84 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
85 code,
86 expires_at,
87 user_id
88 )
89 .execute(&state.db)
90 .await;
91 if let Err(e) = update {
92 error!("DB error setting reset code: {:?}", e);
93 return (
94 StatusCode::INTERNAL_SERVER_ERROR,
95 Json(json!({"error": "InternalError"})),
96 )
97 .into_response();
98 }
99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
100 if let Err(e) =
101 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
102 {
103 warn!("Failed to enqueue password reset notification: {:?}", e);
104 }
105 info!("Password reset requested for user {}", user_id);
106 (StatusCode::OK, Json(json!({}))).into_response()
107}
108
109#[derive(Deserialize)]
110pub struct ResetPasswordInput {
111 pub token: String,
112 pub password: String,
113}
114
115pub async fn reset_password(
116 State(state): State<AppState>,
117 headers: HeaderMap,
118 Json(input): Json<ResetPasswordInput>,
119) -> Response {
120 let client_ip = extract_client_ip(&headers);
121 if !state
122 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
123 .await
124 {
125 warn!(ip = %client_ip, "Reset password rate limit exceeded");
126 return (
127 StatusCode::TOO_MANY_REQUESTS,
128 Json(json!({
129 "error": "RateLimitExceeded",
130 "message": "Too many requests. Please try again later."
131 })),
132 )
133 .into_response();
134 }
135 let token = input.token.trim();
136 let password = &input.password;
137 if token.is_empty() {
138 return (
139 StatusCode::BAD_REQUEST,
140 Json(json!({"error": "InvalidToken", "message": "token is required"})),
141 )
142 .into_response();
143 }
144 if password.is_empty() {
145 return (
146 StatusCode::BAD_REQUEST,
147 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
148 )
149 .into_response();
150 }
151 let user = sqlx::query!(
152 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
153 token
154 )
155 .fetch_optional(&state.db)
156 .await;
157 let (user_id, expires_at) = match user {
158 Ok(Some(row)) => {
159 let expires = row.password_reset_code_expires_at;
160 (row.id, expires)
161 }
162 Ok(None) => {
163 return (
164 StatusCode::BAD_REQUEST,
165 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
166 )
167 .into_response();
168 }
169 Err(e) => {
170 error!("DB error in reset_password: {:?}", e);
171 return (
172 StatusCode::INTERNAL_SERVER_ERROR,
173 Json(json!({"error": "InternalError"})),
174 )
175 .into_response();
176 }
177 };
178 if let Some(exp) = expires_at {
179 if Utc::now() > exp {
180 if let Err(e) = sqlx::query!(
181 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
182 user_id
183 )
184 .execute(&state.db)
185 .await
186 {
187 error!("Failed to clear expired reset code: {:?}", e);
188 }
189 return (
190 StatusCode::BAD_REQUEST,
191 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
192 )
193 .into_response();
194 }
195 } else {
196 return (
197 StatusCode::BAD_REQUEST,
198 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
199 )
200 .into_response();
201 }
202 let password_hash = match hash(password, DEFAULT_COST) {
203 Ok(h) => h,
204 Err(e) => {
205 error!("Failed to hash password: {:?}", e);
206 return (
207 StatusCode::INTERNAL_SERVER_ERROR,
208 Json(json!({"error": "InternalError"})),
209 )
210 .into_response();
211 }
212 };
213 let mut tx = match state.db.begin().await {
214 Ok(tx) => tx,
215 Err(e) => {
216 error!("Failed to begin transaction: {:?}", e);
217 return (
218 StatusCode::INTERNAL_SERVER_ERROR,
219 Json(json!({"error": "InternalError"})),
220 )
221 .into_response();
222 }
223 };
224 if let Err(e) = sqlx::query!(
225 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
226 password_hash,
227 user_id
228 )
229 .execute(&mut *tx)
230 .await
231 {
232 error!("DB error updating password: {:?}", e);
233 return (
234 StatusCode::INTERNAL_SERVER_ERROR,
235 Json(json!({"error": "InternalError"})),
236 )
237 .into_response();
238 }
239 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id)
240 .fetch_one(&mut *tx)
241 .await
242 {
243 Ok(did) => did,
244 Err(e) => {
245 error!("Failed to get DID for user {}: {:?}", user_id, e);
246 return (
247 StatusCode::INTERNAL_SERVER_ERROR,
248 Json(json!({"error": "InternalError"})),
249 )
250 .into_response();
251 }
252 };
253 let session_jtis: Vec<String> = match sqlx::query_scalar!(
254 "SELECT access_jti FROM session_tokens WHERE did = $1",
255 user_did
256 )
257 .fetch_all(&mut *tx)
258 .await
259 {
260 Ok(jtis) => jtis,
261 Err(e) => {
262 error!("Failed to fetch session JTIs: {:?}", e);
263 vec![]
264 }
265 };
266 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
267 .execute(&mut *tx)
268 .await
269 {
270 error!(
271 "Failed to invalidate sessions after password reset: {:?}",
272 e
273 );
274 return (
275 StatusCode::INTERNAL_SERVER_ERROR,
276 Json(json!({"error": "InternalError"})),
277 )
278 .into_response();
279 }
280 if let Err(e) = tx.commit().await {
281 error!("Failed to commit password reset transaction: {:?}", e);
282 return (
283 StatusCode::INTERNAL_SERVER_ERROR,
284 Json(json!({"error": "InternalError"})),
285 )
286 .into_response();
287 }
288 for jti in session_jtis {
289 let cache_key = format!("auth:session:{}:{}", user_did, jti);
290 if let Err(e) = state.cache.delete(&cache_key).await {
291 warn!(
292 "Failed to invalidate session cache for {}: {:?}",
293 cache_key, e
294 );
295 }
296 }
297 info!("Password reset completed for user {}", user_id);
298 (StatusCode::OK, Json(json!({}))).into_response()
299}