this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, request::Parts},
5 response::{IntoResponse, Response},
6};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8use hmac::{Hmac, Mac};
9use serde_json::json;
10use sha2::Sha256;
11use sqlx::PgPool;
12use subtle::ConstantTimeEq;
13
14use super::db;
15use super::scopes::ScopePermissions;
16use super::{DPoPVerifier, OAuthError};
17use crate::config::AuthConfig;
18use crate::state::AppState;
19
20pub struct OAuthTokenInfo {
21 pub did: String,
22 pub token_id: String,
23 pub client_id: String,
24 pub scope: Option<String>,
25 pub dpop_jkt: Option<String>,
26 pub controller_did: Option<String>,
27}
28
29pub struct VerifyResult {
30 pub did: String,
31 pub token_id: String,
32 pub client_id: String,
33 pub scope: Option<String>,
34}
35
36pub async fn verify_oauth_access_token(
37 pool: &PgPool,
38 access_token: &str,
39 dpop_proof: Option<&str>,
40 http_method: &str,
41 http_uri: &str,
42) -> Result<VerifyResult, OAuthError> {
43 let token_info = extract_oauth_token_info(access_token)?;
44 tracing::debug!(
45 token_id = %token_info.token_id,
46 has_dpop_proof = dpop_proof.is_some(),
47 "Verifying OAuth access token"
48 );
49 let token_data = db::get_token_by_id(pool, &token_info.token_id)
50 .await?
51 .ok_or_else(|| {
52 tracing::warn!(token_id = %token_info.token_id, "Token not found in database");
53 OAuthError::InvalidToken("Token not found or revoked".to_string())
54 })?;
55 let now = chrono::Utc::now();
56 if token_data.expires_at < now {
57 return Err(OAuthError::ExpiredToken(
58 "Token session has expired".to_string(),
59 ));
60 }
61 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
62 tracing::debug!(expected_jkt = %expected_jkt, "Token requires DPoP");
63 let proof = dpop_proof.ok_or_else(|| {
64 tracing::warn!("DPoP proof required but not provided");
65 OAuthError::UseDpopNonce("DPoP proof required".to_string())
66 })?;
67 let config = AuthConfig::get();
68 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
69 let access_token_hash = compute_ath(access_token);
70 let result = verifier
71 .verify_proof(proof, http_method, http_uri, Some(&access_token_hash))
72 .map_err(|e| {
73 tracing::warn!(error = ?e, http_method = %http_method, http_uri = %http_uri, "DPoP proof verification failed");
74 e
75 })?;
76 if !db::check_and_record_dpop_jti(pool, &result.jti).await? {
77 return Err(OAuthError::InvalidDpopProof(
78 "DPoP proof has already been used".to_string(),
79 ));
80 }
81 if result.jkt.as_str() != expected_jkt {
82 return Err(OAuthError::InvalidDpopProof(
83 "DPoP key binding mismatch".to_string(),
84 ));
85 }
86 }
87 Ok(VerifyResult {
88 did: token_data.did,
89 token_id: token_info.token_id,
90 client_id: token_data.client_id,
91 scope: token_data.scope,
92 })
93}
94
95pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
96 let parts: Vec<&str> = token.split('.').collect();
97 if parts.len() != 3 {
98 return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
99 }
100 let header_bytes = URL_SAFE_NO_PAD
101 .decode(parts[0])
102 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
103 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
104 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
105 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
106 return Err(OAuthError::InvalidToken(
107 "Not an OAuth access token".to_string(),
108 ));
109 }
110 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
111 return Err(OAuthError::InvalidToken(
112 "Unsupported algorithm".to_string(),
113 ));
114 }
115 let config = AuthConfig::get();
116 let secret = config.jwt_secret();
117 let signing_input = format!("{}.{}", parts[0], parts[1]);
118 let provided_sig = URL_SAFE_NO_PAD
119 .decode(parts[2])
120 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
121 type HmacSha256 = Hmac<Sha256>;
122 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
123 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
124 mac.update(signing_input.as_bytes());
125 let expected_sig = mac.finalize().into_bytes();
126 if !bool::from(expected_sig.ct_eq(&provided_sig)) {
127 return Err(OAuthError::InvalidToken(
128 "Invalid token signature".to_string(),
129 ));
130 }
131 let payload_bytes = URL_SAFE_NO_PAD
132 .decode(parts[1])
133 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
134 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
135 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
136 let exp = payload
137 .get("exp")
138 .and_then(|e| e.as_i64())
139 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
140 let now = chrono::Utc::now().timestamp();
141 if exp < now {
142 return Err(OAuthError::ExpiredToken("Token has expired".to_string()));
143 }
144 let token_id = payload
145 .get("sid")
146 .and_then(|j| j.as_str())
147 .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?
148 .to_string();
149 let did = payload
150 .get("sub")
151 .and_then(|s| s.as_str())
152 .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
153 .to_string();
154 let scope = payload
155 .get("scope")
156 .and_then(|s| s.as_str())
157 .map(|s| s.to_string());
158 let dpop_jkt = payload
159 .get("cnf")
160 .and_then(|c| c.get("jkt"))
161 .and_then(|j| j.as_str())
162 .map(|s| s.to_string());
163 let client_id = payload
164 .get("client_id")
165 .and_then(|c| c.as_str())
166 .map(|s| s.to_string())
167 .unwrap_or_default();
168 let controller_did = payload
169 .get("act")
170 .and_then(|a| a.get("sub"))
171 .and_then(|s| s.as_str())
172 .map(|s| s.to_string());
173 Ok(OAuthTokenInfo {
174 did,
175 token_id,
176 client_id,
177 scope,
178 dpop_jkt,
179 controller_did,
180 })
181}
182
183fn compute_ath(access_token: &str) -> String {
184 use sha2::Digest;
185 let mut hasher = Sha256::new();
186 hasher.update(access_token.as_bytes());
187 let hash = hasher.finalize();
188 URL_SAFE_NO_PAD.encode(hash)
189}
190
191pub fn generate_dpop_nonce() -> String {
192 let config = AuthConfig::get();
193 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
194 verifier.generate_nonce()
195}
196
197pub struct OAuthUser {
198 pub did: String,
199 pub client_id: Option<String>,
200 pub scope: Option<String>,
201 pub is_oauth: bool,
202 pub permissions: ScopePermissions,
203}
204
205pub struct OAuthAuthError {
206 pub status: StatusCode,
207 pub error: String,
208 pub message: String,
209 pub dpop_nonce: Option<String>,
210 pub www_authenticate: Option<String>,
211}
212
213impl IntoResponse for OAuthAuthError {
214 fn into_response(self) -> Response {
215 let mut response = (
216 self.status,
217 Json(json!({
218 "error": self.error,
219 "message": self.message
220 })),
221 )
222 .into_response();
223 if let Some(nonce) = self.dpop_nonce {
224 response
225 .headers_mut()
226 .insert("DPoP-Nonce", nonce.parse().unwrap());
227 }
228 if let Some(www_auth) = self.www_authenticate {
229 response
230 .headers_mut()
231 .insert("WWW-Authenticate", www_auth.parse().unwrap());
232 }
233 response
234 }
235}
236
237impl FromRequestParts<AppState> for OAuthUser {
238 type Rejection = OAuthAuthError;
239
240 async fn from_request_parts(
241 parts: &mut Parts,
242 state: &AppState,
243 ) -> Result<Self, Self::Rejection> {
244 let auth_header = parts
245 .headers
246 .get("Authorization")
247 .and_then(|v| v.to_str().ok())
248 .ok_or_else(|| OAuthAuthError {
249 status: StatusCode::UNAUTHORIZED,
250 error: "AuthenticationRequired".to_string(),
251 message: "Authorization header required".to_string(),
252 dpop_nonce: None,
253 www_authenticate: None,
254 })?;
255 let auth_header_trimmed = auth_header.trim();
256 let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7
257 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ")
258 {
259 (auth_header_trimmed[7..].trim(), false)
260 } else if auth_header_trimmed.len() >= 5
261 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ")
262 {
263 (auth_header_trimmed[5..].trim(), true)
264 } else {
265 return Err(OAuthAuthError {
266 status: StatusCode::UNAUTHORIZED,
267 error: "InvalidRequest".to_string(),
268 message: "Invalid authorization scheme".to_string(),
269 dpop_nonce: None,
270 www_authenticate: None,
271 });
272 };
273 let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok());
274 if let Ok(result) = try_legacy_auth(&state.db, token).await {
275 return Ok(OAuthUser {
276 did: result.did,
277 client_id: None,
278 scope: None,
279 is_oauth: false,
280 permissions: ScopePermissions::default(),
281 });
282 }
283 let http_method = parts.method.as_str();
284 let http_uri = crate::util::build_full_url(&parts.uri.to_string());
285 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
286 {
287 Ok(result) => {
288 let permissions = ScopePermissions::from_scope_string(result.scope.as_deref());
289 Ok(OAuthUser {
290 did: result.did,
291 client_id: Some(result.client_id),
292 scope: result.scope,
293 is_oauth: true,
294 permissions,
295 })
296 }
297 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
298 status: StatusCode::UNAUTHORIZED,
299 error: "use_dpop_nonce".to_string(),
300 message: "DPoP nonce required".to_string(),
301 dpop_nonce: Some(nonce),
302 www_authenticate: Some("DPoP error=\"use_dpop_nonce\"".to_string()),
303 }),
304 Err(OAuthError::InvalidDpopProof(msg)) => {
305 let nonce = generate_dpop_nonce();
306 Err(OAuthAuthError {
307 status: StatusCode::UNAUTHORIZED,
308 error: "invalid_dpop_proof".to_string(),
309 message: msg,
310 dpop_nonce: Some(nonce),
311 www_authenticate: None,
312 })
313 }
314 Err(OAuthError::ExpiredToken(msg)) => {
315 let nonce = if is_dpop_token {
316 Some(generate_dpop_nonce())
317 } else {
318 None
319 };
320 let scheme = if is_dpop_token { "DPoP" } else { "Bearer" };
321 let www_auth = format!(
322 "{} error=\"invalid_token\", error_description=\"{}\"",
323 scheme, msg
324 );
325 Err(OAuthAuthError {
326 status: StatusCode::UNAUTHORIZED,
327 error: "ExpiredToken".to_string(),
328 message: msg,
329 dpop_nonce: nonce,
330 www_authenticate: Some(www_auth),
331 })
332 }
333 Err(OAuthError::InvalidToken(msg)) => {
334 let nonce = if is_dpop_token {
335 Some(generate_dpop_nonce())
336 } else {
337 None
338 };
339 let scheme = if is_dpop_token { "DPoP" } else { "Bearer" };
340 let www_auth = format!(
341 "{} error=\"invalid_token\", error_description=\"{}\"",
342 scheme, msg
343 );
344 Err(OAuthAuthError {
345 status: StatusCode::UNAUTHORIZED,
346 error: "InvalidToken".to_string(),
347 message: msg,
348 dpop_nonce: nonce,
349 www_authenticate: Some(www_auth),
350 })
351 }
352 Err(e) => {
353 let nonce = if is_dpop_token {
354 Some(generate_dpop_nonce())
355 } else {
356 None
357 };
358 Err(OAuthAuthError {
359 status: StatusCode::UNAUTHORIZED,
360 error: "AuthenticationFailed".to_string(),
361 message: format!("{:?}", e),
362 dpop_nonce: nonce,
363 www_authenticate: None,
364 })
365 }
366 }
367 }
368}
369
370struct LegacyAuthResult {
371 did: String,
372}
373
374async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
375 match crate::auth::validate_bearer_token(pool, token).await {
376 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult {
377 did: user.did.to_string(),
378 }),
379 _ => Err(()),
380 }
381}