diff --git a/benchmarks/expected-plans/q6.txt b/benchmarks/expected-plans/q6.txt index ad27ba2b9aebe..55a6174abc4df 100644 --- a/benchmarks/expected-plans/q6.txt +++ b/benchmarks/expected-plans/q6.txt @@ -1,5 +1,5 @@ Projection: SUM(lineitem.l_extendedprice * lineitem.l_discount) AS revenue Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice * lineitem.l_discount)]] Projection: CAST(lineitem.l_discount AS Decimal128(30, 15)) AS CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate - Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= Decimal128(Some(49999999999999),30,15) AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= Decimal128(Some(69999999999999),30,15) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Decimal128(30, 15)) >= Decimal128(Some(49999999999999),30,15) AND CAST(lineitem.l_discount AS Decimal128(30, 15)) <= Decimal128(Some(69999999999999),30,15) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_shipdate] \ No newline at end of file diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 7ced782991e74..29ddffd6ce9a3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -617,6 +617,7 @@ dependencies = [ "ahash 0.8.0", "arrow", "datafusion-common", + "log", "sqlparser", ] diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8bb1d95a48a6c..88be00cd92f93 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -28,8 +28,8 @@ use crate::datasource::source_as_provider; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, Distinct, EmptyRelation, Filter, Join, Projection, Sort, SubqueryAlias, - TableScan, Window, + Aggregate, Distinct, EmptyRelation, Join, Projection, Sort, SubqueryAlias, TableScan, + Window, }; use crate::logical_plan::{ unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, @@ -756,15 +756,13 @@ impl DefaultPhysicalPlanner { input_exec, )?) ) } - LogicalPlan::Filter(Filter { - input, predicate, .. - }) => { - let physical_input = self.create_initial_plan(input, session_state).await?; + LogicalPlan::Filter(filter) => { + let physical_input = self.create_initial_plan(filter.input(), session_state).await?; let input_schema = physical_input.as_ref().schema(); - let input_dfschema = input.as_ref().schema(); + let input_dfschema = filter.input().schema(); let runtime_expr = self.create_physical_expr( - predicate, + filter.predicate(), input_dfschema, &input_schema, session_state, @@ -1696,8 +1694,7 @@ mod tests { use arrow::record_batch::RecordBatch; use datafusion_common::{DFField, DFSchema, DFSchemaRef}; use datafusion_expr::expr::GroupingSet; - use datafusion_expr::sum; - use datafusion_expr::{col, lit}; + use datafusion_expr::{col, lit, sum}; use fmt::Debug; use std::collections::HashMap; use std::convert::TryFrom; @@ -1705,7 +1702,10 @@ mod tests { fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::default()); - SessionState::with_config_rt(SessionConfig::new(), runtime) + let config = SessionConfig::new(); + // TODO we should really test that no optimizer rules are failing here + // let config = config.set_bool(crate::config::OPT_OPTIMIZER_SKIP_FAILED_RULES, false); + SessionState::with_config_rt(config, runtime) } async fn plan(logical_plan: &LogicalPlan) -> Result> { @@ -1972,6 +1972,11 @@ mod tests { let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false, set: None }"; assert!(format!("{:?}", execution_plan).contains(expected)); + Ok(()) + } + + #[tokio::test] + async fn in_list_types_struct_literal() -> Result<()> { // expression: "a in (struct::null, 'a')" let list = vec![struct_literal(), lit("a")]; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 1ea530bfb64e6..e111b21ad4e50 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -1211,14 +1211,19 @@ async fn boolean_literal() -> Result<()> { #[tokio::test] async fn unprojected_filter() { - let ctx = SessionContext::new(); + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); let df = ctx.read_table(table_with_sequence(1, 3).unwrap()).unwrap(); let df = df - .select(vec![col("i") + col("i")]) - .unwrap() .filter(col("i").gt(lit(2))) + .unwrap() + .select(vec![col("i") + col("i")]) .unwrap(); + + let plan = df.to_logical_plan().unwrap(); + println!("{}", plan.display_indent()); + let results = df.collect().await.unwrap(); let expected = vec![ diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 3280628a42eb9..5d7350f720cd9 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,5 @@ path = "src/lib.rs" ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { version = "24.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } +log = "^0.4" sqlparser = "0.25" diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c131682a8ef68..7f75098869b58 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -486,6 +486,14 @@ impl Expr { Expr::Alias(Box::new(self), name.to_owned()) } + /// Remove an alias from an expression if one exists. + pub fn unalias(self) -> Expr { + match self { + Expr::Alias(expr, _) => expr.as_ref().clone(), + _ => self, + } + } + /// Return `self IN ` if `negated` is false, otherwise /// return `self NOT IN `.a pub fn in_list(self, list: Vec, negated: bool) -> Expr { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 5226399840759..300c3b8cb6dcc 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -289,10 +289,10 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Ok(Self::from(LogicalPlan::Filter(Filter { - predicate: expr, - input: Arc::new(self.plan.clone()), - }))) + Ok(Self::from(LogicalPlan::Filter(Filter::try_new( + expr, + Arc::new(self.plan.clone()), + )?))) } /// Limit the number of rows returned diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index fd828021732e9..fce1c76c42c84 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. +///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::{ exprlist_to_fields, grouping_set_expr_count, grouping_set_to_exprlist, }; -use crate::{Expr, TableProviderFilterPushDown, TableSource}; +use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError}; use std::collections::HashSet; -///! Logical plan types use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -1148,18 +1148,59 @@ pub struct SubqueryAlias { #[derive(Clone)] pub struct Filter { /// The predicate expression, which must have Boolean type. - pub predicate: Expr, + predicate: Expr, /// The incoming logical plan - pub input: Arc, + input: Arc, } impl Filter { + /// Create a new filter operator. + pub fn try_new( + predicate: Expr, + input: Arc, + ) -> datafusion_common::Result { + // Filter predicates must return a boolean value so we try and validate that here. + // Note that it is not always possible to resolve the predicate expression during plan + // construction (such as with correlated subqueries) so we make a best effort here and + // ignore errors resolving the expression against the schema. + if let Ok(predicate_type) = predicate.get_type(input.schema()) { + if predicate_type != DataType::Boolean { + return Err(DataFusionError::Plan(format!( + "Cannot create filter with non-boolean predicate '{}' returning {}", + predicate, predicate_type + ))); + } + } + + // filter predicates should not be aliased + if let Expr::Alias(expr, alias) = predicate { + return Err(DataFusionError::Plan(format!( + "Attempted to create Filter predicate with \ + expression `{}` aliased as '{}'. Filter predicates should not be \ + aliased.", + expr, alias + ))); + } + + Ok(Self { predicate, input }) + } + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Filter> { match plan { LogicalPlan::Filter(it) => Ok(it), _ => plan_err!("Could not coerce into Filter!"), } } + + /// Access the filter predicate expression + pub fn predicate(&self) -> &Expr { + &self.predicate + } + + /// Access the filter input plan + pub fn input(&self) -> &Arc { + &self.input + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 501b4a8f11fc4..8e2544793b259 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,6 +17,7 @@ //! Expression utilities +use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ @@ -380,10 +381,51 @@ pub fn from_plan( .map(|s| s.to_vec()) .collect::>(), })), - LogicalPlan::Filter { .. } => Ok(LogicalPlan::Filter(Filter { - predicate: expr[0].clone(), - input: Arc::new(inputs[0].clone()), - })), + LogicalPlan::Filter { .. } => { + assert_eq!(1, expr.len()); + let predicate = expr[0].clone(); + + // filter predicates should not contain aliased expressions so we remove any aliases + // before this logic was added we would have aliases within filters such as for + // benchmark q6: + // + // lineitem.l_shipdate >= Date32(\"8766\") + // AND lineitem.l_shipdate < Date32(\"9131\") + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + // Decimal128(Some(49999999999999),30,15) + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + // Decimal128(Some(69999999999999),30,15) + // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + + struct RemoveAliases {} + + impl ExprRewriter for RemoveAliases { + fn pre_visit(&mut self, expr: &Expr) -> Result { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery { .. } => { + // subqueries could contain aliases so we don't recurse into those + Ok(RewriteRecursion::Stop) + } + Expr::Alias(_, _) => Ok(RewriteRecursion::Mutate), + _ => Ok(RewriteRecursion::Continue), + } + } + + fn mutate(&mut self, expr: Expr) -> Result { + Ok(expr.unalias()) + } + } + + let mut remove_aliases = RemoveAliases {}; + let predicate = predicate.rewrite(&mut remove_aliases)?; + + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(inputs[0].clone()), + )?)) + } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index cea5e8c46eb5e..811142c9b1e90 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -114,7 +114,9 @@ fn optimize( alias.clone(), )?)) } - LogicalPlan::Filter(Filter { predicate, input }) => { + LogicalPlan::Filter(filter) => { + let input = filter.input(); + let predicate = filter.predicate(); let input_schema = Arc::clone(input.schema()); let all_schemas: Vec = plan.all_schemas().into_iter().cloned().collect(); @@ -131,16 +133,16 @@ fn optimize( let (mut new_expr, new_input) = rewrite_expr( &[&[predicate.clone()]], &[&[id_array]], - input, + filter.input(), &mut expr_set, optimizer_config, )?; if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - Ok(LogicalPlan::Filter(Filter { + Ok(LogicalPlan::Filter(Filter::try_new( predicate, - input: Arc::new(new_input), - })) + Arc::new(new_input), + )?)) } else { Err(DataFusionError::Internal( "Failed to pop predicate expr".to_string(), diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 671deb0276b6a..d6727ad0fd61b 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -76,19 +76,19 @@ impl OptimizerRule for DecorrelateWhereExists { optimizer_config: &mut OptimizerConfig, ) -> datafusion_common::Result { match plan { - LogicalPlan::Filter(Filter { - predicate, - input: filter_input, - }) => { + LogicalPlan::Filter(filter) => { + let predicate = filter.predicate(); + let filter_input = filter.input(); + // Apply optimizer rule to current input let optimized_input = self.optimize(filter_input, optimizer_config)?; let (subqueries, other_exprs) = self.extract_subquery_exprs(predicate, optimizer_config)?; - let optimized_plan = LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - }); + let optimized_plan = LogicalPlan::Filter(Filter::try_new( + predicate.clone(), + Arc::new(optimized_input), + )?); if subqueries.is_empty() { // regular filter, no subquery exists clause here return Ok(optimized_plan); @@ -153,20 +153,21 @@ fn optimize_exists( // split into filters let mut subqry_filter_exprs = vec![]; - split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs); + split_conjunction(subqry_filter.predicate(), &mut subqry_filter_exprs); verify_not_disjunction(&subqry_filter_exprs)?; // Grab column names to join on let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?; + find_join_exprs(subqry_filter_exprs, subqry_filter.input().schema())?; let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?; + exprs_to_join_cols(&col_exprs, subqry_filter.input().schema(), false)?; if subqry_cols.is_empty() || outer_cols.is_empty() { plan_err!("cannot optimize non-correlated subquery")?; } // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from((*subqry_filter.input).clone()); + let mut subqry_plan = + LogicalPlanBuilder::from(subqry_filter.input().as_ref().clone()); if let Some(expr) = combine_filters(&other_subqry_exprs) { subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index d5af0911d32ef..a3443eaee882f 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -83,19 +83,19 @@ impl OptimizerRule for DecorrelateWhereIn { optimizer_config: &mut OptimizerConfig, ) -> datafusion_common::Result { match plan { - LogicalPlan::Filter(Filter { - predicate, - input: filter_input, - }) => { + LogicalPlan::Filter(filter) => { + let predicate = filter.predicate(); + let filter_input = filter.input(); + // Apply optimizer rule to current input let optimized_input = self.optimize(filter_input, optimizer_config)?; let (subqueries, other_exprs) = self.extract_subquery_exprs(predicate, optimizer_config)?; - let optimized_plan = LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - }); + let optimized_plan = LogicalPlan::Filter(Filter::try_new( + predicate.clone(), + Arc::new(optimized_input), + )?); if subqueries.is_empty() { // regular filter, no subquery exists clause here return Ok(optimized_plan); @@ -152,18 +152,18 @@ fn optimize_where_in( if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() { // split into filters let mut subqry_filter_exprs = vec![]; - split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs); + split_conjunction(subqry_filter.predicate(), &mut subqry_filter_exprs); verify_not_disjunction(&subqry_filter_exprs)?; // Grab column names to join on let (col_exprs, other_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema()) + find_join_exprs(subqry_filter_exprs, subqry_filter.input().schema()) .map_err(|e| context!("column correlation not found", e))?; if !col_exprs.is_empty() { // it's correlated - subqry_input = subqry_filter.input.clone(); + subqry_input = subqry_filter.input().clone(); (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false) + exprs_to_join_cols(&col_exprs, subqry_filter.input().schema(), false) .map_err(|e| context!("column correlation not found", e))?; other_subqry_exprs = other_exprs; } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 61e9613cfa046..6c0c51b86dc64 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - logical_plan::{EmptyRelation, Filter, LogicalPlan}, + logical_plan::{EmptyRelation, LogicalPlan}, utils::from_plan, Expr, }; @@ -43,21 +43,30 @@ impl OptimizerRule for EliminateFilter { plan: &LogicalPlan, optimizer_config: &mut OptimizerConfig, ) -> Result { - match plan { - LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), - input, - }) => { - if !*v { + let (filter_value, input) = match plan { + LogicalPlan::Filter(filter) => match filter.predicate() { + Expr::Literal(ScalarValue::Boolean(Some(v))) => { + (Some(*v), Some(filter.input())) + } + _ => (None, None), + }, + _ => (None, None), + }; + + match filter_value { + Some(v) => { + // input is guaranteed be Some due to previous code + let input = input.unwrap(); + if v { + self.optimize(input, optimizer_config) + } else { Ok(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: input.schema().clone(), })) - } else { - self.optimize(input, optimizer_config) } } - _ => { + None => { // Apply the optimization to all inputs of the plan let inputs = plan.inputs(); let new_inputs = inputs diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 4d237dd04a664..be33c796ea428 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -68,17 +68,17 @@ impl OptimizerRule for FilterNullJoinKeys { if !left_filters.is_empty() { let predicate = create_not_null_predicate(left_filters); - join.left = Arc::new(LogicalPlan::Filter(Filter { + join.left = Arc::new(LogicalPlan::Filter(Filter::try_new( predicate, - input: join.left.clone(), - })); + join.left.clone(), + )?)); } if !right_filters.is_empty() { let predicate = create_not_null_predicate(right_filters); - join.right = Arc::new(LogicalPlan::Filter(Filter { + join.right = Arc::new(LogicalPlan::Filter(Filter::try_new( predicate, - input: join.right.clone(), - })); + join.right.clone(), + )?)); } Ok(LogicalPlan::Join(join)) } diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 129766012531d..d1f6966213732 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -20,8 +20,8 @@ use datafusion_expr::{ col, expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, logical_plan::{ - Aggregate, CrossJoin, Filter, Join, JoinType, Limit, LogicalPlan, Projection, - TableScan, Union, + Aggregate, CrossJoin, Join, JoinType, Limit, LogicalPlan, Projection, TableScan, + Union, }, utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, TableProviderFilterPushDown, @@ -138,7 +138,7 @@ fn issue_filters( return push_down(&state, plan); } - let plan = utils::add_filter(plan.clone(), &predicates); + let plan = utils::add_filter(plan.clone(), &predicates)?; state.filters = remove_filters(&state.filters, &predicate_columns); @@ -326,7 +326,7 @@ fn optimize_join( Ok(plan) } else { // wrap the join on the filter whose predicates must be kept - let plan = utils::add_filter(plan, &to_keep.0); + let plan = utils::add_filter(plan, &to_keep.0)?; state.filters = remove_filters(&state.filters, &to_keep.1); Ok(plan) @@ -340,9 +340,9 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { push_down(&state, plan) } LogicalPlan::Analyze { .. } => push_down(&state, plan), - LogicalPlan::Filter(Filter { input, predicate }) => { + LogicalPlan::Filter(filter) => { let mut predicates = vec![]; - utils::split_conjunction(predicate, &mut predicates); + utils::split_conjunction(filter.predicate(), &mut predicates); predicates .into_iter() @@ -353,7 +353,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { Ok(()) })?; - optimize(input, state) + optimize(filter.input(), state) } LogicalPlan::Projection(Projection { input, diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs index 051a0ed745b23..5a048aac70f0c 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/projection_push_down.rs @@ -580,12 +580,12 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("c"))? + .filter(col("c").gt(lit(1)))? .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ - \n Filter: test.c\ + \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/optimizer/src/reduce_cross_join.rs b/datafusion/optimizer/src/reduce_cross_join.rs index 4c43188cff5d8..e8c2ff9ecd0a9 100644 --- a/datafusion/optimizer/src/reduce_cross_join.rs +++ b/datafusion/optimizer/src/reduce_cross_join.rs @@ -77,7 +77,9 @@ fn reduce_cross_join( all_join_keys: &mut HashSet<(Column, Column)>, ) -> Result { match plan { - LogicalPlan::Filter(Filter { input, predicate }) => { + LogicalPlan::Filter(filter) => { + let input = filter.input(); + let predicate = filter.predicate(); // join keys are handled locally let mut new_possible_join_keys: Vec<(Column, Column)> = vec![]; let mut new_all_join_keys = HashSet::new(); @@ -93,17 +95,17 @@ fn reduce_cross_join( // if there are no join keys then do nothing. if new_all_join_keys.is_empty() { - Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(new_plan), - })) + Ok(LogicalPlan::Filter(Filter::try_new( + predicate.clone(), + Arc::new(new_plan), + )?)) } else { // remove join expressions from filter match remove_join_expressions(predicate, &new_all_join_keys)? { - Some(filter_expr) => Ok(LogicalPlan::Filter(Filter { - predicate: filter_expr, - input: Arc::new(new_plan), - })), + Some(filter_expr) => Ok(LogicalPlan::Filter(Filter::try_new( + filter_expr, + Arc::new(new_plan), + )?)), _ => Ok(new_plan), } } diff --git a/datafusion/optimizer/src/reduce_outer_join.rs b/datafusion/optimizer/src/reduce_outer_join.rs index 6ca4a5994989a..93b706afe365d 100644 --- a/datafusion/optimizer/src/reduce_outer_join.rs +++ b/datafusion/optimizer/src/reduce_outer_join.rs @@ -69,34 +69,34 @@ fn reduce_outer_join( _optimizer_config: &OptimizerConfig, ) -> Result { match plan { - LogicalPlan::Filter(Filter { input, predicate }) => match &**input { + LogicalPlan::Filter(filter) => match filter.input().as_ref() { LogicalPlan::Join(join) => { extract_nonnullable_columns( - predicate, + filter.predicate(), nonnullable_cols, join.left.schema(), join.right.schema(), true, )?; - Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(reduce_outer_join( + Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(reduce_outer_join( _optimizer, - input, + filter.input(), nonnullable_cols, _optimizer_config, )?), - })) + )?)) } - _ => Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(reduce_outer_join( + _ => Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(reduce_outer_join( _optimizer, - input, + filter.input(), nonnullable_cols, _optimizer_config, )?), - })), + )?)), }, LogicalPlan::Join(join) => { let mut new_join_type = join.join_type; diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 2eadfb3d5f0e0..a4f051e1c2035 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -129,16 +129,16 @@ impl RewriteDisjunctivePredicate { ) -> Result { match plan { LogicalPlan::Filter(filter) => { - let predicate = predicate(&filter.predicate)?; + let predicate = predicate(filter.predicate())?; let rewritten_predicate = rewrite_predicate(predicate); let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(LogicalPlan::Filter(Filter { - predicate: rewritten_expr, - input: Arc::new(self.rewrite_disjunctive_predicate( - &filter.input, + Ok(LogicalPlan::Filter(Filter::try_new( + rewritten_expr, + Arc::new(self.rewrite_disjunctive_predicate( + filter.input(), _optimizer_config, )?), - })) + )?)) } _ => { let expr = plan.expressions(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 0a2256eba05c7..d148881108257 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -95,23 +95,23 @@ impl OptimizerRule for ScalarSubqueryToJoin { optimizer_config: &mut OptimizerConfig, ) -> Result { match plan { - LogicalPlan::Filter(Filter { predicate, input }) => { + LogicalPlan::Filter(filter) => { // Apply optimizer rule to current input - let optimized_input = self.optimize(input, optimizer_config)?; + let optimized_input = self.optimize(filter.input(), optimizer_config)?; let (subqueries, other_exprs) = - self.extract_subquery_exprs(predicate, optimizer_config)?; + self.extract_subquery_exprs(filter.predicate(), optimizer_config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here - return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - })); + return Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(optimized_input), + )?)); } // iterate through all subqueries in predicate, turning each into a join - let mut cur_input = (**input).clone(); + let mut cur_input = filter.input().as_ref().clone(); for subquery in subqueries { if let Some(optimized_subquery) = optimize_scalar( &subquery, @@ -122,10 +122,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - })); + return Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(optimized_input), + )?)); } } Ok(cur_input) @@ -228,7 +228,7 @@ fn optimize_scalar( // if there were filters, we use that logical plan, otherwise the plan from the aggregate let input = if let Some(filter) = filter { - &filter.input + filter.input() } else { &aggr.input }; @@ -236,7 +236,7 @@ fn optimize_scalar( // if there were filters, split and capture them let mut subqry_filter_exprs = vec![]; if let Some(filter) = filter { - split_conjunction(&filter.predicate, &mut subqry_filter_exprs); + split_conjunction(filter.predicate(), &mut subqry_filter_exprs); } verify_not_disjunction(&subqry_filter_exprs)?; diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index 95c640b2a7088..e0a5382b6b8f4 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -1998,7 +1998,7 @@ mod tests { assert_optimized_plan_eq( &plan, "\ - Filter: test.b > Int32(1) AS test.b > Int32(1) AND test.b > Int32(1)\ + Filter: test.b > Int32(1)\ \n Projection: test.a\ \n TableScan: test", ); @@ -2022,7 +2022,7 @@ mod tests { assert_optimized_plan_eq( &plan, "\ - Filter: test.a > Int32(5) AND test.b < Int32(6) AS test.a > Int32(5) AND test.b < Int32(6) AND test.a > Int32(5)\ + Filter: test.a > Int32(5) AND test.b < Int32(6)\ \n Projection: test.a, test.b\ \n TableScan: test", ); @@ -2043,8 +2043,8 @@ mod tests { let expected = "\ Projection: test.a\ - \n Filter: NOT test.c AS test.c = Boolean(false)\ - \n Filter: test.b AS test.b = Boolean(true)\ + \n Filter: NOT test.c\ + \n Filter: test.b\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2068,8 +2068,8 @@ mod tests { let expected = "\ Projection: test.a\ \n Limit: skip=0, fetch=1\ - \n Filter: test.c AS test.c != Boolean(false)\ - \n Filter: NOT test.b AS test.b != Boolean(true)\ + \n Filter: test.c\ + \n Filter: NOT test.b\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2088,7 +2088,7 @@ mod tests { let expected = "\ Projection: test.a\ - \n Filter: NOT test.b AND test.c AS test.b != Boolean(true) AND test.c = Boolean(true)\ + \n Filter: NOT test.b AND test.c\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2107,7 +2107,7 @@ mod tests { let expected = "\ Projection: test.a\ - \n Filter: NOT test.b OR NOT test.c AS test.b != Boolean(true) OR test.c = Boolean(false)\ + \n Filter: NOT test.b OR NOT test.c\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2126,7 +2126,7 @@ mod tests { let expected = "\ Projection: test.a\ - \n Filter: test.b AS NOT test.b = Boolean(false)\ + \n Filter: test.b\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2360,7 +2360,7 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true) AS now() < totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) + Int64(50000)\ + let expected = "Filter: Boolean(true)\ \n TableScan: test"; let actual = get_optimized_plan_formatted(&plan, &time); @@ -2408,7 +2408,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d <= Int32(10) AS NOT test.d > Int32(10)\ + let expected = "Filter: test.d <= Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2423,7 +2423,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100) AS NOT test.d > Int32(10) AND test.d < Int32(100)\ + let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2438,7 +2438,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100) AS NOT test.d > Int32(10) OR test.d < Int32(100)\ + let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2453,7 +2453,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d > Int32(10) AS NOT NOT test.d > Int32(10)\ + let expected = "Filter: test.d > Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2468,7 +2468,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IS NOT NULL AS NOT test.d IS NULL\ + let expected = "Filter: test.d IS NOT NULL\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2483,7 +2483,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IS NULL AS NOT test.d IS NOT NULL\ + let expected = "Filter: test.d IS NULL\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2498,7 +2498,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)]) AS NOT test.d IN (Map { iter: Iter([Int32(1), Int32(2), Int32(3)]) })\ + let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2513,7 +2513,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)]) AS NOT test.d NOT IN (Map { iter: Iter([Int32(1), Int32(2), Int32(3)]) })\ + let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2534,7 +2534,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\ + let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2555,7 +2555,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\ + let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2577,7 +2577,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.a NOT LIKE test.b AS NOT test.a LIKE test.b\ + let expected = "Filter: test.a NOT LIKE test.b\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2599,7 +2599,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.a LIKE test.b AS NOT test.a NOT LIKE test.b\ + let expected = "Filter: test.a LIKE test.b\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2614,7 +2614,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10) AS NOT test.d IS DISTINCT FROM Int32(10)\ + let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2629,7 +2629,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: test.d IS DISTINCT FROM Int32(10) AS NOT test.d IS NOT DISTINCT FROM Int32(10)\ + let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs b/datafusion/optimizer/src/subquery_filter_to_join.rs index 91d31f28edefd..bd07eeab2f82f 100644 --- a/datafusion/optimizer/src/subquery_filter_to_join.rs +++ b/datafusion/optimizer/src/subquery_filter_to_join.rs @@ -55,13 +55,13 @@ impl OptimizerRule for SubqueryFilterToJoin { optimizer_config: &mut OptimizerConfig, ) -> Result { match plan { - LogicalPlan::Filter(Filter { predicate, input }) => { + LogicalPlan::Filter(filter) => { // Apply optimizer rule to current input - let optimized_input = self.optimize(input, optimizer_config)?; + let optimized_input = self.optimize(filter.input(), optimizer_config)?; // Splitting filter expression into components by AND let mut filters = vec![]; - utils::split_conjunction(predicate, &mut filters); + utils::split_conjunction(filter.predicate(), &mut filters); // Searching for subquery-based filters let (subquery_filters, regular_filters): (Vec<&Expr>, Vec<&Expr>) = @@ -79,10 +79,10 @@ impl OptimizerRule for SubqueryFilterToJoin { })?; if !subqueries_in_regular.is_empty() { - return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - })); + return Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(optimized_input), + )?)); }; // Add subquery joins to new_input @@ -151,10 +151,10 @@ impl OptimizerRule for SubqueryFilterToJoin { let new_input = match opt_result { Ok(plan) => plan, Err(_) => { - return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - })) + return Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone(), + Arc::new(optimized_input), + )?)) } }; @@ -162,7 +162,7 @@ impl OptimizerRule for SubqueryFilterToJoin { if regular_filters.is_empty() { Ok(new_input) } else { - Ok(utils::add_filter(new_input, ®ular_filters)) + utils::add_filter(new_input, ®ular_filters) } } _ => { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d962dd7b45b9a..5e64092a48618 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -104,7 +104,7 @@ pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> { /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. -pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { +pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { // reduce filters to a single filter with an AND let predicate = predicates .iter() @@ -113,10 +113,10 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { and(acc, (*predicate).to_owned()) }); - LogicalPlan::Filter(Filter { + Ok(LogicalPlan::Filter(Filter::try_new( predicate, - input: Arc::new(plan), - }) + Arc::new(plan), + )?)) } /// Looks for correlating expressions: equality expressions with one field from the subquery, and diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index a5ddccdb6224e..7a9d635f8152f 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -38,9 +38,8 @@ use datafusion_common::{Column, DataFusionError}; use datafusion_expr::{ logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, - CrossJoin, Distinct, EmptyRelation, Extension, Filter, Join, JoinConstraint, - JoinType, Limit, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, - Window, + CrossJoin, Distinct, EmptyRelation, Extension, Join, JoinConstraint, JoinType, + Limit, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -806,17 +805,17 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }), - LogicalPlan::Filter(Filter { predicate, input }) => { + LogicalPlan::Filter(filter) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), + filter.input().as_ref(), extension_codec, )?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(predicate.try_into()?), + expr: Some(filter.predicate().try_into()?), }, ))), }) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b75efd6f47233..400bbe4fcaa92 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -982,10 +982,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { x.as_slice(), &[join_columns], )?; - Ok(LogicalPlan::Filter(Filter { - predicate: filter_expr, - input: Arc::new(left), - })) + Ok(LogicalPlan::Filter(Filter::try_new( + filter_expr, + Arc::new(left), + )?)) } _ => Ok(left), }