this repo has no description
1use crate::state::AppState; 2use axum::{ 3 Json, 4 extract::{Path, State}, 5 http::StatusCode, 6 response::{IntoResponse, Response}, 7}; 8use base64::Engine; 9use k256::SecretKey; 10use k256::elliptic_curve::sec1::ToEncodedPoint; 11use reqwest; 12use serde_json::json; 13use sqlx::Row; 14use tracing::error; 15 16pub fn get_jwk(key_bytes: &[u8]) -> serde_json::Value { 17 let secret_key = SecretKey::from_slice(key_bytes).expect("Invalid key length"); 18 let public_key = secret_key.public_key(); 19 let encoded = public_key.to_encoded_point(false); 20 let x = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.x().unwrap()); 21 let y = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded.y().unwrap()); 22 23 json!({ 24 "kty": "EC", 25 "crv": "secp256k1", 26 "x": x, 27 "y": y 28 }) 29} 30 31pub async fn well_known_did(State(_state): State<AppState>) -> impl IntoResponse { 32 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 33 // Kinda for local dev, encode hostname if it contains port 34 let did = if hostname.contains(':') { 35 format!("did:web:{}", hostname.replace(':', "%3A")) 36 } else { 37 format!("did:web:{}", hostname) 38 }; 39 40 Json(json!({ 41 "@context": ["https://www.w3.org/ns/did/v1"], 42 "id": did, 43 "service": [{ 44 "id": "#atproto_pds", 45 "type": "AtprotoPersonalDataServer", 46 "serviceEndpoint": format!("https://{}", hostname) 47 }] 48 })) 49} 50 51pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 52 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 53 54 let user = sqlx::query("SELECT id, did FROM users WHERE handle = $1") 55 .bind(&handle) 56 .fetch_optional(&state.db) 57 .await; 58 59 let (user_id, did) = match user { 60 Ok(Some(row)) => { 61 let id: uuid::Uuid = row.get("id"); 62 let d: String = row.get("did"); 63 (id, d) 64 } 65 Ok(None) => { 66 return (StatusCode::NOT_FOUND, Json(json!({"error": "NotFound"}))).into_response(); 67 } 68 Err(e) => { 69 error!("DB Error: {:?}", e); 70 return ( 71 StatusCode::INTERNAL_SERVER_ERROR, 72 Json(json!({"error": "InternalError"})), 73 ) 74 .into_response(); 75 } 76 }; 77 78 if !did.starts_with("did:web:") { 79 return ( 80 StatusCode::NOT_FOUND, 81 Json(json!({"error": "NotFound", "message": "User is not did:web"})), 82 ) 83 .into_response(); 84 } 85 86 let key_row = sqlx::query("SELECT key_bytes FROM user_keys WHERE user_id = $1") 87 .bind(user_id) 88 .fetch_optional(&state.db) 89 .await; 90 91 let key_bytes: Vec<u8> = match key_row { 92 Ok(Some(row)) => row.get("key_bytes"), 93 _ => { 94 return ( 95 StatusCode::INTERNAL_SERVER_ERROR, 96 Json(json!({"error": "InternalError"})), 97 ) 98 .into_response(); 99 } 100 }; 101 102 let jwk = get_jwk(&key_bytes); 103 104 Json(json!({ 105 "@context": ["https://www.w3.org/ns/did/v1", "https://w3id.org/security/suites/jws-2020/v1"], 106 "id": did, 107 "alsoKnownAs": [format!("at://{}", handle)], 108 "verificationMethod": [{ 109 "id": format!("{}#atproto", did), 110 "type": "JsonWebKey2020", 111 "controller": did, 112 "publicKeyJwk": jwk 113 }], 114 "service": [{ 115 "id": "#atproto_pds", 116 "type": "AtprotoPersonalDataServer", 117 "serviceEndpoint": format!("https://{}", hostname) 118 }] 119 })).into_response() 120} 121 122pub async fn verify_did_web(did: &str, hostname: &str, handle: &str) -> Result<(), String> { 123 let expected_prefix = if hostname.contains(':') { 124 format!("did:web:{}", hostname.replace(':', "%3A")) 125 } else { 126 format!("did:web:{}", hostname) 127 }; 128 129 if did.starts_with(&expected_prefix) { 130 let suffix = &did[expected_prefix.len()..]; 131 let expected_suffix = format!(":u:{}", handle); 132 if suffix == expected_suffix { 133 Ok(()) 134 } else { 135 Err(format!( 136 "Invalid DID path for this PDS. Expected {}", 137 expected_suffix 138 )) 139 } 140 } else { 141 let parts: Vec<&str> = did.split(':').collect(); 142 if parts.len() < 3 || parts[0] != "did" || parts[1] != "web" { 143 return Err("Invalid did:web format".into()); 144 } 145 146 let domain_segment = parts[2]; 147 let domain = domain_segment.replace("%3A", ":"); 148 149 let scheme = if domain.starts_with("localhost") || domain.starts_with("127.0.0.1") { 150 "http" 151 } else { 152 "https" 153 }; 154 155 let url = if parts.len() == 3 { 156 format!("{}://{}/.well-known/did.json", scheme, domain) 157 } else { 158 let path = parts[3..].join("/"); 159 format!("{}://{}/{}/did.json", scheme, domain, path) 160 }; 161 162 let client = reqwest::Client::builder() 163 .timeout(std::time::Duration::from_secs(5)) 164 .build() 165 .map_err(|e| format!("Failed to create client: {}", e))?; 166 167 let resp = client 168 .get(&url) 169 .send() 170 .await 171 .map_err(|e| format!("Failed to fetch DID doc: {}", e))?; 172 173 if !resp.status().is_success() { 174 return Err(format!("Failed to fetch DID doc: HTTP {}", resp.status())); 175 } 176 177 let doc: serde_json::Value = resp 178 .json() 179 .await 180 .map_err(|e| format!("Failed to parse DID doc: {}", e))?; 181 182 let services = doc["service"] 183 .as_array() 184 .ok_or("No services found in DID doc")?; 185 186 let pds_endpoint = format!("https://{}", hostname); 187 188 let has_valid_service = services.iter().any(|s| { 189 s["type"] == "AtprotoPersonalDataServer" && s["serviceEndpoint"] == pds_endpoint 190 }); 191 192 if has_valid_service { 193 Ok(()) 194 } else { 195 Err(format!( 196 "DID document does not list this PDS ({}) as AtprotoPersonalDataServer", 197 pds_endpoint 198 )) 199 } 200 } 201}