diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 8364134f8ea..a3304dabc28 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -37,9 +37,9 @@ use arrow::array::{ UInt8Builder, }; use arrow::compute; +use arrow::compute::kernels; use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; use arrow::compute::kernels::boolean::{and, or}; -use arrow::compute::kernels::cast::cast; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, @@ -991,30 +991,240 @@ impl fmt::Display for BinaryExpr { } } +// Returns a formatted error about being impossible to coerce types for the binary operator. +fn coercion_error( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + Err(ExecutionError::General( + format!( + "The binary operator '{}' can't evaluate with lhs = '{:?}' and rhs = '{:?}'", + op, lhs_type, rhs_type + ) + .to_string(), + )) +} + +// the type that both lhs and rhs can be casted to for the purpose of a string computation +fn string_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8, Utf8) => Ok(Utf8), + (LargeUtf8, Utf8) => Ok(LargeUtf8), + (Utf8, LargeUtf8) => Ok(LargeUtf8), + (LargeUtf8, LargeUtf8) => Ok(LargeUtf8), + _ => coercion_error(lhs_type, op, rhs_type), + } +} + +/// coercion rule for numerical values +pub fn numerical_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return coercion_error(lhs_type, op, rhs_type); + }; + + // same type => all good + if lhs_type == rhs_type { + return Ok(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Float64, _) => Ok(Float64), + (_, Float64) => Ok(Float64), + + (_, Float32) => Ok(Float32), + (Float32, _) => Ok(Float32), + + (Int64, _) => Ok(Int64), + (_, Int64) => Ok(Int64), + + (Int32, _) => Ok(Int32), + (_, Int32) => Ok(Int32), + + (Int16, _) => Ok(Int16), + (_, Int16) => Ok(Int16), + + (Int8, _) => Ok(Int8), + (_, Int8) => Ok(Int8), + + (UInt64, _) => Ok(UInt64), + (_, UInt64) => Ok(UInt64), + + (UInt32, _) => Ok(UInt32), + (_, UInt32) => Ok(UInt32), + + (UInt16, _) => Ok(UInt16), + (_, UInt16) => Ok(UInt16), + + (UInt8, _) => Ok(UInt8), + (_, UInt8) => Ok(UInt8), + + _ => coercion_error(lhs_type, op, rhs_type), + } +} + +// coercion rules for `equal` and `not equal`. This is a superset of all numerical coercion rules. +fn eq_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + if lhs_type == rhs_type { + // same type => equality is possible + return Ok(lhs_type.clone()); + } + numerical_coercion(lhs_type, op, rhs_type) +} + +// coercion rules for operators that assume an ordered set, such as "less than". +// These are the union of all numerical coercion rules and all string coercion rules +fn order_coercion( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + if lhs_type == rhs_type { + // same type => all good + return Ok(lhs_type.clone()); + } + + match numerical_coercion(lhs_type, op, rhs_type) { + Err(_) => { + // strings are naturally ordered, and thus ordering can be applied to them. + string_coercion(lhs_type, op, rhs_type) + } + t => t, + } +} + +/// Returns the return type of a binary operator or an error +/// when the binary operator cannot correctly perform the computation between the argument's types, even after +/// trying to coerce them. +/// +/// This function makes some assumptions about the underlying available computations. +pub fn binary_operator_data_type( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + // This result MUST be compatible with `binary_coerce` + match op { + // logical binary boolean operators can only be evaluated in bools + Operator::And | Operator::Or => match (lhs_type, rhs_type) { + (DataType::Boolean, DataType::Boolean) => Ok(DataType::Boolean), + _ => coercion_error(lhs_type, op, rhs_type), + }, + // logical equality operators have their own rules, and always return a boolean + Operator::Eq | Operator::NotEq => { + // validate that the types are valid + eq_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // "like" operators operate on strings and always return a boolean + Operator::Like | Operator::NotLike => { + // validate that the types are valid + string_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // order-comparison operators have their own rules + Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { + // validate that the types are valid + order_coercion(lhs_type, op, rhs_type)?; + Ok(DataType::Boolean) + } + // for math expressions, the final value of the coercion is also the return type + // because coercion favours higher information types + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + numerical_coercion(lhs_type, op, rhs_type) + } + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => Err(ExecutionError::InternalError( + "Trying to coerce a unary operator".to_string(), + )), + } +} + +/// return a binary physical expression that includes any necessary coercion of its arguments. +/// The coercion rule depends on the operator. +// This function MUST be compatible with `binary_operator_data_type` in that the resulting type +// from this function's expression must match `binary_operator_data_type` type: +// binary_coerce(lhs, op, rhs, schema).type === binary_operator_data_type(lhs.type, op, rhs.type) +fn binary_coerce( + lhs: Arc, + op: &Operator, + rhs: Arc, + input_schema: &Schema, +) -> Result<(Arc, Arc)> { + let lhs_type = &lhs.data_type(input_schema)?; + let rhs_type = &rhs.data_type(input_schema)?; + + match op { + // logical binary boolean operators can only be evaluated in bools + Operator::And | Operator::Or => match (lhs_type, rhs_type) { + (DataType::Boolean, DataType::Boolean) => Ok((lhs.clone(), rhs.clone())), + _ => coercion_error(lhs_type, op, rhs_type), + }, + Operator::Eq | Operator::NotEq => { + // validate that the types are valid + let cast_type = eq_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Like | Operator::NotLike => { + let cast_type = string_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { + let cast_type = order_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => { + let cast_type = numerical_coercion(lhs_type, op, rhs_type)?; + Ok(( + cast(lhs, input_schema, cast_type.clone())?, + cast(rhs, input_schema, cast_type)?, + )) + } + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => Err(ExecutionError::InternalError( + "Trying to coerce a unary operator ".to_string(), + )), + } +} + impl PhysicalExpr for BinaryExpr { fn data_type(&self, input_schema: &Schema) -> Result { - Ok(match self.op { - Operator::And - | Operator::Or - | Operator::Not - | Operator::NotLike - | Operator::Like - | Operator::Lt - | Operator::LtEq - | Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::GtEq => DataType::Boolean, - Operator::Plus - | Operator::Minus - | Operator::Multiply - | Operator::Divide - | Operator::Modulus => { - // this assumes that the left and right expressions have already been co-coerced - // to the same type - self.left.data_type(input_schema)? - } - }) + binary_operator_data_type( + &self.left.data_type(input_schema)?, + &self.op, + &self.right.data_type(input_schema)?, + ) } fn nullable(&self, input_schema: &Schema) -> Result { @@ -1069,18 +1279,27 @@ impl PhysicalExpr for BinaryExpr { ))); } } - _ => Err(ExecutionError::General("Unsupported operator".to_string())), + Operator::Modulus => Err(ExecutionError::NotImplemented( + "Modulus operator is still not supported".to_string(), + )), + Operator::Not => { + Err(ExecutionError::General("Unsupported operator".to_string())) + } } } } -/// Create a binary expression +/// Create a binary expression whose arguments are correctly coerced. +/// This function errors if it is not possible to coerce the arguments +/// to computational types supported by the operator. pub fn binary( - l: Arc, + lhs: Arc, op: Operator, - r: Arc, -) -> Arc { - Arc::new(BinaryExpr::new(l, op, r)) + rhs: Arc, + input_schema: &Schema, +) -> Result> { + let (l, r) = binary_coerce(lhs, &op, rhs, input_schema)?; + Ok(Arc::new(BinaryExpr::new(l, op, r))) } /// Not expression @@ -1151,34 +1370,6 @@ fn is_numeric(dt: &DataType) -> bool { } } -impl CastExpr { - /// Create a CAST expression - pub fn try_new( - expr: Arc, - input_schema: &Schema, - cast_type: DataType, - ) -> Result { - let expr_type = expr.data_type(input_schema)?; - // numbers can be cast to numbers and strings - if is_numeric(&expr_type) - && (is_numeric(&cast_type) || cast_type == DataType::Utf8) - { - Ok(Self { expr, cast_type }) - } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { - Ok(Self { expr, cast_type }) - } else if is_numeric(&expr_type) - && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) - { - Ok(Self { expr, cast_type }) - } else { - Err(ExecutionError::General(format!( - "Invalid CAST from {:?} to {:?}", - expr_type, cast_type - ))) - } - } -} - impl fmt::Display for CastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) @@ -1196,7 +1387,33 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - Ok(cast(&value, &self.cast_type)?) + Ok(kernels::cast::cast(&value, &self.cast_type)?) + } +} + +/// Returns a cast operation, if casting needed. +pub fn cast( + expr: Arc, + input_schema: &Schema, + cast_type: DataType, +) -> Result> { + let expr_type = expr.data_type(input_schema)?; + if expr_type == cast_type { + return Ok(expr.clone()); + } + if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else if is_numeric(&expr_type) + && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None) + { + Ok(Arc::new(CastExpr { expr, cast_type })) + } else { + Err(ExecutionError::General(format!( + "Invalid CAST from {:?} to {:?}", + expr_type, cast_type + ))) } } @@ -1316,6 +1533,16 @@ mod tests { }; use arrow::datatypes::*; + // Create a binary expression without coercion. Used here when we do not want to coerce the expressions + // to valid types. Usage can result in an execution (after plan) error. + fn binary_simple( + l: Arc, + op: Operator, + r: Arc, + ) -> Arc { + Arc::new(BinaryExpr::new(l, op, r)) + } + #[test] fn binary_comparison() -> Result<()> { let schema = Schema::new(vec![ @@ -1330,7 +1557,7 @@ mod tests { )?; // expression: "a < b" - let lt = binary(col("a"), Operator::Lt, col("b")); + let lt = binary_simple(col("a"), Operator::Lt, col("b")); let result = lt.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1360,10 +1587,10 @@ mod tests { )?; // expression: "a < b OR a == b" - let expr = binary( - binary(col("a"), Operator::Lt, col("b")), + let expr = binary_simple( + binary_simple(col("a"), Operator::Lt, col("b")), Operator::Or, - binary(col("a"), Operator::Eq, col("b")), + binary_simple(col("a"), Operator::Eq, col("b")), ); assert_eq!("a < b OR a = b", format!("{}", expr)); @@ -1406,13 +1633,120 @@ mod tests { Ok(()) } + // runs an end-to-end test of physical type coercion: + // 1. construct a record batch with two columns of type A and B + // 2. construct a physical expression of A OP B + // 3. evaluate the expression + // 4. verify that the resulting expression is of type C + macro_rules! test_coercion { + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + let schema = Schema::new(vec![ + Field::new("a", $A_TYPE, false), + Field::new("b", $B_TYPE, false), + ]); + let a = $A_ARRAY::from($A_VEC); + let b = $B_ARRAY::from($B_VEC); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + )?; + + // verify that we can construct the expression + let expression = binary(col("a"), $OP, col("b"), &schema)?; + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema)?, $TYPE); + + // compute + let result = expression.evaluate(&batch)?; + + // verify that the array's data_type is correct + assert_eq!(*result.data_type(), $TYPE); + + // verify that the data itself is downcastable + let result = result + .as_any() + .downcast_ref::<$TYPEARRAY>() + .expect("failed to downcast"); + // verify that the result itself is correct + for (i, x) in $VEC.iter().enumerate() { + assert_eq!(result.value(i), *x); + } + }}; + } + + #[test] + fn test_type_coersion() -> Result<()> { + test_coercion!( + Int32Array, + DataType::Int32, + vec![1i32, 2i32], + UInt32Array, + DataType::UInt32, + vec![1u32, 2u32], + Operator::Plus, + Int32Array, + DataType::Int32, + vec![2i32, 4i32] + ); + test_coercion!( + Int32Array, + DataType::Int32, + vec![1i32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Plus, + Int32Array, + DataType::Int32, + vec![2i32] + ); + test_coercion!( + Float32Array, + DataType::Float32, + vec![1f32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Plus, + Float32Array, + DataType::Float32, + vec![2f32] + ); + test_coercion!( + Float32Array, + DataType::Float32, + vec![2f32], + UInt16Array, + DataType::UInt16, + vec![1u16], + Operator::Multiply, + Float32Array, + DataType::Float32, + vec![2f32] + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["hello world", "world"], + StringArray, + DataType::Utf8, + vec!["%hello%", "%hello%"], + Operator::Like, + BooleanArray, + DataType::Boolean, + vec![true, false] + ); + Ok(()) + } + #[test] fn cast_i32_to_u32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col("a"), &schema, DataType::UInt32)?; + let cast = cast(col("a"), &schema, DataType::UInt32)?; assert_eq!("CAST(a AS UInt32)", format!("{}", cast)); let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1432,7 +1766,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col("a"), &schema, DataType::Utf8)?; + let cast = cast(col("a"), &schema, DataType::Utf8)?; let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1451,7 +1785,7 @@ mod tests { let a = Int64Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new( + let cast = cast( col("a"), &schema, DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -1471,7 +1805,7 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let result = CastExpr::try_new(col("a"), &schema, DataType::Int32); + let result = cast(col("a"), &schema, DataType::Int32); result.expect_err("Invalid CAST from Utf8 to Int32"); Ok(()) } @@ -2085,7 +2419,7 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a"), op, col("b")); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?; diff --git a/rust/datafusion/src/execution/physical_plan/filter.rs b/rust/datafusion/src/execution/physical_plan/filter.rs index f771bd59e7c..fee7fe4bcaa 100644 --- a/rust/datafusion/src/execution/physical_plan/filter.rs +++ b/rust/datafusion/src/execution/physical_plan/filter.rs @@ -183,10 +183,21 @@ mod tests { CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; let predicate: Arc = binary( - binary(col("c2"), Operator::Gt, lit(ScalarValue::UInt32(1))), + binary( + col("c2"), + Operator::Gt, + lit(ScalarValue::UInt32(1)), + &schema, + )?, Operator::And, - binary(col("c2"), Operator::Lt, lit(ScalarValue::UInt32(4))), - ); + binary( + col("c2"), + Operator::Lt, + lit(ScalarValue::UInt32(4)), + &schema, + )?, + &schema, + )?; let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 2e83cd0b9f1..b12304382af 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,13 +19,15 @@ use std::sync::Arc; +use super::expressions::binary; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::explain::ExplainExec; +use crate::execution::physical_plan::expressions; use crate::execution::physical_plan::expressions::{ - Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, + Avg, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, }; use crate::execution::physical_plan::filter::FilterExec; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; @@ -312,16 +314,18 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(Column::new(name))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new( - self.create_physical_expr(left, input_schema, ctx_state)?, - op.clone(), - self.create_physical_expr(right, input_schema, ctx_state)?, - ))), - Expr::Cast { expr, data_type } => Ok(Arc::new(CastExpr::try_new( - self.create_physical_expr(expr, input_schema, ctx_state)?, + Expr::BinaryExpr { left, op, right } => { + let lhs = + self.create_physical_expr(left, input_schema, ctx_state.clone())?; + let rhs = + self.create_physical_expr(right, input_schema, ctx_state.clone())?; + binary(lhs, op.clone(), rhs, input_schema) + } + Expr::Cast { expr, data_type } => expressions::cast( + self.create_physical_expr(expr, input_schema, ctx_state.clone())?, input_schema, data_type.clone(), - )?)), + ), Expr::ScalarFunction { name, args, @@ -426,3 +430,101 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { (Err(e), Err(_)) => Err(e), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::physical_plan::csv::CsvReadOptions; + use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder}; + use crate::{prelude::ExecutionConfig, test::arrow_testdata_path}; + use std::collections::HashMap; + + fn plan(logical_plan: &LogicalPlan) -> Result> { + let ctx_state = ExecutionContextState { + datasources: Box::new(HashMap::new()), + scalar_functions: Box::new(HashMap::new()), + config: ExecutionConfig::new(), + }; + + let planer = DefaultPhysicalPlanner {}; + planer.create_physical_plan(logical_plan, &ctx_state) + } + + #[test] + fn test_all_operators() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + + let options = CsvReadOptions::new().schema_infer_max_records(100); + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + // filter clause needs the type coercion rule applied + .filter(col("c7").lt(lit(5_u8)))? + .project(vec![col("c1"), col("c2")])? + .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? + .sort(vec![col("c1").sort(true, true)])? + .limit(10)? + .build()?; + + let plan = plan(&logical_plan)?; + + // verify that the plan correctly casts u8 to i64 + let expected = "BinaryExpr { left: Column { name: \"c7\" }, op: Lt, right: CastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; + assert!(format!("{:?}", plan).contains(expected)); + + Ok(()) + } + + #[test] + fn test_with_csv_plan() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + + let options = CsvReadOptions::new().schema_infer_max_records(100); + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + .filter(col("c7").lt(col("c12")))? + .build()?; + + let plan = plan(&logical_plan)?; + + // c12 is f64, c7 is u8 -> cast c7 to f64 + let expected = "predicate: BinaryExpr { left: CastExpr { expr: Column { name: \"c7\" }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\" } }"; + assert!(format!("{:?}", plan).contains(expected)); + Ok(()) + } + + #[test] + fn errors() -> Result<()> { + let testdata = arrow_testdata_path(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + + let bool_expr = col("c1").eq(col("c1")); + let cases = vec![ + // utf8 < u32 + col("c1").lt(col("c2")), + // utf8 AND utf8 + col("c1").and(col("c1")), + // u8 AND u8 + col("c3").and(col("c3")), + // utf8 = u32 + col("c1").eq(col("c2")), + // utf8 = bool + col("c1").eq(bool_expr.clone()), + // u32 AND bool + col("c2").and(bool_expr), + // utf8 LIKE u32 + col("c1").like(col("c2")), + ]; + for case in cases { + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + .project(vec![case.clone()]); + if let Ok(_) = logical_plan { + return Err(ExecutionError::General(format!( + "Expression {:?} expected to error due to impossible coercion", + case + ))); + }; + } + Ok(()) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index a56cb46a617..0c8289484c7 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -29,8 +29,10 @@ use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; -use crate::optimizer::utils; -use crate::sql::parser::FileType; +use crate::{ + execution::physical_plan::expressions::binary_operator_data_type, + sql::parser::FileType, +}; use arrow::record_batch::RecordBatch; /// Enumeration of supported function types (Scalar and Aggregate) @@ -389,19 +391,11 @@ impl Expr { ref left, ref right, ref op, - } => match op { - Operator::Not => Ok(DataType::Boolean), - Operator::Like | Operator::NotLike => Ok(DataType::Boolean), - Operator::Eq | Operator::NotEq => Ok(DataType::Boolean), - Operator::Lt | Operator::LtEq => Ok(DataType::Boolean), - Operator::Gt | Operator::GtEq => Ok(DataType::Boolean), - Operator::And | Operator::Or => Ok(DataType::Boolean), - _ => { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - utils::get_supertype(&left_type, &right_type) - } - }, + } => binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ), Expr::Sort { ref expr, .. } => expr.get_type(schema), Expr::Wildcard => Err(ExecutionError::General( "Wildcard expressions are not valid in a logical query plan".to_owned(), @@ -544,6 +538,16 @@ impl Expr { binary_expr(self.clone(), Operator::Modulus, other.clone()) } + /// like (string) another expression + pub fn like(&self, other: Expr) -> Expr { + binary_expr(self.clone(), Operator::Like, other.clone()) + } + + /// not like another expression + pub fn not_like(&self, other: Expr) -> Expr { + binary_expr(self.clone(), Operator::NotLike, other.clone()) + } + /// Alias pub fn alias(&self, name: &str) -> Expr { Expr::Alias(Box::new(self.clone()), name.to_owned()) diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 22f3ef020e6..0d7773cc309 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -26,9 +26,11 @@ use std::sync::Arc; use arrow::datatypes::Schema; use crate::error::{ExecutionError, Result}; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::execution::physical_plan::{ + expressions::numerical_coercion, udf::ScalarFunction, +}; use crate::logicalplan::Expr; -use crate::logicalplan::LogicalPlan; +use crate::logicalplan::{LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use utils::optimize_explain; @@ -61,16 +63,6 @@ impl TypeCoercionRule { // modify `expressions` by introducing casts when necessary match expr { - Expr::BinaryExpr { .. } => { - let left_type = expressions[0].get_type(schema)?; - let right_type = expressions[1].get_type(schema)?; - if left_type != right_type { - let super_type = utils::get_supertype(&left_type, &right_type)?; - - expressions[0] = expressions[0].cast_to(&super_type, schema)?; - expressions[1] = expressions[1].cast_to(&super_type, schema)?; - } - } Expr::ScalarFunction { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self.scalar_functions.get(name) { @@ -80,10 +72,19 @@ impl TypeCoercionRule { let actual_type = expressions[i].get_type(schema)?; let required_type = field.data_type(); if &actual_type != required_type { - let super_type = - utils::get_supertype(&actual_type, required_type)?; - expressions[i] = - expressions[i].cast_to(&super_type, schema)? + // attempt to coerce using numerical coercion + // todo: also try string coercion. + if let Ok(cast_to_type) = numerical_coercion( + &actual_type, + // assume that the function behaves like plus + // plus is not special here; the optimizer is just trying its best... + &Operator::Plus, + required_type, + ) { + expressions[i] = + expressions[i].cast_to(&cast_to_type, schema)? + }; + // not possible: do nothing and let the plan fail with a clear error message }; } } @@ -141,140 +142,3 @@ impl OptimizerRule for TypeCoercionRule { return "type_coercion"; } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::context::ExecutionContext; - use crate::execution::physical_plan::csv::CsvReadOptions; - use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder, Operator}; - use crate::test::arrow_testdata_path; - use arrow::datatypes::{DataType, Field, Schema}; - - #[test] - fn test_all_operators() -> Result<()> { - let testdata = arrow_testdata_path(); - let path = format!("{}/csv/aggregate_test_100.csv", testdata); - - let options = CsvReadOptions::new().schema_infer_max_records(100); - let plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - // filter clause needs the type coercion rule applied - .filter(col("c7").lt(lit(5_u8)))? - .project(vec![col("c1"), col("c2")])? - .aggregate(vec![col("c1")], vec![aggregate_expr("SUM", col("c2"))])? - .sort(vec![col("c1")])? - .limit(10)? - .build()?; - - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(&scalar_functions); - let plan = rule.optimize(&plan)?; - - // check that the filter had a cast added - let plan_str = format!("{:?}", plan); - println!("{}", plan_str); - let expected_plan_str = "Limit: 10 - Sort: #c1 - Aggregate: groupBy=[[#c1]], aggr=[[SUM(#c2)]] - Projection: #c1, #c2 - Filter: #c7 Lt CAST(UInt8(5) AS Int64)"; - assert!(plan_str.starts_with(expected_plan_str)); - - Ok(()) - } - - #[test] - fn test_with_csv_plan() -> Result<()> { - let testdata = arrow_testdata_path(); - let path = format!("{}/csv/aggregate_test_100.csv", testdata); - - let options = CsvReadOptions::new().schema_infer_max_records(100); - let plan = LogicalPlanBuilder::scan_csv(&path, options, None)? - .filter(col("c7").lt(col("c12")))? - .build()?; - - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(&scalar_functions); - let plan = rule.optimize(&plan)?; - - assert!(format!("{:?}", plan).starts_with("Filter: CAST(#c7 AS Float64) Lt #c12")); - - Ok(()) - } - - #[test] - fn test_add_i32_i64() { - binary_cast_test( - DataType::Int32, - DataType::Int64, - "CAST(#c0 AS Int64) Plus #c1", - ); - binary_cast_test( - DataType::Int64, - DataType::Int32, - "#c0 Plus CAST(#c1 AS Int64)", - ); - } - - #[test] - fn test_add_f32_f64() { - binary_cast_test( - DataType::Float32, - DataType::Float64, - "CAST(#c0 AS Float64) Plus #c1", - ); - binary_cast_test( - DataType::Float64, - DataType::Float32, - "#c0 Plus CAST(#c1 AS Float64)", - ); - } - - #[test] - fn test_add_i32_f32() { - binary_cast_test( - DataType::Int32, - DataType::Float32, - "CAST(#c0 AS Float32) Plus #c1", - ); - binary_cast_test( - DataType::Float32, - DataType::Int32, - "#c0 Plus CAST(#c1 AS Float32)", - ); - } - - #[test] - fn test_add_u32_i64() { - binary_cast_test( - DataType::UInt32, - DataType::Int64, - "CAST(#c0 AS Int64) Plus #c1", - ); - binary_cast_test( - DataType::Int64, - DataType::UInt32, - "#c0 Plus CAST(#c1 AS Int64)", - ); - } - - fn binary_cast_test(left_type: DataType, right_type: DataType, expected: &str) { - let schema = Schema::new(vec![ - Field::new("c0", left_type, true), - Field::new("c1", right_type, true), - ]); - - let expr = Expr::BinaryExpr { - left: Box::new(col("c0")), - op: Operator::Plus, - right: Box::new(col("c1")), - }; - - let ctx = ExecutionContext::new(); - let rule = TypeCoercionRule::new(ctx.scalar_functions().as_ref()); - - let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); - - assert_eq!(expected, format!("{:?}", expr2)); - } -} diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 695bb0ea4f3..b8e037f40ca 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -19,7 +19,7 @@ use std::collections::HashSet; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::error::{ExecutionError, Result}; @@ -69,121 +69,6 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< } } -/// Given two datatypes, determine the supertype that both types can safely be cast to -pub fn get_supertype(l: &DataType, r: &DataType) -> Result { - match _get_supertype(l, r) { - Some(dt) => Ok(dt), - None => _get_supertype(r, l).ok_or_else(|| { - ExecutionError::InternalError(format!( - "Failed to determine supertype of {:?} and {:?}", - l, r - )) - }), - } -} - -/// Given two datatypes, determine the supertype that both types can safely be cast to -fn _get_supertype(l: &DataType, r: &DataType) -> Option { - use arrow::datatypes::DataType::*; - match (l, r) { - (UInt8, Int8) => Some(Int8), - (UInt8, Int16) => Some(Int16), - (UInt8, Int32) => Some(Int32), - (UInt8, Int64) => Some(Int64), - - (UInt16, Int16) => Some(Int16), - (UInt16, Int32) => Some(Int32), - (UInt16, Int64) => Some(Int64), - - (UInt32, Int32) => Some(Int32), - (UInt32, Int64) => Some(Int64), - - (UInt64, Int64) => Some(Int64), - - (Int8, UInt8) => Some(Int8), - - (Int16, UInt8) => Some(Int16), - (Int16, UInt16) => Some(Int16), - - (Int32, UInt8) => Some(Int32), - (Int32, UInt16) => Some(Int32), - (Int32, UInt32) => Some(Int32), - - (Int64, UInt8) => Some(Int64), - (Int64, UInt16) => Some(Int64), - (Int64, UInt32) => Some(Int64), - (Int64, UInt64) => Some(Int64), - - (UInt8, UInt8) => Some(UInt8), - (UInt8, UInt16) => Some(UInt16), - (UInt8, UInt32) => Some(UInt32), - (UInt8, UInt64) => Some(UInt64), - (UInt8, Float32) => Some(Float32), - (UInt8, Float64) => Some(Float64), - - (UInt16, UInt8) => Some(UInt16), - (UInt16, UInt16) => Some(UInt16), - (UInt16, UInt32) => Some(UInt32), - (UInt16, UInt64) => Some(UInt64), - (UInt16, Float32) => Some(Float32), - (UInt16, Float64) => Some(Float64), - - (UInt32, UInt8) => Some(UInt32), - (UInt32, UInt16) => Some(UInt32), - (UInt32, UInt32) => Some(UInt32), - (UInt32, UInt64) => Some(UInt64), - (UInt32, Float32) => Some(Float32), - (UInt32, Float64) => Some(Float64), - - (UInt64, UInt8) => Some(UInt64), - (UInt64, UInt16) => Some(UInt64), - (UInt64, UInt32) => Some(UInt64), - (UInt64, UInt64) => Some(UInt64), - (UInt64, Float32) => Some(Float32), - (UInt64, Float64) => Some(Float64), - - (Int8, Int8) => Some(Int8), - (Int8, Int16) => Some(Int16), - (Int8, Int32) => Some(Int32), - (Int8, Int64) => Some(Int64), - (Int8, Float32) => Some(Float32), - (Int8, Float64) => Some(Float64), - - (Int16, Int8) => Some(Int16), - (Int16, Int16) => Some(Int16), - (Int16, Int32) => Some(Int32), - (Int16, Int64) => Some(Int64), - (Int16, Float32) => Some(Float32), - (Int16, Float64) => Some(Float64), - - (Int32, Int8) => Some(Int32), - (Int32, Int16) => Some(Int32), - (Int32, Int32) => Some(Int32), - (Int32, Int64) => Some(Int64), - (Int32, Float32) => Some(Float32), - (Int32, Float64) => Some(Float64), - - (Int64, Int8) => Some(Int64), - (Int64, Int16) => Some(Int64), - (Int64, Int32) => Some(Int64), - (Int64, Int64) => Some(Int64), - (Int64, Float32) => Some(Float32), - (Int64, Float64) => Some(Float64), - - (Float32, Float32) => Some(Float32), - (Float32, Float64) => Some(Float64), - (Float64, Float32) => Some(Float64), - (Float64, Float64) => Some(Float64), - - (Utf8, _) => Some(Utf8), - (_, Utf8) => Some(Utf8), - - (Boolean, Boolean) => Some(Boolean), - - _ => None, - } -} - /// Create a `LogicalPlan::Explain` node by running `optimizer` on the /// input plan and capturing the resulting plan string pub fn optimize_explain( diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 942a781a3a2..e9e09b4e7ca 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -470,11 +470,7 @@ fn csv_explain_verbose() { assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); assert!(actual.contains("type_coercion"), "Actual: '{}'", actual); - assert!( - actual.contains("CAST(#c2 AS Int64) Gt Int64(10)"), - "Actual: '{}'", - actual - ); + assert!(actual.contains("#c2 Gt Int64(10)"), "Actual: '{}'", actual); } fn aggr_test_schema() -> SchemaRef { @@ -626,6 +622,13 @@ fn result_str(results: &[RecordBatch]) -> Vec { str.push_str(&format!("{:?}", s)); } + DataType::Boolean => { + let array = + column.as_any().downcast_ref::().unwrap(); + let s = array.value(row_index); + + str.push_str(&format!("{:?}", s)); + } _ => str.push_str("???"), } } @@ -654,3 +657,26 @@ fn query_length() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[test] +fn csv_query_sum_cast() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx); + // c8 = i32; c9 = i64 + let sql = "SELECT c8 + c9 FROM aggregate_test_100"; + // check that the physical and logical schemas are equal + execute(&mut ctx, sql); +} + +#[test] +fn like() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx); + let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; + // check that the physical and logical schemas are equal + let actual = execute(&mut ctx, sql).join("\n"); + + let expected = "1".to_string(); + assert_eq!(expected, actual); + Ok(()) +}