-
Notifications
You must be signed in to change notification settings - Fork 307
fix: modulo op with negative zero divisor produces Nan #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
774acef
45b876b
90494da
d82ffbf
b658eac
3585f14
06febb3
447e327
20ac2ba
00e44ed
cd53284
81f4e62
9696aaf
14f493a
2878d71
79b0a4b
6e843da
8a95de9
790e369
683c57c
fd70aa7
080130b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -954,8 +954,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim | |
| } | ||
| None | ||
|
|
||
| case rem @ Remainder(left, right, _) | ||
| if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => | ||
| case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => | ||
| val leftExpr = exprToProtoInternal(left, inputs) | ||
| val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) | ||
|
|
||
|
|
@@ -987,23 +986,82 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim | |
| None | ||
|
|
||
| case EqualTo(left, right) => | ||
| val leftExpr = exprToProtoInternal(left, inputs) | ||
| val rightExpr = exprToProtoInternal(right, inputs) | ||
| // this is a workaround for handling -0.0 in double and float | ||
| // untill https://github.com/apache/datafusion/issues/11108 is fixed | ||
| val leftZero = Literal.default(left.dataType) | ||
| val rightZero = Literal.default(right.dataType) | ||
| val negZeroLeft = UnaryMinus(leftZero) | ||
| val negZeroRight = UnaryMinus(rightZero) | ||
|
|
||
| def buildEqualExpr( | ||
| leftExpr: Option[Expr], | ||
| rightExpr: Option[Expr]): Option[ExprOuterClass.Expr] = { | ||
| if (leftExpr.isDefined && rightExpr.isDefined) { | ||
| Some( | ||
| ExprOuterClass.Expr | ||
| .newBuilder() | ||
| .setEq( | ||
| ExprOuterClass.Equal | ||
| .newBuilder() | ||
| .setLeft(leftExpr.get) | ||
| .setRight(rightExpr.get)) | ||
| .build()) | ||
| } else { | ||
| withInfo(expr, left, right) | ||
| None | ||
| } | ||
| } | ||
| if ((left.dataType == DoubleType && | ||
| right.dataType == DoubleType) || | ||
| (left.dataType == FloatType && | ||
| right.dataType == FloatType)) { | ||
| (left, right) match { | ||
| case (`negZeroLeft`, `negZeroRight`) => | ||
| return buildEqualExpr( | ||
| exprToProtoInternal(Abs(left).child, inputs), | ||
| exprToProtoInternal(Abs(right).child, inputs)) | ||
| case (`negZeroLeft`, _) => | ||
| return buildEqualExpr( | ||
| exprToProtoInternal(Abs(left).child, inputs), | ||
| exprToProtoInternal(right, inputs)) | ||
| case (_, `negZeroRight`) => | ||
| return buildEqualExpr( | ||
| exprToProtoInternal(left, inputs), | ||
| exprToProtoInternal(Abs(right).child, inputs)) | ||
| case _ => | ||
| val doubleNan = Literal(Double.NaN, DoubleType) | ||
| val floatNan = Literal(Float.NaN, FloatType) | ||
|
|
||
| // Ensure neither left nor right is -0.0 or 0.0 or NaN | ||
| // also return none if one side is nullable and the other is not | ||
| if ((left.nullable && !right.nullable) && | ||
|
Comment on lines
+1036
to
+1037
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason of the |
||
| (left != negZeroLeft && right != negZeroRight) && | ||
| (left != leftZero && right != rightZero) && | ||
| (left != doubleNan && right != doubleNan) && | ||
| (left != floatNan && right != floatNan)) { | ||
|
vaibhawvipul marked this conversation as resolved.
|
||
| withInfo(expr, left, right) | ||
| return None | ||
| } | ||
| buildEqualExpr( | ||
| exprToProtoInternal(left, inputs), | ||
| exprToProtoInternal(right, inputs)) | ||
| } | ||
| } | ||
|
|
||
| if (leftExpr.isDefined && rightExpr.isDefined) { | ||
| val builder = ExprOuterClass.Equal.newBuilder() | ||
| builder.setLeft(leftExpr.get) | ||
| builder.setRight(rightExpr.get) | ||
| val leftExpr = | ||
| if (left.dataType == DoubleType || left.dataType == FloatType) { | ||
| exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm ??
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm still trying to understand how this recursion works... |
||
| } else { | ||
| exprToProtoInternal(left, inputs) | ||
| } | ||
| val rightExpr = | ||
| if (right.dataType == DoubleType || right.dataType == FloatType) { | ||
| exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) | ||
| } else { | ||
| exprToProtoInternal(right, inputs) | ||
| } | ||
|
|
||
| Some( | ||
| ExprOuterClass.Expr | ||
| .newBuilder() | ||
| .setEq(builder) | ||
| .build()) | ||
| } else { | ||
| withInfo(expr, left, right) | ||
| None | ||
| } | ||
| buildEqualExpr(leftExpr, rightExpr) | ||
|
|
||
| case Not(EqualTo(left, right)) => | ||
| val leftExpr = exprToProtoInternal(left, inputs) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -850,9 +850,47 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { | |
| } | ||
| } | ||
|
|
||
| test("zero equality") { | ||
| withParquetTable( | ||
| Seq( | ||
| (-0.0, 0.0), | ||
| (0.0, -0.0), | ||
| (-0.0, -0.0), | ||
| (0.0, 0.0), | ||
| (1.0, 2.0), | ||
| (1.0, 1.0), | ||
| (1.0, 0.0), | ||
| (0.0, 1.0), | ||
| (-0.0, 1.0), | ||
| (1.0, -0.0), | ||
| (1.0, -1.0), | ||
| (-1.0, 1.0), | ||
| (-1.0, -0.0), | ||
| (-1.0, -1.0), | ||
| (-1.0, 0.0), | ||
| (0.0, -1.0)), | ||
| "t") { | ||
| checkSparkAnswerAndOperator("SELECT _1 == _2 FROM t") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need to test
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, I will create a different PR for it. |
||
| } | ||
| } | ||
|
|
||
| test("remainder") { | ||
| val query = "SELECT _1, _2, _1 % _2 FROM t" | ||
| withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") { | ||
| checkSparkAnswerAndOperator(query) | ||
| } | ||
|
|
||
| withParquetTable(Seq((Decimal(21840, 10, 0), Decimal(-0.0, 10, 0))), "t") { | ||
| checkSparkAnswerAndOperator(query) | ||
| } | ||
|
|
||
| withParquetTable(Seq((21840.0f, -0.0f), (21840.0f, 5.0f)), "t") { | ||
| checkSparkAnswerAndOperator(query) | ||
| } | ||
| } | ||
|
|
||
| // https://github.com/apache/datafusion-comet/issues/666 | ||
| ignore("abs Overflow ansi mode") { | ||
|
|
||
| def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { | ||
| withParquetTable(data, "tbl") { | ||
| checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind explaining why we need
Abs(right).child? I thoughtAbs(right).child == right