Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions dev/diffs/4.1.1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2034,16 +2034,6 @@ index 050a004a935..96d982f2829 100644
withTable("t") {
Seq(2, 3, 1).toDF("c1").write.format("parquet").saveAsTable("t")
withView("v1") {
@@ -1334,7 +1335,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}

- test("SPARK-53968 reading the view after allowPrecisionLoss is changed") {
+ test("SPARK-53968 reading the view after allowPrecisionLoss is changed",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4124")) {
import org.apache.spark.sql.internal.SQLConf
val partsTableName = "parts_tbl"
val ordersTableName = "orders_tbl"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index aed11badb71..1a365b5aacf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
inputs: Seq[Attribute],
binding: Boolean = true): Option[Expr] = {

val conf = SQLConf.get
val newExpr =
DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled)
val newExpr = DecimalPrecision.promote(expr, !SQLConf.get.ansiEnabled)
exprToProtoInternal(newExpr, inputs, binding)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,90 +19,47 @@

package org.apache.spark.sql.comet

import scala.math.{max, min}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DecimalType

/**
* This is mostly copied from the `decimalAndDecimal` method in Spark's [[DecimalPrecision]] which
* existed before Spark 3.4.
*
* In Spark 3.4 and up, the method `decimalAndDecimal` is removed from Spark, and for binary
* expressions with different decimal precisions from children, the difference is handled in the
* expression evaluation instead (see SPARK-39316).
*
* However in Comet, we still have to rely on the type coercion to ensure the decimal precision is
* the same for both children of a binary expression, since our arithmetic kernels do not yet
* handle the case where precision is different. Therefore, this re-apply the logic in the
* original rule, and rely on `Cast` and `CheckOverflow` for decimal binary operation.
* Wraps decimal binary arithmetic expressions in [[CheckOverflow]] so the native side has an
* explicit target type for the result.
*
* TODO: instead of relying on this rule, it's probably better to enhance arithmetic kernels to
* handle different decimal precisions
* Spark itself stopped wrapping these in `CheckOverflow` in 3.4 (SPARK-39316), but Comet's native
* `CheckOverflow` only validates precision (it does not rescale), so the target type must equal
* the child's actual `dataType`. Always using `expr.dataType` is the safe choice: on Spark 3.4 -
* 4.0 it equals the value the rule would otherwise recompute from `SQLConf`, and on Spark 4.1+
* (SPARK-53968) it preserves the per-expression `allowDecimalPrecisionLoss` captured at view
* creation time. Recomputing from the live `SQLConf` would re-label a stored DEC(38, 17) result
* as DEC(38, 18) (or vice versa) and shift values by 10x (issue #4124).
*/
object DecimalPrecision {
def promote(
allowPrecisionLoss: Boolean,
expr: Expression,
nullOnOverflow: Boolean): Expression = {
def promote(expr: Expression, nullOnOverflow: Boolean): Expression = {
expr.transformUp {
// This means the binary expression is already optimized with the rule in Spark. This can
// happen if the Spark version is < 3.4
case e: BinaryArithmetic if e.left.prettyName == "promote_precision" => e

case add @ Add(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(add, resultType, nullOnOverflow)
case add @ Add(DecimalExpression(_, _), DecimalExpression(_, _), _)
if add.dataType.isInstanceOf[DecimalType] =>
CheckOverflow(add, add.dataType.asInstanceOf[DecimalType], nullOnOverflow)

case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(sub, resultType, nullOnOverflow)
case sub @ Subtract(DecimalExpression(_, _), DecimalExpression(_, _), _)
if sub.dataType.isInstanceOf[DecimalType] =>
CheckOverflow(sub, sub.dataType.asInstanceOf[DecimalType], nullOnOverflow)

case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
CheckOverflow(mul, resultType, nullOnOverflow)
case mul @ Multiply(DecimalExpression(_, _), DecimalExpression(_, _), _)
if mul.dataType.isInstanceOf[DecimalType] =>
CheckOverflow(mul, mul.dataType.asInstanceOf[DecimalType], nullOnOverflow)

case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
val prec = intDig + scale
DecimalType.adjustPrecisionScale(prec, scale)
} else {
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
}
DecimalType.bounded(intDig + decDig, decDig)
}
CheckOverflow(div, resultType, nullOnOverflow)
case div @ Divide(DecimalExpression(_, _), DecimalExpression(_, _), _)
if div.dataType.isInstanceOf[DecimalType] =>
CheckOverflow(div, div.dataType.asInstanceOf[DecimalType], nullOnOverflow)

case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
CheckOverflow(rem, resultType, nullOnOverflow)
case rem @ Remainder(DecimalExpression(_, _), DecimalExpression(_, _), _)
if rem.dataType.isInstanceOf[DecimalType] =>
CheckOverflow(rem, rem.dataType.asInstanceOf[DecimalType], nullOnOverflow)

case e => e
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, CheckOverflow, EvalMode, NumericEvalContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DecimalType

class CometDecimalArithmeticViewSuite extends CometTestBase {

// Spark 4.1.1 (SPARK-53968) stores `spark.sql.decimalOperations.allowPrecisionLoss` per
// arithmetic expression so a view's analyzed plan keeps a stable result type across config
// changes. Comet's DecimalPrecision rule used to recompute the result type from the current
// SQLConf, producing a CheckOverflow target that disagreed with the stored Add.dataType and
// re-labelling the Decimal128 buffer at the wrong scale (issue #4124). The repro requires a
// mismatch between the stored evalContext and the live SQLConf, which is why we construct
// Add directly rather than going through SQL parsing.
test("issue #4124: DecimalPrecision.promote honours per-expression allowPrecisionLoss") {
val left = AttributeReference("a", DecimalType(38, 18))()
val right = AttributeReference("b", DecimalType(38, 18))()
val storedTrue =
Add(left, right, NumericEvalContext(EvalMode.LEGACY, allowDecimalPrecisionLoss = true))
val storedFalse =
Add(left, right, NumericEvalContext(EvalMode.LEGACY, allowDecimalPrecisionLoss = false))
assert(storedTrue.dataType === DecimalType(38, 17))
assert(storedFalse.dataType === DecimalType(38, 18))

// Current SQLConf disagrees with the stored evalContext on each Add. The promoted
// CheckOverflow's target type must come from Add.dataType (which honours the stored
// evalContext), not from the current SQLConf. Otherwise the native CheckOverflow
// re-labels the Decimal128 buffer at the wrong scale and values come out 10x off.
Seq((true, storedFalse), (false, storedTrue)).foreach { case (currentConf, add) =>
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> currentConf.toString) {
val promoted = org.apache.spark.sql.comet.DecimalPrecision
.promote(add, nullOnOverflow = true)
promoted match {
case CheckOverflow(_, dt, _) =>
assert(
dt === add.dataType,
s"CheckOverflow target $dt must match Add.dataType ${add.dataType}; mismatch " +
s"causes the decimal buffer to be re-labelled at the wrong scale.")
case other =>
fail(s"Expected DecimalPrecision.promote to wrap Add in CheckOverflow, got: $other")
}
}
}
}
}
Loading