From b999feafb8877c06bbeea3153570119bc9d7117c Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 11:20:37 +0530 Subject: [PATCH 1/7] feat: Support decimal for variance --- .../functions-aggregate/src/variance.rs | 76 +++++++++++++++++-- .../sqllogictest/test_files/aggregate.slt | 17 +++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 846c145cb11e7..33dcae872fa71 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -29,8 +29,8 @@ use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarVa use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, - Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, GroupsAccumulator, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, @@ -55,6 +55,29 @@ make_udaf_expr_and_func!( var_pop_udaf ); +fn variance_signature() -> Signature { + Signature::one_of( + vec![ + TypeSignature::Numeric(1), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Decimal, + )]), + ], + Volatility::Immutable, + ) +} + +fn is_numeric_or_decimal(data_type: &DataType) -> bool { + data_type.is_numeric() + || matches!( + data_type, + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the statistical sample variance of a set of numbers.", @@ -86,7 +109,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: variance_signature(), } } } @@ -179,7 +202,7 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: variance_signature(), } } } @@ -198,7 +221,7 @@ impl AggregateUDFImpl for VariancePopulation { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { + if !is_numeric_or_decimal(&arg_types[0]) { return plan_err!("Variance requires numeric input types"); } @@ -583,10 +606,53 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { #[cfg(test)] mod tests { + use arrow::array::Decimal128Builder; use datafusion_expr::EmitTo; + use std::sync::Arc; use super::*; + #[test] + fn variance_population_accepts_decimal() -> Result<()> { + let variance = VariancePopulation::new(); + variance.return_type(&[DataType::Decimal128(10, 3)])?; + Ok(()) + } + + #[test] + fn variance_decimal_input() -> Result<()> { + let mut builder = Decimal128Builder::with_capacity(20); + for i in 0..10 { + builder.append_value(110000 + i); + } + for i in 0..10 { + builder.append_value(-((100000 + i) as i128)); + } + let decimal_array = builder.finish().with_precision_and_scale(10, 3).unwrap(); + let array: ArrayRef = Arc::new(decimal_array); + + let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?; + let pop_input = [Arc::clone(&array)]; + pop_acc.update_batch(&pop_input)?; + assert_variance(pop_acc.evaluate()?, 11025.9450285); + + let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?; + let sample_input = [array]; + sample_acc.update_batch(&sample_input)?; + assert_variance(sample_acc.evaluate()?, 11606.257924736841); + + Ok(()) + } + + fn assert_variance(value: ScalarValue, expected: f64) { + match value { + ScalarValue::Float64(Some(actual)) => { + assert!((actual - expected).abs() < 1e-9) + } + other => panic!("expected Float64 result, got {other:?}"), + } + } + #[test] fn test_groups_accumulator_merge_empty_states() -> Result<()> { let state_1 = vec![ diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a1b868b0b028f..51e2c4e75c380 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5629,6 +5629,23 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2 A 110.0045 Decimal128(14, 7) B -100.0045 Decimal128(14, 7) +# aggregate_decimal_variance +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from d_table +---- +11025.945028500004 Float64 + +query RT +select var(c1), arrow_typeof(var(c1)) from d_table +---- +11606.257924736847 Float64 + +query TRT +select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 +---- +A 0.00000825 Float64 +B 0.00000825 Float64 + # aggregate_decimal_count_distinct query I select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table From abc67804756ac1b9cb432691f363326446949f5a Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 14:28:07 +0530 Subject: [PATCH 2/7] native decimal support using accumulator --- .../functions-aggregate/src/variance.rs | 575 +++++++++++++++++- .../sqllogictest/test_files/aggregate.slt | 8 +- 2 files changed, 573 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 33dcae872fa71..6673e09f2799d 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -20,12 +20,23 @@ use arrow::datatypes::FieldRef; use arrow::{ - array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, + array::{ + Array, ArrayRef, AsArray, BooleanArray, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, Float64Array, Float64Builder, PrimitiveArray, + UInt64Array, UInt64Builder, + }, buffer::NullBuffer, compute::kernels::cast, - datatypes::{DataType, Field}, + datatypes::i256, + datatypes::{ + ArrowNumericType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DecimalType, Field, DECIMAL256_MAX_SCALE, + }, +}; +use datafusion_common::{ + downcast_value, exec_err, not_impl_err, plan_err, DataFusionError, Result, + ScalarValue, }; -use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, @@ -36,8 +47,9 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::convert::TryInto; use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, marker::PhantomData, ops::Neg, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -67,6 +79,61 @@ fn variance_signature() -> Signature { ) } +const DECIMAL_VARIANCE_BINARY_SIZE: i32 = 32; + +fn decimal_overflow_err() -> DataFusionError { + DataFusionError::Execution("Decimal variance overflow".to_string()) +} + +fn i256_to_f64_lossy(value: i256) -> f64 { + const SCALE: f64 = 18446744073709551616.0; // 2^64 + let mut abs = value; + let negative = abs < i256::ZERO; + if negative { + abs = abs.neg(); + } + let bytes = abs.to_le_bytes(); + let mut result = 0f64; + for chunk in bytes.chunks_exact(8).rev() { + let chunk_val = u64::from_le_bytes(chunk.try_into().unwrap()); + result = result * SCALE + chunk_val as f64; + } + if negative { + -result + } else { + result + } +} + +fn decimal_scale(dt: &DataType) -> Option { + match dt { + DataType::Decimal32(_, scale) + | DataType::Decimal64(_, scale) + | DataType::Decimal128(_, scale) + | DataType::Decimal256(_, scale) => Some(*scale), + _ => None, + } +} + +fn decimal_variance_state_fields(name: &str) -> Vec { + vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new( + format_state_name(name, "sum"), + DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE), + true, + ), + Field::new( + format_state_name(name, "sum_squares"), + DataType::FixedSizeBinary(DECIMAL_VARIANCE_BINARY_SIZE), + true, + ), + ] + .into_iter() + .map(Arc::new) + .collect() +} + fn is_numeric_or_decimal(data_type: &DataType) -> bool { data_type.is_numeric() || matches!( @@ -78,6 +145,460 @@ fn is_numeric_or_decimal(data_type: &DataType) -> bool { ) } +fn i256_from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != DECIMAL_VARIANCE_BINARY_LEN { + return exec_err!( + "Decimal variance state expected {} bytes got {}", + DECIMAL_VARIANCE_BINARY_LEN, + bytes.len() + ); + } + let mut buffer = [0u8; DECIMAL_VARIANCE_BINARY_LEN]; + buffer.copy_from_slice(bytes); + Ok(i256::from_le_bytes(buffer)) +} + +const DECIMAL_VARIANCE_BINARY_LEN: usize = DECIMAL_VARIANCE_BINARY_SIZE as usize; + +fn i256_to_scalar(value: i256) -> ScalarValue { + ScalarValue::FixedSizeBinary( + DECIMAL_VARIANCE_BINARY_SIZE, + Some(value.to_le_bytes().to_vec()), + ) +} + +fn create_decimal_variance_accumulator( + data_type: &DataType, + stats_type: StatsType, +) -> Result>> { + let accumulator = match data_type { + DataType::Decimal32(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal32Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal64(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal64Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal128(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal128Type, + >::try_new( + *scale, stats_type + )?) as Box), + DataType::Decimal256(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + Decimal256Type, + >::try_new( + *scale, stats_type + )?) as Box), + _ => None, + }; + Ok(accumulator) +} + +fn create_decimal_variance_groups_accumulator( + data_type: &DataType, + stats_type: StatsType, +) -> Result>> { + let accumulator = match data_type { + DataType::Decimal32(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal64(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal128(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + DataType::Decimal256(_, scale) => Some(Box::new( + DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + ) as Box), + _ => None, + }; + Ok(accumulator) +} + +trait DecimalNative: Copy { + fn to_i256(self) -> i256; +} + +impl DecimalNative for i32 { + fn to_i256(self) -> i256 { + i256::from(self) + } +} + +impl DecimalNative for i64 { + fn to_i256(self) -> i256 { + i256::from(self) + } +} + +impl DecimalNative for i128 { + fn to_i256(self) -> i256 { + i256::from_i128(self) + } +} + +impl DecimalNative for i256 { + fn to_i256(self) -> i256 { + self + } +} + +#[derive(Clone, Debug, Default)] +struct DecimalVarianceState { + count: u64, + sum: i256, + sum_squares: i256, +} + +impl DecimalVarianceState { + fn update(&mut self, value: i256) -> Result<()> { + self.count = self.count.checked_add(1).ok_or_else(decimal_overflow_err)?; + self.sum = self + .sum + .checked_add(value) + .ok_or_else(decimal_overflow_err)?; + let square = value.checked_mul(value).ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_add(square) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn retract(&mut self, value: i256) -> Result<()> { + if self.count == 0 { + return exec_err!("Decimal variance retract underflow"); + } + self.count -= 1; + self.sum = self + .sum + .checked_sub(value) + .ok_or_else(decimal_overflow_err)?; + let square = value.checked_mul(value).ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_sub(square) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn merge(&mut self, other: &Self) -> Result<()> { + self.count = self + .count + .checked_add(other.count) + .ok_or_else(decimal_overflow_err)?; + self.sum = self + .sum + .checked_add(other.sum) + .ok_or_else(decimal_overflow_err)?; + self.sum_squares = self + .sum_squares + .checked_add(other.sum_squares) + .ok_or_else(decimal_overflow_err)?; + Ok(()) + } + + fn variance(&self, stats_type: StatsType, scale: i8) -> Result> { + if self.count == 0 { + return Ok(None); + } + if matches!(stats_type, StatsType::Sample) && self.count <= 1 { + return Ok(None); + } + + let count_i256 = i256::from_i128(self.count as i128); + let scaled_sum_squares = self + .sum_squares + .checked_mul(count_i256) + .ok_or_else(decimal_overflow_err)?; + let sum_squared = self + .sum + .checked_mul(self.sum) + .ok_or_else(decimal_overflow_err)?; + let numerator = scaled_sum_squares + .checked_sub(sum_squared) + .ok_or_else(decimal_overflow_err)?; + + let numerator = if numerator < i256::ZERO { + i256::ZERO + } else { + numerator + }; + + let denominator_counts = match stats_type { + StatsType::Population => { + let count = self.count as f64; + count * count + } + StatsType::Sample => { + let count = self.count as f64; + count * ((self.count - 1) as f64) + } + }; + + if denominator_counts == 0.0 { + return Ok(None); + } + + let numerator_f64 = i256_to_f64_lossy(numerator); + let scale_factor = 10f64.powi(2 * scale as i32); + Ok(Some(numerator_f64 / (denominator_counts * scale_factor))) + } + + fn to_scalar_state(&self) -> Vec { + vec![ + ScalarValue::from(self.count), + i256_to_scalar(self.sum), + i256_to_scalar(self.sum_squares), + ] + } +} + +#[derive(Debug)] +struct DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + state: DecimalVarianceState, + scale: i8, + stats_type: StatsType, + _marker: PhantomData, +} + +impl DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn try_new(scale: i8, stats_type: StatsType) -> Result { + if scale > DECIMAL256_MAX_SCALE { + return exec_err!( + "Decimal variance does not support scale {} greater than {}", + scale, + DECIMAL256_MAX_SCALE + ); + } + Ok(Self { + state: DecimalVarianceState::default(), + scale, + stats_type, + _marker: PhantomData, + }) + } + + fn convert_array(values: &ArrayRef) -> &PrimitiveArray { + values.as_primitive::() + } +} + +impl Accumulator for DecimalVarianceAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn state(&mut self) -> Result> { + Ok(self.state.to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = Self::convert_array(&values[0]); + for value in array.iter().flatten() { + self.state.update(value.to_i256())?; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = Self::convert_array(&values[0]); + for value in array.iter().flatten() { + self.state.retract(value.to_i256())?; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], FixedSizeBinaryArray); + let sum_squares = downcast_value!(states[2], FixedSizeBinaryArray); + + for i in 0..counts.len() { + if counts.is_null(i) { + continue; + } + let count = counts.value(i); + if count == 0 { + continue; + } + let sum = i256_from_bytes(sums.value(i))?; + let sum_sq = i256_from_bytes(sum_squares.value(i))?; + let other = DecimalVarianceState { + count, + sum, + sum_squares: sum_sq, + }; + self.state.merge(&other)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + match self.state.variance(self.stats_type, self.scale)? { + Some(v) => Ok(ScalarValue::Float64(Some(v))), + None => Ok(ScalarValue::Float64(None)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +#[derive(Debug)] +struct DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + states: Vec, + scale: i8, + stats_type: StatsType, + _marker: PhantomData, +} + +impl DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn new(scale: i8, stats_type: StatsType) -> Self { + Self { + states: Vec::new(), + scale, + stats_type, + _marker: PhantomData, + } + } + + fn resize(&mut self, total_num_groups: usize) { + if self.states.len() < total_num_groups { + self.states + .resize(total_num_groups, DecimalVarianceState::default()); + } + } +} + +impl GroupsAccumulator for DecimalVarianceGroupsAccumulator +where + T: DecimalType + ArrowNumericType + Debug, + T::Native: DecimalNative, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = values[0].as_primitive::(); + self.resize(total_num_groups); + for (row, group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.value(row) { + continue; + } + } + if array.is_null(row) { + continue; + } + let value = array.value(row).to_i256(); + self.states[*group_index].update(value)?; + } + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let counts = downcast_value!(values[0], UInt64Array); + let sums = downcast_value!(values[1], FixedSizeBinaryArray); + let sum_squares = downcast_value!(values[2], FixedSizeBinaryArray); + self.resize(total_num_groups); + + for (row, group_index) in group_indices.iter().enumerate() { + if counts.is_null(row) { + continue; + } + let count = counts.value(row); + if count == 0 { + continue; + } + let sum = i256_from_bytes(sums.value(row))?; + let sum_sq = i256_from_bytes(sum_squares.value(row))?; + let other = DecimalVarianceState { + count, + sum, + sum_squares: sum_sq, + }; + self.states[*group_index].merge(&other)?; + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { + let states = emit_to.take_needed(&mut self.states); + let mut builder = Float64Builder::with_capacity(states.len()); + for state in &states { + match state.variance(self.stats_type, self.scale)? { + Some(value) => builder.append_value(value), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + + fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { + let states = emit_to.take_needed(&mut self.states); + let mut counts = UInt64Builder::with_capacity(states.len()); + let mut sums = FixedSizeBinaryBuilder::with_capacity( + states.len(), + DECIMAL_VARIANCE_BINARY_SIZE, + ); + let mut sum_squares = FixedSizeBinaryBuilder::with_capacity( + states.len(), + DECIMAL_VARIANCE_BINARY_SIZE, + ); + + for state in states { + counts.append_value(state.count); + sums.append_value(state.sum.to_le_bytes())?; + sum_squares.append_value(state.sum_squares.to_le_bytes())?; + } + + Ok(vec![ + Arc::new(counts.finish()), + Arc::new(sums.finish()), + Arc::new(sum_squares.finish()), + ]) + } + + fn size(&self) -> usize { + self.states.capacity() * size_of::() + } +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the statistical sample variance of a set of numbers.", @@ -133,6 +654,14 @@ impl AggregateUDFImpl for VarianceSample { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; + if args + .input_fields + .first() + .and_then(|field| decimal_scale(field.data_type())) + .is_some() + { + return Ok(decimal_variance_state_fields(name)); + } Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), @@ -148,6 +677,13 @@ impl AggregateUDFImpl for VarianceSample { return not_impl_err!("VAR(DISTINCT) aggregations are not available"); } + if let Some(acc) = create_decimal_variance_accumulator( + acc_args.expr_fields[0].data_type(), + StatsType::Sample, + )? { + return Ok(acc); + } + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) } @@ -161,8 +697,14 @@ impl AggregateUDFImpl for VarianceSample { fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + if let Some(acc) = create_decimal_variance_groups_accumulator( + args.expr_fields[0].data_type(), + StatsType::Sample, + )? { + return Ok(acc); + } Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) } @@ -230,6 +772,14 @@ impl AggregateUDFImpl for VariancePopulation { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; + if args + .input_fields + .first() + .and_then(|field| decimal_scale(field.data_type())) + .is_some() + { + return Ok(decimal_variance_state_fields(name)); + } Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), @@ -245,6 +795,13 @@ impl AggregateUDFImpl for VariancePopulation { return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } + if let Some(acc) = create_decimal_variance_accumulator( + acc_args.expr_fields[0].data_type(), + StatsType::Population, + )? { + return Ok(acc); + } + Ok(Box::new(VarianceAccumulator::try_new( StatsType::Population, )?)) @@ -260,8 +817,14 @@ impl AggregateUDFImpl for VariancePopulation { fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + if let Some(acc) = create_decimal_variance_groups_accumulator( + args.expr_fields[0].data_type(), + StatsType::Population, + )? { + return Ok(acc); + } Ok(Box::new(VarianceGroupsAccumulator::new( StatsType::Population, ))) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 51e2c4e75c380..a669f26fbc30a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5633,18 +5633,18 @@ B -100.0045 Decimal128(14, 7) query RT select var_pop(c1), arrow_typeof(var_pop(c1)) from d_table ---- -11025.945028500004 Float64 +11025.9450285 Float64 query RT select var(c1), arrow_typeof(var(c1)) from d_table ---- -11606.257924736847 Float64 +11606.257924736841 Float64 query TRT select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 0.00000825 Float64 -B 0.00000825 Float64 +A 0.000008249999999997783 Float64 +B 0.000008249999999997783 Float64 # aggregate_decimal_count_distinct query I From 5b92361c75e4f185e8860f09a17895575bae7b39 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 26 Nov 2025 15:34:22 +0530 Subject: [PATCH 3/7] fixed aggregate test --- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a669f26fbc30a..b248b86e533f1 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5643,8 +5643,8 @@ select var(c1), arrow_typeof(var(c1)) from d_table query TRT select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 0.000008249999999997783 Float64 -B 0.000008249999999997783 Float64 +A 0.00000825 Float64 +B 0.00000825 Float64 # aggregate_decimal_count_distinct query I From 18af96445979d3e8384c6dfbfd4ebff6016b3126 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 27 Nov 2025 10:41:40 +0530 Subject: [PATCH 4/7] fixed incorrect tests and handled edge cases --- .../functions-aggregate/src/variance.rs | 165 ++++++++++++++++-- 1 file changed, 153 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 6673e09f2799d..654a6625fc5e6 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -203,16 +203,24 @@ fn create_decimal_variance_groups_accumulator( ) -> Result>> { let accumulator = match data_type { DataType::Decimal32(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal64(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal128(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), DataType::Decimal256(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::new(*scale, stats_type), + DecimalVarianceGroupsAccumulator::::try_new( + *scale, stats_type, + )?, ) as Box), _ => None, }; @@ -323,7 +331,12 @@ impl DecimalVarianceState { .checked_sub(sum_squared) .ok_or_else(decimal_overflow_err)?; - let numerator = if numerator < i256::ZERO { + let negative_numerator = numerator < i256::ZERO; + debug_assert!( + !negative_numerator, + "Decimal variance numerator became negative: {numerator:?}. This indicates precision loss or overflow in intermediate calculations." + ); + let numerator = if negative_numerator { i256::ZERO } else { numerator @@ -479,13 +492,20 @@ where T: DecimalType + ArrowNumericType + Debug, T::Native: DecimalNative, { - fn new(scale: i8, stats_type: StatsType) -> Self { - Self { + fn try_new(scale: i8, stats_type: StatsType) -> Result { + if scale > DECIMAL256_MAX_SCALE { + return exec_err!( + "Decimal variance does not support scale {} greater than {}", + scale, + DECIMAL256_MAX_SCALE + ); + } + Ok(Self { states: Vec::new(), scale, stats_type, _marker: PhantomData, - } + }) } fn resize(&mut self, total_num_groups: usize) { @@ -512,7 +532,7 @@ where self.resize(total_num_groups); for (row, group_index) in group_indices.iter().enumerate() { if let Some(filter) = opt_filter { - if !filter.value(row) { + if !filter.is_valid(row) || !filter.value(row) { continue; } } @@ -1169,7 +1189,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { #[cfg(test)] mod tests { - use arrow::array::Decimal128Builder; + use arrow::array::{Decimal128Array, Decimal128Builder, Float64Array}; + use arrow::datatypes::DECIMAL256_MAX_PRECISION; use datafusion_expr::EmitTo; use std::sync::Arc; @@ -1194,12 +1215,16 @@ mod tests { let decimal_array = builder.finish().with_precision_and_scale(10, 3).unwrap(); let array: ArrayRef = Arc::new(decimal_array); - let mut pop_acc = VarianceAccumulator::try_new(StatsType::Population)?; + let mut pop_acc = DecimalVarianceAccumulator::::try_new( + 3, + StatsType::Population, + )?; let pop_input = [Arc::clone(&array)]; pop_acc.update_batch(&pop_input)?; assert_variance(pop_acc.evaluate()?, 11025.9450285); - let mut sample_acc = VarianceAccumulator::try_new(StatsType::Sample)?; + let mut sample_acc = + DecimalVarianceAccumulator::::try_new(3, StatsType::Sample)?; let sample_input = [array]; sample_acc.update_batch(&sample_input)?; assert_variance(sample_acc.evaluate()?, 11606.257924736841); @@ -1207,6 +1232,122 @@ mod tests { Ok(()) } + #[test] + fn variance_decimal_handles_nulls() -> Result<()> { + let mut builder = Decimal128Builder::with_capacity(3); + builder.append_value(100); + builder.append_null(); + builder.append_value(300); + let array = builder.finish().with_precision_and_scale(10, 2).unwrap(); + let array: ArrayRef = Arc::new(array); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[Arc::clone(&array)])?; + assert_variance(acc.evaluate()?, 1.0); + Ok(()) + } + + #[test] + fn variance_decimal_empty_input() -> Result<()> { + let array = Decimal128Array::from(Vec::>::new()) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[array])?; + match acc.evaluate()? { + ScalarValue::Float64(None) => Ok(()), + other => panic!("expected NULL variance for empty input, got {other:?}"), + } + } + + #[test] + fn variance_decimal_single_value_sample() -> Result<()> { + let array = Decimal128Array::from(vec![Some(500)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + let mut acc = + DecimalVarianceAccumulator::::try_new(2, StatsType::Sample)?; + acc.update_batch(&[array])?; + match acc.evaluate()? { + ScalarValue::Float64(None) => Ok(()), + other => { + panic!("expected NULL sample variance for single value, got {other:?}") + } + } + } + + #[test] + fn variance_decimal_groups_mixed_values() -> Result<()> { + let array = + Decimal128Array::from(vec![Some(100), Some(300), Some(-200), Some(-400)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let array: ArrayRef = Arc::new(array); + let mut groups = DecimalVarianceGroupsAccumulator::::try_new( + 2, + StatsType::Population, + )?; + let group_indices = vec![0, 0, 1, 1]; + groups.update_batch(&[Arc::clone(&array)], &group_indices, None, 2)?; + let result = groups.evaluate(EmitTo::All)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 1.0).abs() < 1e-9); + assert!((result.value(1) - 1.0).abs() < 1e-9); + Ok(()) + } + + #[test] + fn variance_decimal_max_scale() -> Result<()> { + let values = vec![ + ScalarValue::Decimal256( + Some(i256::from_i128(1)), + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, + ), + ScalarValue::Decimal256( + Some(i256::from_i128(-1)), + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, + ), + ]; + let array = ScalarValue::iter_to_array(values).unwrap(); + let mut acc = DecimalVarianceAccumulator::::try_new( + DECIMAL256_MAX_SCALE, + StatsType::Population, + )?; + acc.update_batch(&[array])?; + assert_variance(acc.evaluate()?, 1e-152); + Ok(()) + } + + #[test] + fn variance_decimal_retract_batch() -> Result<()> { + let update = Decimal128Array::from(vec![Some(100), Some(200), Some(300)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let retract = Decimal128Array::from(vec![Some(100), Some(200)]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let mut acc = DecimalVarianceAccumulator::::try_new( + 2, + StatsType::Population, + )?; + acc.update_batch(&[Arc::new(update)])?; + acc.retract_batch(&[Arc::new(retract)])?; + assert_variance(acc.evaluate()?, 0.0); + Ok(()) + } + fn assert_variance(value: ScalarValue, expected: f64) { match value { ScalarValue::Float64(Some(actual)) => { From 729c82e5016723be94cc607051f4aeb1f917c21d Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 11 Dec 2025 00:08:01 +0530 Subject: [PATCH 5/7] refactored code and added native decimal variance --- .../functions-aggregate/src/variance.rs | 358 ++++++++++++------ .../sqllogictest/test_files/aggregate.slt | 8 +- 2 files changed, 243 insertions(+), 123 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 654a6625fc5e6..9434e0f523ec5 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -21,16 +21,16 @@ use arrow::datatypes::FieldRef; use arrow::{ array::{ - Array, ArrayRef, AsArray, BooleanArray, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, Float64Array, Float64Builder, PrimitiveArray, - UInt64Array, UInt64Builder, + Array, ArrayRef, AsArray, BooleanArray, Decimal256Builder, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, Float64Array, PrimitiveArray, UInt64Array, UInt64Builder, }, buffer::NullBuffer, compute::kernels::cast, datatypes::i256, datatypes::{ ArrowNumericType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, DecimalType, Field, DECIMAL256_MAX_SCALE, + Decimal64Type, DecimalType, Field, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, }, }; use datafusion_common::{ @@ -40,16 +40,17 @@ use datafusion_common::{ use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, Coercion, Documentation, GroupsAccumulator, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, }; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; -use std::convert::TryInto; use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, marker::PhantomData, ops::Neg, sync::Arc}; +#[cfg(test)] +use std::{convert::TryInto, ops::Neg}; +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -68,23 +69,17 @@ make_udaf_expr_and_func!( ); fn variance_signature() -> Signature { - Signature::one_of( - vec![ - TypeSignature::Numeric(1), - TypeSignature::Coercible(vec![Coercion::new_exact( - TypeSignatureClass::Decimal, - )]), - ], - Volatility::Immutable, - ) + Signature::numeric(1, Volatility::Immutable) } const DECIMAL_VARIANCE_BINARY_SIZE: i32 = 32; +const DECIMAL_VARIANCE_SCALE_INCREMENT: i8 = 6; fn decimal_overflow_err() -> DataFusionError { DataFusionError::Execution("Decimal variance overflow".to_string()) } +#[cfg(test)] fn i256_to_f64_lossy(value: i256) -> f64 { const SCALE: f64 = 18446744073709551616.0; // 2^64 let mut abs = value; @@ -115,6 +110,27 @@ fn decimal_scale(dt: &DataType) -> Option { } } +#[derive(Clone, Copy, Debug)] +struct DecimalVarianceParams { + input_scale: i8, + result_precision: u8, + result_scale: i8, +} + +fn decimal_variance_params(data_type: &DataType) -> Option { + decimal_scale(data_type).map(|input_scale| { + let base_scale = input_scale.saturating_mul(2); + let target_scale = base_scale + .saturating_add(DECIMAL_VARIANCE_SCALE_INCREMENT) + .min(DECIMAL256_MAX_SCALE); + DecimalVarianceParams { + input_scale, + result_precision: DECIMAL256_MAX_PRECISION, + result_scale: target_scale, + } + }) +} + fn decimal_variance_state_fields(name: &str) -> Vec { vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -134,15 +150,13 @@ fn decimal_variance_state_fields(name: &str) -> Vec { .collect() } -fn is_numeric_or_decimal(data_type: &DataType) -> bool { - data_type.is_numeric() - || matches!( - data_type, - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ) +fn pow10_i256(exp: u32) -> Result { + let mut value = i256::from_i128(1); + let ten = i256::from_i128(10); + for _ in 0..exp { + value = value.checked_mul(ten).ok_or_else(decimal_overflow_err)?; + } + Ok(value) } fn i256_from_bytes(bytes: &[u8]) -> Result { @@ -171,26 +185,27 @@ fn create_decimal_variance_accumulator( data_type: &DataType, stats_type: StatsType, ) -> Result>> { - let accumulator = match data_type { - DataType::Decimal32(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + let Some(params) = decimal_variance_params(data_type) else { + return Ok(None); + }; + let accumulator: Option> = match data_type { + DataType::Decimal32(_, _) => Some(Box::new(DecimalVarianceAccumulator::< Decimal32Type, - >::try_new( - *scale, stats_type - )?) as Box), - DataType::Decimal64(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + >::try_new(params, stats_type)?) + as Box), + DataType::Decimal64(_, _) => Some(Box::new(DecimalVarianceAccumulator::< Decimal64Type, - >::try_new( - *scale, stats_type - )?) as Box), - DataType::Decimal128(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + >::try_new(params, stats_type)?) + as Box), + DataType::Decimal128(_, _) => Some(Box::new(DecimalVarianceAccumulator::< Decimal128Type, >::try_new( - *scale, stats_type + params, stats_type )?) as Box), - DataType::Decimal256(_, scale) => Some(Box::new(DecimalVarianceAccumulator::< + DataType::Decimal256(_, _) => Some(Box::new(DecimalVarianceAccumulator::< Decimal256Type, >::try_new( - *scale, stats_type + params, stats_type )?) as Box), _ => None, }; @@ -201,27 +216,28 @@ fn create_decimal_variance_groups_accumulator( data_type: &DataType, stats_type: StatsType, ) -> Result>> { - let accumulator = match data_type { - DataType::Decimal32(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::try_new( - *scale, stats_type, - )?, - ) as Box), - DataType::Decimal64(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::try_new( - *scale, stats_type, - )?, - ) as Box), - DataType::Decimal128(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::try_new( - *scale, stats_type, - )?, - ) as Box), - DataType::Decimal256(_, scale) => Some(Box::new( - DecimalVarianceGroupsAccumulator::::try_new( - *scale, stats_type, - )?, - ) as Box), + let Some(params) = decimal_variance_params(data_type) else { + return Ok(None); + }; + let accumulator: Option> = match data_type { + DataType::Decimal32(_, _) => Some(Box::new(DecimalVarianceGroupsAccumulator::< + Decimal32Type, + >::try_new(params, stats_type)?) + as Box), + DataType::Decimal64(_, _) => Some(Box::new(DecimalVarianceGroupsAccumulator::< + Decimal64Type, + >::try_new(params, stats_type)?) + as Box), + DataType::Decimal128(_, _) => Some(Box::new(DecimalVarianceGroupsAccumulator::< + Decimal128Type, + >::try_new( + params, stats_type + )?) as Box), + DataType::Decimal256(_, _) => Some(Box::new(DecimalVarianceGroupsAccumulator::< + Decimal256Type, + >::try_new( + params, stats_type + )?) as Box), _ => None, }; Ok(accumulator) @@ -310,7 +326,11 @@ impl DecimalVarianceState { Ok(()) } - fn variance(&self, stats_type: StatsType, scale: i8) -> Result> { + fn variance_decimal( + &self, + stats_type: StatsType, + params: DecimalVarianceParams, + ) -> Result> { if self.count == 0 { return Ok(None); } @@ -343,23 +363,40 @@ impl DecimalVarianceState { }; let denominator_counts = match stats_type { - StatsType::Population => { - let count = self.count as f64; - count * count - } - StatsType::Sample => { - let count = self.count as f64; - count * ((self.count - 1) as f64) - } + StatsType::Population => count_i256 + .checked_mul(count_i256) + .ok_or_else(decimal_overflow_err)?, + StatsType::Sample => count_i256 + .checked_mul(i256::from_i128((self.count - 1) as i128)) + .ok_or_else(decimal_overflow_err)?, }; - if denominator_counts == 0.0 { + if denominator_counts == i256::ZERO { return Ok(None); } - let numerator_f64 = i256_to_f64_lossy(numerator); - let scale_factor = 10f64.powi(2 * scale as i32); - Ok(Some(numerator_f64 / (denominator_counts * scale_factor))) + let two_scale = params.input_scale.saturating_mul(2); + if params.result_scale >= two_scale { + let up = params.result_scale - two_scale; + let factor = pow10_i256(up as u32)?; + let scaled_numerator = numerator + .checked_mul(factor) + .ok_or_else(decimal_overflow_err)?; + let value = scaled_numerator + .checked_div(denominator_counts) + .ok_or_else(decimal_overflow_err)?; + return Ok(Some(value)); + } + + let down = two_scale - params.result_scale; + let factor = pow10_i256(down as u32)?; + let scaled_numerator = numerator + .checked_div(factor) + .ok_or_else(decimal_overflow_err)?; + let value = scaled_numerator + .checked_div(denominator_counts) + .ok_or_else(decimal_overflow_err)?; + Ok(Some(value)) } fn to_scalar_state(&self) -> Vec { @@ -378,7 +415,7 @@ where T::Native: DecimalNative, { state: DecimalVarianceState, - scale: i8, + params: DecimalVarianceParams, stats_type: StatsType, _marker: PhantomData, } @@ -388,17 +425,17 @@ where T: DecimalType + ArrowNumericType + Debug, T::Native: DecimalNative, { - fn try_new(scale: i8, stats_type: StatsType) -> Result { - if scale > DECIMAL256_MAX_SCALE { + fn try_new(params: DecimalVarianceParams, stats_type: StatsType) -> Result { + if params.input_scale > DECIMAL256_MAX_SCALE { return exec_err!( "Decimal variance does not support scale {} greater than {}", - scale, + params.input_scale, DECIMAL256_MAX_SCALE ); } Ok(Self { state: DecimalVarianceState::default(), - scale, + params, stats_type, _marker: PhantomData, }) @@ -460,9 +497,29 @@ where } fn evaluate(&mut self) -> Result { - match self.state.variance(self.stats_type, self.scale)? { - Some(v) => Ok(ScalarValue::Float64(Some(v))), - None => Ok(ScalarValue::Float64(None)), + let value = self.state.variance_decimal(self.stats_type, self.params)?; + match value { + Some(v) => { + if Decimal256Type::validate_decimal_precision( + v, + self.params.result_precision, + self.params.result_scale, + ) + .is_err() + { + return Err(decimal_overflow_err()); + } + Ok(ScalarValue::Decimal256( + Some(v), + self.params.result_precision, + self.params.result_scale, + )) + } + None => Ok(ScalarValue::Decimal256( + None, + self.params.result_precision, + self.params.result_scale, + )), } } @@ -482,7 +539,7 @@ where T::Native: DecimalNative, { states: Vec, - scale: i8, + params: DecimalVarianceParams, stats_type: StatsType, _marker: PhantomData, } @@ -492,17 +549,17 @@ where T: DecimalType + ArrowNumericType + Debug, T::Native: DecimalNative, { - fn try_new(scale: i8, stats_type: StatsType) -> Result { - if scale > DECIMAL256_MAX_SCALE { + fn try_new(params: DecimalVarianceParams, stats_type: StatsType) -> Result { + if params.input_scale > DECIMAL256_MAX_SCALE { return exec_err!( "Decimal variance does not support scale {} greater than {}", - scale, + params.input_scale, DECIMAL256_MAX_SCALE ); } Ok(Self { states: Vec::new(), - scale, + params, stats_type, _marker: PhantomData, }) @@ -579,14 +636,29 @@ where fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { let states = emit_to.take_needed(&mut self.states); - let mut builder = Float64Builder::with_capacity(states.len()); + let mut builder = Decimal256Builder::with_capacity(states.len()); for state in &states { - match state.variance(self.stats_type, self.scale)? { - Some(value) => builder.append_value(value), + match state.variance_decimal(self.stats_type, self.params)? { + Some(value) => { + if Decimal256Type::validate_decimal_precision( + value, + self.params.result_precision, + self.params.result_scale, + ) + .is_err() + { + return Err(decimal_overflow_err()); + } + builder.append_value(value) + } None => builder.append_null(), } } - Ok(Arc::new(builder.finish())) + let array = builder.finish().with_precision_and_scale( + self.params.result_precision, + self.params.result_scale, + )?; + Ok(Arc::new(array)) } fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { @@ -668,7 +740,16 @@ impl AggregateUDFImpl for VarianceSample { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { + if let Some(params) = decimal_variance_params(&arg_types[0]) { + return Ok(DataType::Decimal256( + params.result_precision, + params.result_scale, + )); + } + if !arg_types[0].is_numeric() { + return plan_err!("Variance requires numeric input types"); + } Ok(DataType::Float64) } @@ -783,7 +864,14 @@ impl AggregateUDFImpl for VariancePopulation { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !is_numeric_or_decimal(&arg_types[0]) { + if let Some(params) = decimal_variance_params(&arg_types[0]) { + return Ok(DataType::Decimal256( + params.result_precision, + params.result_scale, + )); + } + + if !arg_types[0].is_numeric() { return plan_err!("Variance requires numeric input types"); } @@ -1189,7 +1277,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { #[cfg(test)] mod tests { - use arrow::array::{Decimal128Array, Decimal128Builder, Float64Array}; + use arrow::array::{ + Decimal128Array, Decimal128Builder, Decimal256Array, Float64Array, + }; use arrow::datatypes::DECIMAL256_MAX_PRECISION; use datafusion_expr::EmitTo; use std::sync::Arc; @@ -1215,19 +1305,27 @@ mod tests { let decimal_array = builder.finish().with_precision_and_scale(10, 3).unwrap(); let array: ArrayRef = Arc::new(decimal_array); + let params = decimal_variance_params(&DataType::Decimal128(10, 3)) + .expect("decimal params"); let mut pop_acc = DecimalVarianceAccumulator::::try_new( - 3, + params, StatsType::Population, )?; let pop_input = [Arc::clone(&array)]; pop_acc.update_batch(&pop_input)?; - assert_variance(pop_acc.evaluate()?, 11025.9450285); + assert_decimal_variance(pop_acc.evaluate()?, 11025.9450285, params.result_scale); - let mut sample_acc = - DecimalVarianceAccumulator::::try_new(3, StatsType::Sample)?; + let mut sample_acc = DecimalVarianceAccumulator::::try_new( + params, + StatsType::Sample, + )?; let sample_input = [array]; sample_acc.update_batch(&sample_input)?; - assert_variance(sample_acc.evaluate()?, 11606.257924736841); + assert_decimal_variance( + sample_acc.evaluate()?, + 11606.257924736841, + params.result_scale, + ); Ok(()) } @@ -1241,12 +1339,14 @@ mod tests { let array = builder.finish().with_precision_and_scale(10, 2).unwrap(); let array: ArrayRef = Arc::new(array); + let params = decimal_variance_params(&DataType::Decimal128(10, 2)) + .expect("decimal params"); let mut acc = DecimalVarianceAccumulator::::try_new( - 2, + params, StatsType::Population, )?; acc.update_batch(&[Arc::clone(&array)])?; - assert_variance(acc.evaluate()?, 1.0); + assert_decimal_variance(acc.evaluate()?, 1.0, params.result_scale); Ok(()) } @@ -1257,13 +1357,15 @@ mod tests { .unwrap(); let array: ArrayRef = Arc::new(array); + let params = decimal_variance_params(&DataType::Decimal128(10, 2)) + .expect("decimal params"); let mut acc = DecimalVarianceAccumulator::::try_new( - 2, + params, StatsType::Population, )?; acc.update_batch(&[array])?; match acc.evaluate()? { - ScalarValue::Float64(None) => Ok(()), + ScalarValue::Decimal256(None, ..) => Ok(()), other => panic!("expected NULL variance for empty input, got {other:?}"), } } @@ -1274,11 +1376,15 @@ mod tests { .with_precision_and_scale(10, 2) .unwrap(); let array: ArrayRef = Arc::new(array); - let mut acc = - DecimalVarianceAccumulator::::try_new(2, StatsType::Sample)?; + let params = decimal_variance_params(&DataType::Decimal128(10, 2)) + .expect("decimal params"); + let mut acc = DecimalVarianceAccumulator::::try_new( + params, + StatsType::Sample, + )?; acc.update_batch(&[array])?; match acc.evaluate()? { - ScalarValue::Float64(None) => Ok(()), + ScalarValue::Decimal256(None, ..) => Ok(()), other => { panic!("expected NULL sample variance for single value, got {other:?}") } @@ -1292,16 +1398,22 @@ mod tests { .with_precision_and_scale(10, 2) .unwrap(); let array: ArrayRef = Arc::new(array); + let params = decimal_variance_params(&DataType::Decimal128(10, 2)) + .expect("decimal params"); let mut groups = DecimalVarianceGroupsAccumulator::::try_new( - 2, + params, StatsType::Population, )?; let group_indices = vec![0, 0, 1, 1]; groups.update_batch(&[Arc::clone(&array)], &group_indices, None, 2)?; let result = groups.evaluate(EmitTo::All)?; - let result = result.as_any().downcast_ref::().unwrap(); - assert!((result.value(0) - 1.0).abs() < 1e-9); - assert!((result.value(1) - 1.0).abs() < 1e-9); + let result = result.as_any().downcast_ref::().unwrap(); + let v0 = + i256_to_f64_lossy(result.value(0)) / 10f64.powi(params.result_scale as i32); + let v1 = + i256_to_f64_lossy(result.value(1)) / 10f64.powi(params.result_scale as i32); + assert!((v0 - 1.0).abs() < 1e-9); + assert!((v1 - 1.0).abs() < 1e-9); Ok(()) } @@ -1320,12 +1432,17 @@ mod tests { ), ]; let array = ScalarValue::iter_to_array(values).unwrap(); - let mut acc = DecimalVarianceAccumulator::::try_new( + let params = decimal_variance_params(&DataType::Decimal256( + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + )) + .expect("decimal params"); + let mut acc = DecimalVarianceAccumulator::::try_new( + params, StatsType::Population, )?; acc.update_batch(&[array])?; - assert_variance(acc.evaluate()?, 1e-152); + assert_decimal_variance(acc.evaluate()?, 1e-152, params.result_scale); Ok(()) } @@ -1338,23 +1455,26 @@ mod tests { .with_precision_and_scale(10, 2) .unwrap(); + let params = decimal_variance_params(&DataType::Decimal128(10, 2)) + .expect("decimal params"); let mut acc = DecimalVarianceAccumulator::::try_new( - 2, + params, StatsType::Population, )?; acc.update_batch(&[Arc::new(update)])?; acc.retract_batch(&[Arc::new(retract)])?; - assert_variance(acc.evaluate()?, 0.0); + assert_decimal_variance(acc.evaluate()?, 0.0, params.result_scale); Ok(()) } - fn assert_variance(value: ScalarValue, expected: f64) { - match value { - ScalarValue::Float64(Some(actual)) => { - assert!((actual - expected).abs() < 1e-9) + fn assert_decimal_variance(value: ScalarValue, expected: f64, scale: i8) { + let actual = match value { + ScalarValue::Decimal256(Some(v), ..) => { + i256_to_f64_lossy(v) / 10f64.powi(scale as i32) } - other => panic!("expected Float64 result, got {other:?}"), - } + other => panic!("expected Decimal256 result, got {other:?}"), + }; + assert!((actual - expected).abs() < 1e-9); } #[test] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b248b86e533f1..507e3e5754b61 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5633,18 +5633,18 @@ B -100.0045 Decimal128(14, 7) query RT select var_pop(c1), arrow_typeof(var_pop(c1)) from d_table ---- -11025.9450285 Float64 +11025.9450285 Decimal256(76, 12) query RT select var(c1), arrow_typeof(var(c1)) from d_table ---- -11606.257924736841 Float64 +11606.257924736842 Decimal256(76, 12) query TRT select c2, var_pop(c1), arrow_typeof(var_pop(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 0.00000825 Float64 -B 0.00000825 Float64 +A 0.00000825 Decimal256(76, 12) +B 0.00000825 Decimal256(76, 12) # aggregate_decimal_count_distinct query I From de2f0ac541769bd50177621f9ece1803926e9cba Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 11 Dec 2025 00:20:26 +0530 Subject: [PATCH 6/7] fix clippy issues --- .../functions-aggregate/src/variance.rs | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 48ca5f080b9e9..03a383f746886 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -28,14 +28,13 @@ use arrow::{ compute::kernels::cast, datatypes::i256, datatypes::{ - ArrowNumericType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, DecimalType, Field, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, + ArrowNumericType, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, }, }; use datafusion_common::{ - downcast_value, exec_err, not_impl_err, plan_err, DataFusionError, Result, - ScalarValue, + DataFusionError, Result, ScalarValue, downcast_value, exec_err, not_impl_err, + plan_err, }; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, @@ -93,11 +92,7 @@ fn i256_to_f64_lossy(value: i256) -> f64 { let chunk_val = u64::from_le_bytes(chunk.try_into().unwrap()); result = result * SCALE + chunk_val as f64; } - if negative { - -result - } else { - result - } + if negative { -result } else { result } } fn decimal_scale(dt: &DataType) -> Option { @@ -588,10 +583,10 @@ where let array = values[0].as_primitive::(); self.resize(total_num_groups); for (row, group_index) in group_indices.iter().enumerate() { - if let Some(filter) = opt_filter { - if !filter.is_valid(row) || !filter.value(row) { - continue; - } + if let Some(filter) = opt_filter + && (!filter.is_valid(row) || !filter.value(row)) + { + continue; } if array.is_null(row) { continue; From 89e8e665799a2ef8cc1d3c7a27fdc16f233913fc Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Fri, 12 Dec 2025 22:49:20 +0530 Subject: [PATCH 7/7] Scale arithmetic now uses wider integers --- .../functions-aggregate/src/variance.rs | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 03a383f746886..0daa3acadf578 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -108,20 +108,22 @@ fn decimal_scale(dt: &DataType) -> Option { #[derive(Clone, Copy, Debug)] struct DecimalVarianceParams { input_scale: i8, + full_two_scale: i32, result_precision: u8, result_scale: i8, } fn decimal_variance_params(data_type: &DataType) -> Option { decimal_scale(data_type).map(|input_scale| { - let base_scale = input_scale.saturating_mul(2); - let target_scale = base_scale - .saturating_add(DECIMAL_VARIANCE_SCALE_INCREMENT) - .min(DECIMAL256_MAX_SCALE); + let input_scale_i32 = input_scale as i32; + let full_two_scale = input_scale_i32.saturating_mul(2); + let target_scale = (full_two_scale + DECIMAL_VARIANCE_SCALE_INCREMENT as i32) + .min(DECIMAL256_MAX_SCALE as i32); DecimalVarianceParams { input_scale, + full_two_scale, result_precision: DECIMAL256_MAX_PRECISION, - result_scale: target_scale, + result_scale: target_scale as i8, } }) } @@ -370,9 +372,10 @@ impl DecimalVarianceState { return Ok(None); } - let two_scale = params.input_scale.saturating_mul(2); - if params.result_scale >= two_scale { - let up = params.result_scale - two_scale; + let two_scale = params.full_two_scale; + let result_scale = params.result_scale as i32; + if result_scale >= two_scale { + let up = result_scale - two_scale; let factor = pow10_i256(up as u32)?; let scaled_numerator = numerator .checked_mul(factor) @@ -383,7 +386,7 @@ impl DecimalVarianceState { return Ok(Some(value)); } - let down = two_scale - params.result_scale; + let down = two_scale - result_scale; let factor = pow10_i256(down as u32)?; let scaled_numerator = numerator .checked_div(factor) @@ -1437,7 +1440,18 @@ mod tests { StatsType::Population, )?; acc.update_batch(&[array])?; - assert_decimal_variance(acc.evaluate()?, 1e-152, params.result_scale); + match acc.evaluate()? { + ScalarValue::Decimal256(Some(raw), precision, scale) => { + // With input scale at the maximum, 2*scale exceeds the maximum representable + // scale for Decimal256, so the result rounds down to zero at scale=76. + assert_eq!( + raw, + i256::ZERO, + "variance should round to zero at max scale (precision={precision}, scale={scale})" + ); + } + other => panic!("expected Decimal256 result, got {other:?}"), + } Ok(()) }