From 860631f3ec72535dc67d58aab00908eade4198ca Mon Sep 17 00:00:00 2001 From: yangjiang Date: Wed, 20 Apr 2022 16:54:35 +0800 Subject: [PATCH 1/3] Enable filter pushdown when using In_list on parquet --- .../core/src/physical_optimizer/pruning.rs | 43 +++++++++++++++++++ datafusion/core/tests/parquet_pruning.rs | 15 +++++++ 2 files changed, 58 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 67b7476e55795..79e74f6d04178 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -688,6 +688,29 @@ fn build_predicate_expression( return Ok(unhandled); } } + Expr::InList { + expr, + list, + negated, + } => { + if !list.is_empty() { + let mut or_expr = if *negated { + Expr::not_eq(*expr.clone(), list[0].clone()) + } else { + Expr::eq(*expr.clone(), list[0].clone()) + }; + + for e in list.iter().skip(1) { + if *negated { + or_expr = or_expr.or(Expr::not_eq(*expr.clone(), e.clone())); + } else { + or_expr = or_expr.or(Expr::eq(*expr.clone(), e.clone())); + } + } + return build_predicate_expression(&or_expr, schema, required_columns); + } + return Ok(unhandled); + } _ => { return Ok(unhandled); } @@ -1340,6 +1363,26 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_in_list() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 in(1, 2, 3) + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![lit(1), lit(2), lit(3)], + negated: false, + }; + let expected_expr = "#c1_min <= Int32(1) AND Int32(1) <= #c1_max OR #c1_min <= Int32(2) AND Int32(2) <= #c1_max OR #c1_min <= Int32(3) AND Int32(3) <= #c1_max"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + #[test] fn prune_api() { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index d5392e9dcbff5..21805c40f074f 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -403,6 +403,21 @@ async fn prune_f64_complex_expr_subtract() { assert_eq!(output.result_rows, 9, "{}", output.description()); } +#[tokio::test] +async fn prune_int32_eq_in_list() { + // resulrt of sql "SELECT * FROM t where in (1)" + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i in (1)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(3)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + // ---------------------- // Begin test fixture // ---------------------- From 688baa73ce8b335e76e927aa330e97beaa930d59 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Fri, 22 Apr 2022 16:55:41 +0800 Subject: [PATCH 2/3] fix negated situation and add UT --- .../core/src/physical_optimizer/pruning.rs | 67 ++++++++++++++----- datafusion/core/tests/parquet_pruning.rs | 54 +++++++++++---- 2 files changed, 91 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 79e74f6d04178..e734153d6ca64 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -692,24 +692,15 @@ fn build_predicate_expression( expr, list, negated, - } => { - if !list.is_empty() { - let mut or_expr = if *negated { - Expr::not_eq(*expr.clone(), list[0].clone()) - } else { - Expr::eq(*expr.clone(), list[0].clone()) - }; - - for e in list.iter().skip(1) { - if *negated { - or_expr = or_expr.or(Expr::not_eq(*expr.clone(), e.clone())); - } else { - or_expr = or_expr.or(Expr::eq(*expr.clone(), e.clone())); - } - } - return build_predicate_expression(&or_expr, schema, required_columns); - } - return Ok(unhandled); + } if !list.is_empty() && list.len() < 20 => { + let eq_fun = if *negated { Expr::not_eq } else { Expr::eq }; + let re_fun = if *negated { Expr::and } else { Expr::or }; + let change_expr = list + .into_iter() + .map(|e| eq_fun(*expr.clone(), e.clone())) + .reduce(re_fun) + .unwrap(); + return build_predicate_expression(&change_expr, schema, required_columns); } _ => { return Ok(unhandled); @@ -1383,6 +1374,46 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_in_list_empty() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 in() + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![], + negated: false, + }; + let expected_expr = "Boolean(true)"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_in_list_negated() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 not in(1, 2, 3) + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![lit(1), lit(2), lit(3)], + negated: true, + }; + let expected_expr = "#c1_min != Int32(1) OR Int32(1) != #c1_max AND #c1_min != Int32(2) OR Int32(2) != #c1_max AND #c1_min != Int32(3) OR Int32(3) != #c1_max"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + #[test] fn prune_api() { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index 21805c40f074f..5ee4fcca44ba4 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -185,7 +185,7 @@ async fn prune_int32_lt() { let (expected_errors, expected_row_group_pruned, expected_results) = (Some(0), Some(1), 11); - // resulrt of sql "SELECT * FROM t where i < 1" is same as + // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" let output = ContextWithParquet::new(Scenario::Int32) .await @@ -222,7 +222,7 @@ async fn prune_int32_lt() { #[tokio::test] async fn prune_int32_eq() { - // resulrt of sql "SELECT * FROM t where i = 1" + // result of sql "SELECT * FROM t where i = 1" let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where i = 1") @@ -237,7 +237,7 @@ async fn prune_int32_eq() { #[tokio::test] async fn prune_int32_scalar_fun_and_eq() { - // resulrt of sql "SELECT * FROM t where abs(i) = 1 and i = 1" + // result of sql "SELECT * FROM t where abs(i) = 1 and i = 1" // only use "i = 1" to prune let output = ContextWithParquet::new(Scenario::Int32) .await @@ -253,7 +253,7 @@ async fn prune_int32_scalar_fun_and_eq() { #[tokio::test] async fn prune_int32_scalar_fun() { - // resulrt of sql "SELECT * FROM t where abs(i) = 1" is not supported + // result of sql "SELECT * FROM t where abs(i) = 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where abs(i) = 1") @@ -269,7 +269,7 @@ async fn prune_int32_scalar_fun() { #[tokio::test] async fn prune_int32_complex_expr() { - // resulrt of sql "SELECT * FROM t where i+1 = 1" is not supported + // result of sql "SELECT * FROM t where i+1 = 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where i+1 = 1") @@ -285,7 +285,7 @@ async fn prune_int32_complex_expr() { #[tokio::test] async fn prune_int32_complex_expr_subtract() { - // resulrt of sql "SELECT * FROM t where 1-i > 1" is not supported + // result of sql "SELECT * FROM t where 1-i > 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where 1-i > 1") @@ -304,7 +304,7 @@ async fn prune_f64_lt() { let (expected_errors, expected_row_group_pruned, expected_results) = (Some(0), Some(1), 11); - // resulrt of sql "SELECT * FROM t where i < 1" is same as + // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" let output = ContextWithParquet::new(Scenario::Float64) .await @@ -341,7 +341,7 @@ async fn prune_f64_lt() { #[tokio::test] async fn prune_f64_scalar_fun_and_gt() { - // resulrt of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" + // result of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" // only use "f >= 0" to prune let output = ContextWithParquet::new(Scenario::Float64) .await @@ -357,7 +357,7 @@ async fn prune_f64_scalar_fun_and_gt() { #[tokio::test] async fn prune_f64_scalar_fun() { - // resulrt of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported + // result of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where abs(f-1) <= 0.000001") @@ -373,7 +373,7 @@ async fn prune_f64_scalar_fun() { #[tokio::test] async fn prune_f64_complex_expr() { - // resulrt of sql "SELECT * FROM t where f+1 > 1.1"" is not supported + // result of sql "SELECT * FROM t where f+1 > 1.1"" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where f+1 > 1.1") @@ -389,7 +389,7 @@ async fn prune_f64_complex_expr() { #[tokio::test] async fn prune_f64_complex_expr_subtract() { - // resulrt of sql "SELECT * FROM t where 1-f > 1" is not supported + // result of sql "SELECT * FROM t where 1-f > 1" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where 1-f > 1") @@ -405,7 +405,7 @@ async fn prune_f64_complex_expr_subtract() { #[tokio::test] async fn prune_int32_eq_in_list() { - // resulrt of sql "SELECT * FROM t where in (1)" + // result of sql "SELECT * FROM t where in (1)" let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where i in (1)") @@ -418,6 +418,36 @@ async fn prune_int32_eq_in_list() { assert_eq!(output.result_rows, 1, "{}", output.description()); } +#[tokio::test] +async fn prune_int32_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i in (1000)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(4)); + assert_eq!(output.result_rows, 0, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i not in (1)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 19, "{}", output.description()); +} + // ---------------------- // Begin test fixture // ---------------------- From b2f815c83c4edd9e4a86dd687e8a2e4dbddda99a Mon Sep 17 00:00:00 2001 From: yangjiang Date: Fri, 22 Apr 2022 16:57:20 +0800 Subject: [PATCH 3/3] fix clippy --- datafusion/core/src/physical_optimizer/pruning.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index e734153d6ca64..737b4401f7b1a 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -696,7 +696,7 @@ fn build_predicate_expression( let eq_fun = if *negated { Expr::not_eq } else { Expr::eq }; let re_fun = if *negated { Expr::and } else { Expr::or }; let change_expr = list - .into_iter() + .iter() .map(|e| eq_fun(*expr.clone(), e.clone())) .reduce(re_fun) .unwrap();