From a1e77445c2675137fbcddf73181c47469f159dbf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 4 Dec 2018 10:33:27 -0800 Subject: [PATCH] [SPARK-26233][SQL] CheckOverflow when encoding a decimal value When we encode a Decimal from external source we don't check for overflow. That method is useful not only in order to enforce that we can represent the correct value in the specified range, but it also changes the underlying data to the right precision/scale. Since in our code generation we assume that a decimal has exactly the same precision and scale of its data type, missing to enforce it can lead to corrupted output/results when there are subsequent transformations. added UT Closes #23210 from mgaido91/SPARK-26233. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/catalyst/encoders/RowEncoder.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..c2e735231a79c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -108,12 +108,12 @@ object RowEncoder { returnNullable = false) case d: DecimalType => - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, d, "fromDecimal", inputObject :: Nil, - returnNullable = false) + returnNullable = false), d) case StringType => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..3b7bd84be047f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1458,6 +1458,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-26233: serializer should enforce decimal precision and scale") { + val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8)))) + val encoder = RowEncoder(s) + implicit val uEnc = encoder + val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111))) + checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), + Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) + } } case class TestDataUnion(x: Int, y: Int, z: Int)