this repo has no description
1use reqwest::Client;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use super::OAuthError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ClientMetadata {
11 pub client_id: String,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub client_name: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub client_uri: Option<String>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub logo_uri: Option<String>,
18 pub redirect_uris: Vec<String>,
19 #[serde(default)]
20 pub grant_types: Vec<String>,
21 #[serde(default)]
22 pub response_types: Vec<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub scope: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub token_endpoint_auth_method: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub dpop_bound_access_tokens: Option<bool>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub jwks: Option<serde_json::Value>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub jwks_uri: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub application_type: Option<String>,
35}
36
37impl Default for ClientMetadata {
38 fn default() -> Self {
39 Self {
40 client_id: String::new(),
41 client_name: None,
42 client_uri: None,
43 logo_uri: None,
44 redirect_uris: Vec::new(),
45 grant_types: vec!["authorization_code".to_string()],
46 response_types: vec!["code".to_string()],
47 scope: None,
48 token_endpoint_auth_method: Some("none".to_string()),
49 dpop_bound_access_tokens: None,
50 jwks: None,
51 jwks_uri: None,
52 application_type: None,
53 }
54 }
55}
56
57#[derive(Clone)]
58pub struct ClientMetadataCache {
59 cache: Arc<RwLock<HashMap<String, CachedMetadata>>>,
60 http_client: Client,
61 cache_ttl_secs: u64,
62}
63
64struct CachedMetadata {
65 metadata: ClientMetadata,
66 cached_at: std::time::Instant,
67}
68
69impl ClientMetadataCache {
70 pub fn new(cache_ttl_secs: u64) -> Self {
71 Self {
72 cache: Arc::new(RwLock::new(HashMap::new())),
73 http_client: Client::new(),
74 cache_ttl_secs,
75 }
76 }
77
78 pub async fn get(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
79 {
80 let cache = self.cache.read().await;
81 if let Some(cached) = cache.get(client_id) {
82 if cached.cached_at.elapsed().as_secs() < self.cache_ttl_secs {
83 return Ok(cached.metadata.clone());
84 }
85 }
86 }
87
88 let metadata = self.fetch_metadata(client_id).await?;
89
90 {
91 let mut cache = self.cache.write().await;
92 cache.insert(
93 client_id.to_string(),
94 CachedMetadata {
95 metadata: metadata.clone(),
96 cached_at: std::time::Instant::now(),
97 },
98 );
99 }
100
101 Ok(metadata)
102 }
103
104 async fn fetch_metadata(&self, client_id: &str) -> Result<ClientMetadata, OAuthError> {
105 if !client_id.starts_with("http://") && !client_id.starts_with("https://") {
106 return Err(OAuthError::InvalidClient(
107 "client_id must be a URL".to_string(),
108 ));
109 }
110
111 if client_id.starts_with("http://")
112 && !client_id.contains("localhost")
113 && !client_id.contains("127.0.0.1")
114 {
115 return Err(OAuthError::InvalidClient(
116 "Non-localhost client_id must use https".to_string(),
117 ));
118 }
119
120 let response = self
121 .http_client
122 .get(client_id)
123 .header("Accept", "application/json")
124 .send()
125 .await
126 .map_err(|e| OAuthError::InvalidClient(format!("Failed to fetch client metadata: {}", e)))?;
127
128 if !response.status().is_success() {
129 return Err(OAuthError::InvalidClient(format!(
130 "Failed to fetch client metadata: HTTP {}",
131 response.status()
132 )));
133 }
134
135 let mut metadata: ClientMetadata = response
136 .json()
137 .await
138 .map_err(|e| OAuthError::InvalidClient(format!("Invalid client metadata JSON: {}", e)))?;
139
140 if metadata.client_id.is_empty() {
141 metadata.client_id = client_id.to_string();
142 } else if metadata.client_id != client_id {
143 return Err(OAuthError::InvalidClient(
144 "client_id in metadata does not match request".to_string(),
145 ));
146 }
147
148 self.validate_metadata(&metadata)?;
149
150 Ok(metadata)
151 }
152
153 fn validate_metadata(&self, metadata: &ClientMetadata) -> Result<(), OAuthError> {
154 if metadata.redirect_uris.is_empty() {
155 return Err(OAuthError::InvalidClient(
156 "redirect_uris is required".to_string(),
157 ));
158 }
159
160 for uri in &metadata.redirect_uris {
161 self.validate_redirect_uri_format(uri)?;
162 }
163
164 if !metadata.grant_types.is_empty()
165 && !metadata.grant_types.contains(&"authorization_code".to_string())
166 {
167 return Err(OAuthError::InvalidClient(
168 "authorization_code grant type is required".to_string(),
169 ));
170 }
171
172 if !metadata.response_types.is_empty()
173 && !metadata.response_types.contains(&"code".to_string())
174 {
175 return Err(OAuthError::InvalidClient(
176 "code response type is required".to_string(),
177 ));
178 }
179
180 Ok(())
181 }
182
183 pub fn validate_redirect_uri(
184 &self,
185 metadata: &ClientMetadata,
186 redirect_uri: &str,
187 ) -> Result<(), OAuthError> {
188 if !metadata.redirect_uris.contains(&redirect_uri.to_string()) {
189 return Err(OAuthError::InvalidRequest(
190 "redirect_uri not registered for client".to_string(),
191 ));
192 }
193 Ok(())
194 }
195
196 fn validate_redirect_uri_format(&self, uri: &str) -> Result<(), OAuthError> {
197 if uri.contains('#') {
198 return Err(OAuthError::InvalidClient(
199 "redirect_uri must not contain a fragment".to_string(),
200 ));
201 }
202
203 let parsed = reqwest::Url::parse(uri).map_err(|_| {
204 OAuthError::InvalidClient(format!("Invalid redirect_uri: {}", uri))
205 })?;
206
207 let scheme = parsed.scheme();
208
209 if scheme == "http" {
210 let host = parsed.host_str().unwrap_or("");
211 if host != "localhost" && host != "127.0.0.1" && host != "[::1]" {
212 return Err(OAuthError::InvalidClient(
213 "http redirect_uri only allowed for localhost".to_string(),
214 ));
215 }
216 } else if scheme == "https" {
217 } else if scheme.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '+' || c == '.' || c == '-') {
218 if !scheme.chars().next().map(|c| c.is_ascii_lowercase()).unwrap_or(false) {
219 return Err(OAuthError::InvalidClient(format!(
220 "Invalid redirect_uri scheme: {}",
221 scheme
222 )));
223 }
224 } else {
225 return Err(OAuthError::InvalidClient(format!(
226 "Invalid redirect_uri scheme: {}",
227 scheme
228 )));
229 }
230
231 Ok(())
232 }
233}
234
235impl ClientMetadata {
236 pub fn requires_dpop(&self) -> bool {
237 self.dpop_bound_access_tokens.unwrap_or(false)
238 }
239
240 pub fn auth_method(&self) -> &str {
241 self.token_endpoint_auth_method
242 .as_deref()
243 .unwrap_or("none")
244 }
245}
246
247pub fn verify_client_auth(
248 metadata: &ClientMetadata,
249 client_auth: &super::ClientAuth,
250) -> Result<(), OAuthError> {
251 let expected_method = metadata.auth_method();
252
253 match (expected_method, client_auth) {
254 ("none", super::ClientAuth::None) => Ok(()),
255
256 ("none", _) => Err(OAuthError::InvalidClient(
257 "Client is configured for no authentication, but credentials were provided".to_string(),
258 )),
259
260 ("private_key_jwt", super::ClientAuth::PrivateKeyJwt { client_assertion }) => {
261 verify_private_key_jwt(metadata, client_assertion)
262 }
263
264 ("private_key_jwt", _) => Err(OAuthError::InvalidClient(
265 "Client requires private_key_jwt authentication".to_string(),
266 )),
267
268 ("client_secret_post", super::ClientAuth::SecretPost { .. }) => {
269 Err(OAuthError::InvalidClient(
270 "client_secret_post is not supported for ATProto OAuth".to_string(),
271 ))
272 }
273
274 ("client_secret_basic", super::ClientAuth::SecretBasic { .. }) => {
275 Err(OAuthError::InvalidClient(
276 "client_secret_basic is not supported for ATProto OAuth".to_string(),
277 ))
278 }
279
280 (method, _) => Err(OAuthError::InvalidClient(format!(
281 "Unsupported or mismatched authentication method: {}",
282 method
283 ))),
284 }
285}
286
287fn verify_private_key_jwt(
288 metadata: &ClientMetadata,
289 client_assertion: &str,
290) -> Result<(), OAuthError> {
291 use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
292
293 let parts: Vec<&str> = client_assertion.split('.').collect();
294 if parts.len() != 3 {
295 return Err(OAuthError::InvalidClient("Invalid client_assertion format".to_string()));
296 }
297
298 let header_bytes = URL_SAFE_NO_PAD
299 .decode(parts[0])
300 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header encoding".to_string()))?;
301 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
302 .map_err(|_| OAuthError::InvalidClient("Invalid assertion header JSON".to_string()))?;
303
304 let alg = header.get("alg").and_then(|a| a.as_str()).ok_or_else(|| {
305 OAuthError::InvalidClient("Missing alg in client_assertion".to_string())
306 })?;
307
308 if !matches!(alg, "ES256" | "ES384" | "RS256" | "RS384" | "RS512" | "EdDSA") {
309 return Err(OAuthError::InvalidClient(format!(
310 "Unsupported client_assertion algorithm: {}",
311 alg
312 )));
313 }
314
315 let payload_bytes = URL_SAFE_NO_PAD
316 .decode(parts[1])
317 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload encoding".to_string()))?;
318 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
319 .map_err(|_| OAuthError::InvalidClient("Invalid assertion payload JSON".to_string()))?;
320
321 let iss = payload.get("iss").and_then(|i| i.as_str()).ok_or_else(|| {
322 OAuthError::InvalidClient("Missing iss in client_assertion".to_string())
323 })?;
324 if iss != metadata.client_id {
325 return Err(OAuthError::InvalidClient(
326 "client_assertion iss does not match client_id".to_string(),
327 ));
328 }
329
330 let sub = payload.get("sub").and_then(|s| s.as_str()).ok_or_else(|| {
331 OAuthError::InvalidClient("Missing sub in client_assertion".to_string())
332 })?;
333 if sub != metadata.client_id {
334 return Err(OAuthError::InvalidClient(
335 "client_assertion sub does not match client_id".to_string(),
336 ));
337 }
338
339 let exp = payload.get("exp").and_then(|e| e.as_i64()).ok_or_else(|| {
340 OAuthError::InvalidClient("Missing exp in client_assertion".to_string())
341 })?;
342 let now = chrono::Utc::now().timestamp();
343 if exp < now {
344 return Err(OAuthError::InvalidClient("client_assertion has expired".to_string()));
345 }
346
347 let iat = payload.get("iat").and_then(|i| i.as_i64());
348 if let Some(iat) = iat {
349 if iat > now + 60 {
350 return Err(OAuthError::InvalidClient(
351 "client_assertion iat is in the future".to_string(),
352 ));
353 }
354 }
355
356 if metadata.jwks.is_none() && metadata.jwks_uri.is_none() {
357 return Err(OAuthError::InvalidClient(
358 "Client using private_key_jwt must have jwks or jwks_uri".to_string(),
359 ));
360 }
361
362 Err(OAuthError::InvalidClient(
363 "private_key_jwt signature verification not yet implemented - use 'none' auth method".to_string(),
364 ))
365}