From 15aa56d840f9bb7a646aa1000ab8873529bd9e5e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Sep 2022 07:00:32 -0400 Subject: [PATCH 1/3] Add tests for pruning, support pruning with constant expressions --- .../core/src/physical_optimizer/pruning.rs | 193 ++++++++++++++---- datafusion/core/tests/parquet_pruning.rs | 12 +- 2 files changed, 156 insertions(+), 49 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 73f6c795c9016..468108c2f8317 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -38,11 +38,13 @@ use crate::{ logical_plan::{Column, DFSchema, Expr, Operator}, physical_plan::{ColumnarValue, PhysicalExpr}, }; +use arrow::record_batch::RecordBatchOptions; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_common::ScalarValue; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable}; @@ -168,38 +170,51 @@ impl PruningPredicate { /// simplified version `b`. The predicates are simplified via the /// ConstantFolding optimizer pass pub fn prune(&self, statistics: &S) -> Result> { - // build statistics record batch - let predicate_array = - build_statistics_record_batch(statistics, &self.required_columns) - .and_then(|statistics_batch| { - // execute predicate expression - self.predicate_expr.evaluate(&statistics_batch) - }) - .and_then(|v| match v { - ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( - "predicate expression didn't return an array".to_string(), - )), - })?; - - let predicate_array = predicate_array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected pruning predicate evaluation to be BooleanArray, \ - but was {:?}", - predicate_array - )) - })?; - - // when the result of the predicate expression for a row group is null / undefined, - // e.g. due to missing statistics, this row group can't be filtered out, - // so replace with true - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) - .collect::>()) + // build a RecordBatch that contains the min/max values in the + // appropriate statistics columns + let statistics_batch = + build_statistics_record_batch(statistics, &self.required_columns)?; + + // Evaluate the pruning predicate on that record batch. + // + // Use true when the result of evaluating a predicate + // expression on a row group is null (aka `None`). Null can + // arise when the statistics are unknown or some calculation + // in the predicate means we don't know for sure if the row + // group can be filtered out or not. To maintain correctness + // the row group must be kept and thus `true` is returned. + match self.predicate_expr.evaluate(&statistics_batch)? { + ColumnarValue::Array(array) => { + let predicate_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected pruning predicate evaluation to be BooleanArray, \ + but was {:?}", + array + )) + })?; + + Ok(predicate_array + .into_iter() + .map(|x| x.unwrap_or(true)) // None -> true per comments above + .collect::>()) + + }, + // result was a column + ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { + let v = v.unwrap_or(true); // None -> true per comments above + Ok(vec![v; statistics.num_containers()]) + } + other => { + Err(DataFusionError::Internal(format!( + "Unexpected result of pruning predicate evaluation. Expected Boolean array \ + or scalar but got {:?}", + other + ))) + } + } } /// Return a reference to the input schema @@ -390,8 +405,13 @@ fn build_statistics_record_batch( } let schema = Arc::new(Schema::new(fields)); - RecordBatch::try_new(schema, arrays) - .map_err(|err| DataFusionError::Plan(err.to_string())) + // provide the count in case there were no needed statistics + let mut options = RecordBatchOptions::default(); + options.row_count = Some(statistics.num_containers()); + + RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { + DataFusionError::Plan(format!("Can not create statistics record batch: {}", err)) + }) } struct PruningExpressionBuilder<'a> { @@ -1167,7 +1187,7 @@ mod tests { } #[test] - fn test_build_statistics_no_stats() { + fn test_build_statistics_no_required_stats() { let required_columns = RequiredStatColumns::new(); let statistics = OneContainerStats { @@ -1176,13 +1196,9 @@ mod tests { num_containers: 1, }; - let result = - build_statistics_record_batch(&statistics, &required_columns).unwrap_err(); - assert!( - result.to_string().contains("Invalid argument error"), - "{}", - result - ); + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); + assert_eq!(batch.num_rows(), 1); // had 1 container } #[test] @@ -1857,7 +1873,15 @@ mod tests { assert_eq!(result, expected_false); } - /// Creates setup for int32 chunk pruning + /// Creates a setup for chunk pruning, modeling a int32 column "i" + /// with 5 different containers (e.g. RowGroups). They have [min, + /// max]: + /// + /// i [-5, 5] + /// i [1, 11] + /// i [-11, -1] + /// i [NULL, NULL] + /// i [1, NULL] fn int32_setup() -> (SchemaRef, TestStatistics) { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); @@ -1921,6 +1945,45 @@ mod tests { assert_eq!(result, expected_ret); } + #[test] + fn prune_int32_col_lte_zero_cast() { + let (schema, statistics) = int32_setup(); + + // Expression "cast(i as utf8) <= '0'" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass in theory, -0.22 (conservatively keep) + // i [-11, -1] ==> no rows could pass in theory (conservatively keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (conservatively keep) + let expected_ret = vec![true, true, true, true, true]; + + // cast(i as utf8) <= 0 + let expr = cast(col("i"), DataType::Utf8).lt_eq(lit("0")); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // try_cast(i as utf8) <= 0 + let expr = try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // cast(-i as utf8) >= 0 + let expr = + Expr::Negative(Box::new(cast(col("i"), DataType::Utf8))).gt_eq(lit("0")); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // try_cast(-i as utf8) >= 0 + let expr = + Expr::Negative(Box::new(try_cast(col("i"), DataType::Utf8))).gt_eq(lit("0")); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + #[test] fn prune_int32_col_eq_zero() { let (schema, statistics) = int32_setup(); @@ -1940,6 +2003,50 @@ mod tests { assert_eq!(result, expected_ret); } + #[test] + fn prune_int32_col_eq_zero_cast() { + let (schema, statistics) = int32_setup(); + + // Expression "cast(i as int64) = 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> no rows can pass (not keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = vec![true, false, false, true, false]; + + let expr = cast(col("i"), DataType::Int64).eq(lit(0i64)); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + let expr = try_cast(col("i"), DataType::Int64).eq(lit(0i64)); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + + #[test] + fn prune_int32_col_eq_zero_cast_as_str() { + let (schema, statistics) = int32_setup(); + + // Note the cast is to a string where sorting properties are + // not the same as integers + // + // Expression "cast(i as utf8) = '0'" + // i [-5, 5] ==> some rows could pass (keep) + // i [1, 11] ==> no rows can pass (could keep) + // i [-11, -1] ==> no rows can pass (could keep) + // i [NULL, NULL] ==> unknown (keep) + // i [1, NULL] ==> no rows can pass (could keep) + let expected_ret = vec![true, true, true, true, true]; + + let expr = cast(col("i"), DataType::Utf8).eq(lit("0")); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + #[test] fn prune_int32_col_lt_neg_one() { let (schema, statistics) = int32_setup(); diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index 6681a748ee790..1801803ad174c 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -237,7 +237,7 @@ async fn prune_int32_scalar_fun() { test_prune( Scenario::Int32, "SELECT * FROM t where abs(i) = 1", - Some(4), + Some(0), Some(0), 3, ) @@ -249,7 +249,7 @@ async fn prune_int32_complex_expr() { test_prune( Scenario::Int32, "SELECT * FROM t where i+1 = 1", - Some(4), + Some(0), Some(0), 2, ) @@ -261,7 +261,7 @@ async fn prune_int32_complex_expr_subtract() { test_prune( Scenario::Int32, "SELECT * FROM t where 1-i > 1", - Some(4), + Some(0), Some(0), 9, ) @@ -308,7 +308,7 @@ async fn prune_f64_scalar_fun() { test_prune( Scenario::Float64, "SELECT * FROM t where abs(f-1) <= 0.000001", - Some(4), + Some(0), Some(0), 1, ) @@ -321,7 +321,7 @@ async fn prune_f64_complex_expr() { test_prune( Scenario::Float64, "SELECT * FROM t where f+1 > 1.1", - Some(4), + Some(0), Some(0), 9, ) @@ -334,7 +334,7 @@ async fn prune_f64_complex_expr_subtract() { test_prune( Scenario::Float64, "SELECT * FROM t where 1-f > 1", - Some(4), + Some(0), Some(0), 9, ) From 5ad465cec59be724bfcf6f6b9fba53166efb5611 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 14 Sep 2022 14:42:49 -0400 Subject: [PATCH 2/3] Use downcast_any! --- datafusion/core/src/physical_optimizer/pruning.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index cec9e1725147e..3de5514659095 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -28,6 +28,7 @@ //! entities (e.g. entire files) if the statistics are known via some //! other source (e.g. a catalog) +use std::any::type_name; use std::convert::TryFrom; use std::{collections::HashSet, sync::Arc}; @@ -44,7 +45,7 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::ScalarValue; +use datafusion_common::{downcast_value, ScalarValue}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::utils::expr_to_columns; @@ -186,16 +187,7 @@ impl PruningPredicate { // the row group must be kept and thus `true` is returned. match self.predicate_expr.evaluate(&statistics_batch)? { ColumnarValue::Array(array) => { - let predicate_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Expected pruning predicate evaluation to be BooleanArray, \ - but was {:?}", - array - )) - })?; + let predicate_array = downcast_value!(array, BooleanArray); Ok(predicate_array .into_iter() From 6bb0e96a12b51cd674f3651b2cc8882c8187f232 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 19 Sep 2022 09:08:54 -0400 Subject: [PATCH 3/3] chore: Remove uneeded use --- datafusion/core/src/physical_optimizer/pruning.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 3de5514659095..9c53e0a6a2de9 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -28,7 +28,6 @@ //! entities (e.g. entire files) if the statistics are known via some //! other source (e.g. a catalog) -use std::any::type_name; use std::convert::TryFrom; use std::{collections::HashSet, sync::Arc};