-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Consolidate and better tests for expression re-rewriting / aliasing #3727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,12 @@ | |
|
|
||
| //! Optimizer rule for type validation and coercion | ||
|
|
||
| use crate::utils::rewrite_preserving_name; | ||
| use crate::{OptimizerConfig, OptimizerRule}; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; | ||
| use datafusion_expr::expr::Case; | ||
| use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; | ||
| use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; | ||
| use datafusion_expr::logical_plan::Subquery; | ||
| use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion}; | ||
| use datafusion_expr::type_coercion::functions::data_types; | ||
|
|
@@ -91,30 +92,13 @@ fn optimize_internal( | |
| schema: Arc::new(schema), | ||
| }; | ||
|
|
||
| let original_expr_names: Vec<Option<String>> = plan | ||
| .expressions() | ||
| .iter() | ||
| .map(|expr| expr.name().ok()) | ||
| .collect(); | ||
|
|
||
| let new_expr = plan | ||
| .expressions() | ||
| .into_iter() | ||
| .zip(original_expr_names) | ||
| .map(|(expr, original_name)| { | ||
| let expr = expr.rewrite(&mut expr_rewrite)?; | ||
|
|
||
| .map(|expr| { | ||
| // ensure aggregate names don't change: | ||
| // https://github.com/apache/arrow-datafusion/issues/3555 | ||
| if matches!(expr, Expr::AggregateFunction { .. }) { | ||
| if let Some((alias, name)) = original_name.zip(expr.name().ok()) { | ||
| if alias != name { | ||
| return Ok(expr.alias(&alias)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(expr) | ||
| rewrite_preserving_name(expr, &mut expr_rewrite) | ||
| }) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
|
|
@@ -635,7 +619,8 @@ mod test { | |
| let mut config = OptimizerConfig::default(); | ||
| let plan = rule.optimize(&plan, &mut config)?; | ||
| assert_eq!( | ||
| "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation", | ||
| "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @liukun4515 -- I think this is an improvement (and maybe a bug fix 🤔 ) |
||
| \n EmptyRelation", | ||
| &format!("{:?}", plan) | ||
| ); | ||
| // a in (1,4,8), a is decimal | ||
|
|
@@ -653,7 +638,8 @@ mod test { | |
| let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); | ||
| let plan = rule.optimize(&plan, &mut config)?; | ||
| assert_eq!( | ||
| "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation", | ||
| "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ | ||
| \n EmptyRelation", | ||
| &format!("{:?}", plan) | ||
| ); | ||
| Ok(()) | ||
|
|
@@ -751,7 +737,8 @@ mod test { | |
| let mut config = OptimizerConfig::default(); | ||
| let plan = rule.optimize(&plan, &mut config).unwrap(); | ||
| assert_eq!( | ||
| "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation", | ||
| "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \ | ||
| \n EmptyRelation", | ||
| &format!("{:?}", plan) | ||
| ); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,12 +18,13 @@ | |
| //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type | ||
| //! of expr can be added if needed. | ||
| //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. | ||
| use crate::utils::rewrite_preserving_name; | ||
| use crate::{OptimizerConfig, OptimizerRule}; | ||
| use arrow::datatypes::{ | ||
| DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, | ||
| }; | ||
| use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; | ||
| use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; | ||
| use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; | ||
| use datafusion_expr::utils::from_plan; | ||
| use datafusion_expr::{ | ||
| binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, | ||
|
|
@@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> { | |
| let new_exprs = plan | ||
| .expressions() | ||
| .into_iter() | ||
| .map(|expr| { | ||
| let original_name = name_for_alias(&expr)?; | ||
| let expr = expr.rewrite(&mut expr_rewriter)?; | ||
| add_alias_if_changed(&original_name, expr) | ||
| }) | ||
| .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this PR basically refactors the code into |
||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) | ||
| } | ||
|
|
||
| fn name_for_alias(expr: &Expr) -> Result<String> { | ||
| match expr { | ||
| Expr::Sort { expr, .. } => name_for_alias(expr), | ||
| expr => expr.name(), | ||
| } | ||
| } | ||
|
|
||
| fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> { | ||
| let new_name = name_for_alias(&expr)?; | ||
|
|
||
| if new_name == original_name { | ||
| return Ok(expr); | ||
| } | ||
|
|
||
| Ok(match expr { | ||
| Expr::Sort { | ||
| expr, | ||
| asc, | ||
| nulls_first, | ||
| } => { | ||
| let expr = add_alias_if_changed(original_name, *expr)?; | ||
| Expr::Sort { | ||
| expr: Box::new(expr), | ||
| asc, | ||
| nulls_first, | ||
| } | ||
| } | ||
| expr => expr.alias(original_name), | ||
| }) | ||
| } | ||
|
|
||
| struct UnwrapCastExprRewriter { | ||
| schema: DFSchemaRef, | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| use crate::{OptimizerConfig, OptimizerRule}; | ||
| use datafusion_common::Result; | ||
| use datafusion_common::{plan_err, Column, DFSchemaRef}; | ||
| use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; | ||
| use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; | ||
| use datafusion_expr::{ | ||
| and, col, combine_filters, | ||
|
|
@@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec<Expr> { | |
| .collect() | ||
| } | ||
|
|
||
| /// Rewrites `expr` using `rewriter`, ensuring that the output has the | ||
| /// same name as `expr` prior to rewrite, adding an alias if necessary. | ||
| /// | ||
| /// This is important when optimzing plans to ensure the the output | ||
| /// schema of plan nodes don't change after optimization | ||
| pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr> | ||
| where | ||
| R: ExprRewriter<Expr>, | ||
| { | ||
| let original_name = name_for_alias(&expr)?; | ||
| let expr = expr.rewrite(rewriter)?; | ||
| add_alias_if_changed(original_name, expr) | ||
| } | ||
|
|
||
| /// Return the name to use for the specific Expr, recursing into | ||
| /// `Expr::Sort` as appropriate | ||
| fn name_for_alias(expr: &Expr) -> Result<String> { | ||
| match expr { | ||
| Expr::Sort { expr, .. } => name_for_alias(expr), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed this issue #3710 But I want to know why we need to do the special branch for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically because calling I am not super thrilled in general about how this works -- I wonder if I should support calling |
||
| expr => expr.name(), | ||
| } | ||
| } | ||
|
|
||
| /// Ensure `expr` has the name name as `original_name` by adding an | ||
| /// alias if necessary. | ||
| fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> { | ||
| let new_name = name_for_alias(&expr)?; | ||
|
|
||
| if new_name == original_name { | ||
| return Ok(expr); | ||
| } | ||
|
|
||
| Ok(match expr { | ||
| Expr::Sort { | ||
| expr, | ||
| asc, | ||
| nulls_first, | ||
| } => { | ||
| let expr = add_alias_if_changed(original_name, *expr)?; | ||
| Expr::Sort { | ||
| expr: Box::new(expr), | ||
| asc, | ||
| nulls_first, | ||
| } | ||
| } | ||
| expr => expr.alias(original_name), | ||
| }) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::Column; | ||
| use datafusion_expr::{col, utils::expr_to_columns}; | ||
| use datafusion_expr::{col, lit, utils::expr_to_columns}; | ||
| use std::collections::HashSet; | ||
| use std::ops::Add; | ||
|
|
||
| #[test] | ||
| fn test_collect_expr() -> Result<()> { | ||
|
|
@@ -344,4 +395,73 @@ mod tests { | |
| assert!(accum.contains(&Column::from_name("a"))); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_rewrite_preserving_name() { | ||
| test_rewrite(col("a"), col("a")); | ||
|
|
||
| test_rewrite(col("a"), col("b")); | ||
|
|
||
| // cast data types | ||
| test_rewrite( | ||
| col("a"), | ||
| Expr::Cast { | ||
| expr: Box::new(col("a")), | ||
| data_type: DataType::Int32, | ||
| }, | ||
| ); | ||
|
|
||
| // change literal type from i32 to i64 | ||
| test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); | ||
|
|
||
| // SortExpr a+1 ==> b + 2 | ||
| test_rewrite( | ||
| Expr::Sort { | ||
| expr: Box::new(col("a").add(lit(1i32))), | ||
| asc: true, | ||
| nulls_first: false, | ||
| }, | ||
| Expr::Sort { | ||
| expr: Box::new(col("b").add(lit(2i64))), | ||
| asc: true, | ||
| nulls_first: false, | ||
| }, | ||
| ); | ||
| } | ||
|
|
||
| /// rewrites `expr_from` to `rewrite_to` using | ||
| /// `rewrite_preserving_name` verifying the result is `expected_expr` | ||
| fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { | ||
| struct TestRewriter { | ||
| rewrite_to: Expr, | ||
| } | ||
|
|
||
| impl ExprRewriter for TestRewriter { | ||
| fn mutate(&mut self, _: Expr) -> Result<Expr> { | ||
| Ok(self.rewrite_to.clone()) | ||
| } | ||
| } | ||
|
|
||
| let mut rewriter = TestRewriter { | ||
| rewrite_to: rewrite_to.clone(), | ||
| }; | ||
| let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); | ||
|
|
||
| let original_name = match &expr_from { | ||
| Expr::Sort { expr, .. } => expr.name(), | ||
| expr => expr.name(), | ||
| } | ||
| .unwrap(); | ||
|
|
||
| let new_name = match &expr { | ||
| Expr::Sort { expr, .. } => expr.name(), | ||
| expr => expr.name(), | ||
| } | ||
| .unwrap(); | ||
|
|
||
| assert_eq!( | ||
| original_name, new_name, | ||
| "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" | ||
| ) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi @thinkharderdev