From 6e6b766bc38febfbb015e515e2ed834f26a5f381 Mon Sep 17 00:00:00 2001 From: George Date: Fri, 16 Sep 2022 20:10:31 +0800 Subject: [PATCH] feat: Union types coercion 1 --- datafusion/core/src/physical_plan/union.rs | 32 +++++- datafusion/expr/src/expr_rewriter.rs | 38 ++++++- datafusion/expr/src/logical_plan/builder.rs | 73 ++++++++----- datafusion/sql/src/planner.rs | 108 +++++++++++++++++++- 4 files changed, 219 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index d57fbe0f3df3d..bf9dfbd1b694c 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -23,8 +23,12 @@ use std::{any::Any, sync::Arc}; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{ + datatypes::{Field, Schema, SchemaRef}, + record_batch::RecordBatch, +}; use futures::StreamExt; +use itertools::Itertools; use log::debug; use super::{ @@ -46,14 +50,38 @@ pub struct UnionExec { inputs: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Schema of Union + schema: SchemaRef, } impl UnionExec { /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { + let fields: Vec = (0..inputs[0].schema().fields().len()) + .map(|i| { + inputs + .iter() + .filter_map(|input| { + if input.schema().fields().len() > i { + Some(input.schema().field(i).clone()) + } else { + None + } + }) + .find_or_first(|f| f.is_nullable()) + .unwrap() + }) + .collect(); + + let schema = Arc::new(Schema::new_with_metadata( + fields, + inputs[0].schema().metadata().clone(), + )); + UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), + schema, } } @@ -70,7 +98,7 @@ impl ExecutionPlan for UnionExec { } fn schema(&self) -> SchemaRef { - self.inputs[0].schema() + self.schema.clone() } fn children(&self) -> Vec> { diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 533f31ce1584e..1c81b4c4a60ca 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -18,8 +18,8 @@ //! Expression rewriter use crate::expr::GroupingSet; -use crate::logical_plan::Aggregate; -use crate::utils::grouping_set_to_exprlist; +use crate::logical_plan::{Aggregate, Projection}; +use crate::utils::{from_plan, grouping_set_to_exprlist}; use crate::{Expr, ExprSchemable, LogicalPlan}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; @@ -524,6 +524,40 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { exprs.into_iter().map(unnormalize_col).collect() } +/// Returns plan with expressions coerced to types compatible with +/// schema types +pub fn coerce_plan_expr_for_schema( + plan: &LogicalPlan, + schema: &DFSchema, +) -> Result { + let new_expr = plan + .expressions() + .into_iter() + .enumerate() + .map(|(i, expr)| { + let new_type = schema.field(i).data_type(); + if plan.schema().field(i).data_type() != schema.field(i).data_type() { + match (plan, &expr) { + ( + LogicalPlan::Projection(Projection { input, .. }), + Expr::Alias(e, alias), + ) => Ok(Expr::Alias( + Box::new(e.clone().cast_to(new_type, input.schema())?), + alias.clone(), + )), + _ => expr.cast_to(new_type, plan.schema()), + } + } else { + Ok(expr) + } + }) + .collect::>>()?; + + let new_inputs = plan.inputs().into_iter().cloned().collect::>(); + + from_plan(plan, &new_expr, &new_inputs) +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 0125291fda0d9..def9927ab6b32 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -17,7 +17,10 @@ //! This module provides a builder for creating LogicalPlans -use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs}; +use crate::binary_rule::comparison_coercion; +use crate::expr_rewriter::{ + coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, +}; use crate::utils::{ columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist, }; @@ -882,43 +885,59 @@ pub fn union_with_alias( right_plan: LogicalPlan, alias: Option, ) -> Result { - let union_schema = left_plan.schema().clone(); - let inputs_iter = vec![left_plan, right_plan] + let union_schema = (0..left_plan.schema().fields().len()) + .map(|i| { + let left_field = left_plan.schema().field(i); + let right_field = right_plan.schema().field(i); + let nullable = left_field.is_nullable() || right_field.is_nullable(); + let data_type = + comparison_coercion(left_field.data_type(), right_field.data_type()) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "UNION Column {} (type: {}) is not compatible with column {} (type: {})", + right_field.name(), + right_field.data_type(), + left_field.name(), + left_field.data_type() + )) + })?; + + Ok(DFField::new( + alias.as_deref(), + left_field.name(), + data_type, + nullable, + )) + }) + .collect::>>()? + .to_dfschema()?; + + let inputs = vec![left_plan, right_plan] .into_iter() .flat_map(|p| match p { LogicalPlan::Union(Union { inputs, .. }) => inputs, x => vec![Arc::new(x)], - }); - - inputs_iter - .clone() - .skip(1) - .try_for_each(|input_plan| -> Result<()> { - union_schema.check_arrow_schema_type_compatible( - &((**input_plan.schema()).clone().into()), - ) - })?; - - let inputs = inputs_iter - .map(|p| match p.as_ref() { - LogicalPlan::Projection(Projection { - expr, input, alias, .. - }) => Ok(Arc::new(project_with_column_index_alias( - expr.to_vec(), - input.clone(), - union_schema.clone(), - alias.clone(), - )?)), - x => Ok(Arc::new(x.clone())), }) - .into_iter() + .map(|p| { + let plan = coerce_plan_expr_for_schema(&p, &union_schema)?; + match plan { + LogicalPlan::Projection(Projection { + expr, input, alias, .. + }) => Ok(Arc::new(project_with_column_index_alias( + expr.to_vec(), + input, + Arc::new(union_schema.clone()), + alias, + )?)), + x => Ok(Arc::new(x)), + } + }) .collect::>>()?; if inputs.is_empty() { return Err(DataFusionError::Plan("Empty UNION".to_string())); } - let union_schema = (**inputs[0].schema()).clone(); let union_schema = Arc::new(match alias { Some(ref alias) => union_schema.replace_qualifier(alias.as_str()), None => union_schema.strip_qualifiers(), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 04518d81e4575..541dc2520e931 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -4163,7 +4163,7 @@ mod tests { let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Column Int64(1) (type: Int64) is \ + "Plan(\"UNION Column Int64(1) (type: Int64) is \ not compatible with column IntervalMonthDayNano\ (\\\"950737950189618795196236955648\\\") \ (type: Interval(MonthDayNano))\")", @@ -4171,6 +4171,112 @@ mod tests { ); } + #[test] + fn union_with_different_decimal_data_types() { + let sql = "SELECT 1 a UNION ALL SELECT 1.1 a"; + let expected = "Union\ + \n Projection: CAST(Int64(1) AS Float64) AS a\ + \n EmptyRelation\ + \n Projection: Float64(1.1) AS a\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_null() { + let sql = "SELECT NULL a UNION ALL SELECT 1.1 a"; + let expected = "Union\ + \n Projection: CAST(NULL AS Float64) AS a\ + \n EmptyRelation\ + \n Projection: Float64(1.1) AS a\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_float_and_string() { + let sql = "SELECT 'a' a UNION ALL SELECT 1.1 a"; + let expected = "Union\ + \n Projection: Utf8(\"a\") AS a\ + \n EmptyRelation\ + \n Projection: CAST(Float64(1.1) AS Utf8) AS a\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_multiply_cols() { + let sql = "SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b"; + let expected = "Union\ + \n Projection: Utf8(\"a\") AS a, CAST(Int64(1) AS Float64) AS b\ + \n EmptyRelation\ + \n Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn sorted_union_with_different_types_and_group_by() { + let sql = "SELECT a FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1"; + let expected = "Sort: #a ASC NULLS LAST\ + \n Union\ + \n Projection: CAST(#x.a AS Float64) AS a\ + \n Aggregate: groupBy=[[#x.a]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Int64(1) AS a, alias=x\ + \n EmptyRelation\ + \n Projection: #x.a\ + \n Aggregate: groupBy=[[#x.a]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Float64(1.1) AS a, alias=x\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_binary_expr_and_cast() { + let sql = "SELECT cast(0.0 + a as integer) FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT 2.1 + a FROM (select 1 a) x GROUP BY 1)"; + let expected = "Union\ + \n Projection: CAST(#Float64(0) + x.a AS Float64) AS Float64(0) + x.a\ + \n Aggregate: groupBy=[[CAST(Float64(0) + #x.a AS Int32)]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Int64(1) AS a, alias=x\ + \n EmptyRelation\ + \n Projection: #Float64(2.1) + x.a\ + \n Aggregate: groupBy=[[Float64(2.1) + #x.a]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Int64(1) AS a, alias=x\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_aliases() { + let sql = "SELECT a as a1 FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a as a1 FROM (select 1.1 a) x GROUP BY 1)"; + let expected = "Union\ + \n Projection: CAST(#x.a AS Float64) AS a1\ + \n Aggregate: groupBy=[[#x.a]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Int64(1) AS a, alias=x\ + \n EmptyRelation\ + \n Projection: #x.a AS a1\ + \n Aggregate: groupBy=[[#x.a]], aggr=[[]]\ + \n Projection: #x.a, alias=x\ + \n Projection: Float64(1.1) AS a, alias=x\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn union_with_incompatible_data_types() { + let sql = "SELECT 'a' a UNION ALL SELECT true a"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "Plan(\"UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)\")", + format!("{:?}", err) + ); + } + #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders";