A decentralized music tracking and discovery platform built on AT Protocol 馃幍
at main 199 lines 5.8 kB view raw
1use anyhow::Error; 2use jsonwebtoken::DecodingKey; 3use jsonwebtoken::EncodingKey; 4use jsonwebtoken::Header; 5use jsonwebtoken::Validation; 6use serde::{Deserialize, Serialize}; 7use sqlx::{Pool, Postgres}; 8use std::collections::BTreeMap; 9use std::env; 10 11use crate::cache::Cache; 12use crate::repo; 13use crate::rocksky::ROCKSKY_API; 14use crate::signature::generate_signature; 15use crate::xata::user::User; 16 17#[derive(Debug, Serialize, Deserialize)] 18pub struct Claims { 19 pub exp: usize, 20 pub iat: usize, 21 pub did: String, 22} 23 24pub async fn authenticate_v1( 25 pool: &Pool<Postgres>, 26 api_key: &str, 27 timestamp: &str, 28 password_md5: &str, 29) -> Result<(), Error> { 30 match repo::user::get_user_by_apikey(pool, api_key).await? { 31 Some(user) => { 32 let shared_secret = user 33 .shared_secret 34 .ok_or_else(|| Error::msg("User does not have a shared secret"))?; 35 let hashed_password = md5::compute(format!("{}", shared_secret)); 36 let hashed_password = format!("{:x}", hashed_password); 37 let expected_password = format!("{}{}", hashed_password, timestamp); 38 let expected_password = md5::compute(expected_password); 39 let expected_password = format!("{:x}", expected_password); 40 if expected_password != password_md5 { 41 tracing::error!(expected = %expected_password, provided = %password_md5, "Invalid password"); 42 return Err(Error::msg("Invalid password")); 43 } 44 Ok(()) 45 } 46 None => Err(Error::msg("Invalid API key")), 47 } 48} 49 50pub async fn authenticate( 51 pool: &Pool<Postgres>, 52 api_key: &str, 53 api_sig: &str, 54 session_key: &str, 55 form: &BTreeMap<String, String>, 56) -> Result<(), Error> { 57 let claims = decode_token(session_key)?; 58 59 let user_apikey = repo::api_key::get_apikey(pool, api_key, &claims.did).await?; 60 61 if user_apikey.is_none() { 62 return Err(Error::msg("Invalid API key")); 63 } 64 65 let user_apikey = user_apikey.unwrap(); 66 67 let signature = generate_signature(form, &user_apikey.shared_secret); 68 69 if signature != api_sig { 70 return Err(Error::msg("Invalid signature")); 71 } 72 73 Ok(()) 74} 75 76pub async fn extract_did( 77 pool: &Pool<Postgres>, 78 form: &BTreeMap<String, String>, 79) -> Result<String, Error> { 80 let apikey = form 81 .get("api_key") 82 .ok_or_else(|| Error::msg("Missing api_key"))?; 83 let user = repo::user::get_user_by_apikey(pool, apikey).await?; 84 let did = user 85 .ok_or_else(|| Error::msg("Corresponding user not found"))? 86 .did; 87 Ok(did) 88} 89 90pub fn generate_token(did: &str) -> Result<String, Error> { 91 if env::var("JWT_SECRET").is_err() { 92 return Err(Error::msg("JWT_SECRET is not set")); 93 } 94 95 let claims = Claims { 96 exp: chrono::Utc::now().timestamp() as usize + 3600, 97 iat: chrono::Utc::now().timestamp() as usize, 98 did: did.to_string(), 99 }; 100 101 jsonwebtoken::encode( 102 &Header::default(), 103 &claims, 104 &EncodingKey::from_secret(env::var("JWT_SECRET")?.as_ref()), 105 ) 106 .map_err(Into::into) 107} 108 109pub fn decode_token(token: &str) -> Result<Claims, Error> { 110 if env::var("JWT_SECRET").is_err() { 111 return Err(Error::msg("JWT_SECRET is not set")); 112 } 113 114 jsonwebtoken::decode::<Claims>( 115 token, 116 &DecodingKey::from_secret(env::var("JWT_SECRET")?.as_ref()), 117 &Validation::default(), 118 ) 119 .map(|data| data.claims) 120 .map_err(Into::into) 121} 122 123pub async fn generate_session_id( 124 pool: &Pool<Postgres>, 125 cache: &Cache, 126 api_key: &str, 127) -> Result<String, Error> { 128 match repo::user::get_user_by_apikey(pool, &api_key).await? { 129 Some(user) => { 130 let mut bytes = [0u8; 16]; 131 rand::fill(&mut bytes[..]); 132 133 let session_id = hex::encode(bytes); 134 135 let user = 136 serde_json::to_string(&user).map_err(|_| Error::msg("Failed to serialize user"))?; 137 cache.set(&format!("lastfm:{}", session_id), &user)?; 138 Ok(session_id) 139 } 140 None => Err(Error::msg("Invalid API key")), 141 } 142} 143 144pub fn verify_session_id(cache: &Cache, session_id: &str) -> Result<String, Error> { 145 let user = cache.get(&format!("lastfm:{}", session_id))?; 146 if user.is_none() { 147 return Err(Error::msg("Session ID not found")); 148 } 149 let user: String = user.unwrap(); 150 let user: User = serde_json::from_str(&user) 151 .map_err(|e| Error::msg(format!("Failed to deserialize user: {}", e)))?; 152 Ok(user.xata_id) 153} 154 155pub async fn validate_bearer_token(pool: &Pool<Postgres>, token: &str) -> Result<(), Error> { 156 let user = repo::user::get_user_by_apikey(pool, token).await?; 157 if user.is_none() { 158 return Err(Error::msg("Invalid token")); 159 } 160 161 let user = user.unwrap(); 162 let jwt = generate_token(&user.did)?; 163 let client = reqwest::Client::new(); 164 165 let res = client 166 .get(&format!( 167 "{}/xrpc/app.rocksky.actor.getProfile", 168 ROCKSKY_API 169 )) 170 .bearer_auth(jwt) 171 .send() 172 .await? 173 .error_for_status()?; 174 175 let profile: serde_json::Value = res.json().await?; 176 if profile.as_object().map_or(true, |obj| obj.is_empty()) { 177 return Err(Error::msg( 178 "ATProto session expired, please logout and login in https://rocksky.app and try again", 179 )); 180 } 181 182 Ok(()) 183} 184 185#[cfg(test)] 186mod tests { 187 use dotenv::dotenv; 188 189 use super::*; 190 191 #[test] 192 fn test_generate_token() { 193 dotenv().ok(); 194 let token = generate_token("did:plc:7vdlgi2bflelz7mmuxoqjfcr").unwrap(); 195 let claims = decode_token(&token).unwrap(); 196 197 assert_eq!(claims.did, "did:plc:7vdlgi2bflelz7mmuxoqjfcr"); 198 } 199}