Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 119 additions & 42 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use std::sync::Arc;

use arrow::datatypes::Schema;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;

use crate::execution::context::SessionConfig;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
Expand All @@ -37,6 +38,9 @@ use crate::error::Result;
#[derive(Default)]
pub struct AggregateStatistics {}

/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`]
const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constant was hard coded in a few places and I think this symbolic name helps understand what it is doing


impl AggregateStatistics {
#[allow(missing_docs)]
pub fn new() -> Self {
Expand Down Expand Up @@ -148,10 +152,10 @@ fn take_optimizable_table_count(
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
if lit_expr.value() == &COUNT_STAR_EXPANSION {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an implicit coupling between the SQL planner and this file, which I have now made explicit with a named constant

return Some((
ScalarValue::UInt64(Some(num_rows as u64)),
"COUNT(UInt8(1))",
ScalarValue::Int64(Some(num_rows as i64)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from UInt64 to Int64 here and a few lines below is the actual bug fix / change of behavior -- the rest of this PR is testing / readability improvements

COUNT_STAR_NAME,
));
}
}
Expand Down Expand Up @@ -183,7 +187,7 @@ fn take_optimizable_column_count(
{
let expr = format!("COUNT({})", col_expr.name());
return Some((
ScalarValue::UInt64(Some((num_rows - val) as u64)),
ScalarValue::Int64(Some((num_rows - val) as i64)),
expr,
));
}
Expand Down Expand Up @@ -254,9 +258,10 @@ mod tests {
use super::*;
use std::sync::Arc;

use arrow::array::{Int32Array, UInt64Array};
use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_physical_expr::PhysicalExpr;

use crate::error::Result;
use crate::logical_plan::Operator;
Expand Down Expand Up @@ -291,65 +296,132 @@ mod tests {
}

/// Checks that the count optimization was applied and we still get the right result
async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> Result<()> {
async fn assert_count_optim_success(
plan: AggregateExec,
agg: TestAggregate,
) -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?;

let (col, count) = match nulls {
false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3),
true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
};
let plan = Arc::new(plan) as _;
let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?;

// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));

// run both the optimized and nonoptimized plan
let optimized_result =
common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?;
let nonoptimized_result =
common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
assert_eq!(optimized_result.len(), nonoptimized_result.len());

// and validate the results are the same and expected
assert_eq!(optimized_result.len(), 1);
check_batch(optimized_result.into_iter().next().unwrap(), &agg);
// check the non optimized one too to ensure types and names remain the same
assert_eq!(nonoptimized_result.len(), 1);
check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);

Ok(())
}

fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
let schema = batch.schema();
let fields = schema.fields();
assert_eq!(fields.len(), 1);

let field = &fields[0];
assert_eq!(field.name(), agg.column_name());
assert_eq!(field.data_type(), &DataType::Int64);
// note that nullabiolity differs

assert_eq!(
result[0]
batch
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.downcast_ref::<Int64Array>()
.unwrap()
.values(),
&[count]
&[agg.expected_count()]
);
Ok(())
}

fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn AggregateExpr> {
// Return appropriate expr depending if COUNT is for col or table
let expr = match schema {
None => expressions::lit(ScalarValue::UInt8(Some(1))),
Some(s) => expressions::col(col.unwrap(), s).unwrap(),
};
Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
/// Describe the type of aggregate being tested
enum TestAggregate {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now parameterizes the difference between different tests into an explicit enum rather than implicit assumptions. I think it makes the tests easier to follow

/// Testing COUNT(*) type aggregates
CountStar,

/// Testing for COUNT(column) aggregate
ColumnA(Arc<Schema>),
}

impl TestAggregate {
fn new_count_star() -> Self {
Self::CountStar
}

fn new_count_column(schema: &Arc<Schema>) -> Self {
Self::ColumnA(schema.clone())
}

/// Return appropriate expr depending if COUNT is for col or table (*)
fn count_expr(&self) -> Arc<dyn AggregateExpr> {
Arc::new(Count::new(
self.column(),
self.column_name(),
DataType::Int64,
))
}

/// what argument would this aggregate need in the plan?
fn column(&self) -> Arc<dyn PhysicalExpr> {
match self {
Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
Self::ColumnA(s) => expressions::col("a", s).unwrap(),
}
}

/// What name would this aggregate produce in a plan?
fn column_name(&self) -> &'static str {
match self {
Self::CountStar => COUNT_STAR_NAME,
Self::ColumnA(_) => "COUNT(a)",
}
}

/// What is the expected count?
fn expected_count(&self) -> i64 {
match self {
TestAggregate::CountStar => 3,
TestAggregate::ColumnA(_) => 2,
}
}
}

#[tokio::test]
async fn test_count_partial_direct_child() -> Result<()> {
// basic test case with the aggregation applied on a source with exact statistics
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;

let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;

assert_count_optim_success(final_agg, false).await?;
assert_count_optim_success(final_agg, agg).await?;

Ok(())
}
Expand All @@ -359,24 +431,25 @@ mod tests {
// basic test case with the aggregation applied on a source with exact statistics
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;

let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;

assert_count_optim_success(final_agg, true).await?;
assert_count_optim_success(final_agg, agg).await?;

Ok(())
}
Expand All @@ -385,11 +458,12 @@ mod tests {
async fn test_count_partial_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
Expand All @@ -400,12 +474,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
)?;

assert_count_optim_success(final_agg, false).await?;
assert_count_optim_success(final_agg, agg).await?;

Ok(())
}
Expand All @@ -414,11 +488,12 @@ mod tests {
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
Expand All @@ -429,12 +504,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
)?;

assert_count_optim_success(final_agg, true).await?;
assert_count_optim_success(final_agg, agg).await?;

Ok(())
}
Expand All @@ -443,6 +518,7 @@ mod tests {
async fn test_count_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();

// adding a filter makes the statistics inexact
let filter = Arc::new(FilterExec::try_new(
Expand All @@ -458,15 +534,15 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
filter,
Arc::clone(&schema),
)?;

let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand All @@ -485,6 +561,7 @@ mod tests {
async fn test_count_with_nulls_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);

// adding a filter makes the statistics inexact
let filter = Arc::new(FilterExec::try_new(
Expand All @@ -500,15 +577,15 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
filter,
Arc::clone(&schema),
)?;

let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/custom_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Int32Array, PrimitiveArray, UInt64Array};
use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
use arrow::compute::kernels::aggregate;
use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
Expand Down Expand Up @@ -284,12 +284,12 @@ async fn optimizers_catch_all_statistics() {

let expected = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
Field::new("COUNT(UInt8(1))", DataType::Int64, false),
Field::new("MIN(test.c1)", DataType::Int32, false),
Field::new("MAX(test.c1)", DataType::Int32, false),
])),
vec![
Arc::new(UInt64Array::from_slice(&[4])),
Arc::new(Int64Array::from_slice(&[4])),
Arc::new(Int32Array::from_slice(&[1])),
Arc::new(Int32Array::from_slice(&[100])),
],
Expand Down
6 changes: 5 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ use crate::logical_plan::{
};
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use std::collections::HashSet;
use std::sync::Arc;

/// The value to which `COUNT(*)` is expanded to in
/// `COUNT(<constant>)` expressions
pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));

/// Recursively walk a list of expression trees, collecting the unique set of columns
/// referenced in the expression
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
Expand Down
Loading