this repo has no description
1use axum::{
2 Form, Json,
3 extract::State,
4};
5use chrono::{Duration, Utc};
6use serde::{Deserialize, Serialize};
7
8use crate::state::AppState;
9use crate::oauth::{
10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId,
11 client::ClientMetadataCache,
12 db,
13};
14
15const PAR_EXPIRY_SECONDS: i64 = 600;
16
17const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"];
18
19#[derive(Debug, Deserialize)]
20pub struct ParRequest {
21 pub response_type: String,
22 pub client_id: String,
23 pub redirect_uri: String,
24 #[serde(default)]
25 pub scope: Option<String>,
26 #[serde(default)]
27 pub state: Option<String>,
28 #[serde(default)]
29 pub code_challenge: Option<String>,
30 #[serde(default)]
31 pub code_challenge_method: Option<String>,
32 #[serde(default)]
33 pub login_hint: Option<String>,
34 #[serde(default)]
35 pub dpop_jkt: Option<String>,
36 #[serde(default)]
37 pub client_secret: Option<String>,
38 #[serde(default)]
39 pub client_assertion: Option<String>,
40 #[serde(default)]
41 pub client_assertion_type: Option<String>,
42}
43
44#[derive(Debug, Serialize)]
45pub struct ParResponse {
46 pub request_uri: String,
47 pub expires_in: u64,
48}
49
50pub async fn pushed_authorization_request(
51 State(state): State<AppState>,
52 Form(request): Form<ParRequest>,
53) -> Result<Json<ParResponse>, OAuthError> {
54 if request.response_type != "code" {
55 return Err(OAuthError::InvalidRequest(
56 "response_type must be 'code'".to_string(),
57 ));
58 }
59
60 let code_challenge = request.code_challenge.as_ref()
61 .filter(|s| !s.is_empty())
62 .ok_or_else(|| OAuthError::InvalidRequest(
63 "code_challenge is required".to_string(),
64 ))?;
65
66 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
67 if code_challenge_method != "S256" {
68 return Err(OAuthError::InvalidRequest(
69 "code_challenge_method must be 'S256'".to_string(),
70 ));
71 }
72
73 let client_cache = ClientMetadataCache::new(3600);
74 let client_metadata = client_cache.get(&request.client_id).await?;
75
76 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
77
78 let client_auth = determine_client_auth(&request)?;
79
80 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() {
81 return Err(OAuthError::InvalidRequest(
82 "dpop_jkt is required for this client".to_string(),
83 ));
84 }
85
86 let validated_scope = validate_scope(&request.scope, &client_metadata)?;
87
88 let request_id = RequestId::generate();
89 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
90
91 let parameters = AuthorizationRequestParameters {
92 response_type: request.response_type,
93 client_id: request.client_id.clone(),
94 redirect_uri: request.redirect_uri,
95 scope: validated_scope,
96 state: request.state,
97 code_challenge: code_challenge.clone(),
98 code_challenge_method: code_challenge_method.to_string(),
99 login_hint: request.login_hint,
100 dpop_jkt: request.dpop_jkt,
101 extra: None,
102 };
103
104 let request_data = RequestData {
105 client_id: request.client_id,
106 client_auth: Some(client_auth),
107 parameters,
108 expires_at,
109 did: None,
110 device_id: None,
111 code: None,
112 };
113
114 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?;
115
116 tokio::spawn({
117 let pool = state.db.clone();
118 async move {
119 if let Err(e) = db::delete_expired_authorization_requests(&pool).await {
120 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e);
121 }
122 }
123 });
124
125 Ok(Json(ParResponse {
126 request_uri: request_id.0,
127 expires_in: PAR_EXPIRY_SECONDS as u64,
128 }))
129}
130
131fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> {
132 if let (Some(assertion), Some(assertion_type)) =
133 (&request.client_assertion, &request.client_assertion_type)
134 {
135 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
136 return Err(OAuthError::InvalidRequest(
137 "Unsupported client_assertion_type".to_string(),
138 ));
139 }
140 return Ok(ClientAuth::PrivateKeyJwt {
141 client_assertion: assertion.clone(),
142 });
143 }
144
145 if let Some(secret) = &request.client_secret {
146 return Ok(ClientAuth::SecretPost {
147 client_secret: secret.clone(),
148 });
149 }
150
151 Ok(ClientAuth::None)
152}
153
154fn validate_scope(
155 requested_scope: &Option<String>,
156 client_metadata: &crate::oauth::client::ClientMetadata,
157) -> Result<Option<String>, OAuthError> {
158 let scope_str = match requested_scope {
159 Some(s) if !s.is_empty() => s,
160 _ => return Ok(Some("atproto".to_string())),
161 };
162
163 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
164
165 if requested_scopes.is_empty() {
166 return Ok(Some("atproto".to_string()));
167 }
168
169 for scope in &requested_scopes {
170 if !SUPPORTED_SCOPES.contains(scope) {
171 return Err(OAuthError::InvalidScope(format!(
172 "Unsupported scope: {}. Supported scopes: {}",
173 scope,
174 SUPPORTED_SCOPES.join(", ")
175 )));
176 }
177 }
178
179 if let Some(client_scope) = &client_metadata.scope {
180 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect();
181 for scope in &requested_scopes {
182 if !client_scopes.contains(scope) {
183 return Err(OAuthError::InvalidScope(format!(
184 "Scope '{}' not registered for this client",
185 scope
186 )));
187 }
188 }
189 }
190
191 Ok(Some(requested_scopes.join(" ")))
192}