this repo has no description
1use axum::http::HeaderMap; 2use cid::Cid; 3use ipld_core::ipld::Ipld; 4use rand::Rng; 5use serde_json::Value as JsonValue; 6use sqlx::PgPool; 7use std::collections::BTreeMap; 8use std::str::FromStr; 9use std::sync::OnceLock; 10use uuid::Uuid; 11 12use crate::types::{Did, Handle}; 13 14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 16 17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 18 19pub fn get_max_blob_size() -> usize { 20 *MAX_BLOB_SIZE.get_or_init(|| { 21 std::env::var("MAX_BLOB_SIZE") 22 .ok() 23 .and_then(|s| s.parse().ok()) 24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 25 }) 26} 27 28pub fn generate_token_code() -> String { 29 generate_token_code_parts(2, 5) 30} 31 32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 33 let mut rng = rand::thread_rng(); 34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 35 36 (0..parts) 37 .map(|_| { 38 (0..part_len) 39 .map(|_| chars[rng.gen_range(0..chars.len())]) 40 .collect::<String>() 41 }) 42 .collect::<Vec<_>>() 43 .join("-") 44} 45 46#[derive(Debug)] 47pub enum DbLookupError { 48 NotFound, 49 DatabaseError(sqlx::Error), 50} 51 52impl From<sqlx::Error> for DbLookupError { 53 fn from(e: sqlx::Error) -> Self { 54 DbLookupError::DatabaseError(e) 55 } 56} 57 58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 60 .fetch_optional(db) 61 .await? 62 .ok_or(DbLookupError::NotFound) 63} 64 65pub struct UserInfo { 66 pub id: Uuid, 67 pub did: Did, 68 pub handle: Handle, 69} 70 71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 72 sqlx::query_as!( 73 UserInfo, 74 "SELECT id, did, handle FROM users WHERE did = $1", 75 did 76 ) 77 .fetch_optional(db) 78 .await? 79 .ok_or(DbLookupError::NotFound) 80} 81 82pub async fn get_user_by_identifier( 83 db: &PgPool, 84 identifier: &str, 85) -> Result<UserInfo, DbLookupError> { 86 sqlx::query_as!( 87 UserInfo, 88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 89 identifier 90 ) 91 .fetch_optional(db) 92 .await? 93 .ok_or(DbLookupError::NotFound) 94} 95 96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 97 let row = sqlx::query!( 98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 99 did 100 ) 101 .fetch_optional(db) 102 .await?; 103 Ok(row.map(|r| r.migrated).unwrap_or(false)) 104} 105 106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 107 query 108 .map(|q| { 109 let mut values = Vec::new(); 110 for pair in q.split('&') { 111 if let Some((k, v)) = pair.split_once('=') 112 && k == key 113 && let Ok(decoded) = urlencoding::decode(v) 114 { 115 let decoded = decoded.into_owned(); 116 if decoded.contains(',') { 117 for part in decoded.split(',') { 118 let trimmed = part.trim(); 119 if !trimmed.is_empty() { 120 values.push(trimmed.to_string()); 121 } 122 } 123 } else if !decoded.is_empty() { 124 values.push(decoded); 125 } 126 } 127 } 128 values 129 }) 130 .unwrap_or_default() 131} 132 133pub fn extract_client_ip(headers: &HeaderMap) -> String { 134 if let Some(forwarded) = headers.get("x-forwarded-for") 135 && let Ok(value) = forwarded.to_str() 136 && let Some(first_ip) = value.split(',').next() 137 { 138 return first_ip.trim().to_string(); 139 } 140 if let Some(real_ip) = headers.get("x-real-ip") 141 && let Ok(value) = real_ip.to_str() 142 { 143 return value.trim().to_string(); 144 } 145 "unknown".to_string() 146} 147 148pub fn pds_hostname() -> String { 149 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 150} 151 152pub fn pds_public_url() -> String { 153 format!("https://{}", pds_hostname()) 154} 155 156pub fn build_full_url(path: &str) -> String { 157 format!("{}{}", pds_public_url(), path) 158} 159 160pub fn json_to_ipld(value: &JsonValue) -> Ipld { 161 match value { 162 JsonValue::Null => Ipld::Null, 163 JsonValue::Bool(b) => Ipld::Bool(*b), 164 JsonValue::Number(n) => { 165 if let Some(i) = n.as_i64() { 166 Ipld::Integer(i as i128) 167 } else if let Some(f) = n.as_f64() { 168 Ipld::Float(f) 169 } else { 170 Ipld::Null 171 } 172 } 173 JsonValue::String(s) => Ipld::String(s.clone()), 174 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()), 175 JsonValue::Object(obj) => { 176 if let Some(JsonValue::String(link)) = obj.get("$link") 177 && obj.len() == 1 178 && let Ok(cid) = Cid::from_str(link) 179 { 180 return Ipld::Link(cid); 181 } 182 let map: BTreeMap<String, Ipld> = obj 183 .iter() 184 .map(|(k, v)| (k.clone(), json_to_ipld(v))) 185 .collect(); 186 Ipld::Map(map) 187 } 188 } 189} 190 191#[cfg(test)] 192mod tests { 193 use super::*; 194 195 #[test] 196 fn test_parse_repeated_query_param_repeated() { 197 let query = "did=test&cids=a&cids=b&cids=c"; 198 let result = parse_repeated_query_param(Some(query), "cids"); 199 assert_eq!(result, vec!["a", "b", "c"]); 200 } 201 202 #[test] 203 fn test_parse_repeated_query_param_comma_separated() { 204 let query = "did=test&cids=a,b,c"; 205 let result = parse_repeated_query_param(Some(query), "cids"); 206 assert_eq!(result, vec!["a", "b", "c"]); 207 } 208 209 #[test] 210 fn test_parse_repeated_query_param_mixed() { 211 let query = "did=test&cids=a,b&cids=c"; 212 let result = parse_repeated_query_param(Some(query), "cids"); 213 assert_eq!(result, vec!["a", "b", "c"]); 214 } 215 216 #[test] 217 fn test_parse_repeated_query_param_single() { 218 let query = "did=test&cids=a"; 219 let result = parse_repeated_query_param(Some(query), "cids"); 220 assert_eq!(result, vec!["a"]); 221 } 222 223 #[test] 224 fn test_parse_repeated_query_param_empty() { 225 let query = "did=test"; 226 let result = parse_repeated_query_param(Some(query), "cids"); 227 assert!(result.is_empty()); 228 } 229 230 #[test] 231 fn test_parse_repeated_query_param_url_encoded() { 232 let query = "did=test&cids=bafyreib%2Btest"; 233 let result = parse_repeated_query_param(Some(query), "cids"); 234 assert_eq!(result, vec!["bafyreib+test"]); 235 } 236 237 #[test] 238 fn test_generate_token_code() { 239 let code = generate_token_code(); 240 assert_eq!(code.len(), 11); 241 assert!(code.contains('-')); 242 243 let parts: Vec<&str> = code.split('-').collect(); 244 assert_eq!(parts.len(), 2); 245 assert_eq!(parts[0].len(), 5); 246 assert_eq!(parts[1].len(), 5); 247 248 for c in code.chars() { 249 if c != '-' { 250 assert!(BASE32_ALPHABET.contains(c)); 251 } 252 } 253 } 254 255 #[test] 256 fn test_generate_token_code_parts() { 257 let code = generate_token_code_parts(3, 4); 258 let parts: Vec<&str> = code.split('-').collect(); 259 assert_eq!(parts.len(), 3); 260 261 for part in parts { 262 assert_eq!(part.len(), 4); 263 } 264 } 265 266 #[test] 267 fn test_json_to_ipld_cid_link() { 268 let json = serde_json::json!({ 269 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 270 }); 271 let ipld = json_to_ipld(&json); 272 match ipld { 273 Ipld::Link(cid) => { 274 assert_eq!( 275 cid.to_string(), 276 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 277 ); 278 } 279 _ => panic!("Expected Ipld::Link, got {:?}", ipld), 280 } 281 } 282 283 #[test] 284 fn test_json_to_ipld_blob_ref() { 285 let json = serde_json::json!({ 286 "$type": "blob", 287 "ref": { 288 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 289 }, 290 "mimeType": "image/jpeg", 291 "size": 12345 292 }); 293 let ipld = json_to_ipld(&json); 294 match ipld { 295 Ipld::Map(map) => { 296 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string()))); 297 assert_eq!( 298 map.get("mimeType"), 299 Some(&Ipld::String("image/jpeg".to_string())) 300 ); 301 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345))); 302 match map.get("ref") { 303 Some(Ipld::Link(cid)) => { 304 assert_eq!( 305 cid.to_string(), 306 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 307 ); 308 } 309 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")), 310 } 311 } 312 _ => panic!("Expected Ipld::Map, got {:?}", ipld), 313 } 314 } 315 316 #[test] 317 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() { 318 let record = serde_json::json!({ 319 "$type": "app.bsky.feed.post", 320 "text": "Hello world", 321 "embed": { 322 "$type": "app.bsky.embed.images", 323 "images": [ 324 { 325 "alt": "Test image", 326 "image": { 327 "$type": "blob", 328 "ref": { 329 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 330 }, 331 "mimeType": "image/jpeg", 332 "size": 12345 333 } 334 } 335 ] 336 } 337 }); 338 let ipld = json_to_ipld(&record); 339 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed"); 340 assert!(!cbor_bytes.is_empty()); 341 let parsed: Ipld = 342 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed"); 343 if let Ipld::Map(map) = &parsed 344 && let Some(Ipld::Map(embed)) = map.get("embed") 345 && let Some(Ipld::List(images)) = embed.get("images") 346 && let Some(Ipld::Map(img)) = images.first() 347 && let Some(Ipld::Map(blob)) = img.get("image") 348 && let Some(Ipld::Link(cid)) = blob.get("ref") 349 { 350 assert_eq!( 351 cid.to_string(), 352 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 353 ); 354 return; 355 } 356 panic!("Failed to find CID link in parsed CBOR"); 357 } 358}