Python bindings to oxyroot. Makes reading .root files blazing fast ๐Ÿš€

Add arguments for compression and columns

+43 -8
+43 -8
src/lib.rs
··· 12 12 use arrow::datatypes::{DataType, Field, Schema}; 13 13 use arrow::record_batch::RecordBatch; 14 14 use parquet::arrow::ArrowWriter; 15 + use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; 16 + use parquet::file::properties::WriterProperties; 15 17 16 18 #[pyclass(name = "RootFile")] 17 19 struct PyRootFile { ··· 92 94 ) 93 95 } 94 96 95 - #[pyo3(signature = (output_file, overwrite = false))] 96 - fn to_parquet(&self, output_file: String, overwrite: bool) -> PyResult<()> { 97 + #[pyo3(signature = (output_file, overwrite = false, compression = "snappy", columns = None))] 98 + fn to_parquet( 99 + &self, 100 + output_file: String, 101 + overwrite: bool, 102 + compression: &str, 103 + columns: Option<Vec<String>>, 104 + ) -> PyResult<()> { 97 105 if !overwrite && Path::new(&output_file).exists() { 98 106 return Err(PyValueError::new_err("File exists, use overwrite=True")); 99 107 } 100 108 109 + let compression = match compression { 110 + "snappy" => Compression::SNAPPY, 111 + "uncompressed" => Compression::UNCOMPRESSED, 112 + "gzip" => Compression::GZIP(GzipLevel::default()), 113 + "lzo" => Compression::LZO, 114 + "brotli" => Compression::BROTLI(BrotliLevel::default()), 115 + "lz4" => Compression::LZ4, 116 + "zstd" => Compression::ZSTD(ZstdLevel::default()), 117 + _ => return Err(PyValueError::new_err("Invalid compression type")), 118 + }; 119 + 101 120 let mut file = 102 121 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 103 122 let tree = file ··· 105 124 .map_err(|e| PyValueError::new_err(e.to_string()))?; 106 125 107 126 let mut fields = Vec::new(); 108 - let mut columns = Vec::new(); 127 + let mut arrays = Vec::new(); 128 + 129 + let branches_to_save = if let Some(columns) = columns { 130 + columns 131 + } else { 132 + tree.branches().map(|b| b.name().to_string()).collect() 133 + }; 109 134 110 - for branch in tree.branches() { 111 - let branch_name = branch.name().to_string(); 135 + for branch_name in branches_to_save { 136 + let branch = match tree.branch(&branch_name) { 137 + Some(branch) => branch, 138 + None => { 139 + println!("Branch '{}' not found, skipping", branch_name); 140 + continue; 141 + } 142 + }; 143 + 112 144 let (field, array) = match branch.item_type_name().as_str() { 113 145 "float" => { 114 146 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>(); ··· 151 183 } 152 184 }; 153 185 fields.push(field); 154 - columns.push(array); 186 + arrays.push(array); 155 187 } 156 188 157 189 let schema = Arc::new(Schema::new(fields)); 158 - let batch = RecordBatch::try_new(schema.clone(), columns).unwrap(); 190 + let props = WriterProperties::builder() 191 + .set_compression(compression) 192 + .build(); 193 + let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); 159 194 160 195 let file = File::create(output_file)?; 161 - let mut writer = ArrowWriter::try_new(file, schema, None) 196 + let mut writer = ArrowWriter::try_new(file, schema, Some(props)) 162 197 .map_err(|e| PyValueError::new_err(e.to_string()))?; 163 198 writer 164 199 .write(&batch)