this repo has no description
1use axum::http::HeaderMap; 2use cid::Cid; 3use ipld_core::ipld::Ipld; 4use rand::Rng; 5use serde_json::Value as JsonValue; 6use sqlx::PgPool; 7use std::collections::BTreeMap; 8use std::str::FromStr; 9use std::sync::OnceLock; 10use uuid::Uuid; 11 12const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 13const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 14 15static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 16 17pub fn get_max_blob_size() -> usize { 18 *MAX_BLOB_SIZE.get_or_init(|| { 19 std::env::var("MAX_BLOB_SIZE") 20 .ok() 21 .and_then(|s| s.parse().ok()) 22 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 23 }) 24} 25 26pub fn generate_token_code() -> String { 27 generate_token_code_parts(2, 5) 28} 29 30pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 31 let mut rng = rand::thread_rng(); 32 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 33 34 (0..parts) 35 .map(|_| { 36 (0..part_len) 37 .map(|_| chars[rng.gen_range(0..chars.len())]) 38 .collect::<String>() 39 }) 40 .collect::<Vec<_>>() 41 .join("-") 42} 43 44#[derive(Debug)] 45pub enum DbLookupError { 46 NotFound, 47 DatabaseError(sqlx::Error), 48} 49 50impl From<sqlx::Error> for DbLookupError { 51 fn from(e: sqlx::Error) -> Self { 52 DbLookupError::DatabaseError(e) 53 } 54} 55 56pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 57 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 58 .fetch_optional(db) 59 .await? 60 .ok_or(DbLookupError::NotFound) 61} 62 63pub struct UserInfo { 64 pub id: Uuid, 65 pub did: String, 66 pub handle: String, 67} 68 69pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 70 sqlx::query_as!( 71 UserInfo, 72 "SELECT id, did, handle FROM users WHERE did = $1", 73 did 74 ) 75 .fetch_optional(db) 76 .await? 77 .ok_or(DbLookupError::NotFound) 78} 79 80pub async fn get_user_by_identifier( 81 db: &PgPool, 82 identifier: &str, 83) -> Result<UserInfo, DbLookupError> { 84 sqlx::query_as!( 85 UserInfo, 86 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 87 identifier 88 ) 89 .fetch_optional(db) 90 .await? 91 .ok_or(DbLookupError::NotFound) 92} 93 94pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 95 let row = sqlx::query!( 96 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 97 did 98 ) 99 .fetch_optional(db) 100 .await?; 101 Ok(row.map(|r| r.migrated).unwrap_or(false)) 102} 103 104pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 105 query 106 .map(|q| { 107 let mut values = Vec::new(); 108 for pair in q.split('&') { 109 if let Some((k, v)) = pair.split_once('=') 110 && k == key 111 && let Ok(decoded) = urlencoding::decode(v) 112 { 113 let decoded = decoded.into_owned(); 114 if decoded.contains(',') { 115 for part in decoded.split(',') { 116 let trimmed = part.trim(); 117 if !trimmed.is_empty() { 118 values.push(trimmed.to_string()); 119 } 120 } 121 } else if !decoded.is_empty() { 122 values.push(decoded); 123 } 124 } 125 } 126 values 127 }) 128 .unwrap_or_default() 129} 130 131pub fn extract_client_ip(headers: &HeaderMap) -> String { 132 if let Some(forwarded) = headers.get("x-forwarded-for") 133 && let Ok(value) = forwarded.to_str() 134 && let Some(first_ip) = value.split(',').next() 135 { 136 return first_ip.trim().to_string(); 137 } 138 if let Some(real_ip) = headers.get("x-real-ip") 139 && let Ok(value) = real_ip.to_str() 140 { 141 return value.trim().to_string(); 142 } 143 "unknown".to_string() 144} 145 146pub fn pds_hostname() -> String { 147 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 148} 149 150pub fn pds_public_url() -> String { 151 format!("https://{}", pds_hostname()) 152} 153 154pub fn build_full_url(path: &str) -> String { 155 format!("{}{}", pds_public_url(), path) 156} 157 158pub fn json_to_ipld(value: &JsonValue) -> Ipld { 159 match value { 160 JsonValue::Null => Ipld::Null, 161 JsonValue::Bool(b) => Ipld::Bool(*b), 162 JsonValue::Number(n) => { 163 if let Some(i) = n.as_i64() { 164 Ipld::Integer(i as i128) 165 } else if let Some(f) = n.as_f64() { 166 Ipld::Float(f) 167 } else { 168 Ipld::Null 169 } 170 } 171 JsonValue::String(s) => Ipld::String(s.clone()), 172 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()), 173 JsonValue::Object(obj) => { 174 if let Some(JsonValue::String(link)) = obj.get("$link") 175 && obj.len() == 1 176 && let Ok(cid) = Cid::from_str(link) 177 { 178 return Ipld::Link(cid); 179 } 180 let map: BTreeMap<String, Ipld> = obj 181 .iter() 182 .map(|(k, v)| (k.clone(), json_to_ipld(v))) 183 .collect(); 184 Ipld::Map(map) 185 } 186 } 187} 188 189#[cfg(test)] 190mod tests { 191 use super::*; 192 193 #[test] 194 fn test_parse_repeated_query_param_repeated() { 195 let query = "did=test&cids=a&cids=b&cids=c"; 196 let result = parse_repeated_query_param(Some(query), "cids"); 197 assert_eq!(result, vec!["a", "b", "c"]); 198 } 199 200 #[test] 201 fn test_parse_repeated_query_param_comma_separated() { 202 let query = "did=test&cids=a,b,c"; 203 let result = parse_repeated_query_param(Some(query), "cids"); 204 assert_eq!(result, vec!["a", "b", "c"]); 205 } 206 207 #[test] 208 fn test_parse_repeated_query_param_mixed() { 209 let query = "did=test&cids=a,b&cids=c"; 210 let result = parse_repeated_query_param(Some(query), "cids"); 211 assert_eq!(result, vec!["a", "b", "c"]); 212 } 213 214 #[test] 215 fn test_parse_repeated_query_param_single() { 216 let query = "did=test&cids=a"; 217 let result = parse_repeated_query_param(Some(query), "cids"); 218 assert_eq!(result, vec!["a"]); 219 } 220 221 #[test] 222 fn test_parse_repeated_query_param_empty() { 223 let query = "did=test"; 224 let result = parse_repeated_query_param(Some(query), "cids"); 225 assert!(result.is_empty()); 226 } 227 228 #[test] 229 fn test_parse_repeated_query_param_url_encoded() { 230 let query = "did=test&cids=bafyreib%2Btest"; 231 let result = parse_repeated_query_param(Some(query), "cids"); 232 assert_eq!(result, vec!["bafyreib+test"]); 233 } 234 235 #[test] 236 fn test_generate_token_code() { 237 let code = generate_token_code(); 238 assert_eq!(code.len(), 11); 239 assert!(code.contains('-')); 240 241 let parts: Vec<&str> = code.split('-').collect(); 242 assert_eq!(parts.len(), 2); 243 assert_eq!(parts[0].len(), 5); 244 assert_eq!(parts[1].len(), 5); 245 246 for c in code.chars() { 247 if c != '-' { 248 assert!(BASE32_ALPHABET.contains(c)); 249 } 250 } 251 } 252 253 #[test] 254 fn test_generate_token_code_parts() { 255 let code = generate_token_code_parts(3, 4); 256 let parts: Vec<&str> = code.split('-').collect(); 257 assert_eq!(parts.len(), 3); 258 259 for part in parts { 260 assert_eq!(part.len(), 4); 261 } 262 } 263 264 #[test] 265 fn test_json_to_ipld_cid_link() { 266 let json = serde_json::json!({ 267 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 268 }); 269 let ipld = json_to_ipld(&json); 270 match ipld { 271 Ipld::Link(cid) => { 272 assert_eq!( 273 cid.to_string(), 274 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 275 ); 276 } 277 _ => panic!("Expected Ipld::Link, got {:?}", ipld), 278 } 279 } 280 281 #[test] 282 fn test_json_to_ipld_blob_ref() { 283 let json = serde_json::json!({ 284 "$type": "blob", 285 "ref": { 286 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 287 }, 288 "mimeType": "image/jpeg", 289 "size": 12345 290 }); 291 let ipld = json_to_ipld(&json); 292 match ipld { 293 Ipld::Map(map) => { 294 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string()))); 295 assert_eq!( 296 map.get("mimeType"), 297 Some(&Ipld::String("image/jpeg".to_string())) 298 ); 299 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345))); 300 match map.get("ref") { 301 Some(Ipld::Link(cid)) => { 302 assert_eq!( 303 cid.to_string(), 304 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 305 ); 306 } 307 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")), 308 } 309 } 310 _ => panic!("Expected Ipld::Map, got {:?}", ipld), 311 } 312 } 313 314 #[test] 315 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() { 316 let record = serde_json::json!({ 317 "$type": "app.bsky.feed.post", 318 "text": "Hello world", 319 "embed": { 320 "$type": "app.bsky.embed.images", 321 "images": [ 322 { 323 "alt": "Test image", 324 "image": { 325 "$type": "blob", 326 "ref": { 327 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 328 }, 329 "mimeType": "image/jpeg", 330 "size": 12345 331 } 332 } 333 ] 334 } 335 }); 336 let ipld = json_to_ipld(&record); 337 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed"); 338 assert!(!cbor_bytes.is_empty()); 339 let parsed: Ipld = 340 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed"); 341 if let Ipld::Map(map) = &parsed 342 && let Some(Ipld::Map(embed)) = map.get("embed") 343 && let Some(Ipld::List(images)) = embed.get("images") 344 && let Some(Ipld::Map(img)) = images.first() 345 && let Some(Ipld::Map(blob)) = img.get("image") 346 && let Some(Ipld::Link(cid)) = blob.get("ref") 347 { 348 assert_eq!( 349 cid.to_string(), 350 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 351 ); 352 return; 353 } 354 panic!("Failed to find CID link in parsed CBOR"); 355 } 356}