this repo has no description
1use reqwest::Client; 2use serde::{Deserialize, Serialize}; 3use std::collections::HashMap; 4use std::sync::Arc; 5use tokio::sync::RwLock; 6 7use super::OAuthError; 8 9#[derive(Debug, Clone, Serialize, Deserialize)] 10pub struct ClientMetadata { 11 pub client_id: String, 12 #[serde(skip_serializing_if = "Option::is_none")] 13 pub client_name: Option<String>, 14 #[serde(skip_serializing_if = "Option::is_none")] 15 pub client_uri: Option<String>, 16 #[serde(skip_serializing_if = "Option::is_none")] 17 pub logo_uri: Option<String>, 18 pub redirect_uris: Vec<String>, 19 #[serde(default)] 20 pub grant_types: Vec<String>, 21 #[serde(default)] 22 pub response_types: Vec<String>, 23 #[serde(skip_serializing_if = "Option::is_none")] 24 pub scope: Option<String>, 25 #[serde(skip_serializing_if = "Option::is_none")] 26 pub token_endpoint_auth_method: Option<String>, 27 #[serde(skip_serializing_if = "Option::is_none")] 28 pub dpop_bound_access_tokens: Option<bool>, 29 #[serde(skip_serializing_if = "Option::is_none")] 30 pub jwks: Option<serde_json::Value>, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub jwks_uri: Option<String>, 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub application_type: Option<String>, 35} 36 37impl Default for ClientMetadata { 38 fn default() -> Self { 39 Self { 40 client_id: String::new(), 41 client_name: None, 42 client_uri: None, 43 logo_uri: None, 44 redirect_uris: Vec::new(), 45 grant_types: vec!["authorization_code".to_string()], 46 response_types: vec!["code".to_string()], 47 scope: None, 48 token_endpoint_auth_method: Some("none".to_string()), 49 dpop_bound_access_tokens: None, 50 jwks: None, 51 jwks_uri: None, 52 application_type: None, 53 } 54 } 55} 56 57#[derive(Clone)] 58pub struct ClientMetadataCache { 59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>, 60 http_client: Client, 61 cache_ttl_secs: u64, 62} 63 64struct CachedMetadata { 65 metadata: ClientMetadata, 66 cached_at: std::time::Instant, 67} 68 69impl ClientMetadataCache { 70 pub fn new(cache_ttl_secs: u64) -> Self { 71 Self { 72 cache: Arc::new(RwLock::new(HashMap::new())), 73 http_client: Client::new(), 74 cache_ttl_secs, 75 } 76 } 77 78 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 79 { 80 let cache = self.cache.read().await; 81 if let Some(cached) = cache.get(client_id) { 82 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs { 83 return Ok(cached.metadata.clone()); 84 } 85 } 86 } 87 88 let metadata = self.fetch_metadata(client_id).await?; 89 90 { 91 let mut cache = self.cache.write().await; 92 cache.insert( 93 client_id.to_string(), 94 CachedMetadata { 95 metadata: metadata.clone(), 96 cached_at: std::time::Instant::now(), 97 }, 98 ); 99 } 100 101 Ok(metadata) 102 } 103 104 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> { 105 if !client_id.starts_with("http://") && !client_id.starts_with("https://") { 106 return Err(OAuthError::InvalidClient( 107 "client_id must be a URL".to_string(), 108 )); 109 } 110 111 if client_id.starts_with("http://") 112 && !client_id.contains("localhost") 113 && !client_id.contains("127.0.0.1") 114 { 115 return Err(OAuthError::InvalidClient( 116 "Non-localhost client_id must use https".to_string(), 117 )); 118 } 119 120 let response = self 121 .http_client 122 .get(client_id) 123 .header("Accept", "application/json") 124 .send() 125 .await 126 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?; 127 128 if !response.status().is_success() { 129 return Err(OAuthError::InvalidClient(format!( 130 "Failed to fetch client metadata: HTTP {}", 131 response.status() 132 ))); 133 } 134 135 let mut metadata: ClientMetadata = response 136 .json() 137 .await 138 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?; 139 140 if metadata.client_id.is_empty() { 141 metadata.client_id = client_id.to_string(); 142 } else if metadata.client_id != client_id { 143 return Err(OAuthError::InvalidClient( 144 "client_id in metadata does not match request".to_string(), 145 )); 146 } 147 148 self.validate_metadata(&metadata)?; 149 150 Ok(metadata) 151 } 152 153 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> { 154 if metadata.redirect_uris.is_empty() { 155 return Err(OAuthError::InvalidClient( 156 "redirect_uris is required".to_string(), 157 )); 158 } 159 160 for uri in &metadata.redirect_uris { 161 self.validate_redirect_uri_format(uri)?; 162 } 163 164 if !metadata.grant_types.is_empty() 165 && !metadata.grant_types.contains(&"authorization_code".to_string()) 166 { 167 return Err(OAuthError::InvalidClient( 168 "authorization_code grant type is required".to_string(), 169 )); 170 } 171 172 if !metadata.response_types.is_empty() 173 && !metadata.response_types.contains(&"code".to_string()) 174 { 175 return Err(OAuthError::InvalidClient( 176 "code response type is required".to_string(), 177 )); 178 } 179 180 Ok(()) 181 } 182 183 pub fn validate_redirect_uri( 184 &self, 185 metadata: &ClientMetadata, 186 redirect_uri: &str, 187 ) -> Result<(), OAuthError> { 188 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) { 189 return Err(OAuthError::InvalidRequest( 190 "redirect_uri not registered for client".to_string(), 191 )); 192 } 193 Ok(()) 194 } 195 196 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> { 197 if uri.contains('#') { 198 return Err(OAuthError::InvalidClient( 199 "redirect_uri must not contain a fragment".to_string(), 200 )); 201 } 202 203 let parsed = reqwest::Url::parse(uri).map_err(|_| { 204 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri)) 205 })?; 206 207 let scheme = parsed.scheme(); 208 209 if scheme == "http" { 210 let host = parsed.host_str().unwrap_or(""); 211 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" { 212 return Err(OAuthError::InvalidClient( 213 "http redirect_uri only allowed for localhost".to_string(), 214 )); 215 } 216 } else if scheme == "https" { 217 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') { 218 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) { 219 return Err(OAuthError::InvalidClient(format!( 220 "Invalid redirect_uri scheme: {}", 221 scheme 222 ))); 223 } 224 } else { 225 return Err(OAuthError::InvalidClient(format!( 226 "Invalid redirect_uri scheme: {}", 227 scheme 228 ))); 229 } 230 231 Ok(()) 232 } 233} 234 235impl ClientMetadata { 236 pub fn requires_dpop(&self) -> bool { 237 self.dpop_bound_access_tokens.unwrap_or(false) 238 } 239 240 pub fn auth_method(&self) -> &str { 241 self.token_endpoint_auth_method 242 .as_deref() 243 .unwrap_or("none") 244 } 245} 246 247pub fn verify_client_auth( 248 metadata: &ClientMetadata, 249 client_auth: &super::ClientAuth, 250) -> Result<(), OAuthError> { 251 let expected_method = metadata.auth_method(); 252 253 match (expected_method, client_auth) { 254 ("none", super::ClientAuth::None) => Ok(()), 255 256 ("none", _) => Err(OAuthError::InvalidClient( 257 "Client is configured for no authentication, but credentials were provided".to_string(), 258 )), 259 260 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => { 261 verify_private_key_jwt(metadata, client_assertion) 262 } 263 264 ("private_key_jwt", _) => Err(OAuthError::InvalidClient( 265 "Client requires private_key_jwt authentication".to_string(), 266 )), 267 268 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => { 269 Err(OAuthError::InvalidClient( 270 "client_secret_post is not supported for ATProto OAuth".to_string(), 271 )) 272 } 273 274 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => { 275 Err(OAuthError::InvalidClient( 276 "client_secret_basic is not supported for ATProto OAuth".to_string(), 277 )) 278 } 279 280 (method, _) => Err(OAuthError::InvalidClient(format!( 281 "Unsupported or mismatched authentication method: {}", 282 method 283 ))), 284 } 285} 286 287fn verify_private_key_jwt( 288 metadata: &ClientMetadata, 289 client_assertion: &str, 290) -> Result<(), OAuthError> { 291 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 292 293 let parts: Vec<&str> = client_assertion.split('.').collect(); 294 if parts.len() != 3 { 295 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string())); 296 } 297 298 let header_bytes = URL_SAFE_NO_PAD 299 .decode(parts[0]) 300 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?; 301 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 302 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?; 303 304 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| { 305 OAuthError::InvalidClient("Missing alg in client_assertion".to_string()) 306 })?; 307 308 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") { 309 return Err(OAuthError::InvalidClient(format!( 310 "Unsupported client_assertion algorithm: {}", 311 alg 312 ))); 313 } 314 315 let payload_bytes = URL_SAFE_NO_PAD 316 .decode(parts[1]) 317 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?; 318 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 319 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?; 320 321 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| { 322 OAuthError::InvalidClient("Missing iss in client_assertion".to_string()) 323 })?; 324 if iss != metadata.client_id { 325 return Err(OAuthError::InvalidClient( 326 "client_assertion iss does not match client_id".to_string(), 327 )); 328 } 329 330 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| { 331 OAuthError::InvalidClient("Missing sub in client_assertion".to_string()) 332 })?; 333 if sub != metadata.client_id { 334 return Err(OAuthError::InvalidClient( 335 "client_assertion sub does not match client_id".to_string(), 336 )); 337 } 338 339 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| { 340 OAuthError::InvalidClient("Missing exp in client_assertion".to_string()) 341 })?; 342 let now = chrono::Utc::now().timestamp(); 343 if exp < now { 344 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string())); 345 } 346 347 let iat = payload.get("iat").and_then(|i| i.as_i64()); 348 if let Some(iat) = iat { 349 if iat > now + 60 { 350 return Err(OAuthError::InvalidClient( 351 "client_assertion iat is in the future".to_string(), 352 )); 353 } 354 } 355 356 if metadata.jwks.is_none() && metadata.jwks_uri.is_none() { 357 return Err(OAuthError::InvalidClient( 358 "Client using private_key_jwt must have jwks or jwks_uri".to_string(), 359 )); 360 } 361 362 Err(OAuthError::InvalidClient( 363 "private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(), 364 )) 365}