From 2b619e6c2fc113d749dcd525d5ea5d5f8b73e17a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 18 Feb 2023 09:40:31 -0700 Subject: [PATCH 1/4] Add Python wrapper for LogicalPlan::Filter --- datafusion/tests/test_imports.py | 3 +- src/expr.rs | 2 + src/expr/filter.rs | 103 +++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/expr/filter.rs diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index e5d958537..808cb50f4 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -35,6 +35,7 @@ Expr, Projection, TableScan, + Filter, ) @@ -55,7 +56,7 @@ def test_class_module_is_datafusion(): ]: assert klass.__module__ == "datafusion" - for klass in [Expr, Projection, TableScan]: + for klass in [Expr, Projection, TableScan, Filter]: assert klass.__module__ == "datafusion.expr" for klass in [DFField, DFSchema]: diff --git a/src/expr.rs b/src/expr.rs index f3695febf..e888434cc 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -27,6 +27,7 @@ use datafusion::scalar::ScalarValue; pub mod logical_node; pub mod projection; pub mod table_scan; +pub mod filter; /// A PyExpr that can be used on a DataFrame #[pyclass(name = "Expr", module = "datafusion.expr", subclass)] @@ -143,5 +144,6 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/expr/filter.rs b/src/expr/filter.rs new file mode 100644 index 000000000..b2113245c --- /dev/null +++ b/src/expr/filter.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::DataFusionError; +use datafusion_expr::logical_plan::Filter; +use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; + +use crate::common::df_schema::PyDFSchema; +use crate::errors::py_runtime_err; +use crate::expr::logical_node::LogicalNode; +use crate::expr::PyExpr; +use crate::sql::logical::PyLogicalPlan; + +#[pyclass(name = "Filter", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyFilter { + filter: Filter, +} + +impl From for PyFilter { + fn from(filter: Filter) -> PyFilter { + PyFilter { filter } + } +} + +impl TryFrom for Filter { + type Error = DataFusionError; + + fn try_from(py_proj: PyFilter) -> Result { + Filter::try_new( + py_proj.filter.predicate, + py_proj.filter.input.clone(), + ) + } +} + +impl Display for PyFilter { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Filter + \nExpr(s): {:?} + \nPredicate: {:?}", + &self.filter.predicate, &self.filter.input + ) + } +} + +#[pymethods] +impl PyFilter { + /// Retrieves the predicate expression for this `Filter` + #[pyo3(name = "predicate")] + fn py_predicate(&self) -> PyResult { + Ok(PyExpr::from(self.filter.predicate.clone())) + } + + // Retrieves the input `LogicalPlan` to this `Filter` node + #[pyo3(name = "input")] + fn py_input(&self) -> PyResult { + // DataFusion make a loose guarantee that each Filter should have an input, however + // we check for that hear since we are performing explicit index retrieval + let inputs = LogicalNode::input(self); + if !inputs.is_empty() { + return Ok(inputs[0].clone()); + } + + Err(py_runtime_err(format!( + "Expected `input` field for Filter node: {}", + self + ))) + } + + // Resulting Schema for this `Filter` node instance + #[pyo3(name = "schema")] + fn py_schema(&self) -> PyResult { + Ok(self.filter.input.schema().as_ref().clone().into()) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("Filter({})", self)) + } +} + +impl LogicalNode for PyFilter { + fn input(&self) -> Vec { + vec![PyLogicalPlan::from((*self.filter.input).clone())] + } +} From cffec8b6c2eb9c83f2e91b0a578b4cd0c6255eaa Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 18 Feb 2023 17:25:43 -0700 Subject: [PATCH 2/4] clippy --- src/expr.rs | 2 +- src/expr/filter.rs | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/expr.rs b/src/expr.rs index e888434cc..60eb8c291 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -24,10 +24,10 @@ use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; use datafusion::scalar::ScalarValue; +pub mod filter; pub mod logical_node; pub mod projection; pub mod table_scan; -pub mod filter; /// A PyExpr that can be used on a DataFrame #[pyclass(name = "Expr", module = "datafusion.expr", subclass)] diff --git a/src/expr/filter.rs b/src/expr/filter.rs index b2113245c..6baa31028 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -42,10 +42,7 @@ impl TryFrom for Filter { type Error = DataFusionError; fn try_from(py_proj: PyFilter) -> Result { - Filter::try_new( - py_proj.filter.predicate, - py_proj.filter.input.clone(), - ) + Filter::try_new(py_proj.filter.predicate, py_proj.filter.input.clone()) } } From 6520bbe5587aa3403b01ef3ba10750194db3ffec Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 18 Feb 2023 17:35:30 -0700 Subject: [PATCH 3/4] clippy --- src/expr/filter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 6baa31028..e601aa962 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -42,7 +42,7 @@ impl TryFrom for Filter { type Error = DataFusionError; fn try_from(py_proj: PyFilter) -> Result { - Filter::try_new(py_proj.filter.predicate, py_proj.filter.input.clone()) + Filter::try_new(py_proj.filter.predicate, py_proj.filter.input) } } From 20836dd35c780752fd3c05e632872f56ba6ebb66 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 19 Feb 2023 15:48:18 -0700 Subject: [PATCH 4/4] Update src/expr/filter.rs Co-authored-by: Liang-Chi Hsieh --- src/expr/filter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 9587ee0a8..7726c8640 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -47,8 +47,8 @@ impl Display for PyFilter { write!( f, "Filter - \nExpr(s): {:?} - \nPredicate: {:?}", + \nPredicate: {:?} + \Input: {:?}", &self.filter.predicate, &self.filter.input ) }