this repo has no description
1use axum::http::HeaderMap;
2use rand::Rng;
3use sqlx::PgPool;
4use std::sync::OnceLock;
5use uuid::Uuid;
6
7const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
8const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
9
10static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
11
12pub fn get_max_blob_size() -> usize {
13 *MAX_BLOB_SIZE.get_or_init(|| {
14 std::env::var("MAX_BLOB_SIZE")
15 .ok()
16 .and_then(|s| s.parse().ok())
17 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
18 })
19}
20
21pub fn generate_token_code() -> String {
22 generate_token_code_parts(2, 5)
23}
24
25pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
26 let mut rng = rand::thread_rng();
27 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
28
29 (0..parts)
30 .map(|_| {
31 (0..part_len)
32 .map(|_| chars[rng.gen_range(0..chars.len())])
33 .collect::<String>()
34 })
35 .collect::<Vec<_>>()
36 .join("-")
37}
38
39#[derive(Debug)]
40pub enum DbLookupError {
41 NotFound,
42 DatabaseError(sqlx::Error),
43}
44
45impl From<sqlx::Error> for DbLookupError {
46 fn from(e: sqlx::Error) -> Self {
47 DbLookupError::DatabaseError(e)
48 }
49}
50
51pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
52 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
53 .fetch_optional(db)
54 .await?
55 .ok_or(DbLookupError::NotFound)
56}
57
58pub struct UserInfo {
59 pub id: Uuid,
60 pub did: String,
61 pub handle: String,
62}
63
64pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
65 sqlx::query_as!(
66 UserInfo,
67 "SELECT id, did, handle FROM users WHERE did = $1",
68 did
69 )
70 .fetch_optional(db)
71 .await?
72 .ok_or(DbLookupError::NotFound)
73}
74
75pub async fn get_user_by_identifier(
76 db: &PgPool,
77 identifier: &str,
78) -> Result<UserInfo, DbLookupError> {
79 sqlx::query_as!(
80 UserInfo,
81 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
82 identifier
83 )
84 .fetch_optional(db)
85 .await?
86 .ok_or(DbLookupError::NotFound)
87}
88
89pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
90 query
91 .map(|q| {
92 let mut values = Vec::new();
93 for pair in q.split('&') {
94 if let Some((k, v)) = pair.split_once('=')
95 && k == key
96 && let Ok(decoded) = urlencoding::decode(v)
97 {
98 let decoded = decoded.into_owned();
99 if decoded.contains(',') {
100 for part in decoded.split(',') {
101 let trimmed = part.trim();
102 if !trimmed.is_empty() {
103 values.push(trimmed.to_string());
104 }
105 }
106 } else if !decoded.is_empty() {
107 values.push(decoded);
108 }
109 }
110 }
111 values
112 })
113 .unwrap_or_default()
114}
115
116pub fn extract_client_ip(headers: &HeaderMap) -> String {
117 if let Some(forwarded) = headers.get("x-forwarded-for")
118 && let Ok(value) = forwarded.to_str()
119 && let Some(first_ip) = value.split(',').next()
120 {
121 return first_ip.trim().to_string();
122 }
123 if let Some(real_ip) = headers.get("x-real-ip")
124 && let Ok(value) = real_ip.to_str()
125 {
126 return value.trim().to_string();
127 }
128 "unknown".to_string()
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_parse_repeated_query_param_repeated() {
137 let query = "did=test&cids=a&cids=b&cids=c";
138 let result = parse_repeated_query_param(Some(query), "cids");
139 assert_eq!(result, vec!["a", "b", "c"]);
140 }
141
142 #[test]
143 fn test_parse_repeated_query_param_comma_separated() {
144 let query = "did=test&cids=a,b,c";
145 let result = parse_repeated_query_param(Some(query), "cids");
146 assert_eq!(result, vec!["a", "b", "c"]);
147 }
148
149 #[test]
150 fn test_parse_repeated_query_param_mixed() {
151 let query = "did=test&cids=a,b&cids=c";
152 let result = parse_repeated_query_param(Some(query), "cids");
153 assert_eq!(result, vec!["a", "b", "c"]);
154 }
155
156 #[test]
157 fn test_parse_repeated_query_param_single() {
158 let query = "did=test&cids=a";
159 let result = parse_repeated_query_param(Some(query), "cids");
160 assert_eq!(result, vec!["a"]);
161 }
162
163 #[test]
164 fn test_parse_repeated_query_param_empty() {
165 let query = "did=test";
166 let result = parse_repeated_query_param(Some(query), "cids");
167 assert!(result.is_empty());
168 }
169
170 #[test]
171 fn test_parse_repeated_query_param_url_encoded() {
172 let query = "did=test&cids=bafyreib%2Btest";
173 let result = parse_repeated_query_param(Some(query), "cids");
174 assert_eq!(result, vec!["bafyreib+test"]);
175 }
176
177 #[test]
178 fn test_generate_token_code() {
179 let code = generate_token_code();
180 assert_eq!(code.len(), 11);
181 assert!(code.contains('-'));
182
183 let parts: Vec<&str> = code.split('-').collect();
184 assert_eq!(parts.len(), 2);
185 assert_eq!(parts[0].len(), 5);
186 assert_eq!(parts[1].len(), 5);
187
188 for c in code.chars() {
189 if c != '-' {
190 assert!(BASE32_ALPHABET.contains(c));
191 }
192 }
193 }
194
195 #[test]
196 fn test_generate_token_code_parts() {
197 let code = generate_token_code_parts(3, 4);
198 let parts: Vec<&str> = code.split('-').collect();
199 assert_eq!(parts.len(), 3);
200
201 for part in parts {
202 assert_eq!(part.len(), 4);
203 }
204 }
205}