this repo has no description
1use axum::http::HeaderMap; 2use axum::Json; 3use chrono::{Duration, Utc}; 4use crate::config::AuthConfig; 5use crate::state::AppState; 6use crate::oauth::{ 7 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId, 8 client::{ClientMetadataCache, verify_client_auth}, 9 db, 10 dpop::DPoPVerifier, 11}; 12use super::types::{TokenRequest, TokenResponse}; 13use super::helpers::{create_access_token, verify_pkce}; 14 15const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 16const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 17 18pub async fn handle_authorization_code_grant( 19 state: AppState, 20 _headers: HeaderMap, 21 request: TokenRequest, 22 dpop_proof: Option<String>, 23) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 24 let code = request 25 .code 26 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?; 27 let code_verifier = request 28 .code_verifier 29 .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?; 30 let auth_request = db::consume_authorization_request_by_code(&state.db, &code) 31 .await? 32 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 33 if auth_request.expires_at < Utc::now() { 34 return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string())); 35 } 36 if let Some(request_client_id) = &request.client_id { 37 if request_client_id != &auth_request.client_id { 38 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 39 } 40 } 41 let did = auth_request 42 .did 43 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 44 let client_metadata_cache = ClientMetadataCache::new(3600); 45 let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?; 46 let client_auth = if let (Some(assertion), Some(assertion_type)) = (&request.client_assertion, &request.client_assertion_type) { 47 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 48 return Err(OAuthError::InvalidClient( 49 "Unsupported client_assertion_type".to_string(), 50 )); 51 } 52 ClientAuth::PrivateKeyJwt { 53 client_assertion: assertion.clone(), 54 } 55 } else if let Some(secret) = &request.client_secret { 56 ClientAuth::SecretPost { 57 client_secret: secret.clone(), 58 } 59 } else { 60 ClientAuth::None 61 }; 62 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 63 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 64 if let Some(redirect_uri) = &request.redirect_uri { 65 if redirect_uri != &auth_request.parameters.redirect_uri { 66 return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string())); 67 } 68 } 69 let dpop_jkt = if let Some(proof) = &dpop_proof { 70 let config = AuthConfig::get(); 71 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 72 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 73 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 74 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 75 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 76 return Err(OAuthError::InvalidDpopProof( 77 "DPoP proof has already been used".to_string(), 78 )); 79 } 80 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt { 81 if &result.jkt != expected_jkt { 82 return Err(OAuthError::InvalidDpopProof( 83 "DPoP key binding mismatch".to_string(), 84 )); 85 } 86 } 87 Some(result.jkt) 88 } else if auth_request.parameters.dpop_jkt.is_some() { 89 return Err(OAuthError::InvalidRequest( 90 "DPoP proof required for this authorization".to_string(), 91 )); 92 } else { 93 None 94 }; 95 let token_id = TokenId::generate(); 96 let refresh_token = RefreshToken::generate(); 97 let now = Utc::now(); 98 let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 99 let token_data = TokenData { 100 did: did.clone(), 101 token_id: token_id.0.clone(), 102 created_at: now, 103 updated_at: now, 104 expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS), 105 client_id: auth_request.client_id.clone(), 106 client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None), 107 device_id: auth_request.device_id, 108 parameters: auth_request.parameters.clone(), 109 details: None, 110 code: None, 111 current_refresh_token: Some(refresh_token.0.clone()), 112 scope: auth_request.parameters.scope.clone(), 113 }; 114 db::create_token(&state.db, &token_data).await?; 115 tokio::spawn({ 116 let pool = state.db.clone(); 117 let did_clone = did.clone(); 118 async move { 119 if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await { 120 tracing::warn!("Failed to enforce token limit for user: {:?}", e); 121 } 122 } 123 }); 124 let mut response_headers = HeaderMap::new(); 125 let config = AuthConfig::get(); 126 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 127 response_headers.insert( 128 "DPoP-Nonce", 129 verifier.generate_nonce().parse().unwrap(), 130 ); 131 Ok(( 132 response_headers, 133 Json(TokenResponse { 134 access_token, 135 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 136 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 137 refresh_token: Some(refresh_token.0), 138 scope: auth_request.parameters.scope, 139 sub: Some(did), 140 }), 141 )) 142} 143 144pub async fn handle_refresh_token_grant( 145 state: AppState, 146 _headers: HeaderMap, 147 request: TokenRequest, 148 dpop_proof: Option<String>, 149) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 150 let refresh_token_str = request 151 .refresh_token 152 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?; 153 if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? { 154 db::delete_token_family(&state.db, token_id).await?; 155 return Err(OAuthError::InvalidGrant( 156 "Refresh token reuse detected, token family revoked".to_string(), 157 )); 158 } 159 let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str) 160 .await? 161 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?; 162 if token_data.expires_at < Utc::now() { 163 db::delete_token_family(&state.db, db_id).await?; 164 return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string())); 165 } 166 let dpop_jkt = if let Some(proof) = &dpop_proof { 167 let config = AuthConfig::get(); 168 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 169 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 170 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 171 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 172 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 173 return Err(OAuthError::InvalidDpopProof( 174 "DPoP proof has already been used".to_string(), 175 )); 176 } 177 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 178 if &result.jkt != expected_jkt { 179 return Err(OAuthError::InvalidDpopProof( 180 "DPoP key binding mismatch".to_string(), 181 )); 182 } 183 } 184 Some(result.jkt) 185 } else if token_data.parameters.dpop_jkt.is_some() { 186 return Err(OAuthError::InvalidRequest( 187 "DPoP proof required".to_string(), 188 )); 189 } else { 190 None 191 }; 192 let new_token_id = TokenId::generate(); 193 let new_refresh_token = RefreshToken::generate(); 194 let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS); 195 db::rotate_token( 196 &state.db, 197 db_id, 198 &new_token_id.0, 199 &new_refresh_token.0, 200 new_expires_at, 201 ) 202 .await?; 203 let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 204 let mut response_headers = HeaderMap::new(); 205 let config = AuthConfig::get(); 206 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 207 response_headers.insert( 208 "DPoP-Nonce", 209 verifier.generate_nonce().parse().unwrap(), 210 ); 211 Ok(( 212 response_headers, 213 Json(TokenResponse { 214 access_token, 215 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 216 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 217 refresh_token: Some(new_refresh_token.0), 218 scope: token_data.scope, 219 sub: Some(token_data.did), 220 }), 221 )) 222}