diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index ee47b0e6b..dfc1f6535 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -35,6 +35,7 @@ Expr, Projection, TableScan, + Filter, Limit, Aggregate, Sort, @@ -58,7 +59,7 @@ def test_class_module_is_datafusion(): ]: assert klass.__module__ == "datafusion" - for klass in [Expr, Projection, TableScan, Aggregate, Sort, Limit]: + for klass in [Expr, Projection, TableScan, Aggregate, Sort, Limit, Filter]: assert klass.__module__ == "datafusion.expr" for klass in [DFField, DFSchema]: diff --git a/src/expr.rs b/src/expr.rs index 7ef9407cb..adb9e55a0 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -25,6 +25,7 @@ use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; use datafusion::scalar::ScalarValue; pub mod aggregate; +pub mod filter; pub mod limit; pub mod logical_node; pub mod projection; @@ -146,6 +147,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/expr/filter.rs b/src/expr/filter.rs new file mode 100644 index 000000000..b7b48b9d2 --- /dev/null +++ b/src/expr/filter.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::logical_plan::Filter; +use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; + +use crate::common::df_schema::PyDFSchema; +use crate::expr::logical_node::LogicalNode; +use crate::expr::PyExpr; +use crate::sql::logical::PyLogicalPlan; + +#[pyclass(name = "Filter", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyFilter { + filter: Filter, +} + +impl From for PyFilter { + fn from(filter: Filter) -> PyFilter { + PyFilter { filter } + } +} + +impl From for Filter { + fn from(filter: PyFilter) -> Self { + filter.filter + } +} + +impl Display for PyFilter { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Filter + \nPredicate: {:?} + \nInput: {:?}", + &self.filter.predicate, &self.filter.input + ) + } +} + +#[pymethods] +impl PyFilter { + /// Retrieves the predicate expression for this `Filter` + fn predicate(&self) -> PyExpr { + PyExpr::from(self.filter.predicate.clone()) + } + + /// Retrieves the input `LogicalPlan` to this `Filter` node + fn input(&self) -> PyLogicalPlan { + PyLogicalPlan::from((*self.filter.input).clone()) + } + + /// Resulting Schema for this `Filter` node instance + fn schema(&self) -> PyDFSchema { + self.filter.input.schema().as_ref().clone().into() + } + + fn __repr__(&self) -> String { + format!("Filter({})", self) + } +} + +impl LogicalNode for PyFilter { + fn input(&self) -> Vec { + vec![PyLogicalPlan::from((*self.filter.input).clone())] + } +}