this repo has no description
1use axum::http::HeaderMap;
2use cid::Cid;
3use ipld_core::ipld::Ipld;
4use rand::Rng;
5use serde_json::Value as JsonValue;
6use sqlx::PgPool;
7use std::collections::BTreeMap;
8use std::str::FromStr;
9use std::sync::OnceLock;
10use uuid::Uuid;
11
12use crate::types::{Did, Handle};
13
14const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
15const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
16
17static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
18
19pub fn get_max_blob_size() -> usize {
20 *MAX_BLOB_SIZE.get_or_init(|| {
21 std::env::var("MAX_BLOB_SIZE")
22 .ok()
23 .and_then(|s| s.parse().ok())
24 .unwrap_or(DEFAULT_MAX_BLOB_SIZE)
25 })
26}
27
28pub fn generate_token_code() -> String {
29 generate_token_code_parts(2, 5)
30}
31
32pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
33 let mut rng = rand::thread_rng();
34 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
35
36 (0..parts)
37 .map(|_| {
38 (0..part_len)
39 .map(|_| chars[rng.gen_range(0..chars.len())])
40 .collect::<String>()
41 })
42 .collect::<Vec<_>>()
43 .join("-")
44}
45
46#[derive(Debug)]
47pub enum DbLookupError {
48 NotFound,
49 DatabaseError(sqlx::Error),
50}
51
52impl From<sqlx::Error> for DbLookupError {
53 fn from(e: sqlx::Error) -> Self {
54 DbLookupError::DatabaseError(e)
55 }
56}
57
58pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
59 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
60 .fetch_optional(db)
61 .await?
62 .ok_or(DbLookupError::NotFound)
63}
64
65pub struct UserInfo {
66 pub id: Uuid,
67 pub did: Did,
68 pub handle: Handle,
69}
70
71pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
72 sqlx::query_as!(
73 UserInfo,
74 "SELECT id, did, handle FROM users WHERE did = $1",
75 did
76 )
77 .fetch_optional(db)
78 .await?
79 .ok_or(DbLookupError::NotFound)
80}
81
82pub async fn get_user_by_identifier(
83 db: &PgPool,
84 identifier: &str,
85) -> Result<UserInfo, DbLookupError> {
86 sqlx::query_as!(
87 UserInfo,
88 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
89 identifier
90 )
91 .fetch_optional(db)
92 .await?
93 .ok_or(DbLookupError::NotFound)
94}
95
96pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> {
97 let row = sqlx::query!(
98 r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#,
99 did
100 )
101 .fetch_optional(db)
102 .await?;
103 Ok(row.map(|r| r.migrated).unwrap_or(false))
104}
105
106pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
107 query
108 .map(|q| {
109 q.split('&')
110 .filter_map(|pair| {
111 pair.split_once('=')
112 .filter(|(k, _)| *k == key)
113 .and_then(|(_, v)| urlencoding::decode(v).ok())
114 .map(|decoded| decoded.into_owned())
115 })
116 .flat_map(|decoded| {
117 if decoded.contains(',') {
118 decoded
119 .split(',')
120 .filter_map(|part| {
121 let trimmed = part.trim();
122 (!trimmed.is_empty()).then(|| trimmed.to_string())
123 })
124 .collect::<Vec<_>>()
125 } else if decoded.is_empty() {
126 vec![]
127 } else {
128 vec![decoded]
129 }
130 })
131 .collect()
132 })
133 .unwrap_or_default()
134}
135
136pub fn extract_client_ip(headers: &HeaderMap) -> String {
137 if let Some(forwarded) = headers.get("x-forwarded-for")
138 && let Ok(value) = forwarded.to_str()
139 && let Some(first_ip) = value.split(',').next()
140 {
141 return first_ip.trim().to_string();
142 }
143 if let Some(real_ip) = headers.get("x-real-ip")
144 && let Ok(value) = real_ip.to_str()
145 {
146 return value.trim().to_string();
147 }
148 "unknown".to_string()
149}
150
151pub fn pds_hostname() -> String {
152 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
153}
154
155pub fn pds_public_url() -> String {
156 format!("https://{}", pds_hostname())
157}
158
159pub fn build_full_url(path: &str) -> String {
160 let normalized_path = if !path.starts_with("/xrpc/")
161 && (path.starts_with("/com.atproto.")
162 || path.starts_with("/app.bsky.")
163 || path.starts_with("/_"))
164 {
165 format!("/xrpc{}", path)
166 } else {
167 path.to_string()
168 };
169 format!("{}{}", pds_public_url(), normalized_path)
170}
171
172pub fn json_to_ipld(value: &JsonValue) -> Ipld {
173 match value {
174 JsonValue::Null => Ipld::Null,
175 JsonValue::Bool(b) => Ipld::Bool(*b),
176 JsonValue::Number(n) => {
177 if let Some(i) = n.as_i64() {
178 Ipld::Integer(i as i128)
179 } else if let Some(f) = n.as_f64() {
180 Ipld::Float(f)
181 } else {
182 Ipld::Null
183 }
184 }
185 JsonValue::String(s) => Ipld::String(s.clone()),
186 JsonValue::Array(arr) => Ipld::List(arr.iter().map(json_to_ipld).collect()),
187 JsonValue::Object(obj) => {
188 if let Some(JsonValue::String(link)) = obj.get("$link")
189 && obj.len() == 1
190 && let Ok(cid) = Cid::from_str(link)
191 {
192 return Ipld::Link(cid);
193 }
194 let map: BTreeMap<String, Ipld> = obj
195 .iter()
196 .map(|(k, v)| (k.clone(), json_to_ipld(v)))
197 .collect();
198 Ipld::Map(map)
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_parse_repeated_query_param_repeated() {
209 let query = "did=test&cids=a&cids=b&cids=c";
210 let result = parse_repeated_query_param(Some(query), "cids");
211 assert_eq!(result, vec!["a", "b", "c"]);
212 }
213
214 #[test]
215 fn test_parse_repeated_query_param_comma_separated() {
216 let query = "did=test&cids=a,b,c";
217 let result = parse_repeated_query_param(Some(query), "cids");
218 assert_eq!(result, vec!["a", "b", "c"]);
219 }
220
221 #[test]
222 fn test_parse_repeated_query_param_mixed() {
223 let query = "did=test&cids=a,b&cids=c";
224 let result = parse_repeated_query_param(Some(query), "cids");
225 assert_eq!(result, vec!["a", "b", "c"]);
226 }
227
228 #[test]
229 fn test_parse_repeated_query_param_single() {
230 let query = "did=test&cids=a";
231 let result = parse_repeated_query_param(Some(query), "cids");
232 assert_eq!(result, vec!["a"]);
233 }
234
235 #[test]
236 fn test_parse_repeated_query_param_empty() {
237 let query = "did=test";
238 let result = parse_repeated_query_param(Some(query), "cids");
239 assert!(result.is_empty());
240 }
241
242 #[test]
243 fn test_parse_repeated_query_param_url_encoded() {
244 let query = "did=test&cids=bafyreib%2Btest";
245 let result = parse_repeated_query_param(Some(query), "cids");
246 assert_eq!(result, vec!["bafyreib+test"]);
247 }
248
249 #[test]
250 fn test_generate_token_code() {
251 let code = generate_token_code();
252 assert_eq!(code.len(), 11);
253 assert!(code.contains('-'));
254
255 let parts: Vec<&str> = code.split('-').collect();
256 assert_eq!(parts.len(), 2);
257 assert_eq!(parts[0].len(), 5);
258 assert_eq!(parts[1].len(), 5);
259
260 for c in code.chars() {
261 if c != '-' {
262 assert!(BASE32_ALPHABET.contains(c));
263 }
264 }
265 }
266
267 #[test]
268 fn test_generate_token_code_parts() {
269 let code = generate_token_code_parts(3, 4);
270 let parts: Vec<&str> = code.split('-').collect();
271 assert_eq!(parts.len(), 3);
272
273 for part in parts {
274 assert_eq!(part.len(), 4);
275 }
276 }
277
278 #[test]
279 fn test_json_to_ipld_cid_link() {
280 let json = serde_json::json!({
281 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
282 });
283 let ipld = json_to_ipld(&json);
284 match ipld {
285 Ipld::Link(cid) => {
286 assert_eq!(
287 cid.to_string(),
288 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
289 );
290 }
291 _ => panic!("Expected Ipld::Link, got {:?}", ipld),
292 }
293 }
294
295 #[test]
296 fn test_json_to_ipld_blob_ref() {
297 let json = serde_json::json!({
298 "$type": "blob",
299 "ref": {
300 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
301 },
302 "mimeType": "image/jpeg",
303 "size": 12345
304 });
305 let ipld = json_to_ipld(&json);
306 match ipld {
307 Ipld::Map(map) => {
308 assert_eq!(map.get("$type"), Some(&Ipld::String("blob".to_string())));
309 assert_eq!(
310 map.get("mimeType"),
311 Some(&Ipld::String("image/jpeg".to_string()))
312 );
313 assert_eq!(map.get("size"), Some(&Ipld::Integer(12345)));
314 match map.get("ref") {
315 Some(Ipld::Link(cid)) => {
316 assert_eq!(
317 cid.to_string(),
318 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
319 );
320 }
321 _ => panic!("Expected Ipld::Link in ref field, got {:?}", map.get("ref")),
322 }
323 }
324 _ => panic!("Expected Ipld::Map, got {:?}", ipld),
325 }
326 }
327
328 #[test]
329 fn test_json_to_ipld_nested_blob_refs_serializes_correctly() {
330 let record = serde_json::json!({
331 "$type": "app.bsky.feed.post",
332 "text": "Hello world",
333 "embed": {
334 "$type": "app.bsky.embed.images",
335 "images": [
336 {
337 "alt": "Test image",
338 "image": {
339 "$type": "blob",
340 "ref": {
341 "$link": "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
342 },
343 "mimeType": "image/jpeg",
344 "size": 12345
345 }
346 }
347 ]
348 }
349 });
350 let ipld = json_to_ipld(&record);
351 let cbor_bytes = serde_ipld_dagcbor::to_vec(&ipld).expect("CBOR serialization failed");
352 assert!(!cbor_bytes.is_empty());
353 let parsed: Ipld =
354 serde_ipld_dagcbor::from_slice(&cbor_bytes).expect("CBOR deserialization failed");
355 if let Ipld::Map(map) = &parsed
356 && let Some(Ipld::Map(embed)) = map.get("embed")
357 && let Some(Ipld::List(images)) = embed.get("images")
358 && let Some(Ipld::Map(img)) = images.first()
359 && let Some(Ipld::Map(blob)) = img.get("image")
360 && let Some(Ipld::Link(cid)) = blob.get("ref")
361 {
362 assert_eq!(
363 cid.to_string(),
364 "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"
365 );
366 return;
367 }
368 panic!("Failed to find CID link in parsed CBOR");
369 }
370
371 #[test]
372 fn test_build_full_url_adds_xrpc_prefix_for_atproto_paths() {
373 unsafe { std::env::set_var("PDS_HOSTNAME", "example.com") };
374 assert_eq!(
375 build_full_url("/com.atproto.server.getSession"),
376 "https://example.com/xrpc/com.atproto.server.getSession"
377 );
378 assert_eq!(
379 build_full_url("/app.bsky.feed.getTimeline"),
380 "https://example.com/xrpc/app.bsky.feed.getTimeline"
381 );
382 assert_eq!(
383 build_full_url("/_health"),
384 "https://example.com/xrpc/_health"
385 );
386 assert_eq!(
387 build_full_url("/xrpc/com.atproto.server.getSession"),
388 "https://example.com/xrpc/com.atproto.server.getSession"
389 );
390 assert_eq!(
391 build_full_url("/oauth/token"),
392 "https://example.com/oauth/token"
393 );
394 }
395}