Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

fix: concurrent perf improvement

+225 -20
+1
crates/tranquil-pds/src/api/repo/import.rs
··· 190 190 .ok() 191 191 .and_then(|s| s.parse().ok()) 192 192 .unwrap_or(DEFAULT_MAX_BLOCKS); 193 + let _write_lock = state.repo_write_locks.lock(user_id).await; 193 194 match apply_import(&state.repo_repo, user_id, root, blocks.clone(), max_blocks).await { 194 195 Ok(import_result) => { 195 196 info!(
+3
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 326 326 .ok() 327 327 .flatten() 328 328 .ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?; 329 + 330 + let _write_lock = state.repo_write_locks.lock(user_id).await; 331 + 329 332 let root_cid_str = state 330 333 .repo_repo 331 334 .get_repo_root_cid_by_user_id(user_id)
+8 -2
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 2 + use crate::api::repo::record::utils::{ 3 + CommitParams, RecordOp, commit_and_log, get_current_root_cid, 4 + }; 3 5 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 6 use crate::auth::{Active, Auth, VerifyScope}; 5 7 use crate::cid_types::CommitCid; ··· 56 58 57 59 let did = repo_auth.did; 58 60 let user_id = repo_auth.user_id; 59 - let current_root_cid = repo_auth.current_root_cid; 60 61 let controller_did = repo_auth.controller_did; 62 + 63 + let _write_lock = state.repo_write_locks.lock(user_id).await; 64 + let current_root_cid = get_current_root_cid(&state, user_id).await?; 61 65 62 66 if let Some(swap_commit) = &input.swap_commit 63 67 && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) ··· 238 242 collection: &Nsid, 239 243 rkey: &Rkey, 240 244 ) -> Result<(), String> { 245 + let _write_lock = state.repo_write_locks.lock(user_id).await; 246 + 241 247 let root_cid_str = state 242 248 .repo_repo 243 249 .get_repo_root_cid_by_user_id(user_id)
+20
crates/tranquil-pds/src/api/repo/record/utils.rs
··· 1 + use crate::api::error::ApiError; 2 + use crate::cid_types::CommitCid; 1 3 use crate::state::AppState; 2 4 use crate::types::{Did, Handle, Nsid, Rkey}; 3 5 use bytes::Bytes; ··· 8 10 use k256::ecdsa::SigningKey; 9 11 use serde_json::{Value, json}; 10 12 use std::str::FromStr; 13 + use tracing::error; 11 14 use tranquil_db_traits::SequenceNumber; 12 15 use uuid::Uuid; 16 + 17 + pub async fn get_current_root_cid(state: &AppState, user_id: Uuid) -> Result<CommitCid, ApiError> { 18 + let root_cid_str = state 19 + .repo_repo 20 + .get_repo_root_cid_by_user_id(user_id) 21 + .await 22 + .map_err(|e| { 23 + error!("DB error fetching repo root: {}", e); 24 + ApiError::InternalError(None) 25 + })? 26 + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; 27 + CommitCid::from_str(&root_cid_str) 28 + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into()))) 29 + } 13 30 14 31 pub fn extract_blob_cids(record: &Value) -> Vec<String> { 15 32 let mut blobs = Vec::new(); ··· 328 345 .await 329 346 .map_err(|e| format!("DB error: {}", e))? 330 347 .ok_or_else(|| "User not found".to_string())?; 348 + 349 + let _write_lock = state.repo_write_locks.lock(user_id).await; 350 + 331 351 let root_cid_link = state 332 352 .repo_repo 333 353 .get_repo_root_cid_by_user_id(user_id)
+8 -18
crates/tranquil-pds/src/api/repo/record/write.rs
··· 3 3 use crate::api::error::ApiError; 4 4 use crate::api::repo::record::utils::{ 5 5 CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 6 + get_current_root_cid, 6 7 }; 7 8 use crate::auth::{ 8 9 Active, Auth, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, ··· 31 32 pub struct RepoWriteAuth { 32 33 pub did: Did, 33 34 pub user_id: Uuid, 34 - pub current_root_cid: CommitCid, 35 35 pub is_oauth: bool, 36 36 pub scope: Option<String>, 37 37 pub controller_did: Option<Did>, ··· 62 62 ApiError::InternalError(None).into_response() 63 63 })? 64 64 .ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?; 65 - let root_cid_str = state 66 - .repo_repo 67 - .get_repo_root_cid_by_user_id(user_id) 68 - .await 69 - .map_err(|e| { 70 - error!("DB error fetching repo root: {}", e); 71 - ApiError::InternalError(None).into_response() 72 - })? 73 - .ok_or_else(|| { 74 - ApiError::InternalError(Some("Repo root not found".into())).into_response() 75 - })?; 76 - let current_root_cid = CommitCid::from_str(&root_cid_str).map_err(|_| { 77 - ApiError::InternalError(Some("Invalid repo root CID".into())).into_response() 78 - })?; 65 + 79 66 Ok(RepoWriteAuth { 80 67 did: principal_did.into_did(), 81 68 user_id, 82 - current_root_cid, 83 69 is_oauth: user.is_oauth(), 84 70 scope: user.scope.clone(), 85 71 controller_did: scope_proof.controller_did().map(|c| c.into_did()), ··· 130 116 131 117 let did = repo_auth.did; 132 118 let user_id = repo_auth.user_id; 133 - let current_root_cid = repo_auth.current_root_cid; 134 119 let controller_did = repo_auth.controller_did; 120 + 121 + let _write_lock = state.repo_write_locks.lock(user_id).await; 122 + let current_root_cid = get_current_root_cid(&state, user_id).await?; 135 123 136 124 if let Some(swap_commit) = &input.swap_commit 137 125 && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) ··· 433 421 434 422 let did = repo_auth.did; 435 423 let user_id = repo_auth.user_id; 436 - let current_root_cid = repo_auth.current_root_cid; 437 424 let controller_did = repo_auth.controller_did; 425 + 426 + let _write_lock = state.repo_write_locks.lock(user_id).await; 427 + let current_root_cid = get_current_root_cid(&state, user_id).await?; 438 428 439 429 if let Some(swap_commit) = &input.swap_commit 440 430 && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid)
+1
crates/tranquil-pds/src/lib.rs
··· 16 16 pub mod plc; 17 17 pub mod rate_limit; 18 18 pub mod repo; 19 + pub mod repo_write_lock; 19 20 pub mod scheduled; 20 21 pub mod sso; 21 22 pub mod state;
+180
crates/tranquil-pds/src/repo_write_lock.rs
··· 1 + use std::collections::HashMap; 2 + use std::sync::Arc; 3 + use std::time::Duration; 4 + use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; 5 + use uuid::Uuid; 6 + 7 + const SWEEP_INTERVAL: Duration = Duration::from_secs(300); 8 + 9 + pub struct RepoWriteLocks { 10 + locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>, 11 + } 12 + 13 + impl Default for RepoWriteLocks { 14 + fn default() -> Self { 15 + Self::new() 16 + } 17 + } 18 + 19 + impl RepoWriteLocks { 20 + pub fn new() -> Self { 21 + let locks = Arc::new(RwLock::new(HashMap::new())); 22 + let sweep_locks = Arc::clone(&locks); 23 + tokio::spawn(async move { 24 + sweep_loop(sweep_locks).await; 25 + }); 26 + Self { locks } 27 + } 28 + 29 + pub async fn lock(&self, user_id: Uuid) -> OwnedMutexGuard<()> { 30 + let mutex = { 31 + let read_guard = self.locks.read().await; 32 + read_guard.get(&user_id).cloned() 33 + }; 34 + 35 + match mutex { 36 + Some(m) => m.lock_owned().await, 37 + None => { 38 + let mut write_guard = self.locks.write().await; 39 + let mutex = write_guard 40 + .entry(user_id) 41 + .or_insert_with(|| Arc::new(Mutex::new(()))) 42 + .clone(); 43 + drop(write_guard); 44 + mutex.lock_owned().await 45 + } 46 + } 47 + } 48 + } 49 + 50 + async fn sweep_loop(locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>) { 51 + tokio::time::sleep(SWEEP_INTERVAL).await; 52 + let mut write_guard = locks.write().await; 53 + let before = write_guard.len(); 54 + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); 55 + let evicted = before - write_guard.len(); 56 + if evicted > 0 { 57 + tracing::debug!( 58 + evicted, 59 + remaining = write_guard.len(), 60 + "repo write lock sweep" 61 + ); 62 + } 63 + drop(write_guard); 64 + Box::pin(sweep_loop(locks)).await; 65 + } 66 + 67 + #[cfg(test)] 68 + mod tests { 69 + use super::*; 70 + use std::sync::atomic::{AtomicU32, Ordering}; 71 + use std::time::Duration; 72 + 73 + #[tokio::test] 74 + async fn test_locks_serialize_same_user() { 75 + let locks = Arc::new(RepoWriteLocks::new()); 76 + let user_id = Uuid::new_v4(); 77 + let counter = Arc::new(AtomicU32::new(0)); 78 + let max_concurrent = Arc::new(AtomicU32::new(0)); 79 + 80 + let handles: Vec<_> = (0..10) 81 + .map(|_| { 82 + let locks = locks.clone(); 83 + let counter = counter.clone(); 84 + let max_concurrent = max_concurrent.clone(); 85 + 86 + tokio::spawn(async move { 87 + let _guard = locks.lock(user_id).await; 88 + let current = counter.fetch_add(1, Ordering::SeqCst) + 1; 89 + max_concurrent.fetch_max(current, Ordering::SeqCst); 90 + tokio::time::sleep(Duration::from_millis(1)).await; 91 + counter.fetch_sub(1, Ordering::SeqCst); 92 + }) 93 + }) 94 + .collect(); 95 + 96 + futures::future::join_all(handles).await; 97 + 98 + assert_eq!( 99 + max_concurrent.load(Ordering::SeqCst), 100 + 1, 101 + "Only one task should hold the lock at a time for same user" 102 + ); 103 + } 104 + 105 + #[tokio::test] 106 + async fn test_different_users_can_run_concurrently() { 107 + let locks = Arc::new(RepoWriteLocks::new()); 108 + let user1 = Uuid::new_v4(); 109 + let user2 = Uuid::new_v4(); 110 + let concurrent_count = Arc::new(AtomicU32::new(0)); 111 + let max_concurrent = Arc::new(AtomicU32::new(0)); 112 + 113 + let locks1 = locks.clone(); 114 + let count1 = concurrent_count.clone(); 115 + let max1 = max_concurrent.clone(); 116 + let handle1 = tokio::spawn(async move { 117 + let _guard = locks1.lock(user1).await; 118 + let current = count1.fetch_add(1, Ordering::SeqCst) + 1; 119 + max1.fetch_max(current, Ordering::SeqCst); 120 + tokio::time::sleep(Duration::from_millis(50)).await; 121 + count1.fetch_sub(1, Ordering::SeqCst); 122 + }); 123 + 124 + tokio::time::sleep(Duration::from_millis(10)).await; 125 + 126 + let locks2 = locks.clone(); 127 + let count2 = concurrent_count.clone(); 128 + let max2 = max_concurrent.clone(); 129 + let handle2 = tokio::spawn(async move { 130 + let _guard = locks2.lock(user2).await; 131 + let current = count2.fetch_add(1, Ordering::SeqCst) + 1; 132 + max2.fetch_max(current, Ordering::SeqCst); 133 + tokio::time::sleep(Duration::from_millis(50)).await; 134 + count2.fetch_sub(1, Ordering::SeqCst); 135 + }); 136 + 137 + handle1.await.unwrap(); 138 + handle2.await.unwrap(); 139 + 140 + assert_eq!( 141 + max_concurrent.load(Ordering::SeqCst), 142 + 2, 143 + "Different users should be able to run concurrently" 144 + ); 145 + } 146 + 147 + #[tokio::test] 148 + async fn test_sweep_evicts_idle_entries() { 149 + let locks = Arc::new(RwLock::new(HashMap::new())); 150 + let user_id = Uuid::new_v4(); 151 + 152 + { 153 + let mut write_guard = locks.write().await; 154 + write_guard.insert(user_id, Arc::new(Mutex::new(()))); 155 + } 156 + 157 + assert_eq!(locks.read().await.len(), 1); 158 + 159 + let mut write_guard = locks.write().await; 160 + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); 161 + assert_eq!(write_guard.len(), 0, "Idle entry should be evicted"); 162 + } 163 + 164 + #[tokio::test] 165 + async fn test_sweep_preserves_active_entries() { 166 + let locks = Arc::new(RwLock::new(HashMap::new())); 167 + let user_id = Uuid::new_v4(); 168 + let active_mutex = Arc::new(Mutex::new(())); 169 + let _held_ref = active_mutex.clone(); 170 + 171 + { 172 + let mut write_guard = locks.write().await; 173 + write_guard.insert(user_id, active_mutex); 174 + } 175 + 176 + let mut write_guard = locks.write().await; 177 + write_guard.retain(|_, mutex| Arc::strong_count(mutex) > 1); 178 + assert_eq!(write_guard.len(), 1, "Active entry should be preserved"); 179 + } 180 + }
+4
crates/tranquil-pds/src/state.rs
··· 5 5 use crate::config::AuthConfig; 6 6 use crate::rate_limit::RateLimiters; 7 7 use crate::repo::PostgresBlockStore; 8 + use crate::repo_write_lock::RepoWriteLocks; 8 9 use crate::sso::{SsoConfig, SsoManager}; 9 10 use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage}; 10 11 use crate::sync::firehose::SequencedEvent; ··· 38 39 pub backup_storage: Option<Arc<dyn BackupStorage>>, 39 40 pub firehose_tx: broadcast::Sender<SequencedEvent>, 40 41 pub rate_limiters: Arc<RateLimiters>, 42 + pub repo_write_locks: Arc<RepoWriteLocks>, 41 43 pub circuit_breakers: Arc<CircuitBreakers>, 42 44 pub cache: Arc<dyn Cache>, 43 45 pub distributed_rate_limiter: Arc<dyn DistributedRateLimiter>, ··· 181 183 182 184 let (firehose_tx, _) = broadcast::channel(firehose_buffer_size); 183 185 let rate_limiters = Arc::new(RateLimiters::new()); 186 + let repo_write_locks = Arc::new(RepoWriteLocks::new()); 184 187 let circuit_breakers = Arc::new(CircuitBreakers::new()); 185 188 let (cache, distributed_rate_limiter) = create_cache().await; 186 189 let did_resolver = Arc::new(DidResolver::new()); ··· 209 212 backup_storage, 210 213 firehose_tx, 211 214 rate_limiters, 215 + repo_write_locks, 212 216 circuit_breakers, 213 217 cache, 214 218 distributed_rate_limiter,