this repo has no description
1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use super::OAuthError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClientMetadata {
11 pub client_id: String,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub client_name: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub client_uri: Option<String>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub logo_uri: Option<String>,
18 pub redirect_uris: Vec<String>,
19 #[serde(default)]
20 pub grant_types: Vec<String>,
21 #[serde(default)]
22 pub response_types: Vec<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub scope: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub token_endpoint_auth_method: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub dpop_bound_access_tokens: Option<bool>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub jwks: Option<serde_json::Value>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub jwks_uri: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub application_type: Option<String>,
35}
36
37impl Default for ClientMetadata {
38 fn default() -> Self {
39 Self {
40 client_id: String::new(),
41 client_name: None,
42 client_uri: None,
43 logo_uri: None,
44 redirect_uris: Vec::new(),
45 grant_types: vec!["authorization_code".to_string()],
46 response_types: vec!["code".to_string()],
47 scope: None,
48 token_endpoint_auth_method: Some("none".to_string()),
49 dpop_bound_access_tokens: None,
50 jwks: None,
51 jwks_uri: None,
52 application_type: None,
53 }
54 }
55}
56
57#[derive(Clone)]
58pub struct ClientMetadataCache {
59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
60 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
61 http_client: Client,
62 cache_ttl_secs: u64,
63}
64
65struct CachedMetadata {
66 metadata: ClientMetadata,
67 cached_at: std::time::Instant,
68}
69
70struct CachedJwks {
71 jwks: serde_json::Value,
72 cached_at: std::time::Instant,
73}
74
75impl ClientMetadataCache {
76 pub fn new(cache_ttl_secs: u64) -> Self {
77 Self {
78 cache: Arc::new(RwLock::new(HashMap::new())),
79 jwks_cache: Arc::new(RwLock::new(HashMap::new())),
80 http_client: Client::builder()
81 .timeout(std::time::Duration::from_secs(30))
82 .connect_timeout(std::time::Duration::from_secs(10))
83 .build()
84 .unwrap_or_else(|_| Client::new()),
85 cache_ttl_secs,
86 }
87 }
88
89 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
90 {
91 let cache = self.cache.read().await;
92 if let Some(cached) = cache.get(client_id) {
93 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
94 return Ok(cached.metadata.clone());
95 }
96 }
97 }
98
99 let metadata = self.fetch_metadata(client_id).await?;
100
101 {
102 let mut cache = self.cache.write().await;
103 cache.insert(
104 client_id.to_string(),
105 CachedMetadata {
106 metadata: metadata.clone(),
107 cached_at: std::time::Instant::now(),
108 },
109 );
110 }
111
112 Ok(metadata)
113 }
114
115 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
116 if let Some(jwks) = &metadata.jwks {
117 return Ok(jwks.clone());
118 }
119
120 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
121 OAuthError::InvalidClient(
122 "Client using private_key_jwt must have jwks or jwks_uri".to_string(),
123 )
124 })?;
125
126 {
127 let cache = self.jwks_cache.read().await;
128 if let Some(cached) = cache.get(jwks_uri) {
129 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
130 return Ok(cached.jwks.clone());
131 }
132 }
133 }
134
135 let jwks = self.fetch_jwks(jwks_uri).await?;
136
137 {
138 let mut cache = self.jwks_cache.write().await;
139 cache.insert(
140 jwks_uri.clone(),
141 CachedJwks {
142 jwks: jwks.clone(),
143 cached_at: std::time::Instant::now(),
144 },
145 );
146 }
147
148 Ok(jwks)
149 }
150
151 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
152 if !jwks_uri.starts_with("https://") {
153 if !jwks_uri.starts_with("http://")
154 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
155 {
156 return Err(OAuthError::InvalidClient(
157 "jwks_uri must use https (except for localhost)".to_string(),
158 ));
159 }
160 }
161
162 let response = self
163 .http_client
164 .get(jwks_uri)
165 .header("Accept", "application/json")
166 .send()
167 .await
168 .map_err(|e| {
169 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
170 })?;
171
172 if !response.status().is_success() {
173 return Err(OAuthError::InvalidClient(format!(
174 "Failed to fetch JWKS: HTTP {}",
175 response.status()
176 )));
177 }
178
179 let jwks: serde_json::Value = response
180 .json()
181 .await
182 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
183
184 if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
185 return Err(OAuthError::InvalidClient(
186 "JWKS must contain a 'keys' array".to_string(),
187 ));
188 }
189
190 Ok(jwks)
191 }
192
193 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
194 if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
195 return Err(OAuthError::InvalidClient(
196 "client_id must be a URL".to_string(),
197 ));
198 }
199
200 if client_id.starts_with("http://")
201 && !client_id.contains("localhost")
202 && !client_id.contains("127.0.0.1")
203 {
204 return Err(OAuthError::InvalidClient(
205 "Non-localhost client_id must use https".to_string(),
206 ));
207 }
208
209 let response = self
210 .http_client
211 .get(client_id)
212 .header("Accept", "application/json")
213 .send()
214 .await
215 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?;
216
217 if !response.status().is_success() {
218 return Err(OAuthError::InvalidClient(format!(
219 "Failed to fetch client metadata: HTTP {}",
220 response.status()
221 )));
222 }
223
224 let mut metadata: ClientMetadata = response
225 .json()
226 .await
227 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?;
228
229 if metadata.client_id.is_empty() {
230 metadata.client_id = client_id.to_string();
231 } else if metadata.client_id != client_id {
232 return Err(OAuthError::InvalidClient(
233 "client_id in metadata does not match request".to_string(),
234 ));
235 }
236
237 self.validate_metadata(&metadata)?;
238
239 Ok(metadata)
240 }
241
242 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
243 if metadata.redirect_uris.is_empty() {
244 return Err(OAuthError::InvalidClient(
245 "redirect_uris is required".to_string(),
246 ));
247 }
248
249 for uri in &metadata.redirect_uris {
250 self.validate_redirect_uri_format(uri)?;
251 }
252
253 if !metadata.grant_types.is_empty()
254 && !metadata.grant_types.contains(&"authorization_code".to_string())
255 {
256 return Err(OAuthError::InvalidClient(
257 "authorization_code grant type is required".to_string(),
258 ));
259 }
260
261 if !metadata.response_types.is_empty()
262 && !metadata.response_types.contains(&"code".to_string())
263 {
264 return Err(OAuthError::InvalidClient(
265 "code response type is required".to_string(),
266 ));
267 }
268
269 Ok(())
270 }
271
272 pub fn validate_redirect_uri(
273 &self,
274 metadata: &ClientMetadata,
275 redirect_uri: &str,
276 ) -> Result<(), OAuthError> {
277 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) {
278 return Err(OAuthError::InvalidRequest(
279 "redirect_uri not registered for client".to_string(),
280 ));
281 }
282 Ok(())
283 }
284
285 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
286 if uri.contains('#') {
287 return Err(OAuthError::InvalidClient(
288 "redirect_uri must not contain a fragment".to_string(),
289 ));
290 }
291
292 let parsed = reqwest::Url::parse(uri).map_err(|_| {
293 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri))
294 })?;
295
296 let scheme = parsed.scheme();
297
298 if scheme == "http" {
299 let host = parsed.host_str().unwrap_or("");
300 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" {
301 return Err(OAuthError::InvalidClient(
302 "http redirect_uri only allowed for localhost".to_string(),
303 ));
304 }
305 } else if scheme == "https" {
306 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') {
307 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) {
308 return Err(OAuthError::InvalidClient(format!(
309 "Invalid redirect_uri scheme: {}",
310 scheme
311 )));
312 }
313 } else {
314 return Err(OAuthError::InvalidClient(format!(
315 "Invalid redirect_uri scheme: {}",
316 scheme
317 )));
318 }
319
320 Ok(())
321 }
322}
323
324impl ClientMetadata {
325 pub fn requires_dpop(&self) -> bool {
326 self.dpop_bound_access_tokens.unwrap_or(false)
327 }
328
329 pub fn auth_method(&self) -> &str {
330 self.token_endpoint_auth_method
331 .as_deref()
332 .unwrap_or("none")
333 }
334}
335
336pub async fn verify_client_auth(
337 cache: &ClientMetadataCache,
338 metadata: &ClientMetadata,
339 client_auth: &super::ClientAuth,
340) -> Result<(), OAuthError> {
341 let expected_method = metadata.auth_method();
342
343 match (expected_method, client_auth) {
344 ("none", super::ClientAuth::None) => Ok(()),
345
346 ("none", _) => Err(OAuthError::InvalidClient(
347 "Client is configured for no authentication, but credentials were provided".to_string(),
348 )),
349
350 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
351 verify_private_key_jwt_async(cache, metadata, client_assertion).await
352 }
353
354 ("private_key_jwt", _) => Err(OAuthError::InvalidClient(
355 "Client requires private_key_jwt authentication".to_string(),
356 )),
357
358 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => {
359 Err(OAuthError::InvalidClient(
360 "client_secret_post is not supported for ATProto OAuth".to_string(),
361 ))
362 }
363
364 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => {
365 Err(OAuthError::InvalidClient(
366 "client_secret_basic is not supported for ATProto OAuth".to_string(),
367 ))
368 }
369
370 (method, _) => Err(OAuthError::InvalidClient(format!(
371 "Unsupported or mismatched authentication method: {}",
372 method
373 ))),
374 }
375}
376
377async fn verify_private_key_jwt_async(
378 cache: &ClientMetadataCache,
379 metadata: &ClientMetadata,
380 client_assertion: &str,
381) -> Result<(), OAuthError> {
382 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
383
384 let parts: Vec<&str> = client_assertion.split('.').collect();
385 if parts.len() != 3 {
386 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string()));
387 }
388
389 let header_bytes = URL_SAFE_NO_PAD
390 .decode(parts[0])
391 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?;
392 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
393 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?;
394
395 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| {
396 OAuthError::InvalidClient("Missing alg in client_assertion".to_string())
397 })?;
398
399 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") {
400 return Err(OAuthError::InvalidClient(format!(
401 "Unsupported client_assertion algorithm: {}",
402 alg
403 )));
404 }
405
406 let kid = header.get("kid").and_then(|k| k.as_str());
407
408 let payload_bytes = URL_SAFE_NO_PAD
409 .decode(parts[1])
410 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
411 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
412 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?;
413
414 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| {
415 OAuthError::InvalidClient("Missing iss in client_assertion".to_string())
416 })?;
417 if iss != metadata.client_id {
418 return Err(OAuthError::InvalidClient(
419 "client_assertion iss does not match client_id".to_string(),
420 ));
421 }
422
423 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| {
424 OAuthError::InvalidClient("Missing sub in client_assertion".to_string())
425 })?;
426 if sub != metadata.client_id {
427 return Err(OAuthError::InvalidClient(
428 "client_assertion sub does not match client_id".to_string(),
429 ));
430 }
431
432 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| {
433 OAuthError::InvalidClient("Missing exp in client_assertion".to_string())
434 })?;
435 let now = chrono::Utc::now().timestamp();
436 if exp < now {
437 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string()));
438 }
439
440 let iat = payload.get("iat").and_then(|i| i.as_i64());
441 if let Some(iat) = iat {
442 if iat > now + 60 {
443 return Err(OAuthError::InvalidClient(
444 "client_assertion iat is in the future".to_string(),
445 ));
446 }
447 }
448
449 let jwks = cache.get_jwks(metadata).await?;
450 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
451 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
452 })?;
453
454 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
455 keys.iter()
456 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
457 .collect()
458 } else {
459 keys.iter().collect()
460 };
461
462 if matching_keys.is_empty() {
463 return Err(OAuthError::InvalidClient(
464 "No matching key found in client JWKS".to_string(),
465 ));
466 }
467
468 let signing_input = format!("{}.{}", parts[0], parts[1]);
469 let signature_bytes = URL_SAFE_NO_PAD
470 .decode(parts[2])
471 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
472
473 for key in matching_keys {
474 let key_alg = key.get("alg").and_then(|a| a.as_str());
475 if key_alg.is_some() && key_alg != Some(alg) {
476 continue;
477 }
478
479 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
480
481 let verified = match (alg, kty) {
482 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
483 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
484 ("RS256" | "RS384" | "RS512", "RSA") => {
485 verify_rsa(alg, key, &signing_input, &signature_bytes)
486 }
487 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
488 _ => continue,
489 };
490
491 if verified.is_ok() {
492 return Ok(());
493 }
494 }
495
496 Err(OAuthError::InvalidClient(
497 "client_assertion signature verification failed".to_string(),
498 ))
499}
500
501fn verify_es256(
502 key: &serde_json::Value,
503 signing_input: &str,
504 signature: &[u8],
505) -> Result<(), OAuthError> {
506 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
507 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
508 use p256::EncodedPoint;
509
510 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
511 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
512 })?;
513 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
514 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
515 })?;
516
517 let x_bytes = URL_SAFE_NO_PAD.decode(x)
518 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
519 let y_bytes = URL_SAFE_NO_PAD.decode(y)
520 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
521
522 let mut point_bytes = vec![0x04];
523 point_bytes.extend_from_slice(&x_bytes);
524 point_bytes.extend_from_slice(&y_bytes);
525
526 let point = EncodedPoint::from_bytes(&point_bytes)
527 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
528 let verifying_key = VerifyingKey::from_encoded_point(&point)
529 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
530
531 let sig = Signature::from_slice(signature)
532 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
533
534 verifying_key
535 .verify(signing_input.as_bytes(), &sig)
536 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
537}
538
539fn verify_es384(
540 key: &serde_json::Value,
541 signing_input: &str,
542 signature: &[u8],
543) -> Result<(), OAuthError> {
544 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
545 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
546 use p384::EncodedPoint;
547
548 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
549 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
550 })?;
551 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
552 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
553 })?;
554
555 let x_bytes = URL_SAFE_NO_PAD.decode(x)
556 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
557 let y_bytes = URL_SAFE_NO_PAD.decode(y)
558 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
559
560 let mut point_bytes = vec![0x04];
561 point_bytes.extend_from_slice(&x_bytes);
562 point_bytes.extend_from_slice(&y_bytes);
563
564 let point = EncodedPoint::from_bytes(&point_bytes)
565 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
566 let verifying_key = VerifyingKey::from_encoded_point(&point)
567 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
568
569 let sig = Signature::from_slice(signature)
570 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
571
572 verifying_key
573 .verify(signing_input.as_bytes(), &sig)
574 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
575}
576
577fn verify_rsa(
578 _alg: &str,
579 _key: &serde_json::Value,
580 _signing_input: &str,
581 _signature: &[u8],
582) -> Result<(), OAuthError> {
583 Err(OAuthError::InvalidClient(
584 "RSA signature verification not yet supported - use EC keys".to_string(),
585 ))
586}
587
588fn verify_eddsa(
589 key: &serde_json::Value,
590 signing_input: &str,
591 signature: &[u8],
592) -> Result<(), OAuthError> {
593 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
594 use ed25519_dalek::{Signature, Verifier, VerifyingKey};
595
596 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
597 if crv != "Ed25519" {
598 return Err(OAuthError::InvalidClient(format!(
599 "Unsupported EdDSA curve: {}",
600 crv
601 )));
602 }
603
604 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
605 OAuthError::InvalidClient("Missing x in OKP key".to_string())
606 })?;
607
608 let x_bytes = URL_SAFE_NO_PAD.decode(x)
609 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
610
611 let key_bytes: [u8; 32] = x_bytes.try_into()
612 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
613
614 let verifying_key = VerifyingKey::from_bytes(&key_bytes)
615 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
616
617 let sig_bytes: [u8; 64] = signature.try_into()
618 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
619
620 let sig = Signature::from_bytes(&sig_bytes);
621
622 verifying_key
623 .verify(signing_input.as_bytes(), &sig)
624 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
625}