Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions native/core/benches/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("avg_decimal_comet", |b| {
let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new(
Arc::clone(&c1),
"avg",
DataType::Decimal128(38, 10),
DataType::Decimal128(38, 10),
)));
Expand Down Expand Up @@ -96,11 +95,9 @@ fn criterion_benchmark(c: &mut Criterion) {
});

group.bench_function("sum_decimal_comet", |b| {
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new(
"sum",
Arc::clone(&c1),
DataType::Decimal128(38, 10),
)));
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
SumDecimal::try_new(Arc::clone(&c1), DataType::Decimal128(38, 10)).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
partitions,
Expand Down
27 changes: 9 additions & 18 deletions native/core/src/execution/datafusion/expressions/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr};
use std::{any::Any, sync::Arc};

use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
use arrow_array::ArrowNativeTypeOp;
use arrow_data::decimal::{
validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use datafusion::logical_expr::Volatility::Immutable;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
Expand All @@ -43,7 +42,6 @@ use DataType::*;
/// AVG aggregate expression
#[derive(Debug, Clone)]
pub struct AvgDecimal {
name: String,
signature: Signature,
expr: Arc<dyn PhysicalExpr>,
sum_data_type: DataType,
Expand All @@ -52,14 +50,8 @@ pub struct AvgDecimal {

impl AvgDecimal {
/// Create a new AVG aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
result_type: DataType,
sum_type: DataType,
) -> Self {
pub fn new(expr: Arc<dyn PhysicalExpr>, result_type: DataType, sum_type: DataType) -> Self {
Self {
name: name.into(),
signature: Signature::user_defined(Immutable),
expr,
result_data_type: result_type,
Expand Down Expand Up @@ -95,20 +87,20 @@ impl AggregateUDFImpl for AvgDecimal {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
format_state_name(&self.name, "sum"),
format_state_name(self.name(), "sum"),
self.sum_data_type.clone(),
true,
),
Field::new(
format_state_name(&self.name, "count"),
format_state_name(self.name(), "count"),
DataType::Int64,
true,
),
])
}

fn name(&self) -> &str {
&self.name
"avg"
}

fn reverse_expr(&self) -> ReversedUDAF {
Expand Down Expand Up @@ -169,8 +161,7 @@ impl PartialEq<dyn Any> for AvgDecimal {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.sum_data_type == x.sum_data_type
self.sum_data_type == x.sum_data_type
&& self.result_data_type == x.result_data_type
&& self.expr.eq(&x.expr)
})
Expand Down Expand Up @@ -212,7 +203,7 @@ impl AvgDecimalAccumulator {
None => (v, false),
};

if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
// Overflow: set buffer accumulator to null
self.is_not_null = false;
return;
Expand Down Expand Up @@ -380,7 +371,7 @@ impl AvgDecimalGroupsAccumulator {
let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
self.counts[group_index] += 1;

if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
// Overflow: set buffer accumulator to null
self.is_not_null.set_bit(group_index, false);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use arrow::{
datatypes::{Decimal128Type, DecimalType},
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Schema};
use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{DataFusionError, ScalarValue};
Expand Down Expand Up @@ -171,3 +172,15 @@ impl PhysicalExpr for CheckOverflow {
self.hash(&mut s);
}
}

/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
/// instead of Err to avoid the cost of formatting the error strings and is
/// optimized to remove a memcpy that exists in the original function
/// we can remove this code once we upgrade to a version of arrow-rs that
/// includes https://github.com/apache/arrow-rs/pull/6419
#[inline]
pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
precision <= DECIMAL128_MAX_PRECISION
&& value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
&& value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
}
151 changes: 125 additions & 26 deletions native/core/src/execution/datafusion/expressions/sum_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
use crate::unlikely;
use arrow::{
array::BooleanBufferBuilder,
Expand All @@ -23,11 +24,10 @@ use arrow::{
use arrow_array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
use arrow_data::decimal::validate_decimal_precision;
use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{Result as DFResult, ScalarValue};
use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
Expand All @@ -36,37 +36,37 @@ use std::{any::Any, ops::BitAnd, sync::Arc};

#[derive(Debug)]
pub struct SumDecimal {
name: String,
/// Aggregate function signature
signature: Signature,
/// The expression that provides the input decimal values to be summed
expr: Arc<dyn PhysicalExpr>,

/// The data type of the SUM result
/// The data type of the SUM result. This will always be a decimal type
/// with the same precision and scale as specified in this struct
result_type: DataType,

/// Decimal precision and scale
/// Decimal precision
precision: u8,
/// Decimal scale
scale: i8,

/// Whether the result is nullable
nullable: bool,
}

impl SumDecimal {
pub fn new(name: impl Into<String>, expr: Arc<dyn PhysicalExpr>, data_type: DataType) -> Self {
pub fn try_new(expr: Arc<dyn PhysicalExpr>, data_type: DataType) -> DFResult<Self> {
// The `data_type` is the SUM result type passed from Spark side
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => unreachable!(),
_ => {
return Err(DataFusionError::Internal(
"Invalid data type for SumDecimal".into(),
))
}
};
Self {
name: name.into(),
Ok(Self {
signature: Signature::user_defined(Immutable),
expr,
result_type: data_type,
precision,
scale,
nullable: true,
}
})
}
}

Expand All @@ -84,14 +84,14 @@ impl AggregateUDFImpl for SumDecimal {

fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
let fields = vec![
Field::new(&self.name, self.result_type.clone(), self.nullable),
Field::new(self.name(), self.result_type.clone(), self.is_nullable()),
Field::new("is_empty", DataType::Boolean, false),
];
Ok(fields)
}

fn name(&self) -> &str {
&self.name
"sum"
}

fn signature(&self) -> &Signature {
Expand Down Expand Up @@ -127,19 +127,22 @@ impl AggregateUDFImpl for SumDecimal {
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}

fn is_nullable(&self) -> bool {
// SumDecimal is always nullable because overflows can cause null values
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this is true for ANSI.
It looks the previous code is also hardcoding true, but this may be a good time to file an issue if there is not yet.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I filed #961

true
}
}

impl PartialEq<dyn Any> for SumDecimal {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.precision == x.precision
&& self.scale == x.scale
&& self.nullable == x.nullable
&& self.result_type == x.result_type
&& self.expr.eq(&x.expr)
// note that we do not compare result_type because this
// is guaranteed to match if the precision and scale
// match
self.precision == x.precision && self.scale == x.scale && self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
Expand Down Expand Up @@ -170,7 +173,7 @@ impl SumDecimalAccumulator {
let v = unsafe { values.value_unchecked(idx) };
let (new_sum, is_overflow) = self.sum.overflowing_add(v);

if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null = false;
return;
Expand Down Expand Up @@ -312,7 +315,7 @@ impl SumDecimalGroupsAccumulator {
self.is_empty.set_bit(group_index, false);
let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);

if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null.set_bit(group_index, false);
return;
Expand Down Expand Up @@ -478,3 +481,99 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
+ self.is_not_null.capacity() / 8
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::*;
use arrow_array::builder::{Decimal128Builder, StringBuilder};
use arrow_array::RecordBatch;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::Result;
use datafusion_execution::TaskContext;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, Literal};
use futures::StreamExt;

#[test]
fn invalid_data_type() {
let expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
assert!(SumDecimal::try_new(expr, DataType::Int32).is_err());
}

#[tokio::test]
async fn sum_no_overflow() -> Result<()> {
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));

let data_type = DataType::Decimal128(8, 2);
let schema = Arc::clone(&partitions[0][0].schema());
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap());

let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
Arc::clone(&c1),
data_type.clone(),
)?));

let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
.schema(Arc::clone(&schema))
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()?;

let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
vec![aggr_expr],
vec![None], // no filter expressions
scan,
Arc::clone(&schema),
)?);

let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch?;
}

Ok(())
}

fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for i in 0..num_rows {
decimal_builder.append_value(i as i128);
string_builder.append_value(format!("this is string #{}", i % 1024));
}
let decimal_array = Arc::new(decimal_builder.finish());
let string_array = Arc::new(string_builder.finish());

let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];

// string column
fields.push(Field::new("c0", DataType::Utf8, false));
columns.push(string_array);

// decimal column
fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
columns.push(decimal_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}
}
Loading