this repo has no description
1use anyhow::{Result, anyhow};
2use base64::Engine as _;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use chrono::Utc;
5use k256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use tracing::debug;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(rename_all = "camelCase")]
13pub struct FullDidDocument {
14 pub id: String,
15 #[serde(default)]
16 pub also_known_as: Vec<String>,
17 #[serde(default)]
18 pub verification_method: Vec<VerificationMethod>,
19 #[serde(default)]
20 pub service: Vec<DidService>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct VerificationMethod {
26 pub id: String,
27 #[serde(rename = "type")]
28 pub method_type: String,
29 pub controller: String,
30 #[serde(default)]
31 pub public_key_multibase: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct DidService {
37 pub id: String,
38 #[serde(rename = "type")]
39 pub service_type: String,
40 pub service_endpoint: String,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ServiceTokenClaims {
45 pub iss: String,
46 #[serde(default)]
47 pub sub: Option<String>,
48 pub aud: String,
49 pub exp: usize,
50 #[serde(default)]
51 pub iat: Option<usize>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub lxm: Option<String>,
54 #[serde(default)]
55 pub jti: Option<String>,
56}
57
58impl ServiceTokenClaims {
59 pub fn subject(&self) -> &str {
60 self.sub.as_deref().unwrap_or(&self.iss)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65struct TokenHeader {
66 pub alg: String,
67 pub typ: String,
68}
69
70pub struct ServiceTokenVerifier {
71 client: Client,
72 plc_directory_url: String,
73 pds_did: String,
74}
75
76impl ServiceTokenVerifier {
77 pub fn new() -> Self {
78 let plc_directory_url = std::env::var("PLC_DIRECTORY_URL")
79 .unwrap_or_else(|_| "https://plc.directory".to_string());
80
81 let pds_hostname =
82 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
83 let pds_did = format!("did:web:{}", pds_hostname);
84
85 let client = Client::builder()
86 .timeout(Duration::from_secs(10))
87 .connect_timeout(Duration::from_secs(5))
88 .build()
89 .unwrap_or_else(|_| Client::new());
90
91 Self {
92 client,
93 plc_directory_url,
94 pds_did,
95 }
96 }
97
98 pub async fn verify_service_token(
99 &self,
100 token: &str,
101 required_lxm: Option<&str>,
102 ) -> Result<ServiceTokenClaims> {
103 let parts: Vec<&str> = token.split('.').collect();
104 if parts.len() != 3 {
105 return Err(anyhow!("Invalid token format"));
106 }
107
108 let header_bytes = URL_SAFE_NO_PAD
109 .decode(parts[0])
110 .map_err(|e| anyhow!("Base64 decode of header failed: {}", e))?;
111
112 let header: TokenHeader = serde_json::from_slice(&header_bytes)
113 .map_err(|e| anyhow!("JSON decode of header failed: {}", e))?;
114
115 if header.alg != "ES256K" {
116 return Err(anyhow!("Unsupported algorithm: {}", header.alg));
117 }
118
119 let claims_bytes = URL_SAFE_NO_PAD
120 .decode(parts[1])
121 .map_err(|e| anyhow!("Base64 decode of claims failed: {}", e))?;
122
123 let claims: ServiceTokenClaims = serde_json::from_slice(&claims_bytes)
124 .map_err(|e| anyhow!("JSON decode of claims failed: {}", e))?;
125
126 let now = Utc::now().timestamp() as usize;
127 if claims.exp < now {
128 return Err(anyhow!("Token expired"));
129 }
130
131 if claims.aud != self.pds_did {
132 return Err(anyhow!(
133 "Invalid audience: expected {}, got {}",
134 self.pds_did,
135 claims.aud
136 ));
137 }
138
139 if let Some(required) = required_lxm {
140 match &claims.lxm {
141 Some(lxm) if lxm == "*" || lxm == required => {}
142 Some(lxm) => {
143 return Err(anyhow!(
144 "Token lxm '{}' does not permit '{}'",
145 lxm,
146 required
147 ));
148 }
149 None => {
150 return Err(anyhow!("Token missing lxm claim"));
151 }
152 }
153 }
154
155 let did = &claims.iss;
156 let public_key = self.resolve_signing_key(did).await?;
157
158 let signature_bytes = URL_SAFE_NO_PAD
159 .decode(parts[2])
160 .map_err(|e| anyhow!("Base64 decode of signature failed: {}", e))?;
161
162 let signature = Signature::from_slice(&signature_bytes)
163 .map_err(|e| anyhow!("Invalid signature format: {}", e))?;
164
165 let message = format!("{}.{}", parts[0], parts[1]);
166
167 public_key
168 .verify(message.as_bytes(), &signature)
169 .map_err(|e| anyhow!("Signature verification failed: {}", e))?;
170
171 debug!("Service token verified for DID: {}", did);
172
173 Ok(claims)
174 }
175
176 async fn resolve_signing_key(&self, did: &str) -> Result<VerifyingKey> {
177 let did_doc = self.resolve_did_document(did).await?;
178
179 let atproto_key = did_doc
180 .verification_method
181 .iter()
182 .find(|vm| vm.id.ends_with("#atproto") || vm.id == format!("{}#atproto", did))
183 .ok_or_else(|| anyhow!("No atproto verification method found in DID document"))?;
184
185 let multibase = atproto_key
186 .public_key_multibase
187 .as_ref()
188 .ok_or_else(|| anyhow!("Verification method missing publicKeyMultibase"))?;
189
190 parse_did_key_multibase(multibase)
191 }
192
193 async fn resolve_did_document(&self, did: &str) -> Result<FullDidDocument> {
194 if did.starts_with("did:plc:") {
195 self.resolve_did_plc(did).await
196 } else if did.starts_with("did:web:") {
197 self.resolve_did_web(did).await
198 } else {
199 Err(anyhow!("Unsupported DID method: {}", did))
200 }
201 }
202
203 async fn resolve_did_plc(&self, did: &str) -> Result<FullDidDocument> {
204 let url = format!("{}/{}", self.plc_directory_url, urlencoding::encode(did));
205 debug!("Resolving did:plc {} via {}", did, url);
206
207 let resp = self
208 .client
209 .get(&url)
210 .send()
211 .await
212 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
213
214 if resp.status() == reqwest::StatusCode::NOT_FOUND {
215 return Err(anyhow!("DID not found: {}", did));
216 }
217
218 if !resp.status().is_success() {
219 return Err(anyhow!("HTTP {}", resp.status()));
220 }
221
222 resp.json::<FullDidDocument>()
223 .await
224 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
225 }
226
227 async fn resolve_did_web(&self, did: &str) -> Result<FullDidDocument> {
228 let host = did
229 .strip_prefix("did:web:")
230 .ok_or_else(|| anyhow!("Invalid did:web format"))?;
231
232 let decoded_host = host.replace("%3A", ":");
233 let (host_part, path_part) = if let Some(idx) = decoded_host.find('/') {
234 (&decoded_host[..idx], &decoded_host[idx..])
235 } else {
236 (decoded_host.as_str(), "")
237 };
238
239 let scheme = if host_part.starts_with("localhost")
240 || host_part.starts_with("127.0.0.1")
241 || host_part.contains(':')
242 {
243 "http"
244 } else {
245 "https"
246 };
247
248 let url = if path_part.is_empty() {
249 format!("{}://{}/.well-known/did.json", scheme, host_part)
250 } else {
251 format!("{}://{}{}/did.json", scheme, host_part, path_part)
252 };
253
254 debug!("Resolving did:web {} via {}", did, url);
255
256 let resp = self
257 .client
258 .get(&url)
259 .send()
260 .await
261 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
262
263 if !resp.status().is_success() {
264 return Err(anyhow!("HTTP {}", resp.status()));
265 }
266
267 resp.json::<FullDidDocument>()
268 .await
269 .map_err(|e| anyhow!("Failed to parse DID document: {}", e))
270 }
271}
272
273impl Default for ServiceTokenVerifier {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279fn parse_did_key_multibase(multibase: &str) -> Result<VerifyingKey> {
280 if !multibase.starts_with('z') {
281 return Err(anyhow!(
282 "Expected base58btc multibase encoding (starts with 'z')"
283 ));
284 }
285
286 let (_, decoded) =
287 multibase::decode(multibase).map_err(|e| anyhow!("Failed to decode multibase: {}", e))?;
288
289 if decoded.len() < 2 {
290 return Err(anyhow!("Invalid multicodec data"));
291 }
292
293 let (codec, key_bytes) = if decoded[0] == 0xe7 && decoded[1] == 0x01 {
294 (0xe701u16, &decoded[2..])
295 } else {
296 return Err(anyhow!(
297 "Unsupported key type. Expected secp256k1 (0xe701), got {:02x}{:02x}",
298 decoded[0],
299 decoded[1]
300 ));
301 };
302
303 if codec != 0xe701 {
304 return Err(anyhow!("Only secp256k1 keys are supported"));
305 }
306
307 VerifyingKey::from_sec1_bytes(key_bytes).map_err(|e| anyhow!("Invalid public key: {}", e))
308}
309
310pub fn is_service_token(token: &str) -> bool {
311 let parts: Vec<&str> = token.split('.').collect();
312 if parts.len() != 3 {
313 return false;
314 }
315
316 let Ok(claims_bytes) = URL_SAFE_NO_PAD.decode(parts[1]) else {
317 return false;
318 };
319
320 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&claims_bytes) else {
321 return false;
322 };
323
324 claims.get("lxm").is_some()
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_is_service_token() {
333 let claims_with_lxm = serde_json::json!({
334 "iss": "did:plc:test",
335 "sub": "did:plc:test",
336 "aud": "did:web:test.com",
337 "exp": 9999999999i64,
338 "iat": 1000000000i64,
339 "lxm": "com.atproto.repo.uploadBlob",
340 "jti": "test-jti"
341 });
342
343 let claims_without_lxm = serde_json::json!({
344 "iss": "did:plc:test",
345 "sub": "did:plc:test",
346 "aud": "did:web:test.com",
347 "exp": 9999999999i64,
348 "iat": 1000000000i64,
349 "jti": "test-jti"
350 });
351
352 let token_with_lxm = format!(
353 "{}.{}.{}",
354 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"jwt"}"#),
355 URL_SAFE_NO_PAD.encode(claims_with_lxm.to_string()),
356 URL_SAFE_NO_PAD.encode("fake-sig")
357 );
358
359 let token_without_lxm = format!(
360 "{}.{}.{}",
361 URL_SAFE_NO_PAD.encode(r#"{"alg":"ES256K","typ":"at+jwt"}"#),
362 URL_SAFE_NO_PAD.encode(claims_without_lxm.to_string()),
363 URL_SAFE_NO_PAD.encode("fake-sig")
364 );
365
366 assert!(is_service_token(&token_with_lxm));
367 assert!(!is_service_token(&token_without_lxm));
368 }
369
370 #[test]
371 fn test_parse_did_key_multibase() {
372 let test_key = "zQ3shcXtVCEBjUvAhzTW3r12DkpFdR2KmA3rHmuEMFx4GMBDB";
373 let result = parse_did_key_multibase(test_key);
374 assert!(result.is_ok(), "Failed to parse valid multibase key");
375 }
376}