From f96fc32060f825b117e44e1e48cf868274c55972 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2024 20:35:18 -0700 Subject: [PATCH 1/3] Fix SortMergeJoin with join filter filtering all rows out --- datafusion/core/tests/sql/joins.rs | 29 +++++++++++++++++++ .../src/joins/sort_merge_join.rs | 4 ++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index f7d5205db0d3f..a803ab7ceb36b 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -231,3 +231,32 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn test_smj_with_join_filter_fitering_all() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let sql = "set datafusion.optimizer.prefer_hash_join = false;"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = "set datafusion.execution.batch_size = 1"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = " + select * from ( + with + t1 as ( + select 12 a, 12 b + ), + t2 as ( + select 12 a, 12 b + ) + select t1.* from t1 join t2 on t1.a = t2.b where t1.a > t2.b + ) order by 1, 2; + "; + + let results = ctx.sql(sql).await?.collect().await?; + assert_eq!(results.len(), 0); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 4c928a3d2d8d6..d4cf6864d7e49 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1323,7 +1323,9 @@ impl SMJStream { // If join filter exists, `self.output_size` is not accurate as we don't know the exact // number of rows in the output record batch. If streamed row joined with buffered rows, // once join filter is applied, the number of output rows may be more than 1. - if record_batch.num_rows() > self.output_size { + // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened + // when the join filter is applied and all rows are filtered out. + if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { self.output_size = 0; } else { self.output_size -= record_batch.num_rows(); From e4cffd637497b14454c92202ae026c8119eff18b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 May 2024 08:43:35 -0700 Subject: [PATCH 2/3] Move test --- datafusion/core/tests/sql/joins.rs | 29 ------------------- .../test_files/sort_merge_join.slt | 17 +++++++++++ 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index a803ab7ceb36b..f7d5205db0d3f 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -231,32 +231,3 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } - -#[tokio::test] -async fn test_smj_with_join_filter_fitering_all() -> Result<()> { - let ctx: SessionContext = SessionContext::new(); - - let sql = "set datafusion.optimizer.prefer_hash_join = false;"; - let _ = ctx.sql(sql).await?.collect().await?; - - let sql = "set datafusion.execution.batch_size = 1"; - let _ = ctx.sql(sql).await?.collect().await?; - - let sql = " - select * from ( - with - t1 as ( - select 12 a, 12 b - ), - t2 as ( - select 12 a, 12 b - ) - select t1.* from t1 join t2 on t1.a = t2.b where t1.a > t2.b - ) order by 1, 2; - "; - - let results = ctx.sql(sql).await?.collect().await?; - assert_eq!(results.len(), 0); - - Ok(()) -} diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 09a2aa3e74363..c9b4929911654 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,5 +263,22 @@ DROP TABLE t1; statement ok DROP TABLE t2; +# Set batch size to 1 for sort merge join +statement ok +set datafusion.execution.batch_size = 1; + +query II +SELECT * FROM ( + WITH + t1 AS ( + SELECT 12 a, 12 b + ), + t2 AS ( + SELECT 12 a, 12 b + ) + SELECT t1.* FROM t1 JOIN t2 on t1.a = t2.b WHERE t1.a > t2.b +) ORDER BY 1, 2; +---- + statement ok set datafusion.optimizer.prefer_hash_join = true; From b40131745e0db8ac706abbe7d2309c450dce9524 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 14 May 2024 09:09:47 -0700 Subject: [PATCH 3/3] Update datafusion/sqllogictest/test_files/sort_merge_join.slt --- datafusion/sqllogictest/test_files/sort_merge_join.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index c9b4929911654..7b7e355fa2b52 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,7 +263,7 @@ DROP TABLE t1; statement ok DROP TABLE t2; -# Set batch size to 1 for sort merge join +# Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches statement ok set datafusion.execution.batch_size = 1;