this repo has no description
1use axum::{
2 extract::FromRequestParts,
3 http::{StatusCode, request::Parts},
4 response::{IntoResponse, Response},
5 Json,
6};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8use hmac::{Hmac, Mac};
9use serde_json::json;
10use sha2::Sha256;
11use sqlx::PgPool;
12use subtle::ConstantTimeEq;
13
14use crate::config::AuthConfig;
15use crate::state::AppState;
16use super::db;
17use super::dpop::DPoPVerifier;
18use super::OAuthError;
19
20pub struct OAuthTokenInfo {
21 pub did: String,
22 pub token_id: String,
23 pub client_id: String,
24 pub scope: Option<String>,
25 pub dpop_jkt: Option<String>,
26}
27
28pub struct VerifyResult {
29 pub did: String,
30 pub token_id: String,
31 pub client_id: String,
32 pub scope: Option<String>,
33}
34
35pub async fn verify_oauth_access_token(
36 pool: &PgPool,
37 access_token: &str,
38 dpop_proof: Option<&str>,
39 http_method: &str,
40 http_uri: &str,
41) -> Result<VerifyResult, OAuthError> {
42 let token_info = extract_oauth_token_info(access_token)?;
43
44 let token_data = db::get_token_by_id(pool, &token_info.token_id)
45 .await?
46 .ok_or_else(|| OAuthError::InvalidToken("Token not found or revoked".to_string()))?;
47
48 let now = chrono::Utc::now();
49 if token_data.expires_at < now {
50 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
51 }
52
53 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
54 let proof = dpop_proof.ok_or_else(|| {
55 OAuthError::UseDpopNonce("DPoP proof required".to_string())
56 })?;
57
58 let config = AuthConfig::get();
59 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
60
61 let access_token_hash = compute_ath(access_token);
62 let result = verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?;
63
64 if !db::check_and_record_dpop_jti(pool, &result.jti).await? {
65 return Err(OAuthError::InvalidDpopProof(
66 "DPoP proof has already been used".to_string(),
67 ));
68 }
69
70 if &result.jkt != expected_jkt {
71 return Err(OAuthError::InvalidDpopProof(
72 "DPoP key binding mismatch".to_string(),
73 ));
74 }
75 }
76
77 Ok(VerifyResult {
78 did: token_data.did,
79 token_id: token_info.token_id,
80 client_id: token_data.client_id,
81 scope: token_data.scope,
82 })
83}
84
85pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
86 let parts: Vec<&str> = token.split('.').collect();
87 if parts.len() != 3 {
88 return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
89 }
90
91 let header_bytes = URL_SAFE_NO_PAD
92 .decode(parts[0])
93 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
94 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
95 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
96
97 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
98 return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
99 }
100 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
101 return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
102 }
103
104 let config = AuthConfig::get();
105 let secret = config.jwt_secret();
106
107 let signing_input = format!("{}.{}", parts[0], parts[1]);
108 let provided_sig = URL_SAFE_NO_PAD
109 .decode(parts[2])
110 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
111
112 type HmacSha256 = Hmac<Sha256>;
113 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
114 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
115 mac.update(signing_input.as_bytes());
116 let expected_sig = mac.finalize().into_bytes();
117
118 if !bool::from(expected_sig.ct_eq(&provided_sig)) {
119 return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
120 }
121
122 let payload_bytes = URL_SAFE_NO_PAD
123 .decode(parts[1])
124 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
125 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
126 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
127
128 let exp = payload
129 .get("exp")
130 .and_then(|e| e.as_i64())
131 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
132 let now = chrono::Utc::now().timestamp();
133 if exp < now {
134 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
135 }
136
137 let token_id = payload
138 .get("jti")
139 .and_then(|j| j.as_str())
140 .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
141 .to_string();
142
143 let did = payload
144 .get("sub")
145 .and_then(|s| s.as_str())
146 .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
147 .to_string();
148
149 let scope = payload.get("scope").and_then(|s| s.as_str()).map(|s| s.to_string());
150
151 let dpop_jkt = payload
152 .get("cnf")
153 .and_then(|c| c.get("jkt"))
154 .and_then(|j| j.as_str())
155 .map(|s| s.to_string());
156
157 let client_id = payload
158 .get("client_id")
159 .and_then(|c| c.as_str())
160 .map(|s| s.to_string())
161 .unwrap_or_default();
162
163 Ok(OAuthTokenInfo {
164 did,
165 token_id,
166 client_id,
167 scope,
168 dpop_jkt,
169 })
170}
171
172fn compute_ath(access_token: &str) -> String {
173 use sha2::Digest;
174 let mut hasher = Sha256::new();
175 hasher.update(access_token.as_bytes());
176 let hash = hasher.finalize();
177 URL_SAFE_NO_PAD.encode(&hash)
178}
179
180pub fn generate_dpop_nonce() -> String {
181 let config = AuthConfig::get();
182 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
183 verifier.generate_nonce()
184}
185
186pub struct OAuthUser {
187 pub did: String,
188 pub client_id: Option<String>,
189 pub scope: Option<String>,
190 pub is_oauth: bool,
191}
192
193pub struct OAuthAuthError {
194 pub status: StatusCode,
195 pub error: String,
196 pub message: String,
197 pub dpop_nonce: Option<String>,
198}
199
200impl IntoResponse for OAuthAuthError {
201 fn into_response(self) -> Response {
202 let mut response = (
203 self.status,
204 Json(json!({
205 "error": self.error,
206 "message": self.message
207 })),
208 )
209 .into_response();
210
211 if let Some(nonce) = self.dpop_nonce {
212 response.headers_mut().insert(
213 "DPoP-Nonce",
214 nonce.parse().unwrap(),
215 );
216 }
217
218 response
219 }
220}
221
222impl FromRequestParts<AppState> for OAuthUser {
223 type Rejection = OAuthAuthError;
224
225 async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
226 let auth_header = parts
227 .headers
228 .get("Authorization")
229 .and_then(|v| v.to_str().ok())
230 .ok_or_else(|| OAuthAuthError {
231 status: StatusCode::UNAUTHORIZED,
232 error: "AuthenticationRequired".to_string(),
233 message: "Authorization header required".to_string(),
234 dpop_nonce: None,
235 })?;
236
237 let auth_header_trimmed = auth_header.trim();
238 let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ") {
239 (auth_header_trimmed[7..].trim(), false)
240 } else if auth_header_trimmed.len() >= 5 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ") {
241 (auth_header_trimmed[5..].trim(), true)
242 } else {
243 return Err(OAuthAuthError {
244 status: StatusCode::UNAUTHORIZED,
245 error: "InvalidRequest".to_string(),
246 message: "Invalid authorization scheme".to_string(),
247 dpop_nonce: None,
248 });
249 };
250
251 let dpop_proof = parts
252 .headers
253 .get("DPoP")
254 .and_then(|v| v.to_str().ok());
255
256 if let Ok(result) = try_legacy_auth(&state.db, token).await {
257 return Ok(OAuthUser {
258 did: result.did,
259 client_id: None,
260 scope: None,
261 is_oauth: false,
262 });
263 }
264
265 let http_method = parts.method.as_str();
266 let http_uri = parts.uri.to_string();
267
268 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await {
269 Ok(result) => Ok(OAuthUser {
270 did: result.did,
271 client_id: Some(result.client_id),
272 scope: result.scope,
273 is_oauth: true,
274 }),
275 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
276 status: StatusCode::UNAUTHORIZED,
277 error: "use_dpop_nonce".to_string(),
278 message: "DPoP nonce required".to_string(),
279 dpop_nonce: Some(nonce),
280 }),
281 Err(OAuthError::InvalidDpopProof(msg)) => {
282 let nonce = generate_dpop_nonce();
283 Err(OAuthAuthError {
284 status: StatusCode::UNAUTHORIZED,
285 error: "invalid_dpop_proof".to_string(),
286 message: msg,
287 dpop_nonce: Some(nonce),
288 })
289 }
290 Err(e) => {
291 let nonce = if is_dpop_token { Some(generate_dpop_nonce()) } else { None };
292 Err(OAuthAuthError {
293 status: StatusCode::UNAUTHORIZED,
294 error: "AuthenticationFailed".to_string(),
295 message: format!("{:?}", e),
296 dpop_nonce: nonce,
297 })
298 }
299 }
300 }
301}
302
303struct LegacyAuthResult {
304 did: String,
305}
306
307async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
308 match crate::auth::validate_bearer_token(pool, token).await {
309 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
310 _ => Err(()),
311 }
312}