Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ impl Expr {
}

/// Return `self AS name` alias expression
pub fn alias(self, name: &str) -> Expr {
Expr::Alias(Box::new(self), name.to_owned())
pub fn alias(self, name: impl Into<String>) -> Expr {
Expr::Alias(Box::new(self), name.into())
}

/// Return `self IN <list>` if `negated` is false, otherwise
Expand Down
33 changes: 10 additions & 23 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

//! Optimizer rule for type validation and coercion

use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::expr::Case;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
use datafusion_expr::type_coercion::functions::data_types;
Expand Down Expand Up @@ -91,30 +92,13 @@ fn optimize_internal(
schema: Arc::new(schema),
};

let original_expr_names: Vec<Option<String>> = plan
.expressions()
.iter()
.map(|expr| expr.name().ok())
.collect();

let new_expr = plan
.expressions()
.into_iter()
.zip(original_expr_names)
.map(|(expr, original_name)| {
let expr = expr.rewrite(&mut expr_rewrite)?;

.map(|expr| {
// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
if matches!(expr, Expr::AggregateFunction { .. }) {
if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
if alias != name {
return Ok(expr.alias(&alias));
}
}
}

Ok(expr)
rewrite_preserving_name(expr, &mut expr_rewrite)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -635,7 +619,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation",
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @liukun4515 -- I think this is an improvement (and maybe a bug fix 🤔 )

\n EmptyRelation",
&format!("{:?}", plan)
);
// a in (1,4,8), a is decimal
Expand All @@ -653,7 +638,8 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation",
"Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
Expand Down Expand Up @@ -751,7 +737,8 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config).unwrap();
assert_eq!(
"Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation",
"Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \
\n EmptyRelation",
&format!("{:?}", plan)
);

Expand Down
40 changes: 3 additions & 37 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type
//! of expr can be added if needed.
//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
Expand Down Expand Up @@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| {
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(&mut expr_rewriter)?;
add_alias_if_changed(&original_name, expr)
})
.map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR basically refactors the code into rewrite_preserving_name, adds tests, and calls it in a few places

.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.name(),
}
}

fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
Expand Down
122 changes: 121 additions & 1 deletion datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
and, col, combine_filters,
Expand Down Expand Up @@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
.collect()
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimzing plans to ensure the the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: ExprRewriter<Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
Copy link
Contributor

@liukun4515 liukun4515 Oct 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this issue #3710

But I want to know why we need to do the special branch for Sort Expr?

Copy link
Contributor Author

@alamb alamb Oct 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically because calling Expr::name() on a Expr::Sort will throw an error:

https://github.com/apache/arrow-datafusion/blob/7c5c2e5/datafusion/expr/src/expr.rs#L1133-L1135

I am not super thrilled in general about how this works -- I wonder if I should support calling Expr::name() on Expr::Sort 🤔

expr => expr.name(),
}
}

/// Ensure `expr` has the name name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::{col, utils::expr_to_columns};
use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;

#[test]
fn test_collect_expr() -> Result<()> {
Expand All @@ -344,4 +395,73 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Int32,
},
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort {
expr: Box::new(col("a").add(lit(1i32))),
asc: true,
nulls_first: false,
},
Expr::Sort {
expr: Box::new(col("b").add(lit(2i64))),
asc: true,
nulls_first: false,
},
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl ExprRewriter for TestRewriter {
fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort { expr, .. } => expr.name(),
expr => expr.name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort { expr, .. } => expr.name(),
expr => expr.name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}