diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs index 773ddea050..08e335f427 100644 --- a/native/spark-expr/src/agg_funcs/avg_decimal.rs +++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs @@ -207,6 +207,12 @@ impl AggregateUDFImpl for AvgDecimal { fn return_type(&self, arg_types: &[DataType]) -> Result { avg_return_type(self.name(), &arg_types[0]) } + + fn is_nullable(&self) -> bool { + // In Spark, Sum.nullable and Average.nullable both return true irrespective of ANSI mode. + // AvgDecimal is always nullable because overflows can cause null values. + true + } } /// An accumulator to compute the average for decimals diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index bf5569b00b..56a735493c 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -164,7 +164,8 @@ impl AggregateUDFImpl for SumDecimal { } fn is_nullable(&self) -> bool { - // SumDecimal is always nullable because overflows can cause null values + // In Spark, Sum.nullable and Average.nullable both return true irrespective of ANSI mode. + // SumDecimal is always nullable because overflows can cause null values. true } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 9426d1c848..be60f4aaee 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1917,6 +1917,45 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("SumDecimal and AvgDecimal nullable should always be true") { + // SumDecimal and AvgDecimal currently hardcode nullable=true. + // This matches Spark's Sum.nullable and Average.nullable which always return true, + // regardless of ANSI mode or input nullability. + val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq( + (new java.math.BigDecimal("10.00"), 1), + (new java.math.BigDecimal("20.00"), 1), + (new java.math.BigDecimal("30.00"), 2)) + + val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq( + (new java.math.BigDecimal("10.00"), 1), + (null.asInstanceOf[java.math.BigDecimal], 1), + (new java.math.BigDecimal("30.00"), 2)) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable(nonNullableData, "tbl") { + val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(sumRes) + assert(sumRes.schema.fields(1).nullable == true) + + val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(avgRes) + assert(avgRes.schema.fields(1).nullable == true) + } + + withParquetTable(nullableData, "tbl") { + val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(sumRes) + assert(sumRes.schema.fields(1).nullable == true) + + val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(avgRes) + assert(avgRes.schema.fields(1).nullable == true) + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)