diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index e22c73e5794d4..89bcc90bc0752 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -16,12 +16,11 @@ // under the License. //! Optimizer rule to replace nested unions to single union. +use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; -use datafusion_expr::logical_plan::{LogicalPlan, Union}; - -use crate::optimizer::ApplyOrder; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::{Distinct, LogicalPlan, Union}; use std::sync::Arc; #[derive(Default)] @@ -41,22 +40,11 @@ impl OptimizerRule for EliminateNestedUnion { plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - // TODO: Add optimization for nested distinct unions. match plan { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs .iter() - .flat_map(|plan| match plan.as_ref() { - LogicalPlan::Union(Union { inputs, schema }) => inputs - .iter() - .map(|plan| { - Arc::new( - coerce_plan_expr_for_schema(plan, schema).unwrap(), - ) - }) - .collect::>(), - _ => vec![plan.clone()], - }) + .flat_map(extract_plans_from_union) .collect::>(); Ok(Some(LogicalPlan::Union(Union { @@ -64,6 +52,23 @@ impl OptimizerRule for EliminateNestedUnion { schema: schema.clone(), }))) } + LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Distinct(Distinct { + input: Arc::new(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + })), + }))) + } + _ => Ok(None), + }, _ => Ok(None), } } @@ -77,6 +82,23 @@ impl OptimizerRule for EliminateNestedUnion { } } +fn extract_plans_from_union(plan: &Arc) -> Vec> { + match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .collect::>(), + _ => vec![plan.clone()], + } +} + +fn extract_plan_from_distinct(plan: &Arc) -> &Arc { + match plan.as_ref() { + LogicalPlan::Distinct(Distinct { input: plan }) => plan, + _ => plan, + } +} + #[cfg(test)] mod tests { use super::*; @@ -112,6 +134,22 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_distinct_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn eliminate_nested_union() -> Result<()> { let plan_builder = table_scan(Some("table"), &schema(), None)?; @@ -132,6 +170,69 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_nested_union_with_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "Union\ + \n Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().distinct()?.build()?)? + .union(plan_builder.clone().distinct()?.build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + // We don't need to use project_with_column_index in logical optimizer, // after LogicalPlanBuilder::union, we already have all equal expression aliases #[test] @@ -163,6 +264,36 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_nested_distinct_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { let table_1 = table_scan( @@ -208,4 +339,51 @@ mod tests { \n TableScan: table_1"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union_distinct(table_2.build()?)? + .union_distinct(table_3.build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index b11a687d8b9ff..cbb1896efb131 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -186,8 +186,7 @@ Bob_new John John_new -# should be un-nested -# https://github.com/apache/arrow-datafusion/issues/7786 +# should be un-nested, with a single (logical) aggregate query TT EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) ---- @@ -195,26 +194,19 @@ logical_plan Aggregate: groupBy=[[t1.name]], aggr=[[]] --Union ----TableScan: t1 projection=[name] -----Aggregate: groupBy=[[t2.name]], aggr=[[]] -------Union ---------TableScan: t2 projection=[name] ---------Projection: t2.name || Utf8("_new") AS name -----------TableScan: t2 projection=[name] +----TableScan: t2 projection=[name] +----Projection: t2.name || Utf8("_new") AS name +------TableScan: t2 projection=[name] physical_plan AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] --CoalesceBatchesExec: target_batch_size=8192 -----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8 +----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=12 ------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] --------UnionExec ----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -----------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8 -----------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] -------------------UnionExec ---------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------------------ProjectionExec: expr=[name@0 || _new as name] -----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------ProjectionExec: expr=[name@0 || _new as name] +------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] # nested_union_all query T rowsort