From 2b094b6e8fb1b0b8ae8bc89782305ac44d172ec3 Mon Sep 17 00:00:00 2001 From: stanzhai Date: Tue, 4 Apr 2017 20:48:53 +0800 Subject: [PATCH 1/3] fix decimal floor/ceil precision bug --- .../spark/sql/catalyst/expressions/mathExpressions.scala | 6 ++++-- .../src/main/scala/org/apache/spark/sql/types/Decimal.scala | 6 ++++-- .../sql/catalyst/expressions/MathExpressionsSuite.scala | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dea5f85cb08cc..980a7149bf9ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -225,7 +225,8 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision - scale + 1, 0) + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + DecimalType.bounded(boundedPrecision, 0) case _ => LongType } @@ -340,7 +341,8 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision - scale + 1, 0) + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + DecimalType.bounded(boundedPrecision, 0) case _ => LongType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e8f6884c025c2..b5403400b85b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -376,13 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + val newPrecision = DecimalType.bounded(boundedPrecision, 0).precision toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + val newPrecision = DecimalType.bounded(boundedPrecision, 0).precision toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6b5bfac94645c..8a118a81d5850 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -248,7 +248,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) - testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.0001))) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) @@ -258,7 +258,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Floor, (d: Double) => math.floor(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) - testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.0001))) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) From 2d60230b8344b391c3edfeec7c19ad1717e93710 Mon Sep 17 00:00:00 2001 From: stanzhai Date: Tue, 4 Apr 2017 22:28:03 +0800 Subject: [PATCH 2/3] add test case --- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0ff..c26ce12ee287a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2606,4 +2606,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("floor(0.0001)") { + val df = Seq(0).toDF("a") + withTempView("tb") { + df.createOrReplaceTempView("tb") + checkAnswer(sql("SELECT 1 > 0.00001 FROM tb"), Row(true)) + checkAnswer(sql("SELECT floor(0.0001) FROM tb"), Row(0)) + checkAnswer(sql("SELECT ceil(0.0001) FROM tb"), Row(1)) + checkAnswer(sql("SELECT floor(0.00123) FROM tb"), Row(0)) + checkAnswer(sql("SELECT floor(0.00010) FROM tb"), Row(0)) + } + } } From 61058b6e69802312bda35cdaf04a5b2af7dcd827 Mon Sep 17 00:00:00 2001 From: stanzhai Date: Tue, 4 Apr 2017 23:02:54 +0800 Subject: [PATCH 3/3] update test case --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c26ce12ee287a..0b199e4d87d83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2607,7 +2607,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("floor(0.0001)") { + test("SPARK-20211: should be able to floor or ceil with a decimal when its precision < scale") { val df = Seq(0).toDF("a") withTempView("tb") { df.createOrReplaceTempView("tb")