this repo has no description
1use anyhow::{Result, anyhow}; 2use base64::Engine as _; 3use base64::engine::general_purpose::URL_SAFE_NO_PAD; 4use chrono::Utc; 5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 6use reqwest::Client; 7use serde::{Deserialize, Serialize}; 8use std::time::Duration; 9use tracing::debug; 10 11#[derive(Debug, Clone, Serialize, Deserialize)] 12#[serde(rename_all = "camelCase")] 13pub struct FullDidDocument { 14 pub id: String, 15 #[serde(default)] 16 pub also_known_as: Vec<String>, 17 #[serde(default)] 18 pub verification_method: Vec<VerificationMethod>, 19 #[serde(default)] 20 pub service: Vec<DidService>, 21} 22 23#[derive(Debug, Clone, Serialize, Deserialize)] 24#[serde(rename_all = "camelCase")] 25pub struct VerificationMethod { 26 pub id: String, 27 #[serde(rename = "type")] 28 pub method_type: String, 29 pub controller: String, 30 #[serde(default)] 31 pub public_key_multibase: Option<String>, 32} 33 34#[derive(Debug, Clone, Serialize, Deserialize)] 35#[serde(rename_all = "camelCase")] 36pub struct DidService { 37 pub id: String, 38 #[serde(rename = "type")] 39 pub service_type: String, 40 pub service_endpoint: String, 41} 42 43#[derive(Debug, Clone, Serialize, Deserialize)] 44pub struct ServiceTokenClaims { 45 pub iss: String, 46 #[serde(default)] 47 pub sub: Option<String>, 48 pub aud: String, 49 pub exp: usize, 50 #[serde(default)] 51 pub iat: Option<usize>, 52 #[serde(skip_serializing_if = "Option::is_none")] 53 pub lxm: Option<String>, 54 #[serde(default)] 55 pub jti: Option<String>, 56} 57 58impl ServiceTokenClaims { 59 pub fn subject(&self) -> &str { 60 self.sub.as_deref().unwrap_or(&self.iss) 61 } 62} 63 64#[derive(Debug, Clone, Serialize, Deserialize)] 65struct TokenHeader { 66 pub alg: String, 67 pub typ: String, 68} 69 70pub struct ServiceTokenVerifier { 71 client: Client, 72 plc_directory_url: String, 73 pds_did: String, 74} 75 76impl ServiceTokenVerifier { 77 pub fn new() -> Self { 78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL") 79 .unwrap_or_else(|_| "https://plc.directory".to_string()); 80 81 let pds_hostname = 82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 let pds_did = format!("did:web:{}", pds_hostname); 84 85 let client = Client::builder() 86 .timeout(Duration::from_secs(10)) 87 .connect_timeout(Duration::from_secs(5)) 88 .build() 89 .unwrap_or_else(|_| Client::new()); 90 91 Self { 92 client, 93 plc_directory_url, 94 pds_did, 95 } 96 } 97 98 pub async fn verify_service_token( 99 &self, 100 token: &str, 101 required_lxm: Option<&str>, 102 ) -> Result<ServiceTokenClaims> { 103 let parts: Vec<&str> = token.split('.').collect(); 104 if parts.len() != 3 { 105 return Err(anyhow!("Invalid token format")); 106 } 107 108 let header_bytes = URL_SAFE_NO_PAD 109 .decode(parts[0]) 110 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?; 111 112 let header: TokenHeader = serde_json::from_slice(&header_bytes) 113 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?; 114 115 if header.alg != "ES256K" { 116 return Err(anyhow!("Unsupported algorithm: {}", header.alg)); 117 } 118 119 let claims_bytes = URL_SAFE_NO_PAD 120 .decode(parts[1]) 121 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?; 122 123 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes) 124 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?; 125 126 let now = Utc::now().timestamp() as usize; 127 if claims.exp < now { 128 return Err(anyhow!("Token expired")); 129 } 130 131 if claims.aud != self.pds_did { 132 return Err(anyhow!( 133 "Invalid audience: expected {}, got {}", 134 self.pds_did, 135 claims.aud 136 )); 137 } 138 139 if let Some(required) = required_lxm { 140 match &claims.lxm { 141 Some(lxm) if lxm == "*" || lxm == required => {} 142 Some(lxm) => { 143 return Err(anyhow!( 144 "Token lxm '{}' does not permit '{}'", 145 lxm, 146 required 147 )); 148 } 149 None => { 150 return Err(anyhow!("Token missing lxm claim")); 151 } 152 } 153 } 154 155 let did = &claims.iss; 156 let public_key = self.resolve_signing_key(did).await?; 157 158 let signature_bytes = URL_SAFE_NO_PAD 159 .decode(parts[2]) 160 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?; 161 162 let signature = Signature::from_slice(&signature_bytes) 163 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 164 165 let message = format!("{}.{}", parts[0], parts[1]); 166 167 public_key 168 .verify(message.as_bytes(), &signature) 169 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 170 171 debug!("Service token verified for DID: {}", did); 172 173 Ok(claims) 174 } 175 176 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> { 177 let did_doc = self.resolve_did_document(did).await?; 178 179 let atproto_key = did_doc 180 .verification_method 181 .iter() 182 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did)) 183 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?; 184 185 let multibase = atproto_key 186 .public_key_multibase 187 .as_ref() 188 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?; 189 190 parse_did_key_multibase(multibase) 191 } 192 193 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> { 194 if did.starts_with("did:plc:") { 195 self.resolve_did_plc(did).await 196 } else if did.starts_with("did:web:") { 197 self.resolve_did_web(did).await 198 } else { 199 Err(anyhow!("Unsupported DID method: {}", did)) 200 } 201 } 202 203 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> { 204 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 205 debug!("Resolving did:plc {} via {}", did, url); 206 207 let resp = self 208 .client 209 .get(&url) 210 .send() 211 .await 212 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 213 214 if resp.status() == reqwest::StatusCode::NOT_FOUND { 215 return Err(anyhow!("DID not found: {}", did)); 216 } 217 218 if !resp.status().is_success() { 219 return Err(anyhow!("HTTP {}", resp.status())); 220 } 221 222 resp.json::<FullDidDocument>() 223 .await 224 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 225 } 226 227 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> { 228 let host = did 229 .strip_prefix("did:web:") 230 .ok_or_else(|| anyhow!("Invalid did:web format"))?; 231 232 let parts: Vec<&str> = host.split(':').collect(); 233 if parts.is_empty() { 234 return Err(anyhow!("Invalid did:web format - no host")); 235 } 236 237 let host_part = parts[0].replace("%3A", ":"); 238 239 let scheme = if host_part.starts_with("localhost") 240 || host_part.starts_with("127.0.0.1") 241 || host_part.contains(':') 242 { 243 "http" 244 } else { 245 "https" 246 }; 247 248 let url = if parts.len() == 1 { 249 format!("{}://{}/.well-known/did.json", scheme, host_part) 250 } else { 251 let path = parts[1..].join("/"); 252 format!("{}://{}/{}/did.json", scheme, host_part, path) 253 }; 254 255 debug!("Resolving did:web {} via {}", did, url); 256 257 let resp = self 258 .client 259 .get(&url) 260 .send() 261 .await 262 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 263 264 if !resp.status().is_success() { 265 return Err(anyhow!("HTTP {}", resp.status())); 266 } 267 268 resp.json::<FullDidDocument>() 269 .await 270 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 271 } 272} 273 274impl Default for ServiceTokenVerifier { 275 fn default() -> Self { 276 Self::new() 277 } 278} 279 280fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 281 if !multibase.starts_with('z') { 282 return Err(anyhow!( 283 "Expected base58btc multibase encoding (starts with 'z')" 284 )); 285 } 286 287 let (_, decoded) = 288 multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 289 290 if decoded.len() < 2 { 291 return Err(anyhow!("Invalid multicodec data")); 292 } 293 294 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 295 (0xe701u16, &decoded[2..]) 296 } else { 297 return Err(anyhow!( 298 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}", 299 decoded[0], 300 decoded[1] 301 )); 302 }; 303 304 if codec != 0xe701 { 305 return Err(anyhow!("Only secp256k1 keys are supported")); 306 } 307 308 VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e)) 309} 310 311pub fn is_service_token(token: &str) -> bool { 312 let parts: Vec<&str> = token.split('.').collect(); 313 if parts.len() != 3 { 314 return false; 315 } 316 317 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else { 318 return false; 319 }; 320 321 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else { 322 return false; 323 }; 324 325 claims.get("lxm").is_some() 326} 327 328#[cfg(test)] 329mod tests { 330 use super::*; 331 332 #[test] 333 fn test_is_service_token() { 334 let claims_with_lxm = serde_json::json!({ 335 "iss": "did:plc:test", 336 "sub": "did:plc:test", 337 "aud": "did:web:test.com", 338 "exp": 9999999999i64, 339 "iat": 1000000000i64, 340 "lxm": "com.atproto.repo.uploadBlob", 341 "jti": "test-jti" 342 }); 343 344 let claims_without_lxm = serde_json::json!({ 345 "iss": "did:plc:test", 346 "sub": "did:plc:test", 347 "aud": "did:web:test.com", 348 "exp": 9999999999i64, 349 "iat": 1000000000i64, 350 "jti": "test-jti" 351 }); 352 353 let token_with_lxm = format!( 354 "{}.{}.{}", 355 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#), 356 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()), 357 URL_SAFE_NO_PAD.encode("fake-sig") 358 ); 359 360 let token_without_lxm = format!( 361 "{}.{}.{}", 362 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#), 363 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()), 364 URL_SAFE_NO_PAD.encode("fake-sig") 365 ); 366 367 assert!(is_service_token(&token_with_lxm)); 368 assert!(!is_service_token(&token_without_lxm)); 369 } 370 371 #[test] 372 fn test_parse_did_key_multibase() { 373 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB"; 374 let result = parse_did_key_multibase(test_key); 375 assert!(result.is_ok(), "Failed to parse valid multibase key"); 376 } 377}