this repo has no description
1use axum::http::HeaderMap; 2use rand::Rng; 3use sqlx::PgPool; 4use std::sync::OnceLock; 5use uuid::Uuid; 6 7const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 8const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 9 10static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 11 12pub fn get_max_blob_size() -> usize { 13 *MAX_BLOB_SIZE.get_or_init(|| { 14 std::env::var("MAX_BLOB_SIZE") 15 .ok() 16 .and_then(|s| s.parse().ok()) 17 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 18 }) 19} 20 21pub fn generate_token_code() -> String { 22 generate_token_code_parts(2, 5) 23} 24 25pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 26 let mut rng = rand::thread_rng(); 27 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 28 29 (0..parts) 30 .map(|_| { 31 (0..part_len) 32 .map(|_| chars[rng.gen_range(0..chars.len())]) 33 .collect::<String>() 34 }) 35 .collect::<Vec<_>>() 36 .join("-") 37} 38 39#[derive(Debug)] 40pub enum DbLookupError { 41 NotFound, 42 DatabaseError(sqlx::Error), 43} 44 45impl From<sqlx::Error> for DbLookupError { 46 fn from(e: sqlx::Error) -> Self { 47 DbLookupError::DatabaseError(e) 48 } 49} 50 51pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 52 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 53 .fetch_optional(db) 54 .await? 55 .ok_or(DbLookupError::NotFound) 56} 57 58pub struct UserInfo { 59 pub id: Uuid, 60 pub did: String, 61 pub handle: String, 62} 63 64pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 65 sqlx::query_as!( 66 UserInfo, 67 "SELECT id, did, handle FROM users WHERE did = $1", 68 did 69 ) 70 .fetch_optional(db) 71 .await? 72 .ok_or(DbLookupError::NotFound) 73} 74 75pub async fn get_user_by_identifier( 76 db: &PgPool, 77 identifier: &str, 78) -> Result<UserInfo, DbLookupError> { 79 sqlx::query_as!( 80 UserInfo, 81 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 82 identifier 83 ) 84 .fetch_optional(db) 85 .await? 86 .ok_or(DbLookupError::NotFound) 87} 88 89pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 90 let row = sqlx::query!( 91 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 92 did 93 ) 94 .fetch_optional(db) 95 .await?; 96 Ok(row.map(|r| r.migrated).unwrap_or(false)) 97} 98 99pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 100 query 101 .map(|q| { 102 let mut values = Vec::new(); 103 for pair in q.split('&') { 104 if let Some((k, v)) = pair.split_once('=') 105 && k == key 106 && let Ok(decoded) = urlencoding::decode(v) 107 { 108 let decoded = decoded.into_owned(); 109 if decoded.contains(',') { 110 for part in decoded.split(',') { 111 let trimmed = part.trim(); 112 if !trimmed.is_empty() { 113 values.push(trimmed.to_string()); 114 } 115 } 116 } else if !decoded.is_empty() { 117 values.push(decoded); 118 } 119 } 120 } 121 values 122 }) 123 .unwrap_or_default() 124} 125 126pub fn extract_client_ip(headers: &HeaderMap) -> String { 127 if let Some(forwarded) = headers.get("x-forwarded-for") 128 && let Ok(value) = forwarded.to_str() 129 && let Some(first_ip) = value.split(',').next() 130 { 131 return first_ip.trim().to_string(); 132 } 133 if let Some(real_ip) = headers.get("x-real-ip") 134 && let Ok(value) = real_ip.to_str() 135 { 136 return value.trim().to_string(); 137 } 138 "unknown".to_string() 139} 140 141pub fn pds_hostname() -> String { 142 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 143} 144 145pub fn pds_public_url() -> String { 146 format!("https://{}", pds_hostname()) 147} 148 149pub fn build_full_url(path: &str) -> String { 150 format!("{}{}", pds_public_url(), path) 151} 152 153#[cfg(test)] 154mod tests { 155 use super::*; 156 157 #[test] 158 fn test_parse_repeated_query_param_repeated() { 159 let query = "did=test&cids=a&cids=b&cids=c"; 160 let result = parse_repeated_query_param(Some(query), "cids"); 161 assert_eq!(result, vec!["a", "b", "c"]); 162 } 163 164 #[test] 165 fn test_parse_repeated_query_param_comma_separated() { 166 let query = "did=test&cids=a,b,c"; 167 let result = parse_repeated_query_param(Some(query), "cids"); 168 assert_eq!(result, vec!["a", "b", "c"]); 169 } 170 171 #[test] 172 fn test_parse_repeated_query_param_mixed() { 173 let query = "did=test&cids=a,b&cids=c"; 174 let result = parse_repeated_query_param(Some(query), "cids"); 175 assert_eq!(result, vec!["a", "b", "c"]); 176 } 177 178 #[test] 179 fn test_parse_repeated_query_param_single() { 180 let query = "did=test&cids=a"; 181 let result = parse_repeated_query_param(Some(query), "cids"); 182 assert_eq!(result, vec!["a"]); 183 } 184 185 #[test] 186 fn test_parse_repeated_query_param_empty() { 187 let query = "did=test"; 188 let result = parse_repeated_query_param(Some(query), "cids"); 189 assert!(result.is_empty()); 190 } 191 192 #[test] 193 fn test_parse_repeated_query_param_url_encoded() { 194 let query = "did=test&cids=bafyreib%2Btest"; 195 let result = parse_repeated_query_param(Some(query), "cids"); 196 assert_eq!(result, vec!["bafyreib+test"]); 197 } 198 199 #[test] 200 fn test_generate_token_code() { 201 let code = generate_token_code(); 202 assert_eq!(code.len(), 11); 203 assert!(code.contains('-')); 204 205 let parts: Vec<&str> = code.split('-').collect(); 206 assert_eq!(parts.len(), 2); 207 assert_eq!(parts[0].len(), 5); 208 assert_eq!(parts[1].len(), 5); 209 210 for c in code.chars() { 211 if c != '-' { 212 assert!(BASE32_ALPHABET.contains(c)); 213 } 214 } 215 } 216 217 #[test] 218 fn test_generate_token_code_parts() { 219 let code = generate_token_code_parts(3, 4); 220 let parts: Vec<&str> = code.split('-').collect(); 221 assert_eq!(parts.len(), 3); 222 223 for part in parts { 224 assert_eq!(part.len(), 4); 225 } 226 } 227}