Skip to content
10 changes: 2 additions & 8 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1466,10 +1466,9 @@ impl SessionState {
}

let mut rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have TypeCoercion I wonder if we still need PreCastLitInComparisonExpressions

https://github.com/apache/arrow-datafusion/blob/d16457a0ba129b077935078e5cf89d028f598e0b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs#L31-L50

Appears to be a subset of what TypeCoercion is doing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though perhaps this is something similar to what you have described in #3622

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PreCastLitInComparisonExpressions will be invalid when we move the type coercion to the front of PreCastLitInComparisonExpressions.

In the next pr, I will do #3622 and move the type coercion to the front of the PreCastLitInComparisonExpressions

I think type coercion just do one thing which is make type compatible in all operation, but the PreCastLitInComparisonExpressions or #3622 is to reduce cast for column expr or other expr instead of adding the cast to the literal. This will reduce the cast effort of the runtime.

Arc::new(TypeCoercion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Expand All @@ -1490,11 +1489,6 @@ impl SessionState {
rules.push(Arc::new(FilterNullJoinKeys::default()));
}
rules.push(Arc::new(ReduceOuterJoin::new()));
// TODO: https://github.com/apache/arrow-datafusion/issues/3557
// remove this, after the issue fixed.
rules.push(Arc::new(TypeCoercion::new()));
// after the type coercion, can do simplify expression again
rules.push(Arc::new(SimplifyExpressions::new()));
rules.push(Arc::new(FilterPushDown::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
Expand Down
21 changes: 19 additions & 2 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ async fn test_physical_plan_display_indent_multi_children() {
#[tokio::test]
#[cfg_attr(tarpaulin, ignore)]
async fn csv_explain() {
// TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor the `PreCastLitInComparisonExpressions`

// This test uses the execute function that create full plan cycle: logical, optimized logical, and physical,
// then execute the physical plan and return the final explain results
let ctx = SessionContext::new();
Expand All @@ -777,6 +779,23 @@ async fn csv_explain() {

// Note can't use `assert_batches_eq` as the plan needs to be
// normalized for filenames and number of cores
let expected = vec![
vec![
"logical_plan",
"Projection: #aggregate_test_100.c1\
\n Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]"
],
vec!["physical_plan",
"ProjectionExec: expr=[c1@0 as c1]\
\n CoalesceBatchesExec: target_batch_size=4096\
\n FilterExec: CAST(c2@1 AS Int32) > 10\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
\n"
]];
assert_eq!(expected, actual);

let expected = vec![
vec![
"logical_plan",
Expand All @@ -792,9 +811,7 @@ async fn csv_explain() {
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
\n"
]];
assert_eq!(expected, actual);

// Also, expect same result with lowercase explain
let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10";
let actual = execute(&ctx, sql).await;
let actual = normalize_vec_for_explain(actual);
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ async fn csv_in_set_test() -> Result<()> {
}

#[tokio::test]
#[ignore]
// https://github.com/apache/arrow-datafusion/issues/3635
async fn multiple_or_predicates() -> Result<()> {
// TODO https://github.com/apache/arrow-datafusion/issues/3587
let ctx = SessionContext::new();
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ order by s_name;
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
Filter: #part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a regression somehow -- the CAST hasn't been evaluated into a constant

Copy link
Contributor Author

@liukun4515 liukun4515 Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the desired result.
I think the evaluating of cast is done in the SimplifyExpressions rule, but the rule don't support the subquery now.
In this pr, we move the SimplifyExpressions to the front of the some rule about subquery, so the optimization of SimplifyExpressions can't apply some expr which is in the subquery.

If we want to do the evaluation cast for the expr in the subquery, there are two way:

  1. add additional SimplifyExpressions in the tail of the optimizer rules
  2. support subquery case for SimplifyExpressions rule.

If we don't want the regression in this pr, I can use the method of 1 to fix it temporarily.
@alamb

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid regressions -- so perhaps we can do 1 as a temporary workaround and then support subquery in SimplifyExpressons as a follow on PR (I am trying to get some time next week to help out in some of these issues)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will recovery them in the follow up pr

Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
Filter: #lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= Date32("8766")]"#
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise, this isn't right because it should have been evaluated to a constant

TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
.to_string();
assert_eq!(actual, expected);

Expand Down Expand Up @@ -393,8 +393,8 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(#customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);

Expand Down Expand Up @@ -453,7 +453,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,11 @@ pub struct Subquery {
}

impl Subquery {
pub fn new(plan: LogicalPlan) -> Self {
Subquery {
subquery: Arc::new(plan),
}
}
pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
Expand Down
164 changes: 118 additions & 46 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
Expand Down Expand Up @@ -50,56 +51,70 @@ impl OptimizerRule for TypeCoercion {
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.optimize(p, optimizer_config))
.collect::<Result<Vec<_>>>()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let schema = new_inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
);
optimize_internal(&DFSchema::empty(), plan, optimizer_config)
}
}

let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
};
fn optimize_internal(
// use the external schema to handle the correlated subqueries case
external_schema: &DFSchema,
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| optimize_internal(external_schema, p, optimizer_config))
.collect::<Result<Vec<_>>>()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = new_inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
);

// merge the outer schema for correlated subqueries
// like case:
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
schema.merge(external_schema);

let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
};

let original_expr_names: Vec<Option<String>> = plan
.expressions()
.iter()
.map(|expr| expr.name().ok())
.collect();

let new_expr = plan
.expressions()
.into_iter()
.zip(original_expr_names)
.map(|(expr, original_name)| {
let expr = expr.rewrite(&mut expr_rewrite)?;

// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
if matches!(expr, Expr::AggregateFunction { .. }) {
if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
if alias != name {
return Ok(expr.alias(&alias));
}
let original_expr_names: Vec<Option<String>> = plan
.expressions()
.iter()
.map(|expr| expr.name().ok())
.collect();

let new_expr = plan
.expressions()
.into_iter()
.zip(original_expr_names)
.map(|(expr, original_name)| {
let expr = expr.rewrite(&mut expr_rewrite)?;

// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
if matches!(expr, Expr::AggregateFunction { .. }) {
if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
if alias != name {
return Ok(expr.alias(&alias));
}
}
}

Ok(expr)
})
.collect::<Result<Vec<_>>>()?;
Ok(expr)
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, &new_expr, &new_inputs)
}
from_plan(plan, &new_expr, &new_inputs)
}

pub(crate) struct TypeCoercionRewriter {
Expand All @@ -119,6 +134,41 @@ impl ExprRewriter for TypeCoercionRewriter {

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
Expr::ScalarSubquery(Subquery { subquery }) => {
let mut optimizer_config = OptimizerConfig::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass through the same OptimizerConfig (so it has the same context needed to evaluate now() rather than creating new ones

Copy link
Contributor Author

@liukun4515 liukun4515 Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we move the type coercion out of the optimizer framework, we don't need the parameter of OptimizerConfig to iterate the plan when doing the type coercion.

I think the new of OptimizerConfig will not affect the rule of type coercion, because the rule of type coercion don't use the OptimizerConfig to generate some values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when we complete this issue #3582, this parameter will be dropped.
We don't need to concern this.

let new_plan =
optimize_internal(&self.schema, &subquery, &mut optimizer_config)?;
Ok(Expr::ScalarSubquery(Subquery::new(new_plan)))
}
Expr::Exists { subquery, negated } => {
let mut optimizer_config = OptimizerConfig::new();
let new_plan = optimize_internal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the outer query's schema needs to be passed while optimizing the subquery. I would expect that we would recursively optimize the subquery. The subquery doesn't have the same schema as its containing query 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in I think this should be more like:

let optimizer = TypeCoercion::new();
let new_plan = optimizer.optimize(&subquery.subquery, &mut optimizer_config)

Copy link
Contributor Author

@liukun4515 liukun4515 Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also want this, but there are some correlated subquery, like the comments.
I have added the comments about the usage of external_schema schema
// merge the outer schema for correlated subqueries
// like case:
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)

For the subquery select t2.c1 from t2 where t2.c2= t1.c3, if we don't know the schema about the t1, we can't get the data type for t1.c3 and can't do the type coercion for t2.c2=t1.c3.

This is way that I find, do you @andygrove @alamb have any thoughts about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would need to try it myself

&self.schema,
&subquery.subquery,
&mut optimizer_config,
)?;
Ok(Expr::Exists {
subquery: Subquery::new(new_plan),
negated,
})
}
Expr::InSubquery {
expr,
subquery,
negated,
} => {
let mut optimizer_config = OptimizerConfig::new();
let new_plan = optimize_internal(
&self.schema,
&subquery.subquery,
&mut optimizer_config,
)?;
Ok(Expr::InSubquery {
expr,
subquery: Subquery::new(new_plan),
negated,
})
}
Expr::IsTrue(expr) => {
let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?);
Ok(expr)
Expand Down Expand Up @@ -368,11 +418,12 @@ fn coerce_arguments_for_signature(

#[cfg(test)]
mod test {
use crate::type_coercion::TypeCoercion;
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::{col, ColumnarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{cast, col, is_true, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand Down Expand Up @@ -735,4 +786,25 @@ mod test {
),
}))
}

#[test]
fn test_type_coercion_rewrite() -> Result<()> {
let schema = Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Int64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
);
let mut rewriter = TypeCoercionRewriter::new(schema);
let expr = is_true(lit(12i32).eq(lit(13i64)));
let expected = is_true(
cast(lit(ScalarValue::Int32(Some(12))), DataType::Int64)
.eq(lit(ScalarValue::Int64(Some(13)))),
);
let result = expr.rewrite(&mut rewriter)?;
assert_eq!(expected, result);
Ok(())
// TODO add more test for this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree some more tests would be good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}
8 changes: 2 additions & 6 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
// TODO should make align with rules in the context
// https://github.com/apache/arrow-datafusion/issues/3524
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(TypeCoercion::new()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Arc::new(SimplifyExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Expand All @@ -125,9 +124,6 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(FilterNullJoinKeys::default()),
Arc::new(ReduceOuterJoin::new()),
Arc::new(TypeCoercion::new()),
// after the type coercion, can do simplify expression again
Arc::new(SimplifyExpressions::new()),
Arc::new(FilterPushDown::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
Expand Down