From 5bcf37e669d12ffd03df2dc01d5efc990228286e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 1 May 2026 16:28:20 -0600 Subject: [PATCH] fix: preserve stored allowDecimalPrecisionLoss in DecimalPrecision rule Comet's `DecimalPrecision.promote` recomputed the result type for decimal Add/Subtract/Multiply/Divide/Remainder from the live `SQLConf` and wrapped the expression in `CheckOverflow` with that type. On Spark 4.1.1+ (SPARK-53968), arithmetic expressions store their `allowDecimalPrecisionLoss` per-instance so a view's analyzed plan keeps a stable result type across config changes. The recomputed type could disagree with `Add.dataType`, and the native `CheckOverflow` only relabels the decimal buffer (it does not rescale), shifting values by 10x. Use `expr.dataType` directly. On older Spark this is equivalent to the recomputed value; on 4.1+ it honours the stored evalContext. --- dev/diffs/4.1.1.diff | 10 -- .../apache/comet/serde/QueryPlanSerde.scala | 4 +- .../spark/sql/comet/DecimalPrecision.scala | 93 +++++-------------- .../CometDecimalArithmeticViewSuite.scala | 66 +++++++++++++ 4 files changed, 92 insertions(+), 81 deletions(-) create mode 100644 spark/src/test/spark-4.1/org/apache/spark/sql/comet/CometDecimalArithmeticViewSuite.scala diff --git a/dev/diffs/4.1.1.diff b/dev/diffs/4.1.1.diff index 8c2949c293..a6b749e1dc 100644 --- a/dev/diffs/4.1.1.diff +++ b/dev/diffs/4.1.1.diff @@ -2034,16 +2034,6 @@ index 050a004a935..96d982f2829 100644 withTable("t") { Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t") withView("v1") { -@@ -1334,7 +1335,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { - } - } - -- test("SPARK-53968 reading the view after allowPrecisionLoss is changed") { -+ test("SPARK-53968 reading the view after allowPrecisionLoss is changed", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4124")) { - import org.apache.spark.sql.internal.SQLConf - val partsTableName = "parts_tbl" - val ordersTableName = "orders_tbl" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index aed11badb71..1a365b5aacf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index af3e27c774..4a0f0b00ff 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -585,9 +585,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { inputs: Seq[Attribute], binding: Boolean = true): Option[Expr] = { - val conf = SQLConf.get - val newExpr = - DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) + val newExpr = DecimalPrecision.promote(expr, !SQLConf.get.ansiEnabled) exprToProtoInternal(newExpr, inputs, binding) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala index a2cdf421c1..6309a4304a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala @@ -19,90 +19,47 @@ package org.apache.spark.sql.comet -import scala.math.{max, min} - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DecimalType /** - * This is mostly copied from the `decimalAndDecimal` method in Spark's [[DecimalPrecision]] which - * existed before Spark 3.4. - * - * In Spark 3.4 and up, the method `decimalAndDecimal` is removed from Spark, and for binary - * expressions with different decimal precisions from children, the difference is handled in the - * expression evaluation instead (see SPARK-39316). - * - * However in Comet, we still have to rely on the type coercion to ensure the decimal precision is - * the same for both children of a binary expression, since our arithmetic kernels do not yet - * handle the case where precision is different. Therefore, this re-apply the logic in the - * original rule, and rely on `Cast` and `CheckOverflow` for decimal binary operation. + * Wraps decimal binary arithmetic expressions in [[CheckOverflow]] so the native side has an + * explicit target type for the result. * - * TODO: instead of relying on this rule, it's probably better to enhance arithmetic kernels to - * handle different decimal precisions + * Spark itself stopped wrapping these in `CheckOverflow` in 3.4 (SPARK-39316), but Comet's native + * `CheckOverflow` only validates precision (it does not rescale), so the target type must equal + * the child's actual `dataType`. Always using `expr.dataType` is the safe choice: on Spark 3.4 - + * 4.0 it equals the value the rule would otherwise recompute from `SQLConf`, and on Spark 4.1+ + * (SPARK-53968) it preserves the per-expression `allowDecimalPrecisionLoss` captured at view + * creation time. Recomputing from the live `SQLConf` would re-label a stored DEC(38, 17) result + * as DEC(38, 18) (or vice versa) and shift values by 10x (issue #4124). */ object DecimalPrecision { - def promote( - allowPrecisionLoss: Boolean, - expr: Expression, - nullOnOverflow: Boolean): Expression = { + def promote(expr: Expression, nullOnOverflow: Boolean): Expression = { expr.transformUp { // This means the binary expression is already optimized with the rule in Spark. This can // happen if the Spark version is < 3.4 case e: BinaryArithmetic if e.left.prettyName == "promote_precision" => e - case add @ Add(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => - val resultScale = max(s1, s2) - val resultType = if (allowPrecisionLoss) { - DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) - } else { - DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) - } - CheckOverflow(add, resultType, nullOnOverflow) + case add @ Add(DecimalExpression(_, _), DecimalExpression(_, _), _) + if add.dataType.isInstanceOf[DecimalType] => + CheckOverflow(add, add.dataType.asInstanceOf[DecimalType], nullOnOverflow) - case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => - val resultScale = max(s1, s2) - val resultType = if (allowPrecisionLoss) { - DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) - } else { - DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) - } - CheckOverflow(sub, resultType, nullOnOverflow) + case sub @ Subtract(DecimalExpression(_, _), DecimalExpression(_, _), _) + if sub.dataType.isInstanceOf[DecimalType] => + CheckOverflow(sub, sub.dataType.asInstanceOf[DecimalType], nullOnOverflow) - case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => - val resultType = if (allowPrecisionLoss) { - DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) - } else { - DecimalType.bounded(p1 + p2 + 1, s1 + s2) - } - CheckOverflow(mul, resultType, nullOnOverflow) + case mul @ Multiply(DecimalExpression(_, _), DecimalExpression(_, _), _) + if mul.dataType.isInstanceOf[DecimalType] => + CheckOverflow(mul, mul.dataType.asInstanceOf[DecimalType], nullOnOverflow) - case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => - val resultType = if (allowPrecisionLoss) { - // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) - // Scale: max(6, s1 + p2 + 1) - val intDig = p1 - s1 + s2 - val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) - val prec = intDig + scale - DecimalType.adjustPrecisionScale(prec, scale) - } else { - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig - } - DecimalType.bounded(intDig + decDig, decDig) - } - CheckOverflow(div, resultType, nullOnOverflow) + case div @ Divide(DecimalExpression(_, _), DecimalExpression(_, _), _) + if div.dataType.isInstanceOf[DecimalType] => + CheckOverflow(div, div.dataType.asInstanceOf[DecimalType], nullOnOverflow) - case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) => - val resultType = if (allowPrecisionLoss) { - DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - } else { - DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - } - CheckOverflow(rem, resultType, nullOnOverflow) + case rem @ Remainder(DecimalExpression(_, _), DecimalExpression(_, _), _) + if rem.dataType.isInstanceOf[DecimalType] => + CheckOverflow(rem, rem.dataType.asInstanceOf[DecimalType], nullOnOverflow) case e => e } diff --git a/spark/src/test/spark-4.1/org/apache/spark/sql/comet/CometDecimalArithmeticViewSuite.scala b/spark/src/test/spark-4.1/org/apache/spark/sql/comet/CometDecimalArithmeticViewSuite.scala new file mode 100644 index 0000000000..d540bb2ba3 --- /dev/null +++ b/spark/src/test/spark-4.1/org/apache/spark/sql/comet/CometDecimalArithmeticViewSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, CheckOverflow, EvalMode, NumericEvalContext} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType + +class CometDecimalArithmeticViewSuite extends CometTestBase { + + // Spark 4.1.1 (SPARK-53968) stores `spark.sql.decimalOperations.allowPrecisionLoss` per + // arithmetic expression so a view's analyzed plan keeps a stable result type across config + // changes. Comet's DecimalPrecision rule used to recompute the result type from the current + // SQLConf, producing a CheckOverflow target that disagreed with the stored Add.dataType and + // re-labelling the Decimal128 buffer at the wrong scale (issue #4124). The repro requires a + // mismatch between the stored evalContext and the live SQLConf, which is why we construct + // Add directly rather than going through SQL parsing. + test("issue #4124: DecimalPrecision.promote honours per-expression allowPrecisionLoss") { + val left = AttributeReference("a", DecimalType(38, 18))() + val right = AttributeReference("b", DecimalType(38, 18))() + val storedTrue = + Add(left, right, NumericEvalContext(EvalMode.LEGACY, allowDecimalPrecisionLoss = true)) + val storedFalse = + Add(left, right, NumericEvalContext(EvalMode.LEGACY, allowDecimalPrecisionLoss = false)) + assert(storedTrue.dataType === DecimalType(38, 17)) + assert(storedFalse.dataType === DecimalType(38, 18)) + + // Current SQLConf disagrees with the stored evalContext on each Add. The promoted + // CheckOverflow's target type must come from Add.dataType (which honours the stored + // evalContext), not from the current SQLConf. Otherwise the native CheckOverflow + // re-labels the Decimal128 buffer at the wrong scale and values come out 10x off. + Seq((true, storedFalse), (false, storedTrue)).foreach { case (currentConf, add) => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> currentConf.toString) { + val promoted = org.apache.spark.sql.comet.DecimalPrecision + .promote(add, nullOnOverflow = true) + promoted match { + case CheckOverflow(_, dt, _) => + assert( + dt === add.dataType, + s"CheckOverflow target $dt must match Add.dataType ${add.dataType}; mismatch " + + s"causes the decimal buffer to be re-labelled at the wrong scale.") + case other => + fail(s"Expected DecimalPrecision.promote to wrap Add in CheckOverflow, got: $other") + } + } + } + } +}