this repo has no description
1use axum::{
2 Json,
3 extract::FromRequestParts,
4 http::{StatusCode, request::Parts},
5 response::{IntoResponse, Response},
6};
7use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
8use hmac::{Hmac, Mac};
9use serde_json::json;
10use sha2::Sha256;
11use sqlx::PgPool;
12use subtle::ConstantTimeEq;
13
14use super::OAuthError;
15use super::db;
16use super::dpop::DPoPVerifier;
17use super::scopes::ScopePermissions;
18use crate::config::AuthConfig;
19use crate::state::AppState;
20
21pub struct OAuthTokenInfo {
22 pub did: String,
23 pub token_id: String,
24 pub client_id: String,
25 pub scope: Option<String>,
26 pub dpop_jkt: Option<String>,
27}
28
29pub struct VerifyResult {
30 pub did: String,
31 pub token_id: String,
32 pub client_id: String,
33 pub scope: Option<String>,
34}
35
36pub async fn verify_oauth_access_token(
37 pool: &PgPool,
38 access_token: &str,
39 dpop_proof: Option<&str>,
40 http_method: &str,
41 http_uri: &str,
42) -> Result<VerifyResult, OAuthError> {
43 let token_info = extract_oauth_token_info(access_token)?;
44 let token_data = db::get_token_by_id(pool, &token_info.token_id)
45 .await?
46 .ok_or_else(|| OAuthError::InvalidToken("Token not found or revoked".to_string()))?;
47 let now = chrono::Utc::now();
48 if token_data.expires_at < now {
49 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
50 }
51 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
52 let proof = dpop_proof
53 .ok_or_else(|| OAuthError::UseDpopNonce("DPoP proof required".to_string()))?;
54 let config = AuthConfig::get();
55 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
56 let access_token_hash = compute_ath(access_token);
57 let result =
58 verifier.verify_proof(proof, http_method, http_uri, Some(&access_token_hash))?;
59 if !db::check_and_record_dpop_jti(pool, &result.jti).await? {
60 return Err(OAuthError::InvalidDpopProof(
61 "DPoP proof has already been used".to_string(),
62 ));
63 }
64 if &result.jkt != expected_jkt {
65 return Err(OAuthError::InvalidDpopProof(
66 "DPoP key binding mismatch".to_string(),
67 ));
68 }
69 }
70 Ok(VerifyResult {
71 did: token_data.did,
72 token_id: token_info.token_id,
73 client_id: token_data.client_id,
74 scope: token_data.scope,
75 })
76}
77
78pub fn extract_oauth_token_info(token: &str) -> Result<OAuthTokenInfo, OAuthError> {
79 let parts: Vec<&str> = token.split('.').collect();
80 if parts.len() != 3 {
81 return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
82 }
83 let header_bytes = URL_SAFE_NO_PAD
84 .decode(parts[0])
85 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
86 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
87 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
88 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
89 return Err(OAuthError::InvalidToken(
90 "Not an OAuth access token".to_string(),
91 ));
92 }
93 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
94 return Err(OAuthError::InvalidToken(
95 "Unsupported algorithm".to_string(),
96 ));
97 }
98 let config = AuthConfig::get();
99 let secret = config.jwt_secret();
100 let signing_input = format!("{}.{}", parts[0], parts[1]);
101 let provided_sig = URL_SAFE_NO_PAD
102 .decode(parts[2])
103 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
104 type HmacSha256 = Hmac<Sha256>;
105 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
106 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
107 mac.update(signing_input.as_bytes());
108 let expected_sig = mac.finalize().into_bytes();
109 if !bool::from(expected_sig.ct_eq(&provided_sig)) {
110 return Err(OAuthError::InvalidToken(
111 "Invalid token signature".to_string(),
112 ));
113 }
114 let payload_bytes = URL_SAFE_NO_PAD
115 .decode(parts[1])
116 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
117 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
118 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
119 let exp = payload
120 .get("exp")
121 .and_then(|e| e.as_i64())
122 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
123 let now = chrono::Utc::now().timestamp();
124 if exp < now {
125 return Err(OAuthError::InvalidToken("Token has expired".to_string()));
126 }
127 let token_id = payload
128 .get("jti")
129 .and_then(|j| j.as_str())
130 .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
131 .to_string();
132 let did = payload
133 .get("sub")
134 .and_then(|s| s.as_str())
135 .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
136 .to_string();
137 let scope = payload
138 .get("scope")
139 .and_then(|s| s.as_str())
140 .map(|s| s.to_string());
141 let dpop_jkt = payload
142 .get("cnf")
143 .and_then(|c| c.get("jkt"))
144 .and_then(|j| j.as_str())
145 .map(|s| s.to_string());
146 let client_id = payload
147 .get("client_id")
148 .and_then(|c| c.as_str())
149 .map(|s| s.to_string())
150 .unwrap_or_default();
151 Ok(OAuthTokenInfo {
152 did,
153 token_id,
154 client_id,
155 scope,
156 dpop_jkt,
157 })
158}
159
160fn compute_ath(access_token: &str) -> String {
161 use sha2::Digest;
162 let mut hasher = Sha256::new();
163 hasher.update(access_token.as_bytes());
164 let hash = hasher.finalize();
165 URL_SAFE_NO_PAD.encode(hash)
166}
167
168pub fn generate_dpop_nonce() -> String {
169 let config = AuthConfig::get();
170 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
171 verifier.generate_nonce()
172}
173
174pub struct OAuthUser {
175 pub did: String,
176 pub client_id: Option<String>,
177 pub scope: Option<String>,
178 pub is_oauth: bool,
179 pub permissions: ScopePermissions,
180}
181
182pub struct OAuthAuthError {
183 pub status: StatusCode,
184 pub error: String,
185 pub message: String,
186 pub dpop_nonce: Option<String>,
187}
188
189impl IntoResponse for OAuthAuthError {
190 fn into_response(self) -> Response {
191 let mut response = (
192 self.status,
193 Json(json!({
194 "error": self.error,
195 "message": self.message
196 })),
197 )
198 .into_response();
199 if let Some(nonce) = self.dpop_nonce {
200 response
201 .headers_mut()
202 .insert("DPoP-Nonce", nonce.parse().unwrap());
203 }
204 response
205 }
206}
207
208impl FromRequestParts<AppState> for OAuthUser {
209 type Rejection = OAuthAuthError;
210
211 async fn from_request_parts(
212 parts: &mut Parts,
213 state: &AppState,
214 ) -> Result<Self, Self::Rejection> {
215 let auth_header = parts
216 .headers
217 .get("Authorization")
218 .and_then(|v| v.to_str().ok())
219 .ok_or_else(|| OAuthAuthError {
220 status: StatusCode::UNAUTHORIZED,
221 error: "AuthenticationRequired".to_string(),
222 message: "Authorization header required".to_string(),
223 dpop_nonce: None,
224 })?;
225 let auth_header_trimmed = auth_header.trim();
226 let (token, is_dpop_token) = if auth_header_trimmed.len() >= 7
227 && auth_header_trimmed[..7].eq_ignore_ascii_case("bearer ")
228 {
229 (auth_header_trimmed[7..].trim(), false)
230 } else if auth_header_trimmed.len() >= 5
231 && auth_header_trimmed[..5].eq_ignore_ascii_case("dpop ")
232 {
233 (auth_header_trimmed[5..].trim(), true)
234 } else {
235 return Err(OAuthAuthError {
236 status: StatusCode::UNAUTHORIZED,
237 error: "InvalidRequest".to_string(),
238 message: "Invalid authorization scheme".to_string(),
239 dpop_nonce: None,
240 });
241 };
242 let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok());
243 if let Ok(result) = try_legacy_auth(&state.db, token).await {
244 return Ok(OAuthUser {
245 did: result.did,
246 client_id: None,
247 scope: None,
248 is_oauth: false,
249 permissions: ScopePermissions::default(),
250 });
251 }
252 let http_method = parts.method.as_str();
253 let http_uri = parts.uri.to_string();
254 match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await
255 {
256 Ok(result) => {
257 let permissions = ScopePermissions::from_scope_string(result.scope.as_deref());
258 Ok(OAuthUser {
259 did: result.did,
260 client_id: Some(result.client_id),
261 scope: result.scope,
262 is_oauth: true,
263 permissions,
264 })
265 }
266 Err(OAuthError::UseDpopNonce(nonce)) => Err(OAuthAuthError {
267 status: StatusCode::UNAUTHORIZED,
268 error: "use_dpop_nonce".to_string(),
269 message: "DPoP nonce required".to_string(),
270 dpop_nonce: Some(nonce),
271 }),
272 Err(OAuthError::InvalidDpopProof(msg)) => {
273 let nonce = generate_dpop_nonce();
274 Err(OAuthAuthError {
275 status: StatusCode::UNAUTHORIZED,
276 error: "invalid_dpop_proof".to_string(),
277 message: msg,
278 dpop_nonce: Some(nonce),
279 })
280 }
281 Err(e) => {
282 let nonce = if is_dpop_token {
283 Some(generate_dpop_nonce())
284 } else {
285 None
286 };
287 Err(OAuthAuthError {
288 status: StatusCode::UNAUTHORIZED,
289 error: "AuthenticationFailed".to_string(),
290 message: format!("{:?}", e),
291 dpop_nonce: nonce,
292 })
293 }
294 }
295 }
296}
297
298struct LegacyAuthResult {
299 did: String,
300}
301
302async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> {
303 match crate::auth::validate_bearer_token(pool, token).await {
304 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }),
305 _ => Err(()),
306 }
307}