this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use bcrypt::{hash, DEFAULT_COST};
9use chrono::{Duration, Utc};
10use rand::Rng;
11use serde::Deserialize;
12use serde_json::json;
13use tracing::{error, info, warn};
14
15fn generate_reset_code() -> String {
16 let mut rng = rand::thread_rng();
17 let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
18 let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19 let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
20 format!("{}-{}", part1, part2)
21}
22
23#[derive(Deserialize)]
24pub struct RequestPasswordResetInput {
25 pub email: String,
26}
27
28pub async fn request_password_reset(
29 State(state): State<AppState>,
30 Json(input): Json<RequestPasswordResetInput>,
31) -> Response {
32 let email = input.email.trim().to_lowercase();
33 if email.is_empty() {
34 return (
35 StatusCode::BAD_REQUEST,
36 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
37 )
38 .into_response();
39 }
40
41 let user = sqlx::query!("SELECT id FROM users WHERE LOWER(email) = $1", email)
42 .fetch_optional(&state.db)
43 .await;
44
45 let user_id = match user {
46 Ok(Some(row)) => row.id,
47 Ok(None) => {
48 info!("Password reset requested for unknown email: {}", email);
49 return (StatusCode::OK, Json(json!({}))).into_response();
50 }
51 Err(e) => {
52 error!("DB error in request_password_reset: {:?}", e);
53 return (
54 StatusCode::INTERNAL_SERVER_ERROR,
55 Json(json!({"error": "InternalError"})),
56 )
57 .into_response();
58 }
59 };
60
61 let code = generate_reset_code();
62 let expires_at = Utc::now() + Duration::minutes(10);
63
64 let update = sqlx::query!(
65 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
66 code,
67 expires_at,
68 user_id
69 )
70 .execute(&state.db)
71 .await;
72
73 if let Err(e) = update {
74 error!("DB error setting reset code: {:?}", e);
75 return (
76 StatusCode::INTERNAL_SERVER_ERROR,
77 Json(json!({"error": "InternalError"})),
78 )
79 .into_response();
80 }
81
82 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83 if let Err(e) =
84 crate::notifications::enqueue_password_reset(&state.db, user_id, &code, &hostname).await
85 {
86 warn!("Failed to enqueue password reset notification: {:?}", e);
87 }
88
89 info!("Password reset requested for user {}", user_id);
90
91 (StatusCode::OK, Json(json!({}))).into_response()
92}
93
94#[derive(Deserialize)]
95pub struct ResetPasswordInput {
96 pub token: String,
97 pub password: String,
98}
99
100pub async fn reset_password(
101 State(state): State<AppState>,
102 Json(input): Json<ResetPasswordInput>,
103) -> Response {
104 let token = input.token.trim();
105 let password = &input.password;
106
107 if token.is_empty() {
108 return (
109 StatusCode::BAD_REQUEST,
110 Json(json!({"error": "InvalidToken", "message": "token is required"})),
111 )
112 .into_response();
113 }
114
115 if password.is_empty() {
116 return (
117 StatusCode::BAD_REQUEST,
118 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
119 )
120 .into_response();
121 }
122
123 let user = sqlx::query!(
124 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
125 token
126 )
127 .fetch_optional(&state.db)
128 .await;
129
130 let (user_id, expires_at) = match user {
131 Ok(Some(row)) => {
132 let expires = row.password_reset_code_expires_at;
133 (row.id, expires)
134 }
135 Ok(None) => {
136 return (
137 StatusCode::BAD_REQUEST,
138 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
139 )
140 .into_response();
141 }
142 Err(e) => {
143 error!("DB error in reset_password: {:?}", e);
144 return (
145 StatusCode::INTERNAL_SERVER_ERROR,
146 Json(json!({"error": "InternalError"})),
147 )
148 .into_response();
149 }
150 };
151
152 if let Some(exp) = expires_at {
153 if Utc::now() > exp {
154 let _ = sqlx::query!(
155 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
156 user_id
157 )
158 .execute(&state.db)
159 .await;
160
161 return (
162 StatusCode::BAD_REQUEST,
163 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
164 )
165 .into_response();
166 }
167 } else {
168 return (
169 StatusCode::BAD_REQUEST,
170 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
171 )
172 .into_response();
173 }
174
175 let password_hash = match hash(password, DEFAULT_COST) {
176 Ok(h) => h,
177 Err(e) => {
178 error!("Failed to hash password: {:?}", e);
179 return (
180 StatusCode::INTERNAL_SERVER_ERROR,
181 Json(json!({"error": "InternalError"})),
182 )
183 .into_response();
184 }
185 };
186
187 let update = sqlx::query!(
188 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
189 password_hash,
190 user_id
191 )
192 .execute(&state.db)
193 .await;
194
195 if let Err(e) = update {
196 error!("DB error updating password: {:?}", e);
197 return (
198 StatusCode::INTERNAL_SERVER_ERROR,
199 Json(json!({"error": "InternalError"})),
200 )
201 .into_response();
202 }
203
204 let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
205 .execute(&state.db)
206 .await;
207
208 info!("Password reset completed for user {}", user_id);
209
210 (StatusCode::OK, Json(json!({}))).into_response()
211}