this repo has no description
1use reqwest::Client; 2use serde::{Deserialize, Serialize}; 3use std::collections::HashMap; 4use std::sync::Arc; 5use tokio::sync::RwLock; 6use super::OAuthError; 7#[derive(Debug, Clone, Serialize, Deserialize)] 8pub struct ClientMetadata { 9 pub client_id: String, 10 #[serde(skip_serializing_if = "Option::is_none")] 11 pub client_name: Option<String>, 12 #[serde(skip_serializing_if = "Option::is_none")] 13 pub client_uri: Option<String>, 14 #[serde(skip_serializing_if = "Option::is_none")] 15 pub logo_uri: Option<String>, 16 pub redirect_uris: Vec<String>, 17 #[serde(default)] 18 pub grant_types: Vec<String>, 19 #[serde(default)] 20 pub response_types: Vec<String>, 21 #[serde(skip_serializing_if = "Option::is_none")] 22 pub scope: Option<String>, 23 #[serde(skip_serializing_if = "Option::is_none")] 24 pub token_endpoint_auth_method: Option<String>, 25 #[serde(skip_serializing_if = "Option::is_none")] 26 pub dpop_bound_access_tokens: Option<bool>, 27 #[serde(skip_serializing_if = "Option::is_none")] 28 pub jwks: Option<serde_json::Value>, 29 #[serde(skip_serializing_if = "Option::is_none")] 30 pub jwks_uri: Option<String>, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub application_type: Option<String>, 33} 34impl Default for ClientMetadata { 35 fn default() -> Self { 36 Self { 37 client_id: String::new(), 38 client_name: None, 39 client_uri: None, 40 logo_uri: None, 41 redirect_uris: Vec::new(), 42 grant_types: vec!["authorization_code".to_string()], 43 response_types: vec!["code".to_string()], 44 scope: None, 45 token_endpoint_auth_method: Some("none".to_string()), 46 dpop_bound_access_tokens: None, 47 jwks: None, 48 jwks_uri: None, 49 application_type: None, 50 } 51 } 52} 53#[derive(Clone)] 54pub struct ClientMetadataCache { 55 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 56 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>, 57 http_client: Client, 58 cache_ttl_secs: u64, 59} 60struct CachedMetadata { 61 metadata: ClientMetadata, 62 cached_at: std::time::Instant, 63} 64struct CachedJwks { 65 jwks: serde_json::Value, 66 cached_at: std::time::Instant, 67} 68impl ClientMetadataCache { 69 pub fn new(cache_ttl_secs: u64) -> Self { 70 Self { 71 cache: Arc::new(RwLock::new(HashMap::new())), 72 jwks_cache: Arc::new(RwLock::new(HashMap::new())), 73 http_client: Client::builder() 74 .timeout(std::time::Duration::from_secs(30)) 75 .connect_timeout(std::time::Duration::from_secs(10)) 76 .build() 77 .unwrap_or_else(|_| Client::new()), 78 cache_ttl_secs, 79 } 80 } 81 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 82 { 83 let cache = self.cache.read().await; 84 if let Some(cached) = cache.get(client_id) { 85 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 86 return Ok(cached.metadata.clone()); 87 } 88 } 89 } 90 let metadata = self.fetch_metadata(client_id).await?; 91 { 92 let mut cache = self.cache.write().await; 93 cache.insert( 94 client_id.to_string(), 95 CachedMetadata { 96 metadata: metadata.clone(), 97 cached_at: std::time::Instant::now(), 98 }, 99 ); 100 } 101 Ok(metadata) 102 } 103 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> { 104 if let Some(jwks) = &metadata.jwks { 105 return Ok(jwks.clone()); 106 } 107 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| { 108 OAuthError::InvalidClient( 109 "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 110 ) 111 })?; 112 { 113 let cache = self.jwks_cache.read().await; 114 if let Some(cached) = cache.get(jwks_uri) { 115 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 116 return Ok(cached.jwks.clone()); 117 } 118 } 119 } 120 let jwks = self.fetch_jwks(jwks_uri).await?; 121 { 122 let mut cache = self.jwks_cache.write().await; 123 cache.insert( 124 jwks_uri.clone(), 125 CachedJwks { 126 jwks: jwks.clone(), 127 cached_at: std::time::Instant::now(), 128 }, 129 ); 130 } 131 Ok(jwks) 132 } 133 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> { 134 if !jwks_uri.starts_with("https://") { 135 if !jwks_uri.starts_with("http://") 136 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1")) 137 { 138 return Err(OAuthError::InvalidClient( 139 "jwks_uri must use https (except for localhost)".to_string(), 140 )); 141 } 142 } 143 let response = self 144 .http_client 145 .get(jwks_uri) 146 .header("Accept", "application/json") 147 .send() 148 .await 149 .map_err(|e| { 150 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e)) 151 })?; 152 if !response.status().is_success() { 153 return Err(OAuthError::InvalidClient(format!( 154 "Failed to fetch JWKS: HTTP {}", 155 response.status() 156 ))); 157 } 158 let jwks: serde_json::Value = response 159 .json() 160 .await 161 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?; 162 if jwks.get("keys").and_then(|k| k.as_array()).is_none() { 163 return Err(OAuthError::InvalidClient( 164 "JWKS must contain a 'keys' array".to_string(), 165 )); 166 } 167 Ok(jwks) 168 } 169 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 170 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 171 return Err(OAuthError::InvalidClient( 172 "client_id must be a URL".to_string(), 173 )); 174 } 175 if client_id.starts_with("http://") 176 && !client_id.contains("localhost") 177 && !client_id.contains("127.0.0.1") 178 { 179 return Err(OAuthError::InvalidClient( 180 "Non-localhost client_id must use https".to_string(), 181 )); 182 } 183 let response = self 184 .http_client 185 .get(client_id) 186 .header("Accept", "application/json") 187 .send() 188 .await 189 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?; 190 if !response.status().is_success() { 191 return Err(OAuthError::InvalidClient(format!( 192 "Failed to fetch client metadata: HTTP {}", 193 response.status() 194 ))); 195 } 196 let mut metadata: ClientMetadata = response 197 .json() 198 .await 199 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?; 200 if metadata.client_id.is_empty() { 201 metadata.client_id = client_id.to_string(); 202 } else if metadata.client_id != client_id { 203 return Err(OAuthError::InvalidClient( 204 "client_id in metadata does not match request".to_string(), 205 )); 206 } 207 self.validate_metadata(&metadata)?; 208 Ok(metadata) 209 } 210 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 211 if metadata.redirect_uris.is_empty() { 212 return Err(OAuthError::InvalidClient( 213 "redirect_uris is required".to_string(), 214 )); 215 } 216 for uri in &metadata.redirect_uris { 217 self.validate_redirect_uri_format(uri)?; 218 } 219 if !metadata.grant_types.is_empty() 220 && !metadata.grant_types.contains(&"authorization_code".to_string()) 221 { 222 return Err(OAuthError::InvalidClient( 223 "authorization_code grant type is required".to_string(), 224 )); 225 } 226 if !metadata.response_types.is_empty() 227 && !metadata.response_types.contains(&"code".to_string()) 228 { 229 return Err(OAuthError::InvalidClient( 230 "code response type is required".to_string(), 231 )); 232 } 233 Ok(()) 234 } 235 pub fn validate_redirect_uri( 236 &self, 237 metadata: &ClientMetadata, 238 redirect_uri: &str, 239 ) -> Result<(), OAuthError> { 240 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) { 241 return Err(OAuthError::InvalidRequest( 242 "redirect_uri not registered for client".to_string(), 243 )); 244 } 245 Ok(()) 246 } 247 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 248 if uri.contains('#') { 249 return Err(OAuthError::InvalidClient( 250 "redirect_uri must not contain a fragment".to_string(), 251 )); 252 } 253 let parsed = reqwest::Url::parse(uri).map_err(|_| { 254 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)) 255 })?; 256 let scheme = parsed.scheme(); 257 if scheme == "http" { 258 let host = parsed.host_str().unwrap_or(""); 259 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" { 260 return Err(OAuthError::InvalidClient( 261 "http redirect_uri only allowed for localhost".to_string(), 262 )); 263 } 264 } else if scheme == "https" { 265 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') { 266 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) { 267 return Err(OAuthError::InvalidClient(format!( 268 "Invalid redirect_uri scheme: {}", 269 scheme 270 ))); 271 } 272 } else { 273 return Err(OAuthError::InvalidClient(format!( 274 "Invalid redirect_uri scheme: {}", 275 scheme 276 ))); 277 } 278 Ok(()) 279 } 280} 281impl ClientMetadata { 282 pub fn requires_dpop(&self) -> bool { 283 self.dpop_bound_access_tokens.unwrap_or(false) 284 } 285 pub fn auth_method(&self) -> &str { 286 self.token_endpoint_auth_method 287 .as_deref() 288 .unwrap_or("none") 289 } 290} 291pub async fn verify_client_auth( 292 cache: &ClientMetadataCache, 293 metadata: &ClientMetadata, 294 client_auth: &super::ClientAuth, 295) -> Result<(), OAuthError> { 296 let expected_method = metadata.auth_method(); 297 match (expected_method, client_auth) { 298 ("none", super::ClientAuth::None) => Ok(()), 299 ("none", _) => Err(OAuthError::InvalidClient( 300 "Client is configured for no authentication, but credentials were provided".to_string(), 301 )), 302 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 303 verify_private_key_jwt_async(cache, metadata, client_assertion).await 304 } 305 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( 306 "Client requires private_key_jwt authentication".to_string(), 307 )), 308 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => { 309 Err(OAuthError::InvalidClient( 310 "client_secret_post is not supported for ATProto OAuth".to_string(), 311 )) 312 } 313 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => { 314 Err(OAuthError::InvalidClient( 315 "client_secret_basic is not supported for ATProto OAuth".to_string(), 316 )) 317 } 318 (method, _) => Err(OAuthError::InvalidClient(format!( 319 "Unsupported or mismatched authentication method: {}", 320 method 321 ))), 322 } 323} 324async fn verify_private_key_jwt_async( 325 cache: &ClientMetadataCache, 326 metadata: &ClientMetadata, 327 client_assertion: &str, 328) -> Result<(), OAuthError> { 329 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 330 let parts: Vec<&str> = client_assertion.split('.').collect(); 331 if parts.len() != 3 { 332 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string())); 333 } 334 let header_bytes = URL_SAFE_NO_PAD 335 .decode(parts[0]) 336 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?; 337 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 338 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?; 339 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| { 340 OAuthError::InvalidClient("Missing alg in client_assertion".to_string()) 341 })?; 342 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") { 343 return Err(OAuthError::InvalidClient(format!( 344 "Unsupported client_assertion algorithm: {}", 345 alg 346 ))); 347 } 348 let kid = header.get("kid").and_then(|k| k.as_str()); 349 let payload_bytes = URL_SAFE_NO_PAD 350 .decode(parts[1]) 351 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?; 352 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 353 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?; 354 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| { 355 OAuthError::InvalidClient("Missing iss in client_assertion".to_string()) 356 })?; 357 if iss != metadata.client_id { 358 return Err(OAuthError::InvalidClient( 359 "client_assertion iss does not match client_id".to_string(), 360 )); 361 } 362 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| { 363 OAuthError::InvalidClient("Missing sub in client_assertion".to_string()) 364 })?; 365 if sub != metadata.client_id { 366 return Err(OAuthError::InvalidClient( 367 "client_assertion sub does not match client_id".to_string(), 368 )); 369 } 370 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| { 371 OAuthError::InvalidClient("Missing exp in client_assertion".to_string()) 372 })?; 373 let now = chrono::Utc::now().timestamp(); 374 if exp < now { 375 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string())); 376 } 377 let iat = payload.get("iat").and_then(|i| i.as_i64()); 378 if let Some(iat) = iat { 379 if iat > now + 60 { 380 return Err(OAuthError::InvalidClient( 381 "client_assertion iat is in the future".to_string(), 382 )); 383 } 384 } 385 let jwks = cache.get_jwks(metadata).await?; 386 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| { 387 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string()) 388 })?; 389 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid { 390 keys.iter() 391 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid)) 392 .collect() 393 } else { 394 keys.iter().collect() 395 }; 396 if matching_keys.is_empty() { 397 return Err(OAuthError::InvalidClient( 398 "No matching key found in client JWKS".to_string(), 399 )); 400 } 401 let signing_input = format!("{}.{}", parts[0], parts[1]); 402 let signature_bytes = URL_SAFE_NO_PAD 403 .decode(parts[2]) 404 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?; 405 for key in matching_keys { 406 let key_alg = key.get("alg").and_then(|a| a.as_str()); 407 if key_alg.is_some() && key_alg != Some(alg) { 408 continue; 409 } 410 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or(""); 411 let verified = match (alg, kty) { 412 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes), 413 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes), 414 ("RS256" | "RS384" | "RS512", "RSA") => { 415 verify_rsa(alg, key, &signing_input, &signature_bytes) 416 } 417 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes), 418 _ => continue, 419 }; 420 if verified.is_ok() { 421 return Ok(()); 422 } 423 } 424 Err(OAuthError::InvalidClient( 425 "client_assertion signature verification failed".to_string(), 426 )) 427} 428fn verify_es256( 429 key: &serde_json::Value, 430 signing_input: &str, 431 signature: &[u8], 432) -> Result<(), OAuthError> { 433 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 434 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 435 use p256::EncodedPoint; 436 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 437 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 438 })?; 439 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 440 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 441 })?; 442 let x_bytes = URL_SAFE_NO_PAD.decode(x) 443 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 444 let y_bytes = URL_SAFE_NO_PAD.decode(y) 445 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 446 let mut point_bytes = vec![0x04]; 447 point_bytes.extend_from_slice(&x_bytes); 448 point_bytes.extend_from_slice(&y_bytes); 449 let point = EncodedPoint::from_bytes(&point_bytes) 450 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 451 let verifying_key = VerifyingKey::from_encoded_point(&point) 452 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 453 let sig = Signature::from_slice(signature) 454 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?; 455 verifying_key 456 .verify(signing_input.as_bytes(), &sig) 457 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string())) 458} 459fn verify_es384( 460 key: &serde_json::Value, 461 signing_input: &str, 462 signature: &[u8], 463) -> Result<(), OAuthError> { 464 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 465 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 466 use p384::EncodedPoint; 467 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 468 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string()) 469 })?; 470 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| { 471 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string()) 472 })?; 473 let x_bytes = URL_SAFE_NO_PAD.decode(x) 474 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?; 475 let y_bytes = URL_SAFE_NO_PAD.decode(y) 476 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?; 477 let mut point_bytes = vec![0x04]; 478 point_bytes.extend_from_slice(&x_bytes); 479 point_bytes.extend_from_slice(&y_bytes); 480 let point = EncodedPoint::from_bytes(&point_bytes) 481 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?; 482 let verifying_key = VerifyingKey::from_encoded_point(&point) 483 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?; 484 let sig = Signature::from_slice(signature) 485 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?; 486 verifying_key 487 .verify(signing_input.as_bytes(), &sig) 488 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string())) 489} 490fn verify_rsa( 491 _alg: &str, 492 _key: &serde_json::Value, 493 _signing_input: &str, 494 _signature: &[u8], 495) -> Result<(), OAuthError> { 496 Err(OAuthError::InvalidClient( 497 "RSA signature verification not yet supported - use EC keys".to_string(), 498 )) 499} 500fn verify_eddsa( 501 key: &serde_json::Value, 502 signing_input: &str, 503 signature: &[u8], 504) -> Result<(), OAuthError> { 505 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 506 use ed25519_dalek::{Signature, Verifier, VerifyingKey}; 507 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or(""); 508 if crv != "Ed25519" { 509 return Err(OAuthError::InvalidClient(format!( 510 "Unsupported EdDSA curve: {}", 511 crv 512 ))); 513 } 514 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| { 515 OAuthError::InvalidClient("Missing x in OKP key".to_string()) 516 })?; 517 let x_bytes = URL_SAFE_NO_PAD.decode(x) 518 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?; 519 let key_bytes: [u8; 32] = x_bytes.try_into() 520 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?; 521 let verifying_key = VerifyingKey::from_bytes(&key_bytes) 522 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?; 523 let sig_bytes: [u8; 64] = signature.try_into() 524 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?; 525 let sig = Signature::from_bytes(&sig_bytes); 526 verifying_key 527 .verify(signing_input.as_bytes(), &sig) 528 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string())) 529}