this repo has no description
1use anyhow::{Result, anyhow}; 2use base64::Engine as _; 3use base64::engine::general_purpose::URL_SAFE_NO_PAD; 4use chrono::Utc; 5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 6use reqwest::Client; 7use serde::{Deserialize, Serialize}; 8use std::time::Duration; 9use tracing::debug; 10 11#[derive(Debug, Clone, Serialize, Deserialize)] 12#[serde(rename_all = "camelCase")] 13pub struct FullDidDocument { 14 pub id: String, 15 #[serde(default)] 16 pub also_known_as: Vec<String>, 17 #[serde(default)] 18 pub verification_method: Vec<VerificationMethod>, 19 #[serde(default)] 20 pub service: Vec<DidService>, 21} 22 23#[derive(Debug, Clone, Serialize, Deserialize)] 24#[serde(rename_all = "camelCase")] 25pub struct VerificationMethod { 26 pub id: String, 27 #[serde(rename = "type")] 28 pub method_type: String, 29 pub controller: String, 30 #[serde(default)] 31 pub public_key_multibase: Option<String>, 32} 33 34#[derive(Debug, Clone, Serialize, Deserialize)] 35#[serde(rename_all = "camelCase")] 36pub struct DidService { 37 pub id: String, 38 #[serde(rename = "type")] 39 pub service_type: String, 40 pub service_endpoint: String, 41} 42 43#[derive(Debug, Clone, Serialize, Deserialize)] 44pub struct ServiceTokenClaims { 45 pub iss: String, 46 #[serde(default)] 47 pub sub: Option<String>, 48 pub aud: String, 49 pub exp: usize, 50 #[serde(default)] 51 pub iat: Option<usize>, 52 #[serde(skip_serializing_if = "Option::is_none")] 53 pub lxm: Option<String>, 54 #[serde(default)] 55 pub jti: Option<String>, 56} 57 58impl ServiceTokenClaims { 59 pub fn subject(&self) -> &str { 60 self.sub.as_deref().unwrap_or(&self.iss) 61 } 62} 63 64#[derive(Debug, Clone, Serialize, Deserialize)] 65struct TokenHeader { 66 pub alg: String, 67 pub typ: String, 68} 69 70pub struct ServiceTokenVerifier { 71 client: Client, 72 plc_directory_url: String, 73 pds_did: String, 74} 75 76impl ServiceTokenVerifier { 77 pub fn new() -> Self { 78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL") 79 .unwrap_or_else(|_| "https://plc.directory".to_string()); 80 81 let pds_hostname = 82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 let pds_did = format!("did:web:{}", pds_hostname); 84 85 let client = Client::builder() 86 .timeout(Duration::from_secs(10)) 87 .connect_timeout(Duration::from_secs(5)) 88 .build() 89 .unwrap_or_else(|_| Client::new()); 90 91 Self { 92 client, 93 plc_directory_url, 94 pds_did, 95 } 96 } 97 98 pub async fn verify_service_token( 99 &self, 100 token: &str, 101 required_lxm: Option<&str>, 102 ) -> Result<ServiceTokenClaims> { 103 let parts: Vec<&str> = token.split('.').collect(); 104 if parts.len() != 3 { 105 return Err(anyhow!("Invalid token format")); 106 } 107 108 let header_bytes = URL_SAFE_NO_PAD 109 .decode(parts[0]) 110 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?; 111 112 let header: TokenHeader = serde_json::from_slice(&header_bytes) 113 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?; 114 115 if header.alg != "ES256K" { 116 return Err(anyhow!("Unsupported algorithm: {}", header.alg)); 117 } 118 119 let claims_bytes = URL_SAFE_NO_PAD 120 .decode(parts[1]) 121 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?; 122 123 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes) 124 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?; 125 126 let now = Utc::now().timestamp() as usize; 127 if claims.exp < now { 128 return Err(anyhow!("Token expired")); 129 } 130 131 if claims.aud != self.pds_did { 132 return Err(anyhow!( 133 "Invalid audience: expected {}, got {}", 134 self.pds_did, 135 claims.aud 136 )); 137 } 138 139 if let Some(required) = required_lxm { 140 match &claims.lxm { 141 Some(lxm) if lxm == "*" || lxm == required => {} 142 Some(lxm) => { 143 return Err(anyhow!( 144 "Token lxm '{}' does not permit '{}'", 145 lxm, 146 required 147 )); 148 } 149 None => { 150 return Err(anyhow!("Token missing lxm claim")); 151 } 152 } 153 } 154 155 let did = &claims.iss; 156 let public_key = self.resolve_signing_key(did).await?; 157 158 let signature_bytes = URL_SAFE_NO_PAD 159 .decode(parts[2]) 160 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?; 161 162 let signature = Signature::from_slice(&signature_bytes) 163 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 164 165 let message = format!("{}.{}", parts[0], parts[1]); 166 167 public_key 168 .verify(message.as_bytes(), &signature) 169 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 170 171 debug!("Service token verified for DID: {}", did); 172 173 Ok(claims) 174 } 175 176 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> { 177 let did_doc = self.resolve_did_document(did).await?; 178 179 let atproto_key = did_doc 180 .verification_method 181 .iter() 182 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did)) 183 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?; 184 185 let multibase = atproto_key 186 .public_key_multibase 187 .as_ref() 188 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?; 189 190 parse_did_key_multibase(multibase) 191 } 192 193 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> { 194 if did.starts_with("did:plc:") { 195 self.resolve_did_plc(did).await 196 } else if did.starts_with("did:web:") { 197 self.resolve_did_web(did).await 198 } else { 199 Err(anyhow!("Unsupported DID method: {}", did)) 200 } 201 } 202 203 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> { 204 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 205 debug!("Resolving did:plc {} via {}", did, url); 206 207 let resp = self 208 .client 209 .get(&url) 210 .send() 211 .await 212 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 213 214 if resp.status() == reqwest::StatusCode::NOT_FOUND { 215 return Err(anyhow!("DID not found: {}", did)); 216 } 217 218 if !resp.status().is_success() { 219 return Err(anyhow!("HTTP {}", resp.status())); 220 } 221 222 resp.json::<FullDidDocument>() 223 .await 224 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 225 } 226 227 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> { 228 let host = did 229 .strip_prefix("did:web:") 230 .ok_or_else(|| anyhow!("Invalid did:web format"))?; 231 232 let decoded_host = host.replace("%3A", ":"); 233 let (host_part, path_part) = if let Some(idx) = decoded_host.find('/') { 234 (&decoded_host[..idx], &decoded_host[idx..]) 235 } else { 236 (decoded_host.as_str(), "") 237 }; 238 239 let scheme = if host_part.starts_with("localhost") 240 || host_part.starts_with("127.0.0.1") 241 || host_part.contains(':') 242 { 243 "http" 244 } else { 245 "https" 246 }; 247 248 let url = if path_part.is_empty() { 249 format!("{}://{}/.well-known/did.json", scheme, host_part) 250 } else { 251 format!("{}://{}{}/did.json", scheme, host_part, path_part) 252 }; 253 254 debug!("Resolving did:web {} via {}", did, url); 255 256 let resp = self 257 .client 258 .get(&url) 259 .send() 260 .await 261 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 262 263 if !resp.status().is_success() { 264 return Err(anyhow!("HTTP {}", resp.status())); 265 } 266 267 resp.json::<FullDidDocument>() 268 .await 269 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 270 } 271} 272 273impl Default for ServiceTokenVerifier { 274 fn default() -> Self { 275 Self::new() 276 } 277} 278 279fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 280 if !multibase.starts_with('z') { 281 return Err(anyhow!("Expected base58btc multibase encoding (starts with 'z')")); 282 } 283 284 let (_, decoded) = multibase::decode(multibase) 285 .map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 286 287 if decoded.len() < 2 { 288 return Err(anyhow!("Invalid multicodec data")); 289 } 290 291 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 292 (0xe701u16, &decoded[2..]) 293 } else { 294 return Err(anyhow!( 295 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}", 296 decoded[0], 297 decoded[1] 298 )); 299 }; 300 301 if codec != 0xe701 { 302 return Err(anyhow!("Only secp256k1 keys are supported")); 303 } 304 305 VerifyingKey::from_sec1_bytes(key_bytes) 306 .map_err(|e| anyhow!("Invalid public key: {}", e)) 307} 308 309pub fn is_service_token(token: &str) -> bool { 310 let parts: Vec<&str> = token.split('.').collect(); 311 if parts.len() != 3 { 312 return false; 313 } 314 315 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else { 316 return false; 317 }; 318 319 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else { 320 return false; 321 }; 322 323 claims.get("lxm").is_some() 324} 325 326#[cfg(test)] 327mod tests { 328 use super::*; 329 330 #[test] 331 fn test_is_service_token() { 332 let claims_with_lxm = serde_json::json!({ 333 "iss": "did:plc:test", 334 "sub": "did:plc:test", 335 "aud": "did:web:test.com", 336 "exp": 9999999999i64, 337 "iat": 1000000000i64, 338 "lxm": "com.atproto.repo.uploadBlob", 339 "jti": "test-jti" 340 }); 341 342 let claims_without_lxm = serde_json::json!({ 343 "iss": "did:plc:test", 344 "sub": "did:plc:test", 345 "aud": "did:web:test.com", 346 "exp": 9999999999i64, 347 "iat": 1000000000i64, 348 "jti": "test-jti" 349 }); 350 351 let token_with_lxm = format!( 352 "{}.{}.{}", 353 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#), 354 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()), 355 URL_SAFE_NO_PAD.encode("fake-sig") 356 ); 357 358 let token_without_lxm = format!( 359 "{}.{}.{}", 360 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#), 361 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()), 362 URL_SAFE_NO_PAD.encode("fake-sig") 363 ); 364 365 assert!(is_service_token(&token_with_lxm)); 366 assert!(!is_service_token(&token_without_lxm)); 367 } 368 369 #[test] 370 fn test_parse_did_key_multibase() { 371 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB"; 372 let result = parse_did_key_multibase(test_key); 373 assert!(result.is_ok(), "Failed to parse valid multibase key"); 374 } 375}