this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::{Path, Query, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use base64::Engine; 9use k256::SecretKey; 10use k256::elliptic_curve::sec1::ToEncodedPoint; 11use reqwest; 12use serde::Deserialize; 13use serde_json::json; 14use sqlx::Row; 15use tracing::error; 16 17#[derive(Deserialize)] 18pub struct ResolveHandleParams { 19 pub handle: String, 20} 21 22pub async fn resolve_handle( 23 State(state): State<AppState>, 24 Query(params): Query<ResolveHandleParams>, 25) -> Response { 26 let handle = params.handle.trim(); 27 28 if handle.is_empty() { 29 return ( 30 StatusCode::BAD_REQUEST, 31 Json(json!({"error": "InvalidRequest", "message": "handle is required"})), 32 ) 33 .into_response(); 34 } 35 36 let user = sqlx::query("SELECT did FROM users WHERE handle = $1") 37 .bind(handle) 38 .fetch_optional(&state.db) 39 .await; 40 41 match user { 42 Ok(Some(row)) => { 43 let did: String = row.get("did"); 44 (StatusCode::OK, Json(json!({ "did": did }))).into_response() 45 } 46 Ok(None) => ( 47 StatusCode::NOT_FOUND, 48 Json(json!({"error": "HandleNotFound", "message": "Unable to resolve handle"})), 49 ) 50 .into_response(), 51 Err(e) => { 52 error!("DB error resolving handle: {:?}", e); 53 ( 54 StatusCode::INTERNAL_SERVER_ERROR, 55 Json(json!({"error": "InternalError"})), 56 ) 57 .into_response() 58 } 59 } 60} 61 62pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 63 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 64 let public_key = secret_key.public_key(); 65 let encoded = public_key.to_encoded_point(false); 66 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 67 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 68 69 json!({ 70 "kty": "EC", 71 "crv": "secp256k1", 72 "x": x, 73 "y": y 74 }) 75} 76 77pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 78 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 79 // Kinda for local dev, encode hostname if it contains port 80 let did = if hostname.contains(':') { 81 format!("did:web:{}", hostname.replace(':', "%3A")) 82 } else { 83 format!("did:web:{}", hostname) 84 }; 85 86 Json(json!({ 87 "@context": ["https://www.w3.org/ns/did/v1"], 88 "id": did, 89 "service": [{ 90 "id": "#atproto_pds", 91 "type": "AtprotoPersonalDataServer", 92 "serviceEndpoint": format!("https://{}", hostname) 93 }] 94 })) 95} 96 97pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 98 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 99 100 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") 101 .bind(&handle) 102 .fetch_optional(&state.db) 103 .await; 104 105 let (user_id, did) = match user { 106 Ok(Some(row)) => { 107 let id: uuid::Uuid = row.get("id"); 108 let d: String = row.get("did"); 109 (id, d) 110 } 111 Ok(None) => { 112 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 113 } 114 Err(e) => { 115 error!("DB Error: {:?}", e); 116 return ( 117 StatusCode::INTERNAL_SERVER_ERROR, 118 Json(json!({"error": "InternalError"})), 119 ) 120 .into_response(); 121 } 122 }; 123 124 if !did.starts_with("did:web:") { 125 return ( 126 StatusCode::NOT_FOUND, 127 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 128 ) 129 .into_response(); 130 } 131 132 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") 133 .bind(user_id) 134 .fetch_optional(&state.db) 135 .await; 136 137 let key_bytes: Vec<u8> = match key_row { 138 Ok(Some(row)) => row.get("key_bytes"), 139 _ => { 140 return ( 141 StatusCode::INTERNAL_SERVER_ERROR, 142 Json(json!({"error": "InternalError"})), 143 ) 144 .into_response(); 145 } 146 }; 147 148 let jwk = get_jwk(&key_bytes); 149 150 Json(json!({ 151 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 152 "id": did, 153 "alsoKnownAs": [format!("at://{}", handle)], 154 "verificationMethod": [{ 155 "id": format!("{}#atproto", did), 156 "type": "JsonWebKey2020", 157 "controller": did, 158 "publicKeyJwk": jwk 159 }], 160 "service": [{ 161 "id": "#atproto_pds", 162 "type": "AtprotoPersonalDataServer", 163 "serviceEndpoint": format!("https://{}", hostname) 164 }] 165 })).into_response() 166} 167 168pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 169 let expected_prefix = if hostname.contains(':') { 170 format!("did:web:{}", hostname.replace(':', "%3A")) 171 } else { 172 format!("did:web:{}", hostname) 173 }; 174 175 if did.starts_with(&expected_prefix) { 176 let suffix = &did[expected_prefix.len()..]; 177 let expected_suffix = format!(":u:{}", handle); 178 if suffix == expected_suffix { 179 Ok(()) 180 } else { 181 Err(format!( 182 "Invalid DID path for this PDS. Expected {}", 183 expected_suffix 184 )) 185 } 186 } else { 187 let parts: Vec<&str> = did.split(':').collect(); 188 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 189 return Err("Invalid did:web format".into()); 190 } 191 192 let domain_segment = parts[2]; 193 let domain = domain_segment.replace("%3A", ":"); 194 195 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 196 "http" 197 } else { 198 "https" 199 }; 200 201 let url = if parts.len() == 3 { 202 format!("{}://{}/.well-known/did.json", scheme, domain) 203 } else { 204 let path = parts[3..].join("/"); 205 format!("{}://{}/{}/did.json", scheme, domain, path) 206 }; 207 208 let client = reqwest::Client::builder() 209 .timeout(std::time::Duration::from_secs(5)) 210 .build() 211 .map_err(|e| format!("Failed to create client: {}", e))?; 212 213 let resp = client 214 .get(&url) 215 .send() 216 .await 217 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 218 219 if !resp.status().is_success() { 220 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 221 } 222 223 let doc: serde_json::Value = resp 224 .json() 225 .await 226 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 227 228 let services = doc["service"] 229 .as_array() 230 .ok_or("No services found in DID doc")?; 231 232 let pds_endpoint = format!("https://{}", hostname); 233 234 let has_valid_service = services.iter().any(|s| { 235 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 236 }); 237 238 if has_valid_service { 239 Ok(()) 240 } else { 241 Err(format!( 242 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 243 pds_endpoint 244 )) 245 } 246 } 247}