this repo has no description
1use crate::oauth::{
2 AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData,
3 RequestId, db,
4 scopes::{ParsedScope, parse_scope},
5};
6use crate::state::{AppState, RateLimitKind};
7use axum::body::Bytes;
8use axum::{Json, extract::State, http::HeaderMap};
9use chrono::{Duration, Utc};
10use serde::{Deserialize, Serialize};
11
12const PAR_EXPIRY_SECONDS: i64 = 600;
13
14#[derive(Debug, Deserialize)]
15pub struct ParRequest {
16 pub response_type: String,
17 pub client_id: String,
18 pub redirect_uri: String,
19 #[serde(default)]
20 pub scope: Option<String>,
21 #[serde(default)]
22 pub state: Option<String>,
23 #[serde(default)]
24 pub code_challenge: Option<String>,
25 #[serde(default)]
26 pub code_challenge_method: Option<String>,
27 #[serde(default)]
28 pub response_mode: Option<String>,
29 #[serde(default)]
30 pub login_hint: Option<String>,
31 #[serde(default)]
32 pub dpop_jkt: Option<String>,
33 #[serde(default)]
34 pub client_secret: Option<String>,
35 #[serde(default)]
36 pub client_assertion: Option<String>,
37 #[serde(default)]
38 pub client_assertion_type: Option<String>,
39}
40
41#[derive(Debug, Serialize)]
42pub struct ParResponse {
43 pub request_uri: String,
44 pub expires_in: u64,
45}
46
47pub async fn pushed_authorization_request(
48 State(state): State<AppState>,
49 headers: HeaderMap,
50 body: Bytes,
51) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
52 let content_type = headers
53 .get("content-type")
54 .and_then(|v| v.to_str().ok())
55 .unwrap_or("");
56 let request: ParRequest = if content_type.starts_with("application/json") {
57 serde_json::from_slice(&body)
58 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))?
59 } else if content_type.starts_with("application/x-www-form-urlencoded") {
60 let parsed: ParRequest = serde_urlencoded::from_bytes(&body)
61 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?;
62 tracing::info!(login_hint = ?parsed.login_hint, "PAR request received (form)");
63 parsed
64 } else {
65 return Err(OAuthError::InvalidRequest(
66 "Content-Type must be application/json or application/x-www-form-urlencoded"
67 .to_string(),
68 ));
69 };
70 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
71 if !state
72 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
73 .await
74 {
75 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
76 return Err(OAuthError::RateLimited);
77 }
78 if request.response_type != "code" {
79 return Err(OAuthError::InvalidRequest(
80 "response_type must be 'code'".to_string(),
81 ));
82 }
83 let code_challenge = request
84 .code_challenge
85 .as_ref()
86 .filter(|s| !s.is_empty())
87 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
88 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
89 if code_challenge_method != "S256" {
90 return Err(OAuthError::InvalidRequest(
91 "code_challenge_method must be 'S256'".to_string(),
92 ));
93 }
94 let client_cache = ClientMetadataCache::new(3600);
95 let client_metadata = client_cache.get(&request.client_id).await?;
96 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
97 let client_auth = determine_client_auth(&request)?;
98 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
99 let request_id = RequestId::generate();
100 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
101 let response_mode = match request.response_mode.as_deref() {
102 Some("fragment") => Some("fragment".to_string()),
103 Some("query") | None => None,
104 Some(mode) => {
105 return Err(OAuthError::InvalidRequest(format!(
106 "Unsupported response_mode: {}",
107 mode
108 )));
109 }
110 };
111 let parameters = AuthorizationRequestParameters {
112 response_type: request.response_type,
113 client_id: request.client_id.clone(),
114 redirect_uri: request.redirect_uri,
115 scope: validated_scope,
116 state: request.state,
117 code_challenge: code_challenge.clone(),
118 code_challenge_method: code_challenge_method.to_string(),
119 response_mode,
120 login_hint: request.login_hint,
121 dpop_jkt: request.dpop_jkt,
122 extra: None,
123 };
124 let request_data = RequestData {
125 client_id: request.client_id,
126 client_auth: Some(client_auth),
127 parameters,
128 expires_at,
129 did: None,
130 device_id: None,
131 code: None,
132 controller_did: None,
133 };
134 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
135 tokio::spawn({
136 let pool = state.db.clone();
137 async move {
138 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
139 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
140 }
141 }
142 });
143 Ok((
144 axum::http::StatusCode::CREATED,
145 Json(ParResponse {
146 request_uri: request_id.0,
147 expires_in: PAR_EXPIRY_SECONDS as u64,
148 }),
149 ))
150}
151
152fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
153 if let (Some(assertion), Some(assertion_type)) =
154 (&request.client_assertion, &request.client_assertion_type)
155 {
156 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
157 return Err(OAuthError::InvalidRequest(
158 "Unsupported client_assertion_type".to_string(),
159 ));
160 }
161 return Ok(ClientAuth::PrivateKeyJwt {
162 client_assertion: assertion.clone(),
163 });
164 }
165 if let Some(secret) = &request.client_secret {
166 return Ok(ClientAuth::SecretPost {
167 client_secret: secret.clone(),
168 });
169 }
170 Ok(ClientAuth::None)
171}
172
173fn validate_scope(
174 requested_scope: &Option<String>,
175 client_metadata: &crate::oauth::ClientMetadata,
176) -> Result<Option<String>, OAuthError> {
177 let scope_str = match requested_scope {
178 Some(s) if !s.is_empty() => s,
179 _ => return Ok(Some("atproto".to_string())),
180 };
181 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
182 if requested_scopes.is_empty() {
183 return Ok(Some("atproto".to_string()));
184 }
185 if let Some(unknown) = requested_scopes
186 .iter()
187 .find(|s| matches!(parse_scope(s), ParsedScope::Unknown(_)))
188 {
189 return Err(OAuthError::InvalidScope(format!(
190 "Unsupported scope: {}",
191 unknown
192 )));
193 }
194
195 let has_transition = requested_scopes.iter().any(|s| {
196 matches!(
197 parse_scope(s),
198 ParsedScope::TransitionGeneric
199 | ParsedScope::TransitionChat
200 | ParsedScope::TransitionEmail
201 )
202 });
203 let has_granular = requested_scopes.iter().any(|s| {
204 matches!(
205 parse_scope(s),
206 ParsedScope::Repo(_)
207 | ParsedScope::Blob(_)
208 | ParsedScope::Rpc(_)
209 | ParsedScope::Account(_)
210 | ParsedScope::Identity(_)
211 | ParsedScope::Include(_)
212 )
213 });
214
215 if has_transition && has_granular {
216 return Err(OAuthError::InvalidScope(
217 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string()
218 ));
219 }
220
221 if let Some(client_scope) = &client_metadata.scope {
222 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
223 if let Some(unregistered) = requested_scopes
224 .iter()
225 .find(|scope| !client_scopes.iter().any(|cs| scope_matches(cs, scope)))
226 {
227 return Err(OAuthError::InvalidScope(format!(
228 "Scope '{}' not registered for this client",
229 unregistered
230 )));
231 }
232 }
233 Ok(Some(requested_scopes.join(" ")))
234}
235
236fn scope_matches(client_scope: &str, requested_scope: &str) -> bool {
237 if client_scope == requested_scope {
238 return true;
239 }
240
241 fn get_resource_type(scope: &str) -> &str {
242 let base = scope.split('?').next().unwrap_or(scope);
243 base.split(':').next().unwrap_or(base)
244 }
245
246 let client_type = get_resource_type(client_scope);
247 let requested_type = get_resource_type(requested_scope);
248
249 if client_type == requested_type {
250 let client_base = client_scope.split('?').next().unwrap_or(client_scope);
251 if client_base.contains('*') {
252 return true;
253 }
254 }
255
256 false
257}