this repo has no description
1use crate::state::{AppState, RateLimitKind};
2use axum::{
3 Json,
4 extract::State,
5 http::{HeaderMap, StatusCode},
6 response::{IntoResponse, Response},
7};
8use bcrypt::{hash, DEFAULT_COST};
9use chrono::{Duration, Utc};
10use serde::Deserialize;
11use serde_json::json;
12use tracing::{error, info, warn};
13
14fn generate_reset_code() -> String {
15 crate::util::generate_token_code()
16}
17fn extract_client_ip(headers: &HeaderMap) -> String {
18 if let Some(forwarded) = headers.get("x-forwarded-for") {
19 if let Ok(value) = forwarded.to_str() {
20 if let Some(first_ip) = value.split(',').next() {
21 return first_ip.trim().to_string();
22 }
23 }
24 }
25 if let Some(real_ip) = headers.get("x-real-ip") {
26 if let Ok(value) = real_ip.to_str() {
27 return value.trim().to_string();
28 }
29 }
30 "unknown".to_string()
31}
32
33#[derive(Deserialize)]
34pub struct RequestPasswordResetInput {
35 pub email: String,
36}
37
38pub async fn request_password_reset(
39 State(state): State<AppState>,
40 headers: HeaderMap,
41 Json(input): Json<RequestPasswordResetInput>,
42) -> Response {
43 let client_ip = extract_client_ip(&headers);
44 if !state.check_rate_limit(RateLimitKind::PasswordReset, &client_ip).await {
45 warn!(ip = %client_ip, "Password reset rate limit exceeded");
46 return (
47 StatusCode::TOO_MANY_REQUESTS,
48 Json(json!({
49 "error": "RateLimitExceeded",
50 "message": "Too many password reset requests. Please try again later."
51 })),
52 )
53 .into_response();
54 }
55 let email = input.email.trim().to_lowercase();
56 if email.is_empty() {
57 return (
58 StatusCode::BAD_REQUEST,
59 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
60 )
61 .into_response();
62 }
63 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
64 .fetch_optional(&state.db)
65 .await;
66 let user_id = match user {
67 Ok(Some(row)) => row.id,
68 Ok(None) => {
69 info!("Password reset requested for unknown email");
70 return (StatusCode::OK, Json(json!({}))).into_response();
71 }
72 Err(e) => {
73 error!("DB error in request_password_reset: {:?}", e);
74 return (
75 StatusCode::INTERNAL_SERVER_ERROR,
76 Json(json!({"error": "InternalError"})),
77 )
78 .into_response();
79 }
80 };
81 let code = generate_reset_code();
82 let expires_at = Utc::now() + Duration::minutes(10);
83 let update = sqlx::query!(
84 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
85 code,
86 expires_at,
87 user_id
88 )
89 .execute(&state.db)
90 .await;
91 if let Err(e) = update {
92 error!("DB error setting reset code: {:?}", e);
93 return (
94 StatusCode::INTERNAL_SERVER_ERROR,
95 Json(json!({"error": "InternalError"})),
96 )
97 .into_response();
98 }
99 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
100 if let Err(e) =
101 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
102 {
103 warn!("Failed to enqueue password reset notification: {:?}", e);
104 }
105 info!("Password reset requested for user {}", user_id);
106 (StatusCode::OK, Json(json!({}))).into_response()
107}
108
109#[derive(Deserialize)]
110pub struct ResetPasswordInput {
111 pub token: String,
112 pub password: String,
113}
114
115pub async fn reset_password(
116 State(state): State<AppState>,
117 headers: HeaderMap,
118 Json(input): Json<ResetPasswordInput>,
119) -> Response {
120 let client_ip = extract_client_ip(&headers);
121 if !state.check_rate_limit(RateLimitKind::ResetPassword, &client_ip).await {
122 warn!(ip = %client_ip, "Reset password rate limit exceeded");
123 return (
124 StatusCode::TOO_MANY_REQUESTS,
125 Json(json!({
126 "error": "RateLimitExceeded",
127 "message": "Too many requests. Please try again later."
128 })),
129 ).into_response();
130 }
131 let token = input.token.trim();
132 let password = &input.password;
133 if token.is_empty() {
134 return (
135 StatusCode::BAD_REQUEST,
136 Json(json!({"error": "InvalidToken", "message": "token is required"})),
137 )
138 .into_response();
139 }
140 if password.is_empty() {
141 return (
142 StatusCode::BAD_REQUEST,
143 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
144 )
145 .into_response();
146 }
147 let user = sqlx::query!(
148 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
149 token
150 )
151 .fetch_optional(&state.db)
152 .await;
153 let (user_id, expires_at) = match user {
154 Ok(Some(row)) => {
155 let expires = row.password_reset_code_expires_at;
156 (row.id, expires)
157 }
158 Ok(None) => {
159 return (
160 StatusCode::BAD_REQUEST,
161 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
162 )
163 .into_response();
164 }
165 Err(e) => {
166 error!("DB error in reset_password: {:?}", e);
167 return (
168 StatusCode::INTERNAL_SERVER_ERROR,
169 Json(json!({"error": "InternalError"})),
170 )
171 .into_response();
172 }
173 };
174 if let Some(exp) = expires_at {
175 if Utc::now() > exp {
176 if let Err(e) = sqlx::query!(
177 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
178 user_id
179 )
180 .execute(&state.db)
181 .await
182 {
183 error!("Failed to clear expired reset code: {:?}", e);
184 }
185 return (
186 StatusCode::BAD_REQUEST,
187 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
188 )
189 .into_response();
190 }
191 } else {
192 return (
193 StatusCode::BAD_REQUEST,
194 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
195 )
196 .into_response();
197 }
198 let password_hash = match hash(password, DEFAULT_COST) {
199 Ok(h) => h,
200 Err(e) => {
201 error!("Failed to hash password: {:?}", e);
202 return (
203 StatusCode::INTERNAL_SERVER_ERROR,
204 Json(json!({"error": "InternalError"})),
205 )
206 .into_response();
207 }
208 };
209 let mut tx = match state.db.begin().await {
210 Ok(tx) => tx,
211 Err(e) => {
212 error!("Failed to begin transaction: {:?}", e);
213 return (
214 StatusCode::INTERNAL_SERVER_ERROR,
215 Json(json!({"error": "InternalError"})),
216 )
217 .into_response();
218 }
219 };
220 if let Err(e) = sqlx::query!(
221 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
222 password_hash,
223 user_id
224 )
225 .execute(&mut *tx)
226 .await
227 {
228 error!("DB error updating password: {:?}", e);
229 return (
230 StatusCode::INTERNAL_SERVER_ERROR,
231 Json(json!({"error": "InternalError"})),
232 )
233 .into_response();
234 }
235 let user_did = match sqlx::query_scalar!(
236 "SELECT did FROM users WHERE id = $1",
237 user_id
238 )
239 .fetch_one(&mut *tx)
240 .await
241 {
242 Ok(did) => did,
243 Err(e) => {
244 error!("Failed to get DID for user {}: {:?}", user_id, e);
245 return (
246 StatusCode::INTERNAL_SERVER_ERROR,
247 Json(json!({"error": "InternalError"})),
248 )
249 .into_response();
250 }
251 };
252 let session_jtis: Vec<String> = match sqlx::query_scalar!(
253 "SELECT access_jti FROM session_tokens WHERE did = $1",
254 user_did
255 )
256 .fetch_all(&mut *tx)
257 .await
258 {
259 Ok(jtis) => jtis,
260 Err(e) => {
261 error!("Failed to fetch session JTIs: {:?}", e);
262 vec![]
263 }
264 };
265 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
266 .execute(&mut *tx)
267 .await
268 {
269 error!("Failed to invalidate sessions after password reset: {:?}", e);
270 return (
271 StatusCode::INTERNAL_SERVER_ERROR,
272 Json(json!({"error": "InternalError"})),
273 )
274 .into_response();
275 }
276 if let Err(e) = tx.commit().await {
277 error!("Failed to commit password reset transaction: {:?}", e);
278 return (
279 StatusCode::INTERNAL_SERVER_ERROR,
280 Json(json!({"error": "InternalError"})),
281 )
282 .into_response();
283 }
284 for jti in session_jtis {
285 let cache_key = format!("auth:session:{}:{}", user_did, jti);
286 if let Err(e) = state.cache.delete(&cache_key).await {
287 warn!("Failed to invalidate session cache for {}: {:?}", cache_key, e);
288 }
289 }
290 info!("Password reset completed for user {}", user_id);
291 (StatusCode::OK, Json(json!({}))).into_response()
292}