this repo has no description
1use anyhow::{Result, anyhow};
2use base64::Engine as _;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::Utc;
5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tracing::debug;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub struct FullDidDocument {
14 pub id: String,
15 #[serde(default)]
16 pub also_known_as: Vec<String>,
17 #[serde(default)]
18 pub verification_method: Vec<VerificationMethod>,
19 #[serde(default)]
20 pub service: Vec<DidService>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct VerificationMethod {
26 pub id: String,
27 #[serde(rename = "type")]
28 pub method_type: String,
29 pub controller: String,
30 #[serde(default)]
31 pub public_key_multibase: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct DidService {
37 pub id: String,
38 #[serde(rename = "type")]
39 pub service_type: String,
40 pub service_endpoint: String,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ServiceTokenClaims {
45 pub iss: String,
46 #[serde(default)]
47 pub sub: Option<String>,
48 pub aud: String,
49 pub exp: usize,
50 #[serde(default)]
51 pub iat: Option<usize>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub lxm: Option<String>,
54 #[serde(default)]
55 pub jti: Option<String>,
56}
57
58impl ServiceTokenClaims {
59 pub fn subject(&self) -> &str {
60 self.sub.as_deref().unwrap_or(&self.iss)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65struct TokenHeader {
66 pub alg: String,
67 pub typ: String,
68}
69
70pub struct ServiceTokenVerifier {
71 client: Client,
72 plc_directory_url: String,
73 pds_did: String,
74}
75
76impl ServiceTokenVerifier {
77 pub fn new() -> Self {
78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
79 .unwrap_or_else(|_| "https://plc.directory".to_string());
80
81 let pds_hostname =
82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83 let pds_did = format!("did:web:{}", pds_hostname);
84
85 let client = Client::builder()
86 .timeout(Duration::from_secs(10))
87 .connect_timeout(Duration::from_secs(5))
88 .pool_max_idle_per_host(10)
89 .pool_idle_timeout(Duration::from_secs(90))
90 .build()
91 .unwrap_or_else(|_| Client::new());
92
93 Self {
94 client,
95 plc_directory_url,
96 pds_did,
97 }
98 }
99
100 pub async fn verify_service_token(
101 &self,
102 token: &str,
103 required_lxm: Option<&str>,
104 ) -> Result<ServiceTokenClaims> {
105 let parts: Vec<&str> = token.split('.').collect();
106 if parts.len() != 3 {
107 return Err(anyhow!("Invalid token format"));
108 }
109
110 let header_bytes = URL_SAFE_NO_PAD
111 .decode(parts[0])
112 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?;
113
114 let header: TokenHeader = serde_json::from_slice(&header_bytes)
115 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?;
116
117 if header.alg != "ES256K" {
118 return Err(anyhow!("Unsupported algorithm: {}", header.alg));
119 }
120
121 let claims_bytes = URL_SAFE_NO_PAD
122 .decode(parts[1])
123 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?;
124
125 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes)
126 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?;
127
128 let now = Utc::now().timestamp() as usize;
129 if claims.exp < now {
130 return Err(anyhow!("Token expired"));
131 }
132
133 if claims.aud != self.pds_did {
134 return Err(anyhow!(
135 "Invalid audience: expected {}, got {}",
136 self.pds_did,
137 claims.aud
138 ));
139 }
140
141 if let Some(required) = required_lxm {
142 match &claims.lxm {
143 Some(lxm) if lxm == "*" || lxm == required => {}
144 Some(lxm) => {
145 return Err(anyhow!(
146 "Token lxm '{}' does not permit '{}'",
147 lxm,
148 required
149 ));
150 }
151 None => {
152 return Err(anyhow!("Token missing lxm claim"));
153 }
154 }
155 }
156
157 let did = &claims.iss;
158 let public_key = self.resolve_signing_key(did).await?;
159
160 let signature_bytes = URL_SAFE_NO_PAD
161 .decode(parts[2])
162 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?;
163
164 let signature = Signature::from_slice(&signature_bytes)
165 .map_err(|e| anyhow!("Invalid signature format: {}", e))?;
166
167 let message = format!("{}.{}", parts[0], parts[1]);
168
169 public_key
170 .verify(message.as_bytes(), &signature)
171 .map_err(|e| anyhow!("Signature verification failed: {}", e))?;
172
173 debug!("Service token verified for DID: {}", did);
174
175 Ok(claims)
176 }
177
178 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> {
179 let did_doc = self.resolve_did_document(did).await?;
180
181 let atproto_key = did_doc
182 .verification_method
183 .iter()
184 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did))
185 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?;
186
187 let multibase = atproto_key
188 .public_key_multibase
189 .as_ref()
190 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?;
191
192 parse_did_key_multibase(multibase)
193 }
194
195 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> {
196 if did.starts_with("did:plc:") {
197 self.resolve_did_plc(did).await
198 } else if did.starts_with("did:web:") {
199 self.resolve_did_web(did).await
200 } else {
201 Err(anyhow!("Unsupported DID method: {}", did))
202 }
203 }
204
205 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> {
206 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did));
207 debug!("Resolving did:plc {} via {}", did, url);
208
209 let resp = self
210 .client
211 .get(&url)
212 .send()
213 .await
214 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
215
216 if resp.status() == reqwest::StatusCode::NOT_FOUND {
217 return Err(anyhow!("DID not found: {}", did));
218 }
219
220 if !resp.status().is_success() {
221 return Err(anyhow!("HTTP {}", resp.status()));
222 }
223
224 resp.json::<FullDidDocument>()
225 .await
226 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
227 }
228
229 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> {
230 let host = did
231 .strip_prefix("did:web:")
232 .ok_or_else(|| anyhow!("Invalid did:web format"))?;
233
234 let parts: Vec<&str> = host.split(':').collect();
235 if parts.is_empty() {
236 return Err(anyhow!("Invalid did:web format - no host"));
237 }
238
239 let host_part = parts[0].replace("%3A", ":");
240
241 let scheme = if host_part.starts_with("localhost")
242 || host_part.starts_with("127.0.0.1")
243 || host_part.contains(':')
244 {
245 "http"
246 } else {
247 "https"
248 };
249
250 let url = if parts.len() == 1 {
251 format!("{}://{}/.well-known/did.json", scheme, host_part)
252 } else {
253 let path = parts[1..].join("/");
254 format!("{}://{}/{}/did.json", scheme, host_part, path)
255 };
256
257 debug!("Resolving did:web {} via {}", did, url);
258
259 let resp = self
260 .client
261 .get(&url)
262 .send()
263 .await
264 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
265
266 if !resp.status().is_success() {
267 return Err(anyhow!("HTTP {}", resp.status()));
268 }
269
270 resp.json::<FullDidDocument>()
271 .await
272 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
273 }
274}
275
276impl Default for ServiceTokenVerifier {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
283 if !multibase.starts_with('z') {
284 return Err(anyhow!(
285 "Expected base58btc multibase encoding (starts with 'z')"
286 ));
287 }
288
289 let (_, decoded) =
290 multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
291
292 if decoded.len() < 2 {
293 return Err(anyhow!("Invalid multicodec data"));
294 }
295
296 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
297 (0xe701u16, &decoded[2..])
298 } else {
299 return Err(anyhow!(
300 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}",
301 decoded[0],
302 decoded[1]
303 ));
304 };
305
306 if codec != 0xe701 {
307 return Err(anyhow!("Only secp256k1 keys are supported"));
308 }
309
310 VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e))
311}
312
313pub fn is_service_token(token: &str) -> bool {
314 let parts: Vec<&str> = token.split('.').collect();
315 if parts.len() != 3 {
316 return false;
317 }
318
319 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else {
320 return false;
321 };
322
323 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else {
324 return false;
325 };
326
327 claims.get("lxm").is_some()
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_is_service_token() {
336 let claims_with_lxm = serde_json::json!({
337 "iss": "did:plc:test",
338 "sub": "did:plc:test",
339 "aud": "did:web:test.com",
340 "exp": 9999999999i64,
341 "iat": 1000000000i64,
342 "lxm": "com.atproto.repo.uploadBlob",
343 "jti": "test-jti"
344 });
345
346 let claims_without_lxm = serde_json::json!({
347 "iss": "did:plc:test",
348 "sub": "did:plc:test",
349 "aud": "did:web:test.com",
350 "exp": 9999999999i64,
351 "iat": 1000000000i64,
352 "jti": "test-jti"
353 });
354
355 let token_with_lxm = format!(
356 "{}.{}.{}",
357 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#),
358 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()),
359 URL_SAFE_NO_PAD.encode("fake-sig")
360 );
361
362 let token_without_lxm = format!(
363 "{}.{}.{}",
364 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#),
365 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()),
366 URL_SAFE_NO_PAD.encode("fake-sig")
367 );
368
369 assert!(is_service_token(&token_with_lxm));
370 assert!(!is_service_token(&token_without_lxm));
371 }
372
373 #[test]
374 fn test_parse_did_key_multibase() {
375 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB";
376 let result = parse_did_key_multibase(test_key);
377 assert!(result.is_ok(), "Failed to parse valid multibase key");
378 }
379}