From 8bdefdf5f990b3bed55a4629d6e054c217204b2f Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 24 Nov 2020 22:17:12 +0100 Subject: [PATCH 1/7] Reduce size of key by boxing string --- .../datafusion/src/physical_plan/group_scalar.rs | 16 ++++++++++------ .../src/physical_plan/hash_aggregate.rs | 5 +++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index bb1e204c7f5..e5a1937cbcf 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -25,6 +25,7 @@ use crate::scalar::ScalarValue; /// Enumeration of types that can be used in a GROUP BY expression (all primitives except /// for floating point numerics) #[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[repr(u8)] pub(crate) enum GroupByScalar { UInt8(u8), UInt16(u16), @@ -34,7 +35,7 @@ pub(crate) enum GroupByScalar { Int16(i16), Int32(i32), Int64(i64), - Utf8(String), + Utf8(Box), } impl TryFrom<&ScalarValue> for GroupByScalar { @@ -50,7 +51,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar { ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), - ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone()), + ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) @@ -86,7 +87,7 @@ impl From<&GroupByScalar> for ScalarValue { GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), - GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone())), + GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone().to_string())), } } } @@ -122,13 +123,16 @@ mod tests { match result { Err(DataFusionError::Internal(error_message)) => assert_eq!( error_message, - String::from( - "Cannot convert a ScalarValue with associated DataType Float32" - ) + String::from("Cannot convert a ScalarValue with associated DataType Float32") ), _ => panic!("Unexpected result"), } Ok(()) } + + #[test] + fn size_of_group_by_scalar() { + assert_eq!(std::mem::size_of::(), 16); + } } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 4dc1f903ffd..fed832fd924 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -656,7 +656,8 @@ fn create_batch_from_map( GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&**str])), + GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&***str])) + //GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![**str])), }) .collect::>(); @@ -763,7 +764,7 @@ pub(crate) fn create_key( } DataType::Utf8 => { let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(String::from(array.value(row))) + vec[i] = GroupByScalar::Utf8(Box::new(String::from(array.value(row)))) } _ => { // This is internal because we should have caught this before. From 8f1238d1ccbf47b0384990d659bffea8d5a7a5db Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 24 Nov 2020 22:38:11 +0100 Subject: [PATCH 2/7] Use Box --- .../src/physical_plan/group_scalar.rs | 9 +-- .../src/physical_plan/hash_aggregate.rs | 78 ++++++------------- 2 files changed, 28 insertions(+), 59 deletions(-) diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index e5a1937cbcf..c20dfcb5840 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -25,7 +25,6 @@ use crate::scalar::ScalarValue; /// Enumeration of types that can be used in a GROUP BY expression (all primitives except /// for floating point numerics) #[derive(Debug, PartialEq, Eq, Hash, Clone)] -#[repr(u8)] pub(crate) enum GroupByScalar { UInt8(u8), UInt16(u16), @@ -35,7 +34,7 @@ pub(crate) enum GroupByScalar { Int16(i16), Int32(i32), Int64(i64), - Utf8(Box), + Utf8(Box), } impl TryFrom<&ScalarValue> for GroupByScalar { @@ -51,7 +50,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar { ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), - ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), + ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone().into_boxed_str()), ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) @@ -87,7 +86,7 @@ impl From<&GroupByScalar> for ScalarValue { GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), - GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone().to_string())), + GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())), } } } @@ -133,6 +132,6 @@ mod tests { #[test] fn size_of_group_by_scalar() { - assert_eq!(std::mem::size_of::(), 16); + assert_eq!(std::mem::size_of::(), 24); } } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index fed832fd924..0941deeae92 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use arrow::{ array::{ - ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, }, compute, }; @@ -264,8 +264,7 @@ fn group_aggregate_batch( let accumulator_set = create_accumulators(aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; - accumulators - .insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); + accumulators.insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); } // 1.3 Some((_, v)) => v.push(row as u32), @@ -366,14 +365,9 @@ impl GroupedHashAggregateStream { let schema_clone = schema.clone(); tokio::spawn(async move { - let result = compute_grouped_hash_aggregate( - mode, - schema_clone, - group_expr, - aggr_expr, - input, - ) - .await; + let result = + compute_grouped_hash_aggregate(mode, schema_clone, group_expr, aggr_expr, input) + .await; tx.send(result) }); @@ -386,16 +380,12 @@ impl GroupedHashAggregateStream { } type AccumulatorSet = Vec>; -type Accumulators = - HashMap, (AccumulatorSet, Box>), RandomState>; +type Accumulators = HashMap, (AccumulatorSet, Box>), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.finished { return Poll::Ready(None); } @@ -427,10 +417,7 @@ impl RecordBatchStream for GroupedHashAggregateStream { } /// Evaluates expressions against a record batch. -fn evaluate( - expr: &Vec>, - batch: &RecordBatch, -) -> Result> { +fn evaluate(expr: &Vec>, batch: &RecordBatch) -> Result> { expr.iter() .map(|expr| expr.evaluate(&batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) @@ -448,9 +435,7 @@ fn evaluate_many( } /// uses `state_fields` to build a vec of expressions required to merge the AggregateExpr' accumulator's state. -fn merge_expressions( - expr: &Arc, -) -> Result>> { +fn merge_expressions(expr: &Arc) -> Result>> { Ok(expr .state_fields()? .iter() @@ -470,9 +455,7 @@ fn aggregate_expressions( mode: &AggregateMode, ) -> Result>>> { match mode { - AggregateMode::Partial => { - Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()) - } + AggregateMode::Partial => Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()), // in this mode, we build the merge expressions of the aggregation AggregateMode::Final => Ok(aggr_expr .iter() @@ -496,8 +479,8 @@ async fn compute_hash_aggregate( aggr_expr: Vec>, mut input: SendableRecordBatchStream, ) -> ArrowResult { - let mut accumulators = create_accumulators(&aggr_expr) - .map_err(DataFusionError::into_arrow_external_error)?; + let mut accumulators = + create_accumulators(&aggr_expr).map_err(DataFusionError::into_arrow_external_error)?; let expressions = aggregate_expressions(&aggr_expr, &mode) .map_err(DataFusionError::into_arrow_external_error)?; @@ -530,8 +513,7 @@ impl HashAggregateStream { let schema_clone = schema.clone(); tokio::spawn(async move { - let result = - compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await; + let result = compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await; tx.send(result) }); @@ -582,10 +564,7 @@ fn aggregate_batch( impl Stream for HashAggregateStream { type Item = ArrowResult; - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.finished { return Poll::Ready(None); } @@ -646,9 +625,7 @@ fn create_batch_from_map( // 2. let mut groups = (0..num_group_expr) .map(|i| match &k[i] { - GroupByScalar::Int8(n) => { - Arc::new(Int8Array::from(vec![*n])) as ArrayRef - } + GroupByScalar::Int8(n) => Arc::new(Int8Array::from(vec![*n])) as ArrayRef, GroupByScalar::Int16(n) => Arc::new(Int16Array::from(vec![*n])), GroupByScalar::Int32(n) => Arc::new(Int32Array::from(vec![*n])), GroupByScalar::Int64(n) => Arc::new(Int64Array::from(vec![*n])), @@ -656,8 +633,7 @@ fn create_batch_from_map( GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&***str])) - //GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![**str])), + GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&**str])), }) .collect::>(); @@ -682,9 +658,7 @@ fn create_batch_from_map( Ok(batch) } -fn create_accumulators( - aggr_expr: &Vec>, -) -> Result { +fn create_accumulators(aggr_expr: &Vec>) -> Result { aggr_expr .iter() .map(|expr| expr.create_accumulator()) @@ -704,9 +678,8 @@ fn finalize_aggregation( .iter() .map(|accumulator| accumulator.state()) .map(|value| { - value.and_then(|e| { - Ok(e.iter().map(|v| v.to_array()).collect::>()) - }) + value + .and_then(|e| Ok(e.iter().map(|v| v.to_array()).collect::>())) }) .collect::>>()?; Ok(a.iter().flatten().cloned().collect::>()) @@ -764,7 +737,7 @@ pub(crate) fn create_key( } DataType::Utf8 => { let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(Box::new(String::from(array.value(row)))) + vec[i] = GroupByScalar::Utf8(array.value(row).to_string().into_boxed_str()) } _ => { // This is internal because we should have caught this before. @@ -822,8 +795,7 @@ mod tests { /// build the aggregates on the data from some_data() and check the results async fn check_aggregates(input: Arc) -> Result<()> { - let groups: Vec<(Arc, String)> = - vec![(col("a"), "a".to_string())]; + let groups: Vec<(Arc, String)> = vec![(col("a"), "a".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( col("b"), @@ -971,16 +943,14 @@ mod tests { #[tokio::test] async fn aggregate_source_not_yielding() -> Result<()> { - let input: Arc = - Arc::new(TestYieldingExec { yield_first: false }); + let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); check_aggregates(input).await } #[tokio::test] async fn aggregate_source_with_yielding() -> Result<()> { - let input: Arc = - Arc::new(TestYieldingExec { yield_first: true }); + let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); check_aggregates(input).await } From eb44aaa1a4a47f36c5fa0c5c990e2383ed49d136 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 24 Nov 2020 22:52:44 +0100 Subject: [PATCH 3/7] Simplify using into --- rust/datafusion/src/physical_plan/hash_aggregate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 0941deeae92..9f1272c3435 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -737,7 +737,7 @@ pub(crate) fn create_key( } DataType::Utf8 => { let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(array.value(row).to_string().into_boxed_str()) + vec[i] = GroupByScalar::Utf8(array.value(row).into()) } _ => { // This is internal because we should have caught this before. From 02d8da04bddf6d87ede0c1f6f04914a87343186f Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 24 Nov 2020 23:09:49 +0100 Subject: [PATCH 4/7] Improve benchmarks --- .../datafusion/benches/aggregate_query_sql.rs | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index d6957808bdf..15ee38ad8fe 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -19,8 +19,7 @@ extern crate criterion; use criterion::Criterion; -use rand::seq::SliceRandom; -use rand::Rng; +use rand::{seq::SliceRandom, rngs::StdRng, Rng, SeedableRng}; use std::sync::{Arc, Mutex}; use tokio::runtime::Runtime; @@ -40,6 +39,10 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); @@ -50,7 +53,7 @@ fn query(ctx: Arc>, sql: &str) { fn create_data(size: usize, null_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = rand::thread_rng(); + let mut rng = seedable_rng(); (0..size) .map(|_| { @@ -65,7 +68,7 @@ fn create_data(size: usize, null_density: f64) -> Vec> { fn create_integer_data(size: usize, value_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = rand::thread_rng(); + let mut rng = seedable_rng(); (0..size) .map(|_| { @@ -98,6 +101,8 @@ fn create_context( Field::new("u64_narrow", DataType::UInt64, false), ])); + let mut rng = seedable_rng(); + // define data. let partitions = (0..partitions_len) .map(|_| { @@ -109,7 +114,7 @@ fn create_context( let keys: Vec = (0..batch_size) .map( // use random numbers to avoid spurious compiler optimizations wrt to branching - |_| format!("hi{:?}", vs.choose(&mut rand::thread_rng())), + |_| format!("hi{:?}", vs.choose(&mut rng)), ) .collect(); let keys: Vec<&str> = keys.iter().map(|e| &**e).collect(); @@ -124,7 +129,7 @@ fn create_context( let integer_values_narrow = (0..batch_size) .map(|_| { *integer_values_narrow_choices - .choose(&mut rand::thread_rng()) + .choose(&mut rng) .unwrap() }) .collect::>(); @@ -216,6 +221,27 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("aggregate_query_group_by_u64 15 12", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT utf8, MIN(f64), AVG(f64), COUNT(f64) \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("aggregate_query_group_by_with_filter_u64 15 12", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT utf8, MIN(f64), AVG(f64), COUNT(f64) \ + FROM t \ + WHERE f32 > 10 AND f32 < 20 GROUP BY utf8", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); From 7d4ee701e909007fef924c5ef343a4ef15f89509 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 24 Nov 2020 23:11:21 +0100 Subject: [PATCH 5/7] Format fixes --- .../datafusion/benches/aggregate_query_sql.rs | 14 ++-- .../src/physical_plan/group_scalar.rs | 4 +- .../src/physical_plan/hash_aggregate.rs | 73 +++++++++++++------ 3 files changed, 59 insertions(+), 32 deletions(-) diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index 15ee38ad8fe..97bdc720478 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -19,7 +19,7 @@ extern crate criterion; use criterion::Criterion; -use rand::{seq::SliceRandom, rngs::StdRng, Rng, SeedableRng}; +use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; use std::sync::{Arc, Mutex}; use tokio::runtime::Runtime; @@ -127,11 +127,7 @@ fn create_context( // Integer values between [0, 9]. let integer_values_narrow_choices = (0..10).collect::>(); let integer_values_narrow = (0..batch_size) - .map(|_| { - *integer_values_narrow_choices - .choose(&mut rng) - .unwrap() - }) + .map(|_| *integer_values_narrow_choices.choose(&mut rng).unwrap()) .collect::>(); RecordBatch::try_new( @@ -226,7 +222,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { query( ctx.clone(), - "SELECT utf8, MIN(f64), AVG(f64), COUNT(f64) \ + "SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \ FROM t GROUP BY u64_narrow", ) }) @@ -236,9 +232,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { query( ctx.clone(), - "SELECT utf8, MIN(f64), AVG(f64), COUNT(f64) \ + "SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \ FROM t \ - WHERE f32 > 10 AND f32 < 20 GROUP BY utf8", + WHERE f32 > 10 AND f32 < 20 GROUP BY u64_narrow", ) }) }); diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index c20dfcb5840..1543de1df67 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -122,7 +122,9 @@ mod tests { match result { Err(DataFusionError::Internal(error_message)) => assert_eq!( error_message, - String::from("Cannot convert a ScalarValue with associated DataType Float32") + String::from( + "Cannot convert a ScalarValue with associated DataType Float32" + ) ), _ => panic!("Unexpected result"), } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 9f1272c3435..85065aac75b 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use arrow::{ array::{ - ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, compute, }; @@ -264,7 +264,8 @@ fn group_aggregate_batch( let accumulator_set = create_accumulators(aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; - accumulators.insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); + accumulators + .insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); } // 1.3 Some((_, v)) => v.push(row as u32), @@ -365,9 +366,14 @@ impl GroupedHashAggregateStream { let schema_clone = schema.clone(); tokio::spawn(async move { - let result = - compute_grouped_hash_aggregate(mode, schema_clone, group_expr, aggr_expr, input) - .await; + let result = compute_grouped_hash_aggregate( + mode, + schema_clone, + group_expr, + aggr_expr, + input, + ) + .await; tx.send(result) }); @@ -380,12 +386,16 @@ impl GroupedHashAggregateStream { } type AccumulatorSet = Vec>; -type Accumulators = HashMap, (AccumulatorSet, Box>), RandomState>; +type Accumulators = + HashMap, (AccumulatorSet, Box>), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; - fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { if self.finished { return Poll::Ready(None); } @@ -417,7 +427,10 @@ impl RecordBatchStream for GroupedHashAggregateStream { } /// Evaluates expressions against a record batch. -fn evaluate(expr: &Vec>, batch: &RecordBatch) -> Result> { +fn evaluate( + expr: &Vec>, + batch: &RecordBatch, +) -> Result> { expr.iter() .map(|expr| expr.evaluate(&batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) @@ -435,7 +448,9 @@ fn evaluate_many( } /// uses `state_fields` to build a vec of expressions required to merge the AggregateExpr' accumulator's state. -fn merge_expressions(expr: &Arc) -> Result>> { +fn merge_expressions( + expr: &Arc, +) -> Result>> { Ok(expr .state_fields()? .iter() @@ -455,7 +470,9 @@ fn aggregate_expressions( mode: &AggregateMode, ) -> Result>>> { match mode { - AggregateMode::Partial => Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()), + AggregateMode::Partial => { + Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()) + } // in this mode, we build the merge expressions of the aggregation AggregateMode::Final => Ok(aggr_expr .iter() @@ -479,8 +496,8 @@ async fn compute_hash_aggregate( aggr_expr: Vec>, mut input: SendableRecordBatchStream, ) -> ArrowResult { - let mut accumulators = - create_accumulators(&aggr_expr).map_err(DataFusionError::into_arrow_external_error)?; + let mut accumulators = create_accumulators(&aggr_expr) + .map_err(DataFusionError::into_arrow_external_error)?; let expressions = aggregate_expressions(&aggr_expr, &mode) .map_err(DataFusionError::into_arrow_external_error)?; @@ -513,7 +530,8 @@ impl HashAggregateStream { let schema_clone = schema.clone(); tokio::spawn(async move { - let result = compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await; + let result = + compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await; tx.send(result) }); @@ -564,7 +582,10 @@ fn aggregate_batch( impl Stream for HashAggregateStream { type Item = ArrowResult; - fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { if self.finished { return Poll::Ready(None); } @@ -625,7 +646,9 @@ fn create_batch_from_map( // 2. let mut groups = (0..num_group_expr) .map(|i| match &k[i] { - GroupByScalar::Int8(n) => Arc::new(Int8Array::from(vec![*n])) as ArrayRef, + GroupByScalar::Int8(n) => { + Arc::new(Int8Array::from(vec![*n])) as ArrayRef + } GroupByScalar::Int16(n) => Arc::new(Int16Array::from(vec![*n])), GroupByScalar::Int32(n) => Arc::new(Int32Array::from(vec![*n])), GroupByScalar::Int64(n) => Arc::new(Int64Array::from(vec![*n])), @@ -658,7 +681,9 @@ fn create_batch_from_map( Ok(batch) } -fn create_accumulators(aggr_expr: &Vec>) -> Result { +fn create_accumulators( + aggr_expr: &Vec>, +) -> Result { aggr_expr .iter() .map(|expr| expr.create_accumulator()) @@ -678,8 +703,9 @@ fn finalize_aggregation( .iter() .map(|accumulator| accumulator.state()) .map(|value| { - value - .and_then(|e| Ok(e.iter().map(|v| v.to_array()).collect::>())) + value.and_then(|e| { + Ok(e.iter().map(|v| v.to_array()).collect::>()) + }) }) .collect::>>()?; Ok(a.iter().flatten().cloned().collect::>()) @@ -795,7 +821,8 @@ mod tests { /// build the aggregates on the data from some_data() and check the results async fn check_aggregates(input: Arc) -> Result<()> { - let groups: Vec<(Arc, String)> = vec![(col("a"), "a".to_string())]; + let groups: Vec<(Arc, String)> = + vec![(col("a"), "a".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( col("b"), @@ -943,14 +970,16 @@ mod tests { #[tokio::test] async fn aggregate_source_not_yielding() -> Result<()> { - let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); check_aggregates(input).await } #[tokio::test] async fn aggregate_source_with_yielding() -> Result<()> { - let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); check_aggregates(input).await } From b28a6d46b62a10b184a320e129524a742a729de5 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 25 Nov 2020 12:08:03 +0100 Subject: [PATCH 6/7] Reduce to 16 bytes, use Box<[GroupByScalar]> for key --- rust/datafusion/src/physical_plan/group_scalar.rs | 6 +++--- rust/datafusion/src/physical_plan/hash_aggregate.rs | 12 ++++++++---- rust/datafusion/src/physical_plan/hash_join.rs | 9 ++++++--- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index 1543de1df67..8c11a6be65a 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -34,7 +34,7 @@ pub(crate) enum GroupByScalar { Int16(i16), Int32(i32), Int64(i64), - Utf8(Box), + Utf8(Box), } impl TryFrom<&ScalarValue> for GroupByScalar { @@ -50,7 +50,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar { ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), - ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone().into_boxed_str()), + ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) @@ -134,6 +134,6 @@ mod tests { #[test] fn size_of_group_by_scalar() { - assert_eq!(std::mem::size_of::(), 24); + assert_eq!(std::mem::size_of::(), 16); } } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 85065aac75b..e6cd3037120 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -250,6 +250,8 @@ fn group_aggregate_batch( key.push(GroupByScalar::UInt32(0)); } + let mut key = key.into_boxed_slice(); + // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist // 1.3 add the row' index to `indices` @@ -387,7 +389,7 @@ impl GroupedHashAggregateStream { type AccumulatorSet = Vec>; type Accumulators = - HashMap, (AccumulatorSet, Box>), RandomState>; + HashMap, (AccumulatorSet, Box>), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; @@ -656,7 +658,9 @@ fn create_batch_from_map( GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&**str])), + GroupByScalar::Utf8(str) => { + Arc::new(StringArray::from(vec![&***str])) + } }) .collect::>(); @@ -724,7 +728,7 @@ fn finalize_aggregation( pub(crate) fn create_key( group_by_keys: &[ArrayRef], row: usize, - vec: &mut Vec, + vec: &mut Box<[GroupByScalar]>, ) -> Result<()> { for i in 0..group_by_keys.len() { let col = &group_by_keys[i]; @@ -763,7 +767,7 @@ pub(crate) fn create_key( } DataType::Utf8 => { let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(array.value(row).into()) + vec[i] = GroupByScalar::Utf8(Box::new(array.value(row).into())) } _ => { // This is internal because we should have caught this before. diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 8f86df2c680..7bee9fe7c0c 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -53,7 +53,7 @@ type JoinIndex = Option<(usize, usize)>; // Maps ["on" value] -> [list of indices with this key's value] // E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true // for rows 3 and 8 from batch 0 and row 6 from batch 1. -type JoinHashMap = HashMap, Vec>; +type JoinHashMap = HashMap, Vec>; type JoinLeftData = (JoinHashMap, Vec); /// join execution plan executes partitions in parallel and combines them into a set of @@ -210,6 +210,8 @@ fn update_hash( key.push(GroupByScalar::UInt32(0)); } + let mut key = key.into_boxed_slice(); + // update the hash map for row in 0..batch.num_rows() { create_key(&keys_values, row, &mut key)?; @@ -364,8 +366,9 @@ fn build_join_indexes( JoinType::Inner => { // inner => key intersection // unfortunately rust does not support intersection of map keys :( - let left_set: HashSet> = left.keys().cloned().collect(); - let left_right: HashSet> = right.keys().cloned().collect(); + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = + right.keys().cloned().collect(); let inner = left_set.intersection(&left_right); let mut indexes = Vec::new(); // unknown a prior size From 9c8b8f06d209f667930d1d30b95ede2ad4cf73ef Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Wed, 25 Nov 2020 16:51:04 +0100 Subject: [PATCH 7/7] Remove unnecessary Box around indices --- rust/datafusion/src/physical_plan/hash_aggregate.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index e6cd3037120..a3651abaa49 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -266,8 +266,7 @@ fn group_aggregate_batch( let accumulator_set = create_accumulators(aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; - accumulators - .insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); + accumulators.insert(key.clone(), (accumulator_set, vec![row as u32])); } // 1.3 Some((_, v)) => v.push(row as u32), @@ -296,7 +295,7 @@ fn group_aggregate_batch( // 2.3 compute::take( array, - &UInt32Array::from(*indices.clone()), + &UInt32Array::from(indices.clone()), None, // None: no index check ) .unwrap() @@ -389,7 +388,7 @@ impl GroupedHashAggregateStream { type AccumulatorSet = Vec>; type Accumulators = - HashMap, (AccumulatorSet, Box>), RandomState>; + HashMap, (AccumulatorSet, Vec), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult;