this repo has no description
1use axum::http::HeaderMap;
2use cid::Cid;
3use ipld_core::ipld::Ipld;
4use rand::Rng;
5use serde_json::Value as JsonValue;
6use sqlx::PgPool;
7use std::collections::BTreeMap;
8use std::str::FromStr;
9use std::sync::OnceLock;
10use uuid::Uuid;
11
12use crate::types::{Did, Handle};
13
14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
16
17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
18
19pub fn get_max_blob_size() -> usize {
20 *MAX_BLOB_SIZE.get_or_init(|| {
21 std::env::var("MAX_BLOB_SIZE")
22 .ok()
23 .and_then(|s| s.parse().ok())
24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
25 })
26}
27
28pub fn generate_token_code() -> String {
29 generate_token_code_parts(2, 5)
30}
31
32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
33 let mut rng = rand::thread_rng();
34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
35
36 (0..parts)
37 .map(|_| {
38 (0..part_len)
39 .map(|_| chars[rng.gen_range(0..chars.len())])
40 .collect::<String>()
41 })
42 .collect::<Vec<_>>()
43 .join("-")
44}
45
46#[derive(Debug)]
47pub enum DbLookupError {
48 NotFound,
49 DatabaseError(sqlx::Error),
50}
51
52impl From<sqlx::Error> for DbLookupError {
53 fn from(e: sqlx::Error) -> Self {
54 DbLookupError::DatabaseError(e)
55 }
56}
57
58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
60 .fetch_optional(db)
61 .await?
62 .ok_or(DbLookupError::NotFound)
63}
64
65pub struct UserInfo {
66 pub id: Uuid,
67 pub did: Did,
68 pub handle: Handle,
69}
70
71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
72 sqlx::query_as!(
73 UserInfo,
74 "SELECT id, did, handle FROM users WHERE did = $1",
75 did
76 )
77 .fetch_optional(db)
78 .await?
79 .ok_or(DbLookupError::NotFound)
80}
81
82pub async fn get_user_by_identifier(
83 db: &PgPool,
84 identifier: &str,
85) -> Result<UserInfo, DbLookupError> {
86 sqlx::query_as!(
87 UserInfo,
88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
89 identifier
90 )
91 .fetch_optional(db)
92 .await?
93 .ok_or(DbLookupError::NotFound)
94}
95
96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
97 let row = sqlx::query!(
98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#,
99 did
100 )
101 .fetch_optional(db)
102 .await?;
103 Ok(row.map(|r| r.migrated).unwrap_or(false))
104}
105
106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
107 query
108 .map(|q| {
109 let mut values = Vec::new();
110 for pair in q.split('&') {
111 if let Some((k, v)) = pair.split_once('=')
112 && k == key
113 && let Ok(decoded) = urlencoding::decode(v)
114 {
115 let decoded = decoded.into_owned();
116 if decoded.contains(',') {
117 for part in decoded.split(',') {
118 let trimmed = part.trim();
119 if !trimmed.is_empty() {
120 values.push(trimmed.to_string());
121 }
122 }
123 } else if !decoded.is_empty() {
124 values.push(decoded);
125 }
126 }
127 }
128 values
129 })
130 .unwrap_or_default()
131}
132
133pub fn extract_client_ip(headers: &HeaderMap) -> String {
134 if let Some(forwarded) = headers.get("x-forwarded-for")
135 && let Ok(value) = forwarded.to_str()
136 && let Some(first_ip) = value.split(',').next()
137 {
138 return first_ip.trim().to_string();
139 }
140 if let Some(real_ip) = headers.get("x-real-ip")
141 && let Ok(value) = real_ip.to_str()
142 {
143 return value.trim().to_string();
144 }
145 "unknown".to_string()
146}
147
148pub fn pds_hostname() -> String {
149 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
150}
151
152pub fn pds_public_url() -> String {
153 format!("https://{}", pds_hostname())
154}
155
156pub fn build_full_url(path: &str) -> String {
157 format!("{}{}", pds_public_url(), path)
158}
159
160pub fn json_to_ipld(value: &JsonValue) -> Ipld {
161 match value {
162 JsonValue::Null => Ipld::Null,
163 JsonValue::Bool(b) => Ipld::Bool(*b),
164 JsonValue::Number(n) => {
165 if let Some(i) = n.as_i64() {
166 Ipld::Integer(i as i128)
167 } else if let Some(f) = n.as_f64() {
168 Ipld::Float(f)
169 } else {
170 Ipld::Null
171 }
172 }
173 JsonValue::String(s) => Ipld::String(s.clone()),
174 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()),
175 JsonValue::Object(obj) => {
176 if let Some(JsonValue::String(link)) = obj.get("$link")
177 && obj.len() == 1
178 && let Ok(cid) = Cid::from_str(link)
179 {
180 return Ipld::Link(cid);
181 }
182 let map: BTreeMap<String, Ipld> = obj
183 .iter()
184 .map(|(k, v)| (k.clone(), json_to_ipld(v)))
185 .collect();
186 Ipld::Map(map)
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_parse_repeated_query_param_repeated() {
197 let query = "did=test&cids=a&cids=b&cids=c";
198 let result = parse_repeated_query_param(Some(query), "cids");
199 assert_eq!(result, vec!["a", "b", "c"]);
200 }
201
202 #[test]
203 fn test_parse_repeated_query_param_comma_separated() {
204 let query = "did=test&cids=a,b,c";
205 let result = parse_repeated_query_param(Some(query), "cids");
206 assert_eq!(result, vec!["a", "b", "c"]);
207 }
208
209 #[test]
210 fn test_parse_repeated_query_param_mixed() {
211 let query = "did=test&cids=a,b&cids=c";
212 let result = parse_repeated_query_param(Some(query), "cids");
213 assert_eq!(result, vec!["a", "b", "c"]);
214 }
215
216 #[test]
217 fn test_parse_repeated_query_param_single() {
218 let query = "did=test&cids=a";
219 let result = parse_repeated_query_param(Some(query), "cids");
220 assert_eq!(result, vec!["a"]);
221 }
222
223 #[test]
224 fn test_parse_repeated_query_param_empty() {
225 let query = "did=test";
226 let result = parse_repeated_query_param(Some(query), "cids");
227 assert!(result.is_empty());
228 }
229
230 #[test]
231 fn test_parse_repeated_query_param_url_encoded() {
232 let query = "did=test&cids=bafyreib%2Btest";
233 let result = parse_repeated_query_param(Some(query), "cids");
234 assert_eq!(result, vec!["bafyreib+test"]);
235 }
236
237 #[test]
238 fn test_generate_token_code() {
239 let code = generate_token_code();
240 assert_eq!(code.len(), 11);
241 assert!(code.contains('-'));
242
243 let parts: Vec<&str> = code.split('-').collect();
244 assert_eq!(parts.len(), 2);
245 assert_eq!(parts[0].len(), 5);
246 assert_eq!(parts[1].len(), 5);
247
248 for c in code.chars() {
249 if c != '-' {
250 assert!(BASE32_ALPHABET.contains(c));
251 }
252 }
253 }
254
255 #[test]
256 fn test_generate_token_code_parts() {
257 let code = generate_token_code_parts(3, 4);
258 let parts: Vec<&str> = code.split('-').collect();
259 assert_eq!(parts.len(), 3);
260
261 for part in parts {
262 assert_eq!(part.len(), 4);
263 }
264 }
265
266 #[test]
267 fn test_json_to_ipld_cid_link() {
268 let json = serde_json::json!({
269 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
270 });
271 let ipld = json_to_ipld(&json);
272 match ipld {
273 Ipld::Link(cid) => {
274 assert_eq!(
275 cid.to_string(),
276 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
277 );
278 }
279 _ => panic!("Expected Ipld::Link, got {:?}", ipld),
280 }
281 }
282
283 #[test]
284 fn test_json_to_ipld_blob_ref() {
285 let json = serde_json::json!({
286 "$type": "blob",
287 "ref": {
288 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
289 },
290 "mimeType": "image/jpeg",
291 "size": 12345
292 });
293 let ipld = json_to_ipld(&json);
294 match ipld {
295 Ipld::Map(map) => {
296 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string())));
297 assert_eq!(
298 map.get("mimeType"),
299 Some(&Ipld::String("image/jpeg".to_string()))
300 );
301 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345)));
302 match map.get("ref") {
303 Some(Ipld::Link(cid)) => {
304 assert_eq!(
305 cid.to_string(),
306 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
307 );
308 }
309 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")),
310 }
311 }
312 _ => panic!("Expected Ipld::Map, got {:?}", ipld),
313 }
314 }
315
316 #[test]
317 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() {
318 let record = serde_json::json!({
319 "$type": "app.bsky.feed.post",
320 "text": "Hello world",
321 "embed": {
322 "$type": "app.bsky.embed.images",
323 "images": [
324 {
325 "alt": "Test image",
326 "image": {
327 "$type": "blob",
328 "ref": {
329 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
330 },
331 "mimeType": "image/jpeg",
332 "size": 12345
333 }
334 }
335 ]
336 }
337 });
338 let ipld = json_to_ipld(&record);
339 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed");
340 assert!(!cbor_bytes.is_empty());
341 let parsed: Ipld =
342 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed");
343 if let Ipld::Map(map) = &parsed
344 && let Some(Ipld::Map(embed)) = map.get("embed")
345 && let Some(Ipld::List(images)) = embed.get("images")
346 && let Some(Ipld::Map(img)) = images.first()
347 && let Some(Ipld::Map(blob)) = img.get("image")
348 && let Some(Ipld::Link(cid)) = blob.get("ref")
349 {
350 assert_eq!(
351 cid.to_string(),
352 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
353 );
354 return;
355 }
356 panic!("Failed to find CID link in parsed CBOR");
357 }
358}