use super::{AuthRules, HandleCache, SessionData}; use crate::helpers::json_error_response; use crate::AppState; use axum::extract::{Request, State}; use axum::http::{HeaderMap, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use jacquard_identity::resolver::IdentityResolver; use jacquard_identity::PublicResolver; use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; use tracing::log; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum AuthScheme { Bearer, DPoP, } #[derive(Serialize, Deserialize)] pub struct TokenClaims { pub sub: String, /// OAuth scopes as space-separated string (per OAuth 2.0 spec) #[serde(default)] pub scope: Option, } /// State passed to the auth middleware containing both AppState and auth rules. #[derive(Clone)] pub struct AuthMiddlewareState { pub app_state: AppState, pub rules: AuthRules, } /// Core middleware function that validates authentication and applies auth rules. /// /// Use this with `axum::middleware::from_fn_with_state`: /// ```ignore /// use axum::middleware::from_fn_with_state; /// /// let mw_state = AuthMiddlewareState { /// app_state: state.clone(), /// rules: AuthRules::HandleEndsWith(".blacksky.team".into()), /// }; /// /// .route("/protected", get(handler).layer(from_fn_with_state(mw_state, auth_middleware))) /// ``` pub async fn auth_middleware( State(mw_state): State, req: Request, next: Next, ) -> Response { let AuthMiddlewareState { app_state, rules } = mw_state; // 1. Extract DID and scopes from JWT (Bearer token) let extracted = match extract_auth_from_request(req.headers()) { Ok(Some(auth)) => auth, Ok(None) => { return json_error_response(StatusCode::UNAUTHORIZED, "AuthRequired", "Authentication required") .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); } Err(e) => { log::error!("Token extraction error: {}", e); return json_error_response(StatusCode::UNAUTHORIZED, "InvalidToken", &e) .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); } }; // 2. Resolve DID to handle (check cache first) let handle = match resolve_did_to_handle(&app_state.resolver, &app_state.handle_cache, &extracted.did).await { Ok(handle) => handle, Err(e) => { log::error!("Failed to resolve DID {} to handle: {}", extracted.did, e); return json_error_response( StatusCode::INTERNAL_SERVER_ERROR, "ResolutionError", "Failed to resolve identity", ) .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()); } }; // 3. Build session data and validate rules let session = SessionData { did: extracted.did, handle, scopes: extracted.scopes, }; if !rules.validate(&session) { return json_error_response(StatusCode::FORBIDDEN, "AccessDenied", "Access denied by authorization rules") .unwrap_or_else(|_| StatusCode::FORBIDDEN.into_response()); } // 4. Pass through on success next.run(req).await } /// Extracted authentication data from JWT struct ExtractedAuth { did: String, scopes: Vec, } /// Extracts the DID and scopes from the Authorization header (Bearer JWT). fn extract_auth_from_request(headers: &HeaderMap) -> Result, String> { let auth = extract_auth(headers)?; match auth { None => Ok(None), Some((scheme, token_str)) => { match scheme { AuthScheme::Bearer => { let token = UntrustedToken::new(&token_str) .map_err(|_| "Invalid token format".to_string())?; let _claims: Claims = token .deserialize_claims_unchecked() .map_err(|_| "Failed to parse token claims".to_string())?; let key = Hs256Key::new( env::var("PDS_JWT_SECRET") .map_err(|_| "PDS_JWT_SECRET not configured".to_string())?, ); let validated: Token = Hs256 .validator(&key) .validate(&token) .map_err(|e: ValidationError| format!("Token validation failed: {:?}", e))?; let custom = &validated.claims().custom; // Parse scopes from space-separated string (OAuth 2.0 spec) let scopes: Vec = custom.scope .as_ref() .map(|s| s.split_whitespace().map(|s| s.to_string()).collect()) .unwrap_or_default(); Ok(Some(ExtractedAuth { did: custom.sub.clone(), scopes, })) } AuthScheme::DPoP => { // DPoP tokens are not validated here; pass through without auth data Ok(None) } } } } } /// Extracts the authentication scheme and token from the Authorization header. fn extract_auth(headers: &HeaderMap) -> Result, String> { match headers.get(axum::http::header::AUTHORIZATION) { None => Ok(None), Some(hv) => { let s = hv .to_str() .map_err(|_| "Authorization header is not valid UTF-8".to_string())?; let mut parts = s.splitn(2, ' '); match (parts.next(), parts.next()) { (Some("Bearer"), Some(tok)) if !tok.is_empty() => { Ok(Some((AuthScheme::Bearer, tok.to_string()))) } (Some("DPoP"), Some(tok)) if !tok.is_empty() => { Ok(Some((AuthScheme::DPoP, tok.to_string()))) } _ => Err( "Authorization header must be in format 'Bearer ' or 'DPoP '" .to_string(), ), } } } } /// Resolves a DID to its handle using the PublicResolver, with caching. async fn resolve_did_to_handle( resolver: &Arc, cache: &HandleCache, did: &str, ) -> Result { // Check cache first if let Some(handle) = cache.get(did) { return Ok(handle); } // Parse the DID let did_parsed = jacquard_common::types::did::Did::new(did) .map_err(|e| format!("Invalid DID: {:?}", e))?; // Resolve the DID document let did_doc_response = resolver .resolve_did_doc(&did_parsed) .await .map_err(|e| format!("DID resolution failed: {:?}", e))?; let doc = did_doc_response .parse() .map_err(|e| format!("Failed to parse DID document: {:?}", e))?; // Extract handle from alsoKnownAs field // Format is typically: ["at://handle.example.com"] let handle: String = doc .also_known_as .as_ref() .and_then(|aka| { aka.iter() .find(|uri| uri.starts_with("at://")) .map(|uri| uri.strip_prefix("at://").unwrap_or(uri.as_ref()).to_string()) }) .ok_or_else(|| "No ATProto handle found in DID document".to_string())?; // Cache the result cache.insert(did.to_string(), handle.clone()); Ok(handle) } // ============================================================================ // Helper Functions for Creating Middleware State // ============================================================================ /// Creates an `AuthMiddlewareState` for requiring the handle to end with a specific suffix. /// /// # Example /// ```ignore /// use axum::middleware::from_fn_with_state; /// use crate::auth::{auth_middleware, handle_ends_with}; /// /// .route("/protected", get(handler).layer( /// from_fn_with_state(handle_ends_with(".blacksky.team", &state), auth_middleware) /// )) /// ``` pub fn handle_ends_with(suffix: impl Into, state: &AppState) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::HandleEndsWith(suffix.into()), } } /// Creates an `AuthMiddlewareState` for requiring the handle to end with any of the specified suffixes. pub fn handle_ends_with_any(suffixes: I, state: &AppState) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::HandleEndsWithAny(suffixes.into_iter().map(|s| s.into()).collect()), } } /// Creates an `AuthMiddlewareState` for requiring the DID to equal a specific value. pub fn did_equals(did: impl Into, state: &AppState) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::DidEquals(did.into()), } } /// Creates an `AuthMiddlewareState` for requiring the DID to be one of the specified values. pub fn did_equals_any(dids: I, state: &AppState) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::DidEqualsAny(dids.into_iter().map(|d| d.into()).collect()), } } /// Creates an `AuthMiddlewareState` with custom auth rules. pub fn with_rules(rules: AuthRules, state: &AppState) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules, } } // ============================================================================ // Scope Helper Functions // ============================================================================ /// Creates an `AuthMiddlewareState` requiring a specific OAuth scope. /// /// # Example /// ```ignore /// .route("/xrpc/com.atproto.repo.createRecord", /// post(handler).layer(from_fn_with_state( /// scope_equals("repo:app.bsky.feed.post", &state), /// auth_middleware /// ))) /// ``` pub fn scope_equals(scope: impl Into, state: &AppState) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::ScopeEquals(scope.into()), } } /// Creates an `AuthMiddlewareState` requiring ANY of the specified scopes (OR logic). /// /// # Example /// ```ignore /// .route("/xrpc/com.atproto.repo.putRecord", /// post(handler).layer(from_fn_with_state( /// scope_any(["repo:app.bsky.feed.post", "transition:generic"], &state), /// auth_middleware /// ))) /// ``` pub fn scope_any(scopes: I, state: &AppState) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::ScopeEqualsAny(scopes.into_iter().map(|s| s.into()).collect()), } } /// Creates an `AuthMiddlewareState` requiring ALL of the specified scopes (AND logic). /// /// # Example /// ```ignore /// .route("/xrpc/com.atproto.admin.updateAccount", /// post(handler).layer(from_fn_with_state( /// scope_all(["account:email", "account:repo?action=manage"], &state), /// auth_middleware /// ))) /// ``` pub fn scope_all(scopes: I, state: &AppState) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), } } // ============================================================================ // Combined Rule Helpers (Identity + Scope) // ============================================================================ /// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have a specific scope. /// /// # Example /// ```ignore /// .route("/xrpc/community.blacksky.feed.generator", /// post(handler).layer(from_fn_with_state( /// handle_ends_with_and_scope(".blacksky.team", "transition:generic", &state), /// auth_middleware /// ))) /// ``` pub fn handle_ends_with_and_scope( suffix: impl Into, scope: impl Into, state: &AppState, ) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::All(vec![ AuthRules::HandleEndsWith(suffix.into()), AuthRules::ScopeEquals(scope.into()), ]), } } /// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have ALL specified scopes. /// /// # Example /// ```ignore /// .route("/xrpc/community.blacksky.admin.manage", /// post(handler).layer(from_fn_with_state( /// handle_ends_with_and_scopes(".blacksky.team", ["transition:generic", "identity:*"], &state), /// auth_middleware /// ))) /// ``` pub fn handle_ends_with_and_scopes( suffix: impl Into, scopes: I, state: &AppState, ) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::All(vec![ AuthRules::HandleEndsWith(suffix.into()), AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), ]), } } /// Creates an `AuthMiddlewareState` requiring DID to equal value AND have a specific scope. /// /// # Example /// ```ignore /// .route("/xrpc/com.atproto.admin.deleteAccount", /// post(handler).layer(from_fn_with_state( /// did_with_scope("did:plc:rnpkyqnmsw4ipey6eotbdnnf", "transition:generic", &state), /// auth_middleware /// ))) /// ``` pub fn did_with_scope( did: impl Into, scope: impl Into, state: &AppState, ) -> AuthMiddlewareState { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::All(vec![ AuthRules::DidEquals(did.into()), AuthRules::ScopeEquals(scope.into()), ]), } } /// Creates an `AuthMiddlewareState` requiring DID to equal value AND have ALL specified scopes. /// /// # Example /// ```ignore /// .route("/xrpc/com.atproto.admin.fullAccess", /// post(handler).layer(from_fn_with_state( /// did_with_scopes("did:plc:rnpkyqnmsw4ipey6eotbdnnf", ["transition:generic", "identity:*"], &state), /// auth_middleware /// ))) /// ``` pub fn did_with_scopes( did: impl Into, scopes: I, state: &AppState, ) -> AuthMiddlewareState where I: IntoIterator, T: Into, { AuthMiddlewareState { app_state: state.clone(), rules: AuthRules::All(vec![ AuthRules::DidEquals(did.into()), AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), ]), } }