From b1fb0bd2d6c372f99bfff31b857c1e3d42041ac6 Mon Sep 17 00:00:00 2001 From: vaibhawvipul Date: Mon, 23 Mar 2026 09:32:25 +0530 Subject: [PATCH 1/2] fix: Make SumDecimal and AvgDecimal nullability depend on ANSI mode and input nullability --- native/core/src/execution/planner.rs | 4 ++ .../spark-expr/src/agg_funcs/avg_decimal.rs | 17 +++++++++ .../spark-expr/src/agg_funcs/sum_decimal.rs | 18 ++++++++- .../comet/exec/CometAggregateSuite.scala | 37 +++++++++++++++++++ 4 files changed, 74 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bd37755922..a513bc0143 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1961,9 +1961,11 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let input_nullable = child.nullable(&schema)?; let func = AggregateUDF::new_from_impl(SumDecimal::try_new( datatype, eval_mode, + input_nullable, spark_expr.expr_id, Arc::clone(&self.query_context_registry), )?); @@ -1999,10 +2001,12 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let input_nullable = child.nullable(&schema)?; let func = AggregateUDF::new_from_impl(AvgDecimal::new( datatype, input_datatype, eval_mode, + input_nullable, spark_expr.expr_id, Arc::clone(&self.query_context_registry), )); diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs index 773ddea050..ae9700c4e3 100644 --- a/native/spark-expr/src/agg_funcs/avg_decimal.rs +++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs @@ -62,6 +62,8 @@ pub struct AvgDecimal { sum_data_type: DataType, result_data_type: DataType, eval_mode: EvalMode, + /// Whether the input expression is nullable + input_nullable: bool, expr_id: Option, registry: Arc, } @@ -72,6 +74,7 @@ impl PartialEq for AvgDecimal { self.sum_data_type == other.sum_data_type && self.result_data_type == other.result_data_type && self.eval_mode == other.eval_mode + && self.input_nullable == other.input_nullable && self.expr_id == other.expr_id } } @@ -83,6 +86,7 @@ impl std::hash::Hash for AvgDecimal { self.sum_data_type.hash(state); self.result_data_type.hash(state); self.eval_mode.hash(state); + self.input_nullable.hash(state); self.expr_id.hash(state); } } @@ -93,6 +97,7 @@ impl AvgDecimal { result_type: DataType, sum_type: DataType, eval_mode: EvalMode, + input_nullable: bool, expr_id: Option, registry: Arc, ) -> Self { @@ -101,6 +106,7 @@ impl AvgDecimal { result_data_type: result_type, sum_data_type: sum_type, eval_mode, + input_nullable, expr_id, registry, } @@ -207,6 +213,17 @@ impl AggregateUDFImpl for AvgDecimal { fn return_type(&self, arg_types: &[DataType]) -> Result { avg_return_type(self.name(), &arg_types[0]) } + + fn is_nullable(&self) -> bool { + // In ANSI mode, overflows cause exceptions rather than null values. + // If the input is non-nullable, the result is also non-nullable. + // In non-ANSI mode, overflows produce null values, so always nullable. + if self.eval_mode == EvalMode::Ansi { + self.input_nullable + } else { + true + } + } } /// An accumulator to compute the average for decimals diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index bf5569b00b..fb7be7d17c 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -41,6 +41,8 @@ pub struct SumDecimal { /// Decimal scale scale: i8, eval_mode: EvalMode, + /// Whether the input expression is nullable + input_nullable: bool, /// Optional expression ID for query context lookup during error creation expr_id: Option, /// Session-scoped query context registry for error reporting @@ -54,6 +56,7 @@ impl PartialEq for SumDecimal { self.precision == other.precision && self.scale == other.scale && self.eval_mode == other.eval_mode + && self.input_nullable == other.input_nullable && self.expr_id == other.expr_id && self.result_type == other.result_type } @@ -66,6 +69,7 @@ impl std::hash::Hash for SumDecimal { self.precision.hash(state); self.scale.hash(state); self.eval_mode.hash(state); + self.input_nullable.hash(state); self.expr_id.hash(state); self.result_type.hash(state); } @@ -75,6 +79,7 @@ impl SumDecimal { pub fn try_new( data_type: DataType, eval_mode: EvalMode, + input_nullable: bool, expr_id: Option, registry: Arc, ) -> DFResult { @@ -92,6 +97,7 @@ impl SumDecimal { precision, scale, eval_mode, + input_nullable, expr_id, registry, }) @@ -164,8 +170,14 @@ impl AggregateUDFImpl for SumDecimal { } fn is_nullable(&self) -> bool { - // SumDecimal is always nullable because overflows can cause null values - true + // In ANSI mode, overflows cause exceptions rather than null values. + // If the input is non-nullable, the result is also non-nullable. + // In non-ANSI mode, overflows produce null values, so always nullable. + if self.eval_mode == EvalMode::Ansi { + self.input_nullable + } else { + true + } } } @@ -630,6 +642,7 @@ mod tests { assert!(SumDecimal::try_new( DataType::Int32, EvalMode::Legacy, + true, None, crate::create_query_context_map(), ) @@ -657,6 +670,7 @@ mod tests { let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( data_type.clone(), EvalMode::Legacy, + true, None, crate::create_query_context_map(), )?)); diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 9426d1c848..98e3d7952b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1917,6 +1917,43 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("SumDecimal and AvgDecimal nullability depends on ANSI mode and input nullability") { + // Non-nullable input: in ANSI mode the result should be non-nullable because overflows + // throw exceptions instead of producing nulls. + val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq( + (new java.math.BigDecimal("10.00"), 1), + (new java.math.BigDecimal("20.00"), 1), + (new java.math.BigDecimal("30.00"), 2)) + + // Nullable input: result should always be nullable regardless of ANSI mode. + val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq( + (new java.math.BigDecimal("10.00"), 1), + (null.asInstanceOf[java.math.BigDecimal], 1), + (new java.math.BigDecimal("30.00"), 2)) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test SUM with non-nullable input + withParquetTable(nonNullableData, "tbl") { + val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(sumRes) + + val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(avgRes) + } + + // Test SUM/AVG with nullable input + withParquetTable(nullableData, "tbl") { + val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(sumRes) + + val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(avgRes) + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) From f0b9b67a77328f510bcac1b47a6d60de04b28009 Mon Sep 17 00:00:00 2001 From: vaibhawvipul Date: Tue, 24 Mar 2026 11:41:05 +0530 Subject: [PATCH 2/2] Revert input_nullable changes; keep nullable=true to match Spark behavior In Spark, Sum.nullable and Average.nullable both return true irrespective of ANSI mode. Reverted the ANSI-aware is_nullable logic and added comments documenting this Spark behavior. Test updated to assert nullable is always true. --- native/core/src/execution/planner.rs | 4 ---- .../spark-expr/src/agg_funcs/avg_decimal.rs | 17 +++-------------- .../spark-expr/src/agg_funcs/sum_decimal.rs | 19 +++---------------- .../comet/exec/CometAggregateSuite.scala | 14 ++++++++------ 4 files changed, 14 insertions(+), 40 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 3b174bd145..5af31fcc22 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2019,11 +2019,9 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - let input_nullable = child.nullable(&schema)?; let func = AggregateUDF::new_from_impl(SumDecimal::try_new( datatype, eval_mode, - input_nullable, spark_expr.expr_id, Arc::clone(&self.query_context_registry), )?); @@ -2059,12 +2057,10 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - let input_nullable = child.nullable(&schema)?; let func = AggregateUDF::new_from_impl(AvgDecimal::new( datatype, input_datatype, eval_mode, - input_nullable, spark_expr.expr_id, Arc::clone(&self.query_context_registry), )); diff --git a/native/spark-expr/src/agg_funcs/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs index ae9700c4e3..08e335f427 100644 --- a/native/spark-expr/src/agg_funcs/avg_decimal.rs +++ b/native/spark-expr/src/agg_funcs/avg_decimal.rs @@ -62,8 +62,6 @@ pub struct AvgDecimal { sum_data_type: DataType, result_data_type: DataType, eval_mode: EvalMode, - /// Whether the input expression is nullable - input_nullable: bool, expr_id: Option, registry: Arc, } @@ -74,7 +72,6 @@ impl PartialEq for AvgDecimal { self.sum_data_type == other.sum_data_type && self.result_data_type == other.result_data_type && self.eval_mode == other.eval_mode - && self.input_nullable == other.input_nullable && self.expr_id == other.expr_id } } @@ -86,7 +83,6 @@ impl std::hash::Hash for AvgDecimal { self.sum_data_type.hash(state); self.result_data_type.hash(state); self.eval_mode.hash(state); - self.input_nullable.hash(state); self.expr_id.hash(state); } } @@ -97,7 +93,6 @@ impl AvgDecimal { result_type: DataType, sum_type: DataType, eval_mode: EvalMode, - input_nullable: bool, expr_id: Option, registry: Arc, ) -> Self { @@ -106,7 +101,6 @@ impl AvgDecimal { result_data_type: result_type, sum_data_type: sum_type, eval_mode, - input_nullable, expr_id, registry, } @@ -215,14 +209,9 @@ impl AggregateUDFImpl for AvgDecimal { } fn is_nullable(&self) -> bool { - // In ANSI mode, overflows cause exceptions rather than null values. - // If the input is non-nullable, the result is also non-nullable. - // In non-ANSI mode, overflows produce null values, so always nullable. - if self.eval_mode == EvalMode::Ansi { - self.input_nullable - } else { - true - } + // In Spark, Sum.nullable and Average.nullable both return true irrespective of ANSI mode. + // AvgDecimal is always nullable because overflows can cause null values. + true } } diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index fb7be7d17c..56a735493c 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -41,8 +41,6 @@ pub struct SumDecimal { /// Decimal scale scale: i8, eval_mode: EvalMode, - /// Whether the input expression is nullable - input_nullable: bool, /// Optional expression ID for query context lookup during error creation expr_id: Option, /// Session-scoped query context registry for error reporting @@ -56,7 +54,6 @@ impl PartialEq for SumDecimal { self.precision == other.precision && self.scale == other.scale && self.eval_mode == other.eval_mode - && self.input_nullable == other.input_nullable && self.expr_id == other.expr_id && self.result_type == other.result_type } @@ -69,7 +66,6 @@ impl std::hash::Hash for SumDecimal { self.precision.hash(state); self.scale.hash(state); self.eval_mode.hash(state); - self.input_nullable.hash(state); self.expr_id.hash(state); self.result_type.hash(state); } @@ -79,7 +75,6 @@ impl SumDecimal { pub fn try_new( data_type: DataType, eval_mode: EvalMode, - input_nullable: bool, expr_id: Option, registry: Arc, ) -> DFResult { @@ -97,7 +92,6 @@ impl SumDecimal { precision, scale, eval_mode, - input_nullable, expr_id, registry, }) @@ -170,14 +164,9 @@ impl AggregateUDFImpl for SumDecimal { } fn is_nullable(&self) -> bool { - // In ANSI mode, overflows cause exceptions rather than null values. - // If the input is non-nullable, the result is also non-nullable. - // In non-ANSI mode, overflows produce null values, so always nullable. - if self.eval_mode == EvalMode::Ansi { - self.input_nullable - } else { - true - } + // In Spark, Sum.nullable and Average.nullable both return true irrespective of ANSI mode. + // SumDecimal is always nullable because overflows can cause null values. + true } } @@ -642,7 +631,6 @@ mod tests { assert!(SumDecimal::try_new( DataType::Int32, EvalMode::Legacy, - true, None, crate::create_query_context_map(), ) @@ -670,7 +658,6 @@ mod tests { let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( data_type.clone(), EvalMode::Legacy, - true, None, crate::create_query_context_map(), )?)); diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 98e3d7952b..be60f4aaee 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1917,15 +1917,15 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("SumDecimal and AvgDecimal nullability depends on ANSI mode and input nullability") { - // Non-nullable input: in ANSI mode the result should be non-nullable because overflows - // throw exceptions instead of producing nulls. + test("SumDecimal and AvgDecimal nullable should always be true") { + // SumDecimal and AvgDecimal currently hardcode nullable=true. + // This matches Spark's Sum.nullable and Average.nullable which always return true, + // regardless of ANSI mode or input nullability. val nonNullableData: Seq[(java.math.BigDecimal, Int)] = Seq( (new java.math.BigDecimal("10.00"), 1), (new java.math.BigDecimal("20.00"), 1), (new java.math.BigDecimal("30.00"), 2)) - // Nullable input: result should always be nullable regardless of ANSI mode. val nullableData: Seq[(java.math.BigDecimal, Int)] = Seq( (new java.math.BigDecimal("10.00"), 1), (null.asInstanceOf[java.math.BigDecimal], 1), @@ -1933,22 +1933,24 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - // Test SUM with non-nullable input withParquetTable(nonNullableData, "tbl") { val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(sumRes) + assert(sumRes.schema.fields(1).nullable == true) val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(avgRes) + assert(avgRes.schema.fields(1).nullable == true) } - // Test SUM/AVG with nullable input withParquetTable(nullableData, "tbl") { val sumRes = sql("SELECT _2, sum(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(sumRes) + assert(sumRes.schema.fields(1).nullable == true) val avgRes = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(avgRes) + assert(avgRes.schema.fields(1).nullable == true) } } }