this repo has no description
1use crate::auth::BearerAuth;
2use crate::state::{AppState, RateLimitKind};
3use axum::{
4 Json,
5 extract::State,
6 http::{HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8};
9use bcrypt::{DEFAULT_COST, hash, verify};
10use chrono::{Duration, Utc};
11use serde::Deserialize;
12use serde_json::json;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16fn generate_reset_code() -> String {
17 crate::util::generate_token_code()
18}
19fn extract_client_ip(headers: &HeaderMap) -> String {
20 if let Some(forwarded) = headers.get("x-forwarded-for")
21 && let Ok(value) = forwarded.to_str()
22 && let Some(first_ip) = value.split(',').next()
23 {
24 return first_ip.trim().to_string();
25 }
26 if let Some(real_ip) = headers.get("x-real-ip")
27 && let Ok(value) = real_ip.to_str()
28 {
29 return value.trim().to_string();
30 }
31 "unknown".to_string()
32}
33
34#[derive(Deserialize)]
35pub struct RequestPasswordResetInput {
36 #[serde(alias = "identifier")]
37 pub email: String,
38}
39
40pub async fn request_password_reset(
41 State(state): State<AppState>,
42 headers: HeaderMap,
43 Json(input): Json<RequestPasswordResetInput>,
44) -> Response {
45 let client_ip = extract_client_ip(&headers);
46 if !state
47 .check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
48 .await
49 {
50 warn!(ip = %client_ip, "Password reset rate limit exceeded");
51 return (
52 StatusCode::TOO_MANY_REQUESTS,
53 Json(json!({
54 "error": "RateLimitExceeded",
55 "message": "Too many password reset requests. Please try again later."
56 })),
57 )
58 .into_response();
59 }
60 let identifier = input.email.trim();
61 if identifier.is_empty() {
62 return (
63 StatusCode::BAD_REQUEST,
64 Json(json!({"error": "InvalidRequest", "message": "email or handle is required"})),
65 )
66 .into_response();
67 }
68 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
69 let normalized = identifier.to_lowercase();
70 let normalized = normalized.strip_prefix('@').unwrap_or(&normalized);
71 let normalized_handle = if normalized.contains('@') || normalized.contains('.') {
72 normalized.to_string()
73 } else {
74 format!("{}.{}", normalized, pds_hostname)
75 };
76 let user = sqlx::query!(
77 "SELECT id FROM users WHERE LOWER(email) = $1 OR handle = $2",
78 normalized,
79 normalized_handle
80 )
81 .fetch_optional(&state.db)
82 .await;
83 let user_id = match user {
84 Ok(Some(row)) => row.id,
85 Ok(None) => {
86 info!("Password reset requested for unknown identifier");
87 return (StatusCode::OK, Json(json!({}))).into_response();
88 }
89 Err(e) => {
90 error!("DB error in request_password_reset: {:?}", e);
91 return (
92 StatusCode::INTERNAL_SERVER_ERROR,
93 Json(json!({"error": "InternalError"})),
94 )
95 .into_response();
96 }
97 };
98 let code = generate_reset_code();
99 let expires_at = Utc::now() + Duration::minutes(10);
100 let update = sqlx::query!(
101 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
102 code,
103 expires_at,
104 user_id
105 )
106 .execute(&state.db)
107 .await;
108 if let Err(e) = update {
109 error!("DB error setting reset code: {:?}", e);
110 return (
111 StatusCode::INTERNAL_SERVER_ERROR,
112 Json(json!({"error": "InternalError"})),
113 )
114 .into_response();
115 }
116 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
117 if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
118 {
119 warn!("Failed to enqueue password reset notification: {:?}", e);
120 }
121 info!("Password reset requested for user {}", user_id);
122 (StatusCode::OK, Json(json!({}))).into_response()
123}
124
125#[derive(Deserialize)]
126pub struct ResetPasswordInput {
127 pub token: String,
128 pub password: String,
129}
130
131pub async fn reset_password(
132 State(state): State<AppState>,
133 headers: HeaderMap,
134 Json(input): Json<ResetPasswordInput>,
135) -> Response {
136 let client_ip = extract_client_ip(&headers);
137 if !state
138 .check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
139 .await
140 {
141 warn!(ip = %client_ip, "Reset password rate limit exceeded");
142 return (
143 StatusCode::TOO_MANY_REQUESTS,
144 Json(json!({
145 "error": "RateLimitExceeded",
146 "message": "Too many requests. Please try again later."
147 })),
148 )
149 .into_response();
150 }
151 let token = input.token.trim();
152 let password = &input.password;
153 if token.is_empty() {
154 return (
155 StatusCode::BAD_REQUEST,
156 Json(json!({"error": "InvalidToken", "message": "token is required"})),
157 )
158 .into_response();
159 }
160 if password.is_empty() {
161 return (
162 StatusCode::BAD_REQUEST,
163 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
164 )
165 .into_response();
166 }
167 let user = sqlx::query!(
168 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
169 token
170 )
171 .fetch_optional(&state.db)
172 .await;
173 let (user_id, expires_at) = match user {
174 Ok(Some(row)) => {
175 let expires = row.password_reset_code_expires_at;
176 (row.id, expires)
177 }
178 Ok(None) => {
179 return (
180 StatusCode::BAD_REQUEST,
181 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
182 )
183 .into_response();
184 }
185 Err(e) => {
186 error!("DB error in reset_password: {:?}", e);
187 return (
188 StatusCode::INTERNAL_SERVER_ERROR,
189 Json(json!({"error": "InternalError"})),
190 )
191 .into_response();
192 }
193 };
194 if let Some(exp) = expires_at {
195 if Utc::now() > exp {
196 if let Err(e) = sqlx::query!(
197 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
198 user_id
199 )
200 .execute(&state.db)
201 .await
202 {
203 error!("Failed to clear expired reset code: {:?}", e);
204 }
205 return (
206 StatusCode::BAD_REQUEST,
207 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
208 )
209 .into_response();
210 }
211 } else {
212 return (
213 StatusCode::BAD_REQUEST,
214 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
215 )
216 .into_response();
217 }
218 let password_hash = match hash(password, DEFAULT_COST) {
219 Ok(h) => h,
220 Err(e) => {
221 error!("Failed to hash password: {:?}", e);
222 return (
223 StatusCode::INTERNAL_SERVER_ERROR,
224 Json(json!({"error": "InternalError"})),
225 )
226 .into_response();
227 }
228 };
229 let mut tx = match state.db.begin().await {
230 Ok(tx) => tx,
231 Err(e) => {
232 error!("Failed to begin transaction: {:?}", e);
233 return (
234 StatusCode::INTERNAL_SERVER_ERROR,
235 Json(json!({"error": "InternalError"})),
236 )
237 .into_response();
238 }
239 };
240 if let Err(e) = sqlx::query!(
241 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL, password_required = TRUE WHERE id = $2",
242 password_hash,
243 user_id
244 )
245 .execute(&mut *tx)
246 .await
247 {
248 error!("DB error updating password: {:?}", e);
249 return (
250 StatusCode::INTERNAL_SERVER_ERROR,
251 Json(json!({"error": "InternalError"})),
252 )
253 .into_response();
254 }
255 let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id)
256 .fetch_one(&mut *tx)
257 .await
258 {
259 Ok(did) => did,
260 Err(e) => {
261 error!("Failed to get DID for user {}: {:?}", user_id, e);
262 return (
263 StatusCode::INTERNAL_SERVER_ERROR,
264 Json(json!({"error": "InternalError"})),
265 )
266 .into_response();
267 }
268 };
269 let session_jtis: Vec<String> = match sqlx::query_scalar!(
270 "SELECT access_jti FROM session_tokens WHERE did = $1",
271 user_did
272 )
273 .fetch_all(&mut *tx)
274 .await
275 {
276 Ok(jtis) => jtis,
277 Err(e) => {
278 error!("Failed to fetch session JTIs: {:?}", e);
279 vec![]
280 }
281 };
282 if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did)
283 .execute(&mut *tx)
284 .await
285 {
286 error!(
287 "Failed to invalidate sessions after password reset: {:?}",
288 e
289 );
290 return (
291 StatusCode::INTERNAL_SERVER_ERROR,
292 Json(json!({"error": "InternalError"})),
293 )
294 .into_response();
295 }
296 if let Err(e) = tx.commit().await {
297 error!("Failed to commit password reset transaction: {:?}", e);
298 return (
299 StatusCode::INTERNAL_SERVER_ERROR,
300 Json(json!({"error": "InternalError"})),
301 )
302 .into_response();
303 }
304 for jti in session_jtis {
305 let cache_key = format!("auth:session:{}:{}", user_did, jti);
306 if let Err(e) = state.cache.delete(&cache_key).await {
307 warn!(
308 "Failed to invalidate session cache for {}: {:?}",
309 cache_key, e
310 );
311 }
312 }
313 info!("Password reset completed for user {}", user_id);
314 (StatusCode::OK, Json(json!({}))).into_response()
315}
316
317#[derive(Deserialize)]
318#[serde(rename_all = "camelCase")]
319pub struct ChangePasswordInput {
320 pub current_password: String,
321 pub new_password: String,
322}
323
324pub async fn change_password(
325 State(state): State<AppState>,
326 auth: BearerAuth,
327 Json(input): Json<ChangePasswordInput>,
328) -> Response {
329 let current_password = &input.current_password;
330 let new_password = &input.new_password;
331 if current_password.is_empty() {
332 return (
333 StatusCode::BAD_REQUEST,
334 Json(json!({"error": "InvalidRequest", "message": "currentPassword is required"})),
335 )
336 .into_response();
337 }
338 if new_password.is_empty() {
339 return (
340 StatusCode::BAD_REQUEST,
341 Json(json!({"error": "InvalidRequest", "message": "newPassword is required"})),
342 )
343 .into_response();
344 }
345 if new_password.len() < 8 {
346 return (
347 StatusCode::BAD_REQUEST,
348 Json(json!({"error": "InvalidRequest", "message": "Password must be at least 8 characters"})),
349 )
350 .into_response();
351 }
352 let user =
353 sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1")
354 .bind(&auth.0.did)
355 .fetch_optional(&state.db)
356 .await;
357 let (user_id, password_hash) = match user {
358 Ok(Some(row)) => row,
359 Ok(None) => {
360 return (
361 StatusCode::NOT_FOUND,
362 Json(json!({"error": "AccountNotFound", "message": "Account not found"})),
363 )
364 .into_response();
365 }
366 Err(e) => {
367 error!("DB error in change_password: {:?}", e);
368 return (
369 StatusCode::INTERNAL_SERVER_ERROR,
370 Json(json!({"error": "InternalError"})),
371 )
372 .into_response();
373 }
374 };
375 let valid = match verify(current_password, &password_hash) {
376 Ok(v) => v,
377 Err(e) => {
378 error!("Password verification error: {:?}", e);
379 return (
380 StatusCode::INTERNAL_SERVER_ERROR,
381 Json(json!({"error": "InternalError"})),
382 )
383 .into_response();
384 }
385 };
386 if !valid {
387 return (
388 StatusCode::UNAUTHORIZED,
389 Json(json!({"error": "InvalidPassword", "message": "Current password is incorrect"})),
390 )
391 .into_response();
392 }
393 let new_hash = match hash(new_password, DEFAULT_COST) {
394 Ok(h) => h,
395 Err(e) => {
396 error!("Failed to hash password: {:?}", e);
397 return (
398 StatusCode::INTERNAL_SERVER_ERROR,
399 Json(json!({"error": "InternalError"})),
400 )
401 .into_response();
402 }
403 };
404 if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
405 .bind(&new_hash)
406 .bind(user_id)
407 .execute(&state.db)
408 .await
409 {
410 error!("DB error updating password: {:?}", e);
411 return (
412 StatusCode::INTERNAL_SERVER_ERROR,
413 Json(json!({"error": "InternalError"})),
414 )
415 .into_response();
416 }
417 info!(did = %auth.0.did, "Password changed successfully");
418 (StatusCode::OK, Json(json!({}))).into_response()
419}
420
421pub async fn get_password_status(State(state): State<AppState>, auth: BearerAuth) -> Response {
422 let user = sqlx::query!(
423 "SELECT password_hash IS NOT NULL as has_password FROM users WHERE did = $1",
424 auth.0.did
425 )
426 .fetch_optional(&state.db)
427 .await;
428
429 match user {
430 Ok(Some(row)) => {
431 Json(json!({"hasPassword": row.has_password.unwrap_or(false)})).into_response()
432 }
433 Ok(None) => (
434 StatusCode::NOT_FOUND,
435 Json(json!({"error": "AccountNotFound"})),
436 )
437 .into_response(),
438 Err(e) => {
439 error!("DB error: {:?}", e);
440 (
441 StatusCode::INTERNAL_SERVER_ERROR,
442 Json(json!({"error": "InternalError"})),
443 )
444 .into_response()
445 }
446 }
447}
448
449pub async fn remove_password(State(state): State<AppState>, auth: BearerAuth) -> Response {
450 if crate::api::server::reauth::check_reauth_required(&state.db, &auth.0.did).await {
451 return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await;
452 }
453
454 let has_passkeys =
455 crate::api::server::passkeys::has_passkeys_for_user_db(&state.db, &auth.0.did).await;
456 if !has_passkeys {
457 return (
458 StatusCode::BAD_REQUEST,
459 Json(json!({
460 "error": "NoPasskeys",
461 "message": "You must have at least one passkey registered before removing your password"
462 })),
463 )
464 .into_response();
465 }
466
467 let user = sqlx::query!(
468 "SELECT id, password_hash FROM users WHERE did = $1",
469 auth.0.did
470 )
471 .fetch_optional(&state.db)
472 .await;
473
474 let user = match user {
475 Ok(Some(u)) => u,
476 Ok(None) => {
477 return (
478 StatusCode::NOT_FOUND,
479 Json(json!({"error": "AccountNotFound"})),
480 )
481 .into_response();
482 }
483 Err(e) => {
484 error!("DB error: {:?}", e);
485 return (
486 StatusCode::INTERNAL_SERVER_ERROR,
487 Json(json!({"error": "InternalError"})),
488 )
489 .into_response();
490 }
491 };
492
493 if user.password_hash.is_none() {
494 return (
495 StatusCode::BAD_REQUEST,
496 Json(json!({
497 "error": "NoPassword",
498 "message": "Account already has no password"
499 })),
500 )
501 .into_response();
502 }
503
504 if let Err(e) = sqlx::query!(
505 "UPDATE users SET password_hash = NULL, password_required = FALSE WHERE id = $1",
506 user.id
507 )
508 .execute(&state.db)
509 .await
510 {
511 error!("DB error removing password: {:?}", e);
512 return (
513 StatusCode::INTERNAL_SERVER_ERROR,
514 Json(json!({"error": "InternalError"})),
515 )
516 .into_response();
517 }
518
519 info!(did = %auth.0.did, "Password removed - account is now passkey-only");
520 (StatusCode::OK, Json(json!({"success": true}))).into_response()
521}