From 2ff71b3b42fe3cd44f8de21e2b8a3d4359f02800 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 17 Sep 2024 16:42:08 -0600 Subject: [PATCH 01/19] agg bench --- native/core/benches/aggregate.rs | 111 +++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 native/core/benches/aggregate.rs diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs new file mode 100644 index 0000000000..10ad5a9e72 --- /dev/null +++ b/native/core/benches/aggregate.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; + +use arrow::compute::filter_record_batch; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::builder::{BooleanBuilder, Int32Builder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch}; +use comet::execution::operators::comet_filter_record_batch; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("filter"); + + let num_rows = 8192; + let num_int_cols = 4; + let num_string_cols = 4; + + let batch = create_record_batch(num_rows, num_int_cols, num_string_cols); + + // create some different predicates + let mut predicate_select_few = BooleanBuilder::with_capacity(num_rows); + let mut predicate_select_many = BooleanBuilder::with_capacity(num_rows); + let mut predicate_select_all = BooleanBuilder::with_capacity(num_rows); + for i in 0..num_rows { + predicate_select_few.append_value(i % 10 == 0); + predicate_select_many.append_value(i % 10 > 0); + predicate_select_all.append_value(true); + } + let predicate_select_few = predicate_select_few.finish(); + let predicate_select_many = predicate_select_many.finish(); + let predicate_select_all = predicate_select_all.finish(); + + // baseline uses Arrow's filter_record_batch method + group.bench_function("arrow_filter_record_batch - few rows selected", |b| { + b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) + }); + group.bench_function("arrow_filter_record_batch - many rows selected", |b| { + b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) + }); + group.bench_function("arrow_filter_record_batch - all rows selected", |b| { + b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) + }); + + group.bench_function("comet_filter_record_batch - few rows selected", |b| { + b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) + }); + group.bench_function("comet_filter_record_batch - many rows selected", |b| { + b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) + }); + group.bench_function("comet_filter_record_batch - all rows selected", |b| { + b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) + }); + + group.finish(); +} + +fn create_record_batch(num_rows: usize, num_int_cols: i32, num_string_cols: i32) -> RecordBatch { + let mut int32_builder = Int32Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + int32_builder.append_value(i as i32); + string_builder.append_value(format!("this is string #{i}")); + } + let int32_array = Arc::new(int32_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + let mut i = 0; + for _ in 0..num_int_cols { + fields.push(Field::new(format!("c{i}"), DataType::Int32, false)); + columns.push(int32_array.clone()); // note this is just copying a reference to the array + i += 1; + } + for _ in 0..num_string_cols { + fields.push(Field::new(format!("c{i}"), DataType::Utf8, false)); + columns.push(string_array.clone()); // note this is just copying a reference to the array + i += 1; + } + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() +} + +fn config() -> Criterion { + Criterion::default() + .measurement_time(Duration::from_millis(500)) + .warm_up_time(Duration::from_millis(500)) +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); From b24794ed44fbd45fd273914db308bf3b590e3425 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 17 Sep 2024 16:53:52 -0600 Subject: [PATCH 02/19] fix --- native/Cargo.lock | 2 + native/core/Cargo.toml | 6 +- native/core/benches/aggregate.rs | 109 +++++++++++++++++-------------- 3 files changed, 66 insertions(+), 51 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 3692f04883..601da3e3f8 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -707,6 +707,7 @@ dependencies = [ "ciborium", "clap", "criterion-plot", + "futures", "is-terminal", "itertools 0.10.5", "num-traits", @@ -719,6 +720,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio", "walkdir", ] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 58fe00e758..68c470be65 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -80,7 +80,7 @@ datafusion-comet-proto = { workspace = true } [dev-dependencies] pprof = { version = "0.13.0", features = ["flamegraph"] } -criterion = "0.5.1" +criterion = { version = "0.5.1", features = ["async_tokio"] } jni = { version = "0.21", features = ["invocation"] } lazy_static = "1.4" assertables = "7" @@ -122,3 +122,7 @@ harness = false [[bench]] name = "filter" harness = false + +[[bench]] +name = "aggregate" +harness = false \ No newline at end of file diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index 10ad5a9e72..bf91fdc286 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -15,62 +15,73 @@ // specific language governing permissions and limitations // under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; -use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Field, Schema}; -use arrow_array::builder::{BooleanBuilder, Int32Builder, StringBuilder}; +use arrow_array::builder::{Int32Builder, StringBuilder}; use arrow_array::{ArrayRef, RecordBatch}; -use comet::execution::operators::comet_filter_record_batch; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::physical_plan::aggregates::{AggregateMode, PhysicalGroupBy}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::Column; use std::sync::Arc; use std::time::Duration; +use tokio::runtime::Runtime; +use futures::StreamExt; fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("filter"); - + let mut group = c.benchmark_group("aggregate"); let num_rows = 8192; - let num_int_cols = 4; - let num_string_cols = 4; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let scan : Arc = Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); + let c0 = Arc::new(Column::new("c0", 0)); + let c1 = Arc::new(Column::new("c1", 1)); - let batch = create_record_batch(num_rows, num_int_cols, num_string_cols); + let schema = scan.schema(); - // create some different predicates - let mut predicate_select_few = BooleanBuilder::with_capacity(num_rows); - let mut predicate_select_many = BooleanBuilder::with_capacity(num_rows); - let mut predicate_select_all = BooleanBuilder::with_capacity(num_rows); - for i in 0..num_rows { - predicate_select_few.append_value(i % 10 == 0); - predicate_select_many.append_value(i % 10 > 0); - predicate_select_all.append_value(true); - } - let predicate_select_few = predicate_select_few.finish(); - let predicate_select_many = predicate_select_many.finish(); - let predicate_select_all = predicate_select_all.finish(); + let aggr_expr = AggregateExprBuilder::new(sum_udaf(), vec![c1]) + .schema(schema.clone()) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .unwrap(); - // baseline uses Arrow's filter_record_batch method - group.bench_function("arrow_filter_record_batch - few rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) - }); - group.bench_function("arrow_filter_record_batch - many rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) - }); - group.bench_function("arrow_filter_record_batch - all rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) - }); + let aggregate = Arc::new( + datafusion::physical_plan::aggregates::AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr], + vec![None], // no filter expressions + scan, + Arc::clone(&schema), + ).unwrap() + ); - group.bench_function("comet_filter_record_batch - few rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) - }); - group.bench_function("comet_filter_record_batch - many rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) - }); - group.bench_function("comet_filter_record_batch - all rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) + let rt = Runtime::new().unwrap(); + + group.bench_function("aggregate - sum int", |b| { + b.to_async(&rt).iter(|| async { + let mut x = aggregate.execute(0, Arc::new(TaskContext::default())).unwrap(); + while let Some(batch) = x.next().await { + let _batch = batch.unwrap(); + } + }) }); group.finish(); } -fn create_record_batch(num_rows: usize, num_int_cols: i32, num_string_cols: i32) -> RecordBatch { + + +fn create_record_batch(num_rows: usize) -> RecordBatch { let mut int32_builder = Int32Builder::with_capacity(num_rows); let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); for i in 0..num_rows { @@ -83,16 +94,14 @@ fn create_record_batch(num_rows: usize, num_int_cols: i32, num_string_cols: i32) let mut fields = vec![]; let mut columns: Vec = vec![]; let mut i = 0; - for _ in 0..num_int_cols { - fields.push(Field::new(format!("c{i}"), DataType::Int32, false)); - columns.push(int32_array.clone()); // note this is just copying a reference to the array - i += 1; - } - for _ in 0..num_string_cols { - fields.push(Field::new(format!("c{i}"), DataType::Utf8, false)); - columns.push(string_array.clone()); // note this is just copying a reference to the array - i += 1; - } + // string column + fields.push(Field::new(format!("c{i}"), DataType::Utf8, false)); + columns.push(string_array.clone()); // note this is just copying a reference to the array + i += 1; + // int column + fields.push(Field::new(format!("c{i}"), DataType::Int32, false)); + columns.push(int32_array.clone()); // note this is just copying a reference to the array + // i += 1; let schema = Schema::new(fields); RecordBatch::try_new(Arc::new(schema), columns).unwrap() } From af9fc152b5fd16762c21cd716e7a436a3e9246ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 17 Sep 2024 17:15:43 -0600 Subject: [PATCH 03/19] fix --- native/core/benches/aggregate.rs | 128 ++++++++++++++++++++++--------- 1 file changed, 91 insertions(+), 37 deletions(-) diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index bf91fdc286..da3c54c21a 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -16,20 +16,24 @@ // under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; -use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_array::builder::{Decimal128Builder, StringBuilder}; use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use comet::execution::datafusion::expressions::sum_decimal::SumDecimal; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::functions_aggregate::sum::sum_udaf; -use datafusion::physical_plan::aggregates::{AggregateMode, PhysicalGroupBy}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion_execution::TaskContext; +use datafusion_expr::AggregateUDF; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::Column; +use futures::StreamExt; use std::sync::Arc; use std::time::Duration; use tokio::runtime::Runtime; -use futures::StreamExt; fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("aggregate"); @@ -40,13 +44,72 @@ fn criterion_benchmark(c: &mut Criterion) { batches.push(batch.clone()); } let partitions = &[batches]; - let scan : Arc = Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let c0 = Arc::new(Column::new("c0", 0)); - let c1 = Arc::new(Column::new("c1", 1)); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); + let schema = scan.schema().clone(); - let schema = scan.schema(); + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); - let aggr_expr = AggregateExprBuilder::new(sum_udaf(), vec![c1]) + let rt = Runtime::new().unwrap(); + + let datafusion_sum_decimal = sum_udaf(); + group.bench_function("aggregate - sum decimal (DataFusion)", |b| { + b.to_async(&rt).iter(|| async { + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); + let aggregate = create_aggregate( + scan, + c0.clone(), + c1.clone(), + &schema, + datafusion_sum_decimal.clone(), + ); + let mut x = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = x.next().await { + let _batch = batch.unwrap(); + } + }) + }); + + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + Arc::clone(&c1), + DataType::Decimal128(7, 2), + ))); + group.bench_function("aggregate - sum decimal (Comet)", |b| { + b.to_async(&rt).iter(|| async { + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); + let aggregate = create_aggregate( + scan, + c0.clone(), + c1.clone(), + &schema, + comet_sum_decimal.clone(), + ); + let mut x = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = x.next().await { + let _batch = batch.unwrap(); + } + }) + }); + + group.finish(); +} + +fn create_aggregate( + scan: Arc, + c0: Arc, + c1: Arc, + schema: &SchemaRef, + aggregate_udf: Arc, +) -> Arc { + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) .schema(schema.clone()) .alias("sum") .with_ignore_nulls(false) @@ -55,53 +118,44 @@ fn criterion_benchmark(c: &mut Criterion) { .unwrap(); let aggregate = Arc::new( - datafusion::physical_plan::aggregates::AggregateExec::try_new( + AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), vec![aggr_expr], vec![None], // no filter expressions scan, Arc::clone(&schema), - ).unwrap() + ) + .unwrap(), ); - - let rt = Runtime::new().unwrap(); - - group.bench_function("aggregate - sum int", |b| { - b.to_async(&rt).iter(|| async { - let mut x = aggregate.execute(0, Arc::new(TaskContext::default())).unwrap(); - while let Some(batch) = x.next().await { - let _batch = batch.unwrap(); - } - }) - }); - - group.finish(); + aggregate } - - fn create_record_batch(num_rows: usize) -> RecordBatch { - let mut int32_builder = Int32Builder::with_capacity(num_rows); + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); for i in 0..num_rows { - int32_builder.append_value(i as i32); - string_builder.append_value(format!("this is string #{i}")); + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); } - let int32_array = Arc::new(int32_builder.finish()); + let decimal_array = Arc::new(decimal_builder.finish()); let string_array = Arc::new(string_builder.finish()); let mut fields = vec![]; let mut columns: Vec = vec![]; - let mut i = 0; + // string column - fields.push(Field::new(format!("c{i}"), DataType::Utf8, false)); - columns.push(string_array.clone()); // note this is just copying a reference to the array - i += 1; - // int column - fields.push(Field::new(format!("c{i}"), DataType::Int32, false)); - columns.push(int32_array.clone()); // note this is just copying a reference to the array - // i += 1; + fields.push(Field::new(format!("c0"), DataType::Utf8, false)); + columns.push(Arc::clone(&string_array)); + + // decimal column + fields.push(Field::new( + format!("c1"), + DataType::Decimal128(38, 10), + false, + )); + columns.push(Arc::clone(&decimal_array)); + let schema = Schema::new(fields); RecordBatch::try_new(Arc::new(schema), columns).unwrap() } From 36260ad2ccb3b005208a663389ece02f32d8a62f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 17 Sep 2024 18:25:26 -0600 Subject: [PATCH 04/19] refactor --- native/core/benches/aggregate.rs | 72 +++++++++++++++----------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index da3c54c21a..5b31748d3e 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -44,64 +44,60 @@ fn criterion_benchmark(c: &mut Criterion) { batches.push(batch.clone()); } let partitions = &[batches]; - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let schema = scan.schema().clone(); - let c0: Arc = Arc::new(Column::new("c0", 0)); let c1: Arc = Arc::new(Column::new("c1", 1)); let rt = Runtime::new().unwrap(); - let datafusion_sum_decimal = sum_udaf(); group.bench_function("aggregate - sum decimal (DataFusion)", |b| { - b.to_async(&rt).iter(|| async { - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let aggregate = create_aggregate( - scan, + let datafusion_sum_decimal = sum_udaf(); + b.to_async(&rt).iter(|| { + agg_test( + partitions, c0.clone(), c1.clone(), - &schema, datafusion_sum_decimal.clone(), - ); - let mut x = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = x.next().await { - let _batch = batch.unwrap(); - } + ) }) }); - let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( - "sum", - Arc::clone(&c1), - DataType::Decimal128(7, 2), - ))); group.bench_function("aggregate - sum decimal (Comet)", |b| { - b.to_async(&rt).iter(|| async { - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let aggregate = create_aggregate( - scan, + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + Arc::clone(&c1), + DataType::Decimal128(7, 2), + ))); + b.to_async(&rt).iter(|| { + agg_test( + partitions, c0.clone(), c1.clone(), - &schema, comet_sum_decimal.clone(), - ); - let mut x = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = x.next().await { - let _batch = batch.unwrap(); - } + ) }) }); group.finish(); } +async fn agg_test( + partitions: &[Vec], + c0: Arc, + c1: Arc, + aggregate_udf: Arc, +) { + let schema = &partitions[0][0].schema(); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap()); + let aggregate = create_aggregate(scan, c0.clone(), c1.clone(), &schema, aggregate_udf); + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch.unwrap(); + } +} + fn create_aggregate( scan: Arc, c0: Arc, @@ -146,7 +142,7 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { // string column fields.push(Field::new(format!("c0"), DataType::Utf8, false)); - columns.push(Arc::clone(&string_array)); + columns.push(string_array); // decimal column fields.push(Field::new( @@ -154,7 +150,7 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { DataType::Decimal128(38, 10), false, )); - columns.push(Arc::clone(&decimal_array)); + columns.push(decimal_array); let schema = Schema::new(fields); RecordBatch::try_new(Arc::new(schema), columns).unwrap() From e0d3a58a8f0887c242a9cd3ca94c2b9b2c6979f3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Sep 2024 07:32:53 -0600 Subject: [PATCH 05/19] avg --- native/core/benches/aggregate.rs | 62 ++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index 5b31748d3e..605e4cb00d 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -19,8 +19,10 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::builder::{Decimal128Builder, StringBuilder}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_schema::SchemaRef; +use comet::execution::datafusion::expressions::avg_decimal::AvgDecimal; use comet::execution::datafusion::expressions::sum_decimal::SumDecimal; use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::functions_aggregate::average::avg_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; @@ -49,7 +51,38 @@ fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); - group.bench_function("aggregate - sum decimal (DataFusion)", |b| { + group.bench_function("avg_decimal_datafusion", |b| { + let datafusion_sum_decimal = avg_udaf(); + b.to_async(&rt).iter(|| { + agg_test( + partitions, + c0.clone(), + c1.clone(), + datafusion_sum_decimal.clone(), + "avg", + ) + }) + }); + + group.bench_function("avg_decimal_comet", |b| { + let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new( + Arc::clone(&c1), + "avg", + DataType::Decimal128(38, 10), + DataType::Decimal128(38, 10), + ))); + b.to_async(&rt).iter(|| { + agg_test( + partitions, + c0.clone(), + c1.clone(), + comet_avg_decimal.clone(), + "avg", + ) + }) + }); + + group.bench_function("sum_decimal_datafusion", |b| { let datafusion_sum_decimal = sum_udaf(); b.to_async(&rt).iter(|| { agg_test( @@ -57,15 +90,16 @@ fn criterion_benchmark(c: &mut Criterion) { c0.clone(), c1.clone(), datafusion_sum_decimal.clone(), + "sum", ) }) }); - group.bench_function("aggregate - sum decimal (Comet)", |b| { + group.bench_function("sum_decimal_comet", |b| { let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( "sum", Arc::clone(&c1), - DataType::Decimal128(7, 2), + DataType::Decimal128(38, 10), ))); b.to_async(&rt).iter(|| { agg_test( @@ -73,6 +107,7 @@ fn criterion_benchmark(c: &mut Criterion) { c0.clone(), c1.clone(), comet_sum_decimal.clone(), + "sum", ) }) }); @@ -85,11 +120,12 @@ async fn agg_test( c0: Arc, c1: Arc, aggregate_udf: Arc, + alias: &str, ) { let schema = &partitions[0][0].schema(); let scan: Arc = Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap()); - let aggregate = create_aggregate(scan, c0.clone(), c1.clone(), &schema, aggregate_udf); + let aggregate = create_aggregate(scan, c0.clone(), c1.clone(), schema, aggregate_udf, alias); let mut stream = aggregate .execute(0, Arc::new(TaskContext::default())) .unwrap(); @@ -104,27 +140,27 @@ fn create_aggregate( c1: Arc, schema: &SchemaRef, aggregate_udf: Arc, + alias: &str, ) -> Arc { let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) .schema(schema.clone()) - .alias("sum") + .alias(alias) .with_ignore_nulls(false) .with_distinct(false) .build() .unwrap(); - let aggregate = Arc::new( + Arc::new( AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), vec![aggr_expr], vec![None], // no filter expressions scan, - Arc::clone(&schema), + Arc::clone(schema), ) .unwrap(), - ); - aggregate + ) } fn create_record_batch(num_rows: usize) -> RecordBatch { @@ -141,15 +177,11 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { let mut columns: Vec = vec![]; // string column - fields.push(Field::new(format!("c0"), DataType::Utf8, false)); + fields.push(Field::new("c0", DataType::Utf8, false)); columns.push(string_array); // decimal column - fields.push(Field::new( - format!("c1"), - DataType::Decimal128(38, 10), - false, - )); + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); columns.push(decimal_array); let schema = Schema::new(fields); From 858b9986eebae4f6275983c2e10a8c33cf92a069 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Sep 2024 11:16:35 -0600 Subject: [PATCH 06/19] optimized decimal aggregates with more efficient version of validate_decimal_precision --- .../datafusion/expressions/avg_decimal.rs | 9 ++++----- .../datafusion/expressions/checkoverflow.rs | 15 ++++++++++++++- .../datafusion/expressions/sum_decimal.rs | 6 +++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index 0462f2d3d5..4693acf013 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -28,10 +28,9 @@ use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; +use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision; use arrow_array::ArrowNativeTypeOp; -use arrow_data::decimal::{ - validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, -}; +use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -212,7 +211,7 @@ impl AvgDecimalAccumulator { None => (v, false), }; - if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { // Overflow: set buffer accumulator to null self.is_not_null = false; return; @@ -380,7 +379,7 @@ impl AvgDecimalGroupsAccumulator { let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value); self.counts[group_index] += 1; - if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { // Overflow: set buffer accumulator to null self.is_not_null.set_bit(group_index, false); return; diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index e922171bd2..ea8826c792 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -27,7 +27,8 @@ use arrow::{ datatypes::{Decimal128Type, DecimalType}, record_batch::RecordBatch, }; -use arrow_schema::{DataType, Schema}; +use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; +use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue}; @@ -171,3 +172,15 @@ impl PhysicalExpr for CheckOverflow { self.hash(&mut s); } } + +/// Adapted from arrow-rs `validate_decimal_precision` but returns bool +/// instead of Err to avoid the cost of formatting the error strings and is +/// optimized to remove a memcpy that exists in the original function +#[inline] +pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { + if precision > DECIMAL128_MAX_PRECISION { + return false; + } + let idx = usize::from(precision) - 1; + value >= MIN_DECIMAL_FOR_EACH_PRECISION[idx] && value <= MAX_DECIMAL_FOR_EACH_PRECISION[idx] +} diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index e957bd25e2..fb55607337 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision; use crate::unlikely; use arrow::{ array::BooleanBufferBuilder, @@ -23,7 +24,6 @@ use arrow::{ use arrow_array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; -use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; @@ -170,7 +170,7 @@ impl SumDecimalAccumulator { let v = unsafe { values.value_unchecked(idx) }; let (new_sum, is_overflow) = self.sum.overflowing_add(v); - if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { // Overflow: set buffer accumulator to null self.is_not_null = false; return; @@ -312,7 +312,7 @@ impl SumDecimalGroupsAccumulator { self.is_empty.set_bit(group_index, false); let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { // Overflow: set buffer accumulator to null self.is_not_null.set_bit(group_index, false); return; From 1e57c045153b72a0e907d2b4ce2523de91c894f1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Sep 2024 13:03:34 -0600 Subject: [PATCH 07/19] simplify function to remove branch --- .../src/execution/datafusion/expressions/checkoverflow.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index ea8826c792..3cd2ebc663 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -178,9 +178,6 @@ impl PhysicalExpr for CheckOverflow { /// optimized to remove a memcpy that exists in the original function #[inline] pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { - if precision > DECIMAL128_MAX_PRECISION { - return false; - } let idx = usize::from(precision) - 1; - value >= MIN_DECIMAL_FOR_EACH_PRECISION[idx] && value <= MAX_DECIMAL_FOR_EACH_PRECISION[idx] + precision <= DECIMAL128_MAX_PRECISION && value >= MIN_DECIMAL_FOR_EACH_PRECISION[idx] && value <= MAX_DECIMAL_FOR_EACH_PRECISION[idx] } From 13312aad3b24922f228bdeea2878be54c2e695df Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Sep 2024 13:22:03 -0600 Subject: [PATCH 08/19] address feedback --- native/core/Cargo.toml | 2 +- native/core/benches/aggregate.rs | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 68c470be65..13f6b135fb 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -125,4 +125,4 @@ harness = false [[bench]] name = "aggregate" -harness = false \ No newline at end of file +harness = false diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index 605e4cb00d..e6b3e31550 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, RecordBatch}; use arrow_schema::SchemaRef; use comet::execution::datafusion::expressions::avg_decimal::AvgDecimal; use comet::execution::datafusion::expressions::sum_decimal::SumDecimal; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion::functions_aggregate::average::avg_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_expr::PhysicalExpr; @@ -54,13 +54,13 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("avg_decimal_datafusion", |b| { let datafusion_sum_decimal = avg_udaf(); b.to_async(&rt).iter(|| { - agg_test( + black_box(agg_test( partitions, c0.clone(), c1.clone(), datafusion_sum_decimal.clone(), "avg", - ) + )) }) }); @@ -72,26 +72,26 @@ fn criterion_benchmark(c: &mut Criterion) { DataType::Decimal128(38, 10), ))); b.to_async(&rt).iter(|| { - agg_test( + black_box(agg_test( partitions, c0.clone(), c1.clone(), comet_avg_decimal.clone(), "avg", - ) + )) }) }); group.bench_function("sum_decimal_datafusion", |b| { let datafusion_sum_decimal = sum_udaf(); b.to_async(&rt).iter(|| { - agg_test( + black_box(agg_test( partitions, c0.clone(), c1.clone(), datafusion_sum_decimal.clone(), "sum", - ) + )) }) }); @@ -102,13 +102,13 @@ fn criterion_benchmark(c: &mut Criterion) { DataType::Decimal128(38, 10), ))); b.to_async(&rt).iter(|| { - agg_test( + black_box(agg_test( partitions, c0.clone(), c1.clone(), comet_sum_decimal.clone(), "sum", - ) + )) }) }); From 4ca48f404414989ca47d19431ad07cfb8dea6366 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 07:21:36 -0600 Subject: [PATCH 09/19] format --- .../execution/datafusion/expressions/checkoverflow.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index 3cd2ebc663..c514eea3d7 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -111,8 +111,9 @@ impl PhysicalExpr for CheckOverflow { let casted_array = if self.fail_on_error { // Returning error if overflow - decimal_array.validate_decimal_precision(*precision)?; decimal_array + .validate_decimal_precision(*precision) + .map(|| decimal_array)? } else { // Overflowing gets null value &decimal_array.null_if_overflow_precision(*precision) @@ -176,8 +177,12 @@ impl PhysicalExpr for CheckOverflow { /// Adapted from arrow-rs `validate_decimal_precision` but returns bool /// instead of Err to avoid the cost of formatting the error strings and is /// optimized to remove a memcpy that exists in the original function +/// we can remove this code once we upgrade to a version of arrow-rs that +/// includes https://github.com/apache/arrow-rs/pull/6419 #[inline] pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { let idx = usize::from(precision) - 1; - precision <= DECIMAL128_MAX_PRECISION && value >= MIN_DECIMAL_FOR_EACH_PRECISION[idx] && value <= MAX_DECIMAL_FOR_EACH_PRECISION[idx] + precision <= DECIMAL128_MAX_PRECISION + && value >= MIN_DECIMAL_FOR_EACH_PRECISION[idx] + && value <= MAX_DECIMAL_FOR_EACH_PRECISION[idx] } From f281fe4b4214c63fe0db8da98cf202cf4d2ea57a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 07:29:05 -0600 Subject: [PATCH 10/19] Revert a change --- .../core/src/execution/datafusion/expressions/checkoverflow.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index c514eea3d7..5893bfd3d3 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -111,9 +111,8 @@ impl PhysicalExpr for CheckOverflow { let casted_array = if self.fail_on_error { // Returning error if overflow + decimal_array.validate_decimal_precision(*precision)?; decimal_array - .validate_decimal_precision(*precision) - .map(|| decimal_array)? } else { // Overflowing gets null value &decimal_array.null_if_overflow_precision(*precision) From 0221e9f03c059d29ea60b40858efdfbf5147515b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 07:54:43 -0600 Subject: [PATCH 11/19] add rust unit test --- .../datafusion/expressions/sum_decimal.rs | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index fb55607337..c59294c4ae 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -478,3 +478,94 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { + self.is_not_null.capacity() / 8 } } + +#[cfg(test)] +mod tests { + use arrow::datatypes::*; + use arrow_array::builder::{Decimal128Builder, StringBuilder}; + use arrow_array::RecordBatch; + use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_common::Result; + use datafusion_execution::TaskContext; + use datafusion_expr::AggregateUDF; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::Column; + + use super::*; + use futures::StreamExt; + #[tokio::test] + async fn sum_no_overflow() -> Result<()> { + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let data_type = DataType::Decimal128(8, 2); + let schema = partitions[0][0].schema().clone(); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, schema.clone(), None).unwrap()); + + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + Arc::clone(&c1), + data_type.clone(), + ))); + + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(schema.clone()) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr], + vec![None], // no filter expressions + scan, + schema.clone(), + )?); + + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch?; + } + + Ok(()) + } + + fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() + } +} From 2357f6281110fafa192e4ab14c3bb8296dd9ff7d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 07:56:41 -0600 Subject: [PATCH 12/19] format --- .../core/src/execution/datafusion/expressions/sum_decimal.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index c59294c4ae..1978a5b9cd 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -481,6 +481,7 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::*; use arrow_array::builder::{Decimal128Builder, StringBuilder}; use arrow_array::RecordBatch; @@ -492,9 +493,8 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::Column; - - use super::*; use futures::StreamExt; + #[tokio::test] async fn sum_no_overflow() -> Result<()> { let num_rows = 8192; From dcbc88a3c53024b54761fdf9200d464fea044f60 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 08:13:40 -0600 Subject: [PATCH 13/19] code cleanup --- .../datafusion/expressions/avg_decimal.rs | 22 ++----- .../datafusion/expressions/sum_decimal.rs | 62 +++++++++++-------- .../core/src/execution/datafusion/planner.rs | 53 ++++++---------- 3 files changed, 60 insertions(+), 77 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index 4693acf013..b2e210ee4d 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -42,7 +42,6 @@ use DataType::*; /// AVG aggregate expression #[derive(Debug, Clone)] pub struct AvgDecimal { - name: String, signature: Signature, expr: Arc, sum_data_type: DataType, @@ -51,14 +50,8 @@ pub struct AvgDecimal { impl AvgDecimal { /// Create a new AVG aggregate function - pub fn new( - expr: Arc, - name: impl Into, - result_type: DataType, - sum_type: DataType, - ) -> Self { + pub fn new(expr: Arc, result_type: DataType, sum_type: DataType) -> Self { Self { - name: name.into(), signature: Signature::user_defined(Immutable), expr, result_data_type: result_type, @@ -94,20 +87,16 @@ impl AggregateUDFImpl for AvgDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "sum"), + format_state_name("sum", "sum"), self.sum_data_type.clone(), true, ), - Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - ), + Field::new(format_state_name("sum", "count"), DataType::Int64, true), ]) } fn name(&self) -> &str { - &self.name + "avg" } fn reverse_expr(&self) -> ReversedUDAF { @@ -168,8 +157,7 @@ impl PartialEq for AvgDecimal { down_cast_any_ref(other) .downcast_ref::() .map(|x| { - self.name == x.name - && self.sum_data_type == x.sum_data_type + self.sum_data_type == x.sum_data_type && self.result_data_type == x.result_data_type && self.expr.eq(&x.expr) }) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 1978a5b9cd..cdb8f9a50e 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -27,7 +27,7 @@ use arrow_array::{ use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; -use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; @@ -36,37 +36,37 @@ use std::{any::Any, ops::BitAnd, sync::Arc}; #[derive(Debug)] pub struct SumDecimal { - name: String, + /// Aggregate function signature signature: Signature, + /// The expression that provides the input decimal values to be summed expr: Arc, - - /// The data type of the SUM result + /// The data type of the SUM result. This will always be a decimal type + /// with the same precision and scale as specified in this struct result_type: DataType, - - /// Decimal precision and scale + /// Decimal precision precision: u8, + /// Decimal scale scale: i8, - - /// Whether the result is nullable - nullable: bool, } impl SumDecimal { - pub fn new(name: impl Into, expr: Arc, data_type: DataType) -> Self { + pub fn try_new(expr: Arc, data_type: DataType) -> DFResult { // The `data_type` is the SUM result type passed from Spark side let (precision, scale) = match data_type { DataType::Decimal128(p, s) => (p, s), - _ => unreachable!(), + _ => { + return Err(DataFusionError::Internal( + "Invalid data type for SumDecimal".into(), + )) + } }; - Self { - name: name.into(), + Ok(Self { signature: Signature::user_defined(Immutable), expr, result_type: data_type, precision, scale, - nullable: true, - } + }) } } @@ -84,14 +84,14 @@ impl AggregateUDFImpl for SumDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ - Field::new(&self.name, self.result_type.clone(), self.nullable), + Field::new("sum", self.result_type.clone(), true), Field::new("is_empty", DataType::Boolean, false), ]; Ok(fields) } fn name(&self) -> &str { - &self.name + "sum" } fn signature(&self) -> &Signature { @@ -127,6 +127,11 @@ impl AggregateUDFImpl for SumDecimal { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn is_nullable(&self) -> bool { + // SumDecimal is always nullable because overflows can cause null values + true + } } impl PartialEq for SumDecimal { @@ -134,12 +139,10 @@ impl PartialEq for SumDecimal { down_cast_any_ref(other) .downcast_ref::() .map(|x| { - self.name == x.name - && self.precision == x.precision - && self.scale == x.scale - && self.nullable == x.nullable - && self.result_type == x.result_type - && self.expr.eq(&x.expr) + // note that we do not compare data_type because this + // is guaranteed to match if the precision and scale + // match + self.precision == x.precision && self.scale == x.scale && self.expr.eq(&x.expr) }) .unwrap_or(false) } @@ -492,9 +495,15 @@ mod tests { use datafusion_execution::TaskContext; use datafusion_expr::AggregateUDF; use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::expressions::{Column, Literal}; use futures::StreamExt; + #[test] + fn invalid_data_type() { + let expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + assert!(SumDecimal::try_new(expr, DataType::Int32).is_err()); + } + #[tokio::test] async fn sum_no_overflow() -> Result<()> { let num_rows = 8192; @@ -512,11 +521,10 @@ mod tests { let scan: Arc = Arc::new(MemoryExec::try_new(partitions, schema.clone(), None).unwrap()); - let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( - "sum", + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( Arc::clone(&c1), data_type.clone(), - ))); + )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) .schema(schema.clone()) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d7c8d74592..467bf5cf12 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1365,56 +1365,42 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - match datatype { + let builder = match datatype { DataType::Decimal128(_, _) => { - let func = AggregateUDF::new_from_impl(SumDecimal::new( - "sum", + let func = AggregateUDF::new_from_impl(SumDecimal::try_new( Arc::clone(&child), datatype, - )); + )?); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side let child = Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - AggregateExprBuilder::new(sum_udaf(), vec![child]) - .schema(schema) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - } + }; + builder + .schema(schema) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) } AggExprStruct::Avg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); - match datatype { + let builder = match datatype { DataType::Decimal128(_, _) => { let func = AggregateUDF::new_from_impl(AvgDecimal::new( Arc::clone(&child), - "avg", datatype, input_datatype, )); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("avg") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of AVG if the result data type is different @@ -1428,14 +1414,15 @@ impl PhysicalPlanner { datatype, )); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("avg") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - } + }; + builder + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) } AggExprStruct::First(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; From 8487251b50c38a4a2d2dd6346989484e9de06dbd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 08:15:30 -0600 Subject: [PATCH 14/19] fix --- .../core/src/execution/datafusion/expressions/avg_decimal.rs | 4 ++-- .../core/src/execution/datafusion/expressions/sum_decimal.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index b2e210ee4d..e4a7f16d79 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -87,11 +87,11 @@ impl AggregateUDFImpl for AvgDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name("sum", "sum"), + format_state_name(self.name(), "sum"), self.sum_data_type.clone(), true, ), - Field::new(format_state_name("sum", "count"), DataType::Int64, true), + Field::new(format_state_name(self.name(), "count"), DataType::Int64, true), ]) } diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index cdb8f9a50e..09b08b3a97 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -84,7 +84,7 @@ impl AggregateUDFImpl for SumDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ - Field::new("sum", self.result_type.clone(), true), + Field::new(self.name(), self.result_type.clone(), true), Field::new("is_empty", DataType::Boolean, false), ]; Ok(fields) From 3038ba60b152590efa557398b943a95dcbc6cf54 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 08:16:38 -0600 Subject: [PATCH 15/19] fix --- native/core/src/execution/datafusion/expressions/sum_decimal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 09b08b3a97..099e4de84f 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -84,7 +84,7 @@ impl AggregateUDFImpl for SumDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ - Field::new(self.name(), self.result_type.clone(), true), + Field::new(self.name(), self.result_type.clone(), self.is_nullable()), Field::new("is_empty", DataType::Boolean, false), ]; Ok(fields) From a75f8705bfae06a44bc2c7ef21f8abc739c9ea3d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 08:17:21 -0600 Subject: [PATCH 16/19] fix --- native/core/src/execution/datafusion/expressions/sum_decimal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 099e4de84f..048f1f8e5c 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -139,7 +139,7 @@ impl PartialEq for SumDecimal { down_cast_any_ref(other) .downcast_ref::() .map(|x| { - // note that we do not compare data_type because this + // note that we do not compare result_type because this // is guaranteed to match if the precision and scale // match self.precision == x.precision && self.scale == x.scale && self.expr.eq(&x.expr) From 284bd99dec5e74d0351187516098227a9fc31534 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 10:02:48 -0600 Subject: [PATCH 17/19] fmt --- .../src/execution/datafusion/expressions/avg_decimal.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index e4a7f16d79..a265fdc29e 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -91,7 +91,11 @@ impl AggregateUDFImpl for AvgDecimal { self.sum_data_type.clone(), true, ), - Field::new(format_state_name(self.name(), "count"), DataType::Int64, true), + Field::new( + format_state_name(self.name(), "count"), + DataType::Int64, + true, + ), ]) } From 18d474fdeb761c7ab33028c72c8494f76272781e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 11:07:16 -0600 Subject: [PATCH 18/19] update bench --- native/core/benches/aggregate.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index e6b3e31550..14425f76c5 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -67,7 +67,6 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("avg_decimal_comet", |b| { let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new( Arc::clone(&c1), - "avg", DataType::Decimal128(38, 10), DataType::Decimal128(38, 10), ))); @@ -96,11 +95,9 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.bench_function("sum_decimal_comet", |b| { - let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( - "sum", - Arc::clone(&c1), - DataType::Decimal128(38, 10), - ))); + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( + SumDecimal::try_new(Arc::clone(&c1), DataType::Decimal128(38, 10)).unwrap(), + )); b.to_async(&rt).iter(|| { black_box(agg_test( partitions, From 5381327906d8147efe91ff817ec0e30a7fbb96da Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Sep 2024 13:21:40 -0600 Subject: [PATCH 19/19] clippy --- .../src/execution/datafusion/expressions/sum_decimal.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index 048f1f8e5c..a3ce96b676 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -517,9 +517,9 @@ mod tests { let c1: Arc = Arc::new(Column::new("c1", 1)); let data_type = DataType::Decimal128(8, 2); - let schema = partitions[0][0].schema().clone(); + let schema = Arc::clone(&partitions[0][0].schema()); let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, schema.clone(), None).unwrap()); + Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap()); let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( Arc::clone(&c1), @@ -527,7 +527,7 @@ mod tests { )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) - .schema(schema.clone()) + .schema(Arc::clone(&schema)) .alias("sum") .with_ignore_nulls(false) .with_distinct(false) @@ -539,7 +539,7 @@ mod tests { vec![aggr_expr], vec![None], // no filter expressions scan, - schema.clone(), + Arc::clone(&schema), )?); let mut stream = aggregate