this repo has no description
1use crate::oauth::{ 2 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 3 client::ClientMetadataCache, 4 db, 5 scopes::{ParsedScope, parse_scope}, 6}; 7use crate::state::{AppState, RateLimitKind}; 8use axum::body::Bytes; 9use axum::{Json, extract::State, http::HeaderMap}; 10use chrono::{Duration, Utc}; 11use serde::{Deserialize, Serialize}; 12 13const PAR_EXPIRY_SECONDS: i64 = 600; 14 15#[derive(Debug, Deserialize)] 16pub struct ParRequest { 17 pub response_type: String, 18 pub client_id: String, 19 pub redirect_uri: String, 20 #[serde(default)] 21 pub scope: Option<String>, 22 #[serde(default)] 23 pub state: Option<String>, 24 #[serde(default)] 25 pub code_challenge: Option<String>, 26 #[serde(default)] 27 pub code_challenge_method: Option<String>, 28 #[serde(default)] 29 pub response_mode: Option<String>, 30 #[serde(default)] 31 pub login_hint: Option<String>, 32 #[serde(default)] 33 pub dpop_jkt: Option<String>, 34 #[serde(default)] 35 pub client_secret: Option<String>, 36 #[serde(default)] 37 pub client_assertion: Option<String>, 38 #[serde(default)] 39 pub client_assertion_type: Option<String>, 40} 41 42#[derive(Debug, Serialize)] 43pub struct ParResponse { 44 pub request_uri: String, 45 pub expires_in: u64, 46} 47 48pub async fn pushed_authorization_request( 49 State(state): State<AppState>, 50 headers: HeaderMap, 51 body: Bytes, 52) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 53 let content_type = headers 54 .get("content-type") 55 .and_then(|v| v.to_str().ok()) 56 .unwrap_or(""); 57 let request: ParRequest = if content_type.starts_with("application/json") { 58 serde_json::from_slice(&body) 59 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 60 } else if content_type.starts_with("application/x-www-form-urlencoded") { 61 let parsed: ParRequest = serde_urlencoded::from_bytes(&body) 62 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?; 63 tracing::info!(login_hint = ?parsed.login_hint, "PAR request received (form)"); 64 parsed 65 } else { 66 return Err(OAuthError::InvalidRequest( 67 "Content-Type must be application/json or application/x-www-form-urlencoded" 68 .to_string(), 69 )); 70 }; 71 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 72 if !state 73 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 74 .await 75 { 76 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 77 return Err(OAuthError::RateLimited); 78 } 79 if request.response_type != "code" { 80 return Err(OAuthError::InvalidRequest( 81 "response_type must be 'code'".to_string(), 82 )); 83 } 84 let code_challenge = request 85 .code_challenge 86 .as_ref() 87 .filter(|s| !s.is_empty()) 88 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 89 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 90 if code_challenge_method != "S256" { 91 return Err(OAuthError::InvalidRequest( 92 "code_challenge_method must be 'S256'".to_string(), 93 )); 94 } 95 let client_cache = ClientMetadataCache::new(3600); 96 let client_metadata = client_cache.get(&request.client_id).await?; 97 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 98 let client_auth = determine_client_auth(&request)?; 99 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 100 let request_id = RequestId::generate(); 101 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 102 let response_mode = match request.response_mode.as_deref() { 103 Some("fragment") => Some("fragment".to_string()), 104 Some("query") | None => None, 105 Some(mode) => { 106 return Err(OAuthError::InvalidRequest(format!( 107 "Unsupported response_mode: {}", 108 mode 109 ))); 110 } 111 }; 112 let parameters = AuthorizationRequestParameters { 113 response_type: request.response_type, 114 client_id: request.client_id.clone(), 115 redirect_uri: request.redirect_uri, 116 scope: validated_scope, 117 state: request.state, 118 code_challenge: code_challenge.clone(), 119 code_challenge_method: code_challenge_method.to_string(), 120 response_mode, 121 login_hint: request.login_hint, 122 dpop_jkt: request.dpop_jkt, 123 extra: None, 124 }; 125 let request_data = RequestData { 126 client_id: request.client_id, 127 client_auth: Some(client_auth), 128 parameters, 129 expires_at, 130 did: None, 131 device_id: None, 132 code: None, 133 controller_did: None, 134 }; 135 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 136 tokio::spawn({ 137 let pool = state.db.clone(); 138 async move { 139 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 140 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 141 } 142 } 143 }); 144 Ok(( 145 axum::http::StatusCode::CREATED, 146 Json(ParResponse { 147 request_uri: request_id.0, 148 expires_in: PAR_EXPIRY_SECONDS as u64, 149 }), 150 )) 151} 152 153fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 154 if let (Some(assertion), Some(assertion_type)) = 155 (&request.client_assertion, &request.client_assertion_type) 156 { 157 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 158 return Err(OAuthError::InvalidRequest( 159 "Unsupported client_assertion_type".to_string(), 160 )); 161 } 162 return Ok(ClientAuth::PrivateKeyJwt { 163 client_assertion: assertion.clone(), 164 }); 165 } 166 if let Some(secret) = &request.client_secret { 167 return Ok(ClientAuth::SecretPost { 168 client_secret: secret.clone(), 169 }); 170 } 171 Ok(ClientAuth::None) 172} 173 174fn validate_scope( 175 requested_scope: &Option<String>, 176 client_metadata: &crate::oauth::client::ClientMetadata, 177) -> Result<Option<String>, OAuthError> { 178 let scope_str = match requested_scope { 179 Some(s) if !s.is_empty() => s, 180 _ => return Ok(Some("atproto".to_string())), 181 }; 182 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 183 if requested_scopes.is_empty() { 184 return Ok(Some("atproto".to_string())); 185 } 186 let mut has_transition = false; 187 let mut has_granular = false; 188 189 for scope in &requested_scopes { 190 let parsed = parse_scope(scope); 191 match &parsed { 192 ParsedScope::Unknown(_) => { 193 return Err(OAuthError::InvalidScope(format!( 194 "Unsupported scope: {}", 195 scope 196 ))); 197 } 198 ParsedScope::TransitionGeneric 199 | ParsedScope::TransitionChat 200 | ParsedScope::TransitionEmail => { 201 has_transition = true; 202 } 203 ParsedScope::Repo(_) 204 | ParsedScope::Blob(_) 205 | ParsedScope::Rpc(_) 206 | ParsedScope::Account(_) 207 | ParsedScope::Identity(_) 208 | ParsedScope::Include(_) => { 209 has_granular = true; 210 } 211 ParsedScope::Atproto => {} 212 } 213 } 214 215 if has_transition && has_granular { 216 return Err(OAuthError::InvalidScope( 217 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() 218 )); 219 } 220 221 if let Some(client_scope) = &client_metadata.scope { 222 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 223 for scope in &requested_scopes { 224 if !client_scopes.iter().any(|cs| scope_matches(cs, scope)) { 225 return Err(OAuthError::InvalidScope(format!( 226 "Scope '{}' not registered for this client", 227 scope 228 ))); 229 } 230 } 231 } 232 Ok(Some(requested_scopes.join(" "))) 233} 234 235fn scope_matches(client_scope: &str, requested_scope: &str) -> bool { 236 if client_scope == requested_scope { 237 return true; 238 } 239 240 fn get_resource_type(scope: &str) -> &str { 241 let base = scope.split('?').next().unwrap_or(scope); 242 base.split(':').next().unwrap_or(base) 243 } 244 245 let client_type = get_resource_type(client_scope); 246 let requested_type = get_resource_type(requested_scope); 247 248 if client_type == requested_type { 249 let client_base = client_scope.split('?').next().unwrap_or(client_scope); 250 if client_base.contains('*') { 251 return true; 252 } 253 } 254 255 false 256}