this repo has no description
1use crate::oauth::{ 2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 client::ClientMetadataCache, 4 db, 5 scopes::{ParsedScope, parse_scope}, 6}; 7use crate::state::{AppState, RateLimitKind}; 8use axum::body::Bytes; 9use axum::{Json, extract::State, http::HeaderMap}; 10use chrono::{Duration, Utc}; 11use serde::{Deserialize, Serialize}; 12 13const PAR_EXPIRY_SECONDS: i64 = 600; 14 15#[derive(Debug, Deserialize)] 16pub struct ParRequest { 17 pub response_type: String, 18 pub client_id: String, 19 pub redirect_uri: String, 20 #[serde(default)] 21 pub scope: Option<String>, 22 #[serde(default)] 23 pub state: Option<String>, 24 #[serde(default)] 25 pub code_challenge: Option<String>, 26 #[serde(default)] 27 pub code_challenge_method: Option<String>, 28 #[serde(default)] 29 pub response_mode: Option<String>, 30 #[serde(default)] 31 pub login_hint: Option<String>, 32 #[serde(default)] 33 pub dpop_jkt: Option<String>, 34 #[serde(default)] 35 pub client_secret: Option<String>, 36 #[serde(default)] 37 pub client_assertion: Option<String>, 38 #[serde(default)] 39 pub client_assertion_type: Option<String>, 40} 41 42#[derive(Debug, Serialize)] 43pub struct ParResponse { 44 pub request_uri: String, 45 pub expires_in: u64, 46} 47 48pub async fn pushed_authorization_request( 49 State(state): State<AppState>, 50 headers: HeaderMap, 51 body: Bytes, 52) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 53 let content_type = headers 54 .get("content-type") 55 .and_then(|v| v.to_str().ok()) 56 .unwrap_or(""); 57 let request: ParRequest = if content_type.starts_with("application/json") { 58 serde_json::from_slice(&body) 59 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 60 } else if content_type.starts_with("application/x-www-form-urlencoded") { 61 serde_urlencoded::from_bytes(&body) 62 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))? 63 } else { 64 return Err(OAuthError::InvalidRequest( 65 "Content-Type must be application/json or application/x-www-form-urlencoded" 66 .to_string(), 67 )); 68 }; 69 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 70 if !state 71 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 72 .await 73 { 74 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 75 return Err(OAuthError::RateLimited); 76 } 77 if request.response_type != "code" { 78 return Err(OAuthError::InvalidRequest( 79 "response_type must be 'code'".to_string(), 80 )); 81 } 82 let code_challenge = request 83 .code_challenge 84 .as_ref() 85 .filter(|s| !s.is_empty()) 86 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 87 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 88 if code_challenge_method != "S256" { 89 return Err(OAuthError::InvalidRequest( 90 "code_challenge_method must be 'S256'".to_string(), 91 )); 92 } 93 let client_cache = ClientMetadataCache::new(3600); 94 let client_metadata = client_cache.get(&request.client_id).await?; 95 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 96 let client_auth = determine_client_auth(&request)?; 97 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 98 let request_id = RequestId::generate(); 99 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 100 let response_mode = match request.response_mode.as_deref() { 101 Some("fragment") => Some("fragment".to_string()), 102 Some("query") | None => None, 103 Some(mode) => { 104 return Err(OAuthError::InvalidRequest(format!( 105 "Unsupported response_mode: {}", 106 mode 107 ))); 108 } 109 }; 110 let parameters = AuthorizationRequestParameters { 111 response_type: request.response_type, 112 client_id: request.client_id.clone(), 113 redirect_uri: request.redirect_uri, 114 scope: validated_scope, 115 state: request.state, 116 code_challenge: code_challenge.clone(), 117 code_challenge_method: code_challenge_method.to_string(), 118 response_mode, 119 login_hint: request.login_hint, 120 dpop_jkt: request.dpop_jkt, 121 extra: None, 122 }; 123 let request_data = RequestData { 124 client_id: request.client_id, 125 client_auth: Some(client_auth), 126 parameters, 127 expires_at, 128 did: None, 129 device_id: None, 130 code: None, 131 }; 132 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 133 tokio::spawn({ 134 let pool = state.db.clone(); 135 async move { 136 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 137 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 138 } 139 } 140 }); 141 Ok(( 142 axum::http::StatusCode::CREATED, 143 Json(ParResponse { 144 request_uri: request_id.0, 145 expires_in: PAR_EXPIRY_SECONDS as u64, 146 }), 147 )) 148} 149 150fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 151 if let (Some(assertion), Some(assertion_type)) = 152 (&request.client_assertion, &request.client_assertion_type) 153 { 154 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 155 return Err(OAuthError::InvalidRequest( 156 "Unsupported client_assertion_type".to_string(), 157 )); 158 } 159 return Ok(ClientAuth::PrivateKeyJwt { 160 client_assertion: assertion.clone(), 161 }); 162 } 163 if let Some(secret) = &request.client_secret { 164 return Ok(ClientAuth::SecretPost { 165 client_secret: secret.clone(), 166 }); 167 } 168 Ok(ClientAuth::None) 169} 170 171fn validate_scope( 172 requested_scope: &Option<String>, 173 client_metadata: &crate::oauth::client::ClientMetadata, 174) -> Result<Option<String>, OAuthError> { 175 let scope_str = match requested_scope { 176 Some(s) if !s.is_empty() => s, 177 _ => return Ok(Some("atproto".to_string())), 178 }; 179 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 180 if requested_scopes.is_empty() { 181 return Ok(Some("atproto".to_string())); 182 } 183 let mut has_transition = false; 184 let mut has_granular = false; 185 186 for scope in &requested_scopes { 187 let parsed = parse_scope(scope); 188 match &parsed { 189 ParsedScope::Unknown(_) => { 190 return Err(OAuthError::InvalidScope(format!( 191 "Unsupported scope: {}", 192 scope 193 ))); 194 } 195 ParsedScope::TransitionGeneric 196 | ParsedScope::TransitionChat 197 | ParsedScope::TransitionEmail => { 198 has_transition = true; 199 } 200 ParsedScope::Repo(_) 201 | ParsedScope::Blob(_) 202 | ParsedScope::Rpc(_) 203 | ParsedScope::Account(_) 204 | ParsedScope::Identity(_) 205 | ParsedScope::Include(_) => { 206 has_granular = true; 207 } 208 ParsedScope::Atproto => {} 209 } 210 } 211 212 if has_transition && has_granular { 213 return Err(OAuthError::InvalidScope( 214 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() 215 )); 216 } 217 218 if let Some(client_scope) = &client_metadata.scope { 219 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 220 for scope in &requested_scopes { 221 if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) { 222 return Err(OAuthError::InvalidScope(format!( 223 "Scope '{}' not registered for this client", 224 scope 225 ))); 226 } 227 } 228 } 229 Ok(Some(requested_scopes.join(" "))) 230} 231 232fn scope_matches(client_scope: &str, requested_scope: &str) -> bool { 233 if client_scope == requested_scope { 234 return true; 235 } 236 237 fn get_resource_type(scope: &str) -> &str { 238 let base = scope.split('?').next().unwrap_or(scope); 239 base.split(':').next().unwrap_or(base) 240 } 241 242 let client_type = get_resource_type(client_scope); 243 let requested_type = get_resource_type(requested_scope); 244 245 if client_type == requested_type { 246 let client_base = client_scope.split('?').next().unwrap_or(client_scope); 247 if client_base.contains('*') { 248 return true; 249 } 250 } 251 252 false 253}