this repo has no description
1use rand::Rng;
2use sqlx::PgPool;
3use uuid::Uuid;
4
5const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
7pub fn generate_token_code() -> String {
8 generate_token_code_parts(2, 5)
9}
10
11pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
12 let mut rng = rand::thread_rng();
13 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
15 (0..parts)
16 .map(|_| {
17 (0..part_len)
18 .map(|_| chars[rng.gen_range(0..chars.len())])
19 .collect::<String>()
20 })
21 .collect::<Vec<_>>()
22 .join("-")
23}
24
25#[derive(Debug)]
26pub enum DbLookupError {
27 NotFound,
28 DatabaseError(sqlx::Error),
29}
30
31impl From<sqlx::Error> for DbLookupError {
32 fn from(e: sqlx::Error) -> Self {
33 DbLookupError::DatabaseError(e)
34 }
35}
36
37pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
38 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
39 .fetch_optional(db)
40 .await?
41 .ok_or(DbLookupError::NotFound)
42}
43
44pub struct UserInfo {
45 pub id: Uuid,
46 pub did: String,
47 pub handle: String,
48}
49
50pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
51 sqlx::query_as!(
52 UserInfo,
53 "SELECT id, did, handle FROM users WHERE did = $1",
54 did
55 )
56 .fetch_optional(db)
57 .await?
58 .ok_or(DbLookupError::NotFound)
59}
60
61pub async fn get_user_by_identifier(
62 db: &PgPool,
63 identifier: &str,
64) -> Result<UserInfo, DbLookupError> {
65 sqlx::query_as!(
66 UserInfo,
67 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
68 identifier
69 )
70 .fetch_optional(db)
71 .await?
72 .ok_or(DbLookupError::NotFound)
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn test_generate_token_code() {
81 let code = generate_token_code();
82 assert_eq!(code.len(), 11);
83 assert!(code.contains('-'));
84
85 let parts: Vec<&str> = code.split('-').collect();
86 assert_eq!(parts.len(), 2);
87 assert_eq!(parts[0].len(), 5);
88 assert_eq!(parts[1].len(), 5);
89
90 for c in code.chars() {
91 if c != '-' {
92 assert!(BASE32_ALPHABET.contains(c));
93 }
94 }
95 }
96
97 #[test]
98 fn test_generate_token_code_parts() {
99 let code = generate_token_code_parts(3, 4);
100 let parts: Vec<&str> = code.split('-').collect();
101 assert_eq!(parts.len(), 3);
102
103 for part in parts {
104 assert_eq!(part.len(), 4);
105 }
106 }
107}