this repo has no description
1use axum::http::HeaderMap;
2use cid::Cid;
3use ipld_core::ipld::Ipld;
4use rand::Rng;
5use serde_json::Value as JsonValue;
6use sqlx::PgPool;
7use std::collections::BTreeMap;
8use std::str::FromStr;
9use std::sync::OnceLock;
10use uuid::Uuid;
11
12const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
13const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
14
15static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
16
17pub fn get_max_blob_size() -> usize {
18 *MAX_BLOB_SIZE.get_or_init(|| {
19 std::env::var("MAX_BLOB_SIZE")
20 .ok()
21 .and_then(|s| s.parse().ok())
22 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
23 })
24}
25
26pub fn generate_token_code() -> String {
27 generate_token_code_parts(2, 5)
28}
29
30pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
31 let mut rng = rand::thread_rng();
32 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
33
34 (0..parts)
35 .map(|_| {
36 (0..part_len)
37 .map(|_| chars[rng.gen_range(0..chars.len())])
38 .collect::<String>()
39 })
40 .collect::<Vec<_>>()
41 .join("-")
42}
43
44#[derive(Debug)]
45pub enum DbLookupError {
46 NotFound,
47 DatabaseError(sqlx::Error),
48}
49
50impl From<sqlx::Error> for DbLookupError {
51 fn from(e: sqlx::Error) -> Self {
52 DbLookupError::DatabaseError(e)
53 }
54}
55
56pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
57 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
58 .fetch_optional(db)
59 .await?
60 .ok_or(DbLookupError::NotFound)
61}
62
63pub struct UserInfo {
64 pub id: Uuid,
65 pub did: String,
66 pub handle: String,
67}
68
69pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
70 sqlx::query_as!(
71 UserInfo,
72 "SELECT id, did, handle FROM users WHERE did = $1",
73 did
74 )
75 .fetch_optional(db)
76 .await?
77 .ok_or(DbLookupError::NotFound)
78}
79
80pub async fn get_user_by_identifier(
81 db: &PgPool,
82 identifier: &str,
83) -> Result<UserInfo, DbLookupError> {
84 sqlx::query_as!(
85 UserInfo,
86 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
87 identifier
88 )
89 .fetch_optional(db)
90 .await?
91 .ok_or(DbLookupError::NotFound)
92}
93
94pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
95 let row = sqlx::query!(
96 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#,
97 did
98 )
99 .fetch_optional(db)
100 .await?;
101 Ok(row.map(|r| r.migrated).unwrap_or(false))
102}
103
104pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
105 query
106 .map(|q| {
107 let mut values = Vec::new();
108 for pair in q.split('&') {
109 if let Some((k, v)) = pair.split_once('=')
110 && k == key
111 && let Ok(decoded) = urlencoding::decode(v)
112 {
113 let decoded = decoded.into_owned();
114 if decoded.contains(',') {
115 for part in decoded.split(',') {
116 let trimmed = part.trim();
117 if !trimmed.is_empty() {
118 values.push(trimmed.to_string());
119 }
120 }
121 } else if !decoded.is_empty() {
122 values.push(decoded);
123 }
124 }
125 }
126 values
127 })
128 .unwrap_or_default()
129}
130
131pub fn extract_client_ip(headers: &HeaderMap) -> String {
132 if let Some(forwarded) = headers.get("x-forwarded-for")
133 && let Ok(value) = forwarded.to_str()
134 && let Some(first_ip) = value.split(',').next()
135 {
136 return first_ip.trim().to_string();
137 }
138 if let Some(real_ip) = headers.get("x-real-ip")
139 && let Ok(value) = real_ip.to_str()
140 {
141 return value.trim().to_string();
142 }
143 "unknown".to_string()
144}
145
146pub fn pds_hostname() -> String {
147 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
148}
149
150pub fn pds_public_url() -> String {
151 format!("https://{}", pds_hostname())
152}
153
154pub fn build_full_url(path: &str) -> String {
155 format!("{}{}", pds_public_url(), path)
156}
157
158pub fn json_to_ipld(value: &JsonValue) -> Ipld {
159 match value {
160 JsonValue::Null => Ipld::Null,
161 JsonValue::Bool(b) => Ipld::Bool(*b),
162 JsonValue::Number(n) => {
163 if let Some(i) = n.as_i64() {
164 Ipld::Integer(i as i128)
165 } else if let Some(f) = n.as_f64() {
166 Ipld::Float(f)
167 } else {
168 Ipld::Null
169 }
170 }
171 JsonValue::String(s) => Ipld::String(s.clone()),
172 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()),
173 JsonValue::Object(obj) => {
174 if let Some(JsonValue::String(link)) = obj.get("$link")
175 && obj.len() == 1
176 && let Ok(cid) = Cid::from_str(link)
177 {
178 return Ipld::Link(cid);
179 }
180 let map: BTreeMap<String, Ipld> = obj
181 .iter()
182 .map(|(k, v)| (k.clone(), json_to_ipld(v)))
183 .collect();
184 Ipld::Map(map)
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_parse_repeated_query_param_repeated() {
195 let query = "did=test&cids=a&cids=b&cids=c";
196 let result = parse_repeated_query_param(Some(query), "cids");
197 assert_eq!(result, vec!["a", "b", "c"]);
198 }
199
200 #[test]
201 fn test_parse_repeated_query_param_comma_separated() {
202 let query = "did=test&cids=a,b,c";
203 let result = parse_repeated_query_param(Some(query), "cids");
204 assert_eq!(result, vec!["a", "b", "c"]);
205 }
206
207 #[test]
208 fn test_parse_repeated_query_param_mixed() {
209 let query = "did=test&cids=a,b&cids=c";
210 let result = parse_repeated_query_param(Some(query), "cids");
211 assert_eq!(result, vec!["a", "b", "c"]);
212 }
213
214 #[test]
215 fn test_parse_repeated_query_param_single() {
216 let query = "did=test&cids=a";
217 let result = parse_repeated_query_param(Some(query), "cids");
218 assert_eq!(result, vec!["a"]);
219 }
220
221 #[test]
222 fn test_parse_repeated_query_param_empty() {
223 let query = "did=test";
224 let result = parse_repeated_query_param(Some(query), "cids");
225 assert!(result.is_empty());
226 }
227
228 #[test]
229 fn test_parse_repeated_query_param_url_encoded() {
230 let query = "did=test&cids=bafyreib%2Btest";
231 let result = parse_repeated_query_param(Some(query), "cids");
232 assert_eq!(result, vec!["bafyreib+test"]);
233 }
234
235 #[test]
236 fn test_generate_token_code() {
237 let code = generate_token_code();
238 assert_eq!(code.len(), 11);
239 assert!(code.contains('-'));
240
241 let parts: Vec<&str> = code.split('-').collect();
242 assert_eq!(parts.len(), 2);
243 assert_eq!(parts[0].len(), 5);
244 assert_eq!(parts[1].len(), 5);
245
246 for c in code.chars() {
247 if c != '-' {
248 assert!(BASE32_ALPHABET.contains(c));
249 }
250 }
251 }
252
253 #[test]
254 fn test_generate_token_code_parts() {
255 let code = generate_token_code_parts(3, 4);
256 let parts: Vec<&str> = code.split('-').collect();
257 assert_eq!(parts.len(), 3);
258
259 for part in parts {
260 assert_eq!(part.len(), 4);
261 }
262 }
263
264 #[test]
265 fn test_json_to_ipld_cid_link() {
266 let json = serde_json::json!({
267 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
268 });
269 let ipld = json_to_ipld(&json);
270 match ipld {
271 Ipld::Link(cid) => {
272 assert_eq!(
273 cid.to_string(),
274 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
275 );
276 }
277 _ => panic!("Expected Ipld::Link, got {:?}", ipld),
278 }
279 }
280
281 #[test]
282 fn test_json_to_ipld_blob_ref() {
283 let json = serde_json::json!({
284 "$type": "blob",
285 "ref": {
286 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
287 },
288 "mimeType": "image/jpeg",
289 "size": 12345
290 });
291 let ipld = json_to_ipld(&json);
292 match ipld {
293 Ipld::Map(map) => {
294 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string())));
295 assert_eq!(
296 map.get("mimeType"),
297 Some(&Ipld::String("image/jpeg".to_string()))
298 );
299 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345)));
300 match map.get("ref") {
301 Some(Ipld::Link(cid)) => {
302 assert_eq!(
303 cid.to_string(),
304 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
305 );
306 }
307 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")),
308 }
309 }
310 _ => panic!("Expected Ipld::Map, got {:?}", ipld),
311 }
312 }
313
314 #[test]
315 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() {
316 let record = serde_json::json!({
317 "$type": "app.bsky.feed.post",
318 "text": "Hello world",
319 "embed": {
320 "$type": "app.bsky.embed.images",
321 "images": [
322 {
323 "alt": "Test image",
324 "image": {
325 "$type": "blob",
326 "ref": {
327 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
328 },
329 "mimeType": "image/jpeg",
330 "size": 12345
331 }
332 }
333 ]
334 }
335 });
336 let ipld = json_to_ipld(&record);
337 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed");
338 assert!(!cbor_bytes.is_empty());
339 let parsed: Ipld =
340 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed");
341 if let Ipld::Map(map) = &parsed
342 && let Some(Ipld::Map(embed)) = map.get("embed")
343 && let Some(Ipld::List(images)) = embed.get("images")
344 && let Some(Ipld::Map(img)) = images.first()
345 && let Some(Ipld::Map(blob)) = img.get("image")
346 && let Some(Ipld::Link(cid)) = blob.get("ref")
347 {
348 assert_eq!(
349 cid.to_string(),
350 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
351 );
352 return;
353 }
354 panic!("Failed to find CID link in parsed CBOR");
355 }
356}