this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::HeaderMap, 5}; 6use chrono::{Duration, Utc}; 7use serde::{Deserialize, Serialize}; 8use crate::state::{AppState, RateLimitKind}; 9use crate::oauth::{ 10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 11 client::ClientMetadataCache, 12 db, 13}; 14 15const PAR_EXPIRY_SECONDS: i64 = 600; 16const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 17 18#[derive(Debug, Deserialize)] 19pub struct ParRequest { 20 pub response_type: String, 21 pub client_id: String, 22 pub redirect_uri: String, 23 #[serde(default)] 24 pub scope: Option<String>, 25 #[serde(default)] 26 pub state: Option<String>, 27 #[serde(default)] 28 pub code_challenge: Option<String>, 29 #[serde(default)] 30 pub code_challenge_method: Option<String>, 31 #[serde(default)] 32 pub login_hint: Option<String>, 33 #[serde(default)] 34 pub dpop_jkt: Option<String>, 35 #[serde(default)] 36 pub client_secret: Option<String>, 37 #[serde(default)] 38 pub client_assertion: Option<String>, 39 #[serde(default)] 40 pub client_assertion_type: Option<String>, 41} 42 43#[derive(Debug, Serialize)] 44pub struct ParResponse { 45 pub request_uri: String, 46 pub expires_in: u64, 47} 48 49pub async fn pushed_authorization_request( 50 State(state): State<AppState>, 51 headers: HeaderMap, 52 Form(request): Form<ParRequest>, 53) -> Result<Json<ParResponse>, OAuthError> { 54 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 55 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 56 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 57 return Err(OAuthError::RateLimited); 58 } 59 if request.response_type != "code" { 60 return Err(OAuthError::InvalidRequest( 61 "response_type must be 'code'".to_string(), 62 )); 63 } 64 let code_challenge = request.code_challenge.as_ref() 65 .filter(|s| !s.is_empty()) 66 .ok_or_else(|| OAuthError::InvalidRequest( 67 "code_challenge is required".to_string(), 68 ))?; 69 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 70 if code_challenge_method != "S256" { 71 return Err(OAuthError::InvalidRequest( 72 "code_challenge_method must be 'S256'".to_string(), 73 )); 74 } 75 let client_cache = ClientMetadataCache::new(3600); 76 let client_metadata = client_cache.get(&request.client_id).await?; 77 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 78 let client_auth = determine_client_auth(&request)?; 79 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() { 80 return Err(OAuthError::InvalidRequest( 81 "dpop_jkt is required for this client".to_string(), 82 )); 83 } 84 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 85 let request_id = RequestId::generate(); 86 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 87 let parameters = AuthorizationRequestParameters { 88 response_type: request.response_type, 89 client_id: request.client_id.clone(), 90 redirect_uri: request.redirect_uri, 91 scope: validated_scope, 92 state: request.state, 93 code_challenge: code_challenge.clone(), 94 code_challenge_method: code_challenge_method.to_string(), 95 login_hint: request.login_hint, 96 dpop_jkt: request.dpop_jkt, 97 extra: None, 98 }; 99 let request_data = RequestData { 100 client_id: request.client_id, 101 client_auth: Some(client_auth), 102 parameters, 103 expires_at, 104 did: None, 105 device_id: None, 106 code: None, 107 }; 108 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 109 tokio::spawn({ 110 let pool = state.db.clone(); 111 async move { 112 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 113 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 114 } 115 } 116 }); 117 Ok(Json(ParResponse { 118 request_uri: request_id.0, 119 expires_in: PAR_EXPIRY_SECONDS as u64, 120 })) 121} 122 123fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 124 if let (Some(assertion), Some(assertion_type)) = 125 (&request.client_assertion, &request.client_assertion_type) 126 { 127 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 128 return Err(OAuthError::InvalidRequest( 129 "Unsupported client_assertion_type".to_string(), 130 )); 131 } 132 return Ok(ClientAuth::PrivateKeyJwt { 133 client_assertion: assertion.clone(), 134 }); 135 } 136 if let Some(secret) = &request.client_secret { 137 return Ok(ClientAuth::SecretPost { 138 client_secret: secret.clone(), 139 }); 140 } 141 Ok(ClientAuth::None) 142} 143 144fn validate_scope( 145 requested_scope: &Option<String>, 146 client_metadata: &crate::oauth::client::ClientMetadata, 147) -> Result<Option<String>, OAuthError> { 148 let scope_str = match requested_scope { 149 Some(s) if !s.is_empty() => s, 150 _ => return Ok(Some("atproto".to_string())), 151 }; 152 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 153 if requested_scopes.is_empty() { 154 return Ok(Some("atproto".to_string())); 155 } 156 for scope in &requested_scopes { 157 if !SUPPORTED_SCOPES.contains(scope) { 158 return Err(OAuthError::InvalidScope(format!( 159 "Unsupported scope: {}. Supported scopes: {}", 160 scope, 161 SUPPORTED_SCOPES.join(", ") 162 ))); 163 } 164 } 165 if let Some(client_scope) = &client_metadata.scope { 166 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 167 for scope in &requested_scopes { 168 if !client_scopes.contains(scope) { 169 return Err(OAuthError::InvalidScope(format!( 170 "Scope '{}' not registered for this client", 171 scope 172 ))); 173 } 174 } 175 } 176 Ok(Some(requested_scopes.join(" "))) 177}