diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 85022019b5e95..8dfd49e399c2c 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -3230,7 +3230,7 @@ mod tests { SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a UNION ALL SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b - ) AS all + ) GROUP BY cat ORDER BY cat ", diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index c88b25d0a2251..49cfe2b30ea27 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -1075,25 +1075,48 @@ pub fn union_with_alias( } let union_schema = (**inputs[0].schema()).clone(); - let union_schema = Arc::new(match alias { - Some(ref alias) => union_schema.replace_qualifier(alias.as_str()), - None => union_schema.strip_qualifiers(), - }); - - inputs - .iter() - .skip(1) - .try_for_each(|input_plan| -> Result<()> { - union_schema.check_arrow_schema_type_compatible( - &((**input_plan.schema()).clone().into()), - ) - })?; - - Ok(LogicalPlan::Union(Union { - inputs, - schema: union_schema, - alias, - })) + + match alias { + Some(ref alias) => { + let union_schema_copy = union_schema.clone(); + let union_schema = union_schema.strip_qualifiers(); + let alias_schema = union_schema_copy.replace_qualifier(alias.as_str()); + + inputs + .iter() + .skip(1) + .try_for_each(|input_plan| -> Result<()> { + union_schema.check_arrow_schema_type_compatible( + &((**input_plan.schema()).clone().into()), + ) + })?; + + Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { + input: Arc::new(LogicalPlan::Union(Union { + inputs, + schema: Arc::new(union_schema), + })), + alias: alias.to_string(), + schema: Arc::new(alias_schema), + })) + } + None => { + let union_schema = union_schema.strip_qualifiers(); + inputs + .iter() + .skip(1) + .try_for_each(|input_plan| -> Result<()> { + union_schema.check_arrow_schema_type_compatible( + &((**input_plan.schema()).clone().into()), + ) + })?; + + Ok(LogicalPlan::Union(Union { + inputs, + schema: Arc::new(union_schema), + })) + } + } } /// Project with optional alias diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 66307c6aba464..3dc98a9d1da3c 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -169,8 +169,6 @@ pub struct Union { pub inputs: Vec, /// Union schema. Should be the same for all inputs. pub schema: DFSchemaRef, - /// Union output relation alias - pub alias: Option, } /// Creates an in memory table. diff --git a/datafusion/core/src/optimizer/filter_push_down.rs b/datafusion/core/src/optimizer/filter_push_down.rs index 30a7ee97328e8..ed6a26a4f49a3 100644 --- a/datafusion/core/src/optimizer/filter_push_down.rs +++ b/datafusion/core/src/optimizer/filter_push_down.rs @@ -16,7 +16,9 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection, Union}; +use crate::logical_plan::plan::{ + Aggregate, Filter, Join, Projection, SubqueryAlias, Union, +}; use crate::logical_plan::{ and, col, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, }; @@ -393,11 +395,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // sort is filter-commutable push_down(&state, plan) } - LogicalPlan::Union(Union { - inputs: _, - schema, - alias: _, - }) => { + LogicalPlan::Union(Union { inputs: _, schema }) => { // union changing all qualifiers while building logical plan so we need // to rewrite filters to push unqualified columns to inputs let projection = schema @@ -542,6 +540,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { }), ) } + LogicalPlan::SubqueryAlias(SubqueryAlias { .. }) => push_down(&state, plan), _ => { // all other plans are _not_ filter-commutable let used_columns = plan @@ -935,11 +934,12 @@ mod tests { // filter appears below Union without relation qualifier let expected = "\ - Union\ - \n Filter: #a = Int64(1)\ - \n TableScan: test projection=None\ - \n Filter: #a = Int64(1)\ - \n TableScan: test projection=None"; + SubqueryAlias: t\ + \n Union\ + \n Filter: #t.a = Int64(1)\ + \n TableScan: test projection=None\ + \n Filter: #t.a = Int64(1)\ + \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/core/src/optimizer/limit_push_down.rs b/datafusion/core/src/optimizer/limit_push_down.rs index 0c68f1761601d..1e92f3db31672 100644 --- a/datafusion/core/src/optimizer/limit_push_down.rs +++ b/datafusion/core/src/optimizer/limit_push_down.rs @@ -20,7 +20,7 @@ use super::utils; use crate::error::Result; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::Projection; +use crate::logical_plan::plan::{Projection, SubqueryAlias}; use crate::logical_plan::{Limit, TableScan}; use crate::logical_plan::{LogicalPlan, Union}; use crate::optimizer::optimizer::OptimizerRule; @@ -100,14 +100,7 @@ fn limit_push_down( alias: alias.clone(), })) } - ( - LogicalPlan::Union(Union { - inputs, - alias, - schema, - }), - Some(upper_limit), - ) => { + (LogicalPlan::Union(Union { inputs, schema }), Some(upper_limit)) => { // Push down limit through UNION let new_inputs = inputs .iter() @@ -125,10 +118,29 @@ fn limit_push_down( .collect::>()?; Ok(LogicalPlan::Union(Union { inputs: new_inputs, - alias: alias.clone(), schema: schema.clone(), })) } + ( + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema, + }), + upper_limit, + ) => { + // Push down limit directly + Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { + input: Arc::new(limit_push_down( + _optimizer, + upper_limit, + input.as_ref(), + _execution_props, + )?), + schema: schema.clone(), + alias: alias.clone(), + })) + } // For other nodes we can't push down the limit // But try to recurse and find other limit nodes to push down _ => { diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 10bf5d10f9602..caa2cd40a27c8 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -383,11 +383,7 @@ fn optimize_plan( schema: a.schema.clone(), })) } - LogicalPlan::Union(Union { - inputs, - schema, - alias, - }) => { + LogicalPlan::Union(Union { inputs, schema }) => { // UNION inputs will reference the same column with different identifiers, so we need // to populate new_required_columns by unqualified column name based on required fields // from the resulting UNION output @@ -429,7 +425,6 @@ fn optimize_plan( Ok(LogicalPlan::Union(Union { inputs: new_inputs, schema: Arc::new(new_schema), - alias: alias.clone(), })) } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { @@ -455,6 +450,22 @@ fn optimize_plan( let expr = vec![]; utils::from_plan(plan, &expr, &new_inputs) } + LogicalPlan::Union(Union { inputs, .. }) => { + let new_inputs = inputs + .iter() + .map(|input_plan| { + optimize_plan( + _optimizer, + input_plan, + &new_required_columns, + has_projection, + _execution_props, + ) + }) + .collect::>>()?; + let expr = vec![]; + utils::from_plan(plan, &expr, &new_inputs) + } _ => Err(DataFusionError::Plan( "SubqueryAlias should only wrap TableScan".to_string(), )), diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 0dab2d3ed7bcb..7742714ebed7c 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -249,13 +249,10 @@ pub fn from_plan( LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { node: e.node.from_template(expr, inputs), })), - LogicalPlan::Union(Union { schema, alias, .. }) => { - Ok(LogicalPlan::Union(Union { - inputs: inputs.to_vec(), - schema: schema.clone(), - alias: alias.clone(), - })) - } + LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union { + inputs: inputs.to_vec(), + schema: schema.clone(), + })), LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1);