this repo has no description
1use reqwest::Client; 2use serde::{Deserialize, Serialize}; 3use std::collections::HashMap; 4use std::sync::Arc; 5use tokio::sync::RwLock; 6 7use super::OAuthError; 8 9#[derive(Debug, Clone, Serialize, Deserialize)] 10pub struct ClientMetadata { 11 pub client_id: String, 12 #[serde(skip_serializing_if = "Option::is_none")] 13 pub client_name: Option<String>, 14 #[serde(skip_serializing_if = "Option::is_none")] 15 pub client_uri: Option<String>, 16 #[serde(skip_serializing_if = "Option::is_none")] 17 pub logo_uri: Option<String>, 18 pub redirect_uris: Vec<String>, 19 #[serde(default)] 20 pub grant_types: Vec<String>, 21 #[serde(default)] 22 pub response_types: Vec<String>, 23 #[serde(skip_serializing_if = "Option::is_none")] 24 pub scope: Option<String>, 25 #[serde(skip_serializing_if = "Option::is_none")] 26 pub token_endpoint_auth_method: Option<String>, 27 #[serde(skip_serializing_if = "Option::is_none")] 28 pub dpop_bound_access_tokens: Option<bool>, 29 #[serde(skip_serializing_if = "Option::is_none")] 30 pub jwks: Option<serde_json::Value>, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub jwks_uri: Option<String>, 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub application_type: Option<String>, 35} 36 37impl Default for ClientMetadata { 38 fn default() -> Self { 39 Self { 40 client_id: String::new(), 41 client_name: None, 42 client_uri: None, 43 logo_uri: None, 44 redirect_uris: Vec::new(), 45 grant_types: vec!["authorization_code".to_string()], 46 response_types: vec!["code".to_string()], 47 scope: None, 48 token_endpoint_auth_method: Some("none".to_string()), 49 dpop_bound_access_tokens: None, 50 jwks: None, 51 jwks_uri: None, 52 application_type: None, 53 } 54 } 55} 56 57#[derive(Clone)] 58pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 60 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>, 61 http_client: Client, 62 cache_ttl_secs: u64, 63} 64 65struct CachedMetadata { 66 metadata: ClientMetadata, 67 cached_at: std::time::Instant, 68} 69 70struct CachedJwks { 71 jwks: serde_json::Value, 72 cached_at: std::time::Instant, 73} 74 75impl ClientMetadataCache { 76 pub fn new(cache_ttl_secs: u64) -> Self { 77 Self { 78 cache: Arc::new(RwLock::new(HashMap::new())), 79 jwks_cache: Arc::new(RwLock::new(HashMap::new())), 80 http_client: Client::builder() 81 .timeout(std::time::Duration::from_secs(30)) 82 .connect_timeout(std::time::Duration::from_secs(10)) 83 .build() 84 .unwrap_or_else(|_| Client::new()), 85 cache_ttl_secs, 86 } 87 } 88 89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 90 { 91 let cache = self.cache.read().await; 92 if let Some(cached) = cache.get(client_id) { 93 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 94 return Ok(cached.metadata.clone()); 95 } 96 } 97 } 98 let metadata = self.fetch_metadata(client_id).await?; 99 { 100 let mut cache = self.cache.write().await; 101 cache.insert( 102 client_id.to_string(), 103 CachedMetadata { 104 metadata: metadata.clone(), 105 cached_at: std::time::Instant::now(), 106 }, 107 ); 108 } 109 Ok(metadata) 110 } 111 112 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 113 if let Some(jwks) = &metadata.jwks { 114 return Ok(jwks.clone()); 115 } 116 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| { 117 OAuthError::InvalidClient( 118 "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 119 ) 120 })?; 121 { 122 let cache = self.jwks_cache.read().await; 123 if let Some(cached) = cache.get(jwks_uri) { 124 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 125 return Ok(cached.jwks.clone()); 126 } 127 } 128 } 129 let jwks = self.fetch_jwks(jwks_uri).await?; 130 { 131 let mut cache = self.jwks_cache.write().await; 132 cache.insert( 133 jwks_uri.clone(), 134 CachedJwks { 135 jwks: jwks.clone(), 136 cached_at: std::time::Instant::now(), 137 }, 138 ); 139 } 140 Ok(jwks) 141 } 142 143 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 144 if !jwks_uri.starts_with("https://") { 145 if !jwks_uri.starts_with("http://") 146 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) 147 { 148 return Err(OAuthError::InvalidClient( 149 "jwks_uri must use https (except for localhost)".to_string(), 150 )); 151 } 152 } 153 let response = self 154 .http_client 155 .get(jwks_uri) 156 .header("Accept", "application/json") 157 .send() 158 .await 159 .map_err(|e| { 160 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e)) 161 })?; 162 if !response.status().is_success() { 163 return Err(OAuthError::InvalidClient(format!( 164 "Failed to fetch JWKS: HTTP {}", 165 response.status() 166 ))); 167 } 168 let jwks: serde_json::Value = response 169 .json() 170 .await 171 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?; 172 if jwks.get("keys").and_then(|k| k.as_array()).is_none() { 173 return Err(OAuthError::InvalidClient( 174 "JWKS must contain a 'keys' array".to_string(), 175 )); 176 } 177 Ok(jwks) 178 } 179 180 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 181 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 182 return Err(OAuthError::InvalidClient( 183 "client_id must be a URL".to_string(), 184 )); 185 } 186 if client_id.starts_with("http://") 187 && !client_id.contains("localhost") 188 && !client_id.contains("127.0.0.1") 189 { 190 return Err(OAuthError::InvalidClient( 191 "Non-localhost client_id must use https".to_string(), 192 )); 193 } 194 let response = self 195 .http_client 196 .get(client_id) 197 .header("Accept", "application/json") 198 .send() 199 .await 200 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?; 201 if !response.status().is_success() { 202 return Err(OAuthError::InvalidClient(format!( 203 "Failed to fetch client metadata: HTTP {}", 204 response.status() 205 ))); 206 } 207 let mut metadata: ClientMetadata = response 208 .json() 209 .await 210 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?; 211 if metadata.client_id.is_empty() { 212 metadata.client_id = client_id.to_string(); 213 } else if metadata.client_id != client_id { 214 return Err(OAuthError::InvalidClient( 215 "client_id in metadata does not match request".to_string(), 216 )); 217 } 218 self.validate_metadata(&metadata)?; 219 Ok(metadata) 220 } 221 222 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 223 if metadata.redirect_uris.is_empty() { 224 return Err(OAuthError::InvalidClient( 225 "redirect_uris is required".to_string(), 226 )); 227 } 228 for uri in &metadata.redirect_uris { 229 self.validate_redirect_uri_format(uri)?; 230 } 231 if !metadata.grant_types.is_empty() 232 && !metadata.grant_types.contains(&"authorization_code".to_string()) 233 { 234 return Err(OAuthError::InvalidClient( 235 "authorization_code grant type is required".to_string(), 236 )); 237 } 238 if !metadata.response_types.is_empty() 239 && !metadata.response_types.contains(&"code".to_string()) 240 { 241 return Err(OAuthError::InvalidClient( 242 "code response type is required".to_string(), 243 )); 244 } 245 Ok(()) 246 } 247 248 pub fn validate_redirect_uri( 249 &self, 250 metadata: &ClientMetadata, 251 redirect_uri: &str, 252 ) -> Result<(), OAuthError> { 253 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) { 254 return Err(OAuthError::InvalidRequest( 255 "redirect_uri not registered for client".to_string(), 256 )); 257 } 258 Ok(()) 259 } 260 261 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 262 if uri.contains('#') { 263 return Err(OAuthError::InvalidClient( 264 "redirect_uri must not contain a fragment".to_string(), 265 )); 266 } 267 let parsed = reqwest::Url::parse(uri).map_err(|_| { 268 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)) 269 })?; 270 let scheme = parsed.scheme(); 271 if scheme == "http" { 272 let host = parsed.host_str().unwrap_or(""); 273 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" { 274 return Err(OAuthError::InvalidClient( 275 "http redirect_uri only allowed for localhost".to_string(), 276 )); 277 } 278 } else if scheme == "https" { 279 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') { 280 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) { 281 return Err(OAuthError::InvalidClient(format!( 282 "Invalid redirect_uri scheme: {}", 283 scheme 284 ))); 285 } 286 } else { 287 return Err(OAuthError::InvalidClient(format!( 288 "Invalid redirect_uri scheme: {}", 289 scheme 290 ))); 291 } 292 Ok(()) 293 } 294} 295 296impl ClientMetadata { 297 pub fn requires_dpop(&self) -> bool { 298 self.dpop_bound_access_tokens.unwrap_or(false) 299 } 300 301 pub fn auth_method(&self) -> &str { 302 self.token_endpoint_auth_method 303 .as_deref() 304 .unwrap_or("none") 305 } 306} 307 308pub async fn verify_client_auth( 309 cache: &ClientMetadataCache, 310 metadata: &ClientMetadata, 311 client_auth: &super::ClientAuth, 312) -> Result<(), OAuthError> { 313 let expected_method = metadata.auth_method(); 314 match (expected_method, client_auth) { 315 ("none", super::ClientAuth::None) => Ok(()), 316 ("none", _) => Err(OAuthError::InvalidClient( 317 "Client is configured for no authentication, but credentials were provided".to_string(), 318 )), 319 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 320 verify_private_key_jwt_async(cache, metadata, client_assertion).await 321 } 322 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( 323 "Client requires private_key_jwt authentication".to_string(), 324 )), 325 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => { 326 Err(OAuthError::InvalidClient( 327 "client_secret_post is not supported for ATProto OAuth".to_string(), 328 )) 329 } 330 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => { 331 Err(OAuthError::InvalidClient( 332 "client_secret_basic is not supported for ATProto OAuth".to_string(), 333 )) 334 } 335 (method, _) => Err(OAuthError::InvalidClient(format!( 336 "Unsupported or mismatched authentication method: {}", 337 method 338 ))), 339 } 340} 341 342async fn verify_private_key_jwt_async( 343 cache: &ClientMetadataCache, 344 metadata: &ClientMetadata, 345 client_assertion: &str, 346) -> Result<(), OAuthError> { 347 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 348 let parts: Vec<&str> = client_assertion.split('.').collect(); 349 if parts.len() != 3 { 350 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string())); 351 } 352 let header_bytes = URL_SAFE_NO_PAD 353 .decode(parts[0]) 354 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?; 355 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 356 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?; 357 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| { 358 OAuthError::InvalidClient("Missing alg in client_assertion".to_string()) 359 })?; 360 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") { 361 return Err(OAuthError::InvalidClient(format!( 362 "Unsupported client_assertion algorithm: {}", 363 alg 364 ))); 365 } 366 let kid = header.get("kid").and_then(|k| k.as_str()); 367 let payload_bytes = URL_SAFE_NO_PAD 368 .decode(parts[1]) 369 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?; 370 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 371 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?; 372 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| { 373 OAuthError::InvalidClient("Missing iss in client_assertion".to_string()) 374 })?; 375 if iss != metadata.client_id { 376 return Err(OAuthError::InvalidClient( 377 "client_assertion iss does not match client_id".to_string(), 378 )); 379 } 380 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| { 381 OAuthError::InvalidClient("Missing sub in client_assertion".to_string()) 382 })?; 383 if sub != metadata.client_id { 384 return Err(OAuthError::InvalidClient( 385 "client_assertion sub does not match client_id".to_string(), 386 )); 387 } 388 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| { 389 OAuthError::InvalidClient("Missing exp in client_assertion".to_string()) 390 })?; 391 let now = chrono::Utc::now().timestamp(); 392 if exp < now { 393 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string())); 394 } 395 let iat = payload.get("iat").and_then(|i| i.as_i64()); 396 if let Some(iat) = iat { 397 if iat > now + 60 { 398 return Err(OAuthError::InvalidClient( 399 "client_assertion iat is in the future".to_string(), 400 )); 401 } 402 } 403 let jwks = cache.get_jwks(metadata).await?; 404 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { 405 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) 406 })?; 407 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { 408 keys.iter() 409 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) 410 .collect() 411 } else { 412 keys.iter().collect() 413 }; 414 if matching_keys.is_empty() { 415 return Err(OAuthError::InvalidClient( 416 "No matching key found in client JWKS".to_string(), 417 )); 418 } 419 let signing_input = format!("{}.{}", parts[0], parts[1]); 420 let signature_bytes = URL_SAFE_NO_PAD 421 .decode(parts[2]) 422 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; 423 for key in matching_keys { 424 let key_alg = key.get("alg").and_then(|a| a.as_str()); 425 if key_alg.is_some() && key_alg != Some(alg) { 426 continue; 427 } 428 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); 429 let verified = match (alg, kty) { 430 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), 431 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), 432 ("RS256" | "RS384" | "RS512", "RSA") => { 433 verify_rsa(alg, key, &signing_input, &signature_bytes) 434 } 435 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), 436 _ => continue, 437 }; 438 if verified.is_ok() { 439 return Ok(()); 440 } 441 } 442 Err(OAuthError::InvalidClient( 443 "client_assertion signature verification failed".to_string(), 444 )) 445} 446 447fn verify_es256( 448 key: &serde_json::Value, 449 signing_input: &str, 450 signature: &[u8], 451) -> Result<(), OAuthError> { 452 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 453 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 454 use p256::EncodedPoint; 455 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 456 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 457 })?; 458 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 459 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 460 })?; 461 let x_bytes = URL_SAFE_NO_PAD.decode(x) 462 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 463 let y_bytes = URL_SAFE_NO_PAD.decode(y) 464 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 465 let mut point_bytes = vec![0x04]; 466 point_bytes.extend_from_slice(&x_bytes); 467 point_bytes.extend_from_slice(&y_bytes); 468 let point = EncodedPoint::from_bytes(&point_bytes) 469 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 470 let verifying_key = VerifyingKey::from_encoded_point(&point) 471 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 472 let sig = Signature::from_slice(signature) 473 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?; 474 verifying_key 475 .verify(signing_input.as_bytes(), &sig) 476 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 477} 478 479fn verify_es384( 480 key: &serde_json::Value, 481 signing_input: &str, 482 signature: &[u8], 483) -> Result<(), OAuthError> { 484 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 485 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 486 use p384::EncodedPoint; 487 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 488 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 489 })?; 490 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 491 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 492 })?; 493 let x_bytes = URL_SAFE_NO_PAD.decode(x) 494 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 495 let y_bytes = URL_SAFE_NO_PAD.decode(y) 496 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 497 let mut point_bytes = vec![0x04]; 498 point_bytes.extend_from_slice(&x_bytes); 499 point_bytes.extend_from_slice(&y_bytes); 500 let point = EncodedPoint::from_bytes(&point_bytes) 501 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 502 let verifying_key = VerifyingKey::from_encoded_point(&point) 503 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 504 let sig = Signature::from_slice(signature) 505 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?; 506 verifying_key 507 .verify(signing_input.as_bytes(), &sig) 508 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 509} 510 511fn verify_rsa( 512 _alg: &str, 513 _key: &serde_json::Value, 514 _signing_input: &str, 515 _signature: &[u8], 516) -> Result<(), OAuthError> { 517 Err(OAuthError::InvalidClient( 518 "RSA signature verification not yet supported - use EC keys".to_string(), 519 )) 520} 521 522fn verify_eddsa( 523 key: &serde_json::Value, 524 signing_input: &str, 525 signature: &[u8], 526) -> Result<(), OAuthError> { 527 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 528 use ed25519_dalek::{Signature, Verifier, VerifyingKey}; 529 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or(""); 530 if crv != "Ed25519" { 531 return Err(OAuthError::InvalidClient(format!( 532 "Unsupported EdDSA curve: {}", 533 crv 534 ))); 535 } 536 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 537 OAuthError::InvalidClient("Missing x in OKP key".to_string()) 538 })?; 539 let x_bytes = URL_SAFE_NO_PAD.decode(x) 540 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; 541 let key_bytes: [u8; 32] = x_bytes.try_into() 542 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; 543 let verifying_key = VerifyingKey::from_bytes(&key_bytes) 544 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; 545 let sig_bytes: [u8; 64] = signature.try_into() 546 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; 547 let sig = Signature::from_bytes(&sig_bytes); 548 verifying_key 549 .verify(signing_input.as_bytes(), &sig) 550 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string())) 551}