Noreposts Feed
at main 200 lines 6.7 kB view raw
1use anyhow::{anyhow, Result}; 2use atrium_common::resolver::Resolver; 3use atrium_crypto::did::{format_did_key, parse_multikey}; 4use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; 5use atrium_xrpc_client::reqwest::ReqwestClient; 6use base64::Engine; 7use jwt_compact::UntrustedToken; 8use std::sync::Arc; 9use tracing::{debug, warn}; 10 11use crate::types::JwtClaims; 12 13// Unused structs kept for reference if needed in future 14// #[derive(Debug, Deserialize)] 15// struct EmptyCustomClaims {} 16// 17// #[derive(Debug, Deserialize)] 18// struct StandardClaims { 19// #[serde(rename = "iss")] 20// issuer: Option<String>, 21// #[serde(rename = "aud")] 22// audience: Option<String>, 23// #[serde(rename = "exp")] 24// expiration: Option<i64>, 25// } 26 27/// Resolves a DID and extracts the atproto signing key as a did:key string 28async fn resolve_signing_key( 29 resolver: &CommonDidResolver<ReqwestClient>, 30 did_str: &str, 31) -> Result<String> { 32 debug!("Resolving DID: {}", did_str); 33 34 // Convert string to Did type 35 let did = did_str.parse().map_err(|e| { 36 warn!("Invalid DID format: {}", e); 37 anyhow!("Invalid DID format: {}", e) 38 })?; 39 40 // Resolve the DID document 41 let did_doc = resolver.resolve(&did).await.map_err(|e| { 42 warn!("Failed to resolve DID {}: {}", did_str, e); 43 anyhow!("Failed to resolve DID: {}", e) 44 })?; 45 46 debug!("DID document resolved: {:?}", did_doc); 47 48 // Use the built-in helper to get the signing key 49 let verification_method = did_doc.get_signing_key().ok_or_else(|| { 50 warn!("No atproto verification method found in DID document"); 51 anyhow!("No atproto signing key found in DID document") 52 })?; 53 54 debug!("Found verification method: {:?}", verification_method); 55 56 // Extract publicKeyMultibase 57 let public_key_multibase = verification_method 58 .public_key_multibase 59 .as_ref() 60 .ok_or_else(|| { 61 warn!("Verification method missing publicKeyMultibase"); 62 anyhow!("Missing publicKeyMultibase in verification method") 63 })?; 64 65 debug!("Public key multibase: {}", public_key_multibase); 66 67 // Parse the multibase-encoded key 68 let (algorithm, key_bytes) = parse_multikey(public_key_multibase).map_err(|e| { 69 warn!("Failed to parse multikey: {}", e); 70 anyhow!("Invalid publicKeyMultibase format: {}", e) 71 })?; 72 73 debug!( 74 "Parsed key: algorithm={:?}, key_len={}", 75 algorithm, 76 key_bytes.len() 77 ); 78 79 // Format as did:key 80 let did_key = format_did_key(algorithm, &key_bytes).map_err(|e| { 81 warn!("Failed to format did:key: {}", e); 82 anyhow!("Failed to convert key to did:key format: {}", e) 83 })?; 84 85 debug!("Formatted did:key: {}", did_key); 86 Ok(did_key) 87} 88 89pub async fn validate_jwt(token: &str, service_did: &str) -> Result<JwtClaims> { 90 // Token should already have "Bearer " prefix stripped by caller 91 debug!("Validating JWT token (length: {})", token.len()); 92 debug!("Expected audience: {}", service_did); 93 94 // Parse the untrusted token to extract claims without verification 95 let untrusted = UntrustedToken::new(token).map_err(|e| { 96 warn!("Failed to parse JWT: {}", e); 97 anyhow!("Invalid JWT format: {}", e) 98 })?; 99 100 // First, try to deserialize as raw JSON to see the actual structure 101 let claims_wrapper = untrusted 102 .deserialize_claims_unchecked::<serde_json::Value>() 103 .map_err(|e| { 104 warn!("Failed to deserialize JWT claims: {}", e); 105 anyhow!("Invalid JWT claims: {}", e) 106 })?; 107 108 debug!("Raw JWT claims: {:?}", claims_wrapper); 109 110 // Extract the actual claims from the Value 111 let iss = claims_wrapper 112 .custom 113 .get("iss") 114 .and_then(|v| v.as_str()) 115 .ok_or_else(|| anyhow!("Missing 'iss' claim"))? 116 .to_string(); 117 118 let aud = claims_wrapper 119 .custom 120 .get("aud") 121 .and_then(|v| v.as_str()) 122 .ok_or_else(|| anyhow!("Missing 'aud' claim"))? 123 .to_string(); 124 125 let exp = claims_wrapper 126 .custom 127 .get("exp") 128 .and_then(|v| v.as_i64()) 129 .or_else(|| claims_wrapper.expiration.map(|ts| ts.timestamp())) 130 .ok_or_else(|| anyhow!("Missing 'exp' claim"))?; 131 132 debug!( 133 "JWT claims extracted - issuer: {}, audience: {}, exp: {}", 134 iss, aud, exp 135 ); 136 137 // Validate audience 138 if aud != service_did { 139 warn!( 140 "JWT audience mismatch: expected {}, got {}", 141 service_did, aud 142 ); 143 return Err(anyhow!("Invalid JWT audience")); 144 } 145 146 // Validate expiration 147 let now = std::time::SystemTime::now() 148 .duration_since(std::time::UNIX_EPOCH) 149 .unwrap() 150 .as_secs() as i64; 151 152 if exp < now { 153 warn!("JWT expired: exp={}, now={}", exp, now); 154 return Err(anyhow!("JWT has expired")); 155 } 156 157 // Verify signature 158 debug!("Verifying JWT signature for issuer: {}", iss); 159 160 // Create DID resolver 161 // Note: base_uri is not used for DID resolution, so we use a placeholder 162 let http_client = ReqwestClient::new("https://plc.directory"); 163 let resolver_config = CommonDidResolverConfig { 164 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 165 http_client: Arc::new(http_client), 166 }; 167 let resolver = CommonDidResolver::new(resolver_config); 168 169 // Resolve the issuer's signing key 170 let did_key = resolve_signing_key(&resolver, &iss).await?; 171 172 // Extract the signed portion of the JWT (header.payload) 173 // JWT format is: header.payload.signature 174 let parts: Vec<&str> = token.split('.').collect(); 175 if parts.len() != 3 { 176 warn!("Invalid JWT format: expected 3 parts, got {}", parts.len()); 177 return Err(anyhow!("Invalid JWT format")); 178 } 179 180 let signed_data = format!("{}.{}", parts[0], parts[1]); 181 let signature_b64 = parts[2]; 182 183 // Decode the base64url signature 184 let signature_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 185 .decode(signature_b64) 186 .map_err(|e| { 187 warn!("Failed to decode JWT signature: {}", e); 188 anyhow!("Invalid JWT signature encoding: {}", e) 189 })?; 190 191 // Verify the signature 192 atrium_crypto::verify::verify_signature(&did_key, signed_data.as_bytes(), &signature_bytes) 193 .map_err(|e| { 194 warn!("JWT signature verification failed: {}", e); 195 anyhow!("Invalid JWT signature: {}", e) 196 })?; 197 198 debug!("JWT signature verified successfully for issuer: {}", iss); 199 Ok(JwtClaims { iss, aud, exp }) 200}