this repo has no description
1use anyhow::{Result, anyhow}; 2use base64::Engine as _; 3use base64::engine::general_purpose::URL_SAFE_NO_PAD; 4use chrono::Utc; 5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier}; 6use reqwest::Client; 7use serde::{Deserialize, Serialize}; 8use std::time::Duration; 9use tracing::debug; 10 11#[derive(Debug, Clone, Serialize, Deserialize)] 12#[serde(rename_all = "camelCase")] 13pub struct FullDidDocument { 14 pub id: String, 15 #[serde(default)] 16 pub also_known_as: Vec<String>, 17 #[serde(default)] 18 pub verification_method: Vec<VerificationMethod>, 19 #[serde(default)] 20 pub service: Vec<DidService>, 21} 22 23#[derive(Debug, Clone, Serialize, Deserialize)] 24#[serde(rename_all = "camelCase")] 25pub struct VerificationMethod { 26 pub id: String, 27 #[serde(rename = "type")] 28 pub method_type: String, 29 pub controller: String, 30 #[serde(default)] 31 pub public_key_multibase: Option<String>, 32} 33 34#[derive(Debug, Clone, Serialize, Deserialize)] 35#[serde(rename_all = "camelCase")] 36pub struct DidService { 37 pub id: String, 38 #[serde(rename = "type")] 39 pub service_type: String, 40 pub service_endpoint: String, 41} 42 43#[derive(Debug, Clone, Serialize, Deserialize)] 44pub struct ServiceTokenClaims { 45 pub iss: String, 46 #[serde(default)] 47 pub sub: Option<String>, 48 pub aud: String, 49 pub exp: usize, 50 #[serde(default)] 51 pub iat: Option<usize>, 52 #[serde(skip_serializing_if = "Option::is_none")] 53 pub lxm: Option<String>, 54 #[serde(default)] 55 pub jti: Option<String>, 56} 57 58impl ServiceTokenClaims { 59 pub fn subject(&self) -> &str { 60 self.sub.as_deref().unwrap_or(&self.iss) 61 } 62} 63 64#[derive(Debug, Clone, Serialize, Deserialize)] 65struct TokenHeader { 66 pub alg: String, 67 pub typ: String, 68} 69 70pub struct ServiceTokenVerifier { 71 client: Client, 72 plc_directory_url: String, 73 pds_did: String, 74} 75 76impl ServiceTokenVerifier { 77 pub fn new() -> Self { 78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL") 79 .unwrap_or_else(|_| "https://plc.directory".to_string()); 80 81 let pds_hostname = 82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 83 let pds_did = format!("did:web:{}", pds_hostname); 84 85 let client = Client::builder() 86 .timeout(Duration::from_secs(10)) 87 .connect_timeout(Duration::from_secs(5)) 88 .pool_max_idle_per_host(10) 89 .pool_idle_timeout(Duration::from_secs(90)) 90 .build() 91 .unwrap_or_else(|_| Client::new()); 92 93 Self { 94 client, 95 plc_directory_url, 96 pds_did, 97 } 98 } 99 100 pub async fn verify_service_token( 101 &self, 102 token: &str, 103 required_lxm: Option<&str>, 104 ) -> Result<ServiceTokenClaims> { 105 let parts: Vec<&str> = token.split('.').collect(); 106 if parts.len() != 3 { 107 return Err(anyhow!("Invalid token format")); 108 } 109 110 let header_bytes = URL_SAFE_NO_PAD 111 .decode(parts[0]) 112 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?; 113 114 let header: TokenHeader = serde_json::from_slice(&header_bytes) 115 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?; 116 117 if header.alg != "ES256K" { 118 return Err(anyhow!("Unsupported algorithm: {}", header.alg)); 119 } 120 121 let claims_bytes = URL_SAFE_NO_PAD 122 .decode(parts[1]) 123 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?; 124 125 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes) 126 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?; 127 128 let now = Utc::now().timestamp() as usize; 129 if claims.exp < now { 130 return Err(anyhow!("Token expired")); 131 } 132 133 if claims.aud != self.pds_did { 134 return Err(anyhow!( 135 "Invalid audience: expected {}, got {}", 136 self.pds_did, 137 claims.aud 138 )); 139 } 140 141 if let Some(required) = required_lxm { 142 match &claims.lxm { 143 Some(lxm) if lxm == "*" || lxm == required => {} 144 Some(lxm) => { 145 return Err(anyhow!( 146 "Token lxm '{}' does not permit '{}'", 147 lxm, 148 required 149 )); 150 } 151 None => { 152 return Err(anyhow!("Token missing lxm claim")); 153 } 154 } 155 } 156 157 let did = &claims.iss; 158 let public_key = self.resolve_signing_key(did).await?; 159 160 let signature_bytes = URL_SAFE_NO_PAD 161 .decode(parts[2]) 162 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?; 163 164 let signature = Signature::from_slice(&signature_bytes) 165 .map_err(|e| anyhow!("Invalid signature format: {}", e))?; 166 167 let message = format!("{}.{}", parts[0], parts[1]); 168 169 public_key 170 .verify(message.as_bytes(), &signature) 171 .map_err(|e| anyhow!("Signature verification failed: {}", e))?; 172 173 debug!("Service token verified for DID: {}", did); 174 175 Ok(claims) 176 } 177 178 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> { 179 let did_doc = self.resolve_did_document(did).await?; 180 181 let atproto_key = did_doc 182 .verification_method 183 .iter() 184 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did)) 185 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?; 186 187 let multibase = atproto_key 188 .public_key_multibase 189 .as_ref() 190 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?; 191 192 parse_did_key_multibase(multibase) 193 } 194 195 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> { 196 if did.starts_with("did:plc:") { 197 self.resolve_did_plc(did).await 198 } else if did.starts_with("did:web:") { 199 self.resolve_did_web(did).await 200 } else { 201 Err(anyhow!("Unsupported DID method: {}", did)) 202 } 203 } 204 205 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> { 206 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did)); 207 debug!("Resolving did:plc {} via {}", did, url); 208 209 let resp = self 210 .client 211 .get(&url) 212 .send() 213 .await 214 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 215 216 if resp.status() == reqwest::StatusCode::NOT_FOUND { 217 return Err(anyhow!("DID not found: {}", did)); 218 } 219 220 if !resp.status().is_success() { 221 return Err(anyhow!("HTTP {}", resp.status())); 222 } 223 224 resp.json::<FullDidDocument>() 225 .await 226 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 227 } 228 229 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> { 230 let host = did 231 .strip_prefix("did:web:") 232 .ok_or_else(|| anyhow!("Invalid did:web format"))?; 233 234 let parts: Vec<&str> = host.split(':').collect(); 235 if parts.is_empty() { 236 return Err(anyhow!("Invalid did:web format - no host")); 237 } 238 239 let host_part = parts[0].replace("%3A", ":"); 240 241 let scheme = if host_part.starts_with("localhost") 242 || host_part.starts_with("127.0.0.1") 243 || host_part.contains(':') 244 { 245 "http" 246 } else { 247 "https" 248 }; 249 250 let url = if parts.len() == 1 { 251 format!("{}://{}/.well-known/did.json", scheme, host_part) 252 } else { 253 let path = parts[1..].join("/"); 254 format!("{}://{}/{}/did.json", scheme, host_part, path) 255 }; 256 257 debug!("Resolving did:web {} via {}", did, url); 258 259 let resp = self 260 .client 261 .get(&url) 262 .send() 263 .await 264 .map_err(|e| anyhow!("HTTP request failed: {}", e))?; 265 266 if !resp.status().is_success() { 267 return Err(anyhow!("HTTP {}", resp.status())); 268 } 269 270 resp.json::<FullDidDocument>() 271 .await 272 .map_err(|e| anyhow!("Failed to parse DID document: {}", e)) 273 } 274} 275 276impl Default for ServiceTokenVerifier { 277 fn default() -> Self { 278 Self::new() 279 } 280} 281 282fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> { 283 if !multibase.starts_with('z') { 284 return Err(anyhow!( 285 "Expected base58btc multibase encoding (starts with 'z')" 286 )); 287 } 288 289 let (_, decoded) = 290 multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?; 291 292 if decoded.len() < 2 { 293 return Err(anyhow!("Invalid multicodec data")); 294 } 295 296 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 { 297 (0xe701u16, &decoded[2..]) 298 } else { 299 return Err(anyhow!( 300 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}", 301 decoded[0], 302 decoded[1] 303 )); 304 }; 305 306 if codec != 0xe701 { 307 return Err(anyhow!("Only secp256k1 keys are supported")); 308 } 309 310 VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e)) 311} 312 313pub fn is_service_token(token: &str) -> bool { 314 let parts: Vec<&str> = token.split('.').collect(); 315 if parts.len() != 3 { 316 return false; 317 } 318 319 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else { 320 return false; 321 }; 322 323 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else { 324 return false; 325 }; 326 327 claims.get("lxm").is_some() 328} 329 330#[cfg(test)] 331mod tests { 332 use super::*; 333 334 #[test] 335 fn test_is_service_token() { 336 let claims_with_lxm = serde_json::json!({ 337 "iss": "did:plc:test", 338 "sub": "did:plc:test", 339 "aud": "did:web:test.com", 340 "exp": 9999999999i64, 341 "iat": 1000000000i64, 342 "lxm": "com.atproto.repo.uploadBlob", 343 "jti": "test-jti" 344 }); 345 346 let claims_without_lxm = serde_json::json!({ 347 "iss": "did:plc:test", 348 "sub": "did:plc:test", 349 "aud": "did:web:test.com", 350 "exp": 9999999999i64, 351 "iat": 1000000000i64, 352 "jti": "test-jti" 353 }); 354 355 let token_with_lxm = format!( 356 "{}.{}.{}", 357 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#), 358 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()), 359 URL_SAFE_NO_PAD.encode("fake-sig") 360 ); 361 362 let token_without_lxm = format!( 363 "{}.{}.{}", 364 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#), 365 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()), 366 URL_SAFE_NO_PAD.encode("fake-sig") 367 ); 368 369 assert!(is_service_token(&token_with_lxm)); 370 assert!(!is_service_token(&token_without_lxm)); 371 } 372 373 #[test] 374 fn test_parse_did_key_multibase() { 375 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB"; 376 let result = parse_did_key_multibase(test_key); 377 assert!(result.is_ok(), "Failed to parse valid multibase key"); 378 } 379}