diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 7fa7bfe905..6c46d1b0ae 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -238,62 +238,61 @@ impl PhysicalPlanner { ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { ExprStruct::Add(expr) => { - // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/2021 + // TODO respect ANSI eval mode // https://github.com/apache/datafusion-comet/issues/536 - let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), DataFusionOperator::Plus, input_schema, + eval_mode, ) } ExprStruct::Subtract(expr) => { - // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/2021 + // TODO respect ANSI eval mode // https://github.com/apache/datafusion-comet/issues/535 - let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), DataFusionOperator::Minus, input_schema, + eval_mode, ) } ExprStruct::Multiply(expr) => { - // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/2021 + // TODO respect ANSI eval mode // https://github.com/apache/datafusion-comet/issues/534 - let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), DataFusionOperator::Multiply, input_schema, + eval_mode, ) } ExprStruct::Divide(expr) => { - // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/2021 + // TODO respect ANSI eval mode // https://github.com/apache/datafusion-comet/issues/533 - let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), expr.return_type.as_ref(), DataFusionOperator::Divide, input_schema, + eval_mode, ) } ExprStruct::IntegralDivide(expr) => { // TODO respect eval mode - // https://github.com/apache/datafusion-comet/issues/2021 // https://github.com/apache/datafusion-comet/issues/533 - let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; self.create_binary_expr_with_options( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), @@ -303,6 +302,7 @@ impl PhysicalPlanner { BinaryExprOptions { is_integral_div: true, }, + eval_mode, ) } ExprStruct::Remainder(expr) => { @@ -1004,6 +1004,7 @@ impl PhysicalPlanner { return_type: Option<&spark_expression::DataType>, op: DataFusionOperator, input_schema: SchemaRef, + eval_mode: EvalMode, ) -> Result, ExecutionError> { self.create_binary_expr_with_options( left, @@ -1012,9 +1013,11 @@ impl PhysicalPlanner { op, input_schema, BinaryExprOptions::default(), + eval_mode, ) } + #[allow(clippy::too_many_arguments)] fn create_binary_expr_with_options( &self, left: &Expr, @@ -1023,6 +1026,7 @@ impl PhysicalPlanner { op: DataFusionOperator, input_schema: SchemaRef, options: BinaryExprOptions, + eval_mode: EvalMode, ) -> Result, ExecutionError> { let left = self.create_expr(left, Arc::clone(&input_schema))?; let right = self.create_expr(right, Arc::clone(&input_schema))?; @@ -1087,7 +1091,34 @@ impl PhysicalPlanner { Arc::new(Field::new(func_name, data_type, true)), ))) } - _ => Ok(Arc::new(BinaryExpr::new(left, op, right))), + _ => { + let data_type = return_type.map(to_arrow_datatype).unwrap(); + if eval_mode == EvalMode::Try && data_type.is_integer() { + let op_str = match op { + DataFusionOperator::Plus => "checked_add", + DataFusionOperator::Minus => "checked_sub", + DataFusionOperator::Multiply => "checked_mul", + DataFusionOperator::Divide => "checked_div", + _ => { + todo!("Operator yet to be implemented!"); + } + }; + let fun_expr = create_comet_physical_fun( + op_str, + data_type.clone(), + &self.session_ctx.state(), + None, + )?; + Ok(Arc::new(ScalarFunctionExpr::new( + op_str, + fun_expr, + vec![left, right], + Arc::new(Field::new(op_str, data_type, true)), + ))) + } else { + Ok(Arc::new(BinaryExpr::new(left, op, right))) + } + } } } diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 961309b788..75f5689ad5 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, @@ -115,6 +116,18 @@ pub fn create_comet_physical_fun( data_type ) } + "checked_add" => { + make_comet_scalar_udf!("checked_add", checked_add, data_type) + } + "checked_sub" => { + make_comet_scalar_udf!("checked_sub", checked_sub, data_type) + } + "checked_mul" => { + make_comet_scalar_udf!("checked_mul", checked_mul, data_type) + } + "checked_div" => { + make_comet_scalar_udf!("checked_div", checked_div, data_type) + } "murmur3_hash" => { let func = Arc::new(spark_murmur3_hash); make_comet_scalar_udf!("murmur3_hash", func, without data_type) diff --git a/native/spark-expr/src/math_funcs/checked_arithmetic.rs b/native/spark-expr/src/math_funcs/checked_arithmetic.rs new file mode 100644 index 0000000000..0312cdb0b0 --- /dev/null +++ b/native/spark-expr/src/math_funcs/checked_arithmetic.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder}; +use arrow::array::{ArrayRef, AsArray}; + +use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type}; +use datafusion::common::DataFusionError; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +pub fn try_arithmetic_kernel( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: &str, +) -> Result +where + T: ArrowPrimitiveType, +{ + let len = left.len(); + let mut builder = PrimitiveBuilder::::with_capacity(len); + match op { + "checked_add" => { + for i in 0..len { + if left.is_null(i) || right.is_null(i) { + builder.append_null(); + } else { + builder.append_option(left.value(i).add_checked(right.value(i)).ok()); + } + } + } + "checked_sub" => { + for i in 0..len { + if left.is_null(i) || right.is_null(i) { + builder.append_null(); + } else { + builder.append_option(left.value(i).sub_checked(right.value(i)).ok()); + } + } + } + "checked_mul" => { + for i in 0..len { + if left.is_null(i) || right.is_null(i) { + builder.append_null(); + } else { + builder.append_option(left.value(i).mul_checked(right.value(i)).ok()); + } + } + } + "checked_div" => { + for i in 0..len { + if left.is_null(i) || right.is_null(i) { + builder.append_null(); + } else { + builder.append_option(left.value(i).div_checked(right.value(i)).ok()); + } + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported operation: {:?}", + op + ))) + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +pub fn checked_add( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + checked_arithmetic_internal(args, data_type, "checked_add") +} + +pub fn checked_sub( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + checked_arithmetic_internal(args, data_type, "checked_sub") +} + +pub fn checked_mul( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + checked_arithmetic_internal(args, data_type, "checked_mul") +} + +pub fn checked_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + checked_arithmetic_internal(args, data_type, "checked_div") +} + +fn checked_arithmetic_internal( + args: &[ColumnarValue], + data_type: &DataType, + op: &str, +) -> Result { + let left = &args[0]; + let right = &args[1]; + + let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, Arc::clone(r)) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (Arc::clone(l), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), + }; + + // Rust only supports checked_arithmetic on Int32 and Int64 + let result_array = match data_type { + DataType::Int32 => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + ), + DataType::Int64 => try_arithmetic_kernel::( + left_arr.as_primitive::(), + right_arr.as_primitive::(), + op, + ), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type: {:?}", + data_type + ))), + }; + + Ok(ColumnarValue::Array(result_array?)) +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index e7e203de59..873b290ebd 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -16,6 +16,7 @@ // under the License. mod ceil; +pub(crate) mod checked_arithmetic; mod div; mod floor; pub(crate) mod hex; diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 35acf0f356..63160d6ef7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -94,10 +94,6 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase { withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - if (expr.evalMode == EvalMode.TRY) { - withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") - return None - } createMathExpression( expr, expr.left, @@ -119,10 +115,6 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase { withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - if (expr.evalMode == EvalMode.TRY) { - withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") - return None - } createMathExpression( expr, expr.left, @@ -144,10 +136,6 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase { withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - if (expr.evalMode == EvalMode.TRY) { - withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") - return None - } createMathExpression( expr, expr.left, @@ -169,15 +157,10 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase { // See https://github.com/apache/arrow-datafusion/pull/6792 // For now, use NullIf to swap zeros with nulls. val rightExpr = nullIfWhenPrimitive(expr.right) - if (!supportedDataType(expr.left.dataType)) { withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - if (expr.evalMode == EvalMode.TRY) { - withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") - return None - } createMathExpression( expr, expr.left, @@ -199,10 +182,6 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - if (expr.evalMode == EvalMode.TRY) { - withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") - return None - } // Precision is set to 19 (max precision for a numerical data type except DecimalType) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 60435cee7b..f2dcedb53a 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -315,11 +315,82 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("try_add") { - // TODO: we need to implement more comprehensive tests for all try_ arithmetic functions - // https://github.com/apache/datafusion-comet/issues/2021 - val data = Seq((Integer.MAX_VALUE, 1)) + val data = Seq((1, 1)) withParquetTable(data, "tbl") { - checkSparkAnswer("SELECT try_add(_1, _2) FROM tbl") + checkSparkAnswerAndOperator(spark.sql(""" + |SELECT + | try_add(2147483647, 1), + | try_add(-2147483648, -1), + | try_add(NULL, 5), + | try_add(5, NULL), + | try_add(9223372036854775807, 1), + | try_add(-9223372036854775808, -1) + | from tbl + | """.stripMargin)) + } + } + + test("try_subtract") { + val data = Seq((1, 1)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator(spark.sql(""" + |SELECT + | try_subtract(2147483647, -1), + | try_subtract(-2147483648, 1), + | try_subtract(NULL, 5), + | try_subtract(5, NULL), + | try_subtract(9223372036854775807, -1), + | try_subtract(-9223372036854775808, 1) + | FROM tbl + """.stripMargin)) + } + } + + test("try_multiply") { + val data = Seq((1, 1)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator(spark.sql(""" + |SELECT + | try_multiply(1073741824, 4), + | try_multiply(-1073741824, 4), + | try_multiply(NULL, 5), + | try_multiply(5, NULL), + | try_multiply(3037000499, 3037000500), + | try_multiply(-3037000499, 3037000500) + |FROM tbl + """.stripMargin)) + } + } + + test("try_divide") { + val data = Seq((15121991, 0)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT try_divide(_1, _2) FROM tbl") + checkSparkAnswerAndOperator(""" + |SELECT + | try_divide(10, 0), + | try_divide(NULL, 5), + | try_divide(5, NULL), + | try_divide(-2147483648, -1), + | try_divide(-9223372036854775808, -1), + | try_divide(DECIMAL('9999999999999999999999999999'), 0.1) + | from tbl + |""".stripMargin) + } + } + + test("try_integral_divide overflow cases") { + val data = Seq((15121991, 0)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT try_divide(_1, _2) FROM tbl") + checkSparkAnswerAndOperator(""" + |SELECT try_divide(-128, -1), + |try_divide(-32768, -1), + |try_divide(-2147483648, -1), + |try_divide(-9223372036854775808, -1), + |try_divide(CAST(99999 AS DECIMAL(5,0)), CAST(0.0001 AS DECIMAL(5,4))) + |from tbl + |""".stripMargin) } }