this repo has no description
1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use super::OAuthError;
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ClientMetadata {
9 pub client_id: String,
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub client_name: Option<String>,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub client_uri: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub logo_uri: Option<String>,
16 pub redirect_uris: Vec<String>,
17 #[serde(default)]
18 pub grant_types: Vec<String>,
19 #[serde(default)]
20 pub response_types: Vec<String>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub scope: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub token_endpoint_auth_method: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub dpop_bound_access_tokens: Option<bool>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub jwks: Option<serde_json::Value>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub jwks_uri: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub application_type: Option<String>,
33}
34impl Default for ClientMetadata {
35 fn default() -> Self {
36 Self {
37 client_id: String::new(),
38 client_name: None,
39 client_uri: None,
40 logo_uri: None,
41 redirect_uris: Vec::new(),
42 grant_types: vec!["authorization_code".to_string()],
43 response_types: vec!["code".to_string()],
44 scope: None,
45 token_endpoint_auth_method: Some("none".to_string()),
46 dpop_bound_access_tokens: None,
47 jwks: None,
48 jwks_uri: None,
49 application_type: None,
50 }
51 }
52}
53#[derive(Clone)]
54pub struct ClientMetadataCache {
55 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
56 jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
57 http_client: Client,
58 cache_ttl_secs: u64,
59}
60struct CachedMetadata {
61 metadata: ClientMetadata,
62 cached_at: std::time::Instant,
63}
64struct CachedJwks {
65 jwks: serde_json::Value,
66 cached_at: std::time::Instant,
67}
68impl ClientMetadataCache {
69 pub fn new(cache_ttl_secs: u64) -> Self {
70 Self {
71 cache: Arc::new(RwLock::new(HashMap::new())),
72 jwks_cache: Arc::new(RwLock::new(HashMap::new())),
73 http_client: Client::builder()
74 .timeout(std::time::Duration::from_secs(30))
75 .connect_timeout(std::time::Duration::from_secs(10))
76 .build()
77 .unwrap_or_else(|_| Client::new()),
78 cache_ttl_secs,
79 }
80 }
81 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
82 {
83 let cache = self.cache.read().await;
84 if let Some(cached) = cache.get(client_id) {
85 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
86 return Ok(cached.metadata.clone());
87 }
88 }
89 }
90 let metadata = self.fetch_metadata(client_id).await?;
91 {
92 let mut cache = self.cache.write().await;
93 cache.insert(
94 client_id.to_string(),
95 CachedMetadata {
96 metadata: metadata.clone(),
97 cached_at: std::time::Instant::now(),
98 },
99 );
100 }
101 Ok(metadata)
102 }
103 pub async fn get_jwks(&self, metadata: &ClientMetadata) -> Result<serde_json::Value, OAuthError> {
104 if let Some(jwks) = &metadata.jwks {
105 return Ok(jwks.clone());
106 }
107 let jwks_uri = metadata.jwks_uri.as_ref().ok_or_else(|| {
108 OAuthError::InvalidClient(
109 "Client using private_key_jwt must have jwks or jwks_uri".to_string(),
110 )
111 })?;
112 {
113 let cache = self.jwks_cache.read().await;
114 if let Some(cached) = cache.get(jwks_uri) {
115 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
116 return Ok(cached.jwks.clone());
117 }
118 }
119 }
120 let jwks = self.fetch_jwks(jwks_uri).await?;
121 {
122 let mut cache = self.jwks_cache.write().await;
123 cache.insert(
124 jwks_uri.clone(),
125 CachedJwks {
126 jwks: jwks.clone(),
127 cached_at: std::time::Instant::now(),
128 },
129 );
130 }
131 Ok(jwks)
132 }
133 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<serde_json::Value, OAuthError> {
134 if !jwks_uri.starts_with("https://") {
135 if !jwks_uri.starts_with("http://")
136 || (!jwks_uri.contains("localhost") && !jwks_uri.contains("127.0.0.1"))
137 {
138 return Err(OAuthError::InvalidClient(
139 "jwks_uri must use https (except for localhost)".to_string(),
140 ));
141 }
142 }
143 let response = self
144 .http_client
145 .get(jwks_uri)
146 .header("Accept", "application/json")
147 .send()
148 .await
149 .map_err(|e| {
150 OAuthError::InvalidClient(format!("Failed to fetch JWKS from {}: {}", jwks_uri, e))
151 })?;
152 if !response.status().is_success() {
153 return Err(OAuthError::InvalidClient(format!(
154 "Failed to fetch JWKS: HTTP {}",
155 response.status()
156 )));
157 }
158 let jwks: serde_json::Value = response
159 .json()
160 .await
161 .map_err(|e| OAuthError::InvalidClient(format!("Invalid JWKS JSON: {}", e)))?;
162 if jwks.get("keys").and_then(|k| k.as_array()).is_none() {
163 return Err(OAuthError::InvalidClient(
164 "JWKS must contain a 'keys' array".to_string(),
165 ));
166 }
167 Ok(jwks)
168 }
169 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
170 if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
171 return Err(OAuthError::InvalidClient(
172 "client_id must be a URL".to_string(),
173 ));
174 }
175 if client_id.starts_with("http://")
176 && !client_id.contains("localhost")
177 && !client_id.contains("127.0.0.1")
178 {
179 return Err(OAuthError::InvalidClient(
180 "Non-localhost client_id must use https".to_string(),
181 ));
182 }
183 let response = self
184 .http_client
185 .get(client_id)
186 .header("Accept", "application/json")
187 .send()
188 .await
189 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?;
190 if !response.status().is_success() {
191 return Err(OAuthError::InvalidClient(format!(
192 "Failed to fetch client metadata: HTTP {}",
193 response.status()
194 )));
195 }
196 let mut metadata: ClientMetadata = response
197 .json()
198 .await
199 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?;
200 if metadata.client_id.is_empty() {
201 metadata.client_id = client_id.to_string();
202 } else if metadata.client_id != client_id {
203 return Err(OAuthError::InvalidClient(
204 "client_id in metadata does not match request".to_string(),
205 ));
206 }
207 self.validate_metadata(&metadata)?;
208 Ok(metadata)
209 }
210 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
211 if metadata.redirect_uris.is_empty() {
212 return Err(OAuthError::InvalidClient(
213 "redirect_uris is required".to_string(),
214 ));
215 }
216 for uri in &metadata.redirect_uris {
217 self.validate_redirect_uri_format(uri)?;
218 }
219 if !metadata.grant_types.is_empty()
220 && !metadata.grant_types.contains(&"authorization_code".to_string())
221 {
222 return Err(OAuthError::InvalidClient(
223 "authorization_code grant type is required".to_string(),
224 ));
225 }
226 if !metadata.response_types.is_empty()
227 && !metadata.response_types.contains(&"code".to_string())
228 {
229 return Err(OAuthError::InvalidClient(
230 "code response type is required".to_string(),
231 ));
232 }
233 Ok(())
234 }
235 pub fn validate_redirect_uri(
236 &self,
237 metadata: &ClientMetadata,
238 redirect_uri: &str,
239 ) -> Result<(), OAuthError> {
240 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) {
241 return Err(OAuthError::InvalidRequest(
242 "redirect_uri not registered for client".to_string(),
243 ));
244 }
245 Ok(())
246 }
247 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
248 if uri.contains('#') {
249 return Err(OAuthError::InvalidClient(
250 "redirect_uri must not contain a fragment".to_string(),
251 ));
252 }
253 let parsed = reqwest::Url::parse(uri).map_err(|_| {
254 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri))
255 })?;
256 let scheme = parsed.scheme();
257 if scheme == "http" {
258 let host = parsed.host_str().unwrap_or("");
259 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" {
260 return Err(OAuthError::InvalidClient(
261 "http redirect_uri only allowed for localhost".to_string(),
262 ));
263 }
264 } else if scheme == "https" {
265 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') {
266 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) {
267 return Err(OAuthError::InvalidClient(format!(
268 "Invalid redirect_uri scheme: {}",
269 scheme
270 )));
271 }
272 } else {
273 return Err(OAuthError::InvalidClient(format!(
274 "Invalid redirect_uri scheme: {}",
275 scheme
276 )));
277 }
278 Ok(())
279 }
280}
281impl ClientMetadata {
282 pub fn requires_dpop(&self) -> bool {
283 self.dpop_bound_access_tokens.unwrap_or(false)
284 }
285 pub fn auth_method(&self) -> &str {
286 self.token_endpoint_auth_method
287 .as_deref()
288 .unwrap_or("none")
289 }
290}
291pub async fn verify_client_auth(
292 cache: &ClientMetadataCache,
293 metadata: &ClientMetadata,
294 client_auth: &super::ClientAuth,
295) -> Result<(), OAuthError> {
296 let expected_method = metadata.auth_method();
297 match (expected_method, client_auth) {
298 ("none", super::ClientAuth::None) => Ok(()),
299 ("none", _) => Err(OAuthError::InvalidClient(
300 "Client is configured for no authentication, but credentials were provided".to_string(),
301 )),
302 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
303 verify_private_key_jwt_async(cache, metadata, client_assertion).await
304 }
305 ("private_key_jwt", _) => Err(OAuthError::InvalidClient(
306 "Client requires private_key_jwt authentication".to_string(),
307 )),
308 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => {
309 Err(OAuthError::InvalidClient(
310 "client_secret_post is not supported for ATProto OAuth".to_string(),
311 ))
312 }
313 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => {
314 Err(OAuthError::InvalidClient(
315 "client_secret_basic is not supported for ATProto OAuth".to_string(),
316 ))
317 }
318 (method, _) => Err(OAuthError::InvalidClient(format!(
319 "Unsupported or mismatched authentication method: {}",
320 method
321 ))),
322 }
323}
324async fn verify_private_key_jwt_async(
325 cache: &ClientMetadataCache,
326 metadata: &ClientMetadata,
327 client_assertion: &str,
328) -> Result<(), OAuthError> {
329 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
330 let parts: Vec<&str> = client_assertion.split('.').collect();
331 if parts.len() != 3 {
332 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string()));
333 }
334 let header_bytes = URL_SAFE_NO_PAD
335 .decode(parts[0])
336 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?;
337 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
338 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?;
339 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| {
340 OAuthError::InvalidClient("Missing alg in client_assertion".to_string())
341 })?;
342 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") {
343 return Err(OAuthError::InvalidClient(format!(
344 "Unsupported client_assertion algorithm: {}",
345 alg
346 )));
347 }
348 let kid = header.get("kid").and_then(|k| k.as_str());
349 let payload_bytes = URL_SAFE_NO_PAD
350 .decode(parts[1])
351 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
352 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
353 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?;
354 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| {
355 OAuthError::InvalidClient("Missing iss in client_assertion".to_string())
356 })?;
357 if iss != metadata.client_id {
358 return Err(OAuthError::InvalidClient(
359 "client_assertion iss does not match client_id".to_string(),
360 ));
361 }
362 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| {
363 OAuthError::InvalidClient("Missing sub in client_assertion".to_string())
364 })?;
365 if sub != metadata.client_id {
366 return Err(OAuthError::InvalidClient(
367 "client_assertion sub does not match client_id".to_string(),
368 ));
369 }
370 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| {
371 OAuthError::InvalidClient("Missing exp in client_assertion".to_string())
372 })?;
373 let now = chrono::Utc::now().timestamp();
374 if exp < now {
375 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string()));
376 }
377 let iat = payload.get("iat").and_then(|i| i.as_i64());
378 if let Some(iat) = iat {
379 if iat > now + 60 {
380 return Err(OAuthError::InvalidClient(
381 "client_assertion iat is in the future".to_string(),
382 ));
383 }
384 }
385 let jwks = cache.get_jwks(metadata).await?;
386 let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
387 OAuthError::InvalidClient("Invalid JWKS: missing keys array".to_string())
388 })?;
389 let matching_keys: Vec<&serde_json::Value> = if let Some(kid) = kid {
390 keys.iter()
391 .filter(|k| k.get("kid").and_then(|v| v.as_str()) == Some(kid))
392 .collect()
393 } else {
394 keys.iter().collect()
395 };
396 if matching_keys.is_empty() {
397 return Err(OAuthError::InvalidClient(
398 "No matching key found in client JWKS".to_string(),
399 ));
400 }
401 let signing_input = format!("{}.{}", parts[0], parts[1]);
402 let signature_bytes = URL_SAFE_NO_PAD
403 .decode(parts[2])
404 .map_err(|_| OAuthError::InvalidClient("Invalid signature encoding".to_string()))?;
405 for key in matching_keys {
406 let key_alg = key.get("alg").and_then(|a| a.as_str());
407 if key_alg.is_some() && key_alg != Some(alg) {
408 continue;
409 }
410 let kty = key.get("kty").and_then(|k| k.as_str()).unwrap_or("");
411 let verified = match (alg, kty) {
412 ("ES256", "EC") => verify_es256(key, &signing_input, &signature_bytes),
413 ("ES384", "EC") => verify_es384(key, &signing_input, &signature_bytes),
414 ("RS256" | "RS384" | "RS512", "RSA") => {
415 verify_rsa(alg, key, &signing_input, &signature_bytes)
416 }
417 ("EdDSA", "OKP") => verify_eddsa(key, &signing_input, &signature_bytes),
418 _ => continue,
419 };
420 if verified.is_ok() {
421 return Ok(());
422 }
423 }
424 Err(OAuthError::InvalidClient(
425 "client_assertion signature verification failed".to_string(),
426 ))
427}
428fn verify_es256(
429 key: &serde_json::Value,
430 signing_input: &str,
431 signature: &[u8],
432) -> Result<(), OAuthError> {
433 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
434 use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
435 use p256::EncodedPoint;
436 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
437 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
438 })?;
439 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
440 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
441 })?;
442 let x_bytes = URL_SAFE_NO_PAD.decode(x)
443 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
444 let y_bytes = URL_SAFE_NO_PAD.decode(y)
445 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
446 let mut point_bytes = vec![0x04];
447 point_bytes.extend_from_slice(&x_bytes);
448 point_bytes.extend_from_slice(&y_bytes);
449 let point = EncodedPoint::from_bytes(&point_bytes)
450 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
451 let verifying_key = VerifyingKey::from_encoded_point(&point)
452 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
453 let sig = Signature::from_slice(signature)
454 .map_err(|_| OAuthError::InvalidClient("Invalid ES256 signature format".to_string()))?;
455 verifying_key
456 .verify(signing_input.as_bytes(), &sig)
457 .map_err(|_| OAuthError::InvalidClient("ES256 signature verification failed".to_string()))
458}
459fn verify_es384(
460 key: &serde_json::Value,
461 signing_input: &str,
462 signature: &[u8],
463) -> Result<(), OAuthError> {
464 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
465 use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
466 use p384::EncodedPoint;
467 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
468 OAuthError::InvalidClient("Missing x coordinate in EC key".to_string())
469 })?;
470 let y = key.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
471 OAuthError::InvalidClient("Missing y coordinate in EC key".to_string())
472 })?;
473 let x_bytes = URL_SAFE_NO_PAD.decode(x)
474 .map_err(|_| OAuthError::InvalidClient("Invalid x coordinate encoding".to_string()))?;
475 let y_bytes = URL_SAFE_NO_PAD.decode(y)
476 .map_err(|_| OAuthError::InvalidClient("Invalid y coordinate encoding".to_string()))?;
477 let mut point_bytes = vec![0x04];
478 point_bytes.extend_from_slice(&x_bytes);
479 point_bytes.extend_from_slice(&y_bytes);
480 let point = EncodedPoint::from_bytes(&point_bytes)
481 .map_err(|_| OAuthError::InvalidClient("Invalid EC point".to_string()))?;
482 let verifying_key = VerifyingKey::from_encoded_point(&point)
483 .map_err(|_| OAuthError::InvalidClient("Invalid EC key".to_string()))?;
484 let sig = Signature::from_slice(signature)
485 .map_err(|_| OAuthError::InvalidClient("Invalid ES384 signature format".to_string()))?;
486 verifying_key
487 .verify(signing_input.as_bytes(), &sig)
488 .map_err(|_| OAuthError::InvalidClient("ES384 signature verification failed".to_string()))
489}
490fn verify_rsa(
491 _alg: &str,
492 _key: &serde_json::Value,
493 _signing_input: &str,
494 _signature: &[u8],
495) -> Result<(), OAuthError> {
496 Err(OAuthError::InvalidClient(
497 "RSA signature verification not yet supported - use EC keys".to_string(),
498 ))
499}
500fn verify_eddsa(
501 key: &serde_json::Value,
502 signing_input: &str,
503 signature: &[u8],
504) -> Result<(), OAuthError> {
505 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
506 use ed25519_dalek::{Signature, Verifier, VerifyingKey};
507 let crv = key.get("crv").and_then(|c| c.as_str()).unwrap_or("");
508 if crv != "Ed25519" {
509 return Err(OAuthError::InvalidClient(format!(
510 "Unsupported EdDSA curve: {}",
511 crv
512 )));
513 }
514 let x = key.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
515 OAuthError::InvalidClient("Missing x in OKP key".to_string())
516 })?;
517 let x_bytes = URL_SAFE_NO_PAD.decode(x)
518 .map_err(|_| OAuthError::InvalidClient("Invalid x encoding".to_string()))?;
519 let key_bytes: [u8; 32] = x_bytes.try_into()
520 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key length".to_string()))?;
521 let verifying_key = VerifyingKey::from_bytes(&key_bytes)
522 .map_err(|_| OAuthError::InvalidClient("Invalid Ed25519 key".to_string()))?;
523 let sig_bytes: [u8; 64] = signature.try_into()
524 .map_err(|_| OAuthError::InvalidClient("Invalid EdDSA signature length".to_string()))?;
525 let sig = Signature::from_bytes(&sig_bytes);
526 verifying_key
527 .verify(signing_input.as_bytes(), &sig)
528 .map_err(|_| OAuthError::InvalidClient("EdDSA signature verification failed".to_string()))
529}