diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 67b7476e55795..737b4401f7b1a 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -688,6 +688,20 @@ fn build_predicate_expression( return Ok(unhandled); } } + Expr::InList { + expr, + list, + negated, + } if !list.is_empty() && list.len() < 20 => { + let eq_fun = if *negated { Expr::not_eq } else { Expr::eq }; + let re_fun = if *negated { Expr::and } else { Expr::or }; + let change_expr = list + .iter() + .map(|e| eq_fun(*expr.clone(), e.clone())) + .reduce(re_fun) + .unwrap(); + return build_predicate_expression(&change_expr, schema, required_columns); + } _ => { return Ok(unhandled); } @@ -1340,6 +1354,66 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_in_list() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 in(1, 2, 3) + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![lit(1), lit(2), lit(3)], + negated: false, + }; + let expected_expr = "#c1_min <= Int32(1) AND Int32(1) <= #c1_max OR #c1_min <= Int32(2) AND Int32(2) <= #c1_max OR #c1_min <= Int32(3) AND Int32(3) <= #c1_max"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_in_list_empty() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 in() + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![], + negated: false, + }; + let expected_expr = "Boolean(true)"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_in_list_negated() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + // test c1 not in(1, 2, 3) + let expr = Expr::InList { + expr: Box::new(col("c1")), + list: vec![lit(1), lit(2), lit(3)], + negated: true, + }; + let expected_expr = "#c1_min != Int32(1) OR Int32(1) != #c1_max AND #c1_min != Int32(2) OR Int32(2) != #c1_max AND #c1_min != Int32(3) OR Int32(3) != #c1_max"; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + #[test] fn prune_api() { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index d5392e9dcbff5..5ee4fcca44ba4 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -185,7 +185,7 @@ async fn prune_int32_lt() { let (expected_errors, expected_row_group_pruned, expected_results) = (Some(0), Some(1), 11); - // resulrt of sql "SELECT * FROM t where i < 1" is same as + // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" let output = ContextWithParquet::new(Scenario::Int32) .await @@ -222,7 +222,7 @@ async fn prune_int32_lt() { #[tokio::test] async fn prune_int32_eq() { - // resulrt of sql "SELECT * FROM t where i = 1" + // result of sql "SELECT * FROM t where i = 1" let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where i = 1") @@ -237,7 +237,7 @@ async fn prune_int32_eq() { #[tokio::test] async fn prune_int32_scalar_fun_and_eq() { - // resulrt of sql "SELECT * FROM t where abs(i) = 1 and i = 1" + // result of sql "SELECT * FROM t where abs(i) = 1 and i = 1" // only use "i = 1" to prune let output = ContextWithParquet::new(Scenario::Int32) .await @@ -253,7 +253,7 @@ async fn prune_int32_scalar_fun_and_eq() { #[tokio::test] async fn prune_int32_scalar_fun() { - // resulrt of sql "SELECT * FROM t where abs(i) = 1" is not supported + // result of sql "SELECT * FROM t where abs(i) = 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where abs(i) = 1") @@ -269,7 +269,7 @@ async fn prune_int32_scalar_fun() { #[tokio::test] async fn prune_int32_complex_expr() { - // resulrt of sql "SELECT * FROM t where i+1 = 1" is not supported + // result of sql "SELECT * FROM t where i+1 = 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where i+1 = 1") @@ -285,7 +285,7 @@ async fn prune_int32_complex_expr() { #[tokio::test] async fn prune_int32_complex_expr_subtract() { - // resulrt of sql "SELECT * FROM t where 1-i > 1" is not supported + // result of sql "SELECT * FROM t where 1-i > 1" is not supported let output = ContextWithParquet::new(Scenario::Int32) .await .query("SELECT * FROM t where 1-i > 1") @@ -304,7 +304,7 @@ async fn prune_f64_lt() { let (expected_errors, expected_row_group_pruned, expected_results) = (Some(0), Some(1), 11); - // resulrt of sql "SELECT * FROM t where i < 1" is same as + // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" let output = ContextWithParquet::new(Scenario::Float64) .await @@ -341,7 +341,7 @@ async fn prune_f64_lt() { #[tokio::test] async fn prune_f64_scalar_fun_and_gt() { - // resulrt of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" + // result of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" // only use "f >= 0" to prune let output = ContextWithParquet::new(Scenario::Float64) .await @@ -357,7 +357,7 @@ async fn prune_f64_scalar_fun_and_gt() { #[tokio::test] async fn prune_f64_scalar_fun() { - // resulrt of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported + // result of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where abs(f-1) <= 0.000001") @@ -373,7 +373,7 @@ async fn prune_f64_scalar_fun() { #[tokio::test] async fn prune_f64_complex_expr() { - // resulrt of sql "SELECT * FROM t where f+1 > 1.1"" is not supported + // result of sql "SELECT * FROM t where f+1 > 1.1"" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where f+1 > 1.1") @@ -389,7 +389,7 @@ async fn prune_f64_complex_expr() { #[tokio::test] async fn prune_f64_complex_expr_subtract() { - // resulrt of sql "SELECT * FROM t where 1-f > 1" is not supported + // result of sql "SELECT * FROM t where 1-f > 1" is not supported let output = ContextWithParquet::new(Scenario::Float64) .await .query("SELECT * FROM t where 1-f > 1") @@ -403,6 +403,51 @@ async fn prune_f64_complex_expr_subtract() { assert_eq!(output.result_rows, 9, "{}", output.description()); } +#[tokio::test] +async fn prune_int32_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i in (1)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(3)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i in (1000)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(4)); + assert_eq!(output.result_rows, 0, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i not in (1)") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 19, "{}", output.description()); +} + // ---------------------- // Begin test fixture // ----------------------