this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::HeaderMap, 5}; 6use chrono::{Duration, Utc}; 7use serde::{Deserialize, Serialize}; 8 9use crate::state::{AppState, RateLimitKind}; 10use crate::oauth::{ 11 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 12 client::ClientMetadataCache, 13 db, 14}; 15 16const PAR_EXPIRY_SECONDS: i64 = 600; 17 18const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 19 20#[derive(Debug, Deserialize)] 21pub struct ParRequest { 22 pub response_type: String, 23 pub client_id: String, 24 pub redirect_uri: String, 25 #[serde(default)] 26 pub scope: Option<String>, 27 #[serde(default)] 28 pub state: Option<String>, 29 #[serde(default)] 30 pub code_challenge: Option<String>, 31 #[serde(default)] 32 pub code_challenge_method: Option<String>, 33 #[serde(default)] 34 pub login_hint: Option<String>, 35 #[serde(default)] 36 pub dpop_jkt: Option<String>, 37 #[serde(default)] 38 pub client_secret: Option<String>, 39 #[serde(default)] 40 pub client_assertion: Option<String>, 41 #[serde(default)] 42 pub client_assertion_type: Option<String>, 43} 44 45#[derive(Debug, Serialize)] 46pub struct ParResponse { 47 pub request_uri: String, 48 pub expires_in: u64, 49} 50 51pub async fn pushed_authorization_request( 52 State(state): State<AppState>, 53 headers: HeaderMap, 54 Form(request): Form<ParRequest>, 55) -> Result<Json<ParResponse>, OAuthError> { 56 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 57 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 58 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 59 return Err(OAuthError::RateLimited); 60 } 61 62 if request.response_type != "code" { 63 return Err(OAuthError::InvalidRequest( 64 "response_type must be 'code'".to_string(), 65 )); 66 } 67 68 let code_challenge = request.code_challenge.as_ref() 69 .filter(|s| !s.is_empty()) 70 .ok_or_else(|| OAuthError::InvalidRequest( 71 "code_challenge is required".to_string(), 72 ))?; 73 74 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 75 if code_challenge_method != "S256" { 76 return Err(OAuthError::InvalidRequest( 77 "code_challenge_method must be 'S256'".to_string(), 78 )); 79 } 80 81 let client_cache = ClientMetadataCache::new(3600); 82 let client_metadata = client_cache.get(&request.client_id).await?; 83 84 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 85 86 let client_auth = determine_client_auth(&request)?; 87 88 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() { 89 return Err(OAuthError::InvalidRequest( 90 "dpop_jkt is required for this client".to_string(), 91 )); 92 } 93 94 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 95 96 let request_id = RequestId::generate(); 97 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 98 99 let parameters = AuthorizationRequestParameters { 100 response_type: request.response_type, 101 client_id: request.client_id.clone(), 102 redirect_uri: request.redirect_uri, 103 scope: validated_scope, 104 state: request.state, 105 code_challenge: code_challenge.clone(), 106 code_challenge_method: code_challenge_method.to_string(), 107 login_hint: request.login_hint, 108 dpop_jkt: request.dpop_jkt, 109 extra: None, 110 }; 111 112 let request_data = RequestData { 113 client_id: request.client_id, 114 client_auth: Some(client_auth), 115 parameters, 116 expires_at, 117 did: None, 118 device_id: None, 119 code: None, 120 }; 121 122 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 123 124 tokio::spawn({ 125 let pool = state.db.clone(); 126 async move { 127 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 128 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 129 } 130 } 131 }); 132 133 Ok(Json(ParResponse { 134 request_uri: request_id.0, 135 expires_in: PAR_EXPIRY_SECONDS as u64, 136 })) 137} 138 139fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 140 if let (Some(assertion), Some(assertion_type)) = 141 (&request.client_assertion, &request.client_assertion_type) 142 { 143 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 144 return Err(OAuthError::InvalidRequest( 145 "Unsupported client_assertion_type".to_string(), 146 )); 147 } 148 return Ok(ClientAuth::PrivateKeyJwt { 149 client_assertion: assertion.clone(), 150 }); 151 } 152 153 if let Some(secret) = &request.client_secret { 154 return Ok(ClientAuth::SecretPost { 155 client_secret: secret.clone(), 156 }); 157 } 158 159 Ok(ClientAuth::None) 160} 161 162fn validate_scope( 163 requested_scope: &Option<String>, 164 client_metadata: &crate::oauth::client::ClientMetadata, 165) -> Result<Option<String>, OAuthError> { 166 let scope_str = match requested_scope { 167 Some(s) if !s.is_empty() => s, 168 _ => return Ok(Some("atproto".to_string())), 169 }; 170 171 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 172 173 if requested_scopes.is_empty() { 174 return Ok(Some("atproto".to_string())); 175 } 176 177 for scope in &requested_scopes { 178 if !SUPPORTED_SCOPES.contains(scope) { 179 return Err(OAuthError::InvalidScope(format!( 180 "Unsupported scope: {}. Supported scopes: {}", 181 scope, 182 SUPPORTED_SCOPES.join(", ") 183 ))); 184 } 185 } 186 187 if let Some(client_scope) = &client_metadata.scope { 188 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 189 for scope in &requested_scopes { 190 if !client_scopes.contains(scope) { 191 return Err(OAuthError::InvalidScope(format!( 192 "Scope '{}' not registered for this client", 193 scope 194 ))); 195 } 196 } 197 } 198 199 Ok(Some(requested_scopes.join(" "))) 200}