Python bindings to oxyroot. Makes reading .root files blazing fast 馃殌
at main 334 lines 11 kB view raw
1use ::oxyroot::{Named, RootFile}; 2use numpy::IntoPyArray; 3use pyo3::{exceptions::PyValueError, prelude::*, IntoPyObjectExt}; 4use std::fs::File; 5use std::path::Path; 6use std::sync::Arc; 7 8use arrow::array::{ 9 ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, UInt32Array, 10 UInt64Array, 11}; 12use arrow::datatypes::{DataType, Field, Schema}; 13use arrow::record_batch::RecordBatch; 14use parquet::arrow::ArrowWriter; 15use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; 16use parquet::file::properties::WriterProperties; 17 18#[pyclass(name = "RootFile")] 19struct PyRootFile { 20 #[pyo3(get)] 21 path: String, 22} 23 24#[pyclass(name = "Tree")] 25struct PyTree { 26 #[pyo3(get)] 27 path: String, 28 #[pyo3(get)] 29 name: String, 30} 31 32#[pyclass(name = "Branch")] 33struct PyBranch { 34 #[pyo3(get)] 35 path: String, 36 #[pyo3(get)] 37 tree_name: String, 38 #[pyo3(get)] 39 name: String, 40} 41 42#[pymethods] 43impl PyRootFile { 44 #[new] 45 fn new(path: String) -> Self { 46 PyRootFile { path } 47 } 48 49 fn keys(&self) -> PyResult<Vec<String>> { 50 let file = RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 51 Ok(file 52 .keys() 53 .into_iter() 54 .map(|k| k.name().to_string()) 55 .collect()) 56 } 57 58 fn __getitem__(&self, name: &str) -> PyResult<PyTree> { 59 Ok(PyTree { 60 path: self.path.clone(), 61 name: name.to_string(), 62 }) 63 } 64} 65 66#[pymethods] 67impl PyTree { 68 fn branches(&self) -> PyResult<Vec<String>> { 69 let mut file = 70 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 71 let tree = file 72 .get_tree(&self.name) 73 .map_err(|e| PyValueError::new_err(e.to_string()))?; 74 Ok(tree.branches().map(|b| b.name().to_string()).collect()) 75 } 76 77 fn __getitem__(&self, name: &str) -> PyResult<PyBranch> { 78 Ok(PyBranch { 79 path: self.path.clone(), 80 tree_name: self.name.clone(), 81 name: name.to_string(), 82 }) 83 } 84 85 fn __iter__(slf: PyRef<Self>) -> PyResult<Py<PyBranchIterator>> { 86 let branches = slf.branches()?; 87 Py::new( 88 slf.py(), 89 PyBranchIterator { 90 path: slf.path.clone(), 91 tree_name: slf.name.clone(), 92 branches: branches.into_iter(), 93 }, 94 ) 95 } 96 97 #[pyo3(signature = (output_file, overwrite = false, compression = "snappy", columns = None))] 98 fn to_parquet( 99 &self, 100 output_file: String, 101 overwrite: bool, 102 compression: &str, 103 columns: Option<Vec<String>>, 104 ) -> PyResult<()> { 105 if !overwrite && Path::new(&output_file).exists() { 106 return Err(PyValueError::new_err("File exists, use overwrite=True")); 107 } 108 109 let compression = match compression { 110 "snappy" => Compression::SNAPPY, 111 "uncompressed" => Compression::UNCOMPRESSED, 112 "gzip" => Compression::GZIP(GzipLevel::default()), 113 "lzo" => Compression::LZO, 114 "brotli" => Compression::BROTLI(BrotliLevel::default()), 115 "lz4" => Compression::LZ4, 116 "zstd" => Compression::ZSTD(ZstdLevel::default()), 117 _ => return Err(PyValueError::new_err("Invalid compression type")), 118 }; 119 120 let mut file = 121 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 122 let tree = file 123 .get_tree(&self.name) 124 .map_err(|e| PyValueError::new_err(e.to_string()))?; 125 126 let mut fields = Vec::new(); 127 let mut arrays = Vec::new(); 128 129 let branches_to_save = if let Some(columns) = columns { 130 columns 131 } else { 132 tree.branches().map(|b| b.name().to_string()).collect() 133 }; 134 135 for branch_name in branches_to_save { 136 let branch = match tree.branch(&branch_name) { 137 Some(branch) => branch, 138 None => { 139 println!("Branch '{}' not found, skipping", branch_name); 140 continue; 141 } 142 }; 143 144 let (field, array) = match branch.item_type_name().as_str() { 145 "float" => { 146 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>(); 147 let array: ArrayRef = Arc::new(Float32Array::from(data)); 148 (Field::new(&branch_name, DataType::Float32, false), array) 149 } 150 "double" => { 151 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>(); 152 let array: ArrayRef = Arc::new(Float64Array::from(data)); 153 (Field::new(&branch_name, DataType::Float64, false), array) 154 } 155 "int32_t" => { 156 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>(); 157 let array: ArrayRef = Arc::new(Int32Array::from(data)); 158 (Field::new(&branch_name, DataType::Int32, false), array) 159 } 160 "int64_t" => { 161 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>(); 162 let array: ArrayRef = Arc::new(Int64Array::from(data)); 163 (Field::new(&branch_name, DataType::Int64, false), array) 164 } 165 "uint32_t" => { 166 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>(); 167 let array: ArrayRef = Arc::new(UInt32Array::from(data)); 168 (Field::new(&branch_name, DataType::UInt32, false), array) 169 } 170 "uint64_t" => { 171 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>(); 172 let array: ArrayRef = Arc::new(UInt64Array::from(data)); 173 (Field::new(&branch_name, DataType::UInt64, false), array) 174 } 175 "string" => { 176 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>(); 177 let array: ArrayRef = Arc::new(StringArray::from(data)); 178 (Field::new(&branch_name, DataType::Utf8, false), array) 179 } 180 other => { 181 println!("Unsupported branch type: {}, skipping", other); 182 continue; 183 } 184 }; 185 fields.push(field); 186 arrays.push(array); 187 } 188 189 let schema = Arc::new(Schema::new(fields)); 190 let props = WriterProperties::builder() 191 .set_compression(compression) 192 .build(); 193 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); 194 195 let file = File::create(output_file)?; 196 let mut writer = ArrowWriter::try_new(file, schema, Some(props)) 197 .map_err(|e| PyValueError::new_err(e.to_string()))?; 198 writer 199 .write(&batch) 200 .map_err(|e| PyValueError::new_err(e.to_string()))?; 201 writer 202 .close() 203 .map_err(|e| PyValueError::new_err(e.to_string()))?; 204 205 Ok(()) 206 } 207} 208 209#[pyclass] 210struct PyBranchIterator { 211 path: String, 212 tree_name: String, 213 branches: std::vec::IntoIter<String>, 214} 215 216#[pymethods] 217impl PyBranchIterator { 218 fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { 219 slf 220 } 221 222 fn __next__(&mut self) -> Option<PyBranch> { 223 self.branches.next().map(|name| PyBranch { 224 path: self.path.clone(), 225 tree_name: self.tree_name.clone(), 226 name, 227 }) 228 } 229} 230 231#[pymethods] 232impl PyBranch { 233 fn array(&self, py: Python) -> PyResult<Py<PyAny>> { 234 let mut file = 235 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 236 let tree = file 237 .get_tree(&self.tree_name) 238 .map_err(|e| PyValueError::new_err(e.to_string()))?; 239 let branch = tree 240 .branch(&self.name) 241 .ok_or_else(|| PyValueError::new_err("Branch not found"))?; 242 243 match branch.item_type_name().as_str() { 244 "float" => { 245 let data = branch 246 .as_iter::<f32>() 247 .map_err(|e| PyValueError::new_err(e.to_string()))? 248 .collect::<Vec<_>>(); 249 Ok(data.into_pyarray(py).into()) 250 } 251 "double" => { 252 let data = branch 253 .as_iter::<f64>() 254 .map_err(|e| PyValueError::new_err(e.to_string()))? 255 .collect::<Vec<_>>(); 256 Ok(data.into_pyarray(py).into()) 257 } 258 "int32_t" => { 259 let data = branch 260 .as_iter::<i32>() 261 .map_err(|e| PyValueError::new_err(e.to_string()))? 262 .collect::<Vec<_>>(); 263 Ok(data.into_pyarray(py).into()) 264 } 265 "int64_t" => { 266 let data = branch 267 .as_iter::<i64>() 268 .map_err(|e| PyValueError::new_err(e.to_string()))? 269 .collect::<Vec<_>>(); 270 Ok(data.into_pyarray(py).into()) 271 } 272 "uint32_t" => { 273 let data = branch 274 .as_iter::<u32>() 275 .map_err(|e| PyValueError::new_err(e.to_string()))? 276 .collect::<Vec<_>>(); 277 Ok(data.into_pyarray(py).into()) 278 } 279 "uint64_t" => { 280 let data = branch 281 .as_iter::<u64>() 282 .map_err(|e| PyValueError::new_err(e.to_string()))? 283 .collect::<Vec<_>>(); 284 Ok(data.into_pyarray(py).into()) 285 } 286 "string" => { 287 let data = branch 288 .as_iter::<String>() 289 .map_err(|e| PyValueError::new_err(e.to_string()))? 290 .collect::<Vec<_>>(); 291 Ok(data.into_py_any(py).unwrap()) 292 } 293 other => Err(PyValueError::new_err(format!( 294 "Unsupported branch type: {}", 295 other 296 ))), 297 } 298 } 299 300 #[getter] 301 fn typename(&self) -> PyResult<String> { 302 let mut file = 303 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 304 let tree = file 305 .get_tree(&self.tree_name) 306 .map_err(|e| PyValueError::new_err(e.to_string()))?; 307 let branch = tree 308 .branch(&self.name) 309 .ok_or_else(|| PyValueError::new_err("Branch not found"))?; 310 Ok(branch.item_type_name()) 311 } 312} 313 314#[pyfunction] 315fn open(path: String) -> PyResult<PyRootFile> { 316 Ok(PyRootFile::new(path)) 317} 318 319#[pyfunction] 320fn version() -> PyResult<String> { 321 Ok(env!("CARGO_PKG_VERSION").to_string()) 322} 323 324/// A Python module to read root files, implemented in Rust. 325#[pymodule] 326fn oxyroot(m: &Bound<'_, PyModule>) -> PyResult<()> { 327 m.add_function(wrap_pyfunction!(version, m)?)?; 328 m.add_function(wrap_pyfunction!(open, m)?)?; 329 m.add_class::<PyRootFile>()?; 330 m.add_class::<PyTree>()?; 331 m.add_class::<PyBranch>()?; 332 m.add_class::<PyBranchIterator>()?; 333 Ok(()) 334}