this repo has no description
1use crate::oauth::{
2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3 client::ClientMetadataCache,
4 db,
5 scopes::{ParsedScope, parse_scope},
6};
7use crate::state::{AppState, RateLimitKind};
8use axum::body::Bytes;
9use axum::{Json, extract::State, http::HeaderMap};
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12
13const PAR_EXPIRY_SECONDS: i64 = 600;
14
15#[derive(Debug, Deserialize)]
16pub struct ParRequest {
17 pub response_type: String,
18 pub client_id: String,
19 pub redirect_uri: String,
20 #[serde(default)]
21 pub scope: Option<String>,
22 #[serde(default)]
23 pub state: Option<String>,
24 #[serde(default)]
25 pub code_challenge: Option<String>,
26 #[serde(default)]
27 pub code_challenge_method: Option<String>,
28 #[serde(default)]
29 pub response_mode: Option<String>,
30 #[serde(default)]
31 pub login_hint: Option<String>,
32 #[serde(default)]
33 pub dpop_jkt: Option<String>,
34 #[serde(default)]
35 pub client_secret: Option<String>,
36 #[serde(default)]
37 pub client_assertion: Option<String>,
38 #[serde(default)]
39 pub client_assertion_type: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43pub struct ParResponse {
44 pub request_uri: String,
45 pub expires_in: u64,
46}
47
48pub async fn pushed_authorization_request(
49 State(state): State<AppState>,
50 headers: HeaderMap,
51 body: Bytes,
52) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
53 let content_type = headers
54 .get("content-type")
55 .and_then(|v| v.to_str().ok())
56 .unwrap_or("");
57 let request: ParRequest = if content_type.starts_with("application/json") {
58 serde_json::from_slice(&body)
59 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))?
60 } else if content_type.starts_with("application/x-www-form-urlencoded") {
61 let parsed: ParRequest = serde_urlencoded::from_bytes(&body)
62 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?;
63 tracing::info!(login_hint = ?parsed.login_hint, "PAR request received (form)");
64 parsed
65 } else {
66 return Err(OAuthError::InvalidRequest(
67 "Content-Type must be application/json or application/x-www-form-urlencoded"
68 .to_string(),
69 ));
70 };
71 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
72 if !state
73 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
74 .await
75 {
76 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
77 return Err(OAuthError::RateLimited);
78 }
79 if request.response_type != "code" {
80 return Err(OAuthError::InvalidRequest(
81 "response_type must be 'code'".to_string(),
82 ));
83 }
84 let code_challenge = request
85 .code_challenge
86 .as_ref()
87 .filter(|s| !s.is_empty())
88 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
89 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
90 if code_challenge_method != "S256" {
91 return Err(OAuthError::InvalidRequest(
92 "code_challenge_method must be 'S256'".to_string(),
93 ));
94 }
95 let client_cache = ClientMetadataCache::new(3600);
96 let client_metadata = client_cache.get(&request.client_id).await?;
97 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
98 let client_auth = determine_client_auth(&request)?;
99 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
100 let request_id = RequestId::generate();
101 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
102 let response_mode = match request.response_mode.as_deref() {
103 Some("fragment") => Some("fragment".to_string()),
104 Some("query") | None => None,
105 Some(mode) => {
106 return Err(OAuthError::InvalidRequest(format!(
107 "Unsupported response_mode: {}",
108 mode
109 )));
110 }
111 };
112 let parameters = AuthorizationRequestParameters {
113 response_type: request.response_type,
114 client_id: request.client_id.clone(),
115 redirect_uri: request.redirect_uri,
116 scope: validated_scope,
117 state: request.state,
118 code_challenge: code_challenge.clone(),
119 code_challenge_method: code_challenge_method.to_string(),
120 response_mode,
121 login_hint: request.login_hint,
122 dpop_jkt: request.dpop_jkt,
123 extra: None,
124 };
125 let request_data = RequestData {
126 client_id: request.client_id,
127 client_auth: Some(client_auth),
128 parameters,
129 expires_at,
130 did: None,
131 device_id: None,
132 code: None,
133 controller_did: None,
134 };
135 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
136 tokio::spawn({
137 let pool = state.db.clone();
138 async move {
139 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
140 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
141 }
142 }
143 });
144 Ok((
145 axum::http::StatusCode::CREATED,
146 Json(ParResponse {
147 request_uri: request_id.0,
148 expires_in: PAR_EXPIRY_SECONDS as u64,
149 }),
150 ))
151}
152
153fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
154 if let (Some(assertion), Some(assertion_type)) =
155 (&request.client_assertion, &request.client_assertion_type)
156 {
157 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
158 return Err(OAuthError::InvalidRequest(
159 "Unsupported client_assertion_type".to_string(),
160 ));
161 }
162 return Ok(ClientAuth::PrivateKeyJwt {
163 client_assertion: assertion.clone(),
164 });
165 }
166 if let Some(secret) = &request.client_secret {
167 return Ok(ClientAuth::SecretPost {
168 client_secret: secret.clone(),
169 });
170 }
171 Ok(ClientAuth::None)
172}
173
174fn validate_scope(
175 requested_scope: &Option<String>,
176 client_metadata: &crate::oauth::client::ClientMetadata,
177) -> Result<Option<String>, OAuthError> {
178 let scope_str = match requested_scope {
179 Some(s) if !s.is_empty() => s,
180 _ => return Ok(Some("atproto".to_string())),
181 };
182 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
183 if requested_scopes.is_empty() {
184 return Ok(Some("atproto".to_string()));
185 }
186 let mut has_transition = false;
187 let mut has_granular = false;
188
189 for scope in &requested_scopes {
190 let parsed = parse_scope(scope);
191 match &parsed {
192 ParsedScope::Unknown(_) => {
193 return Err(OAuthError::InvalidScope(format!(
194 "Unsupported scope: {}",
195 scope
196 )));
197 }
198 ParsedScope::TransitionGeneric
199 | ParsedScope::TransitionChat
200 | ParsedScope::TransitionEmail => {
201 has_transition = true;
202 }
203 ParsedScope::Repo(_)
204 | ParsedScope::Blob(_)
205 | ParsedScope::Rpc(_)
206 | ParsedScope::Account(_)
207 | ParsedScope::Identity(_)
208 | ParsedScope::Include(_) => {
209 has_granular = true;
210 }
211 ParsedScope::Atproto => {}
212 }
213 }
214
215 if has_transition && has_granular {
216 return Err(OAuthError::InvalidScope(
217 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string()
218 ));
219 }
220
221 if let Some(client_scope) = &client_metadata.scope {
222 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
223 for scope in &requested_scopes {
224 if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) {
225 return Err(OAuthError::InvalidScope(format!(
226 "Scope '{}' not registered for this client",
227 scope
228 )));
229 }
230 }
231 }
232 Ok(Some(requested_scopes.join(" ")))
233}
234
235fn scope_matches(client_scope: &str, requested_scope: &str) -> bool {
236 if client_scope == requested_scope {
237 return true;
238 }
239
240 fn get_resource_type(scope: &str) -> &str {
241 let base = scope.split('?').next().unwrap_or(scope);
242 base.split(':').next().unwrap_or(base)
243 }
244
245 let client_type = get_resource_type(client_scope);
246 let requested_type = get_resource_type(requested_scope);
247
248 if client_type == requested_type {
249 let client_base = client_scope.split('?').next().unwrap_or(client_scope);
250 if client_base.contains('*') {
251 return true;
252 }
253 }
254
255 false
256}