diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e06405a1a2..7c49d94df6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -954,8 +954,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } None - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) @@ -987,23 +986,82 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + // this is a workaround for handling -0.0 in double and float + // untill https://github.com/apache/datafusion/issues/11108 is fixed + val leftZero = Literal.default(left.dataType) + val rightZero = Literal.default(right.dataType) + val negZeroLeft = UnaryMinus(leftZero) + val negZeroRight = UnaryMinus(rightZero) + + def buildEqualExpr( + leftExpr: Option[Expr], + rightExpr: Option[Expr]): Option[ExprOuterClass.Expr] = { + if (leftExpr.isDefined && rightExpr.isDefined) { + Some( + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get)) + .build()) + } else { + withInfo(expr, left, right) + None + } + } + if ((left.dataType == DoubleType && + right.dataType == DoubleType) || + (left.dataType == FloatType && + right.dataType == FloatType)) { + (left, right) match { + case (`negZeroLeft`, `negZeroRight`) => + return buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs), + exprToProtoInternal(Abs(right).child, inputs)) + case (`negZeroLeft`, _) => + return buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs), + exprToProtoInternal(right, inputs)) + case (_, `negZeroRight`) => + return buildEqualExpr( + exprToProtoInternal(left, inputs), + exprToProtoInternal(Abs(right).child, inputs)) + case _ => + val doubleNan = Literal(Double.NaN, DoubleType) + val floatNan = Literal(Float.NaN, FloatType) + + // Ensure neither left nor right is -0.0 or 0.0 or NaN + // also return none if one side is nullable and the other is not + if ((left.nullable && !right.nullable) && + (left != negZeroLeft && right != negZeroRight) && + (left != leftZero && right != rightZero) && + (left != doubleNan && right != doubleNan) && + (left != floatNan && right != floatNan)) { + withInfo(expr, left, right) + return None + } + buildEqualExpr( + exprToProtoInternal(left, inputs), + exprToProtoInternal(right, inputs)) + } + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Equal.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + val leftExpr = + if (left.dataType == DoubleType || left.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) + } else { + exprToProtoInternal(left, inputs) + } + val rightExpr = + if (right.dataType == DoubleType || right.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) + } else { + exprToProtoInternal(right, inputs) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) - } else { - withInfo(expr, left, right) - None - } + buildEqualExpr(leftExpr, rightExpr) case Not(EqualTo(left, right)) => val leftExpr = exprToProtoInternal(left, inputs) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c22c6b06af..0402477d02 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -850,9 +850,47 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("zero equality") { + withParquetTable( + Seq( + (-0.0, 0.0), + (0.0, -0.0), + (-0.0, -0.0), + (0.0, 0.0), + (1.0, 2.0), + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), + (-0.0, 1.0), + (1.0, -0.0), + (1.0, -1.0), + (-1.0, 1.0), + (-1.0, -0.0), + (-1.0, -1.0), + (-1.0, 0.0), + (0.0, -1.0)), + "t") { + checkSparkAnswerAndOperator("SELECT _1 == _2 FROM t") + } + } + + test("remainder") { + val query = "SELECT _1, _2, _1 % _2 FROM t" + withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") { + checkSparkAnswerAndOperator(query) + } + + withParquetTable(Seq((Decimal(21840, 10, 0), Decimal(-0.0, 10, 0))), "t") { + checkSparkAnswerAndOperator(query) + } + + withParquetTable(Seq((21840.0f, -0.0f), (21840.0f, 5.0f)), "t") { + checkSparkAnswerAndOperator(query) + } + } + // https://github.com/apache/datafusion-comet/issues/666 ignore("abs Overflow ansi mode") { - def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetTable(data, "tbl") { checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match {