forked from
rocksky.app/rocksky
A decentralized music tracking and discovery platform built on AT Protocol 馃幍
1use anyhow::Error;
2use jsonwebtoken::DecodingKey;
3use jsonwebtoken::EncodingKey;
4use jsonwebtoken::Header;
5use jsonwebtoken::Validation;
6use serde::{Deserialize, Serialize};
7use sqlx::{Pool, Postgres};
8use std::collections::BTreeMap;
9use std::env;
10
11use crate::cache::Cache;
12use crate::repo;
13use crate::rocksky::ROCKSKY_API;
14use crate::signature::generate_signature;
15use crate::xata::user::User;
16
17#[derive(Debug, Serialize, Deserialize)]
18pub struct Claims {
19 pub exp: usize,
20 pub iat: usize,
21 pub did: String,
22}
23
24pub async fn authenticate_v1(
25 pool: &Pool<Postgres>,
26 api_key: &str,
27 timestamp: &str,
28 password_md5: &str,
29) -> Result<(), Error> {
30 match repo::user::get_user_by_apikey(pool, api_key).await? {
31 Some(user) => {
32 let shared_secret = user
33 .shared_secret
34 .ok_or_else(|| Error::msg("User does not have a shared secret"))?;
35 let hashed_password = md5::compute(format!("{}", shared_secret));
36 let hashed_password = format!("{:x}", hashed_password);
37 let expected_password = format!("{}{}", hashed_password, timestamp);
38 let expected_password = md5::compute(expected_password);
39 let expected_password = format!("{:x}", expected_password);
40 if expected_password != password_md5 {
41 tracing::error!(expected = %expected_password, provided = %password_md5, "Invalid password");
42 return Err(Error::msg("Invalid password"));
43 }
44 Ok(())
45 }
46 None => Err(Error::msg("Invalid API key")),
47 }
48}
49
50pub async fn authenticate(
51 pool: &Pool<Postgres>,
52 api_key: &str,
53 api_sig: &str,
54 session_key: &str,
55 form: &BTreeMap<String, String>,
56) -> Result<(), Error> {
57 let claims = decode_token(session_key)?;
58
59 let user_apikey = repo::api_key::get_apikey(pool, api_key, &claims.did).await?;
60
61 if user_apikey.is_none() {
62 return Err(Error::msg("Invalid API key"));
63 }
64
65 let user_apikey = user_apikey.unwrap();
66
67 let signature = generate_signature(form, &user_apikey.shared_secret);
68
69 if signature != api_sig {
70 return Err(Error::msg("Invalid signature"));
71 }
72
73 Ok(())
74}
75
76pub async fn extract_did(
77 pool: &Pool<Postgres>,
78 form: &BTreeMap<String, String>,
79) -> Result<String, Error> {
80 let apikey = form
81 .get("api_key")
82 .ok_or_else(|| Error::msg("Missing api_key"))?;
83 let user = repo::user::get_user_by_apikey(pool, apikey).await?;
84 let did = user
85 .ok_or_else(|| Error::msg("Corresponding user not found"))?
86 .did;
87 Ok(did)
88}
89
90pub fn generate_token(did: &str) -> Result<String, Error> {
91 if env::var("JWT_SECRET").is_err() {
92 return Err(Error::msg("JWT_SECRET is not set"));
93 }
94
95 let claims = Claims {
96 exp: chrono::Utc::now().timestamp() as usize + 3600,
97 iat: chrono::Utc::now().timestamp() as usize,
98 did: did.to_string(),
99 };
100
101 jsonwebtoken::encode(
102 &Header::default(),
103 &claims,
104 &EncodingKey::from_secret(env::var("JWT_SECRET")?.as_ref()),
105 )
106 .map_err(Into::into)
107}
108
109pub fn decode_token(token: &str) -> Result<Claims, Error> {
110 if env::var("JWT_SECRET").is_err() {
111 return Err(Error::msg("JWT_SECRET is not set"));
112 }
113
114 jsonwebtoken::decode::<Claims>(
115 token,
116 &DecodingKey::from_secret(env::var("JWT_SECRET")?.as_ref()),
117 &Validation::default(),
118 )
119 .map(|data| data.claims)
120 .map_err(Into::into)
121}
122
123pub async fn generate_session_id(
124 pool: &Pool<Postgres>,
125 cache: &Cache,
126 api_key: &str,
127) -> Result<String, Error> {
128 match repo::user::get_user_by_apikey(pool, &api_key).await? {
129 Some(user) => {
130 let mut bytes = [0u8; 16];
131 rand::fill(&mut bytes[..]);
132
133 let session_id = hex::encode(bytes);
134
135 let user =
136 serde_json::to_string(&user).map_err(|_| Error::msg("Failed to serialize user"))?;
137 cache.set(&format!("lastfm:{}", session_id), &user)?;
138 Ok(session_id)
139 }
140 None => Err(Error::msg("Invalid API key")),
141 }
142}
143
144pub fn verify_session_id(cache: &Cache, session_id: &str) -> Result<String, Error> {
145 let user = cache.get(&format!("lastfm:{}", session_id))?;
146 if user.is_none() {
147 return Err(Error::msg("Session ID not found"));
148 }
149 let user: String = user.unwrap();
150 let user: User = serde_json::from_str(&user)
151 .map_err(|e| Error::msg(format!("Failed to deserialize user: {}", e)))?;
152 Ok(user.xata_id)
153}
154
155pub async fn validate_bearer_token(pool: &Pool<Postgres>, token: &str) -> Result<(), Error> {
156 let user = repo::user::get_user_by_apikey(pool, token).await?;
157 if user.is_none() {
158 return Err(Error::msg("Invalid token"));
159 }
160
161 let user = user.unwrap();
162 let jwt = generate_token(&user.did)?;
163 let client = reqwest::Client::new();
164
165 let res = client
166 .get(&format!(
167 "{}/xrpc/app.rocksky.actor.getProfile",
168 ROCKSKY_API
169 ))
170 .bearer_auth(jwt)
171 .send()
172 .await?
173 .error_for_status()?;
174
175 let profile: serde_json::Value = res.json().await?;
176 if profile.as_object().map_or(true, |obj| obj.is_empty()) {
177 return Err(Error::msg(
178 "ATProto session expired, please logout and login in https://rocksky.app and try again",
179 ));
180 }
181
182 Ok(())
183}
184
185#[cfg(test)]
186mod tests {
187 use dotenv::dotenv;
188
189 use super::*;
190
191 #[test]
192 fn test_generate_token() {
193 dotenv().ok();
194 let token = generate_token("did:plc:7vdlgi2bflelz7mmuxoqjfcr").unwrap();
195 let claims = decode_token(&token).unwrap();
196
197 assert_eq!(claims.did, "did:plc:7vdlgi2bflelz7mmuxoqjfcr");
198 }
199}