this repo has no description
1use axum::{
2 Form, Json,
3 extract::State,
4 http::HeaderMap,
5};
6use chrono::{Duration, Utc};
7use serde::{Deserialize, Serialize};
8
9use crate::state::AppState;
10use crate::oauth::{
11 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
12 client::ClientMetadataCache,
13 db,
14};
15
16const PAR_EXPIRY_SECONDS: i64 = 600;
17
18const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
19
20#[derive(Debug, Deserialize)]
21pub struct ParRequest {
22 pub response_type: String,
23 pub client_id: String,
24 pub redirect_uri: String,
25 #[serde(default)]
26 pub scope: Option<String>,
27 #[serde(default)]
28 pub state: Option<String>,
29 #[serde(default)]
30 pub code_challenge: Option<String>,
31 #[serde(default)]
32 pub code_challenge_method: Option<String>,
33 #[serde(default)]
34 pub login_hint: Option<String>,
35 #[serde(default)]
36 pub dpop_jkt: Option<String>,
37 #[serde(default)]
38 pub client_secret: Option<String>,
39 #[serde(default)]
40 pub client_assertion: Option<String>,
41 #[serde(default)]
42 pub client_assertion_type: Option<String>,
43}
44
45#[derive(Debug, Serialize)]
46pub struct ParResponse {
47 pub request_uri: String,
48 pub expires_in: u64,
49}
50
51pub async fn pushed_authorization_request(
52 State(state): State<AppState>,
53 headers: HeaderMap,
54 Form(request): Form<ParRequest>,
55) -> Result<Json<ParResponse>, OAuthError> {
56 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
57 if !state.distributed_rate_limiter.check_rate_limit(
58 &format!("oauth_par:{}", client_ip),
59 30,
60 60_000,
61 ).await {
62 if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() {
63 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
64 return Err(OAuthError::RateLimited);
65 }
66 }
67
68 if request.response_type != "code" {
69 return Err(OAuthError::InvalidRequest(
70 "response_type must be 'code'".to_string(),
71 ));
72 }
73
74 let code_challenge = request.code_challenge.as_ref()
75 .filter(|s| !s.is_empty())
76 .ok_or_else(|| OAuthError::InvalidRequest(
77 "code_challenge is required".to_string(),
78 ))?;
79
80 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
81 if code_challenge_method != "S256" {
82 return Err(OAuthError::InvalidRequest(
83 "code_challenge_method must be 'S256'".to_string(),
84 ));
85 }
86
87 let client_cache = ClientMetadataCache::new(3600);
88 let client_metadata = client_cache.get(&request.client_id).await?;
89
90 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
91
92 let client_auth = determine_client_auth(&request)?;
93
94 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() {
95 return Err(OAuthError::InvalidRequest(
96 "dpop_jkt is required for this client".to_string(),
97 ));
98 }
99
100 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
101
102 let request_id = RequestId::generate();
103 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
104
105 let parameters = AuthorizationRequestParameters {
106 response_type: request.response_type,
107 client_id: request.client_id.clone(),
108 redirect_uri: request.redirect_uri,
109 scope: validated_scope,
110 state: request.state,
111 code_challenge: code_challenge.clone(),
112 code_challenge_method: code_challenge_method.to_string(),
113 login_hint: request.login_hint,
114 dpop_jkt: request.dpop_jkt,
115 extra: None,
116 };
117
118 let request_data = RequestData {
119 client_id: request.client_id,
120 client_auth: Some(client_auth),
121 parameters,
122 expires_at,
123 did: None,
124 device_id: None,
125 code: None,
126 };
127
128 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
129
130 tokio::spawn({
131 let pool = state.db.clone();
132 async move {
133 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
134 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
135 }
136 }
137 });
138
139 Ok(Json(ParResponse {
140 request_uri: request_id.0,
141 expires_in: PAR_EXPIRY_SECONDS as u64,
142 }))
143}
144
145fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
146 if let (Some(assertion), Some(assertion_type)) =
147 (&request.client_assertion, &request.client_assertion_type)
148 {
149 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
150 return Err(OAuthError::InvalidRequest(
151 "Unsupported client_assertion_type".to_string(),
152 ));
153 }
154 return Ok(ClientAuth::PrivateKeyJwt {
155 client_assertion: assertion.clone(),
156 });
157 }
158
159 if let Some(secret) = &request.client_secret {
160 return Ok(ClientAuth::SecretPost {
161 client_secret: secret.clone(),
162 });
163 }
164
165 Ok(ClientAuth::None)
166}
167
168fn validate_scope(
169 requested_scope: &Option<String>,
170 client_metadata: &crate::oauth::client::ClientMetadata,
171) -> Result<Option<String>, OAuthError> {
172 let scope_str = match requested_scope {
173 Some(s) if !s.is_empty() => s,
174 _ => return Ok(Some("atproto".to_string())),
175 };
176
177 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
178
179 if requested_scopes.is_empty() {
180 return Ok(Some("atproto".to_string()));
181 }
182
183 for scope in &requested_scopes {
184 if !SUPPORTED_SCOPES.contains(scope) {
185 return Err(OAuthError::InvalidScope(format!(
186 "Unsupported scope: {}. Supported scopes: {}",
187 scope,
188 SUPPORTED_SCOPES.join(", ")
189 )));
190 }
191 }
192
193 if let Some(client_scope) = &client_metadata.scope {
194 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
195 for scope in &requested_scopes {
196 if !client_scopes.contains(scope) {
197 return Err(OAuthError::InvalidScope(format!(
198 "Scope '{}' not registered for this client",
199 scope
200 )));
201 }
202 }
203 }
204
205 Ok(Some(requested_scopes.join(" ")))
206}