Noreposts Feed
1use anyhow::{anyhow, Result};
2use atrium_common::resolver::Resolver;
3use atrium_crypto::did::{format_did_key, parse_multikey};
4use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
5use atrium_xrpc_client::reqwest::ReqwestClient;
6use base64::Engine;
7use jwt_compact::UntrustedToken;
8use std::sync::Arc;
9use tracing::{debug, warn};
10
11use crate::types::JwtClaims;
12
13// Unused structs kept for reference if needed in future
14// #[derive(Debug, Deserialize)]
15// struct EmptyCustomClaims {}
16//
17// #[derive(Debug, Deserialize)]
18// struct StandardClaims {
19// #[serde(rename = "iss")]
20// issuer: Option<String>,
21// #[serde(rename = "aud")]
22// audience: Option<String>,
23// #[serde(rename = "exp")]
24// expiration: Option<i64>,
25// }
26
27/// Resolves a DID and extracts the atproto signing key as a did:key string
28async fn resolve_signing_key(
29 resolver: &CommonDidResolver<ReqwestClient>,
30 did_str: &str,
31) -> Result<String> {
32 debug!("Resolving DID: {}", did_str);
33
34 // Convert string to Did type
35 let did = did_str.parse().map_err(|e| {
36 warn!("Invalid DID format: {}", e);
37 anyhow!("Invalid DID format: {}", e)
38 })?;
39
40 // Resolve the DID document
41 let did_doc = resolver.resolve(&did).await.map_err(|e| {
42 warn!("Failed to resolve DID {}: {}", did_str, e);
43 anyhow!("Failed to resolve DID: {}", e)
44 })?;
45
46 debug!("DID document resolved: {:?}", did_doc);
47
48 // Use the built-in helper to get the signing key
49 let verification_method = did_doc.get_signing_key().ok_or_else(|| {
50 warn!("No atproto verification method found in DID document");
51 anyhow!("No atproto signing key found in DID document")
52 })?;
53
54 debug!("Found verification method: {:?}", verification_method);
55
56 // Extract publicKeyMultibase
57 let public_key_multibase = verification_method
58 .public_key_multibase
59 .as_ref()
60 .ok_or_else(|| {
61 warn!("Verification method missing publicKeyMultibase");
62 anyhow!("Missing publicKeyMultibase in verification method")
63 })?;
64
65 debug!("Public key multibase: {}", public_key_multibase);
66
67 // Parse the multibase-encoded key
68 let (algorithm, key_bytes) = parse_multikey(public_key_multibase).map_err(|e| {
69 warn!("Failed to parse multikey: {}", e);
70 anyhow!("Invalid publicKeyMultibase format: {}", e)
71 })?;
72
73 debug!(
74 "Parsed key: algorithm={:?}, key_len={}",
75 algorithm,
76 key_bytes.len()
77 );
78
79 // Format as did:key
80 let did_key = format_did_key(algorithm, &key_bytes).map_err(|e| {
81 warn!("Failed to format did:key: {}", e);
82 anyhow!("Failed to convert key to did:key format: {}", e)
83 })?;
84
85 debug!("Formatted did:key: {}", did_key);
86 Ok(did_key)
87}
88
89pub async fn validate_jwt(token: &str, service_did: &str) -> Result<JwtClaims> {
90 // Token should already have "Bearer " prefix stripped by caller
91 debug!("Validating JWT token (length: {})", token.len());
92 debug!("Expected audience: {}", service_did);
93
94 // Parse the untrusted token to extract claims without verification
95 let untrusted = UntrustedToken::new(token).map_err(|e| {
96 warn!("Failed to parse JWT: {}", e);
97 anyhow!("Invalid JWT format: {}", e)
98 })?;
99
100 // First, try to deserialize as raw JSON to see the actual structure
101 let claims_wrapper = untrusted
102 .deserialize_claims_unchecked::<serde_json::Value>()
103 .map_err(|e| {
104 warn!("Failed to deserialize JWT claims: {}", e);
105 anyhow!("Invalid JWT claims: {}", e)
106 })?;
107
108 debug!("Raw JWT claims: {:?}", claims_wrapper);
109
110 // Extract the actual claims from the Value
111 let iss = claims_wrapper
112 .custom
113 .get("iss")
114 .and_then(|v| v.as_str())
115 .ok_or_else(|| anyhow!("Missing 'iss' claim"))?
116 .to_string();
117
118 let aud = claims_wrapper
119 .custom
120 .get("aud")
121 .and_then(|v| v.as_str())
122 .ok_or_else(|| anyhow!("Missing 'aud' claim"))?
123 .to_string();
124
125 let exp = claims_wrapper
126 .custom
127 .get("exp")
128 .and_then(|v| v.as_i64())
129 .or_else(|| claims_wrapper.expiration.map(|ts| ts.timestamp()))
130 .ok_or_else(|| anyhow!("Missing 'exp' claim"))?;
131
132 debug!(
133 "JWT claims extracted - issuer: {}, audience: {}, exp: {}",
134 iss, aud, exp
135 );
136
137 // Validate audience
138 if aud != service_did {
139 warn!(
140 "JWT audience mismatch: expected {}, got {}",
141 service_did, aud
142 );
143 return Err(anyhow!("Invalid JWT audience"));
144 }
145
146 // Validate expiration
147 let now = std::time::SystemTime::now()
148 .duration_since(std::time::UNIX_EPOCH)
149 .unwrap()
150 .as_secs() as i64;
151
152 if exp < now {
153 warn!("JWT expired: exp={}, now={}", exp, now);
154 return Err(anyhow!("JWT has expired"));
155 }
156
157 // Verify signature
158 debug!("Verifying JWT signature for issuer: {}", iss);
159
160 // Create DID resolver
161 // Note: base_uri is not used for DID resolution, so we use a placeholder
162 let http_client = ReqwestClient::new("https://plc.directory");
163 let resolver_config = CommonDidResolverConfig {
164 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
165 http_client: Arc::new(http_client),
166 };
167 let resolver = CommonDidResolver::new(resolver_config);
168
169 // Resolve the issuer's signing key
170 let did_key = resolve_signing_key(&resolver, &iss).await?;
171
172 // Extract the signed portion of the JWT (header.payload)
173 // JWT format is: header.payload.signature
174 let parts: Vec<&str> = token.split('.').collect();
175 if parts.len() != 3 {
176 warn!("Invalid JWT format: expected 3 parts, got {}", parts.len());
177 return Err(anyhow!("Invalid JWT format"));
178 }
179
180 let signed_data = format!("{}.{}", parts[0], parts[1]);
181 let signature_b64 = parts[2];
182
183 // Decode the base64url signature
184 let signature_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
185 .decode(signature_b64)
186 .map_err(|e| {
187 warn!("Failed to decode JWT signature: {}", e);
188 anyhow!("Invalid JWT signature encoding: {}", e)
189 })?;
190
191 // Verify the signature
192 atrium_crypto::verify::verify_signature(&did_key, signed_data.as_bytes(), &signature_bytes)
193 .map_err(|e| {
194 warn!("JWT signature verification failed: {}", e);
195 anyhow!("Invalid JWT signature: {}", e)
196 })?;
197
198 debug!("JWT signature verified successfully for issuer: {}", iss);
199 Ok(JwtClaims { iss, aud, exp })
200}