this repo has no description
1use crate::oauth::{
2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3 client::ClientMetadataCache,
4 db,
5 scopes::{ParsedScope, parse_scope},
6};
7use crate::state::{AppState, RateLimitKind};
8use axum::body::Bytes;
9use axum::{Json, extract::State, http::HeaderMap};
10use chrono::{Duration, Utc};
11use serde::{Deserialize, Serialize};
12
13const PAR_EXPIRY_SECONDS: i64 = 600;
14
15#[derive(Debug, Deserialize)]
16pub struct ParRequest {
17 pub response_type: String,
18 pub client_id: String,
19 pub redirect_uri: String,
20 #[serde(default)]
21 pub scope: Option<String>,
22 #[serde(default)]
23 pub state: Option<String>,
24 #[serde(default)]
25 pub code_challenge: Option<String>,
26 #[serde(default)]
27 pub code_challenge_method: Option<String>,
28 #[serde(default)]
29 pub response_mode: Option<String>,
30 #[serde(default)]
31 pub login_hint: Option<String>,
32 #[serde(default)]
33 pub dpop_jkt: Option<String>,
34 #[serde(default)]
35 pub client_secret: Option<String>,
36 #[serde(default)]
37 pub client_assertion: Option<String>,
38 #[serde(default)]
39 pub client_assertion_type: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43pub struct ParResponse {
44 pub request_uri: String,
45 pub expires_in: u64,
46}
47
48pub async fn pushed_authorization_request(
49 State(state): State<AppState>,
50 headers: HeaderMap,
51 body: Bytes,
52) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
53 let content_type = headers
54 .get("content-type")
55 .and_then(|v| v.to_str().ok())
56 .unwrap_or("");
57 let request: ParRequest = if content_type.starts_with("application/json") {
58 serde_json::from_slice(&body)
59 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))?
60 } else if content_type.starts_with("application/x-www-form-urlencoded") {
61 serde_urlencoded::from_bytes(&body)
62 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?
63 } else {
64 return Err(OAuthError::InvalidRequest(
65 "Content-Type must be application/json or application/x-www-form-urlencoded"
66 .to_string(),
67 ));
68 };
69 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
70 if !state
71 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
72 .await
73 {
74 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
75 return Err(OAuthError::RateLimited);
76 }
77 if request.response_type != "code" {
78 return Err(OAuthError::InvalidRequest(
79 "response_type must be 'code'".to_string(),
80 ));
81 }
82 let code_challenge = request
83 .code_challenge
84 .as_ref()
85 .filter(|s| !s.is_empty())
86 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
87 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
88 if code_challenge_method != "S256" {
89 return Err(OAuthError::InvalidRequest(
90 "code_challenge_method must be 'S256'".to_string(),
91 ));
92 }
93 let client_cache = ClientMetadataCache::new(3600);
94 let client_metadata = client_cache.get(&request.client_id).await?;
95 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
96 let client_auth = determine_client_auth(&request)?;
97 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
98 let request_id = RequestId::generate();
99 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
100 let response_mode = match request.response_mode.as_deref() {
101 Some("fragment") => Some("fragment".to_string()),
102 Some("query") | None => None,
103 Some(mode) => {
104 return Err(OAuthError::InvalidRequest(format!(
105 "Unsupported response_mode: {}",
106 mode
107 )));
108 }
109 };
110 let parameters = AuthorizationRequestParameters {
111 response_type: request.response_type,
112 client_id: request.client_id.clone(),
113 redirect_uri: request.redirect_uri,
114 scope: validated_scope,
115 state: request.state,
116 code_challenge: code_challenge.clone(),
117 code_challenge_method: code_challenge_method.to_string(),
118 response_mode,
119 login_hint: request.login_hint,
120 dpop_jkt: request.dpop_jkt,
121 extra: None,
122 };
123 let request_data = RequestData {
124 client_id: request.client_id,
125 client_auth: Some(client_auth),
126 parameters,
127 expires_at,
128 did: None,
129 device_id: None,
130 code: None,
131 };
132 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
133 tokio::spawn({
134 let pool = state.db.clone();
135 async move {
136 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
137 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
138 }
139 }
140 });
141 Ok((
142 axum::http::StatusCode::CREATED,
143 Json(ParResponse {
144 request_uri: request_id.0,
145 expires_in: PAR_EXPIRY_SECONDS as u64,
146 }),
147 ))
148}
149
150fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
151 if let (Some(assertion), Some(assertion_type)) =
152 (&request.client_assertion, &request.client_assertion_type)
153 {
154 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
155 return Err(OAuthError::InvalidRequest(
156 "Unsupported client_assertion_type".to_string(),
157 ));
158 }
159 return Ok(ClientAuth::PrivateKeyJwt {
160 client_assertion: assertion.clone(),
161 });
162 }
163 if let Some(secret) = &request.client_secret {
164 return Ok(ClientAuth::SecretPost {
165 client_secret: secret.clone(),
166 });
167 }
168 Ok(ClientAuth::None)
169}
170
171fn validate_scope(
172 requested_scope: &Option<String>,
173 client_metadata: &crate::oauth::client::ClientMetadata,
174) -> Result<Option<String>, OAuthError> {
175 let scope_str = match requested_scope {
176 Some(s) if !s.is_empty() => s,
177 _ => return Ok(Some("atproto".to_string())),
178 };
179 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
180 if requested_scopes.is_empty() {
181 return Ok(Some("atproto".to_string()));
182 }
183 let mut has_transition = false;
184 let mut has_granular = false;
185
186 for scope in &requested_scopes {
187 let parsed = parse_scope(scope);
188 match &parsed {
189 ParsedScope::Unknown(_) => {
190 return Err(OAuthError::InvalidScope(format!(
191 "Unsupported scope: {}",
192 scope
193 )));
194 }
195 ParsedScope::TransitionGeneric
196 | ParsedScope::TransitionChat
197 | ParsedScope::TransitionEmail => {
198 has_transition = true;
199 }
200 ParsedScope::Repo(_)
201 | ParsedScope::Blob(_)
202 | ParsedScope::Rpc(_)
203 | ParsedScope::Account(_)
204 | ParsedScope::Identity(_)
205 | ParsedScope::Include(_) => {
206 has_granular = true;
207 }
208 ParsedScope::Atproto => {}
209 }
210 }
211
212 if has_transition && has_granular {
213 return Err(OAuthError::InvalidScope(
214 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string()
215 ));
216 }
217
218 if let Some(client_scope) = &client_metadata.scope {
219 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
220 for scope in &requested_scopes {
221 if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) {
222 return Err(OAuthError::InvalidScope(format!(
223 "Scope '{}' not registered for this client",
224 scope
225 )));
226 }
227 }
228 }
229 Ok(Some(requested_scopes.join(" ")))
230}
231
232fn scope_matches(client_scope: &str, requested_scope: &str) -> bool {
233 if client_scope == requested_scope {
234 return true;
235 }
236
237 fn get_resource_type(scope: &str) -> &str {
238 let base = scope.split('?').next().unwrap_or(scope);
239 base.split(':').next().unwrap_or(base)
240 }
241
242 let client_type = get_resource_type(client_scope);
243 let requested_type = get_resource_type(requested_scope);
244
245 if client_type == requested_type {
246 let client_base = client_scope.split('?').next().unwrap_or(client_scope);
247 if client_base.contains('*') {
248 return true;
249 }
250 }
251
252 false
253}