this repo has no description
1use axum::{
2 Form, Json,
3 extract::State,
4 http::HeaderMap,
5};
6use chrono::{Duration, Utc};
7use serde::{Deserialize, Serialize};
8
9use crate::state::{AppState, RateLimitKind};
10use crate::oauth::{
11 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
12 client::ClientMetadataCache,
13 db,
14};
15
16const PAR_EXPIRY_SECONDS: i64 = 600;
17
18const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
19
20#[derive(Debug, Deserialize)]
21pub struct ParRequest {
22 pub response_type: String,
23 pub client_id: String,
24 pub redirect_uri: String,
25 #[serde(default)]
26 pub scope: Option<String>,
27 #[serde(default)]
28 pub state: Option<String>,
29 #[serde(default)]
30 pub code_challenge: Option<String>,
31 #[serde(default)]
32 pub code_challenge_method: Option<String>,
33 #[serde(default)]
34 pub login_hint: Option<String>,
35 #[serde(default)]
36 pub dpop_jkt: Option<String>,
37 #[serde(default)]
38 pub client_secret: Option<String>,
39 #[serde(default)]
40 pub client_assertion: Option<String>,
41 #[serde(default)]
42 pub client_assertion_type: Option<String>,
43}
44
45#[derive(Debug, Serialize)]
46pub struct ParResponse {
47 pub request_uri: String,
48 pub expires_in: u64,
49}
50
51pub async fn pushed_authorization_request(
52 State(state): State<AppState>,
53 headers: HeaderMap,
54 Form(request): Form<ParRequest>,
55) -> Result<Json<ParResponse>, OAuthError> {
56 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
57 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await {
58 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
59 return Err(OAuthError::RateLimited);
60 }
61
62 if request.response_type != "code" {
63 return Err(OAuthError::InvalidRequest(
64 "response_type must be 'code'".to_string(),
65 ));
66 }
67
68 let code_challenge = request.code_challenge.as_ref()
69 .filter(|s| !s.is_empty())
70 .ok_or_else(|| OAuthError::InvalidRequest(
71 "code_challenge is required".to_string(),
72 ))?;
73
74 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
75 if code_challenge_method != "S256" {
76 return Err(OAuthError::InvalidRequest(
77 "code_challenge_method must be 'S256'".to_string(),
78 ));
79 }
80
81 let client_cache = ClientMetadataCache::new(3600);
82 let client_metadata = client_cache.get(&request.client_id).await?;
83
84 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
85
86 let client_auth = determine_client_auth(&request)?;
87
88 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() {
89 return Err(OAuthError::InvalidRequest(
90 "dpop_jkt is required for this client".to_string(),
91 ));
92 }
93
94 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
95
96 let request_id = RequestId::generate();
97 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
98
99 let parameters = AuthorizationRequestParameters {
100 response_type: request.response_type,
101 client_id: request.client_id.clone(),
102 redirect_uri: request.redirect_uri,
103 scope: validated_scope,
104 state: request.state,
105 code_challenge: code_challenge.clone(),
106 code_challenge_method: code_challenge_method.to_string(),
107 login_hint: request.login_hint,
108 dpop_jkt: request.dpop_jkt,
109 extra: None,
110 };
111
112 let request_data = RequestData {
113 client_id: request.client_id,
114 client_auth: Some(client_auth),
115 parameters,
116 expires_at,
117 did: None,
118 device_id: None,
119 code: None,
120 };
121
122 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
123
124 tokio::spawn({
125 let pool = state.db.clone();
126 async move {
127 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
128 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
129 }
130 }
131 });
132
133 Ok(Json(ParResponse {
134 request_uri: request_id.0,
135 expires_in: PAR_EXPIRY_SECONDS as u64,
136 }))
137}
138
139fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
140 if let (Some(assertion), Some(assertion_type)) =
141 (&request.client_assertion, &request.client_assertion_type)
142 {
143 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
144 return Err(OAuthError::InvalidRequest(
145 "Unsupported client_assertion_type".to_string(),
146 ));
147 }
148 return Ok(ClientAuth::PrivateKeyJwt {
149 client_assertion: assertion.clone(),
150 });
151 }
152
153 if let Some(secret) = &request.client_secret {
154 return Ok(ClientAuth::SecretPost {
155 client_secret: secret.clone(),
156 });
157 }
158
159 Ok(ClientAuth::None)
160}
161
162fn validate_scope(
163 requested_scope: &Option<String>,
164 client_metadata: &crate::oauth::client::ClientMetadata,
165) -> Result<Option<String>, OAuthError> {
166 let scope_str = match requested_scope {
167 Some(s) if !s.is_empty() => s,
168 _ => return Ok(Some("atproto".to_string())),
169 };
170
171 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
172
173 if requested_scopes.is_empty() {
174 return Ok(Some("atproto".to_string()));
175 }
176
177 for scope in &requested_scopes {
178 if !SUPPORTED_SCOPES.contains(scope) {
179 return Err(OAuthError::InvalidScope(format!(
180 "Unsupported scope: {}. Supported scopes: {}",
181 scope,
182 SUPPORTED_SCOPES.join(", ")
183 )));
184 }
185 }
186
187 if let Some(client_scope) = &client_metadata.scope {
188 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
189 for scope in &requested_scopes {
190 if !client_scopes.contains(scope) {
191 return Err(OAuthError::InvalidScope(format!(
192 "Scope '{}' not registered for this client",
193 scope
194 )));
195 }
196 }
197 }
198
199 Ok(Some(requested_scopes.join(" ")))
200}