diff --git a/Cargo.toml b/Cargo.toml index fd862e1332930..fae33a14af012 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ arrow-flight = { version = "45.0.0", features = ["flight-sql-experimental"] } arrow-schema = { version = "45.0.0", default-features = false } parquet = { version = "45.0.0", features = ["arrow", "async", "object_store"] } sqlparser = { version = "0.36.1", features = ["visitor"] } +zerocopy = "0.6.1" [profile.release] codegen-units = 1 diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 5a32349c65e95..70e25ed8c7fb1 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -380,6 +380,10 @@ config_namespace! { /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true + /// When set to true, the optimizer will attempt to perform limit operations + /// during aggregations, if possible + pub enable_topk_aggregation: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 913d3c84beefb..64c7c7a518a8a 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -84,7 +84,8 @@ parking_lot = "0.12" parquet = { workspace = true } percent-encoding = "2.2.0" pin-project-lite = "^0.2.7" -rand = "0.8" +rand = { version = "0.8", features = ["small_rng"] } +rand_distr = "0.4.3" smallvec = { version = "1.6", features = ["union"] } sqlparser = { workspace = true } tempfile = "3" @@ -93,6 +94,7 @@ tokio-util = { version = "0.7.4", features = ["io"] } url = "2.2" uuid = { version = "1.0", features = ["v4"] } xz2 = { version = "0.1", optional = true } +zerocopy = { workspace = true } zstd = { version = "0.12", optional = true, default-features = false } @@ -107,6 +109,8 @@ env_logger = "0.10" half = "2.2.1" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } +ptree = "0.4.0" +rand_distr = "0.4.3" regex = "1.5.4" rstest = "0.18.0" rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } @@ -159,3 +163,7 @@ name = "sql_query_with_io" [[bench]] harness = false name = "sort" + +[[bench]] +harness = false +name = "topk_aggregate" diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs new file mode 100644 index 0000000000000..2dd0d292c9314 --- /dev/null +++ b/datafusion/core/benches/topk_aggregate.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow::{datatypes::Schema, record_batch::RecordBatch}; +use arrow_array::builder::{Int64Builder, StringBuilder}; +use arrow_schema::{DataType, Field, SchemaRef}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_common::DataFusionError; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::sync::Arc; +use tokio::runtime::Runtime; + +async fn create_context( + limit: usize, + partition_cnt: i32, + sample_cnt: i32, + asc: bool, + use_topk: bool, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, asc).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + let df = ctx.sql(sql.as_str()).await?; + let physical_plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + Ok((physical_plan, ctx.task_ctx())) +} + +fn run(plan: Arc, ctx: Arc, asc: bool) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await }), + ) + .unwrap(); +} + +async fn aggregate( + plan: Arc, + ctx: Arc, + asc: bool, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + let actual = format!("{}", pretty_format_batches(&batches)?); + let expected_asc = r#" ++----------------------------------+--------------------------+ +| trace_id | MAX(traces.timestamp_ms) | ++----------------------------------+--------------------------+ +| 233dda9b90a120947515b7d62f46683b | 16909000999999 | +| dd5dee780518eac66c2a4bef4bb64536 | 16909000999998 | +| 67b7d474236429b889b53259dd23c9c6 | 16909000999997 | +| e93942f02f96759e4f0c17700dbe944e | 16909000999996 | +| 2b0b92a840de408268dc45534ff14b66 | 16909000999995 | +| 1557b5073068dc4ea1e65b22af727733 | 16909000999994 | +| 01b8ea6ba186370da76c1634992326dc | 16909000999993 | +| 0c5c9fb232710c7f618d8d5719cd9acd | 16909000999992 | +| b36157951cd7d886514cc49c9ff00987 | 16909000999991 | +| 2a377b56d35464771442a7fb147108ee | 16909000999990 | ++----------------------------------+--------------------------+ + "# + .trim(); + if asc { + assert_eq!(actual.trim(), expected_asc); + } + + Ok(()) +} + +fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + use rand::Rng; + use rand::SeedableRng; + + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} + +fn criterion_benchmark(c: &mut Criterion) { + let limit = 10; + let partitions = 10; + let samples = 100_000; + + let rt = Runtime::new().unwrap(); + let topk_real = rt.block_on(async { + create_context(limit, partitions, samples, false, true) + .await + .unwrap() + }); + let topk_asc = rt.block_on(async { + create_context(limit, partitions, samples, true, true) + .await + .unwrap() + }); + let real = rt.block_on(async { + create_context(limit, partitions, samples, false, false) + .await + .unwrap() + }); + let asc = rt.block_on(async { + create_context(limit, partitions, samples, true, false) + .await + .unwrap() + }); + + c.bench_function( + format!("aggregate {} time-series rows", partitions * samples).as_str(), + |b| b.iter(|| run(real.0.clone(), real.1.clone(), false)), + ); + + c.bench_function( + format!("aggregate {} worst-case rows", partitions * samples).as_str(), + |b| b.iter(|| run(asc.0.clone(), asc.1.clone(), true)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} time-series rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_real.0.clone(), topk_real.1.clone(), false)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} worst-case rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_asc.0.clone(), topk_asc.1.clone(), true)), + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index f74d4ea0c9a6f..ad950e26bbfa5 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -33,6 +33,7 @@ pub mod repartition; pub mod replace_with_order_preserving_variants; pub mod sort_enforcement; mod sort_pushdown; +pub mod topk_aggregation; mod utils; #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 3f6698c6cf466..d629eb0c8e0ae 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -28,6 +28,7 @@ use crate::physical_optimizer::join_selection::JoinSelection; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::repartition::Repartition; use crate::physical_optimizer::sort_enforcement::EnforceSorting; +use crate::physical_optimizer::topk_aggregation::TopKAggregation; use crate::{error::Result, physical_plan::ExecutionPlan}; /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which @@ -101,6 +102,11 @@ impl PhysicalOptimizer { // diagnostic error message when this happens. It makes no changes to the // given query plan; i.e. it only acts as a final gatekeeping rule. Arc::new(PipelineChecker::new()), + // The aggregation limiter will try to find situations where the accumulator count + // is not tied to the cardinality, i.e. when the output of the aggregation is passed + // into an `order by max(x) limit y`. In this case it will copy the limit value down + // to the aggregation, allowing it to use only y number of accumulators. + Arc::new(TopKAggregation::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs new file mode 100644 index 0000000000000..e1ad46cb69175 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that detects aggregate operations that could use a limited bucket count + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::ExecutionPlan; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalSortExpr; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed +pub struct TopKAggregation {} + +impl TopKAggregation { + /// Create a new `LimitAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + order: &PhysicalSortExpr, + limit: usize, + ) -> Option> { + // ensure the sort direction matches aggregate function + let (field, desc) = aggr.get_minmax_desc()?; + if desc != order.options.descending { + return None; + } + let group_key = match aggr.group_expr().expr() { + [expr] => expr, // only one group key + _ => return None, + }; + match group_key.0.data_type(&aggr.input_schema).ok() { + Some(DataType::Utf8) => {} // only String keys for now + _ => return None, + } + if aggr + .filter_expr + .iter() + .fold(false, |acc, cur| acc | cur.is_some()) + { + return None; + } + + // ensure the sort is on the same field as the aggregate output + let col = order.expr.as_any().downcast_ref::()?; + if col.name() != field.name() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let mut new_aggr = AggregateExec::try_new( + aggr.mode, + aggr.group_by.clone(), + aggr.aggr_expr.clone(), + aggr.filter_expr.clone(), + aggr.order_by_expr.clone(), + aggr.input.clone(), + aggr.input_schema.clone(), + ) + .expect("Unable to copy Aggregate!"); + new_aggr.limit = Some(limit); + Some(Arc::new(new_aggr)) + } + + fn transform_sort(plan: Arc) -> Option> { + let sort = plan.as_any().downcast_ref::()?; + + // TODO: support sorting on multiple fields + let children = sort.children(); + let child = match children.as_slice() { + [child] => child.clone(), + _ => return None, + }; + let order = sort.output_ordering()?; + let order = match order { + [order] => order, + _ => return None, + }; + let limit = sort.fetch()?; + + let is_cardinality_preserving = |plan: Arc| { + plan.as_any() + .downcast_ref::() + .is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + // TODO: whitelist joins that don't increase row count? + }; + + let mut cardinality_preserved = true; + let mut closure = |plan: Arc| { + if !cardinality_preserved { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + // either we run into an Aggregate and transform it + match Self::transform_agg(aggr, order, limit) { + None => cardinality_preserved = false, + Some(plan) => return Ok(Transformed::Yes(plan)), + } + } else { + // or we continue down whitelisted nodes of other types + if !is_cardinality_preserving(plan.clone()) { + cardinality_preserved = false; + } + } + Ok(Transformed::No(plan)) + }; + let child = transform_down_mut(child, &mut closure).ok()?; + let sort = SortExec::new(sort.expr().to_vec(), child) + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()); + Some(Arc::new(sort)) + } +} + +fn transform_down_mut( + me: Arc, + op: &mut F, +) -> Result> +where + F: FnMut(Arc) -> Result>>, +{ + let after_op = op(me)?.into(); + after_op.map_children(|node| transform_down_mut(node, op)) +} + +impl Default for TopKAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for TopKAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_topk_aggregation { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +// TODO: tests diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 8338da8ed6777..2c3b2e5d122a3 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -49,11 +49,14 @@ use std::sync::Arc; mod group_values; mod no_grouping; mod order; +mod priority_map; mod row_hash; +use crate::physical_plan::aggregates::priority_map::GroupedTopKAggregateStream; pub use datafusion_expr::AggregateFunction; use datafusion_physical_expr::aggregate::is_order_sensitive; pub use datafusion_physical_expr::expressions::create_aggregate_expr; +use datafusion_physical_expr::expressions::{Max, Min}; use datafusion_physical_expr::utils::{ get_finer_ordering, ordering_satisfy_requirement_concrete, }; @@ -228,14 +231,16 @@ impl PartialEq for PhysicalGroupBy { enum StreamType { AggregateStream(AggregateStream), - GroupedHashAggregateStream(GroupedHashAggregateStream), + GroupedHash(GroupedHashAggregateStream), + GroupedPriorityQueue(GroupedTopKAggregateStream), } impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), - StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), + StreamType::GroupedHash(stream) => Box::pin(stream), + StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } } @@ -265,6 +270,8 @@ pub struct AggregateExec { pub(crate) filter_expr: Vec>>, /// (ORDER BY clause) expression for each aggregate expression pub(crate) order_by_expr: Vec>, + /// Set if the output of this aggregation is truncated by a upstream sort/limit clause + pub(crate) limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub(crate) input: Arc, /// Schema after the aggregate is applied @@ -669,6 +676,7 @@ impl AggregateExec { metrics: ExecutionPlanMetricsSet::new(), aggregation_ordering, required_input_ordering, + limit: None, }) } @@ -717,14 +725,38 @@ impl AggregateExec { partition: usize, context: Arc, ) -> Result { + // no group by at all if self.group_by.expr.is_empty() { - Ok(StreamType::AggregateStream(AggregateStream::new( + return Ok(StreamType::AggregateStream(AggregateStream::new( self, context, partition, - )?)) + )?)); + } + + // grouping by an expression that has a sort/limit upstream + if let Some(limit) = self.limit { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); + } + + // grouping by something else and we need to just materialize all results + Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( + self, context, partition, + )?)) + } + + /// Finds the DataType and SortDirection for this Aggregate, if there is one + pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + let agg_expr = match self.aggr_expr.as_slice() { + [expr] => expr, + _ => return None, + }; + if let Some(max) = agg_expr.as_any().downcast_ref::() { + Some((max.field().ok()?, true)) + } else if let Some(min) = agg_expr.as_any().downcast_ref::() { + Some((min.field().ok()?, true)) } else { - Ok(StreamType::GroupedHashAggregateStream( - GroupedHashAggregateStream::new(self, context, partition)?, - )) + None } } } @@ -793,6 +825,9 @@ impl DisplayAs for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; + if let Some(limit) = self.limit { + write!(f, ", lim=[{limit}]")?; + } if let Some(aggregation_ordering) = &self.aggregation_ordering { write!(f, ", ordering_mode={:?}", aggregation_ordering.mode)?; @@ -900,7 +935,7 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(AggregateExec::try_new( + let mut me = AggregateExec::try_new( self.mode, self.group_by.clone(), self.aggr_expr.clone(), @@ -908,7 +943,9 @@ impl ExecutionPlan for AggregateExec { self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), - )?)) + )?; + me.limit = self.limit; + Ok(Arc::new(me)) } fn execute( @@ -1115,7 +1152,7 @@ fn evaluate( } /// Evaluates expressions against a record batch. -fn evaluate_many( +pub fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { @@ -1138,7 +1175,17 @@ fn evaluate_optional( .collect::>>() } -fn evaluate_group_by( +/// Evaluate a group by expression against a `RecordBatch` +/// +/// Arguments: +/// `group_by`: the expression to evaluate +/// `batch`: the `RecordBatch` to evaluate against +/// +/// Returns: A Vec of Vecs of Array of results +/// The outer Vect appears to be for grouping sets +/// The inner Vect contains the results per expression +/// The inner-inner Array contains the results per row +pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { @@ -1798,10 +1845,10 @@ mod tests { assert!(matches!(stream, StreamType::AggregateStream(_))); } 1 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } 2 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } _ => panic!("Unknown version: {version}"), } diff --git a/datafusion/core/src/physical_plan/aggregates/priority_map.rs b/datafusion/core/src/physical_plan/aggregates/priority_map.rs new file mode 100644 index 0000000000000..2b3fd3c7c5293 --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/priority_map.rs @@ -0,0 +1,1106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A memory-conscious aggregation implementation that limits group buckets to a fixed number + +use crate::physical_plan::aggregates::{ + aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, + PhysicalGroupBy, +}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use ahash::RandomState; +use arrow::util::pretty::print_batches; +use arrow_array::cast::AsArray; +use arrow_array::downcast_primitive; +use arrow_array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, RecordBatch, + StringArray, +}; +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{Stream, StreamExt}; +use hashbrown::raw::RawTable; +use itertools::Itertools; +use log::{trace, Level}; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct GroupedTopKAggregateStream { + partition: usize, + row_count: usize, + started: bool, + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec>>, + group_by: PhysicalGroupBy, + aggregator: Box, +} + +impl GroupedTopKAggregateStream { + pub fn new( + agg: &AggregateExec, + context: Arc, + partition: usize, + limit: usize, + ) -> Result { + let agg_schema = Arc::clone(&agg.schema); + let group_by = agg.group_by.clone(); + + let input = agg.input.execute(partition, Arc::clone(&context))?; + + let aggregate_arguments = + aggregate_expressions(&agg.aggr_expr, &agg.mode, group_by.expr.len())?; + + let (val_field, descending) = agg + .get_minmax_desc() + .ok_or_else(|| DataFusionError::Execution("Min/max required".to_string()))?; + + let vt = val_field.data_type().clone(); + let ag = new_group_values(limit, descending, vt)?; + + Ok(GroupedTopKAggregateStream { + partition, + started: false, + row_count: 0, + schema: agg_schema, + input, + aggregate_arguments, + group_by, + aggregator: ag, + }) + } +} + +pub fn new_group_values( + limit: usize, + desc: bool, + vt: DataType, +) -> Result> { + macro_rules! downcast_helper { + ($vt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveAggregator::<$vt>::new( + limit, + limit * 10, + desc, + ))) + }; + } + + downcast_primitive! { + vt => (downcast_helper, vt), + _ => {} + } + + Err(DataFusionError::Execution(format!( + "Can't group type: {vt:?}" + ))) +} + +impl RecordBatchStream for GroupedTopKAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +pub trait LimitedAggregator: Send { + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>; + fn emit(&mut self) -> Result>; + fn is_empty(&self) -> bool; +} + +pub trait ValueType: ArrowNativeTypeOp + Clone {} + +impl ValueType for T where T: ArrowNativeTypeOp + Clone {} + +pub trait KeyType: Clone + Eq + Hash {} + +impl KeyType for T where T: Clone + Eq + Hash {} + +struct PrimitiveAggregator +where + ::Native: Clone, +{ + priority_map: PriorityMap, VAL::Native>, +} + +impl PrimitiveAggregator +where + ::Native: Clone, +{ + pub fn new(limit: usize, capacity: usize, descending: bool) -> Self { + Self { + priority_map: PriorityMap::new(limit, capacity, descending), + } + } +} + +unsafe impl Send for PrimitiveAggregator where + ::Native: Clone +{ +} + +impl LimitedAggregator for PrimitiveAggregator +where + ::Native: Clone, +{ + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { + let ids = ids.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Execution("Expected StringArray".to_string()) + })?; + let vals = vals.as_primitive::(); + let null_count = vals.null_count(); + let desc = self.priority_map.desc; + for row_idx in 0..ids.len() { + if null_count > 0 && vals.is_null(row_idx) { + continue; + } + let val = vals.value(row_idx); + let id = if ids.is_null(row_idx) { + None + } else { + // Check goes here, because it is generalizable between str/String and Row/OwnedRow + let id = ids.value(row_idx); + if self.priority_map.is_full() { + if let Some(worst) = self.priority_map.worst_val() { + if desc { + if val < *worst { + continue; + } + } else if val > *worst { + continue; + } + } + } + Some(id.to_string()) + }; + + self.priority_map.insert(id, val)?; + } + Ok(()) + } + + fn emit(&mut self) -> Result> { + let (keys, vals): (Vec<_>, Vec<_>) = + self.priority_map.drain().into_iter().unzip(); + let keys = Arc::new(StringArray::from(keys)); + let vals = Arc::new(PrimitiveArray::::from_iter_values(vals)); + Ok(vec![keys, vals]) + } + + fn is_empty(&self) -> bool { + self.priority_map.is_empty() + } +} + +impl Stream for GroupedTopKAggregateStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { + match res { + // got a batch, convert to rows and append to our TreeMap + Some(Ok(batch)) => { + self.started = true; + trace!( + "partition {} has {} rows and got batch with {} rows", + self.partition, + self.row_count, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { + print_batches(&[batch.clone()])?; + } + self.row_count += batch.num_rows(); + let batches = &[batch]; + let group_by_values = + evaluate_group_by(&self.group_by, batches.first().unwrap())?; + let group_by_values = + group_by_values.into_iter().last().expect("values"); + let group_by_values = + group_by_values.into_iter().last().expect("values"); + let input_values = evaluate_many( + &self.aggregate_arguments, + batches.first().unwrap(), + )?; + let input_values = match input_values.as_slice() { + [] => { + Err(DataFusionError::Execution("vals required".to_string()))? + } + [vals] => vals, + _ => { + Err(DataFusionError::Execution("1 val required".to_string()))? + } + }; + let input_values = match input_values.as_slice() { + [] => { + Err(DataFusionError::Execution("vals required".to_string()))? + } + [vals] => vals, + _ => { + Err(DataFusionError::Execution("1 val required".to_string()))? + } + } + .clone(); + + // iterate over each column of group_by values + (*self.aggregator).intern(group_by_values, input_values)?; + } + // inner is done, emit all rows and switch to producing output + None => { + if self.aggregator.is_empty() { + trace!("partition {} emit None", self.partition); + return Poll::Ready(None); + } + let cols = self.aggregator.emit()?; + let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + trace!( + "partition {} emit batch with {} rows", + self.partition, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) { + print_batches(&[batch.clone()])?; + } + return Poll::Ready(Some(Ok(batch))); + } + // inner had error, return to caller + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + } + } + Poll::Pending + } +} + +/// A dual data structure consisting of a bi-directionally linked Map & Heap +/// +/// The implementation is optimized for performance because `insert()` will be called on billions of +/// rows. Because traversing between the map & heap will happen frequently, it is important to +/// be highly optimized. +/// +/// In order to quickly traverse from heap to map, we use the unsafe raw indexes that `RawTable` +/// exposes to us to avoid needing to find buckets based on their hash. +pub struct PriorityMap { + limit: usize, + desc: bool, + rnd: RandomState, + id_to_hi: RawTable>, + root: Option<*mut HeapItem>, +} + +pub struct MapItem { + hash: u64, + pub id: ID, + hi: *mut HeapItem, // TODO: *mut void +} + +impl MapItem { + pub fn new(hash: u64, id: ID, val: *mut HeapItem) -> Self { + Self { hash, id, hi: val } + } +} + +pub struct HeapItem { + val: VAL, + buk_idx: usize, + parent: Option<*mut HeapItem>, + left: Option<*mut HeapItem>, + right: Option<*mut HeapItem>, +} + +impl HeapItem { + pub fn new(val: VAL, buk_idx: usize) -> Self { + Self { + val, + buk_idx, + parent: None, + left: None, + right: None, + } + } + + #[cfg(test)] + pub fn tree_print(&self, builder: &mut ptree::TreeBuilder) { + unsafe { + let ptext = self + .parent + .map(|p| (*p).buk_idx.to_string()) + .unwrap_or("".to_string()); + let label = format!( + "bucket={:?} val={:?} parent={}", + self.buk_idx, self.val, ptext + ); + builder.begin_child(label); + } + for child in [&self.left, &self.right] { + if let Some(child) = child { + unsafe { (**child).tree_print(builder) } + } else { + builder.add_empty_child("None".to_string()); + } + } + builder.end_child(); + } +} + +impl Debug for HeapItem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("bucket=")?; + self.buk_idx.fmt(f)?; + f.write_str(" val=")?; + self.val.fmt(f)?; + Ok(()) + } +} + +impl Eq for HeapItem {} + +impl PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + let res = self.val.compare(other.val); + if res != Ordering::Equal { + return res; + } + self.buk_idx.cmp(&other.buk_idx) + } +} + +impl PriorityMap +where + VAL: PartialEq, +{ + pub fn new(limit: usize, capacity: usize, desc: bool) -> Self { + Self { + limit, + desc, + rnd: Default::default(), + id_to_hi: RawTable::with_capacity(capacity), + root: None, + } + } + + pub fn insert(&mut self, new_id: ID, new_val: VAL) -> Result<()> { + let is_full = self.is_full(); + let desc = self.desc; + assert!(self.len() <= self.limit, "Overflow"); + + // if we're full, and the new val is worse than all our values, just bail + if is_full { + let worst_val = self.worst_val().expect("Missing value!"); + if (!desc && new_val > *worst_val) || (desc && new_val < *worst_val) { + return Ok(()); + } + } + + // handle new groups we haven't seen yet + let new_hash = self.rnd.hash_one(&new_id); + let old_bucket = match self.id_to_hi.find(new_hash, |mi| new_id == mi.id) { + None => { + // we're full and this is a better value, so remove the worst from the map + if is_full { + let worst_hi = self.root.expect("Missing value!"); + unsafe { + self.id_to_hi + .erase(self.id_to_hi.bucket((*worst_hi).buk_idx)) + }; + } + + // add the new group to the map + let new_hi = Box::into_raw(Box::new(HeapItem::new(new_val, 0))); + let mi = MapItem::new(new_hash, new_id, new_hi); + let bucket = self.id_to_hi.try_insert_no_grow(new_hash, mi); + let bucket = match bucket { + Ok(bucket) => bucket, + Err(new_item) => { + // this should basically never happen, but if it does, we must rebuild + let bucket = + self.id_to_hi.insert(new_hash, new_item, |mi| mi.hash); + unsafe { + for bucket in self.id_to_hi.iter() { + let existing_mi = bucket.as_mut(); + let existing_hi = &mut *existing_mi.hi; + existing_hi.buk_idx = self.id_to_hi.bucket_index(&bucket); + } + } + bucket + } + }; + unsafe { + (*new_hi).buk_idx = self.id_to_hi.bucket_index(&bucket); + } + + // update heap + if let Some(root) = self.root { + let root = unsafe { &mut *root }; + if self.is_full() { + // replace top node + self.take_children(new_hi, root); + let old_root = self.root.replace(new_hi); + self.heapify_down(new_hi); + if let Some(old_root) = old_root { + let _old_root = unsafe { Box::from_raw(old_root) }; + } + } else { + // append to end of tree + let old = put_child(root, new_hi, tree_path(self.len() - 1)); + assert!(old.is_none(), "Overwrote node!"); + self.heapify_up(new_hi); + } + } else { + // first entry ever + self.root = Some(new_hi); + } + return Ok(()); + } + Some(bucket) => bucket, + }; + + // this is a value for an existing group + let existing_mi = unsafe { old_bucket.as_mut() }; + let existing_hi = unsafe { &mut *existing_mi.hi }; + if (!desc && new_val >= existing_hi.val) || (desc && new_val <= existing_hi.val) { + // worse than the existing value _for this group_ + return Ok(()); + } + + // update heap + existing_hi.val = new_val; + + Ok(()) + } + + fn take_children( + &self, + new_parent: *mut HeapItem, + old_parent: &mut HeapItem, + ) { + unsafe { + (*new_parent).left = old_parent.left.take(); + (*new_parent).right = old_parent.right.take(); + (*new_parent).left.map(|n| (*n).parent.replace(new_parent)); + (*new_parent).right.map(|n| (*n).parent.replace(new_parent)); + } + } + + fn swap(&mut self, child_ptr: *mut HeapItem, parent_ptr: *mut HeapItem) { + let child = unsafe { &mut *child_ptr }; + let parent = unsafe { &mut *parent_ptr }; + + if child.parent != Some(parent_ptr) { + panic!("Child is not of this parent"); + } + + // store the grand parents and grand children - they are outside the swap + let grand_left = child.left.take(); + let grand_right = child.right.take(); + let grand_parent = parent.parent.take(); + + // transfer parent's children to child + if parent.left == Some(child_ptr) { + child.left = Some(parent); + child.right = parent.right.take(); + let _ = unsafe { child.right.map(|n| (*n).parent.replace(child)) }; + } else if parent.right == Some(child_ptr) { + child.right = Some(parent); + child.left = parent.left.take(); + let _ = unsafe { child.left.map(|n| (*n).parent.replace(child)) }; + } else { + panic!("Child is illegitimate"); + } + parent.parent = Some(child_ptr); + + // transfer child's children to parent + parent.left = grand_left; + parent.right = grand_right; + let _ = unsafe { grand_left.map(|n| (*n).parent.replace(parent)) }; + let _ = unsafe { grand_right.map(|n| (*n).parent.replace(parent)) }; + + // make the child of the grandparent + if let Some(gp_ptr) = grand_parent { + let gp = unsafe { &mut *gp_ptr }; + if gp.left == Some(parent_ptr) { + gp.left = Some(child); + } else if gp.right == Some(parent_ptr) { + gp.right = Some(child); + } else { + panic!("Parent is illegitimate"); + } + assert_eq!(child.parent.replace(gp_ptr), Some(parent_ptr)); + } else { + let _ = child.parent.take(); + self.root = Some(child); + } + } + + fn heapify_up(&mut self, node_ptr: *mut HeapItem) { + let node = unsafe { &mut *node_ptr }; + let parent_ptr = match node.parent { + None => return, + Some(parent) => parent, + }; + let parent = unsafe { &mut *parent_ptr }; + if !self.desc && node.val <= parent.val { + return; + } + if self.desc && node.val >= parent.val { + return; + } + + self.swap(node, parent_ptr); + self.heapify_up(parent_ptr); + } + + fn heapify_down(&mut self, node: *mut HeapItem) { + unsafe { + let mut best_node = node; + if let Some(child) = (*node).left { + if !self.desc && (*child).val > (*best_node).val { + best_node = child; + } + if self.desc && (*child).val < (*best_node).val { + best_node = child; + } + } + if let Some(child) = (*node).right { + if !self.desc && (*child).val > (*best_node).val { + best_node = child; + } + if self.desc && (*child).val < (*best_node).val { + best_node = child; + } + } + if node != best_node { + self.swap(best_node, node); + self.heapify_down(node); + } + } + } + + pub fn len(&self) -> usize { + self.id_to_hi.len() + } + + pub fn is_empty(&self) -> bool { + self.id_to_hi.is_empty() + } + + pub fn is_full(&self) -> bool { + self.len() >= self.limit + } + + pub fn worst_val(&mut self) -> Option<&VAL> { + self.root.map(|hi| unsafe { &(*hi).val }) + } + + pub fn drain(&mut self) -> Vec<(ID, VAL)> { + // TODO: drain heap to sort + let tups: Vec<_> = unsafe { + self.id_to_hi + .drain() + .map(|mi| { + let val = (*mi.hi).val.clone(); + // TODO: free + (mi.id, val) + }) + .collect() + }; + let mut tups: Vec<_> = tups + .into_iter() + .sorted_by(|a, b| a.1.compare(b.1)) + .collect(); + if self.desc { + tups.reverse(); + } + tups + } + + #[cfg(test)] + pub fn tree_print(&self) -> String { + let mut builder = if let Some(root) = &self.root { + let mut builder = ptree::TreeBuilder::new("BinaryHeap".to_string()); + unsafe { (**root).tree_print(&mut builder) }; + builder + } else { + ptree::TreeBuilder::new("Empty BinaryHeap".to_string()) + }; + let mut actual = Vec::new(); + ptree::write_tree(&builder.build(), &mut actual).unwrap(); + String::from_utf8(actual).unwrap() + } +} + +pub fn put_child( + node_ptr: *mut HeapItem, + new_child: *mut HeapItem, + mut path: Vec, +) -> Option<*mut HeapItem> { + let dir = path.pop().expect("empty path"); + if path.is_empty() { + let old_parent = unsafe { (*new_child).parent.replace(node_ptr) }; + assert!(old_parent.is_none(), "Replaced parent!"); + if dir { + unsafe { (*node_ptr).right.replace(new_child) } + } else { + unsafe { (*node_ptr).left.replace(new_child) } + } + } else if dir { + unsafe { put_child((*node_ptr).right.unwrap(), new_child, path) } + } else { + unsafe { put_child((*node_ptr).left.unwrap(), new_child, path) } + } +} + +fn tree_path(mut idx: usize) -> Vec { + let mut path = vec![]; + while idx != 0 { + path.push(idx % 2 == 0); + idx = (idx - 1) / 2; + } + path.reverse(); + path +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::types::Int64Type; + use arrow_array::Int64Array; + use arrow_schema::DataType; + use arrow_schema::Field; + use arrow_schema::Schema; + + #[test] + fn should_swap() -> Result<()> { + let mut map = PriorityMap::::new(10, 100, false); + let root = Box::into_raw(Box::new(HeapItem::new(1, 1))); + let child = Box::into_raw(Box::new(HeapItem::new(2, 2))); + unsafe { + (*root).left = Some(child); + (*child).parent = Some(root); + } + map.root = Some(root); + let actual = map.tree_print(); + let expected = r#" +BinaryHeap +└─ bucket=1 val=1 parent= + ├─ bucket=2 val=2 parent=1 + │ ├─ None + │ └─ None + └─ None + "# + .trim(); + assert_eq!(actual.trim(), expected); + + // exercise + map.swap(child, root); + + // assert + let actual = map.tree_print(); + let expected = r#" +BinaryHeap +└─ bucket=2 val=2 parent= + ├─ bucket=1 val=1 parent=2 + │ ├─ None + │ └─ None + └─ None + "# + .trim(); + assert_eq!(actual.trim(), expected); + + Ok(()) + } + + #[test] + fn should_swap_grandchildren() -> Result<()> { + let mut map = PriorityMap::::new(10, 100, false); + let root = Box::into_raw(Box::new(HeapItem::new(1, 1))); + let l = Box::into_raw(Box::new(HeapItem::new(2, 2))); + let r = Box::into_raw(Box::new(HeapItem::new(5, 5))); + let ll = Box::into_raw(Box::new(HeapItem::new(3, 3))); + let lr = Box::into_raw(Box::new(HeapItem::new(4, 4))); + unsafe { + (*root).left = Some(l); + (*l).parent = Some(root); + + (*root).right = Some(r); + (*r).parent = Some(root); + + (*l).left = Some(ll); + (*ll).parent = Some(l); + + (*l).right = Some(lr); + (*lr).parent = Some(l); + } + map.root = Some(root); + let actual = map.tree_print(); + let expected = r#" +BinaryHeap +└─ bucket=1 val=1 parent= + ├─ bucket=2 val=2 parent=1 + │ ├─ bucket=3 val=3 parent=2 + │ │ ├─ None + │ │ └─ None + │ └─ bucket=4 val=4 parent=2 + │ ├─ None + │ └─ None + └─ bucket=5 val=5 parent=1 + ├─ None + └─ None + "# + .trim(); + assert_eq!(actual.trim(), expected); + + // exercise + map.swap(l, root); + + // assert + let actual = map.tree_print(); + let expected = r#" +BinaryHeap +└─ bucket=2 val=2 parent= + ├─ bucket=1 val=1 parent=2 + │ ├─ bucket=3 val=3 parent=1 + │ │ ├─ None + │ │ └─ None + │ └─ bucket=4 val=4 parent=1 + │ ├─ None + │ └─ None + └─ bucket=5 val=5 parent=2 + ├─ None + └─ None + "# + .trim(); + assert_eq!(actual.trim(), expected); + + Ok(()) + } + + #[test] + fn should_append() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PrimitiveAggregator::::new(1, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PrimitiveAggregator::::new(1, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PrimitiveAggregator::::new(1, 10, true); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PrimitiveAggregator::::new(2, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PrimitiveAggregator::::new(2, 10, true); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PrimitiveAggregator::::new(1, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PrimitiveAggregator::::new(1, 10, true); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PrimitiveAggregator::::new(2, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PrimitiveAggregator::::new(2, 10, true); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_handle_null_ids() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); + let mut agg = PrimitiveAggregator::::new(2, 10, true); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| | 3 | +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_null_vals() -> Result<()> { + let ids: ArrayRef = + Arc::new(StringArray::from(vec![Some("1"), Some("1"), Some("3")])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])); + let mut agg = PrimitiveAggregator::::new(2, 10, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | +| 3 | 3 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_retain_state_after_resize() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3", "4", "5"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let mut agg = PrimitiveAggregator::::new(5, 3, false); + agg.intern(ids, vals)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | +| 2 | 2 | +| 3 | 3 | +| 4 | 4 | +| 5 | 5 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 4613a2e46443e..8e728c3338eaa 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -56,7 +56,7 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; -/// Hash based Grouping Aggregator +/// HashTable based Grouping Aggregator /// /// # Design Goals /// @@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream { /// accumulator. If present, only those rows for which the filter /// evaluate to true should be included in the aggregate results. /// - /// For example, for an aggregate like `SUM(x FILTER x > 100)`, + /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, /// the filter expression is `x > 100`. filter_expressions: Vec>>, @@ -266,7 +266,7 @@ impl GroupedHashAggregateStream { /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. -fn create_group_accumulator( +pub fn create_group_accumulator( agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { diff --git a/datafusion/core/tests/sql/order.rs b/datafusion/core/tests/sql/order.rs index 3981fbaa4d7ab..a400a78fc9146 100644 --- a/datafusion/core/tests/sql/order.rs +++ b/datafusion/core/tests/sql/order.rs @@ -48,7 +48,9 @@ async fn sort_with_lots_of_repetition_values() -> Result<()> { async fn create_external_table_with_order() -> Result<()> { let ctx = SessionContext::new(); let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a_id ASC) LOCATION 'file://path/to/table';"; - let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = ctx.state().create_logical_plan(sql).await? else { + let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = + ctx.state().create_logical_plan(sql).await? + else { panic!("Wrong command") }; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 72cd10f1953c3..532100f37b316 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use arrow::util::pretty::pretty_format_batches; +use datafusion::physical_plan::aggregates::AggregateExec; use datafusion_common::ScalarValue; use tempfile::TempDir; @@ -572,6 +574,79 @@ async fn parallel_query_with_filter() -> Result<()> { Ok(()) } +#[tokio::test] +async fn parallel_query_with_limit() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + let dataframe = ctx + .sql("SELECT c3, max(c2) as max FROM test group by c3 order by max desc limit 2") + .await?; + + let actual_logical_plan = format!("{:?}", dataframe.logical_plan()); + let expected_logical_plan = r#" +Limit: skip=0, fetch=2 + Sort: max DESC NULLS FIRST + Projection: test.c3, MAX(test.c2) AS max + Aggregate: groupBy=[[test.c3]], aggr=[[MAX(test.c2)]] + TableScan: test + "# + .trim(); + assert_eq!(expected_logical_plan, actual_logical_plan); + + let physical_plan = dataframe.create_physical_plan().await?; + + // TODO: find the GroupedHashAggregateStream node and see if we can assert bucket count + finder(physical_plan.clone()); + + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let mut expected_physical_plan = r#" +GlobalLimitExec: skip=0, fetch=2 + SortPreservingMergeExec: [max@1 DESC], fetch=2 + SortExec: fetch=2, expr=[max@1 DESC] + ProjectionExec: expr=[c3@0 as c3, MAX(test.c2)@1 as max] + AggregateExec: mode=FinalPartitioned, gby=[c3@0 as c3], aggr=[MAX(test.c2)] + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([c3@0], 8), input_partitions=8 + AggregateExec: mode=Partial, gby=[c3@1 as c3], aggr=[MAX(test.c2)] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 + "# + .trim() + .to_string(); + expected_physical_plan += "\n"; + expected_physical_plan += actual_phys_plan + .lines() + .last() + .expect("Plan should not be empty"); + expected_physical_plan += "\n"; + assert_eq!(actual_phys_plan, expected_physical_plan); + + let batches = collect(physical_plan, ctx.task_ctx()).await?; + let actual_rows = format!("{}", pretty_format_batches(batches.as_slice())?); + let expected = r#" ++-------+-----+ +| c3 | max | ++-------+-----+ +| true | 10 | +| false | 9 | ++-------+-----+ +"# + .trim(); + assert_eq!(expected, actual_rows); + + Ok(()) +} + +fn finder(plan: Arc) { + if let Some(_aggr) = plan.as_any().downcast_ref::() { + println!("Found it!"); + } + for child in &plan.children() { + finder(child.clone()); + } +} + #[tokio::test] async fn boolean_literal() -> Result<()> { let results = diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 895432026b483..3cf564f367bab 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -412,7 +412,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => { - let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; + let Expr::ScalarSubquery(subquery) = list.remove(0) else { + unreachable!() + }; Expr::InSubquery(InSubquery::new(expr, subquery, negated)) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e881acf5755b1..6e2b99057e401 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2261,7 +2261,7 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 -# bool aggregtion +# bool aggregation statement ok CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3); @@ -2291,7 +2291,131 @@ false true NULL +# TopK aggregation +statement ok +CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES +(NULL, 0), +('a', NULL), +('a', 1), +('b', 0), +('c', 1), +('c', 2), +('b', 3); + +statement ok +set datafusion.optimizer.enable_topk_aggregation = false; +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +NULL 0 +b 0 +c 1 +a 1 + +statement ok +set datafusion.optimizer.enable_topk_aggregation = true; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: traces.trace_id ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 +----SortExec: fetch=4, expr=[trace_id@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +NULL 0 +b 0 +c 1 +a 1 # # regr_*() tests diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index ad9b2be40e9e3..44b67c78ed27e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -250,4 +250,5 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 5db305105f53b..0bb30dc0bd707 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -182,6 +182,7 @@ datafusion.explain.physical_plan_only false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.bounded_order_preserving_variants false datafusion.optimizer.enable_round_robin_repartition true +datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 datafusion.optimizer.max_passes 3 diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 7580322e2069d..b1ac0437fab92 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1280,7 +1280,7 @@ async fn make_datafusion_like( let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { return Err(DataFusionError::Substrait(format!( "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", - ))) + ))); }; Ok(Arc::new(Expr::Like(Like { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 65fa7c7f9fd9d..8c5ff4ac3a38b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1652,7 +1652,10 @@ mod test { println!("Checking round trip of {scalar:?}"); let substrait = to_substrait_literal(&scalar)?; - let Expression { rex_type: Some(RexType::Literal(substrait_literal)) } = substrait else { + let Expression { + rex_type: Some(RexType::Literal(substrait_literal)), + } = substrait + else { panic!("Expected Literal expression, got {substrait:?}"); }; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index a81bece2b5203..9692bb5974da2 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -76,6 +76,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | | datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | | datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | | datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. |