this repo has no description
at main 12 kB view raw
1use axum::http::HeaderMap; 2use cid::Cid; 3use ipld_core::ipld::Ipld; 4use rand::Rng; 5use serde_json::Value as JsonValue; 6use sqlx::PgPool; 7use std::collections::BTreeMap; 8use std::str::FromStr; 9use std::sync::OnceLock; 10use uuid::Uuid; 11 12use crate::types::{Did, Handle}; 13 14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 16 17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 18 19pub fn get_max_blob_size() -> usize { 20 *MAX_BLOB_SIZE.get_or_init(|| { 21 std::env::var("MAX_BLOB_SIZE") 22 .ok() 23 .and_then(|s| s.parse().ok()) 24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 25 }) 26} 27 28pub fn generate_token_code() -> String { 29 generate_token_code_parts(2, 5) 30} 31 32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 33 let mut rng = rand::thread_rng(); 34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 35 36 (0..parts) 37 .map(|_| { 38 (0..part_len) 39 .map(|_| chars[rng.gen_range(0..chars.len())]) 40 .collect::<String>() 41 }) 42 .collect::<Vec<_>>() 43 .join("-") 44} 45 46#[derive(Debug)] 47pub enum DbLookupError { 48 NotFound, 49 DatabaseError(sqlx::Error), 50} 51 52impl From<sqlx::Error> for DbLookupError { 53 fn from(e: sqlx::Error) -> Self { 54 DbLookupError::DatabaseError(e) 55 } 56} 57 58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 60 .fetch_optional(db) 61 .await? 62 .ok_or(DbLookupError::NotFound) 63} 64 65pub struct UserInfo { 66 pub id: Uuid, 67 pub did: Did, 68 pub handle: Handle, 69} 70 71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 72 sqlx::query_as!( 73 UserInfo, 74 "SELECT id, did, handle FROM users WHERE did = $1", 75 did 76 ) 77 .fetch_optional(db) 78 .await? 79 .ok_or(DbLookupError::NotFound) 80} 81 82pub async fn get_user_by_identifier( 83 db: &PgPool, 84 identifier: &str, 85) -> Result<UserInfo, DbLookupError> { 86 sqlx::query_as!( 87 UserInfo, 88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 89 identifier 90 ) 91 .fetch_optional(db) 92 .await? 93 .ok_or(DbLookupError::NotFound) 94} 95 96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 97 let row = sqlx::query!( 98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 99 did 100 ) 101 .fetch_optional(db) 102 .await?; 103 Ok(row.map(|r| r.migrated).unwrap_or(false)) 104} 105 106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 107 query 108 .map(|q| { 109 q.split('&') 110 .filter_map(|pair| { 111 pair.split_once('=') 112 .filter(|(k, _)| *k == key) 113 .and_then(|(_, v)| urlencoding::decode(v).ok()) 114 .map(|decoded| decoded.into_owned()) 115 }) 116 .flat_map(|decoded| { 117 if decoded.contains(',') { 118 decoded 119 .split(',') 120 .filter_map(|part| { 121 let trimmed = part.trim(); 122 (!trimmed.is_empty()).then(|| trimmed.to_string()) 123 }) 124 .collect::<Vec<_>>() 125 } else if decoded.is_empty() { 126 vec![] 127 } else { 128 vec![decoded] 129 } 130 }) 131 .collect() 132 }) 133 .unwrap_or_default() 134} 135 136pub fn extract_client_ip(headers: &HeaderMap) -> String { 137 if let Some(forwarded) = headers.get("x-forwarded-for") 138 && let Ok(value) = forwarded.to_str() 139 && let Some(first_ip) = value.split(',').next() 140 { 141 return first_ip.trim().to_string(); 142 } 143 if let Some(real_ip) = headers.get("x-real-ip") 144 && let Ok(value) = real_ip.to_str() 145 { 146 return value.trim().to_string(); 147 } 148 "unknown".to_string() 149} 150 151pub fn pds_hostname() -> String { 152 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 153} 154 155pub fn pds_public_url() -> String { 156 format!("https://{}", pds_hostname()) 157} 158 159pub fn build_full_url(path: &str) -> String { 160 let normalized_path = if !path.starts_with("/xrpc/") 161 && (path.starts_with("/com.atproto.") 162 || path.starts_with("/app.bsky.") 163 || path.starts_with("/_")) 164 { 165 format!("/xrpc{}", path) 166 } else { 167 path.to_string() 168 }; 169 format!("{}{}", pds_public_url(), normalized_path) 170} 171 172pub fn json_to_ipld(value: &JsonValue) -> Ipld { 173 match value { 174 JsonValue::Null => Ipld::Null, 175 JsonValue::Bool(b) => Ipld::Bool(*b), 176 JsonValue::Number(n) => { 177 if let Some(i) = n.as_i64() { 178 Ipld::Integer(i as i128) 179 } else if let Some(f) = n.as_f64() { 180 Ipld::Float(f) 181 } else { 182 Ipld::Null 183 } 184 } 185 JsonValue::String(s) => Ipld::String(s.clone()), 186 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()), 187 JsonValue::Object(obj) => { 188 if let Some(JsonValue::String(link)) = obj.get("$link") 189 && obj.len() == 1 190 && let Ok(cid) = Cid::from_str(link) 191 { 192 return Ipld::Link(cid); 193 } 194 let map: BTreeMap<String, Ipld> = obj 195 .iter() 196 .map(|(k, v)| (k.clone(), json_to_ipld(v))) 197 .collect(); 198 Ipld::Map(map) 199 } 200 } 201} 202 203#[cfg(test)] 204mod tests { 205 use super::*; 206 207 #[test] 208 fn test_parse_repeated_query_param_repeated() { 209 let query = "did=test&cids=a&cids=b&cids=c"; 210 let result = parse_repeated_query_param(Some(query), "cids"); 211 assert_eq!(result, vec!["a", "b", "c"]); 212 } 213 214 #[test] 215 fn test_parse_repeated_query_param_comma_separated() { 216 let query = "did=test&cids=a,b,c"; 217 let result = parse_repeated_query_param(Some(query), "cids"); 218 assert_eq!(result, vec!["a", "b", "c"]); 219 } 220 221 #[test] 222 fn test_parse_repeated_query_param_mixed() { 223 let query = "did=test&cids=a,b&cids=c"; 224 let result = parse_repeated_query_param(Some(query), "cids"); 225 assert_eq!(result, vec!["a", "b", "c"]); 226 } 227 228 #[test] 229 fn test_parse_repeated_query_param_single() { 230 let query = "did=test&cids=a"; 231 let result = parse_repeated_query_param(Some(query), "cids"); 232 assert_eq!(result, vec!["a"]); 233 } 234 235 #[test] 236 fn test_parse_repeated_query_param_empty() { 237 let query = "did=test"; 238 let result = parse_repeated_query_param(Some(query), "cids"); 239 assert!(result.is_empty()); 240 } 241 242 #[test] 243 fn test_parse_repeated_query_param_url_encoded() { 244 let query = "did=test&cids=bafyreib%2Btest"; 245 let result = parse_repeated_query_param(Some(query), "cids"); 246 assert_eq!(result, vec!["bafyreib+test"]); 247 } 248 249 #[test] 250 fn test_generate_token_code() { 251 let code = generate_token_code(); 252 assert_eq!(code.len(), 11); 253 assert!(code.contains('-')); 254 255 let parts: Vec<&str> = code.split('-').collect(); 256 assert_eq!(parts.len(), 2); 257 assert_eq!(parts[0].len(), 5); 258 assert_eq!(parts[1].len(), 5); 259 260 for c in code.chars() { 261 if c != '-' { 262 assert!(BASE32_ALPHABET.contains(c)); 263 } 264 } 265 } 266 267 #[test] 268 fn test_generate_token_code_parts() { 269 let code = generate_token_code_parts(3, 4); 270 let parts: Vec<&str> = code.split('-').collect(); 271 assert_eq!(parts.len(), 3); 272 273 for part in parts { 274 assert_eq!(part.len(), 4); 275 } 276 } 277 278 #[test] 279 fn test_json_to_ipld_cid_link() { 280 let json = serde_json::json!({ 281 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 282 }); 283 let ipld = json_to_ipld(&json); 284 match ipld { 285 Ipld::Link(cid) => { 286 assert_eq!( 287 cid.to_string(), 288 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 289 ); 290 } 291 _ => panic!("Expected Ipld::Link, got {:?}", ipld), 292 } 293 } 294 295 #[test] 296 fn test_json_to_ipld_blob_ref() { 297 let json = serde_json::json!({ 298 "$type": "blob", 299 "ref": { 300 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 301 }, 302 "mimeType": "image/jpeg", 303 "size": 12345 304 }); 305 let ipld = json_to_ipld(&json); 306 match ipld { 307 Ipld::Map(map) => { 308 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string()))); 309 assert_eq!( 310 map.get("mimeType"), 311 Some(&Ipld::String("image/jpeg".to_string())) 312 ); 313 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345))); 314 match map.get("ref") { 315 Some(Ipld::Link(cid)) => { 316 assert_eq!( 317 cid.to_string(), 318 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 319 ); 320 } 321 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")), 322 } 323 } 324 _ => panic!("Expected Ipld::Map, got {:?}", ipld), 325 } 326 } 327 328 #[test] 329 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() { 330 let record = serde_json::json!({ 331 "$type": "app.bsky.feed.post", 332 "text": "Hello world", 333 "embed": { 334 "$type": "app.bsky.embed.images", 335 "images": [ 336 { 337 "alt": "Test image", 338 "image": { 339 "$type": "blob", 340 "ref": { 341 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 342 }, 343 "mimeType": "image/jpeg", 344 "size": 12345 345 } 346 } 347 ] 348 } 349 }); 350 let ipld = json_to_ipld(&record); 351 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed"); 352 assert!(!cbor_bytes.is_empty()); 353 let parsed: Ipld = 354 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed"); 355 if let Ipld::Map(map) = &parsed 356 && let Some(Ipld::Map(embed)) = map.get("embed") 357 && let Some(Ipld::List(images)) = embed.get("images") 358 && let Some(Ipld::Map(img)) = images.first() 359 && let Some(Ipld::Map(blob)) = img.get("image") 360 && let Some(Ipld::Link(cid)) = blob.get("ref") 361 { 362 assert_eq!( 363 cid.to_string(), 364 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 365 ); 366 return; 367 } 368 panic!("Failed to find CID link in parsed CBOR"); 369 } 370 371 #[test] 372 fn test_build_full_url_adds_xrpc_prefix_for_atproto_paths() { 373 unsafe { std::env::set_var("PDS_HOSTNAME", "example.com") }; 374 assert_eq!( 375 build_full_url("/com.atproto.server.getSession"), 376 "https://example.com/xrpc/com.atproto.server.getSession" 377 ); 378 assert_eq!( 379 build_full_url("/app.bsky.feed.getTimeline"), 380 "https://example.com/xrpc/app.bsky.feed.getTimeline" 381 ); 382 assert_eq!( 383 build_full_url("/_health"), 384 "https://example.com/xrpc/_health" 385 ); 386 assert_eq!( 387 build_full_url("/xrpc/com.atproto.server.getSession"), 388 "https://example.com/xrpc/com.atproto.server.getSession" 389 ); 390 assert_eq!( 391 build_full_url("/oauth/token"), 392 "https://example.com/oauth/token" 393 ); 394 } 395}