From 5f0023b797279986f30abd0afa9008d249406d51 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 11:14:02 -0400 Subject: [PATCH 1/2] fix NamedStructField should be rewritten in OperatorToFunction in subquery --- .../src/analyzer/function_rewrite.rs | 131 ++++++++++++------ .../sqllogictest/test_files/subquery.slt | 55 ++++++++ 2 files changed, 144 insertions(+), 42 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 78f65c5b82abe..70cb54c24f6cc 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -21,9 +21,10 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::{Expr, LogicalPlan, Subquery}; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions @@ -45,54 +46,66 @@ impl AnalyzerRule for ApplyFunctionRewrites { } fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { - self.analyze_internal(&plan, options) + analyze_internal(&plan, &self.function_rewrites, options) } } -impl ApplyFunctionRewrites { - fn analyze_internal( - &self, - plan: &LogicalPlan, - options: &ConfigOptions, - ) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| self.analyze_internal(p, options)) - .collect::>>()?; - - // get schema representing all available input fields. This is used for data type - // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); - - if let LogicalPlan::TableScan(ts) = plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); - } +fn analyze_internal( + plan: &LogicalPlan, + function_rewrites: &[Arc], + options: &ConfigOptions, +) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p, function_rewrites, options)) + .collect::>>()?; - let mut expr_rewrite = OperatorToFunctionRewriter { - function_rewrites: &self.function_rewrites, - options, - schema: &schema, - }; + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); } + + let mut expr_rewrite = OperatorToFunctionRewriter { + function_rewrites, + options, + schema: &schema, + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, new_inputs) } + +fn rewrite_subquery( + mut subquery: Subquery, + function_rewrites: &[Arc], + options: &ConfigOptions, +) -> Result { + subquery.subquery = Arc::new(analyze_internal( + &subquery.subquery, + function_rewrites, + options, + )?); + Ok(subquery) +} + struct OperatorToFunctionRewriter<'a> { function_rewrites: &'a [Arc], options: &'a ConfigOptions, @@ -113,6 +126,40 @@ impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> { expr = result.data } + // recurse into subqueries if needed + let expr = match expr { + Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?), + + Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists { + subquery: rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?, + negated, + }), + + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => Expr::InSubquery(InSubquery { + expr, + subquery: rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?, + negated, + }), + + expr => expr, + }; + Ok(if transformed { Transformed::yes(expr) } else { diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index cc6428e514359..1ae89c9159f8b 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1060,3 +1060,58 @@ logical_plan Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) --Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a ----TableScan: t projection=[a] + +### +## Ensure that operators are rewritten in subqueries +### + +statement ok +create table foo(x int) as values (1); + +# Show input data +query ? +select struct(1, 'b') +---- +{c0: 1, c1: b} + + +query T +select (select struct(1, 'b')['c1']); +---- +b + +query T +select 'foo' || (select struct(1, 'b')['c1']); +---- +foob + +query I +SELECT * FROM (VALUES (1), (2)) +WHERE column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM (VALUES (1), (2)) +WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + + +query I +SELECT * FROM foo +WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM foo +WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + + +statement ok +drop table foo; From a546e69124809547c9a27a0e38e020a48178ac15 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 15:34:18 -0400 Subject: [PATCH 2/2] Use TreeNode rewriter --- .../src/analyzer/function_rewrite.rs | 166 +++++------------- datafusion/optimizer/src/utils.rs | 44 +++++ 2 files changed, 88 insertions(+), 122 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 70cb54c24f6cc..deb493e09953c 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,12 +19,13 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{Exists, InSubquery}; -use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite}; + +use crate::utils::NamePreserver; +use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{Expr, LogicalPlan, Subquery}; +use datafusion_expr::LogicalPlan; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions @@ -38,132 +39,53 @@ impl ApplyFunctionRewrites { pub fn new(function_rewrites: Vec>) -> Self { Self { function_rewrites } } -} - -impl AnalyzerRule for ApplyFunctionRewrites { - fn name(&self) -> &str { - "apply_function_rewrites" - } - - fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { - analyze_internal(&plan, &self.function_rewrites, options) - } -} -fn analyze_internal( - plan: &LogicalPlan, - function_rewrites: &[Arc], - options: &ConfigOptions, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(p, function_rewrites, options)) - .collect::>>()?; + /// Rewrite a single plan, and all its expressions using the provided rewriters + fn rewrite_plan( + &self, + plan: LogicalPlan, + options: &ConfigOptions, + ) -> Result> { + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(plan.inputs()); + + if let LogicalPlan::TableScan(ts) = &plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); + } - // get schema representing all available input fields. This is used for data type - // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let name_preserver = NamePreserver::new(&plan); - if let LogicalPlan::TableScan(ts) = plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); - } + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; - let mut expr_rewrite = OperatorToFunctionRewriter { - function_rewrites, - options, - schema: &schema, - }; + // recursively transform the expression, applying the rewrites at each step + let result = expr.transform_up(&|expr| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, &schema, options) + })?; + } + Ok(result) + })?; - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) + result.map_data(|expr| original_name.restore(expr)) }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) -} - -fn rewrite_subquery( - mut subquery: Subquery, - function_rewrites: &[Arc], - options: &ConfigOptions, -) -> Result { - subquery.subquery = Arc::new(analyze_internal( - &subquery.subquery, - function_rewrites, - options, - )?); - Ok(subquery) -} - -struct OperatorToFunctionRewriter<'a> { - function_rewrites: &'a [Arc], - options: &'a ConfigOptions, - schema: &'a DFSchema, + } } -impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> { - type Node = Expr; - - fn f_up(&mut self, mut expr: Expr) -> Result> { - // apply transforms one by one - let mut transformed = false; - for rewriter in self.function_rewrites.iter() { - let result = rewriter.rewrite(expr, self.schema, self.options)?; - if result.transformed { - transformed = true; - } - expr = result.data - } - - // recurse into subqueries if needed - let expr = match expr { - Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery( - subquery, - self.function_rewrites, - self.options, - )?), - - Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists { - subquery: rewrite_subquery( - subquery, - self.function_rewrites, - self.options, - )?, - negated, - }), - - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => Expr::InSubquery(InSubquery { - expr, - subquery: rewrite_subquery( - subquery, - self.function_rewrites, - self.options, - )?, - negated, - }), - - expr => expr, - }; +impl AnalyzerRule for ApplyFunctionRewrites { + fn name(&self) -> &str { + "apply_function_rewrites" + } - Ok(if transformed { - Transformed::yes(expr) - } else { - Transformed::no(expr) - }) + fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { + plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, options)) + .map(|res| res.data) } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 560c63b18882a..f0605018e6f3b 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -288,3 +288,47 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { expr_utils::merge_schema(inputs) } + +/// Handles ensuring the name of rewritten expressions is not changed. +/// +/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the +/// expression should be preserved: `3 as "1 + 2"` +/// +/// See for details +pub struct NamePreserver { + use_alias: bool, +} + +/// If the name of an expression is remembered, it will be preserved when +/// rewriting the expression +pub struct SavedName(Option); + +impl NamePreserver { + /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan + pub fn new(plan: &LogicalPlan) -> Self { + Self { + use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + } + } + + pub fn save(&self, expr: &Expr) -> Result { + let original_name = if self.use_alias { + Some(expr.name_for_alias()?) + } else { + None + }; + + Ok(SavedName(original_name)) + } +} + +impl SavedName { + /// Ensures the name of the rewritten expression is preserved + pub fn restore(self, expr: Expr) -> Result { + let Self(original_name) = self; + match original_name { + Some(name) => expr.alias_if_changed(name), + None => Ok(expr), + } + } +}