this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, request::Parts},
5 response::{IntoResponse, Response},
6};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8use hmac::{Hmac, Mac};
9use serde_json::json;
10use sha2::Sha256;
11use sqlx::PgPool;
12use subtle::ConstantTimeEq;
13
14use super::OAuthError;
15use super::db;
16use super::dpop::DPoPVerifier;
17use super::scopes::ScopePermissions;
18use crate::config::AuthConfig;
19use crate::state::AppState;
20
21pub struct OAuthTokenInfo {
22 pub did: String,
23 pub token_id: String,
24 pub client_id: String,
25 pub scope: Option<String>,
26 pub dpop_jkt: Option<String>,
27 pub controller_did: Option<String>,
28}
29
30pub struct VerifyResult {
31 pub did: String,
32 pub token_id: String,
33 pub client_id: String,
34 pub scope: Option<String>,
35}
36
37pub async fn verify_oauth_access_token(
38 pool: &PgPool,
39 access_token: &str,
40 dpop_proof: Option<&str>,
41 http_method: &str,
42 http_uri: &str,
43) -> Result<VerifyResult, OAuthError> {
44 let token_info = extract_oauth_token_info(access_token)?;
45 let token_data = db::get_token_by_id(pool, &token_info.token_id)
46 .await?
47 .ok_or_else(|| OAuthError::InvalidToken("Token not found or revoked".to_string()))?;
48 let now = chrono::Utc::now();
49 if token_data.expires_at < now {
50 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
51 }
52 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
53 let proof = dpop_proof
54 .ok_or_else(|| OAuthError::UseDpopNonce("DPoP proof required".to_string()))?;
55 let config = AuthConfig::get();
56 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
57 let access_token_hash = compute_ath(access_token);
58 let result =
59 verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?;
60 if !db::check_and_record_dpop_jti(pool, &result.jti).await? {
61 return Err(OAuthError::InvalidDpopProof(
62 "DPoP proof has already been used".to_string(),
63 ));
64 }
65 if &result.jkt != expected_jkt {
66 return Err(OAuthError::InvalidDpopProof(
67 "DPoP key binding mismatch".to_string(),
68 ));
69 }
70 }
71 Ok(VerifyResult {
72 did: token_data.did,
73 token_id: token_info.token_id,
74 client_id: token_data.client_id,
75 scope: token_data.scope,
76 })
77}
78
79pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
80 let parts: Vec<&str> = token.split('.').collect();
81 if parts.len() != 3 {
82 return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
83 }
84 let header_bytes = URL_SAFE_NO_PAD
85 .decode(parts[0])
86 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
87 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
88 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
89 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
90 return Err(OAuthError::InvalidToken(
91 "Not an OAuth access token".to_string(),
92 ));
93 }
94 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
95 return Err(OAuthError::InvalidToken(
96 "Unsupported algorithm".to_string(),
97 ));
98 }
99 let config = AuthConfig::get();
100 let secret = config.jwt_secret();
101 let signing_input = format!("{}.{}", parts[0], parts[1]);
102 let provided_sig = URL_SAFE_NO_PAD
103 .decode(parts[2])
104 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
105 type HmacSha256 = Hmac<Sha256>;
106 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
107 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
108 mac.update(signing_input.as_bytes());
109 let expected_sig = mac.finalize().into_bytes();
110 if !bool::from(expected_sig.ct_eq(&provided_sig)) {
111 return Err(OAuthError::InvalidToken(
112 "Invalid token signature".to_string(),
113 ));
114 }
115 let payload_bytes = URL_SAFE_NO_PAD
116 .decode(parts[1])
117 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
118 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
119 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
120 let exp = payload
121 .get("exp")
122 .and_then(|e| e.as_i64())
123 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
124 let now = chrono::Utc::now().timestamp();
125 if exp < now {
126 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
127 }
128 let token_id = payload
129 .get("jti")
130 .and_then(|j| j.as_str())
131 .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
132 .to_string();
133 let did = payload
134 .get("sub")
135 .and_then(|s| s.as_str())
136 .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
137 .to_string();
138 let scope = payload
139 .get("scope")
140 .and_then(|s| s.as_str())
141 .map(|s| s.to_string());
142 let dpop_jkt = payload
143 .get("cnf")
144 .and_then(|c| c.get("jkt"))
145 .and_then(|j| j.as_str())
146 .map(|s| s.to_string());
147 let client_id = payload
148 .get("client_id")
149 .and_then(|c| c.as_str())
150 .map(|s| s.to_string())
151 .unwrap_or_default();
152 let controller_did = payload
153 .get("act")
154 .and_then(|a| a.get("sub"))
155 .and_then(|s| s.as_str())
156 .map(|s| s.to_string());
157 Ok(OAuthTokenInfo {
158 did,
159 token_id,
160 client_id,
161 scope,
162 dpop_jkt,
163 controller_did,
164 })
165}
166
167fn compute_ath(access_token: &str) -> String {
168 use sha2::Digest;
169 let mut hasher = Sha256::new();
170 hasher.update(access_token.as_bytes());
171 let hash = hasher.finalize();
172 URL_SAFE_NO_PAD.encode(hash)
173}
174
175pub fn generate_dpop_nonce() -> String {
176 let config = AuthConfig::get();
177 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
178 verifier.generate_nonce()
179}
180
181pub struct OAuthUser {
182 pub did: String,
183 pub client_id: Option<String>,
184 pub scope: Option<String>,
185 pub is_oauth: bool,
186 pub permissions: ScopePermissions,
187}
188
189pub struct OAuthAuthError {
190 pub status: StatusCode,
191 pub error: String,
192 pub message: String,
193 pub dpop_nonce: Option<String>,
194}
195
196impl IntoResponse for OAuthAuthError {
197 fn into_response(self) -> Response {
198 let mut response = (
199 self.status,
200 Json(json!({
201 "error": self.error,
202 "message": self.message
203 })),
204 )
205 .into_response();
206 if let Some(nonce) = self.dpop_nonce {
207 response
208 .headers_mut()
209 .insert("DPoP-Nonce", nonce.parse().unwrap());
210 }
211 response
212 }
213}
214
215impl FromRequestParts<AppState> for OAuthUser {
216 type Rejection = OAuthAuthError;
217
218 async fn from_request_parts(
219 parts: &mut Parts,
220 state: &AppState,
221 ) -> Result<Self, Self::Rejection> {
222 let auth_header = parts
223 .headers
224 .get("Authorization")
225 .and_then(|v| v.to_str().ok())
226 .ok_or_else(|| OAuthAuthError {
227 status: StatusCode::UNAUTHORIZED,
228 error: "AuthenticationRequired".to_string(),
229 message: "Authorization header required".to_string(),
230 dpop_nonce: None,
231 })?;
232 let auth_header_trimmed = auth_header.trim();
233 let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7
234 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ")
235 {
236 (auth_header_trimmed[7..].trim(), false)
237 } else if auth_header_trimmed.len() >= 5
238 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ")
239 {
240 (auth_header_trimmed[5..].trim(), true)
241 } else {
242 return Err(OAuthAuthError {
243 status: StatusCode::UNAUTHORIZED,
244 error: "InvalidRequest".to_string(),
245 message: "Invalid authorization scheme".to_string(),
246 dpop_nonce: None,
247 });
248 };
249 let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok());
250 if let Ok(result) = try_legacy_auth(&state.db, token).await {
251 return Ok(OAuthUser {
252 did: result.did,
253 client_id: None,
254 scope: None,
255 is_oauth: false,
256 permissions: ScopePermissions::default(),
257 });
258 }
259 let http_method = parts.method.as_str();
260 let http_uri = parts.uri.to_string();
261 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
262 {
263 Ok(result) => {
264 let permissions = ScopePermissions::from_scope_string(result.scope.as_deref());
265 Ok(OAuthUser {
266 did: result.did,
267 client_id: Some(result.client_id),
268 scope: result.scope,
269 is_oauth: true,
270 permissions,
271 })
272 }
273 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
274 status: StatusCode::UNAUTHORIZED,
275 error: "use_dpop_nonce".to_string(),
276 message: "DPoP nonce required".to_string(),
277 dpop_nonce: Some(nonce),
278 }),
279 Err(OAuthError::InvalidDpopProof(msg)) => {
280 let nonce = generate_dpop_nonce();
281 Err(OAuthAuthError {
282 status: StatusCode::UNAUTHORIZED,
283 error: "invalid_dpop_proof".to_string(),
284 message: msg,
285 dpop_nonce: Some(nonce),
286 })
287 }
288 Err(e) => {
289 let nonce = if is_dpop_token {
290 Some(generate_dpop_nonce())
291 } else {
292 None
293 };
294 Err(OAuthAuthError {
295 status: StatusCode::UNAUTHORIZED,
296 error: "AuthenticationFailed".to_string(),
297 message: format!("{:?}", e),
298 dpop_nonce: nonce,
299 })
300 }
301 }
302 }
303}
304
305struct LegacyAuthResult {
306 did: String,
307}
308
309async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
310 match crate::auth::validate_bearer_token(pool, token).await {
311 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
312 _ => Err(()),
313 }
314}