this repo has no description
1use axum::http::HeaderMap; 2use rand::Rng; 3use sqlx::PgPool; 4use uuid::Uuid; 5 6const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 7 8pub fn generate_token_code() -> String { 9 generate_token_code_parts(2, 5) 10} 11 12pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String { 13 let mut rng = rand::thread_rng(); 14 let chars: Vec<char> = BASE32_ALPHABET.chars().collect(); 15 16 (0..parts) 17 .map(|_| { 18 (0..part_len) 19 .map(|_| chars[rng.gen_range(0..chars.len())]) 20 .collect::<String>() 21 }) 22 .collect::<Vec<_>>() 23 .join("-") 24} 25 26#[derive(Debug)] 27pub enum DbLookupError { 28 NotFound, 29 DatabaseError(sqlx::Error), 30} 31 32impl From<sqlx::Error> for DbLookupError { 33 fn from(e: sqlx::Error) -> Self { 34 DbLookupError::DatabaseError(e) 35 } 36} 37 38pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 39 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 40 .fetch_optional(db) 41 .await? 42 .ok_or(DbLookupError::NotFound) 43} 44 45pub struct UserInfo { 46 pub id: Uuid, 47 pub did: String, 48 pub handle: String, 49} 50 51pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 52 sqlx::query_as!( 53 UserInfo, 54 "SELECT id, did, handle FROM users WHERE did = $1", 55 did 56 ) 57 .fetch_optional(db) 58 .await? 59 .ok_or(DbLookupError::NotFound) 60} 61 62pub async fn get_user_by_identifier( 63 db: &PgPool, 64 identifier: &str, 65) -> Result<UserInfo, DbLookupError> { 66 sqlx::query_as!( 67 UserInfo, 68 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 69 identifier 70 ) 71 .fetch_optional(db) 72 .await? 73 .ok_or(DbLookupError::NotFound) 74} 75 76pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> { 77 query 78 .map(|q| { 79 let mut values = Vec::new(); 80 for pair in q.split('&') { 81 if let Some((k, v)) = pair.split_once('=') 82 && k == key 83 && let Ok(decoded) = urlencoding::decode(v) 84 { 85 let decoded = decoded.into_owned(); 86 if decoded.contains(',') { 87 for part in decoded.split(',') { 88 let trimmed = part.trim(); 89 if !trimmed.is_empty() { 90 values.push(trimmed.to_string()); 91 } 92 } 93 } else if !decoded.is_empty() { 94 values.push(decoded); 95 } 96 } 97 } 98 values 99 }) 100 .unwrap_or_default() 101} 102 103pub fn extract_client_ip(headers: &HeaderMap) -> String { 104 if let Some(forwarded) = headers.get("x-forwarded-for") 105 && let Ok(value) = forwarded.to_str() 106 && let Some(first_ip) = value.split(',').next() 107 { 108 return first_ip.trim().to_string(); 109 } 110 if let Some(real_ip) = headers.get("x-real-ip") 111 && let Ok(value) = real_ip.to_str() 112 { 113 return value.trim().to_string(); 114 } 115 "unknown".to_string() 116} 117 118#[cfg(test)] 119mod tests { 120 use super::*; 121 122 #[test] 123 fn test_parse_repeated_query_param_repeated() { 124 let query = "did=test&cids=a&cids=b&cids=c"; 125 let result = parse_repeated_query_param(Some(query), "cids"); 126 assert_eq!(result, vec!["a", "b", "c"]); 127 } 128 129 #[test] 130 fn test_parse_repeated_query_param_comma_separated() { 131 let query = "did=test&cids=a,b,c"; 132 let result = parse_repeated_query_param(Some(query), "cids"); 133 assert_eq!(result, vec!["a", "b", "c"]); 134 } 135 136 #[test] 137 fn test_parse_repeated_query_param_mixed() { 138 let query = "did=test&cids=a,b&cids=c"; 139 let result = parse_repeated_query_param(Some(query), "cids"); 140 assert_eq!(result, vec!["a", "b", "c"]); 141 } 142 143 #[test] 144 fn test_parse_repeated_query_param_single() { 145 let query = "did=test&cids=a"; 146 let result = parse_repeated_query_param(Some(query), "cids"); 147 assert_eq!(result, vec!["a"]); 148 } 149 150 #[test] 151 fn test_parse_repeated_query_param_empty() { 152 let query = "did=test"; 153 let result = parse_repeated_query_param(Some(query), "cids"); 154 assert!(result.is_empty()); 155 } 156 157 #[test] 158 fn test_parse_repeated_query_param_url_encoded() { 159 let query = "did=test&cids=bafyreib%2Btest"; 160 let result = parse_repeated_query_param(Some(query), "cids"); 161 assert_eq!(result, vec!["bafyreib+test"]); 162 } 163 164 #[test] 165 fn test_generate_token_code() { 166 let code = generate_token_code(); 167 assert_eq!(code.len(), 11); 168 assert!(code.contains('-')); 169 170 let parts: Vec<&str> = code.split('-').collect(); 171 assert_eq!(parts.len(), 2); 172 assert_eq!(parts[0].len(), 5); 173 assert_eq!(parts[1].len(), 5); 174 175 for c in code.chars() { 176 if c != '-' { 177 assert!(BASE32_ALPHABET.contains(c)); 178 } 179 } 180 } 181 182 #[test] 183 fn test_generate_token_code_parts() { 184 let code = generate_token_code_parts(3, 4); 185 let parts: Vec<&str> = code.split('-').collect(); 186 assert_eq!(parts.len(), 3); 187 188 for part in parts { 189 assert_eq!(part.len(), 4); 190 } 191 } 192}