A better Rust ATProto crate

parallelism maybe?

Orual 5651258d df39d157

+198 -14
+198 -14
crates/jacquard-repo/src/mst/tree.rs
··· 10 10 use jacquard_common::types::recordkey::Rkey; 11 11 use jacquard_common::types::string::{Nsid, RecordKey}; 12 12 use jacquard_common::types::value::RawData; 13 + use n0_future::try_join_all; 13 14 use smol_str::SmolStr; 14 15 use std::fmt::{Display, Formatter}; 16 + use std::future::Future; 15 17 use std::pin::Pin; 16 18 use std::sync::Arc; 17 19 use tokio::sync::RwLock; ··· 307 309 308 310 for entry in &entries { 309 311 if let NodeEntry::Tree(mst) = entry { 310 - let is_outdated = *mst.outdated_pointer.read().await; 311 - if is_outdated { 312 - outdated_children.push(mst.clone()); 312 + if *mst.outdated_pointer.read().await { 313 + let child = mst.clone(); 314 + outdated_children.push(n0_future::task::spawn(async move { 315 + child.get_pointer().await 316 + })); 313 317 } 314 318 } 315 319 } 316 320 317 - // Recursively update outdated children 321 + // Recursively update outdated children concurrently 318 322 if !outdated_children.is_empty() { 319 - for child in &outdated_children { 320 - let _ = child.get_pointer().await?; 321 - } 323 + try_join_all(outdated_children) 324 + .await 325 + .map_err(|e| RepoError::invalid(format!("Task join error: {}", e)))?; 326 + 322 327 // Re-fetch entries with updated child CIDs 323 328 entries = self.get_entries().await?; 324 329 } ··· 869 874 /// 870 875 /// Recursively traverses the tree to collect all leaves. 871 876 /// Used for diff calculation and tree listing. 877 + /// 878 + /// Uses parallel traversal to collect leaves from independent subtrees concurrently. 872 879 pub fn leaves<'a>( 873 880 &'a self, 874 881 ) -> std::pin::Pin< ··· 876 883 dyn std::future::Future<Output = Result<Vec<(smol_str::SmolStr, IpldCid)>>> + Send + 'a, 877 884 >, 878 885 > { 886 + Box::pin(async move { collect_leaves_parallel(self.clone()).await }) 887 + } 888 + 889 + /// Get all leaf entries sequentially (for benchmarking) 890 + pub fn leaves_sequential<'a>( 891 + &'a self, 892 + ) -> std::pin::Pin< 893 + Box< 894 + dyn std::future::Future<Output = Result<Vec<(smol_str::SmolStr, IpldCid)>>> + Send + 'a, 895 + >, 896 + > { 879 897 Box::pin(async move { 880 898 let mut result = Vec::new(); 881 - self.collect_leaves(&mut result).await?; 899 + self.collect_leaves_sequential(&mut result).await?; 882 900 Ok(result) 883 901 }) 884 902 } 885 903 886 - /// Recursively collect all leaves into the result vector 887 - fn collect_leaves<'a>( 904 + /// Recursively collect all leaves into the result vector (sequential) 905 + fn collect_leaves_sequential<'a>( 888 906 &'a self, 889 907 result: &'a mut Vec<(smol_str::SmolStr, IpldCid)>, 890 908 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> { ··· 895 913 match entry { 896 914 NodeEntry::Tree(subtree) => { 897 915 // Recurse into subtree 898 - subtree.collect_leaves(result).await?; 916 + subtree.collect_leaves_sequential(result).await?; 899 917 } 900 918 NodeEntry::Leaf { key, value } => { 901 919 // Add leaf to result ··· 989 1007 /// that aren't already in storage. Skips nodes that are already persisted. 990 1008 /// 991 1009 /// Returns (root_cid, blocks) where blocks is a map of CID → bytes. 1010 + /// 1011 + /// Uses parallel traversal to collect blocks from independent subtrees concurrently. 992 1012 pub fn collect_blocks<'a>( 993 1013 &'a self, 994 1014 ) -> std::pin::Pin< ··· 999 1019 + 'a, 1000 1020 >, 1001 1021 > { 1022 + Box::pin(async move { collect_blocks_parallel(self.clone()).await }) 1023 + } 1024 + 1025 + /// Collect all blocks sequentially (for benchmarking) 1026 + pub fn collect_blocks_sequential<'a>( 1027 + &'a self, 1028 + ) -> std::pin::Pin< 1029 + Box< 1030 + dyn std::future::Future< 1031 + Output = Result<(IpldCid, std::collections::BTreeMap<IpldCid, bytes::Bytes>)>, 1032 + > + Send 1033 + + 'a, 1034 + >, 1035 + > { 1002 1036 Box::pin(async move { 1003 1037 use bytes::Bytes; 1004 1038 use std::collections::BTreeMap; ··· 1021 1055 // Recursively collect from subtrees 1022 1056 for entry in &entries { 1023 1057 if let NodeEntry::Tree(subtree) = entry { 1024 - let (_, subtree_blocks) = subtree.collect_blocks().await?; 1058 + let (_, subtree_blocks) = subtree.collect_blocks_sequential().await?; 1025 1059 blocks.extend(subtree_blocks); 1026 1060 } 1027 1061 } ··· 1048 1082 /// 1049 1083 /// Returns all CIDs for MST nodes (internal nodes), not leaves. 1050 1084 /// Used for diff calculation to determine which MST blocks are removed. 1085 + /// 1086 + /// Uses parallel traversal to collect CIDs from independent subtrees concurrently. 1051 1087 pub async fn collect_node_cids(&self) -> Result<Vec<IpldCid>> { 1088 + collect_node_cids_parallel(self.clone()).await 1089 + } 1090 + 1091 + /// Collect all MST node CIDs sequentially (for benchmarking) 1092 + pub async fn collect_node_cids_sequential(&self) -> Result<Vec<IpldCid>> { 1052 1093 let mut cids = Vec::new(); 1053 1094 let pointer = self.get_pointer().await?; 1054 1095 cids.push(pointer); ··· 1056 1097 let entries = self.get_entries().await?; 1057 1098 for entry in &entries { 1058 1099 if let NodeEntry::Tree(subtree) = entry { 1059 - let subtree_cids = subtree.collect_node_cids().await?; 1100 + let subtree_cids = subtree.collect_node_cids_sequential().await?; 1060 1101 cids.extend(subtree_cids); 1061 1102 } 1062 1103 } 1063 - 1064 1104 Ok(cids) 1065 1105 } 1066 1106 ··· 1189 1229 Ok(()) 1190 1230 }) 1191 1231 } 1232 + } 1233 + 1234 + /// Recursively collect MST node CIDs in parallel 1235 + /// 1236 + /// Spawns concurrent tasks for each subtree branch, then merges results. 1237 + fn collect_node_cids_parallel<S: BlockStore + Sync + Send + 'static>( 1238 + tree: Mst<S>, 1239 + ) -> Pin<Box<dyn Future<Output = Result<Vec<IpldCid>>> + Send>> { 1240 + Box::pin(async move { 1241 + let pointer = tree.get_pointer().await?; 1242 + let entries = tree.get_entries().await?; 1243 + 1244 + // Spawn tasks for each subtree 1245 + let tasks: Vec<_> = entries 1246 + .into_iter() 1247 + .filter_map(|entry| { 1248 + if let NodeEntry::Tree(subtree) = entry { 1249 + Some(n0_future::task::spawn(async move { 1250 + collect_node_cids_parallel(subtree).await 1251 + })) 1252 + } else { 1253 + None 1254 + } 1255 + }) 1256 + .collect(); 1257 + 1258 + // Await all tasks concurrently 1259 + let results = try_join_all(tasks) 1260 + .await 1261 + .map_err(|e| RepoError::invalid(format!("Task join error: {}", e)))?; 1262 + 1263 + // Flatten results 1264 + let mut cids = vec![pointer]; 1265 + for subtree_cids in results { 1266 + cids.extend(subtree_cids?); 1267 + } 1268 + 1269 + Ok(cids) 1270 + }) 1271 + } 1272 + 1273 + /// Recursively collect leaves in parallel 1274 + /// 1275 + /// Spawns concurrent tasks for each subtree branch, preserving lexicographic order. 1276 + fn collect_leaves_parallel<S: BlockStore + Sync + Send + 'static>( 1277 + tree: Mst<S>, 1278 + ) -> Pin<Box<dyn Future<Output = Result<Vec<(smol_str::SmolStr, IpldCid)>>> + Send>> { 1279 + Box::pin(async move { 1280 + let entries = tree.get_entries().await?; 1281 + let mut result = Vec::new(); 1282 + 1283 + // Collect tasks and immediate leaves in order 1284 + let mut tasks = Vec::new(); 1285 + let mut task_positions = Vec::new(); 1286 + 1287 + for (i, entry) in entries.into_iter().enumerate() { 1288 + match entry { 1289 + NodeEntry::Tree(subtree) => { 1290 + task_positions.push(i); 1291 + tasks.push(n0_future::task::spawn(async move { 1292 + collect_leaves_parallel(subtree).await 1293 + })); 1294 + } 1295 + NodeEntry::Leaf { key, value } => { 1296 + result.push((i, vec![(key, value)])); 1297 + } 1298 + } 1299 + } 1300 + 1301 + // Await all tasks concurrently 1302 + if !tasks.is_empty() { 1303 + let subtree_results = try_join_all(tasks) 1304 + .await 1305 + .map_err(|e| RepoError::invalid(format!("Task join error: {}", e)))?; 1306 + 1307 + for (pos, leaves) in task_positions.into_iter().zip(subtree_results) { 1308 + result.push((pos, leaves?)); 1309 + } 1310 + } 1311 + 1312 + // Sort by position and flatten 1313 + result.sort_by_key(|(pos, _)| *pos); 1314 + Ok(result.into_iter().flat_map(|(_, leaves)| leaves).collect()) 1315 + }) 1316 + } 1317 + 1318 + /// Recursively collect blocks in parallel 1319 + /// 1320 + /// Spawns concurrent tasks for each subtree branch, then merges results. 1321 + fn collect_blocks_parallel<S: BlockStore + Sync + Send + 'static>( 1322 + tree: Mst<S>, 1323 + ) -> Pin< 1324 + Box< 1325 + dyn Future<Output = Result<(IpldCid, std::collections::BTreeMap<IpldCid, bytes::Bytes>)>> 1326 + + Send, 1327 + >, 1328 + > { 1329 + Box::pin(async move { 1330 + use bytes::Bytes; 1331 + use std::collections::BTreeMap; 1332 + 1333 + let pointer = tree.get_pointer().await?; 1334 + let mut blocks = BTreeMap::new(); 1335 + 1336 + // Check if already in storage 1337 + if tree.storage.has(&pointer).await? { 1338 + return Ok((pointer, blocks)); 1339 + } 1340 + 1341 + // Serialize this node 1342 + let entries = tree.get_entries().await?; 1343 + let node_data = util::serialize_node_data(&entries).await?; 1344 + let cbor = 1345 + serde_ipld_dagcbor::to_vec(&node_data).map_err(|e| RepoError::serialization(e))?; 1346 + blocks.insert(pointer, Bytes::from(cbor)); 1347 + 1348 + // Spawn tasks for each subtree 1349 + let tasks: Vec<_> = entries 1350 + .into_iter() 1351 + .filter_map(|entry| { 1352 + if let NodeEntry::Tree(subtree) = entry { 1353 + Some(n0_future::task::spawn(async move { 1354 + collect_blocks_parallel(subtree).await 1355 + })) 1356 + } else { 1357 + None 1358 + } 1359 + }) 1360 + .collect(); 1361 + 1362 + // Await all tasks concurrently 1363 + if !tasks.is_empty() { 1364 + let results = try_join_all(tasks) 1365 + .await 1366 + .map_err(|e| RepoError::invalid(format!("Task join error: {}", e)))?; 1367 + 1368 + for subtree_result in results { 1369 + let (_, subtree_blocks) = subtree_result?; 1370 + blocks.extend(subtree_blocks); 1371 + } 1372 + } 1373 + 1374 + Ok((pointer, blocks)) 1375 + }) 1192 1376 } 1193 1377 1194 1378 impl<S: BlockStore> std::fmt::Debug for Mst<S> {