this repo has no description
1use axum::{
2 Form, Json,
3 extract::State,
4 http::HeaderMap,
5};
6use chrono::{Duration, Utc};
7use serde::{Deserialize, Serialize};
8use crate::state::{AppState, RateLimitKind};
9use crate::oauth::{
10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
11 client::ClientMetadataCache,
12 db,
13};
14
15const PAR_EXPIRY_SECONDS: i64 = 600;
16const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
17
18#[derive(Debug, Deserialize)]
19pub struct ParRequest {
20 pub response_type: String,
21 pub client_id: String,
22 pub redirect_uri: String,
23 #[serde(default)]
24 pub scope: Option<String>,
25 #[serde(default)]
26 pub state: Option<String>,
27 #[serde(default)]
28 pub code_challenge: Option<String>,
29 #[serde(default)]
30 pub code_challenge_method: Option<String>,
31 #[serde(default)]
32 pub login_hint: Option<String>,
33 #[serde(default)]
34 pub dpop_jkt: Option<String>,
35 #[serde(default)]
36 pub client_secret: Option<String>,
37 #[serde(default)]
38 pub client_assertion: Option<String>,
39 #[serde(default)]
40 pub client_assertion_type: Option<String>,
41}
42
43#[derive(Debug, Serialize)]
44pub struct ParResponse {
45 pub request_uri: String,
46 pub expires_in: u64,
47}
48
49pub async fn pushed_authorization_request(
50 State(state): State<AppState>,
51 headers: HeaderMap,
52 Form(request): Form<ParRequest>,
53) -> Result<Json<ParResponse>, OAuthError> {
54 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
55 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await {
56 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
57 return Err(OAuthError::RateLimited);
58 }
59 if request.response_type != "code" {
60 return Err(OAuthError::InvalidRequest(
61 "response_type must be 'code'".to_string(),
62 ));
63 }
64 let code_challenge = request.code_challenge.as_ref()
65 .filter(|s| !s.is_empty())
66 .ok_or_else(|| OAuthError::InvalidRequest(
67 "code_challenge is required".to_string(),
68 ))?;
69 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
70 if code_challenge_method != "S256" {
71 return Err(OAuthError::InvalidRequest(
72 "code_challenge_method must be 'S256'".to_string(),
73 ));
74 }
75 let client_cache = ClientMetadataCache::new(3600);
76 let client_metadata = client_cache.get(&request.client_id).await?;
77 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
78 let client_auth = determine_client_auth(&request)?;
79 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() {
80 return Err(OAuthError::InvalidRequest(
81 "dpop_jkt is required for this client".to_string(),
82 ));
83 }
84 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
85 let request_id = RequestId::generate();
86 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
87 let parameters = AuthorizationRequestParameters {
88 response_type: request.response_type,
89 client_id: request.client_id.clone(),
90 redirect_uri: request.redirect_uri,
91 scope: validated_scope,
92 state: request.state,
93 code_challenge: code_challenge.clone(),
94 code_challenge_method: code_challenge_method.to_string(),
95 login_hint: request.login_hint,
96 dpop_jkt: request.dpop_jkt,
97 extra: None,
98 };
99 let request_data = RequestData {
100 client_id: request.client_id,
101 client_auth: Some(client_auth),
102 parameters,
103 expires_at,
104 did: None,
105 device_id: None,
106 code: None,
107 };
108 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
109 tokio::spawn({
110 let pool = state.db.clone();
111 async move {
112 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
113 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
114 }
115 }
116 });
117 Ok(Json(ParResponse {
118 request_uri: request_id.0,
119 expires_in: PAR_EXPIRY_SECONDS as u64,
120 }))
121}
122
123fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
124 if let (Some(assertion), Some(assertion_type)) =
125 (&request.client_assertion, &request.client_assertion_type)
126 {
127 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
128 return Err(OAuthError::InvalidRequest(
129 "Unsupported client_assertion_type".to_string(),
130 ));
131 }
132 return Ok(ClientAuth::PrivateKeyJwt {
133 client_assertion: assertion.clone(),
134 });
135 }
136 if let Some(secret) = &request.client_secret {
137 return Ok(ClientAuth::SecretPost {
138 client_secret: secret.clone(),
139 });
140 }
141 Ok(ClientAuth::None)
142}
143
144fn validate_scope(
145 requested_scope: &Option<String>,
146 client_metadata: &crate::oauth::client::ClientMetadata,
147) -> Result<Option<String>, OAuthError> {
148 let scope_str = match requested_scope {
149 Some(s) if !s.is_empty() => s,
150 _ => return Ok(Some("atproto".to_string())),
151 };
152 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
153 if requested_scopes.is_empty() {
154 return Ok(Some("atproto".to_string()));
155 }
156 for scope in &requested_scopes {
157 if !SUPPORTED_SCOPES.contains(scope) {
158 return Err(OAuthError::InvalidScope(format!(
159 "Unsupported scope: {}. Supported scopes: {}",
160 scope,
161 SUPPORTED_SCOPES.join(", ")
162 )));
163 }
164 }
165 if let Some(client_scope) = &client_metadata.scope {
166 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
167 for scope in &requested_scopes {
168 if !client_scopes.contains(scope) {
169 return Err(OAuthError::InvalidScope(format!(
170 "Scope '{}' not registered for this client",
171 scope
172 )));
173 }
174 }
175 }
176 Ok(Some(requested_scopes.join(" ")))
177}