diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 2a805a5fc0e8b..7284d3e8ac8f9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -33,10 +33,7 @@ use crate::{ MemTable, ViewTable, }, logical_plan::{PlanType, ToStringifiedPlan}, - optimizer::{ - eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit, - optimizer::Optimizer, - }, + optimizer::optimizer::Optimizer, physical_optimizer::{ aggregate_statistics::AggregateStatistics, hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule, @@ -72,16 +69,7 @@ use crate::logical_plan::{ CreateMemoryTable, CreateView, DropTable, FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, }; -use crate::optimizer::common_subexpr_eliminate::CommonSubexprEliminate; -use crate::optimizer::filter_push_down::FilterPushDown; -use crate::optimizer::limit_push_down::LimitPushDown; use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule}; -use crate::optimizer::projection_push_down::ProjectionPushDown; -use crate::optimizer::reduce_cross_join::ReduceCrossJoin; -use crate::optimizer::reduce_outer_join::ReduceOuterJoin; -use crate::optimizer::simplify_expressions::SimplifyExpressions; -use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use crate::optimizer::subquery_filter_to_join::SubqueryFilterToJoin; use datafusion_sql::{ResolvedTableReference, TableReference}; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; @@ -107,13 +95,6 @@ use chrono::{DateTime, Utc}; use datafusion_common::ScalarValue; use datafusion_expr::logical_plan::DropView; use datafusion_expr::{TableSource, TableType}; -use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; -use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; -use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; -use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion_optimizer::type_coercion::TypeCoercion; -use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1465,33 +1446,13 @@ impl SessionState { .register_catalog(config.default_catalog.clone(), default_catalog); } - let mut rules: Vec> = vec![ - Arc::new(TypeCoercion::new()), - Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), - Arc::new(DecorrelateWhereExists::new()), - Arc::new(DecorrelateWhereIn::new()), - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(SubqueryFilterToJoin::new()), - Arc::new(EliminateFilter::new()), - Arc::new(ReduceCrossJoin::new()), - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(RewriteDisjunctivePredicate::new()), - ]; - if config - .config_options - .read() - .get_bool(OPT_FILTER_NULL_JOIN_KEYS) - .unwrap_or_default() - { - rules.push(Arc::new(FilterNullJoinKeys::default())); - } - rules.push(Arc::new(ReduceOuterJoin::new())); - rules.push(Arc::new(FilterPushDown::new())); - rules.push(Arc::new(LimitPushDown::new())); - rules.push(Arc::new(SingleDistinctToGroupBy::new())); + let optimizer_config = OptimizerConfig::new().filter_null_keys( + config + .config_options + .read() + .get_bool(OPT_FILTER_NULL_JOIN_KEYS) + .unwrap_or_default(), + ); let mut physical_optimizers: Vec> = vec![ Arc::new(AggregateStatistics::new()), @@ -1518,7 +1479,7 @@ impl SessionState { SessionState { session_id, - optimizer: Optimizer::new(rules), + optimizer: Optimizer::new(&optimizer_config), physical_optimizers, query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, @@ -1575,7 +1536,7 @@ impl SessionState { mut self, rules: Vec>, ) -> Self { - self.optimizer = Optimizer::new(rules); + self.optimizer = Optimizer::with_rules(rules); self } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e2ccd49448924..5ef5cfdd59755 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,24 @@ //! Query optimizer traits +use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_where_exists::DecorrelateWhereExists; +use crate::decorrelate_where_in::DecorrelateWhereIn; +use crate::eliminate_filter::EliminateFilter; +use crate::eliminate_limit::EliminateLimit; +use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::filter_push_down::FilterPushDown; +use crate::limit_push_down::LimitPushDown; +use crate::projection_push_down::ProjectionPushDown; +use crate::reduce_cross_join::ReduceCrossJoin; +use crate::reduce_outer_join::ReduceOuterJoin; +use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; +use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; +use crate::simplify_expressions::SimplifyExpressions; +use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; +use crate::subquery_filter_to_join::SubqueryFilterToJoin; +use crate::type_coercion::TypeCoercion; +use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use chrono::{DateTime, Utc}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -50,6 +68,8 @@ pub struct OptimizerConfig { next_id: usize, /// Option to skip rules that produce errors skip_failing_rules: bool, + /// Specify whether to enable the filter_null_keys rule + filter_null_keys: bool, } impl OptimizerConfig { @@ -59,9 +79,16 @@ impl OptimizerConfig { query_execution_start_time: chrono::Utc::now(), next_id: 0, // useful for generating things like unique subquery aliases skip_failing_rules: true, + filter_null_keys: true, } } + /// Specify whether to enable the filter_null_keys rule + pub fn filter_null_keys(mut self, filter_null_keys: bool) -> Self { + self.filter_null_keys = filter_null_keys; + self + } + /// Specify whether the optimizer should skip rules that produce /// errors, or fail the query pub fn with_query_execution_start_time( @@ -107,8 +134,35 @@ pub struct Optimizer { } impl Optimizer { + /// Create a new optimizer using the recommended list of rules + pub fn new(config: &OptimizerConfig) -> Self { + let mut rules: Vec> = vec![ + Arc::new(TypeCoercion::new()), + Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), + Arc::new(DecorrelateWhereExists::new()), + Arc::new(DecorrelateWhereIn::new()), + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(SubqueryFilterToJoin::new()), + Arc::new(EliminateFilter::new()), + Arc::new(ReduceCrossJoin::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateLimit::new()), + Arc::new(ProjectionPushDown::new()), + Arc::new(RewriteDisjunctivePredicate::new()), + ]; + if config.filter_null_keys { + rules.push(Arc::new(FilterNullJoinKeys::default())); + } + rules.push(Arc::new(ReduceOuterJoin::new())); + rules.push(Arc::new(FilterPushDown::new())); + rules.push(Arc::new(LimitPushDown::new())); + rules.push(Arc::new(SingleDistinctToGroupBy::new())); + Self::with_rules(rules) + } + /// Create a new optimizer with the given rules - pub fn new(rules: Vec>) -> Self { + pub fn with_rules(rules: Vec>) -> Self { Self { rules } } @@ -172,7 +226,7 @@ mod tests { #[test] fn skip_failing_rule() -> Result<(), DataFusionError> { - let opt = Optimizer::new(vec![Arc::new(BadRule {})]); + let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let mut config = OptimizerConfig::new().with_skip_failing_rules(true); let plan = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -184,7 +238,7 @@ mod tests { #[test] fn no_skip_failing_rule() -> Result<(), DataFusionError> { - let opt = Optimizer::new(vec![Arc::new(BadRule {})]); + let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let mut config = OptimizerConfig::new().with_skip_failing_rules(false); let plan = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 7811e475c2de6..86f55e698505f 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -18,25 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; -use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; -use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; -use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; -use datafusion_optimizer::eliminate_filter::EliminateFilter; -use datafusion_optimizer::eliminate_limit::EliminateLimit; -use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::filter_push_down::FilterPushDown; -use datafusion_optimizer::limit_push_down::LimitPushDown; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::projection_push_down::ProjectionPushDown; -use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin; -use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; -use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; -use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion_optimizer::simplify_expressions::SimplifyExpressions; -use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin; -use datafusion_optimizer::type_coercion::TypeCoercion; -use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -104,31 +86,6 @@ fn between_date64_plus_interval() -> Result<()> { } fn test_sql(sql: &str) -> Result { - // TODO should make align with rules in the context - // https://github.com/apache/arrow-datafusion/issues/3524 - let rules: Vec> = vec![ - Arc::new(TypeCoercion::new()), - Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), - Arc::new(DecorrelateWhereExists::new()), - Arc::new(DecorrelateWhereIn::new()), - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(SubqueryFilterToJoin::new()), - Arc::new(EliminateFilter::new()), - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ReduceCrossJoin::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(RewriteDisjunctivePredicate::new()), - Arc::new(FilterNullJoinKeys::default()), - Arc::new(ReduceOuterJoin::new()), - Arc::new(FilterPushDown::new()), - Arc::new(LimitPushDown::new()), - Arc::new(SingleDistinctToGroupBy::new()), - ]; - - let optimizer = Optimizer::new(rules); - // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); @@ -141,6 +98,7 @@ fn test_sql(sql: &str) -> Result { // optimize the logical plan let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + let optimizer = Optimizer::new(&config); optimizer.optimize(&plan, &mut config, &observe) }