this repo has no description
1use rand::Rng;
2use sqlx::PgPool;
3use uuid::Uuid;
4const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567";
5pub fn generate_token_code() -> String {
6 generate_token_code_parts(2, 5)
7}
8pub fn generate_token_code_parts(parts: usize, part_len: usize) -> String {
9 let mut rng = rand::thread_rng();
10 let chars: Vec<char> = BASE32_ALPHABET.chars().collect();
11 (0..parts)
12 .map(|_| {
13 (0..part_len)
14 .map(|_| chars[rng.gen_range(0..chars.len())])
15 .collect::<String>()
16 })
17 .collect::<Vec<_>>()
18 .join("-")
19}
20#[derive(Debug)]
21pub enum DbLookupError {
22 NotFound,
23 DatabaseError(sqlx::Error),
24}
25impl From<sqlx::Error> for DbLookupError {
26 fn from(e: sqlx::Error) -> Self {
27 DbLookupError::DatabaseError(e)
28 }
29}
30pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> {
31 sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
32 .fetch_optional(db)
33 .await?
34 .ok_or(DbLookupError::NotFound)
35}
36pub struct UserInfo {
37 pub id: Uuid,
38 pub did: String,
39 pub handle: String,
40}
41pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> {
42 sqlx::query_as!(
43 UserInfo,
44 "SELECT id, did, handle FROM users WHERE did = $1",
45 did
46 )
47 .fetch_optional(db)
48 .await?
49 .ok_or(DbLookupError::NotFound)
50}
51pub async fn get_user_by_identifier(db: &PgPool, identifier: &str) -> Result<UserInfo, DbLookupError> {
52 sqlx::query_as!(
53 UserInfo,
54 "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1",
55 identifier
56 )
57 .fetch_optional(db)
58 .await?
59 .ok_or(DbLookupError::NotFound)
60}
61#[cfg(test)]
62mod tests {
63 use super::*;
64 #[test]
65 fn test_generate_token_code() {
66 let code = generate_token_code();
67 assert_eq!(code.len(), 11);
68 assert!(code.contains('-'));
69 let parts: Vec<&str> = code.split('-').collect();
70 assert_eq!(parts.len(), 2);
71 assert_eq!(parts[0].len(), 5);
72 assert_eq!(parts[1].len(), 5);
73 for c in code.chars() {
74 if c != '-' {
75 assert!(BASE32_ALPHABET.contains(c));
76 }
77 }
78 }
79 #[test]
80 fn test_generate_token_code_parts() {
81 let code = generate_token_code_parts(3, 4);
82 let parts: Vec<&str> = code.split('-').collect();
83 assert_eq!(parts.len(), 3);
84 for part in parts {
85 assert_eq!(part.len(), 4);
86 }
87 }
88}