this repo has no description
1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use super::OAuthError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClientMetadata {
11 pub client_id: String,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub client_name: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub client_uri: Option<String>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub logo_uri: Option<String>,
18 pub redirect_uris: Vec<String>,
19 #[serde(default)]
20 pub grant_types: Vec<String>,
21 #[serde(default)]
22 pub response_types: Vec<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub scope: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub token_endpoint_auth_method: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub dpop_bound_access_tokens: Option<bool>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub jwks: Option<serde_json::Value>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub jwks_uri: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub application_type: Option<String>,
35}
36
37impl Default for ClientMetadata {
38 fn default() -> Self {
39 Self {
40 client_id: String::new(),
41 client_name: None,
42 client_uri: None,
43 logo_uri: None,
44 redirect_uris: Vec::new(),
45 grant_types: vec!["authorization_code".to_string()],
46 response_types: vec!["code".to_string()],
47 scope: None,
48 token_endpoint_auth_method: Some("none".to_string()),
49 dpop_bound_access_tokens: None,
50 jwks: None,
51 jwks_uri: None,
52 application_type: None,
53 }
54 }
55}
56
57#[derive(Clone)]
58pub struct ClientMetadataCache {
59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
60 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
61 http_client: Client,
62 cache_ttl_secs: u64,
63}
64
65struct CachedMetadata {
66 metadata: ClientMetadata,
67 cached_at: std::time::Instant,
68}
69
70struct CachedJwks {
71 jwks: serde_json::Value,
72 cached_at: std::time::Instant,
73}
74
75impl ClientMetadataCache {
76 pub fn new(cache_ttl_secs: u64) -> Self {
77 Self {
78 cache: Arc::new(RwLock::new(HashMap::new())),
79 jwks_cache: Arc::new(RwLock::new(HashMap::new())),
80 http_client: Client::builder()
81 .timeout(std::time::Duration::from_secs(30))
82 .connect_timeout(std::time::Duration::from_secs(10))
83 .build()
84 .unwrap_or_else(|_| Client::new()),
85 cache_ttl_secs,
86 }
87 }
88
89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
90 {
91 let cache = self.cache.read().await;
92 if let Some(cached) = cache.get(client_id) {
93 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
94 return Ok(cached.metadata.clone());
95 }
96 }
97 }
98 let metadata = self.fetch_metadata(client_id).await?;
99 {
100 let mut cache = self.cache.write().await;
101 cache.insert(
102 client_id.to_string(),
103 CachedMetadata {
104 metadata: metadata.clone(),
105 cached_at: std::time::Instant::now(),
106 },
107 );
108 }
109 Ok(metadata)
110 }
111
112 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
113 if let Some(jwks) = &metadata.jwks {
114 return Ok(jwks.clone());
115 }
116 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
117 OAuthError::InvalidClient(
118 "Client using private_key_jwt must have jwks or jwks_uri".to_string(),
119 )
120 })?;
121 {
122 let cache = self.jwks_cache.read().await;
123 if let Some(cached) = cache.get(jwks_uri) {
124 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
125 return Ok(cached.jwks.clone());
126 }
127 }
128 }
129 let jwks = self.fetch_jwks(jwks_uri).await?;
130 {
131 let mut cache = self.jwks_cache.write().await;
132 cache.insert(
133 jwks_uri.clone(),
134 CachedJwks {
135 jwks: jwks.clone(),
136 cached_at: std::time::Instant::now(),
137 },
138 );
139 }
140 Ok(jwks)
141 }
142
143 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
144 if !jwks_uri.starts_with("https://") {
145 if !jwks_uri.starts_with("http://")
146 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
147 {
148 return Err(OAuthError::InvalidClient(
149 "jwks_uri must use https (except for localhost)".to_string(),
150 ));
151 }
152 }
153 let response = self
154 .http_client
155 .get(jwks_uri)
156 .header("Accept", "application/json")
157 .send()
158 .await
159 .map_err(|e| {
160 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
161 })?;
162 if !response.status().is_success() {
163 return Err(OAuthError::InvalidClient(format!(
164 "Failed to fetch JWKS: HTTP {}",
165 response.status()
166 )));
167 }
168 let jwks: serde_json::Value = response
169 .json()
170 .await
171 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
172 if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
173 return Err(OAuthError::InvalidClient(
174 "JWKS must contain a 'keys' array".to_string(),
175 ));
176 }
177 Ok(jwks)
178 }
179
180 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
181 if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
182 return Err(OAuthError::InvalidClient(
183 "client_id must be a URL".to_string(),
184 ));
185 }
186 if client_id.starts_with("http://")
187 && !client_id.contains("localhost")
188 && !client_id.contains("127.0.0.1")
189 {
190 return Err(OAuthError::InvalidClient(
191 "Non-localhost client_id must use https".to_string(),
192 ));
193 }
194 let response = self
195 .http_client
196 .get(client_id)
197 .header("Accept", "application/json")
198 .send()
199 .await
200 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?;
201 if !response.status().is_success() {
202 return Err(OAuthError::InvalidClient(format!(
203 "Failed to fetch client metadata: HTTP {}",
204 response.status()
205 )));
206 }
207 let mut metadata: ClientMetadata = response
208 .json()
209 .await
210 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?;
211 if metadata.client_id.is_empty() {
212 metadata.client_id = client_id.to_string();
213 } else if metadata.client_id != client_id {
214 return Err(OAuthError::InvalidClient(
215 "client_id in metadata does not match request".to_string(),
216 ));
217 }
218 self.validate_metadata(&metadata)?;
219 Ok(metadata)
220 }
221
222 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
223 if metadata.redirect_uris.is_empty() {
224 return Err(OAuthError::InvalidClient(
225 "redirect_uris is required".to_string(),
226 ));
227 }
228 for uri in &metadata.redirect_uris {
229 self.validate_redirect_uri_format(uri)?;
230 }
231 if !metadata.grant_types.is_empty()
232 && !metadata.grant_types.contains(&"authorization_code".to_string())
233 {
234 return Err(OAuthError::InvalidClient(
235 "authorization_code grant type is required".to_string(),
236 ));
237 }
238 if !metadata.response_types.is_empty()
239 && !metadata.response_types.contains(&"code".to_string())
240 {
241 return Err(OAuthError::InvalidClient(
242 "code response type is required".to_string(),
243 ));
244 }
245 Ok(())
246 }
247
248 pub fn validate_redirect_uri(
249 &self,
250 metadata: &ClientMetadata,
251 redirect_uri: &str,
252 ) -> Result<(), OAuthError> {
253 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) {
254 return Err(OAuthError::InvalidRequest(
255 "redirect_uri not registered for client".to_string(),
256 ));
257 }
258 Ok(())
259 }
260
261 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
262 if uri.contains('#') {
263 return Err(OAuthError::InvalidClient(
264 "redirect_uri must not contain a fragment".to_string(),
265 ));
266 }
267 let parsed = reqwest::Url::parse(uri).map_err(|_| {
268 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri))
269 })?;
270 let scheme = parsed.scheme();
271 if scheme == "http" {
272 let host = parsed.host_str().unwrap_or("");
273 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" {
274 return Err(OAuthError::InvalidClient(
275 "http redirect_uri only allowed for localhost".to_string(),
276 ));
277 }
278 } else if scheme == "https" {
279 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') {
280 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) {
281 return Err(OAuthError::InvalidClient(format!(
282 "Invalid redirect_uri scheme: {}",
283 scheme
284 )));
285 }
286 } else {
287 return Err(OAuthError::InvalidClient(format!(
288 "Invalid redirect_uri scheme: {}",
289 scheme
290 )));
291 }
292 Ok(())
293 }
294}
295
296impl ClientMetadata {
297 pub fn requires_dpop(&self) -> bool {
298 self.dpop_bound_access_tokens.unwrap_or(false)
299 }
300
301 pub fn auth_method(&self) -> &str {
302 self.token_endpoint_auth_method
303 .as_deref()
304 .unwrap_or("none")
305 }
306}
307
308pub async fn verify_client_auth(
309 cache: &ClientMetadataCache,
310 metadata: &ClientMetadata,
311 client_auth: &super::ClientAuth,
312) -> Result<(), OAuthError> {
313 let expected_method = metadata.auth_method();
314 match (expected_method, client_auth) {
315 ("none", super::ClientAuth::None) => Ok(()),
316 ("none", _) => Err(OAuthError::InvalidClient(
317 "Client is configured for no authentication, but credentials were provided".to_string(),
318 )),
319 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
320 verify_private_key_jwt_async(cache, metadata, client_assertion).await
321 }
322 ("private_key_jwt", _) => Err(OAuthError::InvalidClient(
323 "Client requires private_key_jwt authentication".to_string(),
324 )),
325 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => {
326 Err(OAuthError::InvalidClient(
327 "client_secret_post is not supported for ATProto OAuth".to_string(),
328 ))
329 }
330 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => {
331 Err(OAuthError::InvalidClient(
332 "client_secret_basic is not supported for ATProto OAuth".to_string(),
333 ))
334 }
335 (method, _) => Err(OAuthError::InvalidClient(format!(
336 "Unsupported or mismatched authentication method: {}",
337 method
338 ))),
339 }
340}
341
342async fn verify_private_key_jwt_async(
343 cache: &ClientMetadataCache,
344 metadata: &ClientMetadata,
345 client_assertion: &str,
346) -> Result<(), OAuthError> {
347 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
348 let parts: Vec<&str> = client_assertion.split('.').collect();
349 if parts.len() != 3 {
350 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string()));
351 }
352 let header_bytes = URL_SAFE_NO_PAD
353 .decode(parts[0])
354 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?;
355 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
356 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?;
357 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| {
358 OAuthError::InvalidClient("Missing alg in client_assertion".to_string())
359 })?;
360 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") {
361 return Err(OAuthError::InvalidClient(format!(
362 "Unsupported client_assertion algorithm: {}",
363 alg
364 )));
365 }
366 let kid = header.get("kid").and_then(|k| k.as_str());
367 let payload_bytes = URL_SAFE_NO_PAD
368 .decode(parts[1])
369 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
370 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
371 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?;
372 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| {
373 OAuthError::InvalidClient("Missing iss in client_assertion".to_string())
374 })?;
375 if iss != metadata.client_id {
376 return Err(OAuthError::InvalidClient(
377 "client_assertion iss does not match client_id".to_string(),
378 ));
379 }
380 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| {
381 OAuthError::InvalidClient("Missing sub in client_assertion".to_string())
382 })?;
383 if sub != metadata.client_id {
384 return Err(OAuthError::InvalidClient(
385 "client_assertion sub does not match client_id".to_string(),
386 ));
387 }
388 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| {
389 OAuthError::InvalidClient("Missing exp in client_assertion".to_string())
390 })?;
391 let now = chrono::Utc::now().timestamp();
392 if exp < now {
393 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string()));
394 }
395 let iat = payload.get("iat").and_then(|i| i.as_i64());
396 if let Some(iat) = iat {
397 if iat > now + 60 {
398 return Err(OAuthError::InvalidClient(
399 "client_assertion iat is in the future".to_string(),
400 ));
401 }
402 }
403 let jwks = cache.get_jwks(metadata).await?;
404 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
405 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
406 })?;
407 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
408 keys.iter()
409 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
410 .collect()
411 } else {
412 keys.iter().collect()
413 };
414 if matching_keys.is_empty() {
415 return Err(OAuthError::InvalidClient(
416 "No matching key found in client JWKS".to_string(),
417 ));
418 }
419 let signing_input = format!("{}.{}", parts[0], parts[1]);
420 let signature_bytes = URL_SAFE_NO_PAD
421 .decode(parts[2])
422 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
423 for key in matching_keys {
424 let key_alg = key.get("alg").and_then(|a| a.as_str());
425 if key_alg.is_some() && key_alg != Some(alg) {
426 continue;
427 }
428 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
429 let verified = match (alg, kty) {
430 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
431 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
432 ("RS256" | "RS384" | "RS512", "RSA") => {
433 verify_rsa(alg, key, &signing_input, &signature_bytes)
434 }
435 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
436 _ => continue,
437 };
438 if verified.is_ok() {
439 return Ok(());
440 }
441 }
442 Err(OAuthError::InvalidClient(
443 "client_assertion signature verification failed".to_string(),
444 ))
445}
446
447fn verify_es256(
448 key: &serde_json::Value,
449 signing_input: &str,
450 signature: &[u8],
451) -> Result<(), OAuthError> {
452 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
453 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
454 use p256::EncodedPoint;
455 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
456 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
457 })?;
458 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
459 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
460 })?;
461 let x_bytes = URL_SAFE_NO_PAD.decode(x)
462 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
463 let y_bytes = URL_SAFE_NO_PAD.decode(y)
464 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
465 let mut point_bytes = vec![0x04];
466 point_bytes.extend_from_slice(&x_bytes);
467 point_bytes.extend_from_slice(&y_bytes);
468 let point = EncodedPoint::from_bytes(&point_bytes)
469 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
470 let verifying_key = VerifyingKey::from_encoded_point(&point)
471 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
472 let sig = Signature::from_slice(signature)
473 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
474 verifying_key
475 .verify(signing_input.as_bytes(), &sig)
476 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
477}
478
479fn verify_es384(
480 key: &serde_json::Value,
481 signing_input: &str,
482 signature: &[u8],
483) -> Result<(), OAuthError> {
484 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
485 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
486 use p384::EncodedPoint;
487 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
488 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
489 })?;
490 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
491 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
492 })?;
493 let x_bytes = URL_SAFE_NO_PAD.decode(x)
494 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
495 let y_bytes = URL_SAFE_NO_PAD.decode(y)
496 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
497 let mut point_bytes = vec![0x04];
498 point_bytes.extend_from_slice(&x_bytes);
499 point_bytes.extend_from_slice(&y_bytes);
500 let point = EncodedPoint::from_bytes(&point_bytes)
501 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
502 let verifying_key = VerifyingKey::from_encoded_point(&point)
503 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
504 let sig = Signature::from_slice(signature)
505 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
506 verifying_key
507 .verify(signing_input.as_bytes(), &sig)
508 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
509}
510
511fn verify_rsa(
512 _alg: &str,
513 _key: &serde_json::Value,
514 _signing_input: &str,
515 _signature: &[u8],
516) -> Result<(), OAuthError> {
517 Err(OAuthError::InvalidClient(
518 "RSA signature verification not yet supported - use EC keys".to_string(),
519 ))
520}
521
522fn verify_eddsa(
523 key: &serde_json::Value,
524 signing_input: &str,
525 signature: &[u8],
526) -> Result<(), OAuthError> {
527 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
528 use ed25519_dalek::{Signature, Verifier, VerifyingKey};
529 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
530 if crv != "Ed25519" {
531 return Err(OAuthError::InvalidClient(format!(
532 "Unsupported EdDSA curve: {}",
533 crv
534 )));
535 }
536 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
537 OAuthError::InvalidClient("Missing x in OKP key".to_string())
538 })?;
539 let x_bytes = URL_SAFE_NO_PAD.decode(x)
540 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
541 let key_bytes: [u8; 32] = x_bytes.try_into()
542 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
543 let verifying_key = VerifyingKey::from_bytes(&key_bytes)
544 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
545 let sig_bytes: [u8; 64] = signature.try_into()
546 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
547 let sig = Signature::from_bytes(&sig_bytes);
548 verifying_key
549 .verify(signing_input.as_bytes(), &sig)
550 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
551}