Python bindings to oxyroot. Makes reading .root files blazing fast 馃殌
1use ::oxyroot::{Named, RootFile};
2use numpy::IntoPyArray;
3use pyo3::{exceptions::PyValueError, prelude::*, IntoPyObjectExt};
4use std::fs::File;
5use std::path::Path;
6use std::sync::Arc;
7
8use arrow::array::{
9 ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, UInt32Array,
10 UInt64Array,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use parquet::arrow::ArrowWriter;
15use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
16use parquet::file::properties::WriterProperties;
17
18#[pyclass(name = "RootFile")]
19struct PyRootFile {
20 #[pyo3(get)]
21 path: String,
22}
23
24#[pyclass(name = "Tree")]
25struct PyTree {
26 #[pyo3(get)]
27 path: String,
28 #[pyo3(get)]
29 name: String,
30}
31
32#[pyclass(name = "Branch")]
33struct PyBranch {
34 #[pyo3(get)]
35 path: String,
36 #[pyo3(get)]
37 tree_name: String,
38 #[pyo3(get)]
39 name: String,
40}
41
42#[pymethods]
43impl PyRootFile {
44 #[new]
45 fn new(path: String) -> Self {
46 PyRootFile { path }
47 }
48
49 fn keys(&self) -> PyResult<Vec<String>> {
50 let file = RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
51 Ok(file
52 .keys()
53 .into_iter()
54 .map(|k| k.name().to_string())
55 .collect())
56 }
57
58 fn __getitem__(&self, name: &str) -> PyResult<PyTree> {
59 Ok(PyTree {
60 path: self.path.clone(),
61 name: name.to_string(),
62 })
63 }
64}
65
66#[pymethods]
67impl PyTree {
68 fn branches(&self) -> PyResult<Vec<String>> {
69 let mut file =
70 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
71 let tree = file
72 .get_tree(&self.name)
73 .map_err(|e| PyValueError::new_err(e.to_string()))?;
74 Ok(tree.branches().map(|b| b.name().to_string()).collect())
75 }
76
77 fn __getitem__(&self, name: &str) -> PyResult<PyBranch> {
78 Ok(PyBranch {
79 path: self.path.clone(),
80 tree_name: self.name.clone(),
81 name: name.to_string(),
82 })
83 }
84
85 fn __iter__(slf: PyRef<Self>) -> PyResult<Py<PyBranchIterator>> {
86 let branches = slf.branches()?;
87 Py::new(
88 slf.py(),
89 PyBranchIterator {
90 path: slf.path.clone(),
91 tree_name: slf.name.clone(),
92 branches: branches.into_iter(),
93 },
94 )
95 }
96
97 #[pyo3(signature = (output_file, overwrite = false, compression = "snappy", columns = None))]
98 fn to_parquet(
99 &self,
100 output_file: String,
101 overwrite: bool,
102 compression: &str,
103 columns: Option<Vec<String>>,
104 ) -> PyResult<()> {
105 if !overwrite && Path::new(&output_file).exists() {
106 return Err(PyValueError::new_err("File exists, use overwrite=True"));
107 }
108
109 let compression = match compression {
110 "snappy" => Compression::SNAPPY,
111 "uncompressed" => Compression::UNCOMPRESSED,
112 "gzip" => Compression::GZIP(GzipLevel::default()),
113 "lzo" => Compression::LZO,
114 "brotli" => Compression::BROTLI(BrotliLevel::default()),
115 "lz4" => Compression::LZ4,
116 "zstd" => Compression::ZSTD(ZstdLevel::default()),
117 _ => return Err(PyValueError::new_err("Invalid compression type")),
118 };
119
120 let mut file =
121 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
122 let tree = file
123 .get_tree(&self.name)
124 .map_err(|e| PyValueError::new_err(e.to_string()))?;
125
126 let mut fields = Vec::new();
127 let mut arrays = Vec::new();
128
129 let branches_to_save = if let Some(columns) = columns {
130 columns
131 } else {
132 tree.branches().map(|b| b.name().to_string()).collect()
133 };
134
135 for branch_name in branches_to_save {
136 let branch = match tree.branch(&branch_name) {
137 Some(branch) => branch,
138 None => {
139 println!("Branch '{}' not found, skipping", branch_name);
140 continue;
141 }
142 };
143
144 let (field, array) = match branch.item_type_name().as_str() {
145 "float" => {
146 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>();
147 let array: ArrayRef = Arc::new(Float32Array::from(data));
148 (Field::new(&branch_name, DataType::Float32, false), array)
149 }
150 "double" => {
151 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>();
152 let array: ArrayRef = Arc::new(Float64Array::from(data));
153 (Field::new(&branch_name, DataType::Float64, false), array)
154 }
155 "int32_t" => {
156 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>();
157 let array: ArrayRef = Arc::new(Int32Array::from(data));
158 (Field::new(&branch_name, DataType::Int32, false), array)
159 }
160 "int64_t" => {
161 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>();
162 let array: ArrayRef = Arc::new(Int64Array::from(data));
163 (Field::new(&branch_name, DataType::Int64, false), array)
164 }
165 "uint32_t" => {
166 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>();
167 let array: ArrayRef = Arc::new(UInt32Array::from(data));
168 (Field::new(&branch_name, DataType::UInt32, false), array)
169 }
170 "uint64_t" => {
171 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>();
172 let array: ArrayRef = Arc::new(UInt64Array::from(data));
173 (Field::new(&branch_name, DataType::UInt64, false), array)
174 }
175 "string" => {
176 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>();
177 let array: ArrayRef = Arc::new(StringArray::from(data));
178 (Field::new(&branch_name, DataType::Utf8, false), array)
179 }
180 other => {
181 println!("Unsupported branch type: {}, skipping", other);
182 continue;
183 }
184 };
185 fields.push(field);
186 arrays.push(array);
187 }
188
189 let schema = Arc::new(Schema::new(fields));
190 let props = WriterProperties::builder()
191 .set_compression(compression)
192 .build();
193 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
194
195 let file = File::create(output_file)?;
196 let mut writer = ArrowWriter::try_new(file, schema, Some(props))
197 .map_err(|e| PyValueError::new_err(e.to_string()))?;
198 writer
199 .write(&batch)
200 .map_err(|e| PyValueError::new_err(e.to_string()))?;
201 writer
202 .close()
203 .map_err(|e| PyValueError::new_err(e.to_string()))?;
204
205 Ok(())
206 }
207}
208
209#[pyclass]
210struct PyBranchIterator {
211 path: String,
212 tree_name: String,
213 branches: std::vec::IntoIter<String>,
214}
215
216#[pymethods]
217impl PyBranchIterator {
218 fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
219 slf
220 }
221
222 fn __next__(&mut self) -> Option<PyBranch> {
223 self.branches.next().map(|name| PyBranch {
224 path: self.path.clone(),
225 tree_name: self.tree_name.clone(),
226 name,
227 })
228 }
229}
230
231#[pymethods]
232impl PyBranch {
233 fn array(&self, py: Python) -> PyResult<Py<PyAny>> {
234 let mut file =
235 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
236 let tree = file
237 .get_tree(&self.tree_name)
238 .map_err(|e| PyValueError::new_err(e.to_string()))?;
239 let branch = tree
240 .branch(&self.name)
241 .ok_or_else(|| PyValueError::new_err("Branch not found"))?;
242
243 match branch.item_type_name().as_str() {
244 "float" => {
245 let data = branch
246 .as_iter::<f32>()
247 .map_err(|e| PyValueError::new_err(e.to_string()))?
248 .collect::<Vec<_>>();
249 Ok(data.into_pyarray(py).into())
250 }
251 "double" => {
252 let data = branch
253 .as_iter::<f64>()
254 .map_err(|e| PyValueError::new_err(e.to_string()))?
255 .collect::<Vec<_>>();
256 Ok(data.into_pyarray(py).into())
257 }
258 "int32_t" => {
259 let data = branch
260 .as_iter::<i32>()
261 .map_err(|e| PyValueError::new_err(e.to_string()))?
262 .collect::<Vec<_>>();
263 Ok(data.into_pyarray(py).into())
264 }
265 "int64_t" => {
266 let data = branch
267 .as_iter::<i64>()
268 .map_err(|e| PyValueError::new_err(e.to_string()))?
269 .collect::<Vec<_>>();
270 Ok(data.into_pyarray(py).into())
271 }
272 "uint32_t" => {
273 let data = branch
274 .as_iter::<u32>()
275 .map_err(|e| PyValueError::new_err(e.to_string()))?
276 .collect::<Vec<_>>();
277 Ok(data.into_pyarray(py).into())
278 }
279 "uint64_t" => {
280 let data = branch
281 .as_iter::<u64>()
282 .map_err(|e| PyValueError::new_err(e.to_string()))?
283 .collect::<Vec<_>>();
284 Ok(data.into_pyarray(py).into())
285 }
286 "string" => {
287 let data = branch
288 .as_iter::<String>()
289 .map_err(|e| PyValueError::new_err(e.to_string()))?
290 .collect::<Vec<_>>();
291 Ok(data.into_py_any(py).unwrap())
292 }
293 other => Err(PyValueError::new_err(format!(
294 "Unsupported branch type: {}",
295 other
296 ))),
297 }
298 }
299
300 #[getter]
301 fn typename(&self) -> PyResult<String> {
302 let mut file =
303 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
304 let tree = file
305 .get_tree(&self.tree_name)
306 .map_err(|e| PyValueError::new_err(e.to_string()))?;
307 let branch = tree
308 .branch(&self.name)
309 .ok_or_else(|| PyValueError::new_err("Branch not found"))?;
310 Ok(branch.item_type_name())
311 }
312}
313
314#[pyfunction]
315fn open(path: String) -> PyResult<PyRootFile> {
316 Ok(PyRootFile::new(path))
317}
318
319#[pyfunction]
320fn version() -> PyResult<String> {
321 Ok(env!("CARGO_PKG_VERSION").to_string())
322}
323
324/// A Python module to read root files, implemented in Rust.
325#[pymodule]
326fn oxyroot(m: &Bound<'_, PyModule>) -> PyResult<()> {
327 m.add_function(wrap_pyfunction!(version, m)?)?;
328 m.add_function(wrap_pyfunction!(open, m)?)?;
329 m.add_class::<PyRootFile>()?;
330 m.add_class::<PyTree>()?;
331 m.add_class::<PyBranch>()?;
332 m.add_class::<PyBranchIterator>()?;
333 Ok(())
334}