this repo has no description
1use axum::{ 2 Form, Json, 3 extract::State, 4 http::{HeaderMap, StatusCode}, 5}; 6use base64::Engine; 7use base64::engine::general_purpose::URL_SAFE_NO_PAD; 8use chrono::{Duration, Utc}; 9use hmac::Mac; 10use serde::{Deserialize, Serialize}; 11use sha2::{Digest, Sha256}; 12use subtle::ConstantTimeEq; 13 14use crate::config::AuthConfig; 15use crate::state::AppState; 16use crate::oauth::{ 17 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId, 18 client::{ClientMetadataCache, verify_client_auth}, 19 db, 20 dpop::DPoPVerifier, 21}; 22 23const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600; 24const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60; 25 26#[derive(Debug, Deserialize)] 27pub struct TokenRequest { 28 pub grant_type: String, 29 #[serde(default)] 30 pub code: Option<String>, 31 #[serde(default)] 32 pub redirect_uri: Option<String>, 33 #[serde(default)] 34 pub code_verifier: Option<String>, 35 #[serde(default)] 36 pub refresh_token: Option<String>, 37 #[serde(default)] 38 pub client_id: Option<String>, 39 #[serde(default)] 40 pub client_secret: Option<String>, 41 #[serde(default)] 42 pub client_assertion: Option<String>, 43 #[serde(default)] 44 pub client_assertion_type: Option<String>, 45} 46 47#[derive(Debug, Serialize)] 48pub struct TokenResponse { 49 pub access_token: String, 50 pub token_type: String, 51 pub expires_in: u64, 52 #[serde(skip_serializing_if = "Option::is_none")] 53 pub refresh_token: Option<String>, 54 #[serde(skip_serializing_if = "Option::is_none")] 55 pub scope: Option<String>, 56 #[serde(skip_serializing_if = "Option::is_none")] 57 pub sub: Option<String>, 58} 59 60pub async fn token_endpoint( 61 State(state): State<AppState>, 62 headers: HeaderMap, 63 Form(request): Form<TokenRequest>, 64) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 65 let dpop_proof = headers 66 .get("DPoP") 67 .and_then(|v| v.to_str().ok()) 68 .map(|s| s.to_string()); 69 70 match request.grant_type.as_str() { 71 "authorization_code" => { 72 handle_authorization_code_grant(state, headers, request, dpop_proof).await 73 } 74 "refresh_token" => { 75 handle_refresh_token_grant(state, headers, request, dpop_proof).await 76 } 77 _ => Err(OAuthError::UnsupportedGrantType(format!( 78 "Unsupported grant_type: {}", 79 request.grant_type 80 ))), 81 } 82} 83 84async fn handle_authorization_code_grant( 85 state: AppState, 86 _headers: HeaderMap, 87 request: TokenRequest, 88 dpop_proof: Option<String>, 89) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 90 let code = request 91 .code 92 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?; 93 94 let code_verifier = request 95 .code_verifier 96 .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?; 97 98 let auth_request = db::consume_authorization_request_by_code(&state.db, &code) 99 .await? 100 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 101 102 if auth_request.expires_at < Utc::now() { 103 return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string())); 104 } 105 106 if let Some(request_client_id) = &request.client_id { 107 if request_client_id != &auth_request.client_id { 108 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 109 } 110 } 111 112 let did = auth_request 113 .did 114 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?; 115 116 let client_metadata_cache = ClientMetadataCache::new(3600); 117 let client_metadata = client_metadata_cache 118 .get(&auth_request.client_id) 119 .await?; 120 let client_auth = auth_request.client_auth.clone().unwrap_or(ClientAuth::None); 121 verify_client_auth(&client_metadata, &client_auth)?; 122 123 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 124 125 if let Some(redirect_uri) = &request.redirect_uri { 126 if redirect_uri != &auth_request.parameters.redirect_uri { 127 return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string())); 128 } 129 } 130 131 let dpop_jkt = if let Some(proof) = &dpop_proof { 132 let config = AuthConfig::get(); 133 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 134 135 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 136 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 137 138 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 139 140 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 141 return Err(OAuthError::InvalidDpopProof( 142 "DPoP proof has already been used".to_string(), 143 )); 144 } 145 146 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt { 147 if &result.jkt != expected_jkt { 148 return Err(OAuthError::InvalidDpopProof( 149 "DPoP key binding mismatch".to_string(), 150 )); 151 } 152 } 153 154 Some(result.jkt) 155 } else if auth_request.parameters.dpop_jkt.is_some() { 156 return Err(OAuthError::InvalidRequest( 157 "DPoP proof required for this authorization".to_string(), 158 )); 159 } else { 160 None 161 }; 162 163 let token_id = TokenId::generate(); 164 let refresh_token = RefreshToken::generate(); 165 let now = Utc::now(); 166 167 let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?; 168 169 let token_data = TokenData { 170 did: did.clone(), 171 token_id: token_id.0.clone(), 172 created_at: now, 173 updated_at: now, 174 expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS), 175 client_id: auth_request.client_id.clone(), 176 client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None), 177 device_id: auth_request.device_id, 178 parameters: auth_request.parameters.clone(), 179 details: None, 180 code: None, 181 current_refresh_token: Some(refresh_token.0.clone()), 182 scope: auth_request.parameters.scope.clone(), 183 }; 184 185 db::create_token(&state.db, &token_data).await?; 186 187 tokio::spawn({ 188 let pool = state.db.clone(); 189 let did_clone = did.clone(); 190 async move { 191 if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await { 192 tracing::warn!("Failed to enforce token limit for user: {:?}", e); 193 } 194 } 195 }); 196 197 let mut response_headers = HeaderMap::new(); 198 let config = AuthConfig::get(); 199 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 200 response_headers.insert( 201 "DPoP-Nonce", 202 verifier.generate_nonce().parse().unwrap(), 203 ); 204 205 Ok(( 206 response_headers, 207 Json(TokenResponse { 208 access_token, 209 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 210 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 211 refresh_token: Some(refresh_token.0), 212 scope: auth_request.parameters.scope, 213 sub: Some(did), 214 }), 215 )) 216} 217 218async fn handle_refresh_token_grant( 219 state: AppState, 220 _headers: HeaderMap, 221 request: TokenRequest, 222 dpop_proof: Option<String>, 223) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { 224 let refresh_token_str = request 225 .refresh_token 226 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?; 227 228 if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? { 229 db::delete_token_family(&state.db, token_id).await?; 230 return Err(OAuthError::InvalidGrant( 231 "Refresh token reuse detected, token family revoked".to_string(), 232 )); 233 } 234 235 let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str) 236 .await? 237 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?; 238 239 if token_data.expires_at < Utc::now() { 240 db::delete_token_family(&state.db, db_id).await?; 241 return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string())); 242 } 243 244 let dpop_jkt = if let Some(proof) = &dpop_proof { 245 let config = AuthConfig::get(); 246 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 247 248 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 249 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 250 251 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 252 253 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 254 return Err(OAuthError::InvalidDpopProof( 255 "DPoP proof has already been used".to_string(), 256 )); 257 } 258 259 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt { 260 if &result.jkt != expected_jkt { 261 return Err(OAuthError::InvalidDpopProof( 262 "DPoP key binding mismatch".to_string(), 263 )); 264 } 265 } 266 267 Some(result.jkt) 268 } else if token_data.parameters.dpop_jkt.is_some() { 269 return Err(OAuthError::InvalidRequest( 270 "DPoP proof required".to_string(), 271 )); 272 } else { 273 None 274 }; 275 276 let new_token_id = TokenId::generate(); 277 let new_refresh_token = RefreshToken::generate(); 278 let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS); 279 280 db::rotate_token( 281 &state.db, 282 db_id, 283 &new_token_id.0, 284 &new_refresh_token.0, 285 new_expires_at, 286 ) 287 .await?; 288 289 let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?; 290 291 let mut response_headers = HeaderMap::new(); 292 let config = AuthConfig::get(); 293 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 294 response_headers.insert( 295 "DPoP-Nonce", 296 verifier.generate_nonce().parse().unwrap(), 297 ); 298 299 Ok(( 300 response_headers, 301 Json(TokenResponse { 302 access_token, 303 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 304 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 305 refresh_token: Some(new_refresh_token.0), 306 scope: token_data.scope, 307 sub: Some(token_data.did), 308 }), 309 )) 310} 311 312fn verify_pkce(code_challenge: &str, code_verifier: &str) -> Result<(), OAuthError> { 313 use subtle::ConstantTimeEq; 314 315 let mut hasher = Sha256::new(); 316 hasher.update(code_verifier.as_bytes()); 317 let hash = hasher.finalize(); 318 let computed_challenge = URL_SAFE_NO_PAD.encode(&hash); 319 320 if !bool::from(computed_challenge.as_bytes().ct_eq(code_challenge.as_bytes())) { 321 return Err(OAuthError::InvalidGrant("PKCE verification failed".to_string())); 322 } 323 324 Ok(()) 325} 326 327fn create_access_token( 328 token_id: &str, 329 sub: &str, 330 dpop_jkt: Option<&str>, 331) -> Result<String, OAuthError> { 332 use serde_json::json; 333 334 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 335 let issuer = format!("https://{}", pds_hostname); 336 337 let now = Utc::now().timestamp(); 338 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS; 339 340 let mut payload = json!({ 341 "iss": issuer, 342 "sub": sub, 343 "aud": issuer, 344 "iat": now, 345 "exp": exp, 346 "jti": token_id, 347 "scope": "atproto" 348 }); 349 350 if let Some(jkt) = dpop_jkt { 351 payload["cnf"] = json!({ "jkt": jkt }); 352 } 353 354 let header = json!({ 355 "alg": "HS256", 356 "typ": "at+jwt" 357 }); 358 359 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 360 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 361 362 let signing_input = format!("{}.{}", header_b64, payload_b64); 363 364 let config = AuthConfig::get(); 365 366 use sha2::Sha256 as HmacSha256; 367 use hmac::{Hmac, Mac}; 368 type HmacSha256Type = Hmac<HmacSha256>; 369 370 let mut mac = HmacSha256Type::new_from_slice(config.jwt_secret().as_bytes()) 371 .map_err(|_| OAuthError::ServerError("HMAC key error".to_string()))?; 372 mac.update(signing_input.as_bytes()); 373 let signature = mac.finalize().into_bytes(); 374 375 let signature_b64 = URL_SAFE_NO_PAD.encode(&signature); 376 377 Ok(format!("{}.{}", signing_input, signature_b64)) 378} 379 380pub async fn revoke_token( 381 State(state): State<AppState>, 382 Form(request): Form<RevokeRequest>, 383) -> Result<StatusCode, OAuthError> { 384 if let Some(token) = &request.token { 385 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 386 db::delete_token_family(&state.db, db_id).await?; 387 } else { 388 db::delete_token(&state.db, token).await?; 389 } 390 } 391 392 Ok(StatusCode::OK) 393} 394 395#[derive(Debug, Deserialize)] 396pub struct RevokeRequest { 397 pub token: Option<String>, 398 #[serde(default)] 399 pub token_type_hint: Option<String>, 400} 401 402#[derive(Debug, Deserialize)] 403pub struct IntrospectRequest { 404 pub token: String, 405 #[serde(default)] 406 pub token_type_hint: Option<String>, 407} 408 409#[derive(Debug, Serialize)] 410pub struct IntrospectResponse { 411 pub active: bool, 412 #[serde(skip_serializing_if = "Option::is_none")] 413 pub scope: Option<String>, 414 #[serde(skip_serializing_if = "Option::is_none")] 415 pub client_id: Option<String>, 416 #[serde(skip_serializing_if = "Option::is_none")] 417 pub username: Option<String>, 418 #[serde(skip_serializing_if = "Option::is_none")] 419 pub token_type: Option<String>, 420 #[serde(skip_serializing_if = "Option::is_none")] 421 pub exp: Option<i64>, 422 #[serde(skip_serializing_if = "Option::is_none")] 423 pub iat: Option<i64>, 424 #[serde(skip_serializing_if = "Option::is_none")] 425 pub nbf: Option<i64>, 426 #[serde(skip_serializing_if = "Option::is_none")] 427 pub sub: Option<String>, 428 #[serde(skip_serializing_if = "Option::is_none")] 429 pub aud: Option<String>, 430 #[serde(skip_serializing_if = "Option::is_none")] 431 pub iss: Option<String>, 432 #[serde(skip_serializing_if = "Option::is_none")] 433 pub jti: Option<String>, 434} 435 436pub async fn introspect_token( 437 State(state): State<AppState>, 438 Form(request): Form<IntrospectRequest>, 439) -> Json<IntrospectResponse> { 440 let inactive_response = IntrospectResponse { 441 active: false, 442 scope: None, 443 client_id: None, 444 username: None, 445 token_type: None, 446 exp: None, 447 iat: None, 448 nbf: None, 449 sub: None, 450 aud: None, 451 iss: None, 452 jti: None, 453 }; 454 455 let token_info = match extract_token_claims(&request.token) { 456 Ok(info) => info, 457 Err(_) => return Json(inactive_response), 458 }; 459 460 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await { 461 Ok(Some(data)) => data, 462 _ => return Json(inactive_response), 463 }; 464 465 if token_data.expires_at < Utc::now() { 466 return Json(inactive_response); 467 } 468 469 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 470 let issuer = format!("https://{}", pds_hostname); 471 472 Json(IntrospectResponse { 473 active: true, 474 scope: token_data.scope, 475 client_id: Some(token_data.client_id), 476 username: None, 477 token_type: if token_data.parameters.dpop_jkt.is_some() { 478 Some("DPoP".to_string()) 479 } else { 480 Some("Bearer".to_string()) 481 }, 482 exp: Some(token_info.exp), 483 iat: Some(token_info.iat), 484 nbf: Some(token_info.iat), 485 sub: Some(token_data.did), 486 aud: Some(issuer.clone()), 487 iss: Some(issuer), 488 jti: Some(token_info.jti), 489 }) 490} 491 492struct TokenClaims { 493 jti: String, 494 exp: i64, 495 iat: i64, 496} 497 498fn extract_token_claims(token: &str) -> Result<TokenClaims, OAuthError> { 499 let parts: Vec<&str> = token.split('.').collect(); 500 if parts.len() != 3 { 501 return Err(OAuthError::InvalidToken("Invalid token format".to_string())); 502 } 503 504 let header_bytes = URL_SAFE_NO_PAD 505 .decode(parts[0]) 506 .map_err(|_| OAuthError::InvalidToken("Invalid token encoding".to_string()))?; 507 let header: serde_json::Value = serde_json::from_slice(&header_bytes) 508 .map_err(|_| OAuthError::InvalidToken("Invalid token header".to_string()))?; 509 510 if header.get("typ").and_then(|t| t.as_str()) != Some("at+jwt") { 511 return Err(OAuthError::InvalidToken("Not an OAuth access token".to_string())); 512 } 513 if header.get("alg").and_then(|a| a.as_str()) != Some("HS256") { 514 return Err(OAuthError::InvalidToken("Unsupported algorithm".to_string())); 515 } 516 517 let config = AuthConfig::get(); 518 let secret = config.jwt_secret(); 519 520 let signing_input = format!("{}.{}", parts[0], parts[1]); 521 let provided_sig = URL_SAFE_NO_PAD 522 .decode(parts[2]) 523 .map_err(|_| OAuthError::InvalidToken("Invalid signature encoding".to_string()))?; 524 525 type HmacSha256 = hmac::Hmac<Sha256>; 526 let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) 527 .map_err(|_| OAuthError::ServerError("HMAC initialization failed".to_string()))?; 528 mac.update(signing_input.as_bytes()); 529 let expected_sig = mac.finalize().into_bytes(); 530 531 if !bool::from(expected_sig.ct_eq(&provided_sig)) { 532 return Err(OAuthError::InvalidToken("Invalid token signature".to_string())); 533 } 534 535 let payload_bytes = URL_SAFE_NO_PAD 536 .decode(parts[1]) 537 .map_err(|_| OAuthError::InvalidToken("Invalid payload encoding".to_string()))?; 538 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) 539 .map_err(|_| OAuthError::InvalidToken("Invalid token payload".to_string()))?; 540 541 let jti = payload 542 .get("jti") 543 .and_then(|j| j.as_str()) 544 .ok_or_else(|| OAuthError::InvalidToken("Missing jti claim".to_string()))? 545 .to_string(); 546 547 let exp = payload 548 .get("exp") 549 .and_then(|e| e.as_i64()) 550 .ok_or_else(|| OAuthError::InvalidToken("Missing exp claim".to_string()))?; 551 552 let iat = payload 553 .get("iat") 554 .and_then(|i| i.as_i64()) 555 .ok_or_else(|| OAuthError::InvalidToken("Missing iat claim".to_string()))?; 556 557 Ok(TokenClaims { jti, exp, iat }) 558}