this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::HeaderMap, 5}; 6use chrono::{Duration, Utc}; 7use serde::{Deserialize, Serialize}; 8 9use crate::state::AppState; 10use crate::oauth::{ 11 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 12 client::ClientMetadataCache, 13 db, 14}; 15 16const PAR_EXPIRY_SECONDS: i64 = 600; 17 18const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 19 20#[derive(Debug, Deserialize)] 21pub struct ParRequest { 22 pub response_type: String, 23 pub client_id: String, 24 pub redirect_uri: String, 25 #[serde(default)] 26 pub scope: Option<String>, 27 #[serde(default)] 28 pub state: Option<String>, 29 #[serde(default)] 30 pub code_challenge: Option<String>, 31 #[serde(default)] 32 pub code_challenge_method: Option<String>, 33 #[serde(default)] 34 pub login_hint: Option<String>, 35 #[serde(default)] 36 pub dpop_jkt: Option<String>, 37 #[serde(default)] 38 pub client_secret: Option<String>, 39 #[serde(default)] 40 pub client_assertion: Option<String>, 41 #[serde(default)] 42 pub client_assertion_type: Option<String>, 43} 44 45#[derive(Debug, Serialize)] 46pub struct ParResponse { 47 pub request_uri: String, 48 pub expires_in: u64, 49} 50 51pub async fn pushed_authorization_request( 52 State(state): State<AppState>, 53 headers: HeaderMap, 54 Form(request): Form<ParRequest>, 55) -> Result<Json<ParResponse>, OAuthError> { 56 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 57 if !state.distributed_rate_limiter.check_rate_limit( 58 &format!("oauth_par:{}", client_ip), 59 30, 60 60_000, 61 ).await { 62 if state.rate_limiters.oauth_par.check_key(&client_ip).is_err() { 63 tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 64 return Err(OAuthError::RateLimited); 65 } 66 } 67 68 if request.response_type != "code" { 69 return Err(OAuthError::InvalidRequest( 70 "response_type must be 'code'".to_string(), 71 )); 72 } 73 74 let code_challenge = request.code_challenge.as_ref() 75 .filter(|s| !s.is_empty()) 76 .ok_or_else(|| OAuthError::InvalidRequest( 77 "code_challenge is required".to_string(), 78 ))?; 79 80 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 81 if code_challenge_method != "S256" { 82 return Err(OAuthError::InvalidRequest( 83 "code_challenge_method must be 'S256'".to_string(), 84 )); 85 } 86 87 let client_cache = ClientMetadataCache::new(3600); 88 let client_metadata = client_cache.get(&request.client_id).await?; 89 90 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 91 92 let client_auth = determine_client_auth(&request)?; 93 94 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() { 95 return Err(OAuthError::InvalidRequest( 96 "dpop_jkt is required for this client".to_string(), 97 )); 98 } 99 100 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 101 102 let request_id = RequestId::generate(); 103 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 104 105 let parameters = AuthorizationRequestParameters { 106 response_type: request.response_type, 107 client_id: request.client_id.clone(), 108 redirect_uri: request.redirect_uri, 109 scope: validated_scope, 110 state: request.state, 111 code_challenge: code_challenge.clone(), 112 code_challenge_method: code_challenge_method.to_string(), 113 login_hint: request.login_hint, 114 dpop_jkt: request.dpop_jkt, 115 extra: None, 116 }; 117 118 let request_data = RequestData { 119 client_id: request.client_id, 120 client_auth: Some(client_auth), 121 parameters, 122 expires_at, 123 did: None, 124 device_id: None, 125 code: None, 126 }; 127 128 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 129 130 tokio::spawn({ 131 let pool = state.db.clone(); 132 async move { 133 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 134 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 135 } 136 } 137 }); 138 139 Ok(Json(ParResponse { 140 request_uri: request_id.0, 141 expires_in: PAR_EXPIRY_SECONDS as u64, 142 })) 143} 144 145fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 146 if let (Some(assertion), Some(assertion_type)) = 147 (&request.client_assertion, &request.client_assertion_type) 148 { 149 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 150 return Err(OAuthError::InvalidRequest( 151 "Unsupported client_assertion_type".to_string(), 152 )); 153 } 154 return Ok(ClientAuth::PrivateKeyJwt { 155 client_assertion: assertion.clone(), 156 }); 157 } 158 159 if let Some(secret) = &request.client_secret { 160 return Ok(ClientAuth::SecretPost { 161 client_secret: secret.clone(), 162 }); 163 } 164 165 Ok(ClientAuth::None) 166} 167 168fn validate_scope( 169 requested_scope: &Option<String>, 170 client_metadata: &crate::oauth::client::ClientMetadata, 171) -> Result<Option<String>, OAuthError> { 172 let scope_str = match requested_scope { 173 Some(s) if !s.is_empty() => s, 174 _ => return Ok(Some("atproto".to_string())), 175 }; 176 177 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 178 179 if requested_scopes.is_empty() { 180 return Ok(Some("atproto".to_string())); 181 } 182 183 for scope in &requested_scopes { 184 if !SUPPORTED_SCOPES.contains(scope) { 185 return Err(OAuthError::InvalidScope(format!( 186 "Unsupported scope: {}. Supported scopes: {}", 187 scope, 188 SUPPORTED_SCOPES.join(", ") 189 ))); 190 } 191 } 192 193 if let Some(client_scope) = &client_metadata.scope { 194 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 195 for scope in &requested_scopes { 196 if !client_scopes.contains(scope) { 197 return Err(OAuthError::InvalidScope(format!( 198 "Scope '{}' not registered for this client", 199 scope 200 ))); 201 } 202 } 203 } 204 205 Ok(Some(requested_scopes.join(" "))) 206}