From 653d047fcf3230c40ca6f93c8cf3a711d100bd0b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 01:56:49 +0800 Subject: [PATCH 01/21] Add math function round --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 97 ++++++++++++++++++- .../catalyst/util/BigDecimalConverter.scala | 60 ++++++++++++ .../ExpressionTypeCheckingSuite.scala | 13 +++ .../expressions/MathFunctionsSuite.scala | 12 +++ .../execution/HiveCompatibilitySuite.scala | 2 +- 6 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..471e8bd68b9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -127,6 +127,7 @@ object FunctionRegistry { expression[Tanh]("tanh"), expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), + expression[Round]("round"), // misc functions expression[Md5]("md5"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..ca31009c99419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.BigDecimalConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -520,3 +523,95 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +case class Round(children: Seq[Expression]) extends Expression { + + def nullable: Boolean = true + + def dataType: DataType = { + children(0).dataType match { + case StringType | BinaryType => DoubleType + case t => t + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 1 || children.size > 2) { + return TypeCheckFailure(s"ROUND require one or two arguments, got ${children.size}") + } + children(0).dataType match { + case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement + case dt => + return TypeCheckFailure(s"Only numeric, string or binary data types" + + s" are allowed for ROUND function, got $dt") + } + if (children.size == 2) { + children(1) match { + case Literal(value, LongType) => + if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { + return TypeCheckFailure("ROUND scale argument out of allowed range") + } + case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement + case child => + if (child.find { case _: AttributeReference => true; case _ => false } != None) { + return TypeCheckFailure("Only Integral Literal or Null Literal " + + s"are allowed for ROUND scale arguments, got ${child.dataType}") + } + } + } + TypeCheckSuccess + } + + def eval(input: InternalRow): Any = { + val evalE1 = children(0).eval(input) + if (evalE1 == null) { + return null + } + + var _scale: Int = 0 + if (children.size == 2) { + val evalE2 = children(1).eval(input) + if (evalE2 == null) { + return null + } else { + _scale = evalE2.asInstanceOf[Int] + } + } + + children(0).dataType match { + case decimalType: DecimalType => + // TODO: Support Decimal Round + case ByteType => + round(evalE1.asInstanceOf[Byte], _scale) + case ShortType => + round(evalE1.asInstanceOf[Short], _scale) + case IntegerType => + round(evalE1.asInstanceOf[Int], _scale) + case LongType => + round(evalE1.asInstanceOf[Long], _scale) + case FloatType => + round(evalE1.asInstanceOf[Float], _scale) + case DoubleType => + round(evalE1.asInstanceOf[Double], _scale) + case StringType => + round(evalE1.asInstanceOf[UTF8String].toString, _scale) + case BinaryType => + round(UTF8String.fromBytes(evalE1.asInstanceOf[Array[Byte]]).toString, _scale) + } + } + + private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + input match { + case f: Float if (f.isNaN || f.isInfinite) => return input + case d: Double if (d.isNaN || d.isInfinite) => return input + case _ => + } + bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) + } + + private def round(input: String, scale: Int): Any = { + try round(input.toDouble, scale) catch { + case _ : NumberFormatException => null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala new file mode 100644 index 0000000000000..1320680925c80 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +trait BigDecimalConverter[T] { + def toBigDecimal(in: T) : BigDecimal + def fromBigDecimal(bd: BigDecimal) : T +} + +/** + * Helper type converters to work with BigDecimal + * from http://stackoverflow.com/a/30979266/1115193 + */ +object BigDecimalConverter { + + implicit object ByteConverter extends BigDecimalConverter[Byte] { + def toBigDecimal(in: Byte) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toByte + } + + implicit object ShortConverter extends BigDecimalConverter[Short] { + def toBigDecimal(in: Short) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toShort + } + + implicit object IntConverter extends BigDecimalConverter[Int] { + def toBigDecimal(in: Int) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toInt + } + + implicit object LongConverter extends BigDecimalConverter[Long] { + def toBigDecimal(in: Long) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toLong + } + + implicit object FloatConverter extends BigDecimalConverter[Float] { + def toBigDecimal(in: Float) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toFloat + } + + implicit object DoubleConverter extends BigDecimalConverter[Double] { + def toBigDecimal(in: Double) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toDouble + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..fcefa8f891265 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -171,4 +171,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertError(Round(Seq()), "ROUND require one or two arguments") + assertError(Round(Seq(Literal(null),'booleanField)), + "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + assertError(Round(Seq(Literal(null), 'complexField)), + "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + assertSuccess(Round(Seq(Literal(null), Literal(null)))) + assertError(Round(Seq('booleanField, 'intField)), + "Only numeric, string or binary data types are allowed for ROUND function") + assertError(Round(Seq(Literal(null), Literal(1L + Int.MaxValue))), + "ROUND scale argument out of allowed range") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..c79b9ca0a3340 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -336,4 +336,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round test") { + val piRounds = Seq( + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, + 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593, 3.1415927, 3.14159265, 3.141592654, + 3.1415926536, 3.14159265359, 3.14159265359, 3.1415926535898, 3.14159265358979, + 3.141592653589793, 3.141592653589793) + (-16 to 16).zipWithIndex.foreach { + case (scale, i) => + checkEvaluation(Round(Seq(3.141592653589793, scale)), piRounds(i), EmptyRow) + } + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..65a6a5023ea62 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second", From 7e163aebaad2f58c6cf376df8e1f908c27979dbd Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 02:55:15 +0800 Subject: [PATCH 02/21] style fix --- .../catalyst/util/BigDecimalConverter.scala | 28 +++++++++---------- .../ExpressionTypeCheckingSuite.scala | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala index 1320680925c80..5ce7758156ccb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.util trait BigDecimalConverter[T] { - def toBigDecimal(in: T) : BigDecimal - def fromBigDecimal(bd: BigDecimal) : T + def toBigDecimal(in: T): BigDecimal + def fromBigDecimal(bd: BigDecimal): T } /** @@ -29,32 +29,32 @@ trait BigDecimalConverter[T] { object BigDecimalConverter { implicit object ByteConverter extends BigDecimalConverter[Byte] { - def toBigDecimal(in: Byte) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toByte + def toBigDecimal(in: Byte): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Byte = bd.toByte } implicit object ShortConverter extends BigDecimalConverter[Short] { - def toBigDecimal(in: Short) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toShort + def toBigDecimal(in: Short): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Short = bd.toShort } implicit object IntConverter extends BigDecimalConverter[Int] { - def toBigDecimal(in: Int) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toInt + def toBigDecimal(in: Int): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Int = bd.toInt } implicit object LongConverter extends BigDecimalConverter[Long] { - def toBigDecimal(in: Long) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toLong + def toBigDecimal(in: Long): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Long = bd.toLong } implicit object FloatConverter extends BigDecimalConverter[Float] { - def toBigDecimal(in: Float) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toFloat + def toBigDecimal(in: Float): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Float = bd.toFloat } implicit object DoubleConverter extends BigDecimalConverter[Double] { - def toBigDecimal(in: Double) = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal) = bd.toDouble + def toBigDecimal(in: Double): BigDecimal = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal): Double = bd.toDouble } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index fcefa8f891265..361a2e95907bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -174,7 +174,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for ROUND") { assertError(Round(Seq()), "ROUND require one or two arguments") - assertError(Round(Seq(Literal(null),'booleanField)), + assertError(Round(Seq(Literal(null), 'booleanField)), "Only Integral Literal or Null Literal are allowed for ROUND scale argument") assertError(Round(Seq(Literal(null), 'complexField)), "Only Integral Literal or Null Literal are allowed for ROUND scale argument") From 56db4bb6cf87fb642f7504ef8630ac147638840f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 11:18:59 +0800 Subject: [PATCH 03/21] Add decimal support to Round --- .../spark/sql/catalyst/expressions/math.scala | 23 ++++++++----------- .../execution/HiveCompatibilitySuite.scala | 5 +--- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index ca31009c99419..db659b7e6f40e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -528,9 +528,13 @@ case class Round(children: Seq[Expression]) extends Expression { def nullable: Boolean = true - def dataType: DataType = { + private lazy val evalE2 = if (children.size == 2) children(1).eval(EmptyRow) else null + private lazy val _scale = if (evalE2 != null) evalE2.asInstanceOf[Int] else 0 + + override lazy val dataType: DataType = { children(0).dataType match { case StringType | BinaryType => DoubleType + case DecimalType.Fixed(p, s) => DecimalType(p, _scale) case t => t } } @@ -564,23 +568,14 @@ case class Round(children: Seq[Expression]) extends Expression { def eval(input: InternalRow): Any = { val evalE1 = children(0).eval(input) - if (evalE1 == null) { - return null - } - var _scale: Int = 0 - if (children.size == 2) { - val evalE2 = children(1).eval(input) - if (evalE2 == null) { - return null - } else { - _scale = evalE2.asInstanceOf[Int] - } - } + if (evalE1 == null) return null + if (children.size == 2 && evalE2 == null) return null children(0).dataType match { case decimalType: DecimalType => - // TODO: Support Decimal Round + val decimal = evalE1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => round(evalE1.asInstanceOf[Byte], _scale) case ShortType => diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 65a6a5023ea62..4ada64bc21966 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -918,7 +915,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", + // "udf_round", turn this on after we figure out null vs nan vs infinity "udf_round_3", "udf_rpad", "udf_rtrim", From 7c83e13bd772e1dee6e94bf1b0456995264110cc Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 13:26:52 +0800 Subject: [PATCH 04/21] more tests on round --- .../expressions/MathFunctionsSuite.scala | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index c79b9ca0a3340..93935dec53039 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math.BigDecimal.RoundingMode + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -338,14 +340,31 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round test") { - val piRounds = Seq( - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, - 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593, 3.1415927, 3.14159265, 3.141592654, - 3.1415926536, 3.14159265359, 3.14159265359, 3.1415926535898, 3.14159265358979, - 3.141592653589793, 3.141592653589793) - (-16 to 16).zipWithIndex.foreach { - case (scale, i) => - checkEvaluation(Round(Seq(3.141592653589793, scale)), piRounds(i), EmptyRow) + val domain = -16 to 16 + val doublePi = math.Pi + val stringPi = "3.141592653589793" + val intPi = 314159265 + val bdPi = BigDecimal(31415926535897932L, 10) + + domain.foreach { scale => + checkEvaluation(Round(Seq(doublePi, scale)), + BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) + checkEvaluation(Round(Seq(stringPi, scale)), + BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) + checkEvaluation(Round(Seq(intPi, scale)), + BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow) + } + checkEvaluation(Round(Seq("invalid input")), null, EmptyRow) + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + val (validScales, nullScales) = domain.splitAt(27) + validScales.foreach { scale => + checkEvaluation(Round(Seq(bdPi, scale)), + Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow) + } + nullScales.foreach { scale => + checkEvaluation(Round(Seq(bdPi, scale)), null, EmptyRow) } } } From 9be894efc4cbba8d0f72c316a23847a9da49f8af Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 14:05:44 +0800 Subject: [PATCH 05/21] add round functions in o.a.s.sql.functions --- .../org/apache/spark/sql/functions.scala | 24 +++++++++++++++++++ .../spark/sql/MathExpressionsSuite.scala | 10 ++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ffa52f62588dc..cbae58b303088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1385,6 +1385,30 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = Round(Seq(e.expr)) + + /** + * Returns the value of `e` rounded to the value of `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Column): Column = Round(Seq(e.expr, scale.expr)) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = round(e, lit(scale)) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..f8bbc5a032083 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,16 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + checkAnswer( + ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9), round(1809242.3151111344, 9)"), + Seq((1, 2)).toDF().select( + round(lit(-32768)), + round(lit(1809242.3151111344), lit(9)), + round(lit(1809242.3151111344), 9)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } From 6cd9a64a7866f2c794a9c010846ad8e88ca5724d Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 17:15:36 +0800 Subject: [PATCH 06/21] refactor Round's constructor --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 66 +++++++++---------- .../ExpressionTypeCheckingSuite.scala | 11 ++-- .../expressions/MathFunctionsSuite.scala | 12 ++-- .../org/apache/spark/sql/functions.scala | 16 +++-- 5 files changed, 57 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 471e8bd68b9cc..70970169590d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -116,6 +116,7 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), @@ -127,7 +128,6 @@ object FunctionRegistry { expression[Tanh]("tanh"), expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), - expression[Round]("round"), // misc functions expression[Md5]("md5"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index db659b7e6f40e..d24e564298d75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -524,15 +524,21 @@ case class Logarithm(left: Expression, right: Expression) } } -case class Round(children: Seq[Expression]) extends Expression { +case class Round(child: Expression, scale: Expression) extends Expression { + + def this(child: Expression) = { + this(child, Literal(0)) + } + + def children: Seq[Expression] = Seq(child, scale) def nullable: Boolean = true - private lazy val evalE2 = if (children.size == 2) children(1).eval(EmptyRow) else null - private lazy val _scale = if (evalE2 != null) evalE2.asInstanceOf[Int] else 0 + private lazy val scaleV = scale.asInstanceOf[Literal].value + private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 override lazy val dataType: DataType = { - children(0).dataType match { + child.dataType match { case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) case t => t @@ -540,58 +546,52 @@ case class Round(children: Seq[Expression]) extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.size < 1 || children.size > 2) { - return TypeCheckFailure(s"ROUND require one or two arguments, got ${children.size}") - } - children(0).dataType match { + child.dataType match { case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement case dt => return TypeCheckFailure(s"Only numeric, string or binary data types" + s" are allowed for ROUND function, got $dt") } - if (children.size == 2) { - children(1) match { - case Literal(value, LongType) => - if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { - return TypeCheckFailure("ROUND scale argument out of allowed range") - } - case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement - case child => - if (child.find { case _: AttributeReference => true; case _ => false } != None) { - return TypeCheckFailure("Only Integral Literal or Null Literal " + - s"are allowed for ROUND scale arguments, got ${child.dataType}") - } - } + scale match { + case Literal(value, LongType) => + if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { + return TypeCheckFailure("ROUND scale argument out of allowed range") + } + case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement + case child => + if (child.find { case _: AttributeReference => true; case _ => false } != None) { + return TypeCheckFailure("Only Integral Literal or Null Literal " + + s"are allowed for ROUND scale arguments, got ${child.dataType}") + } } TypeCheckSuccess } def eval(input: InternalRow): Any = { - val evalE1 = children(0).eval(input) + val evalE = child.eval(input) - if (evalE1 == null) return null - if (children.size == 2 && evalE2 == null) return null + if (evalE == null || scaleV == null) return null children(0).dataType match { case decimalType: DecimalType => - val decimal = evalE1.asInstanceOf[Decimal] + val decimal = evalE.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => - round(evalE1.asInstanceOf[Byte], _scale) + round(evalE.asInstanceOf[Byte], _scale) case ShortType => - round(evalE1.asInstanceOf[Short], _scale) + round(evalE.asInstanceOf[Short], _scale) case IntegerType => - round(evalE1.asInstanceOf[Int], _scale) + round(evalE.asInstanceOf[Int], _scale) case LongType => - round(evalE1.asInstanceOf[Long], _scale) + round(evalE.asInstanceOf[Long], _scale) case FloatType => - round(evalE1.asInstanceOf[Float], _scale) + round(evalE.asInstanceOf[Float], _scale) case DoubleType => - round(evalE1.asInstanceOf[Double], _scale) + round(evalE.asInstanceOf[Double], _scale) case StringType => - round(evalE1.asInstanceOf[UTF8String].toString, _scale) + round(evalE.asInstanceOf[UTF8String].toString, _scale) case BinaryType => - round(UTF8String.fromBytes(evalE1.asInstanceOf[Array[Byte]]).toString, _scale) + round(UTF8String.fromBytes(evalE.asInstanceOf[Array[Byte]]).toString, _scale) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 361a2e95907bc..5467be022ff22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -173,15 +173,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertError(Round(Seq()), "ROUND require one or two arguments") - assertError(Round(Seq(Literal(null), 'booleanField)), + assertError(Round(Literal(null), 'booleanField), "Only Integral Literal or Null Literal are allowed for ROUND scale argument") - assertError(Round(Seq(Literal(null), 'complexField)), + assertError(Round(Literal(null), 'complexField), "Only Integral Literal or Null Literal are allowed for ROUND scale argument") - assertSuccess(Round(Seq(Literal(null), Literal(null)))) - assertError(Round(Seq('booleanField, 'intField)), + assertSuccess(Round(Literal(null), Literal(null))) + assertError(Round('booleanField, 'intField), "Only numeric, string or binary data types are allowed for ROUND function") - assertError(Round(Seq(Literal(null), Literal(1L + Int.MaxValue))), + assertError(Round(Literal(null), Literal(1L + Int.MaxValue)), "ROUND scale argument out of allowed range") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 93935dec53039..b1e26bbe76161 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -347,24 +347,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val bdPi = BigDecimal(31415926535897932L, 10) domain.foreach { scale => - checkEvaluation(Round(Seq(doublePi, scale)), + checkEvaluation(Round(doublePi, scale), BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(Seq(stringPi, scale)), + checkEvaluation(Round(stringPi, scale), BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(Seq(intPi, scale)), + checkEvaluation(Round(intPi, scale), BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow) } - checkEvaluation(Round(Seq("invalid input")), null, EmptyRow) + checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null val (validScales, nullScales) = domain.splitAt(27) validScales.foreach { scale => - checkEvaluation(Round(Seq(bdPi, scale)), + checkEvaluation(Round(bdPi, scale), Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow) } nullScales.foreach { scale => - checkEvaluation(Round(Seq(bdPi, scale)), null, EmptyRow) + checkEvaluation(Round(bdPi, scale), null, EmptyRow) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cbae58b303088..f6bd19bac61b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1386,20 +1386,20 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** - * Returns the value of the `e` rounded to 0 decimal places. + * Returns the value of the column `e` rounded to 0 decimal places. * * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = Round(Seq(e.expr)) + def round(e: Column): Column = round(e.expr, 0) /** - * Returns the value of `e` rounded to the value of `scale` decimal places. + * Returns the value of the given column `e` rounded to the value of `scale` decimal places. * * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Column): Column = Round(Seq(e.expr, scale.expr)) + def round(e: Column, scale: Column): Column = Round(e.expr, scale.expr) /** * Returns the value of `e` rounded to `scale` decimal places. @@ -1409,6 +1409,14 @@ object functions { */ def round(e: Column, scale: Int): Column = round(e, lit(scale)) + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. From 2077888ed6544fa64c83251714c803e3816b3435 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 23:21:14 +0800 Subject: [PATCH 07/21] codegen versioned eval --- .../spark/sql/catalyst/expressions/math.scala | 82 ++++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 9 ++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index d24e564298d75..10460c0b2ff20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -572,8 +572,8 @@ case class Round(child: Expression, scale: Expression) extends Expression { if (evalE == null || scaleV == null) return null - children(0).dataType match { - case decimalType: DecimalType => + child.dataType match { + case _: DecimalType => val decimal = evalE.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => @@ -595,6 +595,84 @@ case class Round(child: Expression, scale: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + def integralRound(primitive: String): String = { + s""" + ${ev.primitive} = new java.math.BigDecimal(${primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)""" + } + + def fractionalRound(primitive: String): String = { + s""" + ${ev.primitive} = java.math.BigDecimal.valueOf(${primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)""" + } + + def check(primitive: String, function: String): String = { + s""" + if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){ + ${ev.primitive} = ${primitive}; + } else { + ${fractionalRound(primitive)}.${function}; + }""" + } + + def convert(primitive: String): String = { + val dName = ctx.freshName("converter") + s""" + Double $dName = 0.0; + try { + $dName = Double.valueOf(${primitive}.toString()); + } catch (NumberFormatException e) { + ${ev.isNull} = true; + } + ${check(dName, "doubleValue()")} + """ + } + + def decimalRound(): String = { + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + } + """ + } + + val roundCode = child.dataType match { + case NullType => ";" + case _: DecimalType => + decimalRound() + case ByteType => + integralRound(ce.primitive) + ".byteValue();" + case ShortType => + integralRound(ce.primitive) + ".shortValue();" + case IntegerType => + integralRound(ce.primitive) + ".intValue();" + case LongType => + integralRound(ce.primitive) + ".longValue();" + case FloatType => + check(ce.primitive, "floatValue()") + case DoubleType => + check(ce.primitive, "doubleValue()") + case StringType => + convert(ce.primitive) + case BinaryType => + convert(s"${ctx.stringType}.fromBytes(${ce.primitive})") + } + + ce.code + s""" + boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${roundCode} + } + """ + } + private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { input match { case f: Float if (f.isNaN || f.isInfinite) => return input diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index b1e26bbe76161..9d95ef5cae35d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -343,7 +343,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val domain = -16 to 16 val doublePi = math.Pi val stringPi = "3.141592653589793" + val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte) + val shortPi: Short = 31415 val intPi = 314159265 + val longPi = 31415926535897932L val bdPi = BigDecimal(31415926535897932L, 10) domain.foreach { scale => @@ -351,8 +354,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) checkEvaluation(Round(stringPi, scale), BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) + checkEvaluation(Round(arrayPi, scale), + BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) + checkEvaluation(Round(shortPi, scale), + BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow) checkEvaluation(Round(intPi, scale), BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow) + checkEvaluation(Round(longPi, scale), + BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow) } checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow) From 5486b2d5c445d1bbbbe1fd643ddd318f470266ae Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 23:40:30 +0800 Subject: [PATCH 08/21] DataFrame API modification --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 +++--- .../scala/org/apache/spark/sql/MathExpressionsSuite.scala | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6bd19bac61b2..694cf3b39b09d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1394,12 +1394,12 @@ object functions { def round(e: Column): Column = round(e.expr, 0) /** - * Returns the value of the given column `e` rounded to the value of `scale` decimal places. + * Returns the value of the given column rounded to 0 decimal places. * * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Column): Column = Round(e.expr, scale.expr) + def round(columnName: String): Column = round(Column(columnName), 0) /** * Returns the value of `e` rounded to `scale` decimal places. @@ -1407,7 +1407,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = round(e, lit(scale)) + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** * Returns the value of the given column rounded to `scale` decimal places. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index f8bbc5a032083..8ccfdd5147680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -200,10 +200,9 @@ class MathExpressionsSuite extends QueryTest { test("round") { checkAnswer( - ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9), round(1809242.3151111344, 9)"), + ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9)"), Seq((1, 2)).toDF().select( round(lit(-32768)), - round(lit(1809242.3151111344), lit(9)), round(lit(1809242.3151111344), 9)) ) } From 1b87540358f9195bf4a43ab3ade309bd43357f6c Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 15:22:59 +0800 Subject: [PATCH 09/21] modify checkInputDataTypes using foldable --- .../spark/sql/catalyst/expressions/math.scala | 34 +++++++++---------- .../expressions/MathFunctionsSuite.scala | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 10460c0b2ff20..6f4db69d9e4f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -558,8 +558,8 @@ case class Round(child: Expression, scale: Expression) extends Expression { return TypeCheckFailure("ROUND scale argument out of allowed range") } case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement - case child => - if (child.find { case _: AttributeReference => true; case _ => false } != None) { + case _ => + if (!scale.foldable) { return TypeCheckFailure("Only Integral Literal or Null Literal " + s"are allowed for ROUND scale arguments, got ${child.dataType}") } @@ -595,6 +595,21 @@ case class Round(child: Expression, scale: Expression) extends Expression { } } + private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + input match { + case f: Float if (f.isNaN || f.isInfinite) => return input + case d: Double if (d.isNaN || d.isInfinite) => return input + case _ => + } + bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) + } + + private def round(input: String, scale: Int): Any = { + try round(input.toDouble, scale) catch { + case _ : NumberFormatException => null + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) @@ -672,19 +687,4 @@ case class Round(child: Expression, scale: Expression) extends Expression { } """ } - - private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { - input match { - case f: Float if (f.isNaN || f.isInfinite) => return input - case d: Double if (d.isNaN || d.isInfinite) => return input - case _ => - } - bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) - } - - private def round(input: String, scale: Int): Any = { - try round(input.toDouble, scale) catch { - case _ : NumberFormatException => null - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 9d95ef5cae35d..477ae969240e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -339,7 +339,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { create_row(null)) } - test("round test") { + test("round") { val domain = -16 to 16 val doublePi = math.Pi val stringPi = "3.141592653589793" From e6f44c4c862ce4eb335f8c3653d032b853be9750 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 18:45:37 +0800 Subject: [PATCH 10/21] refactor eval and genCode --- .../spark/sql/catalyst/expressions/math.scala | 98 +++++++++---------- .../ExpressionTypeCheckingSuite.scala | 4 +- 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 6f4db69d9e4f6..70af971440908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -404,11 +404,7 @@ case class Atan2(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -530,20 +526,20 @@ case class Round(child: Expression, scale: Expression) extends Expression { this(child, Literal(0)) } - def children: Seq[Expression] = Seq(child, scale) + override def children: Seq[Expression] = Seq(child, scale) + + override def nullable: Boolean = true - def nullable: Boolean = true + override def foldable: Boolean = child.foldable - private lazy val scaleV = scale.asInstanceOf[Literal].value + private lazy val scaleV = scale.eval(EmptyRow) private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 - override lazy val dataType: DataType = { - child.dataType match { + override lazy val dataType: DataType = child.dataType match { case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) case t => t } - } override def checkInputDataTypes(): TypeCheckResult = { child.dataType match { @@ -557,41 +553,42 @@ case class Round(child: Expression, scale: Expression) extends Expression { if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { return TypeCheckFailure("ROUND scale argument out of allowed range") } - case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement case _ => - if (!scale.foldable) { - return TypeCheckFailure("Only Integral Literal or Null Literal " + - s"are allowed for ROUND scale arguments, got ${child.dataType}") + if ((scale.dataType.isInstanceOf[IntegralType] || scale.dataType == NullType) && + scale.foldable) { + // TODO: foldable LongType is not checked for out of range + // satisfy requirement + } else { + return TypeCheckFailure("Only Integral or Null foldable Expression " + + s"is allowed for ROUND scale arguments, got ${child.dataType}") } } TypeCheckSuccess } - def eval(input: InternalRow): Any = { - val evalE = child.eval(input) + private lazy val rounding: (Any) => (Any) = roundGen(child.dataType) - if (evalE == null || scaleV == null) return null - - child.dataType match { + def roundGen(dt: DataType)(x: Any): Any = { + dt match { case _: DecimalType => - val decimal = evalE.asInstanceOf[Decimal] + val decimal = x.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => - round(evalE.asInstanceOf[Byte], _scale) + round(x.asInstanceOf[Byte], _scale) case ShortType => - round(evalE.asInstanceOf[Short], _scale) + round(x.asInstanceOf[Short], _scale) case IntegerType => - round(evalE.asInstanceOf[Int], _scale) + round(x.asInstanceOf[Int], _scale) case LongType => - round(evalE.asInstanceOf[Long], _scale) + round(x.asInstanceOf[Long], _scale) case FloatType => - round(evalE.asInstanceOf[Float], _scale) + round(x.asInstanceOf[Float], _scale) case DoubleType => - round(evalE.asInstanceOf[Double], _scale) + round(x.asInstanceOf[Double], _scale) case StringType => - round(evalE.asInstanceOf[UTF8String].toString, _scale) + round(x.asInstanceOf[UTF8String].toString, _scale) case BinaryType => - round(UTF8String.fromBytes(evalE.asInstanceOf[Array[Byte]]).toString, _scale) + round(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) } } @@ -606,35 +603,36 @@ case class Round(child: Expression, scale: Expression) extends Expression { private def round(input: String, scale: Int): Any = { try round(input.toDouble, scale) catch { - case _ : NumberFormatException => null + case _: NumberFormatException => null } } + def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null || scaleV == null) return null + rounding(evalE) + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) - def integralRound(primitive: String): String = { - s""" - ${ev.primitive} = new java.math.BigDecimal(${primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)""" - } - - def fractionalRound(primitive: String): String = { + def round(primitive: String, integral: Boolean): String = { + val (p1, p2) = if (integral) ("new", "") else ("", ".valueOf") s""" - ${ev.primitive} = java.math.BigDecimal.valueOf(${primitive}). + ${ev.primitive} = $p1 java.math.BigDecimal$p2(${primitive}). setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)""" } - def check(primitive: String, function: String): String = { + def fractionalCheck(primitive: String, function: String): String = { s""" if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){ ${ev.primitive} = ${primitive}; } else { - ${fractionalRound(primitive)}.${function}; + ${round(primitive, false)}.${function}; }""" } - def convert(primitive: String): String = { + def stringLikeConvert(primitive: String): String = { val dName = ctx.freshName("converter") s""" Double $dName = 0.0; @@ -643,7 +641,7 @@ case class Round(child: Expression, scale: Expression) extends Expression { } catch (NumberFormatException e) { ${ev.isNull} = true; } - ${check(dName, "doubleValue()")} + ${fractionalCheck(dName, "doubleValue()")} """ } @@ -662,21 +660,21 @@ case class Round(child: Expression, scale: Expression) extends Expression { case _: DecimalType => decimalRound() case ByteType => - integralRound(ce.primitive) + ".byteValue();" + round(ce.primitive, true) + ".byteValue();" case ShortType => - integralRound(ce.primitive) + ".shortValue();" + round(ce.primitive, true) + ".shortValue();" case IntegerType => - integralRound(ce.primitive) + ".intValue();" + round(ce.primitive, true) + ".intValue();" case LongType => - integralRound(ce.primitive) + ".longValue();" + round(ce.primitive, true) + ".longValue();" case FloatType => - check(ce.primitive, "floatValue()") + fractionalCheck(ce.primitive, "floatValue()") case DoubleType => - check(ce.primitive, "doubleValue()") + fractionalCheck(ce.primitive, "doubleValue()") case StringType => - convert(ce.primitive) + stringLikeConvert(ce.primitive) case BinaryType => - convert(s"${ctx.stringType}.fromBytes(${ce.primitive})") + stringLikeConvert(s"${ctx.stringType}.fromBytes(${ce.primitive})") } ce.code + s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 5467be022ff22..6bae906ee9f57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -174,9 +174,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for ROUND") { assertError(Round(Literal(null), 'booleanField), - "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + "Only Integral or Null foldable Expression is allowed for ROUND scale argument") assertError(Round(Literal(null), 'complexField), - "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + "Only Integral or Null foldable Expression is allowed for ROUND scale argument") assertSuccess(Round(Literal(null), Literal(null))) assertError(Round('booleanField, 'intField), "Only numeric, string or binary data types are allowed for ROUND function") From 9bd6930c928d082faab8ab463ffbe66024242d11 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 20:54:40 +0800 Subject: [PATCH 11/21] revert accidental change --- .../org/apache/spark/sql/catalyst/expressions/math.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 70af971440908..c9e8b72cdb486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -404,7 +404,11 @@ case class Atan2(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ } } From b0bff7950969f575ae3756177b5d7acab419042f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 3 Jul 2015 10:49:30 +0800 Subject: [PATCH 12/21] make round's inner method's name more meaningful --- .../spark/sql/catalyst/expressions/math.scala | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c9e8b72cdb486..92d8118c67252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -536,9 +536,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { override def foldable: Boolean = child.foldable - private lazy val scaleV = scale.eval(EmptyRow) - private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 - override lazy val dataType: DataType = child.dataType match { case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) @@ -570,33 +567,43 @@ case class Round(child: Expression, scale: Expression) extends Expression { TypeCheckSuccess } - private lazy val rounding: (Any) => (Any) = roundGen(child.dataType) + private lazy val scaleV = scale.eval(EmptyRow) + private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null || scaleV == null) return null + round(evalE) + } + + private lazy val round: (Any) => (Any) = typedRound(child.dataType) - def roundGen(dt: DataType)(x: Any): Any = { + // Using dataType info to find an appropriate round method + private def typedRound(dt: DataType)(x: Any): Any = { dt match { case _: DecimalType => val decimal = x.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => - round(x.asInstanceOf[Byte], _scale) + numericRound(x.asInstanceOf[Byte], _scale) case ShortType => - round(x.asInstanceOf[Short], _scale) + numericRound(x.asInstanceOf[Short], _scale) case IntegerType => - round(x.asInstanceOf[Int], _scale) + numericRound(x.asInstanceOf[Int], _scale) case LongType => - round(x.asInstanceOf[Long], _scale) + numericRound(x.asInstanceOf[Long], _scale) case FloatType => - round(x.asInstanceOf[Float], _scale) + numericRound(x.asInstanceOf[Float], _scale) case DoubleType => - round(x.asInstanceOf[Double], _scale) + numericRound(x.asInstanceOf[Double], _scale) case StringType => - round(x.asInstanceOf[UTF8String].toString, _scale) + stringLikeRound(x.asInstanceOf[UTF8String].toString, _scale) case BinaryType => - round(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) + stringLikeRound(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) } } - private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + private def numericRound[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { input match { case f: Float if (f.isNaN || f.isInfinite) => return input case d: Double if (d.isNaN || d.isInfinite) => return input @@ -605,18 +612,12 @@ case class Round(child: Expression, scale: Expression) extends Expression { bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) } - private def round(input: String, scale: Int): Any = { - try round(input.toDouble, scale) catch { + private def stringLikeRound(input: String, scale: Int): Any = { + try numericRound(input.toDouble, scale) catch { case _: NumberFormatException => null } } - def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null || scaleV == null) return null - rounding(evalE) - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) From c3b9839b63affa05e3549d7e8cdb6950e9abb0ba Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 3 Jul 2015 23:48:46 +0800 Subject: [PATCH 13/21] rely on implict cast to handle string input --- .../spark/sql/catalyst/expressions/math.scala | 44 ++++--------------- .../ExpressionTypeCheckingSuite.scala | 19 +++++--- .../expressions/MathFunctionsSuite.scala | 7 --- 3 files changed, 21 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 92d8118c67252..f858650df410d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -524,7 +524,7 @@ case class Logarithm(left: Expression, right: Expression) } } -case class Round(child: Expression, scale: Expression) extends Expression { +case class Round(child: Expression, scale: Expression) extends Expression with ExpectsInputTypes { def this(child: Expression) = { this(child, Literal(0)) @@ -537,17 +537,17 @@ case class Round(child: Expression, scale: Expression) extends Expression { override def foldable: Boolean = child.foldable override lazy val dataType: DataType = child.dataType match { - case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) case t => t } + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegralType) + override def checkInputDataTypes(): TypeCheckResult = { child.dataType match { - case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement + case _: NumericType => // satisfy requirement case dt => - return TypeCheckFailure(s"Only numeric, string or binary data types" + - s" are allowed for ROUND function, got $dt") + return TypeCheckFailure(s"Only numeric type is allowed for ROUND function, got $dt") } scale match { case Literal(value, LongType) => @@ -555,12 +555,11 @@ case class Round(child: Expression, scale: Expression) extends Expression { return TypeCheckFailure("ROUND scale argument out of allowed range") } case _ => - if ((scale.dataType.isInstanceOf[IntegralType] || scale.dataType == NullType) && - scale.foldable) { - // TODO: foldable LongType is not checked for out of range + if (scale.dataType.isInstanceOf[IntegralType] && scale.foldable) { + // TODO: How to check out of range for foldable LongType Expression // satisfy requirement } else { - return TypeCheckFailure("Only Integral or Null foldable Expression " + + return TypeCheckFailure("Only foldable Integral Expression " + s"is allowed for ROUND scale arguments, got ${child.dataType}") } } @@ -596,10 +595,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { numericRound(x.asInstanceOf[Float], _scale) case DoubleType => numericRound(x.asInstanceOf[Double], _scale) - case StringType => - stringLikeRound(x.asInstanceOf[UTF8String].toString, _scale) - case BinaryType => - stringLikeRound(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) } } @@ -612,12 +607,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) } - private def stringLikeRound(input: String, scale: Int): Any = { - try numericRound(input.toDouble, scale) catch { - case _: NumberFormatException => null - } - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) @@ -637,19 +626,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { }""" } - def stringLikeConvert(primitive: String): String = { - val dName = ctx.freshName("converter") - s""" - Double $dName = 0.0; - try { - $dName = Double.valueOf(${primitive}.toString()); - } catch (NumberFormatException e) { - ${ev.isNull} = true; - } - ${fractionalCheck(dName, "doubleValue()")} - """ - } - def decimalRound(): String = { s""" if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { @@ -676,10 +652,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { fractionalCheck(ce.primitive, "floatValue()") case DoubleType => fractionalCheck(ce.primitive, "doubleValue()") - case StringType => - stringLikeConvert(ce.primitive) - case BinaryType => - stringLikeConvert(s"${ctx.stringType}.fromBytes(${ce.primitive})") } ce.code + s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 6bae906ee9f57..8b596ff2526d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") } + def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains(errorMessage)) + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") @@ -173,14 +180,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertError(Round(Literal(null), 'booleanField), - "Only Integral or Null foldable Expression is allowed for ROUND scale argument") - assertError(Round(Literal(null), 'complexField), - "Only Integral or Null foldable Expression is allowed for ROUND scale argument") + assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), + "Only foldable Integral Expression is allowed for ROUND scale arguments") + assertErrorWithImplicitCast(Round(Literal(null), 'complexField), + "Only foldable Integral Expression is allowed for ROUND scale arguments") assertSuccess(Round(Literal(null), Literal(null))) assertError(Round('booleanField, 'intField), - "Only numeric, string or binary data types are allowed for ROUND function") - assertError(Round(Literal(null), Literal(1L + Int.MaxValue)), + "Only numeric type is allowed for ROUND function") + assertErrorWithImplicitCast(Round(Literal(null), Literal(1L + Int.MaxValue)), "ROUND scale argument out of allowed range") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 477ae969240e9..7aa924c6d4584 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -342,8 +342,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("round") { val domain = -16 to 16 val doublePi = math.Pi - val stringPi = "3.141592653589793" - val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte) val shortPi: Short = 31415 val intPi = 314159265 val longPi = 31415926535897932L @@ -352,10 +350,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { domain.foreach { scale => checkEvaluation(Round(doublePi, scale), BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(stringPi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(arrayPi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) checkEvaluation(Round(shortPi, scale), BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow) checkEvaluation(Round(intPi, scale), @@ -363,7 +357,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(longPi, scale), BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow) } - checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null From d10be4aa883e435cddeb909b4dcc7d1bf28f2d4b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 4 Jul 2015 00:31:31 +0800 Subject: [PATCH 14/21] use TypeCollection to specify wanted input and implicit cast --- .../org/apache/spark/sql/catalyst/expressions/math.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index f858650df410d..cee440773592a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -541,7 +541,10 @@ case class Round(child: Expression, scale: Expression) extends Expression with E case t => t } - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegralType) + override def inputTypes: Seq[AbstractDataType] = Seq( + //rely on precedence to implicit cast String into Double + TypeCollection(DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), + TypeCollection(LongType, IntegerType, ShortType, ByteType)) override def checkInputDataTypes(): TypeCheckResult = { child.dataType match { From 9555e35f58bbeab84edecd3190be514e268d11c3 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 4 Jul 2015 00:45:41 +0800 Subject: [PATCH 15/21] tiny style fix --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index cee440773592a..500f4fad344ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -542,7 +542,7 @@ case class Round(child: Expression, scale: Expression) extends Expression with E } override def inputTypes: Seq[AbstractDataType] = Seq( - //rely on precedence to implicit cast String into Double + // rely on precedence to implicit cast String into Double TypeCollection(DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), TypeCollection(LongType, IntegerType, ShortType, ByteType)) From 8c7a949be724d2748689948b4ade6dcb64c46a31 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 6 Jul 2015 18:22:47 +0800 Subject: [PATCH 16/21] rebase & inputTypes update --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 500f4fad344ca..e28fa98534f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -543,7 +543,7 @@ case class Round(child: Expression, scale: Expression) extends Expression with E override def inputTypes: Seq[AbstractDataType] = Seq( // rely on precedence to implicit cast String into Double - TypeCollection(DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), + TypeCollection(DecimalType, DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), TypeCollection(LongType, IntegerType, ShortType, ByteType)) override def checkInputDataTypes(): TypeCheckResult = { From 31dfe7ce3cc47452cca22ff2f2438af267a2f025 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 10 Jul 2015 20:04:36 +0800 Subject: [PATCH 17/21] refactor round to make it readable --- .../spark/sql/catalyst/expressions/math.scala | 164 ++++++++---------- .../catalyst/util/BigDecimalConverter.scala | 60 ------- .../ExpressionTypeCheckingSuite.scala | 8 +- 3 files changed, 75 insertions(+), 157 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e28fa98534f5d..cdb2db8c4b046 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.BigDecimalConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -524,144 +523,125 @@ case class Logarithm(left: Expression, right: Expression) } } -case class Round(child: Expression, scale: Expression) extends Expression with ExpectsInputTypes { +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { - def this(child: Expression) = { - this(child, Literal(0)) - } + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale override def children: Seq[Expression] = Seq(child, scale) + // round of Decimal would eval to null if it fails to `changePrecision` override def nullable: Boolean = true override def foldable: Boolean = child.foldable override lazy val dataType: DataType = child.dataType match { - case DecimalType.Fixed(p, s) => DecimalType(p, _scale) - case t => t - } + case DecimalType.Fixed(p, s) => DecimalType(p, _scale) + case t => t + } - override def inputTypes: Seq[AbstractDataType] = Seq( - // rely on precedence to implicit cast String into Double - TypeCollection(DecimalType, DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), - TypeCollection(LongType, IntegerType, ShortType, ByteType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { - child.dataType match { - case _: NumericType => // satisfy requirement - case dt => - return TypeCheckFailure(s"Only numeric type is allowed for ROUND function, got $dt") - } - scale match { - case Literal(value, LongType) => - if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { - return TypeCheckFailure("ROUND scale argument out of allowed range") - } - case _ => - if (scale.dataType.isInstanceOf[IntegralType] && scale.foldable) { - // TODO: How to check out of range for foldable LongType Expression - // satisfy requirement + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess } else { - return TypeCheckFailure("Only foldable Integral Expression " + - s"is allowed for ROUND scale arguments, got ${child.dataType}") + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") } + case f => f } - TypeCheckSuccess } private lazy val scaleV = scale.eval(EmptyRow) private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null || scaleV == null) return null - round(evalE) - } - - private lazy val round: (Any) => (Any) = typedRound(child.dataType) - - // Using dataType info to find an appropriate round method - private def typedRound(dt: DataType)(x: Any): Any = { - dt match { + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + child.dataType match { case _: DecimalType => - val decimal = x.asInstanceOf[Decimal] + val decimal = input1.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => - numericRound(x.asInstanceOf[Byte], _scale) + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte case ShortType => - numericRound(x.asInstanceOf[Short], _scale) + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort case IntegerType => - numericRound(x.asInstanceOf[Int], _scale) + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt case LongType => - numericRound(x.asInstanceOf[Long], _scale) + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong case FloatType => - numericRound(x.asInstanceOf[Float], _scale) + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } case DoubleType => - numericRound(x.asInstanceOf[Double], _scale) - } - } - - private def numericRound[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { - input match { - case f: Float if (f.isNaN || f.isInfinite) => return input - case d: Double if (d.isNaN || d.isInfinite) => return input - case _ => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } } - bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) - def round(primitive: String, integral: Boolean): String = { - val (p1, p2) = if (integral) ("new", "") else ("", ".valueOf") - s""" - ${ev.primitive} = $p1 java.math.BigDecimal$p2(${primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)""" - } - - def fractionalCheck(primitive: String, function: String): String = { - s""" - if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){ - ${ev.primitive} = ${primitive}; - } else { - ${round(primitive, false)}.${function}; - }""" - } - - def decimalRound(): String = { - s""" - if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { - ${ev.primitive} = ${ce.primitive}; - } else { - ${ev.isNull} = true; - } - """ - } - - val roundCode = child.dataType match { - case NullType => ";" + val evaluationCode = child.dataType match { case _: DecimalType => - decimalRound() + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" case ByteType => - round(ce.primitive, true) + ".byteValue();" + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" case ShortType => - round(ce.primitive, true) + ".shortValue();" + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" case IntegerType => - round(ce.primitive, true) + ".intValue();" + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" case LongType => - round(ce.primitive, true) + ".longValue();" + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" case FloatType => - fractionalCheck(ce.primitive, "floatValue()") + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" case DoubleType => - fractionalCheck(ce.primitive, "doubleValue()") + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" } ce.code + s""" boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${roundCode} + ${evaluationCode} } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala deleted file mode 100644 index 5ce7758156ccb..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -trait BigDecimalConverter[T] { - def toBigDecimal(in: T): BigDecimal - def fromBigDecimal(bd: BigDecimal): T -} - -/** - * Helper type converters to work with BigDecimal - * from http://stackoverflow.com/a/30979266/1115193 - */ -object BigDecimalConverter { - - implicit object ByteConverter extends BigDecimalConverter[Byte] { - def toBigDecimal(in: Byte): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Byte = bd.toByte - } - - implicit object ShortConverter extends BigDecimalConverter[Short] { - def toBigDecimal(in: Short): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Short = bd.toShort - } - - implicit object IntConverter extends BigDecimalConverter[Int] { - def toBigDecimal(in: Int): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Int = bd.toInt - } - - implicit object LongConverter extends BigDecimalConverter[Long] { - def toBigDecimal(in: Long): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Long = bd.toLong - } - - implicit object FloatConverter extends BigDecimalConverter[Float] { - def toBigDecimal(in: Float): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Float = bd.toFloat - } - - implicit object DoubleConverter extends BigDecimalConverter[Double] { - def toBigDecimal(in: Double): BigDecimal = BigDecimal(in) - def fromBigDecimal(bd: BigDecimal): Double = bd.toDouble - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8b596ff2526d4..0f4fccffbd46f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -181,13 +181,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for ROUND") { assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), - "Only foldable Integral Expression is allowed for ROUND scale arguments") + "data type mismatch: argument 2 is expected to be of type int") assertErrorWithImplicitCast(Round(Literal(null), 'complexField), - "Only foldable Integral Expression is allowed for ROUND scale arguments") + "data type mismatch: argument 2 is expected to be of type int") assertSuccess(Round(Literal(null), Literal(null))) assertError(Round('booleanField, 'intField), - "Only numeric type is allowed for ROUND function") - assertErrorWithImplicitCast(Round(Literal(null), Literal(1L + Int.MaxValue)), - "ROUND scale argument out of allowed range") + "data type mismatch: argument 1 is expected to be of type numeric") } } From 302a78a2fd3aec89e7a8c8e544f429728d270030 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 10 Jul 2015 21:26:24 +0800 Subject: [PATCH 18/21] Add dataframe function test --- .../apache/spark/sql/MathExpressionsSuite.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8ccfdd5147680..0020d318ecd8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -199,11 +199,17 @@ class MathExpressionsSuite extends QueryTest { } test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( - ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9)"), - Seq((1, 2)).toDF().select( - round(lit(-32768)), - round(lit(1809242.3151111344), 9)) + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) ) } From 61760eeb92476fe327fd352c3879f71e31289e64 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 11 Jul 2015 22:02:44 +0800 Subject: [PATCH 19/21] address reviews --- .../spark/sql/catalyst/expressions/math.scala | 145 +++++++++++++----- .../expressions/MathFunctionsSuite.scala | 51 +++--- 2 files changed, 141 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index cdb2db8c4b046..7e7a0e280a62d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -523,6 +523,20 @@ case class Logarithm(left: Expression, right: Expression) } } +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ case class Round(child: Expression, scale: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -559,10 +573,27 @@ case class Round(child: Expression, scale: Expression) } } - private lazy val scaleV = scale.eval(EmptyRow) - private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] - protected override def nullSafeEval(input1: Any, input2: Any): Any = { + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] @@ -604,45 +635,89 @@ case class Round(child: Expression, scale: Expression) ${ev.isNull} = true; }""" case ByteType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case ShortType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case IntegerType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } case LongType => - s""" - ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" - case FloatType => - s""" - if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); - }""" - case DoubleType => - s""" - if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ - ${ev.primitive} = ${ce.primitive}; + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" } else { - ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); - }""" + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } } - ce.code + s""" - boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${evaluationCode} - } + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } """ + } } + + override def prettyName: String = "round" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7aa924c6d4584..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -340,32 +340,43 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -16 to 16 - val doublePi = math.Pi + val domain = -6 to 6 + val doublePi: Double = math.Pi val shortPi: Short = 31415 - val intPi = 314159265 - val longPi = 31415926535897932L - val bdPi = BigDecimal(31415926535897932L, 10) - - domain.foreach { scale => - checkEvaluation(Round(doublePi, scale), - BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow) - checkEvaluation(Round(shortPi, scale), - BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow) - checkEvaluation(Round(intPi, scale), - BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow) - checkEvaluation(Round(longPi, scale), - BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow) + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null - val (validScales, nullScales) = domain.splitAt(27) - validScales.foreach { scale => - checkEvaluation(Round(bdPi, scale), - Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow) + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) } - nullScales.foreach { scale => + (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } } From 392b65baa63b9d3b6597fc7b168c62dd838d8987 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 14 Jul 2015 17:03:43 +0800 Subject: [PATCH 20/21] add negative scale test in DecimalSuite --- .../spark/sql/catalyst/expressions/math.scala | 4 +++- .../sql/types/decimal/DecimalSuite.scala | 23 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7e7a0e280a62d..efb7dc50fd8b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -555,7 +555,9 @@ case class Round(child: Expression, scale: Expression) override def foldable: Boolean = child.foldable override lazy val dataType: DataType = child.dataType match { - case DecimalType.Fixed(p, s) => DecimalType(p, _scale) + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) case t => t } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 5f312964e5bf7..9e8cd026d2a32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { From 07a124c4ac57e933d9b645fc43953dc035bab147 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 15 Jul 2015 14:21:40 +0800 Subject: [PATCH 21/21] remove useless def children --- .../scala/org/apache/spark/sql/catalyst/expressions/math.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index efb7dc50fd8b1..20c3874fcfe4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -547,8 +547,6 @@ case class Round(child: Expression, scale: Expression) override def left: Expression = child override def right: Expression = scale - override def children: Seq[Expression] = Seq(child, scale) - // round of Decimal would eval to null if it fails to `changePrecision` override def nullable: Boolean = true