this repo has no description
1use crate::state::AppState;
2use axum::{
3 Json,
4 extract::{Path, State},
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use base64::Engine;
9use k256::SecretKey;
10use k256::elliptic_curve::sec1::ToEncodedPoint;
11use reqwest;
12use serde_json::json;
13use sqlx::Row;
14use tracing::error;
15
16pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value {
17 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length");
18 let public_key = secret_key.public_key();
19 let encoded = public_key.to_encoded_point(false);
20 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap());
21 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap());
22
23 json!({
24 "kty": "EC",
25 "crv": "secp256k1",
26 "x": x,
27 "y": y
28 })
29}
30
31pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse {
32 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
33 // Kinda for local dev, encode hostname if it contains port
34 let did = if hostname.contains(':') {
35 format!("did:web:{}", hostname.replace(':', "%3A"))
36 } else {
37 format!("did:web:{}", hostname)
38 };
39
40 Json(json!({
41 "@context": ["https://www.w3.org/ns/did/v1"],
42 "id": did,
43 "service": [{
44 "id": "#atproto_pds",
45 "type": "AtprotoPersonalDataServer",
46 "serviceEndpoint": format!("https://{}", hostname)
47 }]
48 }))
49}
50
51pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
52 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
53
54 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1")
55 .bind(&handle)
56 .fetch_optional(&state.db)
57 .await;
58
59 let (user_id, did) = match user {
60 Ok(Some(row)) => {
61 let id: uuid::Uuid = row.get("id");
62 let d: String = row.get("did");
63 (id, d)
64 }
65 Ok(None) => {
66 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response();
67 }
68 Err(e) => {
69 error!("DB Error: {:?}", e);
70 return (
71 StatusCode::INTERNAL_SERVER_ERROR,
72 Json(json!({"error": "InternalError"})),
73 )
74 .into_response();
75 }
76 };
77
78 if !did.starts_with("did:web:") {
79 return (
80 StatusCode::NOT_FOUND,
81 Json(json!({"error": "NotFound", "message": "User is not did:web"})),
82 )
83 .into_response();
84 }
85
86 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1")
87 .bind(user_id)
88 .fetch_optional(&state.db)
89 .await;
90
91 let key_bytes: Vec<u8> = match key_row {
92 Ok(Some(row)) => row.get("key_bytes"),
93 _ => {
94 return (
95 StatusCode::INTERNAL_SERVER_ERROR,
96 Json(json!({"error": "InternalError"})),
97 )
98 .into_response();
99 }
100 };
101
102 let jwk = get_jwk(&key_bytes);
103
104 Json(json!({
105 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"],
106 "id": did,
107 "alsoKnownAs": [format!("at://{}", handle)],
108 "verificationMethod": [{
109 "id": format!("{}#atproto", did),
110 "type": "JsonWebKey2020",
111 "controller": did,
112 "publicKeyJwk": jwk
113 }],
114 "service": [{
115 "id": "#atproto_pds",
116 "type": "AtprotoPersonalDataServer",
117 "serviceEndpoint": format!("https://{}", hostname)
118 }]
119 })).into_response()
120}
121
122pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> {
123 let expected_prefix = if hostname.contains(':') {
124 format!("did:web:{}", hostname.replace(':', "%3A"))
125 } else {
126 format!("did:web:{}", hostname)
127 };
128
129 if did.starts_with(&expected_prefix) {
130 let suffix = &did[expected_prefix.len()..];
131 let expected_suffix = format!(":u:{}", handle);
132 if suffix == expected_suffix {
133 Ok(())
134 } else {
135 Err(format!(
136 "Invalid DID path for this PDS. Expected {}",
137 expected_suffix
138 ))
139 }
140 } else {
141 let parts: Vec<&str> = did.split(':').collect();
142 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" {
143 return Err("Invalid did:web format".into());
144 }
145
146 let domain_segment = parts[2];
147 let domain = domain_segment.replace("%3A", ":");
148
149 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") {
150 "http"
151 } else {
152 "https"
153 };
154
155 let url = if parts.len() == 3 {
156 format!("{}://{}/.well-known/did.json", scheme, domain)
157 } else {
158 let path = parts[3..].join("/");
159 format!("{}://{}/{}/did.json", scheme, domain, path)
160 };
161
162 let client = reqwest::Client::builder()
163 .timeout(std::time::Duration::from_secs(5))
164 .build()
165 .map_err(|e| format!("Failed to create client: {}", e))?;
166
167 let resp = client
168 .get(&url)
169 .send()
170 .await
171 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?;
172
173 if !resp.status().is_success() {
174 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status()));
175 }
176
177 let doc: serde_json::Value = resp
178 .json()
179 .await
180 .map_err(|e| format!("Failed to parse DID doc: {}", e))?;
181
182 let services = doc["service"]
183 .as_array()
184 .ok_or("No services found in DID doc")?;
185
186 let pds_endpoint = format!("https://{}", hostname);
187
188 let has_valid_service = services.iter().any(|s| {
189 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint
190 });
191
192 if has_valid_service {
193 Ok(())
194 } else {
195 Err(format!(
196 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer",
197 pds_endpoint
198 ))
199 }
200 }
201}