this repo has no description
1use crate::oauth::{ 2 AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData, 3 RequestId, db, 4 scopes::{ParsedScope, parse_scope}, 5}; 6use crate::state::{AppState, RateLimitKind}; 7use axum::body::Bytes; 8use axum::{Json, extract::State, http::HeaderMap}; 9use chrono::{Duration, Utc}; 10use serde::{Deserialize, Serialize}; 11 12const PAR_EXPIRY_SECONDS: i64 = 600; 13 14#[derive(Debug, Deserialize)] 15pub struct ParRequest { 16 pub response_type: String, 17 pub client_id: String, 18 pub redirect_uri: String, 19 #[serde(default)] 20 pub scope: Option<String>, 21 #[serde(default)] 22 pub state: Option<String>, 23 #[serde(default)] 24 pub code_challenge: Option<String>, 25 #[serde(default)] 26 pub code_challenge_method: Option<String>, 27 #[serde(default)] 28 pub response_mode: Option<String>, 29 #[serde(default)] 30 pub login_hint: Option<String>, 31 #[serde(default)] 32 pub dpop_jkt: Option<String>, 33 #[serde(default)] 34 pub client_secret: Option<String>, 35 #[serde(default)] 36 pub client_assertion: Option<String>, 37 #[serde(default)] 38 pub client_assertion_type: Option<String>, 39} 40 41#[derive(Debug, Serialize)] 42pub struct ParResponse { 43 pub request_uri: String, 44 pub expires_in: u64, 45} 46 47pub async fn pushed_authorization_request( 48 State(state): State<AppState>, 49 headers: HeaderMap, 50 body: Bytes, 51) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { 52 let content_type = headers 53 .get("content-type") 54 .and_then(|v| v.to_str().ok()) 55 .unwrap_or(""); 56 let request: ParRequest = if content_type.starts_with("application/json") { 57 serde_json::from_slice(&body) 58 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid JSON: {}", e)))? 59 } else if content_type.starts_with("application/x-www-form-urlencoded") { 60 let parsed: ParRequest = serde_urlencoded::from_bytes(&body) 61 .map_err(|e| OAuthError::InvalidRequest(format!("Invalid form data: {}", e)))?; 62 tracing::info!(login_hint = ?parsed.login_hint, "PAR request received (form)"); 63 parsed 64 } else { 65 return Err(OAuthError::InvalidRequest( 66 "Content-Type must be application/json or application/x-www-form-urlencoded" 67 .to_string(), 68 )); 69 }; 70 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 71 if !state 72 .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 73 .await 74 { 75 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 76 return Err(OAuthError::RateLimited); 77 } 78 if request.response_type != "code" { 79 return Err(OAuthError::InvalidRequest( 80 "response_type must be 'code'".to_string(), 81 )); 82 } 83 let code_challenge = request 84 .code_challenge 85 .as_ref() 86 .filter(|s| !s.is_empty()) 87 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 88 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 89 if code_challenge_method != "S256" { 90 return Err(OAuthError::InvalidRequest( 91 "code_challenge_method must be 'S256'".to_string(), 92 )); 93 } 94 let client_cache = ClientMetadataCache::new(3600); 95 let client_metadata = client_cache.get(&request.client_id).await?; 96 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 97 let client_auth = determine_client_auth(&request)?; 98 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 99 let request_id = RequestId::generate(); 100 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 101 let response_mode = match request.response_mode.as_deref() { 102 Some("fragment") => Some("fragment".to_string()), 103 Some("query") | None => None, 104 Some(mode) => { 105 return Err(OAuthError::InvalidRequest(format!( 106 "Unsupported response_mode: {}", 107 mode 108 ))); 109 } 110 }; 111 let parameters = AuthorizationRequestParameters { 112 response_type: request.response_type, 113 client_id: request.client_id.clone(), 114 redirect_uri: request.redirect_uri, 115 scope: validated_scope, 116 state: request.state, 117 code_challenge: code_challenge.clone(), 118 code_challenge_method: code_challenge_method.to_string(), 119 response_mode, 120 login_hint: request.login_hint, 121 dpop_jkt: request.dpop_jkt, 122 extra: None, 123 }; 124 let request_data = RequestData { 125 client_id: request.client_id, 126 client_auth: Some(client_auth), 127 parameters, 128 expires_at, 129 did: None, 130 device_id: None, 131 code: None, 132 controller_did: None, 133 }; 134 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 135 tokio::spawn({ 136 let pool = state.db.clone(); 137 async move { 138 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 139 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 140 } 141 } 142 }); 143 Ok(( 144 axum::http::StatusCode::CREATED, 145 Json(ParResponse { 146 request_uri: request_id.0, 147 expires_in: PAR_EXPIRY_SECONDS as u64, 148 }), 149 )) 150} 151 152fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 153 if let (Some(assertion), Some(assertion_type)) = 154 (&request.client_assertion, &request.client_assertion_type) 155 { 156 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 157 return Err(OAuthError::InvalidRequest( 158 "Unsupported client_assertion_type".to_string(), 159 )); 160 } 161 return Ok(ClientAuth::PrivateKeyJwt { 162 client_assertion: assertion.clone(), 163 }); 164 } 165 if let Some(secret) = &request.client_secret { 166 return Ok(ClientAuth::SecretPost { 167 client_secret: secret.clone(), 168 }); 169 } 170 Ok(ClientAuth::None) 171} 172 173fn validate_scope( 174 requested_scope: &Option<String>, 175 client_metadata: &crate::oauth::ClientMetadata, 176) -> Result<Option<String>, OAuthError> { 177 let scope_str = match requested_scope { 178 Some(s) if !s.is_empty() => s, 179 _ => return Ok(Some("atproto".to_string())), 180 }; 181 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 182 if requested_scopes.is_empty() { 183 return Ok(Some("atproto".to_string())); 184 } 185 if let Some(unknown) = requested_scopes 186 .iter() 187 .find(|s| matches!(parse_scope(s), ParsedScope::Unknown(_))) 188 { 189 return Err(OAuthError::InvalidScope(format!( 190 "Unsupported scope: {}", 191 unknown 192 ))); 193 } 194 195 let has_transition = requested_scopes.iter().any(|s| { 196 matches!( 197 parse_scope(s), 198 ParsedScope::TransitionGeneric 199 | ParsedScope::TransitionChat 200 | ParsedScope::TransitionEmail 201 ) 202 }); 203 let has_granular = requested_scopes.iter().any(|s| { 204 matches!( 205 parse_scope(s), 206 ParsedScope::Repo(_) 207 | ParsedScope::Blob(_) 208 | ParsedScope::Rpc(_) 209 | ParsedScope::Account(_) 210 | ParsedScope::Identity(_) 211 | ParsedScope::Include(_) 212 ) 213 }); 214 215 if has_transition && has_granular { 216 return Err(OAuthError::InvalidScope( 217 "Cannot mix transition scopes with granular scopes. Use either transition:* scopes OR granular scopes (repo:*, blob:*, rpc:*, account:*, include:*), not both.".to_string() 218 )); 219 } 220 221 if let Some(client_scope) = &client_metadata.scope { 222 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 223 if let Some(unregistered) = requested_scopes 224 .iter() 225 .find(|scope| !client_scopes.iter().any(|cs| scope_matches(cs, scope))) 226 { 227 return Err(OAuthError::InvalidScope(format!( 228 "Scope '{}' not registered for this client", 229 unregistered 230 ))); 231 } 232 } 233 Ok(Some(requested_scopes.join(" "))) 234} 235 236fn scope_matches(client_scope: &str, requested_scope: &str) -> bool { 237 if client_scope == requested_scope { 238 return true; 239 } 240 241 fn get_resource_type(scope: &str) -> &str { 242 let base = scope.split('?').next().unwrap_or(scope); 243 base.split(':').next().unwrap_or(base) 244 } 245 246 let client_type = get_resource_type(client_scope); 247 let requested_type = get_resource_type(requested_scope); 248 249 if client_type == requested_type { 250 let client_base = client_scope.split('?').next().unwrap_or(client_scope); 251 if client_base.contains('*') { 252 return true; 253 } 254 } 255 256 false 257}