this repo has no description
1use axum::http::HeaderMap; 2use rand::Rng; 3use sqlx::PgPool; 4use std::sync::OnceLock; 5use uuid::Uuid; 6 7const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 8const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 9 10static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 11 12pub fn get_max_blob_size() -> usize { 13 *MAX_BLOB_SIZE.get_or_init(|| { 14 std::env::var("MAX_BLOB_SIZE") 15 .ok() 16 .and_then(|s| s.parse().ok()) 17 .unwrap_or(DEFAULT_MAX_BLOB_SIZE) 18 }) 19} 20 21pub fn generate_token_code() -> String { 22 generate_token_code_parts(2, 5) 23} 24 25pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 26 let mut rng = rand::thread_rng(); 27 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 28 29 (0..parts) 30 .map(|_| { 31 (0..part_len) 32 .map(|_| chars[rng.gen_range(0..chars.len())]) 33 .collect::<String>() 34 }) 35 .collect::<Vec<_>>() 36 .join("-") 37} 38 39#[derive(Debug)] 40pub enum DbLookupError { 41 NotFound, 42 DatabaseError(sqlx::Error), 43} 44 45impl From<sqlx::Error> for DbLookupError { 46 fn from(e: sqlx::Error) -> Self { 47 DbLookupError::DatabaseError(e) 48 } 49} 50 51pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 52 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 53 .fetch_optional(db) 54 .await? 55 .ok_or(DbLookupError::NotFound) 56} 57 58pub struct UserInfo { 59 pub id: Uuid, 60 pub did: String, 61 pub handle: String, 62} 63 64pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 65 sqlx::query_as!( 66 UserInfo, 67 "SELECT id, did, handle FROM users WHERE did = $1", 68 did 69 ) 70 .fetch_optional(db) 71 .await? 72 .ok_or(DbLookupError::NotFound) 73} 74 75pub async fn get_user_by_identifier( 76 db: &PgPool, 77 identifier: &str, 78) -> Result<UserInfo, DbLookupError> { 79 sqlx::query_as!( 80 UserInfo, 81 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 82 identifier 83 ) 84 .fetch_optional(db) 85 .await? 86 .ok_or(DbLookupError::NotFound) 87} 88 89pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 90 query 91 .map(|q| { 92 let mut values = Vec::new(); 93 for pair in q.split('&') { 94 if let Some((k, v)) = pair.split_once('=') 95 && k == key 96 && let Ok(decoded) = urlencoding::decode(v) 97 { 98 let decoded = decoded.into_owned(); 99 if decoded.contains(',') { 100 for part in decoded.split(',') { 101 let trimmed = part.trim(); 102 if !trimmed.is_empty() { 103 values.push(trimmed.to_string()); 104 } 105 } 106 } else if !decoded.is_empty() { 107 values.push(decoded); 108 } 109 } 110 } 111 values 112 }) 113 .unwrap_or_default() 114} 115 116pub fn extract_client_ip(headers: &HeaderMap) -> String { 117 if let Some(forwarded) = headers.get("x-forwarded-for") 118 && let Ok(value) = forwarded.to_str() 119 && let Some(first_ip) = value.split(',').next() 120 { 121 return first_ip.trim().to_string(); 122 } 123 if let Some(real_ip) = headers.get("x-real-ip") 124 && let Ok(value) = real_ip.to_str() 125 { 126 return value.trim().to_string(); 127 } 128 "unknown".to_string() 129} 130 131#[cfg(test)] 132mod tests { 133 use super::*; 134 135 #[test] 136 fn test_parse_repeated_query_param_repeated() { 137 let query = "did=test&cids=a&cids=b&cids=c"; 138 let result = parse_repeated_query_param(Some(query), "cids"); 139 assert_eq!(result, vec!["a", "b", "c"]); 140 } 141 142 #[test] 143 fn test_parse_repeated_query_param_comma_separated() { 144 let query = "did=test&cids=a,b,c"; 145 let result = parse_repeated_query_param(Some(query), "cids"); 146 assert_eq!(result, vec!["a", "b", "c"]); 147 } 148 149 #[test] 150 fn test_parse_repeated_query_param_mixed() { 151 let query = "did=test&cids=a,b&cids=c"; 152 let result = parse_repeated_query_param(Some(query), "cids"); 153 assert_eq!(result, vec!["a", "b", "c"]); 154 } 155 156 #[test] 157 fn test_parse_repeated_query_param_single() { 158 let query = "did=test&cids=a"; 159 let result = parse_repeated_query_param(Some(query), "cids"); 160 assert_eq!(result, vec!["a"]); 161 } 162 163 #[test] 164 fn test_parse_repeated_query_param_empty() { 165 let query = "did=test"; 166 let result = parse_repeated_query_param(Some(query), "cids"); 167 assert!(result.is_empty()); 168 } 169 170 #[test] 171 fn test_parse_repeated_query_param_url_encoded() { 172 let query = "did=test&cids=bafyreib%2Btest"; 173 let result = parse_repeated_query_param(Some(query), "cids"); 174 assert_eq!(result, vec!["bafyreib+test"]); 175 } 176 177 #[test] 178 fn test_generate_token_code() { 179 let code = generate_token_code(); 180 assert_eq!(code.len(), 11); 181 assert!(code.contains('-')); 182 183 let parts: Vec<&str> = code.split('-').collect(); 184 assert_eq!(parts.len(), 2); 185 assert_eq!(parts[0].len(), 5); 186 assert_eq!(parts[1].len(), 5); 187 188 for c in code.chars() { 189 if c != '-' { 190 assert!(BASE32_ALPHABET.contains(c)); 191 } 192 } 193 } 194 195 #[test] 196 fn test_generate_token_code_parts() { 197 let code = generate_token_code_parts(3, 4); 198 let parts: Vec<&str> = code.split('-').collect(); 199 assert_eq!(parts.len(), 3); 200 201 for part in parts { 202 assert_eq!(part.len(), 4); 203 } 204 } 205}