From d6963c08a0dcdb3fcd3dbdac5ca8cc9a27e6cebc Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 22 Aug 2020 16:19:33 +0200 Subject: [PATCH] Fixed type coercion. This commit makes all type coercion happen on the physical plane instead of logical plane. This allows field names to not change due to coercion rules. The rational for this change is that coercions are simplifications to a physical computation (it is easier to sum two numbers of the same type at the hardware level). This commit essentially makes the logical plane to not worry about type coercion, only about the resulting type of the operator. This also addresses an issue on which the physical schema could be modified by coercion rules, causing the RecordBatch's schema to be different from the logical batch. This also addresses some inconsistencies in how we coerced certain types for binary operators, causing such inconsistencies to error during planning instead of during execution. This closes ARROW-9809 and ARROW-4957. --- .../execution/physical_plan/expressions.rs | 468 +++++++++++++++--- .../src/execution/physical_plan/filter.rs | 17 +- .../src/execution/physical_plan/planner.rs | 120 ++++- rust/datafusion/src/logicalplan.rs | 34 +- .../datafusion/src/optimizer/type_coercion.rs | 170 +------ rust/datafusion/src/optimizer/utils.rs | 117 +---- rust/datafusion/tests/sql.rs | 36 +- 7 files changed, 594 insertions(+), 368 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 8364134f8ea..a3304dabc28 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -37,9 +37,9 @@ use arrow::array::{ UInt8Builder, }; use arrow::compute; +use arrow::compute::kernels; use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; use arrow::compute::kernels::boolean::{and, or}; -use arrow::compute::kernels::cast::cast; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, @@ -991,30 +991,240 @@ impl fmt::Display for BinaryExpr { } } +// Returns a formatted error about being impossible to coerce types for the binary operator. +fn coercion_error( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + Err(ExecutionError::General( + format!( + "The binary operator '{}' can't evaluate with lhs = '{:?}' and rhs = '{:?}'", + op, lhs_type, rhs_type + ) + .to_string(), + )) +} + +// the type that both lhs and rhs can be casted to for the purpose of a string computation +fn string_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8, Utf8) => Ok(Utf8), + (LargeUtf8, Utf8) => Ok(LargeUtf8), + (Utf8, LargeUtf8) => Ok(LargeUtf8), + (LargeUtf8, LargeUtf8) => Ok(LargeUtf8), + _ => coercion_error(lhs_type, op, rhs_type), + } +} + +/// coercion rule for numerical values +pub fn numerical_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return coercion_error(lhs_type, op, rhs_type); + }; + + // same type => all good + if lhs_type == rhs_type { + return Ok(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Float64, _) => Ok(Float64), + (_, Float64) => Ok(Float64), + + (_, Float32) => Ok(Float32), + (Float32, _) => Ok(Float32), + + (Int64, _) => Ok(Int64), + (_, Int64) => Ok(Int64), + + (Int32, _) => Ok(Int32), + (_, Int32) => Ok(Int32), + + (Int16, _) => Ok(Int16), + (_, Int16) => Ok(Int16), + + (Int8, _) => Ok(Int8), + (_, Int8) => Ok(Int8), + + (UInt64, _) => Ok(UInt64), + (_, UInt64) => Ok(UInt64), + + (UInt32, _) => Ok(UInt32), + (_, UInt32) => Ok(UInt32), + + (UInt16, _) => Ok(UInt16), + (_, UInt16) => Ok(UInt16), + + (UInt8, _) => Ok(UInt8), + (_, UInt8) => Ok(UInt8), + + _ => coercion_error(lhs_type, op, rhs_type), + } +} + +// coercion rules for `equal` and `not equal`. This is a superset of all numerical coercion rules. +fn eq_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + if lhs_type == rhs_type { + // same type => equality is possible + return Ok(lhs_type.clone()); + } + numerical_coercion(lhs_type, op, rhs_type) +} + +// coercion rules for operators that assume an ordered set, such as "less than". +// These are the union of all numerical coercion rules and all string coercion rules +fn order_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + if lhs_type == rhs_type { + // same type => all good + return Ok(lhs_type.clone()); + } + + match numerical_coercion(lhs_type, op, rhs_type) { + Err(_) => { + // strings are naturally ordered, and thus ordering can be applied to them. + string_coercion(lhs_type, op, rhs_type) + } + t => t, + } +} + +/// Returns the return type of a binary operator or an error +/// when the binary operator cannot correctly perform the computation between the argument's types, even after +/// trying to coerce them. +/// +/// This function makes some assumptions about the underlying available computations. +pub fn binary_operator_data_type( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + // This result MUST be compatible with `binary_coerce` + match op { + // logical binary boolean operators can only be evaluated in bools + Operator::And | Operator::Or => match (lhs_type, rhs_type) { + (DataType::Boolean, DataType::Boolean) => Ok(DataType::Boolean), + _ => coercion_error(lhs_type, op, rhs_type), + }, + // logical equality operators have their own rules, and always return a boolean + Operator::Eq | Operator::NotEq => { + // validate that the types are valid + eq_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // "like" operators operate on strings and always return a boolean + Operator::Like | Operator::NotLike => { + // validate that the types are valid + string_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // order-comparison operators have their own rules + Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { + // validate that the types are valid + order_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // for math expressions, the final value of the coercion is also the return type + // because coercion favours higher information types + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + numerical_coercion(lhs_type, op, rhs_type) + } + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => Err(ExecutionError::InternalError( + "Trying to coerce a unary operator".to_string(), + )), + } +} + +/// return a binary physical expression that includes any necessary coercion of its arguments. +/// The coercion rule depends on the operator. +// This function MUST be compatible with `binary_operator_data_type` in that the resulting type +// from this function's expression must match `binary_operator_data_type` type: +// binary_coerce(lhs, op, rhs, schema).type === binary_operator_data_type(lhs.type, op, rhs.type) +fn binary_coerce( + lhs: Arc, + op: &Operator, + rhs: Arc, + input_schema: &Schema, +) -> Result<(Arc, Arc)> { + let lhs_type = &lhs.data_type(input_schema)?; + let rhs_type = &rhs.data_type(input_schema)?; + + match op { + // logical binary boolean operators can only be evaluated in bools + Operator::And | Operator::Or => match (lhs_type, rhs_type) { + (DataType::Boolean, DataType::Boolean) => Ok((lhs.clone(), rhs.clone())), + _ => coercion_error(lhs_type, op, rhs_type), + }, + Operator::Eq | Operator::NotEq => { + // validate that the types are valid + let cast_type = eq_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Like | Operator::NotLike => { + let cast_type = string_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { + let cast_type = order_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + let cast_type = numerical_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => Err(ExecutionError::InternalError( + "Trying to coerce a unary operator ".to_string(), + )), + } +} + impl PhysicalExpr for BinaryExpr { fn data_type(&self, input_schema: &Schema) -> Result { - Ok(match self.op { - Operator::And - | Operator::Or - | Operator::Not - | Operator::NotLike - | Operator::Like - | Operator::Lt - | Operator::LtEq - | Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::GtEq => DataType::Boolean, - Operator::Plus - | Operator::Minus - | Operator::Multiply - | Operator::Divide - | Operator::Modulus => { - // this assumes that the left and right expressions have already been co-coerced - // to the same type - self.left.data_type(input_schema)? - } - }) + binary_operator_data_type( + &self.left.data_type(input_schema)?, + &self.op, + &self.right.data_type(input_schema)?, + ) } fn nullable(&self, input_schema: &Schema) -> Result { @@ -1069,18 +1279,27 @@ impl PhysicalExpr for BinaryExpr { ))); } } - _ => Err(ExecutionError::General("Unsupported operator".to_string())), + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => { + Err(ExecutionError::General("Unsupported operator".to_string())) + } } } } -/// Create a binary expression +/// Create a binary expression whose arguments are correctly coerced. +/// This function errors if it is not possible to coerce the arguments +/// to computational types supported by the operator. pub fn binary( - l: Arc, + lhs: Arc, op: Operator, - r: Arc, -) -> Arc { - Arc::new(BinaryExpr::new(l, op, r)) + rhs: Arc, + input_schema: &Schema, +) -> Result> { + let (l, r) = binary_coerce(lhs, &op, rhs, input_schema)?; + Ok(Arc::new(BinaryExpr::new(l, op, r))) } /// Not expression @@ -1151,34 +1370,6 @@ fn is_numeric(dt: &DataType) -> bool { } } -impl CastExpr { - /// Create a CAST expression - pub fn try_new( - expr: Arc, - input_schema: &Schema, - cast_type: DataType, - ) -> Result { - let expr_type = expr.data_type(input_schema)?; - // numbers can be cast to numbers and strings - if is_numeric(&expr_type) - && (is_numeric(&cast_type) || cast_type == DataType::Utf8) - { - Ok(Self { expr, cast_type }) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Self { expr, cast_type }) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { - Ok(Self { expr, cast_type }) - } else { - Err(ExecutionError::General(format!( - "Invalid CAST from {:?} to {:?}", - expr_type, cast_type - ))) - } - } -} - impl fmt::Display for CastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) @@ -1196,7 +1387,33 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - Ok(cast(&value, &self.cast_type)?) + Ok(kernels::cast::cast(&value, &self.cast_type)?) + } +} + +/// Returns a cast operation, if casting needed. +pub fn cast( + expr: Arc, + input_schema: &Schema, + cast_type: DataType, +) -> Result> { + let expr_type = expr.data_type(input_schema)?; + if expr_type == cast_type { + return Ok(expr.clone()); + } + if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else if is_numeric(&expr_type) + && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) + { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else { + Err(ExecutionError::General(format!( + "Invalid CAST from {:?} to {:?}", + expr_type, cast_type + ))) } } @@ -1316,6 +1533,16 @@ mod tests { }; use arrow::datatypes::*; + // Create a binary expression without coercion. Used here when we do not want to coerce the expressions + // to valid types. Usage can result in an execution (after plan) error. + fn binary_simple( + l: Arc, + op: Operator, + r: Arc, + ) -> Arc { + Arc::new(BinaryExpr::new(l, op, r)) + } + #[test] fn binary_comparison() -> Result<()> { let schema = Schema::new(vec![ @@ -1330,7 +1557,7 @@ mod tests { )?; // expression: "a < b" - let lt = binary(col("a"), Operator::Lt, col("b")); + let lt = binary_simple(col("a"), Operator::Lt, col("b")); let result = lt.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1360,10 +1587,10 @@ mod tests { )?; // expression: "a < b OR a == b" - let expr = binary( - binary(col("a"), Operator::Lt, col("b")), + let expr = binary_simple( + binary_simple(col("a"), Operator::Lt, col("b")), Operator::Or, - binary(col("a"), Operator::Eq, col("b")), + binary_simple(col("a"), Operator::Eq, col("b")), ); assert_eq!("a < b OR a = b", format!("{}", expr)); @@ -1406,13 +1633,120 @@ mod tests { Ok(()) } + // runs an end-to-end test of physical type coercion: + // 1. construct a record batch with two columns of type A and B + // 2. construct a physical expression of A OP B + // 3. evaluate the expression + // 4. verify that the resulting expression is of type C + macro_rules! test_coercion { + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + let schema = Schema::new(vec![ + Field::new("a", $A_TYPE, false), + Field::new("b", $B_TYPE, false), + ]); + let a = $A_ARRAY::from($A_VEC); + let b = $B_ARRAY::from($B_VEC); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + )?; + + // verify that we can construct the expression + let expression = binary(col("a"), $OP, col("b"), &schema)?; + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema)?, $TYPE); + + // compute + let result = expression.evaluate(&batch)?; + + // verify that the array's data_type is correct + assert_eq!(*result.data_type(), $TYPE); + + // verify that the data itself is downcastable + let result = result + .as_any() + .downcast_ref::<$TYPEARRAY>() + .expect("failed to downcast"); + // verify that the result itself is correct + for (i, x) in $VEC.iter().enumerate() { + assert_eq!(result.value(i), *x); + } + }}; + } + + #[test] + fn test_type_coersion() -> Result<()> { + test_coercion!( + Int32Array, + DataType::Int32, + vec![1i32, 2i32], + UInt32Array, + DataType::UInt32, + vec![1u32, 2u32], + Operator::Plus, + Int32Array, + DataType::Int32, + vec![2i32, 4i32] + ); + test_coercion!( + Int32Array, + DataType::Int32, + vec![1i32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Plus, + Int32Array, + DataType::Int32, + vec![2i32] + ); + test_coercion!( + Float32Array, + DataType::Float32, + vec![1f32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Plus, + Float32Array, + DataType::Float32, + vec![2f32] + ); + test_coercion!( + Float32Array, + DataType::Float32, + vec![2f32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Multiply, + Float32Array, + DataType::Float32, + vec![2f32] + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["hello world", "world"], + StringArray, + DataType::Utf8, + vec!["%hello%", "%hello%"], + Operator::Like, + BooleanArray, + DataType::Boolean, + vec![true, false] + ); + Ok(()) + } + #[test] fn cast_i32_to_u32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col("a"), &schema, DataType::UInt32)?; + let cast = cast(col("a"), &schema, DataType::UInt32)?; assert_eq!("CAST(a AS UInt32)", format!("{}", cast)); let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1432,7 +1766,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col("a"), &schema, DataType::Utf8)?; + let cast = cast(col("a"), &schema, DataType::Utf8)?; let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1451,7 +1785,7 @@ mod tests { let a = Int64Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new( + let cast = cast( col("a"), &schema, DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -1471,7 +1805,7 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let result = CastExpr::try_new(col("a"), &schema, DataType::Int32); + let result = cast(col("a"), &schema, DataType::Int32); result.expect_err("Invalid CAST from Utf8 to Int32"); Ok(()) } @@ -2085,7 +2419,7 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a"), op, col("b")); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?; diff --git a/rust/datafusion/src/execution/physical_plan/filter.rs b/rust/datafusion/src/execution/physical_plan/filter.rs index f771bd59e7c..fee7fe4bcaa 100644 --- a/rust/datafusion/src/execution/physical_plan/filter.rs +++ b/rust/datafusion/src/execution/physical_plan/filter.rs @@ -183,10 +183,21 @@ mod tests { CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; let predicate: Arc = binary( - binary(col("c2"), Operator::Gt, lit(ScalarValue::UInt32(1))), + binary( + col("c2"), + Operator::Gt, + lit(ScalarValue::UInt32(1)), + &schema, + )?, Operator::And, - binary(col("c2"), Operator::Lt, lit(ScalarValue::UInt32(4))), - ); + binary( + col("c2"), + Operator::Lt, + lit(ScalarValue::UInt32(4)), + &schema, + )?, + &schema, + )?; let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 2e83cd0b9f1..b12304382af 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,13 +19,15 @@ use std::sync::Arc; +use super::expressions::binary; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::explain::ExplainExec; +use crate::execution::physical_plan::expressions; use crate::execution::physical_plan::expressions::{ - Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, + Avg, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, }; use crate::execution::physical_plan::filter::FilterExec; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; @@ -312,16 +314,18 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(Column::new(name))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new( - self.create_physical_expr(left, input_schema, ctx_state)?, - op.clone(), - self.create_physical_expr(right, input_schema, ctx_state)?, - ))), - Expr::Cast { expr, data_type } => Ok(Arc::new(CastExpr::try_new( - self.create_physical_expr(expr, input_schema, ctx_state)?, + Expr::BinaryExpr { left, op, right } => { + let lhs = + self.create_physical_expr(left, input_schema, ctx_state.clone())?; + let rhs = + self.create_physical_expr(right, input_schema, ctx_state.clone())?; + binary(lhs, op.clone(), rhs, input_schema) + } + Expr::Cast { expr, data_type } => expressions::cast( + self.create_physical_expr(expr, input_schema, ctx_state.clone())?, input_schema, data_type.clone(), - )?)), + ), Expr::ScalarFunction { name, args, @@ -426,3 +430,101 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { (Err(e), Err(_)) => Err(e), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::physical_plan::csv::CsvReadOptions; + use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder}; + use crate::{prelude::ExecutionConfig, test::arrow_testdata_path}; + use std::collections::HashMap; + + fn plan(logical_plan: &LogicalPlan) -> Result> { + let ctx_state = ExecutionContextState { + datasources: Box::new(HashMap::new()), + scalar_functions: Box::new(HashMap::new()), + config: ExecutionConfig::new(), + }; + + let planer = DefaultPhysicalPlanner {}; + planer.create_physical_plan(logical_plan, &ctx_state) + } + + #[test] + fn test_all_operators() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + + let options = CsvReadOptions::new().schema_infer_max_records(100); + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + // filter clause needs the type coercion rule applied + .filter(col("c7").lt(lit(5_u8)))? + .project(vec![col("c1"), col("c2")])? + .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? + .sort(vec![col("c1").sort(true, true)])? + .limit(10)? + .build()?; + + let plan = plan(&logical_plan)?; + + // verify that the plan correctly casts u8 to i64 + let expected = "BinaryExpr { left: Column { name: \"c7\" }, op: Lt, right: CastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; + assert!(format!("{:?}", plan).contains(expected)); + + Ok(()) + } + + #[test] + fn test_with_csv_plan() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + + let options = CsvReadOptions::new().schema_infer_max_records(100); + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + .filter(col("c7").lt(col("c12")))? + .build()?; + + let plan = plan(&logical_plan)?; + + // c12 is f64, c7 is u8 -> cast c7 to f64 + let expected = "predicate: BinaryExpr { left: CastExpr { expr: Column { name: \"c7\" }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\" } }"; + assert!(format!("{:?}", plan).contains(expected)); + Ok(()) + } + + #[test] + fn errors() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + + let bool_expr = col("c1").eq(col("c1")); + let cases = vec![ + // utf8 < u32 + col("c1").lt(col("c2")), + // utf8 AND utf8 + col("c1").and(col("c1")), + // u8 AND u8 + col("c3").and(col("c3")), + // utf8 = u32 + col("c1").eq(col("c2")), + // utf8 = bool + col("c1").eq(bool_expr.clone()), + // u32 AND bool + col("c2").and(bool_expr), + // utf8 LIKE u32 + col("c1").like(col("c2")), + ]; + for case in cases { + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + .project(vec![case.clone()]); + if let Ok(_) = logical_plan { + return Err(ExecutionError::General(format!( + "Expression {:?} expected to error due to impossible coercion", + case + ))); + }; + } + Ok(()) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index a56cb46a617..0c8289484c7 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -29,8 +29,10 @@ use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; -use crate::optimizer::utils; -use crate::sql::parser::FileType; +use crate::{ + execution::physical_plan::expressions::binary_operator_data_type, + sql::parser::FileType, +}; use arrow::record_batch::RecordBatch; /// Enumeration of supported function types (Scalar and Aggregate) @@ -389,19 +391,11 @@ impl Expr { ref left, ref right, ref op, - } => match op { - Operator::Not => Ok(DataType::Boolean), - Operator::Like | Operator::NotLike => Ok(DataType::Boolean), - Operator::Eq | Operator::NotEq => Ok(DataType::Boolean), - Operator::Lt | Operator::LtEq => Ok(DataType::Boolean), - Operator::Gt | Operator::GtEq => Ok(DataType::Boolean), - Operator::And | Operator::Or => Ok(DataType::Boolean), - _ => { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - utils::get_supertype(&left_type, &right_type) - } - }, + } => binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ), Expr::Sort { ref expr, .. } => expr.get_type(schema), Expr::Wildcard => Err(ExecutionError::General( "Wildcard expressions are not valid in a logical query plan".to_owned(), @@ -544,6 +538,16 @@ impl Expr { binary_expr(self.clone(), Operator::Modulus, other.clone()) } + /// like (string) another expression + pub fn like(&self, other: Expr) -> Expr { + binary_expr(self.clone(), Operator::Like, other.clone()) + } + + /// not like another expression + pub fn not_like(&self, other: Expr) -> Expr { + binary_expr(self.clone(), Operator::NotLike, other.clone()) + } + /// Alias pub fn alias(&self, name: &str) -> Expr { Expr::Alias(Box::new(self.clone()), name.to_owned()) diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 22f3ef020e6..0d7773cc309 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -26,9 +26,11 @@ use std::sync::Arc; use arrow::datatypes::Schema; use crate::error::{ExecutionError, Result}; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::execution::physical_plan::{ + expressions::numerical_coercion, udf::ScalarFunction, +}; use crate::logicalplan::Expr; -use crate::logicalplan::LogicalPlan; +use crate::logicalplan::{LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use utils::optimize_explain; @@ -61,16 +63,6 @@ impl TypeCoercionRule { // modify `expressions` by introducing casts when necessary match expr { - Expr::BinaryExpr { .. } => { - let left_type = expressions[0].get_type(schema)?; - let right_type = expressions[1].get_type(schema)?; - if left_type != right_type { - let super_type = utils::get_supertype(&left_type, &right_type)?; - - expressions[0] = expressions[0].cast_to(&super_type, schema)?; - expressions[1] = expressions[1].cast_to(&super_type, schema)?; - } - } Expr::ScalarFunction { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self.scalar_functions.get(name) { @@ -80,10 +72,19 @@ impl TypeCoercionRule { let actual_type = expressions[i].get_type(schema)?; let required_type = field.data_type(); if &actual_type != required_type { - let super_type = - utils::get_supertype(&actual_type, required_type)?; - expressions[i] = - expressions[i].cast_to(&super_type, schema)? + // attempt to coerce using numerical coercion + // todo: also try string coercion. + if let Ok(cast_to_type) = numerical_coercion( + &actual_type, + // assume that the function behaves like plus + // plus is not special here; the optimizer is just trying its best... + &Operator::Plus, + required_type, + ) { + expressions[i] = + expressions[i].cast_to(&cast_to_type, schema)? + }; + // not possible: do nothing and let the plan fail with a clear error message }; } } @@ -141,140 +142,3 @@ impl OptimizerRule for TypeCoercionRule { return "type_coercion"; } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::context::ExecutionContext; - use crate::execution::physical_plan::csv::CsvReadOptions; - use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder, Operator}; - use crate::test::arrow_testdata_path; - use arrow::datatypes::{DataType, Field, Schema}; - - #[test] - fn test_all_operators() -> Result<()> { - let testdata = arrow_testdata_path(); - let path = format!("{}/csv/aggregate_test_100.csv", testdata); - - let options = CsvReadOptions::new().schema_infer_max_records(100); - let plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - // filter clause needs the type coercion rule applied - .filter(col("c7").lt(lit(5_u8)))? - .project(vec![col("c1"), col("c2")])? - .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? - .sort(vec![col("c1")])? - .limit(10)? - .build()?; - - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(&scalar_functions); - let plan = rule.optimize(&plan)?; - - // check that the filter had a cast added - let plan_str = format!("{:?}", plan); - println!("{}", plan_str); - let expected_plan_str = "Limit: 10 - Sort: #c1 - Aggregate: groupBy=[[#c1]], aggr=[[SUM(#c2)]] - Projection: #c1, #c2 - Filter: #c7 Lt CAST(UInt8(5) AS Int64)"; - assert!(plan_str.starts_with(expected_plan_str)); - - Ok(()) - } - - #[test] - fn test_with_csv_plan() -> Result<()> { - let testdata = arrow_testdata_path(); - let path = format!("{}/csv/aggregate_test_100.csv", testdata); - - let options = CsvReadOptions::new().schema_infer_max_records(100); - let plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - .filter(col("c7").lt(col("c12")))? - .build()?; - - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(&scalar_functions); - let plan = rule.optimize(&plan)?; - - assert!(format!("{:?}", plan).starts_with("Filter: CAST(#c7 AS Float64) Lt #c12")); - - Ok(()) - } - - #[test] - fn test_add_i32_i64() { - binary_cast_test( - DataType::Int32, - DataType::Int64, - "CAST(#c0 AS Int64) Plus #c1", - ); - binary_cast_test( - DataType::Int64, - DataType::Int32, - "#c0 Plus CAST(#c1 AS Int64)", - ); - } - - #[test] - fn test_add_f32_f64() { - binary_cast_test( - DataType::Float32, - DataType::Float64, - "CAST(#c0 AS Float64) Plus #c1", - ); - binary_cast_test( - DataType::Float64, - DataType::Float32, - "#c0 Plus CAST(#c1 AS Float64)", - ); - } - - #[test] - fn test_add_i32_f32() { - binary_cast_test( - DataType::Int32, - DataType::Float32, - "CAST(#c0 AS Float32) Plus #c1", - ); - binary_cast_test( - DataType::Float32, - DataType::Int32, - "#c0 Plus CAST(#c1 AS Float32)", - ); - } - - #[test] - fn test_add_u32_i64() { - binary_cast_test( - DataType::UInt32, - DataType::Int64, - "CAST(#c0 AS Int64) Plus #c1", - ); - binary_cast_test( - DataType::Int64, - DataType::UInt32, - "#c0 Plus CAST(#c1 AS Int64)", - ); - } - - fn binary_cast_test(left_type: DataType, right_type: DataType, expected: &str) { - let schema = Schema::new(vec![ - Field::new("c0", left_type, true), - Field::new("c1", right_type, true), - ]); - - let expr = Expr::BinaryExpr { - left: Box::new(col("c0")), - op: Operator::Plus, - right: Box::new(col("c1")), - }; - - let ctx = ExecutionContext::new(); - let rule = TypeCoercionRule::new(ctx.scalar_functions().as_ref()); - - let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); - - assert_eq!(expected, format!("{:?}", expr2)); - } -} diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 695bb0ea4f3..b8e037f40ca 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -19,7 +19,7 @@ use std::collections::HashSet; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::error::{ExecutionError, Result}; @@ -69,121 +69,6 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< } } -/// Given two datatypes, determine the supertype that both types can safely be cast to -pub fn get_supertype(l: &DataType, r: &DataType) -> Result { - match _get_supertype(l, r) { - Some(dt) => Ok(dt), - None => _get_supertype(r, l).ok_or_else(|| { - ExecutionError::InternalError(format!( - "Failed to determine supertype of {:?} and {:?}", - l, r - )) - }), - } -} - -/// Given two datatypes, determine the supertype that both types can safely be cast to -fn _get_supertype(l: &DataType, r: &DataType) -> Option { - use arrow::datatypes::DataType::*; - match (l, r) { - (UInt8, Int8) => Some(Int8), - (UInt8, Int16) => Some(Int16), - (UInt8, Int32) => Some(Int32), - (UInt8, Int64) => Some(Int64), - - (UInt16, Int16) => Some(Int16), - (UInt16, Int32) => Some(Int32), - (UInt16, Int64) => Some(Int64), - - (UInt32, Int32) => Some(Int32), - (UInt32, Int64) => Some(Int64), - - (UInt64, Int64) => Some(Int64), - - (Int8, UInt8) => Some(Int8), - - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int16), - - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int32), - - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Int64), - - (UInt8, UInt8) => Some(UInt8), - (UInt8, UInt16) => Some(UInt16), - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt8, Float32) => Some(Float32), - (UInt8, Float64) => Some(Float64), - - (UInt16, UInt8) => Some(UInt16), - (UInt16, UInt16) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - (UInt16, Float32) => Some(Float32), - (UInt16, Float64) => Some(Float64), - - (UInt32, UInt8) => Some(UInt32), - (UInt32, UInt16) => Some(UInt32), - (UInt32, UInt32) => Some(UInt32), - (UInt32, UInt64) => Some(UInt64), - (UInt32, Float32) => Some(Float32), - (UInt32, Float64) => Some(Float64), - - (UInt64, UInt8) => Some(UInt64), - (UInt64, UInt16) => Some(UInt64), - (UInt64, UInt32) => Some(UInt64), - (UInt64, UInt64) => Some(UInt64), - (UInt64, Float32) => Some(Float32), - (UInt64, Float64) => Some(Float64), - - (Int8, Int8) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Int8) => Some(Int16), - (Int16, Int16) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Int8) => Some(Int32), - (Int32, Int16) => Some(Int32), - (Int32, Int32) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, Float32) => Some(Float32), - (Int32, Float64) => Some(Float64), - - (Int64, Int8) => Some(Int64), - (Int64, Int16) => Some(Int64), - (Int64, Int32) => Some(Int64), - (Int64, Int64) => Some(Int64), - (Int64, Float32) => Some(Float32), - (Int64, Float64) => Some(Float64), - - (Float32, Float32) => Some(Float32), - (Float32, Float64) => Some(Float64), - (Float64, Float32) => Some(Float64), - (Float64, Float64) => Some(Float64), - - (Utf8, _) => Some(Utf8), - (_, Utf8) => Some(Utf8), - - (Boolean, Boolean) => Some(Boolean), - - _ => None, - } -} - /// Create a `LogicalPlan::Explain` node by running `optimizer` on the /// input plan and capturing the resulting plan string pub fn optimize_explain( diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 942a781a3a2..e9e09b4e7ca 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -470,11 +470,7 @@ fn csv_explain_verbose() { assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); assert!(actual.contains("type_coercion"), "Actual: '{}'", actual); - assert!( - actual.contains("CAST(#c2 AS Int64) Gt Int64(10)"), - "Actual: '{}'", - actual - ); + assert!(actual.contains("#c2 Gt Int64(10)"), "Actual: '{}'", actual); } fn aggr_test_schema() -> SchemaRef { @@ -626,6 +622,13 @@ fn result_str(results: &[RecordBatch]) -> Vec { str.push_str(&format!("{:?}", s)); } + DataType::Boolean => { + let array = + column.as_any().downcast_ref::().unwrap(); + let s = array.value(row_index); + + str.push_str(&format!("{:?}", s)); + } _ => str.push_str("???"), } } @@ -654,3 +657,26 @@ fn query_length() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[test] +fn csv_query_sum_cast() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx); + // c8 = i32; c9 = i64 + let sql = "SELECT c8 + c9 FROM aggregate_test_100"; + // check that the physical and logical schemas are equal + execute(&mut ctx, sql); +} + +#[test] +fn like() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx); + let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; + // check that the physical and logical schemas are equal + let actual = execute(&mut ctx, sql).join("\n"); + + let expected = "1".to_string(); + assert_eq!(expected, actual); + Ok(()) +}