this repo has no description
1use reqwest::Client; 2use serde::{Deserialize, Serialize}; 3use std::collections::HashMap; 4use std::sync::Arc; 5use tokio::sync::RwLock; 6 7use super::OAuthError; 8 9#[derive(Debug, Clone, Serialize, Deserialize)] 10pub struct ClientMetadata { 11 pub client_id: String, 12 #[serde(skip_serializing_if = "Option::is_none")] 13 pub client_name: Option<String>, 14 #[serde(skip_serializing_if = "Option::is_none")] 15 pub client_uri: Option<String>, 16 #[serde(skip_serializing_if = "Option::is_none")] 17 pub logo_uri: Option<String>, 18 pub redirect_uris: Vec<String>, 19 #[serde(default)] 20 pub grant_types: Vec<String>, 21 #[serde(default)] 22 pub response_types: Vec<String>, 23 #[serde(skip_serializing_if = "Option::is_none")] 24 pub scope: Option<String>, 25 #[serde(skip_serializing_if = "Option::is_none")] 26 pub token_endpoint_auth_method: Option<String>, 27 #[serde(skip_serializing_if = "Option::is_none")] 28 pub dpop_bound_access_tokens: Option<bool>, 29 #[serde(skip_serializing_if = "Option::is_none")] 30 pub jwks: Option<serde_json::Value>, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub jwks_uri: Option<String>, 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub application_type: Option<String>, 35} 36 37impl Default for ClientMetadata { 38 fn default() -> Self { 39 Self { 40 client_id: String::new(), 41 client_name: None, 42 client_uri: None, 43 logo_uri: None, 44 redirect_uris: Vec::new(), 45 grant_types: vec!["authorization_code".to_string()], 46 response_types: vec!["code".to_string()], 47 scope: None, 48 token_endpoint_auth_method: Some("none".to_string()), 49 dpop_bound_access_tokens: None, 50 jwks: None, 51 jwks_uri: None, 52 application_type: None, 53 } 54 } 55} 56 57#[derive(Clone)] 58pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 60 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>, 61 http_client: Client, 62 cache_ttl_secs: u64, 63} 64 65struct CachedMetadata { 66 metadata: ClientMetadata, 67 cached_at: std::time::Instant, 68} 69 70struct CachedJwks { 71 jwks: serde_json::Value, 72 cached_at: std::time::Instant, 73} 74 75impl ClientMetadataCache { 76 pub fn new(cache_ttl_secs: u64) -> Self { 77 Self { 78 cache: Arc::new(RwLock::new(HashMap::new())), 79 jwks_cache: Arc::new(RwLock::new(HashMap::new())), 80 http_client: Client::builder() 81 .timeout(std::time::Duration::from_secs(30)) 82 .connect_timeout(std::time::Duration::from_secs(10)) 83 .build() 84 .unwrap_or_else(|_| Client::new()), 85 cache_ttl_secs, 86 } 87 } 88 89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 90 { 91 let cache = self.cache.read().await; 92 if let Some(cached) = cache.get(client_id) { 93 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 94 return Ok(cached.metadata.clone()); 95 } 96 } 97 } 98 99 let metadata = self.fetch_metadata(client_id).await?; 100 101 { 102 let mut cache = self.cache.write().await; 103 cache.insert( 104 client_id.to_string(), 105 CachedMetadata { 106 metadata: metadata.clone(), 107 cached_at: std::time::Instant::now(), 108 }, 109 ); 110 } 111 112 Ok(metadata) 113 } 114 115 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 116 if let Some(jwks) = &metadata.jwks { 117 return Ok(jwks.clone()); 118 } 119 120 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| { 121 OAuthError::InvalidClient( 122 "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 123 ) 124 })?; 125 126 { 127 let cache = self.jwks_cache.read().await; 128 if let Some(cached) = cache.get(jwks_uri) { 129 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 130 return Ok(cached.jwks.clone()); 131 } 132 } 133 } 134 135 let jwks = self.fetch_jwks(jwks_uri).await?; 136 137 { 138 let mut cache = self.jwks_cache.write().await; 139 cache.insert( 140 jwks_uri.clone(), 141 CachedJwks { 142 jwks: jwks.clone(), 143 cached_at: std::time::Instant::now(), 144 }, 145 ); 146 } 147 148 Ok(jwks) 149 } 150 151 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 152 if !jwks_uri.starts_with("https://") { 153 if !jwks_uri.starts_with("http://") 154 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) 155 { 156 return Err(OAuthError::InvalidClient( 157 "jwks_uri must use https (except for localhost)".to_string(), 158 )); 159 } 160 } 161 162 let response = self 163 .http_client 164 .get(jwks_uri) 165 .header("Accept", "application/json") 166 .send() 167 .await 168 .map_err(|e| { 169 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e)) 170 })?; 171 172 if !response.status().is_success() { 173 return Err(OAuthError::InvalidClient(format!( 174 "Failed to fetch JWKS: HTTP {}", 175 response.status() 176 ))); 177 } 178 179 let jwks: serde_json::Value = response 180 .json() 181 .await 182 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?; 183 184 if jwks.get("keys").and_then(|k| k.as_array()).is_none() { 185 return Err(OAuthError::InvalidClient( 186 "JWKS must contain a 'keys' array".to_string(), 187 )); 188 } 189 190 Ok(jwks) 191 } 192 193 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 194 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 195 return Err(OAuthError::InvalidClient( 196 "client_id must be a URL".to_string(), 197 )); 198 } 199 200 if client_id.starts_with("http://") 201 && !client_id.contains("localhost") 202 && !client_id.contains("127.0.0.1") 203 { 204 return Err(OAuthError::InvalidClient( 205 "Non-localhost client_id must use https".to_string(), 206 )); 207 } 208 209 let response = self 210 .http_client 211 .get(client_id) 212 .header("Accept", "application/json") 213 .send() 214 .await 215 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?; 216 217 if !response.status().is_success() { 218 return Err(OAuthError::InvalidClient(format!( 219 "Failed to fetch client metadata: HTTP {}", 220 response.status() 221 ))); 222 } 223 224 let mut metadata: ClientMetadata = response 225 .json() 226 .await 227 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?; 228 229 if metadata.client_id.is_empty() { 230 metadata.client_id = client_id.to_string(); 231 } else if metadata.client_id != client_id { 232 return Err(OAuthError::InvalidClient( 233 "client_id in metadata does not match request".to_string(), 234 )); 235 } 236 237 self.validate_metadata(&metadata)?; 238 239 Ok(metadata) 240 } 241 242 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 243 if metadata.redirect_uris.is_empty() { 244 return Err(OAuthError::InvalidClient( 245 "redirect_uris is required".to_string(), 246 )); 247 } 248 249 for uri in &metadata.redirect_uris { 250 self.validate_redirect_uri_format(uri)?; 251 } 252 253 if !metadata.grant_types.is_empty() 254 && !metadata.grant_types.contains(&"authorization_code".to_string()) 255 { 256 return Err(OAuthError::InvalidClient( 257 "authorization_code grant type is required".to_string(), 258 )); 259 } 260 261 if !metadata.response_types.is_empty() 262 && !metadata.response_types.contains(&"code".to_string()) 263 { 264 return Err(OAuthError::InvalidClient( 265 "code response type is required".to_string(), 266 )); 267 } 268 269 Ok(()) 270 } 271 272 pub fn validate_redirect_uri( 273 &self, 274 metadata: &ClientMetadata, 275 redirect_uri: &str, 276 ) -> Result<(), OAuthError> { 277 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) { 278 return Err(OAuthError::InvalidRequest( 279 "redirect_uri not registered for client".to_string(), 280 )); 281 } 282 Ok(()) 283 } 284 285 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 286 if uri.contains('#') { 287 return Err(OAuthError::InvalidClient( 288 "redirect_uri must not contain a fragment".to_string(), 289 )); 290 } 291 292 let parsed = reqwest::Url::parse(uri).map_err(|_| { 293 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)) 294 })?; 295 296 let scheme = parsed.scheme(); 297 298 if scheme == "http" { 299 let host = parsed.host_str().unwrap_or(""); 300 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" { 301 return Err(OAuthError::InvalidClient( 302 "http redirect_uri only allowed for localhost".to_string(), 303 )); 304 } 305 } else if scheme == "https" { 306 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') { 307 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) { 308 return Err(OAuthError::InvalidClient(format!( 309 "Invalid redirect_uri scheme: {}", 310 scheme 311 ))); 312 } 313 } else { 314 return Err(OAuthError::InvalidClient(format!( 315 "Invalid redirect_uri scheme: {}", 316 scheme 317 ))); 318 } 319 320 Ok(()) 321 } 322} 323 324impl ClientMetadata { 325 pub fn requires_dpop(&self) -> bool { 326 self.dpop_bound_access_tokens.unwrap_or(false) 327 } 328 329 pub fn auth_method(&self) -> &str { 330 self.token_endpoint_auth_method 331 .as_deref() 332 .unwrap_or("none") 333 } 334} 335 336pub async fn verify_client_auth( 337 cache: &ClientMetadataCache, 338 metadata: &ClientMetadata, 339 client_auth: &super::ClientAuth, 340) -> Result<(), OAuthError> { 341 let expected_method = metadata.auth_method(); 342 343 match (expected_method, client_auth) { 344 ("none", super::ClientAuth::None) => Ok(()), 345 346 ("none", _) => Err(OAuthError::InvalidClient( 347 "Client is configured for no authentication, but credentials were provided".to_string(), 348 )), 349 350 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 351 verify_private_key_jwt_async(cache, metadata, client_assertion).await 352 } 353 354 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( 355 "Client requires private_key_jwt authentication".to_string(), 356 )), 357 358 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => { 359 Err(OAuthError::InvalidClient( 360 "client_secret_post is not supported for ATProto OAuth".to_string(), 361 )) 362 } 363 364 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => { 365 Err(OAuthError::InvalidClient( 366 "client_secret_basic is not supported for ATProto OAuth".to_string(), 367 )) 368 } 369 370 (method, _) => Err(OAuthError::InvalidClient(format!( 371 "Unsupported or mismatched authentication method: {}", 372 method 373 ))), 374 } 375} 376 377async fn verify_private_key_jwt_async( 378 cache: &ClientMetadataCache, 379 metadata: &ClientMetadata, 380 client_assertion: &str, 381) -> Result<(), OAuthError> { 382 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 383 384 let parts: Vec<&str> = client_assertion.split('.').collect(); 385 if parts.len() != 3 { 386 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string())); 387 } 388 389 let header_bytes = URL_SAFE_NO_PAD 390 .decode(parts[0]) 391 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?; 392 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 393 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?; 394 395 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| { 396 OAuthError::InvalidClient("Missing alg in client_assertion".to_string()) 397 })?; 398 399 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") { 400 return Err(OAuthError::InvalidClient(format!( 401 "Unsupported client_assertion algorithm: {}", 402 alg 403 ))); 404 } 405 406 let kid = header.get("kid").and_then(|k| k.as_str()); 407 408 let payload_bytes = URL_SAFE_NO_PAD 409 .decode(parts[1]) 410 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?; 411 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 412 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?; 413 414 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| { 415 OAuthError::InvalidClient("Missing iss in client_assertion".to_string()) 416 })?; 417 if iss != metadata.client_id { 418 return Err(OAuthError::InvalidClient( 419 "client_assertion iss does not match client_id".to_string(), 420 )); 421 } 422 423 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| { 424 OAuthError::InvalidClient("Missing sub in client_assertion".to_string()) 425 })?; 426 if sub != metadata.client_id { 427 return Err(OAuthError::InvalidClient( 428 "client_assertion sub does not match client_id".to_string(), 429 )); 430 } 431 432 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| { 433 OAuthError::InvalidClient("Missing exp in client_assertion".to_string()) 434 })?; 435 let now = chrono::Utc::now().timestamp(); 436 if exp < now { 437 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string())); 438 } 439 440 let iat = payload.get("iat").and_then(|i| i.as_i64()); 441 if let Some(iat) = iat { 442 if iat > now + 60 { 443 return Err(OAuthError::InvalidClient( 444 "client_assertion iat is in the future".to_string(), 445 )); 446 } 447 } 448 449 let jwks = cache.get_jwks(metadata).await?; 450 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { 451 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) 452 })?; 453 454 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { 455 keys.iter() 456 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) 457 .collect() 458 } else { 459 keys.iter().collect() 460 }; 461 462 if matching_keys.is_empty() { 463 return Err(OAuthError::InvalidClient( 464 "No matching key found in client JWKS".to_string(), 465 )); 466 } 467 468 let signing_input = format!("{}.{}", parts[0], parts[1]); 469 let signature_bytes = URL_SAFE_NO_PAD 470 .decode(parts[2]) 471 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; 472 473 for key in matching_keys { 474 let key_alg = key.get("alg").and_then(|a| a.as_str()); 475 if key_alg.is_some() && key_alg != Some(alg) { 476 continue; 477 } 478 479 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); 480 481 let verified = match (alg, kty) { 482 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), 483 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), 484 ("RS256" | "RS384" | "RS512", "RSA") => { 485 verify_rsa(alg, key, &signing_input, &signature_bytes) 486 } 487 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), 488 _ => continue, 489 }; 490 491 if verified.is_ok() { 492 return Ok(()); 493 } 494 } 495 496 Err(OAuthError::InvalidClient( 497 "client_assertion signature verification failed".to_string(), 498 )) 499} 500 501fn verify_es256( 502 key: &serde_json::Value, 503 signing_input: &str, 504 signature: &[u8], 505) -> Result<(), OAuthError> { 506 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 507 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 508 use p256::EncodedPoint; 509 510 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 511 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 512 })?; 513 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 514 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 515 })?; 516 517 let x_bytes = URL_SAFE_NO_PAD.decode(x) 518 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 519 let y_bytes = URL_SAFE_NO_PAD.decode(y) 520 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 521 522 let mut point_bytes = vec![0x04]; 523 point_bytes.extend_from_slice(&x_bytes); 524 point_bytes.extend_from_slice(&y_bytes); 525 526 let point = EncodedPoint::from_bytes(&point_bytes) 527 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 528 let verifying_key = VerifyingKey::from_encoded_point(&point) 529 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 530 531 let sig = Signature::from_slice(signature) 532 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?; 533 534 verifying_key 535 .verify(signing_input.as_bytes(), &sig) 536 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 537} 538 539fn verify_es384( 540 key: &serde_json::Value, 541 signing_input: &str, 542 signature: &[u8], 543) -> Result<(), OAuthError> { 544 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 545 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 546 use p384::EncodedPoint; 547 548 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 549 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 550 })?; 551 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 552 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 553 })?; 554 555 let x_bytes = URL_SAFE_NO_PAD.decode(x) 556 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 557 let y_bytes = URL_SAFE_NO_PAD.decode(y) 558 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 559 560 let mut point_bytes = vec![0x04]; 561 point_bytes.extend_from_slice(&x_bytes); 562 point_bytes.extend_from_slice(&y_bytes); 563 564 let point = EncodedPoint::from_bytes(&point_bytes) 565 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 566 let verifying_key = VerifyingKey::from_encoded_point(&point) 567 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 568 569 let sig = Signature::from_slice(signature) 570 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?; 571 572 verifying_key 573 .verify(signing_input.as_bytes(), &sig) 574 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 575} 576 577fn verify_rsa( 578 _alg: &str, 579 _key: &serde_json::Value, 580 _signing_input: &str, 581 _signature: &[u8], 582) -> Result<(), OAuthError> { 583 Err(OAuthError::InvalidClient( 584 "RSA signature verification not yet supported - use EC keys".to_string(), 585 )) 586} 587 588fn verify_eddsa( 589 key: &serde_json::Value, 590 signing_input: &str, 591 signature: &[u8], 592) -> Result<(), OAuthError> { 593 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 594 use ed25519_dalek::{Signature, Verifier, VerifyingKey}; 595 596 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or(""); 597 if crv != "Ed25519" { 598 return Err(OAuthError::InvalidClient(format!( 599 "Unsupported EdDSA curve: {}", 600 crv 601 ))); 602 } 603 604 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 605 OAuthError::InvalidClient("Missing x in OKP key".to_string()) 606 })?; 607 608 let x_bytes = URL_SAFE_NO_PAD.decode(x) 609 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; 610 611 let key_bytes: [u8; 32] = x_bytes.try_into() 612 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; 613 614 let verifying_key = VerifyingKey::from_bytes(&key_bytes) 615 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; 616 617 let sig_bytes: [u8; 64] = signature.try_into() 618 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; 619 620 let sig = Signature::from_bytes(&sig_bytes); 621 622 verifying_key 623 .verify(signing_input.as_bytes(), &sig) 624 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string())) 625}