this repo has no description
1use anyhow::{Result, anyhow};
2use base64::Engine as _;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::Utc;
5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tracing::debug;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub struct FullDidDocument {
14 pub id: String,
15 #[serde(default)]
16 pub also_known_as: Vec<String>,
17 #[serde(default)]
18 pub verification_method: Vec<VerificationMethod>,
19 #[serde(default)]
20 pub service: Vec<DidService>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct VerificationMethod {
26 pub id: String,
27 #[serde(rename = "type")]
28 pub method_type: String,
29 pub controller: String,
30 #[serde(default)]
31 pub public_key_multibase: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct DidService {
37 pub id: String,
38 #[serde(rename = "type")]
39 pub service_type: String,
40 pub service_endpoint: String,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ServiceTokenClaims {
45 pub iss: String,
46 #[serde(default)]
47 pub sub: Option<String>,
48 pub aud: String,
49 pub exp: usize,
50 #[serde(default)]
51 pub iat: Option<usize>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub lxm: Option<String>,
54 #[serde(default)]
55 pub jti: Option<String>,
56}
57
58impl ServiceTokenClaims {
59 pub fn subject(&self) -> &str {
60 self.sub.as_deref().unwrap_or(&self.iss)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65struct TokenHeader {
66 pub alg: String,
67 pub typ: String,
68}
69
70pub struct ServiceTokenVerifier {
71 client: Client,
72 plc_directory_url: String,
73 pds_did: String,
74}
75
76impl ServiceTokenVerifier {
77 pub fn new() -> Self {
78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
79 .unwrap_or_else(|_| "https://plc.directory".to_string());
80
81 let pds_hostname =
82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83 let pds_did = format!("did:web:{}", pds_hostname);
84
85 let client = Client::builder()
86 .timeout(Duration::from_secs(10))
87 .connect_timeout(Duration::from_secs(5))
88 .build()
89 .unwrap_or_else(|_| Client::new());
90
91 Self {
92 client,
93 plc_directory_url,
94 pds_did,
95 }
96 }
97
98 pub async fn verify_service_token(
99 &self,
100 token: &str,
101 required_lxm: Option<&str>,
102 ) -> Result<ServiceTokenClaims> {
103 let parts: Vec<&str> = token.split('.').collect();
104 if parts.len() != 3 {
105 return Err(anyhow!("Invalid token format"));
106 }
107
108 let header_bytes = URL_SAFE_NO_PAD
109 .decode(parts[0])
110 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?;
111
112 let header: TokenHeader = serde_json::from_slice(&header_bytes)
113 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?;
114
115 if header.alg != "ES256K" {
116 return Err(anyhow!("Unsupported algorithm: {}", header.alg));
117 }
118
119 let claims_bytes = URL_SAFE_NO_PAD
120 .decode(parts[1])
121 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?;
122
123 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes)
124 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?;
125
126 let now = Utc::now().timestamp() as usize;
127 if claims.exp < now {
128 return Err(anyhow!("Token expired"));
129 }
130
131 if claims.aud != self.pds_did {
132 return Err(anyhow!(
133 "Invalid audience: expected {}, got {}",
134 self.pds_did,
135 claims.aud
136 ));
137 }
138
139 if let Some(required) = required_lxm {
140 match &claims.lxm {
141 Some(lxm) if lxm == "*" || lxm == required => {}
142 Some(lxm) => {
143 return Err(anyhow!(
144 "Token lxm '{}' does not permit '{}'",
145 lxm,
146 required
147 ));
148 }
149 None => {
150 return Err(anyhow!("Token missing lxm claim"));
151 }
152 }
153 }
154
155 let did = &claims.iss;
156 let public_key = self.resolve_signing_key(did).await?;
157
158 let signature_bytes = URL_SAFE_NO_PAD
159 .decode(parts[2])
160 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?;
161
162 let signature = Signature::from_slice(&signature_bytes)
163 .map_err(|e| anyhow!("Invalid signature format: {}", e))?;
164
165 let message = format!("{}.{}", parts[0], parts[1]);
166
167 public_key
168 .verify(message.as_bytes(), &signature)
169 .map_err(|e| anyhow!("Signature verification failed: {}", e))?;
170
171 debug!("Service token verified for DID: {}", did);
172
173 Ok(claims)
174 }
175
176 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> {
177 let did_doc = self.resolve_did_document(did).await?;
178
179 let atproto_key = did_doc
180 .verification_method
181 .iter()
182 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did))
183 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?;
184
185 let multibase = atproto_key
186 .public_key_multibase
187 .as_ref()
188 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?;
189
190 parse_did_key_multibase(multibase)
191 }
192
193 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> {
194 if did.starts_with("did:plc:") {
195 self.resolve_did_plc(did).await
196 } else if did.starts_with("did:web:") {
197 self.resolve_did_web(did).await
198 } else {
199 Err(anyhow!("Unsupported DID method: {}", did))
200 }
201 }
202
203 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> {
204 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did));
205 debug!("Resolving did:plc {} via {}", did, url);
206
207 let resp = self
208 .client
209 .get(&url)
210 .send()
211 .await
212 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
213
214 if resp.status() == reqwest::StatusCode::NOT_FOUND {
215 return Err(anyhow!("DID not found: {}", did));
216 }
217
218 if !resp.status().is_success() {
219 return Err(anyhow!("HTTP {}", resp.status()));
220 }
221
222 resp.json::<FullDidDocument>()
223 .await
224 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
225 }
226
227 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> {
228 let host = did
229 .strip_prefix("did:web:")
230 .ok_or_else(|| anyhow!("Invalid did:web format"))?;
231
232 let parts: Vec<&str> = host.split(':').collect();
233 if parts.is_empty() {
234 return Err(anyhow!("Invalid did:web format - no host"));
235 }
236
237 let host_part = parts[0].replace("%3A", ":");
238
239 let scheme = if host_part.starts_with("localhost")
240 || host_part.starts_with("127.0.0.1")
241 || host_part.contains(':')
242 {
243 "http"
244 } else {
245 "https"
246 };
247
248 let url = if parts.len() == 1 {
249 format!("{}://{}/.well-known/did.json", scheme, host_part)
250 } else {
251 let path = parts[1..].join("/");
252 format!("{}://{}/{}/did.json", scheme, host_part, path)
253 };
254
255 debug!("Resolving did:web {} via {}", did, url);
256
257 let resp = self
258 .client
259 .get(&url)
260 .send()
261 .await
262 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
263
264 if !resp.status().is_success() {
265 return Err(anyhow!("HTTP {}", resp.status()));
266 }
267
268 resp.json::<FullDidDocument>()
269 .await
270 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
271 }
272}
273
274impl Default for ServiceTokenVerifier {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
281 if !multibase.starts_with('z') {
282 return Err(anyhow!(
283 "Expected base58btc multibase encoding (starts with 'z')"
284 ));
285 }
286
287 let (_, decoded) =
288 multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
289
290 if decoded.len() < 2 {
291 return Err(anyhow!("Invalid multicodec data"));
292 }
293
294 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
295 (0xe701u16, &decoded[2..])
296 } else {
297 return Err(anyhow!(
298 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}",
299 decoded[0],
300 decoded[1]
301 ));
302 };
303
304 if codec != 0xe701 {
305 return Err(anyhow!("Only secp256k1 keys are supported"));
306 }
307
308 VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e))
309}
310
311pub fn is_service_token(token: &str) -> bool {
312 let parts: Vec<&str> = token.split('.').collect();
313 if parts.len() != 3 {
314 return false;
315 }
316
317 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else {
318 return false;
319 };
320
321 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else {
322 return false;
323 };
324
325 claims.get("lxm").is_some()
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_is_service_token() {
334 let claims_with_lxm = serde_json::json!({
335 "iss": "did:plc:test",
336 "sub": "did:plc:test",
337 "aud": "did:web:test.com",
338 "exp": 9999999999i64,
339 "iat": 1000000000i64,
340 "lxm": "com.atproto.repo.uploadBlob",
341 "jti": "test-jti"
342 });
343
344 let claims_without_lxm = serde_json::json!({
345 "iss": "did:plc:test",
346 "sub": "did:plc:test",
347 "aud": "did:web:test.com",
348 "exp": 9999999999i64,
349 "iat": 1000000000i64,
350 "jti": "test-jti"
351 });
352
353 let token_with_lxm = format!(
354 "{}.{}.{}",
355 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#),
356 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()),
357 URL_SAFE_NO_PAD.encode("fake-sig")
358 );
359
360 let token_without_lxm = format!(
361 "{}.{}.{}",
362 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#),
363 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()),
364 URL_SAFE_NO_PAD.encode("fake-sig")
365 );
366
367 assert!(is_service_token(&token_with_lxm));
368 assert!(!is_service_token(&token_without_lxm));
369 }
370
371 #[test]
372 fn test_parse_did_key_multibase() {
373 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB";
374 let result = parse_did_key_multibase(test_key);
375 assert!(result.is_ok(), "Failed to parse valid multibase key");
376 }
377}