this repo has no description
1use crate::oauth::{
2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
3 client::ClientMetadataCache, db,
4};
5use crate::state::{AppState, RateLimitKind};
6use axum::{Form, Json, extract::State, http::HeaderMap};
7use chrono::{Duration, Utc};
8use serde::{Deserialize, Serialize};
9
10const PAR_EXPIRY_SECONDS: i64 = 600;
11const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
12
13#[derive(Debug, Deserialize)]
14pub struct ParRequest {
15 pub response_type: String,
16 pub client_id: String,
17 pub redirect_uri: String,
18 #[serde(default)]
19 pub scope: Option<String>,
20 #[serde(default)]
21 pub state: Option<String>,
22 #[serde(default)]
23 pub code_challenge: Option<String>,
24 #[serde(default)]
25 pub code_challenge_method: Option<String>,
26 #[serde(default)]
27 pub login_hint: Option<String>,
28 #[serde(default)]
29 pub dpop_jkt: Option<String>,
30 #[serde(default)]
31 pub client_secret: Option<String>,
32 #[serde(default)]
33 pub client_assertion: Option<String>,
34 #[serde(default)]
35 pub client_assertion_type: Option<String>,
36}
37
38#[derive(Debug, Serialize)]
39pub struct ParResponse {
40 pub request_uri: String,
41 pub expires_in: u64,
42}
43
44pub async fn pushed_authorization_request(
45 State(state): State<AppState>,
46 headers: HeaderMap,
47 Form(request): Form<ParRequest>,
48) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
49 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
50 if !state
51 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
52 .await
53 {
54 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
55 return Err(OAuthError::RateLimited);
56 }
57 if request.response_type != "code" {
58 return Err(OAuthError::InvalidRequest(
59 "response_type must be 'code'".to_string(),
60 ));
61 }
62 let code_challenge = request
63 .code_challenge
64 .as_ref()
65 .filter(|s| !s.is_empty())
66 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
67 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
68 if code_challenge_method != "S256" {
69 return Err(OAuthError::InvalidRequest(
70 "code_challenge_method must be 'S256'".to_string(),
71 ));
72 }
73 let client_cache = ClientMetadataCache::new(3600);
74 let client_metadata = client_cache.get(&request.client_id).await?;
75 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
76 let client_auth = determine_client_auth(&request)?;
77 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
78 let request_id = RequestId::generate();
79 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
80 let parameters = AuthorizationRequestParameters {
81 response_type: request.response_type,
82 client_id: request.client_id.clone(),
83 redirect_uri: request.redirect_uri,
84 scope: validated_scope,
85 state: request.state,
86 code_challenge: code_challenge.clone(),
87 code_challenge_method: code_challenge_method.to_string(),
88 login_hint: request.login_hint,
89 dpop_jkt: request.dpop_jkt,
90 extra: None,
91 };
92 let request_data = RequestData {
93 client_id: request.client_id,
94 client_auth: Some(client_auth),
95 parameters,
96 expires_at,
97 did: None,
98 device_id: None,
99 code: None,
100 };
101 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
102 tokio::spawn({
103 let pool = state.db.clone();
104 async move {
105 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
106 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
107 }
108 }
109 });
110 Ok((
111 axum::http::StatusCode::CREATED,
112 Json(ParResponse {
113 request_uri: request_id.0,
114 expires_in: PAR_EXPIRY_SECONDS as u64,
115 }),
116 ))
117}
118
119fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
120 if let (Some(assertion), Some(assertion_type)) =
121 (&request.client_assertion, &request.client_assertion_type)
122 {
123 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
124 return Err(OAuthError::InvalidRequest(
125 "Unsupported client_assertion_type".to_string(),
126 ));
127 }
128 return Ok(ClientAuth::PrivateKeyJwt {
129 client_assertion: assertion.clone(),
130 });
131 }
132 if let Some(secret) = &request.client_secret {
133 return Ok(ClientAuth::SecretPost {
134 client_secret: secret.clone(),
135 });
136 }
137 Ok(ClientAuth::None)
138}
139
140fn validate_scope(
141 requested_scope: &Option<String>,
142 client_metadata: &crate::oauth::client::ClientMetadata,
143) -> Result<Option<String>, OAuthError> {
144 let scope_str = match requested_scope {
145 Some(s) if !s.is_empty() => s,
146 _ => return Ok(Some("atproto".to_string())),
147 };
148 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
149 if requested_scopes.is_empty() {
150 return Ok(Some("atproto".to_string()));
151 }
152 for scope in &requested_scopes {
153 if !SUPPORTED_SCOPES.contains(scope) {
154 return Err(OAuthError::InvalidScope(format!(
155 "Unsupported scope: {}. Supported scopes: {}",
156 scope,
157 SUPPORTED_SCOPES.join(", ")
158 )));
159 }
160 }
161 if let Some(client_scope) = &client_metadata.scope {
162 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
163 for scope in &requested_scopes {
164 if !client_scopes.contains(scope) {
165 return Err(OAuthError::InvalidScope(format!(
166 "Scope '{}' not registered for this client",
167 scope
168 )));
169 }
170 }
171 }
172 Ok(Some(requested_scopes.join(" ")))
173}