this repo has no description
1use rand::Rng;
2use sqlx::PgPool;
3use uuid::Uuid;
4
5const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
6
7pub fn generate_token_code() -> String {
8 generate_token_code_parts(2, 5)
9}
10
11pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
12 let mut rng = rand::thread_rng();
13 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
14
15 (0..parts)
16 .map(|_| {
17 (0..part_len)
18 .map(|_| chars[rng.gen_range(0..chars.len())])
19 .collect::<String>()
20 })
21 .collect::<Vec<_>>()
22 .join("-")
23}
24
25#[derive(Debug)]
26pub enum DbLookupError {
27 NotFound,
28 DatabaseError(sqlx::Error),
29}
30
31impl From<sqlx::Error> for DbLookupError {
32 fn from(e: sqlx::Error) -> Self {
33 DbLookupError::DatabaseError(e)
34 }
35}
36
37pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
38 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
39 .fetch_optional(db)
40 .await?
41 .ok_or(DbLookupError::NotFound)
42}
43
44pub struct UserInfo {
45 pub id: Uuid,
46 pub did: String,
47 pub handle: String,
48}
49
50pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
51 sqlx::query_as!(
52 UserInfo,
53 "SELECT id, did, handle FROM users WHERE did = $1",
54 did
55 )
56 .fetch_optional(db)
57 .await?
58 .ok_or(DbLookupError::NotFound)
59}
60
61pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
62 sqlx::query_as!(
63 UserInfo,
64 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
65 identifier
66 )
67 .fetch_optional(db)
68 .await?
69 .ok_or(DbLookupError::NotFound)
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_generate_token_code() {
78 let code = generate_token_code();
79 assert_eq!(code.len(), 11);
80 assert!(code.contains('-'));
81
82 let parts: Vec<&str> = code.split('-').collect();
83 assert_eq!(parts.len(), 2);
84 assert_eq!(parts[0].len(), 5);
85 assert_eq!(parts[1].len(), 5);
86
87 for c in code.chars() {
88 if c != '-' {
89 assert!(BASE32_ALPHABET.contains(c));
90 }
91 }
92 }
93
94 #[test]
95 fn test_generate_token_code_parts() {
96 let code = generate_token_code_parts(3, 4);
97 let parts: Vec<&str> = code.split('-').collect();
98 assert_eq!(parts.len(), 3);
99 for part in parts {
100 assert_eq!(part.len(), 4);
101 }
102 }
103}