this repo has no description
1use axum::http::HeaderMap;
2use rand::Rng;
3use sqlx::PgPool;
4use std::sync::OnceLock;
5use uuid::Uuid;
6
7const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
8const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
9
10static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
11
12pub fn get_max_blob_size() -> usize {
13 *MAX_BLOB_SIZE.get_or_init(|| {
14 std::env::var("MAX_BLOB_SIZE")
15 .ok()
16 .and_then(|s| s.parse().ok())
17 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
18 })
19}
20
21pub fn generate_token_code() -> String {
22 generate_token_code_parts(2, 5)
23}
24
25pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
26 let mut rng = rand::thread_rng();
27 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
28
29 (0..parts)
30 .map(|_| {
31 (0..part_len)
32 .map(|_| chars[rng.gen_range(0..chars.len())])
33 .collect::<String>()
34 })
35 .collect::<Vec<_>>()
36 .join("-")
37}
38
39#[derive(Debug)]
40pub enum DbLookupError {
41 NotFound,
42 DatabaseError(sqlx::Error),
43}
44
45impl From<sqlx::Error> for DbLookupError {
46 fn from(e: sqlx::Error) -> Self {
47 DbLookupError::DatabaseError(e)
48 }
49}
50
51pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
52 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
53 .fetch_optional(db)
54 .await?
55 .ok_or(DbLookupError::NotFound)
56}
57
58pub struct UserInfo {
59 pub id: Uuid,
60 pub did: String,
61 pub handle: String,
62}
63
64pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
65 sqlx::query_as!(
66 UserInfo,
67 "SELECT id, did, handle FROM users WHERE did = $1",
68 did
69 )
70 .fetch_optional(db)
71 .await?
72 .ok_or(DbLookupError::NotFound)
73}
74
75pub async fn get_user_by_identifier(
76 db: &PgPool,
77 identifier: &str,
78) -> Result<UserInfo, DbLookupError> {
79 sqlx::query_as!(
80 UserInfo,
81 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
82 identifier
83 )
84 .fetch_optional(db)
85 .await?
86 .ok_or(DbLookupError::NotFound)
87}
88
89pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
90 let row = sqlx::query!(
91 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#,
92 did
93 )
94 .fetch_optional(db)
95 .await?;
96 Ok(row.map(|r| r.migrated).unwrap_or(false))
97}
98
99pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
100 query
101 .map(|q| {
102 let mut values = Vec::new();
103 for pair in q.split('&') {
104 if let Some((k, v)) = pair.split_once('=')
105 && k == key
106 && let Ok(decoded) = urlencoding::decode(v)
107 {
108 let decoded = decoded.into_owned();
109 if decoded.contains(',') {
110 for part in decoded.split(',') {
111 let trimmed = part.trim();
112 if !trimmed.is_empty() {
113 values.push(trimmed.to_string());
114 }
115 }
116 } else if !decoded.is_empty() {
117 values.push(decoded);
118 }
119 }
120 }
121 values
122 })
123 .unwrap_or_default()
124}
125
126pub fn extract_client_ip(headers: &HeaderMap) -> String {
127 if let Some(forwarded) = headers.get("x-forwarded-for")
128 && let Ok(value) = forwarded.to_str()
129 && let Some(first_ip) = value.split(',').next()
130 {
131 return first_ip.trim().to_string();
132 }
133 if let Some(real_ip) = headers.get("x-real-ip")
134 && let Ok(value) = real_ip.to_str()
135 {
136 return value.trim().to_string();
137 }
138 "unknown".to_string()
139}
140
141pub fn pds_hostname() -> String {
142 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
143}
144
145pub fn pds_public_url() -> String {
146 format!("https://{}", pds_hostname())
147}
148
149pub fn build_full_url(path: &str) -> String {
150 format!("{}{}", pds_public_url(), path)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_parse_repeated_query_param_repeated() {
159 let query = "did=test&cids=a&cids=b&cids=c";
160 let result = parse_repeated_query_param(Some(query), "cids");
161 assert_eq!(result, vec!["a", "b", "c"]);
162 }
163
164 #[test]
165 fn test_parse_repeated_query_param_comma_separated() {
166 let query = "did=test&cids=a,b,c";
167 let result = parse_repeated_query_param(Some(query), "cids");
168 assert_eq!(result, vec!["a", "b", "c"]);
169 }
170
171 #[test]
172 fn test_parse_repeated_query_param_mixed() {
173 let query = "did=test&cids=a,b&cids=c";
174 let result = parse_repeated_query_param(Some(query), "cids");
175 assert_eq!(result, vec!["a", "b", "c"]);
176 }
177
178 #[test]
179 fn test_parse_repeated_query_param_single() {
180 let query = "did=test&cids=a";
181 let result = parse_repeated_query_param(Some(query), "cids");
182 assert_eq!(result, vec!["a"]);
183 }
184
185 #[test]
186 fn test_parse_repeated_query_param_empty() {
187 let query = "did=test";
188 let result = parse_repeated_query_param(Some(query), "cids");
189 assert!(result.is_empty());
190 }
191
192 #[test]
193 fn test_parse_repeated_query_param_url_encoded() {
194 let query = "did=test&cids=bafyreib%2Btest";
195 let result = parse_repeated_query_param(Some(query), "cids");
196 assert_eq!(result, vec!["bafyreib+test"]);
197 }
198
199 #[test]
200 fn test_generate_token_code() {
201 let code = generate_token_code();
202 assert_eq!(code.len(), 11);
203 assert!(code.contains('-'));
204
205 let parts: Vec<&str> = code.split('-').collect();
206 assert_eq!(parts.len(), 2);
207 assert_eq!(parts[0].len(), 5);
208 assert_eq!(parts[1].len(), 5);
209
210 for c in code.chars() {
211 if c != '-' {
212 assert!(BASE32_ALPHABET.contains(c));
213 }
214 }
215 }
216
217 #[test]
218 fn test_generate_token_code_parts() {
219 let code = generate_token_code_parts(3, 4);
220 let parts: Vec<&str> = code.split('-').collect();
221 assert_eq!(parts.len(), 3);
222
223 for part in parts {
224 assert_eq!(part.len(), 4);
225 }
226 }
227}