this repo has no description
1use axum::http::HeaderMap;
2use rand::Rng;
3use sqlx::PgPool;
4use uuid::Uuid;
5
6const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
7
8pub fn generate_token_code() -> String {
9 generate_token_code_parts(2, 5)
10}
11
12pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
13 let mut rng = rand::thread_rng();
14 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
15
16 (0..parts)
17 .map(|_| {
18 (0..part_len)
19 .map(|_| chars[rng.gen_range(0..chars.len())])
20 .collect::<String>()
21 })
22 .collect::<Vec<_>>()
23 .join("-")
24}
25
26#[derive(Debug)]
27pub enum DbLookupError {
28 NotFound,
29 DatabaseError(sqlx::Error),
30}
31
32impl From<sqlx::Error> for DbLookupError {
33 fn from(e: sqlx::Error) -> Self {
34 DbLookupError::DatabaseError(e)
35 }
36}
37
38pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
39 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
40 .fetch_optional(db)
41 .await?
42 .ok_or(DbLookupError::NotFound)
43}
44
45pub struct UserInfo {
46 pub id: Uuid,
47 pub did: String,
48 pub handle: String,
49}
50
51pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
52 sqlx::query_as!(
53 UserInfo,
54 "SELECT id, did, handle FROM users WHERE did = $1",
55 did
56 )
57 .fetch_optional(db)
58 .await?
59 .ok_or(DbLookupError::NotFound)
60}
61
62pub async fn get_user_by_identifier(
63 db: &PgPool,
64 identifier: &str,
65) -> Result<UserInfo, DbLookupError> {
66 sqlx::query_as!(
67 UserInfo,
68 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
69 identifier
70 )
71 .fetch_optional(db)
72 .await?
73 .ok_or(DbLookupError::NotFound)
74}
75
76pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
77 query
78 .map(|q| {
79 let mut values = Vec::new();
80 for pair in q.split('&') {
81 if let Some((k, v)) = pair.split_once('=')
82 && k == key
83 && let Ok(decoded) = urlencoding::decode(v)
84 {
85 let decoded = decoded.into_owned();
86 if decoded.contains(',') {
87 for part in decoded.split(',') {
88 let trimmed = part.trim();
89 if !trimmed.is_empty() {
90 values.push(trimmed.to_string());
91 }
92 }
93 } else if !decoded.is_empty() {
94 values.push(decoded);
95 }
96 }
97 }
98 values
99 })
100 .unwrap_or_default()
101}
102
103pub fn extract_client_ip(headers: &HeaderMap) -> String {
104 if let Some(forwarded) = headers.get("x-forwarded-for")
105 && let Ok(value) = forwarded.to_str()
106 && let Some(first_ip) = value.split(',').next()
107 {
108 return first_ip.trim().to_string();
109 }
110 if let Some(real_ip) = headers.get("x-real-ip")
111 && let Ok(value) = real_ip.to_str()
112 {
113 return value.trim().to_string();
114 }
115 "unknown".to_string()
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_parse_repeated_query_param_repeated() {
124 let query = "did=test&cids=a&cids=b&cids=c";
125 let result = parse_repeated_query_param(Some(query), "cids");
126 assert_eq!(result, vec!["a", "b", "c"]);
127 }
128
129 #[test]
130 fn test_parse_repeated_query_param_comma_separated() {
131 let query = "did=test&cids=a,b,c";
132 let result = parse_repeated_query_param(Some(query), "cids");
133 assert_eq!(result, vec!["a", "b", "c"]);
134 }
135
136 #[test]
137 fn test_parse_repeated_query_param_mixed() {
138 let query = "did=test&cids=a,b&cids=c";
139 let result = parse_repeated_query_param(Some(query), "cids");
140 assert_eq!(result, vec!["a", "b", "c"]);
141 }
142
143 #[test]
144 fn test_parse_repeated_query_param_single() {
145 let query = "did=test&cids=a";
146 let result = parse_repeated_query_param(Some(query), "cids");
147 assert_eq!(result, vec!["a"]);
148 }
149
150 #[test]
151 fn test_parse_repeated_query_param_empty() {
152 let query = "did=test";
153 let result = parse_repeated_query_param(Some(query), "cids");
154 assert!(result.is_empty());
155 }
156
157 #[test]
158 fn test_parse_repeated_query_param_url_encoded() {
159 let query = "did=test&cids=bafyreib%2Btest";
160 let result = parse_repeated_query_param(Some(query), "cids");
161 assert_eq!(result, vec!["bafyreib+test"]);
162 }
163
164 #[test]
165 fn test_generate_token_code() {
166 let code = generate_token_code();
167 assert_eq!(code.len(), 11);
168 assert!(code.contains('-'));
169
170 let parts: Vec<&str> = code.split('-').collect();
171 assert_eq!(parts.len(), 2);
172 assert_eq!(parts[0].len(), 5);
173 assert_eq!(parts[1].len(), 5);
174
175 for c in code.chars() {
176 if c != '-' {
177 assert!(BASE32_ALPHABET.contains(c));
178 }
179 }
180 }
181
182 #[test]
183 fn test_generate_token_code_parts() {
184 let code = generate_token_code_parts(3, 4);
185 let parts: Vec<&str> = code.split('-').collect();
186 assert_eq!(parts.len(), 3);
187
188 for part in parts {
189 assert_eq!(part.len(), 4);
190 }
191 }
192}