this repo has no description
1use axum::http::HeaderMap;
2use rand::Rng;
3use sqlx::PgPool;
4use uuid::Uuid;
5
6const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
7
8pub fn generate_token_code() -> String {
9 generate_token_code_parts(2, 5)
10}
11
12pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
13 let mut rng = rand::thread_rng();
14 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
15
16 (0..parts)
17 .map(|_| {
18 (0..part_len)
19 .map(|_| chars[rng.gen_range(0..chars.len())])
20 .collect::<String>()
21 })
22 .collect::<Vec<_>>()
23 .join("-")
24}
25
26#[derive(Debug)]
27pub enum DbLookupError {
28 NotFound,
29 DatabaseError(sqlx::Error),
30}
31
32impl From<sqlx::Error> for DbLookupError {
33 fn from(e: sqlx::Error) -> Self {
34 DbLookupError::DatabaseError(e)
35 }
36}
37
38pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
39 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
40 .fetch_optional(db)
41 .await?
42 .ok_or(DbLookupError::NotFound)
43}
44
45pub struct UserInfo {
46 pub id: Uuid,
47 pub did: String,
48 pub handle: String,
49}
50
51pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
52 sqlx::query_as!(
53 UserInfo,
54 "SELECT id, did, handle FROM users WHERE did = $1",
55 did
56 )
57 .fetch_optional(db)
58 .await?
59 .ok_or(DbLookupError::NotFound)
60}
61
62pub async fn get_user_by_identifier(
63 db: &PgPool,
64 identifier: &str,
65) -> Result<UserInfo, DbLookupError> {
66 sqlx::query_as!(
67 UserInfo,
68 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
69 identifier
70 )
71 .fetch_optional(db)
72 .await?
73 .ok_or(DbLookupError::NotFound)
74}
75
76pub fn extract_client_ip(headers: &HeaderMap) -> String {
77 if let Some(forwarded) = headers.get("x-forwarded-for")
78 && let Ok(value) = forwarded.to_str()
79 && let Some(first_ip) = value.split(',').next()
80 {
81 return first_ip.trim().to_string();
82 }
83 if let Some(real_ip) = headers.get("x-real-ip")
84 && let Ok(value) = real_ip.to_str()
85 {
86 return value.trim().to_string();
87 }
88 "unknown".to_string()
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 fn test_generate_token_code() {
97 let code = generate_token_code();
98 assert_eq!(code.len(), 11);
99 assert!(code.contains('-'));
100
101 let parts: Vec<&str> = code.split('-').collect();
102 assert_eq!(parts.len(), 2);
103 assert_eq!(parts[0].len(), 5);
104 assert_eq!(parts[1].len(), 5);
105
106 for c in code.chars() {
107 if c != '-' {
108 assert!(BASE32_ALPHABET.contains(c));
109 }
110 }
111 }
112
113 #[test]
114 fn test_generate_token_code_parts() {
115 let code = generate_token_code_parts(3, 4);
116 let parts: Vec<&str> = code.split('-').collect();
117 assert_eq!(parts.len(), 3);
118
119 for part in parts {
120 assert_eq!(part.len(), 4);
121 }
122 }
123}