this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::HeaderMap, 5}; 6use chrono::{Duration, Utc}; 7use serde::{Deserialize, Serialize}; 8use crate::state::{AppState, RateLimitKind}; 9use crate::oauth::{ 10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 11 client::ClientMetadataCache, 12 db, 13}; 14 15const PAR_EXPIRY_SECONDS: i64 = 600; 16const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 17 18#[derive(Debug, Deserialize)] 19pub struct ParRequest { 20 pub response_type: String, 21 pub client_id: String, 22 pub redirect_uri: String, 23 #[serde(default)] 24 pub scope: Option<String>, 25 #[serde(default)] 26 pub state: Option<String>, 27 #[serde(default)] 28 pub code_challenge: Option<String>, 29 #[serde(default)] 30 pub code_challenge_method: Option<String>, 31 #[serde(default)] 32 pub login_hint: Option<String>, 33 #[serde(default)] 34 pub dpop_jkt: Option<String>, 35 #[serde(default)] 36 pub client_secret: Option<String>, 37 #[serde(default)] 38 pub client_assertion: Option<String>, 39 #[serde(default)] 40 pub client_assertion_type: Option<String>, 41} 42 43#[derive(Debug, Serialize)] 44pub struct ParResponse { 45 pub request_uri: String, 46 pub expires_in: u64, 47} 48 49pub async fn pushed_authorization_request( 50 State(state): State<AppState>, 51 headers: HeaderMap, 52 Form(request): Form<ParRequest>, 53) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 54 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 55 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 56 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 57 return Err(OAuthError::RateLimited); 58 } 59 if request.response_type != "code" { 60 return Err(OAuthError::InvalidRequest( 61 "response_type must be 'code'".to_string(), 62 )); 63 } 64 let code_challenge = request.code_challenge.as_ref() 65 .filter(|s| !s.is_empty()) 66 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 67 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 68 if code_challenge_method != "S256" { 69 return Err(OAuthError::InvalidRequest( 70 "code_challenge_method must be 'S256'".to_string(), 71 )); 72 } 73 let client_cache = ClientMetadataCache::new(3600); 74 let client_metadata = client_cache.get(&request.client_id).await?; 75 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 76 let client_auth = determine_client_auth(&request)?; 77 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 78 let request_id = RequestId::generate(); 79 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 80 let parameters = AuthorizationRequestParameters { 81 response_type: request.response_type, 82 client_id: request.client_id.clone(), 83 redirect_uri: request.redirect_uri, 84 scope: validated_scope, 85 state: request.state, 86 code_challenge: code_challenge.clone(), 87 code_challenge_method: code_challenge_method.to_string(), 88 login_hint: request.login_hint, 89 dpop_jkt: request.dpop_jkt, 90 extra: None, 91 }; 92 let request_data = RequestData { 93 client_id: request.client_id, 94 client_auth: Some(client_auth), 95 parameters, 96 expires_at, 97 did: None, 98 device_id: None, 99 code: None, 100 }; 101 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 102 tokio::spawn({ 103 let pool = state.db.clone(); 104 async move { 105 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 106 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 107 } 108 } 109 }); 110 Ok(( 111 axum::http::StatusCode::CREATED, 112 Json(ParResponse { 113 request_uri: request_id.0, 114 expires_in: PAR_EXPIRY_SECONDS as u64, 115 }), 116 )) 117} 118 119fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 120 if let (Some(assertion), Some(assertion_type)) = 121 (&request.client_assertion, &request.client_assertion_type) 122 { 123 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 124 return Err(OAuthError::InvalidRequest( 125 "Unsupported client_assertion_type".to_string(), 126 )); 127 } 128 return Ok(ClientAuth::PrivateKeyJwt { 129 client_assertion: assertion.clone(), 130 }); 131 } 132 if let Some(secret) = &request.client_secret { 133 return Ok(ClientAuth::SecretPost { 134 client_secret: secret.clone(), 135 }); 136 } 137 Ok(ClientAuth::None) 138} 139 140fn validate_scope( 141 requested_scope: &Option<String>, 142 client_metadata: &crate::oauth::client::ClientMetadata, 143) -> Result<Option<String>, OAuthError> { 144 let scope_str = match requested_scope { 145 Some(s) if !s.is_empty() => s, 146 _ => return Ok(Some("atproto".to_string())), 147 }; 148 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 149 if requested_scopes.is_empty() { 150 return Ok(Some("atproto".to_string())); 151 } 152 for scope in &requested_scopes { 153 if !SUPPORTED_SCOPES.contains(scope) { 154 return Err(OAuthError::InvalidScope(format!( 155 "Unsupported scope: {}. Supported scopes: {}", 156 scope, 157 SUPPORTED_SCOPES.join(", ") 158 ))); 159 } 160 } 161 if let Some(client_scope) = &client_metadata.scope { 162 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 163 for scope in &requested_scopes { 164 if !client_scopes.contains(scope) { 165 return Err(OAuthError::InvalidScope(format!( 166 "Scope '{}' not registered for this client", 167 scope 168 ))); 169 } 170 } 171 } 172 Ok(Some(requested_scopes.join(" "))) 173}