this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4}; 5use chrono::{Duration, Utc}; 6use serde::{Deserialize, Serialize}; 7 8use crate::state::AppState; 9use crate::oauth::{ 10 AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, RequestId, 11 client::ClientMetadataCache, 12 db, 13}; 14 15const PAR_EXPIRY_SECONDS: i64 = 600; 16 17const SUPPORTED_SCOPES: &[&str] = &["atproto", "transition:generic", "transition:chat.bsky"]; 18 19#[derive(Debug, Deserialize)] 20pub struct ParRequest { 21 pub response_type: String, 22 pub client_id: String, 23 pub redirect_uri: String, 24 #[serde(default)] 25 pub scope: Option<String>, 26 #[serde(default)] 27 pub state: Option<String>, 28 #[serde(default)] 29 pub code_challenge: Option<String>, 30 #[serde(default)] 31 pub code_challenge_method: Option<String>, 32 #[serde(default)] 33 pub login_hint: Option<String>, 34 #[serde(default)] 35 pub dpop_jkt: Option<String>, 36 #[serde(default)] 37 pub client_secret: Option<String>, 38 #[serde(default)] 39 pub client_assertion: Option<String>, 40 #[serde(default)] 41 pub client_assertion_type: Option<String>, 42} 43 44#[derive(Debug, Serialize)] 45pub struct ParResponse { 46 pub request_uri: String, 47 pub expires_in: u64, 48} 49 50pub async fn pushed_authorization_request( 51 State(state): State<AppState>, 52 Form(request): Form<ParRequest>, 53) -> Result<Json<ParResponse>, OAuthError> { 54 if request.response_type != "code" { 55 return Err(OAuthError::InvalidRequest( 56 "response_type must be 'code'".to_string(), 57 )); 58 } 59 60 let code_challenge = request.code_challenge.as_ref() 61 .filter(|s| !s.is_empty()) 62 .ok_or_else(|| OAuthError::InvalidRequest( 63 "code_challenge is required".to_string(), 64 ))?; 65 66 let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 67 if code_challenge_method != "S256" { 68 return Err(OAuthError::InvalidRequest( 69 "code_challenge_method must be 'S256'".to_string(), 70 )); 71 } 72 73 let client_cache = ClientMetadataCache::new(3600); 74 let client_metadata = client_cache.get(&request.client_id).await?; 75 76 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; 77 78 let client_auth = determine_client_auth(&request)?; 79 80 if client_metadata.requires_dpop() && request.dpop_jkt.is_none() { 81 return Err(OAuthError::InvalidRequest( 82 "dpop_jkt is required for this client".to_string(), 83 )); 84 } 85 86 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 87 88 let request_id = RequestId::generate(); 89 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 90 91 let parameters = AuthorizationRequestParameters { 92 response_type: request.response_type, 93 client_id: request.client_id.clone(), 94 redirect_uri: request.redirect_uri, 95 scope: validated_scope, 96 state: request.state, 97 code_challenge: code_challenge.clone(), 98 code_challenge_method: code_challenge_method.to_string(), 99 login_hint: request.login_hint, 100 dpop_jkt: request.dpop_jkt, 101 extra: None, 102 }; 103 104 let request_data = RequestData { 105 client_id: request.client_id, 106 client_auth: Some(client_auth), 107 parameters, 108 expires_at, 109 did: None, 110 device_id: None, 111 code: None, 112 }; 113 114 db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 115 116 tokio::spawn({ 117 let pool = state.db.clone(); 118 async move { 119 if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 120 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 121 } 122 } 123 }); 124 125 Ok(Json(ParResponse { 126 request_uri: request_id.0, 127 expires_in: PAR_EXPIRY_SECONDS as u64, 128 })) 129} 130 131fn determine_client_auth(request: &ParRequest) -> Result<ClientAuth, OAuthError> { 132 if let (Some(assertion), Some(assertion_type)) = 133 (&request.client_assertion, &request.client_assertion_type) 134 { 135 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 136 return Err(OAuthError::InvalidRequest( 137 "Unsupported client_assertion_type".to_string(), 138 )); 139 } 140 return Ok(ClientAuth::PrivateKeyJwt { 141 client_assertion: assertion.clone(), 142 }); 143 } 144 145 if let Some(secret) = &request.client_secret { 146 return Ok(ClientAuth::SecretPost { 147 client_secret: secret.clone(), 148 }); 149 } 150 151 Ok(ClientAuth::None) 152} 153 154fn validate_scope( 155 requested_scope: &Option<String>, 156 client_metadata: &crate::oauth::client::ClientMetadata, 157) -> Result<Option<String>, OAuthError> { 158 let scope_str = match requested_scope { 159 Some(s) if !s.is_empty() => s, 160 _ => return Ok(Some("atproto".to_string())), 161 }; 162 163 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); 164 165 if requested_scopes.is_empty() { 166 return Ok(Some("atproto".to_string())); 167 } 168 169 for scope in &requested_scopes { 170 if !SUPPORTED_SCOPES.contains(scope) { 171 return Err(OAuthError::InvalidScope(format!( 172 "Unsupported scope: {}. Supported scopes: {}", 173 scope, 174 SUPPORTED_SCOPES.join(", ") 175 ))); 176 } 177 } 178 179 if let Some(client_scope) = &client_metadata.scope { 180 let client_scopes: Vec<&str> = client_scope.split_whitespace().collect(); 181 for scope in &requested_scopes { 182 if !client_scopes.contains(scope) { 183 return Err(OAuthError::InvalidScope(format!( 184 "Scope '{}' not registered for this client", 185 scope 186 ))); 187 } 188 } 189 } 190 191 Ok(Some(requested_scopes.join(" "))) 192}