this repo has no description
1use axum::{
2 Form, Json,
3 extract::State,
4 http::{HeaderMap, StatusCode},
5};
6use base64::Engine;
7use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8use chrono::{Duration, Utc};
9use hmac::Mac;
10use serde::{Deserialize, Serialize};
11use sha2::{Digest, Sha256};
12use subtle::ConstantTimeEq;
13
14use crate::config::AuthConfig;
15use crate::state::AppState;
16use crate::oauth::{
17 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
18 client::{ClientMetadataCache, verify_client_auth},
19 db,
20 dpop::DPoPVerifier,
21};
22
23const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
24const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
25
26#[derive(Debug, Deserialize)]
27pub struct TokenRequest {
28 pub grant_type: String,
29 #[serde(default)]
30 pub code: Option<String>,
31 #[serde(default)]
32 pub redirect_uri: Option<String>,
33 #[serde(default)]
34 pub code_verifier: Option<String>,
35 #[serde(default)]
36 pub refresh_token: Option<String>,
37 #[serde(default)]
38 pub client_id: Option<String>,
39 #[serde(default)]
40 pub client_secret: Option<String>,
41 #[serde(default)]
42 pub client_assertion: Option<String>,
43 #[serde(default)]
44 pub client_assertion_type: Option<String>,
45}
46
47#[derive(Debug, Serialize)]
48pub struct TokenResponse {
49 pub access_token: String,
50 pub token_type: String,
51 pub expires_in: u64,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub refresh_token: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub scope: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub sub: Option<String>,
58}
59
60pub async fn token_endpoint(
61 State(state): State<AppState>,
62 headers: HeaderMap,
63 Form(request): Form<TokenRequest>,
64) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
65 let dpop_proof = headers
66 .get("DPoP")
67 .and_then(|v| v.to_str().ok())
68 .map(|s| s.to_string());
69
70 match request.grant_type.as_str() {
71 "authorization_code" => {
72 handle_authorization_code_grant(state, headers, request, dpop_proof).await
73 }
74 "refresh_token" => {
75 handle_refresh_token_grant(state, headers, request, dpop_proof).await
76 }
77 _ => Err(OAuthError::UnsupportedGrantType(format!(
78 "Unsupported grant_type: {}",
79 request.grant_type
80 ))),
81 }
82}
83
84async fn handle_authorization_code_grant(
85 state: AppState,
86 _headers: HeaderMap,
87 request: TokenRequest,
88 dpop_proof: Option<String>,
89) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
90 let code = request
91 .code
92 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
93
94 let code_verifier = request
95 .code_verifier
96 .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
97
98 let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
99 .await?
100 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
101
102 if auth_request.expires_at < Utc::now() {
103 return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
104 }
105
106 if let Some(request_client_id) = &request.client_id {
107 if request_client_id != &auth_request.client_id {
108 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
109 }
110 }
111
112 let did = auth_request
113 .did
114 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
115
116 let client_metadata_cache = ClientMetadataCache::new(3600);
117 let client_metadata = client_metadata_cache
118 .get(&auth_request.client_id)
119 .await?;
120 let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None);
121 verify_client_auth(&client_metadata, &client_auth)?;
122
123 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
124
125 if let Some(redirect_uri) = &request.redirect_uri {
126 if redirect_uri != &auth_request.parameters.redirect_uri {
127 return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
128 }
129 }
130
131 let dpop_jkt = if let Some(proof) = &dpop_proof {
132 let config = AuthConfig::get();
133 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
134
135 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
136 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
137
138 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
139
140 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
141 return Err(OAuthError::InvalidDpopProof(
142 "DPoP proof has already been used".to_string(),
143 ));
144 }
145
146 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
147 if &result.jkt != expected_jkt {
148 return Err(OAuthError::InvalidDpopProof(
149 "DPoP key binding mismatch".to_string(),
150 ));
151 }
152 }
153
154 Some(result.jkt)
155 } else if auth_request.parameters.dpop_jkt.is_some() {
156 return Err(OAuthError::InvalidRequest(
157 "DPoP proof required for this authorization".to_string(),
158 ));
159 } else {
160 None
161 };
162
163 let token_id = TokenId::generate();
164 let refresh_token = RefreshToken::generate();
165 let now = Utc::now();
166
167 let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
168
169 let token_data = TokenData {
170 did: did.clone(),
171 token_id: token_id.0.clone(),
172 created_at: now,
173 updated_at: now,
174 expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
175 client_id: auth_request.client_id.clone(),
176 client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
177 device_id: auth_request.device_id,
178 parameters: auth_request.parameters.clone(),
179 details: None,
180 code: None,
181 current_refresh_token: Some(refresh_token.0.clone()),
182 scope: auth_request.parameters.scope.clone(),
183 };
184
185 db::create_token(&state.db, &token_data).await?;
186
187 tokio::spawn({
188 let pool = state.db.clone();
189 let did_clone = did.clone();
190 async move {
191 if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
192 tracing::warn!("Failed to enforce token limit for user: {:?}", e);
193 }
194 }
195 });
196
197 let mut response_headers = HeaderMap::new();
198 let config = AuthConfig::get();
199 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
200 response_headers.insert(
201 "DPoP-Nonce",
202 verifier.generate_nonce().parse().unwrap(),
203 );
204
205 Ok((
206 response_headers,
207 Json(TokenResponse {
208 access_token,
209 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
210 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
211 refresh_token: Some(refresh_token.0),
212 scope: auth_request.parameters.scope,
213 sub: Some(did),
214 }),
215 ))
216}
217
218async fn handle_refresh_token_grant(
219 state: AppState,
220 _headers: HeaderMap,
221 request: TokenRequest,
222 dpop_proof: Option<String>,
223) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
224 let refresh_token_str = request
225 .refresh_token
226 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
227
228 if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
229 db::delete_token_family(&state.db, token_id).await?;
230 return Err(OAuthError::InvalidGrant(
231 "Refresh token reuse detected, token family revoked".to_string(),
232 ));
233 }
234
235 let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
236 .await?
237 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
238
239 if token_data.expires_at < Utc::now() {
240 db::delete_token_family(&state.db, db_id).await?;
241 return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
242 }
243
244 let dpop_jkt = if let Some(proof) = &dpop_proof {
245 let config = AuthConfig::get();
246 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
247
248 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
249 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
250
251 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
252
253 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
254 return Err(OAuthError::InvalidDpopProof(
255 "DPoP proof has already been used".to_string(),
256 ));
257 }
258
259 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
260 if &result.jkt != expected_jkt {
261 return Err(OAuthError::InvalidDpopProof(
262 "DPoP key binding mismatch".to_string(),
263 ));
264 }
265 }
266
267 Some(result.jkt)
268 } else if token_data.parameters.dpop_jkt.is_some() {
269 return Err(OAuthError::InvalidRequest(
270 "DPoP proof required".to_string(),
271 ));
272 } else {
273 None
274 };
275
276 let new_token_id = TokenId::generate();
277 let new_refresh_token = RefreshToken::generate();
278 let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
279
280 db::rotate_token(
281 &state.db,
282 db_id,
283 &new_token_id.0,
284 &new_refresh_token.0,
285 new_expires_at,
286 )
287 .await?;
288
289 let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
290
291 let mut response_headers = HeaderMap::new();
292 let config = AuthConfig::get();
293 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
294 response_headers.insert(
295 "DPoP-Nonce",
296 verifier.generate_nonce().parse().unwrap(),
297 );
298
299 Ok((
300 response_headers,
301 Json(TokenResponse {
302 access_token,
303 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
304 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
305 refresh_token: Some(new_refresh_token.0),
306 scope: token_data.scope,
307 sub: Some(token_data.did),
308 }),
309 ))
310}
311
312fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> {
313 use subtle::ConstantTimeEq;
314
315 let mut hasher = Sha256::new();
316 hasher.update(code_verifier.as_bytes());
317 let hash = hasher.finalize();
318 let computed_challenge = URL_SAFE_NO_PAD.encode(&hash);
319
320 if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) {
321 return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string()));
322 }
323
324 Ok(())
325}
326
327fn create_access_token(
328 token_id: &str,
329 sub: &str,
330 dpop_jkt: Option<&str>,
331) -> Result<String, OAuthError> {
332 use serde_json::json;
333
334 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
335 let issuer = format!("https://{}", pds_hostname);
336
337 let now = Utc::now().timestamp();
338 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
339
340 let mut payload = json!({
341 "iss": issuer,
342 "sub": sub,
343 "aud": issuer,
344 "iat": now,
345 "exp": exp,
346 "jti": token_id,
347 "scope": "atproto"
348 });
349
350 if let Some(jkt) = dpop_jkt {
351 payload["cnf"] = json!({ "jkt": jkt });
352 }
353
354 let header = json!({
355 "alg": "HS256",
356 "typ": "at+jwt"
357 });
358
359 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
360 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap());
361
362 let signing_input = format!("{}.{}", header_b64, payload_b64);
363
364 let config = AuthConfig::get();
365
366 use sha2::Sha256 as HmacSha256;
367 use hmac::{Hmac, Mac};
368 type HmacSha256Type = Hmac<HmacSha256>;
369
370 let mut mac = HmacSha256Type::new_from_slice(config.jwt_secret().as_bytes())
371 .map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?;
372 mac.update(signing_input.as_bytes());
373 let signature = mac.finalize().into_bytes();
374
375 let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
376
377 Ok(format!("{}.{}", signing_input, signature_b64))
378}
379
380pub async fn revoke_token(
381 State(state): State<AppState>,
382 Form(request): Form<RevokeRequest>,
383) -> Result<StatusCode, OAuthError> {
384 if let Some(token) = &request.token {
385 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
386 db::delete_token_family(&state.db, db_id).await?;
387 } else {
388 db::delete_token(&state.db, token).await?;
389 }
390 }
391
392 Ok(StatusCode::OK)
393}
394
395#[derive(Debug, Deserialize)]
396pub struct RevokeRequest {
397 pub token: Option<String>,
398 #[serde(default)]
399 pub token_type_hint: Option<String>,
400}
401
402#[derive(Debug, Deserialize)]
403pub struct IntrospectRequest {
404 pub token: String,
405 #[serde(default)]
406 pub token_type_hint: Option<String>,
407}
408
409#[derive(Debug, Serialize)]
410pub struct IntrospectResponse {
411 pub active: bool,
412 #[serde(skip_serializing_if = "Option::is_none")]
413 pub scope: Option<String>,
414 #[serde(skip_serializing_if = "Option::is_none")]
415 pub client_id: Option<String>,
416 #[serde(skip_serializing_if = "Option::is_none")]
417 pub username: Option<String>,
418 #[serde(skip_serializing_if = "Option::is_none")]
419 pub token_type: Option<String>,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 pub exp: Option<i64>,
422 #[serde(skip_serializing_if = "Option::is_none")]
423 pub iat: Option<i64>,
424 #[serde(skip_serializing_if = "Option::is_none")]
425 pub nbf: Option<i64>,
426 #[serde(skip_serializing_if = "Option::is_none")]
427 pub sub: Option<String>,
428 #[serde(skip_serializing_if = "Option::is_none")]
429 pub aud: Option<String>,
430 #[serde(skip_serializing_if = "Option::is_none")]
431 pub iss: Option<String>,
432 #[serde(skip_serializing_if = "Option::is_none")]
433 pub jti: Option<String>,
434}
435
436pub async fn introspect_token(
437 State(state): State<AppState>,
438 Form(request): Form<IntrospectRequest>,
439) -> Json<IntrospectResponse> {
440 let inactive_response = IntrospectResponse {
441 active: false,
442 scope: None,
443 client_id: None,
444 username: None,
445 token_type: None,
446 exp: None,
447 iat: None,
448 nbf: None,
449 sub: None,
450 aud: None,
451 iss: None,
452 jti: None,
453 };
454
455 let token_info = match extract_token_claims(&request.token) {
456 Ok(info) => info,
457 Err(_) => return Json(inactive_response),
458 };
459
460 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
461 Ok(Some(data)) => data,
462 _ => return Json(inactive_response),
463 };
464
465 if token_data.expires_at < Utc::now() {
466 return Json(inactive_response);
467 }
468
469 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
470 let issuer = format!("https://{}", pds_hostname);
471
472 Json(IntrospectResponse {
473 active: true,
474 scope: token_data.scope,
475 client_id: Some(token_data.client_id),
476 username: None,
477 token_type: if token_data.parameters.dpop_jkt.is_some() {
478 Some("DPoP".to_string())
479 } else {
480 Some("Bearer".to_string())
481 },
482 exp: Some(token_info.exp),
483 iat: Some(token_info.iat),
484 nbf: Some(token_info.iat),
485 sub: Some(token_data.did),
486 aud: Some(issuer.clone()),
487 iss: Some(issuer),
488 jti: Some(token_info.jti),
489 })
490}
491
492struct TokenClaims {
493 jti: String,
494 exp: i64,
495 iat: i64,
496}
497
498fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> {
499 let parts: Vec<&str> = token.split('.').collect();
500 if parts.len() != 3 {
501 return Err(OAuthError::InvalidToken("Invalid token format".to_string()));
502 }
503
504 let header_bytes = URL_SAFE_NO_PAD
505 .decode(parts[0])
506 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?;
507 let header: serde_json::Value = serde_json::from_slice(&header_bytes)
508 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?;
509
510 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") {
511 return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string()));
512 }
513 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") {
514 return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string()));
515 }
516
517 let config = AuthConfig::get();
518 let secret = config.jwt_secret();
519
520 let signing_input = format!("{}.{}", parts[0], parts[1]);
521 let provided_sig = URL_SAFE_NO_PAD
522 .decode(parts[2])
523 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?;
524
525 type HmacSha256 = hmac::Hmac<Sha256>;
526 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
527 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?;
528 mac.update(signing_input.as_bytes());
529 let expected_sig = mac.finalize().into_bytes();
530
531 if !bool::from(expected_sig.ct_eq(&provided_sig)) {
532 return Err(OAuthError::InvalidToken("Invalid token signature".to_string()));
533 }
534
535 let payload_bytes = URL_SAFE_NO_PAD
536 .decode(parts[1])
537 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?;
538 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
539 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?;
540
541 let jti = payload
542 .get("jti")
543 .and_then(|j| j.as_str())
544 .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))?
545 .to_string();
546
547 let exp = payload
548 .get("exp")
549 .and_then(|e| e.as_i64())
550 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?;
551
552 let iat = payload
553 .get("iat")
554 .and_then(|i| i.as_i64())
555 .ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?;
556
557 Ok(TokenClaims { jti, exp, iat })
558}