this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::State,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use bcrypt::{hash, DEFAULT_COST};
9use chrono::{Duration, Utc};
10use rand::Rng;
11use serde::Deserialize;
12use serde_json::json;
13use tracing::{error, info, warn};
14
15fn generate_reset_code() -> String {
16 let mut rng = rand::thread_rng();
17 let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyz234567".chars().collect();
18 let part1: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
19 let part2: String = (0..5).map(|_| chars[rng.gen_range(0..chars.len())]).collect();
20 format!("{}-{}", part1, part2)
21}
22
23#[derive(Deserialize)]
24pub struct RequestPasswordResetInput {
25 pub email: String,
26}
27
28pub async fn request_password_reset(
29 State(state): State<AppState>,
30 Json(input): Json<RequestPasswordResetInput>,
31) -> Response {
32 let email = input.email.trim().to_lowercase();
33 if email.is_empty() {
34 return (
35 StatusCode::BAD_REQUEST,
36 Json(json!({"error": "InvalidRequest", "message": "email is required"})),
37 )
38 .into_response();
39 }
40
41 let user = sqlx::query!(
42 "SELECT id, handle FROM users WHERE LOWER(email) = $1",
43 email
44 )
45 .fetch_optional(&state.db)
46 .await;
47
48 let (user_id, handle) = match user {
49 Ok(Some(row)) => (row.id, row.handle),
50 Ok(None) => {
51 info!("Password reset requested for unknown email: {}", email);
52 return (StatusCode::OK, Json(json!({}))).into_response();
53 }
54 Err(e) => {
55 error!("DB error in request_password_reset: {:?}", e);
56 return (
57 StatusCode::INTERNAL_SERVER_ERROR,
58 Json(json!({"error": "InternalError"})),
59 )
60 .into_response();
61 }
62 };
63
64 let code = generate_reset_code();
65 let expires_at = Utc::now() + Duration::minutes(10);
66
67 let update = sqlx::query!(
68 "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3",
69 code,
70 expires_at,
71 user_id
72 )
73 .execute(&state.db)
74 .await;
75
76 if let Err(e) = update {
77 error!("DB error setting reset code: {:?}", e);
78 return (
79 StatusCode::INTERNAL_SERVER_ERROR,
80 Json(json!({"error": "InternalError"})),
81 )
82 .into_response();
83 }
84
85 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
86 if let Err(e) = crate::notifications::enqueue_password_reset(
87 &state.db,
88 user_id,
89 &email,
90 &handle,
91 &code,
92 &hostname,
93 )
94 .await
95 {
96 warn!("Failed to enqueue password reset notification: {:?}", e);
97 }
98
99 info!("Password reset requested for user {}", user_id);
100
101 (StatusCode::OK, Json(json!({}))).into_response()
102}
103
104#[derive(Deserialize)]
105pub struct ResetPasswordInput {
106 pub token: String,
107 pub password: String,
108}
109
110pub async fn reset_password(
111 State(state): State<AppState>,
112 Json(input): Json<ResetPasswordInput>,
113) -> Response {
114 let token = input.token.trim();
115 let password = &input.password;
116
117 if token.is_empty() {
118 return (
119 StatusCode::BAD_REQUEST,
120 Json(json!({"error": "InvalidToken", "message": "token is required"})),
121 )
122 .into_response();
123 }
124
125 if password.is_empty() {
126 return (
127 StatusCode::BAD_REQUEST,
128 Json(json!({"error": "InvalidRequest", "message": "password is required"})),
129 )
130 .into_response();
131 }
132
133 let user = sqlx::query!(
134 "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1",
135 token
136 )
137 .fetch_optional(&state.db)
138 .await;
139
140 let (user_id, expires_at) = match user {
141 Ok(Some(row)) => {
142 let expires = row.password_reset_code_expires_at;
143 (row.id, expires)
144 }
145 Ok(None) => {
146 return (
147 StatusCode::BAD_REQUEST,
148 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
149 )
150 .into_response();
151 }
152 Err(e) => {
153 error!("DB error in reset_password: {:?}", e);
154 return (
155 StatusCode::INTERNAL_SERVER_ERROR,
156 Json(json!({"error": "InternalError"})),
157 )
158 .into_response();
159 }
160 };
161
162 if let Some(exp) = expires_at {
163 if Utc::now() > exp {
164 let _ = sqlx::query!(
165 "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1",
166 user_id
167 )
168 .execute(&state.db)
169 .await;
170
171 return (
172 StatusCode::BAD_REQUEST,
173 Json(json!({"error": "ExpiredToken", "message": "Token has expired"})),
174 )
175 .into_response();
176 }
177 } else {
178 return (
179 StatusCode::BAD_REQUEST,
180 Json(json!({"error": "InvalidToken", "message": "Invalid or expired token"})),
181 )
182 .into_response();
183 }
184
185 let password_hash = match hash(password, DEFAULT_COST) {
186 Ok(h) => h,
187 Err(e) => {
188 error!("Failed to hash password: {:?}", e);
189 return (
190 StatusCode::INTERNAL_SERVER_ERROR,
191 Json(json!({"error": "InternalError"})),
192 )
193 .into_response();
194 }
195 };
196
197 let update = sqlx::query!(
198 "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $2",
199 password_hash,
200 user_id
201 )
202 .execute(&state.db)
203 .await;
204
205 if let Err(e) = update {
206 error!("DB error updating password: {:?}", e);
207 return (
208 StatusCode::INTERNAL_SERVER_ERROR,
209 Json(json!({"error": "InternalError"})),
210 )
211 .into_response();
212 }
213
214 let _ = sqlx::query!("DELETE FROM session_tokens WHERE did = (SELECT did FROM users WHERE id = $1)", user_id)
215 .execute(&state.db)
216 .await;
217
218 info!("Password reset completed for user {}", user_id);
219
220 (StatusCode::OK, Json(json!({}))).into_response()
221}