this repo has no description
1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; 2use chrono::{Duration, Utc}; 3use sqlx::{PgPool, Row}; 4use uuid::Uuid; 5use webauthn_rs::prelude::*; 6 7pub struct WebAuthnConfig { 8 webauthn: Webauthn, 9} 10 11impl WebAuthnConfig { 12 pub fn new(hostname: &str) -> Result<Self, String> { 13 let rp_id = hostname.to_string(); 14 let rp_origin = Url::parse(&format!("https://{}", hostname)) 15 .map_err(|e| format!("Invalid origin URL: {}", e))?; 16 17 let builder = WebauthnBuilder::new(&rp_id, &rp_origin) 18 .map_err(|e| format!("Failed to create WebAuthn builder: {}", e))? 19 .rp_name("Tranquil PDS") 20 .danger_set_user_presence_only_security_keys(true); 21 22 let webauthn = builder 23 .build() 24 .map_err(|e| format!("Failed to build WebAuthn: {}", e))?; 25 26 Ok(Self { webauthn }) 27 } 28 29 pub fn start_registration( 30 &self, 31 user_id: &str, 32 username: &str, 33 display_name: &str, 34 exclude_credentials: Vec<CredentialID>, 35 ) -> Result<(CreationChallengeResponse, SecurityKeyRegistration), String> { 36 let user_unique_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, user_id.as_bytes()); 37 38 self.webauthn 39 .start_securitykey_registration( 40 user_unique_id, 41 username, 42 display_name, 43 if exclude_credentials.is_empty() { 44 None 45 } else { 46 Some(exclude_credentials) 47 }, 48 None, 49 None, 50 ) 51 .map_err(|e| format!("Failed to start registration: {}", e)) 52 } 53 54 pub fn finish_registration( 55 &self, 56 reg: &RegisterPublicKeyCredential, 57 state: &SecurityKeyRegistration, 58 ) -> Result<SecurityKey, String> { 59 self.webauthn 60 .finish_securitykey_registration(reg, state) 61 .map_err(|e| format!("Failed to finish registration: {}", e)) 62 } 63 64 pub fn start_authentication( 65 &self, 66 credentials: Vec<SecurityKey>, 67 ) -> Result<(RequestChallengeResponse, SecurityKeyAuthentication), String> { 68 self.webauthn 69 .start_securitykey_authentication(&credentials) 70 .map_err(|e| format!("Failed to start authentication: {}", e)) 71 } 72 73 pub fn finish_authentication( 74 &self, 75 auth: &PublicKeyCredential, 76 state: &SecurityKeyAuthentication, 77 ) -> Result<AuthenticationResult, String> { 78 self.webauthn 79 .finish_securitykey_authentication(auth, state) 80 .map_err(|e| format!("Failed to finish authentication: {}", e)) 81 } 82} 83 84pub async fn save_registration_state( 85 pool: &PgPool, 86 did: &str, 87 state: &SecurityKeyRegistration, 88) -> Result<Uuid, sqlx::Error> { 89 let id = Uuid::new_v4(); 90 let state_json = serde_json::to_string(state) 91 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?; 92 let challenge = id.as_bytes().to_vec(); 93 let expires_at = Utc::now() + Duration::minutes(5); 94 95 sqlx::query!( 96 r#" 97 INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at) 98 VALUES ($1, $2, $3, 'registration', $4, $5) 99 "#, 100 id, 101 did, 102 challenge, 103 state_json, 104 expires_at, 105 ) 106 .execute(pool) 107 .await?; 108 109 Ok(id) 110} 111 112pub async fn load_registration_state( 113 pool: &PgPool, 114 did: &str, 115) -> Result<Option<SecurityKeyRegistration>, sqlx::Error> { 116 let row = sqlx::query!( 117 r#" 118 SELECT state_json FROM webauthn_challenges 119 WHERE did = $1 AND challenge_type = 'registration' AND expires_at > NOW() 120 ORDER BY created_at DESC 121 LIMIT 1 122 "#, 123 did, 124 ) 125 .fetch_optional(pool) 126 .await?; 127 128 match row { 129 Some(r) => { 130 let state: SecurityKeyRegistration = 131 serde_json::from_str(&r.state_json).map_err(|e| { 132 sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e)) 133 })?; 134 Ok(Some(state)) 135 } 136 None => Ok(None), 137 } 138} 139 140pub async fn delete_registration_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> { 141 sqlx::query!( 142 "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'registration'", 143 did, 144 ) 145 .execute(pool) 146 .await?; 147 Ok(()) 148} 149 150pub async fn save_authentication_state( 151 pool: &PgPool, 152 did: &str, 153 state: &SecurityKeyAuthentication, 154) -> Result<Uuid, sqlx::Error> { 155 let id = Uuid::new_v4(); 156 let state_json = serde_json::to_string(state) 157 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?; 158 let challenge = id.as_bytes().to_vec(); 159 let expires_at = Utc::now() + Duration::minutes(5); 160 161 sqlx::query!( 162 r#" 163 INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at) 164 VALUES ($1, $2, $3, 'authentication', $4, $5) 165 "#, 166 id, 167 did, 168 challenge, 169 state_json, 170 expires_at, 171 ) 172 .execute(pool) 173 .await?; 174 175 Ok(id) 176} 177 178pub async fn load_authentication_state( 179 pool: &PgPool, 180 did: &str, 181) -> Result<Option<SecurityKeyAuthentication>, sqlx::Error> { 182 let row = sqlx::query!( 183 r#" 184 SELECT state_json FROM webauthn_challenges 185 WHERE did = $1 AND challenge_type = 'authentication' AND expires_at > NOW() 186 ORDER BY created_at DESC 187 LIMIT 1 188 "#, 189 did, 190 ) 191 .fetch_optional(pool) 192 .await?; 193 194 match row { 195 Some(r) => { 196 let state: SecurityKeyAuthentication = 197 serde_json::from_str(&r.state_json).map_err(|e| { 198 sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e)) 199 })?; 200 Ok(Some(state)) 201 } 202 None => Ok(None), 203 } 204} 205 206pub async fn delete_authentication_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> { 207 sqlx::query!( 208 "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'authentication'", 209 did, 210 ) 211 .execute(pool) 212 .await?; 213 Ok(()) 214} 215 216pub async fn cleanup_expired_challenges(pool: &PgPool) -> Result<u64, sqlx::Error> { 217 let result = sqlx::query!("DELETE FROM webauthn_challenges WHERE expires_at < NOW()") 218 .execute(pool) 219 .await?; 220 Ok(result.rows_affected()) 221} 222 223#[derive(Debug, Clone)] 224pub struct StoredPasskey { 225 pub id: Uuid, 226 pub did: String, 227 pub credential_id: Vec<u8>, 228 pub public_key: Vec<u8>, 229 pub sign_count: i32, 230 pub created_at: chrono::DateTime<Utc>, 231 pub last_used: Option<chrono::DateTime<Utc>>, 232 pub friendly_name: Option<String>, 233 pub aaguid: Option<Vec<u8>>, 234 pub transports: Option<Vec<String>>, 235} 236 237impl StoredPasskey { 238 pub fn to_security_key(&self) -> Result<SecurityKey, String> { 239 serde_json::from_slice(&self.public_key) 240 .map_err(|e| format!("Failed to deserialize security key: {}", e)) 241 } 242 243 pub fn credential_id_base64(&self) -> String { 244 URL_SAFE_NO_PAD.encode(&self.credential_id) 245 } 246} 247 248pub async fn save_passkey( 249 pool: &PgPool, 250 did: &str, 251 security_key: &SecurityKey, 252 friendly_name: Option<&str>, 253) -> Result<Uuid, sqlx::Error> { 254 let id = Uuid::new_v4(); 255 let credential_id = security_key.cred_id().to_vec(); 256 let public_key = serde_json::to_vec(security_key) 257 .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize security key: {}", e)))?; 258 let aaguid: Option<Vec<u8>> = None; 259 260 sqlx::query!( 261 r#" 262 INSERT INTO passkeys (id, did, credential_id, public_key, sign_count, friendly_name, aaguid) 263 VALUES ($1, $2, $3, $4, 0, $5, $6) 264 "#, 265 id, 266 did, 267 credential_id, 268 public_key, 269 friendly_name, 270 aaguid, 271 ) 272 .execute(pool) 273 .await?; 274 275 Ok(id) 276} 277 278pub async fn get_passkeys_for_user( 279 pool: &PgPool, 280 did: &str, 281) -> Result<Vec<StoredPasskey>, sqlx::Error> { 282 let rows = sqlx::query!( 283 r#" 284 SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports 285 FROM passkeys 286 WHERE did = $1 287 ORDER BY created_at DESC 288 "#, 289 did, 290 ) 291 .fetch_all(pool) 292 .await?; 293 294 Ok(rows 295 .into_iter() 296 .map(|r| StoredPasskey { 297 id: r.id, 298 did: r.did, 299 credential_id: r.credential_id, 300 public_key: r.public_key, 301 sign_count: r.sign_count, 302 created_at: r.created_at, 303 last_used: r.last_used, 304 friendly_name: r.friendly_name, 305 aaguid: r.aaguid, 306 transports: r.transports, 307 }) 308 .collect()) 309} 310 311pub async fn get_passkey_by_credential_id( 312 pool: &PgPool, 313 credential_id: &[u8], 314) -> Result<Option<StoredPasskey>, sqlx::Error> { 315 let row = sqlx::query!( 316 r#" 317 SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports 318 FROM passkeys 319 WHERE credential_id = $1 320 "#, 321 credential_id, 322 ) 323 .fetch_optional(pool) 324 .await?; 325 326 Ok(row.map(|r| StoredPasskey { 327 id: r.id, 328 did: r.did, 329 credential_id: r.credential_id, 330 public_key: r.public_key, 331 sign_count: r.sign_count, 332 created_at: r.created_at, 333 last_used: r.last_used, 334 friendly_name: r.friendly_name, 335 aaguid: r.aaguid, 336 transports: r.transports, 337 })) 338} 339 340pub async fn update_passkey_counter( 341 pool: &PgPool, 342 credential_id: &[u8], 343 new_counter: u32, 344) -> Result<bool, sqlx::Error> { 345 let stored = get_passkey_by_credential_id(pool, credential_id).await?; 346 let Some(stored) = stored else { 347 return Err(sqlx::Error::RowNotFound); 348 }; 349 350 if new_counter > 0 && new_counter <= stored.sign_count as u32 { 351 tracing::warn!( 352 credential_id = ?credential_id, 353 stored_counter = stored.sign_count, 354 new_counter = new_counter, 355 "Passkey counter did not increment - possible cloned key!" 356 ); 357 return Ok(false); 358 } 359 360 sqlx::query!( 361 "UPDATE passkeys SET sign_count = $1, last_used = NOW() WHERE credential_id = $2", 362 new_counter as i32, 363 credential_id, 364 ) 365 .execute(pool) 366 .await?; 367 Ok(true) 368} 369 370pub async fn delete_passkey(pool: &PgPool, id: Uuid, did: &str) -> Result<bool, sqlx::Error> { 371 let result = sqlx::query("DELETE FROM passkeys WHERE id = $1 AND did = $2") 372 .bind(id) 373 .bind(did) 374 .execute(pool) 375 .await?; 376 Ok(result.rows_affected() > 0) 377} 378 379pub async fn update_passkey_name( 380 pool: &PgPool, 381 id: Uuid, 382 did: &str, 383 name: &str, 384) -> Result<bool, sqlx::Error> { 385 let result = sqlx::query("UPDATE passkeys SET friendly_name = $1 WHERE id = $2 AND did = $3") 386 .bind(name) 387 .bind(id) 388 .bind(did) 389 .execute(pool) 390 .await?; 391 Ok(result.rows_affected() > 0) 392} 393 394pub async fn has_passkeys(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 395 let row = sqlx::query("SELECT COUNT(*) as count FROM passkeys WHERE did = $1") 396 .bind(did) 397 .fetch_one(pool) 398 .await?; 399 let count: i64 = row.get("count"); 400 Ok(count > 0) 401}