this repo has no description
1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2use chrono::{Duration, Utc};
3use sqlx::{PgPool, Row};
4use uuid::Uuid;
5use webauthn_rs::prelude::*;
6
7pub struct WebAuthnConfig {
8 webauthn: Webauthn,
9}
10
11impl WebAuthnConfig {
12 pub fn new(hostname: &str) -> Result<Self, String> {
13 let rp_id = hostname.to_string();
14 let rp_origin = Url::parse(&format!("https://{}", hostname))
15 .map_err(|e| format!("Invalid origin URL: {}", e))?;
16
17 let builder = WebauthnBuilder::new(&rp_id, &rp_origin)
18 .map_err(|e| format!("Failed to create WebAuthn builder: {}", e))?
19 .rp_name("Tranquil PDS")
20 .danger_set_user_presence_only_security_keys(true);
21
22 let webauthn = builder
23 .build()
24 .map_err(|e| format!("Failed to build WebAuthn: {}", e))?;
25
26 Ok(Self { webauthn })
27 }
28
29 pub fn start_registration(
30 &self,
31 user_id: &str,
32 username: &str,
33 display_name: &str,
34 exclude_credentials: Vec<CredentialID>,
35 ) -> Result<(CreationChallengeResponse, SecurityKeyRegistration), String> {
36 let user_unique_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, user_id.as_bytes());
37
38 self.webauthn
39 .start_securitykey_registration(
40 user_unique_id,
41 username,
42 display_name,
43 if exclude_credentials.is_empty() {
44 None
45 } else {
46 Some(exclude_credentials)
47 },
48 None,
49 None,
50 )
51 .map_err(|e| format!("Failed to start registration: {}", e))
52 }
53
54 pub fn finish_registration(
55 &self,
56 reg: &RegisterPublicKeyCredential,
57 state: &SecurityKeyRegistration,
58 ) -> Result<SecurityKey, String> {
59 self.webauthn
60 .finish_securitykey_registration(reg, state)
61 .map_err(|e| format!("Failed to finish registration: {}", e))
62 }
63
64 pub fn start_authentication(
65 &self,
66 credentials: Vec<SecurityKey>,
67 ) -> Result<(RequestChallengeResponse, SecurityKeyAuthentication), String> {
68 self.webauthn
69 .start_securitykey_authentication(&credentials)
70 .map_err(|e| format!("Failed to start authentication: {}", e))
71 }
72
73 pub fn finish_authentication(
74 &self,
75 auth: &PublicKeyCredential,
76 state: &SecurityKeyAuthentication,
77 ) -> Result<AuthenticationResult, String> {
78 self.webauthn
79 .finish_securitykey_authentication(auth, state)
80 .map_err(|e| format!("Failed to finish authentication: {}", e))
81 }
82}
83
84pub async fn save_registration_state(
85 pool: &PgPool,
86 did: &str,
87 state: &SecurityKeyRegistration,
88) -> Result<Uuid, sqlx::Error> {
89 let id = Uuid::new_v4();
90 let state_json = serde_json::to_string(state)
91 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?;
92 let challenge = id.as_bytes().to_vec();
93 let expires_at = Utc::now() + Duration::minutes(5);
94
95 sqlx::query!(
96 r#"
97 INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at)
98 VALUES ($1, $2, $3, 'registration', $4, $5)
99 "#,
100 id,
101 did,
102 challenge,
103 state_json,
104 expires_at,
105 )
106 .execute(pool)
107 .await?;
108
109 Ok(id)
110}
111
112pub async fn load_registration_state(
113 pool: &PgPool,
114 did: &str,
115) -> Result<Option<SecurityKeyRegistration>, sqlx::Error> {
116 let row = sqlx::query!(
117 r#"
118 SELECT state_json FROM webauthn_challenges
119 WHERE did = $1 AND challenge_type = 'registration' AND expires_at > NOW()
120 ORDER BY created_at DESC
121 LIMIT 1
122 "#,
123 did,
124 )
125 .fetch_optional(pool)
126 .await?;
127
128 match row {
129 Some(r) => {
130 let state: SecurityKeyRegistration =
131 serde_json::from_str(&r.state_json).map_err(|e| {
132 sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e))
133 })?;
134 Ok(Some(state))
135 }
136 None => Ok(None),
137 }
138}
139
140pub async fn delete_registration_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> {
141 sqlx::query!(
142 "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'registration'",
143 did,
144 )
145 .execute(pool)
146 .await?;
147 Ok(())
148}
149
150pub async fn save_authentication_state(
151 pool: &PgPool,
152 did: &str,
153 state: &SecurityKeyAuthentication,
154) -> Result<Uuid, sqlx::Error> {
155 let id = Uuid::new_v4();
156 let state_json = serde_json::to_string(state)
157 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?;
158 let challenge = id.as_bytes().to_vec();
159 let expires_at = Utc::now() + Duration::minutes(5);
160
161 sqlx::query!(
162 r#"
163 INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at)
164 VALUES ($1, $2, $3, 'authentication', $4, $5)
165 "#,
166 id,
167 did,
168 challenge,
169 state_json,
170 expires_at,
171 )
172 .execute(pool)
173 .await?;
174
175 Ok(id)
176}
177
178pub async fn load_authentication_state(
179 pool: &PgPool,
180 did: &str,
181) -> Result<Option<SecurityKeyAuthentication>, sqlx::Error> {
182 let row = sqlx::query!(
183 r#"
184 SELECT state_json FROM webauthn_challenges
185 WHERE did = $1 AND challenge_type = 'authentication' AND expires_at > NOW()
186 ORDER BY created_at DESC
187 LIMIT 1
188 "#,
189 did,
190 )
191 .fetch_optional(pool)
192 .await?;
193
194 match row {
195 Some(r) => {
196 let state: SecurityKeyAuthentication =
197 serde_json::from_str(&r.state_json).map_err(|e| {
198 sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e))
199 })?;
200 Ok(Some(state))
201 }
202 None => Ok(None),
203 }
204}
205
206pub async fn delete_authentication_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> {
207 sqlx::query!(
208 "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'authentication'",
209 did,
210 )
211 .execute(pool)
212 .await?;
213 Ok(())
214}
215
216pub async fn cleanup_expired_challenges(pool: &PgPool) -> Result<u64, sqlx::Error> {
217 let result = sqlx::query!("DELETE FROM webauthn_challenges WHERE expires_at < NOW()")
218 .execute(pool)
219 .await?;
220 Ok(result.rows_affected())
221}
222
223#[derive(Debug, Clone)]
224pub struct StoredPasskey {
225 pub id: Uuid,
226 pub did: String,
227 pub credential_id: Vec<u8>,
228 pub public_key: Vec<u8>,
229 pub sign_count: i32,
230 pub created_at: chrono::DateTime<Utc>,
231 pub last_used: Option<chrono::DateTime<Utc>>,
232 pub friendly_name: Option<String>,
233 pub aaguid: Option<Vec<u8>>,
234 pub transports: Option<Vec<String>>,
235}
236
237impl StoredPasskey {
238 pub fn to_security_key(&self) -> Result<SecurityKey, String> {
239 serde_json::from_slice(&self.public_key)
240 .map_err(|e| format!("Failed to deserialize security key: {}", e))
241 }
242
243 pub fn credential_id_base64(&self) -> String {
244 URL_SAFE_NO_PAD.encode(&self.credential_id)
245 }
246}
247
248pub async fn save_passkey(
249 pool: &PgPool,
250 did: &str,
251 security_key: &SecurityKey,
252 friendly_name: Option<&str>,
253) -> Result<Uuid, sqlx::Error> {
254 let id = Uuid::new_v4();
255 let credential_id = security_key.cred_id().to_vec();
256 let public_key = serde_json::to_vec(security_key)
257 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize security key: {}", e)))?;
258 let aaguid: Option<Vec<u8>> = None;
259
260 sqlx::query!(
261 r#"
262 INSERT INTO passkeys (id, did, credential_id, public_key, sign_count, friendly_name, aaguid)
263 VALUES ($1, $2, $3, $4, 0, $5, $6)
264 "#,
265 id,
266 did,
267 credential_id,
268 public_key,
269 friendly_name,
270 aaguid,
271 )
272 .execute(pool)
273 .await?;
274
275 Ok(id)
276}
277
278pub async fn get_passkeys_for_user(
279 pool: &PgPool,
280 did: &str,
281) -> Result<Vec<StoredPasskey>, sqlx::Error> {
282 let rows = sqlx::query!(
283 r#"
284 SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports
285 FROM passkeys
286 WHERE did = $1
287 ORDER BY created_at DESC
288 "#,
289 did,
290 )
291 .fetch_all(pool)
292 .await?;
293
294 Ok(rows
295 .into_iter()
296 .map(|r| StoredPasskey {
297 id: r.id,
298 did: r.did,
299 credential_id: r.credential_id,
300 public_key: r.public_key,
301 sign_count: r.sign_count,
302 created_at: r.created_at,
303 last_used: r.last_used,
304 friendly_name: r.friendly_name,
305 aaguid: r.aaguid,
306 transports: r.transports,
307 })
308 .collect())
309}
310
311pub async fn get_passkey_by_credential_id(
312 pool: &PgPool,
313 credential_id: &[u8],
314) -> Result<Option<StoredPasskey>, sqlx::Error> {
315 let row = sqlx::query!(
316 r#"
317 SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports
318 FROM passkeys
319 WHERE credential_id = $1
320 "#,
321 credential_id,
322 )
323 .fetch_optional(pool)
324 .await?;
325
326 Ok(row.map(|r| StoredPasskey {
327 id: r.id,
328 did: r.did,
329 credential_id: r.credential_id,
330 public_key: r.public_key,
331 sign_count: r.sign_count,
332 created_at: r.created_at,
333 last_used: r.last_used,
334 friendly_name: r.friendly_name,
335 aaguid: r.aaguid,
336 transports: r.transports,
337 }))
338}
339
340pub async fn update_passkey_counter(
341 pool: &PgPool,
342 credential_id: &[u8],
343 new_counter: u32,
344) -> Result<bool, sqlx::Error> {
345 let stored = get_passkey_by_credential_id(pool, credential_id).await?;
346 let Some(stored) = stored else {
347 return Err(sqlx::Error::RowNotFound);
348 };
349
350 if new_counter > 0 && new_counter <= stored.sign_count as u32 {
351 tracing::warn!(
352 credential_id = ?credential_id,
353 stored_counter = stored.sign_count,
354 new_counter = new_counter,
355 "Passkey counter did not increment - possible cloned key!"
356 );
357 return Ok(false);
358 }
359
360 sqlx::query!(
361 "UPDATE passkeys SET sign_count = $1, last_used = NOW() WHERE credential_id = $2",
362 new_counter as i32,
363 credential_id,
364 )
365 .execute(pool)
366 .await?;
367 Ok(true)
368}
369
370pub async fn delete_passkey(pool: &PgPool, id: Uuid, did: &str) -> Result<bool, sqlx::Error> {
371 let result = sqlx::query("DELETE FROM passkeys WHERE id = $1 AND did = $2")
372 .bind(id)
373 .bind(did)
374 .execute(pool)
375 .await?;
376 Ok(result.rows_affected() > 0)
377}
378
379pub async fn update_passkey_name(
380 pool: &PgPool,
381 id: Uuid,
382 did: &str,
383 name: &str,
384) -> Result<bool, sqlx::Error> {
385 let result = sqlx::query("UPDATE passkeys SET friendly_name = $1 WHERE id = $2 AND did = $3")
386 .bind(name)
387 .bind(id)
388 .bind(did)
389 .execute(pool)
390 .await?;
391 Ok(result.rows_affected() > 0)
392}
393
394pub async fn has_passkeys(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
395 let row = sqlx::query("SELECT COUNT(*) as count FROM passkeys WHERE did = $1")
396 .bind(did)
397 .fetch_one(pool)
398 .await?;
399 let count: i64 = row.get("count");
400 Ok(count > 0)
401}