From fa279f8e2def8ed653d5c01d568927182c155092 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 10:06:26 -0600 Subject: [PATCH 1/6] move optimizer init to optimizer crate --- datafusion/core/src/execution/context.rs | 59 ++++-------------- datafusion/optimizer/src/optimizer.rs | 60 ++++++++++++++++++- .../optimizer/tests/integration-test.rs | 44 +------------- 3 files changed, 68 insertions(+), 95 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 2a805a5fc0e8b..578b6cd494598 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -33,10 +33,7 @@ use crate::{ MemTable, ViewTable, }, logical_plan::{PlanType, ToStringifiedPlan}, - optimizer::{ - eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit, - optimizer::Optimizer, - }, + optimizer::optimizer::Optimizer, physical_optimizer::{ aggregate_statistics::AggregateStatistics, hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule, @@ -72,16 +69,7 @@ use crate::logical_plan::{ CreateMemoryTable, CreateView, DropTable, FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, }; -use crate::optimizer::common_subexpr_eliminate::CommonSubexprEliminate; -use crate::optimizer::filter_push_down::FilterPushDown; -use crate::optimizer::limit_push_down::LimitPushDown; use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule}; -use crate::optimizer::projection_push_down::ProjectionPushDown; -use crate::optimizer::reduce_cross_join::ReduceCrossJoin; -use crate::optimizer::reduce_outer_join::ReduceOuterJoin; -use crate::optimizer::simplify_expressions::SimplifyExpressions; -use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use crate::optimizer::subquery_filter_to_join::SubqueryFilterToJoin; use datafusion_sql::{ResolvedTableReference, TableReference}; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; @@ -107,13 +95,6 @@ use chrono::{DateTime, Utc}; use datafusion_common::ScalarValue; use datafusion_expr::logical_plan::DropView; use datafusion_expr::{TableSource, TableType}; -use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; -use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; -use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; -use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion_optimizer::type_coercion::TypeCoercion; -use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1465,33 +1446,13 @@ impl SessionState { .register_catalog(config.default_catalog.clone(), default_catalog); } - let mut rules: Vec> = vec![ - Arc::new(TypeCoercion::new()), - Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), - Arc::new(DecorrelateWhereExists::new()), - Arc::new(DecorrelateWhereIn::new()), - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(SubqueryFilterToJoin::new()), - Arc::new(EliminateFilter::new()), - Arc::new(ReduceCrossJoin::new()), - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(RewriteDisjunctivePredicate::new()), - ]; - if config - .config_options - .read() - .get_bool(OPT_FILTER_NULL_JOIN_KEYS) - .unwrap_or_default() - { - rules.push(Arc::new(FilterNullJoinKeys::default())); - } - rules.push(Arc::new(ReduceOuterJoin::new())); - rules.push(Arc::new(FilterPushDown::new())); - rules.push(Arc::new(LimitPushDown::new())); - rules.push(Arc::new(SingleDistinctToGroupBy::new())); + let x = OptimizerConfig::new().filter_null_keys( + config + .config_options + .read() + .get_bool(OPT_FILTER_NULL_JOIN_KEYS) + .unwrap_or_default(), + ); let mut physical_optimizers: Vec> = vec![ Arc::new(AggregateStatistics::new()), @@ -1518,7 +1479,7 @@ impl SessionState { SessionState { session_id, - optimizer: Optimizer::new(rules), + optimizer: Optimizer::new(&x), physical_optimizers, query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, @@ -1575,7 +1536,7 @@ impl SessionState { mut self, rules: Vec>, ) -> Self { - self.optimizer = Optimizer::new(rules); + self.optimizer = Optimizer::with_rules(rules); self } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e2ccd49448924..5ef5cfdd59755 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,24 @@ //! Query optimizer traits +use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_where_exists::DecorrelateWhereExists; +use crate::decorrelate_where_in::DecorrelateWhereIn; +use crate::eliminate_filter::EliminateFilter; +use crate::eliminate_limit::EliminateLimit; +use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::filter_push_down::FilterPushDown; +use crate::limit_push_down::LimitPushDown; +use crate::projection_push_down::ProjectionPushDown; +use crate::reduce_cross_join::ReduceCrossJoin; +use crate::reduce_outer_join::ReduceOuterJoin; +use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; +use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; +use crate::simplify_expressions::SimplifyExpressions; +use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; +use crate::subquery_filter_to_join::SubqueryFilterToJoin; +use crate::type_coercion::TypeCoercion; +use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use chrono::{DateTime, Utc}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -50,6 +68,8 @@ pub struct OptimizerConfig { next_id: usize, /// Option to skip rules that produce errors skip_failing_rules: bool, + /// Specify whether to enable the filter_null_keys rule + filter_null_keys: bool, } impl OptimizerConfig { @@ -59,9 +79,16 @@ impl OptimizerConfig { query_execution_start_time: chrono::Utc::now(), next_id: 0, // useful for generating things like unique subquery aliases skip_failing_rules: true, + filter_null_keys: true, } } + /// Specify whether to enable the filter_null_keys rule + pub fn filter_null_keys(mut self, filter_null_keys: bool) -> Self { + self.filter_null_keys = filter_null_keys; + self + } + /// Specify whether the optimizer should skip rules that produce /// errors, or fail the query pub fn with_query_execution_start_time( @@ -107,8 +134,35 @@ pub struct Optimizer { } impl Optimizer { + /// Create a new optimizer using the recommended list of rules + pub fn new(config: &OptimizerConfig) -> Self { + let mut rules: Vec> = vec![ + Arc::new(TypeCoercion::new()), + Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), + Arc::new(DecorrelateWhereExists::new()), + Arc::new(DecorrelateWhereIn::new()), + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(SubqueryFilterToJoin::new()), + Arc::new(EliminateFilter::new()), + Arc::new(ReduceCrossJoin::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateLimit::new()), + Arc::new(ProjectionPushDown::new()), + Arc::new(RewriteDisjunctivePredicate::new()), + ]; + if config.filter_null_keys { + rules.push(Arc::new(FilterNullJoinKeys::default())); + } + rules.push(Arc::new(ReduceOuterJoin::new())); + rules.push(Arc::new(FilterPushDown::new())); + rules.push(Arc::new(LimitPushDown::new())); + rules.push(Arc::new(SingleDistinctToGroupBy::new())); + Self::with_rules(rules) + } + /// Create a new optimizer with the given rules - pub fn new(rules: Vec>) -> Self { + pub fn with_rules(rules: Vec>) -> Self { Self { rules } } @@ -172,7 +226,7 @@ mod tests { #[test] fn skip_failing_rule() -> Result<(), DataFusionError> { - let opt = Optimizer::new(vec![Arc::new(BadRule {})]); + let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let mut config = OptimizerConfig::new().with_skip_failing_rules(true); let plan = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -184,7 +238,7 @@ mod tests { #[test] fn no_skip_failing_rule() -> Result<(), DataFusionError> { - let opt = Optimizer::new(vec![Arc::new(BadRule {})]); + let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let mut config = OptimizerConfig::new().with_skip_failing_rules(false); let plan = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 7811e475c2de6..86f55e698505f 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -18,25 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; -use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; -use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; -use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; -use datafusion_optimizer::eliminate_filter::EliminateFilter; -use datafusion_optimizer::eliminate_limit::EliminateLimit; -use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::filter_push_down::FilterPushDown; -use datafusion_optimizer::limit_push_down::LimitPushDown; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::projection_push_down::ProjectionPushDown; -use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin; -use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; -use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; -use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion_optimizer::simplify_expressions::SimplifyExpressions; -use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin; -use datafusion_optimizer::type_coercion::TypeCoercion; -use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -104,31 +86,6 @@ fn between_date64_plus_interval() -> Result<()> { } fn test_sql(sql: &str) -> Result { - // TODO should make align with rules in the context - // https://github.com/apache/arrow-datafusion/issues/3524 - let rules: Vec> = vec![ - Arc::new(TypeCoercion::new()), - Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), - Arc::new(DecorrelateWhereExists::new()), - Arc::new(DecorrelateWhereIn::new()), - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(SubqueryFilterToJoin::new()), - Arc::new(EliminateFilter::new()), - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ReduceCrossJoin::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(RewriteDisjunctivePredicate::new()), - Arc::new(FilterNullJoinKeys::default()), - Arc::new(ReduceOuterJoin::new()), - Arc::new(FilterPushDown::new()), - Arc::new(LimitPushDown::new()), - Arc::new(SingleDistinctToGroupBy::new()), - ]; - - let optimizer = Optimizer::new(rules); - // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); @@ -141,6 +98,7 @@ fn test_sql(sql: &str) -> Result { // optimize the logical plan let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + let optimizer = Optimizer::new(&config); optimizer.optimize(&plan, &mut config, &observe) } From 3b627bb90f62a0de058ececa6900c3b9a35ac796 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 10:08:28 -0600 Subject: [PATCH 2/6] rename variable --- datafusion/core/src/execution/context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 578b6cd494598..7284d3e8ac8f9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1446,7 +1446,7 @@ impl SessionState { .register_catalog(config.default_catalog.clone(), default_catalog); } - let x = OptimizerConfig::new().filter_null_keys( + let optimizer_config = OptimizerConfig::new().filter_null_keys( config .config_options .read() @@ -1479,7 +1479,7 @@ impl SessionState { SessionState { session_id, - optimizer: Optimizer::new(&x), + optimizer: Optimizer::new(&optimizer_config), physical_optimizers, query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, From ae0a9cfaaa8ba41e3388a8f3c1734ee40dee630a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 10:17:11 -0600 Subject: [PATCH 3/6] add failing test --- datafusion/optimizer/tests/integration-test.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 86f55e698505f..932dd5449b539 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -29,6 +29,15 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +#[test] +fn case_when() -> Result<()> { + let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; + let plan = test_sql(sql)?; + let expected = "TBD"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn distribute_by() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/issues/3234 From 0eee723625745b9ffa39694d3fa80a4b08837c81 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 10:46:06 -0600 Subject: [PATCH 4/6] Fix optimizer regressions --- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/tests/integration-test.rs | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5ef5cfdd59755..fc6e6be56400c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -137,9 +137,9 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new(config: &OptimizerConfig) -> Self { let mut rules: Vec> = vec![ + Arc::new(UnwrapCastInComparison::new()), Arc::new(TypeCoercion::new()), Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 932dd5449b539..71efe86bd3d6d 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -31,9 +31,23 @@ use std::sync::Arc; #[test] fn case_when() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3690 let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; let plan = test_sql(sql)?; - let expected = "TBD"; + let expected = "Projection: CASE WHEN CAST(#test.col_int32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn unsigned_target_type() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3690 + let sql = "SELECT * FROM test WHERE col_uint32 > 0"; + let plan = test_sql(sql)?; + let expected = "Projection: #test.col_int32, #test.col_uint32, #test.col_utf8, #test.col_date32, #test.col_date64\ + \n Filter: CAST(#test.col_uint32 AS Int64) > Int64(0)\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } @@ -123,6 +137,7 @@ impl ContextProvider for MySchemaProvider { let schema = Schema::new_with_metadata( vec![ Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), Field::new("col_utf8", DataType::Utf8, true), Field::new("col_date32", DataType::Date32, true), Field::new("col_date64", DataType::Date64, true), From 49ba27c30d10c3cb539d90829a3f93d4bd816092 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 11:13:46 -0600 Subject: [PATCH 5/6] revert https://github.com/apache/arrow-datafusion/pull/3662 --- datafusion/core/tests/sql/explain_analyze.rs | 20 +- datafusion/optimizer/src/lib.rs | 2 +- datafusion/optimizer/src/optimizer.rs | 4 +- ...rison.rs => pre_cast_lit_in_comparison.rs} | 340 +++++++----------- 4 files changed, 160 insertions(+), 206 deletions(-) rename datafusion/optimizer/src/{unwrap_cast_in_comparison.rs => pre_cast_lit_in_comparison.rs} (60%) diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 7d09d94834b13..fe51aedc8c954 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -767,6 +767,8 @@ async fn test_physical_plan_display_indent_multi_children() { #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn csv_explain() { + // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor the `PreCastLitInComparisonExpressions` + // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results let ctx = SessionContext::new(); @@ -777,6 +779,23 @@ async fn csv_explain() { // Note can't use `assert_batches_eq` as the plan needs to be // normalized for filenames and number of cores + let expected = vec![ + vec![ + "logical_plan", + "Projection: #aggregate_test_100.c1\ + \n Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\ + \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]" + ], + vec!["physical_plan", + "ProjectionExec: expr=[c1@0 as c1]\ + \n CoalesceBatchesExec: target_batch_size=4096\ + \n FilterExec: CAST(c2@1 AS Int32) > 10\ + \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ + \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ + \n" + ]]; + assert_eq!(expected, actual); + let expected = vec![ vec![ "logical_plan", @@ -792,7 +811,6 @@ async fn csv_explain() { \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ \n" ]]; - assert_eq!(expected, actual); let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&ctx, sql).await; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 879658c408ad6..bfb5634364d2d 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,9 +35,9 @@ pub mod subquery_filter_to_join; pub mod type_coercion; pub mod utils; +pub mod pre_cast_lit_in_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; -pub mod unwrap_cast_in_comparison; pub use optimizer::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index fc6e6be56400c..1a3174ebf56f8 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -34,12 +34,12 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::subquery_filter_to_join::SubqueryFilterToJoin; use crate::type_coercion::TypeCoercion; -use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use chrono::{DateTime, Utc}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use log::{debug, trace, warn}; use std::sync::Arc; +use crate::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; /// `OptimizerRule` transforms one ['LogicalPlan'] into another which /// computes the same results, but in a potentially more efficient @@ -137,7 +137,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new(config: &OptimizerConfig) -> Self { let mut rules: Vec> = vec![ - Arc::new(UnwrapCastInComparison::new()), + Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(TypeCoercion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(DecorrelateWhereExists::new()), diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs similarity index 60% rename from datafusion/optimizer/src/unwrap_cast_in_comparison.rs rename to datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 0d5665f29e427..a6d915cf0161e 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type -//! of expr can be added if needed. -//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. +//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -29,14 +28,14 @@ use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -/// The rule can be used to the numeric binary comparison with literal expr, like below pattern: -/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr comparison_op cast(right_expr as data_type)`. -/// The data type of two sides must be equal, and must be signed numeric type now, and will support more data type later. +/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: +/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. +/// The data type of two sides must be signed numeric type now, and will support more data type later. /// /// If the binary comparison expr match above rules, the optimizer will check if the value of `literal` /// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`. /// -/// If this is true, the literal expr will be casted to the data type of expr on the other side, and the result of +/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of /// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or /// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization, /// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr` @@ -46,19 +45,19 @@ use datafusion_expr::{ /// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark. /// # Example /// -/// `Filter: cast(c1 as INT64) > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), +/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), /// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. /// #[derive(Default)] -pub struct UnwrapCastInComparison {} +pub struct PreCastLitInComparisonExpressions {} -impl UnwrapCastInComparison { +impl PreCastLitInComparisonExpressions { pub fn new() -> Self { Self::default() } } -impl OptimizerRule for UnwrapCastInComparison { +impl OptimizerRule for PreCastLitInComparisonExpressions { fn optimize( &self, plan: &LogicalPlan, @@ -68,7 +67,7 @@ impl OptimizerRule for UnwrapCastInComparison { } fn name(&self) -> &str { - "unwrap_cast_in_comparison" + "pre_cast_lit_in_comparison" } } @@ -81,7 +80,7 @@ fn optimize(plan: &LogicalPlan) -> Result { let schema = plan.schema(); - let mut expr_rewriter = UnwrapCastExprRewriter { + let mut expr_rewriter = PreCastLitExprRewriter { schema: schema.clone(), }; @@ -94,20 +93,17 @@ fn optimize(plan: &LogicalPlan) -> Result { from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -struct UnwrapCastExprRewriter { +struct PreCastLitExprRewriter { schema: DFSchemaRef, } -impl ExprRewriter for UnwrapCastExprRewriter { +impl ExprRewriter for PreCastLitExprRewriter { fn pre_visit(&mut self, _expr: &Expr) -> Result { Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { match &expr { - // For case: - // try_cast/cast(expr as data_type) op literal - // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr { left, op, right } => { let left = left.as_ref().clone(); let right = right.as_ref().clone(); @@ -117,48 +113,29 @@ impl ExprRewriter for UnwrapCastExprRewriter { if left_type.is_err() || right_type.is_err() { return Ok(expr.clone()); } - // Because the plan has been done the type coercion, the left and right must be equal let left_type = left_type?; let right_type = right_type?; - if is_support_data_type(&left_type) + if !left_type.eq(&right_type) + && is_support_data_type(&left_type) && is_support_data_type(&right_type) && is_comparison_op(op) { match (&left, &right) { - ( - Expr::Literal(left_lit_value), - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, - ) => { - // if the left_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; + (Expr::Literal(_), Expr::Literal(_)) => { + // do nothing + } + (Expr::Literal(left_lit_value), _) => { let casted_scalar_value = - try_cast_literal_to_type(left_lit_value, &expr_type)?; + try_cast_literal_to_type(left_lit_value, &right_type)?; if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - )); + return Ok(binary_expr(lit(value), *op, right)); } } - ( - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, - Expr::Literal(right_lit_value), - ) => { - // if the right_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; + (_, Expr::Literal(right_lit_value)) => { let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &expr_type)?; + try_cast_literal_to_type(right_lit_value, &left_type)?; if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - )); + return Ok(binary_expr(left, *op, lit(value))); } } (_, _) => { @@ -169,75 +146,55 @@ impl ExprRewriter for UnwrapCastExprRewriter { // return the new binary op Ok(binary_expr(left, *op, right)) } - // For case: - // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList { expr: left_expr, list, negated, } => { - if let Some( - Expr::TryCast { - expr: internal_left_expr, - .. - } - | Expr::Cast { - expr: internal_left_expr, - .. - }, - ) = Some(left_expr.as_ref()) - { - let internal_left = internal_left_expr.as_ref().clone(); - let internal_left_type = internal_left.get_type(&self.schema); - if internal_left_type.is_err() { - // error data type - return Ok(expr); - } - let internal_left_type = internal_left_type?; - if !is_support_data_type(&internal_left_type) { - // not supported data type - return Ok(expr); - } - let right_exprs = list - .iter() - .map(|right| { - let right_type = right.get_type(&self.schema)?; - if !is_support_data_type(&right_type) { - return Err(DataFusionError::Internal(format!( - "The type of list expr {} not support", - &right_type - ))); - } - match right { - Expr::Literal(right_lit_value) => { - // if the right_lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &internal_left_type)?; - if let Some(value) = casted_scalar_value { - Ok(lit(value)) - } else { - Err(DataFusionError::Internal(format!( - "Can't cast the list expr {:?} to type {:?}", - right_lit_value, &internal_left_type - ))) - } + let left = left_expr.as_ref().clone(); + let left_type = left.get_type(&self.schema); + if left_type.is_err() { + // error data type + return Ok(expr); + } + let left_type = left_type?; + if !is_support_data_type(&left_type) { + // not supported data type + return Ok(expr); + } + let right_exprs = list + .iter() + .map(|right| { + let right_type = right.get_type(&self.schema)?; + if !is_support_data_type(&right_type) { + return Err(DataFusionError::Internal(format!( + "The type of list expr {} not support", + &right_type + ))); + } + match right { + Expr::Literal(right_lit_value) => { + let casted_scalar_value = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if let Some(value) = casted_scalar_value { + Ok(lit(value)) + } else { + Err(DataFusionError::Internal(format!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &left_type + ))) } - other_expr => Err(DataFusionError::Internal(format!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ))), } - }) - .collect::>>(); - match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) + other_expr => Err(DataFusionError::Internal(format!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ))), } - Err(_) => Ok(expr), - } - } else { - Ok(expr) + }) + .collect::>>(); + match right_exprs { + Ok(right_exprs) => Ok(in_list(left, right_exprs, *negated)), + Err(_) => Ok(expr), } } // TODO: handle other expr type and dfs visit them @@ -369,19 +326,23 @@ fn try_cast_literal_to_type( #[cfg(test)] mod tests { - use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; + use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, lit, try_cast, Expr}; + use datafusion_expr::{col, lit, Expr}; use std::collections::HashMap; use std::sync::Arc; #[test] - fn test_not_unwrap_cast_comparison() { + fn test_not_cast_lit_comparison() { let schema = expr_test_schema(); - // cast(INT32(c1), INT64) > INT64(c2) - let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); + // INT8(NULL) < INT32(12) + let lit_lt_lit = + lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12)))); + assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit); + // INT32(c1) > INT64(c2) + let c1_gt_c2 = col("c1").gt(col("c2")); assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); // INT32(c1) < INT32(16), the type is same @@ -389,132 +350,110 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type - let expr_lt = cast(col("c1"), DataType::Int64) - .lt(lit(ScalarValue::Int64(Some(99999999999)))); + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999)))); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); } #[test] - fn test_unwrap_cast_comparison() { + fn test_pre_cast_lit_comparison() { let schema = expr_test_schema(); - // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) + // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = - cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); - let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); - assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = - try_cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = - cast(col("c2"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(16)))); + // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16)))); let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16)))); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) - let c1_lt_lit_null = - cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(None))); + // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None))); let expected = col("c1").lt(lit(ScalarValue::Int32(None))); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) - let lit_lt_lit = cast(lit(ScalarValue::Int8(None)), DataType::Int32) - .lt(lit(ScalarValue::Int32(Some(12)))); - let expected = lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int8(Some(12)))); - assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } #[test] - fn test_not_unwrap_cast_with_decimal_comparison() { + fn test_not_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal - // cast(c3, INT64) = INT64(100000000000000000) - let expr_eq = cast(col("c3"), DataType::Int64) - .eq(lit(ScalarValue::Int64(Some(100000000000000000)))); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - - // cast(c4, INT64) = INT64(1000) will overflow the i128 - let expr_eq = - cast(col("c4"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1000)))); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + // c3 = INT64(100000000000000000) + let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); + // c4 = INT64(1000) will overflow the i128 + let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); + assert_eq!(optimize_test(expr_eq, &schema), expected); // decimal to decimal: value will lose the scale when convert to the target data type // c3 = DECIMAL(12340,20,4) - let expr_eq = cast(col("c3"), DataType::Decimal128(20, 4)) - .eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + assert_eq!(optimize_test(expr_eq, &schema), expected); // decimal to integer // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type - let expr_eq = cast(col("c1"), DataType::Decimal128(10, 1)) - .eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + assert_eq!(optimize_test(expr_eq, &schema), expected); // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type - let expr_eq = cast(col("c1"), DataType::Decimal128(10, 2)) - .eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + assert_eq!(optimize_test(expr_eq, &schema), expected); } #[test] - fn test_unwrap_cast_with_decimal_lit_comparison() { + fn test_pre_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = - try_cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); + let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = - cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(None))); + let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(10, 0)) - .lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); - // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(10, 3)) - .lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); + let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = cast(col("c1"), DataType::Decimal128(10, 2)) - .lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); + let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); assert_eq!(optimize_test(expr_lt, &schema), expected); } #[test] - fn test_not_unwrap_list_cast_lit_comparison() { + fn test_not_list_cast_lit_comparison() { let schema = expr_test_schema(); - // internal left type is not supported + // left type is not supported // FLOAT32(C5) in ... - let expr_lt = cast(col("c5"), DataType::Int64).in_list( + let expr_lt = col("c5").in_list( vec![ lit(ScalarValue::Int64(Some(12))), - lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Int32(Some(12))), ], false, ); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32).in_list( + // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12)) + let expr_lt = col("c1").in_list( vec![ - lit(ScalarValue::Float32(Some(12.0))), - lit(ScalarValue::Float32(Some(12.0))), + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), lit(ScalarValue::Float32(Some(1.23))), ], false, @@ -522,7 +461,7 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64).in_list( + let expr_lt = col("c1").in_list( vec![ lit(ScalarValue::Int32(Some(12))), lit(ScalarValue::Int64(Some(99999999999))), @@ -532,10 +471,10 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( + let expr_lt = col("c3").in_list( vec![ - lit(ScalarValue::Decimal128(Some(12), 12, 3)), - lit(ScalarValue::Decimal128(Some(12), 12, 3)), + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), lit(ScalarValue::Decimal128(Some(128), 12, 3)), ], false, @@ -544,12 +483,12 @@ mod tests { } #[test] - fn test_unwrap_list_cast_comparison() { + fn test_pre_list_cast_lit_comparison() { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = cast(col("c1"), DataType::Int64).in_list( + let expr_lt = col("c1").in_list( vec![ - lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Int32(Some(12))), lit(ScalarValue::Int64(Some(24))), ], false, @@ -563,9 +502,9 @@ mod tests { ); assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = cast(col("c2"), DataType::Int32).in_list( + let expr_lt = col("c2").in_list( vec![ - lit(ScalarValue::Int32(None)), + lit(ScalarValue::Int64(None)), lit(ScalarValue::Int32(Some(14))), ], false, @@ -581,13 +520,12 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal test case - // c3 is decimal(18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( + let expr_lt = col("c3").in_list( vec![ - lit(ScalarValue::Decimal128(Some(12000), 19, 3)), - lit(ScalarValue::Decimal128(Some(24000), 19, 3)), - lit(ScalarValue::Decimal128(Some(1280), 19, 3)), - lit(ScalarValue::Decimal128(Some(1240), 19, 3)), + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(24))), + lit(ScalarValue::Decimal128(Some(128), 10, 2)), + lit(ScalarValue::Decimal128(Some(1280), 10, 3)), ], false, ); @@ -596,23 +534,23 @@ mod tests { lit(ScalarValue::Decimal128(Some(1200), 18, 2)), lit(ScalarValue::Decimal128(Some(2400), 18, 2)), lit(ScalarValue::Decimal128(Some(128), 18, 2)), - lit(ScalarValue::Decimal128(Some(124), 18, 2)), + lit(ScalarValue::Decimal128(Some(128), 18, 2)), ], false, ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(ScalarValue::Int32(Some(12))), DataType::Int64).in_list( + // INT32(12) IN (.....) + let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list( vec![ - lit(ScalarValue::Int64(Some(13))), + lit(ScalarValue::Int32(Some(12))), lit(ScalarValue::Int64(Some(12))), ], false, ); let expected = lit(ScalarValue::Int32(Some(12))).in_list( vec![ - lit(ScalarValue::Int32(Some(13))), + lit(ScalarValue::Int32(Some(12))), lit(ScalarValue::Int32(Some(12))), ], false, @@ -625,9 +563,7 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64) - .lt(lit(ScalarValue::Int64(Some(16)))) - .alias("x"); + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x"); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -637,9 +573,9 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64) + let expr_lt = col("c1") .lt(lit(ScalarValue::Int64(Some(16)))) - .or(cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(32))))); + .or(col("c1").gt(lit(ScalarValue::Int64(Some(32))))); let expected = col("c1") .lt(lit(ScalarValue::Int32(Some(16)))) .or(col("c1").gt(lit(ScalarValue::Int32(Some(32))))); @@ -647,7 +583,7 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let mut expr_rewriter = UnwrapCastExprRewriter { + let mut expr_rewriter = PreCastLitExprRewriter { schema: schema.clone(), }; expr.rewrite(&mut expr_rewriter).unwrap() From 653ab0425604b046fc12c77830983aa64d1a4dee Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 3 Oct 2022 11:40:34 -0600 Subject: [PATCH 6/6] revert tests --- datafusion/optimizer/src/optimizer.rs | 2 +- .../src/pre_cast_lit_in_comparison.rs | 75 ++++++------------- 2 files changed, 24 insertions(+), 53 deletions(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 1a3174ebf56f8..b59fdae59b071 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -25,6 +25,7 @@ use crate::eliminate_limit::EliminateLimit; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::filter_push_down::FilterPushDown; use crate::limit_push_down::LimitPushDown; +use crate::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use crate::projection_push_down::ProjectionPushDown; use crate::reduce_cross_join::ReduceCrossJoin; use crate::reduce_outer_join::ReduceOuterJoin; @@ -39,7 +40,6 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use log::{debug, trace, warn}; use std::sync::Arc; -use crate::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; /// `OptimizerRule` transforms one ['LogicalPlan'] into another which /// computes the same results, but in a potentially more efficient diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index dbbf8a5473648..382f5bfb22064 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -330,7 +330,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{col, lit, Expr}; + use datafusion_expr::{cast, col, lit, Expr}; use std::collections::HashMap; use std::sync::Arc; @@ -359,34 +359,25 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); - let expected = col("c1").lt(lit(16i32)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = col("c1").lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); + // // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = col("c2").eq(lit(16i32)); let expected = col("c2").eq(lit(16i64)); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) - let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); + // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = col("c1").lt(null_i64()); let expected = col("c1").lt(null_i32()); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) - let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); - let expected = null_i8().lt(lit(12i8)); - assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } #[test] fn test_not_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal - // cast(c3, INT64) = INT64(100000000000000000) let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); @@ -418,32 +409,29 @@ mod tests { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - - let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); + let expr_lt = col("c3").lt(lit(16i64)); let expected = col("c3").lt(lit_decimal(1600, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = col("c3").lt(null_i64()); let expected = col("c3").lt(null_decimal(18, 2)); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); + let expr_lt = col("c3").lt(lit_decimal(123, 10, 0)); let expected = col("c3").lt(lit_decimal(12300, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); + let expr_lt = col("c3").lt(lit_decimal(1230, 10, 3)); let expected = col("c3").lt(lit_decimal(123, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = - cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); + let expr_lt = col("c1").lt(lit_decimal(12300, 10, 2)); let expected = col("c1").lt(lit(123i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -453,27 +441,21 @@ mod tests { let schema = expr_test_schema(); // left type is not supported // FLOAT32(C5) in ... - let expr_lt = - cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); + let expr_lt = col("c5").in_list(vec![lit(12i64), lit(12i32)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32) - .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); + // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12)) + let expr_lt = + col("c1").in_list(vec![lit(12.0_f32), lit(12_i32), lit(12_i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64) - .in_list(vec![lit(12i32), lit(99999999999i64)], false); + let expr_lt = col("c1").in_list(vec![lit(12i32), lit(99999999999i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) let expr_lt = col("c3").in_list( - vec![ - lit_decimal(12, 12, 3), - lit_decimal(12, 12, 3), - lit_decimal(128, 12, 3), - ], + vec![lit(12_i64), lit(12_i32), lit_decimal(128, 12, 3)], false, ); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); @@ -483,13 +465,11 @@ mod tests { fn test_pre_list_cast_lit_comparison() { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); + let expr_lt = col("c1").in_list(vec![lit(12i64), lit(24i64)], false); let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false); + let expr_lt = col("c2").in_list(vec![null_i32(), lit(14i32)], false); let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); @@ -516,8 +496,7 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(12i32), DataType::Int64) - .in_list(vec![lit(13i64), lit(12i64)], false); + let expr_lt = lit(12i32).in_list(vec![lit(13i64), lit(12i64)], false); let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -527,7 +506,7 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); + let expr_lt = col("c1").lt(lit(16i64)).alias("x"); let expected = col("c1").lt(lit(16i32)).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -537,11 +516,7 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( - col("c1"), - DataType::Int64, - ) - .gt(lit(32i64))); + let expr_lt = col("c1").lt(lit(16i64)).or(col("c1").gt(lit(32i64))); let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -569,10 +544,6 @@ mod tests { ) } - fn null_i8() -> Expr { - lit(ScalarValue::Int8(None)) - } - fn null_i32() -> Expr { lit(ScalarValue::Int32(None)) }