Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions datafusion/core/src/physical_plan/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@

use std::{any::Any, sync::Arc};

use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use arrow::{
datatypes::{Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use futures::StreamExt;
use itertools::Itertools;
use log::debug;

use super::{
Expand All @@ -46,14 +50,38 @@ pub struct UnionExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Schema of Union
schema: SchemaRef,
}

impl UnionExec {
/// Create a new UnionExec
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
let fields: Vec<Field> = (0..inputs[0].schema().fields().len())
.map(|i| {
inputs
.iter()
.filter_map(|input| {
if input.schema().fields().len() > i {
Some(input.schema().field(i).clone())
} else {
None
}
})
.find_or_first(|f| f.is_nullable())
.unwrap()
})
.collect();

let schema = Arc::new(Schema::new_with_metadata(
fields,
inputs[0].schema().metadata().clone(),
));

UnionExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
schema,
}
}

Expand All @@ -70,7 +98,7 @@ impl ExecutionPlan for UnionExec {
}

fn schema(&self) -> SchemaRef {
self.inputs[0].schema()
self.schema.clone()
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
Expand Down
38 changes: 36 additions & 2 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Expression rewriter

use crate::expr::GroupingSet;
use crate::logical_plan::Aggregate;
use crate::utils::grouping_set_to_exprlist;
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::{from_plan, grouping_set_to_exprlist};
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -524,6 +524,40 @@ pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}

/// Returns plan with expressions coerced to types compatible with
/// schema types
pub fn coerce_plan_expr_for_schema(
plan: &LogicalPlan,
schema: &DFSchema,
) -> Result<LogicalPlan> {
let new_expr = plan
.expressions()
.into_iter()
.enumerate()
.map(|(i, expr)| {
let new_type = schema.field(i).data_type();
if plan.schema().field(i).data_type() != schema.field(i).data_type() {
match (plan, &expr) {
(
LogicalPlan::Projection(Projection { input, .. }),
Expr::Alias(e, alias),
) => Ok(Expr::Alias(
Box::new(e.clone().cast_to(new_type, input.schema())?),
alias.clone(),
)),
_ => expr.cast_to(new_type, plan.schema()),
}
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;

let new_inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();

from_plan(plan, &new_expr, &new_inputs)
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
73 changes: 46 additions & 27 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

//! This module provides a builder for creating LogicalPlans

use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
use crate::binary_rule::comparison_coercion;
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::utils::{
columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
};
Expand Down Expand Up @@ -882,43 +885,59 @@ pub fn union_with_alias(
right_plan: LogicalPlan,
alias: Option<String>,
) -> Result<LogicalPlan> {
let union_schema = left_plan.schema().clone();
let inputs_iter = vec![left_plan, right_plan]
let union_schema = (0..left_plan.schema().fields().len())
.map(|i| {
let left_field = left_plan.schema().field(i);
let right_field = right_plan.schema().field(i);
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
))
})?;

Ok(DFField::new(
alias.as_deref(),
left_field.name(),
data_type,
nullable,
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.flat_map(|p| match p {
LogicalPlan::Union(Union { inputs, .. }) => inputs,
x => vec![Arc::new(x)],
});

inputs_iter
.clone()
.skip(1)
.try_for_each(|input_plan| -> Result<()> {
union_schema.check_arrow_schema_type_compatible(
&((**input_plan.schema()).clone().into()),
)
})?;

let inputs = inputs_iter
.map(|p| match p.as_ref() {
LogicalPlan::Projection(Projection {
expr, input, alias, ..
}) => Ok(Arc::new(project_with_column_index_alias(
expr.to_vec(),
input.clone(),
union_schema.clone(),
alias.clone(),
)?)),
x => Ok(Arc::new(x.clone())),
})
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection {
expr, input, alias, ..
}) => Ok(Arc::new(project_with_column_index_alias(
expr.to_vec(),
input,
Arc::new(union_schema.clone()),
alias,
)?)),
x => Ok(Arc::new(x)),
}
})
.collect::<Result<Vec<_>>>()?;

if inputs.is_empty() {
return Err(DataFusionError::Plan("Empty UNION".to_string()));
}

let union_schema = (**inputs[0].schema()).clone();
let union_schema = Arc::new(match alias {
Some(ref alias) => union_schema.replace_qualifier(alias.as_str()),
None => union_schema.strip_qualifiers(),
Expand Down
108 changes: 107 additions & 1 deletion datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4163,14 +4163,120 @@ mod tests {
let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Column Int64(1) (type: Int64) is \
"Plan(\"UNION Column Int64(1) (type: Int64) is \
not compatible with column IntervalMonthDayNano\
(\\\"950737950189618795196236955648\\\") \
(type: Interval(MonthDayNano))\")",
format!("{:?}", err)
);
}

#[test]
fn union_with_different_decimal_data_types() {
let sql = "SELECT 1 a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: CAST(Int64(1) AS Float64) AS a\
\n EmptyRelation\
\n Projection: Float64(1.1) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_null() {
let sql = "SELECT NULL a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: CAST(NULL AS Float64) AS a\
\n EmptyRelation\
\n Projection: Float64(1.1) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_float_and_string() {
let sql = "SELECT 'a' a UNION ALL SELECT 1.1 a";
let expected = "Union\
\n Projection: Utf8(\"a\") AS a\
\n EmptyRelation\
\n Projection: CAST(Float64(1.1) AS Utf8) AS a\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_multiply_cols() {
let sql = "SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b";
let expected = "Union\
\n Projection: Utf8(\"a\") AS a, CAST(Int64(1) AS Float64) AS b\
\n EmptyRelation\
\n Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn sorted_union_with_different_types_and_group_by() {
let sql = "SELECT a FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1";
let expected = "Sort: #a ASC NULLS LAST\
\n Union\
\n Projection: CAST(#x.a AS Float64) AS a\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #x.a\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Float64(1.1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_binary_expr_and_cast() {
let sql = "SELECT cast(0.0 + a as integer) FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT 2.1 + a FROM (select 1 a) x GROUP BY 1)";
let expected = "Union\
\n Projection: CAST(#Float64(0) + x.a AS Float64) AS Float64(0) + x.a\
\n Aggregate: groupBy=[[CAST(Float64(0) + #x.a AS Int32)]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #Float64(2.1) + x.a\
\n Aggregate: groupBy=[[Float64(2.1) + #x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_aliases() {
let sql = "SELECT a as a1 FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a as a1 FROM (select 1.1 a) x GROUP BY 1)";
let expected = "Union\
\n Projection: CAST(#x.a AS Float64) AS a1\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Int64(1) AS a, alias=x\
\n EmptyRelation\
\n Projection: #x.a AS a1\
\n Aggregate: groupBy=[[#x.a]], aggr=[[]]\
\n Projection: #x.a, alias=x\
\n Projection: Float64(1.1) AS a, alias=x\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn union_with_incompatible_data_types() {
let sql = "SELECT 'a' a UNION ALL SELECT true a";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)\")",
format!("{:?}", err)
);
}

#[test]
fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
Expand Down