this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::HeaderMap, 5}; 6use chrono::{Duration, Utc}; 7use serde::{Deserialize, Serialize}; 8use crate::state::{AppState, RateLimitKind}; 9use crate::oauth::{ 10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 11 client::ClientMetadataCache, 12 db, 13}; 14const PAR_EXPIRY_SECONDS: i64 = 600; 15const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 16#[derive(Debug, Deserialize)] 17pub struct ParRequest { 18 pub response_type: String, 19 pub client_id: String, 20 pub redirect_uri: String, 21 #[serde(default)] 22 pub scope: Option<String>, 23 #[serde(default)] 24 pub state: Option<String>, 25 #[serde(default)] 26 pub code_challenge: Option<String>, 27 #[serde(default)] 28 pub code_challenge_method: Option<String>, 29 #[serde(default)] 30 pub login_hint: Option<String>, 31 #[serde(default)] 32 pub dpop_jkt: Option<String>, 33 #[serde(default)] 34 pub client_secret: Option<String>, 35 #[serde(default)] 36 pub client_assertion: Option<String>, 37 #[serde(default)] 38 pub client_assertion_type: Option<String>, 39} 40#[derive(Debug, Serialize)] 41pub struct ParResponse { 42 pub request_uri: String, 43 pub expires_in: u64, 44} 45pub async fn pushed_authorization_request( 46 State(state): State<AppState>, 47 headers: HeaderMap, 48 Form(request): Form<ParRequest>, 49) -> Result<Json<ParResponse>, OAuthError> { 50 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 51 if !state.check_rate_limit(RateLimitKind::OAuthPar, &client_ip).await { 52 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 53 return Err(OAuthError::RateLimited); 54 } 55 if request.response_type != "code" { 56 return Err(OAuthError::InvalidRequest( 57 "response_type must be 'code'".to_string(), 58 )); 59 } 60 let code_challenge = request.code_challenge.as_ref() 61 .filter(|s| !s.is_empty()) 62 .ok_or_else(|| OAuthError::InvalidRequest( 63 "code_challenge is required".to_string(), 64 ))?; 65 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 66 if code_challenge_method != "S256" { 67 return Err(OAuthError::InvalidRequest( 68 "code_challenge_method must be 'S256'".to_string(), 69 )); 70 } 71 let client_cache = ClientMetadataCache::new(3600); 72 let client_metadata = client_cache.get(&request.client_id).await?; 73 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 74 let client_auth = determine_client_auth(&request)?; 75 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() { 76 return Err(OAuthError::InvalidRequest( 77 "dpop_jkt is required for this client".to_string(), 78 )); 79 } 80 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 81 let request_id = RequestId::generate(); 82 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 83 let parameters = AuthorizationRequestParameters { 84 response_type: request.response_type, 85 client_id: request.client_id.clone(), 86 redirect_uri: request.redirect_uri, 87 scope: validated_scope, 88 state: request.state, 89 code_challenge: code_challenge.clone(), 90 code_challenge_method: code_challenge_method.to_string(), 91 login_hint: request.login_hint, 92 dpop_jkt: request.dpop_jkt, 93 extra: None, 94 }; 95 let request_data = RequestData { 96 client_id: request.client_id, 97 client_auth: Some(client_auth), 98 parameters, 99 expires_at, 100 did: None, 101 device_id: None, 102 code: None, 103 }; 104 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 105 tokio::spawn({ 106 let pool = state.db.clone(); 107 async move { 108 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 109 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 110 } 111 } 112 }); 113 Ok(Json(ParResponse { 114 request_uri: request_id.0, 115 expires_in: PAR_EXPIRY_SECONDS as u64, 116 })) 117} 118fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 119 if let (Some(assertion), Some(assertion_type)) = 120 (&request.client_assertion, &request.client_assertion_type) 121 { 122 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 123 return Err(OAuthError::InvalidRequest( 124 "Unsupported client_assertion_type".to_string(), 125 )); 126 } 127 return Ok(ClientAuth::PrivateKeyJwt { 128 client_assertion: assertion.clone(), 129 }); 130 } 131 if let Some(secret) = &request.client_secret { 132 return Ok(ClientAuth::SecretPost { 133 client_secret: secret.clone(), 134 }); 135 } 136 Ok(ClientAuth::None) 137} 138fn validate_scope( 139 requested_scope: &Option<String>, 140 client_metadata: &crate::oauth::client::ClientMetadata, 141) -> Result<Option<String>, OAuthError> { 142 let scope_str = match requested_scope { 143 Some(s) if !s.is_empty() => s, 144 _ => return Ok(Some("atproto".to_string())), 145 }; 146 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 147 if requested_scopes.is_empty() { 148 return Ok(Some("atproto".to_string())); 149 } 150 for scope in &requested_scopes { 151 if !SUPPORTED_SCOPES.contains(scope) { 152 return Err(OAuthError::InvalidScope(format!( 153 "Unsupported scope: {}. Supported scopes: {}", 154 scope, 155 SUPPORTED_SCOPES.join(", ") 156 ))); 157 } 158 } 159 if let Some(client_scope) = &client_metadata.scope { 160 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 161 for scope in &requested_scopes { 162 if !client_scopes.contains(scope) { 163 return Err(OAuthError::InvalidScope(format!( 164 "Scope '{}' not registered for this client", 165 scope 166 ))); 167 } 168 } 169 } 170 Ok(Some(requested_scopes.join(" "))) 171}